Add complex distinct count support for repartitioned subqueries

Single table repartition subqueries now support count(distinct column)
and count(distinct (case when ...)) expressions. Repartition query
extracts column used in aggregate expression and adds them to target
list and group by list, master query stays the same (count (distinct ...))
but attribute numbers inside the aggregate expression is modified to
reflect changes in repartition query.
pull/516/head
Murat Tuncer 2016-04-04 15:17:16 +03:00
parent b520fb7448
commit 2b0d6473b9
11 changed files with 1041 additions and 74 deletions

View File

@ -54,6 +54,20 @@ int LimitClauseRowFetchCount = -1; /* number of rows to fetch from each task */
double CountDistinctErrorRate = 0.0; /* precision of count(distinct) approximate */
typedef struct MasterAggregateWalkerContext
{
bool repartitionSubquery;
AttrNumber columnId;
} MasterAggregateWalkerContext;
typedef struct WorkerAggregateWalkerContext
{
bool repartitionSubquery;
List *expressionList;
bool createGroupByClause;
} WorkerAggregateWalkerContext;
/* Local functions forward declarations */
static MultiSelect * AndSelectNode(MultiSelect *selectNode);
static MultiSelect * OrSelectNode(MultiSelect *selectNode);
@ -96,14 +110,18 @@ static void ApplyExtendedOpNodes(MultiExtendedOp *originalNode,
MultiExtendedOp *workerNode);
static void TransformSubqueryNode(MultiTable *subqueryNode);
static MultiExtendedOp * MasterExtendedOpNode(MultiExtendedOp *originalOpNode);
static Node * MasterAggregateMutator(Node *originalNode, AttrNumber *columnId);
static Expr * MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId);
static Node * MasterAggregateMutator(Node *originalNode,
MasterAggregateWalkerContext *walkerContext);
static Expr * MasterAggregateExpression(Aggref *originalAggregate,
MasterAggregateWalkerContext *walkerContext);
static Expr * MasterAverageExpression(Oid sumAggregateType, Oid countAggregateType,
AttrNumber *columnId);
static Expr * AddTypeConversion(Node *originalAggregate, Node *newExpression);
static MultiExtendedOp * WorkerExtendedOpNode(MultiExtendedOp *originalOpNode);
static bool WorkerAggregateWalker(Node *node, List **newExpressionList);
static List * WorkerAggregateExpressionList(Aggref *originalAggregate);
static bool WorkerAggregateWalker(Node *node,
WorkerAggregateWalkerContext *walkerContext);
static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *walkerContextry);
static AggregateType GetAggregateType(Oid aggFunctionId);
static Oid AggregateArgumentType(Aggref *aggregate);
static Oid AggregateFunctionOid(const char *functionName, Oid inputType);
@ -1145,7 +1163,6 @@ TransformSubqueryNode(MultiTable *subqueryNode)
MultiExtendedOp *masterExtendedOpNode = MasterExtendedOpNode(extendedOpNode);
MultiExtendedOp *workerExtendedOpNode = WorkerExtendedOpNode(extendedOpNode);
MultiPartition *partitionNode = CitusMakeNode(MultiPartition);
List *groupClauseList = extendedOpNode->groupClauseList;
List *targetEntryList = extendedOpNode->targetList;
List *groupTargetEntryList = GroupTargetEntryList(groupClauseList, targetEntryList);
@ -1212,7 +1229,18 @@ MasterExtendedOpNode(MultiExtendedOp *originalOpNode)
List *targetEntryList = originalOpNode->targetList;
List *newTargetEntryList = NIL;
ListCell *targetEntryCell = NULL;
AttrNumber columnId = 1;
MultiNode *parentNode = ParentNode((MultiNode *) originalOpNode);
MultiNode *childNode = ChildNode((MultiUnaryNode *) originalOpNode);
MasterAggregateWalkerContext *walkerContext = palloc0(
sizeof(MasterAggregateWalkerContext));
walkerContext->columnId = 1;
walkerContext->repartitionSubquery = false;
if (CitusIsA(parentNode, MultiTable) && CitusIsA(childNode, MultiCollect))
{
walkerContext->repartitionSubquery = true;
}
/* iterate over original target entries */
foreach(targetEntryCell, targetEntryList)
@ -1226,7 +1254,7 @@ MasterExtendedOpNode(MultiExtendedOp *originalOpNode)
if (hasAggregates)
{
Node *newNode = MasterAggregateMutator((Node *) originalExpression,
&columnId);
walkerContext);
newExpression = (Expr *) newNode;
}
else
@ -1238,9 +1266,9 @@ MasterExtendedOpNode(MultiExtendedOp *originalOpNode)
const uint32 masterTableId = 1; /* only one table on master node */
Var *column = makeVarFromTargetEntry(masterTableId, originalTargetEntry);
column->varattno = columnId;
column->varoattno = columnId;
columnId++;
column->varattno = walkerContext->columnId;
column->varoattno = walkerContext->columnId;
walkerContext->columnId++;
newExpression = (Expr *) column;
}
@ -1271,7 +1299,7 @@ MasterExtendedOpNode(MultiExtendedOp *originalOpNode)
* depth first order.
*/
static Node *
MasterAggregateMutator(Node *originalNode, AttrNumber *columnId)
MasterAggregateMutator(Node *originalNode, MasterAggregateWalkerContext *walkerContext)
{
Node *newNode = NULL;
if (originalNode == NULL)
@ -1282,7 +1310,7 @@ MasterAggregateMutator(Node *originalNode, AttrNumber *columnId)
if (IsA(originalNode, Aggref))
{
Aggref *originalAggregate = (Aggref *) originalNode;
Expr *newExpression = MasterAggregateExpression(originalAggregate, columnId);
Expr *newExpression = MasterAggregateExpression(originalAggregate, walkerContext);
newNode = (Node *) newExpression;
}
@ -1291,15 +1319,15 @@ MasterAggregateMutator(Node *originalNode, AttrNumber *columnId)
uint32 masterTableId = 1; /* one table on the master node */
Var *newColumn = copyObject(originalNode);
newColumn->varno = masterTableId;
newColumn->varattno = (*columnId);
(*columnId)++;
newColumn->varattno = walkerContext->columnId;
walkerContext->columnId++;
newNode = (Node *) newColumn;
}
else
{
newNode = expression_tree_mutator(originalNode, MasterAggregateMutator,
(void *) columnId);
(void *) walkerContext);
}
return newNode;
@ -1317,7 +1345,8 @@ MasterAggregateMutator(Node *originalNode, AttrNumber *columnId)
* knowledge to create the appropriate master function with correct data types.
*/
static Expr *
MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
MasterAggregateExpression(Aggref *originalAggregate,
MasterAggregateWalkerContext *walkerContext)
{
AggregateType aggregateType = GetAggregateType(originalAggregate->aggfnoid);
Expr *newMasterExpression = NULL;
@ -1327,7 +1356,55 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
const AttrNumber argumentId = 1; /* our aggregates have single arguments */
if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
walkerContext->repartitionSubquery)
{
Aggref *aggregate = (Aggref *) copyObject(originalAggregate);
List *aggTargetEntryList = aggregate->args;
TargetEntry *distinctTargetEntry = linitial(aggTargetEntryList);
List *varList = pull_var_clause_default((Node *) distinctTargetEntry->expr);
ListCell *varCell = NULL;
List *uniqueVarList = NIL;
int startColumnCount = walkerContext->columnId;
/* determine unique vars that were placed in target list by worker */
foreach(varCell, varList)
{
Var *column = (Var *) lfirst(varCell);
uniqueVarList = list_append_unique(uniqueVarList, copyObject(column));
}
/*
* Go over each var inside aggregate and update their varattno's according to
* worker query target entry column index.
*/
foreach(varCell, varList)
{
Var *columnToUpdate = (Var *) lfirst(varCell);
ListCell *uniqueVarCell = NULL;
int columnIndex = 0;
foreach(uniqueVarCell, uniqueVarList)
{
Var *currentVar = (Var *) lfirst(uniqueVarCell);
if (equal(columnToUpdate, currentVar))
{
break;
}
columnIndex++;
}
columnToUpdate->varattno = startColumnCount + columnIndex;
columnToUpdate->varoattno = startColumnCount + columnIndex;
}
/* we added that many columns */
walkerContext->columnId += list_length(uniqueVarList);
newMasterExpression = (Expr *) aggregate;
}
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
{
/*
* If enabled, we check for count(distinct) approximations before count
@ -1348,9 +1425,10 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
Oid hllType = TypenameGetTypid(HLL_TYPE_NAME);
Oid hllTypeCollationId = get_typcollation(hllType);
Var *hllColumn = makeVar(masterTableId, (*columnId), hllType, defaultTypeMod,
Var *hllColumn = makeVar(masterTableId, walkerContext->columnId, hllType,
defaultTypeMod,
hllTypeCollationId, columnLevelsUp);
(*columnId)++;
walkerContext->columnId++;
hllTargetEntry = makeTargetEntry((Expr *) hllColumn, argumentId, NULL, false);
@ -1389,7 +1467,7 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
/* create the expression sum(sum(column) / sum(count(column))) */
newMasterExpression = MasterAverageExpression(workerSumReturnType,
workerCountReturnType,
columnId);
&(walkerContext->columnId));
}
else if (aggregateType == AGGREGATE_COUNT)
{
@ -1415,9 +1493,9 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
newMasterAggregate->aggfnoid = sumFunctionId;
newMasterAggregate->aggtype = masterReturnType;
column = makeVar(masterTableId, (*columnId), workerReturnType,
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
(*columnId)++;
walkerContext->columnId++;
/* aggref expects its arguments to be wrapped in target entries */
columnTargetEntry = makeTargetEntry((Expr *) column, argumentId, NULL, false);
@ -1451,10 +1529,10 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
ANYARRAYOID);
/* create argument for the array_cat_agg() aggregate */
column = makeVar(masterTableId, (*columnId), workerReturnType,
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
arrayCatAggArgument = makeTargetEntry((Expr *) column, argumentId, NULL, false);
(*columnId)++;
walkerContext->columnId++;
/* construct the master array_cat_agg() expression */
newMasterAggregate = copyObject(originalAggregate);
@ -1486,9 +1564,9 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
newMasterAggregate->aggfnoid = aggregateFunctionId;
newMasterAggregate->aggtype = masterReturnType;
column = makeVar(masterTableId, (*columnId), workerReturnType,
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
(*columnId)++;
walkerContext->columnId++;
/* aggref expects its arguments to be wrapped in target entries */
columnTargetEntry = makeTargetEntry((Expr *) column, argumentId, NULL, false);
@ -1611,16 +1689,45 @@ AddTypeConversion(Node *originalAggregate, Node *newExpression)
* with aggregates in them, this function calls the recursive aggregate walker
* function to create aggregates for the worker nodes. Also, the function checks
* if we can push down the limit to worker nodes; and if we can, sets the limit
* count and sort clause list fields in the new operator node.
* count and sort clause list fields in the new operator node. It provides special
* treatment for count distinct operator if it is used in repartition subqueries.
* Each column in count distinct aggregate is added to target list, and group by
* list of worker extended operator.
*/
static MultiExtendedOp *
WorkerExtendedOpNode(MultiExtendedOp *originalOpNode)
{
MultiExtendedOp *workerExtendedOpNode = NULL;
MultiNode *parentNode = ParentNode((MultiNode *) originalOpNode);
MultiNode *childNode = ChildNode((MultiUnaryNode *) originalOpNode);
List *targetEntryList = originalOpNode->targetList;
List *newTargetEntryList = NIL;
ListCell *targetEntryCell = NULL;
List *newTargetEntryList = NIL;
List *groupClauseList = copyObject(originalOpNode->groupClauseList);
AttrNumber targetProjectionNumber = 1;
WorkerAggregateWalkerContext *walkerContext =
palloc0(sizeof(WorkerAggregateWalkerContext));
walkerContext->repartitionSubquery = false;
walkerContext->expressionList = NIL;
Index nextSortGroupRefIndex = 0;
if (CitusIsA(parentNode, MultiTable) && CitusIsA(childNode, MultiCollect))
{
walkerContext->repartitionSubquery = true;
/* find max of sort group ref index */
foreach(targetEntryCell, targetEntryList)
{
TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell);
if (targetEntry->ressortgroupref > nextSortGroupRefIndex)
{
nextSortGroupRefIndex = targetEntry->ressortgroupref;
}
}
/* next group ref index starts from max group ref index + 1 */
nextSortGroupRefIndex++;
}
/* iterate over original target entries */
foreach(targetEntryCell, targetEntryList)
@ -1629,11 +1736,15 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode)
Expr *originalExpression = originalTargetEntry->expr;
List *newExpressionList = NIL;
ListCell *newExpressionCell = NULL;
bool hasAggregates = contain_agg_clause((Node *) originalExpression);
walkerContext->expressionList = NIL;
walkerContext->createGroupByClause = false;
if (hasAggregates)
{
WorkerAggregateWalker((Node *) originalExpression, &newExpressionList);
WorkerAggregateWalker((Node *) originalExpression, walkerContext);
newExpressionList = walkerContext->expressionList;
}
else
{
@ -1647,6 +1758,37 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode)
TargetEntry *newTargetEntry = copyObject(originalTargetEntry);
newTargetEntry->expr = newExpression;
/*
* Detect new targets of type Var and add it to group clause list.
* This case is expected only if the target entry has aggregates and
* it is inside a repartitioned subquery. We create group by entry
* for each Var in target list. This code does not check if this
* Var was already in the target list or in group by clauses.
*/
if (IsA(newExpression, Var) && walkerContext->createGroupByClause)
{
Var *column = (Var *) newExpression;
Oid lessThanOperator = InvalidOid;
Oid equalsOperator = InvalidOid;
bool hashable = false;
SortGroupClause *groupByClause = makeNode(SortGroupClause);
get_sort_group_operators(column->vartype, true, true, true,
&lessThanOperator, &equalsOperator, NULL,
&hashable);
groupByClause->eqop = equalsOperator;
groupByClause->hashable = hashable;
groupByClause->nulls_first = false;
groupByClause->sortop = lessThanOperator;
groupByClause->tleSortGroupRef = nextSortGroupRefIndex;
groupClauseList = lappend(groupClauseList, groupByClause);
newTargetEntry->ressortgroupref = nextSortGroupRefIndex;
nextSortGroupRefIndex++;
}
if (newTargetEntry->resname == NULL)
{
StringInfo columnNameString = makeStringInfo();
@ -1660,14 +1802,13 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode)
newTargetEntry->resjunk = false;
newTargetEntry->resno = targetProjectionNumber;
targetProjectionNumber++;
newTargetEntryList = lappend(newTargetEntryList, newTargetEntry);
}
}
workerExtendedOpNode = CitusMakeNode(MultiExtendedOp);
workerExtendedOpNode->targetList = newTargetEntryList;
workerExtendedOpNode->groupClauseList = originalOpNode->groupClauseList;
workerExtendedOpNode->groupClauseList = groupClauseList;
/* if we can push down the limit, also set related fields */
workerExtendedOpNode->limitCount = WorkerLimitCount(originalOpNode);
@ -1685,7 +1826,7 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode)
* types.
*/
static bool
WorkerAggregateWalker(Node *node, List **newExpressionList)
WorkerAggregateWalker(Node *node, WorkerAggregateWalkerContext *walkerContext)
{
bool walkerResult = false;
if (node == NULL)
@ -1696,19 +1837,22 @@ WorkerAggregateWalker(Node *node, List **newExpressionList)
if (IsA(node, Aggref))
{
Aggref *originalAggregate = (Aggref *) node;
List *workerAggregateList = WorkerAggregateExpressionList(originalAggregate);
List *workerAggregateList = WorkerAggregateExpressionList(originalAggregate,
walkerContext);
(*newExpressionList) = list_concat(*newExpressionList, workerAggregateList);
walkerContext->expressionList = list_concat(walkerContext->expressionList,
workerAggregateList);
}
else if (IsA(node, Var))
{
Var *originalColumn = (Var *) node;
(*newExpressionList) = lappend(*newExpressionList, originalColumn);
walkerContext->expressionList = lappend(walkerContext->expressionList,
originalColumn);
}
else
{
walkerResult = expression_tree_walker(node, WorkerAggregateWalker,
(void *) newExpressionList);
(void *) walkerContext);
}
return walkerResult;
@ -1718,16 +1862,44 @@ WorkerAggregateWalker(Node *node, List **newExpressionList)
/*
* WorkerAggregateExpressionList takes in the original aggregate function, and
* determines the transformed aggregate functions to execute on worker nodes.
* The function then returns these aggregates in a list.
* The function then returns these aggregates in a list. It also creates
* group by clauses for newly added targets to be placed in the extended operator
* node.
*/
static List *
WorkerAggregateExpressionList(Aggref *originalAggregate)
WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *walkerContext)
{
AggregateType aggregateType = GetAggregateType(originalAggregate->aggfnoid);
List *workerAggregateList = NIL;
if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
walkerContext->repartitionSubquery)
{
Aggref *aggregate = (Aggref *) copyObject(originalAggregate);
List *aggTargetEntryList = aggregate->args;
TargetEntry *distinctTargetEntry = (TargetEntry *) linitial(aggTargetEntryList);
List *columnList = pull_var_clause_default((Node *) distinctTargetEntry);
ListCell *columnCell = NULL;
List *processedColumnList = NIL;
foreach(columnCell, columnList)
{
Var *column = (Var *) lfirst(columnCell);
if (list_member(processedColumnList, column))
{
continue;
}
processedColumnList = lappend(processedColumnList, column);
workerAggregateList = lappend(workerAggregateList, copyObject(column));
}
walkerContext->createGroupByClause = true;
}
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
{
/*
* If the original aggregate is a count(distinct) approximation, we want
@ -2148,9 +2320,11 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
bool distinctSupported = true;
List *repartitionNodeList = NIL;
Var *distinctColumn = NULL;
List *multiTableNodeList = NIL;
ListCell *multiTableNodeCell = NULL;
AggregateType aggregateType = AGGREGATE_INVALID_FIRST;
List *tableNodeList = NIL;
List *extendedOpNodeList = NIL;
MultiExtendedOp *extendedOpNode = NULL;
AggregateType aggregateType = GetAggregateType(aggregateExpression->aggfnoid);
/* check if logical plan includes a subquery */
List *subqueryMultiTableList = SubqueryMultiTableList(logicalPlanNode);
@ -2161,20 +2335,43 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
errdetail("distinct in the outermost query is unsupported")));
}
multiTableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
foreach(multiTableNodeCell, multiTableNodeList)
/*
* We partially support count(distinct) in subqueries, other distinct aggregates in
* subqueries are not supported yet.
*/
if (aggregateType == AGGREGATE_COUNT)
{
MultiTable *multiTable = (MultiTable *) lfirst(multiTableNodeCell);
if (multiTable->relationId == SUBQUERY_RELATION_ID)
Node *aggregateArgument = (Node *) linitial(aggregateExpression->args);
List *columnList = pull_var_clause_default(aggregateArgument);
ListCell *columnCell = NULL;
foreach(columnCell, columnList)
{
ereport(ERROR, (errmsg("cannot compute count (distinct)"),
errdetail("Subqueries with aggregate (distinct) are "
"not supported yet")));
Var *column = (Var *) lfirst(columnCell);
if (column->varattno <= 0)
{
ereport(ERROR, (errmsg("cannot compute count (distinct)"),
errdetail("Non-column references are not supported "
"yet")));
}
}
}
else
{
List *multiTableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
ListCell *multiTableNodeCell = NULL;
foreach(multiTableNodeCell, multiTableNodeList)
{
MultiTable *multiTable = (MultiTable *) lfirst(multiTableNodeCell);
if (multiTable->relationId == SUBQUERY_RELATION_ID)
{
ereport(ERROR, (errmsg("cannot compute aggregate (distinct)"),
errdetail("Only count(distinct) aggregate is "
"supported in subqueries")));
}
}
}
/* if we have a count(distinct), and distinct approximation is enabled */
aggregateType = GetAggregateType(aggregateExpression->aggfnoid);
if (aggregateType == AGGREGATE_COUNT &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
{
@ -2193,6 +2390,16 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
}
}
if (aggregateType == AGGREGATE_COUNT)
{
List *aggregateVarList = pull_var_clause_default((Node *) aggregateExpression);
if (aggregateVarList == NIL)
{
distinctSupported = false;
errorDetail = "aggregate (distinct) with no columns is unsupported";
}
}
repartitionNodeList = FindNodesOfType(logicalPlanNode, T_MultiPartition);
if (repartitionNodeList != NIL)
{
@ -2200,19 +2407,27 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
errorDetail = "aggregate (distinct) with table repartitioning is unsupported";
}
tableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
extendedOpNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp);
extendedOpNode = (MultiExtendedOp *) linitial(extendedOpNodeList);
distinctColumn = AggregateDistinctColumn(aggregateExpression);
if (distinctColumn == NULL)
if (distinctSupported && distinctColumn == NULL)
{
distinctSupported = false;
errorDetail = "aggregate (distinct) on complex expressions is unsupported";
/*
* If the query has a single table, and table is grouped by partition column,
* then we support count distincts even distinct column can not be identified.
*/
distinctSupported = TablePartitioningSupportsDistinct(tableNodeList,
extendedOpNode,
distinctColumn);
if (!distinctSupported)
{
errorDetail = "aggregate (distinct) on complex expressions is unsupported";
}
}
else
else if (distinctSupported)
{
List *tableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
List *opNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp);
MultiExtendedOp *extendedOpNode = (MultiExtendedOp *) linitial(opNodeList);
bool supports = TablePartitioningSupportsDistinct(tableNodeList, extendedOpNode,
distinctColumn);
if (!supports)
@ -2299,6 +2514,11 @@ TablePartitioningSupportsDistinct(List *tableNodeList, MultiExtendedOp *opNode,
bool tableDistinctSupported = false;
char partitionMethod = 0;
if (relationId == SUBQUERY_RELATION_ID)
{
return true;
}
/* if table has one shard, task results don't overlap */
List *shardList = LoadShardList(relationId);
if (list_length(shardList) == 1)
@ -2319,7 +2539,8 @@ TablePartitioningSupportsDistinct(List *tableNodeList, MultiExtendedOp *opNode,
bool groupedByPartitionColumn = false;
/* if distinct is on table partition column, we can push it down */
if (tablePartitionColumn->varno == distinctColumn->varno &&
if (distinctColumn != NULL &&
tablePartitionColumn->varno == distinctColumn->varno &&
tablePartitionColumn->varattno == distinctColumn->varattno)
{
tableDistinctSupported = true;

View File

@ -4343,13 +4343,16 @@ MergeTaskList(MapMergeJob *mapMergeJob, List *mapTaskList, uint32 taskIdIndex)
{
StringInfo mergeTableQueryString =
MergeTableQueryString(taskIdIndex, targetEntryList);
char *escapedMergeTableQueryString =
quote_literal_cstr(mergeTableQueryString->data);
StringInfo intermediateTableQueryString =
IntermediateTableQueryString(jobId, taskIdIndex, reduceQuery);
char *escapedIntermediateTableQueryString =
quote_literal_cstr(intermediateTableQueryString->data);
StringInfo mergeAndRunQueryString = makeStringInfo();
appendStringInfo(mergeAndRunQueryString, MERGE_FILES_AND_RUN_QUERY_COMMAND,
jobId, taskIdIndex, mergeTableQueryString->data,
intermediateTableQueryString->data);
jobId, taskIdIndex, escapedMergeTableQueryString,
escapedIntermediateTableQueryString);
mergeTask = CreateBasicTask(jobId, mergeTaskId, MERGE_TASK,
mergeAndRunQueryString->data);

View File

@ -41,7 +41,7 @@
#define MERGE_FILES_INTO_TABLE_COMMAND "SELECT worker_merge_files_into_table \
(" UINT64_FORMAT ", %d, '%s', '%s')"
#define MERGE_FILES_AND_RUN_QUERY_COMMAND \
"SELECT worker_merge_files_and_run_query(" UINT64_FORMAT ", %d, '%s', '%s')"
"SELECT worker_merge_files_and_run_query(" UINT64_FORMAT ", %d, %s, %s)"
typedef enum CitusRTEKind

View File

@ -16,3 +16,4 @@
/multi_subquery.out
/multi_subquery_0.out
/worker_copy.out
/multi_complex_count_distinct.out

View File

@ -920,9 +920,20 @@ SELECT
articles_hash
GROUP BY
author_id;
ERROR: cannot compute aggregate (distinct)
DETAIL: aggregate (distinct) on complex expressions is unsupported
HINT: You can load the hll extension from contrib packages and enable distinct approximations.
c
---
4
5
5
5
5
5
5
5
5
5
(10 rows)
-- queries inside transactions can be router plannable
BEGIN;
SELECT *

View File

@ -171,9 +171,9 @@ from
l_tax) as distributed_table;
ERROR: cannot perform distributed planning on this query
DETAIL: Subqueries without aggregates are not supported yet
-- Check that we don't support subqueries with count(distinct).
-- Check that we support subqueries with count(distinct).
select
different_shipment_days
avg(different_shipment_days)
from
(select
count(distinct l_shipdate) as different_shipment_days
@ -181,8 +181,11 @@ from
lineitem
group by
l_partkey) as distributed_table;
ERROR: cannot compute count (distinct)
DETAIL: Subqueries with aggregate (distinct) are not supported yet
avg
------------------------
1.02907126318497555956
(1 row)
-- Check that if subquery is pulled, we don't error and run query properly.
SELECT max(l_suppkey) FROM
(

View File

@ -0,0 +1,284 @@
--
-- COMPLEX_COUNT_DISTINCT
--
CREATE TABLE lineitem_hash (
l_orderkey bigint not null,
l_partkey integer not null,
l_suppkey integer not null,
l_linenumber integer not null,
l_quantity decimal(15, 2) not null,
l_extendedprice decimal(15, 2) not null,
l_discount decimal(15, 2) not null,
l_tax decimal(15, 2) not null,
l_returnflag char(1) not null,
l_linestatus char(1) not null,
l_shipdate date not null,
l_commitdate date not null,
l_receiptdate date not null,
l_shipinstruct char(25) not null,
l_shipmode char(10) not null,
l_comment varchar(44) not null,
PRIMARY KEY(l_orderkey, l_linenumber) );
SELECT master_create_distributed_table('lineitem_hash', 'l_orderkey', 'hash');
SELECT master_create_worker_shards('lineitem_hash', 8, 1);
\COPY lineitem_hash FROM '@abs_srcdir@/data/lineitem.1.data' with delimiter '|'
\COPY lineitem_hash FROM '@abs_srcdir@/data/lineitem.2.data' with delimiter '|'
SET citus.task_executor_type to "task-tracker";
-- count(distinct) is supported on top level query if there
-- is a grouping on the partition key
SELECT
l_orderkey, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- it is not supported if there is no grouping or grouping is on non-partition field
SELECT
count(DISTINCT l_partkey)
FROM lineitem_hash
ORDER BY 1 DESC
LIMIT 10;
SELECT
l_shipmode, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_shipmode
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- count distinct is supported on single table subqueries
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
SELECT *
FROM (
SELECT
l_partkey, count(DISTINCT l_orderkey)
FROM lineitem_hash
GROUP BY l_partkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- case expr in count distinct is supported.
-- count orders partkeys if l_shipmode is air
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT CASE WHEN l_shipmode = 'AIR' THEN l_partkey ELSE NULL END) as count
FROM lineitem_hash
GROUP BY l_orderkey) sub
WHERE count > 0
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- text like operator is also supported
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT CASE WHEN l_shipmode like '%A%' THEN l_partkey ELSE NULL END) as count
FROM lineitem_hash
GROUP BY l_orderkey) sub
WHERE count > 0
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- count distinct is rejected if it does not reference any columns
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT 1)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- count distinct is rejected if it does not reference any columns
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT (random() * 5)::int)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- even non-const function calls are supported within count distinct
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT (random() * 5)::int = l_linenumber)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 0;
-- multiple nested subquery
SELECT
total,
avg(avg_count) as total_avg_count
FROM (
SELECT
number_sum,
count(DISTINCT l_suppkey) as total,
avg(total_count) avg_count
FROM (
SELECT
l_suppkey,
sum(l_linenumber) as number_sum,
count(DISTINCT l_shipmode) as total_count
FROM
lineitem_hash
WHERE
l_partkey > 100 and
l_quantity > 2 and
l_orderkey < 10000
GROUP BY
l_suppkey) as distributed_table
WHERE
number_sum >= 10
GROUP BY
number_sum) as distributed_table_2
GROUP BY
total
ORDER BY
total_avg_count DESC;
-- multiple cases query
SELECT *
FROM (
SELECT
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_partkey
WHEN l_shipmode = 'AIR' THEN l_quantity
WHEN l_shipmode = 'SHIP' THEN l_discount
ELSE l_suppkey
END) as count,
l_shipdate
FROM
lineitem_hash
GROUP BY
l_shipdate) sub
WHERE
count > 0
ORDER BY
1 DESC, 2 DESC
LIMIT 10;
-- count DISTINCT expression
SELECT *
FROM (
SELECT
l_quantity, count(DISTINCT ((l_orderkey / 1000) * 1000 )) as count
FROM
lineitem_hash
GROUP BY
l_quantity) sub
WHERE
count > 0
ORDER BY
2 DESC, 1 DESC
LIMIT 10;
-- count DISTINCT is part of an expression which inclues another aggregate
SELECT *
FROM (
SELECT
sum(((l_partkey * l_tax) / 100)) /
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_partkey
ELSE l_suppkey
END) as avg,
l_shipmode
FROM
lineitem_hash
GROUP BY
l_shipmode) sub
ORDER BY
1 DESC, 2 DESC
LIMIT 10;
--- count DISTINCT CASE WHEN expression
SELECT *
FROM (
SELECT
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_linenumber
WHEN l_shipmode = 'AIR' THEN l_linenumber + 10
ELSE 2
END) as avg
FROM
lineitem_hash
GROUP BY l_shipdate) sub
ORDER BY 1 DESC
LIMIT 10;
-- COUNT DISTINCT (c1, c2)
SELECT *
FROM
(SELECT
l_shipmode,
count(DISTINCT (l_shipdate, l_tax))
FROM
lineitem_hash
GROUP BY
l_shipmode) t
ORDER BY
2 DESC,1 DESC
LIMIT 10;
-- other distinct aggregate are not supported
SELECT *
FROM (
SELECT
l_orderkey, sum(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
SELECT *
FROM (
SELECT
l_orderkey, avg(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- whole row references, oid, and ctid are not supported in count distinct
-- test table does not have oid or ctid enabled, so tests for them are skipped
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT lineitem_hash)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT lineitem_hash.*)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
DROP TABLE lineitem_hash;

View File

@ -144,3 +144,8 @@ test: multi_large_shardid
# multi_drop_extension makes sure we can safely drop and recreate the extension
# ----------
test: multi_drop_extension
# ----------
# multi_complex_count_distinct creates table lineitem_hash, creates shards and load data
# ----------
test: multi_complex_count_distinct

View File

@ -0,0 +1,438 @@
--
-- COMPLEX_COUNT_DISTINCT
--
CREATE TABLE lineitem_hash (
l_orderkey bigint not null,
l_partkey integer not null,
l_suppkey integer not null,
l_linenumber integer not null,
l_quantity decimal(15, 2) not null,
l_extendedprice decimal(15, 2) not null,
l_discount decimal(15, 2) not null,
l_tax decimal(15, 2) not null,
l_returnflag char(1) not null,
l_linestatus char(1) not null,
l_shipdate date not null,
l_commitdate date not null,
l_receiptdate date not null,
l_shipinstruct char(25) not null,
l_shipmode char(10) not null,
l_comment varchar(44) not null,
PRIMARY KEY(l_orderkey, l_linenumber) );
SELECT master_create_distributed_table('lineitem_hash', 'l_orderkey', 'hash');
master_create_distributed_table
---------------------------------
(1 row)
SELECT master_create_worker_shards('lineitem_hash', 8, 1);
master_create_worker_shards
-----------------------------
(1 row)
\COPY lineitem_hash FROM '@abs_srcdir@/data/lineitem.1.data' with delimiter '|'
\COPY lineitem_hash FROM '@abs_srcdir@/data/lineitem.2.data' with delimiter '|'
SET citus.task_executor_type to "task-tracker";
-- count(distinct) is supported on top level query if there
-- is a grouping on the partition key
SELECT
l_orderkey, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
l_orderkey | count
------------+-------
14885 | 7
14884 | 7
14821 | 7
14790 | 7
14785 | 7
14755 | 7
14725 | 7
14694 | 7
14627 | 7
14624 | 7
(10 rows)
-- it is not supported if there is no grouping or grouping is on non-partition field
SELECT
count(DISTINCT l_partkey)
FROM lineitem_hash
ORDER BY 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: table partitioning is unsuitable for aggregate (distinct)
HINT: You can load the hll extension from contrib packages and enable distinct approximations.
SELECT
l_shipmode, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_shipmode
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: table partitioning is unsuitable for aggregate (distinct)
HINT: You can load the hll extension from contrib packages and enable distinct approximations.
-- count distinct is supported on single table subqueries
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
l_orderkey | count
------------+-------
14885 | 7
14884 | 7
14821 | 7
14790 | 7
14785 | 7
14755 | 7
14725 | 7
14694 | 7
14627 | 7
14624 | 7
(10 rows)
SELECT *
FROM (
SELECT
l_partkey, count(DISTINCT l_orderkey)
FROM lineitem_hash
GROUP BY l_partkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
l_partkey | count
-----------+-------
199146 | 3
188804 | 3
177771 | 3
160895 | 3
149926 | 3
136884 | 3
87761 | 3
15283 | 3
6983 | 3
1927 | 3
(10 rows)
-- case expr in count distinct is supported.
-- count orders partkeys if l_shipmode is air
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT CASE WHEN l_shipmode = 'AIR' THEN l_partkey ELSE NULL END) as count
FROM lineitem_hash
GROUP BY l_orderkey) sub
WHERE count > 0
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
l_orderkey | count
------------+-------
12005 | 4
5409 | 4
4964 | 4
14848 | 3
14496 | 3
13473 | 3
13122 | 3
12929 | 3
12645 | 3
12417 | 3
(10 rows)
-- text like operator is also supported
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT CASE WHEN l_shipmode like '%A%' THEN l_partkey ELSE NULL END) as count
FROM lineitem_hash
GROUP BY l_orderkey) sub
WHERE count > 0
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
l_orderkey | count
------------+-------
14275 | 7
14181 | 7
13605 | 7
12707 | 7
12384 | 7
11746 | 7
10727 | 7
10467 | 7
5636 | 7
4614 | 7
(10 rows)
-- count distinct is rejected if it does not reference any columns
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT 1)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: aggregate (distinct) with no columns is unsupported
HINT: You can load the hll extension from contrib packages and enable distinct approximations.
-- count distinct is rejected if it does not reference any columns
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT (random() * 5)::int)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: aggregate (distinct) with no columns is unsupported
HINT: You can load the hll extension from contrib packages and enable distinct approximations.
-- even non-const function calls are supported within count distinct
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT (random() * 5)::int = l_linenumber)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 0;
l_orderkey | count
------------+-------
(0 rows)
-- multiple nested subquery
SELECT
total,
avg(avg_count) as total_avg_count
FROM (
SELECT
number_sum,
count(DISTINCT l_suppkey) as total,
avg(total_count) avg_count
FROM (
SELECT
l_suppkey,
sum(l_linenumber) as number_sum,
count(DISTINCT l_shipmode) as total_count
FROM
lineitem_hash
WHERE
l_partkey > 100 and
l_quantity > 2 and
l_orderkey < 10000
GROUP BY
l_suppkey) as distributed_table
WHERE
number_sum >= 10
GROUP BY
number_sum) as distributed_table_2
GROUP BY
total
ORDER BY
total_avg_count DESC;
total | total_avg_count
-------+--------------------
1 | 3.6000000000000000
6 | 2.8333333333333333
10 | 2.6000000000000000
27 | 2.5555555555555556
32 | 2.4687500000000000
77 | 2.1948051948051948
57 | 2.1754385964912281
(7 rows)
-- multiple cases query
SELECT *
FROM (
SELECT
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_partkey
WHEN l_shipmode = 'AIR' THEN l_quantity
WHEN l_shipmode = 'SHIP' THEN l_discount
ELSE l_suppkey
END) as count,
l_shipdate
FROM
lineitem_hash
GROUP BY
l_shipdate) sub
WHERE
count > 0
ORDER BY
1 DESC, 2 DESC
LIMIT 10;
count | l_shipdate
-------+------------
14 | 07-30-1997
13 | 05-26-1998
13 | 08-08-1997
13 | 11-17-1995
13 | 01-09-1993
12 | 01-15-1998
12 | 10-15-1997
12 | 09-07-1997
12 | 06-02-1997
12 | 03-14-1997
(10 rows)
-- count DISTINCT expression
SELECT *
FROM (
SELECT
l_quantity, count(DISTINCT ((l_orderkey / 1000) * 1000 )) as count
FROM
lineitem_hash
GROUP BY
l_quantity) sub
WHERE
count > 0
ORDER BY
2 DESC, 1 DESC
LIMIT 10;
l_quantity | count
------------+-------
48.00 | 13
47.00 | 13
37.00 | 13
33.00 | 13
26.00 | 13
25.00 | 13
23.00 | 13
21.00 | 13
15.00 | 13
12.00 | 13
(10 rows)
-- count DISTINCT is part of an expression which inclues another aggregate
SELECT *
FROM (
SELECT
sum(((l_partkey * l_tax) / 100)) /
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_partkey
ELSE l_suppkey
END) as avg,
l_shipmode
FROM
lineitem_hash
GROUP BY
l_shipmode) sub
ORDER BY
1 DESC, 2 DESC
LIMIT 10;
avg | l_shipmode
-------------------------+------------
44.82904609027336300064 | MAIL
44.80704536679536679537 | SHIP
44.68891732736572890026 | AIR
44.34106724470134874759 | REG AIR
43.12739987269255251432 | FOB
43.07299253636938646426 | RAIL
40.50298377916903813318 | TRUCK
(7 rows)
--- count DISTINCT CASE WHEN expression
SELECT *
FROM (
SELECT
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_linenumber
WHEN l_shipmode = 'AIR' THEN l_linenumber + 10
ELSE 2
END) as avg
FROM
lineitem_hash
GROUP BY l_shipdate) sub
ORDER BY 1 DESC
LIMIT 10;
avg
-----
7
6
6
6
6
6
6
6
5
5
(10 rows)
-- COUNT DISTINCT (c1, c2)
SELECT *
FROM
(SELECT
l_shipmode,
count(DISTINCT (l_shipdate, l_tax))
FROM
lineitem_hash
GROUP BY
l_shipmode) t
ORDER BY
2 DESC,1 DESC
LIMIT 10;
l_shipmode | count
------------+-------
TRUCK | 1689
MAIL | 1683
FOB | 1655
AIR | 1650
SHIP | 1644
RAIL | 1636
REG AIR | 1607
(7 rows)
-- other distinct aggregate are not supported
SELECT *
FROM (
SELECT
l_orderkey, sum(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: Only count(distinct) aggregate is supported in subqueries
SELECT *
FROM (
SELECT
l_orderkey, avg(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: Only count(distinct) aggregate is supported in subqueries
-- whole row references, oid, and ctid are not supported in count distinct
-- test table does not have oid or ctid enabled, so tests for them are skipped
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT lineitem_hash)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute count (distinct)
DETAIL: Non-column references are not supported yet
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT lineitem_hash.*)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute count (distinct)
DETAIL: Non-column references are not supported yet
DROP TABLE lineitem_hash;

View File

@ -14,3 +14,4 @@
/multi_stage_more_data.sql
/multi_subquery.sql
/worker_copy.sql
/multi_complex_count_distinct.sql

View File

@ -125,10 +125,10 @@ from
group by
l_tax) as distributed_table;
-- Check that we don't support subqueries with count(distinct).
-- Check that we support subqueries with count(distinct).
select
different_shipment_days
avg(different_shipment_days)
from
(select
count(distinct l_shipdate) as different_shipment_days