diff --git a/src/test/regress/citus_tests/query_generator/README.md b/src/test/regress/citus_tests/query_generator/README.md
index b35a96c0a..a59e79542 100644
--- a/src/test/regress/citus_tests/query_generator/README.md
+++ b/src/test/regress/citus_tests/query_generator/README.md
@@ -53,6 +53,7 @@ targetTables:
name:
type:
distinctCopyCount:
+ colocateWith: Optional
```
Explanation:
@@ -72,6 +73,7 @@ targetTables: "array of tables that will be used in generated queries"
name: "name of column"
type: "name of data type of column(only support 'int' now)"
distinctCopyCount: "how many tables with the same configuration we should create(only by changing full name, still using the same name prefix)"
+ colocateWith: "colocated_with parameter"
```
@@ -87,6 +89,8 @@ Schema for Query configuration:
```yaml
queryCount:
queryOutFile:
+repartitionJoin:
+singleRepartitionJoin:
semiAntiJoin:
cartesianProduct:
limit:
@@ -114,6 +118,8 @@ Explanation:
```yaml
queryCount: "number of queries to generate"
queryOutFile: "file to write generated queries"
+repartitionJoin: "should we enable repartition join"
+singleRepartitionJoin: "should we make default repartition join mode as single repartition join (default is dual)"
semiAntiJoin: "should we support semi joins (WHERE col IN (Subquery))"
cartesianProduct: "should we support cartesian joins"
limit: "should we support limit clause"
@@ -172,7 +178,7 @@ Tool supports following citus table types:
targetTables:
- Table:
...
- citusType:
+ citusType:
...
```
diff --git a/src/test/regress/citus_tests/query_generator/bin/diff-checker.py b/src/test/regress/citus_tests/query_generator/bin/diff-checker.py
index 5bd2898a9..45df94971 100644
--- a/src/test/regress/citus_tests/query_generator/bin/diff-checker.py
+++ b/src/test/regress/citus_tests/query_generator/bin/diff-checker.py
@@ -73,6 +73,8 @@ def removeFailedQueryOutputFromFiles(distQueryOutFile, localQueryOutFile):
acceptableErrors = [
"ERROR: complex joins are only supported when all distributed tables are co-located and joined on their distribution columns",
"ERROR: recursive complex joins are only supported when all distributed tables are co-located and joined on their distribution columns",
+ "ERROR: cannot push down this subquery",
+ "ERROR: cannot perform a lateral outer join when a distributed subquery references complex subqueries, CTEs or local tables",
]
failedDistQueryIds = findFailedQueriesFromFile(distQueryOutFile, acceptableErrors)
removeFailedQueryOutputFromFile(distQueryOutFile, failedDistQueryIds)
diff --git a/src/test/regress/citus_tests/query_generator/config/config.py b/src/test/regress/citus_tests/query_generator/config/config.py
index fec93ef19..6abab84ff 100644
--- a/src/test/regress/citus_tests/query_generator/config/config.py
+++ b/src/test/regress/citus_tests/query_generator/config/config.py
@@ -28,6 +28,8 @@ class Config:
self.targetRteCount = configObj["targetRteCount"]
self.targetCteCount = configObj["targetCteCount"]
self.targetCteRteCount = configObj["targetCteRteCount"]
+ self.repartitionJoin = configObj["repartitionJoin"]
+ self.singleRepartitionJoin = configObj["singleRepartitionJoin"]
self.semiAntiJoin = configObj["semiAntiJoin"]
self.cartesianProduct = configObj["cartesianProduct"]
self.limit = configObj["limit"]
@@ -111,8 +113,12 @@ def getMaxAllowedCountForTable(tableName):
return filtered[0].maxAllowedUseOnQuery
-def isTableDistributed(table):
- return table.citusType == CitusType.DISTRIBUTED
+def isTableHashDistributed(table):
+ return table.citusType == CitusType.HASH_DISTRIBUTED
+
+
+def isTableSingleShardDistributed(table):
+ return table.citusType == CitusType.SINGLE_SHARD_DISTRIBUTED
def isTableReference(table):
diff --git a/src/test/regress/citus_tests/query_generator/config/config.yaml b/src/test/regress/citus_tests/query_generator/config/config.yaml
index 1920966ee..1d1b5138e 100644
--- a/src/test/regress/citus_tests/query_generator/config/config.yaml
+++ b/src/test/regress/citus_tests/query_generator/config/config.yaml
@@ -2,6 +2,8 @@ interactiveMode: false
queryCount: 250
queryOutFile: queries.sql
ddlOutFile: ddls.sql
+repartitionJoin: true
+singleRepartitionJoin: false
semiAntiJoin: true
cartesianProduct: false
limit: true
@@ -26,7 +28,7 @@ commonColName: id
targetTables:
- Table:
name: dist
- citusType: DISTRIBUTED
+ citusType: HASH_DISTRIBUTED
maxAllowedUseOnQuery: 10
rowCount: 10
nullRate: 0.1
@@ -36,6 +38,19 @@ targetTables:
name: id
type: int
distinctCopyCount: 2
+ - Table:
+ name: single_dist
+ citusType: SINGLE_SHARD_DISTRIBUTED
+ maxAllowedUseOnQuery: 10
+ rowCount: 10
+ nullRate: 0.1
+ duplicateRate: 0.1
+ columns:
+ - Column:
+ name: id
+ type: int
+ distinctCopyCount: 2
+ colocateWith: none
- Table:
name: ref
citusType: REFERENCE
diff --git a/src/test/regress/citus_tests/query_generator/config/config_parser.py b/src/test/regress/citus_tests/query_generator/config/config_parser.py
index 0b4f3837f..09a1438b7 100755
--- a/src/test/regress/citus_tests/query_generator/config/config_parser.py
+++ b/src/test/regress/citus_tests/query_generator/config/config_parser.py
@@ -49,6 +49,7 @@ def parseTable(targetTableDict):
col = parseColumn(columnDict)
columns.append(col)
distinctCopyCount = targetTableDict["distinctCopyCount"]
+ colocateWith = targetTableDict.get("colocateWith", "default")
return Table(
name,
citusType,
@@ -58,6 +59,7 @@ def parseTable(targetTableDict):
duplicateRate,
columns,
distinctCopyCount,
+ colocateWith,
)
diff --git a/src/test/regress/citus_tests/query_generator/data_gen.py b/src/test/regress/citus_tests/query_generator/data_gen.py
index 96f7a1366..d48b3ace4 100644
--- a/src/test/regress/citus_tests/query_generator/data_gen.py
+++ b/src/test/regress/citus_tests/query_generator/data_gen.py
@@ -17,7 +17,10 @@ def getTableData():
dataGenerationSql += "\n"
# generate null rows
- if not table.citusType == CitusType.DISTRIBUTED:
+ if table.citusType not in (
+ CitusType.HASH_DISTRIBUTED,
+ CitusType.SINGLE_SHARD_DISTRIBUTED,
+ ):
targetNullRows = int(table.rowCount * table.nullRate)
dataGenerationSql += _genNullData(table.name, targetNullRows)
dataGenerationSql += "\n"
@@ -32,51 +35,23 @@ def getTableData():
def _genOverlappingData(tableName, startVal, rowCount):
"""returns string to fill table with [startVal,startVal+rowCount] range of integers"""
- dataGenerationSql = ""
- dataGenerationSql += "INSERT INTO " + tableName
- dataGenerationSql += (
- " SELECT i FROM generate_series("
- + str(startVal)
- + ","
- + str(startVal + rowCount)
- + ") i;"
- )
- return dataGenerationSql
+ return f"INSERT INTO {tableName} SELECT i FROM generate_series({startVal}, {startVal + rowCount}) i;"
def _genNullData(tableName, nullCount):
"""returns string to fill table with NULLs"""
- dataGenerationSql = ""
- dataGenerationSql += "INSERT INTO " + tableName
- dataGenerationSql += (
- " SELECT NULL FROM generate_series(0," + str(nullCount) + ") i;"
+ return (
+ f"INSERT INTO {tableName} SELECT NULL FROM generate_series(0, {nullCount}) i;"
)
- return dataGenerationSql
def _genDupData(tableName, dupRowCount):
"""returns string to fill table with duplicate integers which are fetched from given table"""
- dataGenerationSql = ""
- dataGenerationSql += "INSERT INTO " + tableName
- dataGenerationSql += (
- " SELECT * FROM "
- + tableName
- + " ORDER BY "
- + getConfig().commonColName
- + " LIMIT "
- + str(dupRowCount)
- + ";"
- )
- return dataGenerationSql
+ return f"INSERT INTO {tableName} SELECT * FROM {tableName} ORDER BY {getConfig().commonColName} LIMIT {dupRowCount};"
def _genNonOverlappingData(tableName, startVal, tableIdx):
"""returns string to fill table with different integers for given table"""
startVal = startVal + tableIdx * 20
endVal = startVal + 20
- dataGenerationSql = ""
- dataGenerationSql += "INSERT INTO " + tableName
- dataGenerationSql += (
- " SELECT i FROM generate_series(" + str(startVal) + "," + str(endVal) + ") i;"
- )
- return dataGenerationSql
+ return f"INSERT INTO {tableName} SELECT i FROM generate_series({startVal}, {endVal}) i;"
diff --git a/src/test/regress/citus_tests/query_generator/ddl_gen.py b/src/test/regress/citus_tests/query_generator/ddl_gen.py
index b2f97f694..8d7b05950 100755
--- a/src/test/regress/citus_tests/query_generator/ddl_gen.py
+++ b/src/test/regress/citus_tests/query_generator/ddl_gen.py
@@ -1,4 +1,9 @@
-from config.config import getConfig, isTableDistributed, isTableReference
+from config.config import (
+ getConfig,
+ isTableHashDistributed,
+ isTableReference,
+ isTableSingleShardDistributed,
+)
def getTableDDLs():
@@ -12,10 +17,8 @@ def getTableDDLs():
def _genTableDDL(table):
ddl = ""
- ddl += "DROP TABLE IF EXISTS " + table.name + ";"
- ddl += "\n"
-
- ddl += "CREATE TABLE " + table.name + "("
+ ddl += f"DROP TABLE IF EXISTS {table.name};\n"
+ ddl += f"CREATE TABLE {table.name}("
for column in table.columns[:-1]:
ddl += _genColumnDDL(column)
ddl += ",\n"
@@ -23,26 +26,14 @@ def _genTableDDL(table):
ddl += _genColumnDDL(table.columns[-1])
ddl += ");\n"
- if isTableDistributed(table):
- ddl += (
- "SELECT create_distributed_table("
- + "'"
- + table.name
- + "','"
- + getConfig().commonColName
- + "'"
- + ");"
- )
- ddl += "\n"
+ if isTableHashDistributed(table):
+ ddl += f"SELECT create_distributed_table('{table.name}','{getConfig().commonColName}', colocate_with=>'{table.colocateWith}');\n"
+ if isTableSingleShardDistributed(table):
+ ddl += f"SELECT create_distributed_table('{table.name}', NULL, colocate_with=>'{table.colocateWith}');\n"
elif isTableReference(table):
- ddl += "SELECT create_reference_table(" + "'" + table.name + "'" + ");"
- ddl += "\n"
+ ddl += f"SELECT create_reference_table('{table.name}');\n"
return ddl
def _genColumnDDL(column):
- ddl = ""
- ddl += column.name
- ddl += " "
- ddl += column.type
- return ddl
+ return f"{column.name} {column.type}"
diff --git a/src/test/regress/citus_tests/query_generator/generate_queries.py b/src/test/regress/citus_tests/query_generator/generate_queries.py
index dd63b17ec..2029ace5b 100755
--- a/src/test/regress/citus_tests/query_generator/generate_queries.py
+++ b/src/test/regress/citus_tests/query_generator/generate_queries.py
@@ -55,8 +55,15 @@ def _fileMode(ddls, data):
)
with open(fileName, "w") as f:
# enable repartition joins due to https://github.com/citusdata/citus/issues/6865
- enableRepartitionJoinCommand = "SET citus.enable_repartition_joins TO on;\n"
- queryLines = [enableRepartitionJoinCommand]
+ queryLines = []
+ if getConfig().repartitionJoin:
+ enableRepartitionJoinCommand = "SET citus.enable_repartition_joins TO on;\n"
+ queryLines.append(enableRepartitionJoinCommand)
+ if getConfig().singleRepartitionJoin:
+ singleRepartitionJoinCommand = (
+ "SET citus.enable_single_hash_repartition_joins TO on;\n"
+ )
+ queryLines.append(singleRepartitionJoinCommand)
queryId = 1
for _ in range(queryCount):
query = newQuery()
diff --git a/src/test/regress/citus_tests/query_generator/node_defs.py b/src/test/regress/citus_tests/query_generator/node_defs.py
index b0db1da63..3b8e2acfa 100755
--- a/src/test/regress/citus_tests/query_generator/node_defs.py
+++ b/src/test/regress/citus_tests/query_generator/node_defs.py
@@ -22,9 +22,10 @@ class RestrictOp(Enum):
class CitusType(Enum):
- DISTRIBUTED = 1
- REFERENCE = 2
- POSTGRES = 3
+ HASH_DISTRIBUTED = 1
+ SINGLE_SHARD_DISTRIBUTED = 2
+ REFERENCE = 3
+ POSTGRES = 4
class Table:
@@ -38,6 +39,7 @@ class Table:
duplicateRate,
columns,
distinctCopyCount,
+ colocateWith,
):
self.name = name
self.citusType = citusType
@@ -47,6 +49,7 @@ class Table:
self.duplicateRate = duplicateRate
self.columns = columns
self.distinctCopyCount = distinctCopyCount
+ self.colocateWith = colocateWith
class Column:
diff --git a/src/test/regress/citus_tests/query_generator/query_gen.py b/src/test/regress/citus_tests/query_generator/query_gen.py
index e25525d29..bf5da3c05 100644
--- a/src/test/regress/citus_tests/query_generator/query_gen.py
+++ b/src/test/regress/citus_tests/query_generator/query_gen.py
@@ -114,15 +114,15 @@ class GeneratorContext:
def randomCteName(self):
"""returns a randomly selected cte name"""
randCteRef = random.randint(0, self.currentCteCount - 1)
- return " cte_" + str(randCteRef)
+ return f" cte_{randCteRef}"
def curAlias(self):
"""returns current alias name to be used for the current table"""
- return " table_" + str(self.totalRteCount)
+ return f" table_{self.totalRteCount}"
def curCteAlias(self):
"""returns current alias name to be used for the current cte"""
- return " cte_" + str(self.currentCteCount)
+ return f" cte_{self.currentCteCount}"
def hasAnyCte(self):
"""returns if context has any cte"""
@@ -153,7 +153,7 @@ class GeneratorContext:
# do not enforce per table rte limit if we are inside cte
if self.insideCte:
rteName = random.choice(getAllTableNames())
- return " " + rteName + " "
+ return f" {rteName} "
while True:
# keep trying to find random table by eliminating the ones which hit table limit
@@ -173,7 +173,7 @@ class GeneratorContext:
# increment rte count for the table name
self.perTableRtes[rteName] += 1
- return " " + rteName + " "
+ return f" {rteName} "
def newQuery():
@@ -206,8 +206,7 @@ def _genQuery(genCtx):
and not genCtx.usedAvg
):
genCtx.usedAvg = True
- query += "SELECT "
- query += "count(*), avg(avgsub." + getConfig().commonColName + ") FROM "
+ query += f"SELECT count(*), avg(avgsub.{getConfig().commonColName}) FROM "
query += _genSubqueryRte(genCtx)
query += " AS avgsub"
else:
@@ -226,35 +225,24 @@ def _genQuery(genCtx):
def _genOrderBy(genCtx):
# 'ORDER BY' DistColName
- query = ""
- query += " ORDER BY "
- query += getConfig().commonColName + " "
- return query
+ return f" ORDER BY {getConfig().commonColName} "
def _genLimit(genCtx):
# 'LIMIT' 'random()'
- query = ""
- query += " LIMIT "
(fromVal, toVal) = getConfig().limitRange
- query += str(random.randint(fromVal, toVal))
- return query
+ return f" LIMIT {random.randint(fromVal, toVal)}"
def _genSelectExpr(genCtx):
# 'SELECT' 'curAlias()' '.' DistColName
- query = ""
- query += " SELECT "
commonColName = getConfig().commonColName
- query += genCtx.curAlias() + "." + commonColName + " "
-
- return query
+ return f" SELECT {genCtx.curAlias()}.{commonColName} "
def _genFromExpr(genCtx):
# 'FROM' (Rte JoinList JoinOp Rte Using || RteList) ['WHERE' 'nextRandomAlias()' '.' DistColName RestrictExpr]
- query = ""
- query += " FROM "
+ query = " FROM "
if shouldSelectThatBranch():
query += _genRte(genCtx)
@@ -267,8 +255,7 @@ def _genFromExpr(genCtx):
alias = genCtx.removeLastAlias()
if shouldSelectThatBranch():
- query += " WHERE "
- query += alias + "." + getConfig().commonColName
+ query += f" WHERE {alias}.{getConfig().commonColName}"
query += _genRestrictExpr(genCtx)
return query
@@ -353,9 +340,7 @@ def _genJoinList(genCtx):
def _genUsing(genCtx):
# 'USING' '(' DistColName ')'
- query = ""
- query += " USING (" + getConfig().commonColName + " ) "
- return query
+ return f" USING ({getConfig().commonColName}) "
def _genRte(genCtx):
@@ -392,7 +377,7 @@ def _genRte(genCtx):
query += _genCteRte(genCtx)
elif rteType == RTEType.VALUES:
query += _genValuesRte(genCtx)
- modifiedAlias = alias + "(" + getConfig().commonColName + ") "
+ modifiedAlias = f"{alias}({getConfig().commonColName}) "
else:
raise BaseException("unknown RTE type")
@@ -428,7 +413,5 @@ def _genCteRte(genCtx):
def _genValuesRte(genCtx):
# '( VALUES(random()) )'
- query = ""
(fromVal, toVal) = getConfig().dataRange
- query += " ( VALUES(" + str(random.randint(fromVal, toVal)) + " ) ) "
- return query
+ return f" ( VALUES({random.randint(fromVal, toVal)}) ) "
diff --git a/src/test/regress/citus_tests/query_generator/random_selections.py b/src/test/regress/citus_tests/query_generator/random_selections.py
index ee32c620a..47f55b3e5 100644
--- a/src/test/regress/citus_tests/query_generator/random_selections.py
+++ b/src/test/regress/citus_tests/query_generator/random_selections.py
@@ -25,7 +25,7 @@ def randomRteType():
def randomJoinOp():
"""returns a randomly selected JoinOp given at config"""
joinTypes = getConfig().targetJoinTypes
- return " " + random.choice(joinTypes).name + " JOIN"
+ return f" {random.choice(joinTypes).name} JOIN"
def randomRestrictOp():
@@ -42,4 +42,4 @@ def randomRestrictOp():
else:
raise BaseException("Unknown restrict op")
- return " " + opText + " "
+ return f" {opText} "