diff --git a/src/backend/distributed/planner/relation_restriction_equivalence.c b/src/backend/distributed/planner/relation_restriction_equivalence.c index 5a63503f0..39452c53e 100644 --- a/src/backend/distributed/planner/relation_restriction_equivalence.c +++ b/src/backend/distributed/planner/relation_restriction_equivalence.c @@ -83,6 +83,16 @@ typedef struct AttributeEquivalenceClassMember AttrNumber varattno; } AttributeEquivalenceClassMember; +/* + * ECGroupByExpr + * Helper structure to group EquivalenceClasses by their non-Var expressions. + */ +typedef struct ECGroupByExpr +{ + Node *strippedExpr; /* The canonical non-Var expression (stripped) */ + List *ecsWithThisExpr; /* List of EquivalenceClass* sharing this expr */ + List *varsInTheseECs; /* Cached list of Var* from all these ECs */ +} ECGroupByExpr; static bool ContextContainsLocalRelation(RelationRestrictionContext *restrictionContext); static bool ContextContainsAppendRelation(RelationRestrictionContext *restrictionContext); @@ -93,6 +103,8 @@ static bool ContainsMultipleDistributedRelations(PlannerRestrictionContext * plannerRestrictionContext); static List * GenerateAttributeEquivalencesForRelationRestrictions( RelationRestrictionContext *restrictionContext); +static List * MergeEquivalenceClassesWithSameFunctions( + RelationRestrictionContext *restrictionContext); static AttributeEquivalenceClass * AttributeEquivalenceClassForEquivalenceClass( EquivalenceClass *plannerEqClass, RelationRestriction *relationRestriction); static void AddToAttributeEquivalenceClass(AttributeEquivalenceClass * @@ -826,16 +838,34 @@ GenerateAttributeEquivalencesForRelationRestrictions(RelationRestrictionContext { List *attributeEquivalenceList = NIL; ListCell *relationRestrictionCell = NULL; + bool foundRLSPattern = false; if (restrictionContext == NULL) { return attributeEquivalenceList; } + /* + * First pass: Process equivalence classes using the original algorithm. + * This builds the standard attribute equivalence list. + * + * Skip RLS pattern detection entirely if the query doesn't + * use Row Level Security. The hasRowSecurity flag is checked from the query's + * parse tree when any table has RLS policies active. This allows us to skip + * both the pattern detection loop AND the expensive merge pass for non-RLS + * queries (common case). + * + * For RLS queries, detect patterns efficiently. We only need + * to find one EC with both Var + non-Var members to justify the merge pass. + * Once found, skip further pattern checks and focus on building equivalences. + */ + bool skipRLSProcessing = true; foreach(relationRestrictionCell, restrictionContext->relationRestrictionList) { RelationRestriction *relationRestriction = (RelationRestriction *) lfirst(relationRestrictionCell); + + skipRLSProcessing = !relationRestriction->plannerInfo->parse->hasRowSecurity; List *equivalenceClasses = relationRestriction->plannerInfo->eq_classes; ListCell *equivalenceClassCell = NULL; @@ -844,6 +874,44 @@ GenerateAttributeEquivalencesForRelationRestrictions(RelationRestrictionContext EquivalenceClass *plannerEqClass = (EquivalenceClass *) lfirst(equivalenceClassCell); + /* + * RLS pattern = EC with both Var and non-Var (function) members. + * Finding even one such pattern means we need the merge pass. + */ + if (!skipRLSProcessing && !foundRLSPattern) + { + bool hasVar = false; + bool hasNonVar = false; + ListCell *memberCell = NULL; + + foreach(memberCell, plannerEqClass->ec_members) + { + EquivalenceMember *member = (EquivalenceMember *) lfirst(memberCell); + Node *expr = strip_implicit_coercions((Node *) member->em_expr); + + if (IsA(expr, Var)) + { + hasVar = true; + } + else if (member->em_is_const && + !IsA(expr, Param) && !IsA(expr, Const)) + { + /* + * Found a pseudoconstant expression (no Vars) that's not a + * Param or Const - this is the RLS function pattern. + */ + hasNonVar = true; + } + + /* Early exit: If we've found both, we have the pattern */ + if (hasVar && hasNonVar) + { + foundRLSPattern = true; + break; + } + } + } + AttributeEquivalenceClass *attributeEquivalence = AttributeEquivalenceClassForEquivalenceClass(plannerEqClass, relationRestriction); @@ -854,10 +922,271 @@ GenerateAttributeEquivalencesForRelationRestrictions(RelationRestrictionContext } } + /* + * Second pass: Handle RLS-specific case where PostgreSQL splits join conditions + * across multiple EquivalenceClasses due to volatile functions in RLS policies. + * + * When RLS policies use volatile functions (e.g., current_setting()), PostgreSQL + * creates separate EquivalenceClasses that both contain the same volatile function: + * EC1: [table_a.tenant_id, current_setting(...)] + * EC2: [table_b.tenant_id, current_setting(...)] + * + * We need to recognize that these should be merged to detect that tables are + * joined on their distribution columns: [table_a.tenant_id, table_b.tenant_id] + */ + if (foundRLSPattern) + { + List *rlsMergedList = MergeEquivalenceClassesWithSameFunctions( + restrictionContext); + + /* Append any newly created merged classes to the original list */ + attributeEquivalenceList = list_concat(attributeEquivalenceList, rlsMergedList); + } + return attributeEquivalenceList; } +/* + * MergeEquivalenceClassesWithSameFunctions scans equivalence classes + * looking for RLS-specific patterns where volatile functions cause PostgreSQL to + * split what should be a single join condition across multiple EquivalenceClasses. + * + * This function specifically targets the pattern: + * EC1: [table_a.col, COERCEVIAIO(func(...))] + * EC2: [table_b.col, COERCEVIAIO(func(...))] + * + * Where the underlying function calls are identical after stripping implicit coercions + * (e.g., both resolve to current_setting('session.current_tenant_id')). + * + * PostgreSQL wraps RLS policy expressions in COERCEVIAIO nodes to handle type + * conversions (e.g., text → UUID). We strip these to compare the actual function calls. + * + * Returns a list of newly created merged AttributeEquivalenceClasses. Each merged + * class contains the Var members from pairs of EquivalenceClasses that share identical + * non-Var expressions. For example, if EC1 contains [table_a.tenant_id, func()] and + * EC2 contains [table_b.tenant_id, func()], the returned list will include a new + * AttributeEquivalenceClass with [table_a.tenant_id, table_b.tenant_id]. Only classes + * with 2+ members are returned (indicating an actual join between tables). + */ +static List * +MergeEquivalenceClassesWithSameFunctions(RelationRestrictionContext *restrictionContext) +{ + List *newlyMergedClasses = NIL; + List *ecGroupList = NIL; /* List of ECGroupByExpr* */ + ListCell *relationRestrictionCell = NULL; + + if (restrictionContext == NULL) + { + return NIL; + } + + /* + * Phase 1: Collect candidate ECs and group them by their non-Var expressions. + * + * Strategy: For each EC, extract and strip all non-Var expressions, then + * find or create a group for each unique expression. This gives us direct + * access to all ECs sharing the same expression. + */ + foreach(relationRestrictionCell, restrictionContext->relationRestrictionList) + { + RelationRestriction *relationRestriction = + (RelationRestriction *) lfirst(relationRestrictionCell); + List *equivalenceClasses = relationRestriction->plannerInfo->eq_classes; + ListCell *equivalenceClassCell = NULL; + + foreach(equivalenceClassCell, equivalenceClasses) + { + EquivalenceClass *ec = (EquivalenceClass *) lfirst(equivalenceClassCell); + bool hasVar = false; + List *nonVarExprs = NIL; + ListCell *memberCell = NULL; + + /* + * Single pass through EC members: collect Vars and non-Var expressions. + * Strip coercions once and cache the results. + */ + foreach(memberCell, ec->ec_members) + { + EquivalenceMember *member = (EquivalenceMember *) lfirst(memberCell); + Node *expr = strip_implicit_coercions((Node *) member->em_expr); + + if (IsA(expr, Var)) + { + hasVar = true; + } + else if (member->em_is_const && !IsA(expr, Param) && !IsA(expr, Const)) + { + /* + * Found a pseudoconstant expression (no Vars) - potential RLS function. + * After stripping, this is typically a FUNCEXPR like + * current_setting('session.current_tenant_id'). + */ + nonVarExprs = lappend(nonVarExprs, expr); + } + } + + /* Only process ECs with both Var and non-Var members (RLS pattern) */ + if (!hasVar || nonVarExprs == NIL) + { + continue; + } + + /* + * For each non-Var expression in this EC, find or create a group. + * Multiple ECs with the same expression will be grouped together. + */ + ListCell *exprCell = NULL; + foreach(exprCell, nonVarExprs) + { + Node *strippedExpr = (Node *) lfirst(exprCell); + ECGroupByExpr *matchingGroup = NULL; + ListCell *groupCell = NULL; + + /* Search for existing group with this expression */ + foreach(groupCell, ecGroupList) + { + ECGroupByExpr *group = (ECGroupByExpr *) lfirst(groupCell); + + if (equal(group->strippedExpr, strippedExpr)) + { + matchingGroup = group; + break; + } + } + + /* Create new group if this is the first EC with this expression */ + if (matchingGroup == NULL) + { + matchingGroup = palloc0(sizeof(ECGroupByExpr)); + matchingGroup->strippedExpr = strippedExpr; + matchingGroup->ecsWithThisExpr = NIL; + matchingGroup->varsInTheseECs = NIL; + ecGroupList = lappend(ecGroupList, matchingGroup); + } + + /* Add this EC to the group (avoid duplicates) */ + if (!list_member_ptr(matchingGroup->ecsWithThisExpr, ec)) + { + matchingGroup->ecsWithThisExpr = + lappend(matchingGroup->ecsWithThisExpr, ec); + } + } + } + } + + /* + * Phase 2: For each group with 2+ ECs, extract all Vars and create a merged + * AttributeEquivalenceClass. This is where we detect the join pattern. + * + * Idea here is that if multiple ECs share the same non-Var expression (e.g., RLS + * function), then all Vars in those ECs are implicitly equal to each other. + */ + ListCell *groupCell = NULL; + foreach(groupCell, ecGroupList) + { + ECGroupByExpr *group = (ECGroupByExpr *) lfirst(groupCell); + + /* Skip groups with only one EC - no join to detect */ + if (list_length(group->ecsWithThisExpr) < 2) + { + continue; + } + + /* + * Extract all Vars from all ECs in this group. + * These Vars are implicitly equal via the shared expression. + */ + ListCell *ecCell = NULL; + foreach(ecCell, group->ecsWithThisExpr) + { + EquivalenceClass *ec = (EquivalenceClass *) lfirst(ecCell); + ListCell *memberCell = NULL; + + foreach(memberCell, ec->ec_members) + { + EquivalenceMember *member = (EquivalenceMember *) lfirst(memberCell); + Node *expr = strip_implicit_coercions((Node *) member->em_expr); + + if (IsA(expr, Var)) + { + /* Cache this Var for later processing */ + group->varsInTheseECs = lappend(group->varsInTheseECs, expr); + } + } + } + + /* Need at least 2 Vars from different tables to represent a join */ + if (list_length(group->varsInTheseECs) < 2) + { + continue; + } + + /* + * Create the merged AttributeEquivalenceClass. + */ + AttributeEquivalenceClass *mergedClass = + palloc0(sizeof(AttributeEquivalenceClass)); + mergedClass->equivalenceId = AttributeEquivalenceId++; + mergedClass->equivalentAttributes = NIL; + + /* + * Match each Var to its RelationRestriction by comparing varno to + * the restriction's index field (which is the RTE index). + */ + ListCell *varCell = NULL; + foreach(varCell, group->varsInTheseECs) + { + Var *var = (Var *) lfirst(varCell); + ListCell *relResCell = NULL; + bool foundMatch = false; + + /* + * Find the RelationRestriction that corresponds to this Var. + * The index field contains the RTE index (varno) of the relation. + */ + foreach(relResCell, restrictionContext->relationRestrictionList) + { + RelationRestriction *relRestriction = + (RelationRestriction *) lfirst(relResCell); + + /* Direct match: varno equals the restriction's index */ + if (var->varno == relRestriction->index) + { + /* + * Process this Var through AddToAttributeEquivalenceClass. + * This handles subqueries, UNION ALL, LATERAL joins, etc. + */ + AddToAttributeEquivalenceClass(mergedClass, + relRestriction->plannerInfo, var); + foundMatch = true; + break; + } + } + + /* + * If we didn't find a matching restriction, this Var might be from + * a context not tracked in our restriction list (e.g., subquery). + * We skip it as we only care about Vars from distributed tables. + */ + if (!foundMatch) + { + elog(DEBUG2, "Skipping Var with varno=%d in RLS merge - " + "no matching RelationRestriction found", var->varno); + } + } + + /* Only emit if we successfully merged attributes from multiple sources */ + if (list_length(mergedClass->equivalentAttributes) >= 2) + { + newlyMergedClasses = lappend(newlyMergedClasses, mergedClass); + } + } + + return newlyMergedClasses; +} + + /* * AttributeEquivalenceClassForEquivalenceClass is a helper function for * GenerateAttributeEquivalencesForRelationRestrictions. The function takes an diff --git a/src/test/regress/enterprise_schedule b/src/test/regress/enterprise_schedule index 9a832c4d6..265040f36 100644 --- a/src/test/regress/enterprise_schedule +++ b/src/test/regress/enterprise_schedule @@ -34,6 +34,7 @@ test: multi_multiuser_auth test: multi_poolinfo_usage test: multi_alter_table_row_level_security test: multi_alter_table_row_level_security_escape +test: multi_rls_join_distribution_key test: stat_statements test: shard_move_constraints test: shard_move_constraints_blocking diff --git a/src/test/regress/expected/multi_rls_join_distribution_key.out b/src/test/regress/expected/multi_rls_join_distribution_key.out new file mode 100644 index 000000000..6d645e8a3 --- /dev/null +++ b/src/test/regress/expected/multi_rls_join_distribution_key.out @@ -0,0 +1,292 @@ +-- +-- MULTI_RLS_JOIN_DISTRIBUTION_KEY +-- +-- Test that RLS policies with volatile functions don't prevent +-- Citus from recognizing joins on distribution columns. +-- +SET citus.next_shard_id TO 1900000; +SET citus.shard_replication_factor TO 1; +-- Create test user if not exists +DO $$ +BEGIN + IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'app_user') THEN + CREATE USER app_user; + END IF; +END +$$; +-- Create schema and grant privileges +CREATE SCHEMA IF NOT EXISTS rls_join_test; +GRANT ALL PRIVILEGES ON SCHEMA rls_join_test TO app_user; +SET search_path TO rls_join_test; +-- Create and distribute tables +CREATE TABLE table_a (tenant_id uuid, id int); +CREATE TABLE table_b (tenant_id uuid, id int); +CREATE TABLE table_c (tenant_id uuid, id int); +CREATE TABLE table_d (tenant_id uuid, id int); +SELECT create_distributed_table('table_a', 'tenant_id'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +SELECT create_distributed_table('table_b', 'tenant_id', colocate_with => 'table_a'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +SELECT create_distributed_table('table_c', 'tenant_id', colocate_with => 'table_a'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +SELECT create_distributed_table('table_d', 'tenant_id', colocate_with => 'table_a'); + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +-- Grant privileges on tables +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA rls_join_test TO app_user; +-- Insert test data +INSERT INTO table_a VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 1), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 2); +INSERT INTO table_b VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 10), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 20); +INSERT INTO table_c VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 100), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 200); +INSERT INTO table_d VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 1000), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 2000); +-- Enable RLS and create policies on multiple tables +ALTER TABLE table_a ENABLE ROW LEVEL SECURITY; +CREATE POLICY tenant_isolation_a ON table_a TO app_user + USING (tenant_id = current_setting('session.current_tenant_id')::UUID); +ALTER TABLE table_c ENABLE ROW LEVEL SECURITY; +CREATE POLICY tenant_isolation_c ON table_c TO app_user + USING (tenant_id = current_setting('session.current_tenant_id')::UUID); +-- Test scenario that previously failed +-- Switch to app_user and execute the query with RLS +SET ROLE app_user; +SET citus.propagate_set_commands = local; +SET application_name = '0194d116-5dd5-74af-be74-7f5e8468eeb7'; +BEGIN; +-- Set session.current_tenant_id from application_name +DO $$ +DECLARE +BEGIN + EXECUTE 'SET LOCAL session.current_tenant_id = ' || quote_literal(current_setting('application_name', true)); +END; +$$; +-- Simple 2-way join (original test case) +SELECT a.id, b.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +ORDER BY a.id, b.id; + id | id +--------------------------------------------------------------------- + 1 | 10 +(1 row) + +SELECT a.id, b.id +FROM table_a AS a +RIGHT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +ORDER BY a.id, b.id; + id | id +--------------------------------------------------------------------- + 1 | 10 + | 20 +(2 rows) + +-- 3-way join with RLS on multiple tables +SELECT a.id, b.id, c.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +JOIN table_c AS c ON b.tenant_id = c.tenant_id +ORDER BY a.id, b.id, c.id; + id | id | id +--------------------------------------------------------------------- + 1 | 10 | 100 +(1 row) + +SELECT a.id, b.id, c.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT OUTER JOIN table_c AS c ON b.tenant_id = c.tenant_id +ORDER BY a.id, b.id, c.id; + id | id | id +--------------------------------------------------------------------- + 1 | 10 | 100 +(1 row) + +SELECT a.id, b.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +JOIN table_d AS d ON b.tenant_id = d.tenant_id +ORDER BY a.id, b.id, d.id; + id | id | id +--------------------------------------------------------------------- + 1 | 10 | 1000 +(1 row) + +SELECT a.id, b.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT OUTER JOIN table_d AS d ON b.tenant_id = d.tenant_id +ORDER BY a.id, b.id, d.id; + id | id | id +--------------------------------------------------------------------- + 1 | 10 | 1000 +(1 row) + +SELECT a.id, b.id, d.id +FROM table_a AS a +RIGHT JOIN table_b AS b ON a.tenant_id = b.tenant_id +RIGHT OUTER JOIN table_d AS d ON b.tenant_id = d.tenant_id +ORDER BY a.id, b.id, d.id; + id | id | id +--------------------------------------------------------------------- + 1 | 10 | 1000 + | 20 | 2000 +(2 rows) + +SELECT a.id, b.id, d.id +FROM table_a AS a +RIGHT JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT OUTER JOIN table_d AS d ON b.tenant_id = d.tenant_id +ORDER BY a.id, b.id, d.id; + id | id | id +--------------------------------------------------------------------- + 1 | 10 | 1000 + | 20 | 2000 +(2 rows) + +SELECT a.id, c.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_c AS c ON a.tenant_id = c.tenant_id +JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, c.id, d.id; + id | id | id +--------------------------------------------------------------------- + 1 | 100 | 1000 +(1 row) + +SELECT a.id, c.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_c AS c ON a.tenant_id = c.tenant_id +LEFT OUTER JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, c.id, d.id; + id | id | id +--------------------------------------------------------------------- + 1 | 100 | 1000 +(1 row) + +-- 4-way join with different join types +SELECT a.id, b.id, c.id, d.id +FROM table_a AS a +INNER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT JOIN table_c AS c ON b.tenant_id = c.tenant_id +INNER JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, b.id, c.id, d.id; + id | id | id | id +--------------------------------------------------------------------- + 1 | 10 | 100 | 1000 +(1 row) + +SELECT a.id, b.id, c.id, d.id +FROM table_a AS a +INNER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT JOIN table_c AS c ON b.tenant_id = c.tenant_id +RIGHT JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, b.id, c.id, d.id; + id | id | id | id +--------------------------------------------------------------------- + 1 | 10 | 100 | 1000 + | | | 2000 +(2 rows) + +SELECT a.id, b.id, c.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT OUTER JOIN table_c AS c ON b.tenant_id = c.tenant_id +LEFT OUTER JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, b.id, c.id, d.id; + id | id | id | id +--------------------------------------------------------------------- + 1 | 10 | 100 | 1000 +(1 row) + +SELECT a.id, b.id, c.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +RIGHT OUTER JOIN table_c AS c ON b.tenant_id = c.tenant_id +LEFT OUTER JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, b.id, c.id, d.id; + id | id | id | id +--------------------------------------------------------------------- + 1 | 10 | 100 | 1000 +(1 row) + +-- IN subquery that can be transformed to semi-join +SELECT a.id +FROM table_a a +WHERE a.tenant_id IN ( + SELECT b.tenant_id + FROM table_b b + JOIN table_c c USING (tenant_id) +) +ORDER BY a.id; + id +--------------------------------------------------------------------- + 1 +(1 row) + +SELECT a.id +FROM table_a a +WHERE a.tenant_id IN ( + SELECT b.tenant_id + FROM table_b b + LEFT OUTER JOIN table_c c USING (tenant_id) +) +ORDER BY a.id; + id +--------------------------------------------------------------------- + 1 +(1 row) + +-- Another multi-way join variation +SELECT a.id, b.id, c.id +FROM table_a AS a +INNER JOIN table_b AS b ON a.tenant_id = b.tenant_id +INNER JOIN table_c AS c ON a.tenant_id = c.tenant_id +ORDER BY a.id, b.id, c.id; + id | id | id +--------------------------------------------------------------------- + 1 | 10 | 100 +(1 row) + +SELECT a.id, b.id, c.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT OUTER JOIN table_c AS c ON a.tenant_id = c.tenant_id +ORDER BY a.id, b.id, c.id; + id | id | id +--------------------------------------------------------------------- + 1 | 10 | 100 +(1 row) + +ROLLBACK; +-- Switch back to superuser for cleanup +RESET ROLE; +-- Cleanup: Drop schema and all objects +DROP SCHEMA rls_join_test CASCADE; +NOTICE: drop cascades to 4 other objects +DETAIL: drop cascades to table table_a +drop cascades to table table_b +drop cascades to table table_c +drop cascades to table table_d +DROP USER IF EXISTS app_user; diff --git a/src/test/regress/sql/multi_rls_join_distribution_key.sql b/src/test/regress/sql/multi_rls_join_distribution_key.sql new file mode 100644 index 000000000..66f6afd1b --- /dev/null +++ b/src/test/regress/sql/multi_rls_join_distribution_key.sql @@ -0,0 +1,211 @@ +-- +-- MULTI_RLS_JOIN_DISTRIBUTION_KEY +-- +-- Test that RLS policies with volatile functions don't prevent +-- Citus from recognizing joins on distribution columns. +-- + +SET citus.next_shard_id TO 1900000; +SET citus.shard_replication_factor TO 1; + +-- Create test user if not exists +DO $$ +BEGIN + IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'app_user') THEN + CREATE USER app_user; + END IF; +END +$$; + +-- Create schema and grant privileges +CREATE SCHEMA IF NOT EXISTS rls_join_test; +GRANT ALL PRIVILEGES ON SCHEMA rls_join_test TO app_user; + +SET search_path TO rls_join_test; + +-- Create and distribute tables +CREATE TABLE table_a (tenant_id uuid, id int); +CREATE TABLE table_b (tenant_id uuid, id int); +CREATE TABLE table_c (tenant_id uuid, id int); +CREATE TABLE table_d (tenant_id uuid, id int); + +SELECT create_distributed_table('table_a', 'tenant_id'); +SELECT create_distributed_table('table_b', 'tenant_id', colocate_with => 'table_a'); +SELECT create_distributed_table('table_c', 'tenant_id', colocate_with => 'table_a'); +SELECT create_distributed_table('table_d', 'tenant_id', colocate_with => 'table_a'); + +-- Grant privileges on tables +GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA rls_join_test TO app_user; + +-- Insert test data +INSERT INTO table_a VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 1), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 2); + +INSERT INTO table_b VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 10), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 20); + +INSERT INTO table_c VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 100), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 200); + +INSERT INTO table_d VALUES + ('0194d116-5dd5-74af-be74-7f5e8468eeb7', 1000), + ('0194d116-5dd5-74af-be74-7f5e8468eeb8', 2000); + +-- Enable RLS and create policies on multiple tables +ALTER TABLE table_a ENABLE ROW LEVEL SECURITY; +CREATE POLICY tenant_isolation_a ON table_a TO app_user + USING (tenant_id = current_setting('session.current_tenant_id')::UUID); + +ALTER TABLE table_c ENABLE ROW LEVEL SECURITY; +CREATE POLICY tenant_isolation_c ON table_c TO app_user + USING (tenant_id = current_setting('session.current_tenant_id')::UUID); + +-- Test scenario that previously failed +-- Switch to app_user and execute the query with RLS +SET ROLE app_user; + +SET citus.propagate_set_commands = local; +SET application_name = '0194d116-5dd5-74af-be74-7f5e8468eeb7'; + +BEGIN; +-- Set session.current_tenant_id from application_name +DO $$ +DECLARE +BEGIN + EXECUTE 'SET LOCAL session.current_tenant_id = ' || quote_literal(current_setting('application_name', true)); +END; +$$; + +-- Simple 2-way join (original test case) +SELECT a.id, b.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +ORDER BY a.id, b.id; + +SELECT a.id, b.id +FROM table_a AS a +RIGHT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +ORDER BY a.id, b.id; + +-- 3-way join with RLS on multiple tables +SELECT a.id, b.id, c.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +JOIN table_c AS c ON b.tenant_id = c.tenant_id +ORDER BY a.id, b.id, c.id; + +SELECT a.id, b.id, c.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT OUTER JOIN table_c AS c ON b.tenant_id = c.tenant_id +ORDER BY a.id, b.id, c.id; + +SELECT a.id, b.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +JOIN table_d AS d ON b.tenant_id = d.tenant_id +ORDER BY a.id, b.id, d.id; + +SELECT a.id, b.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT OUTER JOIN table_d AS d ON b.tenant_id = d.tenant_id +ORDER BY a.id, b.id, d.id; + +SELECT a.id, b.id, d.id +FROM table_a AS a +RIGHT JOIN table_b AS b ON a.tenant_id = b.tenant_id +RIGHT OUTER JOIN table_d AS d ON b.tenant_id = d.tenant_id +ORDER BY a.id, b.id, d.id; + +SELECT a.id, b.id, d.id +FROM table_a AS a +RIGHT JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT OUTER JOIN table_d AS d ON b.tenant_id = d.tenant_id +ORDER BY a.id, b.id, d.id; + +SELECT a.id, c.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_c AS c ON a.tenant_id = c.tenant_id +JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, c.id, d.id; + +SELECT a.id, c.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_c AS c ON a.tenant_id = c.tenant_id +LEFT OUTER JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, c.id, d.id; + +-- 4-way join with different join types +SELECT a.id, b.id, c.id, d.id +FROM table_a AS a +INNER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT JOIN table_c AS c ON b.tenant_id = c.tenant_id +INNER JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, b.id, c.id, d.id; + +SELECT a.id, b.id, c.id, d.id +FROM table_a AS a +INNER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT JOIN table_c AS c ON b.tenant_id = c.tenant_id +RIGHT JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, b.id, c.id, d.id; + +SELECT a.id, b.id, c.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT OUTER JOIN table_c AS c ON b.tenant_id = c.tenant_id +LEFT OUTER JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, b.id, c.id, d.id; + +SELECT a.id, b.id, c.id, d.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +RIGHT OUTER JOIN table_c AS c ON b.tenant_id = c.tenant_id +LEFT OUTER JOIN table_d AS d ON c.tenant_id = d.tenant_id +ORDER BY a.id, b.id, c.id, d.id; + +-- IN subquery that can be transformed to semi-join +SELECT a.id +FROM table_a a +WHERE a.tenant_id IN ( + SELECT b.tenant_id + FROM table_b b + JOIN table_c c USING (tenant_id) +) +ORDER BY a.id; + +SELECT a.id +FROM table_a a +WHERE a.tenant_id IN ( + SELECT b.tenant_id + FROM table_b b + LEFT OUTER JOIN table_c c USING (tenant_id) +) +ORDER BY a.id; + +-- Another multi-way join variation +SELECT a.id, b.id, c.id +FROM table_a AS a +INNER JOIN table_b AS b ON a.tenant_id = b.tenant_id +INNER JOIN table_c AS c ON a.tenant_id = c.tenant_id +ORDER BY a.id, b.id, c.id; + +SELECT a.id, b.id, c.id +FROM table_a AS a +LEFT OUTER JOIN table_b AS b ON a.tenant_id = b.tenant_id +LEFT OUTER JOIN table_c AS c ON a.tenant_id = c.tenant_id +ORDER BY a.id, b.id, c.id; + +ROLLBACK; + +-- Switch back to superuser for cleanup +RESET ROLE; + +-- Cleanup: Drop schema and all objects +DROP SCHEMA rls_join_test CASCADE; + +DROP USER IF EXISTS app_user;