Fix lower boundary calculation when pruning range dist table shards (#5082)

This happens only when we have a "<" or "<=" filter on distribution
column of a range distributed table and that filter falls in between
two shards.

When the filter falls in between two shards:

  If the filter is ">" or ">=", then UpperShardBoundary was
  returning "upperBoundIndex - 1", where upperBoundIndex is
  exclusive shard index used during binary seach.
  This is expected since upperBoundIndex is an exclusive
  index.

  If the filter is "<" or "<=", then LowerShardBoundary was
  returning "lowerBoundIndex + 1", where lowerBoundIndex is
  inclusive shard index used during binary seach.
  On the other hand, since lowerBoundIndex is an inclusive
  index, we should just return lowerBoundIndex instead of
  doing "+ 1". Before this commit, we were missing leftmost
  shard in such queries.

* Remove useless conditional branches

The branch that we delete from UpperShardBoundary was obviously useless.

The other one in LowerShardBoundary became useless after we remove "+ 1"
from there.

This indeed is another proof of what & how we are fixing with this pr.

* Improve comments and add more

* Add some tests for upper bound calculation too

(cherry picked from commit b118d4188e)
pull/5098/head
Onur Tirtir 2021-07-02 14:48:21 +03:00
parent 9efd8e05d6
commit 3f6e903722
5 changed files with 338 additions and 20 deletions

View File

@ -1571,6 +1571,22 @@ LowerShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach
/* setup partitionColumnValue argument once */
fcSetArg(compareFunction, 0, partitionColumnValue);
/*
* Now we test partitionColumnValue used in where clause such as
* partCol > partitionColumnValue (or partCol >= partitionColumnValue)
* against four possibilities, these are:
* 1) partitionColumnValue falls into a specific shard, such that:
* partitionColumnValue >= shard[x].min, and
* partitionColumnValue < shard[x].max (or partitionColumnValue <= shard[x].max).
* 2) partitionColumnValue < shard[x].min for all the shards
* 3) partitionColumnValue > shard[x].max for all the shards
* 4) partitionColumnValue falls in between two shards, such that:
* partitionColumnValue > shard[x].max and
* partitionColumnValue < shard[x+1].min
*
* For 1), we find that shard in below loop using binary search and
* return the index of it. For the others, see the end of this function.
*/
while (lowerBoundIndex < upperBoundIndex)
{
int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2);
@ -1603,7 +1619,7 @@ LowerShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach
continue;
}
/* found interval containing partitionValue */
/* partitionColumnValue falls into a specific shard, possibility 1) */
return middleIndex;
}
@ -1614,20 +1630,30 @@ LowerShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach
* (we'd have hit the return middleIndex; case otherwise). Figure out
* whether there's possibly any interval containing a value that's bigger
* than the partition key one.
*
* Also note that we initialized lowerBoundIndex with 0. Similarly,
* we always set it to the index of the shard that we consider as our
* lower boundary during binary search.
*/
if (lowerBoundIndex == 0)
if (lowerBoundIndex == shardCount)
{
/* all intervals are bigger, thus return 0 */
return 0;
}
else if (lowerBoundIndex == shardCount)
{
/* partition value is bigger than all partition values */
/*
* Since lowerBoundIndex is an inclusive index, being equal to shardCount
* means all the shards have smaller values than partitionColumnValue,
* which corresponds to possibility 3).
* In that case, since we can't have a lower bound shard, we return
* INVALID_SHARD_INDEX here.
*/
return INVALID_SHARD_INDEX;
}
/* value falls inbetween intervals */
return lowerBoundIndex + 1;
/*
* partitionColumnValue is either smaller than all the shards or falls in
* between two shards, which corresponds to possibility 2) or 4).
* Knowing that lowerBoundIndex is an inclusive index, we directly return
* it as the index for the lower bound shard here.
*/
return lowerBoundIndex;
}
@ -1647,6 +1673,23 @@ UpperShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach
/* setup partitionColumnValue argument once */
fcSetArg(compareFunction, 0, partitionColumnValue);
/*
* Now we test partitionColumnValue used in where clause such as
* partCol < partitionColumnValue (or partCol <= partitionColumnValue)
* against four possibilities, these are:
* 1) partitionColumnValue falls into a specific shard, such that:
* partitionColumnValue <= shard[x].max, and
* partitionColumnValue > shard[x].min (or partitionColumnValue >= shard[x].min).
* 2) partitionColumnValue > shard[x].max for all the shards
* 3) partitionColumnValue < shard[x].min for all the shards
* 4) partitionColumnValue falls in between two shards, such that:
* partitionColumnValue > shard[x].max and
* partitionColumnValue < shard[x+1].min
*
* For 1), we find that shard in below loop using binary search and
* return the index of it. For the others, see the end of this function.
*/
while (lowerBoundIndex < upperBoundIndex)
{
int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2);
@ -1679,7 +1722,7 @@ UpperShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach
continue;
}
/* found interval containing partitionValue */
/* partitionColumnValue falls into a specific shard, possibility 1) */
return middleIndex;
}
@ -1690,19 +1733,29 @@ UpperShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCach
* (we'd have hit the return middleIndex; case otherwise). Figure out
* whether there's possibly any interval containing a value that's smaller
* than the partition key one.
*
* Also note that we initialized upperBoundIndex with shardCount. Similarly,
* we always set it to the index of the next shard that we consider as our
* upper boundary during binary search.
*/
if (upperBoundIndex == shardCount)
if (upperBoundIndex == 0)
{
/* all intervals are smaller, thus return 0 */
return shardCount - 1;
}
else if (upperBoundIndex == 0)
{
/* partition value is smaller than all partition values */
/*
* Since upperBoundIndex is an exclusive index, being equal to 0 means
* all the shards have greater values than partitionColumnValue, which
* corresponds to possibility 3).
* In that case, since we can't have an upper bound shard, we return
* INVALID_SHARD_INDEX here.
*/
return INVALID_SHARD_INDEX;
}
/* value falls inbetween intervals, return the inverval one smaller as bound */
/*
* partitionColumnValue is either greater than all the shards or falls in
* between two shards, which corresponds to possibility 2) or 4).
* Knowing that upperBoundIndex is an exclusive index, we return the index
* for the previous shard here.
*/
return upperBoundIndex - 1;
}

