diff --git a/src/backend/distributed/commands/create_distributed_table.c b/src/backend/distributed/commands/create_distributed_table.c index e38395296..ab8011a6b 100644 --- a/src/backend/distributed/commands/create_distributed_table.c +++ b/src/backend/distributed/commands/create_distributed_table.c @@ -1185,7 +1185,7 @@ CreateCitusTable(Oid relationId, CitusTableType tableType, } else if (tableType == REFERENCE_TABLE) { - CreateReferenceTableShard(relationId); + CreateReferenceTableShard(relationId, colocatedTableId, colocationId); } if (ShouldSyncTableMetadata(relationId)) diff --git a/src/backend/distributed/operations/create_shards.c b/src/backend/distributed/operations/create_shards.c index 18978c671..40a1ca7c0 100644 --- a/src/backend/distributed/operations/create_shards.c +++ b/src/backend/distributed/operations/create_shards.c @@ -350,7 +350,8 @@ CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool * Also, the shard is replicated to the all active nodes in the cluster. */ void -CreateReferenceTableShard(Oid distributedTableId) +CreateReferenceTableShard(Oid distributedTableId, Oid colocatedTableId, + uint32 colocationId) { int workerStartIndex = 0; text *shardMinValue = NULL; @@ -382,24 +383,96 @@ CreateReferenceTableShard(Oid distributedTableId) tableName))); } - /* - * load and sort the worker node list for deterministic placements - * create_reference_table has already acquired pg_dist_node lock - */ - List *nodeList = ReferenceTablePlacementNodeList(ShareLock); - nodeList = SortList(nodeList, CompareWorkerNodes); + List *insertedShardPlacements = NIL; + if (!OidIsValid(colocatedTableId)) + { + /* create first reference table, place them on all active nodes */ - int replicationFactor = list_length(nodeList); + /* + * load and sort the worker node list for deterministic placements + * create_reference_table has already acquired pg_dist_node lock + */ + List *nodeList = ReferenceTablePlacementNodeList(ShareLock); + nodeList = SortList(nodeList, CompareWorkerNodes); - /* get the next shard id */ - uint64 shardId = GetNextShardId(); + int replicationFactor = list_length(nodeList); - InsertShardRow(distributedTableId, shardId, shardStorageType, shardMinValue, - shardMaxValue, NULL); + /* get the next shard id */ + uint64 shardId = GetNextShardId(); + uint64 shardGroupId = shardId; - List *insertedShardPlacements = InsertShardPlacementRows(distributedTableId, shardId, - nodeList, workerStartIndex, - replicationFactor); + StringInfoData shardgroupQuery = { 0 }; + initStringInfo(&shardgroupQuery); + + appendStringInfoString(&shardgroupQuery, + "WITH shardgroup_data(shardgroupid, colocationid, " + "shardminvalue, shardmaxvalue) AS (VALUES "); + + InsertShardGroupRow(shardGroupId, colocationId, shardMinValue, shardMaxValue); + appendStringInfo(&shardgroupQuery, "(%ld, %d, %s, %s)", + shardGroupId, + colocationId, + TextToSQLLiteral(shardMinValue), + TextToSQLLiteral(shardMaxValue)); + + appendStringInfoString(&shardgroupQuery, ") "); + appendStringInfoString(&shardgroupQuery, + "SELECT pg_catalog.citus_internal_add_shardgroup_metadata(" + "shardgroupid, colocationid, shardminvalue, shardmaxvalue)" + "FROM shardgroup_data;"); + + SendCommandToWorkersWithMetadata(shardgroupQuery.data); + + InsertShardRow(distributedTableId, shardId, shardStorageType, + shardMinValue, shardMaxValue, &shardGroupId); + + insertedShardPlacements = InsertShardPlacementRows(distributedTableId, shardId, + nodeList, workerStartIndex, + replicationFactor); + } + else + { + /* add reference table as colocated table to already existing reference table */ + + /* prevent placement changes of the source relation until we colocate with them */ + List *sourceShardIntervalList = LoadShardIntervalList(colocatedTableId); + LockShardListMetadata(sourceShardIntervalList, ShareLock); + + if (list_length(sourceShardIntervalList) != 1) + { + elog(ERROR, "colocating a reference table to a table with shardcount > 1"); + } + + ShardInterval *sourceInterval = + (ShardInterval *) linitial(sourceShardIntervalList); + uint64 newShardId = GetNextShardId(); + + InsertShardRow(distributedTableId, newShardId, shardStorageType, + shardMinValue, shardMaxValue, &sourceInterval->shardGroupId); + + List *sourceShardPlacementList = + ShardPlacementList(sourceInterval->shardId); + + ShardPlacement *sourceShardPlacement = NULL; + foreach_ptr(sourceShardPlacement, sourceShardPlacementList) + { + int32 groupId = sourceShardPlacement->groupId; + const uint64 shardSize = 0; + + /* + * Optimistically add shard placement row the pg_dist_shard_placement, in case + * of any error it will be roll-backed. + */ + uint64 shardPlacementId = InsertShardPlacementRow(newShardId, + INVALID_PLACEMENT_ID, + shardSize, + groupId); + + ShardPlacement *shardPlacement = + LoadShardPlacement(newShardId, shardPlacementId); + insertedShardPlacements = lappend(insertedShardPlacements, shardPlacement); + } + } CreateShardsOnWorkers(distributedTableId, insertedShardPlacements, useExclusiveConnection, colocatedShard); diff --git a/src/include/distributed/coordinator_protocol.h b/src/include/distributed/coordinator_protocol.h index 1444bff91..5aff81487 100644 --- a/src/include/distributed/coordinator_protocol.h +++ b/src/include/distributed/coordinator_protocol.h @@ -262,7 +262,8 @@ extern void CreateShardsWithRoundRobinPolicy(Oid distributedTableId, int32 shard bool useExclusiveConnections); extern void CreateColocatedShards(Oid targetRelationId, Oid sourceRelationId, bool useExclusiveConnections); -extern void CreateReferenceTableShard(Oid distributedTableId); +extern void CreateReferenceTableShard(Oid distributedTableId, Oid colocatedTableId, + uint32 colocationId); extern List * WorkerCreateShardCommandList(Oid relationId, int shardIndex, uint64 shardId, List *ddlCommandList, List *foreignConstraintCommandList);