citus/src/test/regress/mitmscripts/structs.py

411 lines
12 KiB
Python

from construct import (
Struct,
Int8ub, Int16ub, Int32ub, Int16sb, Int32sb,
Bytes, CString, Computed, Switch, Seek, this, Pointer,
GreedyRange, Enum, Byte, Probe, FixedSized, RestreamData, GreedyBytes, Array
)
import construct.lib as cl
import re
class MessageMeta(type):
def __init__(cls, name, bases, namespace):
'''
__init__ is called every time a subclass of MessageMeta is declared
'''
if not hasattr(cls, "_msgtypes"):
raise Exception("classes which use MessageMeta must have a '_msgtypes' field")
if not hasattr(cls, "_classes"):
raise Exception("classes which use MessageMeta must have a '_classes' field")
if not hasattr(cls, "struct"):
# This is one of the direct subclasses
return
if cls.__name__ in cls._classes:
raise Exception("You've already made a class called {}".format( cls.__name__))
cls._classes[cls.__name__] = cls
# add a _type field to the struct so we can identify it while printing structs
cls.struct = cls.struct + ("_type" / Computed(name))
if not hasattr(cls, "key"):
return
# register the type, so we can tell the parser about it
key = cls.key
if key in cls._msgtypes:
raise Exception('key {} is already assigned to {}'.format(
key, cls._msgtypes[key].__name__)
)
cls._msgtypes[key] = cls
class Message:
'Do not subclass this object directly. Instead, subclass of one of the below types'
def print(message):
'Define this on subclasses you want to change the representation of'
raise NotImplementedError
def typeof(message):
'Define this on subclasses you want to change the expressed type of'
return message._type
@classmethod
def _default_print(cls, name, msg):
recur = cls.print_message
return "{}({})".format(name, ",".join(
"{}={}".format(key, recur(value)) for key, value in msg.items()
if not key.startswith('_')
))
@classmethod
def find_typeof(cls, msg):
if not hasattr(cls, "_msgtypes"):
raise Exception('Do not call this method on Message, call it on a subclass')
if isinstance(msg, cl.ListContainer):
raise ValueError("do not call this on a list of messages")
if not isinstance(msg, cl.Container):
raise ValueError("must call this on a parsed message")
if not hasattr(msg, "_type"):
return "Anonymous"
if msg._type and msg._type not in cls._classes:
return msg._type
return cls._classes[msg._type].typeof(msg)
@classmethod
def print_message(cls, msg):
if not hasattr(cls, "_msgtypes"):
raise Exception('Do not call this method on Message, call it on a subclass')
if isinstance(msg, cl.ListContainer):
return repr([cls.print_message(message) for message in msg])
if not isinstance(msg, cl.Container):
return msg
if not hasattr(msg, "_type"):
return cls._default_print("Anonymous", msg)
if msg._type and msg._type not in cls._classes:
return cls._default_print(msg._type, msg)
try:
return cls._classes[msg._type].print(msg)
except NotImplementedError:
return cls._default_print(msg._type, msg)
@classmethod
def name_to_struct(cls):
return {
_class.__name__: _class.struct
for _class in cls._msgtypes.values()
}
@classmethod
def name_to_key(cls):
return {
_class.__name__ : ord(key)
for key, _class in cls._msgtypes.items()
}
class SharedMessage(Message, metaclass=MessageMeta):
'A message which could be sent by either the frontend or the backend'
_msgtypes = dict()
_classes = dict()
class FrontendMessage(Message, metaclass=MessageMeta):
'A message which will only be sent be a backend'
_msgtypes = dict()
_classes = dict()
class BackendMessage(Message, metaclass=MessageMeta):
'A message which will only be sent be a frontend'
_msgtypes = dict()
_classes = dict()
class Query(FrontendMessage):
key = 'Q'
struct = Struct(
"query" / CString("ascii")
)
@staticmethod
def print(message):
query = message.query
query = Query.normalize_shards(query)
query = Query.normalize_timestamps(query)
query = Query.normalize_assign_txn_id(query)
return "Query(query={})".format(query)
@staticmethod
def normalize_shards(content):
'''
For example:
>>> normalize_shards(
>>> 'COPY public.copy_test_120340 (key, value) FROM STDIN WITH (FORMAT BINARY))'
>>> )
'COPY public.copy_test_XXXXXX (key, value) FROM STDIN WITH (FORMAT BINARY))'
'''
result = content
pattern = re.compile('public\.[a-z_]+(?P<shardid>[0-9]+)')
for match in pattern.finditer(content):
span = match.span('shardid')
replacement = 'X'*( span[1] - span[0] )
result = result[:span[0]] + replacement + result[span[1]:]
return result
@staticmethod
def normalize_timestamps(content):
'''
For example:
>>> normalize_timestamps('2018-06-07 05:18:19.388992-07')
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
>>> normalize_timestamps('2018-06-11 05:30:43.01382-07')
'XXXX-XX-XX XX:XX:XX.XXXXXX-XX'
'''
pattern = re.compile(
'[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]{2,6}-[0-9]{2}'
)
return re.sub(pattern, 'XXXX-XX-XX XX:XX:XX.XXXXXX-XX', content)
@staticmethod
def normalize_assign_txn_id(content):
'''
For example:
>>> normalize_assign_txn_id('SELECT assign_distributed_transaction_id(0, 52, ...')
'SELECT assign_distributed_transaction_id(0, XX, ...'
'''
pattern = re.compile(
'assign_distributed_transaction_id\s*\(' # a method call
'\s*[0-9]+\s*,' # an integer first parameter
'\s*(?P<transaction_id>[0-9]+)' # an integer second parameter
)
result = content
for match in pattern.finditer(content):
span = match.span('transaction_id')
result = result[:span[0]] + 'XX' + result[span[1]:]
return result
class Terminate(FrontendMessage):
key = 'X'
struct = Struct()
class CopyData(SharedMessage):
key = 'd'
struct = Struct(
'data' / GreedyBytes # reads all of the data left in this substream
)
class CopyDone(SharedMessage):
key = 'c'
struct = Struct()
class EmptyQueryResponse(BackendMessage):
key = 'I'
struct = Struct()
class CopyOutResponse(BackendMessage):
key = 'H'
struct = Struct(
"format" / Int8ub,
"columncount" / Int16ub,
"columns" / Array(this.columncount, Struct(
"format" / Int16ub
))
)
class ReadyForQuery(BackendMessage):
key='Z'
struct = Struct("state"/Enum(Byte,
idle=ord('I'),
in_transaction_block=ord('T'),
in_failed_transaction_block=ord('E')
))
class CommandComplete(BackendMessage):
key = 'C'
struct = Struct(
"command" / CString("ascii")
)
class RowDescription(BackendMessage):
key = 'T'
struct = Struct(
"fieldcount" / Int16ub,
"fields" / Array(this.fieldcount, Struct(
"_type" / Computed("F"),
"name" / CString("ascii"),
"tableoid" / Int32ub,
"colattrnum" / Int16ub,
"typoid" / Int32ub,
"typlen" / Int16sb,
"typmod" / Int32sb,
"format_code" / Int16ub,
))
)
class DataRow(BackendMessage):
key = 'D'
struct = Struct(
"_type" / Computed("data_row"),
"columncount" / Int16ub,
"columns" / Array(this.columncount, Struct(
"_type" / Computed("C"),
"length" / Int16sb,
"value" / Bytes(this.length)
))
)
class AuthenticationOk(BackendMessage):
key = 'R'
struct = Struct()
class ParameterStatus(BackendMessage):
key = 'S'
struct = Struct(
"name" / CString("ASCII"),
"value" / CString("ASCII"),
)
def print(message):
name, value = ParameterStatus.normalize(message.name, message.value)
return "ParameterStatus({}={})".format(name, value)
@staticmethod
def normalize(name, value):
if name in ('TimeZone', 'server_version'):
value = 'XXX'
return (name, value)
class BackendKeyData(BackendMessage):
key = 'K'
struct = Struct(
"pid" / Int32ub,
"key" / Bytes(4)
)
def print(message):
# Both of these should be censored, for reproducible regression test output
return "BackendKeyData(XXX)"
frontend_switch = Switch(
this.type,
{ **FrontendMessage.name_to_struct(), **SharedMessage.name_to_struct() },
default=Bytes(this.length - 4)
)
backend_switch = Switch(
this.type,
{**BackendMessage.name_to_struct(), **SharedMessage.name_to_struct()},
default=Bytes(this.length - 4)
)
frontend_msgtypes = Enum(Byte, **{
**FrontendMessage.name_to_key(),
**SharedMessage.name_to_key()
})
backend_msgtypes = Enum(Byte, **{
**BackendMessage.name_to_key(),
**SharedMessage.name_to_key()
})
# It might seem a little circuitous to say a frontend message is a kind of frontend
# message but this lets us easily customize how they're printed
class Frontend(FrontendMessage):
struct = Struct(
"type" / frontend_msgtypes,
"length" / Int32ub, # "32-bit unsigned big-endian"
"raw_body" / Bytes(this.length - 4),
# try to parse the body into something more structured than raw bytes
"body" / RestreamData(this.raw_body, frontend_switch),
)
def print(message):
if isinstance(message.body, bytes):
return "Frontend(type={},body={})".format(
chr(message.type), message.body
)
return FrontendMessage.print_message(message.body)
def typeof(message):
if isinstance(message.body, bytes):
return "Unknown"
return message.body._type
class Backend(BackendMessage):
struct = Struct(
"type" / backend_msgtypes,
"length" / Int32ub, # "32-bit unsigned big-endian"
"raw_body" / Bytes(this.length - 4),
# try to parse the body into something more structured than raw bytes
"body" / RestreamData(this.raw_body, backend_switch),
)
def print(message):
if isinstance(message.body, bytes):
return "Backend(type={},body={})".format(
chr(message.type), message.body
)
return BackendMessage.print_message(message.body)
def typeof(message):
if isinstance(message.body, bytes):
return "Unknown"
return message.body._type
# GreedyRange keeps reading messages until we hit EOF
frontend_messages = GreedyRange(Frontend.struct)
backend_messages = GreedyRange(Backend.struct)
def parse(message, from_frontend=True):
if from_frontend:
message = frontend_messages.parse(message)
else:
message = backend_messages.parse(message)
message.from_frontend = from_frontend
return message
def print(message):
if message.from_frontend:
return FrontendMessage.print_message(message)
return BackendMessage.print_message(message)
def message_type(message, from_frontend):
if from_frontend:
return FrontendMessage.find_typeof(message)
return BackendMessage.find_typeof(message)
def message_matches(message, filters, from_frontend):
'''
Message is something like Backend(Query)) and fiters is something like query="COPY".
For now we only support strings, and treat them like a regex, which is matched against
the content of the wrapped message
'''
if message._type != 'Backend' and message._type != 'Frontend':
raise ValueError("can't handle {}".format(message._type))
wrapped = message.body
if isinstance(wrapped, bytes):
# we don't know which kind of message this is, so we can't match against it
return False
for key, value in filters.items():
if not isinstance(value, str):
raise ValueError("don't yet know how to handle {}".format(type(value)))
actual = getattr(wrapped, key)
if not re.search(value, actual):
return False
return True