View File

@ -38,14 +38,48 @@ DEBUG: Router planner does not support append-partitioned tables.
-- Partition pruning left three shards for the lineitem and one shard for the
-- orders table. These shard sets don't overlap, so join pruning should prune
-- out all the shards, and leave us with an empty task list.
select * from pg_dist_shard
where logicalrelid='lineitem'::regclass or
logicalrelid='orders'::regclass
order by shardid;
logicalrelid | shardid | shardstorage | shardminvalue | shardmaxvalue
---------------------------------------------------------------------
lineitem | 290000 | t | 1 | 5986
lineitem | 290001 | t | 8997 | 14947
orders | 290002 | t | 1 | 5986
orders | 290003 | t | 8997 | 14947
(4 rows)
set citus.explain_distributed_queries to on;
-- explain the query before actually executing it
EXPLAIN SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders
WHERE l_orderkey = o_orderkey AND l_orderkey > 6000 AND o_orderkey < 6000;
DEBUG: Router planner does not support append-partitioned tables.
DEBUG: join prunable for intervals [8997,14947] and [1,5986]
QUERY PLAN
---------------------------------------------------------------------
Aggregate (cost=750.01..750.02 rows=1 width=40)
-> Custom Scan (Citus Adaptive) (cost=0.00..0.00 rows=100000 width=24)
Task Count: 0
Tasks Shown: All
(4 rows)
set citus.explain_distributed_queries to off;
set client_min_messages to debug3;
SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders
WHERE l_orderkey = o_orderkey AND l_orderkey > 6000 AND o_orderkey < 6000;
DEBUG: Router planner does not support append-partitioned tables.
DEBUG: constraint (gt) value: '6000'::bigint
DEBUG: shard count after pruning for lineitem: 1
DEBUG: constraint (lt) value: '6000'::bigint
DEBUG: shard count after pruning for orders: 1
DEBUG: join prunable for intervals [8997,14947] and [1,5986]
sum | avg
---------------------------------------------------------------------
|
(1 row)
set client_min_messages to debug2;
-- Make sure that we can handle filters without a column
SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders
WHERE l_orderkey = o_orderkey AND false;

View File

