diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index 8887513cc..ee8cd5700 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -2609,6 +2609,7 @@ ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, const char *queryS Query *query = NULL; Node *queryNode = copyStatement->query; List *queryTreeList = NIL; + StringInfo userFilePath = makeStringInfo(); #if (PG_VERSION_NUM >= 100000) RawStmt *rawStmt = makeNode(RawStmt); @@ -2625,6 +2626,14 @@ ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, const char *queryS } query = (Query *) linitial(queryTreeList); + + /* + * Add a user ID suffix to prevent other users from reading/writing + * the same file. We do this consistently in all functions that interact + * with task files. + */ + appendStringInfo(userFilePath, "%s.%u", filename, GetUserId()); + tuplesSent = WorkerExecuteSqlTask(query, filename, binaryCopyFormat); snprintf(completionTag, COMPLETION_TAG_BUFSIZE, diff --git a/src/backend/distributed/commands/transmit.c b/src/backend/distributed/commands/transmit.c index 95c90bba9..fa976661c 100644 --- a/src/backend/distributed/commands/transmit.c +++ b/src/backend/distributed/commands/transmit.c @@ -359,6 +359,32 @@ IsTransmitStmt(Node *parsetree) } +/* + * TransmitStatementUser extracts the user attribute from a + * COPY ... (format 'transmit', user '...') statement. + */ +char * +TransmitStatementUser(CopyStmt *copyStatement) +{ + ListCell *optionCell = NULL; + char *userName = NULL; + + AssertArg(IsTransmitStmt((Node *) copyStatement)); + + foreach(optionCell, copyStatement->options) + { + DefElem *defel = (DefElem *) lfirst(optionCell); + + if (strncmp(defel->defname, "user", NAMEDATALEN) == 0) + { + userName = defGetString(defel); + } + } + + return userName; +} + + /* * VerifyTransmitStmt checks that the passed in command is a valid transmit * statement. Raise ERROR if not. diff --git a/src/backend/distributed/commands/utility_hook.c b/src/backend/distributed/commands/utility_hook.c index b247e4952..f7699a94e 100644 --- a/src/backend/distributed/commands/utility_hook.c +++ b/src/backend/distributed/commands/utility_hook.c @@ -222,17 +222,28 @@ multi_ProcessUtility(PlannedStmt *pstmt, if (IsTransmitStmt(parsetree)) { CopyStmt *copyStatement = (CopyStmt *) parsetree; + char *userName = TransmitStatementUser(copyStatement); + bool missingOK = false; + StringInfo transmitPath = makeStringInfo(); VerifyTransmitStmt(copyStatement); /* ->relation->relname is the target file in our overloaded COPY */ + appendStringInfoString(transmitPath, copyStatement->relation->relname); + + if (userName != NULL) + { + Oid userId = get_role_oid(userName, missingOK); + appendStringInfo(transmitPath, ".%d", userId); + } + if (copyStatement->is_from) { - RedirectCopyDataToRegularFile(copyStatement->relation->relname); + RedirectCopyDataToRegularFile(transmitPath->data); } else { - SendRegularFile(copyStatement->relation->relname); + SendRegularFile(transmitPath->data); } /* Don't execute the faux copy statement */ diff --git a/src/backend/distributed/executor/multi_task_tracker_executor.c b/src/backend/distributed/executor/multi_task_tracker_executor.c index 724359b57..ae4dcd682 100644 --- a/src/backend/distributed/executor/multi_task_tracker_executor.c +++ b/src/backend/distributed/executor/multi_task_tracker_executor.c @@ -2590,9 +2590,11 @@ ManageTransmitTracker(TaskTracker *transmitTracker) int32 connectionId = transmitTracker->connectionId; StringInfo jobDirectoryName = JobDirectoryName(transmitState->jobId); StringInfo taskFilename = TaskFilename(jobDirectoryName, transmitState->taskId); + char *userName = CurrentUserName(); StringInfo fileTransmitQuery = makeStringInfo(); - appendStringInfo(fileTransmitQuery, TRANSMIT_REGULAR_COMMAND, taskFilename->data); + appendStringInfo(fileTransmitQuery, TRANSMIT_WITH_USER_COMMAND, + taskFilename->data, quote_literal_cstr(userName)); fileTransmitStarted = MultiClientSendQuery(connectionId, fileTransmitQuery->data); if (fileTransmitStarted) diff --git a/src/backend/distributed/worker/worker_data_fetch_protocol.c b/src/backend/distributed/worker/worker_data_fetch_protocol.c index b90e19f9c..d0533158a 100644 --- a/src/backend/distributed/worker/worker_data_fetch_protocol.c +++ b/src/backend/distributed/worker/worker_data_fetch_protocol.c @@ -108,7 +108,7 @@ worker_fetch_partition_file(PG_FUNCTION_ARGS) /* local filename is // */ StringInfo taskDirectoryName = TaskDirectoryName(jobId, upstreamTaskId); - StringInfo taskFilename = TaskFilename(taskDirectoryName, partitionTaskId); + StringInfo taskFilename = UserTaskFilename(taskDirectoryName, partitionTaskId); /* * If we are the first function to fetch a file for the upstream task, the @@ -154,7 +154,7 @@ worker_fetch_query_results_file(PG_FUNCTION_ARGS) /* local filename is // */ StringInfo taskDirectoryName = TaskDirectoryName(jobId, upstreamTaskId); - StringInfo taskFilename = TaskFilename(taskDirectoryName, queryTaskId); + StringInfo taskFilename = UserTaskFilename(taskDirectoryName, queryTaskId); /* * If we are the first function to fetch a file for the upstream task, the @@ -191,6 +191,21 @@ TaskFilename(StringInfo directoryName, uint32 taskId) } +/* + * UserTaskFilename returns a full file path for a task file including the + * current user ID as a suffix. + */ +StringInfo +UserTaskFilename(StringInfo directoryName, uint32 taskId) +{ + StringInfo taskFilename = TaskFilename(directoryName, taskId); + + appendStringInfo(taskFilename, ".%u", GetUserId()); + + return taskFilename; +} + + /* * FetchRegularFileAsSuperUser copies a file from a remote node in an idempotent * manner. It connects to the remote node as superuser to give file access. @@ -203,6 +218,7 @@ FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort, char *nodeUser = NULL; StringInfo attemptFilename = NULL; StringInfo transmitCommand = NULL; + char *userName = CurrentUserName(); uint32 randomId = (uint32) random(); bool received = false; int renamed = 0; @@ -217,7 +233,8 @@ FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort, MIN_TASK_FILENAME_WIDTH, randomId, ATTEMPT_FILE_SUFFIX); transmitCommand = makeStringInfo(); - appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilename->data); + appendStringInfo(transmitCommand, TRANSMIT_WITH_USER_COMMAND, remoteFilename->data, + quote_literal_cstr(userName)); /* connect as superuser to give file access */ nodeUser = CitusExtensionOwnerName(); diff --git a/src/backend/distributed/worker/worker_merge_protocol.c b/src/backend/distributed/worker/worker_merge_protocol.c index f95d19d16..74331f984 100644 --- a/src/backend/distributed/worker/worker_merge_protocol.c +++ b/src/backend/distributed/worker/worker_merge_protocol.c @@ -23,6 +23,7 @@ #include "catalog/pg_namespace.h" #include "commands/copy.h" #include "commands/tablecmds.h" +#include "common/string.h" #include "distributed/metadata_cache.h" #include "distributed/worker_protocol.h" #include "distributed/version_compat.h" @@ -42,7 +43,7 @@ static List * ArrayObjectToCStringList(ArrayType *arrayObject); static void CreateTaskTable(StringInfo schemaName, StringInfo relationName, List *columnNameList, List *columnTypeList); static void CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName, - StringInfo sourceDirectoryName); + StringInfo sourceDirectoryName, Oid userId); /* exports for SQL callable functions */ @@ -78,6 +79,7 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS) List *columnTypeList = NIL; Oid savedUserId = InvalidOid; int savedSecurityContext = 0; + Oid userId = GetUserId(); /* we should have the same number of column names and types */ int32 columnNameCount = ArrayObjectCount(columnNameObject); @@ -112,7 +114,8 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS) GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); - CopyTaskFilesFromDirectory(jobSchemaName, taskTableName, taskDirectoryName); + CopyTaskFilesFromDirectory(jobSchemaName, taskTableName, taskDirectoryName, + userId); SetUserIdAndSecContext(savedUserId, savedSecurityContext); @@ -155,6 +158,7 @@ worker_merge_files_and_run_query(PG_FUNCTION_ARGS) int createMergeTableResult = 0; int createIntermediateTableResult = 0; int finished = 0; + Oid userId = GetUserId(); CheckCitusVersion(ERROR); @@ -196,7 +200,8 @@ worker_merge_files_and_run_query(PG_FUNCTION_ARGS) appendStringInfo(mergeTableName, "%s%s", intermediateTableName->data, MERGE_TABLE_SUFFIX); - CopyTaskFilesFromDirectory(jobSchemaName, mergeTableName, taskDirectoryName); + CopyTaskFilesFromDirectory(jobSchemaName, mergeTableName, taskDirectoryName, + userId); createIntermediateTableResult = SPI_exec(createIntermediateTableQuery, 0); if (createIntermediateTableResult < 0) @@ -482,14 +487,20 @@ CreateStatement(RangeVar *relation, List *columnDefinitionList) * CopyTaskFilesFromDirectory finds all files in the given directory, except for * those having an attempt suffix. The function then copies these files into the * database table identified by the given schema and table name. + * + * The function makes sure all files were generated by the current user by checking + * whether the filename ends with the username, since this is added to local file + * names by functions such as worker_fetch_partition-file. Files that were generated + * by other users calling worker_fetch_partition_file directly are skipped. */ static void CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName, - StringInfo sourceDirectoryName) + StringInfo sourceDirectoryName, Oid userId) { const char *directoryName = sourceDirectoryName->data; struct dirent *directoryEntry = NULL; uint64 copiedRowTotal = 0; + StringInfo expectedFileSuffix = makeStringInfo(); DIR *directory = AllocateDir(directoryName); if (directory == NULL) @@ -498,6 +509,8 @@ CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName, errmsg("could not open directory \"%s\": %m", directoryName))); } + appendStringInfo(expectedFileSuffix, ".%u", userId); + directoryEntry = ReadDir(directory, directoryName); for (; directoryEntry != NULL; directoryEntry = ReadDir(directory, directoryName)) { @@ -516,6 +529,18 @@ CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName, continue; } + if (!pg_str_endswith(baseFilename, expectedFileSuffix->data)) + { + /* + * Someone is trying to tamper with our results. We don't throw an error + * here because we don't want to allow users to prevent each other from + * running queries. + */ + ereport(WARNING, (errmsg("Task file \"%s\" does not have expected suffix " + "\"%s\"", baseFilename, expectedFileSuffix->data))); + continue; + } + fullFilename = makeStringInfo(); appendStringInfo(fullFilename, "%s/%s", directoryName, baseFilename); diff --git a/src/backend/distributed/worker/worker_partition_protocol.c b/src/backend/distributed/worker/worker_partition_protocol.c index 386e7fd28..d5ca9b44f 100644 --- a/src/backend/distributed/worker/worker_partition_protocol.c +++ b/src/backend/distributed/worker/worker_partition_protocol.c @@ -16,6 +16,7 @@ #include "postgres.h" #include "funcapi.h" +#include "miscadmin.h" #include "pgstat.h" #include @@ -80,6 +81,7 @@ static void OutputBinaryFooters(FileOutputStream *partitionFileArray, uint32 fil static uint32 RangePartitionId(Datum partitionValue, const void *context); static uint32 HashPartitionId(Datum partitionValue, const void *context); static uint32 HashPartitionIdViaDeprecatedAPI(Datum partitionValue, const void *context); +static StringInfo UserPartitionFilename(StringInfo directoryName, uint32 partitionId); static bool FileIsLink(char *filename, struct stat filestat); @@ -509,7 +511,7 @@ OpenPartitionFiles(StringInfo directoryName, uint32 fileCount) for (fileIndex = 0; fileIndex < fileCount; fileIndex++) { - StringInfo filePath = PartitionFilename(directoryName, fileIndex); + StringInfo filePath = UserPartitionFilename(directoryName, fileIndex); fileDescriptor = PathNameOpenFilePerm(filePath->data, fileFlags, fileMode); if (fileDescriptor < 0) @@ -606,7 +608,14 @@ TaskDirectoryName(uint64 jobId, uint32 taskId) } -/* Constructs a standardized partition file path for given directory and id. */ +/* + * PartitionFilename returns a partition file path for given directory and id + * which is suitable for use in worker_fetch_partition_file and tranmsit. + * + * It excludes the user ID part at the end of the filename, since that is + * added by worker_fetch_partition_file itself based on the current user. + * For the full path use UserPartitionFilename. + */ StringInfo PartitionFilename(StringInfo directoryName, uint32 partitionId) { @@ -619,6 +628,21 @@ PartitionFilename(StringInfo directoryName, uint32 partitionId) } +/* + * UserPartitionFilename returns the path of a partition file for the given + * partition ID and the current user. + */ +static StringInfo +UserPartitionFilename(StringInfo directoryName, uint32 partitionId) +{ + StringInfo partitionFilename = PartitionFilename(directoryName, partitionId); + + appendStringInfo(partitionFilename, ".%u", GetUserId()); + + return partitionFilename; +} + + /* * JobDirectoryElement takes in a filename, and checks if this name lives in the * directory path that is used for task output files. Note that this function's diff --git a/src/backend/distributed/worker/worker_sql_task_protocol.c b/src/backend/distributed/worker/worker_sql_task_protocol.c index 1648a0c2c..c9246038c 100644 --- a/src/backend/distributed/worker/worker_sql_task_protocol.c +++ b/src/backend/distributed/worker/worker_sql_task_protocol.c @@ -82,7 +82,7 @@ worker_execute_sql_task(PG_FUNCTION_ARGS) /* job directory is created prior to scheduling the task */ StringInfo jobDirectoryName = JobDirectoryName(jobId); - StringInfo taskFilename = TaskFilename(jobDirectoryName, taskId); + StringInfo taskFilename = UserTaskFilename(jobDirectoryName, taskId); query = ParseQueryString(queryString); tuplesSent = WorkerExecuteSqlTask(query, taskFilename->data, binaryCopyFormat); diff --git a/src/include/distributed/transmit.h b/src/include/distributed/transmit.h index 28ed0dad3..5aadafec6 100644 --- a/src/include/distributed/transmit.h +++ b/src/include/distributed/transmit.h @@ -28,6 +28,7 @@ extern void FreeStringInfo(StringInfo stringInfo); /* Local functions forward declarations for Transmit statement */ extern bool IsTransmitStmt(Node *parsetree); +extern char * TransmitStatementUser(CopyStmt *copyStatement); extern void VerifyTransmitStmt(CopyStmt *copyStatement); diff --git a/src/include/distributed/worker_protocol.h b/src/include/distributed/worker_protocol.h index b92119b7f..5dca935cb 100644 --- a/src/include/distributed/worker_protocol.h +++ b/src/include/distributed/worker_protocol.h @@ -44,7 +44,8 @@ /* Defines used for fetching files and tables */ /* the tablename in the overloaded COPY statement is the to-be-transferred file */ -#define TRANSMIT_REGULAR_COMMAND "COPY \"%s\" TO STDOUT WITH (format 'transmit')" +#define TRANSMIT_WITH_USER_COMMAND \ + "COPY \"%s\" TO STDOUT WITH (format 'transmit', user %s)" #define COPY_OUT_COMMAND "COPY %s TO STDOUT" #define COPY_SELECT_ALL_OUT_COMMAND "COPY (SELECT * FROM %s) TO STDOUT" #define COPY_IN_COMMAND "COPY %s FROM '%s'" @@ -128,6 +129,7 @@ extern int64 WorkerExecuteSqlTask(Query *query, char *taskFilename, /* Function declarations shared with the master planner */ extern StringInfo TaskFilename(StringInfo directoryName, uint32 taskId); +extern StringInfo UserTaskFilename(StringInfo directoryName, uint32 taskId); extern List * ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, char *runAsUser, StringInfo queryString); extern List * ColumnDefinitionList(List *columnNameList, List *columnTypeList);