Initial sketch of aggregate expression transforms

fix_120_custom_aggregates_distribute_multiarg
Philip Dubé 2019-08-01 23:17:17 +00:00
parent c801889e08
commit 1a7bf4ca2c
2 changed files with 163 additions and 27 deletions

View File

@ -252,12 +252,13 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *walkerContextry);
static AggregateType GetAggregateType(Oid aggFunctionId);
static Oid AggregateArgumentType(Aggref *aggregate);
static Oid AggregateFunctionOidWithoutInput(const char *functionName);
static Oid AggregateFunctionOid(const char *functionName, Oid inputType);
static Oid TypeOid(Oid schemaId, const char *typeName);
static SortGroupClause * CreateSortGroupClause(Var *column);
/* Local functions forward declarations for count(distinct) approximations */
static char * CountDistinctHashFunctionName(Oid argumentType);
static const char * CountDistinctHashFunctionName(Oid argumentType);
static int CountDistinctStorageSize(double approximationErrorRate);
static Const * MakeIntegerConst(int32 integerValue);
static Const * MakeIntegerConstInt64(int64 integerValue);
@ -1518,20 +1519,59 @@ MasterAggregateExpression(Aggref *originalAggregate,
ObjectIdGetDatum(originalAggregate->aggfnoid));
if (!HeapTupleIsValid(aggTuple))
{
elog(ERROR, "cache lookup failed for aggregate %u",
elog(WARNING, "cache lookup failed for aggregate %u",
originalAggregate->aggfnoid);
combine = InvalidOid;
}
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
combine = aggform->aggcombinefn;
if (combine != InvalidOid && originalAggregate->aggtranstype == INTERNALOID)
else
{
serial = aggform->aggserialfn;
deserial = aggform->aggdeserialfn;
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
combine = aggform->aggcombinefn;
if (combine != InvalidOid && originalAggregate->aggtranstype == INTERNALOID)
{
serial = aggform->aggserialfn;
deserial = aggform->aggdeserialfn;
}
ReleaseSysCache(aggTuple);
}
ReleaseSysCache(aggTuple);
if (combine != InvalidOid)
{ }
{
Const *aggparam = NULL;
Var *column = NULL;
List *aggArguments = NIL;
Aggref *newMasterAggregate = NULL;
Oid coordCombineId = AggregateFunctionOidWithoutInput(
COORD_COMBINE_AGGREGATE_NAME);
Oid workerReturnType = BYTEAOID;
int32 workerReturnTypeMod = -1;
Oid workerCollationId = InvalidOid;
aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum(
originalAggregate->aggfnoid), false, true);
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL,
false));
/* coord_combine_agg(agg, workercol) */
newMasterAggregate = makeNode(Aggref);
newMasterAggregate->aggfnoid = coordCombineId;
newMasterAggregate->aggtype = originalAggregate->aggtype;
newMasterAggregate->args = aggArguments;
newMasterAggregate->aggkind = AGGKIND_NORMAL;
newMasterAggregate->aggfilter = originalAggregate->aggfilter;
newMasterAggregate->aggtranstype = INTERNALOID;
newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID),
list_make1_oid(BYTEAOID));
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) newMasterAggregate;
}
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
walkerContext->pullDistinctColumns)
@ -1585,7 +1625,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
{
/*
* If enabled, we check for count(distinct) approximations before count
* distincts. For this, we first compute hll_add_agg(hll_hash(column) on
* distincts. For this, we first compute hll_add_agg(hll_hash(column)) on
* worker nodes, and get hll values. We then gather hlls on the master
* node, and compute hll_cardinality(hll_union_agg(hll)).
*/
@ -2765,10 +2805,69 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
AggregateType aggregateType = GetAggregateType(originalAggregate->aggfnoid);
List *workerAggregateList = NIL;
AggClauseCosts aggregateCosts;
HeapTuple aggTuple;
Form_pg_aggregate aggform;
Oid combine;
Oid serial = InvalidOid;
Oid deserial = InvalidOid;
if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
walkerContext->pullDistinctColumns)
aggTuple = SearchSysCache1(AGGFNOID,
ObjectIdGetDatum(originalAggregate->aggfnoid));
if (!HeapTupleIsValid(aggTuple))
{
elog(WARNING, "cache lookup failed for aggregate %u",
originalAggregate->aggfnoid);
combine = InvalidOid;
}
else
{
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
combine = aggform->aggcombinefn;
if (combine != InvalidOid && originalAggregate->aggtranstype == INTERNALOID)
{
serial = aggform->aggserialfn;
deserial = aggform->aggdeserialfn;
}
ReleaseSysCache(aggTuple);
}
if (combine != InvalidOid)
{
Const *aggparam = NULL;
Aggref *newWorkerAggregate = NULL;
List *aggArguments = NIL;
ListCell *originalAggArgCell;
Oid workerPartialId = AggregateFunctionOidWithoutInput(
WORKER_PARTIAL_AGGREGATE_NAME);
aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum(
originalAggregate->aggfnoid), false, true);
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
foreach(originalAggArgCell, originalAggregate->args)
{
TargetEntry *arg = (TargetEntry *) lfirst(originalAggArgCell);
TargetEntry *newArg = copyObject(arg);
newArg->resno++;
aggArguments = lappend(aggArguments, newArg);
}
/* worker_partial_agg(agg, ...args) */
newWorkerAggregate = makeNode(Aggref);
newWorkerAggregate->aggfnoid = workerPartialId;
newWorkerAggregate->aggtype = originalAggregate->aggtype;
newWorkerAggregate->args = aggArguments;
newWorkerAggregate->aggkind = AGGKIND_NORMAL;
newWorkerAggregate->aggfilter = originalAggregate->aggfilter;
newWorkerAggregate->aggtranstype = INTERNALOID;
newWorkerAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID),
originalAggregate->aggargtypes);
newWorkerAggregate->aggsplit = AGGSPLIT_SIMPLE;
workerAggregateList = list_make1(newWorkerAggregate);
}
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
walkerContext->pullDistinctColumns)
{
Aggref *aggregate = (Aggref *) copyObject(originalAggregate);
List *columnList = pull_var_clause_default((Node *) aggregate);
@ -2808,7 +2907,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
Oid hllSchemaOid = get_extension_schema(hllId);
const char *hllSchemaName = get_namespace_name(hllSchemaOid);
char *hashFunctionName = CountDistinctHashFunctionName(argumentType);
const char *hashFunctionName = CountDistinctHashFunctionName(argumentType);
Oid hashFunctionId = FunctionOid(hllSchemaName, hashFunctionName,
hashArgumentCount);
Oid hashFunctionReturnType = get_func_rettype(hashFunctionId);
@ -3018,6 +3117,49 @@ AggregateFunctionOid(const char *functionName, Oid inputType)
}
/*
* AggregateFunctionOid performs a reverse lookup on aggregate function name,
* and returns the corresponding aggregate function oid for the given function
* name and input type.
*/
static Oid
AggregateFunctionOidWithoutInput(const char *functionName)
{
Oid functionOid = InvalidOid;
Relation procRelation = NULL;
SysScanDesc scanDescriptor = NULL;
ScanKeyData scanKey[1];
int scanKeyCount = 1;
HeapTuple heapTuple = NULL;
procRelation = heap_open(ProcedureRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_proc_proname,
BTEqualStrategyNumber, F_NAMEEQ, CStringGetDatum(functionName));
scanDescriptor = systable_beginscan(procRelation,
ProcedureNameArgsNspIndexId, true,
NULL, scanKeyCount, scanKey);
/* loop until we find the right function */
heapTuple = systable_getnext(scanDescriptor);
if (HeapTupleIsValid(heapTuple))
{
functionOid = HeapTupleGetOid(heapTuple);
}
if (functionOid == InvalidOid)
{
ereport(ERROR, (errmsg("no matching oid for function: %s", functionName)));
}
systable_endscan(scanDescriptor);
heap_close(procRelation, AccessShareLock);
return functionOid;
}
/*
* TypeOid looks for a type that has the given name and schema, and returns the
* corresponding type's oid.
@ -3064,42 +3206,34 @@ CreateSortGroupClause(Var *column)
* CountDistinctHashFunctionName resolves the hll_hash function name to use for
* the given input type, and returns this function name.
*/
static char *
static const char *
CountDistinctHashFunctionName(Oid argumentType)
{
char *hashFunctionName = NULL;
/* resolve hash function name based on input argument type */
switch (argumentType)
{
case INT4OID:
{
hashFunctionName = pstrdup(HLL_HASH_INTEGER_FUNC_NAME);
break;
return HLL_HASH_INTEGER_FUNC_NAME;
}
case INT8OID:
{
hashFunctionName = pstrdup(HLL_HASH_BIGINT_FUNC_NAME);
break;
return HLL_HASH_BIGINT_FUNC_NAME;
}
case TEXTOID:
case BPCHAROID:
case VARCHAROID:
{
hashFunctionName = pstrdup(HLL_HASH_TEXT_FUNC_NAME);
break;
return HLL_HASH_TEXT_FUNC_NAME;
}
default:
{
hashFunctionName = pstrdup(HLL_HASH_ANY_FUNC_NAME);
break;
return HLL_HASH_ANY_FUNC_NAME;
}
}
return hashFunctionName;
}

View File

@ -26,6 +26,8 @@
#define ARRAY_CAT_AGGREGATE_NAME "array_cat_agg"
#define JSONB_CAT_AGGREGATE_NAME "jsonb_cat_agg"
#define JSON_CAT_AGGREGATE_NAME "json_cat_agg"
#define WORKER_PARTIAL_AGGREGATE_NAME "worker_partial_agg"
#define COORD_COMBINE_AGGREGATE_NAME "coord_combine_agg"
#define WORKER_COLUMN_FORMAT "worker_column_%d"
/* Definitions related to count(distinct) approximations */