Add user ID suffix to intermediate files in re-partition jobs

pull/2489/head
Marco Slot 2018-11-21 17:36:48 +01:00
parent f608739b4f
commit 6aa5592e52
10 changed files with 131 additions and 14 deletions

View File

@ -2609,6 +2609,7 @@ ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, const char *queryS
Query *query = NULL; Query *query = NULL;
Node *queryNode = copyStatement->query; Node *queryNode = copyStatement->query;
List *queryTreeList = NIL; List *queryTreeList = NIL;
StringInfo userFilePath = makeStringInfo();
#if (PG_VERSION_NUM >= 100000) #if (PG_VERSION_NUM >= 100000)
RawStmt *rawStmt = makeNode(RawStmt); RawStmt *rawStmt = makeNode(RawStmt);
@ -2625,6 +2626,14 @@ ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, const char *queryS
} }
query = (Query *) linitial(queryTreeList); query = (Query *) linitial(queryTreeList);
/*
* Add a user ID suffix to prevent other users from reading/writing
* the same file. We do this consistently in all functions that interact
* with task files.
*/
appendStringInfo(userFilePath, "%s.%u", filename, GetUserId());
tuplesSent = WorkerExecuteSqlTask(query, filename, binaryCopyFormat); tuplesSent = WorkerExecuteSqlTask(query, filename, binaryCopyFormat);
snprintf(completionTag, COMPLETION_TAG_BUFSIZE, snprintf(completionTag, COMPLETION_TAG_BUFSIZE,

View File

@ -359,6 +359,32 @@ IsTransmitStmt(Node *parsetree)
} }
/*
* TransmitStatementUser extracts the user attribute from a
* COPY ... (format 'transmit', user '...') statement.
*/
char *
TransmitStatementUser(CopyStmt *copyStatement)
{
ListCell *optionCell = NULL;
char *userName = NULL;
AssertArg(IsTransmitStmt((Node *) copyStatement));
foreach(optionCell, copyStatement->options)
{
DefElem *defel = (DefElem *) lfirst(optionCell);
if (strncmp(defel->defname, "user", NAMEDATALEN) == 0)
{
userName = defGetString(defel);
}
}
return userName;
}
/* /*
* VerifyTransmitStmt checks that the passed in command is a valid transmit * VerifyTransmitStmt checks that the passed in command is a valid transmit
* statement. Raise ERROR if not. * statement. Raise ERROR if not.

View File

@ -222,17 +222,28 @@ multi_ProcessUtility(PlannedStmt *pstmt,
if (IsTransmitStmt(parsetree)) if (IsTransmitStmt(parsetree))
{ {
CopyStmt *copyStatement = (CopyStmt *) parsetree; CopyStmt *copyStatement = (CopyStmt *) parsetree;
char *userName = TransmitStatementUser(copyStatement);
bool missingOK = false;
StringInfo transmitPath = makeStringInfo();
VerifyTransmitStmt(copyStatement); VerifyTransmitStmt(copyStatement);
/* ->relation->relname is the target file in our overloaded COPY */ /* ->relation->relname is the target file in our overloaded COPY */
appendStringInfoString(transmitPath, copyStatement->relation->relname);
if (userName != NULL)
{
Oid userId = get_role_oid(userName, missingOK);
appendStringInfo(transmitPath, ".%d", userId);
}
if (copyStatement->is_from) if (copyStatement->is_from)
{ {
RedirectCopyDataToRegularFile(copyStatement->relation->relname); RedirectCopyDataToRegularFile(transmitPath->data);
} }
else else
{ {
SendRegularFile(copyStatement->relation->relname); SendRegularFile(transmitPath->data);
} }
/* Don't execute the faux copy statement */ /* Don't execute the faux copy statement */

View File

