From 8afb4b9f332472fba521bac42bc05dcfe8b52d65 Mon Sep 17 00:00:00 2001 From: Jason Petersen Date: Wed, 26 Apr 2017 11:33:04 -0600 Subject: [PATCH] Refactor FindShardInterval to use cacheEntry All callers fetch a cache entry and extract/compute arguments for the eventual FindShardInterval call, so it makes more sense to refactor into that function itself; this solves the use-after-free bug, too. Based on 42ee7c05f520a45c20f311b54e2cd8aa47d7d95f --- src/backend/distributed/commands/multi_copy.c | 32 +------------- .../planner/multi_router_planner.c | 24 +---------- .../distributed/utils/shardinterval_utils.c | 43 ++++++------------- src/include/distributed/shardinterval_utils.h | 6 +-- 4 files changed, 19 insertions(+), 86 deletions(-) diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index ec198e778..fb4379c17 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -279,18 +279,11 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag) uint32 columnCount = 0; Datum *columnValues = NULL; bool *columnNulls = NULL; - FmgrInfo *hashFunction = NULL; - FmgrInfo *compareFunction = NULL; - bool hasUniformHashDistribution = false; DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(tableId); const char *delimiterCharacter = "\t"; const char *nullPrintCharacter = "\\N"; - int shardCount = 0; List *shardIntervalList = NULL; - ShardInterval **shardIntervalCache = NULL; - bool useBinarySearch = false; - HTAB *shardConnectionHash = NULL; ShardConnections *shardConnections = NULL; List *shardConnectionsList = NIL; @@ -306,16 +299,10 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag) uint64 processedRowCount = 0; Var *partitionColumn = PartitionColumn(tableId, 0); - char partitionMethod = PartitionMethod(tableId); + char partitionMethod = cacheEntry->partitionMethod; ErrorContextCallback errorCallback; - /* get hash function for partition column */ - hashFunction = cacheEntry->hashFunction; - - /* get compare function for shard intervals */ - compareFunction = cacheEntry->shardIntervalCompareFunction; - /* allocate column values and nulls arrays */ distributedRelation = heap_open(tableId, RowExclusiveLock); tupleDescriptor = RelationGetDescr(distributedRelation); @@ -366,17 +353,6 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag) LockShardListMetadata(shardIntervalList, ShareLock); LockShardListResources(shardIntervalList, ShareLock); - /* initialize the shard interval cache */ - shardCount = cacheEntry->shardIntervalArrayLength; - shardIntervalCache = cacheEntry->sortedShardIntervalArray; - hasUniformHashDistribution = cacheEntry->hasUniformHashDistribution; - - /* determine whether to use binary search */ - if (partitionMethod != DISTRIBUTE_BY_HASH || !hasUniformHashDistribution) - { - useBinarySearch = true; - } - if (cacheEntry->replicationModel == REPLICATION_MODEL_2PC) { CoordinatedTransactionUse2PC(); @@ -462,11 +438,7 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag) * For reference table, this function blindly returns the tables single * shard. */ - shardInterval = FindShardInterval(partitionColumnValue, - shardIntervalCache, - shardCount, partitionMethod, - compareFunction, hashFunction, - useBinarySearch); + shardInterval = FindShardInterval(partitionColumnValue, cacheEntry); if (shardInterval == NULL) { diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index e178f1b3d..294771898 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -2162,35 +2162,13 @@ ShardInterval * FastShardPruning(Oid distributedTableId, Datum partitionValue) { DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); - int shardCount = cacheEntry->shardIntervalArrayLength; - ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray; - bool useBinarySearch = false; - char partitionMethod = cacheEntry->partitionMethod; - FmgrInfo *shardIntervalCompareFunction = cacheEntry->shardIntervalCompareFunction; - bool hasUniformHashDistribution = cacheEntry->hasUniformHashDistribution; - FmgrInfo *hashFunction = NULL; ShardInterval *shardInterval = NULL; - /* determine whether to use binary search */ - if (partitionMethod != DISTRIBUTE_BY_HASH || !hasUniformHashDistribution) - { - useBinarySearch = true; - } - - /* we only need hash functions for hash distributed tables */ - if (partitionMethod == DISTRIBUTE_BY_HASH) - { - hashFunction = cacheEntry->hashFunction; - } - /* * Call FindShardInterval to find the corresponding shard interval for the * given partition value. */ - shardInterval = FindShardInterval(partitionValue, sortedShardIntervalArray, - shardCount, partitionMethod, - shardIntervalCompareFunction, hashFunction, - useBinarySearch); + shardInterval = FindShardInterval(partitionValue, cacheEntry); return shardInterval; } diff --git a/src/backend/distributed/utils/shardinterval_utils.c b/src/backend/distributed/utils/shardinterval_utils.c index f38f218c9..8a63375e8 100644 --- a/src/backend/distributed/utils/shardinterval_utils.c +++ b/src/backend/distributed/utils/shardinterval_utils.c @@ -23,9 +23,7 @@ #include "utils/memutils.h" -static int FindShardIntervalIndex(Datum searchedValue, ShardInterval **shardIntervalCache, - int shardCount, char partitionMethod, - FmgrInfo *compareFunction, bool useBinarySearch); +static int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry); static int SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, int shardCount, FmgrInfo *compareFunction); @@ -188,12 +186,7 @@ ShardIndex(ShardInterval *shardInterval) Datum shardMinValue = shardInterval->minValue; DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); - ShardInterval **shardIntervalCache = cacheEntry->sortedShardIntervalArray; - int shardCount = cacheEntry->shardIntervalArrayLength; char partitionMethod = cacheEntry->partitionMethod; - FmgrInfo *compareFunction = cacheEntry->shardIntervalCompareFunction; - bool hasUniformHashDistribution = cacheEntry->hasUniformHashDistribution; - bool useBinarySearch = false; /* * Note that, we can also support append and range distributed tables, but @@ -215,15 +208,7 @@ ShardIndex(ShardInterval *shardInterval) return shardIndex; } - /* determine whether to use binary search */ - if (partitionMethod != DISTRIBUTE_BY_HASH || !hasUniformHashDistribution) - { - useBinarySearch = true; - } - - shardIndex = FindShardIntervalIndex(shardMinValue, shardIntervalCache, - shardCount, partitionMethod, - compareFunction, useBinarySearch); + shardIndex = FindShardIntervalIndex(shardMinValue, cacheEntry); return shardIndex; } @@ -236,28 +221,24 @@ ShardIndex(ShardInterval *shardInterval) * as NULL for them. */ ShardInterval * -FindShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, - int shardCount, char partitionMethod, FmgrInfo *compareFunction, - FmgrInfo *hashFunction, bool useBinarySearch) +FindShardInterval(Datum partitionColumnValue, DistTableCacheEntry *cacheEntry) { Datum searchedValue = partitionColumnValue; int shardIndex = INVALID_SHARD_INDEX; - if (partitionMethod == DISTRIBUTE_BY_HASH) + if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH) { - searchedValue = FunctionCall1(hashFunction, partitionColumnValue); + searchedValue = FunctionCall1(cacheEntry->hashFunction, partitionColumnValue); } - shardIndex = FindShardIntervalIndex(searchedValue, shardIntervalCache, - shardCount, partitionMethod, - compareFunction, useBinarySearch); + shardIndex = FindShardIntervalIndex(searchedValue, cacheEntry); if (shardIndex == INVALID_SHARD_INDEX) { return NULL; } - return shardIntervalCache[shardIndex]; + return cacheEntry->sortedShardIntervalArray[shardIndex]; } @@ -273,10 +254,14 @@ FindShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache * fire this. */ static int -FindShardIntervalIndex(Datum searchedValue, ShardInterval **shardIntervalCache, - int shardCount, char partitionMethod, FmgrInfo *compareFunction, - bool useBinarySearch) +FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry) { + ShardInterval **shardIntervalCache = cacheEntry->sortedShardIntervalArray; + int shardCount = cacheEntry->shardIntervalArrayLength; + char partitionMethod = cacheEntry->partitionMethod; + FmgrInfo *compareFunction = cacheEntry->shardIntervalCompareFunction; + bool useBinarySearch = (partitionMethod != DISTRIBUTE_BY_HASH || + !cacheEntry->hasUniformHashDistribution); int shardIndex = INVALID_SHARD_INDEX; if (partitionMethod == DISTRIBUTE_BY_HASH) diff --git a/src/include/distributed/shardinterval_utils.h b/src/include/distributed/shardinterval_utils.h index 0d3b8fa15..54b96f2e7 100644 --- a/src/include/distributed/shardinterval_utils.h +++ b/src/include/distributed/shardinterval_utils.h @@ -13,6 +13,7 @@ #define SHARDINTERVAL_UTILS_H_ #include "distributed/master_metadata_utility.h" +#include "distributed/metadata_cache.h" #include "nodes/primnodes.h" #define INVALID_SHARD_INDEX -1 @@ -33,10 +34,7 @@ extern int CompareRelationShards(const void *leftElement, const void *rightElement); extern int ShardIndex(ShardInterval *shardInterval); extern ShardInterval * FindShardInterval(Datum partitionColumnValue, - ShardInterval **shardIntervalCache, - int shardCount, char partitionMethod, - FmgrInfo *compareFunction, - FmgrInfo *hashFunction, bool useBinarySearch); + DistTableCacheEntry *cacheEntry); extern bool SingleReplicatedTable(Oid relationId); #endif /* SHARDINTERVAL_UTILS_H_ */