From 71a7f39b05a3dc24d4c9c92ffb82aa955286b7ed Mon Sep 17 00:00:00 2001 From: Andres Freund Date: Mon, 17 Apr 2017 20:21:07 -0700 Subject: [PATCH] Skip exhaustive test in CoPartitionedTables() if declared colocated. That's considerably cheaper. --- .../planner/multi_logical_optimizer.c | 23 +++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/backend/distributed/planner/multi_logical_optimizer.c b/src/backend/distributed/planner/multi_logical_optimizer.c index c5028ba9c..00183f201 100644 --- a/src/backend/distributed/planner/multi_logical_optimizer.c +++ b/src/backend/distributed/planner/multi_logical_optimizer.c @@ -27,6 +27,7 @@ #include "commands/extension.h" #include "distributed/citus_nodes.h" #include "distributed/citus_ruleutils.h" +#include "distributed/colocation_utils.h" #include "distributed/metadata_cache.h" #include "distributed/multi_logical_optimizer.h" #include "distributed/multi_logical_planner.h" @@ -3960,10 +3961,8 @@ RelationIdList(Query *query) /* - * CoPartitionedTables checks if given two distributed tables have 1-to-1 shard - * partitioning. It uses shard interval array that are sorted on interval minimum - * values. Then it compares every shard interval in order and if any pair of - * shard intervals are not equal it returns false. + * CoPartitionedTables checks if given two distributed tables have 1-to-1 + * shard partitioning. */ static bool CoPartitionedTables(Oid firstRelationId, Oid secondRelationId) @@ -3992,6 +3991,22 @@ CoPartitionedTables(Oid firstRelationId, Oid secondRelationId) Assert(comparisonFunction != NULL); + /* + * Check if the tables have the same colocation ID - if so, we know + * they're colocated. + */ + if (firstTableCache->colocationId != INVALID_COLOCATION_ID && + firstTableCache->colocationId == secondTableCache->colocationId) + { + return true; + } + + /* + * If not known to be colocated check if the remaining shards are + * anyway. Do so by comparing the shard interval arrays that are sorted on + * interval minimum values. Then it compares every shard interval in order + * and if any pair of shard intervals are not equal it returns false. + */ for (intervalIndex = 0; intervalIndex < firstListShardCount; intervalIndex++) { ShardInterval *firstInterval = sortedFirstIntervalArray[intervalIndex];