Skip to content

Commit 701f770

Browse files
saran-tcopybara-github
authored andcommitted
Add a floating point precision option to to_xml(_string).
PiperOrigin-RevId: 467664520 Change-Id: Ibac426918ae8466dc09e65635afbf961157cb62a
1 parent 42527bc commit 701f770

8 files changed

Lines changed: 133 additions & 47 deletions

File tree

dm_control/mjcf/attribute.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _before_clear(self):
105105
def _assign_from_string(self, string):
106106
self._assign(string)
107107

108-
def to_xml_string(self, prefix_root): # pylint: disable=unused-argument
108+
def to_xml_string(self, prefix_root, **kwargs): # pylint: disable=unused-argument
109109
if self._value is None:
110110
return None
111111
else:
@@ -157,6 +157,15 @@ def _assign(self, value):
157157
raise ValueError('Expect a float value: got {}'.format(value)) from None
158158
self._value = float_value
159159

160+
def to_xml_string(self, prefix_root=None,
161+
*, precision=constants.XML_DEFAULT_PRECISION, **kwargs):
162+
if self._value is None:
163+
return None
164+
else:
165+
out = io.BytesIO()
166+
np.savetxt(out, [self._value], fmt=f'%.{precision:d}g', newline=' ')
167+
return util.to_native_string(out.getvalue())[:-1] # Strip trailing space.
168+
160169

161170
class Keyword(_Attribute):
162171
"""A keyword MJCF attribute."""
@@ -199,15 +208,13 @@ def _assign(self, value):
199208
def _assign_from_string(self, string):
200209
self._assign(np.fromstring(string, dtype=self._dtype, sep=' '))
201210

202-
def to_xml_string(self, prefix_root=None): # pylint: disable=unused-argument
211+
def to_xml_string(self, prefix_root=None,
212+
*, precision=constants.XML_DEFAULT_PRECISION, **kwargs):
203213
if self._value is None:
204214
return None
205215
else:
206216
out = io.BytesIO()
207-
# 17 decimal digits is sufficient to represent a double float without loss
208-
# of precision.
209-
# https://en.wikipedia.org/wiki/IEEE_754#Character_representation
210-
np.savetxt(out, self._value, fmt='%.17g', newline=' ')
217+
np.savetxt(out, self._value, fmt=f'%.{precision:d}g', newline=' ')
211218
return util.to_native_string(out.getvalue())[:-1] # Strip trailing space.
212219

213220
def _check_shape(self, array):
@@ -255,7 +262,7 @@ def _defaults_string(self, prefix_root):
255262
prefix.append(self._value or '')
256263
return constants.PREFIX_SEPARATOR.join(prefix) or constants.PREFIX_SEPARATOR
257264

258-
def to_xml_string(self, prefix_root=None):
265+
def to_xml_string(self, prefix_root=None, **kwargs):
259266
if self._parent.tag == constants.DEFAULT:
260267
return self._defaults_string(prefix_root)
261268
elif self._value:
@@ -352,7 +359,7 @@ def _defaults_string(self, prefix_root):
352359
out_string = prefix + self._value
353360
return out_string
354361

355-
def to_xml_string(self, prefix_root):
362+
def to_xml_string(self, prefix_root, **kwargs):
356363
self._check_dead_reference()
357364
if isinstance(self._value, base.Element):
358365
return self._value.prefixed_identifier(prefix_root)
@@ -388,7 +395,7 @@ def _before_clear(self):
388395
if self._value:
389396
self._parent.namescope.remove(constants.BASEPATH, self._path_namespace)
390397

391-
def to_xml_string(self, prefix_root=None):
398+
def to_xml_string(self, prefix_root=None, **kwargs):
392399
return None
393400

394401

@@ -527,7 +534,7 @@ def get_contents(self):
527534
'querying the contents.')
528535
return self._value.contents
529536

