LocalJoin

multi-column-distribution
Jelte Fennema 2021-06-10 10:19:12 +02:00
parent 3797092ae4
commit b43c17b2de
6 changed files with 157 additions and 95 deletions

View File

@ -31,6 +31,11 @@
- [ ] PostprocessAlterTableStmtAttachPartition() - [ ] PostprocessAlterTableStmtAttachPartition()
- [x] TargetListOnPartitionColumn() - [x] TargetListOnPartitionColumn()
- [ ] PartitionColumnForPushedDownSubquery() - [ ] PartitionColumnForPushedDownSubquery()
- [ ] CoPartitionedTables()
- [x] LocalJoin()
- [ ] SinglePartitionJoin()
- [ ] ApplySinglePartitionJoin()
- [ ] MultiJoinTree()
# query pushdown planner # query pushdown planner
- [x] RestrictionEquivalenceForPartitionKeys() - [x] RestrictionEquivalenceForPartitionKeys()

View File

@ -97,7 +97,8 @@ static JoinOrderNode * CartesianProduct(JoinOrderNode *joinNode,
JoinType joinType); JoinType joinType);
static JoinOrderNode * MakeJoinOrderNode(TableEntry *tableEntry, static JoinOrderNode * MakeJoinOrderNode(TableEntry *tableEntry,
JoinRuleType joinRuleType, JoinRuleType joinRuleType,
List *partitionColumnList, char partitionMethod, List *partitionColumnListList,
char partitionMethod,
TableEntry *anchorTable); TableEntry *anchorTable);
@ -343,11 +344,18 @@ JoinOrderForTable(TableEntry *firstTable, List *tableEntryList, List *joinClause
/* create join node for the first table */ /* create join node for the first table */
Oid firstRelationId = firstTable->relationId; Oid firstRelationId = firstTable->relationId;
uint32 firstTableId = firstTable->rangeTableId; uint32 firstTableId = firstTable->rangeTableId;
Var *firstPartitionColumn = PartitionColumn(firstRelationId, firstTableId); List *firstPartitionColumnListList = NIL;
Var *partitionColumn = NULL;
List *partitionColumnList = PartitionColumns(firstRelationId, firstTableId);
foreach_ptr(partitionColumn, partitionColumnList)
{
firstPartitionColumnListList = lappend(
firstPartitionColumnListList, list_make1(partitionColumn));
}
char firstPartitionMethod = PartitionMethod(firstRelationId); char firstPartitionMethod = PartitionMethod(firstRelationId);
JoinOrderNode *firstJoinNode = MakeJoinOrderNode(firstTable, firstJoinRule, JoinOrderNode *firstJoinNode = MakeJoinOrderNode(firstTable, firstJoinRule,
list_make1(firstPartitionColumn), firstPartitionColumnListList,
firstPartitionMethod, firstPartitionMethod,
firstTable); firstTable);
@ -829,7 +837,7 @@ ReferenceJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
return NULL; return NULL;
} }
return MakeJoinOrderNode(candidateTable, REFERENCE_JOIN, return MakeJoinOrderNode(candidateTable, REFERENCE_JOIN,
currentJoinNode->partitionColumnList, currentJoinNode->partitionColumnListList,
currentJoinNode->partitionMethod, currentJoinNode->partitionMethod,
currentJoinNode->anchorTable); currentJoinNode->anchorTable);
} }
@ -881,7 +889,7 @@ CartesianProductReferenceJoin(JoinOrderNode *currentJoinNode, TableEntry *candid
return NULL; return NULL;
} }
return MakeJoinOrderNode(candidateTable, CARTESIAN_PRODUCT_REFERENCE_JOIN, return MakeJoinOrderNode(candidateTable, CARTESIAN_PRODUCT_REFERENCE_JOIN,
currentJoinNode->partitionColumnList, currentJoinNode->partitionColumnListList,
currentJoinNode->partitionMethod, currentJoinNode->partitionMethod,
currentJoinNode->anchorTable); currentJoinNode->anchorTable);
} }
@ -905,8 +913,8 @@ LocalJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
{ {
Oid relationId = candidateTable->relationId; Oid relationId = candidateTable->relationId;
uint32 tableId = candidateTable->rangeTableId; uint32 tableId = candidateTable->rangeTableId;
Var *candidatePartitionColumn = PartitionColumn(relationId, tableId); List *candidatePartitionColumnList = PartitionColumns(relationId, tableId);
List *currentPartitionColumnList = currentJoinNode->partitionColumnList; List *currentPartitionColumnListList = currentJoinNode->partitionColumnListList;
char candidatePartitionMethod = PartitionMethod(relationId); char candidatePartitionMethod = PartitionMethod(relationId);
char currentPartitionMethod = currentJoinNode->partitionMethod; char currentPartitionMethod = currentJoinNode->partitionMethod;
TableEntry *currentAnchorTable = currentJoinNode->anchorTable; TableEntry *currentAnchorTable = currentJoinNode->anchorTable;
@ -926,14 +934,30 @@ LocalJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
return NULL; return NULL;
} }
bool joinOnPartitionColumns = JoinOnColumns(currentPartitionColumnList, /*
candidatePartitionColumn, * If the number of partition columns don't match then we cannot do a local
applicableJoinClauses); * join.
if (!joinOnPartitionColumns) */
if (list_length(candidatePartitionColumnList) != list_length(
currentPartitionColumnListList))
{ {
return NULL; return NULL;
} }
Var *candidatePartitionColumn = NULL;
List *currentPartitionColumnList = NIL;
forboth_ptr(candidatePartitionColumn, candidatePartitionColumnList,
currentPartitionColumnList, currentPartitionColumnListList)
{
bool joinOnPartitionColumns = JoinOnColumns(currentPartitionColumnList,
candidatePartitionColumn,
applicableJoinClauses);
if (!joinOnPartitionColumns)
{
return NULL;
}
}
/* shard interval lists must have 1-1 matching for local joins */ /* shard interval lists must have 1-1 matching for local joins */
bool coPartitionedTables = CoPartitionedTables(currentAnchorTable->relationId, bool coPartitionedTables = CoPartitionedTables(currentAnchorTable->relationId,
relationId); relationId);
@ -949,11 +973,20 @@ LocalJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
* subsequent joins colocated with this candidate table to correctly be recognized as * subsequent joins colocated with this candidate table to correctly be recognized as
* a local join as well. * a local join as well.
*/ */
currentPartitionColumnList = list_append_unique(currentPartitionColumnList, ListCell *candidatePartitionColumnCell = NULL;
candidatePartitionColumn); ListCell *currentPartitionColumnListCell = NULL;
forboth(candidatePartitionColumnCell, candidatePartitionColumnList,
currentPartitionColumnListCell, currentPartitionColumnListList)
{
candidatePartitionColumn = lfirst(candidatePartitionColumnCell);
currentPartitionColumnList = lfirst(currentPartitionColumnListCell);
lfirst(currentPartitionColumnListCell) =
list_append_unique(currentPartitionColumnList,
candidatePartitionColumn);
}
JoinOrderNode *nextJoinNode = MakeJoinOrderNode(candidateTable, LOCAL_PARTITION_JOIN, JoinOrderNode *nextJoinNode = MakeJoinOrderNode(candidateTable, LOCAL_PARTITION_JOIN,
currentPartitionColumnList, currentPartitionColumnListList,
currentPartitionMethod, currentPartitionMethod,
currentAnchorTable); currentAnchorTable);
@ -973,7 +1006,7 @@ static JoinOrderNode *
SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable, SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
List *applicableJoinClauses, JoinType joinType) List *applicableJoinClauses, JoinType joinType)
{ {
List *currentPartitionColumnList = currentJoinNode->partitionColumnList; List *currentPartitionColumnListList = currentJoinNode->partitionColumnListList;
char currentPartitionMethod = currentJoinNode->partitionMethod; char currentPartitionMethod = currentJoinNode->partitionMethod;
TableEntry *currentAnchorTable = currentJoinNode->anchorTable; TableEntry *currentAnchorTable = currentJoinNode->anchorTable;
JoinRuleType currentJoinRuleType = currentJoinNode->joinRuleType; JoinRuleType currentJoinRuleType = currentJoinNode->joinRuleType;
@ -981,8 +1014,6 @@ SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
Oid relationId = candidateTable->relationId; Oid relationId = candidateTable->relationId;
uint32 tableId = candidateTable->rangeTableId; uint32 tableId = candidateTable->rangeTableId;
Var *candidatePartitionColumn = PartitionColumn(relationId, tableId);
char candidatePartitionMethod = PartitionMethod(relationId);
/* outer joins are not supported yet */ /* outer joins are not supported yet */
if (IS_OUTER_JOIN(joinType)) if (IS_OUTER_JOIN(joinType))
@ -1000,8 +1031,17 @@ SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
return NULL; return NULL;
} }
if (list_length(currentPartitionColumnListList) > 1)
{
/*
* TODO: Implement single partition join for multi column distributed
* tables.
*/
return NULL;
}
OpExpr *joinClause = OpExpr *joinClause =
SinglePartitionJoinClause(currentPartitionColumnList, applicableJoinClauses); SinglePartitionJoinClause(currentPartitionColumnListList,
applicableJoinClauses);
if (joinClause != NULL) if (joinClause != NULL)
{ {
if (currentPartitionMethod == DISTRIBUTE_BY_HASH) if (currentPartitionMethod == DISTRIBUTE_BY_HASH)
@ -1016,29 +1056,41 @@ SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
} }
return MakeJoinOrderNode(candidateTable, SINGLE_HASH_PARTITION_JOIN, return MakeJoinOrderNode(candidateTable, SINGLE_HASH_PARTITION_JOIN,
currentPartitionColumnList, currentPartitionColumnListList,
currentPartitionMethod, currentPartitionMethod,
currentAnchorTable); currentAnchorTable);
} }
else else
{ {
return MakeJoinOrderNode(candidateTable, SINGLE_RANGE_PARTITION_JOIN, return MakeJoinOrderNode(candidateTable, SINGLE_RANGE_PARTITION_JOIN,
currentPartitionColumnList, currentPartitionColumnListList,
currentPartitionMethod, currentPartitionMethod,
currentAnchorTable); currentAnchorTable);
} }
} }
char candidatePartitionMethod = PartitionMethod(relationId);
/* evaluate re-partitioning the current table only if the rule didn't apply above */ /* evaluate re-partitioning the current table only if the rule didn't apply above */
if (candidatePartitionMethod != DISTRIBUTE_BY_NONE) if (candidatePartitionMethod != DISTRIBUTE_BY_NONE)
{ {
List *candidatePartitionColumnList = PartitionColumns(relationId, tableId);
if (list_length(currentPartitionColumnListList) > 1)
{
/*
* TODO: Implement single partition join for multi column distributed
* tables.
*/
return NULL;
}
/* /*
* Create a new unique list (set) with the partition column of the candidate table * Create a new unique list (set) with the partition column of the candidate table
* to check if a single repartition join will work for this table. When it works * to check if a single repartition join will work for this table. When it works
* the set is retained on the MultiJoinNode for later local join verification. * the set is retained on the MultiJoinNode for later local join verification.
*/ */
List *candidatePartitionColumnList = list_make1(candidatePartitionColumn); List *candidatePartitionColumnListList = list_make1(candidatePartitionColumnList);
joinClause = SinglePartitionJoinClause(candidatePartitionColumnList, joinClause = SinglePartitionJoinClause(candidatePartitionColumnListList,
applicableJoinClauses); applicableJoinClauses);
if (joinClause != NULL) if (joinClause != NULL)
{ {
@ -1055,7 +1107,7 @@ SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
return MakeJoinOrderNode(candidateTable, return MakeJoinOrderNode(candidateTable,
SINGLE_HASH_PARTITION_JOIN, SINGLE_HASH_PARTITION_JOIN,
candidatePartitionColumnList, candidatePartitionColumnListList,
candidatePartitionMethod, candidatePartitionMethod,
candidateTable); candidateTable);
} }
@ -1063,7 +1115,7 @@ SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
{ {
return MakeJoinOrderNode(candidateTable, return MakeJoinOrderNode(candidateTable,
SINGLE_RANGE_PARTITION_JOIN, SINGLE_RANGE_PARTITION_JOIN,
candidatePartitionColumnList, candidatePartitionColumnListList,
candidatePartitionMethod, candidatePartitionMethod,
candidateTable); candidateTable);
} }
@ -1080,13 +1132,19 @@ SinglePartitionJoin(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
* clause exists, the function returns NULL. * clause exists, the function returns NULL.
*/ */
OpExpr * OpExpr *
SinglePartitionJoinClause(List *partitionColumnList, List *applicableJoinClauses) SinglePartitionJoinClause(List *partitionColumnListList, List *applicableJoinClauses)
{ {
if (list_length(partitionColumnList) == 0) if (list_length(partitionColumnListList) == 0)
{ {
return NULL; return NULL;
} }
/*
* TODO: Support multi column distributed tables.
*/
Assert(list_length(partitionColumnListList) == 1);
List *partitionColumnList = linitial(partitionColumnListList);
Var *partitionColumn = NULL; Var *partitionColumn = NULL;
foreach_ptr(partitionColumn, partitionColumnList) foreach_ptr(partitionColumn, partitionColumnList)
{ {
@ -1210,7 +1268,7 @@ CartesianProduct(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
{ {
/* Because of the cartesian product, anchor table information got lost */ /* Because of the cartesian product, anchor table information got lost */
return MakeJoinOrderNode(candidateTable, CARTESIAN_PRODUCT, return MakeJoinOrderNode(candidateTable, CARTESIAN_PRODUCT,
currentJoinNode->partitionColumnList, currentJoinNode->partitionColumnListList,
currentJoinNode->partitionMethod, currentJoinNode->partitionMethod,
NULL); NULL);
} }
@ -1222,14 +1280,14 @@ CartesianProduct(JoinOrderNode *currentJoinNode, TableEntry *candidateTable,
/* Constructs and returns a join-order node with the given arguments */ /* Constructs and returns a join-order node with the given arguments */
JoinOrderNode * JoinOrderNode *
MakeJoinOrderNode(TableEntry *tableEntry, JoinRuleType joinRuleType, MakeJoinOrderNode(TableEntry *tableEntry, JoinRuleType joinRuleType,
List *partitionColumnList, char partitionMethod, List *partitionColumnListList, char partitionMethod,
TableEntry *anchorTable) TableEntry *anchorTable)
{ {
JoinOrderNode *joinOrderNode = palloc0(sizeof(JoinOrderNode)); JoinOrderNode *joinOrderNode = palloc0(sizeof(JoinOrderNode));
joinOrderNode->tableEntry = tableEntry; joinOrderNode->tableEntry = tableEntry;
joinOrderNode->joinRuleType = joinRuleType; joinOrderNode->joinRuleType = joinRuleType;
joinOrderNode->joinType = JOIN_INNER; joinOrderNode->joinType = JOIN_INNER;
joinOrderNode->partitionColumnList = partitionColumnList; joinOrderNode->partitionColumnListList = partitionColumnListList;
joinOrderNode->partitionMethod = partitionMethod; joinOrderNode->partitionMethod = partitionMethod;
joinOrderNode->joinClauseList = NIL; joinOrderNode->joinClauseList = NIL;
joinOrderNode->anchorTable = anchorTable; joinOrderNode->anchorTable = anchorTable;

View File

@ -62,7 +62,7 @@ typedef struct QualifierWalkerContext
/* Function pointer type definition for apply join rule functions */ /* Function pointer type definition for apply join rule functions */
typedef MultiNode *(*RuleApplyFunction) (MultiNode *leftNode, MultiNode *rightNode, typedef MultiNode *(*RuleApplyFunction) (MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType joinType,
List *joinClauses); List *joinClauses);
typedef bool (*CheckNodeFunc)(Node *); typedef bool (*CheckNodeFunc)(Node *);
@ -94,38 +94,40 @@ static bool IsSelectClause(Node *clause);
/* Local functions forward declarations for applying joins */ /* Local functions forward declarations for applying joins */
static MultiNode * ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode, static MultiNode * ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode,
JoinRuleType ruleType, List *partitionColumnList, JoinRuleType ruleType, List *partitionColumnListList,
JoinType joinType, List *joinClauseList); JoinType joinType, List *joinClauseList);
static RuleApplyFunction JoinRuleApplyFunction(JoinRuleType ruleType); static RuleApplyFunction JoinRuleApplyFunction(JoinRuleType ruleType);
static MultiNode * ApplyReferenceJoin(MultiNode *leftNode, MultiNode *rightNode, static MultiNode * ApplyReferenceJoin(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType joinType,
List *joinClauses); List *joinClauses);
static MultiNode * ApplyLocalJoin(MultiNode *leftNode, MultiNode *rightNode, static MultiNode * ApplyLocalJoin(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType joinType,
List *joinClauses); List *joinClauses);
static MultiNode * ApplySingleRangePartitionJoin(MultiNode *leftNode, static MultiNode * ApplySingleRangePartitionJoin(MultiNode *leftNode,
MultiNode *rightNode, MultiNode *rightNode,
List *partitionColumnList, List *partitionColumnListList,
JoinType joinType, JoinType joinType,
List *applicableJoinClauses); List *applicableJoinClauses);
static MultiNode * ApplySingleHashPartitionJoin(MultiNode *leftNode, static MultiNode * ApplySingleHashPartitionJoin(MultiNode *leftNode,
MultiNode *rightNode, MultiNode *rightNode,
List *partitionColumnList, List *partitionColumnListList,
JoinType joinType, JoinType joinType,
List *applicableJoinClauses); List *applicableJoinClauses);
static MultiJoin * ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode, static MultiJoin * ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType
joinType,
List *joinClauses); List *joinClauses);
static MultiNode * ApplyDualPartitionJoin(MultiNode *leftNode, MultiNode *rightNode, static MultiNode * ApplyDualPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType
joinType,
List *joinClauses); List *joinClauses);
static MultiNode * ApplyCartesianProductReferenceJoin(MultiNode *leftNode, static MultiNode * ApplyCartesianProductReferenceJoin(MultiNode *leftNode,
MultiNode *rightNode, MultiNode *rightNode,
List *partitionColumnList, List *partitionColumnListList,
JoinType joinType, JoinType joinType,
List *joinClauses); List *joinClauses);
static MultiNode * ApplyCartesianProduct(MultiNode *leftNode, MultiNode *rightNode, static MultiNode * ApplyCartesianProduct(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType joinType,
List *joinClauses); List *joinClauses);
@ -1663,16 +1665,18 @@ MultiJoinTree(List *joinOrderList, List *collectTableList, List *joinWhereClause
{ {
JoinRuleType joinRuleType = joinOrderNode->joinRuleType; JoinRuleType joinRuleType = joinOrderNode->joinRuleType;
JoinType joinType = joinOrderNode->joinType; JoinType joinType = joinOrderNode->joinType;
List *partitionColumnList = joinOrderNode->partitionColumnList; List *partitionColumnListList = joinOrderNode->partitionColumnListList;
List *joinClauseList = joinOrderNode->joinClauseList; List *joinClauseList = joinOrderNode->joinClauseList;
/* /*
* Build a join node between the top of our join tree and the next * Build a join node between the top of our join tree and the next
* table in the join order. * table in the join order.
* TODO: Don't use linitial(partitionColumnListList)
*/ */
MultiNode *newJoinNode = ApplyJoinRule(currentTopNode, MultiNode *newJoinNode = ApplyJoinRule(currentTopNode,
(MultiNode *) collectNode, (MultiNode *) collectNode,
joinRuleType, partitionColumnList, joinRuleType,
partitionColumnListList,
joinType, joinType,
joinClauseList); joinClauseList);
@ -2025,7 +2029,7 @@ pull_var_clause_default(Node *node)
*/ */
static MultiNode * static MultiNode *
ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode, JoinRuleType ruleType, ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode, JoinRuleType ruleType,
List *partitionColumnList, JoinType joinType, List *joinClauseList) List *partitionColumnListList, JoinType joinType, List *joinClauseList)
{ {
List *leftTableIdList = OutputTableIdList(leftNode); List *leftTableIdList = OutputTableIdList(leftNode);
List *rightTableIdList = OutputTableIdList(rightNode); List *rightTableIdList = OutputTableIdList(rightNode);
@ -2041,7 +2045,8 @@ ApplyJoinRule(MultiNode *leftNode, MultiNode *rightNode, JoinRuleType ruleType,
/* call the join rule application function to create the new join node */ /* call the join rule application function to create the new join node */
RuleApplyFunction ruleApplyFunction = JoinRuleApplyFunction(ruleType); RuleApplyFunction ruleApplyFunction = JoinRuleApplyFunction(ruleType);
MultiNode *multiNode = (*ruleApplyFunction)(leftNode, rightNode, partitionColumnList, MultiNode *multiNode = (*ruleApplyFunction)(leftNode, rightNode,
partitionColumnListList,
joinType, applicableJoinClauses); joinType, applicableJoinClauses);
if (joinType != JOIN_INNER && CitusIsA(multiNode, MultiJoin)) if (joinType != JOIN_INNER && CitusIsA(multiNode, MultiJoin))
@ -2096,7 +2101,7 @@ JoinRuleApplyFunction(JoinRuleType ruleType)
*/ */
static MultiNode * static MultiNode *
ApplyReferenceJoin(MultiNode *leftNode, MultiNode *rightNode, ApplyReferenceJoin(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType joinType,
List *applicableJoinClauses) List *applicableJoinClauses)
{ {
MultiJoin *joinNode = CitusMakeNode(MultiJoin); MultiJoin *joinNode = CitusMakeNode(MultiJoin);
@ -2118,7 +2123,7 @@ ApplyReferenceJoin(MultiNode *leftNode, MultiNode *rightNode,
*/ */
static MultiNode * static MultiNode *
ApplyCartesianProductReferenceJoin(MultiNode *leftNode, MultiNode *rightNode, ApplyCartesianProductReferenceJoin(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType joinType,
List *applicableJoinClauses) List *applicableJoinClauses)
{ {
MultiJoin *joinNode = CitusMakeNode(MultiJoin); MultiJoin *joinNode = CitusMakeNode(MultiJoin);
@ -2139,7 +2144,7 @@ ApplyCartesianProductReferenceJoin(MultiNode *leftNode, MultiNode *rightNode,
*/ */
static MultiNode * static MultiNode *
ApplyLocalJoin(MultiNode *leftNode, MultiNode *rightNode, ApplyLocalJoin(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType joinType,
List *applicableJoinClauses) List *applicableJoinClauses)
{ {
MultiJoin *joinNode = CitusMakeNode(MultiJoin); MultiJoin *joinNode = CitusMakeNode(MultiJoin);
@ -2160,11 +2165,11 @@ ApplyLocalJoin(MultiNode *leftNode, MultiNode *rightNode,
*/ */
static MultiNode * static MultiNode *
ApplySingleRangePartitionJoin(MultiNode *leftNode, MultiNode *rightNode, ApplySingleRangePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType joinType,
List *applicableJoinClauses) List *applicableJoinClauses)
{ {
MultiJoin *joinNode = MultiJoin *joinNode =
ApplySinglePartitionJoin(leftNode, rightNode, partitionColumnList, joinType, ApplySinglePartitionJoin(leftNode, rightNode, partitionColumnListList, joinType,
applicableJoinClauses); applicableJoinClauses);
joinNode->joinRuleType = SINGLE_RANGE_PARTITION_JOIN; joinNode->joinRuleType = SINGLE_RANGE_PARTITION_JOIN;
@ -2179,11 +2184,11 @@ ApplySingleRangePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
*/ */
static MultiNode * static MultiNode *
ApplySingleHashPartitionJoin(MultiNode *leftNode, MultiNode *rightNode, ApplySingleHashPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType joinType,
List *applicableJoinClauses) List *applicableJoinClauses)
{ {
MultiJoin *joinNode = MultiJoin *joinNode =
ApplySinglePartitionJoin(leftNode, rightNode, partitionColumnList, joinType, ApplySinglePartitionJoin(leftNode, rightNode, partitionColumnListList, joinType,
applicableJoinClauses); applicableJoinClauses);
joinNode->joinRuleType = SINGLE_HASH_PARTITION_JOIN; joinNode->joinRuleType = SINGLE_HASH_PARTITION_JOIN;
@ -2199,10 +2204,11 @@ ApplySingleHashPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
*/ */
static MultiJoin * static MultiJoin *
ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode, ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType joinType,
List *applicableJoinClauses) List *applicableJoinClauses)
{ {
Var *partitionColumn = linitial(partitionColumnList); Assert(list_length(partitionColumnListList) == 1);
Var *partitionColumn = linitial(linitial(partitionColumnListList));
uint32 partitionTableId = partitionColumn->varno; uint32 partitionTableId = partitionColumn->varno;
/* create all operator structures up front */ /* create all operator structures up front */
@ -2215,7 +2221,7 @@ ApplySinglePartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
* column against the join clause's columns. If one of the columns matches, * column against the join clause's columns. If one of the columns matches,
* we introduce a (re-)partition operator for the other column. * we introduce a (re-)partition operator for the other column.
*/ */
OpExpr *joinClause = SinglePartitionJoinClause(partitionColumnList, OpExpr *joinClause = SinglePartitionJoinClause(partitionColumnListList,
applicableJoinClauses); applicableJoinClauses);
Assert(joinClause != NULL); Assert(joinClause != NULL);
@ -2339,7 +2345,7 @@ ApplyDualPartitionJoin(MultiNode *leftNode, MultiNode *rightNode,
/* Creates a cartesian product node that joins the left and the right node. */ /* Creates a cartesian product node that joins the left and the right node. */
static MultiNode * static MultiNode *
ApplyCartesianProduct(MultiNode *leftNode, MultiNode *rightNode, ApplyCartesianProduct(MultiNode *leftNode, MultiNode *rightNode,
List *partitionColumnList, JoinType joinType, List *partitionColumnListList, JoinType joinType,
List *applicableJoinClauses) List *applicableJoinClauses)
{ {
MultiCartesianProduct *cartesianNode = CitusMakeNode(MultiCartesianProduct); MultiCartesianProduct *cartesianNode = CitusMakeNode(MultiCartesianProduct);

View File

@ -54,7 +54,6 @@ typedef struct ListCellAndListWrapper
(((var) = lfirst(var ## CellDoNotUse)) || true); \ (((var) = lfirst(var ## CellDoNotUse)) || true); \
var ## CellDoNotUse = lnext_compat(l, var ## CellDoNotUse)) var ## CellDoNotUse = lnext_compat(l, var ## CellDoNotUse))
/* /*
* foreach_int - * foreach_int -
* a convenience macro which loops through an int list without needing a * a convenience macro which loops through an int list without needing a
@ -80,6 +79,35 @@ typedef struct ListCellAndListWrapper
(((var) = lfirst_oid(var ## CellDoNotUse)) || true); \ (((var) = lfirst_oid(var ## CellDoNotUse)) || true); \
var ## CellDoNotUse = lnext_compat(l, var ## CellDoNotUse)) var ## CellDoNotUse = lnext_compat(l, var ## CellDoNotUse))
/*
* forboth_ptr -
* a convenience macro which loops through a pointer list without needing a
* ListCell, just a declared pointer variable to store the pointer of the
* cell in.
*
* How it works:
* - A ListCell is declared with the name {var}CellDoNotUse and used
* throughout the for loop using ## to concat.
* - To assign to var it needs to be done in the condition of the for loop,
* because we cannot use the initializer since a ListCell* variable is
* declared there.
* - || true is used to always enter the loop when cell is not null even if
* var is NULL.
*/
#define forboth_ptr(var1, l1, var2, l2) \
for (ListCell \
*(var1 ## CellDoNotUse) = list_head(l1) \
, *(var2 ## CellDoNotUse) = list_head(l2) \
; \
(var1 ## CellDoNotUse) != NULL \
&& (((var1) = lfirst(var1 ## CellDoNotUse)) || true) \
&& (var2 ## CellDoNotUse) != NULL \
&& (((var2) = lfirst(var2 ## CellDoNotUse)) || true) \
; \
var1 ## CellDoNotUse = lnext_compat(l1, var1 ## CellDoNotUse) \
, var2 ## CellDoNotUse = lnext_compat(l2, var2 ## CellDoNotUse) \
)
/* /*
* foreach_ptr_append - * foreach_ptr_append -
* a convenience macro which loops through a pointer List and can append list * a convenience macro which loops through a pointer List and can append list

View File

@ -75,7 +75,7 @@ typedef struct JoinOrderNode
* We keep track of all unique partition columns in the relation to correctly find * We keep track of all unique partition columns in the relation to correctly find
* join clauses that can be applied locally. * join clauses that can be applied locally.
*/ */
List *partitionColumnList; List *partitionColumnListList;
char partitionMethod; char partitionMethod;
List *joinClauseList; /* not relevant for the first table */ List *joinClauseList; /* not relevant for the first table */

View File

@ -252,51 +252,16 @@ SELECT * FROM (
FROM t2 JOIN t3 ON t2.id = t3.id FROM t2 JOIN t3 ON t2.id = t3.id
) foo ) foo
ORDER BY 1, 2, 3, 4; ORDER BY 1, 2, 3, 4;
id | id2 | a | b ERROR: the query contains a join that requires repartitioning
--------------------------------------------------------------------- HINT: Set citus.enable_repartition_joins to on to enable repartitioning
1 | 1 | 1 | 1
1 | 1 | 2 | 1
1 | 1 | 4 | 1
2 | 3 | 4 | 2
2 | 3 | 4 | 4
2 | 3 | 5 | 2
2 | 3 | 5 | 4
2 | 4 | 5 | 2
2 | 4 | 5 | 4
(9 rows)
EXPLAIN EXPLAIN
SELECT * FROM ( SELECT * FROM (
SELECT t2.id, t2.id2, a, b SELECT t2.id, t2.id2, a, b
FROM t2 JOIN t3 ON t2.id = t3.id FROM t2 JOIN t3 ON t2.id = t3.id
) foo ) foo
ORDER BY 1, 2, 3, 4; ORDER BY 1, 2, 3, 4;
QUERY PLAN ERROR: the query contains a join that requires repartitioning
--------------------------------------------------------------------- HINT: Set citus.enable_repartition_joins to on to enable repartitioning
Custom Scan (Citus Adaptive) (cost=0.00..0.00 rows=0 width=0)
-> Distributed Subplan XXX_1
-> Custom Scan (Citus Adaptive) (cost=0.00..0.00 rows=100000 width=16)
Task Count: 4
Tasks Shown: One of 4
-> Task
Node: host=localhost port=xxxxx dbname=regression
-> Merge Join (cost=285.08..607.40 rows=20808 width=16)
Merge Cond: (t2.id = t3.id)
-> Sort (cost=142.54..147.64 rows=2040 width=12)
Sort Key: t2.id
-> Seq Scan on t2_27905504 t2 (cost=0.00..30.40 rows=2040 width=12)
-> Sort (cost=142.54..147.64 rows=2040 width=8)
Sort Key: t3.id
-> Seq Scan on t3_27905508 t3 (cost=0.00..30.40 rows=2040 width=8)
Task Count: 1
Tasks Shown: All
-> Task
Node: host=localhost port=xxxxx dbname=regression
-> Sort (cost=59.83..62.33 rows=1000 width=16)
Sort Key: intermediate_result.id, intermediate_result.id2, intermediate_result.a, intermediate_result.b
-> Function Scan on read_intermediate_result intermediate_result (cost=0.00..10.00 rows=1000 width=16)
(22 rows)
-- Cannot pushdown if not joining on both distribution columns -- Cannot pushdown if not joining on both distribution columns
SELECT * FROM ( SELECT * FROM (
SELECT t2.id, t2.id2, a, b SELECT t2.id, t2.id2, a, b