diff --git a/isaaclab_arena/environments/arena_env_builder.py b/isaaclab_arena/environments/arena_env_builder.py index 116882129..47739712c 100644 --- a/isaaclab_arena/environments/arena_env_builder.py +++ b/isaaclab_arena/environments/arena_env_builder.py @@ -183,6 +183,7 @@ def compose_manager_cfg(self) -> IsaacLabArenaManagerBasedRLEnvCfg: task.get_events_cfg(), placement_event_cfg, variations_event_cfg, + task.get_fine_grained_progress_objective_events_cfg(), ) termination_cfg = combine_configclass_instances( "TerminationCfg", @@ -211,6 +212,7 @@ def compose_manager_cfg(self) -> IsaacLabArenaManagerBasedRLEnvCfg: metrics_recorder_manager_cfg, task.get_recorder_term_cfg(), embodiment.get_recorder_term_cfg(), + task.get_fine_grained_progress_objective_recorder_cfg(), bases=(RecorderManagerBaseCfg,), ) recorder_manager_cfg = self._modify_recorder_cfg_dataset_filename(recorder_manager_cfg) diff --git a/isaaclab_arena/tasks/composite_task_base.py b/isaaclab_arena/tasks/composite_task_base.py index 2ce5182c2..1e2eaf732 100644 --- a/isaaclab_arena/tasks/composite_task_base.py +++ b/isaaclab_arena/tasks/composite_task_base.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 import copy +import dataclasses import numpy as np import torch import warnings @@ -20,6 +21,7 @@ from isaaclab_arena.metrics.metric_base import MetricBase from isaaclab_arena.metrics.metric_term_cfg import MetricTermCfg from isaaclab_arena.tasks.common.mimic_default_params import MIMIC_DATAGEN_CONFIG_DEFAULTS +from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective from isaaclab_arena.tasks.task_base import TaskBase from isaaclab_arena.utils.configclass import ( check_configclass_field_duplicates, @@ -360,6 +362,31 @@ def get_metrics(self) -> list[MetricBase]: return subtask_metrics + def get_own_fine_grained_progress_objectives(self) -> list[FineGrainedProgressObjective]: + """Composite-level FineGrainedProgressObjectives. + + These are added on top of whatever FGPOs the child subtasks declare and are not gated. + """ + return [] + + def get_fine_grained_progress_objectives(self) -> list[FineGrainedProgressObjective]: + """Concatenate child subtasks's FineGrainedProgressObjectives with namespace prefixes. + + Each child's FGPO gets a new name (subtask_{i}/{original_name}) and a parent_subtask_idx = i tag. + """ + fgpo_list: list[FineGrainedProgressObjective] = [] + for i, child in enumerate(self.subtasks): + for fgpo in child.get_fine_grained_progress_objectives(): + fgpo_list.append( + dataclasses.replace( + fgpo, + name=f"subtask_{i}/{fgpo.name}", + parent_subtask_idx=i, + ) + ) + fgpo_list.extend(self.get_own_fine_grained_progress_objectives()) + return fgpo_list + def _validate_consistent_mimic_eef_names(self, arm_mode: ArmMode) -> set[str]: "Check that all subtasks have the same Mimic eef_names." mimic_eef_names = set(self.subtasks[0].get_mimic_env_cfg(arm_mode).subtask_configs.keys()) diff --git a/isaaclab_arena/tasks/fine_grained_progress_objective.py b/isaaclab_arena/tasks/fine_grained_progress_objective.py new file mode 100644 index 000000000..7b432f33c --- /dev/null +++ b/isaaclab_arena/tasks/fine_grained_progress_objective.py @@ -0,0 +1,177 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Literal, Union + +PredicateGroups = Union[ + Callable, + list[Callable], + list[tuple[Callable, float]], + dict[str, Callable], + dict[str, list[Callable]], + dict[str, list[tuple[Callable, float]]], +] + + +DEFAULT_GROUP_NAME = "default_group" + + +def format_predicate_groups(predicate_groups: PredicateGroups) -> dict[str, list[tuple[Callable, float]]]: + """Format predicate_groups into the canonical form. + + Canonical form: ``dict[group_name: list[(callable, score)]]``. + + Accepted input shapes: + 1. func (single callable) one group with one predicate + 2. [func, func, ...] one group, sequential chain + 3. [(func, score), ...] one group, sequential chain, weighted + 4. {group: func} multiple groups, one predicate each + 5. {group: [func, ...]} multiple groups, sequential chains + 6. {group: [(func, score), ...]} multiple groups, sequential chains, weighted + """ + + if callable(predicate_groups): + return {DEFAULT_GROUP_NAME: [(predicate_groups, 1.0)]} + + if isinstance(predicate_groups, list): + if len(predicate_groups) == 0: + raise ValueError("FineGrainedProgressObjective.predicate_groups list cannot be empty") + return {DEFAULT_GROUP_NAME: _format_group_chain(predicate_groups, group_name=DEFAULT_GROUP_NAME)} + + if isinstance(predicate_groups, dict): + if len(predicate_groups) == 0: + raise ValueError("FineGrainedProgressObjective.predicate_groups dict cannot be empty") + return { + group_name: _format_group_chain(value, group_name=group_name) + for group_name, value in predicate_groups.items() + } + + raise TypeError( + "FineGrainedProgressObjective.predicate_groups must be a callable, list, or dict; got" + f" {type(predicate_groups).__name__}" + ) + + +def _format_group_chain(value, group_name: str) -> list[tuple[Callable, float]]: + if callable(value): + return [(value, 1.0)] + if not isinstance(value, list): + raise TypeError( + f"Predicate chain for group '{group_name}' must be a callable or a list; got {type(value).__name__}" + ) + if len(value) == 0: + raise ValueError(f"Predicate chain for group '{group_name}' cannot be empty") + + first = value[0] + if isinstance(first, tuple): + chain = [] + for i, item in enumerate(value): + if not (isinstance(item, tuple) and len(item) == 2): + raise TypeError(f"Group '{group_name}' index {i}: expected (callable, score) tuple, got {item!r}") + fn, score = item + if not callable(fn): + raise TypeError(f"Group '{group_name}' index {i}: first tuple element must be callable") + if not isinstance(score, (int, float)): + raise TypeError(f"Group '{group_name}' index {i}: score must be a number") + chain.append((fn, float(score))) + return chain + + if callable(first): + equal = 1.0 / len(value) + chain = [] + for i, fn in enumerate(value): + if not callable(fn): + raise TypeError(f"Group '{group_name}' index {i}: expected callable, got {type(fn).__name__}") + chain.append((fn, equal)) + return chain + + raise TypeError( + f"Group '{group_name}' elements must be callables or (callable, score) tuples; got {type(first).__name__}" + ) + + +def normalize_scores( + predicate_groups: dict[str, list[tuple[Callable, float]]], +) -> dict[str, list[tuple[Callable, float]]]: + """Scale each group's scores to sum to 1.0. Zero and negative-sum groups are left untouched.""" + + out: dict[str, list[tuple[Callable, float]]] = {} + for group, chain in predicate_groups.items(): + total = sum(score for _, score in chain) + if total <= 0: + out[group] = list(chain) + continue + out[group] = [(fn, score / total) for fn, score in chain] + return out + + +@dataclass +class FineGrainedProgressObjective: + """Configuration object that defines a scored predicate sequence to track progress within a task. + + A FineGrainedProgressObjective specifies what the predicate state machine should track. + Each FineGrainedProgressObjective holds one or more sequential predicate chains (groups). + Within a group, predicates run in order. Across groups, predicates run in parallel. + + Args: + name: Identifies the FineGrainedProgressObjective within the TaskBase. + predicate_groups: The sequential predicate chains that define the FineGrainedProgressObjective. + score: Weight of the FineGrainedProgressObjective in the TaskBase-level overall_score. + logical: How completed groups combine to determine if the FineGrainedProgressObjective is complete. + Can be "all", "any", or "choose" + K: Required when logical == "choose". Specifies the number of groups that must be completed + to consider the FineGrainedProgressObjective complete. + description: An optional description of the FineGrainedProgressObjective. + """ + + name: str + predicate_groups: PredicateGroups + score: float = 1.0 + logical: Literal["all", "any", "choose"] = "all" + K: int | None = None + description: str | None = None + + canonical_predicate_groups: dict[str, list[tuple[Callable, float]]] = field(init=False, repr=False) + + # Index of the parent TaskBase this recipe belongs to. Set automatically by + # CompositeTaskBase.get_fine_grained_progress_objectives() when used with composite tasks. + parent_subtask_idx: int | None = None + + def __post_init__(self): + if not (0.0 <= self.score <= 1.0): + raise ValueError(f"FineGrainedProgressObjective '{self.name}': score must be in [0, 1], got {self.score}") + if self.logical not in ("all", "any", "choose"): + raise ValueError( + f"FineGrainedProgressObjective '{self.name}': logical must be in ['all', 'any', 'choose'], got" + f" {self.logical}" + ) + + # Format the predicate groups into the canonical form and normalize the scores. + formatted = format_predicate_groups(self.predicate_groups) + normalized = normalize_scores(formatted) + self.canonical_predicate_groups = normalized + + # Validate the logical and K parameters. + num_groups = len(self.canonical_predicate_groups) + if self.logical == "choose": + if self.K is None: + raise ValueError(f"FineGrainedProgressObjective '{self.name}': K is required when logical='choose'") + if not (1 <= self.K <= num_groups): + raise ValueError( + f"FineGrainedProgressObjective '{self.name}': K={self.K} but must be in [1, {num_groups}]" + ) + + @property + def group_names(self) -> list[str]: + """Returns the names of the groups in the FineGrainedProgressObjective.""" + return list(self.canonical_predicate_groups.keys()) + + def get_chain(self, group_name: str) -> list[tuple[Callable, float]]: + """Returns the chain of predicates for a given group.""" + return self.canonical_predicate_groups[group_name] diff --git a/isaaclab_arena/tasks/fine_grained_progress_tracker.py b/isaaclab_arena/tasks/fine_grained_progress_tracker.py new file mode 100644 index 000000000..5d30e91aa --- /dev/null +++ b/isaaclab_arena/tasks/fine_grained_progress_tracker.py @@ -0,0 +1,419 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import torch +from dataclasses import MISSING +from typing import Any + +from isaaclab.managers import EventTermCfg +from isaaclab.managers.recorder_manager import RecorderManagerBaseCfg, RecorderTerm, RecorderTermCfg +from isaaclab.utils import configclass + +from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + +_STATE_MACHINE_ATTR = "_fine_grained_progress_tracker" + + +def _predicate_repr(pred) -> str: + """Generate human-readable string representation for a predicate.""" + + fn = getattr(pred, "func", pred) + name = getattr(fn, "__name__", repr(fn)) + kwargs = getattr(pred, "keywords", None) or {} + args = getattr(pred, "args", ()) or () + parts = [repr(a) for a in args] + for key, value in kwargs.items(): + if isinstance(value, (str, int, float, bool)): + parts.append(f"{key}={value!r}") + return f"{name}({', '.join(parts)})" if parts else name + + +class FineGrainedProgressObjectiveRunner: + """State machine runner for a single FineGrainedProgressObjective object. + + Each runner is responsible for tracking the progress of all predicate_groups + within a FineGrainedProgressObjective object across all parallel environments. + """ + + def __init__(self, fine_grained_progress_objective: FineGrainedProgressObjective, num_envs: int, device): + self.fine_grained_progress_objective = fine_grained_progress_objective + self.num_envs = num_envs + self.device = device + + # Initialize the state machine's internal state. + self.current_index: dict[str, torch.Tensor] = {} + self.group_score: dict[str, torch.Tensor] = {} + self.group_complete: dict[str, torch.Tensor] = {} + + for group_name in fine_grained_progress_objective.group_names: + self.current_index[group_name] = torch.zeros(num_envs, dtype=torch.long, device=device) + self.group_score[group_name] = torch.zeros(num_envs, dtype=torch.float32, device=device) + self.group_complete[group_name] = torch.zeros(num_envs, dtype=torch.bool, device=device) + + def _compute_composite_task_gating_mask(self, env) -> torch.Tensor: + """Per-env mask of whether the FineGrainedProgressObjective is active. + + The gating is used to determine when tracking of predicates should + be active for composite tasks. + """ + + # If no parent_subtask_idx -> always active (returns all True). + if self.fine_grained_progress_objective.parent_subtask_idx is None: + return torch.ones(self.num_envs, dtype=torch.bool, device=self.device) + + # If no env._current_subtask_idx -> composite task is not sequential (returns all True). + current_idx = getattr(env, "_current_subtask_idx", None) + if current_idx is None: + return torch.ones(self.num_envs, dtype=torch.bool, device=self.device) + + # Otherwise return True only for envs whose current + # parent-subtask index matches this FineGrainedProgressObjective's parent_subtask_idx. + if torch.is_tensor(current_idx): + ci = current_idx.to(self.device) + else: + ci = torch.as_tensor(current_idx, device=self.device) + return ci == int(self.fine_grained_progress_objective.parent_subtask_idx) + + def step(self, env, step_index: torch.Tensor | None) -> list[dict]: + """Step the state machine runner for a single env.step. + + Advance each group's predicate chain by at most one position per env and return a + transition event for every env/group that advanced this step. + """ + + # If the FineGrainedProgressObjective is not active for the composite task, there is + # nothing to advance for any env. + gating_mask = self._compute_composite_task_gating_mask(env) + if not bool(gating_mask.any().item()): + return [] + + events: list[dict] = [] + for group_name, predicate_chain in self.fine_grained_progress_objective.canonical_predicate_groups.items(): + events += self._step_group(env, group_name, predicate_chain, gating_mask, step_index) + return events + + def _step_group( + self, + env, + group_name: str, + predicate_chain: list[tuple], + gating_mask: torch.Tensor, + step_index: torch.Tensor | None, + ) -> list[dict]: + """Advance a single group's predicate chain by at most one position per env. + + Evaluates the current predicate for the envs sitting at each chain position, advances + those whose predicate is satisfied, updates the group's score and completion mask, and + returns one transition event per env that advanced. + """ + + # List of state transition events (events are emitted for an env when a predicate flips True) + events: list[dict] = [] + chain_length = len(predicate_chain) + # Mask for which envs have advanced this step (at most one advance per env per group). + advanced = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + + for chain_idx, (predicate, score_weight) in enumerate(predicate_chain): + # Compute mask for which envs that should evaluate the predicate. + # Envs should only be evaluated if: + # 1) They are at the current predicate position + # 2) They have not yet advanced this step + # 3) The FineGrainedProgressObjective is active for the composite task + at_position = (self.current_index[group_name] == chain_idx) & ~advanced & gating_mask + if not bool(at_position.any().item()): + continue + + # Evaluate the predicate for all envs, reshaped to a flat (num_envs,) bool tensor. + result = torch.as_tensor(predicate(env), dtype=torch.bool, device=self.device).reshape(-1) + if result.shape[0] != self.num_envs: + raise RuntimeError( + f"Predicate {_predicate_repr(predicate)} returned shape {tuple(result.shape)};" + f" expected ({self.num_envs},)" + ) + + # Compute mask for which envs need to be advanced to the next predicate. + advance_mask = at_position & result + if not bool(advance_mask.any().item()): + continue + + # Advance the state machine to the next predicates. + self.current_index[group_name] = torch.where( + advance_mask, + self.current_index[group_name] + 1, + self.current_index[group_name], + ) + # Update the group score for the envs that were advanced. + self.group_score[group_name] = self.group_score[group_name] + advance_mask.float() * float(score_weight) + # Update the advanced mask for the envs that were advanced. + advanced = advanced | advance_mask + + # Emit an event for each env where a predicate was advanced. + pred_name = _predicate_repr(predicate) + for eid in torch.nonzero(advance_mask, as_tuple=False).flatten().tolist(): + events.append({ + "env_idx": int(eid), + "step": int(step_index[eid].item()) if step_index is not None else -1, + "fine_grained_progress_objective": self.fine_grained_progress_objective.name, + "group": group_name, + "predicate_index": chain_idx, + "predicate_name": pred_name, + "score_delta": float(score_weight), + }) + + # Update the group complete mask for the envs that have completed the group. + self.group_complete[group_name] = self.current_index[group_name] >= chain_length + return events + + def reset(self, env_ids) -> None: + """Reset the state machine runner for the provided envs.""" + + env_ids = torch.as_tensor(env_ids, dtype=torch.long, device=self.device) + for group_name in self.fine_grained_progress_objective.group_names: + self.current_index[group_name][env_ids] = 0 + self.group_score[group_name][env_ids] = 0.0 + self.group_complete[group_name][env_ids] = False + + def is_complete(self) -> torch.Tensor: + """Check if the FineGrainedProgressObjective is complete for all envs.""" + + groups = self.fine_grained_progress_objective.group_names + stacked = torch.stack([self.group_complete[g] for g in groups], dim=1) + if self.fine_grained_progress_objective.logical == "all": + return stacked.all(dim=1) + if self.fine_grained_progress_objective.logical == "any": + return stacked.any(dim=1) + return stacked.sum(dim=1) >= int(self.fine_grained_progress_objective.K or 1) + + def overall_score_per_env(self) -> torch.Tensor: + """Compute mean group score within this FineGrainedProgressObjective (in [0, 1]).""" + + groups = self.fine_grained_progress_objective.group_names + stacked = torch.stack([self.group_score[g] for g in groups], dim=1) + return stacked.mean(dim=1) + + def get_state_for_env(self, env_idx: int, is_complete, score) -> dict: + """Per-env view of this objective's progress. + + is_complete and score are passed in (rather than recomputed here) so the full + (num_envs,) tensor reductions run once per runner in + FineGrainedProgressTracker, instead of once per env. + """ + + objective = self.fine_grained_progress_objective + completed_groups = 0 + active_predicates: dict[str, str | None] = {} + # The active predicate for a group is the one at its current chain position. Any group + # whose pointer has run off the end of the chain is complete (no active predicate). + for group_name in objective.group_names: + predicate_chain = objective.canonical_predicate_groups[group_name] + cur_group_index = int(self.current_index[group_name][env_idx].item()) + if cur_group_index >= len(predicate_chain): + active_predicates[group_name] = None + completed_groups += 1 + else: + active_predicates[group_name] = _predicate_repr(predicate_chain[cur_group_index][0]) + + return { + "completed_groups": completed_groups, + "total_groups": len(objective.group_names), + "score": float(score), + "is_complete": bool(is_complete), + "active_predicates": active_predicates, + } + + +class FineGrainedProgressTracker: + """State machine that manages runners for all FineGrainedProgressObjectives. + + Attributes: + fine_grained_progress_objectives: List of FineGrainedProgressObjectives to manage. + num_envs: Number of parallel environments. + device: Device to manage the state machine on. + runners: List of runners for each FineGrainedProgressObjective. + _events: List of events for each environment. + """ + + def __init__(self, fine_grained_progress_objectives: list[FineGrainedProgressObjective], num_envs: int, device): + self.fine_grained_progress_objectives = fine_grained_progress_objectives + self.num_envs = num_envs + self.device = device + self.runners = [ + FineGrainedProgressObjectiveRunner(s, num_envs, device) for s in fine_grained_progress_objectives + ] + self._events: list[list[dict]] = [[] for _ in range(num_envs)] + + def step(self, env, step_index: torch.Tensor | None) -> None: + """Step each runner for a single env.step.""" + + for runner in self.runners: + for event in runner.step(env, step_index): + eid = event.pop("env_idx") + self._events[eid].append(event) + + def reset(self, env_ids) -> None: + """Reset the runners for the provided envs.""" + + for runner in self.runners: + runner.reset(env_ids) + for eid in env_ids: + self._events[eid] = [] + + def get_state(self) -> list[dict]: + """Get the state of each FineGrainedProgressObjective for all envs.""" + + # Compute the per-runner (num_envs,) tensors once + completeness = [runner.is_complete() for runner in self.runners] + scores = [runner.overall_score_per_env() for runner in self.runners] + + output: list[dict] = [] + for env_idx in range(self.num_envs): + # Build a per-env dict from each runner's state. + progress_objective_states: dict[str, dict] = {} + overall_score = 0.0 + all_complete = True + for i, runner in enumerate(self.runners): + objective = runner.fine_grained_progress_objective + state = runner.get_state_for_env(env_idx, completeness[i][env_idx], scores[i][env_idx]) + progress_objective_states[objective.name] = state + overall_score += objective.score * state["score"] + all_complete = all_complete and state["is_complete"] + + # Add the per-env state dict to the output. + output.append({ + "fine_grained_progress_objectives": progress_objective_states, + "overall_score": overall_score, + "all_complete": all_complete, + }) + return output + + def get_events(self) -> list[list[dict]]: + """Get all events for all envs.""" + + return [list(e) for e in self._events] + + +def _ensure_progress_tracker( + env, fine_grained_progress_objectives: list[FineGrainedProgressObjective] +) -> FineGrainedProgressTracker: + """Return the env's FineGrainedProgressTracker, lazily creating and caching it on first call.""" + + sm: FineGrainedProgressTracker | None = getattr(env, _STATE_MACHINE_ATTR, None) + if sm is None: + sm = FineGrainedProgressTracker( + fine_grained_progress_objectives=fine_grained_progress_objectives, num_envs=env.num_envs, device=env.device + ) + setattr(env, _STATE_MACHINE_ATTR, sm) + return sm + + +class FineGrainedProgressRecorder(RecorderTerm): + """Per-step hook that ticks the FineGrainedProgressTracker. Records nothing. + + Registered as a recorder term so it runs once per env.step via + record_post_step. It advances the state machine and publishes the per-step state/events to + env.extras["fine_grained_progress"], then returns + (None, None) so nothing is written to the recorded episode data. + + env.extras["fine_grained_progress"] format: + + { + "states": [ # one entry per env + { + "fine_grained_progress_objectives": { + "": { + "completed_groups": int, + "total_groups": int, + "score": float, # 0..1, normalized within objective + "is_complete": bool, + "active_predicates": {group: str | None}, + }, + ... + }, + "overall_score": float, # weighted by FineGrainedProgressObjective.score + "all_complete": bool, + }, + ... + ], + "events": [ # one list per env + [{"step": int, "fine_grained_progress_objective": str, "group": str, + "predicate_index": int, "predicate_name": str, + "score_delta": float}, ...], + ... + ], + } + """ + + def __init__(self, cfg: FineGrainedProgressObjectiveRecorderCfg, env): + super().__init__(cfg, env) + self._fine_grained_progress_objectives = cfg.fine_grained_progress_objectives + + def record_post_step(self): + """Ticks the state machine, writes events and states to env.extras["fine_grained_progress"]""" + + sm = _ensure_progress_tracker(self._env, self._fine_grained_progress_objectives) + step_index = getattr(self._env, "episode_length_buf", None) + sm.step(self._env, step_index=step_index) + self._env.extras["fine_grained_progress"] = { + "states": sm.get_state(), + "events": sm.get_events(), + } + # This term is a per-step hook only — record nothing. + return None, None + + +def fine_grained_progress_reset_func( + env, env_ids, fine_grained_progress_objectives: list[FineGrainedProgressObjective] +) -> None: + """Reset-event entry point. + + Resets the state machine whenever the Lab env is reset. + """ + + sm = _ensure_progress_tracker(env, fine_grained_progress_objectives) + if env_ids is None: + env_ids = list(range(env.num_envs)) + elif torch.is_tensor(env_ids): + env_ids = env_ids.tolist() + sm.reset(env_ids) + + +@configclass +class FineGrainedProgressObjectiveEventsCfg: + reset_fine_grained_progress_objectives: EventTermCfg = MISSING + + +@configclass +class FineGrainedProgressObjectiveRecorderCfg(RecorderTermCfg): + class_type: type[RecorderTerm] = FineGrainedProgressRecorder + fine_grained_progress_objectives: list[FineGrainedProgressObjective] = MISSING + + +@configclass +class FineGrainedProgressObjectiveRecorderManagerCfg(RecorderManagerBaseCfg): + fine_grained_progress: FineGrainedProgressObjectiveRecorderCfg = MISSING + + +def make_fine_grained_progress_objective_events_cfg( + fine_grained_progress_objectives: list[FineGrainedProgressObjective], +) -> Any: + return FineGrainedProgressObjectiveEventsCfg( + reset_fine_grained_progress_objectives=EventTermCfg( + func=fine_grained_progress_reset_func, + mode="reset", + params={"fine_grained_progress_objectives": fine_grained_progress_objectives}, + ) + ) + + +def make_fine_grained_progress_objective_recorder_cfg( + fine_grained_progress_objectives: list[FineGrainedProgressObjective], +) -> Any: + return FineGrainedProgressObjectiveRecorderManagerCfg( + fine_grained_progress=FineGrainedProgressObjectiveRecorderCfg( + fine_grained_progress_objectives=fine_grained_progress_objectives, + ) + ) diff --git a/isaaclab_arena/tasks/task_base.py b/isaaclab_arena/tasks/task_base.py index 0c712ea45..f8757dd01 100644 --- a/isaaclab_arena/tasks/task_base.py +++ b/isaaclab_arena/tasks/task_base.py @@ -11,6 +11,11 @@ from isaaclab_arena.embodiments.common.arm_mode import ArmMode from isaaclab_arena.metrics.metric_base import MetricBase +from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective +from isaaclab_arena.tasks.fine_grained_progress_tracker import ( + make_fine_grained_progress_objective_events_cfg, + make_fine_grained_progress_objective_recorder_cfg, +) from isaaclab_arena.tasks.task_transition import TaskTransition @@ -64,6 +69,27 @@ def get_episode_length_s(self) -> float | None: def get_task_description(self) -> str | None: return self.task_description + def get_fine_grained_progress_objectives(self) -> list[FineGrainedProgressObjective]: + return [] + + def _resolve_fine_grained_progress_objectives(self) -> list[FineGrainedProgressObjective]: + # Resolve once and cache so the reset (events) and step (recorder) cfgs share the same objective objects. + if not hasattr(self, "_resolved_fine_grained_progress_objectives"): + self._resolved_fine_grained_progress_objectives = self.get_fine_grained_progress_objectives() + return self._resolved_fine_grained_progress_objectives + + def get_fine_grained_progress_objective_events_cfg(self) -> Any: + fine_grained_progress_objectives = self._resolve_fine_grained_progress_objectives() + if not fine_grained_progress_objectives: + return None + return make_fine_grained_progress_objective_events_cfg(fine_grained_progress_objectives) + + def get_fine_grained_progress_objective_recorder_cfg(self) -> Any: + fine_grained_progress_objectives = self._resolve_fine_grained_progress_objectives() + if not fine_grained_progress_objectives: + return None + return make_fine_grained_progress_objective_recorder_cfg(fine_grained_progress_objectives) + @classmethod def success_state_transition(cls, **_) -> TaskTransition: """Inform constraint resolution what the task's success condition implies about the state change. diff --git a/isaaclab_arena/tests/test_fine_grained_progress_objective_tracking.py b/isaaclab_arena/tests/test_fine_grained_progress_objective_tracking.py new file mode 100644 index 000000000..d1d16799d --- /dev/null +++ b/isaaclab_arena/tests/test_fine_grained_progress_objective_tracking.py @@ -0,0 +1,745 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +import traceback + +from isaaclab_arena.tests.utils.subprocess import run_simulation_app_function + +HEADLESS = True + +# Tolerance for floating-point score comparisons. +SCORE_TOL = 1e-6 + + +class _MockPredicate: + """Callable predicate that returns a controlled per-env bool tensor.""" + + def __init__(self, num_envs: int, name: str = "mock_predicate"): + import torch + + self.num_envs = num_envs + self.return_value = torch.tensor([False] * num_envs) + self.__name__ = name + + def set(self, values: list[bool]): + import torch + + assert len(values) == self.num_envs + self.return_value = torch.tensor(values) + + def __call__(self, env, **kwargs): + return self.return_value + + +class _MockEnv: + def __init__(self, num_envs: int = 1, device: str = "cpu"): + import torch + + self.num_envs = num_envs + self.device = device + self.extras = {} + self.episode_length_buf = torch.zeros(num_envs, dtype=torch.long) + + +def _advance_step(env, n: int = 1): + env.episode_length_buf = env.episode_length_buf + n + + +def _test_predicate_groups_single_callable(simulation_app) -> bool: + """A bare predicate becomes a default-named group with weight 1.0.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import DEFAULT_GROUP_NAME, FineGrainedProgressObjective + + try: + pred = _MockPredicate(num_envs=1) + fgpo = FineGrainedProgressObjective(name="t", predicate_groups=pred) + assert fgpo.group_names == [DEFAULT_GROUP_NAME] + chain = fgpo.get_chain(DEFAULT_GROUP_NAME) + assert len(chain) == 1 + assert chain[0][0] is pred + assert abs(chain[0][1] - 1.0) < SCORE_TOL + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_predicate_groups_list_of_callables(simulation_app) -> bool: + """A list of callables becomes a single group with normalized equal scores.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import DEFAULT_GROUP_NAME, FineGrainedProgressObjective + + try: + preds = [_MockPredicate(num_envs=1, name=f"p{i}") for i in range(3)] + fgpo = FineGrainedProgressObjective(name="t", predicate_groups=preds) + chain = fgpo.get_chain(DEFAULT_GROUP_NAME) + assert [c[0] for c in chain] == preds + # Equal scores normalize to 0.33 each, summing to 1.0. + for _, score in chain: + assert abs(score - 1.0 / 3.0) < SCORE_TOL + assert abs(sum(s for _, s in chain) - 1.0) < SCORE_TOL + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_predicate_groups_weighted_tuples(simulation_app) -> bool: + """Explicit (callable, score) tuples are normalized to sum to 1.0 within a group.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import DEFAULT_GROUP_NAME, FineGrainedProgressObjective + + try: + p1 = _MockPredicate(num_envs=1, name="p1") + p2 = _MockPredicate(num_envs=1, name="p2") + fgpo = FineGrainedProgressObjective(name="t", predicate_groups=[(p1, 1.0), (p2, 3.0)]) + chain = fgpo.get_chain(DEFAULT_GROUP_NAME) + # 1.0/4.0 = 0.25, 3.0/4.0 = 0.75 + assert abs(chain[0][1] - 0.25) < SCORE_TOL + assert abs(chain[1][1] - 0.75) < SCORE_TOL + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_predicate_groups_dict_groups(simulation_app) -> bool: + """Dict input gives one group per key and each group's scores are normalized independently.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + + try: + p_a1 = _MockPredicate(num_envs=1, name="a1") + p_a2 = _MockPredicate(num_envs=1, name="a2") + p_b = _MockPredicate(num_envs=1, name="b") + fgpo = FineGrainedProgressObjective( + name="t", + predicate_groups={ + "obj_a": [p_a1, p_a2], + "obj_b": p_b, + }, + logical="all", + ) + assert set(fgpo.group_names) == {"obj_a", "obj_b"} + a_chain = fgpo.get_chain("obj_a") + b_chain = fgpo.get_chain("obj_b") + assert len(a_chain) == 2 + assert len(b_chain) == 1 + # obj_a's equal scores sum to 1.0. + assert abs(sum(s for _, s in a_chain) - 1.0) < SCORE_TOL + # obj_b's single-element group sums to 1.0. + assert abs(b_chain[0][1] - 1.0) < SCORE_TOL + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_predicate_groups_rejects_invalid_inputs(simulation_app) -> bool: + """Empty containers and non-callable entries should raise error.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + + try: + for bad in ([], {}, 42, "string"): + try: + FineGrainedProgressObjective(name="t", predicate_groups=bad) + except (ValueError, TypeError): + continue + print(f"Expected error for input {bad!r}") + return False + # logical=choose without K should raise error. + try: + FineGrainedProgressObjective( + name="t", + predicate_groups=_MockPredicate(num_envs=1), + logical="choose", + ) + except ValueError: + pass + else: + print("Expected ValueError for logical='choose' without K") + return False + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_state_machine_advances_sequentially(simulation_app) -> bool: + """A single FineGrainedProgressObjective with a 3 predicate chain advances one step per satisfied predicate.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.fine_grained_progress_tracker import FineGrainedProgressTracker + + try: + env = _MockEnv(num_envs=1) + preds = [_MockPredicate(num_envs=1, name=f"p{i}") for i in range(3)] + fgpo = FineGrainedProgressObjective(name="lift", predicate_groups=preds) + sm = FineGrainedProgressTracker(fine_grained_progress_objectives=[fgpo], num_envs=1, device="cpu") + sm.reset([0]) + + # Step 1: p0 True while p1, p2 still False. Advance to index 1. + preds[0].set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + state = sm.get_state()[0]["fine_grained_progress_objectives"]["lift"] + assert state["completed_groups"] == 0 # 3-predicate chain not done until all 3 + assert not state["is_complete"] + events = sm.get_events()[0] + assert len(events) == 1 and events[0]["predicate_index"] == 0 + + # Step 2: p0 reverts False, p1 True. + preds[0].set([False]) + preds[1].set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + events = sm.get_events()[0] + assert len(events) == 2 and events[-1]["predicate_index"] == 1 + + # Step 3: p2 True, objective complete. + preds[2].set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + state = sm.get_state()[0]["fine_grained_progress_objectives"]["lift"] + assert state["is_complete"] + assert state["completed_groups"] == 1 + assert abs(state["score"] - 1.0) < SCORE_TOL + events = sm.get_events()[0] + assert len(events) == 3 and events[-1]["predicate_index"] == 2 + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_state_machine_ignores_out_of_order_success(simulation_app) -> bool: + """If a later predicate fires first, it's ignored until preceding ones have advanced.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.fine_grained_progress_tracker import FineGrainedProgressTracker + + try: + env = _MockEnv(num_envs=1) + preds = [_MockPredicate(num_envs=1, name=f"p{i}") for i in range(3)] + fgpo = FineGrainedProgressObjective(name="lift", predicate_groups=preds) + sm = FineGrainedProgressTracker(fine_grained_progress_objectives=[fgpo], num_envs=1, device="cpu") + sm.reset([0]) + + # p0 stays False and p1, p2 True. No progress should be made. + preds[0].set([False]) + preds[1].set([True]) + preds[2].set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + state = sm.get_state()[0]["fine_grained_progress_objectives"]["lift"] + assert state["completed_groups"] == 0 + assert not state["is_complete"] + assert state["score"] == 0.0 + assert len(sm.get_events()[0]) == 0 + + # Now p0 True, p1, p2 should advance over subsequent steps. + preds[0].set([True]) + for _ in range(3): + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + state = sm.get_state()[0]["fine_grained_progress_objectives"]["lift"] + assert state["is_complete"] + assert state["completed_groups"] == 1 + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_state_machine_logical_any(simulation_app) -> bool: + """Two parallel groups with logical=any complete as soon as either one finishes.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.fine_grained_progress_tracker import FineGrainedProgressTracker + + try: + env = _MockEnv(num_envs=1) + p_a = _MockPredicate(num_envs=1, name="a") + p_b = _MockPredicate(num_envs=1, name="b") + fgpo = FineGrainedProgressObjective( + name="either", + predicate_groups={"a": p_a, "b": p_b}, + logical="any", + ) + sm = FineGrainedProgressTracker(fine_grained_progress_objectives=[fgpo], num_envs=1, device="cpu") + sm.reset([0]) + + # Neither group complete -> not done. + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert not sm.get_state()[0]["fine_grained_progress_objectives"]["either"]["is_complete"] + + # Group p_a completes -> done. + p_a.set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert sm.get_state()[0]["fine_grained_progress_objectives"]["either"]["is_complete"] + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_state_machine_logical_all(simulation_app) -> bool: + """Two groups with logical=all complete once all groups are complete.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.fine_grained_progress_tracker import FineGrainedProgressTracker + + try: + env = _MockEnv(num_envs=1) + p_a = _MockPredicate(num_envs=1, name="a") + p_b = _MockPredicate(num_envs=1, name="b") + fgpo = FineGrainedProgressObjective( + name="both", + predicate_groups={"a": p_a, "b": p_b}, + logical="all", + ) + sm = FineGrainedProgressTracker(fine_grained_progress_objectives=[fgpo], num_envs=1, device="cpu") + sm.reset([0]) + + # Only p_a completes -> still not done. + p_a.set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert not sm.get_state()[0]["fine_grained_progress_objectives"]["both"]["is_complete"] + + # p_b also completes -> done. + p_b.set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert sm.get_state()[0]["fine_grained_progress_objectives"]["both"]["is_complete"] + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_state_machine_logical_choose(simulation_app) -> bool: + """Three groups with logical=choose and K=2 complete once any two groups are complete.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.fine_grained_progress_tracker import FineGrainedProgressTracker + + try: + env = _MockEnv(num_envs=1) + p_a = _MockPredicate(num_envs=1, name="a") + p_b = _MockPredicate(num_envs=1, name="b") + p_c = _MockPredicate(num_envs=1, name="c") + fgpo = FineGrainedProgressObjective( + name="any_two", + predicate_groups={"a": p_a, "b": p_b, "c": p_c}, + logical="choose", + K=2, + ) + sm = FineGrainedProgressTracker(fine_grained_progress_objectives=[fgpo], num_envs=1, device="cpu") + sm.reset([0]) + + # Only p_a group complete -> not done. + p_a.set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert not sm.get_state()[0]["fine_grained_progress_objectives"]["any_two"]["is_complete"] + + # p_b also complete -> done. + p_b.set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert sm.get_state()[0]["fine_grained_progress_objectives"]["any_two"]["is_complete"] + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_state_machine_reset_clears_state(simulation_app) -> bool: + """Resetting an env_id zeroes its progress and event log, but leaves other envs alone.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.fine_grained_progress_tracker import FineGrainedProgressTracker + + try: + env = _MockEnv(num_envs=2) + preds = [_MockPredicate(num_envs=2, name=f"p{i}") for i in range(2)] + fgpo = FineGrainedProgressObjective(name="t", predicate_groups=preds) + sm = FineGrainedProgressTracker(fine_grained_progress_objectives=[fgpo], num_envs=2, device="cpu") + sm.reset([0, 1]) + + # Set env 0 to fully complete. + preds[0].set([True, True]) + preds[1].set([True, False]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + + state = sm.get_state() + assert state[0]["fine_grained_progress_objectives"]["t"]["is_complete"] + assert not state[1]["fine_grained_progress_objectives"]["t"]["is_complete"] + assert len(sm.get_events()[0]) >= 2 + assert len(sm.get_events()[1]) >= 1 + + # Reset only env 0. + sm.reset([0]) + state = sm.get_state() + assert not state[0]["fine_grained_progress_objectives"]["t"]["is_complete"] + assert state[0]["fine_grained_progress_objectives"]["t"]["score"] == 0.0 + assert sm.get_events()[0] == [] + # env 1 untouched. + assert len(sm.get_events()[1]) >= 1 + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_gating_advance_when_parent_subtask_idx_matches(simulation_app) -> bool: + """A FineGrainedProgressObjective with parent_subtask_idx=N advances when the env's _current_subtask_idx=N.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.fine_grained_progress_tracker import FineGrainedProgressTracker + + try: + env = _MockEnv(num_envs=1) + env._current_subtask_idx = [1] + + pred = _MockPredicate(num_envs=1, name="p") + fgpo = FineGrainedProgressObjective(name="t", predicate_groups=pred, parent_subtask_idx=1) + sm = FineGrainedProgressTracker(fine_grained_progress_objectives=[fgpo], num_envs=1, device="cpu") + sm.reset([0]) + + pred.set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert sm.get_state()[0]["fine_grained_progress_objectives"]["t"]["is_complete"] + assert len(sm.get_events()[0]) == 1 + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_gating_blocked_when_parent_subtask_idx_mismatches(simulation_app) -> bool: + """A FineGrainedProgressObjective with parent_subtask_idx=N doesn't advance when the env's _current_subtask_idx!=N.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.fine_grained_progress_tracker import FineGrainedProgressTracker + + try: + env = _MockEnv(num_envs=1) + env._current_subtask_idx = [0] + + pred = _MockPredicate(num_envs=1, name="p") + fgpo = FineGrainedProgressObjective(name="t", predicate_groups=pred, parent_subtask_idx=1) + sm = FineGrainedProgressTracker(fine_grained_progress_objectives=[fgpo], num_envs=1, device="cpu") + sm.reset([0]) + + # Predicate True, but the parent isn't at this FGPO's index yet. + pred.set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert not sm.get_state()[0]["fine_grained_progress_objectives"]["t"]["is_complete"] + assert sm.get_state()[0]["fine_grained_progress_objectives"]["t"]["score"] == 0.0 + assert len(sm.get_events()[0]) == 0 + + # Parent advances to this FGPO's index, state machine advances. + env._current_subtask_idx = [1] + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert sm.get_state()[0]["fine_grained_progress_objectives"]["t"]["is_complete"] + assert len(sm.get_events()[0]) == 1 + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_gating_sequential_task_end_to_end(simulation_app) -> bool: + """Two FGPOs with different parent subtask indices. The parent's + _current_subtask_idx advances over time. Each FGPO only progresses + during its active window.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.fine_grained_progress_tracker import FineGrainedProgressTracker + + try: + env = _MockEnv(num_envs=1) + env._current_subtask_idx = [0] + + pred_a = _MockPredicate(num_envs=1, name="a") + pred_b = _MockPredicate(num_envs=1, name="b") + fgpo_a = FineGrainedProgressObjective(name="a", predicate_groups=pred_a, parent_subtask_idx=0) + fgpo_b = FineGrainedProgressObjective(name="b", predicate_groups=pred_b, parent_subtask_idx=1) + sm = FineGrainedProgressTracker(fine_grained_progress_objectives=[fgpo_a, fgpo_b], num_envs=1, device="cpu") + sm.reset([0]) + + # Both predicates True, but only pred_a is active. + pred_a.set([True]) + pred_b.set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert sm.get_state()[0]["fine_grained_progress_objectives"]["a"]["is_complete"] + assert not sm.get_state()[0]["fine_grained_progress_objectives"]["b"]["is_complete"] + + # Advances to subtask 1 so pred_b is now active. + env._current_subtask_idx = [1] + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert sm.get_state()[0]["fine_grained_progress_objectives"]["b"]["is_complete"] + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_gating_noop_when_env_has_no_current_subtask_idx(simulation_app) -> bool: + """For unordered composite tasks gating is a no-op and all FGPOs advance whenever their predicates are True.""" + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.fine_grained_progress_tracker import FineGrainedProgressTracker + + try: + env = _MockEnv(num_envs=1) + + pred = _MockPredicate(num_envs=1, name="p") + fgpo = FineGrainedProgressObjective(name="t", predicate_groups=pred, parent_subtask_idx=1) + sm = FineGrainedProgressTracker(fine_grained_progress_objectives=[fgpo], num_envs=1, device="cpu") + sm.reset([0]) + + pred.set([True]) + _advance_step(env) + sm.step(env, step_index=env.episode_length_buf) + assert sm.get_state()[0]["fine_grained_progress_objectives"]["t"]["is_complete"] + assert len(sm.get_events()[0]) == 1 + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_recorder_publishes_to_extras_and_records_nothing(simulation_app) -> bool: + """FineGrainedProgressRecorder.record_post_step writes env.extras and records nothing. + + ``record_post_step`` returns ``(None, None)`` (so nothing is added to the recorded + episode data) while still ticking the tracker and publishing the per-step state to + ``env.extras["fine_grained_progress"]``. + """ + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.fine_grained_progress_tracker import ( + FineGrainedProgressObjectiveRecorderCfg, + fine_grained_progress_reset_func, + ) + + try: + env = _MockEnv(num_envs=2) + pred = _MockPredicate(num_envs=2, name="p") + fgpo = FineGrainedProgressObjective(name="t", predicate_groups=pred) + + recorder_cfg = FineGrainedProgressObjectiveRecorderCfg(fine_grained_progress_objectives=[fgpo]) + recorder = recorder_cfg.class_type(recorder_cfg, env) + + fine_grained_progress_reset_func(env, env_ids=[0, 1], fine_grained_progress_objectives=[fgpo]) + + # Step with predicate=False, state machine ticks but no transitions. Records nothing. + assert recorder.record_post_step() == (None, None) + assert "fine_grained_progress" in env.extras + assert len(env.extras["fine_grained_progress"]["states"]) == 2 + assert env.extras["fine_grained_progress"]["events"] == [[], []] + assert not env.extras["fine_grained_progress"]["states"][0]["fine_grained_progress_objectives"]["t"][ + "is_complete" + ] + + # Step with env 0 predicate True, env 0 completes, env 1 does not. + pred.set([True, False]) + _advance_step(env) + assert recorder.record_post_step() == (None, None) + states = env.extras["fine_grained_progress"]["states"] + events = env.extras["fine_grained_progress"]["events"] + assert states[0]["fine_grained_progress_objectives"]["t"]["is_complete"] + assert not states[1]["fine_grained_progress_objectives"]["t"]["is_complete"] + assert len(events[0]) == 1 + assert len(events[1]) == 0 + + # Reset env 0, env 1 untouched. + pred.set([False, False]) + fine_grained_progress_reset_func(env, env_ids=[0], fine_grained_progress_objectives=[fgpo]) + assert recorder.record_post_step() == (None, None) + states = env.extras["fine_grained_progress"]["states"] + assert not states[0]["fine_grained_progress_objectives"]["t"]["is_complete"] + assert states[0]["fine_grained_progress_objectives"]["t"]["score"] == 0.0 + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def _test_task_base_fine_grained_progress_objective_hooks(simulation_app) -> bool: + """Test TaskBase's fine-grained-progress-objective hooks. Default is empty/None. Overriding + ``get_fine_grained_progress_objectives`` causes the events/recorder helpers to + return real cfgs that the env builder picks up automatically. + """ + from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective + from isaaclab_arena.tasks.task_base import TaskBase + + try: + + class _Base(TaskBase): + def get_scene_cfg(self): + return None + + def get_termination_cfg(self): + return None + + def get_events_cfg(self): + return None + + def get_mimic_env_cfg(self, arm_mode): + return None + + def get_metrics(self): + return [] + + default_task = _Base() + assert default_task.get_fine_grained_progress_objectives() == [] + assert default_task.get_fine_grained_progress_objective_events_cfg() is None + assert default_task.get_fine_grained_progress_objective_recorder_cfg() is None + + class _OptIn(_Base): + def get_fine_grained_progress_objectives(self): + pred = _MockPredicate(num_envs=1, name="p") + return [FineGrainedProgressObjective(name="lift", predicate_groups=pred)] + + opt_in = _OptIn() + assert len(opt_in.get_fine_grained_progress_objectives()) == 1 + assert opt_in.get_fine_grained_progress_objective_events_cfg() is not None + assert opt_in.get_fine_grained_progress_objective_recorder_cfg() is not None + + from isaaclab_arena.tasks.composite_task_base import CompositeTaskBase + + class _ChildA(_Base): + def get_fine_grained_progress_objectives(self): + return [FineGrainedProgressObjective(name="open", predicate_groups=_MockPredicate(1, name="pa"))] + + class _ChildB(_Base): + def get_fine_grained_progress_objectives(self): + return [FineGrainedProgressObjective(name="close", predicate_groups=_MockPredicate(1, name="pb"))] + + composite = CompositeTaskBase(subtasks=[_ChildA(), _ChildB()]) + recipes = composite.get_fine_grained_progress_objectives() + assert len(recipes) == 2 + assert recipes[0].name == "subtask_0/open" + assert recipes[0].parent_subtask_idx == 0 + assert recipes[1].name == "subtask_1/close" + assert recipes[1].parent_subtask_idx == 1 + + class _CompositeWithOwn(CompositeTaskBase): + def get_own_fine_grained_progress_objectives(self): + return [FineGrainedProgressObjective(name="both_done", predicate_groups=_MockPredicate(1, name="own"))] + + composite2 = _CompositeWithOwn(subtasks=[_ChildA(), _ChildB()]) + recipes2 = composite2.get_fine_grained_progress_objectives() + assert len(recipes2) == 3 + assert recipes2[2].name == "both_done" + assert recipes2[2].parent_subtask_idx is None + except Exception as e: + print(f"Error: {e}") + traceback.print_exc() + return False + return True + + +def test_predicate_groups_single_callable(): + assert run_simulation_app_function(_test_predicate_groups_single_callable, headless=HEADLESS) + + +def test_predicate_groups_list_of_callables(): + assert run_simulation_app_function(_test_predicate_groups_list_of_callables, headless=HEADLESS) + + +def test_predicate_groups_weighted_tuples(): + assert run_simulation_app_function(_test_predicate_groups_weighted_tuples, headless=HEADLESS) + + +def test_predicate_groups_dict_groups(): + assert run_simulation_app_function(_test_predicate_groups_dict_groups, headless=HEADLESS) + + +def test_predicate_groups_rejects_invalid_inputs(): + assert run_simulation_app_function(_test_predicate_groups_rejects_invalid_inputs, headless=HEADLESS) + + +def test_state_machine_advances_sequentially(): + assert run_simulation_app_function(_test_state_machine_advances_sequentially, headless=HEADLESS) + + +def test_state_machine_ignores_out_of_order_success(): + assert run_simulation_app_function(_test_state_machine_ignores_out_of_order_success, headless=HEADLESS) + + +def test_state_machine_logical_any(): + assert run_simulation_app_function(_test_state_machine_logical_any, headless=HEADLESS) + + +def test_state_machine_logical_all(): + assert run_simulation_app_function(_test_state_machine_logical_all, headless=HEADLESS) + + +def test_state_machine_logical_choose(): + assert run_simulation_app_function(_test_state_machine_logical_choose, headless=HEADLESS) + + +def test_state_machine_reset_clears_state(): + assert run_simulation_app_function(_test_state_machine_reset_clears_state, headless=HEADLESS) + + +def test_gating_advance_when_parent_subtask_idx_matches(): + assert run_simulation_app_function(_test_gating_advance_when_parent_subtask_idx_matches, headless=HEADLESS) + + +def test_gating_blocked_when_parent_subtask_idx_mismatches(): + assert run_simulation_app_function(_test_gating_blocked_when_parent_subtask_idx_mismatches, headless=HEADLESS) + + +def test_gating_noop_when_env_has_no_current_subtask_idx(): + assert run_simulation_app_function(_test_gating_noop_when_env_has_no_current_subtask_idx, headless=HEADLESS) + + +def test_gating_sequential_task_end_to_end(): + assert run_simulation_app_function(_test_gating_sequential_task_end_to_end, headless=HEADLESS) + + +def test_recorder_publishes_to_extras_and_records_nothing(): + assert run_simulation_app_function(_test_recorder_publishes_to_extras_and_records_nothing, headless=HEADLESS) + + +def test_task_base_fine_grained_progress_objective_hooks(): + assert run_simulation_app_function(_test_task_base_fine_grained_progress_objective_hooks, headless=HEADLESS) + + +if __name__ == "__main__": + test_predicate_groups_single_callable() + test_predicate_groups_list_of_callables() + test_predicate_groups_weighted_tuples() + test_predicate_groups_dict_groups() + test_predicate_groups_rejects_invalid_inputs() + test_state_machine_advances_sequentially() + test_state_machine_ignores_out_of_order_success() + test_state_machine_logical_any() + test_state_machine_logical_all() + test_state_machine_logical_choose() + test_state_machine_reset_clears_state() + test_gating_advance_when_parent_subtask_idx_matches() + test_gating_blocked_when_parent_subtask_idx_mismatches() + test_gating_noop_when_env_has_no_current_subtask_idx() + test_gating_sequential_task_end_to_end() + test_recorder_publishes_to_extras_and_records_nothing() + test_task_base_fine_grained_progress_objective_hooks()