Skip to content
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v1.4.3
v2.0.1
40 changes: 40 additions & 0 deletions data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,45 @@ def _parse(cls, stream: BytesIO, ctx: OrderedDict):
def _build(cls, obj, ctx: OrderedDotDict):
return StarByteArray.build(obj.encode("utf-8"), ctx)

class StarJson(Struct):
Comment thread
evonzee marked this conversation as resolved.
Outdated
@classmethod
def _parse(cls, stream: BytesIO, ctx: OrderedDict):

type_index = Byte.parse(stream, ctx)
if type_index > 0:
type_index -= 1

match type_index:
case 1:
return BFloat32.parse(stream, ctx)
case 2:
return Flag.parse(stream, ctx)
case 3:
return SignedVLQ.parse(stream, ctx)
case 4:
return StarString.parse(stream, ctx)
case 5:
l = VLQ.parse(stream, ctx)
c = []
for _ in range(l):
c.append(StarJson.parse(stream, ctx))
return c
case 6:
data = {}
l = VLQ.parse(stream, ctx)
for _ in range(l):
key = StarString.parse(stream, ctx)
value = StarJson.parse(stream, ctx)
data[key] = value
return data
case _:
return None # Invalid


@classmethod
def _build(cls, obj, ctx: OrderedDotDict):
raise NotImplementedError


class Byte(Struct):
@classmethod
Expand Down Expand Up @@ -748,6 +787,7 @@ class ProtocolRequest(Struct):
class ProtocolResponse(Struct):
"""packet type 1 """
server_response = Byte
info = StarJson


class ServerDisconnect(Struct):
Expand Down
27 changes: 27 additions & 0 deletions plugins/opensb_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
StarryPy OpenSB Detector Plugin

Detects zstd compression for the stream and sets server configuration accordingly
"""

import asyncio

from base_plugin import SimpleCommandPlugin
from utilities import send_message, Command


class OpenSBDetector(SimpleCommandPlugin):
name = "opensb_detector"

def __init__(self):
super().__init__()

async def activate(self):
await super().activate()

async def on_protocol_response(self, data, connection):
# self.logger.debug("Received protocol response: {} from connection {}".format(data, connection))
if data["parsed"]["info"]["compression"] == "Zstd":
Comment thread
evonzee marked this conversation as resolved.
Outdated
self.logger.info("Detected Zstd compression. Setting server configuration.")
connection.start_zstd()
return True
22 changes: 12 additions & 10 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
aiohttp==3.8.4
aiohappyeyeballs==2.3.7
aiohttp==3.10.4
aiosignal==1.3.1
async-timeout==4.0.2
attrs==23.1.0
charset-normalizer==3.1.0
discord.py==2.3.1
async-timeout==4.0.3
attrs==24.2.0
charset-normalizer==3.3.2
discord.py==2.4.0
docopt==0.6.2
frozenlist==1.3.3
idna==3.4
frozenlist==1.4.1
idna==3.7
irc3==1.1.10
multidict==6.0.4
venusian==3.0.0
yarl==1.9.2
multidict==6.0.5
venusian==3.1.0
yarl==1.9.4
zstandard==0.23.0
51 changes: 40 additions & 11 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import logging
import sys
import signal
import traceback

from configuration_manager import ConfigurationManager
from data_parser import ChatReceived
from packets import packets
from pparser import build_packet
from plugin_manager import PluginManager
from utilities import path, read_packet, State, Direction, ChatReceiveMode
from zstd_reader import ZstdFrameReader
from zstd_writer import ZstdFrameWriter


DEBUG = True
Expand All @@ -21,18 +24,21 @@
logger = logging.getLogger('starrypy')
logger.setLevel(loglevel)

class SwitchToZstdException(Exception):
pass

class StarryPyServer:
"""
Primary server class. Handles all the things.
"""
def __init__(self, reader, writer, config, factory):
logger.debug("Initializing connection.")
self._reader = reader
self._writer = writer
self._client_reader = None
self._client_writer = None
self._reader = reader # read packets from client
self._writer = writer # writes packets to client
self._client_reader = None # read packets from server (acting as client)
self._client_writer = None # write packets to server
self.factory = factory
self._client_loop_future = None
self._client_loop_future = asyncio.create_task(self.client_loop())
self._server_loop_future = asyncio.create_task(self.server_loop())
self.state = None
self._alive = True
Expand All @@ -42,8 +48,20 @@ def __init__(self, reader, writer, config, factory):
self._client_read_future = None
self._server_write_future = None
self._client_write_future = None
self._expect_server_loop_death = False
logger.info("Received connection from {}".format(self.client_ip))

def start_zstd(self):
self._reader = ZstdFrameReader(self._reader, Direction.TO_SERVER)
self._client_reader= ZstdFrameReader(self._client_reader, Direction.TO_CLIENT)
self._writer = ZstdFrameWriter(self._writer, skip_packets=1)
self._client_writer = ZstdFrameWriter(self._client_writer)
self._expect_server_loop_death = True
self._server_loop_future.cancel()
self._server_loop_future = asyncio.create_task(self.server_loop())
logger.info("Switched to zstd")


async def server_loop(self):
"""
Main server loop. As clients connect to the proxy, pass the
Expand All @@ -52,14 +70,15 @@ async def server_loop(self):

