Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
b79ac14
feat: Added run_server.sh.
qrico64 Mar 28, 2025
080725b
feat: Added cube reach environment (although it should really be call…
qrico64 Apr 2, 2025
0b9825e
log intervention rate
pranavnt Apr 3, 2025
43f72a9
feat: Added episode duration to wandb.
qrico64 Apr 4, 2025
e442a3a
feat: Added episode / success rate to wandb.
qrico64 Apr 4, 2025
8f0c1fa
feat: Support identifying immediate failures by pressing f.
qrico64 Apr 4, 2025
6b36aa7
feat: Setup script and environment.
qrico64 Apr 5, 2025
c26e47a
feat: Added second cube_reach task for fixed cube.
qrico64 Apr 5, 2025
7d69de7
Merge branch 'main' of github.com:WEIRDLabUW/jax-hitl-hil-serl into main
qrico64 Apr 5, 2025
8c6fcec
feat: Support RLIF experiments.
qrico64 Apr 6, 2025
d41a096
feat: Added CL, not sure if it works on robot.
qrico64 Apr 7, 2025
acf5c83
feat: Supported pretraining for RLIF.
qrico64 Apr 8, 2025
edfcb4f
fix: Fixed open-ended checkpoint replay buffer issue.
qrico64 Apr 9, 2025
a6a048c
feat: Added BC pretraining and breakpoint debugging.
qrico64 Apr 9, 2025
091b00c
fix: Fixed previous commit.
qrico64 Apr 9, 2025
88efbe0
feat: Added last n actions wrapper.
qrico64 Apr 9, 2025
661d623
feat: Supporting ZED Cameras (not done, in progress).
qrico64 Apr 9, 2025
6dc478a
fix: Quick Fix.
qrico64 Apr 9, 2025
616bb76
feat: Supporting ZED Cameras (part 2, need testing).
qrico64 Apr 10, 2025
1e5c850
feat: Fixed franka server launch script.
qrico64 Apr 11, 2025
f80f3fc
fix: Added instructions to fix HID not connecting to spacemouse.
qrico64 Apr 11, 2025
12f9ac5
feat: Supporting ZED Cameras (part 3, works but very laggy).
qrico64 Apr 11, 2025
43ec28c
feat: Supporting ZED Cameras Done, added cube_reach3 environment.
qrico64 Apr 11, 2025
9c7901c
oops.
qrico64 Apr 11, 2025
b14bed5
fix: Fixed action wrapper, fixed zed camera exposure issue.
qrico64 Apr 11, 2025
a29cb47
preference buffer
pranavnt Apr 17, 2025
90a9ef6
preference buffer data store
pranavnt Apr 17, 2025
253ee9a
feat: Various infrastructural improvements, and fixing the middle-of-…
qrico64 Apr 17, 2025
9f951c9
Merge branch 'main' of github.com:WEIRDLabUW/jax-hitl-hil-serl into main
qrico64 Apr 17, 2025
512a920
impl progress
pranavnt Apr 17, 2025
66cd991
cl implementation
pranavnt Apr 17, 2025
6336897
done
pranavnt Apr 17, 2025
f45525d
feat: Various fixes.
qrico64 Apr 18, 2025
3270b93
fix: Fixed a bunch of jax==0.6.0 issues.
qrico64 Apr 19, 2025
ae2f5f3
feat: By the previous commit, I meant as in, now this code should wor…
qrico64 Apr 19, 2025
f955167
feat: Fixed CL to include another alpha network for the gripper.
qrico64 Apr 19, 2025
8f4e230
fix
pranavnt Apr 20, 2025
604d389
fix: Fixed saving preferences during training CL.
qrico64 Apr 21, 2025
dd9bd55
feat: Added optimism to RLIF.
qrico64 Apr 23, 2025
a8d821d
feat: Essential modifications for robotiq gripper for our franka.
qrico64 Apr 23, 2025
6288c00
fix: Fixed RLIF.py for RLIF.
qrico64 Apr 23, 2025
45f995c
style: Removed all __pycache__ files.
qrico64 Apr 23, 2025
697d8e8
feat: Added soft CL to real-world codebase.
qrico64 Apr 23, 2025
caedff7
style: Added utility function to inspect environment.
qrico64 Apr 25, 2025
68cf3e6
fix: Fixed soft cl.
qrico64 Apr 25, 2025
b211924
feat: Added utility data inspection script.
qrico64 Apr 25, 2025
165a740
fix: Fixed optimism.
qrico64 Apr 25, 2025
6b6b4d3
style: Fixed gitignore.
qrico64 Apr 26, 2025
a37d620
feat: Added script to deploy demos to see if environment changed.
qrico64 Apr 26, 2025
6557a25
feat: Added script to inspect success/failure data.
qrico64 Apr 26, 2025
9dfc09d
fix: Fixed bug where first observation of a trajectory is incorrect.
qrico64 Apr 26, 2025
2ac0b0a
fix: Fixed optimism.
qrico64 Apr 27, 2025
fbfe935
fix: Add intervene_action and policy_action in info.
qrico64 Apr 27, 2025
f21147d
style: Improved style.
qrico64 Apr 27, 2025
8a8216f
feat: Script to align environment with past demos.
qrico64 Apr 29, 2025
e996a90
soft cl code updated
sriyash421 Apr 30, 2025
b8f13c4
action constraints added
sriyash421 Apr 30, 2025
d988d62
Merged branch 'sim' into 'main'.
qrico64 May 8, 2025
ed62665
feat: bc decay hardcoded (sim branch).
qrico64 May 12, 2025
ffbe898
feat: Log bc_coeff.
qrico64 May 12, 2025
de6d67c
fix: Modifications to rlif.py.
qrico64 May 12, 2025
4cc8dc0
Merge branch 'main-merge-test' of github.com:WEIRDLabUW/jax-hitl-hil-…
qrico64 May 12, 2025
3b7ea4b
feat: Extended inspect_data.py.
qrico64 May 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
__pycache__/
*.py[cod]

examples/experiments/**/debug_*/
examples/experiments/**/*.pkl
examples/experiments/**/checkpoint_*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# Distribution / packaging
.eggs/
dist/
build/
sdist/
*.egg-info/
*.whl

# Environments
.venv/
venv/
env/
ENV/
*.env

# PyInstaller
# Usually these files are written by a python script; excluding them is not always
# appropriate.
# https://docs.pyinstaller.org/en/stable/when-things-go-wrong.html
# _MEIPASS*
# _MEI*

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
nosetests.xml
coverage.xml
*.log
*.rpt
*.db

# Translations
*.mo
*.pot

# Django stuff:
*.sqlite3
*.sqlitedb
local_settings.py
/static/
/media/

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx stuff:
docs/_build

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# mkdocs
site/

# mypy
.mypy_cache/
.dmypy.json
dmypy.sock

# VS Code
.vscode/

# PyCharm
.idea/
demo_data/

examples/experiments/**/debug_*/
examples/experiments/**/*.pkl
examples/experiments/**/checkpoint_*
**/*.out
21 changes: 21 additions & 0 deletions check_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
echo "Getting position..."
curl -X POST localhost:5000/getpos_euler

echo ""
echo "Activating gripper..."
curl -X POST localhost:5000/activate_gripper
echo ""
echo ""

echo "Closing gripper in 1s..."
sleep 1

curl -X POST localhost:5000/close_gripper
echo ""
echo ""

echo "Opening gripper in 3s..."
sleep 3

curl -X POST localhost:5000/reset_gripper
echo ""
94 changes: 94 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
name: serl2
channels:
- defaults
- https://repo.anaconda.com/pkgs/main
- https://repo.anaconda.com/pkgs/r
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2025.2.25=h06a4308_0
- ld_impl_linux-64=2.40=h12ee557_0
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- ncurses=6.4=h6a678d5_0
- openssl=3.0.16=h5eee18b_0
- pip=25.0=py310h06a4308_0
- python=3.10.16=he870216_1
- readline=8.2=h5eee18b_0
- setuptools=75.8.0=py310h06a4308_0
- sqlite=3.45.3=h5eee18b_0
- tk=8.6.14=h39e8969_0
- tzdata=2025a=h04d1e81_0
- wheel=0.45.1=py310h06a4308_0
- xz=5.6.4=h5eee18b_1
- zlib=1.2.13=h5eee18b_1
- pip:
- absl-py==2.2.2
- blinker==1.9.0
- catkin-pkg==1.0.0
- certifi==2025.1.31
- cffi==1.17.1
- charset-normalizer==3.4.1
- click==8.1.8
- cloudpickle==3.1.1
- cython==3.0.12
- defusedxml==0.7.1
- distro==1.9.0
- docutils==0.21.2
- easyhid==0.0.10
- empy==4.2
- evdev==1.9.1
- flask==3.1.0
- gym==0.26.2
- gym-notices==0.0.8
- hidapi==0.14.0.post4
- idna==3.10
- itsdangerous==2.2.0
- jax==0.4.35
- jax-cuda12-pjrt==0.4.35
- jax-cuda12-plugin==0.4.35
- jaxlib==0.4.34
- jinja2==3.1.6
- lz4==4.4.4
- markupsafe==3.0.2
- ml-dtypes==0.5.1
- numpy==2.2.4
- nvidia-cublas-cu12==12.8.4.1
- nvidia-cuda-cupti-cu12==12.8.90
- nvidia-cuda-nvcc-cu12==12.8.93
- nvidia-cuda-runtime-cu12==12.8.90
- nvidia-cudnn-cu12==9.8.0.87
- nvidia-cufft-cu12==11.3.3.83
- nvidia-cusolver-cu12==11.7.3.90
- nvidia-cusparse-cu12==12.5.8.93
- nvidia-nccl-cu12==2.26.2
- nvidia-nvjitlink-cu12==12.8.93
- opencv-python==4.11.0.86
- opt-einsum==3.4.0
- pycparser==2.22
- pymodbus==2.5.3
- pynput==1.8.1
- pyopengl==3.1.9
- pyparsing==3.2.3
- pyquaternion==0.9.9
- pyrealsense2==2.55.1.6486
- pyserial==3.5
- pyspacemouse==1.1.4
- python-dateutil==2.9.0.post0
- python-xlib==0.33
- pyyaml==6.0.2
- pyzmq==26.3.0
- requests==2.32.3
- rospkg==1.6.0
- scipy==1.15.2
- six==1.17.0
- typing==3.7.4.3
- typing-extensions==4.13.1
- urllib3==2.3.0
- werkzeug==3.1.3
- zmq==0.0.0
prefix: /home/robot/miniconda3/envs/serl
157 changes: 157 additions & 0 deletions examples/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@

from absl import app, flags
import time
import numpy as np
import os
import pickle
import imageio
import cv2
import queue
from pynput import keyboard
import threading
from flax.training import checkpoints
import jax
import jax.numpy as jnp

# from experiments.mappings import CONFIG_MAPPING
from serl_launcher.agents.continuous.sac import SACAgent
from serl_launcher.agents.continuous.sac_hybrid_single import SACAgentHybridSingleArm
from serl_launcher.agents.continuous.sac_hybrid_dual import SACAgentHybridDualArm

from serl_launcher.utils.launcher import (
make_sac_pixel_agent,
make_sac_pixel_agent_hybrid_single_arm,
make_sac_pixel_agent_hybrid_dual_arm,
make_trainer_config,
make_wandb_logger,
)

checkpoint_path = "/home/qirico/Desktop/All-Weird/Human-Interventions/jax-hitl-hil-serl/examples/experiments/franka_sim/debug_rlif_2"

from experiments.config import DefaultTrainingConfig
class TrainConfig(DefaultTrainingConfig):
image_keys = ["front", "wrist"]
classifier_keys = ["front", "wrist"]
proprio_keys = ['panda/tcp_pos', 'panda/tcp_vel', 'panda/gripper_pos']
# buffer_period = 1000
# checkpoint_period = 5000
# steps_per_update = 50
pretraining_steps = 0 # How many steps to pre-train the model for using RLPD on offline data only.
reward_scale = 1 # How much to scale actual rewards (not RLIF penalties) for RLIF.
rlif_minus_one = False
checkpoint_period = 2000
cta_ratio = 2
random_steps = 0
discount = 0.98
buffer_period = 1000
batch_size = 64
encoder_type = "resnet-pretrained"
setup_mode = "single-arm-learned-gripper"

exp_name = "franka_sim"
config = TrainConfig()
# env = config.get_environment(fake_env=True,save_video=False,classifier=True)

intervene_steps = 0 # Default number of steps between pre and post intervention states
constraint_eps = 0.1 # Default constraint epsilon

# obs_key_shapes = [('front', (1, 128, 128, 3)), ('state', (1, 7)), ('wrist', (1, 128, 128, 3))]
obs_sample = {
'front': np.zeros((1, 128, 128, 3), dtype=np.uint8),
'state': np.zeros((1, 7), dtype=np.float32),
'wrist': np.zeros((1, 128, 128, 3), dtype=np.uint8),
}
action_sample = np.zeros(7, dtype=np.float32)

agent: SACAgentHybridSingleArm = make_sac_pixel_agent_hybrid_single_arm(
seed=0,
sample_obs=obs_sample,
sample_action=action_sample,
image_keys=config.image_keys,
encoder_type=config.encoder_type,
discount=config.discount,
enable_cl=False,
soft_cl = False,
intervene_steps=intervene_steps,
constraint_eps=constraint_eps,
)

ckpt = checkpoints.restore_checkpoint(
os.path.abspath(checkpoint_path),
agent.state,
step='40000',
)
agent = agent.replace(state=ckpt)


preference_buffer_base_path="experiments/franka_sim/debug_rlif_2/interventions/transitions"
preference_buffer_paths = [f"{preference_buffer_base_path}_{i}.pkl" for i in range(1000, 14000, 1000)]

preference_buffer = []

for preference_buffer_path in preference_buffer_paths:
if not os.path.exists(preference_buffer_path):
print(f"Preference buffer path {preference_buffer_path} does not exist.")
continue

# Load the preference buffer
with open(preference_buffer_path, 'rb') as f:
preference_buffer_part = pickle.load(f)
preference_buffer.extend(preference_buffer_part)

rng = jax.random.PRNGKey(0)

def get_action(obs, rng):
rng, key = jax.random.split(rng)
actions = agent.sample_actions(
observations=jax.device_put(obs),
argmax=True,
seed=key
)
return actions, rng

pre_intervention_obs = [p['observations'][0] for p in preference_buffer]
intervene_actions = [p['actions'][0] for p in preference_buffer]
policy_actions = [p['policy_actions'][0] for p in preference_buffer]
post_intervention_obs = [p['observations'][-1] for p in preference_buffer]

pre_intervention_obs = {
'front': np.array([obs['front'] for obs in pre_intervention_obs]),
'state': np.array([obs['state'] for obs in pre_intervention_obs]),
'wrist': np.array([obs['wrist'] for obs in pre_intervention_obs]),
}
pre_intervention_expert_action, rng = get_action(pre_intervention_obs, rng)


post_intervention_obs = {
'front': np.array([obs['front'] for obs in post_intervention_obs]),
'state': np.array([obs['state'] for obs in post_intervention_obs]),
'wrist': np.array([obs['wrist'] for obs in post_intervention_obs]),
}
post_intervention_expert_action, rng = get_action(post_intervention_obs, rng)

policy_actions = np.array(policy_actions)
intervene_actions = np.array(intervene_actions)

key, rng = jax.random.split(rng)
q_pre_expert = agent.forward_critic(pre_intervention_obs, pre_intervention_expert_action[:, :6], key)
key, rng = jax.random.split(rng)
q_post_expert = agent.forward_critic(post_intervention_obs, post_intervention_expert_action[:, :6], key)
key, rng = jax.random.split(rng)
q_pre_policy = agent.forward_critic(pre_intervention_obs, policy_actions[:, :6], key)
key, rng = jax.random.split(rng)
q_pre_intervene = agent.forward_critic(pre_intervention_obs, intervene_actions[:, :6], key)
# q_post_expert = agent.q_network.apply(agent.state.params, post_intervention_obs, post_intervention_expert_action)

# q_pre_policy = agent.q_network.apply(agent.state.params, pre_intervention_obs, policy_actions)
# q_post_policy = agent.q_network.apply(agent.state.params, pre_intervention_obs, policy_actions)
constraint1_acc = ((q_pre_expert - q_post_expert) < 0).mean()
qvalue_based_learning_intervene = ((q_pre_policy - q_pre_intervene) < 0).mean()
qvalue_based_learning_expert = ((q_pre_policy - q_pre_expert) < 0).mean()
constraint2_acc = ((q_pre_intervene - q_post_expert) < 0).mean()

print(f"Constraint 1 accuracy: {constraint1_acc}")
print(f"Q-value based learning intervene accuracy: {qvalue_based_learning_intervene}")
print(f"Q-value based learning expert accuracy: {qvalue_based_learning_expert}")
print(f"Constraint 2 accuracy: {constraint2_acc}")
breakpoint()
Loading