pull/6977/merge
aykut-bozkurt 2025-11-19 21:26:11 -08:00 committed by GitHub
commit b7ab41731e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 89 additions and 99 deletions

View File

@ -53,6 +53,7 @@ targetTables: <Table[]>
name: <string> name: <string>
type: <string> type: <string>
distinctCopyCount: <int> distinctCopyCount: <int>
colocateWith: Optional<string>
``` ```
Explanation: Explanation:
@ -72,6 +73,7 @@ targetTables: "array of tables that will be used in generated queries"
name: "name of column" name: "name of column"
type: "name of data type of column(only support 'int' now)" type: "name of data type of column(only support 'int' now)"
distinctCopyCount: "how many tables with the same configuration we should create(only by changing full name, still using the same name prefix)" distinctCopyCount: "how many tables with the same configuration we should create(only by changing full name, still using the same name prefix)"
colocateWith: "colocated_with parameter"
``` ```
@ -87,6 +89,8 @@ Schema for Query configuration:
```yaml ```yaml
queryCount: <int> queryCount: <int>
queryOutFile: <string> queryOutFile: <string>
repartitionJoin: <bool>
singleRepartitionJoin: <bool>
semiAntiJoin: <bool> semiAntiJoin: <bool>
cartesianProduct: <bool> cartesianProduct: <bool>
limit: <bool> limit: <bool>
@ -114,6 +118,8 @@ Explanation:
```yaml ```yaml
queryCount: "number of queries to generate" queryCount: "number of queries to generate"
queryOutFile: "file to write generated queries" queryOutFile: "file to write generated queries"
repartitionJoin: "should we enable repartition join"
singleRepartitionJoin: "should we make default repartition join mode as single repartition join (default is dual)"
semiAntiJoin: "should we support semi joins (WHERE col IN (Subquery))" semiAntiJoin: "should we support semi joins (WHERE col IN (Subquery))"
cartesianProduct: "should we support cartesian joins" cartesianProduct: "should we support cartesian joins"
limit: "should we support limit clause" limit: "should we support limit clause"
@ -172,7 +178,7 @@ Tool supports following citus table types:
targetTables: targetTables:
- Table: - Table:
... ...
citusType: <one of (DISTRIBUTED || REFERENCE || POSTGRES)> citusType: <one of (HASH_DISTRIBUTED || SINGLE_SHARD_DISTRIBUTED || REFERENCE || POSTGRES)>
... ...
``` ```

View File

@ -73,6 +73,8 @@ def removeFailedQueryOutputFromFiles(distQueryOutFile, localQueryOutFile):
acceptableErrors = [ acceptableErrors = [
"ERROR: complex joins are only supported when all distributed tables are co-located and joined on their distribution columns", "ERROR: complex joins are only supported when all distributed tables are co-located and joined on their distribution columns",
"ERROR: recursive complex joins are only supported when all distributed tables are co-located and joined on their distribution columns", "ERROR: recursive complex joins are only supported when all distributed tables are co-located and joined on their distribution columns",
"ERROR: cannot push down this subquery",
"ERROR: cannot perform a lateral outer join when a distributed subquery references complex subqueries, CTEs or local tables",
] ]
failedDistQueryIds = findFailedQueriesFromFile(distQueryOutFile, acceptableErrors) failedDistQueryIds = findFailedQueriesFromFile(distQueryOutFile, acceptableErrors)
removeFailedQueryOutputFromFile(distQueryOutFile, failedDistQueryIds) removeFailedQueryOutputFromFile(distQueryOutFile, failedDistQueryIds)

View File

