diff --git a/src/backend/distributed/commands/function.c b/src/backend/distributed/commands/function.c index b5c00fd9a..a82770f69 100644 --- a/src/backend/distributed/commands/function.c +++ b/src/backend/distributed/commands/function.c @@ -44,6 +44,7 @@ #include "distributed/multi_logical_optimizer.h" #include "distributed/relation_access_tracking.h" #include "distributed/worker_transaction.h" +#include "executor/spi.h" #include "parser/parse_coerce.h" #include "parser/parse_type.h" #include "storage/lmgr.h" @@ -62,7 +63,7 @@ static char * GetFunctionDDLCommand(const RegProcedure funcOid); static char * GetFunctionAlterOwnerCommand(const RegProcedure funcOid); static void CreateAggregateHelper(const char *helperName, const char *helperPrefix, Form_pg_proc proc, Form_pg_aggregate agg, int numargs, - Oid *argtypes); + Oid *argtypes, bool finalextra); static int GetDistributionArgIndex(Oid functionOid, char *distributionArgumentName, Oid *distributionArgumentOid); static int GetFunctionColocationId(Oid functionOid, char *colocateWithName, Oid @@ -238,6 +239,7 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS) int numargs = 0; Oid *argtypes = NULL; + elog(WARNING, "prelude"); if (!HeapTupleIsValid(proctup)) { @@ -257,6 +259,8 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS) } agg = (Form_pg_aggregate) GETSTRUCT(aggtup); + elog(WARNING, "begins"); + initStringInfo(&helperSuffix); appendStringInfoAggregateHelperSuffix(&helperSuffix, proc, agg); @@ -264,26 +268,80 @@ mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS) initStringInfo(&helperName); appendStringInfo(&helperName, "%s%s", COORD_COMBINE_AGGREGATE_NAME, helperSuffix.data); - helperOid = AggregateHelperOid(helperName.data, proctup, &numargs, &argtypes); + helperOid = CoordCombineAggOid(helperName.data); + + elog(WARNING, "helperOid %d", helperOid); if (helperOid == InvalidOid) { + Oid coordArgTypes[2] = { BYTEAOID, ANYELEMENTOID }; + CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc, agg, - numargs, argtypes); + 2, coordArgTypes, true); } /* worker_partial_agg */ resetStringInfo(&helperName); appendStringInfo(&helperName, "%s%s", WORKER_PARTIAL_AGGREGATE_NAME, helperSuffix.data); - helperOid = AggregateHelperOid(helperName.data, proctup, &numargs, &argtypes); + helperOid = WorkerPartialAggOid(helperName.data, proctup, &numargs, &argtypes); if (helperOid == InvalidOid) { + /* Also check that we have a matching worker_partial_sfunc(internal, oid, ...) */ + FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString("citus"), + makeString( + "worker_partial_agg_sfunc")), + numargs + 2, + NIL, false, false, true); + + for (; clist; clist = clist->next) + { + if (clist->args[0] == INTERNALOID && clist->args[1] == OIDOID && + memcmp(clist->args + 2, argtypes, numargs * sizeof(Oid)) == 0) + { + break; + return clist->oid; + } + } + + if (clist == NULL) + { + int i; + StringInfoData command; + + initStringInfo(&command); + appendStringInfoString(&command, + "CREATE FUNCTION citus.worker_partial_agg_sfunc(internal, oid"); + for (i = 0; i < numargs; i++) + { + appendStringInfo(&command, ", %s", format_type_be_qualified(argtypes[i])); + } + appendStringInfoString(&command, + ") RETURNS internal AS 'citus' LANGUAGE C PARALLEL SAFE"); + + SendCommandToWorkers(ALL_WORKERS, command.data); + + /* TODO execute as CitusExtensionOwner */ + if (SPI_connect() != SPI_OK_CONNECT) + { + elog(ERROR, "SPI_connect failed"); + } + if (SPI_execute(command.data, false, 0) < 0) + { + elog(ERROR, "SPI_execute %s failed", command.data); + } + SPI_finish(); + + pfree(command.data); + } + CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc, agg, - numargs, argtypes); + numargs, argtypes, false); } + elog(WARNING, "mark em"); + /* set strategy column value */ UpdateDistObjectAggregationStrategy(funcOid, AGGREGATION_STRATEGY_COMBINE); @@ -302,10 +360,36 @@ early_exit: /* - * AggregateHelperOid returns helper aggregate oid for given proc's HeapTuple + * CoordCombineAggOid returns coord_combine_agg oid with given name. */ Oid -AggregateHelperOid(char *helperName, HeapTuple proctup, int *numargs, Oid **argtypes) +CoordCombineAggOid(char *helperName) +{ + FuncCandidateList clist; + + clist = FuncnameGetCandidates(list_make2(makeString("citus"), + makeString(helperName)), + 3, NIL, false, false, true); + + for (; clist; clist = clist->next) + { + if (clist->args[0] == OIDOID && + clist->args[1] == BYTEAOID && + clist->args[2] == ANYELEMENTOID) + { + return clist->oid; + } + } + + return InvalidOid; +} + + +/* + * WorkerPartialAggOid returns worker_partial_agg oid for given proc's HeapTuple. + */ +Oid +WorkerPartialAggOid(char *helperName, HeapTuple proctup, int *numargs, Oid **argtypes) { char **argnames = NULL; char *argmodes = NULL; @@ -318,8 +402,8 @@ AggregateHelperOid(char *helperName, HeapTuple proctup, int *numargs, Oid **argt for (; clist; clist = clist->next) { - if (clist->args[0] == OIDOID && memcmp(clist->args + 1, argtypes, *numargs * - sizeof(Oid)) == 0) + if (clist->args[0] == OIDOID && + memcmp(clist->args + 1, *argtypes, *numargs * sizeof(Oid)) == 0) { return clist->oid; } @@ -332,7 +416,9 @@ AggregateHelperOid(char *helperName, HeapTuple proctup, int *numargs, Oid **argt /* * AggregateHelperName returns helper function name for a given aggregate. */ -void appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix, Form_pg_proc proc, Form_pg_aggregate agg) +void +appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix, Form_pg_proc proc, + Form_pg_aggregate agg) { switch (proc->proparallel) { @@ -377,11 +463,6 @@ void appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix, Form_pg_proc break; } } - - if (agg->aggfinalextra) - { - appendStringInfoString(helperSuffix, "_fx"); - } } if (agg->aggmfinalfn != InvalidOid) @@ -414,7 +495,6 @@ void appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix, Form_pg_proc appendStringInfoString(helperSuffix, "_fx"); } } - } @@ -424,18 +504,23 @@ void appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix, Form_pg_proc static void CreateAggregateHelper(const char *helperName, const char *helperPrefix, Form_pg_proc proc, Form_pg_aggregate agg, int numargs, - Oid *argtypes) + Oid *argtypes, bool finalextra) { + int i; StringInfoData command; initStringInfo(&command); - appendStringInfo(&command, "CREATE AGGREGATE %s(" - "STYPE = internal, SFUNC = citus.%s_sfunc, FINALFUNC = %s_ffunc" - ", COMBINEFUNC = citus.citus_stype_combine" - ", SERIALFUNC = citus.citus_stype_serialize" - ", DESERIALFUNC = citus.citus_stype_deserialize", - quote_qualified_identifier("citus", helperName), + appendStringInfo(&command, "CREATE AGGREGATE citus.%s(oid", helperName); + for (i = 0; i < numargs; i++) + { + appendStringInfo(&command, ", %s", format_type_be_qualified(argtypes[i])); + } + appendStringInfo(&command, + ") (STYPE = internal, SFUNC = citus.%s_sfunc, FINALFUNC = citus.%s_ffunc" + ", COMBINEFUNC = citus.citus_stype_combine" + ", SERIALFUNC = citus.citus_stype_serialize" + ", DESERIALFUNC = citus.citus_stype_deserialize", helperPrefix, helperPrefix); switch (proc->proparallel) @@ -482,45 +567,30 @@ CreateAggregateHelper(const char *helperName, const char *helperPrefix, } } - if (agg->aggfinalextra) + if (finalextra) { appendStringInfoString(&command, ", FINALFUNC_EXTRA"); } } - if (agg->aggmfinalfn != InvalidOid) - { - switch (agg->aggmfinalmodify) - { - case AGGMODIFY_READ_ONLY: - { - appendStringInfoString(&command, ", MFINALFUNC_MODIFY = READ_ONLY"); - break; - } - - case AGGMODIFY_SHAREABLE: - { - appendStringInfoString(&command, ", MFINALFUNC_MODIFY = SHAREABLE"); - break; - } - - case AGGMODIFY_READ_WRITE: - { - appendStringInfoString(&command, ", MFINALFUNC_MODIFY = READ_WRITE"); - break; - } - } - - if (agg->aggmfinalextra) - { - appendStringInfoString(&command, ", MFINALFUNC_EXTRA"); - } - } - appendStringInfoChar(&command, ')'); + elog(WARNING, "SEND %s", command.data); SendCommandToWorkers(ALL_WORKERS, command.data); + /* TODO execute as CitusExtensionOwner */ + if (SPI_connect() != SPI_OK_CONNECT) + { + elog(ERROR, "SPI_connect failed"); + } + if (SPI_execute(command.data, false, 0) < 0) + { + elog(ERROR, "SPI_execute %s failed", command.data); + } + SPI_finish(); + + elog(WARNING, "SENT"); + pfree(command.data); } @@ -805,10 +875,10 @@ UpdateDistObjectAggregationStrategy(Oid funcOid, int aggregationStrategy) memset(replace, 0, sizeof(replace)); - replace[Anum_pg_dist_object_distribution_argument_index - 1] = true; - values[Anum_pg_dist_object_distribution_argument_index - 1] = Int32GetDatum( + replace[Anum_pg_dist_object_aggregation_strategy - 1] = true; + values[Anum_pg_dist_object_aggregation_strategy - 1] = Int32GetDatum( aggregationStrategy); - isnull[Anum_pg_dist_object_distribution_argument_index - 1] = false; + isnull[Anum_pg_dist_object_aggregation_strategy - 1] = false; heapTuple = heap_modify_tuple(heapTuple, tupleDescriptor, values, isnull, replace); @@ -821,6 +891,8 @@ UpdateDistObjectAggregationStrategy(Oid funcOid, int aggregationStrategy) systable_endscan(scanDescriptor); heap_close(pgDistObjectRel, NoLock); + + elog(WARNING, "marked %d %d", funcOid, aggregationStrategy); } @@ -1148,7 +1220,7 @@ GetAggregateDDLCommand(const RegProcedure funcOid) appendStringInfo(&buf, "%s ", quote_identifier(argname)); } - appendStringInfoString(&buf, format_type_be(argtype)); + appendStringInfoString(&buf, format_type_be_qualified(argtype)); argsprinted++; diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index 28347733a..bc4117684 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -255,7 +255,9 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate, static AggregateType GetAggregateType(Oid aggFunctionId); static Oid AggregateArgumentType(Aggref *aggregate); static bool AggregateEnabledCustom(Oid aggregateOid); -static Oid AggregateFunctionHelperOid(const char *helperPrefix, Oid aggOid); +static HeapTuple AggregateFunctionHelperOidHelper(Oid aggOid, StringInfo helperName); +static Oid AggregateCoordCombineOid(Oid aggOid); +static Oid AggregateWorkerPartialOid(Oid aggOid); static Oid AggregateFunctionOid(const char *functionName, Oid inputType); static Oid TypeOid(Oid schemaId, const char *typeName); static SortGroupClause * CreateSortGroupClause(Var *column); @@ -1852,6 +1854,7 @@ MasterAggregateExpression(Aggref *originalAggregate, ObjectIdGetDatum(originalAggregate->aggfnoid)); if (!HeapTupleIsValid(aggTuple)) { + elog(WARNING, "!@#"); elog(WARNING, "citus cache lookup failed for aggregate %u", originalAggregate->aggfnoid); combine = InvalidOid; @@ -1869,13 +1872,18 @@ MasterAggregateExpression(Aggref *originalAggregate, Var *column = NULL; List *aggArguments = NIL; Aggref *newMasterAggregate = NULL; - Oid coordCombineId = AggregateFunctionHelperOid( - COORD_COMBINE_AGGREGATE_NAME, originalAggregate->aggfnoid); - + Oid coordCombineId = AggregateCoordCombineOid(originalAggregate->aggfnoid); Oid workerReturnType = BYTEAOID; int32 workerReturnTypeMod = -1; Oid workerCollationId = InvalidOid; + if (coordCombineId == InvalidOid) + { + elog(ERROR, + "Could not find " COORD_COMBINE_AGGREGATE_NAME + " with correct signature."); + } + aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum( originalAggregate->aggfnoid), false, true); column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, @@ -1939,7 +1947,6 @@ MasterAggregateExpression(Aggref *originalAggregate, newMasterExpression = (Expr *) newMasterAggregate; } - /* * Aggregate functions could have changed the return type. If so, we wrap * the new expression with a conversion function to make it have the same @@ -2932,6 +2939,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, ObjectIdGetDatum(originalAggregate->aggfnoid)); if (!HeapTupleIsValid(aggTuple)) { + elog(WARNING, "!3434"); elog(WARNING, "citus cache lookup failed for aggregate %u", originalAggregate->aggfnoid); combine = InvalidOid; @@ -2949,8 +2957,14 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, Aggref *newWorkerAggregate = NULL; List *aggArguments = NIL; ListCell *originalAggArgCell; - Oid workerPartialId = AggregateFunctionHelperOid( - WORKER_PARTIAL_AGGREGATE_NAME, originalAggregate->aggfnoid); + Oid workerPartialId = AggregateWorkerPartialOid(originalAggregate->aggfnoid); + + if (workerPartialId == InvalidOid) + { + elog(ERROR, + "Could not find " WORKER_PARTIAL_AGGREGATE_NAME + " with correct signature."); + } aggparam = makeConst(REGPROCEDUREOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum(originalAggregate->aggfnoid), false, @@ -3066,11 +3080,11 @@ AggregateArgumentType(Aggref *aggregate) static bool AggregateEnabledCustom(Oid aggregateOid) { - DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(AggregateRelationId, + DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(ProcedureRelationId, aggregateOid, 0); - return cacheEntry != NULL && cacheEntry->aggregationStrategy == - AGGREGATION_STRATEGY_COMBINE; + return cacheEntry != NULL && cacheEntry->isDistributed && + cacheEntry->aggregationStrategy == AGGREGATION_STRATEGY_COMBINE; } @@ -3138,12 +3152,9 @@ AggregateFunctionOid(const char *functionName, Oid inputType) /* * AggregateFunctionHelperOid finds the aggregate helper for a given aggregate. */ -static Oid -AggregateFunctionHelperOid(const char *helperPrefix, Oid aggOid) +static HeapTuple +AggregateFunctionHelperOidHelper(Oid aggOid, StringInfo helperName) { - StringInfoData helperName; - int numargs; - Oid *argtypes; HeapTuple proctup; HeapTuple aggtup; Form_pg_proc proc; @@ -3152,22 +3163,73 @@ AggregateFunctionHelperOid(const char *helperPrefix, Oid aggOid) proctup = SearchSysCache1(PROCOID, aggOid); if (!HeapTupleIsValid(proctup)) { - return InvalidOid; + return NULL; } proc = (Form_pg_proc) GETSTRUCT(proctup); aggtup = SearchSysCache1(AGGFNOID, aggOid); if (!HeapTupleIsValid(aggtup)) { - return InvalidOid; + ReleaseSysCache(proctup); + return NULL; } agg = (Form_pg_aggregate) GETSTRUCT(aggtup); - initStringInfo(&helperName); - appendStringInfoString(&helperName, helperPrefix); - appendStringInfoAggregateHelperSuffix(&helperName, proc, agg); + appendStringInfoAggregateHelperSuffix(helperName, proc, agg); - return AggregateHelperOid(helperName.data, proctup, &numargs, &argtypes); + ReleaseSysCache(aggtup); + return proctup; +} + + +/* + * AggregateFunctionHelperOid finds the aggregate helper for a given aggregate. + */ +static Oid +AggregateCoordCombineOid(Oid aggOid) +{ + StringInfoData helperName; + HeapTuple proctup; + + initStringInfo(&helperName); + appendStringInfoString(&helperName, COORD_COMBINE_AGGREGATE_NAME); + proctup = AggregateFunctionHelperOidHelper(aggOid, &helperName); + + if (proctup == NULL) + { + elog(ERROR, "Failed to locate appropriate " COORD_COMBINE_AGGREGATE_NAME); + } + + ReleaseSysCache(proctup); + return CoordCombineAggOid(helperName.data); +} + + +/* + * AggregateFunctionHelperOid finds the aggregate helper for a given aggregate. + */ +static Oid +AggregateWorkerPartialOid(Oid aggOid) +{ + Oid result; + int numargs; + Oid *argtypes; + StringInfoData helperName; + HeapTuple proctup; + + initStringInfo(&helperName); + appendStringInfoString(&helperName, WORKER_PARTIAL_AGGREGATE_NAME); + proctup = AggregateFunctionHelperOidHelper(aggOid, &helperName); + + if (proctup == NULL) + { + elog(ERROR, "Failed to locate appropriate " WORKER_PARTIAL_AGGREGATE_NAME); + } + + result = WorkerPartialAggOid(helperName.data, proctup, &numargs, &argtypes); + + ReleaseSysCache(proctup); + return result; } diff --git a/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql b/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql index 25a3600b2..ab0cd013c 100644 --- a/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql +++ b/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql @@ -1,6 +1,6 @@ SET search_path = 'pg_catalog'; -CREATE FUNCTION mark_aggregate_for_distributed_execution(internal) +CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure) RETURNS void AS 'MODULE_PATHNAME' LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE; @@ -22,11 +22,6 @@ RETURNS internal AS 'MODULE_PATHNAME' LANGUAGE C PARALLEL SAFE; -CREATE FUNCTION worker_partial_agg_sfunc(internal, oid, anyelement) -RETURNS internal -AS 'MODULE_PATHNAME' -LANGUAGE C PARALLEL SAFE; - CREATE FUNCTION worker_partial_agg_ffunc(internal) RETURNS bytea AS 'MODULE_PATHNAME' diff --git a/src/include/distributed/commands.h b/src/include/distributed/commands.h index 52de313b1..1225ba30e 100644 --- a/src/include/distributed/commands.h +++ b/src/include/distributed/commands.h @@ -53,7 +53,9 @@ extern bool ConstraintIsAForeignKey(char *constraintName, Oid relationId); extern void appendStringInfoAggregateHelperSuffix(StringInfo helperSuffix, Form_pg_proc proc, Form_pg_aggregate agg); -extern Oid AggregateHelperOid(char *helperName, HeapTuple proctup, int *numargs, Oid **argtypes); +extern Oid CoordCombineAggOid(char *helperName); +extern Oid WorkerPartialAggOid(char *helperName, HeapTuple proctup, int *numargs, + Oid **argtypes); extern List * PlanCreateFunctionStmt(CreateFunctionStmt *stmt, const char *queryString); extern List * ProcessCreateFunctionStmt(CreateFunctionStmt *stmt, const char *queryString);