From caf402d506318ac9c28fee527d10055c205a4aac Mon Sep 17 00:00:00 2001 From: Marco Slot Date: Thu, 22 Nov 2018 01:26:31 +0100 Subject: [PATCH] COPY to a task file no longer switches to superuser --- Makefile | 1 + src/backend/distributed/commands/multi_copy.c | 95 +++--- .../distributed/commands/utility_hook.c | 18 +- .../distributed/executor/multi_executor.c | 14 +- .../worker/worker_data_fetch_protocol.c | 12 + .../worker/worker_sql_task_protocol.c | 280 ++++++++++++++++++ src/include/distributed/commands/multi_copy.h | 3 +- src/include/distributed/multi_executor.h | 1 + src/include/distributed/worker_protocol.h | 2 + 9 files changed, 361 insertions(+), 65 deletions(-) create mode 100644 src/backend/distributed/worker/worker_sql_task_protocol.c diff --git a/Makefile b/Makefile index b3bcb59f9..e0a2b108d 100644 --- a/Makefile +++ b/Makefile @@ -128,6 +128,7 @@ OBJS = src/backend/distributed/shared_library_init.o \ src/backend/distributed/worker/worker_file_access_protocol.o \ src/backend/distributed/worker/worker_merge_protocol.o \ src/backend/distributed/worker/worker_partition_protocol.o \ + src/backend/distributed/worker/worker_sql_task_protocol.o \ src/backend/distributed/worker/worker_truncate_trigger_protocol.o \ $(WIN32RES) diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index e50441f7d..8887513cc 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -140,9 +140,9 @@ static FmgrInfo * TypeOutputFunctions(uint32 columnCount, Oid *typeIdArray, bool binaryFormat); static Datum CoerceColumnValue(Datum inputValue, CopyCoercionData *coercionPath); static void CreateLocalTable(RangeVar *relation, char *nodeName, int32 nodePort); -static void CheckCopyPermissions(CopyStmt *copyStatement); static List * CopyGetAttnums(TupleDesc tupDesc, Relation rel, List *attnamelist); static bool IsCopyResultStmt(CopyStmt *copyStatement); +static bool CopyStatementHasFormat(CopyStmt *copyStatement, char *formatName); static bool IsCopyFromWorker(CopyStmt *copyStatement); static NodeAddress * MasterNodeAddress(CopyStmt *copyStatement); static void CitusCopyFrom(CopyStmt *copyStatement, char *completionTag); @@ -2445,9 +2445,20 @@ CitusCopyDestReceiverDestroy(DestReceiver *destReceiver) */ static bool IsCopyResultStmt(CopyStmt *copyStatement) +{ + return CopyStatementHasFormat(copyStatement, "result"); +} + + +/* + * CopyStatementHasFormat checks whether the COPY statement has the given + * format. + */ +static bool +CopyStatementHasFormat(CopyStmt *copyStatement, char *formatName) { ListCell *optionCell = NULL; - bool hasFormatReceive = false; + bool hasFormat = false; /* extract WITH (...) options from the COPY statement */ foreach(optionCell, copyStatement->options) @@ -2455,14 +2466,14 @@ IsCopyResultStmt(CopyStmt *copyStatement) DefElem *defel = (DefElem *) lfirst(optionCell); if (strncmp(defel->defname, "format", NAMEDATALEN) == 0 && - strncmp(defGetString(defel), "result", NAMEDATALEN) == 0) + strncmp(defGetString(defel), formatName, NAMEDATALEN) == 0) { - hasFormatReceive = true; + hasFormat = true; break; } } - return hasFormatReceive; + return hasFormat; } @@ -2471,18 +2482,10 @@ IsCopyResultStmt(CopyStmt *copyStatement) * COPYing from distributed tables and preventing unsupported actions. The * function returns a modified COPY statement to be executed, or NULL if no * further processing is needed. - * - * commandMustRunAsOwner is an output parameter used to communicate to the caller whether - * the copy statement should be executed using elevated privileges. If - * ProcessCopyStmt that is required, a call to CheckCopyPermissions will take - * care of verifying the current user's permissions before ProcessCopyStmt - * returns. */ Node * -ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, bool *commandMustRunAsOwner) +ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, const char *queryString) { - *commandMustRunAsOwner = false; /* make sure variable is initialized */ - /* * Handle special COPY "resultid" FROM STDIN WITH (format result) commands * for sending intermediate results to workers. @@ -2591,43 +2594,43 @@ ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, bool *commandMustR if (copyStatement->filename != NULL && !copyStatement->is_program) { - const char *filename = copyStatement->filename; + char *filename = copyStatement->filename; - if (CacheDirectoryElement(filename)) + /* + * We execute COPY commands issued by the task-tracker executor here + * because we're not normally allowed to write to a file as a regular + * user and we don't want to execute the query as superuser. + */ + if (CacheDirectoryElement(filename) && copyStatement->query != NULL && + !copyStatement->is_from && !is_absolute_path(filename)) { - /* - * Only superusers are allowed to copy from a file, so we have to - * become superuser to execute copies to/from files used by citus' - * query execution. - * - * XXX: This is a decidedly suboptimal solution, as that means - * that triggers, input functions, etc. run with elevated - * privileges. But this is better than not being able to run - * queries as normal user. - */ - *commandMustRunAsOwner = true; + bool binaryCopyFormat = CopyStatementHasFormat(copyStatement, "binary"); + int64 tuplesSent = 0; + Query *query = NULL; + Node *queryNode = copyStatement->query; + List *queryTreeList = NIL; - /* - * Have to manually check permissions here as the COPY is will be - * run as a superuser. - */ - if (copyStatement->relation != NULL) +#if (PG_VERSION_NUM >= 100000) + RawStmt *rawStmt = makeNode(RawStmt); + rawStmt->stmt = queryNode; + + queryTreeList = pg_analyze_and_rewrite(rawStmt, queryString, NULL, 0, NULL); +#else + queryTreeList = pg_analyze_and_rewrite(queryNode, queryString, NULL, 0); +#endif + + if (list_length(queryTreeList) != 1) { - CheckCopyPermissions(copyStatement); + ereport(ERROR, (errmsg("can only execute a single query"))); } - /* - * Check if we have a "COPY (query) TO filename". If we do, copy - * doesn't accept relative file paths. However, SQL tasks that get - * assigned to worker nodes have relative paths. We therefore - * convert relative paths to absolute ones here. - */ - if (copyStatement->relation == NULL && - !copyStatement->is_from && - !is_absolute_path(filename)) - { - copyStatement->filename = make_absolute_path(filename); - } + query = (Query *) linitial(queryTreeList); + tuplesSent = WorkerExecuteSqlTask(query, filename, binaryCopyFormat); + + snprintf(completionTag, COMPLETION_TAG_BUFSIZE, + "COPY " UINT64_FORMAT, tuplesSent); + + return NULL; } } @@ -2724,7 +2727,7 @@ CreateLocalTable(RangeVar *relation, char *nodeName, int32 nodePort) * * Copied from postgres, where it's part of DoCopy(). */ -static void +void CheckCopyPermissions(CopyStmt *copyStatement) { /* *INDENT-OFF* */ diff --git a/src/backend/distributed/commands/utility_hook.c b/src/backend/distributed/commands/utility_hook.c index b0d52383a..b247e4952 100644 --- a/src/backend/distributed/commands/utility_hook.c +++ b/src/backend/distributed/commands/utility_hook.c @@ -128,9 +128,6 @@ multi_ProcessUtility(PlannedStmt *pstmt, char *completionTag) { Node *parsetree = pstmt->utilityStmt; - bool commandMustRunAsOwner = false; - Oid savedUserId = InvalidOid; - int savedSecurityContext = 0; List *ddlJobs = NIL; bool checkExtensionVersion = false; @@ -248,8 +245,7 @@ multi_ProcessUtility(PlannedStmt *pstmt, MemoryContext previousContext; parsetree = copyObject(parsetree); - parsetree = ProcessCopyStmt((CopyStmt *) parsetree, completionTag, - &commandMustRunAsOwner); + parsetree = ProcessCopyStmt((CopyStmt *) parsetree, completionTag, queryString); previousContext = MemoryContextSwitchTo(planContext); parsetree = copyObject(parsetree); @@ -450,13 +446,6 @@ multi_ProcessUtility(PlannedStmt *pstmt, } } - /* set user if needed and go ahead and run local utility using standard hook */ - if (commandMustRunAsOwner) - { - GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); - SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); - } - #if (PG_VERSION_NUM >= 100000) pstmt->utilityStmt = parsetree; standard_ProcessUtility(pstmt, queryString, context, @@ -495,11 +484,6 @@ multi_ProcessUtility(PlannedStmt *pstmt, PostProcessUtility(parsetree); } - if (commandMustRunAsOwner) - { - SetUserIdAndSecContext(savedUserId, savedSecurityContext); - } - /* * Re-forming the foreign key graph relies on the command being executed * on the local table first. However, in order to decide whether the diff --git a/src/backend/distributed/executor/multi_executor.c b/src/backend/distributed/executor/multi_executor.c index 73d2d88e3..172cf4598 100644 --- a/src/backend/distributed/executor/multi_executor.c +++ b/src/backend/distributed/executor/multi_executor.c @@ -337,6 +337,18 @@ StubRelation(TupleDesc tupleDescriptor) void ExecuteQueryStringIntoDestReceiver(const char *queryString, ParamListInfo params, DestReceiver *dest) +{ + Query *query = ParseQueryString(queryString); + + ExecuteQueryIntoDestReceiver(query, params, dest); +} + + +/* + * ParseQuery parses query string and returns a Query struct. + */ +Query * +ParseQueryString(const char *queryString) { Query *query = NULL; @@ -355,7 +367,7 @@ ExecuteQueryStringIntoDestReceiver(const char *queryString, ParamListInfo params query = (Query *) linitial(queryTreeList); - ExecuteQueryIntoDestReceiver(query, params, dest); + return query; } diff --git a/src/backend/distributed/worker/worker_data_fetch_protocol.c b/src/backend/distributed/worker/worker_data_fetch_protocol.c index 2ebc56c3e..b90e19f9c 100644 --- a/src/backend/distributed/worker/worker_data_fetch_protocol.c +++ b/src/backend/distributed/worker/worker_data_fetch_protocol.c @@ -32,6 +32,7 @@ #include "distributed/master_protocol.h" #include "distributed/metadata_cache.h" #include "distributed/multi_client_executor.h" +#include "distributed/commands/multi_copy.h" #include "distributed/multi_logical_optimizer.h" #include "distributed/multi_partitioning_utils.h" #include "distributed/multi_server_executor.h" @@ -766,6 +767,8 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS) StringInfo queryString = NULL; Oid sourceShardRelationId = InvalidOid; Oid sourceSchemaId = InvalidOid; + Oid savedUserId = InvalidOid; + int savedSecurityContext = 0; CheckCitusVersion(ERROR); @@ -829,9 +832,18 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS) appendStringInfo(queryString, COPY_IN_COMMAND, shardQualifiedName, localFilePath->data); + /* make sure we are allowed to execute the COPY command */ + CheckCopyPermissions(localCopyCommand); + + /* need superuser to copy from files */ + GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); + SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); + CitusProcessUtility((Node *) localCopyCommand, queryString->data, PROCESS_UTILITY_TOPLEVEL, NULL, None_Receiver, NULL); + SetUserIdAndSecContext(savedUserId, savedSecurityContext); + /* finally delete the temporary file we created */ CitusDeleteFile(localFilePath->data); diff --git a/src/backend/distributed/worker/worker_sql_task_protocol.c b/src/backend/distributed/worker/worker_sql_task_protocol.c new file mode 100644 index 000000000..dfbeca81f --- /dev/null +++ b/src/backend/distributed/worker/worker_sql_task_protocol.c @@ -0,0 +1,280 @@ +/*------------------------------------------------------------------------- + * + * worker_sql_task_protocol.c + * + * Routines for executing SQL tasks during task-tracker execution. + * + * Copyright (c) 2012-2018, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#include "postgres.h" +#include "funcapi.h" +#include "pgstat.h" + +#include "distributed/commands/multi_copy.h" +#include "distributed/multi_executor.h" +#include "distributed/transmit.h" +#include "distributed/worker_protocol.h" +#include "utils/builtins.h" +#include "utils/memutils.h" + + +/* TaskFileDestReceiver can be used to stream results into a file */ +typedef struct TaskFileDestReceiver +{ + /* public DestReceiver interface */ + DestReceiver pub; + + /* descriptor of the tuples that are sent to the worker */ + TupleDesc tupleDescriptor; + + /* EState for per-tuple memory allocation */ + EState *executorState; + + /* MemoryContext for DestReceiver session */ + MemoryContext memoryContext; + + /* output file */ + char *filePath; + File fileDesc; + bool binaryCopyFormat; + + /* state on how to copy out data types */ + CopyOutState copyOutState; + FmgrInfo *columnOutputFunctions; + + /* number of tuples sent */ + uint64 tuplesSent; +} TaskFileDestReceiver; + + +static DestReceiver * CreateTaskFileDestReceiver(char *filePath, EState *executorState, + bool binaryCopyFormat); +static void TaskFileDestReceiverStartup(DestReceiver *dest, int operation, + TupleDesc inputTupleDescriptor); +static bool TaskFileDestReceiverReceive(TupleTableSlot *slot, DestReceiver *dest); +static void WriteToLocalFile(StringInfo copyData, File fileDesc); +static void TaskFileDestReceiverShutdown(DestReceiver *destReceiver); +static void TaskFileDestReceiverDestroy(DestReceiver *destReceiver); + + +/* + * WorkerExecuteSqlTask executes an already-parsed query and writes the result + * to the given task file. + */ +int64 +WorkerExecuteSqlTask(Query *query, char *taskFilename, bool binaryCopyFormat) +{ + EState *estate = NULL; + TaskFileDestReceiver *taskFileDest = NULL; + ParamListInfo paramListInfo = NULL; + int64 tuplesSent = 0L; + + estate = CreateExecutorState(); + taskFileDest = + (TaskFileDestReceiver *) CreateTaskFileDestReceiver(taskFilename, estate, + binaryCopyFormat); + + ExecuteQueryIntoDestReceiver(query, paramListInfo, (DestReceiver *) taskFileDest); + + tuplesSent = taskFileDest->tuplesSent; + + taskFileDest->pub.rDestroy((DestReceiver *) taskFileDest); + FreeExecutorState(estate); + + return tuplesSent; +} + + +/* + * CreateTaskFileDestReceiver creates a DestReceiver for writing query results + * to a task file. + */ +static DestReceiver * +CreateTaskFileDestReceiver(char *filePath, EState *executorState, bool binaryCopyFormat) +{ + TaskFileDestReceiver *taskFileDest = NULL; + + taskFileDest = (TaskFileDestReceiver *) palloc0(sizeof(TaskFileDestReceiver)); + + /* set up the DestReceiver function pointers */ + taskFileDest->pub.receiveSlot = TaskFileDestReceiverReceive; + taskFileDest->pub.rStartup = TaskFileDestReceiverStartup; + taskFileDest->pub.rShutdown = TaskFileDestReceiverShutdown; + taskFileDest->pub.rDestroy = TaskFileDestReceiverDestroy; + taskFileDest->pub.mydest = DestCopyOut; + + /* set up output parameters */ + taskFileDest->executorState = executorState; + taskFileDest->memoryContext = CurrentMemoryContext; + taskFileDest->filePath = pstrdup(filePath); + taskFileDest->binaryCopyFormat = binaryCopyFormat; + + return (DestReceiver *) taskFileDest; +} + + +/* + * TaskFileDestReceiverStartup implements the rStartup interface of + * TaskFileDestReceiver. It opens the destination file and sets up + * the CopyOutState. + */ +static void +TaskFileDestReceiverStartup(DestReceiver *dest, int operation, + TupleDesc inputTupleDescriptor) +{ + TaskFileDestReceiver *taskFileDest = (TaskFileDestReceiver *) dest; + + CopyOutState copyOutState = NULL; + const char *delimiterCharacter = "\t"; + const char *nullPrintCharacter = "\\N"; + + const int fileFlags = (O_APPEND | O_CREAT | O_RDWR | O_TRUNC | PG_BINARY); + const int fileMode = (S_IRUSR | S_IWUSR); + + /* use the memory context that was in place when the DestReceiver was created */ + MemoryContext oldContext = MemoryContextSwitchTo(taskFileDest->memoryContext); + + taskFileDest->tupleDescriptor = inputTupleDescriptor; + + /* define how tuples will be serialised */ + copyOutState = (CopyOutState) palloc0(sizeof(CopyOutStateData)); + copyOutState->delim = (char *) delimiterCharacter; + copyOutState->null_print = (char *) nullPrintCharacter; + copyOutState->null_print_client = (char *) nullPrintCharacter; + copyOutState->binary = taskFileDest->binaryCopyFormat; + copyOutState->fe_msgbuf = makeStringInfo(); + copyOutState->rowcontext = GetPerTupleMemoryContext(taskFileDest->executorState); + taskFileDest->copyOutState = copyOutState; + + taskFileDest->columnOutputFunctions = ColumnOutputFunctions(inputTupleDescriptor, + copyOutState->binary); + + taskFileDest->fileDesc = FileOpenForTransmit(taskFileDest->filePath, fileFlags, + fileMode); + + if (copyOutState->binary) + { + /* write headers when using binary encoding */ + resetStringInfo(copyOutState->fe_msgbuf); + AppendCopyBinaryHeaders(copyOutState); + + WriteToLocalFile(copyOutState->fe_msgbuf, taskFileDest->fileDesc); + } + + MemoryContextSwitchTo(oldContext); +} + + +/* + * TaskFileDestReceiverReceive implements the receiveSlot function of + * TaskFileDestReceiver. It takes a TupleTableSlot and writes the contents + * to a local file. + */ +static bool +TaskFileDestReceiverReceive(TupleTableSlot *slot, DestReceiver *dest) +{ + TaskFileDestReceiver *taskFileDest = (TaskFileDestReceiver *) dest; + + TupleDesc tupleDescriptor = taskFileDest->tupleDescriptor; + + CopyOutState copyOutState = taskFileDest->copyOutState; + FmgrInfo *columnOutputFunctions = taskFileDest->columnOutputFunctions; + + Datum *columnValues = NULL; + bool *columnNulls = NULL; + StringInfo copyData = copyOutState->fe_msgbuf; + + EState *executorState = taskFileDest->executorState; + MemoryContext executorTupleContext = GetPerTupleMemoryContext(executorState); + MemoryContext oldContext = MemoryContextSwitchTo(executorTupleContext); + + slot_getallattrs(slot); + + columnValues = slot->tts_values; + columnNulls = slot->tts_isnull; + + resetStringInfo(copyData); + + /* construct row in COPY format */ + AppendCopyRowData(columnValues, columnNulls, tupleDescriptor, + copyOutState, columnOutputFunctions, NULL); + + WriteToLocalFile(copyOutState->fe_msgbuf, taskFileDest->fileDesc); + + MemoryContextSwitchTo(oldContext); + + taskFileDest->tuplesSent++; + + ResetPerTupleExprContext(executorState); + + return true; +} + + +/* + * WriteToLocalResultsFile writes the bytes in a StringInfo to a local file. + */ +static void +WriteToLocalFile(StringInfo copyData, File fileDesc) +{ +#if (PG_VERSION_NUM >= 100000) + int bytesWritten = FileWrite(fileDesc, copyData->data, copyData->len, PG_WAIT_IO); +#else + int bytesWritten = FileWrite(fileDesc, copyData->data, copyData->len); +#endif + if (bytesWritten < 0) + { + ereport(ERROR, (errcode_for_file_access(), + errmsg("could not append to file: %m"))); + } +} + + +/* + * TaskFileDestReceiverShutdown implements the rShutdown interface of + * TaskFileDestReceiver. It writes the footer and closes the file. + * the relation. + */ +static void +TaskFileDestReceiverShutdown(DestReceiver *destReceiver) +{ + TaskFileDestReceiver *taskFileDest = (TaskFileDestReceiver *) destReceiver; + CopyOutState copyOutState = taskFileDest->copyOutState; + + if (copyOutState->binary) + { + /* write footers when using binary encoding */ + resetStringInfo(copyOutState->fe_msgbuf); + AppendCopyBinaryFooters(copyOutState); + WriteToLocalFile(copyOutState->fe_msgbuf, taskFileDest->fileDesc); + } + + FileClose(taskFileDest->fileDesc); +} + + +/* + * TaskFileDestReceiverDestroy frees memory allocated as part of the + * TaskFileDestReceiver and closes file descriptors. + */ +static void +TaskFileDestReceiverDestroy(DestReceiver *destReceiver) +{ + TaskFileDestReceiver *taskFileDest = (TaskFileDestReceiver *) destReceiver; + + if (taskFileDest->copyOutState) + { + pfree(taskFileDest->copyOutState); + } + + if (taskFileDest->columnOutputFunctions) + { + pfree(taskFileDest->columnOutputFunctions); + } + + pfree(taskFileDest->filePath); + pfree(taskFileDest); +} diff --git a/src/include/distributed/commands/multi_copy.h b/src/include/distributed/commands/multi_copy.h index c4d7a8f4b..6ecc63500 100644 --- a/src/include/distributed/commands/multi_copy.h +++ b/src/include/distributed/commands/multi_copy.h @@ -129,7 +129,8 @@ extern void AppendCopyBinaryHeaders(CopyOutState headerOutputState); extern void AppendCopyBinaryFooters(CopyOutState footerOutputState); extern void EndRemoteCopy(int64 shardId, List *connectionList, bool stopOnFailure); extern Node * ProcessCopyStmt(CopyStmt *copyStatement, char *completionTag, - bool *commandMustRunAsOwner); + const char *queryString); +extern void CheckCopyPermissions(CopyStmt *copyStatement); #endif /* MULTI_COPY_H */ diff --git a/src/include/distributed/multi_executor.h b/src/include/distributed/multi_executor.h index 6c798eb26..ad1aca23e 100644 --- a/src/include/distributed/multi_executor.h +++ b/src/include/distributed/multi_executor.h @@ -36,6 +36,7 @@ extern TupleTableSlot * ReturnTupleFromTuplestore(CitusScanState *scanState); extern void LoadTuplesIntoTupleStore(CitusScanState *citusScanState, Job *workerJob); extern void ReadFileIntoTupleStore(char *fileName, char *copyFormat, TupleDesc tupleDescriptor, Tuplestorestate *tupstore); +extern Query * ParseQueryString(const char *queryString); extern void ExecuteQueryStringIntoDestReceiver(const char *queryString, ParamListInfo params, DestReceiver *dest); diff --git a/src/include/distributed/worker_protocol.h b/src/include/distributed/worker_protocol.h index 30bef0f68..b92119b7f 100644 --- a/src/include/distributed/worker_protocol.h +++ b/src/include/distributed/worker_protocol.h @@ -122,6 +122,8 @@ extern FmgrInfo * GetFunctionInfo(Oid typeId, Oid accessMethodId, int16 procedur extern uint64 ExtractShardIdFromTableName(const char *tableName, bool missingOk); extern List * TableDDLCommandList(const char *nodeName, uint32 nodePort, const char *tableName); +extern int64 WorkerExecuteSqlTask(Query *query, char *taskFilename, + bool binaryCopyFormat); /* Function declarations shared with the master planner */