diff --git a/src/backend/distributed/utils/aggregate_utils.c b/src/backend/distributed/utils/aggregate_utils.c index 3771f428d..95c42e18d 100644 --- a/src/backend/distributed/utils/aggregate_utils.c +++ b/src/backend/distributed/utils/aggregate_utils.c @@ -55,8 +55,10 @@ static HeapTuple GetProcForm(Oid oid, Form_pg_proc *form); static HeapTuple GetTypeForm(Oid oid, Form_pg_type *form); static void * pallocInAggContext(FunctionCallInfo fcinfo, size_t size); static void aclcheckAggregate(ObjectType objectType, Oid userOid, Oid funcOid); +static Datum GetAggInitVal(Datum textInitVal, Oid transtype); static void InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, Oid transtype); +static StypeBox * TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo); static void HandleTransition(StypeBox *box, FunctionCallInfo fcinfo, FunctionCallInfo innerFcinfo); static void HandleStrictUninit(StypeBox *box, FunctionCallInfo fcinfo, Datum value); @@ -142,8 +144,29 @@ aclcheckAggregate(ObjectType objectType, Oid userOid, Oid funcOid) } +/* Copied from nodeAgg.c */ +static Datum +GetAggInitVal(Datum textInitVal, Oid transtype) +{ + /* *INDENT-OFF* */ + Oid typinput, + typioparam; + char *strInitVal; + Datum initVal; + + getTypeInputInfo(transtype, &typinput, &typioparam); + strInitVal = TextDatumGetCString(textInitVal); + initVal = OidInputFunctionCall(typinput, strInitVal, + typioparam, -1); + pfree(strInitVal); + return initVal; + /* *INDENT-ON* */ +} + + /* - * See GetAggInitVal from pg's nodeAgg.c + * InitializeStypeBox fills in the rest of an StypeBox's fields besides agg, + * handling both permission checking & setting up the initial transition state. */ static void InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, Oid @@ -171,9 +194,6 @@ InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, O } else { - Oid typinput, - typioparam; - MemoryContext aggregateContext; if (!AggCheckCallContext(fcinfo, &aggregateContext)) { @@ -181,17 +201,52 @@ InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, O } MemoryContext oldContext = MemoryContextSwitchTo(aggregateContext); - getTypeInputInfo(transtype, &typinput, &typioparam); - char *strInitVal = TextDatumGetCString(textInitVal); - box->value = OidInputFunctionCall(typinput, strInitVal, - typioparam, -1); - pfree(strInitVal); + box->value = GetAggInitVal(textInitVal, transtype); MemoryContextSwitchTo(oldContext); } } +/* + * TryCreateStypeBoxFromFcinfoAggref attempts to initialize an StypeBox through + * introspection of the fcinfo's Aggref from AggGetAggref. This is required + * when we receive no intermediate rows. + * + * Returns NULL if the Aggref isn't our expected shape. + */ +static StypeBox * +TryCreateStypeBoxFromFcinfoAggref(FunctionCallInfo fcinfo) +{ + Aggref *aggref = AggGetAggref(fcinfo); + if (aggref == NULL || aggref->args == NIL) + { + return NULL; + } + + TargetEntry *aggArg = linitial(aggref->args); + if (!IsA(aggArg->expr, Const)) + { + return NULL; + } + + Const *aggConst = (Const *) aggArg->expr; + if (aggConst->consttype != OIDOID && aggConst->consttype != REGPROCEDUREOID) + { + return NULL; + } + + Form_pg_aggregate aggform; + StypeBox *box = pallocInAggContext(fcinfo, sizeof(StypeBox)); + box->agg = DatumGetObjectId(aggConst->constvalue); + HeapTuple aggTuple = GetAggregateForm(box->agg, &aggform); + InitializeStypeBox(fcinfo, box, aggTuple, aggform->aggtranstype); + ReleaseSysCache(aggTuple); + + return box; +} + + /* * HandleTransition copies logic used in nodeAgg's advance_transition_function * for handling result of transition function. @@ -367,6 +422,11 @@ worker_partial_agg_ffunc(PG_FUNCTION_ARGS) Oid typoutput = InvalidOid; bool typIsVarlena = false; + if (box == NULL) + { + box = TryCreateStypeBoxFromFcinfoAggref(fcinfo); + } + if (box == NULL || box->valueNull) { PG_RETURN_NULL(); @@ -544,11 +604,12 @@ coord_combine_agg_ffunc(PG_FUNCTION_ARGS) if (box == NULL) { - /* - * Ideally we'd return initval, - * but we don't know which aggregate we're handling here - */ - PG_RETURN_NULL(); + box = TryCreateStypeBoxFromFcinfoAggref(fcinfo); + + if (box == NULL) + { + PG_RETURN_NULL(); + } } HeapTuple aggtuple = GetAggregateForm(box->agg, &aggform); diff --git a/src/test/regress/expected/aggregate_support.out b/src/test/regress/expected/aggregate_support.out index e4c69e77f..e33f16a1f 100644 --- a/src/test/regress/expected/aggregate_support.out +++ b/src/test/regress/expected/aggregate_support.out @@ -75,11 +75,11 @@ select key, sum2(val) filter (where valf < 5), sum2_strict(val) filter (where va key | sum2 | sum2_strict -----+------+------------- 1 | | - 2 | | 10 - 3 | | - 5 | | - 6 | | - 7 | | + 2 | 10 | 10 + 3 | 0 | + 5 | 0 | + 6 | 0 | + 7 | 0 | 9 | 0 | 0 (7 rows) @@ -106,11 +106,11 @@ select id, sum2(distinct val), sum2_strict(distinct val) from aggdata group by i -- ORDER BY unsupported select key, sum2(val order by valf), sum2_strict(val order by valf) from aggdata group by key order by key; ERROR: unsupported aggregate function sum2 --- Without intermediate results we return NULL, even though the correct result is 0 -select sum2(val) from aggdata where valf = 0; - sum2 ------- - +-- Test handling a lack of intermediate results +select sum2(val), sum2_strict(val) from aggdata where valf = 0; + sum2 | sum2_strict +------+------------- + 0 | (1 row) -- test polymorphic aggregates from https://github.com/citusdata/citus/issues/2397 @@ -165,7 +165,7 @@ create aggregate sumstring(text) ( ); select sumstring(valf::text) from aggdata where valf is not null; ERROR: function "aggregate_support.sumstring(text)" does not exist -CONTEXT: while executing command on localhost:57637 +CONTEXT: while executing command on localhost:57638 select create_distributed_function('sumstring(text)'); create_distributed_function ----------------------------- diff --git a/src/test/regress/sql/aggregate_support.sql b/src/test/regress/sql/aggregate_support.sql index 8bed659c3..f0360bfc9 100644 --- a/src/test/regress/sql/aggregate_support.sql +++ b/src/test/regress/sql/aggregate_support.sql @@ -61,8 +61,8 @@ select key, sum2(distinct val), sum2_strict(distinct val) from aggdata group by select id, sum2(distinct val), sum2_strict(distinct val) from aggdata group by id order by id; -- ORDER BY unsupported select key, sum2(val order by valf), sum2_strict(val order by valf) from aggdata group by key order by key; --- Without intermediate results we return NULL, even though the correct result is 0 -select sum2(val) from aggdata where valf = 0; +-- Test handling a lack of intermediate results +select sum2(val), sum2_strict(val) from aggdata where valf = 0; -- test polymorphic aggregates from https://github.com/citusdata/citus/issues/2397