Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions isaaclab_arena/environments/arena_env_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def compose_manager_cfg(self) -> IsaacLabArenaManagerBasedRLEnvCfg:
task.get_events_cfg(),
placement_event_cfg,
variations_event_cfg,
task.get_fine_grained_progress_objective_events_cfg(),
)
termination_cfg = combine_configclass_instances(
"TerminationCfg",
Expand Down Expand Up @@ -211,6 +212,7 @@ def compose_manager_cfg(self) -> IsaacLabArenaManagerBasedRLEnvCfg:
metrics_recorder_manager_cfg,
task.get_recorder_term_cfg(),
embodiment.get_recorder_term_cfg(),
task.get_fine_grained_progress_objective_recorder_cfg(),
bases=(RecorderManagerBaseCfg,),
)
recorder_manager_cfg = self._modify_recorder_cfg_dataset_filename(recorder_manager_cfg)
Expand Down
27 changes: 27 additions & 0 deletions isaaclab_arena/tasks/composite_task_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# SPDX-License-Identifier: Apache-2.0

import copy
import dataclasses
import numpy as np
import torch
import warnings
Expand All @@ -20,6 +21,7 @@
from isaaclab_arena.metrics.metric_base import MetricBase
from isaaclab_arena.metrics.metric_term_cfg import MetricTermCfg
from isaaclab_arena.tasks.common.mimic_default_params import MIMIC_DATAGEN_CONFIG_DEFAULTS
from isaaclab_arena.tasks.fine_grained_progress_objective import FineGrainedProgressObjective
from isaaclab_arena.tasks.task_base import TaskBase
from isaaclab_arena.utils.configclass import (
check_configclass_field_duplicates,
Expand Down Expand Up @@ -360,6 +362,31 @@ def get_metrics(self) -> list[MetricBase]:

return subtask_metrics

def get_own_fine_grained_progress_objectives(self) -> list[FineGrainedProgressObjective]:
"""Composite-level FineGrainedProgressObjectives.

These are added on top of whatever FGPOs the child subtasks declare and are not gated.
"""
return []

def get_fine_grained_progress_objectives(self) -> list[FineGrainedProgressObjective]:
"""Concatenate child subtasks's FineGrainedProgressObjectives with namespace prefixes.

Each child's FGPO gets a new name (subtask_{i}/{original_name}) and a parent_subtask_idx = i tag.
"""
fgpo_list: list[FineGrainedProgressObjective] = []
for i, child in enumerate(self.subtasks):
for fgpo in child.get_fine_grained_progress_objectives():
fgpo_list.append(
dataclasses.replace(
fgpo,
name=f"subtask_{i}/{fgpo.name}",
parent_subtask_idx=i,
)
)
fgpo_list.extend(self.get_own_fine_grained_progress_objectives())
return fgpo_list

def _validate_consistent_mimic_eef_names(self, arm_mode: ArmMode) -> set[str]:
"Check that all subtasks have the same Mimic eef_names."
mimic_eef_names = set(self.subtasks[0].get_mimic_env_cfg(arm_mode).subtask_configs.keys())
Expand Down
177 changes: 177 additions & 0 deletions isaaclab_arena/tasks/fine_grained_progress_objective.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Literal, Union

PredicateGroups = Union[
Callable,
list[Callable],
list[tuple[Callable, float]],
dict[str, Callable],
dict[str, list[Callable]],
dict[str, list[tuple[Callable, float]]],
]


DEFAULT_GROUP_NAME = "default_group"


def format_predicate_groups(predicate_groups: PredicateGroups) -> dict[str, list[tuple[Callable, float]]]:
"""Format predicate_groups into the canonical form.

Canonical form: ``dict[group_name: list[(callable, score)]]``.

Accepted input shapes:
1. func (single callable) one group with one predicate
2. [func, func, ...] one group, sequential chain
3. [(func, score), ...] one group, sequential chain, weighted
4. {group: func} multiple groups, one predicate each
5. {group: [func, ...]} multiple groups, sequential chains
6. {group: [(func, score), ...]} multiple groups, sequential chains, weighted
"""

if callable(predicate_groups):
return {DEFAULT_GROUP_NAME: [(predicate_groups, 1.0)]}

if isinstance(predicate_groups, list):
if len(predicate_groups) == 0:
raise ValueError("FineGrainedProgressObjective.predicate_groups list cannot be empty")
return {DEFAULT_GROUP_NAME: _format_group_chain(predicate_groups, group_name=DEFAULT_GROUP_NAME)}

if isinstance(predicate_groups, dict):
if len(predicate_groups) == 0:
raise ValueError("FineGrainedProgressObjective.predicate_groups dict cannot be empty")
return {
group_name: _format_group_chain(value, group_name=group_name)
for group_name, value in predicate_groups.items()
}

raise TypeError(
"FineGrainedProgressObjective.predicate_groups must be a callable, list, or dict; got"
f" {type(predicate_groups).__name__}"
)


def _format_group_chain(value, group_name: str) -> list[tuple[Callable, float]]:
if callable(value):
return [(value, 1.0)]
if not isinstance(value, list):
raise TypeError(
f"Predicate chain for group '{group_name}' must be a callable or a list; got {type(value).__name__}"
)
if len(value) == 0:
raise ValueError(f"Predicate chain for group '{group_name}' cannot be empty")

first = value[0]
if isinstance(first, tuple):
chain = []
for i, item in enumerate(value):
if not (isinstance(item, tuple) and len(item) == 2):
raise TypeError(f"Group '{group_name}' index {i}: expected (callable, score) tuple, got {item!r}")
fn, score = item
if not callable(fn):
raise TypeError(f"Group '{group_name}' index {i}: first tuple element must be callable")
if not isinstance(score, (int, float)):
raise TypeError(f"Group '{group_name}' index {i}: score must be a number")
chain.append((fn, float(score)))
return chain

if callable(first):
equal = 1.0 / len(value)
chain = []
for i, fn in enumerate(value):
if not callable(fn):
raise TypeError(f"Group '{group_name}' index {i}: expected callable, got {type(fn).__name__}")
chain.append((fn, equal))
return chain

raise TypeError(
f"Group '{group_name}' elements must be callables or (callable, score) tuples; got {type(first).__name__}"
)


def normalize_scores(
predicate_groups: dict[str, list[tuple[Callable, float]]],
) -> dict[str, list[tuple[Callable, float]]]:
"""Scale each group's scores to sum to 1.0. Zero and negative-sum groups are left untouched."""

out: dict[str, list[tuple[Callable, float]]] = {}
for group, chain in predicate_groups.items():
total = sum(score for _, score in chain)
if total <= 0:
out[group] = list(chain)
continue
out[group] = [(fn, score / total) for fn, score in chain]
return out


@dataclass
class FineGrainedProgressObjective:
"""Configuration object that defines a scored predicate sequence to track progress within a task.

A FineGrainedProgressObjective specifies what the predicate state machine should track.
Each FineGrainedProgressObjective holds one or more sequential predicate chains (groups).
Within a group, predicates run in order. Across groups, predicates run in parallel.

Args:
name: Identifies the FineGrainedProgressObjective within the TaskBase.
predicate_groups: The sequential predicate chains that define the FineGrainedProgressObjective.
score: Weight of the FineGrainedProgressObjective in the TaskBase-level overall_score.
logical: How completed groups combine to determine if the FineGrainedProgressObjective is complete.
Can be "all", "any", or "choose"
K: Required when logical == "choose". Specifies the number of groups that must be completed
to consider the FineGrainedProgressObjective complete.
description: An optional description of the FineGrainedProgressObjective.
"""

name: str
predicate_groups: PredicateGroups
score: float = 1.0
logical: Literal["all", "any", "choose"] = "all"
K: int | None = None
description: str | None = None

canonical_predicate_groups: dict[str, list[tuple[Callable, float]]] = field(init=False, repr=False)

# Index of the parent TaskBase this recipe belongs to. Set automatically by
# CompositeTaskBase.get_fine_grained_progress_objectives() when used with composite tasks.
parent_subtask_idx: int | None = None

def __post_init__(self):
if not (0.0 <= self.score <= 1.0):
raise ValueError(f"FineGrainedProgressObjective '{self.name}': score must be in [0, 1], got {self.score}")
if self.logical not in ("all", "any", "choose"):
raise ValueError(
f"FineGrainedProgressObjective '{self.name}': logical must be in ['all', 'any', 'choose'], got"
f" {self.logical}"
)

# Format the predicate groups into the canonical form and normalize the scores.
formatted = format_predicate_groups(self.predicate_groups)
normalized = normalize_scores(formatted)
self.canonical_predicate_groups = normalized

# Validate the logical and K parameters.
num_groups = len(self.canonical_predicate_groups)
if self.logical == "choose":
if self.K is None:
raise ValueError(f"FineGrainedProgressObjective '{self.name}': K is required when logical='choose'")
if not (1 <= self.K <= num_groups):
raise ValueError(
f"FineGrainedProgressObjective '{self.name}': K={self.K} but must be in [1, {num_groups}]"
)

@property
def group_names(self) -> list[str]:
"""Returns the names of the groups in the FineGrainedProgressObjective."""
return list(self.canonical_predicate_groups.keys())

def get_chain(self, group_name: str) -> list[tuple[Callable, float]]:
"""Returns the chain of predicates for a given group."""
return self.canonical_predicate_groups[group_name]
Loading
Loading