Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions isaaclab_arena/environments/__init__.py

This file was deleted.

224 changes: 224 additions & 0 deletions isaaclab_arena/environments/arena_env_graph_spec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any

import yaml

from isaaclab_arena.assets.object_base import ObjectType
from isaaclab_arena.environments.utils import (
as_dict,
assert_env_graph_references_exist,
assert_env_graph_universal_ids,
optional_dict,
optional_enum,
optional_str,
parse_list,
required_enum,
required_number_sequence,
required_str,
)


class ArenaEnvGraphNodeType(Enum):
EMBODIMENT = "embodiment"
BACKGROUND = "background"
OBJECT = "object"
OBJECT_REFERENCE = "objectReference"
LIGHTING = "lighting"


class ArenaEnvGraphSpatialConstraintType(Enum):
IS_ANCHOR = "is_anchor"
NEXT_TO = "next_to"
ON = "on"
AT_POSE = "at_pose" # through set_initial_pose()
AT_POSITION = "at_position" # through object relation solver: AtPosition
POSITION_LIMITS = "position_limits"
RANDOM_AROUND_SOLUTION = "random_around_solution"
ROTATE_AROUND_SOLUTION = "rotate_around_solution"
# TODO(xinjieyao, 2026-05-21): Support "in" in solver
IN = "in"


@dataclass
class ArenaEnvGraphNodeSpec:
"""Node in an environment graph.

Could be an object, an object reference, an embodiment, a background, etc.
"""

id: str
name: str # Name registered in the asset registry
type: ArenaEnvGraphNodeType
parent: str | None = None # Optional, only need for object references
prim_path: str | None = None # Optional, only need for object references
object_type: ObjectType | None = None # Optional, only need for type=object
params: dict[str, Any] = field(default_factory=dict)


@dataclass
class ArenaEnvGraphSpatialConstraintSpec:
"""Spatial constraint edge in an environment graph state spec.

It defines a relation between two nodes.
"""

id: str
type: ArenaEnvGraphSpatialConstraintType
parent: str
child: str | None = None # Optional, e.g. is_anchor constraint does not have a child
params: dict[str, Any] = field(default_factory=dict)


@dataclass
class ArenaEnvGraphTaskConstraintSpec:
"""Task-dependent constraint edge in an environment graph state spec."""

id: str
type: str
parent: str
child: str | None = None # Optional, could be a robot keeps gripper open or closed, or a single object
Comment on lines +83 to +86
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 ArenaEnvGraphTaskConstraintSpec.parent is typed as str (required, no default) but _parse_task_constraint assigns it via optional_str(), which returns str | None. The validation function confirms the design intent — assert_env_graph_references_exist guards if constraint.parent is not None: for task constraints, meaning None is a valid runtime value. Any downstream code (or static type checker) treating parent as a guaranteed str will fail at runtime or produce false positives. The annotation and default should match the actual parser contract.

Suggested change
id: str
type: str
parent: str
child: str | None = None # Optional, could be a robot keeps gripper open or closed, or a single object
id: str
type: str
parent: str | None = None # Optional, e.g. a gripper-state or single-object constraint has no parent
child: str | None = None # Optional, could be a robot keeps gripper open or closed, or a single object

params: dict[str, Any] = field(default_factory=dict)


@dataclass
class ArenaEnvGraphStateSpec:
"""Snapshot of the environment state in the graph.

Could be an initial, intermediate, or final state.
"""

id: str
name: str
spatial_constraints: list[ArenaEnvGraphSpatialConstraintSpec] = field(default_factory=list)
task_constraints: list[ArenaEnvGraphTaskConstraintSpec] = field(default_factory=list)


@dataclass
class ArenaEnvGraphTaskSpec:
"""Task entry in an environment graph."""

id: str
name: str
type: str # Task class name, could be a custom task class or a built-in task class
initial_state_spec_id: str
success_state_spec_id: str
task_args: dict[str, Any] = field(default_factory=dict)


@dataclass
class ArenaEnvGraphSpec:
"""Typed representation of an environment graph YAML file.
It defines the nodes, tasks, and state specs of the environment graph.
"""

env_name: str
nodes: list[ArenaEnvGraphNodeSpec] = field(default_factory=list)
tasks: list[ArenaEnvGraphTaskSpec] = field(default_factory=list)
state_specs: list[ArenaEnvGraphStateSpec] = field(default_factory=list)

