diff --git a/src/backend/distributed/operations/worker_shard_copy.c b/src/backend/distributed/operations/worker_shard_copy.c index a891b2ee2..173d89c55 100644 --- a/src/backend/distributed/operations/worker_shard_copy.c +++ b/src/backend/distributed/operations/worker_shard_copy.c @@ -17,6 +17,8 @@ #include "parser/parse_relation.h" #include "utils/builtins.h" #include "utils/lsyscache.h" +#include "executor/spi.h" +#include "miscadmin.h" #include "distributed/commands/multi_copy.h" #include "distributed/connection_management.h" @@ -85,6 +87,7 @@ static void LocalCopyToShard(ShardCopyDestReceiver *copyDest, CopyOutState static void ConnectToRemoteAndStartCopy(ShardCopyDestReceiver *copyDest); static StringInfo ConstructShardTruncateStatement( List *destinationShardFullyQualifiedName); +static void TruncateShardForCopy(Oid shardOid); static bool @@ -366,20 +369,10 @@ ShardCopyDestReceiverShutdown(DestReceiver *dest) ResetReplicationOriginRemoteSession(copyDest->connection); - /* End the transaction by sending a COMMIT command */ - if (!SendRemoteCommand(copyDest->connection, "COMMIT")) - { - HandleRemoteTransactionConnectionError(copyDest->connection, true); - } - - PGresult *commitResult = GetRemoteCommandResult(copyDest->connection, true); - if (!IsResponseOK(result)) - { - ereport(ERROR, (errcode(ERRCODE_TRANSACTION_RESOLUTION_UNKNOWN), - errmsg("Failed to commit transaction"))); - } - - PQclear(commitResult); + PQclear(result); + ForgetResults(copyDest->connection); + RemoteTransactionCommit(copyDest->connection); + ResetRemoteTransaction(copyDest->connection); CloseConnection(copyDest->connection); } @@ -548,60 +541,104 @@ WriteLocalTuple(TupleTableSlot *slot, ShardCopyDestReceiver *copyDest) /* - * LocalCopyToShard performs local copy for the given destination shard. + * Truncate the table before starting the COPY with FREEZE. */ +static void +TruncateShardForCopy(Oid shardOid) +{ + Relation shard = table_open(shardOid, AccessExclusiveLock); + + /* Execute the TRUNCATE */ + char *shardRelationName = RelationGetRelationName(shard); + char *schemaName = get_namespace_name(RelationGetNamespace(shard)); + StringInfo truncateQuery = makeStringInfo(); + appendStringInfo(truncateQuery, "TRUNCATE %s.%s", quote_identifier(schemaName), quote_identifier(shardRelationName)); + + /* Initialize SPI */ + if (SPI_connect() != SPI_OK_CONNECT) + { + ereport(ERROR, (errcode(ERRCODE_INTERNAL_ERROR), + errmsg("could not connect to SPI manager"))); + } + + /* Execute the TRUNCATE command */ + int spiResult = SPI_execute(truncateQuery->data, false, 0); + if (spiResult != SPI_OK_UTILITY) + { + ereport(ERROR, (errcode(ERRCODE_INTERNAL_ERROR), + errmsg("TRUNCATE command failed"))); + } + + /* Finalize SPI */ + SPI_finish(); + + /* Release lock */ + table_close(shard, NoLock); +} + + + static void LocalCopyToShard(ShardCopyDestReceiver *copyDest, CopyOutState localCopyOutState) { - bool isBinaryCopy = localCopyOutState->binary; - if (isBinaryCopy) - { - AppendCopyBinaryFooters(localCopyOutState); - } + bool isBinaryCopy = localCopyOutState->binary; + + if (isBinaryCopy) + { + AppendCopyBinaryFooters(localCopyOutState); + } - /* - * Set the buffer as a global variable to allow ReadFromLocalBufferCallback - * to read from it. We cannot pass additional arguments to - * ReadFromLocalBufferCallback. - */ - LocalCopyBuffer = localCopyOutState->fe_msgbuf; + LocalCopyBuffer = localCopyOutState->fe_msgbuf; - char *destinationShardSchemaName = linitial( - copyDest->destinationShardFullyQualifiedName); - char *destinationShardRelationName = lsecond( - copyDest->destinationShardFullyQualifiedName); + char *destinationShardSchemaName = linitial(copyDest->destinationShardFullyQualifiedName); + char *destinationShardRelationName = lsecond(copyDest->destinationShardFullyQualifiedName); - Oid destinationSchemaOid = get_namespace_oid(destinationShardSchemaName, - false /* missing_ok */); - Oid destinationShardOid = get_relname_relid(destinationShardRelationName, - destinationSchemaOid); + Oid destinationSchemaOid = get_namespace_oid(destinationShardSchemaName, false); + Oid destinationShardOid = get_relname_relid(destinationShardRelationName, destinationSchemaOid); - DefElem *binaryFormatOption = NULL; - if (isBinaryCopy) - { - binaryFormatOption = makeDefElem("format", (Node *) makeString("binary"), -1); - } + /* Truncate the destination shard before performing COPY FREEZE */ + set_config_option("citus.enable_manual_changes_to_shards", + "on", /* Always set to "on" */ + (superuser() ? PGC_SUSET : PGC_USERSET), /* Allow superusers to change the setting at SUSET level */ + PGC_S_SESSION, /* Session level scope */ + GUC_ACTION_LOCAL, /* Local action within the session */ + true, /* Change in the current transaction */ + 0, /* No GUC source specified */ + false /* Do not report errors if already set */ + ); - Relation shard = table_open(destinationShardOid, RowExclusiveLock); - ParseState *pState = make_parsestate(NULL /* parentParseState */); - (void) addRangeTableEntryForRelation(pState, shard, AccessShareLock, - NULL /* alias */, false /* inh */, - false /* inFromCl */); + TruncateShardForCopy(destinationShardOid); - List *options = (isBinaryCopy) ? list_make1(binaryFormatOption) : NULL; - CopyFromState cstate = BeginCopyFrom(pState, shard, - NULL /* whereClause */, - NULL /* fileName */, - false /* is_program */, - ReadFromLocalBufferCallback, - NULL /* attlist (NULL is all columns) */, - options); - CopyFrom(cstate); - EndCopyFrom(cstate); - resetStringInfo(localCopyOutState->fe_msgbuf); + DefElem *binaryFormatOption = NULL; + if (isBinaryCopy) + { + binaryFormatOption = makeDefElem("format", (Node *) makeString("binary"), -1); + } - table_close(shard, NoLock); - free_parsestate(pState); + DefElem *freezeOption = makeDefElem("freeze", (Node *) makeString("true"), -1); + + Relation shard = table_open(destinationShardOid, RowExclusiveLock); + ParseState *pState = make_parsestate(NULL); + (void) addRangeTableEntryForRelation(pState, shard, AccessShareLock, NULL, false, false); + + List *options = NIL; + if (isBinaryCopy) + { + options = list_make2(binaryFormatOption, freezeOption); + } + else + { + options = list_make1(freezeOption); + } + + CopyFromState cstate = BeginCopyFrom(pState, shard, NULL, NULL, false, ReadFromLocalBufferCallback, NULL, options); + CopyFrom(cstate); + EndCopyFrom(cstate); + + resetStringInfo(localCopyOutState->fe_msgbuf); + + table_close(shard, NoLock); + free_parsestate(pState); }