diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index c31cf9748..44b4cd9fd 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -240,7 +240,6 @@ static void ExpandWorkerTargetEntry(List *expressionList, bool addToGroupByClause, QueryTargetList *queryTargetList, QueryGroupClause *queryGroupClause); -static Index GetNextSortGroupRef(List *targetEntryList); static TargetEntry * GenerateWorkerTargetEntry(TargetEntry *targetEntry, Expr *workerExpression, AttrNumber targetProjectionNumber); @@ -254,7 +253,6 @@ static AggregateType GetAggregateType(Oid aggFunctionId); static Oid AggregateArgumentType(Aggref *aggregate); static Oid AggregateFunctionOid(const char *functionName, Oid inputType); static Oid TypeOid(Oid schemaId, const char *typeName); -static SortGroupClause * CreateSortGroupClause(Var *column); /* Local functions forward declarations for count(distinct) approximations */ static char * CountDistinctHashFunctionName(Oid argumentType); @@ -2508,7 +2506,7 @@ ExpandWorkerTargetEntry(List *expressionList, TargetEntry *originalTargetEntry, * the next ressortgroupref that should be used based on the * input target list. */ -static Index +Index GetNextSortGroupRef(List *targetEntryList) { ListCell *targetEntryCell = NULL; @@ -2976,7 +2974,7 @@ TypeOid(Oid schemaId, const char *typeName) * The caller should set tleSortGroupRef field and respective * TargetEntry->ressortgroupref fields to appropriate SortGroupRefIndex. */ -static SortGroupClause * +SortGroupClause * CreateSortGroupClause(Var *column) { Oid lessThanOperator = InvalidOid; diff --git a/src/backend/distributed/planner/multi_logical_planner.c b/src/backend/distributed/planner/multi_logical_planner.c index 96f67c708..003a45d7f 100644 --- a/src/backend/distributed/planner/multi_logical_planner.c +++ b/src/backend/distributed/planner/multi_logical_planner.c @@ -840,7 +840,7 @@ DeferErrorIfQueryNotSupported(Query *queryTree) */ if (queryTree->hasSubLinks && !WhereClauseContainsSubquery(queryTree)) { - //preconditionsSatisfied = false; + /* preconditionsSatisfied = false; */ errorMessage = "could not run distributed query with subquery outside the " "FROM and WHERE clauses"; errorHint = filterHint; diff --git a/src/backend/distributed/planner/multi_physical_planner.c b/src/backend/distributed/planner/multi_physical_planner.c index 1dd6c1402..f04785791 100644 --- a/src/backend/distributed/planner/multi_physical_planner.c +++ b/src/backend/distributed/planner/multi_physical_planner.c @@ -145,8 +145,6 @@ static List * AnchorRangeTableIdList(List *rangeTableList, List *baseRangeTableI static void AdjustColumnOldAttributes(List *expressionList); static List * RangeTableFragmentsList(List *rangeTableList, List *whereClauseList, List *dependedJobList); -static OperatorCacheEntry * LookupOperatorByType(Oid typeId, Oid accessMethodId, - int16 strategyNumber); static Oid GetOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber); static List * FragmentCombinationList(List *rangeTableFragmentsList, Query *jobQuery, List *dependedJobList); @@ -3086,7 +3084,7 @@ MakeOpExpression(Var *variable, int16 strategyNumber) * LookupOperatorByType function errors out if it cannot find corresponding * default operator class with the given parameters on the system catalogs. */ -static OperatorCacheEntry * +OperatorCacheEntry * LookupOperatorByType(Oid typeId, Oid accessMethodId, int16 strategyNumber) { OperatorCacheEntry *matchingCacheEntry = NULL; diff --git a/src/backend/distributed/planner/recursive_planning.c b/src/backend/distributed/planner/recursive_planning.c index e115bd253..231129f0a 100644 --- a/src/backend/distributed/planner/recursive_planning.c +++ b/src/backend/distributed/planner/recursive_planning.c @@ -53,6 +53,7 @@ #include "catalog/pg_type.h" #include "catalog/pg_class.h" +#include "catalog/pg_am.h" #include "distributed/citus_nodes.h" #include "distributed/citus_ruleutils.h" #include "distributed/distributed_planner.h" @@ -60,6 +61,7 @@ #include "distributed/metadata_cache.h" #include "distributed/multi_copy.h" #include "distributed/multi_logical_planner.h" +#include "distributed/multi_logical_optimizer.h" #include "distributed/multi_router_planner.h" #include "distributed/multi_physical_planner.h" #include "distributed/multi_server_executor.h" @@ -82,6 +84,7 @@ #include "rewrite/rewriteManip.h" #include "utils/builtins.h" #include "utils/guc.h" +#include "utils/lsyscache.h" #include "../../../include/distributed/query_pushdown_planning.h" @@ -233,7 +236,12 @@ static bool ShouldDeCorrelateSubqueries(Query *query, RecursivePlanningContext * static void ExamineSublinks(Query *query, Node *quals, RecursivePlanningContext *context); static bool SublinkSafeToDeCorrelate(SubLink *sublink); static Expr * ColumnMatchExpressionAtTopLevelConjunction(Node *node, Var *column); -static bool OpExpressionContainsColumnAnyPlace(OpExpr *operatorExpression, Var *partitionColumn); +static bool OpExpressionContainsColumnAnyPlace(OpExpr *operatorExpression, + Var *partitionColumn); +static int +OperatorBtreeStrategy(Oid opno); +static OpExpr * +MakeOpExpressionEquality(Var *variable, Var *secondVar, int16 strategyNumber); static bool SimpleJoinExpression(Expr *clause); static Node * RemoveMatchExpressionAtTopLevelConjunction(Node *quals, Node *node); @@ -494,7 +502,7 @@ ExamineSublinks(Query *query, Node *node, RecursivePlanningContext *context) } - //subselect = copyObject(subselect); + /* subselect = copyObject(subselect); */ /* * The subquery must have a nonempty jointree, else we won't have a join. @@ -512,8 +520,9 @@ ExamineSublinks(Query *query, Node *node, RecursivePlanningContext *context) Node *whereClause = subselect->jointree->quals; OpExpr *opExpr = expr; subselect->jointree->quals = - RemoveMatchExpressionAtTopLevelConjunction(subselect->jointree->quals, - opExpr); + RemoveMatchExpressionAtTopLevelConjunction( + subselect->jointree->quals, + opExpr); /* * The rest of the sub-select must not refer to any Vars of the parent @@ -521,7 +530,8 @@ ExamineSublinks(Query *query, Node *node, RecursivePlanningContext *context) */ if (contain_vars_of_level((Node *) subselect, 1)) { - elog(DEBUG2, "skipping since other parts of the query also has correlation"); + elog(DEBUG2, + "skipping since other parts of the query also has correlation"); return; } @@ -546,18 +556,18 @@ ExamineSublinks(Query *query, Node *node, RecursivePlanningContext *context) } List *rightColumnNames = NIL; - List *rightColumnVars = NIL; - List *leftColumnNames = NIL; - List *leftColumnVars = NIL; - List *joinedColumnNames = NIL; - List *joinedColumnVars = NIL; + List *rightColumnVars = NIL; + List *leftColumnNames = NIL; + List *leftColumnVars = NIL; + List *joinedColumnNames = NIL; + List *joinedColumnVars = NIL; RangeTblEntry *newRte = makeNode(RangeTblEntry); newRte->rtekind = RTE_SUBQUERY; newRte->subquery = subselect; - newRte->alias = makeAlias("new_sub", NIL); - newRte->eref = makeAlias("new_sub", NIL); + newRte->alias = makeAlias("new_sub", NIL); + newRte->eref = makeAlias("new_sub", NIL); TargetEntry *tList = makeTargetEntry((Expr *) copyObject(secondVar), @@ -569,15 +579,14 @@ ExamineSublinks(Query *query, Node *node, RecursivePlanningContext *context) int subqueryOffset = list_length(query->rtable) + 1; /* build the join tree using the read_intermediate_result RTE */ - RangeTblRef *subqueryRteRef = makeNode(RangeTblRef); - subqueryRteRef->rtindex = subqueryOffset; + RangeTblRef *subqueryRteRef = makeNode(RangeTblRef); + subqueryRteRef->rtindex = subqueryOffset; RangeTblEntry *l_rte = linitial(query->rtable); query->rtable = list_concat(query->rtable, list_make1(newRte)); - /* remove the NOT EXISTS part */ query->jointree->quals = RemoveMatchExpressionAtTopLevelConjunction(query->jointree->quals, @@ -600,10 +609,10 @@ ExamineSublinks(Query *query, Node *node, RecursivePlanningContext *context) elog(INFO, "rightColumnNames: %s", nodeToString(rightColumnNames)); elog(INFO, "rightColumnVars: %s", nodeToString(rightColumnVars)); - newRte->alias = makeAlias("decorralated", copyObject(rightColumnNames)); - newRte->eref = makeAlias("decorralated", copyObject(rightColumnNames)); - - + newRte->alias = makeAlias("decorralated", copyObject( + rightColumnNames)); + newRte->eref = makeAlias("decorralated", copyObject( + rightColumnNames)); /* @@ -651,28 +660,23 @@ ExamineSublinks(Query *query, Node *node, RecursivePlanningContext *context) joinedColumnVars = list_concat(joinedColumnVars, rightColumnVars); - RangeTblEntry *rteJoin = makeNode(RangeTblEntry); - rteJoin->rtekind = RTE_JOIN; - rteJoin->relid = InvalidOid; - rteJoin->subquery = NULL; - rteJoin->jointype = JOIN_LEFT; - rteJoin->joinaliasvars = joinedColumnVars; + rteJoin->rtekind = RTE_JOIN; + rteJoin->relid = InvalidOid; + rteJoin->subquery = NULL; + rteJoin->jointype = JOIN_LEFT; + rteJoin->joinaliasvars = joinedColumnVars; - rteJoin->eref = makeAlias("unnamed_citus_join", joinedColumnNames); - rteJoin->alias = makeAlias("unnamed_citus_join", joinedColumnNames); + rteJoin->eref = makeAlias("unnamed_citus_join", joinedColumnNames); + rteJoin->alias = makeAlias("unnamed_citus_join", joinedColumnNames); - query->rtable = lappend(query->rtable, rteJoin); + query->rtable = lappend(query->rtable, rteJoin); StringInfo str = makeStringInfo(); - //deparse_shard_query(query, 0, 0, str); - //elog(INFO, "Current subquery: %s", str->data); + /* deparse_shard_query(query, 0, 0, str); */ + /* elog(INFO, "Current subquery: %s", str->data); */ RecursivelyPlanSubquery(query, context); - - - - } } } @@ -681,25 +685,237 @@ ExamineSublinks(Query *query, Node *node, RecursivePlanningContext *context) if (is_opclause(node)) { - Node *leftOp = get_leftop((Expr *) node); - Node *rightOp = get_rightop((Expr *) node); + Node *leftOp = strip_implicit_coercions(get_leftop((Expr *) node)); + Node *rightOp = strip_implicit_coercions(get_rightop((Expr *) node)); + + OpExpr *topLevelOpExpr = (OpExpr *) node; + + SubLink *sublinkToProcess = NULL; + Var *topLevelOpClaueVar = NULL; + List *correlatedVarList = NIL; + Var *correlatedVar = NULL; + Var *subselectJoinColumn = NULL; + OpExpr *opExpr = NULL; + Expr *correlationExpr = NULL; - if (IsA(strip_implicit_coercions(leftOp), SubLink)) + List *rightColumnNames = NIL; + List *rightColumnVars = NIL; + List *leftColumnNames = NIL; + List *leftColumnVars = NIL; + List *joinedColumnNames = NIL; + List *joinedColumnVars = NIL; + + Query *subselect = NULL; + + + if (IsA(leftOp, SubLink) && IsA(rightOp, Var)) { - if (SublinkSafeToDeCorrelate((SubLink *) strip_implicit_coercions(leftOp))) - { - elog(INFO, "OpClause is found on the left"); - } + elog(INFO, "OpClause is found on the left"); + sublinkToProcess = leftOp; + topLevelOpClaueVar = rightOp; + } + else if (IsA(leftOp, Var) && IsA(rightOp, SubLink)) + { + elog(INFO, "OpClause is found on the right"); + + sublinkToProcess = rightOp; + topLevelOpClaueVar = leftOp; + } + else + { + elog(DEBUG2, "Op clause is not OK to de correlate"); + + return; } - if (IsA(strip_implicit_coercions(rightOp), SubLink)) + subselect = (Query *) sublinkToProcess->subselect; + correlatedVarList = pull_vars_of_level(subselect->jointree->quals, 1); + + if (list_length(correlatedVarList) != 1) { - if (SublinkSafeToDeCorrelate((SubLink *) strip_implicit_coercions(rightOp))) - { - elog(INFO, "OpClause is found on the right"); - } + elog(DEBUG2, "skipping since expression condition didn't hold"); + + return; } + + correlatedVar = linitial(correlatedVarList); + + correlationExpr = + ColumnMatchExpressionAtTopLevelConjunction( + subselect->jointree->quals, correlatedVar); + + if (correlationExpr && !IsA(correlationExpr, OpExpr)) + { + elog(DEBUG2, "skipping since OP EXPRs condition didn't hold"); + return; + } + opExpr = (OpExpr *) correlationExpr; + + + if (equal(strip_implicit_coercions(get_leftop(correlationExpr)), correlatedVar)) + { + subselectJoinColumn = (Var *) strip_implicit_coercions(get_rightop( + correlationExpr)); + } + else + { + subselectJoinColumn = (Var *) strip_implicit_coercions(get_leftop( + correlationExpr)); + } + + subselect->jointree->quals = + RemoveMatchExpressionAtTopLevelConjunction(subselect->jointree->quals, + opExpr); + + /* Add GROUP BY clause to the subselect */ + SortGroupClause *groupByClause = CreateSortGroupClause(subselectJoinColumn); + Index nextSortGroupRefIndex = GetNextSortGroupRef(subselect->targetList); + TargetEntry *newTargetEntry = makeNode(TargetEntry); + StringInfo columnNameString = makeStringInfo(); + + appendStringInfo(columnNameString, "worker_column_%d", + list_length(subselect->targetList) + 1); + + newTargetEntry->resname = columnNameString->data; + + + /* force resjunk to false as we may need this on the master */ + newTargetEntry->expr = copyObject(subselectJoinColumn); + newTargetEntry->resjunk = false; + newTargetEntry->resno = list_length(subselect->targetList) + 1; + + newTargetEntry->ressortgroupref = nextSortGroupRefIndex; + subselect->targetList = lappend(subselect->targetList, newTargetEntry); + + /* the group by clause entry should point to the correct index in the target list */ + groupByClause->tleSortGroupRef = nextSortGroupRefIndex; + + /* update the group by list and the index's value */ + subselect->groupClause = + lappend(subselect->groupClause, groupByClause); + + + RangeTblEntry *rteSubquery = makeNode(RangeTblEntry); + rteSubquery->rtekind = RTE_SUBQUERY; + rteSubquery->subquery = subselect; + rteSubquery->lateral = false; + rteSubquery->alias = makeAlias("new_sub_1", list_make2(makeString("Onder_col"), makeString("onder_2_col"))); + rteSubquery->eref = makeAlias("new_sub_1",list_make2(makeString("Onder_col"), makeString("onder_2_col"))); + + RangeTblRef *subqueryRteRef = makeNode(RangeTblRef); + subqueryRteRef->rtindex = list_length(query->rtable) + 1; + query->rtable = lappend(query->rtable, rteSubquery); + + expandRTE(rteSubquery, subqueryRteRef->rtindex, 0, -1, false, + &rightColumnNames, &rightColumnVars); + + rteSubquery->alias = makeAlias("new_sub", rightColumnNames); + rteSubquery->eref = makeAlias("new_sub", rightColumnNames); + elog(INFO, "rightColumnNames: %s", nodeToString(rightColumnNames)); + elog(INFO, "rightColumnVars: %s", nodeToString(rightColumnVars)); + + query->jointree->quals = + RemoveMatchExpressionAtTopLevelConjunction(query->jointree->quals, + node); + + IncrementVarSublevelsUp(correlatedVar, -1, 1); + + TargetEntry *addedGroupByTargetEntry = copyObject(list_nth(subselect->targetList, newTargetEntry->resno - 1)); + + Var *addedGroupByColumn = makeNode(Var); + + addedGroupByColumn->varno = list_length(query->rtable); + addedGroupByColumn->varattno = newTargetEntry->resno; + + addedGroupByColumn->varlevelsup = 0; + addedGroupByColumn->vartype = ((Var *)addedGroupByTargetEntry->expr)->vartype; + addedGroupByColumn->vartypmod = ((Var *)addedGroupByTargetEntry->expr)->vartypmod; + addedGroupByColumn->varcollid = ((Var *)addedGroupByTargetEntry->expr)->varcollid; + + + OpExpr *equaltyOp = MakeOpExpressionEquality(correlatedVar, addedGroupByColumn, BTEqualStrategyNumber); + + + + /* + * And finally, build the JoinExpr node. + */ + JoinExpr *result = makeNode(JoinExpr); + result->jointype = JOIN_INNER; + result->isNatural = false; + + result->rarg = subqueryRteRef; + + if (list_length(query->jointree->fromlist) == 1) + { + result->larg = (Node *) linitial(query->jointree->fromlist); + } + else + { + result->larg = (Node *) query->jointree; + } + + + result->usingClause = NIL; + result->quals = equaltyOp; + + + TargetEntry *existingAggrageteColumn = copyObject(list_nth(subselect->targetList, 0)); + + Var *existingGroupByColumn = makeNode(Var); + + existingGroupByColumn->varno = list_length(query->rtable); + existingGroupByColumn->varattno = 1; + + /* TODO: Fix the following */ + existingGroupByColumn->varlevelsup = 0; + existingGroupByColumn->vartype = ((Var *)addedGroupByTargetEntry->expr)->vartype; + existingGroupByColumn->vartypmod = ((Var *)addedGroupByTargetEntry->expr)->vartypmod; + existingGroupByColumn->varcollid = ((Var *)addedGroupByTargetEntry->expr)->varcollid; + + OpExpr *otherOperator = MakeOpExpressionEquality(topLevelOpClaueVar, existingGroupByColumn, OperatorBtreeStrategy(topLevelOpExpr->opno)); + + + result->quals = make_and_qual(result->quals, otherOperator); + + + result->alias = NULL; + result->rtindex = list_length(query->rtable) + 1; + query->jointree = makeFromExpr(list_make1(result), NULL); + + expandRTE(linitial(query->rtable), 1, 0, -1, false, + &leftColumnNames, &leftColumnVars); + + joinedColumnNames = list_concat(joinedColumnNames, leftColumnNames); + joinedColumnVars = list_concat(joinedColumnVars, leftColumnVars); + + joinedColumnNames = list_concat(joinedColumnNames, rightColumnNames); + joinedColumnVars = list_concat(joinedColumnVars, rightColumnVars); + + + + RangeTblEntry *rteJoin = makeNode(RangeTblEntry); + + rteJoin->rtekind = RTE_JOIN; + rteJoin->relid = InvalidOid; + rteJoin->subquery = NULL; + rteJoin->jointype = JOIN_INNER; + rteJoin->joinaliasvars = joinedColumnVars; + + rteJoin->eref = makeAlias("unnamed_citus_join", joinedColumnNames); + rteJoin->alias = makeAlias("unnamed_citus_join", joinedColumnNames); + + query->rtable = lappend(query->rtable, rteJoin); + + + StringInfo str = makeStringInfo(); + deparse_shard_query(query, 0, 0, str); + elog(INFO, "Current subquery: %s", str->data); + + + RecursivelyPlanSubquery(query, context); + } if (and_clause(node)) @@ -714,6 +930,79 @@ ExamineSublinks(Query *query, Node *node, RecursivePlanningContext *context) } +/* + * OperatorImplementsEquality returns true if the given opno represents an + * equality operator. The function retrieves btree interpretation list for this + * opno and check if BTEqualStrategyNumber strategy is present. + */ +static int +OperatorBtreeStrategy(Oid opno) +{ + bool equalityOperator = false; + List *btreeIntepretationList = get_op_btree_interpretation(opno); + ListCell *btreeInterpretationCell = NULL; + foreach(btreeInterpretationCell, btreeIntepretationList) + { + OpBtreeInterpretation *btreeIntepretation = (OpBtreeInterpretation *) + lfirst(btreeInterpretationCell); + return btreeIntepretation->strategy; + } + + return -1; +} + + +/* + * MakeOpExpression builds an operator expression node. This operator expression + * implements the operator clause as defined by the variable and the strategy + * number. + */ +static OpExpr * +MakeOpExpressionEquality(Var *variable, Var *secondVar, int16 strategyNumber) +{ + Oid typeId = variable->vartype; + Oid typeModId = variable->vartypmod; + Oid collationId = variable->varcollid; + + OperatorCacheEntry *operatorCacheEntry = NULL; + Oid accessMethodId = BTREE_AM_OID; + Oid operatorId = InvalidOid; + Oid operatorClassInputType = InvalidOid; + OpExpr *expression = NULL; + char typeType = 0; + + operatorCacheEntry = LookupOperatorByType(typeId, accessMethodId, strategyNumber); + + operatorId = operatorCacheEntry->operatorId; + operatorClassInputType = operatorCacheEntry->operatorClassInputType; + typeType = operatorCacheEntry->typeType; + + /* + * Relabel variable if input type of default operator class is not equal to + * the variable type. Note that we don't relabel the variable if the default + * operator class variable type is a pseudo-type. + */ + if (operatorClassInputType != typeId && typeType != TYPTYPE_PSEUDO) + { + variable = (Var *) makeRelabelType((Expr *) variable, operatorClassInputType, + -1, collationId, COERCE_IMPLICIT_CAST); + } + + /* Now make the expression with the given variable and a null constant */ + expression = (OpExpr *) make_opclause(operatorId, + InvalidOid, /* no result type yet */ + false, /* no return set */ + (Expr *) variable, + (Expr *) secondVar, + InvalidOid, collationId); + + /* Set implementing function id and result type */ + expression->opfuncid = get_opcode(operatorId); + expression->opresulttype = get_func_rettype(expression->opfuncid); + + return expression; +} + /* * ColumnMatchExpressionAtTopLevelConjunction returns true if the query contains an exact * match (equal) expression on the provided column. The function returns true only @@ -773,6 +1062,7 @@ ColumnMatchExpressionAtTopLevelConjunction(Node *node, Var *column) return NULL; } + static bool OpExpressionContainsColumnAnyPlace(OpExpr *operatorExpression, Var *partitionColumn) { @@ -789,14 +1079,18 @@ OpExpressionContainsColumnAnyPlace(OpExpr *operatorExpression, Var *partitionCol column = (Var *) leftOperand; if (equal(column, partitionColumn)) + { return true; + } } if (IsA(rightOperand, Var)) { column = (Var *) rightOperand; if (equal(column, partitionColumn)) + { return true; + } } return equal(column, partitionColumn); diff --git a/src/include/distributed/multi_logical_optimizer.h b/src/include/distributed/multi_logical_optimizer.h index 877a0addf..d679e6507 100644 --- a/src/include/distributed/multi_logical_optimizer.h +++ b/src/include/distributed/multi_logical_optimizer.h @@ -141,4 +141,8 @@ extern void FindReferencedTableColumn(Expr *columnExpression, List *parentQueryL extern bool IsGroupBySubsetOfDistinct(List *groupClause, List *distinctClause); +extern Index GetNextSortGroupRef(List *targetEntryList); +extern SortGroupClause * CreateSortGroupClause(Var *column); + + #endif /* MULTI_LOGICAL_OPTIMIZER_H */ diff --git a/src/include/distributed/multi_physical_planner.h b/src/include/distributed/multi_physical_planner.h index a9b805a9b..b241efdea 100644 --- a/src/include/distributed/multi_physical_planner.h +++ b/src/include/distributed/multi_physical_planner.h @@ -316,6 +316,8 @@ extern Node * BuildBaseConstraint(Var *column); extern void UpdateConstraint(Node *baseConstraint, ShardInterval *shardInterval); extern bool SimpleOpExpression(Expr *clause); extern bool OpExpressionContainsColumn(OpExpr *operatorExpression, Var *partitionColumn); +extern OperatorCacheEntry * LookupOperatorByType(Oid typeId, Oid accessMethodId, + int16 strategyNumber); /* helper functions */ extern Var * MakeInt4Column(void);