Skip to content

Decent2512/robomotion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Robomotion

A JAX-based reinforcement learning framework for training a Unitree G1 humanoid to play table tennis through motion imitation. The agent learns to track reference motion-capture trajectories using Proximal Policy Optimization (PPO), trained entirely in MuJoCo MJX for GPU-accelerated parallel simulation. Policies are exported to ONNX for deployment on physical hardware.


Architecture

mocap data (.npz)
      │
      ▼
┌─────────────────────┐
│  Quality Filter     │  score_trajectory() → filter low-jerk, in-limit,
│  (common/data/)     │  plausible-height trajectories before training
└────────┬────────────┘
         │ clean trajectories
         ▼
┌─────────────────────┐
│  Training Env       │  G1TrackingTennisEnv — MJX-based, vmapped,
│  (environments/)    │  domain-randomized, episode-wrapped
└────────┬────────────┘
         │ (obs, reward, done)
         ▼
┌─────────────────────┐
│  PPO Trainer        │  Brax PPO with custom rollout, metrics logger,
│  (algorithms/)      │  early stopping, and W&B progress reporting
└────────┬────────────┘
         │ JAX policy params
         ▼
┌─────────────────────┐
│  ONNX Exporter      │  convert_jax2onnx() traces the inference function
│  (evaluation/)      │  and exports a portable .onnx policy file
└────────┬────────────┘
         │ policy.onnx
         ▼
┌─────────────────────┐
│  MuJoCo Renderer    │  Loads ONNX policy, steps through reference
│  (evaluation/)      │  trajectory, renders video or launches viewer
└─────────────────────┘

Module overview

Module Purpose
robomotion/algorithms/ PPO trainer, rollout, metrics logger, early stopping
robomotion/environments/ G1 humanoid training & inference envs, domain-randomized wrappers
robomotion/common/ Geometry utils, registry, stateful objects, config validator, trajectory data
robomotion/randomization/ Physics domain randomization (friction, mass, COM, armature)
robomotion/evaluation/ JAX→ONNX exporter, MuJoCo ONNX renderer
robomotion/cli/ export_policy, setup_workspace entry-points
robomotion/scripts/ Offline motion pre-processing

Installation

1. Prerequisites

# Python 3.10+ recommended
pip install jax[cuda12] mujoco brax mujoco-playground \
            wandb onnxruntime tyro ml-collections flax \
            scipy tqdm absl-py

2. Environment variables

# Root of the project data/storage tree
export GLI_PATH=/path/to/gli_root

# Weights & Biases (required for training)
export WANDB_PROJECT=my-project
export WANDB_ENTITY=my-entity

GLI_PATH must contain:

gli_root/
  storage/
    assets/          # XML models, meshes, mujoco_menagerie
    data/
      mocap/         # .npz trajectory files
    logs/            # training runs written here

3. Link assets

python -m robomotion.cli.setup_workspace

This symlinks mujoco_menagerie into the location expected by mujoco_playground.


Quickstart

Train

python -m robomotion.algorithms.runners.tennis_ppo \
    --task G1TrackingTennis \
    --exp_name my_run

For domain-randomized training:

python -m robomotion.algorithms.runners.tennis_ppo \
    --task G1TrackingTennisDR \
    --exp_name my_run_dr

Debug mode (small batch, fast iteration, no W&B):

python -m robomotion.algorithms.runners.tennis_ppo \
    --task G1TrackingTennis \
    --exp_name debug

Export to ONNX

python -m robomotion.cli.export_policy \
    --task G1TrackingTennis \
    --exp_name 0318120000_G1TrackingTennis_my_run

Run inference / render

python -m robomotion.evaluation.motion.renderer \
    --exp_name 0318120000_G1TrackingTennis_my_run \
    --use_renderer

Configuration

Configs live in robomotion/environments/humanoid/tennis_config.py and are accessed via the registry:

import robomotion as lmj
task_cfg = lmj.registry.get("G1TrackingTennis", "tracking_config")
env_cfg  = task_cfg.env_config      # environment parameters
pol_cfg  = task_cfg.policy_config   # PPO hyperparameters

Key parameters to tune:

Parameter Location Description
num_timesteps policy_config Total environment steps
episode_length policy_config Steps per episode
batch_size policy_config Parallel environments
ctrl_dt / sim_dt env_config Control and simulation timesteps
soft_joint_pos_limit_factor env_config Joint limit safety margin
with_racket CLI --with_racket Include racket in the model

