import asyncio import atexit import concurrent.futures import os import pathlib import platform import random import re import shutil import socket import subprocess import sys import time import typing from abc import ABC, abstractmethod from contextlib import asynccontextmanager, closing, contextmanager from datetime import datetime, timedelta from pathlib import Path from tempfile import gettempdir import filelock import psycopg import psycopg.sql import utils from psycopg import sql from utils import USER # This SQL returns true ( 't' ) if the Citus version >= 11.0. IS_CITUS_VERSION_11_SQL = "SELECT (split_part(extversion, '.', 1)::int >= 11) as is_11 FROM pg_extension WHERE extname = 'citus';" LINUX = False MACOS = False FREEBSD = False OPENBSD = False if platform.system() == "Linux": LINUX = True elif platform.system() == "Darwin": MACOS = True elif platform.system() == "FreeBSD": FREEBSD = True elif platform.system() == "OpenBSD": OPENBSD = True BSD = MACOS or FREEBSD or OPENBSD TIMEOUT_DEFAULT = timedelta(seconds=int(os.getenv("PG_TEST_TIMEOUT_DEFAULT", "10"))) FORCE_PORTS = os.getenv("PG_FORCE_PORTS", "NO").lower() not in ("no", "0", "n", "") REGRESS_DIR = pathlib.Path(os.path.realpath(__file__)).parent.parent REPO_ROOT = REGRESS_DIR.parent.parent.parent CI = os.environ.get("CI") == "true" def eprint(*args, **kwargs): """eprint prints to stderr""" print(*args, file=sys.stderr, **kwargs) def run(command, *args, check=True, shell=True, silent=False, **kwargs): """run runs the given command and prints it to stderr""" if not silent: eprint(f"+ {command} ") if silent: kwargs.setdefault("stdout", subprocess.DEVNULL) return subprocess.run(command, *args, check=check, shell=shell, **kwargs) def capture(command, *args, **kwargs): """runs the given command and returns its output as a string""" return run(command, *args, stdout=subprocess.PIPE, text=True, **kwargs).stdout PG_CONFIG = os.environ.get("PG_CONFIG", "pg_config") PG_BINDIR = capture([PG_CONFIG, "--bindir"], shell=False).rstrip() os.environ["PATH"] = PG_BINDIR + os.pathsep + os.environ["PATH"] def get_pg_major_version(): full_version_string = run( "initdb --version", stdout=subprocess.PIPE, encoding="utf-8", silent=True ).stdout major_version_string = re.search("[0-9]+", full_version_string) assert major_version_string is not None return int(major_version_string.group(0)) PG_MAJOR_VERSION = get_pg_major_version() OLDEST_SUPPORTED_CITUS_VERSION_MATRIX = { 14: "10.2.0", 15: "11.1.5", 16: "12.1.1", } OLDEST_SUPPORTED_CITUS_VERSION = OLDEST_SUPPORTED_CITUS_VERSION_MATRIX[PG_MAJOR_VERSION] def initialize_temp_dir(temp_dir): if os.path.exists(temp_dir): shutil.rmtree(temp_dir) os.mkdir(temp_dir) # Give full access to TEMP_DIR so that postgres user can use it. os.chmod(temp_dir, 0o777) def initialize_temp_dir_if_not_exists(temp_dir): if os.path.exists(temp_dir): return os.mkdir(temp_dir) # Give full access to TEMP_DIR so that postgres user can use it. os.chmod(temp_dir, 0o777) def parallel_run(function, items, *args, **kwargs): with concurrent.futures.ThreadPoolExecutor() as executor: futures = [executor.submit(function, item, *args, **kwargs) for item in items] for future in futures: future.result() def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names): subprocess.run(["mkdir", rel_data_path], check=True) def initialize(node_name): abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name)) command = [ os.path.join(pg_path, "initdb"), "--pgdata", abs_data_path, "--username", USER, "--no-sync", # --allow-group-access is used to ensure we set permissions on # private keys correctly "--allow-group-access", "--encoding", "UTF8", "--locale", "POSIX", ] subprocess.run(command, check=True) add_settings(abs_data_path, settings) parallel_run(initialize, node_names) def add_settings(abs_data_path, settings): conf_path = os.path.join(abs_data_path, "postgresql.conf") with open(conf_path, "a") as conf_file: for setting_key, setting_val in settings.items(): setting = "{setting_key} = '{setting_val}'\n".format( setting_key=setting_key, setting_val=setting_val ) conf_file.write(setting) def create_role(pg_path, node_ports, user_name): def create(port): command = ( "SET citus.enable_ddl_propagation TO OFF;" + "SELECT worker_create_or_alter_role('{}', 'CREATE ROLE {} WITH LOGIN CREATEROLE CREATEDB;', NULL)".format( user_name, user_name ) ) utils.psql(pg_path, port, command) command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format( user_name ) utils.psql(pg_path, port, command) parallel_run(create, node_ports) def coordinator_should_haveshards(pg_path, port): command = "SELECT citus_set_node_property('localhost', {}, 'shouldhaveshards', true)".format( port ) utils.psql(pg_path, port, command) def start_databases( pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables ): def start(node_name): abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name)) node_port = node_name_to_ports[node_name] command = [ os.path.join(pg_path, "pg_ctl"), "start", "--pgdata", abs_data_path, "-U", USER, "-o", "-p {}".format(node_port), "--log", os.path.join(abs_data_path, logfile_name(logfile_prefix, node_name)), ] # set the application name if requires if env_variables != {}: os.environ.update(env_variables) subprocess.run(command, check=True) parallel_run(start, node_name_to_ports.keys()) # We don't want parallel shutdown here because that will fail when it's # tried in this atexit call with an error like: # cannot schedule new futures after interpreter shutdown atexit.register( stop_databases, pg_path, rel_data_path, node_name_to_ports, logfile_prefix, no_output=True, parallel=False, ) def create_citus_extension(pg_path, node_ports): def create(port): utils.psql(pg_path, port, "CREATE EXTENSION citus;") parallel_run(create, node_ports) def run_pg_regress(pg_path, pg_srcdir, port, schedule): should_exit = True try: _run_pg_regress(pg_path, pg_srcdir, port, schedule, should_exit) finally: subprocess.run("bin/copy_modified", check=True) def run_pg_regress_without_exit( pg_path, pg_srcdir, port, schedule, output_dir=".", input_dir=".", user="postgres", extra_tests="", ): should_exit = False exit_code = _run_pg_regress( pg_path, pg_srcdir, port, schedule, should_exit, output_dir, input_dir, user, extra_tests, ) copy_binary_path = os.path.join(input_dir, "copy_modified_wrapper") exit_code |= subprocess.call(copy_binary_path) return exit_code def _run_pg_regress( pg_path, pg_srcdir, port, schedule, should_exit, output_dir=".", input_dir=".", user="postgres", extra_tests="", ): command = [ os.path.join(pg_srcdir, "src/test/regress/pg_regress"), "--port", str(port), "--schedule", schedule, "--bindir", pg_path, "--user", user, "--dbname", "postgres", "--inputdir", input_dir, "--outputdir", output_dir, "--use-existing", ] if PG_MAJOR_VERSION >= 16: command.append("--expecteddir") command.append(output_dir) if extra_tests != "": command.append(extra_tests) exit_code = subprocess.call(command) if should_exit and exit_code != 0: sys.exit(exit_code) return exit_code def save_regression_diff(name, output_dir): path = os.path.join(output_dir, "regression.diffs") if not os.path.exists(path): return new_file_path = os.path.join(output_dir, "./regression_{}.diffs".format(name)) print("new file path:", new_file_path) shutil.move(path, new_file_path) def stop_metadata_to_workers(pg_path, worker_ports, coordinator_port): for port in worker_ports: command = ( "SELECT * from stop_metadata_sync_to_node('localhost', {port});".format( port=port ) ) utils.psql(pg_path, coordinator_port, command) def add_coordinator_to_metadata(pg_path, coordinator_port): command = "SELECT citus_set_coordinator_host('localhost');" utils.psql(pg_path, coordinator_port, command) def add_workers(pg_path, worker_ports, coordinator_port): for port in worker_ports: command = "SELECT * from master_add_node('localhost', {port});".format( port=port ) utils.psql(pg_path, coordinator_port, command) def logfile_name(logfile_prefix, node_name): return "logfile_" + logfile_prefix + "_" + node_name def stop_databases( pg_path, rel_data_path, node_name_to_ports, logfile_prefix, no_output=False, parallel=True, ): def stop(node_name): abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name)) node_port = node_name_to_ports[node_name] command = [ os.path.join(pg_path, "pg_ctl"), "stop", "--pgdata", abs_data_path, "-U", USER, "-o", "-p {}".format(node_port), "--log", os.path.join(abs_data_path, logfile_name(logfile_prefix, node_name)), ] if no_output: subprocess.call( command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL ) else: subprocess.call(command) if parallel: parallel_run(stop, node_name_to_ports.keys()) else: for node_name in node_name_to_ports.keys(): stop(node_name) def is_citus_set_coordinator_host_udf_exist(pg_path, port): return utils.psql_capture(pg_path, port, IS_CITUS_VERSION_11_SQL) == b" t\n\n" def initialize_citus_cluster(bindir, datadir, settings, config): # In case there was a leftover from previous runs, stop the databases stop_databases( bindir, datadir, config.node_name_to_ports, config.name, no_output=True ) initialize_db_for_cluster( bindir, datadir, settings, config.node_name_to_ports.keys() ) start_databases( bindir, datadir, config.node_name_to_ports, config.name, config.env_variables ) create_citus_extension(bindir, config.node_name_to_ports.values()) # In upgrade tests, it is possible that Citus version < 11.0 # where the citus_set_coordinator_host UDF does not exist. if is_citus_set_coordinator_host_udf_exist(bindir, config.coordinator_port()): add_coordinator_to_metadata(bindir, config.coordinator_port()) add_workers(bindir, config.worker_ports, config.coordinator_port()) if not config.is_mx: stop_metadata_to_workers(bindir, config.worker_ports, config.coordinator_port()) config.setup_steps() def sudo(command, *args, shell=True, **kwargs): """ A version of run that prefixes the command with sudo when the process is not already run as root """ effective_user_id = os.geteuid() if effective_user_id == 0: return run(command, *args, shell=shell, **kwargs) if shell: return run(f"sudo {command}", *args, shell=shell, **kwargs) else: return run(["sudo", *command]) # this is out of ephemeral port range for many systems hence # it is a lower chance that it will conflict with "in-use" ports PORT_LOWER_BOUND = 10200 # ephemeral port start on many Linux systems PORT_UPPER_BOUND = 32768 next_port = PORT_LOWER_BOUND def notice_handler(diag: psycopg.errors.Diagnostic): print(f"{diag.severity}: {diag.message_primary}") if diag.message_detail: print(f"DETAIL: {diag.message_detail}") if diag.message_hint: print(f"HINT: {diag.message_hint}") if diag.context: print(f"CONTEXT: {diag.context}") def cleanup_test_leftovers(nodes): """ Cleaning up test leftovers needs to be done in a specific order, because some of these leftovers depend on others having been removed. They might even depend on leftovers on other nodes being removed. So this takes a list of nodes, so that we can clean up all test leftovers globally in the correct order. """ for node in nodes: node.cleanup_subscriptions() for node in nodes: node.cleanup_publications() for node in nodes: node.cleanup_replication_slots() for node in nodes: node.cleanup_schemas() for node in nodes: node.cleanup_databases() for node in nodes: node.cleanup_users() class PortLock: """PortLock allows you to take a lock an a specific port. While a port is locked by one process, other processes using PortLock won't get the same port. """ def __init__(self): global next_port first_port = next_port while True: next_port += 1 if next_port >= PORT_UPPER_BOUND: next_port = PORT_LOWER_BOUND # avoid infinite loop if first_port == next_port: raise Exception("Could not find port") self.lock = filelock.FileLock(Path(gettempdir()) / f"port-{next_port}.lock") try: self.lock.acquire(timeout=0) except filelock.Timeout: continue if FORCE_PORTS: self.port = next_port break with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: try: s.bind(("127.0.0.1", next_port)) self.port = next_port break except Exception: self.lock.release() continue def release(self): """Call release when you are done with the port. This way other processes can use it again. """ self.lock.release() class QueryRunner(ABC): """A subclassable interface class that can be used to run queries. This is mostly useful to be generic across differnt types of things that implement the Postgres interface, such as Postgres, PgBouncer, or a Citus cluster. This implements some helpers send queries in a simpler manner than psycopg allows by default. """ @abstractmethod def set_default_connection_options(self, options: dict[str, typing.Any]): """Sets the default connection options on the given options dictionary This is the only method that the class that subclasses QueryRunner needs to implement. """ ... def make_conninfo(self, **kwargs) -> str: self.set_default_connection_options(kwargs) return psycopg.conninfo.make_conninfo(**kwargs) def conn(self, *, autocommit=True, **kwargs): """Open a psycopg connection to this server""" self.set_default_connection_options(kwargs) conn = psycopg.connect( autocommit=autocommit, **kwargs, ) conn.add_notice_handler(notice_handler) return conn def aconn(self, *, autocommit=True, **kwargs): """Open an asynchronous psycopg connection to this server""" self.set_default_connection_options(kwargs) return psycopg.AsyncConnection.connect( autocommit=autocommit, **kwargs, ) @contextmanager def cur(self, autocommit=True, **kwargs): """Open an psycopg cursor to this server The connection and the cursors automatically close once you leave the "with" block """ with self.conn( autocommit=autocommit, **kwargs, ) as conn: with conn.cursor() as cur: yield cur @asynccontextmanager async def acur(self, **kwargs): """Open an asynchronous psycopg cursor to this server The connection and the cursors automatically close once you leave the "async with" block """ async with await self.aconn(**kwargs) as conn: async with conn.cursor() as cur: yield cur def sql(self, query, params=None, **kwargs): """Run an SQL query This opens a new connection and closes it once the query is done """ with self.cur(**kwargs) as cur: cur.execute(query, params=params) def sql_prepared(self, query, params=None, **kwargs): """Run an SQL query, with prepare=True This opens a new connection and closes it once the query is done """ with self.cur(**kwargs) as cur: cur.execute(query, params=params, prepare=True) def sql_row(self, query, params=None, allow_empty_result=False, **kwargs): """Run an SQL query that returns a single row and returns this row This opens a new connection and closes it once the query is done """ with self.cur(**kwargs) as cur: cur.execute(query, params=params) result = cur.fetchall() if allow_empty_result and len(result) == 0: return None assert len(result) == 1, "sql_row returns more than one row" return result[0] def sql_value(self, query, params=None, allow_empty_result=False, **kwargs): """Run an SQL query that returns a single cell and return this value This opens a new connection and closes it once the query is done """ with self.cur(**kwargs) as cur: cur.execute(query, params=params) result = cur.fetchall() if allow_empty_result and len(result) == 0: return None assert len(result) == 1 assert len(result[0]) == 1 value = result[0][0] return value def asql(self, query, **kwargs): """Run an SQL query in asynchronous task This opens a new connection and closes it once the query is done """ return asyncio.ensure_future(self.asql_coroutine(query, **kwargs)) async def asql_coroutine( self, query, params=None, **kwargs ) -> typing.Optional[typing.List[typing.Any]]: async with self.acur(**kwargs) as cur: await cur.execute(query, params=params) try: return await cur.fetchall() except psycopg.ProgrammingError as e: if "the last operation didn't produce a result" == str(e): return None raise def psql(self, query, **kwargs): """Run an SQL query using psql instead of psycopg This opens a new connection and closes it once the query is done """ conninfo = self.make_conninfo(**kwargs) run( ["psql", "-X", f"{conninfo}", "-c", query], shell=False, silent=True, ) def poll_query_until(self, query, params=None, expected=True, **kwargs): """Run query repeatedly until it returns the expected result""" start = datetime.now() result = None while start + TIMEOUT_DEFAULT > datetime.now(): result = self.sql_value( query, params=params, allow_empty_result=True, **kwargs ) if result == expected: return time.sleep(0.1) raise Exception( f"Timeout reached while polling query, last result was: {result}" ) @contextmanager def transaction(self, **kwargs): with self.cur(**kwargs) as cur: with cur.connection.transaction(): yield cur def sleep(self, duration=3, **kwargs): """Run pg_sleep""" return self.sql(f"select pg_sleep({duration})", **kwargs) def asleep(self, duration=3, times=1, sequentially=False, **kwargs): """Run pg_sleep asynchronously in a task. times: You can create a single task that opens multiple connections, which run pg_sleep concurrently. The asynchronous task will only complete once all these pg_sleep calls are finished. sequentially: Instead of running all pg_sleep calls spawned by providing times > 1 concurrently, this will run them sequentially. """ return asyncio.ensure_future( self.asleep_coroutine( duration=duration, times=times, sequentially=sequentially, **kwargs ) ) async def asleep_coroutine(self, duration=3, times=1, sequentially=False, **kwargs): """This is the coroutine that the asleep task runs internally""" if not sequentially: await asyncio.gather( *[ self.asql(f"select pg_sleep({duration})", **kwargs) for _ in range(times) ] ) else: for _ in range(times): await self.asql(f"select pg_sleep({duration})", **kwargs) def test(self, **kwargs): """Test if you can connect""" return self.sql("select 1", **kwargs) def atest(self, **kwargs): """Test if you can connect asynchronously""" return self.asql("select 1", **kwargs) def psql_test(self, **kwargs): """Test if you can connect with psql instead of psycopg""" return self.psql("select 1", **kwargs) def debug(self): print("Connect manually to:\n ", repr(self.make_conninfo())) print("Press Enter to continue running the test...") input() def psql_debug(self, **kwargs): conninfo = self.make_conninfo(**kwargs) run( ["psql", f"{conninfo}"], shell=False, silent=True, ) class Postgres(QueryRunner): """A class that represents a Postgres instance on this machine You can query it by using the interface provided by QueryRunner or use many of the helper methods. """ def __init__(self, pgdata): self.port_lock = PortLock() # These values should almost never be changed after initialization self.host = "127.0.0.1" self.port = self.port_lock.port # These values can be changed when needed self.dbname = "postgres" self.user = "postgres" self.schema = None self.pgdata = pgdata self.log_path = self.pgdata / "pg.log" # Used to track objects that we want to clean up at the end of a test self.subscriptions = set() self.publications = set() self.replication_slots = set() self.databases = set() self.schemas = set() self.users = set() def set_default_connection_options(self, options): options.setdefault("host", self.host) options.setdefault("port", self.port) options.setdefault("dbname", self.dbname) options.setdefault("user", self.user) if self.schema is not None: options.setdefault("options", f"-c search_path={self.schema}") options.setdefault("connect_timeout", 3) # needed for Ubuntu 18.04 options.setdefault("client_encoding", "UTF8") def initdb(self): run( f"initdb -A trust --nosync --username postgres --pgdata {self.pgdata} --allow-group-access --encoding UTF8 --locale POSIX", stdout=subprocess.DEVNULL, ) with self.conf_path.open(mode="a") as pgconf: # Allow connecting over unix sockets pgconf.write("unix_socket_directories = '/tmp'\n") # Useful logs for debugging issues pgconf.write("log_replication_commands = on\n") # The following to are also useful for debugging, but quite noisy. # So better to enable them manually by uncommenting. # pgconf.write("log_connections = on\n") # pgconf.write("log_disconnections = on\n") # Enable citus pgconf.write("shared_preload_libraries = 'citus'\n") # Allow CREATE SUBSCRIPTION to work pgconf.write("wal_level = 'logical'\n") # Faster logical replication status update so tests with logical replication # run faster pgconf.write("wal_receiver_status_interval = 1\n") # Faster logical replication apply worker launch so tests with logical # replication run faster. This is used in ApplyLauncherMain in # src/backend/replication/logical/launcher.c. pgconf.write("wal_retrieve_retry_interval = '250ms'\n") # Make sure there's enough logical replication resources for most # of our tests pgconf.write("max_logical_replication_workers = 50\n") pgconf.write("max_wal_senders = 50\n") pgconf.write("max_worker_processes = 50\n") pgconf.write("max_replication_slots = 50\n") # We need to make the log go to stderr so that the tests can # check what is being logged. This should be the default, but # some packagings change the default configuration. pgconf.write("log_destination = stderr\n") # We don't need the logs anywhere else than stderr pgconf.write("logging_collector = off\n") # This makes tests run faster and we don't care about crash safety # of our test data. pgconf.write("fsync = false\n") # conservative settings to ensure we can run multiple postmasters: pgconf.write("shared_buffers = 1MB\n") # limit disk space consumption, too: pgconf.write("max_wal_size = 128MB\n") # don't restart after crashes to make it obvious that a crash # happened pgconf.write("restart_after_crash = off\n") os.truncate(self.hba_path, 0) self.ssl_access("all", "trust") self.nossl_access("all", "trust") self.commit_hba() def init_with_citus(self): self.initdb() self.start() self.sql("CREATE EXTENSION citus") # Manually turn on ssl, so that we can safely truncate # postgresql.auto.conf later. We can only do this after creating the # citus extension because that creates the self signed certificates. with self.conf_path.open(mode="a") as pgconf: pgconf.write("ssl = on\n") def pgctl(self, command, **kwargs): run(f"pg_ctl -w --pgdata {self.pgdata} {command}", **kwargs) def apgctl(self, command, **kwargs): return asyncio.create_subprocess_shell( f"pg_ctl -w --pgdata {self.pgdata} {command}", **kwargs ) def start(self): try: self.pgctl(f'-o "-p {self.port}" -l {self.log_path} start') except Exception: print(f"\n\nPG_LOG: {self.pgdata}\n") with self.log_path.open() as f: print(f.read()) raise def stop(self, mode="fast"): self.pgctl(f"-m {mode} stop", check=False) def cleanup(self): self.stop() self.port_lock.release() def restart(self): self.stop() self.start() def reload(self): self.pgctl("reload") # Sadly UNIX signals are asynchronous, so we sleep a bit and hope that # Postgres actually processed the SIGHUP signal after the sleep. time.sleep(0.1) async def arestart(self): process = await self.apgctl("-m fast restart") await process.communicate() def nossl_access(self, dbname, auth_type): """Prepends a local non-SSL access to the HBA file""" with self.hba_path.open() as pghba: old_contents = pghba.read() with self.hba_path.open(mode="w") as pghba: pghba.write(f"local {dbname} all {auth_type}\n") pghba.write(f"hostnossl {dbname} all 127.0.0.1/32 {auth_type}\n") pghba.write(f"hostnossl {dbname} all ::1/128 {auth_type}\n") pghba.write(old_contents) def ssl_access(self, dbname, auth_type): """Prepends a local SSL access rule to the HBA file""" with self.hba_path.open() as pghba: old_contents = pghba.read() with self.hba_path.open(mode="w") as pghba: pghba.write(f"hostssl {dbname} all 127.0.0.1/32 {auth_type}\n") pghba.write(f"hostssl {dbname} all ::1/128 {auth_type}\n") pghba.write(old_contents) @property def hba_path(self): return self.pgdata / "pg_hba.conf" @property def conf_path(self): return self.pgdata / "postgresql.conf" def commit_hba(self): """Mark the current HBA contents as non-resetable by reset_hba""" with self.hba_path.open() as pghba: old_contents = pghba.read() with self.hba_path.open(mode="w") as pghba: pghba.write("# committed-rules\n") pghba.write(old_contents) def reset_hba(self): """Remove any HBA rules that were added after the last call to commit_hba""" with self.hba_path.open() as f: hba_contents = f.read() committed = hba_contents[hba_contents.find("# committed-rules\n") :] with self.hba_path.open("w") as f: f.write(committed) def prepare_reset(self): """Prepares all changes to reset Postgres settings and objects To actually apply the prepared changes a restart might still be needed. """ self.reset_hba() os.truncate(self.pgdata / "postgresql.auto.conf", 0) def reset(self): """Resets any changes to Postgres settings from previous tests""" self.prepare_reset() self.restart() async def delayed_start(self, delay=1): """Start Postgres after a delay NOTE: The sleep is asynchronous, but while waiting for Postgres to start the pg_ctl start command will block the event loop. This is currently acceptable for our usage of this method in the existing tests and this way it was easiest to implement. However, it seems totally reasonable to change this behaviour in the future if necessary. """ await asyncio.sleep(delay) self.start() def configure(self, *configs): """Configure specific Postgres settings using ALTER SYSTEM SET NOTE: after configuring a call to reload or restart is needed for the settings to become effective. """ for config in configs: self.sql(f"alter system set {config}") def log_handle(self): """Returns the opened logfile at the current end of the log By later calling read on this file you can read the contents that were written from this moment on. IMPORTANT: This handle should be closed once it's not needed anymore """ f = self.log_path.open() f.seek(0, os.SEEK_END) return f @contextmanager def log_contains(self, re_string, times=None): """Checks if during this with block the log matches re_string re_string: The regex to search for. times: If None, any number of matches is accepted. If a number, only that specific number of matches is accepted. """ with self.log_handle() as f: yield content = f.read() if times is None: assert re.search(re_string, content) else: match_count = len(re.findall(re_string, content)) assert match_count == times def create_user(self, name, args: typing.Optional[psycopg.sql.Composable] = None): self.users.add(name) if args is None: args = sql.SQL("") self.sql(sql.SQL("CREATE USER {} {}").format(sql.Identifier(name), args)) def create_database(self, name): self.databases.add(name) self.sql(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(name))) def create_schema(self, name): self.schemas.add(name) self.sql(sql.SQL("CREATE SCHEMA {}").format(sql.Identifier(name))) def create_publication(self, name: str, args: psycopg.sql.Composable): self.publications.add(name) self.sql(sql.SQL("CREATE PUBLICATION {} {}").format(sql.Identifier(name), args)) def create_logical_replication_slot( self, name, plugin, temporary=False, twophase=False ): self.replication_slots.add(name) self.sql( "SELECT pg_catalog.pg_create_logical_replication_slot(%s,%s,%s,%s)", (name, plugin, temporary, twophase), ) def create_subscription(self, name: str, args: psycopg.sql.Composable): self.subscriptions.add(name) self.sql( sql.SQL("CREATE SUBSCRIPTION {} {}").format(sql.Identifier(name), args) ) def cleanup_users(self): for user in self.users: self.sql(sql.SQL("DROP USER IF EXISTS {}").format(sql.Identifier(user))) def cleanup_databases(self): for database in self.databases: self.sql( sql.SQL("DROP DATABASE IF EXISTS {}").format(sql.Identifier(database)) ) def cleanup_schemas(self): for schema in self.schemas: self.sql( sql.SQL("DROP SCHEMA IF EXISTS {} CASCADE").format( sql.Identifier(schema) ) ) def cleanup_publications(self): for publication in self.publications: self.sql( sql.SQL("DROP PUBLICATION IF EXISTS {}").format( sql.Identifier(publication) ) ) def cleanup_replication_slots(self): for slot in self.replication_slots: start = time.time() while True: try: self.sql( "SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE slot_name = %s", (slot,), ) except psycopg.errors.ObjectInUse: if time.time() < start + 10: time.sleep(0.5) continue raise break def cleanup_subscriptions(self): for subscription in self.subscriptions: try: self.sql( sql.SQL("ALTER SUBSCRIPTION {} DISABLE").format( sql.Identifier(subscription) ) ) except psycopg.errors.UndefinedObject: # Subscription didn't exist already continue self.sql( sql.SQL("ALTER SUBSCRIPTION {} SET (slot_name = NONE)").format( sql.Identifier(subscription) ) ) self.sql( sql.SQL("DROP SUBSCRIPTION {}").format(sql.Identifier(subscription)) ) def lsn(self, mode): """Returns the lsn for the given mode""" queries = { "insert": "SELECT pg_current_wal_insert_lsn()", "flush": "SELECT pg_current_wal_flush_lsn()", "write": "SELECT pg_current_wal_lsn()", "receive": "SELECT pg_last_wal_receive_lsn()", "replay": "SELECT pg_last_wal_replay_lsn()", } return self.sql_value(queries[mode]) def wait_for_catchup(self, subscription_name, mode="replay", target_lsn=None): """Waits until the subscription has caught up""" if target_lsn is None: target_lsn = self.lsn("write") # Before release 12 walreceiver just set the application name to # "walreceiver" self.poll_query_until( sql.SQL( """ SELECT {} <= {} AND state = 'streaming' FROM pg_catalog.pg_stat_replication WHERE application_name IN ({}, 'walreceiver') """ ).format(target_lsn, sql.Identifier(f"{mode}_lsn"), subscription_name) ) @contextmanager def _enable_firewall(self): """Enables the firewall for the platform that you are running Normally this should not be called directly, and instead drop_traffic or reject_traffic should be used. """ fw_token = None if BSD: if MACOS: command_stderr = sudo( "pfctl -E", stderr=subprocess.PIPE, text=True ).stderr match = re.search(r"^Token : (\d+)", command_stderr, flags=re.MULTILINE) assert match is not None fw_token = match.group(1) sudo( 'bash -c "' f"echo 'anchor \\\"port_{self.port}\\\"'" f' | pfctl -a citus_test -f -"' ) try: yield finally: if MACOS: sudo(f"pfctl -X {fw_token}") @contextmanager def drop_traffic(self): """Drops all TCP packets to this query runner""" with self._enable_firewall(): if LINUX: sudo( "iptables --append OUTPUT " "--protocol tcp " f"--destination {self.host} " f"--destination-port {self.port} " "--jump DROP " ) elif BSD: sudo( "bash -c '" f'echo "block drop out proto tcp from any to {self.host} port {self.port}"' f"| pfctl -a citus_test/port_{self.port} -f -'" ) else: raise Exception("This OS cannot run this test") try: yield finally: if LINUX: sudo( "iptables --delete OUTPUT " "--protocol tcp " f"--destination {self.host} " f"--destination-port {self.port} " "--jump DROP " ) elif BSD: sudo(f"pfctl -a citus_test/port_{self.port} -F all") @contextmanager def reject_traffic(self): """Rejects all traffic to this query runner with a TCP RST message""" with self._enable_firewall(): if LINUX: sudo( "iptables --append OUTPUT " "--protocol tcp " f"--destination {self.host} " f"--destination-port {self.port} " "--jump REJECT " "--reject-with tcp-reset" ) elif BSD: sudo( "bash -c '" f'echo "block return-rst out out proto tcp from any to {self.host} port {self.port}"' f"| pfctl -a citus_test/port_{self.port} -f -'" ) else: raise Exception("This OS cannot run this test") try: yield finally: if LINUX: sudo( "iptables --delete OUTPUT " "--protocol tcp " f"--destination {self.host} " f"--destination-port {self.port} " "--jump REJECT " "--reject-with tcp-reset" ) elif BSD: sudo(f"pfctl -a citus_test/port_{self.port} -F all") class CitusCluster(QueryRunner): """A class that represents a Citus cluster on this machine The nodes in the cluster can be accessed directly using the coordinator, workers, and nodes properties. If it doesn't matter which of the nodes in the cluster is used to run a query, then you can use the methods provided by QueryRunner directly on the cluster. In that case a random node will be chosen to run your query. """ def __init__(self, basedir: Path, worker_count: int): self.coordinator = Postgres(basedir / "coordinator") self.workers = [Postgres(basedir / f"worker{i}") for i in range(worker_count)] self.nodes = [self.coordinator] + self.workers self._schema = None self.failed_reset = False parallel_run(Postgres.init_with_citus, self.nodes) with self.coordinator.cur() as cur: cur.execute( "SELECT pg_catalog.citus_set_coordinator_host(%s, %s)", (self.coordinator.host, self.coordinator.port), ) for worker in self.workers: cur.execute( "SELECT pg_catalog.citus_add_node(%s, %s)", (worker.host, worker.port), ) def set_default_connection_options(self, options): random.choice(self.nodes).set_default_connection_options(options) @property def schema(self): return self._schema @schema.setter def schema(self, value): self._schema = value for node in self.nodes: node.schema = value def reset(self): """Resets any changes to Postgres settings from previous tests""" parallel_run(Postgres.prepare_reset, self.nodes) parallel_run(Postgres.restart, self.nodes) def cleanup(self): parallel_run(Postgres.cleanup, self.nodes) def debug(self): """Print information to stdout to help with debugging your cluster""" print("The nodes in this cluster and their connection strings are:") for node in self.nodes: print(f"{node.pgdata}:\n ", repr(node.make_conninfo())) print("Press Enter to continue running the test...") input()