Skip to content

[Example] Add MemAgent long-context RL example (mem_agent)#291

Open
CalvinXKY wants to merge 1 commit into
vllm-project:mainfrom
CalvinXKY:examples/mem-agent
Open

[Example] Add MemAgent long-context RL example (mem_agent)#291
CalvinXKY wants to merge 1 commit into
vllm-project:mainfrom
CalvinXKY:examples/mem-agent

Conversation

@CalvinXKY

Copy link
Copy Markdown
Collaborator

Summary

Add `examples/mem_agent`, a vime port of MemAgent for long-context QA with RL:

  • Multi-turn rollout: chunk-wise memory update + final `\boxed{}` answer (HotpotQA reward)
  • Training: GRPO on Qwen3-4B via vLLM colocate (`run-qwen3-4b-train.sh`)
  • Evaluation: RULER-HQA offline eval via vLLM serve (`run-eval.sh`, `eval_ruler_hqa.py`)
  • Data prep: HuggingFace dataset `BytedTsinghua-SIA/hotpotqa` -> training JSONL + eval JSON download
    Also updates `examples/README.md` with a directory entry.

Files

File Description
`rollout.py` / `rollout_client.py` Multi-turn MemAgent rollout + vLLM router client
`custom_convert.py` Unroll trajectories for GRPO
`prepare_data.py` / `prepare-eval-data.sh` Training and eval data preparation
`eval_ruler_hqa.py` / `run-eval.sh` RULER-HQA evaluation
`run-qwen3-4b-train.sh` / `_common.sh` GRPO training launch scripts
`convert-to-hf.sh` Megatron checkpoint -> HuggingFace
See `examples/mem_agent/README.md` for quick start (data download, training, eval).

Test plan

  • Pre-commit passed (ruff, isort, black)
  • GRPO training smoke on H800 (Qwen3-4B, colocate vLLM)
  • RULER-HQA eval at n_docs = 50 / 200 / 800 (RL checkpoint vs base Qwen3-4B)
  • CI (await upstream checks)

Co-authored-by: Cursor
Signed-off-by: kaiyuan <kyxiezju@163.com>
@read-the-docs-community

Copy link
Copy Markdown

Documentation build overview

📚 vime | 🛠️ Build #33270407 | 📁 Comparing 49749f7 against latest (e62d44f)

  🔍 Preview build  

1 file changed
+ _examples_synced/mem_agent/README.html

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the MemAgent long-context RL workflow, adding scripts and utilities for data preparation, multi-turn rollout generation, GRPO training, and RULER-HQA evaluation. The code review highlighted several critical improvements: skipping failed or aborted samples in custom_convert.py to prevent crashes, ensuring log probabilities are always appended to avoid list misalignment, and checking the status of the background vLLM process in run-eval.sh to prevent infinite loops. Additionally, feedback was provided to optimize performance by moving numpy imports to the module level in prepare_data.py, preserving case-sensitivity when extracting boxed answers in rollout.py and eval_ruler_hqa.py, and double-quoting bash array expansions in _common.sh to prevent word splitting.

Comment on lines +46 to +47
for i, sample in enumerate(samples):
meta = sample.train_metadata

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Failed or aborted samples should be skipped during conversion. If a rollout fails (e.g., due to network timeouts or OOM), its tokens and other attributes may be None or empty, which will cause training crashes or corrupt the training data if appended.

    for i, sample in enumerate(samples):
        if sample.status in (sample.status.FAILED, sample.status.ABORTED):
            continue
        meta = sample.train_metadata

Comment on lines +59 to +61
if sample.rollout_log_probs is not None:
rollout_log_probs_list.append(sample.rollout_log_probs)
has_rollout_log_probs = True

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If some samples do not have log probs, skipping the append operation will cause rollout_log_probs_list to become shorter than tokens_list, leading to silent misalignment of log probs with tokens or shape mismatch crashes during training. Always append to the list, and check for presence to set has_rollout_log_probs.

Suggested change
if sample.rollout_log_probs is not None:
rollout_log_probs_list.append(sample.rollout_log_probs)
has_rollout_log_probs = True
rollout_log_probs_list.append(sample.rollout_log_probs)
if sample.rollout_log_probs is not None:
has_rollout_log_probs = True

Comment on lines +79 to +81
if turn.get("rollout_log_probs"):
rollout_log_probs_list.append(turn["rollout_log_probs"])
has_rollout_log_probs = True

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similarly to the sample-level log probs, always append the turn-level log probs to prevent list misalignment when some turns lack log probs.

Suggested change
if turn.get("rollout_log_probs"):
rollout_log_probs_list.append(turn["rollout_log_probs"])
has_rollout_log_probs = True
rollout_log_probs_list.append(turn.get("rollout_log_probs"))
if turn.get("rollout_log_probs"):
has_rollout_log_probs = True

Comment on lines +75 to +87
while true; do
resp=$(curl -sf --max-time 10 "${url}" 2>/dev/null || true)
if echo "${resp}" | grep -Fq "${MODEL_NAME}" 2>/dev/null; then
log "vLLM ready."
break
fi
attempts=$((attempts + 1))
if (( attempts % 6 == 0 )); then
found=$(echo "${resp}" | grep -o '"id":"[^"]*"' 2>/dev/null | head -3 || echo "(no response)")
log "Still waiting... models: ${found}"
fi
sleep 5
done

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

