diff --git a/src/backend/distributed/executor/multi_router_executor.c b/src/backend/distributed/executor/multi_router_executor.c index d8cfd88db..ccc19e616 100644 --- a/src/backend/distributed/executor/multi_router_executor.c +++ b/src/backend/distributed/executor/multi_router_executor.c @@ -80,7 +80,7 @@ static ShardPlacementAccess * CreatePlacementAccess(ShardPlacement *placement, static void ExecuteSingleModifyTask(CitusScanState *scanState, Task *task, bool expectResults); static void ExecuteSingleSelectTask(CitusScanState *scanState, Task *task); -static List * GetModifyConnections(List *taskPlacementList, bool markCritical, +static List * GetModifyConnections(Task *task, bool markCritical, bool startedInTransaction); static void ExecuteMultipleTasks(CitusScanState *scanState, List *taskList, bool isModificationQuery, bool expectResults); @@ -781,7 +781,7 @@ ExecuteSingleModifyTask(CitusScanState *scanState, Task *task, bool expectResult * establish the connection, mark as critical (when modifying reference * table) and start a transaction (when in a transaction). */ - connectionList = GetModifyConnections(taskPlacementList, + connectionList = GetModifyConnections(task, taskRequiresTwoPhaseCommit, startedInTransaction); @@ -884,10 +884,12 @@ ExecuteSingleModifyTask(CitusScanState *scanState, Task *task, bool expectResult * transaction in progress. */ static List * -GetModifyConnections(List *taskPlacementList, bool markCritical, bool noNewTransactions) +GetModifyConnections(Task *task, bool markCritical, bool noNewTransactions) { + List *taskPlacementList = task->taskPlacementList; ListCell *taskPlacementCell = NULL; List *multiConnectionList = NIL; + List *relationShardList = task->relationShardList; /* first initiate connection establishment for all necessary connections */ foreach(taskPlacementCell, taskPlacementList) @@ -895,14 +897,22 @@ GetModifyConnections(List *taskPlacementList, bool markCritical, bool noNewTrans ShardPlacement *taskPlacement = (ShardPlacement *) lfirst(taskPlacementCell); int connectionFlags = SESSION_LIFESPAN | FOR_DML; MultiConnection *multiConnection = NULL; + List *placementAccessList = NIL; + ShardPlacementAccess *placementModification = NULL; - /* - * FIXME: It's not actually correct to use only one shard placement - * here for router queries involving multiple relations. We should - * check that this connection is the only modifying one associated - * with all the involved shards. - */ - multiConnection = StartPlacementConnection(connectionFlags, taskPlacement, NULL); + /* create placement accesses for placements that appear in a subselect */ + placementAccessList = BuildPlacementSelectList(taskPlacement->nodeName, + taskPlacement->nodePort, + relationShardList); + + /* create placement access for the placement that we're modifying */ + placementModification = CreatePlacementAccess(taskPlacement, + PLACEMENT_ACCESS_DML); + placementAccessList = lappend(placementAccessList, placementModification); + + /* get an appropriate connection for the DML statement */ + multiConnection = GetPlacementListConnection(connectionFlags, placementAccessList, + NULL); /* * If already in a transaction, disallow expanding set of remote