Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,51 @@ configuration are needed.
NumPy operations cannot be traced by LEAPP.


Observation-Term Input Boundaries
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

By default, the manager-based export path discovers policy inputs by tracing observation terms
through Isaac Lab proxies. This works well for standard observations that read annotated simulator
state, such as ``robot.data.joint_pos`` or ``robot.data.joint_vel``. In those cases, the exported
LEAPP input is the lower-level simulator state and any PyTorch preprocessing in the observation
term is included in the exported graph.

Some deployment workflows should expose a higher-level observation term instead of the simulator
or task-internal tensors used to compute it. This is useful when the deployment system already
produces the policy-facing observation value directly, or when the term depends on bookkeeping that
should not become part of the exported runtime interface. In that case, mark the observation term
with ``leapp_observation_input``:

.. code-block:: python

from isaaclab.utils.leapp import XYZ_ELEMENT_NAMES, leapp_observation_input


@leapp_observation_input(kind="state/body/position", element_names=[XYZ_ELEMENT_NAMES])
class object_position_w(ManagerTermBase):
...

During export, Isaac Lab computes the configured observation term, applies deterministic
post-processing such as modifiers, clipping, and scaling, disables observation noise, and registers
that term value as the LEAPP input boundary. The generated metadata uses an
``observation:{group}:{term}`` connection, for example:

.. code-block:: text

observation:policy:gear_shaft_pos
observation:policy:gear_shaft_quat

When running the exported policy with :class:`~envs.LeappDeploymentEnv`, Isaac Lab constructs an
``ObservationManager`` only when the LEAPP YAML declares ``observation:*`` inputs and computes the
named terms from the task configuration. External runtimes should provide the same term values with
the same ordering, units, frames, and shape used by the policy observation group.

For example, a gear assembly deployment may provide shaft pose from perception or calibration,
while the training environment computes it from gear type state, fixture offsets, and environment
origins. Marking ``gear_shaft_pos_w`` and ``gear_shaft_quat_w`` with
``leapp_observation_input`` exports those final shaft-pose terms as the policy inputs.


Verifying an Export
-------------------

Expand Down
107 changes: 95 additions & 12 deletions source/isaaclab/isaaclab/envs/leapp_deployment_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

"""Deployment environment that runs LEAPP-exported policies in simulation.

This environment bypasses all Isaac Lab managers (observation, action, reward, etc.)
and instead wires scene entity data properties and ``CommandManager`` outputs directly
to a LEAPP ``InferenceManager``, then writes the model outputs back to the
corresponding scene entities. All I/O resolution is driven by the
This environment bypasses most Isaac Lab managers (action, reward, etc.) and instead
wires scene entity data properties, selected observation terms, and ``CommandManager``
outputs directly to a LEAPP ``InferenceManager``, then writes the model outputs back
to the corresponding scene entities. All I/O resolution is driven by the
``isaaclab_connection`` field in the LEAPP YAML.
"""

Expand All @@ -27,7 +27,7 @@
except ImportError as e:
raise ImportError("LEAPP package is required for policy deployment testing. Install with: pip install leapp") from e

from isaaclab.managers import CommandManager, EventManager
from isaaclab.managers import CommandManager, EventManager, ObservationManager
from isaaclab.scene import InteractiveScene
from isaaclab.sim import SimulationContext
from isaaclab.sim.utils.stage import use_stage
Expand Down Expand Up @@ -59,6 +59,14 @@ class CommandInputSpec:
command_term_name: str


@dataclass
class ObservationInputSpec:
"""Read a named observation term from ``ObservationManager``."""

group_name: str
term_name: str


@dataclass
class WriteOutputSpec:
"""Write a tensor to a scene entity method, optionally indexed by joint."""
Expand Down Expand Up @@ -134,6 +142,18 @@ def _first_param_name(method: Any) -> str:
return params[0].name


def _leapp_desc_has_connection(leapp_desc: dict[str, Any], connection_type: str) -> bool:
"""Return whether a LEAPP description contains an I/O connection of *connection_type*."""
prefix = f"{connection_type}:"
for model_desc in leapp_desc.get("models", {}).values():
for io_section in ("inputs", "outputs"):
for tensor_desc in model_desc.get(io_section, []):
connection = tensor_desc.get("isaaclab_connection")
if isinstance(connection, str) and connection.startswith(prefix):
return True
return False


