diff --git a/src/backend/distributed/commands/local_multi_copy.c b/src/backend/distributed/commands/local_multi_copy.c index 60eaef3ce..f624d5e65 100644 --- a/src/backend/distributed/commands/local_multi_copy.c +++ b/src/backend/distributed/commands/local_multi_copy.c @@ -159,7 +159,7 @@ DoLocalCopy(StringInfo buffer, Oid relationId, int64 shardId, CopyStmt *copyStat */ LocalCopyBuffer = buffer; - Oid shardOid = GetShardLocalTableOid(relationId, shardId); + Oid shardOid = GetTableLocalShardOid(relationId, shardId); Relation shard = heap_open(shardOid, RowExclusiveLock); ParseState *pState = make_parsestate(NULL); diff --git a/src/backend/distributed/planner/deparse_shard_query.c b/src/backend/distributed/planner/deparse_shard_query.c index 8f271e042..a2e0628df 100644 --- a/src/backend/distributed/planner/deparse_shard_query.c +++ b/src/backend/distributed/planner/deparse_shard_query.c @@ -339,7 +339,7 @@ UpdateRelationsToLocalShardTables(Node *node, List *relationShardList) return true; } - Oid shardOid = GetShardLocalTableOid(relationShard->relationId, + Oid shardOid = GetTableLocalShardOid(relationShard->relationId, relationShard->shardId); newRte->relid = shardOid; diff --git a/src/backend/distributed/planner/distributed_planner.c b/src/backend/distributed/planner/distributed_planner.c index 07b7648df..321fdacb1 100644 --- a/src/backend/distributed/planner/distributed_planner.c +++ b/src/backend/distributed/planner/distributed_planner.c @@ -42,6 +42,7 @@ #include "distributed/query_utils.h" #include "distributed/recursive_planning.h" #include "distributed/shardinterval_utils.h" +#include "distributed/shard_utils.h" #include "distributed/version_compat.h" #include "distributed/worker_shard_visibility.h" #include "executor/executor.h" @@ -127,30 +128,22 @@ static void ResetPlannerRestrictionContext( static bool HasUnresolvedExternParamsWalker(Node *expression, ParamListInfo boundParams); static bool IsLocalReferenceTableJoin(Query *parse, List *rangeTableList); static bool QueryIsNotSimpleSelect(Node *node); -static bool UpdateReferenceTablesWithShard(Node *node, void *context); +static void UpdateReferenceTablesWithShard(List *rangeTableList); static PlannedStmt * PlanFastPathDistributedStmt(DistributedPlanningContext *planContext, Node *distributionKeyValue); static PlannedStmt * PlanDistributedStmt(DistributedPlanningContext *planContext, - List *rangeTableList, int rteIdCounter); + int rteIdCounter); /* Distributed planner hook */ PlannedStmt * distributed_planner(Query *parse, int cursorOptions, ParamListInfo boundParams) { - PlannedStmt *result = NULL; bool needsDistributedPlanning = false; - bool setPartitionedTablesInherited = false; - List *rangeTableList = ExtractRangeTableEntryList(parse); - int rteIdCounter = 1; bool fastPathRouterQuery = false; Node *distributionKeyValue = NULL; - DistributedPlanningContext planContext = { - .query = parse, - .cursorOptions = cursorOptions, - .boundParams = boundParams, - }; + List *rangeTableList = ExtractRangeTableEntryList(parse); if (cursorOptions & CURSOR_OPT_FORCE_DISTRIBUTED) { @@ -168,8 +161,9 @@ distributed_planner(Query *parse, int cursorOptions, ParamListInfo boundParams) * reference table names with shard tables names in the query, so * we can use the standard_planner for planning it locally. */ + UpdateReferenceTablesWithShard(rangeTableList); + needsDistributedPlanning = false; - UpdateReferenceTablesWithShard((Node *) parse, NULL); } else { @@ -181,6 +175,14 @@ distributed_planner(Query *parse, int cursorOptions, ParamListInfo boundParams) } } + int rteIdCounter = 1; + + DistributedPlanningContext planContext = { + .query = parse, + .cursorOptions = cursorOptions, + .boundParams = boundParams, + }; + if (fastPathRouterQuery) { /* @@ -217,7 +219,7 @@ distributed_planner(Query *parse, int cursorOptions, ParamListInfo boundParams) rteIdCounter = AssignRTEIdentities(rangeTableList, rteIdCounter); planContext.originalQuery = copyObject(parse); - setPartitionedTablesInherited = false; + bool setPartitionedTablesInherited = false; AdjustPartitioningForDistributedPlanning(rangeTableList, setPartitionedTablesInherited); } @@ -239,6 +241,7 @@ distributed_planner(Query *parse, int cursorOptions, ParamListInfo boundParams) */ PlannerLevel++; + PlannedStmt *result = NULL; PG_TRY(); { @@ -258,7 +261,7 @@ distributed_planner(Query *parse, int cursorOptions, ParamListInfo boundParams) planContext.boundParams); if (needsDistributedPlanning) { - result = PlanDistributedStmt(&planContext, rangeTableList, rteIdCounter); + result = PlanDistributedStmt(&planContext, rteIdCounter); } else if ((result = TryToDelegateFunctionCall(&planContext)) == NULL) { @@ -309,11 +312,42 @@ distributed_planner(Query *parse, int cursorOptions, ParamListInfo boundParams) List * ExtractRangeTableEntryList(Query *query) { - List *rangeTblList = NIL; + List *rteList = NIL; - ExtractRangeTableEntryWalker((Node *) query, &rangeTblList); + ExtractRangeTableEntryWalker((Node *) query, &rteList); - return rangeTblList; + return rteList; +} + + +/* + * ExtractClassifiedRangeTableEntryList extracts reference table rte's from + * the given rte list. + * Callers of this function are responsible for passing referenceTableRTEList + * to be non-null and initially pointing to an empty list. + */ +List * +ExtractReferenceTableRTEList(List *rteList) +{ + List *referenceTableRTEList = NIL; + + RangeTblEntry *rte = NULL; + foreach_ptr(rte, rteList) + { + if (rte->rtekind != RTE_RELATION || rte->relkind != RELKIND_RELATION) + { + continue; + } + + Oid relationOid = rte->relid; + if (IsCitusTable(relationOid) && PartitionMethod(relationOid) == + DISTRIBUTE_BY_NONE) + { + referenceTableRTEList = lappend(referenceTableRTEList, rte); + } + } + + return referenceTableRTEList; } @@ -328,21 +362,20 @@ ExtractRangeTableEntryList(Query *query) bool NeedsDistributedPlanning(Query *query) { - List *allRTEs = NIL; - CmdType commandType = query->commandType; - if (!CitusHasBeenLoaded()) { return false; } + CmdType commandType = query->commandType; + if (commandType != CMD_SELECT && commandType != CMD_INSERT && commandType != CMD_UPDATE && commandType != CMD_DELETE) { return false; } - ExtractRangeTableEntryWalker((Node *) query, &allRTEs); + List *allRTEs = ExtractRangeTableEntryList(query); return ListContainsDistributedTableRTE(allRTEs); } @@ -594,11 +627,10 @@ PlanFastPathDistributedStmt(DistributedPlanningContext *planContext, */ static PlannedStmt * PlanDistributedStmt(DistributedPlanningContext *planContext, - List *rangeTableList, int rteIdCounter) { /* may've inlined new relation rtes */ - rangeTableList = ExtractRangeTableEntryList(planContext->query); + List *rangeTableList = ExtractRangeTableEntryList(planContext->query); rteIdCounter = AssignRTEIdentities(rangeTableList, rteIdCounter); @@ -2464,62 +2496,25 @@ QueryIsNotSimpleSelect(Node *node) /* * UpdateReferenceTablesWithShard recursively replaces the reference table names - * in the given query with the shard table names. + * in the given range table list with the local shard table names. */ -static bool -UpdateReferenceTablesWithShard(Node *node, void *context) +static void +UpdateReferenceTablesWithShard(List *rangeTableList) { - if (node == NULL) + List *referenceTableRTEList = ExtractReferenceTableRTEList(rangeTableList); + + RangeTblEntry *rangeTableEntry = NULL; + foreach_ptr(rangeTableEntry, referenceTableRTEList) { - return false; + Oid referenceTableLocalShardOid = GetReferenceTableLocalShardOid( + rangeTableEntry->relid); + + rangeTableEntry->relid = referenceTableLocalShardOid; + + /* + * Parser locks relations in addRangeTableEntry(). So we should lock the + * modified ones too. + */ + LockRelationOid(referenceTableLocalShardOid, AccessShareLock); } - - /* want to look at all RTEs, even in subqueries, CTEs and such */ - if (IsA(node, Query)) - { - return query_tree_walker((Query *) node, UpdateReferenceTablesWithShard, - NULL, QTW_EXAMINE_RTES_BEFORE); - } - - if (!IsA(node, RangeTblEntry)) - { - return expression_tree_walker(node, UpdateReferenceTablesWithShard, - NULL); - } - - RangeTblEntry *newRte = (RangeTblEntry *) node; - - if (newRte->rtekind != RTE_RELATION) - { - return false; - } - - Oid relationId = newRte->relid; - if (!IsCitusTable(relationId)) - { - return false; - } - - CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(relationId); - if (cacheEntry->partitionMethod != DISTRIBUTE_BY_NONE) - { - return false; - } - - ShardInterval *shardInterval = cacheEntry->sortedShardIntervalArray[0]; - uint64 shardId = shardInterval->shardId; - - char *relationName = get_rel_name(relationId); - AppendShardIdToName(&relationName, shardId); - - Oid schemaId = get_rel_namespace(relationId); - newRte->relid = get_relname_relid(relationName, schemaId); - - /* - * Parser locks relations in addRangeTableEntry(). So we should lock the - * modified ones too. - */ - LockRelationOid(newRte->relid, AccessShareLock); - - return false; } diff --git a/src/backend/distributed/utils/shard_utils.c b/src/backend/distributed/utils/shard_utils.c index aa0a17921..1bfc8940a 100644 --- a/src/backend/distributed/utils/shard_utils.c +++ b/src/backend/distributed/utils/shard_utils.c @@ -11,21 +11,48 @@ #include "postgres.h" #include "utils/lsyscache.h" - +#include "distributed/metadata_cache.h" #include "distributed/relay_utility.h" #include "distributed/shard_utils.h" /* - * GetShardLocalTableOid returns the oid of the shard from the given distributed relation - * with the shardid. + * GetTableLocalShardOid returns the oid of the shard from the given distributed + * relation with the shardId. */ Oid -GetShardLocalTableOid(Oid distRelId, uint64 shardId) +GetTableLocalShardOid(Oid citusTableOid, uint64 shardId) { - char *relationName = get_rel_name(distRelId); - AppendShardIdToName(&relationName, shardId); + const char *citusTableName = get_rel_name(citusTableOid); - Oid schemaId = get_rel_namespace(distRelId); + Assert(citusTableName != NULL); - return get_relname_relid(relationName, schemaId); + /* construct shard relation name */ + char *shardRelationName = pstrdup(citusTableName); + AppendShardIdToName(&shardRelationName, shardId); + + Oid citusTableSchemaOid = get_rel_namespace(citusTableOid); + + Oid shardRelationOid = get_relname_relid(shardRelationName, citusTableSchemaOid); + + return shardRelationOid; +} + + +/* + * GetReferenceTableLocalShardOid returns OID of the local shard of given + * reference table. Caller of this function must ensure that referenceTableOid + * is owned by a reference table. + */ +Oid +GetReferenceTableLocalShardOid(Oid referenceTableOid) +{ + const CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(referenceTableOid); + + /* given OID should belong to a valid reference table */ + Assert(cacheEntry != NULL && cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE); + + const ShardInterval *shardInterval = cacheEntry->sortedShardIntervalArray[0]; + uint64 referenceTableShardId = shardInterval->shardId; + + return GetTableLocalShardOid(referenceTableOid, referenceTableShardId); } diff --git a/src/include/distributed/distributed_planner.h b/src/include/distributed/distributed_planner.h index 4197c67bf..b0ecbc2ce 100644 --- a/src/include/distributed/distributed_planner.h +++ b/src/include/distributed/distributed_planner.h @@ -186,6 +186,7 @@ typedef struct CitusCustomScanPath extern PlannedStmt * distributed_planner(Query *parse, int cursorOptions, ParamListInfo boundParams); extern List * ExtractRangeTableEntryList(Query *query); +extern List * ExtractReferenceTableRTEList(List *rteList); extern bool NeedsDistributedPlanning(Query *query); extern struct DistributedPlan * GetDistributedPlan(CustomScan *node); extern void multi_relation_restriction_hook(PlannerInfo *root, RelOptInfo *relOptInfo, diff --git a/src/include/distributed/shard_utils.h b/src/include/distributed/shard_utils.h index 28addce15..02f3b77ca 100644 --- a/src/include/distributed/shard_utils.h +++ b/src/include/distributed/shard_utils.h @@ -13,6 +13,7 @@ #include "postgres.h" -extern Oid GetShardLocalTableOid(Oid distRelId, uint64 shardId); +extern Oid GetTableLocalShardOid(Oid citusTableOid, uint64 shardId); +extern Oid GetReferenceTableLocalShardOid(Oid referenceTableOid); #endif /* SHARD_UTILS_H */