@ -2590,9 +2590,11 @@ ManageTransmitTracker(TaskTracker *transmitTracker)
int32 connectionId = transmitTracker->connectionId; int32 connectionId = transmitTracker->connectionId;
StringInfo jobDirectoryName = JobDirectoryName(transmitState->jobId); StringInfo jobDirectoryName = JobDirectoryName(transmitState->jobId);
StringInfo taskFilename = TaskFilename(jobDirectoryName, transmitState->taskId); StringInfo taskFilename = TaskFilename(jobDirectoryName, transmitState->taskId);
char *userName = CurrentUserName();
StringInfo fileTransmitQuery = makeStringInfo(); StringInfo fileTransmitQuery = makeStringInfo();
appendStringInfo(fileTransmitQuery, TRANSMIT_REGULAR_COMMAND, taskFilename->data); appendStringInfo(fileTransmitQuery, TRANSMIT_WITH_USER_COMMAND,
taskFilename->data, quote_literal_cstr(userName));
fileTransmitStarted = MultiClientSendQuery(connectionId, fileTransmitQuery->data); fileTransmitStarted = MultiClientSendQuery(connectionId, fileTransmitQuery->data);
if (fileTransmitStarted) if (fileTransmitStarted)

View File

@ -108,7 +108,7 @@ worker_fetch_partition_file(PG_FUNCTION_ARGS)
/* local filename is <jobId>/<upstreamTaskId>/<partitionTaskId> */ /* local filename is <jobId>/<upstreamTaskId>/<partitionTaskId> */
StringInfo taskDirectoryName = TaskDirectoryName(jobId, upstreamTaskId); StringInfo taskDirectoryName = TaskDirectoryName(jobId, upstreamTaskId);
StringInfo taskFilename = TaskFilename(taskDirectoryName, partitionTaskId); StringInfo taskFilename = UserTaskFilename(taskDirectoryName, partitionTaskId);
/* /*
* If we are the first function to fetch a file for the upstream task, the * If we are the first function to fetch a file for the upstream task, the
@ -154,7 +154,7 @@ worker_fetch_query_results_file(PG_FUNCTION_ARGS)
/* local filename is <jobId>/<upstreamTaskId>/<queryTaskId> */ /* local filename is <jobId>/<upstreamTaskId>/<queryTaskId> */
StringInfo taskDirectoryName = TaskDirectoryName(jobId, upstreamTaskId); StringInfo taskDirectoryName = TaskDirectoryName(jobId, upstreamTaskId);
StringInfo taskFilename = TaskFilename(taskDirectoryName, queryTaskId); StringInfo taskFilename = UserTaskFilename(taskDirectoryName, queryTaskId);
/* /*
* If we are the first function to fetch a file for the upstream task, the * If we are the first function to fetch a file for the upstream task, the
@ -191,6 +191,21 @@ TaskFilename(StringInfo directoryName, uint32 taskId)
} }
/*
* UserTaskFilename returns a full file path for a task file including the
* current user ID as a suffix.
*/
StringInfo
UserTaskFilename(StringInfo directoryName, uint32 taskId)
{
StringInfo taskFilename = TaskFilename(directoryName, taskId);
appendStringInfo(taskFilename, ".%u", GetUserId());
return taskFilename;
}
/* /*
* FetchRegularFileAsSuperUser copies a file from a remote node in an idempotent * FetchRegularFileAsSuperUser copies a file from a remote node in an idempotent
* manner. It connects to the remote node as superuser to give file access. * manner. It connects to the remote node as superuser to give file access.
@ -203,6 +218,7 @@ FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort,
char *nodeUser = NULL; char *nodeUser = NULL;
StringInfo attemptFilename = NULL; StringInfo attemptFilename = NULL;
StringInfo transmitCommand = NULL; StringInfo transmitCommand = NULL;
char *userName = CurrentUserName();
uint32 randomId = (uint32) random(); uint32 randomId = (uint32) random();
bool received = false; bool received = false;
int renamed = 0; int renamed = 0;
@ -217,7 +233,8 @@ FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort,
MIN_TASK_FILENAME_WIDTH, randomId, ATTEMPT_FILE_SUFFIX); MIN_TASK_FILENAME_WIDTH, randomId, ATTEMPT_FILE_SUFFIX);
transmitCommand = makeStringInfo(); transmitCommand = makeStringInfo();
appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilename->data); appendStringInfo(transmitCommand, TRANSMIT_WITH_USER_COMMAND, remoteFilename->data,
quote_literal_cstr(userName));
/* connect as superuser to give file access */ /* connect as superuser to give file access */
nodeUser = CitusExtensionOwnerName(); nodeUser = CitusExtensionOwnerName();

