From b823f2127d0e996402a11e8010a2c12357d9b23f Mon Sep 17 00:00:00 2001 From: Marco Slot Date: Thu, 16 Jul 2020 12:56:06 +0200 Subject: [PATCH] Prevent integer overflow in FindShardIntervalIndex --- .../distributed/utils/shardinterval_utils.c | 59 +++++++++++++++---- 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/src/backend/distributed/utils/shardinterval_utils.c b/src/backend/distributed/utils/shardinterval_utils.c index 6339f882d..43b92f2f7 100644 --- a/src/backend/distributed/utils/shardinterval_utils.c +++ b/src/backend/distributed/utils/shardinterval_utils.c @@ -27,6 +27,9 @@ #include "utils/memutils.h" +static int CalculateUniformHashRangeIndex(int hashedValue, int shardCount); + + /* * LowestShardIntervalById returns the shard interval with the lowest shard * ID from a list of shard intervals. @@ -348,20 +351,8 @@ FindShardIntervalIndex(Datum searchedValue, CitusTableCacheEntry *cacheEntry) else { int hashedValue = DatumGetInt32(searchedValue); - uint64 hashTokenIncrement = HASH_TOKEN_COUNT / shardCount; - shardIndex = (uint32) (hashedValue - INT32_MIN) / hashTokenIncrement; - Assert(shardIndex <= shardCount); - - /* - * If the shard count is not power of 2, the range of the last - * shard becomes larger than others. For that extra piece of range, - * we still need to use the last shard. - */ - if (shardIndex == shardCount) - { - shardIndex = shardCount - 1; - } + shardIndex = CalculateUniformHashRangeIndex(hashedValue, shardCount); } } else if (partitionMethod == DISTRIBUTE_BY_NONE) @@ -442,6 +433,48 @@ SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardInter } +/* + * CalculateUniformHashRangeIndex returns the index of the hash range in + * which hashedValue falls, assuming shardCount uniform hash ranges. + * + * We use 64-bit integers to avoid overflow issues during arithmetic. + * + * NOTE: This function is ONLY for hash-distributed tables with uniform + * hash ranges. + */ +static int +CalculateUniformHashRangeIndex(int hashedValue, int shardCount) +{ + int64 hashedValue64 = (int64) hashedValue; + + /* normalize to the 0-UINT32_MAX range */ + int64 normalizedHashValue = hashedValue64 - INT32_MIN; + + /* size of each hash range */ + int64 hashRangeSize = HASH_TOKEN_COUNT / shardCount; + + /* index of hash range into which the hash value falls */ + int shardIndex = (int) (normalizedHashValue / hashRangeSize); + + if (shardIndex < 0 || shardIndex > shardCount) + { + ereport(ERROR, (errmsg("bug: shard index %d out of bounds", shardIndex))); + } + + /* + * If the shard count is not power of 2, the range of the last + * shard becomes larger than others. For that extra piece of range, + * we still need to use the last shard. + */ + if (shardIndex == shardCount) + { + shardIndex = shardCount - 1; + } + + return shardIndex; +} + + /* * SingleReplicatedTable checks whether all shards of a distributed table, do not have * more than one replica. If even one shard has more than one replica, this function