@ -84,7 +84,7 @@ SELECT prune_using_both_values('pruning', 'tomato', 'rose');
-- unit test of the equality expression generation code
SELECT debug_equality_expression('pruning');
debug_equality_expression
debug_equality_expression
---------------------------------------------------------------------
{OPEXPR :opno 98 :opfuncid 67 :opresulttype 16 :opretset false :opcollid 0 :inputcollid 100 :args ({VAR :varno 1 :varattno 1 :vartype 25 :vartypmod -1 :varcollid 100 :varlevelsup 0 :varnoold 1 :varoattno 1 :location -1} {CONST :consttype 25 :consttypmod -1 :constcollid 100 :constlen -1 :constbyval false :constisnull true :location -1 :constvalue <>}) :location -1}
(1 row)
@ -543,6 +543,154 @@ SELECT * FROM numeric_test WHERE id = 21.1::numeric;
21.1 | 87
(1 row)
CREATE TABLE range_dist_table_1 (dist_col BIGINT);
SELECT create_distributed_table('range_dist_table_1', 'dist_col', 'range');
create_distributed_table
---------------------------------------------------------------------
(1 row)
CALL public.create_range_partitioned_shards('range_dist_table_1', '{1000,3000,6000}', '{2000,4000,7000}');
INSERT INTO range_dist_table_1 VALUES (1001);
INSERT INTO range_dist_table_1 VALUES (3800);
INSERT INTO range_dist_table_1 VALUES (6500);
-- all were returning false before fixing #5077
SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col >= 2999;
?column?
---------------------------------------------------------------------
t
(1 row)
SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 2999;
?column?
---------------------------------------------------------------------
t
(1 row)
SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col >= 2500;
?column?
---------------------------------------------------------------------
t
(1 row)
SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 2000;
?column?
---------------------------------------------------------------------
t
(1 row)
SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 1001;
?column?
---------------------------------------------------------------------
t
(1 row)
SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col >= 1001;
?column?
---------------------------------------------------------------------
t
(1 row)
SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col > 1000;
?column?
---------------------------------------------------------------------
t
(1 row)
SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col >= 1000;
?column?
---------------------------------------------------------------------
t
(1 row)
-- we didn't have such an off-by-one error in upper bound
-- calculation, but let's test such cases too
SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col <= 4001;
?column?
---------------------------------------------------------------------
t
(1 row)
SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col < 4001;
?column?
---------------------------------------------------------------------
t
(1 row)
SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col <= 4500;
?column?
---------------------------------------------------------------------
t
(1 row)
SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col < 6000;
?column?
---------------------------------------------------------------------
t
(1 row)
-- now test with composite type and more shards
CREATE TYPE comp_type AS (
int_field_1 BIGINT,
int_field_2 BIGINT
);
CREATE TYPE comp_type_range AS RANGE (
subtype = comp_type);
CREATE TABLE range_dist_table_2 (dist_col comp_type);
SELECT create_distributed_table('range_dist_table_2', 'dist_col', 'range');
create_distributed_table
---------------------------------------------------------------------
(1 row)
CALL public.create_range_partitioned_shards(
'range_dist_table_2',
'{"(10,24)","(10,58)",
"(10,90)","(20,100)"}',
'{"(10,25)","(10,65)",
"(10,99)","(20,100)"}');
INSERT INTO range_dist_table_2 VALUES ((10, 24));
INSERT INTO range_dist_table_2 VALUES ((10, 60));
INSERT INTO range_dist_table_2 VALUES ((10, 91));
INSERT INTO range_dist_table_2 VALUES ((20, 100));
SELECT dist_col='(10, 60)'::comp_type FROM range_dist_table_2
WHERE dist_col >= '(10,26)'::comp_type AND
dist_col <= '(10,75)'::comp_type;
?column?
---------------------------------------------------------------------
t
(1 row)
SELECT * FROM range_dist_table_2
WHERE dist_col >= '(10,57)'::comp_type AND
dist_col <= '(10,95)'::comp_type
ORDER BY dist_col;
dist_col
---------------------------------------------------------------------
(10,60)
(10,91)
(2 rows)
SELECT * FROM range_dist_table_2
WHERE dist_col >= '(10,57)'::comp_type
ORDER BY dist_col;
dist_col
---------------------------------------------------------------------
(10,60)
(10,91)
(20,100)
(3 rows)
SELECT dist_col='(20,100)'::comp_type FROM range_dist_table_2
WHERE dist_col > '(20,99)'::comp_type;
?column?
---------------------------------------------------------------------
t
(1 row)
DROP TABLE range_dist_table_1, range_dist_table_2;
DROP TYPE comp_type CASCADE;
NOTICE: drop cascades to type comp_type_range
SET search_path TO public;
DROP SCHEMA prune_shard_list CASCADE;
NOTICE: drop cascades to 10 other objects

