diff --git a/src/backend/distributed/commands/local_multi_copy.c b/src/backend/distributed/commands/local_multi_copy.c index 60eaef3ce..f624d5e65 100644 --- a/src/backend/distributed/commands/local_multi_copy.c +++ b/src/backend/distributed/commands/local_multi_copy.c @@ -159,7 +159,7 @@ DoLocalCopy(StringInfo buffer, Oid relationId, int64 shardId, CopyStmt *copyStat */ LocalCopyBuffer = buffer; - Oid shardOid = GetShardLocalTableOid(relationId, shardId); + Oid shardOid = GetTableLocalShardOid(relationId, shardId); Relation shard = heap_open(shardOid, RowExclusiveLock); ParseState *pState = make_parsestate(NULL); diff --git a/src/backend/distributed/planner/deparse_shard_query.c b/src/backend/distributed/planner/deparse_shard_query.c index 8f271e042..a2e0628df 100644 --- a/src/backend/distributed/planner/deparse_shard_query.c +++ b/src/backend/distributed/planner/deparse_shard_query.c @@ -339,7 +339,7 @@ UpdateRelationsToLocalShardTables(Node *node, List *relationShardList) return true; } - Oid shardOid = GetShardLocalTableOid(relationShard->relationId, + Oid shardOid = GetTableLocalShardOid(relationShard->relationId, relationShard->shardId); newRte->relid = shardOid; diff --git a/src/backend/distributed/utils/shard_utils.c b/src/backend/distributed/utils/shard_utils.c index aa0a17921..1bfc8940a 100644 --- a/src/backend/distributed/utils/shard_utils.c +++ b/src/backend/distributed/utils/shard_utils.c @@ -11,21 +11,48 @@ #include "postgres.h" #include "utils/lsyscache.h" - +#include "distributed/metadata_cache.h" #include "distributed/relay_utility.h" #include "distributed/shard_utils.h" /* - * GetShardLocalTableOid returns the oid of the shard from the given distributed relation - * with the shardid. + * GetTableLocalShardOid returns the oid of the shard from the given distributed + * relation with the shardId. */ Oid -GetShardLocalTableOid(Oid distRelId, uint64 shardId) +GetTableLocalShardOid(Oid citusTableOid, uint64 shardId) { - char *relationName = get_rel_name(distRelId); - AppendShardIdToName(&relationName, shardId); + const char *citusTableName = get_rel_name(citusTableOid); - Oid schemaId = get_rel_namespace(distRelId); + Assert(citusTableName != NULL); - return get_relname_relid(relationName, schemaId); + /* construct shard relation name */ + char *shardRelationName = pstrdup(citusTableName); + AppendShardIdToName(&shardRelationName, shardId); + + Oid citusTableSchemaOid = get_rel_namespace(citusTableOid); + + Oid shardRelationOid = get_relname_relid(shardRelationName, citusTableSchemaOid); + + return shardRelationOid; +} + + +/* + * GetReferenceTableLocalShardOid returns OID of the local shard of given + * reference table. Caller of this function must ensure that referenceTableOid + * is owned by a reference table. + */ +Oid +GetReferenceTableLocalShardOid(Oid referenceTableOid) +{ + const CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(referenceTableOid); + + /* given OID should belong to a valid reference table */ + Assert(cacheEntry != NULL && cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE); + + const ShardInterval *shardInterval = cacheEntry->sortedShardIntervalArray[0]; + uint64 referenceTableShardId = shardInterval->shardId; + + return GetTableLocalShardOid(referenceTableOid, referenceTableShardId); } diff --git a/src/include/distributed/shard_utils.h b/src/include/distributed/shard_utils.h index 28addce15..02f3b77ca 100644 --- a/src/include/distributed/shard_utils.h +++ b/src/include/distributed/shard_utils.h @@ -13,6 +13,7 @@ #include "postgres.h" -extern Oid GetShardLocalTableOid(Oid distRelId, uint64 shardId); +extern Oid GetTableLocalShardOid(Oid citusTableOid, uint64 shardId); +extern Oid GetReferenceTableLocalShardOid(Oid referenceTableOid); #endif /* SHARD_UTILS_H */