Refactor FindShardInterval to use cacheEntry

All callers fetch a cache entry and extract/compute arguments for the
eventual FindShardInterval call, so it makes more sense to refactor
into that function itself; this solves the use-after-free bug, too.

Based on 42ee7c05f5
pull/1385/head
Jason Petersen 2017-04-26 11:33:04 -06:00
parent 9a2a9664ca
commit 8afb4b9f33
No known key found for this signature in database
GPG Key ID: 9F1D3510D110ABA9
4 changed files with 19 additions and 86 deletions

View File

@ -279,18 +279,11 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag)
uint32 columnCount = 0; uint32 columnCount = 0;
Datum *columnValues = NULL; Datum *columnValues = NULL;
bool *columnNulls = NULL; bool *columnNulls = NULL;
FmgrInfo *hashFunction = NULL;
FmgrInfo *compareFunction = NULL;
bool hasUniformHashDistribution = false;
DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(tableId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(tableId);
const char *delimiterCharacter = "\t"; const char *delimiterCharacter = "\t";
const char *nullPrintCharacter = "\\N"; const char *nullPrintCharacter = "\\N";
int shardCount = 0;
List *shardIntervalList = NULL; List *shardIntervalList = NULL;
ShardInterval **shardIntervalCache = NULL;
bool useBinarySearch = false;
HTAB *shardConnectionHash = NULL; HTAB *shardConnectionHash = NULL;
ShardConnections *shardConnections = NULL; ShardConnections *shardConnections = NULL;
List *shardConnectionsList = NIL; List *shardConnectionsList = NIL;
@ -306,16 +299,10 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag)
uint64 processedRowCount = 0; uint64 processedRowCount = 0;
Var *partitionColumn = PartitionColumn(tableId, 0); Var *partitionColumn = PartitionColumn(tableId, 0);
char partitionMethod = PartitionMethod(tableId); char partitionMethod = cacheEntry->partitionMethod;
ErrorContextCallback errorCallback; ErrorContextCallback errorCallback;
/* get hash function for partition column */
hashFunction = cacheEntry->hashFunction;
/* get compare function for shard intervals */
compareFunction = cacheEntry->shardIntervalCompareFunction;
/* allocate column values and nulls arrays */ /* allocate column values and nulls arrays */
distributedRelation = heap_open(tableId, RowExclusiveLock); distributedRelation = heap_open(tableId, RowExclusiveLock);
tupleDescriptor = RelationGetDescr(distributedRelation); tupleDescriptor = RelationGetDescr(distributedRelation);
@ -366,17 +353,6 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag)
LockShardListMetadata(shardIntervalList, ShareLock); LockShardListMetadata(shardIntervalList, ShareLock);
LockShardListResources(shardIntervalList, ShareLock); LockShardListResources(shardIntervalList, ShareLock);
/* initialize the shard interval cache */
shardCount = cacheEntry->shardIntervalArrayLength;
shardIntervalCache = cacheEntry->sortedShardIntervalArray;
hasUniformHashDistribution = cacheEntry->hasUniformHashDistribution;
/* determine whether to use binary search */
if (partitionMethod != DISTRIBUTE_BY_HASH || !hasUniformHashDistribution)
{
useBinarySearch = true;
}
if (cacheEntry->replicationModel == REPLICATION_MODEL_2PC) if (cacheEntry->replicationModel == REPLICATION_MODEL_2PC)
{ {
CoordinatedTransactionUse2PC(); CoordinatedTransactionUse2PC();
@ -462,11 +438,7 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag)
* For reference table, this function blindly returns the tables single * For reference table, this function blindly returns the tables single
* shard. * shard.
*/ */
shardInterval = FindShardInterval(partitionColumnValue, shardInterval = FindShardInterval(partitionColumnValue, cacheEntry);
shardIntervalCache,
shardCount, partitionMethod,
compareFunction, hashFunction,
useBinarySearch);
if (shardInterval == NULL) if (shardInterval == NULL)
{ {

View File

@ -2162,35 +2162,13 @@ ShardInterval *
FastShardPruning(Oid distributedTableId, Datum partitionValue) FastShardPruning(Oid distributedTableId, Datum partitionValue)
{ {
DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId);
int shardCount = cacheEntry->shardIntervalArrayLength;
ShardInterval **sortedShardIntervalArray = cacheEntry->sortedShardIntervalArray;
bool useBinarySearch = false;
char partitionMethod = cacheEntry->partitionMethod;
FmgrInfo *shardIntervalCompareFunction = cacheEntry->shardIntervalCompareFunction;
bool hasUniformHashDistribution = cacheEntry->hasUniformHashDistribution;
FmgrInfo *hashFunction = NULL;
ShardInterval *shardInterval = NULL; ShardInterval *shardInterval = NULL;
/* determine whether to use binary search */
if (partitionMethod != DISTRIBUTE_BY_HASH || !hasUniformHashDistribution)
{
useBinarySearch = true;
}
/* we only need hash functions for hash distributed tables */
if (partitionMethod == DISTRIBUTE_BY_HASH)
{
hashFunction = cacheEntry->hashFunction;
}
/* /*
* Call FindShardInterval to find the corresponding shard interval for the * Call FindShardInterval to find the corresponding shard interval for the
* given partition value. * given partition value.
*/ */
shardInterval = FindShardInterval(partitionValue, sortedShardIntervalArray, shardInterval = FindShardInterval(partitionValue, cacheEntry);
shardCount, partitionMethod,
shardIntervalCompareFunction, hashFunction,
useBinarySearch);
return shardInterval; return shardInterval;
} }

View File

@ -23,9 +23,7 @@
#include "utils/memutils.h" #include "utils/memutils.h"
static int FindShardIntervalIndex(Datum searchedValue, ShardInterval **shardIntervalCache, static int FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry);
int shardCount, char partitionMethod,
FmgrInfo *compareFunction, bool useBinarySearch);
static int SearchCachedShardInterval(Datum partitionColumnValue, static int SearchCachedShardInterval(Datum partitionColumnValue,
ShardInterval **shardIntervalCache, ShardInterval **shardIntervalCache,
int shardCount, FmgrInfo *compareFunction); int shardCount, FmgrInfo *compareFunction);
@ -188,12 +186,7 @@ ShardIndex(ShardInterval *shardInterval)
Datum shardMinValue = shardInterval->minValue; Datum shardMinValue = shardInterval->minValue;
DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId); DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedTableId);
ShardInterval **shardIntervalCache = cacheEntry->sortedShardIntervalArray;
int shardCount = cacheEntry->shardIntervalArrayLength;
char partitionMethod = cacheEntry->partitionMethod; char partitionMethod = cacheEntry->partitionMethod;
FmgrInfo *compareFunction = cacheEntry->shardIntervalCompareFunction;
bool hasUniformHashDistribution = cacheEntry->hasUniformHashDistribution;
bool useBinarySearch = false;
/* /*
* Note that, we can also support append and range distributed tables, but * Note that, we can also support append and range distributed tables, but
@ -215,15 +208,7 @@ ShardIndex(ShardInterval *shardInterval)
return shardIndex; return shardIndex;
} }
/* determine whether to use binary search */ shardIndex = FindShardIntervalIndex(shardMinValue, cacheEntry);
if (partitionMethod != DISTRIBUTE_BY_HASH || !hasUniformHashDistribution)
{
useBinarySearch = true;
}
shardIndex = FindShardIntervalIndex(shardMinValue, shardIntervalCache,
shardCount, partitionMethod,
compareFunction, useBinarySearch);
return shardIndex; return shardIndex;
} }
@ -236,28 +221,24 @@ ShardIndex(ShardInterval *shardInterval)
* as NULL for them. * as NULL for them.
*/ */
ShardInterval * ShardInterval *
FindShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, FindShardInterval(Datum partitionColumnValue, DistTableCacheEntry *cacheEntry)
int shardCount, char partitionMethod, FmgrInfo *compareFunction,
FmgrInfo *hashFunction, bool useBinarySearch)
{ {
Datum searchedValue = partitionColumnValue; Datum searchedValue = partitionColumnValue;
int shardIndex = INVALID_SHARD_INDEX; int shardIndex = INVALID_SHARD_INDEX;
if (partitionMethod == DISTRIBUTE_BY_HASH) if (cacheEntry->partitionMethod == DISTRIBUTE_BY_HASH)
{ {
searchedValue = FunctionCall1(hashFunction, partitionColumnValue); searchedValue = FunctionCall1(cacheEntry->hashFunction, partitionColumnValue);
} }
shardIndex = FindShardIntervalIndex(searchedValue, shardIntervalCache, shardIndex = FindShardIntervalIndex(searchedValue, cacheEntry);
shardCount, partitionMethod,
compareFunction, useBinarySearch);
if (shardIndex == INVALID_SHARD_INDEX) if (shardIndex == INVALID_SHARD_INDEX)
{ {
return NULL; return NULL;
} }
return shardIntervalCache[shardIndex]; return cacheEntry->sortedShardIntervalArray[shardIndex];
} }
@ -273,10 +254,14 @@ FindShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache
* fire this. * fire this.
*/ */
static int static int
FindShardIntervalIndex(Datum searchedValue, ShardInterval **shardIntervalCache, FindShardIntervalIndex(Datum searchedValue, DistTableCacheEntry *cacheEntry)
int shardCount, char partitionMethod, FmgrInfo *compareFunction,
bool useBinarySearch)
{ {
ShardInterval **shardIntervalCache = cacheEntry->sortedShardIntervalArray;
int shardCount = cacheEntry->shardIntervalArrayLength;
char partitionMethod = cacheEntry->partitionMethod;
FmgrInfo *compareFunction = cacheEntry->shardIntervalCompareFunction;
bool useBinarySearch = (partitionMethod != DISTRIBUTE_BY_HASH ||
!cacheEntry->hasUniformHashDistribution);
int shardIndex = INVALID_SHARD_INDEX; int shardIndex = INVALID_SHARD_INDEX;
if (partitionMethod == DISTRIBUTE_BY_HASH) if (partitionMethod == DISTRIBUTE_BY_HASH)

View File

@ -13,6 +13,7 @@
#define SHARDINTERVAL_UTILS_H_ #define SHARDINTERVAL_UTILS_H_
#include "distributed/master_metadata_utility.h" #include "distributed/master_metadata_utility.h"
#include "distributed/metadata_cache.h"
#include "nodes/primnodes.h" #include "nodes/primnodes.h"
#define INVALID_SHARD_INDEX -1 #define INVALID_SHARD_INDEX -1
@ -33,10 +34,7 @@ extern int CompareRelationShards(const void *leftElement,
const void *rightElement); const void *rightElement);
extern int ShardIndex(ShardInterval *shardInterval); extern int ShardIndex(ShardInterval *shardInterval);
extern ShardInterval * FindShardInterval(Datum partitionColumnValue, extern ShardInterval * FindShardInterval(Datum partitionColumnValue,
ShardInterval **shardIntervalCache, DistTableCacheEntry *cacheEntry);
int shardCount, char partitionMethod,
FmgrInfo *compareFunction,
FmgrInfo *hashFunction, bool useBinarySearch);
extern bool SingleReplicatedTable(Oid relationId); extern bool SingleReplicatedTable(Oid relationId);
#endif /* SHARDINTERVAL_UTILS_H_ */ #endif /* SHARDINTERVAL_UTILS_H_ */