@classmethod
def from_yaml(cls, path: str | Path) -> "ArenaEnvGraphSpec":
with Path(path).open("r", encoding="utf-8") as f:
return cls.from_dict(yaml.safe_load(f))

@classmethod
def from_dict(cls, data: dict[str, Any]) -> "ArenaEnvGraphSpec":
data = as_dict(data, "Env graph spec")
nodes = parse_list(data, "nodes", _parse_node)
tasks = parse_list(data, "tasks", _parse_task)
state_specs = parse_list(data, "state_specs", _parse_state_spec)

assert_env_graph_universal_ids(nodes, tasks, state_specs)
assert_env_graph_references_exist(nodes, tasks, state_specs)

return cls(
env_name=required_str(data, "env_name"),
nodes=nodes,
tasks=tasks,
state_specs=state_specs,
)

@property
def nodes_by_id(self) -> dict[str, ArenaEnvGraphNodeSpec]:
return {node.id: node for node in self.nodes}

@property
def tasks_by_id(self) -> dict[str, ArenaEnvGraphTaskSpec]:
return {task.id: task for task in self.tasks}

@property
def state_specs_by_id(self) -> dict[str, ArenaEnvGraphStateSpec]:
return {state_spec.id: state_spec for state_spec in self.state_specs}


def _parse_node(data: Any) -> ArenaEnvGraphNodeSpec:
data = as_dict(data, "Node spec")
return ArenaEnvGraphNodeSpec(
id=required_str(data, "id"),
name=required_str(data, "name"),
type=required_enum(data, "type", ArenaEnvGraphNodeType),
parent=optional_str(data, "parent"),
prim_path=optional_str(data, "prim_path"),
object_type=optional_enum(data, "object_type", ObjectType),
params=optional_dict(data, "params"),
)


def _parse_spatial_constraint(data: Any) -> ArenaEnvGraphSpatialConstraintSpec:
data = as_dict(data, "Spatial constraint spec")
constraint_type = required_enum(data, "type", ArenaEnvGraphSpatialConstraintType)
params = optional_dict(data, "params")
if constraint_type == ArenaEnvGraphSpatialConstraintType.AT_POSE:
params["position_xyz"] = required_number_sequence(params, "position_xyz", 3)
params["rotation_xyzw"] = required_number_sequence(params, "rotation_xyzw", 4)

return ArenaEnvGraphSpatialConstraintSpec(
id=required_str(data, "id"),
type=constraint_type,
parent=required_str(data, "parent"),
child=optional_str(data, "child"),
params=params,
)


def _parse_task_constraint(data: Any) -> ArenaEnvGraphTaskConstraintSpec:
data = as_dict(data, "Task constraint spec")
return ArenaEnvGraphTaskConstraintSpec(
id=required_str(data, "id"),
type=required_str(data, "type"),
parent=optional_str(data, "parent"),
child=optional_str(data, "child"),
params=optional_dict(data, "params"),
)


def _parse_state_spec(data: Any) -> ArenaEnvGraphStateSpec:
data = as_dict(data, "State spec")
assert "edges" not in data, "State spec must define spatial_constraints and task_constraints directly"
return ArenaEnvGraphStateSpec(
id=required_str(data, "id"),
name=required_str(data, "name"),
spatial_constraints=parse_list(data, "spatial_constraints", _parse_spatial_constraint),
task_constraints=parse_list(data, "task_constraints", _parse_task_constraint),
)


def _parse_task(data: Any) -> ArenaEnvGraphTaskSpec:
data = as_dict(data, "Task spec")
for old_key in ("state_specs", "initial_state_spec", "success_state_spec"):
assert old_key not in data, "Task spec must use initial_state_spec_id and success_state_spec_id"
return ArenaEnvGraphTaskSpec(
id=required_str(data, "id"),
name=required_str(data, "name"),
type=required_str(data, "type"),
initial_state_spec_id=required_str(data, "initial_state_spec_id"),
success_state_spec_id=required_str(data, "success_state_spec_id"),
task_args=optional_dict(data, "task_args"),
)
130 changes: 130 additions & 0 deletions isaaclab_arena/environments/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Callable
from enum import Enum
from numbers import Real
from typing import Any


def as_dict(data: Any, spec_name: str) -> dict[str, Any]:
assert isinstance(data, dict), f"{spec_name} must be a dict, got {type(data).__name__}"
return data


def parse_list(data: dict[str, Any], key: str, parser: Callable[[Any], Any]) -> list[Any]:
values = data.get(key, [])
assert isinstance(values, list), f"Field '{key}' must be a list"
return [parser(value) for value in values]


