mirror of https://github.com/citusdata/citus.git
Merge pull request #1672 from citusdata/task_tracker_superuser
Execute transmit commands as extension owner during task-tracker queriespull/1676/head
commit
632d0c675a
|
@ -127,7 +127,8 @@ MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDataba
|
|||
* error and returns INVALID_CONNECTION_ID.
|
||||
*/
|
||||
int32
|
||||
MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase)
|
||||
MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase,
|
||||
const char *userName)
|
||||
{
|
||||
MultiConnection *connection = NULL;
|
||||
ConnStatusType connStatusType = CONNECTION_OK;
|
||||
|
@ -148,7 +149,8 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD
|
|||
}
|
||||
|
||||
/* prepare asynchronous request for worker node connection */
|
||||
connection = StartNodeConnection(connectionFlags, nodeName, nodePort);
|
||||
connection = StartNodeUserDatabaseConnection(connectionFlags, nodeName, nodePort,
|
||||
userName, nodeDatabase);
|
||||
connStatusType = PQstatus(connection->pgConn);
|
||||
|
||||
/*
|
||||
|
|
|
@ -273,7 +273,8 @@ ManageTaskExecution(Task *task, TaskExecution *taskExecution,
|
|||
/* we use the same database name on the master and worker nodes */
|
||||
nodeDatabase = get_database_name(MyDatabaseId);
|
||||
|
||||
connectionId = MultiClientConnectStart(nodeName, nodePort, nodeDatabase);
|
||||
connectionId = MultiClientConnectStart(nodeName, nodePort, nodeDatabase,
|
||||
NULL);
|
||||
connectionIdArray[currentIndex] = connectionId;
|
||||
|
||||
/* if valid, poll the connection until the connection is initiated */
|
||||
|
|
|
@ -72,7 +72,8 @@ static Task * TaskHashLookup(HTAB *trackerHash, TaskType taskType, uint64 jobId,
|
|||
uint32 taskId);
|
||||
static bool TopLevelTask(Task *task);
|
||||
static bool TransmitExecutionCompleted(TaskExecution *taskExecution);
|
||||
static HTAB * TrackerHash(const char *taskTrackerHashName, List *workerNodeList);
|
||||
static HTAB * TrackerHash(const char *taskTrackerHashName, List *workerNodeList,
|
||||
char *userName);
|
||||
static HTAB * TrackerHashCreate(const char *taskTrackerHashName,
|
||||
uint32 taskTrackerHashSize);
|
||||
static TaskTracker * TrackerHashEnter(HTAB *taskTrackerHash, char *nodeName,
|
||||
|
@ -161,6 +162,7 @@ MultiTaskTrackerExecute(Job *job)
|
|||
List *workerNodeList = NIL;
|
||||
HTAB *taskTrackerHash = NULL;
|
||||
HTAB *transmitTrackerHash = NULL;
|
||||
char *extensionOwner = CitusExtensionOwnerName();
|
||||
const char *taskTrackerHashName = "Task Tracker Hash";
|
||||
const char *transmitTrackerHashName = "Transmit Tracker Hash";
|
||||
List *jobIdList = NIL;
|
||||
|
@ -202,8 +204,12 @@ MultiTaskTrackerExecute(Job *job)
|
|||
workerNodeList = ActivePrimaryNodeList();
|
||||
taskTrackerCount = (uint32) list_length(workerNodeList);
|
||||
|
||||
taskTrackerHash = TrackerHash(taskTrackerHashName, workerNodeList);
|
||||
transmitTrackerHash = TrackerHash(transmitTrackerHashName, workerNodeList);
|
||||
/* connect as the current user for running queries */
|
||||
taskTrackerHash = TrackerHash(taskTrackerHashName, workerNodeList, NULL);
|
||||
|
||||
/* connect as the superuser for fetching result files */
|
||||
transmitTrackerHash = TrackerHash(transmitTrackerHashName, workerNodeList,
|
||||
extensionOwner);
|
||||
|
||||
TrackerHashConnect(taskTrackerHash);
|
||||
TrackerHashConnect(transmitTrackerHash);
|
||||
|
@ -667,10 +673,11 @@ TransmitExecutionCompleted(TaskExecution *taskExecution)
|
|||
/*
|
||||
* TrackerHash creates a task tracker hash with the given name. The function
|
||||
* then inserts one task tracker entry for each node in the given worker node
|
||||
* list, and initializes state for each task tracker.
|
||||
* list, and initializes state for each task tracker. The userName argument
|
||||
* indicates which user to connect as.
|
||||
*/
|
||||
static HTAB *
|
||||
TrackerHash(const char *taskTrackerHashName, List *workerNodeList)
|
||||
TrackerHash(const char *taskTrackerHashName, List *workerNodeList, char *userName)
|
||||
{
|
||||
/* create task tracker hash */
|
||||
uint32 taskTrackerHashSize = list_length(workerNodeList);
|
||||
|
@ -712,6 +719,7 @@ TrackerHash(const char *taskTrackerHashName, List *workerNodeList)
|
|||
}
|
||||
|
||||
taskTracker->taskStateHash = taskStateHash;
|
||||
taskTracker->userName = userName;
|
||||
}
|
||||
|
||||
return taskTrackerHash;
|
||||
|
@ -845,9 +853,10 @@ TrackerConnectPoll(TaskTracker *taskTracker)
|
|||
char *nodeName = taskTracker->workerName;
|
||||
uint32 nodePort = taskTracker->workerPort;
|
||||
char *nodeDatabase = get_database_name(MyDatabaseId);
|
||||
char *nodeUser = taskTracker->userName;
|
||||
|
||||
int32 connectionId = MultiClientConnectStart(nodeName, nodePort,
|
||||
nodeDatabase);
|
||||
nodeDatabase, nodeUser);
|
||||
if (connectionId != INVALID_CONNECTION_ID)
|
||||
{
|
||||
taskTracker->connectionId = connectionId;
|
||||
|
|
|
@ -635,6 +635,10 @@ IsTransmitStmt(Node *parsetree)
|
|||
static void
|
||||
VerifyTransmitStmt(CopyStmt *copyStatement)
|
||||
{
|
||||
char *fileName = NULL;
|
||||
|
||||
EnsureSuperUser();
|
||||
|
||||
/* do some minimal option verification */
|
||||
if (copyStatement->relation == NULL ||
|
||||
copyStatement->relation->relname == NULL)
|
||||
|
@ -643,6 +647,20 @@ VerifyTransmitStmt(CopyStmt *copyStatement)
|
|||
errmsg("FORMAT 'transmit' requires a target file")));
|
||||
}
|
||||
|
||||
fileName = copyStatement->relation->relname;
|
||||
|
||||
if (is_absolute_path(fileName))
|
||||
{
|
||||
ereport(ERROR, (errcode(ERRCODE_INSUFFICIENT_PRIVILEGE),
|
||||
(errmsg("absolute path not allowed"))));
|
||||
}
|
||||
else if (!path_is_relative_and_below_cwd(fileName))
|
||||
{
|
||||
ereport(ERROR,
|
||||
(errcode(ERRCODE_INSUFFICIENT_PRIVILEGE),
|
||||
(errmsg("path must be in or below the current directory"))));
|
||||
}
|
||||
|
||||
if (copyStatement->filename != NULL)
|
||||
{
|
||||
ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
|
|
|
@ -57,10 +57,12 @@ bool ExpireCachedShards = false;
|
|||
|
||||
|
||||
/* Local functions forward declarations */
|
||||
static void FetchRegularFile(const char *nodeName, uint32 nodePort,
|
||||
StringInfo remoteFilename, StringInfo localFilename);
|
||||
static void FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort,
|
||||
StringInfo remoteFilename,
|
||||
StringInfo localFilename);
|
||||
static bool ReceiveRegularFile(const char *nodeName, uint32 nodePort,
|
||||
StringInfo transmitCommand, StringInfo filePath);
|
||||
const char *nodeUser, StringInfo transmitCommand,
|
||||
StringInfo filePath);
|
||||
static void ReceiveResourceCleanup(int32 connectionId, const char *filename,
|
||||
int32 fileDescriptor);
|
||||
static void DeleteFile(const char *filename);
|
||||
|
@ -132,7 +134,9 @@ worker_fetch_partition_file(PG_FUNCTION_ARGS)
|
|||
}
|
||||
|
||||
nodeName = text_to_cstring(nodeNameText);
|
||||
FetchRegularFile(nodeName, nodePort, remoteFilename, taskFilename);
|
||||
|
||||
/* we've made sure the file names are sanitized, safe to fetch as superuser */
|
||||
FetchRegularFileAsSuperUser(nodeName, nodePort, remoteFilename, taskFilename);
|
||||
|
||||
PG_RETURN_VOID();
|
||||
}
|
||||
|
@ -176,7 +180,9 @@ worker_fetch_query_results_file(PG_FUNCTION_ARGS)
|
|||
}
|
||||
|
||||
nodeName = text_to_cstring(nodeNameText);
|
||||
FetchRegularFile(nodeName, nodePort, remoteFilename, taskFilename);
|
||||
|
||||
/* we've made sure the file names are sanitized, safe to fetch as superuser */
|
||||
FetchRegularFileAsSuperUser(nodeName, nodePort, remoteFilename, taskFilename);
|
||||
|
||||
PG_RETURN_VOID();
|
||||
}
|
||||
|
@ -195,11 +201,16 @@ TaskFilename(StringInfo directoryName, uint32 taskId)
|
|||
}
|
||||
|
||||
|
||||
/* Helper function to transfer the remote file in an idempotent manner. */
|
||||
/*
|
||||
* FetchRegularFileAsSuperUser copies a file from a remote node in an idempotent
|
||||
* manner. It connects to the remote node as superuser to give file access.
|
||||
* Callers must make sure that the file names are sanitized.
|
||||
*/
|
||||
static void
|
||||
FetchRegularFile(const char *nodeName, uint32 nodePort,
|
||||
FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort,
|
||||
StringInfo remoteFilename, StringInfo localFilename)
|
||||
{
|
||||
char *nodeUser = NULL;
|
||||
StringInfo attemptFilename = NULL;
|
||||
StringInfo transmitCommand = NULL;
|
||||
uint32 randomId = (uint32) random();
|
||||
|
@ -218,7 +229,11 @@ FetchRegularFile(const char *nodeName, uint32 nodePort,
|
|||
transmitCommand = makeStringInfo();
|
||||
appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilename->data);
|
||||
|
||||
received = ReceiveRegularFile(nodeName, nodePort, transmitCommand, attemptFilename);
|
||||
/* connect as superuser to give file access */
|
||||
nodeUser = CitusExtensionOwnerName();
|
||||
|
||||
received = ReceiveRegularFile(nodeName, nodePort, nodeUser, transmitCommand,
|
||||
attemptFilename);
|
||||
if (!received)
|
||||
{
|
||||
ereport(ERROR, (errmsg("could not receive file \"%s\" from %s:%u",
|
||||
|
@ -245,7 +260,7 @@ FetchRegularFile(const char *nodeName, uint32 nodePort,
|
|||
* and returns false.
|
||||
*/
|
||||
static bool
|
||||
ReceiveRegularFile(const char *nodeName, uint32 nodePort,
|
||||
ReceiveRegularFile(const char *nodeName, uint32 nodePort, const char *nodeUser,
|
||||
StringInfo transmitCommand, StringInfo filePath)
|
||||
{
|
||||
int32 fileDescriptor = -1;
|
||||
|
@ -277,7 +292,7 @@ ReceiveRegularFile(const char *nodeName, uint32 nodePort,
|
|||
nodeDatabase = get_database_name(MyDatabaseId);
|
||||
|
||||
/* connect to remote node */
|
||||
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL);
|
||||
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, nodeUser);
|
||||
if (connectionId == INVALID_CONNECTION_ID)
|
||||
{
|
||||
ReceiveResourceCleanup(connectionId, filename, fileDescriptor);
|
||||
|
@ -822,7 +837,8 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName)
|
|||
remoteCopyCommand = makeStringInfo();
|
||||
appendStringInfo(remoteCopyCommand, COPY_OUT_COMMAND, tableName);
|
||||
|
||||
received = ReceiveRegularFile(nodeName, nodePort, remoteCopyCommand, localFilePath);
|
||||
received = ReceiveRegularFile(nodeName, nodePort, NULL, remoteCopyCommand,
|
||||
localFilePath);
|
||||
if (!received)
|
||||
{
|
||||
return false;
|
||||
|
@ -862,8 +878,6 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName)
|
|||
CommandCounterIncrement();
|
||||
}
|
||||
|
||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
||||
|
||||
/*
|
||||
* Copy local file into the relation. We call ProcessUtility() instead of
|
||||
* directly calling DoCopy() because some extensions (e.g. cstore_fdw) hook
|
||||
|
@ -882,6 +896,8 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName)
|
|||
/* finally delete the temporary file we created */
|
||||
DeleteFile(localFilePath->data);
|
||||
|
||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -897,6 +913,7 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName)
|
|||
static bool
|
||||
FetchForeignTable(const char *nodeName, uint32 nodePort, const char *tableName)
|
||||
{
|
||||
const char *nodeUser = NULL;
|
||||
StringInfo localFilePath = NULL;
|
||||
StringInfo remoteFilePath = NULL;
|
||||
StringInfo transmitCommand = NULL;
|
||||
|
@ -923,7 +940,16 @@ FetchForeignTable(const char *nodeName, uint32 nodePort, const char *tableName)
|
|||
transmitCommand = makeStringInfo();
|
||||
appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilePath->data);
|
||||
|
||||
received = ReceiveRegularFile(nodeName, nodePort, transmitCommand, localFilePath);
|
||||
/*
|
||||
* We allow some arbitrary input in the file name and connect to the remote
|
||||
* node as superuser to transmit. Therefore, we only allow calling this
|
||||
* function when already running as superuser.
|
||||
*/
|
||||
EnsureSuperUser();
|
||||
nodeUser = CitusExtensionOwnerName();
|
||||
|
||||
received = ReceiveRegularFile(nodeName, nodePort, nodeUser, transmitCommand,
|
||||
localFilePath);
|
||||
if (!received)
|
||||
{
|
||||
return false;
|
||||
|
@ -1240,7 +1266,7 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS)
|
|||
sourceCopyCommand = makeStringInfo();
|
||||
appendStringInfo(sourceCopyCommand, COPY_OUT_COMMAND, sourceQualifiedName);
|
||||
|
||||
received = ReceiveRegularFile(sourceNodeName, sourceNodePort, sourceCopyCommand,
|
||||
received = ReceiveRegularFile(sourceNodeName, sourceNodePort, NULL, sourceCopyCommand,
|
||||
localFilePath);
|
||||
if (!received)
|
||||
{
|
||||
|
|
|
@ -98,7 +98,7 @@ typedef struct WaitInfo
|
|||
extern int32 MultiClientConnect(const char *nodeName, uint32 nodePort,
|
||||
const char *nodeDatabase, const char *nodeUser);
|
||||
extern int32 MultiClientConnectStart(const char *nodeName, uint32 nodePort,
|
||||
const char *nodeDatabase);
|
||||
const char *nodeDatabase, const char *nodeUser);
|
||||
extern ConnectStatus MultiClientConnectPoll(int32 connectionId);
|
||||
extern void MultiClientDisconnect(int32 connectionId);
|
||||
extern bool MultiClientConnectionUp(int32 connectionId);
|
||||
|
|
|
@ -156,6 +156,7 @@ typedef struct TaskTracker
|
|||
{
|
||||
uint32 workerPort; /* node's port; part of hash table key */
|
||||
char workerName[WORKER_LENGTH]; /* node's name; part of hash table key */
|
||||
char *userName; /* which user to connect as */
|
||||
TrackerStatus trackerStatus;
|
||||
int32 connectionId;
|
||||
uint32 connectPollCount;
|
||||
|
|
|
@ -7,7 +7,7 @@ ALTER SEQUENCE pg_catalog.pg_dist_shardid_seq RESTART 1420000;
|
|||
ALTER SEQUENCE pg_catalog.pg_dist_jobid_seq RESTART 1420000;
|
||||
SET citus.shard_replication_factor TO 1;
|
||||
SET citus.shard_count TO 2;
|
||||
CREATE TABLE test (id integer);
|
||||
CREATE TABLE test (id integer, val integer);
|
||||
SELECT create_distributed_table('test', 'id');
|
||||
create_distributed_table
|
||||
--------------------------
|
||||
|
@ -56,6 +56,9 @@ GRANT SELECT ON TABLE test_1420001 TO read_access;
|
|||
-- create prepare tests
|
||||
PREPARE prepare_insert AS INSERT INTO test VALUES ($1);
|
||||
PREPARE prepare_select AS SELECT count(*) FROM test;
|
||||
-- not allowed to read absolute paths, even as superuser
|
||||
COPY "/etc/passwd" TO STDOUT WITH (format transmit);
|
||||
ERROR: absolute path not allowed
|
||||
-- check full permission
|
||||
SET ROLE full_access;
|
||||
EXECUTE prepare_insert(1);
|
||||
|
@ -85,7 +88,22 @@ SELECT count(*) FROM test;
|
|||
2
|
||||
(1 row)
|
||||
|
||||
-- test re-partition query (needs to transmit intermediate results)
|
||||
SELECT count(*) FROM test a JOIN test b ON (a.val = b.val) WHERE a.id = 1 AND b.id = 2;
|
||||
count
|
||||
-------
|
||||
0
|
||||
(1 row)
|
||||
|
||||
-- should not be able to transmit directly
|
||||
COPY "postgresql.conf" TO STDOUT WITH (format transmit);
|
||||
ERROR: operation is not allowed
|
||||
HINT: Run the command with a superuser.
|
||||
SET citus.task_executor_type TO 'real-time';
|
||||
-- should not be able to transmit directly
|
||||
COPY "postgresql.conf" TO STDOUT WITH (format transmit);
|
||||
ERROR: operation is not allowed
|
||||
HINT: Run the command with a superuser.
|
||||
-- check read permission
|
||||
SET ROLE read_access;
|
||||
EXECUTE prepare_insert(1);
|
||||
|
@ -117,6 +135,17 @@ SELECT count(*) FROM test;
|
|||
2
|
||||
(1 row)
|
||||
|
||||
-- test re-partition query (needs to transmit intermediate results)
|
||||
SELECT count(*) FROM test a JOIN test b ON (a.val = b.val) WHERE a.id = 1 AND b.id = 2;
|
||||
count
|
||||
-------
|
||||
0
|
||||
(1 row)
|
||||
|
||||
-- should not be able to transmit directly
|
||||
COPY "postgresql.conf" TO STDOUT WITH (format transmit);
|
||||
ERROR: operation is not allowed
|
||||
HINT: Run the command with a superuser.
|
||||
SET citus.task_executor_type TO 'real-time';
|
||||
-- check no permission
|
||||
SET ROLE no_access;
|
||||
|
@ -133,6 +162,13 @@ ERROR: permission denied for relation test
|
|||
SET citus.task_executor_type TO 'task-tracker';
|
||||
SELECT count(*) FROM test;
|
||||
ERROR: permission denied for relation test
|
||||
-- test re-partition query
|
||||
SELECT count(*) FROM test a JOIN test b ON (a.val = b.val) WHERE a.id = 1 AND b.id = 2;
|
||||
ERROR: permission denied for relation test
|
||||
-- should not be able to transmit directly
|
||||
COPY "postgresql.conf" TO STDOUT WITH (format transmit);
|
||||
ERROR: operation is not allowed
|
||||
HINT: Run the command with a superuser.
|
||||
SET citus.task_executor_type TO 'real-time';
|
||||
RESET ROLE;
|
||||
DROP TABLE test;
|
||||
|
|
|
@ -10,7 +10,7 @@ ALTER SEQUENCE pg_catalog.pg_dist_jobid_seq RESTART 1420000;
|
|||
SET citus.shard_replication_factor TO 1;
|
||||
SET citus.shard_count TO 2;
|
||||
|
||||
CREATE TABLE test (id integer);
|
||||
CREATE TABLE test (id integer, val integer);
|
||||
SELECT create_distributed_table('test', 'id');
|
||||
|
||||
-- turn off propagation to avoid Enterprise processing the following section
|
||||
|
@ -47,6 +47,9 @@ GRANT SELECT ON TABLE test_1420001 TO read_access;
|
|||
PREPARE prepare_insert AS INSERT INTO test VALUES ($1);
|
||||
PREPARE prepare_select AS SELECT count(*) FROM test;
|
||||
|
||||
-- not allowed to read absolute paths, even as superuser
|
||||
COPY "/etc/passwd" TO STDOUT WITH (format transmit);
|
||||
|
||||
-- check full permission
|
||||
SET ROLE full_access;
|
||||
|
||||
|
@ -59,8 +62,18 @@ SELECT count(*) FROM test WHERE id = 1;
|
|||
|
||||
SET citus.task_executor_type TO 'task-tracker';
|
||||
SELECT count(*) FROM test;
|
||||
|
||||
-- test re-partition query (needs to transmit intermediate results)
|
||||
SELECT count(*) FROM test a JOIN test b ON (a.val = b.val) WHERE a.id = 1 AND b.id = 2;
|
||||
|
||||
-- should not be able to transmit directly
|
||||
COPY "postgresql.conf" TO STDOUT WITH (format transmit);
|
||||
|
||||
SET citus.task_executor_type TO 'real-time';
|
||||
|
||||
-- should not be able to transmit directly
|
||||
COPY "postgresql.conf" TO STDOUT WITH (format transmit);
|
||||
|
||||
-- check read permission
|
||||
SET ROLE read_access;
|
||||
|
||||
|
@ -73,6 +86,13 @@ SELECT count(*) FROM test WHERE id = 1;
|
|||
|
||||
SET citus.task_executor_type TO 'task-tracker';
|
||||
SELECT count(*) FROM test;
|
||||
|
||||
-- test re-partition query (needs to transmit intermediate results)
|
||||
SELECT count(*) FROM test a JOIN test b ON (a.val = b.val) WHERE a.id = 1 AND b.id = 2;
|
||||
|
||||
-- should not be able to transmit directly
|
||||
COPY "postgresql.conf" TO STDOUT WITH (format transmit);
|
||||
|
||||
SET citus.task_executor_type TO 'real-time';
|
||||
|
||||
-- check no permission
|
||||
|
@ -87,6 +107,13 @@ SELECT count(*) FROM test WHERE id = 1;
|
|||
|
||||
SET citus.task_executor_type TO 'task-tracker';
|
||||
SELECT count(*) FROM test;
|
||||
|
||||
-- test re-partition query
|
||||
SELECT count(*) FROM test a JOIN test b ON (a.val = b.val) WHERE a.id = 1 AND b.id = 2;
|
||||
|
||||
-- should not be able to transmit directly
|
||||
COPY "postgresql.conf" TO STDOUT WITH (format transmit);
|
||||
|
||||
SET citus.task_executor_type TO 'real-time';
|
||||
|
||||
RESET ROLE;
|
||||
|
|
Loading…
Reference in New Issue