diff --git a/src/backend/distributed/master/master_modify_multiple_shards.c b/src/backend/distributed/master/master_modify_multiple_shards.c index d54407341..e5827098d 100644 --- a/src/backend/distributed/master/master_modify_multiple_shards.c +++ b/src/backend/distributed/master/master_modify_multiple_shards.c @@ -26,8 +26,7 @@ #include "commands/event_trigger.h" #include "distributed/citus_clauses.h" #include "distributed/citus_ruleutils.h" -#include "distributed/commit_protocol.h" -#include "distributed/connection_cache.h" +#include "distributed/connection_management.h" #include "distributed/listutils.h" #include "distributed/master_metadata_utility.h" #include "distributed/master_protocol.h" @@ -37,9 +36,9 @@ #include "distributed/multi_router_executor.h" #include "distributed/multi_router_planner.h" #include "distributed/multi_server_executor.h" -#include "distributed/multi_shard_transaction.h" #include "distributed/pg_dist_shard.h" #include "distributed/pg_dist_partition.h" +#include "distributed/remote_commands.h" #include "distributed/resource_lock.h" #include "distributed/worker_protocol.h" #include "optimizer/clauses.h" @@ -57,9 +56,6 @@ static void LockShardsForModify(List *shardIntervalList); static bool HasReplication(List *shardIntervalList); -static int SendQueryToShards(Query *query, List *shardIntervalList, Oid relationId); -static int SendQueryToPlacements(char *shardQueryString, - ShardConnections *shardConnections); PG_FUNCTION_INFO_V1(master_modify_multiple_shards); @@ -91,6 +87,8 @@ master_modify_multiple_shards(PG_FUNCTION_ARGS) PreventTransactionChain(isTopLevel, "master_modify_multiple_shards"); + BeginCoordinatedTransaction(); + queryTreeNode = ParseTreeNode(queryString); if (IsA(queryTreeNode, DeleteStmt)) { @@ -161,12 +159,10 @@ master_modify_multiple_shards(PG_FUNCTION_ARGS) prunedShardIntervalList = PruneShardList(relationId, tableId, restrictClauseList, shardIntervalList); - CHECK_FOR_INTERRUPTS(); - LockShardsForModify(prunedShardIntervalList); - affectedTupleCount = SendQueryToShards(modifyQuery, prunedShardIntervalList, - relationId); + affectedTupleCount = ExecuteQueryOnPlacements(modifyQuery, prunedShardIntervalList, + relationId); PG_RETURN_INT32(affectedTupleCount); } @@ -227,119 +223,3 @@ HasReplication(List *shardIntervalList) return hasReplication; } - - -/* - * SendQueryToShards executes the given query in all placements of the given - * shard list and returns the total affected tuple count. The execution is done - * in a distributed transaction and the commit protocol is decided according to - * the value of citus.multi_shard_commit_protocol parameter. SendQueryToShards - * does not acquire locks for the shards so it is advised to acquire locks to - * the shards when necessary before calling SendQueryToShards. - */ -static int -SendQueryToShards(Query *query, List *shardIntervalList, Oid relationId) -{ - int affectedTupleCount = 0; - char *relationOwner = TableOwner(relationId); - ListCell *shardIntervalCell = NULL; - - OpenTransactionsToAllShardPlacements(shardIntervalList, relationOwner); - - foreach(shardIntervalCell, shardIntervalList) - { - ShardInterval *shardInterval = (ShardInterval *) lfirst( - shardIntervalCell); - Oid relationId = shardInterval->relationId; - uint64 shardId = shardInterval->shardId; - bool shardConnectionsFound = false; - ShardConnections *shardConnections = NULL; - StringInfo shardQueryString = makeStringInfo(); - char *shardQueryStringData = NULL; - int shardAffectedTupleCount = -1; - - shardConnections = GetShardConnections(shardId, &shardConnectionsFound); - Assert(shardConnectionsFound); - - deparse_shard_query(query, relationId, shardId, shardQueryString); - shardQueryStringData = shardQueryString->data; - shardAffectedTupleCount = SendQueryToPlacements(shardQueryStringData, - shardConnections); - affectedTupleCount += shardAffectedTupleCount; - } - - /* check for cancellation one last time before returning */ - CHECK_FOR_INTERRUPTS(); - - return affectedTupleCount; -} - - -/* - * SendQueryToPlacements sends the given query string to all given placement - * connections of a shard. CommitRemoteTransactions or AbortRemoteTransactions - * should be called after all queries have been sent successfully. - */ -static int -SendQueryToPlacements(char *shardQueryString, ShardConnections *shardConnections) -{ - uint64 shardId = shardConnections->shardId; - List *connectionList = shardConnections->connectionList; - ListCell *connectionCell = NULL; - int32 shardAffectedTupleCount = -1; - - Assert(connectionList != NIL); - - foreach(connectionCell, connectionList) - { - TransactionConnection *transactionConnection = - (TransactionConnection *) lfirst(connectionCell); - PGconn *connection = transactionConnection->connection; - PGresult *result = NULL; - char *placementAffectedTupleString = NULL; - int32 placementAffectedTupleCount = -1; - - CHECK_FOR_INTERRUPTS(); - - /* send the query */ - result = PQexec(connection, shardQueryString); - if (PQresultStatus(result) != PGRES_COMMAND_OK) - { - WarnRemoteError(connection, result); - ereport(ERROR, (errmsg("could not send query to shard placement"))); - } - - placementAffectedTupleString = PQcmdTuples(result); - - /* returned tuple count is empty for utility commands, use 0 as affected count */ - if (*placementAffectedTupleString == '\0') - { - placementAffectedTupleCount = 0; - } - else - { - placementAffectedTupleCount = pg_atoi(placementAffectedTupleString, - sizeof(int32), 0); - } - - if ((shardAffectedTupleCount == -1) || - (shardAffectedTupleCount == placementAffectedTupleCount)) - { - shardAffectedTupleCount = placementAffectedTupleCount; - } - else - { - ereport(ERROR, - (errmsg("modified %d tuples, but expected to modify %d", - placementAffectedTupleCount, shardAffectedTupleCount), - errdetail("Affected tuple counts at placements of shard " - UINT64_FORMAT " are different.", shardId))); - } - - PQclear(result); - - transactionConnection->transactionState = TRANSACTION_STATE_OPEN; - } - - return shardAffectedTupleCount; -} diff --git a/src/test/regress/expected/multi_shard_modify.out b/src/test/regress/expected/multi_shard_modify.out index d722c3ad3..23a781f16 100644 --- a/src/test/regress/expected/multi_shard_modify.out +++ b/src/test/regress/expected/multi_shard_modify.out @@ -103,10 +103,6 @@ SELECT master_modify_multiple_shards('DELETE FROM multi_shard_modify_test WHERE DEBUG: predicate pruning for shardId 350001 DEBUG: predicate pruning for shardId 350002 DEBUG: predicate pruning for shardId 350003 -DEBUG: sent PREPARE TRANSACTION over connection 350000 -DEBUG: sent PREPARE TRANSACTION over connection 350000 -DEBUG: sent COMMIT PREPARED over connection 350000 -DEBUG: sent COMMIT PREPARED over connection 350000 master_modify_multiple_shards ------------------------------- 1