View File

@ -23,6 +23,7 @@
#include "catalog/pg_namespace.h" #include "catalog/pg_namespace.h"
#include "commands/copy.h" #include "commands/copy.h"
#include "commands/tablecmds.h" #include "commands/tablecmds.h"
#include "common/string.h"
#include "distributed/metadata_cache.h" #include "distributed/metadata_cache.h"
#include "distributed/worker_protocol.h" #include "distributed/worker_protocol.h"
#include "distributed/version_compat.h" #include "distributed/version_compat.h"
@ -42,7 +43,7 @@ static List * ArrayObjectToCStringList(ArrayType *arrayObject);
static void CreateTaskTable(StringInfo schemaName, StringInfo relationName, static void CreateTaskTable(StringInfo schemaName, StringInfo relationName,
List *columnNameList, List *columnTypeList); List *columnNameList, List *columnTypeList);
static void CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName, static void CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName,
StringInfo sourceDirectoryName); StringInfo sourceDirectoryName, Oid userId);
/* exports for SQL callable functions */ /* exports for SQL callable functions */
@ -78,6 +79,7 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS)
List *columnTypeList = NIL; List *columnTypeList = NIL;
Oid savedUserId = InvalidOid; Oid savedUserId = InvalidOid;
int savedSecurityContext = 0; int savedSecurityContext = 0;
Oid userId = GetUserId();
/* we should have the same number of column names and types */ /* we should have the same number of column names and types */
int32 columnNameCount = ArrayObjectCount(columnNameObject); int32 columnNameCount = ArrayObjectCount(columnNameObject);
@ -112,7 +114,8 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS)
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
CopyTaskFilesFromDirectory(jobSchemaName, taskTableName, taskDirectoryName); CopyTaskFilesFromDirectory(jobSchemaName, taskTableName, taskDirectoryName,
userId);
SetUserIdAndSecContext(savedUserId, savedSecurityContext); SetUserIdAndSecContext(savedUserId, savedSecurityContext);
@ -155,6 +158,7 @@ worker_merge_files_and_run_query(PG_FUNCTION_ARGS)
int createMergeTableResult = 0; int createMergeTableResult = 0;
int createIntermediateTableResult = 0; int createIntermediateTableResult = 0;
int finished = 0; int finished = 0;
Oid userId = GetUserId();
CheckCitusVersion(ERROR); CheckCitusVersion(ERROR);
@ -196,7 +200,8 @@ worker_merge_files_and_run_query(PG_FUNCTION_ARGS)
appendStringInfo(mergeTableName, "%s%s", intermediateTableName->data, appendStringInfo(mergeTableName, "%s%s", intermediateTableName->data,
MERGE_TABLE_SUFFIX); MERGE_TABLE_SUFFIX);
CopyTaskFilesFromDirectory(jobSchemaName, mergeTableName, taskDirectoryName); CopyTaskFilesFromDirectory(jobSchemaName, mergeTableName, taskDirectoryName,
userId);
createIntermediateTableResult = SPI_exec(createIntermediateTableQuery, 0); createIntermediateTableResult = SPI_exec(createIntermediateTableQuery, 0);
if (createIntermediateTableResult < 0) if (createIntermediateTableResult < 0)
@ -482,14 +487,20 @@ CreateStatement(RangeVar *relation, List *columnDefinitionList)
* CopyTaskFilesFromDirectory finds all files in the given directory, except for * CopyTaskFilesFromDirectory finds all files in the given directory, except for
* those having an attempt suffix. The function then copies these files into the * those having an attempt suffix. The function then copies these files into the
* database table identified by the given schema and table name. * database table identified by the given schema and table name.
*
* The function makes sure all files were generated by the current user by checking
* whether the filename ends with the username, since this is added to local file
* names by functions such as worker_fetch_partition-file. Files that were generated
* by other users calling worker_fetch_partition_file directly are skipped.
*/ */
static void static void
CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName, CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName,
StringInfo sourceDirectoryName) StringInfo sourceDirectoryName, Oid userId)
{ {
const char *directoryName = sourceDirectoryName->data; const char *directoryName = sourceDirectoryName->data;
struct dirent *directoryEntry = NULL; struct dirent *directoryEntry = NULL;
uint64 copiedRowTotal = 0; uint64 copiedRowTotal = 0;
StringInfo expectedFileSuffix = makeStringInfo();
DIR *directory = AllocateDir(directoryName); DIR *directory = AllocateDir(directoryName);
if (directory == NULL) if (directory == NULL)
@ -498,6 +509,8 @@ CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName,
errmsg("could not open directory \"%s\": %m", directoryName))); errmsg("could not open directory \"%s\": %m", directoryName)));
} }
appendStringInfo(expectedFileSuffix, ".%u", userId);
directoryEntry = ReadDir(directory, directoryName); directoryEntry = ReadDir(directory, directoryName);
for (; directoryEntry != NULL; directoryEntry = ReadDir(directory, directoryName)) for (; directoryEntry != NULL; directoryEntry = ReadDir(directory, directoryName))
{ {
@ -516,6 +529,18 @@ CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName,
continue; continue;
} }
if (!pg_str_endswith(baseFilename, expectedFileSuffix->data))
{
/*
* Someone is trying to tamper with our results. We don't throw an error
* here because we don't want to allow users to prevent each other from
* running queries.
*/
ereport(WARNING, (errmsg("Task file \"%s\" does not have expected suffix "
"\"%s\"", baseFilename, expectedFileSuffix->data)));
continue;
}
fullFilename = makeStringInfo(); fullFilename = makeStringInfo();
appendStringInfo(fullFilename, "%s/%s", directoryName, baseFilename); appendStringInfo(fullFilename, "%s/%s", directoryName, baseFilename);

