aggregate_support test: test DISTINCT, ORDER BY, FILTER, & no intermediate results

Previously,
- we'd push down ORDER BY, but this doesn't order intermediate results between workers
- we'd keep FILTER on master aggregate, which would raise an error about unexpected cstrings
pull/3249/head
Philip Dubé 2019-11-30 16:11:05 +00:00
parent ffacefc2ad
commit 1597fbb369
3 changed files with 92 additions and 41 deletions

View File

@ -250,9 +250,9 @@ static bool WorkerAggregateWalker(Node *node,
WorkerAggregateWalkerContext *walkerContext);
static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *walkerContextry);
static AggregateType GetAggregateType(Oid aggFunctionId);
static AggregateType GetAggregateType(Aggref *aggregatExpression);
static Oid AggregateArgumentType(Aggref *aggregate);
static bool AggregateEnabledCustom(Oid aggregateOid);
static bool AggregateEnabledCustom(Aggref *aggregateExpression);
static Oid CitusFunctionOidWithSignature(char *functionName, int numargs, Oid *argtypes);
static Oid WorkerPartialAggOid(void);
static Oid CoordCombineAggOid(void);
@ -1494,7 +1494,7 @@ static Expr *
MasterAggregateExpression(Aggref *originalAggregate,
MasterAggregateWalkerContext *walkerContext)
{
AggregateType aggregateType = GetAggregateType(originalAggregate->aggfnoid);
AggregateType aggregateType = GetAggregateType(originalAggregate);
Expr *newMasterExpression = NULL;
const uint32 masterTableId = 1; /* one table on the master node */
const Index columnLevelsUp = 0; /* normal column */
@ -1821,9 +1821,8 @@ MasterAggregateExpression(Aggref *originalAggregate,
}
else if (aggregateType == AGGREGATE_CUSTOM)
{
HeapTuple aggTuple = SearchSysCache1(AGGFNOID,
ObjectIdGetDatum(
originalAggregate->aggfnoid));
HeapTuple aggTuple =
SearchSysCache1(AGGFNOID, ObjectIdGetDatum(originalAggregate->aggfnoid));
Form_pg_aggregate aggform;
Oid combine;
@ -1857,12 +1856,10 @@ MasterAggregateExpression(Aggref *originalAggregate,
walkerContext->columnId++;
Const *nullTag = makeNullConst(resultType, -1, InvalidOid);
List *aggArguments = list_make3(makeTargetEntry((Expr *) aggOidParam, 1, NULL,
false),
makeTargetEntry((Expr *) column, 2, NULL,
false),
makeTargetEntry((Expr *) nullTag, 3, NULL,
false));
List *aggArguments =
list_make3(makeTargetEntry((Expr *) aggOidParam, 1, NULL, false),
makeTargetEntry((Expr *) column, 2, NULL, false),
makeTargetEntry((Expr *) nullTag, 3, NULL, false));
/* coord_combine_agg(agg, workercol) */
Aggref *newMasterAggregate = makeNode(Aggref);
@ -1870,7 +1867,7 @@ MasterAggregateExpression(Aggref *originalAggregate,
newMasterAggregate->aggtype = originalAggregate->aggtype;
newMasterAggregate->args = aggArguments;
newMasterAggregate->aggkind = AGGKIND_NORMAL;
newMasterAggregate->aggfilter = originalAggregate->aggfilter;
newMasterAggregate->aggfilter = NULL;
newMasterAggregate->aggtranstype = INTERNALOID;
newMasterAggregate->aggargtypes = list_make3_oid(OIDOID, CSTRINGOID,
resultType);
@ -2776,7 +2773,7 @@ static List *
WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *walkerContext)
{
AggregateType aggregateType = GetAggregateType(originalAggregate->aggfnoid);
AggregateType aggregateType = GetAggregateType(originalAggregate);
List *workerAggregateList = NIL;
AggClauseCosts aggregateCosts;
@ -2893,9 +2890,8 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
}
else if (aggregateType == AGGREGATE_CUSTOM)
{
HeapTuple aggTuple = SearchSysCache1(AGGFNOID,
ObjectIdGetDatum(
originalAggregate->aggfnoid));
HeapTuple aggTuple =
SearchSysCache1(AGGFNOID, ObjectIdGetDatum(originalAggregate->aggfnoid));
Form_pg_aggregate aggform;
Oid combine;
@ -2919,8 +2915,7 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
Const *aggOidParam = makeConst(REGPROCEDUREOID, -1, InvalidOid, sizeof(Oid),
ObjectIdGetDatum(originalAggregate->aggfnoid),
false,
true);
false, true);
List *aggArguments = list_make1(makeTargetEntry((Expr *) aggOidParam, 1, NULL,
false));
foreach(originalAggArgCell, originalAggregate->args)
@ -2932,15 +2927,14 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
}
/* worker_partial_agg(agg, ...args) */
Aggref *newWorkerAggregate = makeNode(Aggref);
Aggref *newWorkerAggregate = copyObject(originalAggregate);
newWorkerAggregate->aggfnoid = workerPartialId;
newWorkerAggregate->aggtype = CSTRINGOID;
newWorkerAggregate->args = aggArguments;
newWorkerAggregate->aggkind = AGGKIND_NORMAL;
newWorkerAggregate->aggfilter = originalAggregate->aggfilter;
newWorkerAggregate->aggtranstype = INTERNALOID;
newWorkerAggregate->aggargtypes = lcons_oid(OIDOID,
originalAggregate->aggargtypes);
newWorkerAggregate->aggargtypes);
newWorkerAggregate->aggsplit = AGGSPLIT_SIMPLE;
workerAggregateList = list_make1(newWorkerAggregate);
@ -2976,8 +2970,10 @@ WorkerAggregateExpressionList(Aggref *originalAggregate,
* previously stored strings, and returns the appropriate aggregate type.
*/
static AggregateType
GetAggregateType(Oid aggFunctionId)
GetAggregateType(Aggref *aggregateExpression)
{
Oid aggFunctionId = aggregateExpression->aggfnoid;
/* look up the function name */
char *aggregateProcName = get_func_name(aggFunctionId);
if (aggregateProcName == NULL)
@ -2999,7 +2995,7 @@ GetAggregateType(Oid aggFunctionId)
}
}
if (AggregateEnabledCustom(aggFunctionId))
if (AggregateEnabledCustom(aggregateExpression))
{
return AGGREGATE_CUSTOM;
}
@ -3028,8 +3024,15 @@ AggregateArgumentType(Aggref *aggregate)
* distributed across workers using worker_partial_agg & coord_combine_agg.
*/
static bool
AggregateEnabledCustom(Oid aggregateOid)
AggregateEnabledCustom(Aggref *aggregateExpression)
{
if (aggregateExpression->aggorder != NIL ||
list_length(aggregateExpression->args) != 1)
{
return false;
}
Oid aggregateOid = aggregateExpression->aggfnoid;
HeapTuple aggTuple = SearchSysCache1(AGGFNOID, aggregateOid);
if (!HeapTupleIsValid(aggTuple))
{
@ -3345,7 +3348,7 @@ ErrorIfContainsUnsupportedAggregate(MultiNode *logicalPlanNode)
/* GetAggregateType errors out on unsupported aggregate types */
Aggref *aggregateExpression = (Aggref *) expression;
AggregateType aggregateType = GetAggregateType(aggregateExpression->aggfnoid);
AggregateType aggregateType = GetAggregateType(aggregateExpression);
Assert(aggregateType != AGGREGATE_INVALID_FIRST);
/*
@ -3439,7 +3442,7 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
char *errorDetail = NULL;
bool distinctSupported = true;
AggregateType aggregateType = GetAggregateType(aggregateExpression->aggfnoid);
AggregateType aggregateType = GetAggregateType(aggregateExpression);
/*
* We partially support count(distinct) in subqueries, other distinct aggregates in
@ -4173,7 +4176,7 @@ GenerateNewTargetEntriesForSortClauses(List *originalTargetList,
else
{
Aggref *aggNode = (Aggref *) targetExpr;
AggregateType aggregateType = GetAggregateType(aggNode->aggfnoid);
AggregateType aggregateType = GetAggregateType(aggNode);
if (aggregateType == AGGREGATE_AVERAGE)
{
createNewTargetEntry = true;
@ -4285,7 +4288,7 @@ HasOrderByAverage(List *sortClauseList, List *targetList)
{
Aggref *aggregate = (Aggref *) sortExpression;
AggregateType aggregateType = GetAggregateType(aggregate->aggfnoid);
AggregateType aggregateType = GetAggregateType(aggregate);
if (aggregateType == AGGREGATE_AVERAGE)
{
hasOrderByAverage = true;

View File

@ -58,8 +58,7 @@ select create_distributed_table('aggdata', 'id');
(1 row)
insert into aggdata (id, key, val, valf) values (1, 1, 2, 11.2), (2, 1, NULL, 2.1), (3, 2, 2, 3.22), (4, 2, 3, 4.23), (5, 2, 5, 5.25), (6, 3, 4, 63.4), (7, 5, NULL, 75), (8, 6, NULL, NULL), (9, 6, NULL, 96), (10, 7, 8, 1078), (11, 9, 0, 1.19);
select key, sum2(val), sum2_strict(val), stddev(valf)
from aggdata group by key order by key;
select key, sum2(val), sum2_strict(val), stddev(valf) from aggdata group by key order by key;
key | sum2 | sum2_strict | stddev
-----+------+-------------+------------------
1 | | 4 | 6.43467170879758
@ -71,6 +70,49 @@ from aggdata group by key order by key;
9 | 0 | 0 |
(7 rows)
-- FILTER supported
select key, sum2(val) filter (where valf < 5), sum2_strict(val) filter (where valf < 5) from aggdata group by key order by key;
key | sum2 | sum2_strict
-----+------+-------------
1 | |
2 | | 10
3 | |
5 | |
6 | |
7 | |
9 | 0 | 0
(7 rows)
-- DISTINCT unsupported, unless grouped by partition key
select key, sum2(distinct val), sum2_strict(distinct val) from aggdata group by key order by key;
ERROR: cannot compute aggregate (distinct)
DETAIL: table partitioning is unsuitable for aggregate (distinct)
select id, sum2(distinct val), sum2_strict(distinct val) from aggdata group by id order by id;
id | sum2 | sum2_strict
----+------+-------------
1 | 4 | 4
2 | |
3 | 4 | 4
4 | 6 | 6
5 | 10 | 10
6 | 8 | 8
7 | |
8 | |
9 | |
10 | 16 | 16
11 | 0 | 0
(11 rows)
-- ORDER BY unsupported
select key, sum2(val order by valf), sum2_strict(val order by valf) from aggdata group by key order by key;
ERROR: unsupported aggregate function sum2
-- Without intermediate results we return NULL, even though the correct result is 0
select sum2(val) from aggdata where valf = 0;
sum2
------
(1 row)
-- test polymorphic aggregates from https://github.com/citusdata/citus/issues/2397
-- we do not currently support pseudotypes for transition types, so this errors for now
CREATE OR REPLACE FUNCTION first_agg(anyelement, anyelement)
@ -121,8 +163,7 @@ create aggregate sumstring(text) (
combinefunc = sumstring_sfunc,
initcond = '0'
);
select sumstring(valf::text order by id)
from aggdata where valf is not null;
select sumstring(valf::text) from aggdata where valf is not null;
ERROR: function "aggregate_support.sumstring(text)" does not exist
CONTEXT: while executing command on localhost:57637
select create_distributed_function('sumstring(text)');
@ -131,8 +172,7 @@ select create_distributed_function('sumstring(text)');
(1 row)
select sumstring(valf::text order by id)
from aggdata where valf is not null;
select sumstring(valf::text) from aggdata where valf is not null;
sumstring
-----------
1339.59

View File

@ -52,8 +52,18 @@ select create_distributed_function('sum2_strict(int)');
create table aggdata (id int, key int, val int, valf float8);
select create_distributed_table('aggdata', 'id');
insert into aggdata (id, key, val, valf) values (1, 1, 2, 11.2), (2, 1, NULL, 2.1), (3, 2, 2, 3.22), (4, 2, 3, 4.23), (5, 2, 5, 5.25), (6, 3, 4, 63.4), (7, 5, NULL, 75), (8, 6, NULL, NULL), (9, 6, NULL, 96), (10, 7, 8, 1078), (11, 9, 0, 1.19);
select key, sum2(val), sum2_strict(val), stddev(valf)
from aggdata group by key order by key;
select key, sum2(val), sum2_strict(val), stddev(valf) from aggdata group by key order by key;
-- FILTER supported
select key, sum2(val) filter (where valf < 5), sum2_strict(val) filter (where valf < 5) from aggdata group by key order by key;
-- DISTINCT unsupported, unless grouped by partition key
select key, sum2(distinct val), sum2_strict(distinct val) from aggdata group by key order by key;
select id, sum2(distinct val), sum2_strict(distinct val) from aggdata group by id order by id;
-- ORDER BY unsupported
select key, sum2(val order by valf), sum2_strict(val order by valf) from aggdata group by key order by key;
-- Without intermediate results we return NULL, even though the correct result is 0
select sum2(val) from aggdata where valf = 0;
-- test polymorphic aggregates from https://github.com/citusdata/citus/issues/2397
-- we do not currently support pseudotypes for transition types, so this errors for now
@ -102,11 +112,9 @@ create aggregate sumstring(text) (
initcond = '0'
);
select sumstring(valf::text order by id)
from aggdata where valf is not null;
select sumstring(valf::text) from aggdata where valf is not null;
select create_distributed_function('sumstring(text)');
select sumstring(valf::text order by id)
from aggdata where valf is not null;
select sumstring(valf::text) from aggdata where valf is not null;
-- test aggregate with stype that has an expanded read-write form
CREATE FUNCTION array_sort (int[])