Use the current session's username when connecting to worker nodes.

So far we've always used libpq defaults when connecting to workers; bar
special environment variables being set that'll always be the user that
started the server.  That's not desirable because it prevents using
users with fewer privileges.

Thus change the various APIs creating connections to workers to always
use usernames. That means:
1) MultiClientConnect() needs to, optionally, accept a username
2) GetOrEstablishConnection(), including the underlying cache, need to
   use the current user as part of the connection cache key. That way
   connections for separate users are distinct, and we always use one
   with the correct authorization.
3) The task tracker needs to keep track of the username associated with
   a task, so it can use it when establishing connections outside the
   originating session.
pull/471/head
Andres Freund 2016-02-25 18:40:11 -08:00
parent abb4ec019f
commit 42d232c0e8
14 changed files with 103 additions and 24 deletions

View File

@ -793,11 +793,12 @@ OpenCopyTransactions(CopyStmt *copyStatement, ShardConnections *shardConnections
ShardPlacement *placement = (ShardPlacement *) lfirst(placementCell);
char *nodeName = placement->nodeName;
int nodePort = placement->nodePort;
char *nodeUser = CurrentUserName();
TransactionConnection *transactionConnection = NULL;
StringInfo copyCommand = NULL;
PGresult *result = NULL;
PGconn *connection = ConnectToNode(nodeName, nodePort);
PGconn *connection = ConnectToNode(nodeName, nodePort, nodeUser);
/* release failed placement list and copy command at the end of this function */
oldContext = MemoryContextSwitchTo(localContext);

View File

@ -15,6 +15,10 @@
#include "postgres.h"
#include "fmgr.h"
#include "libpq-fe.h"
#include "miscadmin.h"
#include "commands/dbcommands.h"
#include "distributed/metadata_cache.h"
#include "distributed/multi_client_executor.h"
#include <errno.h>
@ -76,24 +80,57 @@ AllocateConnectionId(void)
* MultiClientConnect synchronously tries to establish a connection. If it
* succeeds, it returns the connection id. Otherwise, it reports connection
* error and returns INVALID_CONNECTION_ID.
*
* nodeDatabase and userName can be NULL, in which case values from the
* current session are used.
*/
int32
MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase)
MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDatabase,
const char *userName)
{
PGconn *connection = NULL;
char connInfoString[STRING_BUFFER_SIZE];
ConnStatusType connStatusType = CONNECTION_OK;
int32 connectionId = AllocateConnectionId();
char *effectiveDatabaseName = NULL;
char *effectiveUserName = NULL;
if (connectionId == INVALID_CONNECTION_ID)
{
ereport(WARNING, (errmsg("could not allocate connection in connection pool")));
return connectionId;
}
if (nodeDatabase == NULL)
{
effectiveDatabaseName = get_database_name(MyDatabaseId);
}
else
{
effectiveDatabaseName = pstrdup(nodeDatabase);
}
if (userName == NULL)
{
effectiveUserName = CurrentUserName();
}
else
{
effectiveUserName = pstrdup(userName);
}
/*
* FIXME: This code is bad on several levels. It completely forgoes any
* escaping, it misses setting a number of parameters, it works with a
* limited string size without erroring when it's too long. We shouldn't
* even build a query string this way, there's PQconnectdbParams()!
*/
/* transcribe connection paremeters to string */
snprintf(connInfoString, STRING_BUFFER_SIZE, CONN_INFO_TEMPLATE,
nodeName, nodePort, nodeDatabase, CLIENT_CONNECT_TIMEOUT);
nodeName, nodePort,
effectiveDatabaseName, effectiveUserName,
CLIENT_CONNECT_TIMEOUT);
/* establish synchronous connection to worker node */
connection = PQconnectdb(connInfoString);
@ -111,6 +148,9 @@ MultiClientConnect(const char *nodeName, uint32 nodePort, const char *nodeDataba
connectionId = INVALID_CONNECTION_ID;
}
pfree(effectiveDatabaseName);
pfree(effectiveUserName);
return connectionId;
}
@ -126,6 +166,7 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD
PGconn *connection = NULL;
char connInfoString[STRING_BUFFER_SIZE];
ConnStatusType connStatusType = CONNECTION_BAD;
char *userName = CurrentUserName();
int32 connectionId = AllocateConnectionId();
if (connectionId == INVALID_CONNECTION_ID)
@ -136,7 +177,7 @@ MultiClientConnectStart(const char *nodeName, uint32 nodePort, const char *nodeD
/* transcribe connection paremeters to string */
snprintf(connInfoString, STRING_BUFFER_SIZE, CONN_INFO_TEMPLATE,
nodeName, nodePort, nodeDatabase, CLIENT_CONNECT_TIMEOUT);
nodeName, nodePort, nodeDatabase, userName, CLIENT_CONNECT_TIMEOUT);
/* prepare asynchronous request for worker node connection */
connection = PQconnectStart(connInfoString);

