From bb456d4002f36acf6ada26ceac0c54d60acd4569 Mon Sep 17 00:00:00 2001 From: Andres Freund Date: Fri, 28 Apr 2017 14:40:41 -0700 Subject: [PATCH] Faster shard pruning. So far citus used postgres' predicate proofing logic for shard pruning, except for INSERT and COPY which were already optimized for speed. That turns out to be too slow: * Shard pruning for SELECTs is currently O(#shards), because PruneShardList calls predicate_refuted_by() for every shard. Obviously using an O(N) type algorithm for general pruning isn't good. * predicate_refuted_by() is quite expensive on its own right. That's primarily because it's optimized for doing a single refutation proof, rather than performing the same proof over and over. * predicate_refuted_by() does not keep persistent state (see 2.) for function calls, which means that a lot of syscache lookups will be performed. That's particularly bad if the partitioning key is a composite key, because without a persistent FunctionCallInfo record_cmp() has to repeatedly look-up the type definition of the composite key. That's quite expensive. Thus replace this with custom-code that works in two phases: 1) Search restrictions for constraints that can be pruned upon 2) Use those restrictions to search for matching shards in the most efficient manner available: a) Binary search / Hash Lookup in case of hash partitioned tables b) Binary search for equal clauses in case of range or append tables without overlapping shards. c) Binary search for inequality clauses, searching for both lower and upper boundaries, again in case of range or append tables without overlapping shards. d) exhaustive search testing each ShardInterval My measurements suggest that we are considerably, often orders of magnitude, faster than the previous solution, even if we have to fall back to exhaustive pruning. --- src/backend/distributed/commands/multi_copy.c | 1 + .../master/master_modify_multiple_shards.c | 5 +- .../planner/multi_physical_planner.c | 306 +--- .../planner/multi_router_planner.c | 29 +- .../distributed/planner/shard_pruning.c | 1319 +++++++++++++++++ .../distributed/test/prune_shard_list.c | 5 +- .../distributed/utils/shardinterval_utils.c | 4 +- .../distributed/multi_physical_planner.h | 4 - src/include/distributed/shard_pruning.h | 23 + src/include/distributed/shardinterval_utils.h | 1 + .../expected/multi_prune_shard_list.out | 8 +- .../regress/sql/multi_prune_shard_list.sql | 4 +- 12 files changed, 1373 insertions(+), 336 deletions(-) create mode 100644 src/backend/distributed/planner/shard_pruning.c create mode 100644 src/include/distributed/shard_pruning.h diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index fb4379c17..0b6d6cc0d 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -67,6 +67,7 @@ #include "distributed/placement_connection.h" #include "distributed/remote_commands.h" #include "distributed/resource_lock.h" +#include "distributed/shard_pruning.h" #include "executor/executor.h" #include "tsearch/ts_locale.h" #include "utils/builtins.h" diff --git a/src/backend/distributed/master/master_modify_multiple_shards.c b/src/backend/distributed/master/master_modify_multiple_shards.c index 022bbe6e1..22580a0db 100644 --- a/src/backend/distributed/master/master_modify_multiple_shards.c +++ b/src/backend/distributed/master/master_modify_multiple_shards.c @@ -40,6 +40,7 @@ #include "distributed/pg_dist_partition.h" #include "distributed/resource_lock.h" #include "distributed/shardinterval_utils.h" +#include "distributed/shard_pruning.h" #include "distributed/worker_protocol.h" #include "optimizer/clauses.h" #include "optimizer/predtest.h" @@ -81,7 +82,6 @@ master_modify_multiple_shards(PG_FUNCTION_ARGS) Node *queryTreeNode; List *restrictClauseList = NIL; bool failOK = false; - List *shardIntervalList = NIL; List *prunedShardIntervalList = NIL; List *taskList = NIL; int32 affectedTupleCount = 0; @@ -156,11 +156,10 @@ master_modify_multiple_shards(PG_FUNCTION_ARGS) ExecuteMasterEvaluableFunctions(modifyQuery); - shardIntervalList = LoadShardIntervalList(relationId); restrictClauseList = WhereClauseList(modifyQuery->jointree); prunedShardIntervalList = - PruneShardList(relationId, tableId, restrictClauseList, shardIntervalList); + PruneShards(relationId, tableId, restrictClauseList); CHECK_FOR_INTERRUPTS(); diff --git a/src/backend/distributed/planner/multi_physical_planner.c b/src/backend/distributed/planner/multi_physical_planner.c index 22ee7aa60..14bedee08 100644 --- a/src/backend/distributed/planner/multi_physical_planner.c +++ b/src/backend/distributed/planner/multi_physical_planner.c @@ -40,6 +40,7 @@ #include "distributed/pg_dist_partition.h" #include "distributed/pg_dist_shard.h" #include "distributed/shardinterval_utils.h" +#include "distributed/shard_pruning.h" #include "distributed/task_tracker.h" #include "distributed/worker_manager.h" #include "distributed/worker_protocol.h" @@ -131,9 +132,6 @@ static List * RangeTableFragmentsList(List *rangeTableList, List *whereClauseLis static OperatorCacheEntry * LookupOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber); static Oid GetOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber); -static Node * HashableClauseMutator(Node *originalNode, Var *partitionColumn); -static OpExpr * MakeHashedOperatorExpression(OpExpr *operatorExpression); -static List * BuildRestrictInfoList(List *qualList); static List * FragmentCombinationList(List *rangeTableFragmentsList, Query *jobQuery, List *dependedJobList); static JoinSequenceNode * JoinSequenceArray(List *rangeTableFragmentsList, @@ -2044,7 +2042,6 @@ SubquerySqlTaskList(Job *job) { RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell); Oid relationId = rangeTableEntry->relid; - List *shardIntervalList = LoadShardIntervalList(relationId); List *finalShardIntervalList = NIL; ListCell *fragmentCombinationCell = NULL; ListCell *shardIntervalCell = NULL; @@ -2057,12 +2054,11 @@ SubquerySqlTaskList(Job *job) Var *partitionColumn = PartitionColumn(relationId, tableId); List *whereClauseList = ReplaceColumnsInOpExpressionList(opExpressionList, partitionColumn); - finalShardIntervalList = PruneShardList(relationId, tableId, whereClauseList, - shardIntervalList); + finalShardIntervalList = PruneShards(relationId, tableId, whereClauseList); } else { - finalShardIntervalList = shardIntervalList; + finalShardIntervalList = LoadShardIntervalList(relationId); } /* if all shards are pruned away, we return an empty task list */ @@ -2499,11 +2495,8 @@ RangeTableFragmentsList(List *rangeTableList, List *whereClauseList, Oid relationId = rangeTableEntry->relid; ListCell *shardIntervalCell = NULL; List *shardFragmentList = NIL; - - List *shardIntervalList = LoadShardIntervalList(relationId); - List *prunedShardIntervalList = PruneShardList(relationId, tableId, - whereClauseList, - shardIntervalList); + List *prunedShardIntervalList = PruneShards(relationId, tableId, + whereClauseList); /* * If we prune all shards for one table, query results will be empty. @@ -2572,119 +2565,6 @@ RangeTableFragmentsList(List *rangeTableList, List *whereClauseList, } -/* - * PruneShardList prunes shard intervals from given list based on the selection criteria, - * and returns remaining shard intervals in another list. - * - * For reference tables, the function simply returns the single shard that the table has. - */ -List * -PruneShardList(Oid relationId, Index tableId, List *whereClauseList, - List *shardIntervalList) -{ - List *remainingShardList = NIL; - ListCell *shardIntervalCell = NULL; - List *restrictInfoList = NIL; - Node *baseConstraint = NULL; - - Var *partitionColumn = PartitionColumn(relationId, tableId); - char partitionMethod = PartitionMethod(relationId); - - /* short circuit for reference tables */ - if (partitionMethod == DISTRIBUTE_BY_NONE) - { - return shardIntervalList; - } - - if (ContainsFalseClause(whereClauseList)) - { - /* always return empty result if WHERE clause is of the form: false (AND ..) */ - return NIL; - } - - /* build the filter clause list for the partition method */ - if (partitionMethod == DISTRIBUTE_BY_HASH) - { - Node *hashedNode = HashableClauseMutator((Node *) whereClauseList, - partitionColumn); - - List *hashedClauseList = (List *) hashedNode; - restrictInfoList = BuildRestrictInfoList(hashedClauseList); - } - else - { - restrictInfoList = BuildRestrictInfoList(whereClauseList); - } - - /* override the partition column for hash partitioning */ - if (partitionMethod == DISTRIBUTE_BY_HASH) - { - partitionColumn = MakeInt4Column(); - } - - /* build the base expression for constraint */ - baseConstraint = BuildBaseConstraint(partitionColumn); - - /* walk over shard list and check if shards can be pruned */ - foreach(shardIntervalCell, shardIntervalList) - { - ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); - List *constraintList = NIL; - bool shardPruned = false; - - if (shardInterval->minValueExists && shardInterval->maxValueExists) - { - /* set the min/max values in the base constraint */ - UpdateConstraint(baseConstraint, shardInterval); - constraintList = list_make1(baseConstraint); - - shardPruned = predicate_refuted_by(constraintList, restrictInfoList); - } - - if (shardPruned) - { - ereport(DEBUG2, (errmsg("predicate pruning for shardId " - UINT64_FORMAT, shardInterval->shardId))); - } - else - { - remainingShardList = lappend(remainingShardList, shardInterval); - } - } - - return remainingShardList; -} - - -/* - * ContainsFalseClause returns whether the flattened where clause list - * contains false as a clause. - */ -bool -ContainsFalseClause(List *whereClauseList) -{ - bool containsFalseClause = false; - ListCell *clauseCell = NULL; - - foreach(clauseCell, whereClauseList) - { - Node *clause = (Node *) lfirst(clauseCell); - - if (IsA(clause, Const)) - { - Const *constant = (Const *) clause; - if (constant->consttype == BOOLOID && !DatumGetBool(constant->constvalue)) - { - containsFalseClause = true; - break; - } - } - } - - return containsFalseClause; -} - - /* * BuildBaseConstraint builds and returns a base constraint. This constraint * implements an expression in the form of (column <= max && column >= min), @@ -2907,87 +2787,6 @@ SimpleOpExpression(Expr *clause) } -/* - * HashableClauseMutator walks over the original where clause list, replaces - * hashable nodes with hashed versions and keeps other nodes as they are. - */ -static Node * -HashableClauseMutator(Node *originalNode, Var *partitionColumn) -{ - Node *newNode = NULL; - if (originalNode == NULL) - { - return NULL; - } - - if (IsA(originalNode, OpExpr)) - { - OpExpr *operatorExpression = (OpExpr *) originalNode; - bool hasPartitionColumn = false; - - Oid leftHashFunction = InvalidOid; - Oid rightHashFunction = InvalidOid; - - /* - * If operatorExpression->opno is NOT the registered '=' operator for - * any hash opfamilies, then get_op_hash_functions will return false. - * This means this function both ensures a hash function exists for the - * types in question AND filters out any clauses lacking equality ops. - */ - bool hasHashFunction = get_op_hash_functions(operatorExpression->opno, - &leftHashFunction, - &rightHashFunction); - - bool simpleOpExpression = SimpleOpExpression((Expr *) operatorExpression); - if (simpleOpExpression) - { - hasPartitionColumn = OpExpressionContainsColumn(operatorExpression, - partitionColumn); - } - - if (hasHashFunction && hasPartitionColumn) - { - OpExpr *hashedOperatorExpression = - MakeHashedOperatorExpression((OpExpr *) originalNode); - newNode = (Node *) hashedOperatorExpression; - } - } - else if (IsA(originalNode, ScalarArrayOpExpr)) - { - ScalarArrayOpExpr *arrayOperatorExpression = (ScalarArrayOpExpr *) originalNode; - Node *leftOpExpression = linitial(arrayOperatorExpression->args); - Node *strippedLeftOpExpression = strip_implicit_coercions(leftOpExpression); - bool usingEqualityOperator = OperatorImplementsEquality( - arrayOperatorExpression->opno); - - /* - * Citus cannot prune hash-distributed shards with ANY/ALL. We show a NOTICE - * if the expression is ANY/ALL performed on the partition column with equality. - */ - if (usingEqualityOperator && strippedLeftOpExpression != NULL && - equal(strippedLeftOpExpression, partitionColumn)) - { - ereport(NOTICE, (errmsg("cannot use shard pruning with " - "ANY/ALL (array expression)"), - errhint("Consider rewriting the expression with " - "OR/AND clauses."))); - } - } - - /* - * If this node is not hashable, continue walking down the expression tree - * to find and hash clauses which are eligible. - */ - if (newNode == NULL) - { - newNode = expression_tree_mutator(originalNode, HashableClauseMutator, - (void *) partitionColumn); - } - - return newNode; -} - - /* * OpExpressionContainsColumn checks if the operator expression contains the * given partition column. We assume that given operator expression is a simple @@ -3018,77 +2817,6 @@ OpExpressionContainsColumn(OpExpr *operatorExpression, Var *partitionColumn) } -/* - * MakeHashedOperatorExpression creates a new operator expression with a column - * of int4 type and hashed constant value. - */ -static OpExpr * -MakeHashedOperatorExpression(OpExpr *operatorExpression) -{ - const Oid hashResultTypeId = INT4OID; - TypeCacheEntry *hashResultTypeEntry = NULL; - Oid operatorId = InvalidOid; - OpExpr *hashedExpression = NULL; - Var *hashedColumn = NULL; - Datum hashedValue = 0; - Const *hashedConstant = NULL; - FmgrInfo *hashFunction = NULL; - TypeCacheEntry *typeEntry = NULL; - - Node *leftOperand = get_leftop((Expr *) operatorExpression); - Node *rightOperand = get_rightop((Expr *) operatorExpression); - Const *constant = NULL; - - if (IsA(rightOperand, Const)) - { - constant = (Const *) rightOperand; - } - else - { - constant = (Const *) leftOperand; - } - - /* Load the operator from type cache */ - hashResultTypeEntry = lookup_type_cache(hashResultTypeId, TYPECACHE_EQ_OPR); - operatorId = hashResultTypeEntry->eq_opr; - - /* Get a column with int4 type */ - hashedColumn = MakeInt4Column(); - - /* Load the hash function from type cache */ - typeEntry = lookup_type_cache(constant->consttype, TYPECACHE_HASH_PROC_FINFO); - hashFunction = &(typeEntry->hash_proc_finfo); - if (!OidIsValid(hashFunction->fn_oid)) - { - ereport(ERROR, (errcode(ERRCODE_UNDEFINED_FUNCTION), - errmsg("could not identify a hash function for type %s", - format_type_be(constant->consttype)), - errdatatype(constant->consttype))); - } - - /* - * Note that any changes to PostgreSQL's hashing functions will change the - * new value created by this function. - */ - hashedValue = FunctionCall1(hashFunction, constant->constvalue); - hashedConstant = MakeInt4Constant(hashedValue); - - /* Now create the expression with modified partition column and hashed constant */ - hashedExpression = (OpExpr *) make_opclause(operatorId, - InvalidOid, /* no result type yet */ - false, /* no return set */ - (Expr *) hashedColumn, - (Expr *) hashedConstant, - InvalidOid, InvalidOid); - - /* Set implementing function id and result type */ - hashedExpression->opfuncid = get_opcode(operatorId); - hashedExpression->opresulttype = get_func_rettype(hashedExpression->opfuncid); - - return hashedExpression; -} - - /* * MakeInt4Column creates a column of int4 type with invalid table id and max * attribute number. @@ -3130,30 +2858,6 @@ MakeInt4Constant(Datum constantValue) } -/* - * BuildRestrictInfoList builds restrict info list using the selection criteria, - * and then return this list. Note that this function assumes there is only one - * relation for now. - */ -static List * -BuildRestrictInfoList(List *qualList) -{ - List *restrictInfoList = NIL; - ListCell *qualCell = NULL; - - foreach(qualCell, qualList) - { - RestrictInfo *restrictInfo = NULL; - Node *qualNode = (Node *) lfirst(qualCell); - - restrictInfo = make_simple_restrictinfo((Expr *) qualNode); - restrictInfoList = lappend(restrictInfoList, restrictInfo); - } - - return restrictInfoList; -} - - /* Updates the base constraint with the given min/max values. */ void UpdateConstraint(Node *baseConstraint, ShardInterval *shardInterval) diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 5c765c2d8..2f563cfbe 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -39,6 +39,7 @@ #include "distributed/relay_utility.h" #include "distributed/resource_lock.h" #include "distributed/shardinterval_utils.h" +#include "distributed/shard_pruning.h" #include "executor/execdesc.h" #include "lib/stringinfo.h" #include "nodes/makefuncs.h" @@ -533,8 +534,13 @@ RouterModifyTaskForShardInterval(Query *originalQuery, ShardInterval *shardInter * items in it. The list consists of shard interval ranges with hashed columns * such as (hashColumn >= shardMinValue) and (hashedColumn <= shardMaxValue). * - * The function errors out if the given shard interval does not belong to a hash - * distributed table. + * The function returns hashed columns generated by MakeInt4Column() for the hash + * partitioned tables in place of partition columns. + * + * The function errors out if the given shard interval does not belong to a hash, + * range and append distributed tables. + * + * NB: If you update this, also look at PrunableExpressionsWalker(). */ static List * HashedShardIntervalOpExpressions(ShardInterval *shardInterval) @@ -2070,10 +2076,8 @@ TargetShardIntervalForModify(Query *query, DeferredErrorMessage **planningError) { List *restrictClauseList = QueryRestrictList(query); Index tableId = 1; - List *shardIntervalList = LoadShardIntervalList(distributedTableId); - prunedShardList = PruneShardList(distributedTableId, tableId, restrictClauseList, - shardIntervalList); + prunedShardList = PruneShards(distributedTableId, tableId, restrictClauseList); } prunedShardCount = list_length(prunedShardList); @@ -2467,7 +2471,6 @@ TargetShardIntervalsForSelect(Query *query, List *baseRestrictionList = relationRestriction->relOptInfo->baserestrictinfo; List *restrictClauseList = get_all_actual_clauses(baseRestrictionList); List *prunedShardList = NIL; - int shardIndex = 0; List *joinInfoList = relationRestriction->relOptInfo->joininfo; List *pseudoRestrictionList = extract_actual_clauses(joinInfoList, true); bool whereFalseQuery = false; @@ -2483,18 +2486,8 @@ TargetShardIntervalsForSelect(Query *query, whereFalseQuery = ContainsFalseClause(pseudoRestrictionList); if (!whereFalseQuery && shardCount > 0) { - List *shardIntervalList = NIL; - - for (shardIndex = 0; shardIndex < shardCount; shardIndex++) - { - ShardInterval *shardInterval = - cacheEntry->sortedShardIntervalArray[shardIndex]; - shardIntervalList = lappend(shardIntervalList, shardInterval); - } - - prunedShardList = PruneShardList(relationId, tableId, - restrictClauseList, - shardIntervalList); + prunedShardList = PruneShards(relationId, tableId, + restrictClauseList); /* * Quick bail out. The query can not be router plannable if one diff --git a/src/backend/distributed/planner/shard_pruning.c b/src/backend/distributed/planner/shard_pruning.c new file mode 100644 index 000000000..807b4f71e --- /dev/null +++ b/src/backend/distributed/planner/shard_pruning.c @@ -0,0 +1,1319 @@ +/*------------------------------------------------------------------------- + * + * shard_pruning.c + * Shard pruning related code. + * + * The goal of shard pruning is to find a minimal (super)set of shards that + * need to be queried to find rows matching the expression in a query. + * + * In PruneShards, we first compute a simplified disjunctive normal form (DNF) + * of the expression as a list of pruning instances. Each pruning instance + * contains all AND-ed constraints on the partition column. An OR expression + * will result in two or more new pruning instances being added for the + * subexpressions. The "parent" instance is marked isPartial and ignored + * during pruning. + * + * We use the distributive property for constraints of the form P AND (Q OR R) + * to rewrite it to (P AND Q) OR (P AND R) by copying constraints from parent + * to "child" pruning instances. However, we do not distribute nested + * expressions. While (P OR Q) AND (R OR S) is logically equivalent to (P AND + * R) OR (P AND S) OR (Q AND R) OR (Q AND S), in our implementation it becomes + * P OR Q OR R OR S. This is acceptable since this will always result in a + * superset of shards. If this proves to be a issue in practice, a more + * complete algorithm could be implemented. + * + * We then evaluate each non-partial pruning instance in the disjunction + * through the following, increasingly expensive, steps: + * + * 1) If there is a constant equality constraint on the partition column, and + * no overlapping shards exist, find the shard interval in which the + * constant falls + * + * 2) If there is a hash range constraint on the partition column, find the + * shard interval matching the range + * + * 3) If there are range constraints (e.g. (a > 0 AND a < 10)) on the + * partition column, find the shard intervals that overlap with the range + * + * 4) If there are overlapping shards, exhaustively search all shards that are + * not excluded by constraints + * + * Finally, the union of the shards found by each pruning instance is + * returned. + * + * Copyright (c) 2014-2017, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ +#include "postgres.h" + +#include "distributed/shard_pruning.h" + +#include "access/nbtree.h" +#include "catalog/pg_am.h" +#include "catalog/pg_collation.h" +#include "catalog/pg_type.h" +#include "distributed/metadata_cache.h" +#include "distributed/multi_planner.h" +#include "distributed/multi_join_order.h" +#include "distributed/multi_physical_planner.h" +#include "distributed/shardinterval_utils.h" +#include "distributed/pg_dist_partition.h" +#include "distributed/worker_protocol.h" +#include "nodes/nodeFuncs.h" +#include "optimizer/clauses.h" +#include "utils/catcache.h" +#include "utils/lsyscache.h" +#include "utils/memutils.h" + +/* + * A pruning instance is a set of ANDed constraints on a partition key. + */ +typedef struct PruningInstance +{ + /* Does this instance contain any prunable expressions? */ + bool hasValidConstraint; + + /* + * This constraint never evaluates to true, i.e. pruning does not have to + * be performed. + */ + bool evaluatesToFalse; + + /* + * Constraints on the partition column value. If multiple values are + * found the more restrictive one should be stored here. Even in case of + * a hash-partitioned table, actual column-values are stored here, *not* + * hashed values. + */ + Const *lessConsts; + Const *lessEqualConsts; + Const *equalConsts; + Const *greaterEqualConsts; + Const *greaterConsts; + + /* + * Constraint using a pre-hashed column value. The constant will store the + * hashed value, not the original value of the restriction. + */ + Const *hashedEqualConsts; + + /* + * Types of constraints not understood. We could theoretically try more + * expensive methods of pruning if any such restrictions are found. + * + * TODO: any actual use for this? Right now there seems little point. + */ + List *otherRestrictions; + + /* + * Has this PruningInstance been added to + * ClauseWalkerContext->pruningInstances? This is not done immediately, + * but the first time a constraint (independent of us being able to handle + * that constraint) is found. + */ + bool addedToPruningInstances; + + /* + * When OR clauses are found, the non-ORed part (think of a < 3 AND (a > 5 + * OR a > 7)) of the expression is stored in one PruningInstance which is + * then copied for the ORed expressions. The original is marked as + * isPartial, to avoid it being used for pruning. + */ + bool isPartial; +} PruningInstance; + + +/* + * Partial instances that need to be finished building. This is used to + * collect all ANDed restrictions, before looking into ORed expressions. + */ +typedef struct PendingPruningInstance +{ + PruningInstance *instance; + Node *continueAt; +} PendingPruningInstance; + + +/* + * Data necessary to perform a single PruneShards(). + */ +typedef struct ClauseWalkerContext +{ + Var *partitionColumn; + char partitionMethod; + + /* ORed list of pruning targets */ + List *pruningInstances; + + /* + * Partially built PruningInstances, that need to be completed by doing a + * separate PrunableExpressionsWalker() pass. + */ + List *pendingInstances; + + /* PruningInstance currently being built, all elegible constraints are added here */ + PruningInstance *currentPruningInstance; + + /* + * Information about function calls we need to perform. Re-using the same + * FunctionCallInfoData, instead of using FunctionCall2Coll, is often + * cheaper. + */ + FunctionCallInfoData compareValueFunctionCall; + FunctionCallInfoData compareIntervalFunctionCall; +} ClauseWalkerContext; + +static void PrunableExpressions(Node *originalNode, ClauseWalkerContext *context); +static bool PrunableExpressionsWalker(Node *originalNode, ClauseWalkerContext *context); +static void AddPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, + OpExpr *opClause, Var *varClause, + Const *constantClause); +static void AddHashRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opClause, + Var *varClause, Const *constantClause); +static PruningInstance * CopyPartialPruningInstance(PruningInstance *sourceInstance); +static List * ShardArrayToList(ShardInterval **shardArray, int length); +static List * DeepCopyShardIntervalList(List *originalShardIntervalList); +static int PerformValueCompare(FunctionCallInfoData *compareFunctionCall, Datum a, + Datum b); +static int PerformCompare(FunctionCallInfoData *compareFunctionCall); + +static List * PruneOne(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune); +static List * PruneWithBoundaries(DistTableCacheEntry *cacheEntry, + ClauseWalkerContext *context, + PruningInstance *prune); +static List * ExhaustivePrune(DistTableCacheEntry *cacheEntry, + ClauseWalkerContext *context, + PruningInstance *prune); +static int UpperShardBoundary(Datum partitionColumnValue, + ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, + bool includeMin); +static int LowerShardBoundary(Datum partitionColumnValue, + ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, + bool includeMax); + + +/* + * PruneShards returns all shards from a distributed table that cannot be + * proven to be eliminated by whereClauseList. + * + * For reference tables, the function simply returns the single shard that the + * table has. + */ +List * +PruneShards(Oid relationId, Index rangeTableId, List *whereClauseList) +{ + DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId); + char partitionMethod = cacheEntry->partitionMethod; + ClauseWalkerContext context = { 0 }; + ListCell *pruneCell; + List *prunedList = NIL; + bool foundRestriction = false; + + /* always return empty result if WHERE clause is of the form: false (AND ..) */ + if (ContainsFalseClause(whereClauseList)) + { + return NIL; + } + + /* short circuit for reference tables */ + if (partitionMethod == DISTRIBUTE_BY_NONE) + { + prunedList = ShardArrayToList(cacheEntry->sortedShardIntervalArray, + cacheEntry->shardIntervalArrayLength); + return DeepCopyShardIntervalList(prunedList); + } + + + context.partitionMethod = partitionMethod; + context.partitionColumn = PartitionColumn(relationId, rangeTableId); + context.currentPruningInstance = palloc0(sizeof(PruningInstance)); + + if (cacheEntry->shardIntervalCompareFunction) + { + /* initiate function call info once (allows comparators to cache metadata) */ + InitFunctionCallInfoData(context.compareIntervalFunctionCall, + cacheEntry->shardIntervalCompareFunction, + 2, DEFAULT_COLLATION_OID, NULL, NULL); + } + else + { + ereport(ERROR, (errmsg("shard pruning not possible without " + "a shard interval comparator"))); + } + + if (cacheEntry->shardColumnCompareFunction) + { + /* initiate function call info once (allows comparators to cache metadata) */ + InitFunctionCallInfoData(context.compareValueFunctionCall, + cacheEntry->shardColumnCompareFunction, + 2, DEFAULT_COLLATION_OID, NULL, NULL); + } + else + { + ereport(ERROR, (errmsg("shard pruning not possible without " + "a partition column comparator"))); + } + + /* Figure out what we can prune on */ + PrunableExpressions((Node *) whereClauseList, &context); + + /* + * Prune using each of the PrunableInstances we found, and OR results + * together. + */ + foreach(pruneCell, context.pruningInstances) + { + PruningInstance *prune = (PruningInstance *) lfirst(pruneCell); + List *pruneOneList; + + /* + * If this is a partial instance, a fully built one has also been + * added. Skip. + */ + if (prune->isPartial) + { + continue; + } + + /* + * If the current instance has no prunable expressions, we'll have to + * return all shards. No point in continuing pruning in that case. + */ + if (!prune->hasValidConstraint) + { + foundRestriction = false; + break; + } + + /* + * Similar to the above, if hash-partitioned and there's nothing to + * prune by, we're done. + */ + if (context.partitionMethod == DISTRIBUTE_BY_HASH && + !prune->evaluatesToFalse && !prune->equalConsts && !prune->hashedEqualConsts) + { + foundRestriction = false; + break; + } + + pruneOneList = PruneOne(cacheEntry, &context, prune); + + if (prunedList) + { + /* + * We can use list_union_ptr, which is a lot faster than doing + * comparing shards by value, because all the ShardIntervals are + * guaranteed to be from + * DistTableCacheEntry->sortedShardIntervalArray (thus having the + * same pointer values). + */ + prunedList = list_union_ptr(prunedList, pruneOneList); + } + else + { + prunedList = pruneOneList; + } + foundRestriction = true; + } + + /* found no valid restriction, build list of all shards */ + if (!foundRestriction) + { + prunedList = ShardArrayToList(cacheEntry->sortedShardIntervalArray, + cacheEntry->shardIntervalArrayLength); + } + + /* + * Deep copy list, so it's independent of the DistTableCacheEntry + * contents. + */ + return DeepCopyShardIntervalList(prunedList); +} + + +/* + * ContainsFalseClause returns whether the flattened where clause list + * contains false as a clause. + */ +bool +ContainsFalseClause(List *whereClauseList) +{ + bool containsFalseClause = false; + ListCell *clauseCell = NULL; + + foreach(clauseCell, whereClauseList) + { + Node *clause = (Node *) lfirst(clauseCell); + + if (IsA(clause, Const)) + { + Const *constant = (Const *) clause; + if (constant->consttype == BOOLOID && !DatumGetBool(constant->constvalue)) + { + containsFalseClause = true; + break; + } + } + } + + return containsFalseClause; +} + + +/* + * PrunableExpressions builds a list of all prunable expressions in node, + * storing them in context->pruningInstances. + */ +static void +PrunableExpressions(Node *node, ClauseWalkerContext *context) +{ + /* + * Build initial list of prunable expressions. As long as only, + * implicitly or explicitly, ANDed expressions are found, this perform a + * depth-first search. When an ORed expression is found, the current + * PruningInstance is added to context->pruningInstances (once for each + * ORed expression), then the tree-traversal is continued without + * recursing. Once at the top-level again, we'll process all pending + * expressions - that allows us to find all ANDed expressions, before + * recursing into an ORed expression. + */ + PrunableExpressionsWalker(node, context); + + /* + * Process all pending instances. While processing, new ones might be + * added to the list, so don't use foreach(). + * + * Check the places in PruningInstanceWalker that push onto + * context->pendingInstances why construction of the PruningInstance might + * be pending. + * + * We copy the partial PruningInstance, and continue adding information by + * calling PrunableExpressionsWalker() on the copy, continuing at the the + * node stored in PendingPruningInstance->continueAt. + */ + while (context->pendingInstances != NIL) + { + PendingPruningInstance *instance = + (PendingPruningInstance *) linitial(context->pendingInstances); + PruningInstance *newPrune = CopyPartialPruningInstance(instance->instance); + + context->pendingInstances = list_delete_first(context->pendingInstances); + + context->currentPruningInstance = newPrune; + PrunableExpressionsWalker(instance->continueAt, context); + context->currentPruningInstance = NULL; + } +} + + +/* + * PrunableExpressionsWalker() is the main work horse for + * PrunableExpressions(). + */ +static bool +PrunableExpressionsWalker(Node *node, ClauseWalkerContext *context) +{ + if (node == NULL) + { + return false; + } + + /* + * Check for expressions understood by this routine. + */ + if (IsA(node, List)) + { + /* at the top of quals we'll frequently see lists, those are to be treated as ANDs */ + } + else if (IsA(node, BoolExpr)) + { + BoolExpr *boolExpr = (BoolExpr *) node; + + if (boolExpr->boolop == NOT_EXPR) + { + return false; + } + else if (boolExpr->boolop == AND_EXPR) + { + return expression_tree_walker((Node *) boolExpr->args, + PrunableExpressionsWalker, context); + } + else if (boolExpr->boolop == OR_EXPR) + { + ListCell *opCell = NULL; + + /* + * "Queue" partial pruning instances. This is used to convert + * expressions like (A AND (B OR C) AND D) into (A AND B AND D), + * (A AND C AND D), with A, B, C, D being restrictions. When the + * OR is encountered, a reference to the partially built + * PruningInstance (containing A at this point), is added to + * context->pendingInstances once for B and once for C. Once a + * full tree-walk completed, PrunableExpressions() will complete + * the pending instances, which'll now also know about restriction + * D, by calling PrunableExpressionsWalker() once for B and once + * for C. + */ + foreach(opCell, boolExpr->args) + { + PendingPruningInstance *instance = + palloc0(sizeof(PendingPruningInstance)); + + instance->instance = context->currentPruningInstance; + instance->continueAt = lfirst(opCell); + + /* + * Signal that this instance is not to be used for pruning on + * its own. Once the pending instance is processed, it'll be + * used. + */ + instance->instance->isPartial = true; + + context->pendingInstances = lappend(context->pendingInstances, instance); + } + + return false; + } + } + else if (IsA(node, OpExpr)) + { + OpExpr *opClause = (OpExpr *) node; + PruningInstance *prune = context->currentPruningInstance; + Node *leftOperand = NULL; + Node *rightOperand = NULL; + Const *constantClause = NULL; + Var *varClause = NULL; + + if (!prune->addedToPruningInstances) + { + context->pruningInstances = lappend(context->pruningInstances, + prune); + prune->addedToPruningInstances = true; + } + + if (list_length(opClause->args) == 2) + { + leftOperand = get_leftop((Expr *) opClause); + rightOperand = get_rightop((Expr *) opClause); + + leftOperand = strip_implicit_coercions(leftOperand); + rightOperand = strip_implicit_coercions(rightOperand); + + if (IsA(rightOperand, Const) && IsA(leftOperand, Var)) + { + constantClause = (Const *) rightOperand; + varClause = (Var *) leftOperand; + } + else if (IsA(leftOperand, Const) && IsA(rightOperand, Var)) + { + constantClause = (Const *) leftOperand; + varClause = (Var *) rightOperand; + } + } + + if (constantClause && varClause && equal(varClause, context->partitionColumn)) + { + /* + * Found a restriction on the partition column itself. Update the + * current constraint with the new information. + */ + AddPartitionKeyRestrictionToInstance(context, + opClause, varClause, constantClause); + } + else if (constantClause && varClause && + varClause->varattno == RESERVED_HASHED_COLUMN_ID) + { + /* + * Found restriction that directly specifies the boundaries of a + * hashed column. + */ + AddHashRestrictionToInstance(context, opClause, varClause, constantClause); + } + + return false; + } + else if (IsA(node, ScalarArrayOpExpr)) + { + PruningInstance *prune = context->currentPruningInstance; + ScalarArrayOpExpr *arrayOperatorExpression = (ScalarArrayOpExpr *) node; + Node *leftOpExpression = linitial(arrayOperatorExpression->args); + Node *strippedLeftOpExpression = strip_implicit_coercions(leftOpExpression); + bool usingEqualityOperator = OperatorImplementsEquality( + arrayOperatorExpression->opno); + + /* + * Citus cannot prune hash-distributed shards with ANY/ALL. We show a NOTICE + * if the expression is ANY/ALL performed on the partition column with equality. + * + * TODO: this'd now be easy to implement, similar to the OR_EXPR case + * above, except that one would push an appropriately constructed + * OpExpr(LHS = $array_element) as continueAt. + */ + if (usingEqualityOperator && strippedLeftOpExpression != NULL && + equal(strippedLeftOpExpression, context->partitionColumn)) + { + ereport(NOTICE, (errmsg("cannot use shard pruning with " + "ANY/ALL (array expression)"), + errhint("Consider rewriting the expression with " + "OR/AND clauses."))); + } + + /* + * Mark expression as added, so we'll fail pruning if there's no ANDed + * restrictions that we can deal with. + */ + if (!prune->addedToPruningInstances) + { + context->pruningInstances = lappend(context->pruningInstances, + prune); + prune->addedToPruningInstances = true; + } + + return false; + } + else + { + PruningInstance *prune = context->currentPruningInstance; + + /* + * Mark expression as added, so we'll fail pruning if there's no ANDed + * restrictions that we know how to deal with. + */ + if (!prune->addedToPruningInstances) + { + context->pruningInstances = lappend(context->pruningInstances, + prune); + prune->addedToPruningInstances = true; + } + + return false; + } + + return expression_tree_walker(node, PrunableExpressionsWalker, context); +} + + +/* + * AddPartitionKeyRestrictionToInstance adds information about a PartitionKey + * $op Const restriction to the current pruning instance. + */ +static void +AddPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opClause, + Var *varClause, Const *constantClause) +{ + PruningInstance *prune = context->currentPruningInstance; + List *btreeInterpretationList = NULL; + ListCell *btreeInterpretationCell = NULL; + bool matchedOp = false; + + btreeInterpretationList = + get_op_btree_interpretation(opClause->opno); + foreach(btreeInterpretationCell, btreeInterpretationList) + { + OpBtreeInterpretation *btreeInterpretation = + (OpBtreeInterpretation *) lfirst(btreeInterpretationCell); + + switch (btreeInterpretation->strategy) + { + case BTLessStrategyNumber: + { + if (!prune->lessConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->lessConsts->constvalue) < 0) + { + prune->lessConsts = constantClause; + } + matchedOp = true; + } + break; + + case BTLessEqualStrategyNumber: + { + if (!prune->lessEqualConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->lessEqualConsts->constvalue) < 0) + { + prune->lessEqualConsts = constantClause; + } + matchedOp = true; + } + break; + + case BTEqualStrategyNumber: + { + if (!prune->equalConsts) + { + prune->equalConsts = constantClause; + } + else if (PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->equalConsts->constvalue) != 0) + { + /* key can't be equal to two values */ + prune->evaluatesToFalse = true; + } + matchedOp = true; + } + break; + + case BTGreaterEqualStrategyNumber: + { + if (!prune->greaterEqualConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->greaterEqualConsts->constvalue) > 0 + ) + { + prune->greaterEqualConsts = constantClause; + } + matchedOp = true; + } + break; + + case BTGreaterStrategyNumber: + { + if (!prune->greaterConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->greaterConsts->constvalue) > 0) + { + prune->greaterConsts = constantClause; + } + matchedOp = true; + } + break; + + case ROWCOMPARE_NE: + { + /* TODO: could add support for this, if we feel like it */ + matchedOp = false; + } + break; + + default: + Assert(false); + } + } + + if (!matchedOp) + { + prune->otherRestrictions = + lappend(prune->otherRestrictions, opClause); + } + else + { + prune->hasValidConstraint = true; + } +} + + +/* + * AddHashRestrictionToInstance adds information about a + * RESERVED_HASHED_COLUMN_ID = Const restriction to the current pruning + * instance. + */ +static void +AddHashRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opClause, + Var *varClause, Const *constantClause) +{ + PruningInstance *prune = context->currentPruningInstance; + List *btreeInterpretationList = NULL; + ListCell *btreeInterpretationCell = NULL; + + btreeInterpretationList = + get_op_btree_interpretation(opClause->opno); + foreach(btreeInterpretationCell, btreeInterpretationList) + { + OpBtreeInterpretation *btreeInterpretation = + (OpBtreeInterpretation *) lfirst(btreeInterpretationCell); + + /* + * Ladidadida, dirty hackety hack. We only add such + * constraints (in ShardIntervalOpExpressions()) to select a + * shard based on its exact boundaries. For efficient binary + * search it's better to simply use one representative value + * to look up the shard. In practice, this is sufficient for + * now. + */ + if (btreeInterpretation->strategy == BTGreaterEqualStrategyNumber) + { + Assert(!prune->hashedEqualConsts); + prune->hashedEqualConsts = constantClause; + prune->hasValidConstraint = true; + } + } +} + + +/* + * CopyPartialPruningInstance copies a partial PruningInstance, so it can be + * completed. + */ +static PruningInstance * +CopyPartialPruningInstance(PruningInstance *sourceInstance) +{ + PruningInstance *newInstance = palloc(sizeof(PruningInstance)); + + Assert(sourceInstance->isPartial); + + /* + * To make the new PruningInstance useful for pruning, we have to reset it + * being partial - if necessary it'll be marked so again by + * PrunableExpressionsWalker(). + */ + memcpy(newInstance, sourceInstance, sizeof(PruningInstance)); + newInstance->addedToPruningInstances = false; + newInstance->isPartial = false; + + return newInstance; +} + + +/* + * ShardArrayToList builds a list of out the array of ShardInterval*. + */ +static List * +ShardArrayToList(ShardInterval **shardArray, int length) +{ + List *shardIntervalList = NIL; + int shardIndex; + + for (shardIndex = 0; shardIndex < length; shardIndex++) + { + ShardInterval *shardInterval = + shardArray[shardIndex]; + shardIntervalList = lappend(shardIntervalList, shardInterval); + } + + return shardIntervalList; +} + + +/* + * DeepCopyShardIntervalList copies originalShardIntervalList and the + * contained ShardIntervals, into a new list. + */ +static List * +DeepCopyShardIntervalList(List *originalShardIntervalList) +{ + List *copiedShardIntervalList = NIL; + ListCell *shardIntervalCell = NULL; + + foreach(shardIntervalCell, originalShardIntervalList) + { + ShardInterval *originalShardInterval = + (ShardInterval *) lfirst(shardIntervalCell); + ShardInterval *copiedShardInterval = + (ShardInterval *) palloc0(sizeof(ShardInterval)); + + CopyShardInterval(originalShardInterval, copiedShardInterval); + copiedShardIntervalList = lappend(copiedShardIntervalList, copiedShardInterval); + } + + return copiedShardIntervalList; +} + + +/* + * PruneOne returns all shards in the table that match a single + * PruningInstance. + */ +static List * +PruneOne(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune) +{ + ShardInterval *shardInterval = NULL; + + /* Well, if life always were this easy... */ + if (prune->evaluatesToFalse) + { + return NIL; + } + + /* + * For an equal constraints, if there's no overlapping shards (always the + * case for hash and range partitioning, sometimes for append), can + * perform binary search for the right interval. That's usually the + * fastest, so try that first. + */ + if (prune->equalConsts && + !cacheEntry->hasOverlappingShardInterval) + { + shardInterval = FindShardInterval(prune->equalConsts->constvalue, cacheEntry); + + /* + * If pruned down to nothing, we're done. Otherwise see if other + * methods prune down further / to nothing. + */ + if (!shardInterval) + { + return NIL; + } + } + + /* + * If the hash value we're looking for is known, we can search for the + * interval directly. That's fast and should only ever be the case for a + * hash-partitioned table. + */ + if (prune->hashedEqualConsts) + { + int shardIndex = INVALID_SHARD_INDEX; + ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; + + Assert(context->partitionMethod == DISTRIBUTE_BY_HASH); + + shardIndex = FindShardIntervalIndex(prune->hashedEqualConsts->constvalue, + cacheEntry); + + if (shardIndex == INVALID_SHARD_INDEX) + { + return NIL; + } + else if (shardInterval && + sortedShardIntervalArray[shardIndex]->shardId != shardInterval->shardId) + { + /* + * equalConst based pruning above yielded a different shard than + * pruning based on pre-hashed equality. This is useful in case + * of INSERT ... SELECT, where both can occur together (one via + * join/colocation, the other via a plain equality restriction). + */ + return NIL; + } + else + { + return list_make1(sortedShardIntervalArray[shardIndex]); + } + } + + /* + * If previous pruning method yielded a single shard, we could also + * attempt range based pruning to exclude it further. But that seems + * rarely useful in practice, and thus likely a waste of runtime and code + * complexity. + */ + if (shardInterval) + { + return list_make1(shardInterval); + } + + /* + * Should never get here for hashing, we've filtered down to either zero + * or one shard, and returned. + */ + Assert(context->partitionMethod != DISTRIBUTE_BY_HASH); + + /* + * Next method: binary search with fuzzy boundaries. Can't trivially do so + * if shards have overlapping boundaries. + * + * TODO: If we kept shard intervals separately sorted by both upper and + * lower boundaries, this should be possible? + */ + if (!cacheEntry->hasOverlappingShardInterval && ( + prune->greaterConsts || prune->greaterEqualConsts || + prune->lessConsts || prune->lessEqualConsts)) + { + return PruneWithBoundaries(cacheEntry, context, prune); + } + + /* + * Brute force: Check each shard. + */ + return ExhaustivePrune(cacheEntry, context, prune); +} + + +/* + * PerformCompare invokes comparator with prepared values, check for + * unexpected NULL returns. + */ +static int +PerformCompare(FunctionCallInfoData *compareFunctionCall) +{ + Datum result = FunctionCallInvoke(compareFunctionCall); + + if (compareFunctionCall->isnull) + { + elog(ERROR, "function %u returned NULL", compareFunctionCall->flinfo->fn_oid); + } + + return DatumGetInt32(result); +} + + +/* + * PerformValueCompare invokes comparator with a/b, and checks for unexpected + * NULL returns. + */ +static int +PerformValueCompare(FunctionCallInfoData *compareFunctionCall, Datum a, Datum b) +{ + compareFunctionCall->arg[0] = a; + compareFunctionCall->argnull[0] = false; + compareFunctionCall->arg[1] = b; + compareFunctionCall->argnull[1] = false; + + return PerformCompare(compareFunctionCall); +} + + +/* + * LowerShardBoundary returns the index of the first ShardInterval that's >= + * (if includeMax) or > partitionColumnValue. + */ +static int +LowerShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, bool includeMax) +{ + int lowerBoundIndex = 0; + int upperBoundIndex = shardCount; + + Assert(shardCount != 0); + + /* setup partitionColumnValue argument once */ + compareFunction->arg[0] = partitionColumnValue; + compareFunction->argnull[0] = false; + + while (lowerBoundIndex < upperBoundIndex) + { + int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2); + int maxValueComparison = 0; + int minValueComparison = 0; + + /* setup minValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->minValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, lowerBound) */ + minValueComparison = PerformCompare(compareFunction); + + /* and evaluate results */ + if (minValueComparison < 0) + { + /* value smaller than entire range */ + upperBoundIndex = middleIndex; + continue; + } + + /* setup maxValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->maxValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, upperBound) */ + maxValueComparison = PerformCompare(compareFunction); + + if ((maxValueComparison == 0 && !includeMax) || + maxValueComparison > 0) + { + /* value bigger than entire range */ + lowerBoundIndex = middleIndex + 1; + continue; + } + + /* found interval containing partitionValue */ + return middleIndex; + } + + Assert(lowerBoundIndex == upperBoundIndex); + + /* + * If we get here, none of the ShardIntervals exactly contain the value + * (we'd have hit the return middleIndex; case otherwise). Figure out + * whether there's possibly any interval containing a value that's bigger + * than the partition key one. + */ + if (lowerBoundIndex == 0) + { + /* all intervals are bigger, thus return 0 */ + return 0; + } + else if (lowerBoundIndex == shardCount) + { + /* partition value is bigger than all partition values */ + return INVALID_SHARD_INDEX; + } + + /* value falls inbetween intervals */ + return lowerBoundIndex + 1; +} + + +/* + * UpperShardBoundary returns the index of the last ShardInterval that's <= + * (if includeMin) or < partitionColumnValue. + */ +static int +UpperShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, bool includeMin) +{ + int lowerBoundIndex = 0; + int upperBoundIndex = shardCount; + + Assert(shardCount != 0); + + /* setup partitionColumnValue argument once */ + compareFunction->arg[0] = partitionColumnValue; + compareFunction->argnull[0] = false; + + while (lowerBoundIndex < upperBoundIndex) + { + int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2); + int maxValueComparison = 0; + int minValueComparison = 0; + + /* setup minValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->minValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, lowerBound) */ + minValueComparison = PerformCompare(compareFunction); + + /* and evaluate results */ + if ((minValueComparison == 0 && !includeMin) || + minValueComparison < 0) + { + /* value smaller than entire range */ + upperBoundIndex = middleIndex; + continue; + } + + /* setup maxValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->maxValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, upperBound) */ + maxValueComparison = PerformCompare(compareFunction); + + if (maxValueComparison > 0) + { + /* value bigger than entire range */ + lowerBoundIndex = middleIndex + 1; + continue; + } + + /* found interval containing partitionValue */ + return middleIndex; + } + + Assert(lowerBoundIndex == upperBoundIndex); + + /* + * If we get here, none of the ShardIntervals exactly contain the value + * (we'd have hit the return middleIndex; case otherwise). Figure out + * whether there's possibly any interval containing a value that's smaller + * than the partition key one. + */ + if (upperBoundIndex == shardCount) + { + /* all intervals are smaller, thus return 0 */ + return shardCount - 1; + } + else if (upperBoundIndex == 0) + { + /* partition value is smaller than all partition values */ + return INVALID_SHARD_INDEX; + } + + /* value falls inbetween intervals, return the inverval one smaller as bound */ + return upperBoundIndex - 1; +} + + +/* + * PruneWithBoundaries searches for shards that match inequality constraints, + * using binary search on both the upper and lower boundary, and returns a + * list of surviving shards. + */ +static List * +PruneWithBoundaries(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune) +{ + List *remainingShardList = NIL; + int shardCount = cacheEntry->shardIntervalArrayLength; + ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; + bool hasLowerBound = false; + bool hasUpperBound = false; + Datum lowerBound = 0; + Datum upperBound = 0; + bool lowerBoundInclusive = false; + bool upperBoundInclusive = false; + int lowerBoundIdx = -1; + int upperBoundIdx = -1; + int curIdx = 0; + FunctionCallInfo compareFunctionCall = &context->compareIntervalFunctionCall; + + if (prune->greaterEqualConsts) + { + lowerBound = prune->greaterEqualConsts->constvalue; + lowerBoundInclusive = true; + hasLowerBound = true; + } + if (prune->greaterConsts) + { + lowerBound = prune->greaterConsts->constvalue; + lowerBoundInclusive = false; + hasLowerBound = true; + } + if (prune->lessEqualConsts) + { + upperBound = prune->lessEqualConsts->constvalue; + upperBoundInclusive = true; + hasUpperBound = true; + } + if (prune->lessConsts) + { + upperBound = prune->lessConsts->constvalue; + upperBoundInclusive = false; + hasUpperBound = true; + } + + Assert(hasLowerBound || hasUpperBound); + + /* find lower bound */ + if (hasLowerBound) + { + lowerBoundIdx = LowerShardBoundary(lowerBound, sortedShardIntervalArray, + shardCount, compareFunctionCall, + lowerBoundInclusive); + } + else + { + lowerBoundIdx = 0; + } + + /* find upper bound */ + if (hasUpperBound) + { + upperBoundIdx = UpperShardBoundary(upperBound, sortedShardIntervalArray, + shardCount, compareFunctionCall, + upperBoundInclusive); + } + else + { + upperBoundIdx = shardCount - 1; + } + + if (lowerBoundIdx == INVALID_SHARD_INDEX) + { + return NIL; + } + else if (upperBoundIdx == INVALID_SHARD_INDEX) + { + return NIL; + } + + /* + * Build list of all shards that are in the range of shards (possibly 0). + */ + for (curIdx = lowerBoundIdx; curIdx <= upperBoundIdx; curIdx++) + { + remainingShardList = lappend(remainingShardList, + sortedShardIntervalArray[curIdx]); + } + + return remainingShardList; +} + + +/* + * ExhaustivePrune returns a list of shards matching PruningInstances + * constraints, by simply checking them for each individual shard. + */ +static List * +ExhaustivePrune(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune) +{ + List *remainingShardList = NIL; + FunctionCallInfo compareFunctionCall = &context->compareIntervalFunctionCall; + int shardCount = cacheEntry->shardIntervalArrayLength; + ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; + int curIdx = 0; + + for (curIdx = 0; curIdx < shardCount; curIdx++) + { + Datum compareWith = 0; + ShardInterval *curInterval = sortedShardIntervalArray[curIdx]; + + /* NULL boundaries can't be compared to */ + if (!curInterval->minValueExists || !curInterval->maxValueExists) + { + remainingShardList = lappend(remainingShardList, curInterval); + continue; + } + + if (prune->equalConsts) + { + compareWith = prune->equalConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + compareWith, + curInterval->minValue) < 0) + { + continue; + } + + if (PerformValueCompare(compareFunctionCall, + compareWith, + curInterval->maxValue) > 0) + { + continue; + } + } + if (prune->greaterEqualConsts) + { + compareWith = prune->greaterEqualConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->maxValue, + compareWith) < 0) + { + continue; + } + } + if (prune->greaterConsts) + { + compareWith = prune->greaterConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->maxValue, + compareWith) <= 0) + { + continue; + } + } + if (prune->lessEqualConsts) + { + compareWith = prune->lessEqualConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->minValue, + compareWith) > 0) + { + continue; + } + } + if (prune->lessConsts) + { + compareWith = prune->lessConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->minValue, + compareWith) >= 0) + { + continue; + } + } + + remainingShardList = lappend(remainingShardList, curInterval); + } + + return remainingShardList; +} diff --git a/src/backend/distributed/test/prune_shard_list.c b/src/backend/distributed/test/prune_shard_list.c index 88827f155..d68ebe516 100644 --- a/src/backend/distributed/test/prune_shard_list.c +++ b/src/backend/distributed/test/prune_shard_list.c @@ -24,6 +24,7 @@ #include "distributed/multi_physical_planner.h" #include "distributed/resource_lock.h" #include "distributed/test_helper_functions.h" /* IWYU pragma: keep */ +#include "distributed/shard_pruning.h" #include "nodes/pg_list.h" #include "nodes/primnodes.h" #include "nodes/nodes.h" @@ -203,11 +204,11 @@ PrunedShardIdsForTable(Oid distributedTableId, List *whereClauseList) Oid shardIdTypeId = INT8OID; Index tableId = 1; - List *shardList = LoadShardIntervalList(distributedTableId); + List *shardList = NIL; int shardIdCount = -1; Datum *shardIdDatumArray = NULL; - shardList = PruneShardList(distributedTableId, tableId, whereClauseList, shardList); + shardList = PruneShards(distributedTableId, tableId, whereClauseList); shardIdCount = list_length(shardList); shardIdDatumArray = palloc0(shardIdCount * sizeof(Datum)); diff --git a/src/backend/distributed/utils/shardinterval_utils.c b/src/backend/distributed/utils/shardinterval_utils.c index 5133da261..046ba3208 100644 --- a/src/backend/distributed/utils/shardinterval_utils.c +++ b/src/backend/distributed/utils/shardinterval_utils.c @@ -16,6 +16,7 @@ #include "catalog/pg_type.h" #include "distributed/metadata_cache.h" #include "distributed/multi_planner.h" +#include "distributed/shard_pruning.h" #include "distributed/shardinterval_utils.h" #include "distributed/pg_dist_partition.h" #include "distributed/worker_protocol.h" @@ -23,7 +24,6 @@ #include "utils/memutils.h" -static int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry); static int SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, int shardCount, FmgrInfo *compareFunction); @@ -254,7 +254,7 @@ FindShardInterval(Datum partitionColumnValue, DistTableCacheEntry *cacheEntry) * somewhere. Such as a hash function which returns a value not in the range * of [INT32_MIN, INT32_MAX] can fire this. */ -static int +int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry) { ShardInterval **shardIntervalCache = cacheEntry->sortedShardIntervalArray; diff --git a/src/include/distributed/multi_physical_planner.h b/src/include/distributed/multi_physical_planner.h index 1949afadc..85280cb6c 100644 --- a/src/include/distributed/multi_physical_planner.h +++ b/src/include/distributed/multi_physical_planner.h @@ -249,10 +249,6 @@ extern StringInfo ShardFetchQueryString(uint64 shardId); extern Task * CreateBasicTask(uint64 jobId, uint32 taskId, TaskType taskType, char *queryString); -/* Function declarations for shard pruning */ -extern List * PruneShardList(Oid relationId, Index tableId, List *whereClauseList, - List *shardList); -extern bool ContainsFalseClause(List *whereClauseList); extern OpExpr * MakeOpExpression(Var *variable, int16 strategyNumber); /* diff --git a/src/include/distributed/shard_pruning.h b/src/include/distributed/shard_pruning.h new file mode 100644 index 000000000..3c26c4662 --- /dev/null +++ b/src/include/distributed/shard_pruning.h @@ -0,0 +1,23 @@ +/*------------------------------------------------------------------------- + * + * shard_pruning.h + * Shard pruning infrastructure. + * + * Copyright (c) 2014-2017, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#ifndef SHARD_PRUNING_H_ +#define SHARD_PRUNING_H_ + +#include "distributed/metadata_cache.h" +#include "nodes/primnodes.h" + +#define INVALID_SHARD_INDEX -1 + +/* Function declarations for shard pruning */ +extern List * PruneShards(Oid relationId, Index rangeTableId, List *whereClauseList); +extern bool ContainsFalseClause(List *whereClauseList); + +#endif /* SHARD_PRUNING_H_ */ diff --git a/src/include/distributed/shardinterval_utils.h b/src/include/distributed/shardinterval_utils.h index 54b96f2e7..5c9d6cde8 100644 --- a/src/include/distributed/shardinterval_utils.h +++ b/src/include/distributed/shardinterval_utils.h @@ -35,6 +35,7 @@ extern int CompareRelationShards(const void *leftElement, extern int ShardIndex(ShardInterval *shardInterval); extern ShardInterval * FindShardInterval(Datum partitionColumnValue, DistTableCacheEntry *cacheEntry); +extern int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry); extern bool SingleReplicatedTable(Oid relationId); #endif /* SHARDINTERVAL_UTILS_H_ */ diff --git a/src/test/regress/expected/multi_prune_shard_list.out b/src/test/regress/expected/multi_prune_shard_list.out index a440478e5..e809cd139 100644 --- a/src/test/regress/expected/multi_prune_shard_list.out +++ b/src/test/regress/expected/multi_prune_shard_list.out @@ -70,21 +70,21 @@ SELECT prune_using_single_value('pruning', NULL); SELECT prune_using_either_value('pruning', 'tomato', 'petunia'); prune_using_either_value -------------------------- - {800001,800002} + {800002,800001} (1 row) --- an AND clause with incompatible values returns no shards +-- an AND clause with values on different shards returns no shards SELECT prune_using_both_values('pruning', 'tomato', 'petunia'); prune_using_both_values ------------------------- {} (1 row) --- but if both values are on the same shard, should get back that shard +-- even if both values are on the same shard, a value can't be equal to two others SELECT prune_using_both_values('pruning', 'tomato', 'rose'); prune_using_both_values ------------------------- - {800002} + {} (1 row) -- unit test of the equality expression generation code diff --git a/src/test/regress/sql/multi_prune_shard_list.sql b/src/test/regress/sql/multi_prune_shard_list.sql index 62e0432c6..99e5f95ca 100644 --- a/src/test/regress/sql/multi_prune_shard_list.sql +++ b/src/test/regress/sql/multi_prune_shard_list.sql @@ -60,10 +60,10 @@ SELECT prune_using_single_value('pruning', NULL); -- build an OR clause and expect more than one sahrd SELECT prune_using_either_value('pruning', 'tomato', 'petunia'); --- an AND clause with incompatible values returns no shards +-- an AND clause with values on different shards returns no shards SELECT prune_using_both_values('pruning', 'tomato', 'petunia'); --- but if both values are on the same shard, should get back that shard +-- even if both values are on the same shard, a value can't be equal to two others SELECT prune_using_both_values('pruning', 'tomato', 'rose'); -- unit test of the equality expression generation code