3434import collections
3535import contextlib
3636import threading
37- from typing import NamedTuple
37+ from typing import Callable , NamedTuple , Optional , Union
3838
3939from absl import logging
4040
@@ -177,6 +177,8 @@ def render(
177177 segmentation = False ,
178178 scene_option = None ,
179179 render_flag_overrides = None ,
180+ scene_callback : Optional [Callable [['Physics' , mujoco .MjvScene ],
181+ None ]] = None ,
180182 ):
181183 """Returns a camera view as a NumPy array of pixel values.
182184
@@ -204,12 +206,18 @@ def render(
204206 `{'wireframe': True}` or `{mujoco.mjtRndFlag.mjRND_WIREFRAME: True}`.
205207 See `mujoco.mjtRndFlag` for the set of valid flags. Must be None if
206208 either `depth` or `segmentation` is True.
209+ scene_callback: Called after the scene has been created and before
210+ it is rendered. Can be used to add more geoms to the scene.
207211
208212 Returns:
209213 The rendered RGB, depth or segmentation image.
210214 """
211215 camera = Camera (
212- physics = self , height = height , width = width , camera_id = camera_id )
216+ physics = self ,
217+ height = height ,
218+ width = width ,
219+ camera_id = camera_id ,
220+ scene_callback = scene_callback )
213221 image = camera .render (
214222 overlays = overlays , depth = depth , segmentation = segmentation ,
215223 scene_option = scene_option , render_flag_overrides = render_flag_overrides )
@@ -602,12 +610,16 @@ class Camera:
602610 `camera_id`, for example to render the same view at different resolutions.
603611 """
604612
605- def __init__ (self ,
606- physics ,
607- height = 240 ,
608- width = 320 ,
609- camera_id = - 1 ,
610- max_geom = None ):
613+ def __init__ (
614+ self ,
615+ physics ,
616+ height : int = 240 ,
617+ width : int = 320 ,
618+ camera_id : Union [int , str ] = - 1 ,
619+ max_geom : Optional [int ] = None ,
620+ scene_callback : Optional [Callable [[Physics , mujoco .MjvScene ],
621+ None ]] = None ,
622+ ):
611623 """Initializes a new `Camera`.
612624
613625 Args:
@@ -621,6 +633,8 @@ def __init__(self,
621633 max_geom: Optional integer specifying the maximum number of geoms that can
622634 be rendered in the same scene. If None this will be chosen automatically
623635 based on the estimated maximum number of renderable geoms in the model.
636+ scene_callback: Called after the scene has been created and before
637+ it is rendered. Can be used to add more geoms to the scene.
624638 Raises:
625639 ValueError: If `camera_id` is outside the valid range, or if `width` or
626640 `height` exceed the dimensions of MuJoCo's offscreen framebuffer.
@@ -652,6 +666,7 @@ def __init__(self,
652666 self ._width = width
653667 self ._height = height
654668 self ._physics = physics
669+ self ._scene_callback = scene_callback
655670
656671 # Variables corresponding to structs needed by Mujoco's rendering functions.
657672 self ._scene = wrapper .MjvScene (model = physics .model , max_geom = max_geom )
@@ -844,6 +859,9 @@ def render(
844859 # Update scene geometry.
845860 self .update (scene_option = scene_option )
846861
862+ if self ._scene_callback :
863+ self ._scene_callback (self ._physics , self ._scene )
864+
847865 # Enable flags to compute segmentation labels
848866 if segmentation :
849867 render_flag_overrides .update ({
0 commit comments