Execute transmit commands as superuser during task-tracker queries

release-6.1
Marco Slot 2017-09-27 20:55:35 -07:00
parent ee9e064eef
commit 671f52ee6f
7 changed files with 66 additions and 24 deletions

View File

@ -133,7 +133,8 @@ MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDataba
* error and returns INVALID_CONNECTION_ID.
*/
int32
MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase)
MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase,
const char *userName)
{
MultiConnection *connection = NULL;
ConnStatusType connStatusType = CONNECTION_OK;
@ -154,7 +155,8 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD
}
/* prepare asynchronous request for worker node connection */
connection = StartNodeConnection(connectionFlags, nodeName, nodePort);
connection = StartNodeUserDatabaseConnection(connectionFlags, nodeName, nodePort,
userName, nodeDatabase);
connStatusType = PQstatus(connection->pgConn);
/*

View File

@ -273,7 +273,8 @@ ManageTaskExecution(Task *task, TaskExecution *taskExecution,
/* we use the same database name on the master and worker nodes */
nodeDatabase = get_database_name(MyDatabaseId);
connectionId = MultiClientConnectStart(nodeName, nodePort, nodeDatabase);
connectionId = MultiClientConnectStart(nodeName, nodePort, nodeDatabase,
NULL);
connectionIdArray[currentIndex] = connectionId;
/* if valid, poll the connection until the connection is initiated */

View File

