diff --git a/src/backend/distributed/commands/seclabel.c b/src/backend/distributed/commands/seclabel.c index 205342184..b520f6fac 100644 --- a/src/backend/distributed/commands/seclabel.c +++ b/src/backend/distributed/commands/seclabel.c @@ -28,7 +28,7 @@ */ List * PreprocessSecLabelStmt(Node *node, const char *queryString, - ProcessUtilityContext processUtilityContext) + ProcessUtilityContext processUtilityContext) { if (!IsCoordinator() || !ShouldPropagate()) { @@ -37,8 +37,8 @@ PreprocessSecLabelStmt(Node *node, const char *queryString, SecLabelStmt *secLabelStmt = castNode(SecLabelStmt, node); - List * objectAddresses = GetObjectAddressListFromParseTree(node, false, false); - if(!IsAnyObjectDistributed(objectAddresses)) + List *objectAddresses = GetObjectAddressListFromParseTree(node, false, false); + if (!IsAnyObjectDistributed(objectAddresses)) { return NIL; } @@ -70,6 +70,7 @@ PreprocessSecLabelStmt(Node *node, const char *queryString, return NodeDDLTaskList(NON_COORDINATOR_NODES, commandList); } + /* * PostprocessSecLabelStmt */ @@ -88,8 +89,8 @@ PostprocessSecLabelStmt(Node *node, const char *queryString) return NIL; } - List * objectAddresses = GetObjectAddressListFromParseTree(node, false, false); - if(IsAnyObjectDistributed(objectAddresses)) + List *objectAddresses = GetObjectAddressListFromParseTree(node, false, false); + if (IsAnyObjectDistributed(objectAddresses)) { EnsureAllObjectDependenciesExistOnAllNodes(objectAddresses); } @@ -107,10 +108,30 @@ SecLabelStmtObjectAddress(Node *node, bool missing_ok, bool isPostprocess) SecLabelStmt *secLabelStmt = castNode(SecLabelStmt, node); Relation rel = NULL; /* not used, but required to pass to get_object_address */ - ObjectAddress address = get_object_address(secLabelStmt->objtype, secLabelStmt->object, &rel, - AccessShareLock, missing_ok); + ObjectAddress address = get_object_address(secLabelStmt->objtype, + secLabelStmt->object, &rel, + AccessShareLock, missing_ok); ObjectAddress *addressPtr = palloc0(sizeof(ObjectAddress)); *addressPtr = address; return list_make1(addressPtr); } + + +/* + * citus_test_object_relabel + */ +void +citus_test_object_relabel(const ObjectAddress *object, const char *seclabel) +{ + if (seclabel == NULL || + strcmp(seclabel, "citus_unclassified") == 0 || + strcmp(seclabel, "citus_classified") == 0) + { + return; + } + + ereport(ERROR, + (errcode(ERRCODE_INVALID_NAME), + errmsg("'%s' is not a valid security label for Citus tests.", seclabel))); +} diff --git a/src/backend/distributed/deparser/deparse_seclabel_stmts.c b/src/backend/distributed/deparser/deparse_seclabel_stmts.c index 9119af677..d98f3fb4f 100644 --- a/src/backend/distributed/deparser/deparse_seclabel_stmts.c +++ b/src/backend/distributed/deparser/deparse_seclabel_stmts.c @@ -12,6 +12,7 @@ #include "distributed/deparser.h" #include "nodes/parsenodes.h" +#include "utils/builtins.h" static void AppendSecLabelStmt(StringInfo buf, SecLabelStmt *stmt); @@ -41,10 +42,10 @@ AppendSecLabelStmt(StringInfo buf, SecLabelStmt *stmt) { appendStringInfoString(buf, "SECURITY LABEL "); - if (stmt->provider != NULL) - { - appendStringInfo(buf, "FOR %s ", stmt->provider); - } + if (stmt->provider != NULL) + { + appendStringInfo(buf, "FOR %s ", stmt->provider); + } appendStringInfoString(buf, "ON "); @@ -52,7 +53,7 @@ AppendSecLabelStmt(StringInfo buf, SecLabelStmt *stmt) { case OBJECT_ROLE: { - appendStringInfo(buf, "ROLE %s ", strVal(castNode(String, stmt->object))); + appendStringInfo(buf, "ROLE %s ", strVal(stmt->object)); break; } @@ -67,7 +68,7 @@ AppendSecLabelStmt(StringInfo buf, SecLabelStmt *stmt) if (stmt->label != NULL) { - appendStringInfo(buf, "%s", stmt->label); + appendStringInfo(buf, "%s", quote_literal_cstr(stmt->label)); } else { diff --git a/src/backend/distributed/shared_library_init.c b/src/backend/distributed/shared_library_init.c index 9b5768ee7..6052c7870 100644 --- a/src/backend/distributed/shared_library_init.c +++ b/src/backend/distributed/shared_library_init.c @@ -29,6 +29,7 @@ #include "citus_version.h" #include "commands/explain.h" #include "commands/extension.h" +#include "commands/seclabel.h" #include "common/string.h" #include "executor/executor.h" #include "distributed/backend_data.h" @@ -574,6 +575,8 @@ _PG_init(void) INIT_COLUMNAR_SYMBOL(PGFunction, columnar_storage_info); INIT_COLUMNAR_SYMBOL(PGFunction, columnar_store_memory_stats); INIT_COLUMNAR_SYMBOL(PGFunction, test_columnar_storage_write_new_page); + + register_label_provider("citus_tests_label_provider", citus_test_object_relabel); } diff --git a/src/include/distributed/commands.h b/src/include/distributed/commands.h index 55f8f1867..5b5ae87da 100644 --- a/src/include/distributed/commands.h +++ b/src/include/distributed/commands.h @@ -526,6 +526,7 @@ extern List * PreprocessSecLabelStmt(Node *node, const char *queryString, ProcessUtilityContext processUtilityContext); extern List * PostprocessSecLabelStmt(Node *node, const char *queryString); extern List * SecLabelStmtObjectAddress(Node *node, bool missing_ok, bool isPostprocess); +extern void citus_test_object_relabel(const ObjectAddress *object, const char *seclabel); /* sequence.c - forward declarations */ extern List * PreprocessAlterSequenceStmt(Node *node, const char *queryString,