Detect distribution column joins when RLS policies split equivalence classes

When a table with Row Level Security (RLS) policies is joined with another table
on a column protected by an RLS policy using volatile functions
(e.g., current_setting()), PostgreSQL's planner sometimes splits what would
normally be a single equivalence class into separate equivalence
classes for each table.

Problem:
For a query like:

When table_a has an RLS policy on tenant_id using a volatile function,
PostgreSQL creates separate equivalence classes:

EC1: [table_a.tenant_id, current_setting(...)]
EC2: [table_b.tenant_id, current_setting(...)]

Instead of a single equivalence class that would contain both columns:

EC: [table_a.tenant_id, table_b.tenant_id, current_setting(...)]

This prevents Citus from recognizing that the tables are joined on their
distribution columns, causing the query to fail with: "complex joins are
only supported when all distributed tables are co-located and joined on
their distribution columns".

Solution is to detect when multiple equivalence classes share identical
non-Var expressions (volatile RLS functions) and merge the Var members from
these classes to reconstruct the implicit join condition.
This allows Citus to recognize that table_a.tenant_id = table_b.tenant_id
for query planning purposes.

regression test multi_rls_join_distribution_key that verifies joins on
distribution columns is also part of the commit

Fixes: #7969
muusama/7969
Muhammad Usama 2025-11-23 16:44:47 +03:00
parent 662b7248db
commit 42e34fa870
7 changed files with 533 additions and 0 deletions

View File

