diff --git a/src/backend/distributed/utils/aggregate_utils.c b/src/backend/distributed/utils/aggregate_utils.c index 95c42e18d..5daccd2ca 100644 --- a/src/backend/distributed/utils/aggregate_utils.c +++ b/src/backend/distributed/utils/aggregate_utils.c @@ -21,6 +21,7 @@ #include "catalog/pg_proc.h" #include "catalog/pg_type.h" #include "distributed/version_compat.h" +#include "nodes/nodeFuncs.h" #include "utils/acl.h" #include "utils/builtins.h" #include "utils/datum.h" @@ -62,6 +63,9 @@ static StypeBox * TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo); static void HandleTransition(StypeBox *box, FunctionCallInfo fcinfo, FunctionCallInfo innerFcinfo); static void HandleStrictUninit(StypeBox *box, FunctionCallInfo fcinfo, Datum value); +static bool TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo, StypeBox *box); +static bool TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo, Oid ffunc, + StypeBox *box); /* * GetAggregateForm loads corresponding tuple & Form_pg_aggregate for oid @@ -346,6 +350,12 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS) { box = pallocInAggContext(fcinfo, sizeof(StypeBox)); box->agg = PG_GETARG_OID(1); + + if (!TypecheckWorkerPartialAggArgType(fcinfo, box)) + { + ereport(ERROR, (errmsg( + "worker_partial_agg_sfunc could not confirm type correctness"))); + } } else { @@ -617,6 +627,12 @@ coord_combine_agg_ffunc(PG_FUNCTION_ARGS) bool fextra = aggform->aggfinalextra; ReleaseSysCache(aggtuple); + if (!TypecheckCoordCombineAggReturnType(fcinfo, ffunc, box)) + { + ereport(ERROR, (errmsg( + "coord_combine_agg_ffunc could not confirm type correctness"))); + } + if (ffunc == InvalidOid) { if (box->valueNull) @@ -656,3 +672,74 @@ coord_combine_agg_ffunc(PG_FUNCTION_ARGS) fcinfo->isnull = innerFcinfo->isnull; return result; } + + +/* + * TypecheckWorkerPartialAggArgType returns whether the arguments being passed to + * worker_partial_agg match the arguments expected by the aggregate being distributed. + */ +static bool +TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo, StypeBox *box) +{ + Aggref *aggref = AggGetAggref(fcinfo); + if (aggref == NULL) + { + return false; + } + + Assert(list_length(aggref->args) == 2); + TargetEntry *aggarg = list_nth(aggref->args, 1); + + bool argtypesNull; + HeapTuple proctuple = SearchSysCache1(PROCOID, ObjectIdGetDatum(box->agg)); + if (!HeapTupleIsValid(proctuple)) + { + return false; + } + + Datum argtypes = SysCacheGetAttr(PROCOID, proctuple, + Anum_pg_proc_proargtypes, + &argtypesNull); + Assert(!argtypesNull); + ReleaseSysCache(proctuple); + + if (ARR_NDIM(DatumGetArrayTypeP(argtypes)) != 1 || + ARR_DIMS(DatumGetArrayTypeP(argtypes))[0] != 1) + { + elog(ERROR, "worker_partial_agg_sfunc cannot type check aggregates " + "taking anything other than 1 argument"); + } + + int arrayIndex = 0; + Datum argtype = array_get_element(argtypes, + 1, &arrayIndex, -1, sizeof(Oid), true, 'i', + &argtypesNull); + Assert(!argtypesNull); + + return aggarg != NULL && exprType((Node *) aggarg->expr) == DatumGetObjectId(argtype); +} + + +/* + * TypecheckCoordCombineAggReturnType returns whether the return type of the aggregate + * being distributed by coord_combine_agg matches the null constant used to inform postgres + * what the aggregate's expected return type is. + */ +static bool +TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo, Oid ffunc, StypeBox *box) +{ + Aggref *aggref = AggGetAggref(fcinfo); + if (aggref == NULL) + { + return false; + } + + Oid finalType = ffunc == InvalidOid ? + box->transtype : get_func_rettype(ffunc); + + Assert(list_length(aggref->args) == 3); + TargetEntry *nulltag = list_nth(aggref->args, 2); + + return nulltag != NULL && IsA(nulltag->expr, Const) && + ((Const *) nulltag->expr)->consttype == finalType; +} diff --git a/src/test/regress/expected/aggregate_support.out b/src/test/regress/expected/aggregate_support.out index bf4ddb41d..e8cc08593 100644 --- a/src/test/regress/expected/aggregate_support.out +++ b/src/test/regress/expected/aggregate_support.out @@ -432,5 +432,36 @@ select key, count(distinct aggdata) from aggdata group by key order by 1, 2; ERROR: type "aggregate_support.aggdata" does not exist CONTEXT: while executing command on localhost:xxxxx +-- Test https://github.com/citusdata/citus/issues/3328 +create table nulltable(id int); +insert into nulltable values (0); +-- These cases are not type correct +select pg_catalog.worker_partial_agg('string_agg(text,text)'::regprocedure, id) from nulltable; +ERROR: worker_partial_agg_sfunc cannot type check aggregates taking anything other than 1 argument +select pg_catalog.worker_partial_agg('sum(int8)'::regprocedure, id) from nulltable; +ERROR: worker_partial_agg_sfunc could not confirm type correctness +select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::text) from nulltable; +ERROR: coord_combine_agg_ffunc could not confirm type correctness +select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::text) from nulltable; +ERROR: coord_combine_agg_ffunc could not confirm type correctness +-- These cases are type correct +select pg_catalog.worker_partial_agg('sum(int)'::regprocedure, id) from nulltable; + worker_partial_agg +--------------------------------------------------------------------- + 0 +(1 row) + +select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::float8) from nulltable; + coord_combine_agg +--------------------------------------------------------------------- + 0 +(1 row) + +select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::float8) from nulltable; + coord_combine_agg +--------------------------------------------------------------------- + +(1 row) + set client_min_messages to error; drop schema aggregate_support cascade; diff --git a/src/test/regress/sql/aggregate_support.sql b/src/test/regress/sql/aggregate_support.sql index 016f9979b..faec5f612 100644 --- a/src/test/regress/sql/aggregate_support.sql +++ b/src/test/regress/sql/aggregate_support.sql @@ -204,5 +204,18 @@ RESET citus.task_executor_type; select key, count(distinct aggdata) from aggdata group by key order by 1, 2; +-- Test https://github.com/citusdata/citus/issues/3328 +create table nulltable(id int); +insert into nulltable values (0); +-- These cases are not type correct +select pg_catalog.worker_partial_agg('string_agg(text,text)'::regprocedure, id) from nulltable; +select pg_catalog.worker_partial_agg('sum(int8)'::regprocedure, id) from nulltable; +select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::text) from nulltable; +select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::text) from nulltable; +-- These cases are type correct +select pg_catalog.worker_partial_agg('sum(int)'::regprocedure, id) from nulltable; +select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::float8) from nulltable; +select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::float8) from nulltable; + set client_min_messages to error; drop schema aggregate_support cascade;