Format python files with black

pull/6700/head
Jelte Fennema 2023-02-08 11:34:49 +01:00
parent 42970665fc
commit 530b24a887
8 changed files with 542 additions and 389 deletions

View File

@ -14,6 +14,7 @@ class FileScanner:
FileScanner is an iterator over the lines of a file. FileScanner is an iterator over the lines of a file.
It can apply a rewrite rule which can be used to skip lines. It can apply a rewrite rule which can be used to skip lines.
""" """
def __init__(self, file, rewrite=lambda x: x): def __init__(self, file, rewrite=lambda x: x):
self.file = file self.file = file
self.line = 1 self.line = 1
@ -33,12 +34,12 @@ def main():
regexpipeline = [] regexpipeline = []
for line in open(argv[1]): for line in open(argv[1]):
line = line.strip() line = line.strip()
if not line or line.startswith('#') or not line.endswith('d'): if not line or line.startswith("#") or not line.endswith("d"):
continue continue
rule = regexregex.match(line) rule = regexregex.match(line)
if not rule: if not rule:
raise 'Failed to parse regex rule: %s' % line raise "Failed to parse regex rule: %s" % line
regexpipeline.append(re.compile(rule.group('rule'))) regexpipeline.append(re.compile(rule.group("rule")))
def sed(line): def sed(line):
if any(regex.search(line) for regex in regexpipeline): if any(regex.search(line) for regex in regexpipeline):
@ -46,13 +47,15 @@ def main():
return line return line
for line in stdin: for line in stdin:
if line.startswith('+++ '): if line.startswith("+++ "):
tab = line.rindex('\t') tab = line.rindex("\t")
fname = line[4:tab] fname = line[4:tab]
file2 = FileScanner(open(fname.replace('.modified', ''), encoding='utf8'), sed) file2 = FileScanner(
open(fname.replace(".modified", ""), encoding="utf8"), sed
)
stdout.write(line) stdout.write(line)
elif line.startswith('@@ '): elif line.startswith("@@ "):
idx_start = line.index('+') + 1 idx_start = line.index("+") + 1
idx_end = idx_start + 1 idx_end = idx_start + 1
while line[idx_end].isdigit(): while line[idx_end].isdigit():
idx_end += 1 idx_end += 1
@ -60,11 +63,11 @@ def main():
while file2.line < linenum: while file2.line < linenum:
next(file2) next(file2)
stdout.write(line) stdout.write(line)
elif line.startswith(' '): elif line.startswith(" "):
stdout.write(' ') stdout.write(" ")
stdout.write(next(file2)) stdout.write(next(file2))
elif line.startswith('+'): elif line.startswith("+"):
stdout.write('+') stdout.write("+")
stdout.write(next(file2)) stdout.write(next(file2))
else: else:
stdout.write(line) stdout.write(line)

View File

@ -115,7 +115,6 @@ def copy_copy_modified_binary(datadir):
def copy_test_files(config): def copy_test_files(config):
sql_dir_path = os.path.join(config.datadir, "sql") sql_dir_path = os.path.join(config.datadir, "sql")
expected_dir_path = os.path.join(config.datadir, "expected") expected_dir_path = os.path.join(config.datadir, "expected")
@ -132,7 +131,9 @@ def copy_test_files(config):
line = line[colon_index + 1 :].strip() line = line[colon_index + 1 :].strip()
test_names = line.split(" ") test_names = line.split(" ")
copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config) copy_test_files_with_names(
test_names, sql_dir_path, expected_dir_path, config
)
def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config): def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config):
@ -140,10 +141,10 @@ def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, conf
# make empty files for the skipped tests # make empty files for the skipped tests
if test_name in config.skip_tests: if test_name in config.skip_tests:
expected_sql_file = os.path.join(sql_dir_path, test_name + ".sql") expected_sql_file = os.path.join(sql_dir_path, test_name + ".sql")
open(expected_sql_file, 'x').close() open(expected_sql_file, "x").close()
expected_out_file = os.path.join(expected_dir_path, test_name + ".out") expected_out_file = os.path.join(expected_dir_path, test_name + ".out")
open(expected_out_file, 'x').close() open(expected_out_file, "x").close()
continue continue

View File

@ -27,13 +27,11 @@ def initialize_temp_dir_if_not_exists(temp_dir):
def parallel_run(function, items, *args, **kwargs): def parallel_run(function, items, *args, **kwargs):
with concurrent.futures.ThreadPoolExecutor() as executor: with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [ futures = [executor.submit(function, item, *args, **kwargs) for item in items]
executor.submit(function, item, *args, **kwargs)
for item in items
]
for future in futures: for future in futures:
future.result() future.result()
def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names): def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names):
subprocess.run(["mkdir", rel_data_path], check=True) subprocess.run(["mkdir", rel_data_path], check=True)
@ -52,7 +50,7 @@ def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names):
"--encoding", "--encoding",
"UTF8", "UTF8",
"--locale", "--locale",
"POSIX" "POSIX",
] ]
subprocess.run(command, check=True) subprocess.run(command, check=True)
add_settings(abs_data_path, settings) add_settings(abs_data_path, settings)
@ -76,7 +74,9 @@ def create_role(pg_path, node_ports, user_name):
user_name, user_name user_name, user_name
) )
utils.psql(pg_path, port, command) utils.psql(pg_path, port, command)
command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format(user_name) command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format(
user_name
)
utils.psql(pg_path, port, command) utils.psql(pg_path, port, command)
parallel_run(create, node_ports) parallel_run(create, node_ports)
@ -89,7 +89,9 @@ def coordinator_should_haveshards(pg_path, port):
utils.psql(pg_path, port, command) utils.psql(pg_path, port, command)
def start_databases(pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables): def start_databases(
pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables
):
def start(node_name): def start(node_name):
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name)) abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
node_port = node_name_to_ports[node_name] node_port = node_name_to_ports[node_name]
@ -248,7 +250,12 @@ def logfile_name(logfile_prefix, node_name):
def stop_databases( def stop_databases(
pg_path, rel_data_path, node_name_to_ports, logfile_prefix, no_output=False, parallel=True pg_path,
rel_data_path,
node_name_to_ports,
logfile_prefix,
no_output=False,
parallel=True,
): ):
def stop(node_name): def stop(node_name):
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name)) abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
@ -287,7 +294,9 @@ def initialize_citus_cluster(bindir, datadir, settings, config):
initialize_db_for_cluster( initialize_db_for_cluster(
bindir, datadir, settings, config.node_name_to_ports.keys() bindir, datadir, settings, config.node_name_to_ports.keys()
) )
start_databases(bindir, datadir, config.node_name_to_ports, config.name, config.env_variables) start_databases(
bindir, datadir, config.node_name_to_ports, config.name, config.env_variables
)
create_citus_extension(bindir, config.node_name_to_ports.values()) create_citus_extension(bindir, config.node_name_to_ports.values())
add_workers(bindir, config.worker_ports, config.coordinator_port()) add_workers(bindir, config.worker_ports, config.coordinator_port())
if not config.is_mx: if not config.is_mx:
@ -296,6 +305,7 @@ def initialize_citus_cluster(bindir, datadir, settings, config):
add_coordinator_to_metadata(bindir, config.coordinator_port()) add_coordinator_to_metadata(bindir, config.coordinator_port())
config.setup_steps() config.setup_steps()
def eprint(*args, **kwargs): def eprint(*args, **kwargs):
"""eprint prints to stderr""" """eprint prints to stderr"""

View File

@ -57,8 +57,9 @@ port_lock = threading.Lock()
def should_include_config(class_name): def should_include_config(class_name):
if inspect.isclass(class_name) and issubclass(
if inspect.isclass(class_name) and issubclass(class_name, CitusDefaultClusterConfig): class_name, CitusDefaultClusterConfig
):
return True return True
return False return False
@ -167,7 +168,9 @@ class CitusDefaultClusterConfig(CitusBaseClusterConfig):
self.add_coordinator_to_metadata = True self.add_coordinator_to_metadata = True
self.skip_tests = [ self.skip_tests = [
# Alter Table statement cannot be run from an arbitrary node so this test will fail # Alter Table statement cannot be run from an arbitrary node so this test will fail
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"] "arbitrary_configs_alter_table_add_constraint_without_name_create",
"arbitrary_configs_alter_table_add_constraint_without_name",
]
class CitusUpgradeConfig(CitusBaseClusterConfig): class CitusUpgradeConfig(CitusBaseClusterConfig):
@ -190,9 +193,13 @@ class PostgresConfig(CitusDefaultClusterConfig):
self.new_settings = { self.new_settings = {
"citus.use_citus_managed_tables": False, "citus.use_citus_managed_tables": False,
} }
self.skip_tests = ["nested_execution", self.skip_tests = [
"nested_execution",
# Alter Table statement cannot be run from an arbitrary node so this test will fail # Alter Table statement cannot be run from an arbitrary node so this test will fail
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"] "arbitrary_configs_alter_table_add_constraint_without_name_create",
"arbitrary_configs_alter_table_add_constraint_without_name",
]
class CitusSingleNodeClusterConfig(CitusDefaultClusterConfig): class CitusSingleNodeClusterConfig(CitusDefaultClusterConfig):
def __init__(self, arguments): def __init__(self, arguments):
@ -275,7 +282,7 @@ class CitusUnusualExecutorConfig(CitusDefaultClusterConfig):
# this setting does not necessarily need to be here # this setting does not necessarily need to be here
# could go any other test # could go any other test
self.env_variables = {'PGAPPNAME' : 'test_app'} self.env_variables = {"PGAPPNAME": "test_app"}
class CitusSmallCopyBuffersConfig(CitusDefaultClusterConfig): class CitusSmallCopyBuffersConfig(CitusDefaultClusterConfig):
@ -307,9 +314,13 @@ class CitusUnusualQuerySettingsConfig(CitusDefaultClusterConfig):
# requires the table with the fk to be converted to a citus_local_table. # requires the table with the fk to be converted to a citus_local_table.
# As of c11, there is no way to do that through remote execution so this test # As of c11, there is no way to do that through remote execution so this test
# will fail # will fail
"arbitrary_configs_truncate_cascade_create", "arbitrary_configs_truncate_cascade", "arbitrary_configs_truncate_cascade_create",
"arbitrary_configs_truncate_cascade",
# Alter Table statement cannot be run from an arbitrary node so this test will fail # Alter Table statement cannot be run from an arbitrary node so this test will fail
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"] "arbitrary_configs_alter_table_add_constraint_without_name_create",
"arbitrary_configs_alter_table_add_constraint_without_name",
]
class CitusSingleNodeSingleShardClusterConfig(CitusDefaultClusterConfig): class CitusSingleNodeSingleShardClusterConfig(CitusDefaultClusterConfig):
def __init__(self, arguments): def __init__(self, arguments):
@ -328,15 +339,20 @@ class CitusShardReplicationFactorClusterConfig(CitusDefaultClusterConfig):
self.skip_tests = [ self.skip_tests = [
# citus does not support foreign keys in distributed tables # citus does not support foreign keys in distributed tables
# when citus.shard_replication_factor >= 2 # when citus.shard_replication_factor >= 2
"arbitrary_configs_truncate_partition_create", "arbitrary_configs_truncate_partition", "arbitrary_configs_truncate_partition_create",
"arbitrary_configs_truncate_partition",
# citus does not support modifying a partition when # citus does not support modifying a partition when
# citus.shard_replication_factor >= 2 # citus.shard_replication_factor >= 2
"arbitrary_configs_truncate_cascade_create", "arbitrary_configs_truncate_cascade", "arbitrary_configs_truncate_cascade_create",
"arbitrary_configs_truncate_cascade",
# citus does not support colocating functions with distributed tables when # citus does not support colocating functions with distributed tables when
# citus.shard_replication_factor >= 2 # citus.shard_replication_factor >= 2
"function_create", "functions", "function_create",
"functions",
# Alter Table statement cannot be run from an arbitrary node so this test will fail # Alter Table statement cannot be run from an arbitrary node so this test will fail
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"] "arbitrary_configs_alter_table_add_constraint_without_name_create",
"arbitrary_configs_alter_table_add_constraint_without_name",
]
class CitusSingleShardClusterConfig(CitusDefaultClusterConfig): class CitusSingleShardClusterConfig(CitusDefaultClusterConfig):

View File

@ -12,23 +12,53 @@ import common
import config import config
args = argparse.ArgumentParser() args = argparse.ArgumentParser()
args.add_argument("test_name", help="Test name (must be included in a schedule.)", nargs='?') args.add_argument(
args.add_argument("-p", "--path", required=False, help="Relative path for test file (must have a .sql or .spec extension)", type=pathlib.Path) "test_name", help="Test name (must be included in a schedule.)", nargs="?"
)
args.add_argument(
"-p",
"--path",
required=False,
help="Relative path for test file (must have a .sql or .spec extension)",
type=pathlib.Path,
)
args.add_argument("-r", "--repeat", help="Number of test to run", type=int, default=1) args.add_argument("-r", "--repeat", help="Number of test to run", type=int, default=1)
args.add_argument("-b", "--use-base-schedule", required=False, help="Choose base-schedules rather than minimal-schedules", action='store_true') args.add_argument(
args.add_argument("-w", "--use-whole-schedule-line", required=False, help="Use the whole line found in related schedule", action='store_true') "-b",
args.add_argument("--valgrind", required=False, help="Run the test with valgrind enabled", action='store_true') "--use-base-schedule",
required=False,
help="Choose base-schedules rather than minimal-schedules",
action="store_true",
)
args.add_argument(
"-w",
"--use-whole-schedule-line",
required=False,
help="Use the whole line found in related schedule",
action="store_true",
)
args.add_argument(
"--valgrind",
required=False,
help="Run the test with valgrind enabled",
action="store_true",
)
args = vars(args.parse_args()) args = vars(args.parse_args())
regress_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) regress_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
test_file_path = args['path'] test_file_path = args["path"]
test_file_name = args['test_name'] test_file_name = args["test_name"]
use_base_schedule = args['use_base_schedule'] use_base_schedule = args["use_base_schedule"]
use_whole_schedule_line = args['use_whole_schedule_line'] use_whole_schedule_line = args["use_whole_schedule_line"]
test_files_to_skip = ['multi_cluster_management', 'multi_extension', 'multi_test_helpers', 'multi_insert_select'] test_files_to_skip = [
test_files_to_run_without_schedule = ['single_node_enterprise'] "multi_cluster_management",
"multi_extension",
"multi_test_helpers",
"multi_insert_select",
]
test_files_to_run_without_schedule = ["single_node_enterprise"]
if not (test_file_name or test_file_path): if not (test_file_name or test_file_path):
print(f"FATAL: No test given.") print(f"FATAL: No test given.")
@ -36,7 +66,7 @@ if not (test_file_name or test_file_path):
if test_file_path: if test_file_path:
test_file_path = os.path.join(os.getcwd(), args['path']) test_file_path = os.path.join(os.getcwd(), args["path"])
if not os.path.isfile(test_file_path): if not os.path.isfile(test_file_path):
print(f"ERROR: test file '{test_file_path}' does not exist") print(f"ERROR: test file '{test_file_path}' does not exist")
@ -45,7 +75,7 @@ if test_file_path:
test_file_extension = pathlib.Path(test_file_path).suffix test_file_extension = pathlib.Path(test_file_path).suffix
test_file_name = pathlib.Path(test_file_path).stem test_file_name = pathlib.Path(test_file_path).stem
if not test_file_extension in '.spec.sql': if not test_file_extension in ".spec.sql":
print( print(
"ERROR: Unrecognized test extension. Valid extensions are: .sql and .spec" "ERROR: Unrecognized test extension. Valid extensions are: .sql and .spec"
) )
@ -56,12 +86,12 @@ if test_file_name in test_files_to_skip:
print(f"WARNING: Skipping exceptional test: '{test_file_name}'") print(f"WARNING: Skipping exceptional test: '{test_file_name}'")
sys.exit(0) sys.exit(0)
test_schedule = '' test_schedule = ""
# find related schedule # find related schedule
for schedule_file_path in sorted(glob(os.path.join(regress_dir, "*_schedule"))): for schedule_file_path in sorted(glob(os.path.join(regress_dir, "*_schedule"))):
for schedule_line in open(schedule_file_path, 'r'): for schedule_line in open(schedule_file_path, "r"):
if re.search(r'\b' + test_file_name + r'\b', schedule_line): if re.search(r"\b" + test_file_name + r"\b", schedule_line):
test_schedule = pathlib.Path(schedule_file_path).stem test_schedule = pathlib.Path(schedule_file_path).stem
if use_whole_schedule_line: if use_whole_schedule_line:
test_schedule_line = schedule_line test_schedule_line = schedule_line
@ -74,55 +104,55 @@ for schedule_file_path in sorted(glob(os.path.join(regress_dir, "*_schedule"))):
# map suitable schedule # map suitable schedule
if not test_schedule: if not test_schedule:
print( print(f"WARNING: Could not find any schedule for '{test_file_name}'")
f"WARNING: Could not find any schedule for '{test_file_name}'"
)
sys.exit(0) sys.exit(0)
elif "isolation" in test_schedule: elif "isolation" in test_schedule:
test_schedule = 'base_isolation_schedule' test_schedule = "base_isolation_schedule"
elif "failure" in test_schedule: elif "failure" in test_schedule:
test_schedule = 'failure_base_schedule' test_schedule = "failure_base_schedule"
elif "enterprise" in test_schedule: elif "enterprise" in test_schedule:
test_schedule = 'enterprise_minimal_schedule' test_schedule = "enterprise_minimal_schedule"
elif "split" in test_schedule: elif "split" in test_schedule:
test_schedule = 'minimal_schedule' test_schedule = "minimal_schedule"
elif "mx" in test_schedule: elif "mx" in test_schedule:
if use_base_schedule: if use_base_schedule:
test_schedule = 'mx_base_schedule' test_schedule = "mx_base_schedule"
else: else:
test_schedule = 'mx_minimal_schedule' test_schedule = "mx_minimal_schedule"
elif "operations" in test_schedule: elif "operations" in test_schedule:
test_schedule = 'minimal_schedule' test_schedule = "minimal_schedule"
elif test_schedule in config.ARBITRARY_SCHEDULE_NAMES: elif test_schedule in config.ARBITRARY_SCHEDULE_NAMES:
print(f"WARNING: Arbitrary config schedule ({test_schedule}) is not supported.") print(f"WARNING: Arbitrary config schedule ({test_schedule}) is not supported.")
sys.exit(0) sys.exit(0)
else: else:
if use_base_schedule: if use_base_schedule:
test_schedule = 'base_schedule' test_schedule = "base_schedule"
else: else:
test_schedule = 'minimal_schedule' test_schedule = "minimal_schedule"
# copy base schedule to a temp file and append test_schedule_line # copy base schedule to a temp file and append test_schedule_line
# to be able to run tests in parallel (if test_schedule_line is a parallel group.) # to be able to run tests in parallel (if test_schedule_line is a parallel group.)
tmp_schedule_path = os.path.join(regress_dir, f"tmp_schedule_{ random.randint(1, 10000)}") tmp_schedule_path = os.path.join(
regress_dir, f"tmp_schedule_{ random.randint(1, 10000)}"
)
# some tests don't need a schedule to run # some tests don't need a schedule to run
# e.g tests that are in the first place in their own schedule # e.g tests that are in the first place in their own schedule
if test_file_name not in test_files_to_run_without_schedule: if test_file_name not in test_files_to_run_without_schedule:
shutil.copy2(os.path.join(regress_dir, test_schedule), tmp_schedule_path) shutil.copy2(os.path.join(regress_dir, test_schedule), tmp_schedule_path)
with open(tmp_schedule_path, "a") as myfile: with open(tmp_schedule_path, "a") as myfile:
for i in range(args['repeat']): for i in range(args["repeat"]):
myfile.write(test_schedule_line) myfile.write(test_schedule_line)
# find suitable make recipe # find suitable make recipe
if "isolation" in test_schedule: if "isolation" in test_schedule:
make_recipe = 'check-isolation-custom-schedule' make_recipe = "check-isolation-custom-schedule"
elif "failure" in test_schedule: elif "failure" in test_schedule:
make_recipe = 'check-failure-custom-schedule' make_recipe = "check-failure-custom-schedule"
else: else:
make_recipe = 'check-custom-schedule' make_recipe = "check-custom-schedule"
if args['valgrind']: if args["valgrind"]:
make_recipe += '-vg' make_recipe += "-vg"
# prepare command to run tests # prepare command to run tests
test_command = f"make -C {regress_dir} {make_recipe} SCHEDULE='{pathlib.Path(tmp_schedule_path).stem}'" test_command = f"make -C {regress_dir} {make_recipe} SCHEDULE='{pathlib.Path(tmp_schedule_path).stem}'"

View File

@ -112,7 +112,11 @@ def main(config):
config.node_name_to_ports.keys(), config.node_name_to_ports.keys(),
) )
common.start_databases( common.start_databases(
config.new_bindir, config.new_datadir, config.node_name_to_ports, config.name, {} config.new_bindir,
config.new_datadir,
config.node_name_to_ports,
config.name,
{},
) )
citus_finish_pg_upgrade(config.new_bindir, config.node_name_to_ports.values()) citus_finish_pg_upgrade(config.new_bindir, config.node_name_to_ports.values())

