[Example] Add MemAgent long-context RL example (mem_agent)#291
[Example] Add MemAgent long-context RL example (mem_agent)#291CalvinXKY wants to merge 1 commit into
Conversation
Co-authored-by: Cursor Signed-off-by: kaiyuan <kyxiezju@163.com>
There was a problem hiding this comment.
Code Review
This pull request introduces the MemAgent long-context RL workflow, adding scripts and utilities for data preparation, multi-turn rollout generation, GRPO training, and RULER-HQA evaluation. The code review highlighted several critical improvements: skipping failed or aborted samples in custom_convert.py to prevent crashes, ensuring log probabilities are always appended to avoid list misalignment, and checking the status of the background vLLM process in run-eval.sh to prevent infinite loops. Additionally, feedback was provided to optimize performance by moving numpy imports to the module level in prepare_data.py, preserving case-sensitivity when extracting boxed answers in rollout.py and eval_ruler_hqa.py, and double-quoting bash array expansions in _common.sh to prevent word splitting.
| for i, sample in enumerate(samples): | ||
| meta = sample.train_metadata |
There was a problem hiding this comment.
Failed or aborted samples should be skipped during conversion. If a rollout fails (e.g., due to network timeouts or OOM), its tokens and other attributes may be None or empty, which will cause training crashes or corrupt the training data if appended.
for i, sample in enumerate(samples):
if sample.status in (sample.status.FAILED, sample.status.ABORTED):
continue
meta = sample.train_metadata| if sample.rollout_log_probs is not None: | ||
| rollout_log_probs_list.append(sample.rollout_log_probs) | ||
| has_rollout_log_probs = True |
There was a problem hiding this comment.
If some samples do not have log probs, skipping the append operation will cause rollout_log_probs_list to become shorter than tokens_list, leading to silent misalignment of log probs with tokens or shape mismatch crashes during training. Always append to the list, and check for presence to set has_rollout_log_probs.
| if sample.rollout_log_probs is not None: | |
| rollout_log_probs_list.append(sample.rollout_log_probs) | |
| has_rollout_log_probs = True | |
| rollout_log_probs_list.append(sample.rollout_log_probs) | |
| if sample.rollout_log_probs is not None: | |
| has_rollout_log_probs = True |
| if turn.get("rollout_log_probs"): | ||
| rollout_log_probs_list.append(turn["rollout_log_probs"]) | ||
| has_rollout_log_probs = True |
There was a problem hiding this comment.
Similarly to the sample-level log probs, always append the turn-level log probs to prevent list misalignment when some turns lack log probs.
| if turn.get("rollout_log_probs"): | |
| rollout_log_probs_list.append(turn["rollout_log_probs"]) | |
| has_rollout_log_probs = True | |
| rollout_log_probs_list.append(turn.get("rollout_log_probs")) | |
| if turn.get("rollout_log_probs"): | |
| has_rollout_log_probs = True |
| while true; do | ||
| resp=$(curl -sf --max-time 10 "${url}" 2>/dev/null || true) | ||
| if echo "${resp}" | grep -Fq "${MODEL_NAME}" 2>/dev/null; then | ||
| log "vLLM ready." | ||
| break | ||
| fi | ||
| attempts=$((attempts + 1)) | ||
| if (( attempts % 6 == 0 )); then | ||
| found=$(echo "${resp}" | grep -o '"id":"[^"]*"' 2>/dev/null | head -3 || echo "(no response)") | ||
| log "Still waiting... models: ${found}" | ||
| fi | ||
| sleep 5 | ||
| done |
There was a problem hiding this comment.
wait_for_server will loop infinitely if the background vllm serve process dies during startup (e.g., due to port conflicts or OOM). Checking if VLLM_PID is still running inside the loop prevents infinite hangs in automated pipelines.
| while true; do | |
| resp=$(curl -sf --max-time 10 "${url}" 2>/dev/null || true) | |
| if echo "${resp}" | grep -Fq "${MODEL_NAME}" 2>/dev/null; then | |
| log "vLLM ready." | |
| break | |
| fi | |
| attempts=$((attempts + 1)) | |
| if (( attempts % 6 == 0 )); then | |
| found=$(echo "${resp}" | grep -o '"id":"[^"]*"' 2>/dev/null | head -3 || echo "(no response)") | |
| log "Still waiting... models: ${found}" | |
| fi | |
| sleep 5 | |
| done | |
| while true; do | |
| if ! kill -0 "${VLLM_PID}" 2>/dev/null; then | |
| log "ERROR: vLLM process (PID ${VLLM_PID}) died." | |
| exit 1 | |
| fi | |
| resp=$(curl -sf --max-time 10 "${url}" 2>/dev/null || true) | |
| if echo "${resp}" | grep -Fq "${MODEL_NAME}" 2>/dev/null; then | |
| log "vLLM ready." | |
| break | |
| fi | |
| attempts=$((attempts + 1)) | |
| if (( attempts % 6 == 0 )); then | |
| found=$(echo "${resp}" | grep -o '"id":"[^"]*"' 2>/dev/null | head -3 || echo "(no response)") | |
| log "Still waiting... models: ${found}" | |
| fi | |
| sleep 5 | |
| done |
| def _to_json_safe(obj): | ||
| try: | ||
| import numpy as np | ||
| except ImportError: | ||
| np = None | ||
|
|
||
| if np is not None: | ||
| if isinstance(obj, np.ndarray): | ||
| return obj.tolist() | ||
| if isinstance(obj, (np.integer, np.floating, np.bool_)): | ||
| return obj.item() | ||
| if isinstance(obj, dict): | ||
| return {k: _to_json_safe(v) for k, v in obj.items()} | ||
| if isinstance(obj, list): | ||
| return [_to_json_safe(v) for v in obj] | ||
| return obj |
There was a problem hiding this comment.
Importing numpy inside _to_json_safe is highly inefficient because this function is called for every single row in the dataset. Moving the import and check to the module level avoids this overhead.
| def _to_json_safe(obj): | |
| try: | |
| import numpy as np | |
| except ImportError: | |
| np = None | |
| if np is not None: | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| if isinstance(obj, (np.integer, np.floating, np.bool_)): | |
| return obj.item() | |
| if isinstance(obj, dict): | |
| return {k: _to_json_safe(v) for k, v in obj.items()} | |
| if isinstance(obj, list): | |
| return [_to_json_safe(v) for v in obj] | |
| return obj | |
| try: | |
| import numpy as np | |
| except ImportError: | |
| np = None | |
| def _to_json_safe(obj): | |
| if np is not None: | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| if isinstance(obj, (np.integer, np.floating, np.bool_)): | |
| return obj.item() | |
| if isinstance(obj, dict): | |
| return {k: _to_json_safe(v) for k, v in obj.items()} | |
| if isinstance(obj, list): | |
| return [_to_json_safe(v) for v in obj] | |
| return obj |
| solution_str = final_output[-300:].lower() | ||
| pred = _extract_boxed(solution_str) or "" | ||
|
|
||
| score = 0.0 | ||
| for gt in ground_truth: | ||
| gt_lower = gt.lower() | ||
| try: | ||
| boxed = _last_boxed_only_string(solution_str) | ||
| if boxed is not None: | ||
| answer = _remove_boxed(boxed) | ||
| if _is_equiv(answer, gt_lower): | ||
| score = 1.0 | ||
| break |
There was a problem hiding this comment.
Lowercasing the entire solution_str before extracting the boxed answer can corrupt case-sensitive LaTeX commands or case-sensitive answers. It is safer to extract the boxed answer from the original case string, and only lowercase it during the final comparison.
| solution_str = final_output[-300:].lower() | |
| pred = _extract_boxed(solution_str) or "" | |
| score = 0.0 | |
| for gt in ground_truth: | |
| gt_lower = gt.lower() | |
| try: | |
| boxed = _last_boxed_only_string(solution_str) | |
| if boxed is not None: | |
| answer = _remove_boxed(boxed) | |
| if _is_equiv(answer, gt_lower): | |
| score = 1.0 | |
| break | |
| solution_str = final_output[-300:] | |
| pred = _extract_boxed(solution_str) or "" | |
| score = 0.0 | |
| for gt in ground_truth: | |
| gt_lower = gt.lower() | |
| try: | |
| boxed = _last_boxed_only_string(solution_str) | |
| if boxed is not None: | |
| answer = _remove_boxed(boxed) | |
| if _is_equiv(answer.lower(), gt_lower): | |
| score = 1.0 | |
| break |
| return None | ||
|
|
||
| gold = item["answers"][0] if item.get("answers") else "" | ||
| pred = _extract_boxed(response[-300:].lower()) or "" |
There was a problem hiding this comment.
Lowercasing the response before calling _extract_boxed is redundant and potentially harmful for case-sensitive patterns, especially since the metric functions (_normalize_answer) handle lowercasing themselves.
| pred = _extract_boxed(response[-300:].lower()) or "" | |
| pred = _extract_boxed(response[-300:]) or "" |
| if [[ "${RUN_TRAIN_DIRECT:-0}" == "1" && -n "${RUN_TRAIN_DIRECT_PY:-}" && -f "${RUN_TRAIN_DIRECT_PY}" ]]; then | ||
| export VIME_ROOT NCCL_NVLS_ENABLE MEM_CHUNK_TOKENS MEM_MAX_MEMORY MEM_MAX_FINAL MEM_MAX_CHUNKS | ||
| python3 "${RUN_TRAIN_DIRECT_PY}" \ | ||
| ${MODEL_ARGS[@]} \ |
| --runtime-env-json="${runtime_env}" \ | ||
| -- python3 train.py \ | ||
| --train-backend megatron \ | ||
| ${MODEL_ARGS[@]} \ |


Summary
Add `examples/mem_agent`, a vime port of MemAgent for long-context QA with RL:
Also updates `examples/README.md` with a directory entry.
Files
Test plan