From 5bdf19f51786c5ac7e0888c0c5262830d244b0ee Mon Sep 17 00:00:00 2001 From: Onur Tirtir Date: Fri, 18 Aug 2023 15:00:02 +0300 Subject: [PATCH] Use CopyShardForeignConstraintCommandList in WorkerCreateShardCommandList What we do to collect foreign key constraint commands in WorkerCreateShardCommandList is quite similar to what we do in CopyShardForeignConstraintCommandList. Plus, the code that we used in WorkerCreateShardCommandList before was not able to properly handle foreign key constraints between Citus local tables --when creating a reference table from the referencing one. With a few slight modifications made to CopyShardForeignConstraintCommandList, we can use the same logic in WorkerCreateShardCommandList too. --- .../distributed/operations/create_shards.c | 13 +-- .../distributed/operations/shard_transfer.c | 6 +- .../distributed/operations/stage_protocol.c | 84 +++---------------- .../distributed/coordinator_protocol.h | 8 +- 4 files changed, 23 insertions(+), 88 deletions(-) diff --git a/src/backend/distributed/operations/create_shards.c b/src/backend/distributed/operations/create_shards.c index 358927a09..abcb2ad8d 100644 --- a/src/backend/distributed/operations/create_shards.c +++ b/src/backend/distributed/operations/create_shards.c @@ -82,7 +82,6 @@ CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shardCount, int32 replicationFactor, bool useExclusiveConnections) { CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(distributedTableId); - bool colocatedShard = false; List *insertedShardPlacements = NIL; /* make sure table is hash partitioned */ @@ -201,7 +200,7 @@ CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shardCount, } CreateShardsOnWorkers(distributedTableId, insertedShardPlacements, - useExclusiveConnections, colocatedShard); + useExclusiveConnections); } @@ -213,7 +212,6 @@ void CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool useExclusiveConnections) { - bool colocatedShard = true; List *insertedShardPlacements = NIL; List *insertedShardIds = NIL; @@ -306,7 +304,7 @@ CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool } CreateShardsOnWorkers(targetRelationId, insertedShardPlacements, - useExclusiveConnections, colocatedShard); + useExclusiveConnections); } @@ -322,7 +320,6 @@ CreateReferenceTableShard(Oid distributedTableId) text *shardMinValue = NULL; text *shardMaxValue = NULL; bool useExclusiveConnection = false; - bool colocatedShard = false; /* * In contrast to append/range partitioned tables it makes more sense to @@ -368,7 +365,7 @@ CreateReferenceTableShard(Oid distributedTableId) replicationFactor); CreateShardsOnWorkers(distributedTableId, insertedShardPlacements, - useExclusiveConnection, colocatedShard); + useExclusiveConnection); } @@ -431,10 +428,8 @@ CreateSingleShardTableShardWithRoundRobinPolicy(Oid relationId, uint32 colocatio * creating a single shard. */ bool useExclusiveConnection = false; - - bool colocatedShard = false; CreateShardsOnWorkers(relationId, insertedShardPlacements, - useExclusiveConnection, colocatedShard); + useExclusiveConnection); } diff --git a/src/backend/distributed/operations/shard_transfer.c b/src/backend/distributed/operations/shard_transfer.c index abaa00251..23925a315 100644 --- a/src/backend/distributed/operations/shard_transfer.c +++ b/src/backend/distributed/operations/shard_transfer.c @@ -1841,7 +1841,11 @@ CopyShardForeignConstraintCommandListGrouped(ShardInterval *shardInterval, char *referencedSchemaName = get_namespace_name(referencedSchemaId); char *escapedReferencedSchemaName = quote_literal_cstr(referencedSchemaName); - if (IsCitusTableType(referencedRelationId, REFERENCE_TABLE)) + if (relationId == referencedRelationId) + { + referencedShardId = shardInterval->shardId; + } + else if (IsCitusTableType(referencedRelationId, REFERENCE_TABLE)) { referencedShardId = GetFirstShardId(referencedRelationId); } diff --git a/src/backend/distributed/operations/stage_protocol.c b/src/backend/distributed/operations/stage_protocol.c index ddab3453b..715e2f2f2 100644 --- a/src/backend/distributed/operations/stage_protocol.c +++ b/src/backend/distributed/operations/stage_protocol.c @@ -312,8 +312,6 @@ CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId, int attemptCount = replicationFactor; int workerNodeCount = list_length(workerNodeList); int placementsCreated = 0; - List *foreignConstraintCommandList = - GetReferencingForeignConstaintCommands(relationId); IncludeSequenceDefaults includeSequenceDefaults = NO_SEQUENCE_DEFAULTS; IncludeIdentities includeIdentityDefaults = NO_IDENTITY; @@ -346,7 +344,6 @@ CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId, uint32 nodeGroupId = workerNode->groupId; char *nodeName = workerNode->workerName; uint32 nodePort = workerNode->workerPort; - int shardIndex = -1; /* not used in this code path */ const uint64 shardSize = 0; MultiConnection *connection = GetNodeUserDatabaseConnection(connectionFlag, nodeName, nodePort, @@ -360,9 +357,8 @@ CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId, continue; } - List *commandList = WorkerCreateShardCommandList(relationId, shardIndex, shardId, - ddlCommandList, - foreignConstraintCommandList); + List *commandList = WorkerCreateShardCommandList(relationId, shardId, + ddlCommandList); ExecuteCriticalRemoteCommandList(connection, commandList); @@ -427,7 +423,7 @@ InsertShardPlacementRows(Oid relationId, int64 shardId, List *workerNodeList, */ void CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements, - bool useExclusiveConnection, bool colocatedShard) + bool useExclusiveConnection) { IncludeSequenceDefaults includeSequenceDefaults = NO_SEQUENCE_DEFAULTS; IncludeIdentities includeIdentityDefaults = NO_IDENTITY; @@ -437,8 +433,6 @@ CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements, includeSequenceDefaults, includeIdentityDefaults, creatingShellTableOnRemoteNode); - List *foreignConstraintCommandList = - GetReferencingForeignConstaintCommands(distributedRelationId); int taskId = 1; List *taskList = NIL; @@ -449,18 +443,10 @@ CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements, { uint64 shardId = shardPlacement->shardId; ShardInterval *shardInterval = LoadShardInterval(shardId); - int shardIndex = -1; List *relationShardList = RelationShardListForShardCreate(shardInterval); - if (colocatedShard) - { - shardIndex = ShardIndex(shardInterval); - } - List *commandList = WorkerCreateShardCommandList(distributedRelationId, - shardIndex, - shardId, ddlCommandList, - foreignConstraintCommandList); + shardId, ddlCommandList); Task *task = CitusMakeNode(Task); task->jobId = INVALID_JOB_ID; @@ -604,14 +590,12 @@ RelationShardListForShardCreate(ShardInterval *shardInterval) * shardId to create the shard on the worker node. */ List * -WorkerCreateShardCommandList(Oid relationId, int shardIndex, uint64 shardId, - List *ddlCommandList, - List *foreignConstraintCommandList) +WorkerCreateShardCommandList(Oid relationId, uint64 shardId, + List *ddlCommandList) { List *commandList = NIL; Oid schemaId = get_rel_namespace(relationId); char *schemaName = get_namespace_name(schemaId); - char *escapedSchemaName = quote_literal_cstr(schemaName); TableDDLCommand *ddlCommand = NULL; foreach_ptr(ddlCommand, ddlCommandList) @@ -622,57 +606,12 @@ WorkerCreateShardCommandList(Oid relationId, int shardIndex, uint64 shardId, commandList = lappend(commandList, applyDDLCommand); } - const char *command = NULL; - foreach_ptr(command, foreignConstraintCommandList) - { - char *escapedCommand = quote_literal_cstr(command); + ShardInterval *shardInterval = LoadShardInterval(shardId); - uint64 referencedShardId = INVALID_SHARD_ID; - - StringInfo applyForeignConstraintCommand = makeStringInfo(); - - /* we need to parse the foreign constraint command to get referencing table id */ - Oid referencedRelationId = ForeignConstraintGetReferencedTableId(command); - if (referencedRelationId == InvalidOid) - { - ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), - errmsg("cannot create foreign key constraint"), - errdetail("Referenced relation cannot be found."))); - } - - Oid referencedSchemaId = get_rel_namespace(referencedRelationId); - char *referencedSchemaName = get_namespace_name(referencedSchemaId); - char *escapedReferencedSchemaName = quote_literal_cstr(referencedSchemaName); - - /* - * In case of self referencing shards, relation itself might not be distributed - * already. Therefore we cannot use ColocatedShardIdInRelation which assumes - * given relation is distributed. Besides, since we know foreign key references - * itself, referencedShardId is actual shardId anyway. Also, if the referenced - * relation is a reference table, we cannot use ColocatedShardIdInRelation since - * reference tables only have one shard. Instead, we fetch the one and only shard - * from shardlist and use it. - */ - if (relationId == referencedRelationId) - { - referencedShardId = shardId; - } - else if (IsCitusTableType(referencedRelationId, REFERENCE_TABLE)) - { - referencedShardId = GetFirstShardId(referencedRelationId); - } - else - { - referencedShardId = ColocatedShardIdInRelation(referencedRelationId, - shardIndex); - } - - appendStringInfo(applyForeignConstraintCommand, - WORKER_APPLY_INTER_SHARD_DDL_COMMAND, shardId, escapedSchemaName, - referencedShardId, escapedReferencedSchemaName, escapedCommand); - - commandList = lappend(commandList, applyForeignConstraintCommand->data); - } + commandList = list_concat( + commandList, + CopyShardForeignConstraintCommandList(shardInterval) + ); /* * If the shard is created for a partition, send the command to create the @@ -680,7 +619,6 @@ WorkerCreateShardCommandList(Oid relationId, int shardIndex, uint64 shardId, */ if (PartitionTable(relationId)) { - ShardInterval *shardInterval = LoadShardInterval(shardId); char *attachPartitionCommand = GenerateAttachShardPartitionCommand(shardInterval); commandList = lappend(commandList, attachPartitionCommand); diff --git a/src/include/distributed/coordinator_protocol.h b/src/include/distributed/coordinator_protocol.h index e2f1e9c52..452232f3a 100644 --- a/src/include/distributed/coordinator_protocol.h +++ b/src/include/distributed/coordinator_protocol.h @@ -250,8 +250,7 @@ extern void CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId List *workerNodeList, int replicationFactor); extern void CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements, - bool useExclusiveConnection, - bool colocatedShard); + bool useExclusiveConnection); extern List * InsertShardPlacementRows(Oid relationId, int64 shardId, List *workerNodeList, int workerStartIndex, int replicationFactor); @@ -264,9 +263,8 @@ extern void CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, extern void CreateReferenceTableShard(Oid distributedTableId); extern void CreateSingleShardTableShardWithRoundRobinPolicy(Oid relationId, uint32 colocationId); -extern List * WorkerCreateShardCommandList(Oid relationId, int shardIndex, uint64 shardId, - List *ddlCommandList, - List *foreignConstraintCommandList); +extern List * WorkerCreateShardCommandList(Oid relationId, uint64 shardId, + List *ddlCommandList); extern Oid ForeignConstraintGetReferencedTableId(const char *queryString); extern void CheckHashPartitionedTable(Oid distributedTableId); extern void CheckTableSchemaNameForDrop(Oid relationId, char **schemaName,