@@ -256,19 +256,35 @@ def __init__(self,
256256 maxlen : int = 500 ) -> None :
257257 super ().__init__ (runtime , maxlen )
258258 self .fig .title = 'Reward'
259- self .maxlines = 1
260- self .y .append (deque (maxlen = self .maxlen ))
259+ self .maxlines = None
260+
261+ def _init_buffer (self ):
262+ if isinstance (self ._rt ._time_step .reward , np .ndarray ):
263+ self .maxlines = self ._rt ._time_step .reward .shape [0 ]
264+ else :
265+ self .maxlines = 1
266+ for _1 in range (self .maxlines ):
267+ self .y .append (deque (maxlen = self .maxlen ))
261268 self .reset_data ()
262269
263270 def render (self , context , viewport ):
264- if self ._rt ._time_step is None :
271+ if self ._rt ._time_step is None or self . _rt . _time_step . reward is None :
265272 return
266- r = self ._rt ._time_step .reward
267- self .fig .linepnt [0 ] = self .maxlen
268- self .y [0 ].append (r )
269- self .fig .linedata [0 ][:self .maxlen * 2 ] = np .array ([self .x ,
270- self .y [0 ]]).T .reshape (
271- (- 1 ,))
273+ if self .maxlines is None :
274+ self ._init_buffer ()
275+ if self .maxlines > 1 :
276+ for i , r in enumerate (self ._rt ._time_step .reward ):
277+ self .fig .linepnt [i ] = self .maxlen
278+ self .y [i ].append (r )
279+ self .fig .linedata [i ][:self .maxlen * 2 ] = np .array ([self .x , self .y [i ]
280+ ]).T .reshape ((- 1 ,))
281+ else :
282+ r = self ._rt ._time_step .reward
283+ self .fig .linepnt [0 ] = self .maxlen
284+ self .y [0 ].append (r )
285+ self .fig .linedata [0 ][:self .maxlen * 2 ] = np .array ([self .x ,
286+ self .y [0 ]]).T .reshape (
287+ (- 1 ,))
272288 pos = mujoco .MjrRect (2 * 300 + 5 , viewport .height - 200 - 5 , 300 , 200 )
273289 mujoco .mjr_figure (pos , self .fig , context .ptr )
274290
0 commit comments