diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index c31cf9748..af14d87cc 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -1489,7 +1489,7 @@ MasterAggregateMutator(Node *originalNode, MasterAggregateWalkerContext *walkerC /* * MasterAggregateExpression creates the master aggregate expression using the * original aggregate and aggregate's type information. This function handles - * the average, count, and array_agg aggregates separately due to differences + * the average, count, array_agg, and hll aggregates separately due to differences * in these aggregate functions' transformations. * * Note that this function has implicit knowledge of the transformations applied @@ -1763,6 +1763,40 @@ MasterAggregateExpression(Aggref *originalAggregate, newMasterExpression = (Expr *) newMasterAggregate; } + else if (aggregateType == AGGREGATE_HLL_ADD || + aggregateType == AGGREGATE_HLL_UNION) + { + /* + * If hll aggregates are called, we simply create the hll_union_aggregate + * to apply in the master after running the original aggregate in + * workers. + */ + TargetEntry *hllTargetEntry = NULL; + Aggref *unionAggregate = NULL; + + Oid hllType = exprType((Node *) originalAggregate); + Oid unionFunctionId = AggregateFunctionOid(HLL_UNION_AGGREGATE_NAME, hllType); + int32 hllReturnTypeMod = exprTypmod((Node *) originalAggregate); + Oid hllTypeCollationId = exprCollation((Node *) originalAggregate); + + Var *hllColumn = makeVar(masterTableId, walkerContext->columnId, hllType, + hllReturnTypeMod, hllTypeCollationId, columnLevelsUp); + walkerContext->columnId++; + + hllTargetEntry = makeTargetEntry((Expr *) hllColumn, argumentId, NULL, false); + + unionAggregate = makeNode(Aggref); + unionAggregate->aggfnoid = unionFunctionId; + unionAggregate->aggtype = hllType; + unionAggregate->args = list_make1(hllTargetEntry); + unionAggregate->aggkind = AGGKIND_NORMAL; + unionAggregate->aggfilter = NULL; + unionAggregate->aggtranstype = InvalidOid; + unionAggregate->aggargtypes = list_make1_oid(hllType); + unionAggregate->aggsplit = AGGSPLIT_SIMPLE; + + newMasterExpression = (Expr *) unionAggregate; + } else { /* diff --git a/src/include/distributed/multi_logical_optimizer.h b/src/include/distributed/multi_logical_optimizer.h index 877a0addf..483c0cc60 100644 --- a/src/include/distributed/multi_logical_optimizer.h +++ b/src/include/distributed/multi_logical_optimizer.h @@ -66,7 +66,9 @@ typedef enum AGGREGATE_BIT_OR = 12, AGGREGATE_BOOL_AND = 13, AGGREGATE_BOOL_OR = 14, - AGGREGATE_EVERY = 15 + AGGREGATE_EVERY = 15, + AGGREGATE_HLL_ADD = 16, + AGGREGATE_HLL_UNION = 17 } AggregateType; @@ -111,7 +113,8 @@ static const char *const AggregateNames[] = { "sum", "count", "array_agg", "jsonb_agg", "jsonb_object_agg", "json_agg", "json_object_agg", - "bit_and", "bit_or", "bool_and", "bool_or", "every" + "bit_and", "bit_or", "bool_and", "bool_or", "every", + "hll_add_agg", "hll_union_agg" };