@ -28,6 +28,8 @@ class Config:
self.targetRteCount = configObj["targetRteCount"] self.targetRteCount = configObj["targetRteCount"]
self.targetCteCount = configObj["targetCteCount"] self.targetCteCount = configObj["targetCteCount"]
self.targetCteRteCount = configObj["targetCteRteCount"] self.targetCteRteCount = configObj["targetCteRteCount"]
self.repartitionJoin = configObj["repartitionJoin"]
self.singleRepartitionJoin = configObj["singleRepartitionJoin"]
self.semiAntiJoin = configObj["semiAntiJoin"] self.semiAntiJoin = configObj["semiAntiJoin"]
self.cartesianProduct = configObj["cartesianProduct"] self.cartesianProduct = configObj["cartesianProduct"]
self.limit = configObj["limit"] self.limit = configObj["limit"]
@ -111,8 +113,12 @@ def getMaxAllowedCountForTable(tableName):
return filtered[0].maxAllowedUseOnQuery return filtered[0].maxAllowedUseOnQuery
def isTableDistributed(table): def isTableHashDistributed(table):
return table.citusType == CitusType.DISTRIBUTED return table.citusType == CitusType.HASH_DISTRIBUTED
def isTableSingleShardDistributed(table):
return table.citusType == CitusType.SINGLE_SHARD_DISTRIBUTED
def isTableReference(table): def isTableReference(table):

View File

@ -2,6 +2,8 @@ interactiveMode: false
queryCount: 250 queryCount: 250
queryOutFile: queries.sql queryOutFile: queries.sql
ddlOutFile: ddls.sql ddlOutFile: ddls.sql
repartitionJoin: true
singleRepartitionJoin: false
semiAntiJoin: true semiAntiJoin: true
cartesianProduct: false cartesianProduct: false
limit: true limit: true
@ -26,7 +28,7 @@ commonColName: id
targetTables: targetTables:
- Table: - Table:
name: dist name: dist
citusType: DISTRIBUTED citusType: HASH_DISTRIBUTED
maxAllowedUseOnQuery: 10 maxAllowedUseOnQuery: 10
rowCount: 10 rowCount: 10
nullRate: 0.1 nullRate: 0.1
@ -36,6 +38,19 @@ targetTables:
name: id name: id
type: int type: int
distinctCopyCount: 2 distinctCopyCount: 2
- Table:
name: single_dist
citusType: SINGLE_SHARD_DISTRIBUTED
maxAllowedUseOnQuery: 10
rowCount: 10
nullRate: 0.1
duplicateRate: 0.1
columns:
- Column:
name: id
type: int
distinctCopyCount: 2
colocateWith: none
- Table: - Table:
name: ref name: ref
citusType: REFERENCE citusType: REFERENCE

View File

@ -49,6 +49,7 @@ def parseTable(targetTableDict):
col = parseColumn(columnDict) col = parseColumn(columnDict)
columns.append(col) columns.append(col)
distinctCopyCount = targetTableDict["distinctCopyCount"] distinctCopyCount = targetTableDict["distinctCopyCount"]
colocateWith = targetTableDict.get("colocateWith", "default")
return Table( return Table(
name, name,
citusType, citusType,
@ -58,6 +59,7 @@ def parseTable(targetTableDict):
duplicateRate, duplicateRate,
columns, columns,
distinctCopyCount, distinctCopyCount,
colocateWith,
) )

View File

