diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 0c5541995..02d14c58b 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -144,7 +144,7 @@ static bool SelectsFromDistributedTable(List *rangeTableList, Query *query); static List * get_all_actual_clauses(List *restrictinfo_list); static int CompareInsertValuesByShardId(const void *leftElement, const void *rightElement); -static uint64 GetInitialShardId(List *relationShardList); +static uint64 GetAnchorShardId(List *relationShardList); static List * TargetShardIntervalForFastPathQuery(Query *query, Const **partitionValueConst, bool *isMultiShardQuery); @@ -2004,7 +2004,7 @@ PlanRouterQuery(Query *originalQuery, } /* we need anchor shard id for select queries with router planner */ - shardId = GetInitialShardId(prunedRelationShardList); + shardId = GetAnchorShardId(prunedRelationShardList); /* * Determine the worker that has all shard placements if a shard placement found. @@ -2076,13 +2076,20 @@ PlanRouterQuery(Query *originalQuery, /* - * GetInitialShardId returns the initial shard id given relation shard list. If - * there is no relation shard exist in the list returns INAVLID_SHARD_ID. + * GetAnchorShardId returns the anchor shard id given relation shard list. + * The desired anchor shard is found as follows: + * + * - Return the first distributed table shard id in the relationShardList if + * there is any. + * - Return a random reference table shard id if all the shards belong to + * reference tables + * - Return INVALID_SHARD_ID on empty lists */ static uint64 -GetInitialShardId(List *relationShardList) +GetAnchorShardId(List *relationShardList) { ListCell *prunedRelationShardListCell = NULL; + uint64 referenceShardId = INVALID_SHARD_ID; foreach(prunedRelationShardListCell, relationShardList) { @@ -2096,10 +2103,18 @@ GetInitialShardId(List *relationShardList) } shardInterval = linitial(prunedShardList); - return shardInterval->shardId; + + if (ReferenceTableShardId(shardInterval->shardId)) + { + referenceShardId = shardInterval->shardId; + } + else + { + return shardInterval->shardId; + } } - return INVALID_SHARD_ID; + return referenceShardId; } diff --git a/src/test/regress/expected/multi_task_assignment_policy.out b/src/test/regress/expected/multi_task_assignment_policy.out index 686b1f9cc..70bae799b 100644 --- a/src/test/regress/expected/multi_task_assignment_policy.out +++ b/src/test/regress/expected/multi_task_assignment_policy.out @@ -184,7 +184,7 @@ RESET client_min_messages; -- which might change and we don't have any control over it. -- the important thing that we look for is that round-robin policy -- should give the same output for executions in the same transaction --- and different output for executions that are not insdie the +-- and different output for executions that are not inside the -- same transaction. To ensure that, we define a helper function BEGIN; SET LOCAL citus.explain_distributed_queries TO on; @@ -201,12 +201,11 @@ SELECT count(DISTINCT value) FROM explain_outputs; 1 (1 row) -DROP TABLE explain_outputs; +TRUNCATE explain_outputs; COMMIT; -- now test round-robin policy outside --- a transaction, we should see the assignements +-- a transaction, we should see the assignments -- change on every execution -CREATE TEMPORARY TABLE explain_outputs (value text); SET citus.task_assignment_policy TO 'round-robin'; SET citus.explain_distributed_queries TO ON; INSERT INTO explain_outputs @@ -248,12 +247,11 @@ SELECT count(DISTINCT value) FROM explain_outputs; 1 (1 row) -DROP TABLE explain_outputs; +TRUNCATE explain_outputs; COMMIT; -- now test round-robin policy outside --- a transaction, we should see the assignements +-- a transaction, we should see the assignments -- change on every execution -CREATE TEMPORARY TABLE explain_outputs (value text); SET citus.task_assignment_policy TO 'round-robin'; SET citus.explain_distributed_queries TO ON; INSERT INTO explain_outputs @@ -268,6 +266,40 @@ SELECT count(DISTINCT value) FROM explain_outputs; 2 (1 row) +TRUNCATE explain_outputs; +-- test that the round robin policy detects the anchor shard correctly +-- we should not pick a reference table shard as the anchor shard when joining with a distributed table +SET citus.shard_replication_factor TO 1; +CREATE TABLE task_assignment_nonreplicated_hash (test_id integer, ref_id integer); +SELECT create_distributed_table('task_assignment_nonreplicated_hash', 'test_id'); + create_distributed_table +-------------------------- + +(1 row) + +-- run the query two times to make sure that it hits the correct worker every time +INSERT INTO explain_outputs +SELECT parse_explain_output($cmd$ +EXPLAIN SELECT * +FROM (SELECT * FROM task_assignment_nonreplicated_hash WHERE test_id = 3) AS dist + LEFT JOIN task_assignment_reference_table ref + ON dist.ref_id = ref.test_id +$cmd$, 'task_assignment_nonreplicated_hash'); +INSERT INTO explain_outputs +SELECT parse_explain_output($cmd$ +EXPLAIN SELECT * +FROM (SELECT * FROM task_assignment_nonreplicated_hash WHERE test_id = 3) AS dist + LEFT JOIN task_assignment_reference_table ref + ON dist.ref_id = ref.test_id +$cmd$, 'task_assignment_nonreplicated_hash'); +-- The count should be 1 since the shard exists in only one worker node +SELECT count(DISTINCT value) FROM explain_outputs; + count +------- + 1 +(1 row) + +TRUNCATE explain_outputs; RESET citus.task_assignment_policy; RESET client_min_messages; -- we should be able to use round-robin with router queries that @@ -292,4 +324,5 @@ WITH q1 AS (SELECT * FROM task_assignment_test_table_2) SELECT * FROM q1; (0 rows) ROLLBACK; -DROP TABLE task_assignment_replicated_hash, task_assignment_reference_table; +DROP TABLE task_assignment_replicated_hash, task_assignment_nonreplicated_hash, + task_assignment_reference_table, explain_outputs; diff --git a/src/test/regress/sql/multi_task_assignment_policy.sql b/src/test/regress/sql/multi_task_assignment_policy.sql index ccbc72018..67341b8cc 100644 --- a/src/test/regress/sql/multi_task_assignment_policy.sql +++ b/src/test/regress/sql/multi_task_assignment_policy.sql @@ -133,7 +133,7 @@ RESET client_min_messages; -- which might change and we don't have any control over it. -- the important thing that we look for is that round-robin policy -- should give the same output for executions in the same transaction --- and different output for executions that are not insdie the +-- and different output for executions that are not inside the -- same transaction. To ensure that, we define a helper function BEGIN; @@ -149,15 +149,12 @@ INSERT INTO explain_outputs -- given that we're in the same transaction, the count should be 1 SELECT count(DISTINCT value) FROM explain_outputs; - -DROP TABLE explain_outputs; +TRUNCATE explain_outputs; COMMIT; -- now test round-robin policy outside --- a transaction, we should see the assignements +-- a transaction, we should see the assignments -- change on every execution -CREATE TEMPORARY TABLE explain_outputs (value text); - SET citus.task_assignment_policy TO 'round-robin'; SET citus.explain_distributed_queries TO ON; @@ -192,15 +189,12 @@ INSERT INTO explain_outputs -- given that we're in the same transaction, the count should be 1 SELECT count(DISTINCT value) FROM explain_outputs; - -DROP TABLE explain_outputs; +TRUNCATE explain_outputs; COMMIT; -- now test round-robin policy outside --- a transaction, we should see the assignements +-- a transaction, we should see the assignments -- change on every execution -CREATE TEMPORARY TABLE explain_outputs (value text); - SET citus.task_assignment_policy TO 'round-robin'; SET citus.explain_distributed_queries TO ON; @@ -212,7 +206,35 @@ INSERT INTO explain_outputs -- given that we're in the same transaction, the count should be 2 -- since there are two different worker nodes SELECT count(DISTINCT value) FROM explain_outputs; +TRUNCATE explain_outputs; +-- test that the round robin policy detects the anchor shard correctly +-- we should not pick a reference table shard as the anchor shard when joining with a distributed table +SET citus.shard_replication_factor TO 1; + +CREATE TABLE task_assignment_nonreplicated_hash (test_id integer, ref_id integer); +SELECT create_distributed_table('task_assignment_nonreplicated_hash', 'test_id'); + +-- run the query two times to make sure that it hits the correct worker every time +INSERT INTO explain_outputs +SELECT parse_explain_output($cmd$ +EXPLAIN SELECT * +FROM (SELECT * FROM task_assignment_nonreplicated_hash WHERE test_id = 3) AS dist + LEFT JOIN task_assignment_reference_table ref + ON dist.ref_id = ref.test_id +$cmd$, 'task_assignment_nonreplicated_hash'); + +INSERT INTO explain_outputs +SELECT parse_explain_output($cmd$ +EXPLAIN SELECT * +FROM (SELECT * FROM task_assignment_nonreplicated_hash WHERE test_id = 3) AS dist + LEFT JOIN task_assignment_reference_table ref + ON dist.ref_id = ref.test_id +$cmd$, 'task_assignment_nonreplicated_hash'); + +-- The count should be 1 since the shard exists in only one worker node +SELECT count(DISTINCT value) FROM explain_outputs; +TRUNCATE explain_outputs; RESET citus.task_assignment_policy; RESET client_min_messages; @@ -228,4 +250,5 @@ SET LOCAL citus.task_assignment_policy TO 'round-robin'; WITH q1 AS (SELECT * FROM task_assignment_test_table_2) SELECT * FROM q1; ROLLBACK; -DROP TABLE task_assignment_replicated_hash, task_assignment_reference_table; +DROP TABLE task_assignment_replicated_hash, task_assignment_nonreplicated_hash, + task_assignment_reference_table, explain_outputs;