mirror of https://github.com/citusdata/citus.git
Prevent integer overflow in FindShardIntervalIndex
parent
d0b6e62c9a
commit
b823f2127d
|
@ -27,6 +27,9 @@
|
|||
#include "utils/memutils.h"
|
||||
|
||||
|
||||
static int CalculateUniformHashRangeIndex(int hashedValue, int shardCount);
|
||||
|
||||
|
||||
/*
|
||||
* LowestShardIntervalById returns the shard interval with the lowest shard
|
||||
* ID from a list of shard intervals.
|
||||
|
@ -348,20 +351,8 @@ FindShardIntervalIndex(Datum searchedValue, CitusTableCacheEntry *cacheEntry)
|
|||
else
|
||||
{
|
||||
int hashedValue = DatumGetInt32(searchedValue);
|
||||
uint64 hashTokenIncrement = HASH_TOKEN_COUNT / shardCount;
|
||||
|
||||
shardIndex = (uint32) (hashedValue - INT32_MIN) / hashTokenIncrement;
|
||||
Assert(shardIndex <= shardCount);
|
||||
|
||||
/*
|
||||
* If the shard count is not power of 2, the range of the last
|
||||
* shard becomes larger than others. For that extra piece of range,
|
||||
* we still need to use the last shard.
|
||||
*/
|
||||
if (shardIndex == shardCount)
|
||||
{
|
||||
shardIndex = shardCount - 1;
|
||||
}
|
||||
shardIndex = CalculateUniformHashRangeIndex(hashedValue, shardCount);
|
||||
}
|
||||
}
|
||||
else if (partitionMethod == DISTRIBUTE_BY_NONE)
|
||||
|
@ -442,6 +433,48 @@ SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardInter
|
|||
}
|
||||
|
||||
|
||||
/*
|
||||
* CalculateUniformHashRangeIndex returns the index of the hash range in
|
||||
* which hashedValue falls, assuming shardCount uniform hash ranges.
|
||||
*
|
||||
* We use 64-bit integers to avoid overflow issues during arithmetic.
|
||||
*
|
||||
* NOTE: This function is ONLY for hash-distributed tables with uniform
|
||||
* hash ranges.
|
||||
*/
|
||||
static int
|
||||
CalculateUniformHashRangeIndex(int hashedValue, int shardCount)
|
||||
{
|
||||
int64 hashedValue64 = (int64) hashedValue;
|
||||
|
||||
/* normalize to the 0-UINT32_MAX range */
|
||||
int64 normalizedHashValue = hashedValue64 - INT32_MIN;
|
||||
|
||||
/* size of each hash range */
|
||||
int64 hashRangeSize = HASH_TOKEN_COUNT / shardCount;
|
||||
|
||||
/* index of hash range into which the hash value falls */
|
||||
int shardIndex = (int) (normalizedHashValue / hashRangeSize);
|
||||
|
||||
if (shardIndex < 0 || shardIndex > shardCount)
|
||||
{
|
||||
ereport(ERROR, (errmsg("bug: shard index %d out of bounds", shardIndex)));
|
||||
}
|
||||
|
||||
/*
|
||||
* If the shard count is not power of 2, the range of the last
|
||||
* shard becomes larger than others. For that extra piece of range,
|
||||
* we still need to use the last shard.
|
||||
*/
|
||||
if (shardIndex == shardCount)
|
||||
{
|
||||
shardIndex = shardCount - 1;
|
||||
}
|
||||
|
||||
return shardIndex;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* SingleReplicatedTable checks whether all shards of a distributed table, do not have
|
||||
* more than one replica. If even one shard has more than one replica, this function
|
||||
|
|
Loading…
Reference in New Issue