View File

@ -494,7 +494,7 @@ ExecuteRemoteCommand(const char *nodeName, uint32 nodePort, StringInfo queryStri
bool queryReady = false;
bool queryDone = false;
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase);
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL);
if (connectionId == INVALID_CONNECTION_ID)
{
return false;

View File

@ -793,9 +793,8 @@ static bool
WorkerNodeResponsive(const char *workerName, uint32 workerPort)
{
bool workerNodeResponsive = false;
const char *databaseName = get_database_name(MyDatabaseId);
int connectionId = MultiClientConnect(workerName, workerPort, databaseName);
int connectionId = MultiClientConnect(workerName, workerPort, NULL, NULL);
if (connectionId != INVALID_CONNECTION_ID)
{
MultiClientDisconnect(connectionId);

View File

@ -15,6 +15,8 @@
#include <math.h>
#include "miscadmin.h"
#include "access/genam.h"
#include "access/hash.h"
#include "access/heapam.h"

View File

@ -20,6 +20,7 @@
#include "commands/dbcommands.h"
#include "distributed/connection_cache.h"
#include "distributed/metadata_cache.h"
#include "lib/stringinfo.h"
#include "mb/pg_wchar.h"
#include "utils/builtins.h"
@ -61,6 +62,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort)
NodeConnectionEntry *nodeConnectionEntry = NULL;
bool entryFound = false;
bool needNewConnection = true;
char *userName = CurrentUserName();
/* check input */
if (strnlen(nodeName, MAX_NODE_LENGTH + 1) > MAX_NODE_LENGTH)
@ -79,6 +81,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort)
memset(&nodeConnectionKey, 0, sizeof(nodeConnectionKey));
strncpy(nodeConnectionKey.nodeName, nodeName, MAX_NODE_LENGTH);
nodeConnectionKey.nodePort = nodePort;
strncpy(nodeConnectionKey.nodeUser, userName, NAMEDATALEN);
nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey,
HASH_FIND, &entryFound);
@ -97,7 +100,7 @@ GetOrEstablishConnection(char *nodeName, int32 nodePort)
if (needNewConnection)
{
connection = ConnectToNode(nodeName, nodePort);
connection = ConnectToNode(nodeName, nodePort, nodeConnectionKey.nodeUser);
if (connection != NULL)
{
nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey,
@ -123,6 +126,7 @@ PurgeConnection(PGconn *connection)
bool entryFound = false;
char *nodeNameString = NULL;
char *nodePortString = NULL;
char *nodeUserString = NULL;
nodeNameString = ConnectionGetOptionValue(connection, "host");
if (nodeNameString == NULL)
@ -138,12 +142,21 @@ PurgeConnection(PGconn *connection)
errmsg("connection is missing port option")));
}
nodeUserString = ConnectionGetOptionValue(connection, "user");
if (nodeUserString == NULL)
{
ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("connection is missing user option")));
}
memset(&nodeConnectionKey, 0, sizeof(nodeConnectionKey));
strncpy(nodeConnectionKey.nodeName, nodeNameString, MAX_NODE_LENGTH);
nodeConnectionKey.nodePort = pg_atoi(nodePortString, sizeof(int32), 0);
strncpy(nodeConnectionKey.nodeUser, nodeUserString, NAMEDATALEN);
pfree(nodeNameString);
pfree(nodePortString);
pfree(nodeUserString);
nodeConnectionEntry = hash_search(NodeConnectionHash, &nodeConnectionKey,
HASH_REMOVE, &entryFound);
@ -253,14 +266,14 @@ CreateNodeConnectionHash(void)
/*
* ConnectToNode opens a connection to a remote PostgreSQL server. The function
* configures the connection's fallback application name to 'citus' and sets
* the remote encoding to match the local one. This function requires that the
* port be specified as a string for easier use with libpq functions.
* the remote encoding to match the local one. All parameters are required to
* be non NULL.
*
* We attempt to connect up to MAX_CONNECT_ATTEMPT times. After that we give up
* and return NULL.
*/
PGconn *
ConnectToNode(char *nodeName, int32 nodePort)
ConnectToNode(char *nodeName, int32 nodePort, char *nodeUser)
{
PGconn *connection = NULL;
const char *clientEncoding = GetDatabaseEncodingName();
@ -269,12 +282,12 @@ ConnectToNode(char *nodeName, int32 nodePort)
const char *keywordArray[] = {
"host", "port", "fallback_application_name",
"client_encoding", "connect_timeout", "dbname", NULL
"client_encoding", "connect_timeout", "dbname", "user", NULL
};
char nodePortString[12];
const char *valueArray[] = {
nodeName, nodePortString, "citus", clientEncoding,
CLIENT_CONNECT_TIMEOUT_SECONDS, dbname, NULL
CLIENT_CONNECT_TIMEOUT_SECONDS, dbname, nodeUser, NULL
};
sprintf(nodePortString, "%d", nodePort);

View File

@ -9,6 +9,8 @@
#include "postgres.h"
#include "miscadmin.h"
#include "access/genam.h"
#include "access/heapam.h"
#include "access/htup_details.h"
@ -649,6 +651,20 @@ CitusExtraDataContainerFuncId(void)
}
/* return the username of the currently active role */
char *
CurrentUserName(void)
{
Oid userId = GetUserId();
#if (PG_VERSION_NUM < 90500)
return GetUserNameFromId(userId);
#else
return GetUserNameFromId(userId, false);
#endif
}
/*
* master_dist_partition_cache_invalidate is a trigger function that performs
* relcache invalidations when the contents of pg_dist_partition are changed

View File

@ -88,7 +88,7 @@ static void ManageWorkerTasksHash(HTAB *WorkerTasksHash);
static void ManageWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash);
static void RemoveWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash);
static void CreateJobDirectoryIfNotExists(uint64 jobId);
static int32 ConnectToLocalBackend(const char *databaseName);
static int32 ConnectToLocalBackend(const char *databaseName, const char *userName);
/* Organize, at startup, that the task tracker is started */
@ -904,7 +904,8 @@ ManageWorkerTask(WorkerTask *workerTask, HTAB *WorkerTasksHash)
CreateJobDirectoryIfNotExists(workerTask->jobId);
/* the task is ready to run; connect to local backend */
workerTask->connectionId = ConnectToLocalBackend(workerTask->databaseName);
workerTask->connectionId = ConnectToLocalBackend(workerTask->databaseName,
workerTask->userName);
if (workerTask->connectionId != INVALID_CONNECTION_ID)
{
@ -1082,7 +1083,7 @@ CreateJobDirectoryIfNotExists(uint64 jobId)
/* Wrapper function to inititate connection to local backend. */
static int32
ConnectToLocalBackend(const char *databaseName)
ConnectToLocalBackend(const char *databaseName, const char *userName)
{
const char *nodeName = LOCAL_HOST_NAME;
const uint32 nodePort = PostPortNumber;
@ -1091,7 +1092,7 @@ ConnectToLocalBackend(const char *databaseName)
* Our client library currently only handles TCP sockets. We therefore do
* not use Unix domain sockets here.
*/
int32 connectionId = MultiClientConnect(nodeName, nodePort, databaseName);
int32 connectionId = MultiClientConnect(nodeName, nodePort, databaseName, userName);
return connectionId;
}

View File

@ -22,6 +22,7 @@
#include "access/xact.h"
#include "commands/dbcommands.h"
#include "commands/schemacmds.h"
#include "distributed/metadata_cache.h"
#include "distributed/multi_client_executor.h"
#include "distributed/multi_server_executor.h"
#include "distributed/resource_lock.h"
@ -288,6 +289,7 @@ CreateTask(uint64 jobId, uint32 taskId, char *taskCallString)
WorkerTask *workerTask = NULL;
uint32 assignmentTime = 0;
char *databaseName = get_database_name(MyDatabaseId);
char *userName = CurrentUserName();
/* increase task priority for cleanup tasks */
assignmentTime = (uint32) time(NULL);
@ -305,6 +307,7 @@ CreateTask(uint64 jobId, uint32 taskId, char *taskCallString)
workerTask->connectionId = INVALID_CONNECTION_ID;
workerTask->failureCount = 0;
strncpy(workerTask->databaseName, databaseName, NAMEDATALEN);
strncpy(workerTask->userName, userName, NAMEDATALEN);
}

