From 5e58804d4429387edc8192c67348a4ba3a71dd99 Mon Sep 17 00:00:00 2001 From: Marco Slot Date: Wed, 12 Apr 2017 11:17:46 +0200 Subject: [PATCH] Support query parameters in combination with function evaluation --- .../executor/multi_router_executor.c | 15 ++- .../master/master_modify_multiple_shards.c | 2 +- src/backend/distributed/utils/citus_clauses.c | 105 +++++++++++------- src/include/distributed/citus_clauses.h | 3 +- .../regress/expected/multi_prepare_plsql.out | 80 +++++++++++++ .../regress/expected/multi_prepare_sql.out | 56 ++++++++++ src/test/regress/sql/multi_prepare_plsql.sql | 41 +++++++ src/test/regress/sql/multi_prepare_sql.sql | 38 +++++++ 8 files changed, 292 insertions(+), 48 deletions(-) diff --git a/src/backend/distributed/executor/multi_router_executor.c b/src/backend/distributed/executor/multi_router_executor.c index e546c3311..931b785a7 100644 --- a/src/backend/distributed/executor/multi_router_executor.c +++ b/src/backend/distributed/executor/multi_router_executor.c @@ -87,7 +87,7 @@ static List * TaskShardIntervalList(List *taskList); static void AcquireExecutorShardLock(Task *task, CmdType commandType); static void AcquireExecutorMultiShardLocks(List *taskList); static bool RequiresConsistentSnapshot(Task *task); -static void ProcessMasterEvaluableFunctions(Job *workerJob); +static void ProcessMasterEvaluableFunctions(Job *workerJob, PlanState *planState); static void ExtractParametersFromParamListInfo(ParamListInfo paramListInfo, Oid **parameterTypes, const char ***parameterValues); @@ -443,13 +443,14 @@ RouterSingleModifyExecScan(CustomScanState *node) if (!scanState->finishedRemoteScan) { + PlanState *planState = &(scanState->customScanState.ss.ps); MultiPlan *multiPlan = scanState->multiPlan; bool hasReturning = multiPlan->hasReturning; Job *workerJob = multiPlan->workerJob; List *taskList = workerJob->taskList; Task *task = (Task *) linitial(taskList); - ProcessMasterEvaluableFunctions(workerJob); + ProcessMasterEvaluableFunctions(workerJob, planState); ExecuteSingleModifyTask(scanState, task, hasReturning); @@ -467,14 +468,14 @@ RouterSingleModifyExecScan(CustomScanState *node) * the query strings in task lists. */ static void -ProcessMasterEvaluableFunctions(Job *workerJob) +ProcessMasterEvaluableFunctions(Job *workerJob, PlanState *planState) { if (workerJob->requiresMasterEvaluation) { Query *jobQuery = workerJob->jobQuery; List *taskList = workerJob->taskList; - ExecuteMasterEvaluableFunctions(jobQuery); + ExecuteMasterEvaluableFunctions(jobQuery, planState); RebuildQueryStrings(jobQuery, taskList); } } @@ -493,13 +494,14 @@ RouterMultiModifyExecScan(CustomScanState *node) if (!scanState->finishedRemoteScan) { + PlanState *planState = &(scanState->customScanState.ss.ps); MultiPlan *multiPlan = scanState->multiPlan; Job *workerJob = multiPlan->workerJob; List *taskList = workerJob->taskList; bool hasReturning = multiPlan->hasReturning; bool isModificationQuery = true; - ProcessMasterEvaluableFunctions(workerJob); + ProcessMasterEvaluableFunctions(workerJob, planState); ExecuteMultipleTasks(scanState, taskList, isModificationQuery, hasReturning); @@ -525,12 +527,13 @@ RouterSelectExecScan(CustomScanState *node) if (!scanState->finishedRemoteScan) { + PlanState *planState = &(scanState->customScanState.ss.ps); MultiPlan *multiPlan = scanState->multiPlan; Job *workerJob = multiPlan->workerJob; List *taskList = workerJob->taskList; Task *task = (Task *) linitial(taskList); - ProcessMasterEvaluableFunctions(workerJob); + ProcessMasterEvaluableFunctions(workerJob, planState); ExecuteSingleSelectTask(scanState, task); diff --git a/src/backend/distributed/master/master_modify_multiple_shards.c b/src/backend/distributed/master/master_modify_multiple_shards.c index 022bbe6e1..d2695a7b4 100644 --- a/src/backend/distributed/master/master_modify_multiple_shards.c +++ b/src/backend/distributed/master/master_modify_multiple_shards.c @@ -154,7 +154,7 @@ master_modify_multiple_shards(PG_FUNCTION_ARGS) errmsg("master_modify_multiple_shards() does not support RETURNING"))); } - ExecuteMasterEvaluableFunctions(modifyQuery); + ExecuteMasterEvaluableFunctions(modifyQuery, NULL); shardIntervalList = LoadShardIntervalList(relationId); restrictClauseList = WhereClauseList(modifyQuery->jointree); diff --git a/src/backend/distributed/utils/citus_clauses.c b/src/backend/distributed/utils/citus_clauses.c index f0d632e5b..61268b9ea 100644 --- a/src/backend/distributed/utils/citus_clauses.c +++ b/src/backend/distributed/utils/citus_clauses.c @@ -22,11 +22,21 @@ #include "utils/datum.h" #include "utils/lsyscache.h" -static Node * PartiallyEvaluateExpression(Node *expression); -static Node * EvaluateNodeIfReferencesFunction(Node *expression); -static Node * PartiallyEvaluateExpressionMutator(Node *expression, bool *containsVar); + +typedef struct FunctionEvaluationContext +{ + PlanState *planState; + bool containsVar; +} FunctionEvaluationContext; + + +/* private function declarations */ +static Node * PartiallyEvaluateExpression(Node *expression, PlanState *planState); +static Node * EvaluateNodeIfReferencesFunction(Node *expression, PlanState *planState); +static Node * PartiallyEvaluateExpressionMutator(Node *expression, + FunctionEvaluationContext *context); static Expr * citus_evaluate_expr(Expr *expr, Oid result_type, int32 result_typmod, - Oid result_collation); + Oid result_collation, PlanState *planState); /* @@ -88,7 +98,7 @@ RequiresMasterEvaluation(Query *query) * any sub-expressions which don't include Vars. */ void -ExecuteMasterEvaluableFunctions(Query *query) +ExecuteMasterEvaluableFunctions(Query *query, PlanState *planState) { CmdType commandType = query->commandType; ListCell *targetEntryCell = NULL; @@ -99,7 +109,8 @@ ExecuteMasterEvaluableFunctions(Query *query) if (query->jointree && query->jointree->quals) { - query->jointree->quals = PartiallyEvaluateExpression(query->jointree->quals); + query->jointree->quals = PartiallyEvaluateExpression(query->jointree->quals, + planState); } foreach(targetEntryCell, query->targetList) @@ -114,11 +125,13 @@ ExecuteMasterEvaluableFunctions(Query *query) if (commandType == CMD_INSERT && !insertSelectQuery) { - modifiedNode = EvaluateNodeIfReferencesFunction((Node *) targetEntry->expr); + modifiedNode = EvaluateNodeIfReferencesFunction((Node *) targetEntry->expr, + planState); } else { - modifiedNode = PartiallyEvaluateExpression((Node *) targetEntry->expr); + modifiedNode = PartiallyEvaluateExpression((Node *) targetEntry->expr, + planState); } targetEntry->expr = (Expr *) modifiedNode; @@ -133,14 +146,14 @@ ExecuteMasterEvaluableFunctions(Query *query) continue; } - ExecuteMasterEvaluableFunctions(rte->subquery); + ExecuteMasterEvaluableFunctions(rte->subquery, planState); } foreach(cteCell, query->cteList) { CommonTableExpr *expr = (CommonTableExpr *) lfirst(cteCell); - ExecuteMasterEvaluableFunctions((Query *) expr->ctequery); + ExecuteMasterEvaluableFunctions((Query *) expr->ctequery, planState); } } @@ -150,10 +163,11 @@ ExecuteMasterEvaluableFunctions(Query *query) * doesn't show up in the parameter list. */ static Node * -PartiallyEvaluateExpression(Node *expression) +PartiallyEvaluateExpression(Node *expression, PlanState *planState) { - bool unused; - return PartiallyEvaluateExpressionMutator(expression, &unused); + FunctionEvaluationContext globalContext = { planState, false }; + + return PartiallyEvaluateExpressionMutator(expression, &globalContext); } @@ -167,10 +181,10 @@ PartiallyEvaluateExpression(Node *expression) * only call EvaluateExpression on the top-most level and get the same result. */ static Node * -PartiallyEvaluateExpressionMutator(Node *expression, bool *containsVar) +PartiallyEvaluateExpressionMutator(Node *expression, FunctionEvaluationContext *context) { - bool childContainsVar = false; Node *copy = NULL; + FunctionEvaluationContext localContext = { context->planState, false }; if (expression == NULL) { @@ -182,30 +196,30 @@ PartiallyEvaluateExpressionMutator(Node *expression, bool *containsVar) { return expression_tree_mutator(expression, PartiallyEvaluateExpressionMutator, - containsVar); + context); } if (IsA(expression, Var)) { - *containsVar = true; + context->containsVar = true; /* makes a copy for us */ return expression_tree_mutator(expression, PartiallyEvaluateExpressionMutator, - containsVar); + context); } copy = expression_tree_mutator(expression, PartiallyEvaluateExpressionMutator, - &childContainsVar); + &localContext); - if (childContainsVar) + if (localContext.containsVar) { - *containsVar = true; + context->containsVar = true; } else { - copy = EvaluateNodeIfReferencesFunction(copy); + copy = EvaluateNodeIfReferencesFunction(copy, context->planState); } return copy; @@ -221,7 +235,7 @@ PartiallyEvaluateExpressionMutator(Node *expression, bool *containsVar) * all nodes which invoke functions which might not be IMMUTABLE. */ static Node * -EvaluateNodeIfReferencesFunction(Node *expression) +EvaluateNodeIfReferencesFunction(Node *expression, PlanState *planState) { if (IsA(expression, FuncExpr)) { @@ -230,7 +244,8 @@ EvaluateNodeIfReferencesFunction(Node *expression) return (Node *) citus_evaluate_expr((Expr *) expr, expr->funcresulttype, exprTypmod((Node *) expr), - expr->funccollid); + expr->funccollid, + planState); } if (IsA(expression, OpExpr) || @@ -242,7 +257,8 @@ EvaluateNodeIfReferencesFunction(Node *expression) return (Node *) citus_evaluate_expr((Expr *) expr, expr->opresulttype, -1, - expr->opcollid); + expr->opcollid, + planState); } if (IsA(expression, CoerceViaIO)) @@ -251,7 +267,8 @@ EvaluateNodeIfReferencesFunction(Node *expression) return (Node *) citus_evaluate_expr((Expr *) expr, expr->resulttype, -1, - expr->resultcollid); + expr->resultcollid, + planState); } if (IsA(expression, ArrayCoerceExpr)) @@ -261,21 +278,24 @@ EvaluateNodeIfReferencesFunction(Node *expression) return (Node *) citus_evaluate_expr((Expr *) expr, expr->resulttype, expr->resulttypmod, - expr->resultcollid); + expr->resultcollid, + planState); } if (IsA(expression, ScalarArrayOpExpr)) { ScalarArrayOpExpr *expr = (ScalarArrayOpExpr *) expression; - return (Node *) citus_evaluate_expr((Expr *) expr, BOOLOID, -1, InvalidOid); + return (Node *) citus_evaluate_expr((Expr *) expr, BOOLOID, -1, InvalidOid, + planState); } if (IsA(expression, RowCompareExpr)) { RowCompareExpr *expr = (RowCompareExpr *) expression; - return (Node *) citus_evaluate_expr((Expr *) expr, BOOLOID, -1, InvalidOid); + return (Node *) citus_evaluate_expr((Expr *) expr, BOOLOID, -1, InvalidOid, + planState); } return expression; @@ -292,10 +312,11 @@ EvaluateNodeIfReferencesFunction(Node *expression) */ static Expr * citus_evaluate_expr(Expr *expr, Oid result_type, int32 result_typmod, - Oid result_collation) + Oid result_collation, PlanState *planState) { - EState *estate; + EState *estate; ExprState *exprstate; + ExprContext *econtext; MemoryContext oldcontext; Datum const_val; bool const_is_null; @@ -317,19 +338,23 @@ citus_evaluate_expr(Expr *expr, Oid result_type, int32 result_typmod, * Prepare expr for execution. (Note: we can't use ExecPrepareExpr * because it'd result in recursively invoking eval_const_expressions.) */ - exprstate = ExecInitExpr(expr, NULL); + exprstate = ExecInitExpr(expr, planState); + + if (planState != NULL) + { + /* use executor's context to pass down parameters */ + econtext = planState->ps_ExprContext; + } + else + { + /* when called from a function, use a default context */ + econtext = GetPerTupleExprContext(estate); + } /* * And evaluate it. - * - * It is OK to use a default econtext because none of the ExecEvalExpr() - * code used in this situation will use econtext. That might seem - * fortuitous, but it's not so unreasonable --- a constant expression does - * not depend on context, by definition, n'est ce pas? */ - const_val = ExecEvalExprSwitchContext(exprstate, - GetPerTupleExprContext(estate), - &const_is_null, NULL); + const_val = ExecEvalExprSwitchContext(exprstate, econtext, &const_is_null, NULL); /* Get info needed about result datatype */ get_typlenbyval(result_type, &resultTypLen, &resultTypByVal); diff --git a/src/include/distributed/citus_clauses.h b/src/include/distributed/citus_clauses.h index 29905850a..5612fc3b2 100644 --- a/src/include/distributed/citus_clauses.h +++ b/src/include/distributed/citus_clauses.h @@ -11,10 +11,11 @@ #ifndef CITUS_CLAUSES_H #define CITUS_CLAUSES_H +#include "nodes/execnodes.h" #include "nodes/nodes.h" #include "nodes/parsenodes.h" extern bool RequiresMasterEvaluation(Query *query); -extern void ExecuteMasterEvaluableFunctions(Query *query); +extern void ExecuteMasterEvaluableFunctions(Query *query, PlanState *planState); #endif /* CITUS_CLAUSES_H */ diff --git a/src/test/regress/expected/multi_prepare_plsql.out b/src/test/regress/expected/multi_prepare_plsql.out index 8caf592cf..5774a3282 100644 --- a/src/test/regress/expected/multi_prepare_plsql.out +++ b/src/test/regress/expected/multi_prepare_plsql.out @@ -1106,6 +1106,86 @@ BEGIN END; $$; DROP TABLE execute_parameter_test; +-- check whether we can handle parameters + default +CREATE TABLE func_parameter_test ( + key text NOT NULL, + seq int4 NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + PRIMARY KEY (key, seq) +); +SELECT create_distributed_table('func_parameter_test', 'key'); + create_distributed_table +-------------------------- + +(1 row) + +CREATE OR REPLACE FUNCTION insert_with_max(pkey text) RETURNS VOID AS +$BODY$ + DECLARE + max_seq int4; + BEGIN + SELECT MAX(seq) INTO max_seq + FROM func_parameter_test + WHERE func_parameter_test.key = pkey; + + IF max_seq IS NULL THEN + max_seq := 0; + END IF; + + INSERT INTO func_parameter_test(key, seq) VALUES (pkey, max_seq + 1); + END; +$BODY$ +LANGUAGE plpgsql; +SELECT insert_with_max('key'); + insert_with_max +----------------- + +(1 row) + +SELECT insert_with_max('key'); + insert_with_max +----------------- + +(1 row) + +SELECT insert_with_max('key'); + insert_with_max +----------------- + +(1 row) + +SELECT insert_with_max('key'); + insert_with_max +----------------- + +(1 row) + +SELECT insert_with_max('key'); + insert_with_max +----------------- + +(1 row) + +SELECT insert_with_max('key'); + insert_with_max +----------------- + +(1 row) + +SELECT key, seq FROM func_parameter_test ORDER BY seq; + key | seq +-----+----- + key | 1 + key | 2 + key | 3 + key | 4 + key | 5 + key | 6 +(6 rows) + +DROP FUNCTION insert_with_max(text); +DROP TABLE func_parameter_test; -- clean-up functions DROP FUNCTION plpgsql_test_1(); DROP FUNCTION plpgsql_test_2(); diff --git a/src/test/regress/expected/multi_prepare_sql.out b/src/test/regress/expected/multi_prepare_sql.out index 5c5339d2e..c5510356f 100644 --- a/src/test/regress/expected/multi_prepare_sql.out +++ b/src/test/regress/expected/multi_prepare_sql.out @@ -851,6 +851,62 @@ SELECT * FROM prepare_table ORDER BY key, value; 6 | (8 rows) +-- Testing parameters + function evaluation +CREATE TABLE prepare_func_table ( + key text, + value1 int, + value2 text, + value3 timestamptz DEFAULT now() +); +SELECT create_distributed_table('prepare_func_table', 'key'); + create_distributed_table +-------------------------- + +(1 row) + +-- test function evaluation with parameters in an expression +PREPARE prepared_function_evaluation_insert(int) AS + INSERT INTO prepare_func_table (key, value1) VALUES ($1+1, 0*random()); +-- execute 6 times to trigger prepared statement usage +EXECUTE prepared_function_evaluation_insert(1); +EXECUTE prepared_function_evaluation_insert(2); +EXECUTE prepared_function_evaluation_insert(3); +EXECUTE prepared_function_evaluation_insert(4); +EXECUTE prepared_function_evaluation_insert(5); +EXECUTE prepared_function_evaluation_insert(6); +SELECT key, value1 FROM prepare_func_table ORDER BY key; + key | value1 +-----+-------- + 2 | 0 + 3 | 0 + 4 | 0 + 5 | 0 + 6 | 0 + 7 | 0 +(6 rows) + +TRUNCATE prepare_func_table; +-- make it a bit harder: parameter wrapped in a function call +PREPARE wrapped_parameter_evaluation(text,text[]) AS + INSERT INTO prepare_func_table (key,value2) VALUES ($1,array_to_string($2,'')); +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); +SELECT key, value2 FROM prepare_func_table; + key | value2 +-----+-------- + key | value + key | value + key | value + key | value + key | value + key | value +(6 rows) + +DROP TABLE prepare_func_table; -- verify placement state updates invalidate shard state -- -- We use a immutable function to check for that. The planner will diff --git a/src/test/regress/sql/multi_prepare_plsql.sql b/src/test/regress/sql/multi_prepare_plsql.sql index 5018c2cb8..769bf2cdd 100644 --- a/src/test/regress/sql/multi_prepare_plsql.sql +++ b/src/test/regress/sql/multi_prepare_plsql.sql @@ -497,6 +497,47 @@ END; $$; DROP TABLE execute_parameter_test; +-- check whether we can handle parameters + default +CREATE TABLE func_parameter_test ( + key text NOT NULL, + seq int4 NOT NULL, + created_at timestamptz NOT NULL DEFAULT now(), + updated_at timestamptz NOT NULL DEFAULT now(), + PRIMARY KEY (key, seq) + +); +SELECT create_distributed_table('func_parameter_test', 'key'); + +CREATE OR REPLACE FUNCTION insert_with_max(pkey text) RETURNS VOID AS +$BODY$ + DECLARE + max_seq int4; + BEGIN + SELECT MAX(seq) INTO max_seq + FROM func_parameter_test + WHERE func_parameter_test.key = pkey; + + IF max_seq IS NULL THEN + max_seq := 0; + END IF; + + INSERT INTO func_parameter_test(key, seq) VALUES (pkey, max_seq + 1); + END; +$BODY$ +LANGUAGE plpgsql; + +SELECT insert_with_max('key'); +SELECT insert_with_max('key'); +SELECT insert_with_max('key'); +SELECT insert_with_max('key'); +SELECT insert_with_max('key'); +SELECT insert_with_max('key'); + +SELECT key, seq FROM func_parameter_test ORDER BY seq; + +DROP FUNCTION insert_with_max(text); +DROP TABLE func_parameter_test; + -- clean-up functions DROP FUNCTION plpgsql_test_1(); DROP FUNCTION plpgsql_test_2(); diff --git a/src/test/regress/sql/multi_prepare_sql.sql b/src/test/regress/sql/multi_prepare_sql.sql index d35c5d936..cfd50ba71 100644 --- a/src/test/regress/sql/multi_prepare_sql.sql +++ b/src/test/regress/sql/multi_prepare_sql.sql @@ -438,6 +438,44 @@ EXECUTE prepared_non_partition_parameter_delete(62); -- check after deletes SELECT * FROM prepare_table ORDER BY key, value; +-- Testing parameters + function evaluation +CREATE TABLE prepare_func_table ( + key text, + value1 int, + value2 text, + value3 timestamptz DEFAULT now() +); +SELECT create_distributed_table('prepare_func_table', 'key'); + +-- test function evaluation with parameters in an expression +PREPARE prepared_function_evaluation_insert(int) AS + INSERT INTO prepare_func_table (key, value1) VALUES ($1+1, 0*random()); + +-- execute 6 times to trigger prepared statement usage +EXECUTE prepared_function_evaluation_insert(1); +EXECUTE prepared_function_evaluation_insert(2); +EXECUTE prepared_function_evaluation_insert(3); +EXECUTE prepared_function_evaluation_insert(4); +EXECUTE prepared_function_evaluation_insert(5); +EXECUTE prepared_function_evaluation_insert(6); + +SELECT key, value1 FROM prepare_func_table ORDER BY key; +TRUNCATE prepare_func_table; + +-- make it a bit harder: parameter wrapped in a function call +PREPARE wrapped_parameter_evaluation(text,text[]) AS + INSERT INTO prepare_func_table (key,value2) VALUES ($1,array_to_string($2,'')); + +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); +EXECUTE wrapped_parameter_evaluation('key', ARRAY['value']); + +SELECT key, value2 FROM prepare_func_table; + +DROP TABLE prepare_func_table; -- verify placement state updates invalidate shard state --