wait_for_server will loop infinitely if the background vllm serve process dies during startup (e.g., due to port conflicts or OOM). Checking if VLLM_PID is still running inside the loop prevents infinite hangs in automated pipelines.

Suggested change
while true; do
resp=$(curl -sf --max-time 10 "${url}" 2>/dev/null || true)
if echo "${resp}" | grep -Fq "${MODEL_NAME}" 2>/dev/null; then
log "vLLM ready."
break
fi
attempts=$((attempts + 1))
if (( attempts % 6 == 0 )); then
found=$(echo "${resp}" | grep -o '"id":"[^"]*"' 2>/dev/null | head -3 || echo "(no response)")
log "Still waiting... models: ${found}"
fi
sleep 5
done
while true; do
if ! kill -0 "${VLLM_PID}" 2>/dev/null; then
log "ERROR: vLLM process (PID ${VLLM_PID}) died."
exit 1
fi
resp=$(curl -sf --max-time 10 "${url}" 2>/dev/null || true)
if echo "${resp}" | grep -Fq "${MODEL_NAME}" 2>/dev/null; then
log "vLLM ready."
break
fi
attempts=$((attempts + 1))
if (( attempts % 6 == 0 )); then
found=$(echo "${resp}" | grep -o '"id":"[^"]*"' 2>/dev/null | head -3 || echo "(no response)")
log "Still waiting... models: ${found}"
fi
sleep 5
done

Comment on lines +42 to +57
def _to_json_safe(obj):
try:
import numpy as np
except ImportError:
np = None

if np is not None:
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.integer, np.floating, np.bool_)):
return obj.item()
if isinstance(obj, dict):
return {k: _to_json_safe(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_json_safe(v) for v in obj]
return obj

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Importing numpy inside _to_json_safe is highly inefficient because this function is called for every single row in the dataset. Moving the import and check to the module level avoids this overhead.

Suggested change
def _to_json_safe(obj):
try:
import numpy as np
except ImportError:
np = None
if np is not None:
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.integer, np.floating, np.bool_)):
return obj.item()
if isinstance(obj, dict):
return {k: _to_json_safe(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_json_safe(v) for v in obj]
return obj
try:
import numpy as np
except ImportError:
np = None
def _to_json_safe(obj):
if np is not None:
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, (np.integer, np.floating, np.bool_)):
return obj.item()
if isinstance(obj, dict):
return {k: _to_json_safe(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_to_json_safe(v) for v in obj]
return obj

Comment on lines +247 to +259
solution_str = final_output[-300:].lower()
pred = _extract_boxed(solution_str) or ""

score = 0.0
for gt in ground_truth:
gt_lower = gt.lower()
try:
boxed = _last_boxed_only_string(solution_str)
if boxed is not None:
answer = _remove_boxed(boxed)
if _is_equiv(answer, gt_lower):
score = 1.0
break

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Lowercasing the entire solution_str before extracting the boxed answer can corrupt case-sensitive LaTeX commands or case-sensitive answers. It is safer to extract the boxed answer from the original case string, and only lowercase it during the final comparison.

Suggested change
solution_str = final_output[-300:].lower()
pred = _extract_boxed(solution_str) or ""
score = 0.0
for gt in ground_truth:
gt_lower = gt.lower()
try:
boxed = _last_boxed_only_string(solution_str)
if boxed is not None:
answer = _remove_boxed(boxed)
if _is_equiv(answer, gt_lower):
score = 1.0
break
solution_str = final_output[-300:]
pred = _extract_boxed(solution_str) or ""
score = 0.0
for gt in ground_truth:
gt_lower = gt.lower()
try:
boxed = _last_boxed_only_string(solution_str)
if boxed is not None:
answer = _remove_boxed(boxed)
if _is_equiv(answer.lower(), gt_lower):
score = 1.0
break

return None

gold = item["answers"][0] if item.get("answers") else ""
pred = _extract_boxed(response[-300:].lower()) or ""

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Lowercasing the response before calling _extract_boxed is redundant and potentially harmful for case-sensitive patterns, especially since the metric functions (_normalize_answer) handle lowercasing themselves.

Suggested change
pred = _extract_boxed(response[-300:].lower()) or ""
pred = _extract_boxed(response[-300:]) or ""

if [[ "${RUN_TRAIN_DIRECT:-0}" == "1" && -n "${RUN_TRAIN_DIRECT_PY:-}" && -f "${RUN_TRAIN_DIRECT_PY}" ]]; then
export VIME_ROOT NCCL_NVLS_ENABLE MEM_CHUNK_TOKENS MEM_MAX_MEMORY MEM_MAX_FINAL MEM_MAX_CHUNKS
python3 "${RUN_TRAIN_DIRECT_PY}" \
${MODEL_ARGS[@]} \

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In bash, expanding arrays without double quotes can cause word splitting if any element contains spaces. It is a standard best practice to always double-quote array expansions: "${MODEL_ARGS[@]}".

Suggested change
${MODEL_ARGS[@]} \
"${MODEL_ARGS[@]}" \

--runtime-env-json="${runtime_env}" \
-- python3 train.py \
--train-backend megatron \
${MODEL_ARGS[@]} \

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Double-quote the array expansion to prevent word splitting.

Suggested change
${MODEL_ARGS[@]} \
"${MODEL_ARGS[@]}" \

@CalvinXKY

Copy link
Copy Markdown
Collaborator Author

Test result:

image image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant