citus/src/test/regress/citus_tests/common.py

1281 lines
42 KiB
Python

import asyncio
import atexit
import concurrent.futures
import os
import pathlib
import platform
import random
import re
import shutil
import socket
import subprocess
import sys
import time
import typing
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager, closing, contextmanager
from datetime import datetime, timedelta
from pathlib import Path
from tempfile import gettempdir
import filelock
import psycopg
import psycopg.sql
import utils
from psycopg import sql
from utils import USER
# This SQL returns true ( 't' ) if the Citus version >= 11.0.
IS_CITUS_VERSION_11_SQL = "SELECT (split_part(extversion, '.', 1)::int >= 11) as is_11 FROM pg_extension WHERE extname = 'citus';"
LINUX = False
MACOS = False
FREEBSD = False
OPENBSD = False
if platform.system() == "Linux":
LINUX = True
elif platform.system() == "Darwin":
MACOS = True
elif platform.system() == "FreeBSD":
FREEBSD = True
elif platform.system() == "OpenBSD":
OPENBSD = True
BSD = MACOS or FREEBSD or OPENBSD
TIMEOUT_DEFAULT = timedelta(seconds=int(os.getenv("PG_TEST_TIMEOUT_DEFAULT", "10")))
FORCE_PORTS = os.getenv("PG_FORCE_PORTS", "NO").lower() not in ("no", "0", "n", "")
REGRESS_DIR = pathlib.Path(os.path.realpath(__file__)).parent.parent
REPO_ROOT = REGRESS_DIR.parent.parent.parent
CI = os.environ.get("CI") == "true"
def eprint(*args, **kwargs):
"""eprint prints to stderr"""
print(*args, file=sys.stderr, **kwargs)
def run(command, *args, check=True, shell=True, silent=False, **kwargs):
"""run runs the given command and prints it to stderr"""
if not silent:
eprint(f"+ {command} ")
if silent:
kwargs.setdefault("stdout", subprocess.DEVNULL)
return subprocess.run(command, *args, check=check, shell=shell, **kwargs)
def capture(command, *args, **kwargs):
"""runs the given command and returns its output as a string"""
return run(command, *args, stdout=subprocess.PIPE, text=True, **kwargs).stdout
PG_CONFIG = os.environ.get("PG_CONFIG", "pg_config")
PG_BINDIR = capture([PG_CONFIG, "--bindir"], shell=False).rstrip()
os.environ["PATH"] = PG_BINDIR + os.pathsep + os.environ["PATH"]
def get_pg_major_version():
full_version_string = run(
"initdb --version", stdout=subprocess.PIPE, encoding="utf-8", silent=True
).stdout
major_version_string = re.search("[0-9]+", full_version_string)
assert major_version_string is not None
return int(major_version_string.group(0))
PG_MAJOR_VERSION = get_pg_major_version()
OLDEST_SUPPORTED_CITUS_VERSION_MATRIX = {
14: "10.2.0",
15: "11.1.5",
16: "12.1devel",
}
OLDEST_SUPPORTED_CITUS_VERSION = OLDEST_SUPPORTED_CITUS_VERSION_MATRIX[PG_MAJOR_VERSION]
def initialize_temp_dir(temp_dir):
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.mkdir(temp_dir)
# Give full access to TEMP_DIR so that postgres user can use it.
os.chmod(temp_dir, 0o777)
def initialize_temp_dir_if_not_exists(temp_dir):
if os.path.exists(temp_dir):
return
os.mkdir(temp_dir)
# Give full access to TEMP_DIR so that postgres user can use it.
os.chmod(temp_dir, 0o777)
def parallel_run(function, items, *args, **kwargs):
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(function, item, *args, **kwargs) for item in items]
for future in futures:
future.result()
def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names):
subprocess.run(["mkdir", rel_data_path], check=True)
def initialize(node_name):
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
command = [
os.path.join(pg_path, "initdb"),
"--pgdata",
abs_data_path,
"--username",
USER,
"--no-sync",
# --allow-group-access is used to ensure we set permissions on
# private keys correctly
"--allow-group-access",
"--encoding",
"UTF8",
"--locale",
"POSIX",
]
subprocess.run(command, check=True)
add_settings(abs_data_path, settings)
parallel_run(initialize, node_names)
def add_settings(abs_data_path, settings):
conf_path = os.path.join(abs_data_path, "postgresql.conf")
with open(conf_path, "a") as conf_file:
for setting_key, setting_val in settings.items():
setting = "{setting_key} = '{setting_val}'\n".format(
setting_key=setting_key, setting_val=setting_val
)
conf_file.write(setting)
def create_role(pg_path, node_ports, user_name):
def create(port):
command = (
"SET citus.enable_ddl_propagation TO OFF;"
+ "SELECT worker_create_or_alter_role('{}', 'CREATE ROLE {} WITH LOGIN CREATEROLE CREATEDB;', NULL)".format(
user_name, user_name
)
)
utils.psql(pg_path, port, command)
command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format(
user_name
)
utils.psql(pg_path, port, command)
parallel_run(create, node_ports)
def coordinator_should_haveshards(pg_path, port):
command = "SELECT citus_set_node_property('localhost', {}, 'shouldhaveshards', true)".format(
port
)
utils.psql(pg_path, port, command)
def start_databases(
pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables
):
def start(node_name):
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
node_port = node_name_to_ports[node_name]
command = [
os.path.join(pg_path, "pg_ctl"),
"start",
"--pgdata",
abs_data_path,
"-U",
USER,
"-o",
"-p {}".format(node_port),
"--log",
os.path.join(abs_data_path, logfile_name(logfile_prefix, node_name)),
]
# set the application name if requires
if env_variables != {}:
os.environ.update(env_variables)
subprocess.run(command, check=True)
parallel_run(start, node_name_to_ports.keys())
# We don't want parallel shutdown here because that will fail when it's
# tried in this atexit call with an error like:
# cannot schedule new futures after interpreter shutdown
atexit.register(
stop_databases,
pg_path,
rel_data_path,
node_name_to_ports,
logfile_prefix,
no_output=True,
parallel=False,
)
def create_citus_extension(pg_path, node_ports):
def create(port):
utils.psql(pg_path, port, "CREATE EXTENSION citus;")
parallel_run(create, node_ports)
def run_pg_regress(pg_path, pg_srcdir, port, schedule):
should_exit = True
try:
_run_pg_regress(pg_path, pg_srcdir, port, schedule, should_exit)
finally:
subprocess.run("bin/copy_modified", check=True)
def run_pg_regress_without_exit(
pg_path,
pg_srcdir,
port,
schedule,
output_dir=".",
input_dir=".",
user="postgres",
extra_tests="",
):
should_exit = False
exit_code = _run_pg_regress(
pg_path,
pg_srcdir,
port,
schedule,
should_exit,
output_dir,
input_dir,
user,
extra_tests,
)
copy_binary_path = os.path.join(input_dir, "copy_modified_wrapper")
exit_code |= subprocess.call(copy_binary_path)
return exit_code
def _run_pg_regress(
pg_path,
pg_srcdir,
port,
schedule,
should_exit,
output_dir=".",
input_dir=".",
user="postgres",
extra_tests="",
):
command = [
os.path.join(pg_srcdir, "src/test/regress/pg_regress"),
"--port",
str(port),
"--schedule",
schedule,
"--bindir",
pg_path,
"--user",
user,
"--dbname",
"postgres",
"--inputdir",
input_dir,
"--outputdir",
output_dir,
"--use-existing",
]
if extra_tests != "":
command.append(extra_tests)
exit_code = subprocess.call(command)
if should_exit and exit_code != 0:
sys.exit(exit_code)
return exit_code
def save_regression_diff(name, output_dir):
path = os.path.join(output_dir, "regression.diffs")
if not os.path.exists(path):
return
new_file_path = os.path.join(output_dir, "./regression_{}.diffs".format(name))
print("new file path:", new_file_path)
shutil.move(path, new_file_path)
def stop_metadata_to_workers(pg_path, worker_ports, coordinator_port):
for port in worker_ports:
command = (
"SELECT * from stop_metadata_sync_to_node('localhost', {port});".format(
port=port
)
)
utils.psql(pg_path, coordinator_port, command)
def add_coordinator_to_metadata(pg_path, coordinator_port):
command = "SELECT citus_set_coordinator_host('localhost');"
utils.psql(pg_path, coordinator_port, command)
def add_workers(pg_path, worker_ports, coordinator_port):
for port in worker_ports:
command = "SELECT * from master_add_node('localhost', {port});".format(
port=port
)
utils.psql(pg_path, coordinator_port, command)
def logfile_name(logfile_prefix, node_name):
return "logfile_" + logfile_prefix + "_" + node_name
def stop_databases(
pg_path,
rel_data_path,
node_name_to_ports,
logfile_prefix,
no_output=False,
parallel=True,
):
def stop(node_name):
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
node_port = node_name_to_ports[node_name]
command = [
os.path.join(pg_path, "pg_ctl"),
"stop",
"--pgdata",
abs_data_path,
"-U",
USER,
"-o",
"-p {}".format(node_port),
"--log",
os.path.join(abs_data_path, logfile_name(logfile_prefix, node_name)),
]
if no_output:
subprocess.call(
command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
else:
subprocess.call(command)
if parallel:
parallel_run(stop, node_name_to_ports.keys())
else:
for node_name in node_name_to_ports.keys():
stop(node_name)
def is_citus_set_coordinator_host_udf_exist(pg_path, port):
return utils.psql_capture(pg_path, port, IS_CITUS_VERSION_11_SQL) == b" t\n\n"
def initialize_citus_cluster(bindir, datadir, settings, config):
# In case there was a leftover from previous runs, stop the databases
stop_databases(
bindir, datadir, config.node_name_to_ports, config.name, no_output=True
)
initialize_db_for_cluster(
bindir, datadir, settings, config.node_name_to_ports.keys()
)
start_databases(
bindir, datadir, config.node_name_to_ports, config.name, config.env_variables
)
create_citus_extension(bindir, config.node_name_to_ports.values())
# In upgrade tests, it is possible that Citus version < 11.0
# where the citus_set_coordinator_host UDF does not exist.
if is_citus_set_coordinator_host_udf_exist(bindir, config.coordinator_port()):
add_coordinator_to_metadata(bindir, config.coordinator_port())
add_workers(bindir, config.worker_ports, config.coordinator_port())
if not config.is_mx:
stop_metadata_to_workers(bindir, config.worker_ports, config.coordinator_port())
config.setup_steps()
def sudo(command, *args, shell=True, **kwargs):
"""
A version of run that prefixes the command with sudo when the process is
not already run as root
"""
effective_user_id = os.geteuid()
if effective_user_id == 0:
return run(command, *args, shell=shell, **kwargs)
if shell:
return run(f"sudo {command}", *args, shell=shell, **kwargs)
else:
return run(["sudo", *command])
# this is out of ephemeral port range for many systems hence
# it is a lower chance that it will conflict with "in-use" ports
PORT_LOWER_BOUND = 10200
# ephemeral port start on many Linux systems
PORT_UPPER_BOUND = 32768
next_port = PORT_LOWER_BOUND
def notice_handler(diag: psycopg.errors.Diagnostic):
print(f"{diag.severity}: {diag.message_primary}")
def cleanup_test_leftovers(nodes):
"""
Cleaning up test leftovers needs to be done in a specific order, because
some of these leftovers depend on others having been removed. They might
even depend on leftovers on other nodes being removed. So this takes a list
of nodes, so that we can clean up all test leftovers globally in the
correct order.
"""
for node in nodes:
node.cleanup_subscriptions()
for node in nodes:
node.cleanup_publications()
for node in nodes:
node.cleanup_replication_slots()
for node in nodes:
node.cleanup_schemas()
for node in nodes:
node.cleanup_databases()
for node in nodes:
node.cleanup_users()
class PortLock:
"""PortLock allows you to take a lock an a specific port.
While a port is locked by one process, other processes using PortLock won't
get the same port.
"""
def __init__(self):
global next_port
first_port = next_port
while True:
next_port += 1
if next_port >= PORT_UPPER_BOUND:
next_port = PORT_LOWER_BOUND
# avoid infinite loop
if first_port == next_port:
raise Exception("Could not find port")
self.lock = filelock.FileLock(Path(gettempdir()) / f"port-{next_port}.lock")
try:
self.lock.acquire(timeout=0)
except filelock.Timeout:
continue
if FORCE_PORTS:
self.port = next_port
break
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
try:
s.bind(("127.0.0.1", next_port))
self.port = next_port
break
except Exception:
self.lock.release()
continue
def release(self):
"""Call release when you are done with the port.
This way other processes can use it again.
"""
self.lock.release()
class QueryRunner(ABC):
"""A subclassable interface class that can be used to run queries.
This is mostly useful to be generic across differnt types of things that
implement the Postgres interface, such as Postgres, PgBouncer, or a Citus
cluster.
This implements some helpers send queries in a simpler manner than psycopg
allows by default.
"""
@abstractmethod
def set_default_connection_options(self, options: dict[str, typing.Any]):
"""Sets the default connection options on the given options dictionary
This is the only method that the class that subclasses QueryRunner
needs to implement.
"""
...
def make_conninfo(self, **kwargs) -> str:
self.set_default_connection_options(kwargs)
return psycopg.conninfo.make_conninfo(**kwargs)
def conn(self, *, autocommit=True, **kwargs):
"""Open a psycopg connection to this server"""
self.set_default_connection_options(kwargs)
conn = psycopg.connect(
autocommit=autocommit,
**kwargs,
)
conn.add_notice_handler(notice_handler)
return conn
def aconn(self, *, autocommit=True, **kwargs):
"""Open an asynchronous psycopg connection to this server"""
self.set_default_connection_options(kwargs)
return psycopg.AsyncConnection.connect(
autocommit=autocommit,
**kwargs,
)
@contextmanager
def cur(self, autocommit=True, **kwargs):
"""Open an psycopg cursor to this server
The connection and the cursors automatically close once you leave the
"with" block
"""
with self.conn(
autocommit=autocommit,
**kwargs,
) as conn:
with conn.cursor() as cur:
yield cur
@asynccontextmanager
async def acur(self, **kwargs):
"""Open an asynchronous psycopg cursor to this server
The connection and the cursors automatically close once you leave the
"async with" block
"""
async with await self.aconn(**kwargs) as conn:
async with conn.cursor() as cur:
yield cur
def sql(self, query, params=None, **kwargs):
"""Run an SQL query
This opens a new connection and closes it once the query is done
"""
with self.cur(**kwargs) as cur:
cur.execute(query, params=params)
def sql_prepared(self, query, params=None, **kwargs):
"""Run an SQL query, with prepare=True
This opens a new connection and closes it once the query is done
"""
with self.cur(**kwargs) as cur:
cur.execute(query, params=params, prepare=True)
def sql_row(self, query, params=None, allow_empty_result=False, **kwargs):
"""Run an SQL query that returns a single row and returns this row
This opens a new connection and closes it once the query is done
"""
with self.cur(**kwargs) as cur:
cur.execute(query, params=params)
result = cur.fetchall()
if allow_empty_result and len(result) == 0:
return None
assert len(result) == 1, "sql_row returns more than one row"
return result[0]
def sql_value(self, query, params=None, allow_empty_result=False, **kwargs):
"""Run an SQL query that returns a single cell and return this value
This opens a new connection and closes it once the query is done
"""
with self.cur(**kwargs) as cur:
cur.execute(query, params=params)
result = cur.fetchall()
if allow_empty_result and len(result) == 0:
return None
assert len(result) == 1
assert len(result[0]) == 1
value = result[0][0]
return value
def asql(self, query, **kwargs):
"""Run an SQL query in asynchronous task
This opens a new connection and closes it once the query is done
"""
return asyncio.ensure_future(self.asql_coroutine(query, **kwargs))
async def asql_coroutine(
self, query, params=None, **kwargs
) -> typing.Optional[typing.List[typing.Any]]:
async with self.acur(**kwargs) as cur:
await cur.execute(query, params=params)
try:
return await cur.fetchall()
except psycopg.ProgrammingError as e:
if "the last operation didn't produce a result" == str(e):
return None
raise
def psql(self, query, **kwargs):
"""Run an SQL query using psql instead of psycopg
This opens a new connection and closes it once the query is done
"""
conninfo = self.make_conninfo(**kwargs)
run(
["psql", "-X", f"{conninfo}", "-c", query],
shell=False,
silent=True,
)
def poll_query_until(self, query, params=None, expected=True, **kwargs):
"""Run query repeatedly until it returns the expected result"""
start = datetime.now()
result = None
while start + TIMEOUT_DEFAULT > datetime.now():
result = self.sql_value(
query, params=params, allow_empty_result=True, **kwargs
)
if result == expected:
return
time.sleep(0.1)
raise Exception(
f"Timeout reached while polling query, last result was: {result}"
)
@contextmanager
def transaction(self, **kwargs):
with self.cur(**kwargs) as cur:
with cur.connection.transaction():
yield cur
def sleep(self, duration=3, **kwargs):
"""Run pg_sleep"""
return self.sql(f"select pg_sleep({duration})", **kwargs)
def asleep(self, duration=3, times=1, sequentially=False, **kwargs):
"""Run pg_sleep asynchronously in a task.
times:
You can create a single task that opens multiple connections, which
run pg_sleep concurrently. The asynchronous task will only complete
once all these pg_sleep calls are finished.
sequentially:
Instead of running all pg_sleep calls spawned by providing
times > 1 concurrently, this will run them sequentially.
"""
return asyncio.ensure_future(
self.asleep_coroutine(
duration=duration, times=times, sequentially=sequentially, **kwargs
)
)
async def asleep_coroutine(self, duration=3, times=1, sequentially=False, **kwargs):
"""This is the coroutine that the asleep task runs internally"""
if not sequentially:
await asyncio.gather(
*[
self.asql(f"select pg_sleep({duration})", **kwargs)
for _ in range(times)
]
)
else:
for _ in range(times):
await self.asql(f"select pg_sleep({duration})", **kwargs)
def test(self, **kwargs):
"""Test if you can connect"""
return self.sql("select 1", **kwargs)
def atest(self, **kwargs):
"""Test if you can connect asynchronously"""
return self.asql("select 1", **kwargs)
def psql_test(self, **kwargs):
"""Test if you can connect with psql instead of psycopg"""
return self.psql("select 1", **kwargs)
def debug(self):
print("Connect manually to:\n ", repr(self.make_conninfo()))
print("Press Enter to continue running the test...")
input()
def psql_debug(self, **kwargs):
conninfo = self.make_conninfo(**kwargs)
run(
["psql", f"{conninfo}"],
shell=False,
silent=True,
)
class Postgres(QueryRunner):
"""A class that represents a Postgres instance on this machine
You can query it by using the interface provided by QueryRunner or use many
of the helper methods.
"""
def __init__(self, pgdata):
self.port_lock = PortLock()
# These values should almost never be changed after initialization
self.host = "127.0.0.1"
self.port = self.port_lock.port
# These values can be changed when needed
self.dbname = "postgres"
self.user = "postgres"
self.schema = None
self.pgdata = pgdata
self.log_path = self.pgdata / "pg.log"
# Used to track objects that we want to clean up at the end of a test
self.subscriptions = set()
self.publications = set()
self.replication_slots = set()
self.databases = set()
self.schemas = set()
self.users = set()
def set_default_connection_options(self, options):
options.setdefault("host", self.host)
options.setdefault("port", self.port)
options.setdefault("dbname", self.dbname)
options.setdefault("user", self.user)
if self.schema is not None:
options.setdefault("options", f"-c search_path={self.schema}")
options.setdefault("connect_timeout", 3)
# needed for Ubuntu 18.04
options.setdefault("client_encoding", "UTF8")
def initdb(self):
run(
f"initdb -A trust --nosync --username postgres --pgdata {self.pgdata} --allow-group-access --encoding UTF8 --locale POSIX",
stdout=subprocess.DEVNULL,
)
with self.conf_path.open(mode="a") as pgconf:
# Allow connecting over unix sockets
pgconf.write("unix_socket_directories = '/tmp'\n")
# Useful logs for debugging issues
pgconf.write("log_replication_commands = on\n")
# The following to are also useful for debugging, but quite noisy.
# So better to enable them manually by uncommenting.
# pgconf.write("log_connections = on\n")
# pgconf.write("log_disconnections = on\n")
# Enable citus
pgconf.write("shared_preload_libraries = 'citus'\n")
# Allow CREATE SUBSCRIPTION to work
pgconf.write("wal_level = 'logical'\n")
# Faster logical replication status update so tests with logical replication
# run faster
pgconf.write("wal_receiver_status_interval = 1\n")
# Faster logical replication apply worker launch so tests with logical
# replication run faster. This is used in ApplyLauncherMain in
# src/backend/replication/logical/launcher.c.
pgconf.write("wal_retrieve_retry_interval = '250ms'\n")
# Make sure there's enough logical replication resources for most
# of our tests
pgconf.write("max_logical_replication_workers = 50\n")
pgconf.write("max_wal_senders = 50\n")
pgconf.write("max_worker_processes = 50\n")
pgconf.write("max_replication_slots = 50\n")
# We need to make the log go to stderr so that the tests can
# check what is being logged. This should be the default, but
# some packagings change the default configuration.
pgconf.write("log_destination = stderr\n")
# We don't need the logs anywhere else than stderr
pgconf.write("logging_collector = off\n")
# This makes tests run faster and we don't care about crash safety
# of our test data.
pgconf.write("fsync = false\n")
# conservative settings to ensure we can run multiple postmasters:
pgconf.write("shared_buffers = 1MB\n")
# limit disk space consumption, too:
pgconf.write("max_wal_size = 128MB\n")
# don't restart after crashes to make it obvious that a crash
# happened
pgconf.write("restart_after_crash = off\n")
os.truncate(self.hba_path, 0)
self.ssl_access("all", "trust")
self.nossl_access("all", "trust")
self.commit_hba()
def init_with_citus(self):
self.initdb()
self.start()
self.sql("CREATE EXTENSION citus")
# Manually turn on ssl, so that we can safely truncate
# postgresql.auto.conf later. We can only do this after creating the
# citus extension because that creates the self signed certificates.
with self.conf_path.open(mode="a") as pgconf:
pgconf.write("ssl = on\n")
def pgctl(self, command, **kwargs):
run(f"pg_ctl -w --pgdata {self.pgdata} {command}", **kwargs)
def apgctl(self, command, **kwargs):
return asyncio.create_subprocess_shell(
f"pg_ctl -w --pgdata {self.pgdata} {command}", **kwargs
)
def start(self):
try:
self.pgctl(f'-o "-p {self.port}" -l {self.log_path} start')
except Exception:
print(f"\n\nPG_LOG: {self.pgdata}\n")
with self.log_path.open() as f:
print(f.read())
raise
def stop(self, mode="fast"):
self.pgctl(f"-m {mode} stop", check=False)
def cleanup(self):
self.stop()
self.port_lock.release()
def restart(self):
self.stop()
self.start()
def reload(self):
self.pgctl("reload")
# Sadly UNIX signals are asynchronous, so we sleep a bit and hope that
# Postgres actually processed the SIGHUP signal after the sleep.
time.sleep(0.1)
async def arestart(self):
process = await self.apgctl("-m fast restart")
await process.communicate()
def nossl_access(self, dbname, auth_type):
"""Prepends a local non-SSL access to the HBA file"""
with self.hba_path.open() as pghba:
old_contents = pghba.read()
with self.hba_path.open(mode="w") as pghba:
pghba.write(f"local {dbname} all {auth_type}\n")
pghba.write(f"hostnossl {dbname} all 127.0.0.1/32 {auth_type}\n")
pghba.write(f"hostnossl {dbname} all ::1/128 {auth_type}\n")
pghba.write(old_contents)
def ssl_access(self, dbname, auth_type):
"""Prepends a local SSL access rule to the HBA file"""
with self.hba_path.open() as pghba:
old_contents = pghba.read()
with self.hba_path.open(mode="w") as pghba:
pghba.write(f"hostssl {dbname} all 127.0.0.1/32 {auth_type}\n")
pghba.write(f"hostssl {dbname} all ::1/128 {auth_type}\n")
pghba.write(old_contents)
@property
def hba_path(self):
return self.pgdata / "pg_hba.conf"
@property
def conf_path(self):
return self.pgdata / "postgresql.conf"
def commit_hba(self):
"""Mark the current HBA contents as non-resetable by reset_hba"""
with self.hba_path.open() as pghba:
old_contents = pghba.read()
with self.hba_path.open(mode="w") as pghba:
pghba.write("# committed-rules\n")
pghba.write(old_contents)
def reset_hba(self):
"""Remove any HBA rules that were added after the last call to commit_hba"""
with self.hba_path.open() as f:
hba_contents = f.read()
committed = hba_contents[hba_contents.find("# committed-rules\n") :]
with self.hba_path.open("w") as f:
f.write(committed)
def prepare_reset(self):
"""Prepares all changes to reset Postgres settings and objects
To actually apply the prepared changes a restart might still be needed.
"""
self.reset_hba()
os.truncate(self.pgdata / "postgresql.auto.conf", 0)
def reset(self):
"""Resets any changes to Postgres settings from previous tests"""
self.prepare_reset()
self.restart()
async def delayed_start(self, delay=1):
"""Start Postgres after a delay
NOTE: The sleep is asynchronous, but while waiting for Postgres to
start the pg_ctl start command will block the event loop. This is
currently acceptable for our usage of this method in the existing
tests and this way it was easiest to implement. However, it seems
totally reasonable to change this behaviour in the future if necessary.
"""
await asyncio.sleep(delay)
self.start()
def configure(self, *configs):
"""Configure specific Postgres settings using ALTER SYSTEM SET
NOTE: after configuring a call to reload or restart is needed for the
settings to become effective.
"""
for config in configs:
self.sql(f"alter system set {config}")
def log_handle(self):
"""Returns the opened logfile at the current end of the log
By later calling read on this file you can read the contents that were
written from this moment on.
IMPORTANT: This handle should be closed once it's not needed anymore
"""
f = self.log_path.open()
f.seek(0, os.SEEK_END)
return f
@contextmanager
def log_contains(self, re_string, times=None):
"""Checks if during this with block the log matches re_string
re_string:
The regex to search for.
times:
If None, any number of matches is accepted. If a number, only that
specific number of matches is accepted.
"""
with self.log_handle() as f:
yield
content = f.read()
if times is None:
assert re.search(re_string, content)
else:
match_count = len(re.findall(re_string, content))
assert match_count == times
def create_user(self, name, args: typing.Optional[psycopg.sql.Composable] = None):
self.users.add(name)
if args is None:
args = sql.SQL("")
self.sql(sql.SQL("CREATE USER {} {}").format(sql.Identifier(name), args))
def create_database(self, name):
self.databases.add(name)
self.sql(sql.SQL("CREATE DATABASE {}").format(sql.Identifier(name)))
def create_schema(self, name):
self.schemas.add(name)
self.sql(sql.SQL("CREATE SCHEMA {}").format(sql.Identifier(name)))
def create_publication(self, name: str, args: psycopg.sql.Composable):
self.publications.add(name)
self.sql(sql.SQL("CREATE PUBLICATION {} {}").format(sql.Identifier(name), args))
def create_logical_replication_slot(
self, name, plugin, temporary=False, twophase=False
):
self.replication_slots.add(name)
self.sql(
"SELECT pg_catalog.pg_create_logical_replication_slot(%s,%s,%s,%s)",
(name, plugin, temporary, twophase),
)
def create_subscription(self, name: str, args: psycopg.sql.Composable):
self.subscriptions.add(name)
self.sql(
sql.SQL("CREATE SUBSCRIPTION {} {}").format(sql.Identifier(name), args)
)
def cleanup_users(self):
for user in self.users:
self.sql(sql.SQL("DROP USER IF EXISTS {}").format(sql.Identifier(user)))
def cleanup_databases(self):
for database in self.databases:
self.sql(
sql.SQL("DROP DATABASE IF EXISTS {}").format(sql.Identifier(database))
)
def cleanup_schemas(self):
for schema in self.schemas:
self.sql(
sql.SQL("DROP SCHEMA IF EXISTS {} CASCADE").format(
sql.Identifier(schema)
)
)
def cleanup_publications(self):
for publication in self.publications:
self.sql(
sql.SQL("DROP PUBLICATION IF EXISTS {}").format(
sql.Identifier(publication)
)
)
def cleanup_replication_slots(self):
for slot in self.replication_slots:
start = time.time()
while True:
try:
self.sql(
"SELECT pg_drop_replication_slot(slot_name) FROM pg_replication_slots WHERE slot_name = %s",
(slot,),
)
except psycopg.errors.ObjectInUse:
if time.time() < start + 10:
time.sleep(0.5)
continue
raise
break
def cleanup_subscriptions(self):
for subscription in self.subscriptions:
try:
self.sql(
sql.SQL("ALTER SUBSCRIPTION {} DISABLE").format(
sql.Identifier(subscription)
)
)
except psycopg.errors.UndefinedObject:
# Subscription didn't exist already
continue
self.sql(
sql.SQL("ALTER SUBSCRIPTION {} SET (slot_name = NONE)").format(
sql.Identifier(subscription)
)
)
self.sql(
sql.SQL("DROP SUBSCRIPTION {}").format(sql.Identifier(subscription))
)
def lsn(self, mode):
"""Returns the lsn for the given mode"""
queries = {
"insert": "SELECT pg_current_wal_insert_lsn()",
"flush": "SELECT pg_current_wal_flush_lsn()",
"write": "SELECT pg_current_wal_lsn()",
"receive": "SELECT pg_last_wal_receive_lsn()",
"replay": "SELECT pg_last_wal_replay_lsn()",
}
return self.sql_value(queries[mode])
def wait_for_catchup(self, subscription_name, mode="replay", target_lsn=None):
"""Waits until the subscription has caught up"""
if target_lsn is None:
target_lsn = self.lsn("write")
# Before release 12 walreceiver just set the application name to
# "walreceiver"
self.poll_query_until(
sql.SQL(
"""
SELECT {} <= {} AND state = 'streaming'
FROM pg_catalog.pg_stat_replication
WHERE application_name IN ({}, 'walreceiver')
"""
).format(target_lsn, sql.Identifier(f"{mode}_lsn"), subscription_name)
)
@contextmanager
def _enable_firewall(self):
"""Enables the firewall for the platform that you are running
Normally this should not be called directly, and instead drop_traffic
or reject_traffic should be used.
"""
fw_token = None
if BSD:
if MACOS:
command_stderr = sudo(
"pfctl -E", stderr=subprocess.PIPE, text=True
).stderr
match = re.search(r"^Token : (\d+)", command_stderr, flags=re.MULTILINE)
assert match is not None
fw_token = match.group(1)
sudo(
'bash -c "'
f"echo 'anchor \\\"port_{self.port}\\\"'"
f' | pfctl -a citus_test -f -"'
)
try:
yield
finally:
if MACOS:
sudo(f"pfctl -X {fw_token}")
@contextmanager
def drop_traffic(self):
"""Drops all TCP packets to this query runner"""
with self._enable_firewall():
if LINUX:
sudo(
"iptables --append OUTPUT "
"--protocol tcp "
f"--destination {self.host} "
f"--destination-port {self.port} "
"--jump DROP "
)
elif BSD:
sudo(
"bash -c '"
f'echo "block drop out proto tcp from any to {self.host} port {self.port}"'
f"| pfctl -a citus_test/port_{self.port} -f -'"
)
else:
raise Exception("This OS cannot run this test")
try:
yield
finally:
if LINUX:
sudo(
"iptables --delete OUTPUT "
"--protocol tcp "
f"--destination {self.host} "
f"--destination-port {self.port} "
"--jump DROP "
)
elif BSD:
sudo(f"pfctl -a citus_test/port_{self.port} -F all")
@contextmanager
def reject_traffic(self):
"""Rejects all traffic to this query runner with a TCP RST message"""
with self._enable_firewall():
if LINUX:
sudo(
"iptables --append OUTPUT "
"--protocol tcp "
f"--destination {self.host} "
f"--destination-port {self.port} "
"--jump REJECT "
"--reject-with tcp-reset"
)
elif BSD:
sudo(
"bash -c '"
f'echo "block return-rst out out proto tcp from any to {self.host} port {self.port}"'
f"| pfctl -a citus_test/port_{self.port} -f -'"
)
else:
raise Exception("This OS cannot run this test")
try:
yield
finally:
if LINUX:
sudo(
"iptables --delete OUTPUT "
"--protocol tcp "
f"--destination {self.host} "
f"--destination-port {self.port} "
"--jump REJECT "
"--reject-with tcp-reset"
)
elif BSD:
sudo(f"pfctl -a citus_test/port_{self.port} -F all")
class CitusCluster(QueryRunner):
"""A class that represents a Citus cluster on this machine
The nodes in the cluster can be accessed directly using the coordinator,
workers, and nodes properties.
If it doesn't matter which of the nodes in the cluster is used to run a
query, then you can use the methods provided by QueryRunner directly on the
cluster. In that case a random node will be chosen to run your query.
"""
def __init__(self, basedir: Path, worker_count: int):
self.coordinator = Postgres(basedir / "coordinator")
self.workers = [Postgres(basedir / f"worker{i}") for i in range(worker_count)]
self.nodes = [self.coordinator] + self.workers
self._schema = None
self.failed_reset = False
parallel_run(Postgres.init_with_citus, self.nodes)
with self.coordinator.cur() as cur:
cur.execute(
"SELECT pg_catalog.citus_set_coordinator_host(%s, %s)",
(self.coordinator.host, self.coordinator.port),
)
for worker in self.workers:
cur.execute(
"SELECT pg_catalog.citus_add_node(%s, %s)",
(worker.host, worker.port),
)
def set_default_connection_options(self, options):
random.choice(self.nodes).set_default_connection_options(options)
@property
def schema(self):
return self._schema
@schema.setter
def schema(self, value):
self._schema = value
for node in self.nodes:
node.schema = value
def reset(self):
"""Resets any changes to Postgres settings from previous tests"""
parallel_run(Postgres.prepare_reset, self.nodes)
parallel_run(Postgres.restart, self.nodes)
def cleanup(self):
parallel_run(Postgres.cleanup, self.nodes)
def debug(self):
"""Print information to stdout to help with debugging your cluster"""
print("The nodes in this cluster and their connection strings are:")
for node in self.nodes:
print(f"{node.pgdata}:\n ", repr(node.make_conninfo()))
print("Press Enter to continue running the test...")
input()