@ -17,7 +17,10 @@ def getTableData():
dataGenerationSql += "\n" dataGenerationSql += "\n"
# generate null rows # generate null rows
if not table.citusType == CitusType.DISTRIBUTED: if table.citusType not in (
CitusType.HASH_DISTRIBUTED,
CitusType.SINGLE_SHARD_DISTRIBUTED,
):
targetNullRows = int(table.rowCount * table.nullRate) targetNullRows = int(table.rowCount * table.nullRate)
dataGenerationSql += _genNullData(table.name, targetNullRows) dataGenerationSql += _genNullData(table.name, targetNullRows)
dataGenerationSql += "\n" dataGenerationSql += "\n"
@ -32,51 +35,23 @@ def getTableData():
def _genOverlappingData(tableName, startVal, rowCount): def _genOverlappingData(tableName, startVal, rowCount):
"""returns string to fill table with [startVal,startVal+rowCount] range of integers""" """returns string to fill table with [startVal,startVal+rowCount] range of integers"""
dataGenerationSql = "" return f"INSERT INTO {tableName} SELECT i FROM generate_series({startVal}, {startVal + rowCount}) i;"
dataGenerationSql += "INSERT INTO " + tableName
dataGenerationSql += (
" SELECT i FROM generate_series("
+ str(startVal)
+ ","
+ str(startVal + rowCount)
+ ") i;"
)
return dataGenerationSql
def _genNullData(tableName, nullCount): def _genNullData(tableName, nullCount):
"""returns string to fill table with NULLs""" """returns string to fill table with NULLs"""
dataGenerationSql = "" return (
dataGenerationSql += "INSERT INTO " + tableName f"INSERT INTO {tableName} SELECT NULL FROM generate_series(0, {nullCount}) i;"
dataGenerationSql += (
" SELECT NULL FROM generate_series(0," + str(nullCount) + ") i;"
) )
return dataGenerationSql
def _genDupData(tableName, dupRowCount): def _genDupData(tableName, dupRowCount):
"""returns string to fill table with duplicate integers which are fetched from given table""" """returns string to fill table with duplicate integers which are fetched from given table"""
dataGenerationSql = "" return f"INSERT INTO {tableName} SELECT * FROM {tableName} ORDER BY {getConfig().commonColName} LIMIT {dupRowCount};"
dataGenerationSql += "INSERT INTO " + tableName
dataGenerationSql += (
" SELECT * FROM "
+ tableName
+ " ORDER BY "
+ getConfig().commonColName
+ " LIMIT "
+ str(dupRowCount)
+ ";"
)
return dataGenerationSql
def _genNonOverlappingData(tableName, startVal, tableIdx): def _genNonOverlappingData(tableName, startVal, tableIdx):
"""returns string to fill table with different integers for given table""" """returns string to fill table with different integers for given table"""
startVal = startVal + tableIdx * 20 startVal = startVal + tableIdx * 20
endVal = startVal + 20 endVal = startVal + 20
dataGenerationSql = "" return f"INSERT INTO {tableName} SELECT i FROM generate_series({startVal}, {endVal}) i;"
dataGenerationSql += "INSERT INTO " + tableName
dataGenerationSql += (
" SELECT i FROM generate_series(" + str(startVal) + "," + str(endVal) + ") i;"
)
return dataGenerationSql

View File

@ -1,4 +1,9 @@
from config.config import getConfig, isTableDistributed, isTableReference from config.config import (
getConfig,
isTableHashDistributed,
isTableReference,
isTableSingleShardDistributed,
)
def getTableDDLs(): def getTableDDLs():
@ -12,10 +17,8 @@ def getTableDDLs():
def _genTableDDL(table): def _genTableDDL(table):
ddl = "" ddl = ""
ddl += "DROP TABLE IF EXISTS " + table.name + ";" ddl += f"DROP TABLE IF EXISTS {table.name};\n"
ddl += "\n" ddl += f"CREATE TABLE {table.name}("
ddl += "CREATE TABLE " + table.name + "("
for column in table.columns[:-1]: for column in table.columns[:-1]:
ddl += _genColumnDDL(column) ddl += _genColumnDDL(column)
ddl += ",\n" ddl += ",\n"
@ -23,26 +26,14 @@ def _genTableDDL(table):
ddl += _genColumnDDL(table.columns[-1]) ddl += _genColumnDDL(table.columns[-1])
ddl += ");\n" ddl += ");\n"
if isTableDistributed(table): if isTableHashDistributed(table):
ddl += ( ddl += f"SELECT create_distributed_table('{table.name}','{getConfig().commonColName}', colocate_with=>'{table.colocateWith}');\n"
"SELECT create_distributed_table(" if isTableSingleShardDistributed(table):
+ "'" ddl += f"SELECT create_distributed_table('{table.name}', NULL, colocate_with=>'{table.colocateWith}');\n"
+ table.name
+ "','"
+ getConfig().commonColName
+ "'"
+ ");"
)
ddl += "\n"
elif isTableReference(table): elif isTableReference(table):
ddl += "SELECT create_reference_table(" + "'" + table.name + "'" + ");" ddl += f"SELECT create_reference_table('{table.name}');\n"
ddl += "\n"
return ddl return ddl
def _genColumnDDL(column): def _genColumnDDL(column):
ddl = "" return f"{column.name} {column.type}"
ddl += column.name
ddl += " "
ddl += column.type
return ddl

View File

