diff --git a/src/backend/distributed/executor/multi_router_executor.c b/src/backend/distributed/executor/multi_router_executor.c index a9c833dd7..ee5b579a2 100644 --- a/src/backend/distributed/executor/multi_router_executor.c +++ b/src/backend/distributed/executor/multi_router_executor.c @@ -925,12 +925,19 @@ ExecuteModifyTasks(List *taskList, bool expectResults, ParamListInfo paramListIn { int64 totalAffectedTupleCount = 0; ListCell *taskCell = NULL; - char *userName = CurrentUserName(); + Task *firstTask = NULL; + int connectionFlags = 0; List *shardIntervalList = NIL; List *affectedTupleCountList = NIL; + HTAB *shardConnectionHash = NULL; bool tasksPending = true; int placementIndex = 0; + if (taskList == NIL) + { + return 0; + } + if (XactModificationLevel == XACT_MODIFICATION_DATA) { ereport(ERROR, (errcode(ERRCODE_ACTIVE_SQL_TRANSACTION), @@ -944,8 +951,28 @@ ExecuteModifyTasks(List *taskList, bool expectResults, ParamListInfo paramListIn /* ensure that there are no concurrent modifications on the same shards */ AcquireExecutorMultiShardLocks(taskList); + BeginOrContinueCoordinatedTransaction(); + + firstTask = (Task *) linitial(taskList); + + if (MultiShardCommitProtocol == COMMIT_PROTOCOL_2PC || + firstTask->replicationModel == REPLICATION_MODEL_2PC) + { + CoordinatedTransactionUse2PC(); + } + + if (firstTask->taskType == DDL_TASK) + { + connectionFlags = FOR_DDL; + } + else + { + connectionFlags = FOR_DML; + } + /* open connection to all relevant placements, if not already open */ - OpenTransactionsToAllShardPlacements(shardIntervalList, userName); + shardConnectionHash = OpenTransactionsToAllShardPlacements(shardIntervalList, + connectionFlags); XactModificationLevel = XACT_MODIFICATION_MULTI_SHARD; @@ -968,7 +995,8 @@ ExecuteModifyTasks(List *taskList, bool expectResults, ParamListInfo paramListIn MultiConnection *connection = NULL; bool queryOK = false; - shardConnections = GetShardConnections(shardId, &shardConnectionsFound); + shardConnections = GetShardHashConnections(shardConnectionHash, shardId, + &shardConnectionsFound); connectionList = shardConnections->connectionList; if (placementIndex >= list_length(connectionList)) @@ -1003,7 +1031,8 @@ ExecuteModifyTasks(List *taskList, bool expectResults, ParamListInfo paramListIn /* abort in case of cancellation */ CHECK_FOR_INTERRUPTS(); - shardConnections = GetShardConnections(shardId, &shardConnectionsFound); + shardConnections = GetShardHashConnections(shardConnectionHash, shardId, + &shardConnectionsFound); connectionList = shardConnections->connectionList; if (placementIndex >= list_length(connectionList)) @@ -1074,6 +1103,8 @@ ExecuteModifyTasks(List *taskList, bool expectResults, ParamListInfo paramListIn placementIndex++; } + UnclaimAllShardConnections(shardConnectionHash); + CHECK_FOR_INTERRUPTS(); return totalAffectedTupleCount; diff --git a/src/backend/distributed/executor/multi_utility.c b/src/backend/distributed/executor/multi_utility.c index 78a9898cf..e324c7e13 100644 --- a/src/backend/distributed/executor/multi_utility.c +++ b/src/backend/distributed/executor/multi_utility.c @@ -1124,7 +1124,7 @@ VacuumTaskList(Oid relationId, VacuumStmt *vacuumStmt) task = CitusMakeNode(Task); task->jobId = jobId; task->taskId = taskId++; - task->taskType = SQL_TASK; + task->taskType = DDL_TASK; task->queryString = pstrdup(vacuumString->data); task->dependedTaskList = NULL; task->replicationModel = REPLICATION_MODEL_INVALID; @@ -2094,7 +2094,7 @@ DDLTaskList(Oid relationId, const char *commandString) task = CitusMakeNode(Task); task->jobId = jobId; task->taskId = taskId++; - task->taskType = SQL_TASK; + task->taskType = DDL_TASK; task->queryString = applyCommand->data; task->replicationModel = REPLICATION_MODEL_INVALID; task->dependedTaskList = NULL; @@ -2158,7 +2158,7 @@ ForeignKeyTaskList(Oid leftRelationId, Oid rightRelationId, task = CitusMakeNode(Task); task->jobId = jobId; task->taskId = taskId++; - task->taskType = SQL_TASK; + task->taskType = DDL_TASK; task->queryString = applyCommand->data; task->dependedTaskList = NULL; task->replicationModel = REPLICATION_MODEL_INVALID; diff --git a/src/backend/distributed/transaction/multi_shard_transaction.c b/src/backend/distributed/transaction/multi_shard_transaction.c index 132b75d5e..d5962f57b 100644 --- a/src/backend/distributed/transaction/multi_shard_transaction.c +++ b/src/backend/distributed/transaction/multi_shard_transaction.c @@ -18,6 +18,7 @@ #include "distributed/master_metadata_utility.h" #include "distributed/metadata_cache.h" #include "distributed/multi_shard_transaction.h" +#include "distributed/placement_connection.h" #include "distributed/shardinterval_utils.h" #include "distributed/worker_manager.h" #include "nodes/pg_list.h" @@ -25,34 +26,23 @@ #include "utils/memutils.h" -#define INITIAL_CONNECTION_CACHE_SIZE 1001 - - -/* per-transaction state */ -static HTAB *shardConnectionHash = NULL; +#define INITIAL_SHARD_CONNECTION_HASH_SIZE 128 /* * OpenTransactionsToAllShardPlacements opens connections to all placements - * using the provided shard identifier list. Connections accumulate in a global - * shardConnectionHash variable for use (and re-use) within this transaction. + * using the provided shard identifier list and returns it as a shard ID -> + * ShardConnections hash. connectionFlags can be used to specify whether + * the command is FOR_DML or FOR_DDL. */ -void -OpenTransactionsToAllShardPlacements(List *shardIntervalList, char *userName) +HTAB * +OpenTransactionsToAllShardPlacements(List *shardIntervalList, int connectionFlags) { + HTAB *shardConnectionHash = NULL; ListCell *shardIntervalCell = NULL; List *newConnectionList = NIL; - if (shardConnectionHash == NULL) - { - shardConnectionHash = CreateShardConnectionHash(TopTransactionContext); - } - - BeginOrContinueCoordinatedTransaction(); - if (MultiShardCommitProtocol == COMMIT_PROTOCOL_2PC) - { - CoordinatedTransactionUse2PC(); - } + shardConnectionHash = CreateShardConnectionHash(CurrentMemoryContext); /* open connections to shards which don't have connections yet */ foreach(shardIntervalCell, shardIntervalList) @@ -64,7 +54,8 @@ OpenTransactionsToAllShardPlacements(List *shardIntervalList, char *userName) List *shardPlacementList = NIL; ListCell *placementCell = NULL; - shardConnections = GetShardConnections(shardId, &shardConnectionsFound); + shardConnections = GetShardHashConnections(shardConnectionHash, shardId, + &shardConnectionsFound); if (shardConnectionsFound) { continue; @@ -82,7 +73,6 @@ OpenTransactionsToAllShardPlacements(List *shardIntervalList, char *userName) { ShardPlacement *shardPlacement = (ShardPlacement *) lfirst(placementCell); MultiConnection *connection = NULL; - MemoryContext oldContext = NULL; WorkerNode *workerNode = FindWorkerNode(shardPlacement->nodeName, shardPlacement->nodePort); @@ -93,20 +83,15 @@ OpenTransactionsToAllShardPlacements(List *shardIntervalList, char *userName) shardPlacement->nodePort))); } - connection = StartNodeUserDatabaseConnection(FORCE_NEW_CONNECTION, - shardPlacement->nodeName, - shardPlacement->nodePort, - userName, - NULL); + connection = StartPlacementConnection(connectionFlags, + shardPlacement, + NULL); - /* we need to preserve the connection list for the next statement */ - oldContext = MemoryContextSwitchTo(TopTransactionContext); + ClaimConnectionExclusively(connection); shardConnections->connectionList = lappend(shardConnections->connectionList, connection); - MemoryContextSwitchTo(oldContext); - newConnectionList = lappend(newConnectionList, connection); /* @@ -125,6 +110,8 @@ OpenTransactionsToAllShardPlacements(List *shardIntervalList, char *userName) { RemoteTransactionsBeginIfNecessary(newConnectionList); } + + return shardConnectionHash; } @@ -147,36 +134,13 @@ CreateShardConnectionHash(MemoryContext memoryContext) hashFlags = (HASH_ELEM | HASH_CONTEXT | HASH_BLOBS); shardConnectionsHash = hash_create("Shard Connections Hash", - INITIAL_CONNECTION_CACHE_SIZE, &info, + INITIAL_SHARD_CONNECTION_HASH_SIZE, &info, hashFlags); return shardConnectionsHash; } -/* - * GetShardConnections finds existing connections for a shard in the global - * connection hash. If not found, then a ShardConnections structure with empty - * connectionList is returned and the shardConnectionsFound output parameter - * will be set to false. - */ -ShardConnections * -GetShardConnections(int64 shardId, bool *shardConnectionsFound) -{ - ShardConnections *shardConnections = NULL; - - ShardInterval *shardInterval = LoadShardInterval(shardId); - List *colocatedShardIds = ColocatedShardIntervalList(shardInterval); - ShardInterval *baseShardInterval = LowestShardIntervalById(colocatedShardIds); - int64 baseShardId = baseShardInterval->shardId; - - shardConnections = GetShardHashConnections(shardConnectionHash, baseShardId, - shardConnectionsFound); - - return shardConnections; -} - - /* * GetShardHashConnections finds existing connections for a shard in the * provided hash. If not found, then a ShardConnections structure with empty @@ -235,16 +199,37 @@ ShardConnectionList(HTAB *connectionHash) void ResetShardPlacementTransactionState(void) { - /* - * Now that transaction management does most of our work, nothing remains - * but to reset the connection hash, which wouldn't be valid next time - * round. - */ - shardConnectionHash = NULL; - if (MultiShardCommitProtocol == COMMIT_PROTOCOL_BARE) { MultiShardCommitProtocol = SavedMultiShardCommitProtocol; SavedMultiShardCommitProtocol = COMMIT_PROTOCOL_BARE; } } + + +/* + * UnclaimAllShardConnections unclaims all connections in the given + * shard connections hash after previously claiming them exclusively + * in OpenTransactionsToAllShardPlacements. + */ +void +UnclaimAllShardConnections(HTAB *shardConnectionHash) +{ + HASH_SEQ_STATUS status; + ShardConnections *shardConnections = NULL; + + hash_seq_init(&status, shardConnectionHash); + + while ((shardConnections = hash_seq_search(&status)) != 0) + { + List *connectionList = shardConnections->connectionList; + ListCell *connectionCell = NULL; + + foreach(connectionCell, connectionList) + { + MultiConnection *connection = (MultiConnection *) lfirst(connectionCell); + + UnclaimConnection(connection); + } + } +} diff --git a/src/include/distributed/multi_physical_planner.h b/src/include/distributed/multi_physical_planner.h index 61f5baf73..1949afadc 100644 --- a/src/include/distributed/multi_physical_planner.h +++ b/src/include/distributed/multi_physical_planner.h @@ -82,7 +82,8 @@ typedef enum MAP_OUTPUT_FETCH_TASK = 5, MERGE_FETCH_TASK = 6, MODIFY_TASK = 7, - ROUTER_TASK = 8 + ROUTER_TASK = 8, + DDL_TASK = 9 } TaskType; diff --git a/src/include/distributed/multi_shard_transaction.h b/src/include/distributed/multi_shard_transaction.h index 36c6b51cc..9f5baf0c5 100644 --- a/src/include/distributed/multi_shard_transaction.h +++ b/src/include/distributed/multi_shard_transaction.h @@ -27,13 +27,14 @@ typedef struct ShardConnections } ShardConnections; -extern void OpenTransactionsToAllShardPlacements(List *shardIdList, char *relationOwner); +extern HTAB * OpenTransactionsToAllShardPlacements(List *shardIdList, + int connectionFlags); extern HTAB * CreateShardConnectionHash(MemoryContext memoryContext); -extern ShardConnections * GetShardConnections(int64 shardId, bool *shardConnectionsFound); extern ShardConnections * GetShardHashConnections(HTAB *connectionHash, int64 shardId, bool *connectionsFound); extern List * ShardConnectionList(HTAB *connectionHash); extern void ResetShardPlacementTransactionState(void); +extern void UnclaimAllShardConnections(HTAB *shardConnectionHash); #endif /* MULTI_SHARD_TRANSACTION_H */