diff --git a/src/backend/distributed/worker/worker_merge_protocol.c b/src/backend/distributed/worker/worker_merge_protocol.c index a539a9c90..606c14e76 100644 --- a/src/backend/distributed/worker/worker_merge_protocol.c +++ b/src/backend/distributed/worker/worker_merge_protocol.c @@ -35,6 +35,7 @@ #include "executor/spi.h" #include "nodes/makefuncs.h" +#include "parser/parse_relation.h" #include "parser/parse_type.h" #include "storage/lmgr.h" #include "utils/acl.h" @@ -183,8 +184,6 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS) StringInfo jobSchemaName = JobSchemaName(jobId); StringInfo taskTableName = TaskTableName(taskId); StringInfo taskDirectoryName = TaskDirectoryName(jobId, taskId); - Oid savedUserId = InvalidOid; - int savedSecurityContext = 0; Oid userId = GetUserId(); /* we should have the same number of column names and types */ @@ -231,14 +230,9 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS) CreateTaskTable(jobSchemaName, taskTableName, columnNameList, columnTypeList); - /* need superuser to copy from files */ - GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); - SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); - CopyTaskFilesFromDirectory(jobSchemaName, taskTableName, taskDirectoryName, userId); - SetUserIdAndSecContext(savedUserId, savedSecurityContext); PG_RETURN_VOID(); } @@ -557,8 +551,8 @@ CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName, appendStringInfo(fullFilename, "%s/%s", directoryName, baseFilename); /* build relation object and copy statement */ - RangeVar *relation = makeRangeVar(schemaName->data, relationName->data, -1); - CopyStmt *copyStatement = CopyStatement(relation, fullFilename->data); + RangeVar *rangeVar = makeRangeVar(schemaName->data, relationName->data, -1); + CopyStmt *copyStatement = CopyStatement(rangeVar, fullFilename->data); if (BinaryWorkerCopyFormat) { DefElem *copyOption = makeDefElem("format", (Node *) makeString("binary"), @@ -567,12 +561,25 @@ CopyTaskFilesFromDirectory(StringInfo schemaName, StringInfo relationName, } { - ParseState *pstate = make_parsestate(NULL); - pstate->p_sourcetext = queryString; + ParseState *parseState = make_parsestate(NULL); + parseState->p_sourcetext = queryString; - DoCopy(pstate, copyStatement, -1, -1, &copiedRowCount); + Relation relation = table_openrv(rangeVar, RowExclusiveLock); + (void) addRangeTableEntryForRelation(parseState, relation, RowExclusiveLock, + NULL, false, false); - free_parsestate(pstate); + CopyState copyState = BeginCopyFrom(parseState, + relation, + copyStatement->filename, + copyStatement->is_program, + NULL, + copyStatement->attlist, + copyStatement->options); + copiedRowCount = CopyFrom(copyState); + EndCopyFrom(copyState); + + free_parsestate(parseState); + table_close(relation, NoLock); } copiedRowTotal += copiedRowCount;