Add auto-formatting and linting to our python code (#6700)

We're getting more and more python code in the repo. This adds some
tools to
make sure that styling is consistent and we're not doing easy to miss
mistakes.

- Format python files with black
- Run python files through isort
- Fix issues reported by flake8
- Add .venv to gitignore
pull/6711/head
Jelte Fennema 2023-02-10 13:25:44 +01:00 committed by GitHub
commit dd51938f20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 823 additions and 532 deletions

View File

@ -6,7 +6,7 @@ orbs:
parameters:
image_suffix:
type: string
default: '-v7e4468f'
default: '-v9a494cd'
pg13_version:
type: string
default: '13.9'
@ -157,8 +157,17 @@ jobs:
steps:
- checkout
- run:
name: 'Check Style'
name: 'Check C Style'
command: citus_indent --check
- run:
name: 'Check Python style'
command: black --check .
- run:
name: 'Check Python import order'
command: isort --check .
- run:
name: 'Check Python lints'
command: flake8 .
- run:
name: 'Fix whitespace'
command: ci/editorconfig.sh && git diff --exit-code

7
.flake8 Normal file
View File

@ -0,0 +1,7 @@
[flake8]
# E203 is ignored for black
# E402 is ignored because of te way we do relative imports
extend-ignore = E203, E402
# black will truncate to 88 characters usually, but long string literals it
# might keep. That's fine in most cases unless it gets really excessive.
max-line-length = 150

1
.gitignore vendored
View File

@ -51,3 +51,4 @@ lib*.pc
# style related temporary outputs
*.uncrustify
.venv

View File

@ -61,6 +61,9 @@ clean: clean-extension clean-pg_send_cancellation
reindent:
${citus_abs_top_srcdir}/ci/fix_style.sh
check-style:
black . --check --quiet
isort . --check --quiet
flake8
cd ${citus_abs_top_srcdir} && citus_indent --quiet --check
.PHONY: reindent check-style

View File

@ -9,6 +9,8 @@ cidir="${0%/*}"
cd ${cidir}/..
citus_indent . --quiet
black . --quiet
isort . --quiet
ci/editorconfig.sh
ci/remove_useless_declarations.sh
ci/disallow_c_comments_in_migrations.sh

5
pyproject.toml Normal file
View File

@ -0,0 +1,5 @@
[tool.isort]
profile = 'black'
[tool.black]
include = '(src/test/regress/bin/diff-filter|\.pyi?|\.ipynb)$'

View File

@ -10,6 +10,10 @@ docopt = "==0.6.2"
cryptography = "==3.4.8"
[dev-packages]
black = "*"
isort = "*"
flake8 = "*"
flake8-bugbear = "*"
[requires]
python_version = "3.9"

View File

@ -1,7 +1,7 @@
{
"_meta": {
"hash": {
"sha256": "09fefbec76b9344107dfed06002546b50adca38da844b604fe8581a7a14fd656"
"sha256": "635b4c111e3bca87373fcdf308febf0a816dde15b14f6bf078f2b456630e5ef1"
},
"pipfile-spec": 6,
"requires": {
@ -122,7 +122,7 @@
"sha256:35824b4c3a97115964b408844d64aa14db1cc518f6562e8d7261699d1350a9e3",
"sha256:4ad3232f5e926d6718ec31cfc1fcadfde020920e278684144551c91769c7bc18"
],
"index": "pypi",
"markers": "python_version >= '3.6'",
"version": "==2022.12.7"
},
"cffi": {
@ -315,49 +315,59 @@
},
"markupsafe": {
"hashes": [
"sha256:0212a68688482dc52b2d45013df70d169f542b7394fc744c02a57374a4207003",
"sha256:089cf3dbf0cd6c100f02945abeb18484bd1ee57a079aefd52cffd17fba910b88",
"sha256:10c1bfff05d95783da83491be968e8fe789263689c02724e0c691933c52994f5",
"sha256:33b74d289bd2f5e527beadcaa3f401e0df0a89927c1559c8566c066fa4248ab7",
"sha256:3799351e2336dc91ea70b034983ee71cf2f9533cdff7c14c90ea126bfd95d65a",
"sha256:3ce11ee3f23f79dbd06fb3d63e2f6af7b12db1d46932fe7bd8afa259a5996603",
"sha256:421be9fbf0ffe9ffd7a378aafebbf6f4602d564d34be190fc19a193232fd12b1",
"sha256:43093fb83d8343aac0b1baa75516da6092f58f41200907ef92448ecab8825135",
"sha256:46d00d6cfecdde84d40e572d63735ef81423ad31184100411e6e3388d405e247",
"sha256:4a33dea2b688b3190ee12bd7cfa29d39c9ed176bda40bfa11099a3ce5d3a7ac6",
"sha256:4b9fe39a2ccc108a4accc2676e77da025ce383c108593d65cc909add5c3bd601",
"sha256:56442863ed2b06d19c37f94d999035e15ee982988920e12a5b4ba29b62ad1f77",
"sha256:671cd1187ed5e62818414afe79ed29da836dde67166a9fac6d435873c44fdd02",
"sha256:694deca8d702d5db21ec83983ce0bb4b26a578e71fbdbd4fdcd387daa90e4d5e",
"sha256:6a074d34ee7a5ce3effbc526b7083ec9731bb3cbf921bbe1d3005d4d2bdb3a63",
"sha256:6d0072fea50feec76a4c418096652f2c3238eaa014b2f94aeb1d56a66b41403f",
"sha256:6fbf47b5d3728c6aea2abb0589b5d30459e369baa772e0f37a0320185e87c980",
"sha256:7f91197cc9e48f989d12e4e6fbc46495c446636dfc81b9ccf50bb0ec74b91d4b",
"sha256:86b1f75c4e7c2ac2ccdaec2b9022845dbb81880ca318bb7a0a01fbf7813e3812",
"sha256:8dc1c72a69aa7e082593c4a203dcf94ddb74bb5c8a731e4e1eb68d031e8498ff",
"sha256:8e3dcf21f367459434c18e71b2a9532d96547aef8a871872a5bd69a715c15f96",
"sha256:8e576a51ad59e4bfaac456023a78f6b5e6e7651dcd383bcc3e18d06f9b55d6d1",
"sha256:96e37a3dc86e80bf81758c152fe66dbf60ed5eca3d26305edf01892257049925",
"sha256:97a68e6ada378df82bc9f16b800ab77cbf4b2fada0081794318520138c088e4a",
"sha256:99a2a507ed3ac881b975a2976d59f38c19386d128e7a9a18b7df6fff1fd4c1d6",
"sha256:a49907dd8420c5685cfa064a1335b6754b74541bbb3706c259c02ed65b644b3e",
"sha256:b09bf97215625a311f669476f44b8b318b075847b49316d3e28c08e41a7a573f",
"sha256:b7bd98b796e2b6553da7225aeb61f447f80a1ca64f41d83612e6139ca5213aa4",
"sha256:b87db4360013327109564f0e591bd2a3b318547bcef31b468a92ee504d07ae4f",
"sha256:bcb3ed405ed3222f9904899563d6fc492ff75cce56cba05e32eff40e6acbeaa3",
"sha256:d4306c36ca495956b6d568d276ac11fdd9c30a36f1b6eb928070dc5360b22e1c",
"sha256:d5ee4f386140395a2c818d149221149c54849dfcfcb9f1debfe07a8b8bd63f9a",
"sha256:dda30ba7e87fbbb7eab1ec9f58678558fd9a6b8b853530e176eabd064da81417",
"sha256:e04e26803c9c3851c931eac40c695602c6295b8d432cbe78609649ad9bd2da8a",
"sha256:e1c0b87e09fa55a220f058d1d49d3fb8df88fbfab58558f1198e08c1e1de842a",
"sha256:e72591e9ecd94d7feb70c1cbd7be7b3ebea3f548870aa91e2732960fa4d57a37",
"sha256:e8c843bbcda3a2f1e3c2ab25913c80a3c5376cd00c6e8c4a86a89a28c8dc5452",
"sha256:efc1913fd2ca4f334418481c7e595c00aad186563bbc1ec76067848c7ca0a933",
"sha256:f121a1420d4e173a5d96e47e9a0c0dcff965afdf1626d28de1460815f7c4ee7a",
"sha256:fc7b548b17d238737688817ab67deebb30e8073c95749d55538ed473130ec0c7"
"sha256:0576fe974b40a400449768941d5d0858cc624e3249dfd1e0c33674e5c7ca7aed",
"sha256:085fd3201e7b12809f9e6e9bc1e5c96a368c8523fad5afb02afe3c051ae4afcc",
"sha256:090376d812fb6ac5f171e5938e82e7f2d7adc2b629101cec0db8b267815c85e2",
"sha256:0b462104ba25f1ac006fdab8b6a01ebbfbce9ed37fd37fd4acd70c67c973e460",
"sha256:137678c63c977754abe9086a3ec011e8fd985ab90631145dfb9294ad09c102a7",
"sha256:1bea30e9bf331f3fef67e0a3877b2288593c98a21ccb2cf29b74c581a4eb3af0",
"sha256:22152d00bf4a9c7c83960521fc558f55a1adbc0631fbb00a9471e097b19d72e1",
"sha256:22731d79ed2eb25059ae3df1dfc9cb1546691cc41f4e3130fe6bfbc3ecbbecfa",
"sha256:2298c859cfc5463f1b64bd55cb3e602528db6fa0f3cfd568d3605c50678f8f03",
"sha256:28057e985dace2f478e042eaa15606c7efccb700797660629da387eb289b9323",
"sha256:2e7821bffe00aa6bd07a23913b7f4e01328c3d5cc0b40b36c0bd81d362faeb65",
"sha256:2ec4f2d48ae59bbb9d1f9d7efb9236ab81429a764dedca114f5fdabbc3788013",
"sha256:340bea174e9761308703ae988e982005aedf427de816d1afe98147668cc03036",
"sha256:40627dcf047dadb22cd25ea7ecfe9cbf3bbbad0482ee5920b582f3809c97654f",
"sha256:40dfd3fefbef579ee058f139733ac336312663c6706d1163b82b3003fb1925c4",
"sha256:4cf06cdc1dda95223e9d2d3c58d3b178aa5dacb35ee7e3bbac10e4e1faacb419",
"sha256:50c42830a633fa0cf9e7d27664637532791bfc31c731a87b202d2d8ac40c3ea2",
"sha256:55f44b440d491028addb3b88f72207d71eeebfb7b5dbf0643f7c023ae1fba619",
"sha256:608e7073dfa9e38a85d38474c082d4281f4ce276ac0010224eaba11e929dd53a",
"sha256:63ba06c9941e46fa389d389644e2d8225e0e3e5ebcc4ff1ea8506dce646f8c8a",
"sha256:65608c35bfb8a76763f37036547f7adfd09270fbdbf96608be2bead319728fcd",
"sha256:665a36ae6f8f20a4676b53224e33d456a6f5a72657d9c83c2aa00765072f31f7",
"sha256:6d6607f98fcf17e534162f0709aaad3ab7a96032723d8ac8750ffe17ae5a0666",
"sha256:7313ce6a199651c4ed9d7e4cfb4aa56fe923b1adf9af3b420ee14e6d9a73df65",
"sha256:7668b52e102d0ed87cb082380a7e2e1e78737ddecdde129acadb0eccc5423859",
"sha256:7df70907e00c970c60b9ef2938d894a9381f38e6b9db73c5be35e59d92e06625",
"sha256:7e007132af78ea9df29495dbf7b5824cb71648d7133cf7848a2a5dd00d36f9ff",
"sha256:835fb5e38fd89328e9c81067fd642b3593c33e1e17e2fdbf77f5676abb14a156",
"sha256:8bca7e26c1dd751236cfb0c6c72d4ad61d986e9a41bbf76cb445f69488b2a2bd",
"sha256:8db032bf0ce9022a8e41a22598eefc802314e81b879ae093f36ce9ddf39ab1ba",
"sha256:99625a92da8229df6d44335e6fcc558a5037dd0a760e11d84be2260e6f37002f",
"sha256:9cad97ab29dfc3f0249b483412c85c8ef4766d96cdf9dcf5a1e3caa3f3661cf1",
"sha256:a4abaec6ca3ad8660690236d11bfe28dfd707778e2442b45addd2f086d6ef094",
"sha256:a6e40afa7f45939ca356f348c8e23048e02cb109ced1eb8420961b2f40fb373a",
"sha256:a6f2fcca746e8d5910e18782f976489939d54a91f9411c32051b4aab2bd7c513",
"sha256:a806db027852538d2ad7555b203300173dd1b77ba116de92da9afbc3a3be3eed",
"sha256:abcabc8c2b26036d62d4c746381a6f7cf60aafcc653198ad678306986b09450d",
"sha256:b8526c6d437855442cdd3d87eede9c425c4445ea011ca38d937db299382e6fa3",
"sha256:bb06feb762bade6bf3c8b844462274db0c76acc95c52abe8dbed28ae3d44a147",
"sha256:c0a33bc9f02c2b17c3ea382f91b4db0e6cde90b63b296422a939886a7a80de1c",
"sha256:c4a549890a45f57f1ebf99c067a4ad0cb423a05544accaf2b065246827ed9603",
"sha256:ca244fa73f50a800cf8c3ebf7fd93149ec37f5cb9596aa8873ae2c1d23498601",
"sha256:cf877ab4ed6e302ec1d04952ca358b381a882fbd9d1b07cccbfd61783561f98a",
"sha256:d9d971ec1e79906046aa3ca266de79eac42f1dbf3612a05dc9368125952bd1a1",
"sha256:da25303d91526aac3672ee6d49a2f3db2d9502a4a60b55519feb1a4c7714e07d",
"sha256:e55e40ff0cc8cc5c07996915ad367fa47da6b3fc091fdadca7f5403239c5fec3",
"sha256:f03a532d7dee1bed20bc4884194a16160a2de9ffc6354b3878ec9682bb623c54",
"sha256:f1cd098434e83e656abf198f103a8207a8187c0fc110306691a2e94a78d0abb2",
"sha256:f2bfb563d0211ce16b63c7cb9395d2c682a23187f54c3d79bfec33e6705473c6",
"sha256:f8ffb705ffcf5ddd0e80b65ddf7bed7ee4f5a441ea7d3419e861a12eaf41af58"
],
"markers": "python_version >= '3.7'",
"version": "==2.1.1"
"version": "==2.1.2"
},
"mitmproxy": {
"editable": true,
@ -518,46 +528,6 @@
"markers": "python_version >= '3'",
"version": "==0.17.16"
},
"ruamel.yaml.clib": {
"hashes": [
"sha256:045e0626baf1c52e5527bd5db361bc83180faaba2ff586e763d3d5982a876a9e",
"sha256:15910ef4f3e537eea7fe45f8a5d19997479940d9196f357152a09031c5be59f3",
"sha256:184faeaec61dbaa3cace407cffc5819f7b977e75360e8d5ca19461cd851a5fc5",
"sha256:1f08fd5a2bea9c4180db71678e850b995d2a5f4537be0e94557668cf0f5f9497",
"sha256:2aa261c29a5545adfef9296b7e33941f46aa5bbd21164228e833412af4c9c75f",
"sha256:3110a99e0f94a4a3470ff67fc20d3f96c25b13d24c6980ff841e82bafe827cac",
"sha256:3243f48ecd450eddadc2d11b5feb08aca941b5cd98c9b1db14b2fd128be8c697",
"sha256:370445fd795706fd291ab00c9df38a0caed0f17a6fb46b0f607668ecb16ce763",
"sha256:40d030e2329ce5286d6b231b8726959ebbe0404c92f0a578c0e2482182e38282",
"sha256:41d0f1fa4c6830176eef5b276af04c89320ea616655d01327d5ce65e50575c94",
"sha256:4a4d8d417868d68b979076a9be6a38c676eca060785abaa6709c7b31593c35d1",
"sha256:4b3a93bb9bc662fc1f99c5c3ea8e623d8b23ad22f861eb6fce9377ac07ad6072",
"sha256:5bc0667c1eb8f83a3752b71b9c4ba55ef7c7058ae57022dd9b29065186a113d9",
"sha256:721bc4ba4525f53f6a611ec0967bdcee61b31df5a56801281027a3a6d1c2daf5",
"sha256:763d65baa3b952479c4e972669f679fe490eee058d5aa85da483ebae2009d231",
"sha256:7bdb4c06b063f6fd55e472e201317a3bb6cdeeee5d5a38512ea5c01e1acbdd93",
"sha256:8831a2cedcd0f0927f788c5bdf6567d9dc9cc235646a434986a852af1cb54b4b",
"sha256:91a789b4aa0097b78c93e3dc4b40040ba55bef518f84a40d4442f713b4094acb",
"sha256:92460ce908546ab69770b2e576e4f99fbb4ce6ab4b245345a3869a0a0410488f",
"sha256:99e77daab5d13a48a4054803d052ff40780278240a902b880dd37a51ba01a307",
"sha256:a234a20ae07e8469da311e182e70ef6b199d0fbeb6c6cc2901204dd87fb867e8",
"sha256:a7b301ff08055d73223058b5c46c55638917f04d21577c95e00e0c4d79201a6b",
"sha256:be2a7ad8fd8f7442b24323d24ba0b56c51219513cfa45b9ada3b87b76c374d4b",
"sha256:bf9a6bc4a0221538b1a7de3ed7bca4c93c02346853f44e1cd764be0023cd3640",
"sha256:c3ca1fbba4ae962521e5eb66d72998b51f0f4d0f608d3c0347a48e1af262efa7",
"sha256:d000f258cf42fec2b1bbf2863c61d7b8918d31ffee905da62dede869254d3b8a",
"sha256:d5859983f26d8cd7bb5c287ef452e8aacc86501487634573d260968f753e1d71",
"sha256:d5e51e2901ec2366b79f16c2299a03e74ba4531ddcfacc1416639c557aef0ad8",
"sha256:debc87a9516b237d0466a711b18b6ebeb17ba9f391eb7f91c649c5c4ec5006c7",
"sha256:df5828871e6648db72d1c19b4bd24819b80a755c4541d3409f0f7acd0f335c80",
"sha256:ecdf1a604009bd35c674b9225a8fa609e0282d9b896c03dd441a91e5f53b534e",
"sha256:efa08d63ef03d079dcae1dfe334f6c8847ba8b645d08df286358b1f5293d24ab",
"sha256:f01da5790e95815eb5a8a138508c01c758e5f5bc0ce4286c4f7028b8dd7ac3d0",
"sha256:f34019dced51047d6f70cb9383b2ae2853b7fc4dce65129a5acd49f4f9256646"
],
"markers": "python_version < '3.10' and platform_python_implementation == 'CPython'",
"version": "==0.2.7"
},
"six": {
"hashes": [
"sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926",
@ -667,5 +637,141 @@
"version": "==0.15.2"
}
},
"develop": {}
"develop": {
"attrs": {
"hashes": [
"sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836",
"sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99"
],
"markers": "python_version >= '3.6'",
"version": "==22.2.0"
},
"black": {
"hashes": [
"sha256:0052dba51dec07ed029ed61b18183942043e00008ec65d5028814afaab9a22fd",
"sha256:0680d4380db3719ebcfb2613f34e86c8e6d15ffeabcf8ec59355c5e7b85bb555",
"sha256:121ca7f10b4a01fd99951234abdbd97728e1240be89fde18480ffac16503d481",
"sha256:162e37d49e93bd6eb6f1afc3e17a3d23a823042530c37c3c42eeeaf026f38468",
"sha256:2a951cc83ab535d248c89f300eccbd625e80ab880fbcfb5ac8afb5f01a258ac9",
"sha256:2bf649fda611c8550ca9d7592b69f0637218c2369b7744694c5e4902873b2f3a",
"sha256:382998821f58e5c8238d3166c492139573325287820963d2f7de4d518bd76958",
"sha256:49f7b39e30f326a34b5c9a4213213a6b221d7ae9d58ec70df1c4a307cf2a1580",
"sha256:57c18c5165c1dbe291d5306e53fb3988122890e57bd9b3dcb75f967f13411a26",
"sha256:7a0f701d314cfa0896b9001df70a530eb2472babb76086344e688829efd97d32",
"sha256:8178318cb74f98bc571eef19068f6ab5613b3e59d4f47771582f04e175570ed8",
"sha256:8b70eb40a78dfac24842458476135f9b99ab952dd3f2dab738c1881a9b38b753",
"sha256:9880d7d419bb7e709b37e28deb5e68a49227713b623c72b2b931028ea65f619b",
"sha256:9afd3f493666a0cd8f8df9a0200c6359ac53940cbde049dcb1a7eb6ee2dd7074",
"sha256:a29650759a6a0944e7cca036674655c2f0f63806ddecc45ed40b7b8aa314b651",
"sha256:a436e7881d33acaf2536c46a454bb964a50eff59b21b51c6ccf5a40601fbef24",
"sha256:a59db0a2094d2259c554676403fa2fac3473ccf1354c1c63eccf7ae65aac8ab6",
"sha256:a8471939da5e824b891b25751955be52ee7f8a30a916d570a5ba8e0f2eb2ecad",
"sha256:b0bd97bea8903f5a2ba7219257a44e3f1f9d00073d6cc1add68f0beec69692ac",
"sha256:b6a92a41ee34b883b359998f0c8e6eb8e99803aa8bf3123bf2b2e6fec505a221",
"sha256:bb460c8561c8c1bec7824ecbc3ce085eb50005883a6203dcfb0122e95797ee06",
"sha256:bfffba28dc52a58f04492181392ee380e95262af14ee01d4bc7bb1b1c6ca8d27",
"sha256:c1c476bc7b7d021321e7d93dc2cbd78ce103b84d5a4cf97ed535fbc0d6660648",
"sha256:c91dfc2c2a4e50df0026f88d2215e166616e0c80e86004d0003ece0488db2739",
"sha256:e6663f91b6feca5d06f2ccd49a10f254f9298cc1f7f49c46e498a0771b507104"
],
"index": "pypi",
"version": "==23.1.0"
},
"click": {
"hashes": [
"sha256:6a7a62563bbfabfda3a38f3023a1db4a35978c0abd76f6c9605ecd6554d6d9b1",
"sha256:8458d7b1287c5fb128c90e23381cf99dcde74beaf6c7ff6384ce84d6fe090adb"
],
"markers": "python_version >= '3.6'",
"version": "==8.0.4"
},
"flake8": {
"hashes": [
"sha256:3833794e27ff64ea4e9cf5d410082a8b97ff1a06c16aa3d2027339cd0f1195c7",
"sha256:c61007e76655af75e6785a931f452915b371dc48f56efd765247c8fe68f2b181"
],
"index": "pypi",
"version": "==6.0.0"
},
"flake8-bugbear": {
"hashes": [
"sha256:04a115e5f9c8e87c38bdbbcdf9f58223ffe05469c07c9a7bd8633330bc4d078b",
"sha256:55902ab5a48c5ea53d8689ecd146eda548e72f2724192b9c1d68f6d975d13c06"
],
"index": "pypi",
"version": "==23.1.20"
},
"isort": {
"hashes": [
"sha256:8bef7dde241278824a6d83f44a544709b065191b95b6e50894bdc722fcba0504",
"sha256:f84c2818376e66cf843d497486ea8fed8700b340f308f076c6fb1229dff318b6"
],
"index": "pypi",
"version": "==5.12.0"
},
"mccabe": {
"hashes": [
"sha256:348e0240c33b60bbdf4e523192ef919f28cb2c3d7d5c7794f74009290f236325",
"sha256:6c2d30ab6be0e4a46919781807b4f0d834ebdd6c6e3dca0bda5a15f863427b6e"
],
"markers": "python_version >= '3.6'",
"version": "==0.7.0"
},
"mypy-extensions": {
"hashes": [
"sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d",
"sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"
],
"markers": "python_version >= '3.5'",
"version": "==1.0.0"
},
"packaging": {
"hashes": [
"sha256:714ac14496c3e68c99c29b00845f7a2b85f3bb6f1078fd9f72fd20f0570002b2",
"sha256:b6ad297f8907de0fa2fe1ccbd26fdaf387f5f47c7275fedf8cce89f99446cf97"
],
"markers": "python_version >= '3.7'",
"version": "==23.0"
},
"pathspec": {
"hashes": [
"sha256:3a66eb970cbac598f9e5ccb5b2cf58930cd8e3ed86d393d541eaf2d8b1705229",
"sha256:64d338d4e0914e91c1792321e6907b5a593f1ab1851de7fc269557a21b30ebbc"
],
"markers": "python_version >= '3.7'",
"version": "==0.11.0"
},
"platformdirs": {
"hashes": [
"sha256:8a1228abb1ef82d788f74139988b137e78692984ec7b08eaa6c65f1723af28f9",
"sha256:b1d5eb14f221506f50d6604a561f4c5786d9e80355219694a1b244bcd96f4567"
],
"markers": "python_version >= '3.7'",
"version": "==3.0.0"
},
"pycodestyle": {
"hashes": [
"sha256:347187bdb476329d98f695c213d7295a846d1152ff4fe9bacb8a9590b8ee7053",
"sha256:8a4eaf0d0495c7395bdab3589ac2db602797d76207242c17d470186815706610"
],
"markers": "python_version >= '3.6'",
"version": "==2.10.0"
},
"pyflakes": {
"hashes": [
"sha256:ec55bf7fe21fff7f1ad2f7da62363d749e2a470500eab1b555334b67aa1ef8cf",
"sha256:ec8b276a6b60bd80defed25add7e439881c19e64850afd9b346283d4165fd0fd"
],
"markers": "python_version >= '3.6'",
"version": "==3.0.1"
},
"tomli": {
"hashes": [
"sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc",
"sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"
],
"markers": "python_version < '3.11'",
"version": "==2.0.1"
}
}
}

View File

@ -1,8 +1,8 @@
#!/usr/bin/env python3
import sys
import random
import os
import random
import sys
if len(sys.argv) != 2:
print(

View File

@ -5,69 +5,72 @@ diff-filter denormalizes diff output by having lines beginning with ' ' or '+'
come from file2's unmodified version.
"""
from sys import argv, stdin, stdout
import re
from sys import argv, stdin, stdout
class FileScanner:
"""
FileScanner is an iterator over the lines of a file.
It can apply a rewrite rule which can be used to skip lines.
"""
def __init__(self, file, rewrite = lambda x:x):
self.file = file
self.line = 1
self.rewrite = rewrite
"""
FileScanner is an iterator over the lines of a file.
It can apply a rewrite rule which can be used to skip lines.
"""
def __next__(self):
while True:
nextline = self.rewrite(next(self.file))
if nextline is not None:
self.line += 1
return nextline
def __init__(self, file, rewrite=lambda x: x):
self.file = file
self.line = 1
self.rewrite = rewrite
def __next__(self):
while True:
nextline = self.rewrite(next(self.file))
if nextline is not None:
self.line += 1
return nextline
def main():
# we only test //d rules, as we need to ignore those lines
regexregex = re.compile(r"^/(?P<rule>.*)/d$")
regexpipeline = []
for line in open(argv[1]):
line = line.strip()
if not line or line.startswith('#') or not line.endswith('d'):
continue
rule = regexregex.match(line)
if not rule:
raise 'Failed to parse regex rule: %s' % line
regexpipeline.append(re.compile(rule.group('rule')))
# we only test //d rules, as we need to ignore those lines
regexregex = re.compile(r"^/(?P<rule>.*)/d$")
regexpipeline = []
for line in open(argv[1]):
line = line.strip()
if not line or line.startswith("#") or not line.endswith("d"):
continue
rule = regexregex.match(line)
if not rule:
raise "Failed to parse regex rule: %s" % line
regexpipeline.append(re.compile(rule.group("rule")))
def sed(line):
if any(regex.search(line) for regex in regexpipeline):
return None
return line
def sed(line):
if any(regex.search(line) for regex in regexpipeline):
return None
return line
for line in stdin:
if line.startswith('+++ '):
tab = line.rindex('\t')
fname = line[4:tab]
file2 = FileScanner(open(fname.replace('.modified', ''), encoding='utf8'), sed)
stdout.write(line)
elif line.startswith('@@ '):
idx_start = line.index('+') + 1
idx_end = idx_start + 1
while line[idx_end].isdigit():
idx_end += 1
linenum = int(line[idx_start:idx_end])
while file2.line < linenum:
next(file2)
stdout.write(line)
elif line.startswith(' '):
stdout.write(' ')
stdout.write(next(file2))
elif line.startswith('+'):
stdout.write('+')
stdout.write(next(file2))
else:
stdout.write(line)
for line in stdin:
if line.startswith("+++ "):
tab = line.rindex("\t")
fname = line[4:tab]
file2 = FileScanner(
open(fname.replace(".modified", ""), encoding="utf8"), sed
)
stdout.write(line)
elif line.startswith("@@ "):
idx_start = line.index("+") + 1
idx_end = idx_start + 1
while line[idx_end].isdigit():
idx_end += 1
linenum = int(line[idx_start:idx_end])
while file2.line < linenum:
next(file2)
stdout.write(line)
elif line.startswith(" "):
stdout.write(" ")
stdout.write(next(file2))
elif line.startswith("+"):
stdout.write("+")
stdout.write(next(file2))
else:
stdout.write(line)
main()

View File

@ -12,21 +12,23 @@ Options:
--seed=<seed> random number seed
--base whether to use the base sql schedule or not
"""
import os
import shutil
import sys
import os, shutil
# https://stackoverflow.com/questions/14132789/relative-imports-for-the-billionth-time/14132912#14132912
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import common
import config as cfg
import concurrent.futures
import multiprocessing
from docopt import docopt
import time
import random
import time
import common
from docopt import docopt
import config as cfg
testResults = {}
parallel_thread_amount = 1
@ -115,7 +117,6 @@ def copy_copy_modified_binary(datadir):
def copy_test_files(config):
sql_dir_path = os.path.join(config.datadir, "sql")
expected_dir_path = os.path.join(config.datadir, "expected")
@ -132,7 +133,9 @@ def copy_test_files(config):
line = line[colon_index + 1 :].strip()
test_names = line.split(" ")
copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config)
copy_test_files_with_names(
test_names, sql_dir_path, expected_dir_path, config
)
def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, config):
@ -140,10 +143,10 @@ def copy_test_files_with_names(test_names, sql_dir_path, expected_dir_path, conf
# make empty files for the skipped tests
if test_name in config.skip_tests:
expected_sql_file = os.path.join(sql_dir_path, test_name + ".sql")
open(expected_sql_file, 'x').close()
open(expected_sql_file, "x").close()
expected_out_file = os.path.join(expected_dir_path, test_name + ".out")
open(expected_out_file, 'x').close()
open(expected_out_file, "x").close()
continue

View File

@ -1,12 +1,12 @@
import os
import shutil
import sys
import subprocess
import atexit
import concurrent.futures
import os
import shutil
import subprocess
import sys
import utils
from utils import USER, cd
from utils import USER
def initialize_temp_dir(temp_dir):
@ -27,13 +27,11 @@ def initialize_temp_dir_if_not_exists(temp_dir):
def parallel_run(function, items, *args, **kwargs):
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(function, item, *args, **kwargs)
for item in items
]
futures = [executor.submit(function, item, *args, **kwargs) for item in items]
for future in futures:
future.result()
def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names):
subprocess.run(["mkdir", rel_data_path], check=True)
@ -52,7 +50,7 @@ def initialize_db_for_cluster(pg_path, rel_data_path, settings, node_names):
"--encoding",
"UTF8",
"--locale",
"POSIX"
"POSIX",
]
subprocess.run(command, check=True)
add_settings(abs_data_path, settings)
@ -72,11 +70,16 @@ def add_settings(abs_data_path, settings):
def create_role(pg_path, node_ports, user_name):
def create(port):
command = "SET citus.enable_ddl_propagation TO OFF; SELECT worker_create_or_alter_role('{}', 'CREATE ROLE {} WITH LOGIN CREATEROLE CREATEDB;', NULL)".format(
user_name, user_name
command = (
"SET citus.enable_ddl_propagation TO OFF;"
+ "SELECT worker_create_or_alter_role('{}', 'CREATE ROLE {} WITH LOGIN CREATEROLE CREATEDB;', NULL)".format(
user_name, user_name
)
)
utils.psql(pg_path, port, command)
command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format(user_name)
command = "SET citus.enable_ddl_propagation TO OFF; GRANT CREATE ON DATABASE postgres to {}".format(
user_name
)
utils.psql(pg_path, port, command)
parallel_run(create, node_ports)
@ -89,7 +92,9 @@ def coordinator_should_haveshards(pg_path, port):
utils.psql(pg_path, port, command)
def start_databases(pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables):
def start_databases(
pg_path, rel_data_path, node_name_to_ports, logfile_prefix, env_variables
):
def start(node_name):
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
node_port = node_name_to_ports[node_name]
@ -248,7 +253,12 @@ def logfile_name(logfile_prefix, node_name):
def stop_databases(
pg_path, rel_data_path, node_name_to_ports, logfile_prefix, no_output=False, parallel=True
pg_path,
rel_data_path,
node_name_to_ports,
logfile_prefix,
no_output=False,
parallel=True,
):
def stop(node_name):
abs_data_path = os.path.abspath(os.path.join(rel_data_path, node_name))
@ -287,7 +297,9 @@ def initialize_citus_cluster(bindir, datadir, settings, config):
initialize_db_for_cluster(
bindir, datadir, settings, config.node_name_to_ports.keys()
)
start_databases(bindir, datadir, config.node_name_to_ports, config.name, config.env_variables)
start_databases(
bindir, datadir, config.node_name_to_ports, config.name, config.env_variables
)
create_citus_extension(bindir, config.node_name_to_ports.values())
add_workers(bindir, config.worker_ports, config.coordinator_port())
if not config.is_mx:
@ -296,6 +308,7 @@ def initialize_citus_cluster(bindir, datadir, settings, config):
add_coordinator_to_metadata(bindir, config.coordinator_port())
config.setup_steps()
def eprint(*args, **kwargs):
"""eprint prints to stderr"""

View File

@ -1,11 +1,12 @@
from os.path import expanduser
import inspect
import os
import random
import socket
from contextlib import closing
import os
import threading
from contextlib import closing
from os.path import expanduser
import common
import inspect
COORDINATOR_NAME = "coordinator"
WORKER1 = "worker1"
@ -57,8 +58,9 @@ port_lock = threading.Lock()
def should_include_config(class_name):
if inspect.isclass(class_name) and issubclass(class_name, CitusDefaultClusterConfig):
if inspect.isclass(class_name) and issubclass(
class_name, CitusDefaultClusterConfig
):
return True
return False
@ -73,7 +75,7 @@ def find_free_port():
port = next_port
next_port += 1
return port
except:
except Exception:
next_port += 1
# we couldn't find a port
raise Exception("Couldn't find a port to use")
@ -167,7 +169,9 @@ class CitusDefaultClusterConfig(CitusBaseClusterConfig):
self.add_coordinator_to_metadata = True
self.skip_tests = [
# Alter Table statement cannot be run from an arbitrary node so this test will fail
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
"arbitrary_configs_alter_table_add_constraint_without_name_create",
"arbitrary_configs_alter_table_add_constraint_without_name",
]
class CitusUpgradeConfig(CitusBaseClusterConfig):
@ -190,9 +194,13 @@ class PostgresConfig(CitusDefaultClusterConfig):
self.new_settings = {
"citus.use_citus_managed_tables": False,
}
self.skip_tests = ["nested_execution",
self.skip_tests = [
"nested_execution",
# Alter Table statement cannot be run from an arbitrary node so this test will fail
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
"arbitrary_configs_alter_table_add_constraint_without_name_create",
"arbitrary_configs_alter_table_add_constraint_without_name",
]
class CitusSingleNodeClusterConfig(CitusDefaultClusterConfig):
def __init__(self, arguments):
@ -229,7 +237,7 @@ class CitusSmallSharedPoolSizeConfig(CitusDefaultClusterConfig):
def __init__(self, arguments):
super().__init__(arguments)
self.new_settings = {
"citus.local_shared_pool_size": 5,
"citus.local_shared_pool_size": 5,
"citus.max_shared_pool_size": 5,
}
@ -275,7 +283,7 @@ class CitusUnusualExecutorConfig(CitusDefaultClusterConfig):
# this setting does not necessarily need to be here
# could go any other test
self.env_variables = {'PGAPPNAME' : 'test_app'}
self.env_variables = {"PGAPPNAME": "test_app"}
class CitusSmallCopyBuffersConfig(CitusDefaultClusterConfig):
@ -307,9 +315,13 @@ class CitusUnusualQuerySettingsConfig(CitusDefaultClusterConfig):
# requires the table with the fk to be converted to a citus_local_table.
# As of c11, there is no way to do that through remote execution so this test
# will fail
"arbitrary_configs_truncate_cascade_create", "arbitrary_configs_truncate_cascade",
# Alter Table statement cannot be run from an arbitrary node so this test will fail
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
"arbitrary_configs_truncate_cascade_create",
"arbitrary_configs_truncate_cascade",
# Alter Table statement cannot be run from an arbitrary node so this test will fail
"arbitrary_configs_alter_table_add_constraint_without_name_create",
"arbitrary_configs_alter_table_add_constraint_without_name",
]
class CitusSingleNodeSingleShardClusterConfig(CitusDefaultClusterConfig):
def __init__(self, arguments):
@ -328,15 +340,20 @@ class CitusShardReplicationFactorClusterConfig(CitusDefaultClusterConfig):
self.skip_tests = [
# citus does not support foreign keys in distributed tables
# when citus.shard_replication_factor >= 2
"arbitrary_configs_truncate_partition_create", "arbitrary_configs_truncate_partition",
"arbitrary_configs_truncate_partition_create",
"arbitrary_configs_truncate_partition",
# citus does not support modifying a partition when
# citus.shard_replication_factor >= 2
"arbitrary_configs_truncate_cascade_create", "arbitrary_configs_truncate_cascade",
"arbitrary_configs_truncate_cascade_create",
"arbitrary_configs_truncate_cascade",
# citus does not support colocating functions with distributed tables when
# citus.shard_replication_factor >= 2
"function_create", "functions",
# Alter Table statement cannot be run from an arbitrary node so this test will fail
"arbitrary_configs_alter_table_add_constraint_without_name_create", "arbitrary_configs_alter_table_add_constraint_without_name"]
"function_create",
"functions",
# Alter Table statement cannot be run from an arbitrary node so this test will fail
"arbitrary_configs_alter_table_add_constraint_without_name_create",
"arbitrary_configs_alter_table_add_constraint_without_name",
]
class CitusSingleShardClusterConfig(CitusDefaultClusterConfig):

View File

@ -1,42 +1,74 @@
#!/usr/bin/env python3
import sys
import argparse
import os
import pathlib
from glob import glob
import argparse
import shutil
import random
import re
import shutil
import sys
from glob import glob
import common
import config
args = argparse.ArgumentParser()
args.add_argument("test_name", help="Test name (must be included in a schedule.)", nargs='?')
args.add_argument("-p", "--path", required=False, help="Relative path for test file (must have a .sql or .spec extension)", type=pathlib.Path)
args.add_argument(
"test_name", help="Test name (must be included in a schedule.)", nargs="?"
)
args.add_argument(
"-p",
"--path",
required=False,
help="Relative path for test file (must have a .sql or .spec extension)",
type=pathlib.Path,
)
args.add_argument("-r", "--repeat", help="Number of test to run", type=int, default=1)
args.add_argument("-b", "--use-base-schedule", required=False, help="Choose base-schedules rather than minimal-schedules", action='store_true')
args.add_argument("-w", "--use-whole-schedule-line", required=False, help="Use the whole line found in related schedule", action='store_true')
args.add_argument("--valgrind", required=False, help="Run the test with valgrind enabled", action='store_true')
args.add_argument(
"-b",
"--use-base-schedule",
required=False,
help="Choose base-schedules rather than minimal-schedules",
action="store_true",
)
args.add_argument(
"-w",
"--use-whole-schedule-line",
required=False,
help="Use the whole line found in related schedule",
action="store_true",
)
args.add_argument(
"--valgrind",
required=False,
help="Run the test with valgrind enabled",
action="store_true",
)
args = vars(args.parse_args())
regress_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
test_file_path = args['path']
test_file_name = args['test_name']
use_base_schedule = args['use_base_schedule']
use_whole_schedule_line = args['use_whole_schedule_line']
test_file_path = args["path"]
test_file_name = args["test_name"]
use_base_schedule = args["use_base_schedule"]
use_whole_schedule_line = args["use_whole_schedule_line"]
test_files_to_skip = ['multi_cluster_management', 'multi_extension', 'multi_test_helpers', 'multi_insert_select']
test_files_to_run_without_schedule = ['single_node_enterprise']
test_files_to_skip = [
"multi_cluster_management",
"multi_extension",
"multi_test_helpers",
"multi_insert_select",
]
test_files_to_run_without_schedule = ["single_node_enterprise"]
if not (test_file_name or test_file_path):
print(f"FATAL: No test given.")
print("FATAL: No test given.")
sys.exit(2)
if test_file_path:
test_file_path = os.path.join(os.getcwd(), args['path'])
test_file_path = os.path.join(os.getcwd(), args["path"])
if not os.path.isfile(test_file_path):
print(f"ERROR: test file '{test_file_path}' does not exist")
@ -45,7 +77,7 @@ if test_file_path:
test_file_extension = pathlib.Path(test_file_path).suffix
test_file_name = pathlib.Path(test_file_path).stem
if not test_file_extension in '.spec.sql':
if test_file_extension not in ".spec.sql":
print(
"ERROR: Unrecognized test extension. Valid extensions are: .sql and .spec"
)
@ -56,73 +88,73 @@ if test_file_name in test_files_to_skip:
print(f"WARNING: Skipping exceptional test: '{test_file_name}'")
sys.exit(0)
test_schedule = ''
test_schedule = ""
# find related schedule
for schedule_file_path in sorted(glob(os.path.join(regress_dir, "*_schedule"))):
for schedule_line in open(schedule_file_path, 'r'):
if re.search(r'\b' + test_file_name + r'\b', schedule_line):
test_schedule = pathlib.Path(schedule_file_path).stem
if use_whole_schedule_line:
test_schedule_line = schedule_line
else:
test_schedule_line = f"test: {test_file_name}\n"
break
else:
continue
break
for schedule_line in open(schedule_file_path, "r"):
if re.search(r"\b" + test_file_name + r"\b", schedule_line):
test_schedule = pathlib.Path(schedule_file_path).stem
if use_whole_schedule_line:
test_schedule_line = schedule_line
else:
test_schedule_line = f"test: {test_file_name}\n"
break
else:
continue
break
# map suitable schedule
if not test_schedule:
print(
f"WARNING: Could not find any schedule for '{test_file_name}'"
)
print(f"WARNING: Could not find any schedule for '{test_file_name}'")
sys.exit(0)
elif "isolation" in test_schedule:
test_schedule = 'base_isolation_schedule'
test_schedule = "base_isolation_schedule"
elif "failure" in test_schedule:
test_schedule = 'failure_base_schedule'
test_schedule = "failure_base_schedule"
elif "enterprise" in test_schedule:
test_schedule = 'enterprise_minimal_schedule'
test_schedule = "enterprise_minimal_schedule"
elif "split" in test_schedule:
test_schedule = 'minimal_schedule'
test_schedule = "minimal_schedule"
elif "mx" in test_schedule:
if use_base_schedule:
test_schedule = 'mx_base_schedule'
test_schedule = "mx_base_schedule"
else:
test_schedule = 'mx_minimal_schedule'
test_schedule = "mx_minimal_schedule"
elif "operations" in test_schedule:
test_schedule = 'minimal_schedule'
test_schedule = "minimal_schedule"
elif test_schedule in config.ARBITRARY_SCHEDULE_NAMES:
print(f"WARNING: Arbitrary config schedule ({test_schedule}) is not supported.")
sys.exit(0)
else:
if use_base_schedule:
test_schedule = 'base_schedule'
test_schedule = "base_schedule"
else:
test_schedule = 'minimal_schedule'
test_schedule = "minimal_schedule"
# copy base schedule to a temp file and append test_schedule_line
# to be able to run tests in parallel (if test_schedule_line is a parallel group.)
tmp_schedule_path = os.path.join(regress_dir, f"tmp_schedule_{ random.randint(1, 10000)}")
tmp_schedule_path = os.path.join(
regress_dir, f"tmp_schedule_{ random.randint(1, 10000)}"
)
# some tests don't need a schedule to run
# e.g tests that are in the first place in their own schedule
if test_file_name not in test_files_to_run_without_schedule:
shutil.copy2(os.path.join(regress_dir, test_schedule), tmp_schedule_path)
with open(tmp_schedule_path, "a") as myfile:
for i in range(args['repeat']):
myfile.write(test_schedule_line)
for _ in range(args["repeat"]):
myfile.write(test_schedule_line)
# find suitable make recipe
if "isolation" in test_schedule:
make_recipe = 'check-isolation-custom-schedule'
make_recipe = "check-isolation-custom-schedule"
elif "failure" in test_schedule:
make_recipe = 'check-failure-custom-schedule'
make_recipe = "check-failure-custom-schedule"
else:
make_recipe = 'check-custom-schedule'
make_recipe = "check-custom-schedule"
if args['valgrind']:
make_recipe += '-vg'
if args["valgrind"]:
make_recipe += "-vg"
# prepare command to run tests
test_command = f"make -C {regress_dir} {make_recipe} SCHEDULE='{pathlib.Path(tmp_schedule_path).stem}'"

View File

@ -13,32 +13,30 @@ Options:
--mixed Run the verification phase with one node not upgraded.
"""
import subprocess
import atexit
import os
import re
import subprocess
import sys
# https://stackoverflow.com/questions/14132789/relative-imports-for-the-billionth-time/14132912#14132912
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import common
import utils
from docopt import docopt
from utils import USER
from docopt import docopt
from config import (
CitusUpgradeConfig,
CITUS_VERSION_SQL,
MASTER_VERSION,
AFTER_CITUS_UPGRADE_COORD_SCHEDULE,
BEFORE_CITUS_UPGRADE_COORD_SCHEDULE,
CITUS_VERSION_SQL,
MASTER_VERSION,
MIXED_AFTER_CITUS_UPGRADE_SCHEDULE,
MIXED_BEFORE_CITUS_UPGRADE_SCHEDULE,
CitusUpgradeConfig,
)
import common
def main(config):
install_citus(config.pre_tar_path)
@ -96,7 +94,7 @@ def remove_citus(tar_path):
def remove_tar_files(tar_path):
ps = subprocess.Popen(("tar", "tf", tar_path), stdout=subprocess.PIPE)
output = subprocess.check_output(("xargs", "rm", "-v"), stdin=ps.stdout)
subprocess.check_output(("xargs", "rm", "-v"), stdin=ps.stdout)
ps.wait()

View File

@ -10,23 +10,25 @@ Options:
--pgxsdir=<pgxsdir> Path to the PGXS directory(ex: ~/.pgenv/src/postgresql-11.3)
"""
import sys, os
import os
import sys
# https://stackoverflow.com/questions/14132789/relative-imports-for-the-billionth-time/14132912#14132912
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import (
PGUpgradeConfig,
AFTER_PG_UPGRADE_SCHEDULE,
BEFORE_PG_UPGRADE_SCHEDULE,
)
from docopt import docopt
import utils
from utils import USER
import atexit
import subprocess
import common
import utils
from docopt import docopt
from utils import USER
from config import (
AFTER_PG_UPGRADE_SCHEDULE,
BEFORE_PG_UPGRADE_SCHEDULE,
PGUpgradeConfig,
)
def citus_prepare_pg_upgrade(pg_path, node_ports):
@ -112,7 +114,11 @@ def main(config):
config.node_name_to_ports.keys(),
)
common.start_databases(
config.new_bindir, config.new_datadir, config.node_name_to_ports, config.name, {}
config.new_bindir,
config.new_datadir,
config.node_name_to_ports,
config.name,
{},
)
citus_finish_pg_upgrade(config.new_bindir, config.node_name_to_ports.values())

View File

@ -1,6 +1,5 @@
import subprocess
import os
import subprocess
USER = "postgres"

View File

@ -1,34 +1,34 @@
from collections import defaultdict
from itertools import count
import logging
import re
import os
import queue
import re
import signal
import socket
import struct
import threading
import time
import traceback
import queue
from construct.lib import ListContainer
from mitmproxy import ctx, tcp
from itertools import count
import structs
from construct.lib import ListContainer
from mitmproxy import ctx, tcp
logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=logging.DEBUG)
# I. Command Strings
class Handler:
'''
"""
This class hierarchy serves two purposes:
1. Allow command strings to be evaluated. Once evaluated you'll have a Handler you can
pass packets to
2. Process packets as they come in and decide what to do with them.
Subclasses which want to change how packets are handled should override _handle.
'''
"""
def __init__(self, root=None):
# all packets are first sent to the root handler to be processed
self.root = root if root else self
@ -38,30 +38,31 @@ class Handler:
def _accept(self, flow, message):
result = self._handle(flow, message)
if result == 'pass':
if result == "pass":
# defer to our child
if not self.next:
raise Exception("we don't know what to do!")
if self.next._accept(flow, message) == 'stop':
if self.next._accept(flow, message) == "stop":
if self.root is not self:
return 'stop'
return "stop"
self.next = KillHandler(self)
flow.kill()
else:
return result
def _handle(self, flow, message):
'''
"""
Handlers can return one of three things:
- "done" tells the parent to stop processing. This performs the default action,
which is to allow the packet to be sent.
- "pass" means to delegate to self.next and do whatever it wants
- "stop" means all processing will stop, and all connections will be killed
'''
"""
# subclasses must implement this
raise NotImplementedError()
class FilterableMixin:
def contains(self, pattern):
self.next = Contains(self.root, pattern)
@ -76,7 +77,7 @@ class FilterableMixin:
return self.next
def __getattr__(self, attr):
'''
"""
Methods such as .onQuery trigger when a packet with that name is intercepted
Adds support for commands such as:
@ -85,14 +86,17 @@ class FilterableMixin:
Returns a function because the above command is resolved in two steps:
conn.onQuery becomes conn.__getattr__("onQuery")
conn.onQuery(query="COPY") becomes conn.__getattr__("onQuery")(query="COPY")
'''
if attr.startswith('on'):
"""
if attr.startswith("on"):
def doit(**kwargs):
self.next = OnPacket(self.root, attr[2:], kwargs)
return self.next
return doit
raise AttributeError
class ActionsMixin:
def kill(self):
self.next = KillHandler(self.root)
@ -118,31 +122,39 @@ class ActionsMixin:
self.next = ConnectDelayHandler(self.root, timeMs)
return self.next
class AcceptHandler(Handler):
def __init__(self, root):
super().__init__(root)
def _handle(self, flow, message):
return 'done'
return "done"
class KillHandler(Handler):
def __init__(self, root):
super().__init__(root)
def _handle(self, flow, message):
flow.kill()
return 'done'
return "done"
class KillAllHandler(Handler):
def __init__(self, root):
super().__init__(root)
def _handle(self, flow, message):
return 'stop'
return "stop"
class ResetHandler(Handler):
# try to force a RST to be sent, something went very wrong!
def __init__(self, root):
super().__init__(root)
def _handle(self, flow, message):
flow.kill() # tell mitmproxy this connection should be closed
flow.kill() # tell mitmproxy this connection should be closed
# this is a mitmproxy.connections.ClientConnection(mitmproxy.tcp.BaseHandler)
client_conn = flow.client_conn
@ -152,8 +164,9 @@ class ResetHandler(Handler):
# cause linux to send a RST
LINGER_ON, LINGER_TIMEOUT = 1, 0
conn.setsockopt(
socket.SOL_SOCKET, socket.SO_LINGER,
struct.pack('ii', LINGER_ON, LINGER_TIMEOUT)
socket.SOL_SOCKET,
socket.SO_LINGER,
struct.pack("ii", LINGER_ON, LINGER_TIMEOUT),
)
conn.close()
@ -161,28 +174,35 @@ class ResetHandler(Handler):
# tries to call conn.shutdown(), but there's nothing else to clean up so that's
# maybe okay
return 'done'
return "done"
class CancelHandler(Handler):
'Send a SIGINT to the process'
"Send a SIGINT to the process"
def __init__(self, root, pid):
super().__init__(root)
self.pid = pid
def _handle(self, flow, message):
os.kill(self.pid, signal.SIGINT)
# give the signal a chance to be received before we let the packet through
time.sleep(0.1)
return 'done'
return "done"
class ConnectDelayHandler(Handler):
'Delay the initial packet by sleeping before deciding what to do'
"Delay the initial packet by sleeping before deciding what to do"
def __init__(self, root, timeMs):
super().__init__(root)
self.timeMs = timeMs
def _handle(self, flow, message):
if message.is_initial:
time.sleep(self.timeMs/1000.0)
return 'done'
time.sleep(self.timeMs / 1000.0)
return "done"
class Contains(Handler, ActionsMixin, FilterableMixin):
def __init__(self, root, pattern):
@ -191,8 +211,9 @@ class Contains(Handler, ActionsMixin, FilterableMixin):
def _handle(self, flow, message):
if self.pattern in message.content:
return 'pass'
return 'done'
return "pass"
return "done"
class Matches(Handler, ActionsMixin, FilterableMixin):
def __init__(self, root, pattern):
@ -201,47 +222,56 @@ class Matches(Handler, ActionsMixin, FilterableMixin):
def _handle(self, flow, message):
if self.pattern.search(message.content):
return 'pass'
return 'done'
return "pass"
return "done"
class After(Handler, ActionsMixin, FilterableMixin):
"Don't pass execution to our child until we've handled 'times' messages"
def __init__(self, root, times):
super().__init__(root)
self.target = times
def _handle(self, flow, message):
if not hasattr(flow, '_after_count'):
if not hasattr(flow, "_after_count"):
flow._after_count = 0
if flow._after_count >= self.target:
return 'pass'
return "pass"
flow._after_count += 1
return 'done'
return "done"
class OnPacket(Handler, ActionsMixin, FilterableMixin):
'''Triggers when a packet of the specified kind comes around'''
"""Triggers when a packet of the specified kind comes around"""
def __init__(self, root, packet_kind, kwargs):
super().__init__(root)
self.packet_kind = packet_kind
self.filters = kwargs
def _handle(self, flow, message):
if not message.parsed:
# if this is the first message in the connection we just skip it
return 'done'
return "done"
for msg in message.parsed:
typ = structs.message_type(msg, from_frontend=message.from_client)
if typ == self.packet_kind:
matches = structs.message_matches(msg, self.filters, message.from_client)
matches = structs.message_matches(
msg, self.filters, message.from_client
)
if matches:
return 'pass'
return 'done'
return "pass"
return "done"
class RootHandler(Handler, ActionsMixin, FilterableMixin):
def _handle(self, flow, message):
# do whatever the next Handler tells us to do
return 'pass'
return "pass"
class RecorderCommand:
def __init__(self):
@ -250,23 +280,26 @@ class RecorderCommand:
def dump(self):
# When the user calls dump() we return everything we've captured
self.command = 'dump'
self.command = "dump"
return self
def reset(self):
# If the user calls reset() we dump all captured packets without returning them
self.command = 'reset'
self.command = "reset"
return self
# II. Utilities for interfacing with mitmproxy
def build_handler(spec):
'Turns a command string into a RootHandler ready to accept packets'
"Turns a command string into a RootHandler ready to accept packets"
root = RootHandler()
recorder = RecorderCommand()
handler = eval(spec, {'__builtins__': {}}, {'conn': root, 'recorder': recorder})
handler = eval(spec, {"__builtins__": {}}, {"conn": root, "recorder": recorder})
return handler.root
# a bunch of globals
handler = None # the current handler used to process packets
@ -274,25 +307,25 @@ command_thread = None # sits on the fifo and waits for new commands to come in
captured_messages = queue.Queue() # where we store messages used for recorder.dump()
connection_count = count() # so we can give connections ids in recorder.dump()
def listen_for_commands(fifoname):
def listen_for_commands(fifoname):
def emit_row(conn, from_client, message):
# we're using the COPY text format. It requires us to escape backslashes
cleaned = message.replace('\\', '\\\\')
source = 'coordinator' if from_client else 'worker'
return '{}\t{}\t{}'.format(conn, source, cleaned)
cleaned = message.replace("\\", "\\\\")
source = "coordinator" if from_client else "worker"
return "{}\t{}\t{}".format(conn, source, cleaned)
def emit_message(message):
if message.is_initial:
return emit_row(
message.connection_id, message.from_client, '[initial message]'
message.connection_id, message.from_client, "[initial message]"
)
pretty = structs.print(message.parsed)
return emit_row(message.connection_id, message.from_client, pretty)
def all_items(queue_):
'Pulls everything out of the queue without blocking'
"Pulls everything out of the queue without blocking"
try:
while True:
yield queue_.get(block=False)
@ -300,23 +333,27 @@ def listen_for_commands(fifoname):
pass
def drop_terminate_messages(messages):
'''
"""
Terminate() messages happen eventually, Citus doesn't feel any need to send them
immediately, so tests which embed them aren't reproducible and fail to timing
issues. Here we simply drop those messages.
'''
"""
def isTerminate(msg, from_client):
kind = structs.message_type(msg, from_client)
return kind == 'Terminate'
return kind == "Terminate"
for message in messages:
if not message.parsed:
yield message
continue
message.parsed = ListContainer([
msg for msg in message.parsed
if not isTerminate(msg, message.from_client)
])
message.parsed = ListContainer(
[
msg
for msg in message.parsed
if not isTerminate(msg, message.from_client)
]
)
message.parsed.from_frontend = message.from_client
if len(message.parsed) == 0:
continue
@ -324,35 +361,35 @@ def listen_for_commands(fifoname):
def handle_recorder(recorder):
global connection_count
result = ''
result = ""
if recorder.command == 'reset':
result = ''
if recorder.command == "reset":
result = ""
connection_count = count()
elif recorder.command != 'dump':
elif recorder.command != "dump":
# this should never happen
raise Exception('Unrecognized command: {}'.format(recorder.command))
raise Exception("Unrecognized command: {}".format(recorder.command))
results = []
messages = all_items(captured_messages)
messages = drop_terminate_messages(messages)
for message in messages:
if recorder.command == 'reset':
if recorder.command == "reset":
continue
results.append(emit_message(message))
result = '\n'.join(results)
result = "\n".join(results)
logging.debug('about to write to fifo')
with open(fifoname, mode='w') as fifo:
logging.debug('successfully opened the fifo for writing')
fifo.write('{}'.format(result))
logging.debug("about to write to fifo")
with open(fifoname, mode="w") as fifo:
logging.debug("successfully opened the fifo for writing")
fifo.write("{}".format(result))
while True:
logging.debug('about to read from fifo')
with open(fifoname, mode='r') as fifo:
logging.debug('successfully opened the fifo for reading')
logging.debug("about to read from fifo")
with open(fifoname, mode="r") as fifo:
logging.debug("successfully opened the fifo for reading")
slug = fifo.read()
logging.info('received new command: %s', slug.rstrip())
logging.info("received new command: %s", slug.rstrip())
try:
handler = build_handler(slug)
@ -371,13 +408,14 @@ def listen_for_commands(fifoname):
except Exception as e:
result = str(e)
else:
result = ''
result = ""
logging.debug("about to write to fifo")
with open(fifoname, mode="w") as fifo:
logging.debug("successfully opened the fifo for writing")
fifo.write("{}\n".format(result))
logging.info("responded to command: %s", result.split("\n")[0])
logging.debug('about to write to fifo')
with open(fifoname, mode='w') as fifo:
logging.debug('successfully opened the fifo for writing')
fifo.write('{}\n'.format(result))
logging.info('responded to command: %s', result.split("\n")[0])
def create_thread(fifoname):
global command_thread
@ -388,42 +426,46 @@ def create_thread(fifoname):
return
if command_thread:
print('cannot change the fifo path once mitmproxy has started');
print("cannot change the fifo path once mitmproxy has started")
return
command_thread = threading.Thread(target=listen_for_commands, args=(fifoname,), daemon=True)
command_thread = threading.Thread(
target=listen_for_commands, args=(fifoname,), daemon=True
)
command_thread.start()
# III. mitmproxy callbacks
def load(loader):
loader.add_option('slug', str, 'conn.allow()', "A script to run")
loader.add_option('fifo', str, '', "Which fifo to listen on for commands")
loader.add_option("slug", str, "conn.allow()", "A script to run")
loader.add_option("fifo", str, "", "Which fifo to listen on for commands")
def configure(updated):
global handler
if 'slug' in updated:
if "slug" in updated:
text = ctx.options.slug
handler = build_handler(text)
if 'fifo' in updated:
if "fifo" in updated:
fifoname = ctx.options.fifo
create_thread(fifoname)
def tcp_message(flow: tcp.TCPFlow):
'''
"""
This callback is hit every time mitmproxy receives a packet. It's the main entrypoint
into this script.
'''
"""
global connection_count
tcp_msg = flow.messages[-1]
# Keep track of all the different connections, assign a unique id to each
if not hasattr(flow, 'connection_id'):
if not hasattr(flow, "connection_id"):
flow.connection_id = next(connection_count)
tcp_msg.connection_id = flow.connection_id
@ -434,7 +476,9 @@ def tcp_message(flow: tcp.TCPFlow):
# skip parsing initial messages for now, they're not important
tcp_msg.parsed = None
else:
tcp_msg.parsed = structs.parse(tcp_msg.content, from_frontend=tcp_msg.from_client)
tcp_msg.parsed = structs.parse(
tcp_msg.content, from_frontend=tcp_msg.from_client
)
# record the message, for debugging purposes
captured_messages.put(tcp_msg)

View File

@ -1,33 +1,53 @@
from construct import (
Struct,
Int8ub, Int16ub, Int32ub, Int16sb, Int32sb,
Bytes, CString, Computed, Switch, Seek, this, Pointer,
GreedyRange, Enum, Byte, Probe, FixedSized, RestreamData, GreedyBytes, Array
)
import construct.lib as cl
import re
import construct.lib as cl
from construct import (
Array,
Byte,
Bytes,
Computed,
CString,
Enum,
GreedyBytes,
GreedyRange,
Int8ub,
Int16sb,
Int16ub,
Int32sb,
Int32ub,
RestreamData,
Struct,
Switch,
this,
)
# For all possible message formats see:
# https://www.postgresql.org/docs/current/protocol-message-formats.html
class MessageMeta(type):
def __init__(cls, name, bases, namespace):
'''
"""
__init__ is called every time a subclass of MessageMeta is declared
'''
"""
if not hasattr(cls, "_msgtypes"):
raise Exception("classes which use MessageMeta must have a '_msgtypes' field")
raise Exception(
"classes which use MessageMeta must have a '_msgtypes' field"
)
if not hasattr(cls, "_classes"):
raise Exception("classes which use MessageMeta must have a '_classes' field")
raise Exception(
"classes which use MessageMeta must have a '_classes' field"
)
if not hasattr(cls, "struct"):
# This is one of the direct subclasses
return
if cls.__name__ in cls._classes:
raise Exception("You've already made a class called {}".format( cls.__name__))
raise Exception(
"You've already made a class called {}".format(cls.__name__)
)
cls._classes[cls.__name__] = cls
# add a _type field to the struct so we can identify it while printing structs
@ -39,34 +59,41 @@ class MessageMeta(type):
# register the type, so we can tell the parser about it
key = cls.key
if key in cls._msgtypes:
raise Exception('key {} is already assigned to {}'.format(
key, cls._msgtypes[key].__name__)
raise Exception(
"key {} is already assigned to {}".format(
key, cls._msgtypes[key].__name__
)
)
cls._msgtypes[key] = cls
class Message:
'Do not subclass this object directly. Instead, subclass of one of the below types'
"Do not subclass this object directly. Instead, subclass of one of the below types"
def print(message):
'Define this on subclasses you want to change the representation of'
"Define this on subclasses you want to change the representation of"
raise NotImplementedError
def typeof(message):
'Define this on subclasses you want to change the expressed type of'
"Define this on subclasses you want to change the expressed type of"
return message._type
@classmethod
def _default_print(cls, name, msg):
recur = cls.print_message
return "{}({})".format(name, ",".join(
"{}={}".format(key, recur(value)) for key, value in msg.items()
if not key.startswith('_')
))
return "{}({})".format(
name,
",".join(
"{}={}".format(key, recur(value))
for key, value in msg.items()
if not key.startswith("_")
),
)
@classmethod
def find_typeof(cls, msg):
if not hasattr(cls, "_msgtypes"):
raise Exception('Do not call this method on Message, call it on a subclass')
raise Exception("Do not call this method on Message, call it on a subclass")
if isinstance(msg, cl.ListContainer):
raise ValueError("do not call this on a list of messages")
if not isinstance(msg, cl.Container):
@ -80,7 +107,7 @@ class Message:
@classmethod
def print_message(cls, msg):
if not hasattr(cls, "_msgtypes"):
raise Exception('Do not call this method on Message, call it on a subclass')
raise Exception("Do not call this method on Message, call it on a subclass")
if isinstance(msg, cl.ListContainer):
return repr([cls.print_message(message) for message in msg])
@ -101,38 +128,34 @@ class Message:
@classmethod
def name_to_struct(cls):
return {
_class.__name__: _class.struct
for _class in cls._msgtypes.values()
}
return {_class.__name__: _class.struct for _class in cls._msgtypes.values()}
@classmethod
def name_to_key(cls):
return {
_class.__name__ : ord(key)
for key, _class in cls._msgtypes.items()
}
return {_class.__name__: ord(key) for key, _class in cls._msgtypes.items()}
class SharedMessage(Message, metaclass=MessageMeta):
'A message which could be sent by either the frontend or the backend'
"A message which could be sent by either the frontend or the backend"
_msgtypes = dict()
_classes = dict()
class FrontendMessage(Message, metaclass=MessageMeta):
'A message which will only be sent be a backend'
"A message which will only be sent be a backend"
_msgtypes = dict()
_classes = dict()
class BackendMessage(Message, metaclass=MessageMeta):
'A message which will only be sent be a frontend'
"A message which will only be sent be a frontend"
_msgtypes = dict()
_classes = dict()
class Query(FrontendMessage):
key = 'Q'
struct = Struct(
"query" / CString("ascii")
)
key = "Q"
struct = Struct("query" / CString("ascii"))
@staticmethod
def print(message):
@ -144,132 +167,151 @@ class Query(FrontendMessage):
@staticmethod
def normalize_shards(content):
'''
"""
For example:
>>> normalize_shards(
>>> 'COPY public.copy_test_120340 (key, value) FROM STDIN WITH (FORMAT BINARY))'
>>> )
'COPY public.copy_test_XXXXXX (key, value) FROM STDIN WITH (FORMAT BINARY))'
'''
"""
result = content
pattern = re.compile('public\.[a-z_]+(?P<shardid>[0-9]+)')
pattern = re.compile(r"public\.[a-z_]+(?P<shardid>[0-9]+)")
for match in pattern.finditer(content):
span = match.span('shardid')
replacement = 'X'*( span[1] - span[0] )
result = result[:span[0]] + replacement + result[span[1]:]
span = match.span("shardid")
replacement = "X" * (span[1] - span[0])
result = result[: span[0]] + replacement + result[span[1] :]
return result
@staticmethod
def normalize_timestamps(content):
'''
"""
For example:
>>> normalize_timestamps('2018-06-07 05:18:19.388992-07')
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
>>> normalize_timestamps('2018-06-11 05:30:43.01382-07')
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
'''
"""
pattern = re.compile(
'[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{2,6}-[0-9]{2}'
"[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{2,6}-[0-9]{2}"
)
return re.sub(pattern, 'XXXX-XX-XX XX:XX:XX.XXXXXX-XX', content)
return re.sub(pattern, "XXXX-XX-XX XX:XX:XX.XXXXXX-XX", content)
@staticmethod
def normalize_assign_txn_id(content):
'''
"""
For example:
>>> normalize_assign_txn_id('SELECT assign_distributed_transaction_id(0, 52, ...')
'SELECT assign_distributed_transaction_id(0, XX, ...'
'''
"""
pattern = re.compile(
'assign_distributed_transaction_id\s*\(' # a method call
'\s*[0-9]+\s*,' # an integer first parameter
'\s*(?P<transaction_id>[0-9]+)' # an integer second parameter
r"assign_distributed_transaction_id\s*\(" # a method call
r"\s*[0-9]+\s*," # an integer first parameter
r"\s*(?P<transaction_id>[0-9]+)" # an integer second parameter
)
result = content
for match in pattern.finditer(content):
span = match.span('transaction_id')
result = result[:span[0]] + 'XX' + result[span[1]:]
span = match.span("transaction_id")
result = result[: span[0]] + "XX" + result[span[1] :]
return result
class Terminate(FrontendMessage):
key = 'X'
key = "X"
struct = Struct()
class CopyData(SharedMessage):
key = 'd'
key = "d"
struct = Struct(
'data' / GreedyBytes # reads all of the data left in this substream
"data" / GreedyBytes # reads all of the data left in this substream
)
class CopyDone(SharedMessage):
key = 'c'
key = "c"
struct = Struct()
class EmptyQueryResponse(BackendMessage):
key = 'I'
key = "I"
struct = Struct()
class CopyOutResponse(BackendMessage):
key = 'H'
key = "H"
struct = Struct(
"format" / Int8ub,
"columncount" / Int16ub,
"columns" / Array(this.columncount, Struct(
"format" / Int16ub
))
"columns" / Array(this.columncount, Struct("format" / Int16ub)),
)
class ReadyForQuery(BackendMessage):
key='Z'
struct = Struct("state"/Enum(Byte,
idle=ord('I'),
in_transaction_block=ord('T'),
in_failed_transaction_block=ord('E')
))
key = "Z"
struct = Struct(
"state"
/ Enum(
Byte,
idle=ord("I"),
in_transaction_block=ord("T"),
in_failed_transaction_block=ord("E"),
)
)
class CommandComplete(BackendMessage):
key = 'C'
struct = Struct(
"command" / CString("ascii")
)
key = "C"
struct = Struct("command" / CString("ascii"))
class RowDescription(BackendMessage):
key = 'T'
key = "T"
struct = Struct(
"fieldcount" / Int16ub,
"fields" / Array(this.fieldcount, Struct(
"_type" / Computed("F"),
"name" / CString("ascii"),
"tableoid" / Int32ub,
"colattrnum" / Int16ub,
"typoid" / Int32ub,
"typlen" / Int16sb,
"typmod" / Int32sb,
"format_code" / Int16ub,
))
"fields"
/ Array(
this.fieldcount,
Struct(
"_type" / Computed("F"),
"name" / CString("ascii"),
"tableoid" / Int32ub,
"colattrnum" / Int16ub,
"typoid" / Int32ub,
"typlen" / Int16sb,
"typmod" / Int32sb,
"format_code" / Int16ub,
),
),
)
class DataRow(BackendMessage):
key = 'D'
key = "D"
struct = Struct(
"_type" / Computed("data_row"),
"columncount" / Int16ub,
"columns" / Array(this.columncount, Struct(
"_type" / Computed("C"),
"length" / Int16sb,
"value" / Bytes(this.length)
))
"columns"
/ Array(
this.columncount,
Struct(
"_type" / Computed("C"),
"length" / Int16sb,
"value" / Bytes(this.length),
),
),
)
class AuthenticationOk(BackendMessage):
key = 'R'
key = "R"
struct = Struct()
class ParameterStatus(BackendMessage):
key = 'S'
key = "S"
struct = Struct(
"name" / CString("ASCII"),
"value" / CString("ASCII"),
@ -281,161 +323,156 @@ class ParameterStatus(BackendMessage):
@staticmethod
def normalize(name, value):
if name in ('TimeZone', 'server_version'):
value = 'XXX'
if name in ("TimeZone", "server_version"):
value = "XXX"
return (name, value)
class BackendKeyData(BackendMessage):
key = 'K'
struct = Struct(
"pid" / Int32ub,
"key" / Bytes(4)
)
key = "K"
struct = Struct("pid" / Int32ub, "key" / Bytes(4))
def print(message):
# Both of these should be censored, for reproducible regression test output
return "BackendKeyData(XXX)"
class NoticeResponse(BackendMessage):
key = 'N'
key = "N"
struct = Struct(
"notices" / GreedyRange(
"notices"
/ GreedyRange(
Struct(
"key" / Enum(Byte,
severity=ord('S'),
_severity_not_localized=ord('V'),
_sql_state=ord('C'),
message=ord('M'),
detail=ord('D'),
hint=ord('H'),
_position=ord('P'),
_internal_position=ord('p'),
_internal_query=ord('q'),
_where=ord('W'),
schema_name=ord('s'),
table_name=ord('t'),
column_name=ord('c'),
data_type_name=ord('d'),
constraint_name=ord('n'),
_file_name=ord('F'),
_line_no=ord('L'),
_routine_name=ord('R')
),
"value" / CString("ASCII")
"key"
/ Enum(
Byte,
severity=ord("S"),
_severity_not_localized=ord("V"),
_sql_state=ord("C"),
message=ord("M"),
detail=ord("D"),
hint=ord("H"),
_position=ord("P"),
_internal_position=ord("p"),
_internal_query=ord("q"),
_where=ord("W"),
schema_name=ord("s"),
table_name=ord("t"),
column_name=ord("c"),
data_type_name=ord("d"),
constraint_name=ord("n"),
_file_name=ord("F"),
_line_no=ord("L"),
_routine_name=ord("R"),
),
"value" / CString("ASCII"),
)
)
)
def print(message):
return "NoticeResponse({})".format(", ".join(
"{}={}".format(response.key, response.value)
for response in message.notices
if not response.key.startswith('_')
))
return "NoticeResponse({})".format(
", ".join(
"{}={}".format(response.key, response.value)
for response in message.notices
if not response.key.startswith("_")
)
)
class Parse(FrontendMessage):
key = 'P'
key = "P"
struct = Struct(
"name" / CString("ASCII"),
"query" / CString("ASCII"),
"_parametercount" / Int16ub,
"parameters" / Array(
this._parametercount,
Int32ub
)
"parameters" / Array(this._parametercount, Int32ub),
)
class ParseComplete(BackendMessage):
key = '1'
key = "1"
struct = Struct()
class Bind(FrontendMessage):
key = 'B'
key = "B"
struct = Struct(
"destination_portal" / CString("ASCII"),
"prepared_statement" / CString("ASCII"),
"_parameter_format_code_count" / Int16ub,
"parameter_format_codes" / Array(this._parameter_format_code_count,
Int16ub),
"parameter_format_codes" / Array(this._parameter_format_code_count, Int16ub),
"_parameter_value_count" / Int16ub,
"parameter_values" / Array(
"parameter_values"
/ Array(
this._parameter_value_count,
Struct(
"length" / Int32ub,
"value" / Bytes(this.length)
)
Struct("length" / Int32ub, "value" / Bytes(this.length)),
),
"result_column_format_count" / Int16ub,
"result_column_format_codes" / Array(this.result_column_format_count,
Int16ub)
"result_column_format_codes" / Array(this.result_column_format_count, Int16ub),
)
class BindComplete(BackendMessage):
key = '2'
key = "2"
struct = Struct()
class NoData(BackendMessage):
key = 'n'
key = "n"
struct = Struct()
class Describe(FrontendMessage):
key = 'D'
key = "D"
struct = Struct(
"type" / Enum(Byte,
prepared_statement=ord('S'),
portal=ord('P')
),
"name" / CString("ASCII")
"type" / Enum(Byte, prepared_statement=ord("S"), portal=ord("P")),
"name" / CString("ASCII"),
)
def print(message):
return "Describe({}={})".format(
message.type,
message.name or "<unnamed>"
)
return "Describe({}={})".format(message.type, message.name or "<unnamed>")
class Execute(FrontendMessage):
key = 'E'
struct = Struct(
"name" / CString("ASCII"),
"max_rows_to_return" / Int32ub
)
key = "E"
struct = Struct("name" / CString("ASCII"), "max_rows_to_return" / Int32ub)
def print(message):
return "Execute({}, max_rows_to_return={})".format(
message.name or "<unnamed>",
message.max_rows_to_return
message.name or "<unnamed>", message.max_rows_to_return
)
class Sync(FrontendMessage):
key = 'S'
key = "S"
struct = Struct()
frontend_switch = Switch(
this.type,
{ **FrontendMessage.name_to_struct(), **SharedMessage.name_to_struct() },
default=Bytes(this.length - 4)
{**FrontendMessage.name_to_struct(), **SharedMessage.name_to_struct()},
default=Bytes(this.length - 4),
)
backend_switch = Switch(
this.type,
{**BackendMessage.name_to_struct(), **SharedMessage.name_to_struct()},
default=Bytes(this.length - 4)
default=Bytes(this.length - 4),
)
frontend_msgtypes = Enum(Byte, **{
**FrontendMessage.name_to_key(),
**SharedMessage.name_to_key()
})
frontend_msgtypes = Enum(
Byte, **{**FrontendMessage.name_to_key(), **SharedMessage.name_to_key()}
)
backend_msgtypes = Enum(Byte, **{
**BackendMessage.name_to_key(),
**SharedMessage.name_to_key()
})
backend_msgtypes = Enum(
Byte, **{**BackendMessage.name_to_key(), **SharedMessage.name_to_key()}
)
# It might seem a little circuitous to say a frontend message is a kind of frontend
# message but this lets us easily customize how they're printed
class Frontend(FrontendMessage):
struct = Struct(
"type" / frontend_msgtypes,
@ -447,9 +484,7 @@ class Frontend(FrontendMessage):
def print(message):
if isinstance(message.body, bytes):
return "Frontend(type={},body={})".format(
chr(message.type), message.body
)
return "Frontend(type={},body={})".format(chr(message.type), message.body)
return FrontendMessage.print_message(message.body)
def typeof(message):
@ -457,6 +492,7 @@ class Frontend(FrontendMessage):
return "Unknown"
return message.body._type
class Backend(BackendMessage):
struct = Struct(
"type" / backend_msgtypes,
@ -468,9 +504,7 @@ class Backend(BackendMessage):
def print(message):
if isinstance(message.body, bytes):
return "Backend(type={},body={})".format(
chr(message.type), message.body
)
return "Backend(type={},body={})".format(chr(message.type), message.body)
return BackendMessage.print_message(message.body)
def typeof(message):
@ -478,10 +512,12 @@ class Backend(BackendMessage):
return "Unknown"
return message.body._type
# GreedyRange keeps reading messages until we hit EOF
frontend_messages = GreedyRange(Frontend.struct)
backend_messages = GreedyRange(Backend.struct)
def parse(message, from_frontend=True):
if from_frontend:
message = frontend_messages.parse(message)
@ -491,24 +527,27 @@ def parse(message, from_frontend=True):
return message
def print(message):
if message.from_frontend:
return FrontendMessage.print_message(message)
return BackendMessage.print_message(message)
def message_type(message, from_frontend):
if from_frontend:
return FrontendMessage.find_typeof(message)
return BackendMessage.find_typeof(message)
def message_matches(message, filters, from_frontend):
'''
"""
Message is something like Backend(Query)) and fiters is something like query="COPY".
For now we only support strings, and treat them like a regex, which is matched against
the content of the wrapped message
'''
if message._type != 'Backend' and message._type != 'Frontend':
"""
if message._type != "Backend" and message._type != "Frontend":
raise ValueError("can't handle {}".format(message._type))
wrapped = message.body