Skip to content

Commit 5a2a5c2

Browse files
saran-tcopybara-github
authored andcommitted
Add a zero threshold option to to_xml(_string).
PiperOrigin-RevId: 467678061 Change-Id: If14e6e7025535e8de47c17a8302b45e2bb2d3633
1 parent 701f770 commit 5a2a5c2

6 files changed

Lines changed: 94 additions & 31 deletions

File tree

dm_control/mjcf/attribute.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,18 @@ def _assign(self, value):
158158
self._value = float_value
159159

160160
def to_xml_string(self, prefix_root=None,
161-
*, precision=constants.XML_DEFAULT_PRECISION, **kwargs):
161+
*,
162+
precision=constants.XML_DEFAULT_PRECISION,
163+
zero_threshold=0,
164+
**kwargs):
162165
if self._value is None:
163166
return None
164167
else:
165168
out = io.BytesIO()
166-
np.savetxt(out, [self._value], fmt=f'%.{precision:d}g', newline=' ')
169+
value = self._value
170+
if abs(value) < zero_threshold:
171+
value = 0.0
172+
np.savetxt(out, [value], fmt=f'%.{precision:d}g', newline=' ')
167173
return util.to_native_string(out.getvalue())[:-1] # Strip trailing space.
168174

169175

@@ -209,12 +215,19 @@ def _assign_from_string(self, string):
209215
self._assign(np.fromstring(string, dtype=self._dtype, sep=' '))
210216

211217
def to_xml_string(self, prefix_root=None,
212-
*, precision=constants.XML_DEFAULT_PRECISION, **kwargs):
218+
*,
219+
precision=constants.XML_DEFAULT_PRECISION,
220+
zero_threshold=0,
221+
**kwargs):
213222
if self._value is None:
214223
return None
215224
else:
216225
out = io.BytesIO()
217-
np.savetxt(out, self._value, fmt=f'%.{precision:d}g', newline=' ')
226+
value = self._value
227+
if zero_threshold:
228+
value = np.copy(value)
229+
value[np.abs(value) < zero_threshold] = 0
230+
np.savetxt(out, value, fmt=f'%.{precision:d}g', newline=' ')
218231
return util.to_native_string(out.getvalue())[:-1] # Strip trailing space.
219232

220233
def _check_shape(self, array):

dm_control/mjcf/attribute_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,10 @@ def testFloatArray(self):
203203
self.assertEqual(
204204
mujoco.optional.get_attribute_xml_string('float_array', precision=5),
205205
'3.1416 2 1e-16')
206+
self.assertEqual(
207+
mujoco.optional.get_attribute_xml_string(
208+
'float_array', precision=5, zero_threshold=1e-10),
209+
'3.1416 2 0')
206210
self.assertCanBeCleared(mujoco.optional, 'float_array')
207211

208212
def testFormatVeryLargeArray(self):

dm_control/mjcf/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,9 @@ def all_children(self):
223223

224224
@abc.abstractmethod
225225
def to_xml(self, prefix_root=None, debug_context=None,
226-
*, precision=constants.XML_DEFAULT_PRECISION):
226+
*,
227+
precision=constants.XML_DEFAULT_PRECISION,
228+
zero_threshold=0):
227229
"""Generates an etree._Element corresponding to this MJCF element.
228230
229231
Args:
@@ -236,6 +238,8 @@ def to_xml(self, prefix_root=None, debug_context=None,
236238
manually pass this argument.
237239
precision: (optional) Number of digits to output for floating point
238240
quantities.
241+
zero_threshold: (optional) When outputting XML, floating point quantities
242+
whose absolute value falls below this threshold will be treated as zero.
239243
240244
Returns:
241245
An etree._Element object.
@@ -244,7 +248,9 @@ def to_xml(self, prefix_root=None, debug_context=None,
244248
@abc.abstractmethod
245249
def to_xml_string(self, prefix_root=None,
246250
self_only=False, pretty_print=True, debug_context=None,
247-
*, precision=constants.XML_DEFAULT_PRECISION):
251+
*,
252+
precision=constants.XML_DEFAULT_PRECISION,
253+
zero_threshold=0):
248254
"""Generates an XML string corresponding to this MJCF element.
249255
250256
Args:
@@ -261,6 +267,8 @@ def to_xml_string(self, prefix_root=None,
261267
manually pass this argument.
262268
precision: (optional) Number of digits to output for floating point
263269
quantities.
270+
zero_threshold: (optional) When outputting XML, floating point quantities
271+
whose absolute value falls below this threshold will be treated as zero.
264272
265273
Returns:
266274
A string.

dm_control/mjcf/element.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -511,10 +511,11 @@ def get_attribute_xml_string(self,
511511
attribute_name,
512512
prefix_root=None,
513513
*,
514-
precision=constants.XML_DEFAULT_PRECISION):
514+
precision=constants.XML_DEFAULT_PRECISION,
515+
zero_threshold=0):
515516
self._check_valid_attribute(attribute_name)
516517
return self._attributes[attribute_name].to_xml_string(
517-
prefix_root, precision=precision)
518+
prefix_root, precision=precision, zero_threshold=zero_threshold)
518519

