citus/src/test/regress/citus_tests/common.py

320 lines
8.9 KiB
Python

import atexit
import concurrent.futures
import os
import shutil
import subprocess
import sys
import utils
from utils import USER, cd
def initialize_temp_dir(temp_dir):
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.mkdir(temp_dir)
# Give full access to TEMP_DIR so that postgres user can use it.
os.chmod(temp_dir, 0o777)
def initialize_temp_dir_if_not_exists(temp_dir):
if os.path.exists(temp_dir):
return
os.mkdir(temp_dir)
# Give full access to TEMP_DIR so that postgres user can use it.
os.chmod(temp_dir, 0o777)
def parallel_run(function, items, *args, **kwargs):
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(function, item, *args, **kwargs) for item in items]
for future in futures:
future.result()
def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names):
subprocess.run(["mkdir", rel_data_path], check=True)
def initialize(node_name):
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
command = [
os.path.join(pg_path, "initdb"),
"--pgdata",
abs_data_path,
"--username",
USER,
"--no-sync",
# --allow-group-access is used to ensure we set permissions on
# private keys correctly
"--allow-group-access",
"--encoding",
"UTF8",
"--locale",
"POSIX",
]
subprocess.run(command, check=True)
add_settings(abs_data_path, settings)
parallel_run(initialize, node_names)
def add_settings(abs_data_path, settings):
conf_path = os.path.join(abs_data_path, "postgresql.conf")
with open(conf_path, "a") as conf_file:
for setting_key, setting_val in settings.items():
setting = "{setting_key} = '{setting_val}'\n".format(
setting_key=setting_key, setting_val=setting_val
)
conf_file.write(setting)
def create_role(pg_path, node_ports, user_name):
def create(port):
command = "SET citus.enable_ddl_propagation TO OFF; SELECT worker_create_or_alter_role('{}', 'CREATE ROLE {} WITH LOGIN CREATEROLE CREATEDB;', NULL)".format(
user_name, user_name
)
utils.psql(pg_path, port, command)
command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format(
user_name
)
utils.psql(pg_path, port, command)
parallel_run(create, node_ports)
def coordinator_should_haveshards(pg_path, port):
command = "SELECT citus_set_node_property('localhost', {}, 'shouldhaveshards', true)".format(
port
)
utils.psql(pg_path, port, command)
def start_databases(
pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables
):
def start(node_name):
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
node_port = node_name_to_ports[node_name]
command = [
os.path.join(pg_path, "pg_ctl"),
"start",
"--pgdata",
abs_data_path,
"-U",
USER,
"-o",
"-p {}".format(node_port),
"--log",
os.path.join(abs_data_path, logfile_name(logfile_prefix, node_name)),
]
# set the application name if requires
if env_variables != {}:
os.environ.update(env_variables)
subprocess.run(command, check=True)
parallel_run(start, node_name_to_ports.keys())
# We don't want parallel shutdown here because that will fail when it's
# tried in this atexit call with an error like:
# cannot schedule new futures after interpreter shutdown
atexit.register(
stop_databases,
pg_path,
rel_data_path,
node_name_to_ports,
logfile_prefix,
no_output=True,
parallel=False,
)
def create_citus_extension(pg_path, node_ports):
def create(port):
utils.psql(pg_path, port, "CREATE EXTENSION citus;")
parallel_run(create, node_ports)
def run_pg_regress(pg_path, pg_srcdir, port, schedule):
should_exit = True
try:
_run_pg_regress(pg_path, pg_srcdir, port, schedule, should_exit)
finally:
subprocess.run("bin/copy_modified", check=True)
def run_pg_regress_without_exit(
pg_path,
pg_srcdir,
port,
schedule,
output_dir=".",
input_dir=".",
user="postgres",
extra_tests="",
):
should_exit = False
exit_code = _run_pg_regress(
pg_path,
pg_srcdir,
port,
schedule,
should_exit,
output_dir,
input_dir,
user,
extra_tests,
)
copy_binary_path = os.path.join(input_dir, "copy_modified_wrapper")
exit_code |= subprocess.call(copy_binary_path)
return exit_code
def _run_pg_regress(
pg_path,
pg_srcdir,
port,
schedule,
should_exit,
output_dir=".",
input_dir=".",
user="postgres",
extra_tests="",
):
command = [
os.path.join(pg_srcdir, "src/test/regress/pg_regress"),
"--port",
str(port),
"--schedule",
schedule,
"--bindir",
pg_path,
"--user",
user,
"--dbname",
"postgres",
"--inputdir",
input_dir,
"--outputdir",
output_dir,
"--use-existing",
]
if extra_tests != "":
command.append(extra_tests)
exit_code = subprocess.call(command)
if should_exit and exit_code != 0:
sys.exit(exit_code)
return exit_code
def save_regression_diff(name, output_dir):
path = os.path.join(output_dir, "regression.diffs")
if not os.path.exists(path):
return
new_file_path = os.path.join(output_dir, "./regression_{}.diffs".format(name))
print("new file path:", new_file_path)
shutil.move(path, new_file_path)
def stop_metadata_to_workers(pg_path, worker_ports, coordinator_port):
for port in worker_ports:
command = (
"SELECT * from stop_metadata_sync_to_node('localhost', {port});".format(
port=port
)
)
utils.psql(pg_path, coordinator_port, command)
def add_coordinator_to_metadata(pg_path, coordinator_port):
command = "SELECT citus_add_node('localhost', {}, groupId := 0)".format(
coordinator_port
)
utils.psql(pg_path, coordinator_port, command)
def add_workers(pg_path, worker_ports, coordinator_port):
for port in worker_ports:
command = "SELECT * from master_add_node('localhost', {port});".format(
port=port
)
utils.psql(pg_path, coordinator_port, command)
def logfile_name(logfile_prefix, node_name):
return "logfile_" + logfile_prefix + "_" + node_name
def stop_databases(
pg_path,
rel_data_path,
node_name_to_ports,
logfile_prefix,
no_output=False,
parallel=True,
):
def stop(node_name):
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
node_port = node_name_to_ports[node_name]
command = [
os.path.join(pg_path, "pg_ctl"),
"stop",
"--pgdata",
abs_data_path,
"-U",
USER,
"-o",
"-p {}".format(node_port),
"--log",
os.path.join(abs_data_path, logfile_name(logfile_prefix, node_name)),
]
if no_output:
subprocess.call(
command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
)
else:
subprocess.call(command)
if parallel:
parallel_run(stop, node_name_to_ports.keys())
else:
for node_name in node_name_to_ports.keys():
stop(node_name)
def initialize_citus_cluster(bindir, datadir, settings, config):
# In case there was a leftover from previous runs, stop the databases
stop_databases(
bindir, datadir, config.node_name_to_ports, config.name, no_output=True
)
initialize_db_for_cluster(
bindir, datadir, settings, config.node_name_to_ports.keys()
)
start_databases(
bindir, datadir, config.node_name_to_ports, config.name, config.env_variables
)
create_citus_extension(bindir, config.node_name_to_ports.values())
add_workers(bindir, config.worker_ports, config.coordinator_port())
if not config.is_mx:
stop_metadata_to_workers(bindir, config.worker_ports, config.coordinator_port())
if config.add_coordinator_to_metadata:
add_coordinator_to_metadata(bindir, config.coordinator_port())
config.setup_steps()
def eprint(*args, **kwargs):
"""eprint prints to stderr"""
print(*args, file=sys.stderr, **kwargs)
def run(command, *args, shell=True, **kwargs):
"""run runs the given command and prints it to stderr"""
eprint(f"+ {command} ")
return subprocess.run(command, *args, check=True, shell=shell, **kwargs)