From 620dd121371a845dde56b4117c2552939417549f Mon Sep 17 00:00:00 2001 From: Sait Talha Nisanci Date: Wed, 30 Dec 2020 14:18:40 +0300 Subject: [PATCH] Add security utils We don't need to know the details of what needs to be done for switching to citus security context, so it is abstracted away. --- src/backend/columnar/cstore_metadata_tables.c | 8 +++--- src/backend/distributed/metadata/distobject.c | 10 +++---- .../distributed/metadata/node_metadata.c | 16 +++++------- .../distributed/operations/node_protocol.c | 16 ++++-------- .../distributed/utils/colocation_utils.c | 8 +++--- .../worker/worker_data_fetch_protocol.c | 9 +++---- .../worker/worker_merge_protocol.c | 9 +++---- src/include/distributed/security_utils.h | 26 +++++++++++++++++++ 8 files changed, 52 insertions(+), 50 deletions(-) create mode 100644 src/include/distributed/security_utils.h diff --git a/src/backend/columnar/cstore_metadata_tables.c b/src/backend/columnar/cstore_metadata_tables.c index 9f561f38c..a20c9bc66 100644 --- a/src/backend/columnar/cstore_metadata_tables.c +++ b/src/backend/columnar/cstore_metadata_tables.c @@ -29,6 +29,7 @@ #include "commands/defrem.h" #include "commands/trigger.h" #include "distributed/metadata_cache.h" +#include "distributed/security_utils.h" #include "executor/executor.h" #include "executor/spi.h" #include "miscadmin.h" @@ -1199,8 +1200,6 @@ InitMetapage(Relation relation) static uint64 GetNextStorageId(void) { - Oid savedUserId = InvalidOid; - int savedSecurityContext = 0; Oid sequenceId = get_relname_relid("storageid_seq", CStoreNamespaceId()); Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId); @@ -1208,15 +1207,14 @@ GetNextStorageId(void) * Not all users have update access to the sequence, so switch * security context. */ - GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); - SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); + PushCitusSecurityContext(); /* * Generate new and unique storage id from sequence. */ Datum storageIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum); - SetUserIdAndSecContext(savedUserId, savedSecurityContext); + PopCitusSecurityContext(); uint64 storageId = DatumGetInt64(storageIdDatum); diff --git a/src/backend/distributed/metadata/distobject.c b/src/backend/distributed/metadata/distobject.c index 86823c62e..274b64b45 100644 --- a/src/backend/distributed/metadata/distobject.c +++ b/src/backend/distributed/metadata/distobject.c @@ -31,6 +31,7 @@ #include "distributed/metadata/distobject.h" #include "distributed/metadata/pg_dist_object.h" #include "distributed/metadata_cache.h" +#include "distributed/security_utils.h" #include "distributed/version_compat.h" #include "executor/spi.h" #include "nodes/makefuncs.h" @@ -189,9 +190,6 @@ static int ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes, Datum *paramValues) { - Oid savedUserId = InvalidOid; - int savedSecurityContext = 0; - int spiConnected = SPI_connect(); if (spiConnected != SPI_OK_CONNECT) { @@ -199,13 +197,11 @@ ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes, } /* make sure we have write access */ - GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); - SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); - + PushCitusSecurityContext(); int spiStatus = SPI_execute_with_args(query, paramCount, paramTypes, paramValues, NULL, false, 0); - SetUserIdAndSecContext(savedUserId, savedSecurityContext); + PopCitusSecurityContext(); int spiFinished = SPI_finish(); if (spiFinished != SPI_OK_FINISH) diff --git a/src/backend/distributed/metadata/node_metadata.c b/src/backend/distributed/metadata/node_metadata.c index 335cb115b..af391eda9 100644 --- a/src/backend/distributed/metadata/node_metadata.c +++ b/src/backend/distributed/metadata/node_metadata.c @@ -39,6 +39,7 @@ #include "distributed/reference_table_utils.h" #include "distributed/remote_commands.h" #include "distributed/resource_lock.h" +#include "distributed/security_utils.h" #include "distributed/shardinterval_utils.h" #include "distributed/shared_connection_stats.h" #include "distributed/string_utils.h" @@ -1473,16 +1474,14 @@ GetNextGroupId() text *sequenceName = cstring_to_text(GROUPID_SEQUENCE_NAME); Oid sequenceId = ResolveRelationId(sequenceName, false); Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId); - Oid savedUserId = InvalidOid; - int savedSecurityContext = 0; - GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); - SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); + + PushCitusSecurityContext(); /* generate new and unique shardId from sequence */ Datum groupIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum); - SetUserIdAndSecContext(savedUserId, savedSecurityContext); + PopCitusSecurityContext(); int32 groupId = DatumGetInt32(groupIdDatum); @@ -1505,16 +1504,13 @@ GetNextNodeId() text *sequenceName = cstring_to_text(NODEID_SEQUENCE_NAME); Oid sequenceId = ResolveRelationId(sequenceName, false); Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId); - Oid savedUserId = InvalidOid; - int savedSecurityContext = 0; - GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); - SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); + PushCitusSecurityContext(); /* generate new and unique shardId from sequence */ Datum nextNodeIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum); - SetUserIdAndSecContext(savedUserId, savedSecurityContext); + PopCitusSecurityContext(); int nextNodeId = DatumGetUInt32(nextNodeIdDatum); diff --git a/src/backend/distributed/operations/node_protocol.c b/src/backend/distributed/operations/node_protocol.c index 217f272e0..c3348be1d 100644 --- a/src/backend/distributed/operations/node_protocol.c +++ b/src/backend/distributed/operations/node_protocol.c @@ -48,6 +48,7 @@ #include "distributed/metadata_sync.h" #include "distributed/namespace_utils.h" #include "distributed/pg_dist_shard.h" +#include "distributed/security_utils.h" #include "distributed/version_compat.h" #include "distributed/worker_manager.h" #include "foreign/foreign.h" @@ -293,8 +294,6 @@ master_get_new_shardid(PG_FUNCTION_ARGS) uint64 GetNextShardId() { - Oid savedUserId = InvalidOid; - int savedSecurityContext = 0; uint64 shardId = 0; /* @@ -316,13 +315,11 @@ GetNextShardId() Oid sequenceId = ResolveRelationId(sequenceName, false); Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId); - GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); - SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); + PushCitusSecurityContext(); /* generate new and unique shardId from sequence */ Datum shardIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum); - - SetUserIdAndSecContext(savedUserId, savedSecurityContext); + PopCitusSecurityContext(); shardId = DatumGetInt64(shardIdDatum); @@ -365,8 +362,6 @@ master_get_new_placementid(PG_FUNCTION_ARGS) uint64 GetNextPlacementId(void) { - Oid savedUserId = InvalidOid; - int savedSecurityContext = 0; uint64 placementId = 0; /* @@ -388,13 +383,12 @@ GetNextPlacementId(void) Oid sequenceId = ResolveRelationId(sequenceName, false); Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId); - GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); - SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); + PushCitusSecurityContext(); /* generate new and unique placement id from sequence */ Datum placementIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum); - SetUserIdAndSecContext(savedUserId, savedSecurityContext); + PopCitusSecurityContext(); placementId = DatumGetInt64(placementIdDatum); diff --git a/src/backend/distributed/utils/colocation_utils.c b/src/backend/distributed/utils/colocation_utils.c index b1038fd5f..8294523b0 100644 --- a/src/backend/distributed/utils/colocation_utils.c +++ b/src/backend/distributed/utils/colocation_utils.c @@ -28,6 +28,7 @@ #include "distributed/multi_logical_planner.h" #include "distributed/pg_dist_colocation.h" #include "distributed/resource_lock.h" +#include "distributed/security_utils.h" #include "distributed/shardinterval_utils.h" #include "distributed/version_compat.h" #include "distributed/worker_protocol.h" @@ -614,16 +615,13 @@ GetNextColocationId() text *sequenceName = cstring_to_text(COLOCATIONID_SEQUENCE_NAME); Oid sequenceId = ResolveRelationId(sequenceName, false); Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId); - Oid savedUserId = InvalidOid; - int savedSecurityContext = 0; - GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); - SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); + PushCitusSecurityContext(); /* generate new and unique colocation id from sequence */ Datum colocationIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum); - SetUserIdAndSecContext(savedUserId, savedSecurityContext); + PopCitusSecurityContext(); uint32 colocationId = DatumGetUInt32(colocationIdDatum); diff --git a/src/backend/distributed/worker/worker_data_fetch_protocol.c b/src/backend/distributed/worker/worker_data_fetch_protocol.c index 9c9982a6f..657c0dc17 100644 --- a/src/backend/distributed/worker/worker_data_fetch_protocol.c +++ b/src/backend/distributed/worker/worker_data_fetch_protocol.c @@ -41,6 +41,7 @@ #include "distributed/relay_utility.h" #include "distributed/remote_commands.h" #include "distributed/resource_lock.h" +#include "distributed/security_utils.h" #include "distributed/worker_protocol.h" #include "distributed/version_compat.h" @@ -594,9 +595,6 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS) char *sourceSchemaName = NULL; char *sourceTableName = NULL; - Oid savedUserId = InvalidOid; - int savedSecurityContext = 0; - CheckCitusVersion(ERROR); /* We extract schema names and table names from qualified names */ @@ -665,13 +663,12 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS) CheckCopyPermissions(localCopyCommand); /* need superuser to copy from files */ - GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); - SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); + PushCitusSecurityContext(); CitusProcessUtility((Node *) localCopyCommand, queryString->data, PROCESS_UTILITY_TOPLEVEL, NULL, None_Receiver, NULL); - SetUserIdAndSecContext(savedUserId, savedSecurityContext); + PopCitusSecurityContext(); /* finally delete the temporary file we created */ CitusDeleteFile(localFilePath->data); diff --git a/src/backend/distributed/worker/worker_merge_protocol.c b/src/backend/distributed/worker/worker_merge_protocol.c index b32204aa6..9324158b1 100644 --- a/src/backend/distributed/worker/worker_merge_protocol.c +++ b/src/backend/distributed/worker/worker_merge_protocol.c @@ -32,6 +32,7 @@ #include "commands/tablecmds.h" #include "common/string.h" #include "distributed/metadata_cache.h" +#include "distributed/security_utils.h" #include "distributed/worker_protocol.h" #include "distributed/version_compat.h" @@ -183,8 +184,6 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS) StringInfo jobSchemaName = JobSchemaName(jobId); StringInfo taskTableName = TaskTableName(taskId); StringInfo taskDirectoryName = TaskDirectoryName(jobId, taskId); - Oid savedUserId = InvalidOid; - int savedSecurityContext = 0; Oid userId = GetUserId(); /* we should have the same number of column names and types */ @@ -234,13 +233,11 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS) CreateTaskTable(jobSchemaName, taskTableName, columnNameList, columnTypeList); /* need superuser to copy from files */ - GetUserIdAndSecContext(&savedUserId, &savedSecurityContext); - SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); - + PushCitusSecurityContext(); CopyTaskFilesFromDirectory(jobSchemaName, taskTableName, taskDirectoryName, userId); - SetUserIdAndSecContext(savedUserId, savedSecurityContext); + PopCitusSecurityContext(); PG_RETURN_VOID(); } diff --git a/src/include/distributed/security_utils.h b/src/include/distributed/security_utils.h new file mode 100644 index 000000000..63b9b551f --- /dev/null +++ b/src/include/distributed/security_utils.h @@ -0,0 +1,26 @@ +/*------------------------------------------------------------------------- + * + * security_utils.h + * security related utility functions. + * + * Copyright (c) Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#ifndef SECURITY_UTILS_H_ +#define SECURITY_UTILS_H_ + +#include "postgres.h" +#include "miscadmin.h" + +#define PushCitusSecurityContext() \ + Oid savedUserId_DONTUSE = InvalidOid; \ + int savedSecurityContext_DONTUSE = 0; \ + GetUserIdAndSecContext(&savedUserId_DONTUSE, &savedSecurityContext_DONTUSE); \ + SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE); + +#define PopCitusSecurityContext() \ + SetUserIdAndSecContext(savedUserId_DONTUSE, savedSecurityContext_DONTUSE); + +#endif