Skip to content

Commit ae85a28

Browse files
sbohezcopybara-github
authored andcommitted
Add math module to variation package and make sequence variation evaluate items.
PiperOrigin-RevId: 495758194 Change-Id: I8fae6c2709e4a0d37c985863cde47ee0dea61057
1 parent 73f1607 commit ae85a28

3 files changed

Lines changed: 91 additions & 5 deletions

File tree

dm_control/composer/variation/deterministic.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818

1919
from dm_control.composer.variation import base
20+
from dm_control.composer.variation.variation_values import evaluate
2021

2122

2223
class Constant(base.Variation):
@@ -42,10 +43,12 @@ def __init__(self, values):
4243

4344
def __call__(self, initial_value=None, current_value=None, random_state=None):
4445
try:
45-
return next(self._iterator)
46+
return evaluate(next(self._iterator), initial_value=initial_value,
47+
current_value=current_value, random_state=random_state)
4648
except StopIteration:
4749
self._iterator = iter(self._values)
48-
return next(self._iterator)
50+
return evaluate(next(self._iterator), initial_value=initial_value,
51+
current_value=current_value, random_state=random_state)
4952

5053

5154
class Identity(base.Variation):

dm_control/composer/variation/distributions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import abc
1818
import functools
1919

20-
from dm_control.composer import variation
2120
from dm_control.composer.variation import base
21+
from dm_control.composer.variation.variation_values import evaluate
2222
import numpy as np
2323

2424

@@ -45,12 +45,12 @@ def __call__(self, initial_value=None, current_value=None, random_state=None):
4545
size = (
4646
None if self._single_sample or initial_value is None # pylint: disable=g-long-ternary
4747
else np.shape(initial_value))
48-
local_args = variation.evaluate(
48+
local_args = evaluate(
4949
self._args,
5050
initial_value=initial_value,
5151
current_value=current_value,
5252
random_state=random_state)
53-
local_kwargs = variation.evaluate(
53+
local_kwargs = evaluate(
5454
self._kwargs,
5555
initial_value=initial_value,
5656
current_value=current_value,
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2018 The dm_control Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ============================================================================
15+
16+
"""Math operations on variation objects."""
17+
18+
import abc
19+
20+
from dm_control.composer.variation import base
21+
from dm_control.composer.variation.variation_values import evaluate
22+
23+
import numpy as np
24+
25+
26+
class MathOp(base.Variation):
27+
"""Base MathOp class for applying math operations on variation objects.
28+
29+
Subclasses need to implement `_op`, which takes in a single value and applies
30+
the desired math operation. This operation gets applied to the result of the
31+
evaluated base variation object passed at construction. Structured variation
32+
objects are automatically traversed.
33+
"""
34+
35+
def __init__(self, *args, **kwargs):
36+
self._args = args
37+
self._kwargs = kwargs
38+
39+
def __call__(self, initial_value=None, current_value=None, random_state=None):
40+
local_args = evaluate(
41+
self._args,
42+
initial_value=initial_value,
43+
current_value=current_value,
44+
random_state=random_state)
45+
local_kwargs = evaluate(
46+
self._kwargs,
47+
initial_value=initial_value,
48+
current_value=current_value,
49+
random_state=random_state)
50+
return self._callable(*local_args, **local_kwargs)
51+
52+
@property
53+
@abc.abstractmethod
54+
def _callable(self):
55+
pass
56+
57+
58+
class Log(MathOp):
59+
60+
@property
61+
def _callable(self):
62+
return np.log
63+
64+
65+
class Max(MathOp):
66+
67+
@property
68+
def _callable(self):
69+
return np.max
70+
71+
72+
class Min(MathOp):
73+
74+
@property
75+
def _callable(self):
76+
return np.min
77+
78+
79+
class Norm(MathOp):
80+
81+
@property
82+
def _callable(self):
83+
return np.linalg.norm

0 commit comments

Comments
 (0)