Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Fixed
^^^^^

* Fixed the JIT and ONNX export of image-only :class:`~isaaclab_rl.rsl_rl.models.CNNModel`
policies. The exported models no longer require feeding a zero-width 1D observation
input (``obs``); they now only take the 2D observation groups as inputs.
104 changes: 104 additions & 0 deletions source/isaaclab_rl/isaaclab_rl/rsl_rl/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,27 @@

from __future__ import annotations

import copy

import torch
from rsl_rl.models.cnn_model import CNNModel as _CNNModel
from rsl_rl.models.mlp_model import MLPModel
from rsl_rl.modules import HiddenState
from tensordict import TensorDict
from torch import nn


class CNNModel(_CNNModel):
"""CNN model that supports pure image-only observations.

The rsl_rl CNN model does not support image-only observations as it calls
:meth:`get_latent` without checking whether the observation groups are empty.

The same applies to the export wrappers returned by :meth:`as_jit` and :meth:`as_onnx`:
the rsl_rl wrappers always expect a 1D observation input, which for image-only models
becomes a mandatory zero-width ``obs`` tensor that deployment runtimes have to feed.
For image-only models, this class instead returns export wrappers that only take the
2D observation groups as inputs.
"""

def get_latent(
Expand All @@ -29,3 +38,98 @@ def get_latent(
return latent_cnn
latent_1d = MLPModel.get_latent(self, obs, masks, hidden_state)
return torch.cat([latent_1d, latent_cnn], dim=-1)

def as_jit(self) -> nn.Module:
"""Return a version of the model compatible with Torch JIT export."""
if not self.obs_groups:
return _TorchImageOnlyCNNModel(self)
return super().as_jit()

def as_onnx(self, verbose: bool = False) -> nn.Module:
"""Return a version of the model compatible with ONNX export."""
if not self.obs_groups:
return _OnnxImageOnlyCNNModel(self, verbose)
return super().as_onnx(verbose)


class _TorchImageOnlyCNNModel(nn.Module):
"""Exportable image-only CNN model for JIT.

Unlike ``rsl_rl``'s exportable CNN model, the forward pass only takes the 2D observation
groups as input, without a placeholder for the (empty) 1D observations.
"""

def __init__(self, model: CNNModel):
super().__init__()
# Convert ModuleDict to ModuleList for ordered iteration
self.cnns = nn.ModuleList([copy.deepcopy(model.cnns[g]) for g in model.obs_groups_2d])
self.mlp = copy.deepcopy(model.mlp)
if model.distribution is not None:
self.deterministic_output = model.distribution.as_deterministic_output_module()
else:
self.deterministic_output = nn.Identity()

def forward(self, obs_2d: list[torch.Tensor]) -> torch.Tensor:
"""Run deterministic inference from the 2D observation groups."""
latent_cnn_list = []
for i, cnn in enumerate(self.cnns): # We assume obs_2d list matches the order of obs_groups_2d
latent_cnn_list.append(cnn(obs_2d[i]))
latent = torch.cat(latent_cnn_list, dim=-1)
out = self.mlp(latent)
return self.deterministic_output(out)

@torch.jit.export
def reset(self) -> None:
"""Reset recurrent export state (no-op for CNN exports)."""
pass


class _OnnxImageOnlyCNNModel(nn.Module):
"""Exportable image-only CNN model for ONNX.

Unlike ``rsl_rl``'s exportable CNN model, the forward pass only takes the 2D observation
groups as input, without a placeholder for the (empty) 1D observations.
"""

def __init__(self, model: CNNModel, verbose: bool):
super().__init__()
self.verbose = verbose

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Dead verbose attribute

self.verbose is assigned in __init__ but is never read anywhere in _OnnxImageOnlyCNNModel. The upstream rsl_rl ONNX wrapper presumably prints tracing info when verbose=True, but this class omits that behavior without surfacing the omission. If verbosity is intentionally dropped here, the field should not be stored; if it should be forwarded (e.g. to torch.onnx.export), the call site in as_onnx has the parameter ready.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kept deliberately for structural parity with the upstream wrappers: rsl_rl's _OnnxCNNModel stores self.verbose the same way without reading it in forward (the runner passes verbose to torch.onnx.export separately). Keeping the same shape means these classes stay directly diff-able against the upstream ones they specialize, which also makes them easy to upstream into rsl_rl later.

# Convert ModuleDict to ModuleList for ordered iteration
self.cnns = nn.ModuleList([copy.deepcopy(model.cnns[g]) for g in model.obs_groups_2d])
self.mlp = copy.deepcopy(model.mlp)
if model.distribution is not None:
self.deterministic_output = model.distribution.as_deterministic_output_module()
else:
self.deterministic_output = nn.Identity()

self.obs_groups_2d = model.obs_groups_2d
self.obs_dims_2d = model.obs_dims_2d
self.obs_channels_2d = model.obs_channels_2d

def forward(self, *obs_2d: torch.Tensor) -> torch.Tensor:
"""Run deterministic inference for ONNX export."""
latent_cnn_list = []
for i, cnn in enumerate(self.cnns):
latent_cnn_list.append(cnn(obs_2d[i]))
latent = torch.cat(latent_cnn_list, dim=-1)
out = self.mlp(latent)
return self.deterministic_output(out)

def get_dummy_inputs(self) -> tuple[torch.Tensor, ...]:
"""Return representative dummy inputs for ONNX tracing."""
dummy_2d = []
for i in range(len(self.obs_groups_2d)):
h, w = self.obs_dims_2d[i]
c = self.obs_channels_2d[i]
dummy_2d.append(torch.zeros(1, c, h, w))
return tuple(dummy_2d)

@property
def input_names(self) -> list[str]:
"""Return ONNX input tensor names."""
return list(self.obs_groups_2d)

@property
def output_names(self) -> list[str]:
"""Return ONNX output tensor names."""
return ["actions"]
Comment on lines +55 to +135

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Duplicated initialization logic across wrapper classes

_TorchImageOnlyCNNModel and _OnnxImageOnlyCNNModel share identical __init__ logic: deep-copying cnns, mlp, and computing deterministic_output. The forward body is also nearly identical. A common base class (or a single wrapper that can serve both export paths after construction) would eliminate the redundancy and ensure future changes to one path are automatically reflected in the other.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same rationale as the other thread: the duplication mirrors rsl_rl's own _TorchCNNModel / _OnnxCNNModel, which duplicate the identical init/forward logic between themselves. The two classes can't share a forward (TorchScript can't script *args, while the ONNX path wants positional tensor inputs for named graph inputs), and I'd rather keep these recognizable as the image-only counterparts of the upstream pair than introduce a base-class hierarchy upstream doesn't have. If maintainers prefer a shared mixin I'm happy to refactor.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a solid rationale — the *args vs list[Tensor] distinction is a real TorchScript constraint, and keeping the image-only wrappers structurally parallel to the upstream pair makes the intent immediately obvious to anyone familiar with rsl_rl. No objection to leaving it as-is. Thanks for the context!

Tip: You can customize Greptile's behavior for this repo with .greptile/rules.md and .greptile/config.json.

161 changes: 161 additions & 0 deletions source/isaaclab_rl/test/test_rsl_rl_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Tests for the RSL-RL neural models customized for Isaac Lab.

These tests run on CPU and do not require Isaac Sim. They cover the image-only observation
support of :class:`~isaaclab_rl.rsl_rl.models.CNNModel` for training (``get_latent``) and for
deployment (``as_jit`` / ``as_onnx`` export wrappers), as well as parity of the exported
models with the original model.
"""

import io

import pytest
import torch

pytest.importorskip("rsl_rl", reason="rsl-rl-lib is not installed")
pytest.importorskip("tensordict", reason="tensordict is not installed")

from tensordict import TensorDict # noqa: E402

from isaaclab_rl.rsl_rl.models import CNNModel # noqa: E402

CNN_CFG = {"output_channels": [16, 32], "kernel_size": [8, 4], "stride": [4, 2], "activation": "relu"}
DIST_CFG = {"class_name": "GaussianDistribution", "init_std": 1.0}


def _make_image_only_model() -> tuple[CNNModel, TensorDict]:
"""Create a CNN model with a single image observation group (e.g. cartpole camera task)."""
obs = TensorDict({"policy": torch.rand(1, 3, 64, 64)}, batch_size=[1])
obs_groups = {"actor": ["policy"], "critic": ["policy"]}
model = CNNModel(
obs,
obs_groups,
"actor",
2,
hidden_dims=[64],
activation="elu",
obs_normalization=False,
distribution_cfg=dict(DIST_CFG),
cnn_cfg=dict(CNN_CFG),
)
model.eval()
return model, obs


def _make_mixed_model() -> tuple[CNNModel, TensorDict]:
"""Create a CNN model with image and proprioceptive observation groups."""
obs = TensorDict({"proprio": torch.randn(1, 12), "camera": torch.rand(1, 3, 64, 64)}, batch_size=[1])
obs_groups = {"actor": ["proprio", "camera"], "critic": ["proprio", "camera"]}
model = CNNModel(
obs,
obs_groups,
"actor",
2,
hidden_dims=[64],
activation="elu",
obs_normalization=False,
distribution_cfg=dict(DIST_CFG),
cnn_cfg=dict(CNN_CFG),
)
model.eval()
return model, obs


def _script_roundtrip(jit_model: torch.nn.Module) -> torch.jit.ScriptModule:
"""Script, save and reload a module the way the runner export + deployment does."""
scripted = torch.jit.script(jit_model)
buffer = io.BytesIO()
torch.jit.save(scripted, buffer)
buffer.seek(0)
loaded = torch.jit.load(buffer)
loaded.eval()
return loaded


def test_cnn_model_image_only_forward():
"""Image-only models compute a latent purely from the CNN encoders."""
model, obs = _make_image_only_model()
with torch.inference_mode():
out = model(obs)
assert out.shape == (1, 2)


def test_cnn_model_image_only_jit_export_takes_only_images():
"""The JIT export of an image-only model takes only the 2D observations as input."""
model, obs = _make_image_only_model()
loaded = _script_roundtrip(model.as_jit())
with torch.inference_mode():
out = loaded([obs["policy"]])
ref = model(obs)
torch.testing.assert_close(out, ref, rtol=1e-5, atol=1e-6)


def test_cnn_model_image_only_onnx_export_takes_only_images(tmp_path):
"""The ONNX export of an image-only model declares only the 2D observations as inputs."""
onnx = pytest.importorskip("onnx", reason="onnx is not installed")

model, obs = _make_image_only_model()
onnx_model = model.as_onnx(verbose=False)
onnx_model.eval()
assert onnx_model.input_names == ["policy"]

path = str(tmp_path / "policy.onnx")
torch.onnx.export(
onnx_model,
onnx_model.get_dummy_inputs(),
path,
export_params=True,
opset_version=18,
input_names=onnx_model.input_names,
output_names=onnx_model.output_names,
)
graph_inputs = [graph_input.name for graph_input in onnx.load(path).graph.input]
assert graph_inputs == ["policy"]

ort = pytest.importorskip("onnxruntime", reason="onnxruntime is not installed")
session = ort.InferenceSession(path)
out = torch.from_numpy(session.run(None, {"policy": obs["policy"].numpy()})[0])
with torch.inference_mode():
ref = model(obs)
torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-5)


def test_cnn_model_mixed_obs_jit_export_parity():
"""Models with 1D and 2D observation groups keep the upstream export interface."""
model, obs = _make_mixed_model()
loaded = _script_roundtrip(model.as_jit())
with torch.inference_mode():
out = loaded(obs["proprio"], [obs["camera"]])
ref = model(obs)
torch.testing.assert_close(out, ref, rtol=1e-5, atol=1e-6)


def test_cnn_model_mixed_obs_onnx_export_parity(tmp_path):
"""Models with 1D and 2D observation groups keep the upstream ONNX interface."""
pytest.importorskip("onnx", reason="onnx is not installed")
ort = pytest.importorskip("onnxruntime", reason="onnxruntime is not installed")

model, obs = _make_mixed_model()
onnx_model = model.as_onnx(verbose=False)
onnx_model.eval()
assert onnx_model.input_names == ["obs", "camera"]

path = str(tmp_path / "policy.onnx")
torch.onnx.export(
onnx_model,
onnx_model.get_dummy_inputs(),
path,
export_params=True,
opset_version=18,
input_names=onnx_model.input_names,
output_names=onnx_model.output_names,
)
session = ort.InferenceSession(path)
out = torch.from_numpy(session.run(None, {"obs": obs["proprio"].numpy(), "camera": obs["camera"].numpy()})[0])
with torch.inference_mode():
ref = model(obs)
torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-5)