A JAX-based reinforcement learning framework for training a Unitree G1 humanoid to play table tennis through motion imitation. The agent learns to track reference motion-capture trajectories using Proximal Policy Optimization (PPO), trained entirely in MuJoCo MJX for GPU-accelerated parallel simulation. Policies are exported to ONNX for deployment on physical hardware.
mocap data (.npz)
│
▼
┌─────────────────────┐
│ Quality Filter │ score_trajectory() → filter low-jerk, in-limit,
│ (common/data/) │ plausible-height trajectories before training
└────────┬────────────┘
│ clean trajectories
▼
┌─────────────────────┐
│ Training Env │ G1TrackingTennisEnv — MJX-based, vmapped,
│ (environments/) │ domain-randomized, episode-wrapped
└────────┬────────────┘
│ (obs, reward, done)
▼
┌─────────────────────┐
│ PPO Trainer │ Brax PPO with custom rollout, metrics logger,
│ (algorithms/) │ early stopping, and W&B progress reporting
└────────┬────────────┘
│ JAX policy params
▼
┌─────────────────────┐
│ ONNX Exporter │ convert_jax2onnx() traces the inference function
│ (evaluation/) │ and exports a portable .onnx policy file
└────────┬────────────┘
│ policy.onnx
▼
┌─────────────────────┐
│ MuJoCo Renderer │ Loads ONNX policy, steps through reference
│ (evaluation/) │ trajectory, renders video or launches viewer
└─────────────────────┘
| Module | Purpose |
|---|---|
robomotion/algorithms/ |
PPO trainer, rollout, metrics logger, early stopping |
robomotion/environments/ |
G1 humanoid training & inference envs, domain-randomized wrappers |
robomotion/common/ |
Geometry utils, registry, stateful objects, config validator, trajectory data |
robomotion/randomization/ |
Physics domain randomization (friction, mass, COM, armature) |
robomotion/evaluation/ |
JAX→ONNX exporter, MuJoCo ONNX renderer |
robomotion/cli/ |
export_policy, setup_workspace entry-points |
robomotion/scripts/ |
Offline motion pre-processing |
# Python 3.10+ recommended
pip install jax[cuda12] mujoco brax mujoco-playground \
wandb onnxruntime tyro ml-collections flax \
scipy tqdm absl-py# Root of the project data/storage tree
export GLI_PATH=/path/to/gli_root
# Weights & Biases (required for training)
export WANDB_PROJECT=my-project
export WANDB_ENTITY=my-entityGLI_PATH must contain:
gli_root/
storage/
assets/ # XML models, meshes, mujoco_menagerie
data/
mocap/ # .npz trajectory files
logs/ # training runs written here
python -m robomotion.cli.setup_workspaceThis symlinks mujoco_menagerie into the location expected by mujoco_playground.
python -m robomotion.algorithms.runners.tennis_ppo \
--task G1TrackingTennis \
--exp_name my_runFor domain-randomized training:
python -m robomotion.algorithms.runners.tennis_ppo \
--task G1TrackingTennisDR \
--exp_name my_run_drDebug mode (small batch, fast iteration, no W&B):
python -m robomotion.algorithms.runners.tennis_ppo \
--task G1TrackingTennis \
--exp_name debugpython -m robomotion.cli.export_policy \
--task G1TrackingTennis \
--exp_name 0318120000_G1TrackingTennis_my_runpython -m robomotion.evaluation.motion.renderer \
--exp_name 0318120000_G1TrackingTennis_my_run \
--use_rendererConfigs live in robomotion/environments/humanoid/tennis_config.py and are accessed via the registry:
import robomotion as lmj
task_cfg = lmj.registry.get("G1TrackingTennis", "tracking_config")
env_cfg = task_cfg.env_config # environment parameters
pol_cfg = task_cfg.policy_config # PPO hyperparametersKey parameters to tune:
| Parameter | Location | Description |
|---|---|---|
num_timesteps |
policy_config |
Total environment steps |
episode_length |
policy_config |
Steps per episode |
batch_size |
policy_config |
Parallel environments |
ctrl_dt / sim_dt |
env_config |
Control and simulation timesteps |
soft_joint_pos_limit_factor |
env_config |
Joint limit safety margin |
with_racket |
CLI --with_racket |
Include racket in the model |
Validates training configs before any computation begins, giving clear error messages instead of silent training failures.
from robomotion.common.validator import ConfigValidator
report = ConfigValidator.from_defaults().validate(task_cfg.to_dict())
report.raise_if_invalid() # raises ValueError listing all problemsAdd custom rules:
from robomotion.common.validator import RangeRule
validator = ConfigValidator.from_defaults()
validator.add_rule(RangeRule(key="action_scale", min_val=0.1, max_val=2.0))Monitors a training metric and terminates the run when improvement stalls, saving GPU hours on runs that have converged or diverged.
from robomotion.algorithms.policy.ppo.early_stopping import EarlyStopping
stopper = EarlyStopping(
metric_key="episode/reward",
patience=10, # stop after 10 non-improving iterations
min_delta=1e-4, # minimum improvement to count as progress
mode="max",
)
signal = stopper.update(metrics, step=num_steps)
if signal.stop:
print(f"Stopped: {signal.reason} (best at step {signal.best_step})")Integrated automatically into tennis_ppo.py — no manual setup required during training.
Scores motion-capture trajectories on three criteria before training begins, removing noisy or physically implausible data that would hurt policy learning.
| Criterion | Weight | What it measures |
|---|---|---|
| Jerk score | 0.4 | Smoothness of joint velocity (low jerk = smooth) |
| Joint limit score | 0.3 | Fraction of frames within joint limits |
| Root height score | 0.3 | Fraction of frames with plausible root height |
from robomotion.common.data.quality_filter import TrajectoryQualityFilter
qf = TrajectoryQualityFilter(min_quality_score=0.5)
good_trajs = qf.filter_trajectories(all_trajectories, joint_limits)
# Inspect a single trajectory
report = qf.score_trajectory(traj, joint_limits)
print(report)
# TrajectoryQualityReport(overall=0.812, passes=True, [jerk=0.923, joint_limits=0.800, root_height=0.700])Applied automatically inside prepare_trajectory() — trajectories below min_quality_score=0.5 are excluded before training.
robomotion/
├── __init__.py
├── config.py # Path configuration (GLI_PATH)
├── asset_setup.py # Asset symlinking utilities
│
├── algorithms/
│ ├── policy/ppo/
│ │ ├── trainer.py # Brax PPO training loop
│ │ ├── rollout.py # Policy rollout logic
│ │ ├── metrics.py # Episode metrics aggregation & logging
│ │ └── early_stopping.py # [NEW] Early stopping controller
│ └── runners/
│ └── tennis_ppo.py # Top-level training entry-point
│
├── environments/
│ └── humanoid/
│ ├── tennis_config.py # Task constants and default configs
│ ├── training/
│ │ ├── base.py # G1Env abstract base class
│ │ ├── tennis.py # G1TrackingTennisEnv
│ │ └── tennis_dr.py # Domain-randomized variant
│ ├── inference/
│ │ └── tennis.py # PlayG1TrackingTennisEnv (for rendering)
│ └── wrappers/
│ └── vectorization.py # VmapWrapper, EpisodeWrapper, DR wrapper
│
├── common/
│ ├── registry.py # Task/category registry
│ ├── stateful.py # StatefulObject base class
│ ├── geometry.py # JAX geometry utilities
│ ├── geometry_np.py # NumPy geometry utilities
│ ├── physics_utils.py # MuJoCo model/data utilities
│ ├── validator.py # [NEW] Config validation framework
│ └── data/
│ ├── trajectory.py # Trajectory dataclass
│ ├── loader.py # TrajectoryHandler
│ ├── processor.py # ExtendTrajData replay callback
│ └── quality_filter.py # [NEW] Trajectory quality scoring
│
├── randomization/
│ └── physics.py # Domain randomization functions
│
├── evaluation/
│ └── motion/
│ ├── exporter.py # JAX → ONNX conversion
│ └── renderer.py # ONNX policy renderer / video writer
│
├── cli/
│ ├── export_policy.py # CLI: export checkpoint to ONNX
│ └── setup_workspace.py # CLI: workspace initialization
│
└── scripts/
└── preprocessing/
└── motion.py # Offline motion pre-processing