530-
def to_xml_string(self, prefix_root=None):
537+
def to_xml_string(self, prefix_root=None, **kwargs):
531538
"""Returns the asset filename as it will appear in the generated XML."""
532539
del prefix_root # Unused
533540
if self._value is not None:

dm_control/mjcf/attribute_test.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,15 +131,37 @@ def assertCanNotBeCleared(self, mjcf_element, attribute_name):
131131

132132
def testFloatScalar(self):
133133
mujoco = self._mujoco
134-
mujoco.optional.float = 5
135-
self.assertEqual(mujoco.optional.float, 5)
134+
mujoco.optional.float = 0.357357
135+
self.assertEqual(mujoco.optional.float, 0.357357)
136136
self.assertEqual(type(mujoco.optional.float), float)
137137
with self.assertRaisesRegex(ValueError, 'Expect a float value'):
138138
mujoco.optional.float = 'five'
139139
# failed assignment should not change the value
140-
self.assertEqual(mujoco.optional.float, 5)
141-
self.assertXMLStringEqual(mujoco.optional, 'float', '5.0')
142-
self.assertCanBeCleared(mujoco.optional, 'float')
140+
self.assertEqual(mujoco.optional.float, 0.357357)
141+
self.assertEqual(
142+
mujoco.optional.get_attribute_xml_string('float', precision=1),
143+
'0.4')
144+
self.assertEqual(
145+
mujoco.optional.get_attribute_xml_string('float', precision=2),
146+
'0.36')
147+
self.assertEqual(
148+
mujoco.optional.get_attribute_xml_string('float', precision=3),
149+
'0.357')
150+
self.assertEqual(
151+
mujoco.optional.get_attribute_xml_string('float', precision=4),
152+
'0.3574')
153+
self.assertEqual(
154+
mujoco.optional.get_attribute_xml_string('float', precision=5),
155+
'0.35736')
156+
self.assertEqual(
157+
mujoco.optional.get_attribute_xml_string('float', precision=6),
158+
'0.357357')
159+
self.assertEqual(
160+
mujoco.optional.get_attribute_xml_string('float', precision=7),
161+
'0.357357')
162+
self.assertEqual(
163+
mujoco.optional.get_attribute_xml_string('float', precision=8),
164+
'0.357357')
143165

144166
def testIntScalar(self):
145167
mujoco = self._mujoco
@@ -178,6 +200,9 @@ def testFloatArray(self):
178200
mujoco.optional.float_array = [np.pi, 2, 1e-16]
179201
self.assertXMLStringEqual(mujoco.optional, 'float_array',
180202
'3.1415926535897931 2 9.9999999999999998e-17')
203+
self.assertEqual(
204+
mujoco.optional.get_attribute_xml_string('float_array', precision=5),
205+
'3.1416 2 1e-16')
181206
self.assertCanBeCleared(mujoco.optional, 'float_array')
182207

183208
def testFormatVeryLargeArray(self):

dm_control/mjcf/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import abc
1919

20+
from dm_control.mjcf import constants
21+
2022

