From 530b24a887296765cd18a64c715e8eb861bd6c51 Mon Sep 17 00:00:00 2001 From: Jelte Fennema Date: Wed, 8 Feb 2023 11:34:49 +0100 Subject: [PATCH] Format python files with black --- src/test/regress/bin/diff-filter | 107 ++--- .../citus_arbitrary_configs.py | 9 +- src/test/regress/citus_tests/common.py | 28 +- src/test/regress/citus_tests/config.py | 46 +- src/test/regress/citus_tests/run_test.py | 120 ++++-- .../citus_tests/upgrade/pg_upgrade_test.py | 6 +- src/test/regress/mitmscripts/fluent.py | 212 +++++---- src/test/regress/mitmscripts/structs.py | 403 ++++++++++-------- 8 files changed, 542 insertions(+), 389 deletions(-) diff --git a/src/test/regress/bin/diff-filter b/src/test/regress/bin/diff-filter index 960726285..3908ca37f 100755 --- a/src/test/regress/bin/diff-filter +++ b/src/test/regress/bin/diff-filter @@ -10,64 +10,67 @@ import re class FileScanner: - """ - FileScanner is an iterator over the lines of a file. - It can apply a rewrite rule which can be used to skip lines. - """ - def __init__(self, file, rewrite = lambda x:x): - self.file = file - self.line = 1 - self.rewrite = rewrite + """ + FileScanner is an iterator over the lines of a file. + It can apply a rewrite rule which can be used to skip lines. + """ - def __next__(self): - while True: - nextline = self.rewrite(next(self.file)) - if nextline is not None: - self.line += 1 - return nextline + def __init__(self, file, rewrite=lambda x: x): + self.file = file + self.line = 1 + self.rewrite = rewrite + + def __next__(self): + while True: + nextline = self.rewrite(next(self.file)) + if nextline is not None: + self.line += 1 + return nextline def main(): - # we only test //d rules, as we need to ignore those lines - regexregex = re.compile(r"^/(?P.*)/d$") - regexpipeline = [] - for line in open(argv[1]): - line = line.strip() - if not line or line.startswith('#') or not line.endswith('d'): - continue - rule = regexregex.match(line) - if not rule: - raise 'Failed to parse regex rule: %s' % line - regexpipeline.append(re.compile(rule.group('rule'))) + # we only test //d rules, as we need to ignore those lines + regexregex = re.compile(r"^/(?P.*)/d$") + regexpipeline = [] + for line in open(argv[1]): + line = line.strip() + if not line or line.startswith("#") or not line.endswith("d"): + continue + rule = regexregex.match(line) + if not rule: + raise "Failed to parse regex rule: %s" % line + regexpipeline.append(re.compile(rule.group("rule"))) - def sed(line): - if any(regex.search(line) for regex in regexpipeline): - return None - return line + def sed(line): + if any(regex.search(line) for regex in regexpipeline): + return None + return line - for line in stdin: - if line.startswith('+++ '): - tab = line.rindex('\t') - fname = line[4:tab] - file2 = FileScanner(open(fname.replace('.modified', ''), encoding='utf8'), sed) - stdout.write(line) - elif line.startswith('@@ '): - idx_start = line.index('+') + 1 - idx_end = idx_start + 1 - while line[idx_end].isdigit(): - idx_end += 1 - linenum = int(line[idx_start:idx_end]) - while file2.line < linenum: - next(file2) - stdout.write(line) - elif line.startswith(' '): - stdout.write(' ') - stdout.write(next(file2)) - elif line.startswith('+'): - stdout.write('+') - stdout.write(next(file2)) - else: - stdout.write(line) + for line in stdin: + if line.startswith("+++ "): + tab = line.rindex("\t") + fname = line[4:tab] + file2 = FileScanner( + open(fname.replace(".modified", ""), encoding="utf8"), sed + ) + stdout.write(line) + elif line.startswith("@@ "): + idx_start = line.index("+") + 1 + idx_end = idx_start + 1 + while line[idx_end].isdigit(): + idx_end += 1 + linenum = int(line[idx_start:idx_end]) + while file2.line < linenum: + next(file2) + stdout.write(line) + elif line.startswith(" "): + stdout.write(" ") + stdout.write(next(file2)) + elif line.startswith("+"): + stdout.write("+") + stdout.write(next(file2)) + else: + stdout.write(line) main() diff --git a/src/test/regress/citus_tests/arbitrary_configs/citus_arbitrary_configs.py b/src/test/regress/citus_tests/arbitrary_configs/citus_arbitrary_configs.py index 375298148..2d07c0b63 100755 --- a/src/test/regress/citus_tests/arbitrary_configs/citus_arbitrary_configs.py +++ b/src/test/regress/citus_tests/arbitrary_configs/citus_arbitrary_configs.py @@ -115,7 +115,6 @@ def copy_copy_modified_binary(datadir): def copy_test_files(config): - sql_dir_path = os.path.join(config.datadir, "sql") expected_dir_path = os.path.join(config.datadir, "expected") @@ -132,7 +131,9 @@ def copy_test_files(config): line = line[colon_index + 1 :].strip() test_names = line.split(" ") - copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config) + copy_test_files_with_names( + test_names, sql_dir_path, expected_dir_path, config + ) def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config): @@ -140,10 +141,10 @@ def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, conf # make empty files for the skipped tests if test_name in config.skip_tests: expected_sql_file = os.path.join(sql_dir_path, test_name + ".sql") - open(expected_sql_file, 'x').close() + open(expected_sql_file, "x").close() expected_out_file = os.path.join(expected_dir_path, test_name + ".out") - open(expected_out_file, 'x').close() + open(expected_out_file, "x").close() continue diff --git a/src/test/regress/citus_tests/common.py b/src/test/regress/citus_tests/common.py index cf625c4f7..1b8c69a45 100644 --- a/src/test/regress/citus_tests/common.py +++ b/src/test/regress/citus_tests/common.py @@ -27,13 +27,11 @@ def initialize_temp_dir_if_not_exists(temp_dir): def parallel_run(function, items, *args, **kwargs): with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit(function, item, *args, **kwargs) - for item in items - ] + futures = [executor.submit(function, item, *args, **kwargs) for item in items] for future in futures: future.result() + def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names): subprocess.run(["mkdir", rel_data_path], check=True) @@ -52,7 +50,7 @@ def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names): "--encoding", "UTF8", "--locale", - "POSIX" + "POSIX", ] subprocess.run(command, check=True) add_settings(abs_data_path, settings) @@ -76,7 +74,9 @@ def create_role(pg_path, node_ports, user_name): user_name, user_name ) utils.psql(pg_path, port, command) - command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format(user_name) + command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format( + user_name + ) utils.psql(pg_path, port, command) parallel_run(create, node_ports) @@ -89,7 +89,9 @@ def coordinator_should_haveshards(pg_path, port): utils.psql(pg_path, port, command) -def start_databases(pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables): +def start_databases( + pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables +): def start(node_name): abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name)) node_port = node_name_to_ports[node_name] @@ -248,7 +250,12 @@ def logfile_name(logfile_prefix, node_name): def stop_databases( - pg_path, rel_data_path, node_name_to_ports, logfile_prefix, no_output=False, parallel=True + pg_path, + rel_data_path, + node_name_to_ports, + logfile_prefix, + no_output=False, + parallel=True, ): def stop(node_name): abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name)) @@ -287,7 +294,9 @@ def initialize_citus_cluster(bindir, datadir, settings, config): initialize_db_for_cluster( bindir, datadir, settings, config.node_name_to_ports.keys() ) - start_databases(bindir, datadir, config.node_name_to_ports, config.name, config.env_variables) + start_databases( + bindir, datadir, config.node_name_to_ports, config.name, config.env_variables + ) create_citus_extension(bindir, config.node_name_to_ports.values()) add_workers(bindir, config.worker_ports, config.coordinator_port()) if not config.is_mx: @@ -296,6 +305,7 @@ def initialize_citus_cluster(bindir, datadir, settings, config): add_coordinator_to_metadata(bindir, config.coordinator_port()) config.setup_steps() + def eprint(*args, **kwargs): """eprint prints to stderr""" diff --git a/src/test/regress/citus_tests/config.py b/src/test/regress/citus_tests/config.py index 40de2a3b6..88b47be45 100644 --- a/src/test/regress/citus_tests/config.py +++ b/src/test/regress/citus_tests/config.py @@ -57,8 +57,9 @@ port_lock = threading.Lock() def should_include_config(class_name): - - if inspect.isclass(class_name) and issubclass(class_name, CitusDefaultClusterConfig): + if inspect.isclass(class_name) and issubclass( + class_name, CitusDefaultClusterConfig + ): return True return False @@ -167,7 +168,9 @@ class CitusDefaultClusterConfig(CitusBaseClusterConfig): self.add_coordinator_to_metadata = True self.skip_tests = [ # Alter Table statement cannot be run from an arbitrary node so this test will fail - "arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"] + "arbitrary_configs_alter_table_add_constraint_without_name_create", + "arbitrary_configs_alter_table_add_constraint_without_name", + ] class CitusUpgradeConfig(CitusBaseClusterConfig): @@ -190,9 +193,13 @@ class PostgresConfig(CitusDefaultClusterConfig): self.new_settings = { "citus.use_citus_managed_tables": False, } - self.skip_tests = ["nested_execution", + self.skip_tests = [ + "nested_execution", # Alter Table statement cannot be run from an arbitrary node so this test will fail - "arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"] + "arbitrary_configs_alter_table_add_constraint_without_name_create", + "arbitrary_configs_alter_table_add_constraint_without_name", + ] + class CitusSingleNodeClusterConfig(CitusDefaultClusterConfig): def __init__(self, arguments): @@ -229,7 +236,7 @@ class CitusSmallSharedPoolSizeConfig(CitusDefaultClusterConfig): def __init__(self, arguments): super().__init__(arguments) self.new_settings = { - "citus.local_shared_pool_size": 5, + "citus.local_shared_pool_size": 5, "citus.max_shared_pool_size": 5, } @@ -275,7 +282,7 @@ class CitusUnusualExecutorConfig(CitusDefaultClusterConfig): # this setting does not necessarily need to be here # could go any other test - self.env_variables = {'PGAPPNAME' : 'test_app'} + self.env_variables = {"PGAPPNAME": "test_app"} class CitusSmallCopyBuffersConfig(CitusDefaultClusterConfig): @@ -307,9 +314,13 @@ class CitusUnusualQuerySettingsConfig(CitusDefaultClusterConfig): # requires the table with the fk to be converted to a citus_local_table. # As of c11, there is no way to do that through remote execution so this test # will fail - "arbitrary_configs_truncate_cascade_create", "arbitrary_configs_truncate_cascade", - # Alter Table statement cannot be run from an arbitrary node so this test will fail - "arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"] + "arbitrary_configs_truncate_cascade_create", + "arbitrary_configs_truncate_cascade", + # Alter Table statement cannot be run from an arbitrary node so this test will fail + "arbitrary_configs_alter_table_add_constraint_without_name_create", + "arbitrary_configs_alter_table_add_constraint_without_name", + ] + class CitusSingleNodeSingleShardClusterConfig(CitusDefaultClusterConfig): def __init__(self, arguments): @@ -328,15 +339,20 @@ class CitusShardReplicationFactorClusterConfig(CitusDefaultClusterConfig): self.skip_tests = [ # citus does not support foreign keys in distributed tables # when citus.shard_replication_factor >= 2 - "arbitrary_configs_truncate_partition_create", "arbitrary_configs_truncate_partition", + "arbitrary_configs_truncate_partition_create", + "arbitrary_configs_truncate_partition", # citus does not support modifying a partition when # citus.shard_replication_factor >= 2 - "arbitrary_configs_truncate_cascade_create", "arbitrary_configs_truncate_cascade", + "arbitrary_configs_truncate_cascade_create", + "arbitrary_configs_truncate_cascade", # citus does not support colocating functions with distributed tables when # citus.shard_replication_factor >= 2 - "function_create", "functions", - # Alter Table statement cannot be run from an arbitrary node so this test will fail - "arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"] + "function_create", + "functions", + # Alter Table statement cannot be run from an arbitrary node so this test will fail + "arbitrary_configs_alter_table_add_constraint_without_name_create", + "arbitrary_configs_alter_table_add_constraint_without_name", + ] class CitusSingleShardClusterConfig(CitusDefaultClusterConfig): diff --git a/src/test/regress/citus_tests/run_test.py b/src/test/regress/citus_tests/run_test.py index a2a7fad8f..aba11156e 100755 --- a/src/test/regress/citus_tests/run_test.py +++ b/src/test/regress/citus_tests/run_test.py @@ -12,23 +12,53 @@ import common import config args = argparse.ArgumentParser() -args.add_argument("test_name", help="Test name (must be included in a schedule.)", nargs='?') -args.add_argument("-p", "--path", required=False, help="Relative path for test file (must have a .sql or .spec extension)", type=pathlib.Path) +args.add_argument( + "test_name", help="Test name (must be included in a schedule.)", nargs="?" +) +args.add_argument( + "-p", + "--path", + required=False, + help="Relative path for test file (must have a .sql or .spec extension)", + type=pathlib.Path, +) args.add_argument("-r", "--repeat", help="Number of test to run", type=int, default=1) -args.add_argument("-b", "--use-base-schedule", required=False, help="Choose base-schedules rather than minimal-schedules", action='store_true') -args.add_argument("-w", "--use-whole-schedule-line", required=False, help="Use the whole line found in related schedule", action='store_true') -args.add_argument("--valgrind", required=False, help="Run the test with valgrind enabled", action='store_true') +args.add_argument( + "-b", + "--use-base-schedule", + required=False, + help="Choose base-schedules rather than minimal-schedules", + action="store_true", +) +args.add_argument( + "-w", + "--use-whole-schedule-line", + required=False, + help="Use the whole line found in related schedule", + action="store_true", +) +args.add_argument( + "--valgrind", + required=False, + help="Run the test with valgrind enabled", + action="store_true", +) args = vars(args.parse_args()) regress_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -test_file_path = args['path'] -test_file_name = args['test_name'] -use_base_schedule = args['use_base_schedule'] -use_whole_schedule_line = args['use_whole_schedule_line'] +test_file_path = args["path"] +test_file_name = args["test_name"] +use_base_schedule = args["use_base_schedule"] +use_whole_schedule_line = args["use_whole_schedule_line"] -test_files_to_skip = ['multi_cluster_management', 'multi_extension', 'multi_test_helpers', 'multi_insert_select'] -test_files_to_run_without_schedule = ['single_node_enterprise'] +test_files_to_skip = [ + "multi_cluster_management", + "multi_extension", + "multi_test_helpers", + "multi_insert_select", +] +test_files_to_run_without_schedule = ["single_node_enterprise"] if not (test_file_name or test_file_path): print(f"FATAL: No test given.") @@ -36,7 +66,7 @@ if not (test_file_name or test_file_path): if test_file_path: - test_file_path = os.path.join(os.getcwd(), args['path']) + test_file_path = os.path.join(os.getcwd(), args["path"]) if not os.path.isfile(test_file_path): print(f"ERROR: test file '{test_file_path}' does not exist") @@ -45,7 +75,7 @@ if test_file_path: test_file_extension = pathlib.Path(test_file_path).suffix test_file_name = pathlib.Path(test_file_path).stem - if not test_file_extension in '.spec.sql': + if not test_file_extension in ".spec.sql": print( "ERROR: Unrecognized test extension. Valid extensions are: .sql and .spec" ) @@ -56,73 +86,73 @@ if test_file_name in test_files_to_skip: print(f"WARNING: Skipping exceptional test: '{test_file_name}'") sys.exit(0) -test_schedule = '' +test_schedule = "" # find related schedule for schedule_file_path in sorted(glob(os.path.join(regress_dir, "*_schedule"))): - for schedule_line in open(schedule_file_path, 'r'): - if re.search(r'\b' + test_file_name + r'\b', schedule_line): - test_schedule = pathlib.Path(schedule_file_path).stem - if use_whole_schedule_line: - test_schedule_line = schedule_line - else: - test_schedule_line = f"test: {test_file_name}\n" - break - else: - continue - break + for schedule_line in open(schedule_file_path, "r"): + if re.search(r"\b" + test_file_name + r"\b", schedule_line): + test_schedule = pathlib.Path(schedule_file_path).stem + if use_whole_schedule_line: + test_schedule_line = schedule_line + else: + test_schedule_line = f"test: {test_file_name}\n" + break + else: + continue + break # map suitable schedule if not test_schedule: - print( - f"WARNING: Could not find any schedule for '{test_file_name}'" - ) + print(f"WARNING: Could not find any schedule for '{test_file_name}'") sys.exit(0) elif "isolation" in test_schedule: - test_schedule = 'base_isolation_schedule' + test_schedule = "base_isolation_schedule" elif "failure" in test_schedule: - test_schedule = 'failure_base_schedule' + test_schedule = "failure_base_schedule" elif "enterprise" in test_schedule: - test_schedule = 'enterprise_minimal_schedule' + test_schedule = "enterprise_minimal_schedule" elif "split" in test_schedule: - test_schedule = 'minimal_schedule' + test_schedule = "minimal_schedule" elif "mx" in test_schedule: if use_base_schedule: - test_schedule = 'mx_base_schedule' + test_schedule = "mx_base_schedule" else: - test_schedule = 'mx_minimal_schedule' + test_schedule = "mx_minimal_schedule" elif "operations" in test_schedule: - test_schedule = 'minimal_schedule' + test_schedule = "minimal_schedule" elif test_schedule in config.ARBITRARY_SCHEDULE_NAMES: print(f"WARNING: Arbitrary config schedule ({test_schedule}) is not supported.") sys.exit(0) else: if use_base_schedule: - test_schedule = 'base_schedule' + test_schedule = "base_schedule" else: - test_schedule = 'minimal_schedule' + test_schedule = "minimal_schedule" # copy base schedule to a temp file and append test_schedule_line # to be able to run tests in parallel (if test_schedule_line is a parallel group.) -tmp_schedule_path = os.path.join(regress_dir, f"tmp_schedule_{ random.randint(1, 10000)}") +tmp_schedule_path = os.path.join( + regress_dir, f"tmp_schedule_{ random.randint(1, 10000)}" +) # some tests don't need a schedule to run # e.g tests that are in the first place in their own schedule if test_file_name not in test_files_to_run_without_schedule: shutil.copy2(os.path.join(regress_dir, test_schedule), tmp_schedule_path) with open(tmp_schedule_path, "a") as myfile: - for i in range(args['repeat']): - myfile.write(test_schedule_line) + for i in range(args["repeat"]): + myfile.write(test_schedule_line) # find suitable make recipe if "isolation" in test_schedule: - make_recipe = 'check-isolation-custom-schedule' + make_recipe = "check-isolation-custom-schedule" elif "failure" in test_schedule: - make_recipe = 'check-failure-custom-schedule' + make_recipe = "check-failure-custom-schedule" else: - make_recipe = 'check-custom-schedule' + make_recipe = "check-custom-schedule" -if args['valgrind']: - make_recipe += '-vg' +if args["valgrind"]: + make_recipe += "-vg" # prepare command to run tests test_command = f"make -C {regress_dir} {make_recipe} SCHEDULE='{pathlib.Path(tmp_schedule_path).stem}'" diff --git a/src/test/regress/citus_tests/upgrade/pg_upgrade_test.py b/src/test/regress/citus_tests/upgrade/pg_upgrade_test.py index 045ae42e1..b7fc75a08 100755 --- a/src/test/regress/citus_tests/upgrade/pg_upgrade_test.py +++ b/src/test/regress/citus_tests/upgrade/pg_upgrade_test.py @@ -112,7 +112,11 @@ def main(config): config.node_name_to_ports.keys(), ) common.start_databases( - config.new_bindir, config.new_datadir, config.node_name_to_ports, config.name, {} + config.new_bindir, + config.new_datadir, + config.node_name_to_ports, + config.name, + {}, ) citus_finish_pg_upgrade(config.new_bindir, config.node_name_to_ports.values()) diff --git a/src/test/regress/mitmscripts/fluent.py b/src/test/regress/mitmscripts/fluent.py index 2e6be9d34..363af680e 100644 --- a/src/test/regress/mitmscripts/fluent.py +++ b/src/test/regress/mitmscripts/fluent.py @@ -20,15 +20,17 @@ logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=loggin # I. Command Strings + class Handler: - ''' + """ This class hierarchy serves two purposes: 1. Allow command strings to be evaluated. Once evaluated you'll have a Handler you can pass packets to 2. Process packets as they come in and decide what to do with them. Subclasses which want to change how packets are handled should override _handle. - ''' + """ + def __init__(self, root=None): # all packets are first sent to the root handler to be processed self.root = root if root else self @@ -38,30 +40,31 @@ class Handler: def _accept(self, flow, message): result = self._handle(flow, message) - if result == 'pass': + if result == "pass": # defer to our child if not self.next: raise Exception("we don't know what to do!") - if self.next._accept(flow, message) == 'stop': + if self.next._accept(flow, message) == "stop": if self.root is not self: - return 'stop' + return "stop" self.next = KillHandler(self) flow.kill() else: return result def _handle(self, flow, message): - ''' + """ Handlers can return one of three things: - "done" tells the parent to stop processing. This performs the default action, which is to allow the packet to be sent. - "pass" means to delegate to self.next and do whatever it wants - "stop" means all processing will stop, and all connections will be killed - ''' + """ # subclasses must implement this raise NotImplementedError() + class FilterableMixin: def contains(self, pattern): self.next = Contains(self.root, pattern) @@ -76,7 +79,7 @@ class FilterableMixin: return self.next def __getattr__(self, attr): - ''' + """ Methods such as .onQuery trigger when a packet with that name is intercepted Adds support for commands such as: @@ -85,14 +88,17 @@ class FilterableMixin: Returns a function because the above command is resolved in two steps: conn.onQuery becomes conn.__getattr__("onQuery") conn.onQuery(query="COPY") becomes conn.__getattr__("onQuery")(query="COPY") - ''' - if attr.startswith('on'): + """ + if attr.startswith("on"): + def doit(**kwargs): self.next = OnPacket(self.root, attr[2:], kwargs) return self.next + return doit raise AttributeError + class ActionsMixin: def kill(self): self.next = KillHandler(self.root) @@ -118,31 +124,39 @@ class ActionsMixin: self.next = ConnectDelayHandler(self.root, timeMs) return self.next + class AcceptHandler(Handler): def __init__(self, root): super().__init__(root) + def _handle(self, flow, message): - return 'done' + return "done" + class KillHandler(Handler): def __init__(self, root): super().__init__(root) + def _handle(self, flow, message): flow.kill() - return 'done' + return "done" + class KillAllHandler(Handler): def __init__(self, root): super().__init__(root) + def _handle(self, flow, message): - return 'stop' + return "stop" + class ResetHandler(Handler): # try to force a RST to be sent, something went very wrong! def __init__(self, root): super().__init__(root) + def _handle(self, flow, message): - flow.kill() # tell mitmproxy this connection should be closed + flow.kill() # tell mitmproxy this connection should be closed # this is a mitmproxy.connections.ClientConnection(mitmproxy.tcp.BaseHandler) client_conn = flow.client_conn @@ -152,8 +166,9 @@ class ResetHandler(Handler): # cause linux to send a RST LINGER_ON, LINGER_TIMEOUT = 1, 0 conn.setsockopt( - socket.SOL_SOCKET, socket.SO_LINGER, - struct.pack('ii', LINGER_ON, LINGER_TIMEOUT) + socket.SOL_SOCKET, + socket.SO_LINGER, + struct.pack("ii", LINGER_ON, LINGER_TIMEOUT), ) conn.close() @@ -161,28 +176,35 @@ class ResetHandler(Handler): # tries to call conn.shutdown(), but there's nothing else to clean up so that's # maybe okay - return 'done' + return "done" + class CancelHandler(Handler): - 'Send a SIGINT to the process' + "Send a SIGINT to the process" + def __init__(self, root, pid): super().__init__(root) self.pid = pid + def _handle(self, flow, message): os.kill(self.pid, signal.SIGINT) # give the signal a chance to be received before we let the packet through time.sleep(0.1) - return 'done' + return "done" + class ConnectDelayHandler(Handler): - 'Delay the initial packet by sleeping before deciding what to do' + "Delay the initial packet by sleeping before deciding what to do" + def __init__(self, root, timeMs): super().__init__(root) self.timeMs = timeMs + def _handle(self, flow, message): if message.is_initial: - time.sleep(self.timeMs/1000.0) - return 'done' + time.sleep(self.timeMs / 1000.0) + return "done" + class Contains(Handler, ActionsMixin, FilterableMixin): def __init__(self, root, pattern): @@ -191,8 +213,9 @@ class Contains(Handler, ActionsMixin, FilterableMixin): def _handle(self, flow, message): if self.pattern in message.content: - return 'pass' - return 'done' + return "pass" + return "done" + class Matches(Handler, ActionsMixin, FilterableMixin): def __init__(self, root, pattern): @@ -201,47 +224,56 @@ class Matches(Handler, ActionsMixin, FilterableMixin): def _handle(self, flow, message): if self.pattern.search(message.content): - return 'pass' - return 'done' + return "pass" + return "done" + class After(Handler, ActionsMixin, FilterableMixin): "Don't pass execution to our child until we've handled 'times' messages" + def __init__(self, root, times): super().__init__(root) self.target = times def _handle(self, flow, message): - if not hasattr(flow, '_after_count'): + if not hasattr(flow, "_after_count"): flow._after_count = 0 if flow._after_count >= self.target: - return 'pass' + return "pass" flow._after_count += 1 - return 'done' + return "done" + class OnPacket(Handler, ActionsMixin, FilterableMixin): - '''Triggers when a packet of the specified kind comes around''' + """Triggers when a packet of the specified kind comes around""" + def __init__(self, root, packet_kind, kwargs): super().__init__(root) self.packet_kind = packet_kind self.filters = kwargs + def _handle(self, flow, message): if not message.parsed: # if this is the first message in the connection we just skip it - return 'done' + return "done" for msg in message.parsed: typ = structs.message_type(msg, from_frontend=message.from_client) if typ == self.packet_kind: - matches = structs.message_matches(msg, self.filters, message.from_client) + matches = structs.message_matches( + msg, self.filters, message.from_client + ) if matches: - return 'pass' - return 'done' + return "pass" + return "done" + class RootHandler(Handler, ActionsMixin, FilterableMixin): def _handle(self, flow, message): # do whatever the next Handler tells us to do - return 'pass' + return "pass" + class RecorderCommand: def __init__(self): @@ -250,23 +282,26 @@ class RecorderCommand: def dump(self): # When the user calls dump() we return everything we've captured - self.command = 'dump' + self.command = "dump" return self def reset(self): # If the user calls reset() we dump all captured packets without returning them - self.command = 'reset' + self.command = "reset" return self + # II. Utilities for interfacing with mitmproxy + def build_handler(spec): - 'Turns a command string into a RootHandler ready to accept packets' + "Turns a command string into a RootHandler ready to accept packets" root = RootHandler() recorder = RecorderCommand() - handler = eval(spec, {'__builtins__': {}}, {'conn': root, 'recorder': recorder}) + handler = eval(spec, {"__builtins__": {}}, {"conn": root, "recorder": recorder}) return handler.root + # a bunch of globals handler = None # the current handler used to process packets @@ -274,25 +309,25 @@ command_thread = None # sits on the fifo and waits for new commands to come in captured_messages = queue.Queue() # where we store messages used for recorder.dump() connection_count = count() # so we can give connections ids in recorder.dump() -def listen_for_commands(fifoname): +def listen_for_commands(fifoname): def emit_row(conn, from_client, message): # we're using the COPY text format. It requires us to escape backslashes - cleaned = message.replace('\\', '\\\\') - source = 'coordinator' if from_client else 'worker' - return '{}\t{}\t{}'.format(conn, source, cleaned) + cleaned = message.replace("\\", "\\\\") + source = "coordinator" if from_client else "worker" + return "{}\t{}\t{}".format(conn, source, cleaned) def emit_message(message): if message.is_initial: return emit_row( - message.connection_id, message.from_client, '[initial message]' + message.connection_id, message.from_client, "[initial message]" ) pretty = structs.print(message.parsed) return emit_row(message.connection_id, message.from_client, pretty) def all_items(queue_): - 'Pulls everything out of the queue without blocking' + "Pulls everything out of the queue without blocking" try: while True: yield queue_.get(block=False) @@ -300,23 +335,27 @@ def listen_for_commands(fifoname): pass def drop_terminate_messages(messages): - ''' + """ Terminate() messages happen eventually, Citus doesn't feel any need to send them immediately, so tests which embed them aren't reproducible and fail to timing issues. Here we simply drop those messages. - ''' + """ + def isTerminate(msg, from_client): kind = structs.message_type(msg, from_client) - return kind == 'Terminate' + return kind == "Terminate" for message in messages: if not message.parsed: yield message continue - message.parsed = ListContainer([ - msg for msg in message.parsed - if not isTerminate(msg, message.from_client) - ]) + message.parsed = ListContainer( + [ + msg + for msg in message.parsed + if not isTerminate(msg, message.from_client) + ] + ) message.parsed.from_frontend = message.from_client if len(message.parsed) == 0: continue @@ -324,35 +363,35 @@ def listen_for_commands(fifoname): def handle_recorder(recorder): global connection_count - result = '' + result = "" - if recorder.command == 'reset': - result = '' + if recorder.command == "reset": + result = "" connection_count = count() - elif recorder.command != 'dump': + elif recorder.command != "dump": # this should never happen - raise Exception('Unrecognized command: {}'.format(recorder.command)) + raise Exception("Unrecognized command: {}".format(recorder.command)) results = [] messages = all_items(captured_messages) messages = drop_terminate_messages(messages) for message in messages: - if recorder.command == 'reset': + if recorder.command == "reset": continue results.append(emit_message(message)) - result = '\n'.join(results) + result = "\n".join(results) - logging.debug('about to write to fifo') - with open(fifoname, mode='w') as fifo: - logging.debug('successfully opened the fifo for writing') - fifo.write('{}'.format(result)) + logging.debug("about to write to fifo") + with open(fifoname, mode="w") as fifo: + logging.debug("successfully opened the fifo for writing") + fifo.write("{}".format(result)) while True: - logging.debug('about to read from fifo') - with open(fifoname, mode='r') as fifo: - logging.debug('successfully opened the fifo for reading') + logging.debug("about to read from fifo") + with open(fifoname, mode="r") as fifo: + logging.debug("successfully opened the fifo for reading") slug = fifo.read() - logging.info('received new command: %s', slug.rstrip()) + logging.info("received new command: %s", slug.rstrip()) try: handler = build_handler(slug) @@ -371,13 +410,14 @@ def listen_for_commands(fifoname): except Exception as e: result = str(e) else: - result = '' + result = "" + + logging.debug("about to write to fifo") + with open(fifoname, mode="w") as fifo: + logging.debug("successfully opened the fifo for writing") + fifo.write("{}\n".format(result)) + logging.info("responded to command: %s", result.split("\n")[0]) - logging.debug('about to write to fifo') - with open(fifoname, mode='w') as fifo: - logging.debug('successfully opened the fifo for writing') - fifo.write('{}\n'.format(result)) - logging.info('responded to command: %s', result.split("\n")[0]) def create_thread(fifoname): global command_thread @@ -388,42 +428,46 @@ def create_thread(fifoname): return if command_thread: - print('cannot change the fifo path once mitmproxy has started'); + print("cannot change the fifo path once mitmproxy has started") return - command_thread = threading.Thread(target=listen_for_commands, args=(fifoname,), daemon=True) + command_thread = threading.Thread( + target=listen_for_commands, args=(fifoname,), daemon=True + ) command_thread.start() + # III. mitmproxy callbacks + def load(loader): - loader.add_option('slug', str, 'conn.allow()', "A script to run") - loader.add_option('fifo', str, '', "Which fifo to listen on for commands") + loader.add_option("slug", str, "conn.allow()", "A script to run") + loader.add_option("fifo", str, "", "Which fifo to listen on for commands") def configure(updated): global handler - if 'slug' in updated: + if "slug" in updated: text = ctx.options.slug handler = build_handler(text) - if 'fifo' in updated: + if "fifo" in updated: fifoname = ctx.options.fifo create_thread(fifoname) def tcp_message(flow: tcp.TCPFlow): - ''' + """ This callback is hit every time mitmproxy receives a packet. It's the main entrypoint into this script. - ''' + """ global connection_count tcp_msg = flow.messages[-1] # Keep track of all the different connections, assign a unique id to each - if not hasattr(flow, 'connection_id'): + if not hasattr(flow, "connection_id"): flow.connection_id = next(connection_count) tcp_msg.connection_id = flow.connection_id @@ -434,7 +478,9 @@ def tcp_message(flow: tcp.TCPFlow): # skip parsing initial messages for now, they're not important tcp_msg.parsed = None else: - tcp_msg.parsed = structs.parse(tcp_msg.content, from_frontend=tcp_msg.from_client) + tcp_msg.parsed = structs.parse( + tcp_msg.content, from_frontend=tcp_msg.from_client + ) # record the message, for debugging purposes captured_messages.put(tcp_msg) diff --git a/src/test/regress/mitmscripts/structs.py b/src/test/regress/mitmscripts/structs.py index a6333a3fa..10ef25bc3 100644 --- a/src/test/regress/mitmscripts/structs.py +++ b/src/test/regress/mitmscripts/structs.py @@ -1,8 +1,25 @@ from construct import ( Struct, - Int8ub, Int16ub, Int32ub, Int16sb, Int32sb, - Bytes, CString, Computed, Switch, Seek, this, Pointer, - GreedyRange, Enum, Byte, Probe, FixedSized, RestreamData, GreedyBytes, Array + Int8ub, + Int16ub, + Int32ub, + Int16sb, + Int32sb, + Bytes, + CString, + Computed, + Switch, + Seek, + this, + Pointer, + GreedyRange, + Enum, + Byte, + Probe, + FixedSized, + RestreamData, + GreedyBytes, + Array, ) import construct.lib as cl @@ -11,23 +28,30 @@ import re # For all possible message formats see: # https://www.postgresql.org/docs/current/protocol-message-formats.html + class MessageMeta(type): def __init__(cls, name, bases, namespace): - ''' + """ __init__ is called every time a subclass of MessageMeta is declared - ''' + """ if not hasattr(cls, "_msgtypes"): - raise Exception("classes which use MessageMeta must have a '_msgtypes' field") + raise Exception( + "classes which use MessageMeta must have a '_msgtypes' field" + ) if not hasattr(cls, "_classes"): - raise Exception("classes which use MessageMeta must have a '_classes' field") + raise Exception( + "classes which use MessageMeta must have a '_classes' field" + ) if not hasattr(cls, "struct"): # This is one of the direct subclasses return if cls.__name__ in cls._classes: - raise Exception("You've already made a class called {}".format( cls.__name__)) + raise Exception( + "You've already made a class called {}".format(cls.__name__) + ) cls._classes[cls.__name__] = cls # add a _type field to the struct so we can identify it while printing structs @@ -39,34 +63,41 @@ class MessageMeta(type): # register the type, so we can tell the parser about it key = cls.key if key in cls._msgtypes: - raise Exception('key {} is already assigned to {}'.format( - key, cls._msgtypes[key].__name__) + raise Exception( + "key {} is already assigned to {}".format( + key, cls._msgtypes[key].__name__ + ) ) cls._msgtypes[key] = cls + class Message: - 'Do not subclass this object directly. Instead, subclass of one of the below types' + "Do not subclass this object directly. Instead, subclass of one of the below types" def print(message): - 'Define this on subclasses you want to change the representation of' + "Define this on subclasses you want to change the representation of" raise NotImplementedError def typeof(message): - 'Define this on subclasses you want to change the expressed type of' + "Define this on subclasses you want to change the expressed type of" return message._type @classmethod def _default_print(cls, name, msg): recur = cls.print_message - return "{}({})".format(name, ",".join( - "{}={}".format(key, recur(value)) for key, value in msg.items() - if not key.startswith('_') - )) + return "{}({})".format( + name, + ",".join( + "{}={}".format(key, recur(value)) + for key, value in msg.items() + if not key.startswith("_") + ), + ) @classmethod def find_typeof(cls, msg): if not hasattr(cls, "_msgtypes"): - raise Exception('Do not call this method on Message, call it on a subclass') + raise Exception("Do not call this method on Message, call it on a subclass") if isinstance(msg, cl.ListContainer): raise ValueError("do not call this on a list of messages") if not isinstance(msg, cl.Container): @@ -80,7 +111,7 @@ class Message: @classmethod def print_message(cls, msg): if not hasattr(cls, "_msgtypes"): - raise Exception('Do not call this method on Message, call it on a subclass') + raise Exception("Do not call this method on Message, call it on a subclass") if isinstance(msg, cl.ListContainer): return repr([cls.print_message(message) for message in msg]) @@ -101,38 +132,34 @@ class Message: @classmethod def name_to_struct(cls): - return { - _class.__name__: _class.struct - for _class in cls._msgtypes.values() - } + return {_class.__name__: _class.struct for _class in cls._msgtypes.values()} @classmethod def name_to_key(cls): - return { - _class.__name__ : ord(key) - for key, _class in cls._msgtypes.items() - } + return {_class.__name__: ord(key) for key, _class in cls._msgtypes.items()} + class SharedMessage(Message, metaclass=MessageMeta): - 'A message which could be sent by either the frontend or the backend' + "A message which could be sent by either the frontend or the backend" _msgtypes = dict() _classes = dict() + class FrontendMessage(Message, metaclass=MessageMeta): - 'A message which will only be sent be a backend' + "A message which will only be sent be a backend" _msgtypes = dict() _classes = dict() + class BackendMessage(Message, metaclass=MessageMeta): - 'A message which will only be sent be a frontend' + "A message which will only be sent be a frontend" _msgtypes = dict() _classes = dict() + class Query(FrontendMessage): - key = 'Q' - struct = Struct( - "query" / CString("ascii") - ) + key = "Q" + struct = Struct("query" / CString("ascii")) @staticmethod def print(message): @@ -144,132 +171,151 @@ class Query(FrontendMessage): @staticmethod def normalize_shards(content): - ''' + """ For example: >>> normalize_shards( >>> 'COPY public.copy_test_120340 (key, value) FROM STDIN WITH (FORMAT BINARY))' >>> ) 'COPY public.copy_test_XXXXXX (key, value) FROM STDIN WITH (FORMAT BINARY))' - ''' + """ result = content - pattern = re.compile('public\.[a-z_]+(?P[0-9]+)') + pattern = re.compile("public\.[a-z_]+(?P[0-9]+)") for match in pattern.finditer(content): - span = match.span('shardid') - replacement = 'X'*( span[1] - span[0] ) - result = result[:span[0]] + replacement + result[span[1]:] + span = match.span("shardid") + replacement = "X" * (span[1] - span[0]) + result = result[: span[0]] + replacement + result[span[1] :] return result @staticmethod def normalize_timestamps(content): - ''' + """ For example: >>> normalize_timestamps('2018-06-07 05:18:19.388992-07') 'XXXX-XX-XX XX:XX:XX.XXXXXX-XX' >>> normalize_timestamps('2018-06-11 05:30:43.01382-07') 'XXXX-XX-XX XX:XX:XX.XXXXXX-XX' - ''' + """ pattern = re.compile( - '[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{2,6}-[0-9]{2}' + "[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{2,6}-[0-9]{2}" ) - return re.sub(pattern, 'XXXX-XX-XX XX:XX:XX.XXXXXX-XX', content) + return re.sub(pattern, "XXXX-XX-XX XX:XX:XX.XXXXXX-XX", content) @staticmethod def normalize_assign_txn_id(content): - ''' + """ For example: >>> normalize_assign_txn_id('SELECT assign_distributed_transaction_id(0, 52, ...') 'SELECT assign_distributed_transaction_id(0, XX, ...' - ''' + """ pattern = re.compile( - 'assign_distributed_transaction_id\s*\(' # a method call - '\s*[0-9]+\s*,' # an integer first parameter - '\s*(?P[0-9]+)' # an integer second parameter + "assign_distributed_transaction_id\s*\(" # a method call + "\s*[0-9]+\s*," # an integer first parameter + "\s*(?P[0-9]+)" # an integer second parameter ) result = content for match in pattern.finditer(content): - span = match.span('transaction_id') - result = result[:span[0]] + 'XX' + result[span[1]:] + span = match.span("transaction_id") + result = result[: span[0]] + "XX" + result[span[1] :] return result + class Terminate(FrontendMessage): - key = 'X' + key = "X" struct = Struct() + class CopyData(SharedMessage): - key = 'd' + key = "d" struct = Struct( - 'data' / GreedyBytes # reads all of the data left in this substream + "data" / GreedyBytes # reads all of the data left in this substream ) + class CopyDone(SharedMessage): - key = 'c' + key = "c" struct = Struct() + class EmptyQueryResponse(BackendMessage): - key = 'I' + key = "I" struct = Struct() + class CopyOutResponse(BackendMessage): - key = 'H' + key = "H" struct = Struct( "format" / Int8ub, "columncount" / Int16ub, - "columns" / Array(this.columncount, Struct( - "format" / Int16ub - )) + "columns" / Array(this.columncount, Struct("format" / Int16ub)), ) + class ReadyForQuery(BackendMessage): - key='Z' - struct = Struct("state"/Enum(Byte, - idle=ord('I'), - in_transaction_block=ord('T'), - in_failed_transaction_block=ord('E') - )) + key = "Z" + struct = Struct( + "state" + / Enum( + Byte, + idle=ord("I"), + in_transaction_block=ord("T"), + in_failed_transaction_block=ord("E"), + ) + ) + class CommandComplete(BackendMessage): - key = 'C' - struct = Struct( - "command" / CString("ascii") - ) + key = "C" + struct = Struct("command" / CString("ascii")) + class RowDescription(BackendMessage): - key = 'T' + key = "T" struct = Struct( "fieldcount" / Int16ub, - "fields" / Array(this.fieldcount, Struct( - "_type" / Computed("F"), - "name" / CString("ascii"), - "tableoid" / Int32ub, - "colattrnum" / Int16ub, - "typoid" / Int32ub, - "typlen" / Int16sb, - "typmod" / Int32sb, - "format_code" / Int16ub, - )) + "fields" + / Array( + this.fieldcount, + Struct( + "_type" / Computed("F"), + "name" / CString("ascii"), + "tableoid" / Int32ub, + "colattrnum" / Int16ub, + "typoid" / Int32ub, + "typlen" / Int16sb, + "typmod" / Int32sb, + "format_code" / Int16ub, + ), + ), ) + class DataRow(BackendMessage): - key = 'D' + key = "D" struct = Struct( "_type" / Computed("data_row"), "columncount" / Int16ub, - "columns" / Array(this.columncount, Struct( - "_type" / Computed("C"), - "length" / Int16sb, - "value" / Bytes(this.length) - )) + "columns" + / Array( + this.columncount, + Struct( + "_type" / Computed("C"), + "length" / Int16sb, + "value" / Bytes(this.length), + ), + ), ) + class AuthenticationOk(BackendMessage): - key = 'R' + key = "R" struct = Struct() + class ParameterStatus(BackendMessage): - key = 'S' + key = "S" struct = Struct( "name" / CString("ASCII"), "value" / CString("ASCII"), @@ -281,161 +327,156 @@ class ParameterStatus(BackendMessage): @staticmethod def normalize(name, value): - if name in ('TimeZone', 'server_version'): - value = 'XXX' + if name in ("TimeZone", "server_version"): + value = "XXX" return (name, value) + class BackendKeyData(BackendMessage): - key = 'K' - struct = Struct( - "pid" / Int32ub, - "key" / Bytes(4) - ) + key = "K" + struct = Struct("pid" / Int32ub, "key" / Bytes(4)) def print(message): # Both of these should be censored, for reproducible regression test output return "BackendKeyData(XXX)" + class NoticeResponse(BackendMessage): - key = 'N' + key = "N" struct = Struct( - "notices" / GreedyRange( + "notices" + / GreedyRange( Struct( - "key" / Enum(Byte, - severity=ord('S'), - _severity_not_localized=ord('V'), - _sql_state=ord('C'), - message=ord('M'), - detail=ord('D'), - hint=ord('H'), - _position=ord('P'), - _internal_position=ord('p'), - _internal_query=ord('q'), - _where=ord('W'), - schema_name=ord('s'), - table_name=ord('t'), - column_name=ord('c'), - data_type_name=ord('d'), - constraint_name=ord('n'), - _file_name=ord('F'), - _line_no=ord('L'), - _routine_name=ord('R') - ), - "value" / CString("ASCII") + "key" + / Enum( + Byte, + severity=ord("S"), + _severity_not_localized=ord("V"), + _sql_state=ord("C"), + message=ord("M"), + detail=ord("D"), + hint=ord("H"), + _position=ord("P"), + _internal_position=ord("p"), + _internal_query=ord("q"), + _where=ord("W"), + schema_name=ord("s"), + table_name=ord("t"), + column_name=ord("c"), + data_type_name=ord("d"), + constraint_name=ord("n"), + _file_name=ord("F"), + _line_no=ord("L"), + _routine_name=ord("R"), + ), + "value" / CString("ASCII"), ) ) ) def print(message): - return "NoticeResponse({})".format(", ".join( - "{}={}".format(response.key, response.value) - for response in message.notices - if not response.key.startswith('_') - )) + return "NoticeResponse({})".format( + ", ".join( + "{}={}".format(response.key, response.value) + for response in message.notices + if not response.key.startswith("_") + ) + ) + class Parse(FrontendMessage): - key = 'P' + key = "P" struct = Struct( "name" / CString("ASCII"), "query" / CString("ASCII"), "_parametercount" / Int16ub, - "parameters" / Array( - this._parametercount, - Int32ub - ) + "parameters" / Array(this._parametercount, Int32ub), ) + class ParseComplete(BackendMessage): - key = '1' + key = "1" struct = Struct() + class Bind(FrontendMessage): - key = 'B' + key = "B" struct = Struct( "destination_portal" / CString("ASCII"), "prepared_statement" / CString("ASCII"), "_parameter_format_code_count" / Int16ub, - "parameter_format_codes" / Array(this._parameter_format_code_count, - Int16ub), + "parameter_format_codes" / Array(this._parameter_format_code_count, Int16ub), "_parameter_value_count" / Int16ub, - "parameter_values" / Array( + "parameter_values" + / Array( this._parameter_value_count, - Struct( - "length" / Int32ub, - "value" / Bytes(this.length) - ) + Struct("length" / Int32ub, "value" / Bytes(this.length)), ), "result_column_format_count" / Int16ub, - "result_column_format_codes" / Array(this.result_column_format_count, - Int16ub) + "result_column_format_codes" / Array(this.result_column_format_count, Int16ub), ) + class BindComplete(BackendMessage): - key = '2' + key = "2" struct = Struct() + class NoData(BackendMessage): - key = 'n' + key = "n" struct = Struct() + class Describe(FrontendMessage): - key = 'D' + key = "D" struct = Struct( - "type" / Enum(Byte, - prepared_statement=ord('S'), - portal=ord('P') - ), - "name" / CString("ASCII") + "type" / Enum(Byte, prepared_statement=ord("S"), portal=ord("P")), + "name" / CString("ASCII"), ) def print(message): - return "Describe({}={})".format( - message.type, - message.name or "" - ) + return "Describe({}={})".format(message.type, message.name or "") + class Execute(FrontendMessage): - key = 'E' - struct = Struct( - "name" / CString("ASCII"), - "max_rows_to_return" / Int32ub - ) + key = "E" + struct = Struct("name" / CString("ASCII"), "max_rows_to_return" / Int32ub) def print(message): return "Execute({}, max_rows_to_return={})".format( - message.name or "", - message.max_rows_to_return + message.name or "", message.max_rows_to_return ) + class Sync(FrontendMessage): - key = 'S' + key = "S" struct = Struct() + frontend_switch = Switch( this.type, - { **FrontendMessage.name_to_struct(), **SharedMessage.name_to_struct() }, - default=Bytes(this.length - 4) + {**FrontendMessage.name_to_struct(), **SharedMessage.name_to_struct()}, + default=Bytes(this.length - 4), ) backend_switch = Switch( this.type, {**BackendMessage.name_to_struct(), **SharedMessage.name_to_struct()}, - default=Bytes(this.length - 4) + default=Bytes(this.length - 4), ) -frontend_msgtypes = Enum(Byte, **{ - **FrontendMessage.name_to_key(), - **SharedMessage.name_to_key() -}) +frontend_msgtypes = Enum( + Byte, **{**FrontendMessage.name_to_key(), **SharedMessage.name_to_key()} +) -backend_msgtypes = Enum(Byte, **{ - **BackendMessage.name_to_key(), - **SharedMessage.name_to_key() -}) +backend_msgtypes = Enum( + Byte, **{**BackendMessage.name_to_key(), **SharedMessage.name_to_key()} +) # It might seem a little circuitous to say a frontend message is a kind of frontend # message but this lets us easily customize how they're printed + class Frontend(FrontendMessage): struct = Struct( "type" / frontend_msgtypes, @@ -447,9 +488,7 @@ class Frontend(FrontendMessage): def print(message): if isinstance(message.body, bytes): - return "Frontend(type={},body={})".format( - chr(message.type), message.body - ) + return "Frontend(type={},body={})".format(chr(message.type), message.body) return FrontendMessage.print_message(message.body) def typeof(message): @@ -457,6 +496,7 @@ class Frontend(FrontendMessage): return "Unknown" return message.body._type + class Backend(BackendMessage): struct = Struct( "type" / backend_msgtypes, @@ -468,9 +508,7 @@ class Backend(BackendMessage): def print(message): if isinstance(message.body, bytes): - return "Backend(type={},body={})".format( - chr(message.type), message.body - ) + return "Backend(type={},body={})".format(chr(message.type), message.body) return BackendMessage.print_message(message.body) def typeof(message): @@ -478,10 +516,12 @@ class Backend(BackendMessage): return "Unknown" return message.body._type + # GreedyRange keeps reading messages until we hit EOF frontend_messages = GreedyRange(Frontend.struct) backend_messages = GreedyRange(Backend.struct) + def parse(message, from_frontend=True): if from_frontend: message = frontend_messages.parse(message) @@ -491,24 +531,27 @@ def parse(message, from_frontend=True): return message + def print(message): if message.from_frontend: return FrontendMessage.print_message(message) return BackendMessage.print_message(message) + def message_type(message, from_frontend): if from_frontend: return FrontendMessage.find_typeof(message) return BackendMessage.find_typeof(message) + def message_matches(message, filters, from_frontend): - ''' + """ Message is something like Backend(Query)) and fiters is something like query="COPY". For now we only support strings, and treat them like a regex, which is matched against the content of the wrapped message - ''' - if message._type != 'Backend' and message._type != 'Frontend': + """ + if message._type != "Backend" and message._type != "Frontend": raise ValueError("can't handle {}".format(message._type)) wrapped = message.body