diff --git a/src/backend/distributed/planner/distributed_planner.c b/src/backend/distributed/planner/distributed_planner.c index be046bf9b..99026972a 100644 --- a/src/backend/distributed/planner/distributed_planner.c +++ b/src/backend/distributed/planner/distributed_planner.c @@ -249,6 +249,13 @@ distributed_planner(Query *parse, planContext.plannerRestrictionContext = CreateAndPushPlannerRestrictionContext( &fastPathContext); + /* + * Set RLS flag from the query. This is used to optimize equivalence class + * processing by skipping expensive RLS-specific merging for non-RLS queries. + */ + planContext.plannerRestrictionContext->relationRestrictionContext->hasRowSecurity = + parse->hasRowSecurity; + /* * We keep track of how many times we've recursed into the planner, primarily * to detect whether we are in a function call. We need to make sure that the @@ -2448,6 +2455,9 @@ CreateAndPushPlannerRestrictionContext( /* we'll apply logical AND as we add tables */ plannerRestrictionContext->relationRestrictionContext->allReferenceTables = true; + /* hasRowSecurity will be set later once we have the Query object */ + plannerRestrictionContext->relationRestrictionContext->hasRowSecurity = false; + plannerRestrictionContextList = lcons(plannerRestrictionContext, plannerRestrictionContextList); @@ -2535,6 +2545,9 @@ ResetPlannerRestrictionContext(PlannerRestrictionContext *plannerRestrictionCont /* we'll apply logical AND as we add tables */ plannerRestrictionContext->relationRestrictionContext->allReferenceTables = true; + + /* hasRowSecurity defaults to false, will be set by caller if needed */ + plannerRestrictionContext->relationRestrictionContext->hasRowSecurity = false; } diff --git a/src/backend/distributed/planner/multi_router_planner.c b/src/backend/distributed/planner/multi_router_planner.c index 14ce199c8..84eb72825 100644 --- a/src/backend/distributed/planner/multi_router_planner.c +++ b/src/backend/distributed/planner/multi_router_planner.c @@ -4200,6 +4200,7 @@ CopyRelationRestrictionContext(RelationRestrictionContext *oldContext) ListCell *relationRestrictionCell = NULL; newContext->allReferenceTables = oldContext->allReferenceTables; + newContext->hasRowSecurity = oldContext->hasRowSecurity; newContext->relationRestrictionList = NIL; foreach(relationRestrictionCell, oldContext->relationRestrictionList) diff --git a/src/backend/distributed/planner/relation_restriction_equivalence.c b/src/backend/distributed/planner/relation_restriction_equivalence.c index 5a63503f0..0b475fbef 100644 --- a/src/backend/distributed/planner/relation_restriction_equivalence.c +++ b/src/backend/distributed/planner/relation_restriction_equivalence.c @@ -83,6 +83,16 @@ typedef struct AttributeEquivalenceClassMember AttrNumber varattno; } AttributeEquivalenceClassMember; +/* + * ECGroupByExpr + * Helper structure to group EquivalenceClasses by their non-Var expressions. + */ +typedef struct ECGroupByExpr +{ + Node *strippedExpr; /* The canonical non-Var expression (stripped) */ + List *ecsWithThisExpr; /* List of EquivalenceClass* sharing this expr */ + List *varsInTheseECs; /* Cached list of Var* from all these ECs */ +} ECGroupByExpr; static bool ContextContainsLocalRelation(RelationRestrictionContext *restrictionContext); static bool ContextContainsAppendRelation(RelationRestrictionContext *restrictionContext); @@ -93,6 +103,8 @@ static bool ContainsMultipleDistributedRelations(PlannerRestrictionContext * plannerRestrictionContext); static List * GenerateAttributeEquivalencesForRelationRestrictions( RelationRestrictionContext *restrictionContext); +static List * MergeEquivalenceClassesWithSameFunctions( + RelationRestrictionContext *restrictionContext); static AttributeEquivalenceClass * AttributeEquivalenceClassForEquivalenceClass( EquivalenceClass *plannerEqClass, RelationRestriction *relationRestriction); static void AddToAttributeEquivalenceClass(AttributeEquivalenceClass * @@ -826,12 +838,27 @@ GenerateAttributeEquivalencesForRelationRestrictions(RelationRestrictionContext { List *attributeEquivalenceList = NIL; ListCell *relationRestrictionCell = NULL; + bool foundRLSPattern = false; if (restrictionContext == NULL) { return attributeEquivalenceList; } + /* + * First pass: Process equivalence classes using the original algorithm. + * This builds the standard attribute equivalence list. + * + * Skip RLS pattern detection entirely if the query doesn't + * use Row Level Security. The hasRowSecurity flag is set during query planning + * when any table has RLS policies active. This allows us to skip both the pattern + * detection loop AND the expensive merge pass for non-RLS queries (common case). + * + * For RLS queries, detect patterns efficiently. We only need + * to find one EC with both Var + non-Var members to justify the merge pass. + * Once found, skip further pattern checks and focus on building equivalences. + */ + bool skipRLSProcessing = !restrictionContext->hasRowSecurity; foreach(relationRestrictionCell, restrictionContext->relationRestrictionList) { RelationRestriction *relationRestriction = @@ -844,6 +871,39 @@ GenerateAttributeEquivalencesForRelationRestrictions(RelationRestrictionContext EquivalenceClass *plannerEqClass = (EquivalenceClass *) lfirst(equivalenceClassCell); + /* + * RLS pattern = EC with both Var and non-Var (function) members. + * Finding even one such pattern means we need the merge pass. + */ + if (!skipRLSProcessing && !foundRLSPattern) + { + bool hasVar = false; + bool hasNonVar = false; + ListCell *memberCell = NULL; + + foreach(memberCell, plannerEqClass->ec_members) + { + EquivalenceMember *member = (EquivalenceMember *) lfirst(memberCell); + Node *expr = strip_implicit_coercions((Node *) member->em_expr); + + if (IsA(expr, Var)) + { + hasVar = true; + } + else if (!IsA(expr, Param) && !IsA(expr, Const)) + { + hasNonVar = true; + } + + /* Early exit: If we've found both, we have the pattern */ + if (hasVar && hasNonVar) + { + foundRLSPattern = true; + break; + } + } + } + AttributeEquivalenceClass *attributeEquivalence = AttributeEquivalenceClassForEquivalenceClass(plannerEqClass, relationRestriction); @@ -854,10 +914,259 @@ GenerateAttributeEquivalencesForRelationRestrictions(RelationRestrictionContext } } + /* + * Second pass: Handle RLS-specific case where PostgreSQL splits join conditions + * across multiple EquivalenceClasses due to volatile functions in RLS policies. + * + * When RLS policies use volatile functions (e.g., current_setting()), PostgreSQL + * creates separate EquivalenceClasses that both contain the same volatile function: + * EC1: [table_a.tenant_id, current_setting(...)] + * EC2: [table_b.tenant_id, current_setting(...)] + * + * We need to recognize that these should be merged to detect that tables are + * joined on their distribution columns: [table_a.tenant_id, table_b.tenant_id] + */ + if (foundRLSPattern) + { + List *rlsMergedList = MergeEquivalenceClassesWithSameFunctions( + restrictionContext); + + /* Append any newly created merged classes to the original list */ + attributeEquivalenceList = list_concat(attributeEquivalenceList, rlsMergedList); + } + return attributeEquivalenceList; } +/* + * MergeEquivalenceClassesWithSameFunctions scans equivalence classes + * looking for RLS-specific patterns where volatile functions cause PostgreSQL to + * split what should be a single join condition across multiple EquivalenceClasses. + * + * This function specifically targets the pattern: + * EC1: [table_a.col, COERCEVIAIO(func(...))] + * EC2: [table_b.col, COERCEVIAIO(func(...))] + * + * Where the underlying function calls are identical after stripping implicit coercions + * (e.g., both resolve to current_setting('session.current_tenant_id')). + * + * PostgreSQL wraps RLS policy expressions in COERCEVIAIO nodes to handle type + * conversions (e.g., text → UUID). We strip these to compare the actual function calls. + * + * Returns a list of newly created merged AttributeEquivalenceClasses. Each merged + * class contains the Var members from pairs of EquivalenceClasses that share identical + * non-Var expressions. For example, if EC1 contains [table_a.tenant_id, func()] and + * EC2 contains [table_b.tenant_id, func()], the returned list will include a new + * AttributeEquivalenceClass with [table_a.tenant_id, table_b.tenant_id]. Only classes + * with 2+ members are returned (indicating an actual join between tables). + */ +static List * +MergeEquivalenceClassesWithSameFunctions(RelationRestrictionContext *restrictionContext) +{ + List *newlyMergedClasses = NIL; + List *ecGroupList = NIL; /* List of ECGroupByExpr* */ + ListCell *relationRestrictionCell = NULL; + + if (restrictionContext == NULL) + { + return NIL; + } + + /* + * Phase 1: Collect candidate ECs and group them by their non-Var expressions. + * + * Strategy: For each EC, extract and strip all non-Var expressions, then + * find or create a group for each unique expression. This gives us direct + * access to all ECs sharing the same expression. + */ + foreach(relationRestrictionCell, restrictionContext->relationRestrictionList) + { + RelationRestriction *relationRestriction = + (RelationRestriction *) lfirst(relationRestrictionCell); + List *equivalenceClasses = relationRestriction->plannerInfo->eq_classes; + ListCell *equivalenceClassCell = NULL; + + foreach(equivalenceClassCell, equivalenceClasses) + { + EquivalenceClass *ec = (EquivalenceClass *) lfirst(equivalenceClassCell); + bool hasVar = false; + List *nonVarExprs = NIL; + ListCell *memberCell = NULL; + + /* + * Single pass through EC members: collect Vars and non-Var expressions. + * Strip coercions once and cache the results. + */ + foreach(memberCell, ec->ec_members) + { + EquivalenceMember *member = (EquivalenceMember *) lfirst(memberCell); + Node *expr = strip_implicit_coercions((Node *) member->em_expr); + + if (IsA(expr, Var)) + { + hasVar = true; + } + else if (!IsA(expr, Param) && !IsA(expr, Const)) + { + /* + * Found a non-Var expression (potential RLS function). + * After stripping, this is typically a FUNCEXPR like + * current_setting('session.current_tenant_id'). + */ + nonVarExprs = lappend(nonVarExprs, expr); + } + } + + /* Only process ECs with both Var and non-Var members (RLS pattern) */ + if (!hasVar || nonVarExprs == NIL) + { + continue; + } + + /* + * For each non-Var expression in this EC, find or create a group. + * Multiple ECs with the same expression will be grouped together. + */ + ListCell *exprCell = NULL; + foreach(exprCell, nonVarExprs) + { + Node *strippedExpr = (Node *) lfirst(exprCell); + ECGroupByExpr *matchingGroup = NULL; + ListCell *groupCell = NULL; + + /* Search for existing group with this expression */ + foreach(groupCell, ecGroupList) + { + ECGroupByExpr *group = (ECGroupByExpr *) lfirst(groupCell); + + if (equal(group->strippedExpr, strippedExpr)) + { + matchingGroup = group; + break; + } + } + + /* Create new group if this is the first EC with this expression */ + if (matchingGroup == NULL) + { + matchingGroup = palloc0(sizeof(ECGroupByExpr)); + matchingGroup->strippedExpr = strippedExpr; + matchingGroup->ecsWithThisExpr = NIL; + matchingGroup->varsInTheseECs = NIL; + ecGroupList = lappend(ecGroupList, matchingGroup); + } + + /* Add this EC to the group (avoid duplicates) */ + if (!list_member_ptr(matchingGroup->ecsWithThisExpr, ec)) + { + matchingGroup->ecsWithThisExpr = + lappend(matchingGroup->ecsWithThisExpr, ec); + } + } + } + } + + /* + * Phase 2: For each group with 2+ ECs, extract all Vars and create a merged + * AttributeEquivalenceClass. This is where we detect the join pattern. + * + * Idea here is that if multiple ECs share the same non-Var expression (e.g., RLS + * function), then all Vars in those ECs are implicitly equal to each other. + */ + ListCell *groupCell = NULL; + foreach(groupCell, ecGroupList) + { + ECGroupByExpr *group = (ECGroupByExpr *) lfirst(groupCell); + + /* Skip groups with only one EC - no join to detect */ + if (list_length(group->ecsWithThisExpr) < 2) + { + continue; + } + + /* + * Extract all Vars from all ECs in this group. + * These Vars are implicitly equal via the shared expression. + */ + ListCell *ecCell = NULL; + foreach(ecCell, group->ecsWithThisExpr) + { + EquivalenceClass *ec = (EquivalenceClass *) lfirst(ecCell); + ListCell *memberCell = NULL; + + foreach(memberCell, ec->ec_members) + { + EquivalenceMember *member = (EquivalenceMember *) lfirst(memberCell); + Node *expr = strip_implicit_coercions((Node *) member->em_expr); + + if (IsA(expr, Var)) + { + /* Cache this Var for later processing */ + group->varsInTheseECs = lappend(group->varsInTheseECs, expr); + } + } + } + + /* Need at least 2 Vars from different tables to represent a join */ + if (list_length(group->varsInTheseECs) < 2) + { + continue; + } + + /* + * Create the merged AttributeEquivalenceClass. + */ + AttributeEquivalenceClass *mergedClass = + palloc0(sizeof(AttributeEquivalenceClass)); + mergedClass->equivalenceId = AttributeEquivalenceId++; + mergedClass->equivalentAttributes = NIL; + + /* + * Build a PlannerInfo lookup map for quick access. + * Map varno → RelationRestriction for fast lookups. + */ + ListCell *varCell = NULL; + foreach(varCell, group->varsInTheseECs) + { + Var *var = (Var *) lfirst(varCell); + ListCell *relResCell = NULL; + + /* + * Find the appropriate RelationRestriction for this Var. + * We need the correct PlannerInfo context to process the Var. + */ + foreach(relResCell, restrictionContext->relationRestrictionList) + { + RelationRestriction *relRestriction = + (RelationRestriction *) lfirst(relResCell); + PlannerInfo *root = relRestriction->plannerInfo; + + /* Check if this Var belongs to this planner's range table */ + if (var->varno < root->simple_rel_array_size && + root->simple_rte_array[var->varno] != NULL) + { + /* + * Process this Var through AddToAttributeEquivalenceClass. + * This handles subqueries, UNION ALL, LATERAL joins, etc. + */ + AddToAttributeEquivalenceClass(mergedClass, root, var); + break; /* Found the right planner, move to next Var */ + } + } + } + + /* Only emit if we successfully merged attributes from multiple sources */ + if (list_length(mergedClass->equivalentAttributes) >= 2) + { + newlyMergedClasses = lappend(newlyMergedClasses, mergedClass); + } + } + + return newlyMergedClasses; +} + + /* * AttributeEquivalenceClassForEquivalenceClass is a helper function for * GenerateAttributeEquivalencesForRelationRestrictions. The function takes an @@ -2395,6 +2704,10 @@ FilterRelationRestrictionContext(RelationRestrictionContext *relationRestriction RelationRestrictionContext *filteredRestrictionContext = palloc0(sizeof(RelationRestrictionContext)); + /* Preserve RLS flag from the original context */ + filteredRestrictionContext->hasRowSecurity = + relationRestrictionContext->hasRowSecurity; + ListCell *relationRestrictionCell = NULL; foreach(relationRestrictionCell, relationRestrictionContext->relationRestrictionList) diff --git a/src/include/distributed/distributed_planner.h b/src/include/distributed/distributed_planner.h index 67637cd78..e36651317 100644 --- a/src/include/distributed/distributed_planner.h +++ b/src/include/distributed/distributed_planner.h @@ -48,6 +48,13 @@ typedef enum RouterPlanType typedef struct RelationRestrictionContext { bool allReferenceTables; + + /* + * Set to true when any table in the query + * has Row Level Security policies active. + */ + bool hasRowSecurity; + List *relationRestrictionList; } RelationRestrictionContext; diff --git a/src/test/regress/enterprise_schedule b/src/test/regress/enterprise_schedule index 9a832c4d6..265040f36 100644 --- a/src/test/regress/enterprise_schedule +++ b/src/test/regress/enterprise_schedule @@ -34,6 +34,7 @@ test: multi_multiuser_auth test: multi_poolinfo_usage test: multi_alter_table_row_level_security test: multi_alter_table_row_level_security_escape +test: multi_rls_join_distribution_key test: stat_statements test: shard_move_constraints test: shard_move_constraints_blocking diff --git a/src/test/regress/expected/multi_rls_join_distribution_key.out b/src/test/regress/expected/multi_rls_join_distribution_key.out new file mode 100644 index 000000000..60255488f --- /dev/null +++ b/src/test/regress/expected/multi_rls_join_distribution_key.out @@ -0,0 +1,106 @@ +-- +-- MULTI_RLS_JOIN_DISTRIBUTION_KEY +-- +-- Test that RLS policies with volatile functions don't prevent +-- Citus from recognizing joins on distribution columns. +-- +-- This addresses GitHub issue #7969 where RLS policies using +-- current_setting() caused PostgreSQL to split equivalence classes, +-- preventing Citus from detecting proper distribution column joins. +SET citus.next_shard_id TO 1900000; +SET citus.shard_replication_factor TO 1; +-- Create test user if not exists +DO $$ +BEGIN + IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'app_user') THEN + CREATE USER app_user; + END IF; +END +$$; +-- Create schema and grant privileges +CREATE SCHEMA IF NOT EXISTS rls_join_test; +GRANT ALL PRIVILEGES ON SCHEMA rls_join_test TO app_user; +SET search_path TO rls_join_test; +-- Create and distribute tables +CREATE TABLE table_a (tenant_id uuid, id int); +CREATE TABLE table_b (tenant_id uuid, id int); +SELECT create_distributed_table('table_a', 'tenant_id'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +SELECT create_distributed_table('table_b', 'tenant_id', colocate_with => 'table_a'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +-- Grant privileges on tables +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA rls_join_test TO app_user; +-- Insert test data +INSERT INTO table_a VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 1), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 2); +INSERT INTO table_b VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 10), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 20); +-- Enable RLS and create policy +ALTER TABLE table_a ENABLE ROW LEVEL SECURITY; +CREATE POLICY tenant_isolation_0 ON table_a TO app_user + USING (tenant_id = current_setting('session.current_tenant_id')::UUID); +-- Test scenario that previously failed +-- Switch to app_user and execute the query with RLS +SET ROLE app_user; +SET citus.propagate_set_commands = local; +SET application_name = '0194d116-5dd5-74af-be74-7f5e8468eeb7'; +BEGIN; +-- Set session.current_tenant_id from application_name +DO $$ +DECLARE +BEGIN + EXECUTE 'SET LOCAL session.current_tenant_id = ' || quote_literal(current_setting('application_name', true)); +END; +$$; +-- This query should work with RLS enabled +-- Before the fix, this would fail with: +-- "complex joins are only supported when all distributed tables are +-- co-located and joined on their distribution columns" +EXPLAIN (COSTS OFF) +SELECT c.id, t.id +FROM table_a AS c +LEFT OUTER JOIN table_b AS t ON c.tenant_id = t.tenant_id; + QUERY PLAN +--------------------------------------------------------------------- + Custom Scan (Citus Adaptive) + Task Count: 4 + Tasks Shown: One of 4 + -> Task + Node: host=localhost port=xxxxx dbname=regression + -> Nested Loop Left Join + -> Seq Scan on table_a_1900000 c + Filter: (tenant_id = (current_setting('session.current_tenant_id'::text))::uuid) + -> Materialize + -> Seq Scan on table_b_1900004 t + Filter: (tenant_id = (current_setting('session.current_tenant_id'::text))::uuid) +(11 rows) + +SELECT c.id, t.id +FROM table_a AS c +LEFT OUTER JOIN table_b AS t ON c.tenant_id = t.tenant_id +ORDER BY c.id, t.id; + id | id +--------------------------------------------------------------------- + 1 | 10 +(1 row) + +ROLLBACK; +-- Switch back to superuser for cleanup +RESET ROLE; +-- Cleanup: Drop schema and all objects +DROP SCHEMA rls_join_test CASCADE; +NOTICE: drop cascades to 2 other objects +DETAIL: drop cascades to table table_a +drop cascades to table table_b +-- Drop user (suppress error if it has dependencies) +DROP USER IF EXISTS app_user; diff --git a/src/test/regress/sql/multi_rls_join_distribution_key.sql b/src/test/regress/sql/multi_rls_join_distribution_key.sql new file mode 100644 index 000000000..31fbba971 --- /dev/null +++ b/src/test/regress/sql/multi_rls_join_distribution_key.sql @@ -0,0 +1,92 @@ +-- +-- MULTI_RLS_JOIN_DISTRIBUTION_KEY +-- +-- Test that RLS policies with volatile functions don't prevent +-- Citus from recognizing joins on distribution columns. +-- +-- This addresses GitHub issue #7969 where RLS policies using +-- current_setting() caused PostgreSQL to split equivalence classes, +-- preventing Citus from detecting proper distribution column joins. + +SET citus.next_shard_id TO 1900000; +SET citus.shard_replication_factor TO 1; + +-- Create test user if not exists +DO $$ +BEGIN + IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'app_user') THEN + CREATE USER app_user; + END IF; +END +$$; + +-- Create schema and grant privileges +CREATE SCHEMA IF NOT EXISTS rls_join_test; +GRANT ALL PRIVILEGES ON SCHEMA rls_join_test TO app_user; + +SET search_path TO rls_join_test; + +-- Create and distribute tables +CREATE TABLE table_a (tenant_id uuid, id int); +CREATE TABLE table_b (tenant_id uuid, id int); + +SELECT create_distributed_table('table_a', 'tenant_id'); +SELECT create_distributed_table('table_b', 'tenant_id', colocate_with => 'table_a'); + +-- Grant privileges on tables +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA rls_join_test TO app_user; + +-- Insert test data +INSERT INTO table_a VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 1), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 2); + +INSERT INTO table_b VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 10), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 20); + +-- Enable RLS and create policy +ALTER TABLE table_a ENABLE ROW LEVEL SECURITY; +CREATE POLICY tenant_isolation_0 ON table_a TO app_user + USING (tenant_id = current_setting('session.current_tenant_id')::UUID); + +-- Test scenario that previously failed +-- Switch to app_user and execute the query with RLS +SET ROLE app_user; + +SET citus.propagate_set_commands = local; +SET application_name = '0194d116-5dd5-74af-be74-7f5e8468eeb7'; + +BEGIN; +-- Set session.current_tenant_id from application_name +DO $$ +DECLARE +BEGIN + EXECUTE 'SET LOCAL session.current_tenant_id = ' || quote_literal(current_setting('application_name', true)); +END; +$$; + +-- This query should work with RLS enabled +-- Before the fix, this would fail with: +-- "complex joins are only supported when all distributed tables are +-- co-located and joined on their distribution columns" +EXPLAIN (COSTS OFF) +SELECT c.id, t.id +FROM table_a AS c +LEFT OUTER JOIN table_b AS t ON c.tenant_id = t.tenant_id; + +SELECT c.id, t.id +FROM table_a AS c +LEFT OUTER JOIN table_b AS t ON c.tenant_id = t.tenant_id +ORDER BY c.id, t.id; + +ROLLBACK; + +-- Switch back to superuser for cleanup +RESET ROLE; + +-- Cleanup: Drop schema and all objects +DROP SCHEMA rls_join_test CASCADE; + +-- Drop user (suppress error if it has dependencies) +DROP USER IF EXISTS app_user;