@ -55,8 +55,15 @@ def _fileMode(ddls, data):
) )
with open(fileName, "w") as f: with open(fileName, "w") as f:
# enable repartition joins due to https://github.com/citusdata/citus/issues/6865 # enable repartition joins due to https://github.com/citusdata/citus/issues/6865
enableRepartitionJoinCommand = "SET citus.enable_repartition_joins TO on;\n" queryLines = []
queryLines = [enableRepartitionJoinCommand] if getConfig().repartitionJoin:
enableRepartitionJoinCommand = "SET citus.enable_repartition_joins TO on;\n"
queryLines.append(enableRepartitionJoinCommand)
if getConfig().singleRepartitionJoin:
singleRepartitionJoinCommand = (
"SET citus.enable_single_hash_repartition_joins TO on;\n"
)
queryLines.append(singleRepartitionJoinCommand)
queryId = 1 queryId = 1
for _ in range(queryCount): for _ in range(queryCount):
query = newQuery() query = newQuery()

View File

@ -22,9 +22,10 @@ class RestrictOp(Enum):
class CitusType(Enum): class CitusType(Enum):
DISTRIBUTED = 1 HASH_DISTRIBUTED = 1
REFERENCE = 2 SINGLE_SHARD_DISTRIBUTED = 2
POSTGRES = 3 REFERENCE = 3
POSTGRES = 4
class Table: class Table:
@ -38,6 +39,7 @@ class Table:
duplicateRate, duplicateRate,
columns, columns,
distinctCopyCount, distinctCopyCount,
colocateWith,
): ):
self.name = name self.name = name
self.citusType = citusType self.citusType = citusType
@ -47,6 +49,7 @@ class Table:
self.duplicateRate = duplicateRate self.duplicateRate = duplicateRate
self.columns = columns self.columns = columns
self.distinctCopyCount = distinctCopyCount self.distinctCopyCount = distinctCopyCount
self.colocateWith = colocateWith
class Column: class Column:

View File

