-
Notifications
You must be signed in to change notification settings - Fork 49
Add EnvGraphSpec class & its yaml parser #690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
xyao-nv
wants to merge
7
commits into
main
Choose a base branch
from
xyao/dev/env_yaml
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,224 @@ | ||
| # Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from dataclasses import dataclass, field | ||
| from enum import Enum | ||
| from pathlib import Path | ||
| from typing import Any | ||
|
|
||
| import yaml | ||
|
|
||
| from isaaclab_arena.assets.object_base import ObjectType | ||
| from isaaclab_arena.environments.utils import ( | ||
| as_dict, | ||
| assert_env_graph_references_exist, | ||
| assert_env_graph_universal_ids, | ||
| optional_dict, | ||
| optional_enum, | ||
| optional_str, | ||
| parse_list, | ||
| required_enum, | ||
| required_number_sequence, | ||
| required_str, | ||
| ) | ||
|
|
||
|
|
||
| class ArenaEnvGraphNodeType(Enum): | ||
| EMBODIMENT = "embodiment" | ||
| BACKGROUND = "background" | ||
| OBJECT = "object" | ||
| OBJECT_REFERENCE = "objectReference" | ||
| LIGHTING = "lighting" | ||
|
|
||
|
|
||
| class ArenaEnvGraphSpatialConstraintType(Enum): | ||
| IS_ANCHOR = "is_anchor" | ||
| NEXT_TO = "next_to" | ||
| ON = "on" | ||
| AT_POSE = "at_pose" # through set_initial_pose() | ||
| AT_POSITION = "at_position" # through object relation solver: AtPosition | ||
| POSITION_LIMITS = "position_limits" | ||
| RANDOM_AROUND_SOLUTION = "random_around_solution" | ||
| ROTATE_AROUND_SOLUTION = "rotate_around_solution" | ||
| # TODO(xinjieyao, 2026-05-21): Support "in" in solver | ||
| IN = "in" | ||
|
|
||
|
|
||
| @dataclass | ||
| class ArenaEnvGraphNodeSpec: | ||
| """Node in an environment graph. | ||
|
|
||
| Could be an object, an object reference, an embodiment, a background, etc. | ||
| """ | ||
|
|
||
| id: str | ||
| name: str # Name registered in the asset registry | ||
| type: ArenaEnvGraphNodeType | ||
| parent: str | None = None # Optional, only need for object references | ||
| prim_path: str | None = None # Optional, only need for object references | ||
| object_type: ObjectType | None = None # Optional, only need for type=object | ||
| params: dict[str, Any] = field(default_factory=dict) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ArenaEnvGraphSpatialConstraintSpec: | ||
| """Spatial constraint edge in an environment graph state spec. | ||
|
|
||
| It defines a relation between two nodes. | ||
| """ | ||
|
|
||
| id: str | ||
| type: ArenaEnvGraphSpatialConstraintType | ||
| parent: str | ||
| child: str | None = None # Optional, e.g. is_anchor constraint does not have a child | ||
| params: dict[str, Any] = field(default_factory=dict) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ArenaEnvGraphTaskConstraintSpec: | ||
| """Task-dependent constraint edge in an environment graph state spec.""" | ||
|
|
||
| id: str | ||
| type: str | ||
| parent: str | ||
| child: str | None = None # Optional, could be a robot keeps gripper open or closed, or a single object | ||
| params: dict[str, Any] = field(default_factory=dict) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ArenaEnvGraphStateSpec: | ||
| """Snapshot of the environment state in the graph. | ||
|
|
||
| Could be an initial, intermediate, or final state. | ||
| """ | ||
|
|
||
| id: str | ||
| name: str | ||
| spatial_constraints: list[ArenaEnvGraphSpatialConstraintSpec] = field(default_factory=list) | ||
| task_constraints: list[ArenaEnvGraphTaskConstraintSpec] = field(default_factory=list) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ArenaEnvGraphTaskSpec: | ||
| """Task entry in an environment graph.""" | ||
|
|
||
| id: str | ||
| name: str | ||
| type: str # Task class name, could be a custom task class or a built-in task class | ||
| initial_state_spec_id: str | ||
| success_state_spec_id: str | ||
| task_args: dict[str, Any] = field(default_factory=dict) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ArenaEnvGraphSpec: | ||
| """Typed representation of an environment graph YAML file. | ||
| It defines the nodes, tasks, and state specs of the environment graph. | ||
| """ | ||
|
|
||
| env_name: str | ||
| nodes: list[ArenaEnvGraphNodeSpec] = field(default_factory=list) | ||
| tasks: list[ArenaEnvGraphTaskSpec] = field(default_factory=list) | ||
| state_specs: list[ArenaEnvGraphStateSpec] = field(default_factory=list) | ||
|
|
||
| @classmethod | ||
| def from_yaml(cls, path: str | Path) -> "ArenaEnvGraphSpec": | ||
| with Path(path).open("r", encoding="utf-8") as f: | ||
| return cls.from_dict(yaml.safe_load(f)) | ||
|
|
||
| @classmethod | ||
| def from_dict(cls, data: dict[str, Any]) -> "ArenaEnvGraphSpec": | ||
| data = as_dict(data, "Env graph spec") | ||
| nodes = parse_list(data, "nodes", _parse_node) | ||
| tasks = parse_list(data, "tasks", _parse_task) | ||
| state_specs = parse_list(data, "state_specs", _parse_state_spec) | ||
|
|
||
| assert_env_graph_universal_ids(nodes, tasks, state_specs) | ||
| assert_env_graph_references_exist(nodes, tasks, state_specs) | ||
|
|
||
| return cls( | ||
| env_name=required_str(data, "env_name"), | ||
| nodes=nodes, | ||
| tasks=tasks, | ||
| state_specs=state_specs, | ||
| ) | ||
|
|
||
| @property | ||
| def nodes_by_id(self) -> dict[str, ArenaEnvGraphNodeSpec]: | ||
| return {node.id: node for node in self.nodes} | ||
|
|
||
| @property | ||
| def tasks_by_id(self) -> dict[str, ArenaEnvGraphTaskSpec]: | ||
| return {task.id: task for task in self.tasks} | ||
|
|
||
| @property | ||
| def state_specs_by_id(self) -> dict[str, ArenaEnvGraphStateSpec]: | ||
| return {state_spec.id: state_spec for state_spec in self.state_specs} | ||
|
|
||
|
|
||
| def _parse_node(data: Any) -> ArenaEnvGraphNodeSpec: | ||
| data = as_dict(data, "Node spec") | ||
| return ArenaEnvGraphNodeSpec( | ||
| id=required_str(data, "id"), | ||
| name=required_str(data, "name"), | ||
| type=required_enum(data, "type", ArenaEnvGraphNodeType), | ||
| parent=optional_str(data, "parent"), | ||
| prim_path=optional_str(data, "prim_path"), | ||
| object_type=optional_enum(data, "object_type", ObjectType), | ||
| params=optional_dict(data, "params"), | ||
| ) | ||
|
|
||
|
|
||
| def _parse_spatial_constraint(data: Any) -> ArenaEnvGraphSpatialConstraintSpec: | ||
| data = as_dict(data, "Spatial constraint spec") | ||
| constraint_type = required_enum(data, "type", ArenaEnvGraphSpatialConstraintType) | ||
| params = optional_dict(data, "params") | ||
| if constraint_type == ArenaEnvGraphSpatialConstraintType.AT_POSE: | ||
| params["position_xyz"] = required_number_sequence(params, "position_xyz", 3) | ||
| params["rotation_xyzw"] = required_number_sequence(params, "rotation_xyzw", 4) | ||
|
|
||
| return ArenaEnvGraphSpatialConstraintSpec( | ||
| id=required_str(data, "id"), | ||
| type=constraint_type, | ||
| parent=required_str(data, "parent"), | ||
| child=optional_str(data, "child"), | ||
| params=params, | ||
| ) | ||
|
|
||
|
|
||
| def _parse_task_constraint(data: Any) -> ArenaEnvGraphTaskConstraintSpec: | ||
| data = as_dict(data, "Task constraint spec") | ||
| return ArenaEnvGraphTaskConstraintSpec( | ||
| id=required_str(data, "id"), | ||
| type=required_str(data, "type"), | ||
| parent=optional_str(data, "parent"), | ||
| child=optional_str(data, "child"), | ||
| params=optional_dict(data, "params"), | ||
| ) | ||
|
|
||
|
|
||
| def _parse_state_spec(data: Any) -> ArenaEnvGraphStateSpec: | ||
| data = as_dict(data, "State spec") | ||
| assert "edges" not in data, "State spec must define spatial_constraints and task_constraints directly" | ||
| return ArenaEnvGraphStateSpec( | ||
| id=required_str(data, "id"), | ||
| name=required_str(data, "name"), | ||
| spatial_constraints=parse_list(data, "spatial_constraints", _parse_spatial_constraint), | ||
| task_constraints=parse_list(data, "task_constraints", _parse_task_constraint), | ||
| ) | ||
|
|
||
|
|
||
| def _parse_task(data: Any) -> ArenaEnvGraphTaskSpec: | ||
| data = as_dict(data, "Task spec") | ||
| for old_key in ("state_specs", "initial_state_spec", "success_state_spec"): | ||
| assert old_key not in data, "Task spec must use initial_state_spec_id and success_state_spec_id" | ||
| return ArenaEnvGraphTaskSpec( | ||
| id=required_str(data, "id"), | ||
| name=required_str(data, "name"), | ||
| type=required_str(data, "type"), | ||
| initial_state_spec_id=required_str(data, "initial_state_spec_id"), | ||
| success_state_spec_id=required_str(data, "success_state_spec_id"), | ||
| task_args=optional_dict(data, "task_args"), | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| # Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from collections.abc import Callable | ||
| from enum import Enum | ||
| from numbers import Real | ||
| from typing import Any | ||
|
|
||
|
|
||
| def as_dict(data: Any, spec_name: str) -> dict[str, Any]: | ||
| assert isinstance(data, dict), f"{spec_name} must be a dict, got {type(data).__name__}" | ||
| return data | ||
|
|
||
|
|
||
| def parse_list(data: dict[str, Any], key: str, parser: Callable[[Any], Any]) -> list[Any]: | ||
| values = data.get(key, []) | ||
| assert isinstance(values, list), f"Field '{key}' must be a list" | ||
| return [parser(value) for value in values] | ||
|
|
||
|
|
||
| def required_str(data: dict[str, Any], key: str) -> str: | ||
| value = data.get(key) | ||
| assert isinstance(value, str) and value, f"Missing required string field '{key}'" | ||
| return value | ||
|
|
||
|
|
||
| def optional_str(data: dict[str, Any], key: str) -> str | None: | ||
| value = data.get(key) | ||
| assert value is None or isinstance(value, str), f"Optional field '{key}' must be a string when set" | ||
| return value | ||
|
|
||
|
|
||
| def optional_dict(data: dict[str, Any], key: str) -> dict[str, Any]: | ||
| value = data.get(key, {}) | ||
| assert value is None or isinstance(value, dict), f"Optional field '{key}' must be a dict when set" | ||
| return dict(value or {}) | ||
|
|
||
|
|
||
| def required_number_sequence(data: dict[str, Any], key: str, length: int) -> tuple[float, ...]: | ||
| value = data.get(key) | ||
| assert isinstance(value, (list, tuple)), f"Missing required numeric sequence field '{key}'" | ||
| assert len(value) == length, f"Field '{key}' must contain {length} numbers" | ||
| assert all(isinstance(item, Real) and not isinstance(item, bool) for item in value), ( | ||
| f"Field '{key}' must contain only numbers" | ||
| ) | ||
| return tuple(float(item) for item in value) | ||
|
|
||
|
|
||
| def required_enum(data: dict[str, Any], key: str, enum_type: type[Enum]) -> Enum: | ||
| value = data.get(key) | ||
| assert value is not None, f"Missing required field '{key}'" | ||
| parsed = parse_enum(value, key, enum_type) | ||
| assert parsed is not None | ||
| return parsed | ||
|
|
||
|
|
||
| def optional_enum(data: dict[str, Any], key: str, enum_type: type[Enum]) -> Enum | None: | ||
| return parse_enum(data.get(key), key, enum_type) | ||
|
|
||
|
|
||
| def parse_enum(value: Any, key: str, enum_type: type[Enum]) -> Enum | None: | ||
| if value is None or isinstance(value, enum_type): | ||
| return value | ||
| assert isinstance(value, str), f"Field '{key}' must be a string when set" | ||
| try: | ||
| return enum_type(value) | ||
| except ValueError: | ||
| valid_values = [enum_value.value for enum_value in enum_type] | ||
| raise AssertionError(f"Unknown {key} '{value}'. Expected one of {valid_values}") from None | ||
|
|
||
|
|
||
| def assert_env_graph_universal_ids(nodes: list[Any], tasks: list[Any], state_specs: list[Any]) -> None: | ||
| id_locations: dict[str, list[str]] = {} | ||
| for node in nodes: | ||
| _add_id_location(id_locations, node.id, f"node '{node.id}'") | ||
| for task in tasks: | ||
| _add_id_location(id_locations, task.id, f"task '{task.id}'") | ||
| for state_spec in state_specs: | ||
| _add_id_location(id_locations, state_spec.id, f"state spec '{state_spec.id}'") | ||
| for constraint in state_spec.spatial_constraints: | ||
| _add_id_location(id_locations, constraint.id, f"spatial constraint '{constraint.id}'") | ||
| for constraint in state_spec.task_constraints: | ||
| _add_id_location(id_locations, constraint.id, f"task constraint '{constraint.id}'") | ||
|
|
||
| duplicates = {spec_id: locations for spec_id, locations in id_locations.items() if len(locations) > 1} | ||
| assert not duplicates, f"Duplicate env graph ids found: {duplicates}" | ||
|
|
||
|
|
||
| def assert_env_graph_references_exist(nodes: list[Any], tasks: list[Any], state_specs: list[Any]) -> None: | ||
| node_ids = {node.id for node in nodes} | ||
| state_spec_ids = {state_spec.id for state_spec in state_specs} | ||
|
|
||
| for node in nodes: | ||
| if node.parent is not None: | ||
| assert node.parent in node_ids, f"Node '{node.id}' references unknown parent '{node.parent}'" | ||
|
|
||
| for task in tasks: | ||
| for label, state_spec_id in ( | ||
| ("initial_state_spec_id", task.initial_state_spec_id), | ||
| ("success_state_spec_id", task.success_state_spec_id), | ||
| ): | ||
| assert ( | ||
| state_spec_id in state_spec_ids | ||
| ), f"Task '{task.id}' references unknown state spec '{state_spec_id}' for '{label}'" | ||
|
|
||
| for state_spec in state_specs: | ||
| for constraint in state_spec.spatial_constraints: | ||
| assert ( | ||
| constraint.parent in node_ids | ||
| ), f"Constraint '{constraint.id}' references unknown parent node '{constraint.parent}'" | ||
| if constraint.child is not None: | ||
| assert ( | ||
| constraint.child in node_ids | ||
| ), f"Constraint '{constraint.id}' references unknown child node '{constraint.child}'" | ||
|
|
||
| for constraint in state_spec.task_constraints: | ||
| if constraint.parent is not None: | ||
| assert ( | ||
| constraint.parent in node_ids | ||
| ), f"Constraint '{constraint.id}' references unknown parent node '{constraint.parent}'" | ||
| if constraint.child is not None: | ||
| assert ( | ||
| constraint.child in node_ids | ||
| ), f"Constraint '{constraint.id}' references unknown child node '{constraint.child}'" | ||
|
|
||
|
|
||
| def _add_id_location(id_locations: dict[str, list[str]], spec_id: str, location: str) -> None: | ||
| id_locations.setdefault(spec_id, []).append(location) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ArenaEnvGraphTaskConstraintSpec.parentis typed asstr(required, no default) but_parse_task_constraintassigns it viaoptional_str(), which returnsstr | None. The validation function confirms the design intent —assert_env_graph_references_existguardsif constraint.parent is not None:for task constraints, meaningNoneis a valid runtime value. Any downstream code (or static type checker) treatingparentas a guaranteedstrwill fail at runtime or produce false positives. The annotation and default should match the actual parser contract.