eag/test-workflow
eaydingol 2025-10-17 12:48:57 +03:00
parent b9378abf0f
commit 8053f4226c
2 changed files with 143 additions and 104 deletions

View File

@ -4,12 +4,12 @@ Backward compatibility checker for Citus
Detects changes that could break existing workflows Detects changes that could break existing workflows
""" """
import json
import os import os
import re import re
import json
import subprocess import subprocess
from pathlib import Path
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path
from typing import List from typing import List
@ -27,7 +27,7 @@ class FunctionSignature:
def __str__(self): def __str__(self):
return f"{self.name}({self.args}) -> {self.return_type}" return f"{self.name}({self.args}) -> {self.return_type}"
def compare(self, other: 'FunctionSignature') -> dict: def compare(self, other: "FunctionSignature") -> dict:
"""Compare two function signatures and return differences""" """Compare two function signatures and return differences"""
if not isinstance(other, FunctionSignature): if not isinstance(other, FunctionSignature):
return {"error": "Cannot compare with non-FunctionSignature object"} return {"error": "Cannot compare with non-FunctionSignature object"}
@ -35,16 +35,20 @@ class FunctionSignature:
differences = [] differences = []
if self.name != other.name: if self.name != other.name:
differences.append(f"Function name changed from {self.name} to {other.name}.") differences.append(
f"Function name changed from {self.name} to {other.name}."
)
if self.return_type != other.return_type: if self.return_type != other.return_type:
differences.append(f"Return type changed from {self.return_type} to {other.return_type}.") differences.append(
f"Return type changed from {self.return_type} to {other.return_type}."
)
arg_diff = self._compare_parameters(other.args) arg_diff = self._compare_parameters(other.args)
if arg_diff: if arg_diff:
differences.append(f"Parameter changes detected:\n{arg_diff}") differences.append(f"Parameter changes detected:\n{arg_diff}")
return '\n'.join(differences) if differences else None return "\n".join(differences) if differences else None
def _parse_parameters(self, args_string: str) -> dict[dict]: def _parse_parameters(self, args_string: str) -> dict[dict]:
"""Parse parameter string into structured data""" """Parse parameter string into structured data"""
@ -52,33 +56,37 @@ class FunctionSignature:
return {} return {}
params = {} params = {}
for param in args_string.split(','): for param in args_string.split(","):
param = param.strip() param = param.strip()
if param: if param:
default_value = None default_value = None
# Extract name, type, and default # Extract name, type, and default
if 'DEFAULT' in param.upper(): if "DEFAULT" in param.upper():
# Capture the default value (everything after DEFAULT up to a comma or closing parenthesis) # Capture the default value (everything after DEFAULT up to a comma or closing parenthesis)
m = re.search(r'\bDEFAULT\b\s+([^,)\s][^,)]*)', param, flags=re.IGNORECASE) m = re.search(
r"\bDEFAULT\b\s+([^,)\s][^,)]*)", param, flags=re.IGNORECASE
)
if m: if m:
default_value = m.group(1).strip() default_value = m.group(1).strip()
# Remove the DEFAULT clause from the parameter string for further parsing # Remove the DEFAULT clause from the parameter string for further parsing
param_clean = re.sub(r'\s+DEFAULT\s+[^,)]+', '', param, flags=re.IGNORECASE) param_clean = re.sub(
r"\s+DEFAULT\s+[^,)]+", "", param, flags=re.IGNORECASE
)
else: else:
param_clean = param param_clean = param
parts = param_clean.strip().split() parts = param_clean.strip().split()
if len(parts) >= 2: if len(parts) >= 2:
name = parts[0] name = parts[0]
type_part = ' '.join(parts[1:]) type_part = " ".join(parts[1:])
else: else:
name = param_clean name = param_clean
type_part = "" type_part = ""
params[name] = { params[name] = {
'type': type_part, "type": type_part,
'default_value': default_value, "default_value": default_value,
'original': param "original": param,
} }
return params return params
@ -92,7 +100,7 @@ class FunctionSignature:
added_without_default = [] added_without_default = []
for name in added: for name in added:
if new_params[name]['default_value'] is None: if new_params[name]["default_value"] is None:
added_without_default.append(name) added_without_default.append(name)
# Find modified parameters # Find modified parameters
@ -101,93 +109,118 @@ class FunctionSignature:
for name, old_param in self.args.items(): for name, old_param in self.args.items():
if name in new_params: if name in new_params:
new_param = new_params[name] new_param = new_params[name]
if old_param['type'] != new_param['type']: if old_param["type"] != new_param["type"]:
type_changed.append(name) type_changed.append(name)
if old_param['default_value'] and old_param['default_value'] != new_param['default_value']: if (
old_param["default_value"]
and old_param["default_value"] != new_param["default_value"]
):
default_changed.append(name) default_changed.append(name)
if removed: if removed:
differences.append(f"Removed parameters: {', '.join(removed)}") differences.append(f"Removed parameters: {', '.join(removed)}")
if added_without_default: if added_without_default:
differences.append(f"Added parameters without a default value: {', '.join(added_without_default)}") differences.append(
f"Added parameters without a default value: {', '.join(added_without_default)}"
)
if type_changed: if type_changed:
differences.append(f"Type changed for parameters: {', '.join(type_changed)}") differences.append(
f"Type changed for parameters: {', '.join(type_changed)}"
)
if default_changed: if default_changed:
differences.append(f"Default value changed for parameters: {', '.join(default_changed)}") differences.append(
f"Default value changed for parameters: {', '.join(default_changed)}"
)
return '\n'.join(differences) if differences else None return "\n".join(differences) if differences else None
class CompatibilityChecker: class CompatibilityChecker:
def __init__(self): def __init__(self):
self.base_sha = os.environ.get('BASE_SHA') self.base_sha = os.environ.get("BASE_SHA")
self.head_sha = os.environ.get('HEAD_SHA') self.head_sha = os.environ.get("HEAD_SHA")
self.results = [] self.results = []
def get_changed_files(self): def get_changed_files(self):
"""Get list of changed files between base and head""" """Get list of changed files between base and head"""
cmd = ['git', 'diff', '--name-only', f'{self.base_sha}...{self.head_sha}'] cmd = ["git", "diff", "--name-only", f"{self.base_sha}...{self.head_sha}"]
result = subprocess.run(cmd, capture_output=True, text=True) result = subprocess.run(cmd, capture_output=True, text=True)
return result.stdout.strip().split('\n') if result.stdout.strip() else [] return result.stdout.strip().split("\n") if result.stdout.strip() else []
def get_file_diff(self, file_path): def get_file_diff(self, file_path):
"""Get diff for a specific file""" """Get diff for a specific file"""
cmd = ['git', 'diff', f'{self.base_sha}...{self.head_sha}', '--', file_path] cmd = ["git", "diff", f"{self.base_sha}...{self.head_sha}", "--", file_path]
result = subprocess.run(cmd, capture_output=True, text=True) result = subprocess.run(cmd, capture_output=True, text=True)
return result.stdout return result.stdout
def get_base_file_content(self, file_path): def get_base_file_content(self, file_path):
"""Get file content from base commit""" """Get file content from base commit"""
cmd = ['git', 'show', f'{self.base_sha}:{file_path}'] cmd = ["git", "show", f"{self.base_sha}:{file_path}"]
result = subprocess.run(cmd, capture_output=True, text=True) result = subprocess.run(cmd, capture_output=True, text=True)
return result.stdout if result.returncode == 0 else '' return result.stdout if result.returncode == 0 else ""
def get_head_file_content(self, file_path): def get_head_file_content(self, file_path):
"""Get file content from head commit""" """Get file content from head commit"""
cmd = ['git', 'show', f'{self.head_sha}:{file_path}'] cmd = ["git", "show", f"{self.head_sha}:{file_path}"]
result = subprocess.run(cmd, capture_output=True, text=True) result = subprocess.run(cmd, capture_output=True, text=True)
return result.stdout if result.returncode == 0 else '' return result.stdout if result.returncode == 0 else ""
def get_function_signatures(self, sql_text: str) -> List[FunctionSignature]: def get_function_signatures(self, sql_text: str) -> List[FunctionSignature]:
"""Extract all function signatures from SQL text""" """Extract all function signatures from SQL text"""
pattern = re.compile( pattern = re.compile(
r'CREATE\s+(?:OR\s+REPLACE\s+)?FUNCTION\s+' r"CREATE\s+(?:OR\s+REPLACE\s+)?FUNCTION\s+"
r'(?P<name>[^\s(]+)\s*' # function name, e.g. public.add_numbers r"(?P<name>[^\s(]+)\s*" # function name, e.g. public.add_numbers
r'\((?P<args>[^)]*)\)' # argument list r"\((?P<args>[^)]*)\)" # argument list
r'\s*RETURNS\s+(?P<return>(?:SETOF\s+)?(?:TABLE\s*\([^)]+\)|[\w\[\]]+(?:\s*\[\s*\])*))', # return type r"\s*RETURNS\s+(?P<return>(?:SETOF\s+)?(?:TABLE\s*\([^)]+\)|[\w\[\]]+(?:\s*\[\s*\])*))", # return type
re.IGNORECASE | re.MULTILINE re.IGNORECASE | re.MULTILINE,
) )
matches = pattern.finditer(sql_text) matches = pattern.finditer(sql_text)
return [FunctionSignature(match.group('name'), match.group('args'), match.group('return')) for match in matches] return [
FunctionSignature(
match.group("name"), match.group("args"), match.group("return")
)
for match in matches
]
def check_sql_migrations(self, changed_files): def check_sql_migrations(self, changed_files):
"""Check for potentially breaking SQL migration changes""" """Check for potentially breaking SQL migration changes"""
breaking_patterns = [ breaking_patterns = [
(r'DROP\s+TABLE', 'Table removal'), (r"DROP\s+TABLE", "Table removal"),
(r'ALTER\s+TABLE\s+pg_catalog\.\w+\s+(ADD|DROP)\s+COLUMN\b', 'Column addition/removal in pg_catalog'), (
(r'ALTER\s+TABLE\s+\w+\s+ALTER\s+COLUMN', 'Column type change'), r"ALTER\s+TABLE\s+pg_catalog\.\w+\s+(ADD|DROP)\s+COLUMN\b",
(r'ALTER\s+FUNCTION.*RENAME', 'Function rename'), "Column addition/removal in pg_catalog",
(r'ALTER\s+TABLE\s+\w+\s+RENAME\s+TO\s+\w+', 'Table rename'), ),
(r'REVOKE', 'Permission revocation') (r"ALTER\s+TABLE\s+\w+\s+ALTER\s+COLUMN", "Column type change"),
(r"ALTER\s+FUNCTION.*RENAME", "Function rename"),
(r"ALTER\s+TABLE\s+\w+\s+RENAME\s+TO\s+\w+", "Table rename"),
(r"REVOKE", "Permission revocation"),
] ]
upgrade_scripts = [f for f in changed_files if 'sql' in f and '/downgrades/' not in f and 'citus--' in f] upgrade_scripts = [
f
for f in changed_files
if "sql" in f and "/downgrades/" not in f and "citus--" in f
]
udf_files = [f for f in changed_files if f.endswith('latest.sql')] udf_files = [f for f in changed_files if f.endswith("latest.sql")]
for file_path in upgrade_scripts: for file_path in upgrade_scripts:
diff = self.get_file_diff(file_path) diff = self.get_file_diff(file_path)
added_lines = [line[1:] for line in diff.split('\n') if line.startswith('+')] added_lines = [
line[1:] for line in diff.split("\n") if line.startswith("+")
]
for pattern, description in breaking_patterns: for pattern, description in breaking_patterns:
for line in added_lines: for line in added_lines:
if re.search(pattern, line, re.IGNORECASE): if re.search(pattern, line, re.IGNORECASE):
self.results.append({ self.results.append(
'type': 'SQL Migration', {
'description': description, "type": "SQL Migration",
'file': file_path, "description": description,
'details': f'Line: {line.strip()}' "file": file_path,
}) "details": f"Line: {line.strip()}",
}
)
for file_path in udf_files: for file_path in udf_files:
udf_directory = Path(file_path).parent.name udf_directory = Path(file_path).parent.name
@ -196,12 +229,14 @@ class CompatibilityChecker:
continue # File did not exist in base, likely a new file continue # File did not exist in base, likely a new file
head_content = self.get_head_file_content(file_path) head_content = self.get_head_file_content(file_path)
if not head_content: if not head_content:
self.results.append({ self.results.append(
'type': 'UDF Removal', {
'description': f'UDF file removed: {udf_directory}', "type": "UDF Removal",
'file': file_path, "description": f"UDF file removed: {udf_directory}",
'details': 'The UDF file is missing in the new version' "file": file_path,
}) "details": "The UDF file is missing in the new version",
}
)
continue continue
# Extract function signatures from base and head # Extract function signatures from base and head
@ -220,38 +255,41 @@ class CompatibilityChecker:
found = True found = True
break # Found one for the previous signature break # Found one for the previous signature
if not found and differences: if not found and differences:
self.results.append({ self.results.append(
'type': 'UDF Change', {
'description': f'UDF changed: {udf_directory}', "type": "UDF Change",
'file': file_path, "description": f"UDF changed: {udf_directory}",
'details': differences "file": file_path,
}) "details": differences,
}
)
def check_guc_changes(self, changed_files): def check_guc_changes(self, changed_files):
"""Check for GUC (configuration) changes""" """Check for GUC (configuration) changes"""
c_files = [f for f in changed_files if f.endswith('shared_library_init.c')] c_files = [f for f in changed_files if f.endswith("shared_library_init.c")]
if not c_files: if not c_files:
return return
file_path = c_files[0] # There should be only one shared_library_init.c file_path = c_files[0] # There should be only one shared_library_init.c
guc_pattern = re.compile( guc_pattern = re.compile(
r'^-\s*DefineCustom\w+Variable\s*\(\s*\n' # DefineCustom...Variable (removed) r"^-\s*DefineCustom\w+Variable\s*\(\s*\n" # DefineCustom...Variable (removed)
r'\s*-\s*"([^"]+)"', # Parameter name removed r'\s*-\s*"([^"]+)"', # Parameter name removed
re.MULTILINE re.MULTILINE,
) )
diff = self.get_file_diff(file_path) diff = self.get_file_diff(file_path)
for match in guc_pattern.finditer(diff): for match in guc_pattern.finditer(diff):
print("Matched GUC line:", match.group(0)) print("Matched GUC line:", match.group(0))
guc_name = match.group(1) guc_name = match.group(1)
self.results.append({ self.results.append(
'type': 'Configuration', {
'description': f'GUC change: {guc_name}', "type": "Configuration",
'file': file_path, "description": f"GUC change: {guc_name}",
'details': 'GUC removed' "file": file_path,
}) "details": "GUC removed",
}
)
def run_checks(self): def run_checks(self):
"""Run all compatibility checks""" """Run all compatibility checks"""
@ -267,7 +305,7 @@ class CompatibilityChecker:
self.check_guc_changes(changed_files) self.check_guc_changes(changed_files)
# Write results to file for GitHub Action to read # Write results to file for GitHub Action to read
with open('/tmp/compat-results.json', 'w') as f: with open("/tmp/compat-results.json", "w") as f:
json.dump(self.results, f, indent=2) json.dump(self.results, f, indent=2)
# Print summary # Print summary
@ -280,6 +318,7 @@ class CompatibilityChecker:
else: else:
print("\n No backward compatibility issues detected") print("\n No backward compatibility issues detected")
if __name__ == '__main__':
if __name__ == "__main__":
checker = CompatibilityChecker() checker = CompatibilityChecker()
checker.run_checks() checker.run_checks()