From 9b3260f4dff8472ceadb65432c0d6acb1548ddf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Wed, 16 Oct 2019 22:05:55 +0000 Subject: [PATCH] AGGREGATE_STRATEGY_COMMUTE --- src/backend/distributed/commands/function.c | 205 ++++++++++-------- .../planner/multi_logical_optimizer.c | 41 ++-- .../sql/citus--9.0-1--9.0-customagg.sql | 2 +- .../distributed/metadata/pg_dist_object.h | 1 + .../distributed/multi_logical_optimizer.h | 6 +- 5 files changed, 149 insertions(+), 106 deletions(-) diff --git a/src/backend/distributed/commands/function.c b/src/backend/distributed/commands/function.c index 656dfda28..48aab1541 100644 --- a/src/backend/distributed/commands/function.c +++ b/src/backend/distributed/commands/function.c @@ -102,7 +102,7 @@ PG_FUNCTION_INFO_V1(mark_aggregate_for_distributed_execution); Datum create_distributed_function(PG_FUNCTION_ARGS) { - RegProcedure funcOid = PG_GETARG_OID(0); + RegProcedure funcOid = PG_ARGISNULL(0) ? InvalidOid : PG_GETARG_OID(0); text *distributionArgumentNameText = NULL; /* optional */ text *colocateWithText = NULL; /* optional */ @@ -228,115 +228,146 @@ create_distributed_function(PG_FUNCTION_ARGS) Datum mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS) { - RegProcedure funcOid = PG_GETARG_OID(0); + RegProcedure funcOid = PG_ARGISNULL(0) ? InvalidOid : PG_GETARG_OID(0); StringInfoData helperName; StringInfoData helperSuffix; Form_pg_proc proc = NULL; Form_pg_aggregate agg = NULL; - HeapTuple proctup = SearchSysCache1(PROCOID, funcOid); + HeapTuple proctup = NULL; HeapTuple aggtup = NULL; Oid helperOid = InvalidOid; int numargs = 0; Oid *argtypes = NULL; + int aggregationStrategy = -1; - - if (!HeapTupleIsValid(proctup)) + if (!PG_ARGISNULL(1)) { - goto early_exit; - } - proc = (Form_pg_proc) GETSTRUCT(proctup); - - if (proc->prokind != PROKIND_AGGREGATE) - { - goto early_exit; - } - - aggtup = SearchSysCache1(AGGFNOID, funcOid); - if (!HeapTupleIsValid(aggtup)) - { - goto early_exit; - } - agg = (Form_pg_aggregate) GETSTRUCT(aggtup); - - initStringInfo(&helperSuffix); - appendStringInfoAggregateHelperSuffix(&helperSuffix, proc, agg); - - /* coordinator_combine_agg */ - initStringInfo(&helperName); - appendStringInfo(&helperName, "%s%s", COORD_COMBINE_AGGREGATE_NAME, - helperSuffix.data); - helperOid = CoordCombineAggOid(helperName.data); - - if (helperOid == InvalidOid) - { - Oid coordArgTypes[2] = { BYTEAOID, ANYELEMENTOID }; - - CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc, agg, - 2, coordArgTypes, true); - } - - /* worker_partial_agg */ - resetStringInfo(&helperName); - appendStringInfo(&helperName, "%s%s", WORKER_PARTIAL_AGGREGATE_NAME, - helperSuffix.data); - helperOid = WorkerPartialAggOid(helperName.data, proctup, &numargs, &argtypes); - - if (helperOid == InvalidOid) - { - /* Also check that we have a matching worker_partial_sfunc(internal, oid, ...) */ - FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString("citus"), - makeString( - "worker_partial_agg_sfunc")), - numargs + 2, - NIL, false, false, true); - - for (; clist; clist = clist->next) + char *strategyParam = TextDatumGetCString(PG_GETARG_TEXT_P(1)); + if (strcmp(strategyParam, "none") == 0) { - if (clist->args[0] == INTERNALOID && clist->args[1] == OIDOID && - memcmp(clist->args + 2, argtypes, numargs * sizeof(Oid)) == 0) - { - break; - return clist->oid; - } + aggregationStrategy = AGGREGATION_STRATEGY_NONE; + } + else if (strcmp(strategyParam, "combine") == 0) + { + aggregationStrategy = AGGREGATION_STRATEGY_COMBINE; + } + else if (strcmp(strategyParam, "commute") == 0) + { + aggregationStrategy = AGGREGATION_STRATEGY_COMMUTE; + } + } + + if (aggregationStrategy == -1) + { + elog(ERROR, + "mark_aggregate_for_distributed_execution expects a strategy that is one of: 'none', 'combine', 'commute'"); + } + + if (aggregationStrategy == AGGREGATION_STRATEGY_COMBINE) + { + proctup = SearchSysCache1(PROCOID, funcOid); + if (!HeapTupleIsValid(proctup)) + { + goto early_exit; + } + proc = (Form_pg_proc) GETSTRUCT(proctup); + + if (proc->prokind != PROKIND_AGGREGATE) + { + goto early_exit; } - if (clist == NULL) + aggtup = SearchSysCache1(AGGFNOID, funcOid); + if (!HeapTupleIsValid(aggtup)) { - int i; - StringInfoData command; + goto early_exit; + } + agg = (Form_pg_aggregate) GETSTRUCT(aggtup); - initStringInfo(&command); - appendStringInfoString(&command, - "CREATE FUNCTION citus.worker_partial_agg_sfunc(internal, oid"); - for (i = 0; i < numargs; i++) - { - appendStringInfo(&command, ", %s", format_type_be_qualified(argtypes[i])); - } - appendStringInfoString(&command, - ") RETURNS internal AS 'citus' LANGUAGE C PARALLEL SAFE"); + initStringInfo(&helperSuffix); + appendStringInfoAggregateHelperSuffix(&helperSuffix, proc, agg); - SendCommandToWorkers(ALL_WORKERS, command.data); + /* coordinator_combine_agg */ + initStringInfo(&helperName); + appendStringInfo(&helperName, "%s%s", COORD_COMBINE_AGGREGATE_NAME, + helperSuffix.data); + helperOid = CoordCombineAggOid(helperName.data); - /* TODO execute as CitusExtensionOwner */ - if (SPI_connect() != SPI_OK_CONNECT) - { - elog(ERROR, "SPI_connect failed"); - } - if (SPI_execute(command.data, false, 0) < 0) - { - elog(ERROR, "SPI_execute %s failed", command.data); - } - SPI_finish(); + if (helperOid == InvalidOid) + { + Oid coordArgTypes[2] = { BYTEAOID, ANYELEMENTOID }; - pfree(command.data); + CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc, + agg, + 2, coordArgTypes, true); } - CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc, agg, - numargs, argtypes, false); + /* worker_partial_agg */ + resetStringInfo(&helperName); + appendStringInfo(&helperName, "%s%s", WORKER_PARTIAL_AGGREGATE_NAME, + helperSuffix.data); + helperOid = WorkerPartialAggOid(helperName.data, proctup, &numargs, &argtypes); + + if (helperOid == InvalidOid) + { + /* Also check that we have a matching worker_partial_sfunc(internal, oid, ...) */ + FuncCandidateList clist = FuncnameGetCandidates(list_make2(makeString( + "citus"), + makeString( + "worker_partial_agg_sfunc")), + numargs + 2, + NIL, false, false, true); + + for (; clist; clist = clist->next) + { + if (clist->args[0] == INTERNALOID && clist->args[1] == OIDOID && + memcmp(clist->args + 2, argtypes, numargs * sizeof(Oid)) == 0) + { + break; + return clist->oid; + } + } + + if (clist == NULL) + { + int i; + StringInfoData command; + + initStringInfo(&command); + appendStringInfoString(&command, + "CREATE FUNCTION citus.worker_partial_agg_sfunc(internal, oid"); + for (i = 0; i < numargs; i++) + { + appendStringInfo(&command, ", %s", format_type_be_qualified( + argtypes[i])); + } + appendStringInfoString(&command, + ") RETURNS internal AS 'citus' LANGUAGE C PARALLEL SAFE"); + + SendCommandToWorkers(ALL_WORKERS, command.data); + + /* TODO execute as CitusExtensionOwner */ + if (SPI_connect() != SPI_OK_CONNECT) + { + elog(ERROR, "SPI_connect failed"); + } + if (SPI_execute(command.data, false, 0) < 0) + { + elog(ERROR, "SPI_execute %s failed", command.data); + } + SPI_finish(); + + pfree(command.data); + } + + CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc, + agg, + numargs, argtypes, false); + } } /* set strategy column value */ - UpdateDistObjectAggregationStrategy(funcOid, AGGREGATION_STRATEGY_COMBINE); + UpdateDistObjectAggregationStrategy(funcOid, aggregationStrategy); early_exit: if (aggtup && HeapTupleIsValid(aggtup)) diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index 5bd0aea30..0e176260e 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -254,7 +254,7 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate, WorkerAggregateWalkerContext *walkerContextry); static AggregateType GetAggregateType(Oid aggFunctionId); static Oid AggregateArgumentType(Aggref *aggregate); -static bool AggregateEnabledCustom(Oid aggregateOid); +static int AggregateEnabledCustom(Oid aggregateOid); static HeapTuple AggregateFunctionHelperOidHelper(Oid aggOid, StringInfo helperName); static Oid AggregateCoordCombineOid(Oid aggOid); static Oid AggregateWorkerPartialOid(Oid aggOid); @@ -1848,7 +1848,7 @@ MasterAggregateExpression(Aggref *originalAggregate, newMasterExpression = (Expr *) unionAggregate; } - else if (aggregateType == AGGREGATE_CUSTOM) + else if (aggregateType == AGGREGATE_CUSTOM_COMBINE) { aggTuple = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(originalAggregate->aggfnoid)); @@ -1869,12 +1869,14 @@ MasterAggregateExpression(Aggref *originalAggregate, { Const *aggparam = NULL; Var *column = NULL; + Const *nulltag = NULL; List *aggArguments = NIL; Aggref *newMasterAggregate = NULL; Oid coordCombineId = AggregateCoordCombineOid(originalAggregate->aggfnoid); Oid workerReturnType = BYTEAOID; int32 workerReturnTypeMod = -1; Oid workerCollationId = InvalidOid; + Oid resultType = exprType((Node *) originalAggregate); if (coordCombineId == InvalidOid) { @@ -1887,10 +1889,13 @@ MasterAggregateExpression(Aggref *originalAggregate, originalAggregate->aggfnoid), false, true); column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, workerReturnTypeMod, workerCollationId, columnLevelsUp); + nulltag = makeNullConst(resultType, -1, InvalidOid); aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false)); aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL, false)); + aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) nulltag, 3, + NULL, false)); /* coord_combine_agg(agg, workercol) */ newMasterAggregate = makeNode(Aggref); @@ -1900,8 +1905,8 @@ MasterAggregateExpression(Aggref *originalAggregate, newMasterAggregate->aggkind = AGGKIND_NORMAL; newMasterAggregate->aggfilter = originalAggregate->aggfilter; newMasterAggregate->aggtranstype = INTERNALOID; - newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID), - list_make1_oid(BYTEAOID)); + newMasterAggregate->aggargtypes = list_make3_oid(OIDOID, BYTEAOID, + resultType); newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE; newMasterExpression = (Expr *) newMasterAggregate; @@ -1925,14 +1930,8 @@ MasterAggregateExpression(Aggref *originalAggregate, int32 workerReturnTypeMod = exprTypmod((Node *) originalAggregate); Oid workerCollationId = exprCollation((Node *) originalAggregate); - const char *aggregateName = AggregateNames[aggregateType]; - Oid aggregateFunctionId = AggregateFunctionOid(aggregateName, workerReturnType); - Oid masterReturnType = get_func_rettype(aggregateFunctionId); - Aggref *newMasterAggregate = copyObject(originalAggregate); newMasterAggregate->aggdistinct = NULL; - newMasterAggregate->aggfnoid = aggregateFunctionId; - newMasterAggregate->aggtype = masterReturnType; newMasterAggregate->aggfilter = NULL; column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, @@ -2932,7 +2931,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, workerAggregateList = lappend(workerAggregateList, sumAggregate); workerAggregateList = lappend(workerAggregateList, countAggregate); } - else if (aggregateType == AGGREGATE_CUSTOM) + else if (aggregateType == AGGREGATE_CUSTOM_COMBINE) { aggTuple = SearchSysCache1(AGGFNOID, ObjectIdGetDatum(originalAggregate->aggfnoid)); @@ -2990,11 +2989,16 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, workerAggregateList = list_make1(newWorkerAggregate); } + else + { + elog(ERROR, "Aggregate lacks COMBINEFUNC"); + } } else { /* * All other aggregates are sent as they are to the worker nodes. + * This includes AGGREGATE_CUSTOM_COMMUTE. */ Aggref *workerAggregate = copyObject(originalAggregate); workerAggregateList = lappend(workerAggregateList, workerAggregate); @@ -3023,6 +3027,7 @@ GetAggregateType(Oid aggFunctionId) uint32 aggregateCount = 0; uint32 aggregateIndex = 0; bool found = false; + int customAggregationStrategy; /* look up the function name */ aggregateProcName = get_func_name(aggFunctionId); @@ -3032,9 +3037,10 @@ GetAggregateType(Oid aggFunctionId) aggFunctionId))); } - if (AggregateEnabledCustom(aggFunctionId)) + customAggregationStrategy = AggregateEnabledCustom(aggFunctionId); + if (customAggregationStrategy != AGGREGATION_STRATEGY_NONE) { - return AGGREGATE_CUSTOM; + return AGGREGATE_CUSTOM + customAggregationStrategy; } aggregateCount = lengthof(AggregateNames); @@ -3075,14 +3081,17 @@ AggregateArgumentType(Aggref *aggregate) } -static bool +static int AggregateEnabledCustom(Oid aggregateOid) { DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(ProcedureRelationId, aggregateOid, 0); - return cacheEntry != NULL && cacheEntry->isDistributed && - cacheEntry->aggregationStrategy == AGGREGATION_STRATEGY_COMBINE; + if (cacheEntry == NULL || !cacheEntry->isDistributed) + { + return AGGREGATION_STRATEGY_NONE; + } + return cacheEntry->aggregationStrategy; } diff --git a/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql b/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql index ab0cd013c..ef4218cc4 100644 --- a/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql +++ b/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql @@ -1,6 +1,6 @@ SET search_path = 'pg_catalog'; -CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure) +CREATE FUNCTION mark_aggregate_for_distributed_execution(regprocedure, strategy text) RETURNS void AS 'MODULE_PATHNAME' LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE; diff --git a/src/include/distributed/metadata/pg_dist_object.h b/src/include/distributed/metadata/pg_dist_object.h index 6926c8316..8ec5025ca 100644 --- a/src/include/distributed/metadata/pg_dist_object.h +++ b/src/include/distributed/metadata/pg_dist_object.h @@ -66,5 +66,6 @@ typedef FormData_pg_dist_object *Form_pg_dist_object; */ #define AGGREGATION_STRATEGY_NONE 0 #define AGGREGATION_STRATEGY_COMBINE 1 +#define AGGREGATION_STRATEGY_COMMUTE 2 #endif /* PG_DIST_OBJECT_H */ diff --git a/src/include/distributed/multi_logical_optimizer.h b/src/include/distributed/multi_logical_optimizer.h index bd6bb4918..3628aead6 100644 --- a/src/include/distributed/multi_logical_optimizer.h +++ b/src/include/distributed/multi_logical_optimizer.h @@ -79,8 +79,10 @@ typedef enum AGGREGATE_TOPN_ADD_AGG = 18, AGGREGATE_TOPN_UNION_AGG = 19, - /* AGGREGATE_CUSTOM must come last */ - AGGREGATE_CUSTOM = 20 + /* AGGREGATE_CUSTOM must come last. AGGREGATE_CUSTOM + AGGREGATION_STRATEGY = AGGREGATE_CUSTOM_STRATEGY */ + AGGREGATE_CUSTOM = 20, + AGGREGATE_CUSTOM_COMBINE = 21, + AGGREGATE_CUSTOM_COMMUTE = 22 } AggregateType; /*