diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index e763f6944..ca325bbd8 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -2835,7 +2835,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false)); foreach(originalAggArgCell, originalAggregate->args) { - TargetEntry *arg = (TargetEntry *) lfirst(originalAggArgCell); + TargetEntry *arg = lfirst(originalAggArgCell); TargetEntry *newArg = copyObject(arg); newArg->resno++; aggArguments = lappend(aggArguments, newArg); diff --git a/src/backend/distributed/utils/aggregate_utils.c b/src/backend/distributed/utils/aggregate_utils.c index 6291ec823..9ad9d7cdd 100644 --- a/src/backend/distributed/utils/aggregate_utils.c +++ b/src/backend/distributed/utils/aggregate_utils.c @@ -6,11 +6,14 @@ #include "catalog/pg_type.h" #include "distributed/version_compat.h" #include "utils/builtins.h" +#include "utils/datum.h" #include "utils/lsyscache.h" #include "utils/syscache.h" #include "fmgr.h" #include "pg_config_manual.h" +#include "utils/array.h" + PG_FUNCTION_INFO_V1(citus_stype_serialize); PG_FUNCTION_INFO_V1(citus_stype_deserialize); PG_FUNCTION_INFO_V1(citus_stype_combine); @@ -26,14 +29,19 @@ typedef struct StypeBox { Datum value; Oid agg; + Oid transtype; + int16_t transtypeLen; + bool transtypeByVal; bool value_null; + bool value_init; } StypeBox; static HeapTuple get_aggform(Oid oid, Form_pg_aggregate *form); static HeapTuple get_procform(Oid oid, Form_pg_proc *form); static HeapTuple get_typeform(Oid oid, Form_pg_type *form); static void * pallocInAggContext(FunctionCallInfo fcinfo, size_t size); -static void InitializeStypeBox(StypeBox *box, HeapTuple aggTuple, Oid transtype); +static void InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, + Oid transtype); static HeapTuple get_aggform(Oid oid, Form_pg_aggregate *form) @@ -96,11 +104,14 @@ pallocInAggContext(FunctionCallInfo fcinfo, size_t size) * See GetAggInitVal from pg's nodeAgg.c */ static void -InitializeStypeBox(StypeBox *box, HeapTuple aggTuple, Oid transtype) +InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, Oid + transtype) { Datum textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple, Anum_pg_aggregate_agginitval, &box->value_null); + box->transtype = transtype; + box->value_init = !box->value_null; if (box->value_null) { box->value = (Datum) 0; @@ -111,11 +122,22 @@ InitializeStypeBox(StypeBox *box, HeapTuple aggTuple, Oid transtype) typioparam; char *strInitVal; + MemoryContext aggregateContext; + MemoryContext oldContext; + elog(WARNING, "\tworker sfunc seed"); + if (!AggCheckCallContext(fcinfo, &aggregateContext)) + { + elog(WARNING, "worker_partiail_agg_sfunc called from non aggregate context"); + } + oldContext = MemoryContextSwitchTo(aggregateContext); + getTypeInputInfo(transtype, &typinput, &typioparam); strInitVal = TextDatumGetCString(textInitVal); box->value = OidInputFunctionCall(typinput, strInitVal, typioparam, -1); pfree(strInitVal); + + MemoryContextSwitchTo(oldContext); } } @@ -138,7 +160,6 @@ citus_stype_serialize(PG_FUNCTION_ARGS) bytea *valbytes; bytea *realbytes; Oid serial; - Oid transtype; Size valbyteslen_exhdr; Size realbyteslen; Datum result; @@ -148,16 +169,15 @@ citus_stype_serialize(PG_FUNCTION_ARGS) aggtuple = get_aggform(box->agg, &aggform); serial = aggform->aggserialfn; - transtype = aggform->aggtranstype; ReleaseSysCache(aggtuple); if (serial == InvalidOid) { - elog(WARNING, "\tnoserial, load %d", transtype); + elog(WARNING, "\tnoserial, load %d", box->transtype); /* TODO do we have to fallback to output/receive if not set? */ /* ie is it possible for send/recv to be unset? */ - transtypetuple = get_typeform(transtype, &transtypeform); + transtypetuple = get_typeform(box->transtype, &transtypeform); serial = transtypeform->typsend; ReleaseSysCache(transtypetuple); } @@ -224,7 +244,6 @@ citus_stype_deserialize(PG_FUNCTION_ARGS) Form_pg_aggregate aggform; Form_pg_type transtypeform; Oid deserial; - Oid transtype; Oid ioparam; Oid recv; StringInfoData buf; @@ -237,6 +256,16 @@ citus_stype_deserialize(PG_FUNCTION_ARGS) box = pallocInAggContext(fcinfo, sizeof(StypeBox)); box->agg = agg; + + aggtuple = get_aggform(agg, &aggform); + deserial = aggform->aggdeserialfn; + box->transtype = aggform->aggtranstype; + ReleaseSysCache(aggtuple); + + get_typlenbyval(box->transtype, + &box->transtypeLen, + &box->transtypeByVal); + if (value_null) { box->value = (Datum) 0; @@ -244,11 +273,6 @@ citus_stype_deserialize(PG_FUNCTION_ARGS) PG_RETURN_POINTER(box); } - aggtuple = get_aggform(agg, &aggform); - deserial = aggform->aggdeserialfn; - transtype = aggform->aggtranstype; - ReleaseSysCache(aggtuple); - if (deserial != InvalidOid) { FmgrInfo fdeserialinfo; @@ -266,7 +290,6 @@ citus_stype_deserialize(PG_FUNCTION_ARGS) box->value = FunctionCallInvoke(fdeserial_callinfo); box->value_null = fdeserial_callinfo->isnull; } - /* TODO Correct null handling */ else if (value_null) { @@ -275,7 +298,7 @@ citus_stype_deserialize(PG_FUNCTION_ARGS) } else { - transtypetuple = get_typeform(transtype, &transtypeform); + transtypetuple = get_typeform(box->transtype, &transtypeform); ioparam = getTypeIOParam(transtypetuple); recv = transtypeform->typreceive; ReleaseSysCache(transtypetuple); @@ -327,9 +350,9 @@ citus_stype_combine(PG_FUNCTION_ARGS) } box1 = pallocInAggContext(fcinfo, sizeof(StypeBox)); + memcpy(box1, box2, sizeof(StypeBox)); box1->value = (Datum) 0; box1->value_null = true; - box1->agg = box2->agg; } aggtuple = get_aggform(box1->agg, &aggform); @@ -385,6 +408,7 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS) FmgrInfo info; int i; bool is_initial_call = PG_ARGISNULL(0); + Datum newVal; elog(WARNING, "worker_partial_agg_sfunc"); @@ -405,9 +429,15 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS) aggsfunc = aggform->aggtransfn; if (is_initial_call) { - InitializeStypeBox(box, aggtuple, aggform->aggtranstype); + InitializeStypeBox(fcinfo, box, aggtuple, aggform->aggtranstype); } ReleaseSysCache(aggtuple); + if (is_initial_call) + { + get_typlenbyval(box->transtype, + &box->transtypeLen, + &box->transtypeByVal); + } fmgr_info(aggsfunc, &info); InitFunctionCallInfoData(*inner_fcinfo, &info, fcinfo->nargs - 1, fcinfo->fncollation, @@ -422,11 +452,26 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS) PG_RETURN_POINTER(box); } } + if (!box->value_init) + { + MemoryContext aggregateContext; + MemoryContext oldContext; + elog(WARNING, "\tworker sfunc seed"); + if (!AggCheckCallContext(fcinfo, &aggregateContext)) + { + elog(WARNING, + "worker_partiail_agg_sfunc called from non aggregate context"); + } + oldContext = MemoryContextSwitchTo(aggregateContext); + box->value = datumCopy(PG_GETARG_DATUM(2), box->transtypeByVal, + box->transtypeLen); + MemoryContextSwitchTo(oldContext); + box->value_null = false; + box->value_init = true; + PG_RETURN_POINTER(box); + } if (box->value_null) { - elog(WARNING, "\tworker sfunc seed"); - box->value = PG_GETARG_DATUM(2); - box->value_null = false; PG_RETURN_POINTER(box); } } @@ -439,9 +484,35 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS) i + 1)); } elog(WARNING, "invoke sfunc"); - box->value = FunctionCallInvoke(inner_fcinfo); + newVal = FunctionCallInvoke(inner_fcinfo); box->value_null = inner_fcinfo->isnull; + if (!box->value_null) + { + MemoryContext aggregateContext; + MemoryContext oldContext; + if (!AggCheckCallContext(fcinfo, &aggregateContext)) + { + elog(WARNING, "worker_partial_agg_sfunc called from non aggregate context"); + } + + oldContext = MemoryContextSwitchTo(aggregateContext); + if (!(DatumIsReadWriteExpandedObject(box->value, + false, + box->transtypeLen) && + MemoryContextGetParent(DatumGetEOHP(newVal)->eoh_context) == + CurrentMemoryContext)) + { + newVal = datumCopy(newVal, box->transtypeByVal, box->transtypeLen); + } + MemoryContextSwitchTo(oldContext); + } + + box->value = newVal; + + elog(WARNING, "\tworker sfunc ret: %ld", box->value); + elog(WARNING, "\tworker sfunc ret type: %d", AARR_ELEMTYPE(DatumGetAnyArrayP( + box->value))); elog(WARNING, "\tworker sfunc agg: %d", box->agg); elog(WARNING, "\tworker sfunc null: %d", box->value_null); @@ -483,7 +554,7 @@ worker_partial_agg_ffunc(PG_FUNCTION_ARGS) if (serial == InvalidOid) { - elog(WARNING, "\tload typeform %d", transtype); + elog(WARNING, "\tload typeform %d", box->transtype); /* TODO do we have to fallback to output/receive if not set? */ /* ie is it possible for send/recv to be unset? */ @@ -508,7 +579,14 @@ worker_partial_agg_ffunc(PG_FUNCTION_ARGS) elog(WARNING, "\t\tinvoke inner_fcinfo %p %p", info.fn_addr, array_send); result = FunctionCallInvoke(inner_fcinfo); - elog(WARNING, "\t\t& done %d", VARSIZE(DatumGetByteaPP(result))); + + elog(WARNING, "\t\t& done %ld %d", result, VARSIZE(DatumGetByteaPP(result)) - + VARHDRSZ); + for (int i = 0; i < VARSIZE(DatumGetByteaPP(result)) - VARHDRSZ; i++) + { + elog(WARNING, "\t\t%d\t%d", i, VARDATA(DatumGetByteaPP(result))[i]); + } + if (inner_fcinfo->isnull) { PG_RETURN_NULL(); @@ -534,7 +612,6 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS) Form_pg_type transtypeform; Oid combine; Oid deserial; - Oid transtype; Oid ioparam; Datum value; bool value_null; @@ -545,24 +622,32 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS) if (PG_ARGISNULL(0)) { box = pallocInAggContext(fcinfo, sizeof(StypeBox)); - box->agg = PG_GETARG_OID(1); - box->value = (Datum) 0; - box->value_null = true; + elog(WARNING, "\tinit"); } else { box = (StypeBox *) PG_GETARG_POINTER(0); Assert(box->agg == PG_GETARG_OID(1)); + elog(WARNING, "\talready %d", box->value_null); } elog(WARNING, "\tbox->agg = %u", box->agg); aggtuple = get_aggform(box->agg, &aggform); + if (PG_ARGISNULL(0)) + { + InitializeStypeBox(fcinfo, box, aggtuple, aggform->aggtranstype); + } deserial = aggform->aggdeserialfn; combine = aggform->aggcombinefn; - transtype = aggform->aggtranstype; ReleaseSysCache(aggtuple); + if (PG_ARGISNULL(0)) + { + get_typlenbyval(box->transtype, + &box->transtypeLen, + &box->transtypeByVal); + } value_null = PG_ARGISNULL(2); if (deserial != InvalidOid) @@ -585,7 +670,7 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS) } else { - transtypetuple = get_typeform(transtype, &transtypeform); + transtypetuple = get_typeform(box->transtype, &transtypeform); ioparam = getTypeIOParam(transtypetuple); deserial = transtypeform->typreceive; ReleaseSysCache(transtypetuple); @@ -599,9 +684,13 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS) { bytea *data = PG_GETARG_BYTEA_PP(2); initStringInfo(&buf); - appendBinaryStringInfo(&buf, (char *) VARDATA(data), VARSIZE(data) - - VARHDRSZ); - elog(WARNING, "\treceive %d", buf.len); + appendBinaryStringInfo(&buf, (char *) VARDATA_ANY(data), + VARSIZE_ANY_EXHDR(data)); + elog(WARNING, "\treceive %ld %d", PG_GETARG_DATUM(2), buf.len); + for (int i = 0; i < buf.len; i++) + { + elog(WARNING, "\t%d\t%d", i, buf.data[i]); + } } InitFunctionCallInfoData(*inner_fcinfo, &info, 3, fcinfo->fncollation, @@ -643,6 +732,8 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS) } elog(WARNING, "\tcombine %u", box->agg); + elog(WARNING, "\t\t %d", *AARR_DIMS(DatumGetAnyArrayP(value))); + elog(WARNING, "\t\t %d", *AARR_DIMS(DatumGetAnyArrayP(box->value))); InitFunctionCallInfoData(*inner_fcinfo, &info, 2, fcinfo->fncollation, fcinfo->context, fcinfo->resultinfo); fcSetArgExt(inner_fcinfo, 0, box->value, box->value_null);