1+ import contextlib
12import copy
23import logging
34import threading
45import time
56from dataclasses import dataclass , field
6- from typing import Any , Dict , Iterator , List , Optional , Sequence , Tuple , TypedDict
7+ from typing import Any , ClassVar , Dict , Iterator , List , Sequence , Tuple
78
89import numpy as np
910
2526except ImportError :
2627 HAS_PYNPUT = False
2728
28- from rcs .envs .base import (
29- ArmWithGripper ,
30- ControlMode ,
31- GripperDictType ,
32- JointsDictType ,
33- RelativeTo ,
34- )
29+ from rcs .envs .base import ControlMode , RelativeTo
3530from rcs .operator .interface import BaseOperator , BaseOperatorConfig , TeleopCommands
3631from rcs .sim .sim import Sim
3732from rcs .utils import SimpleFrameRate
@@ -64,14 +59,15 @@ def __init__(
6459 pulses_per_revolution : int = 4095 ,
6560 ):
6661 if not HAS_DYNAMIXEL_SDK :
67- raise ImportError ("dynamixel_sdk is not installed. Please install it to use GelloOperator." )
62+ msg = "dynamixel_sdk is not installed. Please install it to use GelloOperator."
63+ raise ImportError (msg )
6864
6965 self ._ids = ids
7066 self ._port = port
7167 self ._baudrate = baudrate
7268 self ._pulses_per_revolution = pulses_per_revolution
7369 self ._lock = threading .Lock ()
74- self ._buffered_joint_positions = None
70+ self ._buffered_joint_positions : np . ndarray | None = None
7571
7672 self ._portHandler = PortHandler (self ._port )
7773 self ._packetHandler = PacketHandler (2.0 )
@@ -86,29 +82,32 @@ def __init__(
8682 for dxl_id in self ._ids :
8783 self ._groupSyncReadHandlers [key ].addParam (dxl_id )
8884
89- if key != "model_number" and key != "present_position" :
85+ if key not in { "model_number" , "present_position" } :
9086 self ._groupSyncWriteHandlers [key ] = GroupSyncWrite (
9187 self ._portHandler , self ._packetHandler , entry ["addr" ], entry ["len" ]
9288 )
9389
9490 if not self ._portHandler .openPort ():
95- raise ConnectionError (f"Failed to open port { self ._port } " )
91+ msg = f"Failed to open port { self ._port } "
92+ raise ConnectionError (msg )
9693 if not self ._portHandler .setBaudRate (self ._baudrate ):
97- raise ConnectionError (f"Failed to set baudrate { self ._baudrate } " )
94+ msg = f"Failed to set baudrate { self ._baudrate } "
95+ raise ConnectionError (msg )
9896
9997 self ._stop_thread = threading .Event ()
100- self ._polling_thread = None
98+ self ._polling_thread : threading . Thread | None = None
10199 self ._is_polling = False
102100
103101 def write_value_by_name (self , name : str , values : Sequence [int | None ]):
104102 if len (values ) != len (self ._ids ):
105- raise ValueError (f"The length of { name } must match the number of servos" )
103+ msg = f"The length of { name } must match the number of servos"
104+ raise ValueError (msg )
106105
107106 handler = self ._groupSyncWriteHandlers [name ]
108107 value_len = XL330_CONTROL_TABLE [name ]["len" ]
109108
110109 with self ._lock :
111- for dxl_id , value in zip (self ._ids , values ):
110+ for dxl_id , value in zip (self ._ids , values , strict = False ):
112111 if value is None :
113112 continue
114113 param = [(value >> (8 * i )) & 0xFF for i in range (value_len )]
@@ -117,7 +116,8 @@ def write_value_by_name(self, name: str, values: Sequence[int | None]):
117116 comm_result = handler .txPacket ()
118117 if comm_result != COMM_SUCCESS :
119118 handler .clearParam ()
120- raise RuntimeError (f"Failed to syncwrite { name } : { self ._packetHandler .getTxRxResult (comm_result )} " )
119+ msg = f"Failed to syncwrite { name } : { self ._packetHandler .getTxRxResult (comm_result )} "
120+ raise RuntimeError (msg )
121121 handler .clearParam ()
122122
123123 def read_value_by_name (self , name : str ) -> List [int ]:
@@ -128,7 +128,8 @@ def read_value_by_name(self, name: str) -> List[int]:
128128 with self ._lock :
129129 comm_result = handler .txRxPacket ()
130130 if comm_result != COMM_SUCCESS :
131- raise RuntimeError (f"Failed to sync read { name } : { self ._packetHandler .getTxRxResult (comm_result )} " )
131+ msg = f"Failed to sync read { name } : { self ._packetHandler .getTxRxResult (comm_result )} "
132+ raise RuntimeError (msg )
132133
133134 values = []
134135 for dxl_id in self ._ids :
@@ -137,7 +138,8 @@ def read_value_by_name(self, name: str) -> List[int]:
137138 value = int (np .int32 (np .uint32 (value )))
138139 values .append (value )
139140 else :
140- raise RuntimeError (f"Failed to get { name } for ID { dxl_id } " )
141+ msg = f"Failed to get { name } for ID { dxl_id } "
142+ raise RuntimeError (msg )
141143 return values
142144
143145 def start_joint_polling (self ):
@@ -213,7 +215,7 @@ class DynamixelControlConfig:
213215 goal_current : List [int ] = field (default_factory = list )
214216 operating_mode : List [int ] = field (default_factory = list )
215217
216- _UPDATE_ORDER = [
218+ _UPDATE_ORDER : ClassVar [ list [ str ]] = [
217219 "operating_mode" ,
218220 "goal_current" ,
219221 "kp_p" ,
@@ -329,10 +331,8 @@ def _goal_position_to_pulses(self, goals):
329331 return [self ._driver ._rad_to_pulses (rad ) for rad in goals_raw ]
330332
331333 def close (self ):
332- try :
334+ with contextlib . suppress ( Exception ) :
333335 self ._driver .write_value_by_name ("torque_enable" , [0 ] * self ._num_total_joints )
334- except :
335- pass
336336 self ._driver .close ()
337337
338338
@@ -342,7 +342,7 @@ def close(self):
342342class GelloOperator (BaseOperator ):
343343 control_mode = (ControlMode .JOINTS , RelativeTo .NONE )
344344
345- def __init__ (self , config : GelloConfig , sim : Sim | None = None ):
345+ def __init__ (self , config : " GelloConfig" , sim : Sim | None = None ):
346346 super ().__init__ (config , sim )
347347 self .config : GelloConfig
348348 self ._resource_lock = threading .Lock ()
@@ -353,7 +353,7 @@ def __init__(self, config: GelloConfig, sim: Sim | None = None):
353353
354354 self .controller_names = list (self .config .arms .keys ())
355355
356- self ._last_joints = {name : None for name in self .controller_names }
356+ self ._last_joints : Dict [ str , np . ndarray | None ] = {name : None for name in self .controller_names }
357357 self ._last_gripper = {name : 1.0 for name in self .controller_names }
358358 self ._hws : Dict [str , GelloHardware ] = {}
359359
@@ -386,12 +386,13 @@ def reset_operator_state(self):
386386 pass
387387
388388 def consume_action (self ) -> Dict [str , Any ]:
389- actions = {}
389+ actions : Dict [ str , Any ] = {}
390390 with self ._resource_lock :
391391 for name in self .controller_names :
392- if self ._last_joints [name ] is not None :
392+ joints = self ._last_joints [name ]
393+ if joints is not None :
393394 actions [name ] = {
394- "joints" : self . _last_joints [ name ] .copy (),
395+ "joints" : joints .copy (),
395396 "gripper" : np .array ([self ._last_gripper [name ]]),
396397 }
397398 return actions
@@ -434,6 +435,6 @@ def close(self):
434435
435436@dataclass (kw_only = True )
436437class GelloConfig (BaseOperatorConfig ):
437- operator_class = GelloOperator
438+ operator_class : type [ BaseOperator ] = field ( default = GelloOperator )
438439 # Dictionary for multi-arm setups: {"left": GelloArmConfig(...), "right": GelloArmConfig(...)}
439440 arms : Dict [str , GelloArmConfig ] = field (default_factory = lambda : {"right" : GelloArmConfig ()})
0 commit comments