citus/src/backend/distributed/test/shard_rebalancer.c

629 lines
19 KiB
C

/*-------------------------------------------------------------------------
*
* test/shard_rebalancer.c
*
* This file contains functions used for unit testing the planning part of the
* shard rebalancer.
*
* Copyright (c) 2014-2019, Citus Data, Inc.
*
*-------------------------------------------------------------------------
*/
#include "postgres.h"
#include "libpq-fe.h"
#include "safe_lib.h"
#include "catalog/pg_type.h"
#include "distributed/citus_safe_lib.h"
#include "distributed/citus_ruleutils.h"
#include "distributed/connection_management.h"
#include "distributed/listutils.h"
#include "distributed/multi_physical_planner.h"
#include "distributed/shard_rebalancer.h"
#include "funcapi.h"
#include "miscadmin.h"
#include "utils/builtins.h"
#include "utils/int8.h"
#include "utils/json.h"
#include "utils/lsyscache.h"
#include "utils/memutils.h"
/* static declarations for json conversion */
static List * JsonArrayToShardPlacementTestInfoList(
ArrayType *shardPlacementJsonArrayObject);
static List * JsonArrayToWorkerTestInfoList(ArrayType *workerNodeJsonArrayObject);
static bool JsonFieldValueBool(Datum jsonDocument, const char *key);
static uint32 JsonFieldValueUInt32(Datum jsonDocument, const char *key);
static uint64 JsonFieldValueUInt64(Datum jsonDocument, const char *key);
static char * JsonFieldValueString(Datum jsonDocument, const char *key);
static ArrayType * PlacementUpdateListToJsonArray(List *placementUpdateList);
static bool ShardAllowedOnNode(uint64 shardId, WorkerNode *workerNode, void *context);
static float NodeCapacity(WorkerNode *workerNode, void *context);
static ShardCost GetShardCost(uint64 shardId, void *context);
PG_FUNCTION_INFO_V1(shard_placement_rebalance_array);
PG_FUNCTION_INFO_V1(shard_placement_replication_array);
PG_FUNCTION_INFO_V1(worker_node_responsive);
typedef struct ShardPlacementTestInfo
{
ShardPlacement *placement;
uint64 cost;
bool nextColocationGroup;
} ShardPlacementTestInfo;
typedef struct WorkerTestInfo
{
WorkerNode *node;
List *disallowedShardIds;
float capacity;
} WorkerTestInfo;
typedef struct RebalancePlanContext
{
List *workerTestInfoList;
List *shardPlacementTestInfoList;
} RebalancePlacementContext;
/*
* shard_placement_rebalance_array returns a list of operations which can make a
* cluster consisting of given shard placements and worker nodes balanced with
* respect to the given threshold. Threshold is a value between 0 and 1 which
* determines the evenness in shard distribution. When threshold is 0, then all
* nodes should have equal number of shards. As threshold increases, cluster's
* evenness requirements decrease, and we can rebalance the cluster using less
* operations.
*/
Datum
shard_placement_rebalance_array(PG_FUNCTION_ARGS)
{
ArrayType *workerNodeJsonArray = PG_GETARG_ARRAYTYPE_P(0);
ArrayType *shardPlacementJsonArray = PG_GETARG_ARRAYTYPE_P(1);
float threshold = PG_GETARG_FLOAT4(2);
int32 maxShardMoves = PG_GETARG_INT32(3);
bool drainOnly = PG_GETARG_BOOL(4);
List *workerNodeList = NIL;
List *shardPlacementListList = NIL;
List *shardPlacementList = NIL;
WorkerTestInfo *workerTestInfo = NULL;
ShardPlacementTestInfo *shardPlacementTestInfo = NULL;
RebalancePlanFunctions rebalancePlanFunctions = {
.shardAllowedOnNode = ShardAllowedOnNode,
.nodeCapacity = NodeCapacity,
.shardCost = GetShardCost,
};
RebalancePlacementContext context = {
.workerTestInfoList = NULL,
};
context.workerTestInfoList = JsonArrayToWorkerTestInfoList(workerNodeJsonArray);
context.shardPlacementTestInfoList = JsonArrayToShardPlacementTestInfoList(
shardPlacementJsonArray);
/* we don't need original arrays any more, so we free them to save memory */
pfree(workerNodeJsonArray);
pfree(shardPlacementJsonArray);
/* map workerTestInfoList to a list of its WorkerNodes */
foreach_ptr(workerTestInfo, context.workerTestInfoList)
{
workerNodeList = lappend(workerNodeList, workerTestInfo->node);
}
/* map shardPlacementTestInfoList to a list of list of its ShardPlacements */
foreach_ptr(shardPlacementTestInfo, context.shardPlacementTestInfoList)
{
if (shardPlacementTestInfo->nextColocationGroup)
{
shardPlacementList = SortList(shardPlacementList, CompareShardPlacements);
shardPlacementListList = lappend(shardPlacementListList, shardPlacementList);
shardPlacementList = NIL;
}
shardPlacementList = lappend(shardPlacementList,
shardPlacementTestInfo->placement);
}
shardPlacementList = SortList(shardPlacementList, CompareShardPlacements);
shardPlacementListList = lappend(shardPlacementListList, shardPlacementList);
rebalancePlanFunctions.context = &context;
/* sort the lists to make the function more deterministic */
workerNodeList = SortList(workerNodeList, CompareWorkerNodes);
List *placementUpdateList = RebalancePlacementUpdates(workerNodeList,
shardPlacementListList,
threshold,
maxShardMoves,
drainOnly,
&rebalancePlanFunctions);
ArrayType *placementUpdateJsonArray = PlacementUpdateListToJsonArray(
placementUpdateList);
PG_RETURN_ARRAYTYPE_P(placementUpdateJsonArray);
}
/*
* ShardAllowedOnNode is the function that checks if shard is allowed to be on
* a worker when running the shard rebalancer unit tests.
*/
static bool
ShardAllowedOnNode(uint64 shardId, WorkerNode *workerNode, void *voidContext)
{
RebalancePlacementContext *context = voidContext;
WorkerTestInfo *workerTestInfo = NULL;
uint64 *disallowedShardIdPtr = NULL;
foreach_ptr(workerTestInfo, context->workerTestInfoList)
{
if (workerTestInfo->node == workerNode)
{
break;
}
}
Assert(workerTestInfo != NULL);
foreach_ptr(disallowedShardIdPtr, workerTestInfo->disallowedShardIds)
{
if (shardId == *disallowedShardIdPtr)
{
return false;
}
}
return true;
}
/*
* NodeCapacity is the function that gets the capacity of a worker when running
* the shard rebalancer unit tests.
*/
static float
NodeCapacity(WorkerNode *workerNode, void *voidContext)
{
RebalancePlacementContext *context = voidContext;
WorkerTestInfo *workerTestInfo = NULL;
foreach_ptr(workerTestInfo, context->workerTestInfoList)
{
if (workerTestInfo->node == workerNode)
{
break;
}
}
Assert(workerTestInfo != NULL);
return workerTestInfo->capacity;
}
/*
* GetShardCost is the function that gets the ShardCost of a shard when running
* the shard rebalancer unit tests.
*/
static ShardCost
GetShardCost(uint64 shardId, void *voidContext)
{
RebalancePlacementContext *context = voidContext;
ShardCost shardCost;
memset_struct_0(shardCost);
shardCost.shardId = shardId;
ShardPlacementTestInfo *shardPlacementTestInfo = NULL;
foreach_ptr(shardPlacementTestInfo, context->shardPlacementTestInfoList)
{
if (shardPlacementTestInfo->placement->shardId == shardId)
{
break;
}
}
Assert(shardPlacementTestInfo != NULL);
shardCost.cost = shardPlacementTestInfo->cost;
return shardCost;
}
/*
* shard_placement_replication_array returns a list of operations which will
* replicate under-replicated shards in a cluster consisting of given shard
* placements and worker nodes. A shard is under-replicated if it has less
* active placements than the given shard replication factor.
*/
Datum
shard_placement_replication_array(PG_FUNCTION_ARGS)
{
ArrayType *workerNodeJsonArray = PG_GETARG_ARRAYTYPE_P(0);
ArrayType *shardPlacementJsonArray = PG_GETARG_ARRAYTYPE_P(1);
uint32 shardReplicationFactor = PG_GETARG_INT32(2);
List *workerNodeList = NIL;
List *shardPlacementList = NIL;
WorkerTestInfo *workerTestInfo = NULL;
ShardPlacementTestInfo *shardPlacementTestInfo = NULL;
/* validate shard replication factor */
if (shardReplicationFactor < SHARD_REPLICATION_FACTOR_MINIMUM ||
shardReplicationFactor > SHARD_REPLICATION_FACTOR_MAXIMUM)
{
ereport(ERROR, (errmsg("invalid shard replication factor"),
errhint("Shard replication factor must be an integer "
"between %d and %d", SHARD_REPLICATION_FACTOR_MINIMUM,
SHARD_REPLICATION_FACTOR_MAXIMUM)));
}
List *workerTestInfoList = JsonArrayToWorkerTestInfoList(workerNodeJsonArray);
List *shardPlacementTestInfoList = JsonArrayToShardPlacementTestInfoList(
shardPlacementJsonArray);
/* we don't need original arrays any more, so we free them to save memory */
pfree(workerNodeJsonArray);
pfree(shardPlacementJsonArray);
foreach_ptr(workerTestInfo, workerTestInfoList)
{
workerNodeList = lappend(workerNodeList, workerTestInfo->node);
}
foreach_ptr(shardPlacementTestInfo, shardPlacementTestInfoList)
{
shardPlacementList = lappend(shardPlacementList,
shardPlacementTestInfo->placement);
}
/* sort the lists to make the function more deterministic */
workerNodeList = SortList(workerNodeList, CompareWorkerNodes);
shardPlacementList = SortList(shardPlacementList, CompareShardPlacements);
List *placementUpdateList = ReplicationPlacementUpdates(workerNodeList,
shardPlacementList,
shardReplicationFactor);
ArrayType *placementUpdateJsonArray = PlacementUpdateListToJsonArray(
placementUpdateList);
PG_RETURN_ARRAYTYPE_P(placementUpdateJsonArray);
}
/*
* JsonArrayToShardPlacementTestInfoList converts the given shard placement json array
* to a list of ShardPlacement structs.
*/
static List *
JsonArrayToShardPlacementTestInfoList(ArrayType *shardPlacementJsonArrayObject)
{
List *shardPlacementTestInfoList = NIL;
Datum *shardPlacementJsonArray = NULL;
int placementCount = 0;
/*
* Memory is not automatically freed when we call UDFs using DirectFunctionCall.
* We call these functions in functionCallContext, so we can free the memory
* once they return.
*/
MemoryContext functionCallContext = AllocSetContextCreate(CurrentMemoryContext,
"Function Call Context",
ALLOCSET_DEFAULT_MINSIZE,
ALLOCSET_DEFAULT_INITSIZE,
ALLOCSET_DEFAULT_MAXSIZE);
deconstruct_array(shardPlacementJsonArrayObject, JSONOID, -1, false, 'i',
&shardPlacementJsonArray, NULL, &placementCount);
for (int placementIndex = 0; placementIndex < placementCount; placementIndex++)
{
Datum placementJson = shardPlacementJsonArray[placementIndex];
ShardPlacementTestInfo *placementTestInfo = palloc0(
sizeof(ShardPlacementTestInfo));
MemoryContext oldContext = MemoryContextSwitchTo(functionCallContext);
uint64 shardId = JsonFieldValueUInt64(placementJson, FIELD_NAME_SHARD_ID);
uint64 shardLength = JsonFieldValueUInt64(placementJson, FIELD_NAME_SHARD_LENGTH);
int shardState = JsonFieldValueUInt32(placementJson, FIELD_NAME_SHARD_STATE);
char *nodeName = JsonFieldValueString(placementJson, FIELD_NAME_NODE_NAME);
int nodePort = JsonFieldValueUInt32(placementJson, FIELD_NAME_NODE_PORT);
uint64 placementId = JsonFieldValueUInt64(placementJson, FIELD_NAME_PLACEMENT_ID);
MemoryContextSwitchTo(oldContext);
placementTestInfo->placement = palloc0(sizeof(ShardPlacement));
placementTestInfo->placement->shardId = shardId;
placementTestInfo->placement->shardLength = shardLength;
placementTestInfo->placement->shardState = shardState;
placementTestInfo->placement->nodeName = pstrdup(nodeName);
placementTestInfo->placement->nodePort = nodePort;
placementTestInfo->placement->placementId = placementId;
/*
* We have copied whatever we needed from the UDF calls, so we can free
* the memory allocated by them.
*/
MemoryContextReset(functionCallContext);
shardPlacementTestInfoList = lappend(shardPlacementTestInfoList,
placementTestInfo);
PG_TRY();
{
placementTestInfo->cost = JsonFieldValueUInt64(placementJson,
"cost");
}
PG_CATCH();
{
/* Ignore errors about not being able to find the key in that case cost is 1 */
FlushErrorState();
MemoryContextSwitchTo(oldContext);
placementTestInfo->cost = 1;
}
PG_END_TRY();
PG_TRY();
{
placementTestInfo->nextColocationGroup = JsonFieldValueBool(
placementJson, "next_colocation");
}
PG_CATCH();
{
/* Ignore errors about not being able to find the key in that case cost is 1 */
FlushErrorState();
MemoryContextSwitchTo(oldContext);
}
PG_END_TRY();
}
pfree(shardPlacementJsonArray);
return shardPlacementTestInfoList;
}
/*
* JsonArrayToWorkerNodeList converts the given worker node json array to a list
* of WorkerNode structs.
*/
static List *
JsonArrayToWorkerTestInfoList(ArrayType *workerNodeJsonArrayObject)
{
List *workerTestInfoList = NIL;
Datum *workerNodeJsonArray = NULL;
int workerNodeCount = 0;
deconstruct_array(workerNodeJsonArrayObject, JSONOID, -1, false, 'i',
&workerNodeJsonArray, NULL, &workerNodeCount);
for (int workerNodeIndex = 0; workerNodeIndex < workerNodeCount; workerNodeIndex++)
{
Datum workerNodeJson = workerNodeJsonArray[workerNodeIndex];
char *workerName = JsonFieldValueString(workerNodeJson, FIELD_NAME_WORKER_NAME);
uint32 workerPort = JsonFieldValueUInt32(workerNodeJson,
FIELD_NAME_WORKER_PORT);
List *disallowedShardIdList = NIL;
char *disallowedShardsString = NULL;
MemoryContext savedContext = CurrentMemoryContext;
WorkerTestInfo *workerTestInfo = palloc0(sizeof(WorkerTestInfo));
WorkerNode *workerNode = palloc0(sizeof(WorkerNode));
strncpy_s(workerNode->workerName, sizeof(workerNode->workerName), workerName,
WORKER_LENGTH);
workerNode->nodeId = workerNodeIndex;
workerNode->workerPort = workerPort;
workerNode->shouldHaveShards = true;
workerNode->nodeRole = PrimaryNodeRoleId();
workerTestInfo->node = workerNode;
PG_TRY();
{
workerTestInfo->capacity = JsonFieldValueUInt64(workerNodeJson,
"capacity");
}
PG_CATCH();
{
/* Ignore errors about not being able to find the key in that case capacity is 1 */
FlushErrorState();
MemoryContextSwitchTo(savedContext);
workerTestInfo->capacity = 1;
}
PG_END_TRY();
workerTestInfoList = lappend(workerTestInfoList, workerTestInfo);
PG_TRY();
{
disallowedShardsString = JsonFieldValueString(workerNodeJson,
"disallowed_shards");
}
PG_CATCH();
{
/* Ignore errors about not being able to find the key in that case all shards are allowed */
FlushErrorState();
MemoryContextSwitchTo(savedContext);
disallowedShardsString = NULL;
}
PG_END_TRY();
if (disallowedShardsString == NULL)
{
continue;
}
char *strtokPosition = NULL;
char *shardString = strtok_r(disallowedShardsString, ",", &strtokPosition);
while (shardString != NULL)
{
uint64 *shardInt = palloc0(sizeof(uint64));
*shardInt = SafeStringToUint64(shardString);
disallowedShardIdList = lappend(disallowedShardIdList, shardInt);
shardString = strtok_r(NULL, ",", &strtokPosition);
}
workerTestInfo->disallowedShardIds = disallowedShardIdList;
}
return workerTestInfoList;
}
/*
* JsonFieldValueBool gets the value of the given key in the given json
* document and returns it as a boolean.
*/
static bool
JsonFieldValueBool(Datum jsonDocument, const char *key)
{
char *valueString = JsonFieldValueString(jsonDocument, key);
Datum valueBoolDatum = DirectFunctionCall1(boolin, CStringGetDatum(valueString));
return DatumGetBool(valueBoolDatum);
}
/*
* JsonFieldValueUInt32 gets the value of the given key in the given json
* document and returns it as an unsigned 32-bit integer.
*/
static uint32
JsonFieldValueUInt32(Datum jsonDocument, const char *key)
{
char *valueString = JsonFieldValueString(jsonDocument, key);
Datum valueInt4Datum = DirectFunctionCall1(int4in, CStringGetDatum(valueString));
uint32 valueUInt32 = DatumGetInt32(valueInt4Datum);
return valueUInt32;
}
/*
* JsonFieldValueUInt64 gets the value of the given key in the given json
* document and returns it as an unsigned 64-bit integer.
*/
static uint64
JsonFieldValueUInt64(Datum jsonDocument, const char *key)
{
char *valueString = JsonFieldValueString(jsonDocument, key);
Datum valueInt8Datum = DirectFunctionCall1(int8in, CStringGetDatum(valueString));
uint64 valueUInt64 = DatumGetInt64(valueInt8Datum);
return valueUInt64;
}
/*
* JsonFieldValueString gets the value of the given key in the given json
* document and returns it as a string.
*/
static char *
JsonFieldValueString(Datum jsonDocument, const char *key)
{
Datum valueTextDatum = 0;
bool valueFetched = false;
Datum keyDatum = PointerGetDatum(cstring_to_text(key));
/*
* json_object_field_text can return NULL, but DirectFunctionalCall2 raises
* cryptic errors when the function returns NULL. We catch this error and
* raise a more meaningful error.
*/
PG_TRY();
{
valueTextDatum = DirectFunctionCall2(json_object_field_text,
jsonDocument, keyDatum);
valueFetched = true;
}
PG_CATCH();
{
FlushErrorState();
valueFetched = false;
}
PG_END_TRY();
if (!valueFetched)
{
ereport(ERROR, (errmsg("could not get value for '%s'", key)));
}
char *valueString = text_to_cstring(DatumGetTextP(valueTextDatum));
return valueString;
}
/*
* PlacementUpdateListToJsonArray converts the given list of placement update
* data to a json array.
*/
static ArrayType *
PlacementUpdateListToJsonArray(List *placementUpdateList)
{
ListCell *placementUpdateCell = NULL;
int placementUpdateIndex = 0;
int placementUpdateCount = list_length(placementUpdateList);
Datum *placementUpdateJsonArray = palloc0(placementUpdateCount * sizeof(Datum));
foreach(placementUpdateCell, placementUpdateList)
{
PlacementUpdateEvent *placementUpdateEvent = lfirst(placementUpdateCell);
WorkerNode *sourceNode = placementUpdateEvent->sourceNode;
WorkerNode *targetNode = placementUpdateEvent->targetNode;
StringInfo escapedSourceName = makeStringInfo();
escape_json(escapedSourceName, sourceNode->workerName);
StringInfo escapedTargetName = makeStringInfo();
escape_json(escapedTargetName, targetNode->workerName);
StringInfo placementUpdateJsonString = makeStringInfo();
appendStringInfo(placementUpdateJsonString, PLACEMENT_UPDATE_JSON_FORMAT,
placementUpdateEvent->updateType, placementUpdateEvent->shardId,
escapedSourceName->data, sourceNode->workerPort,
escapedTargetName->data, targetNode->workerPort);
Datum placementUpdateStringDatum = CStringGetDatum(
placementUpdateJsonString->data);
Datum placementUpdateJsonDatum = DirectFunctionCall1(json_in,
placementUpdateStringDatum);
placementUpdateJsonArray[placementUpdateIndex] = placementUpdateJsonDatum;
placementUpdateIndex++;
}
ArrayType *placementUpdateObject = construct_array(placementUpdateJsonArray,
placementUpdateCount, JSONOID,
-1, false, 'i');
return placementUpdateObject;
}
/*
* worker_node_responsive returns true if the given worker node is responsive.
* Otherwise, it returns false.
*/
Datum
worker_node_responsive(PG_FUNCTION_ARGS)
{
text *workerNameText = PG_GETARG_TEXT_PP(0);
uint32 workerPort = PG_GETARG_INT32(1);
int connectionFlag = FORCE_NEW_CONNECTION;
bool workerNodeResponsive = false;
const char *workerName = text_to_cstring(workerNameText);
MultiConnection *connection = GetNodeConnection(connectionFlag, workerName,
workerPort);
if (connection != NULL && connection->pgConn != NULL)
{
if (PQstatus(connection->pgConn) == CONNECTION_OK)
{
workerNodeResponsive = true;
}
CloseConnection(connection);
}
PG_RETURN_BOOL(workerNodeResponsive);
}