@ -27,6 +27,7 @@
#include "commands/dbcommands.h"
#include "distributed/citus_nodes.h"
#include "distributed/connection_management.h"
#include "distributed/metadata_cache.h"
#include "distributed/multi_client_executor.h"
#include "distributed/multi_physical_planner.h"
#include "distributed/multi_server_executor.h"
@ -68,7 +69,8 @@ static Task * TaskHashLookup(HTAB *trackerHash, TaskType taskType, uint64 jobId,
uint32 taskId);
static bool TopLevelTask(Task *task);
static bool TransmitExecutionCompleted(TaskExecution *taskExecution);
static HTAB * TrackerHash(const char *taskTrackerHashName, List *workerNodeList);
static HTAB * TrackerHash(const char *taskTrackerHashName, List *workerNodeList,
char *userName);
static HTAB * TrackerHashCreate(const char *taskTrackerHashName,
uint32 taskTrackerHashSize);
static TaskTracker * TrackerHashEnter(HTAB *taskTrackerHash, char *nodeName,
@ -157,6 +159,7 @@ MultiTaskTrackerExecute(Job *job)
List *workerNodeList = NIL;
HTAB *taskTrackerHash = NULL;
HTAB *transmitTrackerHash = NULL;
char *extensionOwner = CitusExtensionOwnerName();
const char *taskTrackerHashName = "Task Tracker Hash";
const char *transmitTrackerHashName = "Transmit Tracker Hash";
List *jobIdList = NIL;
@ -191,8 +194,12 @@ MultiTaskTrackerExecute(Job *job)
workerNodeList = WorkerNodeList();
taskTrackerCount = (uint32) list_length(workerNodeList);
taskTrackerHash = TrackerHash(taskTrackerHashName, workerNodeList);
transmitTrackerHash = TrackerHash(transmitTrackerHashName, workerNodeList);
/* connect as the current user for running queries */
taskTrackerHash = TrackerHash(taskTrackerHashName, workerNodeList, NULL);
/* connect as the superuser for fetching result files */
transmitTrackerHash = TrackerHash(transmitTrackerHashName, workerNodeList,
extensionOwner);
TrackerHashConnect(taskTrackerHash);
TrackerHashConnect(transmitTrackerHash);
@ -660,10 +667,11 @@ TransmitExecutionCompleted(TaskExecution *taskExecution)
/*
* TrackerHash creates a task tracker hash with the given name. The function
* then inserts one task tracker entry for each node in the given worker node
* list, and initializes state for each task tracker.
* list, and initializes state for each task tracker. The userName argument
* indicates which user to connect as.
*/
static HTAB *
TrackerHash(const char *taskTrackerHashName, List *workerNodeList)
TrackerHash(const char *taskTrackerHashName, List *workerNodeList, char *userName)
{
/* create task tracker hash */
uint32 taskTrackerHashSize = list_length(workerNodeList);
@ -705,6 +713,7 @@ TrackerHash(const char *taskTrackerHashName, List *workerNodeList)
}
taskTracker->taskStateHash = taskStateHash;
taskTracker->userName = userName;
}
return taskTrackerHash;
@ -838,9 +847,10 @@ TrackerConnectPoll(TaskTracker *taskTracker)
char *nodeName = taskTracker->workerName;
uint32 nodePort = taskTracker->workerPort;
char *nodeDatabase = get_database_name(MyDatabaseId);
char *nodeUser = taskTracker->userName;
int32 connectionId = MultiClientConnectStart(nodeName, nodePort,
nodeDatabase);
nodeDatabase, nodeUser);
if (connectionId != INVALID_CONNECTION_ID)
{
taskTracker->connectionId = connectionId;

View File

@ -482,6 +482,8 @@ VerifyTransmitStmt(CopyStmt *copyStatement)
{
char *fileName = NULL;
EnsureSuperUser();
/* do some minimal option verification */
if (copyStatement->relation == NULL ||
copyStatement->relation->relname == NULL)

View File

@ -48,10 +48,12 @@ bool ExpireCachedShards = false;
/* Local functions forward declarations */
static void FetchRegularFile(const char *nodeName, uint32 nodePort,
StringInfo remoteFilename, StringInfo localFilename);
static void FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort,
StringInfo remoteFilename,
StringInfo localFilename);
static bool ReceiveRegularFile(const char *nodeName, uint32 nodePort,
StringInfo transmitCommand, StringInfo filePath);
const char *nodeUser, StringInfo transmitCommand,
StringInfo filePath);
static void ReceiveResourceCleanup(int32 connectionId, const char *filename,
int32 fileDescriptor);
static void DeleteFile(const char *filename);
@ -120,7 +122,9 @@ worker_fetch_partition_file(PG_FUNCTION_ARGS)
}
nodeName = text_to_cstring(nodeNameText);
FetchRegularFile(nodeName, nodePort, remoteFilename, taskFilename);
/* we've made sure the file names are sanitized, safe to fetch as superuser */
FetchRegularFileAsSuperUser(nodeName, nodePort, remoteFilename, taskFilename);
PG_RETURN_VOID();
}
@ -161,7 +165,9 @@ worker_fetch_query_results_file(PG_FUNCTION_ARGS)
}
nodeName = text_to_cstring(nodeNameText);
FetchRegularFile(nodeName, nodePort, remoteFilename, taskFilename);
/* we've made sure the file names are sanitized, safe to fetch as superuser */
FetchRegularFileAsSuperUser(nodeName, nodePort, remoteFilename, taskFilename);
PG_RETURN_VOID();
}
@ -180,11 +186,16 @@ TaskFilename(StringInfo directoryName, uint32 taskId)
}
/* Helper function to transfer the remote file in an idempotent manner. */
/*
* FetchRegularFileAsSuperUser copies a file from a remote node in an idempotent
* manner. It connects to the remote node as superuser to give file access.
* Callers must make sure that the file names are sanitized.
*/
static void
FetchRegularFile(const char *nodeName, uint32 nodePort,
StringInfo remoteFilename, StringInfo localFilename)
FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort,
StringInfo remoteFilename, StringInfo localFilename)
{
char *nodeUser = NULL;
StringInfo attemptFilename = NULL;
StringInfo transmitCommand = NULL;
uint32 randomId = (uint32) random();
@ -203,7 +214,11 @@ FetchRegularFile(const char *nodeName, uint32 nodePort,
transmitCommand = makeStringInfo();
appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilename->data);
received = ReceiveRegularFile(nodeName, nodePort, transmitCommand, attemptFilename);
/* connect as superuser to give file access */
nodeUser = CitusExtensionOwnerName();
received = ReceiveRegularFile(nodeName, nodePort, nodeUser, transmitCommand,
attemptFilename);
if (!received)
{
ereport(ERROR, (errmsg("could not receive file \"%s\" from %s:%u",
@ -230,7 +245,7 @@ FetchRegularFile(const char *nodeName, uint32 nodePort,
* and returns false.
*/
static bool
ReceiveRegularFile(const char *nodeName, uint32 nodePort,
ReceiveRegularFile(const char *nodeName, uint32 nodePort, const char *nodeUser,
StringInfo transmitCommand, StringInfo filePath)
{
int32 fileDescriptor = -1;
@ -262,7 +277,7 @@ ReceiveRegularFile(const char *nodeName, uint32 nodePort,
nodeDatabase = get_database_name(MyDatabaseId);
/* connect to remote node */
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL);
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, nodeUser);
if (connectionId == INVALID_CONNECTION_ID)
{
ReceiveResourceCleanup(connectionId, filename, fileDescriptor);
@ -795,7 +810,8 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName)
remoteCopyCommand = makeStringInfo();
appendStringInfo(remoteCopyCommand, COPY_OUT_COMMAND, tableName);
received = ReceiveRegularFile(nodeName, nodePort, remoteCopyCommand, localFilePath);
received = ReceiveRegularFile(nodeName, nodePort, NULL, remoteCopyCommand,
localFilePath);
if (!received)
{
return false;
@ -870,6 +886,7 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName)
static bool
FetchForeignTable(const char *nodeName, uint32 nodePort, const char *tableName)
{
const char *nodeUser = NULL;
StringInfo localFilePath = NULL;
StringInfo remoteFilePath = NULL;
StringInfo transmitCommand = NULL;
@ -896,7 +913,16 @@ FetchForeignTable(const char *nodeName, uint32 nodePort, const char *tableName)
transmitCommand = makeStringInfo();
appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilePath->data);
received = ReceiveRegularFile(nodeName, nodePort, transmitCommand, localFilePath);
/*
* We allow some arbitrary input in the file name and connect to the remote
* node as superuser to transmit. Therefore, we only allow calling this
* function when already running as superuser.
*/
EnsureSuperUser();
nodeUser = CitusExtensionOwnerName();
received = ReceiveRegularFile(nodeName, nodePort, nodeUser, transmitCommand,
localFilePath);
if (!received)
{
return false;
@ -1233,7 +1259,7 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS)
sourceCopyCommand = makeStringInfo();
appendStringInfo(sourceCopyCommand, COPY_OUT_COMMAND, sourceQualifiedName);
received = ReceiveRegularFile(sourceNodeName, sourceNodePort, sourceCopyCommand,
received = ReceiveRegularFile(sourceNodeName, sourceNodePort, NULL, sourceCopyCommand,
localFilePath);
if (!received)
{

View File

@ -98,7 +98,7 @@ typedef struct WaitInfo
extern int32 MultiClientConnect(const char *nodeName, uint32 nodePort,
const char *nodeDatabase, const char *nodeUser);
extern int32 MultiClientConnectStart(const char *nodeName, uint32 nodePort,
const char *nodeDatabase);
const char *nodeDatabase, const char *nodeUser);
extern ConnectStatus MultiClientConnectPoll(int32 connectionId);
extern void MultiClientDisconnect(int32 connectionId);
extern bool MultiClientConnectionUp(int32 connectionId);

View File

@ -152,6 +152,7 @@ typedef struct TaskTracker
{
uint32 workerPort; /* node's port; part of hash table key */
char workerName[WORKER_LENGTH]; /* node's name; part of hash table key */
char *userName; /* which user to connect as */
TrackerStatus trackerStatus;
int32 connectionId;
uint32 connectPollCount;