diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index 770ff4799..a0e1f0b65 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -67,6 +67,7 @@ #include "distributed/placement_connection.h" #include "distributed/remote_commands.h" #include "distributed/resource_lock.h" +#include "distributed/shard_pruning.h" #include "executor/executor.h" #include "nodes/makefuncs.h" #include "tsearch/ts_locale.h" diff --git a/src/backend/distributed/master/master_modify_multiple_shards.c b/src/backend/distributed/master/master_modify_multiple_shards.c index d2695a7b4..086846fde 100644 --- a/src/backend/distributed/master/master_modify_multiple_shards.c +++ b/src/backend/distributed/master/master_modify_multiple_shards.c @@ -40,6 +40,7 @@ #include "distributed/pg_dist_partition.h" #include "distributed/resource_lock.h" #include "distributed/shardinterval_utils.h" +#include "distributed/shard_pruning.h" #include "distributed/worker_protocol.h" #include "optimizer/clauses.h" #include "optimizer/predtest.h" @@ -81,7 +82,6 @@ master_modify_multiple_shards(PG_FUNCTION_ARGS) Node *queryTreeNode; List *restrictClauseList = NIL; bool failOK = false; - List *shardIntervalList = NIL; List *prunedShardIntervalList = NIL; List *taskList = NIL; int32 affectedTupleCount = 0; @@ -156,11 +156,10 @@ master_modify_multiple_shards(PG_FUNCTION_ARGS) ExecuteMasterEvaluableFunctions(modifyQuery, NULL); - shardIntervalList = LoadShardIntervalList(relationId); restrictClauseList = WhereClauseList(modifyQuery->jointree); prunedShardIntervalList = - PruneShardList(relationId, tableId, restrictClauseList, shardIntervalList); + PruneShards(relationId, tableId, restrictClauseList); CHECK_FOR_INTERRUPTS(); diff --git a/src/backend/distributed/planner/multi_physical_planner.c b/src/backend/distributed/planner/multi_physical_planner.c index 505a240d4..9d2bb9f5b 100644 --- a/src/backend/distributed/planner/multi_physical_planner.c +++ b/src/backend/distributed/planner/multi_physical_planner.c @@ -41,6 +41,7 @@ #include "distributed/pg_dist_partition.h" #include "distributed/pg_dist_shard.h" #include "distributed/shardinterval_utils.h" +#include "distributed/shard_pruning.h" #include "distributed/task_tracker.h" #include "distributed/worker_manager.h" #include "distributed/worker_protocol.h" @@ -133,9 +134,6 @@ static List * RangeTableFragmentsList(List *rangeTableList, List *whereClauseLis static OperatorCacheEntry * LookupOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber); static Oid GetOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber); -static Node * HashableClauseMutator(Node *originalNode, Var *partitionColumn); -static OpExpr * MakeHashedOperatorExpression(OpExpr *operatorExpression); -static List * BuildRestrictInfoList(List *qualList); static List * FragmentCombinationList(List *rangeTableFragmentsList, Query *jobQuery, List *dependedJobList); static JoinSequenceNode * JoinSequenceArray(List *rangeTableFragmentsList, @@ -2060,7 +2058,6 @@ SubquerySqlTaskList(Job *job) { RangeTblEntry *rangeTableEntry = (RangeTblEntry *) lfirst(rangeTableCell); Oid relationId = rangeTableEntry->relid; - List *shardIntervalList = LoadShardIntervalList(relationId); List *finalShardIntervalList = NIL; ListCell *fragmentCombinationCell = NULL; ListCell *shardIntervalCell = NULL; @@ -2073,12 +2070,11 @@ SubquerySqlTaskList(Job *job) Var *partitionColumn = PartitionColumn(relationId, tableId); List *whereClauseList = ReplaceColumnsInOpExpressionList(opExpressionList, partitionColumn); - finalShardIntervalList = PruneShardList(relationId, tableId, whereClauseList, - shardIntervalList); + finalShardIntervalList = PruneShards(relationId, tableId, whereClauseList); } else { - finalShardIntervalList = shardIntervalList; + finalShardIntervalList = LoadShardIntervalList(relationId); } /* if all shards are pruned away, we return an empty task list */ @@ -2513,11 +2509,8 @@ RangeTableFragmentsList(List *rangeTableList, List *whereClauseList, Oid relationId = rangeTableEntry->relid; ListCell *shardIntervalCell = NULL; List *shardFragmentList = NIL; - - List *shardIntervalList = LoadShardIntervalList(relationId); - List *prunedShardIntervalList = PruneShardList(relationId, tableId, - whereClauseList, - shardIntervalList); + List *prunedShardIntervalList = PruneShards(relationId, tableId, + whereClauseList); /* * If we prune all shards for one table, query results will be empty. @@ -2586,114 +2579,6 @@ RangeTableFragmentsList(List *rangeTableList, List *whereClauseList, } -/* - * PruneShardList prunes shard intervals from given list based on the selection criteria, - * and returns remaining shard intervals in another list. - * - * For reference tables, the function simply returns the single shard that the table has. - */ -List * -PruneShardList(Oid relationId, Index tableId, List *whereClauseList, - List *shardIntervalList) -{ - List *remainingShardList = NIL; - ListCell *shardIntervalCell = NULL; - List *restrictInfoList = NIL; - Node *baseConstraint = NULL; - - Var *partitionColumn = PartitionColumn(relationId, tableId); - char partitionMethod = PartitionMethod(relationId); - - /* short circuit for reference tables */ - if (partitionMethod == DISTRIBUTE_BY_NONE) - { - return shardIntervalList; - } - - if (ContainsFalseClause(whereClauseList)) - { - /* always return empty result if WHERE clause is of the form: false (AND ..) */ - return NIL; - } - - /* build the filter clause list for the partition method */ - if (partitionMethod == DISTRIBUTE_BY_HASH) - { - Node *hashedNode = HashableClauseMutator((Node *) whereClauseList, - partitionColumn); - - List *hashedClauseList = (List *) hashedNode; - restrictInfoList = BuildRestrictInfoList(hashedClauseList); - } - else - { - restrictInfoList = BuildRestrictInfoList(whereClauseList); - } - - /* override the partition column for hash partitioning */ - if (partitionMethod == DISTRIBUTE_BY_HASH) - { - partitionColumn = MakeInt4Column(); - } - - /* build the base expression for constraint */ - baseConstraint = BuildBaseConstraint(partitionColumn); - - /* walk over shard list and check if shards can be pruned */ - foreach(shardIntervalCell, shardIntervalList) - { - ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); - List *constraintList = NIL; - bool shardPruned = false; - - if (shardInterval->minValueExists && shardInterval->maxValueExists) - { - /* set the min/max values in the base constraint */ - UpdateConstraint(baseConstraint, shardInterval); - constraintList = list_make1(baseConstraint); - - shardPruned = predicate_refuted_by(constraintList, restrictInfoList); - } - - if (!shardPruned) - { - remainingShardList = lappend(remainingShardList, shardInterval); - } - } - - return remainingShardList; -} - - -/* - * ContainsFalseClause returns whether the flattened where clause list - * contains false as a clause. - */ -bool -ContainsFalseClause(List *whereClauseList) -{ - bool containsFalseClause = false; - ListCell *clauseCell = NULL; - - foreach(clauseCell, whereClauseList) - { - Node *clause = (Node *) lfirst(clauseCell); - - if (IsA(clause, Const)) - { - Const *constant = (Const *) clause; - if (constant->consttype == BOOLOID && !DatumGetBool(constant->constvalue)) - { - containsFalseClause = true; - break; - } - } - } - - return containsFalseClause; -} - - /* * BuildBaseConstraint builds and returns a base constraint. This constraint * implements an expression in the form of (column <= max && column >= min), @@ -2916,87 +2801,6 @@ SimpleOpExpression(Expr *clause) } -/* - * HashableClauseMutator walks over the original where clause list, replaces - * hashable nodes with hashed versions and keeps other nodes as they are. - */ -static Node * -HashableClauseMutator(Node *originalNode, Var *partitionColumn) -{ - Node *newNode = NULL; - if (originalNode == NULL) - { - return NULL; - } - - if (IsA(originalNode, OpExpr)) - { - OpExpr *operatorExpression = (OpExpr *) originalNode; - bool hasPartitionColumn = false; - - Oid leftHashFunction = InvalidOid; - Oid rightHashFunction = InvalidOid; - - /* - * If operatorExpression->opno is NOT the registered '=' operator for - * any hash opfamilies, then get_op_hash_functions will return false. - * This means this function both ensures a hash function exists for the - * types in question AND filters out any clauses lacking equality ops. - */ - bool hasHashFunction = get_op_hash_functions(operatorExpression->opno, - &leftHashFunction, - &rightHashFunction); - - bool simpleOpExpression = SimpleOpExpression((Expr *) operatorExpression); - if (simpleOpExpression) - { - hasPartitionColumn = OpExpressionContainsColumn(operatorExpression, - partitionColumn); - } - - if (hasHashFunction && hasPartitionColumn) - { - OpExpr *hashedOperatorExpression = - MakeHashedOperatorExpression((OpExpr *) originalNode); - newNode = (Node *) hashedOperatorExpression; - } - } - else if (IsA(originalNode, ScalarArrayOpExpr)) - { - ScalarArrayOpExpr *arrayOperatorExpression = (ScalarArrayOpExpr *) originalNode; - Node *leftOpExpression = linitial(arrayOperatorExpression->args); - Node *strippedLeftOpExpression = strip_implicit_coercions(leftOpExpression); - bool usingEqualityOperator = OperatorImplementsEquality( - arrayOperatorExpression->opno); - - /* - * Citus cannot prune hash-distributed shards with ANY/ALL. We show a NOTICE - * if the expression is ANY/ALL performed on the partition column with equality. - */ - if (usingEqualityOperator && strippedLeftOpExpression != NULL && - equal(strippedLeftOpExpression, partitionColumn)) - { - ereport(NOTICE, (errmsg("cannot use shard pruning with " - "ANY/ALL (array expression)"), - errhint("Consider rewriting the expression with " - "OR/AND clauses."))); - } - } - - /* - * If this node is not hashable, continue walking down the expression tree - * to find and hash clauses which are eligible. - */ - if (newNode == NULL) - { - newNode = expression_tree_mutator(originalNode, HashableClauseMutator, - (void *) partitionColumn); - } - - return newNode; -} - - /* * OpExpressionContainsColumn checks if the operator expression contains the * given partition column. We assume that given operator expression is a simple @@ -3027,77 +2831,6 @@ OpExpressionContainsColumn(OpExpr *operatorExpression, Var *partitionColumn) } -/* - * MakeHashedOperatorExpression creates a new operator expression with a column - * of int4 type and hashed constant value. - */ -static OpExpr * -MakeHashedOperatorExpression(OpExpr *operatorExpression) -{ - const Oid hashResultTypeId = INT4OID; - TypeCacheEntry *hashResultTypeEntry = NULL; - Oid operatorId = InvalidOid; - OpExpr *hashedExpression = NULL; - Var *hashedColumn = NULL; - Datum hashedValue = 0; - Const *hashedConstant = NULL; - FmgrInfo *hashFunction = NULL; - TypeCacheEntry *typeEntry = NULL; - - Node *leftOperand = get_leftop((Expr *) operatorExpression); - Node *rightOperand = get_rightop((Expr *) operatorExpression); - Const *constant = NULL; - - if (IsA(rightOperand, Const)) - { - constant = (Const *) rightOperand; - } - else - { - constant = (Const *) leftOperand; - } - - /* Load the operator from type cache */ - hashResultTypeEntry = lookup_type_cache(hashResultTypeId, TYPECACHE_EQ_OPR); - operatorId = hashResultTypeEntry->eq_opr; - - /* Get a column with int4 type */ - hashedColumn = MakeInt4Column(); - - /* Load the hash function from type cache */ - typeEntry = lookup_type_cache(constant->consttype, TYPECACHE_HASH_PROC_FINFO); - hashFunction = &(typeEntry->hash_proc_finfo); - if (!OidIsValid(hashFunction->fn_oid)) - { - ereport(ERROR, (errcode(ERRCODE_UNDEFINED_FUNCTION), - errmsg("could not identify a hash function for type %s", - format_type_be(constant->consttype)), - errdatatype(constant->consttype))); - } - - /* - * Note that any changes to PostgreSQL's hashing functions will change the - * new value created by this function. - */ - hashedValue = FunctionCall1(hashFunction, constant->constvalue); - hashedConstant = MakeInt4Constant(hashedValue); - - /* Now create the expression with modified partition column and hashed constant */ - hashedExpression = (OpExpr *) make_opclause(operatorId, - InvalidOid, /* no result type yet */ - false, /* no return set */ - (Expr *) hashedColumn, - (Expr *) hashedConstant, - InvalidOid, InvalidOid); - - /* Set implementing function id and result type */ - hashedExpression->opfuncid = get_opcode(operatorId); - hashedExpression->opresulttype = get_func_rettype(hashedExpression->opfuncid); - - return hashedExpression; -} - - /* * MakeInt4Column creates a column of int4 type with invalid table id and max * attribute number. @@ -3139,30 +2872,6 @@ MakeInt4Constant(Datum constantValue) } -/* - * BuildRestrictInfoList builds restrict info list using the selection criteria, - * and then return this list. Note that this function assumes there is only one - * relation for now. - */ -static List * -BuildRestrictInfoList(List *qualList) -{ - List *restrictInfoList = NIL; - ListCell *qualCell = NULL; - - foreach(qualCell, qualList) - { - RestrictInfo *restrictInfo = NULL; - Node *qualNode = (Node *) lfirst(qualCell); - - restrictInfo = make_simple_restrictinfo((Expr *) qualNode); - restrictInfoList = lappend(restrictInfoList, restrictInfo); - } - - return restrictInfoList; -} - - /* Updates the base constraint with the given min/max values. */ void UpdateConstraint(Node *baseConstraint, ShardInterval *shardInterval) diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 2829ecdcc..e094ad8d0 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -41,6 +41,7 @@ #include "distributed/relay_utility.h" #include "distributed/resource_lock.h" #include "distributed/shardinterval_utils.h" +#include "distributed/shard_pruning.h" #include "executor/execdesc.h" #include "lib/stringinfo.h" #include "nodes/makefuncs.h" @@ -558,6 +559,8 @@ RouterModifyTaskForShardInterval(Query *originalQuery, ShardInterval *shardInter * * The function errors out if the given shard interval does not belong to a hash, * range and append distributed tables. + * + * NB: If you update this, also look at PrunableExpressionsWalker(). */ static List * ShardIntervalOpExpressions(ShardInterval *shardInterval, Index rteIndex) @@ -1998,9 +2001,7 @@ FindShardForInsert(Query *query, DeferredErrorMessage **planningError) restrictClauseList = list_make1(equalityExpr); - shardIntervalList = LoadShardIntervalList(distributedTableId); - prunedShardList = PruneShardList(distributedTableId, tableId, restrictClauseList, - shardIntervalList); + prunedShardList = PruneShards(distributedTableId, tableId, restrictClauseList); } prunedShardCount = list_length(prunedShardList); @@ -2060,7 +2061,6 @@ FindShardForUpdateOrDelete(Query *query, DeferredErrorMessage **planningError) DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); char partitionMethod = cacheEntry->partitionMethod; CmdType commandType = query->commandType; - List *shardIntervalList = NIL; List *restrictClauseList = NIL; Index tableId = 1; List *prunedShardList = NIL; @@ -2068,11 +2068,8 @@ FindShardForUpdateOrDelete(Query *query, DeferredErrorMessage **planningError) Assert(commandType == CMD_UPDATE || commandType == CMD_DELETE); - shardIntervalList = LoadShardIntervalList(distributedTableId); - restrictClauseList = QueryRestrictList(query, partitionMethod); - prunedShardList = PruneShardList(distributedTableId, tableId, restrictClauseList, - shardIntervalList); + prunedShardList = PruneShards(distributedTableId, tableId, restrictClauseList); prunedShardCount = list_length(prunedShardList); if (prunedShardCount != 1) @@ -2412,7 +2409,6 @@ TargetShardIntervalsForSelect(Query *query, List *baseRestrictionList = relationRestriction->relOptInfo->baserestrictinfo; List *restrictClauseList = get_all_actual_clauses(baseRestrictionList); List *prunedShardList = NIL; - int shardIndex = 0; List *joinInfoList = relationRestriction->relOptInfo->joininfo; List *pseudoRestrictionList = extract_actual_clauses(joinInfoList, true); bool whereFalseQuery = false; @@ -2428,18 +2424,8 @@ TargetShardIntervalsForSelect(Query *query, whereFalseQuery = ContainsFalseClause(pseudoRestrictionList); if (!whereFalseQuery && shardCount > 0) { - List *shardIntervalList = NIL; - - for (shardIndex = 0; shardIndex < shardCount; shardIndex++) - { - ShardInterval *shardInterval = - cacheEntry->sortedShardIntervalArray[shardIndex]; - shardIntervalList = lappend(shardIntervalList, shardInterval); - } - - prunedShardList = PruneShardList(relationId, tableId, - restrictClauseList, - shardIntervalList); + prunedShardList = PruneShards(relationId, tableId, + restrictClauseList); /* * Quick bail out. The query can not be router plannable if one diff --git a/src/backend/distributed/planner/shard_pruning.c b/src/backend/distributed/planner/shard_pruning.c new file mode 100644 index 000000000..807b4f71e --- /dev/null +++ b/src/backend/distributed/planner/shard_pruning.c @@ -0,0 +1,1319 @@ +/*------------------------------------------------------------------------- + * + * shard_pruning.c + * Shard pruning related code. + * + * The goal of shard pruning is to find a minimal (super)set of shards that + * need to be queried to find rows matching the expression in a query. + * + * In PruneShards, we first compute a simplified disjunctive normal form (DNF) + * of the expression as a list of pruning instances. Each pruning instance + * contains all AND-ed constraints on the partition column. An OR expression + * will result in two or more new pruning instances being added for the + * subexpressions. The "parent" instance is marked isPartial and ignored + * during pruning. + * + * We use the distributive property for constraints of the form P AND (Q OR R) + * to rewrite it to (P AND Q) OR (P AND R) by copying constraints from parent + * to "child" pruning instances. However, we do not distribute nested + * expressions. While (P OR Q) AND (R OR S) is logically equivalent to (P AND + * R) OR (P AND S) OR (Q AND R) OR (Q AND S), in our implementation it becomes + * P OR Q OR R OR S. This is acceptable since this will always result in a + * superset of shards. If this proves to be a issue in practice, a more + * complete algorithm could be implemented. + * + * We then evaluate each non-partial pruning instance in the disjunction + * through the following, increasingly expensive, steps: + * + * 1) If there is a constant equality constraint on the partition column, and + * no overlapping shards exist, find the shard interval in which the + * constant falls + * + * 2) If there is a hash range constraint on the partition column, find the + * shard interval matching the range + * + * 3) If there are range constraints (e.g. (a > 0 AND a < 10)) on the + * partition column, find the shard intervals that overlap with the range + * + * 4) If there are overlapping shards, exhaustively search all shards that are + * not excluded by constraints + * + * Finally, the union of the shards found by each pruning instance is + * returned. + * + * Copyright (c) 2014-2017, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ +#include "postgres.h" + +#include "distributed/shard_pruning.h" + +#include "access/nbtree.h" +#include "catalog/pg_am.h" +#include "catalog/pg_collation.h" +#include "catalog/pg_type.h" +#include "distributed/metadata_cache.h" +#include "distributed/multi_planner.h" +#include "distributed/multi_join_order.h" +#include "distributed/multi_physical_planner.h" +#include "distributed/shardinterval_utils.h" +#include "distributed/pg_dist_partition.h" +#include "distributed/worker_protocol.h" +#include "nodes/nodeFuncs.h" +#include "optimizer/clauses.h" +#include "utils/catcache.h" +#include "utils/lsyscache.h" +#include "utils/memutils.h" + +/* + * A pruning instance is a set of ANDed constraints on a partition key. + */ +typedef struct PruningInstance +{ + /* Does this instance contain any prunable expressions? */ + bool hasValidConstraint; + + /* + * This constraint never evaluates to true, i.e. pruning does not have to + * be performed. + */ + bool evaluatesToFalse; + + /* + * Constraints on the partition column value. If multiple values are + * found the more restrictive one should be stored here. Even in case of + * a hash-partitioned table, actual column-values are stored here, *not* + * hashed values. + */ + Const *lessConsts; + Const *lessEqualConsts; + Const *equalConsts; + Const *greaterEqualConsts; + Const *greaterConsts; + + /* + * Constraint using a pre-hashed column value. The constant will store the + * hashed value, not the original value of the restriction. + */ + Const *hashedEqualConsts; + + /* + * Types of constraints not understood. We could theoretically try more + * expensive methods of pruning if any such restrictions are found. + * + * TODO: any actual use for this? Right now there seems little point. + */ + List *otherRestrictions; + + /* + * Has this PruningInstance been added to + * ClauseWalkerContext->pruningInstances? This is not done immediately, + * but the first time a constraint (independent of us being able to handle + * that constraint) is found. + */ + bool addedToPruningInstances; + + /* + * When OR clauses are found, the non-ORed part (think of a < 3 AND (a > 5 + * OR a > 7)) of the expression is stored in one PruningInstance which is + * then copied for the ORed expressions. The original is marked as + * isPartial, to avoid it being used for pruning. + */ + bool isPartial; +} PruningInstance; + + +/* + * Partial instances that need to be finished building. This is used to + * collect all ANDed restrictions, before looking into ORed expressions. + */ +typedef struct PendingPruningInstance +{ + PruningInstance *instance; + Node *continueAt; +} PendingPruningInstance; + + +/* + * Data necessary to perform a single PruneShards(). + */ +typedef struct ClauseWalkerContext +{ + Var *partitionColumn; + char partitionMethod; + + /* ORed list of pruning targets */ + List *pruningInstances; + + /* + * Partially built PruningInstances, that need to be completed by doing a + * separate PrunableExpressionsWalker() pass. + */ + List *pendingInstances; + + /* PruningInstance currently being built, all elegible constraints are added here */ + PruningInstance *currentPruningInstance; + + /* + * Information about function calls we need to perform. Re-using the same + * FunctionCallInfoData, instead of using FunctionCall2Coll, is often + * cheaper. + */ + FunctionCallInfoData compareValueFunctionCall; + FunctionCallInfoData compareIntervalFunctionCall; +} ClauseWalkerContext; + +static void PrunableExpressions(Node *originalNode, ClauseWalkerContext *context); +static bool PrunableExpressionsWalker(Node *originalNode, ClauseWalkerContext *context); +static void AddPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, + OpExpr *opClause, Var *varClause, + Const *constantClause); +static void AddHashRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opClause, + Var *varClause, Const *constantClause); +static PruningInstance * CopyPartialPruningInstance(PruningInstance *sourceInstance); +static List * ShardArrayToList(ShardInterval **shardArray, int length); +static List * DeepCopyShardIntervalList(List *originalShardIntervalList); +static int PerformValueCompare(FunctionCallInfoData *compareFunctionCall, Datum a, + Datum b); +static int PerformCompare(FunctionCallInfoData *compareFunctionCall); + +static List * PruneOne(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune); +static List * PruneWithBoundaries(DistTableCacheEntry *cacheEntry, + ClauseWalkerContext *context, + PruningInstance *prune); +static List * ExhaustivePrune(DistTableCacheEntry *cacheEntry, + ClauseWalkerContext *context, + PruningInstance *prune); +static int UpperShardBoundary(Datum partitionColumnValue, + ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, + bool includeMin); +static int LowerShardBoundary(Datum partitionColumnValue, + ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, + bool includeMax); + + +/* + * PruneShards returns all shards from a distributed table that cannot be + * proven to be eliminated by whereClauseList. + * + * For reference tables, the function simply returns the single shard that the + * table has. + */ +List * +PruneShards(Oid relationId, Index rangeTableId, List *whereClauseList) +{ + DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId); + char partitionMethod = cacheEntry->partitionMethod; + ClauseWalkerContext context = { 0 }; + ListCell *pruneCell; + List *prunedList = NIL; + bool foundRestriction = false; + + /* always return empty result if WHERE clause is of the form: false (AND ..) */ + if (ContainsFalseClause(whereClauseList)) + { + return NIL; + } + + /* short circuit for reference tables */ + if (partitionMethod == DISTRIBUTE_BY_NONE) + { + prunedList = ShardArrayToList(cacheEntry->sortedShardIntervalArray, + cacheEntry->shardIntervalArrayLength); + return DeepCopyShardIntervalList(prunedList); + } + + + context.partitionMethod = partitionMethod; + context.partitionColumn = PartitionColumn(relationId, rangeTableId); + context.currentPruningInstance = palloc0(sizeof(PruningInstance)); + + if (cacheEntry->shardIntervalCompareFunction) + { + /* initiate function call info once (allows comparators to cache metadata) */ + InitFunctionCallInfoData(context.compareIntervalFunctionCall, + cacheEntry->shardIntervalCompareFunction, + 2, DEFAULT_COLLATION_OID, NULL, NULL); + } + else + { + ereport(ERROR, (errmsg("shard pruning not possible without " + "a shard interval comparator"))); + } + + if (cacheEntry->shardColumnCompareFunction) + { + /* initiate function call info once (allows comparators to cache metadata) */ + InitFunctionCallInfoData(context.compareValueFunctionCall, + cacheEntry->shardColumnCompareFunction, + 2, DEFAULT_COLLATION_OID, NULL, NULL); + } + else + { + ereport(ERROR, (errmsg("shard pruning not possible without " + "a partition column comparator"))); + } + + /* Figure out what we can prune on */ + PrunableExpressions((Node *) whereClauseList, &context); + + /* + * Prune using each of the PrunableInstances we found, and OR results + * together. + */ + foreach(pruneCell, context.pruningInstances) + { + PruningInstance *prune = (PruningInstance *) lfirst(pruneCell); + List *pruneOneList; + + /* + * If this is a partial instance, a fully built one has also been + * added. Skip. + */ + if (prune->isPartial) + { + continue; + } + + /* + * If the current instance has no prunable expressions, we'll have to + * return all shards. No point in continuing pruning in that case. + */ + if (!prune->hasValidConstraint) + { + foundRestriction = false; + break; + } + + /* + * Similar to the above, if hash-partitioned and there's nothing to + * prune by, we're done. + */ + if (context.partitionMethod == DISTRIBUTE_BY_HASH && + !prune->evaluatesToFalse && !prune->equalConsts && !prune->hashedEqualConsts) + { + foundRestriction = false; + break; + } + + pruneOneList = PruneOne(cacheEntry, &context, prune); + + if (prunedList) + { + /* + * We can use list_union_ptr, which is a lot faster than doing + * comparing shards by value, because all the ShardIntervals are + * guaranteed to be from + * DistTableCacheEntry->sortedShardIntervalArray (thus having the + * same pointer values). + */ + prunedList = list_union_ptr(prunedList, pruneOneList); + } + else + { + prunedList = pruneOneList; + } + foundRestriction = true; + } + + /* found no valid restriction, build list of all shards */ + if (!foundRestriction) + { + prunedList = ShardArrayToList(cacheEntry->sortedShardIntervalArray, + cacheEntry->shardIntervalArrayLength); + } + + /* + * Deep copy list, so it's independent of the DistTableCacheEntry + * contents. + */ + return DeepCopyShardIntervalList(prunedList); +} + + +/* + * ContainsFalseClause returns whether the flattened where clause list + * contains false as a clause. + */ +bool +ContainsFalseClause(List *whereClauseList) +{ + bool containsFalseClause = false; + ListCell *clauseCell = NULL; + + foreach(clauseCell, whereClauseList) + { + Node *clause = (Node *) lfirst(clauseCell); + + if (IsA(clause, Const)) + { + Const *constant = (Const *) clause; + if (constant->consttype == BOOLOID && !DatumGetBool(constant->constvalue)) + { + containsFalseClause = true; + break; + } + } + } + + return containsFalseClause; +} + + +/* + * PrunableExpressions builds a list of all prunable expressions in node, + * storing them in context->pruningInstances. + */ +static void +PrunableExpressions(Node *node, ClauseWalkerContext *context) +{ + /* + * Build initial list of prunable expressions. As long as only, + * implicitly or explicitly, ANDed expressions are found, this perform a + * depth-first search. When an ORed expression is found, the current + * PruningInstance is added to context->pruningInstances (once for each + * ORed expression), then the tree-traversal is continued without + * recursing. Once at the top-level again, we'll process all pending + * expressions - that allows us to find all ANDed expressions, before + * recursing into an ORed expression. + */ + PrunableExpressionsWalker(node, context); + + /* + * Process all pending instances. While processing, new ones might be + * added to the list, so don't use foreach(). + * + * Check the places in PruningInstanceWalker that push onto + * context->pendingInstances why construction of the PruningInstance might + * be pending. + * + * We copy the partial PruningInstance, and continue adding information by + * calling PrunableExpressionsWalker() on the copy, continuing at the the + * node stored in PendingPruningInstance->continueAt. + */ + while (context->pendingInstances != NIL) + { + PendingPruningInstance *instance = + (PendingPruningInstance *) linitial(context->pendingInstances); + PruningInstance *newPrune = CopyPartialPruningInstance(instance->instance); + + context->pendingInstances = list_delete_first(context->pendingInstances); + + context->currentPruningInstance = newPrune; + PrunableExpressionsWalker(instance->continueAt, context); + context->currentPruningInstance = NULL; + } +} + + +/* + * PrunableExpressionsWalker() is the main work horse for + * PrunableExpressions(). + */ +static bool +PrunableExpressionsWalker(Node *node, ClauseWalkerContext *context) +{ + if (node == NULL) + { + return false; + } + + /* + * Check for expressions understood by this routine. + */ + if (IsA(node, List)) + { + /* at the top of quals we'll frequently see lists, those are to be treated as ANDs */ + } + else if (IsA(node, BoolExpr)) + { + BoolExpr *boolExpr = (BoolExpr *) node; + + if (boolExpr->boolop == NOT_EXPR) + { + return false; + } + else if (boolExpr->boolop == AND_EXPR) + { + return expression_tree_walker((Node *) boolExpr->args, + PrunableExpressionsWalker, context); + } + else if (boolExpr->boolop == OR_EXPR) + { + ListCell *opCell = NULL; + + /* + * "Queue" partial pruning instances. This is used to convert + * expressions like (A AND (B OR C) AND D) into (A AND B AND D), + * (A AND C AND D), with A, B, C, D being restrictions. When the + * OR is encountered, a reference to the partially built + * PruningInstance (containing A at this point), is added to + * context->pendingInstances once for B and once for C. Once a + * full tree-walk completed, PrunableExpressions() will complete + * the pending instances, which'll now also know about restriction + * D, by calling PrunableExpressionsWalker() once for B and once + * for C. + */ + foreach(opCell, boolExpr->args) + { + PendingPruningInstance *instance = + palloc0(sizeof(PendingPruningInstance)); + + instance->instance = context->currentPruningInstance; + instance->continueAt = lfirst(opCell); + + /* + * Signal that this instance is not to be used for pruning on + * its own. Once the pending instance is processed, it'll be + * used. + */ + instance->instance->isPartial = true; + + context->pendingInstances = lappend(context->pendingInstances, instance); + } + + return false; + } + } + else if (IsA(node, OpExpr)) + { + OpExpr *opClause = (OpExpr *) node; + PruningInstance *prune = context->currentPruningInstance; + Node *leftOperand = NULL; + Node *rightOperand = NULL; + Const *constantClause = NULL; + Var *varClause = NULL; + + if (!prune->addedToPruningInstances) + { + context->pruningInstances = lappend(context->pruningInstances, + prune); + prune->addedToPruningInstances = true; + } + + if (list_length(opClause->args) == 2) + { + leftOperand = get_leftop((Expr *) opClause); + rightOperand = get_rightop((Expr *) opClause); + + leftOperand = strip_implicit_coercions(leftOperand); + rightOperand = strip_implicit_coercions(rightOperand); + + if (IsA(rightOperand, Const) && IsA(leftOperand, Var)) + { + constantClause = (Const *) rightOperand; + varClause = (Var *) leftOperand; + } + else if (IsA(leftOperand, Const) && IsA(rightOperand, Var)) + { + constantClause = (Const *) leftOperand; + varClause = (Var *) rightOperand; + } + } + + if (constantClause && varClause && equal(varClause, context->partitionColumn)) + { + /* + * Found a restriction on the partition column itself. Update the + * current constraint with the new information. + */ + AddPartitionKeyRestrictionToInstance(context, + opClause, varClause, constantClause); + } + else if (constantClause && varClause && + varClause->varattno == RESERVED_HASHED_COLUMN_ID) + { + /* + * Found restriction that directly specifies the boundaries of a + * hashed column. + */ + AddHashRestrictionToInstance(context, opClause, varClause, constantClause); + } + + return false; + } + else if (IsA(node, ScalarArrayOpExpr)) + { + PruningInstance *prune = context->currentPruningInstance; + ScalarArrayOpExpr *arrayOperatorExpression = (ScalarArrayOpExpr *) node; + Node *leftOpExpression = linitial(arrayOperatorExpression->args); + Node *strippedLeftOpExpression = strip_implicit_coercions(leftOpExpression); + bool usingEqualityOperator = OperatorImplementsEquality( + arrayOperatorExpression->opno); + + /* + * Citus cannot prune hash-distributed shards with ANY/ALL. We show a NOTICE + * if the expression is ANY/ALL performed on the partition column with equality. + * + * TODO: this'd now be easy to implement, similar to the OR_EXPR case + * above, except that one would push an appropriately constructed + * OpExpr(LHS = $array_element) as continueAt. + */ + if (usingEqualityOperator && strippedLeftOpExpression != NULL && + equal(strippedLeftOpExpression, context->partitionColumn)) + { + ereport(NOTICE, (errmsg("cannot use shard pruning with " + "ANY/ALL (array expression)"), + errhint("Consider rewriting the expression with " + "OR/AND clauses."))); + } + + /* + * Mark expression as added, so we'll fail pruning if there's no ANDed + * restrictions that we can deal with. + */ + if (!prune->addedToPruningInstances) + { + context->pruningInstances = lappend(context->pruningInstances, + prune); + prune->addedToPruningInstances = true; + } + + return false; + } + else + { + PruningInstance *prune = context->currentPruningInstance; + + /* + * Mark expression as added, so we'll fail pruning if there's no ANDed + * restrictions that we know how to deal with. + */ + if (!prune->addedToPruningInstances) + { + context->pruningInstances = lappend(context->pruningInstances, + prune); + prune->addedToPruningInstances = true; + } + + return false; + } + + return expression_tree_walker(node, PrunableExpressionsWalker, context); +} + + +/* + * AddPartitionKeyRestrictionToInstance adds information about a PartitionKey + * $op Const restriction to the current pruning instance. + */ +static void +AddPartitionKeyRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opClause, + Var *varClause, Const *constantClause) +{ + PruningInstance *prune = context->currentPruningInstance; + List *btreeInterpretationList = NULL; + ListCell *btreeInterpretationCell = NULL; + bool matchedOp = false; + + btreeInterpretationList = + get_op_btree_interpretation(opClause->opno); + foreach(btreeInterpretationCell, btreeInterpretationList) + { + OpBtreeInterpretation *btreeInterpretation = + (OpBtreeInterpretation *) lfirst(btreeInterpretationCell); + + switch (btreeInterpretation->strategy) + { + case BTLessStrategyNumber: + { + if (!prune->lessConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->lessConsts->constvalue) < 0) + { + prune->lessConsts = constantClause; + } + matchedOp = true; + } + break; + + case BTLessEqualStrategyNumber: + { + if (!prune->lessEqualConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->lessEqualConsts->constvalue) < 0) + { + prune->lessEqualConsts = constantClause; + } + matchedOp = true; + } + break; + + case BTEqualStrategyNumber: + { + if (!prune->equalConsts) + { + prune->equalConsts = constantClause; + } + else if (PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->equalConsts->constvalue) != 0) + { + /* key can't be equal to two values */ + prune->evaluatesToFalse = true; + } + matchedOp = true; + } + break; + + case BTGreaterEqualStrategyNumber: + { + if (!prune->greaterEqualConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->greaterEqualConsts->constvalue) > 0 + ) + { + prune->greaterEqualConsts = constantClause; + } + matchedOp = true; + } + break; + + case BTGreaterStrategyNumber: + { + if (!prune->greaterConsts || + PerformValueCompare(&context->compareValueFunctionCall, + constantClause->constvalue, + prune->greaterConsts->constvalue) > 0) + { + prune->greaterConsts = constantClause; + } + matchedOp = true; + } + break; + + case ROWCOMPARE_NE: + { + /* TODO: could add support for this, if we feel like it */ + matchedOp = false; + } + break; + + default: + Assert(false); + } + } + + if (!matchedOp) + { + prune->otherRestrictions = + lappend(prune->otherRestrictions, opClause); + } + else + { + prune->hasValidConstraint = true; + } +} + + +/* + * AddHashRestrictionToInstance adds information about a + * RESERVED_HASHED_COLUMN_ID = Const restriction to the current pruning + * instance. + */ +static void +AddHashRestrictionToInstance(ClauseWalkerContext *context, OpExpr *opClause, + Var *varClause, Const *constantClause) +{ + PruningInstance *prune = context->currentPruningInstance; + List *btreeInterpretationList = NULL; + ListCell *btreeInterpretationCell = NULL; + + btreeInterpretationList = + get_op_btree_interpretation(opClause->opno); + foreach(btreeInterpretationCell, btreeInterpretationList) + { + OpBtreeInterpretation *btreeInterpretation = + (OpBtreeInterpretation *) lfirst(btreeInterpretationCell); + + /* + * Ladidadida, dirty hackety hack. We only add such + * constraints (in ShardIntervalOpExpressions()) to select a + * shard based on its exact boundaries. For efficient binary + * search it's better to simply use one representative value + * to look up the shard. In practice, this is sufficient for + * now. + */ + if (btreeInterpretation->strategy == BTGreaterEqualStrategyNumber) + { + Assert(!prune->hashedEqualConsts); + prune->hashedEqualConsts = constantClause; + prune->hasValidConstraint = true; + } + } +} + + +/* + * CopyPartialPruningInstance copies a partial PruningInstance, so it can be + * completed. + */ +static PruningInstance * +CopyPartialPruningInstance(PruningInstance *sourceInstance) +{ + PruningInstance *newInstance = palloc(sizeof(PruningInstance)); + + Assert(sourceInstance->isPartial); + + /* + * To make the new PruningInstance useful for pruning, we have to reset it + * being partial - if necessary it'll be marked so again by + * PrunableExpressionsWalker(). + */ + memcpy(newInstance, sourceInstance, sizeof(PruningInstance)); + newInstance->addedToPruningInstances = false; + newInstance->isPartial = false; + + return newInstance; +} + + +/* + * ShardArrayToList builds a list of out the array of ShardInterval*. + */ +static List * +ShardArrayToList(ShardInterval **shardArray, int length) +{ + List *shardIntervalList = NIL; + int shardIndex; + + for (shardIndex = 0; shardIndex < length; shardIndex++) + { + ShardInterval *shardInterval = + shardArray[shardIndex]; + shardIntervalList = lappend(shardIntervalList, shardInterval); + } + + return shardIntervalList; +} + + +/* + * DeepCopyShardIntervalList copies originalShardIntervalList and the + * contained ShardIntervals, into a new list. + */ +static List * +DeepCopyShardIntervalList(List *originalShardIntervalList) +{ + List *copiedShardIntervalList = NIL; + ListCell *shardIntervalCell = NULL; + + foreach(shardIntervalCell, originalShardIntervalList) + { + ShardInterval *originalShardInterval = + (ShardInterval *) lfirst(shardIntervalCell); + ShardInterval *copiedShardInterval = + (ShardInterval *) palloc0(sizeof(ShardInterval)); + + CopyShardInterval(originalShardInterval, copiedShardInterval); + copiedShardIntervalList = lappend(copiedShardIntervalList, copiedShardInterval); + } + + return copiedShardIntervalList; +} + + +/* + * PruneOne returns all shards in the table that match a single + * PruningInstance. + */ +static List * +PruneOne(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune) +{ + ShardInterval *shardInterval = NULL; + + /* Well, if life always were this easy... */ + if (prune->evaluatesToFalse) + { + return NIL; + } + + /* + * For an equal constraints, if there's no overlapping shards (always the + * case for hash and range partitioning, sometimes for append), can + * perform binary search for the right interval. That's usually the + * fastest, so try that first. + */ + if (prune->equalConsts && + !cacheEntry->hasOverlappingShardInterval) + { + shardInterval = FindShardInterval(prune->equalConsts->constvalue, cacheEntry); + + /* + * If pruned down to nothing, we're done. Otherwise see if other + * methods prune down further / to nothing. + */ + if (!shardInterval) + { + return NIL; + } + } + + /* + * If the hash value we're looking for is known, we can search for the + * interval directly. That's fast and should only ever be the case for a + * hash-partitioned table. + */ + if (prune->hashedEqualConsts) + { + int shardIndex = INVALID_SHARD_INDEX; + ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; + + Assert(context->partitionMethod == DISTRIBUTE_BY_HASH); + + shardIndex = FindShardIntervalIndex(prune->hashedEqualConsts->constvalue, + cacheEntry); + + if (shardIndex == INVALID_SHARD_INDEX) + { + return NIL; + } + else if (shardInterval && + sortedShardIntervalArray[shardIndex]->shardId != shardInterval->shardId) + { + /* + * equalConst based pruning above yielded a different shard than + * pruning based on pre-hashed equality. This is useful in case + * of INSERT ... SELECT, where both can occur together (one via + * join/colocation, the other via a plain equality restriction). + */ + return NIL; + } + else + { + return list_make1(sortedShardIntervalArray[shardIndex]); + } + } + + /* + * If previous pruning method yielded a single shard, we could also + * attempt range based pruning to exclude it further. But that seems + * rarely useful in practice, and thus likely a waste of runtime and code + * complexity. + */ + if (shardInterval) + { + return list_make1(shardInterval); + } + + /* + * Should never get here for hashing, we've filtered down to either zero + * or one shard, and returned. + */ + Assert(context->partitionMethod != DISTRIBUTE_BY_HASH); + + /* + * Next method: binary search with fuzzy boundaries. Can't trivially do so + * if shards have overlapping boundaries. + * + * TODO: If we kept shard intervals separately sorted by both upper and + * lower boundaries, this should be possible? + */ + if (!cacheEntry->hasOverlappingShardInterval && ( + prune->greaterConsts || prune->greaterEqualConsts || + prune->lessConsts || prune->lessEqualConsts)) + { + return PruneWithBoundaries(cacheEntry, context, prune); + } + + /* + * Brute force: Check each shard. + */ + return ExhaustivePrune(cacheEntry, context, prune); +} + + +/* + * PerformCompare invokes comparator with prepared values, check for + * unexpected NULL returns. + */ +static int +PerformCompare(FunctionCallInfoData *compareFunctionCall) +{ + Datum result = FunctionCallInvoke(compareFunctionCall); + + if (compareFunctionCall->isnull) + { + elog(ERROR, "function %u returned NULL", compareFunctionCall->flinfo->fn_oid); + } + + return DatumGetInt32(result); +} + + +/* + * PerformValueCompare invokes comparator with a/b, and checks for unexpected + * NULL returns. + */ +static int +PerformValueCompare(FunctionCallInfoData *compareFunctionCall, Datum a, Datum b) +{ + compareFunctionCall->arg[0] = a; + compareFunctionCall->argnull[0] = false; + compareFunctionCall->arg[1] = b; + compareFunctionCall->argnull[1] = false; + + return PerformCompare(compareFunctionCall); +} + + +/* + * LowerShardBoundary returns the index of the first ShardInterval that's >= + * (if includeMax) or > partitionColumnValue. + */ +static int +LowerShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, bool includeMax) +{ + int lowerBoundIndex = 0; + int upperBoundIndex = shardCount; + + Assert(shardCount != 0); + + /* setup partitionColumnValue argument once */ + compareFunction->arg[0] = partitionColumnValue; + compareFunction->argnull[0] = false; + + while (lowerBoundIndex < upperBoundIndex) + { + int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2); + int maxValueComparison = 0; + int minValueComparison = 0; + + /* setup minValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->minValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, lowerBound) */ + minValueComparison = PerformCompare(compareFunction); + + /* and evaluate results */ + if (minValueComparison < 0) + { + /* value smaller than entire range */ + upperBoundIndex = middleIndex; + continue; + } + + /* setup maxValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->maxValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, upperBound) */ + maxValueComparison = PerformCompare(compareFunction); + + if ((maxValueComparison == 0 && !includeMax) || + maxValueComparison > 0) + { + /* value bigger than entire range */ + lowerBoundIndex = middleIndex + 1; + continue; + } + + /* found interval containing partitionValue */ + return middleIndex; + } + + Assert(lowerBoundIndex == upperBoundIndex); + + /* + * If we get here, none of the ShardIntervals exactly contain the value + * (we'd have hit the return middleIndex; case otherwise). Figure out + * whether there's possibly any interval containing a value that's bigger + * than the partition key one. + */ + if (lowerBoundIndex == 0) + { + /* all intervals are bigger, thus return 0 */ + return 0; + } + else if (lowerBoundIndex == shardCount) + { + /* partition value is bigger than all partition values */ + return INVALID_SHARD_INDEX; + } + + /* value falls inbetween intervals */ + return lowerBoundIndex + 1; +} + + +/* + * UpperShardBoundary returns the index of the last ShardInterval that's <= + * (if includeMin) or < partitionColumnValue. + */ +static int +UpperShardBoundary(Datum partitionColumnValue, ShardInterval **shardIntervalCache, + int shardCount, FunctionCallInfoData *compareFunction, bool includeMin) +{ + int lowerBoundIndex = 0; + int upperBoundIndex = shardCount; + + Assert(shardCount != 0); + + /* setup partitionColumnValue argument once */ + compareFunction->arg[0] = partitionColumnValue; + compareFunction->argnull[0] = false; + + while (lowerBoundIndex < upperBoundIndex) + { + int middleIndex = lowerBoundIndex + ((upperBoundIndex - lowerBoundIndex) / 2); + int maxValueComparison = 0; + int minValueComparison = 0; + + /* setup minValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->minValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, lowerBound) */ + minValueComparison = PerformCompare(compareFunction); + + /* and evaluate results */ + if ((minValueComparison == 0 && !includeMin) || + minValueComparison < 0) + { + /* value smaller than entire range */ + upperBoundIndex = middleIndex; + continue; + } + + /* setup maxValue as argument */ + compareFunction->arg[1] = shardIntervalCache[middleIndex]->maxValue; + compareFunction->argnull[1] = false; + + /* execute cmp(partitionValue, upperBound) */ + maxValueComparison = PerformCompare(compareFunction); + + if (maxValueComparison > 0) + { + /* value bigger than entire range */ + lowerBoundIndex = middleIndex + 1; + continue; + } + + /* found interval containing partitionValue */ + return middleIndex; + } + + Assert(lowerBoundIndex == upperBoundIndex); + + /* + * If we get here, none of the ShardIntervals exactly contain the value + * (we'd have hit the return middleIndex; case otherwise). Figure out + * whether there's possibly any interval containing a value that's smaller + * than the partition key one. + */ + if (upperBoundIndex == shardCount) + { + /* all intervals are smaller, thus return 0 */ + return shardCount - 1; + } + else if (upperBoundIndex == 0) + { + /* partition value is smaller than all partition values */ + return INVALID_SHARD_INDEX; + } + + /* value falls inbetween intervals, return the inverval one smaller as bound */ + return upperBoundIndex - 1; +} + + +/* + * PruneWithBoundaries searches for shards that match inequality constraints, + * using binary search on both the upper and lower boundary, and returns a + * list of surviving shards. + */ +static List * +PruneWithBoundaries(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune) +{ + List *remainingShardList = NIL; + int shardCount = cacheEntry->shardIntervalArrayLength; + ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; + bool hasLowerBound = false; + bool hasUpperBound = false; + Datum lowerBound = 0; + Datum upperBound = 0; + bool lowerBoundInclusive = false; + bool upperBoundInclusive = false; + int lowerBoundIdx = -1; + int upperBoundIdx = -1; + int curIdx = 0; + FunctionCallInfo compareFunctionCall = &context->compareIntervalFunctionCall; + + if (prune->greaterEqualConsts) + { + lowerBound = prune->greaterEqualConsts->constvalue; + lowerBoundInclusive = true; + hasLowerBound = true; + } + if (prune->greaterConsts) + { + lowerBound = prune->greaterConsts->constvalue; + lowerBoundInclusive = false; + hasLowerBound = true; + } + if (prune->lessEqualConsts) + { + upperBound = prune->lessEqualConsts->constvalue; + upperBoundInclusive = true; + hasUpperBound = true; + } + if (prune->lessConsts) + { + upperBound = prune->lessConsts->constvalue; + upperBoundInclusive = false; + hasUpperBound = true; + } + + Assert(hasLowerBound || hasUpperBound); + + /* find lower bound */ + if (hasLowerBound) + { + lowerBoundIdx = LowerShardBoundary(lowerBound, sortedShardIntervalArray, + shardCount, compareFunctionCall, + lowerBoundInclusive); + } + else + { + lowerBoundIdx = 0; + } + + /* find upper bound */ + if (hasUpperBound) + { + upperBoundIdx = UpperShardBoundary(upperBound, sortedShardIntervalArray, + shardCount, compareFunctionCall, + upperBoundInclusive); + } + else + { + upperBoundIdx = shardCount - 1; + } + + if (lowerBoundIdx == INVALID_SHARD_INDEX) + { + return NIL; + } + else if (upperBoundIdx == INVALID_SHARD_INDEX) + { + return NIL; + } + + /* + * Build list of all shards that are in the range of shards (possibly 0). + */ + for (curIdx = lowerBoundIdx; curIdx <= upperBoundIdx; curIdx++) + { + remainingShardList = lappend(remainingShardList, + sortedShardIntervalArray[curIdx]); + } + + return remainingShardList; +} + + +/* + * ExhaustivePrune returns a list of shards matching PruningInstances + * constraints, by simply checking them for each individual shard. + */ +static List * +ExhaustivePrune(DistTableCacheEntry *cacheEntry, ClauseWalkerContext *context, + PruningInstance *prune) +{ + List *remainingShardList = NIL; + FunctionCallInfo compareFunctionCall = &context->compareIntervalFunctionCall; + int shardCount = cacheEntry->shardIntervalArrayLength; + ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; + int curIdx = 0; + + for (curIdx = 0; curIdx < shardCount; curIdx++) + { + Datum compareWith = 0; + ShardInterval *curInterval = sortedShardIntervalArray[curIdx]; + + /* NULL boundaries can't be compared to */ + if (!curInterval->minValueExists || !curInterval->maxValueExists) + { + remainingShardList = lappend(remainingShardList, curInterval); + continue; + } + + if (prune->equalConsts) + { + compareWith = prune->equalConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + compareWith, + curInterval->minValue) < 0) + { + continue; + } + + if (PerformValueCompare(compareFunctionCall, + compareWith, + curInterval->maxValue) > 0) + { + continue; + } + } + if (prune->greaterEqualConsts) + { + compareWith = prune->greaterEqualConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->maxValue, + compareWith) < 0) + { + continue; + } + } + if (prune->greaterConsts) + { + compareWith = prune->greaterConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->maxValue, + compareWith) <= 0) + { + continue; + } + } + if (prune->lessEqualConsts) + { + compareWith = prune->lessEqualConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->minValue, + compareWith) > 0) + { + continue; + } + } + if (prune->lessConsts) + { + compareWith = prune->lessConsts->constvalue; + + if (PerformValueCompare(compareFunctionCall, + curInterval->minValue, + compareWith) >= 0) + { + continue; + } + } + + remainingShardList = lappend(remainingShardList, curInterval); + } + + return remainingShardList; +} diff --git a/src/backend/distributed/test/prune_shard_list.c b/src/backend/distributed/test/prune_shard_list.c index 88827f155..d68ebe516 100644 --- a/src/backend/distributed/test/prune_shard_list.c +++ b/src/backend/distributed/test/prune_shard_list.c @@ -24,6 +24,7 @@ #include "distributed/multi_physical_planner.h" #include "distributed/resource_lock.h" #include "distributed/test_helper_functions.h" /* IWYU pragma: keep */ +#include "distributed/shard_pruning.h" #include "nodes/pg_list.h" #include "nodes/primnodes.h" #include "nodes/nodes.h" @@ -203,11 +204,11 @@ PrunedShardIdsForTable(Oid distributedTableId, List *whereClauseList) Oid shardIdTypeId = INT8OID; Index tableId = 1; - List *shardList = LoadShardIntervalList(distributedTableId); + List *shardList = NIL; int shardIdCount = -1; Datum *shardIdDatumArray = NULL; - shardList = PruneShardList(distributedTableId, tableId, whereClauseList, shardList); + shardList = PruneShards(distributedTableId, tableId, whereClauseList); shardIdCount = list_length(shardList); shardIdDatumArray = palloc0(shardIdCount * sizeof(Datum)); diff --git a/src/backend/distributed/utils/metadata_cache.c b/src/backend/distributed/utils/metadata_cache.c index a48405597..15062f78e 100644 --- a/src/backend/distributed/utils/metadata_cache.c +++ b/src/backend/distributed/utils/metadata_cache.c @@ -134,8 +134,6 @@ static ShardCacheEntry * LookupShardCacheEntry(int64 shardId); static DistTableCacheEntry * LookupDistTableCacheEntry(Oid relationId); static void BuildDistTableCacheEntry(DistTableCacheEntry *cacheEntry); static void BuildCachedShardList(DistTableCacheEntry *cacheEntry); -static FmgrInfo * ShardIntervalCompareFunction(ShardInterval **shardIntervalArray, - char partitionMethod); static ShardInterval ** SortShardIntervalArray(ShardInterval **shardIntervalArray, int shardCount, FmgrInfo * @@ -147,6 +145,9 @@ static bool HasUninitializedShardInterval(ShardInterval **sortedShardIntervalArr static void ErrorIfInstalledVersionMismatch(void); static char * AvailableExtensionVersion(void); static char * InstalledExtensionVersion(void); +static bool HasOverlappingShardInterval(ShardInterval **shardIntervalArray, + int shardIntervalArrayLength, + FmgrInfo *shardIntervalSortCompareFunction); static void InitializeDistTableCache(void); static void InitializeWorkerNodeCache(void); static uint32 WorkerNodeHashCode(const void *key, Size keySize); @@ -158,6 +159,7 @@ static HeapTuple LookupDistPartitionTuple(Relation pgDistPartition, Oid relation static List * LookupDistShardTuples(Oid relationId); static Oid LookupShardRelation(int64 shardId); static void GetPartitionTypeInputInfo(char *partitionKeyString, char partitionMethod, + Oid *columnTypeId, int32 *columnTypeMod, Oid *intervalTypeId, int32 *intervalTypeMod); static ShardInterval * TupleToShardInterval(HeapTuple heapTuple, TupleDesc tupleDescriptor, Oid intervalTypeId, @@ -619,9 +621,21 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) ShardInterval **shardIntervalArray = NULL; ShardInterval **sortedShardIntervalArray = NULL; FmgrInfo *shardIntervalCompareFunction = NULL; + FmgrInfo *shardColumnCompareFunction = NULL; List *distShardTupleList = NIL; int shardIntervalArrayLength = 0; int shardIndex = 0; + Oid columnTypeId = InvalidOid; + int32 columnTypeMod = -1; + Oid intervalTypeId = InvalidOid; + int32 intervalTypeMod = -1; + + GetPartitionTypeInputInfo(cacheEntry->partitionKeyString, + cacheEntry->partitionMethod, + &columnTypeId, + &columnTypeMod, + &intervalTypeId, + &intervalTypeMod); distShardTupleList = LookupDistShardTuples(cacheEntry->relationId); shardIntervalArrayLength = list_length(distShardTupleList); @@ -631,13 +645,6 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) TupleDesc distShardTupleDesc = RelationGetDescr(distShardRelation); ListCell *distShardTupleCell = NULL; int arrayIndex = 0; - Oid intervalTypeId = InvalidOid; - int32 intervalTypeMod = -1; - - GetPartitionTypeInputInfo(cacheEntry->partitionKeyString, - cacheEntry->partitionMethod, - &intervalTypeId, - &intervalTypeMod); shardIntervalArray = MemoryContextAllocZero(CacheMemoryContext, shardIntervalArrayLength * @@ -676,29 +683,41 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) heap_close(distShardRelation, AccessShareLock); } - /* decide and allocate interval comparison function */ - if (cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE) + /* look up value comparison function */ + if (columnTypeId != InvalidOid) + { + /* allocate the comparison function in the cache context */ + MemoryContext oldContext = MemoryContextSwitchTo(CacheMemoryContext); + + shardColumnCompareFunction = GetFunctionInfo(columnTypeId, BTREE_AM_OID, + BTORDER_PROC); + MemoryContextSwitchTo(oldContext); + } + else + { + shardColumnCompareFunction = NULL; + } + + /* look up interval comparison function */ + if (intervalTypeId != InvalidOid) + { + /* allocate the comparison function in the cache context */ + MemoryContext oldContext = MemoryContextSwitchTo(CacheMemoryContext); + + shardIntervalCompareFunction = GetFunctionInfo(intervalTypeId, BTREE_AM_OID, + BTORDER_PROC); + MemoryContextSwitchTo(oldContext); + } + else { shardIntervalCompareFunction = NULL; } - else if (shardIntervalArrayLength > 0) - { - MemoryContext oldContext = CurrentMemoryContext; - - /* allocate the comparison function in the cache context */ - oldContext = MemoryContextSwitchTo(CacheMemoryContext); - - shardIntervalCompareFunction = - ShardIntervalCompareFunction(shardIntervalArray, - cacheEntry->partitionMethod); - - MemoryContextSwitchTo(oldContext); - } /* reference tables has a single shard which is not initialized */ if (cacheEntry->partitionMethod == DISTRIBUTE_BY_NONE) { cacheEntry->hasUninitializedShardInterval = true; + cacheEntry->hasOverlappingShardInterval = true; /* * Note that during create_reference_table() call, @@ -727,6 +746,35 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) cacheEntry->hasUninitializedShardInterval = HasUninitializedShardInterval(sortedShardIntervalArray, shardIntervalArrayLength); + + if (!cacheEntry->hasUninitializedShardInterval) + { + cacheEntry->hasOverlappingShardInterval = + HasOverlappingShardInterval(sortedShardIntervalArray, + shardIntervalArrayLength, + shardIntervalCompareFunction); + } + else + { + cacheEntry->hasOverlappingShardInterval = true; + } + + /* + * If table is hash-partitioned and has shards, there never should be + * any uninitalized shards. Historically we've not prevented that for + * range partitioned tables, but it might be a good idea to start + * doing so. + */ + if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH && + cacheEntry->hasUninitializedShardInterval) + { + ereport(ERROR, (errmsg("hash partitioned table has uninitialized shards"))); + } + if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH && + cacheEntry->hasOverlappingShardInterval) + { + ereport(ERROR, (errmsg("hash partitioned table has overlapping shards"))); + } } @@ -794,41 +842,11 @@ BuildCachedShardList(DistTableCacheEntry *cacheEntry) cacheEntry->shardIntervalArrayLength = shardIntervalArrayLength; cacheEntry->sortedShardIntervalArray = sortedShardIntervalArray; + cacheEntry->shardColumnCompareFunction = shardColumnCompareFunction; cacheEntry->shardIntervalCompareFunction = shardIntervalCompareFunction; } -/* - * ShardIntervalCompareFunction returns the appropriate compare function for the - * partition column type. In case of hash-partitioning, it always returns the compare - * function for integers. Callers of this function has to ensure that shardIntervalArray - * has at least one element. - */ -static FmgrInfo * -ShardIntervalCompareFunction(ShardInterval **shardIntervalArray, char partitionMethod) -{ - FmgrInfo *shardIntervalCompareFunction = NULL; - Oid comparisonTypeId = InvalidOid; - - Assert(shardIntervalArray != NULL); - - if (partitionMethod == DISTRIBUTE_BY_HASH) - { - comparisonTypeId = INT4OID; - } - else - { - ShardInterval *shardInterval = shardIntervalArray[0]; - comparisonTypeId = shardInterval->valueTypeId; - } - - shardIntervalCompareFunction = GetFunctionInfo(comparisonTypeId, BTREE_AM_OID, - BTORDER_PROC); - - return shardIntervalCompareFunction; -} - - /* * SortedShardIntervalArray sorts the input shardIntervalArray. Shard intervals with * no min/max values are placed at the end of the array. @@ -932,6 +950,52 @@ HasUninitializedShardInterval(ShardInterval **sortedShardIntervalArray, int shar } +/* + * HasOverlappingShardInterval determines whether the given list of sorted + * shards has overlapping ranges. + */ +static bool +HasOverlappingShardInterval(ShardInterval **shardIntervalArray, + int shardIntervalArrayLength, + FmgrInfo *shardIntervalSortCompareFunction) +{ + int shardIndex = 0; + ShardInterval *lastShardInterval = NULL; + Datum comparisonDatum = 0; + int comparisonResult = 0; + + /* zero/a single shard can't overlap */ + if (shardIntervalArrayLength < 2) + { + return false; + } + + lastShardInterval = shardIntervalArray[0]; + for (shardIndex = 1; shardIndex < shardIntervalArrayLength; shardIndex++) + { + ShardInterval *curShardInterval = shardIntervalArray[shardIndex]; + + /* only called if !hasUninitializedShardInterval */ + Assert(lastShardInterval->minValueExists && lastShardInterval->maxValueExists); + Assert(curShardInterval->minValueExists && curShardInterval->maxValueExists); + + comparisonDatum = CompareCall2(shardIntervalSortCompareFunction, + lastShardInterval->maxValue, + curShardInterval->minValue); + comparisonResult = DatumGetInt32(comparisonDatum); + + if (comparisonResult >= 0) + { + return true; + } + + lastShardInterval = curShardInterval; + } + + return false; +} + + /* * CitusHasBeenLoaded returns true if the citus extension has been created * in the current database and the extension script has been executed. Otherwise, @@ -2153,6 +2217,7 @@ ResetDistTableCacheEntry(DistTableCacheEntry *cacheEntry) cacheEntry->shardIntervalArrayLength = 0; cacheEntry->hasUninitializedShardInterval = false; cacheEntry->hasUniformHashDistribution = false; + cacheEntry->hasOverlappingShardInterval = false; } @@ -2415,8 +2480,11 @@ LookupShardRelation(int64 shardId) */ static void GetPartitionTypeInputInfo(char *partitionKeyString, char partitionMethod, + Oid *columnTypeId, int32 *columnTypeMod, Oid *intervalTypeId, int32 *intervalTypeMod) { + *columnTypeId = InvalidOid; + *columnTypeMod = -1; *intervalTypeId = InvalidOid; *intervalTypeMod = -1; @@ -2431,18 +2499,25 @@ GetPartitionTypeInputInfo(char *partitionKeyString, char partitionMethod, *intervalTypeId = partitionColumn->vartype; *intervalTypeMod = partitionColumn->vartypmod; + *columnTypeId = partitionColumn->vartype; + *columnTypeMod = partitionColumn->vartypmod; break; } case DISTRIBUTE_BY_HASH: { + Node *partitionNode = stringToNode(partitionKeyString); + Var *partitionColumn = (Var *) partitionNode; + Assert(IsA(partitionNode, Var)); + *intervalTypeId = INT4OID; + *columnTypeId = partitionColumn->vartype; + *columnTypeMod = partitionColumn->vartypmod; break; } case DISTRIBUTE_BY_NONE: { - *intervalTypeId = InvalidOid; break; } diff --git a/src/backend/distributed/utils/shardinterval_utils.c b/src/backend/distributed/utils/shardinterval_utils.c index 8a63375e8..046ba3208 100644 --- a/src/backend/distributed/utils/shardinterval_utils.c +++ b/src/backend/distributed/utils/shardinterval_utils.c @@ -16,6 +16,7 @@ #include "catalog/pg_type.h" #include "distributed/metadata_cache.h" #include "distributed/multi_planner.h" +#include "distributed/shard_pruning.h" #include "distributed/shardinterval_utils.h" #include "distributed/pg_dist_partition.h" #include "distributed/worker_protocol.h" @@ -23,7 +24,6 @@ #include "utils/memutils.h" -static int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry); static int SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, int shardCount, FmgrInfo *compareFunction); @@ -247,13 +247,14 @@ FindShardInterval(Datum partitionColumnValue, DistTableCacheEntry *cacheEntry) * the searched value. Note that the searched value must be the hashed value * of the original value if the distribution method is hash. * - * Note that, if the searched value can not be found for hash partitioned tables, - * we error out. This should only happen if something is terribly wrong, either - * metadata tables are corrupted or we have a bug somewhere. Such as a hash - * function which returns a value not in the range of [INT32_MIN, INT32_MAX] can - * fire this. + * Note that, if the searched value can not be found for hash partitioned + * tables, we error out (unless there are no shards, in which case + * INVALID_SHARD_INDEX is returned). This should only happen if something is + * terribly wrong, either metadata tables are corrupted or we have a bug + * somewhere. Such as a hash function which returns a value not in the range + * of [INT32_MIN, INT32_MAX] can fire this. */ -static int +int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry) { ShardInterval **shardIntervalCache = cacheEntry->sortedShardIntervalArray; @@ -264,6 +265,11 @@ FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry) !cacheEntry->hasUniformHashDistribution); int shardIndex = INVALID_SHARD_INDEX; + if (shardCount == 0) + { + return INVALID_SHARD_INDEX; + } + if (partitionMethod == DISTRIBUTE_BY_HASH) { if (useBinarySearch) diff --git a/src/include/distributed/metadata_cache.h b/src/include/distributed/metadata_cache.h index 24324aa56..ae83940d4 100644 --- a/src/include/distributed/metadata_cache.h +++ b/src/include/distributed/metadata_cache.h @@ -38,6 +38,7 @@ typedef struct bool isDistributedTable; bool hasUninitializedShardInterval; bool hasUniformHashDistribution; /* valid for hash partitioned tables */ + bool hasOverlappingShardInterval; /* pg_dist_partition metadata for this table */ char *partitionKeyString; @@ -49,7 +50,15 @@ typedef struct int shardIntervalArrayLength; ShardInterval **sortedShardIntervalArray; - FmgrInfo *shardIntervalCompareFunction; /* NULL if no shard intervals exist */ + /* comparator for partition column's type, NULL if DISTRIBUTE_BY_NONE */ + FmgrInfo *shardColumnCompareFunction; + + /* + * Comparator for partition interval type (different from + * shardValueCompareFunction if hash-partitioned), NULL if + * DISTRIBUTE_BY_NONE. + */ + FmgrInfo *shardIntervalCompareFunction; FmgrInfo *hashFunction; /* NULL if table is not distributed by hash */ /* pg_dist_shard_placement metadata */ diff --git a/src/include/distributed/multi_physical_planner.h b/src/include/distributed/multi_physical_planner.h index 20512bc3e..0c701c479 100644 --- a/src/include/distributed/multi_physical_planner.h +++ b/src/include/distributed/multi_physical_planner.h @@ -253,10 +253,6 @@ extern StringInfo ShardFetchQueryString(uint64 shardId); extern Task * CreateBasicTask(uint64 jobId, uint32 taskId, TaskType taskType, char *queryString); -/* Function declarations for shard pruning */ -extern List * PruneShardList(Oid relationId, Index tableId, List *whereClauseList, - List *shardList); -extern bool ContainsFalseClause(List *whereClauseList); extern OpExpr * MakeOpExpression(Var *variable, int16 strategyNumber); /* diff --git a/src/include/distributed/shard_pruning.h b/src/include/distributed/shard_pruning.h new file mode 100644 index 000000000..3c26c4662 --- /dev/null +++ b/src/include/distributed/shard_pruning.h @@ -0,0 +1,23 @@ +/*------------------------------------------------------------------------- + * + * shard_pruning.h + * Shard pruning infrastructure. + * + * Copyright (c) 2014-2017, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#ifndef SHARD_PRUNING_H_ +#define SHARD_PRUNING_H_ + +#include "distributed/metadata_cache.h" +#include "nodes/primnodes.h" + +#define INVALID_SHARD_INDEX -1 + +/* Function declarations for shard pruning */ +extern List * PruneShards(Oid relationId, Index rangeTableId, List *whereClauseList); +extern bool ContainsFalseClause(List *whereClauseList); + +#endif /* SHARD_PRUNING_H_ */ diff --git a/src/include/distributed/shardinterval_utils.h b/src/include/distributed/shardinterval_utils.h index 54b96f2e7..5c9d6cde8 100644 --- a/src/include/distributed/shardinterval_utils.h +++ b/src/include/distributed/shardinterval_utils.h @@ -35,6 +35,7 @@ extern int CompareRelationShards(const void *leftElement, extern int ShardIndex(ShardInterval *shardInterval); extern ShardInterval * FindShardInterval(Datum partitionColumnValue, DistTableCacheEntry *cacheEntry); +extern int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry); extern bool SingleReplicatedTable(Oid relationId); #endif /* SHARDINTERVAL_UTILS_H_ */ diff --git a/src/test/regress/expected/multi_data_types.out b/src/test/regress/expected/multi_data_types.out index 76fd44367..1cb4d1f9b 100644 --- a/src/test/regress/expected/multi_data_types.out +++ b/src/test/regress/expected/multi_data_types.out @@ -10,8 +10,13 @@ CREATE TYPE test_composite_type AS ( ); -- ... as well as a function to use as its comparator... CREATE FUNCTION equal_test_composite_type_function(test_composite_type, test_composite_type) RETURNS boolean -AS 'select $1.i = $2.i AND $1.i2 = $2.i2;' -LANGUAGE SQL +LANGUAGE 'internal' +AS 'record_eq' +IMMUTABLE +RETURNS NULL ON NULL INPUT; +CREATE FUNCTION cmp_test_composite_type_function(test_composite_type, test_composite_type) RETURNS int +LANGUAGE 'internal' +AS 'btrecordcmp' IMMUTABLE RETURNS NULL ON NULL INPUT; -- ... use that function to create a custom equality operator... @@ -34,7 +39,8 @@ RETURNS NULL ON NULL INPUT; -- One uses BTREE the other uses HASH CREATE OPERATOR CLASS cats_op_fam_clas3 DEFAULT FOR TYPE test_composite_type USING BTREE AS -OPERATOR 3 = (test_composite_type, test_composite_type); +OPERATOR 3 = (test_composite_type, test_composite_type), +FUNCTION 1 cmp_test_composite_type_function(test_composite_type, test_composite_type); CREATE OPERATOR CLASS cats_op_fam_class DEFAULT FOR TYPE test_composite_type USING HASH AS OPERATOR 1 = (test_composite_type, test_composite_type), diff --git a/src/test/regress/expected/multi_prune_shard_list.out b/src/test/regress/expected/multi_prune_shard_list.out index 07d293498..4ddaaae9b 100644 --- a/src/test/regress/expected/multi_prune_shard_list.out +++ b/src/test/regress/expected/multi_prune_shard_list.out @@ -69,21 +69,21 @@ SELECT prune_using_single_value('pruning', NULL); SELECT prune_using_either_value('pruning', 'tomato', 'petunia'); prune_using_either_value -------------------------- - {800001,800002} + {800002,800001} (1 row) --- an AND clause with incompatible values returns no shards +-- an AND clause with values on different shards returns no shards SELECT prune_using_both_values('pruning', 'tomato', 'petunia'); prune_using_both_values ------------------------- {} (1 row) --- but if both values are on the same shard, should get back that shard +-- even if both values are on the same shard, a value can't be equal to two others SELECT prune_using_both_values('pruning', 'tomato', 'rose'); prune_using_both_values ------------------------- - {800002} + {} (1 row) -- unit test of the equality expression generation code diff --git a/src/test/regress/input/multi_outer_join_reference.source b/src/test/regress/input/multi_outer_join_reference.source index f1e5946c4..9a106eb61 100644 --- a/src/test/regress/input/multi_outer_join_reference.source +++ b/src/test/regress/input/multi_outer_join_reference.source @@ -166,7 +166,7 @@ FROM -- load some more data \copy multi_outer_join_right_reference FROM '@abs_srcdir@/data/customer-21-30.data' with delimiter '|' --- Update shards so that they do not have 1-1 matching. We should error here. +-- Update shards so that they do not have 1-1 matching, triggering an error. UPDATE pg_dist_shard SET shardminvalue = '2147483646' WHERE shardid = 1260006; UPDATE pg_dist_shard SET shardmaxvalue = '2147483647' WHERE shardid = 1260006; SELECT diff --git a/src/test/regress/output/multi_outer_join_reference.source b/src/test/regress/output/multi_outer_join_reference.source index 82d1b88ed..a2b17a0fa 100644 --- a/src/test/regress/output/multi_outer_join_reference.source +++ b/src/test/regress/output/multi_outer_join_reference.source @@ -228,15 +228,14 @@ LOG: join order: [ "multi_outer_join_left_hash" ][ broadcast join "multi_outer_ -- load some more data \copy multi_outer_join_right_reference FROM '@abs_srcdir@/data/customer-21-30.data' with delimiter '|' --- Update shards so that they do not have 1-1 matching. We should error here. +-- Update shards so that they do not have 1-1 matching, triggering an error. UPDATE pg_dist_shard SET shardminvalue = '2147483646' WHERE shardid = 1260006; UPDATE pg_dist_shard SET shardmaxvalue = '2147483647' WHERE shardid = 1260006; SELECT min(l_custkey), max(l_custkey) FROM multi_outer_join_left_hash a LEFT JOIN multi_outer_join_right_hash b ON (l_custkey = r_custkey); -ERROR: cannot perform distributed planning on this query -DETAIL: Shards of relations in outer join queries must have 1-to-1 shard partitioning +ERROR: hash partitioned table has overlapping shards UPDATE pg_dist_shard SET shardminvalue = '-2147483648' WHERE shardid = 1260006; UPDATE pg_dist_shard SET shardmaxvalue = '-1073741825' WHERE shardid = 1260006; -- empty tables diff --git a/src/test/regress/sql/multi_data_types.sql b/src/test/regress/sql/multi_data_types.sql index 315550474..3dd6eb0f7 100644 --- a/src/test/regress/sql/multi_data_types.sql +++ b/src/test/regress/sql/multi_data_types.sql @@ -15,8 +15,14 @@ CREATE TYPE test_composite_type AS ( -- ... as well as a function to use as its comparator... CREATE FUNCTION equal_test_composite_type_function(test_composite_type, test_composite_type) RETURNS boolean -AS 'select $1.i = $2.i AND $1.i2 = $2.i2;' -LANGUAGE SQL +LANGUAGE 'internal' +AS 'record_eq' +IMMUTABLE +RETURNS NULL ON NULL INPUT; + +CREATE FUNCTION cmp_test_composite_type_function(test_composite_type, test_composite_type) RETURNS int +LANGUAGE 'internal' +AS 'btrecordcmp' IMMUTABLE RETURNS NULL ON NULL INPUT; @@ -44,7 +50,8 @@ RETURNS NULL ON NULL INPUT; -- One uses BTREE the other uses HASH CREATE OPERATOR CLASS cats_op_fam_clas3 DEFAULT FOR TYPE test_composite_type USING BTREE AS -OPERATOR 3 = (test_composite_type, test_composite_type); +OPERATOR 3 = (test_composite_type, test_composite_type), +FUNCTION 1 cmp_test_composite_type_function(test_composite_type, test_composite_type); CREATE OPERATOR CLASS cats_op_fam_class DEFAULT FOR TYPE test_composite_type USING HASH AS diff --git a/src/test/regress/sql/multi_prune_shard_list.sql b/src/test/regress/sql/multi_prune_shard_list.sql index 0e9b9f599..4f42ee9f7 100644 --- a/src/test/regress/sql/multi_prune_shard_list.sql +++ b/src/test/regress/sql/multi_prune_shard_list.sql @@ -59,10 +59,10 @@ SELECT prune_using_single_value('pruning', NULL); -- build an OR clause and expect more than one sahrd SELECT prune_using_either_value('pruning', 'tomato', 'petunia'); --- an AND clause with incompatible values returns no shards +-- an AND clause with values on different shards returns no shards SELECT prune_using_both_values('pruning', 'tomato', 'petunia'); --- but if both values are on the same shard, should get back that shard +-- even if both values are on the same shard, a value can't be equal to two others SELECT prune_using_both_values('pruning', 'tomato', 'rose'); -- unit test of the equality expression generation code