# ══════════════════════════════════════════════════════════════════
# LeappDeploymentEnv
# ══════════════════════════════════════════════════════════════════
Expand All @@ -153,10 +173,11 @@ class LeappDeploymentEnv:

- ``state:{entity}:{property}`` -- read ``scene[entity].data.{property}``
- ``command:{name}`` -- read ``command_manager.get_command(name)``
- ``observation:{group}:{term}`` -- compute a named observation term
- ``write:{entity}:{method}`` -- call ``scene[entity].{method}(tensor, ...)``

No observation, action, reward, termination, or curriculum managers are used.
The LEAPP model already contains all pre/post-processing.
No action, reward, termination, or curriculum managers are used. The LEAPP model
already contains all action pre/post-processing.
"""

def __init__(self, cfg: Any, leapp_yaml_path: str):
Expand Down Expand Up @@ -197,6 +218,10 @@ def __init__(self, cfg: Any, leapp_yaml_path: str):
else:
self.viewport_camera_controller = None

# ── Parse YAML once before constructing optional managers ─
with open(leapp_yaml_path) as f:
self._leapp_desc = yaml.safe_load(f)

# ── EventManager (optional, for resets) ───────────────────
self.event_manager: EventManager | None = None
if hasattr(cfg, "events") and cfg.events is not None:
Expand All @@ -207,13 +232,21 @@ def __init__(self, cfg: Any, leapp_yaml_path: str):
if hasattr(cfg, "commands") and cfg.commands is not None:
self.command_manager = CommandManager(cfg.commands, cast(Any, self))

# ── ObservationManager (optional, for observation/* inputs) ─
self.observation_manager: ObservationManager | None = None
if _leapp_desc_has_connection(self._leapp_desc, "observation"):
if not hasattr(cfg, "observations") or cfg.observations is None:
raise RuntimeError(
"LEAPP YAML declares observation inputs but no ObservationManager configuration is available "
"(cfg.observations is None)."
)
self.observation_manager = ObservationManager(cfg.observations, cast(Any, self))

# ── LEAPP InferenceManager ────────────────────────────────
self.inference = InferenceManager(leapp_yaml_path)

# ── Parse YAML and resolve I/O mappings ───────────────────
with open(leapp_yaml_path) as f:
self._leapp_desc = yaml.safe_load(f)
self._input_mapping: dict[str, StateInputSpec | CommandInputSpec] = {}
# ── Resolve I/O mappings ─────────────────────────────────
self._input_mapping: dict[str, StateInputSpec | CommandInputSpec | ObservationInputSpec] = {}
self._output_mapping: dict[str, WriteOutputSpec] = {}
self._resolve_io()

Expand Down Expand Up @@ -286,6 +319,10 @@ def _resolve_io(self):
"CommandManager is available (cfg.commands is None)."
)
self._input_mapping[key] = CommandInputSpec(command_term_name=command_name)
elif conn_type == "observation":

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Warning: If a LEAPP YAML contains a malformed connection like "observation:foo" (missing the term component), parts[2] raises an unguarded IndexError. Consider adding validation:

if len(parts) < 3:
    raise ValueError(
        f"Malformed observation connection '{connection}' for input '{key}': "
        "expected format 'observation:<group>:<term>'."
    )

group_name, term_name = parts[1], parts[2]
Comment on lines +322 to +323

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 IndexError on malformed observation: connection strings

parts[1] and parts[2] are accessed without a length guard. A YAML connection value of "observation:policy" (only two colon-separated segments) would raise IndexError at this line. The state: and write: branches above share the same pattern, but adding an explicit len(parts) != 3 check here with a descriptive ValueError would make misconfigured YAMLs much easier to diagnose.

self._validate_observation_term(key, group_name, term_name)
self._input_mapping[key] = ObservationInputSpec(group_name=group_name, term_name=term_name)
else:
logger.warning("Unknown connection type '%s' for input '%s'", conn_type, key)

Expand Down Expand Up @@ -315,6 +352,27 @@ def _resolve_io(self):
else:
logger.warning("Unknown connection type '%s' for output '%s'", conn_type, key)

def _validate_observation_term(self, key: str, group_name: str, term_name: str):
"""Raise a useful error if a LEAPP observation input cannot be resolved."""
observation_manager = self.observation_manager
if observation_manager is None:
raise RuntimeError(
f"LEAPP input '{key}' requires observation term '{group_name}:{term_name}' but no "
"ObservationManager is available."
)

group_term_names = observation_manager.active_terms.get(group_name)
if group_term_names is None:
raise ValueError(
f"LEAPP input '{key}' references unknown observation group '{group_name}'. "
f"Available groups are: {list(observation_manager.active_terms.keys())}"
)
if term_name not in group_term_names:
raise ValueError(
f"LEAPP input '{key}' references unknown observation term '{group_name}:{term_name}'. "
f"Available terms in group '{group_name}' are: {group_term_names}"
)

# ── Read / Write ──────────────────────────────────────────────

def _read_inputs(self) -> dict[str, torch.Tensor]:
Expand All @@ -336,8 +394,29 @@ def _read_inputs(self) -> dict[str, torch.Tensor]:
command_manager = self.command_manager
assert command_manager is not None
inputs[key] = command_manager.get_command(spec.command_term_name)
elif isinstance(spec, ObservationInputSpec):
inputs[key] = self._compute_observation_input(spec)
return inputs

def _compute_observation_input(self, spec: ObservationInputSpec) -> torch.Tensor:
"""Compute one named observation term without injecting observation noise."""
observation_manager = self.observation_manager
assert observation_manager is not None

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 Warning: This accesses private attributes _group_obs_term_names and _group_obs_term_cfgs. The validation above uses the public .active_terms API, but the actual computation bypasses it. If ObservationManager internals change, this will break without a clear error. Consider whether a targeted public method (e.g., compute_term(group, name, noise=False)) would be more maintainable long-term.


group_term_names = observation_manager._group_obs_term_names[spec.group_name]
term_index = group_term_names.index(spec.term_name)
term_cfg = observation_manager._group_obs_term_cfgs[spec.group_name][term_index]

obs: torch.Tensor = term_cfg.func(cast(Any, self), **term_cfg.params).clone()
if term_cfg.modifiers is not None:
for modifier in term_cfg.modifiers:
obs = modifier.func(obs, **modifier.params)
if term_cfg.clip:
obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1])
if term_cfg.scale is not None:
obs = obs.mul_(term_cfg.scale)
return obs

def _write_outputs(self, outputs: dict[str, torch.Tensor]):
"""Write model outputs to scene entities.

