From d25ee8fbd8c84fafb81a3e7c38cecf02fb367f07 Mon Sep 17 00:00:00 2001 From: Marco Slot Date: Tue, 29 Mar 2016 18:21:29 +0200 Subject: [PATCH 1/2] Support for COPY FROM, based on pg_shard PR by Postres Pro --- src/backend/distributed/commands/multi_copy.c | 1062 +++++++++++++++++ .../distributed/executor/multi_utility.c | 17 +- src/backend/distributed/shared_library_init.c | 25 +- .../distributed/utils/connection_cache.c | 18 +- .../distributed/utils/multi_transaction.c | 211 ++++ src/include/distributed/connection_cache.h | 2 + src/include/distributed/multi_copy.h | 27 + src/include/distributed/multi_transaction.h | 57 + src/test/regress/expected/multi_utilities.out | 3 - .../expected/multi_utility_statements.out | 3 - src/test/regress/input/multi_copy.source | 142 +++ src/test/regress/multi_schedule | 6 + src/test/regress/output/multi_copy.source | 174 +++ src/test/regress/sql/multi_utilities.sql | 3 - .../regress/sql/multi_utility_statements.sql | 3 - 15 files changed, 1724 insertions(+), 29 deletions(-) create mode 100644 src/backend/distributed/commands/multi_copy.c create mode 100644 src/backend/distributed/utils/multi_transaction.c create mode 100644 src/include/distributed/multi_copy.h create mode 100644 src/include/distributed/multi_transaction.h create mode 100644 src/test/regress/input/multi_copy.source create mode 100644 src/test/regress/output/multi_copy.source diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c new file mode 100644 index 000000000..2e6c0ee14 --- /dev/null +++ b/src/backend/distributed/commands/multi_copy.c @@ -0,0 +1,1062 @@ +/*------------------------------------------------------------------------- + * + * multi_copy.c + * This file contains implementation of COPY utility for distributed + * tables. + * + * Contributed by Konstantin Knizhnik, Postgres Professional + * + * Copyright (c) 2016, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" +#include "c.h" +#include "fmgr.h" +#include "funcapi.h" +#include "libpq-fe.h" +#include "miscadmin.h" +#include "plpgsql.h" + +#include + +#include "access/heapam.h" +#include "access/htup_details.h" +#include "access/htup.h" +#include "access/nbtree.h" +#include "access/sdir.h" +#include "access/tupdesc.h" +#include "access/xact.h" +#include "catalog/namespace.h" +#include "catalog/pg_class.h" +#include "catalog/pg_type.h" +#include "catalog/pg_am.h" +#include "catalog/pg_collation.h" +#include "commands/extension.h" +#include "commands/copy.h" +#include "commands/defrem.h" +#include "distributed/citus_ruleutils.h" +#include "distributed/connection_cache.h" +#include "distributed/listutils.h" +#include "distributed/master_metadata_utility.h" +#include "distributed/master_protocol.h" +#include "distributed/multi_copy.h" +#include "distributed/multi_physical_planner.h" +#include "distributed/multi_transaction.h" +#include "distributed/pg_dist_partition.h" +#include "distributed/resource_lock.h" +#include "distributed/worker_protocol.h" +#include "executor/execdesc.h" +#include "executor/executor.h" +#include "executor/instrument.h" +#include "executor/tuptable.h" +#include "lib/stringinfo.h" +#include "nodes/execnodes.h" +#include "nodes/makefuncs.h" +#include "nodes/memnodes.h" +#include "nodes/nodeFuncs.h" +#include "nodes/nodes.h" +#include "nodes/params.h" +#include "nodes/parsenodes.h" +#include "nodes/pg_list.h" +#include "nodes/plannodes.h" +#include "nodes/primnodes.h" +#include "optimizer/clauses.h" +#include "optimizer/cost.h" +#include "optimizer/planner.h" +#include "optimizer/var.h" +#include "parser/parser.h" +#include "parser/analyze.h" +#include "parser/parse_node.h" +#include "parser/parsetree.h" +#include "parser/parse_type.h" +#include "storage/lock.h" +#include "tcop/dest.h" +#include "tcop/tcopprot.h" +#include "tcop/utility.h" +#include "tsearch/ts_locale.h" +#include "utils/builtins.h" +#include "utils/elog.h" +#include "utils/errcodes.h" +#include "utils/guc.h" +#include "utils/lsyscache.h" +#include "utils/typcache.h" +#include "utils/palloc.h" +#include "utils/rel.h" +#include "utils/relcache.h" +#include "utils/snapmgr.h" +#include "utils/tuplestore.h" +#include "utils/memutils.h" + + +#define INITIAL_CONNECTION_CACHE_SIZE 1001 + + +/* the transaction manager to use for COPY commands */ +int CopyTransactionManager = TRANSACTION_MANAGER_1PC; + + +/* Data structures from copy.c, to keep track of COPY processing state */ +typedef enum CopyDest +{ + COPY_FILE, /* to/from file (or a piped program) */ + COPY_OLD_FE, /* to/from frontend (2.0 protocol) */ + COPY_NEW_FE /* to/from frontend (3.0 protocol) */ +} CopyDest; + +typedef enum EolType +{ + EOL_UNKNOWN, + EOL_NL, + EOL_CR, + EOL_CRNL +} EolType; + +typedef struct CopyStateData +{ + /* low-level state data */ + CopyDest copy_dest; /* type of copy source/destination */ + FILE *copy_file; /* used if copy_dest == COPY_FILE */ + StringInfo fe_msgbuf; /* used for all dests during COPY TO, only for + * dest == COPY_NEW_FE in COPY FROM */ + bool fe_eof; /* true if detected end of copy data */ + EolType eol_type; /* EOL type of input */ + int file_encoding; /* file or remote side's character encoding */ + bool need_transcoding; /* file encoding diff from server? */ + bool encoding_embeds_ascii; /* ASCII can be non-first byte? */ + + /* parameters from the COPY command */ + Relation rel; /* relation to copy to or from */ + QueryDesc *queryDesc; /* executable query to copy from */ + List *attnumlist; /* integer list of attnums to copy */ + char *filename; /* filename, or NULL for STDIN/STDOUT */ + bool is_program; /* is 'filename' a program to popen? */ + bool binary; /* binary format? */ + bool oids; /* include OIDs? */ + bool freeze; /* freeze rows on loading? */ + bool csv_mode; /* Comma Separated Value format? */ + bool header_line; /* CSV header line? */ + char *null_print; /* NULL marker string (server encoding!) */ + int null_print_len; /* length of same */ + char *null_print_client; /* same converted to file encoding */ + char *delim; /* column delimiter (must be 1 byte) */ + char *quote; /* CSV quote char (must be 1 byte) */ + char *escape; /* CSV escape char (must be 1 byte) */ + List *force_quote; /* list of column names */ + bool force_quote_all; /* FORCE QUOTE *? */ + bool *force_quote_flags; /* per-column CSV FQ flags */ + List *force_notnull; /* list of column names */ + bool *force_notnull_flags; /* per-column CSV FNN flags */ +#if PG_VERSION_NUM >= 90400 + List *force_null; /* list of column names */ + bool *force_null_flags; /* per-column CSV FN flags */ +#endif + bool convert_selectively; /* do selective binary conversion? */ + List *convert_select; /* list of column names (can be NIL) */ + bool *convert_select_flags; /* per-column CSV/TEXT CS flags */ + + /* these are just for error messages, see CopyFromErrorCallback */ + const char *cur_relname; /* table name for error messages */ + int cur_lineno; /* line number for error messages */ + const char *cur_attname; /* current att for error messages */ + const char *cur_attval; /* current att value for error messages */ + + /* + * Working state for COPY TO/FROM + */ + MemoryContext copycontext; /* per-copy execution context */ + + /* + * Working state for COPY TO + */ + FmgrInfo *out_functions; /* lookup info for output functions */ + MemoryContext rowcontext; /* per-row evaluation context */ + + /* + * Working state for COPY FROM + */ + AttrNumber num_defaults; + bool file_has_oids; + FmgrInfo oid_in_function; + Oid oid_typioparam; + FmgrInfo *in_functions; /* array of input functions for each attrs */ + Oid *typioparams; /* array of element types for in_functions */ + int *defmap; /* array of default att numbers */ + ExprState **defexprs; /* array of default att expressions */ + bool volatile_defexprs; /* is any of defexprs volatile? */ + List *range_table; + + /* + * These variables are used to reduce overhead in textual COPY FROM. + * + * attribute_buf holds the separated, de-escaped text for each field of + * the current line. The CopyReadAttributes functions return arrays of + * pointers into this buffer. We avoid palloc/pfree overhead by re-using + * the buffer on each cycle. + */ + StringInfoData attribute_buf; + + /* field raw data pointers found by COPY FROM */ + + int max_fields; + char **raw_fields; + + /* + * Similarly, line_buf holds the whole input line being processed. The + * input cycle is first to read the whole line into line_buf, convert it + * to server encoding there, and then extract the individual attribute + * fields into attribute_buf. line_buf is preserved unmodified so that we + * can display it in error messages if appropriate. + */ + StringInfoData line_buf; + bool line_buf_converted; /* converted to server encoding? */ + bool line_buf_valid; /* contains the row being processed? */ + + /* + * Finally, raw_buf holds raw data read from the data source (file or + * client connection). CopyReadLine parses this data sufficiently to + * locate line boundaries, then transfers the data to line_buf and + * converts it. Note: we guarantee that there is a \0 at + * raw_buf[raw_buf_len]. + */ +#define RAW_BUF_SIZE 65536 /* we palloc RAW_BUF_SIZE+1 bytes */ + char *raw_buf; + int raw_buf_index; /* next byte to process */ + int raw_buf_len; /* total # of bytes stored */ +} CopyStateData; + + +/* ShardConnections represents a set of connections for each placement of a shard */ +typedef struct ShardConnections +{ + int64 shardId; + List *connectionList; +} ShardConnections; + + +/* Local functions forward declarations */ +static HTAB * CreateShardConnectionHash(void); +static int CompareShardIntervalsById(const void *leftElement, const void *rightElement); +static bool IsUniformHashDistribution(ShardInterval **shardIntervalArray, + int shardCount); +static FmgrInfo * ShardIntervalCompareFunction(Var *partitionColumn, char + partitionMethod); +static ShardInterval * FindShardInterval(Datum partitionColumnValue, + ShardInterval **shardIntervalCache, + int shardCount, char partitionMethod, + FmgrInfo *compareFunction, + FmgrInfo *hashFunction, bool useBinarySearch); +static ShardInterval * SearchCachedShardInterval(Datum partitionColumnValue, + ShardInterval **shardIntervalCache, + int shardCount, + FmgrInfo *compareFunction); +static void OpenCopyTransactions(CopyStmt *copyStatement, + ShardConnections *shardConnections, + int64 shardId); +static StringInfo ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId); +static void AppendColumnNames(StringInfo command, List *columnList); +static void AppendCopyOptions(StringInfo command, List *copyOptionList); +static void CopyRowToPlacements(StringInfo lineBuf, ShardConnections *shardConnections); +static List * ConnectionList(HTAB *connectionHash); +static void EndRemoteCopy(List *connectionList, bool stopOnFailure); +static void ReportCopyError(PGconn *connection, PGresult *result); + + +/* + * CitusCopyFrom implements the COPY table_name FROM ... for hash-partitioned + * and range-partitioned tables. + */ +void +CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) +{ + RangeVar *relation = copyStatement->relation; + Oid tableId = RangeVarGetRelid(relation, NoLock, false); + char *relationName = get_rel_name(tableId); + List *shardIntervalList = NULL; + ListCell *shardIntervalCell = NULL; + char partitionMethod = '\0'; + Var *partitionColumn = NULL; + HTAB *shardConnectionHash = NULL; + List *connectionList = NIL; + MemoryContext tupleContext = NULL; + CopyState copyState = NULL; + TupleDesc tupleDescriptor = NULL; + uint32 columnCount = 0; + Datum *columnValues = NULL; + bool *columnNulls = NULL; + Relation rel = NULL; + ShardInterval **shardIntervalCache = NULL; + bool useBinarySearch = false; + TypeCacheEntry *typeEntry = NULL; + FmgrInfo *hashFunction = NULL; + FmgrInfo *compareFunction = NULL; + int shardCount = 0; + uint64 processedRowCount = 0; + ErrorContextCallback errorCallback; + + /* disallow COPY to/from file or program except for superusers */ + if (copyStatement->filename != NULL && !superuser()) + { + if (copyStatement->is_program) + { + ereport(ERROR, + (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE), + errmsg("must be superuser to COPY to or from an external program"), + errhint("Anyone can COPY to stdout or from stdin. " + "psql's \\copy command also works for anyone."))); + } + else + { + ereport(ERROR, + (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE), + errmsg("must be superuser to COPY to or from a file"), + errhint("Anyone can COPY to stdout or from stdin. " + "psql's \\copy command also works for anyone."))); + } + } + + partitionColumn = PartitionColumn(tableId, 0); + partitionMethod = PartitionMethod(tableId); + if (partitionMethod != DISTRIBUTE_BY_RANGE && partitionMethod != DISTRIBUTE_BY_HASH) + { + ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("COPY is only supported for hash- and " + "range-partitioned tables"))); + } + + /* resolve hash function for partition column */ + typeEntry = lookup_type_cache(partitionColumn->vartype, TYPECACHE_HASH_PROC_FINFO); + hashFunction = &(typeEntry->hash_proc_finfo); + + /* resolve compare function for shard intervals */ + compareFunction = ShardIntervalCompareFunction(partitionColumn, partitionMethod); + + /* allocate column values and nulls arrays */ + rel = heap_open(tableId, RowExclusiveLock); + tupleDescriptor = RelationGetDescr(rel); + columnCount = tupleDescriptor->natts; + columnValues = palloc0(columnCount * sizeof(Datum)); + columnNulls = palloc0(columnCount * sizeof(bool)); + + /* load the list of shards and verify that we have shards to copy into */ + shardIntervalList = LoadShardIntervalList(tableId); + if (shardIntervalList == NIL) + { + if (partitionMethod == DISTRIBUTE_BY_HASH) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("could not find any shards for query"), + errdetail("No shards exist for distributed table \"%s\".", + relationName), + errhint("Run master_create_worker_shards to create shards " + "and try again."))); + } + else + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("could not find any shards for query"), + errdetail("No shards exist for distributed table \"%s\".", + relationName))); + } + } + + /* create a mapping of shard id to a connection for each of its placements */ + shardConnectionHash = CreateShardConnectionHash(); + + /* lock shards in order of shard id to prevent deadlock */ + shardIntervalList = SortList(shardIntervalList, CompareShardIntervalsById); + + foreach(shardIntervalCell, shardIntervalList) + { + ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); + int64 shardId = shardInterval->shardId; + + /* prevent concurrent changes to number of placements */ + LockShardDistributionMetadata(shardId, ShareLock); + + /* prevent concurrent update/delete statements */ + LockShardResource(shardId, ShareLock); + } + + /* initialize the shard interval cache */ + shardCount = list_length(shardIntervalList); + shardIntervalCache = SortedShardIntervalArray(shardIntervalList); + + /* determine whether to use binary search */ + if (partitionMethod != DISTRIBUTE_BY_HASH || + !IsUniformHashDistribution(shardIntervalCache, shardCount)) + { + useBinarySearch = true; + } + + /* initialize copy state to read from COPY data source */ + copyState = BeginCopyFrom(rel, copyStatement->filename, + copyStatement->is_program, + copyStatement->attlist, + copyStatement->options); + + if (copyState->binary) + { + ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("Copy in binary mode is not currently supported"))); + } + + /* set up callback to identify error line number */ + errorCallback.callback = CopyFromErrorCallback; + errorCallback.arg = (void *) copyState; + errorCallback.previous = error_context_stack; + error_context_stack = &errorCallback; + + /* + * We create a new memory context called tuple context, and read and write + * each row's values within this memory context. After each read and write, + * we reset the memory context. That way, we immediately release memory + * allocated for each row, and don't bloat memory usage with large input + * files. + */ + tupleContext = AllocSetContextCreate(CurrentMemoryContext, + "COPY Row Memory Context", + ALLOCSET_DEFAULT_MINSIZE, + ALLOCSET_DEFAULT_INITSIZE, + ALLOCSET_DEFAULT_MAXSIZE); + + /* we use a PG_TRY block to roll back on errors (e.g. in NextCopyFrom) */ + PG_TRY(); + { + while (true) + { + bool nextRowFound = false; + Datum partitionColumnValue = 0; + ShardInterval *shardInterval = NULL; + int64 shardId = 0; + ShardConnections *shardConnections = NULL; + bool found = false; + StringInfo lineBuf = NULL; + MemoryContext oldContext = NULL; + + oldContext = MemoryContextSwitchTo(tupleContext); + + /* parse a row from the input */ + nextRowFound = NextCopyFrom(copyState, NULL, columnValues, columnNulls, NULL); + + MemoryContextSwitchTo(oldContext); + + if (!nextRowFound) + { + MemoryContextReset(tupleContext); + break; + } + + CHECK_FOR_INTERRUPTS(); + + /* find the partition column value */ + + if (columnNulls[partitionColumn->varattno - 1]) + { + ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("cannot copy row with NULL value " + "in partition column"))); + } + + partitionColumnValue = columnValues[partitionColumn->varattno - 1]; + + /* find the shard interval and id for the partition column value */ + shardInterval = FindShardInterval(partitionColumnValue, shardIntervalCache, + shardCount, partitionMethod, + compareFunction, hashFunction, + useBinarySearch); + if (shardInterval == NULL) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("no shard for partition column value"))); + } + + shardId = shardInterval->shardId; + + /* find the connections to the shard placements */ + shardConnections = (ShardConnections *) hash_search(shardConnectionHash, + &shardInterval->shardId, + HASH_ENTER, + &found); + if (!found) + { + /* intialize COPY transactions on shard placements */ + shardConnections->shardId = shardId; + shardConnections->connectionList = NIL; + + OpenCopyTransactions(copyStatement, shardConnections, shardId); + } + + /* get the (truncated) line buffer */ + lineBuf = ©State->line_buf; + lineBuf->data[lineBuf->len++] = '\n'; + + /* Replicate row to all shard placements */ + CopyRowToPlacements(lineBuf, shardConnections); + + processedRowCount += 1; + + MemoryContextReset(tupleContext); + } + + connectionList = ConnectionList(shardConnectionHash); + + EndRemoteCopy(connectionList, true); + + if (CopyTransactionManager == TRANSACTION_MANAGER_2PC) + { + PrepareTransactions(connectionList); + } + + CHECK_FOR_INTERRUPTS(); + } + PG_CATCH(); + { + EndCopyFrom(copyState); + + /* roll back all transactions */ + connectionList = ConnectionList(shardConnectionHash); + EndRemoteCopy(connectionList, false); + AbortTransactions(connectionList); + CloseConnections(connectionList); + + PG_RE_THROW(); + } + PG_END_TRY(); + + EndCopyFrom(copyState); + heap_close(rel, NoLock); + + error_context_stack = errorCallback.previous; + + CommitTransactions(connectionList); + CloseConnections(connectionList); + + if (completionTag != NULL) + { + snprintf(completionTag, COMPLETION_TAG_BUFSIZE, + "COPY " UINT64_FORMAT, processedRowCount); + } +} + + +/* + * CreateShardConnectionHash constructs a hash table used for shardId->Connection + * mapping. + */ +static HTAB * +CreateShardConnectionHash(void) +{ + HTAB *shardConnectionsHash = NULL; + HASHCTL info; + + memset(&info, 0, sizeof(info)); + info.keysize = sizeof(int64); + info.entrysize = sizeof(ShardConnections); + info.hash = tag_hash; + + shardConnectionsHash = hash_create("Shard Connections Hash", + INITIAL_CONNECTION_CACHE_SIZE, &info, + HASH_ELEM | HASH_FUNCTION); + + return shardConnectionsHash; +} + + +/* + * CompareShardIntervalsById is a comparison function for sort shard + * intervals by their shard ID. + */ +static int +CompareShardIntervalsById(const void *leftElement, const void *rightElement) +{ + ShardInterval *leftInterval = *((ShardInterval **) leftElement); + ShardInterval *rightInterval = *((ShardInterval **) rightElement); + int64 leftShardId = leftInterval->shardId; + int64 rightShardId = rightInterval->shardId; + + /* we compare 64-bit integers, instead of casting their difference to int */ + if (leftShardId > rightShardId) + { + return 1; + } + else if (leftShardId < rightShardId) + { + return -1; + } + else + { + return 0; + } +} + + +/* + * ShardIntervalCompareFunction returns the appropriate compare function for the + * partition column type. In case of hash-partitioning, it always returns the compare + * function for integers. + */ +static FmgrInfo * +ShardIntervalCompareFunction(Var *partitionColumn, char partitionMethod) +{ + FmgrInfo *compareFunction = NULL; + + if (partitionMethod == DISTRIBUTE_BY_HASH) + { + compareFunction = GetFunctionInfo(INT4OID, BTREE_AM_OID, BTORDER_PROC); + } + else + { + compareFunction = GetFunctionInfo(partitionColumn->vartype, + BTREE_AM_OID, BTORDER_PROC); + } + + return compareFunction; +} + + +/* + * IsUniformHashDistribution determines whether the given list of sorted shards + * has a uniform hash distribution, as produced by master_create_worker_shards. + */ +static bool +IsUniformHashDistribution(ShardInterval **shardIntervalArray, int shardCount) +{ + uint64 hashTokenIncrement = HASH_TOKEN_COUNT / shardCount; + int shardIndex = 0; + + for (shardIndex = 0; shardIndex < shardCount; shardIndex++) + { + ShardInterval *shardInterval = shardIntervalArray[shardIndex]; + int32 shardMinHashToken = INT32_MIN + (shardIndex * hashTokenIncrement); + int32 shardMaxHashToken = shardMinHashToken + (hashTokenIncrement - 1); + + if (shardIndex == (shardCount - 1)) + { + shardMaxHashToken = INT32_MAX; + } + + if (DatumGetInt32(shardInterval->minValue) != shardMinHashToken || + DatumGetInt32(shardInterval->maxValue) != shardMaxHashToken) + { + return false; + } + } + + return true; +} + + +/* + * FindShardInterval finds a single shard interval in the cache for the + * given partition column value. + */ +static ShardInterval * +FindShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, + int shardCount, char partitionMethod, FmgrInfo *compareFunction, + FmgrInfo *hashFunction, bool useBinarySearch) +{ + ShardInterval *shardInterval = NULL; + + if (partitionMethod == DISTRIBUTE_BY_HASH) + { + int hashedValue = DatumGetInt32(FunctionCall1(hashFunction, + partitionColumnValue)); + if (useBinarySearch) + { + shardInterval = SearchCachedShardInterval(Int32GetDatum(hashedValue), + shardIntervalCache, shardCount, + compareFunction); + } + else + { + uint64 hashTokenIncrement = HASH_TOKEN_COUNT / shardCount; + int shardIndex = (uint32) (hashedValue - INT32_MIN) / hashTokenIncrement; + + shardInterval = shardIntervalCache[shardIndex]; + } + } + else + { + shardInterval = SearchCachedShardInterval(partitionColumnValue, + shardIntervalCache, shardCount, + compareFunction); + } + + return shardInterval; +} + + +/* + * SearchCachedShardInterval performs a binary search for a shard interval matching a + * given partition column value and returns it. + */ +static ShardInterval * +SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, + int shardCount, FmgrInfo *compareFunction) +{ + int lowerBoundIndex = 0; + int upperBoundIndex = shardCount; + + while (lowerBoundIndex < upperBoundIndex) + { + int middleIndex = (lowerBoundIndex + upperBoundIndex) >> 1; + int maxValueComparison = 0; + int minValueComparison = 0; + + minValueComparison = FunctionCall2Coll(compareFunction, + DEFAULT_COLLATION_OID, + partitionColumnValue, + shardIntervalCache[middleIndex]->minValue); + + if (DatumGetInt32(minValueComparison) < 0) + { + upperBoundIndex = middleIndex; + continue; + } + + maxValueComparison = FunctionCall2Coll(compareFunction, + DEFAULT_COLLATION_OID, + partitionColumnValue, + shardIntervalCache[middleIndex]->maxValue); + + if (DatumGetInt32(maxValueComparison) <= 0) + { + return shardIntervalCache[middleIndex]; + } + + lowerBoundIndex = middleIndex + 1; + } + + return NULL; +} + + +/* + * OpenCopyTransactions opens a connection for each placement of a shard and + * starts a COPY transaction. If a connection cannot be opened, then the shard + * placement is marked as inactive and the COPY continues with the remaining + * shard placements. + */ +static void +OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections, + int64 shardId) +{ + List *finalizedPlacementList = NIL; + List *failedPlacementList = NIL; + ListCell *placementCell = NULL; + ListCell *failedPlacementCell = NULL; + List *connectionList = NIL; + + finalizedPlacementList = FinalizedShardPlacementList(shardId); + + foreach(placementCell, finalizedPlacementList) + { + ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell); + char *nodeName = placement->nodeName; + int nodePort = placement->nodePort; + TransactionConnection *transactionConnection = NULL; + StringInfo copyCommand = NULL; + PGresult *result = NULL; + + PGconn *connection = ConnectToNode(nodeName, nodePort); + if (connection == NULL) + { + failedPlacementList = lappend(failedPlacementList, placement); + continue; + } + + result = PQexec(connection, "BEGIN"); + if (PQresultStatus(result) != PGRES_COMMAND_OK) + { + ReportRemoteError(connection, result); + failedPlacementList = lappend(failedPlacementList, placement); + continue; + } + + copyCommand = ConstructCopyStatement(copyStatement, shardId); + + result = PQexec(connection, copyCommand->data); + if (PQresultStatus(result) != PGRES_COPY_IN) + { + ReportRemoteError(connection, result); + failedPlacementList = lappend(failedPlacementList, placement); + continue; + } + + transactionConnection = palloc0(sizeof(TransactionConnection)); + + transactionConnection->connectionId = shardId; + transactionConnection->transactionState = TRANSACTION_STATE_COPY_STARTED; + transactionConnection->connection = connection; + + connectionList = lappend(connectionList, transactionConnection); + } + + /* if all placements failed, error out */ + if (list_length(failedPlacementList) == list_length(finalizedPlacementList)) + { + ereport(ERROR, (errmsg("could not modify any active placements"))); + } + + /* otherwise, mark failed placements as inactive: they're stale */ + foreach(failedPlacementCell, failedPlacementList) + { + ShardPlacement *failedPlacement = (ShardPlacement *) lfirst(failedPlacementCell); + uint64 shardLength = 0; + + DeleteShardPlacementRow(failedPlacement->shardId, failedPlacement->nodeName, + failedPlacement->nodePort); + InsertShardPlacementRow(failedPlacement->shardId, FILE_INACTIVE, shardLength, + failedPlacement->nodeName, failedPlacement->nodePort); + } + + shardConnections->connectionList = connectionList; +} + + +/* + * ConstructCopyStattement constructs the text of a COPY statement for a particular + * shard. + */ +static StringInfo +ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId) +{ + StringInfo command = makeStringInfo(); + char *qualifiedName = NULL; + + qualifiedName = quote_qualified_identifier(copyStatement->relation->schemaname, + copyStatement->relation->relname); + + appendStringInfo(command, "COPY %s_%ld ", qualifiedName, shardId); + + if (copyStatement->attlist != NIL) + { + AppendColumnNames(command, copyStatement->attlist); + } + + appendStringInfoString(command, "FROM STDIN"); + + if (copyStatement->options) + { + appendStringInfoString(command, " WITH "); + + AppendCopyOptions(command, copyStatement->options); + } + + return command; +} + + +/* + * AppendCopyOptions deparses a list of CopyStmt options and appends them to command. + */ +static void +AppendCopyOptions(StringInfo command, List *copyOptionList) +{ + ListCell *optionCell = NULL; + char separator = '('; + + foreach(optionCell, copyOptionList) + { + DefElem *option = (DefElem *) lfirst(optionCell); + + if (strcmp(option->defname, "header") == 0 && defGetBoolean(option)) + { + /* worker should not skip header again */ + continue; + } + + appendStringInfo(command, "%c%s ", separator, option->defname); + + if (strcmp(option->defname, "force_not_null") == 0 || + strcmp(option->defname, "force_null") == 0) + { + if (!option->arg) + { + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg( + "argument to option \"%s\" must be a list of column names", + option->defname))); + } + else + { + AppendColumnNames(command, (List *) option->arg); + } + } + else + { + appendStringInfo(command, "'%s'", defGetString(option)); + } + + separator = ','; + } + + appendStringInfoChar(command, ')'); +} + + +/* + * AppendColumnList deparses a list of column names into a StringInfo. + */ +static void +AppendColumnNames(StringInfo command, List *columnList) +{ + ListCell *attributeCell = NULL; + char separator = '('; + + foreach(attributeCell, columnList) + { + char *columnName = strVal(lfirst(attributeCell)); + appendStringInfo(command, "%c%s", separator, quote_identifier(columnName)); + separator = ','; + } + + appendStringInfoChar(command, ')'); +} + + +/* + * CopyRowToPlacements copies a row to a list of placements for a shard. + */ +static void +CopyRowToPlacements(StringInfo lineBuf, ShardConnections *shardConnections) +{ + ListCell *connectionCell = NULL; + foreach(connectionCell, shardConnections->connectionList) + { + TransactionConnection *transactionConnection = + (TransactionConnection *) lfirst(connectionCell); + PGconn *connection = transactionConnection->connection; + int64 shardId = shardConnections->shardId; + + /* copy the line buffer into the placement */ + int copyResult = PQputCopyData(connection, lineBuf->data, lineBuf->len); + if (copyResult != 1) + { + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + ereport(ERROR, (errcode(ERRCODE_IO_ERROR), + errmsg("Failed to COPY to shard %ld on %s:%s", + shardId, nodeName, nodePort))); + } + } +} + + +/* + * ConnectionList flattens the connection hash to a list of placement connections. + */ +static List * +ConnectionList(HTAB *connectionHash) +{ + List *connectionList = NIL; + HASH_SEQ_STATUS status; + ShardConnections *shardConnections = NULL; + + hash_seq_init(&status, connectionHash); + + shardConnections = (ShardConnections *) hash_seq_search(&status); + while (shardConnections != NULL) + { + List *shardConnectionsList = list_copy(shardConnections->connectionList); + connectionList = list_concat(connectionList, shardConnectionsList); + + shardConnections = (ShardConnections *) hash_seq_search(&status); + } + + return connectionList; +} + + +/* + * EndRemoteCopy ends the COPY input on all connections. If stopOnFailure + * is true, then EndRemoteCopy reports an error on failure, otherwise it + * reports a warning or continues. + */ +static void +EndRemoteCopy(List *connectionList, bool stopOnFailure) +{ + ListCell *connectionCell = NULL; + + foreach(connectionCell, connectionList) + { + TransactionConnection *transactionConnection = + (TransactionConnection *) lfirst(connectionCell); + PGconn *connection = transactionConnection->connection; + int64 shardId = transactionConnection->connectionId; + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + int copyEndResult = 0; + PGresult *result = NULL; + + if (transactionConnection->transactionState != TRANSACTION_STATE_COPY_STARTED) + { + /* COPY already ended during the prepare phase */ + continue; + } + + /* end the COPY input */ + copyEndResult = PQputCopyEnd(connection, NULL); + transactionConnection->transactionState = TRANSACTION_STATE_OPEN; + + if (copyEndResult != 1) + { + if (stopOnFailure) + { + ereport(ERROR, (errcode(ERRCODE_IO_ERROR), + errmsg("Failed to COPY to shard %ld on %s:%s", + shardId, nodeName, nodePort))); + } + + continue; + } + + /* check whether there were any COPY errors */ + result = PQgetResult(connection); + if (PQresultStatus(result) != PGRES_COMMAND_OK && stopOnFailure) + { + ReportCopyError(connection, result); + } + + PQclear(result); + } +} + + +/* + * ReportCopyError tries to report a useful error message for the user from + * the remote COPY error messages. + */ +static void +ReportCopyError(PGconn *connection, PGresult *result) +{ + char *remoteMessage = PQresultErrorField(result, PG_DIAG_MESSAGE_PRIMARY); + + if (remoteMessage != NULL) + { + /* probably a constraint violation, show remote message and detail */ + char *remoteDetail = PQresultErrorField(result, PG_DIAG_MESSAGE_DETAIL); + + ereport(ERROR, (errmsg("%s", remoteMessage), + errdetail("%s", remoteDetail))); + } + else + { + /* probably a connection problem, get the message from the connection */ + char *lastNewlineIndex = NULL; + + remoteMessage = PQerrorMessage(connection); + lastNewlineIndex = strrchr(remoteMessage, '\n'); + + /* trim trailing newline, if any */ + if (lastNewlineIndex != NULL) + { + *lastNewlineIndex = '\0'; + } + + ereport(ERROR, (errmsg("%s", remoteMessage))); + } +} diff --git a/src/backend/distributed/executor/multi_utility.c b/src/backend/distributed/executor/multi_utility.c index c08819adf..f6d5af9e0 100644 --- a/src/backend/distributed/executor/multi_utility.c +++ b/src/backend/distributed/executor/multi_utility.c @@ -17,6 +17,7 @@ #include "commands/tablecmds.h" #include "distributed/master_protocol.h" #include "distributed/metadata_cache.h" +#include "distributed/multi_copy.h" #include "distributed/multi_utility.h" #include "distributed/multi_join_order.h" #include "distributed/transmit.h" @@ -50,7 +51,7 @@ static bool IsTransmitStmt(Node *parsetree); static void VerifyTransmitStmt(CopyStmt *copyStatement); /* Local functions forward declarations for processing distributed table commands */ -static Node * ProcessCopyStmt(CopyStmt *copyStatement); +static Node * ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag); static Node * ProcessIndexStmt(IndexStmt *createIndexStatement, const char *createIndexCommand); static Node * ProcessDropIndexStmt(DropStmt *dropIndexStatement, @@ -122,7 +123,12 @@ multi_ProcessUtility(Node *parsetree, if (IsA(parsetree, CopyStmt)) { - parsetree = ProcessCopyStmt((CopyStmt *) parsetree); + parsetree = ProcessCopyStmt((CopyStmt *) parsetree, completionTag); + + if (parsetree == NULL) + { + return; + } } if (IsA(parsetree, IndexStmt)) @@ -300,7 +306,7 @@ VerifyTransmitStmt(CopyStmt *copyStatement) * COPYing from distributed tables and preventing unsupported actions. */ static Node * -ProcessCopyStmt(CopyStmt *copyStatement) +ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag) { /* * We first check if we have a "COPY (query) TO filename". If we do, copy doesn't @@ -344,9 +350,8 @@ ProcessCopyStmt(CopyStmt *copyStatement) { if (copyStatement->is_from) { - ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), - errmsg("cannot execute COPY FROM on a distributed table " - "on master node"))); + CitusCopyFrom(copyStatement, completionTag); + return NULL; } else if (!copyStatement->is_from) { diff --git a/src/backend/distributed/shared_library_init.c b/src/backend/distributed/shared_library_init.c index 8f2eac018..01b961066 100644 --- a/src/backend/distributed/shared_library_init.c +++ b/src/backend/distributed/shared_library_init.c @@ -20,13 +20,15 @@ #include "executor/executor.h" #include "distributed/master_protocol.h" #include "distributed/modify_planner.h" +#include "distributed/multi_copy.h" #include "distributed/multi_executor.h" #include "distributed/multi_explain.h" #include "distributed/multi_join_order.h" #include "distributed/multi_logical_optimizer.h" #include "distributed/multi_planner.h" -#include "distributed/multi_router_executor.h" +#include "distributed/multi_router_executor.h" #include "distributed/multi_server_executor.h" +#include "distributed/multi_transaction.h" #include "distributed/multi_utility.h" #include "distributed/task_tracker.h" #include "distributed/worker_manager.h" @@ -67,6 +69,12 @@ static const struct config_enum_entry shard_placement_policy_options[] = { { NULL, 0, false } }; +static const struct config_enum_entry transaction_manager_options[] = { + { "1pc", TRANSACTION_MANAGER_1PC, false }, + { "2pc", TRANSACTION_MANAGER_2PC, false }, + { NULL, 0, false } +}; + /* shared library initialization function */ void @@ -437,6 +445,21 @@ RegisterCitusConfigVariables(void) 0, NULL, NULL, NULL); + DefineCustomEnumVariable( + "citus.copy_transaction_manager", + gettext_noop("Sets the transaction manager for COPY into distributed tables."), + gettext_noop("When a failure occurs during when copying into a distributed " + "table, 2PC is required to ensure data is never lost. Change " + "this setting to '2pc' from its default '1pc' to enable 2PC." + "You must also set max_prepared_transactions on the worker " + "nodes. Recovery from failed 2PCs is currently manual."), + &CopyTransactionManager, + TRANSACTION_MANAGER_1PC, + transaction_manager_options, + PGC_USERSET, + 0, + NULL, NULL, NULL); + DefineCustomEnumVariable( "citus.task_assignment_policy", gettext_noop("Sets the policy to use when assigning tasks to worker nodes."), diff --git a/src/backend/distributed/utils/connection_cache.c b/src/backend/distributed/utils/connection_cache.c index 31a11155e..db419c457 100644 --- a/src/backend/distributed/utils/connection_cache.c +++ b/src/backend/distributed/utils/connection_cache.c @@ -39,8 +39,6 @@ static HTAB *NodeConnectionHash = NULL; /* local function forward declarations */ static HTAB * CreateNodeConnectionHash(void); -static PGconn * ConnectToNode(char *nodeName, char *nodePort); -static char * ConnectionGetOptionValue(PGconn *connection, char *optionKeyword); /* @@ -99,10 +97,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort) if (needNewConnection) { - StringInfo nodePortString = makeStringInfo(); - appendStringInfo(nodePortString, "%d", nodePort); - - connection = ConnectToNode(nodeName, nodePortString->data); + connection = ConnectToNode(nodeName, nodePort); if (connection != NULL) { nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey, @@ -264,8 +259,8 @@ CreateNodeConnectionHash(void) * We attempt to connect up to MAX_CONNECT_ATTEMPT times. After that we give up * and return NULL. */ -static PGconn * -ConnectToNode(char *nodeName, char *nodePort) +PGconn * +ConnectToNode(char *nodeName, int32 nodePort) { PGconn *connection = NULL; const char *clientEncoding = GetDatabaseEncodingName(); @@ -276,11 +271,14 @@ ConnectToNode(char *nodeName, char *nodePort) "host", "port", "fallback_application_name", "client_encoding", "connect_timeout", "dbname", NULL }; + char nodePortString[12]; const char *valueArray[] = { - nodeName, nodePort, "citus", clientEncoding, + nodeName, nodePortString, "citus", clientEncoding, CLIENT_CONNECT_TIMEOUT_SECONDS, dbname, NULL }; + sprintf(nodePortString, "%d", nodePort); + Assert(sizeof(keywordArray) == sizeof(valueArray)); for (attemptIndex = 0; attemptIndex < MAX_CONNECT_ATTEMPTS; attemptIndex++) @@ -313,7 +311,7 @@ ConnectToNode(char *nodeName, char *nodePort) * The function returns NULL if the connection has no setting for an option with * the provided keyword. */ -static char * +char * ConnectionGetOptionValue(PGconn *connection, char *optionKeyword) { char *optionValue = NULL; diff --git a/src/backend/distributed/utils/multi_transaction.c b/src/backend/distributed/utils/multi_transaction.c new file mode 100644 index 000000000..667542ada --- /dev/null +++ b/src/backend/distributed/utils/multi_transaction.c @@ -0,0 +1,211 @@ +/*------------------------------------------------------------------------- + * + * multi_transaction.c + * This file contains functions for managing 1PC or 2PC transactions + * across many shard placements. + * + * Copyright (c) 2016, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" +#include "libpq-fe.h" +#include "miscadmin.h" + +#include "access/xact.h" +#include "distributed/connection_cache.h" +#include "distributed/multi_transaction.h" +#include "lib/stringinfo.h" +#include "nodes/pg_list.h" + + +/* Local functions forward declarations */ +static StringInfo BuildTransactionName(int connectionId); + + +/* + * PrepareTransactions prepares all transactions on connections in + * connectionList for commit if the 2PC transaction manager is enabled. + * On failure, it reports an error and stops. + */ +void +PrepareTransactions(List *connectionList) +{ + ListCell *connectionCell = NULL; + + foreach(connectionCell, connectionList) + { + TransactionConnection *transactionConnection = + (TransactionConnection *) lfirst(connectionCell); + PGconn *connection = transactionConnection->connection; + int64 connectionId = transactionConnection->connectionId; + + PGresult *result = NULL; + StringInfo command = makeStringInfo(); + StringInfo transactionName = BuildTransactionName(connectionId); + + appendStringInfo(command, "PREPARE TRANSACTION '%s'", transactionName->data); + + result = PQexec(connection, command->data); + if (PQresultStatus(result) != PGRES_COMMAND_OK) + { + /* a failure to prepare is an implicit rollback */ + transactionConnection->transactionState = TRANSACTION_STATE_CLOSED; + + ReportRemoteError(connection, result); + PQclear(result); + + ereport(ERROR, (errcode(ERRCODE_IO_ERROR), + errmsg("Failed to prepare transaction"))); + } + + PQclear(result); + + transactionConnection->transactionState = TRANSACTION_STATE_PREPARED; + } +} + + +/* + * AbortTransactions aborts all transactions on connections in connectionList. + * On failure, it reports a warning and continues to abort all of them. + */ +void +AbortTransactions(List *connectionList) +{ + ListCell *connectionCell = NULL; + + foreach(connectionCell, connectionList) + { + TransactionConnection *transactionConnection = + (TransactionConnection *) lfirst(connectionCell); + PGconn *connection = transactionConnection->connection; + int64 connectionId = transactionConnection->connectionId; + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + PGresult *result = NULL; + + if (transactionConnection->transactionState == TRANSACTION_STATE_PREPARED) + { + StringInfo command = makeStringInfo(); + StringInfo transactionName = BuildTransactionName(connectionId); + + appendStringInfo(command, "ROLLBACK PREPARED '%s'", transactionName->data); + + result = PQexec(connection, command->data); + if (PQresultStatus(result) != PGRES_COMMAND_OK) + { + /* log a warning so the user may abort the transaction later */ + ereport(WARNING, (errmsg("Failed to roll back prepared transaction '%s'", + transactionName->data), + errhint("Run ROLLBACK TRANSACTION '%s' on %s:%s", + transactionName->data, nodeName, nodePort))); + } + + PQclear(result); + } + else if (transactionConnection->transactionState == TRANSACTION_STATE_OPEN) + { + /* try to roll back cleanly, if it fails then we won't commit anyway */ + result = PQexec(connection, "ROLLBACK"); + PQclear(result); + } + + transactionConnection->transactionState = TRANSACTION_STATE_CLOSED; + } +} + + +/* + * CommitTransactions commits all transactions on connections in connectionList. + * On failure, it reports a warning and continues committing all of them. + */ +void +CommitTransactions(List *connectionList) +{ + ListCell *connectionCell = NULL; + + foreach(connectionCell, connectionList) + { + TransactionConnection *transactionConnection = + (TransactionConnection *) lfirst(connectionCell); + PGconn *connection = transactionConnection->connection; + int64 connectionId = transactionConnection->connectionId; + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + PGresult *result = NULL; + + if (transactionConnection->transactionState == TRANSACTION_STATE_PREPARED) + { + StringInfo command = makeStringInfo(); + StringInfo transactionName = BuildTransactionName(connectionId); + + /* we shouldn't be committing if any transactions are not prepared */ + Assert(transactionConnection->transactionState == TRANSACTION_STATE_PREPARED); + + appendStringInfo(command, "COMMIT PREPARED '%s'", transactionName->data); + + result = PQexec(connection, command->data); + if (PQresultStatus(result) != PGRES_COMMAND_OK) + { + /* log a warning so the user may commit the transaction later */ + ereport(WARNING, (errmsg("Failed to commit prepared transaction '%s'", + transactionName->data), + errhint("Run COMMIT TRANSACTION '%s' on %s:%s", + transactionName->data, nodeName, nodePort))); + } + } + else + { + /* we shouldn't be committing if any transactions are not open */ + Assert(transactionConnection->transactionState == TRANSACTION_STATE_OPEN); + + /* try to commit, if it fails then the user might lose data */ + result = PQexec(connection, "COMMIT"); + if (PQresultStatus(result) != PGRES_COMMAND_OK) + { + ereport(WARNING, (errmsg("Failed to commit transaction on %s:%s", + nodeName, nodePort))); + } + } + + PQclear(result); + + transactionConnection->transactionState = TRANSACTION_STATE_CLOSED; + } +} + + +/* + * BuildTransactionName constructs a unique transaction name from an ID. + */ +static StringInfo +BuildTransactionName(int connectionId) +{ + StringInfo commandString = makeStringInfo(); + + appendStringInfo(commandString, "citus_%d_%u_%d", MyProcPid, + GetCurrentTransactionId(), connectionId); + + return commandString; +} + + +/* + * CloseConnections closes all connections in connectionList. + */ +void +CloseConnections(List *connectionList) +{ + ListCell *connectionCell = NULL; + + foreach(connectionCell, connectionList) + { + TransactionConnection *transactionConnection = + (TransactionConnection *) lfirst(connectionCell); + PGconn *connection = transactionConnection->connection; + + PQfinish(connection); + } +} diff --git a/src/include/distributed/connection_cache.h b/src/include/distributed/connection_cache.h index db97a8b1b..e39a0c406 100644 --- a/src/include/distributed/connection_cache.h +++ b/src/include/distributed/connection_cache.h @@ -54,6 +54,8 @@ typedef struct NodeConnectionEntry extern PGconn * GetOrEstablishConnection(char *nodeName, int32 nodePort); extern void PurgeConnection(PGconn *connection); extern void ReportRemoteError(PGconn *connection, PGresult *result); +extern PGconn * ConnectToNode(char *nodeName, int nodePort); +extern char * ConnectionGetOptionValue(PGconn *connection, char *optionKeyword); #endif /* CONNECTION_CACHE_H */ diff --git a/src/include/distributed/multi_copy.h b/src/include/distributed/multi_copy.h new file mode 100644 index 000000000..279b8c165 --- /dev/null +++ b/src/include/distributed/multi_copy.h @@ -0,0 +1,27 @@ +/*------------------------------------------------------------------------- + * + * multi_copy.h + * Declarations for public functions and variables used in COPY for + * distributed tables. + * + * Copyright (c) 2016, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#ifndef MULTI_COPY_H +#define MULTI_COPY_H + + +#include "nodes/parsenodes.h" + + +/* config variable managed via guc.c */ +extern int CopyTransactionManager; + + +/* function declarations for copying into a distributed table */ +extern void CitusCopyFrom(CopyStmt *copyStatement, char *completionTag); + + +#endif /* MULTI_COPY_H */ diff --git a/src/include/distributed/multi_transaction.h b/src/include/distributed/multi_transaction.h new file mode 100644 index 000000000..827f0fcaa --- /dev/null +++ b/src/include/distributed/multi_transaction.h @@ -0,0 +1,57 @@ +/*------------------------------------------------------------------------- + * + * multi_transaction.h + * Type and function declarations used in performing transactions across + * shard placements. + * + * Copyright (c) 2016, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#ifndef MULTI_TRANSACTION_H +#define MULTI_TRANSACTION_H + + +#include "libpq-fe.h" +#include "lib/stringinfo.h" +#include "nodes/pg_list.h" + + +/* Enumeration that defines the different transaction managers available */ +typedef enum +{ + TRANSACTION_MANAGER_1PC = 0, + TRANSACTION_MANAGER_2PC = 1 +} TransactionManagerType; + +/* Enumeration that defines different remote transaction states */ +typedef enum +{ + TRANSACTION_STATE_INVALID = 0, + TRANSACTION_STATE_OPEN, + TRANSACTION_STATE_COPY_STARTED, + TRANSACTION_STATE_PREPARED, + TRANSACTION_STATE_CLOSED +} TransactionState; + +/* + * TransactionConnection represents a connection to a remote node which is + * used to perform a transaction on shard placements. + */ +typedef struct TransactionConnection +{ + int64 connectionId; + TransactionState transactionState; + PGconn *connection; +} TransactionConnection; + + +/* Functions declarations for transaction and connection management */ +extern void PrepareTransactions(List *connectionList); +extern void AbortTransactions(List *connectionList); +extern void CommitTransactions(List *connectionList); +extern void CloseConnections(List *connectionList); + + +#endif /* MULTI_TRANSACTION_H */ diff --git a/src/test/regress/expected/multi_utilities.out b/src/test/regress/expected/multi_utilities.out index 6deec6f9f..6aa4fa18e 100644 --- a/src/test/regress/expected/multi_utilities.out +++ b/src/test/regress/expected/multi_utilities.out @@ -18,9 +18,6 @@ SELECT master_create_worker_shards('sharded_table', 2, 1); COPY sharded_table TO STDOUT; COPY (SELECT COUNT(*) FROM sharded_table) TO STDOUT; 0 --- but COPY in is not -COPY sharded_table FROM STDIN; -ERROR: cannot execute COPY FROM on a distributed table on master node -- cursors may not involve distributed tables DECLARE all_sharded_rows CURSOR FOR SELECT * FROM sharded_table; ERROR: DECLARE CURSOR can only be used in transaction blocks diff --git a/src/test/regress/expected/multi_utility_statements.out b/src/test/regress/expected/multi_utility_statements.out index ca5d7b4ad..1b502f20b 100644 --- a/src/test/regress/expected/multi_utility_statements.out +++ b/src/test/regress/expected/multi_utility_statements.out @@ -188,9 +188,6 @@ VIETNAM RUSSIA UNITED KINGDOM UNITED STATES --- Ensure that preventing COPY FROM against distributed tables works -COPY customer FROM STDIN; -ERROR: cannot execute COPY FROM on a distributed table on master node -- Test that we can create on-commit drop tables, and also test creating with -- oids, along with changing column names BEGIN; diff --git a/src/test/regress/input/multi_copy.source b/src/test/regress/input/multi_copy.source new file mode 100644 index 000000000..92544b9a8 --- /dev/null +++ b/src/test/regress/input/multi_copy.source @@ -0,0 +1,142 @@ +-- +-- MULTI_COPY +-- +-- Create a new hash-partitioned table into which to COPY +CREATE TABLE customer_copy_hash ( + c_custkey integer, + c_name varchar(25) not null, + c_address varchar(40), + c_nationkey integer, + c_phone char(15), + c_acctbal decimal(15,2), + c_mktsegment char(10), + c_comment varchar(117), + primary key (c_custkey)); +SELECT master_create_distributed_table('customer_copy_hash', 'c_custkey', 'hash'); + +-- Test COPY into empty hash-partitioned table +COPY customer_copy_hash FROM '@abs_srcdir@/data/customer.1.data' WITH (DELIMITER '|'); + +SELECT master_create_worker_shards('customer_copy_hash', 64, 1); + +-- Test empty copy +COPY customer_copy_hash FROM STDIN; +\. + +-- Test syntax error +COPY customer_copy_hash (c_custkey,c_name) FROM STDIN; +1,customer1 +2,customer2, +notinteger,customernot +\. + +-- Confirm that no data was copied +SELECT count(*) FROM customer_copy_hash; + +-- Test primary key violation +COPY customer_copy_hash (c_custkey, c_name) FROM STDIN +WITH (FORMAT 'csv'); +1,customer1 +2,customer2 +2,customer2 +\. + +-- Confirm that no data was copied +SELECT count(*) FROM customer_copy_hash; + +-- Test headers option +COPY customer_copy_hash (c_custkey, c_name) FROM STDIN +WITH (FORMAT 'csv', HEADER true, FORCE_NULL (c_custkey)); +# header +1,customer1 +2,customer2 +3,customer3 +\. + +-- Confirm that only first row was skipped +SELECT count(*) FROM customer_copy_hash; + +-- Test force_not_null option +COPY customer_copy_hash (c_custkey, c_name, c_address) FROM STDIN +WITH (FORMAT 'csv', QUOTE '"', FORCE_NOT_NULL (c_address)); +"4","customer4","" +\. + +-- Confirm that value is not null +SELECT count(c_address) FROM customer_copy_hash WHERE c_custkey = 4; + +-- Test force_null option +COPY customer_copy_hash (c_custkey, c_name, c_address) FROM STDIN +WITH (FORMAT 'csv', QUOTE '"', FORCE_NULL (c_address)); +"5","customer5","" +\. + +-- Confirm that value is null +SELECT count(c_address) FROM customer_copy_hash WHERE c_custkey = 5; + +-- Test null violation +COPY customer_copy_hash (c_custkey, c_name) FROM STDIN +WITH (FORMAT 'csv'); +6,customer6 +7,customer7 +8, +\. + +-- Confirm that no data was copied +SELECT count(*) FROM customer_copy_hash; + +-- Test server-side copy from program +COPY customer_copy_hash (c_custkey, c_name) FROM PROGRAM 'echo 9 customer9' +WITH (DELIMITER ' '); + +-- Confirm that data was copied +SELECT count(*) FROM customer_copy_hash WHERE c_custkey = 9; + +-- Test server-side copy from file +COPY customer_copy_hash FROM '@abs_srcdir@/data/customer.2.data' WITH (DELIMITER '|'); + +-- Confirm that data was copied +SELECT count(*) FROM customer_copy_hash; + +-- Test client-side copy from file +\COPY customer_copy_hash FROM '@abs_srcdir@/data/customer.3.data' WITH (DELIMITER '|'); + +-- Confirm that data was copied +SELECT count(*) FROM customer_copy_hash; + +-- Create a new range-partitioned table into which to COPY +CREATE TABLE customer_copy_range ( + c_custkey integer, + c_name varchar(25), + c_address varchar(40), + c_nationkey integer, + c_phone char(15), + c_acctbal decimal(15,2), + c_mktsegment char(10), + c_comment varchar(117), + primary key (c_custkey)); + +SELECT master_create_distributed_table('customer_copy_range', 'c_custkey', 'range'); + +-- Test COPY into empty range-partitioned table +COPY customer_copy_range FROM '@abs_srcdir@/data/customer.1.data' WITH (DELIMITER '|'); + +SELECT master_create_empty_shard('customer_copy_range') AS new_shard_id +\gset +UPDATE pg_dist_shard SET shardminvalue = 1, shardmaxvalue = 500 +WHERE shardid = :new_shard_id; + +SELECT master_create_empty_shard('customer_copy_range') AS new_shard_id +\gset +UPDATE pg_dist_shard SET shardminvalue = 501, shardmaxvalue = 1000 +WHERE shardid = :new_shard_id; + +-- Test copy into range-partitioned table +COPY customer_copy_range FROM '@abs_srcdir@/data/customer.1.data' WITH (DELIMITER '|'); + +-- Check whether data went into the right shard (maybe) +SELECT min(c_custkey), max(c_custkey), avg(c_custkey), count(*) +FROM customer_copy_range WHERE c_custkey <= 500; + +-- Check whether data was copied +SELECT count(*) FROM customer_copy_range; diff --git a/src/test/regress/multi_schedule b/src/test/regress/multi_schedule index c690f5da2..ae71426f4 100644 --- a/src/test/regress/multi_schedule +++ b/src/test/regress/multi_schedule @@ -100,6 +100,7 @@ test: multi_append_table_to_shard # --------- test: multi_outer_join +# # --- # Tests covering mostly modification queries and required preliminary # functionality related to metadata, shard creation, shard pruning and @@ -121,6 +122,11 @@ test: multi_create_insert_proxy test: multi_data_types test: multi_repartitioned_subquery_udf +# --------- +# multi_copy creates hash and range-partitioned tables and performs COPY +# --------- +test: multi_copy + # ---------- # multi_large_shardid stages more shards into lineitem # ---------- diff --git a/src/test/regress/output/multi_copy.source b/src/test/regress/output/multi_copy.source new file mode 100644 index 000000000..7edce8025 --- /dev/null +++ b/src/test/regress/output/multi_copy.source @@ -0,0 +1,174 @@ +-- +-- MULTI_COPY +-- +-- Create a new hash-partitioned table into which to COPY +CREATE TABLE customer_copy_hash ( + c_custkey integer, + c_name varchar(25) not null, + c_address varchar(40), + c_nationkey integer, + c_phone char(15), + c_acctbal decimal(15,2), + c_mktsegment char(10), + c_comment varchar(117), + primary key (c_custkey)); +SELECT master_create_distributed_table('customer_copy_hash', 'c_custkey', 'hash'); + master_create_distributed_table +--------------------------------- + +(1 row) + +-- Test COPY into empty hash-partitioned table +COPY customer_copy_hash FROM '@abs_srcdir@/data/customer.1.data' WITH (DELIMITER '|'); +ERROR: could not find any shards for query +DETAIL: No shards exist for distributed table "customer_copy_hash". +HINT: Run master_create_worker_shards to create shards and try again. +SELECT master_create_worker_shards('customer_copy_hash', 64, 1); + master_create_worker_shards +----------------------------- + +(1 row) + +-- Test empty copy +COPY customer_copy_hash FROM STDIN; +-- Test syntax error +COPY customer_copy_hash (c_custkey,c_name) FROM STDIN; +ERROR: invalid input syntax for integer: "1,customer1" +CONTEXT: COPY customer_copy_hash, line 1, column c_custkey: "1,customer1" +-- Confirm that no data was copied +SELECT count(*) FROM customer_copy_hash; + count +------- + 0 +(1 row) + +-- Test primary key violation +COPY customer_copy_hash (c_custkey, c_name) FROM STDIN +WITH (FORMAT 'csv'); +ERROR: duplicate key value violates unique constraint "customer_copy_hash_pkey_103160" +DETAIL: Key (c_custkey)=(2) already exists. +CONTEXT: COPY customer_copy_hash, line 4: "" +-- Confirm that no data was copied +SELECT count(*) FROM customer_copy_hash; + count +------- + 0 +(1 row) + +-- Test headers option +COPY customer_copy_hash (c_custkey, c_name) FROM STDIN +WITH (FORMAT 'csv', HEADER true, FORCE_NULL (c_custkey)); +-- Confirm that only first row was skipped +SELECT count(*) FROM customer_copy_hash; + count +------- + 3 +(1 row) + +-- Test force_not_null option +COPY customer_copy_hash (c_custkey, c_name, c_address) FROM STDIN +WITH (FORMAT 'csv', QUOTE '"', FORCE_NOT_NULL (c_address)); +-- Confirm that value is not null +SELECT count(c_address) FROM customer_copy_hash WHERE c_custkey = 4; + count +------- + 1 +(1 row) + +-- Test force_null option +COPY customer_copy_hash (c_custkey, c_name, c_address) FROM STDIN +WITH (FORMAT 'csv', QUOTE '"', FORCE_NULL (c_address)); +-- Confirm that value is null +SELECT count(c_address) FROM customer_copy_hash WHERE c_custkey = 5; + count +------- + 0 +(1 row) + +-- Test null violation +COPY customer_copy_hash (c_custkey, c_name) FROM STDIN +WITH (FORMAT 'csv'); +ERROR: null value in column "c_name" violates not-null constraint +DETAIL: Failing row contains (8, null, null, null, null, null, null, null). +CONTEXT: COPY customer_copy_hash, line 4: "" +-- Confirm that no data was copied +SELECT count(*) FROM customer_copy_hash; + count +------- + 5 +(1 row) + +-- Test server-side copy from program +COPY customer_copy_hash (c_custkey, c_name) FROM PROGRAM 'echo 9 customer9' +WITH (DELIMITER ' '); +-- Confirm that data was copied +SELECT count(*) FROM customer_copy_hash WHERE c_custkey = 9; + count +------- + 1 +(1 row) + +-- Test server-side copy from file +COPY customer_copy_hash FROM '@abs_srcdir@/data/customer.2.data' WITH (DELIMITER '|'); +-- Confirm that data was copied +SELECT count(*) FROM customer_copy_hash; + count +------- + 1006 +(1 row) + +-- Test client-side copy from file +\COPY customer_copy_hash FROM '@abs_srcdir@/data/customer.3.data' WITH (DELIMITER '|'); +-- Confirm that data was copied +SELECT count(*) FROM customer_copy_hash; + count +------- + 2006 +(1 row) + +-- Create a new range-partitioned table into which to COPY +CREATE TABLE customer_copy_range ( + c_custkey integer, + c_name varchar(25), + c_address varchar(40), + c_nationkey integer, + c_phone char(15), + c_acctbal decimal(15,2), + c_mktsegment char(10), + c_comment varchar(117), + primary key (c_custkey)); +SELECT master_create_distributed_table('customer_copy_range', 'c_custkey', 'range'); + master_create_distributed_table +--------------------------------- + +(1 row) + +-- Test COPY into empty range-partitioned table +COPY customer_copy_range FROM '@abs_srcdir@/data/customer.1.data' WITH (DELIMITER '|'); +ERROR: could not find any shards for query +DETAIL: No shards exist for distributed table "customer_copy_range". +SELECT master_create_empty_shard('customer_copy_range') AS new_shard_id +\gset +UPDATE pg_dist_shard SET shardminvalue = 1, shardmaxvalue = 500 +WHERE shardid = :new_shard_id; +SELECT master_create_empty_shard('customer_copy_range') AS new_shard_id +\gset +UPDATE pg_dist_shard SET shardminvalue = 501, shardmaxvalue = 1000 +WHERE shardid = :new_shard_id; +-- Test copy into range-partitioned table +COPY customer_copy_range FROM '@abs_srcdir@/data/customer.1.data' WITH (DELIMITER '|'); +-- Check whether data went into the right shard (maybe) +SELECT min(c_custkey), max(c_custkey), avg(c_custkey), count(*) +FROM customer_copy_range WHERE c_custkey <= 500; + min | max | avg | count +-----+-----+----------------------+------- + 1 | 500 | 250.5000000000000000 | 500 +(1 row) + +-- Check whether data was copied +SELECT count(*) FROM customer_copy_range; + count +------- + 1000 +(1 row) + diff --git a/src/test/regress/sql/multi_utilities.sql b/src/test/regress/sql/multi_utilities.sql index a503b2c0d..d2161f246 100644 --- a/src/test/regress/sql/multi_utilities.sql +++ b/src/test/regress/sql/multi_utilities.sql @@ -10,9 +10,6 @@ SELECT master_create_worker_shards('sharded_table', 2, 1); COPY sharded_table TO STDOUT; COPY (SELECT COUNT(*) FROM sharded_table) TO STDOUT; --- but COPY in is not -COPY sharded_table FROM STDIN; - -- cursors may not involve distributed tables DECLARE all_sharded_rows CURSOR FOR SELECT * FROM sharded_table; diff --git a/src/test/regress/sql/multi_utility_statements.sql b/src/test/regress/sql/multi_utility_statements.sql index ad67f2b36..e3ebdcec6 100644 --- a/src/test/regress/sql/multi_utility_statements.sql +++ b/src/test/regress/sql/multi_utility_statements.sql @@ -104,9 +104,6 @@ COPY nation TO STDOUT; -- ensure individual cols can be copied out, too COPY nation(n_name) TO STDOUT; --- Ensure that preventing COPY FROM against distributed tables works -COPY customer FROM STDIN; - -- Test that we can create on-commit drop tables, and also test creating with -- oids, along with changing column names From 1150ce641462afb1e461bf7e56baab19a393a5c3 Mon Sep 17 00:00:00 2001 From: Metin Doslu Date: Tue, 29 Mar 2016 18:25:33 +0200 Subject: [PATCH 2/2] Send COPY rows in binary format --- src/backend/distributed/commands/multi_copy.c | 969 +++++++++++------- .../distributed/executor/multi_utility.c | 5 +- .../distributed/utils/multi_transaction.c | 67 +- .../worker/worker_partition_protocol.c | 351 +------ src/include/distributed/multi_copy.h | 29 + src/include/distributed/multi_transaction.h | 7 +- src/include/distributed/worker_protocol.h | 22 - src/test/regress/input/multi_copy.source | 48 + src/test/regress/multi_schedule | 1 - src/test/regress/output/multi_copy.source | 61 +- 10 files changed, 848 insertions(+), 712 deletions(-) diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index 2e6c0ee14..5f90cf783 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -4,15 +4,38 @@ * This file contains implementation of COPY utility for distributed * tables. * - * Contributed by Konstantin Knizhnik, Postgres Professional + * The CitusCopyFrom function should be called from the utility hook to + * process COPY ... FROM commands on distributed tables. CitusCopyFrom + * parses the input from stdin, a program executed on the master, or a file + * on the master, and decides into which shard to put the data. It opens a + * new connection for every shard placement and uses the PQputCopyData + * function to copy the data. Because PQputCopyData transmits data, + * asynchronously, the workers will ingest data at least partially in + * parallel. + * + * When failing to connect to a worker, the master marks the placement for + * which it was trying to open a connection as inactive, similar to the way + * DML statements are handled. If a failure occurs after connecting, the + * transaction is rolled back on all the workers. + * + * By default, COPY uses normal transactions on the workers. This can cause + * a problem when some of the transactions fail to commit while others have + * succeeded. To ensure no data is lost, COPY can use two-phase commit, by + * increasing max_prepared_transactions on the worker and setting + * citus.copy_transaction_manager to '2pc'. The default is '1pc'. + * + * Parsing options are processed and enforced on the master, while + * constraints are enforced on the worker. In either case, failure causes + * the whole COPY to roll back. * * Copyright (c) 2016, Citus Data, Inc. * + * With contributions from Postgres Professional. + * *------------------------------------------------------------------------- */ #include "postgres.h" -#include "c.h" #include "fmgr.h" #include "funcapi.h" #include "libpq-fe.h" @@ -96,135 +119,8 @@ /* the transaction manager to use for COPY commands */ int CopyTransactionManager = TRANSACTION_MANAGER_1PC; - -/* Data structures from copy.c, to keep track of COPY processing state */ -typedef enum CopyDest -{ - COPY_FILE, /* to/from file (or a piped program) */ - COPY_OLD_FE, /* to/from frontend (2.0 protocol) */ - COPY_NEW_FE /* to/from frontend (3.0 protocol) */ -} CopyDest; - -typedef enum EolType -{ - EOL_UNKNOWN, - EOL_NL, - EOL_CR, - EOL_CRNL -} EolType; - -typedef struct CopyStateData -{ - /* low-level state data */ - CopyDest copy_dest; /* type of copy source/destination */ - FILE *copy_file; /* used if copy_dest == COPY_FILE */ - StringInfo fe_msgbuf; /* used for all dests during COPY TO, only for - * dest == COPY_NEW_FE in COPY FROM */ - bool fe_eof; /* true if detected end of copy data */ - EolType eol_type; /* EOL type of input */ - int file_encoding; /* file or remote side's character encoding */ - bool need_transcoding; /* file encoding diff from server? */ - bool encoding_embeds_ascii; /* ASCII can be non-first byte? */ - - /* parameters from the COPY command */ - Relation rel; /* relation to copy to or from */ - QueryDesc *queryDesc; /* executable query to copy from */ - List *attnumlist; /* integer list of attnums to copy */ - char *filename; /* filename, or NULL for STDIN/STDOUT */ - bool is_program; /* is 'filename' a program to popen? */ - bool binary; /* binary format? */ - bool oids; /* include OIDs? */ - bool freeze; /* freeze rows on loading? */ - bool csv_mode; /* Comma Separated Value format? */ - bool header_line; /* CSV header line? */ - char *null_print; /* NULL marker string (server encoding!) */ - int null_print_len; /* length of same */ - char *null_print_client; /* same converted to file encoding */ - char *delim; /* column delimiter (must be 1 byte) */ - char *quote; /* CSV quote char (must be 1 byte) */ - char *escape; /* CSV escape char (must be 1 byte) */ - List *force_quote; /* list of column names */ - bool force_quote_all; /* FORCE QUOTE *? */ - bool *force_quote_flags; /* per-column CSV FQ flags */ - List *force_notnull; /* list of column names */ - bool *force_notnull_flags; /* per-column CSV FNN flags */ -#if PG_VERSION_NUM >= 90400 - List *force_null; /* list of column names */ - bool *force_null_flags; /* per-column CSV FN flags */ -#endif - bool convert_selectively; /* do selective binary conversion? */ - List *convert_select; /* list of column names (can be NIL) */ - bool *convert_select_flags; /* per-column CSV/TEXT CS flags */ - - /* these are just for error messages, see CopyFromErrorCallback */ - const char *cur_relname; /* table name for error messages */ - int cur_lineno; /* line number for error messages */ - const char *cur_attname; /* current att for error messages */ - const char *cur_attval; /* current att value for error messages */ - - /* - * Working state for COPY TO/FROM - */ - MemoryContext copycontext; /* per-copy execution context */ - - /* - * Working state for COPY TO - */ - FmgrInfo *out_functions; /* lookup info for output functions */ - MemoryContext rowcontext; /* per-row evaluation context */ - - /* - * Working state for COPY FROM - */ - AttrNumber num_defaults; - bool file_has_oids; - FmgrInfo oid_in_function; - Oid oid_typioparam; - FmgrInfo *in_functions; /* array of input functions for each attrs */ - Oid *typioparams; /* array of element types for in_functions */ - int *defmap; /* array of default att numbers */ - ExprState **defexprs; /* array of default att expressions */ - bool volatile_defexprs; /* is any of defexprs volatile? */ - List *range_table; - - /* - * These variables are used to reduce overhead in textual COPY FROM. - * - * attribute_buf holds the separated, de-escaped text for each field of - * the current line. The CopyReadAttributes functions return arrays of - * pointers into this buffer. We avoid palloc/pfree overhead by re-using - * the buffer on each cycle. - */ - StringInfoData attribute_buf; - - /* field raw data pointers found by COPY FROM */ - - int max_fields; - char **raw_fields; - - /* - * Similarly, line_buf holds the whole input line being processed. The - * input cycle is first to read the whole line into line_buf, convert it - * to server encoding there, and then extract the individual attribute - * fields into attribute_buf. line_buf is preserved unmodified so that we - * can display it in error messages if appropriate. - */ - StringInfoData line_buf; - bool line_buf_converted; /* converted to server encoding? */ - bool line_buf_valid; /* contains the row being processed? */ - - /* - * Finally, raw_buf holds raw data read from the data source (file or - * client connection). CopyReadLine parses this data sufficiently to - * locate line boundaries, then transfers the data to line_buf and - * converts it. Note: we guarantee that there is a \0 at - * raw_buf[raw_buf_len]. - */ -#define RAW_BUF_SIZE 65536 /* we palloc RAW_BUF_SIZE+1 bytes */ - char *raw_buf; - int raw_buf_index; /* next byte to process */ - int raw_buf_len; /* total # of bytes stored */ -} CopyStateData; +/* constant used in binary protocol */ +static const char BinarySignature[11] = "PGCOPY\n\377\r\n\0"; /* ShardConnections represents a set of connections for each placement of a shard */ @@ -236,6 +132,7 @@ typedef struct ShardConnections /* Local functions forward declarations */ +static void LockAllShards(List *shardIntervalList); static HTAB * CreateShardConnectionHash(void); static int CompareShardIntervalsById(const void *leftElement, const void *rightElement); static bool IsUniformHashDistribution(ShardInterval **shardIntervalArray, @@ -251,16 +148,28 @@ static ShardInterval * SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, int shardCount, FmgrInfo *compareFunction); +static ShardConnections * GetShardConnections(HTAB *shardConnectionHash, + int64 shardId, + bool *shardConnectionsFound); static void OpenCopyTransactions(CopyStmt *copyStatement, - ShardConnections *shardConnections, - int64 shardId); + ShardConnections *shardConnections); static StringInfo ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId); -static void AppendColumnNames(StringInfo command, List *columnList); -static void AppendCopyOptions(StringInfo command, List *copyOptionList); -static void CopyRowToPlacements(StringInfo lineBuf, ShardConnections *shardConnections); +static void SendCopyDataToAll(StringInfo dataBuffer, List *connectionList); +static void SendCopyDataToPlacement(StringInfo dataBuffer, PGconn *connection, + int64 shardId); static List * ConnectionList(HTAB *connectionHash); static void EndRemoteCopy(List *connectionList, bool stopOnFailure); static void ReportCopyError(PGconn *connection, PGresult *result); +static uint32 AvailableColumnCount(TupleDesc tupleDescriptor); + +/* Private functions copied and adapted from copy.c in PostgreSQL */ +static void CopySendData(CopyOutState outputState, const void *databuf, int datasize); +static void CopySendString(CopyOutState outputState, const char *str); +static void CopySendChar(CopyOutState outputState, char c); +static void CopySendInt32(CopyOutState outputState, int32 val); +static void CopySendInt16(CopyOutState outputState, int16 val); +static void CopyAttributeOutText(CopyOutState outputState, char *string); +static inline void CopyFlushOutput(CopyOutState outputState, char *start, char *pointer); /* @@ -270,30 +179,36 @@ static void ReportCopyError(PGconn *connection, PGresult *result); void CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) { - RangeVar *relation = copyStatement->relation; - Oid tableId = RangeVarGetRelid(relation, NoLock, false); + Oid tableId = RangeVarGetRelid(copyStatement->relation, NoLock, false); char *relationName = get_rel_name(tableId); - List *shardIntervalList = NULL; - ListCell *shardIntervalCell = NULL; + Relation distributedRelation = NULL; char partitionMethod = '\0'; Var *partitionColumn = NULL; - HTAB *shardConnectionHash = NULL; - List *connectionList = NIL; - MemoryContext tupleContext = NULL; - CopyState copyState = NULL; TupleDesc tupleDescriptor = NULL; uint32 columnCount = 0; Datum *columnValues = NULL; bool *columnNulls = NULL; - Relation rel = NULL; - ShardInterval **shardIntervalCache = NULL; - bool useBinarySearch = false; TypeCacheEntry *typeEntry = NULL; FmgrInfo *hashFunction = NULL; FmgrInfo *compareFunction = NULL; + int shardCount = 0; + List *shardIntervalList = NULL; + ShardInterval **shardIntervalCache = NULL; + bool useBinarySearch = false; + + HTAB *shardConnectionHash = NULL; + ShardConnections *shardConnections = NULL; + List *connectionList = NIL; + + EState *executorState = NULL; + MemoryContext executorTupleContext = NULL; + ExprContext *executorExpressionContext = NULL; + + CopyState copyState = NULL; + CopyOutState copyOutState = NULL; + FmgrInfo *columnOutputFunctions = NULL; uint64 processedRowCount = 0; - ErrorContextCallback errorCallback; /* disallow COPY to/from file or program except for superusers */ if (copyStatement->filename != NULL && !superuser()) @@ -333,8 +248,8 @@ CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) compareFunction = ShardIntervalCompareFunction(partitionColumn, partitionMethod); /* allocate column values and nulls arrays */ - rel = heap_open(tableId, RowExclusiveLock); - tupleDescriptor = RelationGetDescr(rel); + distributedRelation = heap_open(tableId, RowExclusiveLock); + tupleDescriptor = RelationGetDescr(distributedRelation); columnCount = tupleDescriptor->natts; columnValues = palloc0(columnCount * sizeof(Datum)); columnNulls = palloc0(columnCount * sizeof(bool)); @@ -346,7 +261,7 @@ CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) if (partitionMethod == DISTRIBUTE_BY_HASH) { ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), - errmsg("could not find any shards for query"), + errmsg("could not find any shards into which to copy"), errdetail("No shards exist for distributed table \"%s\".", relationName), errhint("Run master_create_worker_shards to create shards " @@ -355,29 +270,14 @@ CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) else { ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), - errmsg("could not find any shards for query"), + errmsg("could not find any shards into which to copy"), errdetail("No shards exist for distributed table \"%s\".", relationName))); } } - /* create a mapping of shard id to a connection for each of its placements */ - shardConnectionHash = CreateShardConnectionHash(); - - /* lock shards in order of shard id to prevent deadlock */ - shardIntervalList = SortList(shardIntervalList, CompareShardIntervalsById); - - foreach(shardIntervalCell, shardIntervalList) - { - ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); - int64 shardId = shardInterval->shardId; - - /* prevent concurrent changes to number of placements */ - LockShardDistributionMetadata(shardId, ShareLock); - - /* prevent concurrent update/delete statements */ - LockShardResource(shardId, ShareLock); - } + /* prevent concurrent placement changes and non-commutative DML statements */ + LockAllShards(shardIntervalList); /* initialize the shard interval cache */ shardCount = list_length(shardIntervalList); @@ -391,60 +291,65 @@ CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) } /* initialize copy state to read from COPY data source */ - copyState = BeginCopyFrom(rel, copyStatement->filename, + copyState = BeginCopyFrom(distributedRelation, + copyStatement->filename, copyStatement->is_program, copyStatement->attlist, copyStatement->options); - if (copyState->binary) - { - ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), - errmsg("Copy in binary mode is not currently supported"))); - } + executorState = CreateExecutorState(); + executorTupleContext = GetPerTupleMemoryContext(executorState); + executorExpressionContext = GetPerTupleExprContext(executorState); - /* set up callback to identify error line number */ - errorCallback.callback = CopyFromErrorCallback; - errorCallback.arg = (void *) copyState; - errorCallback.previous = error_context_stack; - error_context_stack = &errorCallback; + copyOutState = (CopyOutState) palloc0(sizeof(CopyOutStateData)); + copyOutState->binary = true; + copyOutState->fe_msgbuf = makeStringInfo(); + copyOutState->rowcontext = executorTupleContext; + + columnOutputFunctions = ColumnOutputFunctions(tupleDescriptor, copyOutState->binary); /* - * We create a new memory context called tuple context, and read and write - * each row's values within this memory context. After each read and write, - * we reset the memory context. That way, we immediately release memory - * allocated for each row, and don't bloat memory usage with large input - * files. + * Create a mapping of shard id to a connection for each of its placements. + * The hash should be initialized before the PG_TRY, since it is used and + * PG_CATCH. Otherwise, it may be undefined in the PG_CATCH (see sigsetjmp + * documentation). */ - tupleContext = AllocSetContextCreate(CurrentMemoryContext, - "COPY Row Memory Context", - ALLOCSET_DEFAULT_MINSIZE, - ALLOCSET_DEFAULT_INITSIZE, - ALLOCSET_DEFAULT_MAXSIZE); + shardConnectionHash = CreateShardConnectionHash(); /* we use a PG_TRY block to roll back on errors (e.g. in NextCopyFrom) */ PG_TRY(); { + ErrorContextCallback errorCallback; + + /* set up callback to identify error line number */ + errorCallback.callback = CopyFromErrorCallback; + errorCallback.arg = (void *) copyState; + errorCallback.previous = error_context_stack; + error_context_stack = &errorCallback; + + /* ensure transactions have unique names on worker nodes */ + InitializeDistributedTransaction(); + while (true) { bool nextRowFound = false; Datum partitionColumnValue = 0; ShardInterval *shardInterval = NULL; int64 shardId = 0; - ShardConnections *shardConnections = NULL; - bool found = false; - StringInfo lineBuf = NULL; + bool shardConnectionsFound = false; MemoryContext oldContext = NULL; - oldContext = MemoryContextSwitchTo(tupleContext); + ResetPerTupleExprContext(executorState); + + oldContext = MemoryContextSwitchTo(executorTupleContext); /* parse a row from the input */ - nextRowFound = NextCopyFrom(copyState, NULL, columnValues, columnNulls, NULL); - - MemoryContextSwitchTo(oldContext); + nextRowFound = NextCopyFrom(copyState, executorExpressionContext, + columnValues, columnNulls, NULL); if (!nextRowFound) { - MemoryContextReset(tupleContext); + MemoryContextSwitchTo(oldContext); break; } @@ -469,68 +374,83 @@ CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) if (shardInterval == NULL) { ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), - errmsg("no shard for partition column value"))); + errmsg("could not find shard for partition column " + "value"))); } shardId = shardInterval->shardId; - /* find the connections to the shard placements */ - shardConnections = (ShardConnections *) hash_search(shardConnectionHash, - &shardInterval->shardId, - HASH_ENTER, - &found); - if (!found) - { - /* intialize COPY transactions on shard placements */ - shardConnections->shardId = shardId; - shardConnections->connectionList = NIL; + MemoryContextSwitchTo(oldContext); - OpenCopyTransactions(copyStatement, shardConnections, shardId); + /* get existing connections to the shard placements, if any */ + shardConnections = GetShardConnections(shardConnectionHash, + shardId, + &shardConnectionsFound); + if (!shardConnectionsFound) + { + /* open connections and initiate COPY on shard placements */ + OpenCopyTransactions(copyStatement, shardConnections); + + /* send binary headers to shard placements */ + resetStringInfo(copyOutState->fe_msgbuf); + AppendCopyBinaryHeaders(copyOutState); + SendCopyDataToAll(copyOutState->fe_msgbuf, + shardConnections->connectionList); } - /* get the (truncated) line buffer */ - lineBuf = ©State->line_buf; - lineBuf->data[lineBuf->len++] = '\n'; - - /* Replicate row to all shard placements */ - CopyRowToPlacements(lineBuf, shardConnections); + /* replicate row to shard placements */ + resetStringInfo(copyOutState->fe_msgbuf); + AppendCopyRowData(columnValues, columnNulls, tupleDescriptor, + copyOutState, columnOutputFunctions); + SendCopyDataToAll(copyOutState->fe_msgbuf, shardConnections->connectionList); processedRowCount += 1; - - MemoryContextReset(tupleContext); } connectionList = ConnectionList(shardConnectionHash); + /* send binary footers to all shard placements */ + resetStringInfo(copyOutState->fe_msgbuf); + AppendCopyBinaryFooters(copyOutState); + SendCopyDataToAll(copyOutState->fe_msgbuf, connectionList); + + /* all lines have been copied, stop showing line number in errors */ + error_context_stack = errorCallback.previous; + + /* close the COPY input on all shard placements */ EndRemoteCopy(connectionList, true); if (CopyTransactionManager == TRANSACTION_MANAGER_2PC) { - PrepareTransactions(connectionList); + PrepareRemoteTransactions(connectionList); } + EndCopyFrom(copyState); + heap_close(distributedRelation, NoLock); + + /* check for cancellation one last time before committing */ CHECK_FOR_INTERRUPTS(); } PG_CATCH(); { - EndCopyFrom(copyState); + List *abortConnectionList = NIL; /* roll back all transactions */ - connectionList = ConnectionList(shardConnectionHash); - EndRemoteCopy(connectionList, false); - AbortTransactions(connectionList); - CloseConnections(connectionList); + abortConnectionList = ConnectionList(shardConnectionHash); + EndRemoteCopy(abortConnectionList, false); + AbortRemoteTransactions(abortConnectionList); + CloseConnections(abortConnectionList); PG_RE_THROW(); } PG_END_TRY(); - EndCopyFrom(copyState); - heap_close(rel, NoLock); - - error_context_stack = errorCallback.previous; - - CommitTransactions(connectionList); + /* + * Ready to commit the transaction, this code is below the PG_TRY block because + * we do not want any of the transactions rolled back if a failure occurs. Instead, + * they should be rolled forward. + */ + CommitRemoteTransactions(connectionList); CloseConnections(connectionList); if (completionTag != NULL) @@ -541,6 +461,33 @@ CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) } +/* + * LockAllShards takes shared locks on the metadata and the data of all shards in + * shardIntervalList. This prevents concurrent placement changes and concurrent + * DML statements that require an exclusive lock. + */ +static void +LockAllShards(List *shardIntervalList) +{ + ListCell *shardIntervalCell = NULL; + + /* lock shards in order of shard id to prevent deadlock */ + shardIntervalList = SortList(shardIntervalList, CompareShardIntervalsById); + + foreach(shardIntervalCell, shardIntervalList) + { + ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); + int64 shardId = shardInterval->shardId; + + /* prevent concurrent changes to number of placements */ + LockShardDistributionMetadata(shardId, ShareLock); + + /* prevent concurrent update/delete statements */ + LockShardResource(shardId, ShareLock); + } +} + + /* * CreateShardConnectionHash constructs a hash table used for shardId->Connection * mapping. @@ -549,6 +496,7 @@ static HTAB * CreateShardConnectionHash(void) { HTAB *shardConnectionsHash = NULL; + int hashFlags = 0; HASHCTL info; memset(&info, 0, sizeof(info)); @@ -556,9 +504,10 @@ CreateShardConnectionHash(void) info.entrysize = sizeof(ShardConnections); info.hash = tag_hash; + hashFlags = HASH_ELEM | HASH_FUNCTION | HASH_CONTEXT; shardConnectionsHash = hash_create("Shard Connections Hash", INITIAL_CONNECTION_CACHE_SIZE, &info, - HASH_ELEM | HASH_FUNCTION); + hashFlags); return shardConnectionsHash; } @@ -606,11 +555,16 @@ ShardIntervalCompareFunction(Var *partitionColumn, char partitionMethod) { compareFunction = GetFunctionInfo(INT4OID, BTREE_AM_OID, BTORDER_PROC); } - else + else if (partitionMethod == DISTRIBUTE_BY_RANGE) { compareFunction = GetFunctionInfo(partitionColumn->vartype, BTREE_AM_OID, BTORDER_PROC); } + else + { + ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("unsupported partition method %d", partitionMethod))); + } return compareFunction; } @@ -674,6 +628,18 @@ FindShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache uint64 hashTokenIncrement = HASH_TOKEN_COUNT / shardCount; int shardIndex = (uint32) (hashedValue - INT32_MIN) / hashTokenIncrement; + Assert(shardIndex <= shardCount); + + /* + * If the shard count is not power of 2, the range of the last + * shard becomes larger than others. For that extra piece of range, + * we still need to use the last shard. + */ + if (shardIndex == shardCount) + { + shardIndex = shardCount - 1; + } + shardInterval = shardIntervalCache[shardIndex]; } } @@ -701,7 +667,7 @@ SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardInter while (lowerBoundIndex < upperBoundIndex) { - int middleIndex = (lowerBoundIndex + upperBoundIndex) >> 1; + int middleIndex = (lowerBoundIndex + upperBoundIndex) / 2; int maxValueComparison = 0; int minValueComparison = 0; @@ -733,6 +699,31 @@ SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardInter } +/* + * GetShardConnections finds existing connections for a shard in the hash + * or opens new connections to each active placement and starts a (binary) COPY + * transaction on each of them. + */ +static ShardConnections * +GetShardConnections(HTAB *shardConnectionHash, int64 shardId, + bool *shardConnectionsFound) +{ + ShardConnections *shardConnections = NULL; + + shardConnections = (ShardConnections *) hash_search(shardConnectionHash, + &shardId, + HASH_ENTER, + shardConnectionsFound); + if (!*shardConnectionsFound) + { + shardConnections->shardId = shardId; + shardConnections->connectionList = NIL; + } + + return shardConnections; +} + + /* * OpenCopyTransactions opens a connection for each placement of a shard and * starts a COPY transaction. If a connection cannot be opened, then the shard @@ -740,16 +731,26 @@ SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval **shardInter * shard placements. */ static void -OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections, - int64 shardId) +OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections) { List *finalizedPlacementList = NIL; List *failedPlacementList = NIL; ListCell *placementCell = NULL; ListCell *failedPlacementCell = NULL; - List *connectionList = NIL; + List *connectionList = NULL; - finalizedPlacementList = FinalizedShardPlacementList(shardId); + MemoryContext localContext = AllocSetContextCreate(CurrentMemoryContext, + "OpenCopyTransactions", + ALLOCSET_DEFAULT_MINSIZE, + ALLOCSET_DEFAULT_INITSIZE, + ALLOCSET_DEFAULT_MAXSIZE); + + /* release finalized placement list at the end of this function */ + MemoryContext oldContext = MemoryContextSwitchTo(localContext); + + finalizedPlacementList = FinalizedShardPlacementList(shardConnections->shardId); + + MemoryContextSwitchTo(oldContext); foreach(placementCell, finalizedPlacementList) { @@ -761,6 +762,10 @@ OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections PGresult *result = NULL; PGconn *connection = ConnectToNode(nodeName, nodePort); + + /* release failed placement list and copy command at the end of this function */ + oldContext = MemoryContextSwitchTo(localContext); + if (connection == NULL) { failedPlacementList = lappend(failedPlacementList, placement); @@ -775,7 +780,7 @@ OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections continue; } - copyCommand = ConstructCopyStatement(copyStatement, shardId); + copyCommand = ConstructCopyStatement(copyStatement, shardConnections->shardId); result = PQexec(connection, copyCommand->data); if (PQresultStatus(result) != PGRES_COPY_IN) @@ -785,9 +790,12 @@ OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections continue; } + /* preserve transaction connection in regular memory context */ + MemoryContextSwitchTo(oldContext); + transactionConnection = palloc0(sizeof(TransactionConnection)); - transactionConnection->connectionId = shardId; + transactionConnection->connectionId = shardConnections->shardId; transactionConnection->transactionState = TRANSACTION_STATE_COPY_STARTED; transactionConnection->connection = connection; @@ -797,7 +805,7 @@ OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections /* if all placements failed, error out */ if (list_length(failedPlacementList) == list_length(finalizedPlacementList)) { - ereport(ERROR, (errmsg("could not modify any active placements"))); + ereport(ERROR, (errmsg("could not find any active placements"))); } /* otherwise, mark failed placements as inactive: they're stale */ @@ -813,11 +821,13 @@ OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections } shardConnections->connectionList = connectionList; + + MemoryContextReset(localContext); } /* - * ConstructCopyStattement constructs the text of a COPY statement for a particular + * ConstructCopyStatement constructs the text of a COPY statement for a particular * shard. */ static StringInfo @@ -831,117 +841,47 @@ ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId) appendStringInfo(command, "COPY %s_%ld ", qualifiedName, shardId); - if (copyStatement->attlist != NIL) - { - AppendColumnNames(command, copyStatement->attlist); - } - - appendStringInfoString(command, "FROM STDIN"); - - if (copyStatement->options) - { - appendStringInfoString(command, " WITH "); - - AppendCopyOptions(command, copyStatement->options); - } + appendStringInfoString(command, "FROM STDIN WITH (FORMAT BINARY)"); return command; } /* - * AppendCopyOptions deparses a list of CopyStmt options and appends them to command. + * SendCopyDataToAll sends copy data to all connections in a list. */ static void -AppendCopyOptions(StringInfo command, List *copyOptionList) -{ - ListCell *optionCell = NULL; - char separator = '('; - - foreach(optionCell, copyOptionList) - { - DefElem *option = (DefElem *) lfirst(optionCell); - - if (strcmp(option->defname, "header") == 0 && defGetBoolean(option)) - { - /* worker should not skip header again */ - continue; - } - - appendStringInfo(command, "%c%s ", separator, option->defname); - - if (strcmp(option->defname, "force_not_null") == 0 || - strcmp(option->defname, "force_null") == 0) - { - if (!option->arg) - { - ereport(ERROR, - (errcode(ERRCODE_INVALID_PARAMETER_VALUE), - errmsg( - "argument to option \"%s\" must be a list of column names", - option->defname))); - } - else - { - AppendColumnNames(command, (List *) option->arg); - } - } - else - { - appendStringInfo(command, "'%s'", defGetString(option)); - } - - separator = ','; - } - - appendStringInfoChar(command, ')'); -} - - -/* - * AppendColumnList deparses a list of column names into a StringInfo. - */ -static void -AppendColumnNames(StringInfo command, List *columnList) -{ - ListCell *attributeCell = NULL; - char separator = '('; - - foreach(attributeCell, columnList) - { - char *columnName = strVal(lfirst(attributeCell)); - appendStringInfo(command, "%c%s", separator, quote_identifier(columnName)); - separator = ','; - } - - appendStringInfoChar(command, ')'); -} - - -/* - * CopyRowToPlacements copies a row to a list of placements for a shard. - */ -static void -CopyRowToPlacements(StringInfo lineBuf, ShardConnections *shardConnections) +SendCopyDataToAll(StringInfo dataBuffer, List *connectionList) { ListCell *connectionCell = NULL; - foreach(connectionCell, shardConnections->connectionList) + foreach(connectionCell, connectionList) { TransactionConnection *transactionConnection = (TransactionConnection *) lfirst(connectionCell); PGconn *connection = transactionConnection->connection; - int64 shardId = shardConnections->shardId; + int64 shardId = transactionConnection->connectionId; - /* copy the line buffer into the placement */ - int copyResult = PQputCopyData(connection, lineBuf->data, lineBuf->len); - if (copyResult != 1) - { - char *nodeName = ConnectionGetOptionValue(connection, "host"); - char *nodePort = ConnectionGetOptionValue(connection, "port"); - ereport(ERROR, (errcode(ERRCODE_IO_ERROR), - errmsg("Failed to COPY to shard %ld on %s:%s", - shardId, nodeName, nodePort))); - } + SendCopyDataToPlacement(dataBuffer, connection, shardId); + } +} + + +/* + * SendCopyDataToPlacement sends serialized COPY data to a specific shard placement + * over the given connection. + */ +static void +SendCopyDataToPlacement(StringInfo dataBuffer, PGconn *connection, int64 shardId) +{ + int copyResult = PQputCopyData(connection, dataBuffer->data, dataBuffer->len); + if (copyResult != 1) + { + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + + ereport(ERROR, (errcode(ERRCODE_IO_ERROR), + errmsg("failed to COPY to shard %ld on %s:%s", + shardId, nodeName, nodePort))); } } @@ -987,14 +927,12 @@ EndRemoteCopy(List *connectionList, bool stopOnFailure) (TransactionConnection *) lfirst(connectionCell); PGconn *connection = transactionConnection->connection; int64 shardId = transactionConnection->connectionId; - char *nodeName = ConnectionGetOptionValue(connection, "host"); - char *nodePort = ConnectionGetOptionValue(connection, "port"); int copyEndResult = 0; PGresult *result = NULL; if (transactionConnection->transactionState != TRANSACTION_STATE_COPY_STARTED) { - /* COPY already ended during the prepare phase */ + /* a failure occurred after having previously called EndRemoteCopy */ continue; } @@ -1006,8 +944,11 @@ EndRemoteCopy(List *connectionList, bool stopOnFailure) { if (stopOnFailure) { + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + ereport(ERROR, (errcode(ERRCODE_IO_ERROR), - errmsg("Failed to COPY to shard %ld on %s:%s", + errmsg("failed to COPY to shard %ld on %s:%s", shardId, nodeName, nodePort))); } @@ -1047,6 +988,8 @@ ReportCopyError(PGconn *connection, PGresult *result) { /* probably a connection problem, get the message from the connection */ char *lastNewlineIndex = NULL; + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); remoteMessage = PQerrorMessage(connection); lastNewlineIndex = strrchr(remoteMessage, '\n'); @@ -1057,6 +1000,348 @@ ReportCopyError(PGconn *connection, PGresult *result) *lastNewlineIndex = '\0'; } - ereport(ERROR, (errmsg("%s", remoteMessage))); + ereport(ERROR, (errcode(ERRCODE_IO_ERROR), + errmsg("failed to complete COPY on %s:%s", nodeName, nodePort), + errdetail("%s", remoteMessage))); + } +} + + +/* + * ColumnOutputFunctions walks over a table's columns, and finds each column's + * type information. The function then resolves each type's output function, + * and stores and returns these output functions in an array. + */ +FmgrInfo * +ColumnOutputFunctions(TupleDesc rowDescriptor, bool binaryFormat) +{ + uint32 columnCount = (uint32) rowDescriptor->natts; + FmgrInfo *columnOutputFunctions = palloc0(columnCount * sizeof(FmgrInfo)); + + uint32 columnIndex = 0; + for (columnIndex = 0; columnIndex < columnCount; columnIndex++) + { + FmgrInfo *currentOutputFunction = &columnOutputFunctions[columnIndex]; + Form_pg_attribute currentColumn = rowDescriptor->attrs[columnIndex]; + Oid columnTypeId = currentColumn->atttypid; + Oid outputFunctionId = InvalidOid; + bool typeVariableLength = false; + + if (currentColumn->attisdropped) + { + /* dropped column, leave the output function NULL */ + continue; + } + else if (binaryFormat) + { + getTypeBinaryOutputInfo(columnTypeId, &outputFunctionId, &typeVariableLength); + } + else + { + getTypeOutputInfo(columnTypeId, &outputFunctionId, &typeVariableLength); + } + + fmgr_info(outputFunctionId, currentOutputFunction); + } + + return columnOutputFunctions; +} + + +/* + * AppendCopyRowData serializes one row using the column output functions, + * and appends the data to the row output state object's message buffer. + * This function is modeled after the CopyOneRowTo() function in + * commands/copy.c, but only implements a subset of that functionality. + * Note that the caller of this function should reset row memory context + * to not bloat memory usage. + */ +void +AppendCopyRowData(Datum *valueArray, bool *isNullArray, TupleDesc rowDescriptor, + CopyOutState rowOutputState, FmgrInfo *columnOutputFunctions) +{ + MemoryContext oldContext = NULL; + uint32 totalColumnCount = (uint32) rowDescriptor->natts; + uint32 columnIndex = 0; + uint32 availableColumnCount = AvailableColumnCount(rowDescriptor); + uint32 appendedColumnCount = 0; + + oldContext = MemoryContextSwitchTo(rowOutputState->rowcontext); + + if (rowOutputState->binary) + { + CopySendInt16(rowOutputState, availableColumnCount); + } + + for (columnIndex = 0; columnIndex < totalColumnCount; columnIndex++) + { + Form_pg_attribute currentColumn = rowDescriptor->attrs[columnIndex]; + Datum value = valueArray[columnIndex]; + bool isNull = isNullArray[columnIndex]; + bool lastColumn = false; + + if (currentColumn->attisdropped) + { + continue; + } + else if (rowOutputState->binary) + { + if (!isNull) + { + FmgrInfo *outputFunctionPointer = &columnOutputFunctions[columnIndex]; + bytea *outputBytes = SendFunctionCall(outputFunctionPointer, value); + + CopySendInt32(rowOutputState, VARSIZE(outputBytes) - VARHDRSZ); + CopySendData(rowOutputState, VARDATA(outputBytes), + VARSIZE(outputBytes) - VARHDRSZ); + } + else + { + CopySendInt32(rowOutputState, -1); + } + } + else + { + if (!isNull) + { + FmgrInfo *outputFunctionPointer = &columnOutputFunctions[columnIndex]; + char *columnText = OutputFunctionCall(outputFunctionPointer, value); + + CopyAttributeOutText(rowOutputState, columnText); + } + else + { + CopySendString(rowOutputState, rowOutputState->null_print_client); + } + + lastColumn = ((appendedColumnCount + 1) == availableColumnCount); + if (!lastColumn) + { + CopySendChar(rowOutputState, rowOutputState->delim[0]); + } + } + + appendedColumnCount++; + } + + if (!rowOutputState->binary) + { + /* append default line termination string depending on the platform */ +#ifndef WIN32 + CopySendChar(rowOutputState, '\n'); +#else + CopySendString(rowOutputState, "\r\n"); +#endif + } + + MemoryContextSwitchTo(oldContext); +} + + +/* + * AvailableColumnCount returns the number of columns in a tuple descriptor, excluding + * columns that were dropped. + */ +static uint32 +AvailableColumnCount(TupleDesc tupleDescriptor) +{ + uint32 columnCount = 0; + uint32 columnIndex = 0; + + for (columnIndex = 0; columnIndex < tupleDescriptor->natts; columnIndex++) + { + Form_pg_attribute currentColumn = tupleDescriptor->attrs[columnIndex]; + + if (!currentColumn->attisdropped) + { + columnCount++; + } + } + + return columnCount; +} + + +/* + * AppendCopyBinaryHeaders appends binary headers to the copy buffer in + * headerOutputState. + */ +void +AppendCopyBinaryHeaders(CopyOutState headerOutputState) +{ + const int32 zero = 0; + + /* Signature */ + CopySendData(headerOutputState, BinarySignature, 11); + + /* Flags field (no OIDs) */ + CopySendInt32(headerOutputState, zero); + + /* No header extension */ + CopySendInt32(headerOutputState, zero); +} + + +/* + * AppendCopyBinaryFooters appends binary footers to the copy buffer in + * footerOutputState. + */ +void +AppendCopyBinaryFooters(CopyOutState footerOutputState) +{ + int16 negative = -1; + + CopySendInt16(footerOutputState, negative); +} + + +/* *INDENT-OFF* */ +/* Append data to the copy buffer in outputState */ +static void +CopySendData(CopyOutState outputState, const void *databuf, int datasize) +{ + appendBinaryStringInfo(outputState->fe_msgbuf, databuf, datasize); +} + + +/* Append a striong to the copy buffer in outputState. */ +static void +CopySendString(CopyOutState outputState, const char *str) +{ + appendBinaryStringInfo(outputState->fe_msgbuf, str, strlen(str)); +} + + +/* Append a char to the copy buffer in outputState. */ +static void +CopySendChar(CopyOutState outputState, char c) +{ + appendStringInfoCharMacro(outputState->fe_msgbuf, c); +} + + +/* Append an int32 to the copy buffer in outputState. */ +static void +CopySendInt32(CopyOutState outputState, int32 val) +{ + uint32 buf = htonl((uint32) val); + CopySendData(outputState, &buf, sizeof(buf)); +} + + +/* Append an int16 to the copy buffer in outputState. */ +static void +CopySendInt16(CopyOutState outputState, int16 val) +{ + uint16 buf = htons((uint16) val); + CopySendData(outputState, &buf, sizeof(buf)); +} + + +/* + * Send text representation of one column, with conversion and escaping. + * + * NB: This function is based on commands/copy.c and doesn't fully conform to + * our coding style. The function should be kept in sync with copy.c. + */ +static void +CopyAttributeOutText(CopyOutState cstate, char *string) +{ + char *pointer = NULL; + char *start = NULL; + char c = '\0'; + char delimc = cstate->delim[0]; + + if (cstate->need_transcoding) + { + pointer = pg_server_to_any(string, strlen(string), cstate->file_encoding); + } + else + { + pointer = string; + } + + /* + * We have to grovel through the string searching for control characters + * and instances of the delimiter character. In most cases, though, these + * are infrequent. To avoid overhead from calling CopySendData once per + * character, we dump out all characters between escaped characters in a + * single call. The loop invariant is that the data from "start" to "pointer" + * can be sent literally, but hasn't yet been. + * + * As all encodings here are safe, i.e. backend supported ones, we can + * skip doing pg_encoding_mblen(), because in valid backend encodings, + * extra bytes of a multibyte character never look like ASCII. + */ + start = pointer; + while ((c = *pointer) != '\0') + { + if ((unsigned char) c < (unsigned char) 0x20) + { + /* + * \r and \n must be escaped, the others are traditional. We + * prefer to dump these using the C-like notation, rather than + * a backslash and the literal character, because it makes the + * dump file a bit more proof against Microsoftish data + * mangling. + */ + switch (c) + { + case '\b': + c = 'b'; + break; + case '\f': + c = 'f'; + break; + case '\n': + c = 'n'; + break; + case '\r': + c = 'r'; + break; + case '\t': + c = 't'; + break; + case '\v': + c = 'v'; + break; + default: + /* If it's the delimiter, must backslash it */ + if (c == delimc) + break; + /* All ASCII control chars are length 1 */ + pointer++; + continue; /* fall to end of loop */ + } + /* if we get here, we need to convert the control char */ + CopyFlushOutput(cstate, start, pointer); + CopySendChar(cstate, '\\'); + CopySendChar(cstate, c); + start = ++pointer; /* do not include char in next run */ + } + else if (c == '\\' || c == delimc) + { + CopyFlushOutput(cstate, start, pointer); + CopySendChar(cstate, '\\'); + start = pointer++; /* we include char in next run */ + } + else + { + pointer++; + } + } + + CopyFlushOutput(cstate, start, pointer); +} + + +/* *INDENT-ON* */ +/* Helper function to send pending copy output */ +static inline void +CopyFlushOutput(CopyOutState cstate, char *start, char *pointer) +{ + if (pointer > start) + { + CopySendData(cstate, start, pointer - start); } } diff --git a/src/backend/distributed/executor/multi_utility.c b/src/backend/distributed/executor/multi_utility.c index f6d5af9e0..c6fca17b5 100644 --- a/src/backend/distributed/executor/multi_utility.c +++ b/src/backend/distributed/executor/multi_utility.c @@ -303,7 +303,9 @@ VerifyTransmitStmt(CopyStmt *copyStatement) /* * ProcessCopyStmt handles Citus specific concerns for COPY like supporting - * COPYing from distributed tables and preventing unsupported actions. + * COPYing from distributed tables and preventing unsupported actions. The + * function returns a modified COPY statement to be executed, or NULL if no + * further processing is needed. */ static Node * ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag) @@ -588,6 +590,7 @@ ErrorIfUnsupportedIndexStmt(IndexStmt *createIndexStatement) { RangeVar *relation = createIndexStatement->relation; bool missingOk = false; + /* caller uses ShareLock for non-concurrent indexes, use the same lock here */ LOCKMODE lockMode = ShareLock; Oid relationId = RangeVarGetRelid(relation, lockMode, missingOk); diff --git a/src/backend/distributed/utils/multi_transaction.c b/src/backend/distributed/utils/multi_transaction.c index 667542ada..46b129a52 100644 --- a/src/backend/distributed/utils/multi_transaction.c +++ b/src/backend/distributed/utils/multi_transaction.c @@ -20,17 +20,32 @@ #include "nodes/pg_list.h" +/* Local functions forward declarations */ +static uint32 DistributedTransactionId = 0; + + /* Local functions forward declarations */ static StringInfo BuildTransactionName(int connectionId); /* - * PrepareTransactions prepares all transactions on connections in + * InitializeDistributedTransaction prepares the distributed transaction ID + * used in transaction names. + */ +void +InitializeDistributedTransaction(void) +{ + DistributedTransactionId++; +} + + +/* + * PrepareRemoteTransactions prepares all transactions on connections in * connectionList for commit if the 2PC transaction manager is enabled. * On failure, it reports an error and stops. */ void -PrepareTransactions(List *connectionList) +PrepareRemoteTransactions(List *connectionList) { ListCell *connectionCell = NULL; @@ -57,7 +72,7 @@ PrepareTransactions(List *connectionList) PQclear(result); ereport(ERROR, (errcode(ERRCODE_IO_ERROR), - errmsg("Failed to prepare transaction"))); + errmsg("failed to prepare transaction"))); } PQclear(result); @@ -68,11 +83,11 @@ PrepareTransactions(List *connectionList) /* - * AbortTransactions aborts all transactions on connections in connectionList. + * AbortRemoteTransactions aborts all transactions on connections in connectionList. * On failure, it reports a warning and continues to abort all of them. */ void -AbortTransactions(List *connectionList) +AbortRemoteTransactions(List *connectionList) { ListCell *connectionCell = NULL; @@ -82,8 +97,6 @@ AbortTransactions(List *connectionList) (TransactionConnection *) lfirst(connectionCell); PGconn *connection = transactionConnection->connection; int64 connectionId = transactionConnection->connectionId; - char *nodeName = ConnectionGetOptionValue(connection, "host"); - char *nodePort = ConnectionGetOptionValue(connection, "port"); PGresult *result = NULL; if (transactionConnection->transactionState == TRANSACTION_STATE_PREPARED) @@ -96,11 +109,14 @@ AbortTransactions(List *connectionList) result = PQexec(connection, command->data); if (PQresultStatus(result) != PGRES_COMMAND_OK) { + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + /* log a warning so the user may abort the transaction later */ - ereport(WARNING, (errmsg("Failed to roll back prepared transaction '%s'", + ereport(WARNING, (errmsg("failed to roll back prepared transaction '%s'", transactionName->data), - errhint("Run ROLLBACK TRANSACTION '%s' on %s:%s", - transactionName->data, nodeName, nodePort))); + errhint("Run \"%s\" on %s:%s", + command->data, nodeName, nodePort))); } PQclear(result); @@ -118,11 +134,11 @@ AbortTransactions(List *connectionList) /* - * CommitTransactions commits all transactions on connections in connectionList. + * CommitRemoteTransactions commits all transactions on connections in connectionList. * On failure, it reports a warning and continues committing all of them. */ void -CommitTransactions(List *connectionList) +CommitRemoteTransactions(List *connectionList) { ListCell *connectionCell = NULL; @@ -132,8 +148,6 @@ CommitTransactions(List *connectionList) (TransactionConnection *) lfirst(connectionCell); PGconn *connection = transactionConnection->connection; int64 connectionId = transactionConnection->connectionId; - char *nodeName = ConnectionGetOptionValue(connection, "host"); - char *nodePort = ConnectionGetOptionValue(connection, "port"); PGresult *result = NULL; if (transactionConnection->transactionState == TRANSACTION_STATE_PREPARED) @@ -149,11 +163,14 @@ CommitTransactions(List *connectionList) result = PQexec(connection, command->data); if (PQresultStatus(result) != PGRES_COMMAND_OK) { + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + /* log a warning so the user may commit the transaction later */ - ereport(WARNING, (errmsg("Failed to commit prepared transaction '%s'", + ereport(WARNING, (errmsg("failed to commit prepared transaction '%s'", transactionName->data), - errhint("Run COMMIT TRANSACTION '%s' on %s:%s", - transactionName->data, nodeName, nodePort))); + errhint("Run \"%s\" on %s:%s", + command->data, nodeName, nodePort))); } } else @@ -165,7 +182,10 @@ CommitTransactions(List *connectionList) result = PQexec(connection, "COMMIT"); if (PQresultStatus(result) != PGRES_COMMAND_OK) { - ereport(WARNING, (errmsg("Failed to commit transaction on %s:%s", + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + + ereport(WARNING, (errmsg("failed to commit transaction on %s:%s", nodeName, nodePort))); } } @@ -178,7 +198,14 @@ CommitTransactions(List *connectionList) /* - * BuildTransactionName constructs a unique transaction name from an ID. + * BuildTransactionName constructs a transaction name that ensures there are no + * collisions with concurrent transactions by the same master node, subsequent + * transactions by the same backend, or transactions on a different shard. + * + * Collisions may occur over time if transactions fail to commit or abort and + * are left to linger. This would cause a PREPARE failure for the second + * transaction, which causes it to be rolled back. In general, the user + * should ensure that prepared transactions do not linger. */ static StringInfo BuildTransactionName(int connectionId) @@ -186,7 +213,7 @@ BuildTransactionName(int connectionId) StringInfo commandString = makeStringInfo(); appendStringInfo(commandString, "citus_%d_%u_%d", MyProcPid, - GetCurrentTransactionId(), connectionId); + DistributedTransactionId, connectionId); return commandString; } diff --git a/src/backend/distributed/worker/worker_partition_protocol.c b/src/backend/distributed/worker/worker_partition_protocol.c index 0751233c0..f2d5fa1cc 100644 --- a/src/backend/distributed/worker/worker_partition_protocol.c +++ b/src/backend/distributed/worker/worker_partition_protocol.c @@ -29,6 +29,7 @@ #include "catalog/pg_collation.h" #include "commands/copy.h" #include "commands/defrem.h" +#include "distributed/multi_copy.h" #include "distributed/resource_lock.h" #include "distributed/transmit.h" #include "distributed/worker_protocol.h" @@ -45,7 +46,6 @@ bool BinaryWorkerCopyFormat = false; /* binary format for copying between work int PartitionBufferSize = 16384; /* total partitioning buffer size in KB */ /* Local variables */ -static const char BinarySignature[11] = "PGCOPY\n\377\r\n\0"; static uint32 FileBufferSizeInBytes = 0; /* file buffer size to init later */ @@ -64,21 +64,10 @@ static void FilterAndPartitionTable(const char *filterQuery, FileOutputStream *partitionFileArray, uint32 fileCount); static int ColumnIndex(TupleDesc rowDescriptor, const char *columnName); -static FmgrInfo * ColumnOutputFunctions(TupleDesc rowDescriptor, bool binaryFormat); -static PartialCopyState InitRowOutputState(void); -static void ClearRowOutputState(PartialCopyState copyState); -static void OutputRow(HeapTuple row, TupleDesc rowDescriptor, - PartialCopyState rowOutputState, FmgrInfo *columnOutputFunctions); +static CopyOutState InitRowOutputState(void); +static void ClearRowOutputState(CopyOutState copyState); static void OutputBinaryHeaders(FileOutputStream *partitionFileArray, uint32 fileCount); static void OutputBinaryFooters(FileOutputStream *partitionFileArray, uint32 fileCount); -static void CopySendData(PartialCopyState outputState, const void *databuf, int datasize); -static void CopySendString(PartialCopyState outputState, const char *str); -static void CopySendChar(PartialCopyState outputState, char c); -static void CopySendInt32(PartialCopyState outputState, int32 val); -static void CopySendInt16(PartialCopyState outputState, int16 val); -static void CopyAttributeOutText(PartialCopyState outputState, char *string); -static inline void CopyFlushOutput(PartialCopyState outputState, - char *start, char *pointer); static uint32 RangePartitionId(Datum partitionValue, const void *context); static uint32 HashPartitionId(Datum partitionValue, const void *context); @@ -731,13 +720,16 @@ FilterAndPartitionTable(const char *filterQuery, FileOutputStream *partitionFileArray, uint32 fileCount) { - PartialCopyState rowOutputState = NULL; + CopyOutState rowOutputState = NULL; FmgrInfo *columnOutputFunctions = NULL; int partitionColumnIndex = 0; Oid partitionColumnTypeId = InvalidOid; Portal queryPortal = NULL; int connected = 0; int finished = 0; + uint32 columnCount = 0; + Datum *valueArray = NULL; + bool *isNullArray = NULL; const char *noPortalName = NULL; const bool readOnly = true; @@ -784,6 +776,10 @@ FilterAndPartitionTable(const char *filterQuery, OutputBinaryHeaders(partitionFileArray, fileCount); } + columnCount = (uint32) SPI_tuptable->tupdesc->natts; + valueArray = (Datum *) palloc0(columnCount * sizeof(Datum)); + isNullArray = (bool *) palloc0(columnCount * sizeof(bool)); + while (SPI_processed > 0) { int rowIndex = 0; @@ -815,13 +811,19 @@ FilterAndPartitionTable(const char *filterQuery, partitionId = 0; } - OutputRow(row, rowDescriptor, rowOutputState, columnOutputFunctions); + /* deconstruct the tuple; this is faster than repeated heap_getattr */ + heap_deform_tuple(row, rowDescriptor, valueArray, isNullArray); + + AppendCopyRowData(valueArray, isNullArray, rowDescriptor, + rowOutputState, columnOutputFunctions); + rowText = rowOutputState->fe_msgbuf; partitionFile = partitionFileArray[partitionId]; FileOutputStreamWrite(partitionFile, rowText); resetStringInfo(rowText); + MemoryContextReset(rowOutputState->rowcontext); } SPI_freetuptable(SPI_tuptable); @@ -829,6 +831,9 @@ FilterAndPartitionTable(const char *filterQuery, SPI_cursor_fetch(queryPortal, fetchForward, prefetchCount); } + pfree(valueArray); + pfree(isNullArray); + SPI_cursor_close(queryPortal); if (BinaryWorkerCopyFormat) @@ -866,44 +871,6 @@ ColumnIndex(TupleDesc rowDescriptor, const char *columnName) } -/* - * ColumnOutputFunctions walks over a table's columns, and finds each column's - * type information. The function then resolves each type's output function, - * and stores and returns these output functions in an array. - */ -static FmgrInfo * -ColumnOutputFunctions(TupleDesc rowDescriptor, bool binaryFormat) -{ - uint32 columnCount = (uint32) rowDescriptor->natts; - FmgrInfo *columnOutputFunctions = palloc0(columnCount * sizeof(FmgrInfo)); - - uint32 columnIndex = 0; - for (columnIndex = 0; columnIndex < columnCount; columnIndex++) - { - FmgrInfo *currentOutputFunction = &columnOutputFunctions[columnIndex]; - Form_pg_attribute currentColumn = rowDescriptor->attrs[columnIndex]; - Oid columnTypeId = currentColumn->atttypid; - Oid outputFunctionId = InvalidOid; - bool typeVariableLength = false; - - if (binaryFormat) - { - getTypeBinaryOutputInfo(columnTypeId, &outputFunctionId, &typeVariableLength); - } - else - { - getTypeOutputInfo(columnTypeId, &outputFunctionId, &typeVariableLength); - } - - Assert(currentColumn->attisdropped == false); - - fmgr_info(outputFunctionId, currentOutputFunction); - } - - return columnOutputFunctions; -} - - /* * InitRowOutputState creates and initializes a copy state object. This object * is internal to the copy command's implementation in Postgres; and we refactor @@ -914,11 +881,10 @@ ColumnOutputFunctions(TupleDesc rowDescriptor, bool binaryFormat) * must match one another. Therefore, any changes to the default values in the * copy command must be propagated to this function. */ -static PartialCopyState +static CopyOutState InitRowOutputState(void) { - PartialCopyState rowOutputState = - (PartialCopyState) palloc0(sizeof(PartialCopyStateData)); + CopyOutState rowOutputState = (CopyOutState) palloc0(sizeof(CopyOutStateData)); int fileEncoding = pg_get_client_encoding(); int databaseEncoding = GetDatabaseEncoding(); @@ -975,7 +941,7 @@ InitRowOutputState(void) /* Clears copy state used for outputting row data. */ static void -ClearRowOutputState(PartialCopyState rowOutputState) +ClearRowOutputState(CopyOutState rowOutputState) { Assert(rowOutputState != NULL); @@ -990,98 +956,6 @@ ClearRowOutputState(PartialCopyState rowOutputState) } -/* - * OutputRow serializes one row using the column output functions, - * and appends the data to the row output state object's message buffer. - * This function is modeled after the CopyOneRowTo() function in - * commands/copy.c, but only implements a subset of that functionality. - */ -static void -OutputRow(HeapTuple row, TupleDesc rowDescriptor, - PartialCopyState rowOutputState, FmgrInfo *columnOutputFunctions) -{ - MemoryContext oldContext = NULL; - uint32 columnIndex = 0; - - uint32 columnCount = (uint32) rowDescriptor->natts; - Datum *valueArray = (Datum *) palloc0(columnCount * sizeof(Datum)); - bool *isNullArray = (bool *) palloc0(columnCount * sizeof(bool)); - - /* deconstruct the tuple; this is faster than repeated heap_getattr */ - heap_deform_tuple(row, rowDescriptor, valueArray, isNullArray); - - /* reset previous tuple's output data, and the temporary memory context */ - resetStringInfo(rowOutputState->fe_msgbuf); - - MemoryContextReset(rowOutputState->rowcontext); - oldContext = MemoryContextSwitchTo(rowOutputState->rowcontext); - - if (rowOutputState->binary) - { - CopySendInt16(rowOutputState, rowDescriptor->natts); - } - - for (columnIndex = 0; columnIndex < columnCount; columnIndex++) - { - Datum value = valueArray[columnIndex]; - bool isNull = isNullArray[columnIndex]; - bool lastColumn = false; - - if (rowOutputState->binary) - { - if (!isNull) - { - FmgrInfo *outputFunctionPointer = &columnOutputFunctions[columnIndex]; - bytea *outputBytes = SendFunctionCall(outputFunctionPointer, value); - - CopySendInt32(rowOutputState, VARSIZE(outputBytes) - VARHDRSZ); - CopySendData(rowOutputState, VARDATA(outputBytes), - VARSIZE(outputBytes) - VARHDRSZ); - } - else - { - CopySendInt32(rowOutputState, -1); - } - } - else - { - if (!isNull) - { - FmgrInfo *outputFunctionPointer = &columnOutputFunctions[columnIndex]; - char *columnText = OutputFunctionCall(outputFunctionPointer, value); - - CopyAttributeOutText(rowOutputState, columnText); - } - else - { - CopySendString(rowOutputState, rowOutputState->null_print_client); - } - - lastColumn = ((columnIndex + 1) == columnCount); - if (!lastColumn) - { - CopySendChar(rowOutputState, rowOutputState->delim[0]); - } - } - } - - if (!rowOutputState->binary) - { - /* append default line termination string depending on the platform */ -#ifndef WIN32 - CopySendChar(rowOutputState, '\n'); -#else - CopySendString(rowOutputState, "\r\n"); -#endif - } - - MemoryContextSwitchTo(oldContext); - - pfree(valueArray); - pfree(isNullArray); -} - - /* * Write the header of postgres' binary serialization format to each partition file. * This function is used when binary_worker_copy_format is enabled. @@ -1093,22 +967,14 @@ OutputBinaryHeaders(FileOutputStream *partitionFileArray, uint32 fileCount) for (fileIndex = 0; fileIndex < fileCount; fileIndex++) { /* Generate header for a binary copy */ - const int32 zero = 0; FileOutputStream partitionFile = { 0, 0, 0 }; - PartialCopyStateData headerOutputStateData; - PartialCopyState headerOutputState = (PartialCopyState) & headerOutputStateData; + CopyOutStateData headerOutputStateData; + CopyOutState headerOutputState = (CopyOutState) & headerOutputStateData; - memset(headerOutputState, 0, sizeof(PartialCopyStateData)); + memset(headerOutputState, 0, sizeof(CopyOutStateData)); headerOutputState->fe_msgbuf = makeStringInfo(); - /* Signature */ - CopySendData(headerOutputState, BinarySignature, 11); - - /* Flags field (no OIDs) */ - CopySendInt32(headerOutputState, zero); - - /* No header extension */ - CopySendInt32(headerOutputState, zero); + AppendCopyBinaryHeaders(headerOutputState); partitionFile = partitionFileArray[fileIndex]; FileOutputStreamWrite(partitionFile, headerOutputState->fe_msgbuf); @@ -1127,15 +993,14 @@ OutputBinaryFooters(FileOutputStream *partitionFileArray, uint32 fileCount) for (fileIndex = 0; fileIndex < fileCount; fileIndex++) { /* Generate footer for a binary copy */ - int16 negative = -1; FileOutputStream partitionFile = { 0, 0, 0 }; - PartialCopyStateData footerOutputStateData; - PartialCopyState footerOutputState = (PartialCopyState) & footerOutputStateData; + CopyOutStateData footerOutputStateData; + CopyOutState footerOutputState = (CopyOutState) & footerOutputStateData; - memset(footerOutputState, 0, sizeof(PartialCopyStateData)); + memset(footerOutputState, 0, sizeof(CopyOutStateData)); footerOutputState->fe_msgbuf = makeStringInfo(); - CopySendInt16(footerOutputState, negative); + AppendCopyBinaryFooters(footerOutputState); partitionFile = partitionFileArray[fileIndex]; FileOutputStreamWrite(partitionFile, footerOutputState->fe_msgbuf); @@ -1143,158 +1008,6 @@ OutputBinaryFooters(FileOutputStream *partitionFileArray, uint32 fileCount) } -/* *INDENT-OFF* */ -/* Append data to the copy buffer in outputState */ -static void -CopySendData(PartialCopyState outputState, const void *databuf, int datasize) -{ - appendBinaryStringInfo(outputState->fe_msgbuf, databuf, datasize); -} - - -/* Append a striong to the copy buffer in outputState. */ -static void -CopySendString(PartialCopyState outputState, const char *str) -{ - appendBinaryStringInfo(outputState->fe_msgbuf, str, strlen(str)); -} - - -/* Append a char to the copy buffer in outputState. */ -static void -CopySendChar(PartialCopyState outputState, char c) -{ - appendStringInfoCharMacro(outputState->fe_msgbuf, c); -} - - -/* Append an int32 to the copy buffer in outputState. */ -static void -CopySendInt32(PartialCopyState outputState, int32 val) -{ - uint32 buf = htonl((uint32) val); - CopySendData(outputState, &buf, sizeof(buf)); -} - - -/* Append an int16 to the copy buffer in outputState. */ -static void -CopySendInt16(PartialCopyState outputState, int16 val) -{ - uint16 buf = htons((uint16) val); - CopySendData(outputState, &buf, sizeof(buf)); -} - - -/* - * Send text representation of one column, with conversion and escaping. - * - * NB: This function is based on commands/copy.c and doesn't fully conform to - * our coding style. The function should be kept in sync with copy.c. - */ -static void -CopyAttributeOutText(PartialCopyState cstate, char *string) -{ - char *pointer = NULL; - char *start = NULL; - char c = '\0'; - char delimc = cstate->delim[0]; - - if (cstate->need_transcoding) - { - pointer = pg_server_to_any(string, strlen(string), cstate->file_encoding); - } - else - { - pointer = string; - } - - /* - * We have to grovel through the string searching for control characters - * and instances of the delimiter character. In most cases, though, these - * are infrequent. To avoid overhead from calling CopySendData once per - * character, we dump out all characters between escaped characters in a - * single call. The loop invariant is that the data from "start" to "pointer" - * can be sent literally, but hasn't yet been. - * - * As all encodings here are safe, i.e. backend supported ones, we can - * skip doing pg_encoding_mblen(), because in valid backend encodings, - * extra bytes of a multibyte character never look like ASCII. - */ - start = pointer; - while ((c = *pointer) != '\0') - { - if ((unsigned char) c < (unsigned char) 0x20) - { - /* - * \r and \n must be escaped, the others are traditional. We - * prefer to dump these using the C-like notation, rather than - * a backslash and the literal character, because it makes the - * dump file a bit more proof against Microsoftish data - * mangling. - */ - switch (c) - { - case '\b': - c = 'b'; - break; - case '\f': - c = 'f'; - break; - case '\n': - c = 'n'; - break; - case '\r': - c = 'r'; - break; - case '\t': - c = 't'; - break; - case '\v': - c = 'v'; - break; - default: - /* If it's the delimiter, must backslash it */ - if (c == delimc) - break; - /* All ASCII control chars are length 1 */ - pointer++; - continue; /* fall to end of loop */ - } - /* if we get here, we need to convert the control char */ - CopyFlushOutput(cstate, start, pointer); - CopySendChar(cstate, '\\'); - CopySendChar(cstate, c); - start = ++pointer; /* do not include char in next run */ - } - else if (c == '\\' || c == delimc) - { - CopyFlushOutput(cstate, start, pointer); - CopySendChar(cstate, '\\'); - start = pointer++; /* we include char in next run */ - } - else - { - pointer++; - } - } - - CopyFlushOutput(cstate, start, pointer); -} - - -/* *INDENT-ON* */ -/* Helper function to send pending copy output */ -static inline void -CopyFlushOutput(PartialCopyState cstate, char *start, char *pointer) -{ - if (pointer > start) - { - CopySendData(cstate, start, pointer - start); - } -} - - /* Helper function that invokes a function with the default collation oid. */ Datum CompareCall2(FmgrInfo *functionInfo, Datum leftArgument, Datum rightArgument) diff --git a/src/include/distributed/multi_copy.h b/src/include/distributed/multi_copy.h index 279b8c165..2bb13048f 100644 --- a/src/include/distributed/multi_copy.h +++ b/src/include/distributed/multi_copy.h @@ -20,7 +20,36 @@ extern int CopyTransactionManager; +/* + * A smaller version of copy.c's CopyStateData, trimmed to the elements + * necessary to copy out results. While it'd be a bit nicer to share code, + * it'd require changing core postgres code. + */ +typedef struct CopyOutStateData +{ + StringInfo fe_msgbuf; /* used for all dests during COPY TO, only for + * dest == COPY_NEW_FE in COPY FROM */ + int file_encoding; /* file or remote side's character encoding */ + bool need_transcoding; /* file encoding diff from server? */ + bool binary; /* binary format? */ + char *null_print; /* NULL marker string (server encoding!) */ + char *null_print_client; /* same converted to file encoding */ + char *delim; /* column delimiter (must be 1 byte) */ + + MemoryContext rowcontext; /* per-row evaluation context */ +} CopyOutStateData; + +typedef struct CopyOutStateData *CopyOutState; + + /* function declarations for copying into a distributed table */ +extern FmgrInfo * ColumnOutputFunctions(TupleDesc rowDescriptor, bool binaryFormat); +extern void AppendCopyRowData(Datum *valueArray, bool *isNullArray, + TupleDesc rowDescriptor, + CopyOutState rowOutputState, + FmgrInfo *columnOutputFunctions); +extern void AppendCopyBinaryHeaders(CopyOutState headerOutputState); +extern void AppendCopyBinaryFooters(CopyOutState footerOutputState); extern void CitusCopyFrom(CopyStmt *copyStatement, char *completionTag); diff --git a/src/include/distributed/multi_transaction.h b/src/include/distributed/multi_transaction.h index 827f0fcaa..760899e44 100644 --- a/src/include/distributed/multi_transaction.h +++ b/src/include/distributed/multi_transaction.h @@ -48,9 +48,10 @@ typedef struct TransactionConnection /* Functions declarations for transaction and connection management */ -extern void PrepareTransactions(List *connectionList); -extern void AbortTransactions(List *connectionList); -extern void CommitTransactions(List *connectionList); +extern void InitializeDistributedTransaction(void); +extern void PrepareRemoteTransactions(List *connectionList); +extern void AbortRemoteTransactions(List *connectionList); +extern void CommitRemoteTransactions(List *connectionList); extern void CloseConnections(List *connectionList); diff --git a/src/include/distributed/worker_protocol.h b/src/include/distributed/worker_protocol.h index db78f8138..ac526a2be 100644 --- a/src/include/distributed/worker_protocol.h +++ b/src/include/distributed/worker_protocol.h @@ -79,28 +79,6 @@ typedef struct HashPartitionContext } HashPartitionContext; -/* - * A smaller version of copy.c's CopyStateData, trimmed to the elements - * necessary for re-partition jobs. While it'd be a bit nicer to share code, - * it'd require changing core postgres code. - */ -typedef struct PartialCopyStateData -{ - StringInfo fe_msgbuf; /* used for all dests during COPY TO, only for - * dest == COPY_NEW_FE in COPY FROM */ - int file_encoding; /* file or remote side's character encoding */ - bool need_transcoding; /* file encoding diff from server? */ - bool binary; /* binary format? */ - char *null_print; /* NULL marker string (server encoding!) */ - char *null_print_client; /* same converted to file encoding */ - char *delim; /* column delimiter (must be 1 byte) */ - - MemoryContext rowcontext; /* per-row evaluation context */ -} PartialCopyStateData; - -typedef struct PartialCopyStateData *PartialCopyState; - - /* * FileOutputStream helps buffer write operations to a file; these writes are * then regularly flushed to the underlying file. This structure differs from diff --git a/src/test/regress/input/multi_copy.source b/src/test/regress/input/multi_copy.source index 92544b9a8..8a690fba8 100644 --- a/src/test/regress/input/multi_copy.source +++ b/src/test/regress/input/multi_copy.source @@ -104,6 +104,54 @@ SELECT count(*) FROM customer_copy_hash; -- Confirm that data was copied SELECT count(*) FROM customer_copy_hash; +-- Create a new hash-partitioned table with default now() function +CREATE TABLE customer_with_default( + c_custkey integer, + c_name varchar(25) not null, + c_time timestamp default now()); + +SELECT master_create_distributed_table('customer_with_default', 'c_custkey', 'hash'); + +SELECT master_create_worker_shards('customer_with_default', 64, 1); + +-- Test with default values for now() function +COPY customer_with_default (c_custkey, c_name) FROM STDIN +WITH (FORMAT 'csv'); +1,customer1 +2,customer2 +\. + +-- Confirm that data was copied with now() function +SELECT count(*) FROM customer_with_default where c_time IS NOT NULL; + +-- Add columns to the table and perform a COPY +ALTER TABLE customer_copy_hash ADD COLUMN extra1 INT DEFAULT 0; +ALTER TABLE customer_copy_hash ADD COLUMN extra2 INT DEFAULT 0; + +COPY customer_copy_hash (c_custkey, c_name, extra1, extra2) FROM STDIN CSV; +10,customer10,1,5 +\. + +SELECT * FROM customer_copy_hash WHERE extra1 = 1; + +-- Test dropping an intermediate column +ALTER TABLE customer_copy_hash DROP COLUMN extra1; + +COPY customer_copy_hash (c_custkey, c_name, extra2) FROM STDIN CSV; +11,customer11,5 +\. + +SELECT * FROM customer_copy_hash WHERE c_custkey = 11; + +-- Test dropping the last column +ALTER TABLE customer_copy_hash DROP COLUMN extra2; + +COPY customer_copy_hash (c_custkey, c_name) FROM STDIN CSV; +12,customer12 +\. + +SELECT * FROM customer_copy_hash WHERE c_custkey = 12; + -- Create a new range-partitioned table into which to COPY CREATE TABLE customer_copy_range ( c_custkey integer, diff --git a/src/test/regress/multi_schedule b/src/test/regress/multi_schedule index ae71426f4..e493fc739 100644 --- a/src/test/regress/multi_schedule +++ b/src/test/regress/multi_schedule @@ -100,7 +100,6 @@ test: multi_append_table_to_shard # --------- test: multi_outer_join -# # --- # Tests covering mostly modification queries and required preliminary # functionality related to metadata, shard creation, shard pruning and diff --git a/src/test/regress/output/multi_copy.source b/src/test/regress/output/multi_copy.source index 7edce8025..c43e04292 100644 --- a/src/test/regress/output/multi_copy.source +++ b/src/test/regress/output/multi_copy.source @@ -20,7 +20,7 @@ SELECT master_create_distributed_table('customer_copy_hash', 'c_custkey', 'hash' -- Test COPY into empty hash-partitioned table COPY customer_copy_hash FROM '@abs_srcdir@/data/customer.1.data' WITH (DELIMITER '|'); -ERROR: could not find any shards for query +ERROR: could not find any shards into which to copy DETAIL: No shards exist for distributed table "customer_copy_hash". HINT: Run master_create_worker_shards to create shards and try again. SELECT master_create_worker_shards('customer_copy_hash', 64, 1); @@ -47,7 +47,6 @@ COPY customer_copy_hash (c_custkey, c_name) FROM STDIN WITH (FORMAT 'csv'); ERROR: duplicate key value violates unique constraint "customer_copy_hash_pkey_103160" DETAIL: Key (c_custkey)=(2) already exists. -CONTEXT: COPY customer_copy_hash, line 4: "" -- Confirm that no data was copied SELECT count(*) FROM customer_copy_hash; count @@ -90,7 +89,6 @@ COPY customer_copy_hash (c_custkey, c_name) FROM STDIN WITH (FORMAT 'csv'); ERROR: null value in column "c_name" violates not-null constraint DETAIL: Failing row contains (8, null, null, null, null, null, null, null). -CONTEXT: COPY customer_copy_hash, line 4: "" -- Confirm that no data was copied SELECT count(*) FROM customer_copy_hash; count @@ -126,6 +124,61 @@ SELECT count(*) FROM customer_copy_hash; 2006 (1 row) +-- Create a new hash-partitioned table with default now() function +CREATE TABLE customer_with_default( + c_custkey integer, + c_name varchar(25) not null, + c_time timestamp default now()); +SELECT master_create_distributed_table('customer_with_default', 'c_custkey', 'hash'); + master_create_distributed_table +--------------------------------- + +(1 row) + +SELECT master_create_worker_shards('customer_with_default', 64, 1); + master_create_worker_shards +----------------------------- + +(1 row) + +-- Test with default values for now() function +COPY customer_with_default (c_custkey, c_name) FROM STDIN +WITH (FORMAT 'csv'); +-- Confirm that data was copied with now() function +SELECT count(*) FROM customer_with_default where c_time IS NOT NULL; + count +------- + 2 +(1 row) + +-- Add columns to the table and perform a COPY +ALTER TABLE customer_copy_hash ADD COLUMN extra1 INT DEFAULT 0; +ALTER TABLE customer_copy_hash ADD COLUMN extra2 INT DEFAULT 0; +COPY customer_copy_hash (c_custkey, c_name, extra1, extra2) FROM STDIN CSV; +SELECT * FROM customer_copy_hash WHERE extra1 = 1; + c_custkey | c_name | c_address | c_nationkey | c_phone | c_acctbal | c_mktsegment | c_comment | extra1 | extra2 +-----------+------------+-----------+-------------+---------+-----------+--------------+-----------+--------+-------- + 10 | customer10 | | | | | | | 1 | 5 +(1 row) + +-- Test dropping an intermediate column +ALTER TABLE customer_copy_hash DROP COLUMN extra1; +COPY customer_copy_hash (c_custkey, c_name, extra2) FROM STDIN CSV; +SELECT * FROM customer_copy_hash WHERE c_custkey = 11; + c_custkey | c_name | c_address | c_nationkey | c_phone | c_acctbal | c_mktsegment | c_comment | extra2 +-----------+------------+-----------+-------------+---------+-----------+--------------+-----------+-------- + 11 | customer11 | | | | | | | 5 +(1 row) + +-- Test dropping the last column +ALTER TABLE customer_copy_hash DROP COLUMN extra2; +COPY customer_copy_hash (c_custkey, c_name) FROM STDIN CSV; +SELECT * FROM customer_copy_hash WHERE c_custkey = 12; + c_custkey | c_name | c_address | c_nationkey | c_phone | c_acctbal | c_mktsegment | c_comment +-----------+------------+-----------+-------------+---------+-----------+--------------+----------- + 12 | customer12 | | | | | | +(1 row) + -- Create a new range-partitioned table into which to COPY CREATE TABLE customer_copy_range ( c_custkey integer, @@ -145,7 +198,7 @@ SELECT master_create_distributed_table('customer_copy_range', 'c_custkey', 'rang -- Test COPY into empty range-partitioned table COPY customer_copy_range FROM '@abs_srcdir@/data/customer.1.data' WITH (DELIMITER '|'); -ERROR: could not find any shards for query +ERROR: could not find any shards into which to copy DETAIL: No shards exist for distributed table "customer_copy_range". SELECT master_create_empty_shard('customer_copy_range') AS new_shard_id \gset