mirror of https://github.com/citusdata/citus.git
Prevent integer overflow in FindShardIntervalIndex
parent
4b493f088b
commit
77b4534c72
|
@ -27,6 +27,9 @@
|
||||||
#include "utils/memutils.h"
|
#include "utils/memutils.h"
|
||||||
|
|
||||||
|
|
||||||
|
static int CalculateUniformHashRangeIndex(int hashedValue, int shardCount);
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* LowestShardIntervalById returns the shard interval with the lowest shard
|
* LowestShardIntervalById returns the shard interval with the lowest shard
|
||||||
* ID from a list of shard intervals.
|
* ID from a list of shard intervals.
|
||||||
|
@ -348,20 +351,8 @@ FindShardIntervalIndex(Datum searchedValue, CitusTableCacheEntry *cacheEntry)
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
int hashedValue = DatumGetInt32(searchedValue);
|
int hashedValue = DatumGetInt32(searchedValue);
|
||||||
uint64 hashTokenIncrement = HASH_TOKEN_COUNT / shardCount;
|
|
||||||
|
|
||||||
shardIndex = (uint32) (hashedValue - INT32_MIN) / hashTokenIncrement;
|
shardIndex = CalculateUniformHashRangeIndex(hashedValue, shardCount);
|
||||||
Assert(shardIndex <= shardCount);
|
|
||||||
|
|
||||||
/*
|
|
||||||
* If the shard count is not power of 2, the range of the last
|
|
||||||
* shard becomes larger than others. For that extra piece of range,
|
|
||||||
* we still need to use the last shard.
|
|
||||||
*/
|
|
||||||
if (shardIndex == shardCount)
|
|
||||||
{
|
|
||||||
shardIndex = shardCount - 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if (partitionMethod == DISTRIBUTE_BY_NONE)
|
else if (partitionMethod == DISTRIBUTE_BY_NONE)
|
||||||
|
@ -442,6 +433,48 @@ SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardInter
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
|
* CalculateUniformHashRangeIndex returns the index of the hash range in
|
||||||
|
* which hashedValue falls, assuming shardCount uniform hash ranges.
|
||||||
|
*
|
||||||
|
* We use 64-bit integers to avoid overflow issues during arithmetic.
|
||||||
|
*
|
||||||
|
* NOTE: This function is ONLY for hash-distributed tables with uniform
|
||||||
|
* hash ranges.
|
||||||
|
*/
|
||||||
|
static int
|
||||||
|
CalculateUniformHashRangeIndex(int hashedValue, int shardCount)
|
||||||
|
{
|
||||||
|
int64 hashedValue64 = (int64) hashedValue;
|
||||||
|
|
||||||
|
/* normalize to the 0-UINT32_MAX range */
|
||||||
|
int64 normalizedHashValue = hashedValue64 - INT32_MIN;
|
||||||
|
|
||||||
|
/* size of each hash range */
|
||||||
|
int64 hashRangeSize = HASH_TOKEN_COUNT / shardCount;
|
||||||
|
|
||||||
|
/* index of hash range into which the hash value falls */
|
||||||
|
int shardIndex = (int) (normalizedHashValue / hashRangeSize);
|
||||||
|
|
||||||
|
if (shardIndex < 0 || shardIndex > shardCount)
|
||||||
|
{
|
||||||
|
ereport(ERROR, (errmsg("bug: shard index %d out of bounds", shardIndex)));
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
* If the shard count is not power of 2, the range of the last
|
||||||
|
* shard becomes larger than others. For that extra piece of range,
|
||||||
|
* we still need to use the last shard.
|
||||||
|
*/
|
||||||
|
if (shardIndex == shardCount)
|
||||||
|
{
|
||||||
|
shardIndex = shardCount - 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return shardIndex;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* SingleReplicatedTable checks whether all shards of a distributed table, do not have
|
* SingleReplicatedTable checks whether all shards of a distributed table, do not have
|
||||||
* more than one replica. If even one shard has more than one replica, this function
|
* more than one replica. If even one shard has more than one replica, this function
|
||||||
|
|
Loading…
Reference in New Issue