mirror of https://github.com/citusdata/citus.git
Initial sketch of aggregate expression transforms
parent
c801889e08
commit
1a7bf4ca2c
|
@ -252,12 +252,13 @@ static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
|
|||
WorkerAggregateWalkerContext *walkerContextry);
|
||||
static AggregateType GetAggregateType(Oid aggFunctionId);
|
||||
static Oid AggregateArgumentType(Aggref *aggregate);
|
||||
static Oid AggregateFunctionOidWithoutInput(const char *functionName);
|
||||
static Oid AggregateFunctionOid(const char *functionName, Oid inputType);
|
||||
static Oid TypeOid(Oid schemaId, const char *typeName);
|
||||
static SortGroupClause * CreateSortGroupClause(Var *column);
|
||||
|
||||
/* Local functions forward declarations for count(distinct) approximations */
|
||||
static char * CountDistinctHashFunctionName(Oid argumentType);
|
||||
static const char * CountDistinctHashFunctionName(Oid argumentType);
|
||||
static int CountDistinctStorageSize(double approximationErrorRate);
|
||||
static Const * MakeIntegerConst(int32 integerValue);
|
||||
static Const * MakeIntegerConstInt64(int64 integerValue);
|
||||
|
@ -1518,9 +1519,12 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
|||
ObjectIdGetDatum(originalAggregate->aggfnoid));
|
||||
if (!HeapTupleIsValid(aggTuple))
|
||||
{
|
||||
elog(ERROR, "cache lookup failed for aggregate %u",
|
||||
elog(WARNING, "cache lookup failed for aggregate %u",
|
||||
originalAggregate->aggfnoid);
|
||||
combine = InvalidOid;
|
||||
}
|
||||
else
|
||||
{
|
||||
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
|
||||
combine = aggform->aggcombinefn;
|
||||
if (combine != InvalidOid && originalAggregate->aggtranstype == INTERNALOID)
|
||||
|
@ -1529,9 +1533,45 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
|||
deserial = aggform->aggdeserialfn;
|
||||
}
|
||||
ReleaseSysCache(aggTuple);
|
||||
}
|
||||
|
||||
if (combine != InvalidOid)
|
||||
{ }
|
||||
{
|
||||
Const *aggparam = NULL;
|
||||
Var *column = NULL;
|
||||
List *aggArguments = NIL;
|
||||
Aggref *newMasterAggregate = NULL;
|
||||
Oid coordCombineId = AggregateFunctionOidWithoutInput(
|
||||
COORD_COMBINE_AGGREGATE_NAME);
|
||||
|
||||
Oid workerReturnType = BYTEAOID;
|
||||
int32 workerReturnTypeMod = -1;
|
||||
Oid workerCollationId = InvalidOid;
|
||||
|
||||
|
||||
aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum(
|
||||
originalAggregate->aggfnoid), false, true);
|
||||
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
|
||||
workerReturnTypeMod, workerCollationId, columnLevelsUp);
|
||||
|
||||
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
|
||||
aggArguments = lappend(aggArguments, makeTargetEntry((Expr *) column, 2, NULL,
|
||||
false));
|
||||
|
||||
/* coord_combine_agg(agg, workercol) */
|
||||
newMasterAggregate = makeNode(Aggref);
|
||||
newMasterAggregate->aggfnoid = coordCombineId;
|
||||
newMasterAggregate->aggtype = originalAggregate->aggtype;
|
||||
newMasterAggregate->args = aggArguments;
|
||||
newMasterAggregate->aggkind = AGGKIND_NORMAL;
|
||||
newMasterAggregate->aggfilter = originalAggregate->aggfilter;
|
||||
newMasterAggregate->aggtranstype = INTERNALOID;
|
||||
newMasterAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID),
|
||||
list_make1_oid(BYTEAOID));
|
||||
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
|
||||
|
||||
newMasterExpression = (Expr *) newMasterAggregate;
|
||||
}
|
||||
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
|
||||
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
|
||||
walkerContext->pullDistinctColumns)
|
||||
|
@ -1585,7 +1625,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
|
|||
{
|
||||
/*
|
||||
* If enabled, we check for count(distinct) approximations before count
|
||||
* distincts. For this, we first compute hll_add_agg(hll_hash(column) on
|
||||
* distincts. For this, we first compute hll_add_agg(hll_hash(column)) on
|
||||
* worker nodes, and get hll values. We then gather hlls on the master
|
||||
* node, and compute hll_cardinality(hll_union_agg(hll)).
|
||||
*/
|
||||
|
@ -2765,8 +2805,67 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
|
|||
AggregateType aggregateType = GetAggregateType(originalAggregate->aggfnoid);
|
||||
List *workerAggregateList = NIL;
|
||||
AggClauseCosts aggregateCosts;
|
||||
HeapTuple aggTuple;
|
||||
Form_pg_aggregate aggform;
|
||||
Oid combine;
|
||||
Oid serial = InvalidOid;
|
||||
Oid deserial = InvalidOid;
|
||||
|
||||
if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
|
||||
aggTuple = SearchSysCache1(AGGFNOID,
|
||||
ObjectIdGetDatum(originalAggregate->aggfnoid));
|
||||
if (!HeapTupleIsValid(aggTuple))
|
||||
{
|
||||
elog(WARNING, "cache lookup failed for aggregate %u",
|
||||
originalAggregate->aggfnoid);
|
||||
combine = InvalidOid;
|
||||
}
|
||||
else
|
||||
{
|
||||
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
|
||||
combine = aggform->aggcombinefn;
|
||||
if (combine != InvalidOid && originalAggregate->aggtranstype == INTERNALOID)
|
||||
{
|
||||
serial = aggform->aggserialfn;
|
||||
deserial = aggform->aggdeserialfn;
|
||||
}
|
||||
ReleaseSysCache(aggTuple);
|
||||
}
|
||||
|
||||
if (combine != InvalidOid)
|
||||
{
|
||||
Const *aggparam = NULL;
|
||||
Aggref *newWorkerAggregate = NULL;
|
||||
List *aggArguments = NIL;
|
||||
ListCell *originalAggArgCell;
|
||||
Oid workerPartialId = AggregateFunctionOidWithoutInput(
|
||||
WORKER_PARTIAL_AGGREGATE_NAME);
|
||||
|
||||
aggparam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid), ObjectIdGetDatum(
|
||||
originalAggregate->aggfnoid), false, true);
|
||||
aggArguments = list_make1(makeTargetEntry((Expr *) aggparam, 1, NULL, false));
|
||||
foreach(originalAggArgCell, originalAggregate->args)
|
||||
{
|
||||
TargetEntry *arg = (TargetEntry *) lfirst(originalAggArgCell);
|
||||
TargetEntry *newArg = copyObject(arg);
|
||||
newArg->resno++;
|
||||
aggArguments = lappend(aggArguments, newArg);
|
||||
}
|
||||
|
||||
/* worker_partial_agg(agg, ...args) */
|
||||
newWorkerAggregate = makeNode(Aggref);
|
||||
newWorkerAggregate->aggfnoid = workerPartialId;
|
||||
newWorkerAggregate->aggtype = originalAggregate->aggtype;
|
||||
newWorkerAggregate->args = aggArguments;
|
||||
newWorkerAggregate->aggkind = AGGKIND_NORMAL;
|
||||
newWorkerAggregate->aggfilter = originalAggregate->aggfilter;
|
||||
newWorkerAggregate->aggtranstype = INTERNALOID;
|
||||
newWorkerAggregate->aggargtypes = list_concat(list_make1_oid(OIDOID),
|
||||
originalAggregate->aggargtypes);
|
||||
newWorkerAggregate->aggsplit = AGGSPLIT_SIMPLE;
|
||||
|
||||
workerAggregateList = list_make1(newWorkerAggregate);
|
||||
}
|
||||
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
|
||||
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
|
||||
walkerContext->pullDistinctColumns)
|
||||
{
|
||||
|
@ -2808,7 +2907,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
|
|||
Oid hllSchemaOid = get_extension_schema(hllId);
|
||||
const char *hllSchemaName = get_namespace_name(hllSchemaOid);
|
||||
|
||||
char *hashFunctionName = CountDistinctHashFunctionName(argumentType);
|
||||
const char *hashFunctionName = CountDistinctHashFunctionName(argumentType);
|
||||
Oid hashFunctionId = FunctionOid(hllSchemaName, hashFunctionName,
|
||||
hashArgumentCount);
|
||||
Oid hashFunctionReturnType = get_func_rettype(hashFunctionId);
|
||||
|
@ -3018,6 +3117,49 @@ AggregateFunctionOid(const char *functionName, Oid inputType)
|
|||
}
|
||||
|
||||
|
||||
/*
|
||||
* AggregateFunctionOid performs a reverse lookup on aggregate function name,
|
||||
* and returns the corresponding aggregate function oid for the given function
|
||||
* name and input type.
|
||||
*/
|
||||
static Oid
|
||||
AggregateFunctionOidWithoutInput(const char *functionName)
|
||||
{
|
||||
Oid functionOid = InvalidOid;
|
||||
Relation procRelation = NULL;
|
||||
SysScanDesc scanDescriptor = NULL;
|
||||
ScanKeyData scanKey[1];
|
||||
int scanKeyCount = 1;
|
||||
HeapTuple heapTuple = NULL;
|
||||
|
||||
procRelation = heap_open(ProcedureRelationId, AccessShareLock);
|
||||
|
||||
ScanKeyInit(&scanKey[0], Anum_pg_proc_proname,
|
||||
BTEqualStrategyNumber, F_NAMEEQ, CStringGetDatum(functionName));
|
||||
|
||||
scanDescriptor = systable_beginscan(procRelation,
|
||||
ProcedureNameArgsNspIndexId, true,
|
||||
NULL, scanKeyCount, scanKey);
|
||||
|
||||
/* loop until we find the right function */
|
||||
heapTuple = systable_getnext(scanDescriptor);
|
||||
if (HeapTupleIsValid(heapTuple))
|
||||
{
|
||||
functionOid = HeapTupleGetOid(heapTuple);
|
||||
}
|
||||
|
||||
if (functionOid == InvalidOid)
|
||||
{
|
||||
ereport(ERROR, (errmsg("no matching oid for function: %s", functionName)));
|
||||
}
|
||||
|
||||
systable_endscan(scanDescriptor);
|
||||
heap_close(procRelation, AccessShareLock);
|
||||
|
||||
return functionOid;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
* TypeOid looks for a type that has the given name and schema, and returns the
|
||||
* corresponding type's oid.
|
||||
|
@ -3064,42 +3206,34 @@ CreateSortGroupClause(Var *column)
|
|||
* CountDistinctHashFunctionName resolves the hll_hash function name to use for
|
||||
* the given input type, and returns this function name.
|
||||
*/
|
||||
static char *
|
||||
static const char *
|
||||
CountDistinctHashFunctionName(Oid argumentType)
|
||||
{
|
||||
char *hashFunctionName = NULL;
|
||||
|
||||
/* resolve hash function name based on input argument type */
|
||||
switch (argumentType)
|
||||
{
|
||||
case INT4OID:
|
||||
{
|
||||
hashFunctionName = pstrdup(HLL_HASH_INTEGER_FUNC_NAME);
|
||||
break;
|
||||
return HLL_HASH_INTEGER_FUNC_NAME;
|
||||
}
|
||||
|
||||
case INT8OID:
|
||||
{
|
||||
hashFunctionName = pstrdup(HLL_HASH_BIGINT_FUNC_NAME);
|
||||
break;
|
||||
return HLL_HASH_BIGINT_FUNC_NAME;
|
||||
}
|
||||
|
||||
case TEXTOID:
|
||||
case BPCHAROID:
|
||||
case VARCHAROID:
|
||||
{
|
||||
hashFunctionName = pstrdup(HLL_HASH_TEXT_FUNC_NAME);
|
||||
break;
|
||||
return HLL_HASH_TEXT_FUNC_NAME;
|
||||
}
|
||||
|
||||
default:
|
||||
{
|
||||
hashFunctionName = pstrdup(HLL_HASH_ANY_FUNC_NAME);
|
||||
break;
|
||||
return HLL_HASH_ANY_FUNC_NAME;
|
||||
}
|
||||
}
|
||||
|
||||
return hashFunctionName;
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -26,6 +26,8 @@
|
|||
#define ARRAY_CAT_AGGREGATE_NAME "array_cat_agg"
|
||||
#define JSONB_CAT_AGGREGATE_NAME "jsonb_cat_agg"
|
||||
#define JSON_CAT_AGGREGATE_NAME "json_cat_agg"
|
||||
#define WORKER_PARTIAL_AGGREGATE_NAME "worker_partial_agg"
|
||||
#define COORD_COMBINE_AGGREGATE_NAME "coord_combine_agg"
|
||||
#define WORKER_COLUMN_FORMAT "worker_column_%d"
|
||||
|
||||
/* Definitions related to count(distinct) approximations */
|
||||
|
|
Loading…
Reference in New Issue