citus/src/backend/distributed/operations/shard_rebalancer.c

2339 lines
71 KiB
C

/*-------------------------------------------------------------------------
*
* shard_rebalancer.c
*
* Function definitions for the shard rebalancer tool.
*
* Copyright (c) Citus Data, Inc.
*
*-------------------------------------------------------------------------
*/
#include "postgres.h"
#include "libpq-fe.h"
#include <math.h>
#include "distributed/pg_version_constants.h"
#include "access/htup_details.h"
#include "access/genam.h"
#include "catalog/pg_type.h"
#include "catalog/pg_proc.h"
#include "commands/dbcommands.h"
#include "commands/sequence.h"
#include "distributed/argutils.h"
#include "distributed/citus_safe_lib.h"
#include "distributed/citus_ruleutils.h"
#include "distributed/colocation_utils.h"
#include "distributed/connection_management.h"
#include "distributed/enterprise.h"
#include "distributed/hash_helpers.h"
#include "distributed/listutils.h"
#include "distributed/coordinator_protocol.h"
#include "distributed/metadata_cache.h"
#include "distributed/multi_client_executor.h"
#include "distributed/multi_progress.h"
#include "distributed/multi_server_executor.h"
#include "distributed/pg_dist_rebalance_strategy.h"
#include "distributed/reference_table_utils.h"
#include "distributed/remote_commands.h"
#include "distributed/resource_lock.h"
#include "distributed/shard_rebalancer.h"
#include "distributed/tuplestore.h"
#include "distributed/worker_protocol.h"
#include "funcapi.h"
#include "miscadmin.h"
#include "postmaster/postmaster.h"
#include "storage/lmgr.h"
#include "utils/builtins.h"
#include "utils/fmgroids.h"
#include "utils/int8.h"
#include "utils/json.h"
#include "utils/lsyscache.h"
#include "utils/memutils.h"
#include "utils/syscache.h"
#if PG_VERSION_NUM >= PG_VERSION_13
#include "common/hashfn.h"
#endif
/* RebalanceOptions are the options used to control the rebalance algorithm */
typedef struct RebalanceOptions
{
List *relationIdList;
float4 threshold;
int32 maxShardMoves;
ArrayType *excludedShardArray;
bool drainOnly;
Form_pg_dist_rebalance_strategy rebalanceStrategy;
} RebalanceOptions;
/*
* RebalanceState is used to keep the internal state of the rebalance
* algorithm in one place.
*/
typedef struct RebalanceState
{
HTAB *placementsHash;
List *placementUpdateList;
RebalancePlanFunctions *functions;
List *fillStateListDesc;
List *fillStateListAsc;
List *disallowedPlacementList;
float4 totalCost;
float4 totalCapacity;
} RebalanceState;
/* RebalanceContext stores the context for the function callbacks */
typedef struct RebalanceContext
{
FmgrInfo shardCostUDF;
FmgrInfo nodeCapacityUDF;
FmgrInfo shardAllowedOnNodeUDF;
} RebalanceContext;
/* static declarations for main logic */
static int ShardActivePlacementCount(HTAB *activePlacementsHash, uint64 shardId,
List *activeWorkerNodeList);
static bool UpdateShardPlacement(PlacementUpdateEvent *placementUpdateEvent,
List *responsiveNodeList, Oid shardReplicationModeOid);
/* static declarations for main logic's utility functions */
static HTAB * ActivePlacementsHash(List *shardPlacementList);
static bool PlacementsHashFind(HTAB *placementsHash, uint64 shardId,
WorkerNode *workerNode);
static void PlacementsHashEnter(HTAB *placementsHash, uint64 shardId,
WorkerNode *workerNode);
static void PlacementsHashRemove(HTAB *placementsHash, uint64 shardId,
WorkerNode *workerNode);
static int PlacementsHashCompare(const void *lhsKey, const void *rhsKey, Size keySize);
static uint32 PlacementsHashHashCode(const void *key, Size keySize);
static bool WorkerNodeListContains(List *workerNodeList, const char *workerName,
uint32 workerPort);
static void UpdateColocatedShardPlacementProgress(uint64 shardId, char *sourceName,
int sourcePort, uint64 progress);
static bool IsPlacementOnWorkerNode(ShardPlacement *placement, WorkerNode *workerNode);
static NodeFillState * FindFillStateForPlacement(RebalanceState *state,
ShardPlacement *placement);
static RebalanceState * InitRebalanceState(List *workerNodeList, List *shardPlacementList,
RebalancePlanFunctions *functions);
static void MoveShardsAwayFromDisallowedNodes(RebalanceState *state);
static bool FindAndMoveShardCost(float4 utilizationLowerBound,
float4 utilizationUpperBound,
RebalanceState *state);
static NodeFillState * FindAllowedTargetFillState(RebalanceState *state, uint64 shardId);
static void MoveShardCost(NodeFillState *sourceFillState, NodeFillState *targetFillState,
ShardCost *shardCost, RebalanceState *state);
static int CompareNodeFillStateAsc(const void *void1, const void *void2);
static int CompareNodeFillStateDesc(const void *void1, const void *void2);
static int CompareShardCostAsc(const void *void1, const void *void2);
static int CompareShardCostDesc(const void *void1, const void *void2);
static int CompareDisallowedPlacementAsc(const void *void1, const void *void2);
static int CompareDisallowedPlacementDesc(const void *void1, const void *void2);
static bool ShardAllowedOnNode(uint64 shardId, WorkerNode *workerNode, void *context);
static float4 NodeCapacity(WorkerNode *workerNode, void *context);
static ShardCost GetShardCost(uint64 shardId, void *context);
static List * NonColocatedDistRelationIdList(void);
static void RebalanceTableShards(RebalanceOptions *options, Oid shardReplicationModeOid);
static void AcquireColocationLock(Oid relationId, const char *operationName);
static void ExecutePlacementUpdates(List *placementUpdateList, Oid
shardReplicationModeOid, char *noticeOperation);
static float4 CalculateUtilization(float4 totalCost, float4 capacity);
static Form_pg_dist_rebalance_strategy GetRebalanceStrategy(Name name);
static void EnsureShardCostUDF(Oid functionOid);
static void EnsureNodeCapacityUDF(Oid functionOid);
static void EnsureShardAllowedOnNodeUDF(Oid functionOid);
/* declarations for dynamic loading */
PG_FUNCTION_INFO_V1(rebalance_table_shards);
PG_FUNCTION_INFO_V1(replicate_table_shards);
PG_FUNCTION_INFO_V1(get_rebalance_table_shards_plan);
PG_FUNCTION_INFO_V1(get_rebalance_progress);
PG_FUNCTION_INFO_V1(citus_drain_node);
PG_FUNCTION_INFO_V1(master_drain_node);
PG_FUNCTION_INFO_V1(citus_shard_cost_by_disk_size);
PG_FUNCTION_INFO_V1(citus_validate_rebalance_strategy_functions);
PG_FUNCTION_INFO_V1(pg_dist_rebalance_strategy_enterprise_check);
#ifdef USE_ASSERT_CHECKING
/*
* Check that all the invariants of the state hold.
*/
static void
CheckRebalanceStateInvariants(const RebalanceState *state)
{
NodeFillState *fillState = NULL;
NodeFillState *prevFillState = NULL;
int fillStateIndex = 0;
int fillStateLength = list_length(state->fillStateListAsc);
Assert(state != NULL);
Assert(list_length(state->fillStateListAsc) == list_length(state->fillStateListDesc));
foreach_ptr(fillState, state->fillStateListAsc)
{
float4 totalCost = 0;
ShardCost *shardCost = NULL;
ShardCost *prevShardCost = NULL;
if (prevFillState != NULL)
{
/* Check that the previous fill state is more empty than this one */
bool higherUtilization = fillState->utilization > prevFillState->utilization;
bool sameUtilization = fillState->utilization == prevFillState->utilization;
bool lowerOrSameCapacity = fillState->capacity <= prevFillState->capacity;
Assert(higherUtilization || (sameUtilization && lowerOrSameCapacity));
}
/* Check that fillStateListDesc is the reversed version of fillStateListAsc */
Assert(list_nth(state->fillStateListDesc, fillStateLength - fillStateIndex - 1) ==
fillState);
foreach_ptr(shardCost, fillState->shardCostListDesc)
{
if (prevShardCost != NULL)
{
/* Check that shard costs are sorted in descending order */
Assert(shardCost->cost <= prevShardCost->cost);
}
totalCost += shardCost->cost;
}
/* Check that utilization field is up to date. */
Assert(fillState->utilization == CalculateUtilization(fillState->totalCost,
fillState->capacity));
/*
* Check that fillState->totalCost is within 0.1% difference of
* sum(fillState->shardCostListDesc->cost)
* We cannot compare exactly, because these numbers are floats and
* fillState->totalCost is modified by doing + and - on it. So instead
* we check that the numbers are roughly the same.
*/
float4 absoluteDifferenceBetweenTotalCosts =
fabsf(fillState->totalCost - totalCost);
float4 maximumAbsoluteValueOfTotalCosts =
fmaxf(fabsf(fillState->totalCost), fabsf(totalCost));
Assert(absoluteDifferenceBetweenTotalCosts <= maximumAbsoluteValueOfTotalCosts /
1000);
prevFillState = fillState;
fillStateIndex++;
}
}
#else
#define CheckRebalanceStateInvariants(l) ((void) 0)
#endif /* USE_ASSERT_CHECKING */
/*
* BigIntArrayDatumContains checks if the array contains the given number.
*/
static bool
BigIntArrayDatumContains(Datum *array, int arrayLength, uint64 toFind)
{
for (int i = 0; i < arrayLength; i++)
{
if (DatumGetInt64(array[i]) == toFind)
{
return true;
}
}
return false;
}
/*
* FullShardPlacementList returns a List containing all the shard placements of
* a specific table (excluding the excludedShardArray)
*/
static List *
FullShardPlacementList(Oid relationId, ArrayType *excludedShardArray)
{
List *shardPlacementList = NIL;
CitusTableCacheEntry *citusTableCacheEntry = GetCitusTableCacheEntry(relationId);
int shardIntervalArrayLength = citusTableCacheEntry->shardIntervalArrayLength;
int excludedShardIdCount = ArrayObjectCount(excludedShardArray);
Datum *excludedShardArrayDatum = DeconstructArrayObject(excludedShardArray);
for (int shardIndex = 0; shardIndex < shardIntervalArrayLength; shardIndex++)
{
ShardInterval *shardInterval =
citusTableCacheEntry->sortedShardIntervalArray[shardIndex];
GroupShardPlacement *placementArray =
citusTableCacheEntry->arrayOfPlacementArrays[shardIndex];
int numberOfPlacements =
citusTableCacheEntry->arrayOfPlacementArrayLengths[shardIndex];
if (BigIntArrayDatumContains(excludedShardArrayDatum, excludedShardIdCount,
shardInterval->shardId))
{
continue;
}
for (int placementIndex = 0; placementIndex < numberOfPlacements;
placementIndex++)
{
GroupShardPlacement *groupPlacement = &placementArray[placementIndex];
WorkerNode *worker = LookupNodeForGroup(groupPlacement->groupId);
ShardPlacement *placement = CitusMakeNode(ShardPlacement);
placement->shardId = groupPlacement->shardId;
placement->shardLength = groupPlacement->shardLength;
placement->shardState = groupPlacement->shardState;
placement->nodeName = pstrdup(worker->workerName);
placement->nodePort = worker->workerPort;
placement->placementId = groupPlacement->placementId;
shardPlacementList = lappend(shardPlacementList, placement);
}
}
return SortList(shardPlacementList, CompareShardPlacements);
}
/*
* SortedActiveWorkers returns all the active workers like
* ActiveReadableNodeList, but sorted.
*/
static List *
SortedActiveWorkers()
{
List *activeWorkerList = ActiveReadableNodeList();
return SortList(activeWorkerList, CompareWorkerNodes);
}
/*
* GetRebalanceSteps returns a List of PlacementUpdateEvents that are needed to
* rebalance a list of tables.
*/
static List *
GetRebalanceSteps(RebalanceOptions *options)
{
EnsureShardCostUDF(options->rebalanceStrategy->shardCostFunction);
EnsureNodeCapacityUDF(options->rebalanceStrategy->nodeCapacityFunction);
EnsureShardAllowedOnNodeUDF(options->rebalanceStrategy->shardAllowedOnNodeFunction);
RebalanceContext context;
memset(&context, 0, sizeof(RebalanceContext));
fmgr_info(options->rebalanceStrategy->shardCostFunction, &context.shardCostUDF);
fmgr_info(options->rebalanceStrategy->nodeCapacityFunction, &context.nodeCapacityUDF);
fmgr_info(options->rebalanceStrategy->shardAllowedOnNodeFunction,
&context.shardAllowedOnNodeUDF);
RebalancePlanFunctions rebalancePlanFunctions = {
.shardAllowedOnNode = ShardAllowedOnNode,
.nodeCapacity = NodeCapacity,
.shardCost = GetShardCost,
.context = &context,
};
/* sort the lists to make the function more deterministic */
List *activeWorkerList = SortedActiveWorkers();
List *shardPlacementListList = NIL;
Oid relationId = InvalidOid;
foreach_oid(relationId, options->relationIdList)
{
List *shardPlacementList = FullShardPlacementList(relationId,
options->excludedShardArray);
shardPlacementListList = lappend(shardPlacementListList, shardPlacementList);
}
if (options->threshold < options->rebalanceStrategy->minimumThreshold)
{
ereport(WARNING, (errmsg(
"the given threshold is lower than the minimum "
"threshold allowed by the rebalance strategy, "
"using the minimum allowed threshold instead"
),
errdetail("Using threshold of %.2f",
options->rebalanceStrategy->minimumThreshold
)
));
options->threshold = options->rebalanceStrategy->minimumThreshold;
}
return RebalancePlacementUpdates(activeWorkerList,
shardPlacementListList,
options->threshold,
options->maxShardMoves,
options->drainOnly,
&rebalancePlanFunctions);
}
/*
* ShardAllowedOnNode determines if shard is allowed on a specific worker node.
*/
static bool
ShardAllowedOnNode(uint64 shardId, WorkerNode *workerNode, void *voidContext)
{
if (!workerNode->shouldHaveShards)
{
return false;
}
RebalanceContext *context = voidContext;
Datum allowed = FunctionCall2(&context->shardAllowedOnNodeUDF, shardId,
workerNode->nodeId);
return DatumGetBool(allowed);
}
/*
* NodeCapacity returns the relative capacity of a node. A node with capacity 2
* can contain twice as many shards as a node with capacity 1. The actual
* capacity can be a number grounded in reality, like the disk size, number of
* cores, but it doesn't have to be.
*/
static float4
NodeCapacity(WorkerNode *workerNode, void *voidContext)
{
if (!workerNode->shouldHaveShards)
{
return 0;
}
RebalanceContext *context = voidContext;
Datum capacity = FunctionCall1(&context->nodeCapacityUDF, workerNode->nodeId);
return DatumGetFloat4(capacity);
}
/*
* GetShardCost returns the cost of the given shard. A shard with cost 2 will
* be weighted as heavily as two shards with cost 1. This cost number can be a
* number grounded in reality, like the shard size on disk, but it doesn't have
* to be.
*/
static ShardCost
GetShardCost(uint64 shardId, void *voidContext)
{
ShardCost shardCost;
memset_struct_0(shardCost);
shardCost.shardId = shardId;
RebalanceContext *context = voidContext;
Datum shardCostDatum = FunctionCall1(&context->shardCostUDF, UInt64GetDatum(shardId));
shardCost.cost = DatumGetFloat4(shardCostDatum);
return shardCost;
}
/*
* citus_shard_cost_by_disk_size gets the cost for a shard based on the disk
* size of the shard on a worker. The worker to check the disk size is
* determined by choosing the first active placement for the shard. The disk
* size is calculated using pg_total_relation_size, so it includes indexes.
*
* SQL signature:
* citus_shard_cost_by_disk_size(shardid bigint) returns float4
*/
Datum
citus_shard_cost_by_disk_size(PG_FUNCTION_ARGS)
{
uint64 shardId = PG_GETARG_INT64(0);
bool missingOk = false;
ShardPlacement *shardPlacement = ActiveShardPlacement(shardId, missingOk);
char *workerNodeName = shardPlacement->nodeName;
uint32 workerNodePort = shardPlacement->nodePort;
uint32 connectionFlag = 0;
PGresult *result = NULL;
bool raiseErrors = true;
char *sizeQuery = PG_TOTAL_RELATION_SIZE_FUNCTION;
ShardInterval *shardInterval = LoadShardInterval(shardId);
List *colocatedShardList = ColocatedShardIntervalList(shardInterval);
StringInfo tableSizeQuery = GenerateSizeQueryOnMultiplePlacements(colocatedShardList,
sizeQuery);
MultiConnection *connection = GetNodeConnection(connectionFlag, workerNodeName,
workerNodePort);
int queryResult = ExecuteOptionalRemoteCommand(connection, tableSizeQuery->data,
&result);
if (queryResult != RESPONSE_OKAY)
{
ereport(ERROR, (errcode(ERRCODE_CONNECTION_FAILURE),
errmsg("cannot get the size because of a connection error")));
}
List *sizeList = ReadFirstColumnAsText(result);
if (list_length(sizeList) != 1)
{
ereport(ERROR, (errmsg(
"received wrong number of rows from worker, expected 1 received %d",
list_length(sizeList))));
}
StringInfo tableSizeStringInfo = (StringInfo) linitial(sizeList);
char *tableSizeString = tableSizeStringInfo->data;
uint64 tableSize = SafeStringToUint64(tableSizeString);
PQclear(result);
ClearResults(connection, raiseErrors);
if (tableSize <= 0)
{
PG_RETURN_FLOAT4(1);
}
PG_RETURN_FLOAT4(tableSize);
}
/*
* GetColocatedRebalanceSteps takes a List of PlacementUpdateEvents and creates
* a new List of containing those and all the updates for colocated shards.
*/
static List *
GetColocatedRebalanceSteps(List *placementUpdateList)
{
ListCell *placementUpdateCell = NULL;
List *colocatedUpdateList = NIL;
foreach(placementUpdateCell, placementUpdateList)
{
PlacementUpdateEvent *placementUpdate = lfirst(placementUpdateCell);
ShardInterval *shardInterval = LoadShardInterval(placementUpdate->shardId);
List *colocatedShardList = ColocatedShardIntervalList(shardInterval);
ListCell *colocatedShardCell = NULL;
foreach(colocatedShardCell, colocatedShardList)
{
ShardInterval *colocatedShard = lfirst(colocatedShardCell);
PlacementUpdateEvent *colocatedUpdate = palloc0(sizeof(PlacementUpdateEvent));
colocatedUpdate->shardId = colocatedShard->shardId;
colocatedUpdate->sourceNode = placementUpdate->sourceNode;
colocatedUpdate->targetNode = placementUpdate->targetNode;
colocatedUpdate->updateType = placementUpdate->updateType;
colocatedUpdateList = lappend(colocatedUpdateList, colocatedUpdate);
}
}
return colocatedUpdateList;
}
/*
* AcquireColocationLock tries to acquire a lock for rebalance/replication. If
* this is it not possible it fails instantly because this means another
* rebalance/repliction is currently happening. This would really mess up
* planning.
*/
static void
AcquireColocationLock(Oid relationId, const char *operationName)
{
uint32 lockId = relationId;
LOCKTAG tag;
CitusTableCacheEntry *citusTableCacheEntry = GetCitusTableCacheEntry(relationId);
if (citusTableCacheEntry->colocationId != INVALID_COLOCATION_ID)
{
lockId = citusTableCacheEntry->colocationId;
}
SET_LOCKTAG_REBALANCE_COLOCATION(tag, (int64) lockId);
LockAcquireResult lockAcquired = LockAcquire(&tag, ExclusiveLock, false, true);
if (!lockAcquired)
{
ereport(ERROR, (errmsg("could not acquire the lock required to %s %s",
operationName, generate_qualified_relation_name(
relationId))));
}
}
/*
* GetResponsiveWorkerList returns a List of workers that respond to new
* connection requests.
*/
static List *
GetResponsiveWorkerList()
{
List *activeWorkerList = ActiveReadableNodeList();
ListCell *activeWorkerCell = NULL;
List *responsiveWorkerList = NIL;
foreach(activeWorkerCell, activeWorkerList)
{
WorkerNode *worker = lfirst(activeWorkerCell);
int connectionFlag = FORCE_NEW_CONNECTION;
MultiConnection *connection = GetNodeConnection(connectionFlag,
worker->workerName,
worker->workerPort);
if (connection != NULL && connection->pgConn != NULL)
{
if (PQstatus(connection->pgConn) == CONNECTION_OK)
{
responsiveWorkerList = lappend(responsiveWorkerList, worker);
}
CloseConnection(connection);
}
}
return responsiveWorkerList;
}
/*
* ExecutePlacementUpdates copies or moves a shard placement by calling the
* corresponding functions in Citus in a separate subtransaction for each
* update.
*/
static void
ExecutePlacementUpdates(List *placementUpdateList, Oid shardReplicationModeOid,
char *noticeOperation)
{
List *responsiveWorkerList = GetResponsiveWorkerList();
ListCell *placementUpdateCell = NULL;
char shardReplicationMode = LookupShardTransferMode(shardReplicationModeOid);
if (shardReplicationMode == TRANSFER_MODE_FORCE_LOGICAL)
{
ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("the force_logical transfer mode is currently "
"unsupported")));
}
foreach(placementUpdateCell, placementUpdateList)
{
PlacementUpdateEvent *placementUpdate = lfirst(placementUpdateCell);
ereport(NOTICE, (errmsg(
"%s shard %lu from %s:%u to %s:%u ...",
noticeOperation,
placementUpdate->shardId,
placementUpdate->sourceNode->workerName,
placementUpdate->sourceNode->workerPort,
placementUpdate->targetNode->workerName,
placementUpdate->targetNode->workerPort
)));
UpdateShardPlacement(placementUpdate, responsiveWorkerList,
shardReplicationModeOid);
}
}
/*
* SetupRebalanceMonitor initializes the dynamic shared memory required for storing the
* progress information of a rebalance process. The function takes a List of
* PlacementUpdateEvents for all shards that will be moved (including colocated
* ones) and the relation id of the target table. The dynamic shared memory
* portion consists of a RebalanceMonitorHeader and multiple
* PlacementUpdateEventProgress, one for each planned shard placement move. The
* dsm_handle of the created segment is savedin the progress of the current backend so
* that it can be read by external agents such as get_rebalance_progress function by
* calling pg_stat_get_progress_info UDF. Since currently only VACUUM commands are
* officially allowed as the command type, we describe ourselves as a VACUUM command and
* in order to distinguish a rebalancer progress from regular VACUUM progresses, we put
* a magic number to the first progress field as an indicator. Finally we return the
* dsm handle so that it can be used for updating the progress and cleaning things up.
*/
static void
SetupRebalanceMonitor(List *placementUpdateList, Oid relationId)
{
List *colocatedUpdateList = GetColocatedRebalanceSteps(placementUpdateList);
ListCell *colocatedUpdateCell = NULL;
ProgressMonitorData *monitor = CreateProgressMonitor(REBALANCE_ACTIVITY_MAGIC_NUMBER,
list_length(colocatedUpdateList),
sizeof(
PlacementUpdateEventProgress),
relationId);
PlacementUpdateEventProgress *rebalanceSteps = monitor->steps;
int32 eventIndex = 0;
foreach(colocatedUpdateCell, colocatedUpdateList)
{
PlacementUpdateEvent *colocatedUpdate = lfirst(colocatedUpdateCell);
PlacementUpdateEventProgress *event = rebalanceSteps + eventIndex;
strlcpy(event->sourceName, colocatedUpdate->sourceNode->workerName, 255);
strlcpy(event->targetName, colocatedUpdate->targetNode->workerName, 255);
event->shardId = colocatedUpdate->shardId;
event->sourcePort = colocatedUpdate->sourceNode->workerPort;
event->targetPort = colocatedUpdate->targetNode->workerPort;
event->shardSize = ShardLength(colocatedUpdate->shardId);
eventIndex++;
}
}
/*
* rebalance_table_shards rebalances the shards across the workers.
*
* SQL signature:
*
* rebalance_table_shards(
* relation regclass,
* threshold float4,
* max_shard_moves int,
* excluded_shard_list bigint[],
* shard_transfer_mode citus.shard_transfer_mode,
* drain_only boolean,
* rebalance_strategy name
* ) RETURNS VOID
*/
Datum
rebalance_table_shards(PG_FUNCTION_ARGS)
{
List *relationIdList = NIL;
if (!PG_ARGISNULL(0))
{
Oid relationId = PG_GETARG_OID(0);
ErrorIfMoveCitusLocalTable(relationId);
relationIdList = list_make1_oid(relationId);
}
else
{
/*
* Note that we don't need to do any checks to error out for
* citus local tables here as NonColocatedDistRelationIdList
* already doesn't return non-distributed tables.
*/
relationIdList = NonColocatedDistRelationIdList();
}
PG_ENSURE_ARGNOTNULL(2, "max_shard_moves");
PG_ENSURE_ARGNOTNULL(3, "excluded_shard_list");
PG_ENSURE_ARGNOTNULL(4, "shard_transfer_mode");
PG_ENSURE_ARGNOTNULL(5, "drain_only");
Form_pg_dist_rebalance_strategy strategy = GetRebalanceStrategy(
PG_GETARG_NAME_OR_NULL(6));
RebalanceOptions options = {
.relationIdList = relationIdList,
.threshold = PG_GETARG_FLOAT4_OR_DEFAULT(1, strategy->defaultThreshold),
.maxShardMoves = PG_GETARG_INT32(2),
.excludedShardArray = PG_GETARG_ARRAYTYPE_P(3),
.drainOnly = PG_GETARG_BOOL(5),
.rebalanceStrategy = strategy,
};
Oid shardTransferModeOid = PG_GETARG_OID(4);
RebalanceTableShards(&options, shardTransferModeOid);
PG_RETURN_VOID();
}
/*
* GetRebalanceStrategy returns the rebalance strategy from
* pg_dist_rebalance_strategy matching the given name. If name is NULL it
* returns the default rebalance strategy from pg_dist_rebalance_strategy.
*/
static Form_pg_dist_rebalance_strategy
GetRebalanceStrategy(Name name)
{
Relation pgDistRebalanceStrategy = table_open(DistRebalanceStrategyRelationId(),
AccessShareLock);
const int scanKeyCount = 1;
ScanKeyData scanKey[1];
if (name == NULL)
{
/* WHERE default_strategy=true */
ScanKeyInit(&scanKey[0], Anum_pg_dist_rebalance_strategy_default_strategy,
BTEqualStrategyNumber, F_BOOLEQ, BoolGetDatum(true));
}
else
{
/* WHERE name=$name */
ScanKeyInit(&scanKey[0], Anum_pg_dist_rebalance_strategy_name,
BTEqualStrategyNumber, F_NAMEEQ, NameGetDatum(name));
}
SysScanDesc scanDescriptor = systable_beginscan(pgDistRebalanceStrategy,
InvalidOid, false,
NULL, scanKeyCount, scanKey);
HeapTuple heapTuple = systable_getnext(scanDescriptor);
if (!HeapTupleIsValid(heapTuple))
{
if (name == NULL)
{
ereport(ERROR, (errmsg(
"no rebalance_strategy was provided, but there is also no default strategy set")));
}
ereport(ERROR, (errmsg("could not find rebalance strategy with name %s",
(char *) name)));
}
Form_pg_dist_rebalance_strategy strategy =
(Form_pg_dist_rebalance_strategy) GETSTRUCT(heapTuple);
Form_pg_dist_rebalance_strategy strategy_copy =
palloc0(sizeof(FormData_pg_dist_rebalance_strategy));
/* Copy data over by dereferencing */
*strategy_copy = *strategy;
systable_endscan(scanDescriptor);
table_close(pgDistRebalanceStrategy, NoLock);
return strategy_copy;
}
/*
* citus_drain_node drains a node by setting shouldhaveshards to false and
* running the rebalancer after in drain_only mode.
*/
Datum
citus_drain_node(PG_FUNCTION_ARGS)
{
PG_ENSURE_ARGNOTNULL(0, "nodename");
PG_ENSURE_ARGNOTNULL(1, "nodeport");
PG_ENSURE_ARGNOTNULL(2, "shard_transfer_mode");
text *nodeNameText = PG_GETARG_TEXT_P(0);
int32 nodePort = PG_GETARG_INT32(1);
Oid shardTransferModeOid = PG_GETARG_OID(2);
Form_pg_dist_rebalance_strategy strategy = GetRebalanceStrategy(
PG_GETARG_NAME_OR_NULL(3));
RebalanceOptions options = {
.relationIdList = NonColocatedDistRelationIdList(),
.threshold = strategy->defaultThreshold,
.maxShardMoves = 0,
.excludedShardArray = construct_empty_array(INT4OID),
.drainOnly = true,
.rebalanceStrategy = strategy,
};
char *nodeName = text_to_cstring(nodeNameText);
int connectionFlag = FORCE_NEW_CONNECTION;
MultiConnection *connection = GetNodeConnection(connectionFlag, LocalHostName,
PostPortNumber);
/*
* This is done in a separate session. This way it's not undone if the
* draining fails midway through.
*/
ExecuteCriticalRemoteCommand(connection, psprintf(
"SELECT master_set_node_property(%s, %i, 'shouldhaveshards', false)",
quote_literal_cstr(nodeName), nodePort));
RebalanceTableShards(&options, shardTransferModeOid);
PG_RETURN_VOID();
}
/*
* replicate_table_shards replicates under-replicated shards of the specified
* table.
*/
Datum
replicate_table_shards(PG_FUNCTION_ARGS)
{
Oid relationId = PG_GETARG_OID(0);
uint32 shardReplicationFactor = PG_GETARG_INT32(1);
int32 maxShardCopies = PG_GETARG_INT32(2);
ArrayType *excludedShardArray = PG_GETARG_ARRAYTYPE_P(3);
Oid shardReplicationModeOid = PG_GETARG_OID(4);
char transferMode = LookupShardTransferMode(shardReplicationModeOid);
EnsureReferenceTablesExistOnAllNodesExtended(transferMode);
AcquireColocationLock(relationId, "replicate");
List *activeWorkerList = SortedActiveWorkers();
List *shardPlacementList = FullShardPlacementList(relationId, excludedShardArray);
List *placementUpdateList = ReplicationPlacementUpdates(activeWorkerList,
shardPlacementList,
shardReplicationFactor);
placementUpdateList = list_truncate(placementUpdateList, maxShardCopies);
ExecutePlacementUpdates(placementUpdateList, shardReplicationModeOid, "Copying");
PG_RETURN_VOID();
}
/*
* master_drain_node is a wrapper function for old UDF name.
*/
Datum
master_drain_node(PG_FUNCTION_ARGS)
{
return citus_drain_node(fcinfo);
}
/*
* get_rebalance_table_shards_plan function calculates the shard move steps
* required for the rebalance operations including the ones for colocated
* tables.
*
* SQL signature:
*
* get_rebalance_table_shards_plan(
* relation regclass,
* threshold float4,
* max_shard_moves int,
* excluded_shard_list bigint[],
* drain_only boolean,
* rebalance_strategy name
* )
*/
Datum
get_rebalance_table_shards_plan(PG_FUNCTION_ARGS)
{
List *relationIdList = NIL;
if (!PG_ARGISNULL(0))
{
Oid relationId = PG_GETARG_OID(0);
ErrorIfMoveCitusLocalTable(relationId);
relationIdList = list_make1_oid(relationId);
}
else
{
/*
* Note that we don't need to do any checks to error out for
* citus local tables here as NonColocatedDistRelationIdList
* already doesn't return non-distributed tables.
*/
relationIdList = NonColocatedDistRelationIdList();
}
PG_ENSURE_ARGNOTNULL(2, "max_shard_moves");
PG_ENSURE_ARGNOTNULL(3, "excluded_shard_list");
PG_ENSURE_ARGNOTNULL(4, "drain_only");
Form_pg_dist_rebalance_strategy strategy = GetRebalanceStrategy(
PG_GETARG_NAME_OR_NULL(5));
RebalanceOptions options = {
.relationIdList = relationIdList,
.threshold = PG_GETARG_FLOAT4_OR_DEFAULT(1, strategy->defaultThreshold),
.maxShardMoves = PG_GETARG_INT32(2),
.excludedShardArray = PG_GETARG_ARRAYTYPE_P(3),
.drainOnly = PG_GETARG_BOOL(4),
.rebalanceStrategy = strategy,
};
List *placementUpdateList = GetRebalanceSteps(&options);
List *colocatedUpdateList = GetColocatedRebalanceSteps(placementUpdateList);
ListCell *colocatedUpdateCell = NULL;
TupleDesc tupdesc;
Tuplestorestate *tupstore = SetupTuplestore(fcinfo, &tupdesc);
foreach(colocatedUpdateCell, colocatedUpdateList)
{
PlacementUpdateEvent *colocatedUpdate = lfirst(colocatedUpdateCell);
Datum values[7];
bool nulls[7];
memset(values, 0, sizeof(values));
memset(nulls, 0, sizeof(nulls));
values[0] = ObjectIdGetDatum(RelationIdForShard(colocatedUpdate->shardId));
values[1] = UInt64GetDatum(colocatedUpdate->shardId);
values[2] = UInt64GetDatum(ShardLength(colocatedUpdate->shardId));
values[3] = PointerGetDatum(cstring_to_text(
colocatedUpdate->sourceNode->workerName));
values[4] = UInt32GetDatum(colocatedUpdate->sourceNode->workerPort);
values[5] = PointerGetDatum(cstring_to_text(
colocatedUpdate->targetNode->workerName));
values[6] = UInt32GetDatum(colocatedUpdate->targetNode->workerPort);
tuplestore_putvalues(tupstore, tupdesc, values, nulls);
}
tuplestore_donestoring(tupstore);
return (Datum) 0;
}
/*
* get_rebalance_progress collects information about the ongoing rebalance operations and
* returns the concatenated list of steps involved in the operations, along with their
* progress information. Currently the progress field can take 4 integer values
* (-1: error, 0: waiting, 1: moving, 2: moved). The progress field is of type bigint
* because we may implement a more granular, byte-level progress as a future improvement.
*/
Datum
get_rebalance_progress(PG_FUNCTION_ARGS)
{
List *segmentList = NIL;
ListCell *rebalanceMonitorCell = NULL;
TupleDesc tupdesc;
Tuplestorestate *tupstore = SetupTuplestore(fcinfo, &tupdesc);
/* get the addresses of all current rebalance monitors */
List *rebalanceMonitorList = ProgressMonitorList(REBALANCE_ACTIVITY_MAGIC_NUMBER,
&segmentList);
foreach(rebalanceMonitorCell, rebalanceMonitorList)
{
ProgressMonitorData *monitor = lfirst(rebalanceMonitorCell);
PlacementUpdateEventProgress *placementUpdateEvents = monitor->steps;
for (int eventIndex = 0; eventIndex < monitor->stepCount; eventIndex++)
{
PlacementUpdateEventProgress *step = placementUpdateEvents + eventIndex;
uint64 shardId = step->shardId;
ShardInterval *shardInterval = LoadShardInterval(shardId);
Datum values[9];
bool nulls[9];
memset(values, 0, sizeof(values));
memset(nulls, 0, sizeof(nulls));
values[0] = monitor->processId;
values[1] = ObjectIdGetDatum(shardInterval->relationId);
values[2] = UInt64GetDatum(shardId);
values[3] = UInt64GetDatum(step->shardSize);
values[4] = PointerGetDatum(cstring_to_text(step->sourceName));
values[5] = UInt32GetDatum(step->sourcePort);
values[6] = PointerGetDatum(cstring_to_text(step->targetName));
values[7] = UInt32GetDatum(step->targetPort);
values[8] = UInt64GetDatum(step->progress);
tuplestore_putvalues(tupstore, tupdesc, values, nulls);
}
}
tuplestore_donestoring(tupstore);
DetachFromDSMSegments(segmentList);
return (Datum) 0;
}
/*
* NonColocatedDistRelationIdList returns a list of distributed table oids, one
* for each existing colocation group.
*/
static List *
NonColocatedDistRelationIdList(void)
{
List *relationIdList = NIL;
List *allCitusTablesList = CitusTableTypeIdList(ANY_CITUS_TABLE_TYPE);
Oid tableId = InvalidOid;
/* allocate sufficient capacity for O(1) expected look-up time */
int capacity = (int) (list_length(allCitusTablesList) / 0.75) + 1;
int flags = HASH_ELEM | HASH_CONTEXT | HASH_BLOBS;
HASHCTL info = {
.keysize = sizeof(Oid),
.entrysize = sizeof(Oid),
.hcxt = CurrentMemoryContext
};
HTAB *alreadySelectedColocationIds = hash_create("RebalanceColocationIdSet",
capacity, &info, flags);
foreach_oid(tableId, allCitusTablesList)
{
bool foundInSet = false;
CitusTableCacheEntry *citusTableCacheEntry = GetCitusTableCacheEntry(
tableId);
if (!IsCitusTableTypeCacheEntry(citusTableCacheEntry, DISTRIBUTED_TABLE))
{
/*
* We're only interested in distributed tables, should ignore
* reference tables and citus local tables.
*/
continue;
}
if (citusTableCacheEntry->colocationId != INVALID_COLOCATION_ID)
{
hash_search(alreadySelectedColocationIds,
&citusTableCacheEntry->colocationId, HASH_ENTER,
&foundInSet);
if (foundInSet)
{
continue;
}
}
relationIdList = lappend_oid(relationIdList, tableId);
}
return relationIdList;
}
/*
* RebalanceTableShards rebalances the shards for the relations inside the
* relationIdList across the different workers.
*/
static void
RebalanceTableShards(RebalanceOptions *options, Oid shardReplicationModeOid)
{
char transferMode = LookupShardTransferMode(shardReplicationModeOid);
EnsureReferenceTablesExistOnAllNodesExtended(transferMode);
if (list_length(options->relationIdList) == 0)
{
return;
}
Oid relationId = InvalidOid;
char *operationName = "rebalance";
if (options->drainOnly)
{
operationName = "move";
}
foreach_oid(relationId, options->relationIdList)
{
AcquireColocationLock(relationId, operationName);
}
List *placementUpdateList = GetRebalanceSteps(options);
if (list_length(placementUpdateList) == 0)
{
return;
}
/*
* This uses the first relationId from the list, it's only used for display
* purposes so it does not really matter which to show
*/
SetupRebalanceMonitor(placementUpdateList, linitial_oid(options->relationIdList));
ExecutePlacementUpdates(placementUpdateList, shardReplicationModeOid, "Moving");
FinalizeCurrentProgressMonitor();
}
/*
* UpdateShardPlacement copies or moves a shard placement by calling
* the corresponding functions in Citus in a subtransaction.
*/
static bool
UpdateShardPlacement(PlacementUpdateEvent *placementUpdateEvent,
List *responsiveNodeList, Oid shardReplicationModeOid)
{
PlacementUpdateType updateType = placementUpdateEvent->updateType;
uint64 shardId = placementUpdateEvent->shardId;
WorkerNode *sourceNode = placementUpdateEvent->sourceNode;
WorkerNode *targetNode = placementUpdateEvent->targetNode;
const char *doRepair = "false";
int connectionFlag = FORCE_NEW_CONNECTION;
Datum shardTranferModeLabelDatum =
DirectFunctionCall1(enum_out, shardReplicationModeOid);
char *shardTranferModeLabel = DatumGetCString(shardTranferModeLabelDatum);
StringInfo placementUpdateCommand = makeStringInfo();
/* if target node is not responsive, don't continue */
bool targetResponsive = WorkerNodeListContains(responsiveNodeList,
targetNode->workerName,
targetNode->workerPort);
if (!targetResponsive)
{
ereport(WARNING, (errmsg("%s:%d is not responsive", targetNode->workerName,
targetNode->workerPort)));
UpdateColocatedShardPlacementProgress(shardId,
sourceNode->workerName,
sourceNode->workerPort,
REBALANCE_PROGRESS_ERROR);
return false;
}
/* if source node is not responsive, don't continue */
bool sourceResponsive = WorkerNodeListContains(responsiveNodeList,
sourceNode->workerName,
sourceNode->workerPort);
if (!sourceResponsive)
{
ereport(WARNING, (errmsg("%s:%d is not responsive", sourceNode->workerName,
sourceNode->workerPort)));
UpdateColocatedShardPlacementProgress(shardId,
sourceNode->workerName,
sourceNode->workerPort,
REBALANCE_PROGRESS_ERROR);
return false;
}
if (updateType == PLACEMENT_UPDATE_MOVE)
{
appendStringInfo(placementUpdateCommand,
"SELECT citus_move_shard_placement(%ld,%s,%u,%s,%u,%s)",
shardId,
quote_literal_cstr(sourceNode->workerName),
sourceNode->workerPort,
quote_literal_cstr(targetNode->workerName),
targetNode->workerPort,
quote_literal_cstr(shardTranferModeLabel));
}
else if (updateType == PLACEMENT_UPDATE_COPY)
{
appendStringInfo(placementUpdateCommand,
"SELECT citus_copy_shard_placement(%ld,%s,%u,%s,%u,%s,%s)",
shardId,
quote_literal_cstr(sourceNode->workerName),
sourceNode->workerPort,
quote_literal_cstr(targetNode->workerName),
targetNode->workerPort,
doRepair,
quote_literal_cstr(shardTranferModeLabel));
}
else
{
ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("only moving or copying shards is supported")));
}
UpdateColocatedShardPlacementProgress(shardId,
sourceNode->workerName,
sourceNode->workerPort,
REBALANCE_PROGRESS_MOVING);
MultiConnection *connection = GetNodeConnection(connectionFlag, LocalHostName,
PostPortNumber);
/*
* In case of failure, we throw an error such that rebalance_table_shards
* fails early.
*/
ExecuteCriticalRemoteCommand(connection, placementUpdateCommand->data);
UpdateColocatedShardPlacementProgress(shardId,
sourceNode->workerName,
sourceNode->workerPort,
REBALANCE_PROGRESS_MOVED);
return true;
}
/*
* RebalancePlacementUpdates returns a list of placement updates which makes the
* cluster balanced. We move shards to these nodes until all nodes become utilized.
* We consider a node under-utilized if it has less than floor((1.0 - threshold) *
* placementCountAverage) shard placements. In each iteration we choose the node
* with maximum number of shard placements as the source, and we choose the node
* with minimum number of shard placements as the target. Then we choose a shard
* which is placed in the source node but not in the target node as the shard to
* move.
*
* The shardPlacementListList argument contains a list of lists of shard
* placements. Each of these lists are balanced independently. This is used to
* make sure different colocation groups are balanced separately, so each list
* contains the placements of a colocation group.
*/
List *
RebalancePlacementUpdates(List *workerNodeList, List *shardPlacementListList,
double threshold,
int32 maxShardMoves,
bool drainOnly,
RebalancePlanFunctions *functions)
{
List *rebalanceStates = NIL;
RebalanceState *state = NULL;
List *shardPlacementList = NIL;
List *placementUpdateList = NIL;
foreach_ptr(shardPlacementList, shardPlacementListList)
{
state = InitRebalanceState(workerNodeList, shardPlacementList,
functions);
rebalanceStates = lappend(rebalanceStates, state);
}
foreach_ptr(state, rebalanceStates)
{
state->placementUpdateList = placementUpdateList;
MoveShardsAwayFromDisallowedNodes(state);
placementUpdateList = state->placementUpdateList;
}
if (!drainOnly)
{
foreach_ptr(state, rebalanceStates)
{
state->placementUpdateList = placementUpdateList;
/* calculate lower bound for placement count */
float4 averageUtilization = (state->totalCost / state->totalCapacity);
float4 utilizationLowerBound = ((1.0 - threshold) * averageUtilization);
float4 utilizationUpperBound = ((1.0 + threshold) * averageUtilization);
bool moreMovesAvailable = true;
while (list_length(state->placementUpdateList) < maxShardMoves &&
moreMovesAvailable)
{
moreMovesAvailable = FindAndMoveShardCost(utilizationLowerBound,
utilizationUpperBound,
state);
}
placementUpdateList = state->placementUpdateList;
if (moreMovesAvailable)
{
ereport(NOTICE, (errmsg(
"Stopped searching before we were out of moves. "
"Please rerun the rebalancer after it's finished "
"for a more optimal placement.")));
break;
}
}
}
foreach_ptr(state, rebalanceStates)
{
hash_destroy(state->placementsHash);
}
return placementUpdateList;
}
/*
* InitRebalanceState sets up a RebalanceState for it's arguments. The
* RebalanceState contains the information needed to calculate shard moves.
*/
static RebalanceState *
InitRebalanceState(List *workerNodeList, List *shardPlacementList,
RebalancePlanFunctions *functions)
{
ShardPlacement *placement = NULL;
HASH_SEQ_STATUS status;
WorkerNode *workerNode = NULL;
RebalanceState *state = palloc0(sizeof(RebalanceState));
state->functions = functions;
state->placementsHash = ActivePlacementsHash(shardPlacementList);
/* create empty fill state for all of the worker nodes */
foreach_ptr(workerNode, workerNodeList)
{
NodeFillState *fillState = palloc0(sizeof(NodeFillState));
fillState->node = workerNode;
fillState->capacity = functions->nodeCapacity(workerNode, functions->context);
/*
* Set the utilization here although the totalCost is not set yet. This is
* important to set the utilization to INFINITY when the capacity is 0.
*/
fillState->utilization = CalculateUtilization(fillState->totalCost,
fillState->capacity);
state->fillStateListAsc = lappend(state->fillStateListAsc, fillState);
state->fillStateListDesc = lappend(state->fillStateListDesc, fillState);
state->totalCapacity += fillState->capacity;
}
/* Fill the fill states for all of the worker nodes based on the placements */
foreach_htab(placement, &status, state->placementsHash)
{
ShardCost *shardCost = palloc0(sizeof(ShardCost));
NodeFillState *fillState = FindFillStateForPlacement(state, placement);
Assert(fillState != NULL);
*shardCost = functions->shardCost(placement->shardId, functions->context);
fillState->totalCost += shardCost->cost;
fillState->utilization = CalculateUtilization(fillState->totalCost,
fillState->capacity);
fillState->shardCostListDesc = lappend(fillState->shardCostListDesc,
shardCost);
fillState->shardCostListDesc = SortList(fillState->shardCostListDesc,
CompareShardCostDesc);
state->totalCost += shardCost->cost;
if (!functions->shardAllowedOnNode(placement->shardId, fillState->node,
functions->context))
{
DisallowedPlacement *disallowed = palloc0(sizeof(DisallowedPlacement));
disallowed->shardCost = shardCost;
disallowed->fillState = fillState;
state->disallowedPlacementList = lappend(state->disallowedPlacementList,
disallowed);
}
}
foreach_htab_cleanup(placement, &status);
state->fillStateListAsc = SortList(state->fillStateListAsc, CompareNodeFillStateAsc);
state->fillStateListDesc = SortList(state->fillStateListDesc,
CompareNodeFillStateDesc);
CheckRebalanceStateInvariants(state);
return state;
}
/*
* CalculateUtilization returns INFINITY when capacity is 0 and
* totalCost/capacity otherwise.
*/
static float4
CalculateUtilization(float4 totalCost, float4 capacity)
{
if (capacity <= 0)
{
return INFINITY;
}
return totalCost / capacity;
}
/*
* FindFillStateForPlacement finds the fillState for the workernode that
* matches the placement.
*/
static NodeFillState *
FindFillStateForPlacement(RebalanceState *state, ShardPlacement *placement)
{
NodeFillState *fillState = NULL;
/* Find the correct fill state to add the placement to and do that */
foreach_ptr(fillState, state->fillStateListAsc)
{
if (IsPlacementOnWorkerNode(placement, fillState->node))
{
return fillState;
}
}
return NULL;
}
/*
* IsPlacementOnWorkerNode checks if the shard placement is for to the given
* workenode.
*/
static bool
IsPlacementOnWorkerNode(ShardPlacement *placement, WorkerNode *workerNode)
{
if (strncmp(workerNode->workerName, placement->nodeName, WORKER_LENGTH) != 0)
{
return false;
}
return workerNode->workerPort == placement->nodePort;
}
/*
* CompareNodeFillStateAsc can be used to sort fill states from empty to full.
*/
static int
CompareNodeFillStateAsc(const void *void1, const void *void2)
{
const NodeFillState *a = *((const NodeFillState **) void1);
const NodeFillState *b = *((const NodeFillState **) void2);
if (a->utilization < b->utilization)
{
return -1;
}
if (a->utilization > b->utilization)
{
return 1;
}
/*
* If utilization prefer nodes with more capacity, since utilization will
* grow slower on those
*/
if (a->capacity > b->capacity)
{
return -1;
}
if (a->capacity < b->capacity)
{
return 1;
}
/* Finally differentiate by node id */
if (a->node->nodeId < b->node->nodeId)
{
return -1;
}
return a->node->nodeId > b->node->nodeId;
}
/*
* CompareNodeFillStateDesc can be used to sort fill states from full to empty.
*/
static int
CompareNodeFillStateDesc(const void *a, const void *b)
{
return -CompareNodeFillStateAsc(a, b);
}
/*
* CompareShardCostAsc can be used to sort shard costs from low cost to high
* cost.
*/
static int
CompareShardCostAsc(const void *void1, const void *void2)
{
const ShardCost *a = *((const ShardCost **) void1);
const ShardCost *b = *((const ShardCost **) void2);
if (a->cost < b->cost)
{
return -1;
}
if (a->cost > b->cost)
{
return 1;
}
/* make compare function (more) stable for tests */
if (a->shardId > b->shardId)
{
return -1;
}
return a->shardId < b->shardId;
}
/*
* CompareShardCostAsc can be used to sort shard costs from high cost to low
* cost.
*/
static int
CompareShardCostDesc(const void *a, const void *b)
{
return -CompareShardCostAsc(a, b);
}
/*
* MoveShardsAwayFromDisallowedNodes returns a list of placement updates that
* move any shards that are not allowed on their current node to a node that
* they are allowed on.
*/
static void
MoveShardsAwayFromDisallowedNodes(RebalanceState *state)
{
DisallowedPlacement *disallowedPlacement = NULL;
state->disallowedPlacementList = SortList(state->disallowedPlacementList,
CompareDisallowedPlacementDesc);
/* Move shards off of nodes they are not allowed on */
foreach_ptr(disallowedPlacement, state->disallowedPlacementList)
{
NodeFillState *targetFillState = FindAllowedTargetFillState(
state, disallowedPlacement->shardCost->shardId);
if (targetFillState == NULL)
{
ereport(WARNING, (errmsg(
"Not allowed to move shard " UINT64_FORMAT
" anywhere from %s:%d",
disallowedPlacement->shardCost->shardId,
disallowedPlacement->fillState->node->workerName,
disallowedPlacement->fillState->node->workerPort
)));
continue;
}
MoveShardCost(disallowedPlacement->fillState,
targetFillState,
disallowedPlacement->shardCost,
state);
}
}
/*
* CompareDisallowedPlacementAsc can be used to sort disallowed placements from
* low cost to high cost.
*/
static int
CompareDisallowedPlacementAsc(const void *void1, const void *void2)
{
const DisallowedPlacement *a = *((const DisallowedPlacement **) void1);
const DisallowedPlacement *b = *((const DisallowedPlacement **) void2);
return CompareShardCostAsc(&(a->shardCost), &(b->shardCost));
}
/*
* CompareDisallowedPlacementAsc can be used to sort disallowed placements from
* low cost to high cost.
*/
static int
CompareDisallowedPlacementDesc(const void *a, const void *b)
{
return -CompareDisallowedPlacementAsc(a, b);
}
/*
* FindAllowedTargetFillState finds the first fill state in fillStateListAsc
* where the shard can be moved to.
*/
static NodeFillState *
FindAllowedTargetFillState(RebalanceState *state, uint64 shardId)
{
NodeFillState *targetFillState = NULL;
foreach_ptr(targetFillState, state->fillStateListAsc)
{
bool hasShard = PlacementsHashFind(
state->placementsHash,
shardId,
targetFillState->node);
if (!hasShard && state->functions->shardAllowedOnNode(
shardId,
targetFillState->node,
state->functions->context))
{
return targetFillState;
}
}
return NULL;
}
/*
* MoveShardCost moves a shardcost from the source to the target fill states
* and updates the RebalanceState accordingly. What it does in detail is:
* 1. add a placement update to state->placementUpdateList
* 2. update state->placementsHash
* 3. update totalcost, utilization and shardCostListDesc in source and target
* 4. resort state->fillStateListAsc/Desc
*/
static void
MoveShardCost(NodeFillState *sourceFillState,
NodeFillState *targetFillState,
ShardCost *shardCost,
RebalanceState *state)
{
uint64 shardIdToMove = shardCost->shardId;
/* construct the placement update */
PlacementUpdateEvent *placementUpdateEvent = palloc0(sizeof(PlacementUpdateEvent));
placementUpdateEvent->updateType = PLACEMENT_UPDATE_MOVE;
placementUpdateEvent->shardId = shardIdToMove;
placementUpdateEvent->sourceNode = sourceFillState->node;
placementUpdateEvent->targetNode = targetFillState->node;
/* record the placement update */
state->placementUpdateList = lappend(state->placementUpdateList,
placementUpdateEvent);
/* update the placements hash and the node shard lists */
PlacementsHashRemove(state->placementsHash, shardIdToMove, sourceFillState->node);
PlacementsHashEnter(state->placementsHash, shardIdToMove, targetFillState->node);
sourceFillState->totalCost -= shardCost->cost;
sourceFillState->utilization = CalculateUtilization(sourceFillState->totalCost,
sourceFillState->capacity);
sourceFillState->shardCostListDesc = list_delete_ptr(
sourceFillState->shardCostListDesc,
shardCost);
targetFillState->totalCost += shardCost->cost;
targetFillState->utilization = CalculateUtilization(targetFillState->totalCost,
targetFillState->capacity);
targetFillState->shardCostListDesc = lappend(targetFillState->shardCostListDesc,
shardCost);
targetFillState->shardCostListDesc = SortList(targetFillState->shardCostListDesc,
CompareShardCostDesc);
state->fillStateListAsc = SortList(state->fillStateListAsc, CompareNodeFillStateAsc);
state->fillStateListDesc = SortList(state->fillStateListDesc,
CompareNodeFillStateDesc);
CheckRebalanceStateInvariants(state);
}
/*
* FindAndMoveShardCost is the main rebalancing algorithm. This takes the
* current state and returns a list with a new move appended that improves the
* balance of shards. The algorithm is greedy and will use the first new move
* that improves the balance. It finds nodes by trying to move a shard from the
* fullest node to the emptiest node. If no moves are possible it will try the
* second emptiest node until it tried all of them. Then it wil try the second
* fullest node. If it was able to find a move it will return true and false if
* it couldn't.
*/
static bool
FindAndMoveShardCost(float4 utilizationLowerBound, float4 utilizationUpperBound,
RebalanceState *state)
{
NodeFillState *sourceFillState = NULL;
NodeFillState *targetFillState = NULL;
/*
* find a source node for the move, starting at the node with the highest
* utilization
*/
foreach_ptr(sourceFillState, state->fillStateListDesc)
{
/* Don't move shards away from nodes that are already too empty, we're
* done searching */
if (sourceFillState->utilization <= utilizationLowerBound)
{
return false;
}
/* find a target node for the move, starting at the node with the
* lowest utilization */
foreach_ptr(targetFillState, state->fillStateListAsc)
{
ShardCost *shardCost = NULL;
/* Don't add more shards to nodes that are already at the upper
* bound. We should try the next source node now because further
* target nodes will also be above the upper bound */
if (targetFillState->utilization >= utilizationUpperBound)
{
break;
}
/* Don't move a shard between nodes that both have decent
* utilization. We should try the next source node now because
* further target nodes will also have have decent utilization */
if (targetFillState->utilization >= utilizationLowerBound &&
sourceFillState->utilization <= utilizationUpperBound)
{
break;
}
/* find a shardcost that can be moved between between nodes that
* makes the cost distribution more equal */
foreach_ptr(shardCost, sourceFillState->shardCostListDesc)
{
bool targetHasShard = PlacementsHashFind(state->placementsHash,
shardCost->shardId,
targetFillState->node);
float4 newTargetTotalCost = targetFillState->totalCost + shardCost->cost;
float4 newTargetUtilization = CalculateUtilization(
newTargetTotalCost,
targetFillState->capacity);
float4 newSourceTotalCost = sourceFillState->totalCost - shardCost->cost;
float4 newSourceUtilization = CalculateUtilization(
newSourceTotalCost,
sourceFillState->capacity);
/* Skip shards that already are on the node */
if (targetHasShard)
{
continue;
}
/* Skip shards that already are not allowed on the node */
if (!state->functions->shardAllowedOnNode(shardCost->shardId,
targetFillState->node,
state->functions->context))
{
continue;
}
/*
* Ensure that the cost distrubition is actually better
* after the move, i.e. the new highest utilization of
* source and target is lower than the previous highest, or
* the highest utilization is the same, but the lowest
* increased.
*/
if (newTargetUtilization > sourceFillState->utilization)
{
continue;
}
if (newTargetUtilization == sourceFillState->utilization &&
newSourceUtilization <= targetFillState->utilization
)
{
/*
* this can trigger when capacity of the nodes is not the
* same. Example (also a test):
* - node with capacity 3
* - node with capacity 1
* - 3 shards with cost 1
* Best distribution would be 2 shards on node with
* capacity 3 and one on node with capacity 1
*/
continue;
}
MoveShardCost(sourceFillState, targetFillState,
shardCost, state);
return true;
}
}
}
return false;
}
/*
* ReplicationPlacementUpdates returns a list of placement updates which
* replicates shard placements that need re-replication. To do this, the
* function loops over the shard placements, and for each shard placement
* which needs to be re-replicated, it chooses an active worker node with
* smallest number of shards as the target node.
*/
List *
ReplicationPlacementUpdates(List *workerNodeList, List *shardPlacementList,
int shardReplicationFactor)
{
List *placementUpdateList = NIL;
ListCell *shardPlacementCell = NULL;
uint32 workerNodeIndex = 0;
HTAB *placementsHash = ActivePlacementsHash(shardPlacementList);
uint32 workerNodeCount = list_length(workerNodeList);
/* get number of shards per node */
uint32 *shardCountArray = palloc0(workerNodeCount * sizeof(uint32));
foreach(shardPlacementCell, shardPlacementList)
{
ShardPlacement *placement = lfirst(shardPlacementCell);
if (placement->shardState != SHARD_STATE_ACTIVE)
{
continue;
}
for (workerNodeIndex = 0; workerNodeIndex < workerNodeCount; workerNodeIndex++)
{
WorkerNode *node = list_nth(workerNodeList, workerNodeIndex);
if (strncmp(node->workerName, placement->nodeName, WORKER_LENGTH) == 0 &&
node->workerPort == placement->nodePort)
{
shardCountArray[workerNodeIndex]++;
break;
}
}
}
foreach(shardPlacementCell, shardPlacementList)
{
WorkerNode *sourceNode = NULL;
WorkerNode *targetNode = NULL;
uint32 targetNodeShardCount = UINT_MAX;
uint32 targetNodeIndex = 0;
ShardPlacement *placement = (ShardPlacement *) lfirst(shardPlacementCell);
uint64 shardId = placement->shardId;
/* skip the shard placement if it has enough replications */
int activePlacementCount = ShardActivePlacementCount(placementsHash, shardId,
workerNodeList);
if (activePlacementCount >= shardReplicationFactor)
{
continue;
}
/*
* We can copy the shard from any active worker node that contains the
* shard.
*/
for (workerNodeIndex = 0; workerNodeIndex < workerNodeCount; workerNodeIndex++)
{
WorkerNode *workerNode = list_nth(workerNodeList, workerNodeIndex);
bool placementExists = PlacementsHashFind(placementsHash, shardId,
workerNode);
if (placementExists)
{
sourceNode = workerNode;
break;
}
}
/*
* If we couldn't find any worker node which contains the shard, then
* all copies of the shard are list and we should error out.
*/
if (sourceNode == NULL)
{
ereport(ERROR, (errmsg("could not find a source for shard " UINT64_FORMAT,
shardId)));
}
/*
* We can copy the shard to any worker node that doesn't contain the shard.
* Among such worker nodes, we choose the worker node with minimum shard
* count as the target.
*/
for (workerNodeIndex = 0; workerNodeIndex < workerNodeCount; workerNodeIndex++)
{
WorkerNode *workerNode = list_nth(workerNodeList, workerNodeIndex);
if (!NodeCanHaveDistTablePlacements(workerNode))
{
/* never replicate placements to nodes that should not have placements */
continue;
}
/* skip this node if it already contains the shard */
bool placementExists = PlacementsHashFind(placementsHash, shardId,
workerNode);
if (placementExists)
{
continue;
}
/* compare and change the target node */
if (shardCountArray[workerNodeIndex] < targetNodeShardCount)
{
targetNode = workerNode;
targetNodeShardCount = shardCountArray[workerNodeIndex];
targetNodeIndex = workerNodeIndex;
}
}
/*
* If there is no worker node which doesn't contain the shard, then the
* shard replication factor is greater than number of worker nodes, and
* we should error out.
*/
if (targetNode == NULL)
{
ereport(ERROR, (errmsg("could not find a target for shard " UINT64_FORMAT,
shardId)));
}
/* construct the placement update */
PlacementUpdateEvent *placementUpdateEvent = palloc0(
sizeof(PlacementUpdateEvent));
placementUpdateEvent->updateType = PLACEMENT_UPDATE_COPY;
placementUpdateEvent->shardId = shardId;
placementUpdateEvent->sourceNode = sourceNode;
placementUpdateEvent->targetNode = targetNode;
/* record the placement update */
placementUpdateList = lappend(placementUpdateList, placementUpdateEvent);
/* update the placements hash and the shard count array */
PlacementsHashEnter(placementsHash, shardId, targetNode);
shardCountArray[targetNodeIndex]++;
}
hash_destroy(placementsHash);
return placementUpdateList;
}
/*
* ShardActivePlacementCount returns the number of active placements for the
* given shard which are placed at the active worker nodes.
*/
static int
ShardActivePlacementCount(HTAB *activePlacementsHash, uint64 shardId,
List *activeWorkerNodeList)
{
int shardActivePlacementCount = 0;
ListCell *workerNodeCell = NULL;
foreach(workerNodeCell, activeWorkerNodeList)
{
WorkerNode *workerNode = lfirst(workerNodeCell);
bool placementExists = PlacementsHashFind(activePlacementsHash, shardId,
workerNode);
if (placementExists)
{
shardActivePlacementCount++;
}
}
return shardActivePlacementCount;
}
/*
* ActivePlacementsHash creates and returns a hash set for the placements in
* the given list of shard placements which are in active state.
*/
static HTAB *
ActivePlacementsHash(List *shardPlacementList)
{
ListCell *shardPlacementCell = NULL;
HASHCTL info;
int shardPlacementCount = list_length(shardPlacementList);
memset(&info, 0, sizeof(info));
info.keysize = sizeof(ShardPlacement);
info.entrysize = sizeof(ShardPlacement);
info.hash = PlacementsHashHashCode;
info.match = PlacementsHashCompare;
int hashFlags = (HASH_ELEM | HASH_FUNCTION | HASH_COMPARE);
HTAB *shardPlacementsHash = hash_create("ActivePlacements Hash",
shardPlacementCount, &info, hashFlags);
foreach(shardPlacementCell, shardPlacementList)
{
ShardPlacement *shardPlacement = (ShardPlacement *) lfirst(shardPlacementCell);
if (shardPlacement->shardState == SHARD_STATE_ACTIVE)
{
void *hashKey = (void *) shardPlacement;
hash_search(shardPlacementsHash, hashKey, HASH_ENTER, NULL);
}
}
return shardPlacementsHash;
}
/*
* PlacementsHashFinds returns true if there exists a shard placement with the
* given workerNode and shard id in the given placements hash, otherwise it
* returns false.
*/
static bool
PlacementsHashFind(HTAB *placementsHash, uint64 shardId, WorkerNode *workerNode)
{
bool placementFound = false;
ShardPlacement shardPlacement;
memset(&shardPlacement, 0, sizeof(shardPlacement));
shardPlacement.shardId = shardId;
shardPlacement.nodeName = workerNode->workerName;
shardPlacement.nodePort = workerNode->workerPort;
void *hashKey = (void *) (&shardPlacement);
hash_search(placementsHash, hashKey, HASH_FIND, &placementFound);
return placementFound;
}
/*
* PlacementsHashEnter enters a shard placement for the given worker node and
* shard id to the given placements hash.
*/
static void
PlacementsHashEnter(HTAB *placementsHash, uint64 shardId, WorkerNode *workerNode)
{
ShardPlacement shardPlacement;
memset(&shardPlacement, 0, sizeof(shardPlacement));
shardPlacement.shardId = shardId;
shardPlacement.nodeName = workerNode->workerName;
shardPlacement.nodePort = workerNode->workerPort;
void *hashKey = (void *) (&shardPlacement);
hash_search(placementsHash, hashKey, HASH_ENTER, NULL);
}
/*
* PlacementsHashRemove removes the shard placement for the given worker node and
* shard id from the given placements hash.
*/
static void
PlacementsHashRemove(HTAB *placementsHash, uint64 shardId, WorkerNode *workerNode)
{
ShardPlacement shardPlacement;
memset(&shardPlacement, 0, sizeof(shardPlacement));
shardPlacement.shardId = shardId;
shardPlacement.nodeName = workerNode->workerName;
shardPlacement.nodePort = workerNode->workerPort;
void *hashKey = (void *) (&shardPlacement);
hash_search(placementsHash, hashKey, HASH_REMOVE, NULL);
}
/*
* ShardPlacementCompare compares two shard placements using shard id, node name,
* and node port number.
*/
static int
PlacementsHashCompare(const void *lhsKey, const void *rhsKey, Size keySize)
{
const ShardPlacement *placementLhs = (const ShardPlacement *) lhsKey;
const ShardPlacement *placementRhs = (const ShardPlacement *) rhsKey;
int shardIdCompare = 0;
/* first, compare by shard id */
if (placementLhs->shardId < placementRhs->shardId)
{
shardIdCompare = -1;
}
else if (placementLhs->shardId > placementRhs->shardId)
{
shardIdCompare = 1;
}
if (shardIdCompare != 0)
{
return shardIdCompare;
}
/* then, compare by node name */
int nodeNameCompare = strncmp(placementLhs->nodeName, placementRhs->nodeName,
WORKER_LENGTH);
if (nodeNameCompare != 0)
{
return nodeNameCompare;
}
/* finally, compare by node port */
int nodePortCompare = placementLhs->nodePort - placementRhs->nodePort;
return nodePortCompare;
}
/*
* ShardPlacementHashCode computes the hash code for a shard placement from the
* placement's shard id, node name, and node port number.
*/
static uint32
PlacementsHashHashCode(const void *key, Size keySize)
{
const ShardPlacement *placement = (const ShardPlacement *) key;
const uint64 *shardId = &(placement->shardId);
const char *nodeName = placement->nodeName;
const uint32 *nodePort = &(placement->nodePort);
/* standard hash function outlined in Effective Java, Item 8 */
uint32 result = 17;
result = 37 * result + tag_hash(shardId, sizeof(uint64));
result = 37 * result + string_hash(nodeName, WORKER_LENGTH);
result = 37 * result + tag_hash(nodePort, sizeof(uint32));
return result;
}
/* WorkerNodeListContains checks if the worker node exists in the given list. */
static bool
WorkerNodeListContains(List *workerNodeList, const char *workerName, uint32 workerPort)
{
bool workerNodeListContains = false;
ListCell *workerNodeCell = NULL;
foreach(workerNodeCell, workerNodeList)
{
WorkerNode *workerNode = (WorkerNode *) lfirst(workerNodeCell);
if ((strncmp(workerNode->workerName, workerName, WORKER_LENGTH) == 0) &&
(workerNode->workerPort == workerPort))
{
workerNodeListContains = true;
break;
}
}
return workerNodeListContains;
}
/*
* UpdateColocatedShardPlacementProgress updates the progress of the given placement,
* along with its colocated placements, to the given state.
*/
static void
UpdateColocatedShardPlacementProgress(uint64 shardId, char *sourceName, int sourcePort,
uint64 progress)
{
ProgressMonitorData *header = GetCurrentProgressMonitor();
if (header != NULL && header->steps != NULL)
{
PlacementUpdateEventProgress *steps = header->steps;
ListCell *colocatedShardIntervalCell = NULL;
ShardInterval *shardInterval = LoadShardInterval(shardId);
List *colocatedShardIntervalList = ColocatedShardIntervalList(shardInterval);
for (int moveIndex = 0; moveIndex < header->stepCount; moveIndex++)
{
PlacementUpdateEventProgress *step = steps + moveIndex;
uint64 currentShardId = step->shardId;
bool colocatedShard = false;
foreach(colocatedShardIntervalCell, colocatedShardIntervalList)
{
ShardInterval *candidateShard = lfirst(colocatedShardIntervalCell);
if (candidateShard->shardId == currentShardId)
{
colocatedShard = true;
break;
}
}
if (colocatedShard &&
strcmp(step->sourceName, sourceName) == 0 &&
step->sourcePort == sourcePort)
{
step->progress = progress;
}
}
}
}
/*
* citus_rebalance_strategy_enterprise_check is trigger function, intended for
* use in prohibiting writes to pg_dist_rebalance_strategy in Citus Community.
*/
Datum
pg_dist_rebalance_strategy_enterprise_check(PG_FUNCTION_ARGS)
{
/* This is Enterprise, so this check is a no-op */
PG_RETURN_VOID();
}
/*
* citus_validate_rebalance_strategy_functions checks all the functions for
* their correct signature.
*
* SQL signature:
*
* citus_validate_rebalance_strategy_functions(
* shard_cost_function regproc,
* node_capacity_function regproc,
* shard_allowed_on_node_function regproc,
* ) RETURNS VOID
*/
Datum
citus_validate_rebalance_strategy_functions(PG_FUNCTION_ARGS)
{
EnsureShardCostUDF(PG_GETARG_OID(0));
EnsureNodeCapacityUDF(PG_GETARG_OID(1));
EnsureShardAllowedOnNodeUDF(PG_GETARG_OID(2));
PG_RETURN_VOID();
}
/*
* EnsureShardCostUDF checks that the UDF matching the oid has the correct
* signature to be used as a ShardCost function. The expected signature is:
*
* shard_cost(shardid bigint) returns float4
*/
static void
EnsureShardCostUDF(Oid functionOid)
{
HeapTuple proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(functionOid));
if (!HeapTupleIsValid(proctup))
{
ereport(ERROR, (errmsg("cache lookup failed for shard_cost_function with oid %u",
functionOid)));
}
Form_pg_proc procForm = (Form_pg_proc) GETSTRUCT(proctup);
char *name = NameStr(procForm->proname);
if (procForm->pronargs != 1)
{
ereport(ERROR, (errmsg("signature for shard_cost_function is incorrect"),
errdetail(
"number of arguments of %s should be 1, not %i",
name, procForm->pronargs)));
}
if (procForm->proargtypes.values[0] != INT8OID)
{
ereport(ERROR, (errmsg("signature for shard_cost_function is incorrect"),
errdetail(
"argument type of %s should be bigint", name)));
}
if (procForm->prorettype != FLOAT4OID)
{
ereport(ERROR, (errmsg("signature for shard_cost_function is incorrect"),
errdetail("return type of %s should be real", name)));
}
ReleaseSysCache(proctup);
}
/*
* EnsureNodeCapacityUDF checks that the UDF matching the oid has the correct
* signature to be used as a NodeCapacity function. The expected signature is:
*
* node_capacity(nodeid int) returns float4
*/
static void
EnsureNodeCapacityUDF(Oid functionOid)
{
HeapTuple proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(functionOid));
if (!HeapTupleIsValid(proctup))
{
ereport(ERROR, (errmsg(
"cache lookup failed for node_capacity_function with oid %u",
functionOid)));
}
Form_pg_proc procForm = (Form_pg_proc) GETSTRUCT(proctup);
char *name = NameStr(procForm->proname);
if (procForm->pronargs != 1)
{
ereport(ERROR, (errmsg("signature for node_capacity_function is incorrect"),
errdetail(
"number of arguments of %s should be 1, not %i",
name, procForm->pronargs)));
}
if (procForm->proargtypes.values[0] != INT4OID)
{
ereport(ERROR, (errmsg("signature for node_capacity_function is incorrect"),
errdetail("argument type of %s should be int", name)));
}
if (procForm->prorettype != FLOAT4OID)
{
ereport(ERROR, (errmsg("signature for node_capacity_function is incorrect"),
errdetail("return type of %s should be real", name)));
}
ReleaseSysCache(proctup);
}
/*
* EnsureNodeCapacityUDF checks that the UDF matching the oid has the correct
* signature to be used as a NodeCapacity function. The expected signature is:
*
* shard_allowed_on_node(shardid bigint, nodeid int) returns boolean
*/
static void
EnsureShardAllowedOnNodeUDF(Oid functionOid)
{
HeapTuple proctup = SearchSysCache1(PROCOID, ObjectIdGetDatum(functionOid));
if (!HeapTupleIsValid(proctup))
{
ereport(ERROR, (errmsg(
"cache lookup failed for shard_allowed_on_node_function with oid %u",
functionOid)));
}
Form_pg_proc procForm = (Form_pg_proc) GETSTRUCT(proctup);
char *name = NameStr(procForm->proname);
if (procForm->pronargs != 2)
{
ereport(ERROR, (errmsg(
"signature for shard_allowed_on_node_function is incorrect"),
errdetail(
"number of arguments of %s should be 2, not %i",
name, procForm->pronargs)));
}
if (procForm->proargtypes.values[0] != INT8OID)
{
ereport(ERROR, (errmsg(
"signature for shard_allowed_on_node_function is incorrect"),
errdetail(
"type of first argument of %s should be bigint", name)));
}
if (procForm->proargtypes.values[1] != INT4OID)
{
ereport(ERROR, (errmsg(
"signature for shard_allowed_on_node_function is incorrect"),
errdetail(
"type of second argument of %s should be int", name)));
}
if (procForm->prorettype != BOOLOID)
{
ereport(ERROR, (errmsg(
"signature for shard_allowed_on_node_function is incorrect"),
errdetail(
"return type of %s should be boolean", name)));
}
ReleaseSysCache(proctup);
}