From 1a7bf4ca2c95df85a198c563a6094f07cf299c65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Thu, 1 Aug 2019 23:17:17 +0000 Subject: [PATCH] Initial sketch of aggregate expression transforms --- .../planner/multi_logical_optimizer.c | 188 +++++++++++++++--- .../distributed/multi_logical_optimizer.h | 2 + 2 files changed, 163 insertions(+), 27 deletions(-) diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index b54c1b309..e4aefa1a5 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -252,12 +252,13 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate, WorkerAggregateWalkerContext *walkerContextry); static AggregateType GetAggregateType(Oid aggFunctionId); static Oid AggregateArgumentType(Aggref *aggregate); +static Oid AggregateFunctionOidWithoutInput(const char *functionName); static Oid AggregateFunctionOid(const char *functionName, Oid inputType); static Oid TypeOid(Oid schemaId, const char *typeName); static SortGroupClause * CreateSortGroupClause(Var *column); /* Local functions forward declarations for count(distinct) approximations */ -static char * CountDistinctHashFunctionName(Oid argumentType); +static const char * CountDistinctHashFunctionName(Oid argumentType); static int CountDistinctStorageSize(double approximationErrorRate); static Const * MakeIntegerConst(int32 integerValue); static Const * MakeIntegerConstInt64(int64 integerValue); @@ -1518,20 +1519,59 @@ MasterAggregateExpression(Aggref *originalAggregate, ObjectIdGetDatum(originalAggregate->aggfnoid)); if (!HeapTupleIsValid(aggTuple)) { - elog(ERROR, "cache lookup failed for aggregate %u", + elog(WARNING, "cache lookup failed for aggregate %u", originalAggregate->aggfnoid); + combine = InvalidOid; } - aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); - combine = aggform->aggcombinefn; - if (combine != InvalidOid && originalAggregate->aggtranstype == INTERNALOID) + else { - serial = aggform->aggserialfn; - deserial = aggform->aggdeserialfn; + aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); + combine = aggform->aggcombinefn; + if (combine != InvalidOid && originalAggregate->aggtranstype == INTERNALOID) + { + serial = aggform->aggserialfn; + deserial = aggform->aggdeserialfn; + } + ReleaseSysCache(aggTuple); } - ReleaseSysCache(aggTuple); if (combine != InvalidOid) - { } + { + Const *aggparam = NULL; + Var *column = NULL; + List *aggArguments = NIL; + Aggref *newMasterAggregate = NULL; + Oid coordCombineId = AggregateFunctionOidWithoutInput( + COORD_COMBINE_AGGREGATE_NAME); + + Oid workerReturnType = BYTEAOID; + int32 workerReturnTypeMod = -1; + Oid workerCollationId = InvalidOid; + + + aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum( + originalAggregate->aggfnoid), false, true); + column = makeVar(masterTableId, walkerContext->columnId, workerReturnType, + workerReturnTypeMod, workerCollationId, columnLevelsUp); + + aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false)); + aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL, + false)); + + /* coord_combine_agg(agg, workercol) */ + newMasterAggregate = makeNode(Aggref); + newMasterAggregate->aggfnoid = coordCombineId; + newMasterAggregate->aggtype = originalAggregate->aggtype; + newMasterAggregate->args = aggArguments; + newMasterAggregate->aggkind = AGGKIND_NORMAL; + newMasterAggregate->aggfilter = originalAggregate->aggfilter; + newMasterAggregate->aggtranstype = INTERNALOID; + newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID), + list_make1_oid(BYTEAOID)); + newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE; + + newMasterExpression = (Expr *) newMasterAggregate; + } else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct && CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION && walkerContext->pullDistinctColumns) @@ -1585,7 +1625,7 @@ MasterAggregateExpression(Aggref *originalAggregate, { /* * If enabled, we check for count(distinct) approximations before count - * distincts. For this, we first compute hll_add_agg(hll_hash(column) on + * distincts. For this, we first compute hll_add_agg(hll_hash(column)) on * worker nodes, and get hll values. We then gather hlls on the master * node, and compute hll_cardinality(hll_union_agg(hll)). */ @@ -2765,10 +2805,69 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, AggregateType aggregateType = GetAggregateType(originalAggregate->aggfnoid); List *workerAggregateList = NIL; AggClauseCosts aggregateCosts; + HeapTuple aggTuple; + Form_pg_aggregate aggform; + Oid combine; + Oid serial = InvalidOid; + Oid deserial = InvalidOid; - if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct && - CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION && - walkerContext->pullDistinctColumns) + aggTuple = SearchSysCache1(AGGFNOID, + ObjectIdGetDatum(originalAggregate->aggfnoid)); + if (!HeapTupleIsValid(aggTuple)) + { + elog(WARNING, "cache lookup failed for aggregate %u", + originalAggregate->aggfnoid); + combine = InvalidOid; + } + else + { + aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple); + combine = aggform->aggcombinefn; + if (combine != InvalidOid && originalAggregate->aggtranstype == INTERNALOID) + { + serial = aggform->aggserialfn; + deserial = aggform->aggdeserialfn; + } + ReleaseSysCache(aggTuple); + } + + if (combine != InvalidOid) + { + Const *aggparam = NULL; + Aggref *newWorkerAggregate = NULL; + List *aggArguments = NIL; + ListCell *originalAggArgCell; + Oid workerPartialId = AggregateFunctionOidWithoutInput( + WORKER_PARTIAL_AGGREGATE_NAME); + + aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum( + originalAggregate->aggfnoid), false, true); + aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false)); + foreach(originalAggArgCell, originalAggregate->args) + { + TargetEntry *arg = (TargetEntry *) lfirst(originalAggArgCell); + TargetEntry *newArg = copyObject(arg); + newArg->resno++; + aggArguments = lappend(aggArguments, newArg); + } + + /* worker_partial_agg(agg, ...args) */ + newWorkerAggregate = makeNode(Aggref); + newWorkerAggregate->aggfnoid = workerPartialId; + newWorkerAggregate->aggtype = originalAggregate->aggtype; + newWorkerAggregate->args = aggArguments; + newWorkerAggregate->aggkind = AGGKIND_NORMAL; + newWorkerAggregate->aggfilter = originalAggregate->aggfilter; + newWorkerAggregate->aggtranstype = INTERNALOID; + newWorkerAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID), + originalAggregate->aggargtypes); + newWorkerAggregate->aggsplit = AGGSPLIT_SIMPLE; + + workerAggregateList = list_make1(newWorkerAggregate); + } + else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct && + CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION && + walkerContext->pullDistinctColumns) { Aggref *aggregate = (Aggref *) copyObject(originalAggregate); List *columnList = pull_var_clause_default((Node *) aggregate); @@ -2808,7 +2907,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate, Oid hllSchemaOid = get_extension_schema(hllId); const char *hllSchemaName = get_namespace_name(hllSchemaOid); - char *hashFunctionName = CountDistinctHashFunctionName(argumentType); + const char *hashFunctionName = CountDistinctHashFunctionName(argumentType); Oid hashFunctionId = FunctionOid(hllSchemaName, hashFunctionName, hashArgumentCount); Oid hashFunctionReturnType = get_func_rettype(hashFunctionId); @@ -3018,6 +3117,49 @@ AggregateFunctionOid(const char *functionName, Oid inputType) } +/* + * AggregateFunctionOid performs a reverse lookup on aggregate function name, + * and returns the corresponding aggregate function oid for the given function + * name and input type. + */ +static Oid +AggregateFunctionOidWithoutInput(const char *functionName) +{ + Oid functionOid = InvalidOid; + Relation procRelation = NULL; + SysScanDesc scanDescriptor = NULL; + ScanKeyData scanKey[1]; + int scanKeyCount = 1; + HeapTuple heapTuple = NULL; + + procRelation = heap_open(ProcedureRelationId, AccessShareLock); + + ScanKeyInit(&scanKey[0], Anum_pg_proc_proname, + BTEqualStrategyNumber, F_NAMEEQ, CStringGetDatum(functionName)); + + scanDescriptor = systable_beginscan(procRelation, + ProcedureNameArgsNspIndexId, true, + NULL, scanKeyCount, scanKey); + + /* loop until we find the right function */ + heapTuple = systable_getnext(scanDescriptor); + if (HeapTupleIsValid(heapTuple)) + { + functionOid = HeapTupleGetOid(heapTuple); + } + + if (functionOid == InvalidOid) + { + ereport(ERROR, (errmsg("no matching oid for function: %s", functionName))); + } + + systable_endscan(scanDescriptor); + heap_close(procRelation, AccessShareLock); + + return functionOid; +} + + /* * TypeOid looks for a type that has the given name and schema, and returns the * corresponding type's oid. @@ -3064,42 +3206,34 @@ CreateSortGroupClause(Var *column) * CountDistinctHashFunctionName resolves the hll_hash function name to use for * the given input type, and returns this function name. */ -static char * +static const char * CountDistinctHashFunctionName(Oid argumentType) { - char *hashFunctionName = NULL; - /* resolve hash function name based on input argument type */ switch (argumentType) { case INT4OID: { - hashFunctionName = pstrdup(HLL_HASH_INTEGER_FUNC_NAME); - break; + return HLL_HASH_INTEGER_FUNC_NAME; } case INT8OID: { - hashFunctionName = pstrdup(HLL_HASH_BIGINT_FUNC_NAME); - break; + return HLL_HASH_BIGINT_FUNC_NAME; } case TEXTOID: case BPCHAROID: case VARCHAROID: { - hashFunctionName = pstrdup(HLL_HASH_TEXT_FUNC_NAME); - break; + return HLL_HASH_TEXT_FUNC_NAME; } default: { - hashFunctionName = pstrdup(HLL_HASH_ANY_FUNC_NAME); - break; + return HLL_HASH_ANY_FUNC_NAME; } } - - return hashFunctionName; } diff --git a/src/include/distributed/multi_logical_optimizer.h b/src/include/distributed/multi_logical_optimizer.h index 7ef725b57..51cd2781d 100644 --- a/src/include/distributed/multi_logical_optimizer.h +++ b/src/include/distributed/multi_logical_optimizer.h @@ -26,6 +26,8 @@ #define ARRAY_CAT_AGGREGATE_NAME "array_cat_agg" #define JSONB_CAT_AGGREGATE_NAME "jsonb_cat_agg" #define JSON_CAT_AGGREGATE_NAME "json_cat_agg" +#define WORKER_PARTIAL_AGGREGATE_NAME "worker_partial_agg" +#define COORD_COMBINE_AGGREGATE_NAME "coord_combine_agg" #define WORKER_COLUMN_FORMAT "worker_column_%d" /* Definitions related to count(distinct) approximations */