diff --git a/src/backend/distributed/operations/create_shards.c b/src/backend/distributed/operations/create_shards.c index 358927a09..abcb2ad8d 100644 --- a/src/backend/distributed/operations/create_shards.c +++ b/src/backend/distributed/operations/create_shards.c @@ -82,7 +82,6 @@ CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shardCount, int32 replicationFactor, bool useExclusiveConnections) { CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(distributedTableId); - bool colocatedShard = false; List *insertedShardPlacements = NIL; /* make sure table is hash partitioned */ @@ -201,7 +200,7 @@ CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shardCount, } CreateShardsOnWorkers(distributedTableId, insertedShardPlacements, - useExclusiveConnections, colocatedShard); + useExclusiveConnections); } @@ -213,7 +212,6 @@ void CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool useExclusiveConnections) { - bool colocatedShard = true; List *insertedShardPlacements = NIL; List *insertedShardIds = NIL; @@ -306,7 +304,7 @@ CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool } CreateShardsOnWorkers(targetRelationId, insertedShardPlacements, - useExclusiveConnections, colocatedShard); + useExclusiveConnections); } @@ -322,7 +320,6 @@ CreateReferenceTableShard(Oid distributedTableId) text *shardMinValue = NULL; text *shardMaxValue = NULL; bool useExclusiveConnection = false; - bool colocatedShard = false; /* * In contrast to append/range partitioned tables it makes more sense to @@ -368,7 +365,7 @@ CreateReferenceTableShard(Oid distributedTableId) replicationFactor); CreateShardsOnWorkers(distributedTableId, insertedShardPlacements, - useExclusiveConnection, colocatedShard); + useExclusiveConnection); } @@ -431,10 +428,8 @@ CreateSingleShardTableShardWithRoundRobinPolicy(Oid relationId, uint32 colocatio * creating a single shard. */ bool useExclusiveConnection = false; - - bool colocatedShard = false; CreateShardsOnWorkers(relationId, insertedShardPlacements, - useExclusiveConnection, colocatedShard); + useExclusiveConnection); } diff --git a/src/backend/distributed/operations/shard_transfer.c b/src/backend/distributed/operations/shard_transfer.c index abaa00251..23925a315 100644 --- a/src/backend/distributed/operations/shard_transfer.c +++ b/src/backend/distributed/operations/shard_transfer.c @@ -1841,7 +1841,11 @@ CopyShardForeignConstraintCommandListGrouped(ShardInterval *shardInterval, char *referencedSchemaName = get_namespace_name(referencedSchemaId); char *escapedReferencedSchemaName = quote_literal_cstr(referencedSchemaName); - if (IsCitusTableType(referencedRelationId, REFERENCE_TABLE)) + if (relationId == referencedRelationId) + { + referencedShardId = shardInterval->shardId; + } + else if (IsCitusTableType(referencedRelationId, REFERENCE_TABLE)) { referencedShardId = GetFirstShardId(referencedRelationId); } diff --git a/src/backend/distributed/operations/stage_protocol.c b/src/backend/distributed/operations/stage_protocol.c index ddab3453b..715e2f2f2 100644 --- a/src/backend/distributed/operations/stage_protocol.c +++ b/src/backend/distributed/operations/stage_protocol.c @@ -312,8 +312,6 @@ CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId, int attemptCount = replicationFactor; int workerNodeCount = list_length(workerNodeList); int placementsCreated = 0; - List *foreignConstraintCommandList = - GetReferencingForeignConstaintCommands(relationId); IncludeSequenceDefaults includeSequenceDefaults = NO_SEQUENCE_DEFAULTS; IncludeIdentities includeIdentityDefaults = NO_IDENTITY; @@ -346,7 +344,6 @@ CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId, uint32 nodeGroupId = workerNode->groupId; char *nodeName = workerNode->workerName; uint32 nodePort = workerNode->workerPort; - int shardIndex = -1; /* not used in this code path */ const uint64 shardSize = 0; MultiConnection *connection = GetNodeUserDatabaseConnection(connectionFlag, nodeName, nodePort, @@ -360,9 +357,8 @@ CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId, continue; } - List *commandList = WorkerCreateShardCommandList(relationId, shardIndex, shardId, - ddlCommandList, - foreignConstraintCommandList); + List *commandList = WorkerCreateShardCommandList(relationId, shardId, + ddlCommandList); ExecuteCriticalRemoteCommandList(connection, commandList); @@ -427,7 +423,7 @@ InsertShardPlacementRows(Oid relationId, int64 shardId, List *workerNodeList, */ void CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements, - bool useExclusiveConnection, bool colocatedShard) + bool useExclusiveConnection) { IncludeSequenceDefaults includeSequenceDefaults = NO_SEQUENCE_DEFAULTS; IncludeIdentities includeIdentityDefaults = NO_IDENTITY; @@ -437,8 +433,6 @@ CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements, includeSequenceDefaults, includeIdentityDefaults, creatingShellTableOnRemoteNode); - List *foreignConstraintCommandList = - GetReferencingForeignConstaintCommands(distributedRelationId); int taskId = 1; List *taskList = NIL; @@ -449,18 +443,10 @@ CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements, { uint64 shardId = shardPlacement->shardId; ShardInterval *shardInterval = LoadShardInterval(shardId); - int shardIndex = -1; List *relationShardList = RelationShardListForShardCreate(shardInterval); - if (colocatedShard) - { - shardIndex = ShardIndex(shardInterval); - } - List *commandList = WorkerCreateShardCommandList(distributedRelationId, - shardIndex, - shardId, ddlCommandList, - foreignConstraintCommandList); + shardId, ddlCommandList); Task *task = CitusMakeNode(Task); task->jobId = INVALID_JOB_ID; @@ -604,14 +590,12 @@ RelationShardListForShardCreate(ShardInterval *shardInterval) * shardId to create the shard on the worker node. */ List * -WorkerCreateShardCommandList(Oid relationId, int shardIndex, uint64 shardId, - List *ddlCommandList, - List *foreignConstraintCommandList) +WorkerCreateShardCommandList(Oid relationId, uint64 shardId, + List *ddlCommandList) { List *commandList = NIL; Oid schemaId = get_rel_namespace(relationId); char *schemaName = get_namespace_name(schemaId); - char *escapedSchemaName = quote_literal_cstr(schemaName); TableDDLCommand *ddlCommand = NULL; foreach_ptr(ddlCommand, ddlCommandList) @@ -622,57 +606,12 @@ WorkerCreateShardCommandList(Oid relationId, int shardIndex, uint64 shardId, commandList = lappend(commandList, applyDDLCommand); } - const char *command = NULL; - foreach_ptr(command, foreignConstraintCommandList) - { - char *escapedCommand = quote_literal_cstr(command); + ShardInterval *shardInterval = LoadShardInterval(shardId); - uint64 referencedShardId = INVALID_SHARD_ID; - - StringInfo applyForeignConstraintCommand = makeStringInfo(); - - /* we need to parse the foreign constraint command to get referencing table id */ - Oid referencedRelationId = ForeignConstraintGetReferencedTableId(command); - if (referencedRelationId == InvalidOid) - { - ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), - errmsg("cannot create foreign key constraint"), - errdetail("Referenced relation cannot be found."))); - } - - Oid referencedSchemaId = get_rel_namespace(referencedRelationId); - char *referencedSchemaName = get_namespace_name(referencedSchemaId); - char *escapedReferencedSchemaName = quote_literal_cstr(referencedSchemaName); - - /* - * In case of self referencing shards, relation itself might not be distributed - * already. Therefore we cannot use ColocatedShardIdInRelation which assumes - * given relation is distributed. Besides, since we know foreign key references - * itself, referencedShardId is actual shardId anyway. Also, if the referenced - * relation is a reference table, we cannot use ColocatedShardIdInRelation since - * reference tables only have one shard. Instead, we fetch the one and only shard - * from shardlist and use it. - */ - if (relationId == referencedRelationId) - { - referencedShardId = shardId; - } - else if (IsCitusTableType(referencedRelationId, REFERENCE_TABLE)) - { - referencedShardId = GetFirstShardId(referencedRelationId); - } - else - { - referencedShardId = ColocatedShardIdInRelation(referencedRelationId, - shardIndex); - } - - appendStringInfo(applyForeignConstraintCommand, - WORKER_APPLY_INTER_SHARD_DDL_COMMAND, shardId, escapedSchemaName, - referencedShardId, escapedReferencedSchemaName, escapedCommand); - - commandList = lappend(commandList, applyForeignConstraintCommand->data); - } + commandList = list_concat( + commandList, + CopyShardForeignConstraintCommandList(shardInterval) + ); /* * If the shard is created for a partition, send the command to create the @@ -680,7 +619,6 @@ WorkerCreateShardCommandList(Oid relationId, int shardIndex, uint64 shardId, */ if (PartitionTable(relationId)) { - ShardInterval *shardInterval = LoadShardInterval(shardId); char *attachPartitionCommand = GenerateAttachShardPartitionCommand(shardInterval); commandList = lappend(commandList, attachPartitionCommand); diff --git a/src/include/distributed/coordinator_protocol.h b/src/include/distributed/coordinator_protocol.h index e2f1e9c52..452232f3a 100644 --- a/src/include/distributed/coordinator_protocol.h +++ b/src/include/distributed/coordinator_protocol.h @@ -250,8 +250,7 @@ extern void CreateAppendDistributedShardPlacements(Oid relationId, int64 shardId List *workerNodeList, int replicationFactor); extern void CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements, - bool useExclusiveConnection, - bool colocatedShard); + bool useExclusiveConnection); extern List * InsertShardPlacementRows(Oid relationId, int64 shardId, List *workerNodeList, int workerStartIndex, int replicationFactor); @@ -264,9 +263,8 @@ extern void CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, extern void CreateReferenceTableShard(Oid distributedTableId); extern void CreateSingleShardTableShardWithRoundRobinPolicy(Oid relationId, uint32 colocationId); -extern List * WorkerCreateShardCommandList(Oid relationId, int shardIndex, uint64 shardId, - List *ddlCommandList, - List *foreignConstraintCommandList); +extern List * WorkerCreateShardCommandList(Oid relationId, uint64 shardId, + List *ddlCommandList); extern Oid ForeignConstraintGetReferencedTableId(const char *queryString); extern void CheckHashPartitionedTable(Oid distributedTableId); extern void CheckTableSchemaNameForDrop(Oid relationId, char **schemaName,