@@ -33,8 +33,8 @@ class StarryPyServer:
3333 """
3434 def __init__ (self , reader , writer , config , factory ):
3535 logger .debug ("Initializing connection." )
36- self ._reader = reader # read packets from client
37- self ._writer = writer # writes packets to client
36+ self ._reader = ZstdFrameReader ( reader , Direction . TO_SERVER ) # read packets from client
37+ self ._writer = ZstdFrameWriter ( writer ) # writes packets to client
3838 self ._client_reader = None # read packets from server (acting as client)
3939 self ._client_writer = None # write packets to server
4040 self .factory = factory
@@ -52,13 +52,10 @@ def __init__(self, reader, writer, config, factory):
5252 logger .info ("Received connection from {}" .format (self .client_ip ))
5353
5454 def start_zstd (self ):
55- self ._reader = ZstdFrameReader (self ._reader , Direction .TO_SERVER )
56- self ._client_reader = ZstdFrameReader (self ._client_reader , Direction .TO_CLIENT )
57- self ._writer = ZstdFrameWriter (self ._writer , skip_packets = 1 )
58- self ._client_writer = ZstdFrameWriter (self ._client_writer )
59- self ._expect_server_loop_death = True
60- self ._server_loop_future .cancel ()
61- self ._server_loop_future = asyncio .create_task (self .server_loop ())
55+ self ._reader .enable_zstd ()
56+ self ._client_reader .enable_zstd ()
57+ self ._writer .enable_zstd (skip_packets = 1 ) # skip this packet
58+ self ._client_writer .enable_zstd ()
6259 logger .info ("Switched to zstd" )
6360
6461
@@ -109,9 +106,11 @@ async def client_loop(self):
109106
110107 :return:
111108 """
112- (self ._client_reader , self ._client_writer ) = \
113- await asyncio .open_connection (self .config ['upstream_host' ],
109+ (reader , writer ) = await asyncio .open_connection (self .config ['upstream_host' ],
114110 self .config ['upstream_port' ])
111+
112+ self ._client_reader = ZstdFrameReader (reader , Direction .TO_CLIENT )
113+ self ._client_writer = ZstdFrameWriter (writer )
115114
116115 try :
117116 while True :
0 commit comments