Fix handling of empty intermediate results when distributing custom aggregates

pull/3316/head
Philip Dubé 2019-12-17 19:04:50 +00:00
parent bb6ba89708
commit e9bbdb8f31
3 changed files with 88 additions and 27 deletions

View File

@ -55,8 +55,10 @@ static HeapTuple GetProcForm(Oid oid, Form_pg_proc *form);
static HeapTuple GetTypeForm(Oid oid, Form_pg_type *form); static HeapTuple GetTypeForm(Oid oid, Form_pg_type *form);
static void * pallocInAggContext(FunctionCallInfo fcinfo, size_t size); static void * pallocInAggContext(FunctionCallInfo fcinfo, size_t size);
static void aclcheckAggregate(ObjectType objectType, Oid userOid, Oid funcOid); static void aclcheckAggregate(ObjectType objectType, Oid userOid, Oid funcOid);
static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
static void InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, static void InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple,
Oid transtype); Oid transtype);
static StypeBox * TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo);
static void HandleTransition(StypeBox *box, FunctionCallInfo fcinfo, static void HandleTransition(StypeBox *box, FunctionCallInfo fcinfo,
FunctionCallInfo innerFcinfo); FunctionCallInfo innerFcinfo);
static void HandleStrictUninit(StypeBox *box, FunctionCallInfo fcinfo, Datum value); static void HandleStrictUninit(StypeBox *box, FunctionCallInfo fcinfo, Datum value);
@ -142,8 +144,29 @@ aclcheckAggregate(ObjectType objectType, Oid userOid, Oid funcOid)
} }
/* Copied from nodeAgg.c */
static Datum
GetAggInitVal(Datum textInitVal, Oid transtype)
{
/* *INDENT-OFF* */
Oid typinput,
typioparam;
char *strInitVal;
Datum initVal;
getTypeInputInfo(transtype, &typinput, &typioparam);
strInitVal = TextDatumGetCString(textInitVal);
initVal = OidInputFunctionCall(typinput, strInitVal,
typioparam, -1);
pfree(strInitVal);
return initVal;
/* *INDENT-ON* */
}
/* /*
* See GetAggInitVal from pg's nodeAgg.c * InitializeStypeBox fills in the rest of an StypeBox's fields besides agg,
* handling both permission checking & setting up the initial transition state.
*/ */
static void static void
InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, Oid InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, Oid
@ -171,9 +194,6 @@ InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, O
} }
else else
{ {
Oid typinput,
typioparam;
MemoryContext aggregateContext; MemoryContext aggregateContext;
if (!AggCheckCallContext(fcinfo, &aggregateContext)) if (!AggCheckCallContext(fcinfo, &aggregateContext))
{ {
@ -181,17 +201,52 @@ InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, O
} }
MemoryContext oldContext = MemoryContextSwitchTo(aggregateContext); MemoryContext oldContext = MemoryContextSwitchTo(aggregateContext);
getTypeInputInfo(transtype, &typinput, &typioparam); box->value = GetAggInitVal(textInitVal, transtype);
char *strInitVal = TextDatumGetCString(textInitVal);
box->value = OidInputFunctionCall(typinput, strInitVal,
typioparam, -1);
pfree(strInitVal);
MemoryContextSwitchTo(oldContext); MemoryContextSwitchTo(oldContext);
} }
} }
/*
* TryCreateStypeBoxFromFcinfoAggref attempts to initialize an StypeBox through
* introspection of the fcinfo's Aggref from AggGetAggref. This is required
* when we receive no intermediate rows.
*
* Returns NULL if the Aggref isn't our expected shape.
*/
static StypeBox *
TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo)
{
Aggref *aggref = AggGetAggref(fcinfo);
if (aggref == NULL || aggref->args == NIL)
{
return NULL;
}
TargetEntry *aggArg = linitial(aggref->args);
if (!IsA(aggArg->expr, Const))
{
return NULL;
}
Const *aggConst = (Const *) aggArg->expr;
if (aggConst->consttype != OIDOID && aggConst->consttype != REGPROCEDUREOID)
{
return NULL;
}
Form_pg_aggregate aggform;
StypeBox *box = pallocInAggContext(fcinfo, sizeof(StypeBox));
box->agg = DatumGetObjectId(aggConst->constvalue);
HeapTuple aggTuple = GetAggregateForm(box->agg, &aggform);
InitializeStypeBox(fcinfo, box, aggTuple, aggform->aggtranstype);
ReleaseSysCache(aggTuple);
return box;
}
/* /*
* HandleTransition copies logic used in nodeAgg's advance_transition_function * HandleTransition copies logic used in nodeAgg's advance_transition_function
* for handling result of transition function. * for handling result of transition function.
@ -367,6 +422,11 @@ worker_partial_agg_ffunc(PG_FUNCTION_ARGS)
Oid typoutput = InvalidOid; Oid typoutput = InvalidOid;
bool typIsVarlena = false; bool typIsVarlena = false;
if (box == NULL)
{
box = TryCreateStypeBoxFromFcinfoAggref(fcinfo);
}
if (box == NULL || box->valueNull) if (box == NULL || box->valueNull)
{ {
PG_RETURN_NULL(); PG_RETURN_NULL();
@ -544,12 +604,13 @@ coord_combine_agg_ffunc(PG_FUNCTION_ARGS)
if (box == NULL) if (box == NULL)
{ {
/* box = TryCreateStypeBoxFromFcinfoAggref(fcinfo);
* Ideally we'd return initval,
* but we don't know which aggregate we're handling here if (box == NULL)
*/ {
PG_RETURN_NULL(); PG_RETURN_NULL();
} }
}
HeapTuple aggtuple = GetAggregateForm(box->agg, &aggform); HeapTuple aggtuple = GetAggregateForm(box->agg, &aggform);
Oid ffunc = aggform->aggfinalfn; Oid ffunc = aggform->aggfinalfn;