View File

@ -20,15 +20,17 @@ logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=loggin
# I. Command Strings # I. Command Strings
class Handler: class Handler:
''' """
This class hierarchy serves two purposes: This class hierarchy serves two purposes:
1. Allow command strings to be evaluated. Once evaluated you'll have a Handler you can 1. Allow command strings to be evaluated. Once evaluated you'll have a Handler you can
pass packets to pass packets to
2. Process packets as they come in and decide what to do with them. 2. Process packets as they come in and decide what to do with them.
Subclasses which want to change how packets are handled should override _handle. Subclasses which want to change how packets are handled should override _handle.
''' """
def __init__(self, root=None): def __init__(self, root=None):
# all packets are first sent to the root handler to be processed # all packets are first sent to the root handler to be processed
self.root = root if root else self self.root = root if root else self
@ -38,30 +40,31 @@ class Handler:
def _accept(self, flow, message): def _accept(self, flow, message):
result = self._handle(flow, message) result = self._handle(flow, message)
if result == 'pass': if result == "pass":
# defer to our child # defer to our child
if not self.next: if not self.next:
raise Exception("we don't know what to do!") raise Exception("we don't know what to do!")
if self.next._accept(flow, message) == 'stop': if self.next._accept(flow, message) == "stop":
if self.root is not self: if self.root is not self:
return 'stop' return "stop"
self.next = KillHandler(self) self.next = KillHandler(self)
flow.kill() flow.kill()
else: else:
return result return result
def _handle(self, flow, message): def _handle(self, flow, message):
''' """
Handlers can return one of three things: Handlers can return one of three things:
- "done" tells the parent to stop processing. This performs the default action, - "done" tells the parent to stop processing. This performs the default action,
which is to allow the packet to be sent. which is to allow the packet to be sent.
- "pass" means to delegate to self.next and do whatever it wants - "pass" means to delegate to self.next and do whatever it wants
- "stop" means all processing will stop, and all connections will be killed - "stop" means all processing will stop, and all connections will be killed
''' """
# subclasses must implement this # subclasses must implement this
raise NotImplementedError() raise NotImplementedError()
class FilterableMixin: class FilterableMixin:
def contains(self, pattern): def contains(self, pattern):
self.next = Contains(self.root, pattern) self.next = Contains(self.root, pattern)
@ -76,7 +79,7 @@ class FilterableMixin:
return self.next return self.next
def __getattr__(self, attr): def __getattr__(self, attr):
''' """
Methods such as .onQuery trigger when a packet with that name is intercepted Methods such as .onQuery trigger when a packet with that name is intercepted
Adds support for commands such as: Adds support for commands such as:
@ -85,14 +88,17 @@ class FilterableMixin:
Returns a function because the above command is resolved in two steps: Returns a function because the above command is resolved in two steps:
conn.onQuery becomes conn.__getattr__("onQuery") conn.onQuery becomes conn.__getattr__("onQuery")
conn.onQuery(query="COPY") becomes conn.__getattr__("onQuery")(query="COPY") conn.onQuery(query="COPY") becomes conn.__getattr__("onQuery")(query="COPY")
''' """
if attr.startswith('on'): if attr.startswith("on"):
def doit(**kwargs): def doit(**kwargs):
self.next = OnPacket(self.root, attr[2:], kwargs) self.next = OnPacket(self.root, attr[2:], kwargs)
return self.next return self.next
return doit return doit
raise AttributeError raise AttributeError
class ActionsMixin: class ActionsMixin:
def kill(self): def kill(self):
self.next = KillHandler(self.root) self.next = KillHandler(self.root)
@ -118,29 +124,37 @@ class ActionsMixin:
self.next = ConnectDelayHandler(self.root, timeMs) self.next = ConnectDelayHandler(self.root, timeMs)
return self.next return self.next
class AcceptHandler(Handler): class AcceptHandler(Handler):
def __init__(self, root): def __init__(self, root):
super().__init__(root) super().__init__(root)
def _handle(self, flow, message): def _handle(self, flow, message):
return 'done' return "done"
class KillHandler(Handler): class KillHandler(Handler):
def __init__(self, root): def __init__(self, root):
super().__init__(root) super().__init__(root)
def _handle(self, flow, message): def _handle(self, flow, message):
flow.kill() flow.kill()
return 'done' return "done"
class KillAllHandler(Handler): class KillAllHandler(Handler):
def __init__(self, root): def __init__(self, root):
super().__init__(root) super().__init__(root)
def _handle(self, flow, message): def _handle(self, flow, message):
return 'stop' return "stop"
class ResetHandler(Handler): class ResetHandler(Handler):
# try to force a RST to be sent, something went very wrong! # try to force a RST to be sent, something went very wrong!
def __init__(self, root): def __init__(self, root):
super().__init__(root) super().__init__(root)
def _handle(self, flow, message): def _handle(self, flow, message):
flow.kill() # tell mitmproxy this connection should be closed flow.kill() # tell mitmproxy this connection should be closed
@ -152,8 +166,9 @@ class ResetHandler(Handler):
# cause linux to send a RST # cause linux to send a RST
LINGER_ON, LINGER_TIMEOUT = 1, 0 LINGER_ON, LINGER_TIMEOUT = 1, 0
conn.setsockopt( conn.setsockopt(
socket.SOL_SOCKET, socket.SO_LINGER, socket.SOL_SOCKET,
struct.pack('ii', LINGER_ON, LINGER_TIMEOUT) socket.SO_LINGER,
struct.pack("ii", LINGER_ON, LINGER_TIMEOUT),
) )
conn.close() conn.close()
@ -161,28 +176,35 @@ class ResetHandler(Handler):
# tries to call conn.shutdown(), but there's nothing else to clean up so that's # tries to call conn.shutdown(), but there's nothing else to clean up so that's
# maybe okay # maybe okay
return 'done' return "done"
class CancelHandler(Handler): class CancelHandler(Handler):
'Send a SIGINT to the process' "Send a SIGINT to the process"
def __init__(self, root, pid): def __init__(self, root, pid):
super().__init__(root) super().__init__(root)
self.pid = pid self.pid = pid
def _handle(self, flow, message): def _handle(self, flow, message):
os.kill(self.pid, signal.SIGINT) os.kill(self.pid, signal.SIGINT)
# give the signal a chance to be received before we let the packet through # give the signal a chance to be received before we let the packet through
time.sleep(0.1) time.sleep(0.1)
return 'done' return "done"
class ConnectDelayHandler(Handler): class ConnectDelayHandler(Handler):
'Delay the initial packet by sleeping before deciding what to do' "Delay the initial packet by sleeping before deciding what to do"
def __init__(self, root, timeMs): def __init__(self, root, timeMs):
super().__init__(root) super().__init__(root)
self.timeMs = timeMs self.timeMs = timeMs
def _handle(self, flow, message): def _handle(self, flow, message):
if message.is_initial: if message.is_initial:
time.sleep(self.timeMs / 1000.0) time.sleep(self.timeMs / 1000.0)
return 'done' return "done"
class Contains(Handler, ActionsMixin, FilterableMixin): class Contains(Handler, ActionsMixin, FilterableMixin):
def __init__(self, root, pattern): def __init__(self, root, pattern):
@ -191,8 +213,9 @@ class Contains(Handler, ActionsMixin, FilterableMixin):
def _handle(self, flow, message): def _handle(self, flow, message):
if self.pattern in message.content: if self.pattern in message.content:
return 'pass' return "pass"
return 'done' return "done"
class Matches(Handler, ActionsMixin, FilterableMixin): class Matches(Handler, ActionsMixin, FilterableMixin):
def __init__(self, root, pattern): def __init__(self, root, pattern):
@ -201,47 +224,56 @@ class Matches(Handler, ActionsMixin, FilterableMixin):
def _handle(self, flow, message): def _handle(self, flow, message):
if self.pattern.search(message.content): if self.pattern.search(message.content):
return 'pass' return "pass"
return 'done' return "done"
class After(Handler, ActionsMixin, FilterableMixin): class After(Handler, ActionsMixin, FilterableMixin):
"Don't pass execution to our child until we've handled 'times' messages" "Don't pass execution to our child until we've handled 'times' messages"
def __init__(self, root, times): def __init__(self, root, times):
super().__init__(root) super().__init__(root)
self.target = times self.target = times
def _handle(self, flow, message): def _handle(self, flow, message):
if not hasattr(flow, '_after_count'): if not hasattr(flow, "_after_count"):
flow._after_count = 0 flow._after_count = 0
if flow._after_count >= self.target: if flow._after_count >= self.target:
return 'pass' return "pass"
flow._after_count += 1 flow._after_count += 1
return 'done' return "done"
class OnPacket(Handler, ActionsMixin, FilterableMixin): class OnPacket(Handler, ActionsMixin, FilterableMixin):
'''Triggers when a packet of the specified kind comes around''' """Triggers when a packet of the specified kind comes around"""
def __init__(self, root, packet_kind, kwargs): def __init__(self, root, packet_kind, kwargs):
super().__init__(root) super().__init__(root)
self.packet_kind = packet_kind self.packet_kind = packet_kind
self.filters = kwargs self.filters = kwargs
def _handle(self, flow, message): def _handle(self, flow, message):
if not message.parsed: if not message.parsed:
# if this is the first message in the connection we just skip it # if this is the first message in the connection we just skip it
return 'done' return "done"
for msg in message.parsed: for msg in message.parsed:
typ = structs.message_type(msg, from_frontend=message.from_client) typ = structs.message_type(msg, from_frontend=message.from_client)
if typ == self.packet_kind: if typ == self.packet_kind:
matches = structs.message_matches(msg, self.filters, message.from_client) matches = structs.message_matches(
msg, self.filters, message.from_client
)
if matches: if matches:
return 'pass' return "pass"
return 'done' return "done"
class RootHandler(Handler, ActionsMixin, FilterableMixin): class RootHandler(Handler, ActionsMixin, FilterableMixin):
def _handle(self, flow, message): def _handle(self, flow, message):
# do whatever the next Handler tells us to do # do whatever the next Handler tells us to do
return 'pass' return "pass"
class RecorderCommand: class RecorderCommand:
def __init__(self): def __init__(self):
@ -250,23 +282,26 @@ class RecorderCommand:
def dump(self): def dump(self):
# When the user calls dump() we return everything we've captured # When the user calls dump() we return everything we've captured
self.command = 'dump' self.command = "dump"
return self return self
def reset(self): def reset(self):
# If the user calls reset() we dump all captured packets without returning them # If the user calls reset() we dump all captured packets without returning them
self.command = 'reset' self.command = "reset"
return self return self
# II. Utilities for interfacing with mitmproxy # II. Utilities for interfacing with mitmproxy
def build_handler(spec): def build_handler(spec):
'Turns a command string into a RootHandler ready to accept packets' "Turns a command string into a RootHandler ready to accept packets"
root = RootHandler() root = RootHandler()
recorder = RecorderCommand() recorder = RecorderCommand()
handler = eval(spec, {'__builtins__': {}}, {'conn': root, 'recorder': recorder}) handler = eval(spec, {"__builtins__": {}}, {"conn": root, "recorder": recorder})
return handler.root return handler.root
# a bunch of globals # a bunch of globals
handler = None # the current handler used to process packets handler = None # the current handler used to process packets
@ -274,25 +309,25 @@ command_thread = None # sits on the fifo and waits for new commands to come in
captured_messages = queue.Queue() # where we store messages used for recorder.dump() captured_messages = queue.Queue() # where we store messages used for recorder.dump()
connection_count = count() # so we can give connections ids in recorder.dump() connection_count = count() # so we can give connections ids in recorder.dump()
def listen_for_commands(fifoname):
def listen_for_commands(fifoname):
def emit_row(conn, from_client, message): def emit_row(conn, from_client, message):
# we're using the COPY text format. It requires us to escape backslashes # we're using the COPY text format. It requires us to escape backslashes
cleaned = message.replace('\\', '\\\\') cleaned = message.replace("\\", "\\\\")
source = 'coordinator' if from_client else 'worker' source = "coordinator" if from_client else "worker"
return '{}\t{}\t{}'.format(conn, source, cleaned) return "{}\t{}\t{}".format(conn, source, cleaned)
def emit_message(message): def emit_message(message):
if message.is_initial: if message.is_initial:
return emit_row( return emit_row(
message.connection_id, message.from_client, '[initial message]' message.connection_id, message.from_client, "[initial message]"
) )
pretty = structs.print(message.parsed) pretty = structs.print(message.parsed)
return emit_row(message.connection_id, message.from_client, pretty) return emit_row(message.connection_id, message.from_client, pretty)
def all_items(queue_): def all_items(queue_):
'Pulls everything out of the queue without blocking' "Pulls everything out of the queue without blocking"
try: try:
while True: while True:
yield queue_.get(block=False) yield queue_.get(block=False)
@ -300,23 +335,27 @@ def listen_for_commands(fifoname):
pass pass
def drop_terminate_messages(messages): def drop_terminate_messages(messages):
''' """
Terminate() messages happen eventually, Citus doesn't feel any need to send them Terminate() messages happen eventually, Citus doesn't feel any need to send them
immediately, so tests which embed them aren't reproducible and fail to timing immediately, so tests which embed them aren't reproducible and fail to timing
issues. Here we simply drop those messages. issues. Here we simply drop those messages.
''' """
def isTerminate(msg, from_client): def isTerminate(msg, from_client):
kind = structs.message_type(msg, from_client) kind = structs.message_type(msg, from_client)
return kind == 'Terminate' return kind == "Terminate"
for message in messages: for message in messages:
if not message.parsed: if not message.parsed:
yield message yield message
continue continue
message.parsed = ListContainer([ message.parsed = ListContainer(
msg for msg in message.parsed [
msg
for msg in message.parsed
if not isTerminate(msg, message.from_client) if not isTerminate(msg, message.from_client)
]) ]
)
message.parsed.from_frontend = message.from_client message.parsed.from_frontend = message.from_client
if len(message.parsed) == 0: if len(message.parsed) == 0:
continue continue
@ -324,35 +363,35 @@ def listen_for_commands(fifoname):
def handle_recorder(recorder): def handle_recorder(recorder):
global connection_count global connection_count
result = '' result = ""
if recorder.command == 'reset': if recorder.command == "reset":
result = '' result = ""
connection_count = count() connection_count = count()
elif recorder.command != 'dump': elif recorder.command != "dump":
# this should never happen # this should never happen
raise Exception('Unrecognized command: {}'.format(recorder.command)) raise Exception("Unrecognized command: {}".format(recorder.command))
results = [] results = []
messages = all_items(captured_messages) messages = all_items(captured_messages)
messages = drop_terminate_messages(messages) messages = drop_terminate_messages(messages)
for message in messages: for message in messages:
if recorder.command == 'reset': if recorder.command == "reset":
continue continue
results.append(emit_message(message)) results.append(emit_message(message))
result = '\n'.join(results) result = "\n".join(results)
logging.debug('about to write to fifo') logging.debug("about to write to fifo")
with open(fifoname, mode='w') as fifo: with open(fifoname, mode="w") as fifo:
logging.debug('successfully opened the fifo for writing') logging.debug("successfully opened the fifo for writing")
fifo.write('{}'.format(result)) fifo.write("{}".format(result))
while True: while True:
logging.debug('about to read from fifo') logging.debug("about to read from fifo")
with open(fifoname, mode='r') as fifo: with open(fifoname, mode="r") as fifo:
logging.debug('successfully opened the fifo for reading') logging.debug("successfully opened the fifo for reading")
slug = fifo.read() slug = fifo.read()
logging.info('received new command: %s', slug.rstrip()) logging.info("received new command: %s", slug.rstrip())
try: try:
handler = build_handler(slug) handler = build_handler(slug)
@ -371,13 +410,14 @@ def listen_for_commands(fifoname):
except Exception as e: except Exception as e:
result = str(e) result = str(e)
else: else:
result = '' result = ""
logging.debug("about to write to fifo")
with open(fifoname, mode="w") as fifo:
logging.debug("successfully opened the fifo for writing")
fifo.write("{}\n".format(result))
logging.info("responded to command: %s", result.split("\n")[0])
logging.debug('about to write to fifo')
with open(fifoname, mode='w') as fifo:
logging.debug('successfully opened the fifo for writing')
fifo.write('{}\n'.format(result))
logging.info('responded to command: %s', result.split("\n")[0])
def create_thread(fifoname): def create_thread(fifoname):
global command_thread global command_thread
@ -388,42 +428,46 @@ def create_thread(fifoname):
return return
if command_thread: if command_thread:
print('cannot change the fifo path once mitmproxy has started'); print("cannot change the fifo path once mitmproxy has started")
return return
command_thread = threading.Thread(target=listen_for_commands, args=(fifoname,), daemon=True) command_thread = threading.Thread(
target=listen_for_commands, args=(fifoname,), daemon=True
)
command_thread.start() command_thread.start()
# III. mitmproxy callbacks # III. mitmproxy callbacks
def load(loader): def load(loader):
loader.add_option('slug', str, 'conn.allow()', "A script to run") loader.add_option("slug", str, "conn.allow()", "A script to run")
loader.add_option('fifo', str, '', "Which fifo to listen on for commands") loader.add_option("fifo", str, "", "Which fifo to listen on for commands")
def configure(updated): def configure(updated):
global handler global handler
if 'slug' in updated: if "slug" in updated:
text = ctx.options.slug text = ctx.options.slug
handler = build_handler(text) handler = build_handler(text)
if 'fifo' in updated: if "fifo" in updated:
fifoname = ctx.options.fifo fifoname = ctx.options.fifo
create_thread(fifoname) create_thread(fifoname)
def tcp_message(flow: tcp.TCPFlow): def tcp_message(flow: tcp.TCPFlow):
''' """
This callback is hit every time mitmproxy receives a packet. It's the main entrypoint This callback is hit every time mitmproxy receives a packet. It's the main entrypoint
into this script. into this script.
''' """
global connection_count global connection_count
tcp_msg = flow.messages[-1] tcp_msg = flow.messages[-1]
# Keep track of all the different connections, assign a unique id to each # Keep track of all the different connections, assign a unique id to each
if not hasattr(flow, 'connection_id'): if not hasattr(flow, "connection_id"):
flow.connection_id = next(connection_count) flow.connection_id = next(connection_count)
tcp_msg.connection_id = flow.connection_id tcp_msg.connection_id = flow.connection_id
@ -434,7 +478,9 @@ def tcp_message(flow: tcp.TCPFlow):
# skip parsing initial messages for now, they're not important # skip parsing initial messages for now, they're not important
tcp_msg.parsed = None tcp_msg.parsed = None
else: else:
tcp_msg.parsed = structs.parse(tcp_msg.content, from_frontend=tcp_msg.from_client) tcp_msg.parsed = structs.parse(
tcp_msg.content, from_frontend=tcp_msg.from_client
)
# record the message, for debugging purposes # record the message, for debugging purposes
captured_messages.put(tcp_msg) captured_messages.put(tcp_msg)

View File

@ -1,8 +1,25 @@
from construct import ( from construct import (
Struct, Struct,
Int8ub, Int16ub, Int32ub, Int16sb, Int32sb, Int8ub,
Bytes, CString, Computed, Switch, Seek, this, Pointer, Int16ub,
GreedyRange, Enum, Byte, Probe, FixedSized, RestreamData, GreedyBytes, Array Int32ub,
Int16sb,
Int32sb,
Bytes,
CString,
Computed,
Switch,
Seek,
this,
Pointer,
GreedyRange,
Enum,
Byte,
Probe,
FixedSized,
RestreamData,
GreedyBytes,
Array,
) )
import construct.lib as cl import construct.lib as cl
@ -11,23 +28,30 @@ import re
# For all possible message formats see: # For all possible message formats see:
# https://www.postgresql.org/docs/current/protocol-message-formats.html # https://www.postgresql.org/docs/current/protocol-message-formats.html
class MessageMeta(type): class MessageMeta(type):
def __init__(cls, name, bases, namespace): def __init__(cls, name, bases, namespace):
''' """
__init__ is called every time a subclass of MessageMeta is declared __init__ is called every time a subclass of MessageMeta is declared
''' """
if not hasattr(cls, "_msgtypes"): if not hasattr(cls, "_msgtypes"):
raise Exception("classes which use MessageMeta must have a '_msgtypes' field") raise Exception(
"classes which use MessageMeta must have a '_msgtypes' field"
)
if not hasattr(cls, "_classes"): if not hasattr(cls, "_classes"):
raise Exception("classes which use MessageMeta must have a '_classes' field") raise Exception(
"classes which use MessageMeta must have a '_classes' field"
)
if not hasattr(cls, "struct"): if not hasattr(cls, "struct"):
# This is one of the direct subclasses # This is one of the direct subclasses
return return
if cls.__name__ in cls._classes: if cls.__name__ in cls._classes:
raise Exception("You've already made a class called {}".format( cls.__name__)) raise Exception(
"You've already made a class called {}".format(cls.__name__)
)
cls._classes[cls.__name__] = cls cls._classes[cls.__name__] = cls
# add a _type field to the struct so we can identify it while printing structs # add a _type field to the struct so we can identify it while printing structs
@ -39,34 +63,41 @@ class MessageMeta(type):
# register the type, so we can tell the parser about it # register the type, so we can tell the parser about it
key = cls.key key = cls.key
if key in cls._msgtypes: if key in cls._msgtypes:
raise Exception('key {} is already assigned to {}'.format( raise Exception(
key, cls._msgtypes[key].__name__) "key {} is already assigned to {}".format(
key, cls._msgtypes[key].__name__
)
) )
cls._msgtypes[key] = cls cls._msgtypes[key] = cls
class Message: class Message:
'Do not subclass this object directly. Instead, subclass of one of the below types' "Do not subclass this object directly. Instead, subclass of one of the below types"
def print(message): def print(message):
'Define this on subclasses you want to change the representation of' "Define this on subclasses you want to change the representation of"
raise NotImplementedError raise NotImplementedError
def typeof(message): def typeof(message):
'Define this on subclasses you want to change the expressed type of' "Define this on subclasses you want to change the expressed type of"
return message._type return message._type
@classmethod @classmethod
def _default_print(cls, name, msg): def _default_print(cls, name, msg):
recur = cls.print_message recur = cls.print_message
return "{}({})".format(name, ",".join( return "{}({})".format(
"{}={}".format(key, recur(value)) for key, value in msg.items() name,
if not key.startswith('_') ",".join(
)) "{}={}".format(key, recur(value))
for key, value in msg.items()
if not key.startswith("_")
),
)
@classmethod @classmethod
def find_typeof(cls, msg): def find_typeof(cls, msg):
if not hasattr(cls, "_msgtypes"): if not hasattr(cls, "_msgtypes"):
raise Exception('Do not call this method on Message, call it on a subclass') raise Exception("Do not call this method on Message, call it on a subclass")
if isinstance(msg, cl.ListContainer): if isinstance(msg, cl.ListContainer):
raise ValueError("do not call this on a list of messages") raise ValueError("do not call this on a list of messages")
if not isinstance(msg, cl.Container): if not isinstance(msg, cl.Container):
@ -80,7 +111,7 @@ class Message:
@classmethod @classmethod
def print_message(cls, msg): def print_message(cls, msg):
if not hasattr(cls, "_msgtypes"): if not hasattr(cls, "_msgtypes"):
raise Exception('Do not call this method on Message, call it on a subclass') raise Exception("Do not call this method on Message, call it on a subclass")
if isinstance(msg, cl.ListContainer): if isinstance(msg, cl.ListContainer):
return repr([cls.print_message(message) for message in msg]) return repr([cls.print_message(message) for message in msg])
@ -101,38 +132,34 @@ class Message:
@classmethod @classmethod
def name_to_struct(cls): def name_to_struct(cls):
return { return {_class.__name__: _class.struct for _class in cls._msgtypes.values()}
_class.__name__: _class.struct
for _class in cls._msgtypes.values()
}
@classmethod @classmethod
def name_to_key(cls): def name_to_key(cls):
return { return {_class.__name__: ord(key) for key, _class in cls._msgtypes.items()}
_class.__name__ : ord(key)
for key, _class in cls._msgtypes.items()
}
class SharedMessage(Message, metaclass=MessageMeta): class SharedMessage(Message, metaclass=MessageMeta):
'A message which could be sent by either the frontend or the backend' "A message which could be sent by either the frontend or the backend"
_msgtypes = dict() _msgtypes = dict()
_classes = dict() _classes = dict()
class FrontendMessage(Message, metaclass=MessageMeta): class FrontendMessage(Message, metaclass=MessageMeta):
'A message which will only be sent be a backend' "A message which will only be sent be a backend"
_msgtypes = dict() _msgtypes = dict()
_classes = dict() _classes = dict()
class BackendMessage(Message, metaclass=MessageMeta): class BackendMessage(Message, metaclass=MessageMeta):
'A message which will only be sent be a frontend' "A message which will only be sent be a frontend"
_msgtypes = dict() _msgtypes = dict()
_classes = dict() _classes = dict()
class Query(FrontendMessage): class Query(FrontendMessage):
key = 'Q' key = "Q"
struct = Struct( struct = Struct("query" / CString("ascii"))
"query" / CString("ascii")
)
@staticmethod @staticmethod
def print(message): def print(message):
@ -144,103 +171,114 @@ class Query(FrontendMessage):
@staticmethod @staticmethod
def normalize_shards(content): def normalize_shards(content):
''' """
For example: For example:
>>> normalize_shards( >>> normalize_shards(
>>> 'COPY public.copy_test_120340 (key, value) FROM STDIN WITH (FORMAT BINARY))' >>> 'COPY public.copy_test_120340 (key, value) FROM STDIN WITH (FORMAT BINARY))'
>>> ) >>> )
'COPY public.copy_test_XXXXXX (key, value) FROM STDIN WITH (FORMAT BINARY))' 'COPY public.copy_test_XXXXXX (key, value) FROM STDIN WITH (FORMAT BINARY))'
''' """
result = content result = content
pattern = re.compile('public\.[a-z_]+(?P<shardid>[0-9]+)') pattern = re.compile("public\.[a-z_]+(?P<shardid>[0-9]+)")
for match in pattern.finditer(content): for match in pattern.finditer(content):
span = match.span('shardid') span = match.span("shardid")
replacement = 'X'*( span[1] - span[0] ) replacement = "X" * (span[1] - span[0])
result = result[: span[0]] + replacement + result[span[1] :] result = result[: span[0]] + replacement + result[span[1] :]
return result return result
@staticmethod @staticmethod
def normalize_timestamps(content): def normalize_timestamps(content):
''' """
For example: For example:
>>> normalize_timestamps('2018-06-07 05:18:19.388992-07') >>> normalize_timestamps('2018-06-07 05:18:19.388992-07')
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX' 'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
>>> normalize_timestamps('2018-06-11 05:30:43.01382-07') >>> normalize_timestamps('2018-06-11 05:30:43.01382-07')
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX' 'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
''' """
pattern = re.compile( pattern = re.compile(
'[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{2,6}-[0-9]{2}' "[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{2,6}-[0-9]{2}"
) )
return re.sub(pattern, 'XXXX-XX-XX XX:XX:XX.XXXXXX-XX', content) return re.sub(pattern, "XXXX-XX-XX XX:XX:XX.XXXXXX-XX", content)
@staticmethod @staticmethod
def normalize_assign_txn_id(content): def normalize_assign_txn_id(content):
''' """
For example: For example:
>>> normalize_assign_txn_id('SELECT assign_distributed_transaction_id(0, 52, ...') >>> normalize_assign_txn_id('SELECT assign_distributed_transaction_id(0, 52, ...')
'SELECT assign_distributed_transaction_id(0, XX, ...' 'SELECT assign_distributed_transaction_id(0, XX, ...'
''' """
pattern = re.compile( pattern = re.compile(
'assign_distributed_transaction_id\s*\(' # a method call "assign_distributed_transaction_id\s*\(" # a method call
'\s*[0-9]+\s*,' # an integer first parameter "\s*[0-9]+\s*," # an integer first parameter
'\s*(?P<transaction_id>[0-9]+)' # an integer second parameter "\s*(?P<transaction_id>[0-9]+)" # an integer second parameter
) )
result = content result = content
for match in pattern.finditer(content): for match in pattern.finditer(content):
span = match.span('transaction_id') span = match.span("transaction_id")
result = result[:span[0]] + 'XX' + result[span[1]:] result = result[: span[0]] + "XX" + result[span[1] :]
return result return result
class Terminate(FrontendMessage): class Terminate(FrontendMessage):
key = 'X' key = "X"
struct = Struct() struct = Struct()
class CopyData(SharedMessage): class CopyData(SharedMessage):
key = 'd' key = "d"
struct = Struct( struct = Struct(
'data' / GreedyBytes # reads all of the data left in this substream "data" / GreedyBytes # reads all of the data left in this substream
) )
class CopyDone(SharedMessage): class CopyDone(SharedMessage):
key = 'c' key = "c"
struct = Struct() struct = Struct()
class EmptyQueryResponse(BackendMessage): class EmptyQueryResponse(BackendMessage):
key = 'I' key = "I"
struct = Struct() struct = Struct()
class CopyOutResponse(BackendMessage): class CopyOutResponse(BackendMessage):
key = 'H' key = "H"
struct = Struct( struct = Struct(
"format" / Int8ub, "format" / Int8ub,
"columncount" / Int16ub, "columncount" / Int16ub,
"columns" / Array(this.columncount, Struct( "columns" / Array(this.columncount, Struct("format" / Int16ub)),
"format" / Int16ub
))
) )
class ReadyForQuery(BackendMessage): class ReadyForQuery(BackendMessage):
key='Z' key = "Z"
struct = Struct("state"/Enum(Byte,
idle=ord('I'),
in_transaction_block=ord('T'),
in_failed_transaction_block=ord('E')
))
class CommandComplete(BackendMessage):
key = 'C'
struct = Struct( struct = Struct(
"command" / CString("ascii") "state"
/ Enum(
Byte,
idle=ord("I"),
in_transaction_block=ord("T"),
in_failed_transaction_block=ord("E"),
)
) )
class CommandComplete(BackendMessage):
key = "C"
struct = Struct("command" / CString("ascii"))
class RowDescription(BackendMessage): class RowDescription(BackendMessage):
key = 'T' key = "T"
struct = Struct( struct = Struct(
"fieldcount" / Int16ub, "fieldcount" / Int16ub,
"fields" / Array(this.fieldcount, Struct( "fields"
/ Array(
this.fieldcount,
Struct(
"_type" / Computed("F"), "_type" / Computed("F"),
"name" / CString("ascii"), "name" / CString("ascii"),
"tableoid" / Int32ub, "tableoid" / Int32ub,
@ -249,27 +287,35 @@ class RowDescription(BackendMessage):
"typlen" / Int16sb, "typlen" / Int16sb,
"typmod" / Int32sb, "typmod" / Int32sb,
"format_code" / Int16ub, "format_code" / Int16ub,
)) ),
),
) )
class DataRow(BackendMessage): class DataRow(BackendMessage):
key = 'D' key = "D"
struct = Struct( struct = Struct(
"_type" / Computed("data_row"), "_type" / Computed("data_row"),
"columncount" / Int16ub, "columncount" / Int16ub,
"columns" / Array(this.columncount, Struct( "columns"
/ Array(
this.columncount,
Struct(
"_type" / Computed("C"), "_type" / Computed("C"),
"length" / Int16sb, "length" / Int16sb,
"value" / Bytes(this.length) "value" / Bytes(this.length),
)) ),
),
) )
class AuthenticationOk(BackendMessage): class AuthenticationOk(BackendMessage):
key = 'R' key = "R"
struct = Struct() struct = Struct()
class ParameterStatus(BackendMessage): class ParameterStatus(BackendMessage):
key = 'S' key = "S"
struct = Struct( struct = Struct(
"name" / CString("ASCII"), "name" / CString("ASCII"),
"value" / CString("ASCII"), "value" / CString("ASCII"),
@ -281,161 +327,156 @@ class ParameterStatus(BackendMessage):
@staticmethod @staticmethod
def normalize(name, value): def normalize(name, value):
if name in ('TimeZone', 'server_version'): if name in ("TimeZone", "server_version"):
value = 'XXX' value = "XXX"
return (name, value) return (name, value)
class BackendKeyData(BackendMessage): class BackendKeyData(BackendMessage):
key = 'K' key = "K"
struct = Struct( struct = Struct("pid" / Int32ub, "key" / Bytes(4))
"pid" / Int32ub,
"key" / Bytes(4)
)
def print(message): def print(message):
# Both of these should be censored, for reproducible regression test output # Both of these should be censored, for reproducible regression test output
return "BackendKeyData(XXX)" return "BackendKeyData(XXX)"
class NoticeResponse(BackendMessage): class NoticeResponse(BackendMessage):
key = 'N' key = "N"
struct = Struct( struct = Struct(
"notices" / GreedyRange( "notices"
/ GreedyRange(
Struct( Struct(
"key" / Enum(Byte, "key"
severity=ord('S'), / Enum(
_severity_not_localized=ord('V'), Byte,
_sql_state=ord('C'), severity=ord("S"),
message=ord('M'), _severity_not_localized=ord("V"),
detail=ord('D'), _sql_state=ord("C"),
hint=ord('H'), message=ord("M"),
_position=ord('P'), detail=ord("D"),
_internal_position=ord('p'), hint=ord("H"),
_internal_query=ord('q'), _position=ord("P"),
_where=ord('W'), _internal_position=ord("p"),
schema_name=ord('s'), _internal_query=ord("q"),
table_name=ord('t'), _where=ord("W"),
column_name=ord('c'), schema_name=ord("s"),
data_type_name=ord('d'), table_name=ord("t"),
constraint_name=ord('n'), column_name=ord("c"),
_file_name=ord('F'), data_type_name=ord("d"),
_line_no=ord('L'), constraint_name=ord("n"),
_routine_name=ord('R') _file_name=ord("F"),
_line_no=ord("L"),
_routine_name=ord("R"),
), ),
"value" / CString("ASCII") "value" / CString("ASCII"),
) )
) )
) )
def print(message): def print(message):
return "NoticeResponse({})".format(", ".join( return "NoticeResponse({})".format(
", ".join(
"{}={}".format(response.key, response.value) "{}={}".format(response.key, response.value)
for response in message.notices for response in message.notices
if not response.key.startswith('_') if not response.key.startswith("_")
)) )
)
class Parse(FrontendMessage): class Parse(FrontendMessage):
key = 'P' key = "P"
struct = Struct( struct = Struct(
"name" / CString("ASCII"), "name" / CString("ASCII"),
"query" / CString("ASCII"), "query" / CString("ASCII"),
"_parametercount" / Int16ub, "_parametercount" / Int16ub,
"parameters" / Array( "parameters" / Array(this._parametercount, Int32ub),
this._parametercount,
Int32ub
)
) )
class ParseComplete(BackendMessage): class ParseComplete(BackendMessage):
key = '1' key = "1"
struct = Struct() struct = Struct()
class Bind(FrontendMessage): class Bind(FrontendMessage):
key = 'B' key = "B"
struct = Struct( struct = Struct(
"destination_portal" / CString("ASCII"), "destination_portal" / CString("ASCII"),
"prepared_statement" / CString("ASCII"), "prepared_statement" / CString("ASCII"),
"_parameter_format_code_count" / Int16ub, "_parameter_format_code_count" / Int16ub,
"parameter_format_codes" / Array(this._parameter_format_code_count, "parameter_format_codes" / Array(this._parameter_format_code_count, Int16ub),
Int16ub),
"_parameter_value_count" / Int16ub, "_parameter_value_count" / Int16ub,
"parameter_values" / Array( "parameter_values"
/ Array(
this._parameter_value_count, this._parameter_value_count,
Struct( Struct("length" / Int32ub, "value" / Bytes(this.length)),
"length" / Int32ub,
"value" / Bytes(this.length)
)
), ),
"result_column_format_count" / Int16ub, "result_column_format_count" / Int16ub,
"result_column_format_codes" / Array(this.result_column_format_count, "result_column_format_codes" / Array(this.result_column_format_count, Int16ub),
Int16ub)
) )
class BindComplete(BackendMessage): class BindComplete(BackendMessage):
key = '2' key = "2"
struct = Struct() struct = Struct()
class NoData(BackendMessage): class NoData(BackendMessage):
key = 'n' key = "n"
struct = Struct() struct = Struct()
class Describe(FrontendMessage): class Describe(FrontendMessage):
key = 'D' key = "D"
struct = Struct( struct = Struct(
"type" / Enum(Byte, "type" / Enum(Byte, prepared_statement=ord("S"), portal=ord("P")),
prepared_statement=ord('S'), "name" / CString("ASCII"),
portal=ord('P')
),
"name" / CString("ASCII")
) )
def print(message): def print(message):
return "Describe({}={})".format( return "Describe({}={})".format(message.type, message.name or "<unnamed>")
message.type,
message.name or "<unnamed>"
)
class Execute(FrontendMessage): class Execute(FrontendMessage):
key = 'E' key = "E"
struct = Struct( struct = Struct("name" / CString("ASCII"), "max_rows_to_return" / Int32ub)
"name" / CString("ASCII"),
"max_rows_to_return" / Int32ub
)
def print(message): def print(message):
return "Execute({}, max_rows_to_return={})".format( return "Execute({}, max_rows_to_return={})".format(
message.name or "<unnamed>", message.name or "<unnamed>", message.max_rows_to_return
message.max_rows_to_return
) )
class Sync(FrontendMessage): class Sync(FrontendMessage):
key = 'S' key = "S"
struct = Struct() struct = Struct()
frontend_switch = Switch( frontend_switch = Switch(
this.type, this.type,
{**FrontendMessage.name_to_struct(), **SharedMessage.name_to_struct()}, {**FrontendMessage.name_to_struct(), **SharedMessage.name_to_struct()},
default=Bytes(this.length - 4) default=Bytes(this.length - 4),
) )
backend_switch = Switch( backend_switch = Switch(
this.type, this.type,
{**BackendMessage.name_to_struct(), **SharedMessage.name_to_struct()}, {**BackendMessage.name_to_struct(), **SharedMessage.name_to_struct()},
default=Bytes(this.length - 4) default=Bytes(this.length - 4),
) )
frontend_msgtypes = Enum(Byte, **{ frontend_msgtypes = Enum(
**FrontendMessage.name_to_key(), Byte, **{**FrontendMessage.name_to_key(), **SharedMessage.name_to_key()}
**SharedMessage.name_to_key() )
})
backend_msgtypes = Enum(Byte, **{ backend_msgtypes = Enum(
**BackendMessage.name_to_key(), Byte, **{**BackendMessage.name_to_key(), **SharedMessage.name_to_key()}
**SharedMessage.name_to_key() )
})
# It might seem a little circuitous to say a frontend message is a kind of frontend # It might seem a little circuitous to say a frontend message is a kind of frontend
# message but this lets us easily customize how they're printed # message but this lets us easily customize how they're printed
class Frontend(FrontendMessage): class Frontend(FrontendMessage):
struct = Struct( struct = Struct(
"type" / frontend_msgtypes, "type" / frontend_msgtypes,
@ -447,9 +488,7 @@ class Frontend(FrontendMessage):
def print(message): def print(message):
if isinstance(message.body, bytes): if isinstance(message.body, bytes):
return "Frontend(type={},body={})".format( return "Frontend(type={},body={})".format(chr(message.type), message.body)
chr(message.type), message.body
)
return FrontendMessage.print_message(message.body) return FrontendMessage.print_message(message.body)
def typeof(message): def typeof(message):
@ -457,6 +496,7 @@ class Frontend(FrontendMessage):
return "Unknown" return "Unknown"
return message.body._type return message.body._type
class Backend(BackendMessage): class Backend(BackendMessage):
struct = Struct( struct = Struct(
"type" / backend_msgtypes, "type" / backend_msgtypes,
@ -468,9 +508,7 @@ class Backend(BackendMessage):
def print(message): def print(message):
if isinstance(message.body, bytes): if isinstance(message.body, bytes):
return "Backend(type={},body={})".format( return "Backend(type={},body={})".format(chr(message.type), message.body)
chr(message.type), message.body
)
return BackendMessage.print_message(message.body) return BackendMessage.print_message(message.body)
def typeof(message): def typeof(message):
@ -478,10 +516,12 @@ class Backend(BackendMessage):
return "Unknown" return "Unknown"
return message.body._type return message.body._type
# GreedyRange keeps reading messages until we hit EOF # GreedyRange keeps reading messages until we hit EOF
frontend_messages = GreedyRange(Frontend.struct) frontend_messages = GreedyRange(Frontend.struct)
backend_messages = GreedyRange(Backend.struct) backend_messages = GreedyRange(Backend.struct)
def parse(message, from_frontend=True): def parse(message, from_frontend=True):
if from_frontend: if from_frontend:
message = frontend_messages.parse(message) message = frontend_messages.parse(message)
@ -491,24 +531,27 @@ def parse(message, from_frontend=True):
return message return message
def print(message): def print(message):
if message.from_frontend: if message.from_frontend:
return FrontendMessage.print_message(message) return FrontendMessage.print_message(message)
return BackendMessage.print_message(message) return BackendMessage.print_message(message)
def message_type(message, from_frontend): def message_type(message, from_frontend):
if from_frontend: if from_frontend:
return FrontendMessage.find_typeof(message) return FrontendMessage.find_typeof(message)
return BackendMessage.find_typeof(message) return BackendMessage.find_typeof(message)
def message_matches(message, filters, from_frontend): def message_matches(message, filters, from_frontend):
''' """
Message is something like Backend(Query)) and fiters is something like query="COPY". Message is something like Backend(Query)) and fiters is something like query="COPY".
For now we only support strings, and treat them like a regex, which is matched against For now we only support strings, and treat them like a regex, which is matched against
the content of the wrapped message the content of the wrapped message
''' """
if message._type != 'Backend' and message._type != 'Frontend': if message._type != "Backend" and message._type != "Frontend":
raise ValueError("can't handle {}".format(message._type)) raise ValueError("can't handle {}".format(message._type))
wrapped = message.body wrapped = message.body