-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fixes JIT and ONNX export of image-only CNN policies #6159
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| Fixed | ||
| ^^^^^ | ||
|
|
||
| * Fixed the JIT and ONNX export of image-only :class:`~isaaclab_rl.rsl_rl.models.CNNModel` | ||
| policies. The exported models no longer require feeding a zero-width 1D observation | ||
| input (``obs``); they now only take the 2D observation groups as inputs. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,18 +7,27 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import copy | ||
|
|
||
| import torch | ||
| from rsl_rl.models.cnn_model import CNNModel as _CNNModel | ||
| from rsl_rl.models.mlp_model import MLPModel | ||
| from rsl_rl.modules import HiddenState | ||
| from tensordict import TensorDict | ||
| from torch import nn | ||
|
|
||
|
|
||
| class CNNModel(_CNNModel): | ||
| """CNN model that supports pure image-only observations. | ||
|
|
||
| The rsl_rl CNN model does not support image-only observations as it calls | ||
| :meth:`get_latent` without checking whether the observation groups are empty. | ||
|
|
||
| The same applies to the export wrappers returned by :meth:`as_jit` and :meth:`as_onnx`: | ||
| the rsl_rl wrappers always expect a 1D observation input, which for image-only models | ||
| becomes a mandatory zero-width ``obs`` tensor that deployment runtimes have to feed. | ||
| For image-only models, this class instead returns export wrappers that only take the | ||
| 2D observation groups as inputs. | ||
| """ | ||
|
|
||
| def get_latent( | ||
|
|
@@ -29,3 +38,98 @@ def get_latent( | |
| return latent_cnn | ||
| latent_1d = MLPModel.get_latent(self, obs, masks, hidden_state) | ||
| return torch.cat([latent_1d, latent_cnn], dim=-1) | ||
|
|
||
| def as_jit(self) -> nn.Module: | ||
| """Return a version of the model compatible with Torch JIT export.""" | ||
| if not self.obs_groups: | ||
| return _TorchImageOnlyCNNModel(self) | ||
| return super().as_jit() | ||
|
|
||
| def as_onnx(self, verbose: bool = False) -> nn.Module: | ||
| """Return a version of the model compatible with ONNX export.""" | ||
| if not self.obs_groups: | ||
| return _OnnxImageOnlyCNNModel(self, verbose) | ||
| return super().as_onnx(verbose) | ||
|
|
||
|
|
||
| class _TorchImageOnlyCNNModel(nn.Module): | ||
| """Exportable image-only CNN model for JIT. | ||
|
|
||
| Unlike ``rsl_rl``'s exportable CNN model, the forward pass only takes the 2D observation | ||
| groups as input, without a placeholder for the (empty) 1D observations. | ||
| """ | ||
|
|
||
| def __init__(self, model: CNNModel): | ||
| super().__init__() | ||
| # Convert ModuleDict to ModuleList for ordered iteration | ||
| self.cnns = nn.ModuleList([copy.deepcopy(model.cnns[g]) for g in model.obs_groups_2d]) | ||
| self.mlp = copy.deepcopy(model.mlp) | ||
| if model.distribution is not None: | ||
| self.deterministic_output = model.distribution.as_deterministic_output_module() | ||
| else: | ||
| self.deterministic_output = nn.Identity() | ||
|
|
||
| def forward(self, obs_2d: list[torch.Tensor]) -> torch.Tensor: | ||
| """Run deterministic inference from the 2D observation groups.""" | ||
| latent_cnn_list = [] | ||
| for i, cnn in enumerate(self.cnns): # We assume obs_2d list matches the order of obs_groups_2d | ||
| latent_cnn_list.append(cnn(obs_2d[i])) | ||
| latent = torch.cat(latent_cnn_list, dim=-1) | ||
| out = self.mlp(latent) | ||
| return self.deterministic_output(out) | ||
|
|
||
| @torch.jit.export | ||
| def reset(self) -> None: | ||
| """Reset recurrent export state (no-op for CNN exports).""" | ||
| pass | ||
|
|
||
|
|
||
| class _OnnxImageOnlyCNNModel(nn.Module): | ||
| """Exportable image-only CNN model for ONNX. | ||
|
|
||
| Unlike ``rsl_rl``'s exportable CNN model, the forward pass only takes the 2D observation | ||
| groups as input, without a placeholder for the (empty) 1D observations. | ||
| """ | ||
|
|
||
| def __init__(self, model: CNNModel, verbose: bool): | ||
| super().__init__() | ||
| self.verbose = verbose | ||
| # Convert ModuleDict to ModuleList for ordered iteration | ||
| self.cnns = nn.ModuleList([copy.deepcopy(model.cnns[g]) for g in model.obs_groups_2d]) | ||
| self.mlp = copy.deepcopy(model.mlp) | ||
| if model.distribution is not None: | ||
| self.deterministic_output = model.distribution.as_deterministic_output_module() | ||
| else: | ||
| self.deterministic_output = nn.Identity() | ||
|
|
||
| self.obs_groups_2d = model.obs_groups_2d | ||
| self.obs_dims_2d = model.obs_dims_2d | ||
| self.obs_channels_2d = model.obs_channels_2d | ||
|
|
||
| def forward(self, *obs_2d: torch.Tensor) -> torch.Tensor: | ||
| """Run deterministic inference for ONNX export.""" | ||
| latent_cnn_list = [] | ||
| for i, cnn in enumerate(self.cnns): | ||
| latent_cnn_list.append(cnn(obs_2d[i])) | ||
| latent = torch.cat(latent_cnn_list, dim=-1) | ||
| out = self.mlp(latent) | ||
| return self.deterministic_output(out) | ||
|
|
||
| def get_dummy_inputs(self) -> tuple[torch.Tensor, ...]: | ||
| """Return representative dummy inputs for ONNX tracing.""" | ||
| dummy_2d = [] | ||
| for i in range(len(self.obs_groups_2d)): | ||
| h, w = self.obs_dims_2d[i] | ||
| c = self.obs_channels_2d[i] | ||
| dummy_2d.append(torch.zeros(1, c, h, w)) | ||
| return tuple(dummy_2d) | ||
|
|
||
| @property | ||
| def input_names(self) -> list[str]: | ||
| """Return ONNX input tensor names.""" | ||
| return list(self.obs_groups_2d) | ||
|
|
||
| @property | ||
| def output_names(self) -> list[str]: | ||
| """Return ONNX output tensor names.""" | ||
| return ["actions"] | ||
|
Comment on lines
+55
to
+135
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same rationale as the other thread: the duplication mirrors rsl_rl's own
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a solid rationale — the Tip: You can customize Greptile's behavior for this repo with |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,161 @@ | ||
| # Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). | ||
| # All rights reserved. | ||
| # | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| """Tests for the RSL-RL neural models customized for Isaac Lab. | ||
|
|
||
| These tests run on CPU and do not require Isaac Sim. They cover the image-only observation | ||
| support of :class:`~isaaclab_rl.rsl_rl.models.CNNModel` for training (``get_latent``) and for | ||
| deployment (``as_jit`` / ``as_onnx`` export wrappers), as well as parity of the exported | ||
| models with the original model. | ||
| """ | ||
|
|
||
| import io | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| pytest.importorskip("rsl_rl", reason="rsl-rl-lib is not installed") | ||
| pytest.importorskip("tensordict", reason="tensordict is not installed") | ||
|
|
||
| from tensordict import TensorDict # noqa: E402 | ||
|
|
||
| from isaaclab_rl.rsl_rl.models import CNNModel # noqa: E402 | ||
|
|
||
| CNN_CFG = {"output_channels": [16, 32], "kernel_size": [8, 4], "stride": [4, 2], "activation": "relu"} | ||
| DIST_CFG = {"class_name": "GaussianDistribution", "init_std": 1.0} | ||
|
|
||
|
|
||
| def _make_image_only_model() -> tuple[CNNModel, TensorDict]: | ||
| """Create a CNN model with a single image observation group (e.g. cartpole camera task).""" | ||
| obs = TensorDict({"policy": torch.rand(1, 3, 64, 64)}, batch_size=[1]) | ||
| obs_groups = {"actor": ["policy"], "critic": ["policy"]} | ||
| model = CNNModel( | ||
| obs, | ||
| obs_groups, | ||
| "actor", | ||
| 2, | ||
| hidden_dims=[64], | ||
| activation="elu", | ||
| obs_normalization=False, | ||
| distribution_cfg=dict(DIST_CFG), | ||
| cnn_cfg=dict(CNN_CFG), | ||
| ) | ||
| model.eval() | ||
| return model, obs | ||
|
|
||
|
|
||
| def _make_mixed_model() -> tuple[CNNModel, TensorDict]: | ||
| """Create a CNN model with image and proprioceptive observation groups.""" | ||
| obs = TensorDict({"proprio": torch.randn(1, 12), "camera": torch.rand(1, 3, 64, 64)}, batch_size=[1]) | ||
| obs_groups = {"actor": ["proprio", "camera"], "critic": ["proprio", "camera"]} | ||
| model = CNNModel( | ||
| obs, | ||
| obs_groups, | ||
| "actor", | ||
| 2, | ||
| hidden_dims=[64], | ||
| activation="elu", | ||
| obs_normalization=False, | ||
| distribution_cfg=dict(DIST_CFG), | ||
| cnn_cfg=dict(CNN_CFG), | ||
| ) | ||
| model.eval() | ||
| return model, obs | ||
|
|
||
|
|
||
| def _script_roundtrip(jit_model: torch.nn.Module) -> torch.jit.ScriptModule: | ||
| """Script, save and reload a module the way the runner export + deployment does.""" | ||
| scripted = torch.jit.script(jit_model) | ||
| buffer = io.BytesIO() | ||
| torch.jit.save(scripted, buffer) | ||
| buffer.seek(0) | ||
| loaded = torch.jit.load(buffer) | ||
| loaded.eval() | ||
| return loaded | ||
|
|
||
|
|
||
| def test_cnn_model_image_only_forward(): | ||
| """Image-only models compute a latent purely from the CNN encoders.""" | ||
| model, obs = _make_image_only_model() | ||
| with torch.inference_mode(): | ||
| out = model(obs) | ||
| assert out.shape == (1, 2) | ||
|
|
||
|
|
||
| def test_cnn_model_image_only_jit_export_takes_only_images(): | ||
| """The JIT export of an image-only model takes only the 2D observations as input.""" | ||
| model, obs = _make_image_only_model() | ||
| loaded = _script_roundtrip(model.as_jit()) | ||
| with torch.inference_mode(): | ||
| out = loaded([obs["policy"]]) | ||
| ref = model(obs) | ||
| torch.testing.assert_close(out, ref, rtol=1e-5, atol=1e-6) | ||
|
|
||
|
|
||
| def test_cnn_model_image_only_onnx_export_takes_only_images(tmp_path): | ||
| """The ONNX export of an image-only model declares only the 2D observations as inputs.""" | ||
| onnx = pytest.importorskip("onnx", reason="onnx is not installed") | ||
|
|
||
| model, obs = _make_image_only_model() | ||
| onnx_model = model.as_onnx(verbose=False) | ||
| onnx_model.eval() | ||
| assert onnx_model.input_names == ["policy"] | ||
|
|
||
| path = str(tmp_path / "policy.onnx") | ||
| torch.onnx.export( | ||
| onnx_model, | ||
| onnx_model.get_dummy_inputs(), | ||
| path, | ||
| export_params=True, | ||
| opset_version=18, | ||
| input_names=onnx_model.input_names, | ||
| output_names=onnx_model.output_names, | ||
| ) | ||
| graph_inputs = [graph_input.name for graph_input in onnx.load(path).graph.input] | ||
| assert graph_inputs == ["policy"] | ||
|
|
||
| ort = pytest.importorskip("onnxruntime", reason="onnxruntime is not installed") | ||
| session = ort.InferenceSession(path) | ||
| out = torch.from_numpy(session.run(None, {"policy": obs["policy"].numpy()})[0]) | ||
| with torch.inference_mode(): | ||
| ref = model(obs) | ||
| torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-5) | ||
|
|
||
|
|
||
| def test_cnn_model_mixed_obs_jit_export_parity(): | ||
| """Models with 1D and 2D observation groups keep the upstream export interface.""" | ||
| model, obs = _make_mixed_model() | ||
| loaded = _script_roundtrip(model.as_jit()) | ||
| with torch.inference_mode(): | ||
| out = loaded(obs["proprio"], [obs["camera"]]) | ||
| ref = model(obs) | ||
| torch.testing.assert_close(out, ref, rtol=1e-5, atol=1e-6) | ||
|
|
||
|
|
||
| def test_cnn_model_mixed_obs_onnx_export_parity(tmp_path): | ||
| """Models with 1D and 2D observation groups keep the upstream ONNX interface.""" | ||
| pytest.importorskip("onnx", reason="onnx is not installed") | ||
| ort = pytest.importorskip("onnxruntime", reason="onnxruntime is not installed") | ||
|
|
||
| model, obs = _make_mixed_model() | ||
| onnx_model = model.as_onnx(verbose=False) | ||
| onnx_model.eval() | ||
| assert onnx_model.input_names == ["obs", "camera"] | ||
|
|
||
| path = str(tmp_path / "policy.onnx") | ||
| torch.onnx.export( | ||
| onnx_model, | ||
| onnx_model.get_dummy_inputs(), | ||
| path, | ||
| export_params=True, | ||
| opset_version=18, | ||
| input_names=onnx_model.input_names, | ||
| output_names=onnx_model.output_names, | ||
| ) | ||
| session = ort.InferenceSession(path) | ||
| out = torch.from_numpy(session.run(None, {"obs": obs["proprio"].numpy(), "camera": obs["camera"].numpy()})[0]) | ||
| with torch.inference_mode(): | ||
| ref = model(obs) | ||
| torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
verboseattributeself.verboseis assigned in__init__but is never read anywhere in_OnnxImageOnlyCNNModel. The upstreamrsl_rlONNX wrapper presumably prints tracing info whenverbose=True, but this class omits that behavior without surfacing the omission. If verbosity is intentionally dropped here, the field should not be stored; if it should be forwarded (e.g. totorch.onnx.export), the call site inas_onnxhas the parameter ready.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kept deliberately for structural parity with the upstream wrappers: rsl_rl's
_OnnxCNNModelstoresself.verbosethe same way without reading it inforward(the runner passesverbosetotorch.onnx.exportseparately). Keeping the same shape means these classes stay directly diff-able against the upstream ones they specialize, which also makes them easy to upstream into rsl_rl later.