View File

@ -75,11 +75,11 @@ select key, sum2(val) filter (where valf < 5), sum2_strict(val) filter (where va
key | sum2 | sum2_strict key | sum2 | sum2_strict
-----+------+------------- -----+------+-------------
1 | | 1 | |
2 | | 10 2 | 10 | 10
3 | | 3 | 0 |
5 | | 5 | 0 |
6 | | 6 | 0 |
7 | | 7 | 0 |
9 | 0 | 0 9 | 0 | 0
(7 rows) (7 rows)
@ -106,11 +106,11 @@ select id, sum2(distinct val), sum2_strict(distinct val) from aggdata group by i
-- ORDER BY unsupported -- ORDER BY unsupported
select key, sum2(val order by valf), sum2_strict(val order by valf) from aggdata group by key order by key; select key, sum2(val order by valf), sum2_strict(val order by valf) from aggdata group by key order by key;
ERROR: unsupported aggregate function sum2 ERROR: unsupported aggregate function sum2
-- Without intermediate results we return NULL, even though the correct result is 0 -- Test handling a lack of intermediate results
select sum2(val) from aggdata where valf = 0; select sum2(val), sum2_strict(val) from aggdata where valf = 0;
sum2 sum2 | sum2_strict
------ ------+-------------
0 |
(1 row) (1 row)
-- test polymorphic aggregates from https://github.com/citusdata/citus/issues/2397 -- test polymorphic aggregates from https://github.com/citusdata/citus/issues/2397
@ -165,7 +165,7 @@ create aggregate sumstring(text) (
); );
select sumstring(valf::text) from aggdata where valf is not null; select sumstring(valf::text) from aggdata where valf is not null;
ERROR: function "aggregate_support.sumstring(text)" does not exist ERROR: function "aggregate_support.sumstring(text)" does not exist
CONTEXT: while executing command on localhost:57637 CONTEXT: while executing command on localhost:57638
select create_distributed_function('sumstring(text)'); select create_distributed_function('sumstring(text)');
create_distributed_function create_distributed_function
----------------------------- -----------------------------

View File

@ -61,8 +61,8 @@ select key, sum2(distinct val), sum2_strict(distinct val) from aggdata group by
select id, sum2(distinct val), sum2_strict(distinct val) from aggdata group by id order by id; select id, sum2(distinct val), sum2_strict(distinct val) from aggdata group by id order by id;
-- ORDER BY unsupported -- ORDER BY unsupported
select key, sum2(val order by valf), sum2_strict(val order by valf) from aggdata group by key order by key; select key, sum2(val order by valf), sum2_strict(val order by valf) from aggdata group by key order by key;
-- Without intermediate results we return NULL, even though the correct result is 0 -- Test handling a lack of intermediate results
select sum2(val) from aggdata where valf = 0; select sum2(val), sum2_strict(val) from aggdata where valf = 0;
-- test polymorphic aggregates from https://github.com/citusdata/citus/issues/2397 -- test polymorphic aggregates from https://github.com/citusdata/citus/issues/2397