diff --git a/docs/source/policy_deployment/05_leapp/exporting_policies_with_leapp.rst b/docs/source/policy_deployment/05_leapp/exporting_policies_with_leapp.rst index 8f645757db00..e514bec2e390 100644 --- a/docs/source/policy_deployment/05_leapp/exporting_policies_with_leapp.rst +++ b/docs/source/policy_deployment/05_leapp/exporting_policies_with_leapp.rst @@ -188,6 +188,51 @@ configuration are needed. NumPy operations cannot be traced by LEAPP. +Observation-Term Input Boundaries +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, the manager-based export path discovers policy inputs by tracing observation terms +through Isaac Lab proxies. This works well for standard observations that read annotated simulator +state, such as ``robot.data.joint_pos`` or ``robot.data.joint_vel``. In those cases, the exported +LEAPP input is the lower-level simulator state and any PyTorch preprocessing in the observation +term is included in the exported graph. + +Some deployment workflows should expose a higher-level observation term instead of the simulator +or task-internal tensors used to compute it. This is useful when the deployment system already +produces the policy-facing observation value directly, or when the term depends on bookkeeping that +should not become part of the exported runtime interface. In that case, mark the observation term +with ``leapp_observation_input``: + +.. code-block:: python + + from isaaclab.utils.leapp import XYZ_ELEMENT_NAMES, leapp_observation_input + + + @leapp_observation_input(kind="state/body/position", element_names=[XYZ_ELEMENT_NAMES]) + class object_position_w(ManagerTermBase): + ... + +During export, Isaac Lab computes the configured observation term, applies deterministic +post-processing such as modifiers, clipping, and scaling, disables observation noise, and registers +that term value as the LEAPP input boundary. The generated metadata uses an +``observation:{group}:{term}`` connection, for example: + +.. code-block:: text + + observation:policy:gear_shaft_pos + observation:policy:gear_shaft_quat + +When running the exported policy with :class:`~envs.LeappDeploymentEnv`, Isaac Lab constructs an +``ObservationManager`` only when the LEAPP YAML declares ``observation:*`` inputs and computes the +named terms from the task configuration. External runtimes should provide the same term values with +the same ordering, units, frames, and shape used by the policy observation group. + +For example, a gear assembly deployment may provide shaft pose from perception or calibration, +while the training environment computes it from gear type state, fixture offsets, and environment +origins. Marking ``gear_shaft_pos_w`` and ``gear_shaft_quat_w`` with +``leapp_observation_input`` exports those final shaft-pose terms as the policy inputs. + + Verifying an Export ------------------- diff --git a/source/isaaclab/isaaclab/envs/leapp_deployment_env.py b/source/isaaclab/isaaclab/envs/leapp_deployment_env.py index a2ad5b897647..d5b6409e1553 100644 --- a/source/isaaclab/isaaclab/envs/leapp_deployment_env.py +++ b/source/isaaclab/isaaclab/envs/leapp_deployment_env.py @@ -5,10 +5,10 @@ """Deployment environment that runs LEAPP-exported policies in simulation. -This environment bypasses all Isaac Lab managers (observation, action, reward, etc.) -and instead wires scene entity data properties and ``CommandManager`` outputs directly -to a LEAPP ``InferenceManager``, then writes the model outputs back to the -corresponding scene entities. All I/O resolution is driven by the +This environment bypasses most Isaac Lab managers (action, reward, etc.) and instead +wires scene entity data properties, selected observation terms, and ``CommandManager`` +outputs directly to a LEAPP ``InferenceManager``, then writes the model outputs back +to the corresponding scene entities. All I/O resolution is driven by the ``isaaclab_connection`` field in the LEAPP YAML. """ @@ -27,7 +27,7 @@ except ImportError as e: raise ImportError("LEAPP package is required for policy deployment testing. Install with: pip install leapp") from e -from isaaclab.managers import CommandManager, EventManager +from isaaclab.managers import CommandManager, EventManager, ObservationManager from isaaclab.scene import InteractiveScene from isaaclab.sim import SimulationContext from isaaclab.sim.utils.stage import use_stage @@ -59,6 +59,14 @@ class CommandInputSpec: command_term_name: str +@dataclass +class ObservationInputSpec: + """Read a named observation term from ``ObservationManager``.""" + + group_name: str + term_name: str + + @dataclass class WriteOutputSpec: """Write a tensor to a scene entity method, optionally indexed by joint.""" @@ -134,6 +142,18 @@ def _first_param_name(method: Any) -> str: return params[0].name +def _leapp_desc_has_connection(leapp_desc: dict[str, Any], connection_type: str) -> bool: + """Return whether a LEAPP description contains an I/O connection of *connection_type*.""" + prefix = f"{connection_type}:" + for model_desc in leapp_desc.get("models", {}).values(): + for io_section in ("inputs", "outputs"): + for tensor_desc in model_desc.get(io_section, []): + connection = tensor_desc.get("isaaclab_connection") + if isinstance(connection, str) and connection.startswith(prefix): + return True + return False + + # ══════════════════════════════════════════════════════════════════ # LeappDeploymentEnv # ══════════════════════════════════════════════════════════════════ @@ -153,10 +173,11 @@ class LeappDeploymentEnv: - ``state:{entity}:{property}`` -- read ``scene[entity].data.{property}`` - ``command:{name}`` -- read ``command_manager.get_command(name)`` + - ``observation:{group}:{term}`` -- compute a named observation term - ``write:{entity}:{method}`` -- call ``scene[entity].{method}(tensor, ...)`` - No observation, action, reward, termination, or curriculum managers are used. - The LEAPP model already contains all pre/post-processing. + No action, reward, termination, or curriculum managers are used. The LEAPP model + already contains all action pre/post-processing. """ def __init__(self, cfg: Any, leapp_yaml_path: str): @@ -197,6 +218,10 @@ def __init__(self, cfg: Any, leapp_yaml_path: str): else: self.viewport_camera_controller = None + # ── Parse YAML once before constructing optional managers ─ + with open(leapp_yaml_path) as f: + self._leapp_desc = yaml.safe_load(f) + # ── EventManager (optional, for resets) ─────────────────── self.event_manager: EventManager | None = None if hasattr(cfg, "events") and cfg.events is not None: @@ -207,13 +232,21 @@ def __init__(self, cfg: Any, leapp_yaml_path: str): if hasattr(cfg, "commands") and cfg.commands is not None: self.command_manager = CommandManager(cfg.commands, cast(Any, self)) + # ── ObservationManager (optional, for observation/* inputs) ─ + self.observation_manager: ObservationManager | None = None + if _leapp_desc_has_connection(self._leapp_desc, "observation"): + if not hasattr(cfg, "observations") or cfg.observations is None: + raise RuntimeError( + "LEAPP YAML declares observation inputs but no ObservationManager configuration is available " + "(cfg.observations is None)." + ) + self.observation_manager = ObservationManager(cfg.observations, cast(Any, self)) + # ── LEAPP InferenceManager ──────────────────────────────── self.inference = InferenceManager(leapp_yaml_path) - # ── Parse YAML and resolve I/O mappings ─────────────────── - with open(leapp_yaml_path) as f: - self._leapp_desc = yaml.safe_load(f) - self._input_mapping: dict[str, StateInputSpec | CommandInputSpec] = {} + # ── Resolve I/O mappings ───────────────────────────────── + self._input_mapping: dict[str, StateInputSpec | CommandInputSpec | ObservationInputSpec] = {} self._output_mapping: dict[str, WriteOutputSpec] = {} self._resolve_io() @@ -286,6 +319,10 @@ def _resolve_io(self): "CommandManager is available (cfg.commands is None)." ) self._input_mapping[key] = CommandInputSpec(command_term_name=command_name) + elif conn_type == "observation": + group_name, term_name = parts[1], parts[2] + self._validate_observation_term(key, group_name, term_name) + self._input_mapping[key] = ObservationInputSpec(group_name=group_name, term_name=term_name) else: logger.warning("Unknown connection type '%s' for input '%s'", conn_type, key) @@ -315,6 +352,27 @@ def _resolve_io(self): else: logger.warning("Unknown connection type '%s' for output '%s'", conn_type, key) + def _validate_observation_term(self, key: str, group_name: str, term_name: str): + """Raise a useful error if a LEAPP observation input cannot be resolved.""" + observation_manager = self.observation_manager + if observation_manager is None: + raise RuntimeError( + f"LEAPP input '{key}' requires observation term '{group_name}:{term_name}' but no " + "ObservationManager is available." + ) + + group_term_names = observation_manager.active_terms.get(group_name) + if group_term_names is None: + raise ValueError( + f"LEAPP input '{key}' references unknown observation group '{group_name}'. " + f"Available groups are: {list(observation_manager.active_terms.keys())}" + ) + if term_name not in group_term_names: + raise ValueError( + f"LEAPP input '{key}' references unknown observation term '{group_name}:{term_name}'. " + f"Available terms in group '{group_name}' are: {group_term_names}" + ) + # ── Read / Write ────────────────────────────────────────────── def _read_inputs(self) -> dict[str, torch.Tensor]: @@ -336,8 +394,29 @@ def _read_inputs(self) -> dict[str, torch.Tensor]: command_manager = self.command_manager assert command_manager is not None inputs[key] = command_manager.get_command(spec.command_term_name) + elif isinstance(spec, ObservationInputSpec): + inputs[key] = self._compute_observation_input(spec) return inputs + def _compute_observation_input(self, spec: ObservationInputSpec) -> torch.Tensor: + """Compute one named observation term without injecting observation noise.""" + observation_manager = self.observation_manager + assert observation_manager is not None + + group_term_names = observation_manager._group_obs_term_names[spec.group_name] + term_index = group_term_names.index(spec.term_name) + term_cfg = observation_manager._group_obs_term_cfgs[spec.group_name][term_index] + + obs: torch.Tensor = term_cfg.func(cast(Any, self), **term_cfg.params).clone() + if term_cfg.modifiers is not None: + for modifier in term_cfg.modifiers: + obs = modifier.func(obs, **modifier.params) + if term_cfg.clip: + obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1]) + if term_cfg.scale is not None: + obs = obs.mul_(term_cfg.scale) + return obs + def _write_outputs(self, outputs: dict[str, torch.Tensor]): """Write model outputs to scene entities. @@ -364,12 +443,14 @@ def reset(self) -> dict[str, torch.Tensor]: Returns: The initial input tensors (for logging / debugging). """ - env_ids = [0] + env_ids = torch.arange(self.num_envs, dtype=torch.int32, device=self.device) self.scene.reset(env_ids) if self.event_manager is not None and "reset" in self.event_manager.available_modes: self.event_manager.apply(mode="reset", env_ids=env_ids, global_env_step_count=self._step_count) + if self.observation_manager is not None: + self.observation_manager.reset(env_ids) if self.command_manager is not None: self.command_manager.reset(env_ids) @@ -440,6 +521,8 @@ def close(self): self.sim.stop() if self.command_manager is not None: del self.command_manager + if self.observation_manager is not None: + del self.observation_manager if self.event_manager is not None: del self.event_manager del self.scene diff --git a/source/isaaclab/isaaclab/utils/leapp/__init__.pyi b/source/isaaclab/isaaclab/utils/leapp/__init__.pyi index e2b8f497b5f3..061d28ab07d3 100644 --- a/source/isaaclab/isaaclab/utils/leapp/__init__.pyi +++ b/source/isaaclab/isaaclab/utils/leapp/__init__.pyi @@ -20,12 +20,15 @@ __all__ = [ "body_wrench_resolver", "body_xyz_resolver", "build_command_connection", + "build_observation_connection", "build_state_connection", "build_write_connection", "joint_names_resolver", + "leapp_observation_input", "leapp_tensor_semantics", "patch_env_for_export", "resolve_leapp_element_names", + "resolve_leapp_observation_input_semantics", "target_frame_pose_resolver", "target_frame_quat_resolver", "target_frame_xyz_resolver", @@ -48,14 +51,17 @@ from .leapp_semantics import ( body_wrench_resolver, body_xyz_resolver, joint_names_resolver, + leapp_observation_input, leapp_tensor_semantics, resolve_leapp_element_names, + resolve_leapp_observation_input_semantics, target_frame_pose_resolver, target_frame_quat_resolver, target_frame_xyz_resolver, ) from .utils import ( build_command_connection, + build_observation_connection, build_state_connection, build_write_connection, ) diff --git a/source/isaaclab/isaaclab/utils/leapp/export_annotator.py b/source/isaaclab/isaaclab/utils/leapp/export_annotator.py index 3ca82a66c082..a9d562e4960d 100644 --- a/source/isaaclab/isaaclab/utils/leapp/export_annotator.py +++ b/source/isaaclab/isaaclab/utils/leapp/export_annotator.py @@ -45,11 +45,16 @@ from isaaclab.assets.articulation.base_articulation import BaseArticulation from isaaclab.managers import ManagerTermBase -from .leapp_semantics import select_element_names +from .leapp_semantics import ( + resolve_leapp_element_names, + resolve_leapp_observation_input_semantics, + select_element_names, +) from .proxy import _ArticulationWriteProxy, _DataProxy, _EnvProxy, _ManagerTermProxy from .utils import ( TracedProxyArray, build_command_connection, + build_observation_connection, build_write_connection, ) @@ -129,7 +134,7 @@ def setup(self, env): ) self._disable_training_managers(unwrapped) - self._patch_observation_manager(unwrapped.observation_manager, proxy_env) + self._patch_observation_manager(unwrapped.observation_manager, proxy_env, unwrapped) self._patch_history_buffers(unwrapped.observation_manager) self._patch_action_manager( unwrapped.action_manager, @@ -272,21 +277,38 @@ def patched_append(data: torch.Tensor): circular_buffer._leapp_original_append = original_append circular_buffer._append = patched_append - def _patch_observation_manager(self, obs_manager, proxy_env): + def _patch_observation_manager(self, obs_manager, proxy_env, real_env): """Patch observation terms to use annotating proxies and disable noise. Args: obs_manager: Observation manager instance to patch. proxy_env: Proxy environment routed into observation terms. + real_env: Unwrapped environment used for explicit observation inputs. """ + term_names_by_group = getattr(obs_manager, "_group_obs_term_names", {}) for group_name, term_cfgs in obs_manager._group_obs_term_cfgs.items(): if self.required_obs_groups is not None and group_name not in self.required_obs_groups: continue - for term_cfg in term_cfgs: + group_term_names = term_names_by_group.get(group_name, []) + for index, term_cfg in enumerate(term_cfgs): original_func = term_cfg.func func_name = getattr(original_func, "__name__", None) - - if func_name == "last_action": + term_name = group_term_names[index] if index < len(group_term_names) else func_name + observation_input_semantics = resolve_leapp_observation_input_semantics(original_func) + + if observation_input_semantics is not None: + term_cfg.func = self._wrap_observation_input( + original_func, + real_env, + group_name, + term_name, + observation_input_semantics, + term_cfg, + ) + term_cfg.modifiers = None + term_cfg.clip = None + term_cfg.scale = None + elif func_name == "last_action": self._uses_last_action_state = True term_cfg.func = self._wrap_last_action(original_func) elif func_name == "generated_commands": @@ -308,6 +330,52 @@ def patched_compute(*args, **kwargs): obs_manager.compute = patched_compute + @staticmethod + def _apply_observation_post_processing(obs: torch.Tensor, term_cfg) -> torch.Tensor: + """Apply deterministic observation post-processing configured on a term.""" + if term_cfg.modifiers is not None: + for modifier in term_cfg.modifiers: + obs = modifier.func(obs, **modifier.params) + if term_cfg.clip: + obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1]) + if term_cfg.scale is not None: + obs = obs.mul_(term_cfg.scale) + return obs + + def _wrap_observation_input(self, original_func, real_env, group_name: str, term_name: str, semantics, term_cfg): + """Wrap a full observation term as an explicit LEAPP input tensor. + + Some policy inputs are task-level observation terms, not single raw + scene state properties. Calling the original term with the real env and + applying deterministic post-processing keeps the configured observation + boundary intact while making the term a live LEAPP deployment input. + """ + task_name = self.task_name + element_names = resolve_leapp_element_names(semantics, original_func) + + def wrapped(*args, **kwargs): + if args: + args = (real_env, *args[1:]) + else: + args = (real_env,) + result = original_func(*args, **kwargs) + result = self._apply_observation_post_processing(result, term_cfg) + sem = TensorSemantics( + name=term_name, + ref=result, + kind=semantics.kind, + element_names=element_names, + extra=build_observation_connection(group_name, term_name), + ) + return annotate.input_tensors(task_name, sem) + + wrapped.__name__ = getattr(original_func, "__name__", term_name) + if hasattr(original_func, "reset"): + wrapped.reset = original_func.reset + if hasattr(original_func, "serialize"): + wrapped.serialize = original_func.serialize + return wrapped + # ── Action manager patches ──────────────────────────────────── def _patch_action_manager(self, action_manager, cache): diff --git a/source/isaaclab/isaaclab/utils/leapp/leapp_semantics.py b/source/isaaclab/isaaclab/utils/leapp/leapp_semantics.py index 340291de16ad..3dce0f8cb076 100644 --- a/source/isaaclab/isaaclab/utils/leapp/leapp_semantics.py +++ b/source/isaaclab/isaaclab/utils/leapp/leapp_semantics.py @@ -90,6 +90,40 @@ def _apply(func: Callable) -> Callable: return _apply +def leapp_observation_input( + *, + kind: Any = None, + element_names: list[str] | list[list[str]] | None = None, + element_names_resolver: Callable | None = None, +) -> Callable: + """Mark an observation term's returned tensor as a named LEAPP input boundary. + + This is intended for observation terms whose deployment interface should be + the final observation value rather than the lower-level scene-state tensors + read inside the term. + """ + + semantics = LeappTensorSemantics( + kind=kind, + element_names=element_names, + element_names_resolver=element_names_resolver, + ) + + def _apply(term: Callable) -> Callable: + term._leapp_observation_input_semantics = semantics + return term + + return _apply + + +def resolve_leapp_observation_input_semantics(term: Any) -> LeappTensorSemantics | None: + """Return LEAPP input-boundary metadata attached to an observation term.""" + semantics = getattr(term, "_leapp_observation_input_semantics", None) + if semantics is not None: + return semantics + return getattr(type(term), "_leapp_observation_input_semantics", None) + + def resolve_leapp_element_names(semantics: LeappTensorSemantics | None, data_self) -> list | None: """Resolve element names from attached semantics and a tensor-producing object.""" if semantics is None: diff --git a/source/isaaclab/isaaclab/utils/leapp/utils.py b/source/isaaclab/isaaclab/utils/leapp/utils.py index 2308f662ab80..7dcff894fc14 100644 --- a/source/isaaclab/isaaclab/utils/leapp/utils.py +++ b/source/isaaclab/isaaclab/utils/leapp/utils.py @@ -82,6 +82,11 @@ def build_command_connection(command_name: str) -> dict[str, str]: return {"isaaclab_connection": f"command:{command_name}"} +def build_observation_connection(group_name: str, term_name: str) -> dict[str, str]: + """Return a compact deployment connection string for an observation term.""" + return {"isaaclab_connection": f"observation:{group_name}:{term_name}"} + + def build_write_connection(entity_name: str, method_name: str) -> dict[str, str]: """Return a compact deployment connection string for an articulation write target.""" return {"isaaclab_connection": f"write:{entity_name}:{method_name}"} diff --git a/source/isaaclab/test/utils/test_leapp_observation_input.py b/source/isaaclab/test/utils/test_leapp_observation_input.py new file mode 100644 index 000000000000..ab54148f6ae6 --- /dev/null +++ b/source/isaaclab/test/utils/test_leapp_observation_input.py @@ -0,0 +1,118 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Unit tests for LEAPP observation-term input metadata.""" + +from types import SimpleNamespace + +import torch + +import isaaclab.utils.leapp.export_annotator as export_annotator +from isaaclab.utils.leapp import ( + ExportPatcher, + XYZ_ELEMENT_NAMES, + leapp_observation_input, + resolve_leapp_element_names, + resolve_leapp_observation_input_semantics, +) + + +def test_leapp_observation_input_marks_function(): + """Observation input metadata can be attached to function-based terms.""" + + @leapp_observation_input(kind="state/body/position", element_names=[XYZ_ELEMENT_NAMES]) + def observation_term(env): + return None + + semantics = resolve_leapp_observation_input_semantics(observation_term) + + assert semantics is not None + assert semantics.kind == "state/body/position" + assert semantics.element_names == [XYZ_ELEMENT_NAMES] + assert resolve_leapp_element_names(semantics, observation_term) == [XYZ_ELEMENT_NAMES] + + +def test_leapp_observation_input_marks_class_instances(): + """Observation input metadata can be attached to class-based manager terms.""" + + @leapp_observation_input(kind="state/gear/type") + class ObservationTerm: + def __call__(self, env): + return None + + semantics_from_class = resolve_leapp_observation_input_semantics(ObservationTerm) + semantics_from_instance = resolve_leapp_observation_input_semantics(ObservationTerm()) + + assert semantics_from_class is not None + assert semantics_from_class.kind == "state/gear/type" + assert semantics_from_instance is semantics_from_class + + +def test_leapp_observation_input_supports_element_name_resolver(): + """Observation input metadata can resolve element names lazily.""" + + def element_names_resolver(term): + return [term.element_names] + + @leapp_observation_input(kind="state/custom", element_names_resolver=element_names_resolver) + class ObservationTerm: + element_names = ["a", "b", "c"] + + def __call__(self, env): + return None + + term = ObservationTerm() + semantics = resolve_leapp_observation_input_semantics(term) + + assert semantics is not None + assert resolve_leapp_element_names(semantics, term) == [["a", "b", "c"]] + + +def test_leapp_observation_input_wrapper_annotates_processed_term(monkeypatch): + """The export wrapper registers the configured observation value as the LEAPP input.""" + + def add_offset(obs, offset: float): + return obs + offset + + @leapp_observation_input(kind="state/body/position", element_names=[XYZ_ELEMENT_NAMES]) + def observation_term(env, multiplier: float): + return env.value * multiplier + + real_env = SimpleNamespace(value=torch.tensor([[-1.0, 0.25, 2.0]])) + term_cfg = SimpleNamespace( + modifiers=[SimpleNamespace(func=add_offset, params={"offset": 1.0})], + clip=(0.0, 2.0), + scale=3.0, + ) + semantics = resolve_leapp_observation_input_semantics(observation_term) + patcher = ExportPatcher(export_method="onnx") + patcher.task_name = "TestTask-v0" + captured = {} + + def fake_input_tensors(task_name, tensor_semantics): + captured["task_name"] = task_name + captured["tensor_semantics"] = tensor_semantics + return tensor_semantics.ref + + monkeypatch.setattr(export_annotator.annotate, "input_tensors", fake_input_tensors) + + wrapped = patcher._wrap_observation_input( + observation_term, + real_env, + "policy", + "shaft_pos", + semantics, + term_cfg, + ) + returned = wrapped(SimpleNamespace(value=torch.full((1, 3), 100.0)), multiplier=2.0) + + expected = torch.tensor([[0.0, 4.5, 6.0]]) + torch.testing.assert_close(returned, expected) + assert captured["task_name"] == "TestTask-v0" + assert captured["tensor_semantics"].name == "shaft_pos" + assert captured["tensor_semantics"].ref is returned + assert captured["tensor_semantics"].kind == "state/body/position" + assert captured["tensor_semantics"].element_names == [XYZ_ELEMENT_NAMES] + assert captured["tensor_semantics"].extra == {"isaaclab_connection": "observation:policy:shaft_pos"} diff --git a/source/isaaclab_tasks/isaaclab_tasks/contrib/deploy/mdp/observations.py b/source/isaaclab_tasks/isaaclab_tasks/contrib/deploy/mdp/observations.py index ac12d8b22f7f..3e0252ad6ee0 100644 --- a/source/isaaclab_tasks/isaaclab_tasks/contrib/deploy/mdp/observations.py +++ b/source/isaaclab_tasks/isaaclab_tasks/contrib/deploy/mdp/observations.py @@ -12,6 +12,11 @@ import torch from isaaclab.managers import ManagerTermBase, ObservationTermCfg, SceneEntityCfg +from isaaclab.utils.leapp.leapp_semantics import ( + QUAT_XYZW_ELEMENT_NAMES, + XYZ_ELEMENT_NAMES, + leapp_observation_input, +) from isaaclab.utils.math import combine_frame_transforms if TYPE_CHECKING: @@ -21,6 +26,7 @@ from .events import randomize_gear_type +@leapp_observation_input(kind="state/body/position", element_names=[XYZ_ELEMENT_NAMES]) class gear_shaft_pos_w(ManagerTermBase): """Gear shaft position in world frame with offset applied. @@ -139,6 +145,7 @@ def __call__( return shaft_pos - env.scene.env_origins +@leapp_observation_input(kind="state/body/rotation", element_names=[QUAT_XYZW_ELEMENT_NAMES]) class gear_shaft_quat_w(ManagerTermBase): """Gear shaft orientation in world frame.