From 1545cb312c6ea53a1860b3ffaade9b355871426a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 15 Oct 2019 02:13:32 +0000 Subject: [PATCH] Sketched out rest of how mark_aggregate_for_distributed_execution should look like --- src/backend/distributed/commands/function.c | 369 +++++++++++++++++- .../planner/multi_logical_optimizer.c | 20 +- .../sql/citus--9.0-1--9.0-customagg.sql | 34 +- .../distributed/utils/metadata_cache.c | 2 + .../distributed/metadata/pg_dist_object.h | 10 +- src/include/distributed/metadata_cache.h | 2 + 6 files changed, 399 insertions(+), 38 deletions(-) diff --git a/src/backend/distributed/commands/function.c b/src/backend/distributed/commands/function.c index 7fa02d9f7..d2e7bb09a 100644 --- a/src/backend/distributed/commands/function.c +++ b/src/backend/distributed/commands/function.c @@ -41,6 +41,7 @@ #include "distributed/metadata/pg_dist_object.h" #include "distributed/metadata_sync.h" #include "distributed/multi_executor.h" +#include "distributed/multi_logical_optimizer.h" #include "distributed/relation_access_tracking.h" #include "distributed/worker_transaction.h" #include "parser/parse_coerce.h" @@ -59,6 +60,9 @@ static char * GetAggregateDDLCommand(const RegProcedure funcOid); static char * GetFunctionDDLCommand(const RegProcedure funcOid); static char * GetFunctionAlterOwnerCommand(const RegProcedure funcOid); +static void CreateAggregateHelper(const char *helperName, const char *helperPrefix, + Form_pg_proc proc, Form_pg_aggregate agg, int numargs, + Oid *argtypes); static int GetDistributionArgIndex(Oid functionOid, char *distributionArgumentName, Oid *distributionArgumentOid); static int GetFunctionColocationId(Oid functionOid, char *colocateWithName, Oid @@ -66,6 +70,7 @@ static int GetFunctionColocationId(Oid functionOid, char *colocateWithName, Oid static void EnsureFunctionCanBeColocatedWithTable(Oid functionOid, Oid distributionColumnType, Oid sourceRelationId); +static void UpdateDistObjectAggregationStrategy(Oid funcOid, int aggregationStrategy); static void UpdateFunctionDistributionInfo(const ObjectAddress *distAddress, int *distribution_argument_index, int *colocationId); @@ -82,6 +87,7 @@ static char * quote_qualified_func_name(Oid funcOid); PG_FUNCTION_INFO_V1(create_distributed_function); +PG_FUNCTION_INFO_v1(mark_aggregate_for_distributed_execution); #define AssertIsFunctionOrProcedure(objtype) \ Assert((objtype) == OBJECT_FUNCTION || (objtype) == OBJECT_PROCEDURE || (objtype) == \ @@ -212,6 +218,305 @@ create_distributed_function(PG_FUNCTION_ARGS) } +/* + * mark_aggregate_for_distributed_execution(regproc) signals the aggregate is safe to + * execute across worker nodes. This requires superuser because an aggregate + * which does not support distributed execution has undefined behavior. + * Also makes sure necessary helper functions exist across nodes. + */ +Datum +mark_aggregate_for_distributed_execution(PG_FUNCTION_ARGS) +{ + RegProcedure funcOid = PG_GETARG_OID(0); + StringInfoData helperName; + StringInfoData helperSuffix; + Form_pg_proc proc = NULL; + Form_pg_aggregate agg = NULL; + HeapTuple proctup = SearchSysCache1(PROCOID, funcOid); + HeapTuple aggtup = NULL; + int numargs = 0; + Oid *argtypes = NULL; + char **argnames = NULL; + char *argmodes = NULL; + FuncCandidateList clist; + + if (!HeapTupleIsValid(proctup)) + { + goto early_exit; + } + proc = (Form_pg_proc) GETSTRUCT(proctup); + + if (proc->prokind != PROKIND_AGGREGATE) + { + goto early_exit; + } + + aggtup = SearchSysCache1(AGGFNOID, funcOid); + if (!HeapTupleIsValid(aggtup)) + { + goto early_exit; + } + agg = (Form_pg_aggregate) GETSTRUCT(aggtup); + + initStringInfo(&helperSuffix); + + switch (proc->proparallel) + { + case PROPARALLEL_SAFE: + { + appendStringInfoString(&helperSuffix, "_ps"); + break; + } + + case PROPARALLEL_RESTRICTED: + { + appendStringInfoString(&helperSuffix, "_pr"); + break; + } + + case PROPARALLEL_UNSAFE: + { + appendStringInfoString(&helperSuffix, "_pu"); + break; + } + } + + if (agg->aggfinalfn != InvalidOid) + { + switch (agg->aggfinalmodify) + { + case AGGMODIFY_READ_ONLY: + { + appendStringInfoString(&helperSuffix, "_ro"); + break; + } + + case AGGMODIFY_SHAREABLE: + { + appendStringInfoString(&helperSuffix, "_rs"); + break; + } + + case AGGMODIFY_READ_WRITE: + { + appendStringInfoString(&helperSuffix, "_rw"); + break; + } + } + + if (agg->aggfinalextra) + { + appendStringInfoString(&helperSuffix, "_fx"); + } + } + + if (agg->aggmfinalfn != InvalidOid) + { + appendStringInfoString(&helperSuffix, "_m"); + + switch (agg->aggmfinalmodify) + { + case AGGMODIFY_READ_ONLY: + { + appendStringInfoString(&helperSuffix, "_ro"); + break; + } + + case AGGMODIFY_SHAREABLE: + { + appendStringInfoString(&helperSuffix, "_rs"); + break; + } + + case AGGMODIFY_READ_WRITE: + { + appendStringInfoString(&helperSuffix, "_rw"); + break; + } + } + + if (agg->aggmfinalextra) + { + appendStringInfoString(&helperSuffix, "_fx"); + } + } + + /* Parameters, borrows heavily from print_function_arguments in postgres */ + + /* coordinator_combine_agg */ + initStringInfo(&helperName); + appendStringInfo(&helperName, "%s%s", COORD_COMBINE_AGGREGATE_NAME, + helperSuffix.data); + + numargs = get_func_arg_info(proctup, &argtypes, &argnames, &argmodes); + clist = FuncnameGetCandidates(list_make2(makeString("citus"), makeString( + helperName.data)), proc->pronargs + 1, + NIL, false, false, true); + + for (; clist; clist = clist->next) + { + if (clist->args[0] == OIDOID && memcmp(clist->args + 1, argtypes, numargs * + sizeof(numargs)) == 0) + { + break; + } + } + + if (clist == NULL) + { + CreateAggregateHelper(helperName.data, COORD_COMBINE_AGGREGATE_NAME, proc, agg, + numargs, argtypes); + } + + /* worker_partial_agg */ + resetStringInfo(&helperName); + appendStringInfo(&helperName, "%s%s", WORKER_PARTIAL_AGGREGATE_NAME, + helperSuffix.data); + + numargs = get_func_arg_info(proctup, &argtypes, &argnames, &argmodes); + clist = FuncnameGetCandidates(list_make2(makeString("citus"), makeString( + helperName.data)), proc->pronargs + 1, + NIL, false, false, true); + + for (; clist; clist = clist->next) + { + if (clist->args[0] == OIDOID && memcmp(clist->args + 1, argtypes, numargs * + sizeof(numargs)) == 0) + { + break; + } + } + + if (clist == NULL) + { + CreateAggregateHelper(helperName.data, WORKER_PARTIAL_AGGREGATE_NAME, proc, agg, + numargs, argtypes); + } + + /* set strategy column value */ + UpdateDistObjectAggregationStrategy(funcOid, AGGREGATION_STRATEGY_COMBINE); + +early_exit: + if (aggtup && HeapTupleIsValid(aggtup)) + { + ReleaseSysCache(aggtup); + } + if (proctup && HeapTupleIsValid(proctup)) + { + ReleaseSysCache(proctup); + } + + PG_RETURN_VOID(); +} + + +/* + * CreateAggregateHelper creates helper aggregates across nodes + */ +static void +CreateAggregateHelper(const char *helperName, const char *helperPrefix, + Form_pg_proc proc, Form_pg_aggregate agg, int numargs, + Oid *argtypes) +{ + StringInfoData command; + + initStringInfo(&command); + + appendStringInfo(&command, "CREATE AGGREGATE %s(" + "STYPE = internal, SFUNC = citus.%s_sfunc, FINALFUNC = %s_ffunc" + ", COMBINEFUNC = citus.citus_stype_combine" + ", SERIALFUNC = citus.citus_stype_serialize" + ", DESERIALFUNC = citus.citus_stype_deserialize", + quote_qualified_identifier("citus", helperName), + helperPrefix, helperPrefix); + + switch (proc->proparallel) + { + case PROPARALLEL_SAFE: + { + appendStringInfoString(&command, ", PARALLEL = SAFE"); + break; + } + + case PROPARALLEL_RESTRICTED: + { + appendStringInfoString(&command, ", PARALLEL = RESTRICTED"); + break; + } + + case PROPARALLEL_UNSAFE: + { + appendStringInfoString(&command, ", PARALLEL = UNSAFE"); + break; + } + } + + if (agg->aggfinalfn != InvalidOid) + { + switch (agg->aggfinalmodify) + { + case AGGMODIFY_READ_ONLY: + { + appendStringInfoString(&command, ", FINALFUNC_MODIFY = READ_ONLY"); + break; + } + + case AGGMODIFY_SHAREABLE: + { + appendStringInfoString(&command, ", FINALFUNC_MODIFY = SHAREABLE"); + break; + } + + case AGGMODIFY_READ_WRITE: + { + appendStringInfoString(&command, ", FINALFUNC = READ_WRITE"); + break; + } + } + + if (agg->aggfinalextra) + { + appendStringInfoString(&command, ", FINALFUNC_EXTRA"); + } + } + + if (agg->aggmfinalfn != InvalidOid) + { + switch (agg->aggmfinalmodify) + { + case AGGMODIFY_READ_ONLY: + { + appendStringInfoString(&command, ", MFINALFUNC_MODIFY = READ_ONLY"); + break; + } + + case AGGMODIFY_SHAREABLE: + { + appendStringInfoString(&command, ", MFINALFUNC_MODIFY = SHAREABLE"); + break; + } + + case AGGMODIFY_READ_WRITE: + { + appendStringInfoString(&command, ", MFINALFUNC_MODIFY = READ_WRITE"); + break; + } + } + + if (agg->aggmfinalextra) + { + appendStringInfoString(&command, ", MFINALFUNC_EXTRA"); + } + } + + appendStringInfoChar(&command, ')'); + + SendCommandToWorkers(ALL_WORKERS, command.data); + + pfree(command.data); +} + + /* * CreateFunctionDDLCommandsIdempotent returns a list of DDL statements (const char *) to be * executed on a node to recreate the function addressed by the functionAddress. @@ -449,6 +754,68 @@ EnsureFunctionCanBeColocatedWithTable(Oid functionOid, Oid distributionColumnTyp } +/* + * UpdateFunctionDistributionInfo gets object address of a function and + * updates its distribution_argument_index and colocationId in pg_dist_object. + */ +static void +UpdateDistObjectAggregationStrategy(Oid funcOid, int aggregationStrategy) +{ + const bool indexOK = true; + + Relation pgDistObjectRel = NULL; + TupleDesc tupleDescriptor = NULL; + ScanKeyData scanKey[3]; + SysScanDesc scanDescriptor = NULL; + HeapTuple heapTuple = NULL; + Datum values[Natts_pg_dist_object]; + bool isnull[Natts_pg_dist_object]; + bool replace[Natts_pg_dist_object]; + + pgDistObjectRel = heap_open(DistObjectRelationId(), RowExclusiveLock); + tupleDescriptor = RelationGetDescr(pgDistObjectRel); + + /* scan pg_dist_object for classid = $1 AND objid = $2 AND objsubid = $3 via index */ + ScanKeyInit(&scanKey[0], Anum_pg_dist_object_classid, BTEqualStrategyNumber, F_OIDEQ, + ObjectIdGetDatum(ProcedureRelationId)); + ScanKeyInit(&scanKey[1], Anum_pg_dist_object_objid, BTEqualStrategyNumber, F_OIDEQ, + ObjectIdGetDatum(funcOid)); + ScanKeyInit(&scanKey[2], Anum_pg_dist_object_objsubid, BTEqualStrategyNumber, + F_INT4EQ, ObjectIdGetDatum(0)); + + scanDescriptor = systable_beginscan(pgDistObjectRel, DistObjectPrimaryKeyIndexId(), + indexOK, + NULL, 3, scanKey); + + heapTuple = systable_getnext(scanDescriptor); + if (!HeapTupleIsValid(heapTuple)) + { + ereport(ERROR, (errmsg("could not find valid entry for \"%d,%d,%d\" " + "in pg_dist_object", ProcedureRelationId, + funcOid, 0))); + } + + memset(replace, 0, sizeof(replace)); + + replace[Anum_pg_dist_object_distribution_argument_index - 1] = true; + values[Anum_pg_dist_object_distribution_argument_index - 1] = Int32GetDatum( + aggregationStrategy); + isnull[Anum_pg_dist_object_distribution_argument_index - 1] = false; + + heapTuple = heap_modify_tuple(heapTuple, tupleDescriptor, values, isnull, replace); + + CatalogTupleUpdate(pgDistObjectRel, &heapTuple->t_self, heapTuple); + + CitusInvalidateRelcacheByRelid(DistObjectRelationId()); + + CommandCounterIncrement(); + + systable_endscan(scanDescriptor); + + heap_close(pgDistObjectRel, NoLock); +} + + /* * UpdateFunctionDistributionInfo gets object address of a function and * updates its distribution_argument_index and colocationId in pg_dist_object. @@ -487,7 +854,7 @@ UpdateFunctionDistributionInfo(const ObjectAddress *distAddress, heapTuple = systable_getnext(scanDescriptor); if (!HeapTupleIsValid(heapTuple)) { - ereport(ERROR, (errmsg("could not find valid entry for node \"%d,%d,%d\" " + ereport(ERROR, (errmsg("could not find valid entry for \"%d,%d,%d\" " "in pg_dist_object", distAddress->classId, distAddress->objectId, distAddress->objectSubId))); } diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index 8e755c20b..865280a1a 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -24,6 +24,7 @@ #include "catalog/pg_proc.h" #include "catalog/pg_type.h" #include "commands/extension.h" +#include "distributed/metadata/pg_dist_object.h" #include "distributed/citus_nodes.h" #include "distributed/citus_ruleutils.h" #include "distributed/colocation_utils.h" @@ -253,7 +254,7 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate, static AggregateType GetAggregateType(Oid aggFunctionId); static Oid AggregateArgumentType(Aggref *aggregate); static bool AggregateEnabledCustom(Oid aggregateOid); -static Oid AggregateFunctionOidWithoutInput(const char *functionName); +static Oid AggregateFunctionHelperOid(const char *helperPrefix, Oid aggOid); static Oid AggregateFunctionOid(const char *functionName, Oid inputType); static Oid TypeOid(Oid schemaId, const char *typeName); static SortGroupClause * CreateSortGroupClause(Var *column); @@ -1867,8 +1868,8 @@ MasterAggregateExpression(Aggref *originalAggregate, Var *column = NULL; List *aggArguments = NIL; Aggref *newMasterAggregate = NULL; - Oid coordCombineId = AggregateFunctionOidWithoutInput( - COORD_COMBINE_AGGREGATE_NAME); + Oid coordCombineId = AggregateFunctionHelperOid( + COORD_COMBINE_AGGREGATE_NAME, originalAggregate->aggfnoid); Oid workerReturnType = BYTEAOID; int32 workerReturnTypeMod = -1; @@ -2947,8 +2948,8 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, Aggref *newWorkerAggregate = NULL; List *aggArguments = NIL; ListCell *originalAggArgCell; - Oid workerPartialId = AggregateFunctionOidWithoutInput( - WORKER_PARTIAL_AGGREGATE_NAME); + Oid workerPartialId = AggregateFunctionHelperOid( + WORKER_PARTIAL_AGGREGATE_NAME, originalAggregate->aggfnoid); aggparam = makeConst(REGPROCEDUREOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum(originalAggregate->aggfnoid), false, @@ -3067,7 +3068,8 @@ AggregateEnabledCustom(Oid aggregateOid) DistObjectCacheEntry *cacheEntry = LookupDistObjectCacheEntry(AggregateRelationId, aggregateOid, 0); - return cacheEntry != NULL; + return cacheEntry != NULL && cacheEntry->aggregationStrategy == + AGGREGATION_STRATEGY_COMBINE; } @@ -3133,12 +3135,10 @@ AggregateFunctionOid(const char *functionName, Oid inputType) /* - * AggregateFunctionOid performs a reverse lookup on aggregate function name, - * and returns the corresponding aggregate function oid for the given function - * name and input type. + * AggregateFunctionHelperOid finds the aggregate helper for a given aggregate. */ static Oid -AggregateFunctionOidWithoutInput(const char *functionName) +AggregateFunctionHelperOid(const char *helperPrefix, Oid aggOid) { Oid functionOid = InvalidOid; Relation procRelation = NULL; diff --git a/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql b/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql index 35f2e8d43..25a3600b2 100644 --- a/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql +++ b/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql @@ -1,5 +1,12 @@ SET search_path = 'pg_catalog'; +CREATE FUNCTION mark_aggregate_for_distributed_execution(internal) +RETURNS void +AS 'MODULE_PATHNAME' +LANGUAGE C STRICT IMMUTABLE PARALLEL SAFE; + +SET search_path = 'citus'; + CREATE FUNCTION citus_stype_serialize(internal) RETURNS bytea AS 'MODULE_PATHNAME' @@ -35,31 +42,6 @@ RETURNS anyelement AS 'MODULE_PATHNAME' LANGUAGE C PARALLEL SAFE; --- select worker_partial_agg(agg, ...) --- equivalent to --- select serialize_stype(agg_without_ffunc(...)) -CREATE AGGREGATE worker_partial_agg(oid, anyelement) ( - STYPE = internal, - SFUNC = worker_partial_agg_sfunc, - FINALFUNC = worker_partial_agg_ffunc, - COMBINEFUNC = citus_stype_combine, - SERIALFUNC = citus_stype_serialize, - DESERIALFUNC = citus_stype_deserialize, - PARALLEL = SAFE -); - --- select coord_combine_agg(agg, col) --- equivalent to --- select agg_ffunc(agg_combine(col)) -CREATE AGGREGATE coord_combine_agg(oid, bytea, anyelement) ( - STYPE = internal, - SFUNC = coord_combine_agg_sfunc, - FINALFUNC = coord_combine_agg_ffunc, - FINALFUNC_EXTRA, - COMBINEFUNC = citus_stype_combine, - SERIALFUNC = citus_stype_serialize, - DESERIALFUNC = citus_stype_deserialize, - PARALLEL = SAFE -); +ALTER TABLE pg_dist_object ADD aggregation_strategy int; RESET search_path; diff --git a/src/backend/distributed/utils/metadata_cache.c b/src/backend/distributed/utils/metadata_cache.c index 3672f7371..afaf9bbf8 100644 --- a/src/backend/distributed/utils/metadata_cache.c +++ b/src/backend/distributed/utils/metadata_cache.c @@ -998,6 +998,8 @@ LookupDistObjectCacheEntry(Oid classid, Oid objid, int32 objsubid) 1]); cacheEntry->colocationId = DatumGetInt32(datumArray[Anum_pg_dist_object_colocationid - 1]); + cacheEntry->aggregationStrategy = + DatumGetInt32(datumArray[Anum_pg_dist_object_aggregation_strategy - 1]); } else { diff --git a/src/include/distributed/metadata/pg_dist_object.h b/src/include/distributed/metadata/pg_dist_object.h index 442bb408b..4d41d4d01 100644 --- a/src/include/distributed/metadata/pg_dist_object.h +++ b/src/include/distributed/metadata/pg_dist_object.h @@ -34,7 +34,8 @@ typedef struct FormData_pg_dist_object text[] object_arguments; uint32 distribution_argument_index; /* only valid for distributed functions/procedures */ - uint32 colocationid; /* only valid for distributed functions/procedures */ + uint32 colocationid; /* only valid for distributed functions/procedures */ + uint32 aggregation_strategy; /* only valid for distributed aggregates */ #endif } FormData_pg_dist_object; @@ -58,5 +59,12 @@ typedef FormData_pg_dist_object *Form_pg_dist_object; #define Anum_pg_dist_object_object_args 6 #define Anum_pg_dist_object_distribution_argument_index 7 #define Anum_pg_dist_object_colocationid 8 +#define Anum_pg_dist_object_aggregation_strategy 9 + +/* + * Values for aggregation_strategy + */ +#define AGGREGATION_STRATEGY_NONE 0 +#define AGGREGATION_STRATEGY_COMBINE 1 #endif /* PG_DIST_OBJECT_H */ diff --git a/src/include/distributed/metadata_cache.h b/src/include/distributed/metadata_cache.h index ce2d628ea..9a0272e33 100644 --- a/src/include/distributed/metadata_cache.h +++ b/src/include/distributed/metadata_cache.h @@ -114,6 +114,8 @@ typedef struct DistObjectCacheEntry int distributionArgIndex; int colocationId; + + int aggregationStrategy; } DistObjectCacheEntry;