diff --git a/src/backend/distributed/utils/citus_clauses.c b/src/backend/distributed/utils/citus_clauses.c index 1acfc1278..17e076503 100644 --- a/src/backend/distributed/utils/citus_clauses.c +++ b/src/backend/distributed/utils/citus_clauses.c @@ -27,7 +27,7 @@ /* private function declarations */ -static bool IsVarOrParamSublink(Node *node); +static bool IsVariableExpression(Node *node); static Expr * citus_evaluate_expr(Expr *expr, Oid result_type, int32 result_typmod, Oid result_collation, MasterEvaluationContext *masterEvaluationContext); @@ -120,9 +120,19 @@ PartiallyEvaluateExpression(Node *expression, else if (ShouldEvaluateExpression((Expr *) expression) && ShouldEvaluateFunctionWithMasterContext(masterEvaluationContext)) { - /* don't call citus_evaluate_expr on nodes it cannot handle */ - if (FindNodeCheck(expression, IsVarOrParamSublink)) + if (FindNodeCheck(expression, IsVariableExpression)) { + /* + * The expression contains a variable expression (e.g. a stable function, + * which has a column reference as its input). That means that we cannot + * evaluate the expression on the coordinator, since the result depends + * on the input. + * + * Skipping function evaluation for these expressions is safe in most + * cases, since the function will always be re-evaluated for every input + * value. An exception is function calls that call another stable function + * that should not be re-evaluated, such as now(). + */ return (Node *) expression_tree_mutator(expression, PartiallyEvaluateExpression, masterEvaluationContext); @@ -211,11 +221,23 @@ ShouldEvaluateExpression(Expr *expression) /* - * IsVarOrParamSublink returns whether node is a Var or PARAM_SUBLINK param. + * IsVariableExpression returns whether the given node is a variable expression, + * meaning its result depends on the input data and is not constant for the whole + * query. */ static bool -IsVarOrParamSublink(Node *node) +IsVariableExpression(Node *node) { + if (IsA(node, Aggref)) + { + return true; + } + + if (IsA(node, WindowFunc)) + { + return true; + } + if (IsA(node, Param)) { /* ExecInitExpr cannot handle PARAM_SUBLINK */ diff --git a/src/test/regress/expected/multi_function_evaluation.out b/src/test/regress/expected/multi_function_evaluation.out index 4711ea264..6ad76423d 100644 --- a/src/test/regress/expected/multi_function_evaluation.out +++ b/src/test/regress/expected/multi_function_evaluation.out @@ -163,9 +163,31 @@ DELETE FROM table_1 WHERE key >= (SELECT min(KEY) FROM table_1) AND value > now() - interval '1 hour'; +CREATE OR REPLACE FUNCTION stable_squared(int) +RETURNS int STABLE +LANGUAGE plpgsql +AS $function$ +BEGIN + RAISE NOTICE 'stable_fn called'; + RETURN $1 * $1; +END; +$function$; +SELECT create_distributed_function('stable_squared(int)'); + create_distributed_function +--------------------------------------------------------------------- + +(1 row) + +UPDATE example SET value = timestamp '10-10-2000 00:00' +FROM (SELECT key, stable_squared(count(*)::int) y FROM example GROUP BY key) a WHERE example.key = a.key; +UPDATE example SET value = timestamp '10-10-2000 00:00' +FROM (SELECT key, stable_squared((count(*) OVER ())::int) y FROM example GROUP BY key) a WHERE example.key = a.key; +UPDATE example SET value = timestamp '10-10-2000 00:00' +FROM (SELECT key, stable_squared(grouping(key)) y FROM example GROUP BY key) a WHERE example.key = a.key; DROP SCHEMA multi_function_evaluation CASCADE; -NOTICE: drop cascades to 4 other objects +NOTICE: drop cascades to 5 other objects DETAIL: drop cascades to table example drop cascades to sequence example_value_seq drop cascades to function stable_fn() drop cascades to table table_1 +drop cascades to function stable_squared(integer) diff --git a/src/test/regress/sql/multi_function_evaluation.sql b/src/test/regress/sql/multi_function_evaluation.sql index b19afff1d..7e5480bab 100644 --- a/src/test/regress/sql/multi_function_evaluation.sql +++ b/src/test/regress/sql/multi_function_evaluation.sql @@ -145,4 +145,24 @@ FROM table_1 WHERE key >= (SELECT min(KEY) FROM table_1) AND value > now() - interval '1 hour'; +CREATE OR REPLACE FUNCTION stable_squared(int) +RETURNS int STABLE +LANGUAGE plpgsql +AS $function$ +BEGIN + RAISE NOTICE 'stable_fn called'; + RETURN $1 * $1; +END; +$function$; +SELECT create_distributed_function('stable_squared(int)'); + +UPDATE example SET value = timestamp '10-10-2000 00:00' +FROM (SELECT key, stable_squared(count(*)::int) y FROM example GROUP BY key) a WHERE example.key = a.key; + +UPDATE example SET value = timestamp '10-10-2000 00:00' +FROM (SELECT key, stable_squared((count(*) OVER ())::int) y FROM example GROUP BY key) a WHERE example.key = a.key; + +UPDATE example SET value = timestamp '10-10-2000 00:00' +FROM (SELECT key, stable_squared(grouping(key)) y FROM example GROUP BY key) a WHERE example.key = a.key; + DROP SCHEMA multi_function_evaluation CASCADE;