Merge pull request #516 from citusdata/feature/fix_434_support_count_distinct

Add complex count distinct support
pull/523/merge
Murat Tuncer 2016-05-27 15:55:47 +03:00
commit 24e1224eac
11 changed files with 1041 additions and 74 deletions

View File

@ -54,6 +54,20 @@ int LimitClauseRowFetchCount = -1; /* number of rows to fetch from each task */
double CountDistinctErrorRate = 0.0; /* precision of count(distinct) approximate */
typedef struct MasterAggregateWalkerContext
{
bool repartitionSubquery;
AttrNumber columnId;
} MasterAggregateWalkerContext;
typedef struct WorkerAggregateWalkerContext
{
bool repartitionSubquery;
List *expressionList;
bool createGroupByClause;
} WorkerAggregateWalkerContext;
/* Local functions forward declarations */
static MultiSelect * AndSelectNode(MultiSelect *selectNode);
static MultiSelect * OrSelectNode(MultiSelect *selectNode);
@ -96,14 +110,18 @@ static void ApplyExtendedOpNodes(MultiExtendedOp *originalNode,
MultiExtendedOp *workerNode);
static void TransformSubqueryNode(MultiTable *subqueryNode);
static MultiExtendedOp * MasterExtendedOpNode(MultiExtendedOp *originalOpNode);
static Node * MasterAggregateMutator(Node *originalNode, AttrNumber *columnId);
static Expr * MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId);
static Node * MasterAggregateMutator(Node *originalNode,
MasterAggregateWalkerContext *walkerContext);
static Expr * MasterAggregateExpression(Aggref *originalAggregate,
MasterAggregateWalkerContext *walkerContext);
static Expr * MasterAverageExpression(Oid sumAggregateType, Oid countAggregateType,
AttrNumber *columnId);
static Expr * AddTypeConversion(Node *originalAggregate, Node *newExpression);
static MultiExtendedOp * WorkerExtendedOpNode(MultiExtendedOp *originalOpNode);
static bool WorkerAggregateWalker(Node *node, List **newExpressionList);
static List * WorkerAggregateExpressionList(Aggref *originalAggregate);
static bool WorkerAggregateWalker(Node *node,
WorkerAggregateWalkerContext *walkerContext);
static List * WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *walkerContextry);
static AggregateType GetAggregateType(Oid aggFunctionId);
static Oid AggregateArgumentType(Aggref *aggregate);
static Oid AggregateFunctionOid(const char *functionName, Oid inputType);
@ -1145,7 +1163,6 @@ TransformSubqueryNode(MultiTable *subqueryNode)
MultiExtendedOp *masterExtendedOpNode = MasterExtendedOpNode(extendedOpNode);
MultiExtendedOp *workerExtendedOpNode = WorkerExtendedOpNode(extendedOpNode);
MultiPartition *partitionNode = CitusMakeNode(MultiPartition);
List *groupClauseList = extendedOpNode->groupClauseList;
List *targetEntryList = extendedOpNode->targetList;
List *groupTargetEntryList = GroupTargetEntryList(groupClauseList, targetEntryList);
@ -1212,7 +1229,18 @@ MasterExtendedOpNode(MultiExtendedOp *originalOpNode)
List *targetEntryList = originalOpNode->targetList;
List *newTargetEntryList = NIL;
ListCell *targetEntryCell = NULL;
AttrNumber columnId = 1;
MultiNode *parentNode = ParentNode((MultiNode *) originalOpNode);
MultiNode *childNode = ChildNode((MultiUnaryNode *) originalOpNode);
MasterAggregateWalkerContext *walkerContext = palloc0(
sizeof(MasterAggregateWalkerContext));
walkerContext->columnId = 1;
walkerContext->repartitionSubquery = false;
if (CitusIsA(parentNode, MultiTable) && CitusIsA(childNode, MultiCollect))
{
walkerContext->repartitionSubquery = true;
}
/* iterate over original target entries */
foreach(targetEntryCell, targetEntryList)
@ -1226,7 +1254,7 @@ MasterExtendedOpNode(MultiExtendedOp *originalOpNode)
if (hasAggregates)
{
Node *newNode = MasterAggregateMutator((Node *) originalExpression,
&columnId);
walkerContext);
newExpression = (Expr *) newNode;
}
else
@ -1238,9 +1266,9 @@ MasterExtendedOpNode(MultiExtendedOp *originalOpNode)
const uint32 masterTableId = 1; /* only one table on master node */
Var *column = makeVarFromTargetEntry(masterTableId, originalTargetEntry);
column->varattno = columnId;
column->varoattno = columnId;
columnId++;
column->varattno = walkerContext->columnId;
column->varoattno = walkerContext->columnId;
walkerContext->columnId++;
newExpression = (Expr *) column;
}
@ -1271,7 +1299,7 @@ MasterExtendedOpNode(MultiExtendedOp *originalOpNode)
* depth first order.
*/
static Node *
MasterAggregateMutator(Node *originalNode, AttrNumber *columnId)
MasterAggregateMutator(Node *originalNode, MasterAggregateWalkerContext *walkerContext)
{
Node *newNode = NULL;
if (originalNode == NULL)
@ -1282,7 +1310,7 @@ MasterAggregateMutator(Node *originalNode, AttrNumber *columnId)
if (IsA(originalNode, Aggref))
{
Aggref *originalAggregate = (Aggref *) originalNode;
Expr *newExpression = MasterAggregateExpression(originalAggregate, columnId);
Expr *newExpression = MasterAggregateExpression(originalAggregate, walkerContext);
newNode = (Node *) newExpression;
}
@ -1291,15 +1319,15 @@ MasterAggregateMutator(Node *originalNode, AttrNumber *columnId)
uint32 masterTableId = 1; /* one table on the master node */
Var *newColumn = copyObject(originalNode);
newColumn->varno = masterTableId;
newColumn->varattno = (*columnId);
(*columnId)++;
newColumn->varattno = walkerContext->columnId;
walkerContext->columnId++;
newNode = (Node *) newColumn;
}
else
{
newNode = expression_tree_mutator(originalNode, MasterAggregateMutator,
(void *) columnId);
(void *) walkerContext);
}
return newNode;
@ -1317,7 +1345,8 @@ MasterAggregateMutator(Node *originalNode, AttrNumber *columnId)
* knowledge to create the appropriate master function with correct data types.
*/
static Expr *
MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
MasterAggregateExpression(Aggref *originalAggregate,
MasterAggregateWalkerContext *walkerContext)
{
AggregateType aggregateType = GetAggregateType(originalAggregate->aggfnoid);
Expr *newMasterExpression = NULL;
@ -1327,7 +1356,55 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
const AttrNumber argumentId = 1; /* our aggregates have single arguments */
if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
walkerContext->repartitionSubquery)
{
Aggref *aggregate = (Aggref *) copyObject(originalAggregate);
List *aggTargetEntryList = aggregate->args;
TargetEntry *distinctTargetEntry = linitial(aggTargetEntryList);
List *varList = pull_var_clause_default((Node *) distinctTargetEntry->expr);
ListCell *varCell = NULL;
List *uniqueVarList = NIL;
int startColumnCount = walkerContext->columnId;
/* determine unique vars that were placed in target list by worker */
foreach(varCell, varList)
{
Var *column = (Var *) lfirst(varCell);
uniqueVarList = list_append_unique(uniqueVarList, copyObject(column));
}
/*
* Go over each var inside aggregate and update their varattno's according to
* worker query target entry column index.
*/
foreach(varCell, varList)
{
Var *columnToUpdate = (Var *) lfirst(varCell);
ListCell *uniqueVarCell = NULL;
int columnIndex = 0;
foreach(uniqueVarCell, uniqueVarList)
{
Var *currentVar = (Var *) lfirst(uniqueVarCell);
if (equal(columnToUpdate, currentVar))
{
break;
}
columnIndex++;
}
columnToUpdate->varattno = startColumnCount + columnIndex;
columnToUpdate->varoattno = startColumnCount + columnIndex;
}
/* we added that many columns */
walkerContext->columnId += list_length(uniqueVarList);
newMasterExpression = (Expr *) aggregate;
}
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
{
/*
* If enabled, we check for count(distinct) approximations before count
@ -1348,9 +1425,10 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
Oid hllType = TypenameGetTypid(HLL_TYPE_NAME);
Oid hllTypeCollationId = get_typcollation(hllType);
Var *hllColumn = makeVar(masterTableId, (*columnId), hllType, defaultTypeMod,
Var *hllColumn = makeVar(masterTableId, walkerContext->columnId, hllType,
defaultTypeMod,
hllTypeCollationId, columnLevelsUp);
(*columnId)++;
walkerContext->columnId++;
hllTargetEntry = makeTargetEntry((Expr *) hllColumn, argumentId, NULL, false);
@ -1389,7 +1467,7 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
/* create the expression sum(sum(column) / sum(count(column))) */
newMasterExpression = MasterAverageExpression(workerSumReturnType,
workerCountReturnType,
columnId);
&(walkerContext->columnId));
}
else if (aggregateType == AGGREGATE_COUNT)
{
@ -1415,9 +1493,9 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
newMasterAggregate->aggfnoid = sumFunctionId;
newMasterAggregate->aggtype = masterReturnType;
column = makeVar(masterTableId, (*columnId), workerReturnType,
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
(*columnId)++;
walkerContext->columnId++;
/* aggref expects its arguments to be wrapped in target entries */
columnTargetEntry = makeTargetEntry((Expr *) column, argumentId, NULL, false);
@ -1451,10 +1529,10 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
ANYARRAYOID);
/* create argument for the array_cat_agg() aggregate */
column = makeVar(masterTableId, (*columnId), workerReturnType,
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
arrayCatAggArgument = makeTargetEntry((Expr *) column, argumentId, NULL, false);
(*columnId)++;
walkerContext->columnId++;
/* construct the master array_cat_agg() expression */
newMasterAggregate = copyObject(originalAggregate);
@ -1486,9 +1564,9 @@ MasterAggregateExpression(Aggref *originalAggregate, AttrNumber *columnId)
newMasterAggregate->aggfnoid = aggregateFunctionId;
newMasterAggregate->aggtype = masterReturnType;
column = makeVar(masterTableId, (*columnId), workerReturnType,
column = makeVar(masterTableId, walkerContext->columnId, workerReturnType,
workerReturnTypeMod, workerCollationId, columnLevelsUp);
(*columnId)++;
walkerContext->columnId++;
/* aggref expects its arguments to be wrapped in target entries */
columnTargetEntry = makeTargetEntry((Expr *) column, argumentId, NULL, false);
@ -1611,16 +1689,45 @@ AddTypeConversion(Node *originalAggregate, Node *newExpression)
* with aggregates in them, this function calls the recursive aggregate walker
* function to create aggregates for the worker nodes. Also, the function checks
* if we can push down the limit to worker nodes; and if we can, sets the limit
* count and sort clause list fields in the new operator node.
* count and sort clause list fields in the new operator node. It provides special
* treatment for count distinct operator if it is used in repartition subqueries.
* Each column in count distinct aggregate is added to target list, and group by
* list of worker extended operator.
*/
static MultiExtendedOp *
WorkerExtendedOpNode(MultiExtendedOp *originalOpNode)
{
MultiExtendedOp *workerExtendedOpNode = NULL;
MultiNode *parentNode = ParentNode((MultiNode *) originalOpNode);
MultiNode *childNode = ChildNode((MultiUnaryNode *) originalOpNode);
List *targetEntryList = originalOpNode->targetList;
List *newTargetEntryList = NIL;
ListCell *targetEntryCell = NULL;
List *newTargetEntryList = NIL;
List *groupClauseList = copyObject(originalOpNode->groupClauseList);
AttrNumber targetProjectionNumber = 1;
WorkerAggregateWalkerContext *walkerContext =
palloc0(sizeof(WorkerAggregateWalkerContext));
walkerContext->repartitionSubquery = false;
walkerContext->expressionList = NIL;
Index nextSortGroupRefIndex = 0;
if (CitusIsA(parentNode, MultiTable) && CitusIsA(childNode, MultiCollect))
{
walkerContext->repartitionSubquery = true;
/* find max of sort group ref index */
foreach(targetEntryCell, targetEntryList)
{
TargetEntry *targetEntry = (TargetEntry *) lfirst(targetEntryCell);
if (targetEntry->ressortgroupref > nextSortGroupRefIndex)
{
nextSortGroupRefIndex = targetEntry->ressortgroupref;
}
}
/* next group ref index starts from max group ref index + 1 */
nextSortGroupRefIndex++;
}
/* iterate over original target entries */
foreach(targetEntryCell, targetEntryList)
@ -1629,11 +1736,15 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode)
Expr *originalExpression = originalTargetEntry->expr;
List *newExpressionList = NIL;
ListCell *newExpressionCell = NULL;
bool hasAggregates = contain_agg_clause((Node *) originalExpression);
walkerContext->expressionList = NIL;
walkerContext->createGroupByClause = false;
if (hasAggregates)
{
WorkerAggregateWalker((Node *) originalExpression, &newExpressionList);
WorkerAggregateWalker((Node *) originalExpression, walkerContext);
newExpressionList = walkerContext->expressionList;
}
else
{
@ -1647,6 +1758,37 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode)
TargetEntry *newTargetEntry = copyObject(originalTargetEntry);
newTargetEntry->expr = newExpression;
/*
* Detect new targets of type Var and add it to group clause list.
* This case is expected only if the target entry has aggregates and
* it is inside a repartitioned subquery. We create group by entry
* for each Var in target list. This code does not check if this
* Var was already in the target list or in group by clauses.
*/
if (IsA(newExpression, Var) && walkerContext->createGroupByClause)
{
Var *column = (Var *) newExpression;
Oid lessThanOperator = InvalidOid;
Oid equalsOperator = InvalidOid;
bool hashable = false;
SortGroupClause *groupByClause = makeNode(SortGroupClause);
get_sort_group_operators(column->vartype, true, true, true,
&lessThanOperator, &equalsOperator, NULL,
&hashable);
groupByClause->eqop = equalsOperator;
groupByClause->hashable = hashable;
groupByClause->nulls_first = false;
groupByClause->sortop = lessThanOperator;
groupByClause->tleSortGroupRef = nextSortGroupRefIndex;
groupClauseList = lappend(groupClauseList, groupByClause);
newTargetEntry->ressortgroupref = nextSortGroupRefIndex;
nextSortGroupRefIndex++;
}
if (newTargetEntry->resname == NULL)
{
StringInfo columnNameString = makeStringInfo();
@ -1660,14 +1802,13 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode)
newTargetEntry->resjunk = false;
newTargetEntry->resno = targetProjectionNumber;
targetProjectionNumber++;
newTargetEntryList = lappend(newTargetEntryList, newTargetEntry);
}
}
workerExtendedOpNode = CitusMakeNode(MultiExtendedOp);
workerExtendedOpNode->targetList = newTargetEntryList;
workerExtendedOpNode->groupClauseList = originalOpNode->groupClauseList;
workerExtendedOpNode->groupClauseList = groupClauseList;
/* if we can push down the limit, also set related fields */
workerExtendedOpNode->limitCount = WorkerLimitCount(originalOpNode);
@ -1685,7 +1826,7 @@ WorkerExtendedOpNode(MultiExtendedOp *originalOpNode)
* types.
*/
static bool
WorkerAggregateWalker(Node *node, List **newExpressionList)
WorkerAggregateWalker(Node *node, WorkerAggregateWalkerContext *walkerContext)
{
bool walkerResult = false;
if (node == NULL)
@ -1696,19 +1837,22 @@ WorkerAggregateWalker(Node *node, List **newExpressionList)
if (IsA(node, Aggref))
{
Aggref *originalAggregate = (Aggref *) node;
List *workerAggregateList = WorkerAggregateExpressionList(originalAggregate);
List *workerAggregateList = WorkerAggregateExpressionList(originalAggregate,
walkerContext);
(*newExpressionList) = list_concat(*newExpressionList, workerAggregateList);
walkerContext->expressionList = list_concat(walkerContext->expressionList,
workerAggregateList);
}
else if (IsA(node, Var))
{
Var *originalColumn = (Var *) node;
(*newExpressionList) = lappend(*newExpressionList, originalColumn);
walkerContext->expressionList = lappend(walkerContext->expressionList,
originalColumn);
}
else
{
walkerResult = expression_tree_walker(node, WorkerAggregateWalker,
(void *) newExpressionList);
(void *) walkerContext);
}
return walkerResult;
@ -1718,16 +1862,44 @@ WorkerAggregateWalker(Node *node, List **newExpressionList)
/*
* WorkerAggregateExpressionList takes in the original aggregate function, and
* determines the transformed aggregate functions to execute on worker nodes.
* The function then returns these aggregates in a list.
* The function then returns these aggregates in a list. It also creates
* group by clauses for newly added targets to be placed in the extended operator
* node.
*/
static List *
WorkerAggregateExpressionList(Aggref *originalAggregate)
WorkerAggregateExpressionList(Aggref *originalAggregate,
WorkerAggregateWalkerContext *walkerContext)
{
AggregateType aggregateType = GetAggregateType(originalAggregate->aggfnoid);
List *workerAggregateList = NIL;
if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
CountDistinctErrorRate == DISABLE_DISTINCT_APPROXIMATION &&
walkerContext->repartitionSubquery)
{
Aggref *aggregate = (Aggref *) copyObject(originalAggregate);
List *aggTargetEntryList = aggregate->args;
TargetEntry *distinctTargetEntry = (TargetEntry *) linitial(aggTargetEntryList);
List *columnList = pull_var_clause_default((Node *) distinctTargetEntry);
ListCell *columnCell = NULL;
List *processedColumnList = NIL;
foreach(columnCell, columnList)
{
Var *column = (Var *) lfirst(columnCell);
if (list_member(processedColumnList, column))
{
continue;
}
processedColumnList = lappend(processedColumnList, column);
workerAggregateList = lappend(workerAggregateList, copyObject(column));
}
walkerContext->createGroupByClause = true;
}
else if (aggregateType == AGGREGATE_COUNT && originalAggregate->aggdistinct &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
{
/*
* If the original aggregate is a count(distinct) approximation, we want
@ -2148,9 +2320,11 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
bool distinctSupported = true;
List *repartitionNodeList = NIL;
Var *distinctColumn = NULL;
List *multiTableNodeList = NIL;
ListCell *multiTableNodeCell = NULL;
AggregateType aggregateType = AGGREGATE_INVALID_FIRST;
List *tableNodeList = NIL;
List *extendedOpNodeList = NIL;
MultiExtendedOp *extendedOpNode = NULL;
AggregateType aggregateType = GetAggregateType(aggregateExpression->aggfnoid);
/* check if logical plan includes a subquery */
List *subqueryMultiTableList = SubqueryMultiTableList(logicalPlanNode);
@ -2161,20 +2335,43 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
errdetail("distinct in the outermost query is unsupported")));
}
multiTableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
foreach(multiTableNodeCell, multiTableNodeList)
/*
* We partially support count(distinct) in subqueries, other distinct aggregates in
* subqueries are not supported yet.
*/
if (aggregateType == AGGREGATE_COUNT)
{
MultiTable *multiTable = (MultiTable *) lfirst(multiTableNodeCell);
if (multiTable->relationId == SUBQUERY_RELATION_ID)
Node *aggregateArgument = (Node *) linitial(aggregateExpression->args);
List *columnList = pull_var_clause_default(aggregateArgument);
ListCell *columnCell = NULL;
foreach(columnCell, columnList)
{
ereport(ERROR, (errmsg("cannot compute count (distinct)"),
errdetail("Subqueries with aggregate (distinct) are "
"not supported yet")));
Var *column = (Var *) lfirst(columnCell);
if (column->varattno <= 0)
{
ereport(ERROR, (errmsg("cannot compute count (distinct)"),
errdetail("Non-column references are not supported "
"yet")));
}
}
}
else
{
List *multiTableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
ListCell *multiTableNodeCell = NULL;
foreach(multiTableNodeCell, multiTableNodeList)
{
MultiTable *multiTable = (MultiTable *) lfirst(multiTableNodeCell);
if (multiTable->relationId == SUBQUERY_RELATION_ID)
{
ereport(ERROR, (errmsg("cannot compute aggregate (distinct)"),
errdetail("Only count(distinct) aggregate is "
"supported in subqueries")));
}
}
}
/* if we have a count(distinct), and distinct approximation is enabled */
aggregateType = GetAggregateType(aggregateExpression->aggfnoid);
if (aggregateType == AGGREGATE_COUNT &&
CountDistinctErrorRate != DISABLE_DISTINCT_APPROXIMATION)
{
@ -2193,6 +2390,16 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
}
}
if (aggregateType == AGGREGATE_COUNT)
{
List *aggregateVarList = pull_var_clause_default((Node *) aggregateExpression);
if (aggregateVarList == NIL)
{
distinctSupported = false;
errorDetail = "aggregate (distinct) with no columns is unsupported";
}
}
repartitionNodeList = FindNodesOfType(logicalPlanNode, T_MultiPartition);
if (repartitionNodeList != NIL)
{
@ -2200,19 +2407,27 @@ ErrorIfUnsupportedAggregateDistinct(Aggref *aggregateExpression,
errorDetail = "aggregate (distinct) with table repartitioning is unsupported";
}
tableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
extendedOpNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp);
extendedOpNode = (MultiExtendedOp *) linitial(extendedOpNodeList);
distinctColumn = AggregateDistinctColumn(aggregateExpression);
if (distinctColumn == NULL)
if (distinctSupported && distinctColumn == NULL)
{
distinctSupported = false;
errorDetail = "aggregate (distinct) on complex expressions is unsupported";
/*
* If the query has a single table, and table is grouped by partition column,
* then we support count distincts even distinct column can not be identified.
*/
distinctSupported = TablePartitioningSupportsDistinct(tableNodeList,
extendedOpNode,
distinctColumn);
if (!distinctSupported)
{
errorDetail = "aggregate (distinct) on complex expressions is unsupported";
}
}
else
else if (distinctSupported)
{
List *tableNodeList = FindNodesOfType(logicalPlanNode, T_MultiTable);
List *opNodeList = FindNodesOfType(logicalPlanNode, T_MultiExtendedOp);
MultiExtendedOp *extendedOpNode = (MultiExtendedOp *) linitial(opNodeList);
bool supports = TablePartitioningSupportsDistinct(tableNodeList, extendedOpNode,
distinctColumn);
if (!supports)
@ -2299,6 +2514,11 @@ TablePartitioningSupportsDistinct(List *tableNodeList, MultiExtendedOp *opNode,
bool tableDistinctSupported = false;
char partitionMethod = 0;
if (relationId == SUBQUERY_RELATION_ID)
{
return true;
}
/* if table has one shard, task results don't overlap */
List *shardList = LoadShardList(relationId);
if (list_length(shardList) == 1)
@ -2319,7 +2539,8 @@ TablePartitioningSupportsDistinct(List *tableNodeList, MultiExtendedOp *opNode,
bool groupedByPartitionColumn = false;
/* if distinct is on table partition column, we can push it down */
if (tablePartitionColumn->varno == distinctColumn->varno &&
if (distinctColumn != NULL &&
tablePartitionColumn->varno == distinctColumn->varno &&
tablePartitionColumn->varattno == distinctColumn->varattno)
{
tableDistinctSupported = true;

View File

@ -4343,13 +4343,16 @@ MergeTaskList(MapMergeJob *mapMergeJob, List *mapTaskList, uint32 taskIdIndex)
{
StringInfo mergeTableQueryString =
MergeTableQueryString(taskIdIndex, targetEntryList);
char *escapedMergeTableQueryString =
quote_literal_cstr(mergeTableQueryString->data);
StringInfo intermediateTableQueryString =
IntermediateTableQueryString(jobId, taskIdIndex, reduceQuery);
char *escapedIntermediateTableQueryString =
quote_literal_cstr(intermediateTableQueryString->data);
StringInfo mergeAndRunQueryString = makeStringInfo();
appendStringInfo(mergeAndRunQueryString, MERGE_FILES_AND_RUN_QUERY_COMMAND,
jobId, taskIdIndex, mergeTableQueryString->data,
intermediateTableQueryString->data);
jobId, taskIdIndex, escapedMergeTableQueryString,
escapedIntermediateTableQueryString);
mergeTask = CreateBasicTask(jobId, mergeTaskId, MERGE_TASK,
mergeAndRunQueryString->data);

View File

@ -41,7 +41,7 @@
#define MERGE_FILES_INTO_TABLE_COMMAND "SELECT worker_merge_files_into_table \
(" UINT64_FORMAT ", %d, '%s', '%s')"
#define MERGE_FILES_AND_RUN_QUERY_COMMAND \
"SELECT worker_merge_files_and_run_query(" UINT64_FORMAT ", %d, '%s', '%s')"
"SELECT worker_merge_files_and_run_query(" UINT64_FORMAT ", %d, %s, %s)"
typedef enum CitusRTEKind

View File

@ -16,3 +16,4 @@
/multi_subquery.out
/multi_subquery_0.out
/worker_copy.out
/multi_complex_count_distinct.out

View File

@ -920,9 +920,20 @@ SELECT
articles_hash
GROUP BY
author_id;
ERROR: cannot compute aggregate (distinct)
DETAIL: aggregate (distinct) on complex expressions is unsupported
HINT: You can load the hll extension from contrib packages and enable distinct approximations.
c
---
4
5
5
5
5
5
5
5
5
5
(10 rows)
-- queries inside transactions can be router plannable
BEGIN;
SELECT *

View File

@ -171,9 +171,9 @@ from
l_tax) as distributed_table;
ERROR: cannot perform distributed planning on this query
DETAIL: Subqueries without aggregates are not supported yet
-- Check that we don't support subqueries with count(distinct).
-- Check that we support subqueries with count(distinct).
select
different_shipment_days
avg(different_shipment_days)
from
(select
count(distinct l_shipdate) as different_shipment_days
@ -181,8 +181,11 @@ from
lineitem
group by
l_partkey) as distributed_table;
ERROR: cannot compute count (distinct)
DETAIL: Subqueries with aggregate (distinct) are not supported yet
avg
------------------------
1.02907126318497555956
(1 row)
-- Check that if subquery is pulled, we don't error and run query properly.
SELECT max(l_suppkey) FROM
(

View File

@ -0,0 +1,284 @@
--
-- COMPLEX_COUNT_DISTINCT
--
CREATE TABLE lineitem_hash (
l_orderkey bigint not null,
l_partkey integer not null,
l_suppkey integer not null,
l_linenumber integer not null,
l_quantity decimal(15, 2) not null,
l_extendedprice decimal(15, 2) not null,
l_discount decimal(15, 2) not null,
l_tax decimal(15, 2) not null,
l_returnflag char(1) not null,
l_linestatus char(1) not null,
l_shipdate date not null,
l_commitdate date not null,
l_receiptdate date not null,
l_shipinstruct char(25) not null,
l_shipmode char(10) not null,
l_comment varchar(44) not null,
PRIMARY KEY(l_orderkey, l_linenumber) );
SELECT master_create_distributed_table('lineitem_hash', 'l_orderkey', 'hash');
SELECT master_create_worker_shards('lineitem_hash', 8, 1);
\COPY lineitem_hash FROM '@abs_srcdir@/data/lineitem.1.data' with delimiter '|'
\COPY lineitem_hash FROM '@abs_srcdir@/data/lineitem.2.data' with delimiter '|'
SET citus.task_executor_type to "task-tracker";
-- count(distinct) is supported on top level query if there
-- is a grouping on the partition key
SELECT
l_orderkey, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- it is not supported if there is no grouping or grouping is on non-partition field
SELECT
count(DISTINCT l_partkey)
FROM lineitem_hash
ORDER BY 1 DESC
LIMIT 10;
SELECT
l_shipmode, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_shipmode
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- count distinct is supported on single table subqueries
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
SELECT *
FROM (
SELECT
l_partkey, count(DISTINCT l_orderkey)
FROM lineitem_hash
GROUP BY l_partkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- case expr in count distinct is supported.
-- count orders partkeys if l_shipmode is air
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT CASE WHEN l_shipmode = 'AIR' THEN l_partkey ELSE NULL END) as count
FROM lineitem_hash
GROUP BY l_orderkey) sub
WHERE count > 0
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- text like operator is also supported
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT CASE WHEN l_shipmode like '%A%' THEN l_partkey ELSE NULL END) as count
FROM lineitem_hash
GROUP BY l_orderkey) sub
WHERE count > 0
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- count distinct is rejected if it does not reference any columns
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT 1)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- count distinct is rejected if it does not reference any columns
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT (random() * 5)::int)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- even non-const function calls are supported within count distinct
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT (random() * 5)::int = l_linenumber)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 0;
-- multiple nested subquery
SELECT
total,
avg(avg_count) as total_avg_count
FROM (
SELECT
number_sum,
count(DISTINCT l_suppkey) as total,
avg(total_count) avg_count
FROM (
SELECT
l_suppkey,
sum(l_linenumber) as number_sum,
count(DISTINCT l_shipmode) as total_count
FROM
lineitem_hash
WHERE
l_partkey > 100 and
l_quantity > 2 and
l_orderkey < 10000
GROUP BY
l_suppkey) as distributed_table
WHERE
number_sum >= 10
GROUP BY
number_sum) as distributed_table_2
GROUP BY
total
ORDER BY
total_avg_count DESC;
-- multiple cases query
SELECT *
FROM (
SELECT
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_partkey
WHEN l_shipmode = 'AIR' THEN l_quantity
WHEN l_shipmode = 'SHIP' THEN l_discount
ELSE l_suppkey
END) as count,
l_shipdate
FROM
lineitem_hash
GROUP BY
l_shipdate) sub
WHERE
count > 0
ORDER BY
1 DESC, 2 DESC
LIMIT 10;
-- count DISTINCT expression
SELECT *
FROM (
SELECT
l_quantity, count(DISTINCT ((l_orderkey / 1000) * 1000 )) as count
FROM
lineitem_hash
GROUP BY
l_quantity) sub
WHERE
count > 0
ORDER BY
2 DESC, 1 DESC
LIMIT 10;
-- count DISTINCT is part of an expression which inclues another aggregate
SELECT *
FROM (
SELECT
sum(((l_partkey * l_tax) / 100)) /
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_partkey
ELSE l_suppkey
END) as avg,
l_shipmode
FROM
lineitem_hash
GROUP BY
l_shipmode) sub
ORDER BY
1 DESC, 2 DESC
LIMIT 10;
--- count DISTINCT CASE WHEN expression
SELECT *
FROM (
SELECT
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_linenumber
WHEN l_shipmode = 'AIR' THEN l_linenumber + 10
ELSE 2
END) as avg
FROM
lineitem_hash
GROUP BY l_shipdate) sub
ORDER BY 1 DESC
LIMIT 10;
-- COUNT DISTINCT (c1, c2)
SELECT *
FROM
(SELECT
l_shipmode,
count(DISTINCT (l_shipdate, l_tax))
FROM
lineitem_hash
GROUP BY
l_shipmode) t
ORDER BY
2 DESC,1 DESC
LIMIT 10;
-- other distinct aggregate are not supported
SELECT *
FROM (
SELECT
l_orderkey, sum(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
SELECT *
FROM (
SELECT
l_orderkey, avg(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
-- whole row references, oid, and ctid are not supported in count distinct
-- test table does not have oid or ctid enabled, so tests for them are skipped
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT lineitem_hash)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT lineitem_hash.*)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
DROP TABLE lineitem_hash;

View File

@ -144,3 +144,8 @@ test: multi_large_shardid
# multi_drop_extension makes sure we can safely drop and recreate the extension
# ----------
test: multi_drop_extension
# ----------
# multi_complex_count_distinct creates table lineitem_hash, creates shards and load data
# ----------
test: multi_complex_count_distinct

View File

@ -0,0 +1,438 @@
--
-- COMPLEX_COUNT_DISTINCT
--
CREATE TABLE lineitem_hash (
l_orderkey bigint not null,
l_partkey integer not null,
l_suppkey integer not null,
l_linenumber integer not null,
l_quantity decimal(15, 2) not null,
l_extendedprice decimal(15, 2) not null,
l_discount decimal(15, 2) not null,
l_tax decimal(15, 2) not null,
l_returnflag char(1) not null,
l_linestatus char(1) not null,
l_shipdate date not null,
l_commitdate date not null,
l_receiptdate date not null,
l_shipinstruct char(25) not null,
l_shipmode char(10) not null,
l_comment varchar(44) not null,
PRIMARY KEY(l_orderkey, l_linenumber) );
SELECT master_create_distributed_table('lineitem_hash', 'l_orderkey', 'hash');
master_create_distributed_table
---------------------------------
(1 row)
SELECT master_create_worker_shards('lineitem_hash', 8, 1);
master_create_worker_shards
-----------------------------
(1 row)
\COPY lineitem_hash FROM '@abs_srcdir@/data/lineitem.1.data' with delimiter '|'
\COPY lineitem_hash FROM '@abs_srcdir@/data/lineitem.2.data' with delimiter '|'
SET citus.task_executor_type to "task-tracker";
-- count(distinct) is supported on top level query if there
-- is a grouping on the partition key
SELECT
l_orderkey, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
l_orderkey | count
------------+-------
14885 | 7
14884 | 7
14821 | 7
14790 | 7
14785 | 7
14755 | 7
14725 | 7
14694 | 7
14627 | 7
14624 | 7
(10 rows)
-- it is not supported if there is no grouping or grouping is on non-partition field
SELECT
count(DISTINCT l_partkey)
FROM lineitem_hash
ORDER BY 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: table partitioning is unsuitable for aggregate (distinct)
HINT: You can load the hll extension from contrib packages and enable distinct approximations.
SELECT
l_shipmode, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_shipmode
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: table partitioning is unsuitable for aggregate (distinct)
HINT: You can load the hll extension from contrib packages and enable distinct approximations.
-- count distinct is supported on single table subqueries
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
l_orderkey | count
------------+-------
14885 | 7
14884 | 7
14821 | 7
14790 | 7
14785 | 7
14755 | 7
14725 | 7
14694 | 7
14627 | 7
14624 | 7
(10 rows)
SELECT *
FROM (
SELECT
l_partkey, count(DISTINCT l_orderkey)
FROM lineitem_hash
GROUP BY l_partkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
l_partkey | count
-----------+-------
199146 | 3
188804 | 3
177771 | 3
160895 | 3
149926 | 3
136884 | 3
87761 | 3
15283 | 3
6983 | 3
1927 | 3
(10 rows)
-- case expr in count distinct is supported.
-- count orders partkeys if l_shipmode is air
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT CASE WHEN l_shipmode = 'AIR' THEN l_partkey ELSE NULL END) as count
FROM lineitem_hash
GROUP BY l_orderkey) sub
WHERE count > 0
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
l_orderkey | count
------------+-------
12005 | 4
5409 | 4
4964 | 4
14848 | 3
14496 | 3
13473 | 3
13122 | 3
12929 | 3
12645 | 3
12417 | 3
(10 rows)
-- text like operator is also supported
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT CASE WHEN l_shipmode like '%A%' THEN l_partkey ELSE NULL END) as count
FROM lineitem_hash
GROUP BY l_orderkey) sub
WHERE count > 0
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
l_orderkey | count
------------+-------
14275 | 7
14181 | 7
13605 | 7
12707 | 7
12384 | 7
11746 | 7
10727 | 7
10467 | 7
5636 | 7
4614 | 7
(10 rows)
-- count distinct is rejected if it does not reference any columns
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT 1)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: aggregate (distinct) with no columns is unsupported
HINT: You can load the hll extension from contrib packages and enable distinct approximations.
-- count distinct is rejected if it does not reference any columns
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT (random() * 5)::int)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: aggregate (distinct) with no columns is unsupported
HINT: You can load the hll extension from contrib packages and enable distinct approximations.
-- even non-const function calls are supported within count distinct
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT (random() * 5)::int = l_linenumber)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 0;
l_orderkey | count
------------+-------
(0 rows)
-- multiple nested subquery
SELECT
total,
avg(avg_count) as total_avg_count
FROM (
SELECT
number_sum,
count(DISTINCT l_suppkey) as total,
avg(total_count) avg_count
FROM (
SELECT
l_suppkey,
sum(l_linenumber) as number_sum,
count(DISTINCT l_shipmode) as total_count
FROM
lineitem_hash
WHERE
l_partkey > 100 and
l_quantity > 2 and
l_orderkey < 10000
GROUP BY
l_suppkey) as distributed_table
WHERE
number_sum >= 10
GROUP BY
number_sum) as distributed_table_2
GROUP BY
total
ORDER BY
total_avg_count DESC;
total | total_avg_count
-------+--------------------
1 | 3.6000000000000000
6 | 2.8333333333333333
10 | 2.6000000000000000
27 | 2.5555555555555556
32 | 2.4687500000000000
77 | 2.1948051948051948
57 | 2.1754385964912281
(7 rows)
-- multiple cases query
SELECT *
FROM (
SELECT
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_partkey
WHEN l_shipmode = 'AIR' THEN l_quantity
WHEN l_shipmode = 'SHIP' THEN l_discount
ELSE l_suppkey
END) as count,
l_shipdate
FROM
lineitem_hash
GROUP BY
l_shipdate) sub
WHERE
count > 0
ORDER BY
1 DESC, 2 DESC
LIMIT 10;
count | l_shipdate
-------+------------
14 | 07-30-1997
13 | 05-26-1998
13 | 08-08-1997
13 | 11-17-1995
13 | 01-09-1993
12 | 01-15-1998
12 | 10-15-1997
12 | 09-07-1997
12 | 06-02-1997
12 | 03-14-1997
(10 rows)
-- count DISTINCT expression
SELECT *
FROM (
SELECT
l_quantity, count(DISTINCT ((l_orderkey / 1000) * 1000 )) as count
FROM
lineitem_hash
GROUP BY
l_quantity) sub
WHERE
count > 0
ORDER BY
2 DESC, 1 DESC
LIMIT 10;
l_quantity | count
------------+-------
48.00 | 13
47.00 | 13
37.00 | 13
33.00 | 13
26.00 | 13
25.00 | 13
23.00 | 13
21.00 | 13
15.00 | 13
12.00 | 13
(10 rows)
-- count DISTINCT is part of an expression which inclues another aggregate
SELECT *
FROM (
SELECT
sum(((l_partkey * l_tax) / 100)) /
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_partkey
ELSE l_suppkey
END) as avg,
l_shipmode
FROM
lineitem_hash
GROUP BY
l_shipmode) sub
ORDER BY
1 DESC, 2 DESC
LIMIT 10;
avg | l_shipmode
-------------------------+------------
44.82904609027336300064 | MAIL
44.80704536679536679537 | SHIP
44.68891732736572890026 | AIR
44.34106724470134874759 | REG AIR
43.12739987269255251432 | FOB
43.07299253636938646426 | RAIL
40.50298377916903813318 | TRUCK
(7 rows)
--- count DISTINCT CASE WHEN expression
SELECT *
FROM (
SELECT
count(DISTINCT
CASE
WHEN l_shipmode = 'TRUCK' THEN l_linenumber
WHEN l_shipmode = 'AIR' THEN l_linenumber + 10
ELSE 2
END) as avg
FROM
lineitem_hash
GROUP BY l_shipdate) sub
ORDER BY 1 DESC
LIMIT 10;
avg
-----
7
6
6
6
6
6
6
6
5
5
(10 rows)
-- COUNT DISTINCT (c1, c2)
SELECT *
FROM
(SELECT
l_shipmode,
count(DISTINCT (l_shipdate, l_tax))
FROM
lineitem_hash
GROUP BY
l_shipmode) t
ORDER BY
2 DESC,1 DESC
LIMIT 10;
l_shipmode | count
------------+-------
TRUCK | 1689
MAIL | 1683
FOB | 1655
AIR | 1650
SHIP | 1644
RAIL | 1636
REG AIR | 1607
(7 rows)
-- other distinct aggregate are not supported
SELECT *
FROM (
SELECT
l_orderkey, sum(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: Only count(distinct) aggregate is supported in subqueries
SELECT *
FROM (
SELECT
l_orderkey, avg(DISTINCT l_partkey)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute aggregate (distinct)
DETAIL: Only count(distinct) aggregate is supported in subqueries
-- whole row references, oid, and ctid are not supported in count distinct
-- test table does not have oid or ctid enabled, so tests for them are skipped
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT lineitem_hash)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute count (distinct)
DETAIL: Non-column references are not supported yet
SELECT *
FROM (
SELECT
l_orderkey, count(DISTINCT lineitem_hash.*)
FROM lineitem_hash
GROUP BY l_orderkey) sub
ORDER BY 2 DESC, 1 DESC
LIMIT 10;
ERROR: cannot compute count (distinct)
DETAIL: Non-column references are not supported yet
DROP TABLE lineitem_hash;

View File

@ -14,3 +14,4 @@
/multi_stage_more_data.sql
/multi_subquery.sql
/worker_copy.sql
/multi_complex_count_distinct.sql

View File

@ -125,10 +125,10 @@ from
group by
l_tax) as distributed_table;
-- Check that we don't support subqueries with count(distinct).
-- Check that we support subqueries with count(distinct).
select
different_shipment_days
avg(different_shipment_days)
from
(select
count(distinct l_shipdate) as different_shipment_days