View File

@ -254,7 +254,7 @@ ReceiveRegularFile(const char *nodeName, uint32 nodePort,
nodeDatabase = get_database_name(MyDatabaseId);
/* connect to remote node */
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase);
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase, NULL);
if (connectionId == INVALID_CONNECTION_ID)
{
ReceiveResourceCleanup(connectionId, filename, fileDescriptor);
@ -870,7 +870,6 @@ ForeignFilePath(const char *nodeName, uint32 nodePort, StringInfo tableName)
List *
ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, StringInfo queryString)
{
char *nodeDatabase = get_database_name(MyDatabaseId);
int32 connectionId = -1;
bool querySent = false;
bool queryReady = false;
@ -881,7 +880,7 @@ ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, StringInfo queryString
int columnCount = 0;
List *resultList = NIL;
connectionId = MultiClientConnect(nodeName, nodePort, nodeDatabase);
connectionId = MultiClientConnect(nodeName, nodePort, NULL, NULL);
if (connectionId == INVALID_CONNECTION_ID)
{
return NIL;

View File

@ -39,6 +39,7 @@ typedef struct NodeConnectionKey
{
char nodeName[MAX_NODE_LENGTH + 1]; /* hostname of host to connect to */
int32 nodePort; /* port of host to connect to */
char nodeUser[NAMEDATALEN + 1]; /* user name to connect as */
} NodeConnectionKey;
@ -54,7 +55,7 @@ typedef struct NodeConnectionEntry
extern PGconn * GetOrEstablishConnection(char *nodeName, int32 nodePort);
extern void PurgeConnection(PGconn *connection);
extern void ReportRemoteError(PGconn *connection, PGresult *result);
extern PGconn * ConnectToNode(char *nodeName, int nodePort);
extern PGconn * ConnectToNode(char *nodeName, int nodePort, char *nodeUser);
extern char * ConnectionGetOptionValue(PGconn *connection, char *optionKeyword);

View File

@ -69,4 +69,6 @@ extern Oid DistShardPlacementShardidIndexId(void);
/* function oids */
extern Oid CitusExtraDataContainerFuncId(void);
/* user related functions */
extern char * CurrentUserName(void);
#endif /* METADATA_CACHE_H */

View File

@ -19,7 +19,7 @@
#define CLIENT_CONNECT_TIMEOUT 5 /* connection timeout in seconds */
#define MAX_CONNECTION_COUNT 2048 /* simultaneous client connection count */
#define STRING_BUFFER_SIZE 1024 /* buffer size for character arrays */
#define CONN_INFO_TEMPLATE "host=%s port=%u dbname=%s connect_timeout=%u"
#define CONN_INFO_TEMPLATE "host=%s port=%u dbname=%s user=%s connect_timeout=%u"
/* Enumeration to track one client connection's status */
@ -74,7 +74,7 @@ typedef enum
/* Function declarations for executing client-side (libpq) logic. */
extern int32 MultiClientConnect(const char *nodeName, uint32 nodePort,
const char *nodeDatabase);
const char *nodeDatabase, const char *nodeUser);
extern int32 MultiClientConnectStart(const char *nodeName, uint32 nodePort,
const char *nodeDatabase);
extern ConnectStatus MultiClientConnectPoll(int32 connectionId);

View File

@ -82,6 +82,7 @@ typedef struct WorkerTask
char taskCallString[TASK_CALL_STRING_SIZE]; /* query or function call string */
TaskStatus taskStatus; /* task's current execution status */
char databaseName[NAMEDATALEN]; /* name to use for local backend connection */
char userName[NAMEDATALEN]; /* user to use for local backend connection */
int32 connectionId; /* connection id to local backend */
uint32 failureCount; /* number of task failures */
} WorkerTask;