mirror of https://github.com/citusdata/citus.git
Add security utils
We don't need to know the details of what needs to be done for switching to citus security context, so it is abstracted away.refactor/security_utils
parent
e91e745dbc
commit
620dd12137
|
@ -29,6 +29,7 @@
|
|||
#include "commands/defrem.h"
|
||||
#include "commands/trigger.h"
|
||||
#include "distributed/metadata_cache.h"
|
||||
#include "distributed/security_utils.h"
|
||||
#include "executor/executor.h"
|
||||
#include "executor/spi.h"
|
||||
#include "miscadmin.h"
|
||||
|
@ -1199,8 +1200,6 @@ InitMetapage(Relation relation)
|
|||
static uint64
|
||||
GetNextStorageId(void)
|
||||
{
|
||||
Oid savedUserId = InvalidOid;
|
||||
int savedSecurityContext = 0;
|
||||
Oid sequenceId = get_relname_relid("storageid_seq", CStoreNamespaceId());
|
||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||
|
||||
|
@ -1208,15 +1207,14 @@ GetNextStorageId(void)
|
|||
* Not all users have update access to the sequence, so switch
|
||||
* security context.
|
||||
*/
|
||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
||||
PushCitusSecurityContext();
|
||||
|
||||
/*
|
||||
* Generate new and unique storage id from sequence.
|
||||
*/
|
||||
Datum storageIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||
|
||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
||||
PopCitusSecurityContext();
|
||||
|
||||
uint64 storageId = DatumGetInt64(storageIdDatum);
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "distributed/metadata/distobject.h"
|
||||
#include "distributed/metadata/pg_dist_object.h"
|
||||
#include "distributed/metadata_cache.h"
|
||||
#include "distributed/security_utils.h"
|
||||
#include "distributed/version_compat.h"
|
||||
#include "executor/spi.h"
|
||||
#include "nodes/makefuncs.h"
|
||||
|
@ -189,9 +190,6 @@ static int
|
|||
ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes,
|
||||
Datum *paramValues)
|
||||
{
|
||||
Oid savedUserId = InvalidOid;
|
||||
int savedSecurityContext = 0;
|
||||
|
||||
int spiConnected = SPI_connect();
|
||||
if (spiConnected != SPI_OK_CONNECT)
|
||||
{
|
||||
|
@ -199,13 +197,11 @@ ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes,
|
|||
}
|
||||
|
||||
/* make sure we have write access */
|
||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
||||
|
||||
PushCitusSecurityContext();
|
||||
int spiStatus = SPI_execute_with_args(query, paramCount, paramTypes, paramValues,
|
||||
NULL, false, 0);
|
||||
|
||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
||||
PopCitusSecurityContext();
|
||||
|
||||
int spiFinished = SPI_finish();
|
||||
if (spiFinished != SPI_OK_FINISH)
|
||||
|
|
|
@ -39,6 +39,7 @@
|
|||
#include "distributed/reference_table_utils.h"
|
||||
#include "distributed/remote_commands.h"
|
||||
#include "distributed/resource_lock.h"
|
||||
#include "distributed/security_utils.h"
|
||||
#include "distributed/shardinterval_utils.h"
|
||||
#include "distributed/shared_connection_stats.h"
|
||||
#include "distributed/string_utils.h"
|
||||
|
@ -1473,16 +1474,14 @@ GetNextGroupId()
|
|||
text *sequenceName = cstring_to_text(GROUPID_SEQUENCE_NAME);
|
||||
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||
Oid savedUserId = InvalidOid;
|
||||
int savedSecurityContext = 0;
|
||||
|
||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
||||
|
||||
PushCitusSecurityContext();
|
||||
|
||||
/* generate new and unique shardId from sequence */
|
||||
Datum groupIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||
|
||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
||||
PopCitusSecurityContext();
|
||||
|
||||
int32 groupId = DatumGetInt32(groupIdDatum);
|
||||
|
||||
|
@ -1505,16 +1504,13 @@ GetNextNodeId()
|
|||
text *sequenceName = cstring_to_text(NODEID_SEQUENCE_NAME);
|
||||
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||
Oid savedUserId = InvalidOid;
|
||||
int savedSecurityContext = 0;
|
||||
|
||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
||||
PushCitusSecurityContext();
|
||||
|
||||
/* generate new and unique shardId from sequence */
|
||||
Datum nextNodeIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||
|
||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
||||
PopCitusSecurityContext();
|
||||
|
||||
int nextNodeId = DatumGetUInt32(nextNodeIdDatum);
|
||||
|
||||
|
|
|
@ -48,6 +48,7 @@
|
|||
#include "distributed/metadata_sync.h"
|
||||
#include "distributed/namespace_utils.h"
|
||||
#include "distributed/pg_dist_shard.h"
|
||||
#include "distributed/security_utils.h"
|
||||
#include "distributed/version_compat.h"
|
||||
#include "distributed/worker_manager.h"
|
||||
#include "foreign/foreign.h"
|
||||
|
@ -293,8 +294,6 @@ master_get_new_shardid(PG_FUNCTION_ARGS)
|
|||
uint64
|
||||
GetNextShardId()
|
||||
{
|
||||
Oid savedUserId = InvalidOid;
|
||||
int savedSecurityContext = 0;
|
||||
uint64 shardId = 0;
|
||||
|
||||
/*
|
||||
|
@ -316,13 +315,11 @@ GetNextShardId()
|
|||
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||
|
||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
||||
PushCitusSecurityContext();
|
||||
|
||||
/* generate new and unique shardId from sequence */
|
||||
Datum shardIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||
|
||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
||||
PopCitusSecurityContext();
|
||||
|
||||
shardId = DatumGetInt64(shardIdDatum);
|
||||
|
||||
|
@ -365,8 +362,6 @@ master_get_new_placementid(PG_FUNCTION_ARGS)
|
|||
uint64
|
||||
GetNextPlacementId(void)
|
||||
{
|
||||
Oid savedUserId = InvalidOid;
|
||||
int savedSecurityContext = 0;
|
||||
uint64 placementId = 0;
|
||||
|
||||
/*
|
||||
|
@ -388,13 +383,12 @@ GetNextPlacementId(void)
|
|||
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||
|
||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
||||
PushCitusSecurityContext();
|
||||
|
||||
/* generate new and unique placement id from sequence */
|
||||
Datum placementIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||
|
||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
||||
PopCitusSecurityContext();
|
||||
|
||||
placementId = DatumGetInt64(placementIdDatum);
|
||||
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "distributed/multi_logical_planner.h"
|
||||
#include "distributed/pg_dist_colocation.h"
|
||||
#include "distributed/resource_lock.h"
|
||||
#include "distributed/security_utils.h"
|
||||
#include "distributed/shardinterval_utils.h"
|
||||
#include "distributed/version_compat.h"
|
||||
#include "distributed/worker_protocol.h"
|
||||
|
@ -614,16 +615,13 @@ GetNextColocationId()
|
|||
text *sequenceName = cstring_to_text(COLOCATIONID_SEQUENCE_NAME);
|
||||
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||
Oid savedUserId = InvalidOid;
|
||||
int savedSecurityContext = 0;
|
||||
|
||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
||||
PushCitusSecurityContext();
|
||||
|
||||
/* generate new and unique colocation id from sequence */
|
||||
Datum colocationIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||
|
||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
||||
PopCitusSecurityContext();
|
||||
|
||||
uint32 colocationId = DatumGetUInt32(colocationIdDatum);
|
||||
|
||||
|
|
|
@ -41,6 +41,7 @@
|
|||
#include "distributed/relay_utility.h"
|
||||
#include "distributed/remote_commands.h"
|
||||
#include "distributed/resource_lock.h"
|
||||
#include "distributed/security_utils.h"
|
||||
|
||||
#include "distributed/worker_protocol.h"
|
||||
#include "distributed/version_compat.h"
|
||||
|
@ -594,9 +595,6 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS)
|
|||
char *sourceSchemaName = NULL;
|
||||
char *sourceTableName = NULL;
|
||||
|
||||
Oid savedUserId = InvalidOid;
|
||||
int savedSecurityContext = 0;
|
||||
|
||||
CheckCitusVersion(ERROR);
|
||||
|
||||
/* We extract schema names and table names from qualified names */
|
||||
|
@ -665,13 +663,12 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS)
|
|||
CheckCopyPermissions(localCopyCommand);
|
||||
|
||||
/* need superuser to copy from files */
|
||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
||||
PushCitusSecurityContext();
|
||||
|
||||
CitusProcessUtility((Node *) localCopyCommand, queryString->data,
|
||||
PROCESS_UTILITY_TOPLEVEL, NULL, None_Receiver, NULL);
|
||||
|
||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
||||
PopCitusSecurityContext();
|
||||
|
||||
/* finally delete the temporary file we created */
|
||||
CitusDeleteFile(localFilePath->data);
|
||||
|
|
|
@ -32,6 +32,7 @@
|
|||
#include "commands/tablecmds.h"
|
||||
#include "common/string.h"
|
||||
#include "distributed/metadata_cache.h"
|
||||
#include "distributed/security_utils.h"
|
||||
#include "distributed/worker_protocol.h"
|
||||
#include "distributed/version_compat.h"
|
||||
|
||||
|
@ -183,8 +184,6 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS)
|
|||
StringInfo jobSchemaName = JobSchemaName(jobId);
|
||||
StringInfo taskTableName = TaskTableName(taskId);
|
||||
StringInfo taskDirectoryName = TaskDirectoryName(jobId, taskId);
|
||||
Oid savedUserId = InvalidOid;
|
||||
int savedSecurityContext = 0;
|
||||
Oid userId = GetUserId();
|
||||
|
||||
/* we should have the same number of column names and types */
|
||||
|
@ -234,13 +233,11 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS)
|
|||
CreateTaskTable(jobSchemaName, taskTableName, columnNameList, columnTypeList);
|
||||
|
||||
/* need superuser to copy from files */
|
||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
||||
|
||||
PushCitusSecurityContext();
|
||||
CopyTaskFilesFromDirectory(jobSchemaName, taskTableName, taskDirectoryName,
|
||||
userId);
|
||||
|
||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
||||
PopCitusSecurityContext();
|
||||
PG_RETURN_VOID();
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
/*-------------------------------------------------------------------------
|
||||
*
|
||||
* security_utils.h
|
||||
* security related utility functions.
|
||||
*
|
||||
* Copyright (c) Citus Data, Inc.
|
||||
*
|
||||
*-------------------------------------------------------------------------
|
||||
*/
|
||||
|
||||
#ifndef SECURITY_UTILS_H_
|
||||
#define SECURITY_UTILS_H_
|
||||
|
||||
#include "postgres.h"
|
||||
#include "miscadmin.h"
|
||||
|
||||
#define PushCitusSecurityContext() \
|
||||
Oid savedUserId_DONTUSE = InvalidOid; \
|
||||
int savedSecurityContext_DONTUSE = 0; \
|
||||
GetUserIdAndSecContext(&savedUserId_DONTUSE, &savedSecurityContext_DONTUSE); \
|
||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
||||
|
||||
#define PopCitusSecurityContext() \
|
||||
SetUserIdAndSecContext(savedUserId_DONTUSE, savedSecurityContext_DONTUSE);
|
||||
|
||||
#endif
|
Loading…
Reference in New Issue