diff --git a/src/backend/distributed/planner/multi_logical_planner.c b/src/backend/distributed/planner/multi_logical_planner.c index b6dd0dcce..f2faacb1d 100644 --- a/src/backend/distributed/planner/multi_logical_planner.c +++ b/src/backend/distributed/planner/multi_logical_planner.c @@ -2080,12 +2080,6 @@ ExtractRangeTableRelationWalker(Node *node, List **rangeTableRelationList) walkIsComplete = false; } - else - { - walkIsComplete = range_table_walker(list_make1(rangeTable), - ExtractRangeTableRelationWalker, - rangeTableRelationList, 0); - } } else if (IsA(node, Query)) { diff --git a/src/backend/distributed/planner/relation_restriction_equivalence.c b/src/backend/distributed/planner/relation_restriction_equivalence.c index 23a7a3f47..42196933b 100644 --- a/src/backend/distributed/planner/relation_restriction_equivalence.c +++ b/src/backend/distributed/planner/relation_restriction_equivalence.c @@ -1825,8 +1825,21 @@ RangeTableArrayContainsAnyRTEIdentities(RangeTblEntry **rangeTableEntries, int * (i.e.,rangeTableEntry could be a subquery where we're interested * in relations). */ - ExtractRangeTableRelationWalker((Node *) rangeTableEntry, - &rangeTableRelationList); + if (rangeTableEntry->rtekind == RTE_SUBQUERY) + { + ExtractRangeTableRelationWalker((Node *) rangeTableEntry->subquery, + &rangeTableRelationList); + } + else if (rangeTableEntry->rtekind == RTE_RELATION) + { + ExtractRangeTableRelationWalker((Node *) rangeTableEntry, + &rangeTableRelationList); + } + else + { + /* we currently do not accept any other RTE types here */ + continue; + } foreach(rteRelationCell, rangeTableRelationList) { diff --git a/src/backend/distributed/test/distribution_metadata.c b/src/backend/distributed/test/distribution_metadata.c index cae3d2eea..8702b863c 100644 --- a/src/backend/distributed/test/distribution_metadata.c +++ b/src/backend/distributed/test/distribution_metadata.c @@ -32,6 +32,7 @@ #include "nodes/pg_list.h" #include "nodes/primnodes.h" #include "storage/lock.h" +#include "tcop/tcopprot.h" #include "utils/array.h" #include "utils/elog.h" #include "utils/errcodes.h" @@ -48,6 +49,7 @@ PG_FUNCTION_INFO_V1(partition_type); PG_FUNCTION_INFO_V1(is_distributed_table); PG_FUNCTION_INFO_V1(create_monolithic_shard_row); PG_FUNCTION_INFO_V1(acquire_shared_shard_lock); +PG_FUNCTION_INFO_V1(relation_count_in_query); /* @@ -249,3 +251,43 @@ acquire_shared_shard_lock(PG_FUNCTION_ARGS) PG_RETURN_VOID(); } + + +/* + * relation_count_in_query return the first query's relation count. + */ +Datum +relation_count_in_query(PG_FUNCTION_ARGS) +{ + text *queryString = PG_GETARG_TEXT_P(0); + + char *queryStringChar = text_to_cstring(queryString); + List *parseTreeList = pg_parse_query(queryStringChar); + ListCell *parseTreeCell = NULL; + + foreach(parseTreeCell, parseTreeList) + { + Node *parsetree = (Node *) lfirst(parseTreeCell); + ListCell *queryTreeCell = NULL; + List *queryTreeList = NIL; + +#if (PG_VERSION_NUM >= 100000) + queryTreeList = pg_analyze_and_rewrite((RawStmt *) parsetree, queryStringChar, + NULL, 0, NULL); +#else + queryTreeList = pg_analyze_and_rewrite(parsetree, queryStringChar, NULL, 0); +#endif + + foreach(queryTreeCell, queryTreeList) + { + Query *query = lfirst(queryTreeCell); + List *rangeTableList = NIL; + + ExtractRangeTableRelationWalker((Node *) query, &rangeTableList); + + PG_RETURN_INT32(list_length(rangeTableList)); + } + } + + PG_RETURN_INT32(0); +} diff --git a/src/test/regress/expected/multi_distribution_metadata.out b/src/test/regress/expected/multi_distribution_metadata.out index bc9ea521e..ca7f8bb58 100644 --- a/src/test/regress/expected/multi_distribution_metadata.out +++ b/src/test/regress/expected/multi_distribution_metadata.out @@ -38,6 +38,10 @@ CREATE FUNCTION acquire_shared_shard_lock(bigint) RETURNS void AS 'citus' LANGUAGE C STRICT; +CREATE FUNCTION relation_count_in_query(text) + RETURNS int + AS 'citus' + LANGUAGE C STRICT; -- =================================================================== -- test distribution metadata functionality -- =================================================================== @@ -464,5 +468,96 @@ SELECT get_shard_id_for_distribution_column('get_shardid_test_table5', -999); 0 (1 row) +SET citus.shard_count TO 2; +CREATE TABLE events_table_count (user_id int, time timestamp, event_type int, value_2 int, value_3 float, value_4 bigint); +SELECT create_distributed_table('events_table_count', 'user_id'); + create_distributed_table +-------------------------- + +(1 row) + +CREATE TABLE users_table_count (user_id int, time timestamp, value_1 int, value_2 int, value_3 float, value_4 bigint); +SELECT create_distributed_table('users_table_count', 'user_id'); + create_distributed_table +-------------------------- + +(1 row) + +SELECT relation_count_in_query($$-- we can support arbitrary subqueries within UNIONs +SELECT ("final_query"."event_types") as types, count(*) AS sumOfEventType +FROM + ( SELECT + *, random() + FROM + (SELECT + "t"."user_id", "t"."time", unnest("t"."collected_events") AS "event_types" + FROM + ( SELECT + "t1"."user_id", min("t1"."time") AS "time", array_agg(("t1"."event") ORDER BY TIME ASC, event DESC) AS collected_events + FROM ( + (SELECT + * + FROM + (SELECT + events_table."time", 0 AS event, events_table."user_id" + FROM + "events_table_count" as events_table + WHERE + events_table.event_type IN (1, 2) ) events_subquery_1) + UNION + (SELECT * + FROM + ( + SELECT * FROM + ( + SELECT + max("events_table_count"."time"), + 0 AS event, + "events_table_count"."user_id" + FROM + "events_table_count", users_table_count as "users" + WHERE + "events_table_count".user_id = users.user_id AND + "events_table_count".event_type IN (1, 2) + GROUP BY "events_table_count"."user_id" + ) as events_subquery_5 + ) events_subquery_2) + UNION + (SELECT * + FROM + (SELECT + "events_table_count"."time", 2 AS event, "events_table_count"."user_id" + FROM + "events_table_count" + WHERE + event_type IN (3, 4) ) events_subquery_3) + UNION + (SELECT * + FROM + (SELECT + "events_table_count"."time", 3 AS event, "events_table_count"."user_id" + FROM + "events_table_count" + WHERE + event_type IN (5, 6)) events_subquery_4) + ) t1 + GROUP BY "t1"."user_id") AS t) "q" +INNER JOIN + (SELECT + "events_table_count"."user_id" + FROM + users_table_count as "events_table_count" + WHERE + value_1 > 0 and value_1 < 4) AS t + ON (t.user_id = q.user_id)) as final_query +GROUP BY + types +ORDER BY + types;$$); + relation_count_in_query +------------------------- + 6 +(1 row) + -- clear unnecessary tables; -DROP TABLE get_shardid_test_table1, get_shardid_test_table2, get_shardid_test_table3, get_shardid_test_table4, get_shardid_test_table5; +DROP TABLE get_shardid_test_table1, get_shardid_test_table2, get_shardid_test_table3, get_shardid_test_table4, get_shardid_test_table5, events_table_count; diff --git a/src/test/regress/sql/multi_distribution_metadata.sql b/src/test/regress/sql/multi_distribution_metadata.sql index 7100d567b..b5b4aa417 100644 --- a/src/test/regress/sql/multi_distribution_metadata.sql +++ b/src/test/regress/sql/multi_distribution_metadata.sql @@ -51,6 +51,11 @@ CREATE FUNCTION acquire_shared_shard_lock(bigint) AS 'citus' LANGUAGE C STRICT; +CREATE FUNCTION relation_count_in_query(text) + RETURNS int + AS 'citus' + LANGUAGE C STRICT; + -- =================================================================== -- test distribution metadata functionality -- =================================================================== @@ -259,5 +264,85 @@ SELECT get_shard_id_for_distribution_column('get_shardid_test_table5', 3248); SELECT get_shard_id_for_distribution_column('get_shardid_test_table5', 4001); SELECT get_shard_id_for_distribution_column('get_shardid_test_table5', -999); + +SET citus.shard_count TO 2; +CREATE TABLE events_table_count (user_id int, time timestamp, event_type int, value_2 int, value_3 float, value_4 bigint); +SELECT create_distributed_table('events_table_count', 'user_id'); + +CREATE TABLE users_table_count (user_id int, time timestamp, value_1 int, value_2 int, value_3 float, value_4 bigint); +SELECT create_distributed_table('users_table_count', 'user_id'); + +SELECT relation_count_in_query($$-- we can support arbitrary subqueries within UNIONs +SELECT ("final_query"."event_types") as types, count(*) AS sumOfEventType +FROM + ( SELECT + *, random() + FROM + (SELECT + "t"."user_id", "t"."time", unnest("t"."collected_events") AS "event_types" + FROM + ( SELECT + "t1"."user_id", min("t1"."time") AS "time", array_agg(("t1"."event") ORDER BY TIME ASC, event DESC) AS collected_events + FROM ( + (SELECT + * + FROM + (SELECT + events_table."time", 0 AS event, events_table."user_id" + FROM + "events_table_count" as events_table + WHERE + events_table.event_type IN (1, 2) ) events_subquery_1) + UNION + (SELECT * + FROM + ( + SELECT * FROM + ( + SELECT + max("events_table_count"."time"), + 0 AS event, + "events_table_count"."user_id" + FROM + "events_table_count", users_table_count as "users" + WHERE + "events_table_count".user_id = users.user_id AND + "events_table_count".event_type IN (1, 2) + GROUP BY "events_table_count"."user_id" + ) as events_subquery_5 + ) events_subquery_2) + UNION + (SELECT * + FROM + (SELECT + "events_table_count"."time", 2 AS event, "events_table_count"."user_id" + FROM + "events_table_count" + WHERE + event_type IN (3, 4) ) events_subquery_3) + UNION + (SELECT * + FROM + (SELECT + "events_table_count"."time", 3 AS event, "events_table_count"."user_id" + FROM + "events_table_count" + WHERE + event_type IN (5, 6)) events_subquery_4) + ) t1 + GROUP BY "t1"."user_id") AS t) "q" +INNER JOIN + (SELECT + "events_table_count"."user_id" + FROM + users_table_count as "events_table_count" + WHERE + value_1 > 0 and value_1 < 4) AS t + ON (t.user_id = q.user_id)) as final_query +GROUP BY + types +ORDER BY + types;$$); + -- clear unnecessary tables; -DROP TABLE get_shardid_test_table1, get_shardid_test_table2, get_shardid_test_table3, get_shardid_test_table4, get_shardid_test_table5; +DROP TABLE get_shardid_test_table1, get_shardid_test_table2, get_shardid_test_table3, get_shardid_test_table4, get_shardid_test_table5, events_table_count;