9494_RAGGED_ADDRS = {
9595 'nq' : 'jnt_qposadr' ,
9696 'nv' : 'jnt_dofadr' ,
97+ 'na' : 'actuator_actadr' ,
9798 'nsensordata' : 'sensor_adr' ,
9899 'nnumericdata' : 'numeric_adr' ,
99100}
@@ -218,6 +219,8 @@ def _get_size_name_to_element_names(model):
218219 # For example, the element names for "nv" axis come from "njnt".
219220 for size_name , address_field_name in _RAGGED_ADDRS .items ():
220221 donor = 'n' + address_field_name .split ('_' )[0 ]
222+ if donor == 'nactuator' :
223+ donor = 'nu'
221224 if donor in size_name_to_element_names :
222225 size_name_to_element_names [size_name ] = size_name_to_element_names [donor ]
223226
@@ -230,13 +233,6 @@ def _get_size_name_to_element_names(model):
230233 assert None not in mocap_body_names
231234 size_name_to_element_names ['nmocap' ] = mocap_body_names
232235
233- # Arrays with dimension `na` correspond to stateful actuators. MuJoCo's
234- # compiler requires that these are always defined after stateless actuators,
235- # so we only need the final `na` elements in the list of all actuator names.
236- if model .na :
237- all_actuator_names = size_name_to_element_names ['nu' ]
238- size_name_to_element_names ['na' ] = all_actuator_names [- model .na :]
239-
240236 return size_name_to_element_names
241237
242238
@@ -255,8 +251,11 @@ def _get_size_name_to_element_sizes(model):
255251
256252 for size_name , address_field_name in _RAGGED_ADDRS .items ():
257253 addresses = getattr (model , address_field_name ).ravel ()
258- total_length = getattr (model , size_name )
259- element_sizes = np .diff (np .r_ [addresses , total_length ])
254+ if size_name == 'na' :
255+ element_sizes = np .where (addresses == - 1 , 0 , 1 )
256+ else :
257+ total_length = getattr (model , size_name )
258+ element_sizes = np .diff (np .r_ [addresses , total_length ])
260259 size_name_to_element_sizes [size_name ] = element_sizes
261260
262261 return size_name_to_element_sizes
@@ -282,7 +281,9 @@ def make_axis_indexers(model):
282281 element_names = size_name_to_element_names [size_name ]
283282 if size_name in _RAGGED_ADDRS :
284283 element_sizes = size_name_to_element_sizes [size_name ]
285- indexer = RaggedNamedAxis (element_names , element_sizes )
284+ singleton = (size_name == 'na' )
285+ indexer = RaggedNamedAxis (element_names , element_sizes ,
286+ singleton = singleton )
286287 else :
287288 indexer = RegularNamedAxis (element_names )
288289 axis_indexers [size_name ] = indexer
@@ -377,12 +378,13 @@ def names(self):
377378class RaggedNamedAxis (Axis ):
378379 """Represents an axis where the named elements may vary in size."""
379380
380- def __init__ (self , element_names , element_sizes ):
381+ def __init__ (self , element_names , element_sizes , singleton = False ):
381382 """Initializes a new `RaggedNamedAxis` instance.
382383
383384 Args:
384385 element_names: A list or array containing the element names.
385386 element_sizes: A list or array containing the size of each element.
387+ singleton: Whether to reduce singleton slices to scalars.
386388 """
387389 names_to_slices = {}
388390 names_to_indices = {}
@@ -391,7 +393,10 @@ def __init__(self, element_names, element_sizes):
391393 for name , size in zip (element_names , element_sizes ):
392394 # Don't add unnamed elements to the dicts.
393395 if name :
394- names_to_slices [name ] = slice (offset , offset + size )
396+ if size == 1 and singleton :
397+ names_to_slices [name ] = offset
398+ else :
399+ names_to_slices [name ] = slice (offset , offset + size )
395400 names_to_indices [name ] = range (offset , offset + size )
396401 offset += size
397402
@@ -400,29 +405,29 @@ def __init__(self, element_names, element_sizes):
400405 self ._names_to_slices = names_to_slices
401406 self ._names_to_indices = names_to_indices
402407
403- def convert_key_item (self , key ):
408+ def convert_key_item (self , key_item ):
404409 """Converts a named indexing expression to a numpy-friendly index."""
405410
406- _validate_key_item (key )
411+ _validate_key_item (key_item )
407412
408- if isinstance (key , str ):
409- key = self ._names_to_slices [util .to_native_string (key )]
413+ if isinstance (key_item , str ):
414+ key_item = self ._names_to_slices [util .to_native_string (key_item )]
410415
411- elif isinstance (key , (list , np .ndarray )):
416+ elif isinstance (key_item , (list , np .ndarray )):
412417 # We assume that either all or none of the items in the sequence are
413418 # strings representing names. If there is a mix, we will let NumPy throw
414419 # an error when trying to index with the returned key.
415- if isinstance (key [0 ], str ):
420+ if isinstance (key_item [0 ], str ):
416421 new_key = []
417- for k in key :
422+ for k in key_item :
418423 idx = self ._names_to_indices [util .to_native_string (k )]
419424 if isinstance (idx , int ):
420425 new_key .append (idx )
421426 else :
422427 new_key .extend (idx )
423- key = new_key
428+ key_item = new_key
424429
425- return key
430+ return key_item
426431
427432 @property
428433 def names (self ):
0 commit comments