diff --git a/src/backend/distributed/citus--8.4-1--8.4-2.sql b/src/backend/distributed/citus--8.4-1--8.4-2.sql index 9449077cf..fd95268de 100644 --- a/src/backend/distributed/citus--8.4-1--8.4-2.sql +++ b/src/backend/distributed/citus--8.4-1--8.4-2.sql @@ -40,7 +40,6 @@ CREATE AGGREGATE worker_partial_agg(oid, ...) ( STYPE = internal, SFUNC = worker_partial_agg_sfunc, FINALFUNC = worker_partial_agg_ffunc, - FINALFUNC_EXTRA, COMBINEFUNC = stypebox_combine, SERIALFUNC = stypebox_serialize, DESERIALFUNC = stypebox_deserialize, diff --git a/src/backend/distributed/utils/aggregate_utils.c b/src/backend/distributed/utils/aggregate_utils.c index 54efdce30..7d4023105 100644 --- a/src/backend/distributed/utils/aggregate_utils.c +++ b/src/backend/distributed/utils/aggregate_utils.c @@ -1,6 +1,10 @@ #include "postgres.h" +#include "access/htup_details.h" +#include "catalog/pg_aggregate.h" +#include "catalog/pg_proc.h" #include "utils/fmgr.h" +#include "utils/syscache.h" PG_FUNCTION_INFO_V1(stypebox_serialize); PG_FUNCTION_INFO_V1(stypebox_deserialize); @@ -12,17 +16,50 @@ PG_FUNCTION_INFO_V1(coord_combine_agg_ffunc); typedef struct StypeBox { Datum value; - bool value_null; Oid agg; + bool value_null; } StypeBox; +static Form_pg_aggregate get_aggform(Oid aggfnoid); +static Form_pg_proc get_procform(Oid aggfnoid); + +static Form_pg_aggregate +get_aggform(Oid aggfnoid) +{ + /* Fetch the pg_aggregate row */ + HeapTuple tuple = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(aggfnoid)); + if (!HeapTupleIsValid(tuple)) + elog(ERROR, "cache lookup failed for aggregate %u", + aggfnoid); + return (Form_pg_aggregate) GETSTRUCT(tuple); +} + +static Form_pg_proc +get_procform(Oid fnoid) +{ + Form_pg_proc = SearchSysCache1(PROCID, ObjectIdGetDatum(fnoid)); + if (!HeapTupleIsValid(tuple)) + elog(ERROR, "cache lookup failed for function %u", + fnoid); + return (Form_pg_proc) GETSTRUCT(tuple); +} /* * (box) -> bytea - * return bytes(box.agg.name, box.agg.serial(box.value)) + * return bytes(box.agg.oid, box.agg.serial(box.value)) */ Datum stypebox_serialize(PG_FUNCTION_ARGS) { + StypeBox *box = PG_GETARG_POINTER(0); + Form_pg_aggregate aggform = get_aggform(box->agg); + // TODO return null if box null? + byteap *valbytes = DatumGetByteaPP(DirectFunctionCall1(aggform->serialfunc, box->value)); + byteap *realbytes = palloc(VARSIZE(valbytes) + sizeof(Oid)); + SET_VARSIZE(realbytes, VARSIZE(valbytes) + sizeof(Oid)); + memcpy(VARDATA(realbytes), &box->agg, sizeof(Oid)); + memcpy(VARDATA(realbytes) + sizeof(Oid), VARDATA(valbytes), VARSIZE(valbytes) - VARHDRSZ); + pfree(valbytes); // TODO I get to free this right? + PG_RETURN_BYTEA_P(valbytes); } /* @@ -34,6 +71,20 @@ stypebox_serialize(PG_FUNCTION_ARGS) Datum stypebox_deserialize(PG_FUNCTION_ARGS) { + StypeBox *box; + byteap *bytes = PG_GETARG_BYTEA_PP(0); + byteap *inner_bytes = PG_GETARG_BYTEA_P_SLICE(0, sizeof(Oid), VARSIZE(bytes) - sizeof(Oid)) + Oid agg; + Form_pg_aggregate aggform; + memcpy(&agg, VARDATA(bytes), sizeof(Oid)); + aggform = get_aggform(agg); + // Can deserialize be called with NULL? + + box = palloc(sizeof(StypeBox)); + box->agg = agg; + box->value = DirectFunctionCall2(aggform->deserialfunc, inner_bytes, PG_GETARG_DATUM(1)); + box->null_value = false; + PG_RETURN_POINTER(box); } /* @@ -46,7 +97,11 @@ stypebox_combine(PG_FUNCTION_ARGS) { StypeBox *box1 = NULL; StypeBox *box2 = NULL; + FunctionCallInfo inner_fcinfo; Oid aggOid; + Form_pg_aggregate aggform; + Form_pg_proc combineform; + if (!PG_ISARGNULL(0)) { box1 = PG_GETARG_POINTER(0); @@ -55,6 +110,7 @@ stypebox_combine(PG_FUNCTION_ARGS) { box2 = PG_GETARG_POINTER(1); } + if (box1 == NULL) { if (box2 == NULL) @@ -66,8 +122,37 @@ stypebox_combine(PG_FUNCTION_ARGS) box1->value_null = true; box1->agg = box2->agg; } - // TODO - // box1.agg = box1.agg.combine(box1.value, box2.value) + + aggform = get_aggform(box->agg); + combineform = get_procform(aggform->combinefn); + + // TODO respect strictness + Assert(IsValidOid(aggform->combineefn)); + + if (combineform->proisstrict) + { + if (box1->value_null) + { + if (box2->value_null) + { + PG_RETURN_NULL(); + } + PG_RETURN_DATUM(box2->value); + } + if (box2->value_null) + { + PG_RETURN_DATUM(box1->value); + } + } + + InitFunctionCallInfoData(&inner_fcinfo, &info, fcinfo->nargs - 1, fcinfo->collation, fcinfo->context, fcinfo->resultinfo); + inner_fcinfo.arg[0] = box1->value; + inner_fcinfo.argnull[0] = box1->value_null; + inner_fcinfo.arg[1] = box2->value; + inner_fcinfo.argnull[1] = box2->value_null; + // TODO Deal with memory management juggling (see executor/nodeAgg) + box1->value = FunctionCallInvoke(inner_fcinfo); + box1->value_null = inner_fcinfo.isnull; PG_RETURN_POINTER(box1); } @@ -81,6 +166,8 @@ Datum worker_partial_agg_sfunc(PG_FUNCTION_ARGS) { StypeBox *box; + FunctionCallInfo inner_fcinfo; + FmgrInfo info; int i; if (PG_ARGISNULL(0)) { box = palloc(sizeof(StypeBox)); @@ -91,9 +178,7 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS) box = PG_GETARG_POINTER(0); Assert(box->agg == PG_GETARG_OID(1)); } - FmgrInfo info; fmgr_info(box->agg, &info); - FunctionCallInfo inner_fcinfo; InitFunctionCallInfoData(&inner_fcinfo, &info, fcinfo->nargs - 1, fcinfo->collation, fcinfo->context, fcinfo->resultinfo); // TODO if strict, deal with it // Deal with memory management juggling (see executor/nodeAgg) @@ -113,12 +198,48 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS) Datum worker_partial_agg_ffunc(PG_FUNCTION_ARGS) { + StypeBox *box = PG_GETARG_POINTER(0); + Form_pg_aggregate aggform = get_aggform(box->agg); + PG_RETURN_DATUM(DirectFunctionCall1(aggform->serialfunc, box->value)); } /* * (box, agg, valbytes) -> box + * box->agg = agg + * box->value = agg.sfunc(box->value, agg.deserialize(valbytes)) + * return box */ Datum coord_combine_agg_sfunc(PG_FUNCTION_ARGS) { + // TODO } + +/* + * box -> fval + * return box.agg.ffunc(box.value) + */ +Datum +coord_combine_agg_ffunc(PG_FUNCTION_ARGS) +{ + StypeBox = PG_GETARG_POINTER(0); + FunctionCallInfo inner_fcinfo; + FmgrInfo info; + Form_pg_aggregate aggform = get_aggform(box->agg); + Form_pg_proc ffuncform; + + if (!IsValidOid(aggform->aggfinalfn)) + { + if (box->value_null) { + return NULL; + } + PG_RETURN_DATUM(box->value); + } + + ffuncform = get_aggform(aggform->aggfinalfn); + // TODO FINALFUNC EXTRA & stuff + fmgr_info(aggform->aggfinalfn, &info); + InitFunctionCallInfoData(&inner_fcinfo, &info, fcinfo->nargs - 1, fcinfo->collation, fcinfo->context, fcinfo->resultinfo); + +} +