diff --git a/src/backend/distributed/planner/shard_pruning.c b/src/backend/distributed/planner/shard_pruning.c index 02c799b63..18288348f 100644 --- a/src/backend/distributed/planner/shard_pruning.c +++ b/src/backend/distributed/planner/shard_pruning.c @@ -1571,6 +1571,22 @@ LowerShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach /* setup partitionColumnValue argument once */ fcSetArg(compareFunction, 0, partitionColumnValue); + /* + * Now we test partitionColumnValue used in where clause such as + * partCol > partitionColumnValue (or partCol >= partitionColumnValue) + * against four possibilities, these are: + * 1) partitionColumnValue falls into a specific shard, such that: + * partitionColumnValue >= shard[x].min, and + * partitionColumnValue < shard[x].max (or partitionColumnValue <= shard[x].max). + * 2) partitionColumnValue < shard[x].min for all the shards + * 3) partitionColumnValue > shard[x].max for all the shards + * 4) partitionColumnValue falls in between two shards, such that: + * partitionColumnValue > shard[x].max and + * partitionColumnValue < shard[x+1].min + * + * For 1), we find that shard in below loop using binary search and + * return the index of it. For the others, see the end of this function. + */ while (lowerBoundIndex < upperBoundIndex) { int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2); @@ -1603,7 +1619,7 @@ LowerShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach continue; } - /* found interval containing partitionValue */ + /* partitionColumnValue falls into a specific shard, possibility 1) */ return middleIndex; } @@ -1614,20 +1630,30 @@ LowerShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach * (we'd have hit the return middleIndex; case otherwise). Figure out * whether there's possibly any interval containing a value that's bigger * than the partition key one. + * + * Also note that we initialized lowerBoundIndex with 0. Similarly, + * we always set it to the index of the shard that we consider as our + * lower boundary during binary search. */ - if (lowerBoundIndex == 0) + if (lowerBoundIndex == shardCount) { - /* all intervals are bigger, thus return 0 */ - return 0; - } - else if (lowerBoundIndex == shardCount) - { - /* partition value is bigger than all partition values */ + /* + * Since lowerBoundIndex is an inclusive index, being equal to shardCount + * means all the shards have smaller values than partitionColumnValue, + * which corresponds to possibility 3). + * In that case, since we can't have a lower bound shard, we return + * INVALID_SHARD_INDEX here. + */ return INVALID_SHARD_INDEX; } - /* value falls inbetween intervals */ - return lowerBoundIndex + 1; + /* + * partitionColumnValue is either smaller than all the shards or falls in + * between two shards, which corresponds to possibility 2) or 4). + * Knowing that lowerBoundIndex is an inclusive index, we directly return + * it as the index for the lower bound shard here. + */ + return lowerBoundIndex; } @@ -1647,6 +1673,23 @@ UpperShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach /* setup partitionColumnValue argument once */ fcSetArg(compareFunction, 0, partitionColumnValue); + /* + * Now we test partitionColumnValue used in where clause such as + * partCol < partitionColumnValue (or partCol <= partitionColumnValue) + * against four possibilities, these are: + * 1) partitionColumnValue falls into a specific shard, such that: + * partitionColumnValue <= shard[x].max, and + * partitionColumnValue > shard[x].min (or partitionColumnValue >= shard[x].min). + * 2) partitionColumnValue > shard[x].max for all the shards + * 3) partitionColumnValue < shard[x].min for all the shards + * 4) partitionColumnValue falls in between two shards, such that: + * partitionColumnValue > shard[x].max and + * partitionColumnValue < shard[x+1].min + * + * For 1), we find that shard in below loop using binary search and + * return the index of it. For the others, see the end of this function. + */ + while (lowerBoundIndex < upperBoundIndex) { int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2); @@ -1679,7 +1722,7 @@ UpperShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach continue; } - /* found interval containing partitionValue */ + /* partitionColumnValue falls into a specific shard, possibility 1) */ return middleIndex; } @@ -1690,19 +1733,29 @@ UpperShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach * (we'd have hit the return middleIndex; case otherwise). Figure out * whether there's possibly any interval containing a value that's smaller * than the partition key one. + * + * Also note that we initialized upperBoundIndex with shardCount. Similarly, + * we always set it to the index of the next shard that we consider as our + * upper boundary during binary search. */ - if (upperBoundIndex == shardCount) + if (upperBoundIndex == 0) { - /* all intervals are smaller, thus return 0 */ - return shardCount - 1; - } - else if (upperBoundIndex == 0) - { - /* partition value is smaller than all partition values */ + /* + * Since upperBoundIndex is an exclusive index, being equal to 0 means + * all the shards have greater values than partitionColumnValue, which + * corresponds to possibility 3). + * In that case, since we can't have an upper bound shard, we return + * INVALID_SHARD_INDEX here. + */ return INVALID_SHARD_INDEX; } - /* value falls inbetween intervals, return the inverval one smaller as bound */ + /* + * partitionColumnValue is either greater than all the shards or falls in + * between two shards, which corresponds to possibility 2) or 4). + * Knowing that upperBoundIndex is an exclusive index, we return the index + * for the previous shard here. + */ return upperBoundIndex - 1; } diff --git a/src/test/regress/expected/multi_join_pruning.out b/src/test/regress/expected/multi_join_pruning.out index 065c8cfc3..66aff4a3e 100644 --- a/src/test/regress/expected/multi_join_pruning.out +++ b/src/test/regress/expected/multi_join_pruning.out @@ -38,14 +38,48 @@ DEBUG: Router planner does not support append-partitioned tables. -- Partition pruning left three shards for the lineitem and one shard for the -- orders table. These shard sets don't overlap, so join pruning should prune -- out all the shards, and leave us with an empty task list. +select * from pg_dist_shard +where logicalrelid='lineitem'::regclass or + logicalrelid='orders'::regclass +order by shardid; + logicalrelid | shardid | shardstorage | shardminvalue | shardmaxvalue +--------------------------------------------------------------------- + lineitem | 290000 | t | 1 | 5986 + lineitem | 290001 | t | 8997 | 14947 + orders | 290002 | t | 1 | 5986 + orders | 290003 | t | 8997 | 14947 +(4 rows) + +set citus.explain_distributed_queries to on; +-- explain the query before actually executing it +EXPLAIN SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders + WHERE l_orderkey = o_orderkey AND l_orderkey > 6000 AND o_orderkey < 6000; +DEBUG: Router planner does not support append-partitioned tables. +DEBUG: join prunable for intervals [8997,14947] and [1,5986] + QUERY PLAN +--------------------------------------------------------------------- + Aggregate (cost=750.01..750.02 rows=1 width=40) + -> Custom Scan (Citus Adaptive) (cost=0.00..0.00 rows=100000 width=24) + Task Count: 0 + Tasks Shown: All +(4 rows) + +set citus.explain_distributed_queries to off; +set client_min_messages to debug3; SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders WHERE l_orderkey = o_orderkey AND l_orderkey > 6000 AND o_orderkey < 6000; DEBUG: Router planner does not support append-partitioned tables. +DEBUG: constraint (gt) value: '6000'::bigint +DEBUG: shard count after pruning for lineitem: 1 +DEBUG: constraint (lt) value: '6000'::bigint +DEBUG: shard count after pruning for orders: 1 +DEBUG: join prunable for intervals [8997,14947] and [1,5986] sum | avg --------------------------------------------------------------------- | (1 row) +set client_min_messages to debug2; -- Make sure that we can handle filters without a column SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders WHERE l_orderkey = o_orderkey AND false; diff --git a/src/test/regress/expected/multi_prune_shard_list.out b/src/test/regress/expected/multi_prune_shard_list.out index 3aa44e766..26adfa0c4 100644 --- a/src/test/regress/expected/multi_prune_shard_list.out +++ b/src/test/regress/expected/multi_prune_shard_list.out @@ -84,7 +84,7 @@ SELECT prune_using_both_values('pruning', 'tomato', 'rose'); -- unit test of the equality expression generation code SELECT debug_equality_expression('pruning'); - debug_equality_expression + debug_equality_expression --------------------------------------------------------------------- {OPEXPR :opno 98 :opfuncid 67 :opresulttype 16 :opretset false :opcollid 0 :inputcollid 100 :args ({VAR :varno 1 :varattno 1 :vartype 25 :vartypmod -1 :varcollid 100 :varlevelsup 0 :varnoold 1 :varoattno 1 :location -1} {CONST :consttype 25 :consttypmod -1 :constcollid 100 :constlen -1 :constbyval false :constisnull true :location -1 :constvalue <>}) :location -1} (1 row) @@ -543,6 +543,154 @@ SELECT * FROM numeric_test WHERE id = 21.1::numeric; 21.1 | 87 (1 row) +CREATE TABLE range_dist_table_1 (dist_col BIGINT); +SELECT create_distributed_table('range_dist_table_1', 'dist_col', 'range'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +CALL public.create_range_partitioned_shards('range_dist_table_1', '{1000,3000,6000}', '{2000,4000,7000}'); +INSERT INTO range_dist_table_1 VALUES (1001); +INSERT INTO range_dist_table_1 VALUES (3800); +INSERT INTO range_dist_table_1 VALUES (6500); +-- all were returning false before fixing #5077 +SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col >= 2999; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 2999; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col >= 2500; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 2000; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 1001; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col >= 1001; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col > 1000; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col >= 1000; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +-- we didn't have such an off-by-one error in upper bound +-- calculation, but let's test such cases too +SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col <= 4001; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col < 4001; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col <= 4500; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col < 6000; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +-- now test with composite type and more shards +CREATE TYPE comp_type AS ( + int_field_1 BIGINT, + int_field_2 BIGINT +); +CREATE TYPE comp_type_range AS RANGE ( + subtype = comp_type); +CREATE TABLE range_dist_table_2 (dist_col comp_type); +SELECT create_distributed_table('range_dist_table_2', 'dist_col', 'range'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +CALL public.create_range_partitioned_shards( + 'range_dist_table_2', + '{"(10,24)","(10,58)", + "(10,90)","(20,100)"}', + '{"(10,25)","(10,65)", + "(10,99)","(20,100)"}'); +INSERT INTO range_dist_table_2 VALUES ((10, 24)); +INSERT INTO range_dist_table_2 VALUES ((10, 60)); +INSERT INTO range_dist_table_2 VALUES ((10, 91)); +INSERT INTO range_dist_table_2 VALUES ((20, 100)); +SELECT dist_col='(10, 60)'::comp_type FROM range_dist_table_2 +WHERE dist_col >= '(10,26)'::comp_type AND + dist_col <= '(10,75)'::comp_type; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +SELECT * FROM range_dist_table_2 +WHERE dist_col >= '(10,57)'::comp_type AND + dist_col <= '(10,95)'::comp_type +ORDER BY dist_col; + dist_col +--------------------------------------------------------------------- + (10,60) + (10,91) +(2 rows) + +SELECT * FROM range_dist_table_2 +WHERE dist_col >= '(10,57)'::comp_type +ORDER BY dist_col; + dist_col +--------------------------------------------------------------------- + (10,60) + (10,91) + (20,100) +(3 rows) + +SELECT dist_col='(20,100)'::comp_type FROM range_dist_table_2 +WHERE dist_col > '(20,99)'::comp_type; + ?column? +--------------------------------------------------------------------- + t +(1 row) + +DROP TABLE range_dist_table_1, range_dist_table_2; +DROP TYPE comp_type CASCADE; +NOTICE: drop cascades to type comp_type_range SET search_path TO public; DROP SCHEMA prune_shard_list CASCADE; NOTICE: drop cascades to 10 other objects diff --git a/src/test/regress/sql/multi_join_pruning.sql b/src/test/regress/sql/multi_join_pruning.sql index 69c722be7..b9feb999c 100644 --- a/src/test/regress/sql/multi_join_pruning.sql +++ b/src/test/regress/sql/multi_join_pruning.sql @@ -27,8 +27,21 @@ SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders -- orders table. These shard sets don't overlap, so join pruning should prune -- out all the shards, and leave us with an empty task list. +select * from pg_dist_shard +where logicalrelid='lineitem'::regclass or + logicalrelid='orders'::regclass +order by shardid; + +set citus.explain_distributed_queries to on; +-- explain the query before actually executing it +EXPLAIN SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders + WHERE l_orderkey = o_orderkey AND l_orderkey > 6000 AND o_orderkey < 6000; +set citus.explain_distributed_queries to off; + +set client_min_messages to debug3; SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders WHERE l_orderkey = o_orderkey AND l_orderkey > 6000 AND o_orderkey < 6000; +set client_min_messages to debug2; -- Make sure that we can handle filters without a column SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders diff --git a/src/test/regress/sql/multi_prune_shard_list.sql b/src/test/regress/sql/multi_prune_shard_list.sql index 224a759cf..651d84992 100644 --- a/src/test/regress/sql/multi_prune_shard_list.sql +++ b/src/test/regress/sql/multi_prune_shard_list.sql @@ -218,5 +218,75 @@ SELECT * FROM numeric_test WHERE id = 21; SELECT * FROM numeric_test WHERE id = 21::numeric; SELECT * FROM numeric_test WHERE id = 21.1::numeric; +CREATE TABLE range_dist_table_1 (dist_col BIGINT); +SELECT create_distributed_table('range_dist_table_1', 'dist_col', 'range'); + +CALL public.create_range_partitioned_shards('range_dist_table_1', '{1000,3000,6000}', '{2000,4000,7000}'); + +INSERT INTO range_dist_table_1 VALUES (1001); +INSERT INTO range_dist_table_1 VALUES (3800); +INSERT INTO range_dist_table_1 VALUES (6500); + +-- all were returning false before fixing #5077 +SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col >= 2999; +SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 2999; +SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col >= 2500; +SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 2000; + +SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 1001; +SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col >= 1001; +SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col > 1000; +SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col >= 1000; + +-- we didn't have such an off-by-one error in upper bound +-- calculation, but let's test such cases too +SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col <= 4001; +SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col < 4001; +SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col <= 4500; +SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col < 6000; + +-- now test with composite type and more shards +CREATE TYPE comp_type AS ( + int_field_1 BIGINT, + int_field_2 BIGINT +); + +CREATE TYPE comp_type_range AS RANGE ( + subtype = comp_type); + +CREATE TABLE range_dist_table_2 (dist_col comp_type); +SELECT create_distributed_table('range_dist_table_2', 'dist_col', 'range'); + +CALL public.create_range_partitioned_shards( + 'range_dist_table_2', + '{"(10,24)","(10,58)", + "(10,90)","(20,100)"}', + '{"(10,25)","(10,65)", + "(10,99)","(20,100)"}'); + +INSERT INTO range_dist_table_2 VALUES ((10, 24)); +INSERT INTO range_dist_table_2 VALUES ((10, 60)); +INSERT INTO range_dist_table_2 VALUES ((10, 91)); +INSERT INTO range_dist_table_2 VALUES ((20, 100)); + +SELECT dist_col='(10, 60)'::comp_type FROM range_dist_table_2 +WHERE dist_col >= '(10,26)'::comp_type AND + dist_col <= '(10,75)'::comp_type; + +SELECT * FROM range_dist_table_2 +WHERE dist_col >= '(10,57)'::comp_type AND + dist_col <= '(10,95)'::comp_type +ORDER BY dist_col; + +SELECT * FROM range_dist_table_2 +WHERE dist_col >= '(10,57)'::comp_type +ORDER BY dist_col; + +SELECT dist_col='(20,100)'::comp_type FROM range_dist_table_2 +WHERE dist_col > '(20,99)'::comp_type; + +DROP TABLE range_dist_table_1, range_dist_table_2; +DROP TYPE comp_type CASCADE; + SET search_path TO public; DROP SCHEMA prune_shard_list CASCADE;