format include grouping script

pull/7326/head
Nils Dijk 2023-11-07 12:18:07 +00:00
parent 59eb067bb0
commit a16b604843
No known key found for this signature in database
GPG Key ID: CA1177EF9434F241
1 changed files with 56 additions and 52 deletions

View File

@ -6,6 +6,7 @@
import os import os
import sys import sys
def main(args): def main(args):
if len(args) < 2: if len(args) < 2:
print("Usage: include-grouping.py <file>") print("Usage: include-grouping.py <file>")
@ -16,8 +17,8 @@ def main(args):
print("File does not exist", file=sys.stderr) print("File does not exist", file=sys.stderr)
return sys.exit(1) return sys.exit(1)
with open(file, 'r') as f: with open(file, "r") as f:
with open(file + ".tmp", 'w') as out_file: with open(file + ".tmp", "w") as out_file:
lines = f.readlines() lines = f.readlines()
includes = [] includes = []
skipped_lines = [] skipped_lines = []
@ -36,16 +37,17 @@ def main(args):
# print skipped lines # print skipped lines
for skipped_line in skipped_lines: for skipped_line in skipped_lines:
print(skipped_line, end='', file=out_file) print(skipped_line, end="", file=out_file)
skipped_lines = [] skipped_lines = []
print(line, end='', file=out_file) print(line, end="", file=out_file)
# move out_file to file # move out_file to file
os.rename(file + ".tmp", file) os.rename(file + ".tmp", file)
pass pass
def print_sorted_includes(includes, file=sys.stdout): def print_sorted_includes(includes, file=sys.stdout):
default_group_key = 1 default_group_key = 1
groups = {} groups = {}
@ -53,60 +55,61 @@ def print_sorted_includes(includes, file=sys.stdout):
matches = [ matches = [
{ {
"name": "system includes", "name": "system includes",
"matcher": lambda x: x.startswith('<'), "matcher": lambda x: x.startswith("<"),
"group_key": -2, "group_key": -2,
"priority": 0 "priority": 0,
}, },
{ {
"name": "naked postgres includes", "name": "naked postgres includes",
"matcher": lambda x: not '/' in x, "matcher": lambda x: not "/" in x,
"group_key": 0, "group_key": 0,
"priority": 9 "priority": 9,
}, },
{ {
"name": "postgres.h", "name": "postgres.h",
"list": ['"postgres.h"'], "list": ['"postgres.h"'],
"group_key": -1, "group_key": -1,
"priority": -1 "priority": -1,
}, },
{ {
"name": "naked citus inlcudes", "name": "naked citus inlcudes",
"list": ['"citus_version.h"', '"pg_version_compat.h"'], "list": ['"citus_version.h"', '"pg_version_compat.h"'],
"group_key": 3, "group_key": 3,
"priority": 0 "priority": 0,
}, },
{ {
"name": "positional citus includes", "name": "positional citus includes",
"list": ['"distributed/pg_version_constants.h"'], "list": ['"distributed/pg_version_constants.h"'],
"group_key": 4, "group_key": 4,
"priority": 0 "priority": 0,
}, },
{ {
"name": "columnar includes", "name": "columnar includes",
"matcher": lambda x: x.startswith('"columnar/'), "matcher": lambda x: x.startswith('"columnar/'),
"group_key": 5, "group_key": 5,
"priority": 1 "priority": 1,
}, },
{ {
"name": "distributed includes", "name": "distributed includes",
"matcher": lambda x: x.startswith('"distributed/'), "matcher": lambda x: x.startswith('"distributed/'),
"group_key": 6, "group_key": 6,
"priority": 1 "priority": 1,
}] },
]
matches.sort(key=lambda x: x["priority"]) matches.sort(key=lambda x: x["priority"])
common_system_include_error_prefixes = ['<nodes/', '<distributed/'] common_system_include_error_prefixes = ["<nodes/", "<distributed/"]
for include in includes: for include in includes:
# extract the group key from the include # extract the group key from the include
include_content = include.split(' ')[1] include_content = include.split(" ")[1]
# fix common system includes which are secretly postgres or citus includes # fix common system includes which are secretly postgres or citus includes
for common_prefix in common_system_include_error_prefixes: for common_prefix in common_system_include_error_prefixes:
if include_content.startswith(common_prefix): if include_content.startswith(common_prefix):
include_content = '"' + include_content.strip()[1:-1] + '"' include_content = '"' + include_content.strip()[1:-1] + '"'
include = include.split(' ')[0] + ' ' + include_content + '\n' include = include.split(" ")[0] + " " + include_content + "\n"
break break
group_key = default_group_key group_key = default_group_key
@ -140,8 +143,9 @@ def print_sorted_includes(includes, file=sys.stdout):
# remove duplicates # remove duplicates
if prev == include: if prev == include:
continue continue
print(include, end='', file=file) print(include, end="", file=file)
prev = include prev = include
if __name__ == '__main__':
if __name__ == "__main__":
main(sys.argv) main(sys.argv)