diff --git a/src/test/regress/citus_tests/query_generator/README.md b/src/test/regress/citus_tests/query_generator/README.md index b35a96c0a..a59e79542 100644 --- a/src/test/regress/citus_tests/query_generator/README.md +++ b/src/test/regress/citus_tests/query_generator/README.md @@ -53,6 +53,7 @@ targetTables: name: type: distinctCopyCount: + colocateWith: Optional ``` Explanation: @@ -72,6 +73,7 @@ targetTables: "array of tables that will be used in generated queries" name: "name of column" type: "name of data type of column(only support 'int' now)" distinctCopyCount: "how many tables with the same configuration we should create(only by changing full name, still using the same name prefix)" + colocateWith: "colocated_with parameter" ``` @@ -87,6 +89,8 @@ Schema for Query configuration: ```yaml queryCount: queryOutFile: +repartitionJoin: +singleRepartitionJoin: semiAntiJoin: cartesianProduct: limit: @@ -114,6 +118,8 @@ Explanation: ```yaml queryCount: "number of queries to generate" queryOutFile: "file to write generated queries" +repartitionJoin: "should we enable repartition join" +singleRepartitionJoin: "should we make default repartition join mode as single repartition join (default is dual)" semiAntiJoin: "should we support semi joins (WHERE col IN (Subquery))" cartesianProduct: "should we support cartesian joins" limit: "should we support limit clause" @@ -172,7 +178,7 @@ Tool supports following citus table types: targetTables: - Table: ... - citusType: + citusType: ... ``` diff --git a/src/test/regress/citus_tests/query_generator/bin/diff-checker.py b/src/test/regress/citus_tests/query_generator/bin/diff-checker.py index 5bd2898a9..45df94971 100644 --- a/src/test/regress/citus_tests/query_generator/bin/diff-checker.py +++ b/src/test/regress/citus_tests/query_generator/bin/diff-checker.py @@ -73,6 +73,8 @@ def removeFailedQueryOutputFromFiles(distQueryOutFile, localQueryOutFile): acceptableErrors = [ "ERROR: complex joins are only supported when all distributed tables are co-located and joined on their distribution columns", "ERROR: recursive complex joins are only supported when all distributed tables are co-located and joined on their distribution columns", + "ERROR: cannot push down this subquery", + "ERROR: cannot perform a lateral outer join when a distributed subquery references complex subqueries, CTEs or local tables", ] failedDistQueryIds = findFailedQueriesFromFile(distQueryOutFile, acceptableErrors) removeFailedQueryOutputFromFile(distQueryOutFile, failedDistQueryIds) diff --git a/src/test/regress/citus_tests/query_generator/config/config.py b/src/test/regress/citus_tests/query_generator/config/config.py index fec93ef19..6abab84ff 100644 --- a/src/test/regress/citus_tests/query_generator/config/config.py +++ b/src/test/regress/citus_tests/query_generator/config/config.py @@ -28,6 +28,8 @@ class Config: self.targetRteCount = configObj["targetRteCount"] self.targetCteCount = configObj["targetCteCount"] self.targetCteRteCount = configObj["targetCteRteCount"] + self.repartitionJoin = configObj["repartitionJoin"] + self.singleRepartitionJoin = configObj["singleRepartitionJoin"] self.semiAntiJoin = configObj["semiAntiJoin"] self.cartesianProduct = configObj["cartesianProduct"] self.limit = configObj["limit"] @@ -111,8 +113,12 @@ def getMaxAllowedCountForTable(tableName): return filtered[0].maxAllowedUseOnQuery -def isTableDistributed(table): - return table.citusType == CitusType.DISTRIBUTED +def isTableHashDistributed(table): + return table.citusType == CitusType.HASH_DISTRIBUTED + + +def isTableSingleShardDistributed(table): + return table.citusType == CitusType.SINGLE_SHARD_DISTRIBUTED def isTableReference(table): diff --git a/src/test/regress/citus_tests/query_generator/config/config.yaml b/src/test/regress/citus_tests/query_generator/config/config.yaml index 1920966ee..1d1b5138e 100644 --- a/src/test/regress/citus_tests/query_generator/config/config.yaml +++ b/src/test/regress/citus_tests/query_generator/config/config.yaml @@ -2,6 +2,8 @@ interactiveMode: false queryCount: 250 queryOutFile: queries.sql ddlOutFile: ddls.sql +repartitionJoin: true +singleRepartitionJoin: false semiAntiJoin: true cartesianProduct: false limit: true @@ -26,7 +28,7 @@ commonColName: id targetTables: - Table: name: dist - citusType: DISTRIBUTED + citusType: HASH_DISTRIBUTED maxAllowedUseOnQuery: 10 rowCount: 10 nullRate: 0.1 @@ -36,6 +38,19 @@ targetTables: name: id type: int distinctCopyCount: 2 + - Table: + name: single_dist + citusType: SINGLE_SHARD_DISTRIBUTED + maxAllowedUseOnQuery: 10 + rowCount: 10 + nullRate: 0.1 + duplicateRate: 0.1 + columns: + - Column: + name: id + type: int + distinctCopyCount: 2 + colocateWith: none - Table: name: ref citusType: REFERENCE diff --git a/src/test/regress/citus_tests/query_generator/config/config_parser.py b/src/test/regress/citus_tests/query_generator/config/config_parser.py index 0b4f3837f..09a1438b7 100755 --- a/src/test/regress/citus_tests/query_generator/config/config_parser.py +++ b/src/test/regress/citus_tests/query_generator/config/config_parser.py @@ -49,6 +49,7 @@ def parseTable(targetTableDict): col = parseColumn(columnDict) columns.append(col) distinctCopyCount = targetTableDict["distinctCopyCount"] + colocateWith = targetTableDict.get("colocateWith", "default") return Table( name, citusType, @@ -58,6 +59,7 @@ def parseTable(targetTableDict): duplicateRate, columns, distinctCopyCount, + colocateWith, ) diff --git a/src/test/regress/citus_tests/query_generator/data_gen.py b/src/test/regress/citus_tests/query_generator/data_gen.py index 96f7a1366..d48b3ace4 100644 --- a/src/test/regress/citus_tests/query_generator/data_gen.py +++ b/src/test/regress/citus_tests/query_generator/data_gen.py @@ -17,7 +17,10 @@ def getTableData(): dataGenerationSql += "\n" # generate null rows - if not table.citusType == CitusType.DISTRIBUTED: + if table.citusType not in ( + CitusType.HASH_DISTRIBUTED, + CitusType.SINGLE_SHARD_DISTRIBUTED, + ): targetNullRows = int(table.rowCount * table.nullRate) dataGenerationSql += _genNullData(table.name, targetNullRows) dataGenerationSql += "\n" @@ -32,51 +35,23 @@ def getTableData(): def _genOverlappingData(tableName, startVal, rowCount): """returns string to fill table with [startVal,startVal+rowCount] range of integers""" - dataGenerationSql = "" - dataGenerationSql += "INSERT INTO " + tableName - dataGenerationSql += ( - " SELECT i FROM generate_series(" - + str(startVal) - + "," - + str(startVal + rowCount) - + ") i;" - ) - return dataGenerationSql + return f"INSERT INTO {tableName} SELECT i FROM generate_series({startVal}, {startVal + rowCount}) i;" def _genNullData(tableName, nullCount): """returns string to fill table with NULLs""" - dataGenerationSql = "" - dataGenerationSql += "INSERT INTO " + tableName - dataGenerationSql += ( - " SELECT NULL FROM generate_series(0," + str(nullCount) + ") i;" + return ( + f"INSERT INTO {tableName} SELECT NULL FROM generate_series(0, {nullCount}) i;" ) - return dataGenerationSql def _genDupData(tableName, dupRowCount): """returns string to fill table with duplicate integers which are fetched from given table""" - dataGenerationSql = "" - dataGenerationSql += "INSERT INTO " + tableName - dataGenerationSql += ( - " SELECT * FROM " - + tableName - + " ORDER BY " - + getConfig().commonColName - + " LIMIT " - + str(dupRowCount) - + ";" - ) - return dataGenerationSql + return f"INSERT INTO {tableName} SELECT * FROM {tableName} ORDER BY {getConfig().commonColName} LIMIT {dupRowCount};" def _genNonOverlappingData(tableName, startVal, tableIdx): """returns string to fill table with different integers for given table""" startVal = startVal + tableIdx * 20 endVal = startVal + 20 - dataGenerationSql = "" - dataGenerationSql += "INSERT INTO " + tableName - dataGenerationSql += ( - " SELECT i FROM generate_series(" + str(startVal) + "," + str(endVal) + ") i;" - ) - return dataGenerationSql + return f"INSERT INTO {tableName} SELECT i FROM generate_series({startVal}, {endVal}) i;" diff --git a/src/test/regress/citus_tests/query_generator/ddl_gen.py b/src/test/regress/citus_tests/query_generator/ddl_gen.py index b2f97f694..8d7b05950 100755 --- a/src/test/regress/citus_tests/query_generator/ddl_gen.py +++ b/src/test/regress/citus_tests/query_generator/ddl_gen.py @@ -1,4 +1,9 @@ -from config.config import getConfig, isTableDistributed, isTableReference +from config.config import ( + getConfig, + isTableHashDistributed, + isTableReference, + isTableSingleShardDistributed, +) def getTableDDLs(): @@ -12,10 +17,8 @@ def getTableDDLs(): def _genTableDDL(table): ddl = "" - ddl += "DROP TABLE IF EXISTS " + table.name + ";" - ddl += "\n" - - ddl += "CREATE TABLE " + table.name + "(" + ddl += f"DROP TABLE IF EXISTS {table.name};\n" + ddl += f"CREATE TABLE {table.name}(" for column in table.columns[:-1]: ddl += _genColumnDDL(column) ddl += ",\n" @@ -23,26 +26,14 @@ def _genTableDDL(table): ddl += _genColumnDDL(table.columns[-1]) ddl += ");\n" - if isTableDistributed(table): - ddl += ( - "SELECT create_distributed_table(" - + "'" - + table.name - + "','" - + getConfig().commonColName - + "'" - + ");" - ) - ddl += "\n" + if isTableHashDistributed(table): + ddl += f"SELECT create_distributed_table('{table.name}','{getConfig().commonColName}', colocate_with=>'{table.colocateWith}');\n" + if isTableSingleShardDistributed(table): + ddl += f"SELECT create_distributed_table('{table.name}', NULL, colocate_with=>'{table.colocateWith}');\n" elif isTableReference(table): - ddl += "SELECT create_reference_table(" + "'" + table.name + "'" + ");" - ddl += "\n" + ddl += f"SELECT create_reference_table('{table.name}');\n" return ddl def _genColumnDDL(column): - ddl = "" - ddl += column.name - ddl += " " - ddl += column.type - return ddl + return f"{column.name} {column.type}" diff --git a/src/test/regress/citus_tests/query_generator/generate_queries.py b/src/test/regress/citus_tests/query_generator/generate_queries.py index dd63b17ec..2029ace5b 100755 --- a/src/test/regress/citus_tests/query_generator/generate_queries.py +++ b/src/test/regress/citus_tests/query_generator/generate_queries.py @@ -55,8 +55,15 @@ def _fileMode(ddls, data): ) with open(fileName, "w") as f: # enable repartition joins due to https://github.com/citusdata/citus/issues/6865 - enableRepartitionJoinCommand = "SET citus.enable_repartition_joins TO on;\n" - queryLines = [enableRepartitionJoinCommand] + queryLines = [] + if getConfig().repartitionJoin: + enableRepartitionJoinCommand = "SET citus.enable_repartition_joins TO on;\n" + queryLines.append(enableRepartitionJoinCommand) + if getConfig().singleRepartitionJoin: + singleRepartitionJoinCommand = ( + "SET citus.enable_single_hash_repartition_joins TO on;\n" + ) + queryLines.append(singleRepartitionJoinCommand) queryId = 1 for _ in range(queryCount): query = newQuery() diff --git a/src/test/regress/citus_tests/query_generator/node_defs.py b/src/test/regress/citus_tests/query_generator/node_defs.py index b0db1da63..3b8e2acfa 100755 --- a/src/test/regress/citus_tests/query_generator/node_defs.py +++ b/src/test/regress/citus_tests/query_generator/node_defs.py @@ -22,9 +22,10 @@ class RestrictOp(Enum): class CitusType(Enum): - DISTRIBUTED = 1 - REFERENCE = 2 - POSTGRES = 3 + HASH_DISTRIBUTED = 1 + SINGLE_SHARD_DISTRIBUTED = 2 + REFERENCE = 3 + POSTGRES = 4 class Table: @@ -38,6 +39,7 @@ class Table: duplicateRate, columns, distinctCopyCount, + colocateWith, ): self.name = name self.citusType = citusType @@ -47,6 +49,7 @@ class Table: self.duplicateRate = duplicateRate self.columns = columns self.distinctCopyCount = distinctCopyCount + self.colocateWith = colocateWith class Column: diff --git a/src/test/regress/citus_tests/query_generator/query_gen.py b/src/test/regress/citus_tests/query_generator/query_gen.py index e25525d29..bf5da3c05 100644 --- a/src/test/regress/citus_tests/query_generator/query_gen.py +++ b/src/test/regress/citus_tests/query_generator/query_gen.py @@ -114,15 +114,15 @@ class GeneratorContext: def randomCteName(self): """returns a randomly selected cte name""" randCteRef = random.randint(0, self.currentCteCount - 1) - return " cte_" + str(randCteRef) + return f" cte_{randCteRef}" def curAlias(self): """returns current alias name to be used for the current table""" - return " table_" + str(self.totalRteCount) + return f" table_{self.totalRteCount}" def curCteAlias(self): """returns current alias name to be used for the current cte""" - return " cte_" + str(self.currentCteCount) + return f" cte_{self.currentCteCount}" def hasAnyCte(self): """returns if context has any cte""" @@ -153,7 +153,7 @@ class GeneratorContext: # do not enforce per table rte limit if we are inside cte if self.insideCte: rteName = random.choice(getAllTableNames()) - return " " + rteName + " " + return f" {rteName} " while True: # keep trying to find random table by eliminating the ones which hit table limit @@ -173,7 +173,7 @@ class GeneratorContext: # increment rte count for the table name self.perTableRtes[rteName] += 1 - return " " + rteName + " " + return f" {rteName} " def newQuery(): @@ -206,8 +206,7 @@ def _genQuery(genCtx): and not genCtx.usedAvg ): genCtx.usedAvg = True - query += "SELECT " - query += "count(*), avg(avgsub." + getConfig().commonColName + ") FROM " + query += f"SELECT count(*), avg(avgsub.{getConfig().commonColName}) FROM " query += _genSubqueryRte(genCtx) query += " AS avgsub" else: @@ -226,35 +225,24 @@ def _genQuery(genCtx): def _genOrderBy(genCtx): # 'ORDER BY' DistColName - query = "" - query += " ORDER BY " - query += getConfig().commonColName + " " - return query + return f" ORDER BY {getConfig().commonColName} " def _genLimit(genCtx): # 'LIMIT' 'random()' - query = "" - query += " LIMIT " (fromVal, toVal) = getConfig().limitRange - query += str(random.randint(fromVal, toVal)) - return query + return f" LIMIT {random.randint(fromVal, toVal)}" def _genSelectExpr(genCtx): # 'SELECT' 'curAlias()' '.' DistColName - query = "" - query += " SELECT " commonColName = getConfig().commonColName - query += genCtx.curAlias() + "." + commonColName + " " - - return query + return f" SELECT {genCtx.curAlias()}.{commonColName} " def _genFromExpr(genCtx): # 'FROM' (Rte JoinList JoinOp Rte Using || RteList) ['WHERE' 'nextRandomAlias()' '.' DistColName RestrictExpr] - query = "" - query += " FROM " + query = " FROM " if shouldSelectThatBranch(): query += _genRte(genCtx) @@ -267,8 +255,7 @@ def _genFromExpr(genCtx): alias = genCtx.removeLastAlias() if shouldSelectThatBranch(): - query += " WHERE " - query += alias + "." + getConfig().commonColName + query += f" WHERE {alias}.{getConfig().commonColName}" query += _genRestrictExpr(genCtx) return query @@ -353,9 +340,7 @@ def _genJoinList(genCtx): def _genUsing(genCtx): # 'USING' '(' DistColName ')' - query = "" - query += " USING (" + getConfig().commonColName + " ) " - return query + return f" USING ({getConfig().commonColName}) " def _genRte(genCtx): @@ -392,7 +377,7 @@ def _genRte(genCtx): query += _genCteRte(genCtx) elif rteType == RTEType.VALUES: query += _genValuesRte(genCtx) - modifiedAlias = alias + "(" + getConfig().commonColName + ") " + modifiedAlias = f"{alias}({getConfig().commonColName}) " else: raise BaseException("unknown RTE type") @@ -428,7 +413,5 @@ def _genCteRte(genCtx): def _genValuesRte(genCtx): # '( VALUES(random()) )' - query = "" (fromVal, toVal) = getConfig().dataRange - query += " ( VALUES(" + str(random.randint(fromVal, toVal)) + " ) ) " - return query + return f" ( VALUES({random.randint(fromVal, toVal)}) ) " diff --git a/src/test/regress/citus_tests/query_generator/random_selections.py b/src/test/regress/citus_tests/query_generator/random_selections.py index ee32c620a..47f55b3e5 100644 --- a/src/test/regress/citus_tests/query_generator/random_selections.py +++ b/src/test/regress/citus_tests/query_generator/random_selections.py @@ -25,7 +25,7 @@ def randomRteType(): def randomJoinOp(): """returns a randomly selected JoinOp given at config""" joinTypes = getConfig().targetJoinTypes - return " " + random.choice(joinTypes).name + " JOIN" + return f" {random.choice(joinTypes).name} JOIN" def randomRestrictOp(): @@ -42,4 +42,4 @@ def randomRestrictOp(): else: raise BaseException("Unknown restrict op") - return " " + opText + " " + return f" {opText} "