From ba93d32c8afb9a7a2fa2c55109f0e94161d663ce Mon Sep 17 00:00:00 2001 From: Andres Freund Date: Thu, 27 Apr 2017 16:03:05 -0700 Subject: [PATCH 1/5] Fix: Make FindShardIntervalIndex robust against 0 shards. --- .../distributed/utils/shardinterval_utils.c | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/backend/distributed/utils/shardinterval_utils.c b/src/backend/distributed/utils/shardinterval_utils.c index 8a63375e8..5133da261 100644 --- a/src/backend/distributed/utils/shardinterval_utils.c +++ b/src/backend/distributed/utils/shardinterval_utils.c @@ -247,11 +247,12 @@ FindShardInterval(Datum partitionColumnValue, DistTableCacheEntry *cacheEntry) * the searched value. Note that the searched value must be the hashed value * of the original value if the distribution method is hash. * - * Note that, if the searched value can not be found for hash partitioned tables, - * we error out. This should only happen if something is terribly wrong, either - * metadata tables are corrupted or we have a bug somewhere. Such as a hash - * function which returns a value not in the range of [INT32_MIN, INT32_MAX] can - * fire this. + * Note that, if the searched value can not be found for hash partitioned + * tables, we error out (unless there are no shards, in which case + * INVALID_SHARD_INDEX is returned). This should only happen if something is + * terribly wrong, either metadata tables are corrupted or we have a bug + * somewhere. Such as a hash function which returns a value not in the range + * of [INT32_MIN, INT32_MAX] can fire this. */ static int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry) @@ -264,6 +265,11 @@ FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry) !cacheEntry->hasUniformHashDistribution); int shardIndex = INVALID_SHARD_INDEX; + if (shardCount == 0) + { + return INVALID_SHARD_INDEX; + } + if (partitionMethod == DISTRIBUTE_BY_HASH) { if (useBinarySearch) From 52571c00adeddddeca10d32e71104810caf96970 Mon Sep 17 00:00:00 2001 From: Andres Freund Date: Tue, 25 Apr 2017 20:26:40 -0700 Subject: [PATCH 2/5] Build DistTableCacheEntry->shardIntervalCompareFunction even for 0 shards. Previously we, unnecessarily, used a the first shard's type information to to look up the comparison function. But that information is already available, so use it. That's helpful because we sometimes want to access the comparator function even if there's no shards. --- .../distributed/utils/metadata_cache.c | 68 ++++--------------- src/include/distributed/metadata_cache.h | 2 +- 2 files changed, 15 insertions(+), 55 deletions(-) diff --git a/src/backend/distributed/utils/metadata_cache.c b/src/backend/distributed/utils/metadata_cache.c index a48405597..710526aaa 100644 --- a/src/backend/distributed/utils/metadata_cache.c +++ b/src/backend/distributed/utils/metadata_cache.c @@ -134,8 +134,6 @@ static ShardCacheEntry * LookupShardCacheEntry(int64 shardId); static DistTableCacheEntry * LookupDistTableCacheEntry(Oid relationId); static void BuildDistTableCacheEntry(DistTableCacheEntry *cacheEntry); static void BuildCachedShardList(DistTableCacheEntry *cacheEntry); -static FmgrInfo * ShardIntervalCompareFunction(ShardInterval **shardIntervalArray, - char partitionMethod); static ShardInterval ** SortShardIntervalArray(ShardInterval **shardIntervalArray, int shardCount, FmgrInfo * @@ -622,6 +620,10 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) List *distShardTupleList = NIL; int shardIntervalArrayLength = 0; int shardIndex = 0; + GetPartitionTypeInputInfo(cacheEntry->partitionKeyString, + cacheEntry->partitionMethod, + &intervalTypeId, + &intervalTypeMod); distShardTupleList = LookupDistShardTuples(cacheEntry->relationId); shardIntervalArrayLength = list_length(distShardTupleList); @@ -631,13 +633,6 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) TupleDesc distShardTupleDesc = RelationGetDescr(distShardRelation); ListCell *distShardTupleCell = NULL; int arrayIndex = 0; - Oid intervalTypeId = InvalidOid; - int32 intervalTypeMod = -1; - - GetPartitionTypeInputInfo(cacheEntry->partitionKeyString, - cacheEntry->partitionMethod, - &intervalTypeId, - &intervalTypeMod); shardIntervalArray = MemoryContextAllocZero(CacheMemoryContext, shardIntervalArrayLength * @@ -677,23 +672,19 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) } /* decide and allocate interval comparison function */ - if (cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE) + if (intervalTypeId != InvalidOid) + { + /* allocate the comparison function in the cache context */ + MemoryContext oldContext = MemoryContextSwitchTo(CacheMemoryContext); + + shardIntervalCompareFunction = GetFunctionInfo(intervalTypeId, BTREE_AM_OID, + BTORDER_PROC); + MemoryContextSwitchTo(oldContext); + } + else { shardIntervalCompareFunction = NULL; } - else if (shardIntervalArrayLength > 0) - { - MemoryContext oldContext = CurrentMemoryContext; - - /* allocate the comparison function in the cache context */ - oldContext = MemoryContextSwitchTo(CacheMemoryContext); - - shardIntervalCompareFunction = - ShardIntervalCompareFunction(shardIntervalArray, - cacheEntry->partitionMethod); - - MemoryContextSwitchTo(oldContext); - } /* reference tables has a single shard which is not initialized */ if (cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE) @@ -798,37 +789,6 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) } -/* - * ShardIntervalCompareFunction returns the appropriate compare function for the - * partition column type. In case of hash-partitioning, it always returns the compare - * function for integers. Callers of this function has to ensure that shardIntervalArray - * has at least one element. - */ -static FmgrInfo * -ShardIntervalCompareFunction(ShardInterval **shardIntervalArray, char partitionMethod) -{ - FmgrInfo *shardIntervalCompareFunction = NULL; - Oid comparisonTypeId = InvalidOid; - - Assert(shardIntervalArray != NULL); - - if (partitionMethod == DISTRIBUTE_BY_HASH) - { - comparisonTypeId = INT4OID; - } - else - { - ShardInterval *shardInterval = shardIntervalArray[0]; - comparisonTypeId = shardInterval->valueTypeId; - } - - shardIntervalCompareFunction = GetFunctionInfo(comparisonTypeId, BTREE_AM_OID, - BTORDER_PROC); - - return shardIntervalCompareFunction; -} - - /* * SortedShardIntervalArray sorts the input shardIntervalArray. Shard intervals with * no min/max values are placed at the end of the array. diff --git a/src/include/distributed/metadata_cache.h b/src/include/distributed/metadata_cache.h index 24324aa56..d80b4a038 100644 --- a/src/include/distributed/metadata_cache.h +++ b/src/include/distributed/metadata_cache.h @@ -49,7 +49,7 @@ typedef struct int shardIntervalArrayLength; ShardInterval **sortedShardIntervalArray; - FmgrInfo *shardIntervalCompareFunction; /* NULL if no shard intervals exist */ + FmgrInfo *shardIntervalCompareFunction; /* NULL if DISTRIBUTE_BY_NONE */ FmgrInfo *hashFunction; /* NULL if table is not distributed by hash */ /* pg_dist_shard_placement metadata */ From 105483ec5640d2bf037541e9bcd18011c962dfcb Mon Sep 17 00:00:00 2001 From: Andres Freund Date: Tue, 25 Apr 2017 23:29:10 -0700 Subject: [PATCH 3/5] Add DistTableCacheEntry->shardValueCompareFunction. That's useful when comparing values a hash-partitioned table is filtered by. The existing shardIntervalCompareFunction is about comparing hashed values, not unhashed ones. The added btree opclass function is so we can get a comparator back. This should be changed much more widely, but is not necessary so far. --- .../distributed/utils/metadata_cache.c | 39 ++++++++++++++++++- src/include/distributed/metadata_cache.h | 10 ++++- .../regress/expected/multi_data_types.out | 12 ++++-- src/test/regress/sql/multi_data_types.sql | 13 +++++-- 4 files changed, 65 insertions(+), 9 deletions(-) diff --git a/src/backend/distributed/utils/metadata_cache.c b/src/backend/distributed/utils/metadata_cache.c index 710526aaa..a36dd5c46 100644 --- a/src/backend/distributed/utils/metadata_cache.c +++ b/src/backend/distributed/utils/metadata_cache.c @@ -156,6 +156,7 @@ static HeapTuple LookupDistPartitionTuple(Relation pgDistPartition, Oid relation static List * LookupDistShardTuples(Oid relationId); static Oid LookupShardRelation(int64 shardId); static void GetPartitionTypeInputInfo(char *partitionKeyString, char partitionMethod, + Oid *columnTypeId, int32 *columnTypeMod, Oid *intervalTypeId, int32 *intervalTypeMod); static ShardInterval * TupleToShardInterval(HeapTuple heapTuple, TupleDesc tupleDescriptor, Oid intervalTypeId, @@ -617,11 +618,19 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) ShardInterval **shardIntervalArray = NULL; ShardInterval **sortedShardIntervalArray = NULL; FmgrInfo *shardIntervalCompareFunction = NULL; + FmgrInfo *shardColumnCompareFunction = NULL; List *distShardTupleList = NIL; int shardIntervalArrayLength = 0; int shardIndex = 0; + Oid columnTypeId = InvalidOid; + int32 columnTypeMod = -1; + Oid intervalTypeId = InvalidOid; + int32 intervalTypeMod = -1; + GetPartitionTypeInputInfo(cacheEntry->partitionKeyString, cacheEntry->partitionMethod, + &columnTypeId, + &columnTypeMod, &intervalTypeId, &intervalTypeMod); @@ -671,7 +680,22 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) heap_close(distShardRelation, AccessShareLock); } - /* decide and allocate interval comparison function */ + /* look up value comparison function */ + if (columnTypeId != InvalidOid) + { + /* allocate the comparison function in the cache context */ + MemoryContext oldContext = MemoryContextSwitchTo(CacheMemoryContext); + + shardColumnCompareFunction = GetFunctionInfo(columnTypeId, BTREE_AM_OID, + BTORDER_PROC); + MemoryContextSwitchTo(oldContext); + } + else + { + shardColumnCompareFunction = NULL; + } + + /* look up interval comparison function */ if (intervalTypeId != InvalidOid) { /* allocate the comparison function in the cache context */ @@ -785,6 +809,7 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) cacheEntry->shardIntervalArrayLength = shardIntervalArrayLength; cacheEntry->sortedShardIntervalArray = sortedShardIntervalArray; + cacheEntry->shardColumnCompareFunction = shardColumnCompareFunction; cacheEntry->shardIntervalCompareFunction = shardIntervalCompareFunction; } @@ -2375,8 +2400,11 @@ LookupShardRelation(int64 shardId) */ static void GetPartitionTypeInputInfo(char *partitionKeyString, char partitionMethod, + Oid *columnTypeId, int32 *columnTypeMod, Oid *intervalTypeId, int32 *intervalTypeMod) { + *columnTypeId = InvalidOid; + *columnTypeMod = -1; *intervalTypeId = InvalidOid; *intervalTypeMod = -1; @@ -2391,18 +2419,25 @@ GetPartitionTypeInputInfo(char *partitionKeyString, char partitionMethod, *intervalTypeId = partitionColumn->vartype; *intervalTypeMod = partitionColumn->vartypmod; + *columnTypeId = partitionColumn->vartype; + *columnTypeMod = partitionColumn->vartypmod; break; } case DISTRIBUTE_BY_HASH: { + Node *partitionNode = stringToNode(partitionKeyString); + Var *partitionColumn = (Var *) partitionNode; + Assert(IsA(partitionNode, Var)); + *intervalTypeId = INT4OID; + *columnTypeId = partitionColumn->vartype; + *columnTypeMod = partitionColumn->vartypmod; break; } case DISTRIBUTE_BY_NONE: { - *intervalTypeId = InvalidOid; break; } diff --git a/src/include/distributed/metadata_cache.h b/src/include/distributed/metadata_cache.h index d80b4a038..5cc2bbf26 100644 --- a/src/include/distributed/metadata_cache.h +++ b/src/include/distributed/metadata_cache.h @@ -49,7 +49,15 @@ typedef struct int shardIntervalArrayLength; ShardInterval **sortedShardIntervalArray; - FmgrInfo *shardIntervalCompareFunction; /* NULL if DISTRIBUTE_BY_NONE */ + /* comparator for partition column's type, NULL if DISTRIBUTE_BY_NONE */ + FmgrInfo *shardColumnCompareFunction; + + /* + * Comparator for partition interval type (different from + * shardValueCompareFunction if hash-partitioned), NULL if + * DISTRIBUTE_BY_NONE. + */ + FmgrInfo *shardIntervalCompareFunction; FmgrInfo *hashFunction; /* NULL if table is not distributed by hash */ /* pg_dist_shard_placement metadata */ diff --git a/src/test/regress/expected/multi_data_types.out b/src/test/regress/expected/multi_data_types.out index 76fd44367..1cb4d1f9b 100644 --- a/src/test/regress/expected/multi_data_types.out +++ b/src/test/regress/expected/multi_data_types.out @@ -10,8 +10,13 @@ CREATE TYPE test_composite_type AS ( ); -- ... as well as a function to use as its comparator... CREATE FUNCTION equal_test_composite_type_function(test_composite_type, test_composite_type) RETURNS boolean -AS 'select $1.i = $2.i AND $1.i2 = $2.i2;' -LANGUAGE SQL +LANGUAGE 'internal' +AS 'record_eq' +IMMUTABLE +RETURNS NULL ON NULL INPUT; +CREATE FUNCTION cmp_test_composite_type_function(test_composite_type, test_composite_type) RETURNS int +LANGUAGE 'internal' +AS 'btrecordcmp' IMMUTABLE RETURNS NULL ON NULL INPUT; -- ... use that function to create a custom equality operator... @@ -34,7 +39,8 @@ RETURNS NULL ON NULL INPUT; -- One uses BTREE the other uses HASH CREATE OPERATOR CLASS cats_op_fam_clas3 DEFAULT FOR TYPE test_composite_type USING BTREE AS -OPERATOR 3 = (test_composite_type, test_composite_type); +OPERATOR 3 = (test_composite_type, test_composite_type), +FUNCTION 1 cmp_test_composite_type_function(test_composite_type, test_composite_type); CREATE OPERATOR CLASS cats_op_fam_class DEFAULT FOR TYPE test_composite_type USING HASH AS OPERATOR 1 = (test_composite_type, test_composite_type), diff --git a/src/test/regress/sql/multi_data_types.sql b/src/test/regress/sql/multi_data_types.sql index 315550474..3dd6eb0f7 100644 --- a/src/test/regress/sql/multi_data_types.sql +++ b/src/test/regress/sql/multi_data_types.sql @@ -15,8 +15,14 @@ CREATE TYPE test_composite_type AS ( -- ... as well as a function to use as its comparator... CREATE FUNCTION equal_test_composite_type_function(test_composite_type, test_composite_type) RETURNS boolean -AS 'select $1.i = $2.i AND $1.i2 = $2.i2;' -LANGUAGE SQL +LANGUAGE 'internal' +AS 'record_eq' +IMMUTABLE +RETURNS NULL ON NULL INPUT; + +CREATE FUNCTION cmp_test_composite_type_function(test_composite_type, test_composite_type) RETURNS int +LANGUAGE 'internal' +AS 'btrecordcmp' IMMUTABLE RETURNS NULL ON NULL INPUT; @@ -44,7 +50,8 @@ RETURNS NULL ON NULL INPUT; -- One uses BTREE the other uses HASH CREATE OPERATOR CLASS cats_op_fam_clas3 DEFAULT FOR TYPE test_composite_type USING BTREE AS -OPERATOR 3 = (test_composite_type, test_composite_type); +OPERATOR 3 = (test_composite_type, test_composite_type), +FUNCTION 1 cmp_test_composite_type_function(test_composite_type, test_composite_type); CREATE OPERATOR CLASS cats_op_fam_class DEFAULT FOR TYPE test_composite_type USING HASH AS From 6bd2e3ed302e83f7dacab0b6d76695691c1d31d1 Mon Sep 17 00:00:00 2001 From: Andres Freund Date: Thu, 27 Apr 2017 18:17:14 -0700 Subject: [PATCH 4/5] Add DistTableCacheEntry->hasOverlappingShardInterval. This determines whether it's possible to perform binary search on sortedShardIntervalArray or not. If e.g. two shards have overlapping ranges, that'd be prohibitive. That'll be useful in later commit introducing faster shard pruning. --- .../distributed/utils/metadata_cache.c | 80 +++++++++++++++++++ src/include/distributed/metadata_cache.h | 1 + .../input/multi_outer_join_reference.source | 2 +- .../output/multi_outer_join_reference.source | 5 +- 4 files changed, 84 insertions(+), 4 deletions(-) diff --git a/src/backend/distributed/utils/metadata_cache.c b/src/backend/distributed/utils/metadata_cache.c index a36dd5c46..15062f78e 100644 --- a/src/backend/distributed/utils/metadata_cache.c +++ b/src/backend/distributed/utils/metadata_cache.c @@ -145,6 +145,9 @@ static bool HasUninitializedShardInterval(ShardInterval **sortedShardIntervalArr static void ErrorIfInstalledVersionMismatch(void); static char * AvailableExtensionVersion(void); static char * InstalledExtensionVersion(void); +static bool HasOverlappingShardInterval(ShardInterval **shardIntervalArray, + int shardIntervalArrayLength, + FmgrInfo *shardIntervalSortCompareFunction); static void InitializeDistTableCache(void); static void InitializeWorkerNodeCache(void); static uint32 WorkerNodeHashCode(const void *key, Size keySize); @@ -714,6 +717,7 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) if (cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE) { cacheEntry->hasUninitializedShardInterval = true; + cacheEntry->hasOverlappingShardInterval = true; /* * Note that during create_reference_table() call, @@ -742,6 +746,35 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) cacheEntry->hasUninitializedShardInterval = HasUninitializedShardInterval(sortedShardIntervalArray, shardIntervalArrayLength); + + if (!cacheEntry->hasUninitializedShardInterval) + { + cacheEntry->hasOverlappingShardInterval = + HasOverlappingShardInterval(sortedShardIntervalArray, + shardIntervalArrayLength, + shardIntervalCompareFunction); + } + else + { + cacheEntry->hasOverlappingShardInterval = true; + } + + /* + * If table is hash-partitioned and has shards, there never should be + * any uninitalized shards. Historically we've not prevented that for + * range partitioned tables, but it might be a good idea to start + * doing so. + */ + if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH && + cacheEntry->hasUninitializedShardInterval) + { + ereport(ERROR, (errmsg("hash partitioned table has uninitialized shards"))); + } + if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH && + cacheEntry->hasOverlappingShardInterval) + { + ereport(ERROR, (errmsg("hash partitioned table has overlapping shards"))); + } } @@ -917,6 +950,52 @@ HasUninitializedShardInterval(ShardInterval **sortedShardIntervalArray, int shar } +/* + * HasOverlappingShardInterval determines whether the given list of sorted + * shards has overlapping ranges. + */ +static bool +HasOverlappingShardInterval(ShardInterval **shardIntervalArray, + int shardIntervalArrayLength, + FmgrInfo *shardIntervalSortCompareFunction) +{ + int shardIndex = 0; + ShardInterval *lastShardInterval = NULL; + Datum comparisonDatum = 0; + int comparisonResult = 0; + + /* zero/a single shard can't overlap */ + if (shardIntervalArrayLength < 2) + { + return false; + } + + lastShardInterval = shardIntervalArray[0]; + for (shardIndex = 1; shardIndex < shardIntervalArrayLength; shardIndex++) + { + ShardInterval *curShardInterval = shardIntervalArray[shardIndex]; + + /* only called if !hasUninitializedShardInterval */ + Assert(lastShardInterval->minValueExists && lastShardInterval->maxValueExists); + Assert(curShardInterval->minValueExists && curShardInterval->maxValueExists); + + comparisonDatum = CompareCall2(shardIntervalSortCompareFunction, + lastShardInterval->maxValue, + curShardInterval->minValue); + comparisonResult = DatumGetInt32(comparisonDatum); + + if (comparisonResult >= 0) + { + return true; + } + + lastShardInterval = curShardInterval; + } + + return false; +} + + /* * CitusHasBeenLoaded returns true if the citus extension has been created * in the current database and the extension script has been executed. Otherwise, @@ -2138,6 +2217,7 @@ ResetDistTableCacheEntry(DistTableCacheEntry *cacheEntry) cacheEntry->shardIntervalArrayLength = 0; cacheEntry->hasUninitializedShardInterval = false; cacheEntry->hasUniformHashDistribution = false; + cacheEntry->hasOverlappingShardInterval = false; } diff --git a/src/include/distributed/metadata_cache.h b/src/include/distributed/metadata_cache.h index 5cc2bbf26..ae83940d4 100644 --- a/src/include/distributed/metadata_cache.h +++ b/src/include/distributed/metadata_cache.h @@ -38,6 +38,7 @@ typedef struct bool isDistributedTable; bool hasUninitializedShardInterval; bool hasUniformHashDistribution; /* valid for hash partitioned tables */ + bool hasOverlappingShardInterval; /* pg_dist_partition metadata for this table */ char *partitionKeyString; diff --git a/src/test/regress/input/multi_outer_join_reference.source b/src/test/regress/input/multi_outer_join_reference.source index f1e5946c4..9a106eb61 100644 --- a/src/test/regress/input/multi_outer_join_reference.source +++ b/src/test/regress/input/multi_outer_join_reference.source @@ -166,7 +166,7 @@ FROM -- load some more data \copy multi_outer_join_right_reference FROM '@abs_srcdir@/data/customer-21-30.data' with delimiter '|' --- Update shards so that they do not have 1-1 matching. We should error here. +-- Update shards so that they do not have 1-1 matching, triggering an error. UPDATE pg_dist_shard SET shardminvalue = '2147483646' WHERE shardid = 1260006; UPDATE pg_dist_shard SET shardmaxvalue = '2147483647' WHERE shardid = 1260006; SELECT diff --git a/src/test/regress/output/multi_outer_join_reference.source b/src/test/regress/output/multi_outer_join_reference.source index 82d1b88ed..a2b17a0fa 100644 --- a/src/test/regress/output/multi_outer_join_reference.source +++ b/src/test/regress/output/multi_outer_join_reference.source @@ -228,15 +228,14 @@ LOG: join order: [ "multi_outer_join_left_hash" ][ broadcast join "multi_outer_ -- load some more data \copy multi_outer_join_right_reference FROM '@abs_srcdir@/data/customer-21-30.data' with delimiter '|' --- Update shards so that they do not have 1-1 matching. We should error here. +-- Update shards so that they do not have 1-1 matching, triggering an error. UPDATE pg_dist_shard SET shardminvalue = '2147483646' WHERE shardid = 1260006; UPDATE pg_dist_shard SET shardmaxvalue = '2147483647' WHERE shardid = 1260006; SELECT min(l_custkey), max(l_custkey) FROM multi_outer_join_left_hash a LEFT JOIN multi_outer_join_right_hash b ON (l_custkey = r_custkey); -ERROR: cannot perform distributed planning on this query -DETAIL: Shards of relations in outer join queries must have 1-to-1 shard partitioning +ERROR: hash partitioned table has overlapping shards UPDATE pg_dist_shard SET shardminvalue = '-2147483648' WHERE shardid = 1260006; UPDATE pg_dist_shard SET shardmaxvalue = '-1073741825' WHERE shardid = 1260006; -- empty tables From d399f395f7c409fc43a1affcd37a127557e1bb5c Mon Sep 17 00:00:00 2001 From: Andres Freund Date: Fri, 28 Apr 2017 14:40:41 -0700 Subject: [PATCH 5/5] Faster shard pruning. So far citus used postgres' predicate proofing logic for shard pruning, except for INSERT and COPY which were already optimized for speed. That turns out to be too slow: * Shard pruning for SELECTs is currently O(#shards), because PruneShardList calls predicate_refuted_by() for every shard. Obviously using an O(N) type algorithm for general pruning isn't good. * predicate_refuted_by() is quite expensive on its own right. That's primarily because it's optimized for doing a single refutation proof, rather than performing the same proof over and over. * predicate_refuted_by() does not keep persistent state (see 2.) for function calls, which means that a lot of syscache lookups will be performed. That's particularly bad if the partitioning key is a composite key, because without a persistent FunctionCallInfo record_cmp() has to repeatedly look-up the type definition of the composite key. That's quite expensive. Thus replace this with custom-code that works in two phases: 1) Search restrictions for constraints that can be pruned upon 2) Use those restrictions to search for matching shards in the most efficient manner available: a) Binary search / Hash Lookup in case of hash partitioned tables b) Binary search for equal clauses in case of range or append tables without overlapping shards. c) Binary search for inequality clauses, searching for both lower and upper boundaries, again in case of range or append tables without overlapping shards. d) exhaustive search testing each ShardInterval My measurements suggest that we are considerably, often orders of magnitude, faster than the previous solution, even if we have to fall back to exhaustive pruning. --- src/backend/distributed/commands/multi_copy.c | 1 + .../master/master_modify_multiple_shards.c | 5 +- .../planner/multi_physical_planner.c | 301 +--- .../planner/multi_router_planner.c | 28 +- .../distributed/planner/shard_pruning.c | 1319 +++++++++++++++++ .../distributed/test/prune_shard_list.c | 5 +- .../distributed/utils/shardinterval_utils.c | 4 +- .../distributed/multi_physical_planner.h | 4 - src/include/distributed/shard_pruning.h | 23 + src/include/distributed/shardinterval_utils.h | 1 + .../expected/multi_prune_shard_list.out | 8 +- .../regress/sql/multi_prune_shard_list.sql | 4 +- 12 files changed, 1369 insertions(+), 334 deletions(-) create mode 100644 src/backend/distributed/planner/shard_pruning.c create mode 100644 src/include/distributed/shard_pruning.h diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index 770ff4799..a0e1f0b65 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -67,6 +67,7 @@ #include "distributed/placement_connection.h" #include "distributed/remote_commands.h" #include "distributed/resource_lock.h" +#include "distributed/shard_pruning.h" #include "executor/executor.h" #include "nodes/makefuncs.h" #include "tsearch/ts_locale.h" diff --git a/src/backend/distributed/master/master_modify_multiple_shards.c b/src/backend/distributed/master/master_modify_multiple_shards.c index d2695a7b4..086846fde 100644 --- a/src/backend/distributed/master/master_modify_multiple_shards.c +++ b/src/backend/distributed/master/master_modify_multiple_shards.c @@ -40,6 +40,7 @@ #include "distributed/pg_dist_partition.h" #include "distributed/resource_lock.h" #include "distributed/shardinterval_utils.h" +#include "distributed/shard_pruning.h" #include "distributed/worker_protocol.h" #include "optimizer/clauses.h" #include "optimizer/predtest.h" @@ -81,7 +82,6 @@ master_modify_multiple_shards(PG_FUNCTION_ARGS) Node *queryTreeNode; List *restrictClauseList = NIL; bool failOK = false; - List *shardIntervalList = NIL; List *prunedShardIntervalList = NIL; List *taskList = NIL; int32 affectedTupleCount = 0; @@ -156,11 +156,10 @@ master_modify_multiple_shards(PG_FUNCTION_ARGS) ExecuteMasterEvaluableFunctions(modifyQuery, NULL); - shardIntervalList = LoadShardIntervalList(relationId); restrictClauseList = WhereClauseList(modifyQuery->jointree); prunedShardIntervalList = - PruneShardList(relationId, tableId, restrictClauseList, shardIntervalList); + PruneShards(relationId, tableId, restrictClauseList); CHECK_FOR_INTERRUPTS(); diff --git a/src/backend/distributed/planner/multi_physical_planner.c b/src/backend/distributed/planner/multi_physical_planner.c index 505a240d4..9d2bb9f5b 100644 --- a/src/backend/distributed/planner/multi_physical_planner.c +++ b/src/backend/distributed/planner/multi_physical_planner.c @@ -41,6 +41,7 @@ #include "distributed/pg_dist_partition.h" #include "distributed/pg_dist_shard.h" #include "distributed/shardinterval_utils.h" +#include "distributed/shard_pruning.h" #include "distributed/task_tracker.h" #include "distributed/worker_manager.h" #include "distributed/worker_protocol.h" @@ -133,9 +134,6 @@ static List * RangeTableFragmentsList(List *rangeTableList, List *whereClauseLis static OperatorCacheEntry * LookupOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber); static Oid GetOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber); -static Node * HashableClauseMutator(Node *originalNode, Var *partitionColumn); -static OpExpr * MakeHashedOperatorExpression(OpExpr *operatorExpression); -static List * BuildRestrictInfoList(List *qualList); static List * FragmentCombinationList(List *rangeTableFragmentsList, Query *jobQuery, List *dependedJobList); static JoinSequenceNode * JoinSequenceArray(List *rangeTableFragmentsList, @@ -2060,7 +2058,6 @@ SubquerySqlTaskList(Job *job) { RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell); Oid relationId = rangeTableEntry->relid; - List *shardIntervalList = LoadShardIntervalList(relationId); List *finalShardIntervalList = NIL; ListCell *fragmentCombinationCell = NULL; ListCell *shardIntervalCell = NULL; @@ -2073,12 +2070,11 @@ SubquerySqlTaskList(Job *job) Var *partitionColumn = PartitionColumn(relationId, tableId); List *whereClauseList = ReplaceColumnsInOpExpressionList(opExpressionList, partitionColumn); - finalShardIntervalList = PruneShardList(relationId, tableId, whereClauseList, - shardIntervalList); + finalShardIntervalList = PruneShards(relationId, tableId, whereClauseList); } else { - finalShardIntervalList = shardIntervalList; + finalShardIntervalList = LoadShardIntervalList(relationId); } /* if all shards are pruned away, we return an empty task list */ @@ -2513,11 +2509,8 @@ RangeTableFragmentsList(List *rangeTableList, List *whereClauseList, Oid relationId = rangeTableEntry->relid; ListCell *shardIntervalCell = NULL; List *shardFragmentList = NIL; - - List *shardIntervalList = LoadShardIntervalList(relationId); - List *prunedShardIntervalList = PruneShardList(relationId, tableId, - whereClauseList, - shardIntervalList); + List *prunedShardIntervalList = PruneShards(relationId, tableId, + whereClauseList); /* * If we prune all shards for one table, query results will be empty. @@ -2586,114 +2579,6 @@ RangeTableFragmentsList(List *rangeTableList, List *whereClauseList, } -/* - * PruneShardList prunes shard intervals from given list based on the selection criteria, - * and returns remaining shard intervals in another list. - * - * For reference tables, the function simply returns the single shard that the table has. - */ -List * -PruneShardList(Oid relationId, Index tableId, List *whereClauseList, - List *shardIntervalList) -{ - List *remainingShardList = NIL; - ListCell *shardIntervalCell = NULL; - List *restrictInfoList = NIL; - Node *baseConstraint = NULL; - - Var *partitionColumn = PartitionColumn(relationId, tableId); - char partitionMethod = PartitionMethod(relationId); - - /* short circuit for reference tables */ - if (partitionMethod == DISTRIBUTE_BY_NONE) - { - return shardIntervalList; - } - - if (ContainsFalseClause(whereClauseList)) - { - /* always return empty result if WHERE clause is of the form: false (AND ..) */ - return NIL; - } - - /* build the filter clause list for the partition method */ - if (partitionMethod == DISTRIBUTE_BY_HASH) - { - Node *hashedNode = HashableClauseMutator((Node *) whereClauseList, - partitionColumn); - - List *hashedClauseList = (List *) hashedNode; - restrictInfoList = BuildRestrictInfoList(hashedClauseList); - } - else - { - restrictInfoList = BuildRestrictInfoList(whereClauseList); - } - - /* override the partition column for hash partitioning */ - if (partitionMethod == DISTRIBUTE_BY_HASH) - { - partitionColumn = MakeInt4Column(); - } - - /* build the base expression for constraint */ - baseConstraint = BuildBaseConstraint(partitionColumn); - - /* walk over shard list and check if shards can be pruned */ - foreach(shardIntervalCell, shardIntervalList) - { - ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); - List *constraintList = NIL; - bool shardPruned = false; - - if (shardInterval->minValueExists && shardInterval->maxValueExists) - { - /* set the min/max values in the base constraint */ - UpdateConstraint(baseConstraint, shardInterval); - constraintList = list_make1(baseConstraint); - - shardPruned = predicate_refuted_by(constraintList, restrictInfoList); - } - - if (!shardPruned) - { - remainingShardList = lappend(remainingShardList, shardInterval); - } - } - - return remainingShardList; -} - - -/* - * ContainsFalseClause returns whether the flattened where clause list - * contains false as a clause. - */ -bool -ContainsFalseClause(List *whereClauseList) -{ - bool containsFalseClause = false; - ListCell *clauseCell = NULL; - - foreach(clauseCell, whereClauseList) - { - Node *clause = (Node *) lfirst(clauseCell); - - if (IsA(clause, Const)) - { - Const *constant = (Const *) clause; - if (constant->consttype == BOOLOID && !DatumGetBool(constant->constvalue)) - { - containsFalseClause = true; - break; - } - } - } - - return containsFalseClause; -} - - /* * BuildBaseConstraint builds and returns a base constraint. This constraint * implements an expression in the form of (column <= max && column >= min), @@ -2916,87 +2801,6 @@ SimpleOpExpression(Expr *clause) } -/* - * HashableClauseMutator walks over the original where clause list, replaces - * hashable nodes with hashed versions and keeps other nodes as they are. - */ -static Node * -HashableClauseMutator(Node *originalNode, Var *partitionColumn) -{ - Node *newNode = NULL; - if (originalNode == NULL) - { - return NULL; - } - - if (IsA(originalNode, OpExpr)) - { - OpExpr *operatorExpression = (OpExpr *) originalNode; - bool hasPartitionColumn = false; - - Oid leftHashFunction = InvalidOid; - Oid rightHashFunction = InvalidOid; - - /* - * If operatorExpression->opno is NOT the registered '=' operator for - * any hash opfamilies, then get_op_hash_functions will return false. - * This means this function both ensures a hash function exists for the - * types in question AND filters out any clauses lacking equality ops. - */ - bool hasHashFunction = get_op_hash_functions(operatorExpression->opno, - &leftHashFunction, - &rightHashFunction); - - bool simpleOpExpression = SimpleOpExpression((Expr *) operatorExpression); - if (simpleOpExpression) - { - hasPartitionColumn = OpExpressionContainsColumn(operatorExpression, - partitionColumn); - } - - if (hasHashFunction && hasPartitionColumn) - { - OpExpr *hashedOperatorExpression = - MakeHashedOperatorExpression((OpExpr *) originalNode); - newNode = (Node *) hashedOperatorExpression; - } - } - else if (IsA(originalNode, ScalarArrayOpExpr)) - { - ScalarArrayOpExpr *arrayOperatorExpression = (ScalarArrayOpExpr *) originalNode; - Node *leftOpExpression = linitial(arrayOperatorExpression->args); - Node *strippedLeftOpExpression = strip_implicit_coercions(leftOpExpression); - bool usingEqualityOperator = OperatorImplementsEquality( - arrayOperatorExpression->opno); - - /* - * Citus cannot prune hash-distributed shards with ANY/ALL. We show a NOTICE - * if the expression is ANY/ALL performed on the partition column with equality. - */ - if (usingEqualityOperator && strippedLeftOpExpression != NULL && - equal(strippedLeftOpExpression, partitionColumn)) - { - ereport(NOTICE, (errmsg("cannot use shard pruning with " - "ANY/ALL (array expression)"), - errhint("Consider rewriting the expression with " - "OR/AND clauses."))); - } - } - - /* - * If this node is not hashable, continue walking down the expression tree - * to find and hash clauses which are eligible. - */ - if (newNode == NULL) - { - newNode = expression_tree_mutator(originalNode, HashableClauseMutator, - (void *) partitionColumn); - } - - return newNode; -} - - /* * OpExpressionContainsColumn checks if the operator expression contains the * given partition column. We assume that given operator expression is a simple @@ -3027,77 +2831,6 @@ OpExpressionContainsColumn(OpExpr *operatorExpression, Var *partitionColumn) } -/* - * MakeHashedOperatorExpression creates a new operator expression with a column - * of int4 type and hashed constant value. - */ -static OpExpr * -MakeHashedOperatorExpression(OpExpr *operatorExpression) -{ - const Oid hashResultTypeId = INT4OID; - TypeCacheEntry *hashResultTypeEntry = NULL; - Oid operatorId = InvalidOid; - OpExpr *hashedExpression = NULL; - Var *hashedColumn = NULL; - Datum hashedValue = 0; - Const *hashedConstant = NULL; - FmgrInfo *hashFunction = NULL; - TypeCacheEntry *typeEntry = NULL; - - Node *leftOperand = get_leftop((Expr *) operatorExpression); - Node *rightOperand = get_rightop((Expr *) operatorExpression); - Const *constant = NULL; - - if (IsA(rightOperand, Const)) - { - constant = (Const *) rightOperand; - } - else - { - constant = (Const *) leftOperand; - } - - /* Load the operator from type cache */ - hashResultTypeEntry = lookup_type_cache(hashResultTypeId, TYPECACHE_EQ_OPR); - operatorId = hashResultTypeEntry->eq_opr; - - /* Get a column with int4 type */ - hashedColumn = MakeInt4Column(); - - /* Load the hash function from type cache */ - typeEntry = lookup_type_cache(constant->consttype, TYPECACHE_HASH_PROC_FINFO); - hashFunction = &(typeEntry->hash_proc_finfo); - if (!OidIsValid(hashFunction->fn_oid)) - { - ereport(ERROR, (errcode(ERRCODE_UNDEFINED_FUNCTION), - errmsg("could not identify a hash function for type %s", - format_type_be(constant->consttype)), - errdatatype(constant->consttype))); - } - - /* - * Note that any changes to PostgreSQL's hashing functions will change the - * new value created by this function. - */ - hashedValue = FunctionCall1(hashFunction, constant->constvalue); - hashedConstant = MakeInt4Constant(hashedValue); - - /* Now create the expression with modified partition column and hashed constant */ - hashedExpression = (OpExpr *) make_opclause(operatorId, - InvalidOid, /* no result type yet */ - false, /* no return set */ - (Expr *) hashedColumn, - (Expr *) hashedConstant, - InvalidOid, InvalidOid); - - /* Set implementing function id and result type */ - hashedExpression->opfuncid = get_opcode(operatorId); - hashedExpression->opresulttype = get_func_rettype(hashedExpression->opfuncid); - - return hashedExpression; -} - - /* * MakeInt4Column creates a column of int4 type with invalid table id and max * attribute number. @@ -3139,30 +2872,6 @@ MakeInt4Constant(Datum constantValue) } -/* - * BuildRestrictInfoList builds restrict info list using the selection criteria, - * and then return this list. Note that this function assumes there is only one - * relation for now. - */ -static List * -BuildRestrictInfoList(List *qualList) -{ - List *restrictInfoList = NIL; - ListCell *qualCell = NULL; - - foreach(qualCell, qualList) - { - RestrictInfo *restrictInfo = NULL; - Node *qualNode = (Node *) lfirst(qualCell); - - restrictInfo = make_simple_restrictinfo((Expr *) qualNode); - restrictInfoList = lappend(restrictInfoList, restrictInfo); - } - - return restrictInfoList; -} - - /* Updates the base constraint with the given min/max values. */ void UpdateConstraint(Node *baseConstraint, ShardInterval *shardInterval) diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 2829ecdcc..e094ad8d0 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -41,6 +41,7 @@ #include "distributed/relay_utility.h" #include "distributed/resource_lock.h" #include "distributed/shardinterval_utils.h" +#include "distributed/shard_pruning.h" #include "executor/execdesc.h" #include "lib/stringinfo.h" #include "nodes/makefuncs.h" @@ -558,6 +559,8 @@ RouterModifyTaskForShardInterval(Query *originalQuery, ShardInterval *shardInter * * The function errors out if the given shard interval does not belong to a hash, * range and append distributed tables. + * + * NB: If you update this, also look at PrunableExpressionsWalker(). */ static List * ShardIntervalOpExpressions(ShardInterval *shardInterval, Index rteIndex) @@ -1998,9 +2001,7 @@ FindShardForInsert(Query *query, DeferredErrorMessage **planningError) restrictClauseList = list_make1(equalityExpr); - shardIntervalList = LoadShardIntervalList(distributedTableId); - prunedShardList = PruneShardList(distributedTableId, tableId, restrictClauseList, - shardIntervalList); + prunedShardList = PruneShards(distributedTableId, tableId, restrictClauseList); } prunedShardCount = list_length(prunedShardList); @@ -2060,7 +2061,6 @@ FindShardForUpdateOrDelete(Query *query, DeferredErrorMessage **planningError) DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); char partitionMethod = cacheEntry->partitionMethod; CmdType commandType = query->commandType; - List *shardIntervalList = NIL; List *restrictClauseList = NIL; Index tableId = 1; List *prunedShardList = NIL; @@ -2068,11 +2068,8 @@ FindShardForUpdateOrDelete(Query *query, DeferredErrorMessage **planningError) Assert(commandType == CMD_UPDATE || commandType == CMD_DELETE); - shardIntervalList = LoadShardIntervalList(distributedTableId); - restrictClauseList = QueryRestrictList(query, partitionMethod); - prunedShardList = PruneShardList(distributedTableId, tableId, restrictClauseList, - shardIntervalList); + prunedShardList = PruneShards(distributedTableId, tableId, restrictClauseList); prunedShardCount = list_length(prunedShardList); if (prunedShardCount != 1) @@ -2412,7 +2409,6 @@ TargetShardIntervalsForSelect(Query *query, List *baseRestrictionList = relationRestriction->relOptInfo->baserestrictinfo; List *restrictClauseList = get_all_actual_clauses(baseRestrictionList); List *prunedShardList = NIL; - int shardIndex = 0; List *joinInfoList = relationRestriction->relOptInfo->joininfo; List *pseudoRestrictionList = extract_actual_clauses(joinInfoList, true); bool whereFalseQuery = false; @@ -2428,18 +2424,8 @@ TargetShardIntervalsForSelect(Query *query, whereFalseQuery = ContainsFalseClause(pseudoRestrictionList); if (!whereFalseQuery && shardCount > 0) { - List *shardIntervalList = NIL; - - for (shardIndex = 0; shardIndex < shardCount; shardIndex++) - { - ShardInterval *shardInterval = - cacheEntry->sortedShardIntervalArray[shardIndex]; - shardIntervalList = lappend(shardIntervalList, shardInterval); - } - - prunedShardList = PruneShardList(relationId, tableId, - restrictClauseList, - shardIntervalList); + prunedShardList = PruneShards(relationId, tableId, + restrictClauseList); /* * Quick bail out. The query can not be router plannable if one diff --git a/src/backend/distributed/planner/shard_pruning.c b/src/backend/distributed/planner/shard_pruning.c new file mode 100644 index 000000000..807b4f71e --- /dev/null +++ b/src/backend/distributed/planner/shard_pruning.c @@ -0,0 +1,1319 @@ +/*------------------------------------------------------------------------- + * + * shard_pruning.c + * Shard pruning related code. + * + * The goal of shard pruning is to find a minimal (super)set of shards that + * need to be queried to find rows matching the expression in a query. + * + * In PruneShards, we first compute a simplified disjunctive normal form (DNF) + * of the expression as a list of pruning instances. Each pruning instance + * contains all AND-ed constraints on the partition column. An OR expression + * will result in two or more new pruning instances being added for the + * subexpressions. The "parent" instance is marked isPartial and ignored + * during pruning. + * + * We use the distributive property for constraints of the form P AND (Q OR R) + * to rewrite it to (P AND Q) OR (P AND R) by copying constraints from parent + * to "child" pruning instances. However, we do not distribute nested + * expressions. While (P OR Q) AND (R OR S) is logically equivalent to (P AND + * R) OR (P AND S) OR (Q AND R) OR (Q AND S), in our implementation it becomes + * P OR Q OR R OR S. This is acceptable since this will always result in a + * superset of shards. If this proves to be a issue in practice, a more + * complete algorithm could be implemented. + * + * We then evaluate each non-partial pruning instance in the disjunction + * through the following, increasingly expensive, steps: + * + * 1) If there is a constant equality constraint on the partition column, and + * no overlapping shards exist, find the shard interval in which the + * constant falls + * + * 2) If there is a hash range constraint on the partition column, find the + * shard interval matching the range + * + * 3) If there are range constraints (e.g. (a > 0 AND a < 10)) on the + * partition column, find the shard intervals that overlap with the range + * + * 4) If there are overlapping shards, exhaustively search all shards that are + * not excluded by constraints + * + * Finally, the union of the shards found by each pruning instance is + * returned. + * + * Copyright (c) 2014-2017, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ +#include "postgres.h" + +#include "distributed/shard_pruning.h" + +#include "access/nbtree.h" +#include "catalog/pg_am.h" +#include "catalog/pg_collation.h" +#include "catalog/pg_type.h" +#include "distributed/metadata_cache.h" +#include "distributed/multi_planner.h" +#include "distributed/multi_join_order.h" +#include "distributed/multi_physical_planner.h" +#include "distributed/shardinterval_utils.h" +#include "distributed/pg_dist_partition.h" +#include "distributed/worker_protocol.h" +#include "nodes/nodeFuncs.h" +#include "optimizer/clauses.h" +#include "utils/catcache.h" +#include "utils/lsyscache.h" +#include "utils/memutils.h" + +/* + * A pruning instance is a set of ANDed constraints on a partition key. + */ +typedef struct PruningInstance +{ + /* Does this instance contain any prunable expressions? */ + bool hasValidConstraint; + + /* + * This constraint never evaluates to true, i.e. pruning does not have to + * be performed. + */ + bool evaluatesToFalse; + + /* + * Constraints on the partition column value. If multiple values are + * found the more restrictive one should be stored here. Even in case of + * a hash-partitioned table, actual column-values are stored here, *not* + * hashed values. + */ + Const *lessConsts; + Const *lessEqualConsts; + Const *equalConsts; + Const *greaterEqualConsts; + Const *greaterConsts; + + /* + * Constraint using a pre-hashed column value. The constant will store the + * hashed value, not the original value of the restriction. + */ + Const *hashedEqualConsts; + + /* + * Types of constraints not understood. We could theoretically try more + * expensive methods of pruning if any such restrictions are found. + * + * TODO: any actual use for this? Right now there seems little point. + */ + List *otherRestrictions; + + /* + * Has this PruningInstance been added to + * ClauseWalkerContext->pruningInstances? This is not done immediately, + * but the first time a constraint (independent of us being able to handle + * that constraint) is found. + */ + bool addedToPruningInstances; + + /* + * When OR clauses are found, the non-ORed part (think of a < 3 AND (a > 5 + * OR a > 7)) of the expression is stored in one PruningInstance which is + * then copied for the ORed expressions. The original is marked as + * isPartial, to avoid it being used for pruning. + */ + bool isPartial; +} PruningInstance; + + +/* + * Partial instances that need to be finished building. This is used to + * collect all ANDed restrictions, before looking into ORed expressions. + */ +typedef struct PendingPruningInstance +{ + PruningInstance *instance; + Node *continueAt; +} PendingPruningInstance; + + +/* + * Data necessary to perform a single PruneShards(). + */ +typedef struct ClauseWalkerContext +{ + Var *partitionColumn; + char partitionMethod; + + /* ORed list of pruning targets */ + List *pruningInstances; + + /* + * Partially built PruningInstances, that need to be completed by doing a + * separate PrunableExpressionsWalker() pass. + */ + List *pendingInstances; + + /* PruningInstance currently being built, all elegible constraints are added here */ + PruningInstance *currentPruningInstance; + + /* + * Information about function calls we need to perform. Re-using the same + * FunctionCallInfoData, instead of using FunctionCall2Coll, is often + * cheaper. + */ + FunctionCallInfoData compareValueFunctionCall; + FunctionCallInfoData compareIntervalFunctionCall; +} ClauseWalkerContext; + +static void PrunableExpressions(Node *originalNode, ClauseWalkerContext *context); +static bool PrunableExpressionsWalker(Node *originalNode, ClauseWalkerContext *context); +static void AddPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, + OpExpr *opClause, Var *varClause, + Const *constantClause); +static void AddHashRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opClause, + Var *varClause, Const *constantClause); +static PruningInstance * CopyPartialPruningInstance(PruningInstance *sourceInstance); +static List * ShardArrayToList(ShardInterval **shardArray, int length); +static List * DeepCopyShardIntervalList(List *originalShardIntervalList); +static int PerformValueCompare(FunctionCallInfoData *compareFunctionCall, Datum a, + Datum b); +static int PerformCompare(FunctionCallInfoData *compareFunctionCall); + +static List * PruneOne(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune); +static List * PruneWithBoundaries(DistTableCacheEntry *cacheEntry, + ClauseWalkerContext *context, + PruningInstance *prune); +static List * ExhaustivePrune(DistTableCacheEntry *cacheEntry, + ClauseWalkerContext *context, + PruningInstance *prune); +static int UpperShardBoundary(Datum partitionColumnValue, + ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, + bool includeMin); +static int LowerShardBoundary(Datum partitionColumnValue, + ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, + bool includeMax); + + +/* + * PruneShards returns all shards from a distributed table that cannot be + * proven to be eliminated by whereClauseList. + * + * For reference tables, the function simply returns the single shard that the + * table has. + */ +List * +PruneShards(Oid relationId, Index rangeTableId, List *whereClauseList) +{ + DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId); + char partitionMethod = cacheEntry->partitionMethod; + ClauseWalkerContext context = { 0 }; + ListCell *pruneCell; + List *prunedList = NIL; + bool foundRestriction = false; + + /* always return empty result if WHERE clause is of the form: false (AND ..) */ + if (ContainsFalseClause(whereClauseList)) + { + return NIL; + } + + /* short circuit for reference tables */ + if (partitionMethod == DISTRIBUTE_BY_NONE) + { + prunedList = ShardArrayToList(cacheEntry->sortedShardIntervalArray, + cacheEntry->shardIntervalArrayLength); + return DeepCopyShardIntervalList(prunedList); + } + + + context.partitionMethod = partitionMethod; + context.partitionColumn = PartitionColumn(relationId, rangeTableId); + context.currentPruningInstance = palloc0(sizeof(PruningInstance)); + + if (cacheEntry->shardIntervalCompareFunction) + { + /* initiate function call info once (allows comparators to cache metadata) */ + InitFunctionCallInfoData(context.compareIntervalFunctionCall, + cacheEntry->shardIntervalCompareFunction, + 2, DEFAULT_COLLATION_OID, NULL, NULL); + } + else + { + ereport(ERROR, (errmsg("shard pruning not possible without " + "a shard interval comparator"))); + } + + if (cacheEntry->shardColumnCompareFunction) + { + /* initiate function call info once (allows comparators to cache metadata) */ + InitFunctionCallInfoData(context.compareValueFunctionCall, + cacheEntry->shardColumnCompareFunction, + 2, DEFAULT_COLLATION_OID, NULL, NULL); + } + else + { + ereport(ERROR, (errmsg("shard pruning not possible without " + "a partition column comparator"))); + } + + /* Figure out what we can prune on */ + PrunableExpressions((Node *) whereClauseList, &context); + + /* + * Prune using each of the PrunableInstances we found, and OR results + * together. + */ + foreach(pruneCell, context.pruningInstances) + { + PruningInstance *prune = (PruningInstance *) lfirst(pruneCell); + List *pruneOneList; + + /* + * If this is a partial instance, a fully built one has also been + * added. Skip. + */ + if (prune->isPartial) + { + continue; + } + + /* + * If the current instance has no prunable expressions, we'll have to + * return all shards. No point in continuing pruning in that case. + */ + if (!prune->hasValidConstraint) + { + foundRestriction = false; + break; + } + + /* + * Similar to the above, if hash-partitioned and there's nothing to + * prune by, we're done. + */ + if (context.partitionMethod == DISTRIBUTE_BY_HASH && + !prune->evaluatesToFalse && !prune->equalConsts && !prune->hashedEqualConsts) + { + foundRestriction = false; + break; + } + + pruneOneList = PruneOne(cacheEntry, &context, prune); + + if (prunedList) + { + /* + * We can use list_union_ptr, which is a lot faster than doing + * comparing shards by value, because all the ShardIntervals are + * guaranteed to be from + * DistTableCacheEntry->sortedShardIntervalArray (thus having the + * same pointer values). + */ + prunedList = list_union_ptr(prunedList, pruneOneList); + } + else + { + prunedList = pruneOneList; + } + foundRestriction = true; + } + + /* found no valid restriction, build list of all shards */ + if (!foundRestriction) + { + prunedList = ShardArrayToList(cacheEntry->sortedShardIntervalArray, + cacheEntry->shardIntervalArrayLength); + } + + /* + * Deep copy list, so it's independent of the DistTableCacheEntry + * contents. + */ + return DeepCopyShardIntervalList(prunedList); +} + + +/* + * ContainsFalseClause returns whether the flattened where clause list + * contains false as a clause. + */ +bool +ContainsFalseClause(List *whereClauseList) +{ + bool containsFalseClause = false; + ListCell *clauseCell = NULL; + + foreach(clauseCell, whereClauseList) + { + Node *clause = (Node *) lfirst(clauseCell); + + if (IsA(clause, Const)) + { + Const *constant = (Const *) clause; + if (constant->consttype == BOOLOID && !DatumGetBool(constant->constvalue)) + { + containsFalseClause = true; + break; + } + } + } + + return containsFalseClause; +} + + +/* + * PrunableExpressions builds a list of all prunable expressions in node, + * storing them in context->pruningInstances. + */ +static void +PrunableExpressions(Node *node, ClauseWalkerContext *context) +{ + /* + * Build initial list of prunable expressions. As long as only, + * implicitly or explicitly, ANDed expressions are found, this perform a + * depth-first search. When an ORed expression is found, the current + * PruningInstance is added to context->pruningInstances (once for each + * ORed expression), then the tree-traversal is continued without + * recursing. Once at the top-level again, we'll process all pending + * expressions - that allows us to find all ANDed expressions, before + * recursing into an ORed expression. + */ + PrunableExpressionsWalker(node, context); + + /* + * Process all pending instances. While processing, new ones might be + * added to the list, so don't use foreach(). + * + * Check the places in PruningInstanceWalker that push onto + * context->pendingInstances why construction of the PruningInstance might + * be pending. + * + * We copy the partial PruningInstance, and continue adding information by + * calling PrunableExpressionsWalker() on the copy, continuing at the the + * node stored in PendingPruningInstance->continueAt. + */ + while (context->pendingInstances != NIL) + { + PendingPruningInstance *instance = + (PendingPruningInstance *) linitial(context->pendingInstances); + PruningInstance *newPrune = CopyPartialPruningInstance(instance->instance); + + context->pendingInstances = list_delete_first(context->pendingInstances); + + context->currentPruningInstance = newPrune; + PrunableExpressionsWalker(instance->continueAt, context); + context->currentPruningInstance = NULL; + } +} + + +/* + * PrunableExpressionsWalker() is the main work horse for + * PrunableExpressions(). + */ +static bool +PrunableExpressionsWalker(Node *node, ClauseWalkerContext *context) +{ + if (node == NULL) + { + return false; + } + + /* + * Check for expressions understood by this routine. + */ + if (IsA(node, List)) + { + /* at the top of quals we'll frequently see lists, those are to be treated as ANDs */ + } + else if (IsA(node, BoolExpr)) + { + BoolExpr *boolExpr = (BoolExpr *) node; + + if (boolExpr->boolop == NOT_EXPR) + { + return false; + } + else if (boolExpr->boolop == AND_EXPR) + { + return expression_tree_walker((Node *) boolExpr->args, + PrunableExpressionsWalker, context); + } + else if (boolExpr->boolop == OR_EXPR) + { + ListCell *opCell = NULL; + + /* + * "Queue" partial pruning instances. This is used to convert + * expressions like (A AND (B OR C) AND D) into (A AND B AND D), + * (A AND C AND D), with A, B, C, D being restrictions. When the + * OR is encountered, a reference to the partially built + * PruningInstance (containing A at this point), is added to + * context->pendingInstances once for B and once for C. Once a + * full tree-walk completed, PrunableExpressions() will complete + * the pending instances, which'll now also know about restriction + * D, by calling PrunableExpressionsWalker() once for B and once + * for C. + */ + foreach(opCell, boolExpr->args) + { + PendingPruningInstance *instance = + palloc0(sizeof(PendingPruningInstance)); + + instance->instance = context->currentPruningInstance; + instance->continueAt = lfirst(opCell); + + /* + * Signal that this instance is not to be used for pruning on + * its own. Once the pending instance is processed, it'll be + * used. + */ + instance->instance->isPartial = true; + + context->pendingInstances = lappend(context->pendingInstances, instance); + } + + return false; + } + } + else if (IsA(node, OpExpr)) + { + OpExpr *opClause = (OpExpr *) node; + PruningInstance *prune = context->currentPruningInstance; + Node *leftOperand = NULL; + Node *rightOperand = NULL; + Const *constantClause = NULL; + Var *varClause = NULL; + + if (!prune->addedToPruningInstances) + { + context->pruningInstances = lappend(context->pruningInstances, + prune); + prune->addedToPruningInstances = true; + } + + if (list_length(opClause->args) == 2) + { + leftOperand = get_leftop((Expr *) opClause); + rightOperand = get_rightop((Expr *) opClause); + + leftOperand = strip_implicit_coercions(leftOperand); + rightOperand = strip_implicit_coercions(rightOperand); + + if (IsA(rightOperand, Const) && IsA(leftOperand, Var)) + { + constantClause = (Const *) rightOperand; + varClause = (Var *) leftOperand; + } + else if (IsA(leftOperand, Const) && IsA(rightOperand, Var)) + { + constantClause = (Const *) leftOperand; + varClause = (Var *) rightOperand; + } + } + + if (constantClause && varClause && equal(varClause, context->partitionColumn)) + { + /* + * Found a restriction on the partition column itself. Update the + * current constraint with the new information. + */ + AddPartitionKeyRestrictionToInstance(context, + opClause, varClause, constantClause); + } + else if (constantClause && varClause && + varClause->varattno == RESERVED_HASHED_COLUMN_ID) + { + /* + * Found restriction that directly specifies the boundaries of a + * hashed column. + */ + AddHashRestrictionToInstance(context, opClause, varClause, constantClause); + } + + return false; + } + else if (IsA(node, ScalarArrayOpExpr)) + { + PruningInstance *prune = context->currentPruningInstance; + ScalarArrayOpExpr *arrayOperatorExpression = (ScalarArrayOpExpr *) node; + Node *leftOpExpression = linitial(arrayOperatorExpression->args); + Node *strippedLeftOpExpression = strip_implicit_coercions(leftOpExpression); + bool usingEqualityOperator = OperatorImplementsEquality( + arrayOperatorExpression->opno); + + /* + * Citus cannot prune hash-distributed shards with ANY/ALL. We show a NOTICE + * if the expression is ANY/ALL performed on the partition column with equality. + * + * TODO: this'd now be easy to implement, similar to the OR_EXPR case + * above, except that one would push an appropriately constructed + * OpExpr(LHS = $array_element) as continueAt. + */ + if (usingEqualityOperator && strippedLeftOpExpression != NULL && + equal(strippedLeftOpExpression, context->partitionColumn)) + { + ereport(NOTICE, (errmsg("cannot use shard pruning with " + "ANY/ALL (array expression)"), + errhint("Consider rewriting the expression with " + "OR/AND clauses."))); + } + + /* + * Mark expression as added, so we'll fail pruning if there's no ANDed + * restrictions that we can deal with. + */ + if (!prune->addedToPruningInstances) + { + context->pruningInstances = lappend(context->pruningInstances, + prune); + prune->addedToPruningInstances = true; + } + + return false; + } + else + { + PruningInstance *prune = context->currentPruningInstance; + + /* + * Mark expression as added, so we'll fail pruning if there's no ANDed + * restrictions that we know how to deal with. + */ + if (!prune->addedToPruningInstances) + { + context->pruningInstances = lappend(context->pruningInstances, + prune); + prune->addedToPruningInstances = true; + } + + return false; + } + + return expression_tree_walker(node, PrunableExpressionsWalker, context); +} + + +/* + * AddPartitionKeyRestrictionToInstance adds information about a PartitionKey + * $op Const restriction to the current pruning instance. + */ +static void +AddPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opClause, + Var *varClause, Const *constantClause) +{ + PruningInstance *prune = context->currentPruningInstance; + List *btreeInterpretationList = NULL; + ListCell *btreeInterpretationCell = NULL; + bool matchedOp = false; + + btreeInterpretationList = + get_op_btree_interpretation(opClause->opno); + foreach(btreeInterpretationCell, btreeInterpretationList) + { + OpBtreeInterpretation *btreeInterpretation = + (OpBtreeInterpretation *) lfirst(btreeInterpretationCell); + + switch (btreeInterpretation->strategy) + { + case BTLessStrategyNumber: + { + if (!prune->lessConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->lessConsts->constvalue) < 0) + { + prune->lessConsts = constantClause; + } + matchedOp = true; + } + break; + + case BTLessEqualStrategyNumber: + { + if (!prune->lessEqualConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->lessEqualConsts->constvalue) < 0) + { + prune->lessEqualConsts = constantClause; + } + matchedOp = true; + } + break; + + case BTEqualStrategyNumber: + { + if (!prune->equalConsts) + { + prune->equalConsts = constantClause; + } + else if (PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->equalConsts->constvalue) != 0) + { + /* key can't be equal to two values */ + prune->evaluatesToFalse = true; + } + matchedOp = true; + } + break; + + case BTGreaterEqualStrategyNumber: + { + if (!prune->greaterEqualConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->greaterEqualConsts->constvalue) > 0 + ) + { + prune->greaterEqualConsts = constantClause; + } + matchedOp = true; + } + break; + + case BTGreaterStrategyNumber: + { + if (!prune->greaterConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->greaterConsts->constvalue) > 0) + { + prune->greaterConsts = constantClause; + } + matchedOp = true; + } + break; + + case ROWCOMPARE_NE: + { + /* TODO: could add support for this, if we feel like it */ + matchedOp = false; + } + break; + + default: + Assert(false); + } + } + + if (!matchedOp) + { + prune->otherRestrictions = + lappend(prune->otherRestrictions, opClause); + } + else + { + prune->hasValidConstraint = true; + } +} + + +/* + * AddHashRestrictionToInstance adds information about a + * RESERVED_HASHED_COLUMN_ID = Const restriction to the current pruning + * instance. + */ +static void +AddHashRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opClause, + Var *varClause, Const *constantClause) +{ + PruningInstance *prune = context->currentPruningInstance; + List *btreeInterpretationList = NULL; + ListCell *btreeInterpretationCell = NULL; + + btreeInterpretationList = + get_op_btree_interpretation(opClause->opno); + foreach(btreeInterpretationCell, btreeInterpretationList) + { + OpBtreeInterpretation *btreeInterpretation = + (OpBtreeInterpretation *) lfirst(btreeInterpretationCell); + + /* + * Ladidadida, dirty hackety hack. We only add such + * constraints (in ShardIntervalOpExpressions()) to select a + * shard based on its exact boundaries. For efficient binary + * search it's better to simply use one representative value + * to look up the shard. In practice, this is sufficient for + * now. + */ + if (btreeInterpretation->strategy == BTGreaterEqualStrategyNumber) + { + Assert(!prune->hashedEqualConsts); + prune->hashedEqualConsts = constantClause; + prune->hasValidConstraint = true; + } + } +} + + +/* + * CopyPartialPruningInstance copies a partial PruningInstance, so it can be + * completed. + */ +static PruningInstance * +CopyPartialPruningInstance(PruningInstance *sourceInstance) +{ + PruningInstance *newInstance = palloc(sizeof(PruningInstance)); + + Assert(sourceInstance->isPartial); + + /* + * To make the new PruningInstance useful for pruning, we have to reset it + * being partial - if necessary it'll be marked so again by + * PrunableExpressionsWalker(). + */ + memcpy(newInstance, sourceInstance, sizeof(PruningInstance)); + newInstance->addedToPruningInstances = false; + newInstance->isPartial = false; + + return newInstance; +} + + +/* + * ShardArrayToList builds a list of out the array of ShardInterval*. + */ +static List * +ShardArrayToList(ShardInterval **shardArray, int length) +{ + List *shardIntervalList = NIL; + int shardIndex; + + for (shardIndex = 0; shardIndex < length; shardIndex++) + { + ShardInterval *shardInterval = + shardArray[shardIndex]; + shardIntervalList = lappend(shardIntervalList, shardInterval); + } + + return shardIntervalList; +} + + +/* + * DeepCopyShardIntervalList copies originalShardIntervalList and the + * contained ShardIntervals, into a new list. + */ +static List * +DeepCopyShardIntervalList(List *originalShardIntervalList) +{ + List *copiedShardIntervalList = NIL; + ListCell *shardIntervalCell = NULL; + + foreach(shardIntervalCell, originalShardIntervalList) + { + ShardInterval *originalShardInterval = + (ShardInterval *) lfirst(shardIntervalCell); + ShardInterval *copiedShardInterval = + (ShardInterval *) palloc0(sizeof(ShardInterval)); + + CopyShardInterval(originalShardInterval, copiedShardInterval); + copiedShardIntervalList = lappend(copiedShardIntervalList, copiedShardInterval); + } + + return copiedShardIntervalList; +} + + +/* + * PruneOne returns all shards in the table that match a single + * PruningInstance. + */ +static List * +PruneOne(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune) +{ + ShardInterval *shardInterval = NULL; + + /* Well, if life always were this easy... */ + if (prune->evaluatesToFalse) + { + return NIL; + } + + /* + * For an equal constraints, if there's no overlapping shards (always the + * case for hash and range partitioning, sometimes for append), can + * perform binary search for the right interval. That's usually the + * fastest, so try that first. + */ + if (prune->equalConsts && + !cacheEntry->hasOverlappingShardInterval) + { + shardInterval = FindShardInterval(prune->equalConsts->constvalue, cacheEntry); + + /* + * If pruned down to nothing, we're done. Otherwise see if other + * methods prune down further / to nothing. + */ + if (!shardInterval) + { + return NIL; + } + } + + /* + * If the hash value we're looking for is known, we can search for the + * interval directly. That's fast and should only ever be the case for a + * hash-partitioned table. + */ + if (prune->hashedEqualConsts) + { + int shardIndex = INVALID_SHARD_INDEX; + ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; + + Assert(context->partitionMethod == DISTRIBUTE_BY_HASH); + + shardIndex = FindShardIntervalIndex(prune->hashedEqualConsts->constvalue, + cacheEntry); + + if (shardIndex == INVALID_SHARD_INDEX) + { + return NIL; + } + else if (shardInterval && + sortedShardIntervalArray[shardIndex]->shardId != shardInterval->shardId) + { + /* + * equalConst based pruning above yielded a different shard than + * pruning based on pre-hashed equality. This is useful in case + * of INSERT ... SELECT, where both can occur together (one via + * join/colocation, the other via a plain equality restriction). + */ + return NIL; + } + else + { + return list_make1(sortedShardIntervalArray[shardIndex]); + } + } + + /* + * If previous pruning method yielded a single shard, we could also + * attempt range based pruning to exclude it further. But that seems + * rarely useful in practice, and thus likely a waste of runtime and code + * complexity. + */ + if (shardInterval) + { + return list_make1(shardInterval); + } + + /* + * Should never get here for hashing, we've filtered down to either zero + * or one shard, and returned. + */ + Assert(context->partitionMethod != DISTRIBUTE_BY_HASH); + + /* + * Next method: binary search with fuzzy boundaries. Can't trivially do so + * if shards have overlapping boundaries. + * + * TODO: If we kept shard intervals separately sorted by both upper and + * lower boundaries, this should be possible? + */ + if (!cacheEntry->hasOverlappingShardInterval && ( + prune->greaterConsts || prune->greaterEqualConsts || + prune->lessConsts || prune->lessEqualConsts)) + { + return PruneWithBoundaries(cacheEntry, context, prune); + } + + /* + * Brute force: Check each shard. + */ + return ExhaustivePrune(cacheEntry, context, prune); +} + + +/* + * PerformCompare invokes comparator with prepared values, check for + * unexpected NULL returns. + */ +static int +PerformCompare(FunctionCallInfoData *compareFunctionCall) +{ + Datum result = FunctionCallInvoke(compareFunctionCall); + + if (compareFunctionCall->isnull) + { + elog(ERROR, "function %u returned NULL", compareFunctionCall->flinfo->fn_oid); + } + + return DatumGetInt32(result); +} + + +/* + * PerformValueCompare invokes comparator with a/b, and checks for unexpected + * NULL returns. + */ +static int +PerformValueCompare(FunctionCallInfoData *compareFunctionCall, Datum a, Datum b) +{ + compareFunctionCall->arg[0] = a; + compareFunctionCall->argnull[0] = false; + compareFunctionCall->arg[1] = b; + compareFunctionCall->argnull[1] = false; + + return PerformCompare(compareFunctionCall); +} + + +/* + * LowerShardBoundary returns the index of the first ShardInterval that's >= + * (if includeMax) or > partitionColumnValue. + */ +static int +LowerShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, bool includeMax) +{ + int lowerBoundIndex = 0; + int upperBoundIndex = shardCount; + + Assert(shardCount != 0); + + /* setup partitionColumnValue argument once */ + compareFunction->arg[0] = partitionColumnValue; + compareFunction->argnull[0] = false; + + while (lowerBoundIndex < upperBoundIndex) + { + int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2); + int maxValueComparison = 0; + int minValueComparison = 0; + + /* setup minValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->minValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, lowerBound) */ + minValueComparison = PerformCompare(compareFunction); + + /* and evaluate results */ + if (minValueComparison < 0) + { + /* value smaller than entire range */ + upperBoundIndex = middleIndex; + continue; + } + + /* setup maxValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->maxValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, upperBound) */ + maxValueComparison = PerformCompare(compareFunction); + + if ((maxValueComparison == 0 && !includeMax) || + maxValueComparison > 0) + { + /* value bigger than entire range */ + lowerBoundIndex = middleIndex + 1; + continue; + } + + /* found interval containing partitionValue */ + return middleIndex; + } + + Assert(lowerBoundIndex == upperBoundIndex); + + /* + * If we get here, none of the ShardIntervals exactly contain the value + * (we'd have hit the return middleIndex; case otherwise). Figure out + * whether there's possibly any interval containing a value that's bigger + * than the partition key one. + */ + if (lowerBoundIndex == 0) + { + /* all intervals are bigger, thus return 0 */ + return 0; + } + else if (lowerBoundIndex == shardCount) + { + /* partition value is bigger than all partition values */ + return INVALID_SHARD_INDEX; + } + + /* value falls inbetween intervals */ + return lowerBoundIndex + 1; +} + + +/* + * UpperShardBoundary returns the index of the last ShardInterval that's <= + * (if includeMin) or < partitionColumnValue. + */ +static int +UpperShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, bool includeMin) +{ + int lowerBoundIndex = 0; + int upperBoundIndex = shardCount; + + Assert(shardCount != 0); + + /* setup partitionColumnValue argument once */ + compareFunction->arg[0] = partitionColumnValue; + compareFunction->argnull[0] = false; + + while (lowerBoundIndex < upperBoundIndex) + { + int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2); + int maxValueComparison = 0; + int minValueComparison = 0; + + /* setup minValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->minValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, lowerBound) */ + minValueComparison = PerformCompare(compareFunction); + + /* and evaluate results */ + if ((minValueComparison == 0 && !includeMin) || + minValueComparison < 0) + { + /* value smaller than entire range */ + upperBoundIndex = middleIndex; + continue; + } + + /* setup maxValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->maxValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, upperBound) */ + maxValueComparison = PerformCompare(compareFunction); + + if (maxValueComparison > 0) + { + /* value bigger than entire range */ + lowerBoundIndex = middleIndex + 1; + continue; + } + + /* found interval containing partitionValue */ + return middleIndex; + } + + Assert(lowerBoundIndex == upperBoundIndex); + + /* + * If we get here, none of the ShardIntervals exactly contain the value + * (we'd have hit the return middleIndex; case otherwise). Figure out + * whether there's possibly any interval containing a value that's smaller + * than the partition key one. + */ + if (upperBoundIndex == shardCount) + { + /* all intervals are smaller, thus return 0 */ + return shardCount - 1; + } + else if (upperBoundIndex == 0) + { + /* partition value is smaller than all partition values */ + return INVALID_SHARD_INDEX; + } + + /* value falls inbetween intervals, return the inverval one smaller as bound */ + return upperBoundIndex - 1; +} + + +/* + * PruneWithBoundaries searches for shards that match inequality constraints, + * using binary search on both the upper and lower boundary, and returns a + * list of surviving shards. + */ +static List * +PruneWithBoundaries(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune) +{ + List *remainingShardList = NIL; + int shardCount = cacheEntry->shardIntervalArrayLength; + ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; + bool hasLowerBound = false; + bool hasUpperBound = false; + Datum lowerBound = 0; + Datum upperBound = 0; + bool lowerBoundInclusive = false; + bool upperBoundInclusive = false; + int lowerBoundIdx = -1; + int upperBoundIdx = -1; + int curIdx = 0; + FunctionCallInfo compareFunctionCall = &context->compareIntervalFunctionCall; + + if (prune->greaterEqualConsts) + { + lowerBound = prune->greaterEqualConsts->constvalue; + lowerBoundInclusive = true; + hasLowerBound = true; + } + if (prune->greaterConsts) + { + lowerBound = prune->greaterConsts->constvalue; + lowerBoundInclusive = false; + hasLowerBound = true; + } + if (prune->lessEqualConsts) + { + upperBound = prune->lessEqualConsts->constvalue; + upperBoundInclusive = true; + hasUpperBound = true; + } + if (prune->lessConsts) + { + upperBound = prune->lessConsts->constvalue; + upperBoundInclusive = false; + hasUpperBound = true; + } + + Assert(hasLowerBound || hasUpperBound); + + /* find lower bound */ + if (hasLowerBound) + { + lowerBoundIdx = LowerShardBoundary(lowerBound, sortedShardIntervalArray, + shardCount, compareFunctionCall, + lowerBoundInclusive); + } + else + { + lowerBoundIdx = 0; + } + + /* find upper bound */ + if (hasUpperBound) + { + upperBoundIdx = UpperShardBoundary(upperBound, sortedShardIntervalArray, + shardCount, compareFunctionCall, + upperBoundInclusive); + } + else + { + upperBoundIdx = shardCount - 1; + } + + if (lowerBoundIdx == INVALID_SHARD_INDEX) + { + return NIL; + } + else if (upperBoundIdx == INVALID_SHARD_INDEX) + { + return NIL; + } + + /* + * Build list of all shards that are in the range of shards (possibly 0). + */ + for (curIdx = lowerBoundIdx; curIdx <= upperBoundIdx; curIdx++) + { + remainingShardList = lappend(remainingShardList, + sortedShardIntervalArray[curIdx]); + } + + return remainingShardList; +} + + +/* + * ExhaustivePrune returns a list of shards matching PruningInstances + * constraints, by simply checking them for each individual shard. + */ +static List * +ExhaustivePrune(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune) +{ + List *remainingShardList = NIL; + FunctionCallInfo compareFunctionCall = &context->compareIntervalFunctionCall; + int shardCount = cacheEntry->shardIntervalArrayLength; + ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; + int curIdx = 0; + + for (curIdx = 0; curIdx < shardCount; curIdx++) + { + Datum compareWith = 0; + ShardInterval *curInterval = sortedShardIntervalArray[curIdx]; + + /* NULL boundaries can't be compared to */ + if (!curInterval->minValueExists || !curInterval->maxValueExists) + { + remainingShardList = lappend(remainingShardList, curInterval); + continue; + } + + if (prune->equalConsts) + { + compareWith = prune->equalConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + compareWith, + curInterval->minValue) < 0) + { + continue; + } + + if (PerformValueCompare(compareFunctionCall, + compareWith, + curInterval->maxValue) > 0) + { + continue; + } + } + if (prune->greaterEqualConsts) + { + compareWith = prune->greaterEqualConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->maxValue, + compareWith) < 0) + { + continue; + } + } + if (prune->greaterConsts) + { + compareWith = prune->greaterConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->maxValue, + compareWith) <= 0) + { + continue; + } + } + if (prune->lessEqualConsts) + { + compareWith = prune->lessEqualConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->minValue, + compareWith) > 0) + { + continue; + } + } + if (prune->lessConsts) + { + compareWith = prune->lessConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->minValue, + compareWith) >= 0) + { + continue; + } + } + + remainingShardList = lappend(remainingShardList, curInterval); + } + + return remainingShardList; +} diff --git a/src/backend/distributed/test/prune_shard_list.c b/src/backend/distributed/test/prune_shard_list.c index 88827f155..d68ebe516 100644 --- a/src/backend/distributed/test/prune_shard_list.c +++ b/src/backend/distributed/test/prune_shard_list.c @@ -24,6 +24,7 @@ #include "distributed/multi_physical_planner.h" #include "distributed/resource_lock.h" #include "distributed/test_helper_functions.h" /* IWYU pragma: keep */ +#include "distributed/shard_pruning.h" #include "nodes/pg_list.h" #include "nodes/primnodes.h" #include "nodes/nodes.h" @@ -203,11 +204,11 @@ PrunedShardIdsForTable(Oid distributedTableId, List *whereClauseList) Oid shardIdTypeId = INT8OID; Index tableId = 1; - List *shardList = LoadShardIntervalList(distributedTableId); + List *shardList = NIL; int shardIdCount = -1; Datum *shardIdDatumArray = NULL; - shardList = PruneShardList(distributedTableId, tableId, whereClauseList, shardList); + shardList = PruneShards(distributedTableId, tableId, whereClauseList); shardIdCount = list_length(shardList); shardIdDatumArray = palloc0(shardIdCount * sizeof(Datum)); diff --git a/src/backend/distributed/utils/shardinterval_utils.c b/src/backend/distributed/utils/shardinterval_utils.c index 5133da261..046ba3208 100644 --- a/src/backend/distributed/utils/shardinterval_utils.c +++ b/src/backend/distributed/utils/shardinterval_utils.c @@ -16,6 +16,7 @@ #include "catalog/pg_type.h" #include "distributed/metadata_cache.h" #include "distributed/multi_planner.h" +#include "distributed/shard_pruning.h" #include "distributed/shardinterval_utils.h" #include "distributed/pg_dist_partition.h" #include "distributed/worker_protocol.h" @@ -23,7 +24,6 @@ #include "utils/memutils.h" -static int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry); static int SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, int shardCount, FmgrInfo *compareFunction); @@ -254,7 +254,7 @@ FindShardInterval(Datum partitionColumnValue, DistTableCacheEntry *cacheEntry) * somewhere. Such as a hash function which returns a value not in the range * of [INT32_MIN, INT32_MAX] can fire this. */ -static int +int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry) { ShardInterval **shardIntervalCache = cacheEntry->sortedShardIntervalArray; diff --git a/src/include/distributed/multi_physical_planner.h b/src/include/distributed/multi_physical_planner.h index 20512bc3e..0c701c479 100644 --- a/src/include/distributed/multi_physical_planner.h +++ b/src/include/distributed/multi_physical_planner.h @@ -253,10 +253,6 @@ extern StringInfo ShardFetchQueryString(uint64 shardId); extern Task * CreateBasicTask(uint64 jobId, uint32 taskId, TaskType taskType, char *queryString); -/* Function declarations for shard pruning */ -extern List * PruneShardList(Oid relationId, Index tableId, List *whereClauseList, - List *shardList); -extern bool ContainsFalseClause(List *whereClauseList); extern OpExpr * MakeOpExpression(Var *variable, int16 strategyNumber); /* diff --git a/src/include/distributed/shard_pruning.h b/src/include/distributed/shard_pruning.h new file mode 100644 index 000000000..3c26c4662 --- /dev/null +++ b/src/include/distributed/shard_pruning.h @@ -0,0 +1,23 @@ +/*------------------------------------------------------------------------- + * + * shard_pruning.h + * Shard pruning infrastructure. + * + * Copyright (c) 2014-2017, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#ifndef SHARD_PRUNING_H_ +#define SHARD_PRUNING_H_ + +#include "distributed/metadata_cache.h" +#include "nodes/primnodes.h" + +#define INVALID_SHARD_INDEX -1 + +/* Function declarations for shard pruning */ +extern List * PruneShards(Oid relationId, Index rangeTableId, List *whereClauseList); +extern bool ContainsFalseClause(List *whereClauseList); + +#endif /* SHARD_PRUNING_H_ */ diff --git a/src/include/distributed/shardinterval_utils.h b/src/include/distributed/shardinterval_utils.h index 54b96f2e7..5c9d6cde8 100644 --- a/src/include/distributed/shardinterval_utils.h +++ b/src/include/distributed/shardinterval_utils.h @@ -35,6 +35,7 @@ extern int CompareRelationShards(const void *leftElement, extern int ShardIndex(ShardInterval *shardInterval); extern ShardInterval * FindShardInterval(Datum partitionColumnValue, DistTableCacheEntry *cacheEntry); +extern int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry); extern bool SingleReplicatedTable(Oid relationId); #endif /* SHARDINTERVAL_UTILS_H_ */ diff --git a/src/test/regress/expected/multi_prune_shard_list.out b/src/test/regress/expected/multi_prune_shard_list.out index 07d293498..4ddaaae9b 100644 --- a/src/test/regress/expected/multi_prune_shard_list.out +++ b/src/test/regress/expected/multi_prune_shard_list.out @@ -69,21 +69,21 @@ SELECT prune_using_single_value('pruning', NULL); SELECT prune_using_either_value('pruning', 'tomato', 'petunia'); prune_using_either_value -------------------------- - {800001,800002} + {800002,800001} (1 row) --- an AND clause with incompatible values returns no shards +-- an AND clause with values on different shards returns no shards SELECT prune_using_both_values('pruning', 'tomato', 'petunia'); prune_using_both_values ------------------------- {} (1 row) --- but if both values are on the same shard, should get back that shard +-- even if both values are on the same shard, a value can't be equal to two others SELECT prune_using_both_values('pruning', 'tomato', 'rose'); prune_using_both_values ------------------------- - {800002} + {} (1 row) -- unit test of the equality expression generation code diff --git a/src/test/regress/sql/multi_prune_shard_list.sql b/src/test/regress/sql/multi_prune_shard_list.sql index 0e9b9f599..4f42ee9f7 100644 --- a/src/test/regress/sql/multi_prune_shard_list.sql +++ b/src/test/regress/sql/multi_prune_shard_list.sql @@ -59,10 +59,10 @@ SELECT prune_using_single_value('pruning', NULL); -- build an OR clause and expect more than one sahrd SELECT prune_using_either_value('pruning', 'tomato', 'petunia'); --- an AND clause with incompatible values returns no shards +-- an AND clause with values on different shards returns no shards SELECT prune_using_both_values('pruning', 'tomato', 'petunia'); --- but if both values are on the same shard, should get back that shard +-- even if both values are on the same shard, a value can't be equal to two others SELECT prune_using_both_values('pruning', 'tomato', 'rose'); -- unit test of the equality expression generation code