diff --git a/src/backend/distributed/commands/table.c b/src/backend/distributed/commands/table.c index 0c9e3d552..635909fd2 100644 --- a/src/backend/distributed/commands/table.c +++ b/src/backend/distributed/commands/table.c @@ -39,7 +39,6 @@ #include "utils/lsyscache.h" #include "utils/syscache.h" - /* Local functions forward declarations for unsupported command checks */ static void ErrorIfUnsupportedAlterTableStmt(AlterTableStmt *alterTableStatement); static List * InterShardDDLTaskList(Oid leftRelationId, Oid rightRelationId, diff --git a/src/backend/distributed/master/master_metadata_utility.c b/src/backend/distributed/master/master_metadata_utility.c index 196a33b97..6f1c6fdf9 100644 --- a/src/backend/distributed/master/master_metadata_utility.c +++ b/src/backend/distributed/master/master_metadata_utility.c @@ -503,6 +503,35 @@ LoadShardIntervalList(Oid relationId) } +/* + * GetOnlyShardOidOfReferenceTable returns OID of the one and only placement + * of the given reference table. Caller of this function must ensure that + * referenceTableOid is owned by a reference table. + */ +Oid +GetOnlyShardOidOfReferenceTable(Oid referenceTableOid) +{ + DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(referenceTableOid); + + /* assert that it is a "valid" "reference table" */ + Assert(cacheEntry != NULL && cacheEntry->shardIntervalArrayLength == 1); + + const ShardInterval *shardInterval = cacheEntry->sortedShardIntervalArray[0]; + uint64 referenceTableShardId = shardInterval->shardId; + + /* construct reference table's one & only placement's relation name */ + char *referenceTableShardName = get_rel_name(referenceTableOid); + AppendShardIdToName(&referenceTableShardName, referenceTableShardId); + + Oid referenceTableSchemaOid = get_rel_namespace(referenceTableOid); + + Oid referenceTableShardOid = get_relname_relid(referenceTableShardName, + referenceTableSchemaOid); + + return referenceTableShardOid; +} + + /* * ShardIntervalCount returns number of shard intervals for a given distributed table. * The function returns 0 if table is not distributed, or no shards can be found for diff --git a/src/backend/distributed/planner/distributed_planner.c b/src/backend/distributed/planner/distributed_planner.c index c552da1f2..76276be36 100644 --- a/src/backend/distributed/planner/distributed_planner.c +++ b/src/backend/distributed/planner/distributed_planner.c @@ -2503,14 +2503,7 @@ UpdateReferenceTablesWithShard(Node *node, void *context) return false; } - ShardInterval *shardInterval = cacheEntry->sortedShardIntervalArray[0]; - uint64 shardId = shardInterval->shardId; - - char *relationName = get_rel_name(relationId); - AppendShardIdToName(&relationName, shardId); - - Oid schemaId = get_rel_namespace(relationId); - newRte->relid = get_relname_relid(relationName, schemaId); + newRte->relid = GetOnlyShardOidOfReferenceTable(relationId); /* * Parser locks relations in addRangeTableEntry(). So we should lock the diff --git a/src/include/distributed/master_metadata_utility.h b/src/include/distributed/master_metadata_utility.h index ca0d60f7c..72db1df10 100644 --- a/src/include/distributed/master_metadata_utility.h +++ b/src/include/distributed/master_metadata_utility.h @@ -99,6 +99,7 @@ extern Datum citus_relation_size(PG_FUNCTION_ARGS); /* Function declarations to read shard and shard placement data */ extern uint32 TableShardReplicationFactor(Oid relationId); extern List * LoadShardIntervalList(Oid relationId); +extern Oid GetOnlyShardOidOfReferenceTable(Oid referenceTableOid); extern int ShardIntervalCount(Oid relationId); extern List * LoadShardList(Oid relationId); extern void CopyShardInterval(ShardInterval *srcInterval, ShardInterval *destInterval);