Refactor logging in utility hook and function call delegation; enhance test for parameterized CALL with multiple parameters

m3hm3t/dist_func_parameter
Mehmet Yilmaz 2025-07-21 07:58:37 +00:00
parent ecc67255ec
commit 9b90096823
3 changed files with 35 additions and 46 deletions

View File

@ -273,13 +273,6 @@ citus_ProcessUtility(PlannedStmt *pstmt,
{
CallStmt *callStmt = (CallStmt *) parsetree;
/* Log the entire parsetree as a string */
elog(DEBUG1, "Parse Tree: %s", nodeToString(parsetree));
/* Log other information as before */
elog(DEBUG1, "Processing CallStmt for procedure");
elog(DEBUG1, "Procedure context: %d", context);
/*
* If the procedure is distributed and we are using MX then we have the
* possibility of calling it on the worker. If the data is located on

View File

@ -514,32 +514,17 @@ ShardPlacementForFunctionColocatedWithDistTable(DistObjectCacheEntry *procedure,
CitusTableCacheEntry *cacheEntry,
PlannedStmt *plan)
{
/* Log distribution argument index and argument list size */
elog(DEBUG1, "Distribution Argument Index: %d, Argument List Length: %d",
procedure->distributionArgIndex, list_length(argumentList));
if (procedure->distributionArgIndex < 0 ||
procedure->distributionArgIndex >= list_length(argumentList))
{
/* Add more detailed log for invalid distribution argument index */
ereport(DEBUG1, (errmsg("Invalid distribution argument index: %d",
procedure->distributionArgIndex)));
ereport(DEBUG1, (errmsg("cannot push down invalid distribution_argument_index")));
return NULL;
}
/* Get the partition value node */
Node *partitionValueNode = (Node *) list_nth(argumentList, procedure->distributionArgIndex);
/* Log the partition value node before stripping implicit coercions */
elog(DEBUG1, "Partition Value Node before stripping: %s", nodeToString(partitionValueNode));
/* Strip implicit coercions */
Node *partitionValueNode = (Node *) list_nth(argumentList,
procedure->distributionArgIndex);
partitionValueNode = strip_implicit_coercions(partitionValueNode);
/* Log the partition value node after stripping implicit coercions */
elog(DEBUG1, "Partition Value Node after stripping: %s", nodeToString(partitionValueNode));
if (IsA(partitionValueNode, Param))
{
Param *partitionParam = (Param *) partitionValueNode;

View File

@ -33,36 +33,47 @@ def test_call_param(cluster):
def test_call_param2(cluster):
# Get the coordinator node from the Citus cluster
# create a distributed table with two columns and an associated distributed procedure
# to ensure parameterized CALL succeed, even when the first param is the
# distribution key but there are additional params.
coord = cluster.coordinator
coord.sql("DROP TABLE IF EXISTS t CASCADE")
coord.sql("CREATE TABLE t (p int, i int)")
coord.sql("SELECT create_distributed_table('t', 'p')")
# 1) create table with two columns
coord.sql("DROP TABLE IF EXISTS test CASCADE")
coord.sql("CREATE TABLE test(i int, j int)")
# 2) create a procedure taking both columns as inputs
coord.sql(
"""
CREATE PROCEDURE f(_p INT, _i INT) LANGUAGE plpgsql AS $$
CREATE PROCEDURE p(_i INT, _j INT)
LANGUAGE plpgsql AS $$
BEGIN
INSERT INTO t (p, i) VALUES (_p, _i);
PERFORM _i;
END; $$
INSERT INTO test(i, j) VALUES (_i, _j);
END;
$$;
"""
)
sql = "CALL p(2, %s)"
# 4) distribute table on column i and function on _i
coord.sql("SELECT create_distributed_table('test', 'i')")
coord.sql(
"SELECT create_distributed_function('f(int, int)', distribution_arg_name := '_p', colocate_with := 't')"
"""
SELECT create_distributed_function(
'p(int, int)',
distribution_arg_name := '_i',
colocate_with := 'test'
)
"""
)
sql_insert_and_call = "CALL f(1, 33);"
# sql_insert_and_call = "CALL f(%s, 1);"
# 5) prepare/exec after distribution
coord.sql_prepared(sql, (20,))
# cluster.coordinator.psql_debug()
# cluster.debug()
# 6) verify both inserts happened
sum_i = coord.sql_value("SELECT sum(i) FROM test;")
sum_j = coord.sql_value("SELECT sum(j) FROM test;")
# After distributing the table, insert more data and call the procedure again
# coord.sql_prepared(sql_insert_and_call, (33,))
coord.sql(sql_insert_and_call)
# Step 6: Check the result
sum_i = coord.sql_value("SELECT i FROM t limit 1;")
assert sum_i == 33
assert sum_i == 2
assert sum_j == 20