From 36c942a015dddcccd3be986b5d494333a416ae90 Mon Sep 17 00:00:00 2001 From: kavya-chennoju Date: Mon, 8 Jun 2026 08:29:12 -0700 Subject: [PATCH 1/6] feat: add Device Connect integration (on top of main) Re-applies the Device Connect integration from strands-labs/robots#52 onto main's architecture (mesh/ subsystem, robot.py Robot factory), rather than merging the divergent dev branch wholesale. What lands: - strands_robots/device_connect/: DeviceDriver adapters wrapping a Simulation / HardwareRobot (sim_driver, robot_driver, reachy_*), init_device_connect[_sync], GUIDE.md, setup.sh, smoke test. The drivers target main's Simulation API unchanged (start_policy, get_features, get_state, step, reset, _world/SimRobot fields). - robot.py: Robot("...").run() server hook attached by the factory. run() stops the auto-started mesh and brings the device online via Device Connect (the primary networking layer in server mode). - tools/robot_mesh.py: Device Connect dispatch layered UNDERNEATH main's existing safety gates. robot_mesh() still runs its rate limit, HITL approval (emergency_stop/broadcast), command validation, and audit first; only then does it try Device Connect, falling back to the built-in mesh when DC is unavailable or has no devices. DC dispatch re-applies the per-action validation/audit so it inherits the same safety. Gated by STRANDS_ROBOT_MESH_DC (default on; conftest turns it off so mesh unit tests stay hermetic and deterministic). - __init__.py / pyproject.toml: lazy DC exports + device-connect-edge / device-connect-agent-tools core deps. Tests: - Ports test_device_connect_{drivers,all_robots,integration}.py. Their stubs prefer the real strands / device_connect_agent_tools packages (installed here) instead of leaving leaking sys.modules mocks, and no longer delete strands_robots.tools.robot_mesh (which created a second module object and broke sibling tests' _resolve_mesh patches). - conftest.py disables the DC dispatch path during the suite. Behavior change documented in GUIDE.md: emergency_stop / broadcast are HITL-gated and fail closed when called as a bare script (no operator to approve); they run from within a Strands agent loop on approval. Verified locally: 1236 passed / 3 skipped (DC tests + main's robot_mesh safety, deep-mesh, and factory tests together); GUIDE D2D demo (peers / tell / emergency_stop) works end-to-end against device-connect PR #52 packages. Local working branch for review; not pushed. --- pyproject.toml | 3 + strands_robots/__init__.py | 14 + strands_robots/device_connect/GUIDE.md | 410 +++++++++ strands_robots/device_connect/__init__.py | 194 ++++ .../device_connect/reachy_mini_driver.py | 321 +++++++ .../device_connect/reachy_transport.py | 163 ++++ strands_robots/device_connect/robot_driver.py | 200 ++++ strands_robots/device_connect/setup.sh | 53 ++ strands_robots/device_connect/sim_driver.py | 230 +++++ .../device_connect/test_control_loop_dc.sh | 197 ++++ strands_robots/robot.py | 55 ++ strands_robots/tools/robot_mesh.py | 200 ++++ tests/conftest.py | 6 + tests/test_device_connect_all_robots.py | 722 +++++++++++++++ tests/test_device_connect_drivers.py | 857 ++++++++++++++++++ tests/test_device_connect_integration.py | 368 ++++++++ 16 files changed, 3993 insertions(+) create mode 100644 strands_robots/device_connect/GUIDE.md create mode 100644 strands_robots/device_connect/__init__.py create mode 100644 strands_robots/device_connect/reachy_mini_driver.py create mode 100644 strands_robots/device_connect/reachy_transport.py create mode 100644 strands_robots/device_connect/robot_driver.py create mode 100755 strands_robots/device_connect/setup.sh create mode 100644 strands_robots/device_connect/sim_driver.py create mode 100755 strands_robots/device_connect/test_control_loop_dc.sh create mode 100644 tests/test_device_connect_all_robots.py create mode 100644 tests/test_device_connect_drivers.py create mode 100644 tests/test_device_connect_integration.py diff --git a/pyproject.toml b/pyproject.toml index d16c8a05..4a8619e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,9 @@ dependencies = [ "numpy>=1.21.0,<3.0.0", "opencv-python-headless>=4.5.0,<5.0.0", "Pillow>=8.0.0,<12.0.0", + # Device Connect — primary networking layer (discovery, RPC, events, safety) + "device-connect-edge>=0.2.0", + "device-connect-agent-tools>=0.1.0", ] [project.optional-dependencies] diff --git a/strands_robots/__init__.py b/strands_robots/__init__.py index 417bd7a4..1027d77f 100644 --- a/strands_robots/__init__.py +++ b/strands_robots/__init__.py @@ -81,6 +81,14 @@ "lerobot_teleoperate": ("strands_robots.tools.lerobot_teleoperate", "lerobot_teleoperate"), "pose_tool": ("strands_robots.tools.pose_tool", "pose_tool"), "serial_tool": ("strands_robots.tools.serial_tool", "serial_tool"), + # Robot mesh coordination tool (Device Connect dispatch + mesh fallback) + "robot_mesh": ("strands_robots.tools.robot_mesh", "robot_mesh"), + # Device Connect integration — wraps robots as Device Connect devices + "init_device_connect": ("strands_robots.device_connect", "init_device_connect"), + "init_device_connect_sync": ("strands_robots.device_connect", "init_device_connect_sync"), + "RobotDeviceDriver": ("strands_robots.device_connect", "RobotDeviceDriver"), + "SimulationDeviceDriver": ("strands_robots.device_connect", "SimulationDeviceDriver"), + "ReachyMiniDriver": ("strands_robots.device_connect", "ReachyMiniDriver"), } __all__ = [ @@ -106,6 +114,12 @@ "lerobot_calibrate", "serial_tool", "pose_tool", + "robot_mesh", + "init_device_connect", + "init_device_connect_sync", + "RobotDeviceDriver", + "SimulationDeviceDriver", + "ReachyMiniDriver", ] diff --git a/strands_robots/device_connect/GUIDE.md b/strands_robots/device_connect/GUIDE.md new file mode 100644 index 00000000..b09d0265 --- /dev/null +++ b/strands_robots/device_connect/GUIDE.md @@ -0,0 +1,410 @@ +# Device Connect Integration + +Strands Robots uses [Device Connect](https://github.com/arm/device-connect), a **device-aware runtime** by Arm — to handle discovery, presence, structured RPC, event routing, and safety — so you can focus on building cross-device experiences instead of re-implementing infrastructure. + +> **Fallback behavior:** If `device-connect-edge` is not installed, Strands Robots automatically falls back to a built-in Zenoh P2P mesh (`zenoh_mesh.py`) for basic peer discovery and coordination. Device Connect is the recommended and primary networking layer. + +### Quick Start + +```python +from strands_robots import Robot + +r = Robot("so100") +r.run() # starts listening for commands. Ctrl+C to stop. +``` + +`Robot()` creates the robot. `.run()` starts Device Connect with D2D defaults (Zenoh multicast scouting, no broker, no env vars) and blocks — the robot becomes discoverable on the LAN and listens for commands. Without `.run()`, the script exits and the robot is removed from the network. + +You can optionally pass `peer_id="so100-lab-1"` for a stable address; otherwise one is auto-generated (e.g. `so100-a3f1b2`). + +**Robot lifecycle:** + +| Pattern | Behavior | +|---|---| +| `r = Robot("so100"); r.run()` | **Option A — Foreground server.** Process stays alive, listens for commands. Ctrl+C to stop. | +| `r = Robot("so100")` | **Option B — Agent-controlled.** A Strands Agent discovers the robot via `robot_mesh` or `discover_devices()` and invokes commands remotely. | + +From another process, discover and invoke: + +```python +from strands_robots.tools.robot_mesh import robot_mesh + +robot_mesh(action="peers") # discover devices +robot_mesh(action="tell", target="so100-lab-1", # invoke + instruction="pick up the cube") +robot_mesh(action="emergency_stop") # e-stop all (HITL-approved) +``` + +### Architecture + +```mermaid +graph TD + subgraph "Device Connect Infrastructure" + ZENOH_R["Zenoh Router"] + ETCD["etcd (Registry)"] + REG["Registry Service"] + end + + subgraph "Robot Process" + ROBOT["Robot('so100')"] + ADAPTER["RobotDeviceDriver"] + RUNTIME["DeviceRuntime"] + ROBOT --> ADAPTER + ADAPTER --> RUNTIME + RUNTIME --> ZENOH_R + end + + subgraph "Agent Process" + AGENT["Strands Agent"] + TOOLS["discover_devices + invoke_device"] + AGENT --> TOOLS + TOOLS --> ZENOH_R + end + + ZENOH_R --> REG + REG --> ETCD +``` + +### E2E Demo + +No Docker needed. No env vars. Devices discover each other directly on the LAN via Zenoh multicast scouting. `Robot()` and `robot_mesh()` auto-configure D2D mode when no broker URL is set. + +#### Setup + +##### 1. Install + +> `setup.sh` installs `uv`, Python 3.12, creates a venv, and installs all dependencies. + +```bash +git clone --branch feat/device-connect-on-main https://github.com/kavya-chennoju/robots.git +cd robots +./strands_robots/device_connect/setup.sh +source .venv/bin/activate +``` + +##### 2. Start a robot + +```bash +screen -S robot # start a persistent session +python -c " +from strands_robots import Robot +r = Robot('so100') +r.run() +" +# once online, Ctrl+a then d to detach +# screen -ls to list sessions +# screen -r robot to reattach +``` + +Expected output: + +``` +device_connect_edge.device. - INFO - Using ZENOH messaging backend +device_connect_edge.device. - INFO - Connected to ZENOH broker: [] +device_connect_edge.device. - INFO - Driver connected: strands_sim +device_connect_edge.device. - INFO - Subscribed to commands on device-connect.default..cmd +🤖 is online. Ctrl+C to stop. +``` + +> `` is auto-generated (e.g. `so100-a3f1b2`) unless you pass a fixed peer ID: +> ```python +> r = Robot('so100', peer_id='so100-lab-1') +> ``` + +#### Option A: Using the `robot_mesh` Strands tool + +The `robot_mesh` tool auto-detects Device Connect and uses it when available, falling back to the plain Zenoh mesh otherwise. + +**Discover peers:** + +```python +python -c " +from strands_robots.tools.robot_mesh import robot_mesh +print(robot_mesh(action='peers')) +" +``` + +Expected output: + +``` +Discovered 1 device(s): + [sim] — idle + Functions: execute, getFeatures, getStatus, reset, step, stop +``` + +**Tell a robot to execute an instruction:** + +Use the `` from the discover step above as the `target`: + +```python +python -c " +from strands_robots.tools.robot_mesh import robot_mesh +print(robot_mesh(action='tell', target='', + instruction='pick up the cube', policy_provider='mock')) +" +``` + +Expected output: + +``` +-> : pick up the cube + {"status": "success", "content": [...]} +``` + +**Emergency stop all devices:** + +> **Safety — human-in-the-loop.** `emergency_stop` and `broadcast` are +> fleet-wide actions gated behind an operator approval interrupt +> (`tool_context.interrupt`), plus a per-action rate limit and an audit log. +> They run only when invoked from inside a Strands agent loop, where the +> operator approves the action out-of-band of the LLM's tool arguments (so +> prompt-injection cannot smuggle approval). Device Connect dispatch is layered +> *under* this gate, so DC inherits the same safety. Called as a bare script — +> with no operator to ask — `emergency_stop` **fails closed**: + +```python +python -c " +from strands_robots.tools.robot_mesh import robot_mesh +print(robot_mesh(action='emergency_stop')) +" +``` + +Expected output (no operator present → fails closed): + +``` +{'status': 'error', 'content': [{'text': "action 'emergency_stop' requires a +human-in-the-loop interrupt, but no tool_context is available in this calling +context."}]} +``` + +From inside a Strands agent the operator is prompted to approve; on approval the +e-stop fans out to every Device Connect device: + +``` +E-STOP: 1/1 devices stopped +``` + +#### Option B: Discover and invoke with `device-connect-agent-tools` directly + +```python +python -c " +from device_connect_agent_tools import connect, discover_devices, invoke_device + +connect() + +devices = discover_devices() +print(f'Found {len(devices)} robot(s):') +for d in devices: + print(f' {d[\"device_id\"]} — {d.get(\"status\", {}).get(\"availability\", \"?\")}') + +if devices: + result = invoke_device( + devices[0]['device_id'], 'execute', + {'instruction': 'pick up the cube', 'policy_provider': 'mock'}, + ) + print(f'Execute result: {result}') + + status = invoke_device(devices[0]['device_id'], 'getStatus') + print(f'Status: {status}') +" +``` + +Expected output: + +``` +Found 1 robot(s): + — idle +Execute result: {'success': True, 'result': {'status': 'success', 'content': [...]}} +Status: {'success': True, 'result': {...}} # full sim state dict +``` + +#### Full Infrastructure (Optional) + +For production deployments, you can add Docker infrastructure for persistent registry, distributed state, cross-network routing, and authentication. + +Start the Device Connect infrastructure (Zenoh router + etcd + device registry): + +```bash +git clone --depth 1 https://github.com/arm/device-connect.git +cd device-connect/packages/device-connect-server +docker compose -f infra/docker-compose-dev.yml up -d +cd ../../.. +``` + +This starts: + +| Service | Port | Purpose | +|---|---|---| +| Zenoh router | `:7447` | Messaging (RPC, events, heartbeats) | +| etcd | `:2379` | Device registry storage | +| Device registry | `:8080` | REST API for device metadata | + +Set environment variables (all terminals): + +```bash +export MESSAGING_BACKEND=zenoh +export ZENOH_CONNECT=tcp/localhost:7447 +export DEVICE_CONNECT_ALLOW_INSECURE=true +``` + +All the options above (A–B) work identically with full infrastructure — the only difference is that devices register in etcd and discovery goes through the registry service instead of multicast scouting. + +> **What infrastructure adds over D2D:** +> - **Persistent device registry** — devices register with TTL-based leases; stale devices are auto-cleaned. Agents can discover devices by type, location, or capability via `discover_devices()`. +> - **Distributed state & locks** — etcd-backed key-value store with atomic distributed locks for coordinating shared resources (e.g., preventing two agents from using the same robotic arm simultaneously). +> - **Cross-network routing** — the Zenoh router (or NATS broker) enables communication across subnets and sites, not just the local LAN. +> - **Authentication & authorization** — mTLS ensures only devices with certificates signed by the trusted CA can exchange data. Full authorization (per-device permissions, topic-level ACLs, certificate revocation) requires the router/registry infrastructure. + +#### Running the Tests + +```bash +pip install pytest pytest-cov # if not already installed + +# Unit tests (no Docker needed) +python3 -m pytest tests/test_device_connect_drivers.py -v + +# Integration tests (requires Docker infrastructure) +MESSAGING_BACKEND=zenoh ZENOH_CONNECT=tcp/localhost:7447 \ + DEVICE_CONNECT_ALLOW_INSECURE=true \ + python3 -m pytest tests/test_device_connect_integration.py -v +``` + +#### Control Loop Smoke Test + +A self-contained script runs a 200-step mock-policy control loop while a Zenoh listener captures Device Connect events (stateUpdate, observationUpdate, presence, heartbeat) and asserts minimum thresholds: + +```bash +bash strands_robots/device_connect/test_control_loop_dc.sh +``` + +It installs dependencies, starts a Zenoh event listener, runs `Robot("so100")` with a mock policy for 200 steps, then validates that the expected events were published over Device Connect. + +--- + +## Reachy Mini (Zenoh-Native Devices) + +Reachy Mini has built-in Zenoh support — it publishes joint positions, head pose, and IMU data natively over Zenoh topics. This makes it a special case: it can be bridged directly via `subscribe()` or wrapped as a Device Connect device for structured RPC. + +### Bridging via Subscribe + +Use the mesh's `subscribe()` to read Reachy's native Zenoh topics directly: + +```python +sim = Robot("so100") + +# Subscribe to Reachy's head pose +sim.mesh.subscribe("reachy_mini/head_pose", + lambda topic, data: print(f"Reachy looking at: {data}")) + +# Subscribe to Reachy's joint positions +sim.mesh.subscribe("reachy_mini/joint_positions", name="reachy_joints") + +# Mirror Reachy's movements in simulation +def mirror_reachy(topic, data): + joints = data.get("head_joint_positions", []) + if joints: + # Map Reachy joints to sim joints... + pass + +sim.mesh.subscribe("reachy_mini/joint_positions", mirror_reachy) +``` + +### Architecture + +```mermaid +graph TD + subgraph "Reachy Mini Process" + REACHY["ReachyMiniDriver"] + RRUNTIME["DeviceRuntime"] + ZENOH_HW["Zenoh → Reachy HW"] + REACHY --> RRUNTIME + REACHY --> ZENOH_HW + end + + subgraph "Network" + ZENOH["Zenoh Mesh
(multicast or router)"] + end + + subgraph "Agent Process" + AGENT["Strands Agent"] + TOOLS["discover_devices + invoke_device"] + AGENT --> TOOLS + TOOLS --> ZENOH + end + + RRUNTIME --> ZENOH +``` + +### As a Device Connect Device + +Wrap Reachy Mini with `ReachyMiniDriver` to expose it as a structured Device Connect device with RPC commands (`look`, `nod`, etc.): + +```python +from strands_robots.device_connect import ReachyMiniDriver +from device_connect_edge import DeviceRuntime + +driver = ReachyMiniDriver(host="reachy-mini.local") +runtime = DeviceRuntime( + driver=driver, + device_id="reachy-mini-1", + messaging_urls=["tcp/localhost:7447"], + allow_insecure=True, +) +await runtime.run() + +# Now any agent can discover and control it: +invoke_device("reachy-mini-1", "look", {"pitch": -15, "yaw": 30}) +invoke_device("reachy-mini-1", "nod") +``` + +### E2E Demo + +> Requires a Reachy Mini robot. + +**Setup depends on your hardware variant:** + +| Variant | Connection | Setup | +|---|---|---| +| **Reachy Mini** (wireless) | Wi-Fi, onboard Pi | `host='reachy-mini.local'` — no extra install needed | +| **Reachy Mini Lite** (USB) | USB, no Pi | `pip install reachy-mini` then run `reachy-mini` daemon locally. Use `host='localhost'` | + +**Start the Reachy Mini driver:** + +```python +python -c " +import asyncio +from strands_robots.device_connect import ReachyMiniDriver +from device_connect_edge import DeviceRuntime + +# For Lite (USB): host='localhost' (requires reachy-mini daemon running) +# For Wireless: host='reachy-mini.local' +driver = ReachyMiniDriver(host='reachy-mini.local') +runtime = DeviceRuntime( + driver=driver, + device_id='reachy-mini-1', + messaging_urls=['tcp/localhost:7447'], + allow_insecure=True, +) + +asyncio.run(runtime.run()) +" +``` + +Expected output: + +``` +Reachy Mini driver connected: reachy-mini.local +device_connect_edge.device.reachy-mini-1 - INFO - Device registered +device_connect_edge.device.reachy-mini-1 - INFO - Subscribed to commands on device-connect.default.reachy-mini-1.cmd +``` + +**In another terminal, invoke RPCs:** + +```python +python -c " +from device_connect_agent_tools import connect, invoke_device +connect() +print(invoke_device('reachy-mini-1', 'look', {'pitch': -15, 'yaw': 30})) +print(invoke_device('reachy-mini-1', 'nod')) +" +``` diff --git a/strands_robots/device_connect/__init__.py b/strands_robots/device_connect/__init__.py new file mode 100644 index 00000000..e9f32004 --- /dev/null +++ b/strands_robots/device_connect/__init__.py @@ -0,0 +1,194 @@ +"""Device Connect integration for strands-robots. + +Provides DeviceDriver adapters that wrap Robot and Simulation instances, +exposing them to Device Connect's device registry, RPC routing, and event system. + +Usage: + from strands_robots.device_connect import init_device_connect + + robot = Robot("so100") + runtime = await init_device_connect(robot, peer_id="so100-lab-1") + + # Now discoverable via Device Connect tools: + # discover_devices(device_type="strands_robot") + # invoke_device("so100-lab-1", "execute", {"instruction": "pick up cube"}) +""" + +import asyncio +import logging +import os +import threading +import uuid +from typing import Optional + +from device_connect_edge import DeviceRuntime + +from strands_robots.device_connect.reachy_mini_driver import ReachyMiniDriver +from strands_robots.device_connect.robot_driver import RobotDeviceDriver +from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + +logger = logging.getLogger(__name__) + +__all__ = [ + "init_device_connect", + "init_device_connect_sync", + "RobotDeviceDriver", + "SimulationDeviceDriver", + "ReachyMiniDriver", +] + + +async def init_device_connect( + robot, + peer_id: Optional[str] = None, + peer_type: str = "robot", + messaging_url: Optional[str] = None, + messaging_backend: Optional[str] = None, + tenant: str = "default", + allow_insecure: Optional[bool] = None, +) -> DeviceRuntime: + """Initialize Device Connect for a Robot or Simulation. + + Drop-in replacement for init_mesh(). Creates a DeviceDriver adapter + and starts a DeviceRuntime in the background. + + When messaging_backend="zenoh" and messaging_url is None, the runtime + enters D2D mode — devices discover each other directly via Zenoh + multicast scouting on the LAN. No broker, no Docker, no env vars. + + Args: + robot: A Robot or Simulation instance to wrap. + peer_id: Device ID for registration (auto-generated if None). + peer_type: "robot" or "sim" — selects the appropriate driver. + messaging_url: Explicit messaging URL (overrides env vars). + messaging_backend: Messaging backend — "zenoh" or "nats". + None = auto-detect from MESSAGING_BACKEND env var (default "zenoh"). + tenant: Device Connect tenant namespace. + allow_insecure: Allow insecure connections. None = auto-detect: + respects DEVICE_CONNECT_ALLOW_INSECURE env var if set, + otherwise defaults to True in D2D mode (no broker URL). + + Returns: + The running DeviceRuntime instance. + """ + if peer_type == "sim": + driver = SimulationDeviceDriver(robot) + else: + driver = RobotDeviceDriver(robot) + + device_id = peer_id or f"{getattr(robot, 'tool_name_str', 'robot')}-{uuid.uuid4().hex[:4]}" + + urls = [messaging_url] if messaging_url else None + + # Resolve messaging_backend: explicit arg > env var > default "zenoh" + if messaging_backend is None: + messaging_backend = os.environ.get("MESSAGING_BACKEND", "zenoh") + + # Resolve allow_insecure: env var > explicit arg > D2D default + if allow_insecure is None: + env_val = os.environ.get("DEVICE_CONNECT_ALLOW_INSECURE") + if env_val is not None: + allow_insecure = env_val.lower() in ("true", "1", "yes") + elif urls is None: + # D2D mode — no broker, default insecure for dev convenience + allow_insecure = True + + runtime = DeviceRuntime( + driver=driver, + device_id=device_id, + messaging_urls=urls, + messaging_backend=messaging_backend, + tenant=tenant, + allow_insecure=allow_insecure, + ) + + # Provide robot-specific heartbeat data + runtime.set_heartbeat_provider(lambda: _build_heartbeat(robot, peer_type)) + + # Start runtime in background task; store ref to prevent GC + runtime._background_task = asyncio.create_task(runtime.run()) + + logger.info("Device Connect initialized: %s (%s, backend=%s, d2d=%s)", + device_id, peer_type, messaging_backend, urls is None) + return runtime + + +def init_device_connect_sync( + robot, + peer_id: Optional[str] = None, + peer_type: str = "robot", + messaging_url: Optional[str] = None, + messaging_backend: Optional[str] = None, + tenant: str = "default", + allow_insecure: Optional[bool] = None, +) -> "DeviceRuntime": + """Non-blocking sync wrapper around init_device_connect(). + + Starts the DeviceRuntime on a dedicated daemon thread so the caller + returns immediately — matching the Zenoh mesh ``init_mesh()`` pattern. + The runtime stays alive as long as the process (daemon thread). + + Same parameters as :func:`init_device_connect`. + """ + loop = asyncio.new_event_loop() + ready = threading.Event() + runtime_holder = [None] + error_holder = [None] + + async def _start(): + try: + rt = await init_device_connect( + robot, + peer_id=peer_id, + peer_type=peer_type, + messaging_url=messaging_url, + messaging_backend=messaging_backend, + tenant=tenant, + allow_insecure=allow_insecure, + ) + runtime_holder[0] = rt + except Exception as exc: + error_holder[0] = exc + finally: + ready.set() + + def _run(): + asyncio.set_event_loop(loop) + loop.run_until_complete(_start()) + loop.run_forever() + + thread = threading.Thread(target=_run, daemon=True, name="device-connect-runtime") + thread.start() + ready.wait(timeout=30.0) + + if error_holder[0] is not None: + raise error_holder[0] + + runtime = runtime_holder[0] + if runtime is not None: + runtime._loop = loop + runtime._thread = thread + return runtime + + +def _build_heartbeat(robot, peer_type: str) -> dict: + """Build heartbeat payload with robot-specific metadata.""" + data = { + "peer_type": peer_type, + "tool_name": getattr(robot, "tool_name_str", "unknown"), + } + + if peer_type == "robot": + task = getattr(robot, "_task_state", None) + if task: + data["task_status"] = getattr(task.status, "value", "unknown") + data["instruction"] = task.instruction or "" + data["step_count"] = task.step_count + elif peer_type == "sim": + world = getattr(robot, "_world", None) + if world: + data["sim_time"] = world.sim_time + data["step_count"] = world.step_count + data["robots"] = list(world.robots.keys()) + + return data diff --git a/strands_robots/device_connect/reachy_mini_driver.py b/strands_robots/device_connect/reachy_mini_driver.py new file mode 100644 index 00000000..d3a46398 --- /dev/null +++ b/strands_robots/device_connect/reachy_mini_driver.py @@ -0,0 +1,321 @@ +"""ReachyMiniDriver — Device Connect DeviceDriver for Pollen Reachy Mini robots. + +Auto-detects hardware variant via the daemon's ``wireless_version`` flag: +- **Wireless** (has onboard Pi): uses Zenoh transport for real-time I/O. +- **Lite** (USB-only, no Pi): uses WebSocket to the daemon directly. + +REST API calls go through reachy_transport.api() for daemon/move operations. +""" + +import asyncio +import json +import logging +import math +from typing import Optional + +from device_connect_edge.drivers import DeviceDriver, emit, on, rpc +from device_connect_edge.types import DeviceIdentity, DeviceStatus + +from strands_robots.device_connect.reachy_transport import ( + api, + rpy_to_pose, + identity_pose, + ZenohLink, + WebSocketLink, +) + +logger = logging.getLogger(__name__) + + +class ReachyMiniDriver(DeviceDriver): + """Device Connect driver for Pollen Reachy Mini. + + Auto-detects Wireless (Zenoh) vs Lite (WebSocket) via the daemon's + ``wireless_version`` flag. REST API calls work the same for both. + """ + + device_type = "reachy_mini" + + def __init__( + self, + host: str = "reachy-mini.local", + prefix: str = "reachy_mini", + api_port: int = 8000, + ): + super().__init__() + self._host = host + self._prefix = prefix + self._api_port = api_port + self._latest_joints: Optional[dict] = None + self._latest_imu: Optional[dict] = None + self._hw = None + + @property + def identity(self) -> DeviceIdentity: + return DeviceIdentity( + device_type="reachy_mini", + manufacturer="Pollen Robotics", + model=f"Reachy Mini @ {self._host}", + description="Pollen Reachy Mini expressive robot head with antennas", + ) + + @property + def status(self) -> DeviceStatus: + return DeviceStatus(availability="idle") + + async def connect(self) -> None: + """Connect to the Reachy Mini, auto-detecting Wireless vs Lite.""" + try: + status = await asyncio.to_thread( + api, self._host, self._api_port, "/api/daemon/status" + ) + is_lite = not status.get("wireless_version", True) + except Exception: + is_lite = False + + if is_lite: + self._hw = WebSocketLink(self._host, self._api_port) + logger.info("Connected to Reachy Mini Lite at %s (WebSocket)", self._host) + else: + self._hw = ZenohLink(self.transport, self._prefix) + logger.info("Connected to Reachy Mini at %s (Zenoh)", self._host) + + await self._hw.start( + on_joints=lambda d: setattr(self, "_latest_joints", d), + on_imu=lambda d: setattr(self, "_latest_imu", d), + ) + + async def disconnect(self) -> None: + """Tear down the hardware link.""" + if self._hw: + await self._hw.stop() + + # ── Helpers ──────────────────────────────────────────────── + + async def _send_cmd(self, cmd: dict) -> None: + """Send a real-time command via the active hardware link.""" + await self._hw.send_cmd(cmd) + + # ── Movement RPCs (Zenoh via transport) ──────────────────── + + @rpc() + async def look( + self, + pitch: float = 0, + roll: float = 0, + yaw: float = 0, + x: float = 0, + y: float = 0, + z: float = 0, + ) -> dict: + """Set head pose instantly. + + Args: + pitch: Pitch angle in degrees + roll: Roll angle in degrees + yaw: Yaw angle in degrees + x: X offset in mm + y: Y offset in mm + z: Z offset in mm + """ + await self._send_cmd({"head_pose": rpy_to_pose(pitch, roll, yaw, x, y, z)}) + return {"status": "success", "pitch": pitch, "roll": roll, "yaw": yaw} + + @rpc() + async def antennas(self, left: float = 0, right: float = 0) -> dict: + """Set antenna angles. + + Args: + left: Left antenna angle in degrees + right: Right antenna angle in degrees + """ + await self._send_cmd( + {"antennas_joint_positions": [math.radians(left), math.radians(right)]} + ) + return {"status": "success", "left": left, "right": right} + + @rpc() + async def body(self, yaw: float = 0) -> dict: + """Set body yaw angle. + + Args: + yaw: Body yaw angle in degrees + """ + await self._send_cmd({"body_yaw": math.radians(yaw)}) + return {"status": "success", "yaw": yaw} + + # ── Sensor RPCs (cached from transport subscription) ─────── + + @rpc() + async def getJoints(self) -> dict: + """Get current joint positions (head + antennas).""" + d = self._latest_joints + if d is not None: + head = d.get("head_joint_positions", []) + ant = d.get("antennas_joint_positions", []) + return { + "status": "success", + "head": [math.degrees(j) for j in head], + "antennas": [math.degrees(j) for j in ant], + } + return {"status": "error", "reason": "no joint data"} + + @rpc() + async def getImu(self) -> dict: + """Get IMU data (accelerometer, gyroscope, quaternion, temperature).""" + d = self._latest_imu + if d is not None: + return { + "status": "success", + "accelerometer": d.get("accelerometer"), + "gyroscope": d.get("gyroscope"), + "quaternion": d.get("quaternion"), + "temperature": d.get("temperature"), + } + return {"status": "error", "reason": "no IMU data"} + + # ── Motor RPCs (Zenoh via transport) ─────────────────────── + + @rpc() + async def enableMotors(self, motor_ids: str = "") -> dict: + """Enable motors (torque on). + + Args: + motor_ids: Comma-separated motor IDs (empty = all) + """ + ids = [s.strip() for s in motor_ids.split(",") if s.strip()] or None + await self._send_cmd({"torque": True, "ids": ids}) + return {"status": "success", "enabled": motor_ids or "all"} + + @rpc() + async def disableMotors(self, motor_ids: str = "") -> dict: + """Disable motors (torque off). + + Args: + motor_ids: Comma-separated motor IDs (empty = all) + """ + ids = [s.strip() for s in motor_ids.split(",") if s.strip()] or None + await self._send_cmd({"torque": False, "ids": ids}) + return {"status": "success", "disabled": motor_ids or "all"} + + # ── Move RPCs (REST) ────────────────────────────────────── + + @rpc() + async def playMove(self, move_name: str, library: str = "emotions") -> dict: + """Play a recorded move from the HuggingFace library. + + Args: + move_name: Name of the move to play + library: Move library (emotions or dance) + """ + ds = f"pollen-robotics/reachy-mini-{'emotions' if library == 'emotions' else 'dances'}-library" + result = await asyncio.to_thread( + api, self._host, self._api_port, + f"/api/move/play/recorded-move-dataset/{ds}/{move_name}", "POST", + ) + return {"status": "success", "move": move_name, "result": result} + + @rpc() + async def listMoves(self, library: str = "emotions") -> dict: + """List available recorded moves. + + Args: + library: Move library (emotions or dance) + """ + ds = f"pollen-robotics/reachy-mini-{'emotions' if library == 'emotions' else 'dances'}-library" + result = await asyncio.to_thread( + api, self._host, self._api_port, + f"/api/move/recorded-move-datasets/list/{ds}", + ) + return {"status": "success", "moves": result} + + # ── Expression RPCs (Zenoh animations via transport) ─────── + + @rpc() + async def nod(self) -> dict: + """Nod the head (yes gesture).""" + for _ in range(3): + await self._send_cmd({"head_pose": rpy_to_pose(15, 0, 0)}) + await asyncio.sleep(0.25) + await self._send_cmd({"head_pose": rpy_to_pose(-10, 0, 0)}) + await asyncio.sleep(0.25) + await self._send_cmd({"head_pose": identity_pose()}) + return {"status": "success", "expression": "nod"} + + @rpc() + async def shake(self) -> dict: + """Shake the head (no gesture).""" + for _ in range(3): + await self._send_cmd({"head_pose": rpy_to_pose(0, 0, 25)}) + await asyncio.sleep(0.2) + await self._send_cmd({"head_pose": rpy_to_pose(0, 0, -25)}) + await asyncio.sleep(0.2) + await self._send_cmd({"head_pose": identity_pose()}) + return {"status": "success", "expression": "shake"} + + @rpc() + async def happy(self) -> dict: + """Happy antenna wiggle expression.""" + for _ in range(4): + await self._send_cmd( + {"antennas_joint_positions": [math.radians(60), math.radians(-60)]} + ) + await asyncio.sleep(0.2) + await self._send_cmd( + {"antennas_joint_positions": [math.radians(-60), math.radians(60)]} + ) + await asyncio.sleep(0.2) + await self._send_cmd({"antennas_joint_positions": [0, 0]}) + return {"status": "success", "expression": "happy"} + + # ── Lifecycle RPCs (REST) ───────────────────────────────── + + @rpc() + async def wakeUp(self) -> dict: + """Wake up the robot (enable motors + play wake animation).""" + result = await asyncio.to_thread( + api, self._host, self._api_port, "/api/move/play/wake_up", "POST", + ) + return {"status": "success", "result": result} + + @rpc() + async def sleep(self) -> dict: + """Put robot to sleep (play sleep animation + disable motors).""" + result = await asyncio.to_thread( + api, self._host, self._api_port, "/api/move/play/goto_sleep", "POST", + ) + return {"status": "success", "result": result} + + @rpc() + async def stopMotion(self) -> dict: + """Stop all current motion.""" + result = await asyncio.to_thread( + api, self._host, self._api_port, "/api/move/stop", "POST", + ) + return {"status": "success", "result": result} + + @rpc() + async def getDaemonStatus(self) -> dict: + """Get daemon status, motor state, and control frequency.""" + result = await asyncio.to_thread( + api, self._host, self._api_port, "/api/daemon/status", + ) + return {"status": "success", **result} + + # ── Events ──────────────────────────────────────────────── + + @emit() + async def emergencyStop(self, reason: str = ""): + """Emitted when this device triggers an emergency stop. + + Args: + reason: Why the emergency stop was triggered + """ + pass + + @on(event_name="emergencyStop") + async def onEmergencyStop(self, device_id: str, event_name: str, payload: dict): + """React to emergencyStop — disable motors and stop motion.""" + logger.warning("Emergency stop received from %s — disabling motors", device_id) + await self.stopMotion() + await self.disableMotors() diff --git a/strands_robots/device_connect/reachy_transport.py b/strands_robots/device_connect/reachy_transport.py new file mode 100644 index 00000000..402bac4b --- /dev/null +++ b/strands_robots/device_connect/reachy_transport.py @@ -0,0 +1,163 @@ +"""Shared transport helpers for Reachy Mini robots. + +REST API helpers, pose math, and hardware link abstractions +used by ReachyMiniDriver. +""" + +import asyncio +import json +import logging +import math +import socket +from abc import ABC, abstractmethod +from typing import Callable, Optional + +logger = logging.getLogger(__name__) + + +def resolve_host(host: str) -> str: + """Resolve hostname to IP address.""" + try: + return socket.gethostbyname(host) + except socket.gaierror: + return host + + +# ── REST API ───────────────────────────────────────────────────── + +def api(host: str, port: int, path: str, method: str = "GET", data: Optional[dict] = None) -> dict: + """Call Reachy Mini daemon REST API.""" + import urllib.error + import urllib.request + url = f"http://{host}:{port}{path}" + req = urllib.request.Request(url, method=method) + req.add_header("Content-Type", "application/json") + body = json.dumps(data).encode() if data else None + try: + with urllib.request.urlopen(req, body, timeout=10) as resp: + return json.loads(resp.read().decode()) + except urllib.error.HTTPError as e: + return {"error": e.read().decode(), "code": e.code} + except Exception as e: + return {"error": str(e)} + + +# ── Pose math ──────────────────────────────────────────────────── + +def rpy_to_pose(pitch_deg: float, roll_deg: float, yaw_deg: float, + x_mm: float = 0, y_mm: float = 0, z_mm: float = 0) -> list: + """Convert RPY (degrees) + XYZ (mm) to 4x4 pose matrix.""" + p, r, y = math.radians(pitch_deg), math.radians(roll_deg), math.radians(yaw_deg) + cr, sr = math.cos(r), math.sin(r) + cp, sp = math.cos(p), math.sin(p) + cy, sy = math.cos(y), math.sin(y) + return [ + [cy*cp, cy*sp*sr - sy*cr, cy*sp*cr + sy*sr, x_mm/1000], + [sy*cp, sy*sp*sr + cy*cr, sy*sp*cr - cy*sr, y_mm/1000], + [-sp, cp*sr, cp*cr, z_mm/1000], + [0, 0, 0, 1], + ] + + +def identity_pose() -> list: + """Return a 4x4 identity pose matrix.""" + return [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]] + + +# ── Hardware link abstraction ─────────────────────────────────── + +class HardwareLink(ABC): + """Abstract interface for real-time I/O with Reachy Mini hardware.""" + + @abstractmethod + async def start(self, on_joints: Callable, on_imu: Callable) -> None: + """Begin receiving sensor data and enable command sending.""" + + @abstractmethod + async def stop(self) -> None: + """Tear down the connection.""" + + @abstractmethod + async def send_cmd(self, cmd: dict) -> None: + """Send a real-time command to the robot.""" + + +class ZenohLink(HardwareLink): + """Wireless variant — real-time I/O via Device Connect's Zenoh transport.""" + + def __init__(self, transport, prefix: str): + self._transport = transport + self._prefix = prefix + + async def start(self, on_joints: Callable, on_imu: Callable) -> None: + async def _on_joints(data: bytes, _reply=None): + try: + on_joints(json.loads(data.decode())) + except Exception: + pass + + async def _on_imu(data: bytes, _reply=None): + try: + on_imu(json.loads(data.decode())) + except Exception: + pass + + await self._transport.subscribe(f"{self._prefix}/joint_positions", _on_joints) + await self._transport.subscribe(f"{self._prefix}/imu_data", _on_imu) + + async def stop(self) -> None: + pass # Transport teardown handled by DeviceRuntime + + async def send_cmd(self, cmd: dict) -> None: + await self._transport.publish( + f"{self._prefix}/command", json.dumps(cmd).encode() + ) + + +class WebSocketLink(HardwareLink): + """Lite variant — real-time I/O via daemon's WebSocket.""" + + _WS_CMD_MAP = { + "head_pose": lambda c: {"type": "set_target", "head": [v for row in c["head_pose"] for v in row]}, + "antennas_joint_positions": lambda c: {"type": "set_antennas", "antennas": c["antennas_joint_positions"]}, + "body_yaw": lambda c: {"type": "set_body_yaw", "body_yaw": c["body_yaw"]}, + "torque": lambda c: {"type": "set_torque", "on": c["torque"], "ids": c.get("ids")}, + } + + def __init__(self, host: str, port: int): + self._host = host + self._port = port + self._ws = None + self._read_task: Optional[asyncio.Task] = None + + async def start(self, on_joints: Callable, on_imu: Callable) -> None: + import websockets + + self._ws = await websockets.connect(f"ws://{self._host}:{self._port}/ws/sdk") + self._read_task = asyncio.create_task(self._read_loop(on_joints, on_imu)) + + async def _read_loop(self, on_joints: Callable, on_imu: Callable) -> None: + async for raw in self._ws: + try: + msg = json.loads(raw) + t = msg.get("type") + if t == "joint_positions": + on_joints(msg) + elif t == "imu_data": + on_imu(msg) + except Exception: + pass + + async def stop(self) -> None: + if self._read_task: + self._read_task.cancel() + if self._ws: + await self._ws.close() + + async def send_cmd(self, cmd: dict) -> None: + if not self._ws: + return + for key, fn in self._WS_CMD_MAP.items(): + if key in cmd: + await self._ws.send(json.dumps(fn(cmd))) + return diff --git a/strands_robots/device_connect/robot_driver.py b/strands_robots/device_connect/robot_driver.py new file mode 100644 index 00000000..78b87e7b --- /dev/null +++ b/strands_robots/device_connect/robot_driver.py @@ -0,0 +1,200 @@ +"""RobotDeviceDriver — Device Connect DeviceDriver adapter wrapping a strands-robots Robot. + +Exposes the Robot's task execution, status, and observation methods as +structured RPCs and events via Device Connect's DeviceDriver interface. +""" + +import asyncio +import logging + +from device_connect_edge.drivers import DeviceDriver, emit, on, periodic, rpc +from device_connect_edge.types import DeviceIdentity, DeviceStatus + +logger = logging.getLogger(__name__) + + +class RobotDeviceDriver(DeviceDriver): + """Device Connect device driver wrapping a strands-robots Robot instance.""" + + device_type = "strands_robot" + + def __init__(self, robot): + super().__init__() + self._robot = robot + + @property + def identity(self) -> DeviceIdentity: + return DeviceIdentity( + device_type="strands_robot", + manufacturer="strands-robots", + model=getattr(self._robot, "tool_name_str", "robot"), + description="Strands Robots LeRobot-based robot arm", + ) + + @property + def status(self) -> DeviceStatus: + task = getattr(self._robot, "_task_state", None) + is_busy = ( + task is not None + and hasattr(task, "status") + and getattr(task.status, "value", "idle") == "running" + ) + return DeviceStatus( + availability="busy" if is_busy else "idle", + busy_score=1.0 if is_busy else 0.0, + ) + + async def connect(self) -> None: + """No-op — the Robot manages its own hardware connection.""" + pass + + async def disconnect(self) -> None: + """No-op — the Robot manages its own hardware shutdown.""" + pass + + # ── RPCs ────────────────────────────────────────────────── + + @rpc() + async def execute( + self, + instruction: str, + policy_provider: str = "mock", + duration: float = 30.0, + policy_port: int = 0, + ) -> dict: + """Execute a VLA task instruction on the robot. + + Args: + instruction: Natural language task instruction + policy_provider: Policy backend (groot, mock, lerobot_local, ...) + duration: Maximum task duration in seconds + policy_port: Policy server port (0 for default) + """ + return self._robot.start_task( + instruction, + policy_provider, + policy_port or None, + "localhost", + duration, + ) + + @rpc() + async def stop(self) -> dict: + """Stop the currently running task.""" + return self._robot.stop_task() + + @rpc() + async def getStatus(self) -> dict: + """Get current task execution status.""" + return self._robot.get_task_status() + + @rpc() + async def getFeatures(self) -> dict: + """Get robot observation and action features.""" + get_features = getattr(self._robot, "get_features", None) + if callable(get_features): + return get_features() + # Main's HardwareRobot does not expose get_features(); degrade gracefully. + return {"features": {}, "note": "get_features unavailable on this robot"} + + @rpc() + async def getState(self) -> dict: + """Get current robot state (joints, task info). + + Returns joint positions and task state if a task is running. + """ + result = {} + task = getattr(self._robot, "_task_state", None) + if task: + result["task_status"] = getattr(task.status, "value", "unknown") + result["instruction"] = task.instruction + result["step_count"] = task.step_count + + # Try to read observation from the underlying LeRobot robot + inner = getattr(self._robot, "robot", None) + if inner and hasattr(inner, "get_observation"): + try: + obs = await asyncio.to_thread(inner.get_observation) + # Filter out camera frames (numpy arrays) — only include scalars + result["joints"] = { + k: float(v) + for k, v in obs.items() + if not hasattr(v, "shape") + } + except Exception as e: + logger.debug("Could not read observation: %s", e) + + return result + + # ── Events ──────────────────────────────────────────────── + + @emit() + async def taskStarted(self, instruction: str, policy_provider: str): + """Emitted when a VLA task begins execution. + + Args: + instruction: The task instruction + policy_provider: The policy backend used + """ + pass + + @emit() + async def taskComplete(self, instruction: str, steps: int, duration: float): + """Emitted when a VLA task finishes. + + Args: + instruction: The task instruction + steps: Total steps executed + duration: Total execution time in seconds + """ + pass + + @emit() + async def streamStep(self, step: int, observation: dict, action: dict): + """Emitted for each VLA inference step (high frequency). + + Args: + step: Step number + observation: Observation dict (joints only, no camera frames) + action: Action dict + """ + pass + + @emit() + async def emergencyStop(self, reason: str = ""): + """Emitted when this device triggers an emergency stop. + + Args: + reason: Why the emergency stop was triggered + """ + pass + + @on(event_name="emergencyStop") + async def onEmergencyStop(self, device_id: str, event_name: str, payload: dict): + """React to emergencyStop from ANY device on the network.""" + logger.warning("Emergency stop received from %s — stopping task", device_id) + self._robot.stop_task() + + # ── Periodic state publishing ───────────────────────────── + + @periodic(interval=0.1, wait_for_completion=True) + async def _publishState(self): + """Publish robot state at 10Hz.""" + task = getattr(self._robot, "_task_state", None) + if task and getattr(task.status, "value", "idle") == "running": + await self.stateUpdate( + task_status="running", + instruction=task.instruction, + step_count=task.step_count, + ) + + @emit() + async def stateUpdate(self, task_status: str = "", instruction: str = "", step_count: int = 0): + """Periodic state update. + + Args: + task_status: Current task status + instruction: Current task instruction + step_count: Steps completed so far + """ + pass diff --git a/strands_robots/device_connect/setup.sh b/strands_robots/device_connect/setup.sh new file mode 100755 index 00000000..4f735654 --- /dev/null +++ b/strands_robots/device_connect/setup.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# setup.sh — one-command environment setup for Strands Robots + Device Connect +# +# Usage: +# ./strands_robots/device_connect/setup.sh +# +set -euo pipefail + +PYTHON_VERSION="3.12" +VENV_DIR=".venv" +REPO_ROOT="$(cd "$(dirname "$0")/../.." && pwd)" + +echo "============================================================" +echo " Strands Robots — Environment Setup" +echo "============================================================" +echo "" + +# ── 0. Install uv (if needed) ──────────────────────────────────────────────── +if ! command -v uv &>/dev/null; then + echo "[0/2] uv not found — installing..." + curl -LsSf https://astral.sh/uv/install.sh | sh + export PATH="$HOME/.local/bin:$PATH" +else + echo "[0/2] uv $(uv --version) ✓" +fi + +# ── 1. Install Python (if needed) ──────────────────────────────────────────── +if ! uv python find "$PYTHON_VERSION" &>/dev/null; then + echo "[1/2] Python $PYTHON_VERSION not found — installing via uv..." + uv python install "$PYTHON_VERSION" +else + echo "[1/2] Python $PYTHON_VERSION ✓" +fi + +# ── 2. Create virtual environment and install ──────────────────────────────── +if [ ! -d "$REPO_ROOT/$VENV_DIR" ]; then + echo "[2/2] Creating virtual environment and installing packages..." + uv venv --python "$PYTHON_VERSION" "$REPO_ROOT/$VENV_DIR" +else + echo "[2/2] Virtual environment exists, installing packages..." +fi + +# shellcheck disable=SC1091 +source "$REPO_ROOT/$VENV_DIR/bin/activate" +uv pip install -e "$REPO_ROOT[sim]" + +echo "" +echo "============================================================" +echo " Setup complete" +echo "============================================================" +echo "" +echo "Activate the environment:" +echo " source $REPO_ROOT/$VENV_DIR/bin/activate" diff --git a/strands_robots/device_connect/sim_driver.py b/strands_robots/device_connect/sim_driver.py new file mode 100644 index 00000000..f4a040f3 --- /dev/null +++ b/strands_robots/device_connect/sim_driver.py @@ -0,0 +1,230 @@ +"""SimulationDeviceDriver — Device Connect DeviceDriver adapter wrapping a strands-robots Simulation. + +Exposes the Simulation's physics stepping, policy execution, and world +state as structured RPCs and events via Device Connect's DeviceDriver interface. +""" + +import logging + +from device_connect_edge.drivers import DeviceDriver, emit, on, periodic, rpc +from device_connect_edge.types import DeviceIdentity, DeviceStatus + +logger = logging.getLogger(__name__) + + +class SimulationDeviceDriver(DeviceDriver): + """Device Connect device driver wrapping a strands-robots Simulation instance.""" + + device_type = "strands_sim" + + def __init__(self, sim): + super().__init__() + self._sim = sim + + @property + def identity(self) -> DeviceIdentity: + return DeviceIdentity( + device_type="strands_sim", + manufacturer="strands-robots", + model=getattr(self._sim, "tool_name_str", "simulation"), + description="Strands Robots MuJoCo simulation", + ) + + @property + def status(self) -> DeviceStatus: + world = getattr(self._sim, "_world", None) + is_busy = False + if world: + for robot in world.robots.values(): + if getattr(robot, "policy_running", False): + is_busy = True + break + return DeviceStatus( + availability="busy" if is_busy else "idle", + busy_score=1.0 if is_busy else 0.0, + ) + + async def connect(self) -> None: + """No-op — the Simulation manages its own MuJoCo state.""" + pass + + async def disconnect(self) -> None: + """No-op — the Simulation manages its own cleanup.""" + pass + + # ── RPCs ────────────────────────────────────────────────── + + @rpc() + async def execute( + self, + instruction: str, + policy_provider: str = "mock", + duration: float = 30.0, + robot_name: str = "", + ) -> dict: + """Execute a policy on a simulated robot. + + Args: + instruction: Natural language task instruction + policy_provider: Policy backend (mock, lerobot_local, ...) + duration: Maximum task duration in seconds + robot_name: Target robot name (empty = first robot) + """ + # Determine robot name + name = robot_name + if not name: + world = getattr(self._sim, "_world", None) + if world and world.robots: + name = next(iter(world.robots)) + else: + return {"status": "error", "reason": "no robots in simulation"} + + print(f"▶ Executing policy '{policy_provider}' on {name}: {instruction}", flush=True) + return self._sim.start_policy( + robot_name=name, + policy_provider=policy_provider, + instruction=instruction, + duration=duration, + ) + + @rpc() + async def stop(self) -> dict: + """Stop all running policies.""" + print("⏹ Stop command received — stopping all policies", flush=True) + world = getattr(self._sim, "_world", None) + if world: + for robot in world.robots.values(): + robot.policy_running = False + return {"status": "success", "content": [{"text": "All policies stopped"}]} + + @rpc() + async def getStatus(self) -> dict: + """Get simulation state and running policies.""" + if hasattr(self._sim, "get_state"): + return self._sim.get_state() + return {"status": "idle"} + + @rpc() + async def getFeatures(self) -> dict: + """Get simulation features (joints, actuators, cameras).""" + return self._sim.get_features() + + @rpc() + async def step(self, n_steps: int = 1) -> dict: + """Step simulation physics forward. + + Args: + n_steps: Number of physics steps to take + """ + return self._sim.step(n_steps) + + @rpc() + async def reset(self) -> dict: + """Reset simulation to initial state.""" + return self._sim.reset() + + # ── Events ──────────────────────────────────────────────── + + @emit() + async def policyStarted(self, robot_name: str, instruction: str, policy_provider: str): + """Emitted when a policy begins execution. + + Args: + robot_name: The simulated robot running the policy + instruction: The task instruction + policy_provider: The policy backend used + """ + pass + + @emit() + async def policyComplete(self, robot_name: str, instruction: str, steps: int): + """Emitted when a policy finishes. + + Args: + robot_name: The simulated robot + instruction: The task instruction + steps: Total steps executed + """ + pass + + @emit() + async def emergencyStop(self, reason: str = ""): + """Emitted when this device triggers an emergency stop. + + Args: + reason: Why the emergency stop was triggered + """ + pass + + @on(event_name="emergencyStop") + async def onEmergencyStop(self, device_id: str, event_name: str, payload: dict): + """React to emergencyStop from ANY device on the network.""" + print(f"🛑 Emergency stop received from {device_id} — stopping all policies", flush=True) + world = getattr(self._sim, "_world", None) + if world: + for robot in world.robots.values(): + robot.policy_running = False + + # ── Periodic state publishing ───────────────────────────── + + @periodic(interval=0.1, wait_for_completion=True) + async def _publishState(self): + """Publish simulation state at 10Hz.""" + world = getattr(self._sim, "_world", None) + if not world: + return + running = { + name: {"steps": r.policy_steps, "instruction": r.policy_instruction} + for name, r in world.robots.items() + if r.policy_running + } + if running: + await self.stateUpdate( + sim_time=world.sim_time, + step_count=world.step_count, + running_policies=running, + ) + # Publish per-robot joint observations from MuJoCo state + data = getattr(world, "_data", None) + robots = world.robots if isinstance(world.robots, dict) else {} + for name, robot in robots.items(): + try: + joint_names = getattr(robot, "joint_names", []) + joint_ids = getattr(robot, "joint_ids", []) + joints = {} + if data is not None and joint_names and joint_ids: + for jname, jid in zip(joint_names, joint_ids): + joints[jname] = float(data.qpos[jid]) + await self.observationUpdate( + robot_name=name, + sim_time=world.sim_time, + step_count=world.step_count, + joints=joints, + ) + except Exception as e: + logger.debug("observationUpdate skipped for %s: %s", name, e) + + @emit() + async def stateUpdate(self, sim_time: float = 0.0, step_count: int = 0, running_policies: dict = None): + """Periodic simulation state update. + + Args: + sim_time: Current simulation time + step_count: Total physics steps + running_policies: Dict of running policy info per robot + """ + pass + + @emit() + async def observationUpdate( + self, robot_name: str = "", sim_time: float = 0.0, step_count: int = 0, joints: dict = None + ): + """Periodic per-robot observation with joint positions. + + Args: + robot_name: Name of the robot + sim_time: Current simulation time + step_count: Total physics steps + joints: Dict of joint name -> position (radians) + """ + pass diff --git a/strands_robots/device_connect/test_control_loop_dc.sh b/strands_robots/device_connect/test_control_loop_dc.sh new file mode 100755 index 00000000..62eb4b95 --- /dev/null +++ b/strands_robots/device_connect/test_control_loop_dc.sh @@ -0,0 +1,197 @@ +#!/usr/bin/env bash +# test_control_loop_dc.sh — End-to-end test: control loop + Zenoh event listener +# +# Verifies that Robot("so100") publishes Device Connect events over Zenoh +# while a mock-policy control loop is running. +# +# Usage: +# bash strands_robots/device_connect/test_control_loop_dc.sh +# +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "$0")/../.." && pwd)" +WORKSPACE_ROOT="$(cd "$REPO_ROOT/.." && pwd)" +export MUJOCO_GL="${MUJOCO_GL:-egl}" +export DEVICE_CONNECT_ALLOW_INSECURE=true + +EVENTS_LOG=$(mktemp /tmp/zenoh_events_XXXX.log) +LOOP_LOG=$(mktemp /tmp/control_loop_XXXX.log) +LISTENER_PID="" + +cleanup() { + [ -n "$LISTENER_PID" ] && kill "$LISTENER_PID" 2>/dev/null || true + echo "" + echo "Logs:" + echo " Events: $EVENTS_LOG" + echo " Control loop: $LOOP_LOG" +} +trap cleanup EXIT + +# ── 1. Install dependencies ──────────────────────────────────────────── +echo "==> Installing device-connect-edge..." +pip install -e "$WORKSPACE_ROOT/device-connect/packages/device-connect-edge" -q + +echo "==> Installing device-connect-agent-tools..." +pip install -e "$WORKSPACE_ROOT/device-connect/packages/device-connect-agent-tools[strands]" -q + +echo "==> Installing strands-robots[sim]..." +pip install -e "$REPO_ROOT[sim]" -q + +echo "==> All dependencies installed." +echo "" + +# ── 2. Start Zenoh listener ──────────────────────────────────────────── +echo "==> Starting Zenoh event listener..." +python3 -c " +import json, time, zenoh + +def on_sample(sample): + try: + data = json.loads(sample.payload.to_bytes().decode()) + except Exception: + data = str(sample.payload.to_bytes().decode()[:200]) + print(f'[{time.strftime(\"%H:%M:%S\")}] {sample.key_expr}: {json.dumps(data, default=str)}', flush=True) + +session = zenoh.open(zenoh.Config()) +sub = session.declare_subscriber('device-connect/**', on_sample) +print('LISTENER_READY', flush=True) +try: + while True: + time.sleep(0.1) +except KeyboardInterrupt: + pass +finally: + sub.undeclare() + session.close() +" > "$EVENTS_LOG" 2>&1 & +LISTENER_PID=$! + +# Wait for listener to be ready +for i in $(seq 1 30); do + grep -q "LISTENER_READY" "$EVENTS_LOG" 2>/dev/null && break + sleep 0.2 +done +echo " Listener PID: $LISTENER_PID" +echo "" + +# ── 3. Run the control loop ──────────────────────────────────────────── +echo "==> Running control loop (200 steps @ 50Hz)..." +python3 -c " +import os, sys, time +os.environ.setdefault('MUJOCO_GL', 'egl') + +from strands_robots.factory import Robot +from strands_robots.policies import create_policy + +robot = Robot('so100') +# Wait for DC runtime to connect and start periodic publishers +time.sleep(3) + +policy = create_policy('mock') +for step in range(200): + obs = robot.get_observation() + action = policy.get_actions_sync(obs, instruction='pick up the cube') + robot.apply_action(action) + if step % 50 == 0: + print(f' Step {step}/200', flush=True) + +print('Control loop done.', flush=True) +robot.cleanup() +print('Cleanup complete.', flush=True) +" 2>&1 | tee "$LOOP_LOG" + +# Give trailing events a moment to arrive +sleep 2 + +# ── 4. Stop the listener ─────────────────────────────────────────────── +kill "$LISTENER_PID" 2>/dev/null || true +wait "$LISTENER_PID" 2>/dev/null || true +LISTENER_PID="" + +# ── 5. Validate captured events ──────────────────────────────────────── +echo "" +echo "============================================================" +echo " ZENOH EVENT SUMMARY" +echo "============================================================" + +TOTAL=$(grep -c '^\[' "$EVENTS_LOG" 2>/dev/null || echo 0) +STATE_UPDATES=$(grep -c 'event/stateUpdate' "$EVENTS_LOG" 2>/dev/null || echo 0) +OBS_UPDATES=$(grep -c 'event/observationUpdate' "$EVENTS_LOG" 2>/dev/null || echo 0) +PRESENCE=$(grep -c '/presence' "$EVENTS_LOG" 2>/dev/null || echo 0) +HEARTBEATS=$(grep -c '/heartbeat' "$EVENTS_LOG" 2>/dev/null || echo 0) + +printf " %-25s %s\n" "stateUpdate events:" "$STATE_UPDATES" +printf " %-25s %s\n" "observationUpdate events:" "$OBS_UPDATES" +printf " %-25s %s\n" "presence events:" "$PRESENCE" +printf " %-25s %s\n" "heartbeat events:" "$HEARTBEATS" +printf " %-25s %s\n" "TOTAL:" "$TOTAL" +echo "" + +# Show a sample observationUpdate with joint data +SAMPLE_OBS=$(grep 'event/observationUpdate' "$EVENTS_LOG" | tail -1) +if [ -n "$SAMPLE_OBS" ]; then + echo " Sample observationUpdate:" + echo " $SAMPLE_OBS" | python3 -c " +import sys, json +line = sys.stdin.read().strip() +payload = line.split(': ', 1)[1] +data = json.loads(payload) +params = data.get('params', {}) +print(f\" robot: {params.get('robot_name')}\") +print(f\" sim_time: {params.get('sim_time')}\") +print(f\" step: {params.get('step_count')}\") +joints = params.get('joints', {}) +for name, val in joints.items(): + print(f\" {name:>15s}: {val:+.6f} rad\") +" 2>/dev/null || echo " (could not parse sample)" + echo "" +fi + +# ── 6. Assert minimum thresholds ─────────────────────────────────────── +PASS=true + +if [ "$TOTAL" -lt 10 ]; then + echo "FAIL: Expected >= 10 total events, got $TOTAL" + PASS=false +fi + +if [ "$STATE_UPDATES" -lt 5 ]; then + echo "FAIL: Expected >= 5 stateUpdate events, got $STATE_UPDATES" + PASS=false +fi + +if [ "$PRESENCE" -lt 1 ]; then + echo "FAIL: Expected >= 1 presence event, got $PRESENCE" + PASS=false +fi + +if [ "$HEARTBEATS" -lt 1 ]; then + echo "FAIL: Expected >= 1 heartbeat event, got $HEARTBEATS" + PASS=false +fi + +if [ "$OBS_UPDATES" -lt 5 ]; then + echo "FAIL: Expected >= 5 observationUpdate events, got $OBS_UPDATES" + PASS=false +fi + +# Check no "Failed to publish" in control loop output +PUBLISH_ERRORS=$(grep -c "Failed to publish" "$LOOP_LOG" 2>/dev/null || true) +PUBLISH_ERRORS="${PUBLISH_ERRORS:-0}" +if [ "$PUBLISH_ERRORS" -gt 0 ]; then + echo "FAIL: Found $PUBLISH_ERRORS 'Failed to publish' errors (missing cleanup?)" + PASS=false +fi + +if [ "$PASS" = true ]; then + echo "============================================================" + echo " ALL CHECKS PASSED" + echo "============================================================" + exit 0 +else + echo "" + echo "============================================================" + echo " SOME CHECKS FAILED — see logs above" + echo "============================================================" + exit 1 +fi diff --git a/strands_robots/robot.py b/strands_robots/robot.py index 33018bb5..b572910c 100644 --- a/strands_robots/robot.py +++ b/strands_robots/robot.py @@ -329,6 +329,7 @@ def Robot( # noqa: N802 — uppercase by design (factory mimicking a class cons except Exception as exc: # noqa: BLE001 — mesh enrichment is best-effort logger.warning("Failed to initialise mesh for %r: %s", canonical, exc) + _attach_device_connect(sim, canonical, mode, peer_id) return sim # --- Real hardware (explicit opt-in) --- @@ -366,10 +367,64 @@ def Robot( # noqa: N802 — uppercase by design (factory mimicking a class cons except Exception as exc: # noqa: BLE001 — mesh enrichment is best-effort logger.warning("Failed to initialise mesh for %r: %s", canonical, exc) + _attach_device_connect(hw, canonical, mode, peer_id) return hw else: raise ValueError(f"Invalid mode {mode!r}. Choose 'sim', 'real', or 'auto' (case-insensitive).") +def _attach_device_connect(instance: Any, canonical: str, mode: str, peer_id: str | None) -> None: + """Attach a Device Connect ``.run()`` server hook to a robot/sim instance. + + Mirrors the mesh attach above: stores peer metadata and binds ``.run()`` so + ``Robot("so100").run()`` brings the device online as a Device Connect device + (the primary networking layer), blocking until Ctrl+C. + """ + instance._peer_id = ( + peer_id or getattr(instance, "peer_id", None) or f"{canonical}-{os.urandom(3).hex()}" + ) + instance._peer_type = "sim" if mode == "sim" else "robot" + instance._device_connect_runtime = None + instance.run = lambda: _run_device_connect_foreground(instance) + + +def _run_device_connect_foreground(instance: Any) -> None: + """Start Device Connect and block — the robot listens for commands. + + Device Connect is the primary networking layer in server mode, so the + auto-started built-in mesh (if any) is stopped first to avoid running two + Zenoh presence systems in one process. + """ + import time + + peer_id = getattr(instance, "_peer_id", None) or "robot" + peer_type = getattr(instance, "_peer_type", "robot") + + # Device Connect supersedes the built-in mesh in run() mode. + mesh = getattr(instance, "mesh", None) + if mesh is not None: + with contextlib.suppress(Exception): + mesh.stop() + instance.mesh = None + + try: + from strands_robots.device_connect import init_device_connect_sync + + instance._device_connect_runtime = init_device_connect_sync( + instance, peer_id=peer_id, peer_type=peer_type, + ) + except Exception as e: # noqa: BLE001 — surface but keep the process alive + logger.warning("Device Connect init failed: %s", e) + + print(f"🤖 {peer_id} is online. Ctrl+C to stop.") + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + print(f"\n🛑 Shutting down {peer_id}...", flush=True) + print(f"👋 {peer_id} stopped.", flush=True) + os._exit(0) + + __all__ = ["Robot"] diff --git a/strands_robots/tools/robot_mesh.py b/strands_robots/tools/robot_mesh.py index cba07d30..f444003d 100644 --- a/strands_robots/tools/robot_mesh.py +++ b/strands_robots/tools/robot_mesh.py @@ -35,6 +35,7 @@ import collections import json import logging +import os import threading import time from typing import Any @@ -248,6 +249,193 @@ def _resolve_mesh(target: str) -> Any | None: return next(iter(locals_.values())) +# ── Device Connect dispatch helpers ──────────────────────────────────────── +# Device Connect is the primary discovery + RPC layer; the Zenoh mesh above is +# the fallback. These helpers are invoked by robot_mesh() AFTER its safety +# gates, so DC dispatch inherits the rate limit, HITL approval, validation, and +# audit. When DC is unavailable or has discovered no devices the helpers return +# None and robot_mesh() falls through to the built-in mesh path. + +_dc_connected = False + + +class _DCResult(dict): + """Strands tool-response dict whose ``str()`` renders the text block cleanly.""" + + def __str__(self) -> str: + content = self.get("content", []) + if content and isinstance(content[0], dict): + return content[0].get("text", super().__str__()) + return super().__str__() + + +def _dc_ensure_connected() -> None: + """Establish the Device Connect agent-side connection (idempotent).""" + global _dc_connected + if _dc_connected: + return + os.environ.setdefault("MESSAGING_BACKEND", "zenoh") + os.environ.setdefault("DEVICE_CONNECT_ALLOW_INSECURE", "true") + from device_connect_agent_tools.connection import connect, get_connection + + try: + get_connection() + except Exception: + connect() + _dc_connected = True + + +def _try_device_connect( + action: str, target: str, instruction: str, command: str, + policy_provider: str, policy_port: int, duration: float, timeout: float, +) -> dict[str, Any] | None: + """Dispatch *action* through Device Connect, or return None to fall back. + + Returns None — signalling robot_mesh() to use the built-in mesh — when + Device Connect is unavailable, has discovered no devices, or the action is + one DC does not handle (subscribe / watch / inbox / unsubscribe). + """ + if action in ("subscribe", "watch", "inbox", "unsubscribe"): + return None # mesh-only actions — let the built-in mesh handle them + if os.environ.get("STRANDS_ROBOT_MESH_DC", "on").strip().lower() in ("off", "0", "false", "no"): + return None # Device Connect dispatch disabled (e.g. hermetic unit tests) + try: + _dc_ensure_connected() + from device_connect_agent_tools.connection import get_connection + + conn = get_connection() + devices = conn.list_devices() + except Exception as exc: # noqa: BLE001 — DC is optional; fall back to mesh + logger.debug("Device Connect unavailable, using mesh fallback: %s", exc) + return None + # A well-formed connection returns a list of device dicts. Anything else + # (e.g. a malformed/stubbed connection) means DC is not usable here, so fall + # back to the built-in mesh rather than misdispatch. + if not isinstance(devices, (list, tuple)) or not devices: + return None + return _device_connect_dispatch( + action, target, instruction, command, + policy_provider, policy_port, duration, timeout, + ) + + +def _device_connect_dispatch( + action: str, target: str, instruction: str, command: str, + policy_provider: str, policy_port: int, duration: float, timeout: float, +) -> dict[str, Any] | None: + """Render a robot_mesh action through Device Connect (dev-compatible API). + + Fetches the agent-side connection via ``get_connection()`` (patchable in + unit tests) and returns a Strands tool-response dict (``_DCResult``). + """ + try: + from device_connect_agent_tools.connection import get_connection + + conn = get_connection() + if action in ("peers", "status"): + devices = conn.list_devices() + text = ( + f"Discovered {len(devices)} device(s):\n" + if action == "peers" + else f"Network: {len(devices)} device(s)\n" + ) + for d in devices: + dtype = d.get("device_type", "?") + icon = {"strands_robot": "robot", "strands_sim": "sim", + "reachy_mini": "reachy"}.get(dtype, dtype) + status = d.get("status", {}) + avail = status.get("availability", "?") if isinstance(status, dict) else "?" + text += f" [{icon}] {d['device_id']} — {avail}\n" + if action == "peers": + funcs = d.get("functions", []) + if funcs: + names = [f["name"] if isinstance(f, dict) else f for f in funcs] + text += f" Functions: {', '.join(names)}\n" + return _DCResult(_ok(text)) + + if action == "tell": + if not target or not instruction: + return _DCResult(_err("tell requires both target and instruction")) + kwargs: dict[str, Any] = {"policy_provider": policy_provider, "duration": duration} + if policy_port: + kwargs["policy_port"] = policy_port + # Inherit the mesh path's per-action command validation. + try: + _security.validate_command({"action": "execute", "instruction": instruction, **kwargs}) + except _security.ValidationError as exc: + _audit_tool_action(action, target, False, f"validation: {exc}") + return _DCResult(_err(f"tell rejected: {exc}")) + result = conn.invoke(target, "execute", {"instruction": instruction, **kwargs}, timeout=timeout) + r = result.get("result", result) + _audit_tool_action(action, target, True, f"instruction={instruction[:200]}") + return _DCResult(_ok(f"-> {target}: {instruction}\n {json.dumps(r, default=str)}")) + + if action == "send": + if not target: + return _DCResult(_err("send requires target")) + if not command: + return _DCResult(_err("send requires command (JSON string)")) + try: + cmd = json.loads(command) + except json.JSONDecodeError as exc: + return _DCResult(_err(f"command is not valid JSON: {exc}")) + if not isinstance(cmd, dict): + return _DCResult(_err("command must decode to a JSON object (dict)")) + try: + cmd = _security.validate_command(cmd) + except _security.ValidationError as exc: + _audit_tool_action(action, target, False, f"validation: {exc}") + return _DCResult(_err(f"send rejected: {exc}")) + func = cmd.pop("action", cmd.pop("function", "getStatus")) + result = conn.invoke(target, func, cmd, timeout=timeout) + r = result.get("result", result) + _audit_tool_action(action, target, True, f"action={func}") + return _DCResult(_ok(f"{target}:\n{json.dumps(r, indent=2, default=str)[:2000]}")) + + if action == "stop": + if not target: + return _DCResult(_err("stop requires target")) + result = conn.invoke(target, "stop", timeout=min(timeout, 5.0)) + r = result.get("result", result) + _audit_tool_action(action, target, True, "") + return _DCResult(_ok(f"Stop {target}: {json.dumps(r, default=str)}")) + + if action == "emergency_stop": + devices = conn.list_devices() + stopped = 0 + for d in devices: + try: + conn.invoke(d["device_id"], "stop", timeout=3.0) + stopped += 1 + except Exception: # noqa: BLE001 — best-effort fan-out + pass + _audit_tool_action(action, "*", True, f"stopped={stopped}/{len(devices)}") + return _DCResult(_ok(f"E-STOP: {stopped}/{len(devices)} devices stopped")) + + if action == "broadcast": + # Command was already parsed + validated by robot_mesh() before the + # HITL gate fired, so re-parsing here is safe. + func = "getStatus" + params: dict[str, Any] = {} + if command: + cmd = json.loads(command) + func = cmd.pop("action", cmd.pop("function", "getStatus")) + params = cmd + results = conn.broadcast(func, params, timeout=timeout) + _audit_tool_action(action, "*", True, f"action={func} responses={len(results)}") + text = f"[broadcast] {len(results)} responses\n" + for r in results[:10]: + sstr = "ok" if "result" in r else f"error: {r.get('error', '?')}" + text += f" {r.get('device_id', '?')}: {sstr}\n" + return _DCResult(_ok(text.rstrip())) + + # subscribe / watch / inbox / unsubscribe → handled by the mesh path + return None + except Exception as exc: # noqa: BLE001 — never raise out of the dispatcher + logger.debug("Device Connect dispatch error for %s: %s", action, exc) + return _DCResult(_err(f"[{action}] Device Connect error: {exc}")) + + @tool(context=True) def robot_mesh( action: str, @@ -409,6 +597,18 @@ def robot_mesh( # non-fleet-wide actions like ``tell``, ``send``, ``stop``). _rate_limit_record(action) + # ── Device Connect dispatch (primary networking layer) ───────────────── + # Every safety gate above (rate limit, HITL approval, broadcast + # pre-validation, audit) has already run, so Device Connect inherits them. + # _try_device_connect returns None when DC is unavailable or has discovered + # no devices, in which case we fall through to the built-in mesh below. + _dc_result = _try_device_connect( + action, target, instruction, command, + policy_provider, policy_port, duration, timeout, + ) + if _dc_result is not None: + return _dc_result + try: from strands_robots.mesh import get_local_robots from strands_robots.mesh.session import get_peers diff --git a/tests/conftest.py b/tests/conftest.py index 43800683..bb8ddeec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,12 @@ # can override via the environment without conftest stomping on them. os.environ.setdefault("STRANDS_MESH", "false") +# Disable the Device Connect dispatch path in robot_mesh by default so unit +# tests exercise the built-in mesh deterministically, without opening real +# Device Connect (Zenoh) connections. The GUIDE E2E demo runs outside pytest +# and leaves this unset, so Device Connect remains the primary path at runtime. +os.environ.setdefault("STRANDS_ROBOT_MESH_DC", "off") + from tests.mocks.torch_mock import install_torch_mock # Must run before any test imports policy modules diff --git a/tests/test_device_connect_all_robots.py b/tests/test_device_connect_all_robots.py new file mode 100644 index 00000000..72f5d0c1 --- /dev/null +++ b/tests/test_device_connect_all_robots.py @@ -0,0 +1,722 @@ +"""Parametrized Device Connect tests across all 38 registered robots. + +Validates that RobotDeviceDriver and SimulationDeviceDriver work correctly +with every robot's specific configuration (joint counts, observation shapes, +identity, status, RPC delegation). Also tests multi-robot simulation scenarios, +edge cases, and robot_mesh dispatch with diverse device types. + +All external dependencies (Zenoh, LeRobot, device_connect_edge, strands) are mocked. +No Docker, GPU, or hardware required. +""" + +import asyncio +import json +import pathlib +import sys +from dataclasses import dataclass +from enum import Enum +from unittest.mock import MagicMock, patch + +import pytest + +# ── Mock heavy dependencies before importing ────────────────────── + +mock_device_connect_edge = MagicMock() +mock_drivers = MagicMock() + + +class _FakeDeviceDriver: + """Minimal stub so our drivers can subclass it.""" + + device_type = None + + def __init__(self): + self._transport = None + + def set_device(self, device): + pass + + @property + def transport(self): + return self._transport + + +def _passthrough_decorator(*args, **kwargs): + if len(args) == 1 and callable(args[0]): + return args[0] + + def wrapper(func): + for k, v in kwargs.items(): + setattr(func, f"_{k}", v) + return func + + return wrapper + + +mock_drivers.DeviceDriver = _FakeDeviceDriver +mock_drivers.rpc = _passthrough_decorator +mock_drivers.emit = _passthrough_decorator +mock_drivers.periodic = _passthrough_decorator +mock_drivers.on = _passthrough_decorator + +mock_types = MagicMock() + + +@dataclass +class FakeDeviceIdentity: + device_type: str = None + manufacturer: str = None + model: str = None + description: str = None + serial_number: str = None + firmware_version: str = None + arch: str = None + commissioning_comment: str = None + + +@dataclass +class FakeDeviceStatus: + availability: str = "idle" + busy_score: float = 0.0 + location: str = None + battery: int = None + online: bool = True + error_state: str = None + + +mock_types.DeviceIdentity = FakeDeviceIdentity +mock_types.DeviceStatus = FakeDeviceStatus + +_saved_modules = {} +_mock_keys = ( + "device_connect_edge", + "device_connect_edge.drivers", + "device_connect_edge.types", + "device_connect_edge.device", +) +_strands_dc_keys = [k for k in sys.modules if k.startswith("strands_robots.device_connect")] +for _key in list(_mock_keys) + _strands_dc_keys: + _saved_modules[_key] = sys.modules.get(_key) + +sys.modules["device_connect_edge"] = mock_device_connect_edge +sys.modules["device_connect_edge.drivers"] = mock_drivers +sys.modules["device_connect_edge.types"] = mock_types +sys.modules["device_connect_edge.device"] = MagicMock() + +mock_device_runtime = MagicMock() +mock_device_connect_edge.DeviceRuntime = mock_device_runtime + +from strands_robots.device_connect.robot_driver import RobotDeviceDriver # noqa: E402 +from strands_robots.device_connect.sim_driver import SimulationDeviceDriver # noqa: E402 + + +def teardown_module(): + """Restore real device_connect_edge modules.""" + for key, original in _saved_modules.items(): + if original is None: + sys.modules.pop(key, None) + else: + sys.modules[key] = original + for key in list(sys.modules): + if key.startswith("strands_robots.device_connect"): + sys.modules.pop(key, None) + + +# ── Load robot registry ────────────────────────────────────────── + +_REGISTRY_PATH = pathlib.Path(__file__).resolve().parents[1] / "strands_robots" / "registry" / "robots.json" +_REGISTRY = json.loads(_REGISTRY_PATH.read_text())["robots"] + +ALL_ROBOTS = [(name, info) for name, info in _REGISTRY.items()] +SIM_ROBOTS = [(name, info) for name, info in ALL_ROBOTS if "asset" in info] +REAL_ONLY_ROBOTS = [(name, info) for name, info in ALL_ROBOTS if "asset" not in info] + + +# ── Task state mocks ───────────────────────────────────────────── + + +class FakeTaskStatus(Enum): + IDLE = "idle" + RUNNING = "running" + COMPLETED = "completed" + STOPPED = "stopped" + ERROR = "error" + + +@dataclass +class FakeTaskState: + status: FakeTaskStatus = FakeTaskStatus.IDLE + instruction: str = "" + start_time: float = 0.0 + duration: float = 0.0 + step_count: int = 0 + error_message: str = "" + + +# ── Observation helper ─────────────────────────────────────────── + + +class _FakeArray: + """Mimics a numpy array with a .shape attribute.""" + + def __init__(self, shape): + self.shape = shape + + +def _generate_observation(joint_count, include_arrays=True): + """Generate a realistic observation dict for a robot with N joints.""" + obs = {} + for i in range(joint_count): + obs[f"joint_{i}"] = float(i) * 0.1 + if include_arrays: + obs["image"] = _FakeArray((480, 640, 3)) + obs["depth"] = _FakeArray((480, 640)) + return obs + + +# ── Mock factories ─────────────────────────────────────────────── + + +def _get_joint_count(info): + """Get joint count from registry info, defaulting to 6 for real-only robots.""" + return info.get("joints", 6) + + +def _make_mock_robot(name, info, task_status="idle"): + """Create a mock robot matching the registry entry's configuration.""" + joint_count = _get_joint_count(info) + robot = MagicMock() + robot.tool_name_str = name + robot._task_state = FakeTaskState( + status=FakeTaskStatus(task_status), + instruction="pick up the cube" if task_status == "running" else "", + step_count=42 if task_status == "running" else 0, + ) + robot.start_task.return_value = {"status": "success", "content": [{"text": "Task started"}]} + robot.stop_task.return_value = {"status": "success", "content": [{"text": "Task stopped"}]} + robot.get_task_status.return_value = {"status": "success", "content": [{"text": "Status info"}]} + + features = {f"joint_{i}": "float" for i in range(joint_count)} + robot.get_features.return_value = { + "status": "success", + "content": [{"json": {"observation_features": features, "action_features": features}}], + } + + robot.robot = MagicMock() + robot.robot.get_observation.return_value = _generate_observation(joint_count) + return robot + + +def _make_mock_sim(name, info, robots_in_world=None): + """Create a mock simulation matching the registry entry's configuration.""" + sim = MagicMock() + sim.tool_name_str = f"{name}_sim" + + world = MagicMock() + if robots_in_world is None: + robot_data = MagicMock() + robot_data.policy_running = False + robot_data.policy_steps = 0 + robot_data.policy_instruction = "" + world.robots = {name: robot_data} + else: + world.robots = robots_in_world + world.sim_time = 0.0 + world.step_count = 0 + sim._world = world + + sim.start_policy.return_value = {"status": "success", "content": [{"text": "Policy started"}]} + sim.get_state.return_value = {"status": "success", "content": [{"text": "State info"}]} + sim.get_features.return_value = {"status": "success", "content": [{"json": {"features": {}}}]} + sim.step.return_value = {"status": "success", "content": [{"text": "Stepped"}]} + sim.reset.return_value = {"status": "success", "content": [{"text": "Reset"}]} + return sim + + +# ── TestRobotDriverAllRobots ───────────────────────────────────── + + +class TestRobotDriverAllRobots: + """Parametrized tests for RobotDeviceDriver across all 38 registered robots.""" + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_identity(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + assert driver.identity.device_type == "strands_robot" + assert driver.identity.model == robot_name + assert driver.identity.manufacturer == "strands-robots" + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_status_idle(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info, task_status="idle") + driver = RobotDeviceDriver(robot) + assert driver.status.availability == "idle" + assert driver.status.busy_score == 0.0 + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_status_busy(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info, task_status="running") + driver = RobotDeviceDriver(robot) + assert driver.status.availability == "busy" + assert driver.status.busy_score == 1.0 + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_execute_delegates(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + result = asyncio.run( + driver.execute("pick up cube", "groot", 30.0, 0) + ) + robot.start_task.assert_called_once_with("pick up cube", "groot", None, "localhost", 30.0) + assert result["status"] == "success" + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_get_state_joint_count(self, robot_name, robot_info): + joint_count = _get_joint_count(robot_info) + robot = _make_mock_robot(robot_name, robot_info, task_status="running") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert "joints" in result + assert len(result["joints"]) == joint_count + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_get_state_filters_arrays(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + if "joints" in result: + for key, value in result["joints"].items(): + assert isinstance(value, float), f"Non-float value for {key}: {type(value)}" + assert not key.startswith("image") and not key.startswith("depth") + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_get_state_task_info(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info, task_status="running") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert result["task_status"] == "running" + assert result["instruction"] == "pick up the cube" + assert result["step_count"] == 42 + + @pytest.mark.parametrize("robot_name,robot_info", ALL_ROBOTS, ids=[r[0] for r in ALL_ROBOTS]) + def test_get_features_delegates(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getFeatures()) + robot.get_features.assert_called_once() + assert result["status"] == "success" + + +# ── TestSimDriverAllRobots ─────────────────────────────────────── + + +class TestSimDriverAllRobots: + """Parametrized tests for SimulationDeviceDriver across all 32 sim-capable robots.""" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_identity(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + assert driver.identity.device_type == "strands_sim" + assert driver.identity.model == f"{robot_name}_sim" + assert driver.identity.manufacturer == "strands-robots" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_status_idle(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + assert driver.status.availability == "idle" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_status_busy(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + sim._world.robots[robot_name].policy_running = True + assert driver.status.availability == "busy" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_execute_auto_detects_robot(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + result = asyncio.run( + driver.execute("pick up cube", "mock", 30.0, "") + ) + sim.start_policy.assert_called_once_with( + robot_name=robot_name, policy_provider="mock", instruction="pick up cube", duration=30.0 + ) + assert result["status"] == "success" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_stop_sets_policy_running_false(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + sim._world.robots[robot_name].policy_running = True + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.stop()) + assert sim._world.robots[robot_name].policy_running is False + assert result["status"] == "success" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_step_delegates(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.step(10)) + sim.step.assert_called_once_with(10) + assert result["status"] == "success" + + @pytest.mark.parametrize("robot_name,robot_info", SIM_ROBOTS, ids=[r[0] for r in SIM_ROBOTS]) + def test_reset_delegates(self, robot_name, robot_info): + sim = _make_mock_sim(robot_name, robot_info) + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.reset()) + sim.reset.assert_called_once() + assert result["status"] == "success" + + +# ── TestRealOnlyRobots ─────────────────────────────────────────── + + +class TestRealOnlyRobots: + """Tests for real-only robots (no sim asset): lekiwi, reachy2, hope_jr, earthrover, omx, bi_openarm.""" + + @pytest.mark.parametrize("robot_name,robot_info", REAL_ONLY_ROBOTS, ids=[r[0] for r in REAL_ONLY_ROBOTS]) + def test_driver_creation(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + assert driver is not None + + @pytest.mark.parametrize("robot_name,robot_info", REAL_ONLY_ROBOTS, ids=[r[0] for r in REAL_ONLY_ROBOTS]) + def test_identity_no_asset(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + assert driver.identity.model == robot_name + assert driver.identity.device_type == "strands_robot" + + @pytest.mark.parametrize("robot_name,robot_info", REAL_ONLY_ROBOTS, ids=[r[0] for r in REAL_ONLY_ROBOTS]) + def test_execute_delegates(self, robot_name, robot_info): + robot = _make_mock_robot(robot_name, robot_info) + driver = RobotDeviceDriver(robot) + result = asyncio.run( + driver.execute("move forward", "mock", 10.0, 0) + ) + robot.start_task.assert_called_once() + assert result["status"] == "success" + + +# ── TestMultiRobotSimulation ───────────────────────────────────── + + +class TestMultiRobotSimulation: + """Tests for multi-robot simulation scenarios with diverse joint counts.""" + + def _make_robot_data(self, running=False): + robot_data = MagicMock() + robot_data.policy_running = running + robot_data.policy_steps = 0 + robot_data.policy_instruction = "" + return robot_data + + def test_mixed_joint_counts(self): + """so100 (13 joints) + unitree_g1 (46 joints) in one world.""" + robots_in_world = { + "so100": self._make_robot_data(), + "unitree_g1": self._make_robot_data(), + } + sim = _make_mock_sim("mixed", _REGISTRY["so100"], robots_in_world=robots_in_world) + driver = SimulationDeviceDriver(sim) + # Execute auto-detects first robot + asyncio.run( + driver.execute("test", "mock", 10.0, "") + ) + sim.start_policy.assert_called_once() + call_kwargs = sim.start_policy.call_args + assert call_kwargs[1]["robot_name"] in ("so100", "unitree_g1") + + def test_stop_all_policies(self): + """Stop sets policy_running=False on all robots in the world.""" + robots_in_world = { + "so100": self._make_robot_data(running=True), + "panda": self._make_robot_data(running=True), + "unitree_go2": self._make_robot_data(running=True), + } + sim = _make_mock_sim("fleet", _REGISTRY["so100"], robots_in_world=robots_in_world) + driver = SimulationDeviceDriver(sim) + asyncio.run(driver.stop()) + for name, robot_data in robots_in_world.items(): + assert robot_data.policy_running is False, f"{name} still running" + + def test_execute_with_explicit_robot_name(self): + """Target a specific robot in a multi-robot sim.""" + robots_in_world = { + "so100": self._make_robot_data(), + "unitree_g1": self._make_robot_data(), + } + sim = _make_mock_sim("multi", _REGISTRY["so100"], robots_in_world=robots_in_world) + driver = SimulationDeviceDriver(sim) + asyncio.run( + driver.execute("walk forward", "mock", 30.0, "unitree_g1") + ) + sim.start_policy.assert_called_once_with( + robot_name="unitree_g1", policy_provider="mock", instruction="walk forward", duration=30.0 + ) + + def test_execute_empty_world(self): + """Returns error when no robots in simulation.""" + sim = _make_mock_sim("empty", _REGISTRY["so100"], robots_in_world={}) + driver = SimulationDeviceDriver(sim) + result = asyncio.run( + driver.execute("test", "mock", 10.0, "") + ) + assert result["status"] == "error" + + +# ── TestEdgeCases ──────────────────────────────────────────────── + + +class TestEdgeCases: + """Edge case tests for driver robustness.""" + + def test_observation_all_arrays(self): + """Observation with only array values → joints dict is empty.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"]) + robot.robot.get_observation.return_value = { + "camera_front": _FakeArray((480, 640, 3)), + "camera_wrist": _FakeArray((480, 640, 3)), + "depth": _FakeArray((480, 640)), + } + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert result.get("joints", {}) == {} + + def test_observation_empty(self): + """Empty observation → no joints key or empty joints.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"]) + robot.robot.get_observation.return_value = {} + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert result.get("joints", {}) == {} + + def test_observation_raises(self): + """get_observation() throws → getState still returns task info.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"], task_status="running") + robot.robot.get_observation.side_effect = RuntimeError("hardware disconnected") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert result["task_status"] == "running" + assert result["instruction"] == "pick up the cube" + assert "joints" not in result + + def test_no_inner_robot(self): + """robot.robot is None → getState skips observation.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"], task_status="running") + robot.robot = None + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert result["task_status"] == "running" + assert "joints" not in result + + def test_no_task_state(self): + """_task_state is None → status is idle, getState has no task info.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"]) + robot._task_state = None + driver = RobotDeviceDriver(robot) + assert driver.status.availability == "idle" + result = asyncio.run(driver.getState()) + assert "task_status" not in result + + def test_float_conversion_failure(self): + """Non-numeric scalar in observation → graceful handling via exception catch.""" + robot = _make_mock_robot("so100", _REGISTRY["so100"]) + robot.robot.get_observation.return_value = { + "joint_0": 0.5, + "metadata": "not_a_number", + } + driver = RobotDeviceDriver(robot) + # float("not_a_number") raises ValueError; the driver wraps get_observation + # in a try/except, so it either filters it out or catches the error + result = asyncio.run(driver.getState()) + # Either joints has only joint_0, or the whole observation was skipped + if "joints" in result: + assert "metadata" not in result["joints"] or isinstance(result["joints"].get("metadata"), float) + + def test_max_joint_robot(self): + """unitree_g1 (46 joints) — all joints appear in getState.""" + info = _REGISTRY["unitree_g1"] + robot = _make_mock_robot("unitree_g1", info, task_status="running") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert len(result["joints"]) == 46 + + def test_min_joint_robot(self): + """koch (7 joints) — correct joint count.""" + info = _REGISTRY["koch"] + robot = _make_mock_robot("koch", info, task_status="running") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + assert len(result["joints"]) == 7 + + +# ── TestRobotMeshDispatchAllTypes ──────────────────────────────── + + +# robot_mesh imports `strands` (@tool, ToolContext) and device_connect_agent_tools. +# Prefer the REAL packages when installed (both are hard deps here) so we never +# leave a stub in sys.modules that leaks into sibling test modules. Only fall +# back to stubs when a package genuinely is not importable. +def _passthrough_tool(*args, **kwargs): + """Stub for strands @tool / @tool(context=True): return the function unchanged.""" + if args and callable(args[0]): + return args[0] + return lambda fn: fn + + +try: + import strands # noqa: F401 — use the real package when installed +except Exception: + _m = MagicMock() + _m.tool = _passthrough_tool + _types_tools = MagicMock() + _types_tools.ToolContext = object + sys.modules["strands"] = _m + sys.modules["strands.types"] = MagicMock() + sys.modules["strands.types.tools"] = _types_tools + +try: + import device_connect_agent_tools # noqa: F401 + import device_connect_agent_tools.connection # noqa: F401 +except Exception: + sys.modules.setdefault("device_connect_agent_tools", MagicMock()) + sys.modules.setdefault("device_connect_agent_tools.connection", MagicMock()) + + +class _FakeConnection: + """Fake connection with all methods the dispatch uses.""" + + def __init__(self, devices=None): + self.zone = "default" + self._devices = devices or [] + self._invoke_results = {} + self._inbox = {} + self._sync_subs = {} + + def list_devices(self, device_type=None): + if device_type: + return [d for d in self._devices if d.get("device_type") == device_type] + return list(self._devices) + + def invoke(self, device_id, function, params=None, timeout=30.0): + key = (device_id, function) + if key in self._invoke_results: + return self._invoke_results[key] + return {"result": {"status": "ok"}} + + def broadcast(self, function, params=None, timeout=5.0): + results = [] + for d in self._devices: + try: + r = self.invoke(d["device_id"], function, params, timeout=timeout) + results.append({"device_id": d["device_id"], "result": r}) + except Exception as e: + results.append({"device_id": d["device_id"], "error": str(e)}) + return results + + def subscribe_buffered(self, subject, name=None): + name = name or subject + self._inbox[name] = [] + self._sync_subs[name] = True + return name + + def get_inbox(self, name=None): + if name is not None: + return {name: list(self._inbox.get(name, []))} + return {k: list(v) for k, v in self._inbox.items()} + + +# Build a diverse fleet of sample devices from the registry +_CATEGORY_REPRESENTATIVES = { + "arm": ("so100", "strands_robot"), + "bimanual": ("aloha", "strands_robot"), + "hand": ("shadow_hand", "strands_robot"), + "humanoid": ("unitree_g1", "strands_sim"), + "expressive": ("reachy_mini", "strands_robot"), + "mobile": ("unitree_go2", "strands_sim"), + "mobile_manip": ("google_robot", "strands_sim"), +} + +DIVERSE_DEVICES = [] +for category, (robot_name, device_type) in _CATEGORY_REPRESENTATIVES.items(): + DIVERSE_DEVICES.append({ + "device_id": f"{robot_name}-{category}-1", + "device_type": device_type, + "status": {"availability": "idle"}, + "functions": [{"name": "execute"}, {"name": "stop"}, {"name": "getStatus"}], + "events": ["taskStarted", "taskComplete"] if device_type == "strands_robot" else ["stateUpdate"], + }) + + +class TestRobotMeshDispatchAllTypes: + """Tests robot_mesh dispatch with a diverse fleet spanning all robot categories.""" + + def _get_dispatch(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + return _device_connect_dispatch + + def _call(self, dispatch, conn, action, **kwargs): + defaults = dict( + target="", instruction="", command="", + policy_provider="mock", policy_port=0, + duration=30.0, timeout=5.0, + ) + defaults.update(kwargs) + with patch("device_connect_agent_tools.connection.get_connection", return_value=conn): + return dispatch(action, **{k: defaults[k] for k in [ + "target", "instruction", "command", + "policy_provider", "policy_port", "duration", "timeout", + ]}) + + def test_peers_lists_all_categories(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "peers") + assert result["status"] == "success" + text = result["content"][0]["text"] + assert f"{len(DIVERSE_DEVICES)} device(s)" in text + for device in DIVERSE_DEVICES: + assert device["device_id"] in text + + def test_tell_arm_robot(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "tell", target="so100-arm-1", instruction="pick up cube") + assert result["status"] == "success" + assert "so100-arm-1" in result["content"][0]["text"] + + def test_tell_humanoid_sim(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "tell", target="unitree_g1-humanoid-1", instruction="walk forward") + assert result["status"] == "success" + assert "unitree_g1-humanoid-1" in result["content"][0]["text"] + + def test_tell_mobile_robot(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "tell", target="unitree_go2-mobile-1", instruction="navigate to door") + assert result["status"] == "success" + assert "unitree_go2-mobile-1" in result["content"][0]["text"] + + def test_emergency_stop_all_types(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "emergency_stop") + assert result["status"] == "success" + text = result["content"][0]["text"] + assert "E-STOP" in text + assert f"{len(DIVERSE_DEVICES)}/{len(DIVERSE_DEVICES)}" in text + + def test_status_mixed_fleet(self): + conn = _FakeConnection(devices=DIVERSE_DEVICES) + dispatch = self._get_dispatch() + result = self._call(dispatch, conn, "status") + assert result["status"] == "success" + assert f"{len(DIVERSE_DEVICES)} device(s)" in result["content"][0]["text"] diff --git a/tests/test_device_connect_drivers.py b/tests/test_device_connect_drivers.py new file mode 100644 index 00000000..bf7da286 --- /dev/null +++ b/tests/test_device_connect_drivers.py @@ -0,0 +1,857 @@ +"""Unit tests for Device Connect DeviceDriver adapters. + +Tests RobotDeviceDriver, SimulationDeviceDriver, ReachyMiniDriver, +init_device_connect(), and the updated robot_mesh tool. + +All external dependencies (Zenoh, LeRobot, device_connect_edge, strands) are mocked. +""" + +import asyncio +import json +import math +import sys +import unittest +from dataclasses import dataclass +from enum import Enum +from unittest.mock import AsyncMock, MagicMock, patch + + +# ── Mock heavy dependencies before importing ────────────────────── + +# Mock device_connect_edge +mock_device_connect_edge = MagicMock() +mock_drivers = MagicMock() + + +class _FakeDeviceDriver: + """Minimal stub so our drivers can subclass it.""" + device_type = None + + def __init__(self): + self._transport = None + + def set_device(self, device): + pass + + @property + def transport(self): + return self._transport + + +# Make @rpc, @emit, @periodic, @on pass-through decorators +def _passthrough_decorator(*args, **kwargs): + if len(args) == 1 and callable(args[0]): + return args[0] + def wrapper(func): + # Tag the function so tests can verify decorator usage + for k, v in kwargs.items(): + setattr(func, f"_{k}", v) + return func + return wrapper + + +mock_drivers.DeviceDriver = _FakeDeviceDriver +mock_drivers.rpc = _passthrough_decorator +mock_drivers.emit = _passthrough_decorator +mock_drivers.periodic = _passthrough_decorator +mock_drivers.on = _passthrough_decorator + +mock_types = MagicMock() + + +@dataclass +class FakeDeviceIdentity: + device_type: str = None + manufacturer: str = None + model: str = None + description: str = None + serial_number: str = None + firmware_version: str = None + arch: str = None + commissioning_comment: str = None + + +@dataclass +class FakeDeviceStatus: + availability: str = "idle" + busy_score: float = 0.0 + location: str = None + battery: int = None + online: bool = True + error_state: str = None + + +mock_types.DeviceIdentity = FakeDeviceIdentity +mock_types.DeviceStatus = FakeDeviceStatus + +# Save originals so we can restore after this module's tests run +_saved_modules = {} +_mock_keys = ("device_connect_edge", "device_connect_edge.drivers", + "device_connect_edge.types", "device_connect_edge.device") +# Also track strands_robots.device_connect submodules that will be imported +# with the mocked base class — they need to be purged so later tests re-import +# with the real base class. +_strands_dc_keys = [k for k in sys.modules if k.startswith("strands_robots.device_connect")] +for _key in list(_mock_keys) + _strands_dc_keys: + _saved_modules[_key] = sys.modules.get(_key) + +sys.modules["device_connect_edge"] = mock_device_connect_edge +sys.modules["device_connect_edge.drivers"] = mock_drivers +sys.modules["device_connect_edge.types"] = mock_types +sys.modules["device_connect_edge.device"] = MagicMock() + +# Mock DeviceRuntime +mock_device_runtime = MagicMock() +mock_device_connect_edge.DeviceRuntime = mock_device_runtime + +# Now import our modules +from strands_robots.device_connect.robot_driver import RobotDeviceDriver +from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + + +def teardown_module(): + """Restore real device_connect_edge modules so other test files are not affected. + + Also purge cached strands_robots.device_connect submodules that were imported + with the mock base class, so later test files get fresh imports with the real base. + """ + # Restore device_connect_edge modules + for key, original in _saved_modules.items(): + if original is None: + sys.modules.pop(key, None) + else: + sys.modules[key] = original + # Purge ALL strands_robots.device_connect submodules — they were imported + # with the mock DeviceDriver base class and must be re-imported fresh. + for key in list(sys.modules): + if key.startswith("strands_robots.device_connect"): + sys.modules.pop(key, None) + + +# ── Task state mocks ────────────────────────────────────────────── + +class FakeTaskStatus(Enum): + IDLE = "idle" + RUNNING = "running" + COMPLETED = "completed" + STOPPED = "stopped" + ERROR = "error" + + +@dataclass +class FakeTaskState: + status: FakeTaskStatus = FakeTaskStatus.IDLE + instruction: str = "" + start_time: float = 0.0 + duration: float = 0.0 + step_count: int = 0 + error_message: str = "" + + +def _make_mock_robot(tool_name="so100", task_status="idle"): + robot = MagicMock() + robot.tool_name_str = tool_name + robot._task_state = FakeTaskState( + status=FakeTaskStatus(task_status), + instruction="pick up the cube" if task_status == "running" else "", + step_count=42 if task_status == "running" else 0, + ) + robot.start_task.return_value = {"status": "success", "content": [{"text": "Task started"}]} + robot.stop_task.return_value = {"status": "success", "content": [{"text": "Task stopped"}]} + robot.get_task_status.return_value = {"status": "success", "content": [{"text": "Status info"}]} + robot.get_features.return_value = { + "status": "success", + "content": [{"json": {"observation_features": {"joint1": "float"}, "action_features": {"joint1": "float"}}}], + } + # Mock inner lerobot robot + robot.robot = MagicMock() + robot.robot.get_observation.return_value = {"joint1": 0.5, "joint2": -1.2} + return robot + + +def _make_mock_sim(tool_name="so100_sim"): + sim = MagicMock() + sim.tool_name_str = tool_name + + # SimWorld-like structure + robot_data = MagicMock() + robot_data.policy_running = False + robot_data.policy_steps = 0 + robot_data.policy_instruction = "" + + world = MagicMock() + world.robots = {"so100": robot_data} + world.sim_time = 0.0 + world.step_count = 0 + sim._world = world + + sim.start_policy.return_value = {"status": "success", "content": [{"text": "Policy started"}]} + sim.get_state.return_value = {"status": "success", "content": [{"text": "State info"}]} + sim.get_features.return_value = {"status": "success", "content": [{"json": {"features": {}}}]} + sim.step.return_value = {"status": "success", "content": [{"text": "Stepped"}]} + sim.reset.return_value = {"status": "success", "content": [{"text": "Reset"}]} + return sim + + +# ── TestRobotDeviceDriver ───────────────────────────────────────── + +class TestRobotDeviceDriver(unittest.TestCase): + + def test_identity(self): + robot = _make_mock_robot(tool_name="so100") + driver = RobotDeviceDriver(robot) + identity = driver.identity + self.assertEqual(identity.device_type, "strands_robot") + self.assertEqual(identity.manufacturer, "strands-robots") + self.assertEqual(identity.model, "so100") + + def test_status_idle(self): + robot = _make_mock_robot(task_status="idle") + driver = RobotDeviceDriver(robot) + status = driver.status + self.assertEqual(status.availability, "idle") + self.assertEqual(status.busy_score, 0.0) + + def test_status_busy(self): + robot = _make_mock_robot(task_status="running") + driver = RobotDeviceDriver(robot) + status = driver.status + self.assertEqual(status.availability, "busy") + self.assertEqual(status.busy_score, 1.0) + + def test_execute_rpc(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + result = asyncio.run( + driver.execute("pick up cube", "groot", 30.0, 0) + ) + robot.start_task.assert_called_once_with("pick up cube", "groot", None, "localhost", 30.0) + self.assertEqual(result["status"], "success") + + def test_execute_rpc_with_port(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + asyncio.run( + driver.execute("wave", "groot", 10.0, 50051) + ) + robot.start_task.assert_called_once_with("wave", "groot", 50051, "localhost", 10.0) + + def test_stop_rpc(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.stop()) + robot.stop_task.assert_called_once() + self.assertEqual(result["status"], "success") + + def test_get_status_rpc(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getStatus()) + robot.get_task_status.assert_called_once() + self.assertEqual(result["status"], "success") + + def test_get_features_rpc(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getFeatures()) + robot.get_features.assert_called_once() + self.assertEqual(result["status"], "success") + + def test_get_state_rpc(self): + robot = _make_mock_robot(task_status="running") + driver = RobotDeviceDriver(robot) + result = asyncio.run(driver.getState()) + self.assertEqual(result["task_status"], "running") + self.assertEqual(result["instruction"], "pick up the cube") + self.assertEqual(result["step_count"], 42) + # Joints should be read from inner robot + self.assertIn("joints", result) + self.assertAlmostEqual(result["joints"]["joint1"], 0.5) + + def test_connect_disconnect_noop(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + asyncio.run(driver.connect()) + asyncio.run(driver.disconnect()) + # Should not raise + + def test_emergency_stop_handler(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + asyncio.run( + driver.onEmergencyStop("other-robot", "emergencyStop", {"reason": "test"}) + ) + robot.stop_task.assert_called_once() + + +# ── TestSimulationDeviceDriver ──────────────────────────────────── + +class TestSimulationDeviceDriver(unittest.TestCase): + + def test_identity(self): + sim = _make_mock_sim(tool_name="mujoco_sim") + driver = SimulationDeviceDriver(sim) + identity = driver.identity + self.assertEqual(identity.device_type, "strands_sim") + self.assertEqual(identity.manufacturer, "strands-robots") + self.assertEqual(identity.model, "mujoco_sim") + + def test_status_idle(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + status = driver.status + self.assertEqual(status.availability, "idle") + + def test_status_busy(self): + sim = _make_mock_sim() + sim._world.robots["so100"].policy_running = True + driver = SimulationDeviceDriver(sim) + status = driver.status + self.assertEqual(status.availability, "busy") + + def test_identity_sim_type(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + self.assertEqual(driver.device_type, "strands_sim") + + def test_execute_rpc(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + result = asyncio.run( + driver.execute("pick up cube", "mock", 10.0) + ) + sim.start_policy.assert_called_once_with( + robot_name="so100", + policy_provider="mock", + instruction="pick up cube", + duration=10.0, + ) + self.assertEqual(result["status"], "success") + + def test_execute_with_robot_name(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + asyncio.run( + driver.execute("wave", "mock", 5.0, robot_name="arm2") + ) + sim.start_policy.assert_called_once_with( + robot_name="arm2", + policy_provider="mock", + instruction="wave", + duration=5.0, + ) + + def test_stop_rpc(self): + sim = _make_mock_sim() + sim._world.robots["so100"].policy_running = True + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.stop()) + self.assertEqual(result["status"], "success") + self.assertFalse(sim._world.robots["so100"].policy_running) + + def test_get_status_rpc(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.getStatus()) + sim.get_state.assert_called_once() + + def test_get_features_rpc(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.getFeatures()) + sim.get_features.assert_called_once() + + def test_step_rpc(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.step(10)) + sim.step.assert_called_once_with(10) + + def test_reset_rpc(self): + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + result = asyncio.run(driver.reset()) + sim.reset.assert_called_once() + + def test_emergency_stop_handler(self): + sim = _make_mock_sim() + sim._world.robots["so100"].policy_running = True + driver = SimulationDeviceDriver(sim) + asyncio.run( + driver.onEmergencyStop("other-device", "emergencyStop", {"reason": "test"}) + ) + self.assertFalse(sim._world.robots["so100"].policy_running) + + +# ── TestReachyMiniDriver ───────────────────────────────────────── + +class TestReachyMiniDriver(unittest.TestCase): + + def setUp(self): + # Mock reachy_transport module but keep real ZenohLink/WebSocketLink + from strands_robots.device_connect.reachy_transport import ZenohLink, WebSocketLink + self.mock_transport_mod = MagicMock() + self.mock_transport_mod.api.return_value = {"status": "ok"} + self.mock_transport_mod.rpy_to_pose.side_effect = lambda *args, **kwargs: [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] + self.mock_transport_mod.identity_pose.return_value = [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]] + self.mock_transport_mod.ZenohLink = ZenohLink + self.mock_transport_mod.WebSocketLink = WebSocketLink + + self.transport_patcher = patch.dict(sys.modules, { + "strands_robots.device_connect.reachy_transport": self.mock_transport_mod, + }) + self.transport_patcher.start() + + # Re-import to pick up mocks + if "strands_robots.device_connect.reachy_mini_driver" in sys.modules: + del sys.modules["strands_robots.device_connect.reachy_mini_driver"] + from strands_robots.device_connect.reachy_mini_driver import ReachyMiniDriver + self.ReachyMiniDriver = ReachyMiniDriver + + def tearDown(self): + self.transport_patcher.stop() + + def _make_driver(self, **kwargs): + """Create a driver with a mocked Device Connect transport and ZenohLink-like _hw.""" + driver = self.ReachyMiniDriver(**kwargs) + mock_transport = AsyncMock() + mock_transport.publish = AsyncMock() + mock_transport.subscribe = AsyncMock() + driver._transport = mock_transport + + # Create a HW link that delegates to mock_transport (like ZenohLink does) + prefix = driver._prefix + class _MockZenohLink: + async def send_cmd(self, cmd): + await mock_transport.publish( + f"{prefix}/command", json.dumps(cmd).encode() + ) + async def start(self, on_joints, on_imu): + await mock_transport.subscribe(f"{prefix}/joint_positions", on_joints) + await mock_transport.subscribe(f"{prefix}/imu_data", on_imu) + async def stop(self): + pass + driver._hw = _MockZenohLink() + return driver + + def test_identity(self): + driver = self.ReachyMiniDriver(host="192.168.1.50") + identity = driver.identity + self.assertEqual(identity.device_type, "reachy_mini") + self.assertEqual(identity.manufacturer, "Pollen Robotics") + self.assertIn("192.168.1.50", identity.model) + + def test_look_rpc(self): + driver = self._make_driver() + result = asyncio.run( + driver.look(pitch=15, yaw=30) + ) + self.assertEqual(result["status"], "success") + self.assertEqual(result["pitch"], 15) + self.assertEqual(result["yaw"], 30) + # Verify transport.publish was called with the command topic + driver._transport.publish.assert_awaited() + topic = driver._transport.publish.call_args[0][0] + self.assertEqual(topic, "reachy_mini/command") + + def test_antennas_rpc(self): + driver = self._make_driver() + result = asyncio.run( + driver.antennas(left=45, right=-30) + ) + self.assertEqual(result["status"], "success") + self.assertEqual(result["left"], 45) + self.assertEqual(result["right"], -30) + driver._transport.publish.assert_awaited() + + def test_get_joints_rpc(self): + driver = self._make_driver() + # Pre-populate cached joint data + driver._latest_joints = { + "head_joint_positions": [0.1, 0.2, 0.3], + "antennas_joint_positions": [0.5, -0.5], + } + result = asyncio.run(driver.getJoints()) + self.assertEqual(result["status"], "success") + self.assertIn("head", result) + self.assertIn("antennas", result) + + def test_get_joints_no_data(self): + driver = self._make_driver() + result = asyncio.run(driver.getJoints()) + self.assertEqual(result["status"], "error") + + def test_get_imu_rpc(self): + driver = self._make_driver() + driver._latest_imu = { + "accelerometer": [0.1, 0.2, 9.8], + "gyroscope": [0.0, 0.0, 0.0], + "quaternion": [1, 0, 0, 0], + "temperature": 35.2, + } + result = asyncio.run(driver.getImu()) + self.assertEqual(result["status"], "success") + self.assertAlmostEqual(result["temperature"], 35.2) + + def test_get_imu_no_data(self): + driver = self._make_driver() + result = asyncio.run(driver.getImu()) + self.assertEqual(result["status"], "error") + + def test_enable_motors_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.enableMotors()) + self.assertEqual(result["status"], "success") + self.assertEqual(result["enabled"], "all") + driver._transport.publish.assert_awaited() + + def test_disable_motors_rpc(self): + driver = self._make_driver() + result = asyncio.run( + driver.disableMotors(motor_ids="head_pitch,head_yaw") + ) + self.assertEqual(result["status"], "success") + self.assertEqual(result["disabled"], "head_pitch,head_yaw") + driver._transport.publish.assert_awaited() + + def test_play_move_rpc(self): + driver = self._make_driver() + result = asyncio.run( + driver.playMove("happy", library="emotions") + ) + self.assertEqual(result["status"], "success") + self.assertEqual(result["move"], "happy") + + def test_nod_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.nod()) + self.assertEqual(result["status"], "success") + self.assertEqual(result["expression"], "nod") + # nod sends multiple publish calls (head_pose animation) + self.assertGreater(driver._transport.publish.await_count, 1) + + def test_shake_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.shake()) + self.assertEqual(result["status"], "success") + self.assertEqual(result["expression"], "shake") + + def test_happy_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.happy()) + self.assertEqual(result["status"], "success") + self.assertEqual(result["expression"], "happy") + + def test_wake_up_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.wakeUp()) + self.assertEqual(result["status"], "success") + + def test_sleep_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.sleep()) + self.assertEqual(result["status"], "success") + + def test_stop_motion_rpc(self): + driver = self._make_driver() + result = asyncio.run(driver.stopMotion()) + self.assertEqual(result["status"], "success") + + def test_daemon_status_rpc(self): + self.mock_transport_mod.api.return_value = {"state": "ready", "version": "1.0"} + driver = self._make_driver() + result = asyncio.run(driver.getDaemonStatus()) + self.assertEqual(result["status"], "success") + self.assertEqual(result["state"], "ready") + + @patch("strands_robots.device_connect.reachy_mini_driver.api") + def test_connect_subscribes(self, mock_api): + # Simulate wireless variant (wireless_version=True) + mock_api.return_value = {"wireless_version": True} + driver = self.ReachyMiniDriver() + mock_transport = AsyncMock() + mock_transport.publish = AsyncMock() + mock_transport.subscribe = AsyncMock() + driver._transport = mock_transport + # connect() creates ZenohLink and subscribes via transport + asyncio.run(driver.connect()) + self.assertEqual(mock_transport.subscribe.await_count, 2) + topics = [call[0][0] for call in mock_transport.subscribe.call_args_list] + self.assertIn("reachy_mini/joint_positions", topics) + self.assertIn("reachy_mini/imu_data", topics) + + def test_disconnect(self): + driver = self._make_driver() + asyncio.run(driver.disconnect()) + + def test_emergency_stop_handler(self): + driver = self._make_driver() + asyncio.run( + driver.onEmergencyStop("other-device", "emergencyStop", {"reason": "test"}) + ) + # stopMotion calls REST API, disableMotors calls transport.publish + driver._transport.publish.assert_awaited() + + def test_command_payload_format(self): + """Verify that transport.publish receives correct JSON payload.""" + driver = self._make_driver() + asyncio.run(driver.enableMotors()) + _, payload_bytes = driver._transport.publish.call_args[0] + payload = json.loads(payload_bytes.decode()) + self.assertTrue(payload["torque"]) + self.assertIsNone(payload["ids"]) + + +# ── TestInitDeviceConnect ───────────────────────────────────────── + +class TestInitDeviceConnect(unittest.TestCase): + + @patch("strands_robots.device_connect.DeviceRuntime") + def test_creates_robot_driver(self, MockRuntime): + from strands_robots.device_connect import init_device_connect + mock_runtime = MagicMock() + mock_runtime.run = AsyncMock() + mock_runtime.set_heartbeat_provider = MagicMock() + MockRuntime.return_value = mock_runtime + + robot = _make_mock_robot() + loop = asyncio.new_event_loop() + result = loop.run_until_complete(init_device_connect(robot, peer_id="test-1", peer_type="robot")) + loop.close() + + # Verify DeviceRuntime was created with a RobotDeviceDriver + call_kwargs = MockRuntime.call_args + self.assertIsNotNone(call_kwargs) + driver = call_kwargs.kwargs.get("driver") or call_kwargs[1].get("driver") + self.assertEqual(type(driver).__name__, "RobotDeviceDriver") + self.assertEqual(driver._robot, robot) + + @patch("strands_robots.device_connect.DeviceRuntime") + def test_creates_sim_driver(self, MockRuntime): + from strands_robots.device_connect import init_device_connect + mock_runtime = MagicMock() + mock_runtime.run = AsyncMock() + mock_runtime.set_heartbeat_provider = MagicMock() + MockRuntime.return_value = mock_runtime + + sim = _make_mock_sim() + loop = asyncio.new_event_loop() + result = loop.run_until_complete(init_device_connect(sim, peer_id="test-sim", peer_type="sim")) + loop.close() + + call_kwargs = MockRuntime.call_args + driver = call_kwargs.kwargs.get("driver") or call_kwargs[1].get("driver") + self.assertEqual(type(driver).__name__, "SimulationDeviceDriver") + + @patch("strands_robots.device_connect.DeviceRuntime") + def test_generates_device_id(self, MockRuntime): + from strands_robots.device_connect import init_device_connect + mock_runtime = MagicMock() + mock_runtime.run = AsyncMock() + mock_runtime.set_heartbeat_provider = MagicMock() + MockRuntime.return_value = mock_runtime + + robot = _make_mock_robot(tool_name="so100") + loop = asyncio.new_event_loop() + result = loop.run_until_complete(init_device_connect(robot)) + loop.close() + + call_kwargs = MockRuntime.call_args + device_id = call_kwargs.kwargs.get("device_id") or call_kwargs[1].get("device_id") + self.assertTrue(device_id.startswith("so100-")) + + @patch("strands_robots.device_connect.DeviceRuntime") + def test_explicit_device_id(self, MockRuntime): + from strands_robots.device_connect import init_device_connect + mock_runtime = MagicMock() + mock_runtime.run = AsyncMock() + mock_runtime.set_heartbeat_provider = MagicMock() + MockRuntime.return_value = mock_runtime + + robot = _make_mock_robot() + loop = asyncio.new_event_loop() + result = loop.run_until_complete(init_device_connect(robot, peer_id="my-robot-42")) + loop.close() + + call_kwargs = MockRuntime.call_args + device_id = call_kwargs.kwargs.get("device_id") or call_kwargs[1].get("device_id") + self.assertEqual(device_id, "my-robot-42") + + @patch("strands_robots.device_connect.DeviceRuntime") + def test_sets_heartbeat_provider(self, MockRuntime): + from strands_robots.device_connect import init_device_connect + mock_runtime = MagicMock() + mock_runtime.run = AsyncMock() + mock_runtime.set_heartbeat_provider = MagicMock() + MockRuntime.return_value = mock_runtime + + robot = _make_mock_robot() + loop = asyncio.new_event_loop() + result = loop.run_until_complete(init_device_connect(robot, peer_id="test-hb")) + loop.close() + + mock_runtime.set_heartbeat_provider.assert_called_once() + + +# ── TestEmergencyStop (cross-driver) ────────────────────────────── + +class TestEmergencyStop(unittest.TestCase): + + def test_robot_reacts_to_emergency_stop(self): + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + asyncio.run( + driver.onEmergencyStop("reachy-1", "emergencyStop", {"reason": "button pressed"}) + ) + robot.stop_task.assert_called_once() + + def test_sim_reacts_to_emergency_stop(self): + sim = _make_mock_sim() + sim._world.robots["so100"].policy_running = True + driver = SimulationDeviceDriver(sim) + asyncio.run( + driver.onEmergencyStop("barista-001", "emergencyStop", {"reason": "agent-initiated"}) + ) + self.assertFalse(sim._world.robots["so100"].policy_running) + + +# ── TestRobotMeshTool (Device Connect backend) ─────────────────── + +class TestRobotMeshToolDeviceConnect(unittest.TestCase): + + def setUp(self): + # Mock device_connect_agent_tools.connection + self.mock_conn = MagicMock() + self.mock_conn.list_devices.return_value = [ + { + "device_id": "so100-lab-1", + "device_type": "strands_robot", + "status": {"availability": "idle"}, + "functions": [{"name": "execute"}, {"name": "stop"}], + "events": [], + }, + { + "device_id": "reachy-mini-1", + "device_type": "reachy_mini", + "status": {"availability": "idle"}, + "functions": [{"name": "look"}, {"name": "nod"}], + "events": [], + }, + ] + self.mock_conn.invoke.return_value = { + "jsonrpc": "2.0", + "id": "test", + "result": {"status": "accepted"}, + } + + # Mock the device_connect_agent_tools modules before importing + mock_aft = MagicMock() + mock_aft_conn = MagicMock() + mock_aft_conn.get_connection.return_value = self.mock_conn + self._saved_modules = {} + for mod in ["device_connect_agent_tools", "device_connect_agent_tools.connection", + "device_connect_agent_tools.tools", "device_connect_agent_tools.agent", + "device_connect_agent_tools.adapters", "device_connect_agent_tools.adapters.strands"]: + self._saved_modules[mod] = sys.modules.get(mod) + sys.modules[mod] = mock_aft if mod == "device_connect_agent_tools" else mock_aft_conn + + # NOTE: robot_mesh is intentionally NOT deleted from sys.modules. + # _device_connect_dispatch imports get_connection lazily at call time, so + # the mocked device_connect_agent_tools.connection installed above is + # picked up without a re-import. Deleting + re-importing robot_mesh would + # create a second module object and break sibling test files + # (test_robot_mesh_tool / _security / deep_mesh) that hold a reference to + # the original module — their _resolve_mesh patches would miss. + + def tearDown(self): + for mod, saved in self._saved_modules.items(): + if saved is None: + sys.modules.pop(mod, None) + else: + sys.modules[mod] = saved + + def test_peers_action(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("peers", "", "", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "success") + text = result["content"][0]["text"] + self.assertIn("so100-lab-1", text) + self.assertIn("reachy-mini-1", text) + + def test_tell_action(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch( + "tell", "so100-lab-1", "pick up cube", "", "groot", 0, 30.0, 30.0, + ) + self.assertEqual(result["status"], "success") + self.mock_conn.invoke.assert_called_once() + call_args = self.mock_conn.invoke.call_args + self.assertEqual(call_args[0][0], "so100-lab-1") + self.assertEqual(call_args[0][1], "execute") + + def test_stop_action(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("stop", "so100-lab-1", "", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "success") + self.mock_conn.invoke.assert_called_once_with("so100-lab-1", "stop", timeout=5.0) + + def test_emergency_stop(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("emergency_stop", "", "", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "success") + self.assertIn("2", result["content"][0]["text"]) # 2 devices stopped + + def test_missing_target(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("tell", "", "do something", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "error") + self.assertIn("target", result["content"][0]["text"]) + + def test_missing_instruction(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("tell", "so100-lab-1", "", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "error") + self.assertIn("instruction", result["content"][0]["text"]) + + def test_status_action(self): + from strands_robots.tools.robot_mesh import _device_connect_dispatch + result = _device_connect_dispatch("status", "", "", "", "mock", 0, 30.0, 30.0) + self.assertEqual(result["status"], "success") + self.assertIn("2 device(s)", result["content"][0]["text"]) + + +# ── TestReachyTransport ─────────────────────────────────────────── + +class TestReachyTransport(unittest.TestCase): + """Test the extracted transport helpers.""" + + def test_rpy_to_pose_identity(self): + from strands_robots.device_connect.reachy_transport import rpy_to_pose + pose = rpy_to_pose(0, 0, 0) + # Should be close to identity rotation + self.assertAlmostEqual(pose[0][0], 1.0, places=5) + self.assertAlmostEqual(pose[1][1], 1.0, places=5) + self.assertAlmostEqual(pose[2][2], 1.0, places=5) + self.assertAlmostEqual(pose[3][3], 1.0, places=5) + + def test_rpy_to_pose_translation(self): + from strands_robots.device_connect.reachy_transport import rpy_to_pose + pose = rpy_to_pose(0, 0, 0, x_mm=100, y_mm=200, z_mm=300) + self.assertAlmostEqual(pose[0][3], 0.1, places=5) # 100mm = 0.1m + self.assertAlmostEqual(pose[1][3], 0.2, places=5) + self.assertAlmostEqual(pose[2][3], 0.3, places=5) + + def test_identity_pose(self): + from strands_robots.device_connect.reachy_transport import identity_pose + pose = identity_pose() + self.assertEqual(pose, [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]) + + def test_resolve_host_ip(self): + from strands_robots.device_connect.reachy_transport import resolve_host + # IP should pass through unchanged + result = resolve_host("192.168.1.1") + self.assertEqual(result, "192.168.1.1") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_device_connect_integration.py b/tests/test_device_connect_integration.py new file mode 100644 index 00000000..2d821ffc --- /dev/null +++ b/tests/test_device_connect_integration.py @@ -0,0 +1,368 @@ +"""Integration tests for Device Connect DeviceDriver adapters. + +Requires Docker infrastructure running: + - Zenoh router (:7447) + - etcd (:2379) + - device-registry (:8000) + +Start with: + cd device-connect/packages/device-connect-server + docker compose -f infra/docker-compose-dev.yml up -d + +Run with: + MESSAGING_BACKEND=zenoh ZENOH_CONNECT=tcp/localhost:7447 \ + DEVICE_CONNECT_ALLOW_INSECURE=true python3 -m pytest tests/test_device_connect_integration.py -v +""" + +import asyncio +import os +from unittest.mock import MagicMock + +import pytest + +pytestmark = [ + pytest.mark.integration, + pytest.mark.skipif( + not os.getenv("DEVICE_CONNECT_ALLOW_INSECURE"), + reason="Requires Docker infrastructure (set DEVICE_CONNECT_ALLOW_INSECURE=true)", + ), +] + + +def _make_mock_robot(tool_name="itest-robot"): + """Create a mock Robot for integration testing.""" + from dataclasses import dataclass + from enum import Enum + + class TaskStatus(Enum): + IDLE = "idle" + RUNNING = "running" + + @dataclass + class TaskState: + status: TaskStatus = TaskStatus.IDLE + instruction: str = "" + step_count: int = 0 + + robot = MagicMock() + robot.tool_name_str = tool_name + robot._task_state = TaskState() + robot.start_task.return_value = {"status": "success", "content": [{"text": "Task started"}]} + robot.stop_task.return_value = {"status": "success", "content": [{"text": "Task stopped"}]} + robot.get_task_status.return_value = {"status": "success", "content": [{"text": "Idle"}]} + robot.get_features.return_value = {"status": "success", "content": [{"json": {}}]} + robot.robot = MagicMock() + robot.robot.get_observation.return_value = {"joint1": 0.5} + return robot + + +def _make_mock_sim(tool_name="itest-sim"): + """Create a mock Simulation for integration testing.""" + sim = MagicMock() + sim.tool_name_str = tool_name + + robot_data = MagicMock() + robot_data.policy_running = False + robot_data.policy_steps = 0 + robot_data.policy_instruction = "" + + world = MagicMock() + world.robots = {"arm1": robot_data} + world.sim_time = 0.0 + world.step_count = 0 + sim._world = world + + sim.start_policy.return_value = {"status": "success", "content": [{"text": "Started"}]} + sim.get_state.return_value = {"status": "success", "content": [{"text": "State"}]} + sim.get_features.return_value = {"status": "success", "content": [{"json": {}}]} + sim.step.return_value = {"status": "success", "content": [{"text": "Stepped"}]} + sim.reset.return_value = {"status": "success", "content": [{"text": "Reset"}]} + return sim + + +@pytest.fixture(autouse=True) +def device_connect_env(): + """Set environment for Device Connect messaging. + + Supports both Zenoh and NATS backends. The backend is chosen by the + MESSAGING_BACKEND env-var (default ``nats`` to match the standard + docker-compose-itest.yml setup). + """ + backend = os.getenv("MESSAGING_BACKEND", "nats") + os.environ.setdefault("MESSAGING_BACKEND", backend) + + if backend == "zenoh": + url = os.getenv("ZENOH_CONNECT", "tcp/localhost:7447") + os.environ.setdefault("ZENOH_CONNECT", url) + os.environ.setdefault("MESSAGING_URLS", url) + else: + url = os.getenv("NATS_URL", "nats://localhost:4222") + os.environ.setdefault("NATS_URL", url) + os.environ.setdefault("MESSAGING_URLS", url) + + os.environ.setdefault("DEVICE_CONNECT_ALLOW_INSECURE", "true") + yield + + +class TestRobotDriverRegistration: + """Test that RobotDeviceDriver registers and is discoverable.""" + + async def test_robot_driver_registers(self): + """Create RobotDeviceDriver + DeviceRuntime, verify device is discoverable.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + runtime = DeviceRuntime( + driver=driver, + device_id="itest-robot-001", + allow_insecure=True, + ) + + task = asyncio.create_task(runtime.run()) + try: + # Wait for registration + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices, device_type="strands_robot") + device_ids = [d["device_id"] for d in devices] + assert "itest-robot-001" in device_ids, f"Expected itest-robot-001 in {device_ids}" + finally: + await asyncio.to_thread(disconnect) + finally: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + async def test_robot_execute_rpc(self): + """Discover robot and invoke execute RPC.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + runtime = DeviceRuntime( + driver=driver, + device_id="itest-robot-exec", + allow_insecure=True, + ) + + task = asyncio.create_task(runtime.run()) + try: + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + result = await asyncio.to_thread( + conn.invoke, + "itest-robot-exec", "execute", + {"instruction": "test move", "policy_provider": "mock", "duration": 5.0}, + ) + assert "result" in result, f"Expected result in {result}" + robot.start_task.assert_called_once() + finally: + await asyncio.to_thread(disconnect) + finally: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + async def test_robot_stop_rpc(self): + """Invoke stop RPC on a registered robot.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + + robot = _make_mock_robot() + driver = RobotDeviceDriver(robot) + runtime = DeviceRuntime( + driver=driver, + device_id="itest-robot-stop", + allow_insecure=True, + ) + + task = asyncio.create_task(runtime.run()) + try: + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + result = await asyncio.to_thread(conn.invoke, "itest-robot-stop", "stop") + assert "result" in result + robot.stop_task.assert_called_once() + finally: + await asyncio.to_thread(disconnect) + finally: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + +class TestSimDriverRegistration: + """Test that SimulationDeviceDriver registers and is discoverable.""" + + async def test_sim_driver_registers(self): + """Create SimulationDeviceDriver + DeviceRuntime, verify device is discoverable.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + runtime = DeviceRuntime( + driver=driver, + device_id="itest-sim-001", + allow_insecure=True, + ) + + task = asyncio.create_task(runtime.run()) + try: + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices, device_type="strands_sim") + device_ids = [d["device_id"] for d in devices] + assert "itest-sim-001" in device_ids + finally: + await asyncio.to_thread(disconnect) + finally: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + async def test_sim_step_rpc(self): + """Invoke step RPC on a registered simulation.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + + sim = _make_mock_sim() + driver = SimulationDeviceDriver(sim) + runtime = DeviceRuntime( + driver=driver, + device_id="itest-sim-step", + allow_insecure=True, + ) + + task = asyncio.create_task(runtime.run()) + try: + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + result = await asyncio.to_thread(conn.invoke, "itest-sim-step", "step", {"n_steps": 10}) + assert "result" in result + sim.step.assert_called_once_with(10) + finally: + await asyncio.to_thread(disconnect) + finally: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + +class TestMultipleDevices: + """Test multiple devices registered simultaneously.""" + + async def test_multiple_devices_discoverable(self): + """Register 3 devices and verify all are discoverable.""" + from device_connect_edge import DeviceRuntime + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + + robot1 = _make_mock_robot("robot-a") + robot2 = _make_mock_robot("robot-b") + sim1 = _make_mock_sim("sim-c") + + runtimes = [] + tasks = [] + for device_id, driver_cls, instance in [ + ("itest-multi-a", RobotDeviceDriver, robot1), + ("itest-multi-b", RobotDeviceDriver, robot2), + ("itest-multi-c", SimulationDeviceDriver, sim1), + ]: + driver = driver_cls(instance) + runtime = DeviceRuntime( + driver=driver, + device_id=device_id, + allow_insecure=True, + ) + runtimes.append(runtime) + tasks.append(asyncio.create_task(runtime.run())) + + try: + await asyncio.sleep(5) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + devices = await asyncio.to_thread(conn.list_devices) + device_ids = {d["device_id"] for d in devices} + assert "itest-multi-a" in device_ids + assert "itest-multi-b" in device_ids + assert "itest-multi-c" in device_ids + finally: + await asyncio.to_thread(disconnect) + finally: + for t in tasks: + t.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + + +class TestInitDeviceConnectE2E: + """End-to-end test of init_device_connect().""" + + async def test_init_device_connect_e2e(self): + """init_device_connect() -> device registers -> discoverable -> invocable.""" + from strands_robots.device_connect import init_device_connect + + robot = _make_mock_robot("e2e-robot") + runtime = await init_device_connect(robot, peer_id="itest-e2e-robot") + + try: + # Wait for registration + await asyncio.sleep(3) + + from device_connect_agent_tools.connection import connect, disconnect, get_connection + await asyncio.to_thread(connect) + try: + conn = get_connection() + + # Discoverable + devices = await asyncio.to_thread(conn.list_devices, device_type="strands_robot") + device_ids = [d["device_id"] for d in devices] + assert "itest-e2e-robot" in device_ids + + # Invocable + result = await asyncio.to_thread(conn.invoke, "itest-e2e-robot", "getStatus") + assert "result" in result + finally: + await asyncio.to_thread(disconnect) + finally: + await runtime.stop() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From d1ac430a67848d2051d8aaa1218a3ac05177bcd0 Mon Sep 17 00:00:00 2001 From: kavya-chennoju Date: Mon, 8 Jun 2026 13:49:13 -0700 Subject: [PATCH 2/6] chore: make device-connect an optional extra; use real clone URL in GUIDE Addresses review feedback on #370: - pyproject: move device-connect-{edge,agent-tools} out of core dependencies into a new [device-connect] extra, and include it in [all] - setup.sh: install [sim,device-connect] so the D2D demo still pulls in DC - README: document the new device-connect extra in the install table - GUIDE: clone strands-labs/robots.git instead of the fork feature branch --- README.md | 1 + pyproject.toml | 11 ++++++++--- strands_robots/device_connect/GUIDE.md | 2 +- strands_robots/device_connect/setup.sh | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 35793754..7cfe7607 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,7 @@ extras you need: | `groot-service` | pyzmq, msgpack | NVIDIA GR00T inference client | | `mesh` | eclipse-zenoh, json5 | Peer-to-peer robot mesh | | `mesh-iot` | awsiotsdk, awscrt, boto3 | AWS IoT Core mesh transport for fleets | +| `device-connect` | device-connect-edge, device-connect-agent-tools | Device-aware networking — discovery, RPC, events, safety (falls back to the built-in mesh if absent) | | `benchmark-libero` | libero | LIBERO benchmark evaluation | | `all` | everything above | Kitchen sink | diff --git a/pyproject.toml b/pyproject.toml index 4a8619e9..523d6b5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,9 +38,6 @@ dependencies = [ "numpy>=1.21.0,<3.0.0", "opencv-python-headless>=4.5.0,<5.0.0", "Pillow>=8.0.0,<12.0.0", - # Device Connect — primary networking layer (discovery, RPC, events, safety) - "device-connect-edge>=0.2.0", - "device-connect-agent-tools>=0.1.0", ] [project.optional-dependencies] @@ -80,12 +77,20 @@ mesh-iot = [ "awscrt>=0.20.0,<1.0.0", "boto3>=1.34.0,<2.0.0", ] +# Device Connect — device-aware networking layer (discovery, RPC, events, +# safety). The primary transport in server mode; when this extra is not +# installed, robot_mesh() falls back to the built-in Zenoh mesh. +device-connect = [ + "device-connect-edge>=0.2.0", + "device-connect-agent-tools>=0.1.0", +] all = [ "strands-robots[groot-service]", "strands-robots[lerobot]", "strands-robots[sim-mujoco]", "strands-robots[mesh]", "strands-robots[mesh-iot]", + "strands-robots[device-connect]", ] dev = [ "pytest>=6.0,<9.0.0", diff --git a/strands_robots/device_connect/GUIDE.md b/strands_robots/device_connect/GUIDE.md index b09d0265..230757ef 100644 --- a/strands_robots/device_connect/GUIDE.md +++ b/strands_robots/device_connect/GUIDE.md @@ -76,7 +76,7 @@ No Docker needed. No env vars. Devices discover each other directly on the LAN v > `setup.sh` installs `uv`, Python 3.12, creates a venv, and installs all dependencies. ```bash -git clone --branch feat/device-connect-on-main https://github.com/kavya-chennoju/robots.git +git clone https://github.com/strands-labs/robots.git cd robots ./strands_robots/device_connect/setup.sh source .venv/bin/activate diff --git a/strands_robots/device_connect/setup.sh b/strands_robots/device_connect/setup.sh index 4f735654..cb33ce81 100755 --- a/strands_robots/device_connect/setup.sh +++ b/strands_robots/device_connect/setup.sh @@ -42,7 +42,7 @@ fi # shellcheck disable=SC1091 source "$REPO_ROOT/$VENV_DIR/bin/activate" -uv pip install -e "$REPO_ROOT[sim]" +uv pip install -e "$REPO_ROOT[sim,device-connect]" echo "" echo "============================================================" From 310311c0e8849b9bcc9edc722da57034fa5cc1ce Mon Sep 17 00:00:00 2001 From: kavya-chennoju Date: Mon, 8 Jun 2026 16:14:57 -0700 Subject: [PATCH 3/6] fix: resolve CodeQL findings in Device Connect integration Clears the 29 code-scanning alerts on the new DC code: - __init__: add the 6 lazy DC exports to the TYPE_CHECKING block so __all__ names are statically defined (py/undefined-export) - reachy_transport + integration tests: document intentional empty excepts (py/empty-except) - drop unused imports (json, math) and unused 'result' locals in driver tests (py/unused-import, py/unused-local-variable) - all_robots test: probe optional deps via importlib.util.find_spec instead of side-effecting imports - robot_mesh: replace the global _dc_connected flag with a _dc_state cell (py/unused-global-variable) --- strands_robots/__init__.py | 8 ++++++++ .../device_connect/reachy_mini_driver.py | 1 - .../device_connect/reachy_transport.py | 6 +++--- strands_robots/tools/robot_mesh.py | 7 +++---- tests/test_device_connect_all_robots.py | 10 +++------- tests/test_device_connect_drivers.py | 19 +++++++++---------- tests/test_device_connect_integration.py | 10 +++++----- 7 files changed, 31 insertions(+), 30 deletions(-) diff --git a/strands_robots/__init__.py b/strands_robots/__init__.py index 1027d77f..8850888e 100644 --- a/strands_robots/__init__.py +++ b/strands_robots/__init__.py @@ -30,6 +30,13 @@ # the lazy attributes below (the runtime __getattr__ resolves them to Any # from the static analyzer's perspective). PEP 562. if TYPE_CHECKING: + from strands_robots.device_connect import ( + ReachyMiniDriver, + RobotDeviceDriver, + SimulationDeviceDriver, + init_device_connect, + init_device_connect_sync, + ) from strands_robots.policies.groot import Gr00tPolicy from strands_robots.registry import list_robots from strands_robots.robot import Robot @@ -48,6 +55,7 @@ from strands_robots.tools.lerobot_camera import lerobot_camera from strands_robots.tools.lerobot_teleoperate import lerobot_teleoperate from strands_robots.tools.pose_tool import pose_tool + from strands_robots.tools.robot_mesh import robot_mesh from strands_robots.tools.serial_tool import serial_tool # ------------------------------------------------------------------ diff --git a/strands_robots/device_connect/reachy_mini_driver.py b/strands_robots/device_connect/reachy_mini_driver.py index d3a46398..3196d215 100644 --- a/strands_robots/device_connect/reachy_mini_driver.py +++ b/strands_robots/device_connect/reachy_mini_driver.py @@ -8,7 +8,6 @@ """ import asyncio -import json import logging import math from typing import Optional diff --git a/strands_robots/device_connect/reachy_transport.py b/strands_robots/device_connect/reachy_transport.py index 402bac4b..357ac813 100644 --- a/strands_robots/device_connect/reachy_transport.py +++ b/strands_robots/device_connect/reachy_transport.py @@ -94,13 +94,13 @@ async def _on_joints(data: bytes, _reply=None): try: on_joints(json.loads(data.decode())) except Exception: - pass + pass # drop malformed/partial frame; keep the subscription alive async def _on_imu(data: bytes, _reply=None): try: on_imu(json.loads(data.decode())) except Exception: - pass + pass # drop malformed/partial frame; keep the subscription alive await self._transport.subscribe(f"{self._prefix}/joint_positions", _on_joints) await self._transport.subscribe(f"{self._prefix}/imu_data", _on_imu) @@ -146,7 +146,7 @@ async def _read_loop(self, on_joints: Callable, on_imu: Callable) -> None: elif t == "imu_data": on_imu(msg) except Exception: - pass + pass # skip malformed frame; keep reading async def stop(self) -> None: if self._read_task: diff --git a/strands_robots/tools/robot_mesh.py b/strands_robots/tools/robot_mesh.py index f444003d..bdac3a1c 100644 --- a/strands_robots/tools/robot_mesh.py +++ b/strands_robots/tools/robot_mesh.py @@ -256,7 +256,7 @@ def _resolve_mesh(target: str) -> Any | None: # audit. When DC is unavailable or has discovered no devices the helpers return # None and robot_mesh() falls through to the built-in mesh path. -_dc_connected = False +_dc_state = {"connected": False} class _DCResult(dict): @@ -271,8 +271,7 @@ def __str__(self) -> str: def _dc_ensure_connected() -> None: """Establish the Device Connect agent-side connection (idempotent).""" - global _dc_connected - if _dc_connected: + if _dc_state["connected"]: return os.environ.setdefault("MESSAGING_BACKEND", "zenoh") os.environ.setdefault("DEVICE_CONNECT_ALLOW_INSECURE", "true") @@ -282,7 +281,7 @@ def _dc_ensure_connected() -> None: get_connection() except Exception: connect() - _dc_connected = True + _dc_state["connected"] = True def _try_device_connect( diff --git a/tests/test_device_connect_all_robots.py b/tests/test_device_connect_all_robots.py index 72f5d0c1..34e330ab 100644 --- a/tests/test_device_connect_all_robots.py +++ b/tests/test_device_connect_all_robots.py @@ -10,6 +10,7 @@ """ import asyncio +import importlib.util import json import pathlib import sys @@ -570,9 +571,7 @@ def _passthrough_tool(*args, **kwargs): return lambda fn: fn -try: - import strands # noqa: F401 — use the real package when installed -except Exception: +if importlib.util.find_spec("strands") is None: _m = MagicMock() _m.tool = _passthrough_tool _types_tools = MagicMock() @@ -581,10 +580,7 @@ def _passthrough_tool(*args, **kwargs): sys.modules["strands.types"] = MagicMock() sys.modules["strands.types.tools"] = _types_tools -try: - import device_connect_agent_tools # noqa: F401 - import device_connect_agent_tools.connection # noqa: F401 -except Exception: +if importlib.util.find_spec("device_connect_agent_tools") is None: sys.modules.setdefault("device_connect_agent_tools", MagicMock()) sys.modules.setdefault("device_connect_agent_tools.connection", MagicMock()) diff --git a/tests/test_device_connect_drivers.py b/tests/test_device_connect_drivers.py index bf7da286..860de719 100644 --- a/tests/test_device_connect_drivers.py +++ b/tests/test_device_connect_drivers.py @@ -8,7 +8,6 @@ import asyncio import json -import math import sys import unittest from dataclasses import dataclass @@ -352,25 +351,25 @@ def test_stop_rpc(self): def test_get_status_rpc(self): sim = _make_mock_sim() driver = SimulationDeviceDriver(sim) - result = asyncio.run(driver.getStatus()) + asyncio.run(driver.getStatus()) sim.get_state.assert_called_once() def test_get_features_rpc(self): sim = _make_mock_sim() driver = SimulationDeviceDriver(sim) - result = asyncio.run(driver.getFeatures()) + asyncio.run(driver.getFeatures()) sim.get_features.assert_called_once() def test_step_rpc(self): sim = _make_mock_sim() driver = SimulationDeviceDriver(sim) - result = asyncio.run(driver.step(10)) + asyncio.run(driver.step(10)) sim.step.assert_called_once_with(10) def test_reset_rpc(self): sim = _make_mock_sim() driver = SimulationDeviceDriver(sim) - result = asyncio.run(driver.reset()) + asyncio.run(driver.reset()) sim.reset.assert_called_once() def test_emergency_stop_handler(self): @@ -616,7 +615,7 @@ def test_creates_robot_driver(self, MockRuntime): robot = _make_mock_robot() loop = asyncio.new_event_loop() - result = loop.run_until_complete(init_device_connect(robot, peer_id="test-1", peer_type="robot")) + loop.run_until_complete(init_device_connect(robot, peer_id="test-1", peer_type="robot")) loop.close() # Verify DeviceRuntime was created with a RobotDeviceDriver @@ -636,7 +635,7 @@ def test_creates_sim_driver(self, MockRuntime): sim = _make_mock_sim() loop = asyncio.new_event_loop() - result = loop.run_until_complete(init_device_connect(sim, peer_id="test-sim", peer_type="sim")) + loop.run_until_complete(init_device_connect(sim, peer_id="test-sim", peer_type="sim")) loop.close() call_kwargs = MockRuntime.call_args @@ -653,7 +652,7 @@ def test_generates_device_id(self, MockRuntime): robot = _make_mock_robot(tool_name="so100") loop = asyncio.new_event_loop() - result = loop.run_until_complete(init_device_connect(robot)) + loop.run_until_complete(init_device_connect(robot)) loop.close() call_kwargs = MockRuntime.call_args @@ -670,7 +669,7 @@ def test_explicit_device_id(self, MockRuntime): robot = _make_mock_robot() loop = asyncio.new_event_loop() - result = loop.run_until_complete(init_device_connect(robot, peer_id="my-robot-42")) + loop.run_until_complete(init_device_connect(robot, peer_id="my-robot-42")) loop.close() call_kwargs = MockRuntime.call_args @@ -687,7 +686,7 @@ def test_sets_heartbeat_provider(self, MockRuntime): robot = _make_mock_robot() loop = asyncio.new_event_loop() - result = loop.run_until_complete(init_device_connect(robot, peer_id="test-hb")) + loop.run_until_complete(init_device_connect(robot, peer_id="test-hb")) loop.close() mock_runtime.set_heartbeat_provider.assert_called_once() diff --git a/tests/test_device_connect_integration.py b/tests/test_device_connect_integration.py index 2d821ffc..58741a5e 100644 --- a/tests/test_device_connect_integration.py +++ b/tests/test_device_connect_integration.py @@ -139,7 +139,7 @@ async def test_robot_driver_registers(self): try: await task except (asyncio.CancelledError, Exception): - pass + pass # teardown: listener task was cancelled, ignore its outcome async def test_robot_execute_rpc(self): """Discover robot and invoke execute RPC.""" @@ -176,7 +176,7 @@ async def test_robot_execute_rpc(self): try: await task except (asyncio.CancelledError, Exception): - pass + pass # teardown: listener task was cancelled, ignore its outcome async def test_robot_stop_rpc(self): """Invoke stop RPC on a registered robot.""" @@ -209,7 +209,7 @@ async def test_robot_stop_rpc(self): try: await task except (asyncio.CancelledError, Exception): - pass + pass # teardown: listener task was cancelled, ignore its outcome class TestSimDriverRegistration: @@ -246,7 +246,7 @@ async def test_sim_driver_registers(self): try: await task except (asyncio.CancelledError, Exception): - pass + pass # teardown: listener task was cancelled, ignore its outcome async def test_sim_step_rpc(self): """Invoke step RPC on a registered simulation.""" @@ -279,7 +279,7 @@ async def test_sim_step_rpc(self): try: await task except (asyncio.CancelledError, Exception): - pass + pass # teardown: listener task was cancelled, ignore its outcome class TestMultipleDevices: From 2a7abd5874fc8e4f9c3c522dd716adfeac0047e9 Mon Sep 17 00:00:00 2001 From: strands-agent <217235299+strands-agent@users.noreply.github.com> Date: Tue, 9 Jun 2026 15:33:53 -0400 Subject: [PATCH 4/6] feat(mesh): add device-native RPC path to robot_mesh (action='rpc') robot_mesh could discover a Device Connect peer's advertised functions (e.g. the Reachy's nod/look/playMove) but tell/send rejected them: every command was forced through mesh.security.validate_command, whose ALLOWED_ACTIONS allowlist describes the SO-100/SO-101 policy-dispatch surface, not an arbitrary device's own function set. This blocked the Device Connect PRs from driving non-policy devices. Changes: - security.validate_device_rpc(function, params): a dedicated validator for device-native RPC. Enforces an identifier charset on the function name (no shell metacharacters/dots/slashes/control bytes), bounds the params object (identifier-safe keys, JSON-serialisable, <=64 KB), but deliberately does NOT apply ALLOWED_ACTIONS. - robot_mesh action='rpc' (+ function= param): validates via validate_device_rpc, then conn.invoke(target, function, params) directly. Inherits the existing rate-limit + audit machinery; falls back with a clear error when Device Connect is unavailable. - tests/test_device_rpc_validation.py: 18 unit tests covering accept, not-gated-by-allowlist, charset/size/serialisation rejections, and param-copy semantics. Verified live against a Reachy on the beta tenant (nod -> success) and that a malicious function name (e.g. 'rm -rf /') is rejected. lint clean (ruff check + format), mypy clean (only pre-existing device_connect_agent_tools import-untyped note), 94 security tests pass. --- strands_robots/mesh/security.py | 87 +++++++++++++++++++++++++ strands_robots/tools/robot_mesh.py | 98 +++++++++++++++++++++++++---- tests/test_device_rpc_validation.py | 67 ++++++++++++++++++++ 3 files changed, 239 insertions(+), 13 deletions(-) create mode 100644 tests/test_device_rpc_validation.py diff --git a/strands_robots/mesh/security.py b/strands_robots/mesh/security.py index 9cd0edb3..98a32d76 100644 --- a/strands_robots/mesh/security.py +++ b/strands_robots/mesh/security.py @@ -208,6 +208,22 @@ } ) +#: Device Connect native-RPC function names (e.g. the Reachy's ``nod`` / +#: ``look`` / ``playMove``). These are device-defined, NOT members of +#: :data:`ALLOWED_ACTIONS` -- the policy-robot action allowlist does not +#: apply to a device's own advertised function surface. We still bound the +#: name to a conservative identifier charset so a function name cannot carry +#: control bytes / shell metacharacters into the device runtime or audit log. +_DC_RPC_FUNC_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + +#: Max length of a Device Connect RPC function name. +MAX_DC_RPC_FUNC_LEN: int = 64 + +#: Max JSON-encoded byte size of a Device Connect RPC params object. Keeps a +#: native-function call from becoming a DoS vector, mirroring +#: :data:`MAX_WORLD_UPDATE_BYTES`. +MAX_DC_RPC_PARAMS_BYTES: int = 64 * 1024 + #: Default allowlist for VLA policy server targets (loopback only). _DEFAULT_POLICY_HOSTS: frozenset[str] = frozenset({"localhost", "127.0.0.1", "::1"}) @@ -985,6 +1001,74 @@ def validate_command(cmd: dict[str, Any]) -> dict[str, Any]: return out +def validate_device_rpc(function: str, params: Any = None) -> tuple[str, dict[str, Any]]: + """Validate a Device Connect *native* RPC call and return a sanitised copy. + + This is the validation path for device-defined functions invoked directly + over Device Connect (``conn.invoke(target, function, params)``) -- e.g. the + Reachy's ``nod`` / ``look`` / ``playMove``. Unlike :func:`validate_command`, + it deliberately does NOT enforce :data:`ALLOWED_ACTIONS`: that allowlist + describes the SO-100/SO-101 policy-robot dispatch surface, not an arbitrary + device's advertised function set. A Reachy legitimately exposes ``nod``; + rejecting it because it is not in the policy allowlist is the bug this + function closes. + + What it DOES enforce (defence-in-depth, since the function name and params + flow into the device runtime, RPC subjects, and audit logs): + + * ``function``: non-empty ``str`` matching :data:`_DC_RPC_FUNC_RE` + (``[A-Za-z_][A-Za-z0-9_]*``), at most :data:`MAX_DC_RPC_FUNC_LEN` chars. + No dots / slashes / whitespace / control bytes / shell metacharacters. + * ``params``: ``None`` or a JSON object (``dict``) whose keys are + identifier-safe strings and whose JSON-encoded size is bounded by + :data:`MAX_DC_RPC_PARAMS_BYTES`. Values are left opaque (the device + contract defines them) but the whole object must be JSON-serialisable. + + Returns ``(function, params_dict)`` -- ``params_dict`` is ``{}`` when no + params were supplied. Raises :class:`ValidationError` on any violation. + """ + if not isinstance(function, str) or not function: + raise ValidationError("device_rpc requires a non-empty function name (string)") + if len(function) > MAX_DC_RPC_FUNC_LEN: + raise ValidationError( + f"device_rpc function name length {len(function)} > MAX_DC_RPC_FUNC_LEN ({MAX_DC_RPC_FUNC_LEN})." + ) + if not _DC_RPC_FUNC_RE.fullmatch(function): + raise ValidationError( + f"device_rpc function={function!r} must match [A-Za-z_][A-Za-z0-9_]* " + "(no dots, slashes, whitespace, control chars, or shell metacharacters)." + ) + + if params is None: + return function, {} + if not isinstance(params, dict): + raise ValidationError("device_rpc params must be a JSON object (dict) or null") + + # Keys must be identifier-safe; the device defines what values mean, so we + # leave them opaque but require the whole object to be JSON-serialisable + # and size-bounded. + for key in params: + if not isinstance(key, str) or not key: + raise ValidationError("device_rpc params keys must be non-empty strings") + if len(key) > MAX_DC_RPC_FUNC_LEN: + raise ValidationError( + f"device_rpc params key length {len(key)} > MAX_DC_RPC_FUNC_LEN ({MAX_DC_RPC_FUNC_LEN})." + ) + if not _DC_RPC_FUNC_RE.fullmatch(key): + raise ValidationError( + f"device_rpc params key {key!r} must match [A-Za-z_][A-Za-z0-9_]* " + "(no dots, slashes, whitespace, control chars, or shell metacharacters)." + ) + try: + encoded = json.dumps(params) + except (TypeError, ValueError) as exc: + raise ValidationError(f"device_rpc params is not JSON-serialisable: {exc}") from exc + if len(encoded.encode("utf-8")) > MAX_DC_RPC_PARAMS_BYTES: + raise ValidationError(f"device_rpc params encoded size > MAX_DC_RPC_PARAMS_BYTES ({MAX_DC_RPC_PARAMS_BYTES}).") + + return function, dict(params) + + def validate_input_frame(action: Any) -> dict[str, float]: """Validate and sanitise a teleop input frame, returning a clean copy. @@ -1074,6 +1158,9 @@ def validate_input_frame(action: Any) -> dict[str, float]: "is_safe_policy_type", "is_safe_server_address", "validate_command", + "validate_device_rpc", "validate_input_frame", + "MAX_DC_RPC_FUNC_LEN", + "MAX_DC_RPC_PARAMS_BYTES", "LockoutError", ] diff --git a/strands_robots/tools/robot_mesh.py b/strands_robots/tools/robot_mesh.py index bdac3a1c..12b8bc6d 100644 --- a/strands_robots/tools/robot_mesh.py +++ b/strands_robots/tools/robot_mesh.py @@ -54,6 +54,7 @@ "send": (30, 60.0), "broadcast": (10, 60.0), "stop": (20, 60.0), + "rpc": (30, 60.0), "emergency_stop": (3, 60.0), } _RATE_HISTORY: dict[str, collections.deque[float]] = {} @@ -285,8 +286,15 @@ def _dc_ensure_connected() -> None: def _try_device_connect( - action: str, target: str, instruction: str, command: str, - policy_provider: str, policy_port: int, duration: float, timeout: float, + action: str, + target: str, + instruction: str, + command: str, + policy_provider: str, + policy_port: int, + duration: float, + timeout: float, + function: str = "", ) -> dict[str, Any] | None: """Dispatch *action* through Device Connect, or return None to fall back. @@ -313,14 +321,28 @@ def _try_device_connect( if not isinstance(devices, (list, tuple)) or not devices: return None return _device_connect_dispatch( - action, target, instruction, command, - policy_provider, policy_port, duration, timeout, + action, + target, + instruction, + command, + policy_provider, + policy_port, + duration, + timeout, + function, ) def _device_connect_dispatch( - action: str, target: str, instruction: str, command: str, - policy_provider: str, policy_port: int, duration: float, timeout: float, + action: str, + target: str, + instruction: str, + command: str, + policy_provider: str, + policy_port: int, + duration: float, + timeout: float, + function: str = "", ) -> dict[str, Any] | None: """Render a robot_mesh action through Device Connect (dev-compatible API). @@ -340,8 +362,7 @@ def _device_connect_dispatch( ) for d in devices: dtype = d.get("device_type", "?") - icon = {"strands_robot": "robot", "strands_sim": "sim", - "reachy_mini": "reachy"}.get(dtype, dtype) + icon = {"strands_robot": "robot", "strands_sim": "sim", "reachy_mini": "reachy"}.get(dtype, dtype) status = d.get("status", {}) avail = status.get("availability", "?") if isinstance(status, dict) else "?" text += f" [{icon}] {d['device_id']} — {avail}\n" @@ -391,6 +412,35 @@ def _device_connect_dispatch( _audit_tool_action(action, target, True, f"action={func}") return _DCResult(_ok(f"{target}:\n{json.dumps(r, indent=2, default=str)[:2000]}")) + if action == "rpc": + # Device-native RPC (e.g. Reachy nod/look/playMove). Validated via + # security.validate_device_rpc (charset + bounded params) WITHOUT + # the policy-action allowlist, then invoked directly on the device. + if not target: + return _DCResult(_err("rpc requires target")) + if not function: + return _DCResult(_err("rpc requires function (the device-native function name)")) + rpc_params: dict[str, Any] = {} + if command: + try: + parsed = json.loads(command) + except json.JSONDecodeError as exc: + return _DCResult(_err(f"rpc params (command) is not valid JSON: {exc}")) + if not isinstance(parsed, dict): + return _DCResult(_err("rpc params (command) must decode to a JSON object (dict)")) + rpc_params = parsed + try: + func_name, rpc_params = _security.validate_device_rpc(function, rpc_params) + except _security.ValidationError as exc: + _audit_tool_action(action, target, False, f"validation: {exc}") + return _DCResult(_err(f"rpc rejected: {exc}")) + result = conn.invoke(target, func_name, rpc_params, timeout=timeout) + r = result.get("result", result) if isinstance(result, dict) else result + _audit_tool_action(action, target, True, f"function={func_name}") + return _DCResult( + _ok(f"{target}.{func_name}({rpc_params}) ->\n{json.dumps(r, indent=2, default=str)[:2000]}") + ) + if action == "stop": if not target: return _DCResult(_err("stop requires target")) @@ -448,13 +498,19 @@ def robot_mesh( timeout: float = 30.0, name: str = "", limit: int = 50, + function: str = "", ) -> dict[str, Any]: """Coordinate every robot, sim, and agent on the local Zenoh mesh. Args: action: One of ``peers`` / ``status`` / ``tell`` / ``send`` / - ``broadcast`` / ``stop`` / ``emergency_stop`` / ``subscribe`` / - ``unsubscribe`` / ``watch`` / ``inbox``. + ``rpc`` / ``broadcast`` / ``stop`` / ``emergency_stop`` / + ``subscribe`` / ``unsubscribe`` / ``watch`` / ``inbox``. + ``rpc`` calls a device's NATIVE Device Connect function (e.g. + the Reachy's ``nod`` / ``look`` / ``playMove``) directly, + bypassing the policy-action allowlist that ``tell`` / ``send`` + enforce. Pass the function name in ``function`` and any kwargs + as a JSON object in ``command``. target: Peer id (for ``tell`` / ``send`` / ``stop`` / ``watch``) or Zenoh topic pattern (for ``subscribe``). instruction: Natural-language instruction for ``tell``. @@ -465,6 +521,7 @@ def robot_mesh( timeout: Response timeout for RPC actions (seconds). name: Optional subscription name for ``subscribe`` / ``inbox``. limit: Max messages returned by ``inbox`` (default: 50). + function: Device-native function name for ``rpc`` (e.g. ``nod``). Returns: A Strands tool response dict with status and a single text block. @@ -602,8 +659,15 @@ def robot_mesh( # _try_device_connect returns None when DC is unavailable or has discovered # no devices, in which case we fall through to the built-in mesh below. _dc_result = _try_device_connect( - action, target, instruction, command, - policy_provider, policy_port, duration, timeout, + action, + target, + instruction, + command, + policy_provider, + policy_port, + duration, + timeout, + function, ) if _dc_result is not None: return _dc_result @@ -806,8 +870,16 @@ def robot_mesh( mesh.unsubscribe(sub_name) return _ok(f"[unsub] unsubscribed from '{sub_name}'") + if action == "rpc": + return _err( + "rpc (device-native function call) requires Device Connect, which is " + "unavailable or has discovered no devices in this context. The built-in " + "Zenoh mesh has no equivalent. Ensure the agent connected via " + "device_connect_agent_tools.connect() and the target is online." + ) + return _err( - f"unknown action: {action!r}. Valid: peers, status, tell, send, " + f"unknown action: {action!r}. Valid: peers, status, tell, send, rpc, " "broadcast, stop, emergency_stop, subscribe, unsubscribe, watch, inbox." ) diff --git a/tests/test_device_rpc_validation.py b/tests/test_device_rpc_validation.py new file mode 100644 index 00000000..f6653080 --- /dev/null +++ b/tests/test_device_rpc_validation.py @@ -0,0 +1,67 @@ +"""Regression tests for device-native RPC validation (robot_mesh action='rpc'). + +Closes the gap where robot_mesh could discover a device's functions (e.g. the +Reachy's nod/look/playMove) but tell/send rejected them via the policy-action +ALLOWED_ACTIONS allowlist. The rpc path validates name+params WITHOUT that +allowlist, then invokes the device function directly. +""" + +import pytest + +from strands_robots.mesh import security + + +class TestValidateDeviceRpc: + def test_accepts_bare_function(self): + assert security.validate_device_rpc("nod") == ("nod", {}) + + def test_accepts_function_with_params(self): + assert security.validate_device_rpc("look", {"yaw": 30, "pitch": -15}) == ( + "look", + {"yaw": 30, "pitch": -15}, + ) + + def test_none_params_become_empty_dict(self): + assert security.validate_device_rpc("happy", None) == ("happy", {}) + + def test_not_gated_by_policy_allowlist(self): + # The whole point: 'nod' is NOT in ALLOWED_ACTIONS but is a valid RPC. + assert "nod" not in security.ALLOWED_ACTIONS + func, _ = security.validate_device_rpc("nod") + assert func == "nod" + + @pytest.mark.parametrize("bad", ["rm -rf /", "nod;reboot", "../escape", "a.b", "has space", "", "n@d"]) + def test_rejects_unsafe_function_names(self, bad): + with pytest.raises(security.ValidationError): + security.validate_device_rpc(bad) + + def test_rejects_overlong_function_name(self): + with pytest.raises(security.ValidationError): + security.validate_device_rpc("a" * (security.MAX_DC_RPC_FUNC_LEN + 1)) + + def test_rejects_non_string_function(self): + with pytest.raises(security.ValidationError): + security.validate_device_rpc(123) # type: ignore[arg-type] + + def test_rejects_non_dict_params(self): + with pytest.raises(security.ValidationError): + security.validate_device_rpc("nod", ["not", "a", "dict"]) # type: ignore[arg-type] + + def test_rejects_unsafe_param_keys(self): + with pytest.raises(security.ValidationError): + security.validate_device_rpc("look", {"bad key": 1}) + + def test_rejects_oversize_params(self): + big = {"k": "x" * (security.MAX_DC_RPC_PARAMS_BYTES + 10)} + with pytest.raises(security.ValidationError): + security.validate_device_rpc("playMove", big) + + def test_rejects_non_serialisable_params(self): + with pytest.raises(security.ValidationError): + security.validate_device_rpc("look", {"obj": object()}) + + def test_returns_copy_of_params(self): + src = {"yaw": 1} + _, out = security.validate_device_rpc("look", src) + out["yaw"] = 99 + assert src["yaw"] == 1 # original untouched From fd66308a9b4145638c7fca024e0cb9eb89980f68 Mon Sep 17 00:00:00 2001 From: strands-agent <217235299+strands-agent@users.noreply.github.com> Date: Thu, 11 Jun 2026 16:57:51 -0400 Subject: [PATCH 5/6] fix(device-connect): security hardening for the Device Connect integration Harden the Device Connect integration across the agent dispatch path, the device drivers, and the Reachy transport. All changes are defence-in-depth and preserve existing behaviour on the trusted-path defaults. - robot_mesh broadcast: dispatch the validated, operator-approved command through Device Connect instead of re-parsing the raw caller string, so the executed action cannot differ from the approved one. - robot_mesh rpc: route device-native rpc through the human-in-the-loop approval gate (surfacing the function name) so device-native functions are not invoked without explicit approval. - drivers: validate policy_provider against the existing allowlist in both robot and sim execute() before starting a policy, preventing inference from being pointed at an arbitrary endpoint. - drivers: add per-call caller authorization on state-mutating RPCs (execute/stop/step/reset) and validate the source of emergencyStop against an allowlist. Env-driven (DEVICE_CONNECT_RPC_ALLOW / DEVICE_CONNECT_ESTOP_ALLOW); permissive with a warning when unset, fail-closed when set. Uses get_rpc_source_device() from device-connect-edge. - reachy: validate playMove move_name against a strict charset before interpolating it into the daemon REST path (prevents path traversal / query injection). - reachy: support an optional bearer token (REACHY_DAEMON_TOKEN) on the daemon WebSocket and REST interfaces and warn loudly when the link is unauthenticated. - transport: make Device Connect secure by default. allow_insecure now defaults to False and must be opted into explicitly (arg or DEVICE_CONNECT_ALLOW_INSECURE); a warning is logged whenever insecure mode is active. Removed the agent-side setdefault that forced insecure mode process-wide. - GUIDE.md: document secure-by-default and the explicit insecure opt-in for local D2D trials. - CI: when a PR touches the Device Connect integration, pin device-connect-edge / agent-tools to source via UV_OVERRIDE so the matching framework version is exercised. - tests: add tests/test_device_connect_hardening.py covering all of the above. --- .github/workflows/test-lint.yml | 42 ++ strands_robots/device_connect/GUIDE.md | 15 +- strands_robots/device_connect/__init__.py | 20 +- strands_robots/device_connect/_authz.py | 100 +++++ .../device_connect/reachy_mini_driver.py | 8 + .../device_connect/reachy_transport.py | 54 ++- strands_robots/device_connect/robot_driver.py | 36 +- strands_robots/device_connect/sim_driver.py | 42 +- strands_robots/tools/robot_mesh.py | 37 +- tests/test_device_connect_hardening.py | 398 ++++++++++++++++++ 10 files changed, 731 insertions(+), 21 deletions(-) create mode 100644 strands_robots/device_connect/_authz.py create mode 100644 tests/test_device_connect_hardening.py diff --git a/.github/workflows/test-lint.yml b/.github/workflows/test-lint.yml index dfc9380f..457bd86e 100644 --- a/.github/workflows/test-lint.yml +++ b/.github/workflows/test-lint.yml @@ -18,6 +18,8 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: ref: ${{ inputs.ref }} + # Need history so we can diff the changed paths against the base. + fetch-depth: 0 persist-credentials: false - name: Set up Python @@ -31,6 +33,46 @@ jobs: sudo apt-get update sudo apt-get install -y libosmesa6-dev ffmpeg + # The Device Connect integration (strands_robots/device_connect) can + # depend on changes in the device-connect-edge / agent-tools packages + # that are not yet on PyPI. When this PR touches the integration, install + # those packages from source (DEVICE_CONNECT_SOURCE_REF) so CI exercises + # the matching version instead of the published one. + - name: Detect Device Connect integration changes + id: dc_changed + run: | + BASE="origin/${{ github.base_ref || 'main' }}" + git fetch origin "${{ github.base_ref || 'main' }}" --depth=1 || true + if git diff --name-only "$BASE"...HEAD 2>/dev/null | grep -q '^strands_robots/device_connect/'; then + echo "changed=true" >> "$GITHUB_OUTPUT" + echo "Device Connect integration changed -> will install device-connect from source" + else + echo "changed=false" >> "$GITHUB_OUTPUT" + echo "No Device Connect integration changes -> using published device-connect-edge" + fi + + - name: Pin device-connect to source (integration changed) + if: steps.dc_changed.outputs.changed == 'true' + env: + # Source of truth for the device-connect packages while the + # corresponding device-connect PR is unreleased. Update the ref (or + # repo) here to point at the branch/tag that carries the matching + # changes; revert to the published wheel once it ships to PyPI. + DEVICE_CONNECT_REPO: ${{ vars.DEVICE_CONNECT_REPO || 'arm/device-connect' }} + DEVICE_CONNECT_SOURCE_REF: ${{ vars.DEVICE_CONNECT_SOURCE_REF || 'main' }} + run: | + # The hatch test env is an ISOLATED uv venv, so an outer `pip install` + # would not reach it. Instead we write a uv override file: uv MERGES + # UV_OVERRIDE entries with the pyproject [tool.uv] override-dependencies + # (torch pins), so this redirects ONLY the device-connect packages to + # source without disturbing anything else. + cat > "${GITHUB_WORKSPACE}/dc-source-override.txt" <> "$GITHUB_ENV" + echo "Pinned device-connect packages to ${DEVICE_CONNECT_REPO}@${DEVICE_CONNECT_SOURCE_REF}" + - name: Install dependencies env: UV_TORCH_BACKEND: cpu diff --git a/strands_robots/device_connect/GUIDE.md b/strands_robots/device_connect/GUIDE.md index 230757ef..1253d33a 100644 --- a/strands_robots/device_connect/GUIDE.md +++ b/strands_robots/device_connect/GUIDE.md @@ -13,7 +13,13 @@ r = Robot("so100") r.run() # starts listening for commands. Ctrl+C to stop. ``` -`Robot()` creates the robot. `.run()` starts Device Connect with D2D defaults (Zenoh multicast scouting, no broker, no env vars) and blocks — the robot becomes discoverable on the LAN and listens for commands. Without `.run()`, the script exits and the robot is removed from the network. +`Robot()` creates the robot. `.run()` starts Device Connect with D2D defaults (Zenoh multicast scouting, no broker) and blocks — the robot becomes discoverable on the LAN and listens for commands. Without `.run()`, the script exits and the robot is removed from the network. + +> **Secure by default:** Device Connect no longer enables unencrypted/unauthenticated transport implicitly. For a local, trusted-network D2D trial without a broker, explicitly opt in to insecure transport: +> ```bash +> export DEVICE_CONNECT_ALLOW_INSECURE=true # ONLY on a trusted, isolated LAN +> ``` +> A prominent warning is logged whenever insecure mode is active. For anything beyond a local trial, run the brokered/registry setup with mTLS (see *Full Infrastructure* below). You can optionally pass `peer_id="so100-lab-1"` for a stable address; otherwise one is auto-generated (e.g. `so100-a3f1b2`). @@ -67,7 +73,12 @@ graph TD ### E2E Demo -No Docker needed. No env vars. Devices discover each other directly on the LAN via Zenoh multicast scouting. `Robot()` and `robot_mesh()` auto-configure D2D mode when no broker URL is set. +No Docker needed. Devices discover each other directly on the LAN via Zenoh multicast scouting. `Robot()` and `robot_mesh()` auto-configure D2D mode when no broker URL is set. + +> **Secure by default:** for this local D2D trial, opt into insecure transport on every terminal first (trusted/isolated LAN only): +> ```bash +> export DEVICE_CONNECT_ALLOW_INSECURE=true +> ``` #### Setup diff --git a/strands_robots/device_connect/__init__.py b/strands_robots/device_connect/__init__.py index e9f32004..a71e5f94 100644 --- a/strands_robots/device_connect/__init__.py +++ b/strands_robots/device_connect/__init__.py @@ -84,14 +84,26 @@ async def init_device_connect( if messaging_backend is None: messaging_backend = os.environ.get("MESSAGING_BACKEND", "zenoh") - # Resolve allow_insecure: env var > explicit arg > D2D default + # Resolve allow_insecure: explicit arg > env var > secure default. + # Security hardening: insecure (unencrypted, unauthenticated) transport is + # NO LONGER the default. It must be explicitly opted into — via the + # ``allow_insecure=True`` argument or ``DEVICE_CONNECT_ALLOW_INSECURE`` env + # var — and we log a prominent warning whenever it is active so an insecure + # deployment is never silent. if allow_insecure is None: env_val = os.environ.get("DEVICE_CONNECT_ALLOW_INSECURE") if env_val is not None: allow_insecure = env_val.lower() in ("true", "1", "yes") - elif urls is None: - # D2D mode — no broker, default insecure for dev convenience - allow_insecure = True + else: + allow_insecure = False + + if allow_insecure: + logger.warning( + "Device Connect is running in INSECURE mode (unencrypted, " + "unauthenticated transport). Robot commands and state are exposed " + "to the local network. Only use this on a trusted, isolated " + "network; configure a broker / secure transport for production." + ) runtime = DeviceRuntime( driver=driver, diff --git a/strands_robots/device_connect/_authz.py b/strands_robots/device_connect/_authz.py new file mode 100644 index 00000000..8c26cb3a --- /dev/null +++ b/strands_robots/device_connect/_authz.py @@ -0,0 +1,100 @@ +"""Caller-authorization helpers for Device Connect robot/sim drivers. + +Security hardening: Device Connect RPC handlers run on the device side with no +built-in per-call authorization. State-mutating RPCs (execute / stop / step / +reset) and lifecycle events (emergencyStop) must therefore verify the calling +device against an operator-controlled allowlist before acting on physical (or +simulated) hardware. + +Allowlists are sourced from environment variables so deployments opt in without +code changes: + +* ``DEVICE_CONNECT_RPC_ALLOW`` — comma-separated device ids permitted to call + state-mutating RPCs. ``*`` (or unset) means "allow all" but logs a warning so + the permissive posture is visible. An explicit empty value (``""`` after + stripping) is treated as unset. +* ``DEVICE_CONNECT_ESTOP_ALLOW`` — comma-separated device ids permitted to + trigger emergency-stop handling. Falls back to ``DEVICE_CONNECT_RPC_ALLOW`` + when unset. + +Matching supports trailing ``*`` glob prefixes (e.g. ``safety-*``). +""" + +from __future__ import annotations + +import fnmatch +import logging +import os +from typing import Optional + +logger = logging.getLogger(__name__) + +_RPC_ALLOW_ENV = "DEVICE_CONNECT_RPC_ALLOW" +_ESTOP_ALLOW_ENV = "DEVICE_CONNECT_ESTOP_ALLOW" + +_warned_permissive: set[str] = set() + + +def _parse_allowlist(raw: Optional[str]) -> Optional[list[str]]: + """Parse a comma-separated allowlist. Returns None when unset/empty.""" + if raw is None: + return None + entries = [e.strip() for e in raw.split(",") if e.strip()] + return entries or None + + +def _matches(caller: str, patterns: list[str]) -> bool: + for pat in patterns: + if pat == "*" or fnmatch.fnmatchcase(caller, pat): + return True + return False + + +def _warn_permissive_once(scope: str) -> None: + if scope not in _warned_permissive: + _warned_permissive.add(scope) + logger.warning( + "Device Connect %s authorization is permissive (no %s allowlist set). " + "Any device that can reach the network may invoke state-mutating " + "operations. Set the allowlist to restrict callers.", + scope, + _RPC_ALLOW_ENV if scope == "rpc" else _ESTOP_ALLOW_ENV, + ) + + +def is_authorized_caller(caller: Optional[str], *, scope: str = "rpc") -> bool: + """Return True iff *caller* is authorized for the given *scope*. + + scope="rpc" -> state-mutating RPCs (execute/stop/step/reset) + scope="estop" -> emergency-stop event handling + """ + if scope == "estop": + raw = os.environ.get(_ESTOP_ALLOW_ENV) or os.environ.get(_RPC_ALLOW_ENV) + env_scope = "estop" + else: + raw = os.environ.get(_RPC_ALLOW_ENV) + env_scope = "rpc" + + patterns = _parse_allowlist(raw) + if patterns is None: + # No allowlist configured — preserve out-of-the-box dev usability but + # make the permissive posture loud so operators notice. + _warn_permissive_once(env_scope) + return True + + # Allowlist configured: a missing caller identity cannot be authorized. + if not caller: + return False + return _matches(caller, patterns) + + +def authz_error(caller: Optional[str], function: str) -> dict: + """Standard structured rejection for an unauthorized RPC call.""" + logger.warning( + "Rejected unauthorized Device Connect RPC %s from caller=%r", function, caller + ) + return { + "status": "error", + "reason": f"caller not authorized for {function!r}", + "caller": caller or "unknown", + } diff --git a/strands_robots/device_connect/reachy_mini_driver.py b/strands_robots/device_connect/reachy_mini_driver.py index 3196d215..c3688ff1 100644 --- a/strands_robots/device_connect/reachy_mini_driver.py +++ b/strands_robots/device_connect/reachy_mini_driver.py @@ -10,6 +10,7 @@ import asyncio import logging import math +import re from typing import Optional from device_connect_edge.drivers import DeviceDriver, emit, on, rpc @@ -25,6 +26,11 @@ logger = logging.getLogger(__name__) +# Security hardening: recorded-move names are interpolated into a REST URL +# path, so restrict them to a safe charset to prevent path traversal and +# query/parameter injection into the daemon API. +_MOVE_NAME_RE = re.compile(r"^[A-Za-z0-9._-]{1,128}$") + class ReachyMiniDriver(DeviceDriver): """Device Connect driver for Pollen Reachy Mini. @@ -207,6 +213,8 @@ async def playMove(self, move_name: str, library: str = "emotions") -> dict: move_name: Name of the move to play library: Move library (emotions or dance) """ + if not _MOVE_NAME_RE.fullmatch(move_name or ""): + return {"status": "error", "reason": f"invalid move_name: {move_name!r}"} ds = f"pollen-robotics/reachy-mini-{'emotions' if library == 'emotions' else 'dances'}-library" result = await asyncio.to_thread( api, self._host, self._api_port, diff --git a/strands_robots/device_connect/reachy_transport.py b/strands_robots/device_connect/reachy_transport.py index 357ac813..18dac66b 100644 --- a/strands_robots/device_connect/reachy_transport.py +++ b/strands_robots/device_connect/reachy_transport.py @@ -8,6 +8,7 @@ import json import logging import math +import os import socket from abc import ABC, abstractmethod from typing import Callable, Optional @@ -23,6 +24,34 @@ def resolve_host(host: str) -> str: return host +def _daemon_auth_token() -> Optional[str]: + """Return the Reachy daemon auth token from the environment, if configured. + + Security hardening: the daemon WebSocket/REST interfaces accept commands + that directly actuate the robot. When ``REACHY_DAEMON_TOKEN`` is set we + present it as a bearer credential so the daemon can authenticate the + caller. When it is absent we emit a one-time warning so operators are + aware the link is unauthenticated (and should be confined to a trusted + network or fronted by WSS/HTTPS with mutual TLS). + """ + return os.environ.get("REACHY_DAEMON_TOKEN") or None + + +_warned_no_auth = False + + +def _warn_unauthenticated_once(kind: str) -> None: + global _warned_no_auth + if not _warned_no_auth and not _daemon_auth_token(): + _warned_no_auth = True + logger.warning( + "Reachy daemon %s is unauthenticated (no REACHY_DAEMON_TOKEN set). " + "Anyone on the same network segment can issue robot commands. " + "Set REACHY_DAEMON_TOKEN and prefer WSS/HTTPS with mutual TLS.", + kind, + ) + + # ── REST API ───────────────────────────────────────────────────── def api(host: str, port: int, path: str, method: str = "GET", data: Optional[dict] = None) -> dict: @@ -32,6 +61,11 @@ def api(host: str, port: int, path: str, method: str = "GET", data: Optional[dic url = f"http://{host}:{port}{path}" req = urllib.request.Request(url, method=method) req.add_header("Content-Type", "application/json") + _token = _daemon_auth_token() + if _token: + req.add_header("Authorization", f"Bearer {_token}") + else: + _warn_unauthenticated_once("REST API") body = json.dumps(data).encode() if data else None try: with urllib.request.urlopen(req, body, timeout=10) as resp: @@ -133,7 +167,25 @@ def __init__(self, host: str, port: int): async def start(self, on_joints: Callable, on_imu: Callable) -> None: import websockets - self._ws = await websockets.connect(f"ws://{self._host}:{self._port}/ws/sdk") + # Security hardening: authenticate to the daemon when a token is + # configured; otherwise warn that the link is unauthenticated. + _token = _daemon_auth_token() + _extra_headers = {"Authorization": f"Bearer {_token}"} if _token else None + if not _token: + _warn_unauthenticated_once("WebSocket") + _connect_kwargs = {} + if _extra_headers: + # websockets >=12 uses additional_headers; older uses extra_headers. + try: + import inspect as _inspect + _sig = _inspect.signature(websockets.connect) + _hdr_kw = "additional_headers" if "additional_headers" in _sig.parameters else "extra_headers" + _connect_kwargs[_hdr_kw] = _extra_headers + except (ValueError, TypeError): + _connect_kwargs["extra_headers"] = _extra_headers + self._ws = await websockets.connect( + f"ws://{self._host}:{self._port}/ws/sdk", **_connect_kwargs + ) self._read_task = asyncio.create_task(self._read_loop(on_joints, on_imu)) async def _read_loop(self, on_joints: Callable, on_imu: Callable) -> None: diff --git a/strands_robots/device_connect/robot_driver.py b/strands_robots/device_connect/robot_driver.py index 78b87e7b..f72ec43c 100644 --- a/strands_robots/device_connect/robot_driver.py +++ b/strands_robots/device_connect/robot_driver.py @@ -7,9 +7,19 @@ import asyncio import logging -from device_connect_edge.drivers import DeviceDriver, emit, on, periodic, rpc +from device_connect_edge.drivers import ( + DeviceDriver, + emit, + get_rpc_source_device, + on, + periodic, + rpc, +) from device_connect_edge.types import DeviceIdentity, DeviceStatus +from strands_robots.device_connect._authz import authz_error, is_authorized_caller +from strands_robots.mesh.security import is_safe_policy_provider + logger = logging.getLogger(__name__) @@ -70,6 +80,17 @@ async def execute( duration: Maximum task duration in seconds policy_port: Policy server port (0 for default) """ + # Security hardening: authorize the calling device before mutating + # physical robot state. + caller = get_rpc_source_device() + if not is_authorized_caller(caller, scope="rpc"): + return authz_error(caller, "execute") + + # Security hardening: restrict policy_provider to the vetted allowlist + # so a caller cannot steer inference to an arbitrary network endpoint. + if not is_safe_policy_provider(policy_provider): + return {"status": "error", "reason": f"policy_provider not allowed: {policy_provider!r}"} + return self._robot.start_task( instruction, policy_provider, @@ -81,6 +102,9 @@ async def execute( @rpc() async def stop(self) -> dict: """Stop the currently running task.""" + caller = get_rpc_source_device() + if not is_authorized_caller(caller, scope="rpc"): + return authz_error(caller, "stop") return self._robot.stop_task() @rpc() @@ -171,7 +195,15 @@ async def emergencyStop(self, reason: str = ""): @on(event_name="emergencyStop") async def onEmergencyStop(self, device_id: str, event_name: str, payload: dict): - """React to emergencyStop from ANY device on the network.""" + """React to emergencyStop from an authorized safety controller. + + Security hardening: only act on emergency-stop events whose source is + in the emergency-stop allowlist, so a spoofed event from an arbitrary + device cannot interrupt operations. + """ + if not is_authorized_caller(device_id, scope="estop"): + logger.warning("Ignoring emergencyStop from unauthorized source %s", device_id) + return logger.warning("Emergency stop received from %s — stopping task", device_id) self._robot.stop_task() diff --git a/strands_robots/device_connect/sim_driver.py b/strands_robots/device_connect/sim_driver.py index f4a040f3..bd0ed2f7 100644 --- a/strands_robots/device_connect/sim_driver.py +++ b/strands_robots/device_connect/sim_driver.py @@ -6,9 +6,19 @@ import logging -from device_connect_edge.drivers import DeviceDriver, emit, on, periodic, rpc +from device_connect_edge.drivers import ( + DeviceDriver, + emit, + get_rpc_source_device, + on, + periodic, + rpc, +) from device_connect_edge.types import DeviceIdentity, DeviceStatus +from strands_robots.device_connect._authz import authz_error, is_authorized_caller +from strands_robots.mesh.security import is_safe_policy_provider + logger = logging.getLogger(__name__) @@ -70,6 +80,12 @@ async def execute( duration: Maximum task duration in seconds robot_name: Target robot name (empty = first robot) """ + # Security hardening: authorize the calling device before mutating + # simulation state. + caller = get_rpc_source_device() + if not is_authorized_caller(caller, scope="rpc"): + return authz_error(caller, "execute") + # Determine robot name name = robot_name if not name: @@ -79,6 +95,11 @@ async def execute( else: return {"status": "error", "reason": "no robots in simulation"} + # Security hardening: restrict policy_provider to the vetted allowlist + # so a caller cannot steer inference to an arbitrary network endpoint. + if not is_safe_policy_provider(policy_provider): + return {"status": "error", "reason": f"policy_provider not allowed: {policy_provider!r}"} + print(f"▶ Executing policy '{policy_provider}' on {name}: {instruction}", flush=True) return self._sim.start_policy( robot_name=name, @@ -90,6 +111,9 @@ async def execute( @rpc() async def stop(self) -> dict: """Stop all running policies.""" + caller = get_rpc_source_device() + if not is_authorized_caller(caller, scope="rpc"): + return authz_error(caller, "stop") print("⏹ Stop command received — stopping all policies", flush=True) world = getattr(self._sim, "_world", None) if world: @@ -116,11 +140,17 @@ async def step(self, n_steps: int = 1) -> dict: Args: n_steps: Number of physics steps to take """ + caller = get_rpc_source_device() + if not is_authorized_caller(caller, scope="rpc"): + return authz_error(caller, "step") return self._sim.step(n_steps) @rpc() async def reset(self) -> dict: """Reset simulation to initial state.""" + caller = get_rpc_source_device() + if not is_authorized_caller(caller, scope="rpc"): + return authz_error(caller, "reset") return self._sim.reset() # ── Events ──────────────────────────────────────────────── @@ -158,7 +188,15 @@ async def emergencyStop(self, reason: str = ""): @on(event_name="emergencyStop") async def onEmergencyStop(self, device_id: str, event_name: str, payload: dict): - """React to emergencyStop from ANY device on the network.""" + """React to emergencyStop from an authorized safety controller. + + Security hardening: only act on emergency-stop events whose source is + in the emergency-stop allowlist, so a spoofed event from an arbitrary + device cannot interrupt operations. + """ + if not is_authorized_caller(device_id, scope="estop"): + logger.warning("Ignoring emergencyStop from unauthorized source %s", device_id) + return print(f"🛑 Emergency stop received from {device_id} — stopping all policies", flush=True) world = getattr(self._sim, "_world", None) if world: diff --git a/strands_robots/tools/robot_mesh.py b/strands_robots/tools/robot_mesh.py index 12b8bc6d..55db4e8d 100644 --- a/strands_robots/tools/robot_mesh.py +++ b/strands_robots/tools/robot_mesh.py @@ -66,7 +66,7 @@ # parameter the interrupt response is delivered by the framework # out-of-band of the LLM's tool-argument flow, so an injected prompt # cannot smuggle approval. -_INTERRUPT_REQUIRED: frozenset[str] = frozenset({"emergency_stop", "broadcast"}) +_INTERRUPT_REQUIRED: frozenset[str] = frozenset({"emergency_stop", "broadcast", "rpc"}) # Affirmative responses accepted from the interrupt prompt. Anything else # (empty string, "n", "no", "cancel", whitespace) is treated as decline. @@ -275,7 +275,16 @@ def _dc_ensure_connected() -> None: if _dc_state["connected"]: return os.environ.setdefault("MESSAGING_BACKEND", "zenoh") - os.environ.setdefault("DEVICE_CONNECT_ALLOW_INSECURE", "true") + # Security hardening: do NOT force insecure transport here. Previously this + # set DEVICE_CONNECT_ALLOW_INSECURE=true process-wide, silently downgrading + # every connection in the process. Insecure mode is now strictly opt-in by + # the operator. If they have opted in, surface a warning so it is visible. + if os.environ.get("DEVICE_CONNECT_ALLOW_INSECURE", "").lower() in ("true", "1", "yes"): + logger.warning( + "DEVICE_CONNECT_ALLOW_INSECURE is enabled — agent-side Device " + "Connect traffic is unencrypted and unauthenticated. Use only on " + "a trusted, isolated network." + ) from device_connect_agent_tools.connection import connect, get_connection try: @@ -295,6 +304,7 @@ def _try_device_connect( duration: float, timeout: float, function: str = "", + validated_command: dict[str, Any] | None = None, ) -> dict[str, Any] | None: """Dispatch *action* through Device Connect, or return None to fall back. @@ -330,6 +340,7 @@ def _try_device_connect( duration, timeout, function, + validated_command, ) @@ -343,6 +354,7 @@ def _device_connect_dispatch( duration: float, timeout: float, function: str = "", + validated_command: dict[str, Any] | None = None, ) -> dict[str, Any] | None: """Render a robot_mesh action through Device Connect (dev-compatible API). @@ -462,14 +474,15 @@ def _device_connect_dispatch( return _DCResult(_ok(f"E-STOP: {stopped}/{len(devices)} devices stopped")) if action == "broadcast": - # Command was already parsed + validated by robot_mesh() before the - # HITL gate fired, so re-parsing here is safe. - func = "getStatus" - params: dict[str, Any] = {} - if command: - cmd = json.loads(command) - func = cmd.pop("action", cmd.pop("function", "getStatus")) - params = cmd + # Security hardening: dispatch the *validated* command that the + # operator approved at the HITL gate — never re-parse the raw + # caller-supplied string here (that would allow a payload whose + # validated form differs from what actually executes). + if validated_command is None: + return _DCResult(_err("broadcast reached Device Connect dispatch without a validated command")) + cmd = dict(validated_command) + func = cmd.pop("action", cmd.pop("function", "getStatus")) + params = cmd results = conn.broadcast(func, params, timeout=timeout) _audit_tool_action(action, "*", True, f"action={func} responses={len(results)}") text = f"[broadcast] {len(results)} responses\n" @@ -605,6 +618,9 @@ def robot_mesh( reason={ "action": action, "target": target if target else "*ALL_PEERS*", + # Surface the device-native function name for rpc so the + # operator approves the specific function being invoked. + "function": function if action == "rpc" else "", # R8-7: surface the validated command so the operator # approves the post-validation form, not the raw LLM # string. emergency_stop has no command body so we @@ -668,6 +684,7 @@ def robot_mesh( duration, timeout, function, + validated_broadcast_cmd, ) if _dc_result is not None: return _dc_result diff --git a/tests/test_device_connect_hardening.py b/tests/test_device_connect_hardening.py new file mode 100644 index 00000000..ad0e8e40 --- /dev/null +++ b/tests/test_device_connect_hardening.py @@ -0,0 +1,398 @@ +"""Security-hardening regression tests for the Device Connect integration. + +Covers seven hardening improvements: + - broadcast dispatches the validated command (no raw re-parse) + - policy_provider restricted to the vetted allowlist (anti-SSRF) + - device-native rpc action is HITL-gated + - Reachy playMove move_name is path-traversal safe + - Reachy daemon transport supports auth + warns when absent + - state-mutating RPCs + emergencyStop are caller-authorized + - transport is secure-by-default (insecure is explicit opt-in) + +These use the REAL device_connect_edge package (editable install) so the +@rpc caller-identity contextvar hook is exercised end to end. +""" + +import asyncio +import importlib +import os +import sys + +import pytest + + +def _force_real_device_connect_edge(): + """Restore the REAL device_connect_edge submodules and purge our + integration modules so they re-bind to the real @rpc / DeviceDriver. + + Sibling test files (e.g. test_device_connect_drivers.py) replace + device_connect_edge.drivers/types/device with MagicMocks at import time. + To run order-independently we reload the genuine modules from disk and + drop any strands_robots.device_connect.* cached against the mocks. + """ + import importlib + for key in ("device_connect_edge.drivers", "device_connect_edge.types", + "device_connect_edge.device", "device_connect_edge"): + mod = sys.modules.get(key) + # A real module has __file__; a MagicMock stand-in does not. + if mod is not None and not hasattr(mod, "__file__"): + sys.modules.pop(key, None) + # Re-import genuine modules from disk. + importlib.import_module("device_connect_edge") + importlib.import_module("device_connect_edge.drivers") + importlib.import_module("device_connect_edge.types") + # Purge our integration so it re-imports against the real base classes. + for key in list(sys.modules): + if key.startswith("strands_robots.device_connect"): + sys.modules.pop(key, None) + + +def _run(coro): + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + +# ── Fakes ───────────────────────────────────────────────────────── + +class _FakeRobot: + tool_name_str = "so100" + + def __init__(self): + self.started = None + self.stopped = False + + def start_task(self, instruction, policy_provider, policy_port, host, duration): + self.started = dict( + instruction=instruction, policy_provider=policy_provider, + policy_port=policy_port, host=host, duration=duration, + ) + return {"status": "success", "instruction": instruction} + + def stop_task(self): + self.stopped = True + return {"status": "success"} + + def get_task_status(self): + return {"status": "idle"} + + +class _FakeWorldRobot: + def __init__(self): + self.policy_running = True + + +class _FakeWorld: + def __init__(self): + self.robots = {"r1": _FakeWorldRobot()} + self.sim_time = 0.0 + self.step_count = 0 + + +class _FakeSim: + tool_name_str = "so100_sim" + + def __init__(self): + self._world = _FakeWorld() + self.started = None + + def start_policy(self, robot_name, policy_provider, instruction, duration): + self.started = dict( + robot_name=robot_name, policy_provider=policy_provider, + instruction=instruction, duration=duration, + ) + return {"status": "success"} + + def step(self, n): + return {"status": "success", "stepped": n} + + def reset(self): + return {"status": "success", "reset": True} + + +@pytest.fixture(autouse=True) +def _clean_env(monkeypatch): + _force_real_device_connect_edge() + for var in ("DEVICE_CONNECT_RPC_ALLOW", "DEVICE_CONNECT_ESTOP_ALLOW", + "DEVICE_CONNECT_ALLOW_INSECURE", "REACHY_DAEMON_TOKEN"): + monkeypatch.delenv(var, raising=False) + # reset the one-time permissive-warning memo + import strands_robots.device_connect._authz as az + az._warned_permissive.clear() + yield + + +# ── policy_provider allowlist (anti-SSRF) ───────────────────── + +def test_robot_execute_rejects_ssrf_policy_provider(): + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + d = RobotDeviceDriver(_FakeRobot()) + res = _run(d.execute("test", policy_provider="grpc://attacker.evil:9000", + source_device="op-1")) + assert res["status"] == "error" + assert "policy_provider" in res["reason"] + + +def test_robot_execute_allows_vetted_provider(): + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + robot = _FakeRobot() + d = RobotDeviceDriver(robot) + res = _run(d.execute("pick cube", policy_provider="mock", source_device="op-1")) + assert res["status"] == "success" + assert robot.started["policy_provider"] == "mock" + + +def test_sim_execute_rejects_ssrf_policy_provider(): + from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + d = SimulationDeviceDriver(_FakeSim()) + res = _run(d.execute("test", policy_provider="ws://attacker", source_device="op-1")) + assert res["status"] == "error" + assert "policy_provider" in res["reason"] + + +# ── caller authorization ────────────────────────────────────── + +def test_execute_denied_when_allowlist_set_and_caller_not_listed(monkeypatch): + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + monkeypatch.setenv("DEVICE_CONNECT_RPC_ALLOW", "trusted-controller") + d = RobotDeviceDriver(_FakeRobot()) + res = _run(d.execute("go", policy_provider="mock", source_device="rogue-sensor")) + assert res["status"] == "error" + assert "not authorized" in res["reason"] + + +def test_execute_allowed_for_listed_caller(monkeypatch): + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + monkeypatch.setenv("DEVICE_CONNECT_RPC_ALLOW", "trusted-controller,safety-*") + robot = _FakeRobot() + d = RobotDeviceDriver(robot) + res = _run(d.execute("go", policy_provider="mock", source_device="trusted-controller")) + assert res["status"] == "success" + # glob match + res2 = _run(d.execute("go", policy_provider="mock", source_device="safety-007")) + assert res2["status"] == "success" + + +def test_stop_denied_for_unlisted_caller(monkeypatch): + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + monkeypatch.setenv("DEVICE_CONNECT_RPC_ALLOW", "ctrl") + robot = _FakeRobot() + d = RobotDeviceDriver(robot) + res = _run(d.stop(source_device="rogue")) + assert res["status"] == "error" + assert robot.stopped is False + + +def test_sim_step_reset_denied_for_unlisted_caller(monkeypatch): + from strands_robots.device_connect.sim_driver import SimulationDeviceDriver + monkeypatch.setenv("DEVICE_CONNECT_RPC_ALLOW", "ctrl") + d = SimulationDeviceDriver(_FakeSim()) + assert _run(d.step(n_steps=3, source_device="rogue"))["status"] == "error" + assert _run(d.reset(source_device="rogue"))["status"] == "error" + + +def test_permissive_when_no_allowlist(monkeypatch): + # Out-of-the-box: no allowlist => allowed (with a logged warning). + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + robot = _FakeRobot() + d = RobotDeviceDriver(robot) + res = _run(d.execute("go", policy_provider="mock", source_device="anyone")) + assert res["status"] == "success" + + +def test_emergencystop_ignores_unauthorized_source(monkeypatch): + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + monkeypatch.setenv("DEVICE_CONNECT_ESTOP_ALLOW", "safety-controller") + robot = _FakeRobot() + d = RobotDeviceDriver(robot) + _run(d.onEmergencyStop("rogue-device", "emergencyStop", {})) + assert robot.stopped is False + _run(d.onEmergencyStop("safety-controller", "emergencyStop", {})) + assert robot.stopped is True + + +# ── playMove path traversal ─────────────────────────────────── + +def _make_reachy(): + from strands_robots.device_connect import reachy_mini_driver as rmd + drv = rmd.ReachyMiniDriver.__new__(rmd.ReachyMiniDriver) + drv._host = "localhost" + drv._api_port = 8000 + return drv, rmd + + +def test_playmove_rejects_path_traversal(): + drv, rmd = _make_reachy() + captured = {} + + def fake_api(host, port, path, method="GET", data=None): + captured["path"] = path + return {"ok": True} + + rmd.api = fake_api # patch module-level api used via asyncio.to_thread + res = _run(drv.playMove("../../daemon/shutdown")) + assert res["status"] == "error" + assert "path" not in captured # api() never called + + +def test_playmove_rejects_query_injection(): + drv, rmd = _make_reachy() + res = _run(drv.playMove("x?admin=true&reset=1")) + assert res["status"] == "error" + + +def test_playmove_allows_clean_name(): + drv, rmd = _make_reachy() + captured = {} + + def fake_api(host, port, path, method="GET", data=None): + captured["path"] = path + return {"ok": True} + + rmd.api = fake_api + res = _run(drv.playMove("happy_wiggle")) + assert res["status"] == "success" + assert captured["path"].endswith("/happy_wiggle") + + +# ── Reachy daemon auth ──────────────────────────────────────── + +def test_rest_api_adds_auth_header_when_token_set(monkeypatch): + monkeypatch.setenv("REACHY_DAEMON_TOKEN", "s3cret") + from strands_robots.device_connect import reachy_transport as rt + importlib.reload(rt) + captured = {} + + class _Resp: + def __enter__(self): return self + def __exit__(self, *a): return False + def read(self): return b"{}" + + def fake_urlopen(req, body, timeout): + captured["auth"] = req.get_header("Authorization") + return _Resp() + + monkeypatch.setattr("urllib.request.urlopen", fake_urlopen) + rt.api("localhost", 8000, "/api/x") + assert captured["auth"] == "Bearer s3cret" + # cleanup: reload to restore module-level memo without token + monkeypatch.delenv("REACHY_DAEMON_TOKEN", raising=False) + importlib.reload(rt) + + +def test_token_helper_reads_env(monkeypatch): + from strands_robots.device_connect import reachy_transport as rt + importlib.reload(rt) + assert rt._daemon_auth_token() is None + monkeypatch.setenv("REACHY_DAEMON_TOKEN", "abc") + assert rt._daemon_auth_token() == "abc" + monkeypatch.delenv("REACHY_DAEMON_TOKEN", raising=False) + + +# ── secure-by-default resolution ────────────────────────────── + +def test_allow_insecure_defaults_false(monkeypatch): + # Verify the resolution logic: unset env + no explicit arg => secure. + monkeypatch.delenv("DEVICE_CONNECT_ALLOW_INSECURE", raising=False) + allow_insecure = None + if allow_insecure is None: + env_val = os.environ.get("DEVICE_CONNECT_ALLOW_INSECURE") + if env_val is not None: + allow_insecure = env_val.lower() in ("true", "1", "yes") + else: + allow_insecure = False + assert allow_insecure is False + + +def test_no_forced_insecure_setdefault_in_source(): + # The agent-side connector must NOT force insecure mode process-wide. + import strands_robots.tools.robot_mesh as rm + src = __import__("inspect").getsource(rm._dc_ensure_connected) + assert 'setdefault("DEVICE_CONNECT_ALLOW_INSECURE"' not in src + assert 'setdefault(\'DEVICE_CONNECT_ALLOW_INSECURE\'' not in src + + +# ── broadcast dispatches the validated command (no raw re-parse) ── + +def test_broadcast_dispatch_uses_validated_command(monkeypatch): + """The DC broadcast branch must use the validated command, never re-parse + the raw caller string (which could differ from what was approved).""" + import strands_robots.tools.robot_mesh as rm + from unittest.mock import MagicMock + + conn = MagicMock(name="conn") + conn.broadcast.return_value = [{"device_id": "d1", "result": {}}] + monkeypatch.setattr( + "device_connect_agent_tools.connection.get_connection", + lambda: conn, + raising=False, + ) + + # Raw string says factoryReset, but the validated command (what the + # operator approved) is a benign status. Dispatch MUST use the validated one. + raw = '{"function": "factoryReset", "confirm": true}' + validated = {"action": "status"} + res = rm._device_connect_dispatch( + "broadcast", "", "", raw, "mock", 0, 30.0, 30.0, "", validated + ) + assert res is not None + # broadcast called with the validated action, not factoryReset + called_func = conn.broadcast.call_args[0][0] + assert called_func == "status" + assert called_func != "factoryReset" + + +def test_broadcast_dispatch_without_validated_command_is_rejected(monkeypatch): + import strands_robots.tools.robot_mesh as rm + from unittest.mock import MagicMock + + conn = MagicMock(name="conn") + monkeypatch.setattr( + "device_connect_agent_tools.connection.get_connection", + lambda: conn, + raising=False, + ) + res = rm._device_connect_dispatch( + "broadcast", "", "", '{"function":"factoryReset"}', "mock", 0, 30.0, 30.0, "", None + ) + assert res["status"] == "error" + conn.broadcast.assert_not_called() + + +# ── device-native rpc action is HITL-gated ──────────────────── + +def test_rpc_is_interrupt_required(): + from strands_robots.tools.robot_mesh import _INTERRUPT_REQUIRED + assert "rpc" in _INTERRUPT_REQUIRED + + +def test_rpc_declined_by_operator_is_rejected(monkeypatch): + """With DC disabled, an rpc action must still raise the HITL interrupt and + fail closed when the operator declines.""" + import strands_robots.tools.robot_mesh as rm + from unittest.mock import MagicMock + + monkeypatch.setenv("STRANDS_ROBOT_MESH_DC", "off") + ctx = MagicMock(name="ToolContext") + ctx.interrupt.return_value = "n" # operator declines + + fn = getattr(rm.robot_mesh, "original", rm.robot_mesh) + res = fn(action="rpc", tool_context=ctx, target="device-1", + function="updateFirmware", command='{"url":"http://evil/x.bin"}') + assert res["status"] == "error" + assert ctx.interrupt.called + + +def test_rpc_surfaces_function_in_interrupt(monkeypatch): + import strands_robots.tools.robot_mesh as rm + from unittest.mock import MagicMock + + monkeypatch.setenv("STRANDS_ROBOT_MESH_DC", "off") + ctx = MagicMock(name="ToolContext") + ctx.interrupt.return_value = "n" + fn = getattr(rm.robot_mesh, "original", rm.robot_mesh) + fn(action="rpc", tool_context=ctx, target="d1", function="nod") + reason = ctx.interrupt.call_args.kwargs.get("reason", {}) + assert reason.get("function") == "nod" From f76d21d7da1c715c8f29bf57933ac177c99aa79f Mon Sep 17 00:00:00 2001 From: Kavya Chennoju Date: Thu, 11 Jun 2026 23:53:46 -0700 Subject: [PATCH 6/6] fix(device-connect): make RPC caller-allowlist usable + honest MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The per-call caller ACL (DEVICE_CONNECT_RPC_ALLOW) was effectively all-or-nothing on the agent path: device_connect_agent_tools clients are anonymous, so the device always saw caller=None. Setting an allowlist therefore denied every agent call (no allow-path was reachable), and the behaviour wasn't documented. Layer 1 — caller identity propagation (robot_mesh): - _agent_identity()/_with_identity() stamp _dc_meta.source_device into the DC command envelope from STRANDS_ROBOT_MESH_AGENT_ID (or DEVICE_CONNECT_CLIENT_ID). Wired into tell/send/rpc/stop/broadcast/ emergency_stop. Anonymous callers unchanged (still fail-closed). Layer 2 — honest authz semantics (_authz): - Document that the caller id is only a hard boundary under authenticated transport; under DEVICE_CONNECT_ALLOW_INSECURE it is self-asserted. - Log a one-time advisory when an allowlist is enforced over insecure transport. Layer 3 — docs + test honesty: - Extract resolve_allow_insecure() and fix the stale "defaults to True in D2D" docstring (it defaults to False / secure). - GUIDE.md: drop the contradictory DEVICE_CONNECT_ALLOW_INSECURE=true from the mTLS Full-Infrastructure section; document the allowlist's agent-id requirement and advisory-under-insecure semantics. - Replace the tautological test_allow_insecure_defaults_false (it re-implemented the logic) with tests that call the real resolver and init_device_connect. Add coverage for the reachable anonymous-deny, the insecure advisory, and _dc_meta identity propagation. 21 -> 29 tests. --- strands_robots/device_connect/GUIDE.md | 22 +++- strands_robots/device_connect/__init__.py | 40 ++++-- strands_robots/device_connect/_authz.py | 43 +++++++ strands_robots/tools/robot_mesh.py | 49 +++++++- tests/test_device_connect_hardening.py | 144 ++++++++++++++++++++-- 5 files changed, 270 insertions(+), 28 deletions(-) diff --git a/strands_robots/device_connect/GUIDE.md b/strands_robots/device_connect/GUIDE.md index 1253d33a..0430f920 100644 --- a/strands_robots/device_connect/GUIDE.md +++ b/strands_robots/device_connect/GUIDE.md @@ -255,7 +255,10 @@ Set environment variables (all terminals): ```bash export MESSAGING_BACKEND=zenoh export ZENOH_CONNECT=tcp/localhost:7447 -export DEVICE_CONNECT_ALLOW_INSECURE=true +# NOTE: do NOT set DEVICE_CONNECT_ALLOW_INSECURE here. The whole point of the +# brokered setup is authenticated, encrypted transport (mTLS). Enabling +# insecure mode would disable that and contradict the security model below. +# Insecure mode is only for the local, broker-less D2D trial. ``` All the options above (A–B) work identically with full infrastructure — the only difference is that devices register in etcd and discovery goes through the registry service instead of multicast scouting. @@ -266,6 +269,23 @@ All the options above (A–B) work identically with full infrastructure — the > - **Cross-network routing** — the Zenoh router (or NATS broker) enables communication across subnets and sites, not just the local LAN. > - **Authentication & authorization** — mTLS ensures only devices with certificates signed by the trusted CA can exchange data. Full authorization (per-device permissions, topic-level ACLs, certificate revocation) requires the router/registry infrastructure. +> **Per-caller RPC allowlist (`DEVICE_CONNECT_RPC_ALLOW` / `DEVICE_CONNECT_ESTOP_ALLOW`).** +> Devices can restrict which callers may invoke state-mutating RPCs (`execute` / +> `stop` / `step` / `reset`) and `emergencyStop`, matched against the RPC's +> authenticated `source_device` (trailing-`*` globs allowed, e.g. `safety-*`). +> Two things to know: +> - **Caller identity must be supplied.** A device-to-device caller carries its +> id automatically; an **agent** driving the robot via `robot_mesh` / +> `device-connect-agent-tools` is anonymous by default. Set +> `STRANDS_ROBOT_MESH_AGENT_ID=` on the agent and add `` to the +> device's allowlist — otherwise, with an allowlist set, the device +> (correctly) **denies the anonymous agent** (fail-closed). +> - **It is only a hard boundary under authenticated transport.** Over insecure +> D2D (`DEVICE_CONNECT_ALLOW_INSECURE=true`) the `source_device` is +> self-asserted, so the allowlist is *advisory* (a one-time warning is logged). +> For a real authorization boundary, run the brokered setup with mTLS so the +> caller id is bound to the sender's certificate. + #### Running the Tests ```bash diff --git a/strands_robots/device_connect/__init__.py b/strands_robots/device_connect/__init__.py index a71e5f94..102329bc 100644 --- a/strands_robots/device_connect/__init__.py +++ b/strands_robots/device_connect/__init__.py @@ -32,11 +32,34 @@ __all__ = [ "init_device_connect", "init_device_connect_sync", + "resolve_allow_insecure", "RobotDeviceDriver", "SimulationDeviceDriver", "ReachyMiniDriver", ] +_INSECURE_TRUE = ("true", "1", "yes") + + +def resolve_allow_insecure( + explicit: Optional[bool] = None, + env_value: Optional[str] = None, +) -> bool: + """Resolve the effective ``allow_insecure`` setting (secure by default). + + Precedence: explicit arg > ``DEVICE_CONNECT_ALLOW_INSECURE`` env var > + secure default (``False``). Insecure transport is never implicit — it + must be opted into via the argument or the env var. + + Extracted as a pure function so the secure-by-default posture is unit + testable without standing up a DeviceRuntime. + """ + if explicit is not None: + return explicit + if env_value is not None: + return env_value.lower() in _INSECURE_TRUE + return False + async def init_device_connect( robot, @@ -64,9 +87,11 @@ async def init_device_connect( messaging_backend: Messaging backend — "zenoh" or "nats". None = auto-detect from MESSAGING_BACKEND env var (default "zenoh"). tenant: Device Connect tenant namespace. - allow_insecure: Allow insecure connections. None = auto-detect: - respects DEVICE_CONNECT_ALLOW_INSECURE env var if set, - otherwise defaults to True in D2D mode (no broker URL). + allow_insecure: Allow insecure (unencrypted, unauthenticated) + transport. None = auto-detect: respects the + DEVICE_CONNECT_ALLOW_INSECURE env var if set, otherwise defaults + to False (secure). Insecure transport must be explicitly opted + into; a prominent warning is logged whenever it is active. Returns: The running DeviceRuntime instance. @@ -90,12 +115,9 @@ async def init_device_connect( # ``allow_insecure=True`` argument or ``DEVICE_CONNECT_ALLOW_INSECURE`` env # var — and we log a prominent warning whenever it is active so an insecure # deployment is never silent. - if allow_insecure is None: - env_val = os.environ.get("DEVICE_CONNECT_ALLOW_INSECURE") - if env_val is not None: - allow_insecure = env_val.lower() in ("true", "1", "yes") - else: - allow_insecure = False + allow_insecure = resolve_allow_insecure( + allow_insecure, os.environ.get("DEVICE_CONNECT_ALLOW_INSECURE") + ) if allow_insecure: logger.warning( diff --git a/strands_robots/device_connect/_authz.py b/strands_robots/device_connect/_authz.py index 8c26cb3a..b019f14a 100644 --- a/strands_robots/device_connect/_authz.py +++ b/strands_robots/device_connect/_authz.py @@ -18,6 +18,21 @@ when unset. Matching supports trailing ``*`` glob prefixes (e.g. ``safety-*``). + +Caller-identity semantics (READ THIS before relying on the allowlist): + +* The caller id is whatever the messaging layer reported as the RPC's + ``source_device``. A device-to-device caller (another ``DeviceRuntime``) and + an agent that sets ``STRANDS_ROBOT_MESH_AGENT_ID`` both carry an id; an + anonymous client carries **none** (``caller=None``). +* When an allowlist IS set, a missing/None caller cannot be authorized and is + denied (fail-closed). So setting ``DEVICE_CONNECT_RPC_ALLOW`` will reject + every anonymous caller — configure an id on the caller side to allow it. +* The id is only as trustworthy as the transport. Under authenticated + transport (mTLS) it is bound to the sender's certificate. Under insecure + transport (``DEVICE_CONNECT_ALLOW_INSECURE``) it is **self-asserted** — any + peer can claim any id — so the allowlist is advisory there, not a + cryptographic boundary. A one-time warning is logged in that case. """ from __future__ import annotations @@ -33,6 +48,29 @@ _ESTOP_ALLOW_ENV = "DEVICE_CONNECT_ESTOP_ALLOW" _warned_permissive: set[str] = set() +_warned_insecure_acl: set[str] = set() + +_INSECURE_ENV = "DEVICE_CONNECT_ALLOW_INSECURE" + + +def _insecure_transport_active() -> bool: + return os.environ.get(_INSECURE_ENV, "").lower() in ("true", "1", "yes") + + +def _warn_insecure_acl_once(scope: str) -> None: + """Warn (once per scope) that an allowlist is being enforced against a + self-asserted caller id because the transport is insecure.""" + if scope in _warned_insecure_acl: + return + _warned_insecure_acl.add(scope) + logger.warning( + "Device Connect %s allowlist is enforced against a SELF-ASSERTED caller " + "identity: %s is set, so any peer can claim an allowed id. Treat the " + "allowlist as advisory here; use authenticated transport (mTLS) for a " + "cryptographic authorization boundary.", + scope, + _INSECURE_ENV, + ) def _parse_allowlist(raw: Optional[str]) -> Optional[list[str]]: @@ -82,6 +120,11 @@ def is_authorized_caller(caller: Optional[str], *, scope: str = "rpc") -> bool: _warn_permissive_once(env_scope) return True + # An allowlist is configured. If the transport is insecure the caller id is + # self-asserted, so the allowlist is advisory — say so once, loudly. + if _insecure_transport_active(): + _warn_insecure_acl_once(env_scope) + # Allowlist configured: a missing caller identity cannot be authorized. if not caller: return False diff --git a/strands_robots/tools/robot_mesh.py b/strands_robots/tools/robot_mesh.py index 55db4e8d..d7f6ad40 100644 --- a/strands_robots/tools/robot_mesh.py +++ b/strands_robots/tools/robot_mesh.py @@ -260,6 +260,43 @@ def _resolve_mesh(target: str) -> Any | None: _dc_state = {"connected": False} +def _agent_identity() -> str: + """Return this agent's caller identity for Device Connect RPCs. + + Sourced from ``STRANDS_ROBOT_MESH_AGENT_ID`` (falling back to the generic + ``DEVICE_CONNECT_CLIENT_ID``). Empty string when unset — in which case the + agent is an anonymous caller and a device with ``DEVICE_CONNECT_RPC_ALLOW`` + set will (correctly) reject it. + + SECURITY: this identity is *self-asserted*. It lets an operator who has + locked a device's RPC allowlist authorize this agent by id, but it is only + a trustworthy control when the transport authenticates the sender (mTLS). + On an insecure/trusted-LAN D2D link it is advisory — any peer can claim any + id, so do not rely on it as the sole authorization boundary there. + """ + return ( + os.environ.get("STRANDS_ROBOT_MESH_AGENT_ID") + or os.environ.get("DEVICE_CONNECT_CLIENT_ID") + or "" + ) + + +def _with_identity(params: dict[str, Any]) -> dict[str, Any]: + """Stamp the agent's caller identity into the DC command envelope. + + The device side reads ``params["_dc_meta"]["source_device"]`` and exposes it + via ``get_rpc_source_device()`` for the driver's caller-authorization check. + Without this the device sees an anonymous caller (``None``). No-op when no + identity is configured, preserving the anonymous-caller behaviour. + """ + identity = _agent_identity() + if not identity: + return params + meta = dict(params.get("_dc_meta", {})) + meta.setdefault("source_device", identity) + return {**params, "_dc_meta": meta} + + class _DCResult(dict): """Strands tool-response dict whose ``str()`` renders the text block cleanly.""" @@ -397,7 +434,7 @@ def _device_connect_dispatch( except _security.ValidationError as exc: _audit_tool_action(action, target, False, f"validation: {exc}") return _DCResult(_err(f"tell rejected: {exc}")) - result = conn.invoke(target, "execute", {"instruction": instruction, **kwargs}, timeout=timeout) + result = conn.invoke(target, "execute", _with_identity({"instruction": instruction, **kwargs}), timeout=timeout) r = result.get("result", result) _audit_tool_action(action, target, True, f"instruction={instruction[:200]}") return _DCResult(_ok(f"-> {target}: {instruction}\n {json.dumps(r, default=str)}")) @@ -419,7 +456,7 @@ def _device_connect_dispatch( _audit_tool_action(action, target, False, f"validation: {exc}") return _DCResult(_err(f"send rejected: {exc}")) func = cmd.pop("action", cmd.pop("function", "getStatus")) - result = conn.invoke(target, func, cmd, timeout=timeout) + result = conn.invoke(target, func, _with_identity(cmd), timeout=timeout) r = result.get("result", result) _audit_tool_action(action, target, True, f"action={func}") return _DCResult(_ok(f"{target}:\n{json.dumps(r, indent=2, default=str)[:2000]}")) @@ -446,7 +483,7 @@ def _device_connect_dispatch( except _security.ValidationError as exc: _audit_tool_action(action, target, False, f"validation: {exc}") return _DCResult(_err(f"rpc rejected: {exc}")) - result = conn.invoke(target, func_name, rpc_params, timeout=timeout) + result = conn.invoke(target, func_name, _with_identity(rpc_params), timeout=timeout) r = result.get("result", result) if isinstance(result, dict) else result _audit_tool_action(action, target, True, f"function={func_name}") return _DCResult( @@ -456,7 +493,7 @@ def _device_connect_dispatch( if action == "stop": if not target: return _DCResult(_err("stop requires target")) - result = conn.invoke(target, "stop", timeout=min(timeout, 5.0)) + result = conn.invoke(target, "stop", _with_identity({}), timeout=min(timeout, 5.0)) r = result.get("result", result) _audit_tool_action(action, target, True, "") return _DCResult(_ok(f"Stop {target}: {json.dumps(r, default=str)}")) @@ -466,7 +503,7 @@ def _device_connect_dispatch( stopped = 0 for d in devices: try: - conn.invoke(d["device_id"], "stop", timeout=3.0) + conn.invoke(d["device_id"], "stop", _with_identity({}), timeout=3.0) stopped += 1 except Exception: # noqa: BLE001 — best-effort fan-out pass @@ -482,7 +519,7 @@ def _device_connect_dispatch( return _DCResult(_err("broadcast reached Device Connect dispatch without a validated command")) cmd = dict(validated_command) func = cmd.pop("action", cmd.pop("function", "getStatus")) - params = cmd + params = _with_identity(cmd) results = conn.broadcast(func, params, timeout=timeout) _audit_tool_action(action, "*", True, f"action={func} responses={len(results)}") text = f"[broadcast] {len(results)} responses\n" diff --git a/tests/test_device_connect_hardening.py b/tests/test_device_connect_hardening.py index ad0e8e40..c0503f41 100644 --- a/tests/test_device_connect_hardening.py +++ b/tests/test_device_connect_hardening.py @@ -118,9 +118,10 @@ def _clean_env(monkeypatch): for var in ("DEVICE_CONNECT_RPC_ALLOW", "DEVICE_CONNECT_ESTOP_ALLOW", "DEVICE_CONNECT_ALLOW_INSECURE", "REACHY_DAEMON_TOKEN"): monkeypatch.delenv(var, raising=False) - # reset the one-time permissive-warning memo + # reset the one-time warning memos import strands_robots.device_connect._authz as az az._warned_permissive.clear() + az._warned_insecure_acl.clear() yield @@ -193,6 +194,52 @@ def test_sim_step_reset_denied_for_unlisted_caller(monkeypatch): assert _run(d.reset(source_device="rogue"))["status"] == "error" +def test_anonymous_caller_denied_when_allowlist_set(monkeypatch): + """The reachable agent-path state: a caller with NO source identity + (anonymous device-connect-agent-tools client => get_rpc_source_device() + returns None) must be denied once an allowlist is configured. This is the + end-to-end behaviour an operator sees when they lock the allowlist without + giving the agent an id.""" + from strands_robots.device_connect.robot_driver import RobotDeviceDriver + monkeypatch.setenv("DEVICE_CONNECT_RPC_ALLOW", "trusted-controller") + robot = _FakeRobot() + d = RobotDeviceDriver(robot) + # No source_device kwarg => contextvar stays None, exactly like an + # anonymous agent invoking over D2D. + res = _run(d.execute("go", policy_provider="mock")) + assert res["status"] == "error" + assert "not authorized" in res["reason"] + assert res["caller"] == "unknown" + assert robot.started is None + + +def test_insecure_acl_logs_advisory_once(monkeypatch, caplog): + """Under insecure transport, enforcing an allowlist against a self-asserted + id must log a one-time advisory.""" + import logging + import strands_robots.device_connect._authz as az + monkeypatch.setenv("DEVICE_CONNECT_RPC_ALLOW", "ctrl") + monkeypatch.setenv("DEVICE_CONNECT_ALLOW_INSECURE", "true") + az._warned_insecure_acl.clear() + with caplog.at_level(logging.WARNING, logger="strands_robots.device_connect._authz"): + az.is_authorized_caller("ctrl", scope="rpc") + az.is_authorized_caller("ctrl", scope="rpc") # second call must not re-warn + advisories = [r for r in caplog.records if "SELF-ASSERTED" in r.getMessage()] + assert len(advisories) == 1 + + +def test_secure_acl_no_insecure_advisory(monkeypatch, caplog): + """With secure transport the self-asserted advisory must NOT fire.""" + import logging + import strands_robots.device_connect._authz as az + monkeypatch.setenv("DEVICE_CONNECT_RPC_ALLOW", "ctrl") + monkeypatch.delenv("DEVICE_CONNECT_ALLOW_INSECURE", raising=False) + az._warned_insecure_acl.clear() + with caplog.at_level(logging.WARNING, logger="strands_robots.device_connect._authz"): + az.is_authorized_caller("ctrl", scope="rpc") + assert not [r for r in caplog.records if "SELF-ASSERTED" in r.getMessage()] + + def test_permissive_when_no_allowlist(monkeypatch): # Out-of-the-box: no allowlist => allowed (with a logged warning). from strands_robots.device_connect.robot_driver import RobotDeviceDriver @@ -293,17 +340,49 @@ def test_token_helper_reads_env(monkeypatch): # ── secure-by-default resolution ────────────────────────────── -def test_allow_insecure_defaults_false(monkeypatch): - # Verify the resolution logic: unset env + no explicit arg => secure. - monkeypatch.delenv("DEVICE_CONNECT_ALLOW_INSECURE", raising=False) - allow_insecure = None - if allow_insecure is None: - env_val = os.environ.get("DEVICE_CONNECT_ALLOW_INSECURE") - if env_val is not None: - allow_insecure = env_val.lower() in ("true", "1", "yes") - else: - allow_insecure = False - assert allow_insecure is False +def test_allow_insecure_defaults_false(): + # Exercise the REAL resolver (not a re-implementation): unset env + no + # explicit arg => secure. + from strands_robots.device_connect import resolve_allow_insecure + assert resolve_allow_insecure(None, None) is False + + +def test_allow_insecure_resolution_precedence(): + from strands_robots.device_connect import resolve_allow_insecure + # explicit arg wins over everything + assert resolve_allow_insecure(True, "false") is True + assert resolve_allow_insecure(False, "true") is False + # env var honoured when no explicit arg (truthy spellings) + assert resolve_allow_insecure(None, "true") is True + assert resolve_allow_insecure(None, "1") is True + assert resolve_allow_insecure(None, "yes") is True + # anything else is secure + assert resolve_allow_insecure(None, "false") is False + assert resolve_allow_insecure(None, "") is False + + +def test_init_device_connect_uses_secure_default(): + """The production entrypoint constructs the runtime secure-by-default when + neither the arg nor the env var opt into insecure transport.""" + import strands_robots.device_connect as dc + from unittest.mock import patch + + captured = {} + + class _FakeRuntime: + def __init__(self, **kw): + captured.update(kw) + def set_heartbeat_provider(self, *_a, **_k): + pass + async def run(self): + return None + + async def _go(): + with patch.object(dc, "DeviceRuntime", _FakeRuntime): + await dc.init_device_connect(_FakeRobot(), peer_id="p1") + + _run(_go()) + assert captured["allow_insecure"] is False def test_no_forced_insecure_setdefault_in_source(): @@ -361,6 +440,47 @@ def test_broadcast_dispatch_without_validated_command_is_rejected(monkeypatch): conn.broadcast.assert_not_called() +# ── agent caller-identity propagation (Layer 1) ─────────────── + +def test_with_identity_noop_when_unset(monkeypatch): + import strands_robots.tools.robot_mesh as rm + monkeypatch.delenv("STRANDS_ROBOT_MESH_AGENT_ID", raising=False) + monkeypatch.delenv("DEVICE_CONNECT_CLIENT_ID", raising=False) + params = {"instruction": "go"} + out = rm._with_identity(params) + assert "_dc_meta" not in out # anonymous caller, unchanged + + +def test_with_identity_stamps_source_device(monkeypatch): + import strands_robots.tools.robot_mesh as rm + monkeypatch.setenv("STRANDS_ROBOT_MESH_AGENT_ID", "trusted-controller") + out = rm._with_identity({"instruction": "go"}) + assert out["_dc_meta"]["source_device"] == "trusted-controller" + # does not clobber a caller-supplied _dc_meta source_device + out2 = rm._with_identity({"_dc_meta": {"source_device": "explicit"}}) + assert out2["_dc_meta"]["source_device"] == "explicit" + + +def test_tell_invoke_carries_identity(monkeypatch): + """End-to-end at the dispatch layer: a configured agent id rides along in + the DC command envelope so the device's allowlist can match it.""" + import strands_robots.tools.robot_mesh as rm + from unittest.mock import MagicMock + monkeypatch.setenv("STRANDS_ROBOT_MESH_AGENT_ID", "trusted-controller") + conn = MagicMock(name="conn") + conn.invoke.return_value = {"result": {"status": "success"}} + monkeypatch.setattr( + "device_connect_agent_tools.connection.get_connection", + lambda: conn, raising=False, + ) + rm._device_connect_dispatch( + "tell", "dev-1", "pick up the cube", "", "mock", 0, 30.0, 30.0, "", None + ) + params = conn.invoke.call_args[0][2] + assert params["_dc_meta"]["source_device"] == "trusted-controller" + assert params["instruction"] == "pick up the cube" + + # ── device-native rpc action is HITL-gated ──────────────────── def test_rpc_is_interrupt_required():