mirror of https://github.com/citusdata/citus.git
Add security utils
We don't need to know the details of what needs to be done for switching to citus security context, so it is abstracted away.refactor/security_utils
parent
e91e745dbc
commit
620dd12137
|
@ -29,6 +29,7 @@
|
||||||
#include "commands/defrem.h"
|
#include "commands/defrem.h"
|
||||||
#include "commands/trigger.h"
|
#include "commands/trigger.h"
|
||||||
#include "distributed/metadata_cache.h"
|
#include "distributed/metadata_cache.h"
|
||||||
|
#include "distributed/security_utils.h"
|
||||||
#include "executor/executor.h"
|
#include "executor/executor.h"
|
||||||
#include "executor/spi.h"
|
#include "executor/spi.h"
|
||||||
#include "miscadmin.h"
|
#include "miscadmin.h"
|
||||||
|
@ -1199,8 +1200,6 @@ InitMetapage(Relation relation)
|
||||||
static uint64
|
static uint64
|
||||||
GetNextStorageId(void)
|
GetNextStorageId(void)
|
||||||
{
|
{
|
||||||
Oid savedUserId = InvalidOid;
|
|
||||||
int savedSecurityContext = 0;
|
|
||||||
Oid sequenceId = get_relname_relid("storageid_seq", CStoreNamespaceId());
|
Oid sequenceId = get_relname_relid("storageid_seq", CStoreNamespaceId());
|
||||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||||
|
|
||||||
|
@ -1208,15 +1207,14 @@ GetNextStorageId(void)
|
||||||
* Not all users have update access to the sequence, so switch
|
* Not all users have update access to the sequence, so switch
|
||||||
* security context.
|
* security context.
|
||||||
*/
|
*/
|
||||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
PushCitusSecurityContext();
|
||||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Generate new and unique storage id from sequence.
|
* Generate new and unique storage id from sequence.
|
||||||
*/
|
*/
|
||||||
Datum storageIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
Datum storageIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||||
|
|
||||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
PopCitusSecurityContext();
|
||||||
|
|
||||||
uint64 storageId = DatumGetInt64(storageIdDatum);
|
uint64 storageId = DatumGetInt64(storageIdDatum);
|
||||||
|
|
||||||
|
|
|
@ -31,6 +31,7 @@
|
||||||
#include "distributed/metadata/distobject.h"
|
#include "distributed/metadata/distobject.h"
|
||||||
#include "distributed/metadata/pg_dist_object.h"
|
#include "distributed/metadata/pg_dist_object.h"
|
||||||
#include "distributed/metadata_cache.h"
|
#include "distributed/metadata_cache.h"
|
||||||
|
#include "distributed/security_utils.h"
|
||||||
#include "distributed/version_compat.h"
|
#include "distributed/version_compat.h"
|
||||||
#include "executor/spi.h"
|
#include "executor/spi.h"
|
||||||
#include "nodes/makefuncs.h"
|
#include "nodes/makefuncs.h"
|
||||||
|
@ -189,9 +190,6 @@ static int
|
||||||
ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes,
|
ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes,
|
||||||
Datum *paramValues)
|
Datum *paramValues)
|
||||||
{
|
{
|
||||||
Oid savedUserId = InvalidOid;
|
|
||||||
int savedSecurityContext = 0;
|
|
||||||
|
|
||||||
int spiConnected = SPI_connect();
|
int spiConnected = SPI_connect();
|
||||||
if (spiConnected != SPI_OK_CONNECT)
|
if (spiConnected != SPI_OK_CONNECT)
|
||||||
{
|
{
|
||||||
|
@ -199,13 +197,11 @@ ExecuteCommandAsSuperuser(char *query, int paramCount, Oid *paramTypes,
|
||||||
}
|
}
|
||||||
|
|
||||||
/* make sure we have write access */
|
/* make sure we have write access */
|
||||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
PushCitusSecurityContext();
|
||||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
|
||||||
|
|
||||||
int spiStatus = SPI_execute_with_args(query, paramCount, paramTypes, paramValues,
|
int spiStatus = SPI_execute_with_args(query, paramCount, paramTypes, paramValues,
|
||||||
NULL, false, 0);
|
NULL, false, 0);
|
||||||
|
|
||||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
PopCitusSecurityContext();
|
||||||
|
|
||||||
int spiFinished = SPI_finish();
|
int spiFinished = SPI_finish();
|
||||||
if (spiFinished != SPI_OK_FINISH)
|
if (spiFinished != SPI_OK_FINISH)
|
||||||
|
|
|
@ -39,6 +39,7 @@
|
||||||
#include "distributed/reference_table_utils.h"
|
#include "distributed/reference_table_utils.h"
|
||||||
#include "distributed/remote_commands.h"
|
#include "distributed/remote_commands.h"
|
||||||
#include "distributed/resource_lock.h"
|
#include "distributed/resource_lock.h"
|
||||||
|
#include "distributed/security_utils.h"
|
||||||
#include "distributed/shardinterval_utils.h"
|
#include "distributed/shardinterval_utils.h"
|
||||||
#include "distributed/shared_connection_stats.h"
|
#include "distributed/shared_connection_stats.h"
|
||||||
#include "distributed/string_utils.h"
|
#include "distributed/string_utils.h"
|
||||||
|
@ -1473,16 +1474,14 @@ GetNextGroupId()
|
||||||
text *sequenceName = cstring_to_text(GROUPID_SEQUENCE_NAME);
|
text *sequenceName = cstring_to_text(GROUPID_SEQUENCE_NAME);
|
||||||
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
||||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||||
Oid savedUserId = InvalidOid;
|
|
||||||
int savedSecurityContext = 0;
|
|
||||||
|
|
||||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
|
||||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
PushCitusSecurityContext();
|
||||||
|
|
||||||
/* generate new and unique shardId from sequence */
|
/* generate new and unique shardId from sequence */
|
||||||
Datum groupIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
Datum groupIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||||
|
|
||||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
PopCitusSecurityContext();
|
||||||
|
|
||||||
int32 groupId = DatumGetInt32(groupIdDatum);
|
int32 groupId = DatumGetInt32(groupIdDatum);
|
||||||
|
|
||||||
|
@ -1505,16 +1504,13 @@ GetNextNodeId()
|
||||||
text *sequenceName = cstring_to_text(NODEID_SEQUENCE_NAME);
|
text *sequenceName = cstring_to_text(NODEID_SEQUENCE_NAME);
|
||||||
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
||||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||||
Oid savedUserId = InvalidOid;
|
|
||||||
int savedSecurityContext = 0;
|
|
||||||
|
|
||||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
PushCitusSecurityContext();
|
||||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
|
||||||
|
|
||||||
/* generate new and unique shardId from sequence */
|
/* generate new and unique shardId from sequence */
|
||||||
Datum nextNodeIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
Datum nextNodeIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||||
|
|
||||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
PopCitusSecurityContext();
|
||||||
|
|
||||||
int nextNodeId = DatumGetUInt32(nextNodeIdDatum);
|
int nextNodeId = DatumGetUInt32(nextNodeIdDatum);
|
||||||
|
|
||||||
|
|
|
@ -48,6 +48,7 @@
|
||||||
#include "distributed/metadata_sync.h"
|
#include "distributed/metadata_sync.h"
|
||||||
#include "distributed/namespace_utils.h"
|
#include "distributed/namespace_utils.h"
|
||||||
#include "distributed/pg_dist_shard.h"
|
#include "distributed/pg_dist_shard.h"
|
||||||
|
#include "distributed/security_utils.h"
|
||||||
#include "distributed/version_compat.h"
|
#include "distributed/version_compat.h"
|
||||||
#include "distributed/worker_manager.h"
|
#include "distributed/worker_manager.h"
|
||||||
#include "foreign/foreign.h"
|
#include "foreign/foreign.h"
|
||||||
|
@ -293,8 +294,6 @@ master_get_new_shardid(PG_FUNCTION_ARGS)
|
||||||
uint64
|
uint64
|
||||||
GetNextShardId()
|
GetNextShardId()
|
||||||
{
|
{
|
||||||
Oid savedUserId = InvalidOid;
|
|
||||||
int savedSecurityContext = 0;
|
|
||||||
uint64 shardId = 0;
|
uint64 shardId = 0;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -316,13 +315,11 @@ GetNextShardId()
|
||||||
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
||||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||||
|
|
||||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
PushCitusSecurityContext();
|
||||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
|
||||||
|
|
||||||
/* generate new and unique shardId from sequence */
|
/* generate new and unique shardId from sequence */
|
||||||
Datum shardIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
Datum shardIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||||
|
PopCitusSecurityContext();
|
||||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
|
||||||
|
|
||||||
shardId = DatumGetInt64(shardIdDatum);
|
shardId = DatumGetInt64(shardIdDatum);
|
||||||
|
|
||||||
|
@ -365,8 +362,6 @@ master_get_new_placementid(PG_FUNCTION_ARGS)
|
||||||
uint64
|
uint64
|
||||||
GetNextPlacementId(void)
|
GetNextPlacementId(void)
|
||||||
{
|
{
|
||||||
Oid savedUserId = InvalidOid;
|
|
||||||
int savedSecurityContext = 0;
|
|
||||||
uint64 placementId = 0;
|
uint64 placementId = 0;
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
@ -388,13 +383,12 @@ GetNextPlacementId(void)
|
||||||
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
||||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||||
|
|
||||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
PushCitusSecurityContext();
|
||||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
|
||||||
|
|
||||||
/* generate new and unique placement id from sequence */
|
/* generate new and unique placement id from sequence */
|
||||||
Datum placementIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
Datum placementIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||||
|
|
||||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
PopCitusSecurityContext();
|
||||||
|
|
||||||
placementId = DatumGetInt64(placementIdDatum);
|
placementId = DatumGetInt64(placementIdDatum);
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
#include "distributed/multi_logical_planner.h"
|
#include "distributed/multi_logical_planner.h"
|
||||||
#include "distributed/pg_dist_colocation.h"
|
#include "distributed/pg_dist_colocation.h"
|
||||||
#include "distributed/resource_lock.h"
|
#include "distributed/resource_lock.h"
|
||||||
|
#include "distributed/security_utils.h"
|
||||||
#include "distributed/shardinterval_utils.h"
|
#include "distributed/shardinterval_utils.h"
|
||||||
#include "distributed/version_compat.h"
|
#include "distributed/version_compat.h"
|
||||||
#include "distributed/worker_protocol.h"
|
#include "distributed/worker_protocol.h"
|
||||||
|
@ -614,16 +615,13 @@ GetNextColocationId()
|
||||||
text *sequenceName = cstring_to_text(COLOCATIONID_SEQUENCE_NAME);
|
text *sequenceName = cstring_to_text(COLOCATIONID_SEQUENCE_NAME);
|
||||||
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
Oid sequenceId = ResolveRelationId(sequenceName, false);
|
||||||
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
Datum sequenceIdDatum = ObjectIdGetDatum(sequenceId);
|
||||||
Oid savedUserId = InvalidOid;
|
|
||||||
int savedSecurityContext = 0;
|
|
||||||
|
|
||||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
PushCitusSecurityContext();
|
||||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
|
||||||
|
|
||||||
/* generate new and unique colocation id from sequence */
|
/* generate new and unique colocation id from sequence */
|
||||||
Datum colocationIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
Datum colocationIdDatum = DirectFunctionCall1(nextval_oid, sequenceIdDatum);
|
||||||
|
|
||||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
PopCitusSecurityContext();
|
||||||
|
|
||||||
uint32 colocationId = DatumGetUInt32(colocationIdDatum);
|
uint32 colocationId = DatumGetUInt32(colocationIdDatum);
|
||||||
|
|
||||||
|
|
|
@ -41,6 +41,7 @@
|
||||||
#include "distributed/relay_utility.h"
|
#include "distributed/relay_utility.h"
|
||||||
#include "distributed/remote_commands.h"
|
#include "distributed/remote_commands.h"
|
||||||
#include "distributed/resource_lock.h"
|
#include "distributed/resource_lock.h"
|
||||||
|
#include "distributed/security_utils.h"
|
||||||
|
|
||||||
#include "distributed/worker_protocol.h"
|
#include "distributed/worker_protocol.h"
|
||||||
#include "distributed/version_compat.h"
|
#include "distributed/version_compat.h"
|
||||||
|
@ -594,9 +595,6 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS)
|
||||||
char *sourceSchemaName = NULL;
|
char *sourceSchemaName = NULL;
|
||||||
char *sourceTableName = NULL;
|
char *sourceTableName = NULL;
|
||||||
|
|
||||||
Oid savedUserId = InvalidOid;
|
|
||||||
int savedSecurityContext = 0;
|
|
||||||
|
|
||||||
CheckCitusVersion(ERROR);
|
CheckCitusVersion(ERROR);
|
||||||
|
|
||||||
/* We extract schema names and table names from qualified names */
|
/* We extract schema names and table names from qualified names */
|
||||||
|
@ -665,13 +663,12 @@ worker_append_table_to_shard(PG_FUNCTION_ARGS)
|
||||||
CheckCopyPermissions(localCopyCommand);
|
CheckCopyPermissions(localCopyCommand);
|
||||||
|
|
||||||
/* need superuser to copy from files */
|
/* need superuser to copy from files */
|
||||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
PushCitusSecurityContext();
|
||||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
|
||||||
|
|
||||||
CitusProcessUtility((Node *) localCopyCommand, queryString->data,
|
CitusProcessUtility((Node *) localCopyCommand, queryString->data,
|
||||||
PROCESS_UTILITY_TOPLEVEL, NULL, None_Receiver, NULL);
|
PROCESS_UTILITY_TOPLEVEL, NULL, None_Receiver, NULL);
|
||||||
|
|
||||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
PopCitusSecurityContext();
|
||||||
|
|
||||||
/* finally delete the temporary file we created */
|
/* finally delete the temporary file we created */
|
||||||
CitusDeleteFile(localFilePath->data);
|
CitusDeleteFile(localFilePath->data);
|
||||||
|
|
|
@ -32,6 +32,7 @@
|
||||||
#include "commands/tablecmds.h"
|
#include "commands/tablecmds.h"
|
||||||
#include "common/string.h"
|
#include "common/string.h"
|
||||||
#include "distributed/metadata_cache.h"
|
#include "distributed/metadata_cache.h"
|
||||||
|
#include "distributed/security_utils.h"
|
||||||
#include "distributed/worker_protocol.h"
|
#include "distributed/worker_protocol.h"
|
||||||
#include "distributed/version_compat.h"
|
#include "distributed/version_compat.h"
|
||||||
|
|
||||||
|
@ -183,8 +184,6 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS)
|
||||||
StringInfo jobSchemaName = JobSchemaName(jobId);
|
StringInfo jobSchemaName = JobSchemaName(jobId);
|
||||||
StringInfo taskTableName = TaskTableName(taskId);
|
StringInfo taskTableName = TaskTableName(taskId);
|
||||||
StringInfo taskDirectoryName = TaskDirectoryName(jobId, taskId);
|
StringInfo taskDirectoryName = TaskDirectoryName(jobId, taskId);
|
||||||
Oid savedUserId = InvalidOid;
|
|
||||||
int savedSecurityContext = 0;
|
|
||||||
Oid userId = GetUserId();
|
Oid userId = GetUserId();
|
||||||
|
|
||||||
/* we should have the same number of column names and types */
|
/* we should have the same number of column names and types */
|
||||||
|
@ -234,13 +233,11 @@ worker_merge_files_into_table(PG_FUNCTION_ARGS)
|
||||||
CreateTaskTable(jobSchemaName, taskTableName, columnNameList, columnTypeList);
|
CreateTaskTable(jobSchemaName, taskTableName, columnNameList, columnTypeList);
|
||||||
|
|
||||||
/* need superuser to copy from files */
|
/* need superuser to copy from files */
|
||||||
GetUserIdAndSecContext(&savedUserId, &savedSecurityContext);
|
PushCitusSecurityContext();
|
||||||
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
|
||||||
|
|
||||||
CopyTaskFilesFromDirectory(jobSchemaName, taskTableName, taskDirectoryName,
|
CopyTaskFilesFromDirectory(jobSchemaName, taskTableName, taskDirectoryName,
|
||||||
userId);
|
userId);
|
||||||
|
|
||||||
SetUserIdAndSecContext(savedUserId, savedSecurityContext);
|
PopCitusSecurityContext();
|
||||||
PG_RETURN_VOID();
|
PG_RETURN_VOID();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
/*-------------------------------------------------------------------------
|
||||||
|
*
|
||||||
|
* security_utils.h
|
||||||
|
* security related utility functions.
|
||||||
|
*
|
||||||
|
* Copyright (c) Citus Data, Inc.
|
||||||
|
*
|
||||||
|
*-------------------------------------------------------------------------
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef SECURITY_UTILS_H_
|
||||||
|
#define SECURITY_UTILS_H_
|
||||||
|
|
||||||
|
#include "postgres.h"
|
||||||
|
#include "miscadmin.h"
|
||||||
|
|
||||||
|
#define PushCitusSecurityContext() \
|
||||||
|
Oid savedUserId_DONTUSE = InvalidOid; \
|
||||||
|
int savedSecurityContext_DONTUSE = 0; \
|
||||||
|
GetUserIdAndSecContext(&savedUserId_DONTUSE, &savedSecurityContext_DONTUSE); \
|
||||||
|
SetUserIdAndSecContext(CitusExtensionOwner(), SECURITY_LOCAL_USERID_CHANGE);
|
||||||
|
|
||||||
|
#define PopCitusSecurityContext() \
|
||||||
|
SetUserIdAndSecContext(savedUserId_DONTUSE, savedSecurityContext_DONTUSE);
|
||||||
|
|
||||||
|
#endif
|
Loading…
Reference in New Issue