From 1fd15591facb3c69f61811bdb89a1293e3d61119 Mon Sep 17 00:00:00 2001 From: Advait Jayant Date: Fri, 12 Jun 2026 00:53:18 +0100 Subject: [PATCH] Fix JIT and ONNX export of image-only CNN policies The CNNModel subclass in isaaclab_rl supports image-only observations during training by overriding get_latent, but the export wrappers inherited from rsl_rl always declare a 1D observation input. For image-only policies (e.g. the cartpole camera tasks), the exported TorchScript and ONNX graphs therefore require feeding a zero-width placeholder tensor, which is awkward to deploy and rejected by some downstream toolchains. Override as_jit and as_onnx to return image-only export wrappers when no 1D observation groups are active. The exported models now only take the 2D observation groups as inputs. Models with mixed 1D and 2D observations keep the upstream export interface. Verified on CPU by exporting and reloading both variants and checking output parity against the original model (torch.jit and onnxruntime). Related to #4592 Signed-off-by: Advait Jayant --- .../0xadvait-fix-cnn-image-only-export.rst | 6 + .../isaaclab_rl/isaaclab_rl/rsl_rl/models.py | 104 +++++++++++ source/isaaclab_rl/test/test_rsl_rl_models.py | 161 ++++++++++++++++++ 3 files changed, 271 insertions(+) create mode 100644 source/isaaclab_rl/changelog.d/0xadvait-fix-cnn-image-only-export.rst create mode 100644 source/isaaclab_rl/test/test_rsl_rl_models.py diff --git a/source/isaaclab_rl/changelog.d/0xadvait-fix-cnn-image-only-export.rst b/source/isaaclab_rl/changelog.d/0xadvait-fix-cnn-image-only-export.rst new file mode 100644 index 000000000000..5b98c27ccbf4 --- /dev/null +++ b/source/isaaclab_rl/changelog.d/0xadvait-fix-cnn-image-only-export.rst @@ -0,0 +1,6 @@ +Fixed +^^^^^ + +* Fixed the JIT and ONNX export of image-only :class:`~isaaclab_rl.rsl_rl.models.CNNModel` + policies. The exported models no longer require feeding a zero-width 1D observation + input (``obs``); they now only take the 2D observation groups as inputs. diff --git a/source/isaaclab_rl/isaaclab_rl/rsl_rl/models.py b/source/isaaclab_rl/isaaclab_rl/rsl_rl/models.py index 89907518a814..584618e71102 100644 --- a/source/isaaclab_rl/isaaclab_rl/rsl_rl/models.py +++ b/source/isaaclab_rl/isaaclab_rl/rsl_rl/models.py @@ -7,11 +7,14 @@ from __future__ import annotations +import copy + import torch from rsl_rl.models.cnn_model import CNNModel as _CNNModel from rsl_rl.models.mlp_model import MLPModel from rsl_rl.modules import HiddenState from tensordict import TensorDict +from torch import nn class CNNModel(_CNNModel): @@ -19,6 +22,12 @@ class CNNModel(_CNNModel): The rsl_rl CNN model does not support image-only observations as it calls :meth:`get_latent` without checking whether the observation groups are empty. + + The same applies to the export wrappers returned by :meth:`as_jit` and :meth:`as_onnx`: + the rsl_rl wrappers always expect a 1D observation input, which for image-only models + becomes a mandatory zero-width ``obs`` tensor that deployment runtimes have to feed. + For image-only models, this class instead returns export wrappers that only take the + 2D observation groups as inputs. """ def get_latent( @@ -29,3 +38,98 @@ def get_latent( return latent_cnn latent_1d = MLPModel.get_latent(self, obs, masks, hidden_state) return torch.cat([latent_1d, latent_cnn], dim=-1) + + def as_jit(self) -> nn.Module: + """Return a version of the model compatible with Torch JIT export.""" + if not self.obs_groups: + return _TorchImageOnlyCNNModel(self) + return super().as_jit() + + def as_onnx(self, verbose: bool = False) -> nn.Module: + """Return a version of the model compatible with ONNX export.""" + if not self.obs_groups: + return _OnnxImageOnlyCNNModel(self, verbose) + return super().as_onnx(verbose) + + +class _TorchImageOnlyCNNModel(nn.Module): + """Exportable image-only CNN model for JIT. + + Unlike ``rsl_rl``'s exportable CNN model, the forward pass only takes the 2D observation + groups as input, without a placeholder for the (empty) 1D observations. + """ + + def __init__(self, model: CNNModel): + super().__init__() + # Convert ModuleDict to ModuleList for ordered iteration + self.cnns = nn.ModuleList([copy.deepcopy(model.cnns[g]) for g in model.obs_groups_2d]) + self.mlp = copy.deepcopy(model.mlp) + if model.distribution is not None: + self.deterministic_output = model.distribution.as_deterministic_output_module() + else: + self.deterministic_output = nn.Identity() + + def forward(self, obs_2d: list[torch.Tensor]) -> torch.Tensor: + """Run deterministic inference from the 2D observation groups.""" + latent_cnn_list = [] + for i, cnn in enumerate(self.cnns): # We assume obs_2d list matches the order of obs_groups_2d + latent_cnn_list.append(cnn(obs_2d[i])) + latent = torch.cat(latent_cnn_list, dim=-1) + out = self.mlp(latent) + return self.deterministic_output(out) + + @torch.jit.export + def reset(self) -> None: + """Reset recurrent export state (no-op for CNN exports).""" + pass + + +class _OnnxImageOnlyCNNModel(nn.Module): + """Exportable image-only CNN model for ONNX. + + Unlike ``rsl_rl``'s exportable CNN model, the forward pass only takes the 2D observation + groups as input, without a placeholder for the (empty) 1D observations. + """ + + def __init__(self, model: CNNModel, verbose: bool): + super().__init__() + self.verbose = verbose + # Convert ModuleDict to ModuleList for ordered iteration + self.cnns = nn.ModuleList([copy.deepcopy(model.cnns[g]) for g in model.obs_groups_2d]) + self.mlp = copy.deepcopy(model.mlp) + if model.distribution is not None: + self.deterministic_output = model.distribution.as_deterministic_output_module() + else: + self.deterministic_output = nn.Identity() + + self.obs_groups_2d = model.obs_groups_2d + self.obs_dims_2d = model.obs_dims_2d + self.obs_channels_2d = model.obs_channels_2d + + def forward(self, *obs_2d: torch.Tensor) -> torch.Tensor: + """Run deterministic inference for ONNX export.""" + latent_cnn_list = [] + for i, cnn in enumerate(self.cnns): + latent_cnn_list.append(cnn(obs_2d[i])) + latent = torch.cat(latent_cnn_list, dim=-1) + out = self.mlp(latent) + return self.deterministic_output(out) + + def get_dummy_inputs(self) -> tuple[torch.Tensor, ...]: + """Return representative dummy inputs for ONNX tracing.""" + dummy_2d = [] + for i in range(len(self.obs_groups_2d)): + h, w = self.obs_dims_2d[i] + c = self.obs_channels_2d[i] + dummy_2d.append(torch.zeros(1, c, h, w)) + return tuple(dummy_2d) + + @property + def input_names(self) -> list[str]: + """Return ONNX input tensor names.""" + return list(self.obs_groups_2d) + + @property + def output_names(self) -> list[str]: + """Return ONNX output tensor names.""" + return ["actions"] diff --git a/source/isaaclab_rl/test/test_rsl_rl_models.py b/source/isaaclab_rl/test/test_rsl_rl_models.py new file mode 100644 index 000000000000..7b91ed7d41c7 --- /dev/null +++ b/source/isaaclab_rl/test/test_rsl_rl_models.py @@ -0,0 +1,161 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for the RSL-RL neural models customized for Isaac Lab. + +These tests run on CPU and do not require Isaac Sim. They cover the image-only observation +support of :class:`~isaaclab_rl.rsl_rl.models.CNNModel` for training (``get_latent``) and for +deployment (``as_jit`` / ``as_onnx`` export wrappers), as well as parity of the exported +models with the original model. +""" + +import io + +import pytest +import torch + +pytest.importorskip("rsl_rl", reason="rsl-rl-lib is not installed") +pytest.importorskip("tensordict", reason="tensordict is not installed") + +from tensordict import TensorDict # noqa: E402 + +from isaaclab_rl.rsl_rl.models import CNNModel # noqa: E402 + +CNN_CFG = {"output_channels": [16, 32], "kernel_size": [8, 4], "stride": [4, 2], "activation": "relu"} +DIST_CFG = {"class_name": "GaussianDistribution", "init_std": 1.0} + + +def _make_image_only_model() -> tuple[CNNModel, TensorDict]: + """Create a CNN model with a single image observation group (e.g. cartpole camera task).""" + obs = TensorDict({"policy": torch.rand(1, 3, 64, 64)}, batch_size=[1]) + obs_groups = {"actor": ["policy"], "critic": ["policy"]} + model = CNNModel( + obs, + obs_groups, + "actor", + 2, + hidden_dims=[64], + activation="elu", + obs_normalization=False, + distribution_cfg=dict(DIST_CFG), + cnn_cfg=dict(CNN_CFG), + ) + model.eval() + return model, obs + + +def _make_mixed_model() -> tuple[CNNModel, TensorDict]: + """Create a CNN model with image and proprioceptive observation groups.""" + obs = TensorDict({"proprio": torch.randn(1, 12), "camera": torch.rand(1, 3, 64, 64)}, batch_size=[1]) + obs_groups = {"actor": ["proprio", "camera"], "critic": ["proprio", "camera"]} + model = CNNModel( + obs, + obs_groups, + "actor", + 2, + hidden_dims=[64], + activation="elu", + obs_normalization=False, + distribution_cfg=dict(DIST_CFG), + cnn_cfg=dict(CNN_CFG), + ) + model.eval() + return model, obs + + +def _script_roundtrip(jit_model: torch.nn.Module) -> torch.jit.ScriptModule: + """Script, save and reload a module the way the runner export + deployment does.""" + scripted = torch.jit.script(jit_model) + buffer = io.BytesIO() + torch.jit.save(scripted, buffer) + buffer.seek(0) + loaded = torch.jit.load(buffer) + loaded.eval() + return loaded + + +def test_cnn_model_image_only_forward(): + """Image-only models compute a latent purely from the CNN encoders.""" + model, obs = _make_image_only_model() + with torch.inference_mode(): + out = model(obs) + assert out.shape == (1, 2) + + +def test_cnn_model_image_only_jit_export_takes_only_images(): + """The JIT export of an image-only model takes only the 2D observations as input.""" + model, obs = _make_image_only_model() + loaded = _script_roundtrip(model.as_jit()) + with torch.inference_mode(): + out = loaded([obs["policy"]]) + ref = model(obs) + torch.testing.assert_close(out, ref, rtol=1e-5, atol=1e-6) + + +def test_cnn_model_image_only_onnx_export_takes_only_images(tmp_path): + """The ONNX export of an image-only model declares only the 2D observations as inputs.""" + onnx = pytest.importorskip("onnx", reason="onnx is not installed") + + model, obs = _make_image_only_model() + onnx_model = model.as_onnx(verbose=False) + onnx_model.eval() + assert onnx_model.input_names == ["policy"] + + path = str(tmp_path / "policy.onnx") + torch.onnx.export( + onnx_model, + onnx_model.get_dummy_inputs(), + path, + export_params=True, + opset_version=18, + input_names=onnx_model.input_names, + output_names=onnx_model.output_names, + ) + graph_inputs = [graph_input.name for graph_input in onnx.load(path).graph.input] + assert graph_inputs == ["policy"] + + ort = pytest.importorskip("onnxruntime", reason="onnxruntime is not installed") + session = ort.InferenceSession(path) + out = torch.from_numpy(session.run(None, {"policy": obs["policy"].numpy()})[0]) + with torch.inference_mode(): + ref = model(obs) + torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-5) + + +def test_cnn_model_mixed_obs_jit_export_parity(): + """Models with 1D and 2D observation groups keep the upstream export interface.""" + model, obs = _make_mixed_model() + loaded = _script_roundtrip(model.as_jit()) + with torch.inference_mode(): + out = loaded(obs["proprio"], [obs["camera"]]) + ref = model(obs) + torch.testing.assert_close(out, ref, rtol=1e-5, atol=1e-6) + + +def test_cnn_model_mixed_obs_onnx_export_parity(tmp_path): + """Models with 1D and 2D observation groups keep the upstream ONNX interface.""" + pytest.importorskip("onnx", reason="onnx is not installed") + ort = pytest.importorskip("onnxruntime", reason="onnxruntime is not installed") + + model, obs = _make_mixed_model() + onnx_model = model.as_onnx(verbose=False) + onnx_model.eval() + assert onnx_model.input_names == ["obs", "camera"] + + path = str(tmp_path / "policy.onnx") + torch.onnx.export( + onnx_model, + onnx_model.get_dummy_inputs(), + path, + export_params=True, + opset_version=18, + input_names=onnx_model.input_names, + output_names=onnx_model.output_names, + ) + session = ort.InferenceSession(path) + out = torch.from_numpy(session.run(None, {"obs": obs["proprio"].numpy(), "camera": obs["camera"].numpy()})[0]) + with torch.inference_mode(): + ref = model(obs) + torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-5)