diff --git a/source/isaaclab_rl/changelog.d/0xadvait-fix-cnn-image-only-export.rst b/source/isaaclab_rl/changelog.d/0xadvait-fix-cnn-image-only-export.rst new file mode 100644 index 000000000000..5b98c27ccbf4 --- /dev/null +++ b/source/isaaclab_rl/changelog.d/0xadvait-fix-cnn-image-only-export.rst @@ -0,0 +1,6 @@ +Fixed +^^^^^ + +* Fixed the JIT and ONNX export of image-only :class:`~isaaclab_rl.rsl_rl.models.CNNModel` + policies. The exported models no longer require feeding a zero-width 1D observation + input (``obs``); they now only take the 2D observation groups as inputs. diff --git a/source/isaaclab_rl/isaaclab_rl/rsl_rl/models.py b/source/isaaclab_rl/isaaclab_rl/rsl_rl/models.py index 89907518a814..584618e71102 100644 --- a/source/isaaclab_rl/isaaclab_rl/rsl_rl/models.py +++ b/source/isaaclab_rl/isaaclab_rl/rsl_rl/models.py @@ -7,11 +7,14 @@ from __future__ import annotations +import copy + import torch from rsl_rl.models.cnn_model import CNNModel as _CNNModel from rsl_rl.models.mlp_model import MLPModel from rsl_rl.modules import HiddenState from tensordict import TensorDict +from torch import nn class CNNModel(_CNNModel): @@ -19,6 +22,12 @@ class CNNModel(_CNNModel): The rsl_rl CNN model does not support image-only observations as it calls :meth:`get_latent` without checking whether the observation groups are empty. + + The same applies to the export wrappers returned by :meth:`as_jit` and :meth:`as_onnx`: + the rsl_rl wrappers always expect a 1D observation input, which for image-only models + becomes a mandatory zero-width ``obs`` tensor that deployment runtimes have to feed. + For image-only models, this class instead returns export wrappers that only take the + 2D observation groups as inputs. """ def get_latent( @@ -29,3 +38,98 @@ def get_latent( return latent_cnn latent_1d = MLPModel.get_latent(self, obs, masks, hidden_state) return torch.cat([latent_1d, latent_cnn], dim=-1) + + def as_jit(self) -> nn.Module: + """Return a version of the model compatible with Torch JIT export.""" + if not self.obs_groups: + return _TorchImageOnlyCNNModel(self) + return super().as_jit() + + def as_onnx(self, verbose: bool = False) -> nn.Module: + """Return a version of the model compatible with ONNX export.""" + if not self.obs_groups: + return _OnnxImageOnlyCNNModel(self, verbose) + return super().as_onnx(verbose) + + +class _TorchImageOnlyCNNModel(nn.Module): + """Exportable image-only CNN model for JIT. + + Unlike ``rsl_rl``'s exportable CNN model, the forward pass only takes the 2D observation + groups as input, without a placeholder for the (empty) 1D observations. + """ + + def __init__(self, model: CNNModel): + super().__init__() + # Convert ModuleDict to ModuleList for ordered iteration + self.cnns = nn.ModuleList([copy.deepcopy(model.cnns[g]) for g in model.obs_groups_2d]) + self.mlp = copy.deepcopy(model.mlp) + if model.distribution is not None: + self.deterministic_output = model.distribution.as_deterministic_output_module() + else: + self.deterministic_output = nn.Identity() + + def forward(self, obs_2d: list[torch.Tensor]) -> torch.Tensor: + """Run deterministic inference from the 2D observation groups.""" + latent_cnn_list = [] + for i, cnn in enumerate(self.cnns): # We assume obs_2d list matches the order of obs_groups_2d + latent_cnn_list.append(cnn(obs_2d[i])) + latent = torch.cat(latent_cnn_list, dim=-1) + out = self.mlp(latent) + return self.deterministic_output(out) + + @torch.jit.export + def reset(self) -> None: + """Reset recurrent export state (no-op for CNN exports).""" + pass + + +class _OnnxImageOnlyCNNModel(nn.Module): + """Exportable image-only CNN model for ONNX. + + Unlike ``rsl_rl``'s exportable CNN model, the forward pass only takes the 2D observation + groups as input, without a placeholder for the (empty) 1D observations. + """ + + def __init__(self, model: CNNModel, verbose: bool): + super().__init__() + self.verbose = verbose + # Convert ModuleDict to ModuleList for ordered iteration + self.cnns = nn.ModuleList([copy.deepcopy(model.cnns[g]) for g in model.obs_groups_2d]) + self.mlp = copy.deepcopy(model.mlp) + if model.distribution is not None: + self.deterministic_output = model.distribution.as_deterministic_output_module() + else: + self.deterministic_output = nn.Identity() + + self.obs_groups_2d = model.obs_groups_2d + self.obs_dims_2d = model.obs_dims_2d + self.obs_channels_2d = model.obs_channels_2d + + def forward(self, *obs_2d: torch.Tensor) -> torch.Tensor: + """Run deterministic inference for ONNX export.""" + latent_cnn_list = [] + for i, cnn in enumerate(self.cnns): + latent_cnn_list.append(cnn(obs_2d[i])) + latent = torch.cat(latent_cnn_list, dim=-1) + out = self.mlp(latent) + return self.deterministic_output(out) + + def get_dummy_inputs(self) -> tuple[torch.Tensor, ...]: + """Return representative dummy inputs for ONNX tracing.""" + dummy_2d = [] + for i in range(len(self.obs_groups_2d)): + h, w = self.obs_dims_2d[i] + c = self.obs_channels_2d[i] + dummy_2d.append(torch.zeros(1, c, h, w)) + return tuple(dummy_2d) + + @property + def input_names(self) -> list[str]: + """Return ONNX input tensor names.""" + return list(self.obs_groups_2d) + + @property + def output_names(self) -> list[str]: + """Return ONNX output tensor names.""" + return ["actions"] diff --git a/source/isaaclab_rl/test/test_rsl_rl_models.py b/source/isaaclab_rl/test/test_rsl_rl_models.py new file mode 100644 index 000000000000..7b91ed7d41c7 --- /dev/null +++ b/source/isaaclab_rl/test/test_rsl_rl_models.py @@ -0,0 +1,161 @@ +# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for the RSL-RL neural models customized for Isaac Lab. + +These tests run on CPU and do not require Isaac Sim. They cover the image-only observation +support of :class:`~isaaclab_rl.rsl_rl.models.CNNModel` for training (``get_latent``) and for +deployment (``as_jit`` / ``as_onnx`` export wrappers), as well as parity of the exported +models with the original model. +""" + +import io + +import pytest +import torch + +pytest.importorskip("rsl_rl", reason="rsl-rl-lib is not installed") +pytest.importorskip("tensordict", reason="tensordict is not installed") + +from tensordict import TensorDict # noqa: E402 + +from isaaclab_rl.rsl_rl.models import CNNModel # noqa: E402 + +CNN_CFG = {"output_channels": [16, 32], "kernel_size": [8, 4], "stride": [4, 2], "activation": "relu"} +DIST_CFG = {"class_name": "GaussianDistribution", "init_std": 1.0} + + +def _make_image_only_model() -> tuple[CNNModel, TensorDict]: + """Create a CNN model with a single image observation group (e.g. cartpole camera task).""" + obs = TensorDict({"policy": torch.rand(1, 3, 64, 64)}, batch_size=[1]) + obs_groups = {"actor": ["policy"], "critic": ["policy"]} + model = CNNModel( + obs, + obs_groups, + "actor", + 2, + hidden_dims=[64], + activation="elu", + obs_normalization=False, + distribution_cfg=dict(DIST_CFG), + cnn_cfg=dict(CNN_CFG), + ) + model.eval() + return model, obs + + +def _make_mixed_model() -> tuple[CNNModel, TensorDict]: + """Create a CNN model with image and proprioceptive observation groups.""" + obs = TensorDict({"proprio": torch.randn(1, 12), "camera": torch.rand(1, 3, 64, 64)}, batch_size=[1]) + obs_groups = {"actor": ["proprio", "camera"], "critic": ["proprio", "camera"]} + model = CNNModel( + obs, + obs_groups, + "actor", + 2, + hidden_dims=[64], + activation="elu", + obs_normalization=False, + distribution_cfg=dict(DIST_CFG), + cnn_cfg=dict(CNN_CFG), + ) + model.eval() + return model, obs + + +def _script_roundtrip(jit_model: torch.nn.Module) -> torch.jit.ScriptModule: + """Script, save and reload a module the way the runner export + deployment does.""" + scripted = torch.jit.script(jit_model) + buffer = io.BytesIO() + torch.jit.save(scripted, buffer) + buffer.seek(0) + loaded = torch.jit.load(buffer) + loaded.eval() + return loaded + + +def test_cnn_model_image_only_forward(): + """Image-only models compute a latent purely from the CNN encoders.""" + model, obs = _make_image_only_model() + with torch.inference_mode(): + out = model(obs) + assert out.shape == (1, 2) + + +def test_cnn_model_image_only_jit_export_takes_only_images(): + """The JIT export of an image-only model takes only the 2D observations as input.""" + model, obs = _make_image_only_model() + loaded = _script_roundtrip(model.as_jit()) + with torch.inference_mode(): + out = loaded([obs["policy"]]) + ref = model(obs) + torch.testing.assert_close(out, ref, rtol=1e-5, atol=1e-6) + + +def test_cnn_model_image_only_onnx_export_takes_only_images(tmp_path): + """The ONNX export of an image-only model declares only the 2D observations as inputs.""" + onnx = pytest.importorskip("onnx", reason="onnx is not installed") + + model, obs = _make_image_only_model() + onnx_model = model.as_onnx(verbose=False) + onnx_model.eval() + assert onnx_model.input_names == ["policy"] + + path = str(tmp_path / "policy.onnx") + torch.onnx.export( + onnx_model, + onnx_model.get_dummy_inputs(), + path, + export_params=True, + opset_version=18, + input_names=onnx_model.input_names, + output_names=onnx_model.output_names, + ) + graph_inputs = [graph_input.name for graph_input in onnx.load(path).graph.input] + assert graph_inputs == ["policy"] + + ort = pytest.importorskip("onnxruntime", reason="onnxruntime is not installed") + session = ort.InferenceSession(path) + out = torch.from_numpy(session.run(None, {"policy": obs["policy"].numpy()})[0]) + with torch.inference_mode(): + ref = model(obs) + torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-5) + + +def test_cnn_model_mixed_obs_jit_export_parity(): + """Models with 1D and 2D observation groups keep the upstream export interface.""" + model, obs = _make_mixed_model() + loaded = _script_roundtrip(model.as_jit()) + with torch.inference_mode(): + out = loaded(obs["proprio"], [obs["camera"]]) + ref = model(obs) + torch.testing.assert_close(out, ref, rtol=1e-5, atol=1e-6) + + +def test_cnn_model_mixed_obs_onnx_export_parity(tmp_path): + """Models with 1D and 2D observation groups keep the upstream ONNX interface.""" + pytest.importorskip("onnx", reason="onnx is not installed") + ort = pytest.importorskip("onnxruntime", reason="onnxruntime is not installed") + + model, obs = _make_mixed_model() + onnx_model = model.as_onnx(verbose=False) + onnx_model.eval() + assert onnx_model.input_names == ["obs", "camera"] + + path = str(tmp_path / "policy.onnx") + torch.onnx.export( + onnx_model, + onnx_model.get_dummy_inputs(), + path, + export_params=True, + opset_version=18, + input_names=onnx_model.input_names, + output_names=onnx_model.output_names, + ) + session = ort.InferenceSession(path) + out = torch.from_numpy(session.run(None, {"obs": obs["proprio"].numpy(), "camera": obs["camera"].numpy()})[0]) + with torch.inference_mode(): + ref = model(obs) + torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-5)