diff --git a/src/backend/distributed/master/master_stage_protocol.c b/src/backend/distributed/master/master_stage_protocol.c index 60f39103a..2f964d603 100644 --- a/src/backend/distributed/master/master_stage_protocol.c +++ b/src/backend/distributed/master/master_stage_protocol.c @@ -483,6 +483,7 @@ void CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements, bool useExclusiveConnection, bool colocatedShard) { + DistTableCacheEntry *cacheEntry = DistributedTableCacheEntry(distributedRelationId); char *placementOwner = TableOwner(distributedRelationId); bool includeSequenceDefaults = false; List *ddlCommandList = GetTableDDLEvents(distributedRelationId, @@ -509,6 +510,12 @@ CreateShardsOnWorkers(Oid distributedRelationId, List *shardPlacements, BeginOrContinueCoordinatedTransaction(); + if (MultiShardCommitProtocol == COMMIT_PROTOCOL_2PC || + cacheEntry->replicationModel == REPLICATION_MODEL_2PC) + { + CoordinatedTransactionUse2PC(); + } + foreach(shardPlacementCell, shardPlacements) { ShardPlacement *shardPlacement = (ShardPlacement *) lfirst(shardPlacementCell); diff --git a/src/test/regress/expected/multi_transaction_recovery.out b/src/test/regress/expected/multi_transaction_recovery.out index a5ec64d3f..e156d3724 100644 --- a/src/test/regress/expected/multi_transaction_recovery.out +++ b/src/test/regress/expected/multi_transaction_recovery.out @@ -63,6 +63,7 @@ SELECT count(*) FROM pg_tables WHERE tablename = 'should_commit'; SET citus.shard_replication_factor TO 2; SET citus.shard_count TO 2; SET citus.multi_shard_commit_protocol TO '2pc'; +-- create_distributed_table should add 2 recovery records (1 connection per node) CREATE TABLE test_recovery (x text); SELECT create_distributed_table('test_recovery', 'x'); create_distributed_table @@ -70,7 +71,40 @@ SELECT create_distributed_table('test_recovery', 'x'); (1 row) +SELECT count(*) FROM pg_dist_transaction; + count +------- + 2 +(1 row) + +-- create_reference_table should add another 2 recovery records +CREATE TABLE test_recovery_ref (x text); +SELECT create_reference_table('test_recovery_ref'); + create_reference_table +------------------------ + +(1 row) + +SELECT count(*) FROM pg_dist_transaction; + count +------- + 4 +(1 row) + +SELECT recover_prepared_transactions(); + recover_prepared_transactions +------------------------------- + 0 +(1 row) + +-- plain INSERT does not use 2PC INSERT INTO test_recovery VALUES ('hello'); +SELECT count(*) FROM pg_dist_transaction; + count +------- + 0 +(1 row) + -- Committed DDL commands should write 4 transaction recovery records BEGIN; ALTER TABLE test_recovery ADD COLUMN y text; diff --git a/src/test/regress/sql/multi_transaction_recovery.sql b/src/test/regress/sql/multi_transaction_recovery.sql index b566080c9..29d97c799 100644 --- a/src/test/regress/sql/multi_transaction_recovery.sql +++ b/src/test/regress/sql/multi_transaction_recovery.sql @@ -40,9 +40,21 @@ SET citus.shard_replication_factor TO 2; SET citus.shard_count TO 2; SET citus.multi_shard_commit_protocol TO '2pc'; +-- create_distributed_table should add 2 recovery records (1 connection per node) CREATE TABLE test_recovery (x text); SELECT create_distributed_table('test_recovery', 'x'); +SELECT count(*) FROM pg_dist_transaction; + +-- create_reference_table should add another 2 recovery records +CREATE TABLE test_recovery_ref (x text); +SELECT create_reference_table('test_recovery_ref'); +SELECT count(*) FROM pg_dist_transaction; + +SELECT recover_prepared_transactions(); + +-- plain INSERT does not use 2PC INSERT INTO test_recovery VALUES ('hello'); +SELECT count(*) FROM pg_dist_transaction; -- Committed DDL commands should write 4 transaction recovery records BEGIN;