Merge pull request #3600 from citusdata/typecheck-agg-combine

Add runtime type checking to AGGREGATE_CUSTOM_COMBINE helper functions
pull/3561/head^2
Philip Dubé 2020-03-11 17:31:17 +00:00 committed by GitHub
commit 7eb678f0f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 131 additions and 0 deletions

View File

@ -21,6 +21,7 @@
#include "catalog/pg_proc.h"
#include "catalog/pg_type.h"
#include "distributed/version_compat.h"
#include "nodes/nodeFuncs.h"
#include "utils/acl.h"
#include "utils/builtins.h"
#include "utils/datum.h"
@ -62,6 +63,9 @@ static StypeBox * TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo);
static void HandleTransition(StypeBox *box, FunctionCallInfo fcinfo,
FunctionCallInfo innerFcinfo);
static void HandleStrictUninit(StypeBox *box, FunctionCallInfo fcinfo, Datum value);
static bool TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo, StypeBox *box);
static bool TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo, Oid ffunc,
StypeBox *box);
/*
* GetAggregateForm loads corresponding tuple & Form_pg_aggregate for oid
@ -346,6 +350,12 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS)
{
box = pallocInAggContext(fcinfo, sizeof(StypeBox));
box->agg = PG_GETARG_OID(1);
if (!TypecheckWorkerPartialAggArgType(fcinfo, box))
{
ereport(ERROR, (errmsg(
"worker_partial_agg_sfunc could not confirm type correctness")));
}
}
else
{
@ -617,6 +627,12 @@ coord_combine_agg_ffunc(PG_FUNCTION_ARGS)
bool fextra = aggform->aggfinalextra;
ReleaseSysCache(aggtuple);
if (!TypecheckCoordCombineAggReturnType(fcinfo, ffunc, box))
{
ereport(ERROR, (errmsg(
"coord_combine_agg_ffunc could not confirm type correctness")));
}
if (ffunc == InvalidOid)
{
if (box->valueNull)
@ -656,3 +672,74 @@ coord_combine_agg_ffunc(PG_FUNCTION_ARGS)
fcinfo->isnull = innerFcinfo->isnull;
return result;
}
/*
* TypecheckWorkerPartialAggArgType returns whether the arguments being passed to
* worker_partial_agg match the arguments expected by the aggregate being distributed.
*/
static bool
TypecheckWorkerPartialAggArgType(FunctionCallInfo fcinfo, StypeBox *box)
{
Aggref *aggref = AggGetAggref(fcinfo);
if (aggref == NULL)
{
return false;
}
Assert(list_length(aggref->args) == 2);
TargetEntry *aggarg = list_nth(aggref->args, 1);
bool argtypesNull;
HeapTuple proctuple = SearchSysCache1(PROCOID, ObjectIdGetDatum(box->agg));
if (!HeapTupleIsValid(proctuple))
{
return false;
}
Datum argtypes = SysCacheGetAttr(PROCOID, proctuple,
Anum_pg_proc_proargtypes,
&argtypesNull);
Assert(!argtypesNull);
ReleaseSysCache(proctuple);
if (ARR_NDIM(DatumGetArrayTypeP(argtypes)) != 1 ||
ARR_DIMS(DatumGetArrayTypeP(argtypes))[0] != 1)
{
elog(ERROR, "worker_partial_agg_sfunc cannot type check aggregates "
"taking anything other than 1 argument");
}
int arrayIndex = 0;
Datum argtype = array_get_element(argtypes,
1, &arrayIndex, -1, sizeof(Oid), true, 'i',
&argtypesNull);
Assert(!argtypesNull);
return aggarg != NULL && exprType((Node *) aggarg->expr) == DatumGetObjectId(argtype);
}
/*
* TypecheckCoordCombineAggReturnType returns whether the return type of the aggregate
* being distributed by coord_combine_agg matches the null constant used to inform postgres
* what the aggregate's expected return type is.
*/
static bool
TypecheckCoordCombineAggReturnType(FunctionCallInfo fcinfo, Oid ffunc, StypeBox *box)
{
Aggref *aggref = AggGetAggref(fcinfo);
if (aggref == NULL)
{
return false;
}
Oid finalType = ffunc == InvalidOid ?
box->transtype : get_func_rettype(ffunc);
Assert(list_length(aggref->args) == 3);
TargetEntry *nulltag = list_nth(aggref->args, 2);
return nulltag != NULL && IsA(nulltag->expr, Const) &&
((Const *) nulltag->expr)->consttype == finalType;
}

View File

@ -432,5 +432,36 @@ select key, count(distinct aggdata)
from aggdata group by key order by 1, 2;
ERROR: type "aggregate_support.aggdata" does not exist
CONTEXT: while executing command on localhost:xxxxx
-- Test https://github.com/citusdata/citus/issues/3328
create table nulltable(id int);
insert into nulltable values (0);
-- These cases are not type correct
select pg_catalog.worker_partial_agg('string_agg(text,text)'::regprocedure, id) from nulltable;
ERROR: worker_partial_agg_sfunc cannot type check aggregates taking anything other than 1 argument
select pg_catalog.worker_partial_agg('sum(int8)'::regprocedure, id) from nulltable;
ERROR: worker_partial_agg_sfunc could not confirm type correctness
select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::text) from nulltable;
ERROR: coord_combine_agg_ffunc could not confirm type correctness
select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::text) from nulltable;
ERROR: coord_combine_agg_ffunc could not confirm type correctness
-- These cases are type correct
select pg_catalog.worker_partial_agg('sum(int)'::regprocedure, id) from nulltable;
worker_partial_agg
---------------------------------------------------------------------
0
(1 row)
select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::float8) from nulltable;
coord_combine_agg
---------------------------------------------------------------------
0
(1 row)
select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::float8) from nulltable;
coord_combine_agg
---------------------------------------------------------------------
(1 row)
set client_min_messages to error;
drop schema aggregate_support cascade;

View File

@ -204,5 +204,18 @@ RESET citus.task_executor_type;
select key, count(distinct aggdata)
from aggdata group by key order by 1, 2;
-- Test https://github.com/citusdata/citus/issues/3328
create table nulltable(id int);
insert into nulltable values (0);
-- These cases are not type correct
select pg_catalog.worker_partial_agg('string_agg(text,text)'::regprocedure, id) from nulltable;
select pg_catalog.worker_partial_agg('sum(int8)'::regprocedure, id) from nulltable;
select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::text) from nulltable;
select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::text) from nulltable;
-- These cases are type correct
select pg_catalog.worker_partial_agg('sum(int)'::regprocedure, id) from nulltable;
select pg_catalog.coord_combine_agg('sum(float8)'::regprocedure, id::text::cstring, null::float8) from nulltable;
select pg_catalog.coord_combine_agg('avg(float8)'::regprocedure, ARRAY[id,id,id]::text::cstring, null::float8) from nulltable;
set client_min_messages to error;
drop schema aggregate_support cascade;