Prevent integer overflow in FindShardIntervalIndex

pull/4036/head
Marco Slot 2020-07-16 12:56:06 +02:00
parent d0b6e62c9a
commit b823f2127d
1 changed files with 46 additions and 13 deletions

View File

@ -27,6 +27,9 @@
#include "utils/memutils.h" #include "utils/memutils.h"
static int CalculateUniformHashRangeIndex(int hashedValue, int shardCount);
/* /*
* LowestShardIntervalById returns the shard interval with the lowest shard * LowestShardIntervalById returns the shard interval with the lowest shard
* ID from a list of shard intervals. * ID from a list of shard intervals.
@ -348,20 +351,8 @@ FindShardIntervalIndex(Datum searchedValue, CitusTableCacheEntry *cacheEntry)
else else
{ {
int hashedValue = DatumGetInt32(searchedValue); int hashedValue = DatumGetInt32(searchedValue);
uint64 hashTokenIncrement = HASH_TOKEN_COUNT / shardCount;
shardIndex = (uint32) (hashedValue - INT32_MIN) / hashTokenIncrement; shardIndex = CalculateUniformHashRangeIndex(hashedValue, shardCount);
Assert(shardIndex <= shardCount);
/*
* If the shard count is not power of 2, the range of the last
* shard becomes larger than others. For that extra piece of range,
* we still need to use the last shard.
*/
if (shardIndex == shardCount)
{
shardIndex = shardCount - 1;
}
} }
} }
else if (partitionMethod == DISTRIBUTE_BY_NONE) else if (partitionMethod == DISTRIBUTE_BY_NONE)
@ -442,6 +433,48 @@ SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardInter
} }
/*
* CalculateUniformHashRangeIndex returns the index of the hash range in
* which hashedValue falls, assuming shardCount uniform hash ranges.
*
* We use 64-bit integers to avoid overflow issues during arithmetic.
*
* NOTE: This function is ONLY for hash-distributed tables with uniform
* hash ranges.
*/
static int
CalculateUniformHashRangeIndex(int hashedValue, int shardCount)
{
int64 hashedValue64 = (int64) hashedValue;
/* normalize to the 0-UINT32_MAX range */
int64 normalizedHashValue = hashedValue64 - INT32_MIN;
/* size of each hash range */
int64 hashRangeSize = HASH_TOKEN_COUNT / shardCount;
/* index of hash range into which the hash value falls */
int shardIndex = (int) (normalizedHashValue / hashRangeSize);
if (shardIndex < 0 || shardIndex > shardCount)
{
ereport(ERROR, (errmsg("bug: shard index %d out of bounds", shardIndex)));
}
/*
* If the shard count is not power of 2, the range of the last
* shard becomes larger than others. For that extra piece of range,
* we still need to use the last shard.
*/
if (shardIndex == shardCount)
{
shardIndex = shardCount - 1;
}
return shardIndex;
}
/* /*
* SingleReplicatedTable checks whether all shards of a distributed table, do not have * SingleReplicatedTable checks whether all shards of a distributed table, do not have
* more than one replica. If even one shard has more than one replica, this function * more than one replica. If even one shard has more than one replica, this function