diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index 44442c0ab..ea29b9ba7 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -253,7 +253,8 @@ static CopyShardState * GetShardState(uint64 shardId, HTAB *shardStateHash, copyOutState, bool isCopyToIntermediateFile); static MultiConnection * CopyGetPlacementConnection(HTAB *connectionStateHash, ShardPlacement *placement, - bool stopOnFailure); + bool stopOnFailure, + bool colocatedIntermediateResult); static bool HasReachedAdaptiveExecutorPoolSize(List *connectionStateHash); static MultiConnection * GetLeastUtilisedCopyConnection(List *connectionStateList, char *nodeName, int nodePort); @@ -2230,7 +2231,10 @@ CitusCopyDestReceiverStartup(DestReceiver *dest, int operation, /* define the template for the COPY statement that is sent to workers */ CopyStmt *copyStatement = makeNode(CopyStmt); - if (copyDest->intermediateResultIdPrefix != NULL) + + bool colocatedIntermediateResults = + copyDest->intermediateResultIdPrefix != NULL; + if (colocatedIntermediateResults) { copyStatement->relation = makeRangeVar(NULL, copyDest->intermediateResultIdPrefix, -1); @@ -3448,7 +3452,8 @@ InitializeCopyShardState(CopyShardState *shardState, } MultiConnection *connection = - CopyGetPlacementConnection(connectionStateHash, placement, stopOnFailure); + CopyGetPlacementConnection(connectionStateHash, placement, stopOnFailure, + isCopyToIntermediateFile); if (connection == NULL) { failedPlacementCount++; @@ -3544,11 +3549,40 @@ LogLocalCopyExecution(uint64 shardId) * then it reuses the connection. Otherwise, it requests a connection for placement. */ static MultiConnection * -CopyGetPlacementConnection(HTAB *connectionStateHash, ShardPlacement *placement, bool - stopOnFailure) +CopyGetPlacementConnection(HTAB *connectionStateHash, ShardPlacement *placement, + bool stopOnFailure, bool colocatedIntermediateResult) { - uint32 connectionFlags = FOR_DML; - char *nodeUser = CurrentUserName(); + if (colocatedIntermediateResult) + { + /* + * Colocated intermediate results are just files and not required to use + * the same connections with their co-located shards. So, we are free to + * use any connection we can get. + * + * Also, the current connection re-use logic does not know how to handle + * intermediate results as the intermediate results always truncates the + * existing files. That's why we we use one connection per intermediate + * result. + * + * Also note that we are breaking the guarantees of citus.shared_pool_size + * as we cannot rely on optional connections. + */ + uint32 connectionFlagsForIntermediateResult = 0; + MultiConnection *connection = + GetNodeConnection(connectionFlagsForIntermediateResult, placement->nodeName, + placement->nodePort); + + /* + * As noted above, we want each intermediate file to go over + * a separate connection. + */ + ClaimConnectionExclusively(connection); + + /* and, we cannot afford to handle failures when anything goes wrong */ + MarkRemoteTransactionCritical(connection); + + return connection; + } /* * Determine whether the task has to be assigned to a particular connection @@ -3556,10 +3590,10 @@ CopyGetPlacementConnection(HTAB *connectionStateHash, ShardPlacement *placement, */ ShardPlacementAccess *placementAccess = CreatePlacementAccess(placement, PLACEMENT_ACCESS_DML); - MultiConnection *connection = GetConnectionIfPlacementAccessedInXact(connectionFlags, - list_make1( - placementAccess), - NULL); + uint32 connectionFlags = FOR_DML; + MultiConnection *connection = + GetConnectionIfPlacementAccessedInXact(connectionFlags, + list_make1(placementAccess), NULL); if (connection != NULL) { return connection; @@ -3605,6 +3639,7 @@ CopyGetPlacementConnection(HTAB *connectionStateHash, ShardPlacement *placement, connectionFlags |= CONNECTION_PER_PLACEMENT; } + char *nodeUser = CurrentUserName(); connection = GetPlacementConnection(connectionFlags, placement, nodeUser); if (PQstatus(connection->pgConn) != CONNECTION_OK)