:return:
"""
(self._client_reader, self._client_writer) = \
await asyncio.open_connection(self.config['upstream_host'],
self.config['upstream_port'])
self._client_loop_future = asyncio.create_task(self.client_loop())

# wait until client is available
while self._client_writer is None:
await asyncio.sleep(0.1)

try:
while True:
packet = await read_packet(self._reader,
Direction.TO_SERVER)
Direction.TO_SERVER)
# Break in case of emergencies:
# if packet['type'] not in [17, 40, 41, 43, 48, 51]:
# logger.debug('c->s {}'.format(packet['type']))
Expand All @@ -74,8 +93,14 @@ async def server_loop(self):
except Exception as err:
logger.error("Server loop exception occurred:"
"{}: {}".format(err.__class__.__name__, err))
logger.error("Error details and traceback: {}".format(traceback.format_exc()))
finally:
self.die()
if not self._expect_server_loop_death:
logger.info("Server loop ended.")
self.die()
else:
logger.info("Restarting server loop for switch to zstd.")
self._expect_server_loop_death = False

async def client_loop(self):
"""
Expand All @@ -84,6 +109,10 @@ async def client_loop(self):

:return:
"""
(self._client_reader, self._client_writer) = \
await asyncio.open_connection(self.config['upstream_host'],
self.config['upstream_port'])

try:
while True:
packet = await read_packet(self._client_reader,
Expand Down
71 changes: 71 additions & 0 deletions zstd_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import asyncio
from io import BufferedReader
import io
import zstandard as zstd

from utilities import Direction

class ZstdFrameReader:
def __init__(self, reader: asyncio.StreamReader, direction: Direction):
self.outputbuffer = NonSeekableMemoryStream()
self.decompressor = zstd.ZstdDecompressor().stream_writer(self.outputbuffer)
self.raw_reader = reader
self.direction = direction

async def readexactly(self, count):
# print(f"Reading exactly {count} bytes")

while True:
# if there are enough bytes, return them
if self.outputbuffer.remaining() >= count:
# print (f"Returning {count} bytes from buffer {self.direction}")
return self.outputbuffer.read(count)

# print(f"Reading from network since there are only {self.remaining} bytes in buffer")
await self.read_from_network(count)

async def read_from_network(self, target_count):
while self.outputbuffer.remaining() < target_count:

chunk = await self.raw_reader.read(32768) # Read in chunks; we'll only get what's available
# print(f"Read {len(chunk)} bytes from network")
if not chunk:
raise asyncio.CancelledError("Connection closed")
try:
self.decompressor.write(chunk)
except zstd.ZstdError:
print("Zstd error, dropping connection")
raise asyncio.CancelledError("Error in compressed data stream!")

class NonSeekableMemoryStream(io.RawIOBase):
def __init__(self):
self.buffer = bytearray()
self.read_pos = 0
self.write_pos = 0

def write(self, b):
self.buffer.extend(b)
self.write_pos += len(b)
return len(b)

def read(self, size=-1):
if size == -1 or size > self.write_pos - self.read_pos:
size = self.write_pos - self.read_pos
if size == 0:
return b''
data = self.buffer[self.read_pos:self.read_pos + size]
self.read_pos += size
if self.read_pos == self.write_pos:
self.buffer = bytearray()
self.read_pos = 0
self.write_pos = 0
return bytes(data)

def remaining(self):
return self.write_pos - self.read_pos

def readable(self):
return True

def writable(self):
return True
25 changes: 25 additions & 0 deletions zstd_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import asyncio
from io import BufferedReader, BytesIO
import zstandard as zstd

class ZstdFrameWriter:
def __init__(self, raw_writer: asyncio.StreamWriter, skip_packets=0):
self.compressor = zstd.ZstdCompressor()
self.raw_writer = raw_writer
self.skip_packets = skip_packets

async def drain(self):
await self.raw_writer.drain()

def close(self):
self.raw_writer.close()
self.compressor = None

def write(self, data):

if self.skip_packets > 0:
self.skip_packets -= 1
self.raw_writer.write(data)
return

self.raw_writer.write(self.compressor.compress(data))