mirror of https://github.com/citusdata/citus.git
Use the current session's username when connecting to worker nodes.
So far we've always used libpq defaults when connecting to workers; bar special environment variables being set that'll always be the user that started the server. That's not desirable because it prevents using users with fewer privileges. Thus change the various APIs creating connections to workers to always use usernames. That means: 1) MultiClientConnect() needs to, optionally, accept a username 2) GetOrEstablishConnection(), including the underlying cache, need to use the current user as part of the connection cache key. That way connections for separate users are distinct, and we always use one with the correct authorization. 3) The task tracker needs to keep track of the username associated with a task, so it can use it when establishing connections outside the originating session.pull/471/head
parent
abb4ec019f
commit
42d232c0e8
|
@ -793,11 +793,12 @@ OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections
|
|||
ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell);
|
||||
char *nodeName = placement->nodeName;
|
||||
int nodePort = placement->nodePort;
|
||||
char *nodeUser = CurrentUserName();
|
||||
TransactionConnection *transactionConnection = NULL;
|
||||
StringInfo copyCommand = NULL;
|
||||
PGresult *result = NULL;
|
||||
|
||||
PGconn *connection = ConnectToNode(nodeName, nodePort);
|
||||
PGconn *connection = ConnectToNode(nodeName, nodePort, nodeUser);
|
||||
|
||||
/* release failed placement list and copy command at the end of this function */
|
||||
oldContext = MemoryContextSwitchTo(localContext);
|
||||
|
|
|
@ -15,6 +15,10 @@
|
|||
#include "postgres.h"
|
||||
#include "fmgr.h"
|
||||
#include "libpq-fe.h"
|
||||
#include "miscadmin.h"
|
||||
|
||||
#include "commands/dbcommands.h"
|
||||
#include "distributed/metadata_cache.h"
|
||||
#include "distributed/multi_client_executor.h"
|
||||
|
||||
#include <errno.h>
|
||||
|
@ -76,24 +80,57 @@ AllocateConnectionId(void)
|
|||
* MultiClientConnect synchronously tries to establish a connection. If it
|
||||
* succeeds, it returns the connection id. Otherwise, it reports connection
|
||||
* error and returns INVALID_CONNECTION_ID.
|
||||
*
|
||||
* nodeDatabase and userName can be NULL, in which case values from the
|
||||
* current session are used.
|
||||
*/
|
||||
int32
|
||||
MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase)
|
||||
MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase,
|
||||
const char *userName)
|
||||
{
|
||||
PGconn *connection = NULL;
|
||||
char connInfoString[STRING_BUFFER_SIZE];
|
||||
ConnStatusType connStatusType = CONNECTION_OK;
|
||||
|
||||
int32 connectionId = AllocateConnectionId();
|
||||
char *effectiveDatabaseName = NULL;
|
||||
char *effectiveUserName = NULL;
|
||||
|
||||
if (connectionId == INVALID_CONNECTION_ID)
|
||||
{
|
||||
ereport(WARNING, (errmsg("could not allocate connection in connection pool")));
|
||||
return connectionId;
|
||||
}
|
||||
|
||||
if (nodeDatabase == NULL)
|
||||
{
|
||||
effectiveDatabaseName = get_database_name(MyDatabaseId);
|
||||
}
|
||||
else
|
||||
{
|
||||
effectiveDatabaseName = pstrdup(nodeDatabase);
|
||||
}
|
||||
|
||||
if (userName == NULL)
|
||||
{
|
||||
effectiveUserName = CurrentUserName();
|
||||
}
|
||||
else
|
||||
{
|
||||
effectiveUserName = pstrdup(userName);
|
||||
}
|
||||
|
||||
/*
|
||||
* FIXME: This code is bad on several levels. It completely forgoes any
|
||||
* escaping, it misses setting a number of parameters, it works with a
|
||||
* limited string size without erroring when it's too long. We shouldn't
|
||||
* even build a query string this way, there's PQconnectdbParams()!
|
||||
*/
|
||||
|
||||
/* transcribe connection paremeters to string */
|
||||
snprintf(connInfoString, STRING_BUFFER_SIZE, CONN_INFO_TEMPLATE,
|
||||
nodeName, nodePort, nodeDatabase, CLIENT_CONNECT_TIMEOUT);
|
||||
nodeName, nodePort,
|
||||
effectiveDatabaseName, effectiveUserName,
|
||||
CLIENT_CONNECT_TIMEOUT);
|
||||
|
||||
/* establish synchronous connection to worker node */
|
||||
connection = PQconnectdb(connInfoString);
|
||||
|
@ -111,6 +148,9 @@ MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDataba
|
|||
connectionId = INVALID_CONNECTION_ID;
|
||||
}
|
||||
|
||||
pfree(effectiveDatabaseName);
|
||||
pfree(effectiveUserName);
|
||||
|
||||
return connectionId;
|
||||
}
|
||||
|
||||
|
@ -126,6 +166,7 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD
|
|||
PGconn *connection = NULL;
|
||||
char connInfoString[STRING_BUFFER_SIZE];
|
||||
ConnStatusType connStatusType = CONNECTION_BAD;
|
||||
char *userName = CurrentUserName();
|
||||
|
||||
int32 connectionId = AllocateConnectionId();
|
||||
if (connectionId == INVALID_CONNECTION_ID)
|
||||
|
@ -136,7 +177,7 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD
|
|||
|
||||
/* transcribe connection paremeters to string */
|
||||
snprintf(connInfoString, STRING_BUFFER_SIZE, CONN_INFO_TEMPLATE,
|
||||
nodeName, nodePort, nodeDatabase, CLIENT_CONNECT_TIMEOUT);
|
||||
nodeName, nodePort, nodeDatabase, userName, CLIENT_CONNECT_TIMEOUT);
|
||||
|
||||
/* prepare asynchronous request for worker node connection */
|
||||
connection = PQconnectStart(connInfoString);
|
||||
|
|
|
@ -494,7 +494,7 @@ ExecuteRemoteCommand(const char *nodeName, uint32 nodePort, StringInfo queryStri
|
|||
bool queryReady = false;
|
||||
bool queryDone = false;
|
||||
|
||||
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase);
|
||||
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL);
|
||||
if (connectionId == INVALID_CONNECTION_ID)
|
||||
{
|
||||
return false;
|
||||
|
|
|
@ -793,9 +793,8 @@ static bool
|
|||
WorkerNodeResponsive(const char *workerName, uint32 workerPort)
|
||||
{
|
||||
bool workerNodeResponsive = false;
|
||||
const char *databaseName = get_database_name(MyDatabaseId);
|
||||
|
||||
int connectionId = MultiClientConnect(workerName, workerPort, databaseName);
|
||||
int connectionId = MultiClientConnect(workerName, workerPort, NULL, NULL);
|
||||
if (connectionId != INVALID_CONNECTION_ID)
|
||||
{
|
||||
MultiClientDisconnect(connectionId);
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
|
||||
#include <math.h>
|
||||
|
||||
#include "miscadmin.h"
|
||||
|
||||
#include "access/genam.h"
|
||||
#include "access/hash.h"
|
||||
#include "access/heapam.h"
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "commands/dbcommands.h"
|
||||
#include "distributed/connection_cache.h"
|
||||
#include "distributed/metadata_cache.h"
|
||||
#include "lib/stringinfo.h"
|
||||
#include "mb/pg_wchar.h"
|
||||
#include "utils/builtins.h"
|
||||
|
@ -61,6 +62,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort)
|
|||
NodeConnectionEntry *nodeConnectionEntry = NULL;
|
||||
bool entryFound = false;
|
||||
bool needNewConnection = true;
|
||||
char *userName = CurrentUserName();
|
||||
|
||||
/* check input */
|
||||
if (strnlen(nodeName, MAX_NODE_LENGTH + 1) > MAX_NODE_LENGTH)
|
||||
|
@ -79,6 +81,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort)
|
|||
memset(&nodeConnectionKey, 0, sizeof(nodeConnectionKey));
|
||||
strncpy(nodeConnectionKey.nodeName, nodeName, MAX_NODE_LENGTH);
|
||||
nodeConnectionKey.nodePort = nodePort;
|
||||
strncpy(nodeConnectionKey.nodeUser, userName, NAMEDATALEN);
|
||||
|
||||
nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey,
|
||||
HASH_FIND, &entryFound);
|
||||
|
@ -97,7 +100,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort)
|
|||
|
||||
if (needNewConnection)
|
||||
{
|
||||
connection = ConnectToNode(nodeName, nodePort);
|
||||
connection = ConnectToNode(nodeName, nodePort, nodeConnectionKey.nodeUser);
|
||||
if (connection != NULL)
|
||||
{
|
||||
nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey,
|
||||
|
@ -123,6 +126,7 @@ PurgeConnection(PGconn *connection)
|
|||
bool entryFound = false;
|
||||
char *nodeNameString = NULL;
|
||||
char *nodePortString = NULL;
|
||||
char *nodeUserString = NULL;
|
||||
|
||||
nodeNameString = ConnectionGetOptionValue(connection, "host");
|
||||
if (nodeNameString == NULL)
|
||||
|
@ -138,12 +142,21 @@ PurgeConnection(PGconn *connection)
|
|||
errmsg("connection is missing port option")));
|
||||
}
|
||||
|
||||
nodeUserString = ConnectionGetOptionValue(connection, "user");
|
||||
if (nodeUserString == NULL)
|
||||
{
|
||||
ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
|
||||
errmsg("connection is missing user option")));
|
||||
}
|
||||
|
||||
memset(&nodeConnectionKey, 0, sizeof(nodeConnectionKey));
|
||||
strncpy(nodeConnectionKey.nodeName, nodeNameString, MAX_NODE_LENGTH);
|
||||
nodeConnectionKey.nodePort = pg_atoi(nodePortString, sizeof(int32), 0);
|
||||
strncpy(nodeConnectionKey.nodeUser, nodeUserString, NAMEDATALEN);
|
||||
|
||||
pfree(nodeNameString);
|
||||
pfree(nodePortString);
|
||||
pfree(nodeUserString);
|
||||
|
||||
nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey,
|
||||
HASH_REMOVE, &entryFound);
|
||||
|
@ -253,14 +266,14 @@ CreateNodeConnectionHash(void)
|
|||
/*
|
||||
* ConnectToNode opens a connection to a remote PostgreSQL server. The function
|
||||
* configures the connection's fallback application name to 'citus' and sets
|
||||
* the remote encoding to match the local one. This function requires that the
|
||||
* port be specified as a string for easier use with libpq functions.
|
||||
* the remote encoding to match the local one. All parameters are required to
|
||||
* be non NULL.
|
||||
*
|
||||
* We attempt to connect up to MAX_CONNECT_ATTEMPT times. After that we give up
|
||||
* and return NULL.
|
||||
*/
|
||||
PGconn *
|
||||
ConnectToNode(char *nodeName, int32 nodePort)
|
||||
ConnectToNode(char *nodeName, int32 nodePort, char *nodeUser)
|
||||
{
|
||||
PGconn *connection = NULL;
|
||||
const char *clientEncoding = GetDatabaseEncodingName();
|
||||
|
@ -269,12 +282,12 @@ ConnectToNode(char *nodeName, int32 nodePort)
|
|||
|
||||
const char *keywordArray[] = {
|
||||
"host", "port", "fallback_application_name",
|
||||
"client_encoding", "connect_timeout", "dbname", NULL
|
||||
"client_encoding", "connect_timeout", "dbname", "user", NULL
|
||||
};
|
||||
char nodePortString[12];
|
||||
const char *valueArray[] = {
|
||||
nodeName, nodePortString, "citus", clientEncoding,
|
||||
CLIENT_CONNECT_TIMEOUT_SECONDS, dbname, NULL
|
||||
CLIENT_CONNECT_TIMEOUT_SECONDS, dbname, nodeUser, NULL
|
||||
};
|
||||
|
||||
sprintf(nodePortString, "%d", nodePort);
|
||||
|
|
|
@ -9,6 +9,8 @@
|
|||
|
||||
#include "postgres.h"
|
||||
|
||||
#include "miscadmin.h"
|
||||
|
||||
#include "access/genam.h"
|
||||
#include "access/heapam.h"
|
||||
#include "access/htup_details.h"
|
||||
|
@ -649,6 +651,20 @@ CitusExtraDataContainerFuncId(void)
|
|||
}
|
||||
|
||||
|
||||
/* return the username of the currently active role */
|
||||
char *
|
||||
CurrentUserName(void)
|
||||
{
|
||||
Oid userId = GetUserId();
|
||||
|
||||
#if (PG_VERSION_NUM < 90500)
|
||||
return GetUserNameFromId(userId);
|
||||
#else
|
||||
return GetUserNameFromId(userId, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* master_dist_partition_cache_invalidate is a trigger function that performs
|
||||
* relcache invalidations when the contents of pg_dist_partition are changed
|
||||
|
|
|
@ -88,7 +88,7 @@ static void ManageWorkerTasksHash(HTAB *WorkerTasksHash);
|
|||
static void ManageWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash);
|
||||
static void RemoveWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash);
|
||||
static void CreateJobDirectoryIfNotExists(uint64 jobId);
|
||||
static int32 ConnectToLocalBackend(const char *databaseName);
|
||||
static int32 ConnectToLocalBackend(const char *databaseName, const char *userName);
|
||||
|
||||
|
||||
/* Organize, at startup, that the task tracker is started */
|
||||
|
@ -904,7 +904,8 @@ ManageWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash)
|
|||
CreateJobDirectoryIfNotExists(workerTask->jobId);
|
||||
|
||||
/* the task is ready to run; connect to local backend */
|
||||
workerTask->connectionId = ConnectToLocalBackend(workerTask->databaseName);
|
||||
workerTask->connectionId = ConnectToLocalBackend(workerTask->databaseName,
|
||||
workerTask->userName);
|
||||
|
||||
if (workerTask->connectionId != INVALID_CONNECTION_ID)
|
||||
{
|
||||
|
@ -1082,7 +1083,7 @@ CreateJobDirectoryIfNotExists(uint64 jobId)
|
|||
|
||||
/* Wrapper function to inititate connection to local backend. */
|
||||
static int32
|
||||
ConnectToLocalBackend(const char *databaseName)
|
||||
ConnectToLocalBackend(const char *databaseName, const char *userName)
|
||||
{
|
||||
const char *nodeName = LOCAL_HOST_NAME;
|
||||
const uint32 nodePort = PostPortNumber;
|
||||
|
@ -1091,7 +1092,7 @@ ConnectToLocalBackend(const char *databaseName)
|
|||
* Our client library currently only handles TCP sockets. We therefore do
|
||||
* not use Unix domain sockets here.
|
||||
*/
|
||||
int32 connectionId = MultiClientConnect(nodeName, nodePort, databaseName);
|
||||
int32 connectionId = MultiClientConnect(nodeName, nodePort, databaseName, userName);
|
||||
|
||||
return connectionId;
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include "access/xact.h"
|
||||
#include "commands/dbcommands.h"
|
||||
#include "commands/schemacmds.h"
|
||||
#include "distributed/metadata_cache.h"
|
||||
#include "distributed/multi_client_executor.h"
|
||||
#include "distributed/multi_server_executor.h"
|
||||
#include "distributed/resource_lock.h"
|
||||
|
@ -288,6 +289,7 @@ CreateTask(uint64 jobId, uint32 taskId, char *taskCallString)
|
|||
WorkerTask *workerTask = NULL;
|
||||
uint32 assignmentTime = 0;
|
||||
char *databaseName = get_database_name(MyDatabaseId);
|
||||
char *userName = CurrentUserName();
|
||||
|
||||
/* increase task priority for cleanup tasks */
|
||||
assignmentTime = (uint32) time(NULL);
|
||||
|
@ -305,6 +307,7 @@ CreateTask(uint64 jobId, uint32 taskId, char *taskCallString)
|
|||
workerTask->connectionId = INVALID_CONNECTION_ID;
|
||||
workerTask->failureCount = 0;
|
||||
strncpy(workerTask->databaseName, databaseName, NAMEDATALEN);
|
||||
strncpy(workerTask->userName, userName, NAMEDATALEN);
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -254,7 +254,7 @@ ReceiveRegularFile(const char *nodeName, uint32 nodePort,
|
|||
nodeDatabase = get_database_name(MyDatabaseId);
|
||||
|
||||
/* connect to remote node */
|
||||
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase);
|
||||
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL);
|
||||
if (connectionId == INVALID_CONNECTION_ID)
|
||||
{
|
||||
ReceiveResourceCleanup(connectionId, filename, fileDescriptor);
|
||||
|
@ -870,7 +870,6 @@ ForeignFilePath(const char *nodeName, uint32 nodePort, StringInfo tableName)
|
|||
List *
|
||||
ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, StringInfo queryString)
|
||||
{
|
||||
char *nodeDatabase = get_database_name(MyDatabaseId);
|
||||
int32 connectionId = -1;
|
||||
bool querySent = false;
|
||||
bool queryReady = false;
|
||||
|
@ -881,7 +880,7 @@ ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, StringInfo queryString
|
|||
int columnCount = 0;
|
||||
List *resultList = NIL;
|
||||
|
||||
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase);
|
||||
connectionId = MultiClientConnect(nodeName, nodePort, NULL, NULL);
|
||||
if (connectionId == INVALID_CONNECTION_ID)
|
||||
{
|
||||
return NIL;
|
||||
|
|
|
@ -39,6 +39,7 @@ typedef struct NodeConnectionKey
|
|||
{
|
||||
char nodeName[MAX_NODE_LENGTH + 1]; /* hostname of host to connect to */
|
||||
int32 nodePort; /* port of host to connect to */
|
||||
char nodeUser[NAMEDATALEN + 1]; /* user name to connect as */
|
||||
} NodeConnectionKey;
|
||||
|
||||
|
||||
|
@ -54,7 +55,7 @@ typedef struct NodeConnectionEntry
|
|||
extern PGconn * GetOrEstablishConnection(char *nodeName, int32 nodePort);
|
||||
extern void PurgeConnection(PGconn *connection);
|
||||
extern void ReportRemoteError(PGconn *connection, PGresult *result);
|
||||
extern PGconn * ConnectToNode(char *nodeName, int nodePort);
|
||||
extern PGconn * ConnectToNode(char *nodeName, int nodePort, char *nodeUser);
|
||||
extern char * ConnectionGetOptionValue(PGconn *connection, char *optionKeyword);
|
||||
|
||||
|
||||
|
|
|
@ -69,4 +69,6 @@ extern Oid DistShardPlacementShardidIndexId(void);
|
|||
/* function oids */
|
||||
extern Oid CitusExtraDataContainerFuncId(void);
|
||||
|
||||
/* user related functions */
|
||||
extern char * CurrentUserName(void);
|
||||
#endif /* METADATA_CACHE_H */
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
#define CLIENT_CONNECT_TIMEOUT 5 /* connection timeout in seconds */
|
||||
#define MAX_CONNECTION_COUNT 2048 /* simultaneous client connection count */
|
||||
#define STRING_BUFFER_SIZE 1024 /* buffer size for character arrays */
|
||||
#define CONN_INFO_TEMPLATE "host=%s port=%u dbname=%s connect_timeout=%u"
|
||||
#define CONN_INFO_TEMPLATE "host=%s port=%u dbname=%s user=%s connect_timeout=%u"
|
||||
|
||||
|
||||
/* Enumeration to track one client connection's status */
|
||||
|
@ -74,7 +74,7 @@ typedef enum
|
|||
|
||||
/* Function declarations for executing client-side (libpq) logic. */
|
||||
extern int32 MultiClientConnect(const char *nodeName, uint32 nodePort,
|
||||
const char *nodeDatabase);
|
||||
const char *nodeDatabase, const char *nodeUser);
|
||||
extern int32 MultiClientConnectStart(const char *nodeName, uint32 nodePort,
|
||||
const char *nodeDatabase);
|
||||
extern ConnectStatus MultiClientConnectPoll(int32 connectionId);
|
||||
|
|
|
@ -82,6 +82,7 @@ typedef struct WorkerTask
|
|||
char taskCallString[TASK_CALL_STRING_SIZE]; /* query or function call string */
|
||||
TaskStatus taskStatus; /* task's current execution status */
|
||||
char databaseName[NAMEDATALEN]; /* name to use for local backend connection */
|
||||
char userName[NAMEDATALEN]; /* user to use for local backend connection */
|
||||
int32 connectionId; /* connection id to local backend */
|
||||
uint32 failureCount; /* number of task failures */
|
||||
} WorkerTask;
|
||||
|
|
Loading…
Reference in New Issue