mirror of https://github.com/citusdata/citus.git
Execute transmit commands as superuser during task-tracker queries
parent
4c45c87819
commit
0f490b13f2
|
@ -167,12 +167,12 @@ MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDataba
|
||||||
* error and returns INVALID_CONNECTION_ID.
|
* error and returns INVALID_CONNECTION_ID.
|
||||||
*/
|
*/
|
||||||
int32
|
int32
|
||||||
MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase)
|
MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase,
|
||||||
|
const char *userName)
|
||||||
{
|
{
|
||||||
PGconn *connection = NULL;
|
PGconn *connection = NULL;
|
||||||
char connInfoString[STRING_BUFFER_SIZE];
|
char connInfoString[STRING_BUFFER_SIZE];
|
||||||
ConnStatusType connStatusType = CONNECTION_BAD;
|
ConnStatusType connStatusType = CONNECTION_BAD;
|
||||||
char *userName = CurrentUserName();
|
|
||||||
|
|
||||||
int32 connectionId = AllocateConnectionId();
|
int32 connectionId = AllocateConnectionId();
|
||||||
if (connectionId == INVALID_CONNECTION_ID)
|
if (connectionId == INVALID_CONNECTION_ID)
|
||||||
|
@ -188,6 +188,11 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD
|
||||||
"command within a transaction")));
|
"command within a transaction")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (userName == NULL)
|
||||||
|
{
|
||||||
|
userName = CurrentUserName();
|
||||||
|
}
|
||||||
|
|
||||||
/* transcribe connection paremeters to string */
|
/* transcribe connection paremeters to string */
|
||||||
snprintf(connInfoString, STRING_BUFFER_SIZE, CONN_INFO_TEMPLATE,
|
snprintf(connInfoString, STRING_BUFFER_SIZE, CONN_INFO_TEMPLATE,
|
||||||
nodeName, nodePort, nodeDatabase, userName, CLIENT_CONNECT_TIMEOUT);
|
nodeName, nodePort, nodeDatabase, userName, CLIENT_CONNECT_TIMEOUT);
|
||||||
|
|
|
@ -272,7 +272,8 @@ ManageTaskExecution(Task *task, TaskExecution *taskExecution,
|
||||||
/* we use the same database name on the master and worker nodes */
|
/* we use the same database name on the master and worker nodes */
|
||||||
nodeDatabase = get_database_name(MyDatabaseId);
|
nodeDatabase = get_database_name(MyDatabaseId);
|
||||||
|
|
||||||
connectionId = MultiClientConnectStart(nodeName, nodePort, nodeDatabase);
|
connectionId = MultiClientConnectStart(nodeName, nodePort, nodeDatabase,
|
||||||
|
NULL);
|
||||||
connectionIdArray[currentIndex] = connectionId;
|
connectionIdArray[currentIndex] = connectionId;
|
||||||
|
|
||||||
/* if valid, poll the connection until the connection is initiated */
|
/* if valid, poll the connection until the connection is initiated */
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
|
|
||||||
#include "commands/dbcommands.h"
|
#include "commands/dbcommands.h"
|
||||||
#include "distributed/citus_nodes.h"
|
#include "distributed/citus_nodes.h"
|
||||||
|
#include "distributed/metadata_cache.h"
|
||||||
#include "distributed/multi_client_executor.h"
|
#include "distributed/multi_client_executor.h"
|
||||||
#include "distributed/multi_physical_planner.h"
|
#include "distributed/multi_physical_planner.h"
|
||||||
#include "distributed/multi_server_executor.h"
|
#include "distributed/multi_server_executor.h"
|
||||||
|
@ -65,7 +66,8 @@ static Task * TaskHashLookup(HTAB *trackerHash, TaskType taskType, uint64 jobId,
|
||||||
uint32 taskId);
|
uint32 taskId);
|
||||||
static bool TopLevelTask(Task *task);
|
static bool TopLevelTask(Task *task);
|
||||||
static bool TransmitExecutionCompleted(TaskExecution *taskExecution);
|
static bool TransmitExecutionCompleted(TaskExecution *taskExecution);
|
||||||
static HTAB * TrackerHash(const char *taskTrackerHashName, List *workerNodeList);
|
static HTAB * TrackerHash(const char *taskTrackerHashName, List *workerNodeList,
|
||||||
|
char *userName);
|
||||||
static HTAB * TrackerHashCreate(const char *taskTrackerHashName,
|
static HTAB * TrackerHashCreate(const char *taskTrackerHashName,
|
||||||
uint32 taskTrackerHashSize);
|
uint32 taskTrackerHashSize);
|
||||||
static TaskTracker * TrackerHashEnter(HTAB *taskTrackerHash, char *nodeName,
|
static TaskTracker * TrackerHashEnter(HTAB *taskTrackerHash, char *nodeName,
|
||||||
|
@ -154,6 +156,7 @@ MultiTaskTrackerExecute(Job *job)
|
||||||
List *workerNodeList = NIL;
|
List *workerNodeList = NIL;
|
||||||
HTAB *taskTrackerHash = NULL;
|
HTAB *taskTrackerHash = NULL;
|
||||||
HTAB *transmitTrackerHash = NULL;
|
HTAB *transmitTrackerHash = NULL;
|
||||||
|
char *extensionOwner = CitusExtensionOwnerName();
|
||||||
const char *taskTrackerHashName = "Task Tracker Hash";
|
const char *taskTrackerHashName = "Task Tracker Hash";
|
||||||
const char *transmitTrackerHashName = "Transmit Tracker Hash";
|
const char *transmitTrackerHashName = "Transmit Tracker Hash";
|
||||||
List *jobIdList = NIL;
|
List *jobIdList = NIL;
|
||||||
|
@ -188,8 +191,12 @@ MultiTaskTrackerExecute(Job *job)
|
||||||
workerNodeList = WorkerNodeList();
|
workerNodeList = WorkerNodeList();
|
||||||
taskTrackerCount = (uint32) list_length(workerNodeList);
|
taskTrackerCount = (uint32) list_length(workerNodeList);
|
||||||
|
|
||||||
taskTrackerHash = TrackerHash(taskTrackerHashName, workerNodeList);
|
/* connect as the current user for running queries */
|
||||||
transmitTrackerHash = TrackerHash(transmitTrackerHashName, workerNodeList);
|
taskTrackerHash = TrackerHash(taskTrackerHashName, workerNodeList, NULL);
|
||||||
|
|
||||||
|
/* connect as the superuser for fetching result files */
|
||||||
|
transmitTrackerHash = TrackerHash(transmitTrackerHashName, workerNodeList,
|
||||||
|
extensionOwner);
|
||||||
|
|
||||||
TrackerHashConnect(taskTrackerHash);
|
TrackerHashConnect(taskTrackerHash);
|
||||||
TrackerHashConnect(transmitTrackerHash);
|
TrackerHashConnect(transmitTrackerHash);
|
||||||
|
@ -657,10 +664,11 @@ TransmitExecutionCompleted(TaskExecution *taskExecution)
|
||||||
/*
|
/*
|
||||||
* TrackerHash creates a task tracker hash with the given name. The function
|
* TrackerHash creates a task tracker hash with the given name. The function
|
||||||
* then inserts one task tracker entry for each node in the given worker node
|
* then inserts one task tracker entry for each node in the given worker node
|
||||||
* list, and initializes state for each task tracker.
|
* list, and initializes state for each task tracker. The userName argument
|
||||||
|
* indicates which user to connect as.
|
||||||
*/
|
*/
|
||||||
static HTAB *
|
static HTAB *
|
||||||
TrackerHash(const char *taskTrackerHashName, List *workerNodeList)
|
TrackerHash(const char *taskTrackerHashName, List *workerNodeList, char *userName)
|
||||||
{
|
{
|
||||||
/* create task tracker hash */
|
/* create task tracker hash */
|
||||||
uint32 taskTrackerHashSize = list_length(workerNodeList);
|
uint32 taskTrackerHashSize = list_length(workerNodeList);
|
||||||
|
@ -702,6 +710,7 @@ TrackerHash(const char *taskTrackerHashName, List *workerNodeList)
|
||||||
}
|
}
|
||||||
|
|
||||||
taskTracker->taskStateHash = taskStateHash;
|
taskTracker->taskStateHash = taskStateHash;
|
||||||
|
taskTracker->userName = userName;
|
||||||
}
|
}
|
||||||
|
|
||||||
return taskTrackerHash;
|
return taskTrackerHash;
|
||||||
|
@ -835,9 +844,10 @@ TrackerConnectPoll(TaskTracker *taskTracker)
|
||||||
char *nodeName = taskTracker->workerName;
|
char *nodeName = taskTracker->workerName;
|
||||||
uint32 nodePort = taskTracker->workerPort;
|
uint32 nodePort = taskTracker->workerPort;
|
||||||
char *nodeDatabase = get_database_name(MyDatabaseId);
|
char *nodeDatabase = get_database_name(MyDatabaseId);
|
||||||
|
char *nodeUser = taskTracker->userName;
|
||||||
|
|
||||||
int32 connectionId = MultiClientConnectStart(nodeName, nodePort,
|
int32 connectionId = MultiClientConnectStart(nodeName, nodePort,
|
||||||
nodeDatabase);
|
nodeDatabase, nodeUser);
|
||||||
if (connectionId != INVALID_CONNECTION_ID)
|
if (connectionId != INVALID_CONNECTION_ID)
|
||||||
{
|
{
|
||||||
taskTracker->connectionId = connectionId;
|
taskTracker->connectionId = connectionId;
|
||||||
|
|
|
@ -401,6 +401,8 @@ VerifyTransmitStmt(CopyStmt *copyStatement)
|
||||||
{
|
{
|
||||||
char *fileName = NULL;
|
char *fileName = NULL;
|
||||||
|
|
||||||
|
EnsureSuperUser();
|
||||||
|
|
||||||
/* do some minimal option verification */
|
/* do some minimal option verification */
|
||||||
if (copyStatement->relation == NULL ||
|
if (copyStatement->relation == NULL ||
|
||||||
copyStatement->relation->relname == NULL)
|
copyStatement->relation->relname == NULL)
|
||||||
|
|
|
@ -26,6 +26,7 @@
|
||||||
#include "commands/extension.h"
|
#include "commands/extension.h"
|
||||||
#include "distributed/citus_ruleutils.h"
|
#include "distributed/citus_ruleutils.h"
|
||||||
#include "distributed/master_protocol.h"
|
#include "distributed/master_protocol.h"
|
||||||
|
#include "distributed/metadata_cache.h"
|
||||||
#include "distributed/multi_client_executor.h"
|
#include "distributed/multi_client_executor.h"
|
||||||
#include "distributed/multi_logical_optimizer.h"
|
#include "distributed/multi_logical_optimizer.h"
|
||||||
#include "distributed/multi_server_executor.h"
|
#include "distributed/multi_server_executor.h"
|
||||||
|
@ -46,10 +47,12 @@ bool ExpireCachedShards = false;
|
||||||
|
|
||||||
|
|
||||||
/* Local functions forward declarations */
|
/* Local functions forward declarations */
|
||||||
static void FetchRegularFile(const char *nodeName, uint32 nodePort,
|
static void FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort,
|
||||||
StringInfo remoteFilename, StringInfo localFilename);
|
StringInfo remoteFilename,
|
||||||
|
StringInfo localFilename);
|
||||||
static bool ReceiveRegularFile(const char *nodeName, uint32 nodePort,
|
static bool ReceiveRegularFile(const char *nodeName, uint32 nodePort,
|
||||||
StringInfo transmitCommand, StringInfo filePath);
|
const char *nodeUser, StringInfo transmitCommand,
|
||||||
|
StringInfo filePath);
|
||||||
static void ReceiveResourceCleanup(int32 connectionId, const char *filename,
|
static void ReceiveResourceCleanup(int32 connectionId, const char *filename,
|
||||||
int32 fileDescriptor);
|
int32 fileDescriptor);
|
||||||
static void DeleteFile(const char *filename);
|
static void DeleteFile(const char *filename);
|
||||||
|
@ -115,7 +118,9 @@ worker_fetch_partition_file(PG_FUNCTION_ARGS)
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeName = text_to_cstring(nodeNameText);
|
nodeName = text_to_cstring(nodeNameText);
|
||||||
FetchRegularFile(nodeName, nodePort, remoteFilename, taskFilename);
|
|
||||||
|
/* we've made sure the file names are sanitized, safe to fetch as superuser */
|
||||||
|
FetchRegularFileAsSuperUser(nodeName, nodePort, remoteFilename, taskFilename);
|
||||||
|
|
||||||
PG_RETURN_VOID();
|
PG_RETURN_VOID();
|
||||||
}
|
}
|
||||||
|
@ -156,7 +161,9 @@ worker_fetch_query_results_file(PG_FUNCTION_ARGS)
|
||||||
}
|
}
|
||||||
|
|
||||||
nodeName = text_to_cstring(nodeNameText);
|
nodeName = text_to_cstring(nodeNameText);
|
||||||
FetchRegularFile(nodeName, nodePort, remoteFilename, taskFilename);
|
|
||||||
|
/* we've made sure the file names are sanitized, safe to fetch as superuser */
|
||||||
|
FetchRegularFileAsSuperUser(nodeName, nodePort, remoteFilename, taskFilename);
|
||||||
|
|
||||||
PG_RETURN_VOID();
|
PG_RETURN_VOID();
|
||||||
}
|
}
|
||||||
|
@ -175,11 +182,16 @@ TaskFilename(StringInfo directoryName, uint32 taskId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/* Helper function to transfer the remote file in an idempotent manner. */
|
/*
|
||||||
|
* FetchRegularFileAsSuperUser copies a file from a remote node in an idempotent
|
||||||
|
* manner. It connects to the remote node as superuser to give file access.
|
||||||
|
* Callers must make sure that the file names are sanitized.
|
||||||
|
*/
|
||||||
static void
|
static void
|
||||||
FetchRegularFile(const char *nodeName, uint32 nodePort,
|
FetchRegularFileAsSuperUser(const char *nodeName, uint32 nodePort,
|
||||||
StringInfo remoteFilename, StringInfo localFilename)
|
StringInfo remoteFilename, StringInfo localFilename)
|
||||||
{
|
{
|
||||||
|
char *nodeUser = NULL;
|
||||||
StringInfo attemptFilename = NULL;
|
StringInfo attemptFilename = NULL;
|
||||||
StringInfo transmitCommand = NULL;
|
StringInfo transmitCommand = NULL;
|
||||||
uint32 randomId = (uint32) random();
|
uint32 randomId = (uint32) random();
|
||||||
|
@ -198,7 +210,11 @@ FetchRegularFile(const char *nodeName, uint32 nodePort,
|
||||||
transmitCommand = makeStringInfo();
|
transmitCommand = makeStringInfo();
|
||||||
appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilename->data);
|
appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilename->data);
|
||||||
|
|
||||||
received = ReceiveRegularFile(nodeName, nodePort, transmitCommand, attemptFilename);
|
/* connect as superuser to give file access */
|
||||||
|
nodeUser = CitusExtensionOwnerName();
|
||||||
|
|
||||||
|
received = ReceiveRegularFile(nodeName, nodePort, nodeUser, transmitCommand,
|
||||||
|
attemptFilename);
|
||||||
if (!received)
|
if (!received)
|
||||||
{
|
{
|
||||||
ereport(ERROR, (errmsg("could not receive file \"%s\" from %s:%u",
|
ereport(ERROR, (errmsg("could not receive file \"%s\" from %s:%u",
|
||||||
|
@ -225,7 +241,7 @@ FetchRegularFile(const char *nodeName, uint32 nodePort,
|
||||||
* and returns false.
|
* and returns false.
|
||||||
*/
|
*/
|
||||||
static bool
|
static bool
|
||||||
ReceiveRegularFile(const char *nodeName, uint32 nodePort,
|
ReceiveRegularFile(const char *nodeName, uint32 nodePort, const char *nodeUser,
|
||||||
StringInfo transmitCommand, StringInfo filePath)
|
StringInfo transmitCommand, StringInfo filePath)
|
||||||
{
|
{
|
||||||
int32 fileDescriptor = -1;
|
int32 fileDescriptor = -1;
|
||||||
|
@ -257,7 +273,7 @@ ReceiveRegularFile(const char *nodeName, uint32 nodePort,
|
||||||
nodeDatabase = get_database_name(MyDatabaseId);
|
nodeDatabase = get_database_name(MyDatabaseId);
|
||||||
|
|
||||||
/* connect to remote node */
|
/* connect to remote node */
|
||||||
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL);
|
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, nodeUser);
|
||||||
if (connectionId == INVALID_CONNECTION_ID)
|
if (connectionId == INVALID_CONNECTION_ID)
|
||||||
{
|
{
|
||||||
ReceiveResourceCleanup(connectionId, filename, fileDescriptor);
|
ReceiveResourceCleanup(connectionId, filename, fileDescriptor);
|
||||||
|
@ -745,7 +761,8 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName)
|
||||||
remoteCopyCommand = makeStringInfo();
|
remoteCopyCommand = makeStringInfo();
|
||||||
appendStringInfo(remoteCopyCommand, COPY_OUT_COMMAND, tableName);
|
appendStringInfo(remoteCopyCommand, COPY_OUT_COMMAND, tableName);
|
||||||
|
|
||||||
received = ReceiveRegularFile(nodeName, nodePort, remoteCopyCommand, localFilePath);
|
received = ReceiveRegularFile(nodeName, nodePort, NULL, remoteCopyCommand,
|
||||||
|
localFilePath);
|
||||||
if (!received)
|
if (!received)
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
|
@ -820,6 +837,7 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, const char *tableName)
|
||||||
static bool
|
static bool
|
||||||
FetchForeignTable(const char *nodeName, uint32 nodePort, const char *tableName)
|
FetchForeignTable(const char *nodeName, uint32 nodePort, const char *tableName)
|
||||||
{
|
{
|
||||||
|
const char *nodeUser = NULL;
|
||||||
StringInfo localFilePath = NULL;
|
StringInfo localFilePath = NULL;
|
||||||
StringInfo remoteFilePath = NULL;
|
StringInfo remoteFilePath = NULL;
|
||||||
StringInfo transmitCommand = NULL;
|
StringInfo transmitCommand = NULL;
|
||||||
|
@ -846,7 +864,16 @@ FetchForeignTable(const char *nodeName, uint32 nodePort, const char *tableName)
|
||||||
transmitCommand = makeStringInfo();
|
transmitCommand = makeStringInfo();
|
||||||
appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilePath->data);
|
appendStringInfo(transmitCommand, TRANSMIT_REGULAR_COMMAND, remoteFilePath->data);
|
||||||
|
|
||||||
received = ReceiveRegularFile(nodeName, nodePort, transmitCommand, localFilePath);
|
/*
|
||||||
|
* We allow some arbitrary input in the file name and connect to the remote
|
||||||
|
* node as superuser to transmit. Therefore, we only allow calling this
|
||||||
|
* function when already running as superuser.
|
||||||
|
*/
|
||||||
|
EnsureSuperUser();
|
||||||
|
nodeUser = CitusExtensionOwnerName();
|
||||||
|
|
||||||
|
received = ReceiveRegularFile(nodeName, nodePort, nodeUser, transmitCommand,
|
||||||
|
localFilePath);
|
||||||
if (!received)
|
if (!received)
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
|
@ -1183,7 +1210,7 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS)
|
||||||
sourceCopyCommand = makeStringInfo();
|
sourceCopyCommand = makeStringInfo();
|
||||||
appendStringInfo(sourceCopyCommand, COPY_OUT_COMMAND, sourceQualifiedName);
|
appendStringInfo(sourceCopyCommand, COPY_OUT_COMMAND, sourceQualifiedName);
|
||||||
|
|
||||||
received = ReceiveRegularFile(sourceNodeName, sourceNodePort, sourceCopyCommand,
|
received = ReceiveRegularFile(sourceNodeName, sourceNodePort, NULL, sourceCopyCommand,
|
||||||
localFilePath);
|
localFilePath);
|
||||||
if (!received)
|
if (!received)
|
||||||
{
|
{
|
||||||
|
|
|
@ -100,7 +100,7 @@ typedef struct WaitInfo
|
||||||
extern int32 MultiClientConnect(const char *nodeName, uint32 nodePort,
|
extern int32 MultiClientConnect(const char *nodeName, uint32 nodePort,
|
||||||
const char *nodeDatabase, const char *nodeUser);
|
const char *nodeDatabase, const char *nodeUser);
|
||||||
extern int32 MultiClientConnectStart(const char *nodeName, uint32 nodePort,
|
extern int32 MultiClientConnectStart(const char *nodeName, uint32 nodePort,
|
||||||
const char *nodeDatabase);
|
const char *nodeDatabase, const char *nodeUser);
|
||||||
extern ConnectStatus MultiClientConnectPoll(int32 connectionId);
|
extern ConnectStatus MultiClientConnectPoll(int32 connectionId);
|
||||||
extern void MultiClientDisconnect(int32 connectionId);
|
extern void MultiClientDisconnect(int32 connectionId);
|
||||||
extern bool MultiClientConnectionUp(int32 connectionId);
|
extern bool MultiClientConnectionUp(int32 connectionId);
|
||||||
|
|
|
@ -153,6 +153,7 @@ typedef struct TaskTracker
|
||||||
{
|
{
|
||||||
uint32 workerPort; /* node's port; part of hash table key */
|
uint32 workerPort; /* node's port; part of hash table key */
|
||||||
char workerName[WORKER_LENGTH]; /* node's name; part of hash table key */
|
char workerName[WORKER_LENGTH]; /* node's name; part of hash table key */
|
||||||
|
char *userName; /* which user to connect as */
|
||||||
TrackerStatus trackerStatus;
|
TrackerStatus trackerStatus;
|
||||||
int32 connectionId;
|
int32 connectionId;
|
||||||
uint32 connectPollCount;
|
uint32 connectPollCount;
|
||||||
|
|
Loading…
Reference in New Issue