Expand All @@ -364,12 +443,14 @@ def reset(self) -> dict[str, torch.Tensor]:
Returns:
The initial input tensors (for logging / debugging).
"""
env_ids = [0]
env_ids = torch.arange(self.num_envs, dtype=torch.int32, device=self.device)

self.scene.reset(env_ids)

if self.event_manager is not None and "reset" in self.event_manager.available_modes:
self.event_manager.apply(mode="reset", env_ids=env_ids, global_env_step_count=self._step_count)
if self.observation_manager is not None:
self.observation_manager.reset(env_ids)
if self.command_manager is not None:
self.command_manager.reset(env_ids)

Expand Down Expand Up @@ -440,6 +521,8 @@ def close(self):
self.sim.stop()
if self.command_manager is not None:
del self.command_manager
if self.observation_manager is not None:
del self.observation_manager
if self.event_manager is not None:
del self.event_manager
del self.scene
Expand Down
6 changes: 6 additions & 0 deletions source/isaaclab/isaaclab/utils/leapp/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@ __all__ = [
"body_wrench_resolver",
"body_xyz_resolver",
"build_command_connection",
"build_observation_connection",
"build_state_connection",
"build_write_connection",
"joint_names_resolver",
"leapp_observation_input",
"leapp_tensor_semantics",
"patch_env_for_export",
"resolve_leapp_element_names",
"resolve_leapp_observation_input_semantics",
"target_frame_pose_resolver",
"target_frame_quat_resolver",
"target_frame_xyz_resolver",
Expand All @@ -48,14 +51,17 @@ from .leapp_semantics import (
body_wrench_resolver,
body_xyz_resolver,
joint_names_resolver,
leapp_observation_input,
leapp_tensor_semantics,
resolve_leapp_element_names,
resolve_leapp_observation_input_semantics,
target_frame_pose_resolver,
target_frame_quat_resolver,
target_frame_xyz_resolver,
)
from .utils import (
build_command_connection,
build_observation_connection,
build_state_connection,
build_write_connection,
)
Loading
Loading