Add security utils

We don't need to know the details of what needs to be done for switching
to citus security context, so it is abstracted away.
refactor/security_utils
Sait Talha Nisanci 2020-12-30 14:18:40 +03:00
parent e91e745dbc
commit 620dd12137
8 changed files with 52 additions and 50 deletions

View File

@ -29,6 +29,7 @@
#include "commands/defrem.h"
#include "commands/trigger.h"
#include "distributed/metadata_cache.h"
#include "distributed/security_utils.h"
#include "executor/executor.h"
#include "executor/spi.h"
#include "miscadmin.h"
@ -1199,8 +1200,6 @@ InitMetapage(Relation relation)
static uint64
GetNextStorageId(void)
{
Oid savedUserId = InvalidOid;
int savedSecurityContext = 0;
Oid sequenceId = get_relname_relid("storageid_seq", CStoreNamespaceId());
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
@ -1208,15 +1207,14 @@ GetNextStorageId(void)
* Not all users have update access to the sequence, so switch
* security context.
*/
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
PushCitusSecurityContext();
/*
* Generate new and unique storage id from sequence.
*/
Datum storageIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
PopCitusSecurityContext();
uint64 storageId = DatumGetInt64(storageIdDatum);

View File

@ -31,6 +31,7 @@
#include "distributed/metadata/distobject.h"
#include "distributed/metadata/pg_dist_object.h"
#include "distributed/metadata_cache.h"
#include "distributed/security_utils.h"
#include "distributed/version_compat.h"
#include "executor/spi.h"
#include "nodes/makefuncs.h"
@ -189,9 +190,6 @@ static int
ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes,
Datum *paramValues)
{
Oid savedUserId = InvalidOid;
int savedSecurityContext = 0;
int spiConnected = SPI_connect();
if (spiConnected != SPI_OK_CONNECT)
{
@ -199,13 +197,11 @@ ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes,
}
/* make sure we have write access */
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
PushCitusSecurityContext();
int spiStatus = SPI_execute_with_args(query, paramCount, paramTypes, paramValues,
NULL, false, 0);
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
PopCitusSecurityContext();
int spiFinished = SPI_finish();
if (spiFinished != SPI_OK_FINISH)

View File

@ -39,6 +39,7 @@
#include "distributed/reference_table_utils.h"
#include "distributed/remote_commands.h"
#include "distributed/resource_lock.h"
#include "distributed/security_utils.h"
#include "distributed/shardinterval_utils.h"
#include "distributed/shared_connection_stats.h"
#include "distributed/string_utils.h"
@ -1473,16 +1474,14 @@ GetNextGroupId()
text *sequenceName = cstring_to_text(GROUPID_SEQUENCE_NAME);
Oid sequenceId = ResolveRelationId(sequenceName, false);
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
Oid savedUserId = InvalidOid;
int savedSecurityContext = 0;
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
PushCitusSecurityContext();
/* generate new and unique shardId from sequence */
Datum groupIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
PopCitusSecurityContext();
int32 groupId = DatumGetInt32(groupIdDatum);
@ -1505,16 +1504,13 @@ GetNextNodeId()
text *sequenceName = cstring_to_text(NODEID_SEQUENCE_NAME);
Oid sequenceId = ResolveRelationId(sequenceName, false);
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
Oid savedUserId = InvalidOid;
int savedSecurityContext = 0;
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
PushCitusSecurityContext();
/* generate new and unique shardId from sequence */
Datum nextNodeIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
PopCitusSecurityContext();
int nextNodeId = DatumGetUInt32(nextNodeIdDatum);

View File

@ -48,6 +48,7 @@
#include "distributed/metadata_sync.h"
#include "distributed/namespace_utils.h"
#include "distributed/pg_dist_shard.h"
#include "distributed/security_utils.h"
#include "distributed/version_compat.h"
#include "distributed/worker_manager.h"
#include "foreign/foreign.h"
@ -293,8 +294,6 @@ master_get_new_shardid(PG_FUNCTION_ARGS)
uint64
GetNextShardId()
{
Oid savedUserId = InvalidOid;
int savedSecurityContext = 0;
uint64 shardId = 0;
/*
@ -316,13 +315,11 @@ GetNextShardId()
Oid sequenceId = ResolveRelationId(sequenceName, false);
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
PushCitusSecurityContext();
/* generate new and unique shardId from sequence */
Datum shardIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
PopCitusSecurityContext();
shardId = DatumGetInt64(shardIdDatum);
@ -365,8 +362,6 @@ master_get_new_placementid(PG_FUNCTION_ARGS)
uint64
GetNextPlacementId(void)
{
Oid savedUserId = InvalidOid;
int savedSecurityContext = 0;
uint64 placementId = 0;
/*
@ -388,13 +383,12 @@ GetNextPlacementId(void)
Oid sequenceId = ResolveRelationId(sequenceName, false);
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
PushCitusSecurityContext();
/* generate new and unique placement id from sequence */
Datum placementIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
PopCitusSecurityContext();
placementId = DatumGetInt64(placementIdDatum);

View File

@ -28,6 +28,7 @@
#include "distributed/multi_logical_planner.h"
#include "distributed/pg_dist_colocation.h"
#include "distributed/resource_lock.h"
#include "distributed/security_utils.h"
#include "distributed/shardinterval_utils.h"
#include "distributed/version_compat.h"
#include "distributed/worker_protocol.h"
@ -614,16 +615,13 @@ GetNextColocationId()
text *sequenceName = cstring_to_text(COLOCATIONID_SEQUENCE_NAME);
Oid sequenceId = ResolveRelationId(sequenceName, false);
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
Oid savedUserId = InvalidOid;
int savedSecurityContext = 0;
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
PushCitusSecurityContext();
/* generate new and unique colocation id from sequence */
Datum colocationIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
PopCitusSecurityContext();
uint32 colocationId = DatumGetUInt32(colocationIdDatum);

View File

@ -41,6 +41,7 @@
#include "distributed/relay_utility.h"
#include "distributed/remote_commands.h"
#include "distributed/resource_lock.h"
#include "distributed/security_utils.h"
#include "distributed/worker_protocol.h"
#include "distributed/version_compat.h"
@ -594,9 +595,6 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS)
char *sourceSchemaName = NULL;
char *sourceTableName = NULL;
Oid savedUserId = InvalidOid;
int savedSecurityContext = 0;
CheckCitusVersion(ERROR);
/* We extract schema names and table names from qualified names */
@ -665,13 +663,12 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS)
CheckCopyPermissions(localCopyCommand);
/* need superuser to copy from files */
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
PushCitusSecurityContext();
CitusProcessUtility((Node *) localCopyCommand, queryString->data,
PROCESS_UTILITY_TOPLEVEL, NULL, None_Receiver, NULL);
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
PopCitusSecurityContext();
/* finally delete the temporary file we created */
CitusDeleteFile(localFilePath->data);

View File

@ -32,6 +32,7 @@
#include "commands/tablecmds.h"
#include "common/string.h"
#include "distributed/metadata_cache.h"
#include "distributed/security_utils.h"
#include "distributed/worker_protocol.h"
#include "distributed/version_compat.h"
@ -183,8 +184,6 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS)
StringInfo jobSchemaName = JobSchemaName(jobId);
StringInfo taskTableName = TaskTableName(taskId);
StringInfo taskDirectoryName = TaskDirectoryName(jobId, taskId);
Oid savedUserId = InvalidOid;
int savedSecurityContext = 0;
Oid userId = GetUserId();
/* we should have the same number of column names and types */
@ -234,13 +233,11 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS)
CreateTaskTable(jobSchemaName, taskTableName, columnNameList, columnTypeList);
/* need superuser to copy from files */
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
PushCitusSecurityContext();
CopyTaskFilesFromDirectory(jobSchemaName, taskTableName, taskDirectoryName,
userId);
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
PopCitusSecurityContext();
PG_RETURN_VOID();
}

View File

@ -0,0 +1,26 @@
/*-------------------------------------------------------------------------
*
* security_utils.h
* security related utility functions.
*
* Copyright (c) Citus Data, Inc.
*
*-------------------------------------------------------------------------
*/
#ifndef SECURITY_UTILS_H_
#define SECURITY_UTILS_H_
#include "postgres.h"
#include "miscadmin.h"
#define PushCitusSecurityContext() \
Oid savedUserId_DONTUSE = InvalidOid; \
int savedSecurityContext_DONTUSE = 0; \
GetUserIdAndSecContext(&savedUserId_DONTUSE, &savedSecurityContext_DONTUSE); \
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
#define PopCitusSecurityContext() \
SetUserIdAndSecContext(savedUserId_DONTUSE, savedSecurityContext_DONTUSE);
#endif