diff --git a/Makefile b/Makefile index 10841d743..9efad50b5 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ endif include Makefile.global -all: extension +all: extension bload # build extension extension: @@ -30,6 +30,18 @@ clean-extension: install: install-extension install-headers clean: clean-extension +# build bload binary +bload: + $(MAKE) -C src/bin/bload/ all +install-bload: bload + $(MAKE) -C src/bin/bload/ install +clean-bload: + $(MAKE) -C src/bin/bload/ clean +.PHONY: bload install-bload clean-bload +# Add to generic targets +install: install-bload +clean: clean-bload + # apply or check style reindent: cd ${citus_abs_top_srcdir} && citus_indent --quiet diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index bfc73c676..6583f1315 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -46,33 +46,49 @@ #include "postgres.h" #include "libpq-fe.h" +#include "libpq/libpq.h" +#include "libpq/pqformat.h" #include "miscadmin.h" #include /* for htons */ #include /* for htons */ +#include +#include #include +#include +#include +#include +#include +#include +#include #include "access/htup_details.h" #include "access/htup.h" +#include "access/nbtree.h" #include "access/sdir.h" #include "catalog/namespace.h" #include "catalog/pg_type.h" #include "commands/copy.h" #include "commands/defrem.h" +#include "distributed/bload.h" +#include "distributed/colocation_utils.h" #include "distributed/master_protocol.h" #include "distributed/metadata_cache.h" #include "distributed/multi_copy.h" #include "distributed/multi_physical_planner.h" #include "distributed/multi_shard_transaction.h" #include "distributed/placement_connection.h" +#include "distributed/pg_dist_shard.h" #include "distributed/remote_commands.h" #include "distributed/resource_lock.h" +#include "distributed/worker_protocol.h" #include "executor/executor.h" #include "tsearch/ts_locale.h" #include "utils/builtins.h" #include "utils/lsyscache.h" #include "utils/rel.h" #include "utils/memutils.h" +#include "utils/typcache.h" /* constant used in binary protocol */ @@ -81,10 +97,14 @@ static const char BinarySignature[11] = "PGCOPY\n\377\r\n\0"; /* use a global connection to the master node in order to skip passing it around */ static MultiConnection *masterConnection = NULL; +static int MaxEvents = 64; /* up to MaxEvents returned by epoll_wait() */ +static int EpollTimeout = 100; /* wait for a maximum time of EpollTimeout ms */ +static int ZeromqPortCount = 100; /* number of available zeromq ports */ +static int ZeromqStartPort = 10240; /* start port of zeormq */ /* Local functions forward declarations */ static void CopyFromWorkerNode(CopyStmt *copyStatement, char *completionTag); -static void CopyToExistingShards(CopyStmt *copyStatement, char *completionTag); +static void CopyToExistingShards(CopyStmt *copyStatement, char *completionTag, Oid relationId); static void CopyToNewShards(CopyStmt *copyStatement, char *completionTag, Oid relationId); static char MasterPartitionMethod(RangeVar *relation); static void RemoveMasterOptions(CopyStmt *copyStatement); @@ -103,7 +123,7 @@ static void SendCopyBinaryFooters(CopyOutState copyOutState, int64 shardId, List *connectionList); static StringInfo ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId, - bool useBinaryCopyFormat); + bool useBinaryCopyFormat, bool useFreeze); static void SendCopyDataToAll(StringInfo dataBuffer, int64 shardId, List *connectionList); static void SendCopyDataToPlacement(StringInfo dataBuffer, int64 shardId, MultiConnection *connection); @@ -127,6 +147,37 @@ static void CopySendInt16(CopyOutState outputState, int16 val); static void CopyAttributeOutText(CopyOutState outputState, char *string); static inline void CopyFlushOutput(CopyOutState outputState, char *start, char *pointer); +/* Functions for bulkload copy */ +static void RemoveBulkloadOptions(CopyStmt *copyStatement); + +static StringInfo ConstructBulkloadCopyStmt(CopyStmt *copyStatement, + NodeAddress *masterNodeAddress, char *nodeName, uint32 nodePort); +static void RebuildBulkloadCopyStatement(CopyStmt *copyStatement, + NodeAddress *bulkloadServer); +static StringInfo DeparseCopyStatementOptions(List *options); + +static NodeAddress * LocalAddress(void); +static NodeAddress * BulkloadServerAddress(CopyStmt *copyStatement); + +static void BulkloadCopyToNewShards(CopyStmt *copyStatement, char *completionTag, + Oid relationId); +static void BulkloadCopyToExistingShards(CopyStmt *copyStatement, char *completionTag, + Oid relationId); +static void BulkloadCopyServer(CopyStmt *copyStatement, char *completionTag, + NodeAddress *masterNodeAddress, Oid relationId); + +static List * MasterWorkerNodeList(void); +static List * RemoteWorkerNodeList(void); +static DistTableCacheEntry * MasterDistributedTableCacheEntry(RangeVar *relation); + +static void StartZeroMQServer(ZeroMQServer *zeromqServer, bool is_program, bool binary, + int natts); +static void SendMessage(ZeroMQServer *zeromqServer, char *buf, size_t len, bool kill); +static void StopZeroMQServer(ZeroMQServer *zeromqServer); + +static int CopyGetAttnums(Oid relationId, List *attnamelist); +static PGconn * GetConnectionBySock(List *connList, int sock, int *connIdx); + /* * CitusCopyFrom implements the COPY table_name FROM. It dispacthes the copy @@ -137,6 +188,7 @@ void CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) { bool isCopyFromWorker = false; + bool isBulkloadCopy = false; BeginOrContinueCoordinatedTransaction(); if (MultiShardCommitProtocol == COMMIT_PROTOCOL_2PC) @@ -166,8 +218,13 @@ CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) } masterConnection = NULL; /* reset, might still be set after error */ + isBulkloadCopy = IsBulkloadCopy(copyStatement); isCopyFromWorker = IsCopyFromWorker(copyStatement); - if (isCopyFromWorker) + if (isBulkloadCopy) + { + CitusBulkloadCopy(copyStatement, completionTag); + } + else if (isCopyFromWorker) { CopyFromWorkerNode(copyStatement, completionTag); } @@ -179,7 +236,7 @@ CitusCopyFrom(CopyStmt *copyStatement, char *completionTag) if (partitionMethod == DISTRIBUTE_BY_HASH || partitionMethod == DISTRIBUTE_BY_RANGE || partitionMethod == DISTRIBUTE_BY_NONE) { - CopyToExistingShards(copyStatement, completionTag); + CopyToExistingShards(copyStatement, completionTag, InvalidOid); } else if (partitionMethod == DISTRIBUTE_BY_APPEND) { @@ -270,9 +327,17 @@ CopyFromWorkerNode(CopyStmt *copyStatement, char *completionTag) * rows. */ static void -CopyToExistingShards(CopyStmt *copyStatement, char *completionTag) +CopyToExistingShards(CopyStmt *copyStatement, char *completionTag, Oid relationId) { - Oid tableId = RangeVarGetRelid(copyStatement->relation, NoLock, false); + Oid tableId = InvalidOid; + if (relationId != InvalidOid) + { + tableId = relationId; + } + else + { + tableId = RangeVarGetRelid(copyStatement->relation, NoLock, false); + } char *relationName = get_rel_name(tableId); Relation distributedRelation = NULL; TupleDesc tupleDescriptor = NULL; @@ -899,7 +964,7 @@ OpenCopyConnections(CopyStmt *copyStatement, ShardConnections *shardConnections, ClaimConnectionExclusively(connection); RemoteTransactionBeginIfNecessary(connection); copyCommand = ConstructCopyStatement(copyStatement, shardConnections->shardId, - useBinaryCopyFormat); + useBinaryCopyFormat, false); result = PQexec(connection->pgConn, copyCommand->data); if (PQresultStatus(result) != PGRES_COPY_IN) @@ -1025,7 +1090,8 @@ RemoteFinalizedShardPlacementList(uint64 shardId) (ShardPlacement *) palloc0(sizeof(ShardPlacement)); shardPlacement->placementId = placementId; - shardPlacement->nodeName = nodeName; + shardPlacement->nodeName = (char *) palloc0(strlen(nodeName) + 1); + strcpy(shardPlacement->nodeName, nodeName); shardPlacement->nodePort = nodePort; finalizedPlacementList = lappend(finalizedPlacementList, shardPlacement); @@ -1065,7 +1131,8 @@ SendCopyBinaryFooters(CopyOutState copyOutState, int64 shardId, List *connection * shard. */ static StringInfo -ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId, bool useBinaryCopyFormat) +ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId, bool useBinaryCopyFormat, + bool useFreeze) { StringInfo command = makeStringInfo(); @@ -1075,6 +1142,7 @@ ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId, bool useBinaryCop char *shardName = pstrdup(relationName); char *shardQualifiedName = NULL; const char *copyFormat = NULL; + const char *freeze = NULL; AppendShardIdToName(&shardName, shardId); @@ -1088,8 +1156,16 @@ ConstructCopyStatement(CopyStmt *copyStatement, int64 shardId, bool useBinaryCop { copyFormat = "TEXT"; } - appendStringInfo(command, "COPY %s FROM STDIN WITH (FORMAT %s)", shardQualifiedName, - copyFormat); + if (useFreeze) + { + freeze = "TRUE"; + } + else + { + freeze = "FALSE"; + } + appendStringInfo(command, "COPY %s FROM STDIN WITH (FORMAT %s, FREEZE %s)", + shardQualifiedName, copyFormat, freeze); return command; } @@ -1694,3 +1770,1512 @@ CopyFlushOutput(CopyOutState cstate, char *start, char *pointer) CopySendData(cstate, start, pointer - start); } } + +/* + * CitusBulkloadCopy implements the COPY table_name FROM xxx WITH(method 'bulkload'). + * For bulkload server, it dispatches the copy statement and records from FROM to all + * workers, and waits them finish. For bulkload clients, they pull records from server + * and copy them into shards. Bulkload client handles differently against append and + * hash distributed table. + * For APPEND distributed table, there are two copy policies: + * 1. bulkload client would create a shard for each tablespace and insert records to + * these shards in round-robin policy, and if any shard reaches ShardMaxSize, it + * would create a new shard in the tablespace and so on. In this policy, since DDL + * commands of create shard and DML commands of copy are running in one transaction, + * we can use COPY FREEZE and Lazy-Indexing to improve ingestion performance. + * 2. each bulkload client acts just like CopyToNewShards(), calling master to create a + * new shard and insert records into this new shard, when the new shard reaches the + * ShardMaxSize, it would call master to create another new shard and so on. + * For HASH distributed table, clients get metadata of the table from master node and + * send records to different shards according to hash value. + */ +void +CitusBulkloadCopy(CopyStmt *copyStatement, char *completionTag) +{ + bool isCopyFromWorker = false; + bool isBulkloadClient = false; + NodeAddress *masterNodeAddress = NULL; + Oid relationId = InvalidOid; + char partitionMethod = 0; + char *nodeName = NULL; + uint32 nodePort = 0; + char *nodeUser = NULL; + char *schemaName = NULL; + + /* + * from postgres/src/bin/psql/copy.c:handleCopyIn(), we know that a pq_message + * contains exactly one record for csv and text format, but not for binary. + * since in StartZeroMQServer() it's hard to handle pq_message which may contain + * incomplete records, currently, we don't support COPY FROM STDIN with binary + * format for bulkload copy. + */ + if (copyStatement->filename == NULL && IsBinaryCopy(copyStatement)) + { + elog(ERROR, "bulkload doesn't support copy from stdin with binary format"); + } + + isCopyFromWorker = IsCopyFromWorker(copyStatement); + if (isCopyFromWorker) + { + masterNodeAddress = MasterNodeAddress(copyStatement); + nodeName = masterNodeAddress->nodeName; + nodePort = masterNodeAddress->nodePort; + nodeUser = CurrentUserName(); + masterConnection = GetNodeConnection(FORCE_NEW_CONNECTION, nodeName, nodePort); + if (masterConnection == NULL) + { + elog(ERROR, "Can't connect to master server %s:%d as user %s", + nodeName, nodePort, nodeUser); + } + RemoveMasterOptions(copyStatement); + + /* strip schema name for local reference */ + schemaName = copyStatement->relation->schemaname; + copyStatement->relation->schemaname = NULL; + relationId = RangeVarGetRelid(copyStatement->relation, NoLock, false); + /* put schema name back */ + copyStatement->relation->schemaname = schemaName; + partitionMethod = MasterPartitionMethod(copyStatement->relation); + } + else + { + masterNodeAddress = LocalAddress(); + relationId = RangeVarGetRelid(copyStatement->relation, NoLock, false); + partitionMethod = PartitionMethod(relationId); + } + + isBulkloadClient = IsBulkloadClient(copyStatement); + PG_TRY(); + { + if (isBulkloadClient) + { + if (partitionMethod == DISTRIBUTE_BY_APPEND + || partitionMethod == DISTRIBUTE_BY_RANGE) + { + PGresult *queryResult = NULL; + Assert(masterConnection != NULL); + /* run all metadata commands in a transaction */ + queryResult = PQexec(masterConnection->pgConn, "BEGIN"); + if (PQresultStatus(queryResult) != PGRES_COMMAND_OK) + { + elog(ERROR, "could not start to update master node metadata"); + } + PQclear(queryResult); + + /* there are two policies for copying into new shard */ + // BulkloadCopyToNewShardsV1(copyStatement, completionTag, masterNodeAddress, + // relationId); + BulkloadCopyToNewShards(copyStatement, completionTag, relationId); + + /* commit metadata transactions */ + queryResult = PQexec(masterConnection->pgConn, "COMMIT"); + if (PQresultStatus(queryResult) != PGRES_COMMAND_OK) + { + elog(ERROR, "could not commit master node metadata changes"); + } + PQclear(queryResult); + } + else if (partitionMethod == DISTRIBUTE_BY_HASH) + { + BulkloadCopyToExistingShards(copyStatement, completionTag, relationId); + } + else + { + elog(ERROR, "Unknown partition method: %d", partitionMethod); + } + } + else + { + BulkloadCopyServer(copyStatement, completionTag, masterNodeAddress, relationId); + } + } + PG_CATCH(); + { + PG_RE_THROW(); + } + PG_END_TRY(); +} + +/* + * CopyGetAttnums - get the number of non-dropped columns of relation. + * copy from postgresql/src/backend/commands/copy.c + */ +static int +CopyGetAttnums(Oid relationId, List *attnamelist) +{ + int attnums = list_length(attnamelist); + if (attnums != 0) + { + return attnums; + } + else + { + Relation rel = heap_open(relationId, AccessShareLock); + TupleDesc tupDesc = RelationGetDescr(rel); + Form_pg_attribute *attr = tupDesc->attrs; + int attr_count = tupDesc->natts; + int i; + for (i = 0; i < attr_count; i++) + { + if (attr[i]->attisdropped) + continue; + attnums++; + } + heap_close(rel, NoLock); + return attnums; + } +} + +/* + * BulkloadCopyServer rebuild a COPY statement with 'bulkload_host' and 'bulkload_port' + * options and dispatches it to all worker nodes for asynchronous executing. It also + * starts a zeromq server to dispatches records from FROM clause to all worker nodes. + */ +static void +BulkloadCopyServer(CopyStmt *copyStatement, char *completionTag, + NodeAddress *masterNodeAddress, Oid relationId) +{ + List *workerNodeList = NULL; + ListCell *workerCell = NULL; + List *workerConnectionList = NIL; + NodeAddress *serverAddress = NULL; + StringInfo clientCopyCommand = NULL; + struct ZeroMQServer *zeromqServer = NULL; + uint64 processedRowCount = 0; + int loopIndex; + char *nodeName = NULL; + uint32 nodePort = 0; + WorkerNode *workerNode = NULL; + MultiConnection *multiConn = NULL; + PGconn *conn = NULL; + PGresult *res = NULL; + int workerConnectionCount = 0; + int finishCount = 0; + int failCount = 0; + int *finish = NULL; + int rc; + int efd; + int nevents; + int sock; + int connIdx; + struct epoll_event *event = NULL; + struct epoll_event *events = NULL; + + workerNodeList = MasterWorkerNodeList(); + serverAddress = LocalAddress(); + + zeromqServer = (struct ZeroMQServer *) palloc0(sizeof(ZeroMQServer)); + strcpy(zeromqServer->host, serverAddress->nodeName); + /* + * use port number between ZeromqStartPort and ZeromqStartPort+ZeromqPortCount + * as zeromq server port + */ + zeromqServer->port = random() % ZeromqPortCount + ZeromqStartPort; + + if (copyStatement->filename != NULL) + { + strcpy(zeromqServer->file, copyStatement->filename); + } + + clientCopyCommand = ConstructBulkloadCopyStmt(copyStatement, masterNodeAddress, + serverAddress->nodeName, zeromqServer->port); + + events = (struct epoll_event *) palloc0(MaxEvents * sizeof(struct epoll_event)); + efd = epoll_create1(0); + if (efd == -1) + { + elog(ERROR, "epoll_create failed"); + } + + foreach(workerCell, workerNodeList) + { + workerNode = (WorkerNode *) lfirst(workerCell); + nodeName = workerNode->workerName; + nodePort = workerNode->workerPort; + multiConn = GetNodeConnection(FOR_DML, nodeName, nodePort); + conn = multiConn->pgConn; + if (conn == NULL) + { + elog(WARNING, "connect to %s:%d failed", nodeName, nodePort); + } + else + { + int querySent = PQsendQuery(conn, clientCopyCommand->data); + if (querySent == 0) + { + elog(WARNING, "send bulkload copy to %s:%d failed: %s", nodeName, nodePort, + PQerrorMessage(conn)); + } + else + { + if (PQsetnonblocking(conn, 1) == -1) + { + /* + * make sure it wouldn't cause to fatal error even in blocking mode + */ + elog(WARNING, "%s:%d set non-blocking failed", nodeName, nodePort); + } + sock = PQsocket(conn); + if (sock < 0) + { + elog(WARNING, "%s:%d get socket failed", nodeName, nodePort); + } + else + { + event = (struct epoll_event *) palloc0(sizeof(struct epoll_event)); + event->events = EPOLLIN | EPOLLERR | EPOLLET; + event->data.fd = sock; + if (epoll_ctl(efd, EPOLL_CTL_ADD, sock, event) != 0) + { + elog(WARNING, "epoll_ctl add socket of %s:%d failed", nodeName, nodePort); + } + else + { + /* + * finally we append the connection which we have sent query and got it's + * socket file descriptor successfully to connection list. + */ + workerConnectionList = lappend(workerConnectionList, conn); + } + } + } + } + } + workerConnectionCount = list_length(workerConnectionList); + if (workerConnectionCount == 0) + { + elog(ERROR, "Can't send bulkload copy to any worker"); + } + + /* + * array representing the status of worker connection: + * -1: worker failed + * 0: worker still running + * 1: worker succeed + */ + finish = (int *) palloc0(workerConnectionCount * sizeof(int)); + + PG_TRY(); + { + int natts = CopyGetAttnums(relationId, copyStatement->attlist); + /* + * check status of workers before starting zeromq server in case + * of unproper bulkload copy command. + * TODO@luoyuanhao: if error occurs after EpollTimeOut, there's + * still possibility of deadlock, refactoring this code properly. + */ + do + { + nevents = epoll_wait(efd, events, MaxEvents, EpollTimeout * 10); + if (nevents == -1) + { + elog(ERROR, "epoll_wait error(%d): %s", errno, strerror(errno)); + } + for (loopIndex = 0; loopIndex < nevents; loopIndex++) + { + conn = GetConnectionBySock(workerConnectionList, events[loopIndex].data.fd, + &connIdx); + Assert(conn != NULL); + if (finish[connIdx] != 0) continue; + /* + * if bulkload copy command is okay, there should be neither output nor error + * message in socket, otherwise, bulkload copy command is wrong. + */ + elog(WARNING, "bulkload copy in %s:%s fail, read log file to get error message", + PQhost(conn), PQport(conn)); + finish[connIdx] = -1; + finishCount++; + } + } while(nevents != 0); + if (finishCount == workerConnectionCount) + { + elog(ERROR, "bulkload copy commands fail in all workers"); + } + + StartZeroMQServer(zeromqServer, copyStatement->is_program, + IsBinaryCopy(copyStatement), natts); + + while (finishCount < workerConnectionCount) + { + CHECK_FOR_INTERRUPTS(); + + /* send EOF message */ + SendMessage(zeromqServer, "KILL", 4, true); + + /* + * wait indefinitely may cause to dead-lock: we send a 'KILL' signal, + * but bload may not catch it and wait indefinitely, so COPY command + * wouldn't finish and therefore there are no responses(events) from + * pg connection. + */ + nevents = epoll_wait(efd, events, MaxEvents, EpollTimeout); + if (nevents == -1) + { + elog(ERROR, "epoll_wait error(%d): %s", errno, strerror(errno)); + } + for (loopIndex = 0; loopIndex < nevents; loopIndex++) + { + conn = GetConnectionBySock(workerConnectionList, events[loopIndex].data.fd, + &connIdx); + Assert(conn != NULL); + if (finish[connIdx] != 0) continue; + if (events[loopIndex].events & EPOLLERR) + { + elog(WARNING, "socket of %s:%s error", PQhost(conn), PQport(conn)); + finish[connIdx] = -1; + finishCount++; + continue; + } + if (events[loopIndex].events & EPOLLIN) + { + rc = PQconsumeInput(conn); + if (rc == 0) + { + elog(WARNING, "%s:%s error:%s", PQhost(conn), PQport(conn), + PQerrorMessage(conn)); + finish[connIdx] = -1; + finishCount++; + } + else + { + if (!PQisBusy(conn)) + { + res = PQgetResult(conn); + if (res == NULL) + { + finish[connIdx] = 1; + } + else + { + if (PQresultStatus(res) != PGRES_COMMAND_OK && + PQresultStatus(res) != PGRES_TUPLES_OK) + { + elog(WARNING, "%s:%s error:%s", PQhost(conn), PQport(conn), + PQresultErrorMessage(res)); + finish[connIdx] = -1; + } + else + { + processedRowCount += atol(PQcmdTuples(res)); + finish[connIdx] = 1; + } + PQclear(res); + } + finishCount++; + } + } + } + } + } + } + PG_CATCH(); + { + char *errbuf = (char *) palloc0(NAMEDATALEN); + foreach(workerCell, workerConnectionList) + { + PGconn *conn = (PGconn *) lfirst(workerCell); + PGcancel *cancel = PQgetCancel(conn); + if (cancel != NULL && PQcancel(cancel, errbuf, NAMEDATALEN) != 1) + { + elog(WARNING, "%s", errbuf); + } + PQfreeCancel(cancel); + } + StopZeroMQServer(zeromqServer); + + PG_RE_THROW(); + } + PG_END_TRY(); + + for (loopIndex = 0; loopIndex < workerConnectionCount; loopIndex++) + { + if (finish[loopIndex] == -1) + { + failCount++; + } + } + /* + * TODO@luoyuanhao: two phase commit, if failCount > 0, rollback. + */ + StopZeroMQServer(zeromqServer); + if (completionTag != NULL) + { + snprintf(completionTag, COMPLETION_TAG_BUFSIZE, + "COPY " UINT64_FORMAT, processedRowCount); + } +} + +/* + * IsBulkloadCopy checks if the given copy statement has the 'method' option + * and the value is 'bulkload'. + */ +bool +IsBulkloadCopy(CopyStmt *copyStatement) +{ + ListCell *optionCell = NULL; + DefElem *defel = NULL; + foreach(optionCell, copyStatement->options) + { + defel = (DefElem *) lfirst(optionCell); + if (strcasecmp(defel->defname, "method") == 0) + { + char *method = defGetString(defel); + if (strcasecmp(method, "bulkload") != 0) + { + elog(ERROR, "Unsupported method: %s. Valid values('bulkload')", method); + } + else + { + return true; + } + } + } + return false; +} + +/* + * IsBinaryCopy checks if the given copy statement has the 'format' option + * and the value is 'binary'. + */ + bool +IsBinaryCopy(CopyStmt *copyStatement) +{ + ListCell *optionCell = NULL; + DefElem *defel = NULL; + foreach(optionCell, copyStatement->options) + { + defel = (DefElem *) lfirst(optionCell); + if (strcasecmp(defel->defname, "format") == 0) + { + char *method = defGetString(defel); + if (strcasecmp(method, "binary") == 0) + { + return true; + } + } + } + return false; +} + +/* + * IsBulkloadClient checks if the given copy statement has the 'bulkload_host' option. + */ + bool +IsBulkloadClient(CopyStmt *copyStatement) +{ + ListCell *optionCell = NULL; + DefElem *defel = NULL; + foreach(optionCell, copyStatement->options) + { + defel = (DefElem *) lfirst(optionCell); + if (strcasecmp(defel->defname, "bulkload_host") == 0) + { + return true; + } + } + return false; +} + +/* + * RemoveBulkloadOptions removes bulkload related copy options from the option + * list of the copy statement. + */ +static void +RemoveBulkloadOptions(CopyStmt *copyStatement) +{ + List *newOptionList = NIL; + ListCell *optionCell = NULL; + + /* walk over the list of all options */ + foreach(optionCell, copyStatement->options) + { + DefElem *option = (DefElem *) lfirst(optionCell); + + /* skip master related options */ + if ((strcmp(option->defname, "bulkload_host") == 0) || + (strcmp(option->defname, "bulkload_port") == 0) || + (strcmp(option->defname, "method") == 0)) + { + continue; + } + + newOptionList = lappend(newOptionList, option); + } + + copyStatement->options = newOptionList; +} + +/* + * BulkloadServerAddress gets the bulkload zeromq server address from copy options + * and returns it. Note that if the bulkload_port is not provided, we use 5557 as + * the default port. + */ +static NodeAddress * +BulkloadServerAddress(CopyStmt *copyStatement) +{ + NodeAddress *bulkloadServer = (NodeAddress *) palloc0(sizeof(NodeAddress)); + char *nodeName = NULL; + + /* set default port to 5557 */ + uint32 nodePort = 5557; + + ListCell *optionCell = NULL; + foreach(optionCell, copyStatement->options) + { + DefElem *defel = (DefElem *) lfirst(optionCell); + if (strncmp(defel->defname, "bulkload_host", NAMEDATALEN) == 0) + { + nodeName = defGetString(defel); + } + else if (strncmp(defel->defname, "bulkload_port", NAMEDATALEN) == 0) + { + nodePort = defGetInt32(defel); + } + } + + bulkloadServer->nodeName = nodeName; + bulkloadServer->nodePort = nodePort; + return bulkloadServer; +} + +/* + * ConstructBulkloadCopyStmt constructs the text of a Bulkload COPY statement for + * executing in bulkload copy client. + */ +static StringInfo +ConstructBulkloadCopyStmt(CopyStmt *copyStatement, NodeAddress *masterNodeAddress, + char *nodeName, uint32 nodePort) +{ + char *schemaName = copyStatement->relation->schemaname; + char *relationName = copyStatement->relation->relname; + char *qualifiedName = quote_qualified_identifier(schemaName, relationName); + List *attlist = copyStatement->attlist; + ListCell *lc = NULL; + char *binaryPath = NULL; + StringInfo optionsString = NULL; + StringInfo command = NULL; + int res; + bool isfirst = true; + + RemoveBulkloadOptions(copyStatement); + + binaryPath = (char *) palloc0(NAMEDATALEN); + res = readlink("/proc/self/exe", binaryPath, NAMEDATALEN); + if (res == -1) + { + elog(ERROR, "%s", "Can't get absolute path of PG_HOME"); + } + else + { + /* + * original string would be "/path_of_pg_home/bin/postgres" + * after cutting it turns to be "/path_of_pg_home/bin/" + */ + binaryPath[res - 8] = '\0'; + /* append 'bload' */ + strcat(binaryPath, "bload"); + } + + optionsString = DeparseCopyStatementOptions(copyStatement->options); + command = makeStringInfo(); + appendStringInfo(command, "COPY %s", qualifiedName); + if (list_length(attlist) != 0) + { + appendStringInfoChar(command, '('); + foreach(lc, attlist) + { + if (isfirst) + { + isfirst = false; + } + else + { + appendStringInfoString(command, ", "); + } + appendStringInfoString(command, strVal(lfirst(lc))); + } + appendStringInfoChar(command, ')'); + } + appendStringInfo(command, " FROM PROGRAM '%s' WITH(master_host '%s', " + "master_port %d, method 'bulkload', bulkload_host '%s', bulkload_port %d", + binaryPath, + masterNodeAddress->nodeName, + masterNodeAddress->nodePort, + nodeName, + nodePort); + if (strlen(optionsString->data) != 0) + { + appendStringInfo(command, ", %s)", optionsString->data); + } + else + { + appendStringInfoChar(command, ')'); + } + return command; +} + +/* + * DeparseCopyStatementOptions construct the text command in WITH clause of COPY stmt. + */ +static StringInfo +DeparseCopyStatementOptions(List *options) +{ + StringInfo optionsStr = makeStringInfo(); + ListCell *option; + bool isfirst = true; + DefElem *defel = NULL; + foreach(option, options) + { + if (isfirst) isfirst = false; + else appendStringInfoString(optionsStr, ", "); + + defel = (DefElem *) lfirst(option); + + if (strcmp(defel->defname, "format") == 0) + { + appendStringInfo(optionsStr, "format %s", defGetString(defel)); + } + else if (strcmp(defel->defname, "oids") == 0) + { + appendStringInfo(optionsStr, "oids %s", defGetBoolean(defel) ? "true" : "false"); + } + else if (strcmp(defel->defname, "freeze") == 0) + { + appendStringInfo(optionsStr, "freeze %s", defGetBoolean(defel) ? "true" : "false"); + } + else if (strcmp(defel->defname, "delimiter") == 0) + { + appendStringInfo(optionsStr, "delimiter '%s'", defGetString(defel)); + } + else if (strcmp(defel->defname, "null") == 0) + { + appendStringInfo(optionsStr, "null '%s'", defGetString(defel)); + } + else if (strcmp(defel->defname, "header") == 0) + { + appendStringInfo(optionsStr, "header %s", defGetBoolean(defel) ? "true" : "false"); + } + else if (strcmp(defel->defname, "quote") == 0) + { + appendStringInfo(optionsStr, "quote '%s'", defGetString(defel)); + } + else if (strcmp(defel->defname, "escape") == 0) + { + if (strcmp(defGetString(defel), "\\") == 0) + { + appendStringInfo(optionsStr, "quote '\\%s'", defGetString(defel)); + } + else + { + appendStringInfo(optionsStr, "quote '%s'", defGetString(defel)); + } + } + /* unhandle force_quote/force_not_null/force_null and convert_selectively*/ + //else if (strcmp(defel->defname, "force_quote") == 0) + //{ + // if (cstate->force_quote || cstate->force_quote_all) + // ereport(ERROR, + // (errcode(ERRCODE_SYNTAX_ERROR), + // errmsg("conflicting or redundant options"))); + // if (defel->arg && IsA(defel->arg, A_Star)) + // cstate->force_quote_all = true; + // else if (defel->arg && IsA(defel->arg, List)) + // cstate->force_quote = (List *) defel->arg; + // else + // ereport(ERROR, + // (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + // errmsg("argument to option \"%s\" must be a list of column names", + // defel->defname))); + //} + //else if (strcmp(defel->defname, "force_not_null") == 0) + //{ + // if (cstate->force_notnull) + // ereport(ERROR, + // (errcode(ERRCODE_SYNTAX_ERROR), + // errmsg("conflicting or redundant options"))); + // if (defel->arg && IsA(defel->arg, List)) + // cstate->force_notnull = (List *) defel->arg; + // else + // ereport(ERROR, + // (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + // errmsg("argument to option \"%s\" must be a list of column names", + // defel->defname))); + //} + //else if (strcmp(defel->defname, "force_null") == 0) + //{ + // if (cstate->force_null) + // ereport(ERROR, + // (errcode(ERRCODE_SYNTAX_ERROR), + // errmsg("conflicting or redundant options"))); + // if (defel->arg && IsA(defel->arg, List)) + // cstate->force_null = (List *) defel->arg; + // else + // ereport(ERROR, + // (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + // errmsg("argument to option \"%s\" must be a list of column names", + // defel->defname))); + //} + //else if (strcmp(defel->defname, "convert_selectively") == 0) + //{ + // /* + // * Undocumented, not-accessible-from-SQL option: convert only the + // * named columns to binary form, storing the rest as NULLs. It's + // * allowed for the column list to be NIL. + // */ + // if (cstate->convert_selectively) + // ereport(ERROR, + // (errcode(ERRCODE_SYNTAX_ERROR), + // errmsg("conflicting or redundant options"))); + // cstate->convert_selectively = true; + // if (defel->arg == NULL || IsA(defel->arg, List)) + // cstate->convert_select = (List *) defel->arg; + // else + // ereport(ERROR, + // (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + // errmsg("argument to option \"%s\" must be a list of column names", + // defel->defname))); + //} + //else if (strcmp(defel->defname, "encoding") == 0) + //{ + // appendStringInfo(optionsStr, "encoding '%s'", defGetString(defel)); + //} + else + ereport(ERROR, + (errcode(ERRCODE_SYNTAX_ERROR), + errmsg("option \"%s\" not recognized", + defel->defname))); + } + return optionsStr; +} +/* + * BulkloadCopyToNewShards executes the bulkload COPY command sent by bulkload server + * for APPEND distributed table. + * It acts just like CopyToNewShards() but records are received from zeromq server. + */ +static void +BulkloadCopyToNewShards(CopyStmt *copyStatement, char *completionTag, Oid relationId) +{ + NodeAddress *bulkloadServer = BulkloadServerAddress(copyStatement); + RemoveBulkloadOptions(copyStatement); + RebuildBulkloadCopyStatement(copyStatement, bulkloadServer); + CopyToNewShards(copyStatement, completionTag, relationId); +} +/* + * LocalAddress gets the host and port of current running postgres. + */ +static NodeAddress * +LocalAddress(void) +{ + NodeAddress *node = (NodeAddress *) palloc0(sizeof(NodeAddress)); + char *host = (char *) palloc0(32); + const char *portStr = GetConfigOption("port", true, false); + int rc = gethostname(host, 32); + if (rc != 0) + { + strcpy(host, "localhost"); + elog(WARNING, "gethostname fail: %s, use 'localhost'", strerror(errno)); + } + node->nodeName = host; + if (portStr == NULL) + { + node->nodePort = 5432; + } + else + { + node->nodePort = atoi(portStr); + } + return node; +} + +/* + * MasterWorkerNodeList dispatches the master_get_active_worker_nodes + * between local or remote master node according to the master connection state. + */ +static List * +MasterWorkerNodeList(void) +{ + List *workerNodeList = NIL; + if (masterConnection == NULL) + { + workerNodeList = WorkerNodeList(); + } + else + { + workerNodeList = RemoteWorkerNodeList(); + } + + return workerNodeList ; +} + +/* + * RemoteWorkerNodeList gets the active worker node list from the remote master node. + */ +static List * +RemoteWorkerNodeList(void) +{ + List *workerNodeList = NIL; + PGresult *queryResult = NULL; + + StringInfo workerNodeCommand = makeStringInfo(); + appendStringInfoString(workerNodeCommand, ACTIVE_WORKER_NODE_QUERY); + + queryResult = PQexec(masterConnection->pgConn, workerNodeCommand->data); + if (PQresultStatus(queryResult) == PGRES_TUPLES_OK) + { + int rowCount = PQntuples(queryResult); + int rowIndex = 0; + + for (rowIndex = 0; rowIndex < rowCount; rowIndex++) + { + WorkerNode *workerNode = + (WorkerNode *) palloc0(sizeof(WorkerNode)); + + char *host = PQgetvalue(queryResult, rowIndex, 0); + char *port = PQgetvalue(queryResult, rowIndex, 1); + strcpy(workerNode->workerName, host); + workerNode->workerPort = atoi(port); + + workerNodeList = lappend(workerNodeList, workerNode); + } + } + else + { + elog(ERROR, "could not get active worker node list from the master node: %s", + PQresultErrorMessage(queryResult)); + } + PQclear(queryResult); + + return workerNodeList; +} + +/* + * StartZeroMQServer starts a zeromq socket, reads data from file, remote frontend + * or the output of the program and then sends it to zeromq client. + * TODO@luoyuanhao: Currently we don't support bulkload copy from stdin with binary + */ +static void +StartZeroMQServer(ZeroMQServer *zeromqServer, bool is_program, bool binary, int natts) +{ + uint64_t start = 0, read = 0; + FILE *fp = NULL; + char *buf = NULL; + char zeroaddr[32]; + void *context = NULL; + void *sender = NULL; + void *controller = NULL; + char *file = zeromqServer->file; + bool pipe = (strlen(file) == 0); + StringInfoData msgbuf; + int16 format = binary ? 1 : 0; + int loopIdx; + bool copyDone = false; + + context = zmq_ctx_new (); + Assert(context != NULL); + // Socket to send messages on + sender = zmq_socket(context, ZMQ_PUSH); + if (sender == NULL) + { + elog(ERROR, "zmq_socket() error(%d): %s", errno, zmq_strerror(errno)); + } + // Socket for control signal + controller = zmq_socket(context, ZMQ_PUB); + if (controller == NULL) + { + elog(ERROR, "zmq_socket() error(%d): %s", errno, zmq_strerror(errno)); + } + + zeromqServer->context = context; + zeromqServer->sender = sender; + zeromqServer->controller = controller; + + sprintf(zeroaddr, "tcp://*:%d", zeromqServer->port); + if (zmq_bind (sender, zeroaddr) != 0) + { + elog(ERROR, "zmq_bind() error(%d): %s", errno, zmq_strerror(errno)); + } + sprintf(zeroaddr, "tcp://*:%d", zeromqServer->port + 1); + if (zmq_bind (controller, zeroaddr) != 0) + { + elog(ERROR, "zmq_bind() error(%d): %s", errno, zmq_strerror(errno)); + } + + if (pipe) + { + /* + * inspired by ReceivedCopyBegin() + */ + Assert(!binary); + pq_beginmessage(&msgbuf, 'G'); + pq_sendbyte(&msgbuf, format); /* overall format */ + pq_sendint(&msgbuf, natts, 2); + for (loopIdx = 0; loopIdx < natts; loopIdx++) + pq_sendint(&msgbuf, format, 2); /* per-column formats */ + pq_endmessage(&msgbuf); + pq_flush(); + + initStringInfo(&msgbuf); + /* get records from fe */ + while (!copyDone) + { + int mtype; + CHECK_FOR_INTERRUPTS(); + HOLD_CANCEL_INTERRUPTS(); + /* + * inspired by CopyGetData() + */ + pq_startmsgread(); + mtype = pq_getbyte(); + if (mtype == EOF) + elog(ERROR, "unexpected EOF on client connection with an open transaction"); + if (pq_getmessage(&msgbuf, 0)) + elog(ERROR, "unexpected EOF on client connection with an open transaction"); + RESUME_CANCEL_INTERRUPTS(); + switch (mtype) + { + case 'd': /* CopyData */ + SendMessage(zeromqServer, msgbuf.data, msgbuf.len, false); + break; + case 'c': /* CopyDone */ + /* COPY IN correctly terminated by frontend */ + copyDone = true; + break; + case 'f': /* CopyFail */ + elog(ERROR, "COPY from stdin failed: %s", pq_getmsgstring(&msgbuf)); + break; + case 'H': /* Flush */ + case 'S': /* Sync */ + break; + default: + elog(ERROR, "unexpected message type 0x%02X during COPY from stdin", mtype); + break; + } + } + return; + } + + Assert(!pipe); + if (is_program) + { + fp = popen(file, PG_BINARY_R); + if (fp == NULL) + { + elog(ERROR, "could not execute command \"%s\"", file); + } + } + else + { + struct stat st; + fp = fopen(file, PG_BINARY_R); + if (fp == NULL) + { + elog(ERROR, "could not open file \"%s\": %s", file, strerror(errno)); + } + if (fstat(fileno(fp), &st)) + { + elog(ERROR, "could not stat file \"%s\"", file); + } + if (S_ISDIR(st.st_mode)) + { + elog(ERROR, "\"%s\" is a directory", file); + } + } + + buf = (char *) palloc0(BatchSize + MaxRecordSize + 1); + Assert(buf != NULL); + + if (!binary) + { + while (true) + { + uint64_t i; + CHECK_FOR_INTERRUPTS(); + start = 0; + read = fread(buf + start, 1, BatchSize, fp); + start += read; + if (read < BatchSize) break; + for (i = 0; i < MaxRecordSize; i++) + { + read = fread(buf + start, 1, 1, fp); + if (read == 0) break; + Assert(read == 1); + start += read; + if (buf[start - 1] == '\n') break; + } + if (i == MaxRecordSize) + { + char *tmp = (char*) palloc0(MaxRecordSize + 1); + strncpy(tmp, buf + start - MaxRecordSize, MaxRecordSize); + tmp[MaxRecordSize] = '\0'; + elog(ERROR, "Too large record: %s", tmp); + } + else + { + SendMessage(zeromqServer, buf, start, false); + } + } + if (start > 0) + { + SendMessage(zeromqServer, buf, start, false); + } + } + else + { + int32 flag, elen; + int16 fld_count; + + /* Signature */ + read = fread(buf, 1, 11, fp); + if (read != 11 || strncmp(buf, BinarySignature, 11) != 0) + { + elog(ERROR, "COPY file signature not recognized"); + } + /* Flags field */ + read = fread(buf, 1, 4, fp); + if (read != 4) + { + elog(ERROR, "invalid COPY file header (missing flags)"); + } + flag = (int32) ntohl(*(uint32 *)buf); + if ((flag & (1 << 16)) != 0) + { + elog(ERROR, "bulkload COPY can't set OID flag"); + } + flag &= ~(1 << 16); + if ((flag >> 16) != 0) + { + elog(ERROR, "unrecognized critical flags in COPY file header"); + } + /* Header extension length */ + read = fread(buf, 1, 4, fp); + if (read != 4) + { + elog(ERROR, "invalid COPY file header (missing length)"); + } + elen = (int32) ntohl(*(uint32 *)buf); + /* Skip extension header, if present */ + read = fread(buf, 1, elen, fp); + if (read != elen) + { + elog(ERROR, "invalid COPY file header (wrong length)"); + } + + /* handle tuples one by one */ + while (true) + { + int16 fld_index; + int32 fld_size; + + CHECK_FOR_INTERRUPTS(); + start = 0; + read = fread(buf + start, 1, 2, fp); + if (read != 2) + { + /* EOF detected (end of file, or protocol-level EOR) */ + break; + } + fld_count = (int16) ntohs(*(uint16 *)buf); + if (fld_count == -1) + { + read = fread(buf + start, 1, 1, fp); + if (read == 1) + { + elog(ERROR, "received copy data after EOF marker"); + } + /* Received EOF marker */ + break; + } + start += 2; + for (fld_index = 0; fld_index < fld_count; fld_index++) + { + read = fread(buf + start, 1, 4, fp); + if (read != 4) + { + elog(ERROR, "unexpected EOF in COPY data"); + } + fld_size = (int32) ntohl(*(uint32 *)(buf + start)); + if (fld_size == -1) + { + /* null value */ + start += 4; + } + else if (fld_size < 0) + { + elog(ERROR, "invalid field size %d", fld_size); + } + else + { + start += 4; + read = fread(buf + start, 1, fld_size, fp); + if (read != fld_size) + { + elog(ERROR, "unexpected EOF in COPY data"); + } + else + { + /* skip field value */ + start += fld_size; + } + } + + if (start >= MaxRecordSize + BatchSize) + { + elog(ERROR, "Too large binary record: %s", buf); + } + } + SendMessage(zeromqServer, buf, start, false); + } + } + if (is_program) + { + int rc = pclose(fp); + if (rc == -1) + { + elog(WARNING, "could not close pipe to external command \"%s\"", file); + } + else if (rc != 0) + { + elog(WARNING, "program \"%s\" failed", file); + } + } + else if (fclose(fp) != 0) + { + elog(WARNING, "close file error: %s", strerror(errno)); + } +} + +/* + * SendMessage sends message to zeromq socket. + * If kill is true, send KILL signal. + */ +static void +SendMessage(ZeroMQServer *zeromqServer, char *buf, size_t len, bool kill) +{ + int rc; + if (kill) + { + rc = zmq_send(zeromqServer->controller, buf, len, 0); + } + else + { + rc = zmq_send(zeromqServer->sender, buf, len, 0); + } + if (rc != len) + { + elog(LOG, "zmq_send() error(%d): %s", errno, zmq_strerror(errno)); + } +} + +/* + * StopZeroMQServer stops zeromq server and releases related resources. + */ +static void +StopZeroMQServer(ZeroMQServer *zeromqServer) +{ + if (zeromqServer->sender) + { + zmq_close(zeromqServer->sender); + zeromqServer->sender = NULL; + } + if (zeromqServer->controller) + { + zmq_close(zeromqServer->controller); + zeromqServer->controller = NULL; + } + if (zeromqServer->context) + { + zmq_ctx_destroy(zeromqServer->context); + zeromqServer->context = NULL; + } +} + +/* + * RebuildBulkloadCopyStatement adds bulkload server address as PROGRAM's arguments. + */ +static void +RebuildBulkloadCopyStatement(CopyStmt *copyStatement, NodeAddress *bulkloadServer) +{ + StringInfo tmp = makeStringInfo(); + appendStringInfo(tmp, "%s %s %d", copyStatement->filename, bulkloadServer->nodeName, + bulkloadServer->nodePort); + if (IsBinaryCopy(copyStatement)) + { + appendStringInfoString(tmp, " binary"); + } + copyStatement->filename = tmp->data; +} +/* + * MasterRelationId gets the relationId of relation from the master node. + */ +static Oid +MasterRelationId(char *qualifiedName) +{ + Oid relationId = 0; + PGresult *queryResult = NULL; + + StringInfo relationIdCommand = makeStringInfo(); + appendStringInfo(relationIdCommand, RELATIONID_QUERY, qualifiedName); + + queryResult = PQexec(masterConnection->pgConn, relationIdCommand->data); + if (PQresultStatus(queryResult) == PGRES_TUPLES_OK) + { + char *relationIdString = PQgetvalue(queryResult, 0, 0); + if (relationIdString == NULL || (*relationIdString) == '\0') + { + elog(ERROR, "could not find relationId for the table %s", qualifiedName); + } + + relationId = (Oid) atoi(relationIdString); + } + else + { + elog(ERROR, "could not get the relationId of the distributed table %s: %s", + qualifiedName, PQresultErrorMessage(queryResult)); + } + PQclear(queryResult); + return relationId; +} + +/* + * MasterDistributedTableCacheEntry get's metadada from master node and + * build's a DistTableCacheEntry for the relation. + */ +static DistTableCacheEntry * +MasterDistributedTableCacheEntry(RangeVar *relation) +{ + DistTableCacheEntry *cacheEntry = NULL; + Oid relationId = 0; + + /* temporary value */ + char *partmethod = NULL; + char *colocationid = NULL; + char *repmodel = NULL; + char *shardidString = NULL; + char *storageString = NULL; + char *minValueString = NULL; + char *maxValueString = NULL; + char *shardidStringEnd = NULL; + + /* members of pg_dist_partition and DistTableCacheEntry */ + char *partitionKeyString = NULL; + char partitionMethod = 0; + uint32 colocationId = INVALID_COLOCATION_ID; + char replicationModel = 0; + + /* members of pg_dist_shard */ + int64 shardId; + char storageType; + Datum minValue = 0; + Datum maxValue = 0; + bool minValueExists = true; + bool maxValueExists = true; + + /* members of DistTableCacheEntry */ + int shardIntervalArrayLength = 0; + ShardInterval **shardIntervalArray = NULL; + ShardInterval **sortedShardIntervalArray = NULL; + FmgrInfo *shardIntervalCompareFunction = NULL; + FmgrInfo *hashFunction = NULL; + bool hasUninitializedShardInterval = false; + bool hasUniformHashDistribution = false; + + ShardInterval *shardInterval = NULL; + Oid intervalTypeId = INT4OID; + int32 intervalTypeMod = -1; + int16 intervalTypeLen = 0; + bool intervalByVal = false; + char intervalAlign = '0'; + char intervalDelim = '0'; + Oid typeIoParam = InvalidOid; + Oid inputFunctionId = InvalidOid; + PGresult *queryResult = NULL; + StringInfo partitionKeyStringInfo = makeStringInfo(); + StringInfo queryString = makeStringInfo(); + + char *relationName = relation->relname; + char *schemaName = relation->schemaname; + char *qualifiedName = quote_qualified_identifier(schemaName, relationName); + relationId = MasterRelationId(qualifiedName); + + Assert(masterConnection != NULL); + + appendStringInfo(queryString, "SELECT * FROM pg_dist_partition WHERE logicalrelid=%d", + relationId); + queryResult = PQexec(masterConnection->pgConn, queryString->data); + if (PQresultStatus(queryResult) == PGRES_TUPLES_OK) + { + int rowCount = PQntuples(queryResult); + Assert(rowCount == 1); + + partmethod = PQgetvalue(queryResult, 0, Anum_pg_dist_partition_partmethod - 1); + partitionKeyString = PQgetvalue(queryResult, 0, Anum_pg_dist_partition_partkey - 1); + colocationid = PQgetvalue(queryResult, 0, Anum_pg_dist_partition_colocationid - 1); + repmodel = PQgetvalue(queryResult, 0, Anum_pg_dist_partition_repmodel - 1); + + partitionMethod = partmethod[0]; + appendStringInfoString(partitionKeyStringInfo, partitionKeyString); + partitionKeyString = partitionKeyStringInfo->data; + colocationId = (uint32) atoi(colocationid); + replicationModel = repmodel[0]; + } + else + { + elog(ERROR, "could not get metadata of table %s: %s", + qualifiedName, PQresultErrorMessage(queryResult)); + } + PQclear(queryResult); + + get_type_io_data(intervalTypeId, IOFunc_input, &intervalTypeLen, &intervalByVal, + &intervalAlign, &intervalDelim, &typeIoParam, &inputFunctionId); + + resetStringInfo(queryString); + appendStringInfo(queryString, "SELECT * FROM pg_dist_shard WHERE logicalrelid=%d", + relationId); + queryResult = PQexec(masterConnection->pgConn, queryString->data); + if (PQresultStatus(queryResult) == PGRES_TUPLES_OK) + { + int arrayIndex = 0; + + shardIntervalArrayLength = PQntuples(queryResult); + shardIntervalArray = (ShardInterval **) palloc0( + shardIntervalArrayLength * sizeof(ShardInterval *)); + + for (arrayIndex = 0; arrayIndex < shardIntervalArrayLength; arrayIndex++) + { + shardidString = + PQgetvalue(queryResult, arrayIndex, Anum_pg_dist_shard_shardid - 1); + storageString = + PQgetvalue(queryResult, arrayIndex, Anum_pg_dist_shard_shardstorage - 1); + minValueString = + PQgetvalue(queryResult, arrayIndex, Anum_pg_dist_shard_shardminvalue - 2); + maxValueString = + PQgetvalue(queryResult, arrayIndex, Anum_pg_dist_shard_shardmaxvalue - 2); + + shardId = strtoul(shardidString, &shardidStringEnd, 0); + storageType = storageString[0]; + /* finally convert min/max values to their actual types */ + minValue = OidInputFunctionCall(inputFunctionId, minValueString, + typeIoParam, intervalTypeMod); + maxValue = OidInputFunctionCall(inputFunctionId, maxValueString, + typeIoParam, intervalTypeMod); + + shardInterval = CitusMakeNode(ShardInterval); + shardInterval->relationId = relationId; + shardInterval->storageType = storageType; + shardInterval->valueTypeId = intervalTypeId; + shardInterval->valueTypeLen = intervalTypeLen; + shardInterval->valueByVal = intervalByVal; + shardInterval->minValueExists = minValueExists; + shardInterval->maxValueExists = maxValueExists; + shardInterval->minValue = minValue; + shardInterval->maxValue = maxValue; + shardInterval->shardId = shardId; + + shardIntervalArray[arrayIndex] = shardInterval; + } + } + else + { + elog(ERROR, "could not get metadata of table %s: %s", + qualifiedName, PQresultErrorMessage(queryResult)); + } + PQclear(queryResult); + + /* decide and allocate interval comparison function */ + if (shardIntervalArrayLength > 0) + { + shardIntervalCompareFunction = GetFunctionInfo(INT4OID, BTREE_AM_OID, + BTORDER_PROC); + } + + /* sort the interval array */ + sortedShardIntervalArray = SortShardIntervalArray(shardIntervalArray, + shardIntervalArrayLength, + shardIntervalCompareFunction); + + /* check if there exists any shard intervals with no min/max values */ + hasUninitializedShardInterval = + HasUninitializedShardInterval(sortedShardIntervalArray, shardIntervalArrayLength); + + /* we only need hash functions for hash distributed tables */ + if (partitionMethod == DISTRIBUTE_BY_HASH) + { + TypeCacheEntry *typeEntry = NULL; + Node *partitionNode = stringToNode(partitionKeyString); + Var *partitionColumn = (Var *) partitionNode; + Assert(IsA(partitionNode, Var)); + typeEntry = lookup_type_cache(partitionColumn->vartype, + TYPECACHE_HASH_PROC_FINFO); + + hashFunction = (FmgrInfo *) palloc0(sizeof(FmgrInfo)); + + fmgr_info_copy(hashFunction, &(typeEntry->hash_proc_finfo), CurrentMemoryContext); + + /* check the shard distribution for hash partitioned tables */ + hasUniformHashDistribution = + HasUniformHashDistribution(sortedShardIntervalArray, shardIntervalArrayLength); + } + + cacheEntry = (DistTableCacheEntry *) palloc0(sizeof(DistTableCacheEntry)); + cacheEntry->relationId = relationId; + cacheEntry->isValid = true; + cacheEntry->isDistributedTable = true; + cacheEntry->partitionKeyString = partitionKeyString; + cacheEntry->partitionMethod = partitionMethod; + cacheEntry->colocationId = colocationId; + cacheEntry->replicationModel = replicationModel; + cacheEntry->shardIntervalArrayLength = shardIntervalArrayLength; + cacheEntry->sortedShardIntervalArray = sortedShardIntervalArray; + cacheEntry->shardIntervalCompareFunction = shardIntervalCompareFunction; + cacheEntry->hashFunction = hashFunction; + cacheEntry->hasUninitializedShardInterval = hasUninitializedShardInterval; + cacheEntry->hasUniformHashDistribution = hasUniformHashDistribution; + + return cacheEntry; +} + +/* + * BulkloadCopyToExistingShards implements the COPY table_name FROM ... for HASH + * distributed table where there are already shards into which to copy. It works just + * like CopyToExistingShards, except that the later runs only in master node, but the + * former runs on each worker node. So BulkloadCopyToExistingShards would be mush more + * faster(# worker times) than CopyToExistingShards for HASH distributed table. + */ +static void +BulkloadCopyToExistingShards(CopyStmt *copyStatement, char *completionTag, + Oid relationId) +{ + DistTableCacheEntry *cacheEntry = NULL; + NodeAddress *bulkloadServer = NULL; + + Assert(masterConnection != NULL); + cacheEntry = MasterDistributedTableCacheEntry(copyStatement->relation); + InsertDistTableCacheEntry(relationId, cacheEntry); + + bulkloadServer = BulkloadServerAddress(copyStatement); + RemoveBulkloadOptions(copyStatement); + RebuildBulkloadCopyStatement(copyStatement, bulkloadServer); + + CopyToExistingShards(copyStatement, completionTag, relationId); +} + +/* + * Get PGconn* and it's index from PGconn* list by socket descriptor. + */ +static PGconn * +GetConnectionBySock(List *connList, int sock, int *connIdx) +{ + PGconn *conn = NULL; + int idx; + int n = list_length(connList); + for (idx = 0; idx < n; idx++) + { + conn = (PGconn *) list_nth(connList, idx); + if (PQsocket(conn) == sock) + { + *connIdx = idx; + return conn; + } + } + return NULL; +} diff --git a/src/backend/distributed/executor/multi_utility.c b/src/backend/distributed/executor/multi_utility.c index 849b7093c..62fb87db8 100644 --- a/src/backend/distributed/executor/multi_utility.c +++ b/src/backend/distributed/executor/multi_utility.c @@ -585,6 +585,10 @@ ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, bool *commandMustR SelectStmt *selectStmt = makeNode(SelectStmt); ResTarget *selectTarget = makeNode(ResTarget); + if (IsBulkloadCopy(copyStatement)) + { + elog(ERROR, "Bulkload copy only supports for COPY FROM"); + } allColumns->fields = list_make1(makeNode(A_Star)); allColumns->location = -1; diff --git a/src/backend/distributed/utils/metadata_cache.c b/src/backend/distributed/utils/metadata_cache.c index 2939ade41..e802ec6e9 100644 --- a/src/backend/distributed/utils/metadata_cache.c +++ b/src/backend/distributed/utils/metadata_cache.c @@ -125,14 +125,6 @@ static void BuildDistTableCacheEntry(DistTableCacheEntry *cacheEntry); static void BuildCachedShardList(DistTableCacheEntry *cacheEntry); static FmgrInfo * ShardIntervalCompareFunction(ShardInterval **shardIntervalArray, char partitionMethod); -static ShardInterval ** SortShardIntervalArray(ShardInterval **shardIntervalArray, - int shardCount, - FmgrInfo * - shardIntervalSortCompareFunction); -static bool HasUniformHashDistribution(ShardInterval **shardIntervalArray, - int shardIntervalArrayLength); -static bool HasUninitializedShardInterval(ShardInterval **sortedShardIntervalArray, - int shardCount); static void InitializeDistTableCache(void); static void InitializeWorkerNodeCache(void); static uint32 WorkerNodeHashCode(const void *key, Size keySize); @@ -434,6 +426,28 @@ DistributedTableCacheEntry(Oid distributedRelationId) } } +/* + * InsertDistTableCacheEntry insert the distributed table metadata for the + * passed relationId. + */ +void +InsertDistTableCacheEntry(Oid relationId, DistTableCacheEntry *ent) +{ + DistTableCacheEntry *cacheEntry = NULL; + bool foundInCache = false; + + if (DistTableCacheHash == NULL) + { + InitializeDistTableCache(); + } + + cacheEntry = hash_search(DistTableCacheHash, (const void *)&relationId, HASH_ENTER, + &foundInCache); + Assert(foundInCache == false); + memcpy(cacheEntry, ent, sizeof(DistTableCacheEntry)); + /* restore relationId */ + cacheEntry->relationId = relationId; +} /* * LookupDistTableCacheEntry returns the distributed table metadata for the @@ -819,7 +833,7 @@ ShardIntervalCompareFunction(ShardInterval **shardIntervalArray, char partitionM * SortedShardIntervalArray sorts the input shardIntervalArray. Shard intervals with * no min/max values are placed at the end of the array. */ -static ShardInterval ** +ShardInterval ** SortShardIntervalArray(ShardInterval **shardIntervalArray, int shardCount, FmgrInfo *shardIntervalSortCompareFunction) { @@ -847,7 +861,7 @@ SortShardIntervalArray(ShardInterval **shardIntervalArray, int shardCount, * has a uniform hash distribution, as produced by master_create_worker_shards for * hash partitioned tables. */ -static bool +bool HasUniformHashDistribution(ShardInterval **shardIntervalArray, int shardIntervalArrayLength) { @@ -891,7 +905,7 @@ HasUniformHashDistribution(ShardInterval **shardIntervalArray, * ensure that input shard interval array is sorted on shardminvalue and uninitialized * shard intervals are at the end of the array. */ -static bool +bool HasUninitializedShardInterval(ShardInterval **sortedShardIntervalArray, int shardCount) { bool hasUninitializedShardInterval = false; diff --git a/src/bin/bload/Makefile b/src/bin/bload/Makefile new file mode 100644 index 000000000..cf7498caa --- /dev/null +++ b/src/bin/bload/Makefile @@ -0,0 +1,29 @@ +#------------------------------------------------------------------------- +# +# Makefile for src/bin/bload +# +# Portions Copyright (c) 1996-2014, PostgreSQL Global Development Group +# Portions Copyright (c) 1994, Regents of the University of California +# +# src/bin/bload/Makefile +# +#------------------------------------------------------------------------- + +citus_subdir = src/bin/bload +citus_top_builddir = ../../.. + +PROGRAM = bload + +PGFILEDESC = "bload - the zeromq client for bulkload" + +OBJS = bload.o + +PG_LIBS = $(libpq) + +override CFLAGS += -lzmq + +include $(citus_top_builddir)/Makefile.global + +clean: bload-clean +bload-clean: + rm -f bload$(X) $(OBJS) diff --git a/src/bin/bload/bload.c b/src/bin/bload/bload.c new file mode 100644 index 000000000..f38f795e2 --- /dev/null +++ b/src/bin/bload/bload.c @@ -0,0 +1,176 @@ +/*------------------------------------------------------------------------- + * + * bload.c + * + * This is the zeromq client of bulkload copy. It pulls data from zeromq server + * and outputs the message to stdout. + * + * Copyright (c) 2012-2016, Citus Data, Inc. + * + * $Id$ + * + *------------------------------------------------------------------------- + */ + +#include "distributed/bload.h" +#include +#include +#include +#include +#include +#include + +/* constant used in binary protocol */ +static const char BinarySignature[11] = "PGCOPY\n\377\r\n\0"; + +int main(int argc, char *argv[]) +{ + FILE *fp = NULL; + uint64_t buffer_size = BatchSize + MaxRecordSize + 1; + char buf[buffer_size]; + char connstr[64]; + int rc; + const int zero = 0; + const short negative = -1; + bool binary = false; + + /* variables for zeromq */ + void *context = NULL; + void *receiver = NULL; + void *controller = NULL; + int port = 0; + int nbytes; + + fp = fopen("/tmp/bload.log", "a"); + if (!fp) fp = stderr; + if (argc < 3) + { + fprintf(fp, "Usage: %s host port [binary]\n", argv[0]); + fflush(fp); + fclose(fp); + return 1; + } + + if (argc == 4 && strcmp(argv[3], "binary") == 0) + { + binary = true; + } + + context = zmq_ctx_new(); + + // Socket to receive messages on + receiver = zmq_socket(context, ZMQ_PULL); + port = atoi(argv[2]); + sprintf(connstr, "tcp://%s:%d", argv[1], port); + rc = zmq_connect(receiver, connstr); + if (rc != 0) + { + fprintf(fp, "zmq_connect() error(%d): %s\n", errno, strerror(errno)); + fflush(fp); + fclose(fp); + zmq_close(receiver); + zmq_ctx_destroy(context); + return 1; + } + + // Socket to receive control message + controller = zmq_socket(context, ZMQ_SUB); + sprintf(connstr, "tcp://%s:%d", argv[1], port + 1); + rc = zmq_connect(controller, connstr); + if (rc != 0) + { + fprintf(fp, "zmq_connect() error(%d): %s\n", errno, strerror(errno)); + fflush(fp); + fclose(fp); + zmq_close(receiver); + zmq_close(controller); + zmq_ctx_destroy(context); + return 1; + } + zmq_setsockopt(controller, ZMQ_SUBSCRIBE, "", 0); + + zmq_pollitem_t items[] = { + { receiver, 0, ZMQ_POLLIN, 0 }, + { controller, 0, ZMQ_POLLIN, 0 } + }; + + if (binary) + { + /* Signature */ + fwrite(BinarySignature, 1, 11, stdout); + /* Flags field(no OIDs) */ + fwrite((void *)&zero, 1, 4, stdout); + /* No header extenstion */ + fwrite((void *)&zero, 1, 4, stdout); + fflush(stdout); + } + while (true) { + /* wait indefinitely for an event to occur */ + rc = zmq_poll(items, 2, -1); + + if (rc == -1) /* error occurs */ + { + fprintf(fp, "zmq_poll() error(%d): %s\n", errno, strerror(errno)); + fflush(fp); + break; + } + if (items[0].revents & ZMQ_POLLIN) /* receive a message */ + { + nbytes = zmq_recv(receiver, buf, buffer_size - 1, 0); + if (nbytes == -1) + { + fprintf(fp, "zmq_recv() error(%d): %s\n", errno, strerror(errno)); + fflush(fp); + break; + } + fwrite(buf, 1, nbytes, stdout); + fflush(stdout); + } + if (items[1].revents & ZMQ_POLLIN) /* receive signal kill */ + { + fprintf(fp, "receive signal kill, wait for exhausting all messages\n"); + fflush(fp); + /* consume all messages before exit */ + while (true) { + /* wait 100 milliseconds for an event to occur */ + rc = zmq_poll(items, 1, 100); + if (rc == 0) /* no more messages */ + { + break; + } + else if (rc == -1) /* error occurs */ + { + fprintf(fp, "zmq_poll() error(%d): %s\n", errno, strerror(errno)); + fflush(fp); + break; + } + if (items[0].revents & ZMQ_POLLIN) /* receive a message */ + { + nbytes = zmq_recv(receiver, buf, buffer_size - 1, 0); + if (nbytes == -1) + { + fprintf(fp, "zmq_recv() error(%d): %s\n", errno, strerror(errno)); + fflush(fp); + break; + } + fwrite(buf, 1, nbytes, stdout); + fflush(stdout); + } + } + fprintf(fp, "we have consume all messages, exit now\n"); + fflush(fp); + if (binary) + { + /* Binary footers */ + fwrite((void *)&negative, 1, 2, stdout); + fflush(stdout); + } + break; + } + } + zmq_close(receiver); + zmq_close(controller); + zmq_ctx_destroy(context); + fclose(fp); + return 0; +} diff --git a/src/include/distributed/bload.h b/src/include/distributed/bload.h new file mode 100644 index 000000000..be4b4f491 --- /dev/null +++ b/src/include/distributed/bload.h @@ -0,0 +1,18 @@ +/*------------------------------------------------------------------------- + * + * bload.h + * Definitions of const variables used in bulkload copy for + * distributed tables. + * + * Copyright (c) 2016, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#ifndef BLOAD_H +#define BLOAD_H + +#define BatchSize 1024 /* size of a zeromq message in bytes */ +#define MaxRecordSize 256 /* size of max acceptable record in bytes */ + +#endif diff --git a/src/include/distributed/master_protocol.h b/src/include/distributed/master_protocol.h index 6eddde757..a75a7bd4c 100644 --- a/src/include/distributed/master_protocol.h +++ b/src/include/distributed/master_protocol.h @@ -75,6 +75,8 @@ #define UPDATE_SHARD_STATISTICS_QUERY \ "SELECT master_update_shard_statistics(%ld)" #define PARTITION_METHOD_QUERY "SELECT part_method FROM master_get_table_metadata('%s');" +#define ACTIVE_WORKER_NODE_QUERY "SELECT * FROM master_get_active_worker_nodes();" +#define RELATIONID_QUERY "SELECT logical_relid FROM master_get_table_metadata('%s');" /* Enumeration that defines the shard placement policy to use while staging */ typedef enum diff --git a/src/include/distributed/metadata_cache.h b/src/include/distributed/metadata_cache.h index abd1861c0..ecabd1705 100644 --- a/src/include/distributed/metadata_cache.h +++ b/src/include/distributed/metadata_cache.h @@ -61,11 +61,20 @@ extern List * DistributedTableList(void); extern ShardInterval * LoadShardInterval(uint64 shardId); extern ShardPlacement * LoadShardPlacement(uint64 shardId, uint64 placementId); extern DistTableCacheEntry * DistributedTableCacheEntry(Oid distributedRelationId); +extern void InsertDistTableCacheEntry(Oid relationId, DistTableCacheEntry *ent); extern int GetLocalGroupId(void); extern List * DistTableOidList(void); extern List * ShardPlacementList(uint64 shardId); extern void CitusInvalidateRelcacheByRelid(Oid relationId); extern void CitusInvalidateRelcacheByShardId(int64 shardId); +extern ShardInterval ** SortShardIntervalArray(ShardInterval **shardIntervalArray, + int shardCount, + FmgrInfo * + shardIntervalSortCompareFunction); +extern bool HasUniformHashDistribution(ShardInterval **shardIntervalArray, + int shardIntervalArrayLength); +extern bool HasUninitializedShardInterval(ShardInterval **sortedShardIntervalArray, + int shardCount); extern bool CitusHasBeenLoaded(void); diff --git a/src/include/distributed/multi_copy.h b/src/include/distributed/multi_copy.h index d27926bbf..295e024eb 100644 --- a/src/include/distributed/multi_copy.h +++ b/src/include/distributed/multi_copy.h @@ -43,6 +43,18 @@ typedef struct NodeAddress int32 nodePort; } NodeAddress; +/* struct type to keep zeromq related value */ +typedef struct ZeroMQServer +{ + char host[NAMEDATALEN]; + int32 port; + char file[NAMEDATALEN]; + void *context; + void *sender; + void *controller; + +} ZeroMQServer; + /* function declarations for copying into a distributed table */ extern FmgrInfo * ColumnOutputFunctions(TupleDesc rowDescriptor, bool binaryFormat); @@ -56,5 +68,10 @@ extern void CitusCopyFrom(CopyStmt *copyStatement, char *completionTag); extern bool IsCopyFromWorker(CopyStmt *copyStatement); extern NodeAddress * MasterNodeAddress(CopyStmt *copyStatement); +/* functions declarations for bulkload copy */ +extern bool IsBulkloadCopy(CopyStmt *copyStatement); +extern bool IsBinaryCopy(CopyStmt *copyStatement); +extern bool IsBulkloadClient(CopyStmt *copyStatement); +extern void CitusBulkloadCopy(CopyStmt *copyStatement, char *completionTag); #endif /* MULTI_COPY_H */