diff --git a/src/backend/distributed/commands/multi_copy.c b/src/backend/distributed/commands/multi_copy.c index 42bbec3c7..5daae51f3 100644 --- a/src/backend/distributed/commands/multi_copy.c +++ b/src/backend/distributed/commands/multi_copy.c @@ -793,11 +793,12 @@ OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell); char *nodeName = placement->nodeName; int nodePort = placement->nodePort; + char *nodeUser = CurrentUserName(); TransactionConnection *transactionConnection = NULL; StringInfo copyCommand = NULL; PGresult *result = NULL; - PGconn *connection = ConnectToNode(nodeName, nodePort); + PGconn *connection = ConnectToNode(nodeName, nodePort, nodeUser); /* release failed placement list and copy command at the end of this function */ oldContext = MemoryContextSwitchTo(localContext); diff --git a/src/backend/distributed/executor/multi_client_executor.c b/src/backend/distributed/executor/multi_client_executor.c index 286dab089..361688348 100644 --- a/src/backend/distributed/executor/multi_client_executor.c +++ b/src/backend/distributed/executor/multi_client_executor.c @@ -15,6 +15,10 @@ #include "postgres.h" #include "fmgr.h" #include "libpq-fe.h" +#include "miscadmin.h" + +#include "commands/dbcommands.h" +#include "distributed/metadata_cache.h" #include "distributed/multi_client_executor.h" #include @@ -76,24 +80,57 @@ AllocateConnectionId(void) * MultiClientConnect synchronously tries to establish a connection. If it * succeeds, it returns the connection id. Otherwise, it reports connection * error and returns INVALID_CONNECTION_ID. + * + * nodeDatabase and userName can be NULL, in which case values from the + * current session are used. */ int32 -MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase) +MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase, + const char *userName) { PGconn *connection = NULL; char connInfoString[STRING_BUFFER_SIZE]; ConnStatusType connStatusType = CONNECTION_OK; - int32 connectionId = AllocateConnectionId(); + char *effectiveDatabaseName = NULL; + char *effectiveUserName = NULL; + if (connectionId == INVALID_CONNECTION_ID) { ereport(WARNING, (errmsg("could not allocate connection in connection pool"))); return connectionId; } + if (nodeDatabase == NULL) + { + effectiveDatabaseName = get_database_name(MyDatabaseId); + } + else + { + effectiveDatabaseName = pstrdup(nodeDatabase); + } + + if (userName == NULL) + { + effectiveUserName = CurrentUserName(); + } + else + { + effectiveUserName = pstrdup(userName); + } + + /* + * FIXME: This code is bad on several levels. It completely forgoes any + * escaping, it misses setting a number of parameters, it works with a + * limited string size without erroring when it's too long. We shouldn't + * even build a query string this way, there's PQconnectdbParams()! + */ + /* transcribe connection paremeters to string */ snprintf(connInfoString, STRING_BUFFER_SIZE, CONN_INFO_TEMPLATE, - nodeName, nodePort, nodeDatabase, CLIENT_CONNECT_TIMEOUT); + nodeName, nodePort, + effectiveDatabaseName, effectiveUserName, + CLIENT_CONNECT_TIMEOUT); /* establish synchronous connection to worker node */ connection = PQconnectdb(connInfoString); @@ -111,6 +148,9 @@ MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDataba connectionId = INVALID_CONNECTION_ID; } + pfree(effectiveDatabaseName); + pfree(effectiveUserName); + return connectionId; } @@ -126,6 +166,7 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD PGconn *connection = NULL; char connInfoString[STRING_BUFFER_SIZE]; ConnStatusType connStatusType = CONNECTION_BAD; + char *userName = CurrentUserName(); int32 connectionId = AllocateConnectionId(); if (connectionId == INVALID_CONNECTION_ID) @@ -136,7 +177,7 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD /* transcribe connection paremeters to string */ snprintf(connInfoString, STRING_BUFFER_SIZE, CONN_INFO_TEMPLATE, - nodeName, nodePort, nodeDatabase, CLIENT_CONNECT_TIMEOUT); + nodeName, nodePort, nodeDatabase, userName, CLIENT_CONNECT_TIMEOUT); /* prepare asynchronous request for worker node connection */ connection = PQconnectStart(connInfoString); diff --git a/src/backend/distributed/master/master_delete_protocol.c b/src/backend/distributed/master/master_delete_protocol.c index 8be57af52..3a9efae6f 100644 --- a/src/backend/distributed/master/master_delete_protocol.c +++ b/src/backend/distributed/master/master_delete_protocol.c @@ -494,7 +494,7 @@ ExecuteRemoteCommand(const char *nodeName, uint32 nodePort, StringInfo queryStri bool queryReady = false; bool queryDone = false; - connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase); + connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL); if (connectionId == INVALID_CONNECTION_ID) { return false; diff --git a/src/backend/distributed/master/worker_node_manager.c b/src/backend/distributed/master/worker_node_manager.c index 04ba68658..d9eb76734 100644 --- a/src/backend/distributed/master/worker_node_manager.c +++ b/src/backend/distributed/master/worker_node_manager.c @@ -793,9 +793,8 @@ static bool WorkerNodeResponsive(const char *workerName, uint32 workerPort) { bool workerNodeResponsive = false; - const char *databaseName = get_database_name(MyDatabaseId); - int connectionId = MultiClientConnect(workerName, workerPort, databaseName); + int connectionId = MultiClientConnect(workerName, workerPort, NULL, NULL); if (connectionId != INVALID_CONNECTION_ID) { MultiClientDisconnect(connectionId); diff --git a/src/backend/distributed/planner/multi_physical_planner.c b/src/backend/distributed/planner/multi_physical_planner.c index 9d927c99f..62eb14a1a 100644 --- a/src/backend/distributed/planner/multi_physical_planner.c +++ b/src/backend/distributed/planner/multi_physical_planner.c @@ -15,6 +15,8 @@ #include +#include "miscadmin.h" + #include "access/genam.h" #include "access/hash.h" #include "access/heapam.h" diff --git a/src/backend/distributed/utils/connection_cache.c b/src/backend/distributed/utils/connection_cache.c index db419c457..a93b21ec8 100644 --- a/src/backend/distributed/utils/connection_cache.c +++ b/src/backend/distributed/utils/connection_cache.c @@ -20,6 +20,7 @@ #include "commands/dbcommands.h" #include "distributed/connection_cache.h" +#include "distributed/metadata_cache.h" #include "lib/stringinfo.h" #include "mb/pg_wchar.h" #include "utils/builtins.h" @@ -61,6 +62,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort) NodeConnectionEntry *nodeConnectionEntry = NULL; bool entryFound = false; bool needNewConnection = true; + char *userName = CurrentUserName(); /* check input */ if (strnlen(nodeName, MAX_NODE_LENGTH + 1) > MAX_NODE_LENGTH) @@ -79,6 +81,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort) memset(&nodeConnectionKey, 0, sizeof(nodeConnectionKey)); strncpy(nodeConnectionKey.nodeName, nodeName, MAX_NODE_LENGTH); nodeConnectionKey.nodePort = nodePort; + strncpy(nodeConnectionKey.nodeUser, userName, NAMEDATALEN); nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey, HASH_FIND, &entryFound); @@ -97,7 +100,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort) if (needNewConnection) { - connection = ConnectToNode(nodeName, nodePort); + connection = ConnectToNode(nodeName, nodePort, nodeConnectionKey.nodeUser); if (connection != NULL) { nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey, @@ -123,6 +126,7 @@ PurgeConnection(PGconn *connection) bool entryFound = false; char *nodeNameString = NULL; char *nodePortString = NULL; + char *nodeUserString = NULL; nodeNameString = ConnectionGetOptionValue(connection, "host"); if (nodeNameString == NULL) @@ -138,12 +142,21 @@ PurgeConnection(PGconn *connection) errmsg("connection is missing port option"))); } + nodeUserString = ConnectionGetOptionValue(connection, "user"); + if (nodeUserString == NULL) + { + ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE), + errmsg("connection is missing user option"))); + } + memset(&nodeConnectionKey, 0, sizeof(nodeConnectionKey)); strncpy(nodeConnectionKey.nodeName, nodeNameString, MAX_NODE_LENGTH); nodeConnectionKey.nodePort = pg_atoi(nodePortString, sizeof(int32), 0); + strncpy(nodeConnectionKey.nodeUser, nodeUserString, NAMEDATALEN); pfree(nodeNameString); pfree(nodePortString); + pfree(nodeUserString); nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey, HASH_REMOVE, &entryFound); @@ -253,14 +266,14 @@ CreateNodeConnectionHash(void) /* * ConnectToNode opens a connection to a remote PostgreSQL server. The function * configures the connection's fallback application name to 'citus' and sets - * the remote encoding to match the local one. This function requires that the - * port be specified as a string for easier use with libpq functions. + * the remote encoding to match the local one. All parameters are required to + * be non NULL. * * We attempt to connect up to MAX_CONNECT_ATTEMPT times. After that we give up * and return NULL. */ PGconn * -ConnectToNode(char *nodeName, int32 nodePort) +ConnectToNode(char *nodeName, int32 nodePort, char *nodeUser) { PGconn *connection = NULL; const char *clientEncoding = GetDatabaseEncodingName(); @@ -269,12 +282,12 @@ ConnectToNode(char *nodeName, int32 nodePort) const char *keywordArray[] = { "host", "port", "fallback_application_name", - "client_encoding", "connect_timeout", "dbname", NULL + "client_encoding", "connect_timeout", "dbname", "user", NULL }; char nodePortString[12]; const char *valueArray[] = { nodeName, nodePortString, "citus", clientEncoding, - CLIENT_CONNECT_TIMEOUT_SECONDS, dbname, NULL + CLIENT_CONNECT_TIMEOUT_SECONDS, dbname, nodeUser, NULL }; sprintf(nodePortString, "%d", nodePort); diff --git a/src/backend/distributed/utils/metadata_cache.c b/src/backend/distributed/utils/metadata_cache.c index 9c8a81edc..9260dde00 100644 --- a/src/backend/distributed/utils/metadata_cache.c +++ b/src/backend/distributed/utils/metadata_cache.c @@ -9,6 +9,8 @@ #include "postgres.h" +#include "miscadmin.h" + #include "access/genam.h" #include "access/heapam.h" #include "access/htup_details.h" @@ -649,6 +651,20 @@ CitusExtraDataContainerFuncId(void) } +/* return the username of the currently active role */ +char * +CurrentUserName(void) +{ + Oid userId = GetUserId(); + +#if (PG_VERSION_NUM < 90500) + return GetUserNameFromId(userId); +#else + return GetUserNameFromId(userId, false); +#endif +} + + /* * master_dist_partition_cache_invalidate is a trigger function that performs * relcache invalidations when the contents of pg_dist_partition are changed diff --git a/src/backend/distributed/worker/task_tracker.c b/src/backend/distributed/worker/task_tracker.c index aa7d9f9ec..3db6194af 100644 --- a/src/backend/distributed/worker/task_tracker.c +++ b/src/backend/distributed/worker/task_tracker.c @@ -88,7 +88,7 @@ static void ManageWorkerTasksHash(HTAB *WorkerTasksHash); static void ManageWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash); static void RemoveWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash); static void CreateJobDirectoryIfNotExists(uint64 jobId); -static int32 ConnectToLocalBackend(const char *databaseName); +static int32 ConnectToLocalBackend(const char *databaseName, const char *userName); /* Organize, at startup, that the task tracker is started */ @@ -904,7 +904,8 @@ ManageWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash) CreateJobDirectoryIfNotExists(workerTask->jobId); /* the task is ready to run; connect to local backend */ - workerTask->connectionId = ConnectToLocalBackend(workerTask->databaseName); + workerTask->connectionId = ConnectToLocalBackend(workerTask->databaseName, + workerTask->userName); if (workerTask->connectionId != INVALID_CONNECTION_ID) { @@ -1082,7 +1083,7 @@ CreateJobDirectoryIfNotExists(uint64 jobId) /* Wrapper function to inititate connection to local backend. */ static int32 -ConnectToLocalBackend(const char *databaseName) +ConnectToLocalBackend(const char *databaseName, const char *userName) { const char *nodeName = LOCAL_HOST_NAME; const uint32 nodePort = PostPortNumber; @@ -1091,7 +1092,7 @@ ConnectToLocalBackend(const char *databaseName) * Our client library currently only handles TCP sockets. We therefore do * not use Unix domain sockets here. */ - int32 connectionId = MultiClientConnect(nodeName, nodePort, databaseName); + int32 connectionId = MultiClientConnect(nodeName, nodePort, databaseName, userName); return connectionId; } diff --git a/src/backend/distributed/worker/task_tracker_protocol.c b/src/backend/distributed/worker/task_tracker_protocol.c index ba36fa737..a2b2628ea 100644 --- a/src/backend/distributed/worker/task_tracker_protocol.c +++ b/src/backend/distributed/worker/task_tracker_protocol.c @@ -22,6 +22,7 @@ #include "access/xact.h" #include "commands/dbcommands.h" #include "commands/schemacmds.h" +#include "distributed/metadata_cache.h" #include "distributed/multi_client_executor.h" #include "distributed/multi_server_executor.h" #include "distributed/resource_lock.h" @@ -288,6 +289,7 @@ CreateTask(uint64 jobId, uint32 taskId, char *taskCallString) WorkerTask *workerTask = NULL; uint32 assignmentTime = 0; char *databaseName = get_database_name(MyDatabaseId); + char *userName = CurrentUserName(); /* increase task priority for cleanup tasks */ assignmentTime = (uint32) time(NULL); @@ -305,6 +307,7 @@ CreateTask(uint64 jobId, uint32 taskId, char *taskCallString) workerTask->connectionId = INVALID_CONNECTION_ID; workerTask->failureCount = 0; strncpy(workerTask->databaseName, databaseName, NAMEDATALEN); + strncpy(workerTask->userName, userName, NAMEDATALEN); } diff --git a/src/backend/distributed/worker/worker_data_fetch_protocol.c b/src/backend/distributed/worker/worker_data_fetch_protocol.c index d6f0773a0..3bb557c84 100644 --- a/src/backend/distributed/worker/worker_data_fetch_protocol.c +++ b/src/backend/distributed/worker/worker_data_fetch_protocol.c @@ -254,7 +254,7 @@ ReceiveRegularFile(const char *nodeName, uint32 nodePort, nodeDatabase = get_database_name(MyDatabaseId); /* connect to remote node */ - connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase); + connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL); if (connectionId == INVALID_CONNECTION_ID) { ReceiveResourceCleanup(connectionId, filename, fileDescriptor); @@ -870,7 +870,6 @@ ForeignFilePath(const char *nodeName, uint32 nodePort, StringInfo tableName) List * ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, StringInfo queryString) { - char *nodeDatabase = get_database_name(MyDatabaseId); int32 connectionId = -1; bool querySent = false; bool queryReady = false; @@ -881,7 +880,7 @@ ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, StringInfo queryString int columnCount = 0; List *resultList = NIL; - connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase); + connectionId = MultiClientConnect(nodeName, nodePort, NULL, NULL); if (connectionId == INVALID_CONNECTION_ID) { return NIL; diff --git a/src/include/distributed/connection_cache.h b/src/include/distributed/connection_cache.h index e39a0c406..ecc6b6c7b 100644 --- a/src/include/distributed/connection_cache.h +++ b/src/include/distributed/connection_cache.h @@ -39,6 +39,7 @@ typedef struct NodeConnectionKey { char nodeName[MAX_NODE_LENGTH + 1]; /* hostname of host to connect to */ int32 nodePort; /* port of host to connect to */ + char nodeUser[NAMEDATALEN + 1]; /* user name to connect as */ } NodeConnectionKey; @@ -54,7 +55,7 @@ typedef struct NodeConnectionEntry extern PGconn * GetOrEstablishConnection(char *nodeName, int32 nodePort); extern void PurgeConnection(PGconn *connection); extern void ReportRemoteError(PGconn *connection, PGresult *result); -extern PGconn * ConnectToNode(char *nodeName, int nodePort); +extern PGconn * ConnectToNode(char *nodeName, int nodePort, char *nodeUser); extern char * ConnectionGetOptionValue(PGconn *connection, char *optionKeyword); diff --git a/src/include/distributed/metadata_cache.h b/src/include/distributed/metadata_cache.h index 32485e152..8df923734 100644 --- a/src/include/distributed/metadata_cache.h +++ b/src/include/distributed/metadata_cache.h @@ -69,4 +69,6 @@ extern Oid DistShardPlacementShardidIndexId(void); /* function oids */ extern Oid CitusExtraDataContainerFuncId(void); +/* user related functions */ +extern char * CurrentUserName(void); #endif /* METADATA_CACHE_H */ diff --git a/src/include/distributed/multi_client_executor.h b/src/include/distributed/multi_client_executor.h index 7e91d715d..1486cd692 100644 --- a/src/include/distributed/multi_client_executor.h +++ b/src/include/distributed/multi_client_executor.h @@ -19,7 +19,7 @@ #define CLIENT_CONNECT_TIMEOUT 5 /* connection timeout in seconds */ #define MAX_CONNECTION_COUNT 2048 /* simultaneous client connection count */ #define STRING_BUFFER_SIZE 1024 /* buffer size for character arrays */ -#define CONN_INFO_TEMPLATE "host=%s port=%u dbname=%s connect_timeout=%u" +#define CONN_INFO_TEMPLATE "host=%s port=%u dbname=%s user=%s connect_timeout=%u" /* Enumeration to track one client connection's status */ @@ -74,7 +74,7 @@ typedef enum /* Function declarations for executing client-side (libpq) logic. */ extern int32 MultiClientConnect(const char *nodeName, uint32 nodePort, - const char *nodeDatabase); + const char *nodeDatabase, const char *nodeUser); extern int32 MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeDatabase); extern ConnectStatus MultiClientConnectPoll(int32 connectionId); diff --git a/src/include/distributed/task_tracker.h b/src/include/distributed/task_tracker.h index 1b04ccb31..81fa4fe43 100644 --- a/src/include/distributed/task_tracker.h +++ b/src/include/distributed/task_tracker.h @@ -82,6 +82,7 @@ typedef struct WorkerTask char taskCallString[TASK_CALL_STRING_SIZE]; /* query or function call string */ TaskStatus taskStatus; /* task's current execution status */ char databaseName[NAMEDATALEN]; /* name to use for local backend connection */ + char userName[NAMEDATALEN]; /* user to use for local backend connection */ int32 connectionId; /* connection id to local backend */ uint32 failureCount; /* number of task failures */ } WorkerTask;