-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Ashwinvk/leapp observation term inputs #6081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,10 +5,10 @@ | |
|
|
||
| """Deployment environment that runs LEAPP-exported policies in simulation. | ||
|
|
||
| This environment bypasses all Isaac Lab managers (observation, action, reward, etc.) | ||
| and instead wires scene entity data properties and ``CommandManager`` outputs directly | ||
| to a LEAPP ``InferenceManager``, then writes the model outputs back to the | ||
| corresponding scene entities. All I/O resolution is driven by the | ||
| This environment bypasses most Isaac Lab managers (action, reward, etc.) and instead | ||
| wires scene entity data properties, selected observation terms, and ``CommandManager`` | ||
| outputs directly to a LEAPP ``InferenceManager``, then writes the model outputs back | ||
| to the corresponding scene entities. All I/O resolution is driven by the | ||
| ``isaaclab_connection`` field in the LEAPP YAML. | ||
| """ | ||
|
|
||
|
|
@@ -27,7 +27,7 @@ | |
| except ImportError as e: | ||
| raise ImportError("LEAPP package is required for policy deployment testing. Install with: pip install leapp") from e | ||
|
|
||
| from isaaclab.managers import CommandManager, EventManager | ||
| from isaaclab.managers import CommandManager, EventManager, ObservationManager | ||
| from isaaclab.scene import InteractiveScene | ||
| from isaaclab.sim import SimulationContext | ||
| from isaaclab.sim.utils.stage import use_stage | ||
|
|
@@ -59,6 +59,14 @@ class CommandInputSpec: | |
| command_term_name: str | ||
|
|
||
|
|
||
| @dataclass | ||
| class ObservationInputSpec: | ||
| """Read a named observation term from ``ObservationManager``.""" | ||
|
|
||
| group_name: str | ||
| term_name: str | ||
|
|
||
|
|
||
| @dataclass | ||
| class WriteOutputSpec: | ||
| """Write a tensor to a scene entity method, optionally indexed by joint.""" | ||
|
|
@@ -134,6 +142,18 @@ def _first_param_name(method: Any) -> str: | |
| return params[0].name | ||
|
|
||
|
|
||
| def _leapp_desc_has_connection(leapp_desc: dict[str, Any], connection_type: str) -> bool: | ||
| """Return whether a LEAPP description contains an I/O connection of *connection_type*.""" | ||
| prefix = f"{connection_type}:" | ||
| for model_desc in leapp_desc.get("models", {}).values(): | ||
| for io_section in ("inputs", "outputs"): | ||
| for tensor_desc in model_desc.get(io_section, []): | ||
| connection = tensor_desc.get("isaaclab_connection") | ||
| if isinstance(connection, str) and connection.startswith(prefix): | ||
| return True | ||
| return False | ||
|
|
||
|
|
||
| # ══════════════════════════════════════════════════════════════════ | ||
| # LeappDeploymentEnv | ||
| # ══════════════════════════════════════════════════════════════════ | ||
|
|
@@ -153,10 +173,11 @@ class LeappDeploymentEnv: | |
|
|
||
| - ``state:{entity}:{property}`` -- read ``scene[entity].data.{property}`` | ||
| - ``command:{name}`` -- read ``command_manager.get_command(name)`` | ||
| - ``observation:{group}:{term}`` -- compute a named observation term | ||
| - ``write:{entity}:{method}`` -- call ``scene[entity].{method}(tensor, ...)`` | ||
|
|
||
| No observation, action, reward, termination, or curriculum managers are used. | ||
| The LEAPP model already contains all pre/post-processing. | ||
| No action, reward, termination, or curriculum managers are used. The LEAPP model | ||
| already contains all action pre/post-processing. | ||
| """ | ||
|
|
||
| def __init__(self, cfg: Any, leapp_yaml_path: str): | ||
|
|
@@ -197,6 +218,10 @@ def __init__(self, cfg: Any, leapp_yaml_path: str): | |
| else: | ||
| self.viewport_camera_controller = None | ||
|
|
||
| # ── Parse YAML once before constructing optional managers ─ | ||
| with open(leapp_yaml_path) as f: | ||
| self._leapp_desc = yaml.safe_load(f) | ||
|
|
||
| # ── EventManager (optional, for resets) ─────────────────── | ||
| self.event_manager: EventManager | None = None | ||
| if hasattr(cfg, "events") and cfg.events is not None: | ||
|
|
@@ -207,13 +232,21 @@ def __init__(self, cfg: Any, leapp_yaml_path: str): | |
| if hasattr(cfg, "commands") and cfg.commands is not None: | ||
| self.command_manager = CommandManager(cfg.commands, cast(Any, self)) | ||
|
|
||
| # ── ObservationManager (optional, for observation/* inputs) ─ | ||
| self.observation_manager: ObservationManager | None = None | ||
| if _leapp_desc_has_connection(self._leapp_desc, "observation"): | ||
| if not hasattr(cfg, "observations") or cfg.observations is None: | ||
| raise RuntimeError( | ||
| "LEAPP YAML declares observation inputs but no ObservationManager configuration is available " | ||
| "(cfg.observations is None)." | ||
| ) | ||
| self.observation_manager = ObservationManager(cfg.observations, cast(Any, self)) | ||
|
|
||
| # ── LEAPP InferenceManager ──────────────────────────────── | ||
| self.inference = InferenceManager(leapp_yaml_path) | ||
|
|
||
| # ── Parse YAML and resolve I/O mappings ─────────────────── | ||
| with open(leapp_yaml_path) as f: | ||
| self._leapp_desc = yaml.safe_load(f) | ||
| self._input_mapping: dict[str, StateInputSpec | CommandInputSpec] = {} | ||
| # ── Resolve I/O mappings ───────────────────────────────── | ||
| self._input_mapping: dict[str, StateInputSpec | CommandInputSpec | ObservationInputSpec] = {} | ||
| self._output_mapping: dict[str, WriteOutputSpec] = {} | ||
| self._resolve_io() | ||
|
|
||
|
|
@@ -286,6 +319,10 @@ def _resolve_io(self): | |
| "CommandManager is available (cfg.commands is None)." | ||
| ) | ||
| self._input_mapping[key] = CommandInputSpec(command_term_name=command_name) | ||
| elif conn_type == "observation": | ||
| group_name, term_name = parts[1], parts[2] | ||
|
Comment on lines
+322
to
+323
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| self._validate_observation_term(key, group_name, term_name) | ||
| self._input_mapping[key] = ObservationInputSpec(group_name=group_name, term_name=term_name) | ||
| else: | ||
| logger.warning("Unknown connection type '%s' for input '%s'", conn_type, key) | ||
|
|
||
|
|
@@ -315,6 +352,27 @@ def _resolve_io(self): | |
| else: | ||
| logger.warning("Unknown connection type '%s' for output '%s'", conn_type, key) | ||
|
|
||
| def _validate_observation_term(self, key: str, group_name: str, term_name: str): | ||
| """Raise a useful error if a LEAPP observation input cannot be resolved.""" | ||
| observation_manager = self.observation_manager | ||
| if observation_manager is None: | ||
| raise RuntimeError( | ||
| f"LEAPP input '{key}' requires observation term '{group_name}:{term_name}' but no " | ||
| "ObservationManager is available." | ||
| ) | ||
|
|
||
| group_term_names = observation_manager.active_terms.get(group_name) | ||
| if group_term_names is None: | ||
| raise ValueError( | ||
| f"LEAPP input '{key}' references unknown observation group '{group_name}'. " | ||
| f"Available groups are: {list(observation_manager.active_terms.keys())}" | ||
| ) | ||
| if term_name not in group_term_names: | ||
| raise ValueError( | ||
| f"LEAPP input '{key}' references unknown observation term '{group_name}:{term_name}'. " | ||
| f"Available terms in group '{group_name}' are: {group_term_names}" | ||
| ) | ||
|
|
||
| # ── Read / Write ────────────────────────────────────────────── | ||
|
|
||
| def _read_inputs(self) -> dict[str, torch.Tensor]: | ||
|
|
@@ -336,8 +394,29 @@ def _read_inputs(self) -> dict[str, torch.Tensor]: | |
| command_manager = self.command_manager | ||
| assert command_manager is not None | ||
| inputs[key] = command_manager.get_command(spec.command_term_name) | ||
| elif isinstance(spec, ObservationInputSpec): | ||
| inputs[key] = self._compute_observation_input(spec) | ||
| return inputs | ||
|
|
||
| def _compute_observation_input(self, spec: ObservationInputSpec) -> torch.Tensor: | ||
| """Compute one named observation term without injecting observation noise.""" | ||
| observation_manager = self.observation_manager | ||
| assert observation_manager is not None | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 Warning: This accesses private attributes |
||
|
|
||
| group_term_names = observation_manager._group_obs_term_names[spec.group_name] | ||
| term_index = group_term_names.index(spec.term_name) | ||
| term_cfg = observation_manager._group_obs_term_cfgs[spec.group_name][term_index] | ||
|
|
||
| obs: torch.Tensor = term_cfg.func(cast(Any, self), **term_cfg.params).clone() | ||
| if term_cfg.modifiers is not None: | ||
| for modifier in term_cfg.modifiers: | ||
| obs = modifier.func(obs, **modifier.params) | ||
| if term_cfg.clip: | ||
| obs = obs.clip_(min=term_cfg.clip[0], max=term_cfg.clip[1]) | ||
| if term_cfg.scale is not None: | ||
| obs = obs.mul_(term_cfg.scale) | ||
| return obs | ||
|
|
||
| def _write_outputs(self, outputs: dict[str, torch.Tensor]): | ||
| """Write model outputs to scene entities. | ||
|
|
||
|
|
@@ -364,12 +443,14 @@ def reset(self) -> dict[str, torch.Tensor]: | |
| Returns: | ||
| The initial input tensors (for logging / debugging). | ||
| """ | ||
| env_ids = [0] | ||
| env_ids = torch.arange(self.num_envs, dtype=torch.int32, device=self.device) | ||
|
|
||
| self.scene.reset(env_ids) | ||
|
|
||
| if self.event_manager is not None and "reset" in self.event_manager.available_modes: | ||
| self.event_manager.apply(mode="reset", env_ids=env_ids, global_env_step_count=self._step_count) | ||
| if self.observation_manager is not None: | ||
| self.observation_manager.reset(env_ids) | ||
| if self.command_manager is not None: | ||
| self.command_manager.reset(env_ids) | ||
|
|
||
|
|
@@ -440,6 +521,8 @@ def close(self): | |
| self.sim.stop() | ||
| if self.command_manager is not None: | ||
| del self.command_manager | ||
| if self.observation_manager is not None: | ||
| del self.observation_manager | ||
| if self.event_manager is not None: | ||
| del self.event_manager | ||
| del self.scene | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 Warning: If a LEAPP YAML contains a malformed connection like
"observation:foo"(missing the term component),parts[2]raises an unguardedIndexError. Consider adding validation: