mirror of https://github.com/citusdata/citus.git
Format python files with black
parent
42970665fc
commit
530b24a887
|
@ -10,64 +10,67 @@ import re
|
|||
|
||||
|
||||
class FileScanner:
|
||||
"""
|
||||
FileScanner is an iterator over the lines of a file.
|
||||
It can apply a rewrite rule which can be used to skip lines.
|
||||
"""
|
||||
def __init__(self, file, rewrite = lambda x:x):
|
||||
self.file = file
|
||||
self.line = 1
|
||||
self.rewrite = rewrite
|
||||
"""
|
||||
FileScanner is an iterator over the lines of a file.
|
||||
It can apply a rewrite rule which can be used to skip lines.
|
||||
"""
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
nextline = self.rewrite(next(self.file))
|
||||
if nextline is not None:
|
||||
self.line += 1
|
||||
return nextline
|
||||
def __init__(self, file, rewrite=lambda x: x):
|
||||
self.file = file
|
||||
self.line = 1
|
||||
self.rewrite = rewrite
|
||||
|
||||
def __next__(self):
|
||||
while True:
|
||||
nextline = self.rewrite(next(self.file))
|
||||
if nextline is not None:
|
||||
self.line += 1
|
||||
return nextline
|
||||
|
||||
|
||||
def main():
|
||||
# we only test //d rules, as we need to ignore those lines
|
||||
regexregex = re.compile(r"^/(?P<rule>.*)/d$")
|
||||
regexpipeline = []
|
||||
for line in open(argv[1]):
|
||||
line = line.strip()
|
||||
if not line or line.startswith('#') or not line.endswith('d'):
|
||||
continue
|
||||
rule = regexregex.match(line)
|
||||
if not rule:
|
||||
raise 'Failed to parse regex rule: %s' % line
|
||||
regexpipeline.append(re.compile(rule.group('rule')))
|
||||
# we only test //d rules, as we need to ignore those lines
|
||||
regexregex = re.compile(r"^/(?P<rule>.*)/d$")
|
||||
regexpipeline = []
|
||||
for line in open(argv[1]):
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#") or not line.endswith("d"):
|
||||
continue
|
||||
rule = regexregex.match(line)
|
||||
if not rule:
|
||||
raise "Failed to parse regex rule: %s" % line
|
||||
regexpipeline.append(re.compile(rule.group("rule")))
|
||||
|
||||
def sed(line):
|
||||
if any(regex.search(line) for regex in regexpipeline):
|
||||
return None
|
||||
return line
|
||||
def sed(line):
|
||||
if any(regex.search(line) for regex in regexpipeline):
|
||||
return None
|
||||
return line
|
||||
|
||||
for line in stdin:
|
||||
if line.startswith('+++ '):
|
||||
tab = line.rindex('\t')
|
||||
fname = line[4:tab]
|
||||
file2 = FileScanner(open(fname.replace('.modified', ''), encoding='utf8'), sed)
|
||||
stdout.write(line)
|
||||
elif line.startswith('@@ '):
|
||||
idx_start = line.index('+') + 1
|
||||
idx_end = idx_start + 1
|
||||
while line[idx_end].isdigit():
|
||||
idx_end += 1
|
||||
linenum = int(line[idx_start:idx_end])
|
||||
while file2.line < linenum:
|
||||
next(file2)
|
||||
stdout.write(line)
|
||||
elif line.startswith(' '):
|
||||
stdout.write(' ')
|
||||
stdout.write(next(file2))
|
||||
elif line.startswith('+'):
|
||||
stdout.write('+')
|
||||
stdout.write(next(file2))
|
||||
else:
|
||||
stdout.write(line)
|
||||
for line in stdin:
|
||||
if line.startswith("+++ "):
|
||||
tab = line.rindex("\t")
|
||||
fname = line[4:tab]
|
||||
file2 = FileScanner(
|
||||
open(fname.replace(".modified", ""), encoding="utf8"), sed
|
||||
)
|
||||
stdout.write(line)
|
||||
elif line.startswith("@@ "):
|
||||
idx_start = line.index("+") + 1
|
||||
idx_end = idx_start + 1
|
||||
while line[idx_end].isdigit():
|
||||
idx_end += 1
|
||||
linenum = int(line[idx_start:idx_end])
|
||||
while file2.line < linenum:
|
||||
next(file2)
|
||||
stdout.write(line)
|
||||
elif line.startswith(" "):
|
||||
stdout.write(" ")
|
||||
stdout.write(next(file2))
|
||||
elif line.startswith("+"):
|
||||
stdout.write("+")
|
||||
stdout.write(next(file2))
|
||||
else:
|
||||
stdout.write(line)
|
||||
|
||||
|
||||
main()
|
||||
|
|
|
@ -115,7 +115,6 @@ def copy_copy_modified_binary(datadir):
|
|||
|
||||
|
||||
def copy_test_files(config):
|
||||
|
||||
sql_dir_path = os.path.join(config.datadir, "sql")
|
||||
expected_dir_path = os.path.join(config.datadir, "expected")
|
||||
|
||||
|
@ -132,7 +131,9 @@ def copy_test_files(config):
|
|||
|
||||
line = line[colon_index + 1 :].strip()
|
||||
test_names = line.split(" ")
|
||||
copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config)
|
||||
copy_test_files_with_names(
|
||||
test_names, sql_dir_path, expected_dir_path, config
|
||||
)
|
||||
|
||||
|
||||
def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config):
|
||||
|
@ -140,10 +141,10 @@ def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, conf
|
|||
# make empty files for the skipped tests
|
||||
if test_name in config.skip_tests:
|
||||
expected_sql_file = os.path.join(sql_dir_path, test_name + ".sql")
|
||||
open(expected_sql_file, 'x').close()
|
||||
open(expected_sql_file, "x").close()
|
||||
|
||||
expected_out_file = os.path.join(expected_dir_path, test_name + ".out")
|
||||
open(expected_out_file, 'x').close()
|
||||
open(expected_out_file, "x").close()
|
||||
|
||||
continue
|
||||
|
||||
|
|
|
@ -27,13 +27,11 @@ def initialize_temp_dir_if_not_exists(temp_dir):
|
|||
|
||||
def parallel_run(function, items, *args, **kwargs):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(function, item, *args, **kwargs)
|
||||
for item in items
|
||||
]
|
||||
futures = [executor.submit(function, item, *args, **kwargs) for item in items]
|
||||
for future in futures:
|
||||
future.result()
|
||||
|
||||
|
||||
def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names):
|
||||
subprocess.run(["mkdir", rel_data_path], check=True)
|
||||
|
||||
|
@ -52,7 +50,7 @@ def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names):
|
|||
"--encoding",
|
||||
"UTF8",
|
||||
"--locale",
|
||||
"POSIX"
|
||||
"POSIX",
|
||||
]
|
||||
subprocess.run(command, check=True)
|
||||
add_settings(abs_data_path, settings)
|
||||
|
@ -76,7 +74,9 @@ def create_role(pg_path, node_ports, user_name):
|
|||
user_name, user_name
|
||||
)
|
||||
utils.psql(pg_path, port, command)
|
||||
command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format(user_name)
|
||||
command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format(
|
||||
user_name
|
||||
)
|
||||
utils.psql(pg_path, port, command)
|
||||
|
||||
parallel_run(create, node_ports)
|
||||
|
@ -89,7 +89,9 @@ def coordinator_should_haveshards(pg_path, port):
|
|||
utils.psql(pg_path, port, command)
|
||||
|
||||
|
||||
def start_databases(pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables):
|
||||
def start_databases(
|
||||
pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables
|
||||
):
|
||||
def start(node_name):
|
||||
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
|
||||
node_port = node_name_to_ports[node_name]
|
||||
|
@ -248,7 +250,12 @@ def logfile_name(logfile_prefix, node_name):
|
|||
|
||||
|
||||
def stop_databases(
|
||||
pg_path, rel_data_path, node_name_to_ports, logfile_prefix, no_output=False, parallel=True
|
||||
pg_path,
|
||||
rel_data_path,
|
||||
node_name_to_ports,
|
||||
logfile_prefix,
|
||||
no_output=False,
|
||||
parallel=True,
|
||||
):
|
||||
def stop(node_name):
|
||||
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
|
||||
|
@ -287,7 +294,9 @@ def initialize_citus_cluster(bindir, datadir, settings, config):
|
|||
initialize_db_for_cluster(
|
||||
bindir, datadir, settings, config.node_name_to_ports.keys()
|
||||
)
|
||||
start_databases(bindir, datadir, config.node_name_to_ports, config.name, config.env_variables)
|
||||
start_databases(
|
||||
bindir, datadir, config.node_name_to_ports, config.name, config.env_variables
|
||||
)
|
||||
create_citus_extension(bindir, config.node_name_to_ports.values())
|
||||
add_workers(bindir, config.worker_ports, config.coordinator_port())
|
||||
if not config.is_mx:
|
||||
|
@ -296,6 +305,7 @@ def initialize_citus_cluster(bindir, datadir, settings, config):
|
|||
add_coordinator_to_metadata(bindir, config.coordinator_port())
|
||||
config.setup_steps()
|
||||
|
||||
|
||||
def eprint(*args, **kwargs):
|
||||
"""eprint prints to stderr"""
|
||||
|
||||
|
|
|
@ -57,8 +57,9 @@ port_lock = threading.Lock()
|
|||
|
||||
|
||||
def should_include_config(class_name):
|
||||
|
||||
if inspect.isclass(class_name) and issubclass(class_name, CitusDefaultClusterConfig):
|
||||
if inspect.isclass(class_name) and issubclass(
|
||||
class_name, CitusDefaultClusterConfig
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
@ -167,7 +168,9 @@ class CitusDefaultClusterConfig(CitusBaseClusterConfig):
|
|||
self.add_coordinator_to_metadata = True
|
||||
self.skip_tests = [
|
||||
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name_create",
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name",
|
||||
]
|
||||
|
||||
|
||||
class CitusUpgradeConfig(CitusBaseClusterConfig):
|
||||
|
@ -190,9 +193,13 @@ class PostgresConfig(CitusDefaultClusterConfig):
|
|||
self.new_settings = {
|
||||
"citus.use_citus_managed_tables": False,
|
||||
}
|
||||
self.skip_tests = ["nested_execution",
|
||||
self.skip_tests = [
|
||||
"nested_execution",
|
||||
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name_create",
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name",
|
||||
]
|
||||
|
||||
|
||||
class CitusSingleNodeClusterConfig(CitusDefaultClusterConfig):
|
||||
def __init__(self, arguments):
|
||||
|
@ -229,7 +236,7 @@ class CitusSmallSharedPoolSizeConfig(CitusDefaultClusterConfig):
|
|||
def __init__(self, arguments):
|
||||
super().__init__(arguments)
|
||||
self.new_settings = {
|
||||
"citus.local_shared_pool_size": 5,
|
||||
"citus.local_shared_pool_size": 5,
|
||||
"citus.max_shared_pool_size": 5,
|
||||
}
|
||||
|
||||
|
@ -275,7 +282,7 @@ class CitusUnusualExecutorConfig(CitusDefaultClusterConfig):
|
|||
|
||||
# this setting does not necessarily need to be here
|
||||
# could go any other test
|
||||
self.env_variables = {'PGAPPNAME' : 'test_app'}
|
||||
self.env_variables = {"PGAPPNAME": "test_app"}
|
||||
|
||||
|
||||
class CitusSmallCopyBuffersConfig(CitusDefaultClusterConfig):
|
||||
|
@ -307,9 +314,13 @@ class CitusUnusualQuerySettingsConfig(CitusDefaultClusterConfig):
|
|||
# requires the table with the fk to be converted to a citus_local_table.
|
||||
# As of c11, there is no way to do that through remote execution so this test
|
||||
# will fail
|
||||
"arbitrary_configs_truncate_cascade_create", "arbitrary_configs_truncate_cascade",
|
||||
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
|
||||
"arbitrary_configs_truncate_cascade_create",
|
||||
"arbitrary_configs_truncate_cascade",
|
||||
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name_create",
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name",
|
||||
]
|
||||
|
||||
|
||||
class CitusSingleNodeSingleShardClusterConfig(CitusDefaultClusterConfig):
|
||||
def __init__(self, arguments):
|
||||
|
@ -328,15 +339,20 @@ class CitusShardReplicationFactorClusterConfig(CitusDefaultClusterConfig):
|
|||
self.skip_tests = [
|
||||
# citus does not support foreign keys in distributed tables
|
||||
# when citus.shard_replication_factor >= 2
|
||||
"arbitrary_configs_truncate_partition_create", "arbitrary_configs_truncate_partition",
|
||||
"arbitrary_configs_truncate_partition_create",
|
||||
"arbitrary_configs_truncate_partition",
|
||||
# citus does not support modifying a partition when
|
||||
# citus.shard_replication_factor >= 2
|
||||
"arbitrary_configs_truncate_cascade_create", "arbitrary_configs_truncate_cascade",
|
||||
"arbitrary_configs_truncate_cascade_create",
|
||||
"arbitrary_configs_truncate_cascade",
|
||||
# citus does not support colocating functions with distributed tables when
|
||||
# citus.shard_replication_factor >= 2
|
||||
"function_create", "functions",
|
||||
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
|
||||
"function_create",
|
||||
"functions",
|
||||
# Alter Table statement cannot be run from an arbitrary node so this test will fail
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name_create",
|
||||
"arbitrary_configs_alter_table_add_constraint_without_name",
|
||||
]
|
||||
|
||||
|
||||
class CitusSingleShardClusterConfig(CitusDefaultClusterConfig):
|
||||
|
|
|
@ -12,23 +12,53 @@ import common
|
|||
import config
|
||||
|
||||
args = argparse.ArgumentParser()
|
||||
args.add_argument("test_name", help="Test name (must be included in a schedule.)", nargs='?')
|
||||
args.add_argument("-p", "--path", required=False, help="Relative path for test file (must have a .sql or .spec extension)", type=pathlib.Path)
|
||||
args.add_argument(
|
||||
"test_name", help="Test name (must be included in a schedule.)", nargs="?"
|
||||
)
|
||||
args.add_argument(
|
||||
"-p",
|
||||
"--path",
|
||||
required=False,
|
||||
help="Relative path for test file (must have a .sql or .spec extension)",
|
||||
type=pathlib.Path,
|
||||
)
|
||||
args.add_argument("-r", "--repeat", help="Number of test to run", type=int, default=1)
|
||||
args.add_argument("-b", "--use-base-schedule", required=False, help="Choose base-schedules rather than minimal-schedules", action='store_true')
|
||||
args.add_argument("-w", "--use-whole-schedule-line", required=False, help="Use the whole line found in related schedule", action='store_true')
|
||||
args.add_argument("--valgrind", required=False, help="Run the test with valgrind enabled", action='store_true')
|
||||
args.add_argument(
|
||||
"-b",
|
||||
"--use-base-schedule",
|
||||
required=False,
|
||||
help="Choose base-schedules rather than minimal-schedules",
|
||||
action="store_true",
|
||||
)
|
||||
args.add_argument(
|
||||
"-w",
|
||||
"--use-whole-schedule-line",
|
||||
required=False,
|
||||
help="Use the whole line found in related schedule",
|
||||
action="store_true",
|
||||
)
|
||||
args.add_argument(
|
||||
"--valgrind",
|
||||
required=False,
|
||||
help="Run the test with valgrind enabled",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = vars(args.parse_args())
|
||||
|
||||
regress_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
test_file_path = args['path']
|
||||
test_file_name = args['test_name']
|
||||
use_base_schedule = args['use_base_schedule']
|
||||
use_whole_schedule_line = args['use_whole_schedule_line']
|
||||
test_file_path = args["path"]
|
||||
test_file_name = args["test_name"]
|
||||
use_base_schedule = args["use_base_schedule"]
|
||||
use_whole_schedule_line = args["use_whole_schedule_line"]
|
||||
|
||||
test_files_to_skip = ['multi_cluster_management', 'multi_extension', 'multi_test_helpers', 'multi_insert_select']
|
||||
test_files_to_run_without_schedule = ['single_node_enterprise']
|
||||
test_files_to_skip = [
|
||||
"multi_cluster_management",
|
||||
"multi_extension",
|
||||
"multi_test_helpers",
|
||||
"multi_insert_select",
|
||||
]
|
||||
test_files_to_run_without_schedule = ["single_node_enterprise"]
|
||||
|
||||
if not (test_file_name or test_file_path):
|
||||
print(f"FATAL: No test given.")
|
||||
|
@ -36,7 +66,7 @@ if not (test_file_name or test_file_path):
|
|||
|
||||
|
||||
if test_file_path:
|
||||
test_file_path = os.path.join(os.getcwd(), args['path'])
|
||||
test_file_path = os.path.join(os.getcwd(), args["path"])
|
||||
|
||||
if not os.path.isfile(test_file_path):
|
||||
print(f"ERROR: test file '{test_file_path}' does not exist")
|
||||
|
@ -45,7 +75,7 @@ if test_file_path:
|
|||
test_file_extension = pathlib.Path(test_file_path).suffix
|
||||
test_file_name = pathlib.Path(test_file_path).stem
|
||||
|
||||
if not test_file_extension in '.spec.sql':
|
||||
if not test_file_extension in ".spec.sql":
|
||||
print(
|
||||
"ERROR: Unrecognized test extension. Valid extensions are: .sql and .spec"
|
||||
)
|
||||
|
@ -56,73 +86,73 @@ if test_file_name in test_files_to_skip:
|
|||
print(f"WARNING: Skipping exceptional test: '{test_file_name}'")
|
||||
sys.exit(0)
|
||||
|
||||
test_schedule = ''
|
||||
test_schedule = ""
|
||||
|
||||
# find related schedule
|
||||
for schedule_file_path in sorted(glob(os.path.join(regress_dir, "*_schedule"))):
|
||||
for schedule_line in open(schedule_file_path, 'r'):
|
||||
if re.search(r'\b' + test_file_name + r'\b', schedule_line):
|
||||
test_schedule = pathlib.Path(schedule_file_path).stem
|
||||
if use_whole_schedule_line:
|
||||
test_schedule_line = schedule_line
|
||||
else:
|
||||
test_schedule_line = f"test: {test_file_name}\n"
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
for schedule_line in open(schedule_file_path, "r"):
|
||||
if re.search(r"\b" + test_file_name + r"\b", schedule_line):
|
||||
test_schedule = pathlib.Path(schedule_file_path).stem
|
||||
if use_whole_schedule_line:
|
||||
test_schedule_line = schedule_line
|
||||
else:
|
||||
test_schedule_line = f"test: {test_file_name}\n"
|
||||
break
|
||||
else:
|
||||
continue
|
||||
break
|
||||
|
||||
# map suitable schedule
|
||||
if not test_schedule:
|
||||
print(
|
||||
f"WARNING: Could not find any schedule for '{test_file_name}'"
|
||||
)
|
||||
print(f"WARNING: Could not find any schedule for '{test_file_name}'")
|
||||
sys.exit(0)
|
||||
elif "isolation" in test_schedule:
|
||||
test_schedule = 'base_isolation_schedule'
|
||||
test_schedule = "base_isolation_schedule"
|
||||
elif "failure" in test_schedule:
|
||||
test_schedule = 'failure_base_schedule'
|
||||
test_schedule = "failure_base_schedule"
|
||||
elif "enterprise" in test_schedule:
|
||||
test_schedule = 'enterprise_minimal_schedule'
|
||||
test_schedule = "enterprise_minimal_schedule"
|
||||
elif "split" in test_schedule:
|
||||
test_schedule = 'minimal_schedule'
|
||||
test_schedule = "minimal_schedule"
|
||||
elif "mx" in test_schedule:
|
||||
if use_base_schedule:
|
||||
test_schedule = 'mx_base_schedule'
|
||||
test_schedule = "mx_base_schedule"
|
||||
else:
|
||||
test_schedule = 'mx_minimal_schedule'
|
||||
test_schedule = "mx_minimal_schedule"
|
||||
elif "operations" in test_schedule:
|
||||
test_schedule = 'minimal_schedule'
|
||||
test_schedule = "minimal_schedule"
|
||||
elif test_schedule in config.ARBITRARY_SCHEDULE_NAMES:
|
||||
print(f"WARNING: Arbitrary config schedule ({test_schedule}) is not supported.")
|
||||
sys.exit(0)
|
||||
else:
|
||||
if use_base_schedule:
|
||||
test_schedule = 'base_schedule'
|
||||
test_schedule = "base_schedule"
|
||||
else:
|
||||
test_schedule = 'minimal_schedule'
|
||||
test_schedule = "minimal_schedule"
|
||||
|
||||
# copy base schedule to a temp file and append test_schedule_line
|
||||
# to be able to run tests in parallel (if test_schedule_line is a parallel group.)
|
||||
tmp_schedule_path = os.path.join(regress_dir, f"tmp_schedule_{ random.randint(1, 10000)}")
|
||||
tmp_schedule_path = os.path.join(
|
||||
regress_dir, f"tmp_schedule_{ random.randint(1, 10000)}"
|
||||
)
|
||||
# some tests don't need a schedule to run
|
||||
# e.g tests that are in the first place in their own schedule
|
||||
if test_file_name not in test_files_to_run_without_schedule:
|
||||
shutil.copy2(os.path.join(regress_dir, test_schedule), tmp_schedule_path)
|
||||
with open(tmp_schedule_path, "a") as myfile:
|
||||
for i in range(args['repeat']):
|
||||
myfile.write(test_schedule_line)
|
||||
for i in range(args["repeat"]):
|
||||
myfile.write(test_schedule_line)
|
||||
|
||||
# find suitable make recipe
|
||||
if "isolation" in test_schedule:
|
||||
make_recipe = 'check-isolation-custom-schedule'
|
||||
make_recipe = "check-isolation-custom-schedule"
|
||||
elif "failure" in test_schedule:
|
||||
make_recipe = 'check-failure-custom-schedule'
|
||||
make_recipe = "check-failure-custom-schedule"
|
||||
else:
|
||||
make_recipe = 'check-custom-schedule'
|
||||
make_recipe = "check-custom-schedule"
|
||||
|
||||
if args['valgrind']:
|
||||
make_recipe += '-vg'
|
||||
if args["valgrind"]:
|
||||
make_recipe += "-vg"
|
||||
|
||||
# prepare command to run tests
|
||||
test_command = f"make -C {regress_dir} {make_recipe} SCHEDULE='{pathlib.Path(tmp_schedule_path).stem}'"
|
||||
|
|
|
@ -112,7 +112,11 @@ def main(config):
|
|||
config.node_name_to_ports.keys(),
|
||||
)
|
||||
common.start_databases(
|
||||
config.new_bindir, config.new_datadir, config.node_name_to_ports, config.name, {}
|
||||
config.new_bindir,
|
||||
config.new_datadir,
|
||||
config.node_name_to_ports,
|
||||
config.name,
|
||||
{},
|
||||
)
|
||||
citus_finish_pg_upgrade(config.new_bindir, config.node_name_to_ports.values())
|
||||
|
||||
|
|
|
@ -20,15 +20,17 @@ logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=loggin
|
|||
|
||||
# I. Command Strings
|
||||
|
||||
|
||||
class Handler:
|
||||
'''
|
||||
"""
|
||||
This class hierarchy serves two purposes:
|
||||
1. Allow command strings to be evaluated. Once evaluated you'll have a Handler you can
|
||||
pass packets to
|
||||
2. Process packets as they come in and decide what to do with them.
|
||||
|
||||
Subclasses which want to change how packets are handled should override _handle.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self, root=None):
|
||||
# all packets are first sent to the root handler to be processed
|
||||
self.root = root if root else self
|
||||
|
@ -38,30 +40,31 @@ class Handler:
|
|||
def _accept(self, flow, message):
|
||||
result = self._handle(flow, message)
|
||||
|
||||
if result == 'pass':
|
||||
if result == "pass":
|
||||
# defer to our child
|
||||
if not self.next:
|
||||
raise Exception("we don't know what to do!")
|
||||
|
||||
if self.next._accept(flow, message) == 'stop':
|
||||
if self.next._accept(flow, message) == "stop":
|
||||
if self.root is not self:
|
||||
return 'stop'
|
||||
return "stop"
|
||||
self.next = KillHandler(self)
|
||||
flow.kill()
|
||||
else:
|
||||
return result
|
||||
|
||||
def _handle(self, flow, message):
|
||||
'''
|
||||
"""
|
||||
Handlers can return one of three things:
|
||||
- "done" tells the parent to stop processing. This performs the default action,
|
||||
which is to allow the packet to be sent.
|
||||
- "pass" means to delegate to self.next and do whatever it wants
|
||||
- "stop" means all processing will stop, and all connections will be killed
|
||||
'''
|
||||
"""
|
||||
# subclasses must implement this
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class FilterableMixin:
|
||||
def contains(self, pattern):
|
||||
self.next = Contains(self.root, pattern)
|
||||
|
@ -76,7 +79,7 @@ class FilterableMixin:
|
|||
return self.next
|
||||
|
||||
def __getattr__(self, attr):
|
||||
'''
|
||||
"""
|
||||
Methods such as .onQuery trigger when a packet with that name is intercepted
|
||||
|
||||
Adds support for commands such as:
|
||||
|
@ -85,14 +88,17 @@ class FilterableMixin:
|
|||
Returns a function because the above command is resolved in two steps:
|
||||
conn.onQuery becomes conn.__getattr__("onQuery")
|
||||
conn.onQuery(query="COPY") becomes conn.__getattr__("onQuery")(query="COPY")
|
||||
'''
|
||||
if attr.startswith('on'):
|
||||
"""
|
||||
if attr.startswith("on"):
|
||||
|
||||
def doit(**kwargs):
|
||||
self.next = OnPacket(self.root, attr[2:], kwargs)
|
||||
return self.next
|
||||
|
||||
return doit
|
||||
raise AttributeError
|
||||
|
||||
|
||||
class ActionsMixin:
|
||||
def kill(self):
|
||||
self.next = KillHandler(self.root)
|
||||
|
@ -118,31 +124,39 @@ class ActionsMixin:
|
|||
self.next = ConnectDelayHandler(self.root, timeMs)
|
||||
return self.next
|
||||
|
||||
|
||||
class AcceptHandler(Handler):
|
||||
def __init__(self, root):
|
||||
super().__init__(root)
|
||||
|
||||
def _handle(self, flow, message):
|
||||
return 'done'
|
||||
return "done"
|
||||
|
||||
|
||||
class KillHandler(Handler):
|
||||
def __init__(self, root):
|
||||
super().__init__(root)
|
||||
|
||||
def _handle(self, flow, message):
|
||||
flow.kill()
|
||||
return 'done'
|
||||
return "done"
|
||||
|
||||
|
||||
class KillAllHandler(Handler):
|
||||
def __init__(self, root):
|
||||
super().__init__(root)
|
||||
|
||||
def _handle(self, flow, message):
|
||||
return 'stop'
|
||||
return "stop"
|
||||
|
||||
|
||||
class ResetHandler(Handler):
|
||||
# try to force a RST to be sent, something went very wrong!
|
||||
def __init__(self, root):
|
||||
super().__init__(root)
|
||||
|
||||
def _handle(self, flow, message):
|
||||
flow.kill() # tell mitmproxy this connection should be closed
|
||||
flow.kill() # tell mitmproxy this connection should be closed
|
||||
|
||||
# this is a mitmproxy.connections.ClientConnection(mitmproxy.tcp.BaseHandler)
|
||||
client_conn = flow.client_conn
|
||||
|
@ -152,8 +166,9 @@ class ResetHandler(Handler):
|
|||
# cause linux to send a RST
|
||||
LINGER_ON, LINGER_TIMEOUT = 1, 0
|
||||
conn.setsockopt(
|
||||
socket.SOL_SOCKET, socket.SO_LINGER,
|
||||
struct.pack('ii', LINGER_ON, LINGER_TIMEOUT)
|
||||
socket.SOL_SOCKET,
|
||||
socket.SO_LINGER,
|
||||
struct.pack("ii", LINGER_ON, LINGER_TIMEOUT),
|
||||
)
|
||||
conn.close()
|
||||
|
||||
|
@ -161,28 +176,35 @@ class ResetHandler(Handler):
|
|||
# tries to call conn.shutdown(), but there's nothing else to clean up so that's
|
||||
# maybe okay
|
||||
|
||||
return 'done'
|
||||
return "done"
|
||||
|
||||
|
||||
class CancelHandler(Handler):
|
||||
'Send a SIGINT to the process'
|
||||
"Send a SIGINT to the process"
|
||||
|
||||
def __init__(self, root, pid):
|
||||
super().__init__(root)
|
||||
self.pid = pid
|
||||
|
||||
def _handle(self, flow, message):
|
||||
os.kill(self.pid, signal.SIGINT)
|
||||
# give the signal a chance to be received before we let the packet through
|
||||
time.sleep(0.1)
|
||||
return 'done'
|
||||
return "done"
|
||||
|
||||
|
||||
class ConnectDelayHandler(Handler):
|
||||
'Delay the initial packet by sleeping before deciding what to do'
|
||||
"Delay the initial packet by sleeping before deciding what to do"
|
||||
|
||||
def __init__(self, root, timeMs):
|
||||
super().__init__(root)
|
||||
self.timeMs = timeMs
|
||||
|
||||
def _handle(self, flow, message):
|
||||
if message.is_initial:
|
||||
time.sleep(self.timeMs/1000.0)
|
||||
return 'done'
|
||||
time.sleep(self.timeMs / 1000.0)
|
||||
return "done"
|
||||
|
||||
|
||||
class Contains(Handler, ActionsMixin, FilterableMixin):
|
||||
def __init__(self, root, pattern):
|
||||
|
@ -191,8 +213,9 @@ class Contains(Handler, ActionsMixin, FilterableMixin):
|
|||
|
||||
def _handle(self, flow, message):
|
||||
if self.pattern in message.content:
|
||||
return 'pass'
|
||||
return 'done'
|
||||
return "pass"
|
||||
return "done"
|
||||
|
||||
|
||||
class Matches(Handler, ActionsMixin, FilterableMixin):
|
||||
def __init__(self, root, pattern):
|
||||
|
@ -201,47 +224,56 @@ class Matches(Handler, ActionsMixin, FilterableMixin):
|
|||
|
||||
def _handle(self, flow, message):
|
||||
if self.pattern.search(message.content):
|
||||
return 'pass'
|
||||
return 'done'
|
||||
return "pass"
|
||||
return "done"
|
||||
|
||||
|
||||
class After(Handler, ActionsMixin, FilterableMixin):
|
||||
"Don't pass execution to our child until we've handled 'times' messages"
|
||||
|
||||
def __init__(self, root, times):
|
||||
super().__init__(root)
|
||||
self.target = times
|
||||
|
||||
def _handle(self, flow, message):
|
||||
if not hasattr(flow, '_after_count'):
|
||||
if not hasattr(flow, "_after_count"):
|
||||
flow._after_count = 0
|
||||
|
||||
if flow._after_count >= self.target:
|
||||
return 'pass'
|
||||
return "pass"
|
||||
|
||||
flow._after_count += 1
|
||||
return 'done'
|
||||
return "done"
|
||||
|
||||
|
||||
class OnPacket(Handler, ActionsMixin, FilterableMixin):
|
||||
'''Triggers when a packet of the specified kind comes around'''
|
||||
"""Triggers when a packet of the specified kind comes around"""
|
||||
|
||||
def __init__(self, root, packet_kind, kwargs):
|
||||
super().__init__(root)
|
||||
self.packet_kind = packet_kind
|
||||
self.filters = kwargs
|
||||
|
||||
def _handle(self, flow, message):
|
||||
if not message.parsed:
|
||||
# if this is the first message in the connection we just skip it
|
||||
return 'done'
|
||||
return "done"
|
||||
for msg in message.parsed:
|
||||
typ = structs.message_type(msg, from_frontend=message.from_client)
|
||||
if typ == self.packet_kind:
|
||||
matches = structs.message_matches(msg, self.filters, message.from_client)
|
||||
matches = structs.message_matches(
|
||||
msg, self.filters, message.from_client
|
||||
)
|
||||
if matches:
|
||||
return 'pass'
|
||||
return 'done'
|
||||
return "pass"
|
||||
return "done"
|
||||
|
||||
|
||||
class RootHandler(Handler, ActionsMixin, FilterableMixin):
|
||||
def _handle(self, flow, message):
|
||||
# do whatever the next Handler tells us to do
|
||||
return 'pass'
|
||||
return "pass"
|
||||
|
||||
|
||||
class RecorderCommand:
|
||||
def __init__(self):
|
||||
|
@ -250,23 +282,26 @@ class RecorderCommand:
|
|||
|
||||
def dump(self):
|
||||
# When the user calls dump() we return everything we've captured
|
||||
self.command = 'dump'
|
||||
self.command = "dump"
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
# If the user calls reset() we dump all captured packets without returning them
|
||||
self.command = 'reset'
|
||||
self.command = "reset"
|
||||
return self
|
||||
|
||||
|
||||
# II. Utilities for interfacing with mitmproxy
|
||||
|
||||
|
||||
def build_handler(spec):
|
||||
'Turns a command string into a RootHandler ready to accept packets'
|
||||
"Turns a command string into a RootHandler ready to accept packets"
|
||||
root = RootHandler()
|
||||
recorder = RecorderCommand()
|
||||
handler = eval(spec, {'__builtins__': {}}, {'conn': root, 'recorder': recorder})
|
||||
handler = eval(spec, {"__builtins__": {}}, {"conn": root, "recorder": recorder})
|
||||
return handler.root
|
||||
|
||||
|
||||
# a bunch of globals
|
||||
|
||||
handler = None # the current handler used to process packets
|
||||
|
@ -274,25 +309,25 @@ command_thread = None # sits on the fifo and waits for new commands to come in
|
|||
captured_messages = queue.Queue() # where we store messages used for recorder.dump()
|
||||
connection_count = count() # so we can give connections ids in recorder.dump()
|
||||
|
||||
def listen_for_commands(fifoname):
|
||||
|
||||
def listen_for_commands(fifoname):
|
||||
def emit_row(conn, from_client, message):
|
||||
# we're using the COPY text format. It requires us to escape backslashes
|
||||
cleaned = message.replace('\\', '\\\\')
|
||||
source = 'coordinator' if from_client else 'worker'
|
||||
return '{}\t{}\t{}'.format(conn, source, cleaned)
|
||||
cleaned = message.replace("\\", "\\\\")
|
||||
source = "coordinator" if from_client else "worker"
|
||||
return "{}\t{}\t{}".format(conn, source, cleaned)
|
||||
|
||||
def emit_message(message):
|
||||
if message.is_initial:
|
||||
return emit_row(
|
||||
message.connection_id, message.from_client, '[initial message]'
|
||||
message.connection_id, message.from_client, "[initial message]"
|
||||
)
|
||||
|
||||
pretty = structs.print(message.parsed)
|
||||
return emit_row(message.connection_id, message.from_client, pretty)
|
||||
|
||||
def all_items(queue_):
|
||||
'Pulls everything out of the queue without blocking'
|
||||
"Pulls everything out of the queue without blocking"
|
||||
try:
|
||||
while True:
|
||||
yield queue_.get(block=False)
|
||||
|
@ -300,23 +335,27 @@ def listen_for_commands(fifoname):
|
|||
pass
|
||||
|
||||
def drop_terminate_messages(messages):
|
||||
'''
|
||||
"""
|
||||
Terminate() messages happen eventually, Citus doesn't feel any need to send them
|
||||
immediately, so tests which embed them aren't reproducible and fail to timing
|
||||
issues. Here we simply drop those messages.
|
||||
'''
|
||||
"""
|
||||
|
||||
def isTerminate(msg, from_client):
|
||||
kind = structs.message_type(msg, from_client)
|
||||
return kind == 'Terminate'
|
||||
return kind == "Terminate"
|
||||
|
||||
for message in messages:
|
||||
if not message.parsed:
|
||||
yield message
|
||||
continue
|
||||
message.parsed = ListContainer([
|
||||
msg for msg in message.parsed
|
||||
if not isTerminate(msg, message.from_client)
|
||||
])
|
||||
message.parsed = ListContainer(
|
||||
[
|
||||
msg
|
||||
for msg in message.parsed
|
||||
if not isTerminate(msg, message.from_client)
|
||||
]
|
||||
)
|
||||
message.parsed.from_frontend = message.from_client
|
||||
if len(message.parsed) == 0:
|
||||
continue
|
||||
|
@ -324,35 +363,35 @@ def listen_for_commands(fifoname):
|
|||
|
||||
def handle_recorder(recorder):
|
||||
global connection_count
|
||||
result = ''
|
||||
result = ""
|
||||
|
||||
if recorder.command == 'reset':
|
||||
result = ''
|
||||
if recorder.command == "reset":
|
||||
result = ""
|
||||
connection_count = count()
|
||||
elif recorder.command != 'dump':
|
||||
elif recorder.command != "dump":
|
||||
# this should never happen
|
||||
raise Exception('Unrecognized command: {}'.format(recorder.command))
|
||||
raise Exception("Unrecognized command: {}".format(recorder.command))
|
||||
|
||||
results = []
|
||||
messages = all_items(captured_messages)
|
||||
messages = drop_terminate_messages(messages)
|
||||
for message in messages:
|
||||
if recorder.command == 'reset':
|
||||
if recorder.command == "reset":
|
||||
continue
|
||||
results.append(emit_message(message))
|
||||
result = '\n'.join(results)
|
||||
result = "\n".join(results)
|
||||
|
||||
logging.debug('about to write to fifo')
|
||||
with open(fifoname, mode='w') as fifo:
|
||||
logging.debug('successfully opened the fifo for writing')
|
||||
fifo.write('{}'.format(result))
|
||||
logging.debug("about to write to fifo")
|
||||
with open(fifoname, mode="w") as fifo:
|
||||
logging.debug("successfully opened the fifo for writing")
|
||||
fifo.write("{}".format(result))
|
||||
|
||||
while True:
|
||||
logging.debug('about to read from fifo')
|
||||
with open(fifoname, mode='r') as fifo:
|
||||
logging.debug('successfully opened the fifo for reading')
|
||||
logging.debug("about to read from fifo")
|
||||
with open(fifoname, mode="r") as fifo:
|
||||
logging.debug("successfully opened the fifo for reading")
|
||||
slug = fifo.read()
|
||||
logging.info('received new command: %s', slug.rstrip())
|
||||
logging.info("received new command: %s", slug.rstrip())
|
||||
|
||||
try:
|
||||
handler = build_handler(slug)
|
||||
|
@ -371,13 +410,14 @@ def listen_for_commands(fifoname):
|
|||
except Exception as e:
|
||||
result = str(e)
|
||||
else:
|
||||
result = ''
|
||||
result = ""
|
||||
|
||||
logging.debug("about to write to fifo")
|
||||
with open(fifoname, mode="w") as fifo:
|
||||
logging.debug("successfully opened the fifo for writing")
|
||||
fifo.write("{}\n".format(result))
|
||||
logging.info("responded to command: %s", result.split("\n")[0])
|
||||
|
||||
logging.debug('about to write to fifo')
|
||||
with open(fifoname, mode='w') as fifo:
|
||||
logging.debug('successfully opened the fifo for writing')
|
||||
fifo.write('{}\n'.format(result))
|
||||
logging.info('responded to command: %s', result.split("\n")[0])
|
||||
|
||||
def create_thread(fifoname):
|
||||
global command_thread
|
||||
|
@ -388,42 +428,46 @@ def create_thread(fifoname):
|
|||
return
|
||||
|
||||
if command_thread:
|
||||
print('cannot change the fifo path once mitmproxy has started');
|
||||
print("cannot change the fifo path once mitmproxy has started")
|
||||
return
|
||||
|
||||
command_thread = threading.Thread(target=listen_for_commands, args=(fifoname,), daemon=True)
|
||||
command_thread = threading.Thread(
|
||||
target=listen_for_commands, args=(fifoname,), daemon=True
|
||||
)
|
||||
command_thread.start()
|
||||
|
||||
|
||||
# III. mitmproxy callbacks
|
||||
|
||||
|
||||
def load(loader):
|
||||
loader.add_option('slug', str, 'conn.allow()', "A script to run")
|
||||
loader.add_option('fifo', str, '', "Which fifo to listen on for commands")
|
||||
loader.add_option("slug", str, "conn.allow()", "A script to run")
|
||||
loader.add_option("fifo", str, "", "Which fifo to listen on for commands")
|
||||
|
||||
|
||||
def configure(updated):
|
||||
global handler
|
||||
|
||||
if 'slug' in updated:
|
||||
if "slug" in updated:
|
||||
text = ctx.options.slug
|
||||
handler = build_handler(text)
|
||||
|
||||
if 'fifo' in updated:
|
||||
if "fifo" in updated:
|
||||
fifoname = ctx.options.fifo
|
||||
create_thread(fifoname)
|
||||
|
||||
|
||||
def tcp_message(flow: tcp.TCPFlow):
|
||||
'''
|
||||
"""
|
||||
This callback is hit every time mitmproxy receives a packet. It's the main entrypoint
|
||||
into this script.
|
||||
'''
|
||||
"""
|
||||
global connection_count
|
||||
|
||||
tcp_msg = flow.messages[-1]
|
||||
|
||||
# Keep track of all the different connections, assign a unique id to each
|
||||
if not hasattr(flow, 'connection_id'):
|
||||
if not hasattr(flow, "connection_id"):
|
||||
flow.connection_id = next(connection_count)
|
||||
tcp_msg.connection_id = flow.connection_id
|
||||
|
||||
|
@ -434,7 +478,9 @@ def tcp_message(flow: tcp.TCPFlow):
|
|||
# skip parsing initial messages for now, they're not important
|
||||
tcp_msg.parsed = None
|
||||
else:
|
||||
tcp_msg.parsed = structs.parse(tcp_msg.content, from_frontend=tcp_msg.from_client)
|
||||
tcp_msg.parsed = structs.parse(
|
||||
tcp_msg.content, from_frontend=tcp_msg.from_client
|
||||
)
|
||||
|
||||
# record the message, for debugging purposes
|
||||
captured_messages.put(tcp_msg)
|
||||
|
|
|
@ -1,8 +1,25 @@
|
|||
from construct import (
|
||||
Struct,
|
||||
Int8ub, Int16ub, Int32ub, Int16sb, Int32sb,
|
||||
Bytes, CString, Computed, Switch, Seek, this, Pointer,
|
||||
GreedyRange, Enum, Byte, Probe, FixedSized, RestreamData, GreedyBytes, Array
|
||||
Int8ub,
|
||||
Int16ub,
|
||||
Int32ub,
|
||||
Int16sb,
|
||||
Int32sb,
|
||||
Bytes,
|
||||
CString,
|
||||
Computed,
|
||||
Switch,
|
||||
Seek,
|
||||
this,
|
||||
Pointer,
|
||||
GreedyRange,
|
||||
Enum,
|
||||
Byte,
|
||||
Probe,
|
||||
FixedSized,
|
||||
RestreamData,
|
||||
GreedyBytes,
|
||||
Array,
|
||||
)
|
||||
import construct.lib as cl
|
||||
|
||||
|
@ -11,23 +28,30 @@ import re
|
|||
# For all possible message formats see:
|
||||
# https://www.postgresql.org/docs/current/protocol-message-formats.html
|
||||
|
||||
|
||||
class MessageMeta(type):
|
||||
def __init__(cls, name, bases, namespace):
|
||||
'''
|
||||
"""
|
||||
__init__ is called every time a subclass of MessageMeta is declared
|
||||
'''
|
||||
"""
|
||||
if not hasattr(cls, "_msgtypes"):
|
||||
raise Exception("classes which use MessageMeta must have a '_msgtypes' field")
|
||||
raise Exception(
|
||||
"classes which use MessageMeta must have a '_msgtypes' field"
|
||||
)
|
||||
|
||||
if not hasattr(cls, "_classes"):
|
||||
raise Exception("classes which use MessageMeta must have a '_classes' field")
|
||||
raise Exception(
|
||||
"classes which use MessageMeta must have a '_classes' field"
|
||||
)
|
||||
|
||||
if not hasattr(cls, "struct"):
|
||||
# This is one of the direct subclasses
|
||||
return
|
||||
|
||||
if cls.__name__ in cls._classes:
|
||||
raise Exception("You've already made a class called {}".format( cls.__name__))
|
||||
raise Exception(
|
||||
"You've already made a class called {}".format(cls.__name__)
|
||||
)
|
||||
cls._classes[cls.__name__] = cls
|
||||
|
||||
# add a _type field to the struct so we can identify it while printing structs
|
||||
|
@ -39,34 +63,41 @@ class MessageMeta(type):
|
|||
# register the type, so we can tell the parser about it
|
||||
key = cls.key
|
||||
if key in cls._msgtypes:
|
||||
raise Exception('key {} is already assigned to {}'.format(
|
||||
key, cls._msgtypes[key].__name__)
|
||||
raise Exception(
|
||||
"key {} is already assigned to {}".format(
|
||||
key, cls._msgtypes[key].__name__
|
||||
)
|
||||
)
|
||||
cls._msgtypes[key] = cls
|
||||
|
||||
|
||||
class Message:
|
||||
'Do not subclass this object directly. Instead, subclass of one of the below types'
|
||||
"Do not subclass this object directly. Instead, subclass of one of the below types"
|
||||
|
||||
def print(message):
|
||||
'Define this on subclasses you want to change the representation of'
|
||||
"Define this on subclasses you want to change the representation of"
|
||||
raise NotImplementedError
|
||||
|
||||
def typeof(message):
|
||||
'Define this on subclasses you want to change the expressed type of'
|
||||
"Define this on subclasses you want to change the expressed type of"
|
||||
return message._type
|
||||
|
||||
@classmethod
|
||||
def _default_print(cls, name, msg):
|
||||
recur = cls.print_message
|
||||
return "{}({})".format(name, ",".join(
|
||||
"{}={}".format(key, recur(value)) for key, value in msg.items()
|
||||
if not key.startswith('_')
|
||||
))
|
||||
return "{}({})".format(
|
||||
name,
|
||||
",".join(
|
||||
"{}={}".format(key, recur(value))
|
||||
for key, value in msg.items()
|
||||
if not key.startswith("_")
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def find_typeof(cls, msg):
|
||||
if not hasattr(cls, "_msgtypes"):
|
||||
raise Exception('Do not call this method on Message, call it on a subclass')
|
||||
raise Exception("Do not call this method on Message, call it on a subclass")
|
||||
if isinstance(msg, cl.ListContainer):
|
||||
raise ValueError("do not call this on a list of messages")
|
||||
if not isinstance(msg, cl.Container):
|
||||
|
@ -80,7 +111,7 @@ class Message:
|
|||
@classmethod
|
||||
def print_message(cls, msg):
|
||||
if not hasattr(cls, "_msgtypes"):
|
||||
raise Exception('Do not call this method on Message, call it on a subclass')
|
||||
raise Exception("Do not call this method on Message, call it on a subclass")
|
||||
|
||||
if isinstance(msg, cl.ListContainer):
|
||||
return repr([cls.print_message(message) for message in msg])
|
||||
|
@ -101,38 +132,34 @@ class Message:
|
|||
|
||||
@classmethod
|
||||
def name_to_struct(cls):
|
||||
return {
|
||||
_class.__name__: _class.struct
|
||||
for _class in cls._msgtypes.values()
|
||||
}
|
||||
return {_class.__name__: _class.struct for _class in cls._msgtypes.values()}
|
||||
|
||||
@classmethod
|
||||
def name_to_key(cls):
|
||||
return {
|
||||
_class.__name__ : ord(key)
|
||||
for key, _class in cls._msgtypes.items()
|
||||
}
|
||||
return {_class.__name__: ord(key) for key, _class in cls._msgtypes.items()}
|
||||
|
||||
|
||||
class SharedMessage(Message, metaclass=MessageMeta):
|
||||
'A message which could be sent by either the frontend or the backend'
|
||||
"A message which could be sent by either the frontend or the backend"
|
||||
_msgtypes = dict()
|
||||
_classes = dict()
|
||||
|
||||
|
||||
class FrontendMessage(Message, metaclass=MessageMeta):
|
||||
'A message which will only be sent be a backend'
|
||||
"A message which will only be sent be a backend"
|
||||
_msgtypes = dict()
|
||||
_classes = dict()
|
||||
|
||||
|
||||
class BackendMessage(Message, metaclass=MessageMeta):
|
||||
'A message which will only be sent be a frontend'
|
||||
"A message which will only be sent be a frontend"
|
||||
_msgtypes = dict()
|
||||
_classes = dict()
|
||||
|
||||
|
||||
class Query(FrontendMessage):
|
||||
key = 'Q'
|
||||
struct = Struct(
|
||||
"query" / CString("ascii")
|
||||
)
|
||||
key = "Q"
|
||||
struct = Struct("query" / CString("ascii"))
|
||||
|
||||
@staticmethod
|
||||
def print(message):
|
||||
|
@ -144,132 +171,151 @@ class Query(FrontendMessage):
|
|||
|
||||
@staticmethod
|
||||
def normalize_shards(content):
|
||||
'''
|
||||
"""
|
||||
For example:
|
||||
>>> normalize_shards(
|
||||
>>> 'COPY public.copy_test_120340 (key, value) FROM STDIN WITH (FORMAT BINARY))'
|
||||
>>> )
|
||||
'COPY public.copy_test_XXXXXX (key, value) FROM STDIN WITH (FORMAT BINARY))'
|
||||
'''
|
||||
"""
|
||||
result = content
|
||||
pattern = re.compile('public\.[a-z_]+(?P<shardid>[0-9]+)')
|
||||
pattern = re.compile("public\.[a-z_]+(?P<shardid>[0-9]+)")
|
||||
for match in pattern.finditer(content):
|
||||
span = match.span('shardid')
|
||||
replacement = 'X'*( span[1] - span[0] )
|
||||
result = result[:span[0]] + replacement + result[span[1]:]
|
||||
span = match.span("shardid")
|
||||
replacement = "X" * (span[1] - span[0])
|
||||
result = result[: span[0]] + replacement + result[span[1] :]
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def normalize_timestamps(content):
|
||||
'''
|
||||
"""
|
||||
For example:
|
||||
>>> normalize_timestamps('2018-06-07 05:18:19.388992-07')
|
||||
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
|
||||
>>> normalize_timestamps('2018-06-11 05:30:43.01382-07')
|
||||
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
|
||||
'''
|
||||
"""
|
||||
|
||||
pattern = re.compile(
|
||||
'[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{2,6}-[0-9]{2}'
|
||||
"[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{2,6}-[0-9]{2}"
|
||||
)
|
||||
|
||||
return re.sub(pattern, 'XXXX-XX-XX XX:XX:XX.XXXXXX-XX', content)
|
||||
return re.sub(pattern, "XXXX-XX-XX XX:XX:XX.XXXXXX-XX", content)
|
||||
|
||||
@staticmethod
|
||||
def normalize_assign_txn_id(content):
|
||||
'''
|
||||
"""
|
||||
For example:
|
||||
>>> normalize_assign_txn_id('SELECT assign_distributed_transaction_id(0, 52, ...')
|
||||
'SELECT assign_distributed_transaction_id(0, XX, ...'
|
||||
'''
|
||||
"""
|
||||
|
||||
pattern = re.compile(
|
||||
'assign_distributed_transaction_id\s*\(' # a method call
|
||||
'\s*[0-9]+\s*,' # an integer first parameter
|
||||
'\s*(?P<transaction_id>[0-9]+)' # an integer second parameter
|
||||
"assign_distributed_transaction_id\s*\(" # a method call
|
||||
"\s*[0-9]+\s*," # an integer first parameter
|
||||
"\s*(?P<transaction_id>[0-9]+)" # an integer second parameter
|
||||
)
|
||||
result = content
|
||||
for match in pattern.finditer(content):
|
||||
span = match.span('transaction_id')
|
||||
result = result[:span[0]] + 'XX' + result[span[1]:]
|
||||
span = match.span("transaction_id")
|
||||
result = result[: span[0]] + "XX" + result[span[1] :]
|
||||
return result
|
||||
|
||||
|
||||
class Terminate(FrontendMessage):
|
||||
key = 'X'
|
||||
key = "X"
|
||||
struct = Struct()
|
||||
|
||||
|
||||
class CopyData(SharedMessage):
|
||||
key = 'd'
|
||||
key = "d"
|
||||
struct = Struct(
|
||||
'data' / GreedyBytes # reads all of the data left in this substream
|
||||
"data" / GreedyBytes # reads all of the data left in this substream
|
||||
)
|
||||
|
||||
|
||||
class CopyDone(SharedMessage):
|
||||
key = 'c'
|
||||
key = "c"
|
||||
struct = Struct()
|
||||
|
||||
|
||||
class EmptyQueryResponse(BackendMessage):
|
||||
key = 'I'
|
||||
key = "I"
|
||||
struct = Struct()
|
||||
|
||||
|
||||
class CopyOutResponse(BackendMessage):
|
||||
key = 'H'
|
||||
key = "H"
|
||||
struct = Struct(
|
||||
"format" / Int8ub,
|
||||
"columncount" / Int16ub,
|
||||
"columns" / Array(this.columncount, Struct(
|
||||
"format" / Int16ub
|
||||
))
|
||||
"columns" / Array(this.columncount, Struct("format" / Int16ub)),
|
||||
)
|
||||
|
||||
|
||||
class ReadyForQuery(BackendMessage):
|
||||
key='Z'
|
||||
struct = Struct("state"/Enum(Byte,
|
||||
idle=ord('I'),
|
||||
in_transaction_block=ord('T'),
|
||||
in_failed_transaction_block=ord('E')
|
||||
))
|
||||
key = "Z"
|
||||
struct = Struct(
|
||||
"state"
|
||||
/ Enum(
|
||||
Byte,
|
||||
idle=ord("I"),
|
||||
in_transaction_block=ord("T"),
|
||||
in_failed_transaction_block=ord("E"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class CommandComplete(BackendMessage):
|
||||
key = 'C'
|
||||
struct = Struct(
|
||||
"command" / CString("ascii")
|
||||
)
|
||||
key = "C"
|
||||
struct = Struct("command" / CString("ascii"))
|
||||
|
||||
|
||||
class RowDescription(BackendMessage):
|
||||
key = 'T'
|
||||
key = "T"
|
||||
struct = Struct(
|
||||
"fieldcount" / Int16ub,
|
||||
"fields" / Array(this.fieldcount, Struct(
|
||||
"_type" / Computed("F"),
|
||||
"name" / CString("ascii"),
|
||||
"tableoid" / Int32ub,
|
||||
"colattrnum" / Int16ub,
|
||||
"typoid" / Int32ub,
|
||||
"typlen" / Int16sb,
|
||||
"typmod" / Int32sb,
|
||||
"format_code" / Int16ub,
|
||||
))
|
||||
"fields"
|
||||
/ Array(
|
||||
this.fieldcount,
|
||||
Struct(
|
||||
"_type" / Computed("F"),
|
||||
"name" / CString("ascii"),
|
||||
"tableoid" / Int32ub,
|
||||
"colattrnum" / Int16ub,
|
||||
"typoid" / Int32ub,
|
||||
"typlen" / Int16sb,
|
||||
"typmod" / Int32sb,
|
||||
"format_code" / Int16ub,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class DataRow(BackendMessage):
|
||||
key = 'D'
|
||||
key = "D"
|
||||
struct = Struct(
|
||||
"_type" / Computed("data_row"),
|
||||
"columncount" / Int16ub,
|
||||
"columns" / Array(this.columncount, Struct(
|
||||
"_type" / Computed("C"),
|
||||
"length" / Int16sb,
|
||||
"value" / Bytes(this.length)
|
||||
))
|
||||
"columns"
|
||||
/ Array(
|
||||
this.columncount,
|
||||
Struct(
|
||||
"_type" / Computed("C"),
|
||||
"length" / Int16sb,
|
||||
"value" / Bytes(this.length),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class AuthenticationOk(BackendMessage):
|
||||
key = 'R'
|
||||
key = "R"
|
||||
struct = Struct()
|
||||
|
||||
|
||||
class ParameterStatus(BackendMessage):
|
||||
key = 'S'
|
||||
key = "S"
|
||||
struct = Struct(
|
||||
"name" / CString("ASCII"),
|
||||
"value" / CString("ASCII"),
|
||||
|
@ -281,161 +327,156 @@ class ParameterStatus(BackendMessage):
|
|||
|
||||
@staticmethod
|
||||
def normalize(name, value):
|
||||
if name in ('TimeZone', 'server_version'):
|
||||
value = 'XXX'
|
||||
if name in ("TimeZone", "server_version"):
|
||||
value = "XXX"
|
||||
return (name, value)
|
||||
|
||||
|
||||
class BackendKeyData(BackendMessage):
|
||||
key = 'K'
|
||||
struct = Struct(
|
||||
"pid" / Int32ub,
|
||||
"key" / Bytes(4)
|
||||
)
|
||||
key = "K"
|
||||
struct = Struct("pid" / Int32ub, "key" / Bytes(4))
|
||||
|
||||
def print(message):
|
||||
# Both of these should be censored, for reproducible regression test output
|
||||
return "BackendKeyData(XXX)"
|
||||
|
||||
|
||||
class NoticeResponse(BackendMessage):
|
||||
key = 'N'
|
||||
key = "N"
|
||||
struct = Struct(
|
||||
"notices" / GreedyRange(
|
||||
"notices"
|
||||
/ GreedyRange(
|
||||
Struct(
|
||||
"key" / Enum(Byte,
|
||||
severity=ord('S'),
|
||||
_severity_not_localized=ord('V'),
|
||||
_sql_state=ord('C'),
|
||||
message=ord('M'),
|
||||
detail=ord('D'),
|
||||
hint=ord('H'),
|
||||
_position=ord('P'),
|
||||
_internal_position=ord('p'),
|
||||
_internal_query=ord('q'),
|
||||
_where=ord('W'),
|
||||
schema_name=ord('s'),
|
||||
table_name=ord('t'),
|
||||
column_name=ord('c'),
|
||||
data_type_name=ord('d'),
|
||||
constraint_name=ord('n'),
|
||||
_file_name=ord('F'),
|
||||
_line_no=ord('L'),
|
||||
_routine_name=ord('R')
|
||||
),
|
||||
"value" / CString("ASCII")
|
||||
"key"
|
||||
/ Enum(
|
||||
Byte,
|
||||
severity=ord("S"),
|
||||
_severity_not_localized=ord("V"),
|
||||
_sql_state=ord("C"),
|
||||
message=ord("M"),
|
||||
detail=ord("D"),
|
||||
hint=ord("H"),
|
||||
_position=ord("P"),
|
||||
_internal_position=ord("p"),
|
||||
_internal_query=ord("q"),
|
||||
_where=ord("W"),
|
||||
schema_name=ord("s"),
|
||||
table_name=ord("t"),
|
||||
column_name=ord("c"),
|
||||
data_type_name=ord("d"),
|
||||
constraint_name=ord("n"),
|
||||
_file_name=ord("F"),
|
||||
_line_no=ord("L"),
|
||||
_routine_name=ord("R"),
|
||||
),
|
||||
"value" / CString("ASCII"),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def print(message):
|
||||
return "NoticeResponse({})".format(", ".join(
|
||||
"{}={}".format(response.key, response.value)
|
||||
for response in message.notices
|
||||
if not response.key.startswith('_')
|
||||
))
|
||||
return "NoticeResponse({})".format(
|
||||
", ".join(
|
||||
"{}={}".format(response.key, response.value)
|
||||
for response in message.notices
|
||||
if not response.key.startswith("_")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class Parse(FrontendMessage):
|
||||
key = 'P'
|
||||
key = "P"
|
||||
struct = Struct(
|
||||
"name" / CString("ASCII"),
|
||||
"query" / CString("ASCII"),
|
||||
"_parametercount" / Int16ub,
|
||||
"parameters" / Array(
|
||||
this._parametercount,
|
||||
Int32ub
|
||||
)
|
||||
"parameters" / Array(this._parametercount, Int32ub),
|
||||
)
|
||||
|
||||
|
||||
class ParseComplete(BackendMessage):
|
||||
key = '1'
|
||||
key = "1"
|
||||
struct = Struct()
|
||||
|
||||
|
||||
class Bind(FrontendMessage):
|
||||
key = 'B'
|
||||
key = "B"
|
||||
struct = Struct(
|
||||
"destination_portal" / CString("ASCII"),
|
||||
"prepared_statement" / CString("ASCII"),
|
||||
"_parameter_format_code_count" / Int16ub,
|
||||
"parameter_format_codes" / Array(this._parameter_format_code_count,
|
||||
Int16ub),
|
||||
"parameter_format_codes" / Array(this._parameter_format_code_count, Int16ub),
|
||||
"_parameter_value_count" / Int16ub,
|
||||
"parameter_values" / Array(
|
||||
"parameter_values"
|
||||
/ Array(
|
||||
this._parameter_value_count,
|
||||
Struct(
|
||||
"length" / Int32ub,
|
||||
"value" / Bytes(this.length)
|
||||
)
|
||||
Struct("length" / Int32ub, "value" / Bytes(this.length)),
|
||||
),
|
||||
"result_column_format_count" / Int16ub,
|
||||
"result_column_format_codes" / Array(this.result_column_format_count,
|
||||
Int16ub)
|
||||
"result_column_format_codes" / Array(this.result_column_format_count, Int16ub),
|
||||
)
|
||||
|
||||
|
||||
class BindComplete(BackendMessage):
|
||||
key = '2'
|
||||
key = "2"
|
||||
struct = Struct()
|
||||
|
||||
|
||||
class NoData(BackendMessage):
|
||||
key = 'n'
|
||||
key = "n"
|
||||
struct = Struct()
|
||||
|
||||
|
||||
class Describe(FrontendMessage):
|
||||
key = 'D'
|
||||
key = "D"
|
||||
struct = Struct(
|
||||
"type" / Enum(Byte,
|
||||
prepared_statement=ord('S'),
|
||||
portal=ord('P')
|
||||
),
|
||||
"name" / CString("ASCII")
|
||||
"type" / Enum(Byte, prepared_statement=ord("S"), portal=ord("P")),
|
||||
"name" / CString("ASCII"),
|
||||
)
|
||||
|
||||
def print(message):
|
||||
return "Describe({}={})".format(
|
||||
message.type,
|
||||
message.name or "<unnamed>"
|
||||
)
|
||||
return "Describe({}={})".format(message.type, message.name or "<unnamed>")
|
||||
|
||||
|
||||
class Execute(FrontendMessage):
|
||||
key = 'E'
|
||||
struct = Struct(
|
||||
"name" / CString("ASCII"),
|
||||
"max_rows_to_return" / Int32ub
|
||||
)
|
||||
key = "E"
|
||||
struct = Struct("name" / CString("ASCII"), "max_rows_to_return" / Int32ub)
|
||||
|
||||
def print(message):
|
||||
return "Execute({}, max_rows_to_return={})".format(
|
||||
message.name or "<unnamed>",
|
||||
message.max_rows_to_return
|
||||
message.name or "<unnamed>", message.max_rows_to_return
|
||||
)
|
||||
|
||||
|
||||
class Sync(FrontendMessage):
|
||||
key = 'S'
|
||||
key = "S"
|
||||
struct = Struct()
|
||||
|
||||
|
||||
frontend_switch = Switch(
|
||||
this.type,
|
||||
{ **FrontendMessage.name_to_struct(), **SharedMessage.name_to_struct() },
|
||||
default=Bytes(this.length - 4)
|
||||
{**FrontendMessage.name_to_struct(), **SharedMessage.name_to_struct()},
|
||||
default=Bytes(this.length - 4),
|
||||
)
|
||||
|
||||
backend_switch = Switch(
|
||||
this.type,
|
||||
{**BackendMessage.name_to_struct(), **SharedMessage.name_to_struct()},
|
||||
default=Bytes(this.length - 4)
|
||||
default=Bytes(this.length - 4),
|
||||
)
|
||||
|
||||
frontend_msgtypes = Enum(Byte, **{
|
||||
**FrontendMessage.name_to_key(),
|
||||
**SharedMessage.name_to_key()
|
||||
})
|
||||
frontend_msgtypes = Enum(
|
||||
Byte, **{**FrontendMessage.name_to_key(), **SharedMessage.name_to_key()}
|
||||
)
|
||||
|
||||
backend_msgtypes = Enum(Byte, **{
|
||||
**BackendMessage.name_to_key(),
|
||||
**SharedMessage.name_to_key()
|
||||
})
|
||||
backend_msgtypes = Enum(
|
||||
Byte, **{**BackendMessage.name_to_key(), **SharedMessage.name_to_key()}
|
||||
)
|
||||
|
||||
# It might seem a little circuitous to say a frontend message is a kind of frontend
|
||||
# message but this lets us easily customize how they're printed
|
||||
|
||||
|
||||
class Frontend(FrontendMessage):
|
||||
struct = Struct(
|
||||
"type" / frontend_msgtypes,
|
||||
|
@ -447,9 +488,7 @@ class Frontend(FrontendMessage):
|
|||
|
||||
def print(message):
|
||||
if isinstance(message.body, bytes):
|
||||
return "Frontend(type={},body={})".format(
|
||||
chr(message.type), message.body
|
||||
)
|
||||
return "Frontend(type={},body={})".format(chr(message.type), message.body)
|
||||
return FrontendMessage.print_message(message.body)
|
||||
|
||||
def typeof(message):
|
||||
|
@ -457,6 +496,7 @@ class Frontend(FrontendMessage):
|
|||
return "Unknown"
|
||||
return message.body._type
|
||||
|
||||
|
||||
class Backend(BackendMessage):
|
||||
struct = Struct(
|
||||
"type" / backend_msgtypes,
|
||||
|
@ -468,9 +508,7 @@ class Backend(BackendMessage):
|
|||
|
||||
def print(message):
|
||||
if isinstance(message.body, bytes):
|
||||
return "Backend(type={},body={})".format(
|
||||
chr(message.type), message.body
|
||||
)
|
||||
return "Backend(type={},body={})".format(chr(message.type), message.body)
|
||||
return BackendMessage.print_message(message.body)
|
||||
|
||||
def typeof(message):
|
||||
|
@ -478,10 +516,12 @@ class Backend(BackendMessage):
|
|||
return "Unknown"
|
||||
return message.body._type
|
||||
|
||||
|
||||
# GreedyRange keeps reading messages until we hit EOF
|
||||
frontend_messages = GreedyRange(Frontend.struct)
|
||||
backend_messages = GreedyRange(Backend.struct)
|
||||
|
||||
|
||||
def parse(message, from_frontend=True):
|
||||
if from_frontend:
|
||||
message = frontend_messages.parse(message)
|
||||
|
@ -491,24 +531,27 @@ def parse(message, from_frontend=True):
|
|||
|
||||
return message
|
||||
|
||||
|
||||
def print(message):
|
||||
if message.from_frontend:
|
||||
return FrontendMessage.print_message(message)
|
||||
return BackendMessage.print_message(message)
|
||||
|
||||
|
||||
def message_type(message, from_frontend):
|
||||
if from_frontend:
|
||||
return FrontendMessage.find_typeof(message)
|
||||
return BackendMessage.find_typeof(message)
|
||||
|
||||
|
||||
def message_matches(message, filters, from_frontend):
|
||||
'''
|
||||
"""
|
||||
Message is something like Backend(Query)) and fiters is something like query="COPY".
|
||||
|
||||
For now we only support strings, and treat them like a regex, which is matched against
|
||||
the content of the wrapped message
|
||||
'''
|
||||
if message._type != 'Backend' and message._type != 'Frontend':
|
||||
"""
|
||||
if message._type != "Backend" and message._type != "Frontend":
|
||||
raise ValueError("can't handle {}".format(message._type))
|
||||
|
||||
wrapped = message.body
|
||||
|
|
Loading…
Reference in New Issue