From 94ccc8b7b51e54144a853358dafb4d017c6748e9 Mon Sep 17 00:00:00 2001 From: Burak Velioglu Date: Mon, 24 Jan 2022 17:05:12 +0300 Subject: [PATCH] Rename sequence instead of drop --- src/backend/distributed/commands/sequence.c | 42 ++++++++++++++ .../worker/worker_create_or_replace.c | 57 +++++++++++++++++-- .../worker/worker_data_fetch_protocol.c | 23 +++++--- src/include/distributed/commands.h | 1 + .../distributed/worker_create_or_replace.h | 4 ++ 5 files changed, 114 insertions(+), 13 deletions(-) diff --git a/src/backend/distributed/commands/sequence.c b/src/backend/distributed/commands/sequence.c index f1c2339a2..3a1e949f3 100644 --- a/src/backend/distributed/commands/sequence.c +++ b/src/backend/distributed/commands/sequence.c @@ -656,3 +656,45 @@ PostprocessAlterSequenceOwnerStmt(Node *node, const char *queryString) return NIL; } + + +/* + * GenerateBackupNameForSequenceCollision generates a new sequence name for an existing + * sequence. The name is generated in such a way that the new name doesn't overlap with + * an existing relation by adding a suffix with incrementing number after the new name. + */ +char * +GenerateBackupNameForSequenceCollision(const ObjectAddress *address) +{ + char *newName = palloc0(NAMEDATALEN); + char suffix[NAMEDATALEN] = { 0 }; + int count = 0; + char *namespaceName = get_namespace_name(get_rel_namespace(address->objectId)); + Oid schemaId = get_namespace_oid(namespaceName, false); + + char *baseName = get_rel_name(address->objectId); + int baseLength = strlen(baseName); + + while (true) + { + int suffixLength = SafeSnprintf(suffix, NAMEDATALEN - 1, "(citus_backup_%d)", + count); + + /* trim the base name at the end to leave space for the suffix and trailing \0 */ + baseLength = Min(baseLength, NAMEDATALEN - suffixLength - 1); + + /* clear newName before copying the potentially trimmed baseName and suffix */ + memset(newName, 0, NAMEDATALEN); + strncpy_s(newName, NAMEDATALEN, baseName, baseLength); + strncpy_s(newName + baseLength, NAMEDATALEN - baseLength, suffix, + suffixLength); + + Oid typeOid = get_relname_relid(newName, schemaId); + if (typeOid == InvalidOid) + { + return newName; + } + + count++; + } +} diff --git a/src/backend/distributed/worker/worker_create_or_replace.c b/src/backend/distributed/worker/worker_create_or_replace.c index c067abc11..82883c468 100644 --- a/src/backend/distributed/worker/worker_create_or_replace.c +++ b/src/backend/distributed/worker/worker_create_or_replace.c @@ -33,8 +33,6 @@ #include "distributed/worker_protocol.h" static const char * CreateStmtByObjectAddress(const ObjectAddress *address); -static RenameStmt * CreateRenameStatement(const ObjectAddress *address, char *newName); -static char * GenerateBackupNameForCollision(const ObjectAddress *address); PG_FUNCTION_INFO_V1(worker_create_or_replace_object); @@ -166,7 +164,7 @@ CreateStmtByObjectAddress(const ObjectAddress *address) * address. This name should be used when renaming an existing object before creating the * new object locally on the worker. */ -static char * +char * GenerateBackupNameForCollision(const ObjectAddress *address) { switch (getObjectClass(address)) @@ -186,6 +184,15 @@ GenerateBackupNameForCollision(const ObjectAddress *address) return GenerateBackupNameForTypeCollision(address); } + case OCLASS_CLASS: + { + char relKind = get_rel_relkind(address->objectId); + if (relKind == RELKIND_SEQUENCE) + { + return GenerateBackupNameForSequenceCollision(address); + } + } + default: { ereport(ERROR, (errmsg("unsupported object to construct a rename statement"), @@ -243,6 +250,7 @@ CreateRenameTypeStmt(const ObjectAddress *address, char *newName) address->objectId)); stmt->newname = newName; + return stmt; } @@ -265,11 +273,43 @@ CreateRenameProcStmt(const ObjectAddress *address, char *newName) } +/* + * CreateRenameSequenceStmt creates a rename statement for a sequence based on its + * ObjectAddress. The rename statement will rename the existing object on its address + * to the value provided in newName. + */ +static RenameStmt * +CreateRenameSequenceStmt(const ObjectAddress *address, char *newName) +{ + RenameStmt *stmt = makeNode(RenameStmt); + Oid seqOid = address->objectId; + + HeapTuple seqClassTuple = SearchSysCache1(RELOID, seqOid); + if (!HeapTupleIsValid(seqClassTuple)) + { + ereport(ERROR, (errmsg("citus cache lookup error"))); + } + Form_pg_class seqClassForm = (Form_pg_class) GETSTRUCT(seqClassTuple); + + char *schemaName = get_namespace_name(seqClassForm->relnamespace); + char *seqName = NameStr(seqClassForm->relname); + List *name = list_make2(makeString(schemaName), makeString(seqName)); + ReleaseSysCache(seqClassTuple); + + stmt->renameType = OBJECT_SEQUENCE; + stmt->object = (Node *) name; + stmt->relation = makeRangeVar(schemaName, seqName, -1); + stmt->newname = newName; + + return stmt; +} + + /* * CreateRenameStatement creates a rename statement for an existing object to rename the * object to newName. */ -static RenameStmt * +RenameStmt * CreateRenameStatement(const ObjectAddress *address, char *newName) { switch (getObjectClass(address)) @@ -289,6 +329,15 @@ CreateRenameStatement(const ObjectAddress *address, char *newName) return CreateRenameTypeStmt(address, newName); } + case OCLASS_CLASS: + { + char relKind = get_rel_relkind(address->objectId); + if (relKind == RELKIND_SEQUENCE) + { + return CreateRenameSequenceStmt(address, newName); + } + } + default: { ereport(ERROR, (errmsg("unsupported object to construct a rename statement"), diff --git a/src/backend/distributed/worker/worker_data_fetch_protocol.c b/src/backend/distributed/worker/worker_data_fetch_protocol.c index a24af380a..ae8898f6b 100644 --- a/src/backend/distributed/worker/worker_data_fetch_protocol.c +++ b/src/backend/distributed/worker/worker_data_fetch_protocol.c @@ -32,6 +32,7 @@ #include "distributed/commands/utility_hook.h" #include "distributed/connection_management.h" #include "distributed/coordinator_protocol.h" +#include "distributed/deparser.h" #include "distributed/intermediate_results.h" #include "distributed/listutils.h" #include "distributed/metadata_cache.h" @@ -44,6 +45,7 @@ #include "distributed/remote_commands.h" #include "distributed/resource_lock.h" +#include "distributed/worker_create_or_replace.h" #include "distributed/worker_protocol.h" #include "distributed/version_compat.h" #include "executor/spi.h" @@ -480,13 +482,16 @@ worker_apply_sequence_command(PG_FUNCTION_ARGS) Form_pg_sequence pgSequenceForm = pg_get_sequencedef(sequenceOid); if (pgSequenceForm->seqtypid != sequenceTypeId) { - StringInfo dropSequenceString = makeStringInfo(); - char *qualifiedSequenceName = quote_qualified_identifier(sequenceSchema, - sequenceName); - appendStringInfoString(dropSequenceString, "DROP SEQUENCE "); - appendStringInfoString(dropSequenceString, qualifiedSequenceName); - appendStringInfoString(dropSequenceString, ";"); - ExecuteQueryViaSPI(dropSequenceString->data, SPI_OK_UTILITY); + ObjectAddress sequenceAddress = { 0 }; + ObjectAddressSet(sequenceAddress, RelationRelationId, sequenceOid); + + char *newName = GenerateBackupNameForCollision(&sequenceAddress); + + RenameStmt *renameStmt = CreateRenameStatement(&sequenceAddress, newName); + const char *sqlRenameStmt = DeparseTreeNode((Node *) renameStmt); + ProcessUtilityParseTree((Node *) renameStmt, sqlRenameStmt, + PROCESS_UTILITY_QUERY, + NULL, None_Receiver, NULL); } } @@ -495,8 +500,8 @@ worker_apply_sequence_command(PG_FUNCTION_ARGS) None_Receiver, NULL); CommandCounterIncrement(); - sequenceRelationId = RangeVarGetRelid(createSequenceStatement->sequence, - AccessShareLock, false); + Oid sequenceRelationId = RangeVarGetRelid(createSequenceStatement->sequence, + AccessShareLock, false); Assert(sequenceRelationId != InvalidOid); AlterSequenceMinMax(sequenceRelationId, sequenceSchema, sequenceName, sequenceTypeId); diff --git a/src/include/distributed/commands.h b/src/include/distributed/commands.h index 8d04ae4c4..f415a8866 100644 --- a/src/include/distributed/commands.h +++ b/src/include/distributed/commands.h @@ -397,6 +397,7 @@ extern ObjectAddress AlterSequenceOwnerStmtObjectAddress(Node *node, bool missin extern ObjectAddress RenameSequenceStmtObjectAddress(Node *node, bool missing_ok); extern void ErrorIfUnsupportedSeqStmt(CreateSeqStmt *createSeqStmt); extern void ErrorIfDistributedAlterSeqOwnedBy(AlterSeqStmt *alterSeqStmt); +extern char * GenerateBackupNameForSequenceCollision(const ObjectAddress *address); /* statistics.c - forward declarations */ extern List * PreprocessCreateStatisticsStmt(Node *node, const char *queryString, diff --git a/src/include/distributed/worker_create_or_replace.h b/src/include/distributed/worker_create_or_replace.h index 403efc5dc..60323d172 100644 --- a/src/include/distributed/worker_create_or_replace.h +++ b/src/include/distributed/worker_create_or_replace.h @@ -14,8 +14,12 @@ #ifndef WORKER_CREATE_OR_REPLACE_H #define WORKER_CREATE_OR_REPLACE_H +#include "catalog/objectaddress.h" + #define CREATE_OR_REPLACE_COMMAND "SELECT worker_create_or_replace_object(%s);" extern char * WrapCreateOrReplace(const char *sql); +extern char * GenerateBackupNameForCollision(const ObjectAddress *address); +extern RenameStmt * CreateRenameStatement(const ObjectAddress *address, char *newName); #endif /* WORKER_CREATE_OR_REPLACE_H */