diff --git a/src/backend/distributed/citus--8.4-1--8.4-2.sql b/src/backend/distributed/citus--8.4-1--8.4-2.sql deleted file mode 100644 index e7a11b615..000000000 --- a/src/backend/distributed/citus--8.4-1--8.4-2.sql +++ /dev/null @@ -1,61 +0,0 @@ -CREATE FUNCTION stype_serialize(internal, oid, ...) -RETURNS internal -AS 'MODULE_PATHNAME' -LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE; - -CREATE FUNCTION stype_deserialize(internal, oid, ...) -RETURNS internal -AS 'MODULE_PATHNAME' -LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE; - -CREATE FUNCTION stype_combine(internal, oid, ...) -RETURNS internal -AS 'MODULE_PATHNAME' -LANGUAGE C PARALLEL SAFE; - -CREATE FUNCTION worker_partial_agg_sfunc(internal, oid, ...) -RETURNS internal -AS 'MODULE_PATHNAME' -LANGUAGE C PARALLEL SAFE; - -CREATE FUNCTION worker_partial_agg_ffunc(internal, oid, ...) -RETURNS internal -AS 'MODULE_PATHNAME' -LANGUAGE C PARALLEL SAFE; - -CREATE FUNCTION coord_combine_agg_sfunc(internal, oid, ...) -RETURNS internal -AS 'MODULE_PATHNAME' -LANGUAGE C PARALLEL SAFE; - -CREATE FUNCTION coord_combine_agg_ffunc(internal, oid, ...) -RETURNS internal -AS 'MODULE_PATHNAME' -LANGUAGE C PARALLEL SAFE; - --- select worker_partial_agg(agg, ...) --- equivalent to --- select serialize_stype(agg_without_ffunc(...)) -CREATE AGGREGATE worker_partial_agg(oid, ...) ( - STYPE = internal, - SFUNC = worker_partial_agg_sfunc, - FINALFUNC = worker_partial_agg_ffunc, - COMBINEFUNC = stypebox_combine, - SERIALFUNC = stypebox_serialize, - DESERIALFUNC = stypebox_deserialize, - PARALLEL = SAFE -) - --- select coord_combine_agg(agg, col) --- equivalent to --- select agg_ffunc(agg_combine(col)) -CREATE AGGREGATE coord_combine_agg(oid, ...) ( - STYPE = internal, - SFUNC = coord_combine_sfunc, - FINALFUNC = coord_combine_agg_ffunc, - FINALFUNC_EXTRA, - COMBINEFUNC = stypebox_combine, - SERIALFUNC = stypebox_serialize, - DESERIALFUNC = stypebox_deserialize, - PARALLEL = SAFE -) \ No newline at end of file diff --git a/src/backend/distributed/citus--8.4-1--8.4-customagg.sql b/src/backend/distributed/citus--8.4-1--8.4-customagg.sql new file mode 100644 index 000000000..35f2e8d43 --- /dev/null +++ b/src/backend/distributed/citus--8.4-1--8.4-customagg.sql @@ -0,0 +1,65 @@ +SET search_path = 'pg_catalog'; + +CREATE FUNCTION citus_stype_serialize(internal) +RETURNS bytea +AS 'MODULE_PATHNAME' +LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE; + +CREATE FUNCTION citus_stype_deserialize(bytea, internal) +RETURNS internal +AS 'MODULE_PATHNAME' +LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE; + +CREATE FUNCTION citus_stype_combine(internal, internal) +RETURNS internal +AS 'MODULE_PATHNAME' +LANGUAGE C PARALLEL SAFE; + +CREATE FUNCTION worker_partial_agg_sfunc(internal, oid, anyelement) +RETURNS internal +AS 'MODULE_PATHNAME' +LANGUAGE C PARALLEL SAFE; + +CREATE FUNCTION worker_partial_agg_ffunc(internal) +RETURNS bytea +AS 'MODULE_PATHNAME' +LANGUAGE C PARALLEL SAFE; + +CREATE FUNCTION coord_combine_agg_sfunc(internal, oid, bytea, anyelement) +RETURNS internal +AS 'MODULE_PATHNAME' +LANGUAGE C PARALLEL SAFE; + +CREATE FUNCTION coord_combine_agg_ffunc(internal, oid, bytea, anyelement) +RETURNS anyelement +AS 'MODULE_PATHNAME' +LANGUAGE C PARALLEL SAFE; + +-- select worker_partial_agg(agg, ...) +-- equivalent to +-- select serialize_stype(agg_without_ffunc(...)) +CREATE AGGREGATE worker_partial_agg(oid, anyelement) ( + STYPE = internal, + SFUNC = worker_partial_agg_sfunc, + FINALFUNC = worker_partial_agg_ffunc, + COMBINEFUNC = citus_stype_combine, + SERIALFUNC = citus_stype_serialize, + DESERIALFUNC = citus_stype_deserialize, + PARALLEL = SAFE +); + +-- select coord_combine_agg(agg, col) +-- equivalent to +-- select agg_ffunc(agg_combine(col)) +CREATE AGGREGATE coord_combine_agg(oid, bytea, anyelement) ( + STYPE = internal, + SFUNC = coord_combine_agg_sfunc, + FINALFUNC = coord_combine_agg_ffunc, + FINALFUNC_EXTRA, + COMBINEFUNC = citus_stype_combine, + SERIALFUNC = citus_stype_serialize, + DESERIALFUNC = citus_stype_deserialize, + PARALLEL = SAFE +); + +RESET search_path; diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index e4aefa1a5..5e09031a8 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -1512,8 +1512,6 @@ MasterAggregateExpression(Aggref *originalAggregate, HeapTuple aggTuple; Form_pg_aggregate aggform; Oid combine; - Oid serial = InvalidOid; - Oid deserial = InvalidOid; aggTuple = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(originalAggregate->aggfnoid)); @@ -1527,11 +1525,6 @@ MasterAggregateExpression(Aggref *originalAggregate, { aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); combine = aggform->aggcombinefn; - if (combine != InvalidOid && originalAggregate->aggtranstype == INTERNALOID) - { - serial = aggform->aggserialfn; - deserial = aggform->aggdeserialfn; - } ReleaseSysCache(aggTuple); } @@ -1548,6 +1541,8 @@ MasterAggregateExpression(Aggref *originalAggregate, int32 workerReturnTypeMod = -1; Oid workerCollationId = InvalidOid; + elog(WARNING, "coord_combine_agg %d %d", coordCombineId, + originalAggregate->aggfnoid); aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum( originalAggregate->aggfnoid), false, true); @@ -2808,8 +2803,6 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, HeapTuple aggTuple; Form_pg_aggregate aggform; Oid combine; - Oid serial = InvalidOid; - Oid deserial = InvalidOid; aggTuple = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(originalAggregate->aggfnoid)); @@ -2823,11 +2816,6 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, { aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); combine = aggform->aggcombinefn; - if (combine != InvalidOid && originalAggregate->aggtranstype == INTERNALOID) - { - serial = aggform->aggserialfn; - deserial = aggform->aggdeserialfn; - } ReleaseSysCache(aggTuple); } @@ -2840,6 +2828,8 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, Oid workerPartialId = AggregateFunctionOidWithoutInput( WORKER_PARTIAL_AGGREGATE_NAME); + elog(WARNING, "worker_partial_agg %d %d", workerPartialId, + originalAggregate->aggfnoid); aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum( originalAggregate->aggfnoid), false, true); aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false)); @@ -2854,7 +2844,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, /* worker_partial_agg(agg, ...args) */ newWorkerAggregate = makeNode(Aggref); newWorkerAggregate->aggfnoid = workerPartialId; - newWorkerAggregate->aggtype = originalAggregate->aggtype; + newWorkerAggregate->aggtype = BYTEAOID; newWorkerAggregate->args = aggArguments; newWorkerAggregate->aggkind = AGGKIND_NORMAL; newWorkerAggregate->aggfilter = originalAggregate->aggfilter; diff --git a/src/backend/distributed/utils/aggregate_utils.c b/src/backend/distributed/utils/aggregate_utils.c index 66a11bc9f..0a9a891e1 100644 --- a/src/backend/distributed/utils/aggregate_utils.c +++ b/src/backend/distributed/utils/aggregate_utils.c @@ -4,13 +4,14 @@ #include "catalog/pg_aggregate.h" #include "catalog/pg_proc.h" #include "catalog/pg_type.h" +#include "utils/builtins.h" #include "utils/lsyscache.h" #include "utils/syscache.h" #include "fmgr.h" -PG_FUNCTION_INFO_V1(stypebox_serialize); -PG_FUNCTION_INFO_V1(stypebox_deserialize); -PG_FUNCTION_INFO_V1(stypebox_combine); +PG_FUNCTION_INFO_V1(citus_stype_serialize); +PG_FUNCTION_INFO_V1(citus_stype_deserialize); +PG_FUNCTION_INFO_V1(citus_stype_combine); PG_FUNCTION_INFO_V1(worker_partial_agg_sfunc); PG_FUNCTION_INFO_V1(worker_partial_agg_ffunc); PG_FUNCTION_INFO_V1(coord_combine_agg_sfunc); @@ -69,12 +70,40 @@ get_typeform(Oid oid, Form_pg_type *form) } +/* + * See GetAggInitVal from pg's nodeAgg.c + */ +static void +InitializeStypeBox(StypeBox *box, HeapTuple aggTuple, Oid transtype) +{ + Datum textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple, + Anum_pg_aggregate_agginitval, + &box->value_null); + if (box->value_null) + { + box->value = (Datum) 0; + } + else + { + Oid typinput, + typioparam; + char *strInitVal; + + getTypeInputInfo(transtype, &typinput, &typioparam); + strInitVal = TextDatumGetCString(textInitVal); + box->value = OidInputFunctionCall(typinput, strInitVal, + typioparam, -1); + pfree(strInitVal); + } +} + + /* * (box) -> bytea * return bytes(box.agg.oid, box.agg.serial(box.value)) */ Datum -stypebox_serialize(PG_FUNCTION_ARGS) +citus_stype_serialize(PG_FUNCTION_ARGS) { FunctionCallInfoData inner_fcinfodata; FunctionCallInfo inner_fcinfo = &inner_fcinfodata; @@ -93,12 +122,14 @@ stypebox_serialize(PG_FUNCTION_ARGS) Size realbyteslen; Datum result; + elog(WARNING, "citus_stype_serialize"); + aggtuple = get_aggform(box->agg, &aggform); serial = aggform->aggserialfn; transtype = aggform->aggtranstype; ReleaseSysCache(aggtuple); - if (serial != InvalidOid) + if (serial == InvalidOid) { /* TODO do we have to fallback to output/receive if not set? */ /* ie is it possible for send/recv to be unset? */ @@ -107,6 +138,8 @@ stypebox_serialize(PG_FUNCTION_ARGS) ReleaseSysCache(transtypetuple); } + Assert(serial != InvalidOid); + fmgr_info(serial, info); if (info->fn_strict && box->value_null) { @@ -157,7 +190,7 @@ stypebox_serialize(PG_FUNCTION_ARGS) * return box */ Datum -stypebox_deserialize(PG_FUNCTION_ARGS) +citus_stype_deserialize(PG_FUNCTION_ARGS) { StypeBox *box; bytea *bytes = PG_GETARG_BYTEA_PP(0); @@ -174,6 +207,8 @@ stypebox_deserialize(PG_FUNCTION_ARGS) StringInfoData buf; bool value_null; + elog(WARNING, "citus_stype_deserialize"); + memcpy(&agg, VARDATA(bytes), sizeof(Oid)); memcpy(&value_null, VARDATA(bytes) + sizeof(Oid), sizeof(bool)); @@ -198,6 +233,7 @@ stypebox_deserialize(PG_FUNCTION_ARGS) inner_bytes = PG_GETARG_BYTEA_P_SLICE(0, sizeof(Oid), VARSIZE(bytes) - sizeof(Oid)); + elog(WARNING, "deserial %d", VARSIZE(inner_bytes)); fmgr_info(deserial, &fdeserialinfo); InitFunctionCallInfoData(fdeserial_callinfodata, &fdeserialinfo, 2, fcinfo->fncollation, fcinfo->context, @@ -209,6 +245,13 @@ stypebox_deserialize(PG_FUNCTION_ARGS) box->value = FunctionCallInvoke(&fdeserial_callinfodata); box->value_null = fdeserial_callinfodata.isnull; } + + /* TODO Correct null handling */ + else if (value_null) + { + box->value = (Datum) 0; + box->value_null = true; + } else { transtypetuple = get_typeform(transtype, &transtypeform); @@ -222,7 +265,6 @@ stypebox_deserialize(PG_FUNCTION_ARGS) VARSIZE(bytes) - VARHDRSZ - sizeof(Oid) - sizeof(bool)); box->value = OidReceiveFunctionCall(recv, &buf, ioparam, -1); - box->value_null = value_null; } PG_RETURN_POINTER(box); @@ -235,7 +277,7 @@ stypebox_deserialize(PG_FUNCTION_ARGS) * return box */ Datum -stypebox_combine(PG_FUNCTION_ARGS) +citus_stype_combine(PG_FUNCTION_ARGS) { StypeBox *box1 = NULL; StypeBox *box2 = NULL; @@ -246,6 +288,8 @@ stypebox_combine(PG_FUNCTION_ARGS) HeapTuple aggtuple; Form_pg_aggregate aggform; + elog(WARNING, "citus_stype_combine"); + if (!PG_ARGISNULL(0)) { box1 = (StypeBox *) PG_GETARG_POINTER(0); @@ -315,38 +359,56 @@ Datum worker_partial_agg_sfunc(PG_FUNCTION_ARGS) { StypeBox *box; + Form_pg_aggregate aggform; + HeapTuple aggtuple; + Oid aggsfunc; FunctionCallInfoData inner_fcinfodata; FunctionCallInfo inner_fcinfo = &inner_fcinfodata; FmgrInfo info; int i; - if (PG_ARGISNULL(0)) + bool is_initial_call = PG_ARGISNULL(0); + + elog(WARNING, "worker_partial_agg_sfunc"); + + if (is_initial_call) { box = palloc(sizeof(StypeBox)); box->agg = PG_GETARG_OID(1); - box->value = (Datum) 0; - box->value_null = true; } else { box = (StypeBox *) PG_GETARG_POINTER(0); Assert(box->agg == PG_GETARG_OID(1)); } - fmgr_info(box->agg, &info); + + aggtuple = get_aggform(box->agg, &aggform); + aggsfunc = aggform->aggtransfn; + if (is_initial_call) + { + InitializeStypeBox(box, aggtuple, aggform->aggtranstype); + } + ReleaseSysCache(aggtuple); + + fmgr_info(aggsfunc, &info); InitFunctionCallInfoData(*inner_fcinfo, &info, fcinfo->nargs - 1, fcinfo->fncollation, fcinfo->context, fcinfo->resultinfo); if (info.fn_strict) { - if (box->value_null) - { - PG_RETURN_NULL(); - } for (i = 2; i < PG_NARGS(); i++) { if (PG_ARGISNULL(i)) { - PG_RETURN_NULL(); + elog(WARNING, "\tworker sfunc retnull %i", i); + PG_RETURN_POINTER(box); } } + if (box->value_null) + { + elog(WARNING, "\tworker sfunc seed"); + box->value = PG_GETARG_DATUM(2); + box->value_null = false; + PG_RETURN_POINTER(box); + } } /* Deal with memory management juggling (see executor/nodeAgg) */ @@ -358,6 +420,9 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS) (inner_fcinfo->nargs - 1)); box->value = FunctionCallInvoke(inner_fcinfo); box->value_null = inner_fcinfo->isnull; + + elog(WARNING, "\tworker sfunc null: %d", box->value_null); + PG_RETURN_POINTER(box); } @@ -374,48 +439,60 @@ worker_partial_agg_ffunc(PG_FUNCTION_ARGS) FmgrInfo info; StypeBox *box = (StypeBox *) PG_GETARG_POINTER(0); HeapTuple aggtuple; + HeapTuple transtypetuple; Form_pg_aggregate aggform; + Form_pg_type transtypeform; Oid serial; + Oid transtype; Datum result; + elog(WARNING, "worker_partial_agg_ffunc %p", box); + + if (box == NULL) + { + PG_RETURN_NULL(); + } + aggtuple = get_aggform(box->agg, &aggform); serial = aggform->aggserialfn; + transtype = aggform->aggtranstype; ReleaseSysCache(aggtuple); - if (serial != InvalidOid) + if (serial == InvalidOid) { - fmgr_info(serial, &info); - if (info.fn_strict && box->value_null) - { - PG_RETURN_NULL(); - } - InitFunctionCallInfoData(*inner_fcinfo, &info, 1, fcinfo->fncollation, - fcinfo->context, fcinfo->resultinfo); - inner_fcinfo->arg[0] = box->value; - inner_fcinfo->argnull[0] = box->value_null; - result = FunctionCallInvoke(inner_fcinfo); - if (inner_fcinfo->isnull) - { - PG_RETURN_NULL(); - } - PG_RETURN_DATUM(result); + /* TODO do we have to fallback to output/receive if not set? */ + /* ie is it possible for send/recv to be unset? */ + transtypetuple = get_typeform(transtype, &transtypeform); + serial = transtypeform->typsend; + ReleaseSysCache(transtypetuple); } - else + + Assert(serial != InvalidOid); + + elog(WARNING, "calling serial %d", serial); + fmgr_info(serial, &info); + if (info.fn_strict && box->value_null) { - if (box->value_null) - { - PG_RETURN_NULL(); - } - PG_RETURN_DATUM(box->value); + PG_RETURN_NULL(); } + InitFunctionCallInfoData(*inner_fcinfo, &info, 1, fcinfo->fncollation, + fcinfo->context, fcinfo->resultinfo); + inner_fcinfo->arg[0] = box->value; + inner_fcinfo->argnull[0] = box->value_null; + result = FunctionCallInvoke(inner_fcinfo); + elog(WARNING, "& done %d", VARSIZE(DatumGetByteaPP(result))); + if (inner_fcinfo->isnull) + { + PG_RETURN_NULL(); + } + PG_RETURN_DATUM(result); } /* - * (box, agg, valbytes|value) -> box + * (box, agg, valbytes) -> box * box->agg = agg - * if agg.deserialize: box->value = agg.combine(box->value, agg.deserialize(valbytes)) - * else: box->value = agg.combine(box->value, value) + * box->value = agg.combine(box->value, agg.deserialize(valbytes)) * return box */ Datum @@ -425,13 +502,19 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS) FunctionCallInfo inner_fcinfo = &inner_fcinfodata; FmgrInfo info; HeapTuple aggtuple; + HeapTuple transtypetuple; Form_pg_aggregate aggform; + Form_pg_type transtypeform; Oid combine; Oid deserial; + Oid transtype; + Oid ioparam; Datum value; bool value_null; StypeBox *box; + elog(WARNING, "coord_combine_agg_sfunc"); + if (PG_ARGISNULL(0)) { box = palloc(sizeof(StypeBox)); @@ -445,9 +528,12 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS) Assert(box->agg == PG_GETARG_OID(1)); } + elog(WARNING, "\tbox->agg = %u", box->agg); + aggtuple = get_aggform(box->agg, &aggform); deserial = aggform->aggdeserialfn; combine = aggform->aggcombinefn; + transtype = aggform->aggtranstype; ReleaseSysCache(aggtuple); value_null = PG_ARGISNULL(2); @@ -472,7 +558,41 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS) } else { - value = value_null ? (Datum) 0 : PG_GETARG_DATUM(2); + transtypetuple = get_typeform(transtype, &transtypeform); + ioparam = getTypeIOParam(transtypetuple); + deserial = transtypeform->typreceive; + ReleaseSysCache(transtypetuple); + + fmgr_info(deserial, &info); + if (!value_null || !info.fn_strict) + { + StringInfoData buf; + + if (!value_null) + { + bytea *data = PG_GETARG_BYTEA_PP(2); + initStringInfo(&buf); + appendBinaryStringInfo(&buf, (char *) VARDATA(data), VARSIZE(data) - + VARHDRSZ); + elog(WARNING, "\treceive %d", buf.len); + } + + InitFunctionCallInfoData(*inner_fcinfo, &info, 3, fcinfo->fncollation, + fcinfo->context, fcinfo->resultinfo); + inner_fcinfo->arg[0] = PointerGetDatum(value_null ? NULL : &buf); + inner_fcinfo->arg[1] = ObjectIdGetDatum(ioparam); + inner_fcinfo->arg[2] = Int32GetDatum(-1); /* typmod */ + inner_fcinfo->argnull[0] = value_null; + inner_fcinfo->argnull[1] = false; + inner_fcinfo->argnull[2] = false; + + value = FunctionCallInvoke(inner_fcinfo); + value_null = inner_fcinfo->isnull; + } + else + { + value = (Datum) 0; + } } fmgr_info(combine, &info); @@ -481,6 +601,7 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS) { if (box->value_null) { + elog(WARNING, "\tbox null"); if (value_null) { PG_RETURN_NULL(); @@ -491,10 +612,12 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS) } if (value_null) { + elog(WARNING, "\tvalue null"); PG_RETURN_POINTER(box); } } + elog(WARNING, "\tcombine %u", box->agg); InitFunctionCallInfoData(*inner_fcinfo, &info, 2, fcinfo->fncollation, fcinfo->context, fcinfo->resultinfo); inner_fcinfo->arg[0] = box->value; @@ -530,6 +653,13 @@ coord_combine_agg_ffunc(PG_FUNCTION_ARGS) bool final_strict; int i; + elog(WARNING, "coord_combine_agg_ffunc %p", box); + + if (box == NULL) + { + PG_RETURN_NULL(); + } + aggtuple = get_aggform(box->agg, &aggform); ffunc = aggform->aggfinalfn; fextra = aggform->aggfinalextra;