@ -114,15 +114,15 @@ class GeneratorContext:
def randomCteName(self): def randomCteName(self):
"""returns a randomly selected cte name""" """returns a randomly selected cte name"""
randCteRef = random.randint(0, self.currentCteCount - 1) randCteRef = random.randint(0, self.currentCteCount - 1)
return " cte_" + str(randCteRef) return f" cte_{randCteRef}"
def curAlias(self): def curAlias(self):
"""returns current alias name to be used for the current table""" """returns current alias name to be used for the current table"""
return " table_" + str(self.totalRteCount) return f" table_{self.totalRteCount}"
def curCteAlias(self): def curCteAlias(self):
"""returns current alias name to be used for the current cte""" """returns current alias name to be used for the current cte"""
return " cte_" + str(self.currentCteCount) return f" cte_{self.currentCteCount}"
def hasAnyCte(self): def hasAnyCte(self):
"""returns if context has any cte""" """returns if context has any cte"""
@ -153,7 +153,7 @@ class GeneratorContext:
# do not enforce per table rte limit if we are inside cte # do not enforce per table rte limit if we are inside cte
if self.insideCte: if self.insideCte:
rteName = random.choice(getAllTableNames()) rteName = random.choice(getAllTableNames())
return " " + rteName + " " return f" {rteName} "
while True: while True:
# keep trying to find random table by eliminating the ones which hit table limit # keep trying to find random table by eliminating the ones which hit table limit
@ -173,7 +173,7 @@ class GeneratorContext:
# increment rte count for the table name # increment rte count for the table name
self.perTableRtes[rteName] += 1 self.perTableRtes[rteName] += 1
return " " + rteName + " " return f" {rteName} "
def newQuery(): def newQuery():
@ -206,8 +206,7 @@ def _genQuery(genCtx):
and not genCtx.usedAvg and not genCtx.usedAvg
): ):
genCtx.usedAvg = True genCtx.usedAvg = True
query += "SELECT " query += f"SELECT count(*), avg(avgsub.{getConfig().commonColName}) FROM "
query += "count(*), avg(avgsub." + getConfig().commonColName + ") FROM "
query += _genSubqueryRte(genCtx) query += _genSubqueryRte(genCtx)
query += " AS avgsub" query += " AS avgsub"
else: else:
@ -226,35 +225,24 @@ def _genQuery(genCtx):
def _genOrderBy(genCtx): def _genOrderBy(genCtx):
# 'ORDER BY' DistColName # 'ORDER BY' DistColName
query = "" return f" ORDER BY {getConfig().commonColName} "
query += " ORDER BY "
query += getConfig().commonColName + " "
return query
def _genLimit(genCtx): def _genLimit(genCtx):
# 'LIMIT' 'random()' # 'LIMIT' 'random()'
query = ""
query += " LIMIT "
(fromVal, toVal) = getConfig().limitRange (fromVal, toVal) = getConfig().limitRange
query += str(random.randint(fromVal, toVal)) return f" LIMIT {random.randint(fromVal, toVal)}"
return query
def _genSelectExpr(genCtx): def _genSelectExpr(genCtx):
# 'SELECT' 'curAlias()' '.' DistColName # 'SELECT' 'curAlias()' '.' DistColName
query = ""
query += " SELECT "
commonColName = getConfig().commonColName commonColName = getConfig().commonColName
query += genCtx.curAlias() + "." + commonColName + " " return f" SELECT {genCtx.curAlias()}.{commonColName} "
return query
def _genFromExpr(genCtx): def _genFromExpr(genCtx):
# 'FROM' (Rte JoinList JoinOp Rte Using || RteList) ['WHERE' 'nextRandomAlias()' '.' DistColName RestrictExpr] # 'FROM' (Rte JoinList JoinOp Rte Using || RteList) ['WHERE' 'nextRandomAlias()' '.' DistColName RestrictExpr]
query = "" query = " FROM "
query += " FROM "
if shouldSelectThatBranch(): if shouldSelectThatBranch():
query += _genRte(genCtx) query += _genRte(genCtx)
@ -267,8 +255,7 @@ def _genFromExpr(genCtx):
alias = genCtx.removeLastAlias() alias = genCtx.removeLastAlias()
if shouldSelectThatBranch(): if shouldSelectThatBranch():
query += " WHERE " query += f" WHERE {alias}.{getConfig().commonColName}"
query += alias + "." + getConfig().commonColName
query += _genRestrictExpr(genCtx) query += _genRestrictExpr(genCtx)
return query return query
@ -353,9 +340,7 @@ def _genJoinList(genCtx):
def _genUsing(genCtx): def _genUsing(genCtx):
# 'USING' '(' DistColName ')' # 'USING' '(' DistColName ')'
query = "" return f" USING ({getConfig().commonColName}) "
query += " USING (" + getConfig().commonColName + " ) "
return query
def _genRte(genCtx): def _genRte(genCtx):
@ -392,7 +377,7 @@ def _genRte(genCtx):
query += _genCteRte(genCtx) query += _genCteRte(genCtx)
elif rteType == RTEType.VALUES: elif rteType == RTEType.VALUES:
query += _genValuesRte(genCtx) query += _genValuesRte(genCtx)
modifiedAlias = alias + "(" + getConfig().commonColName + ") " modifiedAlias = f"{alias}({getConfig().commonColName}) "
else: else:
raise BaseException("unknown RTE type") raise BaseException("unknown RTE type")
@ -428,7 +413,5 @@ def _genCteRte(genCtx):
def _genValuesRte(genCtx): def _genValuesRte(genCtx):
# '( VALUES(random()) )' # '( VALUES(random()) )'
query = ""
(fromVal, toVal) = getConfig().dataRange (fromVal, toVal) = getConfig().dataRange
query += " ( VALUES(" + str(random.randint(fromVal, toVal)) + " ) ) " return f" ( VALUES({random.randint(fromVal, toVal)}) ) "
return query

View File

@ -25,7 +25,7 @@ def randomRteType():
def randomJoinOp(): def randomJoinOp():
"""returns a randomly selected JoinOp given at config""" """returns a randomly selected JoinOp given at config"""
joinTypes = getConfig().targetJoinTypes joinTypes = getConfig().targetJoinTypes
return " " + random.choice(joinTypes).name + " JOIN" return f" {random.choice(joinTypes).name} JOIN"
def randomRestrictOp(): def randomRestrictOp():
@ -42,4 +42,4 @@ def randomRestrictOp():
else: else:
raise BaseException("Unknown restrict op") raise BaseException("Unknown restrict op")
return " " + opText + " " return f" {opText} "