diff --git a/src/backend/distributed/planner/multi_physical_planner.c b/src/backend/distributed/planner/multi_physical_planner.c index 1ef7554ac..2a12a8b8a 100644 --- a/src/backend/distributed/planner/multi_physical_planner.c +++ b/src/backend/distributed/planner/multi_physical_planner.c @@ -50,8 +50,10 @@ #include "distributed/pg_dist_partition.h" #include "distributed/pg_dist_shard.h" #include "distributed/query_pushdown_planning.h" +#include "distributed/query_utils.h" #include "distributed/shardinterval_utils.h" #include "distributed/shard_pruning.h" +#include "distributed/string_utils.h" #include "distributed/worker_manager.h" #include "distributed/worker_protocol.h" @@ -225,8 +227,7 @@ static void AssignDataFetchDependencies(List *taskList); static uint32 TaskListHighestTaskId(List *taskList); static List * MapTaskList(MapMergeJob *mapMergeJob, List *filterTaskList); static StringInfo CreateMapQueryString(MapMergeJob *mapMergeJob, Task *filterTask, - char *partitionColumnName); -static char * ColumnName(Var *column, List *rangeTableList); + uint32 partitionColumnIndex); static List * MergeTaskList(MapMergeJob *mapMergeJob, List *mapTaskList, uint32 taskIdIndex); static StringInfo ColumnNameArrayString(uint32 columnCount, uint64 generatingJobId); @@ -237,6 +238,7 @@ static bool CoPlacedShardIntervals(ShardInterval *firstInterval, static List * FetchEqualityAttrNumsForRTEOpExpr(OpExpr *opExpr); static List * FetchEqualityAttrNumsForRTEBoolExpr(BoolExpr *boolExpr); static List * FetchEqualityAttrNumsForList(List *nodeList); +static int PartitionColumnIndex(Var *targetVar, List *targetList); #if PG_VERSION_NUM >= PG_VERSION_13 static List * GetColumnOriginalIndexes(Oid relationId); #endif @@ -4477,11 +4479,10 @@ MapTaskList(MapMergeJob *mapMergeJob, List *filterTaskList) { List *mapTaskList = NIL; Query *filterQuery = mapMergeJob->job.jobQuery; - List *rangeTableList = filterQuery->rtable; ListCell *filterTaskCell = NULL; Var *partitionColumn = mapMergeJob->partitionColumn; - char *partitionColumnName = NULL; + uint32 partitionColumnResNo = 0; List *groupClauseList = filterQuery->groupClause; if (groupClauseList != NIL) { @@ -4490,29 +4491,19 @@ MapTaskList(MapMergeJob *mapMergeJob, List *filterTaskList) targetEntryList); TargetEntry *groupByTargetEntry = (TargetEntry *) linitial(groupTargetEntryList); - partitionColumnName = groupByTargetEntry->resname; + partitionColumnResNo = groupByTargetEntry->resno; } else { - TargetEntry *targetEntry = tlist_member((Expr *) partitionColumn, - filterQuery->targetList); - if (targetEntry != NULL) - { - /* targetEntry->resname may be NULL */ - partitionColumnName = targetEntry->resname; - } - - if (partitionColumnName == NULL) - { - partitionColumnName = ColumnName(partitionColumn, rangeTableList); - } + partitionColumnResNo = PartitionColumnIndex(partitionColumn, + filterQuery->targetList); } foreach(filterTaskCell, filterTaskList) { Task *filterTask = (Task *) lfirst(filterTaskCell); StringInfo mapQueryString = CreateMapQueryString(mapMergeJob, filterTask, - partitionColumnName); + partitionColumnResNo); /* convert filter query task into map task */ Task *mapTask = filterTask; @@ -4526,12 +4517,40 @@ MapTaskList(MapMergeJob *mapMergeJob, List *filterTaskList) } +/* + * PartitionColumnIndex finds the index of the given target var. + */ +static int +PartitionColumnIndex(Var *targetVar, List *targetList) +{ + TargetEntry *targetEntry = NULL; + int resNo = 1; + foreach_ptr(targetEntry, targetList) + { + if (IsA(targetEntry->expr, Var)) + { + Var *candidateVar = (Var *) targetEntry->expr; + if (candidateVar->varattno == targetVar->varattno && + candidateVar->varno == targetVar->varno) + { + return resNo; + } + resNo++; + } + } + + ereport(ERROR, (errmsg("unexpected state: %d varno %d varattno couldn't be found", + targetVar->varno, targetVar->varattno))); + return resNo; +} + + /* * CreateMapQueryString creates and returns the map query string for the given filterTask. */ static StringInfo CreateMapQueryString(MapMergeJob *mapMergeJob, Task *filterTask, - char *partitionColumnName) + uint32 partitionColumnIndex) { uint64 jobId = filterTask->jobId; uint32 taskId = filterTask->taskId; @@ -4577,8 +4596,9 @@ CreateMapQueryString(MapMergeJob *mapMergeJob, Task *filterTask, partitionCommand = HASH_PARTITION_COMMAND; } + char *partitionColumnIndextText = ConvertIntToString(partitionColumnIndex); appendStringInfo(mapQueryString, partitionCommand, jobId, taskId, - filterQueryEscapedText, partitionColumnName, + filterQueryEscapedText, partitionColumnIndextText, partitionColumnTypeFullName, splitPointString->data); return mapQueryString; } @@ -4674,40 +4694,6 @@ RowModifyLevelForQuery(Query *query) } -/* - * ColumnName resolves the given column's name. The given column could belong to - * a regular table or to an intermediate table formed to execute a distributed - * query. - */ -static char * -ColumnName(Var *column, List *rangeTableList) -{ - char *columnName = NULL; - Index tableId = column->varno; - AttrNumber columnNumber = column->varattno; - RangeTblEntry *rangeTableEntry = rt_fetch(tableId, rangeTableList); - - CitusRTEKind rangeTableKind = GetRangeTblKind(rangeTableEntry); - if (rangeTableKind == CITUS_RTE_REMOTE_QUERY) - { - Alias *referenceNames = rangeTableEntry->eref; - List *columnNameList = referenceNames->colnames; - int columnIndex = columnNumber - 1; - - Value *columnValue = (Value *) list_nth(columnNameList, columnIndex); - columnName = strVal(columnValue); - } - else if (rangeTableKind == CITUS_RTE_RELATION) - { - Oid relationId = rangeTableEntry->relid; - columnName = get_attname(relationId, columnNumber, false); - } - - Assert(columnName != NULL); - return columnName; -} - - /* * ArrayObjectToString converts an SQL object to its string representation. */ diff --git a/src/backend/distributed/sql/citus--9.5-1--10.0-1.sql b/src/backend/distributed/sql/citus--9.5-1--10.0-1.sql index c305fec8d..f770bb57e 100644 --- a/src/backend/distributed/sql/citus--9.5-1--10.0-1.sql +++ b/src/backend/distributed/sql/citus--9.5-1--10.0-1.sql @@ -67,4 +67,3 @@ DROP FUNCTION pg_catalog.master_create_worker_shards(text, integer, integer); DROP FUNCTION pg_catalog.mark_tables_colocated(regclass, regclass[]); #include "udfs/citus_shard_sizes/10.0-1.sql" #include "udfs/citus_shards/10.0-1.sql" - diff --git a/src/backend/distributed/utils/query_utils.c b/src/backend/distributed/utils/query_utils.c index f4741e36c..4ae49ed81 100644 --- a/src/backend/distributed/utils/query_utils.c +++ b/src/backend/distributed/utils/query_utils.c @@ -11,16 +11,17 @@ */ #include "postgres.h" +#include "nodes/primnodes.h" #include "catalog/pg_class.h" #include "distributed/query_utils.h" #include "distributed/version_compat.h" +#include "distributed/listutils.h" #include "nodes/nodeFuncs.h" static bool CitusQueryableRangeTableRelation(RangeTblEntry *rangeTableEntry); - /* * ExtractRangeTableList walks over a tree to gather entries. * Execution is parameterized by passing walkerMode flag via ExtractRangeTableWalkerContext diff --git a/src/backend/distributed/worker/worker_partition_protocol.c b/src/backend/distributed/worker/worker_partition_protocol.c index 671e8e04c..0da3a5bb0 100644 --- a/src/backend/distributed/worker/worker_partition_protocol.c +++ b/src/backend/distributed/worker/worker_partition_protocol.c @@ -70,7 +70,8 @@ static void RenameDirectory(StringInfo oldDirectoryName, StringInfo newDirectory static void FileOutputStreamWrite(FileOutputStream *file, StringInfo dataToWrite); static void FileOutputStreamFlush(FileOutputStream *file); static void FilterAndPartitionTable(const char *filterQuery, - const char *columnName, Oid columnType, + char *partitionColumnName, + int partitionColumnIndex, Oid columnType, PartitionIdFunction partitionIdFunction, const void *partitionIdContext, FileOutputStream *partitionFileArray, @@ -86,7 +87,9 @@ static uint32 HashPartitionId(Datum partitionValue, Oid partitionCollation, const void *context); static StringInfo UserPartitionFilename(StringInfo directoryName, uint32 partitionId); static bool FileIsLink(const char *filename, struct stat filestat); - +static void PartitionColumnIndexOrPartitionColumnName(char *partitionColumnNameCandidate, + char **partitionColumnName, + uint32 *partitionColumnIndex); /* exports for SQL callable functions */ PG_FUNCTION_INFO_V1(worker_range_partition_table); @@ -110,12 +113,19 @@ worker_range_partition_table(PG_FUNCTION_ARGS) uint32 taskId = PG_GETARG_UINT32(1); text *filterQueryText = PG_GETARG_TEXT_P(2); text *partitionColumnText = PG_GETARG_TEXT_P(3); + char *partitionColumnNameCandidate = text_to_cstring(partitionColumnText); + + char *partitionColumnName = NULL; + uint32 partitionColumnIndex = 0; + PartitionColumnIndexOrPartitionColumnName(partitionColumnNameCandidate, + &partitionColumnName, + &partitionColumnIndex); + Oid partitionColumnType = PG_GETARG_OID(4); ArrayType *splitPointObject = PG_GETARG_ARRAYTYPE_P(5); - const char *filterQuery = text_to_cstring(filterQueryText); - const char *partitionColumn = text_to_cstring(partitionColumnText); + const char *filterQuery = text_to_cstring(filterQueryText); /* first check that array element's and partition column's types match */ Oid splitPointType = ARR_ELEMTYPE(splitPointObject); @@ -152,7 +162,8 @@ worker_range_partition_table(PG_FUNCTION_ARGS) FileBufferSizeInBytes = FileBufferSize(PartitionBufferSize, fileCount); /* call the partitioning function that does the actual work */ - FilterAndPartitionTable(filterQuery, partitionColumn, partitionColumnType, + FilterAndPartitionTable(filterQuery, partitionColumnName, partitionColumnIndex, + partitionColumnType, &RangePartitionId, (const void *) partitionContext, partitionFileArray, fileCount); @@ -160,7 +171,6 @@ worker_range_partition_table(PG_FUNCTION_ARGS) ClosePartitionFiles(partitionFileArray, fileCount); CitusRemoveDirectory(taskDirectory->data); RenameDirectory(taskAttemptDirectory, taskDirectory); - PG_RETURN_VOID(); } @@ -182,11 +192,19 @@ worker_hash_partition_table(PG_FUNCTION_ARGS) uint32 taskId = PG_GETARG_UINT32(1); text *filterQueryText = PG_GETARG_TEXT_P(2); text *partitionColumnText = PG_GETARG_TEXT_P(3); + char *partitionColumnNameCandidate = text_to_cstring(partitionColumnText); + + char *partitionColumnName = NULL; + uint32 partitionColumnIndex = 0; + PartitionColumnIndexOrPartitionColumnName(partitionColumnNameCandidate, + &partitionColumnName, + &partitionColumnIndex); + Oid partitionColumnType = PG_GETARG_OID(4); ArrayType *hashRangeObject = PG_GETARG_ARRAYTYPE_P(5); + const char *filterQuery = text_to_cstring(filterQueryText); - const char *partitionColumn = text_to_cstring(partitionColumnText); Datum *hashRangeArray = DeconstructArrayObject(hashRangeObject); int32 partitionCount = ArrayObjectCount(hashRangeObject); @@ -226,7 +244,8 @@ worker_hash_partition_table(PG_FUNCTION_ARGS) FileBufferSizeInBytes = FileBufferSize(PartitionBufferSize, fileCount); /* call the partitioning function that does the actual work */ - FilterAndPartitionTable(filterQuery, partitionColumn, partitionColumnType, + FilterAndPartitionTable(filterQuery, partitionColumnName, partitionColumnIndex, + partitionColumnType, &HashPartitionId, (const void *) partitionContext, partitionFileArray, fileCount); @@ -234,11 +253,43 @@ worker_hash_partition_table(PG_FUNCTION_ARGS) ClosePartitionFiles(partitionFileArray, fileCount); CitusRemoveDirectory(taskDirectory->data); RenameDirectory(taskAttemptDirectory, taskDirectory); - PG_RETURN_VOID(); } +/* + * PartitionColumnIndexOrPartitionColumnName either sets partitionColumnName or + * partitionColumnIndex. See below for more. + */ +static void +PartitionColumnIndexOrPartitionColumnName(char *partitionColumnNameCandidate, + char **partitionColumnName, + uint32 *partitionColumnIndex) +{ + char *endptr = NULL; + uint32 partitionColumnIndexCandidate = + strtoul(partitionColumnNameCandidate, &endptr, 10 /*base*/); + if (endptr == partitionColumnNameCandidate) + { + /* + * There was a bug around using the column name in worker_[hash|range]_partition_table + * APIs and one of the solutions was to send partition column index directly to these APIs. + * However, this would mean change in API signature and would introduce difficulties + * in upgrade paths. Instead of changing the API signature, we send the partition column index + * as text. In case of rolling upgrades, when a worker is upgraded and coordinator is not, it + * is possible that the text still has the column name, not the column index. So + * we rely on detecting that with a parse error here. + * + */ + *partitionColumnName = partitionColumnNameCandidate; + } + else + { + *partitionColumnIndex = partitionColumnIndexCandidate; + } +} + + /* * SyntheticShardIntervalArrayForShardMinValues returns a shard interval pointer array * which gets the shardMinValues from the input shardMinValues array. Note that @@ -845,14 +896,14 @@ FileOutputStreamFlush(FileOutputStream *file) */ static void FilterAndPartitionTable(const char *filterQuery, - const char *partitionColumnName, Oid partitionColumnType, + char *partitionColumnName, + int partitionColumnIndex, Oid partitionColumnType, PartitionIdFunction partitionIdFunction, const void *partitionIdContext, FileOutputStream *partitionFileArray, uint32 fileCount) { FmgrInfo *columnOutputFunctions = NULL; - int partitionColumnIndex = 0; Oid partitionColumnTypeId = InvalidOid; Oid partitionColumnCollation = InvalidOid; @@ -888,8 +939,14 @@ FilterAndPartitionTable(const char *filterQuery, { ereport(ERROR, (errmsg("no partition to read into"))); } - - partitionColumnIndex = ColumnIndex(rowDescriptor, partitionColumnName); + if (partitionColumnName != NULL) + { + /* + * in old API, the partition column name is used + * to determine partitionColumnIndex + */ + partitionColumnIndex = ColumnIndex(rowDescriptor, partitionColumnName); + } partitionColumnTypeId = SPI_gettypeid(rowDescriptor, partitionColumnIndex); partitionColumnCollation = TupleDescAttr(rowDescriptor, partitionColumnIndex - 1)->attcollation; diff --git a/src/include/distributed/query_utils.h b/src/include/distributed/query_utils.h index 5e4a55bf7..7e1ba54e6 100644 --- a/src/include/distributed/query_utils.h +++ b/src/include/distributed/query_utils.h @@ -13,6 +13,7 @@ #include "postgres.h" #include "nodes/pg_list.h" +#include "nodes/primnodes.h" /* Enum to define execution flow of ExtractRangeTableList */ typedef enum ExtractRangeTableMode diff --git a/src/test/regress/expected/adaptive_executor_repartition.out b/src/test/regress/expected/adaptive_executor_repartition.out index a70e97688..c5b583bef 100644 --- a/src/test/regress/expected/adaptive_executor_repartition.out +++ b/src/test/regress/expected/adaptive_executor_repartition.out @@ -117,9 +117,58 @@ WHERE (7 rows) SET citus.enable_single_hash_repartition_joins TO OFF; +--issue 4315 +create table cars (car_id int); +insert into cars select s from generate_series(1,10) s; +create table trips (trip_id int, car_id int); +insert into trips select s % 10, s % 11 from generate_series(1, 100) s; +-- the result of this should be the same when the tables are distributed +select count(*) from trips t1, cars r1, trips t2, cars r2 where t1.trip_id = t2.trip_id and t1.car_id = r1.car_id and t2.car_id = r2.car_id; + count +--------------------------------------------------------------------- + 829 +(1 row) + +select create_distributed_table('trips', 'trip_id'); +NOTICE: Copying data from local table... +NOTICE: copying the data has completed +DETAIL: The local data in the table is no longer visible, but is still on disk. +HINT: To remove the local data, run: SELECT truncate_local_data_after_distributing_table($$adaptive_executor.trips$$) + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +select create_distributed_table('cars', 'car_id'); +NOTICE: Copying data from local table... +NOTICE: copying the data has completed +DETAIL: The local data in the table is no longer visible, but is still on disk. +HINT: To remove the local data, run: SELECT truncate_local_data_after_distributing_table($$adaptive_executor.cars$$) + create_distributed_table +--------------------------------------------------------------------- + +(1 row) + +set citus.enable_repartition_joins to on; +set citus.enable_single_hash_repartition_joins to off; +select count(*) from trips t1, cars r1, trips t2, cars r2 where t1.trip_id = t2.trip_id and t1.car_id = r1.car_id and t2.car_id = r2.car_id; + count +--------------------------------------------------------------------- + 829 +(1 row) + +set citus.enable_single_hash_repartition_joins to on; +select count(*) from trips t1, cars r1, trips t2, cars r2 where t1.trip_id = t2.trip_id and t1.car_id = r1.car_id and t2.car_id = r2.car_id; + count +--------------------------------------------------------------------- + 829 +(1 row) + DROP SCHEMA adaptive_executor CASCADE; -NOTICE: drop cascades to 4 other objects +NOTICE: drop cascades to 6 other objects DETAIL: drop cascades to table ab drop cascades to table single_hash_repartition_first drop cascades to table single_hash_repartition_second drop cascades to table ref_table +drop cascades to table cars +drop cascades to table trips diff --git a/src/test/regress/sql/adaptive_executor_repartition.sql b/src/test/regress/sql/adaptive_executor_repartition.sql index fb9c5a208..bb625ae6f 100644 --- a/src/test/regress/sql/adaptive_executor_repartition.sql +++ b/src/test/regress/sql/adaptive_executor_repartition.sql @@ -60,4 +60,24 @@ WHERE SET citus.enable_single_hash_repartition_joins TO OFF; +--issue 4315 +create table cars (car_id int); +insert into cars select s from generate_series(1,10) s; + +create table trips (trip_id int, car_id int); +insert into trips select s % 10, s % 11 from generate_series(1, 100) s; + +-- the result of this should be the same when the tables are distributed +select count(*) from trips t1, cars r1, trips t2, cars r2 where t1.trip_id = t2.trip_id and t1.car_id = r1.car_id and t2.car_id = r2.car_id; + +select create_distributed_table('trips', 'trip_id'); +select create_distributed_table('cars', 'car_id'); + +set citus.enable_repartition_joins to on; +set citus.enable_single_hash_repartition_joins to off; +select count(*) from trips t1, cars r1, trips t2, cars r2 where t1.trip_id = t2.trip_id and t1.car_id = r1.car_id and t2.car_id = r2.car_id; + +set citus.enable_single_hash_repartition_joins to on; +select count(*) from trips t1, cars r1, trips t2, cars r2 where t1.trip_id = t2.trip_id and t1.car_id = r1.car_id and t2.car_id = r2.car_id; + DROP SCHEMA adaptive_executor CASCADE;