diff --git a/src/backend/distributed/commands/create_distributed_table.c b/src/backend/distributed/commands/create_distributed_table.c index 60caa58a5..aa76f338d 100644 --- a/src/backend/distributed/commands/create_distributed_table.c +++ b/src/backend/distributed/commands/create_distributed_table.c @@ -1258,12 +1258,15 @@ CopyLocalDataIntoShards(Oid distributedRelationId) ExprContext *econtext = GetPerTupleExprContext(estate); econtext->ecxt_scantuple = slot; + /* here we already have the data locally */ + bool hasCopyDataLocally = true; copyDest = (DestReceiver *) CreateCitusCopyDestReceiver(distributedRelationId, columnNameList, partitionColumnIndex, estate, stopOnFailure, - NULL); + NULL, + hasCopyDataLocally); /* initialise state for writing to shards, we'll open connections on demand */ copyDest->rStartup(copyDest, 0, tupleDescriptor); diff --git a/src/backend/distributed/commands/local_multi_copy.c b/src/backend/distributed/commands/local_multi_copy.c new file mode 100644 index 000000000..54fedc2a5 --- /dev/null +++ b/src/backend/distributed/commands/local_multi_copy.c @@ -0,0 +1,241 @@ +/*------------------------------------------------------------------------- + * + * local_multi_copy.c + * Commands for running a copy locally + * + * For each local placement, we have a buffer. When we receive a slot + * from a copy, the slot will be put to the corresponding buffer based + * on the shard id. When the buffer size exceeds the threshold a local + * copy will be done. Also If we reach to the end of copy, we will send + * the current buffer for local copy. + * + * The existing logic from multi_copy.c and format are used, therefore + * even if user did not do a copy with binary format, it is possible that + * we are going to be using binary format internally. + * + * + * Copyright (c) Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" +#include "commands/copy.h" +#include "catalog/namespace.h" +#include "parser/parse_relation.h" +#include "utils/lsyscache.h" +#include "nodes/makefuncs.h" +#include "safe_lib.h" +#include /* for htons */ + +#include "distributed/transmit.h" +#include "distributed/commands/multi_copy.h" +#include "distributed/multi_partitioning_utils.h" +#include "distributed/local_executor.h" +#include "distributed/local_multi_copy.h" +#include "distributed/shard_utils.h" + +/* + * LOCAL_COPY_BUFFER_SIZE is buffer size for local copy. + * There will be one buffer for each local placement, therefore + * the maximum amount of memory that might be alocated is + * LOCAL_COPY_BUFFER_SIZE * #local_placement + */ +#define LOCAL_COPY_BUFFER_SIZE (1 * 512 * 1024) + + +static int ReadFromLocalBufferCallback(void *outbuf, int minread, int maxread); +static Relation CreateCopiedShard(RangeVar *distributedRel, Relation shard); +static void AddSlotToBuffer(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest, + bool isBinary); + +static bool ShouldSendCopyNow(StringInfo buffer); +static void DoLocalCopy(StringInfo buffer, Oid relationId, int64 shardId, + CopyStmt *copyStatement, bool isEndOfCopy); +static bool ShouldAddBinaryHeaders(StringInfo buffer, bool isBinary); + +/* + * localCopyBuffer is used in copy callback to return the copied rows. + * The reason this is a global variable is that we cannot pass an additional + * argument to the copy callback. + */ +StringInfo localCopyBuffer; + +/* + * ProcessLocalCopy adds the given slot and does a local copy if + * this is the end of copy, or the buffer size exceeds the threshold. + */ +void +ProcessLocalCopy(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest, int64 shardId, + StringInfo buffer, bool isEndOfCopy) +{ + /* + * Here we save the previous buffer, and put the local shard's buffer + * into copyOutState. The motivation is to use the existing logic to + * serialize a row slot into buffer. + */ + StringInfo previousBuffer = copyDest->copyOutState->fe_msgbuf; + copyDest->copyOutState->fe_msgbuf = buffer; + + /* since we are doing a local copy, the following statements should use local execution to see the changes */ + TransactionAccessedLocalPlacement = true; + + bool isBinaryCopy = copyDest->copyOutState->binary; + AddSlotToBuffer(slot, copyDest, isBinaryCopy); + + if (isEndOfCopy || ShouldSendCopyNow(buffer)) + { + if (isBinaryCopy) + { + AppendCopyBinaryFooters(copyDest->copyOutState); + } + + DoLocalCopy(buffer, copyDest->distributedRelationId, shardId, + copyDest->copyStatement, isEndOfCopy); + } + copyDest->copyOutState->fe_msgbuf = previousBuffer; +} + + +/* + * AddSlotToBuffer serializes the given slot and adds it to the buffer in copyDest. + * If the copy format is binary, it adds binary headers as well. + */ +static void +AddSlotToBuffer(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest, bool isBinary) +{ + if (ShouldAddBinaryHeaders(copyDest->copyOutState->fe_msgbuf, isBinary)) + { + AppendCopyBinaryHeaders(copyDest->copyOutState); + } + + if (slot != NULL) + { + Datum *columnValues = slot->tts_values; + bool *columnNulls = slot->tts_isnull; + FmgrInfo *columnOutputFunctions = copyDest->columnOutputFunctions; + CopyCoercionData *columnCoercionPaths = copyDest->columnCoercionPaths; + + AppendCopyRowData(columnValues, columnNulls, copyDest->tupleDescriptor, + copyDest->copyOutState, columnOutputFunctions, + columnCoercionPaths); + } +} + + +/* + * ShouldSendCopyNow returns true if the given buffer size exceeds the + * local copy buffer size threshold. + */ +static bool +ShouldSendCopyNow(StringInfo buffer) +{ + return buffer->len > LOCAL_COPY_BUFFER_SIZE; +} + + +/* + * DoLocalCopy finds the shard table from the distributed relation id, and copies the given + * buffer into the shard. + */ +static void +DoLocalCopy(StringInfo buffer, Oid relationId, int64 shardId, CopyStmt *copyStatement, + bool isEndOfCopy) +{ + localCopyBuffer = buffer; + + Oid shardOid = GetShardLocalTableOid(relationId, shardId); + Relation shard = heap_open(shardOid, RowExclusiveLock); + Relation copiedShard = CreateCopiedShard(copyStatement->relation, shard); + ParseState *pState = make_parsestate(NULL); + + /* p_rtable of pState is set so that we can check constraints. */ + pState->p_rtable = CreateRangeTable(copiedShard, ACL_INSERT); + + CopyState cstate = BeginCopyFrom(pState, copiedShard, NULL, false, + ReadFromLocalBufferCallback, + copyStatement->attlist, copyStatement->options); + CopyFrom(cstate); + EndCopyFrom(cstate); + + heap_close(shard, NoLock); + free_parsestate(pState); + FreeStringInfo(buffer); + if (!isEndOfCopy) + { + buffer = makeStringInfo(); + } +} + + +/* + * ShouldAddBinaryHeaders returns true if the given buffer + * is empty and the format is binary. + */ +static bool +ShouldAddBinaryHeaders(StringInfo buffer, bool isBinary) +{ + if (!isBinary) + { + return false; + } + return buffer->len == 0; +} + + +/* + * CreateCopiedShard clones deep copies the necessary fields of the given + * relation. + */ +Relation +CreateCopiedShard(RangeVar *distributedRel, Relation shard) +{ + TupleDesc tupleDescriptor = RelationGetDescr(shard); + + Relation copiedDistributedRelation = (Relation) palloc(sizeof(RelationData)); + Form_pg_class copiedDistributedRelationTuple = (Form_pg_class) palloc( + CLASS_TUPLE_SIZE); + + *copiedDistributedRelation = *shard; + *copiedDistributedRelationTuple = *shard->rd_rel; + + copiedDistributedRelation->rd_rel = copiedDistributedRelationTuple; + copiedDistributedRelation->rd_att = CreateTupleDescCopyConstr(tupleDescriptor); + + Oid tableId = RangeVarGetRelid(distributedRel, NoLock, false); + + /* + * BeginCopyFrom opens all partitions of given partitioned table with relation_open + * and it expects its caller to close those relations. We do not have direct access + * to opened relations, thus we are changing relkind of partitioned tables so that + * Postgres will treat those tables as regular relations and will not open its + * partitions. + */ + if (PartitionedTable(tableId)) + { + copiedDistributedRelationTuple->relkind = RELKIND_RELATION; + } + return copiedDistributedRelation; +} + + +/* + * ReadFromLocalBufferCallback is the copy callback. + * It always tries to copy maxread bytes. + */ +static int +ReadFromLocalBufferCallback(void *outbuf, int minread, int maxread) +{ + int bytesread = 0; + int avail = localCopyBuffer->len - localCopyBuffer->cursor; + int bytesToRead = Min(avail, maxread); + if (bytesToRead > 0) + { + memcpy_s(outbuf, bytesToRead + strlen((char *) outbuf), + &localCopyBuffer->data[localCopyBuffer->cursor], bytesToRead); + } + bytesread += bytesToRead; + localCopyBuffer->cursor += bytesToRead; + + return bytesread; +} diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index 0c5d52d0e..8d2dd9928 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -85,6 +85,8 @@ #include "distributed/shard_pruning.h" #include "distributed/version_compat.h" #include "distributed/worker_protocol.h" +#include "distributed/local_multi_copy.h" +#include "distributed/hash_helpers.h" #include "executor/executor.h" #include "foreign/foreign.h" #include "libpq/libpq.h" @@ -162,6 +164,8 @@ struct CopyPlacementState /* State of shard to which the placement belongs to. */ CopyShardState *shardState; + int32 groupId; + /* * Buffered COPY data. When the placement is activePlacementState of * some connection, this is empty. Because in that case we directly @@ -178,6 +182,12 @@ struct CopyShardState /* Used as hash key. */ uint64 shardId; + /* used for doing local copy */ + StringInfo localCopyBuffer; + + /* containsLocalPlacement is true if we have a local placement for the shard id of this state */ + bool containsLocalPlacement; + /* List of CopyPlacementStates for all active placements of the shard. */ List *placementStateList; }; @@ -232,13 +242,15 @@ static CopyConnectionState * GetConnectionState(HTAB *connectionStateHash, MultiConnection *connection); static CopyShardState * GetShardState(uint64 shardId, HTAB *shardStateHash, HTAB *connectionStateHash, bool stopOnFailure, - bool *found); + bool *found, bool shouldUseLocalCopy, MemoryContext + context); static MultiConnection * CopyGetPlacementConnection(ShardPlacement *placement, bool stopOnFailure); static List * ConnectionStateList(HTAB *connectionStateHash); static void InitializeCopyShardState(CopyShardState *shardState, HTAB *connectionStateHash, - uint64 shardId, bool stopOnFailure); + uint64 shardId, bool stopOnFailure, bool + canUseLocalCopy, MemoryContext context); static void StartPlacementStateCopyCommand(CopyPlacementState *placementState, CopyStmt *copyStatement, CopyOutState copyOutState); @@ -274,6 +286,10 @@ static bool CitusCopyDestReceiverReceive(TupleTableSlot *slot, DestReceiver *copyDest); static void CitusCopyDestReceiverShutdown(DestReceiver *destReceiver); static void CitusCopyDestReceiverDestroy(DestReceiver *destReceiver); +static bool ContainsLocalPlacement(int64 shardId); +static void FinishLocalCopy(CitusCopyDestReceiver *copyDest); +static bool ShouldExecuteCopyLocally(void); +static void LogLocalCopyExecution(uint64 shardId); /* exports for SQL callable functions */ @@ -415,9 +431,12 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag) stopOnFailure = true; } + bool hasCopyDataLocally = false; + /* set up the destination for the COPY */ copyDest = CreateCitusCopyDestReceiver(tableId, columnNameList, partitionColumnIndex, - executorState, stopOnFailure, NULL); + executorState, stopOnFailure, NULL, + hasCopyDataLocally); dest = (DestReceiver *) copyDest; dest->rStartup(dest, 0, tupleDescriptor); @@ -1960,11 +1979,13 @@ CopyFlushOutput(CopyOutState cstate, char *start, char *pointer) CitusCopyDestReceiver * CreateCitusCopyDestReceiver(Oid tableId, List *columnNameList, int partitionColumnIndex, EState *executorState, bool stopOnFailure, - char *intermediateResultIdPrefix) + char *intermediateResultIdPrefix, bool hasCopyDataLocally) { CitusCopyDestReceiver *copyDest = (CitusCopyDestReceiver *) palloc0( sizeof(CitusCopyDestReceiver)); + copyDest->shouldUseLocalCopy = !hasCopyDataLocally && ShouldExecuteCopyLocally(); + /* set up the DestReceiver function pointers */ copyDest->pub.receiveSlot = CitusCopyDestReceiverReceive; copyDest->pub.rStartup = CitusCopyDestReceiverStartup; @@ -1985,6 +2006,44 @@ CreateCitusCopyDestReceiver(Oid tableId, List *columnNameList, int partitionColu } +/* + * ShouldExecuteCopyLocally returns true if the current copy + * operation should be done locally for local placements. + */ +static bool +ShouldExecuteCopyLocally() +{ + if (!EnableLocalExecution) + { + return false; + } + + if (TransactionAccessedLocalPlacement) + { + /* + * For various reasons, including the transaction visibility + * rules (e.g., read-your-own-writes), we have to use local + * execution again if it has already happened within this + * transaction block. + * + * We might error out later in the execution if it is not suitable + * to execute the tasks locally. + */ + Assert(IsMultiStatementTransaction() || InCoordinatedTransaction()); + + /* + * TODO: A future improvement could be to keep track of which placements + * have been locally executed. At this point, only use local execution for + * those placements. That'd help to benefit more from parallelism. + */ + + return true; + } + + return IsMultiStatementTransaction(); +} + + /* * CitusCopyDestReceiverStartup implements the rStartup interface of * CitusCopyDestReceiver. It opens the relation, acquires necessary @@ -2013,9 +2072,6 @@ CitusCopyDestReceiverStartup(DestReceiver *dest, int operation, const char *delimiterCharacter = "\t"; const char *nullPrintCharacter = "\\N"; - /* Citus currently doesn't know how to handle COPY command locally */ - ErrorIfTransactionAccessedPlacementsLocally(); - /* look up table properties */ Relation distributedRelation = heap_open(tableId, RowExclusiveLock); CitusTableCacheEntry *cacheEntry = GetCitusTableCacheEntry(tableId); @@ -2145,6 +2201,7 @@ CitusCopyDestReceiverStartup(DestReceiver *dest, int operation, } } + copyStatement->query = NULL; copyStatement->attlist = attributeList; copyStatement->is_from = true; @@ -2228,7 +2285,9 @@ CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest CopyShardState *shardState = GetShardState(shardId, copyDest->shardStateHash, copyDest->connectionStateHash, stopOnFailure, - &cachedShardStateFound); + &cachedShardStateFound, + copyDest->shouldUseLocalCopy, + copyDest->memoryContext); if (!cachedShardStateFound) { firstTupleInShard = true; @@ -2249,6 +2308,14 @@ CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest } } + if (copyDest->shouldUseLocalCopy && shardState->containsLocalPlacement) + { + bool isEndOfCopy = false; + ProcessLocalCopy(slot, copyDest, shardId, shardState->localCopyBuffer, + isEndOfCopy); + } + + foreach(placementStateCell, shardState->placementStateList) { CopyPlacementState *currentPlacementState = lfirst(placementStateCell); @@ -2276,6 +2343,7 @@ CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest { StartPlacementStateCopyCommand(currentPlacementState, copyStatement, copyOutState); + dlist_delete(¤tPlacementState->bufferedPlacementNode); connectionState->activePlacementState = currentPlacementState; @@ -2330,6 +2398,30 @@ CitusSendTupleToPlacements(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest } +/* + * ContainsLocalPlacement returns true if the current node has + * a local placement for the given shard id. + */ +static bool +ContainsLocalPlacement(int64 shardId) +{ + ListCell *placementCell = NULL; + List *activePlacementList = ActiveShardPlacementList(shardId); + int32 localGroupId = GetLocalGroupId(); + + foreach(placementCell, activePlacementList) + { + ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell); + + if (placement->groupId == localGroupId) + { + return true; + } + } + return false; +} + + /* * ShardIdForTuple returns id of the shard to which the given tuple belongs to. */ @@ -2407,6 +2499,7 @@ CitusCopyDestReceiverShutdown(DestReceiver *destReceiver) Relation distributedRelation = copyDest->distributedRelation; List *connectionStateList = ConnectionStateList(connectionStateHash); + FinishLocalCopy(copyDest); PG_TRY(); { @@ -2434,6 +2527,28 @@ CitusCopyDestReceiverShutdown(DestReceiver *destReceiver) } +/* + * FinishLocalCopy sends the remaining copies for local placements. + */ +static void +FinishLocalCopy(CitusCopyDestReceiver *copyDest) +{ + HTAB *shardStateHash = copyDest->shardStateHash; + HASH_SEQ_STATUS status; + CopyShardState *copyShardState; + + bool isEndOfCopy = true; + foreach_htab(copyShardState, &status, shardStateHash) + { + if (copyShardState->localCopyBuffer->len > 0) + { + ProcessLocalCopy(NULL, copyDest, copyShardState->shardId, + copyShardState->localCopyBuffer, isEndOfCopy); + } + } +} + + /* * ShutdownCopyConnectionState ends the copy command for the current active * placement on connection, and then sends the rest of the buffers over the @@ -2864,7 +2979,6 @@ CheckCopyPermissions(CopyStmt *copyStatement) /* *INDENT-OFF* */ bool is_from = copyStatement->is_from; Relation rel; - Oid relid; List *range_table = NIL; TupleDesc tupDesc; AclMode required_access = (is_from ? ACL_INSERT : ACL_SELECT); @@ -2874,15 +2988,8 @@ CheckCopyPermissions(CopyStmt *copyStatement) rel = heap_openrv(copyStatement->relation, is_from ? RowExclusiveLock : AccessShareLock); - relid = RelationGetRelid(rel); - - RangeTblEntry *rte = makeNode(RangeTblEntry); - rte->rtekind = RTE_RELATION; - rte->relid = relid; - rte->relkind = rel->rd_rel->relkind; - rte->requiredPerms = required_access; - range_table = list_make1(rte); - + range_table = CreateRangeTable(rel, required_access); + RangeTblEntry *rte = (RangeTblEntry*) linitial(range_table); tupDesc = RelationGetDescr(rel); attnums = CopyGetAttnums(tupDesc, rel, copyStatement->attlist); @@ -2909,6 +3016,21 @@ CheckCopyPermissions(CopyStmt *copyStatement) } +/* + * CreateRangeTable creates a range table with the given relation. + */ +List * +CreateRangeTable(Relation rel, AclMode requiredAccess) +{ + RangeTblEntry *rte = makeNode(RangeTblEntry); + rte->rtekind = RTE_RELATION; + rte->relid = rel->rd_id; + rte->relkind = rel->rd_rel->relkind; + rte->requiredPerms = requiredAccess; + return list_make1(rte); +} + + /* Helper for CheckCopyPermissions(), copied from postgres */ static List * CopyGetAttnums(TupleDesc tupDesc, Relation rel, List *attnamelist) @@ -3087,14 +3209,15 @@ ConnectionStateList(HTAB *connectionStateHash) */ static CopyShardState * GetShardState(uint64 shardId, HTAB *shardStateHash, - HTAB *connectionStateHash, bool stopOnFailure, bool *found) + HTAB *connectionStateHash, bool stopOnFailure, bool *found, bool + shouldUseLocalCopy, MemoryContext context) { CopyShardState *shardState = (CopyShardState *) hash_search(shardStateHash, &shardId, HASH_ENTER, found); if (!*found) { InitializeCopyShardState(shardState, connectionStateHash, - shardId, stopOnFailure); + shardId, stopOnFailure, shouldUseLocalCopy, context); } return shardState; @@ -3109,11 +3232,16 @@ GetShardState(uint64 shardId, HTAB *shardStateHash, static void InitializeCopyShardState(CopyShardState *shardState, HTAB *connectionStateHash, uint64 shardId, - bool stopOnFailure) + bool stopOnFailure, bool shouldUseLocalCopy, MemoryContext + context) { ListCell *placementCell = NULL; int failedPlacementCount = 0; + MemoryContext oldContext = MemoryContextSwitchTo(context); + + MemoryContextSwitchTo(oldContext); + MemoryContext localContext = AllocSetContextCreateExtended(CurrentMemoryContext, "InitializeCopyShardState", @@ -3121,8 +3249,9 @@ InitializeCopyShardState(CopyShardState *shardState, ALLOCSET_DEFAULT_INITSIZE, ALLOCSET_DEFAULT_MAXSIZE); + /* release active placement list at the end of this function */ - MemoryContext oldContext = MemoryContextSwitchTo(localContext); + oldContext = MemoryContextSwitchTo(localContext); List *activePlacementList = ActiveShardPlacementList(shardId); @@ -3130,11 +3259,20 @@ InitializeCopyShardState(CopyShardState *shardState, shardState->shardId = shardId; shardState->placementStateList = NIL; + shardState->localCopyBuffer = makeStringInfo(); + shardState->containsLocalPlacement = ContainsLocalPlacement(shardId); + foreach(placementCell, activePlacementList) { ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell); + if (shouldUseLocalCopy && placement->groupId == GetLocalGroupId()) + { + LogLocalCopyExecution(shardId); + continue; + } + MultiConnection *connection = CopyGetPlacementConnection(placement, stopOnFailure); if (connection == NULL) @@ -3158,6 +3296,7 @@ InitializeCopyShardState(CopyShardState *shardState, CopyPlacementState *placementState = palloc0(sizeof(CopyPlacementState)); placementState->shardState = shardState; placementState->data = makeStringInfo(); + placementState->groupId = placement->groupId; placementState->connectionState = connectionState; /* @@ -3188,6 +3327,21 @@ InitializeCopyShardState(CopyShardState *shardState, } +/* + * LogLocalCopyExecution logs that the copy will be done locally for + * the given shard. + */ +static void +LogLocalCopyExecution(uint64 shardId) +{ + if (!(LogRemoteCommands || LogLocalCommands)) + { + return; + } + ereport(NOTICE, (errmsg("executing the copy locally for shard"))); +} + + /* * CopyGetPlacementConnection assigns a connection to the given placement. If * a connection has already been assigned the placement in the current transaction diff --git a/src/backend/distributed/executor/insert_select_executor.c b/src/backend/distributed/executor/insert_select_executor.c index 6663acc49..f5b3d2cce 100644 --- a/src/backend/distributed/executor/insert_select_executor.c +++ b/src/backend/distributed/executor/insert_select_executor.c @@ -581,13 +581,16 @@ ExecutePlanIntoColocatedIntermediateResults(Oid targetRelationId, int partitionColumnIndex = PartitionColumnIndexFromColumnList(targetRelationId, columnNameList); + bool hasCopyDataLocally = true; + /* set up a DestReceiver that copies into the intermediate table */ CitusCopyDestReceiver *copyDest = CreateCitusCopyDestReceiver(targetRelationId, columnNameList, partitionColumnIndex, executorState, stopOnFailure, - intermediateResultIdPrefix); + intermediateResultIdPrefix, + hasCopyDataLocally); ExecutePlanIntoDestReceiver(selectPlan, paramListInfo, (DestReceiver *) copyDest); @@ -623,12 +626,15 @@ ExecutePlanIntoRelation(Oid targetRelationId, List *insertTargetList, int partitionColumnIndex = PartitionColumnIndexFromColumnList(targetRelationId, columnNameList); + bool hasCopyDataLocally = true; + /* set up a DestReceiver that copies into the distributed table */ CitusCopyDestReceiver *copyDest = CreateCitusCopyDestReceiver(targetRelationId, columnNameList, partitionColumnIndex, executorState, - stopOnFailure, NULL); + stopOnFailure, NULL, + hasCopyDataLocally); ExecutePlanIntoDestReceiver(selectPlan, paramListInfo, (DestReceiver *) copyDest); diff --git a/src/backend/distributed/planner/deparse_shard_query.c b/src/backend/distributed/planner/deparse_shard_query.c index 5187e5094..9d26a8702 100644 --- a/src/backend/distributed/planner/deparse_shard_query.c +++ b/src/backend/distributed/planner/deparse_shard_query.c @@ -335,7 +335,12 @@ UpdateRelationsToLocalShardTables(Node *node, List *relationShardList) return true; } - Oid shardOid = GetShardOid(relationShard->relationId, relationShard->shardId); +<<<<<<< HEAD + Oid shardOid = GetShardLocalTableOid(relationShard->relationId, relationShard->shardId); +======= + Oid shardOid = GetShardLocalTableOid(relationShard->relationId, + relationShard->shardId); +>>>>>>> add the support to execute copy locally newRte->relid = shardOid; diff --git a/src/backend/distributed/utils/shard_utils.c b/src/backend/distributed/utils/shard_utils.c index ad3acac67..aa0a17921 100644 --- a/src/backend/distributed/utils/shard_utils.c +++ b/src/backend/distributed/utils/shard_utils.c @@ -16,11 +16,11 @@ #include "distributed/shard_utils.h" /* - * GetShardOid returns the oid of the shard from the given distributed relation + * GetShardLocalTableOid returns the oid of the shard from the given distributed relation * with the shardid. */ Oid -GetShardOid(Oid distRelId, uint64 shardId) +GetShardLocalTableOid(Oid distRelId, uint64 shardId) { char *relationName = get_rel_name(distRelId); AppendShardIdToName(&relationName, shardId); diff --git a/src/include/distributed/commands/multi_copy.h b/src/include/distributed/commands/multi_copy.h index d91358839..b887cc5ad 100644 --- a/src/include/distributed/commands/multi_copy.h +++ b/src/include/distributed/commands/multi_copy.h @@ -130,6 +130,9 @@ typedef struct CitusCopyDestReceiver /* useful for tracking multi shard accesses */ bool multiShardCopy; + /* if true, should copy to local placements in the current session */ + bool shouldUseLocalCopy; + /* copy into intermediate result */ char *intermediateResultIdPrefix; } CitusCopyDestReceiver; @@ -141,7 +144,8 @@ extern CitusCopyDestReceiver * CreateCitusCopyDestReceiver(Oid relationId, int partitionColumnIndex, EState *executorState, bool stopOnFailure, - char *intermediateResultPrefix); + char *intermediateResultPrefix, + bool hasCopyDataLocally); extern FmgrInfo * ColumnOutputFunctions(TupleDesc rowDescriptor, bool binaryFormat); extern bool CanUseBinaryCopyFormat(TupleDesc tupleDescription); extern bool CanUseBinaryCopyFormatForTargetList(List *targetEntryList); @@ -154,6 +158,7 @@ extern void AppendCopyRowData(Datum *valueArray, bool *isNullArray, extern void AppendCopyBinaryHeaders(CopyOutState headerOutputState); extern void AppendCopyBinaryFooters(CopyOutState footerOutputState); extern void EndRemoteCopy(int64 shardId, List *connectionList); +extern List * CreateRangeTable(Relation rel, AclMode requiredAccess); extern Node * ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, const char *queryString); extern void CheckCopyPermissions(CopyStmt *copyStatement); diff --git a/src/include/distributed/local_multi_copy.h b/src/include/distributed/local_multi_copy.h new file mode 100644 index 000000000..83dc6a32b --- /dev/null +++ b/src/include/distributed/local_multi_copy.h @@ -0,0 +1,9 @@ + +#ifndef LOCAL_MULTI_COPY +#define LOCAL_MULTI_COPY + +extern void ProcessLocalCopy(TupleTableSlot *slot, CitusCopyDestReceiver *copyDest, int64 + shardId, + StringInfo buffer, bool isEndOfCopy); + +#endif /* LOCAL_MULTI_COPY */ diff --git a/src/include/distributed/shard_utils.h b/src/include/distributed/shard_utils.h index e4fd9a2f2..28addce15 100644 --- a/src/include/distributed/shard_utils.h +++ b/src/include/distributed/shard_utils.h @@ -13,6 +13,6 @@ #include "postgres.h" -extern Oid GetShardOid(Oid distRelId, uint64 shardId); +extern Oid GetShardLocalTableOid(Oid distRelId, uint64 shardId); #endif /* SHARD_UTILS_H */