@ -249,6 +249,13 @@ distributed_planner(Query *parse,
planContext.plannerRestrictionContext = CreateAndPushPlannerRestrictionContext(
&fastPathContext);
/*
* Set RLS flag from the query. This is used to optimize equivalence class
* processing by skipping expensive RLS-specific merging for non-RLS queries.
*/
planContext.plannerRestrictionContext->relationRestrictionContext->hasRowSecurity =
parse->hasRowSecurity;
/*
* We keep track of how many times we've recursed into the planner, primarily
* to detect whether we are in a function call. We need to make sure that the
@ -2448,6 +2455,9 @@ CreateAndPushPlannerRestrictionContext(
/* we'll apply logical AND as we add tables */
plannerRestrictionContext->relationRestrictionContext->allReferenceTables = true;
/* hasRowSecurity will be set later once we have the Query object */
plannerRestrictionContext->relationRestrictionContext->hasRowSecurity = false;
plannerRestrictionContextList = lcons(plannerRestrictionContext,
plannerRestrictionContextList);
@ -2535,6 +2545,9 @@ ResetPlannerRestrictionContext(PlannerRestrictionContext *plannerRestrictionCont
/* we'll apply logical AND as we add tables */
plannerRestrictionContext->relationRestrictionContext->allReferenceTables = true;
/* hasRowSecurity defaults to false, will be set by caller if needed */
plannerRestrictionContext->relationRestrictionContext->hasRowSecurity = false;
}

View File

@ -4200,6 +4200,7 @@ CopyRelationRestrictionContext(RelationRestrictionContext *oldContext)
ListCell *relationRestrictionCell = NULL;
newContext->allReferenceTables = oldContext->allReferenceTables;
newContext->hasRowSecurity = oldContext->hasRowSecurity;
newContext->relationRestrictionList = NIL;
foreach(relationRestrictionCell, oldContext->relationRestrictionList)

View File

@ -83,6 +83,16 @@ typedef struct AttributeEquivalenceClassMember
AttrNumber varattno;
} AttributeEquivalenceClassMember;
/*
* ECGroupByExpr
* Helper structure to group EquivalenceClasses by their non-Var expressions.
*/
typedef struct ECGroupByExpr
{
Node *strippedExpr; /* The canonical non-Var expression (stripped) */
List *ecsWithThisExpr; /* List of EquivalenceClass* sharing this expr */
List *varsInTheseECs; /* Cached list of Var* from all these ECs */
} ECGroupByExpr;
static bool ContextContainsLocalRelation(RelationRestrictionContext *restrictionContext);
static bool ContextContainsAppendRelation(RelationRestrictionContext *restrictionContext);
@ -93,6 +103,8 @@ static bool ContainsMultipleDistributedRelations(PlannerRestrictionContext *
plannerRestrictionContext);
static List * GenerateAttributeEquivalencesForRelationRestrictions(
RelationRestrictionContext *restrictionContext);
static List * MergeEquivalenceClassesWithSameFunctions(
RelationRestrictionContext *restrictionContext);
static AttributeEquivalenceClass * AttributeEquivalenceClassForEquivalenceClass(
EquivalenceClass *plannerEqClass, RelationRestriction *relationRestriction);
static void AddToAttributeEquivalenceClass(AttributeEquivalenceClass *
@ -826,12 +838,27 @@ GenerateAttributeEquivalencesForRelationRestrictions(RelationRestrictionContext
{
List *attributeEquivalenceList = NIL;
ListCell *relationRestrictionCell = NULL;
bool foundRLSPattern = false;
if (restrictionContext == NULL)
{
return attributeEquivalenceList;
}
/*
* First pass: Process equivalence classes using the original algorithm.
* This builds the standard attribute equivalence list.
*
* Skip RLS pattern detection entirely if the query doesn't
* use Row Level Security. The hasRowSecurity flag is set during query planning
* when any table has RLS policies active. This allows us to skip both the pattern
* detection loop AND the expensive merge pass for non-RLS queries (common case).
*
* For RLS queries, detect patterns efficiently. We only need
* to find one EC with both Var + non-Var members to justify the merge pass.
* Once found, skip further pattern checks and focus on building equivalences.
*/
bool skipRLSProcessing = !restrictionContext->hasRowSecurity;
foreach(relationRestrictionCell, restrictionContext->relationRestrictionList)
{
RelationRestriction *relationRestriction =
@ -844,6 +871,39 @@ GenerateAttributeEquivalencesForRelationRestrictions(RelationRestrictionContext
EquivalenceClass *plannerEqClass =
(EquivalenceClass *) lfirst(equivalenceClassCell);
/*
* RLS pattern = EC with both Var and non-Var (function) members.
* Finding even one such pattern means we need the merge pass.
*/
if (!skipRLSProcessing && !foundRLSPattern)
{
bool hasVar = false;
bool hasNonVar = false;
ListCell *memberCell = NULL;
foreach(memberCell, plannerEqClass->ec_members)
{
EquivalenceMember *member = (EquivalenceMember *) lfirst(memberCell);
Node *expr = strip_implicit_coercions((Node *) member->em_expr);
if (IsA(expr, Var))
{
hasVar = true;
}
else if (!IsA(expr, Param) && !IsA(expr, Const))
{
hasNonVar = true;
}
/* Early exit: If we've found both, we have the pattern */
if (hasVar && hasNonVar)
{
foundRLSPattern = true;
break;
}
}
}
AttributeEquivalenceClass *attributeEquivalence =
AttributeEquivalenceClassForEquivalenceClass(plannerEqClass,
relationRestriction);
@ -854,10 +914,259 @@ GenerateAttributeEquivalencesForRelationRestrictions(RelationRestrictionContext
}
}
/*
* Second pass: Handle RLS-specific case where PostgreSQL splits join conditions
* across multiple EquivalenceClasses due to volatile functions in RLS policies.
*
* When RLS policies use volatile functions (e.g., current_setting()), PostgreSQL
* creates separate EquivalenceClasses that both contain the same volatile function:
* EC1: [table_a.tenant_id, current_setting(...)]
* EC2: [table_b.tenant_id, current_setting(...)]
*
* We need to recognize that these should be merged to detect that tables are
* joined on their distribution columns: [table_a.tenant_id, table_b.tenant_id]
*/
if (foundRLSPattern)
{
List *rlsMergedList = MergeEquivalenceClassesWithSameFunctions(
restrictionContext);
/* Append any newly created merged classes to the original list */
attributeEquivalenceList = list_concat(attributeEquivalenceList, rlsMergedList);
}
return attributeEquivalenceList;
}
/*
* MergeEquivalenceClassesWithSameFunctions scans equivalence classes
* looking for RLS-specific patterns where volatile functions cause PostgreSQL to
* split what should be a single join condition across multiple EquivalenceClasses.
*
* This function specifically targets the pattern:
* EC1: [table_a.col, COERCEVIAIO(func(...))]
* EC2: [table_b.col, COERCEVIAIO(func(...))]
*
* Where the underlying function calls are identical after stripping implicit coercions
* (e.g., both resolve to current_setting('session.current_tenant_id')).
*
* PostgreSQL wraps RLS policy expressions in COERCEVIAIO nodes to handle type
* conversions (e.g., text UUID). We strip these to compare the actual function calls.
*
* Returns a list of newly created merged AttributeEquivalenceClasses. Each merged
* class contains the Var members from pairs of EquivalenceClasses that share identical
* non-Var expressions. For example, if EC1 contains [table_a.tenant_id, func()] and
* EC2 contains [table_b.tenant_id, func()], the returned list will include a new
* AttributeEquivalenceClass with [table_a.tenant_id, table_b.tenant_id]. Only classes
* with 2+ members are returned (indicating an actual join between tables).
*/
static List *
MergeEquivalenceClassesWithSameFunctions(RelationRestrictionContext *restrictionContext)
{
List *newlyMergedClasses = NIL;
List *ecGroupList = NIL; /* List of ECGroupByExpr* */
ListCell *relationRestrictionCell = NULL;
if (restrictionContext == NULL)
{
return NIL;
}
/*
* Phase 1: Collect candidate ECs and group them by their non-Var expressions.
*
* Strategy: For each EC, extract and strip all non-Var expressions, then
* find or create a group for each unique expression. This gives us direct
* access to all ECs sharing the same expression.
*/
foreach(relationRestrictionCell, restrictionContext->relationRestrictionList)
{
RelationRestriction *relationRestriction =
(RelationRestriction *) lfirst(relationRestrictionCell);
List *equivalenceClasses = relationRestriction->plannerInfo->eq_classes;
ListCell *equivalenceClassCell = NULL;
foreach(equivalenceClassCell, equivalenceClasses)
{
EquivalenceClass *ec = (EquivalenceClass *) lfirst(equivalenceClassCell);
bool hasVar = false;
List *nonVarExprs = NIL;
ListCell *memberCell = NULL;
/*
* Single pass through EC members: collect Vars and non-Var expressions.
* Strip coercions once and cache the results.
*/
foreach(memberCell, ec->ec_members)
{
EquivalenceMember *member = (EquivalenceMember *) lfirst(memberCell);
Node *expr = strip_implicit_coercions((Node *) member->em_expr);
if (IsA(expr, Var))
{
hasVar = true;
}
else if (!IsA(expr, Param) && !IsA(expr, Const))
{
/*
* Found a non-Var expression (potential RLS function).
* After stripping, this is typically a FUNCEXPR like
* current_setting('session.current_tenant_id').
*/
nonVarExprs = lappend(nonVarExprs, expr);
}
}
/* Only process ECs with both Var and non-Var members (RLS pattern) */
if (!hasVar || nonVarExprs == NIL)
{
continue;
}
/*
* For each non-Var expression in this EC, find or create a group.
* Multiple ECs with the same expression will be grouped together.
*/
ListCell *exprCell = NULL;
foreach(exprCell, nonVarExprs)
{
Node *strippedExpr = (Node *) lfirst(exprCell);
ECGroupByExpr *matchingGroup = NULL;
ListCell *groupCell = NULL;
/* Search for existing group with this expression */
foreach(groupCell, ecGroupList)
{
ECGroupByExpr *group = (ECGroupByExpr *) lfirst(groupCell);
if (equal(group->strippedExpr, strippedExpr))
{
matchingGroup = group;
break;
}
}
/* Create new group if this is the first EC with this expression */
if (matchingGroup == NULL)
{
matchingGroup = palloc0(sizeof(ECGroupByExpr));
matchingGroup->strippedExpr = strippedExpr;
matchingGroup->ecsWithThisExpr = NIL;
matchingGroup->varsInTheseECs = NIL;
ecGroupList = lappend(ecGroupList, matchingGroup);
}
/* Add this EC to the group (avoid duplicates) */
if (!list_member_ptr(matchingGroup->ecsWithThisExpr, ec))
{
matchingGroup->ecsWithThisExpr =
lappend(matchingGroup->ecsWithThisExpr, ec);
}
}
}
}
/*
* Phase 2: For each group with 2+ ECs, extract all Vars and create a merged
* AttributeEquivalenceClass. This is where we detect the join pattern.
*
* Idea here is that if multiple ECs share the same non-Var expression (e.g., RLS
* function), then all Vars in those ECs are implicitly equal to each other.
*/
ListCell *groupCell = NULL;
foreach(groupCell, ecGroupList)
{
ECGroupByExpr *group = (ECGroupByExpr *) lfirst(groupCell);
/* Skip groups with only one EC - no join to detect */
if (list_length(group->ecsWithThisExpr) < 2)
{
continue;
}
/*
* Extract all Vars from all ECs in this group.
* These Vars are implicitly equal via the shared expression.
*/
ListCell *ecCell = NULL;
foreach(ecCell, group->ecsWithThisExpr)
{
EquivalenceClass *ec = (EquivalenceClass *) lfirst(ecCell);
ListCell *memberCell = NULL;
foreach(memberCell, ec->ec_members)
{
EquivalenceMember *member = (EquivalenceMember *) lfirst(memberCell);
Node *expr = strip_implicit_coercions((Node *) member->em_expr);
if (IsA(expr, Var))
{
/* Cache this Var for later processing */
group->varsInTheseECs = lappend(group->varsInTheseECs, expr);
}
}
}
/* Need at least 2 Vars from different tables to represent a join */
if (list_length(group->varsInTheseECs) < 2)
{
continue;
}
/*
* Create the merged AttributeEquivalenceClass.
*/
AttributeEquivalenceClass *mergedClass =
palloc0(sizeof(AttributeEquivalenceClass));
mergedClass->equivalenceId = AttributeEquivalenceId++;
mergedClass->equivalentAttributes = NIL;
/*
* Build a PlannerInfo lookup map for quick access.
* Map varno RelationRestriction for fast lookups.
*/
ListCell *varCell = NULL;
foreach(varCell, group->varsInTheseECs)
{
Var *var = (Var *) lfirst(varCell);
ListCell *relResCell = NULL;
/*
* Find the appropriate RelationRestriction for this Var.
* We need the correct PlannerInfo context to process the Var.
*/
foreach(relResCell, restrictionContext->relationRestrictionList)
{
RelationRestriction *relRestriction =
(RelationRestriction *) lfirst(relResCell);
PlannerInfo *root = relRestriction->plannerInfo;
/* Check if this Var belongs to this planner's range table */
if (var->varno < root->simple_rel_array_size &&
root->simple_rte_array[var->varno] != NULL)
{
/*
* Process this Var through AddToAttributeEquivalenceClass.
* This handles subqueries, UNION ALL, LATERAL joins, etc.
*/
AddToAttributeEquivalenceClass(mergedClass, root, var);
break; /* Found the right planner, move to next Var */
}
}
}
/* Only emit if we successfully merged attributes from multiple sources */
if (list_length(mergedClass->equivalentAttributes) >= 2)
{
newlyMergedClasses = lappend(newlyMergedClasses, mergedClass);
}
}
return newlyMergedClasses;
}
/*
* AttributeEquivalenceClassForEquivalenceClass is a helper function for
* GenerateAttributeEquivalencesForRelationRestrictions. The function takes an
@ -2395,6 +2704,10 @@ FilterRelationRestrictionContext(RelationRestrictionContext *relationRestriction
RelationRestrictionContext *filteredRestrictionContext =
palloc0(sizeof(RelationRestrictionContext));
/* Preserve RLS flag from the original context */
filteredRestrictionContext->hasRowSecurity =
relationRestrictionContext->hasRowSecurity;
ListCell *relationRestrictionCell = NULL;
foreach(relationRestrictionCell, relationRestrictionContext->relationRestrictionList)

View File

@ -48,6 +48,13 @@ typedef enum RouterPlanType
typedef struct RelationRestrictionContext
{
bool allReferenceTables;
/*
* Set to true when any table in the query
* has Row Level Security policies active.
*/
bool hasRowSecurity;
List *relationRestrictionList;
} RelationRestrictionContext;

View File

@ -34,6 +34,7 @@ test: multi_multiuser_auth
test: multi_poolinfo_usage
test: multi_alter_table_row_level_security
test: multi_alter_table_row_level_security_escape
test: multi_rls_join_distribution_key
test: stat_statements
test: shard_move_constraints
test: shard_move_constraints_blocking

View File

@ -0,0 +1,106 @@
--
-- MULTI_RLS_JOIN_DISTRIBUTION_KEY
--
-- Test that RLS policies with volatile functions don't prevent
-- Citus from recognizing joins on distribution columns.
--
-- This addresses GitHub issue #7969 where RLS policies using
-- current_setting() caused PostgreSQL to split equivalence classes,
-- preventing Citus from detecting proper distribution column joins.
SET citus.next_shard_id TO 1900000;
SET citus.shard_replication_factor TO 1;
-- Create test user if not exists
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'app_user') THEN
CREATE USER app_user;
END IF;
END
$$;
-- Create schema and grant privileges
CREATE SCHEMA IF NOT EXISTS rls_join_test;
GRANT ALL PRIVILEGES ON SCHEMA rls_join_test TO app_user;
SET search_path TO rls_join_test;
-- Create and distribute tables
CREATE TABLE table_a (tenant_id uuid, id int);
CREATE TABLE table_b (tenant_id uuid, id int);
SELECT create_distributed_table('table_a', 'tenant_id');
create_distributed_table
---------------------------------------------------------------------
(1 row)
SELECT create_distributed_table('table_b', 'tenant_id', colocate_with => 'table_a');
create_distributed_table
---------------------------------------------------------------------
(1 row)
-- Grant privileges on tables
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA rls_join_test TO app_user;
-- Insert test data
INSERT INTO table_a VALUES
('0194d116-5dd5-74af-be74-7f5e8468eeb7', 1),
('0194d116-5dd5-74af-be74-7f5e8468eeb8', 2);
INSERT INTO table_b VALUES
('0194d116-5dd5-74af-be74-7f5e8468eeb7', 10),
('0194d116-5dd5-74af-be74-7f5e8468eeb8', 20);
-- Enable RLS and create policy
ALTER TABLE table_a ENABLE ROW LEVEL SECURITY;
CREATE POLICY tenant_isolation_0 ON table_a TO app_user
USING (tenant_id = current_setting('session.current_tenant_id')::UUID);
-- Test scenario that previously failed
-- Switch to app_user and execute the query with RLS
SET ROLE app_user;
SET citus.propagate_set_commands = local;
SET application_name = '0194d116-5dd5-74af-be74-7f5e8468eeb7';
BEGIN;
-- Set session.current_tenant_id from application_name
DO $$
DECLARE
BEGIN
EXECUTE 'SET LOCAL session.current_tenant_id = ' || quote_literal(current_setting('application_name', true));
END;
$$;
-- This query should work with RLS enabled
-- Before the fix, this would fail with:
-- "complex joins are only supported when all distributed tables are
-- co-located and joined on their distribution columns"
EXPLAIN (COSTS OFF)
SELECT c.id, t.id
FROM table_a AS c
LEFT OUTER JOIN table_b AS t ON c.tenant_id = t.tenant_id;
QUERY PLAN
---------------------------------------------------------------------
Custom Scan (Citus Adaptive)
Task Count: 4
Tasks Shown: One of 4
-> Task
Node: host=localhost port=xxxxx dbname=regression
-> Nested Loop Left Join
-> Seq Scan on table_a_1900000 c
Filter: (tenant_id = (current_setting('session.current_tenant_id'::text))::uuid)
-> Materialize
-> Seq Scan on table_b_1900004 t
Filter: (tenant_id = (current_setting('session.current_tenant_id'::text))::uuid)
(11 rows)
SELECT c.id, t.id
FROM table_a AS c
LEFT OUTER JOIN table_b AS t ON c.tenant_id = t.tenant_id
ORDER BY c.id, t.id;
id | id
---------------------------------------------------------------------
1 | 10
(1 row)
ROLLBACK;
-- Switch back to superuser for cleanup
RESET ROLE;
-- Cleanup: Drop schema and all objects
DROP SCHEMA rls_join_test CASCADE;
NOTICE: drop cascades to 2 other objects
DETAIL: drop cascades to table table_a
drop cascades to table table_b
-- Drop user (suppress error if it has dependencies)
DROP USER IF EXISTS app_user;

View File

@ -0,0 +1,92 @@
--
-- MULTI_RLS_JOIN_DISTRIBUTION_KEY
--
-- Test that RLS policies with volatile functions don't prevent
-- Citus from recognizing joins on distribution columns.
--
-- This addresses GitHub issue #7969 where RLS policies using
-- current_setting() caused PostgreSQL to split equivalence classes,
-- preventing Citus from detecting proper distribution column joins.
SET citus.next_shard_id TO 1900000;
SET citus.shard_replication_factor TO 1;
-- Create test user if not exists
DO $$
BEGIN
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = 'app_user') THEN
CREATE USER app_user;
END IF;
END
$$;
-- Create schema and grant privileges
CREATE SCHEMA IF NOT EXISTS rls_join_test;
GRANT ALL PRIVILEGES ON SCHEMA rls_join_test TO app_user;
SET search_path TO rls_join_test;
-- Create and distribute tables
CREATE TABLE table_a (tenant_id uuid, id int);
CREATE TABLE table_b (tenant_id uuid, id int);
SELECT create_distributed_table('table_a', 'tenant_id');
SELECT create_distributed_table('table_b', 'tenant_id', colocate_with => 'table_a');
-- Grant privileges on tables
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA rls_join_test TO app_user;
-- Insert test data
INSERT INTO table_a VALUES
('0194d116-5dd5-74af-be74-7f5e8468eeb7', 1),
('0194d116-5dd5-74af-be74-7f5e8468eeb8', 2);
INSERT INTO table_b VALUES
('0194d116-5dd5-74af-be74-7f5e8468eeb7', 10),
('0194d116-5dd5-74af-be74-7f5e8468eeb8', 20);
-- Enable RLS and create policy
ALTER TABLE table_a ENABLE ROW LEVEL SECURITY;
CREATE POLICY tenant_isolation_0 ON table_a TO app_user
USING (tenant_id = current_setting('session.current_tenant_id')::UUID);
-- Test scenario that previously failed
-- Switch to app_user and execute the query with RLS
SET ROLE app_user;
SET citus.propagate_set_commands = local;
SET application_name = '0194d116-5dd5-74af-be74-7f5e8468eeb7';
BEGIN;
-- Set session.current_tenant_id from application_name
DO $$
DECLARE
BEGIN
EXECUTE 'SET LOCAL session.current_tenant_id = ' || quote_literal(current_setting('application_name', true));
END;
$$;
-- This query should work with RLS enabled
-- Before the fix, this would fail with:
-- "complex joins are only supported when all distributed tables are
-- co-located and joined on their distribution columns"
EXPLAIN (COSTS OFF)
SELECT c.id, t.id
FROM table_a AS c
LEFT OUTER JOIN table_b AS t ON c.tenant_id = t.tenant_id;
SELECT c.id, t.id
FROM table_a AS c
LEFT OUTER JOIN table_b AS t ON c.tenant_id = t.tenant_id
ORDER BY c.id, t.id;
ROLLBACK;
-- Switch back to superuser for cleanup
RESET ROLE;
-- Cleanup: Drop schema and all objects
DROP SCHEMA rls_join_test CASCADE;
-- Drop user (suppress error if it has dependencies)
DROP USER IF EXISTS app_user;