Now able to figure out that the average of 1,2,3 is 2. Next: clean up all the logspam, make sure value_init is being handled properly everywhere

fix_120_custom_aggregates_distribute_multiarg
Philip Dubé 2019-08-30 23:29:10 +00:00
parent 5b25204dda
commit d1aefe240b
2 changed files with 123 additions and 32 deletions

View File

@ -2835,7 +2835,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false)); aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
foreach(originalAggArgCell, originalAggregate->args) foreach(originalAggArgCell, originalAggregate->args)
{ {
TargetEntry *arg = (TargetEntry *) lfirst(originalAggArgCell); TargetEntry *arg = lfirst(originalAggArgCell);
TargetEntry *newArg = copyObject(arg); TargetEntry *newArg = copyObject(arg);
newArg->resno++; newArg->resno++;
aggArguments = lappend(aggArguments, newArg); aggArguments = lappend(aggArguments, newArg);

View File

@ -6,11 +6,14 @@
#include "catalog/pg_type.h" #include "catalog/pg_type.h"
#include "distributed/version_compat.h" #include "distributed/version_compat.h"
#include "utils/builtins.h" #include "utils/builtins.h"
#include "utils/datum.h"
#include "utils/lsyscache.h" #include "utils/lsyscache.h"
#include "utils/syscache.h" #include "utils/syscache.h"
#include "fmgr.h" #include "fmgr.h"
#include "pg_config_manual.h" #include "pg_config_manual.h"
#include "utils/array.h"
PG_FUNCTION_INFO_V1(citus_stype_serialize); PG_FUNCTION_INFO_V1(citus_stype_serialize);
PG_FUNCTION_INFO_V1(citus_stype_deserialize); PG_FUNCTION_INFO_V1(citus_stype_deserialize);
PG_FUNCTION_INFO_V1(citus_stype_combine); PG_FUNCTION_INFO_V1(citus_stype_combine);
@ -26,14 +29,19 @@ typedef struct StypeBox
{ {
Datum value; Datum value;
Oid agg; Oid agg;
Oid transtype;
int16_t transtypeLen;
bool transtypeByVal;
bool value_null; bool value_null;
bool value_init;
} StypeBox; } StypeBox;
static HeapTuple get_aggform(Oid oid, Form_pg_aggregate *form); static HeapTuple get_aggform(Oid oid, Form_pg_aggregate *form);
static HeapTuple get_procform(Oid oid, Form_pg_proc *form); static HeapTuple get_procform(Oid oid, Form_pg_proc *form);
static HeapTuple get_typeform(Oid oid, Form_pg_type *form); static HeapTuple get_typeform(Oid oid, Form_pg_type *form);
static void * pallocInAggContext(FunctionCallInfo fcinfo, size_t size); static void * pallocInAggContext(FunctionCallInfo fcinfo, size_t size);
static void InitializeStypeBox(StypeBox *box, HeapTuple aggTuple, Oid transtype); static void InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple,
Oid transtype);
static HeapTuple static HeapTuple
get_aggform(Oid oid, Form_pg_aggregate *form) get_aggform(Oid oid, Form_pg_aggregate *form)
@ -96,11 +104,14 @@ pallocInAggContext(FunctionCallInfo fcinfo, size_t size)
* See GetAggInitVal from pg's nodeAgg.c * See GetAggInitVal from pg's nodeAgg.c
*/ */
static void static void
InitializeStypeBox(StypeBox *box, HeapTuple aggTuple, Oid transtype) InitializeStypeBox(FunctionCallInfo fcinfo, StypeBox *box, HeapTuple aggTuple, Oid
transtype)
{ {
Datum textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple, Datum textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple,
Anum_pg_aggregate_agginitval, Anum_pg_aggregate_agginitval,
&box->value_null); &box->value_null);
box->transtype = transtype;
box->value_init = !box->value_null;
if (box->value_null) if (box->value_null)
{ {
box->value = (Datum) 0; box->value = (Datum) 0;
@ -111,11 +122,22 @@ InitializeStypeBox(StypeBox *box, HeapTuple aggTuple, Oid transtype)
typioparam; typioparam;
char *strInitVal; char *strInitVal;
MemoryContext aggregateContext;
MemoryContext oldContext;
elog(WARNING, "\tworker sfunc seed");
if (!AggCheckCallContext(fcinfo, &aggregateContext))
{
elog(WARNING, "worker_partiail_agg_sfunc called from non aggregate context");
}
oldContext = MemoryContextSwitchTo(aggregateContext);
getTypeInputInfo(transtype, &typinput, &typioparam); getTypeInputInfo(transtype, &typinput, &typioparam);
strInitVal = TextDatumGetCString(textInitVal); strInitVal = TextDatumGetCString(textInitVal);
box->value = OidInputFunctionCall(typinput, strInitVal, box->value = OidInputFunctionCall(typinput, strInitVal,
typioparam, -1); typioparam, -1);
pfree(strInitVal); pfree(strInitVal);
MemoryContextSwitchTo(oldContext);
} }
} }
@ -138,7 +160,6 @@ citus_stype_serialize(PG_FUNCTION_ARGS)
bytea *valbytes; bytea *valbytes;
bytea *realbytes; bytea *realbytes;
Oid serial; Oid serial;
Oid transtype;
Size valbyteslen_exhdr; Size valbyteslen_exhdr;
Size realbyteslen; Size realbyteslen;
Datum result; Datum result;
@ -148,16 +169,15 @@ citus_stype_serialize(PG_FUNCTION_ARGS)
aggtuple = get_aggform(box->agg, &aggform); aggtuple = get_aggform(box->agg, &aggform);
serial = aggform->aggserialfn; serial = aggform->aggserialfn;
transtype = aggform->aggtranstype;
ReleaseSysCache(aggtuple); ReleaseSysCache(aggtuple);
if (serial == InvalidOid) if (serial == InvalidOid)
{ {
elog(WARNING, "\tnoserial, load %d", transtype); elog(WARNING, "\tnoserial, load %d", box->transtype);
/* TODO do we have to fallback to output/receive if not set? */ /* TODO do we have to fallback to output/receive if not set? */
/* ie is it possible for send/recv to be unset? */ /* ie is it possible for send/recv to be unset? */
transtypetuple = get_typeform(transtype, &transtypeform); transtypetuple = get_typeform(box->transtype, &transtypeform);
serial = transtypeform->typsend; serial = transtypeform->typsend;
ReleaseSysCache(transtypetuple); ReleaseSysCache(transtypetuple);
} }
@ -224,7 +244,6 @@ citus_stype_deserialize(PG_FUNCTION_ARGS)
Form_pg_aggregate aggform; Form_pg_aggregate aggform;
Form_pg_type transtypeform; Form_pg_type transtypeform;
Oid deserial; Oid deserial;
Oid transtype;
Oid ioparam; Oid ioparam;
Oid recv; Oid recv;
StringInfoData buf; StringInfoData buf;
@ -237,6 +256,16 @@ citus_stype_deserialize(PG_FUNCTION_ARGS)
box = pallocInAggContext(fcinfo, sizeof(StypeBox)); box = pallocInAggContext(fcinfo, sizeof(StypeBox));
box->agg = agg; box->agg = agg;
aggtuple = get_aggform(agg, &aggform);
deserial = aggform->aggdeserialfn;
box->transtype = aggform->aggtranstype;
ReleaseSysCache(aggtuple);
get_typlenbyval(box->transtype,
&box->transtypeLen,
&box->transtypeByVal);
if (value_null) if (value_null)
{ {
box->value = (Datum) 0; box->value = (Datum) 0;
@ -244,11 +273,6 @@ citus_stype_deserialize(PG_FUNCTION_ARGS)
PG_RETURN_POINTER(box); PG_RETURN_POINTER(box);
} }
aggtuple = get_aggform(agg, &aggform);
deserial = aggform->aggdeserialfn;
transtype = aggform->aggtranstype;
ReleaseSysCache(aggtuple);
if (deserial != InvalidOid) if (deserial != InvalidOid)
{ {
FmgrInfo fdeserialinfo; FmgrInfo fdeserialinfo;
@ -266,7 +290,6 @@ citus_stype_deserialize(PG_FUNCTION_ARGS)
box->value = FunctionCallInvoke(fdeserial_callinfo); box->value = FunctionCallInvoke(fdeserial_callinfo);
box->value_null = fdeserial_callinfo->isnull; box->value_null = fdeserial_callinfo->isnull;
} }
/* TODO Correct null handling */ /* TODO Correct null handling */
else if (value_null) else if (value_null)
{ {
@ -275,7 +298,7 @@ citus_stype_deserialize(PG_FUNCTION_ARGS)
} }
else else
{ {
transtypetuple = get_typeform(transtype, &transtypeform); transtypetuple = get_typeform(box->transtype, &transtypeform);
ioparam = getTypeIOParam(transtypetuple); ioparam = getTypeIOParam(transtypetuple);
recv = transtypeform->typreceive; recv = transtypeform->typreceive;
ReleaseSysCache(transtypetuple); ReleaseSysCache(transtypetuple);
@ -327,9 +350,9 @@ citus_stype_combine(PG_FUNCTION_ARGS)
} }
box1 = pallocInAggContext(fcinfo, sizeof(StypeBox)); box1 = pallocInAggContext(fcinfo, sizeof(StypeBox));
memcpy(box1, box2, sizeof(StypeBox));
box1->value = (Datum) 0; box1->value = (Datum) 0;
box1->value_null = true; box1->value_null = true;
box1->agg = box2->agg;
} }
aggtuple = get_aggform(box1->agg, &aggform); aggtuple = get_aggform(box1->agg, &aggform);
@ -385,6 +408,7 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS)
FmgrInfo info; FmgrInfo info;
int i; int i;
bool is_initial_call = PG_ARGISNULL(0); bool is_initial_call = PG_ARGISNULL(0);
Datum newVal;
elog(WARNING, "worker_partial_agg_sfunc"); elog(WARNING, "worker_partial_agg_sfunc");
@ -405,9 +429,15 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS)
aggsfunc = aggform->aggtransfn; aggsfunc = aggform->aggtransfn;
if (is_initial_call) if (is_initial_call)
{ {
InitializeStypeBox(box, aggtuple, aggform->aggtranstype); InitializeStypeBox(fcinfo, box, aggtuple, aggform->aggtranstype);
} }
ReleaseSysCache(aggtuple); ReleaseSysCache(aggtuple);
if (is_initial_call)
{
get_typlenbyval(box->transtype,
&box->transtypeLen,
&box->transtypeByVal);
}
fmgr_info(aggsfunc, &info); fmgr_info(aggsfunc, &info);
InitFunctionCallInfoData(*inner_fcinfo, &info, fcinfo->nargs - 1, fcinfo->fncollation, InitFunctionCallInfoData(*inner_fcinfo, &info, fcinfo->nargs - 1, fcinfo->fncollation,
@ -422,11 +452,26 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS)
PG_RETURN_POINTER(box); PG_RETURN_POINTER(box);
} }
} }
if (!box->value_init)
{
MemoryContext aggregateContext;
MemoryContext oldContext;
elog(WARNING, "\tworker sfunc seed");
if (!AggCheckCallContext(fcinfo, &aggregateContext))
{
elog(WARNING,
"worker_partiail_agg_sfunc called from non aggregate context");
}
oldContext = MemoryContextSwitchTo(aggregateContext);
box->value = datumCopy(PG_GETARG_DATUM(2), box->transtypeByVal,
box->transtypeLen);
MemoryContextSwitchTo(oldContext);
box->value_null = false;
box->value_init = true;
PG_RETURN_POINTER(box);
}
if (box->value_null) if (box->value_null)
{ {
elog(WARNING, "\tworker sfunc seed");
box->value = PG_GETARG_DATUM(2);
box->value_null = false;
PG_RETURN_POINTER(box); PG_RETURN_POINTER(box);
} }
} }
@ -439,9 +484,35 @@ worker_partial_agg_sfunc(PG_FUNCTION_ARGS)
i + 1)); i + 1));
} }
elog(WARNING, "invoke sfunc"); elog(WARNING, "invoke sfunc");
box->value = FunctionCallInvoke(inner_fcinfo); newVal = FunctionCallInvoke(inner_fcinfo);
box->value_null = inner_fcinfo->isnull; box->value_null = inner_fcinfo->isnull;
if (!box->value_null)
{
MemoryContext aggregateContext;
MemoryContext oldContext;
if (!AggCheckCallContext(fcinfo, &aggregateContext))
{
elog(WARNING, "worker_partial_agg_sfunc called from non aggregate context");
}
oldContext = MemoryContextSwitchTo(aggregateContext);
if (!(DatumIsReadWriteExpandedObject(box->value,
false,
box->transtypeLen) &&
MemoryContextGetParent(DatumGetEOHP(newVal)->eoh_context) ==
CurrentMemoryContext))
{
newVal = datumCopy(newVal, box->transtypeByVal, box->transtypeLen);
}
MemoryContextSwitchTo(oldContext);
}
box->value = newVal;
elog(WARNING, "\tworker sfunc ret: %ld", box->value);
elog(WARNING, "\tworker sfunc ret type: %d", AARR_ELEMTYPE(DatumGetAnyArrayP(
box->value)));
elog(WARNING, "\tworker sfunc agg: %d", box->agg); elog(WARNING, "\tworker sfunc agg: %d", box->agg);
elog(WARNING, "\tworker sfunc null: %d", box->value_null); elog(WARNING, "\tworker sfunc null: %d", box->value_null);
@ -483,7 +554,7 @@ worker_partial_agg_ffunc(PG_FUNCTION_ARGS)
if (serial == InvalidOid) if (serial == InvalidOid)
{ {
elog(WARNING, "\tload typeform %d", transtype); elog(WARNING, "\tload typeform %d", box->transtype);
/* TODO do we have to fallback to output/receive if not set? */ /* TODO do we have to fallback to output/receive if not set? */
/* ie is it possible for send/recv to be unset? */ /* ie is it possible for send/recv to be unset? */
@ -508,7 +579,14 @@ worker_partial_agg_ffunc(PG_FUNCTION_ARGS)
elog(WARNING, "\t\tinvoke inner_fcinfo %p %p", info.fn_addr, array_send); elog(WARNING, "\t\tinvoke inner_fcinfo %p %p", info.fn_addr, array_send);
result = FunctionCallInvoke(inner_fcinfo); result = FunctionCallInvoke(inner_fcinfo);
elog(WARNING, "\t\t& done %d", VARSIZE(DatumGetByteaPP(result)));
elog(WARNING, "\t\t& done %ld %d", result, VARSIZE(DatumGetByteaPP(result)) -
VARHDRSZ);
for (int i = 0; i < VARSIZE(DatumGetByteaPP(result)) - VARHDRSZ; i++)
{
elog(WARNING, "\t\t%d\t%d", i, VARDATA(DatumGetByteaPP(result))[i]);
}
if (inner_fcinfo->isnull) if (inner_fcinfo->isnull)
{ {
PG_RETURN_NULL(); PG_RETURN_NULL();
@ -534,7 +612,6 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS)
Form_pg_type transtypeform; Form_pg_type transtypeform;
Oid combine; Oid combine;
Oid deserial; Oid deserial;
Oid transtype;
Oid ioparam; Oid ioparam;
Datum value; Datum value;
bool value_null; bool value_null;
@ -545,24 +622,32 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS)
if (PG_ARGISNULL(0)) if (PG_ARGISNULL(0))
{ {
box = pallocInAggContext(fcinfo, sizeof(StypeBox)); box = pallocInAggContext(fcinfo, sizeof(StypeBox));
box->agg = PG_GETARG_OID(1); box->agg = PG_GETARG_OID(1);
box->value = (Datum) 0; elog(WARNING, "\tinit");
box->value_null = true;
} }
else else
{ {
box = (StypeBox *) PG_GETARG_POINTER(0); box = (StypeBox *) PG_GETARG_POINTER(0);
Assert(box->agg == PG_GETARG_OID(1)); Assert(box->agg == PG_GETARG_OID(1));
elog(WARNING, "\talready %d", box->value_null);
} }
elog(WARNING, "\tbox->agg = %u", box->agg); elog(WARNING, "\tbox->agg = %u", box->agg);
aggtuple = get_aggform(box->agg, &aggform); aggtuple = get_aggform(box->agg, &aggform);
if (PG_ARGISNULL(0))
{
InitializeStypeBox(fcinfo, box, aggtuple, aggform->aggtranstype);
}
deserial = aggform->aggdeserialfn; deserial = aggform->aggdeserialfn;
combine = aggform->aggcombinefn; combine = aggform->aggcombinefn;
transtype = aggform->aggtranstype;
ReleaseSysCache(aggtuple); ReleaseSysCache(aggtuple);
if (PG_ARGISNULL(0))
{
get_typlenbyval(box->transtype,
&box->transtypeLen,
&box->transtypeByVal);
}
value_null = PG_ARGISNULL(2); value_null = PG_ARGISNULL(2);
if (deserial != InvalidOid) if (deserial != InvalidOid)
@ -585,7 +670,7 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS)
} }
else else
{ {
transtypetuple = get_typeform(transtype, &transtypeform); transtypetuple = get_typeform(box->transtype, &transtypeform);
ioparam = getTypeIOParam(transtypetuple); ioparam = getTypeIOParam(transtypetuple);
deserial = transtypeform->typreceive; deserial = transtypeform->typreceive;
ReleaseSysCache(transtypetuple); ReleaseSysCache(transtypetuple);
@ -599,9 +684,13 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS)
{ {
bytea *data = PG_GETARG_BYTEA_PP(2); bytea *data = PG_GETARG_BYTEA_PP(2);
initStringInfo(&buf); initStringInfo(&buf);
appendBinaryStringInfo(&buf, (char *) VARDATA(data), VARSIZE(data) - appendBinaryStringInfo(&buf, (char *) VARDATA_ANY(data),
VARHDRSZ); VARSIZE_ANY_EXHDR(data));
elog(WARNING, "\treceive %d", buf.len); elog(WARNING, "\treceive %ld %d", PG_GETARG_DATUM(2), buf.len);
for (int i = 0; i < buf.len; i++)
{
elog(WARNING, "\t%d\t%d", i, buf.data[i]);
}
} }
InitFunctionCallInfoData(*inner_fcinfo, &info, 3, fcinfo->fncollation, InitFunctionCallInfoData(*inner_fcinfo, &info, 3, fcinfo->fncollation,
@ -643,6 +732,8 @@ coord_combine_agg_sfunc(PG_FUNCTION_ARGS)
} }
elog(WARNING, "\tcombine %u", box->agg); elog(WARNING, "\tcombine %u", box->agg);
elog(WARNING, "\t\t %d", *AARR_DIMS(DatumGetAnyArrayP(value)));
elog(WARNING, "\t\t %d", *AARR_DIMS(DatumGetAnyArrayP(box->value)));
InitFunctionCallInfoData(*inner_fcinfo, &info, 2, fcinfo->fncollation, InitFunctionCallInfoData(*inner_fcinfo, &info, 2, fcinfo->fncollation,
fcinfo->context, fcinfo->resultinfo); fcinfo->context, fcinfo->resultinfo);
fcSetArgExt(inner_fcinfo, 0, box->value, box->value_null); fcSetArgExt(inner_fcinfo, 0, box->value, box->value_null);