diff --git a/src/backend/distributed/utils/citus_clauses.c b/src/backend/distributed/utils/citus_clauses.c index f88b173af..7086cec4e 100644 --- a/src/backend/distributed/utils/citus_clauses.c +++ b/src/backend/distributed/utils/citus_clauses.c @@ -41,6 +41,7 @@ static bool ShouldEvaluateExpression(Expr *expression); static bool ShouldEvaluateFunctions(CoordinatorEvaluationContext *evaluationContext); static void FixFunctionArguments(Node *expr); static bool FixFunctionArgumentsWalker(Node *expr, void *context); +static bool CheckExprExecutorSafe(Node *expr); /* @@ -99,15 +100,18 @@ PartiallyEvaluateExpression(Node *expression, } NodeTag nodeTag = nodeTag(expression); - if (nodeTag == T_Param) - { - Param *param = (Param *) expression; - if (param->paramkind == PARAM_SUBLINK) - { - /* ExecInitExpr cannot handle PARAM_SUBLINK */ - return expression; - } + /* ExecInitExpr cannot handle some expressions (PARAM_MULTIEXPR and PARAM_SUBLINK) */ + if (!CheckExprExecutorSafe(expression)) + { + return expression; + } + + /* ExecInitExpr cannot handle PARAM_MULTIEXPR and PARAM_SUBLINK but we have guards */ + else if (nodeTag == T_Param) + { + Assert(((Param *) expression)->paramkind != PARAM_MULTIEXPR && + ((Param *) expression)->paramkind != PARAM_SUBLINK); return (Node *) citus_evaluate_expr((Expr *) expression, exprType(expression), exprTypmod(expression), @@ -260,7 +264,9 @@ ShouldEvaluateExpression(Expr *expression) } default: + { return false; + } } } @@ -537,3 +543,48 @@ FixFunctionArgumentsWalker(Node *expr, void *context) return expression_tree_walker(expr, FixFunctionArgumentsWalker, NULL); } + + +/* + * Recursively explore an expression to ensure it can be used in the PostgreSQL + * ExecInitExpr. + * Currently only search for PARAM_MULTIEXPR or PARAM_SUBLINK. + */ +static bool +CheckExprExecutorSafe(Node *expr) +{ + if (expr == NULL) + { + return true; + } + + /* + * If it's a Param, we're done traversing the tree. + * Just check if it contins a sublink or multiexpr. + */ + else if (IsA(expr, Param)) + { + Param *param = (Param *) expr; + if (param->paramkind == PARAM_MULTIEXPR || + param->paramkind == PARAM_SUBLINK) + { + return false; + } + } + + /* If it's a FuncExpr, search in arguments */ + else if (IsA(expr, FuncExpr)) + { + FuncExpr *func = (FuncExpr *) expr; + ListCell *lc; + + foreach(lc, func->args) + { + if (!CheckExprExecutorSafe((Node *) lfirst(lc))) + { + return false; + } + } + } + return true; +} diff --git a/src/test/regress/expected/multi_modifications.out b/src/test/regress/expected/multi_modifications.out index 887003a97..93f6c8c45 100644 --- a/src/test/regress/expected/multi_modifications.out +++ b/src/test/regress/expected/multi_modifications.out @@ -812,6 +812,33 @@ SELECT * FROM app_analytics_events ORDER BY id; (2 rows) DROP TABLE app_analytics_events; +-- test function call in UPDATE SET +-- https://github.com/citusdata/citus/issues/7676 +CREATE FUNCTION citus_is_coordinator_stable() returns bool as $$ + select citus_is_coordinator(); +$$ language sql stable; +CREATE TABLE bool_test ( + id bigint primary key, + col_bool bool + ); +SELECT create_reference_table('bool_test'); + create_reference_table +--------------------------------------------------------------------- + +(1 row) + +INSERT INTO bool_test values (1, true); +UPDATE bool_test +SET (col_bool) + = (SELECT citus_is_coordinator_stable()) +RETURNING id, col_bool; + id | col_bool +--------------------------------------------------------------------- + 1 | t +(1 row) + +DROP TABLE bool_test; +DROP FUNCTION citus_is_coordinator_stable(); -- Test multi-row insert with serial in a non-partition column CREATE TABLE app_analytics_events (id int, app_id serial, name text); SELECT create_distributed_table('app_analytics_events', 'id'); diff --git a/src/test/regress/sql/multi_modifications.sql b/src/test/regress/sql/multi_modifications.sql index 7977325ea..2a00e7992 100644 --- a/src/test/regress/sql/multi_modifications.sql +++ b/src/test/regress/sql/multi_modifications.sql @@ -505,6 +505,28 @@ VALUES (104, 'Wayz'), (105, 'Mynt') RETURNING *; SELECT * FROM app_analytics_events ORDER BY id; DROP TABLE app_analytics_events; +-- test function call in UPDATE SET +-- https://github.com/citusdata/citus/issues/7676 +CREATE FUNCTION citus_is_coordinator_stable() returns bool as $$ + select citus_is_coordinator(); +$$ language sql stable; + +CREATE TABLE bool_test ( + id bigint primary key, + col_bool bool + ); +SELECT create_reference_table('bool_test'); + +INSERT INTO bool_test values (1, true); + +UPDATE bool_test +SET (col_bool) + = (SELECT citus_is_coordinator_stable()) +RETURNING id, col_bool; + +DROP TABLE bool_test; +DROP FUNCTION citus_is_coordinator_stable(); + -- Test multi-row insert with serial in a non-partition column CREATE TABLE app_analytics_events (id int, app_id serial, name text); SELECT create_distributed_table('app_analytics_events', 'id');