diff --git a/src/backend/distributed/planner/distributed_planner.c b/src/backend/distributed/planner/distributed_planner.c index fd5876e1e..13457520c 100644 --- a/src/backend/distributed/planner/distributed_planner.c +++ b/src/backend/distributed/planner/distributed_planner.c @@ -1702,6 +1702,19 @@ multi_join_restriction_hook(PlannerInfo *root, JoinType jointype, JoinPathExtraData *extra) { + if (bms_is_empty(innerrel->relids) || bms_is_empty(outerrel->relids)) + { + /* + * We do not expect empty relids. Still, ignoring such JoinRestriction is + * preferable for two reasons: + * 1. This might be a query that doesn't rely on JoinRestrictions at all (e.g., + * local query). + * 2. We cannot process them when they are empty (and likely to segfault if + * we allow as-is). + */ + ereport(DEBUG1, (errmsg("Join restriction information is NULL"))); + } + /* * Use a memory context that's guaranteed to live long enough, could be * called in a more shortly lived one (e.g. with GEQO). @@ -1711,23 +1724,22 @@ multi_join_restriction_hook(PlannerInfo *root, MemoryContext restrictionsMemoryContext = plannerRestrictionContext->memoryContext; MemoryContext oldMemoryContext = MemoryContextSwitchTo(restrictionsMemoryContext); - /* - * We create a copy of restrictInfoList because it may be created in a memory - * context which will be deleted when we still need it, thus we create a copy - * of it in our memory context. - */ - List *restrictInfoList = copyObject(extra->restrictlist); - JoinRestrictionContext *joinRestrictionContext = plannerRestrictionContext->joinRestrictionContext; Assert(joinRestrictionContext != NULL); JoinRestriction *joinRestriction = palloc0(sizeof(JoinRestriction)); joinRestriction->joinType = jointype; - joinRestriction->joinRestrictInfoList = restrictInfoList; joinRestriction->plannerInfo = root; - joinRestriction->innerrel = innerrel; - joinRestriction->outerrel = outerrel; + + /* + * We create a copy of restrictInfoList and relids because with geqo they may + * be created in a memory context which will be deleted when we still need it, + * thus we create a copy of it in our memory context. + */ + joinRestriction->joinRestrictInfoList = copyObject(extra->restrictlist); + joinRestriction->innerrelRelids = bms_copy(innerrel->relids); + joinRestriction->outerrelRelids = bms_copy(outerrel->relids); joinRestrictionContext->joinRestrictionList = lappend(joinRestrictionContext->joinRestrictionList, joinRestriction); diff --git a/src/backend/distributed/planner/query_pushdown_planning.c b/src/backend/distributed/planner/query_pushdown_planning.c index dfade2acd..e2e8be6cd 100644 --- a/src/backend/distributed/planner/query_pushdown_planning.c +++ b/src/backend/distributed/planner/query_pushdown_planning.c @@ -79,8 +79,8 @@ static DeferredErrorMessage * DeferredErrorIfUnsupportedRecurringTuplesJoin( PlannerRestrictionContext *plannerRestrictionContext); static DeferredErrorMessage * DeferErrorIfUnsupportedTableCombination(Query *queryTree); static bool ExtractSetOperationStatmentWalker(Node *node, List **setOperationList); -static RecurringTuplesType FetchFirstRecurType(PlannerInfo *plannerInfo, RelOptInfo * - relationInfo); +static RecurringTuplesType FetchFirstRecurType(PlannerInfo *plannerInfo, + Relids relids); static bool ContainsRecurringRTE(RangeTblEntry *rangeTableEntry, RecurringTuplesType *recurType); static bool ContainsRecurringRangeTable(List *rangeTable, RecurringTuplesType *recurType); @@ -93,7 +93,7 @@ static void UpdateColumnToMatchingTargetEntry(Var *column, static MultiTable * MultiSubqueryPushdownTable(Query *subquery); static List * CreateSubqueryTargetEntryList(List *columnList); static bool RelationInfoContainsOnlyRecurringTuples(PlannerInfo *plannerInfo, - RelOptInfo *relationInfo); + Relids relids); /* * ShouldUseSubqueryPushDown determines whether it's desirable to use @@ -782,8 +782,8 @@ DeferredErrorIfUnsupportedRecurringTuplesJoin( joinRestrictionCell); JoinType joinType = joinRestriction->joinType; PlannerInfo *plannerInfo = joinRestriction->plannerInfo; - RelOptInfo *innerrel = joinRestriction->innerrel; - RelOptInfo *outerrel = joinRestriction->outerrel; + Relids innerrelRelids = joinRestriction->innerrelRelids; + Relids outerrelRelids = joinRestriction->outerrelRelids; if (joinType == JOIN_SEMI || joinType == JOIN_ANTI || joinType == JOIN_LEFT) { @@ -793,7 +793,7 @@ DeferredErrorIfUnsupportedRecurringTuplesJoin( * recurring or not. Otherwise, we check the outer side for recurring * tuples. */ - if (RelationInfoContainsOnlyRecurringTuples(plannerInfo, innerrel)) + if (RelationInfoContainsOnlyRecurringTuples(plannerInfo, innerrelRelids)) { continue; } @@ -805,37 +805,37 @@ DeferredErrorIfUnsupportedRecurringTuplesJoin( * the query. The reason is that recurring tuples on every shard would * be added to the result, which is wrong. */ - if (RelationInfoContainsOnlyRecurringTuples(plannerInfo, outerrel)) + if (RelationInfoContainsOnlyRecurringTuples(plannerInfo, outerrelRelids)) { /* * Find the first (or only) recurring RTE to give a meaningful * error to the user. */ - recurType = FetchFirstRecurType(plannerInfo, outerrel); + recurType = FetchFirstRecurType(plannerInfo, outerrelRelids); break; } } else if (joinType == JOIN_FULL) { - if (RelationInfoContainsOnlyRecurringTuples(plannerInfo, innerrel)) + if (RelationInfoContainsOnlyRecurringTuples(plannerInfo, innerrelRelids)) { /* * Find the first (or only) recurring RTE to give a meaningful * error to the user. */ - recurType = FetchFirstRecurType(plannerInfo, innerrel); + recurType = FetchFirstRecurType(plannerInfo, innerrelRelids); break; } - if (RelationInfoContainsOnlyRecurringTuples(plannerInfo, outerrel)) + if (RelationInfoContainsOnlyRecurringTuples(plannerInfo, outerrelRelids)) { /* * Find the first (or only) recurring RTE to give a meaningful * error to the user. */ - recurType = FetchFirstRecurType(plannerInfo, outerrel); + recurType = FetchFirstRecurType(plannerInfo, outerrelRelids); break; } @@ -1301,10 +1301,8 @@ ExtractSetOperationStatmentWalker(Node *node, List **setOperationList) * a RelOptInfo is not recurring. */ static bool -RelationInfoContainsOnlyRecurringTuples(PlannerInfo *plannerInfo, - RelOptInfo *relationInfo) +RelationInfoContainsOnlyRecurringTuples(PlannerInfo *plannerInfo, Relids relids) { - Relids relids = relationInfo->relids; int relationId = -1; while ((relationId = bms_next_member(relids, relationId)) >= 0) @@ -1340,10 +1338,8 @@ RelationInfoContainsOnlyRecurringTuples(PlannerInfo *plannerInfo, * table entry list of planner info, planner info is also passed. */ static RecurringTuplesType -FetchFirstRecurType(PlannerInfo *plannerInfo, RelOptInfo * - relationInfo) +FetchFirstRecurType(PlannerInfo *plannerInfo, Relids relids) { - Relids relids = relationInfo->relids; RecurringTuplesType recurType = RECURRING_TUPLES_INVALID; int relationId = -1; diff --git a/src/include/distributed/distributed_planner.h b/src/include/distributed/distributed_planner.h index 2b5836789..4abe52d02 100644 --- a/src/include/distributed/distributed_planner.h +++ b/src/include/distributed/distributed_planner.h @@ -83,8 +83,8 @@ typedef struct JoinRestriction JoinType joinType; List *joinRestrictInfoList; PlannerInfo *plannerInfo; - RelOptInfo *innerrel; - RelOptInfo *outerrel; + Relids innerrelRelids; + Relids outerrelRelids; } JoinRestriction; typedef struct FastPathRestrictionContext diff --git a/src/test/regress/expected/geqo.out b/src/test/regress/expected/geqo.out new file mode 100644 index 000000000..7e400d997 --- /dev/null +++ b/src/test/regress/expected/geqo.out @@ -0,0 +1,127 @@ +-- test geqo +CREATE SCHEMA geqo_schema; +SET search_path TO geqo_schema; +CREATE TABLE dist (a int, b int); +SELECT create_distributed_table('dist', 'a'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +INSERT INTO dist VALUES (1, 1), (2, 2), (3, 3); +CREATE TABLE dist2 (a int, b int); +SELECT create_distributed_table('dist2', 'a'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +INSERT INTO dist2 VALUES (1, 1), (2, 2), (3, 3); +SET geqo_threshold TO 2; +SET geqo_pool_size TO 1000; +SET geqo_generations TO 1000; +SET citus.enable_repartition_joins to ON; +SELECT count(*) FROM dist d1 LEFT JOIN dist2 d2 ON d1.a = d2.a LEFT JOIN dist d3 ON d3.a = d2.a; + count +--------------------------------------------------------------------- + 3 +(1 row) + +-- JOINs with CTEs: +WITH cte_1 AS (SELECT * FROM dist OFFSET 0) + SELECT count(*) FROM dist d1 LEFT JOIN dist2 d2 ON d1.a = d2.a LEFT JOIN dist d3 ON d3.a = d2.a LEFT JOIN cte_1 ON true; + count +--------------------------------------------------------------------- + 9 +(1 row) + +WITH cte_1 AS (SELECT * FROM dist OFFSET 0), + cte_2 AS (SELECT * FROM dist2 OFFSET 0), + cte_3 AS (SELECT * FROM dist OFFSET 0) + SELECT count(*) FROM cte_1 d1 LEFT JOIN cte_2 d2 ON d1.a = d2.a LEFT JOIN cte_3 d3 ON d3.a = d2.a LEFT JOIN cte_1 ON true; + count +--------------------------------------------------------------------- + 9 +(1 row) + +-- Inner JOIN: +SELECT count(*) FROM dist d1 JOIN dist2 d2 ON d1.a = d2.a JOIN dist d3 ON d3.a = d2.a; + count +--------------------------------------------------------------------- + 3 +(1 row) + +-- subquery join +SELECT count(*) FROM (SELECT *, random() FROM dist) as d1 JOIN (SELECT *, random() FROM dist2) d2 ON d1.a = d2.a JOIN (SELECT *, random() FROM dist) as d3 ON d3.a = d2.a; + count +--------------------------------------------------------------------- + 3 +(1 row) + +SELECT count(*) FROM (SELECT *, random() FROM dist) as d1 LEFT JOIN (SELECT *, random() FROM dist2) d2 ON d1.a = d2.a LEFT JOIN (SELECT *, random() FROM dist) as d3 ON d3.a = d2.a; + count +--------------------------------------------------------------------- + 3 +(1 row) + +-- router query +SELECT count(*) FROM dist d1 LEFT JOIN dist2 d2 ON d1.a = d2.a LEFT JOIN dist d3 ON d3.a = d2.a WHERE d1.a = 1 AND d2.a = 1 AND d3.a = 1; + count +--------------------------------------------------------------------- + 1 +(1 row) + +-- fast path router query +SELECT count(*) FROM dist WHERE a = 1; + count +--------------------------------------------------------------------- + 1 +(1 row) + +-- simple INSERT +INSERT INTO dist (a) VALUES (1); +-- repartition join, probably not relevant, but still be defensive +SELECT count(*) FROM dist d1 JOIN dist2 d2 ON d1.b = d2.b JOIN dist d3 ON d3.b = d2.b; + count +--------------------------------------------------------------------- + 3 +(1 row) + +-- update query with join +UPDATE dist SET b = foo.a FROM (SELECT d1.a FROM dist d1 JOIN dist2 d2 USING(a)) foo WHERE foo.a = dist.a RETURNING *; + a | b | a +--------------------------------------------------------------------- + 1 | 1 | 1 + 1 | 1 | 1 + 2 | 2 | 2 + 3 | 3 | 3 +(4 rows) + +-- insert select via repartitioning +INSERT INTO dist (a) SELECT max(d1.b) FROM dist d1 JOIN dist2 d2 ON d1.a = d2.a JOIN dist d3 ON d3.a = d2.a GROUP BY d1.a; +SELECT count(*) FROM dist; + count +--------------------------------------------------------------------- + 7 +(1 row) + +-- insert select pushdown +INSERT INTO dist SELECT d1.* FROM dist d1 JOIN dist2 d2 ON d1.a = d2.a JOIN dist d3 ON d3.a = d2.a WHERE d1.b < 2 AND d2.b < 2; +SELECT count(*) FROM dist; + count +--------------------------------------------------------------------- + 13 +(1 row) + +-- insert select via coordinator +INSERT INTO dist SELECT d1.* FROM dist d1 JOIN dist2 d2 ON d1.a = d2.a JOIN dist d3 ON d3.a = d2.a WHERE d1.b < 2 AND d2.b < 2 OFFSET 0; +SELECT count(*) FROM dist; + count +--------------------------------------------------------------------- + 85 +(1 row) + +DROP SCHEMA geqo_schema CASCADE; +NOTICE: drop cascades to 2 other objects +DETAIL: drop cascades to table dist +drop cascades to table dist2 diff --git a/src/test/regress/multi_schedule b/src/test/regress/multi_schedule index 5a7ab94a8..57ac0f47a 100644 --- a/src/test/regress/multi_schedule +++ b/src/test/regress/multi_schedule @@ -92,7 +92,7 @@ test: multi_explain hyperscale_tutorial partitioned_intermediate_results distrib test: multi_basic_queries cross_join multi_complex_expressions multi_subquery multi_subquery_complex_queries multi_subquery_behavioral_analytics test: multi_subquery_complex_reference_clause multi_subquery_window_functions multi_view multi_sql_function multi_prepare_sql test: sql_procedure multi_function_in_join row_types materialized_view undistribute_table -test: multi_subquery_in_where_reference_clause join adaptive_executor propagate_set_commands +test: multi_subquery_in_where_reference_clause join geqo adaptive_executor propagate_set_commands test: multi_subquery_union multi_subquery_in_where_clause multi_subquery_misc statement_cancel_error_message test: multi_agg_distinct multi_agg_approximate_distinct multi_limit_clause_approximate multi_outer_join_reference multi_single_relation_subquery multi_prepare_plsql set_role_in_transaction test: multi_reference_table multi_select_for_update relation_access_tracking pg13_with_ties diff --git a/src/test/regress/sql/geqo.sql b/src/test/regress/sql/geqo.sql new file mode 100644 index 000000000..a6c930f5a --- /dev/null +++ b/src/test/regress/sql/geqo.sql @@ -0,0 +1,64 @@ +-- test geqo +CREATE SCHEMA geqo_schema; +SET search_path TO geqo_schema; + +CREATE TABLE dist (a int, b int); +SELECT create_distributed_table('dist', 'a'); +INSERT INTO dist VALUES (1, 1), (2, 2), (3, 3); + +CREATE TABLE dist2 (a int, b int); +SELECT create_distributed_table('dist2', 'a'); +INSERT INTO dist2 VALUES (1, 1), (2, 2), (3, 3); + +SET geqo_threshold TO 2; +SET geqo_pool_size TO 1000; +SET geqo_generations TO 1000; + +SET citus.enable_repartition_joins to ON; + +SELECT count(*) FROM dist d1 LEFT JOIN dist2 d2 ON d1.a = d2.a LEFT JOIN dist d3 ON d3.a = d2.a; + +-- JOINs with CTEs: +WITH cte_1 AS (SELECT * FROM dist OFFSET 0) + SELECT count(*) FROM dist d1 LEFT JOIN dist2 d2 ON d1.a = d2.a LEFT JOIN dist d3 ON d3.a = d2.a LEFT JOIN cte_1 ON true; + +WITH cte_1 AS (SELECT * FROM dist OFFSET 0), + cte_2 AS (SELECT * FROM dist2 OFFSET 0), + cte_3 AS (SELECT * FROM dist OFFSET 0) + SELECT count(*) FROM cte_1 d1 LEFT JOIN cte_2 d2 ON d1.a = d2.a LEFT JOIN cte_3 d3 ON d3.a = d2.a LEFT JOIN cte_1 ON true; + +-- Inner JOIN: +SELECT count(*) FROM dist d1 JOIN dist2 d2 ON d1.a = d2.a JOIN dist d3 ON d3.a = d2.a; + +-- subquery join +SELECT count(*) FROM (SELECT *, random() FROM dist) as d1 JOIN (SELECT *, random() FROM dist2) d2 ON d1.a = d2.a JOIN (SELECT *, random() FROM dist) as d3 ON d3.a = d2.a; +SELECT count(*) FROM (SELECT *, random() FROM dist) as d1 LEFT JOIN (SELECT *, random() FROM dist2) d2 ON d1.a = d2.a LEFT JOIN (SELECT *, random() FROM dist) as d3 ON d3.a = d2.a; + +-- router query +SELECT count(*) FROM dist d1 LEFT JOIN dist2 d2 ON d1.a = d2.a LEFT JOIN dist d3 ON d3.a = d2.a WHERE d1.a = 1 AND d2.a = 1 AND d3.a = 1; + +-- fast path router query +SELECT count(*) FROM dist WHERE a = 1; + +-- simple INSERT +INSERT INTO dist (a) VALUES (1); + +-- repartition join, probably not relevant, but still be defensive +SELECT count(*) FROM dist d1 JOIN dist2 d2 ON d1.b = d2.b JOIN dist d3 ON d3.b = d2.b; + +-- update query with join +UPDATE dist SET b = foo.a FROM (SELECT d1.a FROM dist d1 JOIN dist2 d2 USING(a)) foo WHERE foo.a = dist.a RETURNING *; + +-- insert select via repartitioning +INSERT INTO dist (a) SELECT max(d1.b) FROM dist d1 JOIN dist2 d2 ON d1.a = d2.a JOIN dist d3 ON d3.a = d2.a GROUP BY d1.a; +SELECT count(*) FROM dist; + +-- insert select pushdown +INSERT INTO dist SELECT d1.* FROM dist d1 JOIN dist2 d2 ON d1.a = d2.a JOIN dist d3 ON d3.a = d2.a WHERE d1.b < 2 AND d2.b < 2; +SELECT count(*) FROM dist; + +-- insert select via coordinator +INSERT INTO dist SELECT d1.* FROM dist d1 JOIN dist2 d2 ON d1.a = d2.a JOIN dist d3 ON d3.a = d2.a WHERE d1.b < 2 AND d2.b < 2 OFFSET 0; +SELECT count(*) FROM dist; + +DROP SCHEMA geqo_schema CASCADE;