@@ -97,16 +97,19 @@ class GripperDictType(RCSpaceType):
9797class CameraDictType (RCSpaceType ):
9898 frames : dict [
9999 Annotated [str , "camera_names" ],
100- Annotated [
101- np .ndarray ,
102- # needs to be filled with values downstream
103- lambda height , width : gym .spaces .Box (
104- low = 0 ,
105- high = 255 ,
106- shape = (height , width , 3 ),
107- dtype = np .uint8 ,
108- ),
109- "frame" ,
100+ dict [
101+ Annotated [str , "camera_type" ], # "rgb" or "depth"
102+ Annotated [
103+ np .ndarray ,
104+ # needs to be filled with values downstream
105+ lambda height , width , color_dim = 3 , dtype = np .uint8 , low = 0 , high = 255 : gym .spaces .Box (
106+ low = low ,
107+ high = high ,
108+ shape = (height , width , color_dim ),
109+ dtype = dtype ,
110+ ),
111+ "frame" ,
112+ ],
110113 ],
111114 ]
112115
@@ -387,22 +390,46 @@ def action(self, action: dict[str, Any]) -> dict[str, Any]:
387390
388391
389392class CameraSetWrapper (ActObsInfoWrapper ):
390- def __init__ (self , env , camera_set : BaseCameraSet ):
393+ RGB_KEY = "rgb"
394+ DEPTH_KEY = "depth"
395+
396+ def __init__ (self , env , camera_set : BaseCameraSet , include_depth : bool = False ):
391397 super ().__init__ (env )
392398 self .unwrapped : FR3Env
393399 self .camera_set = camera_set
400+ self .include_depth = include_depth
394401
395402 self .observation_space : gym .spaces .Dict
403+ # rgb is always included
404+ params : dict = {
405+ "frame" : {
406+ "height" : camera_set .config .resolution_height ,
407+ "width" : camera_set .config .resolution_width ,
408+ }
409+ }
410+ if self .include_depth :
411+ # depth is optional
412+ params .update (
413+ {
414+ f"/{ name } /{ self .DEPTH_KEY } /frame" : {
415+ "height" : camera_set .config .resolution_height ,
416+ "width" : camera_set .config .resolution_width ,
417+ "color_dim" : 1 ,
418+ "dtype" : np .float32 ,
419+ "low" : 0.0 ,
420+ "high" : 1.0 ,
421+ }
422+ for name in camera_set .camera_names
423+ }
424+ )
396425 self .observation_space .spaces .update (
397426 get_space (
398427 CameraDictType ,
399- child_dict_keys_to_unfold = {"camera_names" : camera_set .camera_names },
400- params = {
401- "frame" : {
402- "height" : camera_set .config .resolution_height ,
403- "width" : camera_set .config .resolution_height ,
404- }
428+ child_dict_keys_to_unfold = {
429+ "camera_names" : camera_set .camera_names ,
430+ "camera_type" : [self .RGB_KEY , self .DEPTH_KEY ] if self .include_depth else [self .RGB_KEY ],
405431 },
432+ params = params ,
406433 ).spaces
407434 )
408435 self .camera_key = get_space_keys (CameraDictType )[0 ]
@@ -419,11 +446,27 @@ def observation(self, observation: dict, info: dict[str, Any]) -> tuple[dict[str
419446 observation [self .camera_key ] = {}
420447 info ["camera_available" ] = False
421448 return observation , info
422- assert frameset is not None , "No frame available."
423- color_frame_dict : dict [str , np .ndarray ] = {
424- camera_name : frame .camera .color .data for camera_name , frame in frameset .frames .items ()
449+
450+ def check_depth (depth ):
451+ if self .include_depth and depth is None :
452+ msg = "Depth is not available in data but still requested."
453+ raise ValueError (msg )
454+ return self .include_depth
455+
456+ frame_dict : dict [str , dict [str , np .ndarray ]] = {
457+ camera_name : (
458+ {
459+ self .RGB_KEY : frame .camera .color .data ,
460+ }
461+ if check_depth (frame .camera .depth )
462+ else {
463+ self .RGB_KEY : frame .camera .color .data ,
464+ self .DEPTH_KEY : frame .camera .depth .data , # type: ignore
465+ }
466+ )
467+ for camera_name , frame in frameset .frames .items ()
425468 }
426- observation [self .camera_key ] = color_frame_dict
469+ observation [self .camera_key ] = frame_dict
427470
428471 info ["camera_available" ] = True
429472 if frameset .avg_timestamp is not None :
0 commit comments