mirror of https://github.com/citusdata/citus.git
Merge pull request #3600 from citusdata/typecheck-agg-combine
Add runtime type checking to AGGREGATE_CUSTOM_COMBINE helper functionspull/3561/head^2
commit
7eb678f0f7
|
@ -21,6 +21,7 @@
|
|||
#include "catalog/pg_proc.h"
|
||||
#include "catalog/pg_type.h"
|
||||
#include "distributed/version_compat.h"
|
||||
#include "nodes/nodeFuncs.h"
|
||||
#include "utils/acl.h"
|
||||
#include "utils/builtins.h"
|
||||
#include "utils/datum.h"
|
||||
|
@ -62,6 +63,9 @@ static StypeBox * TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo);
|
|||
static void HandleTransition(StypeBox *box, FunctionCallInfo fcinfo,
|
||||
FunctionCallInfo innerFcinfo);
|
||||
static void HandleStrictUninit(StypeBox *box, FunctionCallInfo fcinfo, Datum value);
|
||||
static bool TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo, StypeBox *box);
|
||||
static bool TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo, Oid ffunc,
|
||||
StypeBox *box);
|
||||
|
||||
/*
|
||||
* GetAggregateForm loads corresponding tuple & Form_pg_aggregate for oid
|
||||
|
@ -346,6 +350,12 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS)
|
|||
{
|
||||
box = pallocInAggContext(fcinfo, sizeof(StypeBox));
|
||||
box->agg = PG_GETARG_OID(1);
|
||||
|
||||
if (!TypecheckWorkerPartialAggArgType(fcinfo, box))
|
||||
{
|
||||
ereport(ERROR, (errmsg(
|
||||
"worker_partial_agg_sfunc could not confirm type correctness")));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -617,6 +627,12 @@ coord_combine_agg_ffunc(PG_FUNCTION_ARGS)
|
|||
bool fextra = aggform->aggfinalextra;
|
||||
ReleaseSysCache(aggtuple);
|
||||
|
||||
if (!TypecheckCoordCombineAggReturnType(fcinfo, ffunc, box))
|
||||
{
|
||||
ereport(ERROR, (errmsg(
|
||||
"coord_combine_agg_ffunc could not confirm type correctness")));
|
||||
}
|
||||
|
||||
if (ffunc == InvalidOid)
|
||||
{
|
||||
if (box->valueNull)
|
||||
|
@ -656,3 +672,74 @@ coord_combine_agg_ffunc(PG_FUNCTION_ARGS)
|
|||
fcinfo->isnull = innerFcinfo->isnull;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* TypecheckWorkerPartialAggArgType returns whether the arguments being passed to
|
||||
* worker_partial_agg match the arguments expected by the aggregate being distributed.
|
||||
*/
|
||||
static bool
|
||||
TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo, StypeBox *box)
|
||||
{
|
||||
Aggref *aggref = AggGetAggref(fcinfo);
|
||||
if (aggref == NULL)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Assert(list_length(aggref->args) == 2);
|
||||
TargetEntry *aggarg = list_nth(aggref->args, 1);
|
||||
|
||||
bool argtypesNull;
|
||||
HeapTuple proctuple = SearchSysCache1(PROCOID, ObjectIdGetDatum(box->agg));
|
||||
if (!HeapTupleIsValid(proctuple))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Datum argtypes = SysCacheGetAttr(PROCOID, proctuple,
|
||||
Anum_pg_proc_proargtypes,
|
||||
&argtypesNull);
|
||||
Assert(!argtypesNull);
|
||||
ReleaseSysCache(proctuple);
|
||||
|
||||
if (ARR_NDIM(DatumGetArrayTypeP(argtypes)) != 1 ||
|
||||
ARR_DIMS(DatumGetArrayTypeP(argtypes))[0] != 1)
|
||||
{
|
||||
elog(ERROR, "worker_partial_agg_sfunc cannot type check aggregates "
|
||||
"taking anything other than 1 argument");
|
||||
}
|
||||
|
||||
int arrayIndex = 0;
|
||||
Datum argtype = array_get_element(argtypes,
|
||||
1, &arrayIndex, -1, sizeof(Oid), true, 'i',
|
||||
&argtypesNull);
|
||||
Assert(!argtypesNull);
|
||||
|
||||
return aggarg != NULL && exprType((Node *) aggarg->expr) == DatumGetObjectId(argtype);
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* TypecheckCoordCombineAggReturnType returns whether the return type of the aggregate
|
||||
* being distributed by coord_combine_agg matches the null constant used to inform postgres
|
||||
* what the aggregate's expected return type is.
|
||||
*/
|
||||
static bool
|
||||
TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo, Oid ffunc, StypeBox *box)
|
||||
{
|
||||
Aggref *aggref = AggGetAggref(fcinfo);
|
||||
if (aggref == NULL)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Oid finalType = ffunc == InvalidOid ?
|
||||
box->transtype : get_func_rettype(ffunc);
|
||||
|
||||
Assert(list_length(aggref->args) == 3);
|
||||
TargetEntry *nulltag = list_nth(aggref->args, 2);
|
||||
|
||||
return nulltag != NULL && IsA(nulltag->expr, Const) &&
|
||||
((Const *) nulltag->expr)->consttype == finalType;
|
||||
}
|
||||
|
|
|
@ -432,5 +432,36 @@ select key, count(distinct aggdata)
|
|||
from aggdata group by key order by 1, 2;
|
||||
ERROR: type "aggregate_support.aggdata" does not exist
|
||||
CONTEXT: while executing command on localhost:xxxxx
|
||||
-- Test https://github.com/citusdata/citus/issues/3328
|
||||
create table nulltable(id int);
|
||||
insert into nulltable values (0);
|
||||
-- These cases are not type correct
|
||||
select pg_catalog.worker_partial_agg('string_agg(text,text)'::regprocedure, id) from nulltable;
|
||||
ERROR: worker_partial_agg_sfunc cannot type check aggregates taking anything other than 1 argument
|
||||
select pg_catalog.worker_partial_agg('sum(int8)'::regprocedure, id) from nulltable;
|
||||
ERROR: worker_partial_agg_sfunc could not confirm type correctness
|
||||
select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::text) from nulltable;
|
||||
ERROR: coord_combine_agg_ffunc could not confirm type correctness
|
||||
select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::text) from nulltable;
|
||||
ERROR: coord_combine_agg_ffunc could not confirm type correctness
|
||||
-- These cases are type correct
|
||||
select pg_catalog.worker_partial_agg('sum(int)'::regprocedure, id) from nulltable;
|
||||
worker_partial_agg
|
||||
---------------------------------------------------------------------
|
||||
0
|
||||
(1 row)
|
||||
|
||||
select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::float8) from nulltable;
|
||||
coord_combine_agg
|
||||
---------------------------------------------------------------------
|
||||
0
|
||||
(1 row)
|
||||
|
||||
select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::float8) from nulltable;
|
||||
coord_combine_agg
|
||||
---------------------------------------------------------------------
|
||||
|
||||
(1 row)
|
||||
|
||||
set client_min_messages to error;
|
||||
drop schema aggregate_support cascade;
|
||||
|
|
|
@ -204,5 +204,18 @@ RESET citus.task_executor_type;
|
|||
select key, count(distinct aggdata)
|
||||
from aggdata group by key order by 1, 2;
|
||||
|
||||
-- Test https://github.com/citusdata/citus/issues/3328
|
||||
create table nulltable(id int);
|
||||
insert into nulltable values (0);
|
||||
-- These cases are not type correct
|
||||
select pg_catalog.worker_partial_agg('string_agg(text,text)'::regprocedure, id) from nulltable;
|
||||
select pg_catalog.worker_partial_agg('sum(int8)'::regprocedure, id) from nulltable;
|
||||
select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::text) from nulltable;
|
||||
select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::text) from nulltable;
|
||||
-- These cases are type correct
|
||||
select pg_catalog.worker_partial_agg('sum(int)'::regprocedure, id) from nulltable;
|
||||
select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::float8) from nulltable;
|
||||
select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::float8) from nulltable;
|
||||
|
||||
set client_min_messages to error;
|
||||
drop schema aggregate_support cascade;
|
||||
|
|
Loading…
Reference in New Issue