citus/src/backend/distributed/planner/multi_logical_optimizer.c

5154 lines
166 KiB
C
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/*-------------------------------------------------------------------------
*
* multi_logical_optimizer.c
* Routines for optimizing logical plan trees based on multi-relational
* algebra.
*
* Copyright (c) Citus Data, Inc.
*
* $Id$
*
*-------------------------------------------------------------------------
*/
#include <math.h>
#include "postgres.h"
#include "access/genam.h"
#include "access/heapam.h"
#include "access/htup_details.h"
#include "access/nbtree.h"
#include "catalog/indexing.h"
#include "catalog/pg_aggregate.h"
#include "catalog/pg_am.h"
#include "catalog/pg_proc.h"
#include "catalog/pg_type.h"
#include "commands/extension.h"
#include "nodes/makefuncs.h"
#include "nodes/nodeFuncs.h"
#include "optimizer/clauses.h"
#include "optimizer/optimizer.h"
#include "optimizer/tlist.h"
#include "parser/parse_agg.h"
#include "parser/parse_coerce.h"
#include "parser/parse_oper.h"
#include "parser/parsetree.h"
#include "rewrite/rewriteManip.h"
#include "utils/fmgroids.h"
#include "utils/lsyscache.h"
#include "utils/rel.h"
#include "utils/syscache.h"
#include "pg_version_constants.h"
#include "distributed/citus_nodes.h"
#include "distributed/citus_ruleutils.h"
#include "distributed/colocation_utils.h"
#include "distributed/errormessage.h"
#include "distributed/extended_op_node_utils.h"
#include "distributed/function_utils.h"
#include "distributed/listutils.h"
#include "distributed/metadata_cache.h"
#include "distributed/multi_logical_optimizer.h"
#include "distributed/multi_logical_planner.h"
#include "distributed/multi_physical_planner.h"
#include "distributed/pg_dist_partition.h"
#include "distributed/query_pushdown_planning.h"
#include "distributed/string_utils.h"
#include "distributed/tdigest_extension.h"
#include "distributed/version_compat.h"
#include "distributed/worker_protocol.h"
/* Config variable managed via guc.c */
int LimitClauseRowFetchCount = -1; /* number of rows to fetch from each task */
double CountDistinctErrorRate = 0.0; /* precision of count(distinct) approximate */
int CoordinatorAggregationStrategy = COORDINATOR_AGGREGATION_ROW_GATHER;
/* Constant used throughout file */
static const uint32 masterTableId = 1; /* first range table reference on the master node */
typedef struct MasterAggregateWalkerContext
{
const ExtendedOpNodeProperties *extendedOpNodeProperties;
AttrNumber columnId;
} MasterAggregateWalkerContext;
typedef struct WorkerAggregateWalkerContext
{
const ExtendedOpNodeProperties *extendedOpNodeProperties;
List *expressionList;
bool createGroupByClause;
} WorkerAggregateWalkerContext;
/*
* QueryTargetList encapsulates the necessary fields to form
* worker query's target list.
*/
typedef struct QueryTargetList
{
List *targetEntryList; /* the list of target entries */
AttrNumber targetProjectionNumber; /* the index of the last entry */
} QueryTargetList;
/*
* QueryGroupClause encapsulates the necessary fields to form
* worker query's group by clause.
*/
typedef struct QueryGroupClause
{
List *groupClauseList; /* the list of group clause entries */
Index *nextSortGroupRefIndex; /* pointer to the index of the largest sort group reference index */
} QueryGroupClause;
/*
* QueryDistinctClause encapsulates the necessary fields to form
* worker query's DISTINCT/DISTINCT ON parts.
*/
typedef struct QueryDistinctClause
{
List *workerDistinctClause; /* the list of distinct clause entries */
bool workerHasDistinctOn;
} QueryDistinctClause;
/*
* QueryWindowClause encapsulates the necessary fields to form
* worker query's window clause.
*/
typedef struct QueryWindowClause
{
List *workerWindowClauseList; /* the list of window clause entries */
bool hasWindowFunctions;
Index *nextSortGroupRefIndex; /* see QueryGroupClause */
} QueryWindowClause;
/*
* QueryOrderByLimit encapsulates the necessary fields to form
* worker query's order by and limit clauses. Note that we don't
* keep track of limit offset clause since it is incorporated
* into the limit clause during the processing.
*/
typedef struct QueryOrderByLimit
{
Node *workerLimitCount;
List *workerSortClauseList;
Index *nextSortGroupRefIndex; /* see QueryGroupClause */
} QueryOrderByLimit;
/*
* LimitPushdownable tells us how a limit can be pushed down.
* See WorkerLimitCount for details.
*/
typedef enum LimitPushdownable
{
LIMIT_CANNOT_PUSHDOWN,
LIMIT_CAN_PUSHDOWN,
LIMIT_CAN_APPROXIMATE,
} LimitPushdownable;
/*
* OrderByLimitReference a structure that is used commonly while
* processing sort and limit clauses.
*/
typedef struct OrderByLimitReference
{
bool groupedByDisjointPartitionColumn;
bool onlyPushableWindowFunctions;
bool groupClauseIsEmpty;
bool sortClauseIsEmpty;
bool hasOrderByAggregate;
bool canApproximate;
bool hasDistinctOn;
} OrderByLimitReference;
/* Local functions forward declarations */
static MultiSelect * AndSelectNode(MultiSelect *selectNode);
static MultiSelect * OrSelectNode(MultiSelect *selectNode);
static List * OrSelectClauseList(List *selectClauseList);
static void PushDownNodeLoop(MultiUnaryNode *currentNode);
static void PullUpCollectLoop(MultiCollect *collectNode);
static void AddressProjectSpecialConditions(MultiProject *projectNode);
static PushDownStatus CanPushDown(MultiUnaryNode *parentNode);
static PullUpStatus CanPullUp(MultiUnaryNode *childNode);
static PushDownStatus Commutative(MultiUnaryNode *parentNode,
MultiUnaryNode *childNode);
static PushDownStatus Distributive(MultiUnaryNode *parentNode,
MultiBinaryNode *childNode);
static PullUpStatus Factorizable(MultiBinaryNode *parentNode,
MultiUnaryNode *childNode);
static List * SelectClauseTableIdList(List *selectClauseList);
static MultiUnaryNode * GenerateLeftNode(MultiUnaryNode *currentNode,
MultiBinaryNode *binaryNode);
static MultiUnaryNode * GenerateRightNode(MultiUnaryNode *currentNode,
MultiBinaryNode *binaryNode);
static MultiUnaryNode * GenerateNode(MultiUnaryNode *currentNode, MultiNode *childNode);
static List * TableIdListColumns(List *tableIdList, List *columnList);
static List * TableIdListSelectClauses(List *tableIdList, List *selectClauseList);
static void PushDownBelowUnaryChild(MultiUnaryNode *currentNode,
MultiUnaryNode *childNode);
static void PlaceUnaryNodeChild(MultiUnaryNode *unaryNode, MultiUnaryNode *childNode);
static void PlaceBinaryNodeLeftChild(MultiBinaryNode *binaryNode,
MultiUnaryNode *newLeftChildNode);
static void PlaceBinaryNodeRightChild(MultiBinaryNode *binaryNode,
MultiUnaryNode *newRightChildNode);
static void RemoveUnaryNode(MultiUnaryNode *unaryNode);
static void PullUpUnaryNode(MultiUnaryNode *unaryNode);
static void ParentSetNewChild(MultiNode *parentNode, MultiNode *oldChildNode,
MultiNode *newChildNode);
/* Local functions forward declarations for aggregate expressions */
static void ApplyExtendedOpNodes(MultiExtendedOp *originalNode,
MultiExtendedOp *masterNode,
MultiExtendedOp *workerNode);
static void TransformSubqueryNode(MultiTable *subqueryNode,
bool subqueryHasNonDistributableAggregates);
static MultiExtendedOp * MasterExtendedOpNode(MultiExtendedOp *originalOpNode,
ExtendedOpNodeProperties *
extendedOpNodeProperties);
static Node * MasterAggregateMutator(Node *originalNode,
MasterAggregateWalkerContext *walkerContext);
static Expr * MasterAggregateExpression(Aggref *originalAggregate,
MasterAggregateWalkerContext *walkerContext);
static Expr * MasterAverageExpression(Oid sumAggregateType, Oid countAggregateType,
AttrNumber *columnId);
static Expr * AddTypeConversion(Node *originalAggregate, Node *newExpression);
static MultiExtendedOp * WorkerExtendedOpNode(MultiExtendedOp *originalOpNode,
ExtendedOpNodeProperties *
extendedOpNodeProperties);
static void ProcessTargetListForWorkerQuery(List *targetEntryList,
ExtendedOpNodeProperties *
extendedOpNodeProperties,
QueryTargetList *queryTargetList,
QueryGroupClause *queryGroupClause);
static void ProcessHavingClauseForWorkerQuery(Node *havingQual,
ExtendedOpNodeProperties *
extendedOpNodeProperties,
Node **workerHavingQual,
QueryTargetList *queryTargetList,
QueryGroupClause *queryGroupClause);
static void ProcessDistinctClauseForWorkerQuery(List *distinctClause, bool hasDistinctOn,
List *groupClauseList,
bool queryHasAggregates,
QueryDistinctClause *queryDistinctClause,
bool *distinctPreventsLimitPushdown);
static void ProcessWindowFunctionsForWorkerQuery(List *windowClauseList,
List *originalTargetEntryList,
QueryWindowClause *queryWindowClause,
QueryTargetList *queryTargetList);
static void ProcessWindowFunctionPullUpForWorkerQuery(List *windowClause,
QueryTargetList *queryTargetList);
static void ProcessLimitOrderByForWorkerQuery(OrderByLimitReference orderByLimitReference,
Node *originalLimitCount, Node *limitOffset,
List *sortClauseList, List *groupClauseList,
List *originalTargetList,
QueryOrderByLimit *queryOrderByLimit,
QueryTargetList *queryTargetList);
static OrderByLimitReference BuildOrderByLimitReference(bool hasDistinctOn, bool
groupedByDisjointPartitionColumn,
bool onlyPushableWindowFunctions,
List *groupClause,
List *sortClauseList,
List *targetList);
static void ExpandWorkerTargetEntry(List *expressionList,
TargetEntry *originalTargetEntry,
bool addToGroupByClause,
QueryTargetList *queryTargetList,
QueryGroupClause *queryGroupClause);
static Index GetNextSortGroupRef(List *targetEntryList);
static TargetEntry * GenerateWorkerTargetEntry(TargetEntry *targetEntry,
Expr *workerExpression,
AttrNumber targetProjectionNumber);
static void AppendTargetEntryToGroupClause(TargetEntry *targetEntry,
QueryGroupClause *queryGroupClause);
static bool WorkerAggregateWalker(Node *node,
WorkerAggregateWalkerContext *walkerContext);
static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *
walkerContextry);
static AggregateType GetAggregateType(Aggref *aggregatExpression);
static Oid AggregateArgumentType(Aggref *aggregate);
static Expr * FirstAggregateArgument(Aggref *aggregate);
static bool AggregateEnabledCustom(Aggref *aggregateExpression);
static Oid CitusFunctionOidWithSignature(char *functionName, int numargs, Oid *argtypes);
static Oid WorkerPartialAggOid(void);
static Oid CoordCombineAggOid(void);
static Oid AggregateFunctionOid(const char *functionName, Oid inputType);
static Oid TypeOid(Oid schemaId, const char *typeName);
static SortGroupClause * CreateSortGroupClause(Var *column);
/* Local functions forward declarations for count(distinct) approximations */
static const char * CountDistinctHashFunctionName(Oid argumentType);
static int CountDistinctStorageSize(double approximationErrorRate);
static Const * MakeIntegerConstInt64(int64 integerValue);
static Const * MakeIntegerConst(int32 integerValue);
/* Local functions forward declarations for aggregate expression checks */
static bool HasNonDistributableAggregates(MultiNode *logicalPlanNode);
static bool CanPushDownExpression(Node *expression,
const ExtendedOpNodeProperties *
extendedOpNodeProperties);
static DeferredErrorMessage * DeferErrorIfHasNonDistributableAggregates(MultiNode *
logicalPlanNode);
static DeferredErrorMessage * DeferErrorIfUnsupportedArrayAggregate(Aggref *
arrayAggregateExpression);
static DeferredErrorMessage * DeferErrorIfUnsupportedJsonAggregate(AggregateType type,
Aggref *
aggregateExpression);
static DeferredErrorMessage * DeferErrorIfUnsupportedAggregateDistinct(Aggref *
aggregateExpression,
MultiNode *
logicalPlanNode);
static Var * AggregateDistinctColumn(Aggref *aggregateExpression);
static bool TablePartitioningSupportsDistinct(List *tableNodeList,
MultiExtendedOp *opNode,
Var *distinctColumn,
AggregateType aggregateType);
/* Local functions forward declarations for limit clauses */
static Node * WorkerLimitCount(Node *limitCount, Node *limitOffset, OrderByLimitReference
orderByLimitReference);
static List * WorkerSortClauseList(Node *limitCount,
List *groupClauseList, List *sortClauseList,
OrderByLimitReference orderByLimitReference);
static bool CanPushDownLimitApproximate(List *sortClauseList, List *targetList);
static bool HasOrderByAggregate(List *sortClauseList, List *targetList);
static bool HasOrderByNonCommutativeAggregate(List *sortClauseList, List *targetList);
static bool HasOrderByComplexExpression(List *sortClauseList, List *targetList);
static bool HasOrderByHllType(List *sortClauseList, List *targetList);
static bool ShouldProcessDistinctOrderAndLimitForWorker(ExtendedOpNodeProperties *
extendedOpNodeProperties,
bool pushingDownOriginalGrouping,
Node *havingQual);
static bool IsIndexInRange(const List *list, int index);
/*
* MultiLogicalPlanOptimize applies multi-relational algebra optimizations on
* the given logical plan tree. Specifically, the function applies four set of
* optimizations in a particular order.
*
* First, the function splits the search node into two nodes that contain And
* and Or clauses, and pushes down the node that contains And clauses. Second,
* the function pushes down the project node; this node either contains columns
* to return to the user, or aggregate expressions used by the aggregate node.
* Third, the function pulls up the collect operators in the tree. Fourth, the
* function finds the extended operator node, and splits this node into master
* and worker extended operator nodes.
*/
void
MultiLogicalPlanOptimize(MultiTreeRoot *multiLogicalPlan)
{
MultiNode *logicalPlanNode = (MultiNode *) multiLogicalPlan;
bool hasNonDistributableAggregates = HasNonDistributableAggregates(
logicalPlanNode);
List *extendedOpNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp);
MultiExtendedOp *extendedOpNode = (MultiExtendedOp *) linitial(extendedOpNodeList);
ExtendedOpNodeProperties extendedOpNodeProperties = BuildExtendedOpNodeProperties(
extendedOpNode, hasNonDistributableAggregates);
if (!extendedOpNodeProperties.groupedByDisjointPartitionColumn &&
!extendedOpNodeProperties.pullUpIntermediateRows)
{
DeferredErrorMessage *aggregatePushdownError =
DeferErrorIfHasNonDistributableAggregates(logicalPlanNode);
if (aggregatePushdownError != NULL)
{
if (CoordinatorAggregationStrategy == COORDINATOR_AGGREGATION_DISABLED)
{
RaiseDeferredError(aggregatePushdownError, ERROR);
}
else
{
extendedOpNodeProperties.pullUpIntermediateRows = true;
extendedOpNodeProperties.pushDownGroupingAndHaving = false;
}
}
}
/*
* If a select node exists, we use the idempower property to split the node
* into two nodes that contain And and Or clauses. If both And and Or nodes
* exist, we modify the tree in place to swap the original select node with
* And and Or nodes. We then push down the And select node if it exists.
*/
List *selectNodeList = FindNodesOfType(logicalPlanNode, T_MultiSelect);
if (selectNodeList != NIL)
{
MultiSelect *selectNode = (MultiSelect *) linitial(selectNodeList);
MultiSelect *andSelectNode = AndSelectNode(selectNode);
MultiSelect *orSelectNode = OrSelectNode(selectNode);
if (andSelectNode != NULL && orSelectNode != NULL)
{
MultiNode *parentNode = ParentNode((MultiNode *) selectNode);
MultiNode *childNode = ChildNode((MultiUnaryNode *) selectNode);
Assert(UnaryOperator(parentNode));
SetChild((MultiUnaryNode *) parentNode, (MultiNode *) orSelectNode);
SetChild((MultiUnaryNode *) orSelectNode, (MultiNode *) andSelectNode);
SetChild((MultiUnaryNode *) andSelectNode, (MultiNode *) childNode);
}
else if (andSelectNode != NULL && orSelectNode == NULL)
{
andSelectNode = selectNode; /* no need to modify the tree */
}
if (andSelectNode != NULL)
{
PushDownNodeLoop((MultiUnaryNode *) andSelectNode);
}
}
/* push down the multi project node */
List *projectNodeList = FindNodesOfType(logicalPlanNode, T_MultiProject);
MultiProject *projectNode = (MultiProject *) linitial(projectNodeList);
PushDownNodeLoop((MultiUnaryNode *) projectNode);
/* pull up collect nodes and merge duplicate collects */
List *collectNodeList = FindNodesOfType(logicalPlanNode, T_MultiCollect);
MultiCollect *collectNode = NULL;
foreach_declared_ptr(collectNode, collectNodeList)
{
PullUpCollectLoop(collectNode);
}
/*
* We split the extended operator node into its equivalent master and worker
* operator nodes; and if the extended operator has aggregates, we transform
* aggregate functions accordingly for the master and worker operator nodes.
* If we can push down the limit clause, we also add limit count and sort
* clause list to the worker operator node. We then push the worker operator
* node below the collect node.
*/
MultiExtendedOp *masterExtendedOpNode =
MasterExtendedOpNode(extendedOpNode, &extendedOpNodeProperties);
MultiExtendedOp *workerExtendedOpNode =
WorkerExtendedOpNode(extendedOpNode, &extendedOpNodeProperties);
ApplyExtendedOpNodes(extendedOpNode, masterExtendedOpNode, workerExtendedOpNode);
List *tableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
MultiTable *tableNode = NULL;
foreach_declared_ptr(tableNode, tableNodeList)
{
if (tableNode->relationId == SUBQUERY_RELATION_ID)
{
DeferredErrorMessage *error =
DeferErrorIfHasNonDistributableAggregates((MultiNode *) tableNode);
bool subqueryHasNonDistributableAggregates = false;
if (error != NULL)
{
if (CoordinatorAggregationStrategy == COORDINATOR_AGGREGATION_DISABLED)
{
RaiseDeferredError(error, ERROR);
}
else
{
subqueryHasNonDistributableAggregates = true;
}
}
TransformSubqueryNode(tableNode, subqueryHasNonDistributableAggregates);
}
}
/*
* When enabled, count(distinct) approximation uses hll as the intermediate
* data type. We currently have a mismatch between hll target entry and sort
* clause's sortop oid, so we can't push an order by on the hll data type to
* the worker node. We check that here and error out if necessary.
*/
bool hasOrderByHllType = HasOrderByHllType(workerExtendedOpNode->sortClauseList,
workerExtendedOpNode->targetList);
if (hasOrderByHllType)
{
ereport(ERROR, (errmsg("cannot approximate count(distinct) and order by it"),
errhint("You might need to disable approximations for either "
"count(distinct) or limit through configuration.")));
}
if (TargetListContainsSubquery(masterExtendedOpNode->targetList))
{
ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("cannot push down subquery on the target list"),
errdetail("Subqueries in the SELECT part of the query can only "
"be pushed down if they happen before aggregates and "
"window functions")));
}
}
/*
* AndSelectNode looks for AND clauses in the given select node. If they exist,
* the function returns these clauses in a new node. Otherwise, the function
* returns null.
*/
static MultiSelect *
AndSelectNode(MultiSelect *selectNode)
{
MultiSelect *andSelectNode = NULL;
List *selectClauseList = selectNode->selectClauseList;
List *orSelectClauseList = OrSelectClauseList(selectClauseList);
/* AND clauses are select clauses that are not OR clauses */
List *andSelectClauseList = list_difference(selectClauseList, orSelectClauseList);
if (andSelectClauseList != NIL)
{
andSelectNode = CitusMakeNode(MultiSelect);
andSelectNode->selectClauseList = andSelectClauseList;
}
return andSelectNode;
}
/*
* OrSelectNode looks for OR clauses in the given select node. If they exist,
* the function returns these clauses in a new node. Otherwise, the function
* returns null.
*/
static MultiSelect *
OrSelectNode(MultiSelect *selectNode)
{
MultiSelect *orSelectNode = NULL;
List *selectClauseList = selectNode->selectClauseList;
List *orSelectClauseList = OrSelectClauseList(selectClauseList);
if (orSelectClauseList != NIL)
{
orSelectNode = CitusMakeNode(MultiSelect);
orSelectNode->selectClauseList = orSelectClauseList;
}
return orSelectNode;
}
/*
* OrSelectClauseList walks over the select clause list, and returns all clauses
* that have OR expressions in them.
*/
static List *
OrSelectClauseList(List *selectClauseList)
{
List *orSelectClauseList = NIL;
Node *selectClause = NULL;
foreach_declared_ptr(selectClause, selectClauseList)
{
bool orClause = is_orclause(selectClause);
if (orClause)
{
orSelectClauseList = lappend(orSelectClauseList, selectClause);
}
}
return orSelectClauseList;
}
/*
* PushDownNodeLoop pushes down the current node as far down the plan tree as
* possible. For this, the function first addresses any special conditions that
* may apply on the current node. Then, the function pushes down the current
* node if its child node is unary. If the child is binary, the function splits
* the current node into two nodes by applying generation rules, and recurses
* into itself to push down these two nodes.
*/
static void
PushDownNodeLoop(MultiUnaryNode *currentNode)
{
MultiUnaryNode *projectNodeGenerated = NULL;
MultiUnaryNode *leftNodeGenerated = NULL;
MultiUnaryNode *rightNodeGenerated = NULL;
PushDownStatus pushDownStatus = CanPushDown(currentNode);
while (pushDownStatus == PUSH_DOWN_VALID ||
pushDownStatus == PUSH_DOWN_SPECIAL_CONDITIONS)
{
MultiNode *childNode = currentNode->childNode;
bool unaryChild = UnaryOperator(childNode);
bool binaryChild = BinaryOperator(childNode);
/*
* We first check if we can use the idempower property to split the
* project node. We split at a partition node as it captures the
* minimal set of columns needed from a partition job. After the split
* we break from the loop and recursively call pushdown for the
* generated project node.
*/
MultiNode *parentNode = ParentNode((MultiNode *) currentNode);
CitusNodeTag currentNodeType = CitusNodeTag(currentNode);
CitusNodeTag parentNodeType = CitusNodeTag(parentNode);
if (currentNodeType == T_MultiProject && parentNodeType == T_MultiPartition)
{
projectNodeGenerated = GenerateNode(currentNode, childNode);
PlaceUnaryNodeChild(currentNode, projectNodeGenerated);
break;
}
/* address any special conditions before we can perform the pushdown */
if (pushDownStatus == PUSH_DOWN_SPECIAL_CONDITIONS)
{
MultiProject *projectNode = (MultiProject *) currentNode;
Assert(currentNodeType == T_MultiProject);
AddressProjectSpecialConditions(projectNode);
}
if (unaryChild)
{
MultiUnaryNode *unaryChildNode = (MultiUnaryNode *) childNode;
PushDownBelowUnaryChild(currentNode, unaryChildNode);
}
else if (binaryChild)
{
MultiBinaryNode *binaryChildNode = (MultiBinaryNode *) childNode;
leftNodeGenerated = GenerateLeftNode(currentNode, binaryChildNode);
rightNodeGenerated = GenerateRightNode(currentNode, binaryChildNode);
/* push down the generated nodes below the binary child node */
PlaceBinaryNodeLeftChild(binaryChildNode, leftNodeGenerated);
PlaceBinaryNodeRightChild(binaryChildNode, rightNodeGenerated);
/*
* Remove the current node, and break out of the push down loop for
* the current node. Then, recurse into the push down function for
* the newly generated nodes.
*/
RemoveUnaryNode(currentNode);
break;
}
pushDownStatus = CanPushDown(currentNode);
}
/* recursively perform pushdown of any nodes generated in the loop */
if (projectNodeGenerated != NULL)
{
PushDownNodeLoop(projectNodeGenerated);
}
if (leftNodeGenerated != NULL)
{
PushDownNodeLoop(leftNodeGenerated);
}
if (rightNodeGenerated != NULL)
{
PushDownNodeLoop(rightNodeGenerated);
}
}
/*
* PullUpCollectLoop pulls up the collect node as far up as possible in the plan
* tree. The function also merges two collect nodes that are direct descendants
* of each other by removing the given collect node from the tree.
*/
static void
PullUpCollectLoop(MultiCollect *collectNode)
{
MultiUnaryNode *currentNode = (MultiUnaryNode *) collectNode;
PullUpStatus pullUpStatus = CanPullUp(currentNode);
while (pullUpStatus == PULL_UP_VALID)
{
PullUpUnaryNode(currentNode);
pullUpStatus = CanPullUp(currentNode);
}
/*
* After pulling up the collect node, if we find that our child node is also
* a collect, we merge the two collect nodes together by removing this node.
*/
MultiNode *childNode = currentNode->childNode;
if (CitusIsA(childNode, MultiCollect))
{
RemoveUnaryNode(currentNode);
}
}
/*
* AddressProjectSpecialConditions adds columns to the project node if necessary
* to make the node commutative and distributive with its child node. For this,
* the function checks for any special conditions between the project and child
* node, and determines the child node columns to add for the special conditions
* to apply. The function then adds these columns to the project node.
*/
static void
AddressProjectSpecialConditions(MultiProject *projectNode)
{
MultiNode *childNode = ChildNode((MultiUnaryNode *) projectNode);
CitusNodeTag childNodeTag = CitusNodeTag(childNode);
List *childColumnList = NIL;
/*
* We check if we need to include any child columns in the project node to
* address the following special conditions.
*
* SNC1: project node must include child node's projected columns, or
* SNC2: project node must include child node's partition column, or
* SNC3: project node must include child node's selection columns, or
* NSC1: project node must include child node's join columns.
*/
if (childNodeTag == T_MultiProject)
{
MultiProject *projectChildNode = (MultiProject *) childNode;
List *projectColumnList = projectChildNode->columnList;
childColumnList = copyObject(projectColumnList);
}
else if (childNodeTag == T_MultiPartition)
{
MultiPartition *partitionNode = (MultiPartition *) childNode;
Var *partitionColumn = partitionNode->partitionColumn;
List *partitionColumnList = list_make1(partitionColumn);
childColumnList = copyObject(partitionColumnList);
}
else if (childNodeTag == T_MultiSelect)
{
MultiSelect *selectNode = (MultiSelect *) childNode;
Node *selectClauseList = (Node *) selectNode->selectClauseList;
List *selectList = pull_var_clause_default(selectClauseList);
childColumnList = copyObject(selectList);
}
else if (childNodeTag == T_MultiJoin)
{
MultiJoin *joinNode = (MultiJoin *) childNode;
Node *joinClauseList = (Node *) joinNode->joinClauseList;
List *joinList = pull_var_clause_default(joinClauseList);
childColumnList = copyObject(joinList);
}
/*
* If we need to include any child columns, then find the columns that are
* not already in the project column list, and add them.
*/
if (childColumnList != NIL)
{
List *projectColumnList = projectNode->columnList;
List *newColumnList = list_concat_unique(projectColumnList, childColumnList);
projectNode->columnList = newColumnList;
}
}
/*
* CanPushDown determines if a particular node can be moved below its child. The
* criteria for pushing down a node is determined by multi-relational algebra's
* rules for commutativity and distributivity.
*/
static PushDownStatus
CanPushDown(MultiUnaryNode *parentNode)
{
PushDownStatus pushDownStatus = PUSH_DOWN_INVALID_FIRST;
MultiNode *childNode = parentNode->childNode;
bool unaryChild = UnaryOperator(childNode);
bool binaryChild = BinaryOperator(childNode);
if (unaryChild)
{
pushDownStatus = Commutative(parentNode, (MultiUnaryNode *) childNode);
}
else if (binaryChild)
{
pushDownStatus = Distributive(parentNode, (MultiBinaryNode *) childNode);
}
Assert(pushDownStatus != PUSH_DOWN_INVALID_FIRST);
return pushDownStatus;
}
/*
* CanPullUp determines if a particular node can be moved above its parent. The
* criteria for pulling up a node is determined by multi-relational algebra's
* rules for commutativity and factorizability.
*/
static PullUpStatus
CanPullUp(MultiUnaryNode *childNode)
{
PullUpStatus pullUpStatus = PULL_UP_INVALID_FIRST;
MultiNode *parentNode = ParentNode((MultiNode *) childNode);
bool unaryParent = UnaryOperator(parentNode);
bool binaryParent = BinaryOperator(parentNode);
if (unaryParent)
{
/*
* Evaluate if parent can be pushed down below the child node, since it
* is equivalent to pulling up the child above its parent.
*/
PushDownStatus parentPushDownStatus = Commutative((MultiUnaryNode *) parentNode,
childNode);
if (parentPushDownStatus == PUSH_DOWN_VALID)
{
pullUpStatus = PULL_UP_VALID;
}
else
{
pullUpStatus = PULL_UP_NOT_VALID;
}
}
else if (binaryParent)
{
pullUpStatus = Factorizable((MultiBinaryNode *) parentNode, childNode);
}
Assert(pullUpStatus != PULL_UP_INVALID_FIRST);
return pullUpStatus;
}
/*
* Commutative returns a status which denotes whether the given parent node can
* be pushed down below its child node using the commutative property.
*/
static PushDownStatus
Commutative(MultiUnaryNode *parentNode, MultiUnaryNode *childNode)
{
PushDownStatus pushDownStatus = PUSH_DOWN_NOT_VALID;
CitusNodeTag parentNodeTag = CitusNodeTag(parentNode);
CitusNodeTag childNodeTag = CitusNodeTag(childNode);
/* we cannot be commutative with non-query operators */
if (childNodeTag == T_MultiTreeRoot || childNodeTag == T_MultiTable)
{
return PUSH_DOWN_NOT_VALID;
}
/* first check for commutative operators and no special conditions */
if ((parentNodeTag == T_MultiPartition && childNodeTag == T_MultiProject) ||
(parentNodeTag == T_MultiPartition && childNodeTag == T_MultiPartition) ||
(parentNodeTag == T_MultiPartition && childNodeTag == T_MultiSelect))
{
pushDownStatus = PUSH_DOWN_VALID;
}
if ((parentNodeTag == T_MultiCollect && childNodeTag == T_MultiProject) ||
(parentNodeTag == T_MultiCollect && childNodeTag == T_MultiCollect) ||
(parentNodeTag == T_MultiCollect && childNodeTag == T_MultiSelect))
{
pushDownStatus = PUSH_DOWN_VALID;
}
if (parentNodeTag == T_MultiSelect)
{
pushDownStatus = PUSH_DOWN_VALID;
}
if (parentNodeTag == T_MultiProject && childNodeTag == T_MultiCollect)
{
pushDownStatus = PUSH_DOWN_VALID;
}
/*
* The project node is commutative with the below operators given that
* its special conditions apply.
*/
if ((parentNodeTag == T_MultiProject && childNodeTag == T_MultiProject) ||
(parentNodeTag == T_MultiProject && childNodeTag == T_MultiPartition) ||
(parentNodeTag == T_MultiProject && childNodeTag == T_MultiSelect) ||
(parentNodeTag == T_MultiProject && childNodeTag == T_MultiJoin))
{
pushDownStatus = PUSH_DOWN_SPECIAL_CONDITIONS;
}
return pushDownStatus;
}
/*
* Distributive returns a status which denotes whether the given parent node can
* be pushed down below its binary child node using the distributive property.
*/
static PushDownStatus
Distributive(MultiUnaryNode *parentNode, MultiBinaryNode *childNode)
{
PushDownStatus pushDownStatus = PUSH_DOWN_NOT_VALID;
CitusNodeTag parentNodeTag = CitusNodeTag(parentNode);
CitusNodeTag childNodeTag = CitusNodeTag(childNode);
/* special condition checks for partition operator are not implemented */
Assert(parentNodeTag != T_MultiPartition);
/*
* The project node is distributive with the join operator given that its
* special conditions apply.
*/
if (parentNodeTag == T_MultiProject)
{
pushDownStatus = PUSH_DOWN_SPECIAL_CONDITIONS;
}
/* collect node is distributive without special conditions */
if ((parentNodeTag == T_MultiCollect && childNodeTag == T_MultiJoin) ||
(parentNodeTag == T_MultiCollect && childNodeTag == T_MultiCartesianProduct))
{
pushDownStatus = PUSH_DOWN_VALID;
}
/*
* The select node is distributive with a binary operator if all tables in
* the select clauses are output by the binary child. The select clauses are
* individually AND'd; and therefore this check is sufficient to implement
* the NSC3 special condition in multi-relational algebra.
*/
if ((parentNodeTag == T_MultiSelect && childNodeTag == T_MultiJoin) ||
(parentNodeTag == T_MultiSelect && childNodeTag == T_MultiCartesianProduct))
{
MultiSelect *selectNode = (MultiSelect *) parentNode;
List *selectClauseList = selectNode->selectClauseList;
List *selectTableIdList = SelectClauseTableIdList(selectClauseList);
List *childTableIdList = OutputTableIdList((MultiNode *) childNode);
/* find tables that are in select clause list, but not in child list */
List *diffList = list_difference_int(selectTableIdList, childTableIdList);
if (diffList == NIL)
{
pushDownStatus = PUSH_DOWN_VALID;
}
}
return pushDownStatus;
}
/*
* Factorizable returns a status which denotes whether the given unary child
* node can be pulled up above its binary parent node using the factorizability
* property. The function currently performs this check only for collect node
* types; other node types have generation rules that are not yet implemented.
*/
static PullUpStatus
Factorizable(MultiBinaryNode *parentNode, MultiUnaryNode *childNode)
{
PullUpStatus pullUpStatus = PULL_UP_NOT_VALID;
CitusNodeTag parentNodeTag = CitusNodeTag(parentNode);
CitusNodeTag childNodeTag = CitusNodeTag(childNode);
/*
* The following nodes are factorizable with their parents, but we don't
* have their generation rules implemented. We therefore assert here.
*/
Assert(childNodeTag != T_MultiProject);
Assert(childNodeTag != T_MultiPartition);
Assert(childNodeTag != T_MultiSelect);
if ((childNodeTag == T_MultiCollect && parentNodeTag == T_MultiJoin) ||
(childNodeTag == T_MultiCollect && parentNodeTag == T_MultiCartesianProduct))
{
pullUpStatus = PULL_UP_VALID;
}
return pullUpStatus;
}
/*
* SelectClauseTableIdList finds the (range) table identifier for each select
* clause in the given list, and returns these identifiers in a new list.
*/
static List *
SelectClauseTableIdList(List *selectClauseList)
{
List *tableIdList = NIL;
Node *selectClause = NULL;
foreach_declared_ptr(selectClause, selectClauseList)
{
List *selectColumnList = pull_var_clause_default(selectClause);
if (list_length(selectColumnList) == 0)
{
/* filter is a constant, e.g. false or 1=0 */
continue;
}
Var *selectColumn = (Var *) linitial(selectColumnList);
int selectColumnTableId = (int) selectColumn->varno;
tableIdList = lappend_int(tableIdList, selectColumnTableId);
}
return tableIdList;
}
/*
* GenerateLeftNode splits the current node over the binary node by applying the
* generation rule for distributivity in multi-relational algebra. After the
* split, the function returns the left node.
*/
static MultiUnaryNode *
GenerateLeftNode(MultiUnaryNode *currentNode, MultiBinaryNode *binaryNode)
{
MultiNode *leftChildNode = binaryNode->leftChildNode;
MultiUnaryNode *leftNodeGenerated = GenerateNode(currentNode, leftChildNode);
return leftNodeGenerated;
}
/*
* GenerateRightNode splits the current node over the binary node by applying
* the generation rule for distributivity in multi-relational algebra. After the
* split, the function returns the right node.
*/
static MultiUnaryNode *
GenerateRightNode(MultiUnaryNode *currentNode, MultiBinaryNode *binaryNode)
{
MultiNode *rightChildNode = binaryNode->rightChildNode;
MultiUnaryNode *rightNodeGenerated = GenerateNode(currentNode, rightChildNode);
return rightNodeGenerated;
}
/*
* GenerateNode determines the current node's type, and applies the relevant
* generation node for that node type. If the current node is a project node,
* the function creates a new project node with attributes that only have the
* child subtree's tables. Else if the current node is a select node, the
* function creates a new select node with select clauses that only belong to
* the tables output by the child node's subtree.
*/
static MultiUnaryNode *
GenerateNode(MultiUnaryNode *currentNode, MultiNode *childNode)
{
MultiUnaryNode *generatedNode = NULL;
CitusNodeTag currentNodeType = CitusNodeTag(currentNode);
List *tableIdList = OutputTableIdList(childNode);
if (currentNodeType == T_MultiProject)
{
MultiProject *projectNode = (MultiProject *) currentNode;
List *columnList = copyObject(projectNode->columnList);
List *newColumnList = TableIdListColumns(tableIdList, columnList);
if (newColumnList != NIL)
{
MultiProject *newProjectNode = CitusMakeNode(MultiProject);
newProjectNode->columnList = newColumnList;
generatedNode = (MultiUnaryNode *) newProjectNode;
}
}
else if (currentNodeType == T_MultiSelect)
{
MultiSelect *selectNode = (MultiSelect *) currentNode;
List *selectClauseList = copyObject(selectNode->selectClauseList);
List *newSelectClauseList = TableIdListSelectClauses(tableIdList,
selectClauseList);
if (newSelectClauseList != NIL)
{
MultiSelect *newSelectNode = CitusMakeNode(MultiSelect);
newSelectNode->selectClauseList = newSelectClauseList;
generatedNode = (MultiUnaryNode *) newSelectNode;
}
}
return generatedNode;
}
/*
* TableIdListColumns walks over the given column list, finds columns belonging
* to the given table id list, and returns the found columns in a new list.
*/
static List *
TableIdListColumns(List *tableIdList, List *columnList)
{
List *tableColumnList = NIL;
Var *column = NULL;
foreach_declared_ptr(column, columnList)
{
int columnTableId = (int) column->varno;
bool tableListMember = list_member_int(tableIdList, columnTableId);
if (tableListMember)
{
tableColumnList = lappend(tableColumnList, column);
}
}
return tableColumnList;
}
/*
* TableIdListSelectClauses walks over the given select clause list, finds the
* select clauses whose column references belong to the given table list, and
* returns the found clauses in a new list.
*/
static List *
TableIdListSelectClauses(List *tableIdList, List *selectClauseList)
{
List *tableSelectClauseList = NIL;
Node *selectClause = NULL;
foreach_declared_ptr(selectClause, selectClauseList)
{
List *selectColumnList = pull_var_clause_default(selectClause);
if (list_length(selectColumnList) == 0)
{
/* filter is a constant, e.g. false or 1=0, always include it */
tableSelectClauseList = lappend(tableSelectClauseList, selectClause);
}
else
{
Var *selectColumn = (Var *) linitial(selectColumnList);
int selectClauseTableId = (int) selectColumn->varno;
bool tableIdListMember = list_member_int(tableIdList, selectClauseTableId);
if (tableIdListMember)
{
tableSelectClauseList = lappend(tableSelectClauseList, selectClause);
}
}
}
return tableSelectClauseList;
}
/* Pushes down the current node below its unary child node. */
static void
PushDownBelowUnaryChild(MultiUnaryNode *currentNode, MultiUnaryNode *childNode)
{
MultiNode *parentNode = ParentNode((MultiNode *) currentNode);
MultiNode *childChildNode = ChildNode(childNode);
/* current node's parent now points to the child node */
ParentSetNewChild(parentNode, (MultiNode *) currentNode, (MultiNode *) childNode);
/* current node's child becomes its parent */
SetChild(childNode, (MultiNode *) currentNode);
/* current node points to the child node's child */
SetChild(currentNode, childChildNode);
}
/*
* PlaceUnaryNodeChild inserts the new node as a child node under the given
* unary node. The function also places the previous child node under the new
* child node.
*/
static void
PlaceUnaryNodeChild(MultiUnaryNode *unaryNode, MultiUnaryNode *newChildNode)
{
MultiNode *oldChildNode = ChildNode(unaryNode);
SetChild(unaryNode, (MultiNode *) newChildNode);
SetChild(newChildNode, oldChildNode);
}
/*
* PlaceBinaryNodeLeftChild inserts the new left child as the binary node's left
* child. The function also places the previous left child below the new child
* node.
*/
static void
PlaceBinaryNodeLeftChild(MultiBinaryNode *binaryNode, MultiUnaryNode *newLeftChildNode)
{
if (newLeftChildNode == NULL)
{
return;
}
SetChild(newLeftChildNode, binaryNode->leftChildNode);
SetLeftChild(binaryNode, (MultiNode *) newLeftChildNode);
}
/*
* PlaceBinaryNodeRightChild inserts the new right child as the binary node's
* right child. The function also places the previous right child below the new
* child node.
*/
static void
PlaceBinaryNodeRightChild(MultiBinaryNode *binaryNode, MultiUnaryNode *newRightChildNode)
{
if (newRightChildNode == NULL)
{
return;
}
SetChild(newRightChildNode, binaryNode->rightChildNode);
SetRightChild(binaryNode, (MultiNode *) newRightChildNode);
}
/* Removes the given unary node from the logical plan, and frees the node. */
static void
RemoveUnaryNode(MultiUnaryNode *unaryNode)
{
MultiNode *parentNode = ParentNode((MultiNode *) unaryNode);
MultiNode *childNode = ChildNode(unaryNode);
/* set parent to directly point to unary node's child */
ParentSetNewChild(parentNode, (MultiNode *) unaryNode, childNode);
pfree(unaryNode);
}
/* Pulls up the given current node above its parent node. */
static void
PullUpUnaryNode(MultiUnaryNode *unaryNode)
{
MultiNode *parentNode = ParentNode((MultiNode *) unaryNode);
bool unaryParent = UnaryOperator(parentNode);
bool binaryParent = BinaryOperator(parentNode);
if (unaryParent)
{
/* pulling up a node is the same as pushing down the node's unary parent */
MultiUnaryNode *unaryParentNode = (MultiUnaryNode *) parentNode;
PushDownBelowUnaryChild(unaryParentNode, unaryNode);
}
else if (binaryParent)
{
MultiBinaryNode *binaryParentNode = (MultiBinaryNode *) parentNode;
MultiNode *parentParentNode = ParentNode((MultiNode *) binaryParentNode);
MultiNode *childNode = unaryNode->childNode;
/* make the parent node point to the unary node's child node */
if (binaryParentNode->leftChildNode == ((MultiNode *) unaryNode))
{
SetLeftChild(binaryParentNode, childNode);
}
else
{
SetRightChild(binaryParentNode, childNode);
}
/* make the parent parent node point to the unary node */
ParentSetNewChild(parentParentNode, parentNode, (MultiNode *) unaryNode);
/* make the unary node point to the (old) parent node */
SetChild(unaryNode, parentNode);
}
}
/*
* ParentSetNewChild takes in the given parent node, and replaces the parent's
* old child node with the new child node. The function needs the old child node
* in case the parent is a binary node and the function needs to determine which
* side of the parent node the new child node needs to go to.
*/
static void
ParentSetNewChild(MultiNode *parentNode, MultiNode *oldChildNode,
MultiNode *newChildNode)
{
bool unaryParent = UnaryOperator(parentNode);
bool binaryParent = BinaryOperator(parentNode);
if (unaryParent)
{
MultiUnaryNode *unaryParentNode = (MultiUnaryNode *) parentNode;
SetChild(unaryParentNode, newChildNode);
}
else if (binaryParent)
{
MultiBinaryNode *binaryParentNode = (MultiBinaryNode *) parentNode;
/* determine which side of the parent the old child is on */
if (binaryParentNode->leftChildNode == oldChildNode)
{
SetLeftChild(binaryParentNode, newChildNode);
}
else
{
SetRightChild(binaryParentNode, newChildNode);
}
}
}
/*
* ApplyExtendedOpNodes replaces the original extended operator node with the
* master and worker extended operator nodes. The function then pushes down the
* worker node below the original node's child node. Note that for the push down
* to apply, the original node's child must be a collect node.
*/
static void
ApplyExtendedOpNodes(MultiExtendedOp *originalNode, MultiExtendedOp *masterNode,
MultiExtendedOp *workerNode)
{
MultiNode *parentNode = ParentNode((MultiNode *) originalNode);
MultiNode *collectNode = ChildNode((MultiUnaryNode *) originalNode);
MultiNode *collectChildNode = ChildNode((MultiUnaryNode *) collectNode);
/* original node's child must be a collect node */
Assert(CitusIsA(collectNode, MultiCollect));
Assert(UnaryOperator(parentNode));
/* swap the original aggregate node with the master extended node */
SetChild((MultiUnaryNode *) parentNode, (MultiNode *) masterNode);
SetChild((MultiUnaryNode *) masterNode, (MultiNode *) collectNode);
/* add the worker extended node below the collect node */
SetChild((MultiUnaryNode *) collectNode, (MultiNode *) workerNode);
SetChild((MultiUnaryNode *) workerNode, (MultiNode *) collectChildNode);
/* clean up the original extended operator node */
pfree(originalNode);
}
/*
* TransformSubqueryNode splits the extended operator node under subquery
* multi table node into its equivalent master and worker operator nodes, and
* we transform aggregate functions accordingly for the master and worker
* operator nodes. We create a partition node based on the first group by
* column of the extended operator node and set it as the child of the master
* operator node.
*/
static void
TransformSubqueryNode(MultiTable *subqueryNode,
bool subqueryHasNonDistributableAggregates)
{
if (CoordinatorAggregationStrategy != COORDINATOR_AGGREGATION_DISABLED &&
HasNonDistributableAggregates((MultiNode *) subqueryNode))
{
subqueryHasNonDistributableAggregates = true;
}
MultiExtendedOp *extendedOpNode =
(MultiExtendedOp *) ChildNode((MultiUnaryNode *) subqueryNode);
MultiNode *collectNode = ChildNode((MultiUnaryNode *) extendedOpNode);
MultiNode *collectChildNode = ChildNode((MultiUnaryNode *) collectNode);
ExtendedOpNodeProperties extendedOpNodeProperties =
BuildExtendedOpNodeProperties(extendedOpNode,
subqueryHasNonDistributableAggregates);
MultiExtendedOp *masterExtendedOpNode =
MasterExtendedOpNode(extendedOpNode, &extendedOpNodeProperties);
MultiExtendedOp *workerExtendedOpNode =
WorkerExtendedOpNode(extendedOpNode, &extendedOpNodeProperties);
List *groupClauseList = extendedOpNode->groupClauseList;
List *targetEntryList = extendedOpNode->targetList;
List *groupTargetEntryList = GroupTargetEntryList(groupClauseList, targetEntryList);
TargetEntry *groupByTargetEntry = (TargetEntry *) linitial(groupTargetEntryList);
Expr *groupByExpression = groupByTargetEntry->expr;
MultiPartition *partitionNode = CitusMakeNode(MultiPartition);
/*
* If group by is on a function expression, then we create a new column from
* function expression result type. Because later while creating partition
* tasks, we expect a column type to partition intermediate results.
* Note that we will only need partition type. So we set column type to
* result type of the function expression, and set other fields of column to
* default values.
*/
if (IsA(groupByExpression, Var))
{
partitionNode->partitionColumn = (Var *) groupByExpression;
}
else if (IsA(groupByExpression, FuncExpr))
{
FuncExpr *functionExpression = (FuncExpr *) groupByExpression;
Index tableId = 0;
AttrNumber columnAttributeNumber = InvalidAttrNumber;
Oid columnType = functionExpression->funcresulttype;
int32 columnTypeMod = -1;
Oid columnCollationOid = InvalidOid;
Index columnLevelSup = 0;
Var *partitionColumn = makeVar(tableId, columnAttributeNumber, columnType,
columnTypeMod, columnCollationOid, columnLevelSup);
partitionNode->partitionColumn = partitionColumn;
}
else
{
ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
errmsg("cannot run this subquery"),
errdetail("Currently only columns and function expressions "
"are allowed in group by expression of subqueries")));
}
SetChild((MultiUnaryNode *) subqueryNode, (MultiNode *) masterExtendedOpNode);
SetChild((MultiUnaryNode *) masterExtendedOpNode, (MultiNode *) partitionNode);
SetChild((MultiUnaryNode *) partitionNode, (MultiNode *) collectNode);
SetChild((MultiUnaryNode *) collectNode, (MultiNode *) workerExtendedOpNode);
SetChild((MultiUnaryNode *) workerExtendedOpNode, (MultiNode *) collectChildNode);
}
/*
* MasterExtendedOpNode creates the master extended operator node from the given
* target entries. The function walks over these target entries; and for entries
* with aggregates in them, this function calls the aggregate expression mutator
* function.
*
* Note that the function logically depends on the worker extended operator node
* function. If the target entry does not contain aggregate functions, we assume
* all work is done on the worker side, and create a column that references the
* worker nodes' results.
*/
static MultiExtendedOp *
MasterExtendedOpNode(MultiExtendedOp *originalOpNode,
ExtendedOpNodeProperties *extendedOpNodeProperties)
{
List *targetEntryList = originalOpNode->targetList;
List *newTargetEntryList = NIL;
List *newGroupClauseList = NIL;
Node *originalHavingQual = originalOpNode->havingQual;
Node *newHavingQual = NULL;
MasterAggregateWalkerContext walkerContext = {
.extendedOpNodeProperties = extendedOpNodeProperties,
.columnId = 1,
};
/* iterate over original target entries */
TargetEntry *originalTargetEntry = NULL;
foreach_declared_ptr(originalTargetEntry, targetEntryList)
{
TargetEntry *newTargetEntry = flatCopyTargetEntry(originalTargetEntry);
Expr *originalExpression = originalTargetEntry->expr;
Expr *newExpression = NULL;
if (CanPushDownExpression((Node *) originalExpression, extendedOpNodeProperties))
{
/*
* The expression was entirely pushed down to worker.
* We simply make it reference the output generated by worker nodes.
*/
Var *column = makeVarFromTargetEntry(masterTableId, originalTargetEntry);
column->varattno = walkerContext.columnId;
column->varattnosyn = walkerContext.columnId;
walkerContext.columnId++;
if (column->vartype == RECORDOID || column->vartype == RECORDARRAYOID)
{
column->vartypmod = BlessRecordExpression(originalTargetEntry->expr);
}
newExpression = (Expr *) column;
}
else
{
Node *newNode = MasterAggregateMutator((Node *) originalExpression,
&walkerContext);
newExpression = (Expr *) newNode;
}
newTargetEntry->expr = newExpression;
newTargetEntryList = lappend(newTargetEntryList, newTargetEntry);
}
if (!extendedOpNodeProperties->pushDownGroupingAndHaving)
{
/*
* Not pushing down GROUP BY, need to regroup on coordinator
* and apply having on the coordinator.
*/
newGroupClauseList = originalOpNode->groupClauseList;
if (originalHavingQual != NULL)
{
newHavingQual = MasterAggregateMutator(originalHavingQual, &walkerContext);
if (IsA(newHavingQual, List))
{
/*
* unflatten having qual to allow standard planner to work when transforming
* the master query to a plan
*/
newHavingQual = (Node *) make_ands_explicit(
castNode(List, newHavingQual));
}
}
}
MultiExtendedOp *masterExtendedOpNode = CitusMakeNode(MultiExtendedOp);
masterExtendedOpNode->targetList = newTargetEntryList;
masterExtendedOpNode->groupClauseList = newGroupClauseList;
masterExtendedOpNode->sortClauseList = originalOpNode->sortClauseList;
masterExtendedOpNode->distinctClause = originalOpNode->distinctClause;
masterExtendedOpNode->hasDistinctOn = originalOpNode->hasDistinctOn;
masterExtendedOpNode->limitCount = originalOpNode->limitCount;
masterExtendedOpNode->limitOffset = originalOpNode->limitOffset;
masterExtendedOpNode->limitOption = originalOpNode->limitOption;
masterExtendedOpNode->havingQual = newHavingQual;
if (!extendedOpNodeProperties->onlyPushableWindowFunctions)
{
masterExtendedOpNode->hasWindowFuncs = originalOpNode->hasWindowFuncs;
masterExtendedOpNode->windowClause = originalOpNode->windowClause;
masterExtendedOpNode->onlyPushableWindowFunctions = false;
}
return masterExtendedOpNode;
}
/*
* MasterAggregateMutator walks over the original target entry expression, and
* creates the new expression tree to execute on the master node. The function
* transforms aggregates, and copies columns; and recurses into the expression
* mutator function for all other expression types.
*
* Please note that the recursive mutator function traverses the expression tree
* in depth first order. For this function to set attribute numbers correctly,
* WorkerAggregateWalker() *must* walk over the expression tree in the same
* depth first order.
*/
static Node *
MasterAggregateMutator(Node *originalNode, MasterAggregateWalkerContext *walkerContext)
{
Node *newNode = NULL;
if (originalNode == NULL)
{
return NULL;
}
if (IsA(originalNode, Aggref))
{
Aggref *originalAggregate = (Aggref *) originalNode;
if (CanPushDownExpression(originalNode,
walkerContext->extendedOpNodeProperties))
{
/*
* The expression was entirely pushed down to worker.
* We simply make it reference the output generated by worker nodes.
*/
Var *column = makeVar(masterTableId, walkerContext->columnId,
originalAggregate->aggtype,
-1, originalAggregate->aggcollid, 0);
walkerContext->columnId++;
if (column->vartype == RECORDOID || column->vartype == RECORDARRAYOID)
{
column->vartypmod = BlessRecordExpression((Expr *) originalNode);
}
newNode = (Node *) column;
}
else
{
Expr *newExpression = MasterAggregateExpression(originalAggregate,
walkerContext);
newNode = (Node *) newExpression;
}
}
else if (IsA(originalNode, Var))
{
Var *origColumn = (Var *) originalNode;
Var *newColumn = makeVar(masterTableId, walkerContext->columnId,
origColumn->vartype, origColumn->vartypmod,
origColumn->varcollid, origColumn->varlevelsup);
walkerContext->columnId++;
newNode = (Node *) newColumn;
}
else
{
newNode = expression_tree_mutator(originalNode, MasterAggregateMutator,
(void *) walkerContext);
}
return newNode;
}
/*
* MasterAggregateExpression creates the master aggregate expression using the
* original aggregate and aggregate's type information. This function handles
* the average, count, array_agg, hll and topn aggregates separately due to
* differences in these aggregate functions' transformations.
*
* Note that this function has implicit knowledge of the transformations applied
* for worker nodes on the original aggregate. The function uses this implicit
* knowledge to create the appropriate master function with correct data types.
*/
static Expr *
MasterAggregateExpression(Aggref *originalAggregate,
MasterAggregateWalkerContext *walkerContext)
{
const Index columnLevelsUp = 0; /* normal column */
const AttrNumber argumentId = 1; /* our aggregates have single arguments */
AggregateType aggregateType = GetAggregateType(originalAggregate);
Expr *newMasterExpression = NULL;
if (walkerContext->extendedOpNodeProperties->pullUpIntermediateRows)
{
Aggref *aggregate = (Aggref *) copyObject(originalAggregate);
TargetEntry *targetEntry;
foreach_declared_ptr(targetEntry, aggregate->args)
{
targetEntry->expr = (Expr *)
makeVar(masterTableId, walkerContext->columnId,
exprType((Node *) targetEntry->expr),
exprTypmod((Node *) targetEntry->expr),
exprCollation((Node *) targetEntry->expr),
columnLevelsUp);
walkerContext->columnId++;
}
aggregate->aggdirectargs = NIL;
Expr *directarg;
foreach_declared_ptr(directarg, originalAggregate->aggdirectargs)
{
/*
* Need to replace nodes that contain any Vars with Vars referring
* to the related column of the result set returned for the worker
* aggregation.
*
* When there are no Vars, then the expression can be fully evaluated
* on the coordinator, so we skip it here. This is not just an
* optimization, but the result of the expression might require
* calling the final function of the aggregate, and doing so when
* there are no input rows (i.e.: with an empty tuple slot) is not
* desirable for the node-executor methods.
*/
if (pull_var_clause_default((Node *) directarg) != NIL)
{
Var *var = makeVar(masterTableId, walkerContext->columnId,
exprType((Node *) directarg),
exprTypmod((Node *) directarg),
exprCollation((Node *) directarg),
columnLevelsUp);
aggregate->aggdirectargs = lappend(aggregate->aggdirectargs, var);
walkerContext->columnId++;
}
else
{
aggregate->aggdirectargs = lappend(aggregate->aggdirectargs, directarg);
}
}
if (aggregate->aggfilter)
{
aggregate->aggfilter = (Expr *)
makeVar(masterTableId, walkerContext->columnId,
BOOLOID, -1, InvalidOid, columnLevelsUp);
walkerContext->columnId++;
}
newMasterExpression = (Expr *) aggregate;
}
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
walkerContext->extendedOpNodeProperties->pullDistinctColumns)
{
Aggref *aggregate = (Aggref *) copyObject(originalAggregate);
List *varList = pull_var_clause_default((Node *) aggregate);
List *uniqueVarList = NIL;
int startColumnCount = walkerContext->columnId;
/* determine unique vars that were placed in target list by worker */
Var *column = NULL;
foreach_declared_ptr(column, varList)
{
uniqueVarList = list_append_unique(uniqueVarList, copyObject(column));
}
/*
* Go over each var inside aggregate and update their varattno's according to
* worker query target entry column index.
*/
Var *columnToUpdate = NULL;
foreach_declared_ptr(columnToUpdate, varList)
{
int columnIndex = 0;
Var *currentVar = NULL;
foreach_declared_ptr(currentVar, uniqueVarList)
{
if (equal(columnToUpdate, currentVar))
{
break;
}
columnIndex++;
}
columnToUpdate->varno = masterTableId;
columnToUpdate->varnosyn = masterTableId;
columnToUpdate->varattno = startColumnCount + columnIndex;
columnToUpdate->varattnosyn = startColumnCount + columnIndex;
}
/* we added that many columns */
walkerContext->columnId += list_length(uniqueVarList);
newMasterExpression = (Expr *) aggregate;
}
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
{
/*
* If enabled, we check for count(distinct) approximations before count
* distincts. For this, we first compute hll_add_agg(hll_hash(column)) on
* worker nodes, and get hll values. We then gather hlls on the master
* node, and compute hll_cardinality(hll_union_agg(hll)).
*/
const int argCount = 1;
const int defaultTypeMod = -1;
/* extract schema name of hll */
Oid hllId = get_extension_oid(HLL_EXTENSION_NAME, false);
Oid hllSchemaOid = get_extension_schema(hllId);
const char *hllSchemaName = get_namespace_name(hllSchemaOid);
Oid unionFunctionId = FunctionOid(hllSchemaName, HLL_UNION_AGGREGATE_NAME,
argCount);
Oid cardinalityFunctionId = FunctionOid(hllSchemaName, HLL_CARDINALITY_FUNC_NAME,
argCount);
Oid cardinalityReturnType = get_func_rettype(cardinalityFunctionId);
Oid hllType = TypeOid(hllSchemaOid, HLL_TYPE_NAME);
Oid hllTypeCollationId = get_typcollation(hllType);
Var *hllColumn = makeVar(masterTableId, walkerContext->columnId, hllType,
defaultTypeMod,
hllTypeCollationId, columnLevelsUp);
walkerContext->columnId++;
TargetEntry *hllTargetEntry = makeTargetEntry((Expr *) hllColumn, argumentId,
NULL, false);
Aggref *unionAggregate = makeNode(Aggref);
unionAggregate->aggfnoid = unionFunctionId;
unionAggregate->aggtype = hllType;
unionAggregate->args = list_make1(hllTargetEntry);
unionAggregate->aggkind = AGGKIND_NORMAL;
unionAggregate->aggfilter = NULL;
unionAggregate->aggtranstype = InvalidOid;
unionAggregate->aggargtypes = list_make1_oid(unionAggregate->aggtype);
unionAggregate->aggsplit = AGGSPLIT_SIMPLE;
FuncExpr *cardinalityExpression = makeNode(FuncExpr);
cardinalityExpression->funcid = cardinalityFunctionId;
cardinalityExpression->funcresulttype = cardinalityReturnType;
cardinalityExpression->args = list_make1(unionAggregate);
newMasterExpression = (Expr *) cardinalityExpression;
}
else if (aggregateType == AGGREGATE_AVERAGE)
{
/*
* If the original aggregate is an average, we first compute sum(colum)
* and count(column) on worker nodes. Then, we compute (sum(sum(column))
* / sum(count(column))) on the master node.
*/
const char *sumAggregateName = AggregateNames[AGGREGATE_SUM];
const char *countAggregateName = AggregateNames[AGGREGATE_COUNT];
Oid argumentType = AggregateArgumentType(originalAggregate);
Oid sumFunctionId = AggregateFunctionOid(sumAggregateName, argumentType);
Oid countFunctionId = AggregateFunctionOid(countAggregateName, ANYOID);
/* calculate the aggregate types that worker nodes are going to return */
Oid workerSumReturnType = get_func_rettype(sumFunctionId);
Oid workerCountReturnType = get_func_rettype(countFunctionId);
/* create the expression sum(sum(column) / sum(count(column))) */
newMasterExpression = MasterAverageExpression(workerSumReturnType,
workerCountReturnType,
&(walkerContext->columnId));
}
else if (aggregateType == AGGREGATE_COUNT)
{
/*
* Count aggregates are handled in two steps. First, worker nodes report
* their count results. Then, the master node sums up these results.
*/
/* worker aggregate and original aggregate have the same return type */
Oid workerReturnType = exprType((Node *) originalAggregate);
int32 workerReturnTypeMod = exprTypmod((Node *) originalAggregate);
Oid workerCollationId = exprCollation((Node *) originalAggregate);
const char *sumAggregateName = AggregateNames[AGGREGATE_SUM];
Oid sumFunctionId = AggregateFunctionOid(sumAggregateName, workerReturnType);
Oid masterReturnType = get_func_rettype(sumFunctionId);
Aggref *newMasterAggregate = copyObject(originalAggregate);
newMasterAggregate->aggstar = false;
newMasterAggregate->aggdistinct = NULL;
newMasterAggregate->aggfnoid = sumFunctionId;
newMasterAggregate->aggtype = masterReturnType;
newMasterAggregate->aggfilter = NULL;
newMasterAggregate->aggtranstype = InvalidOid;
newMasterAggregate->aggargtypes = list_make1_oid(newMasterAggregate->aggtype);
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
Var *column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
walkerContext->columnId++;
/* aggref expects its arguments to be wrapped in target entries */
TargetEntry *columnTargetEntry = makeTargetEntry((Expr *) column, argumentId,
NULL, false);
newMasterAggregate->args = list_make1(columnTargetEntry);
/* cast numeric sum result to bigint (count's return type) */
CoerceViaIO *coerceExpr = makeNode(CoerceViaIO);
coerceExpr->arg = (Expr *) newMasterAggregate;
coerceExpr->resulttype = INT8OID;
coerceExpr->resultcollid = InvalidOid;
coerceExpr->coerceformat = COERCE_IMPLICIT_CAST;
coerceExpr->location = -1;
/* convert NULL to 0 in case of no rows */
Const *zeroConst = MakeIntegerConstInt64(0);
List *coalesceArgs = list_make2(coerceExpr, zeroConst);
CoalesceExpr *coalesceExpr = makeNode(CoalesceExpr);
coalesceExpr->coalescetype = INT8OID;
coalesceExpr->coalescecollid = InvalidOid;
coalesceExpr->args = coalesceArgs;
coalesceExpr->location = -1;
newMasterExpression = (Expr *) coalesceExpr;
}
else if (aggregateType == AGGREGATE_ARRAY_AGG ||
aggregateType == AGGREGATE_JSONB_AGG ||
aggregateType == AGGREGATE_JSONB_OBJECT_AGG ||
aggregateType == AGGREGATE_JSON_AGG ||
aggregateType == AGGREGATE_JSON_OBJECT_AGG)
{
/*
* Array and json aggregates are handled in two steps. First, we compute
* array_agg() or json aggregate on the worker nodes. Then, we gather
* the arrays or jsons on the master and compute the array_cat_agg()
* or jsonb_cat_agg() aggregate on them to get the final array or json.
*/
const char *catAggregateName = NULL;
Oid catInputType = InvalidOid;
/* worker aggregate and original aggregate have same return type */
Oid workerReturnType = exprType((Node *) originalAggregate);
int32 workerReturnTypeMod = exprTypmod((Node *) originalAggregate);
Oid workerCollationId = exprCollation((Node *) originalAggregate);
/* assert that we do not support array or json aggregation with
* distinct or order by */
Assert(!originalAggregate->aggorder);
Assert(!originalAggregate->aggdistinct);
if (aggregateType == AGGREGATE_ARRAY_AGG)
{
/* array_cat_agg() takes anyarray as input */
catAggregateName = ARRAY_CAT_AGGREGATE_NAME;
catInputType = ANYCOMPATIBLEARRAYOID;
}
else if (aggregateType == AGGREGATE_JSONB_AGG ||
aggregateType == AGGREGATE_JSONB_OBJECT_AGG)
{
/* jsonb_cat_agg() takes jsonb as input */
catAggregateName = JSONB_CAT_AGGREGATE_NAME;
catInputType = JSONBOID;
}
else
{
/* json_cat_agg() takes json as input */
catAggregateName = JSON_CAT_AGGREGATE_NAME;
catInputType = JSONOID;
}
Assert(catAggregateName != NULL);
Assert(catInputType != InvalidOid);
Oid aggregateFunctionId = AggregateFunctionOid(catAggregateName,
catInputType);
/* create argument for the array_cat_agg() or jsonb_cat_agg() aggregate */
Var *column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
TargetEntry *catAggArgument = makeTargetEntry((Expr *) column, argumentId, NULL,
false);
walkerContext->columnId++;
/* construct the master array_cat_agg() or jsonb_cat_agg() expression */
Aggref *newMasterAggregate = copyObject(originalAggregate);
newMasterAggregate->aggfnoid = aggregateFunctionId;
newMasterAggregate->args = list_make1(catAggArgument);
newMasterAggregate->aggfilter = NULL;
newMasterAggregate->aggtranstype = InvalidOid;
if (aggregateType == AGGREGATE_ARRAY_AGG)
{
/*
* Postgres expects the type of the array here such as INT4ARRAYOID.
* Hence we set it to workerReturnType. If we set this to
* ANYCOMPATIBLEARRAYOID then we will get the following error:
* "argument declared anycompatiblearray is not an array but type anycompatiblearray"
*/
newMasterAggregate->aggargtypes = list_make1_oid(workerReturnType);
}
else
{
newMasterAggregate->aggargtypes = list_make1_oid(ANYARRAYOID);
}
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) newMasterAggregate;
}
else if (aggregateType == AGGREGATE_HLL_ADD ||
aggregateType == AGGREGATE_HLL_UNION)
{
/*
* If hll aggregates are called, we simply create the hll_union_aggregate
* to apply in the master after running the original aggregate in
* workers.
*/
Oid hllType = exprType((Node *) originalAggregate);
Oid unionFunctionId = AggregateFunctionOid(HLL_UNION_AGGREGATE_NAME, hllType);
int32 hllReturnTypeMod = exprTypmod((Node *) originalAggregate);
Oid hllTypeCollationId = exprCollation((Node *) originalAggregate);
Var *hllColumn = makeVar(masterTableId, walkerContext->columnId, hllType,
hllReturnTypeMod, hllTypeCollationId, columnLevelsUp);
walkerContext->columnId++;
TargetEntry *hllTargetEntry = makeTargetEntry((Expr *) hllColumn, argumentId,
NULL, false);
Aggref *unionAggregate = makeNode(Aggref);
unionAggregate->aggfnoid = unionFunctionId;
unionAggregate->aggtype = hllType;
unionAggregate->args = list_make1(hllTargetEntry);
unionAggregate->aggkind = AGGKIND_NORMAL;
unionAggregate->aggfilter = NULL;
unionAggregate->aggtranstype = InvalidOid;
unionAggregate->aggargtypes = list_make1_oid(hllType);
unionAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) unionAggregate;
}
else if (aggregateType == AGGREGATE_TOPN_UNION_AGG ||
aggregateType == AGGREGATE_TOPN_ADD_AGG)
{
/*
* Top-N aggregates are handled in two steps. First, we compute
* topn_add_agg() or topn_union_agg() aggregates on the worker nodes.
* Then, we gather the Top-Ns on the master and take the union of all
* to get the final topn.
*/
/* worker aggregate and original aggregate have same return type */
Oid topnType = exprType((Node *) originalAggregate);
Oid unionFunctionId = AggregateFunctionOid(TOPN_UNION_AGGREGATE_NAME,
topnType);
int32 topnReturnTypeMod = exprTypmod((Node *) originalAggregate);
Oid topnTypeCollationId = exprCollation((Node *) originalAggregate);
/* create argument for the topn_union_agg() aggregate */
Var *topnColumn = makeVar(masterTableId, walkerContext->columnId, topnType,
topnReturnTypeMod, topnTypeCollationId, columnLevelsUp);
walkerContext->columnId++;
TargetEntry *topNTargetEntry = makeTargetEntry((Expr *) topnColumn, argumentId,
NULL, false);
/* construct the master topn_union_agg() expression */
Aggref *unionAggregate = makeNode(Aggref);
unionAggregate->aggfnoid = unionFunctionId;
unionAggregate->aggtype = topnType;
unionAggregate->args = list_make1(topNTargetEntry);
unionAggregate->aggkind = AGGKIND_NORMAL;
unionAggregate->aggfilter = NULL;
unionAggregate->aggtranstype = InvalidOid;
unionAggregate->aggargtypes = list_make1_oid(topnType);
unionAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) unionAggregate;
}
else if (aggregateType == AGGREGATE_TDIGEST_COMBINE ||
aggregateType == AGGREGATE_TDIGEST_ADD_DOUBLE)
{
/* tdigest of column */
Oid tdigestType = TDigestExtensionTypeOid(); /* tdigest type */
Oid unionFunctionId = TDigestExtensionAggTDigest1();
int32 tdigestReturnTypeMod = exprTypmod((Node *) originalAggregate);
Oid tdigestTypeCollationId = exprCollation((Node *) originalAggregate);
/* create first argument for tdigest_precentile(tdigest, double) */
Var *tdigestColumn = makeVar(masterTableId, walkerContext->columnId, tdigestType,
tdigestReturnTypeMod, tdigestTypeCollationId,
columnLevelsUp);
TargetEntry *tdigestTargetEntry = makeTargetEntry((Expr *) tdigestColumn,
argumentId,
NULL, false);
walkerContext->columnId++;
/* construct the master tdigest(tdigest) expression */
Aggref *unionAggregate = makeNode(Aggref);
unionAggregate->aggfnoid = unionFunctionId;
unionAggregate->aggtype = originalAggregate->aggtype;
unionAggregate->args = list_make1(tdigestTargetEntry);
unionAggregate->aggkind = AGGKIND_NORMAL;
unionAggregate->aggfilter = NULL;
unionAggregate->aggtranstype = InvalidOid;
unionAggregate->aggargtypes = list_make1_oid(tdigestType);
unionAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) unionAggregate;
}
else if (aggregateType == AGGREGATE_TDIGEST_PERCENTILE_ADD_DOUBLE ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_ADD_DOUBLEARRAY ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_OF_ADD_DOUBLE ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_OF_ADD_DOUBLEARRAY)
{
/* tdigest of column */
Oid tdigestType = TDigestExtensionTypeOid();
Oid unionFunctionId = InvalidOid;
if (aggregateType == AGGREGATE_TDIGEST_PERCENTILE_ADD_DOUBLE)
{
unionFunctionId = TDigestExtensionAggTDigestPercentile2();
}
else if (aggregateType == AGGREGATE_TDIGEST_PERCENTILE_ADD_DOUBLEARRAY)
{
unionFunctionId = TDigestExtensionAggTDigestPercentile2a();
}
else if (aggregateType == AGGREGATE_TDIGEST_PERCENTILE_OF_ADD_DOUBLE)
{
unionFunctionId = TDigestExtensionAggTDigestPercentileOf2();
}
else if (aggregateType == AGGREGATE_TDIGEST_PERCENTILE_OF_ADD_DOUBLEARRAY)
{
unionFunctionId = TDigestExtensionAggTDigestPercentileOf2a();
}
Assert(OidIsValid(unionFunctionId));
int32 tdigestReturnTypeMod = exprTypmod((Node *) originalAggregate);
Oid tdigestTypeCollationId = exprCollation((Node *) originalAggregate);
/* create first argument for tdigest_precentile(tdigest, double) */
Var *tdigestColumn = makeVar(masterTableId, walkerContext->columnId, tdigestType,
tdigestReturnTypeMod, tdigestTypeCollationId,
columnLevelsUp);
TargetEntry *tdigestTargetEntry = makeTargetEntry((Expr *) tdigestColumn,
argumentId, NULL, false);
walkerContext->columnId++;
/* construct the master tdigest_precentile(tdigest, double) expression */
Aggref *unionAggregate = makeNode(Aggref);
unionAggregate->aggfnoid = unionFunctionId;
unionAggregate->aggtype = originalAggregate->aggtype;
unionAggregate->args = list_make2(
tdigestTargetEntry,
list_nth(originalAggregate->args, 2));
unionAggregate->aggkind = AGGKIND_NORMAL;
unionAggregate->aggfilter = NULL;
unionAggregate->aggtranstype = InvalidOid;
unionAggregate->aggargtypes = list_make2_oid(
tdigestType,
list_nth_oid(originalAggregate->aggargtypes, 2));
unionAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) unionAggregate;
}
else if (aggregateType == AGGREGATE_TDIGEST_PERCENTILE_TDIGEST_DOUBLE ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_TDIGEST_DOUBLEARRAY ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_OF_TDIGEST_DOUBLE ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_OF_TDIGEST_DOUBLEARRAY)
{
/* tdigest of column */
Oid tdigestType = TDigestExtensionTypeOid();
/* These functions already will combine the tdigest arguments returned */
Oid unionFunctionId = originalAggregate->aggfnoid;
int32 tdigestReturnTypeMod = exprTypmod((Node *) originalAggregate);
Oid tdigestTypeCollationId = exprCollation((Node *) originalAggregate);
/* create first argument for tdigest_precentile(tdigest, double) */
Var *tdigestColumn = makeVar(masterTableId, walkerContext->columnId, tdigestType,
tdigestReturnTypeMod, tdigestTypeCollationId,
columnLevelsUp);
TargetEntry *tdigestTargetEntry = makeTargetEntry((Expr *) tdigestColumn,
argumentId, NULL, false);
walkerContext->columnId++;
/* construct the master tdigest_precentile(tdigest, double) expression */
Aggref *unionAggregate = makeNode(Aggref);
unionAggregate->aggfnoid = unionFunctionId;
unionAggregate->aggtype = originalAggregate->aggtype;
unionAggregate->args = list_make2(
tdigestTargetEntry,
list_nth(originalAggregate->args, 1));
unionAggregate->aggkind = AGGKIND_NORMAL;
unionAggregate->aggfilter = NULL;
unionAggregate->aggtranstype = InvalidOid;
unionAggregate->aggargtypes = list_make2_oid(
tdigestType,
list_nth_oid(originalAggregate->aggargtypes, 1));
unionAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) unionAggregate;
}
else if (aggregateType == AGGREGATE_CUSTOM_COMBINE)
{
HeapTuple aggTuple =
SearchSysCache1(AGGFNOID, ObjectIdGetDatum(originalAggregate->aggfnoid));
Form_pg_aggregate aggform;
Oid combine;
if (!HeapTupleIsValid(aggTuple))
{
elog(ERROR, "citus cache lookup failed for aggregate %u",
originalAggregate->aggfnoid);
return NULL;
}
else
{
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
combine = aggform->aggcombinefn;
ReleaseSysCache(aggTuple);
}
if (combine != InvalidOid)
{
Oid coordCombineId = CoordCombineAggOid();
Oid workerReturnType = CSTRINGOID;
int32 workerReturnTypeMod = -1;
Oid workerCollationId = InvalidOid;
Oid resultType = exprType((Node *) originalAggregate);
Const *aggOidParam = makeConst(OIDOID, -1, InvalidOid, sizeof(Oid),
ObjectIdGetDatum(originalAggregate->aggfnoid),
false, true);
Var *column = makeVar(masterTableId, walkerContext->columnId,
workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
walkerContext->columnId++;
Const *nullTag = makeNullConst(resultType, -1, InvalidOid);
List *aggArguments =
list_make3(makeTargetEntry((Expr *) aggOidParam, 1, NULL, false),
makeTargetEntry((Expr *) column, 2, NULL, false),
makeTargetEntry((Expr *) nullTag, 3, NULL, false));
/* coord_combine_agg(agg, workercol) */
Aggref *newMasterAggregate = makeNode(Aggref);
newMasterAggregate->aggfnoid = coordCombineId;
newMasterAggregate->aggtype = originalAggregate->aggtype;
newMasterAggregate->args = aggArguments;
newMasterAggregate->aggkind = AGGKIND_NORMAL;
newMasterAggregate->aggfilter = NULL;
newMasterAggregate->aggtranstype = INTERNALOID;
newMasterAggregate->aggargtypes = list_make3_oid(OIDOID, CSTRINGOID,
resultType);
newMasterAggregate->aggsplit = AGGSPLIT_SIMPLE;
newMasterExpression = (Expr *) newMasterAggregate;
}
else
{
elog(ERROR, "Aggregate lacks COMBINEFUNC");
}
}
else
{
/*
* All other aggregates are handled as they are. These include sum, min,
* and max.
*/
/* worker aggregate and original aggregate have the same return type */
Oid workerReturnType = exprType((Node *) originalAggregate);
int32 workerReturnTypeMod = exprTypmod((Node *) originalAggregate);
Oid workerCollationId = exprCollation((Node *) originalAggregate);
const char *aggregateName = AggregateNames[aggregateType];
Oid aggregateFunctionId = AggregateFunctionOid(aggregateName, workerReturnType);
Oid masterReturnType = get_func_rettype(aggregateFunctionId);
Aggref *newMasterAggregate = copyObject(originalAggregate);
newMasterAggregate->aggdistinct = NULL;
newMasterAggregate->aggfnoid = aggregateFunctionId;
newMasterAggregate->aggtype = masterReturnType;
newMasterAggregate->aggfilter = NULL;
/*
* If return type aggregate is anyelement, its actual return type is
* determined on the type of its argument. So we replace it with the
* argument type in that case.
*/
if (masterReturnType == ANYELEMENTOID)
{
newMasterAggregate->aggtype = workerReturnType;
Expr *firstArg = FirstAggregateArgument(originalAggregate);
newMasterAggregate->aggcollid = exprCollation((Node *) firstArg);
}
Var *column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
walkerContext->columnId++;
/* aggref expects its arguments to be wrapped in target entries */
TargetEntry *columnTargetEntry = makeTargetEntry((Expr *) column, argumentId,
NULL, false);
newMasterAggregate->args = list_make1(columnTargetEntry);
newMasterExpression = (Expr *) newMasterAggregate;
}
/*
* Aggregate functions could have changed the return type. If so, we wrap
* the new expression with a conversion function to make it have the same
* type as the original aggregate. We need this since functions like sorting
* and grouping have already been chosen based on the original type.
*/
Expr *typeConvertedExpression = AddTypeConversion((Node *) originalAggregate,
(Node *) newMasterExpression);
if (typeConvertedExpression != NULL)
{
newMasterExpression = typeConvertedExpression;
}
return newMasterExpression;
}
/*
* MasterAverageExpression creates an expression of the form (sum(column1) /
* sum(column2)), where column1 is the sum of the original value, and column2 is
* the count of that value. This expression allows us to evaluate the average
* function over distributed data.
*/
static Expr *
MasterAverageExpression(Oid sumAggregateType, Oid countAggregateType,
AttrNumber *columnId)
{
const char *sumAggregateName = AggregateNames[AGGREGATE_SUM];
const int32 defaultTypeMod = -1;
const Index defaultLevelsUp = 0;
const AttrNumber argumentId = 1;
Oid sumTypeCollationId = get_typcollation(sumAggregateType);
Oid countTypeCollationId = get_typcollation(countAggregateType);
/* create the first argument for sum(column1) */
Var *firstColumn = makeVar(masterTableId, (*columnId), sumAggregateType,
defaultTypeMod, sumTypeCollationId, defaultLevelsUp);
TargetEntry *firstTargetEntry = makeTargetEntry((Expr *) firstColumn, argumentId,
NULL, false);
(*columnId)++;
Aggref *firstSum = makeNode(Aggref);
firstSum->aggfnoid = AggregateFunctionOid(sumAggregateName, sumAggregateType);
firstSum->aggtype = get_func_rettype(firstSum->aggfnoid);
firstSum->args = list_make1(firstTargetEntry);
firstSum->aggkind = AGGKIND_NORMAL;
firstSum->aggtranstype = InvalidOid;
firstSum->aggargtypes = list_make1_oid(firstSum->aggtype);
firstSum->aggsplit = AGGSPLIT_SIMPLE;
/* create the second argument for sum(column2) */
Var *secondColumn = makeVar(masterTableId, (*columnId), countAggregateType,
defaultTypeMod, countTypeCollationId, defaultLevelsUp);
TargetEntry *secondTargetEntry = makeTargetEntry((Expr *) secondColumn, argumentId,
NULL, false);
(*columnId)++;
Aggref *secondSum = makeNode(Aggref);
secondSum->aggfnoid = AggregateFunctionOid(sumAggregateName, countAggregateType);
secondSum->aggtype = get_func_rettype(secondSum->aggfnoid);
secondSum->args = list_make1(secondTargetEntry);
secondSum->aggkind = AGGKIND_NORMAL;
secondSum->aggtranstype = InvalidOid;
secondSum->aggargtypes = list_make1_oid(firstSum->aggtype);
secondSum->aggsplit = AGGSPLIT_SIMPLE;
/*
* Build the division operator between these two aggregates. This function
* will convert the types of the aggregates if necessary.
*/
List *operatorNameList = list_make1(makeString(DIVISION_OPER_NAME));
Expr *opExpr = make_op(NULL, operatorNameList, (Node *) firstSum, (Node *) secondSum,
NULL,
-1);
return opExpr;
}
/*
* AddTypeConversion checks if the given expressions generate the same types. If
* they don't, the function adds a type conversion function on top of the new
* expression to have it generate the same type as the original aggregate.
*/
static Expr *
AddTypeConversion(Node *originalAggregate, Node *newExpression)
{
Oid newTypeId = exprType(newExpression);
Oid originalTypeId = exprType(originalAggregate);
int32 originalTypeMod = exprTypmod(originalAggregate);
/* nothing to do if the two types are the same */
if (originalTypeId == newTypeId)
{
return NULL;
}
/* otherwise, add a type conversion function */
Node *typeConvertedExpression = coerce_to_target_type(NULL, newExpression, newTypeId,
originalTypeId, originalTypeMod,
COERCION_EXPLICIT,
COERCE_EXPLICIT_CAST, -1);
Assert(typeConvertedExpression != NULL);
return (Expr *) typeConvertedExpression;
}
/*
* WorkerExtendedOpNode creates the worker extended operator node from the given
* originalOpNode and extendedOpNodeProperties.
*
* For the details of the processing see the comments of the functions that
* are called from this function.
*/
static MultiExtendedOp *
WorkerExtendedOpNode(MultiExtendedOp *originalOpNode,
ExtendedOpNodeProperties *extendedOpNodeProperties)
{
bool distinctPreventsLimitPushdown = false;
QueryTargetList queryTargetList;
QueryGroupClause queryGroupClause;
QueryDistinctClause queryDistinctClause;
QueryWindowClause queryWindowClause;
QueryOrderByLimit queryOrderByLimit;
Node *queryHavingQual = NULL;
List *originalTargetEntryList = originalOpNode->targetList;
List *originalGroupClauseList = originalOpNode->groupClauseList;
List *originalSortClauseList = originalOpNode->sortClauseList;
Node *originalHavingQual = originalOpNode->havingQual;
Node *originalLimitCount = originalOpNode->limitCount;
Node *originalLimitOffset = originalOpNode->limitOffset;
List *originalWindowClause = originalOpNode->windowClause;
List *originalDistinctClause = originalOpNode->distinctClause;
bool hasDistinctOn = originalOpNode->hasDistinctOn;
int originalGroupClauseLength = list_length(originalGroupClauseList);
/* initialize to default values */
memset(&queryTargetList, 0, sizeof(queryTargetList));
memset(&queryGroupClause, 0, sizeof(queryGroupClause));
memset(&queryDistinctClause, 0, sizeof(queryDistinctClause));
memset(&queryWindowClause, 0, sizeof(queryWindowClause));
memset(&queryOrderByLimit, 0, sizeof(queryOrderByLimit));
/* calculate the next sort group index based on the original target list */
Index nextSortGroupRefIndex = GetNextSortGroupRef(originalTargetEntryList);
/* targetProjectionNumber starts from 1 */
queryTargetList.targetProjectionNumber = 1;
if (!extendedOpNodeProperties->pullUpIntermediateRows)
{
queryGroupClause.groupClauseList = copyObject(originalGroupClauseList);
}
else
{
queryGroupClause.groupClauseList = NIL;
}
/*
* For the purpose of this variable, not pushing down when there are no groups
* is pushing down the original grouping, ie the worker's GROUP BY matches
* the master's GROUP BY.
*/
bool pushingDownOriginalGrouping =
list_length(queryGroupClause.groupClauseList) == originalGroupClauseLength;
/*
* nextSortGroupRefIndex is used by group by, window and order by clauses.
* Thus, we pass a reference to a single nextSortGroupRefIndex and expect
* it modified separately while processing those parts of the query.
*/
queryGroupClause.nextSortGroupRefIndex = &nextSortGroupRefIndex;
queryWindowClause.nextSortGroupRefIndex = &nextSortGroupRefIndex;
queryOrderByLimit.nextSortGroupRefIndex = &nextSortGroupRefIndex;
/* process each part of the query in order to generate the worker query's parts */
ProcessTargetListForWorkerQuery(originalTargetEntryList, extendedOpNodeProperties,
&queryTargetList, &queryGroupClause);
ProcessHavingClauseForWorkerQuery(originalHavingQual, extendedOpNodeProperties,
&queryHavingQual, &queryTargetList,
&queryGroupClause);
/*
* Planner optimizations may leave window clauses with hasWindowFuncs as false.
* Ignore window clauses in that case.
*/
if (extendedOpNodeProperties->hasWindowFuncs)
{
if (extendedOpNodeProperties->onlyPushableWindowFunctions)
{
ProcessWindowFunctionsForWorkerQuery(originalWindowClause,
originalTargetEntryList,
&queryWindowClause, &queryTargetList);
}
else
{
ProcessWindowFunctionPullUpForWorkerQuery(originalWindowClause,
&queryTargetList);
}
}
if (ShouldProcessDistinctOrderAndLimitForWorker(extendedOpNodeProperties,
pushingDownOriginalGrouping,
originalHavingQual))
{
bool queryHasAggregates = TargetListHasAggregates(originalTargetEntryList);
ProcessDistinctClauseForWorkerQuery(originalDistinctClause, hasDistinctOn,
queryGroupClause.groupClauseList,
queryHasAggregates, &queryDistinctClause,
&distinctPreventsLimitPushdown);
/*
* Order by and limit clauses are relevant to each other, and processing
* them together makes it handy for us.
*
* The other parts of the query might have already prohibited pushing down
* LIMIT and ORDER BY clauses as described below:
* (1) Creating a new group by clause during aggregate mutation, or
* (2) Distinct clause is not pushed down
*/
bool groupByExtended =
list_length(queryGroupClause.groupClauseList) > originalGroupClauseLength;
if (pushingDownOriginalGrouping && !groupByExtended &&
!distinctPreventsLimitPushdown)
{
/* both sort and limit clauses rely on similar information */
OrderByLimitReference limitOrderByReference =
BuildOrderByLimitReference(hasDistinctOn,
extendedOpNodeProperties->
groupedByDisjointPartitionColumn,
extendedOpNodeProperties->
onlyPushableWindowFunctions,
originalGroupClauseList,
originalSortClauseList,
originalTargetEntryList);
ProcessLimitOrderByForWorkerQuery(limitOrderByReference, originalLimitCount,
originalLimitOffset,
originalSortClauseList,
originalGroupClauseList,
originalTargetEntryList,
&queryOrderByLimit,
&queryTargetList);
}
}
/* finally, fill the extended op node with the data we gathered */
MultiExtendedOp *workerExtendedOpNode = CitusMakeNode(MultiExtendedOp);
workerExtendedOpNode->targetList = queryTargetList.targetEntryList;
workerExtendedOpNode->groupClauseList = queryGroupClause.groupClauseList;
workerExtendedOpNode->havingQual = queryHavingQual;
workerExtendedOpNode->hasDistinctOn = queryDistinctClause.workerHasDistinctOn;
workerExtendedOpNode->distinctClause = queryDistinctClause.workerDistinctClause;
workerExtendedOpNode->hasWindowFuncs = queryWindowClause.hasWindowFunctions;
workerExtendedOpNode->windowClause = queryWindowClause.workerWindowClauseList;
workerExtendedOpNode->sortClauseList = queryOrderByLimit.workerSortClauseList;
workerExtendedOpNode->limitCount = queryOrderByLimit.workerLimitCount;
/*
* If the limitCount cannot be pushed down it will be NULL, so the deparser will
* ignore the limitOption.
*/
workerExtendedOpNode->limitOption = originalOpNode->limitOption;
return workerExtendedOpNode;
}
/*
* ProcessTargetListForWorkerQuery gets the inputs and modifies the outputs
* such that the worker query's target list and group by clauses are extended
* for the given inputs.
*
* The function walks over the input targetEntryList. For the entries
* with aggregates in them, it calls the recursive aggregate walker function to
* create aggregates for the worker nodes. For example, the avg() is sent to
* the worker with two expressions count() and sum(). Thus, a single target entry
* might end up with multiple expressions in the worker query.
*
* The function doesn't change the aggregates in the window functions and sends them
* as-is. The reason is that Citus only supports pushing down window functions when
* this is safe to do.
*
* The function also handles count distinct operator if it is used in repartition
* subqueries or on non-partition columns (e.g., cannot be pushed down). Each
* column in count distinct aggregate is added to target list, and group by
* list of worker extended operator. This approach guarantees the distinctness
* in the worker queries.
*
* inputs: targetEntryList, extendedOpNodeProperties
* outputs: queryTargetList, queryGroupClause
*/
static void
ProcessTargetListForWorkerQuery(List *targetEntryList,
ExtendedOpNodeProperties *extendedOpNodeProperties,
QueryTargetList *queryTargetList,
QueryGroupClause *queryGroupClause)
{
WorkerAggregateWalkerContext workerAggContext = {
.extendedOpNodeProperties = extendedOpNodeProperties,
};
/* iterate over original target entries */
TargetEntry *originalTargetEntry = NULL;
foreach_declared_ptr(originalTargetEntry, targetEntryList)
{
Expr *originalExpression = originalTargetEntry->expr;
List *newExpressionList = NIL;
/* reset walker context */
workerAggContext.expressionList = NIL;
workerAggContext.createGroupByClause = false;
/*
* If we can push down the expression we copy the expression to the targetlist of the worker query.
* Otherwise the expression is processed to be combined on the coordinator.
*/
if (CanPushDownExpression((Node *) originalExpression, extendedOpNodeProperties))
{
newExpressionList = list_make1(originalExpression);
}
else
{
WorkerAggregateWalker((Node *) originalExpression, &workerAggContext);
newExpressionList = workerAggContext.expressionList;
}
ExpandWorkerTargetEntry(newExpressionList, originalTargetEntry,
workerAggContext.createGroupByClause,
queryTargetList, queryGroupClause);
}
}
/*
* ProcessHavingClauseForWorkerQuery gets the inputs and modifies the outputs
* such that the worker query's target list and group by clauses are extended
* based on the inputs.
*
* The rule is that Citus always applies the HAVING clause on the
* coordinator. Thus, it pulls the necessary data from the workers. Also, when the
* having clause is safe to pushdown to the workers, workerHavingQual is set to
* be the original having clause.
*
* inputs: originalHavingQual, extendedOpNodeProperties
* outputs: workerHavingQual, queryTargetList, queryGroupClause
*/
static void
ProcessHavingClauseForWorkerQuery(Node *originalHavingQual,
ExtendedOpNodeProperties *extendedOpNodeProperties,
Node **workerHavingQual,
QueryTargetList *queryTargetList,
QueryGroupClause *queryGroupClause)
{
*workerHavingQual = NULL;
if (originalHavingQual == NULL)
{
return;
}
if (extendedOpNodeProperties->pushDownGroupingAndHaving)
{
/*
* We converted the having expression to a list in subquery pushdown
* planner. However, this query cannot be parsed as it is in the worker.
* We should convert this back to being explicit for worker query
* so that it can be parsed when it hits the standard planner in worker.
*/
if (IsA(originalHavingQual, List))
{
*workerHavingQual =
(Node *) make_ands_explicit((List *) originalHavingQual);
}
else
{
*workerHavingQual = originalHavingQual;
}
}
else
{
/*
* If the GROUP BY or PARTITION BY is not on the distribution column
* then we need to combine the aggregates in the HAVING across shards.
*/
WorkerAggregateWalkerContext workerAggContext = {
.extendedOpNodeProperties = extendedOpNodeProperties,
};
WorkerAggregateWalker(originalHavingQual, &workerAggContext);
List *newExpressionList = workerAggContext.expressionList;
TargetEntry *targetEntry = NULL;
ExpandWorkerTargetEntry(newExpressionList, targetEntry,
workerAggContext.createGroupByClause,
queryTargetList, queryGroupClause);
}
}
/*
* ProcessDistinctClauseForWorkerQuery gets the inputs and modifies the outputs
* such that worker query's DISTINCT and DISTINCT ON clauses are set accordingly.
* Note the function may or may not decide to pushdown the DISTINCT and DISTINCT
* on clauses based on the inputs.
*
* See the detailed comments in the function for the rules of pushing down DISTINCT
* and DISTINCT ON clauses to the worker queries.
*
* The function also sets distinctPreventsLimitPushdown. As the name reveals,
* distinct could prevent pushing down LIMIT clauses later in the planning.
* For the details, see the comments in the function.
*
* inputs: distinctClause, hasDistinctOn, groupClauseList, queryHasAggregates
* outputs: queryDistinctClause, distinctPreventsLimitPushdown
*
*/
static void
ProcessDistinctClauseForWorkerQuery(List *distinctClause, bool hasDistinctOn,
List *groupClauseList,
bool queryHasAggregates,
QueryDistinctClause *queryDistinctClause,
bool *distinctPreventsLimitPushdown)
{
*distinctPreventsLimitPushdown = false;
if (distinctClause == NIL)
{
return;
}
bool distinctClauseSupersetofGroupClause = false;
if (groupClauseList == NIL ||
IsGroupBySubsetOfDistinct(groupClauseList, distinctClause))
{
distinctClauseSupersetofGroupClause = true;
}
else
{
distinctClauseSupersetofGroupClause = false;
/*
* GROUP BY being a subset of DISTINCT guarantees the
* distinctness on the workers. Otherwise, pushing down
* LIMIT might cause missing the necessary data from
* the worker query
*/
*distinctPreventsLimitPushdown = true;
}
/*
* Distinct is pushed down to worker query only if the query does not
* contain an aggregate in which master processing might be required to
* complete the final result before distinct operation. We also prevent
* distinct pushdown if distinct clause is missing some entries that
* group by clause has.
*/
bool shouldPushdownDistinct = !queryHasAggregates &&
distinctClauseSupersetofGroupClause;
if (shouldPushdownDistinct)
{
queryDistinctClause->workerDistinctClause = distinctClause;
queryDistinctClause->workerHasDistinctOn = hasDistinctOn;
}
}
/*
* ProcessWindowFunctionsForWorkerQuery gets the inputs and modifies the outputs such
* that worker query's workerWindowClauseList is set when the window clauses are safe to
* pushdown.
*
* Note that even though Citus only pushes down the window functions, it may need to
* modify the target list of the worker query when the window function refers to
* an avg(). The reason is that any aggregate which is also referred by other
* target entries would be mutated by Citus. Thus, we add a copy of the same aggregate
* to the worker target list to make sure that the window function refers to the
* non-mutated aggregate.
*
* inputs: windowClauseList, originalTargetEntryList
* outputs: queryWindowClause, queryTargetList
*
*/
static void
ProcessWindowFunctionsForWorkerQuery(List *windowClauseList,
List *originalTargetEntryList,
QueryWindowClause *queryWindowClause,
QueryTargetList *queryTargetList)
{
if (windowClauseList == NIL)
{
return;
}
queryWindowClause->workerWindowClauseList = windowClauseList;
queryWindowClause->hasWindowFunctions = true;
}
/* ProcessWindowFunctionPullUpForWorkerQuery pulls up inputs for window functions */
static void
ProcessWindowFunctionPullUpForWorkerQuery(List *windowClause,
QueryTargetList *queryTargetList)
{
if (windowClause != NIL)
{
List *columnList = pull_var_clause_default((Node *) windowClause);
Expr *newExpression = NULL;
foreach_declared_ptr(newExpression, columnList)
{
TargetEntry *newTargetEntry = makeNode(TargetEntry);
newTargetEntry->expr = newExpression;
newTargetEntry->resname =
WorkerColumnName(queryTargetList->targetProjectionNumber);
/* force resjunk to false as we may need this on the master */
newTargetEntry->resjunk = false;
newTargetEntry->resno = queryTargetList->targetProjectionNumber;
queryTargetList->targetEntryList =
lappend(queryTargetList->targetEntryList, newTargetEntry);
queryTargetList->targetProjectionNumber++;
}
}
}
/*
* ProcessLimitOrderByForWorkerQuery gets the inputs and modifies the outputs
* such that worker query's LIMIT and ORDER BY clauses are set accordingly.
* Adding entries to ORDER BY might trigger adding new entries to newTargetEntryList.
* See GenerateNewTargetEntriesForSortClauses() for the details.
*
* For the decisions on whether and how to pushdown LIMIT and ORDER BY are documented
* in the functions that are called from this function.
*
* inputs: sortLimitReference, originalLimitCount, limitOffset,
* sortClauseList, groupClauseList, originalTargetList
* outputs: queryOrderByLimit, queryTargetList
*/
static void
ProcessLimitOrderByForWorkerQuery(OrderByLimitReference orderByLimitReference,
Node *originalLimitCount, Node *limitOffset,
List *sortClauseList, List *groupClauseList,
List *originalTargetList,
QueryOrderByLimit *queryOrderByLimit,
QueryTargetList *queryTargetList)
{
queryOrderByLimit->workerLimitCount =
WorkerLimitCount(originalLimitCount, limitOffset, orderByLimitReference);
queryOrderByLimit->workerSortClauseList =
WorkerSortClauseList(originalLimitCount,
groupClauseList,
sortClauseList,
orderByLimitReference);
}
/*
* BuildOrderByLimitReference is a helper function that simply builds
* the necessary information for processing the limit and order by.
* The return value should be used in a read-only manner.
*/
static OrderByLimitReference
BuildOrderByLimitReference(bool hasDistinctOn, bool groupedByDisjointPartitionColumn,
bool onlyPushableWindowFunctions,
List *groupClause, List *sortClauseList, List *targetList)
{
OrderByLimitReference limitOrderByReference;
limitOrderByReference.groupedByDisjointPartitionColumn =
groupedByDisjointPartitionColumn;
limitOrderByReference.onlyPushableWindowFunctions =
onlyPushableWindowFunctions;
limitOrderByReference.hasDistinctOn = hasDistinctOn;
limitOrderByReference.groupClauseIsEmpty = (groupClause == NIL);
limitOrderByReference.sortClauseIsEmpty = (sortClauseList == NIL);
limitOrderByReference.canApproximate =
CanPushDownLimitApproximate(sortClauseList, targetList);
limitOrderByReference.hasOrderByAggregate =
HasOrderByAggregate(sortClauseList, targetList);
return limitOrderByReference;
}
/*
* TargetListHasAggregates returns true if any of the elements in the
* target list contain aggregates that are not inside the window functions.
* This function should not be called if window functions are being pulled up.
*/
bool
TargetListHasAggregates(List *targetEntryList)
{
TargetEntry *targetEntry = NULL;
foreach_declared_ptr(targetEntry, targetEntryList)
{
Expr *targetExpr = targetEntry->expr;
bool hasAggregates = contain_aggs_of_level((Node *) targetExpr, 0);
bool hasWindowFunction = contain_window_function((Node *) targetExpr);
/*
* If the expression uses aggregates inside window function contain agg
* clause still returns true. We want to make sure it is not a part of
* window function before we proceed.
*/
if (hasAggregates && !hasWindowFunction)
{
return true;
}
}
return false;
}
/*
* ExpandWorkerTargetEntry is a utility function which processes the
* expressions that are intended to be added to the worker target list.
*
* In summary, the function gets a list of expressions, converts them to target
* entries and updates all the necessary fields such that the expression is correctly
* added to the worker query's target list.
*
* Inputs:
* - expressionList: The list of expressions that should be added to the worker query's
* target list.
* - originalTargetEntry: Target entry that the expressionList generated for. NULL
* if the expressionList is not generated from any target entry.
* - addToGroupByClause: True if the expressionList should also be added to the
* worker query's GROUP BY clause.
*/
static void
ExpandWorkerTargetEntry(List *expressionList, TargetEntry *originalTargetEntry,
bool addToGroupByClause, QueryTargetList *queryTargetList,
QueryGroupClause *queryGroupClause)
{
/* now create target entries for each new expression */
Expr *newExpression = NULL;
foreach_declared_ptr(newExpression, expressionList)
{
/* generate and add the new target entry to the target list */
TargetEntry *newTargetEntry =
GenerateWorkerTargetEntry(originalTargetEntry, newExpression,
queryTargetList->targetProjectionNumber);
queryTargetList->targetProjectionNumber++;
queryTargetList->targetEntryList =
lappend(queryTargetList->targetEntryList, newTargetEntry);
/*
* Detect new targets of type Var and add it to group clause list.
* This case is expected only if the target entry has aggregates and
* it is inside a repartitioned subquery. We create group by entry
* for each Var in target list. This code does not check if this
* Var was already in the target list or in group by clauses.
*/
if (IsA(newExpression, Var) && addToGroupByClause)
{
AppendTargetEntryToGroupClause(newTargetEntry, queryGroupClause);
}
}
}
/*
* GetNextSortGroupRef gets a target list entry and returns
* the next ressortgroupref that should be used based on the
* input target list.
*/
static Index
GetNextSortGroupRef(List *targetEntryList)
{
Index nextSortGroupRefIndex = 0;
/* find max of sort group ref index */
TargetEntry *targetEntry = NULL;
foreach_declared_ptr(targetEntry, targetEntryList)
{
if (targetEntry->ressortgroupref > nextSortGroupRefIndex)
{
nextSortGroupRefIndex = targetEntry->ressortgroupref;
}
}
/* next group ref index starts from max group ref index + 1 */
nextSortGroupRefIndex++;
return nextSortGroupRefIndex;
}
/*
* GenerateWorkerTargetEntry is a simple utility function which gets a
* target entry, an expression and a targetProjectionNumber.
*
* The function returns a newly allocated target entry which can be added
* to the worker's target list.
*/
static TargetEntry *
GenerateWorkerTargetEntry(TargetEntry *targetEntry, Expr *workerExpression,
AttrNumber targetProjectionNumber)
{
TargetEntry *newTargetEntry = NULL;
/*
* If a target entry is already provided, use a copy of
* it because some of the callers rely on resorigtbl and
* resorigcol.
*/
if (targetEntry)
{
newTargetEntry = flatCopyTargetEntry(targetEntry);
}
else
{
newTargetEntry = makeNode(TargetEntry);
}
if (newTargetEntry->resname == NULL)
{
newTargetEntry->resname = WorkerColumnName(targetProjectionNumber);
}
/* we can't generate a target entry without an expression */
Assert(workerExpression != NULL);
/* force resjunk to false as we may need this on the master */
newTargetEntry->expr = workerExpression;
newTargetEntry->resjunk = false;
newTargetEntry->resno = targetProjectionNumber;
return newTargetEntry;
}
/*
* AppendTargetEntryToGroupClause gets a target entry, pointer to group list
* and the ressortgroupref index.
*
* The function modifies all of the three input such that the target entry is
* appended to the group clause and the index is incremented by one.
*/
static void
AppendTargetEntryToGroupClause(TargetEntry *targetEntry,
QueryGroupClause *queryGroupClause)
{
Expr *targetExpr PG_USED_FOR_ASSERTS_ONLY = targetEntry->expr;
/* we currently only support appending Var target entries */
Assert(IsA(targetExpr, Var));
Var *targetColumn = (Var *) targetEntry->expr;
SortGroupClause *groupByClause = CreateSortGroupClause(targetColumn);
/* the target entry should have an index */
targetEntry->ressortgroupref = *queryGroupClause->nextSortGroupRefIndex;
/* the group by clause entry should point to the correct index in the target list */
groupByClause->tleSortGroupRef = *queryGroupClause->nextSortGroupRefIndex;
/* update the group by list and the index's value */
queryGroupClause->groupClauseList =
lappend(queryGroupClause->groupClauseList, groupByClause);
(*queryGroupClause->nextSortGroupRefIndex)++;
}
/*
* WorkerAggregateWalker walks over the original target entry expression, and
* creates the list of expression trees (potentially more than one) to execute
* on the worker nodes. The function creates new expressions for aggregates and
* columns; and recurses into expression_tree_walker() for all other expression
* types.
*/
static bool
WorkerAggregateWalker(Node *node, WorkerAggregateWalkerContext *walkerContext)
{
bool walkerResult = false;
if (node == NULL)
{
return false;
}
if (IsA(node, Aggref))
{
if (CanPushDownExpression(node, walkerContext->extendedOpNodeProperties))
{
walkerContext->expressionList = lappend(walkerContext->expressionList,
node);
}
else
{
Aggref *originalAggregate = (Aggref *) node;
List *workerAggregateList = WorkerAggregateExpressionList(originalAggregate,
walkerContext);
walkerContext->expressionList = list_concat(walkerContext->expressionList,
workerAggregateList);
}
}
else if (IsA(node, Var))
{
Var *originalColumn = (Var *) node;
walkerContext->expressionList = lappend(walkerContext->expressionList,
originalColumn);
}
else
{
walkerResult = expression_tree_walker(node, WorkerAggregateWalker,
(void *) walkerContext);
}
return walkerResult;
}
/*
* WorkerAggregateExpressionList takes in the original aggregate function, and
* determines the transformed aggregate functions to execute on worker nodes.
* The function then returns these aggregates in a list. It also creates
* group by clauses for newly added targets to be placed in the extended operator
* node.
*/
static List *
WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *walkerContext)
{
List *workerAggregateList = NIL;
if (walkerContext->extendedOpNodeProperties->pullUpIntermediateRows)
{
TargetEntry *targetEntry;
foreach_declared_ptr(targetEntry, originalAggregate->args)
{
workerAggregateList = lappend(workerAggregateList, targetEntry->expr);
}
Expr *directarg;
foreach_declared_ptr(directarg, originalAggregate->aggdirectargs)
{
/*
* The worker aggregation should execute any node that contains any
* Var nodes and return the result in the targetlist, so that the
* combine query can then fetch the result via remote scan; see
* MasterAggregateExpression.
*/
if (pull_var_clause_default((Node *) directarg) != NIL)
{
workerAggregateList = lappend(workerAggregateList, directarg);
}
}
if (originalAggregate->aggfilter)
{
workerAggregateList = lappend(workerAggregateList,
originalAggregate->aggfilter);
}
return workerAggregateList;
}
AggregateType aggregateType = GetAggregateType(originalAggregate);
if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
walkerContext->extendedOpNodeProperties->pullDistinctColumns)
{
Aggref *aggregate = (Aggref *) copyObject(originalAggregate);
List *columnList = pull_var_clause_default((Node *) aggregate);
Var *column = NULL;
foreach_declared_ptr(column, columnList)
{
workerAggregateList = list_append_unique(workerAggregateList, column);
}
walkerContext->createGroupByClause = true;
}
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
{
/*
* If the original aggregate is a count(distinct) approximation, we want
* to compute hll_add_agg(hll_hash(var), storageSize) on worker nodes.
*/
const AttrNumber firstArgumentId = 1;
const AttrNumber secondArgumentId = 2;
const int hashArgumentCount = 2;
const int addArgumentCount = 2;
/* init hll_hash() related variables */
Oid argumentType = AggregateArgumentType(originalAggregate);
TargetEntry *argument = (TargetEntry *) linitial(originalAggregate->args);
Expr *argumentExpression = copyObject(argument->expr);
/* extract schema name of hll */
Oid hllId = get_extension_oid(HLL_EXTENSION_NAME, false);
Oid hllSchemaOid = get_extension_schema(hllId);
const char *hllSchemaName = get_namespace_name(hllSchemaOid);
const char *hashFunctionName = CountDistinctHashFunctionName(argumentType);
Oid hashFunctionId = FunctionOid(hllSchemaName, hashFunctionName,
hashArgumentCount);
Oid hashFunctionReturnType = get_func_rettype(hashFunctionId);
/* init hll_add_agg() related variables */
Oid addFunctionId = FunctionOid(hllSchemaName, HLL_ADD_AGGREGATE_NAME,
addArgumentCount);
Oid hllType = TypeOid(hllSchemaOid, HLL_TYPE_NAME);
int logOfStorageSize = CountDistinctStorageSize(CountDistinctErrorRate);
Const *logOfStorageSizeConst = MakeIntegerConst(logOfStorageSize);
/* construct hll_hash() expression */
FuncExpr *hashFunction = makeNode(FuncExpr);
hashFunction->funcid = hashFunctionId;
hashFunction->funcresulttype = hashFunctionReturnType;
hashFunction->args = list_make1(argumentExpression);
/* construct hll_add_agg() expression */
TargetEntry *hashedColumnArgument = makeTargetEntry((Expr *) hashFunction,
firstArgumentId, NULL, false);
TargetEntry *storageSizeArgument = makeTargetEntry((Expr *) logOfStorageSizeConst,
secondArgumentId, NULL, false);
List *addAggregateArgumentList = list_make2(hashedColumnArgument,
storageSizeArgument);
Aggref *addAggregateFunction = makeNode(Aggref);
addAggregateFunction->aggfnoid = addFunctionId;
addAggregateFunction->aggtype = hllType;
addAggregateFunction->args = addAggregateArgumentList;
addAggregateFunction->aggkind = AGGKIND_NORMAL;
addAggregateFunction->aggfilter = (Expr *) copyObject(
originalAggregate->aggfilter);
workerAggregateList = lappend(workerAggregateList, addAggregateFunction);
}
else if (aggregateType == AGGREGATE_AVERAGE)
{
/*
* If the original aggregate is an average, we want to compute sum(var)
* and count(var) on worker nodes.
*/
Aggref *sumAggregate = copyObject(originalAggregate);
Aggref *countAggregate = copyObject(originalAggregate);
/* extract function names for sum and count */
const char *sumAggregateName = AggregateNames[AGGREGATE_SUM];
const char *countAggregateName = AggregateNames[AGGREGATE_COUNT];
/*
* Find the type of the expression over which we execute the aggregate.
* We then need to find the right sum function for that type.
*/
Oid argumentType = AggregateArgumentType(originalAggregate);
/* find function implementing sum over the original type */
sumAggregate->aggfnoid = AggregateFunctionOid(sumAggregateName, argumentType);
sumAggregate->aggtype = get_func_rettype(sumAggregate->aggfnoid);
sumAggregate->aggtranstype = InvalidOid;
sumAggregate->aggargtypes = list_make1_oid(argumentType);
sumAggregate->aggsplit = AGGSPLIT_SIMPLE;
/* count has any input type */
countAggregate->aggfnoid = AggregateFunctionOid(countAggregateName, ANYOID);
countAggregate->aggtype = get_func_rettype(countAggregate->aggfnoid);
countAggregate->aggtranstype = InvalidOid;
countAggregate->aggargtypes = list_make1_oid(argumentType);
countAggregate->aggsplit = AGGSPLIT_SIMPLE;
workerAggregateList = lappend(workerAggregateList, sumAggregate);
workerAggregateList = lappend(workerAggregateList, countAggregate);
}
else if (aggregateType == AGGREGATE_TDIGEST_PERCENTILE_ADD_DOUBLE ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_ADD_DOUBLEARRAY ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_OF_ADD_DOUBLE ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_OF_ADD_DOUBLEARRAY)
{
/*
* The original query has an aggregate in the form of either
* - tdigest_percentile(column, compression, quantile)
* - tdigest_percentile(column, compression, quantile[])
* - tdigest_percentile_of(column, compression, value)
* - tdigest_percentile_of(column, compression, value[])
*
* We are creating the worker part of this query by creating a
* - tdigest(column, compression)
*
* One could see we are passing argument 0 and argument 1 from the original query
* in here. This corresponds with the list_nth calls in the args and aggargstypes
* list construction. The tdigest function and type are read from the catalog.
*/
Aggref *newWorkerAggregate = copyObject(originalAggregate);
newWorkerAggregate->aggfnoid = TDigestExtensionAggTDigest2();
newWorkerAggregate->aggtype = TDigestExtensionTypeOid();
newWorkerAggregate->args = list_make2(
list_nth(newWorkerAggregate->args, 0),
list_nth(newWorkerAggregate->args, 1));
newWorkerAggregate->aggkind = AGGKIND_NORMAL;
newWorkerAggregate->aggtranstype = InvalidOid;
newWorkerAggregate->aggargtypes = list_make2_oid(
list_nth_oid(newWorkerAggregate->aggargtypes, 0),
list_nth_oid(newWorkerAggregate->aggargtypes, 1));
newWorkerAggregate->aggsplit = AGGSPLIT_SIMPLE;
workerAggregateList = lappend(workerAggregateList, newWorkerAggregate);
}
else if (aggregateType == AGGREGATE_TDIGEST_PERCENTILE_TDIGEST_DOUBLE ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_TDIGEST_DOUBLEARRAY ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_OF_TDIGEST_DOUBLE ||
aggregateType == AGGREGATE_TDIGEST_PERCENTILE_OF_TDIGEST_DOUBLEARRAY)
{
/*
* The original query has an aggregate in the form of either
* - tdigest_percentile(tdigest, quantile)
* - tdigest_percentile(tdigest, quantile[])
* - tdigest_percentile_of(tdigest, value)
* - tdigest_percentile_of(tdigest, value[])
*
* We are creating the worker part of this query by creating a
* - tdigest(tdigest)
*
* One could see we are passing argument 0 from the original query in here. This
* corresponds with the list_nth calls in the args and aggargstypes list
* construction. The tdigest function and type are read from the catalog.
*/
Aggref *newWorkerAggregate = copyObject(originalAggregate);
newWorkerAggregate->aggfnoid = TDigestExtensionAggTDigest1();
newWorkerAggregate->aggtype = TDigestExtensionTypeOid();
newWorkerAggregate->args = list_make1(list_nth(newWorkerAggregate->args, 0));
newWorkerAggregate->aggkind = AGGKIND_NORMAL;
newWorkerAggregate->aggtranstype = InvalidOid;
newWorkerAggregate->aggargtypes = list_make1_oid(
list_nth_oid(newWorkerAggregate->aggargtypes, 0));
newWorkerAggregate->aggsplit = AGGSPLIT_SIMPLE;
workerAggregateList = lappend(workerAggregateList, newWorkerAggregate);
}
else if (aggregateType == AGGREGATE_CUSTOM_COMBINE)
{
HeapTuple aggTuple =
SearchSysCache1(AGGFNOID, ObjectIdGetDatum(originalAggregate->aggfnoid));
Form_pg_aggregate aggform;
Oid combine;
if (!HeapTupleIsValid(aggTuple))
{
elog(ERROR, "citus cache lookup failed for aggregate %u",
originalAggregate->aggfnoid);
return NULL;
}
else
{
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
combine = aggform->aggcombinefn;
ReleaseSysCache(aggTuple);
}
if (combine != InvalidOid)
{
Oid workerPartialId = WorkerPartialAggOid();
Const *aggOidParam = makeConst(REGPROCEDUREOID, -1, InvalidOid, sizeof(Oid),
ObjectIdGetDatum(originalAggregate->aggfnoid),
false, true);
List *newWorkerAggregateArgs =
list_make1(makeTargetEntry((Expr *) aggOidParam, 1, NULL, false));
if (list_length(originalAggregate->args) == 1)
{
/*
* Single argument case, append 'arg' to worker_partial_agg(agg, arg).
* We don't wrap single argument in a row expression because
* it has performance implications to unwrap arguments on each
* SFUNC invocation.
*/
TargetEntry *newArg =
copyObject((TargetEntry *) linitial(originalAggregate->args));
newArg->resno++;
newWorkerAggregateArgs = lappend(newWorkerAggregateArgs, newArg);
}
else
{
/*
* Aggregation on workers assumes a single aggregation parameter.
* To still be able to handle multiple parameters, we combine
* parameters into a single row expression, i.e., append 'ROW(...args)'
* to worker_partial_agg(agg, ROW(...args)).
*/
RowExpr *rowExpr = makeNode(RowExpr);
rowExpr->row_typeid = RECORDOID;
rowExpr->row_format = COERCE_EXPLICIT_CALL;
rowExpr->location = -1;
rowExpr->colnames = NIL;
TargetEntry *arg = NULL;
foreach_declared_ptr(arg, originalAggregate->args)
{
rowExpr->args = lappend(rowExpr->args, copyObject(arg->expr));
}
newWorkerAggregateArgs =
lappend(newWorkerAggregateArgs,
makeTargetEntry((Expr *) rowExpr, 2, NULL, false));
}
/* worker_partial_agg(agg, arg) or worker_partial_agg(agg, ROW(...args)) */
Aggref *newWorkerAggregate = copyObject(originalAggregate);
newWorkerAggregate->aggfnoid = workerPartialId;
newWorkerAggregate->aggtype = CSTRINGOID;
newWorkerAggregate->args = newWorkerAggregateArgs;
newWorkerAggregate->aggkind = AGGKIND_NORMAL;
newWorkerAggregate->aggtranstype = INTERNALOID;
newWorkerAggregate->aggargtypes = lcons_oid(OIDOID,
newWorkerAggregate->aggargtypes);
newWorkerAggregate->aggsplit = AGGSPLIT_SIMPLE;
workerAggregateList = list_make1(newWorkerAggregate);
}
else
{
elog(ERROR, "Aggregate lacks COMBINEFUNC");
}
}
else
{
/*
* All other aggregates are sent as they are to the worker nodes.
*/
Aggref *workerAggregate = copyObject(originalAggregate);
workerAggregateList = lappend(workerAggregateList, workerAggregate);
}
return workerAggregateList;
}
/*
* GetAggregateType scans pg_catalog.pg_proc for the given aggregate oid, and
* finds the aggregate's name. The function then matches the aggregate's name to
* previously stored strings, and returns the appropriate aggregate type.
*/
static AggregateType
GetAggregateType(Aggref *aggregateExpression)
{
Oid aggFunctionId = aggregateExpression->aggfnoid;
/* custom aggregates with combine func take precedence over name-based logic */
if (aggFunctionId >= FirstNormalObjectId &&
AggregateEnabledCustom(aggregateExpression))
{
return AGGREGATE_CUSTOM_COMBINE;
}
/* look up the function name */
char *aggregateProcName = get_func_name(aggFunctionId);
if (aggregateProcName == NULL)
{
ereport(ERROR, (errmsg("citus cache lookup failed for function %u",
aggFunctionId)));
}
uint32 aggregateCount = lengthof(AggregateNames);
for (uint32 aggregateIndex = 1; aggregateIndex < aggregateCount; aggregateIndex++)
{
const char *aggregateName = AggregateNames[aggregateIndex];
if (strncmp(aggregateName, aggregateProcName, NAMEDATALEN) == 0)
{
return aggregateIndex;
}
}
/*
* All functions from github.com/tvondra/tdigest start with the "tdigest" prefix.
* Since it requires lookups of function names in a schema we would like to only
* perform these checks if there is some chance it will actually result in a positive
* hit.
*/
if (StringStartsWith(aggregateProcName, "tdigest"))
{
if (aggFunctionId == TDigestExtensionAggTDigest1())
{
return AGGREGATE_TDIGEST_COMBINE;
}
if (aggFunctionId == TDigestExtensionAggTDigest2())
{
return AGGREGATE_TDIGEST_ADD_DOUBLE;
}
if (aggFunctionId == TDigestExtensionAggTDigestPercentile3())
{
return AGGREGATE_TDIGEST_PERCENTILE_ADD_DOUBLE;
}
if (aggFunctionId == TDigestExtensionAggTDigestPercentile3a())
{
return AGGREGATE_TDIGEST_PERCENTILE_ADD_DOUBLEARRAY;
}
if (aggFunctionId == TDigestExtensionAggTDigestPercentile2())
{
return AGGREGATE_TDIGEST_PERCENTILE_TDIGEST_DOUBLE;
}
if (aggFunctionId == TDigestExtensionAggTDigestPercentile2a())
{
return AGGREGATE_TDIGEST_PERCENTILE_TDIGEST_DOUBLEARRAY;
}
if (aggFunctionId == TDigestExtensionAggTDigestPercentileOf3())
{
return AGGREGATE_TDIGEST_PERCENTILE_OF_ADD_DOUBLE;
}
if (aggFunctionId == TDigestExtensionAggTDigestPercentileOf3a())
{
return AGGREGATE_TDIGEST_PERCENTILE_OF_ADD_DOUBLEARRAY;
}
if (aggFunctionId == TDigestExtensionAggTDigestPercentileOf2())
{
return AGGREGATE_TDIGEST_PERCENTILE_OF_TDIGEST_DOUBLE;
}
if (aggFunctionId == TDigestExtensionAggTDigestPercentileOf2a())
{
return AGGREGATE_TDIGEST_PERCENTILE_OF_TDIGEST_DOUBLEARRAY;
}
}
/* handle any remaining built-in aggregates with a suitable combinefn */
if (AggregateEnabledCustom(aggregateExpression))
{
return AGGREGATE_CUSTOM_COMBINE;
}
if (CoordinatorAggregationStrategy == COORDINATOR_AGGREGATION_DISABLED)
{
ereport(ERROR, (errmsg("unsupported aggregate function %s", aggregateProcName)));
}
else
{
return AGGREGATE_CUSTOM_ROW_GATHER;
}
}
/* Extracts the type of the argument over which the aggregate is operating. */
static Oid
AggregateArgumentType(Aggref *aggregate)
{
List *argumentList = aggregate->args;
TargetEntry *argument = (TargetEntry *) linitial(argumentList);
Oid returnTypeId = exprType((Node *) argument->expr);
/* Here we currently support aggregates with only one argument; assert that. */
Assert(list_length(argumentList) == 1);
return returnTypeId;
}
/*
* FirstAggregateArgument returns the first argument of the aggregate.
*/
static Expr *
FirstAggregateArgument(Aggref *aggregate)
{
List *argumentList = aggregate->args;
Assert(list_length(argumentList) >= 1);
TargetEntry *argument = (TargetEntry *) linitial(argumentList);
return argument->expr;
}
/*
* AggregateEnabledCustom returns whether given aggregate can be
* distributed across workers using worker_partial_agg & coord_combine_agg.
*/
static bool
AggregateEnabledCustom(Aggref *aggregateExpression)
{
if (aggregateExpression->aggorder != NIL ||
list_length(aggregateExpression->args) == 0)
{
return false;
}
Oid aggregateOid = aggregateExpression->aggfnoid;
HeapTuple aggTuple = SearchSysCache1(AGGFNOID, aggregateOid);
if (!HeapTupleIsValid(aggTuple))
{
elog(ERROR, "citus cache lookup failed.");
}
Form_pg_aggregate aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
if (aggform->aggcombinefn == InvalidOid)
{
ReleaseSysCache(aggTuple);
return false;
}
HeapTuple typeTuple = SearchSysCache1(TYPEOID, aggform->aggtranstype);
if (!HeapTupleIsValid(typeTuple))
{
elog(ERROR, "citus cache lookup failed.");
}
Form_pg_type typeform = (Form_pg_type) GETSTRUCT(typeTuple);
bool supportsSafeCombine = typeform->typtype != TYPTYPE_PSEUDO;
ReleaseSysCache(aggTuple);
ReleaseSysCache(typeTuple);
return supportsSafeCombine;
}
/*
* AggregateFunctionOid performs a reverse lookup on aggregate function name,
* and returns the corresponding aggregate function oid for the given function
* name and input type.
*/
static Oid
AggregateFunctionOid(const char *functionName, Oid inputType)
{
Oid functionOid = InvalidOid;
ScanKeyData scanKey[1];
int scanKeyCount = 1;
Relation procRelation = table_open(ProcedureRelationId, AccessShareLock);
ScanKeyInit(&scanKey[0], Anum_pg_proc_proname,
BTEqualStrategyNumber, F_NAMEEQ, CStringGetDatum(functionName));
SysScanDesc scanDescriptor = systable_beginscan(procRelation,
ProcedureNameArgsNspIndexId, true,
NULL, scanKeyCount, scanKey);
/* loop until we find the right function */
HeapTuple heapTuple = systable_getnext(scanDescriptor);
while (HeapTupleIsValid(heapTuple))
{
Form_pg_proc procForm = (Form_pg_proc) GETSTRUCT(heapTuple);
int argumentCount = procForm->pronargs;
if (argumentCount == 1)
{
/* check if input type and found value type match */
if (procForm->proargtypes.values[0] == inputType ||
procForm->proargtypes.values[0] == ANYELEMENTOID)
{
functionOid = procForm->oid;
break;
}
}
Assert(argumentCount <= 1);
heapTuple = systable_getnext(scanDescriptor);
}
if (functionOid == InvalidOid)
{
ereport(ERROR, (errmsg("no matching oid for function: %s", functionName)));
}
systable_endscan(scanDescriptor);
table_close(procRelation, AccessShareLock);
return functionOid;
}
/*
* CitusFunctionOidWithSignature looks up a function with given input types.
* Looks in pg_catalog schema, as this function's sole purpose is
* support aggregate lookup.
*/
static Oid
CitusFunctionOidWithSignature(char *functionName, int numargs, Oid *argtypes)
{
List *aggregateName = list_make2(makeString("pg_catalog"), makeString(functionName));
FuncCandidateList clist = FuncnameGetCandidates(aggregateName, numargs, NIL,
false, false, false, true);
for (; clist; clist = clist->next)
{
if (memcmp(clist->args, argtypes, numargs * sizeof(Oid)) == 0)
{
return clist->oid;
}
}
ereport(ERROR, (errmsg("no matching oid for function: %s", functionName)));
return InvalidOid;
}
/*
* WorkerPartialAggOid looks up oid of pg_catalog.worker_partial_agg
*/
static Oid
WorkerPartialAggOid()
{
Oid argtypes[] = {
OIDOID,
ANYELEMENTOID,
};
return CitusFunctionOidWithSignature(WORKER_PARTIAL_AGGREGATE_NAME, 2, argtypes);
}
/*
* CoordCombineAggOid looks up oid of pg_catalog.coord_combine_agg
*/
static Oid
CoordCombineAggOid()
{
Oid argtypes[] = {
OIDOID,
CSTRINGOID,
ANYELEMENTOID,
};
return CitusFunctionOidWithSignature(COORD_COMBINE_AGGREGATE_NAME, 3, argtypes);
}
/*
* TypeOid looks for a type that has the given name and schema, and returns the
* corresponding type's oid.
*/
static Oid
TypeOid(Oid schemaId, const char *typeName)
{
Oid typeOid = GetSysCacheOid2(TYPENAMENSP, Anum_pg_type_oid,
PointerGetDatum(typeName),
ObjectIdGetDatum(schemaId));
return typeOid;
}
/*
* CreateSortGroupClause creates SortGroupClause for a given column Var.
* The caller should set tleSortGroupRef field and respective
* TargetEntry->ressortgroupref fields to appropriate SortGroupRefIndex.
*/
static SortGroupClause *
CreateSortGroupClause(Var *column)
{
Oid lessThanOperator = InvalidOid;
Oid equalsOperator = InvalidOid;
bool hashable = false;
SortGroupClause *groupByClause = makeNode(SortGroupClause);
get_sort_group_operators(column->vartype, true, true, true,
&lessThanOperator, &equalsOperator, NULL,
&hashable);
groupByClause->eqop = equalsOperator;
groupByClause->hashable = hashable;
groupByClause->nulls_first = false;
groupByClause->sortop = lessThanOperator;
return groupByClause;
}
/*
* CountDistinctHashFunctionName resolves the hll_hash function name to use for
* the given input type, and returns this function name.
*/
static const char *
CountDistinctHashFunctionName(Oid argumentType)
{
/* resolve hash function name based on input argument type */
switch (argumentType)
{
case INT4OID:
{
return HLL_HASH_INTEGER_FUNC_NAME;
}
case INT8OID:
{
return HLL_HASH_BIGINT_FUNC_NAME;
}
case TEXTOID:
case BPCHAROID:
case VARCHAROID:
{
return HLL_HASH_TEXT_FUNC_NAME;
}
default:
{
return HLL_HASH_ANY_FUNC_NAME;
}
}
}
/*
* CountDistinctStorageSize takes in the desired precision for count distinct
* approximations, and returns the log-base-2 of storage space needed for the
* HyperLogLog algorithm.
*/
static int
CountDistinctStorageSize(double approximationErrorRate)
{
double desiredStorageSize = pow((1.04 / approximationErrorRate), 2);
double logOfDesiredStorageSize = log(desiredStorageSize) / log(2);
/* keep log2(storage size) inside allowed range */
int logOfStorageSize = (int) rint(logOfDesiredStorageSize);
if (logOfStorageSize < 4)
{
logOfStorageSize = 4;
}
else if (logOfStorageSize > 17)
{
logOfStorageSize = 17;
}
return logOfStorageSize;
}
/* Makes an integer constant node from the given value, and returns that node. */
static Const *
MakeIntegerConst(int32 integerValue)
{
const int typeCollationId = get_typcollation(INT4OID);
const int16 typeLength = get_typlen(INT4OID);
const int32 typeModifier = -1;
const bool typeIsNull = false;
const bool typePassByValue = true;
Datum integerDatum = Int32GetDatum(integerValue);
Const *integerConst = makeConst(INT4OID, typeModifier, typeCollationId, typeLength,
integerDatum, typeIsNull, typePassByValue);
return integerConst;
}
/* Makes a 64-bit integer constant node from the given value, and returns that node. */
static Const *
MakeIntegerConstInt64(int64 integerValue)
{
const int typeCollationId = get_typcollation(INT8OID);
const int16 typeLength = get_typlen(INT8OID);
const int32 typeModifier = -1;
const bool typeIsNull = false;
const bool typePassByValue = true;
Datum integer64Datum = Int64GetDatum(integerValue);
Const *integer64Const = makeConst(INT8OID, typeModifier, typeCollationId, typeLength,
integer64Datum, typeIsNull, typePassByValue);
return integer64Const;
}
/*
* HasNonDistributableAggregates checks for if any aggregates cannot be pushed down.
* This only checks with GetAggregateType. DeferErrorIfHasNonDistributableAggregates
* performs further checks which should be done if aggregates are not being pushed down.
*/
static bool
HasNonDistributableAggregates(MultiNode *logicalPlanNode)
{
if (CoordinatorAggregationStrategy == COORDINATOR_AGGREGATION_DISABLED)
{
return false;
}
List *opNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp);
MultiExtendedOp *extendedOpNode = (MultiExtendedOp *) linitial(opNodeList);
List *targetList = extendedOpNode->targetList;
Node *havingQual = extendedOpNode->havingQual;
/*
* PVC_REJECT_PLACEHOLDERS is implicit if PVC_INCLUDE_PLACEHOLDERS isn't
* specified.
*/
List *expressionList = pull_var_clause((Node *) targetList, PVC_INCLUDE_AGGREGATES |
PVC_INCLUDE_WINDOWFUNCS);
expressionList = list_concat(expressionList,
pull_var_clause(havingQual, PVC_INCLUDE_AGGREGATES));
Node *expression = NULL;
foreach_declared_ptr(expression, expressionList)
{
/* only consider aggregate expressions */
if (!IsA(expression, Aggref))
{
continue;
}
AggregateType aggregateType = GetAggregateType((Aggref *) expression);
Assert(aggregateType != AGGREGATE_INVALID_FIRST);
if (aggregateType == AGGREGATE_CUSTOM_ROW_GATHER)
{
return true;
}
}
return false;
}
/*
* CanPushDownExpression returns whether the expression can be pushed down to workers.
*/
static bool
CanPushDownExpression(Node *expression,
const ExtendedOpNodeProperties *extendedOpNodeProperties)
{
if (contain_nextval_expression_walker(expression, NULL))
{
/* nextval can only be evaluated on the coordinator */
return false;
}
bool hasAggregate = contain_aggs_of_level(expression, 0);
bool hasWindowFunction = contain_window_function(expression);
if (!hasAggregate && !hasWindowFunction)
{
/*
* If the query has the form SELECT expression, agg(..) FROM table;
* then expression should be evaluated on the coordinator.
*
* Other than the efficiency part of this, we could also crash if
* we pushed down the expression to the workers. When pushing down
* expressions to workers we create a Var reference to the worker
* tuples. If the result from worker is empty, but we need to have
* at least a row in coordinator result, postgres will crash when
* trying to evaluate the Var.
*
* For details, see https://github.com/citusdata/citus/pull/3961
*/
if (!extendedOpNodeProperties->hasAggregate ||
extendedOpNodeProperties->hasGroupBy)
{
return true;
}
}
/* aggregates inside pushed down window functions can be pushed down */
bool hasPushableWindowFunction =
hasWindowFunction && extendedOpNodeProperties->onlyPushableWindowFunctions;
if (hasPushableWindowFunction)
{
return true;
}
if (extendedOpNodeProperties->pushDownGroupingAndHaving && !hasWindowFunction)
{
return true;
}
if (hasAggregate && !hasWindowFunction &&
extendedOpNodeProperties->groupedByDisjointPartitionColumn)
{
return true;
}
return false;
}
/*
* DeferErrorIfHasNonDistributableAggregates extracts aggregate expressions from
* the logical plan, walks over them and uses helper functions to check if we
* can transform these aggregate expressions and push them down to worker nodes.
*/
static DeferredErrorMessage *
DeferErrorIfHasNonDistributableAggregates(MultiNode *logicalPlanNode)
{
DeferredErrorMessage *error = NULL;
List *opNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp);
MultiExtendedOp *extendedOpNode = (MultiExtendedOp *) linitial(opNodeList);
List *targetList = extendedOpNode->targetList;
Node *havingQual = extendedOpNode->havingQual;
/*
* PVC_REJECT_PLACEHOLDERS is implicit if PVC_INCLUDE_PLACEHOLDERS isn't
* specified.
*/
List *expressionList = pull_var_clause((Node *) targetList, PVC_INCLUDE_AGGREGATES |
PVC_INCLUDE_WINDOWFUNCS);
expressionList = list_concat(expressionList,
pull_var_clause(havingQual, PVC_INCLUDE_AGGREGATES));
Node *expression = NULL;
foreach_declared_ptr(expression, expressionList)
{
/* only consider aggregate expressions */
if (!IsA(expression, Aggref))
{
continue;
}
/* GetAggregateType errors out on unsupported aggregate types */
Aggref *aggregateExpression = (Aggref *) expression;
AggregateType aggregateType = GetAggregateType(aggregateExpression);
Assert(aggregateType != AGGREGATE_INVALID_FIRST);
/*
* Check that we can transform the current aggregate expression. These
* functions error out on unsupported array_agg and aggregate (distinct)
* clauses.
*/
if (aggregateType == AGGREGATE_ARRAY_AGG)
{
error = DeferErrorIfUnsupportedArrayAggregate(aggregateExpression);
}
else if (aggregateType == AGGREGATE_JSONB_AGG ||
aggregateType == AGGREGATE_JSON_AGG)
{
error = DeferErrorIfUnsupportedJsonAggregate(aggregateType,
aggregateExpression);
}
else if (aggregateType == AGGREGATE_JSONB_OBJECT_AGG ||
aggregateType == AGGREGATE_JSON_OBJECT_AGG)
{
error = DeferErrorIfUnsupportedJsonAggregate(aggregateType,
aggregateExpression);
}
else if (aggregateExpression->aggdistinct)
{
error = DeferErrorIfUnsupportedAggregateDistinct(aggregateExpression,
logicalPlanNode);
}
if (error != NULL)
{
return error;
}
}
return NULL;
}
/*
* DeferErrorIfUnsupportedArrayAggregate checks if we can transform the array aggregate
* expression and push it down to the worker node. If we cannot transform the
* aggregate, this function errors.
*/
static DeferredErrorMessage *
DeferErrorIfUnsupportedArrayAggregate(Aggref *arrayAggregateExpression)
{
/* if array_agg has order by, we error out */
if (arrayAggregateExpression->aggorder)
{
return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
"array_agg with order by is unsupported",
NULL, NULL);
}
/* if array_agg has distinct, we error out */
if (arrayAggregateExpression->aggdistinct)
{
return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
"array_agg (distinct) is unsupported",
NULL, NULL);
}
return NULL;
}
/*
* DeferErrorIfUnsupportedJsonAggregate checks if we can transform the json
* aggregate expression and push it down to the worker node. If we cannot
* transform the aggregate, this function errors.
*/
static DeferredErrorMessage *
DeferErrorIfUnsupportedJsonAggregate(AggregateType type,
Aggref *aggregateExpression)
{
/* if json aggregate has order by, we error out */
if (aggregateExpression->aggdistinct || aggregateExpression->aggorder)
{
StringInfoData errorDetail;
initStringInfo(&errorDetail);
const char *name = AggregateNames[type];
appendStringInfoString(&errorDetail, name);
if (aggregateExpression->aggorder)
{
appendStringInfoString(&errorDetail, " with order by is unsupported");
}
else
{
appendStringInfoString(&errorDetail, " (distinct) is unsupported");
}
return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED, errorDetail.data,
NULL, NULL);
}
return NULL;
}
/*
* DeferErrorIfUnsupportedAggregateDistinct checks if we can transform the aggregate
* (distinct expression) and push it down to the worker node. It handles count
* (distinct) separately to check if we can use distinct approximations. If we
* cannot transform the aggregate, this function errors.
*/
static DeferredErrorMessage *
DeferErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
MultiNode *logicalPlanNode)
{
const char *errorDetail = NULL;
bool distinctSupported = true;
AggregateType aggregateType = GetAggregateType(aggregateExpression);
/* If we're aggregating on coordinator, this becomes simple. */
if (aggregateType == AGGREGATE_CUSTOM_ROW_GATHER)
{
return NULL;
}
/*
* We partially support count(distinct) in subqueries, other distinct aggregates in
* subqueries are not supported yet.
*/
if (aggregateType == AGGREGATE_COUNT)
{
Node *aggregateArgument = (Node *) linitial(aggregateExpression->args);
List *columnList = pull_var_clause_default(aggregateArgument);
Var *column = NULL;
foreach_declared_ptr(column, columnList)
{
if (column->varattno <= 0)
{
return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
"cannot compute count (distinct)",
"Non-column references are not supported yet",
NULL);
}
}
}
else
{
List *multiTableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
MultiTable *multiTable = NULL;
foreach_declared_ptr(multiTable, multiTableNodeList)
{
if (multiTable->relationId == SUBQUERY_RELATION_ID ||
multiTable->relationId == SUBQUERY_PUSHDOWN_RELATION_ID)
{
return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
"cannot compute aggregate (distinct)",
"Only count(distinct) aggregate is "
"supported in subqueries", NULL);
}
}
}
/* if we have a count(distinct), and distinct approximation is enabled */
if (aggregateType == AGGREGATE_COUNT &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
{
bool missingOK = true;
Oid distinctExtensionId = get_extension_oid(HLL_EXTENSION_NAME, missingOK);
/* if extension for distinct approximation is loaded, we are good */
if (distinctExtensionId != InvalidOid)
{
return NULL;
}
else
{
return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
"cannot compute count (distinct) approximation",
NULL,
"You need to have the hll extension loaded.");
}
}
if (aggregateType == AGGREGATE_COUNT)
{
List *aggregateVarList = pull_var_clause_default((Node *) aggregateExpression);
if (aggregateVarList == NIL)
{
distinctSupported = false;
errorDetail = "aggregate (distinct) with no columns is unsupported";
}
}
List *repartitionNodeList = FindNodesOfType(logicalPlanNode, T_MultiPartition);
if (repartitionNodeList != NIL)
{
distinctSupported = false;
errorDetail = "aggregate (distinct) with table repartitioning is unsupported";
}
List *tableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
List *extendedOpNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp);
MultiExtendedOp *extendedOpNode = (MultiExtendedOp *) linitial(extendedOpNodeList);
Var *distinctColumn = AggregateDistinctColumn(aggregateExpression);
if (distinctSupported)
{
if (distinctColumn == NULL)
{
/*
* If the query has a single table, and table is grouped by partition
* column, then we support count distincts even distinct column can
* not be identified.
*/
distinctSupported = TablePartitioningSupportsDistinct(tableNodeList,
extendedOpNode,
distinctColumn,
aggregateType);
if (!distinctSupported)
{
errorDetail = "aggregate (distinct) on complex expressions is"
" unsupported";
}
}
else if (aggregateType != AGGREGATE_COUNT)
{
bool supports = TablePartitioningSupportsDistinct(tableNodeList,
extendedOpNode,
distinctColumn,
aggregateType);
if (!supports)
{
distinctSupported = false;
errorDetail = "table partitioning is unsuitable for aggregate (distinct)";
}
}
}
/* if current aggregate expression isn't supported, error out */
if (!distinctSupported)
{
const char *errorHint = NULL;
if (aggregateType == AGGREGATE_COUNT)
{
errorHint = "You can load the hll extension from contrib "
"packages and enable distinct approximations.";
}
return DeferredError(ERRCODE_FEATURE_NOT_SUPPORTED,
"cannot compute aggregate (distinct)",
errorDetail, errorHint);
}
return NULL;
}
/*
* AggregateDistinctColumn checks if the given aggregate expression's distinct
* clause is on a single column. If it is, the function finds and returns that
* column. Otherwise, the function returns null.
* The function expects to find a single column here, no FieldSelect or other
* expressions are accepted as a column.
*/
static Var *
AggregateDistinctColumn(Aggref *aggregateExpression)
{
/* only consider aggregates with distincts */
if (!aggregateExpression->aggdistinct)
{
return NULL;
}
int aggregateArgumentCount = list_length(aggregateExpression->args);
if (aggregateArgumentCount != 1)
{
return NULL;
}
TargetEntry *aggregateTargetEntry = (TargetEntry *) linitial(
aggregateExpression->args);
if (!IsA(aggregateTargetEntry->expr, Var))
{
return NULL;
}
Var *aggregateColumn = (Var *) aggregateTargetEntry->expr;
return aggregateColumn;
}
/*
* TablePartitioningSupportsDistinct walks over all tables in the given list and
* checks that each table's partitioning method is suitable for pushing down an
* aggregate (distinct) expression to worker nodes. For this, the function needs
* to check that task results do not overlap with one another on the distinct
* column.
*/
static bool
TablePartitioningSupportsDistinct(List *tableNodeList, MultiExtendedOp *opNode,
Var *distinctColumn, AggregateType aggregateType)
{
bool distinctSupported = true;
MultiTable *tableNode = NULL;
foreach_declared_ptr(tableNode, tableNodeList)
{
Oid relationId = tableNode->relationId;
bool tableDistinctSupported = false;
if (relationId == SUBQUERY_RELATION_ID ||
relationId == SUBQUERY_PUSHDOWN_RELATION_ID)
{
return true;
}
/* if table has one shard, task results don't overlap */
List *shardList = LoadShardList(relationId);
if (list_length(shardList) == 1)
{
continue;
}
/*
* We need to check that task results don't overlap. We can only do this
* if table is range partitioned.
*/
if (IsCitusTableType(relationId, RANGE_DISTRIBUTED) ||
IsCitusTableType(relationId, HASH_DISTRIBUTED))
{
Var *tablePartitionColumn = tableNode->partitionColumn;
if (aggregateType == AGGREGATE_COUNT)
{
tableDistinctSupported = true;
}
/* if distinct is on table partition column, we can push it down */
if (distinctColumn != NULL &&
tablePartitionColumn->varno == distinctColumn->varno &&
tablePartitionColumn->varattno == distinctColumn->varattno)
{
tableDistinctSupported = true;
}
/* if results are grouped by partition column, we can push down */
bool groupedByPartitionColumn = GroupedByColumn(opNode->groupClauseList,
opNode->targetList,
tablePartitionColumn);
if (groupedByPartitionColumn)
{
tableDistinctSupported = true;
}
}
if (!tableDistinctSupported)
{
distinctSupported = false;
break;
}
}
return distinctSupported;
}
/*
* GroupedByColumn walks over group clauses in the given list, and checks if any
* of the group clauses is on the given column.
*/
bool
GroupedByColumn(List *groupClauseList, List *targetList, Var *column)
{
bool groupedByColumn = false;
if (column == NULL)
{
return false;
}
SortGroupClause *groupClause = NULL;
foreach_declared_ptr(groupClause, groupClauseList)
{
TargetEntry *groupTargetEntry = get_sortgroupclause_tle(groupClause, targetList);
Expr *groupExpression = (Expr *) groupTargetEntry->expr;
if (IsA(groupExpression, Var))
{
Var *groupColumn = (Var *) groupExpression;
if (groupColumn->varno == column->varno &&
groupColumn->varattno == column->varattno)
{
groupedByColumn = true;
break;
}
}
}
return groupedByColumn;
}
/*
* SubqueryMultiTableList extracts multi tables in the given logical plan tree
* and returns subquery multi tables in a new list.
*/
List *
SubqueryMultiTableList(MultiNode *multiNode)
{
List *subqueryMultiTableList = NIL;
List *multiTableNodeList = FindNodesOfType(multiNode, T_MultiTable);
MultiTable *multiTable = NULL;
foreach_declared_ptr(multiTable, multiTableNodeList)
{
Query *subquery = multiTable->subquery;
if (subquery != NULL)
{
subqueryMultiTableList = lappend(subqueryMultiTableList, multiTable);
}
}
return subqueryMultiTableList;
}
/*
* GroupTargetEntryList walks over group clauses in the given list, finds
* matching target entries and return them in a new list.
*/
List *
GroupTargetEntryList(List *groupClauseList, List *targetEntryList)
{
List *groupTargetEntryList = NIL;
SortGroupClause *groupClause = NULL;
foreach_declared_ptr(groupClause, groupClauseList)
{
TargetEntry *groupTargetEntry =
get_sortgroupclause_tle(groupClause, targetEntryList);
groupTargetEntryList = lappend(groupTargetEntryList, groupTargetEntry);
}
return groupTargetEntryList;
}
/*
* IsPartitionColumn returns true if the given column is a partition column.
* The function uses FindReferencedTableColumn to find the original relation
* id and column that the column expression refers to. It then checks whether
* that column is a partition column of the relation.
*
* Also, the function returns always false for reference tables given that
* reference tables do not have partition column. The function does not
* support queries with CTEs, it would return false if columnExpression
* refers to a column returned by a CTE.
*
* If skipOuterVars is true, then it doesn't process the outervars.
*/
bool
IsPartitionColumn(Expr *columnExpression, Query *query, bool skipOuterVars)
{
bool isPartitionColumn = false;
Var *column = NULL;
RangeTblEntry *relationRTE = NULL;
FindReferencedTableColumn(columnExpression, NIL, query, &column, &relationRTE,
skipOuterVars);
Oid relationId = relationRTE ? relationRTE->relid : InvalidOid;
if (relationId != InvalidOid && column != NULL)
{
Var *partitionColumn = DistPartitionKey(relationId);
/* not all distributed tables have partition column */
if (partitionColumn != NULL && column->varattno == partitionColumn->varattno)
{
isPartitionColumn = true;
}
}
return isPartitionColumn;
}
/*
* FindReferencedTableColumn recursively traverses query tree to find actual relation
* id, and column that columnExpression refers to. If columnExpression is a
* non-relational or computed/derived expression, the function returns NULL for
* rte and NULL for column. The caller should provide parent query list from
* top of the tree to this particular Query's parent. This argument is used to look
* into CTEs that may be present in the query.
*
* If skipOuterVars is true, then it doesn't check vars coming from outer queries.
* We probably don't need this skipOuterVars check but we wanted to be on the safe side
* and used it only in UNION path, we can separately work on verifying that it doesn't break
* anything existing.
*/
void
FindReferencedTableColumn(Expr *columnExpression, List *parentQueryList, Query *query,
Var **column, RangeTblEntry **rteContainingReferencedColumn,
bool skipOuterVars)
{
Var *candidateColumn = NULL;
Expr *strippedColumnExpression = (Expr *) strip_implicit_coercions(
(Node *) columnExpression);
*rteContainingReferencedColumn = NULL;
*column = NULL;
if (IsA(strippedColumnExpression, Var))
{
candidateColumn = (Var *) strippedColumnExpression;
}
else if (IsA(strippedColumnExpression, FieldSelect))
{
FieldSelect *compositeField = (FieldSelect *) strippedColumnExpression;
Expr *fieldExpression = compositeField->arg;
if (IsA(fieldExpression, Var))
{
candidateColumn = (Var *) fieldExpression;
}
}
if (candidateColumn == NULL)
{
return;
}
/* Walk up varlevelsup as many times as needed */
while (candidateColumn->varlevelsup > 0)
{
/* Caller asked us to ignore any outer Vars → just bail out */
if (skipOuterVars)
{
return;
}
/* Locate the parent query that owns this Var */
int parentIdx =
list_length(parentQueryList) - candidateColumn->varlevelsup;
if (!IsIndexInRange(parentQueryList, parentIdx))
{
return; /* malformed tree */
}
/* Work on a fresh copy of the Var with varlevelsup reset */
candidateColumn = copyObject(candidateColumn);
candidateColumn->varlevelsup = 0;
/*
* Make a *completely private* copy of parentQueryList for the
* next recursion step. We copy the whole list and then truncate
* so every recursive branch owns its own list cells.
*/
List *newParent =
list_copy(parentQueryList); /* duplicates every cell */
newParent = list_truncate(newParent, parentIdx);
query = list_nth(parentQueryList, parentIdx);
parentQueryList = newParent; /* hand private copy down */
/* Loop again if still pointing to an outer level */
}
if (candidateColumn->varattno == InvalidAttrNumber)
{
/*
* varattno can be 0 in case of SELECT table FROM table, but that Var
* definitely does not correspond to a specific column.
*/
return;
}
List *rangetableList = query->rtable;
int rangeTableEntryIndex = candidateColumn->varno - 1;
RangeTblEntry *rangeTableEntry = list_nth(rangetableList, rangeTableEntryIndex);
if (rangeTableEntry->rtekind == RTE_RELATION)
{
*rteContainingReferencedColumn = rangeTableEntry;
*column = candidateColumn;
}
else if (rangeTableEntry->rtekind == RTE_SUBQUERY)
{
Query *subquery = rangeTableEntry->subquery;
List *targetEntryList = subquery->targetList;
AttrNumber targetEntryIndex = candidateColumn->varattno - 1;
TargetEntry *subqueryTargetEntry = list_nth(targetEntryList, targetEntryIndex);
Expr *subColumnExpression = subqueryTargetEntry->expr;
/* append current query to parent query list */
parentQueryList = lappend(parentQueryList, query);
FindReferencedTableColumn(subColumnExpression, parentQueryList,
subquery, column, rteContainingReferencedColumn,
skipOuterVars);
}
else if (rangeTableEntry->rtekind == RTE_JOIN)
{
List *joinColumnList = rangeTableEntry->joinaliasvars;
AttrNumber joinColumnIndex = candidateColumn->varattno - 1;
Expr *joinColumn = list_nth(joinColumnList, joinColumnIndex);
/* parent query list stays the same since still in the same query boundary */
FindReferencedTableColumn(joinColumn, parentQueryList, query, column,
rteContainingReferencedColumn, skipOuterVars);
}
#if PG_VERSION_NUM >= PG_VERSION_18
else if (rangeTableEntry->rtekind == RTE_GROUP)
{
/*
* PG 18: synthetic GROUP RTE. Each groupexprs item corresponds to the
* columns produced by the grouping step, in the *same ordinal order* as
* the Vars that reference them.
*/
List *groupexprs = rangeTableEntry->groupexprs;
AttrNumber groupIndex = candidateColumn->varattno - 1;
/* this must always hold unless upstream Postgres mis-constructed the RTE_GROUP */
Assert(groupIndex >= 0 && groupIndex < list_length(groupexprs));
Expr *groupExpr = (Expr *) list_nth(groupexprs, groupIndex);
/* Recurse on the underlying expression (stay in the same query) */
FindReferencedTableColumn(groupExpr, parentQueryList, query,
column, rteContainingReferencedColumn,
skipOuterVars);
}
#endif /* PG_VERSION_NUM >= 180000 */
else if (rangeTableEntry->rtekind == RTE_CTE)
{
/*
* Resolve through a CTE even when skipOuterVars == false.
* Maintain the invariant that each recursion level owns a private,
* correctly-bounded copy of parentQueryList.
*/
int cteParentListIndex = list_length(parentQueryList) -
rangeTableEntry->ctelevelsup - 1;
Query *cteParentQuery = NULL;
List *cteList = NIL;
CommonTableExpr *cte = NULL;
/*
* This should have been an error case, not marking it as error at the
* moment due to usage from IsPartitionColumn. Callers of that function
* do not have access to parent query list.
*/
if (IsIndexInRange(parentQueryList, cteParentListIndex))
{
cteParentQuery = list_nth(parentQueryList, cteParentListIndex);
cteList = cteParentQuery->cteList;
}
CommonTableExpr *candidateCte = NULL;
foreach_declared_ptr(candidateCte, cteList)
{
if (strcmp(candidateCte->ctename, rangeTableEntry->ctename) == 0)
{
cte = candidateCte;
break;
}
}
if (cte != NULL)
{
Query *cteQuery = (Query *) cte->ctequery;
AttrNumber targetEntryIndex = candidateColumn->varattno - 1;
if (targetEntryIndex >= 0 &&
targetEntryIndex < list_length(cteQuery->targetList))
{
TargetEntry *targetEntry =
list_nth(cteQuery->targetList, targetEntryIndex);
/* Build a private, bounded parentQueryList before recursing into the CTE.
* Invariant: list is [top … current], owned by this call (no aliasing).
* For RTE_CTE:
* owner_idx = list_length(parentQueryList) - rangeTableEntry->ctelevelsup - 1;
* newParent = lappend(list_truncate(list_copy(parentQueryList), owner_idx + 1), query);
* Example (Q0 owns CTE; we’re in Q2 via nested subquery):
* parent=[Q0,Q1,Q2], ctelevelsup=2 ⇒ owner_idx=0 ⇒ newParent=[Q0,Q2].
* Keeps outer-Var level math correct without mutating the caller’s list.
*/
List *newParent = list_copy(parentQueryList);
newParent = list_truncate(newParent, cteParentListIndex + 1);
newParent = lappend(newParent, query);
FindReferencedTableColumn(targetEntry->expr,
newParent,
cteQuery,
column,
rteContainingReferencedColumn,
skipOuterVars);
}
}
}
}
/*
* IsIndexInRange returns true if the given index is within the
* range of the given list.
*/
static bool
IsIndexInRange(const List *list, int index)
{
return index >= 0 && index < list_length(list);
}
/*
* ExtractQueryWalker walks over a query, and finds all queries in the query
* tree and returns these queries. Note that the function also recurses into
* the subqueries in WHERE clause.
*/
bool
ExtractQueryWalker(Node *node, List **queryList)
{
if (node == NULL)
{
return false;
}
if (IsA(node, Query))
{
Query *query = (Query *) node;
(*queryList) = lappend(*queryList, query);
return query_tree_walker(query, ExtractQueryWalker, queryList, 0);
}
return expression_tree_walker(node, ExtractQueryWalker, queryList);
}
/*
* WorkerLimitCount checks if the given input contains a valid limit node, and
* if that node can be pushed down. For this, the function checks if this limit
* count or a meaningful approximation of it can be pushed down to worker nodes.
* If they can, the function returns the limit count.
*
* The limit push-down decision tree is as follows:
* group by?
* 1/ \0
* group by partition column? (exact pd)
* 0/ \1
* order by? (exact pd)
* 1/ \0
* has order by agg? (no pd)
* 1/ \0
* can approximate? (exact pd)
* 1/ \0
* (approx pd) (no pd)
*
* When an offset is present, the offset value is added to limit because for a query
* with LIMIT x OFFSET y, (x+y) records should be pulled from the workers.
*
* If no limit is present or can be pushed down, then WorkerLimitCount
* returns null.
*/
static Node *
WorkerLimitCount(Node *limitCount, Node *limitOffset, OrderByLimitReference
orderByLimitReference)
{
Node *workerLimitNode = NULL;
LimitPushdownable canPushDownLimit = LIMIT_CANNOT_PUSHDOWN;
if (limitCount == NULL)
{
/* no limit node to push down */
return NULL;
}
if (!IsA(limitCount, Const))
{
/*
* We only push down constant LIMIT clauses to make sure we get back
* the minimum number of rows.
*/
return NULL;
}
if (limitOffset != NULL && !IsA(limitOffset, Const))
{
/*
* If OFFSET is not a constant then we cannot calculate the LIMIT to
* push down.
*/
return NULL;
}
/*
* If window functions are computed on coordinator, we cannot push down LIMIT.
* If we don't have group by clauses, or we have group by partition column,
* or if we have order by clauses without aggregates, we can push down the
* original limit. Else if we have order by clauses with commutative aggregates,
* we can push down approximate limits.
*/
if (!orderByLimitReference.onlyPushableWindowFunctions)
{
canPushDownLimit = LIMIT_CANNOT_PUSHDOWN;
}
else if (orderByLimitReference.groupClauseIsEmpty ||
orderByLimitReference.groupedByDisjointPartitionColumn)
{
canPushDownLimit = LIMIT_CAN_PUSHDOWN;
}
else if (orderByLimitReference.sortClauseIsEmpty)
{
canPushDownLimit = LIMIT_CANNOT_PUSHDOWN;
}
else if (!orderByLimitReference.hasOrderByAggregate)
{
canPushDownLimit = LIMIT_CAN_PUSHDOWN;
}
else if (orderByLimitReference.canApproximate)
{
canPushDownLimit = LIMIT_CAN_APPROXIMATE;
}
/* create the workerLimitNode according to the decisions above */
if (canPushDownLimit == LIMIT_CAN_PUSHDOWN)
{
workerLimitNode = (Node *) copyObject(limitCount);
}
else if (canPushDownLimit == LIMIT_CAN_APPROXIMATE)
{
Const *workerLimitConst = (Const *) copyObject(limitCount);
int64 workerLimitCount = (int64) LimitClauseRowFetchCount;
workerLimitConst->constvalue = Int64GetDatum(workerLimitCount);
workerLimitNode = (Node *) workerLimitConst;
}
/*
* If offset clause is present and limit can be pushed down (whether exactly or
* approximately), add the offset value to limit on workers
*/
if (workerLimitNode != NULL && limitOffset != NULL)
{
Const *workerLimitConst = (Const *) workerLimitNode;
/* Only update the worker limit if the const is not null.*/
if (!workerLimitConst->constisnull)
{
Const *workerOffsetConst = (Const *) limitOffset;
int64 workerLimitCount = DatumGetInt64(workerLimitConst->constvalue);
/* If the offset is null, it defaults to 0 when cast to int64. */
int64 workerOffsetCount = DatumGetInt64(workerOffsetConst->constvalue);
workerLimitCount = workerLimitCount + workerOffsetCount;
workerLimitNode = (Node *) MakeIntegerConstInt64(workerLimitCount);
}
}
/* display debug message on limit push down */
if (workerLimitNode != NULL)
{
Const *workerLimitConst = (Const *) workerLimitNode;
if (!workerLimitConst->constisnull)
{
int64 workerLimitCount = DatumGetInt64(workerLimitConst->constvalue);
ereport(DEBUG1, (errmsg("push down of limit count: " INT64_FORMAT,
workerLimitCount)));
}
else
{
ereport(DEBUG1, (errmsg("push down of limit count: ALL")));
}
}
return workerLimitNode;
}
/*
* WorkerSortClauseList first checks if the given input contains a limit
* or hasDistinctOn that can be pushed down. If it does, the function then
* checks if we need to add any sorting and grouping clauses to the sort list we
* push down for the limit. If we do, the function adds these clauses and
* returns them. Otherwise, the function returns null.
*/
static List *
WorkerSortClauseList(Node *limitCount, List *groupClauseList, List *sortClauseList,
OrderByLimitReference orderByLimitReference)
{
List *workerSortClauseList = NIL;
/* if no limit node and no hasDistinctOn, no need to push down sort clauses */
if (limitCount == NULL && !orderByLimitReference.hasDistinctOn)
{
return NIL;
}
/* If window functions are computed on coordinator, we cannot push down sorting. */
if (!orderByLimitReference.onlyPushableWindowFunctions)
{
return NIL;
}
sortClauseList = copyObject(sortClauseList);
/*
* If we are pushing down the limit, push down any order by clauses. Also if
* we are pushing down the limit because the order by clauses don't have any
* aggregates, add group by clauses to the order by list. We do this because
* rows that belong to the same grouping may appear in different "offsets"
* in different task results. By ordering on the group by clause, we ensure
* that query results are consistent.
*/
if (orderByLimitReference.groupClauseIsEmpty ||
orderByLimitReference.groupedByDisjointPartitionColumn)
{
workerSortClauseList = sortClauseList;
}
else if (sortClauseList != NIL)
{
bool orderByNonAggregates = !orderByLimitReference.hasOrderByAggregate;
bool canApproximate = orderByLimitReference.canApproximate;
if (orderByNonAggregates)
{
workerSortClauseList = sortClauseList;
workerSortClauseList = list_concat(workerSortClauseList, groupClauseList);
}
else if (canApproximate)
{
workerSortClauseList = sortClauseList;
}
}
return workerSortClauseList;
}
/*
* CanPushDownLimitApproximate checks if we can push down the limit clause to
* the worker nodes, and get approximate and meaningful results. We can do this
* only when: (1) the user has enabled the limit approximation and (2) the query
* has order by clauses that are commutative.
*/
static bool
CanPushDownLimitApproximate(List *sortClauseList, List *targetList)
{
bool canApproximate = false;
/* user hasn't enabled the limit approximation */
if (LimitClauseRowFetchCount == DISABLE_LIMIT_APPROXIMATION)
{
return false;
}
if (sortClauseList != NIL)
{
bool orderByNonCommutativeAggregate =
HasOrderByNonCommutativeAggregate(sortClauseList, targetList);
bool orderByComplex = HasOrderByComplexExpression(sortClauseList, targetList);
if (!orderByNonCommutativeAggregate && !orderByComplex)
{
canApproximate = true;
}
}
return canApproximate;
}
/*
* HasOrderByAggregate walks over the given order by clauses, and checks if we
* have an order by an aggregate function. If we do, the function returns true.
*/
static bool
HasOrderByAggregate(List *sortClauseList, List *targetList)
{
bool hasOrderByAggregate = false;
SortGroupClause *sortClause = NULL;
foreach_declared_ptr(sortClause, sortClauseList)
{
Node *sortExpression = get_sortgroupclause_expr(sortClause, targetList);
bool containsAggregate = contain_aggs_of_level(sortExpression, 0);
if (containsAggregate)
{
hasOrderByAggregate = true;
break;
}
}
return hasOrderByAggregate;
}
/*
* HasOrderByNonCommutativeAggregate walks over the given order by clauses,
* and checks if we have an order by an aggregate which is not commutative.
*/
static bool
HasOrderByNonCommutativeAggregate(List *sortClauseList, List *targetList)
{
bool hasOrderByNonCommutativeAggregate = false;
SortGroupClause *sortClause = NULL;
foreach_declared_ptr(sortClause, sortClauseList)
{
Node *sortExpression = get_sortgroupclause_expr(sortClause, targetList);
/* if sort expression is an aggregate, check its type */
if (IsA(sortExpression, Aggref))
{
Aggref *aggregate = (Aggref *) sortExpression;
AggregateType aggregateType = GetAggregateType(aggregate);
if (aggregateType != AGGREGATE_MIN &&
aggregateType != AGGREGATE_MAX &&
aggregateType != AGGREGATE_SUM &&
aggregateType != AGGREGATE_COUNT &&
aggregateType != AGGREGATE_BIT_AND &&
aggregateType != AGGREGATE_BIT_OR &&
aggregateType != AGGREGATE_EVERY &&
aggregateType != AGGREGATE_ANY_VALUE)
{
hasOrderByNonCommutativeAggregate = true;
break;
}
}
}
return hasOrderByNonCommutativeAggregate;
}
/*
* HasOrderByComplexExpression walks over the given order by clauses, and checks
* if we have a nested expression that contains an aggregate function within it.
* If we do, the function returns true.
*/
static bool
HasOrderByComplexExpression(List *sortClauseList, List *targetList)
{
bool hasOrderByComplexExpression = false;
SortGroupClause *sortClause = NULL;
foreach_declared_ptr(sortClause, sortClauseList)
{
Node *sortExpression = get_sortgroupclause_expr(sortClause, targetList);
/* simple aggregate functions are ok */
if (IsA(sortExpression, Aggref))
{
continue;
}
bool nestedAggregate = contain_aggs_of_level(sortExpression, 0);
if (nestedAggregate)
{
hasOrderByComplexExpression = true;
break;
}
}
return hasOrderByComplexExpression;
}
/*
* HasOrderByHllType walks over the given order by clauses, and checks if any of
* those clauses operate on hll data type. If they do, the function returns true.
*/
static bool
HasOrderByHllType(List *sortClauseList, List *targetList)
{
bool hasOrderByHllType = false;
/* check whether HLL is loaded */
Oid hllId = get_extension_oid(HLL_EXTENSION_NAME, true);
if (!OidIsValid(hllId))
{
return hasOrderByHllType;
}
Oid hllSchemaOid = get_extension_schema(hllId);
Oid hllTypeId = TypeOid(hllSchemaOid, HLL_TYPE_NAME);
SortGroupClause *sortClause = NULL;
foreach_declared_ptr(sortClause, sortClauseList)
{
Node *sortExpression = get_sortgroupclause_expr(sortClause, targetList);
Oid sortColumnTypeId = exprType(sortExpression);
if (sortColumnTypeId == hllTypeId)
{
hasOrderByHllType = true;
break;
}
}
return hasOrderByHllType;
}
/*
* ShouldProcessDistinctOrderAndLimitForWorker returns whether
* ProcessDistinctClauseForWorkerQuery should be called. If not,
* neither should ProcessLimitOrderByForWorkerQuery.
*/
static bool
ShouldProcessDistinctOrderAndLimitForWorker(ExtendedOpNodeProperties *
extendedOpNodeProperties,
bool pushingDownOriginalGrouping,
Node *havingQual)
{
if (extendedOpNodeProperties->pullUpIntermediateRows)
{
return false;
}
/* window functions must be evaluated beforehand */
if (!extendedOpNodeProperties->onlyPushableWindowFunctions)
{
return false;
}
if (extendedOpNodeProperties->pushDownGroupingAndHaving)
{
return true;
}
/* If the same GROUP BY is being pushed down and there's no HAVING,
* then the push down logic will be able to handle this scenario.
*/
if (pushingDownOriginalGrouping && havingQual == NULL)
{
return true;
}
return false;
}
/*
* WorkerColumnName returns a palloc'd string for being the resname of a TargetEntry.
*/
char *
WorkerColumnName(AttrNumber resno)
{
StringInfoData name = { 0 };
initStringInfo(&name);
appendStringInfo(&name, WORKER_COLUMN_FORMAT, resno);
return name.data;
}
/*
* IsGroupBySubsetOfDistinct checks whether each clause in group clauses also
* exists in the distinct clauses. Note that, empty group clause is not a subset
* of distinct clause.
*/
bool
IsGroupBySubsetOfDistinct(List *groupClauses, List *distinctClauses)
{
/* There must be a group clause */
if (list_length(groupClauses) == 0)
{
return false;
}
SortGroupClause *groupClause = NULL;
foreach_declared_ptr(groupClause, groupClauses)
{
bool isFound = false;
SortGroupClause *distinctClause = NULL;
foreach_declared_ptr(distinctClause, distinctClauses)
{
if (groupClause->tleSortGroupRef == distinctClause->tleSortGroupRef)
{
isFound = true;
break;
}
}
/*
* If we can't find any member of group clause in the distinct clause,
* that means group clause is not a subset of distinct clause.
*/
if (!isFound)
{
return false;
}
}
return true;
}