Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import isaaclab_ovphysx.tensor_types as TT
from isaaclab_ovphysx.physics import OvPhysxManager as SimulationManager
from isaaclab_ovphysx.sim.views.ovphysx_view import OvPhysxView

from .frame_transformer_data import FrameTransformerData
from .kernels import frame_transformer_update_kernel, gather_body_pose_kernel
Expand Down Expand Up @@ -61,7 +62,7 @@ def __init__(self, cfg: FrameTransformerCfg):
super().__init__(cfg)
self._data: FrameTransformerData = FrameTransformerData()
self._physx_instance: Any = None
self._body_bindings: list[Any] = []
self._body_views: list[OvPhysxView] = []
self._body_read_bufs: list[wp.array] = []
self._body_dst_flat_indices: list[wp.array] = []
self._raw_transforms: wp.array | None = None
Expand Down Expand Up @@ -228,28 +229,30 @@ def has_rigid_body_api(prim) -> bool:
tracked_prim_paths = [body_names_to_frames[body_name]["prim_path"] for body_name in body_names_to_frames.keys()]
tracked_body_names = [body_name for body_name in body_names_to_frames.keys()]

# --- OVPhysX: create one TT.RIGID_BODY_POSE binding per unique tracked body ---
# --- OVPhysX: create one TT.RIGID_BODY_POSE view per unique tracked body ---
physx_instance = SimulationManager.get_physx_instance()
if physx_instance is None:
raise RuntimeError(
"OvPhysxManager has not been initialized yet."
" Reset the simulation context before adding the FrameTransformer."
)
self._physx_instance = physx_instance
self._body_bindings = []
self._body_read_bufs = [] # one (num_envs, 7) float32 buffer per body
self._body_views = []
self._body_read_bufs = [] # one (num_envs,) wp.transformf buffer per body
self._body_dst_flat_indices = [] # (num_envs,) int32 destination slots per body

num_unique_bodies = len(tracked_body_names)

for body_slot, tracked_path in enumerate(tracked_prim_paths):
pattern = self._env_wildcardify(tracked_path)
binding = physx_instance.create_tensor_binding(pattern=pattern, tensor_type=TT.RIGID_BODY_POSE)
if binding.count == 0:
view = OvPhysxView(physx_instance, pattern=pattern, device=self._device)
try:
binding = view.binding_for(TT.RIGID_BODY_POSE)
except OvPhysxView.AttributeUnavailable as exc:
raise RuntimeError(
f"FrameTransformer: TT.RIGID_BODY_POSE binding for pattern {pattern!r} matched zero bodies."
" Verify the prim has UsdPhysics.RigidBodyAPI."
)
) from exc
Comment on lines 252 to +255

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The error message hard-codes "matched zero bodies" but OvPhysxView.AttributeUnavailable is also raised when create_tensor_binding itself throws (e.g. ABI / OOM / init failure), so the message misleads debugging in those cases. The chained exception exposes the real cause, but the primary message is incorrect. Consider making the message more general.

Suggested change
raise RuntimeError(
f"FrameTransformer: TT.RIGID_BODY_POSE binding for pattern {pattern!r} matched zero bodies."
" Verify the prim has UsdPhysics.RigidBodyAPI."
)
) from exc
raise RuntimeError(
f"FrameTransformer: TT.RIGID_BODY_POSE binding for pattern {pattern!r} is not available"
" (no matching prims or binding creation failed)."
" Verify the prim has UsdPhysics.RigidBodyAPI."
) from exc

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


if binding.count != self._num_envs:
# OVPhysX's InteractiveScene defaults to clone_usd=True on develop, so this branch is
Expand All @@ -270,13 +273,13 @@ def has_rigid_body_api(prim) -> bool:
self._timestamp = wp.zeros(self._num_envs, dtype=wp.float32, device=self._device)
self._timestamp_last_update = wp.zeros_like(self._timestamp)

read_buf = wp.zeros((self._num_envs, 7), dtype=wp.float32, device=self._device)
read_buf = wp.zeros(self._num_envs, dtype=wp.transformf, device=self._device)
dst_torch = torch.tensor(
[env_id * num_unique_bodies + body_slot for env_id in range(self._num_envs)],
dtype=torch.int32,
device=self._device,
)
self._body_bindings.append(binding)
self._body_views.append(view)
self._body_read_bufs.append(read_buf)
self._body_dst_flat_indices.append(wp.from_torch(dst_torch.contiguous(), dtype=wp.int32))

Expand Down Expand Up @@ -423,16 +426,15 @@ def _update_buffers_impl(self, env_mask: wp.array | None = None) -> None:
" Access sensor data only after sim.reset() has been called."
)

# Step 1: refresh each per-body RIGID_BODY_POSE binding and gather rows into _raw_transforms.
for binding, read_buf, dst_indices in zip(
self._body_bindings, self._body_read_bufs, self._body_dst_flat_indices
):
binding.read(read_buf)
pose_buf_tf = read_buf.view(wp.transformf) # (num_envs, 7) float32 -> (num_envs,) transformf
# Step 1: refresh each per-body RIGID_BODY_POSE view and gather rows into _raw_transforms.
# read_into fills the structured (num_envs,) wp.transformf buffer directly via a cached
# float32 reinterpret, so no manual .view(wp.transformf) is needed here.
for view, read_buf, dst_indices in zip(self._body_views, self._body_read_bufs, self._body_dst_flat_indices):
view.read_into(TT.RIGID_BODY_POSE, read_buf)
wp.launch(
gather_body_pose_kernel,
dim=self._num_envs,
inputs=[env_mask, pose_buf_tf, dst_indices, self._raw_transforms],
inputs=[env_mask, read_buf, dst_indices, self._raw_transforms],
device=self._device,
)

Expand Down Expand Up @@ -519,7 +521,7 @@ def _invalidate_initialize_callback(self, event) -> None:
super()._invalidate_initialize_callback(event)
self._body_read_bufs = []
self._body_dst_flat_indices = []
self._body_bindings = []
self._body_views = []
self._physx_instance = None
self._raw_transforms = None
self._source_raw_indices = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,10 @@ def gather_body_pose_kernel(
For each env in the launch, copies ``pose_buffer[env]`` into
``raw_transforms[dst_flat_indices[env]]``. Skips envs whose ``env_mask`` is False.

The pose buffer is a view (``wp.array.view(wp.transformf)``) over a
``(num_envs, 7)`` ``float32`` array populated by
``binding.read(...)`` for a single ``RIGID_BODY_POSE`` tensor binding,
so it has shape ``(num_envs,)``. One launch per tracked body fills the
body's slot column in the flat ``raw_transforms`` buffer.
The pose buffer is a ``(num_envs,)`` ``wp.transformf`` array filled in place
from a single ``RIGID_BODY_POSE`` view via ``OvPhysxView.read_into``. One
launch per tracked body fills the body's slot column in the flat
``raw_transforms`` buffer.

Args:
env_mask: Active environment mask, shape ``(num_envs,)``.
Expand Down
Loading