mirror of https://github.com/citusdata/citus.git
Format python files with black
parent
42970665fc
commit
530b24a887
|
@ -14,6 +14,7 @@ class FileScanner:
|
||||||
FileScanner is an iterator over the lines of a file.
|
FileScanner is an iterator over the lines of a file.
|
||||||
It can apply a rewrite rule which can be used to skip lines.
|
It can apply a rewrite rule which can be used to skip lines.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, file, rewrite=lambda x: x):
|
def __init__(self, file, rewrite=lambda x: x):
|
||||||
self.file = file
|
self.file = file
|
||||||
self.line = 1
|
self.line = 1
|
||||||
|
@ -33,12 +34,12 @@ def main():
|
||||||
regexpipeline = []
|
regexpipeline = []
|
||||||
for line in open(argv[1]):
|
for line in open(argv[1]):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line or line.startswith('#') or not line.endswith('d'):
|
if not line or line.startswith("#") or not line.endswith("d"):
|
||||||
continue
|
continue
|
||||||
rule = regexregex.match(line)
|
rule = regexregex.match(line)
|
||||||
if not rule:
|
if not rule:
|
||||||
raise 'Failed to parse regex rule: %s' % line
|
raise "Failed to parse regex rule: %s" % line
|
||||||
regexpipeline.append(re.compile(rule.group('rule')))
|
regexpipeline.append(re.compile(rule.group("rule")))
|
||||||
|
|
||||||
def sed(line):
|
def sed(line):
|
||||||
if any(regex.search(line) for regex in regexpipeline):
|
if any(regex.search(line) for regex in regexpipeline):
|
||||||
|
@ -46,13 +47,15 @@ def main():
|
||||||
return line
|
return line
|
||||||
|
|
||||||
for line in stdin:
|
for line in stdin:
|
||||||
if line.startswith('+++ '):
|
if line.startswith("+++ "):
|
||||||
tab = line.rindex('\t')
|
tab = line.rindex("\t")
|
||||||
fname = line[4:tab]
|
fname = line[4:tab]
|
||||||
file2 = FileScanner(open(fname.replace('.modified', ''), encoding='utf8'), sed)
|
file2 = FileScanner(
|
||||||
|
open(fname.replace(".modified", ""), encoding="utf8"), sed
|
||||||
|
)
|
||||||
stdout.write(line)
|
stdout.write(line)
|
||||||
elif line.startswith('@@ '):
|
elif line.startswith("@@ "):
|
||||||
idx_start = line.index('+') + 1
|
idx_start = line.index("+") + 1
|
||||||
idx_end = idx_start + 1
|
idx_end = idx_start + 1
|
||||||
while line[idx_end].isdigit():
|
while line[idx_end].isdigit():
|
||||||
idx_end += 1
|
idx_end += 1
|
||||||
|
@ -60,11 +63,11 @@ def main():
|
||||||
while file2.line < linenum:
|
while file2.line < linenum:
|
||||||
next(file2)
|
next(file2)
|
||||||
stdout.write(line)
|
stdout.write(line)
|
||||||
elif line.startswith(' '):
|
elif line.startswith(" "):
|
||||||
stdout.write(' ')
|
stdout.write(" ")
|
||||||
stdout.write(next(file2))
|
stdout.write(next(file2))
|
||||||
elif line.startswith('+'):
|
elif line.startswith("+"):
|
||||||
stdout.write('+')
|
stdout.write("+")
|
||||||
stdout.write(next(file2))
|
stdout.write(next(file2))
|
||||||
else:
|
else:
|
||||||
stdout.write(line)
|
stdout.write(line)
|
||||||
|
|
|
@ -115,7 +115,6 @@ def copy_copy_modified_binary(datadir):
|
||||||
|
|
||||||
|
|
||||||
def copy_test_files(config):
|
def copy_test_files(config):
|
||||||
|
|
||||||
sql_dir_path = os.path.join(config.datadir, "sql")
|
sql_dir_path = os.path.join(config.datadir, "sql")
|
||||||
expected_dir_path = os.path.join(config.datadir, "expected")
|
expected_dir_path = os.path.join(config.datadir, "expected")
|
||||||
|
|
||||||
|
@ -132,7 +131,9 @@ def copy_test_files(config):
|
||||||
|
|
||||||
line = line[colon_index + 1 :].strip()
|
line = line[colon_index + 1 :].strip()
|
||||||
test_names = line.split(" ")
|
test_names = line.split(" ")
|
||||||
copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config)
|
copy_test_files_with_names(
|
||||||
|
test_names, sql_dir_path, expected_dir_path, config
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config):
|
def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config):
|
||||||
|
@ -140,10 +141,10 @@ def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, conf
|
||||||
# make empty files for the skipped tests
|
# make empty files for the skipped tests
|
||||||
if test_name in config.skip_tests:
|
if test_name in config.skip_tests:
|
||||||
expected_sql_file = os.path.join(sql_dir_path, test_name + ".sql")
|
expected_sql_file = os.path.join(sql_dir_path, test_name + ".sql")
|
||||||
open(expected_sql_file, 'x').close()
|
open(expected_sql_file, "x").close()
|
||||||
|
|
||||||
expected_out_file = os.path.join(expected_dir_path, test_name + ".out")
|
expected_out_file = os.path.join(expected_dir_path, test_name + ".out")
|
||||||
open(expected_out_file, 'x').close()
|
open(expected_out_file, "x").close()
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
@ -27,13 +27,11 @@ def initialize_temp_dir_if_not_exists(temp_dir):
|
||||||
|
|
||||||
def parallel_run(function, items, *args, **kwargs):
|
def parallel_run(function, items, *args, **kwargs):
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
futures = [
|
futures = [executor.submit(function, item, *args, **kwargs) for item in items]
|
||||||
executor.submit(function, item, *args, **kwargs)
|
|
||||||
for item in items
|
|
||||||
]
|
|
||||||
for future in futures:
|
for future in futures:
|
||||||
future.result()
|
future.result()
|
||||||
|
|
||||||
|
|
||||||
def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names):
|
def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names):
|
||||||
subprocess.run(["mkdir", rel_data_path], check=True)
|
subprocess.run(["mkdir", rel_data_path], check=True)
|
||||||
|
|
||||||
|
@ -52,7 +50,7 @@ def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names):
|
||||||
"--encoding",
|
"--encoding",
|
||||||
"UTF8",
|
"UTF8",
|
||||||
"--locale",
|
"--locale",
|
||||||
"POSIX"
|
"POSIX",
|
||||||
]
|
]
|
||||||
subprocess.run(command, check=True)
|
subprocess.run(command, check=True)
|
||||||
add_settings(abs_data_path, settings)
|
add_settings(abs_data_path, settings)
|
||||||
|
@ -76,7 +74,9 @@ def create_role(pg_path, node_ports, user_name):
|
||||||
user_name, user_name
|
user_name, user_name
|
||||||
)
|
)
|
||||||
utils.psql(pg_path, port, command)
|
utils.psql(pg_path, port, command)
|
||||||
command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format(user_name)
|
command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format(
|
||||||
|
user_name
|
||||||
|
)
|
||||||
utils.psql(pg_path, port, command)
|
utils.psql(pg_path, port, command)
|
||||||
|
|
||||||
parallel_run(create, node_ports)
|
parallel_run(create, node_ports)
|
||||||
|
@ -89,7 +89,9 @@ def coordinator_should_haveshards(pg_path, port):
|
||||||
utils.psql(pg_path, port, command)
|
utils.psql(pg_path, port, command)
|
||||||
|
|
||||||
|
|
||||||
def start_databases(pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables):
|
def start_databases(
|
||||||
|
pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables
|
||||||
|
):
|
||||||
def start(node_name):
|
def start(node_name):
|
||||||
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
|
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
|
||||||
node_port = node_name_to_ports[node_name]
|
node_port = node_name_to_ports[node_name]
|
||||||
|
@ -248,7 +250,12 @@ def logfile_name(logfile_prefix, node_name):
|
||||||
|
|
||||||
|
|
||||||
def stop_databases(
|
def stop_databases(
|
||||||
pg_path, rel_data_path, node_name_to_ports, logfile_prefix, no_output=False, parallel=True
|
pg_path,
|
||||||
|
rel_data_path,
|
||||||
|
node_name_to_ports,
|
||||||
|
logfile_prefix,
|
||||||
|
no_output=False,
|
||||||
|
parallel=True,
|
||||||
):
|
):
|
||||||
def stop(node_name):
|
def stop(node_name):
|
||||||
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
|
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
|
||||||
|
@ -287,7 +294,9 @@ def initialize_citus_cluster(bindir, datadir, settings, config):
|
||||||
initialize_db_for_cluster(
|
initialize_db_for_cluster(
|
||||||
bindir, datadir, settings, config.node_name_to_ports.keys()
|
bindir, datadir, settings, config.node_name_to_ports.keys()
|
||||||
)
|
)
|
||||||
start_databases(bindir, datadir, config.node_name_to_ports, config.name, config.env_variables)
|
start_databases(
|
||||||
|
bindir, datadir, config.node_name_to_ports, config.name, config.env_variables
|
||||||
|
)
|
||||||
create_citus_extension(bindir, config.node_name_to_ports.values())
|
create_citus_extension(bindir, config.node_name_to_ports.values())
|
||||||
add_workers(bindir, config.worker_ports, config.coordinator_port())
|
add_workers(bindir, config.worker_ports, config.coordinator_port())
|
||||||
if not config.is_mx:
|
if not config.is_mx:
|
||||||
|
@ -296,6 +305,7 @@ def initialize_citus_cluster(bindir, datadir, settings, config):
|
||||||
add_coordinator_to_metadata(bindir, config.coordinator_port())
|
add_coordinator_to_metadata(bindir, config.coordinator_port())
|
||||||
config.setup_steps()
|
config.setup_steps()
|
||||||
|
|
||||||
|
|
||||||
def eprint(*args, **kwargs):
|
def eprint(*args, **kwargs):
|
||||||
"""eprint prints to stderr"""
|
"""eprint prints to stderr"""
|
||||||
|
|
||||||
|
|
|
@ -57,8 +57,9 @@ port_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def should_include_config(class_name):
|
def should_include_config(class_name):
|
||||||
|
if inspect.isclass(class_name) and issubclass(
|
||||||
if inspect.isclass(class_name) and issubclass(class_name, CitusDefaultClusterConfig):
|
class_name, CitusDefaultClusterConfig
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -167,7 +168,9 @@ class CitusDefaultClusterConfig(CitusBaseClusterConfig):
|
||||||
self.add_coordinator_to_metadata = True
|
self.add_coordinator_to_metadata = True
|
||||||
self.skip_tests = [
|
self.skip_tests = [
|
||||||
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
||||||
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
|
"arbitrary_configs_alter_table_add_constraint_without_name_create",
|
||||||
|
"arbitrary_configs_alter_table_add_constraint_without_name",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class CitusUpgradeConfig(CitusBaseClusterConfig):
|
class CitusUpgradeConfig(CitusBaseClusterConfig):
|
||||||
|
@ -190,9 +193,13 @@ class PostgresConfig(CitusDefaultClusterConfig):
|
||||||
self.new_settings = {
|
self.new_settings = {
|
||||||
"citus.use_citus_managed_tables": False,
|
"citus.use_citus_managed_tables": False,
|
||||||
}
|
}
|
||||||
self.skip_tests = ["nested_execution",
|
self.skip_tests = [
|
||||||
|
"nested_execution",
|
||||||
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
||||||
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
|
"arbitrary_configs_alter_table_add_constraint_without_name_create",
|
||||||
|
"arbitrary_configs_alter_table_add_constraint_without_name",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class CitusSingleNodeClusterConfig(CitusDefaultClusterConfig):
|
class CitusSingleNodeClusterConfig(CitusDefaultClusterConfig):
|
||||||
def __init__(self, arguments):
|
def __init__(self, arguments):
|
||||||
|
@ -275,7 +282,7 @@ class CitusUnusualExecutorConfig(CitusDefaultClusterConfig):
|
||||||
|
|
||||||
# this setting does not necessarily need to be here
|
# this setting does not necessarily need to be here
|
||||||
# could go any other test
|
# could go any other test
|
||||||
self.env_variables = {'PGAPPNAME' : 'test_app'}
|
self.env_variables = {"PGAPPNAME": "test_app"}
|
||||||
|
|
||||||
|
|
||||||
class CitusSmallCopyBuffersConfig(CitusDefaultClusterConfig):
|
class CitusSmallCopyBuffersConfig(CitusDefaultClusterConfig):
|
||||||
|
@ -307,9 +314,13 @@ class CitusUnusualQuerySettingsConfig(CitusDefaultClusterConfig):
|
||||||
# requires the table with the fk to be converted to a citus_local_table.
|
# requires the table with the fk to be converted to a citus_local_table.
|
||||||
# As of c11, there is no way to do that through remote execution so this test
|
# As of c11, there is no way to do that through remote execution so this test
|
||||||
# will fail
|
# will fail
|
||||||
"arbitrary_configs_truncate_cascade_create", "arbitrary_configs_truncate_cascade",
|
"arbitrary_configs_truncate_cascade_create",
|
||||||
|
"arbitrary_configs_truncate_cascade",
|
||||||
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
||||||
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
|
"arbitrary_configs_alter_table_add_constraint_without_name_create",
|
||||||
|
"arbitrary_configs_alter_table_add_constraint_without_name",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class CitusSingleNodeSingleShardClusterConfig(CitusDefaultClusterConfig):
|
class CitusSingleNodeSingleShardClusterConfig(CitusDefaultClusterConfig):
|
||||||
def __init__(self, arguments):
|
def __init__(self, arguments):
|
||||||
|
@ -328,15 +339,20 @@ class CitusShardReplicationFactorClusterConfig(CitusDefaultClusterConfig):
|
||||||
self.skip_tests = [
|
self.skip_tests = [
|
||||||
# citus does not support foreign keys in distributed tables
|
# citus does not support foreign keys in distributed tables
|
||||||
# when citus.shard_replication_factor >= 2
|
# when citus.shard_replication_factor >= 2
|
||||||
"arbitrary_configs_truncate_partition_create", "arbitrary_configs_truncate_partition",
|
"arbitrary_configs_truncate_partition_create",
|
||||||
|
"arbitrary_configs_truncate_partition",
|
||||||
# citus does not support modifying a partition when
|
# citus does not support modifying a partition when
|
||||||
# citus.shard_replication_factor >= 2
|
# citus.shard_replication_factor >= 2
|
||||||
"arbitrary_configs_truncate_cascade_create", "arbitrary_configs_truncate_cascade",
|
"arbitrary_configs_truncate_cascade_create",
|
||||||
|
"arbitrary_configs_truncate_cascade",
|
||||||
# citus does not support colocating functions with distributed tables when
|
# citus does not support colocating functions with distributed tables when
|
||||||
# citus.shard_replication_factor >= 2
|
# citus.shard_replication_factor >= 2
|
||||||
"function_create", "functions",
|
"function_create",
|
||||||
|
"functions",
|
||||||
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
||||||
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
|
"arbitrary_configs_alter_table_add_constraint_without_name_create",
|
||||||
|
"arbitrary_configs_alter_table_add_constraint_without_name",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class CitusSingleShardClusterConfig(CitusDefaultClusterConfig):
|
class CitusSingleShardClusterConfig(CitusDefaultClusterConfig):
|
||||||
|
|
|
@ -12,23 +12,53 @@ import common
|
||||||
import config
|
import config
|
||||||
|
|
||||||
args = argparse.ArgumentParser()
|
args = argparse.ArgumentParser()
|
||||||
args.add_argument("test_name", help="Test name (must be included in a schedule.)", nargs='?')
|
args.add_argument(
|
||||||
args.add_argument("-p", "--path", required=False, help="Relative path for test file (must have a .sql or .spec extension)", type=pathlib.Path)
|
"test_name", help="Test name (must be included in a schedule.)", nargs="?"
|
||||||
|
)
|
||||||
|
args.add_argument(
|
||||||
|
"-p",
|
||||||
|
"--path",
|
||||||
|
required=False,
|
||||||
|
help="Relative path for test file (must have a .sql or .spec extension)",
|
||||||
|
type=pathlib.Path,
|
||||||
|
)
|
||||||
args.add_argument("-r", "--repeat", help="Number of test to run", type=int, default=1)
|
args.add_argument("-r", "--repeat", help="Number of test to run", type=int, default=1)
|
||||||
args.add_argument("-b", "--use-base-schedule", required=False, help="Choose base-schedules rather than minimal-schedules", action='store_true')
|
args.add_argument(
|
||||||
args.add_argument("-w", "--use-whole-schedule-line", required=False, help="Use the whole line found in related schedule", action='store_true')
|
"-b",
|
||||||
args.add_argument("--valgrind", required=False, help="Run the test with valgrind enabled", action='store_true')
|
"--use-base-schedule",
|
||||||
|
required=False,
|
||||||
|
help="Choose base-schedules rather than minimal-schedules",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
args.add_argument(
|
||||||
|
"-w",
|
||||||
|
"--use-whole-schedule-line",
|
||||||
|
required=False,
|
||||||
|
help="Use the whole line found in related schedule",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
args.add_argument(
|
||||||
|
"--valgrind",
|
||||||
|
required=False,
|
||||||
|
help="Run the test with valgrind enabled",
|
||||||
|
action="store_true",
|
||||||
|
)
|
||||||
|
|
||||||
args = vars(args.parse_args())
|
args = vars(args.parse_args())
|
||||||
|
|
||||||
regress_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
regress_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||||
test_file_path = args['path']
|
test_file_path = args["path"]
|
||||||
test_file_name = args['test_name']
|
test_file_name = args["test_name"]
|
||||||
use_base_schedule = args['use_base_schedule']
|
use_base_schedule = args["use_base_schedule"]
|
||||||
use_whole_schedule_line = args['use_whole_schedule_line']
|
use_whole_schedule_line = args["use_whole_schedule_line"]
|
||||||
|
|
||||||
test_files_to_skip = ['multi_cluster_management', 'multi_extension', 'multi_test_helpers', 'multi_insert_select']
|
test_files_to_skip = [
|
||||||
test_files_to_run_without_schedule = ['single_node_enterprise']
|
"multi_cluster_management",
|
||||||
|
"multi_extension",
|
||||||
|
"multi_test_helpers",
|
||||||
|
"multi_insert_select",
|
||||||
|
]
|
||||||
|
test_files_to_run_without_schedule = ["single_node_enterprise"]
|
||||||
|
|
||||||
if not (test_file_name or test_file_path):
|
if not (test_file_name or test_file_path):
|
||||||
print(f"FATAL: No test given.")
|
print(f"FATAL: No test given.")
|
||||||
|
@ -36,7 +66,7 @@ if not (test_file_name or test_file_path):
|
||||||
|
|
||||||
|
|
||||||
if test_file_path:
|
if test_file_path:
|
||||||
test_file_path = os.path.join(os.getcwd(), args['path'])
|
test_file_path = os.path.join(os.getcwd(), args["path"])
|
||||||
|
|
||||||
if not os.path.isfile(test_file_path):
|
if not os.path.isfile(test_file_path):
|
||||||
print(f"ERROR: test file '{test_file_path}' does not exist")
|
print(f"ERROR: test file '{test_file_path}' does not exist")
|
||||||
|
@ -45,7 +75,7 @@ if test_file_path:
|
||||||
test_file_extension = pathlib.Path(test_file_path).suffix
|
test_file_extension = pathlib.Path(test_file_path).suffix
|
||||||
test_file_name = pathlib.Path(test_file_path).stem
|
test_file_name = pathlib.Path(test_file_path).stem
|
||||||
|
|
||||||
if not test_file_extension in '.spec.sql':
|
if not test_file_extension in ".spec.sql":
|
||||||
print(
|
print(
|
||||||
"ERROR: Unrecognized test extension. Valid extensions are: .sql and .spec"
|
"ERROR: Unrecognized test extension. Valid extensions are: .sql and .spec"
|
||||||
)
|
)
|
||||||
|
@ -56,12 +86,12 @@ if test_file_name in test_files_to_skip:
|
||||||
print(f"WARNING: Skipping exceptional test: '{test_file_name}'")
|
print(f"WARNING: Skipping exceptional test: '{test_file_name}'")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
test_schedule = ''
|
test_schedule = ""
|
||||||
|
|
||||||
# find related schedule
|
# find related schedule
|
||||||
for schedule_file_path in sorted(glob(os.path.join(regress_dir, "*_schedule"))):
|
for schedule_file_path in sorted(glob(os.path.join(regress_dir, "*_schedule"))):
|
||||||
for schedule_line in open(schedule_file_path, 'r'):
|
for schedule_line in open(schedule_file_path, "r"):
|
||||||
if re.search(r'\b' + test_file_name + r'\b', schedule_line):
|
if re.search(r"\b" + test_file_name + r"\b", schedule_line):
|
||||||
test_schedule = pathlib.Path(schedule_file_path).stem
|
test_schedule = pathlib.Path(schedule_file_path).stem
|
||||||
if use_whole_schedule_line:
|
if use_whole_schedule_line:
|
||||||
test_schedule_line = schedule_line
|
test_schedule_line = schedule_line
|
||||||
|
@ -74,55 +104,55 @@ for schedule_file_path in sorted(glob(os.path.join(regress_dir, "*_schedule"))):
|
||||||
|
|
||||||
# map suitable schedule
|
# map suitable schedule
|
||||||
if not test_schedule:
|
if not test_schedule:
|
||||||
print(
|
print(f"WARNING: Could not find any schedule for '{test_file_name}'")
|
||||||
f"WARNING: Could not find any schedule for '{test_file_name}'"
|
|
||||||
)
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
elif "isolation" in test_schedule:
|
elif "isolation" in test_schedule:
|
||||||
test_schedule = 'base_isolation_schedule'
|
test_schedule = "base_isolation_schedule"
|
||||||
elif "failure" in test_schedule:
|
elif "failure" in test_schedule:
|
||||||
test_schedule = 'failure_base_schedule'
|
test_schedule = "failure_base_schedule"
|
||||||
elif "enterprise" in test_schedule:
|
elif "enterprise" in test_schedule:
|
||||||
test_schedule = 'enterprise_minimal_schedule'
|
test_schedule = "enterprise_minimal_schedule"
|
||||||
elif "split" in test_schedule:
|
elif "split" in test_schedule:
|
||||||
test_schedule = 'minimal_schedule'
|
test_schedule = "minimal_schedule"
|
||||||
elif "mx" in test_schedule:
|
elif "mx" in test_schedule:
|
||||||
if use_base_schedule:
|
if use_base_schedule:
|
||||||
test_schedule = 'mx_base_schedule'
|
test_schedule = "mx_base_schedule"
|
||||||
else:
|
else:
|
||||||
test_schedule = 'mx_minimal_schedule'
|
test_schedule = "mx_minimal_schedule"
|
||||||
elif "operations" in test_schedule:
|
elif "operations" in test_schedule:
|
||||||
test_schedule = 'minimal_schedule'
|
test_schedule = "minimal_schedule"
|
||||||
elif test_schedule in config.ARBITRARY_SCHEDULE_NAMES:
|
elif test_schedule in config.ARBITRARY_SCHEDULE_NAMES:
|
||||||
print(f"WARNING: Arbitrary config schedule ({test_schedule}) is not supported.")
|
print(f"WARNING: Arbitrary config schedule ({test_schedule}) is not supported.")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
else:
|
else:
|
||||||
if use_base_schedule:
|
if use_base_schedule:
|
||||||
test_schedule = 'base_schedule'
|
test_schedule = "base_schedule"
|
||||||
else:
|
else:
|
||||||
test_schedule = 'minimal_schedule'
|
test_schedule = "minimal_schedule"
|
||||||
|
|
||||||
# copy base schedule to a temp file and append test_schedule_line
|
# copy base schedule to a temp file and append test_schedule_line
|
||||||
# to be able to run tests in parallel (if test_schedule_line is a parallel group.)
|
# to be able to run tests in parallel (if test_schedule_line is a parallel group.)
|
||||||
tmp_schedule_path = os.path.join(regress_dir, f"tmp_schedule_{ random.randint(1, 10000)}")
|
tmp_schedule_path = os.path.join(
|
||||||
|
regress_dir, f"tmp_schedule_{ random.randint(1, 10000)}"
|
||||||
|
)
|
||||||
# some tests don't need a schedule to run
|
# some tests don't need a schedule to run
|
||||||
# e.g tests that are in the first place in their own schedule
|
# e.g tests that are in the first place in their own schedule
|
||||||
if test_file_name not in test_files_to_run_without_schedule:
|
if test_file_name not in test_files_to_run_without_schedule:
|
||||||
shutil.copy2(os.path.join(regress_dir, test_schedule), tmp_schedule_path)
|
shutil.copy2(os.path.join(regress_dir, test_schedule), tmp_schedule_path)
|
||||||
with open(tmp_schedule_path, "a") as myfile:
|
with open(tmp_schedule_path, "a") as myfile:
|
||||||
for i in range(args['repeat']):
|
for i in range(args["repeat"]):
|
||||||
myfile.write(test_schedule_line)
|
myfile.write(test_schedule_line)
|
||||||
|
|
||||||
# find suitable make recipe
|
# find suitable make recipe
|
||||||
if "isolation" in test_schedule:
|
if "isolation" in test_schedule:
|
||||||
make_recipe = 'check-isolation-custom-schedule'
|
make_recipe = "check-isolation-custom-schedule"
|
||||||
elif "failure" in test_schedule:
|
elif "failure" in test_schedule:
|
||||||
make_recipe = 'check-failure-custom-schedule'
|
make_recipe = "check-failure-custom-schedule"
|
||||||
else:
|
else:
|
||||||
make_recipe = 'check-custom-schedule'
|
make_recipe = "check-custom-schedule"
|
||||||
|
|
||||||
if args['valgrind']:
|
if args["valgrind"]:
|
||||||
make_recipe += '-vg'
|
make_recipe += "-vg"
|
||||||
|
|
||||||
# prepare command to run tests
|
# prepare command to run tests
|
||||||
test_command = f"make -C {regress_dir} {make_recipe} SCHEDULE='{pathlib.Path(tmp_schedule_path).stem}'"
|
test_command = f"make -C {regress_dir} {make_recipe} SCHEDULE='{pathlib.Path(tmp_schedule_path).stem}'"
|
||||||
|
|
|
@ -112,7 +112,11 @@ def main(config):
|
||||||
config.node_name_to_ports.keys(),
|
config.node_name_to_ports.keys(),
|
||||||
)
|
)
|
||||||
common.start_databases(
|
common.start_databases(
|
||||||
config.new_bindir, config.new_datadir, config.node_name_to_ports, config.name, {}
|
config.new_bindir,
|
||||||
|
config.new_datadir,
|
||||||
|
config.node_name_to_ports,
|
||||||
|
config.name,
|
||||||
|
{},
|
||||||
)
|
)
|
||||||
citus_finish_pg_upgrade(config.new_bindir, config.node_name_to_ports.values())
|
citus_finish_pg_upgrade(config.new_bindir, config.node_name_to_ports.values())
|
||||||
|
|
||||||
|
|
|
@ -20,15 +20,17 @@ logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=loggin
|
||||||
|
|
||||||
# I. Command Strings
|
# I. Command Strings
|
||||||
|
|
||||||
|
|
||||||
class Handler:
|
class Handler:
|
||||||
'''
|
"""
|
||||||
This class hierarchy serves two purposes:
|
This class hierarchy serves two purposes:
|
||||||
1. Allow command strings to be evaluated. Once evaluated you'll have a Handler you can
|
1. Allow command strings to be evaluated. Once evaluated you'll have a Handler you can
|
||||||
pass packets to
|
pass packets to
|
||||||
2. Process packets as they come in and decide what to do with them.
|
2. Process packets as they come in and decide what to do with them.
|
||||||
|
|
||||||
Subclasses which want to change how packets are handled should override _handle.
|
Subclasses which want to change how packets are handled should override _handle.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self, root=None):
|
def __init__(self, root=None):
|
||||||
# all packets are first sent to the root handler to be processed
|
# all packets are first sent to the root handler to be processed
|
||||||
self.root = root if root else self
|
self.root = root if root else self
|
||||||
|
@ -38,30 +40,31 @@ class Handler:
|
||||||
def _accept(self, flow, message):
|
def _accept(self, flow, message):
|
||||||
result = self._handle(flow, message)
|
result = self._handle(flow, message)
|
||||||
|
|
||||||
if result == 'pass':
|
if result == "pass":
|
||||||
# defer to our child
|
# defer to our child
|
||||||
if not self.next:
|
if not self.next:
|
||||||
raise Exception("we don't know what to do!")
|
raise Exception("we don't know what to do!")
|
||||||
|
|
||||||
if self.next._accept(flow, message) == 'stop':
|
if self.next._accept(flow, message) == "stop":
|
||||||
if self.root is not self:
|
if self.root is not self:
|
||||||
return 'stop'
|
return "stop"
|
||||||
self.next = KillHandler(self)
|
self.next = KillHandler(self)
|
||||||
flow.kill()
|
flow.kill()
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
'''
|
"""
|
||||||
Handlers can return one of three things:
|
Handlers can return one of three things:
|
||||||
- "done" tells the parent to stop processing. This performs the default action,
|
- "done" tells the parent to stop processing. This performs the default action,
|
||||||
which is to allow the packet to be sent.
|
which is to allow the packet to be sent.
|
||||||
- "pass" means to delegate to self.next and do whatever it wants
|
- "pass" means to delegate to self.next and do whatever it wants
|
||||||
- "stop" means all processing will stop, and all connections will be killed
|
- "stop" means all processing will stop, and all connections will be killed
|
||||||
'''
|
"""
|
||||||
# subclasses must implement this
|
# subclasses must implement this
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class FilterableMixin:
|
class FilterableMixin:
|
||||||
def contains(self, pattern):
|
def contains(self, pattern):
|
||||||
self.next = Contains(self.root, pattern)
|
self.next = Contains(self.root, pattern)
|
||||||
|
@ -76,7 +79,7 @@ class FilterableMixin:
|
||||||
return self.next
|
return self.next
|
||||||
|
|
||||||
def __getattr__(self, attr):
|
def __getattr__(self, attr):
|
||||||
'''
|
"""
|
||||||
Methods such as .onQuery trigger when a packet with that name is intercepted
|
Methods such as .onQuery trigger when a packet with that name is intercepted
|
||||||
|
|
||||||
Adds support for commands such as:
|
Adds support for commands such as:
|
||||||
|
@ -85,14 +88,17 @@ class FilterableMixin:
|
||||||
Returns a function because the above command is resolved in two steps:
|
Returns a function because the above command is resolved in two steps:
|
||||||
conn.onQuery becomes conn.__getattr__("onQuery")
|
conn.onQuery becomes conn.__getattr__("onQuery")
|
||||||
conn.onQuery(query="COPY") becomes conn.__getattr__("onQuery")(query="COPY")
|
conn.onQuery(query="COPY") becomes conn.__getattr__("onQuery")(query="COPY")
|
||||||
'''
|
"""
|
||||||
if attr.startswith('on'):
|
if attr.startswith("on"):
|
||||||
|
|
||||||
def doit(**kwargs):
|
def doit(**kwargs):
|
||||||
self.next = OnPacket(self.root, attr[2:], kwargs)
|
self.next = OnPacket(self.root, attr[2:], kwargs)
|
||||||
return self.next
|
return self.next
|
||||||
|
|
||||||
return doit
|
return doit
|
||||||
raise AttributeError
|
raise AttributeError
|
||||||
|
|
||||||
|
|
||||||
class ActionsMixin:
|
class ActionsMixin:
|
||||||
def kill(self):
|
def kill(self):
|
||||||
self.next = KillHandler(self.root)
|
self.next = KillHandler(self.root)
|
||||||
|
@ -118,29 +124,37 @@ class ActionsMixin:
|
||||||
self.next = ConnectDelayHandler(self.root, timeMs)
|
self.next = ConnectDelayHandler(self.root, timeMs)
|
||||||
return self.next
|
return self.next
|
||||||
|
|
||||||
|
|
||||||
class AcceptHandler(Handler):
|
class AcceptHandler(Handler):
|
||||||
def __init__(self, root):
|
def __init__(self, root):
|
||||||
super().__init__(root)
|
super().__init__(root)
|
||||||
|
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
return 'done'
|
return "done"
|
||||||
|
|
||||||
|
|
||||||
class KillHandler(Handler):
|
class KillHandler(Handler):
|
||||||
def __init__(self, root):
|
def __init__(self, root):
|
||||||
super().__init__(root)
|
super().__init__(root)
|
||||||
|
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
flow.kill()
|
flow.kill()
|
||||||
return 'done'
|
return "done"
|
||||||
|
|
||||||
|
|
||||||
class KillAllHandler(Handler):
|
class KillAllHandler(Handler):
|
||||||
def __init__(self, root):
|
def __init__(self, root):
|
||||||
super().__init__(root)
|
super().__init__(root)
|
||||||
|
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
return 'stop'
|
return "stop"
|
||||||
|
|
||||||
|
|
||||||
class ResetHandler(Handler):
|
class ResetHandler(Handler):
|
||||||
# try to force a RST to be sent, something went very wrong!
|
# try to force a RST to be sent, something went very wrong!
|
||||||
def __init__(self, root):
|
def __init__(self, root):
|
||||||
super().__init__(root)
|
super().__init__(root)
|
||||||
|
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
flow.kill() # tell mitmproxy this connection should be closed
|
flow.kill() # tell mitmproxy this connection should be closed
|
||||||
|
|
||||||
|
@ -152,8 +166,9 @@ class ResetHandler(Handler):
|
||||||
# cause linux to send a RST
|
# cause linux to send a RST
|
||||||
LINGER_ON, LINGER_TIMEOUT = 1, 0
|
LINGER_ON, LINGER_TIMEOUT = 1, 0
|
||||||
conn.setsockopt(
|
conn.setsockopt(
|
||||||
socket.SOL_SOCKET, socket.SO_LINGER,
|
socket.SOL_SOCKET,
|
||||||
struct.pack('ii', LINGER_ON, LINGER_TIMEOUT)
|
socket.SO_LINGER,
|
||||||
|
struct.pack("ii", LINGER_ON, LINGER_TIMEOUT),
|
||||||
)
|
)
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
@ -161,28 +176,35 @@ class ResetHandler(Handler):
|
||||||
# tries to call conn.shutdown(), but there's nothing else to clean up so that's
|
# tries to call conn.shutdown(), but there's nothing else to clean up so that's
|
||||||
# maybe okay
|
# maybe okay
|
||||||
|
|
||||||
return 'done'
|
return "done"
|
||||||
|
|
||||||
|
|
||||||
class CancelHandler(Handler):
|
class CancelHandler(Handler):
|
||||||
'Send a SIGINT to the process'
|
"Send a SIGINT to the process"
|
||||||
|
|
||||||
def __init__(self, root, pid):
|
def __init__(self, root, pid):
|
||||||
super().__init__(root)
|
super().__init__(root)
|
||||||
self.pid = pid
|
self.pid = pid
|
||||||
|
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
os.kill(self.pid, signal.SIGINT)
|
os.kill(self.pid, signal.SIGINT)
|
||||||
# give the signal a chance to be received before we let the packet through
|
# give the signal a chance to be received before we let the packet through
|
||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
return 'done'
|
return "done"
|
||||||
|
|
||||||
|
|
||||||
class ConnectDelayHandler(Handler):
|
class ConnectDelayHandler(Handler):
|
||||||
'Delay the initial packet by sleeping before deciding what to do'
|
"Delay the initial packet by sleeping before deciding what to do"
|
||||||
|
|
||||||
def __init__(self, root, timeMs):
|
def __init__(self, root, timeMs):
|
||||||
super().__init__(root)
|
super().__init__(root)
|
||||||
self.timeMs = timeMs
|
self.timeMs = timeMs
|
||||||
|
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
if message.is_initial:
|
if message.is_initial:
|
||||||
time.sleep(self.timeMs / 1000.0)
|
time.sleep(self.timeMs / 1000.0)
|
||||||
return 'done'
|
return "done"
|
||||||
|
|
||||||
|
|
||||||
class Contains(Handler, ActionsMixin, FilterableMixin):
|
class Contains(Handler, ActionsMixin, FilterableMixin):
|
||||||
def __init__(self, root, pattern):
|
def __init__(self, root, pattern):
|
||||||
|
@ -191,8 +213,9 @@ class Contains(Handler, ActionsMixin, FilterableMixin):
|
||||||
|
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
if self.pattern in message.content:
|
if self.pattern in message.content:
|
||||||
return 'pass'
|
return "pass"
|
||||||
return 'done'
|
return "done"
|
||||||
|
|
||||||
|
|
||||||
class Matches(Handler, ActionsMixin, FilterableMixin):
|
class Matches(Handler, ActionsMixin, FilterableMixin):
|
||||||
def __init__(self, root, pattern):
|
def __init__(self, root, pattern):
|
||||||
|
@ -201,47 +224,56 @@ class Matches(Handler, ActionsMixin, FilterableMixin):
|
||||||
|
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
if self.pattern.search(message.content):
|
if self.pattern.search(message.content):
|
||||||
return 'pass'
|
return "pass"
|
||||||
return 'done'
|
return "done"
|
||||||
|
|
||||||
|
|
||||||
class After(Handler, ActionsMixin, FilterableMixin):
|
class After(Handler, ActionsMixin, FilterableMixin):
|
||||||
"Don't pass execution to our child until we've handled 'times' messages"
|
"Don't pass execution to our child until we've handled 'times' messages"
|
||||||
|
|
||||||
def __init__(self, root, times):
|
def __init__(self, root, times):
|
||||||
super().__init__(root)
|
super().__init__(root)
|
||||||
self.target = times
|
self.target = times
|
||||||
|
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
if not hasattr(flow, '_after_count'):
|
if not hasattr(flow, "_after_count"):
|
||||||
flow._after_count = 0
|
flow._after_count = 0
|
||||||
|
|
||||||
if flow._after_count >= self.target:
|
if flow._after_count >= self.target:
|
||||||
return 'pass'
|
return "pass"
|
||||||
|
|
||||||
flow._after_count += 1
|
flow._after_count += 1
|
||||||
return 'done'
|
return "done"
|
||||||
|
|
||||||
|
|
||||||
class OnPacket(Handler, ActionsMixin, FilterableMixin):
|
class OnPacket(Handler, ActionsMixin, FilterableMixin):
|
||||||
'''Triggers when a packet of the specified kind comes around'''
|
"""Triggers when a packet of the specified kind comes around"""
|
||||||
|
|
||||||
def __init__(self, root, packet_kind, kwargs):
|
def __init__(self, root, packet_kind, kwargs):
|
||||||
super().__init__(root)
|
super().__init__(root)
|
||||||
self.packet_kind = packet_kind
|
self.packet_kind = packet_kind
|
||||||
self.filters = kwargs
|
self.filters = kwargs
|
||||||
|
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
if not message.parsed:
|
if not message.parsed:
|
||||||
# if this is the first message in the connection we just skip it
|
# if this is the first message in the connection we just skip it
|
||||||
return 'done'
|
return "done"
|
||||||
for msg in message.parsed:
|
for msg in message.parsed:
|
||||||
typ = structs.message_type(msg, from_frontend=message.from_client)
|
typ = structs.message_type(msg, from_frontend=message.from_client)
|
||||||
if typ == self.packet_kind:
|
if typ == self.packet_kind:
|
||||||
matches = structs.message_matches(msg, self.filters, message.from_client)
|
matches = structs.message_matches(
|
||||||
|
msg, self.filters, message.from_client
|
||||||
|
)
|
||||||
if matches:
|
if matches:
|
||||||
return 'pass'
|
return "pass"
|
||||||
return 'done'
|
return "done"
|
||||||
|
|
||||||
|
|
||||||
class RootHandler(Handler, ActionsMixin, FilterableMixin):
|
class RootHandler(Handler, ActionsMixin, FilterableMixin):
|
||||||
def _handle(self, flow, message):
|
def _handle(self, flow, message):
|
||||||
# do whatever the next Handler tells us to do
|
# do whatever the next Handler tells us to do
|
||||||
return 'pass'
|
return "pass"
|
||||||
|
|
||||||
|
|
||||||
class RecorderCommand:
|
class RecorderCommand:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -250,23 +282,26 @@ class RecorderCommand:
|
||||||
|
|
||||||
def dump(self):
|
def dump(self):
|
||||||
# When the user calls dump() we return everything we've captured
|
# When the user calls dump() we return everything we've captured
|
||||||
self.command = 'dump'
|
self.command = "dump"
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
# If the user calls reset() we dump all captured packets without returning them
|
# If the user calls reset() we dump all captured packets without returning them
|
||||||
self.command = 'reset'
|
self.command = "reset"
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
# II. Utilities for interfacing with mitmproxy
|
# II. Utilities for interfacing with mitmproxy
|
||||||
|
|
||||||
|
|
||||||
def build_handler(spec):
|
def build_handler(spec):
|
||||||
'Turns a command string into a RootHandler ready to accept packets'
|
"Turns a command string into a RootHandler ready to accept packets"
|
||||||
root = RootHandler()
|
root = RootHandler()
|
||||||
recorder = RecorderCommand()
|
recorder = RecorderCommand()
|
||||||
handler = eval(spec, {'__builtins__': {}}, {'conn': root, 'recorder': recorder})
|
handler = eval(spec, {"__builtins__": {}}, {"conn": root, "recorder": recorder})
|
||||||
return handler.root
|
return handler.root
|
||||||
|
|
||||||
|
|
||||||
# a bunch of globals
|
# a bunch of globals
|
||||||
|
|
||||||
handler = None # the current handler used to process packets
|
handler = None # the current handler used to process packets
|
||||||
|
@ -274,25 +309,25 @@ command_thread = None # sits on the fifo and waits for new commands to come in
|
||||||
captured_messages = queue.Queue() # where we store messages used for recorder.dump()
|
captured_messages = queue.Queue() # where we store messages used for recorder.dump()
|
||||||
connection_count = count() # so we can give connections ids in recorder.dump()
|
connection_count = count() # so we can give connections ids in recorder.dump()
|
||||||
|
|
||||||
def listen_for_commands(fifoname):
|
|
||||||
|
|
||||||
|
def listen_for_commands(fifoname):
|
||||||
def emit_row(conn, from_client, message):
|
def emit_row(conn, from_client, message):
|
||||||
# we're using the COPY text format. It requires us to escape backslashes
|
# we're using the COPY text format. It requires us to escape backslashes
|
||||||
cleaned = message.replace('\\', '\\\\')
|
cleaned = message.replace("\\", "\\\\")
|
||||||
source = 'coordinator' if from_client else 'worker'
|
source = "coordinator" if from_client else "worker"
|
||||||
return '{}\t{}\t{}'.format(conn, source, cleaned)
|
return "{}\t{}\t{}".format(conn, source, cleaned)
|
||||||
|
|
||||||
def emit_message(message):
|
def emit_message(message):
|
||||||
if message.is_initial:
|
if message.is_initial:
|
||||||
return emit_row(
|
return emit_row(
|
||||||
message.connection_id, message.from_client, '[initial message]'
|
message.connection_id, message.from_client, "[initial message]"
|
||||||
)
|
)
|
||||||
|
|
||||||
pretty = structs.print(message.parsed)
|
pretty = structs.print(message.parsed)
|
||||||
return emit_row(message.connection_id, message.from_client, pretty)
|
return emit_row(message.connection_id, message.from_client, pretty)
|
||||||
|
|
||||||
def all_items(queue_):
|
def all_items(queue_):
|
||||||
'Pulls everything out of the queue without blocking'
|
"Pulls everything out of the queue without blocking"
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
yield queue_.get(block=False)
|
yield queue_.get(block=False)
|
||||||
|
@ -300,23 +335,27 @@ def listen_for_commands(fifoname):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def drop_terminate_messages(messages):
|
def drop_terminate_messages(messages):
|
||||||
'''
|
"""
|
||||||
Terminate() messages happen eventually, Citus doesn't feel any need to send them
|
Terminate() messages happen eventually, Citus doesn't feel any need to send them
|
||||||
immediately, so tests which embed them aren't reproducible and fail to timing
|
immediately, so tests which embed them aren't reproducible and fail to timing
|
||||||
issues. Here we simply drop those messages.
|
issues. Here we simply drop those messages.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def isTerminate(msg, from_client):
|
def isTerminate(msg, from_client):
|
||||||
kind = structs.message_type(msg, from_client)
|
kind = structs.message_type(msg, from_client)
|
||||||
return kind == 'Terminate'
|
return kind == "Terminate"
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if not message.parsed:
|
if not message.parsed:
|
||||||
yield message
|
yield message
|
||||||
continue
|
continue
|
||||||
message.parsed = ListContainer([
|
message.parsed = ListContainer(
|
||||||
msg for msg in message.parsed
|
[
|
||||||
|
msg
|
||||||
|
for msg in message.parsed
|
||||||
if not isTerminate(msg, message.from_client)
|
if not isTerminate(msg, message.from_client)
|
||||||
])
|
]
|
||||||
|
)
|
||||||
message.parsed.from_frontend = message.from_client
|
message.parsed.from_frontend = message.from_client
|
||||||
if len(message.parsed) == 0:
|
if len(message.parsed) == 0:
|
||||||
continue
|
continue
|
||||||
|
@ -324,35 +363,35 @@ def listen_for_commands(fifoname):
|
||||||
|
|
||||||
def handle_recorder(recorder):
|
def handle_recorder(recorder):
|
||||||
global connection_count
|
global connection_count
|
||||||
result = ''
|
result = ""
|
||||||
|
|
||||||
if recorder.command == 'reset':
|
if recorder.command == "reset":
|
||||||
result = ''
|
result = ""
|
||||||
connection_count = count()
|
connection_count = count()
|
||||||
elif recorder.command != 'dump':
|
elif recorder.command != "dump":
|
||||||
# this should never happen
|
# this should never happen
|
||||||
raise Exception('Unrecognized command: {}'.format(recorder.command))
|
raise Exception("Unrecognized command: {}".format(recorder.command))
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
messages = all_items(captured_messages)
|
messages = all_items(captured_messages)
|
||||||
messages = drop_terminate_messages(messages)
|
messages = drop_terminate_messages(messages)
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if recorder.command == 'reset':
|
if recorder.command == "reset":
|
||||||
continue
|
continue
|
||||||
results.append(emit_message(message))
|
results.append(emit_message(message))
|
||||||
result = '\n'.join(results)
|
result = "\n".join(results)
|
||||||
|
|
||||||
logging.debug('about to write to fifo')
|
logging.debug("about to write to fifo")
|
||||||
with open(fifoname, mode='w') as fifo:
|
with open(fifoname, mode="w") as fifo:
|
||||||
logging.debug('successfully opened the fifo for writing')
|
logging.debug("successfully opened the fifo for writing")
|
||||||
fifo.write('{}'.format(result))
|
fifo.write("{}".format(result))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
logging.debug('about to read from fifo')
|
logging.debug("about to read from fifo")
|
||||||
with open(fifoname, mode='r') as fifo:
|
with open(fifoname, mode="r") as fifo:
|
||||||
logging.debug('successfully opened the fifo for reading')
|
logging.debug("successfully opened the fifo for reading")
|
||||||
slug = fifo.read()
|
slug = fifo.read()
|
||||||
logging.info('received new command: %s', slug.rstrip())
|
logging.info("received new command: %s", slug.rstrip())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
handler = build_handler(slug)
|
handler = build_handler(slug)
|
||||||
|
@ -371,13 +410,14 @@ def listen_for_commands(fifoname):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
result = str(e)
|
result = str(e)
|
||||||
else:
|
else:
|
||||||
result = ''
|
result = ""
|
||||||
|
|
||||||
|
logging.debug("about to write to fifo")
|
||||||
|
with open(fifoname, mode="w") as fifo:
|
||||||
|
logging.debug("successfully opened the fifo for writing")
|
||||||
|
fifo.write("{}\n".format(result))
|
||||||
|
logging.info("responded to command: %s", result.split("\n")[0])
|
||||||
|
|
||||||
logging.debug('about to write to fifo')
|
|
||||||
with open(fifoname, mode='w') as fifo:
|
|
||||||
logging.debug('successfully opened the fifo for writing')
|
|
||||||
fifo.write('{}\n'.format(result))
|
|
||||||
logging.info('responded to command: %s', result.split("\n")[0])
|
|
||||||
|
|
||||||
def create_thread(fifoname):
|
def create_thread(fifoname):
|
||||||
global command_thread
|
global command_thread
|
||||||
|
@ -388,42 +428,46 @@ def create_thread(fifoname):
|
||||||
return
|
return
|
||||||
|
|
||||||
if command_thread:
|
if command_thread:
|
||||||
print('cannot change the fifo path once mitmproxy has started');
|
print("cannot change the fifo path once mitmproxy has started")
|
||||||
return
|
return
|
||||||
|
|
||||||
command_thread = threading.Thread(target=listen_for_commands, args=(fifoname,), daemon=True)
|
command_thread = threading.Thread(
|
||||||
|
target=listen_for_commands, args=(fifoname,), daemon=True
|
||||||
|
)
|
||||||
command_thread.start()
|
command_thread.start()
|
||||||
|
|
||||||
|
|
||||||
# III. mitmproxy callbacks
|
# III. mitmproxy callbacks
|
||||||
|
|
||||||
|
|
||||||
def load(loader):
|
def load(loader):
|
||||||
loader.add_option('slug', str, 'conn.allow()', "A script to run")
|
loader.add_option("slug", str, "conn.allow()", "A script to run")
|
||||||
loader.add_option('fifo', str, '', "Which fifo to listen on for commands")
|
loader.add_option("fifo", str, "", "Which fifo to listen on for commands")
|
||||||
|
|
||||||
|
|
||||||
def configure(updated):
|
def configure(updated):
|
||||||
global handler
|
global handler
|
||||||
|
|
||||||
if 'slug' in updated:
|
if "slug" in updated:
|
||||||
text = ctx.options.slug
|
text = ctx.options.slug
|
||||||
handler = build_handler(text)
|
handler = build_handler(text)
|
||||||
|
|
||||||
if 'fifo' in updated:
|
if "fifo" in updated:
|
||||||
fifoname = ctx.options.fifo
|
fifoname = ctx.options.fifo
|
||||||
create_thread(fifoname)
|
create_thread(fifoname)
|
||||||
|
|
||||||
|
|
||||||
def tcp_message(flow: tcp.TCPFlow):
|
def tcp_message(flow: tcp.TCPFlow):
|
||||||
'''
|
"""
|
||||||
This callback is hit every time mitmproxy receives a packet. It's the main entrypoint
|
This callback is hit every time mitmproxy receives a packet. It's the main entrypoint
|
||||||
into this script.
|
into this script.
|
||||||
'''
|
"""
|
||||||
global connection_count
|
global connection_count
|
||||||
|
|
||||||
tcp_msg = flow.messages[-1]
|
tcp_msg = flow.messages[-1]
|
||||||
|
|
||||||
# Keep track of all the different connections, assign a unique id to each
|
# Keep track of all the different connections, assign a unique id to each
|
||||||
if not hasattr(flow, 'connection_id'):
|
if not hasattr(flow, "connection_id"):
|
||||||
flow.connection_id = next(connection_count)
|
flow.connection_id = next(connection_count)
|
||||||
tcp_msg.connection_id = flow.connection_id
|
tcp_msg.connection_id = flow.connection_id
|
||||||
|
|
||||||
|
@ -434,7 +478,9 @@ def tcp_message(flow: tcp.TCPFlow):
|
||||||
# skip parsing initial messages for now, they're not important
|
# skip parsing initial messages for now, they're not important
|
||||||
tcp_msg.parsed = None
|
tcp_msg.parsed = None
|
||||||
else:
|
else:
|
||||||
tcp_msg.parsed = structs.parse(tcp_msg.content, from_frontend=tcp_msg.from_client)
|
tcp_msg.parsed = structs.parse(
|
||||||
|
tcp_msg.content, from_frontend=tcp_msg.from_client
|
||||||
|
)
|
||||||
|
|
||||||
# record the message, for debugging purposes
|
# record the message, for debugging purposes
|
||||||
captured_messages.put(tcp_msg)
|
captured_messages.put(tcp_msg)
|
||||||
|
|
|
@ -1,8 +1,25 @@
|
||||||
from construct import (
|
from construct import (
|
||||||
Struct,
|
Struct,
|
||||||
Int8ub, Int16ub, Int32ub, Int16sb, Int32sb,
|
Int8ub,
|
||||||
Bytes, CString, Computed, Switch, Seek, this, Pointer,
|
Int16ub,
|
||||||
GreedyRange, Enum, Byte, Probe, FixedSized, RestreamData, GreedyBytes, Array
|
Int32ub,
|
||||||
|
Int16sb,
|
||||||
|
Int32sb,
|
||||||
|
Bytes,
|
||||||
|
CString,
|
||||||
|
Computed,
|
||||||
|
Switch,
|
||||||
|
Seek,
|
||||||
|
this,
|
||||||
|
Pointer,
|
||||||
|
GreedyRange,
|
||||||
|
Enum,
|
||||||
|
Byte,
|
||||||
|
Probe,
|
||||||
|
FixedSized,
|
||||||
|
RestreamData,
|
||||||
|
GreedyBytes,
|
||||||
|
Array,
|
||||||
)
|
)
|
||||||
import construct.lib as cl
|
import construct.lib as cl
|
||||||
|
|
||||||
|
@ -11,23 +28,30 @@ import re
|
||||||
# For all possible message formats see:
|
# For all possible message formats see:
|
||||||
# https://www.postgresql.org/docs/current/protocol-message-formats.html
|
# https://www.postgresql.org/docs/current/protocol-message-formats.html
|
||||||
|
|
||||||
|
|
||||||
class MessageMeta(type):
|
class MessageMeta(type):
|
||||||
def __init__(cls, name, bases, namespace):
|
def __init__(cls, name, bases, namespace):
|
||||||
'''
|
"""
|
||||||
__init__ is called every time a subclass of MessageMeta is declared
|
__init__ is called every time a subclass of MessageMeta is declared
|
||||||
'''
|
"""
|
||||||
if not hasattr(cls, "_msgtypes"):
|
if not hasattr(cls, "_msgtypes"):
|
||||||
raise Exception("classes which use MessageMeta must have a '_msgtypes' field")
|
raise Exception(
|
||||||
|
"classes which use MessageMeta must have a '_msgtypes' field"
|
||||||
|
)
|
||||||
|
|
||||||
if not hasattr(cls, "_classes"):
|
if not hasattr(cls, "_classes"):
|
||||||
raise Exception("classes which use MessageMeta must have a '_classes' field")
|
raise Exception(
|
||||||
|
"classes which use MessageMeta must have a '_classes' field"
|
||||||
|
)
|
||||||
|
|
||||||
if not hasattr(cls, "struct"):
|
if not hasattr(cls, "struct"):
|
||||||
# This is one of the direct subclasses
|
# This is one of the direct subclasses
|
||||||
return
|
return
|
||||||
|
|
||||||
if cls.__name__ in cls._classes:
|
if cls.__name__ in cls._classes:
|
||||||
raise Exception("You've already made a class called {}".format( cls.__name__))
|
raise Exception(
|
||||||
|
"You've already made a class called {}".format(cls.__name__)
|
||||||
|
)
|
||||||
cls._classes[cls.__name__] = cls
|
cls._classes[cls.__name__] = cls
|
||||||
|
|
||||||
# add a _type field to the struct so we can identify it while printing structs
|
# add a _type field to the struct so we can identify it while printing structs
|
||||||
|
@ -39,34 +63,41 @@ class MessageMeta(type):
|
||||||
# register the type, so we can tell the parser about it
|
# register the type, so we can tell the parser about it
|
||||||
key = cls.key
|
key = cls.key
|
||||||
if key in cls._msgtypes:
|
if key in cls._msgtypes:
|
||||||
raise Exception('key {} is already assigned to {}'.format(
|
raise Exception(
|
||||||
key, cls._msgtypes[key].__name__)
|
"key {} is already assigned to {}".format(
|
||||||
|
key, cls._msgtypes[key].__name__
|
||||||
|
)
|
||||||
)
|
)
|
||||||
cls._msgtypes[key] = cls
|
cls._msgtypes[key] = cls
|
||||||
|
|
||||||
|
|
||||||
class Message:
|
class Message:
|
||||||
'Do not subclass this object directly. Instead, subclass of one of the below types'
|
"Do not subclass this object directly. Instead, subclass of one of the below types"
|
||||||
|
|
||||||
def print(message):
|
def print(message):
|
||||||
'Define this on subclasses you want to change the representation of'
|
"Define this on subclasses you want to change the representation of"
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def typeof(message):
|
def typeof(message):
|
||||||
'Define this on subclasses you want to change the expressed type of'
|
"Define this on subclasses you want to change the expressed type of"
|
||||||
return message._type
|
return message._type
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _default_print(cls, name, msg):
|
def _default_print(cls, name, msg):
|
||||||
recur = cls.print_message
|
recur = cls.print_message
|
||||||
return "{}({})".format(name, ",".join(
|
return "{}({})".format(
|
||||||
"{}={}".format(key, recur(value)) for key, value in msg.items()
|
name,
|
||||||
if not key.startswith('_')
|
",".join(
|
||||||
))
|
"{}={}".format(key, recur(value))
|
||||||
|
for key, value in msg.items()
|
||||||
|
if not key.startswith("_")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def find_typeof(cls, msg):
|
def find_typeof(cls, msg):
|
||||||
if not hasattr(cls, "_msgtypes"):
|
if not hasattr(cls, "_msgtypes"):
|
||||||
raise Exception('Do not call this method on Message, call it on a subclass')
|
raise Exception("Do not call this method on Message, call it on a subclass")
|
||||||
if isinstance(msg, cl.ListContainer):
|
if isinstance(msg, cl.ListContainer):
|
||||||
raise ValueError("do not call this on a list of messages")
|
raise ValueError("do not call this on a list of messages")
|
||||||
if not isinstance(msg, cl.Container):
|
if not isinstance(msg, cl.Container):
|
||||||
|
@ -80,7 +111,7 @@ class Message:
|
||||||
@classmethod
|
@classmethod
|
||||||
def print_message(cls, msg):
|
def print_message(cls, msg):
|
||||||
if not hasattr(cls, "_msgtypes"):
|
if not hasattr(cls, "_msgtypes"):
|
||||||
raise Exception('Do not call this method on Message, call it on a subclass')
|
raise Exception("Do not call this method on Message, call it on a subclass")
|
||||||
|
|
||||||
if isinstance(msg, cl.ListContainer):
|
if isinstance(msg, cl.ListContainer):
|
||||||
return repr([cls.print_message(message) for message in msg])
|
return repr([cls.print_message(message) for message in msg])
|
||||||
|
@ -101,38 +132,34 @@ class Message:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name_to_struct(cls):
|
def name_to_struct(cls):
|
||||||
return {
|
return {_class.__name__: _class.struct for _class in cls._msgtypes.values()}
|
||||||
_class.__name__: _class.struct
|
|
||||||
for _class in cls._msgtypes.values()
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def name_to_key(cls):
|
def name_to_key(cls):
|
||||||
return {
|
return {_class.__name__: ord(key) for key, _class in cls._msgtypes.items()}
|
||||||
_class.__name__ : ord(key)
|
|
||||||
for key, _class in cls._msgtypes.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
class SharedMessage(Message, metaclass=MessageMeta):
|
class SharedMessage(Message, metaclass=MessageMeta):
|
||||||
'A message which could be sent by either the frontend or the backend'
|
"A message which could be sent by either the frontend or the backend"
|
||||||
_msgtypes = dict()
|
_msgtypes = dict()
|
||||||
_classes = dict()
|
_classes = dict()
|
||||||
|
|
||||||
|
|
||||||
class FrontendMessage(Message, metaclass=MessageMeta):
|
class FrontendMessage(Message, metaclass=MessageMeta):
|
||||||
'A message which will only be sent be a backend'
|
"A message which will only be sent be a backend"
|
||||||
_msgtypes = dict()
|
_msgtypes = dict()
|
||||||
_classes = dict()
|
_classes = dict()
|
||||||
|
|
||||||
|
|
||||||
class BackendMessage(Message, metaclass=MessageMeta):
|
class BackendMessage(Message, metaclass=MessageMeta):
|
||||||
'A message which will only be sent be a frontend'
|
"A message which will only be sent be a frontend"
|
||||||
_msgtypes = dict()
|
_msgtypes = dict()
|
||||||
_classes = dict()
|
_classes = dict()
|
||||||
|
|
||||||
|
|
||||||
class Query(FrontendMessage):
|
class Query(FrontendMessage):
|
||||||
key = 'Q'
|
key = "Q"
|
||||||
struct = Struct(
|
struct = Struct("query" / CString("ascii"))
|
||||||
"query" / CString("ascii")
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def print(message):
|
def print(message):
|
||||||
|
@ -144,103 +171,114 @@ class Query(FrontendMessage):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def normalize_shards(content):
|
def normalize_shards(content):
|
||||||
'''
|
"""
|
||||||
For example:
|
For example:
|
||||||
>>> normalize_shards(
|
>>> normalize_shards(
|
||||||
>>> 'COPY public.copy_test_120340 (key, value) FROM STDIN WITH (FORMAT BINARY))'
|
>>> 'COPY public.copy_test_120340 (key, value) FROM STDIN WITH (FORMAT BINARY))'
|
||||||
>>> )
|
>>> )
|
||||||
'COPY public.copy_test_XXXXXX (key, value) FROM STDIN WITH (FORMAT BINARY))'
|
'COPY public.copy_test_XXXXXX (key, value) FROM STDIN WITH (FORMAT BINARY))'
|
||||||
'''
|
"""
|
||||||
result = content
|
result = content
|
||||||
pattern = re.compile('public\.[a-z_]+(?P<shardid>[0-9]+)')
|
pattern = re.compile("public\.[a-z_]+(?P<shardid>[0-9]+)")
|
||||||
for match in pattern.finditer(content):
|
for match in pattern.finditer(content):
|
||||||
span = match.span('shardid')
|
span = match.span("shardid")
|
||||||
replacement = 'X'*( span[1] - span[0] )
|
replacement = "X" * (span[1] - span[0])
|
||||||
result = result[: span[0]] + replacement + result[span[1] :]
|
result = result[: span[0]] + replacement + result[span[1] :]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def normalize_timestamps(content):
|
def normalize_timestamps(content):
|
||||||
'''
|
"""
|
||||||
For example:
|
For example:
|
||||||
>>> normalize_timestamps('2018-06-07 05:18:19.388992-07')
|
>>> normalize_timestamps('2018-06-07 05:18:19.388992-07')
|
||||||
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
|
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
|
||||||
>>> normalize_timestamps('2018-06-11 05:30:43.01382-07')
|
>>> normalize_timestamps('2018-06-11 05:30:43.01382-07')
|
||||||
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
|
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
|
||||||
'''
|
"""
|
||||||
|
|
||||||
pattern = re.compile(
|
pattern = re.compile(
|
||||||
'[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{2,6}-[0-9]{2}'
|
"[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{2,6}-[0-9]{2}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return re.sub(pattern, 'XXXX-XX-XX XX:XX:XX.XXXXXX-XX', content)
|
return re.sub(pattern, "XXXX-XX-XX XX:XX:XX.XXXXXX-XX", content)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def normalize_assign_txn_id(content):
|
def normalize_assign_txn_id(content):
|
||||||
'''
|
"""
|
||||||
For example:
|
For example:
|
||||||
>>> normalize_assign_txn_id('SELECT assign_distributed_transaction_id(0, 52, ...')
|
>>> normalize_assign_txn_id('SELECT assign_distributed_transaction_id(0, 52, ...')
|
||||||
'SELECT assign_distributed_transaction_id(0, XX, ...'
|
'SELECT assign_distributed_transaction_id(0, XX, ...'
|
||||||
'''
|
"""
|
||||||
|
|
||||||
pattern = re.compile(
|
pattern = re.compile(
|
||||||
'assign_distributed_transaction_id\s*\(' # a method call
|
"assign_distributed_transaction_id\s*\(" # a method call
|
||||||
'\s*[0-9]+\s*,' # an integer first parameter
|
"\s*[0-9]+\s*," # an integer first parameter
|
||||||
'\s*(?P<transaction_id>[0-9]+)' # an integer second parameter
|
"\s*(?P<transaction_id>[0-9]+)" # an integer second parameter
|
||||||
)
|
)
|
||||||
result = content
|
result = content
|
||||||
for match in pattern.finditer(content):
|
for match in pattern.finditer(content):
|
||||||
span = match.span('transaction_id')
|
span = match.span("transaction_id")
|
||||||
result = result[:span[0]] + 'XX' + result[span[1]:]
|
result = result[: span[0]] + "XX" + result[span[1] :]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
class Terminate(FrontendMessage):
|
class Terminate(FrontendMessage):
|
||||||
key = 'X'
|
key = "X"
|
||||||
struct = Struct()
|
struct = Struct()
|
||||||
|
|
||||||
|
|
||||||
class CopyData(SharedMessage):
|
class CopyData(SharedMessage):
|
||||||
key = 'd'
|
key = "d"
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
'data' / GreedyBytes # reads all of the data left in this substream
|
"data" / GreedyBytes # reads all of the data left in this substream
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CopyDone(SharedMessage):
|
class CopyDone(SharedMessage):
|
||||||
key = 'c'
|
key = "c"
|
||||||
struct = Struct()
|
struct = Struct()
|
||||||
|
|
||||||
|
|
||||||
class EmptyQueryResponse(BackendMessage):
|
class EmptyQueryResponse(BackendMessage):
|
||||||
key = 'I'
|
key = "I"
|
||||||
struct = Struct()
|
struct = Struct()
|
||||||
|
|
||||||
|
|
||||||
class CopyOutResponse(BackendMessage):
|
class CopyOutResponse(BackendMessage):
|
||||||
key = 'H'
|
key = "H"
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
"format" / Int8ub,
|
"format" / Int8ub,
|
||||||
"columncount" / Int16ub,
|
"columncount" / Int16ub,
|
||||||
"columns" / Array(this.columncount, Struct(
|
"columns" / Array(this.columncount, Struct("format" / Int16ub)),
|
||||||
"format" / Int16ub
|
|
||||||
))
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ReadyForQuery(BackendMessage):
|
class ReadyForQuery(BackendMessage):
|
||||||
key='Z'
|
key = "Z"
|
||||||
struct = Struct("state"/Enum(Byte,
|
|
||||||
idle=ord('I'),
|
|
||||||
in_transaction_block=ord('T'),
|
|
||||||
in_failed_transaction_block=ord('E')
|
|
||||||
))
|
|
||||||
|
|
||||||
class CommandComplete(BackendMessage):
|
|
||||||
key = 'C'
|
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
"command" / CString("ascii")
|
"state"
|
||||||
|
/ Enum(
|
||||||
|
Byte,
|
||||||
|
idle=ord("I"),
|
||||||
|
in_transaction_block=ord("T"),
|
||||||
|
in_failed_transaction_block=ord("E"),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CommandComplete(BackendMessage):
|
||||||
|
key = "C"
|
||||||
|
struct = Struct("command" / CString("ascii"))
|
||||||
|
|
||||||
|
|
||||||
class RowDescription(BackendMessage):
|
class RowDescription(BackendMessage):
|
||||||
key = 'T'
|
key = "T"
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
"fieldcount" / Int16ub,
|
"fieldcount" / Int16ub,
|
||||||
"fields" / Array(this.fieldcount, Struct(
|
"fields"
|
||||||
|
/ Array(
|
||||||
|
this.fieldcount,
|
||||||
|
Struct(
|
||||||
"_type" / Computed("F"),
|
"_type" / Computed("F"),
|
||||||
"name" / CString("ascii"),
|
"name" / CString("ascii"),
|
||||||
"tableoid" / Int32ub,
|
"tableoid" / Int32ub,
|
||||||
|
@ -249,27 +287,35 @@ class RowDescription(BackendMessage):
|
||||||
"typlen" / Int16sb,
|
"typlen" / Int16sb,
|
||||||
"typmod" / Int32sb,
|
"typmod" / Int32sb,
|
||||||
"format_code" / Int16ub,
|
"format_code" / Int16ub,
|
||||||
))
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DataRow(BackendMessage):
|
class DataRow(BackendMessage):
|
||||||
key = 'D'
|
key = "D"
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
"_type" / Computed("data_row"),
|
"_type" / Computed("data_row"),
|
||||||
"columncount" / Int16ub,
|
"columncount" / Int16ub,
|
||||||
"columns" / Array(this.columncount, Struct(
|
"columns"
|
||||||
|
/ Array(
|
||||||
|
this.columncount,
|
||||||
|
Struct(
|
||||||
"_type" / Computed("C"),
|
"_type" / Computed("C"),
|
||||||
"length" / Int16sb,
|
"length" / Int16sb,
|
||||||
"value" / Bytes(this.length)
|
"value" / Bytes(this.length),
|
||||||
))
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationOk(BackendMessage):
|
class AuthenticationOk(BackendMessage):
|
||||||
key = 'R'
|
key = "R"
|
||||||
struct = Struct()
|
struct = Struct()
|
||||||
|
|
||||||
|
|
||||||
class ParameterStatus(BackendMessage):
|
class ParameterStatus(BackendMessage):
|
||||||
key = 'S'
|
key = "S"
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
"name" / CString("ASCII"),
|
"name" / CString("ASCII"),
|
||||||
"value" / CString("ASCII"),
|
"value" / CString("ASCII"),
|
||||||
|
@ -281,161 +327,156 @@ class ParameterStatus(BackendMessage):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def normalize(name, value):
|
def normalize(name, value):
|
||||||
if name in ('TimeZone', 'server_version'):
|
if name in ("TimeZone", "server_version"):
|
||||||
value = 'XXX'
|
value = "XXX"
|
||||||
return (name, value)
|
return (name, value)
|
||||||
|
|
||||||
|
|
||||||
class BackendKeyData(BackendMessage):
|
class BackendKeyData(BackendMessage):
|
||||||
key = 'K'
|
key = "K"
|
||||||
struct = Struct(
|
struct = Struct("pid" / Int32ub, "key" / Bytes(4))
|
||||||
"pid" / Int32ub,
|
|
||||||
"key" / Bytes(4)
|
|
||||||
)
|
|
||||||
|
|
||||||
def print(message):
|
def print(message):
|
||||||
# Both of these should be censored, for reproducible regression test output
|
# Both of these should be censored, for reproducible regression test output
|
||||||
return "BackendKeyData(XXX)"
|
return "BackendKeyData(XXX)"
|
||||||
|
|
||||||
|
|
||||||
class NoticeResponse(BackendMessage):
|
class NoticeResponse(BackendMessage):
|
||||||
key = 'N'
|
key = "N"
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
"notices" / GreedyRange(
|
"notices"
|
||||||
|
/ GreedyRange(
|
||||||
Struct(
|
Struct(
|
||||||
"key" / Enum(Byte,
|
"key"
|
||||||
severity=ord('S'),
|
/ Enum(
|
||||||
_severity_not_localized=ord('V'),
|
Byte,
|
||||||
_sql_state=ord('C'),
|
severity=ord("S"),
|
||||||
message=ord('M'),
|
_severity_not_localized=ord("V"),
|
||||||
detail=ord('D'),
|
_sql_state=ord("C"),
|
||||||
hint=ord('H'),
|
message=ord("M"),
|
||||||
_position=ord('P'),
|
detail=ord("D"),
|
||||||
_internal_position=ord('p'),
|
hint=ord("H"),
|
||||||
_internal_query=ord('q'),
|
_position=ord("P"),
|
||||||
_where=ord('W'),
|
_internal_position=ord("p"),
|
||||||
schema_name=ord('s'),
|
_internal_query=ord("q"),
|
||||||
table_name=ord('t'),
|
_where=ord("W"),
|
||||||
column_name=ord('c'),
|
schema_name=ord("s"),
|
||||||
data_type_name=ord('d'),
|
table_name=ord("t"),
|
||||||
constraint_name=ord('n'),
|
column_name=ord("c"),
|
||||||
_file_name=ord('F'),
|
data_type_name=ord("d"),
|
||||||
_line_no=ord('L'),
|
constraint_name=ord("n"),
|
||||||
_routine_name=ord('R')
|
_file_name=ord("F"),
|
||||||
|
_line_no=ord("L"),
|
||||||
|
_routine_name=ord("R"),
|
||||||
),
|
),
|
||||||
"value" / CString("ASCII")
|
"value" / CString("ASCII"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def print(message):
|
def print(message):
|
||||||
return "NoticeResponse({})".format(", ".join(
|
return "NoticeResponse({})".format(
|
||||||
|
", ".join(
|
||||||
"{}={}".format(response.key, response.value)
|
"{}={}".format(response.key, response.value)
|
||||||
for response in message.notices
|
for response in message.notices
|
||||||
if not response.key.startswith('_')
|
if not response.key.startswith("_")
|
||||||
))
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Parse(FrontendMessage):
|
class Parse(FrontendMessage):
|
||||||
key = 'P'
|
key = "P"
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
"name" / CString("ASCII"),
|
"name" / CString("ASCII"),
|
||||||
"query" / CString("ASCII"),
|
"query" / CString("ASCII"),
|
||||||
"_parametercount" / Int16ub,
|
"_parametercount" / Int16ub,
|
||||||
"parameters" / Array(
|
"parameters" / Array(this._parametercount, Int32ub),
|
||||||
this._parametercount,
|
|
||||||
Int32ub
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ParseComplete(BackendMessage):
|
class ParseComplete(BackendMessage):
|
||||||
key = '1'
|
key = "1"
|
||||||
struct = Struct()
|
struct = Struct()
|
||||||
|
|
||||||
|
|
||||||
class Bind(FrontendMessage):
|
class Bind(FrontendMessage):
|
||||||
key = 'B'
|
key = "B"
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
"destination_portal" / CString("ASCII"),
|
"destination_portal" / CString("ASCII"),
|
||||||
"prepared_statement" / CString("ASCII"),
|
"prepared_statement" / CString("ASCII"),
|
||||||
"_parameter_format_code_count" / Int16ub,
|
"_parameter_format_code_count" / Int16ub,
|
||||||
"parameter_format_codes" / Array(this._parameter_format_code_count,
|
"parameter_format_codes" / Array(this._parameter_format_code_count, Int16ub),
|
||||||
Int16ub),
|
|
||||||
"_parameter_value_count" / Int16ub,
|
"_parameter_value_count" / Int16ub,
|
||||||
"parameter_values" / Array(
|
"parameter_values"
|
||||||
|
/ Array(
|
||||||
this._parameter_value_count,
|
this._parameter_value_count,
|
||||||
Struct(
|
Struct("length" / Int32ub, "value" / Bytes(this.length)),
|
||||||
"length" / Int32ub,
|
|
||||||
"value" / Bytes(this.length)
|
|
||||||
)
|
|
||||||
),
|
),
|
||||||
"result_column_format_count" / Int16ub,
|
"result_column_format_count" / Int16ub,
|
||||||
"result_column_format_codes" / Array(this.result_column_format_count,
|
"result_column_format_codes" / Array(this.result_column_format_count, Int16ub),
|
||||||
Int16ub)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BindComplete(BackendMessage):
|
class BindComplete(BackendMessage):
|
||||||
key = '2'
|
key = "2"
|
||||||
struct = Struct()
|
struct = Struct()
|
||||||
|
|
||||||
|
|
||||||
class NoData(BackendMessage):
|
class NoData(BackendMessage):
|
||||||
key = 'n'
|
key = "n"
|
||||||
struct = Struct()
|
struct = Struct()
|
||||||
|
|
||||||
|
|
||||||
class Describe(FrontendMessage):
|
class Describe(FrontendMessage):
|
||||||
key = 'D'
|
key = "D"
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
"type" / Enum(Byte,
|
"type" / Enum(Byte, prepared_statement=ord("S"), portal=ord("P")),
|
||||||
prepared_statement=ord('S'),
|
"name" / CString("ASCII"),
|
||||||
portal=ord('P')
|
|
||||||
),
|
|
||||||
"name" / CString("ASCII")
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def print(message):
|
def print(message):
|
||||||
return "Describe({}={})".format(
|
return "Describe({}={})".format(message.type, message.name or "<unnamed>")
|
||||||
message.type,
|
|
||||||
message.name or "<unnamed>"
|
|
||||||
)
|
|
||||||
|
|
||||||
class Execute(FrontendMessage):
|
class Execute(FrontendMessage):
|
||||||
key = 'E'
|
key = "E"
|
||||||
struct = Struct(
|
struct = Struct("name" / CString("ASCII"), "max_rows_to_return" / Int32ub)
|
||||||
"name" / CString("ASCII"),
|
|
||||||
"max_rows_to_return" / Int32ub
|
|
||||||
)
|
|
||||||
|
|
||||||
def print(message):
|
def print(message):
|
||||||
return "Execute({}, max_rows_to_return={})".format(
|
return "Execute({}, max_rows_to_return={})".format(
|
||||||
message.name or "<unnamed>",
|
message.name or "<unnamed>", message.max_rows_to_return
|
||||||
message.max_rows_to_return
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class Sync(FrontendMessage):
|
class Sync(FrontendMessage):
|
||||||
key = 'S'
|
key = "S"
|
||||||
struct = Struct()
|
struct = Struct()
|
||||||
|
|
||||||
|
|
||||||
frontend_switch = Switch(
|
frontend_switch = Switch(
|
||||||
this.type,
|
this.type,
|
||||||
{**FrontendMessage.name_to_struct(), **SharedMessage.name_to_struct()},
|
{**FrontendMessage.name_to_struct(), **SharedMessage.name_to_struct()},
|
||||||
default=Bytes(this.length - 4)
|
default=Bytes(this.length - 4),
|
||||||
)
|
)
|
||||||
|
|
||||||
backend_switch = Switch(
|
backend_switch = Switch(
|
||||||
this.type,
|
this.type,
|
||||||
{**BackendMessage.name_to_struct(), **SharedMessage.name_to_struct()},
|
{**BackendMessage.name_to_struct(), **SharedMessage.name_to_struct()},
|
||||||
default=Bytes(this.length - 4)
|
default=Bytes(this.length - 4),
|
||||||
)
|
)
|
||||||
|
|
||||||
frontend_msgtypes = Enum(Byte, **{
|
frontend_msgtypes = Enum(
|
||||||
**FrontendMessage.name_to_key(),
|
Byte, **{**FrontendMessage.name_to_key(), **SharedMessage.name_to_key()}
|
||||||
**SharedMessage.name_to_key()
|
)
|
||||||
})
|
|
||||||
|
|
||||||
backend_msgtypes = Enum(Byte, **{
|
backend_msgtypes = Enum(
|
||||||
**BackendMessage.name_to_key(),
|
Byte, **{**BackendMessage.name_to_key(), **SharedMessage.name_to_key()}
|
||||||
**SharedMessage.name_to_key()
|
)
|
||||||
})
|
|
||||||
|
|
||||||
# It might seem a little circuitous to say a frontend message is a kind of frontend
|
# It might seem a little circuitous to say a frontend message is a kind of frontend
|
||||||
# message but this lets us easily customize how they're printed
|
# message but this lets us easily customize how they're printed
|
||||||
|
|
||||||
|
|
||||||
class Frontend(FrontendMessage):
|
class Frontend(FrontendMessage):
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
"type" / frontend_msgtypes,
|
"type" / frontend_msgtypes,
|
||||||
|
@ -447,9 +488,7 @@ class Frontend(FrontendMessage):
|
||||||
|
|
||||||
def print(message):
|
def print(message):
|
||||||
if isinstance(message.body, bytes):
|
if isinstance(message.body, bytes):
|
||||||
return "Frontend(type={},body={})".format(
|
return "Frontend(type={},body={})".format(chr(message.type), message.body)
|
||||||
chr(message.type), message.body
|
|
||||||
)
|
|
||||||
return FrontendMessage.print_message(message.body)
|
return FrontendMessage.print_message(message.body)
|
||||||
|
|
||||||
def typeof(message):
|
def typeof(message):
|
||||||
|
@ -457,6 +496,7 @@ class Frontend(FrontendMessage):
|
||||||
return "Unknown"
|
return "Unknown"
|
||||||
return message.body._type
|
return message.body._type
|
||||||
|
|
||||||
|
|
||||||
class Backend(BackendMessage):
|
class Backend(BackendMessage):
|
||||||
struct = Struct(
|
struct = Struct(
|
||||||
"type" / backend_msgtypes,
|
"type" / backend_msgtypes,
|
||||||
|
@ -468,9 +508,7 @@ class Backend(BackendMessage):
|
||||||
|
|
||||||
def print(message):
|
def print(message):
|
||||||
if isinstance(message.body, bytes):
|
if isinstance(message.body, bytes):
|
||||||
return "Backend(type={},body={})".format(
|
return "Backend(type={},body={})".format(chr(message.type), message.body)
|
||||||
chr(message.type), message.body
|
|
||||||
)
|
|
||||||
return BackendMessage.print_message(message.body)
|
return BackendMessage.print_message(message.body)
|
||||||
|
|
||||||
def typeof(message):
|
def typeof(message):
|
||||||
|
@ -478,10 +516,12 @@ class Backend(BackendMessage):
|
||||||
return "Unknown"
|
return "Unknown"
|
||||||
return message.body._type
|
return message.body._type
|
||||||
|
|
||||||
|
|
||||||
# GreedyRange keeps reading messages until we hit EOF
|
# GreedyRange keeps reading messages until we hit EOF
|
||||||
frontend_messages = GreedyRange(Frontend.struct)
|
frontend_messages = GreedyRange(Frontend.struct)
|
||||||
backend_messages = GreedyRange(Backend.struct)
|
backend_messages = GreedyRange(Backend.struct)
|
||||||
|
|
||||||
|
|
||||||
def parse(message, from_frontend=True):
|
def parse(message, from_frontend=True):
|
||||||
if from_frontend:
|
if from_frontend:
|
||||||
message = frontend_messages.parse(message)
|
message = frontend_messages.parse(message)
|
||||||
|
@ -491,24 +531,27 @@ def parse(message, from_frontend=True):
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
def print(message):
|
def print(message):
|
||||||
if message.from_frontend:
|
if message.from_frontend:
|
||||||
return FrontendMessage.print_message(message)
|
return FrontendMessage.print_message(message)
|
||||||
return BackendMessage.print_message(message)
|
return BackendMessage.print_message(message)
|
||||||
|
|
||||||
|
|
||||||
def message_type(message, from_frontend):
|
def message_type(message, from_frontend):
|
||||||
if from_frontend:
|
if from_frontend:
|
||||||
return FrontendMessage.find_typeof(message)
|
return FrontendMessage.find_typeof(message)
|
||||||
return BackendMessage.find_typeof(message)
|
return BackendMessage.find_typeof(message)
|
||||||
|
|
||||||
|
|
||||||
def message_matches(message, filters, from_frontend):
|
def message_matches(message, filters, from_frontend):
|
||||||
'''
|
"""
|
||||||
Message is something like Backend(Query)) and fiters is something like query="COPY".
|
Message is something like Backend(Query)) and fiters is something like query="COPY".
|
||||||
|
|
||||||
For now we only support strings, and treat them like a regex, which is matched against
|
For now we only support strings, and treat them like a regex, which is matched against
|
||||||
the content of the wrapped message
|
the content of the wrapped message
|
||||||
'''
|
"""
|
||||||
if message._type != 'Backend' and message._type != 'Frontend':
|
if message._type != "Backend" and message._type != "Frontend":
|
||||||
raise ValueError("can't handle {}".format(message._type))
|
raise ValueError("can't handle {}".format(message._type))
|
||||||
|
|
||||||
wrapped = message.body
|
wrapped = message.body
|
||||||
|
|
Loading…
Reference in New Issue