diff --git a/src/backend/distributed/planner/distributed_planner.c b/src/backend/distributed/planner/distributed_planner.c index 65278d1ea..4f7612f8f 100644 --- a/src/backend/distributed/planner/distributed_planner.c +++ b/src/backend/distributed/planner/distributed_planner.c @@ -702,6 +702,7 @@ DissuadePlannerFromUsingPlan(PlannedStmt *plan) * Arbitrarily high cost, but low enough that it can be added up * without overflowing by choose_custom_plan(). */ + Assert(plan != NULL); plan->planTree->total_cost = FLT_MAX / 100000000; } diff --git a/src/backend/distributed/planner/function_call_delegation.c b/src/backend/distributed/planner/function_call_delegation.c index 2f8da29c0..ce9c818d7 100644 --- a/src/backend/distributed/planner/function_call_delegation.c +++ b/src/backend/distributed/planner/function_call_delegation.c @@ -525,8 +525,16 @@ ShardPlacementForFunctionColocatedWithDistTable(DistObjectCacheEntry *procedure, if (partitionParam->paramkind == PARAM_EXTERN) { - /* Don't log a message, we should end up here again without a parameter */ - DissuadePlannerFromUsingPlan(plan); + /* + * Don't log a message, we should end up here again without a + * parameter. + * Note that "plan" can be null, for example when a CALL statement + * is prepared. + */ + if (plan) + { + DissuadePlannerFromUsingPlan(plan); + } return NULL; } } diff --git a/src/test/regress/citus_tests/common.py b/src/test/regress/citus_tests/common.py index 53c9c7944..40c727189 100644 --- a/src/test/regress/citus_tests/common.py +++ b/src/test/regress/citus_tests/common.py @@ -581,6 +581,14 @@ class QueryRunner(ABC): with self.cur(**kwargs) as cur: cur.execute(query, params=params) + def sql_prepared(self, query, params=None, **kwargs): + """Run an SQL query, with prepare=True + + This opens a new connection and closes it once the query is done + """ + with self.cur(**kwargs) as cur: + cur.execute(query, params=params, prepare=True) + def sql_row(self, query, params=None, allow_empty_result=False, **kwargs): """Run an SQL query that returns a single row and returns this row diff --git a/src/test/regress/citus_tests/test/test_prepared_statements.py b/src/test/regress/citus_tests/test/test_prepared_statements.py new file mode 100644 index 000000000..761ecc30c --- /dev/null +++ b/src/test/regress/citus_tests/test/test_prepared_statements.py @@ -0,0 +1,30 @@ +def test_call_param(cluster): + # create a distributed table and an associated distributed procedure + # to ensure parameterized CALL succeed, even when the param is the + # distribution key. + coord = cluster.coordinator + coord.sql("CREATE TABLE test(i int)") + coord.sql( + """ + CREATE PROCEDURE p(_i INT) LANGUAGE plpgsql AS $$ + BEGIN + INSERT INTO test(i) VALUES (_i); + END; $$ + """ + ) + sql = "CALL p(%s)" + + # prepare/exec before distributing + coord.sql_prepared(sql, (1,)) + + coord.sql("SELECT create_distributed_table('test', 'i')") + coord.sql( + "SELECT create_distributed_function('p(int)', distribution_arg_name := '_i', colocate_with := 'test')" + ) + + # prepare/exec after distribution + coord.sql_prepared(sql, (2,)) + + sum_i = coord.sql_value("select sum(i) from test;") + + assert sum_i == 3