diff --git a/src/backend/distributed/planner/multi_join_order.c b/src/backend/distributed/planner/multi_join_order.c index f9d25db1d..bb39d5678 100644 --- a/src/backend/distributed/planner/multi_join_order.c +++ b/src/backend/distributed/planner/multi_join_order.c @@ -212,7 +212,6 @@ ExtractLeftMostRangeTableIndex(Node *node, int *rangeTableIndex) static bool JoinOnColumns(Var *currentColumn, Var *candidateColumn, List *joinClauseList) { - ListCell *joinClauseCell = NULL; if (currentColumn == NULL || candidateColumn == NULL) { /* @@ -222,15 +221,16 @@ JoinOnColumns(Var *currentColumn, Var *candidateColumn, List *joinClauseList) return false; } - foreach(joinClauseCell, joinClauseList) + Node *joinClause = NULL; + foreach_ptr(joinClause, joinClauseList) { - OpExpr *joinClause = castNode(OpExpr, lfirst(joinClauseCell)); - Var *leftColumn = LeftColumnOrNULL(joinClause); - Var *rightColumn = RightColumnOrNULL(joinClause); - if (!OperatorImplementsEquality(joinClause->opno)) + if (!NodeIsEqualsOpExpr(joinClause)) { continue; } + OpExpr *joinClauseOpExpr = castNode(OpExpr, joinClause); + Var *leftColumn = LeftColumnOrNULL(joinClauseOpExpr); + Var *rightColumn = RightColumnOrNULL(joinClauseOpExpr); /* * Check if both join columns and both partition key columns match, since the @@ -253,6 +253,22 @@ JoinOnColumns(Var *currentColumn, Var *candidateColumn, List *joinClauseList) } +/* + * NodeIsEqualsOpExpr checks if the node is an OpExpr, where the operator + * matches OperatorImplementsEquality. + */ +bool +NodeIsEqualsOpExpr(Node *node) +{ + if (!IsA(node, OpExpr)) + { + return false; + } + OpExpr *opExpr = castNode(OpExpr, node); + return OperatorImplementsEquality(opExpr->opno); +} + + /* * JoinOrderList calculates the best join order and join rules that apply given * the list of tables and join clauses. First, the function generates a set of @@ -1010,21 +1026,21 @@ SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable, OpExpr * SinglePartitionJoinClause(Var *partitionColumn, List *applicableJoinClauses) { - ListCell *applicableJoinClauseCell = NULL; if (partitionColumn == NULL) { return NULL; } - foreach(applicableJoinClauseCell, applicableJoinClauses) + Node *applicableJoinClause = NULL; + foreach_ptr(applicableJoinClause, applicableJoinClauses) { - OpExpr *applicableJoinClause = castNode(OpExpr, lfirst(applicableJoinClauseCell)); - if (!OperatorImplementsEquality(applicableJoinClause->opno)) + if (!NodeIsEqualsOpExpr(applicableJoinClause)) { continue; } - Var *leftColumn = LeftColumnOrNULL(applicableJoinClause); - Var *rightColumn = RightColumnOrNULL(applicableJoinClause); + OpExpr *applicableJoinOpExpr = castNode(OpExpr, applicableJoinClause); + Var *leftColumn = LeftColumnOrNULL(applicableJoinOpExpr); + Var *rightColumn = RightColumnOrNULL(applicableJoinOpExpr); if (leftColumn == NULL || rightColumn == NULL) { /* not a simple partition column join */ @@ -1042,7 +1058,7 @@ SinglePartitionJoinClause(Var *partitionColumn, List *applicableJoinClauses) { if (leftColumn->vartype == rightColumn->vartype) { - return applicableJoinClause; + return applicableJoinOpExpr; } else { @@ -1089,17 +1105,16 @@ DualPartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable, OpExpr * DualPartitionJoinClause(List *applicableJoinClauses) { - ListCell *applicableJoinClauseCell = NULL; - - foreach(applicableJoinClauseCell, applicableJoinClauses) + Node *applicableJoinClause = NULL; + foreach_ptr(applicableJoinClause, applicableJoinClauses) { - OpExpr *applicableJoinClause = (OpExpr *) lfirst(applicableJoinClauseCell); - if (!OperatorImplementsEquality(applicableJoinClause->opno)) + if (!NodeIsEqualsOpExpr(applicableJoinClause)) { continue; } - Var *leftColumn = LeftColumnOrNULL(applicableJoinClause); - Var *rightColumn = RightColumnOrNULL(applicableJoinClause); + OpExpr *applicableJoinOpExpr = castNode(OpExpr, applicableJoinClause); + Var *leftColumn = LeftColumnOrNULL(applicableJoinOpExpr); + Var *rightColumn = RightColumnOrNULL(applicableJoinOpExpr); if (leftColumn == NULL || rightColumn == NULL) { @@ -1109,7 +1124,7 @@ DualPartitionJoinClause(List *applicableJoinClauses) /* we only need to check that the join column types match */ if (leftColumn->vartype == rightColumn->vartype) { - return applicableJoinClause; + return applicableJoinOpExpr; } else { @@ -1170,9 +1185,9 @@ MakeJoinOrderNode(TableEntry *tableEntry, JoinRuleType joinRuleType, * in either the list of tables on the left *or* in the right hand table. */ bool -IsApplicableJoinClause(List *leftTableIdList, uint32 rightTableId, OpExpr *joinClause) +IsApplicableJoinClause(List *leftTableIdList, uint32 rightTableId, Node *joinClause) { - List *varList = pull_var_clause_default((Node *) joinClause); + List *varList = pull_var_clause_default(joinClause); Var *var = NULL; bool joinContainsRightTable = false; foreach_ptr(var, varList) @@ -1208,15 +1223,14 @@ IsApplicableJoinClause(List *leftTableIdList, uint32 rightTableId, OpExpr *joinC List * ApplicableJoinClauses(List *leftTableIdList, uint32 rightTableId, List *joinClauseList) { - ListCell *joinClauseCell = NULL; List *applicableJoinClauses = NIL; /* make sure joinClauseList contains only join clauses */ joinClauseList = JoinClauseList(joinClauseList); - foreach(joinClauseCell, joinClauseList) + Node *joinClause = NULL; + foreach_ptr(joinClause, joinClauseList) { - OpExpr *joinClause = castNode(OpExpr, lfirst(joinClauseCell)); if (IsApplicableJoinClause(leftTableIdList, rightTableId, joinClause)) { applicableJoinClauses = lappend(applicableJoinClauses, joinClause); diff --git a/src/backend/distributed/planner/multi_logical_planner.c b/src/backend/distributed/planner/multi_logical_planner.c index 44832dd71..1372086c5 100644 --- a/src/backend/distributed/planner/multi_logical_planner.c +++ b/src/backend/distributed/planner/multi_logical_planner.c @@ -1429,11 +1429,6 @@ IsJoinClause(Node *clause) { Var *var = NULL; - if (!IsA(clause, OpExpr)) - { - return false; - } - /* * take all column references from the clause, if we find 2 column references from a * different relation we assume this is a join clause @@ -1689,7 +1684,7 @@ MultiSelectNode(List *whereClauseList) foreach(whereClauseCell, whereClauseList) { Node *whereClause = (Node *) lfirst(whereClauseCell); - if (IsSelectClause(whereClause) || or_clause(whereClause)) + if (IsSelectClause(whereClause)) { selectClauseList = lappend(selectClauseList, whereClause); } diff --git a/src/backend/distributed/planner/multi_physical_planner.c b/src/backend/distributed/planner/multi_physical_planner.c index 703dfacf3..5a9752245 100644 --- a/src/backend/distributed/planner/multi_physical_planner.c +++ b/src/backend/distributed/planner/multi_physical_planner.c @@ -40,6 +40,7 @@ #include "distributed/master_protocol.h" #include "distributed/metadata_cache.h" #include "distributed/multi_router_planner.h" +#include "distributed/multi_join_order.h" #include "distributed/multi_logical_optimizer.h" #include "distributed/multi_logical_planner.h" #include "distributed/multi_physical_planner.h" @@ -3497,7 +3498,6 @@ JoinSequenceArray(List *rangeTableFragmentsList, Query *jobQuery, List *dependen JoinExpr *joinExpr = (JoinExpr *) lfirst(joinExprCell); RangeTblRef *rightTableRef = (RangeTblRef *) joinExpr->rarg; uint32 nextRangeTableId = rightTableRef->rtindex; - ListCell *nextJoinClauseCell = NULL; Index existingRangeTableId = 0; bool applyJoinPruning = false; @@ -3518,17 +3518,23 @@ JoinSequenceArray(List *rangeTableFragmentsList, Query *jobQuery, List *dependen * We now determine if we can apply join pruning between existing range * tables and this new one. */ - foreach(nextJoinClauseCell, nextJoinClauseList) + Node *nextJoinClause = NULL; + foreach_ptr(nextJoinClause, nextJoinClauseList) { - OpExpr *nextJoinClause = (OpExpr *) lfirst(nextJoinClauseCell); - - if (!IsJoinClause((Node *) nextJoinClause)) + if (!NodeIsEqualsOpExpr(nextJoinClause)) { continue; } - Var *leftColumn = LeftColumnOrNULL(nextJoinClause); - Var *rightColumn = RightColumnOrNULL(nextJoinClause); + OpExpr *nextJoinClauseOpExpr = castNode(OpExpr, nextJoinClause); + + if (!IsJoinClause((Node *) nextJoinClauseOpExpr)) + { + continue; + } + + Var *leftColumn = LeftColumnOrNULL(nextJoinClauseOpExpr); + Var *rightColumn = RightColumnOrNULL(nextJoinClauseOpExpr); if (leftColumn == NULL || rightColumn == NULL) { continue; @@ -3567,7 +3573,7 @@ JoinSequenceArray(List *rangeTableFragmentsList, Query *jobQuery, List *dependen if (leftPartitioned && rightPartitioned) { /* make sure this join clause references only simple columns */ - CheckJoinBetweenColumns(nextJoinClause); + CheckJoinBetweenColumns(nextJoinClauseOpExpr); applyJoinPruning = true; break; diff --git a/src/include/distributed/multi_join_order.h b/src/include/distributed/multi_join_order.h index 230726ef0..f5d4a7f66 100644 --- a/src/include/distributed/multi_join_order.h +++ b/src/include/distributed/multi_join_order.h @@ -15,6 +15,8 @@ #ifndef MULTI_JOIN_ORDER_H #define MULTI_JOIN_ORDER_H +#include "postgres.h" + #include "nodes/pg_list.h" #include "nodes/primnodes.h" @@ -83,9 +85,10 @@ extern bool EnableSingleHashRepartitioning; extern List * JoinExprList(FromExpr *fromExpr); extern List * JoinOrderList(List *rangeTableEntryList, List *joinClauseList); extern bool IsApplicableJoinClause(List *leftTableIdList, uint32 rightTableId, - OpExpr *joinClause); + Node *joinClause); extern List * ApplicableJoinClauses(List *leftTableIdList, uint32 rightTableId, List *joinClauseList); +extern bool NodeIsEqualsOpExpr(Node *node); extern OpExpr * SinglePartitionJoinClause(Var *partitionColumn, List *applicableJoinClauses); extern OpExpr * DualPartitionJoinClause(List *applicableJoinClauses);