From dd5916ef3aa512cb4c62f896ee8a8d2c1b66c7a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 10 Sep 2019 19:14:28 +0000 Subject: [PATCH] Move custom aggregates into a not-default path, add pg_dist_enabled_custom_aggregates to have a set of allowed aggregates as it's possible for aggregates to not be compatible with this implementation --- .../planner/multi_logical_optimizer.c | 260 ++++++++++-------- .../sql/citus--9.0-1--9.0-customagg.sql | 6 + .../distributed/utils/aggregate_utils.c | 11 +- .../distributed/utils/metadata_cache.c | 12 + src/include/distributed/metadata_cache.h | 1 + .../distributed/multi_logical_optimizer.h | 6 +- .../pg_dist_enabled_custom_aggregates.h | 17 ++ 7 files changed, 193 insertions(+), 120 deletions(-) create mode 100644 src/include/distributed/pg_dist_enabled_custom_aggregates.h diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index 36bc38f4f..361080f04 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -34,6 +34,7 @@ #include "distributed/multi_logical_planner.h" #include "distributed/multi_physical_planner.h" #include "distributed/pg_dist_partition.h" +#include "distributed/pg_dist_enabled_custom_aggregates.h" #include "distributed/worker_protocol.h" #include "distributed/version_compat.h" #include "nodes/makefuncs.h" @@ -252,6 +253,7 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate, WorkerAggregateWalkerContext *walkerContextry); static AggregateType GetAggregateType(Oid aggFunctionId); static Oid AggregateArgumentType(Aggref *aggregate); +static bool AggregateEnabledCustom(const char *functionName); static Oid AggregateFunctionOidWithoutInput(const char *functionName); static Oid AggregateFunctionOid(const char *functionName, Oid inputType); static Oid TypeOid(Oid schemaId, const char *typeName); @@ -1513,60 +1515,9 @@ MasterAggregateExpression(Aggref *originalAggregate, Form_pg_aggregate aggform; Oid combine; - aggTuple = SearchSysCache1(AGGFNOID, - ObjectIdGetDatum(originalAggregate->aggfnoid)); - if (!HeapTupleIsValid(aggTuple)) - { - elog(WARNING, "citus cache lookup failed for aggregate %u", - originalAggregate->aggfnoid); - combine = InvalidOid; - } - else - { - aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); - combine = aggform->aggcombinefn; - ReleaseSysCache(aggTuple); - } - - if (combine != InvalidOid) - { - Const *aggparam = NULL; - Var *column = NULL; - List *aggArguments = NIL; - Aggref *newMasterAggregate = NULL; - Oid coordCombineId = AggregateFunctionOidWithoutInput( - COORD_COMBINE_AGGREGATE_NAME); - - Oid workerReturnType = BYTEAOID; - int32 workerReturnTypeMod = -1; - Oid workerCollationId = InvalidOid; - - aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum( - originalAggregate->aggfnoid), false, true); - column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, - workerReturnTypeMod, workerCollationId, columnLevelsUp); - - aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false)); - aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL, - false)); - - /* coord_combine_agg(agg, workercol) */ - newMasterAggregate = makeNode(Aggref); - newMasterAggregate->aggfnoid = coordCombineId; - newMasterAggregate->aggtype = originalAggregate->aggtype; - newMasterAggregate->args = aggArguments; - newMasterAggregate->aggkind = AGGKIND_NORMAL; - newMasterAggregate->aggfilter = originalAggregate->aggfilter; - newMasterAggregate->aggtranstype = INTERNALOID; - newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID), - list_make1_oid(BYTEAOID)); - newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE; - - newMasterExpression = (Expr *) newMasterAggregate; - } - else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct && - CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION && - walkerContext->pullDistinctColumns) + if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct && + CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION && + walkerContext->pullDistinctColumns) { Aggref *aggregate = (Aggref *) copyObject(originalAggregate); List *varList = pull_var_clause_default((Node *) aggregate); @@ -1894,6 +1845,60 @@ MasterAggregateExpression(Aggref *originalAggregate, newMasterExpression = (Expr *) unionAggregate; } + else if (aggregateType == AGGREGATE_CUSTOM) + { + aggTuple = SearchSysCache1(AGGFNOID, + ObjectIdGetDatum(originalAggregate->aggfnoid)); + if (!HeapTupleIsValid(aggTuple)) + { + elog(WARNING, "citus cache lookup failed for aggregate %u", + originalAggregate->aggfnoid); + combine = InvalidOid; + } + else + { + aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); + combine = aggform->aggcombinefn; + ReleaseSysCache(aggTuple); + } + + if (combine != InvalidOid) + { + Const *aggparam = NULL; + Var *column = NULL; + List *aggArguments = NIL; + Aggref *newMasterAggregate = NULL; + Oid coordCombineId = AggregateFunctionOidWithoutInput( + COORD_COMBINE_AGGREGATE_NAME); + + Oid workerReturnType = BYTEAOID; + int32 workerReturnTypeMod = -1; + Oid workerCollationId = InvalidOid; + + aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum( + originalAggregate->aggfnoid), false, true); + column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, + workerReturnTypeMod, workerCollationId, columnLevelsUp); + + aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false)); + aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL, + false)); + + /* coord_combine_agg(agg, workercol) */ + newMasterAggregate = makeNode(Aggref); + newMasterAggregate->aggfnoid = coordCombineId; + newMasterAggregate->aggtype = originalAggregate->aggtype; + newMasterAggregate->args = aggArguments; + newMasterAggregate->aggkind = AGGKIND_NORMAL; + newMasterAggregate->aggfilter = originalAggregate->aggfilter; + newMasterAggregate->aggtranstype = INTERNALOID; + newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID), + list_make1_oid(BYTEAOID)); + newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE; + + newMasterExpression = (Expr *) newMasterAggregate; + } + } else { /* @@ -1929,6 +1934,7 @@ MasterAggregateExpression(Aggref *originalAggregate, newMasterExpression = (Expr *) newMasterAggregate; } + /* * Aggregate functions could have changed the return type. If so, we wrap * the new expression with a conversion function to make it have the same @@ -2801,58 +2807,9 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, Form_pg_aggregate aggform; Oid combine; - aggTuple = SearchSysCache1(AGGFNOID, - ObjectIdGetDatum(originalAggregate->aggfnoid)); - if (!HeapTupleIsValid(aggTuple)) - { - elog(WARNING, "citus cache lookup failed for aggregate %u", - originalAggregate->aggfnoid); - combine = InvalidOid; - } - else - { - aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); - combine = aggform->aggcombinefn; - ReleaseSysCache(aggTuple); - } - - if (combine != InvalidOid) - { - Const *aggparam = NULL; - Aggref *newWorkerAggregate = NULL; - List *aggArguments = NIL; - ListCell *originalAggArgCell; - Oid workerPartialId = AggregateFunctionOidWithoutInput( - WORKER_PARTIAL_AGGREGATE_NAME); - - aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum( - originalAggregate->aggfnoid), false, true); - aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false)); - foreach(originalAggArgCell, originalAggregate->args) - { - TargetEntry *arg = lfirst(originalAggArgCell); - TargetEntry *newArg = copyObject(arg); - newArg->resno++; - aggArguments = lappend(aggArguments, newArg); - } - - /* worker_partial_agg(agg, ...args) */ - newWorkerAggregate = makeNode(Aggref); - newWorkerAggregate->aggfnoid = workerPartialId; - newWorkerAggregate->aggtype = BYTEAOID; - newWorkerAggregate->args = aggArguments; - newWorkerAggregate->aggkind = AGGKIND_NORMAL; - newWorkerAggregate->aggfilter = originalAggregate->aggfilter; - newWorkerAggregate->aggtranstype = INTERNALOID; - newWorkerAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID), - originalAggregate->aggargtypes); - newWorkerAggregate->aggsplit = AGGSPLIT_SIMPLE; - - workerAggregateList = list_make1(newWorkerAggregate); - } - else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct && - CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION && - walkerContext->pullDistinctColumns) + if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct && + CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION && + walkerContext->pullDistinctColumns) { Aggref *aggregate = (Aggref *) copyObject(originalAggregate); List *columnList = pull_var_clause_default((Node *) aggregate); @@ -2964,6 +2921,58 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, workerAggregateList = lappend(workerAggregateList, sumAggregate); workerAggregateList = lappend(workerAggregateList, countAggregate); } + else if (aggregateType == AGGREGATE_CUSTOM) + { + aggTuple = SearchSysCache1(AGGFNOID, + ObjectIdGetDatum(originalAggregate->aggfnoid)); + if (!HeapTupleIsValid(aggTuple)) + { + elog(WARNING, "citus cache lookup failed for aggregate %u", + originalAggregate->aggfnoid); + combine = InvalidOid; + } + else + { + aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); + combine = aggform->aggcombinefn; + ReleaseSysCache(aggTuple); + } + + if (combine != InvalidOid) + { + Const *aggparam = NULL; + Aggref *newWorkerAggregate = NULL; + List *aggArguments = NIL; + ListCell *originalAggArgCell; + Oid workerPartialId = AggregateFunctionOidWithoutInput( + WORKER_PARTIAL_AGGREGATE_NAME); + + aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum( + originalAggregate->aggfnoid), false, true); + aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false)); + foreach(originalAggArgCell, originalAggregate->args) + { + TargetEntry *arg = lfirst(originalAggArgCell); + TargetEntry *newArg = copyObject(arg); + newArg->resno++; + aggArguments = lappend(aggArguments, newArg); + } + + /* worker_partial_agg(agg, ...args) */ + newWorkerAggregate = makeNode(Aggref); + newWorkerAggregate->aggfnoid = workerPartialId; + newWorkerAggregate->aggtype = BYTEAOID; + newWorkerAggregate->args = aggArguments; + newWorkerAggregate->aggkind = AGGKIND_NORMAL; + newWorkerAggregate->aggfilter = originalAggregate->aggfilter; + newWorkerAggregate->aggtranstype = INTERNALOID; + newWorkerAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID), + originalAggregate->aggargtypes); + newWorkerAggregate->aggsplit = AGGSPLIT_SIMPLE; + + workerAggregateList = list_make1(newWorkerAggregate); + } + } else { /* @@ -3005,9 +3014,15 @@ GetAggregateType(Oid aggFunctionId) aggFunctionId))); } + if (AggregateEnabledCustom(aggregateProcName)) + { + return AGGREGATE_CUSTOM; + } + aggregateCount = lengthof(AggregateNames); Assert(AGGREGATE_INVALID_FIRST == 0); + for (aggregateIndex = 1; aggregateIndex < aggregateCount; aggregateIndex++) { const char *aggregateName = AggregateNames[aggregateIndex]; @@ -3042,6 +3057,35 @@ AggregateArgumentType(Aggref *aggregate) } +static bool +AggregateEnabledCustom(const char *functionName) +{ + SysScanDesc scanDescriptor = NULL; + ScanKeyData scanKey[1]; + bool enabled = false; + HeapTuple heapTuple = NULL; + Relation pgDistEnabledCustomAggregates = NULL; + + ScanKeyInit(&scanKey[0], Anum_pg_dist_enabled_custom_aggregates_name, + BTEqualStrategyNumber, F_NAMEEQ, CStringGetDatum(functionName)); + + pgDistEnabledCustomAggregates = heap_open(DistEnabledCustomAggregatesId(), + AccessShareLock); + + scanDescriptor = systable_beginscan(pgDistEnabledCustomAggregates, InvalidOid, false, + NULL, 1, scanKey); + + heapTuple = systable_getnext(scanDescriptor); + + enabled = HeapTupleIsValid(heapTuple); + + systable_endscan(scanDescriptor); + heap_close(pgDistEnabledCustomAggregates, AccessShareLock); + + return enabled; +} + + /* * AggregateFunctionOid performs a reverse lookup on aggregate function name, * and returns the corresponding aggregate function oid for the given function @@ -3160,8 +3204,8 @@ TypeOid(Oid schemaId, const char *typeName) { Oid typeOid; - typeOid = GetSysCacheOid2Compat(TYPENAMENSP, Anum_pg_type_oid, PointerGetDatum( - typeName), + typeOid = GetSysCacheOid2Compat(TYPENAMENSP, Anum_pg_type_oid, + PointerGetDatum(typeName), ObjectIdGetDatum(schemaId)); return typeOid; diff --git a/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql b/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql index 35f2e8d43..6cf46cfdf 100644 --- a/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql +++ b/src/backend/distributed/sql/citus--9.0-1--9.0-customagg.sql @@ -62,4 +62,10 @@ CREATE AGGREGATE coord_combine_agg(oid, bytea, anyelement) ( PARALLEL = SAFE ); +CREATE TABLE citus.pg_dist_enabled_custom_aggregates ( + name text not null primary key +); +ALTER TABLE citus.pg_dist_enabled_custom_aggregates SET SCHEMA pg_catalog; +GRANT SELECT ON pg_catalog.pg_dist_node_metadata TO public; + RESET search_path; diff --git a/src/backend/distributed/utils/aggregate_utils.c b/src/backend/distributed/utils/aggregate_utils.c index 340738485..7ad6d8ec8 100644 --- a/src/backend/distributed/utils/aggregate_utils.c +++ b/src/backend/distributed/utils/aggregate_utils.c @@ -12,8 +12,6 @@ #include "fmgr.h" #include "pg_config_manual.h" -#include "utils/array.h" - PG_FUNCTION_INFO_V1(citus_stype_serialize); PG_FUNCTION_INFO_V1(citus_stype_deserialize); PG_FUNCTION_INFO_V1(citus_stype_combine); @@ -265,8 +263,7 @@ citus_stype_deserialize(PG_FUNCTION_ARGS) box->value_null = true; PG_RETURN_POINTER(box); } - - if (deserial != InvalidOid) + else if (deserial != InvalidOid) { FmgrInfo fdeserialinfo; LOCAL_FCINFO(fdeserial_callinfo, 2); @@ -282,12 +279,6 @@ citus_stype_deserialize(PG_FUNCTION_ARGS) box->value = FunctionCallInvoke(fdeserial_callinfo); box->value_null = fdeserial_callinfo->isnull; } - /* TODO Correct null handling */ - else if (value_null) - { - box->value = (Datum) 0; - box->value_null = true; - } else { transtypetuple = get_typeform(box->transtype, &transtypeform); diff --git a/src/backend/distributed/utils/metadata_cache.c b/src/backend/distributed/utils/metadata_cache.c index 3672f7371..7e04472ce 100644 --- a/src/backend/distributed/utils/metadata_cache.c +++ b/src/backend/distributed/utils/metadata_cache.c @@ -128,6 +128,7 @@ typedef struct MetadataCacheData Oid distPlacementShardidIndexId; Oid distPlacementPlacementidIndexId; Oid distPlacementGroupidIndexId; + Oid distEnabledCustomAggregatesId; Oid distTransactionRelationId; Oid distTransactionGroupIndexId; Oid distTransactionRecordIndexId; @@ -2113,6 +2114,17 @@ DistPlacementGroupidIndexId(void) } +/* return oid of pg_dist_enabled_custom_aggregates relation */ +Oid +DistEnabledCustomAggregatesId(void) +{ + CachedRelationLookup("pg_dist_enabled_custom_aggregates", + &MetadataCache.distEnabledCustomAggregatesId); + + return MetadataCache.distEnabledCustomAggregatesId; +} + + /* return oid of the read_intermediate_result(text,citus_copy_format) function */ Oid CitusReadIntermediateResultFuncId(void) diff --git a/src/include/distributed/metadata_cache.h b/src/include/distributed/metadata_cache.h index f590a6035..ce2d628ea 100644 --- a/src/include/distributed/metadata_cache.h +++ b/src/include/distributed/metadata_cache.h @@ -166,6 +166,7 @@ extern Oid DistPlacementRelationId(void); extern Oid DistNodeRelationId(void); extern Oid DistLocalGroupIdRelationId(void); extern Oid DistObjectRelationId(void); +extern Oid DistEnabledCustomAggregatesId(void); /* index oids */ extern Oid DistNodeNodeIdIndexId(void); diff --git a/src/include/distributed/multi_logical_optimizer.h b/src/include/distributed/multi_logical_optimizer.h index 51cd2781d..bd6bb4918 100644 --- a/src/include/distributed/multi_logical_optimizer.h +++ b/src/include/distributed/multi_logical_optimizer.h @@ -77,9 +77,11 @@ typedef enum AGGREGATE_HLL_ADD = 16, AGGREGATE_HLL_UNION = 17, AGGREGATE_TOPN_ADD_AGG = 18, - AGGREGATE_TOPN_UNION_AGG = 19 -} AggregateType; + AGGREGATE_TOPN_UNION_AGG = 19, + /* AGGREGATE_CUSTOM must come last */ + AGGREGATE_CUSTOM = 20 +} AggregateType; /* * PushDownStatus indicates whether a node can be pushed down below its child diff --git a/src/include/distributed/pg_dist_enabled_custom_aggregates.h b/src/include/distributed/pg_dist_enabled_custom_aggregates.h new file mode 100644 index 000000000..7c4e603ab --- /dev/null +++ b/src/include/distributed/pg_dist_enabled_custom_aggregates.h @@ -0,0 +1,17 @@ +/*------------------------------------------------------------------------- + * + * pg_dist_enabled_custom_aggregates.h + * definition of the relation that lists which aggregates to treat as custom aggregates. + * + * Copyright (c) 2012-2019, Citus Data, Inc. + * + *------------------------------------------------------------------------- + */ + +#ifndef PG_DIST_ENABLED_CUSTOM_AGGREGATES_H +#define PG_DIST_ENABLED_CUSTOM_AGGREGATES_H + +#define Natts_pg_dist_node 1 +#define Anum_pg_dist_enabled_custom_aggregates_name 1 + +#endif /* PG_DIST_ENABLED_CUSTOM_AGGREGATES_H */