From 42ee7c05f520a45c20f311b54e2cd8aa47d7d95f Mon Sep 17 00:00:00 2001 From: Jason Petersen Date: Wed, 26 Apr 2017 11:33:04 -0600 Subject: [PATCH 1/2] Refactor FindShardInterval to use cacheEntry All callers fetch a cache entry and extract/compute arguments for the eventual FindShardInterval call, so it makes more sense to refactor into that function itself; this solves the use-after-free bug, too. --- src/backend/distributed/commands/multi_copy.c | 20 +-------- .../planner/multi_router_planner.c | 24 +---------- .../distributed/utils/shardinterval_utils.c | 43 ++++++------------- src/include/distributed/multi_copy.h | 1 - src/include/distributed/shardinterval_utils.h | 6 +-- 5 files changed, 18 insertions(+), 76 deletions(-) diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index c769bf3e0..770ff4799 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -1753,12 +1753,6 @@ CitusCopyDestReceiverStartup(DestReceiver *dest, int operation, /* keep the table metadata to avoid looking it up for every tuple */ copyDest->tableMetadata = cacheEntry; - /* determine whether to use binary search */ - if (partitionMethod != DISTRIBUTE_BY_HASH || !cacheEntry->hasUniformHashDistribution) - { - copyDest->useBinarySearch = true; - } - if (cacheEntry->replicationModel == REPLICATION_MODEL_2PC) { CoordinatedTransactionUse2PC(); @@ -1835,19 +1829,10 @@ CitusCopyDestReceiverReceive(TupleTableSlot *slot, DestReceiver *dest) { CitusCopyDestReceiver *copyDest = (CitusCopyDestReceiver *) dest; - DistTableCacheEntry *tableMetadata = copyDest->tableMetadata; - char partitionMethod = tableMetadata->partitionMethod; int partitionColumnIndex = copyDest->partitionColumnIndex; TupleDesc tupleDescriptor = copyDest->tupleDescriptor; CopyStmt *copyStatement = copyDest->copyStatement; - int shardCount = tableMetadata->shardIntervalArrayLength; - ShardInterval **shardIntervalCache = tableMetadata->sortedShardIntervalArray; - - bool useBinarySearch = copyDest->useBinarySearch; - FmgrInfo *hashFunction = tableMetadata->hashFunction; - FmgrInfo *compareFunction = tableMetadata->shardIntervalCompareFunction; - HTAB *shardConnectionHash = copyDest->shardConnectionHash; CopyOutState copyOutState = copyDest->copyOutState; FmgrInfo *columnOutputFunctions = copyDest->columnOutputFunctions; @@ -1907,10 +1892,7 @@ CitusCopyDestReceiverReceive(TupleTableSlot *slot, DestReceiver *dest) * For reference table, this function blindly returns the tables single * shard. */ - shardInterval = FindShardInterval(partitionColumnValue, shardIntervalCache, - shardCount, partitionMethod, - compareFunction, hashFunction, - useBinarySearch); + shardInterval = FindShardInterval(partitionColumnValue, copyDest->tableMetadata); if (shardInterval == NULL) { ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 8d0ded0f4..7cfbc8efc 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -2057,35 +2057,13 @@ ShardInterval * FastShardPruning(Oid distributedTableId, Datum partitionValue) { DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); - int shardCount = cacheEntry->shardIntervalArrayLength; - ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; - bool useBinarySearch = false; - char partitionMethod = cacheEntry->partitionMethod; - FmgrInfo *shardIntervalCompareFunction = cacheEntry->shardIntervalCompareFunction; - bool hasUniformHashDistribution = cacheEntry->hasUniformHashDistribution; - FmgrInfo *hashFunction = NULL; ShardInterval *shardInterval = NULL; - /* determine whether to use binary search */ - if (partitionMethod != DISTRIBUTE_BY_HASH || !hasUniformHashDistribution) - { - useBinarySearch = true; - } - - /* we only need hash functions for hash distributed tables */ - if (partitionMethod == DISTRIBUTE_BY_HASH) - { - hashFunction = cacheEntry->hashFunction; - } - /* * Call FindShardInterval to find the corresponding shard interval for the * given partition value. */ - shardInterval = FindShardInterval(partitionValue, sortedShardIntervalArray, - shardCount, partitionMethod, - shardIntervalCompareFunction, hashFunction, - useBinarySearch); + shardInterval = FindShardInterval(partitionValue, cacheEntry); return shardInterval; } diff --git a/src/backend/distributed/utils/shardinterval_utils.c b/src/backend/distributed/utils/shardinterval_utils.c index f38f218c9..8a63375e8 100644 --- a/src/backend/distributed/utils/shardinterval_utils.c +++ b/src/backend/distributed/utils/shardinterval_utils.c @@ -23,9 +23,7 @@ #include "utils/memutils.h" -static int FindShardIntervalIndex(Datum searchedValue, ShardInterval **shardIntervalCache, - int shardCount, char partitionMethod, - FmgrInfo *compareFunction, bool useBinarySearch); +static int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry); static int SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, int shardCount, FmgrInfo *compareFunction); @@ -188,12 +186,7 @@ ShardIndex(ShardInterval *shardInterval) Datum shardMinValue = shardInterval->minValue; DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); - ShardInterval **shardIntervalCache = cacheEntry->sortedShardIntervalArray; - int shardCount = cacheEntry->shardIntervalArrayLength; char partitionMethod = cacheEntry->partitionMethod; - FmgrInfo *compareFunction = cacheEntry->shardIntervalCompareFunction; - bool hasUniformHashDistribution = cacheEntry->hasUniformHashDistribution; - bool useBinarySearch = false; /* * Note that, we can also support append and range distributed tables, but @@ -215,15 +208,7 @@ ShardIndex(ShardInterval *shardInterval) return shardIndex; } - /* determine whether to use binary search */ - if (partitionMethod != DISTRIBUTE_BY_HASH || !hasUniformHashDistribution) - { - useBinarySearch = true; - } - - shardIndex = FindShardIntervalIndex(shardMinValue, shardIntervalCache, - shardCount, partitionMethod, - compareFunction, useBinarySearch); + shardIndex = FindShardIntervalIndex(shardMinValue, cacheEntry); return shardIndex; } @@ -236,28 +221,24 @@ ShardIndex(ShardInterval *shardInterval) * as NULL for them. */ ShardInterval * -FindShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, - int shardCount, char partitionMethod, FmgrInfo *compareFunction, - FmgrInfo *hashFunction, bool useBinarySearch) +FindShardInterval(Datum partitionColumnValue, DistTableCacheEntry *cacheEntry) { Datum searchedValue = partitionColumnValue; int shardIndex = INVALID_SHARD_INDEX; - if (partitionMethod == DISTRIBUTE_BY_HASH) + if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH) { - searchedValue = FunctionCall1(hashFunction, partitionColumnValue); + searchedValue = FunctionCall1(cacheEntry->hashFunction, partitionColumnValue); } - shardIndex = FindShardIntervalIndex(searchedValue, shardIntervalCache, - shardCount, partitionMethod, - compareFunction, useBinarySearch); + shardIndex = FindShardIntervalIndex(searchedValue, cacheEntry); if (shardIndex == INVALID_SHARD_INDEX) { return NULL; } - return shardIntervalCache[shardIndex]; + return cacheEntry->sortedShardIntervalArray[shardIndex]; } @@ -273,10 +254,14 @@ FindShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache * fire this. */ static int -FindShardIntervalIndex(Datum searchedValue, ShardInterval **shardIntervalCache, - int shardCount, char partitionMethod, FmgrInfo *compareFunction, - bool useBinarySearch) +FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry) { + ShardInterval **shardIntervalCache = cacheEntry->sortedShardIntervalArray; + int shardCount = cacheEntry->shardIntervalArrayLength; + char partitionMethod = cacheEntry->partitionMethod; + FmgrInfo *compareFunction = cacheEntry->shardIntervalCompareFunction; + bool useBinarySearch = (partitionMethod != DISTRIBUTE_BY_HASH || + !cacheEntry->hasUniformHashDistribution); int shardIndex = INVALID_SHARD_INDEX; if (partitionMethod == DISTRIBUTE_BY_HASH) diff --git a/src/include/distributed/multi_copy.h b/src/include/distributed/multi_copy.h index cd3b56b73..1b6fcc412 100644 --- a/src/include/distributed/multi_copy.h +++ b/src/include/distributed/multi_copy.h @@ -61,7 +61,6 @@ typedef struct CitusCopyDestReceiver /* distributed table metadata */ DistTableCacheEntry *tableMetadata; - bool useBinarySearch; /* open relation handle */ Relation distributedRelation; diff --git a/src/include/distributed/shardinterval_utils.h b/src/include/distributed/shardinterval_utils.h index 0d3b8fa15..54b96f2e7 100644 --- a/src/include/distributed/shardinterval_utils.h +++ b/src/include/distributed/shardinterval_utils.h @@ -13,6 +13,7 @@ #define SHARDINTERVAL_UTILS_H_ #include "distributed/master_metadata_utility.h" +#include "distributed/metadata_cache.h" #include "nodes/primnodes.h" #define INVALID_SHARD_INDEX -1 @@ -33,10 +34,7 @@ extern int CompareRelationShards(const void *leftElement, const void *rightElement); extern int ShardIndex(ShardInterval *shardInterval); extern ShardInterval * FindShardInterval(Datum partitionColumnValue, - ShardInterval **shardIntervalCache, - int shardCount, char partitionMethod, - FmgrInfo *compareFunction, - FmgrInfo *hashFunction, bool useBinarySearch); + DistTableCacheEntry *cacheEntry); extern bool SingleReplicatedTable(Oid relationId); #endif /* SHARDINTERVAL_UTILS_H_ */ From 93e3afc25c425a3d6ea7cdd67eeafb48a4299629 Mon Sep 17 00:00:00 2001 From: Jason Petersen Date: Thu, 27 Apr 2017 13:15:48 -0600 Subject: [PATCH 2/2] Remove FastShardPruning method With the other simplifications, it doesn't make sense to keep around. --- .../planner/multi_router_planner.c | 27 +++---------------- src/backend/distributed/utils/node_metadata.c | 3 ++- .../distributed/multi_router_planner.h | 1 - 3 files changed, 5 insertions(+), 26 deletions(-) diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 7cfbc8efc..2829ecdcc 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -1974,8 +1974,9 @@ FindShardForInsert(Query *query, DeferredErrorMessage **planningError) if (partitionMethod == DISTRIBUTE_BY_HASH || partitionMethod == DISTRIBUTE_BY_RANGE) { Datum partitionValue = partitionValueConst->constvalue; - ShardInterval *shardInterval = FastShardPruning(distributedTableId, - partitionValue); + DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); + ShardInterval *shardInterval = FindShardInterval(partitionValue, cacheEntry); + if (shardInterval != NULL) { prunedShardList = list_make1(shardInterval); @@ -2047,28 +2048,6 @@ FindShardForInsert(Query *query, DeferredErrorMessage **planningError) } -/* - * FastShardPruning is a higher level API for FindShardInterval function. Given the - * relationId of the distributed table and partitionValue, FastShardPruning function finds - * the corresponding shard interval that the partitionValue should be in. FastShardPruning - * returns NULL if no ShardIntervals exist for the given partitionValue. - */ -ShardInterval * -FastShardPruning(Oid distributedTableId, Datum partitionValue) -{ - DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); - ShardInterval *shardInterval = NULL; - - /* - * Call FindShardInterval to find the corresponding shard interval for the - * given partition value. - */ - shardInterval = FindShardInterval(partitionValue, cacheEntry); - - return shardInterval; -} - - /* * FindShardForUpdateOrDelete finds the shard interval in which an UPDATE or * DELETE command should be applied, or sets planningError when the query diff --git a/src/backend/distributed/utils/node_metadata.c b/src/backend/distributed/utils/node_metadata.c index aa16e1502..fd4d8970b 100644 --- a/src/backend/distributed/utils/node_metadata.c +++ b/src/backend/distributed/utils/node_metadata.c @@ -332,6 +332,7 @@ get_shard_id_for_distribution_column(PG_FUNCTION_ARGS) char *distributionValueString = NULL; Datum inputDatum = 0; Datum distributionValueDatum = 0; + DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(relationId); /* if given table is not reference table, distributionValue cannot be NULL */ if (PG_ARGISNULL(1)) @@ -351,7 +352,7 @@ get_shard_id_for_distribution_column(PG_FUNCTION_ARGS) distributionValueDatum = StringToDatum(distributionValueString, distributionDataType); - shardInterval = FastShardPruning(relationId, distributionValueDatum); + shardInterval = FindShardInterval(distributionValueDatum, cacheEntry); } else { diff --git a/src/include/distributed/multi_router_planner.h b/src/include/distributed/multi_router_planner.h index 322cbf791..589724464 100644 --- a/src/include/distributed/multi_router_planner.h +++ b/src/include/distributed/multi_router_planner.h @@ -44,7 +44,6 @@ extern void AddShardIntervalRestrictionToSelect(Query *subqery, ShardInterval *shardInterval); extern ShardInterval * FindShardForInsert(Query *query, DeferredErrorMessage **planningError); -extern ShardInterval * FastShardPruning(Oid distributedTableId, Datum partitionValue); #endif /* MULTI_ROUTER_PLANNER_H */