-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathscene.py
More file actions
83 lines (73 loc) · 2.64 KB
/
scene.py
File metadata and controls
83 lines (73 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import Dict
import torch
from .rendering import NVDiffRastRenderer
from .robot import Robot
from .structs import VirtualCamera
class RobotScene:
__slots__ = [
"_cameras",
"_robot",
"_renderer",
]
def __init__(
self,
cameras: Dict[str, VirtualCamera],
robot: Robot, # TODO: ideally this would be any combinations of TorchMeshContainers, i.e. Dict[str, TorchMeshContainer] (future work: RobotScene -> Scene, robot -> objects)
renderer: NVDiffRastRenderer,
) -> None:
self._cameras = cameras
self._robot = robot
self._renderer = renderer
self._verify_devices()
def _verify_devices(self) -> None:
for camera_name in self._cameras.keys():
if not all(
[
self._cameras[camera_name].device == self._robot.device,
self._robot.device == self._renderer.device,
]
):
raise ValueError(
"All devices must be the same. Got:\n"
f"Camera '{camera_name}' on: {self._cameras[camera_name].device}\n"
f"Robot on: {self._robot.device}\n"
f"Renderer on: {self._renderer.device}"
)
def observe_from(
self, camera_name: str, reference_transform: torch.FloatTensor = None
) -> torch.Tensor:
if reference_transform is None:
reference_transform = torch.eye(
4,
dtype=self._cameras[camera_name].extrinsics.dtype,
device=self._cameras[camera_name].extrinsics.device,
)
observed_vertices = torch.matmul(
self._robot.configured_vertices.clone(),
torch.matmul(
torch.linalg.inv(
torch.matmul(
reference_transform,
torch.matmul(
self._cameras[camera_name].extrinsics,
self._cameras[camera_name].ht_optical,
),
)
).transpose(-1, -2),
self._cameras[camera_name].perspective_projection.transpose(-1, -2),
),
)
return self._renderer.constant_color(
observed_vertices,
self._robot.faces,
self._cameras[camera_name].resolution,
)
@property
def cameras(self) -> Dict[str, VirtualCamera]:
return self._cameras
@property
def robot(self) -> Robot:
return self._robot
@property
def renderer(self) -> NVDiffRastRenderer:
return self._renderer