519520
def get_attributes(self):
520521
fix_attribute_name = (
@@ -696,7 +697,9 @@ def all_children(self):
696697
return all_children
697698

698699
def to_xml(self, prefix_root=None, debug_context=None,
699-
*, precision=constants.XML_DEFAULT_PRECISION):
700+
*,
701+
precision=constants.XML_DEFAULT_PRECISION,
702+
zero_threshold=0):
700703
"""Generates an etree._Element corresponding to this MJCF element.
701704
702705
Args:
@@ -709,24 +712,27 @@ def to_xml(self, prefix_root=None, debug_context=None,
709712
manually pass this argument.
710713
precision: (optional) Number of digits to output for floating point
711714
quantities.
715+
zero_threshold: (optional) When outputting XML, floating point quantities
716+
whose absolute value falls below this threshold will be treated as zero.
712717
713718
Returns:
714719
An etree._Element object.
715720
"""
716721
prefix_root = prefix_root or self.namescope
717722
xml_element = etree.Element(self._spec.name)
718723
self._attributes_to_xml(xml_element, prefix_root, debug_context,
719-
precision=precision)
724+
precision=precision, zero_threshold=zero_threshold)
720725
self._children_to_xml(xml_element, prefix_root, debug_context,
721-
precision=precision)
726+
precision=precision, zero_threshold=zero_threshold)
722727
return xml_element
723728

724729
def _attributes_to_xml(self, xml_element, prefix_root, debug_context=None,
725-
*, precision):
730+
*, precision, zero_threshold):
726731
del debug_context # Unused.
727732
for attribute_name, attribute in self._attributes.items():
728733
attribute_value = attribute.to_xml_string(prefix_root,
729-
precision=precision)
734+
precision=precision,
735+
zero_threshold=zero_threshold)
730736
if attribute_name == self._spec.identifier and attribute_value is None:
731737
xml_element.set(attribute_name, self.full_identifier)
732738
elif attribute_value is None:
@@ -735,9 +741,11 @@ def _attributes_to_xml(self, xml_element, prefix_root, debug_context=None,
735741
xml_element.set(attribute_name, attribute_value)
736742

737743
def _children_to_xml(self, xml_element, prefix_root, debug_context=None,
738-
*, precision):
744+
*, precision, zero_threshold):
739745
for child in self.all_children():
740-
child_xml = child.to_xml(prefix_root, debug_context, precision=precision)
746+
child_xml = child.to_xml(prefix_root, debug_context,
747+
precision=precision,
748+
zero_threshold=zero_threshold)
741749
if (child_xml.attrib or len(child_xml) # pylint: disable=g-explicit-length-test
742750
or child.spec.repeated or child.spec.on_demand):
743751
xml_element.append(child_xml)
@@ -749,7 +757,9 @@ def _children_to_xml(self, xml_element, prefix_root, debug_context=None,
749757

750758
def to_xml_string(self, prefix_root=None,
751759
self_only=False, pretty_print=True, debug_context=None,
752-
*, precision=constants.XML_DEFAULT_PRECISION):
760+
*,
761+
precision=constants.XML_DEFAULT_PRECISION,
762+
zero_threshold=0):
753763
"""Generates an XML string corresponding to this MJCF element.
754764
755765
Args:
@@ -766,17 +776,21 @@ def to_xml_string(self, prefix_root=None,
766776
manually pass this argument.
767777
precision: (optional) Number of digits to output for floating point
768778
quantities.
779+
zero_threshold: (optional) When outputting XML, floating point quantities
780+
whose absolute value falls below this threshold will be treated as zero.
769781
770782
Returns:
771783
A string.
772784
"""
773-
xml_element = self.to_xml(prefix_root, debug_context, precision=precision)
785+
xml_element = self.to_xml(prefix_root, debug_context,
786+
precision=precision,
787+
zero_threshold=zero_threshold)
774788
if self_only and len(xml_element) > 0: # pylint: disable=g-explicit-length-test
775789
etree.strip_elements(xml_element, '*')
776790
xml_element.text = '...'
777791
if (self_only and self._spec.identifier and
778792
not self._attributes[self._spec.identifier].to_xml_string(
779-
prefix_root, precision=precision)):
793+
prefix_root, precision=precision, zero_threshold=zero_threshold)):
780794
del xml_element.attrib[self._spec.identifier]
781795
xml_string = util.to_native_string(
782796
etree.tostring(xml_element, pretty_print=pretty_print))
@@ -1005,9 +1019,12 @@ def prefixed_identifier(self, prefix_root=None):
10051019
return prefix + self._attachment.namescope.name + constants.PREFIX_SEPARATOR
10061020

