From d37ccff15b0c92e3b7a596e2e1e7bbdb1c2bfa86 Mon Sep 17 00:00:00 2001 From: Marco Slot Date: Tue, 23 Feb 2016 14:07:39 +0100 Subject: [PATCH] Support for COPY FROM, based on pg_shard PR by Postres Pro --- src/backend/distributed/commands/multi_copy.c | 1101 +++++++++++++++++ .../distributed/executor/multi_utility.c | 17 +- .../planner/multi_physical_planner.c | 7 +- src/backend/distributed/shared_library_init.c | 23 + .../distributed/utils/connection_cache.c | 18 +- .../distributed/utils/transaction_manager.c | 173 +++ src/include/distributed/connection_cache.h | 2 + src/include/distributed/multi_copy.h | 29 + .../distributed/multi_physical_planner.h | 3 + src/include/distributed/transaction_manager.h | 49 + src/test/regress/expected/multi_utilities.out | 3 - .../expected/multi_utility_statements.out | 3 - src/test/regress/sql/multi_utilities.sql | 3 - .../regress/sql/multi_utility_statements.sql | 3 - 14 files changed, 1401 insertions(+), 33 deletions(-) create mode 100644 src/backend/distributed/commands/multi_copy.c create mode 100644 src/backend/distributed/utils/transaction_manager.c create mode 100644 src/include/distributed/multi_copy.h create mode 100644 src/include/distributed/transaction_manager.h diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c new file mode 100644 index 000000000..6d3d46e73 --- /dev/null +++ b/src/backend/distributed/commands/multi_copy.c @@ -0,0 +1,1101 @@ +/*------------------------------------------------------------------------- + * + * multi_copy.c + * This file contains implementation of COPY utility for distributed + * tables. + * + * Contributed by Konstantin Knizhnik, Postgres Professional + * + * Copyright (c) 2016, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" +#include "c.h" +#include "fmgr.h" +#include "funcapi.h" +#include "libpq-fe.h" +#include "miscadmin.h" +#include "plpgsql.h" + +#include + +#include "access/heapam.h" +#include "access/htup_details.h" +#include "access/htup.h" +#include "access/nbtree.h" +#include "access/sdir.h" +#include "access/tupdesc.h" +#include "access/xact.h" +#include "catalog/namespace.h" +#include "catalog/pg_class.h" +#include "catalog/pg_type.h" +#include "catalog/pg_am.h" +#include "catalog/pg_collation.h" +#include "commands/extension.h" +#include "commands/copy.h" +#include "commands/defrem.h" +#include "distributed/citus_ruleutils.h" +#include "distributed/connection_cache.h" +#include "distributed/listutils.h" +#include "distributed/master_metadata_utility.h" +#include "distributed/master_protocol.h" +#include "distributed/multi_copy.h" +#include "distributed/multi_physical_planner.h" +#include "distributed/pg_dist_partition.h" +#include "distributed/resource_lock.h" +#include "distributed/transaction_manager.h" +#include "distributed/worker_protocol.h" +#include "executor/execdesc.h" +#include "executor/executor.h" +#include "executor/instrument.h" +#include "executor/tuptable.h" +#include "lib/stringinfo.h" +#include "nodes/execnodes.h" +#include "nodes/makefuncs.h" +#include "nodes/memnodes.h" +#include "nodes/nodeFuncs.h" +#include "nodes/nodes.h" +#include "nodes/params.h" +#include "nodes/parsenodes.h" +#include "nodes/pg_list.h" +#include "nodes/plannodes.h" +#include "nodes/primnodes.h" +#include "optimizer/clauses.h" +#include "optimizer/cost.h" +#include "optimizer/planner.h" +#include "optimizer/var.h" +#include "parser/parser.h" +#include "parser/analyze.h" +#include "parser/parse_node.h" +#include "parser/parsetree.h" +#include "parser/parse_type.h" +#include "storage/lock.h" +#include "tcop/dest.h" +#include "tcop/tcopprot.h" +#include "tcop/utility.h" +#include "tsearch/ts_locale.h" +#include "utils/builtins.h" +#include "utils/elog.h" +#include "utils/errcodes.h" +#include "utils/guc.h" +#include "utils/lsyscache.h" +#include "utils/typcache.h" +#include "utils/palloc.h" +#include "utils/rel.h" +#include "utils/relcache.h" +#include "utils/snapmgr.h" +#include "utils/tuplestore.h" +#include "utils/memutils.h" + + +#define INITIAL_CONNECTION_CACHE_SIZE 1001 + + +/* the transaction manager to use for COPY commands */ +int CopyTransactionManager = TRANSACTION_MANAGER_1PC; + + +/* Data structures copy.c, to keep track of COPY processing state */ +typedef enum CopyDest +{ + COPY_FILE, /* to/from file (or a piped program) */ + COPY_OLD_FE, /* to/from frontend (2.0 protocol) */ + COPY_NEW_FE /* to/from frontend (3.0 protocol) */ +} CopyDest; + +typedef enum EolType +{ + EOL_UNKNOWN, + EOL_NL, + EOL_CR, + EOL_CRNL +} EolType; + +typedef struct CopyStateData +{ + /* low-level state data */ + CopyDest copy_dest; /* type of copy source/destination */ + FILE *copy_file; /* used if copy_dest == COPY_FILE */ + StringInfo fe_msgbuf; /* used for all dests during COPY TO, only for + * dest == COPY_NEW_FE in COPY FROM */ + bool fe_eof; /* true if detected end of copy data */ + EolType eol_type; /* EOL type of input */ + int file_encoding; /* file or remote side's character encoding */ + bool need_transcoding; /* file encoding diff from server? */ + bool encoding_embeds_ascii; /* ASCII can be non-first byte? */ + + /* parameters from the COPY command */ + Relation rel; /* relation to copy to or from */ + QueryDesc *queryDesc; /* executable query to copy from */ + List *attnumlist; /* integer list of attnums to copy */ + char *filename; /* filename, or NULL for STDIN/STDOUT */ + bool is_program; /* is 'filename' a program to popen? */ + bool binary; /* binary format? */ + bool oids; /* include OIDs? */ + bool freeze; /* freeze rows on loading? */ + bool csv_mode; /* Comma Separated Value format? */ + bool header_line; /* CSV header line? */ + char *null_print; /* NULL marker string (server encoding!) */ + int null_print_len; /* length of same */ + char *null_print_client; /* same converted to file encoding */ + char *delim; /* column delimiter (must be 1 byte) */ + char *quote; /* CSV quote char (must be 1 byte) */ + char *escape; /* CSV escape char (must be 1 byte) */ + List *force_quote; /* list of column names */ + bool force_quote_all; /* FORCE QUOTE *? */ + bool *force_quote_flags; /* per-column CSV FQ flags */ + List *force_notnull; /* list of column names */ + bool *force_notnull_flags; /* per-column CSV FNN flags */ +#if PG_VERSION_NUM >= 90400 + List *force_null; /* list of column names */ + bool *force_null_flags; /* per-column CSV FN flags */ +#endif + bool convert_selectively; /* do selective binary conversion? */ + List *convert_select; /* list of column names (can be NIL) */ + bool *convert_select_flags; /* per-column CSV/TEXT CS flags */ + + /* these are just for error messages, see CopyFromErrorCallback */ + const char *cur_relname; /* table name for error messages */ + int cur_lineno; /* line number for error messages */ + const char *cur_attname; /* current att for error messages */ + const char *cur_attval; /* current att value for error messages */ + + /* + * Working state for COPY TO/FROM + */ + MemoryContext copycontext; /* per-copy execution context */ + + /* + * Working state for COPY TO + */ + FmgrInfo *out_functions; /* lookup info for output functions */ + MemoryContext rowcontext; /* per-row evaluation context */ + + /* + * Working state for COPY FROM + */ + AttrNumber num_defaults; + bool file_has_oids; + FmgrInfo oid_in_function; + Oid oid_typioparam; + FmgrInfo *in_functions; /* array of input functions for each attrs */ + Oid *typioparams; /* array of element types for in_functions */ + int *defmap; /* array of default att numbers */ + ExprState **defexprs; /* array of default att expressions */ + bool volatile_defexprs; /* is any of defexprs volatile? */ + List *range_table; + + /* + * These variables are used to reduce overhead in textual COPY FROM. + * + * attribute_buf holds the separated, de-escaped text for each field of + * the current line. The CopyReadAttributes functions return arrays of + * pointers into this buffer. We avoid palloc/pfree overhead by re-using + * the buffer on each cycle. + */ + StringInfoData attribute_buf; + + /* field raw data pointers found by COPY FROM */ + + int max_fields; + char **raw_fields; + + /* + * Similarly, line_buf holds the whole input line being processed. The + * input cycle is first to read the whole line into line_buf, convert it + * to server encoding there, and then extract the individual attribute + * fields into attribute_buf. line_buf is preserved unmodified so that we + * can display it in error messages if appropriate. + */ + StringInfoData line_buf; + bool line_buf_converted; /* converted to server encoding? */ + bool line_buf_valid; /* contains the row being processed? */ + + /* + * Finally, raw_buf holds raw data read from the data source (file or + * client connection). CopyReadLine parses this data sufficiently to + * locate line boundaries, then transfers the data to line_buf and + * converts it. Note: we guarantee that there is a \0 at + * raw_buf[raw_buf_len]. + */ +#define RAW_BUF_SIZE 65536 /* we palloc RAW_BUF_SIZE+1 bytes */ + char *raw_buf; + int raw_buf_index; /* next byte to process */ + int raw_buf_len; /* total # of bytes stored */ +} CopyStateData; + + +/* Data structures for keeping track of connections to placements */ +typedef struct PlacementConnection +{ + int64 shardId; + bool prepared; + PGconn* connection; +} PlacementConnection; + +typedef struct ShardConnections +{ + int64 shardId; + List *connectionList; +} ShardConnections; + + +/* Local functions forward declarations */ +static HTAB * CreateShardConnectionHash(void); +static bool IsUniformHashDistribution(ShardInterval **shardIntervalArray, + int shardCount); +static FmgrInfo * ShardIntervalCompareFunction(Var *partitionColumn, char partitionMethod); +static ShardInterval * FindShardInterval(Datum partitionColumnValue, + ShardInterval **shardIntervalCache, + int shardCount, char partitionMethod, + FmgrInfo *compareFunction, + FmgrInfo *hashFunction, bool useBinarySearch); +static ShardInterval * SearchCachedShardInterval(Datum partitionColumnValue, + ShardInterval** shardIntervalCache, + int shardCount, + FmgrInfo *compareFunction); +static void OpenShardConnections(CopyStmt *copyStatement, + ShardConnections *shardConnections, + int64 shardId); +static char * ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId); +static void AppendColumnNames(StringInfo buf, List *columnList); +static void AppendCopyOptions(StringInfo buf, List *copyOptionList); +static void CopyRowToPlacements(StringInfo lineBuf, ShardConnections *shardConnections); +static List * ConnectionList(HTAB *connectionHash); +static void PrepareCopyTransaction(List *connectionList); +static bool EndRemoteCopy(PGconn *connection); +static void AbortCopyTransaction(List *connectionList); +static void CommitCopyTransaction(List *connectionList); + + +/* + * CitusCopyFrom implements the COPY table_name FROM ... for hash-partitioned + * and range-partitioned tables. + */ +void +CitusCopyFrom(CopyStmt *copyStatement, char* completionTag) +{ + RangeVar *relation = copyStatement->relation; + Oid tableId = RangeVarGetRelid(relation, NoLock, false); + char *relationName = get_rel_name(tableId); + List *shardIntervalList = NULL; + ListCell *shardIntervalCell = NULL; + char partitionMethod = '\0'; + Var *partitionColumn = NULL; + HTAB *shardConnectionHash = NULL; + List *connectionList = NIL; + MemoryContext tupleContext = NULL; + CopyState copyState = NULL; + TupleDesc tupleDescriptor = NULL; + uint32 columnCount = 0; + Datum *columnValues = NULL; + bool *columnNulls = NULL; + Relation rel = NULL; + ShardInterval **shardIntervalCache = NULL; + bool useBinarySearch = false; + TypeCacheEntry *typeEntry = NULL; + FmgrInfo *hashFunction = NULL; + FmgrInfo *compareFunction = NULL; + int shardCount = 0; + uint64 processedRowCount = 0; + ErrorContextCallback errorCallback; + + /* disallow COPY to/from file or program except for superusers */ + if (copyStatement->filename != NULL && !superuser()) + { + if (copyStatement->is_program) + { + ereport(ERROR, + (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE), + errmsg("must be superuser to COPY to or from an external program"), + errhint("Anyone can COPY to stdout or from stdin. " + "psql's \\copy command also works for anyone."))); + } + else + { + ereport(ERROR, + (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE), + errmsg("must be superuser to COPY to or from a file"), + errhint("Anyone can COPY to stdout or from stdin. " + "psql's \\copy command also works for anyone."))); + } + } + + /* load the list of shards and verify that we have shards to copy into */ + shardIntervalList = LoadShardIntervalList(tableId); + if (shardIntervalList == NIL) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("could not find any shards for query"), + errdetail("No shards exist for distributed table \"%s\".", + relationName), + errhint("Run master_create_worker_shards to create shards " + "and try again."))); + } + + partitionMethod = PartitionMethod(tableId); + if (partitionMethod != DISTRIBUTE_BY_RANGE && partitionMethod != DISTRIBUTE_BY_HASH) + { + ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("COPY is only supported for hash- and range-partitioned tables"))); + } + + partitionColumn = PartitionColumn(tableId, 0); + + /* resolve hash function for partition column */ + typeEntry = lookup_type_cache(partitionColumn->vartype, TYPECACHE_HASH_PROC_FINFO); + hashFunction = &(typeEntry->hash_proc_finfo); + + /* resolve compare function for shard intervals */ + compareFunction = ShardIntervalCompareFunction(partitionColumn, partitionMethod); + + /* allocate column values and nulls arrays */ + rel = heap_open(tableId, AccessShareLock); + tupleDescriptor = RelationGetDescr(rel); + columnCount = tupleDescriptor->natts; + columnValues = palloc0(columnCount * sizeof(Datum)); + columnNulls = palloc0(columnCount * sizeof(bool)); + + /* + * We create a new memory context called tuple context, and read and write + * each row's values within this memory context. After each read and write, + * we reset the memory context. That way, we immediately release memory + * allocated for each row, and don't bloat memory usage with large input + * files. + */ + tupleContext = AllocSetContextCreate(CurrentMemoryContext, + "COPY Row Memory Context", + ALLOCSET_DEFAULT_MINSIZE, + ALLOCSET_DEFAULT_INITSIZE, + ALLOCSET_DEFAULT_MAXSIZE); + + /* create a mapping of shard id to a connection for each of its placements */ + shardConnectionHash = CreateShardConnectionHash(); + + /* initialize copy state to read from COPY data source */ + copyState = BeginCopyFrom(rel, copyStatement->filename, + copyStatement->is_program, + copyStatement->attlist, + copyStatement->options); + + if (copyState->binary) + { + ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), + errmsg("Copy in binary mode is not currently supported"))); + } + + /* set up callback to identify error line number */ + errorCallback.callback = CopyFromErrorCallback; + errorCallback.arg = (void *) copyState; + errorCallback.previous = error_context_stack; + error_context_stack = &errorCallback; + + /* lock shards in order of shard id to prevent deadlock */ + shardIntervalList = SortList(shardIntervalList, CompareTasksByShardId); + + foreach(shardIntervalCell, shardIntervalList) + { + ShardInterval *shardInterval = (ShardInterval *) lfirst(shardIntervalCell); + int64 shardId = shardInterval->shardId; + + /* prevent concurrent changes to number of placements */ + LockShardDistributionMetadata(shardId, ShareLock); + + /* prevent concurrent update/delete statements */ + LockShardResource(shardId, ShareLock); + + } + + /* initialize the shard interval cache */ + shardCount = list_length(shardIntervalList); + shardIntervalCache = SortedShardIntervalArray(shardIntervalList); + + /* determine whether to use binary search */ + if (partitionMethod != DISTRIBUTE_BY_HASH || + !IsUniformHashDistribution(shardIntervalCache, shardCount)) + { + useBinarySearch = true; + } + + /* we use a PG_TRY block to roll back on errors (e.g. in NextCopyFrom) */ + PG_TRY(); + { + while (true) + { + bool nextRowFound = false; + Datum partitionColumnValue = 0; + ShardInterval *shardInterval = NULL; + int64 shardId = 0; + ShardConnections *shardConnections = NULL; + bool found = false; + StringInfo lineBuf = NULL; + MemoryContext oldContext = NULL; + + oldContext = MemoryContextSwitchTo(tupleContext); + + /* parse a row from the input */ + nextRowFound = NextCopyFrom(copyState, NULL, columnValues, columnNulls, NULL); + + MemoryContextSwitchTo(oldContext); + + if (!nextRowFound) + { + MemoryContextReset(tupleContext); + break; + } + + CHECK_FOR_INTERRUPTS(); + + /* find the partition column value */ + + if (columnNulls[partitionColumn->varattno-1]) + { + ereport(ERROR, (errcode(ERRCODE_NULL_VALUE_NOT_ALLOWED), + errmsg("cannot copy row with NULL value " + "in partition column"))); + } + + partitionColumnValue = columnValues[partitionColumn->varattno-1]; + + /* find the shard interval and id for the partition column value */ + shardInterval = FindShardInterval(partitionColumnValue, shardIntervalCache, + shardCount, partitionMethod, + compareFunction, hashFunction, + useBinarySearch); + if (shardInterval == NULL) + { + ereport(ERROR, (errcode(ERRCODE_OBJECT_NOT_IN_PREREQUISITE_STATE), + errmsg("no shard for partition column value"))); + } + + shardId = shardInterval->shardId; + + /* find the connections to the shard placements */ + shardConnections = (ShardConnections *) hash_search(shardConnectionHash, + &shardInterval->shardId, + HASH_ENTER, + &found); + if (!found) + { + OpenShardConnections(copyStatement, shardConnections, shardId); + } + + /* get the (truncated) line buffer */ + lineBuf = ©State->line_buf; + lineBuf->data[lineBuf->len++] = '\n'; + + /* Replicate row to all shard placements */ + CopyRowToPlacements(lineBuf, shardConnections); + + processedRowCount += 1; + + MemoryContextReset(tupleContext); + } + + /* prepare two phase commit in replicas */ + connectionList = ConnectionList(shardConnectionHash); + PrepareCopyTransaction(connectionList); + } + PG_CATCH(); + { + EndCopyFrom(copyState); + heap_close(rel, AccessShareLock); + + /* roll back all transactions */ + connectionList = ConnectionList(shardConnectionHash); + AbortCopyTransaction(connectionList); + PG_RE_THROW(); + } + PG_END_TRY(); + + EndCopyFrom(copyState); + heap_close(rel, AccessShareLock); + + error_context_stack = errorCallback.previous; + + if (QueryCancelPending) + { + AbortCopyTransaction(connectionList); + ereport(ERROR, (errcode(ERRCODE_QUERY_CANCELED), + errmsg("canceling statement due to user request"))); + } + else + { + CommitCopyTransaction(connectionList); + } + + if (completionTag != NULL) + { + snprintf(completionTag, COMPLETION_TAG_BUFSIZE, + "COPY " UINT64_FORMAT, processedRowCount); + } +} + + +/* + * CreateShardConnectionHash constructs a hash table used for shardId->Connection + * mapping. + */ +static HTAB * +CreateShardConnectionHash(void) +{ + HTAB *shardConnectionsHash = NULL; + HASHCTL info; + + memset(&info, 0, sizeof(info)); + info.keysize = sizeof(int64); + info.entrysize = sizeof(ShardConnections); + info.hash = tag_hash; + + shardConnectionsHash = hash_create("shardConnectionHash", + INITIAL_CONNECTION_CACHE_SIZE, &info, + HASH_ELEM | HASH_FUNCTION); + + return shardConnectionsHash; +} + + +/* + * ShardIntervalCompareFunction returns the appropriate compare function for the + * partition column type. In case of hash-partitioning, it always returns the compare + * function for integers. + */ +static FmgrInfo * +ShardIntervalCompareFunction(Var *partitionColumn, char partitionMethod) +{ + FmgrInfo *compareFunction = NULL; + + if (partitionMethod == DISTRIBUTE_BY_HASH) + { + compareFunction = GetFunctionInfo(INT4OID, BTREE_AM_OID, BTORDER_PROC); + } + else + { + compareFunction = GetFunctionInfo(partitionColumn->vartype, + BTREE_AM_OID, BTORDER_PROC); + } + + return compareFunction; +} + + +/* + * IsUniformHashDistribution determines whether the given list of sorted shards + * hash a uniform hash distribution, as produced by master_create_worker_shards. + */ +static bool +IsUniformHashDistribution(ShardInterval **shardIntervalArray, int shardCount) +{ + uint32 hashTokenIncrement = (uint32) (HASH_TOKEN_COUNT / shardCount); + int shardIndex = 0; + + for (shardIndex = 0; shardIndex < shardCount; shardIndex++) + { + ShardInterval *shardInterval = shardIntervalArray[shardIndex]; + int32 shardMinHashToken = INT32_MIN + (shardIndex * hashTokenIncrement); + int32 shardMaxHashToken = shardMinHashToken + (hashTokenIncrement - 1); + + if (shardIndex == (shardCount - 1)) + { + shardMaxHashToken = INT32_MAX; + } + + if (DatumGetInt32(shardInterval->minValue) != shardMinHashToken || + DatumGetInt32(shardInterval->maxValue) != shardMaxHashToken) + { + return false; + } + + shardIndex += 1; + } + + return true; +} + + +/* + * FindShardInterval finds a single shard interval in the cache for the + * given partition column value. + */ +static ShardInterval * +FindShardInterval(Datum partitionColumnValue, ShardInterval **shardIntervalCache, + int shardCount, char partitionMethod, FmgrInfo *compareFunction, + FmgrInfo *hashFunction, bool useBinarySearch) +{ + ShardInterval *shardInterval = NULL; + + if (partitionMethod == DISTRIBUTE_BY_HASH) + { + int hashedValue = DatumGetInt32(FunctionCall1(hashFunction, + partitionColumnValue)); + if (useBinarySearch) + { + shardInterval = SearchCachedShardInterval(Int32GetDatum(hashedValue), + shardIntervalCache, shardCount, + compareFunction); + } + else + { + uint32 hashTokenIncrement = (uint32) (HASH_TOKEN_COUNT / shardCount); + int shardHashCode = ((uint32) (hashedValue-INT32_MIN)/hashTokenIncrement); + + shardInterval = shardIntervalCache[shardHashCode]; + } + } + else + { + shardInterval = SearchCachedShardInterval(partitionColumnValue, + shardIntervalCache, shardCount, + compareFunction); + } + + return shardInterval; +} + + +/* + * SearchCachedShardInterval performs a binary search for a shard interval matching a + * given partition column value and returns it. + */ +static ShardInterval * +SearchCachedShardInterval(Datum partitionColumnValue, ShardInterval** shardIntervalCache, + int shardCount, FmgrInfo *compareFunction) +{ + int lowerBoundIndex = 0; + int upperBoundIndex = shardCount; + + while (lowerBoundIndex < upperBoundIndex) + { + int middleIndex = (lowerBoundIndex + upperBoundIndex) >> 1; + if (DatumGetInt32(FunctionCall2Coll(compareFunction, + DEFAULT_COLLATION_OID, + partitionColumnValue, + shardIntervalCache[middleIndex]->minValue)) < 0) + { + upperBoundIndex = middleIndex; + } + else if (DatumGetInt32(FunctionCall2Coll(compareFunction, + DEFAULT_COLLATION_OID, + partitionColumnValue, + shardIntervalCache[middleIndex]->maxValue)) <= 0) + { + return shardIntervalCache[middleIndex]; + + } + else + { + lowerBoundIndex = middleIndex + 1; + } + } + + return NULL; +} + + +/* + * Open connections for each placement of a shard. If a connection cannot be opened, + * the shard placement is marked as inactive and the COPY continues with the + * remaining shard placements. + */ +static void +OpenShardConnections(CopyStmt *copyStatement, ShardConnections *shardConnections, + int64 shardId) +{ + CitusTransactionManager const *transactionManager = + &CitusTransactionManagerImpl[CopyTransactionManager]; + + List *finalizedPlacementList = NIL; + List *failedPlacementList = NIL; + ListCell *placementCell = NULL; + ListCell *failedPlacementCell = NULL; + List *connectionList = NIL; + + finalizedPlacementList = ShardPlacementList(shardId); + + foreach(placementCell, finalizedPlacementList) + { + ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell); + char *nodeName = placement->nodeName; + int nodePort = placement->nodePort; + + PGconn *connection = ConnectToNode(nodeName, nodePort); + if (connection != NULL) + { + char *copyCommand = ConstructCopyStatement(copyStatement, shardId); + + /* + * New connection: start transaction with copy command on it. + * Append shard id to table name. + */ + if (transactionManager->Begin(connection) && + ExecuteCommand(connection, PGRES_COPY_IN, copyCommand)) + { + PlacementConnection *placementConnection = + (PlacementConnection *) palloc0(sizeof(PlacementConnection)); + + placementConnection->shardId = shardId; + placementConnection->prepared = false; + placementConnection->connection = connection; + + connectionList = lappend(connectionList, placementConnection); + } + else + { + failedPlacementList = lappend(failedPlacementList, placement); + ereport(WARNING, (errcode(ERRCODE_IO_ERROR), + errmsg("Failed to start '%s' on node %s:%d", + copyCommand, nodeName, nodePort))); + } + } + else + { + failedPlacementList = lappend(failedPlacementList, placement); + ereport(WARNING, (errcode(ERRCODE_IO_ERROR), + errmsg("Failed to connect to node %s:%d", + nodeName, nodePort))); + } + } + + /* if all placements failed, error out */ + if (list_length(failedPlacementList) == list_length(finalizedPlacementList)) + { + ereport(ERROR, (errmsg("could not modify any active placements"))); + } + + /* otherwise, mark failed placements as inactive: they're stale */ + foreach(failedPlacementCell, failedPlacementList) + { + ShardPlacement *failedPlacement = (ShardPlacement *) lfirst(failedPlacementCell); + uint64 shardLength = 0; + + DeleteShardPlacementRow(failedPlacement->shardId, failedPlacement->nodeName, + failedPlacement->nodePort); + InsertShardPlacementRow(failedPlacement->shardId, FILE_INACTIVE, shardLength, + failedPlacement->nodeName, failedPlacement->nodePort); + } + + shardConnections->shardId = shardId; + shardConnections->connectionList = connectionList; + +} + + +/* + * ConstructCopyStattement constructs the text of a COPY statement for a particular + * shard. + */ +static char * +ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId) +{ + StringInfo buf = makeStringInfo(); + char *qualifiedName = NULL; + + qualifiedName = quote_qualified_identifier(copyStatement->relation->schemaname, + copyStatement->relation->relname); + + appendStringInfo(buf, "COPY %s_%ld ", qualifiedName, (long) shardId); + + if (copyStatement->attlist != NIL) + { + AppendColumnNames(buf, copyStatement->attlist); + } + + appendStringInfoString(buf, "FROM STDIN"); + + if (copyStatement->options) + { + appendStringInfoString(buf, " WITH "); + + AppendCopyOptions(buf, copyStatement->options); + } + + return buf->data; +} + + +/* + * AppendCopyOptions deparses a list of CopyStmt options and appends them to buf. + */ +static void +AppendCopyOptions(StringInfo buf, List *copyOptionList) +{ + ListCell *optionCell = NULL; + char separator = '('; + + foreach(optionCell, copyOptionList) + { + DefElem *defel = (DefElem *) lfirst(optionCell); + + if (strcmp(defel->defname, "header") == 0 && defGetBoolean(defel)) + { + /* worker should not skip header again */ + continue; + } + + appendStringInfo(buf, "%c%s ", separator, defel->defname); + + if (strcmp(defel->defname, "force_quote") == 0) + { + if (!defel->arg) + { + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("argument to option \"%s\" must be a list of column names", + defel->defname))); + } + else if (IsA(defel->arg, A_Star)) + { + appendStringInfoString(buf, "*"); + } + else + { + AppendColumnNames(buf, (List *) defel->arg); + } + } + else if (strcmp(defel->defname, "force_not_null") == 0 || + strcmp(defel->defname, "force_null") == 0) + { + if (!defel->arg) + { + ereport(ERROR, + (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("argument to option \"%s\" must be a list of column names", + defel->defname))); + } + else + { + AppendColumnNames(buf, (List *) defel->arg); + } + } + else + { + appendStringInfoString(buf, defGetString(defel)); + } + + separator = ','; + } + + appendStringInfoChar(buf, ')'); +} + + +/* + * AppendColumnList deparses a list of column names into a StringInfo. + */ +static void +AppendColumnNames(StringInfo buf, List *columnList) +{ + ListCell *attributeCell = NULL; + char separator = '('; + + foreach(attributeCell, columnList) + { + char *columnName = strVal(lfirst(attributeCell)); + appendStringInfo(buf, "%c%s", separator, quote_identifier(columnName)); + separator = ','; + } + + appendStringInfoChar(buf, ')'); +} + + +/* + * CopyRowToPlacements copies a row to a list of placements for a shard. + */ +static void +CopyRowToPlacements(StringInfo lineBuf, ShardConnections *shardConnections) +{ + ListCell *connectionCell = NULL; + foreach(connectionCell, shardConnections->connectionList) + { + PlacementConnection *placementConnection = + (PlacementConnection *) lfirst(connectionCell); + PGconn *connection = placementConnection->connection; + int64 shardId = shardConnections->shardId; + + /* copy the line buffer into the placement */ + int copyResult = PQputCopyData(connection, lineBuf->data, lineBuf->len); + if (copyResult != 1) + { + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + ereport(ERROR, (errcode(ERRCODE_IO_ERROR), + errmsg("COPY to shard %ld on %s:%s failed", + (long) shardId, nodeName, nodePort))); + } + } +} + + +/* + * ConnectionList flatten the connection hash to a list of placement connections. + */ +static List * +ConnectionList(HTAB *connectionHash) +{ + List *connectionList = NIL; + HASH_SEQ_STATUS status; + ShardConnections* shardConnections = NULL; + + hash_seq_init(&status, connectionHash); + while ((shardConnections = (ShardConnections *) hash_seq_search(&status)) != NULL) + { + ListCell *connectionCell = NULL; + foreach(connectionCell, shardConnections->connectionList) + { + PlacementConnection *placementConnection = + (PlacementConnection *) lfirst(connectionCell); + + connectionList = lappend(connectionList, placementConnection); + } + } + + return connectionList; +} + + +/* + * End copy and prepare transaction. + * This function is applied for each shard placement unless some error happen. + * Status of this function is stored in ShardConnections::status field + */ +static void +PrepareCopyTransaction(List *connectionList) +{ + CitusTransactionManager const *transactionManager = + &CitusTransactionManagerImpl[CopyTransactionManager]; + + ListCell *connectionCell = NULL; + foreach(connectionCell, connectionList) + { + PlacementConnection *placementConnection = + (PlacementConnection *) lfirst(connectionCell); + PGconn *connection = placementConnection->connection; + int64 shardId = placementConnection->shardId; + + if (EndRemoteCopy(connection) && + transactionManager->Prepare(connection, BuildTransactionId(shardId))) + { + placementConnection->prepared = true; + } + else + { + ereport(ERROR, (errcode(ERRCODE_IO_ERROR), + errmsg("Failed to prepare transaction for shard %ld", + (long) shardId))); + } + } +} + + +/* + * EndRemoteCopy sends PQputCopyEnd command to the client and checks the result. + */ +static bool +EndRemoteCopy(PGconn *connection) +{ + PGresult *result = NULL; + + int copyEndResult = PQputCopyEnd(connection, NULL); + if (copyEndResult != 1) + { + return false; + } + + while ((result = PQgetResult(connection)) != NULL) + { + int resultStatus = PQresultStatus(result); + if (resultStatus != PGRES_COMMAND_OK) + { + ReportRemoteError(connection, result); + return false; + } + PQclear(result); + } + + return true; +} + + +/* + * AbortCopyTransaction aborts a two-phase commit. It attempts to roll back + * all transactions even if some of them fail, in which case a warning is given + * for each of them. + */ +static void +AbortCopyTransaction(List *connectionList) +{ + CitusTransactionManager const *transactionManager = + &CitusTransactionManagerImpl[CopyTransactionManager]; + + ListCell *connectionCell = NULL; + foreach(connectionCell, connectionList) + { + PlacementConnection *placementConnection = + (PlacementConnection *) lfirst(connectionCell); + PGconn *connection = placementConnection->connection; + int64 shardId = placementConnection->shardId; + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + + if (placementConnection->prepared) + { + char *transactionId = BuildTransactionId(shardId); + + if (!transactionManager->RollbackPrepared(connection, transactionId)) + { + ereport(WARNING, (errcode(ERRCODE_IO_ERROR), + errmsg("Failed to roll back transaction '%s' on %s:%s", + transactionId, nodeName, nodePort))); + } + } + else if (!EndRemoteCopy(connection) && + !transactionManager->Rollback(connection)) + { + ereport(WARNING, (errcode(ERRCODE_IO_ERROR), + errmsg("Failed to COPY to shard %ld on %s:%s", + shardId, nodeName, nodePort))); + } + + PQfinish(connection); + } +} + + +/* + * CommitCopyTransaction commits a two-phase commit. It attempts to commit all + * transactionsm even if some of them fail, in which case a warning is given + * for each of them. + */ +static void +CommitCopyTransaction(List *connectionList) +{ + CitusTransactionManager const *transactionManager = + &CitusTransactionManagerImpl[CopyTransactionManager]; + + ListCell *connectionCell = NULL; + foreach(connectionCell, connectionList) + { + PlacementConnection *placementConnection = + (PlacementConnection *) lfirst(connectionCell); + PGconn *connection = placementConnection->connection; + int64 shardId = placementConnection->shardId; + char *transactionId = BuildTransactionId(shardId); + + Assert(placementConnection->prepared); + + if (!transactionManager->CommitPrepared(connection, transactionId)) + { + char *nodeName = ConnectionGetOptionValue(connection, "host"); + char *nodePort = ConnectionGetOptionValue(connection, "port"); + ereport(WARNING, (errcode(ERRCODE_IO_ERROR), + errmsg("Failed to commit transaction '%s' on %s:%s", + transactionId, nodeName, nodePort))); + } + + PQfinish(connection); + } +} + diff --git a/src/backend/distributed/executor/multi_utility.c b/src/backend/distributed/executor/multi_utility.c index be4777085..39c4d27c5 100644 --- a/src/backend/distributed/executor/multi_utility.c +++ b/src/backend/distributed/executor/multi_utility.c @@ -17,6 +17,7 @@ #include "commands/tablecmds.h" #include "distributed/master_protocol.h" #include "distributed/metadata_cache.h" +#include "distributed/multi_copy.h" #include "distributed/multi_utility.h" #include "distributed/multi_join_order.h" #include "distributed/transmit.h" @@ -50,7 +51,7 @@ static bool IsTransmitStmt(Node *parsetree); static void VerifyTransmitStmt(CopyStmt *copyStatement); /* Local functions forward declarations for processing distributed table commands */ -static Node * ProcessCopyStmt(CopyStmt *copyStatement); +static Node * ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag); static Node * ProcessIndexStmt(IndexStmt *createIndexStatement, const char *createIndexCommand); static Node * ProcessDropIndexStmt(DropStmt *dropIndexStatement, @@ -122,7 +123,12 @@ multi_ProcessUtility(Node *parsetree, if (IsA(parsetree, CopyStmt)) { - parsetree = ProcessCopyStmt((CopyStmt *) parsetree); + parsetree = ProcessCopyStmt((CopyStmt *) parsetree, completionTag); + + if (parsetree == NULL) + { + return; + } } if (IsA(parsetree, IndexStmt)) @@ -300,7 +306,7 @@ VerifyTransmitStmt(CopyStmt *copyStatement) * COPYing from distributed tables and preventing unsupported actions. */ static Node * -ProcessCopyStmt(CopyStmt *copyStatement) +ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag) { /* * We first check if we have a "COPY (query) TO filename". If we do, copy doesn't @@ -344,9 +350,8 @@ ProcessCopyStmt(CopyStmt *copyStatement) { if (copyStatement->is_from) { - ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED), - errmsg("cannot execute COPY FROM on a distributed table " - "on master node"))); + CitusCopyFrom(copyStatement, completionTag); + return NULL; } else if (!copyStatement->is_from) { diff --git a/src/backend/distributed/planner/multi_physical_planner.c b/src/backend/distributed/planner/multi_physical_planner.c index b34c36b4c..64d7a5f16 100644 --- a/src/backend/distributed/planner/multi_physical_planner.c +++ b/src/backend/distributed/planner/multi_physical_planner.c @@ -110,8 +110,6 @@ static MapMergeJob * BuildMapMergeJob(Query *jobQuery, List *dependedJobList, Oid baseRelationId, BoundaryNodeJobType boundaryNodeJobType); static uint32 HashPartitionCount(void); -static int CompareShardIntervals(const void *leftElement, const void *rightElement, - FmgrInfo *typeCompareFunction); static ArrayType * SplitPointObject(ShardInterval **shardIntervalArray, uint32 shardIntervalCount); @@ -169,7 +167,6 @@ static List * RoundRobinAssignTaskList(List *taskList); static List * RoundRobinReorder(Task *task, List *placementList); static List * ReorderAndAssignTaskList(List *taskList, List * (*reorderFunction)(Task *, List *)); -static int CompareTasksByShardId(const void *leftElement, const void *rightElement); static List * ActiveShardPlacementLists(List *taskList); static List * ActivePlacementList(List *placementList); static List * LeftRotateList(List *list, uint32 rotateCount); @@ -1810,7 +1807,7 @@ SortedShardIntervalArray(List *shardIntervalList) * CompareShardIntervals acts as a helper function to compare two shard interval * pointers by their minimum values, using the value's type comparison function. */ -static int +int CompareShardIntervals(const void *leftElement, const void *rightElement, FmgrInfo *typeCompareFunction) { @@ -5073,7 +5070,7 @@ ReorderAndAssignTaskList(List *taskList, List * (*reorderFunction)(Task *, List /* Helper function to compare two tasks by their anchor shardId. */ -static int +int CompareTasksByShardId(const void *leftElement, const void *rightElement) { const Task *leftTask = *((const Task **) leftElement); diff --git a/src/backend/distributed/shared_library_init.c b/src/backend/distributed/shared_library_init.c index 332319b2a..0a5db8697 100644 --- a/src/backend/distributed/shared_library_init.c +++ b/src/backend/distributed/shared_library_init.c @@ -20,6 +20,7 @@ #include "executor/executor.h" #include "distributed/master_protocol.h" #include "distributed/modify_planner.h" +#include "distributed/multi_copy.h" #include "distributed/multi_executor.h" #include "distributed/multi_explain.h" #include "distributed/multi_join_order.h" @@ -29,6 +30,7 @@ #include "distributed/multi_server_executor.h" #include "distributed/multi_utility.h" #include "distributed/task_tracker.h" +#include "distributed/transaction_manager.h" #include "distributed/worker_manager.h" #include "distributed/worker_protocol.h" #include "postmaster/postmaster.h" @@ -67,6 +69,12 @@ static const struct config_enum_entry shard_placement_policy_options[] = { { NULL, 0, false } }; +static const struct config_enum_entry transaction_manager_options[] = { + { "1pc", TRANSACTION_MANAGER_1PC, false }, + { "2pc", TRANSACTION_MANAGER_2PC, false }, + { NULL, 0, false } +}; + /* shared library initialization function */ void @@ -437,6 +445,21 @@ RegisterCitusConfigVariables(void) 0, NULL, NULL, NULL); + DefineCustomEnumVariable( + "citus.copy_transaction_manager", + gettext_noop("Sets the transaction manager for COPY into distributed tables."), + gettext_noop("When a failure occurs during when copying into a distributed " + "table, 2PC is required to ensure data is never lost. Change " + "this setting to '2pc' from its default '1pc' to enable 2PC." + "You must also set max_prepared_transactions on the worker " + "nodes. Recovery from failed 2PCs is currently manual."), + &CopyTransactionManager, + TRANSACTION_MANAGER_1PC, + transaction_manager_options, + PGC_USERSET, + 0, + NULL, NULL, NULL); + DefineCustomEnumVariable( "citus.task_assignment_policy", gettext_noop("Sets the policy to use when assigning tasks to worker nodes."), diff --git a/src/backend/distributed/utils/connection_cache.c b/src/backend/distributed/utils/connection_cache.c index a7391aa5b..8995f0b88 100644 --- a/src/backend/distributed/utils/connection_cache.c +++ b/src/backend/distributed/utils/connection_cache.c @@ -39,8 +39,6 @@ static HTAB *NodeConnectionHash = NULL; /* local function forward declarations */ static HTAB * CreateNodeConnectionHash(void); -static PGconn * ConnectToNode(char *nodeName, char *nodePort); -static char * ConnectionGetOptionValue(PGconn *connection, char *optionKeyword); /* @@ -99,10 +97,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort) if (needNewConnection) { - StringInfo nodePortString = makeStringInfo(); - appendStringInfo(nodePortString, "%d", nodePort); - - connection = ConnectToNode(nodeName, nodePortString->data); + connection = ConnectToNode(nodeName, nodePort); if (connection != NULL) { nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey, @@ -264,8 +259,8 @@ CreateNodeConnectionHash(void) * We attempt to connect up to MAX_CONNECT_ATTEMPT times. After that we give up * and return NULL. */ -static PGconn * -ConnectToNode(char *nodeName, char *nodePort) +PGconn * +ConnectToNode(char *nodeName, int32 nodePort) { PGconn *connection = NULL; const char *clientEncoding = GetDatabaseEncodingName(); @@ -276,11 +271,14 @@ ConnectToNode(char *nodeName, char *nodePort) "host", "port", "fallback_application_name", "client_encoding", "connect_timeout", "dbname", NULL }; + char nodePortString[12]; const char *valueArray[] = { - nodeName, nodePort, "citus", clientEncoding, + nodeName, nodePortString, "citus", clientEncoding, CLIENT_CONNECT_TIMEOUT_SECONDS, dbname, NULL }; + sprintf(nodePortString, "%d", nodePort); + Assert(sizeof(keywordArray) == sizeof(valueArray)); for (attemptIndex = 0; attemptIndex < MAX_CONNECT_ATTEMPTS; attemptIndex++) @@ -313,7 +311,7 @@ ConnectToNode(char *nodeName, char *nodePort) * The function returns NULL if the connection has no setting for an option with * the provided keyword. */ -static char * +char * ConnectionGetOptionValue(PGconn *connection, char *optionKeyword) { char *optionValue = NULL; diff --git a/src/backend/distributed/utils/transaction_manager.c b/src/backend/distributed/utils/transaction_manager.c new file mode 100644 index 000000000..e89aba54b --- /dev/null +++ b/src/backend/distributed/utils/transaction_manager.c @@ -0,0 +1,173 @@ +/*------------------------------------------------------------------------- + * + * transaction_manager.c + * This file contains functions that comprise a pluggable API for + * managing transactions across many worker nodes using 1PC or 2PC. + * + * Contributed by Konstantin Knizhnik, Postgres Professional + * + * Copyright (c) 2016, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" +#include "libpq-fe.h" +#include "miscadmin.h" + +#include "access/xact.h" +#include "distributed/connection_cache.h" +#include "distributed/transaction_manager.h" + +static bool BeginTransaction(PGconn *connection); +static bool Prepare1PC(PGconn *connection, char *transactionId); +static bool CommitPrepared1PC(PGconn *connection, char *transactionId); +static bool RollbackPrepared1PC(PGconn *connection, char *transactionId); +static bool RollbackTransaction(PGconn *connection); + +static bool Prepare2PC(PGconn *connection, char *transactionId); +static bool CommitPrepared2PC(PGconn *connection, char *transactionId); +static bool RollbackPrepared2PC(PGconn *connection, char *transactionId); + +static char * Build2PCCommand(char const *command, char *transactionId); + +CitusTransactionManager const CitusTransactionManagerImpl[] = +{ + { BeginTransaction, Prepare1PC, CommitPrepared1PC, + RollbackPrepared1PC, RollbackTransaction }, + { BeginTransaction, Prepare2PC, CommitPrepared2PC, + RollbackPrepared2PC, RollbackTransaction } +}; + + +/* + * BeginTransaction sends a BEGIN command to start a transaction. + */ +static bool +BeginTransaction(PGconn *connection) +{ + return ExecuteCommand(connection, PGRES_COMMAND_OK, "BEGIN"); +} + + +/* + * Prepare1PC does nothing since 1PC mode does not have a prepare phase. + * This function is provided for compatibility with the 2PC API. + */ +static bool +Prepare1PC(PGconn *connection, char *transactionId) +{ + return true; +} + + +/* + * Commit1PC sends a COMMIT command to commit a transaction. + */ +static bool +CommitPrepared1PC(PGconn *connection, char *transactionId) +{ + return ExecuteCommand(connection, PGRES_COMMAND_OK, "COMMIT"); +} + + +/* + * RollbackPrepared1PC sends a ROLLBACK command to roll a transaction + * back. This function is provided for compatibility with the 2PC API. + */ +static bool +RollbackPrepared1PC(PGconn *connection, char *transactionId) +{ + return ExecuteCommand(connection, PGRES_COMMAND_OK, "ROLLBACK"); +} + + +/* + * Prepare2PC sends a PREPARE TRANSACTION command to prepare a 2PC. + */ +static bool +Prepare2PC(PGconn *connection, char *transactionId) +{ + return ExecuteCommand(connection, PGRES_COMMAND_OK, + Build2PCCommand("PREPARE TRANSACTION", transactionId)); +} + + +/* + * CommitPrepared2PC sends a COMMIT TRANSACTION command to commit a 2PC. + */ +static bool +CommitPrepared2PC(PGconn *connection, char *transactionId) +{ + return ExecuteCommand(connection, PGRES_COMMAND_OK, + Build2PCCommand("COMMIT PREPARED", transactionId)); +} + + +/* + * RollbackPrepared2PC sends a COMMIT TRANSACTION command to commit a 2PC. + */ +static bool +RollbackPrepared2PC(PGconn *connection, char *transactionId) +{ + return ExecuteCommand(connection, PGRES_COMMAND_OK, + Build2PCCommand("ROLLBACK PREPARED", transactionId)); +} + + +/* + * RollbackTransaction sends a ROLLBACK command to roll a transaction back. + */ +static bool +RollbackTransaction(PGconn *connection) +{ + return ExecuteCommand(connection, PGRES_COMMAND_OK, "ROLLBACK"); +} + + +/* + * Build2PCCommand builds a command with a unique transaction ID for a two-phase commit. + */ +static char * +Build2PCCommand(char const *command, char *transactionId) +{ + StringInfo commandString = makeStringInfo(); + + appendStringInfo(commandString, "%s '%s'", transactionId); + + return commandString->data; +} + + +/* + * BuildTransactionId helps users construct a unique transaction id from an + * application-specific id. + */ +char * +BuildTransactionId(int localId) +{ + StringInfo commandString = makeStringInfo(); + + appendStringInfo(commandString, "citus_%d_%u_%d", MyProcPid, + GetCurrentTransactionId(), localId); + + return commandString->data; +} + + +/* + * ExecuteCommand executes a statement on a remote node and checks its result. + */ +bool +ExecuteCommand(PGconn *connection, ExecStatusType expectedResult, char const *command) +{ + bool ret = true; + PGresult *result = PQexec(connection, command); + if (PQresultStatus(result) != expectedResult) + { + ReportRemoteError(connection, result); + ret = false; + } + PQclear(result); + return ret; +} diff --git a/src/include/distributed/connection_cache.h b/src/include/distributed/connection_cache.h index 45f61742b..b99e1a69f 100644 --- a/src/include/distributed/connection_cache.h +++ b/src/include/distributed/connection_cache.h @@ -54,6 +54,8 @@ typedef struct NodeConnectionEntry extern PGconn * GetOrEstablishConnection(char *nodeName, int32 nodePort); extern void PurgeConnection(PGconn *connection); extern void ReportRemoteError(PGconn *connection, PGresult *result); +extern PGconn * ConnectToNode(char *nodeName, int nodePort); +extern char * ConnectionGetOptionValue(PGconn *connection, char *optionKeyword); #endif /* CONNECTION_CACHE_H */ diff --git a/src/include/distributed/multi_copy.h b/src/include/distributed/multi_copy.h new file mode 100644 index 000000000..c9f06c33a --- /dev/null +++ b/src/include/distributed/multi_copy.h @@ -0,0 +1,29 @@ +/*------------------------------------------------------------------------- + * + * multi_copy.h + * + * Declarations for public functions and variables used in COPY for + * distributed tables. + * + * Copyright (c) 2016, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#ifndef MULTI_COPY_H +#define MULTI_COPY_H + + +#include "libpq-fe.h" +#include "distributed/transaction_manager.h" + + +/* config variable managed via guc.c */ +extern int CopyTransactionManager; + + +/* function declarations for copying into a distributed table */ +extern void CitusCopyFrom(CopyStmt *copyStatement, char* completionTag); + + +#endif /* MULTI_COPY_H */ diff --git a/src/include/distributed/multi_physical_planner.h b/src/include/distributed/multi_physical_planner.h index 3ad053b22..92bcd1354 100644 --- a/src/include/distributed/multi_physical_planner.h +++ b/src/include/distributed/multi_physical_planner.h @@ -238,6 +238,8 @@ extern int CompareShardPlacements(const void *leftElement, const void *rightElem extern ShardInterval ** SortedShardIntervalArray(List *shardList); extern bool ShardIntervalsOverlap(ShardInterval *firstInterval, ShardInterval *secondInterval); +extern int CompareShardIntervals(const void *leftElement, const void *rightElement, + FmgrInfo *typeCompareFunction); /* function declarations for Task and Task list operations */ extern bool TasksEqual(const Task *a, const Task *b); @@ -247,6 +249,7 @@ extern bool TaskListMember(const List *taskList, const Task *task); extern List * TaskListDifference(const List *list1, const List *list2); extern List * TaskListUnion(const List *list1, const List *list2); extern List * FirstReplicaAssignTaskList(List *taskList); +extern int CompareTasksByShardId(const void *leftElement, const void *rightElement); #endif /* MULTI_PHYSICAL_PLANNER_H */ diff --git a/src/include/distributed/transaction_manager.h b/src/include/distributed/transaction_manager.h new file mode 100644 index 000000000..d3e6834ed --- /dev/null +++ b/src/include/distributed/transaction_manager.h @@ -0,0 +1,49 @@ +/*------------------------------------------------------------------------- + * + * transaction_manager.h + * + * Transaction manager API. + * + * Copyright (c) 2016, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#ifndef TRANSACTION_MANAGER_H +#define TRANSACTION_MANAGER_H + + +#include "libpq-fe.h" +#include "utils/guc.h" + + +/* pluggable transaction manager API */ +typedef struct CitusTransactionManager +{ + bool (*Begin)(PGconn *conn); + bool (*Prepare)(PGconn *conn, char *transactionId); + bool (*CommitPrepared)(PGconn *conn, char *transactionId); + bool (*RollbackPrepared)(PGconn *conn, char *transactionId); + bool (*Rollback)(PGconn *conn); +} CitusTransactionManager; + + +/* Enumeration that defines the transaction manager to use */ +typedef enum +{ + TRANSACTION_MANAGER_1PC = 0, + TRANSACTION_MANAGER_2PC = 1 +} TransactionManagerType; + + +/* Implementations of the transaction manager API */ +extern CitusTransactionManager const CitusTransactionManagerImpl[]; + + +/* Function declarations for copying into a distributed table */ +extern bool ExecuteCommand(PGconn *connection, ExecStatusType expectedResult, + char const *command); +extern char * BuildTransactionId(int localId); + + +#endif /* TRANSACTION_MANAGER_H */ diff --git a/src/test/regress/expected/multi_utilities.out b/src/test/regress/expected/multi_utilities.out index 6deec6f9f..6aa4fa18e 100644 --- a/src/test/regress/expected/multi_utilities.out +++ b/src/test/regress/expected/multi_utilities.out @@ -18,9 +18,6 @@ SELECT master_create_worker_shards('sharded_table', 2, 1); COPY sharded_table TO STDOUT; COPY (SELECT COUNT(*) FROM sharded_table) TO STDOUT; 0 --- but COPY in is not -COPY sharded_table FROM STDIN; -ERROR: cannot execute COPY FROM on a distributed table on master node -- cursors may not involve distributed tables DECLARE all_sharded_rows CURSOR FOR SELECT * FROM sharded_table; ERROR: DECLARE CURSOR can only be used in transaction blocks diff --git a/src/test/regress/expected/multi_utility_statements.out b/src/test/regress/expected/multi_utility_statements.out index ca5d7b4ad..1b502f20b 100644 --- a/src/test/regress/expected/multi_utility_statements.out +++ b/src/test/regress/expected/multi_utility_statements.out @@ -188,9 +188,6 @@ VIETNAM RUSSIA UNITED KINGDOM UNITED STATES --- Ensure that preventing COPY FROM against distributed tables works -COPY customer FROM STDIN; -ERROR: cannot execute COPY FROM on a distributed table on master node -- Test that we can create on-commit drop tables, and also test creating with -- oids, along with changing column names BEGIN; diff --git a/src/test/regress/sql/multi_utilities.sql b/src/test/regress/sql/multi_utilities.sql index a503b2c0d..d2161f246 100644 --- a/src/test/regress/sql/multi_utilities.sql +++ b/src/test/regress/sql/multi_utilities.sql @@ -10,9 +10,6 @@ SELECT master_create_worker_shards('sharded_table', 2, 1); COPY sharded_table TO STDOUT; COPY (SELECT COUNT(*) FROM sharded_table) TO STDOUT; --- but COPY in is not -COPY sharded_table FROM STDIN; - -- cursors may not involve distributed tables DECLARE all_sharded_rows CURSOR FOR SELECT * FROM sharded_table; diff --git a/src/test/regress/sql/multi_utility_statements.sql b/src/test/regress/sql/multi_utility_statements.sql index ad67f2b36..e3ebdcec6 100644 --- a/src/test/regress/sql/multi_utility_statements.sql +++ b/src/test/regress/sql/multi_utility_statements.sql @@ -104,9 +104,6 @@ COPY nation TO STDOUT; -- ensure individual cols can be copied out, too COPY nation(n_name) TO STDOUT; --- Ensure that preventing COPY FROM against distributed tables works -COPY customer FROM STDIN; - -- Test that we can create on-commit drop tables, and also test creating with -- oids, along with changing column names