diff --git a/src/backend/distributed/executor/multi_client_executor.c b/src/backend/distributed/executor/multi_client_executor.c index c77115b9f..91d0e9506 100644 --- a/src/backend/distributed/executor/multi_client_executor.c +++ b/src/backend/distributed/executor/multi_client_executor.c @@ -127,7 +127,8 @@ MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDataba * error and returns INVALID_CONNECTION_ID. */ int32 -MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase) +MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase, + const char *userName) { MultiConnection *connection = NULL; ConnStatusType connStatusType = CONNECTION_OK; @@ -148,7 +149,8 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD } /* prepare asynchronous request for worker node connection */ - connection = StartNodeConnection(connectionFlags, nodeName, nodePort); + connection = StartNodeUserDatabaseConnection(connectionFlags, nodeName, nodePort, + userName, nodeDatabase); connStatusType = PQstatus(connection->pgConn); /* diff --git a/src/backend/distributed/executor/multi_real_time_executor.c b/src/backend/distributed/executor/multi_real_time_executor.c index 30d18b20e..be8054d7b 100644 --- a/src/backend/distributed/executor/multi_real_time_executor.c +++ b/src/backend/distributed/executor/multi_real_time_executor.c @@ -273,7 +273,8 @@ ManageTaskExecution(Task *task, TaskExecution *taskExecution, /* we use the same database name on the master and worker nodes */ nodeDatabase = get_database_name(MyDatabaseId); - connectionId = MultiClientConnectStart(nodeName, nodePort, nodeDatabase); + connectionId = MultiClientConnectStart(nodeName, nodePort, nodeDatabase, + NULL); connectionIdArray[currentIndex] = connectionId; /* if valid, poll the connection until the connection is initiated */ diff --git a/src/backend/distributed/executor/multi_task_tracker_executor.c b/src/backend/distributed/executor/multi_task_tracker_executor.c index 5b890106c..2f4fd9bfc 100644 --- a/src/backend/distributed/executor/multi_task_tracker_executor.c +++ b/src/backend/distributed/executor/multi_task_tracker_executor.c @@ -72,7 +72,8 @@ static Task * TaskHashLookup(HTAB *trackerHash, TaskType taskType, uint64 jobId, uint32 taskId); static bool TopLevelTask(Task *task); static bool TransmitExecutionCompleted(TaskExecution *taskExecution); -static HTAB * TrackerHash(const char *taskTrackerHashName, List *workerNodeList); +static HTAB * TrackerHash(const char *taskTrackerHashName, List *workerNodeList, + char *userName); static HTAB * TrackerHashCreate(const char *taskTrackerHashName, uint32 taskTrackerHashSize); static TaskTracker * TrackerHashEnter(HTAB *taskTrackerHash, char *nodeName, @@ -161,6 +162,7 @@ MultiTaskTrackerExecute(Job *job) List *workerNodeList = NIL; HTAB *taskTrackerHash = NULL; HTAB *transmitTrackerHash = NULL; + char *extensionOwner = CitusExtensionOwnerName(); const char *taskTrackerHashName = "Task Tracker Hash"; const char *transmitTrackerHashName = "Transmit Tracker Hash"; List *jobIdList = NIL; @@ -202,8 +204,12 @@ MultiTaskTrackerExecute(Job *job) workerNodeList = ActivePrimaryNodeList(); taskTrackerCount = (uint32) list_length(workerNodeList); - taskTrackerHash = TrackerHash(taskTrackerHashName, workerNodeList); - transmitTrackerHash = TrackerHash(transmitTrackerHashName, workerNodeList); + /* connect as the current user for running queries */ + taskTrackerHash = TrackerHash(taskTrackerHashName, workerNodeList, NULL); + + /* connect as the superuser for fetching result files */ + transmitTrackerHash = TrackerHash(transmitTrackerHashName, workerNodeList, + extensionOwner); TrackerHashConnect(taskTrackerHash); TrackerHashConnect(transmitTrackerHash); @@ -667,10 +673,11 @@ TransmitExecutionCompleted(TaskExecution *taskExecution) /* * TrackerHash creates a task tracker hash with the given name. The function * then inserts one task tracker entry for each node in the given worker node - * list, and initializes state for each task tracker. + * list, and initializes state for each task tracker. The userName argument + * indicates which user to connect as. */ static HTAB * -TrackerHash(const char *taskTrackerHashName, List *workerNodeList) +TrackerHash(const char *taskTrackerHashName, List *workerNodeList, char *userName) { /* create task tracker hash */ uint32 taskTrackerHashSize = list_length(workerNodeList); @@ -712,6 +719,7 @@ TrackerHash(const char *taskTrackerHashName, List *workerNodeList) } taskTracker->taskStateHash = taskStateHash; + taskTracker->userName = userName; } return taskTrackerHash; @@ -845,9 +853,10 @@ TrackerConnectPoll(TaskTracker *taskTracker) char *nodeName = taskTracker->workerName; uint32 nodePort = taskTracker->workerPort; char *nodeDatabase = get_database_name(MyDatabaseId); + char *nodeUser = taskTracker->userName; int32 connectionId = MultiClientConnectStart(nodeName, nodePort, - nodeDatabase); + nodeDatabase, nodeUser); if (connectionId != INVALID_CONNECTION_ID) { taskTracker->connectionId = connectionId; diff --git a/src/backend/distributed/executor/multi_utility.c b/src/backend/distributed/executor/multi_utility.c index 3c0b5853f..8c2467afd 100644 --- a/src/backend/distributed/executor/multi_utility.c +++ b/src/backend/distributed/executor/multi_utility.c @@ -637,6 +637,8 @@ VerifyTransmitStmt(CopyStmt *copyStatement) { char *fileName = NULL; + EnsureSuperUser(); + /* do some minimal option verification */ if (copyStatement->relation == NULL || copyStatement->relation->relname == NULL) diff --git a/src/backend/distributed/worker/worker_data_fetch_protocol.c b/src/backend/distributed/worker/worker_data_fetch_protocol.c index 5d3453144..22f8484a9 100644 --- a/src/backend/distributed/worker/worker_data_fetch_protocol.c +++ b/src/backend/distributed/worker/worker_data_fetch_protocol.c @@ -57,10 +57,12 @@ bool ExpireCachedShards = false; /* Local functions forward declarations */ -static void FetchRegularFile(const char *nodeName, uint32 nodePort, - StringInfo remoteFilename, StringInfo localFilename); +static void FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort, + StringInfo remoteFilename, + StringInfo localFilename); static bool ReceiveRegularFile(const char *nodeName, uint32 nodePort, - StringInfo transmitCommand, StringInfo filePath); + const char *nodeUser, StringInfo transmitCommand, + StringInfo filePath); static void ReceiveResourceCleanup(int32 connectionId, const char *filename, int32 fileDescriptor); static void DeleteFile(const char *filename); @@ -132,7 +134,9 @@ worker_fetch_partition_file(PG_FUNCTION_ARGS) } nodeName = text_to_cstring(nodeNameText); - FetchRegularFile(nodeName, nodePort, remoteFilename, taskFilename); + + /* we've made sure the file names are sanitized, safe to fetch as superuser */ + FetchRegularFileAsSuperUser(nodeName, nodePort, remoteFilename, taskFilename); PG_RETURN_VOID(); } @@ -176,7 +180,9 @@ worker_fetch_query_results_file(PG_FUNCTION_ARGS) } nodeName = text_to_cstring(nodeNameText); - FetchRegularFile(nodeName, nodePort, remoteFilename, taskFilename); + + /* we've made sure the file names are sanitized, safe to fetch as superuser */ + FetchRegularFileAsSuperUser(nodeName, nodePort, remoteFilename, taskFilename); PG_RETURN_VOID(); } @@ -195,11 +201,16 @@ TaskFilename(StringInfo directoryName, uint32 taskId) } -/* Helper function to transfer the remote file in an idempotent manner. */ +/* + * FetchRegularFileAsSuperUser copies a file from a remote node in an idempotent + * manner. It connects to the remote node as superuser to give file access. + * Callers must make sure that the file names are sanitized. + */ static void -FetchRegularFile(const char *nodeName, uint32 nodePort, - StringInfo remoteFilename, StringInfo localFilename) +FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort, + StringInfo remoteFilename, StringInfo localFilename) { + char *nodeUser = NULL; StringInfo attemptFilename = NULL; StringInfo transmitCommand = NULL; uint32 randomId = (uint32) random(); @@ -218,7 +229,11 @@ FetchRegularFile(const char *nodeName, uint32 nodePort, transmitCommand = makeStringInfo(); appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilename->data); - received = ReceiveRegularFile(nodeName, nodePort, transmitCommand, attemptFilename); + /* connect as superuser to give file access */ + nodeUser = CitusExtensionOwnerName(); + + received = ReceiveRegularFile(nodeName, nodePort, nodeUser, transmitCommand, + attemptFilename); if (!received) { ereport(ERROR, (errmsg("could not receive file \"%s\" from %s:%u", @@ -245,7 +260,7 @@ FetchRegularFile(const char *nodeName, uint32 nodePort, * and returns false. */ static bool -ReceiveRegularFile(const char *nodeName, uint32 nodePort, +ReceiveRegularFile(const char *nodeName, uint32 nodePort, const char *nodeUser, StringInfo transmitCommand, StringInfo filePath) { int32 fileDescriptor = -1; @@ -277,7 +292,7 @@ ReceiveRegularFile(const char *nodeName, uint32 nodePort, nodeDatabase = get_database_name(MyDatabaseId); /* connect to remote node */ - connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL); + connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, nodeUser); if (connectionId == INVALID_CONNECTION_ID) { ReceiveResourceCleanup(connectionId, filename, fileDescriptor); @@ -822,7 +837,8 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName) remoteCopyCommand = makeStringInfo(); appendStringInfo(remoteCopyCommand, COPY_OUT_COMMAND, tableName); - received = ReceiveRegularFile(nodeName, nodePort, remoteCopyCommand, localFilePath); + received = ReceiveRegularFile(nodeName, nodePort, NULL, remoteCopyCommand, + localFilePath); if (!received) { return false; @@ -897,6 +913,7 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName) static bool FetchForeignTable(const char *nodeName, uint32 nodePort, const char *tableName) { + const char *nodeUser = NULL; StringInfo localFilePath = NULL; StringInfo remoteFilePath = NULL; StringInfo transmitCommand = NULL; @@ -923,7 +940,16 @@ FetchForeignTable(const char *nodeName, uint32 nodePort, const char *tableName) transmitCommand = makeStringInfo(); appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilePath->data); - received = ReceiveRegularFile(nodeName, nodePort, transmitCommand, localFilePath); + /* + * We allow some arbitrary input in the file name and connect to the remote + * node as superuser to transmit. Therefore, we only allow calling this + * function when already running as superuser. + */ + EnsureSuperUser(); + nodeUser = CitusExtensionOwnerName(); + + received = ReceiveRegularFile(nodeName, nodePort, nodeUser, transmitCommand, + localFilePath); if (!received) { return false; @@ -1240,7 +1266,7 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS) sourceCopyCommand = makeStringInfo(); appendStringInfo(sourceCopyCommand, COPY_OUT_COMMAND, sourceQualifiedName); - received = ReceiveRegularFile(sourceNodeName, sourceNodePort, sourceCopyCommand, + received = ReceiveRegularFile(sourceNodeName, sourceNodePort, NULL, sourceCopyCommand, localFilePath); if (!received) { diff --git a/src/include/distributed/multi_client_executor.h b/src/include/distributed/multi_client_executor.h index 56d889f4d..0b6561ea3 100644 --- a/src/include/distributed/multi_client_executor.h +++ b/src/include/distributed/multi_client_executor.h @@ -98,7 +98,7 @@ typedef struct WaitInfo extern int32 MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase, const char *nodeUser); extern int32 MultiClientConnectStart(const char *nodeName, uint32 nodePort, - const char *nodeDatabase); + const char *nodeDatabase, const char *nodeUser); extern ConnectStatus MultiClientConnectPoll(int32 connectionId); extern void MultiClientDisconnect(int32 connectionId); extern bool MultiClientConnectionUp(int32 connectionId); diff --git a/src/include/distributed/multi_server_executor.h b/src/include/distributed/multi_server_executor.h index 5c97d2575..eb9c87e4e 100644 --- a/src/include/distributed/multi_server_executor.h +++ b/src/include/distributed/multi_server_executor.h @@ -156,6 +156,7 @@ typedef struct TaskTracker { uint32 workerPort; /* node's port; part of hash table key */ char workerName[WORKER_LENGTH]; /* node's name; part of hash table key */ + char *userName; /* which user to connect as */ TrackerStatus trackerStatus; int32 connectionId; uint32 connectPollCount;