New Features

Config Validator (robomotion/common/validator.py)

Validates training configs before any computation begins, giving clear error messages instead of silent training failures.

from robomotion.common.validator import ConfigValidator

report = ConfigValidator.from_defaults().validate(task_cfg.to_dict())
report.raise_if_invalid()   # raises ValueError listing all problems

Add custom rules:

from robomotion.common.validator import RangeRule

validator = ConfigValidator.from_defaults()
validator.add_rule(RangeRule(key="action_scale", min_val=0.1, max_val=2.0))

Early Stopping (robomotion/algorithms/policy/ppo/early_stopping.py)

Monitors a training metric and terminates the run when improvement stalls, saving GPU hours on runs that have converged or diverged.

from robomotion.algorithms.policy.ppo.early_stopping import EarlyStopping

stopper = EarlyStopping(
    metric_key="episode/reward",
    patience=10,       # stop after 10 non-improving iterations
    min_delta=1e-4,    # minimum improvement to count as progress
    mode="max",
)

signal = stopper.update(metrics, step=num_steps)
if signal.stop:
    print(f"Stopped: {signal.reason} (best at step {signal.best_step})")

Integrated automatically into tennis_ppo.py — no manual setup required during training.


Trajectory Quality Filter (robomotion/common/data/quality_filter.py)

Scores motion-capture trajectories on three criteria before training begins, removing noisy or physically implausible data that would hurt policy learning.

Criterion Weight What it measures
Jerk score 0.4 Smoothness of joint velocity (low jerk = smooth)
Joint limit score 0.3 Fraction of frames within joint limits
Root height score 0.3 Fraction of frames with plausible root height
from robomotion.common.data.quality_filter import TrajectoryQualityFilter

qf = TrajectoryQualityFilter(min_quality_score=0.5)
good_trajs = qf.filter_trajectories(all_trajectories, joint_limits)

# Inspect a single trajectory
report = qf.score_trajectory(traj, joint_limits)
print(report)
# TrajectoryQualityReport(overall=0.812, passes=True, [jerk=0.923, joint_limits=0.800, root_height=0.700])

Applied automatically inside prepare_trajectory() — trajectories below min_quality_score=0.5 are excluded before training.


Project Structure

robomotion/
├── __init__.py
├── config.py                          # Path configuration (GLI_PATH)
├── asset_setup.py                     # Asset symlinking utilities
│
├── algorithms/
│   ├── policy/ppo/
│   │   ├── trainer.py                 # Brax PPO training loop
│   │   ├── rollout.py                 # Policy rollout logic
│   │   ├── metrics.py                 # Episode metrics aggregation & logging
│   │   └── early_stopping.py          # [NEW] Early stopping controller
│   └── runners/
│       └── tennis_ppo.py              # Top-level training entry-point
│
├── environments/
│   └── humanoid/
│       ├── tennis_config.py           # Task constants and default configs
│       ├── training/
│       │   ├── base.py                # G1Env abstract base class
│       │   ├── tennis.py              # G1TrackingTennisEnv
│       │   └── tennis_dr.py          # Domain-randomized variant
│       ├── inference/
│       │   └── tennis.py              # PlayG1TrackingTennisEnv (for rendering)
│       └── wrappers/
│           └── vectorization.py       # VmapWrapper, EpisodeWrapper, DR wrapper
│
├── common/
│   ├── registry.py                    # Task/category registry
│   ├── stateful.py                    # StatefulObject base class
│   ├── geometry.py                    # JAX geometry utilities
│   ├── geometry_np.py                 # NumPy geometry utilities
│   ├── physics_utils.py               # MuJoCo model/data utilities
│   ├── validator.py                   # [NEW] Config validation framework
│   └── data/
│       ├── trajectory.py              # Trajectory dataclass
│       ├── loader.py                  # TrajectoryHandler
│       ├── processor.py               # ExtendTrajData replay callback
│       └── quality_filter.py          # [NEW] Trajectory quality scoring
│
├── randomization/
│   └── physics.py                     # Domain randomization functions
│
├── evaluation/
│   └── motion/
│       ├── exporter.py                # JAX → ONNX conversion
│       └── renderer.py                # ONNX policy renderer / video writer
│
├── cli/
│   ├── export_policy.py               # CLI: export checkpoint to ONNX
│   └── setup_workspace.py             # CLI: workspace initialization
│
└── scripts/
    └── preprocessing/
        └── motion.py                  # Offline motion pre-processing

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors