citus/src/test/regress/mitmscripts/fluent.py

488 lines
14 KiB
Python

import logging
import os
import queue
import re
import signal
import socket
import struct
import threading
import time
import traceback
from itertools import count
import structs
from construct.lib import ListContainer
from mitmproxy import ctx, tcp
logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=logging.DEBUG)
# I. Command Strings
class Handler:
"""
This class hierarchy serves two purposes:
1. Allow command strings to be evaluated. Once evaluated you'll have a Handler you can
pass packets to
2. Process packets as they come in and decide what to do with them.
Subclasses which want to change how packets are handled should override _handle.
"""
def __init__(self, root=None):
# all packets are first sent to the root handler to be processed
self.root = root if root else self
# all handlers keep track of the next handler so they know where to send packets
self.next = None
def _accept(self, flow, message):
result = self._handle(flow, message)
if result == "pass":
# defer to our child
if not self.next:
raise Exception("we don't know what to do!")
if self.next._accept(flow, message) == "stop":
if self.root is not self:
return "stop"
self.next = KillHandler(self)
flow.kill()
else:
return result
def _handle(self, flow, message):
"""
Handlers can return one of three things:
- "done" tells the parent to stop processing. This performs the default action,
which is to allow the packet to be sent.
- "pass" means to delegate to self.next and do whatever it wants
- "stop" means all processing will stop, and all connections will be killed
"""
# subclasses must implement this
raise NotImplementedError()
class FilterableMixin:
def contains(self, pattern):
self.next = Contains(self.root, pattern)
return self.next
def matches(self, pattern):
self.next = Matches(self.root, pattern)
return self.next
def after(self, times):
self.next = After(self.root, times)
return self.next
def __getattr__(self, attr):
"""
Methods such as .onQuery trigger when a packet with that name is intercepted
Adds support for commands such as:
conn.onQuery(query="COPY")
Returns a function because the above command is resolved in two steps:
conn.onQuery becomes conn.__getattr__("onQuery")
conn.onQuery(query="COPY") becomes conn.__getattr__("onQuery")(query="COPY")
"""
if attr.startswith("on"):
def doit(**kwargs):
self.next = OnPacket(self.root, attr[2:], kwargs)
return self.next
return doit
raise AttributeError
class ActionsMixin:
def kill(self):
self.next = KillHandler(self.root)
return self.next
def allow(self):
self.next = AcceptHandler(self.root)
return self.next
def killall(self):
self.next = KillAllHandler(self.root)
return self.next
def reset(self):
self.next = ResetHandler(self.root)
return self.next
def cancel(self, pid):
self.next = CancelHandler(self.root, pid)
return self.next
def connect_delay(self, timeMs):
self.next = ConnectDelayHandler(self.root, timeMs)
return self.next
class AcceptHandler(Handler):
def __init__(self, root):
super().__init__(root)
def _handle(self, flow, message):
return "done"
class KillHandler(Handler):
def __init__(self, root):
super().__init__(root)
def _handle(self, flow, message):
flow.kill()
return "done"
class KillAllHandler(Handler):
def __init__(self, root):
super().__init__(root)
def _handle(self, flow, message):
return "stop"
class ResetHandler(Handler):
# try to force a RST to be sent, something went very wrong!
def __init__(self, root):
super().__init__(root)
def _handle(self, flow, message):
flow.kill() # tell mitmproxy this connection should be closed
# this is a mitmproxy.connections.ClientConnection(mitmproxy.tcp.BaseHandler)
client_conn = flow.client_conn
# this is a regular socket object
conn = client_conn.connection
# cause linux to send a RST
LINGER_ON, LINGER_TIMEOUT = 1, 0
conn.setsockopt(
socket.SOL_SOCKET,
socket.SO_LINGER,
struct.pack("ii", LINGER_ON, LINGER_TIMEOUT),
)
conn.close()
# closing the connection isn't ideal, this thread later crashes when mitmproxy
# tries to call conn.shutdown(), but there's nothing else to clean up so that's
# maybe okay
return "done"
class CancelHandler(Handler):
"Send a SIGINT to the process"
def __init__(self, root, pid):
super().__init__(root)
self.pid = pid
def _handle(self, flow, message):
os.kill(self.pid, signal.SIGINT)
# give the signal a chance to be received before we let the packet through
time.sleep(0.1)
return "done"
class ConnectDelayHandler(Handler):
"Delay the initial packet by sleeping before deciding what to do"
def __init__(self, root, timeMs):
super().__init__(root)
self.timeMs = timeMs
def _handle(self, flow, message):
if message.is_initial:
time.sleep(self.timeMs / 1000.0)
return "done"
class Contains(Handler, ActionsMixin, FilterableMixin):
def __init__(self, root, pattern):
super().__init__(root)
self.pattern = pattern
def _handle(self, flow, message):
if self.pattern in message.content:
return "pass"
return "done"
class Matches(Handler, ActionsMixin, FilterableMixin):
def __init__(self, root, pattern):
super().__init__(root)
self.pattern = re.compile(pattern)
def _handle(self, flow, message):
if self.pattern.search(message.content):
return "pass"
return "done"
class After(Handler, ActionsMixin, FilterableMixin):
"Don't pass execution to our child until we've handled 'times' messages"
def __init__(self, root, times):
super().__init__(root)
self.target = times
def _handle(self, flow, message):
if not hasattr(flow, "_after_count"):
flow._after_count = 0
if flow._after_count >= self.target:
return "pass"
flow._after_count += 1
return "done"
class OnPacket(Handler, ActionsMixin, FilterableMixin):
"""Triggers when a packet of the specified kind comes around"""
def __init__(self, root, packet_kind, kwargs):
super().__init__(root)
self.packet_kind = packet_kind
self.filters = kwargs
def _handle(self, flow, message):
if not message.parsed:
# if this is the first message in the connection we just skip it
return "done"
for msg in message.parsed:
typ = structs.message_type(msg, from_frontend=message.from_client)
if typ == self.packet_kind:
matches = structs.message_matches(
msg, self.filters, message.from_client
)
if matches:
return "pass"
return "done"
class RootHandler(Handler, ActionsMixin, FilterableMixin):
def _handle(self, flow, message):
# do whatever the next Handler tells us to do
return "pass"
class RecorderCommand:
def __init__(self):
self.root = self
self.command = None
def dump(self):
# When the user calls dump() we return everything we've captured
self.command = "dump"
return self
def reset(self):
# If the user calls reset() we dump all captured packets without returning them
self.command = "reset"
return self
# II. Utilities for interfacing with mitmproxy
def build_handler(spec):
"Turns a command string into a RootHandler ready to accept packets"
root = RootHandler()
recorder = RecorderCommand()
handler = eval(spec, {"__builtins__": {}}, {"conn": root, "recorder": recorder})
return handler.root
# a bunch of globals
handler = None # the current handler used to process packets
command_thread = None # sits on the fifo and waits for new commands to come in
captured_messages = queue.Queue() # where we store messages used for recorder.dump()
connection_count = count() # so we can give connections ids in recorder.dump()
def listen_for_commands(fifoname):
def emit_row(conn, from_client, message):
# we're using the COPY text format. It requires us to escape backslashes
cleaned = message.replace("\\", "\\\\")
source = "coordinator" if from_client else "worker"
return "{}\t{}\t{}".format(conn, source, cleaned)
def emit_message(message):
if message.is_initial:
return emit_row(
message.connection_id, message.from_client, "[initial message]"
)
pretty = structs.print(message.parsed)
return emit_row(message.connection_id, message.from_client, pretty)
def all_items(queue_):
"Pulls everything out of the queue without blocking"
try:
while True:
yield queue_.get(block=False)
except queue.Empty:
pass
def drop_terminate_messages(messages):
"""
Terminate() messages happen eventually, Citus doesn't feel any need to send them
immediately, so tests which embed them aren't reproducible and fail to timing
issues. Here we simply drop those messages.
"""
def isTerminate(msg, from_client):
kind = structs.message_type(msg, from_client)
return kind == "Terminate"
for message in messages:
if not message.parsed:
yield message
continue
message.parsed = ListContainer(
[
msg
for msg in message.parsed
if not isTerminate(msg, message.from_client)
]
)
message.parsed.from_frontend = message.from_client
if len(message.parsed) == 0:
continue
yield message
def handle_recorder(recorder):
global connection_count
result = ""
if recorder.command == "reset":
result = ""
connection_count = count()
elif recorder.command != "dump":
# this should never happen
raise Exception("Unrecognized command: {}".format(recorder.command))
results = []
messages = all_items(captured_messages)
messages = drop_terminate_messages(messages)
for message in messages:
if recorder.command == "reset":
continue
results.append(emit_message(message))
result = "\n".join(results)
logging.debug("about to write to fifo")
with open(fifoname, mode="w") as fifo:
logging.debug("successfully opened the fifo for writing")
fifo.write("{}".format(result))
while True:
logging.debug("about to read from fifo")
with open(fifoname, mode="r") as fifo:
logging.debug("successfully opened the fifo for reading")
slug = fifo.read()
logging.info("received new command: %s", slug.rstrip())
try:
handler = build_handler(slug)
if isinstance(handler, RecorderCommand):
handle_recorder(handler)
continue
except Exception as e:
traceback.print_exc()
result = str(e)
else:
result = None
if not result:
try:
ctx.options.update(slug=slug)
except Exception as e:
result = str(e)
else:
result = ""
logging.debug("about to write to fifo")
with open(fifoname, mode="w") as fifo:
logging.debug("successfully opened the fifo for writing")
fifo.write("{}\n".format(result))
logging.info("responded to command: %s", result.split("\n")[0])
def create_thread(fifoname):
global command_thread
if not fifoname:
return
if not len(fifoname):
return
if command_thread:
print("cannot change the fifo path once mitmproxy has started")
return
command_thread = threading.Thread(
target=listen_for_commands, args=(fifoname,), daemon=True
)
command_thread.start()
# III. mitmproxy callbacks
def load(loader):
loader.add_option("slug", str, "conn.allow()", "A script to run")
loader.add_option("fifo", str, "", "Which fifo to listen on for commands")
def configure(updated):
global handler
if "slug" in updated:
text = ctx.options.slug
handler = build_handler(text)
if "fifo" in updated:
fifoname = ctx.options.fifo
create_thread(fifoname)
def tcp_message(flow: tcp.TCPFlow):
"""
This callback is hit every time mitmproxy receives a packet. It's the main entrypoint
into this script.
"""
global connection_count
tcp_msg = flow.messages[-1]
# Keep track of all the different connections, assign a unique id to each
if not hasattr(flow, "connection_id"):
flow.connection_id = next(connection_count)
tcp_msg.connection_id = flow.connection_id
# The first packet the frontend sends shounld be parsed differently
tcp_msg.is_initial = len(flow.messages) == 1
if tcp_msg.is_initial:
# skip parsing initial messages for now, they're not important
tcp_msg.parsed = None
else:
tcp_msg.parsed = structs.parse(
tcp_msg.content, from_frontend=tcp_msg.from_client
)
# record the message, for debugging purposes
captured_messages.put(tcp_msg)
# okay, finally, give the packet to the command the user wants us to use
handler._accept(flow, tcp_msg)