66import json
77
88from galaxy .reader import StreamLineReader
9+ from galaxy .task_manager import TaskManager
910
1011class JsonRpcError (Exception ):
1112 def __init__ (self , code , message , data = None ):
@@ -52,7 +53,8 @@ def __init__(self, data=None):
5253 super ().__init__ (0 , "Unknown error" , data )
5354
5455Request = namedtuple ("Request" , ["method" , "params" , "id" ], defaults = [{}, None ])
55- Method = namedtuple ("Method" , ["callback" , "signature" , "internal" , "sensitive_params" ])
56+ Method = namedtuple ("Method" , ["callback" , "signature" , "immediate" , "sensitive_params" ])
57+
5658
5759def anonymise_sensitive_params (params , sensitive_params ):
5860 anomized_data = "****"
@@ -74,9 +76,9 @@ def __init__(self, reader, writer, encoder=json.JSONEncoder()):
7476 self ._encoder = encoder
7577 self ._methods = {}
7678 self ._notifications = {}
77- self ._eof_listeners = []
79+ self ._task_manager = TaskManager ( "jsonrpc server" )
7880
79- def register_method (self , name , callback , internal , sensitive_params = False ):
81+ def register_method (self , name , callback , immediate , sensitive_params = False ):
8082 """
8183 Register method
8284
@@ -86,9 +88,9 @@ def register_method(self, name, callback, internal, sensitive_params=False):
8688 :param sensitive_params: list of parameters that are anonymized before logging; \
8789 if False - no params are considered sensitive, if True - all params are considered sensitive
8890 """
89- self ._methods [name ] = Method (callback , inspect .signature (callback ), internal , sensitive_params )
91+ self ._methods [name ] = Method (callback , inspect .signature (callback ), immediate , sensitive_params )
9092
91- def register_notification (self , name , callback , internal , sensitive_params = False ):
93+ def register_notification (self , name , callback , immediate , sensitive_params = False ):
9294 """
9395 Register notification
9496
@@ -98,10 +100,7 @@ def register_notification(self, name, callback, internal, sensitive_params=False
98100 :param sensitive_params: list of parameters that are anonymized before logging; \
99101 if False - no params are considered sensitive, if True - all params are considered sensitive
100102 """
101- self ._notifications [name ] = Method (callback , inspect .signature (callback ), internal , sensitive_params )
102-
103- def register_eof (self , callback ):
104- self ._eof_listeners .append (callback )
103+ self ._notifications [name ] = Method (callback , inspect .signature (callback ), immediate , sensitive_params )
105104
106105 async def run (self ):
107106 while self ._active :
@@ -118,14 +117,16 @@ async def run(self):
118117 self ._handle_input (data )
119118 await asyncio .sleep (0 ) # To not starve task queue
120119
121- def stop (self ):
120+ def close (self ):
121+ logging .info ("Closing JSON-RPC server - not more messages will be read" )
122122 self ._active = False
123123
124+ async def wait_closed (self ):
125+ await self ._task_manager .wait ()
126+
124127 def _eof (self ):
125128 logging .info ("Received EOF" )
126- self .stop ()
127- for listener in self ._eof_listeners :
128- listener ()
129+ self .close ()
129130
130131 def _handle_input (self , data ):
131132 try :
@@ -145,20 +146,19 @@ def _handle_notification(self, request):
145146 logging .error ("Received unknown notification: %s" , request .method )
146147 return
147148
148- callback , signature , internal , sensitive_params = method
149+ callback , signature , immediate , sensitive_params = method
149150 self ._log_request (request , sensitive_params )
150151
151152 try :
152153 bound_args = signature .bind (** request .params )
153154 except TypeError :
154155 self ._send_error (request .id , InvalidParams ())
155156
156- if internal :
157- # internal requests are handled immediately
157+ if immediate :
158158 callback (* bound_args .args , ** bound_args .kwargs )
159159 else :
160160 try :
161- asyncio . create_task (callback (* bound_args .args , ** bound_args .kwargs ))
161+ self . _task_manager . create_task (callback (* bound_args .args , ** bound_args .kwargs ), request . method )
162162 except Exception :
163163 logging .exception ("Unexpected exception raised in notification handler" )
164164
@@ -169,16 +169,15 @@ def _handle_request(self, request):
169169 self ._send_error (request .id , MethodNotFound ())
170170 return
171171
172- callback , signature , internal , sensitive_params = method
172+ callback , signature , immediate , sensitive_params = method
173173 self ._log_request (request , sensitive_params )
174174
175175 try :
176176 bound_args = signature .bind (** request .params )
177177 except TypeError :
178178 self ._send_error (request .id , InvalidParams ())
179179
180- if internal :
181- # internal requests are handled immediately
180+ if immediate :
182181 response = callback (* bound_args .args , ** bound_args .kwargs )
183182 self ._send_response (request .id , response )
184183 else :
@@ -190,11 +189,13 @@ async def handle():
190189 self ._send_error (request .id , MethodNotFound ())
191190 except JsonRpcError as error :
192191 self ._send_error (request .id , error )
192+ except asyncio .CancelledError :
193+ self ._send_error (request .id , Aborted ())
193194 except Exception as e : #pylint: disable=broad-except
194195 logging .exception ("Unexpected exception raised in plugin handler" )
195196 self ._send_error (request .id , UnknownError (str (e )))
196197
197- asyncio . create_task (handle ())
198+ self . _task_manager . create_task (handle (), request . method )
198199
199200 @staticmethod
200201 def _parse_request (data ):
@@ -215,7 +216,7 @@ def _send(self, data):
215216 logging .debug ("Sending data: %s" , line )
216217 data = (line + "\n " ).encode ("utf-8" )
217218 self ._writer .write (data )
218- asyncio . create_task (self ._writer .drain ())
219+ self . _task_manager . create_task (self ._writer .drain (), "drain" )
219220 except TypeError as error :
220221 logging .error (str (error ))
221222
@@ -255,6 +256,7 @@ def __init__(self, writer, encoder=json.JSONEncoder()):
255256 self ._writer = writer
256257 self ._encoder = encoder
257258 self ._methods = {}
259+ self ._task_manager = TaskManager ("notification client" )
258260
259261 def notify (self , method , params , sensitive_params = False ):
260262 """
@@ -273,13 +275,16 @@ def notify(self, method, params, sensitive_params=False):
273275 self ._log (method , params , sensitive_params )
274276 self ._send (notification )
275277
278+ async def close (self ):
279+ await self ._task_manager .wait ()
280+
276281 def _send (self , data ):
277282 try :
278283 line = self ._encoder .encode (data )
279284 data = (line + "\n " ).encode ("utf-8" )
280285 logging .debug ("Sending %d byte of data" , len (data ))
281286 self ._writer .write (data )
282- asyncio . create_task (self ._writer .drain ())
287+ self . _task_manager . create_task (self ._writer .drain (), "drain" )
283288 except TypeError as error :
284289 logging .error ("Failed to parse outgoing message: %s" , str (error ))
285290
0 commit comments