From 758a70a8ff51bb94c5d6e127620f33ff9d19fc20 Mon Sep 17 00:00:00 2001 From: Andres Freund Date: Tue, 26 Apr 2016 14:42:31 -0700 Subject: [PATCH] Create new shards as owned the distributed table's owner. That's important because ownership of relations implies special privileges. Without this change, a distributed table can be accessible by a table's owner, but a shard created by another user might not. --- .../distributed/executor/multi_utility.c | 3 +- .../distributed/master/master_create_shards.c | 5 +- .../master/master_metadata_utility.c | 29 +++++++++ .../distributed/master/master_repair_shards.c | 10 ++- .../master/master_stage_protocol.c | 31 ++++++---- .../worker/worker_data_fetch_protocol.c | 61 +++++++++++++++++-- .../distributed/master_metadata_utility.h | 1 + src/include/distributed/master_protocol.h | 1 + src/include/distributed/worker_protocol.h | 5 +- 9 files changed, 123 insertions(+), 23 deletions(-) diff --git a/src/backend/distributed/executor/multi_utility.c b/src/backend/distributed/executor/multi_utility.c index da65e3ef5..58de7ae0f 100644 --- a/src/backend/distributed/executor/multi_utility.c +++ b/src/backend/distributed/executor/multi_utility.c @@ -949,6 +949,7 @@ ExecuteCommandOnWorkerShards(Oid relationId, const char *commandString, bool isFirstPlacement = true; ListCell *shardCell = NULL; List *shardList = NIL; + char *relationOwner = TableOwner(relationId); shardList = LoadShardList(relationId); foreach(shardCell, shardList) @@ -972,7 +973,7 @@ ExecuteCommandOnWorkerShards(Oid relationId, const char *commandString, uint32 workerPort = placement->nodePort; List *queryResultList = ExecuteRemoteQuery(workerName, workerPort, - applyCommand); + relationOwner, applyCommand); if (queryResultList == NIL) { /* diff --git a/src/backend/distributed/master/master_create_shards.c b/src/backend/distributed/master/master_create_shards.c index 82de910b5..7dd8008cf 100644 --- a/src/backend/distributed/master/master_create_shards.c +++ b/src/backend/distributed/master/master_create_shards.c @@ -76,6 +76,7 @@ master_create_worker_shards(PG_FUNCTION_ARGS) Oid distributedTableId = ResolveRelationId(tableNameText); char relationKind = get_rel_relkind(distributedTableId); char *tableName = text_to_cstring(tableNameText); + char *relationOwner = NULL; char shardStorageType = '\0'; List *workerNodeList = NIL; List *ddlCommandList = NIL; @@ -99,6 +100,8 @@ master_create_worker_shards(PG_FUNCTION_ARGS) /* we plan to add shards: get an exclusive metadata lock */ LockRelationDistributionMetadata(distributedTableId, ExclusiveLock); + relationOwner = TableOwner(distributedTableId); + /* validate that shards haven't already been created for this table */ existingShardList = LoadShardList(distributedTableId); if (existingShardList != NIL) @@ -192,7 +195,7 @@ master_create_worker_shards(PG_FUNCTION_ARGS) */ LockShardDistributionMetadata(shardId, ExclusiveLock); - CreateShardPlacements(shardId, ddlCommandList, workerNodeList, + CreateShardPlacements(shardId, ddlCommandList, relationOwner, workerNodeList, roundRobinNodeIndex, replicationFactor); InsertShardRow(distributedTableId, shardId, shardStorageType, diff --git a/src/backend/distributed/master/master_metadata_utility.c b/src/backend/distributed/master/master_metadata_utility.c index ac4ca44b3..fce154c2a 100644 --- a/src/backend/distributed/master/master_metadata_utility.c +++ b/src/backend/distributed/master/master_metadata_utility.c @@ -654,6 +654,35 @@ EnsureTableOwner(Oid relationId) } } + +/* + * Return a table's owner as a string. + */ +char * +TableOwner(Oid relationId) +{ + Oid userId = InvalidOid; + HeapTuple tuple; + + tuple = SearchSysCache1(RELOID, ObjectIdGetDatum(relationId)); + if (!HeapTupleIsValid(tuple)) + { + ereport(ERROR, (errcode(ERRCODE_UNDEFINED_TABLE), + errmsg("relation with OID %u does not exist", relationId))); + } + + userId = ((Form_pg_class) GETSTRUCT(tuple))->relowner; + + ReleaseSysCache(tuple); + +#if (PG_VERSION_NUM < 90500) + return GetUserNameFromId(userId); +#else + return GetUserNameFromId(userId, false); +#endif +} + + /* * master_stage_shard_row() inserts a row into pg_dist_shard, after performing * basic permission checks. diff --git a/src/backend/distributed/master/master_repair_shards.c b/src/backend/distributed/master/master_repair_shards.c index 39e14a1e8..2fbb8cd9f 100644 --- a/src/backend/distributed/master/master_repair_shards.c +++ b/src/backend/distributed/master/master_repair_shards.c @@ -68,6 +68,7 @@ master_copy_shard_placement(PG_FUNCTION_ARGS) ShardInterval *shardInterval = LoadShardInterval(shardId); Oid distributedTableId = shardInterval->relationId; + char *relationOwner = NULL; List *shardPlacementList = NIL; ShardPlacement *sourcePlacement = NULL; ShardPlacement *targetPlacement = NULL; @@ -92,6 +93,8 @@ master_copy_shard_placement(PG_FUNCTION_ARGS) */ LockShardDistributionMetadata(shardId, ExclusiveLock); + relationOwner = TableOwner(distributedTableId); + shardPlacementList = ShardPlacementList(shardId); sourcePlacement = SearchShardPlacementInList(shardPlacementList, sourceNodeName, sourceNodePort); @@ -131,7 +134,8 @@ master_copy_shard_placement(PG_FUNCTION_ARGS) targetPlacement->nodePort); /* finally, drop/recreate remote table and add back row (in healthy state) */ - CreateShardPlacements(shardId, ddlCommandList, list_make1(targetNode), 0, 1); + CreateShardPlacements(shardId, ddlCommandList, relationOwner, + list_make1(targetNode), 0, 1); HOLD_INTERRUPTS(); @@ -256,7 +260,9 @@ CopyDataFromFinalizedPlacement(Oid distributedTableId, int64 shardId, healthyPlacement->nodePort); /* remote port */ queryResultList = ExecuteRemoteQuery(placementToRepair->nodeName, - placementToRepair->nodePort, copyRelationQuery); + placementToRepair->nodePort, + NULL, /* current user, just data manipulation */ + copyRelationQuery); if (queryResultList != NIL) { copySuccessful = true; diff --git a/src/backend/distributed/master/master_stage_protocol.c b/src/backend/distributed/master/master_stage_protocol.c index 29d0af05b..80785ea7e 100644 --- a/src/backend/distributed/master/master_stage_protocol.c +++ b/src/backend/distributed/master/master_stage_protocol.c @@ -40,8 +40,8 @@ /* Local functions forward declarations */ -static bool WorkerCreateShard(char *nodeName, uint32 nodePort, - uint64 shardId, List *ddlCommandList); +static bool WorkerCreateShard(char *nodeName, uint32 nodePort, uint64 shardId, + char *newShardOwner, List *ddlCommandList); static bool WorkerShardStats(char *nodeName, uint32 nodePort, Oid relationId, char *shardName, uint64 *shardSize, text **shardMinValue, text **shardMaxValue); @@ -82,6 +82,7 @@ master_create_empty_shard(PG_FUNCTION_ARGS) char storageType = SHARD_STORAGE_TABLE; Oid relationId = ResolveRelationId(relationNameText); + char *relationOwner = TableOwner(relationId); EnsureTablePermissions(relationId, ACL_INSERT); CheckDistributedTable(relationId); @@ -129,8 +130,8 @@ master_create_empty_shard(PG_FUNCTION_ARGS) candidateNodeCount++; } - CreateShardPlacements(shardId, ddlEventList, candidateNodeList, 0, - ShardReplicationFactor); + CreateShardPlacements(shardId, ddlEventList, relationOwner, + candidateNodeList, 0, ShardReplicationFactor); InsertShardRow(relationId, shardId, storageType, nullMinValue, nullMaxValue); @@ -226,7 +227,9 @@ master_append_table_to_shard(PG_FUNCTION_ARGS) quote_literal_cstr(sourceTableName), quote_literal_cstr(sourceNodeName), sourceNodePort); - queryResultList = ExecuteRemoteQuery(workerName, workerPort, workerAppendQuery); + /* inserting data should be performed by the current user */ + queryResultList = ExecuteRemoteQuery(workerName, workerPort, NULL, + workerAppendQuery); if (queryResultList != NIL) { succeededPlacementList = lappend(succeededPlacementList, shardPlacement); @@ -310,8 +313,8 @@ CheckDistributedTable(Oid relationId) * nodes if some DDL commands had been successful). */ void -CreateShardPlacements(int64 shardId, List *ddlEventList, List *workerNodeList, - int workerStartIndex, int replicationFactor) +CreateShardPlacements(int64 shardId, List *ddlEventList, char *newPlacementOwner, + List *workerNodeList, int workerStartIndex, int replicationFactor) { int attemptCount = replicationFactor; int workerNodeCount = list_length(workerNodeList); @@ -331,7 +334,8 @@ CreateShardPlacements(int64 shardId, List *ddlEventList, List *workerNodeList, char *nodeName = workerNode->workerName; uint32 nodePort = workerNode->workerPort; - bool created = WorkerCreateShard(nodeName, nodePort, shardId, ddlEventList); + bool created = WorkerCreateShard(nodeName, nodePort, shardId, newPlacementOwner, + ddlEventList); if (created) { const RelayFileState shardState = FILE_FINALIZED; @@ -367,8 +371,8 @@ CreateShardPlacements(int64 shardId, List *ddlEventList, List *workerNodeList, * each DDL command, and could leave the shard in an half-initialized state. */ static bool -WorkerCreateShard(char *nodeName, uint32 nodePort, - uint64 shardId, List *ddlCommandList) +WorkerCreateShard(char *nodeName, uint32 nodePort, uint64 shardId, + char *newShardOwner, List *ddlCommandList) { bool shardCreated = true; ListCell *ddlCommandCell = NULL; @@ -383,7 +387,8 @@ WorkerCreateShard(char *nodeName, uint32 nodePort, appendStringInfo(applyDDLCommand, WORKER_APPLY_SHARD_DDL_COMMAND, shardId, escapedDDLCommand); - queryResultList = ExecuteRemoteQuery(nodeName, nodePort, applyDDLCommand); + queryResultList = ExecuteRemoteQuery(nodeName, nodePort, newShardOwner, + applyDDLCommand); if (queryResultList == NIL) { shardCreated = false; @@ -537,7 +542,7 @@ WorkerTableSize(char *nodeName, uint32 nodePort, Oid relationId, char *tableName appendStringInfo(tableSizeQuery, SHARD_TABLE_SIZE_QUERY, tableName); } - queryResultList = ExecuteRemoteQuery(nodeName, nodePort, tableSizeQuery); + queryResultList = ExecuteRemoteQuery(nodeName, nodePort, NULL, tableSizeQuery); if (queryResultList == NIL) { ereport(ERROR, (errmsg("could not receive table size from node " @@ -583,7 +588,7 @@ WorkerPartitionValue(char *nodeName, uint32 nodePort, Oid relationId, * simply casts the results to a (char *). If the user partitioned the table * on a binary byte array, this approach fails and should be fixed. */ - queryResultList = ExecuteRemoteQuery(nodeName, nodePort, partitionValueQuery); + queryResultList = ExecuteRemoteQuery(nodeName, nodePort, NULL, partitionValueQuery); if (queryResultList == NIL) { ereport(ERROR, (errmsg("could not receive shard min/max values from node " diff --git a/src/backend/distributed/worker/worker_data_fetch_protocol.c b/src/backend/distributed/worker/worker_data_fetch_protocol.c index 3bb557c84..7e927f75c 100644 --- a/src/backend/distributed/worker/worker_data_fetch_protocol.c +++ b/src/backend/distributed/worker/worker_data_fetch_protocol.c @@ -61,6 +61,8 @@ static bool FetchRegularTable(const char *nodeName, uint32 nodePort, StringInfo tableName); static bool FetchForeignTable(const char *nodeName, uint32 nodePort, StringInfo tableName); +static const char * RemoteTableOwner(const char *nodeName, uint32 nodePort, + StringInfo tableName); static List * TableDDLCommandList(const char *nodeName, uint32 nodePort, StringInfo tableName); static StringInfo ForeignFilePath(const char *nodeName, uint32 nodePort, @@ -689,6 +691,10 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, StringInfo tableName) char *quotedTableName = NULL; StringInfo queryString = NULL; const char *schemaName = NULL; + const char *tableOwner = NULL; + Oid tableOwnerId = InvalidOid; + Oid savedUserId = InvalidOid; + int savedSecurityContext = 0; /* copy remote table's data to this node in an idempotent manner */ shardId = ExtractShardId(tableName); @@ -706,6 +712,14 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, StringInfo tableName) return false; } + /* fetch the ddl commands needed to create the table */ + tableOwner = RemoteTableOwner(nodeName, nodePort, tableName); + if (tableOwner == NULL) + { + return false; + } + tableOwnerId = get_role_oid(tableOwner, false); + /* fetch the ddl commands needed to create the table */ ddlCommandList = TableDDLCommandList(nodeName, nodePort, tableName); if (ddlCommandList == NIL) @@ -715,8 +729,13 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, StringInfo tableName) /* * Apply DDL commands against the database. Note that on failure from here - * on, we immediately error out instead of returning false. + * on, we immediately error out instead of returning false. Have to do + * this as the table's owner to ensure the local table is created with + * compatible permissions. */ + GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); + SetUserIdAndSecContext(tableOwnerId, SECURITY_LOCAL_USERID_CHANGE); + foreach(ddlCommandCell, ddlCommandList) { StringInfo ddlCommand = (StringInfo) lfirst(ddlCommandCell); @@ -727,6 +746,8 @@ FetchRegularTable(const char *nodeName, uint32 nodePort, StringInfo tableName) CommandCounterIncrement(); } + SetUserIdAndSecContext(savedUserId, savedSecurityContext); + /* * Copy local file into the relation. We call ProcessUtility() instead of * directly calling DoCopy() because some extensions (e.g. cstore_fdw) hook @@ -817,6 +838,33 @@ FetchForeignTable(const char *nodeName, uint32 nodePort, StringInfo tableName) } +/* + * RemoteTableOwner takes in the given table name, and fetches the owner of + * the table. If an error occurs during fetching, return NULL. + */ +static const char * +RemoteTableOwner(const char *nodeName, uint32 nodePort, StringInfo tableName) +{ + List *ownerList = NIL; + StringInfo queryString = NULL; + const char *escapedTableName = quote_literal_cstr(tableName->data); + StringInfo relationOwner; + + queryString = makeStringInfo(); + appendStringInfo(queryString, GET_TABLE_OWNER, escapedTableName); + + ownerList = ExecuteRemoteQuery(nodeName, nodePort, NULL, queryString); + if (list_length(ownerList) != 1) + { + return NULL; + } + + relationOwner = (StringInfo) linitial(ownerList); + + return relationOwner->data; +} + + /* * TableDDLCommandList takes in the given table name, and fetches the list of * DDL commands used in creating the table. If an error occurs during fetching, @@ -831,7 +879,7 @@ TableDDLCommandList(const char *nodeName, uint32 nodePort, StringInfo tableName) queryString = makeStringInfo(); appendStringInfo(queryString, GET_TABLE_DDL_EVENTS, tableName->data); - ddlCommandList = ExecuteRemoteQuery(nodeName, nodePort, queryString); + ddlCommandList = ExecuteRemoteQuery(nodeName, nodePort, NULL, queryString); return ddlCommandList; } @@ -851,7 +899,7 @@ ForeignFilePath(const char *nodeName, uint32 nodePort, StringInfo tableName) foreignPathCommand = makeStringInfo(); appendStringInfo(foreignPathCommand, FOREIGN_FILE_PATH_COMMAND, tableName->data); - foreignPathList = ExecuteRemoteQuery(nodeName, nodePort, foreignPathCommand); + foreignPathList = ExecuteRemoteQuery(nodeName, nodePort, NULL, foreignPathCommand); if (foreignPathList != NIL) { foreignPath = (StringInfo) linitial(foreignPathList); @@ -866,9 +914,12 @@ ForeignFilePath(const char *nodeName, uint32 nodePort, StringInfo tableName) * sorted list, and returns this list. The function assumes that query results * have a single column, and asserts on that assumption. If results are empty, * or an error occurs during query runtime, the function returns an empty list. + * If asUser is NULL the connection is established as the current user, + * otherwise as the specified user. */ List * -ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, StringInfo queryString) +ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, char *runAsUser, + StringInfo queryString) { int32 connectionId = -1; bool querySent = false; @@ -880,7 +931,7 @@ ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, StringInfo queryString int columnCount = 0; List *resultList = NIL; - connectionId = MultiClientConnect(nodeName, nodePort, NULL, NULL); + connectionId = MultiClientConnect(nodeName, nodePort, NULL, runAsUser); if (connectionId == INVALID_CONNECTION_ID) { return NIL; diff --git a/src/include/distributed/master_metadata_utility.h b/src/include/distributed/master_metadata_utility.h index 3dd8e4927..3df148fbc 100644 --- a/src/include/distributed/master_metadata_utility.h +++ b/src/include/distributed/master_metadata_utility.h @@ -78,6 +78,7 @@ extern void DeleteShardPlacementRow(uint64 shardId, char *workerName, uint32 wor /* Remaining metadata utility functions */ extern Node * BuildDistributionKeyFromColumnName(Relation distributedRelation, char *columnName); +extern char * TableOwner(Oid relationId); extern void EnsureTablePermissions(Oid relationId, AclMode mode); extern void EnsureTableOwner(Oid relationId); diff --git a/src/include/distributed/master_protocol.h b/src/include/distributed/master_protocol.h index e9f76a093..95ecb5cbf 100644 --- a/src/include/distributed/master_protocol.h +++ b/src/include/distributed/master_protocol.h @@ -83,6 +83,7 @@ extern Oid ResolveRelationId(text *relationName); extern List * GetTableDDLEvents(Oid relationId); extern void CheckDistributedTable(Oid relationId); extern void CreateShardPlacements(int64 shardId, List *ddlEventList, + char *newPlacementOwner, List *workerNodeList, int workerStartIndex, int replicationFactor); extern uint64 UpdateShardStatistics(Oid relationId, int64 shardId); diff --git a/src/include/distributed/worker_protocol.h b/src/include/distributed/worker_protocol.h index a09b27fbf..b74fa0158 100644 --- a/src/include/distributed/worker_protocol.h +++ b/src/include/distributed/worker_protocol.h @@ -48,6 +48,9 @@ /* Defines that relate to fetching foreign tables */ #define FOREIGN_CACHED_FILE_PATH "pg_foreign_file/cached/%s" +#define GET_TABLE_OWNER \ + "SELECT rolname FROM pg_class JOIN pg_roles ON (pg_roles.oid = pg_class.relowner) " \ + "WHERE pg_class.oid = %s::regclass" #define GET_TABLE_DDL_EVENTS "SELECT master_get_table_ddl_events('%s')" #define SET_FOREIGN_TABLE_FILENAME "ALTER FOREIGN TABLE %s OPTIONS (SET filename '%s')" #define FOREIGN_FILE_PATH_COMMAND "SELECT worker_foreign_file_path('%s')" @@ -119,7 +122,7 @@ extern FmgrInfo * GetFunctionInfo(Oid typeId, Oid accessMethodId, int16 procedur /* Function declarations shared with the master planner */ extern StringInfo TaskFilename(StringInfo directoryName, uint32 taskId); -extern List * ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, +extern List * ExecuteRemoteQuery(const char *nodeName, uint32 nodePort, char *runAsUser, StringInfo queryString); extern List * ColumnDefinitionList(List *columnNameList, List *columnTypeList); extern CreateStmt * CreateStatement(RangeVar *relation, List *columnDefinitionList);