diff --git a/src/backend/distributed/metadata/metadata_cache.c b/src/backend/distributed/metadata/metadata_cache.c index b9f274495..432eb87fc 100644 --- a/src/backend/distributed/metadata/metadata_cache.c +++ b/src/backend/distributed/metadata/metadata_cache.c @@ -4379,6 +4379,8 @@ InitializeWorkerNodeCache(void) workerNode->isActive = currentNode->isActive; workerNode->nodeRole = currentNode->nodeRole; workerNode->shouldHaveShards = currentNode->shouldHaveShards; + workerNode->nodeprimarynodeid = currentNode->nodeprimarynodeid; + workerNode->nodeisreplica = currentNode->nodeisreplica; strlcpy(workerNode->nodeCluster, currentNode->nodeCluster, NAMEDATALEN); newWorkerNodeArray[workerNodeIndex++] = workerNode; diff --git a/src/backend/distributed/metadata/node_metadata.c b/src/backend/distributed/metadata/node_metadata.c index fa4de2f1d..d85adce1f 100644 --- a/src/backend/distributed/metadata/node_metadata.c +++ b/src/backend/distributed/metadata/node_metadata.c @@ -135,8 +135,6 @@ static void MarkNodesNotSyncedInLoopBackConnection(MetadataSyncContext *context, static void EnsureParentSessionHasExclusiveLockOnPgDistNode(pid_t parentSessionPid); static void SetNodeMetadata(MetadataSyncContext *context, bool localOnly); static void EnsureTransactionalMetadataSyncMode(void); -static void LockShardsInWorkerPlacementList(WorkerNode *workerNode, LOCKMODE - lockMode); static BackgroundWorkerHandle * CheckBackgroundWorkerToObtainLocks(int32 lock_cooldown); static BackgroundWorkerHandle * LockPlacementsWithBackgroundWorkersInPrimaryNode( WorkerNode *workerNode, bool force, int32 lock_cooldown); @@ -1189,6 +1187,27 @@ ActivateNodeList(MetadataSyncContext *context) SetNodeMetadata(context, localOnly); } +/* + * ActivateReplicaNodeAsPrimary sets the given worker node as primary and active + * in the pg_dist_node catalog and make the replica node as first class citizen. + */ +void +ActivateReplicaNodeAsPrimary(WorkerNode *workerNode) +{ + /* + * Set the node as primary and active. + */ + SetWorkerColumnLocalOnly(workerNode, Anum_pg_dist_node_noderole, + ObjectIdGetDatum(PrimaryNodeRoleId())); + SetWorkerColumnLocalOnly(workerNode, Anum_pg_dist_node_isactive, + BoolGetDatum(true)); + SetWorkerColumnLocalOnly(workerNode, Anum_pg_dist_node_nodeisreplica, + BoolGetDatum(false)); + SetWorkerColumnLocalOnly(workerNode, Anum_pg_dist_node_shouldhaveshards, + BoolGetDatum(true)); + SetWorkerColumnLocalOnly(workerNode, Anum_pg_dist_node_nodeprimarynodeid, + Int32GetDatum(0)); +} /* * Acquires shard metadata locks on all shards residing in the given worker node @@ -3710,7 +3729,7 @@ EnsureValidStreamingReplica(WorkerNode *primaryWorkerNode, char* replicaHostname StringInfo sysidQueryResInfo = (StringInfo) linitial(sysidList); char *sysidQueryResStr = sysidQueryResInfo->data; - ereport (NOTICE, (errmsg("system identifier of %s:%d is %s", + ereport (DEBUG2, (errmsg("system identifier of %s:%d is %s", replicaHostname, replicaPort, sysidQueryResStr))); /* We do not need the connection anymore */ @@ -3757,7 +3776,7 @@ EnsureValidStreamingReplica(WorkerNode *primaryWorkerNode, char* replicaHostname StringInfo primarySysidQueryResInfo = (StringInfo) linitial(primarySizeList); char *primarySysidQueryResStr = primarySysidQueryResInfo->data; - ereport (NOTICE, (errmsg("system identifier of %s:%d is %s", + ereport (DEBUG2, (errmsg("system identifier of %s:%d is %s", primaryWorkerNode->workerName, primaryWorkerNode->workerPort, primarySysidQueryResStr))); /* verify both identifiers */ if (strcmp(sysidQueryResStr, primarySysidQueryResStr) != 0) diff --git a/src/backend/distributed/operations/node_promotion.c b/src/backend/distributed/operations/node_promotion.c new file mode 100644 index 000000000..5d9934bde --- /dev/null +++ b/src/backend/distributed/operations/node_promotion.c @@ -0,0 +1,380 @@ +#include "postgres.h" +#include "utils/fmgrprotos.h" +#include "utils/pg_lsn.h" + +#include "distributed/argutils.h" +#include "distributed/remote_commands.h" +#include "distributed/metadata_cache.h" +#include "distributed/metadata_sync.h" +#include "distributed/shard_rebalancer.h" + + +static int64 GetReplicationLag(WorkerNode *primaryWorkerNode, WorkerNode *replicaWorkerNode); +static void BlockAllWritesToWorkerNode(WorkerNode *workerNode); +static bool GetNodeIsInRecoveryStatus(WorkerNode *workerNode); +static void PromoteReplicaNode(WorkerNode *replicaWorkerNode); + + +PG_FUNCTION_INFO_V1(citus_promote_replica_and_rebalance); + +Datum +citus_promote_replica_and_rebalance(PG_FUNCTION_ARGS) +{ + // Ensure superuser and coordinator + EnsureSuperUser(); + EnsureCoordinator(); + + // Get replica_nodeid argument + int32 replicaNodeIdArg = PG_GETARG_INT32(0); + + WorkerNode *replicaNode = NULL; + WorkerNode *primaryNode = NULL; + + // Lock pg_dist_node to prevent concurrent modifications during this operation + LockRelationOid(DistNodeRelationId(), RowExclusiveLock); + + replicaNode = FindNodeAnyClusterByNodeId(replicaNodeIdArg); + if (replicaNode == NULL) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Replica node with ID %d not found.", replicaNodeIdArg))); + } + + if (!replicaNode->nodeisreplica || replicaNode->nodeprimarynodeid == 0) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Node %s:%d (ID %d) is not a valid replica or its primary node ID is not set.", + replicaNode->workerName, replicaNode->workerPort, replicaNode->nodeId))); + } + + if (replicaNode->isActive) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Replica node %s:%d (ID %d) is already active and cannot be promoted.", + replicaNode->workerName, replicaNode->workerPort, replicaNode->nodeId))); + } + + primaryNode = FindNodeAnyClusterByNodeId(replicaNode->nodeprimarynodeid); + if (primaryNode == NULL) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Primary node with ID %d (for replica %s:%d) not found.", + replicaNode->nodeprimarynodeid, replicaNode->workerName, replicaNode->workerPort))); + } + + if (primaryNode->nodeisreplica) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Primary node %s:%d (ID %d) is itself a replica.", + primaryNode->workerName, primaryNode->workerPort, primaryNode->nodeId))); + } + + if (!primaryNode->isActive) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Primary node %s:%d (ID %d) is not active.", + primaryNode->workerName, primaryNode->workerPort, primaryNode->nodeId))); + } + /* Ensure the primary node is related to the replica node */ + if (primaryNode->nodeId != replicaNode->nodeprimarynodeid) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Replica node %s:%d (ID %d) is not replica of the primary node %s:%d (ID %d).", + replicaNode->workerName, replicaNode->workerPort, replicaNode->nodeId, + primaryNode->workerName, primaryNode->workerPort, primaryNode->nodeId))); + } + + ereport(NOTICE, (errmsg("Starting promotion process for replica node %s:%d (ID %d), original primary %s:%d (ID %d)", + replicaNode->workerName, replicaNode->workerPort, replicaNode->nodeId, + primaryNode->workerName, primaryNode->workerPort, primaryNode->nodeId))); + + /* Step 1: Block Writes on Original Primary's Shards */ + ereport(NOTICE, (errmsg("Blocking writes on shards of original primary node %s:%d (group %d)", + primaryNode->workerName, primaryNode->workerPort, primaryNode->groupId))); + + BlockAllWritesToWorkerNode(primaryNode); + + /* Step 2: Wait for Replica to Catch Up */ + ereport(NOTICE, (errmsg("Waiting for replica %s:%d to catch up with primary %s:%d", + replicaNode->workerName, replicaNode->workerPort, + primaryNode->workerName, primaryNode->workerPort))); + + bool caughtUp = false; + const int catchUpTimeoutSeconds = 300; // 5 minutes, TODO: Make GUC + const int sleepIntervalSeconds = 5; + int elapsedTimeSeconds = 0; + + while (elapsedTimeSeconds < catchUpTimeoutSeconds) + { + uint64 repLag = GetReplicationLag(primaryNode, replicaNode); + if (repLag <= 0) + { + caughtUp = true; + break; + } + pg_usleep(sleepIntervalSeconds * 1000000L); + elapsedTimeSeconds += sleepIntervalSeconds; + } + + if (!caughtUp) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Replica %s:%d failed to catch up with primary %s:%d within %d seconds.", + replicaNode->workerName, replicaNode->workerPort, + primaryNode->workerName, primaryNode->workerPort, + catchUpTimeoutSeconds))); + } + + ereport(NOTICE, (errmsg("Replica %s:%d is now caught up with primary %s:%d.", + replicaNode->workerName, replicaNode->workerPort, + primaryNode->workerName, primaryNode->workerPort))); + + + + /* Step 3: PostgreSQL Replica Promotion */ + ereport(NOTICE, (errmsg("Attempting to promote replica %s:%d via pg_promote().", + replicaNode->workerName, replicaNode->workerPort))); + + PromoteReplicaNode(replicaNode); + + /* Step 4: Update Replica Metadata in pg_dist_node on Coordinator */ + + ereport(NOTICE, (errmsg("Updating metadata for promoted replica %s:%d (ID %d)", + replicaNode->workerName, replicaNode->workerPort, replicaNode->nodeId))); + ActivateReplicaNodeAsPrimary(replicaNode); + + /* We need to sync metadata changes to all nodes before rebalancing shards + * since the rebalancing algorithm depends on the latest metadata. + */ + SyncNodeMetadataToNodes(); + + /* Step 5: Split Shards Between Primary and Replica */ + SplitShardsBetweenPrimaryAndReplica(primaryNode, replicaNode, PG_GETARG_NAME_OR_NULL(1)); + + + TransactionModifiedNodeMetadata = true; // Inform Citus about metadata change + TriggerNodeMetadataSyncOnCommit(); // Ensure changes are propagated + + + + ereport(NOTICE, (errmsg("Replica node %s:%d (ID %d) metadata updated. It is now a primary", + replicaNode->workerName, replicaNode->workerPort, replicaNode->nodeId))); + + + + /* TODO: Step 6: Unblock Writes (should be handled by transaction commit) */ + ereport(NOTICE, (errmsg("TODO: Step 6: Unblock Writes"))); + + PG_RETURN_VOID(); +} + + +/* + * GetReplicationLag calculates the replication lag between the primary and replica nodes. + * It returns the lag in bytes. + */ +static int64 +GetReplicationLag(WorkerNode *primaryWorkerNode, WorkerNode *replicaWorkerNode) +{ + +#if PG_VERSION_NUM >= 100000 + const char *primary_lsn_query = "SELECT pg_current_wal_lsn()"; + const char *replica_lsn_query = "SELECT pg_last_wal_replay_lsn()"; +#else + const char *primary_lsn_query = "SELECT pg_current_xlog_location()"; + const char *replica_lsn_query = "SELECT pg_last_xlog_replay_location()"; +#endif + + int connectionFlag = 0; + MultiConnection *primaryConnection = GetNodeConnection(connectionFlag, + primaryWorkerNode->workerName, + primaryWorkerNode->workerPort); + if (PQstatus(primaryConnection->pgConn) != CONNECTION_OK) + { + ereport(ERROR, (errmsg("cannot connect to %s:%d to fetch replication status", + primaryWorkerNode->workerName, primaryWorkerNode->workerPort))); + } + MultiConnection *replicaConnection = GetNodeConnection(connectionFlag, + replicaWorkerNode->workerName, + replicaWorkerNode->workerPort); + + if (PQstatus(replicaConnection->pgConn) != CONNECTION_OK) + { + ereport(ERROR, (errmsg("cannot connect to %s:%d to fetch replication status", + replicaWorkerNode->workerName, replicaWorkerNode->workerPort))); + } + + int primaryResultCode = SendRemoteCommand(primaryConnection, primary_lsn_query); + if (primaryResultCode == 0) + { + ReportConnectionError(primaryConnection, ERROR); + } + + PGresult *primaryResult = GetRemoteCommandResult(primaryConnection, true); + if (!IsResponseOK(primaryResult)) + { + ReportResultError(primaryConnection, primaryResult, ERROR); + } + + int replicaResultCode = SendRemoteCommand(replicaConnection, replica_lsn_query); + if (replicaResultCode == 0) + { + ReportConnectionError(replicaConnection, ERROR); + } + PGresult *replicaResult = GetRemoteCommandResult(replicaConnection, true); + if (!IsResponseOK(replicaResult)) + { + ReportResultError(replicaConnection, replicaResult, ERROR); + } + + + List *primaryLsnList = ReadFirstColumnAsText(primaryResult); + if (list_length(primaryLsnList) != 1) + { + PQclear(primaryResult); + ClearResults(primaryConnection, true); + CloseConnection(primaryConnection); + + ereport(ERROR, (errcode(ERRCODE_CONNECTION_FAILURE), + errmsg("cannot parse get primary LSN result from %s:%d", + primaryWorkerNode->workerName, + primaryWorkerNode->workerPort))); + + } + StringInfo primaryLsnQueryResInfo = (StringInfo) linitial(primaryLsnList); + char *primary_lsn_str = primaryLsnQueryResInfo->data; + + List *replicaLsnList = ReadFirstColumnAsText(replicaResult); + if (list_length(replicaLsnList) != 1) + { + PQclear(replicaResult); + ClearResults(replicaConnection, true); + CloseConnection(replicaConnection); + + ereport(ERROR, (errcode(ERRCODE_CONNECTION_FAILURE), + errmsg("cannot parse get replica LSN result from %s:%d", + replicaWorkerNode->workerName, + replicaWorkerNode->workerPort))); + + } + StringInfo replicaLsnQueryResInfo = (StringInfo) linitial(replicaLsnList); + char *replica_lsn_str = replicaLsnQueryResInfo->data; + + if (!primary_lsn_str || !replica_lsn_str) + return -1; + + int64 primary_lsn = DatumGetLSN(DirectFunctionCall1(pg_lsn_in, CStringGetDatum(primary_lsn_str))); + int64 replica_lsn = DatumGetLSN(DirectFunctionCall1(pg_lsn_in, CStringGetDatum(replica_lsn_str))); + + int64 lag_bytes = primary_lsn - replica_lsn; + + PQclear(primaryResult); + ForgetResults(primaryConnection); + CloseConnection(primaryConnection); + + PQclear(replicaResult); + ForgetResults(replicaConnection); + CloseConnection(replicaConnection); + + ereport(NOTICE, (errmsg("replication lag between %s:%d and %s:%d is %ld bytes", + primaryWorkerNode->workerName, primaryWorkerNode->workerPort, + replicaWorkerNode->workerName, replicaWorkerNode->workerPort, + lag_bytes))); + return lag_bytes; +} + +static void +PromoteReplicaNode(WorkerNode *replicaWorkerNode) +{ + int connectionFlag = 0; + MultiConnection *replicaConnection = GetNodeConnection(connectionFlag, + replicaWorkerNode->workerName, + replicaWorkerNode->workerPort); + + if (PQstatus(replicaConnection->pgConn) != CONNECTION_OK) + { + ereport(ERROR, (errmsg("cannot connect to %s:%d to promote replica", + replicaWorkerNode->workerName, replicaWorkerNode->workerPort))); + } + + const char *promoteQuery = "SELECT pg_promote(wait := true);"; + int resultCode = SendRemoteCommand(replicaConnection, promoteQuery); + if (resultCode == 0) + { + ReportConnectionError(replicaConnection, ERROR); + } + ForgetResults(replicaConnection); + CloseConnection(replicaConnection); + /* connect again and verify the replica is promoted */ + if ( GetNodeIsInRecoveryStatus(replicaWorkerNode) ) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Failed to promote replica %s:%d (ID %d). It is still in recovery.", + replicaWorkerNode->workerName, replicaWorkerNode->workerPort, replicaWorkerNode->nodeId))); + } + else + { + ereport(NOTICE, (errmsg("Replica node %s:%d (ID %d) has been successfully promoted.", + replicaWorkerNode->workerName, replicaWorkerNode->workerPort, replicaWorkerNode->nodeId))); + } +} + +static void +BlockAllWritesToWorkerNode(WorkerNode *workerNode) +{ + ereport(NOTICE, (errmsg("Blocking all writes to worker node %s:%d (ID %d)", + workerNode->workerName, workerNode->workerPort, workerNode->nodeId))); + // List *placementsOnOldPrimaryGroup = AllShardPlacementsOnNodeGroup(workerNode->groupId); + + LockShardsInWorkerPlacementList(workerNode, AccessExclusiveLock); +} + +bool +GetNodeIsInRecoveryStatus(WorkerNode *workerNode) +{ + int connectionFlag = 0; + MultiConnection *nodeConnection = GetNodeConnection(connectionFlag, + workerNode->workerName, + workerNode->workerPort); + + if (PQstatus(nodeConnection->pgConn) != CONNECTION_OK) + { + ereport(ERROR, (errmsg("cannot connect to %s:%d to check recovery status", + workerNode->workerName, workerNode->workerPort))); + } + + const char *recoveryQuery = "SELECT pg_is_in_recovery();"; + int resultCode = SendRemoteCommand(nodeConnection, recoveryQuery); + if (resultCode == 0) + { + ReportConnectionError(nodeConnection, ERROR); + } + + PGresult *result = GetRemoteCommandResult(nodeConnection, true); + if (!IsResponseOK(result)) + { + ReportResultError(nodeConnection, result, ERROR); + } + + List *recoveryStatusList = ReadFirstColumnAsText(result); + if (list_length(recoveryStatusList) != 1) + { + PQclear(result); + ClearResults(nodeConnection, true); + CloseConnection(nodeConnection); + + ereport(ERROR, (errcode(ERRCODE_CONNECTION_FAILURE), + errmsg("cannot parse recovery status result from %s:%d", + workerNode->workerName, + workerNode->workerPort))); + } + + StringInfo recoveryStatusInfo = (StringInfo) linitial(recoveryStatusList); + bool isInRecovery = (strcmp(recoveryStatusInfo->data, "t") == 0) || (strcmp(recoveryStatusInfo->data, "true") == 0); + + PQclear(result); + ForgetResults(nodeConnection); + CloseConnection(nodeConnection); + + return isInRecovery; +} \ No newline at end of file diff --git a/src/backend/distributed/operations/shard_rebalancer.c b/src/backend/distributed/operations/shard_rebalancer.c index 074f1bed0..20dc7299e 100644 --- a/src/backend/distributed/operations/shard_rebalancer.c +++ b/src/backend/distributed/operations/shard_rebalancer.c @@ -81,9 +81,27 @@ typedef struct RebalanceOptions Form_pg_dist_rebalance_strategy rebalanceStrategy; const char *operationName; WorkerNode *workerNode; + List *involvedWorkerNodeList; } RebalanceOptions; +typedef struct SplitPrimaryReplicaShards +{ + /* + * primaryShardPlacementList contains the placements that + * should stay on primary worker node. + */ + List *primaryShardIdList; + /* + * replicaShardPlacementList contains the placements that should stay on + * replica worker node. + */ + List *replicaShardIdList; +} SplitPrimaryReplicaShards; + + +static SplitPrimaryReplicaShards * +GetPrimaryReplicaSplitRebalanceSteps(RebalanceOptions *options, WorkerNode* replicaNode); /* * RebalanceState is used to keep the internal state of the rebalance * algorithm in one place. @@ -318,6 +336,7 @@ PG_FUNCTION_INFO_V1(pg_dist_rebalance_strategy_enterprise_check); PG_FUNCTION_INFO_V1(citus_rebalance_start); PG_FUNCTION_INFO_V1(citus_rebalance_stop); PG_FUNCTION_INFO_V1(citus_rebalance_wait); +PG_FUNCTION_INFO_V1(get_snapshot_based_node_split_plan); bool RunningUnderCitusTestSuite = false; int MaxRebalancerLoggedIgnoredMoves = 5; @@ -517,8 +536,16 @@ GetRebalanceSteps(RebalanceOptions *options) .context = &context, }; + if (options->involvedWorkerNodeList == NULL) + { + /* + * If the user did not specify a list of worker nodes, we use all the + * active worker nodes. + */ + options->involvedWorkerNodeList = SortedActiveWorkers(); + } /* sort the lists to make the function more deterministic */ - List *activeWorkerList = SortedActiveWorkers(); + List *activeWorkerList = options->involvedWorkerNodeList; //SortedActiveWorkers(); int shardAllowedNodeCount = 0; WorkerNode *workerNode = NULL; foreach_declared_ptr(workerNode, activeWorkerList) @@ -981,6 +1008,7 @@ rebalance_table_shards(PG_FUNCTION_ARGS) .excludedShardArray = PG_GETARG_ARRAYTYPE_P(3), .drainOnly = PG_GETARG_BOOL(5), .rebalanceStrategy = strategy, + .involvedWorkerNodeList = NULL, .improvementThreshold = strategy->improvementThreshold, }; Oid shardTransferModeOid = PG_GETARG_OID(4); @@ -3546,6 +3574,342 @@ EnsureShardCostUDF(Oid functionOid) ReleaseSysCache(proctup); } +/* + * SplitShardsBetweenPrimaryAndReplica splits the shards in shardPlacementList + * between the primary and replica nodes, adding them to the respective lists. + */ +void +SplitShardsBetweenPrimaryAndReplica(WorkerNode *primaryNode, + WorkerNode *replicaNode, + Name strategyName) +{ + CheckCitusVersion(ERROR); + + List *relationIdList = NIL; + relationIdList = NonColocatedDistRelationIdList(); + + Form_pg_dist_rebalance_strategy strategy = GetRebalanceStrategy(strategyName);/* We use default strategy for now */ + + RebalanceOptions options = { + .relationIdList = relationIdList, + .threshold = 0, /* Threshold is not strictly needed for two nodes */ + .maxShardMoves = -1, /* No limit on moves between these two nodes */ + .excludedShardArray = construct_empty_array(INT8OID), + .drainOnly = false, /* Not a drain operation */ + .rebalanceStrategy = strategy, + .improvementThreshold = 0, /* Consider all beneficial moves */ + .workerNode = primaryNode /* indicate Primary node as a source node */ + }; + + SplitPrimaryReplicaShards *splitShards = NULL; + splitShards = GetPrimaryReplicaSplitRebalanceSteps(&options, replicaNode); + AdjustShardsForPrimaryReplicaNodeSplit(primaryNode, replicaNode, + splitShards->primaryShardIdList, splitShards->replicaShardIdList); +} + +/* + * GetPrimaryReplicaSplitRebalanceSteps returns a List of PlacementUpdateEvents that are needed to + * rebalance a list of tables. + */ +static SplitPrimaryReplicaShards * +GetPrimaryReplicaSplitRebalanceSteps(RebalanceOptions *options, WorkerNode* replicaNode) +{ + WorkerNode *sourceNode = options->workerNode; + WorkerNode *targetNode = replicaNode; + + /* Initialize rebalance plan functions and context */ + EnsureShardCostUDF(options->rebalanceStrategy->shardCostFunction); + EnsureNodeCapacityUDF(options->rebalanceStrategy->nodeCapacityFunction); + EnsureShardAllowedOnNodeUDF(options->rebalanceStrategy->shardAllowedOnNodeFunction); + + RebalanceContext context; + memset(&context, 0, sizeof(RebalanceContext)); + fmgr_info(options->rebalanceStrategy->shardCostFunction, &context.shardCostUDF); + fmgr_info(options->rebalanceStrategy->nodeCapacityFunction, &context.nodeCapacityUDF); + fmgr_info(options->rebalanceStrategy->shardAllowedOnNodeFunction, + &context.shardAllowedOnNodeUDF); + + RebalancePlanFunctions rebalancePlanFunctions = { + .shardAllowedOnNode = ShardAllowedOnNode, + .nodeCapacity = NodeCapacity, + .shardCost = GetShardCost, + .context = &context, + }; + + /* + * Collect all active shard placements on the source node for the given relations. + * Unlike the main rebalancer, we build a single list of all relevant source placements + * across all specified relations (or all relations if none specified). + */ + List *allSourcePlacements = NIL; + Oid relationIdItr = InvalidOid; + foreach_declared_oid(relationIdItr, options->relationIdList) + { + List *shardPlacementList = FullShardPlacementList(relationIdItr, + options->excludedShardArray); + List *activeShardPlacementsForRelation = + FilterShardPlacementList(shardPlacementList, IsActiveShardPlacement); + + ShardPlacement *placement = NULL; + foreach_declared_ptr(placement, activeShardPlacementsForRelation) + { + if (placement->nodeId == sourceNode->nodeId) + { + /* Ensure we don't add duplicate shardId if it's somehow listed under multiple relations */ + bool alreadyAdded = false; + ShardPlacement *existingPlacement = NULL; + foreach_declared_ptr(existingPlacement, allSourcePlacements) + { + if (existingPlacement->shardId == placement->shardId) + { + alreadyAdded = true; + break; + } + } + if (!alreadyAdded) + { + allSourcePlacements = lappend(allSourcePlacements, placement); + } + } + } + } + + List *activeWorkerList = list_make2(options->workerNode, replicaNode); + SplitPrimaryReplicaShards *splitShards = palloc0(sizeof(SplitPrimaryReplicaShards)); + splitShards->primaryShardIdList = NIL; + splitShards->replicaShardIdList = NIL; + + if (list_length(allSourcePlacements) > 0) + { + /* + * Initialize RebalanceState considering only the source node's shards + * and the two active workers (source and target). + */ + RebalanceState *state = InitRebalanceState(activeWorkerList, allSourcePlacements, &rebalancePlanFunctions); + + NodeFillState *sourceFillState = NULL; + NodeFillState *targetFillState = NULL; + ListCell *fsc = NULL; + + /* Identify the fill states for our specific source and target nodes */ + foreach(fsc, state->fillStateListAsc) /* Could be fillStateListDesc too, order doesn't matter here */ + { + NodeFillState *fs = (NodeFillState *) lfirst(fsc); + if (fs->node->nodeId == sourceNode->nodeId) + { + sourceFillState = fs; + } + else if (fs->node->nodeId == targetNode->nodeId) + { + targetFillState = fs; + } + } + + if (sourceFillState != NULL && targetFillState != NULL) + { + /* + * The goal is to move roughly half the total cost from source to target. + * The target node is assumed to be empty or its existing load is not + * considered for this specific two-node balancing plan's shard distribution. + * We calculate costs based *only* on the shards currently on the source node. + */ + /* + * The core idea is to simulate the balancing process between these two nodes. + * We have all shards on sourceFillState. TargetFillState is initially empty (in terms of these specific shards). + * We want to move shards from source to target until their costs are as balanced as possible. + */ + float4 sourceCurrentCost = sourceFillState->totalCost; + float4 targetCurrentCost = 0; /* Representing cost on target from these source shards */ + + /* Sort shards on source node by cost (descending). This is a common heuristic. */ + sourceFillState->shardCostListDesc = SortList(sourceFillState->shardCostListDesc, CompareShardCostDesc); + + List *potentialMoves = NIL; + ListCell *lc_shardcost = NULL; + + /* + * Iterate through each shard on the source node. For each shard, decide if moving it + * to the target node would improve the balance (or is necessary to reach balance). + * A simple greedy approach: move shard if target node's current cost is less than source's. + */ + foreach(lc_shardcost, sourceFillState->shardCostListDesc) + { + ShardCost *shardToConsider = (ShardCost *) lfirst(lc_shardcost); + + /* Check if shard is allowed on the target node */ + // if (!state->functions->shardAllowedOnNode(shardToConsider->shardId, + // targetNode, + // state->functions->context)) + // { + // splitShards->primaryShardIdList = lappend_int(splitShards->primaryShardIdList, shardToConsider->shardId); + // continue; /* Cannot move this shard to the target */ + // } + + /* + * If moving this shard makes the target less loaded than the source would become, + * or if target is simply less loaded currently, consider the move. + * More accurately, we move if target's cost + shard's cost < source's cost - shard's cost (approximately) + * or if target is significantly emptier. + * The condition (targetCurrentCost < sourceCurrentCost - shardToConsider->cost) is a greedy choice. + * A better check: would moving this shard reduce the difference in costs? + * Current difference: abs(sourceCurrentCost - targetCurrentCost) + * Difference after move: abs((sourceCurrentCost - shardToConsider->cost) - (targetCurrentCost + shardToConsider->cost)) + * Move if new difference is smaller. + */ + float4 costOfShard = shardToConsider->cost; + float4 diffBefore = fabsf(sourceCurrentCost - targetCurrentCost); + float4 diffAfter = fabsf((sourceCurrentCost - costOfShard) - (targetCurrentCost + costOfShard)); + + if (diffAfter < diffBefore) + { + PlacementUpdateEvent *update = palloc0(sizeof(PlacementUpdateEvent)); + update->shardId = shardToConsider->shardId; + update->sourceNode = sourceNode; + update->targetNode = targetNode; + update->updateType = PLACEMENT_UPDATE_MOVE; + potentialMoves = lappend(potentialMoves, update); + splitShards->replicaShardIdList = lappend_int(splitShards->replicaShardIdList, shardToConsider->shardId); + + + /* Update simulated costs for the next iteration */ + sourceCurrentCost -= costOfShard; + targetCurrentCost += costOfShard; + } + else + { + splitShards->primaryShardIdList = lappend_int(splitShards->primaryShardIdList, shardToConsider->shardId); + } + } + } + /* RebalanceState is in memory context, will be cleaned up */ + } + return splitShards; +} + +/* + * Snapshot-based node split plan outputs the shard placement plan + * for primary and replica based node split + * + * SQL signature: + * get_snapshot_based_node_split_plan( + * primary_node_name text, + * primary_node_port integer, + * replica_node_name text, + * replica_node_port integer, + * rebalance_strategy name DEFAULT NULL + * + */ +Datum +get_snapshot_based_node_split_plan(PG_FUNCTION_ARGS) +{ + CheckCitusVersion(ERROR); + + text *primaryNodeNameText = PG_GETARG_TEXT_P(0); + int32 primaryNodePort = PG_GETARG_INT32(1); + text *replicaNodeNameText = PG_GETARG_TEXT_P(2); + int32 replicaNodePort = PG_GETARG_INT32(3); + + char *primaryNodeName = text_to_cstring(primaryNodeNameText); + char *replicaNodeName = text_to_cstring(replicaNodeNameText); + + WorkerNode *primaryNode = FindWorkerNodeOrError(primaryNodeName, primaryNodePort); + WorkerNode *replicaNode = FindWorkerNodeOrError(replicaNodeName, replicaNodePort); + + if (!replicaNode->nodeisreplica || replicaNode->nodeprimarynodeid == 0) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Node %s:%d (ID %d) is not a valid replica or its primary node ID is not set.", + replicaNode->workerName, replicaNode->workerPort, replicaNode->nodeId))); + } + if (primaryNode->nodeisreplica) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Primary node %s:%d (ID %d) is itself a replica.", + primaryNode->workerName, primaryNode->workerPort, primaryNode->nodeId))); + } + /* Ensure the primary node is related to the replica node */ + if (primaryNode->nodeId != replicaNode->nodeprimarynodeid) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("Replica node %s:%d (ID %d) is not replica of the primary node %s:%d (ID %d).", + replicaNode->workerName, replicaNode->workerPort, replicaNode->nodeId, + primaryNode->workerName, primaryNode->workerPort, primaryNode->nodeId))); + } + + List *relationIdList = NIL; + relationIdList = NonColocatedDistRelationIdList(); + + Form_pg_dist_rebalance_strategy strategy = GetRebalanceStrategy( + PG_GETARG_NAME_OR_NULL(4)); + + RebalanceOptions options = { + .relationIdList = relationIdList, + .threshold = 0, /* Threshold is not strictly needed for two nodes */ + .maxShardMoves = -1, /* No limit on moves between these two nodes */ + .excludedShardArray = construct_empty_array(INT8OID), + .drainOnly = false, /* Not a drain operation */ + .rebalanceStrategy = strategy, + .improvementThreshold = 0, /* Consider all beneficial moves */ + .workerNode = primaryNode /* indicate Primary node as a source node */ + }; + + SplitPrimaryReplicaShards *splitShards = NULL; + splitShards = GetPrimaryReplicaSplitRebalanceSteps(&options, replicaNode); + + if (splitShards == NULL) + { + ereport(ERROR, (errmsg("No shards to split between primary and replica nodes."))); + } + + int shardId = 0; + TupleDesc tupdesc; + Tuplestorestate *tupstore = SetupTuplestore(fcinfo, &tupdesc); + Datum values[4]; + bool nulls[4]; + + + foreach_declared_int(shardId, splitShards->primaryShardIdList) + { + ShardInterval *shardInterval = LoadShardInterval(shardId); + List *colocatedShardList = ColocatedShardIntervalList(shardInterval); + ListCell *colocatedShardCell = NULL; + foreach(colocatedShardCell, colocatedShardList) + { + ShardInterval *colocatedShard = lfirst(colocatedShardCell); + int colocatedShardId = colocatedShard->shardId; + memset(values, 0, sizeof(values)); + memset(nulls, 0, sizeof(nulls)); + + values[0] = ObjectIdGetDatum(RelationIdForShard(colocatedShardId)); + values[1] = UInt64GetDatum(colocatedShardId); + values[2] = UInt64GetDatum(ShardLength(colocatedShardId)); + values[3] = PointerGetDatum(cstring_to_text("Primary Node")); + tuplestore_putvalues(tupstore, tupdesc, values, nulls); + } + } + + foreach_declared_int(shardId, splitShards->replicaShardIdList) + { + ShardInterval *shardInterval = LoadShardInterval(shardId); + List *colocatedShardList = ColocatedShardIntervalList(shardInterval); + ListCell *colocatedShardCell = NULL; + foreach(colocatedShardCell, colocatedShardList) + { + ShardInterval *colocatedShard = lfirst(colocatedShardCell); + int colocatedShardId = colocatedShard->shardId; + memset(values, 0, sizeof(values)); + memset(nulls, 0, sizeof(nulls)); + + values[0] = ObjectIdGetDatum(RelationIdForShard(colocatedShardId)); + values[1] = UInt64GetDatum(colocatedShardId); + values[2] = UInt64GetDatum(ShardLength(colocatedShardId)); + values[3] = PointerGetDatum(cstring_to_text("Replica Node")); + tuplestore_putvalues(tupstore, tupdesc, values, nulls); + } + } + + return (Datum) 0; +} /* * EnsureNodeCapacityUDF checks that the UDF matching the oid has the correct diff --git a/src/backend/distributed/operations/shard_transfer.c b/src/backend/distributed/operations/shard_transfer.c index b7d07b2cf..3922b862e 100644 --- a/src/backend/distributed/operations/shard_transfer.c +++ b/src/backend/distributed/operations/shard_transfer.c @@ -573,6 +573,44 @@ TransferShards(int64 shardId, char *sourceNodeName, FinalizeCurrentProgressMonitor(); } +/* + * AdjustShardsForPrimaryReplicaNodeSplit is called when a primary-replica node split + * occurs. It adjusts the shard placements such that the shards that should be on the + * primary node are removed from the replica node, and vice versa. + * + * This function does not move any data; it only updates the shard placement metadata. + */ +void +AdjustShardsForPrimaryReplicaNodeSplit(WorkerNode *primaryNode, + WorkerNode *replicaNode, + List* primaryShardList, + List* replicaShardList) +{ + int shardId = 0; + /* + * Remove all shards from the replica that should reside on the primary node, + * and update the shard placement metadata for shards that will now be served + * from the replica node. No data movement is required; we only need to drop + * the relevant shards from the replica and primary nodes and update the + * corresponding shard placement metadata. + */ + foreach_declared_int(shardId, primaryShardList) + { + ShardInterval *shardInterval = LoadShardInterval(shardId); + List *colocatedShardList = ColocatedShardIntervalList(shardInterval); + /* TODO: Drops shard table here */ + } + /* Now drop all shards from primary that need to be on the replica node */ + foreach_declared_int(shardId, replicaShardList) + { + ShardInterval *shardInterval = LoadShardInterval(shardId); + List *colocatedShardList = ColocatedShardIntervalList(shardInterval); + UpdateColocatedShardPlacementMetadataOnWorkers(shardId, + primaryNode->workerName, primaryNode->workerPort, + replicaNode->workerName, replicaNode->workerPort); + /* TODO: Drop the not required table on primary here */ + } +} /* * Insert deferred cleanup records. diff --git a/src/backend/distributed/sql/citus--13.0-1--13.1-1.sql b/src/backend/distributed/sql/citus--13.0-1--13.1-1.sql index 76ae1f596..009dc2dd2 100644 --- a/src/backend/distributed/sql/citus--13.0-1--13.1-1.sql +++ b/src/backend/distributed/sql/citus--13.0-1--13.1-1.sql @@ -55,7 +55,8 @@ DROP VIEW IF EXISTS pg_catalog.citus_lock_waits; #include "cat_upgrades/add_replica_info_to_pg_dist_node.sql" #include "udfs/citus_add_replica_node/13.1-1.sql" #include "udfs/citus_remove_replica_node/13.1-1.sql" - +#include "udfs/citus_promote_replica_and_rebalance/13.1-1.sql" +#include "udfs/get_snapshot_based_node_split_plan/13.1-1.sql" -- Since shard_name/13.1-1.sql first drops the function and then creates it, we first -- need to drop citus_shards view since that view depends on this function. And immediately -- after creating the function, we recreate citus_shards view again. diff --git a/src/backend/distributed/sql/udfs/citus_promote_replica_and_rebalance/13.1-1.sql b/src/backend/distributed/sql/udfs/citus_promote_replica_and_rebalance/13.1-1.sql new file mode 100644 index 000000000..274e1f727 --- /dev/null +++ b/src/backend/distributed/sql/udfs/citus_promote_replica_and_rebalance/13.1-1.sql @@ -0,0 +1,12 @@ +CREATE OR REPLACE FUNCTION pg_catalog.citus_promote_replica_and_rebalance( + replica_nodeid integer, + rebalance_strategy name DEFAULT NULL +) +RETURNS VOID +AS 'MODULE_PATHNAME' +LANGUAGE C VOLATILE; + +COMMENT ON FUNCTION pg_catalog.citus_promote_replica_and_rebalance(integer, name) IS +'Promotes a registered replica node to a primary, performs necessary metadata updates, and rebalances a portion of shards from its original primary to the newly promoted node.'; + +REVOKE ALL ON FUNCTION pg_catalog.citus_promote_replica_and_rebalance(integer, name) FROM PUBLIC; diff --git a/src/backend/distributed/sql/udfs/get_snapshot_based_node_split_plan/13.1-1.sql b/src/backend/distributed/sql/udfs/get_snapshot_based_node_split_plan/13.1-1.sql new file mode 100644 index 000000000..f2d294315 --- /dev/null +++ b/src/backend/distributed/sql/udfs/get_snapshot_based_node_split_plan/13.1-1.sql @@ -0,0 +1,18 @@ +CREATE OR REPLACE FUNCTION pg_catalog.get_snapshot_based_node_split_plan( + primary_node_name text, + primary_node_port integer, + replica_node_name text, + replica_node_port integer, + rebalance_strategy name DEFAULT NULL + ) + RETURNS TABLE (table_name regclass, + shardid bigint, + shard_size bigint, + placement_node text) + AS 'MODULE_PATHNAME' + LANGUAGE C VOLATILE; + +COMMENT ON FUNCTION pg_catalog.get_snapshot_based_node_split_plan(text, int, text, int, name) + IS 'shows the shard placements to balance shards between primary and replica worker nodes'; + +REVOKE ALL ON FUNCTION pg_catalog.get_snapshot_based_node_split_plan(text, int, text, int, name) FROM PUBLIC; diff --git a/src/include/distributed/metadata_utility.h b/src/include/distributed/metadata_utility.h index 38c13eb51..fd146b576 100644 --- a/src/include/distributed/metadata_utility.h +++ b/src/include/distributed/metadata_utility.h @@ -466,4 +466,7 @@ extern bool IsBackgroundJobStatusTerminal(BackgroundJobStatus status); extern bool IsBackgroundTaskStatusTerminal(BackgroundTaskStatus status); extern Oid BackgroundJobStatusOid(BackgroundJobStatus status); extern Oid BackgroundTaskStatusOid(BackgroundTaskStatus status); +/* from node_metadata.c */ +extern void LockShardsInWorkerPlacementList(WorkerNode *workerNode, LOCKMODE lockMode); +extern void ActivateReplicaNodeAsPrimary(WorkerNode *workerNode); #endif /* METADATA_UTILITY_H */ diff --git a/src/include/distributed/shard_rebalancer.h b/src/include/distributed/shard_rebalancer.h index 79414eb3c..8ea5fb1d0 100644 --- a/src/include/distributed/shard_rebalancer.h +++ b/src/include/distributed/shard_rebalancer.h @@ -222,4 +222,7 @@ extern void SetupRebalanceMonitor(List *placementUpdateList, uint64 initialProgressState, PlacementUpdateStatus initialStatus); +extern void SplitShardsBetweenPrimaryAndReplica(WorkerNode *primaryNode, + WorkerNode *replicaNode, + Name strategyName); #endif /* SHARD_REBALANCER_H */ diff --git a/src/include/distributed/shard_transfer.h b/src/include/distributed/shard_transfer.h index c1621879b..0d7b641a9 100644 --- a/src/include/distributed/shard_transfer.h +++ b/src/include/distributed/shard_transfer.h @@ -41,3 +41,9 @@ extern void UpdatePlacementUpdateStatusForShardIntervalList(List *shardIntervalL extern void InsertDeferredDropCleanupRecordsForShards(List *shardIntervalList); extern void InsertCleanupRecordsForShardPlacementsOnNode(List *shardIntervalList, int32 groupId); + +extern void +AdjustShardsForPrimaryReplicaNodeSplit(WorkerNode *primaryNode, + WorkerNode *replicaNode, + List* primaryShardList, + List* replicaShardList); \ No newline at end of file