def required_str(data: dict[str, Any], key: str) -> str:
value = data.get(key)
assert isinstance(value, str) and value, f"Missing required string field '{key}'"
return value


def optional_str(data: dict[str, Any], key: str) -> str | None:
value = data.get(key)
assert value is None or isinstance(value, str), f"Optional field '{key}' must be a string when set"
return value


def optional_dict(data: dict[str, Any], key: str) -> dict[str, Any]:
value = data.get(key, {})
assert value is None or isinstance(value, dict), f"Optional field '{key}' must be a dict when set"
return dict(value or {})


def required_number_sequence(data: dict[str, Any], key: str, length: int) -> tuple[float, ...]:
value = data.get(key)
assert isinstance(value, (list, tuple)), f"Missing required numeric sequence field '{key}'"
assert len(value) == length, f"Field '{key}' must contain {length} numbers"
assert all(isinstance(item, Real) and not isinstance(item, bool) for item in value), (
f"Field '{key}' must contain only numbers"
)
return tuple(float(item) for item in value)


def required_enum(data: dict[str, Any], key: str, enum_type: type[Enum]) -> Enum:
value = data.get(key)
assert value is not None, f"Missing required field '{key}'"
parsed = parse_enum(value, key, enum_type)
assert parsed is not None
return parsed


def optional_enum(data: dict[str, Any], key: str, enum_type: type[Enum]) -> Enum | None:
return parse_enum(data.get(key), key, enum_type)


def parse_enum(value: Any, key: str, enum_type: type[Enum]) -> Enum | None:
if value is None or isinstance(value, enum_type):
return value
assert isinstance(value, str), f"Field '{key}' must be a string when set"
try:
return enum_type(value)
except ValueError:
valid_values = [enum_value.value for enum_value in enum_type]
raise AssertionError(f"Unknown {key} '{value}'. Expected one of {valid_values}") from None


def assert_env_graph_universal_ids(nodes: list[Any], tasks: list[Any], state_specs: list[Any]) -> None:
id_locations: dict[str, list[str]] = {}
for node in nodes:
_add_id_location(id_locations, node.id, f"node '{node.id}'")
for task in tasks:
_add_id_location(id_locations, task.id, f"task '{task.id}'")
for state_spec in state_specs:
_add_id_location(id_locations, state_spec.id, f"state spec '{state_spec.id}'")
for constraint in state_spec.spatial_constraints:
_add_id_location(id_locations, constraint.id, f"spatial constraint '{constraint.id}'")
for constraint in state_spec.task_constraints:
_add_id_location(id_locations, constraint.id, f"task constraint '{constraint.id}'")

duplicates = {spec_id: locations for spec_id, locations in id_locations.items() if len(locations) > 1}
assert not duplicates, f"Duplicate env graph ids found: {duplicates}"


def assert_env_graph_references_exist(nodes: list[Any], tasks: list[Any], state_specs: list[Any]) -> None:
node_ids = {node.id for node in nodes}
state_spec_ids = {state_spec.id for state_spec in state_specs}

for node in nodes:
if node.parent is not None:
assert node.parent in node_ids, f"Node '{node.id}' references unknown parent '{node.parent}'"

for task in tasks:
for label, state_spec_id in (
("initial_state_spec_id", task.initial_state_spec_id),
("success_state_spec_id", task.success_state_spec_id),
):
assert (
state_spec_id in state_spec_ids
), f"Task '{task.id}' references unknown state spec '{state_spec_id}' for '{label}'"

for state_spec in state_specs:
for constraint in state_spec.spatial_constraints:
assert (
constraint.parent in node_ids
), f"Constraint '{constraint.id}' references unknown parent node '{constraint.parent}'"
if constraint.child is not None:
assert (
constraint.child in node_ids
), f"Constraint '{constraint.id}' references unknown child node '{constraint.child}'"

for constraint in state_spec.task_constraints:
if constraint.parent is not None:
assert (
constraint.parent in node_ids
), f"Constraint '{constraint.id}' references unknown parent node '{constraint.parent}'"
if constraint.child is not None:
assert (
constraint.child in node_ids
), f"Constraint '{constraint.id}' references unknown child node '{constraint.child}'"


def _add_id_location(id_locations: dict[str, list[str]], spec_id: str, location: str) -> None:
id_locations.setdefault(spec_id, []).append(location)
Loading
Loading