10071021
def to_xml(self, prefix_root=None, debug_context=None,
1008-
*, precision=constants.XML_DEFAULT_PRECISION):
1022+
*,
1023+
precision=constants.XML_DEFAULT_PRECISION,
1024+
zero_threshold=0):
10091025
xml_element = (super().to_xml(prefix_root, debug_context,
1010-
precision=precision))
1026+
precision=precision,
1027+
zero_threshold=zero_threshold))
10111028
xml_element.set('name', self.prefixed_identifier(prefix_root))
10121029
return xml_element
10131030

@@ -1033,9 +1050,12 @@ class _AttachmentFrameChild(_ElementImpl):
10331050
__slots__ = []
10341051

10351052
def to_xml(self, prefix_root=None, debug_context=None,
1036-
*, precision=constants.XML_DEFAULT_PRECISION):
1053+
*,
1054+
precision=constants.XML_DEFAULT_PRECISION,
1055+
zero_threshold=0):
10371056
xml_element = (super().to_xml(prefix_root, debug_context,
1038-
precision=precision))
1057+
precision=precision,
1058+
zero_threshold=zero_threshold))
10391059
if self.spec.namespace is not None:
10401060
if self.name:
10411061
name = (self._parent.prefixed_identifier(prefix_root) +
@@ -1073,16 +1093,20 @@ def all_children(self):
10731093
return [child for child in self._children]
10741094

10751095
def to_xml(self, prefix_root=None, debug_context=None,
1076-
*, precision=constants.XML_DEFAULT_PRECISION):
1096+
*,
1097+
precision=constants.XML_DEFAULT_PRECISION,
1098+
zero_threshold=0):
10771099
prefix_root = prefix_root or self.namescope
10781100
xml_element = (super().to_xml(prefix_root, debug_context,
1079-
precision=precision))
1101+
precision=precision,
1102+
zero_threshold=zero_threshold))
10801103
if isinstance(self._parent, RootElement):
10811104
root_default = etree.Element(self._spec.name)
10821105
root_default.append(xml_element)
10831106
for attachment in self._attachments.values():
10841107
attachment_xml = attachment.to_xml(prefix_root, debug_context,
1085-
precision=precision)
1108+
precision=precision,
1109+
zero_threshold=zero_threshold)
10861110
for attachment_child_xml in attachment_xml:
10871111
root_default.append(attachment_child_xml)
10881112
xml_element = root_default
@@ -1107,12 +1131,16 @@ def _is_third_order_actuator(self, child):
11071131
return False # No other actuator shortcuts have internal dynamics.
11081132