View File

@ -27,8 +27,21 @@ SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders
-- orders table. These shard sets don't overlap, so join pruning should prune
-- out all the shards, and leave us with an empty task list.
select * from pg_dist_shard
where logicalrelid='lineitem'::regclass or
logicalrelid='orders'::regclass
order by shardid;
set citus.explain_distributed_queries to on;
-- explain the query before actually executing it
EXPLAIN SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders
WHERE l_orderkey = o_orderkey AND l_orderkey > 6000 AND o_orderkey < 6000;
set citus.explain_distributed_queries to off;
set client_min_messages to debug3;
SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders
WHERE l_orderkey = o_orderkey AND l_orderkey > 6000 AND o_orderkey < 6000;
set client_min_messages to debug2;
-- Make sure that we can handle filters without a column
SELECT sum(l_linenumber), avg(l_linenumber) FROM lineitem, orders

View File

@ -218,5 +218,75 @@ SELECT * FROM numeric_test WHERE id = 21;
SELECT * FROM numeric_test WHERE id = 21::numeric;
SELECT * FROM numeric_test WHERE id = 21.1::numeric;
CREATE TABLE range_dist_table_1 (dist_col BIGINT);
SELECT create_distributed_table('range_dist_table_1', 'dist_col', 'range');
CALL public.create_range_partitioned_shards('range_dist_table_1', '{1000,3000,6000}', '{2000,4000,7000}');
INSERT INTO range_dist_table_1 VALUES (1001);
INSERT INTO range_dist_table_1 VALUES (3800);
INSERT INTO range_dist_table_1 VALUES (6500);
-- all were returning false before fixing #5077
SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col >= 2999;
SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 2999;
SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col >= 2500;
SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 2000;
SELECT SUM(dist_col)=3800+6500 FROM range_dist_table_1 WHERE dist_col > 1001;
SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col >= 1001;
SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col > 1000;
SELECT SUM(dist_col)=1001+3800+6500 FROM range_dist_table_1 WHERE dist_col >= 1000;
-- we didn't have such an off-by-one error in upper bound
-- calculation, but let's test such cases too
SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col <= 4001;
SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col < 4001;
SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col <= 4500;
SELECT SUM(dist_col)=1001+3800 FROM range_dist_table_1 WHERE dist_col < 6000;
-- now test with composite type and more shards
CREATE TYPE comp_type AS (
int_field_1 BIGINT,
int_field_2 BIGINT
);
CREATE TYPE comp_type_range AS RANGE (
subtype = comp_type);
CREATE TABLE range_dist_table_2 (dist_col comp_type);
SELECT create_distributed_table('range_dist_table_2', 'dist_col', 'range');
CALL public.create_range_partitioned_shards(
'range_dist_table_2',
'{"(10,24)","(10,58)",
"(10,90)","(20,100)"}',
'{"(10,25)","(10,65)",
"(10,99)","(20,100)"}');
INSERT INTO range_dist_table_2 VALUES ((10, 24));
INSERT INTO range_dist_table_2 VALUES ((10, 60));
INSERT INTO range_dist_table_2 VALUES ((10, 91));
INSERT INTO range_dist_table_2 VALUES ((20, 100));
SELECT dist_col='(10, 60)'::comp_type FROM range_dist_table_2
WHERE dist_col >= '(10,26)'::comp_type AND
dist_col <= '(10,75)'::comp_type;
SELECT * FROM range_dist_table_2
WHERE dist_col >= '(10,57)'::comp_type AND
dist_col <= '(10,95)'::comp_type
ORDER BY dist_col;
SELECT * FROM range_dist_table_2
WHERE dist_col >= '(10,57)'::comp_type
ORDER BY dist_col;
SELECT dist_col='(20,100)'::comp_type FROM range_dist_table_2
WHERE dist_col > '(20,99)'::comp_type;
DROP TABLE range_dist_table_1, range_dist_table_2;
DROP TYPE comp_type CASCADE;
SET search_path TO public;
DROP SCHEMA prune_shard_list CASCADE;