From 7c0589abb8e1b8779515857c02d74b7f6971b659 Mon Sep 17 00:00:00 2001 From: Marco Slot Date: Mon, 3 Apr 2023 19:43:09 +0200 Subject: [PATCH] Do not override combinefunc of custom aggregates with common names (#6805) DESCRIPTION: Fix an issue that caused some queries with custom aggregates to fail While playing around with https://github.com/pgvector/pgvector I noticed that the AVG query was broken. That's because we treat it as any other AVG by breaking it down in SUM and COUNT, but there are no SUM/COUNT functions in this case, but there is a perfectly usable combinefunc. This PR changes our aggregate logic to prefer custom aggregates with a combinefunc even if they have a common name. Co-authored-by: Marco Slot --- .../planner/multi_logical_optimizer.c | 11 +++++-- .../regress/expected/aggregate_support.out | 33 ++++++++++++++++++- src/test/regress/sql/aggregate_support.sql | 31 +++++++++++++++++ 3 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index 19b4aea4d..851afc4b6 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -3385,6 +3385,13 @@ GetAggregateType(Aggref *aggregateExpression) { Oid aggFunctionId = aggregateExpression->aggfnoid; + /* custom aggregates with combine func take precedence over name-based logic */ + if (aggFunctionId >= FirstNormalObjectId && + AggregateEnabledCustom(aggregateExpression)) + { + return AGGREGATE_CUSTOM_COMBINE; + } + /* look up the function name */ char *aggregateProcName = get_func_name(aggFunctionId); if (aggregateProcName == NULL) @@ -3395,8 +3402,6 @@ GetAggregateType(Aggref *aggregateExpression) uint32 aggregateCount = lengthof(AggregateNames); - Assert(AGGREGATE_INVALID_FIRST == 0); - for (uint32 aggregateIndex = 1; aggregateIndex < aggregateCount; aggregateIndex++) { const char *aggregateName = AggregateNames[aggregateIndex]; @@ -3465,7 +3470,7 @@ GetAggregateType(Aggref *aggregateExpression) } } - + /* handle any remaining built-in aggregates with a suitable combinefn */ if (AggregateEnabledCustom(aggregateExpression)) { return AGGREGATE_CUSTOM_COMBINE; diff --git a/src/test/regress/expected/aggregate_support.out b/src/test/regress/expected/aggregate_support.out index ba9d9e2f3..57bbbbe78 100644 --- a/src/test/regress/expected/aggregate_support.out +++ b/src/test/regress/expected/aggregate_support.out @@ -665,6 +665,8 @@ select array_collect_sort(val) from aggdata; (1 row) reset role; +drop owned by notsuper; +drop user notsuper; -- Test aggregation on coordinator set citus.coordinator_aggregation_strategy to 'row-gather'; select key, first(val order by id), last(val order by id) @@ -1233,11 +1235,40 @@ CREATE AGGREGATE newavg ( initcond1 = '{0,0}' ); SELECT run_command_on_workers($$select aggfnoid from pg_aggregate where aggfnoid::text like '%newavg%';$$); - run_command_on_workers + run_command_on_workers --------------------------------------------------------------------- (localhost,57637,t,aggregate_support.newavg) (localhost,57638,t,aggregate_support.newavg) (2 rows) +CREATE TYPE coord AS (x int, y int); +CREATE FUNCTION coord_minx_sfunc(state coord, new coord) +returns coord immutable language plpgsql as $$ +BEGIN + IF (state IS NULL OR new.x < state.x) THEN + RETURN new; + ELSE + RETURN state; + END IF; +END +$$; +create function coord_minx_finalfunc(state coord) +returns coord immutable language plpgsql as $$ +begin return state; +end; +$$; +-- custom aggregate that has the same name as a built-in function, but with a combinefunc +create aggregate min (coord) ( + sfunc = coord_minx_sfunc, + stype = coord, + finalfunc = coord_minx_finalfunc, + combinefunc = coord_minx_sfunc +); +select min((id,val)::coord) from aggdata; + min +--------------------------------------------------------------------- + (1,2) +(1 row) + set client_min_messages to error; drop schema aggregate_support cascade; diff --git a/src/test/regress/sql/aggregate_support.sql b/src/test/regress/sql/aggregate_support.sql index a991d856e..bccac491e 100644 --- a/src/test/regress/sql/aggregate_support.sql +++ b/src/test/regress/sql/aggregate_support.sql @@ -364,6 +364,8 @@ $$); set role notsuper; select array_collect_sort(val) from aggdata; reset role; +drop owned by notsuper; +drop user notsuper; -- Test aggregation on coordinator set citus.coordinator_aggregation_strategy to 'row-gather'; @@ -645,5 +647,34 @@ CREATE AGGREGATE newavg ( SELECT run_command_on_workers($$select aggfnoid from pg_aggregate where aggfnoid::text like '%newavg%';$$); +CREATE TYPE coord AS (x int, y int); + +CREATE FUNCTION coord_minx_sfunc(state coord, new coord) +returns coord immutable language plpgsql as $$ +BEGIN + IF (state IS NULL OR new.x < state.x) THEN + RETURN new; + ELSE + RETURN state; + END IF; +END +$$; + +create function coord_minx_finalfunc(state coord) +returns coord immutable language plpgsql as $$ +begin return state; +end; +$$; + +-- custom aggregate that has the same name as a built-in function, but with a combinefunc +create aggregate min (coord) ( + sfunc = coord_minx_sfunc, + stype = coord, + finalfunc = coord_minx_finalfunc, + combinefunc = coord_minx_sfunc +); + +select min((id,val)::coord) from aggdata; + set client_min_messages to error; drop schema aggregate_support cascade;