diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index bfc73c676..d9cf8eade 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -68,6 +68,7 @@ #include "distributed/remote_commands.h" #include "distributed/resource_lock.h" #include "executor/executor.h" +#include "nodes/makefuncs.h" #include "tsearch/ts_locale.h" #include "utils/builtins.h" #include "utils/lsyscache.h" @@ -127,6 +128,19 @@ static void CopySendInt16(CopyOutState outputState, int16 val); static void CopyAttributeOutText(CopyOutState outputState, char *string); static inline void CopyFlushOutput(CopyOutState outputState, char *start, char *pointer); +/* CitusCopyDestReceiver functions */ +static void CitusCopyDestReceiverStartup(DestReceiver *copyDest, int operation, + TupleDesc inputTupleDesc); +#if PG_VERSION_NUM >= 90600 +static bool CitusCopyDestReceiverReceive(TupleTableSlot *slot, + DestReceiver *copyDest); +#else +static void CitusCopyDestReceiverReceive(TupleTableSlot *slot, + DestReceiver *copyDest); +#endif +static void CitusCopyDestReceiverShutdown(DestReceiver *destReceiver); +static void CitusCopyDestReceiverDestroy(DestReceiver *destReceiver); + /* * CitusCopyFrom implements the COPY table_name FROM. It dispacthes the copy @@ -406,6 +420,13 @@ CopyToExistingShards(CopyStmt *copyStatement, char *completionTag) /* create a mapping of shard id to a connection for each of its placements */ shardConnectionHash = CreateShardConnectionHash(TopTransactionContext); + /* + * From here on we use copyStatement as the template for the command + * that we send to workers. This command does not have an attribute + * list since NextCopyFrom will generate a value for all columns. + */ + copyStatement->attlist = NIL; + /* set up callback to identify error line number */ errorCallback.callback = CopyFromErrorCallback; errorCallback.arg = (void *) copyState; @@ -604,6 +625,13 @@ CopyToNewShards(CopyStmt *copyStatement, char *completionTag, Oid relationId) errorCallback.arg = (void *) copyState; errorCallback.previous = error_context_stack; + /* + * From here on we use copyStatement as the template for the command + * that we send to workers. This command does not have an attribute + * list since NextCopyFrom will generate a value for all columns. + */ + copyStatement->attlist = NIL; + while (true) { bool nextRowFound = false; @@ -1074,22 +1102,46 @@ ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId, bool useBinaryCop char *shardName = pstrdup(relationName); char *shardQualifiedName = NULL; - const char *copyFormat = NULL; AppendShardIdToName(&shardName, shardId); shardQualifiedName = quote_qualified_identifier(schemaName, shardName); + appendStringInfo(command, "COPY %s ", shardQualifiedName); + + if (copyStatement->attlist != NIL) + { + ListCell *columnNameCell = NULL; + bool appendedFirstName = false; + + foreach(columnNameCell, copyStatement->attlist) + { + char *columnName = (char *) lfirst(columnNameCell); + + if (!appendedFirstName) + { + appendStringInfo(command, "(%s", columnName); + appendedFirstName = true; + } + else + { + appendStringInfo(command, ", %s", columnName); + } + } + + appendStringInfoString(command, ") "); + } + + appendStringInfo(command, "FROM STDIN WITH "); + if (useBinaryCopyFormat) { - copyFormat = "BINARY"; + appendStringInfoString(command, "(FORMAT BINARY)"); } else { - copyFormat = "TEXT"; + appendStringInfoString(command, "(FORMAT TEXT)"); } - appendStringInfo(command, "COPY %s FROM STDIN WITH (FORMAT %s)", shardQualifiedName, - copyFormat); return command; } @@ -1277,7 +1329,6 @@ AppendCopyRowData(Datum *valueArray, bool *isNullArray, TupleDesc rowDescriptor, { CopySendInt16(rowOutputState, availableColumnCount); } - for (columnIndex = 0; columnIndex < totalColumnCount; columnIndex++) { Form_pg_attribute currentColumn = rowDescriptor->attrs[columnIndex]; @@ -1694,3 +1745,368 @@ CopyFlushOutput(CopyOutState cstate, char *start, char *pointer) CopySendData(cstate, start, pointer - start); } } + + +/* + * CreateCitusCopyDestReceiver creates a DestReceiver that copies into + * a distributed table. + */ +CitusCopyDestReceiver * +CreateCitusCopyDestReceiver(Oid tableId, List *columnNameList, EState *executorState, + bool stopOnFailure) +{ + CitusCopyDestReceiver *copyDest = NULL; + + copyDest = (CitusCopyDestReceiver *) palloc0(sizeof(CitusCopyDestReceiver)); + + /* set up the DestReceiver function pointers */ + copyDest->pub.receiveSlot = CitusCopyDestReceiverReceive; + copyDest->pub.rStartup = CitusCopyDestReceiverStartup; + copyDest->pub.rShutdown = CitusCopyDestReceiverShutdown; + copyDest->pub.rDestroy = CitusCopyDestReceiverDestroy; + copyDest->pub.mydest = DestCopyOut; + + /* set up output parameters */ + copyDest->distributedRelationId = tableId; + copyDest->columnNameList = columnNameList; + copyDest->executorState = executorState; + copyDest->stopOnFailure = stopOnFailure; + copyDest->memoryContext = CurrentMemoryContext; + + return copyDest; +} + + +static void +CitusCopyDestReceiverStartup(DestReceiver *dest, int operation, + TupleDesc inputTupleDescriptor) +{ + CitusCopyDestReceiver *copyDest = (CitusCopyDestReceiver *) dest; + + Oid tableId = copyDest->distributedRelationId; + + char *relationName = get_rel_name(tableId); + Oid schemaOid = get_rel_namespace(tableId); + char *schemaName = get_namespace_name(schemaOid); + + Relation distributedRelation = NULL; + int columnIndex = 0; + List *columnNameList = copyDest->columnNameList; + + ListCell *columnNameCell = NULL; + + char partitionMethod = '\0'; + Var *partitionColumn = NULL; + int partitionColumnIndex = -1; + DistTableCacheEntry *cacheEntry = NULL; + + CopyStmt *copyStatement = NULL; + + List *shardIntervalList = NULL; + + CopyOutState copyOutState = NULL; + const char *delimiterCharacter = "\t"; + const char *nullPrintCharacter = "\\N"; + + /* look up table properties */ + distributedRelation = heap_open(tableId, RowExclusiveLock); + cacheEntry = DistributedTableCacheEntry(tableId); + partitionMethod = cacheEntry->partitionMethod; + + copyDest->distributedRelation = distributedRelation; + copyDest->partitionMethod = partitionMethod; + + if (partitionMethod == DISTRIBUTE_BY_NONE) + { + /* we don't support copy to reference tables from workers */ + EnsureSchemaNode(); + } + else + { + partitionColumn = PartitionColumn(tableId, 0); + } + + /* load the list of shards and verify that we have shards to copy into */ + shardIntervalList = LoadShardIntervalList(tableId); + if (shardIntervalList == NIL) + { + if (partitionMethod == DISTRIBUTE_BY_HASH) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("could not find any shards into which to copy"), + errdetail("No shards exist for distributed table \"%s\".", + relationName), + errhint("Run master_create_worker_shards to create shards " + "and try again."))); + } + else + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("could not find any shards into which to copy"), + errdetail("No shards exist for distributed table \"%s\".", + relationName))); + } + } + + /* prevent concurrent placement changes and non-commutative DML statements */ + LockShardListMetadata(shardIntervalList, ShareLock); + LockShardListResources(shardIntervalList, ShareLock); + + /* error if any shard missing min/max values */ + if (partitionMethod != DISTRIBUTE_BY_NONE && + cacheEntry->hasUninitializedShardInterval) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("could not start copy"), + errdetail("Distributed relation \"%s\" has shards " + "with missing shardminvalue/shardmaxvalue.", + relationName))); + } + + copyDest->hashFunction = cacheEntry->hashFunction; + copyDest->compareFunction = cacheEntry->shardIntervalCompareFunction; + + /* initialize the shard interval cache */ + copyDest->shardCount = cacheEntry->shardIntervalArrayLength; + copyDest->shardIntervalCache = cacheEntry->sortedShardIntervalArray; + + /* determine whether to use binary search */ + if (partitionMethod != DISTRIBUTE_BY_HASH || !cacheEntry->hasUniformHashDistribution) + { + copyDest->useBinarySearch = true; + } + + /* define how tuples will be serialised */ + copyOutState = (CopyOutState) palloc0(sizeof(CopyOutStateData)); + copyOutState->delim = (char *) delimiterCharacter; + copyOutState->null_print = (char *) nullPrintCharacter; + copyOutState->null_print_client = (char *) nullPrintCharacter; + copyOutState->binary = CanUseBinaryCopyFormat(inputTupleDescriptor, copyOutState); + copyOutState->fe_msgbuf = makeStringInfo(); + copyOutState->rowcontext = GetPerTupleMemoryContext(copyDest->executorState); + copyDest->copyOutState = copyOutState; + + copyDest->tupleDescriptor = inputTupleDescriptor; + + /* prepare output functions */ + copyDest->columnOutputFunctions = + ColumnOutputFunctions(inputTupleDescriptor, copyOutState->binary); + + foreach(columnNameCell, columnNameList) + { + char *columnName = (char *) lfirst(columnNameCell); + + /* load the column information from pg_attribute */ + AttrNumber attrNumber = get_attnum(tableId, columnName); + + /* check whether this is the partition column */ + if (partitionColumn != NULL && attrNumber == partitionColumn->varattno) + { + Assert(partitionColumnIndex == -1); + + partitionColumnIndex = columnIndex; + } + + columnIndex++; + } + + if (partitionMethod != DISTRIBUTE_BY_NONE && partitionColumnIndex == -1) + { + ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("the partition column of table %s should have a value", + quote_qualified_identifier(schemaName, relationName)))); + } + + copyDest->partitionColumnIndex = partitionColumnIndex; + + /* define the template for the COPY statement that is sent to workers */ + copyStatement = makeNode(CopyStmt); + copyStatement->relation = makeRangeVar(schemaName, relationName, -1); + copyStatement->query = NULL; + copyStatement->attlist = columnNameList; + copyStatement->is_from = true; + copyStatement->is_program = false; + copyStatement->filename = NULL; + copyStatement->options = NIL; + copyDest->copyStatement = copyStatement; + + copyDest->copyConnectionHash = CreateShardConnectionHash(TopTransactionContext); +} + + +#if PG_VERSION_NUM >= 90600 +static bool +#else +static void +#endif +CitusCopyDestReceiverReceive(TupleTableSlot *slot, DestReceiver *dest) +{ + CitusCopyDestReceiver *copyDest = (CitusCopyDestReceiver *) dest; + + char partitionMethod = copyDest->partitionMethod; + int partitionColumnIndex = copyDest->partitionColumnIndex; + TupleDesc tupleDescriptor = copyDest->tupleDescriptor; + CopyStmt *copyStatement = copyDest->copyStatement; + + int shardCount = copyDest->shardCount; + ShardInterval **shardIntervalCache = copyDest->shardIntervalCache; + + bool useBinarySearch = copyDest->useBinarySearch; + FmgrInfo *hashFunction = copyDest->hashFunction; + FmgrInfo *compareFunction = copyDest->compareFunction; + + HTAB *copyConnectionHash = copyDest->copyConnectionHash; + CopyOutState copyOutState = copyDest->copyOutState; + FmgrInfo *columnOutputFunctions = copyDest->columnOutputFunctions; + + bool stopOnFailure = copyDest->stopOnFailure; + + Datum *columnValues = NULL; + bool *columnNulls = NULL; + + Datum partitionColumnValue = 0; + ShardInterval *shardInterval = NULL; + int64 shardId = 0; + + bool shardConnectionsFound = false; + ShardConnections *shardConnections = NULL; + + EState *executorState = copyDest->executorState; + MemoryContext executorTupleContext = GetPerTupleMemoryContext(executorState); + MemoryContext oldContext = MemoryContextSwitchTo(executorTupleContext); + + slot_getallattrs(slot); + + columnValues = slot->tts_values; + columnNulls = slot->tts_isnull; + + /* + * Find the partition column value and corresponding shard interval + * for non-reference tables. + * Get the existing (and only a single) shard interval for the reference + * tables. Note that, reference tables has NULL partition column values so + * skip the check. + */ + if (partitionColumnIndex >= 0) + { + if (columnNulls[partitionColumnIndex]) + { + Oid relationId = copyDest->distributedRelationId; + char *relationName = get_rel_name(relationId); + Oid schemaOid = get_rel_namespace(relationId); + char *schemaName = get_namespace_name(schemaOid); + char *qualifiedTableName = quote_qualified_identifier(schemaName, + relationName); + + ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("the partition column of table %s should have a value", + qualifiedTableName))); + } + + /* find the partition column value */ + partitionColumnValue = columnValues[partitionColumnIndex]; + } + + /* + * Find the shard interval and id for the partition column value for + * non-reference tables. + * + * For reference table, this function blindly returns the tables single + * shard. + */ + shardInterval = FindShardInterval(partitionColumnValue, shardIntervalCache, + shardCount, partitionMethod, + compareFunction, hashFunction, + useBinarySearch); + if (shardInterval == NULL) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("could not find shard for partition column " + "value"))); + } + + shardId = shardInterval->shardId; + + /* connections hash is kept in memory context */ + MemoryContextSwitchTo(copyDest->memoryContext); + + /* get existing connections to the shard placements, if any */ + shardConnections = GetShardHashConnections(copyConnectionHash, shardId, + &shardConnectionsFound); + if (!shardConnectionsFound) + { + /* open connections and initiate COPY on shard placements */ + OpenCopyConnections(copyStatement, shardConnections, stopOnFailure, + copyOutState->binary); + + /* send copy binary headers to shard placements */ + if (copyOutState->binary) + { + SendCopyBinaryHeaders(copyOutState, shardId, + shardConnections->connectionList); + } + } + + /* replicate row to shard placements */ + resetStringInfo(copyOutState->fe_msgbuf); + AppendCopyRowData(columnValues, columnNulls, tupleDescriptor, + copyOutState, columnOutputFunctions); + SendCopyDataToAll(copyOutState->fe_msgbuf, shardId, shardConnections->connectionList); + + MemoryContextSwitchTo(oldContext); + +#if PG_VERSION_NUM >= 90600 + return true; +#endif +} + + +static void +CitusCopyDestReceiverShutdown(DestReceiver *destReceiver) +{ + CitusCopyDestReceiver *copyDest = (CitusCopyDestReceiver *) destReceiver; + + HTAB *shardConnectionHash = copyDest->copyConnectionHash; + List *shardConnectionsList = NIL; + ListCell *shardConnectionsCell = NULL; + CopyOutState copyOutState = copyDest->copyOutState; + Relation distributedRelation = copyDest->distributedRelation; + + shardConnectionsList = ShardConnectionList(shardConnectionHash); + foreach(shardConnectionsCell, shardConnectionsList) + { + ShardConnections *shardConnections = (ShardConnections *) lfirst( + shardConnectionsCell); + + /* send copy binary footers to all shard placements */ + if (copyOutState->binary) + { + SendCopyBinaryFooters(copyOutState, shardConnections->shardId, + shardConnections->connectionList); + } + + /* close the COPY input on all shard placements */ + EndRemoteCopy(shardConnections->shardId, shardConnections->connectionList, true); + } + + heap_close(distributedRelation, NoLock); +} + + +static void +CitusCopyDestReceiverDestroy(DestReceiver *destReceiver) +{ + CitusCopyDestReceiver *copyDest = (CitusCopyDestReceiver *) destReceiver; + + if (copyDest->copyOutState) + { + pfree(copyDest->copyOutState); + } + + if (copyDest->columnOutputFunctions) + { + pfree(copyDest->columnOutputFunctions); + } + + pfree(copyDest); +} diff --git a/src/include/distributed/multi_copy.h b/src/include/distributed/multi_copy.h index d27926bbf..7b1c6dd04 100644 --- a/src/include/distributed/multi_copy.h +++ b/src/include/distributed/multi_copy.h @@ -13,7 +13,11 @@ #define MULTI_COPY_H +#include "distributed/master_metadata_utility.h" +#include "nodes/execnodes.h" #include "nodes/parsenodes.h" +#include "tcop/dest.h" + /* * A smaller version of copy.c's CopyStateData, trimmed to the elements @@ -43,8 +47,54 @@ typedef struct NodeAddress int32 nodePort; } NodeAddress; +/* CopyDestReceiver can be used to stream results into a distributed table */ +typedef struct CitusCopyDestReceiver +{ + DestReceiver pub; + + /* relation and columns to which to copy */ + Oid distributedRelationId; + List *columnNameList; + + /* EState for per-tuple memory allocation */ + EState *executorState; + + /* MemoryContext for DestReceiver session */ + MemoryContext memoryContext; + + /* distributed relation details */ + Relation distributedRelation; + char partitionMethod; + int partitionColumnIndex; + + /* descriptor of the tuples that are sent to the worker */ + TupleDesc tupleDescriptor; + + /* template for COPY statement to send to workers */ + CopyStmt *copyStatement; + + /* cached shard metadata for pruning */ + int shardCount; + ShardInterval **shardIntervalCache; + bool useBinarySearch; + FmgrInfo *hashFunction; + FmgrInfo *compareFunction; + + /* cached shard metadata for pruning */ + HTAB *copyConnectionHash; + bool stopOnFailure; + + /* state on how to copy out data types */ + CopyOutState copyOutState; + FmgrInfo *columnOutputFunctions; +} CitusCopyDestReceiver; + /* function declarations for copying into a distributed table */ +extern CitusCopyDestReceiver * CreateCitusCopyDestReceiver(Oid relationId, + List *columnNameList, + EState *executorState, + bool stopOnFailure); extern FmgrInfo * ColumnOutputFunctions(TupleDesc rowDescriptor, bool binaryFormat); extern void AppendCopyRowData(Datum *valueArray, bool *isNullArray, TupleDesc rowDescriptor,