diff --git a/src/backend/distributed/executor/multi_client_executor.c b/src/backend/distributed/executor/multi_client_executor.c index da852d5c1..fadfb7072 100644 --- a/src/backend/distributed/executor/multi_client_executor.c +++ b/src/backend/distributed/executor/multi_client_executor.c @@ -167,12 +167,12 @@ MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDataba * error and returns INVALID_CONNECTION_ID. */ int32 -MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase) +MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase, + const char *userName) { PGconn *connection = NULL; char connInfoString[STRING_BUFFER_SIZE]; ConnStatusType connStatusType = CONNECTION_BAD; - char *userName = CurrentUserName(); int32 connectionId = AllocateConnectionId(); if (connectionId == INVALID_CONNECTION_ID) @@ -188,6 +188,11 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD "command within a transaction"))); } + if (userName == NULL) + { + userName = CurrentUserName(); + } + /* transcribe connection paremeters to string */ snprintf(connInfoString, STRING_BUFFER_SIZE, CONN_INFO_TEMPLATE, nodeName, nodePort, nodeDatabase, userName, CLIENT_CONNECT_TIMEOUT); diff --git a/src/backend/distributed/executor/multi_real_time_executor.c b/src/backend/distributed/executor/multi_real_time_executor.c index 86a069584..e1a93cdfb 100644 --- a/src/backend/distributed/executor/multi_real_time_executor.c +++ b/src/backend/distributed/executor/multi_real_time_executor.c @@ -272,7 +272,8 @@ ManageTaskExecution(Task *task, TaskExecution *taskExecution, /* we use the same database name on the master and worker nodes */ nodeDatabase = get_database_name(MyDatabaseId); - connectionId = MultiClientConnectStart(nodeName, nodePort, nodeDatabase); + connectionId = MultiClientConnectStart(nodeName, nodePort, nodeDatabase, + NULL); connectionIdArray[currentIndex] = connectionId; /* if valid, poll the connection until the connection is initiated */ diff --git a/src/backend/distributed/executor/multi_task_tracker_executor.c b/src/backend/distributed/executor/multi_task_tracker_executor.c index 8bdcdacd3..0c8dd282e 100644 --- a/src/backend/distributed/executor/multi_task_tracker_executor.c +++ b/src/backend/distributed/executor/multi_task_tracker_executor.c @@ -25,6 +25,7 @@ #include "commands/dbcommands.h" #include "distributed/citus_nodes.h" +#include "distributed/metadata_cache.h" #include "distributed/multi_client_executor.h" #include "distributed/multi_physical_planner.h" #include "distributed/multi_server_executor.h" @@ -65,7 +66,8 @@ static Task * TaskHashLookup(HTAB *trackerHash, TaskType taskType, uint64 jobId, uint32 taskId); static bool TopLevelTask(Task *task); static bool TransmitExecutionCompleted(TaskExecution *taskExecution); -static HTAB * TrackerHash(const char *taskTrackerHashName, List *workerNodeList); +static HTAB * TrackerHash(const char *taskTrackerHashName, List *workerNodeList, + char *userName); static HTAB * TrackerHashCreate(const char *taskTrackerHashName, uint32 taskTrackerHashSize); static TaskTracker * TrackerHashEnter(HTAB *taskTrackerHash, char *nodeName, @@ -154,6 +156,7 @@ MultiTaskTrackerExecute(Job *job) List *workerNodeList = NIL; HTAB *taskTrackerHash = NULL; HTAB *transmitTrackerHash = NULL; + char *extensionOwner = CitusExtensionOwnerName(); const char *taskTrackerHashName = "Task Tracker Hash"; const char *transmitTrackerHashName = "Transmit Tracker Hash"; List *jobIdList = NIL; @@ -188,8 +191,12 @@ MultiTaskTrackerExecute(Job *job) workerNodeList = WorkerNodeList(); taskTrackerCount = (uint32) list_length(workerNodeList); - taskTrackerHash = TrackerHash(taskTrackerHashName, workerNodeList); - transmitTrackerHash = TrackerHash(transmitTrackerHashName, workerNodeList); + /* connect as the current user for running queries */ + taskTrackerHash = TrackerHash(taskTrackerHashName, workerNodeList, NULL); + + /* connect as the superuser for fetching result files */ + transmitTrackerHash = TrackerHash(transmitTrackerHashName, workerNodeList, + extensionOwner); TrackerHashConnect(taskTrackerHash); TrackerHashConnect(transmitTrackerHash); @@ -657,10 +664,11 @@ TransmitExecutionCompleted(TaskExecution *taskExecution) /* * TrackerHash creates a task tracker hash with the given name. The function * then inserts one task tracker entry for each node in the given worker node - * list, and initializes state for each task tracker. + * list, and initializes state for each task tracker. The userName argument + * indicates which user to connect as. */ static HTAB * -TrackerHash(const char *taskTrackerHashName, List *workerNodeList) +TrackerHash(const char *taskTrackerHashName, List *workerNodeList, char *userName) { /* create task tracker hash */ uint32 taskTrackerHashSize = list_length(workerNodeList); @@ -702,6 +710,7 @@ TrackerHash(const char *taskTrackerHashName, List *workerNodeList) } taskTracker->taskStateHash = taskStateHash; + taskTracker->userName = userName; } return taskTrackerHash; @@ -835,9 +844,10 @@ TrackerConnectPoll(TaskTracker *taskTracker) char *nodeName = taskTracker->workerName; uint32 nodePort = taskTracker->workerPort; char *nodeDatabase = get_database_name(MyDatabaseId); + char *nodeUser = taskTracker->userName; int32 connectionId = MultiClientConnectStart(nodeName, nodePort, - nodeDatabase); + nodeDatabase, nodeUser); if (connectionId != INVALID_CONNECTION_ID) { taskTracker->connectionId = connectionId; diff --git a/src/backend/distributed/executor/multi_utility.c b/src/backend/distributed/executor/multi_utility.c index 791f03099..b4a20d3f5 100644 --- a/src/backend/distributed/executor/multi_utility.c +++ b/src/backend/distributed/executor/multi_utility.c @@ -401,6 +401,8 @@ VerifyTransmitStmt(CopyStmt *copyStatement) { char *fileName = NULL; + EnsureSuperUser(); + /* do some minimal option verification */ if (copyStatement->relation == NULL || copyStatement->relation->relname == NULL) diff --git a/src/backend/distributed/worker/worker_data_fetch_protocol.c b/src/backend/distributed/worker/worker_data_fetch_protocol.c index 55be5783b..562fcc564 100644 --- a/src/backend/distributed/worker/worker_data_fetch_protocol.c +++ b/src/backend/distributed/worker/worker_data_fetch_protocol.c @@ -26,6 +26,7 @@ #include "commands/extension.h" #include "distributed/citus_ruleutils.h" #include "distributed/master_protocol.h" +#include "distributed/metadata_cache.h" #include "distributed/multi_client_executor.h" #include "distributed/multi_logical_optimizer.h" #include "distributed/multi_server_executor.h" @@ -46,10 +47,12 @@ bool ExpireCachedShards = false; /* Local functions forward declarations */ -static void FetchRegularFile(const char *nodeName, uint32 nodePort, - StringInfo remoteFilename, StringInfo localFilename); +static void FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort, + StringInfo remoteFilename, + StringInfo localFilename); static bool ReceiveRegularFile(const char *nodeName, uint32 nodePort, - StringInfo transmitCommand, StringInfo filePath); + const char *nodeUser, StringInfo transmitCommand, + StringInfo filePath); static void ReceiveResourceCleanup(int32 connectionId, const char *filename, int32 fileDescriptor); static void DeleteFile(const char *filename); @@ -115,7 +118,9 @@ worker_fetch_partition_file(PG_FUNCTION_ARGS) } nodeName = text_to_cstring(nodeNameText); - FetchRegularFile(nodeName, nodePort, remoteFilename, taskFilename); + + /* we've made sure the file names are sanitized, safe to fetch as superuser */ + FetchRegularFileAsSuperUser(nodeName, nodePort, remoteFilename, taskFilename); PG_RETURN_VOID(); } @@ -156,7 +161,9 @@ worker_fetch_query_results_file(PG_FUNCTION_ARGS) } nodeName = text_to_cstring(nodeNameText); - FetchRegularFile(nodeName, nodePort, remoteFilename, taskFilename); + + /* we've made sure the file names are sanitized, safe to fetch as superuser */ + FetchRegularFileAsSuperUser(nodeName, nodePort, remoteFilename, taskFilename); PG_RETURN_VOID(); } @@ -175,11 +182,16 @@ TaskFilename(StringInfo directoryName, uint32 taskId) } -/* Helper function to transfer the remote file in an idempotent manner. */ +/* + * FetchRegularFileAsSuperUser copies a file from a remote node in an idempotent + * manner. It connects to the remote node as superuser to give file access. + * Callers must make sure that the file names are sanitized. + */ static void -FetchRegularFile(const char *nodeName, uint32 nodePort, - StringInfo remoteFilename, StringInfo localFilename) +FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort, + StringInfo remoteFilename, StringInfo localFilename) { + char *nodeUser = NULL; StringInfo attemptFilename = NULL; StringInfo transmitCommand = NULL; uint32 randomId = (uint32) random(); @@ -198,7 +210,11 @@ FetchRegularFile(const char *nodeName, uint32 nodePort, transmitCommand = makeStringInfo(); appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilename->data); - received = ReceiveRegularFile(nodeName, nodePort, transmitCommand, attemptFilename); + /* connect as superuser to give file access */ + nodeUser = CitusExtensionOwnerName(); + + received = ReceiveRegularFile(nodeName, nodePort, nodeUser, transmitCommand, + attemptFilename); if (!received) { ereport(ERROR, (errmsg("could not receive file \"%s\" from %s:%u", @@ -225,7 +241,7 @@ FetchRegularFile(const char *nodeName, uint32 nodePort, * and returns false. */ static bool -ReceiveRegularFile(const char *nodeName, uint32 nodePort, +ReceiveRegularFile(const char *nodeName, uint32 nodePort, const char *nodeUser, StringInfo transmitCommand, StringInfo filePath) { int32 fileDescriptor = -1; @@ -257,7 +273,7 @@ ReceiveRegularFile(const char *nodeName, uint32 nodePort, nodeDatabase = get_database_name(MyDatabaseId); /* connect to remote node */ - connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL); + connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, nodeUser); if (connectionId == INVALID_CONNECTION_ID) { ReceiveResourceCleanup(connectionId, filename, fileDescriptor); @@ -745,7 +761,8 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName) remoteCopyCommand = makeStringInfo(); appendStringInfo(remoteCopyCommand, COPY_OUT_COMMAND, tableName); - received = ReceiveRegularFile(nodeName, nodePort, remoteCopyCommand, localFilePath); + received = ReceiveRegularFile(nodeName, nodePort, NULL, remoteCopyCommand, + localFilePath); if (!received) { return false; @@ -820,6 +837,7 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName) static bool FetchForeignTable(const char *nodeName, uint32 nodePort, const char *tableName) { + const char *nodeUser = NULL; StringInfo localFilePath = NULL; StringInfo remoteFilePath = NULL; StringInfo transmitCommand = NULL; @@ -846,7 +864,16 @@ FetchForeignTable(const char *nodeName, uint32 nodePort, const char *tableName) transmitCommand = makeStringInfo(); appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilePath->data); - received = ReceiveRegularFile(nodeName, nodePort, transmitCommand, localFilePath); + /* + * We allow some arbitrary input in the file name and connect to the remote + * node as superuser to transmit. Therefore, we only allow calling this + * function when already running as superuser. + */ + EnsureSuperUser(); + nodeUser = CitusExtensionOwnerName(); + + received = ReceiveRegularFile(nodeName, nodePort, nodeUser, transmitCommand, + localFilePath); if (!received) { return false; @@ -1183,7 +1210,7 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS) sourceCopyCommand = makeStringInfo(); appendStringInfo(sourceCopyCommand, COPY_OUT_COMMAND, sourceQualifiedName); - received = ReceiveRegularFile(sourceNodeName, sourceNodePort, sourceCopyCommand, + received = ReceiveRegularFile(sourceNodeName, sourceNodePort, NULL, sourceCopyCommand, localFilePath); if (!received) { diff --git a/src/include/distributed/multi_client_executor.h b/src/include/distributed/multi_client_executor.h index e9658ee22..a414ed35f 100644 --- a/src/include/distributed/multi_client_executor.h +++ b/src/include/distributed/multi_client_executor.h @@ -100,7 +100,7 @@ typedef struct WaitInfo extern int32 MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase, const char *nodeUser); extern int32 MultiClientConnectStart(const char *nodeName, uint32 nodePort, - const char *nodeDatabase); + const char *nodeDatabase, const char *nodeUser); extern ConnectStatus MultiClientConnectPoll(int32 connectionId); extern void MultiClientDisconnect(int32 connectionId); extern bool MultiClientConnectionUp(int32 connectionId); diff --git a/src/include/distributed/multi_server_executor.h b/src/include/distributed/multi_server_executor.h index 9067125bf..c2949f0e3 100644 --- a/src/include/distributed/multi_server_executor.h +++ b/src/include/distributed/multi_server_executor.h @@ -153,6 +153,7 @@ typedef struct TaskTracker { uint32 workerPort; /* node's port; part of hash table key */ char workerName[WORKER_LENGTH]; /* node's name; part of hash table key */ + char *userName; /* which user to connect as */ TrackerStatus trackerStatus; int32 connectionId; uint32 connectPollCount;