From 42d232c0e80656cb35fe07aeef4d40426bcc7e1f Mon Sep 17 00:00:00 2001 From: Andres Freund Date: Thu, 25 Feb 2016 18:40:11 -0800 Subject: [PATCH] Use the current session's username when connecting to worker nodes. So far we've always used libpq defaults when connecting to workers; bar special environment variables being set that'll always be the user that started the server. That's not desirable because it prevents using users with fewer privileges. Thus change the various APIs creating connections to workers to always use usernames. That means: 1) MultiClientConnect() needs to, optionally, accept a username 2) GetOrEstablishConnection(), including the underlying cache, need to use the current user as part of the connection cache key. That way connections for separate users are distinct, and we always use one with the correct authorization. 3) The task tracker needs to keep track of the username associated with a task, so it can use it when establishing connections outside the originating session. --- src/backend/distributed/commands/multi_copy.c | 3 +- .../executor/multi_client_executor.c | 49 +++++++++++++++++-- .../master/master_delete_protocol.c | 2 +- .../distributed/master/worker_node_manager.c | 3 +- .../planner/multi_physical_planner.c | 2 + .../distributed/utils/connection_cache.c | 25 +++++++--- .../distributed/utils/metadata_cache.c | 16 ++++++ src/backend/distributed/worker/task_tracker.c | 9 ++-- .../worker/task_tracker_protocol.c | 3 ++ .../worker/worker_data_fetch_protocol.c | 5 +- src/include/distributed/connection_cache.h | 3 +- src/include/distributed/metadata_cache.h | 2 + .../distributed/multi_client_executor.h | 4 +- src/include/distributed/task_tracker.h | 1 + 14 files changed, 103 insertions(+), 24 deletions(-) diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index 42bbec3c7..5daae51f3 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -793,11 +793,12 @@ OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell); char *nodeName = placement->nodeName; int nodePort = placement->nodePort; + char *nodeUser = CurrentUserName(); TransactionConnection *transactionConnection = NULL; StringInfo copyCommand = NULL; PGresult *result = NULL; - PGconn *connection = ConnectToNode(nodeName, nodePort); + PGconn *connection = ConnectToNode(nodeName, nodePort, nodeUser); /* release failed placement list and copy command at the end of this function */ oldContext = MemoryContextSwitchTo(localContext); diff --git a/src/backend/distributed/executor/multi_client_executor.c b/src/backend/distributed/executor/multi_client_executor.c index 286dab089..361688348 100644 --- a/src/backend/distributed/executor/multi_client_executor.c +++ b/src/backend/distributed/executor/multi_client_executor.c @@ -15,6 +15,10 @@ #include "postgres.h" #include "fmgr.h" #include "libpq-fe.h" +#include "miscadmin.h" + +#include "commands/dbcommands.h" +#include "distributed/metadata_cache.h" #include "distributed/multi_client_executor.h" #include @@ -76,24 +80,57 @@ AllocateConnectionId(void) * MultiClientConnect synchronously tries to establish a connection. If it * succeeds, it returns the connection id. Otherwise, it reports connection * error and returns INVALID_CONNECTION_ID. + * + * nodeDatabase and userName can be NULL, in which case values from the + * current session are used. */ int32 -MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase) +MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase, + const char *userName) { PGconn *connection = NULL; char connInfoString[STRING_BUFFER_SIZE]; ConnStatusType connStatusType = CONNECTION_OK; - int32 connectionId = AllocateConnectionId(); + char *effectiveDatabaseName = NULL; + char *effectiveUserName = NULL; + if (connectionId == INVALID_CONNECTION_ID) { ereport(WARNING, (errmsg("could not allocate connection in connection pool"))); return connectionId; } + if (nodeDatabase == NULL) + { + effectiveDatabaseName = get_database_name(MyDatabaseId); + } + else + { + effectiveDatabaseName = pstrdup(nodeDatabase); + } + + if (userName == NULL) + { + effectiveUserName = CurrentUserName(); + } + else + { + effectiveUserName = pstrdup(userName); + } + + /* + * FIXME: This code is bad on several levels. It completely forgoes any + * escaping, it misses setting a number of parameters, it works with a + * limited string size without erroring when it's too long. We shouldn't + * even build a query string this way, there's PQconnectdbParams()! + */ + /* transcribe connection paremeters to string */ snprintf(connInfoString, STRING_BUFFER_SIZE, CONN_INFO_TEMPLATE, - nodeName, nodePort, nodeDatabase, CLIENT_CONNECT_TIMEOUT); + nodeName, nodePort, + effectiveDatabaseName, effectiveUserName, + CLIENT_CONNECT_TIMEOUT); /* establish synchronous connection to worker node */ connection = PQconnectdb(connInfoString); @@ -111,6 +148,9 @@ MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDataba connectionId = INVALID_CONNECTION_ID; } + pfree(effectiveDatabaseName); + pfree(effectiveUserName); + return connectionId; } @@ -126,6 +166,7 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD PGconn *connection = NULL; char connInfoString[STRING_BUFFER_SIZE]; ConnStatusType connStatusType = CONNECTION_BAD; + char *userName = CurrentUserName(); int32 connectionId = AllocateConnectionId(); if (connectionId == INVALID_CONNECTION_ID) @@ -136,7 +177,7 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD /* transcribe connection paremeters to string */ snprintf(connInfoString, STRING_BUFFER_SIZE, CONN_INFO_TEMPLATE, - nodeName, nodePort, nodeDatabase, CLIENT_CONNECT_TIMEOUT); + nodeName, nodePort, nodeDatabase, userName, CLIENT_CONNECT_TIMEOUT); /* prepare asynchronous request for worker node connection */ connection = PQconnectStart(connInfoString); diff --git a/src/backend/distributed/master/master_delete_protocol.c b/src/backend/distributed/master/master_delete_protocol.c index 8be57af52..3a9efae6f 100644 --- a/src/backend/distributed/master/master_delete_protocol.c +++ b/src/backend/distributed/master/master_delete_protocol.c @@ -494,7 +494,7 @@ ExecuteRemoteCommand(const char *nodeName, uint32 nodePort, StringInfo queryStri bool queryReady = false; bool queryDone = false; - connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase); + connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL); if (connectionId == INVALID_CONNECTION_ID) { return false; diff --git a/src/backend/distributed/master/worker_node_manager.c b/src/backend/distributed/master/worker_node_manager.c index 04ba68658..d9eb76734 100644 --- a/src/backend/distributed/master/worker_node_manager.c +++ b/src/backend/distributed/master/worker_node_manager.c @@ -793,9 +793,8 @@ static bool WorkerNodeResponsive(const char *workerName, uint32 workerPort) { bool workerNodeResponsive = false; - const char *databaseName = get_database_name(MyDatabaseId); - int connectionId = MultiClientConnect(workerName, workerPort, databaseName); + int connectionId = MultiClientConnect(workerName, workerPort, NULL, NULL); if (connectionId != INVALID_CONNECTION_ID) { MultiClientDisconnect(connectionId); diff --git a/src/backend/distributed/planner/multi_physical_planner.c b/src/backend/distributed/planner/multi_physical_planner.c index 9d927c99f..62eb14a1a 100644 --- a/src/backend/distributed/planner/multi_physical_planner.c +++ b/src/backend/distributed/planner/multi_physical_planner.c @@ -15,6 +15,8 @@ #include +#include "miscadmin.h" + #include "access/genam.h" #include "access/hash.h" #include "access/heapam.h" diff --git a/src/backend/distributed/utils/connection_cache.c b/src/backend/distributed/utils/connection_cache.c index db419c457..a93b21ec8 100644 --- a/src/backend/distributed/utils/connection_cache.c +++ b/src/backend/distributed/utils/connection_cache.c @@ -20,6 +20,7 @@ #include "commands/dbcommands.h" #include "distributed/connection_cache.h" +#include "distributed/metadata_cache.h" #include "lib/stringinfo.h" #include "mb/pg_wchar.h" #include "utils/builtins.h" @@ -61,6 +62,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort) NodeConnectionEntry *nodeConnectionEntry = NULL; bool entryFound = false; bool needNewConnection = true; + char *userName = CurrentUserName(); /* check input */ if (strnlen(nodeName, MAX_NODE_LENGTH + 1) > MAX_NODE_LENGTH) @@ -79,6 +81,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort) memset(&nodeConnectionKey, 0, sizeof(nodeConnectionKey)); strncpy(nodeConnectionKey.nodeName, nodeName, MAX_NODE_LENGTH); nodeConnectionKey.nodePort = nodePort; + strncpy(nodeConnectionKey.nodeUser, userName, NAMEDATALEN); nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey, HASH_FIND, &entryFound); @@ -97,7 +100,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort) if (needNewConnection) { - connection = ConnectToNode(nodeName, nodePort); + connection = ConnectToNode(nodeName, nodePort, nodeConnectionKey.nodeUser); if (connection != NULL) { nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey, @@ -123,6 +126,7 @@ PurgeConnection(PGconn *connection) bool entryFound = false; char *nodeNameString = NULL; char *nodePortString = NULL; + char *nodeUserString = NULL; nodeNameString = ConnectionGetOptionValue(connection, "host"); if (nodeNameString == NULL) @@ -138,12 +142,21 @@ PurgeConnection(PGconn *connection) errmsg("connection is missing port option"))); } + nodeUserString = ConnectionGetOptionValue(connection, "user"); + if (nodeUserString == NULL) + { + ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("connection is missing user option"))); + } + memset(&nodeConnectionKey, 0, sizeof(nodeConnectionKey)); strncpy(nodeConnectionKey.nodeName, nodeNameString, MAX_NODE_LENGTH); nodeConnectionKey.nodePort = pg_atoi(nodePortString, sizeof(int32), 0); + strncpy(nodeConnectionKey.nodeUser, nodeUserString, NAMEDATALEN); pfree(nodeNameString); pfree(nodePortString); + pfree(nodeUserString); nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey, HASH_REMOVE, &entryFound); @@ -253,14 +266,14 @@ CreateNodeConnectionHash(void) /* * ConnectToNode opens a connection to a remote PostgreSQL server. The function * configures the connection's fallback application name to 'citus' and sets - * the remote encoding to match the local one. This function requires that the - * port be specified as a string for easier use with libpq functions. + * the remote encoding to match the local one. All parameters are required to + * be non NULL. * * We attempt to connect up to MAX_CONNECT_ATTEMPT times. After that we give up * and return NULL. */ PGconn * -ConnectToNode(char *nodeName, int32 nodePort) +ConnectToNode(char *nodeName, int32 nodePort, char *nodeUser) { PGconn *connection = NULL; const char *clientEncoding = GetDatabaseEncodingName(); @@ -269,12 +282,12 @@ ConnectToNode(char *nodeName, int32 nodePort) const char *keywordArray[] = { "host", "port", "fallback_application_name", - "client_encoding", "connect_timeout", "dbname", NULL + "client_encoding", "connect_timeout", "dbname", "user", NULL }; char nodePortString[12]; const char *valueArray[] = { nodeName, nodePortString, "citus", clientEncoding, - CLIENT_CONNECT_TIMEOUT_SECONDS, dbname, NULL + CLIENT_CONNECT_TIMEOUT_SECONDS, dbname, nodeUser, NULL }; sprintf(nodePortString, "%d", nodePort); diff --git a/src/backend/distributed/utils/metadata_cache.c b/src/backend/distributed/utils/metadata_cache.c index 9c8a81edc..9260dde00 100644 --- a/src/backend/distributed/utils/metadata_cache.c +++ b/src/backend/distributed/utils/metadata_cache.c @@ -9,6 +9,8 @@ #include "postgres.h" +#include "miscadmin.h" + #include "access/genam.h" #include "access/heapam.h" #include "access/htup_details.h" @@ -649,6 +651,20 @@ CitusExtraDataContainerFuncId(void) } +/* return the username of the currently active role */ +char * +CurrentUserName(void) +{ + Oid userId = GetUserId(); + +#if (PG_VERSION_NUM < 90500) + return GetUserNameFromId(userId); +#else + return GetUserNameFromId(userId, false); +#endif +} + + /* * master_dist_partition_cache_invalidate is a trigger function that performs * relcache invalidations when the contents of pg_dist_partition are changed diff --git a/src/backend/distributed/worker/task_tracker.c b/src/backend/distributed/worker/task_tracker.c index aa7d9f9ec..3db6194af 100644 --- a/src/backend/distributed/worker/task_tracker.c +++ b/src/backend/distributed/worker/task_tracker.c @@ -88,7 +88,7 @@ static void ManageWorkerTasksHash(HTAB *WorkerTasksHash); static void ManageWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash); static void RemoveWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash); static void CreateJobDirectoryIfNotExists(uint64 jobId); -static int32 ConnectToLocalBackend(const char *databaseName); +static int32 ConnectToLocalBackend(const char *databaseName, const char *userName); /* Organize, at startup, that the task tracker is started */ @@ -904,7 +904,8 @@ ManageWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash) CreateJobDirectoryIfNotExists(workerTask->jobId); /* the task is ready to run; connect to local backend */ - workerTask->connectionId = ConnectToLocalBackend(workerTask->databaseName); + workerTask->connectionId = ConnectToLocalBackend(workerTask->databaseName, + workerTask->userName); if (workerTask->connectionId != INVALID_CONNECTION_ID) { @@ -1082,7 +1083,7 @@ CreateJobDirectoryIfNotExists(uint64 jobId) /* Wrapper function to inititate connection to local backend. */ static int32 -ConnectToLocalBackend(const char *databaseName) +ConnectToLocalBackend(const char *databaseName, const char *userName) { const char *nodeName = LOCAL_HOST_NAME; const uint32 nodePort = PostPortNumber; @@ -1091,7 +1092,7 @@ ConnectToLocalBackend(const char *databaseName) * Our client library currently only handles TCP sockets. We therefore do * not use Unix domain sockets here. */ - int32 connectionId = MultiClientConnect(nodeName, nodePort, databaseName); + int32 connectionId = MultiClientConnect(nodeName, nodePort, databaseName, userName); return connectionId; } diff --git a/src/backend/distributed/worker/task_tracker_protocol.c b/src/backend/distributed/worker/task_tracker_protocol.c index ba36fa737..a2b2628ea 100644 --- a/src/backend/distributed/worker/task_tracker_protocol.c +++ b/src/backend/distributed/worker/task_tracker_protocol.c @@ -22,6 +22,7 @@ #include "access/xact.h" #include "commands/dbcommands.h" #include "commands/schemacmds.h" +#include "distributed/metadata_cache.h" #include "distributed/multi_client_executor.h" #include "distributed/multi_server_executor.h" #include "distributed/resource_lock.h" @@ -288,6 +289,7 @@ CreateTask(uint64 jobId, uint32 taskId, char *taskCallString) WorkerTask *workerTask = NULL; uint32 assignmentTime = 0; char *databaseName = get_database_name(MyDatabaseId); + char *userName = CurrentUserName(); /* increase task priority for cleanup tasks */ assignmentTime = (uint32) time(NULL); @@ -305,6 +307,7 @@ CreateTask(uint64 jobId, uint32 taskId, char *taskCallString) workerTask->connectionId = INVALID_CONNECTION_ID; workerTask->failureCount = 0; strncpy(workerTask->databaseName, databaseName, NAMEDATALEN); + strncpy(workerTask->userName, userName, NAMEDATALEN); } diff --git a/src/backend/distributed/worker/worker_data_fetch_protocol.c b/src/backend/distributed/worker/worker_data_fetch_protocol.c index d6f0773a0..3bb557c84 100644 --- a/src/backend/distributed/worker/worker_data_fetch_protocol.c +++ b/src/backend/distributed/worker/worker_data_fetch_protocol.c @@ -254,7 +254,7 @@ ReceiveRegularFile(const char *nodeName, uint32 nodePort, nodeDatabase = get_database_name(MyDatabaseId); /* connect to remote node */ - connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase); + connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL); if (connectionId == INVALID_CONNECTION_ID) { ReceiveResourceCleanup(connectionId, filename, fileDescriptor); @@ -870,7 +870,6 @@ ForeignFilePath(const char *nodeName, uint32 nodePort, StringInfo tableName) List * ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, StringInfo queryString) { - char *nodeDatabase = get_database_name(MyDatabaseId); int32 connectionId = -1; bool querySent = false; bool queryReady = false; @@ -881,7 +880,7 @@ ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, StringInfo queryString int columnCount = 0; List *resultList = NIL; - connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase); + connectionId = MultiClientConnect(nodeName, nodePort, NULL, NULL); if (connectionId == INVALID_CONNECTION_ID) { return NIL; diff --git a/src/include/distributed/connection_cache.h b/src/include/distributed/connection_cache.h index e39a0c406..ecc6b6c7b 100644 --- a/src/include/distributed/connection_cache.h +++ b/src/include/distributed/connection_cache.h @@ -39,6 +39,7 @@ typedef struct NodeConnectionKey { char nodeName[MAX_NODE_LENGTH + 1]; /* hostname of host to connect to */ int32 nodePort; /* port of host to connect to */ + char nodeUser[NAMEDATALEN + 1]; /* user name to connect as */ } NodeConnectionKey; @@ -54,7 +55,7 @@ typedef struct NodeConnectionEntry extern PGconn * GetOrEstablishConnection(char *nodeName, int32 nodePort); extern void PurgeConnection(PGconn *connection); extern void ReportRemoteError(PGconn *connection, PGresult *result); -extern PGconn * ConnectToNode(char *nodeName, int nodePort); +extern PGconn * ConnectToNode(char *nodeName, int nodePort, char *nodeUser); extern char * ConnectionGetOptionValue(PGconn *connection, char *optionKeyword); diff --git a/src/include/distributed/metadata_cache.h b/src/include/distributed/metadata_cache.h index 32485e152..8df923734 100644 --- a/src/include/distributed/metadata_cache.h +++ b/src/include/distributed/metadata_cache.h @@ -69,4 +69,6 @@ extern Oid DistShardPlacementShardidIndexId(void); /* function oids */ extern Oid CitusExtraDataContainerFuncId(void); +/* user related functions */ +extern char * CurrentUserName(void); #endif /* METADATA_CACHE_H */ diff --git a/src/include/distributed/multi_client_executor.h b/src/include/distributed/multi_client_executor.h index 7e91d715d..1486cd692 100644 --- a/src/include/distributed/multi_client_executor.h +++ b/src/include/distributed/multi_client_executor.h @@ -19,7 +19,7 @@ #define CLIENT_CONNECT_TIMEOUT 5 /* connection timeout in seconds */ #define MAX_CONNECTION_COUNT 2048 /* simultaneous client connection count */ #define STRING_BUFFER_SIZE 1024 /* buffer size for character arrays */ -#define CONN_INFO_TEMPLATE "host=%s port=%u dbname=%s connect_timeout=%u" +#define CONN_INFO_TEMPLATE "host=%s port=%u dbname=%s user=%s connect_timeout=%u" /* Enumeration to track one client connection's status */ @@ -74,7 +74,7 @@ typedef enum /* Function declarations for executing client-side (libpq) logic. */ extern int32 MultiClientConnect(const char *nodeName, uint32 nodePort, - const char *nodeDatabase); + const char *nodeDatabase, const char *nodeUser); extern int32 MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase); extern ConnectStatus MultiClientConnectPoll(int32 connectionId); diff --git a/src/include/distributed/task_tracker.h b/src/include/distributed/task_tracker.h index 1b04ccb31..81fa4fe43 100644 --- a/src/include/distributed/task_tracker.h +++ b/src/include/distributed/task_tracker.h @@ -82,6 +82,7 @@ typedef struct WorkerTask char taskCallString[TASK_CALL_STRING_SIZE]; /* query or function call string */ TaskStatus taskStatus; /* task's current execution status */ char databaseName[NAMEDATALEN]; /* name to use for local backend connection */ + char userName[NAMEDATALEN]; /* user to use for local backend connection */ int32 connectionId; /* connection id to local backend */ uint32 failureCount; /* number of task failures */ } WorkerTask;