11091133
def _children_to_xml(self, xml_element, prefix_root, debug_context=None,
1110-
*, precision=constants.XML_DEFAULT_PRECISION):
1134+
*,
1135+
precision=constants.XML_DEFAULT_PRECISION,
1136+
zero_threshold=0):
11111137
second_order = []
11121138
third_order = []
11131139
debug_comments = {}
11141140
for child in self.all_children():
1115-
child_xml = child.to_xml(prefix_root, debug_context, precision=precision)
1141+
child_xml = child.to_xml(prefix_root, debug_context,
1142+
precision=precision,
1143+
zero_threshold=zero_threshold)
11161144
if (child_xml.attrib or len(child_xml) # pylint: disable=g-explicit-length-test
11171145
or child.spec.repeated or child.spec.on_demand):
11181146
if self._is_third_order_actuator(child):

dm_control/mjcf/export_with_assets.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222

2323

2424
def export_with_assets(mjcf_model, out_dir, out_file_name=None,
25-
*, precision=constants.XML_DEFAULT_PRECISION):
25+
*,
26+
precision=constants.XML_DEFAULT_PRECISION,
27+
zero_threshold=0):
2628
"""Saves mjcf.model in the given directory in MJCF (XML) format.
2729
2830
Creates an MJCF XML file named `out_file_name` in the specified `out_dir`, and
@@ -36,6 +38,8 @@ def export_with_assets(mjcf_model, out_dir, out_file_name=None,
3638
model name (`mjcf_model.model`) suffixed with '.xml'.
3739
precision: (optional) Number of digits to output for floating point
3840
quantities.
41+
zero_threshold: (optional) When outputting XML, floating point quantities
42+
whose absolute value falls below this threshold will be treated as zero.
3943
4044
Raises:
4145
ValueError: If `out_file_name` is a string that does not end with '.xml'.
@@ -48,7 +52,8 @@ def export_with_assets(mjcf_model, out_dir, out_file_name=None,
4852
assets = mjcf_model.get_assets()
4953
# This should never happen because `mjcf` does not support `.xml` assets.
5054
assert out_file_name not in assets
51-
assets[out_file_name] = mjcf_model.to_xml_string(precision=precision)
55+
assets[out_file_name] = mjcf_model.to_xml_string(
56+
precision=precision, zero_threshold=zero_threshold)
5257
if not os.path.exists(out_dir):
5358
os.makedirs(out_dir)
5459
for filename, contents in assets.items():

dm_control/mjcf/export_with_assets_as_zip.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121

2222

2323
def export_with_assets_as_zip(mjcf_model, out_dir, model_name=None,
24-
*, precision=constants.XML_DEFAULT_PRECISION):
24+
*,
25+
precision=constants.XML_DEFAULT_PRECISION,
26+
zero_threshold=0):
2527
"""Saves mjcf_model and all its assets as a .zip file in the given directory.
2628
2729
Creates a .zip file named `model_name`.zip in the specified `out_dir`, and a
@@ -39,6 +41,8 @@ def export_with_assets_as_zip(mjcf_model, out_dir, model_name=None,
3941
(`mjcf_model.model`).
4042
precision: (optional) Number of digits to output for floating point
4143
quantities.
44+
zero_threshold: (optional) When outputting XML, floating point quantities
45+
whose absolute value falls below this threshold will be treated as zero.
4246
"""
4347

4448
if model_name is None:
@@ -48,7 +52,8 @@ def export_with_assets_as_zip(mjcf_model, out_dir, model_name=None,
4852
zip_name = model_name + '.zip'
4953

5054
files_to_zip = mjcf_model.get_assets()
51-
files_to_zip[xml_name] = mjcf_model.to_xml_string(precision=precision)
55+
files_to_zip[xml_name] = mjcf_model.to_xml_string(
56+
precision=precision, zero_threshold=zero_threshold)
5257

5358
if not os.path.exists(out_dir):
5459
os.makedirs(out_dir)

0 commit comments

Comments
 (0)