View File

@ -16,6 +16,7 @@
#include "postgres.h" #include "postgres.h"
#include "funcapi.h" #include "funcapi.h"
#include "miscadmin.h"
#include "pgstat.h" #include "pgstat.h"
#include <arpa/inet.h> #include <arpa/inet.h>
@ -80,6 +81,7 @@ static void OutputBinaryFooters(FileOutputStream *partitionFileArray, uint32 fil
static uint32 RangePartitionId(Datum partitionValue, const void *context); static uint32 RangePartitionId(Datum partitionValue, const void *context);
static uint32 HashPartitionId(Datum partitionValue, const void *context); static uint32 HashPartitionId(Datum partitionValue, const void *context);
static uint32 HashPartitionIdViaDeprecatedAPI(Datum partitionValue, const void *context); static uint32 HashPartitionIdViaDeprecatedAPI(Datum partitionValue, const void *context);
static StringInfo UserPartitionFilename(StringInfo directoryName, uint32 partitionId);
static bool FileIsLink(char *filename, struct stat filestat); static bool FileIsLink(char *filename, struct stat filestat);
@ -509,7 +511,7 @@ OpenPartitionFiles(StringInfo directoryName, uint32 fileCount)
for (fileIndex = 0; fileIndex < fileCount; fileIndex++) for (fileIndex = 0; fileIndex < fileCount; fileIndex++)
{ {
StringInfo filePath = PartitionFilename(directoryName, fileIndex); StringInfo filePath = UserPartitionFilename(directoryName, fileIndex);
fileDescriptor = PathNameOpenFilePerm(filePath->data, fileFlags, fileMode); fileDescriptor = PathNameOpenFilePerm(filePath->data, fileFlags, fileMode);
if (fileDescriptor < 0) if (fileDescriptor < 0)
@ -606,7 +608,14 @@ TaskDirectoryName(uint64 jobId, uint32 taskId)
} }
/* Constructs a standardized partition file path for given directory and id. */ /*
* PartitionFilename returns a partition file path for given directory and id
* which is suitable for use in worker_fetch_partition_file and tranmsit.
*
* It excludes the user ID part at the end of the filename, since that is
* added by worker_fetch_partition_file itself based on the current user.
* For the full path use UserPartitionFilename.
*/
StringInfo StringInfo
PartitionFilename(StringInfo directoryName, uint32 partitionId) PartitionFilename(StringInfo directoryName, uint32 partitionId)
{ {
@ -619,6 +628,21 @@ PartitionFilename(StringInfo directoryName, uint32 partitionId)
} }
/*
* UserPartitionFilename returns the path of a partition file for the given
* partition ID and the current user.
*/
static StringInfo
UserPartitionFilename(StringInfo directoryName, uint32 partitionId)
{
StringInfo partitionFilename = PartitionFilename(directoryName, partitionId);
appendStringInfo(partitionFilename, ".%u", GetUserId());
return partitionFilename;
}
/* /*
* JobDirectoryElement takes in a filename, and checks if this name lives in the * JobDirectoryElement takes in a filename, and checks if this name lives in the
* directory path that is used for task output files. Note that this function's * directory path that is used for task output files. Note that this function's

View File

@ -82,7 +82,7 @@ worker_execute_sql_task(PG_FUNCTION_ARGS)
/* job directory is created prior to scheduling the task */ /* job directory is created prior to scheduling the task */
StringInfo jobDirectoryName = JobDirectoryName(jobId); StringInfo jobDirectoryName = JobDirectoryName(jobId);
StringInfo taskFilename = TaskFilename(jobDirectoryName, taskId); StringInfo taskFilename = UserTaskFilename(jobDirectoryName, taskId);
query = ParseQueryString(queryString); query = ParseQueryString(queryString);
tuplesSent = WorkerExecuteSqlTask(query, taskFilename->data, binaryCopyFormat); tuplesSent = WorkerExecuteSqlTask(query, taskFilename->data, binaryCopyFormat);

View File

@ -28,6 +28,7 @@ extern void FreeStringInfo(StringInfo stringInfo);
/* Local functions forward declarations for Transmit statement */ /* Local functions forward declarations for Transmit statement */
extern bool IsTransmitStmt(Node *parsetree); extern bool IsTransmitStmt(Node *parsetree);
extern char * TransmitStatementUser(CopyStmt *copyStatement);
extern void VerifyTransmitStmt(CopyStmt *copyStatement); extern void VerifyTransmitStmt(CopyStmt *copyStatement);

View File

@ -44,7 +44,8 @@
/* Defines used for fetching files and tables */ /* Defines used for fetching files and tables */
/* the tablename in the overloaded COPY statement is the to-be-transferred file */ /* the tablename in the overloaded COPY statement is the to-be-transferred file */
#define TRANSMIT_REGULAR_COMMAND "COPY \"%s\" TO STDOUT WITH (format 'transmit')" #define TRANSMIT_WITH_USER_COMMAND \
"COPY \"%s\" TO STDOUT WITH (format 'transmit', user %s)"
#define COPY_OUT_COMMAND "COPY %s TO STDOUT" #define COPY_OUT_COMMAND "COPY %s TO STDOUT"
#define COPY_SELECT_ALL_OUT_COMMAND "COPY (SELECT * FROM %s) TO STDOUT" #define COPY_SELECT_ALL_OUT_COMMAND "COPY (SELECT * FROM %s) TO STDOUT"
#define COPY_IN_COMMAND "COPY %s FROM '%s'" #define COPY_IN_COMMAND "COPY %s FROM '%s'"
@ -128,6 +129,7 @@ extern int64 WorkerExecuteSqlTask(Query *query, char *taskFilename,
/* Function declarations shared with the master planner */ /* Function declarations shared with the master planner */
extern StringInfo TaskFilename(StringInfo directoryName, uint32 taskId); extern StringInfo TaskFilename(StringInfo directoryName, uint32 taskId);
extern StringInfo UserTaskFilename(StringInfo directoryName, uint32 taskId);
extern List * ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, char *runAsUser, extern List * ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, char *runAsUser,
StringInfo queryString); StringInfo queryString);
extern List * ColumnDefinitionList(List *columnNameList, List *columnTypeList); extern List * ColumnDefinitionList(List *columnNameList, List *columnTypeList);