2123
class Element(metaclass=abc.ABCMeta):
2224
"""Abstract base class for an MJCF element.
@@ -220,7 +222,8 @@ def all_children(self):
220222
pass
221223

222224
@abc.abstractmethod
223-
def to_xml(self, prefix_root=None, debug_context=None):
225+
def to_xml(self, prefix_root=None, debug_context=None,
226+
*, precision=constants.XML_DEFAULT_PRECISION):
224227
"""Generates an etree._Element corresponding to this MJCF element.
225228
226229
Args:
@@ -231,14 +234,17 @@ def to_xml(self, prefix_root=None, debug_context=None):
231234
the debugging information associated with the generated XML is written.
232235
This is intended for internal use within PyMJCF; users should never need
233236
manually pass this argument.
237+
precision: (optional) Number of digits to output for floating point
238+
quantities.
234239
235240
Returns:
236241
An etree._Element object.
237242
"""
238243

239244
@abc.abstractmethod
240245
def to_xml_string(self, prefix_root=None,
241-
self_only=False, pretty_print=True, debug_context=None):
246+
self_only=False, pretty_print=True, debug_context=None,
247+
*, precision=constants.XML_DEFAULT_PRECISION):
242248
"""Generates an XML string corresponding to this MJCF element.
243249
244250
Args:
@@ -253,6 +259,8 @@ def to_xml_string(self, prefix_root=None,
253259
the debugging information associated with the generated XML is written.
254260
This is intended for internal use within PyMJCF; users should never need
255261
manually pass this argument.
262+
precision: (optional) Number of digits to output for floating point
263+
quantities.
256264
257265
Returns:
258266
A string.

dm_control/mjcf/constants.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,8 @@
7272
INDIRECT_REFERENCE_ATTRIB = {
7373
'xbody': 'body',
7474
}
75+
76+
# 17 decimal digits is sufficient to represent a double float without loss
77+
# of precision.
78+
# https://en.wikipedia.org/wiki/IEEE_754#Character_representation
79+
XML_DEFAULT_PRECISION = 17

dm_control/mjcf/element.py

Lines changed: 51 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _mjcf_property(self):
7575
err_with_next_tb = err.with_traceback(tb.tb_next)
7676
if isinstance(err, AttributeError):
7777
self._last_attribute_error = err_with_next_tb # pylint: disable=protected-access
78-
raise err_with_next_tb
78+
raise err_with_next_tb # pylint: disable=raise-missing-from
7979
return _raw_property(_mjcf_property)
8080

8181

@@ -182,7 +182,7 @@ def __init__(self, spec, parent, attributes=None):
182182
attribute_obj._force_clear() # pylint: disable=protected-access
183183
# Then raise a meaningful error
184184
err_type, err, tb = sys.exc_info()
185-
raise err_type(
185+
raise err_type( # pylint: disable=raise-missing-from
186186
f'during initialization of attribute {attribute_spec.name!r} of '
187187
f'element <{self._spec.name}>: {err}').with_traceback(tb)
188188

@@ -507,9 +507,14 @@ def _get_attribute(self, attribute_name):
507507
self._check_valid_attribute(attribute_name)
508508
return self._attributes[attribute_name].value
509509

510-
def get_attribute_xml_string(self, attribute_name, prefix_root=None):
510+
def get_attribute_xml_string(self,
511+
attribute_name,
512+
prefix_root=None,
513+
*,
514+
precision=constants.XML_DEFAULT_PRECISION):
511515
self._check_valid_attribute(attribute_name)
512-
return self._attributes[attribute_name].to_xml_string(prefix_root)
516+
return self._attributes[attribute_name].to_xml_string(
517+
prefix_root, precision=precision)
513518

514519
def get_attributes(self):
515520
fix_attribute_name = (
@@ -541,7 +546,7 @@ def set_attributes(self, **kwargs):
541546
self._set_attribute(name, old_value)
542547
# Then raise a meaningful error.
543548
err_type, err, tb = sys.exc_info()
544-
raise err_type(
549+
raise err_type( # pylint: disable=raise-missing-from
545550
f'during assignment to attribute {attribute_name!r} of '
546551
f'element <{self._spec.name}>: {err}').with_traceback(tb)
547552

@@ -554,7 +559,7 @@ def _check_valid_child(self, element_name):
554559
try:
555560
return self._spec.children[element_name]
556561
except KeyError:
557-
raise AttributeError(
562+
raise AttributeError( # pylint: disable=raise-missing-from
558563
'<{}> is not a valid child of <{}>'
559564
.format(element_name, self._spec.name))
560565

@@ -690,7 +695,8 @@ def all_children(self):
690695
if child.spec.repeated]
691696
return all_children
692697

693-
def to_xml(self, prefix_root=None, debug_context=None):
698+
def to_xml(self, prefix_root=None, debug_context=None,
699+
*, precision=constants.XML_DEFAULT_PRECISION):
694700
"""Generates an etree._Element corresponding to this MJCF element.
695701
696702
Args:
@@ -701,30 +707,37 @@ def to_xml(self, prefix_root=None, debug_context=None):
701707
the debugging information associated with the generated XML is written.
702708
This is intended for internal use within PyMJCF; users should never need
703709
manually pass this argument.
710+
precision: (optional) Number of digits to output for floating point
711+
quantities.
704712
705713
Returns:
706714
An etree._Element object.
707715
"""
708716
prefix_root = prefix_root or self.namescope
709717
xml_element = etree.Element(self._spec.name)
710-
self._attributes_to_xml(xml_element, prefix_root, debug_context)
711-
self._children_to_xml(xml_element, prefix_root, debug_context)
718+
self._attributes_to_xml(xml_element, prefix_root, debug_context,
719+
precision=precision)
720+
self._children_to_xml(xml_element, prefix_root, debug_context,
721+
precision=precision)
712722
return xml_element
713723

714-
def _attributes_to_xml(self, xml_element, prefix_root, debug_context=None):
724+
def _attributes_to_xml(self, xml_element, prefix_root, debug_context=None,
725+
*, precision):
715726
del debug_context # Unused.
716727
for attribute_name, attribute in self._attributes.items():
717-
attribute_value = attribute.to_xml_string(prefix_root)
728+
attribute_value = attribute.to_xml_string(prefix_root,
729+
precision=precision)
718730
if attribute_name == self._spec.identifier and attribute_value is None:
719731
xml_element.set(attribute_name, self.full_identifier)
720732
elif attribute_value is None:
721733
continue
722734
else:
723735
xml_element.set(attribute_name, attribute_value)
724736

725-
def _children_to_xml(self, xml_element, prefix_root, debug_context=None):
737+
def _children_to_xml(self, xml_element, prefix_root, debug_context=None,
738+
*, precision):
726739
for child in self.all_children():
727-
child_xml = child.to_xml(prefix_root, debug_context)
740+
child_xml = child.to_xml(prefix_root, debug_context, precision=precision)
728741
if (child_xml.attrib or len(child_xml) # pylint: disable=g-explicit-length-test
729742
or child.spec.repeated or child.spec.on_demand):
730743
xml_element.append(child_xml)
@@ -735,7 +748,8 @@ def _children_to_xml(self, xml_element, prefix_root, debug_context=None):
735748
child_xml.insert(0, copy.deepcopy(debug_comment))
736749

737750
def to_xml_string(self, prefix_root=None,
738-
self_only=False, pretty_print=True, debug_context=None):
751+
self_only=False, pretty_print=True, debug_context=None,
752+
*, precision=constants.XML_DEFAULT_PRECISION):
739753
"""Generates an XML string corresponding to this MJCF element.
740754
741755
Args:
@@ -750,16 +764,19 @@ def to_xml_string(self, prefix_root=None,
750764
the debugging information associated with the generated XML is written.
751765
This is intended for internal use within PyMJCF; users should never need
752766
manually pass this argument.
767+
precision: (optional) Number of digits to output for floating point
768+
quantities.
753769
754770
Returns:
755771
A string.
756772
"""
757-
xml_element = self.to_xml(prefix_root, debug_context)
773+
xml_element = self.to_xml(prefix_root, debug_context, precision=precision)
758774
if self_only and len(xml_element) > 0: # pylint: disable=g-explicit-length-test
759775
etree.strip_elements(xml_element, '*')
760776
xml_element.text = '...'
761777
if (self_only and self._spec.identifier and
762-
not self._attributes[self._spec.identifier].to_xml_string(prefix_root)):
778+
not self._attributes[self._spec.identifier].to_xml_string(
779+
prefix_root, precision=precision)):
763780
del xml_element.attrib[self._spec.identifier]
764781
xml_string = util.to_native_string(
765782
etree.tostring(xml_element, pretty_print=pretty_print))
@@ -987,8 +1004,10 @@ def prefixed_identifier(self, prefix_root=None):
9871004
prefix = self.namescope.full_prefix(prefix_root)
9881005
return prefix + self._attachment.namescope.name + constants.PREFIX_SEPARATOR
9891006

990-
def to_xml(self, prefix_root=None, debug_context=None):
991-
xml_element = (super().to_xml(prefix_root, debug_context))
1007+
def to_xml(self, prefix_root=None, debug_context=None,
1008+
*, precision=constants.XML_DEFAULT_PRECISION):
1009+
xml_element = (super().to_xml(prefix_root, debug_context,
1010+
precision=precision))
9921011
xml_element.set('name', self.prefixed_identifier(prefix_root))
9931012
return xml_element
9941013

@@ -1013,8 +1032,10 @@ class _AttachmentFrameChild(_ElementImpl):
10131032
"""
10141033
__slots__ = []
10151034

1016-
def to_xml(self, prefix_root=None, debug_context=None):
1017-
xml_element = (super().to_xml(prefix_root, debug_context))
1035+
def to_xml(self, prefix_root=None, debug_context=None,
1036+
*, precision=constants.XML_DEFAULT_PRECISION):
1037+
xml_element = (super().to_xml(prefix_root, debug_context,
1038+
precision=precision))
10181039
if self.spec.namespace is not None:
10191040
if self.name:
10201041
name = (self._parent.prefixed_identifier(prefix_root) +
@@ -1051,14 +1072,17 @@ def _attach(self, other, exclude_worldbody=False, dry_run=False):
10511072
def all_children(self):
10521073
return [child for child in self._children]
10531074

1054-
def to_xml(self, prefix_root=None, debug_context=None):
1075+
def to_xml(self, prefix_root=None, debug_context=None,
1076+
*, precision=constants.XML_DEFAULT_PRECISION):
10551077
prefix_root = prefix_root or self.namescope
1056-
xml_element = (super().to_xml(prefix_root, debug_context))
1078+
xml_element = (super().to_xml(prefix_root, debug_context,
1079+
precision=precision))
10571080
if isinstance(self._parent, RootElement):
10581081
root_default = etree.Element(self._spec.name)
10591082
root_default.append(xml_element)
10601083
for attachment in self._attachments.values():
1061-
attachment_xml = attachment.to_xml(prefix_root, debug_context)
1084+
attachment_xml = attachment.to_xml(prefix_root, debug_context,
1085+
precision=precision)
10621086
for attachment_child_xml in attachment_xml:
10631087
root_default.append(attachment_child_xml)
10641088
xml_element = root_default
@@ -1082,12 +1106,13 @@ def _is_third_order_actuator(self, child):
10821106
else:
10831107
return False # No other actuator shortcuts have internal dynamics.
10841108

1085-
def _children_to_xml(self, xml_element, prefix_root, debug_context=None):
1109+
def _children_to_xml(self, xml_element, prefix_root, debug_context=None,
1110+
*, precision=constants.XML_DEFAULT_PRECISION):
10861111
second_order = []
10871112
third_order = []
10881113
debug_comments = {}
10891114
for child in self.all_children():
1090-
child_xml = child.to_xml(prefix_root, debug_context)
1115+
child_xml = child.to_xml(prefix_root, debug_context, precision=precision)
10911116
if (child_xml.attrib or len(child_xml) # pylint: disable=g-explicit-length-test
10921117
or child.spec.repeated or child.spec.on_demand):
10931118
if self._is_third_order_actuator(child):
@@ -1297,7 +1322,7 @@ def __getitem__(self, index):
12971322
return scoped_elements[index[(len(scope_name) + 1):]]
12981323
except KeyError:
12991324
# Re-raise so that the error shows the full, un-stripped index string
1300-
raise self._identifier_not_found_error(index)
1325+
raise self._identifier_not_found_error(index) # pylint: disable=raise-missing-from
13011326
elif isinstance(index, slice) or (isinstance(index, int) and index < 0):
13021327
return self._full_list()[index]
13031328
else:

0 commit comments

Comments
 (0)