Skip to content

[NPU] colocate training with TMS, bridge sync, and IPC weight loading#285

Open
CalvinXKY wants to merge 3 commits into
vllm-project:ascendfrom
CalvinXKY:feat/npu-colocate
Open

[NPU] colocate training with TMS, bridge sync, and IPC weight loading#285
CalvinXKY wants to merge 3 commits into
vllm-project:ascendfrom
CalvinXKY:feat/npu-colocate

Conversation

@CalvinXKY

@CalvinXKY CalvinXKY commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • NPU colocate end-to-end: TMS torch mode with NPU/GPU env split, explicit PYTORCH_NPU_ALLOC_CONF=expandable_segments:False, CANN PYTHONPATH, safe empty_cache, and IPC weight sync via UpdateWeightFromTensor + bridge mode.
  • Docker patches: vllm-ascend colocate init, vllm flash_attn NPU guard, vllm-ascend NPUIPCWeightTransferEngine API for vLLM 0.22.
  • Trim vllm-ascend.patch worker changes to colocate only; remap HF packed modules (QKV, gate_up) in vime IPC path.

Motivation

New NPU images set PYTORCH_NPU_ALLOC_CONF=expandable_segments:True in Dockerfile, which breaks TMS/CaMem colocate. Validated on 8×910B (Qwen3-4B, TP=2, bridge mode, smoke + long run). Code must explicitly override alloc conf; pop/"" is insufficient.

Key changes

Area What
actor_group.py NPU: TMS torch only (no LD_PRELOAD/TMS_INIT_ENABLE); colocate alloc conf override
rollout.py, vllm_engine.py Colocate expandable_segments:False for Ray workers and vLLM subprocess
memory_utils.py, actor.py Skip/patch empty_cache under TMS custom allocator
update_weight_from_tensor.py NPU colocate IPC direct param writes
vllm.patch Skip flash_attn import on NPU
vllm-ascend.patch Colocate init + npu_ipc_engine model arg

TMS .so note

Dockerfile.npu builds torch_memory_saver from sgl-kernel-npu@2026.6.0. The stock wheel .so may still be incompatible with NPU TMS torch mode until upstream fixes TMS_INIT_ENABLE handling. Validated runs used known-good .so MD5 (03adebff… / 83181484…). Track upstream or pin binaries if needed.

Test plan

  • pre-commit run on changed files
  • Smoke colocate on vime_pr285_test (8×910B, num-rollout=1)
  • Rebuild image from docker/Dockerfile.npu and verify colocate without manual .so swap
  • 100+ step diff vs pr266 baseline

Move packed-module remap into vime colocate worker extension and keep vllm-ascend patch limited to colocate init (free_memory check and profiling assert). NPU colocate uses is_checkpoint_format=False, skips layerwise_reload, and loads weights via direct param writes with QKV/gate_up remap.

Signed-off-by: kaiyuan <kyxiezju@163.com>
@read-the-docs-community

read-the-docs-community Bot commented Jun 23, 2026

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces NPU support for colocated weight updates, including custom weight loading, NPU synchronization, and patching for the NPU worker. The review feedback highlights several critical issues: first, torch_npu.get_device_properties lacks a uuid attribute on Ascend NPU, which will cause an AttributeError and crash the process; second, the colocate worker bypasses general NPU patches (such as rotary embeddings and MoE weight loaders) which must be explicitly applied; and third, _load_colocate_weights_direct lacks a torch.no_grad() context and needs a shape length guard to prevent potential errors during dimension-sharding.

Comment on lines 61 to 68
def _current_gpu_uuid() -> str:
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
if is_npu():
device_index = torch.npu.current_device()
props = torch.npu.get_device_properties(device_index)
else:
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
return str(props.uuid)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The torch_npu.get_device_properties object does not have a uuid attribute on Ascend NPU. Accessing props.uuid directly will raise an AttributeError and crash the weight update process. We should fallback to a unique identifier such as f"npu:{device_index}" if uuid is not present.

Suggested change
def _current_gpu_uuid() -> str:
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
if is_npu():
device_index = torch.npu.current_device()
props = torch.npu.get_device_properties(device_index)
else:
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
return str(props.uuid)
def _current_gpu_uuid() -> str:
if is_npu():
device_index = torch.npu.current_device()
props = torch.npu.get_device_properties(device_index)
return getattr(props, "uuid", f"npu:{device_index}")
else:
device_index = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device_index)
return str(props.uuid)

Comment on lines +635 to +640
if is_npu():
device_index = torch.npu.current_device()
physical_gpu_id = str(torch.npu.get_device_properties(device_index).uuid)
else:
device_index = torch.cuda.current_device()
physical_gpu_id = str(torch.cuda.get_device_properties(device_index).uuid)

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The torch_npu.get_device_properties object does not have a uuid attribute on Ascend NPU. Accessing props.uuid directly will raise an AttributeError and crash the weight update process. We should fallback to a unique identifier such as f"npu:{device_index}" if uuid is not present.

Suggested change
if is_npu():
device_index = torch.npu.current_device()
physical_gpu_id = str(torch.npu.get_device_properties(device_index).uuid)
else:
device_index = torch.cuda.current_device()
physical_gpu_id = str(torch.cuda.get_device_properties(device_index).uuid)
if is_npu():
device_index = torch.npu.current_device()
props = torch.npu.get_device_properties(device_index)
physical_gpu_id = getattr(props, "uuid", f"npu:{device_index}")
else:
device_index = torch.cuda.current_device()
physical_gpu_id = str(torch.cuda.get_device_properties(device_index).uuid)

Comment on lines +39 to +58
def _patch_npu_colocate_worker() -> None:
"""Skip layerwise_reload in NPUWorker weight-update hooks (colocate OOM fix)."""
try:
from vllm_ascend.worker.worker import NPUWorker
except ImportError:
return

if getattr(NPUWorker, "_vime_colocate_patched", False):
return

def _patched_start_weight_update(self, is_checkpoint_format: bool = True) -> None:
self._weight_update_active = True
self._is_checkpoint_format = is_checkpoint_format

def _patched_finish_weight_update(self) -> None:
self._weight_update_active = False

NPUWorker.start_weight_update = _patched_start_weight_update # type: ignore[method-assign]
NPUWorker.finish_weight_update = _patched_finish_weight_update # type: ignore[method-assign]
NPUWorker._vime_colocate_patched = True # type: ignore[attr-defined]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When using NPU colocate mode, the vLLMColocateWorkerExtension is loaded instead of vLLMWorkerExtension. This bypasses the general NPU patches applied in _VLLMHijack._patch_npu_worker() and _VLLMHijack._patch_npu_rotary_emb(). As a result, MoE models will fail due to the missing weight_loader attribute on EP, and rotary embeddings will fail to initialize correctly if flash_attn is not available.

To fix this, we should apply the rotary embedding patch and ensure patch_moe_weight_loader is called during _patched_start_weight_update.

def _patch_npu_colocate_worker() -> None:
    """Skip layerwise_reload in NPUWorker weight-update hooks (colocate OOM fix)."""
    try:
        from vllm_ascend.worker.worker import NPUWorker
    except ImportError:
        return

    if getattr(NPUWorker, "_vime_colocate_patched", False):
        return

    _VLLMHijack._patch_npu_rotary_emb()

    def _patched_start_weight_update(self, is_checkpoint_format: bool = True) -> None:
        self._weight_update_active = True
        self._is_checkpoint_format = is_checkpoint_format
        _VLLMHijack.patch_moe_weight_loader(self.model_runner.model)

    def _patched_finish_weight_update(self) -> None:
        self._weight_update_active = False

    NPUWorker.start_weight_update = _patched_start_weight_update  # type: ignore[method-assign]
    NPUWorker.finish_weight_update = _patched_finish_weight_update  # type: ignore[method-assign]
    NPUWorker._vime_colocate_patched = True  # type: ignore[attr-defined]

Comment on lines +71 to +116
def _load_colocate_weights_direct(
model: torch.nn.Module,
weights: list[tuple[str, torch.Tensor]],
*,
tp_rank: int,
) -> None:
"""Load HF-named weights into vLLM packed params (qkv_proj, gate_up_proj, etc.)."""
from vllm.model_executor.utils import get_packed_modules_mapping

packed_map = get_packed_modules_mapping(model)
sub_to_packed = {}
for packed_name, sub_names in packed_map.items():
for idx, sub_name in enumerate(sub_names):
sub_to_packed[sub_name] = (packed_name, idx)
qkv_shard_ids = {"q_proj": "q", "k_proj": "k", "v_proj": "v"}

def _remap_and_get_shard_id(hf_name: str) -> tuple[str, object | None]:
parts = hf_name.split(".")
shard_id = None
for i in range(len(parts)):
if parts[i] in sub_to_packed:
packed_name, idx = sub_to_packed[parts[i]]
shard_id = qkv_shard_ids.get(parts[i], idx)
parts[i] = packed_name
break
return ".".join(parts), shard_id

for name, weight in weights:
remapped, shard_id = _remap_and_get_shard_id(name)
try:
param = model.get_parameter(remapped)
except AttributeError:
continue
weight_loader = getattr(param, "weight_loader", None)
if weight_loader is not None and shard_id is not None:
weight_loader(param, weight, loaded_shard_id=shard_id)
elif weight_loader is not None:
weight_loader(param, weight)
elif param.shape == weight.shape:
param.data.copy_(weight)
else:
for dim in range(len(param.shape)):
if param.shape[dim] != weight.shape[dim]:
shard_size = param.shape[dim]
param.data.copy_(weight.narrow(dim, tp_rank * shard_size, shard_size))
break

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are two issues in _load_colocate_weights_direct:

  1. The weight loading and copying operations are not wrapped in a torch.no_grad() context. Since this runs on the worker side via RPC, there is no active torch.no_grad() decorator, which can lead to unnecessary autograd graph construction and memory overhead.
  2. The dimension-sharding fallback assumes len(param.shape) == len(weight.shape). If they differ, it could raise an IndexError or perform incorrect narrowing.

We should wrap the loading loop in with torch.no_grad(): and add a guard for shape length equality.

def _load_colocate_weights_direct(
    model: torch.nn.Module,
    weights: list[tuple[str, torch.Tensor]],
    *,
    tp_rank: int,
) -> None:
    """Load HF-named weights into vLLM packed params (qkv_proj, gate_up_proj, etc.)."""
    from vllm.model_executor.utils import get_packed_modules_mapping

    packed_map = get_packed_modules_mapping(model)
    sub_to_packed = {}
    for packed_name, sub_names in packed_map.items():
        for idx, sub_name in enumerate(sub_names):
            sub_to_packed[sub_name] = (packed_name, idx)
    qkv_shard_ids = {"q_proj": "q", "k_proj": "k", "v_proj": "v"}

    def _remap_and_get_shard_id(hf_name: str) -> tuple[str, object | None]:
        parts = hf_name.split(".")
        shard_id = None
        for i in range(len(parts)):
            if parts[i] in sub_to_packed:
                packed_name, idx = sub_to_packed[parts[i]]
                shard_id = qkv_shard_ids.get(parts[i], idx)
                parts[i] = packed_name
                break
        return ".".join(parts), shard_id

    with torch.no_grad():
        for name, weight in weights:
            remapped, shard_id = _remap_and_get_shard_id(name)
            try:
                param = model.get_parameter(remapped)
            except AttributeError:
                continue
            weight_loader = getattr(param, "weight_loader", None)
            if weight_loader is not None and shard_id is not None:
                weight_loader(param, weight, loaded_shard_id=shard_id)
            elif weight_loader is not None:
                weight_loader(param, weight)
            elif param.shape == weight.shape:
                param.data.copy_(weight)
            else:
                if len(param.shape) == len(weight.shape):
                    for dim in range(len(param.shape)):
                        if param.shape[dim] != weight.shape[dim]:
                            shard_size = param.shape[dim]
                            param.data.copy_(weight.narrow(dim, tp_rank * shard_size, shard_size))
                            break

Add TMS torch mode, training region, colocate-only PYTORCH_NPU_ALLOC_CONF override, CANN PYTHONPATH for Ray workers, and safe NPU empty_cache. Complements IPC weight sync for bridge colocate on 8x910B.

Signed-off-by: kaiyuan <kyxiezju@163.com>
@CalvinXKY CalvinXKY changed the title [NPU] colocate IPC weight sync on Ascend [NPU] colocate training with TMS, bridge sync, and IPC weight loading Jun 24, 2026
@CalvinXKY

CalvinXKY commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator Author

TEST

** vLLM 0.21 **

script

#!/bin/bash
set -e

MODE="${1:-smoke}"
case "$MODE" in
  smoke) NUM_ROLLOUT=1 ;;
  test)  NUM_ROLLOUT=100 ;;
  long)  NUM_ROLLOUT=500 ;;
  *)
    echo "Usage: $0 [smoke|test|long]" >&2
    exit 1
    ;;
esac

source /root/.proxy_env.sh 2>/dev/null || true
source /usr/local/Ascend/driver/bin/setenv.bash
source "${ASCEND_TOOLKIT_HOME:-/usr/local/Ascend/ascend-toolkit}/set_env.sh"
source /usr/local/Ascend/nnal/atb/set_env.sh

VIME_ROOT="${VIME_ROOT:-/root/vime}"
RAY_TMPDIR="${RAY_TMPDIR:-/tmp/ray_vime_colocate_aligned}"
RAY_PORT="${RAY_PORT:-6394}"
RAY_DASHBOARD_PORT="${RAY_DASHBOARD_PORT:-8280}"
LOG="${LOG:-/root/vime/colocate_aligned_${MODE}.log}"
PYTHON="${PYTHON:-/usr/local/python3.12.13/bin/python3}"

HF_CKPT="${HF_CKPT:-/data/nfs_87/model/Qwen3-4B}"
PROMPT_DATA="${PROMPT_DATA:-/data/nfs_87/wx/data/slime/gsm8k/train.parquet}"

CANN_PYTHON="${ASCEND_TOOLKIT_HOME}/python/site-packages"
if [ ! -d "${CANN_PYTHON}/acl" ]; then
  CANN_PYTHON="/usr/local/Ascend/ascend-toolkit/latest/python/site-packages"
fi

echo "=== PR #285 colocate ALIGNED (${MODE}, num-rollout=${NUM_ROLLOUT}) ==="
echo "VIME_ROOT=$VIME_ROOT HF_CKPT=$HF_CKPT LOG=$LOG"

export PYTHONUNBUFFERED=1
export VLLM_ASCEND_ENABLE_NZ=0
export ASCEND_COREDUMP_SIGNAL=none
export HCCL_DETERMINISTIC=true
export HCCL_CONNECT_TIMEOUT=7200
export HCCL_HOST_SOCKET_PORT_RANGE=60000-60050
export HCCL_NPU_SOCKET_PORT_RANGE=61000-61050
export ASCEND_RT_VISIBLE_DEVICES="${ASCEND_RT_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
export CUDA_DEVICE_MAX_CONNECTIONS=1
export RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES=1
export HYDRA_FULL_ERROR=1
export VLLM_SERVER_DEV_MODE=1
export VLLM_USE_AOT_COMPILE=0
export DISABLE_L2_CACHE=1

unset TORCH_DEVICE_BACKEND_AUTOLOAD
# pr285 colocate + Ray: must disable expandable_segments (TMS/captures_underway)
export PYTORCH_NPU_ALLOC_CONF=expandable_segments:False

export RAY_DISABLE_SIGINT_OVERRIDE=1
export ATB_MATMUL_SHUFFLE_K_ENABLE=0
export ATB_LLM_LCOC_ENABLE=0
export TRITON_NPU_COMPILER_PATH=/usr/local/Ascend/cann-9.0.0/aarch64-linux/bin
export LD_LIBRARY_PATH=/usr/local/Ascend/driver/lib64:/usr/local/Ascend/ascend-toolkit/latest/lib64:/usr/local/Ascend/nnal/atb/latest/atb/cxx_abi_1/lib:/usr/local/Ascend/cann-9.0.0/lib64:${LD_LIBRARY_PATH}

export PATH="/usr/local/python3.12.13/bin:${PATH}"
export PYTHONPATH="/root/Megatron-LM:/root/vllm:/root/vllm-ascend:/root/vime:/root/Megatron-Bridge:/root/mbridge:/root/MindSpeed:${CANN_PYTHON}:/usr/local/Ascend/ascend-toolkit/latest/tools/ms_fmk_transplt/torch_npu_bridge:${PYTHONPATH:-}"

cd "$VIME_ROOT"
source "${VIME_ROOT}/scripts/models/qwen3-4B.sh"

echo "=== clean Ray / vLLM / train ==="
ray stop --force 2>/dev/null || true
pkill -9 -f 'ray::' 2>/dev/null || true
pkill -9 -f 'gcs_server' 2>/dev/null || true
pkill -9 -f 'raylet' 2>/dev/null || true
pkill -9 -f 'dashboard' 2>/dev/null || true
pkill -9 -f 'train.py' 2>/dev/null || true
pkill -9 -f 'vllm serve' 2>/dev/null || true
pkill -9 -f 'VLLMWorker' 2>/dev/null || true
pkill -9 -f 'VLLMEngine' 2>/dev/null || true
pkill -9 -f 'EngineCore' 2>/dev/null || true
pkill -9 -f 'RolloutManager' 2>/dev/null || true
pkill -9 -f 'MegatronTrainRayActor' 2>/dev/null || true
sleep 5
rm -rf /tmp/ray "${RAY_TMPDIR}"/session_* "${RAY_TMPDIR}"/ray_current_cluster "$RAY_TMPDIR"
rm -f "$LOG"

ray start --head \
  --temp-dir="$RAY_TMPDIR" \
  --port="$RAY_PORT" \
  --num-gpus=0 \
  --resources='{"NPU": 8}' \
  --disable-usage-stats \
  --dashboard-host=0.0.0.0 \
  --dashboard-port="$RAY_DASHBOARD_PORT"

sleep 3
ray status | head -15

RUNTIME_ENV_JSON=$("$PYTHON" - <<PY
import json, os
print(json.dumps({
    "env_vars": {
        "PATH": os.environ.get("PATH", ""),
        "PYTHONPATH": os.environ.get("PYTHONPATH", ""),
        "PYTHONUNBUFFERED": "1",
        "PYTORCH_NPU_ALLOC_CONF": "expandable_segments:False",
        "VLLM_ASCEND_ENABLE_NZ": "0",
        "VLLM_SERVER_DEV_MODE": "1",
        "VLLM_USE_AOT_COMPILE": "0",
        "HCCL_DETERMINISTIC": "true",
        "HCCL_CONNECT_TIMEOUT": "7200",
        "HCCL_HOST_SOCKET_PORT_RANGE": "60000-60050",
        "HCCL_NPU_SOCKET_PORT_RANGE": "61000-61050",
        "ASCEND_RT_VISIBLE_DEVICES": os.environ.get("ASCEND_RT_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7"),
        "CUDA_DEVICE_MAX_CONNECTIONS": "1",
        "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1",
        "RAY_DISABLE_SIGINT_OVERRIDE": "1",
        "ATB_MATMUL_SHUFFLE_K_ENABLE": "0",
        "ATB_LLM_LCOC_ENABLE": "0",
        "TRITON_NPU_COMPILER_PATH": "/usr/local/Ascend/cann-9.0.0/aarch64-linux/bin",
        "ASCEND_COREDUMP_SIGNAL": "none",
        "DISABLE_L2_CACHE": "1",
        "LD_LIBRARY_PATH": os.environ.get("LD_LIBRARY_PATH", ""),
    }
}))
PY
)

RAY_ADDR="http://127.0.0.1:${RAY_DASHBOARD_PORT}"
echo "=== ray job submit -> $LOG ==="

ray job submit --address="$RAY_ADDR" \
  --working-dir="$VIME_ROOT" \
  --runtime-env-json="$RUNTIME_ENV_JSON" \
  --no-wait \
  -- python3 train.py \
  --train-backend megatron \
  --actor-num-nodes 1 \
  --actor-num-gpus-per-node 8 \
  --rollout-num-gpus 8 \
  --colocate \
  --train-memory-margin-bytes 2147483648 \
  "${MODEL_ARGS[@]}" \
  --hf-checkpoint "$HF_CKPT" \
  --load "$HF_CKPT" \
  --megatron-to-hf-mode bridge \
  --vllm-weight-sync-mode native \
  --prompt-data "$PROMPT_DATA" \
  --input-key question \
  --label-key label \
  --apply-chat-template \
  --rollout-shuffle \
  --rm-type math \
  --num-rollout "$NUM_ROLLOUT" \
  --rollout-batch-size 16 \
  --n-samples-per-prompt 4 \
  --rollout-max-response-len 1024 \
  --vllm-max-model-len 4096 \
  --rollout-temperature 1.0 \
  --global-batch-size 64 \
  --balance-data \
  --advantage-estimator grpo \
  --kl-loss-coef 0.0 \
  --kl-loss-type low_var_kl \
  --kl-coef 0.00 \
  --eps-clip 0.2 \
  --optimizer adam \
  --lr 1e-6 \
  --lr-decay-style constant \
  --weight-decay 0.1 \
  --tensor-model-parallel-size 8 \
  --pipeline-model-parallel-size 1 \
  --context-parallel-size 1 \
  --expert-model-parallel-size 1 \
  --expert-tensor-parallel-size 1 \
  --recompute-granularity full \
  --recompute-method uniform \
  --recompute-num-layers 1 \
  --use-dynamic-batch-size \
  --max-tokens-per-gpu 8192 \
  --micro-batch-size 1 \
  --rollout-num-gpus-per-engine 8 \
  --vllm-gpu-memory-utilization 0.6 \
  --vllm-enable-sleep-mode \
  --vllm-enforce-eager \
  --transformer-impl local \
  --attention-dropout 0.0 \
  --hidden-dropout 0.0 \
  --accumulate-allreduce-grads-in-fp32 \
  --attention-softmax-in-fp32 \
  --attention-backend flash \
  --use-flash-attn \
  > "$LOG" 2>&1 &

SUBMIT_PID=$!
sleep 15
JOB_ID=$(grep -o 'raysubmit_[A-Za-z0-9]*' "$LOG" | head -1 || true)
echo "submit_pid=$SUBMIT_PID job_id=${JOB_ID:-unknown} log=$LOG"
if [ -n "$JOB_ID" ]; then
  nohup ray job logs "$JOB_ID" --address="$RAY_ADDR" -f >> "$LOG" 2>&1 &
fi
echo "tail -f $LOG"

result

image

@CalvinXKY CalvinXKY force-pushed the feat/npu-colocate branch from 8c9e13f to 916af60 Compare June 25, 2026 12:06
Add npu_attention_patch before mindspeed import so Megatron train-side attention uses npu_fusion_attention on Ascend, matching pr266 colocate.

Signed-off-by: kaiyuan <kyxiezju@163.com>
@CalvinXKY CalvinXKY force-pushed the feat/npu-colocate branch from 916af60 to 3eb5436 Compare June 26, 2026 07:42
@CalvinXKY

Copy link
Copy Markdown
Collaborator Author

Test on new vllm version

PKG Version
vime 0.3.0
vLLM 0.22.1rc1.dev263+g967c5c3bc.d20260623.empty
vllm-ascend 0.1.dev1+gbad0caf65.d20260623
PyTorch 2.10.0+cpu
Python 3.12.13
CANN 9.0.0
#!/bin/bash
set -e

MODE="${1:-smoke}"
case "$MODE" in
  smoke) NUM_ROLLOUT=4 ;;
  test)  NUM_ROLLOUT=20 ;;
  long)  NUM_ROLLOUT=500 ;;
  *)
    echo "Usage: $0 [smoke|test|long]" >&2
    exit 1
    ;;
esac

VERIFY_LOG="${VERIFY_LOG:-/root/vime/PR285_verify_step.log}"
LOG_FILE="${LOG_FILE:-/root/vime/colocate_pr285_verify_${MODE}.log}"
MODEL_ROOT="${MODEL_ROOT:-/data/nfs_87}"
SCRIPT_DIR="/root/vime/scripts"

log() { echo "[$(date -Is)] $*" | tee -a "$VERIFY_LOG"; }

log "========== PR285 verify RUN (${MODE}) START =========="
log "LOG_FILE=$LOG_FILE"

export PATH="/usr/local/python3.12.13/bin:/usr/bin:/bin:/usr/sbin:/sbin:${PATH:-}"
source /usr/local/Ascend/driver/bin/setenv.bash
source "${ASCEND_TOOLKIT_HOME:-/usr/local/Ascend/ascend-toolkit}/set_env.sh"
( set +e +u; source /usr/local/Ascend/nnal/atb/set_env.sh 2>/dev/null ) || true

export SLIME_SCRIPT_TRAIN_BACKEND=megatron
export PYTHONPATH="/root/Megatron-Bridge/src:/root/Megatron-LM:/root/vllm:/root/vllm-ascend:/root/vime:/root/mbridge:/root/MindSpeed:/usr/local/Ascend/cann-9.0.0/python/site-packages:${PYTHONPATH:-}"
export ASCEND_RT_VISIBLE_DEVICES="${ASCEND_RT_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
export CUDA_DEVICE_MAX_CONNECTIONS=1
export RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES=1
export HCCL_HOST_SOCKET_PORT_RANGE=60000-60050
export HCCL_NPU_SOCKET_PORT_RANGE=61000-61050
export HYDRA_FULL_ERROR=1
export MASTER_PORT="${MASTER_PORT:-29500}"
export DISABLE_L2_CACHE=1
export VLLM_ASCEND_ENABLE_NZ=0
export VLLM_USE_AOT_COMPILE=0
export ASCEND_COREDUMP_SIGNAL=none
export HCCL_DETERMINISTIC=true
export HCCL_CONNECT_TIMEOUT=7200
export VLLM_SERVER_DEV_MODE=1
export ATB_MATMUL_SHUFFLE_K_ENABLE=0
export ATB_LLM_LCOC_ENABLE=0
export TRITON_NPU_COMPILER_PATH=/usr/local/Ascend/cann-9.0.0/aarch64-linux/bin
export LD_LIBRARY_PATH="/usr/local/Ascend/driver/lib64:/usr/local/Ascend/ascend-toolkit/latest/lib64:/usr/local/Ascend/nnal/atb/latest/atb/cxx_abi_1/lib:/usr/local/Ascend/cann-9.0.0/lib64:${LD_LIBRARY_PATH:-}"
export PYTHONUNBUFFERED=1
export RAY_DISABLE_SIGINT_OVERRIDE=1

unset PYTORCH_NPU_ALLOC_CONF
unset TORCH_DEVICE_BACKEND_AUTOLOAD

log "--- [R1] pre-train cleanup ---"
bash /root/clean_before_colocate_train.sh 2>&1 | tee -a "$VERIFY_LOG"

cd /root/vime
source "${SCRIPT_DIR}/models/qwen3-4B.sh"
: >"$LOG_FILE"

log "--- [R2] python train.py direct (no Ray job submit) ---"
python /root/vime/train.py \
  --train-backend megatron \
  --actor-num-nodes 1 \
  --actor-num-gpus-per-node 8 \
  --rollout-num-gpus 8 \
  --rollout-num-gpus-per-engine 8 \
  "${MODEL_ARGS[@]}" \
  --hf-checkpoint "${MODEL_ROOT}/model/Qwen3-4B/" \
  --prompt-data "${MODEL_ROOT}/wx/data/slime/gsm8k/train.parquet" \
  --input-key question \
  --label-key label \
  --apply-chat-template \
  --rollout-shuffle \
  --rm-type math \
  --rollout-backend vllm \
  --vllm-weight-sync-mode native \
  --vllm-gpu-memory-utilization 0.6 \
  --vllm-enable-sleep-mode \
  --vllm-max-model-len 4096 \
  --colocate \
  --num-rollout "$NUM_ROLLOUT" \
  --rollout-batch-size 16 \
  --n-samples-per-prompt 4 \
  --rollout-max-response-len 1024 \
  --rollout-temperature 1.0 \
  --global-batch-size 64 \
  --balance-data \
  --advantage-estimator grpo \
  --kl-loss-coef 0.0 \
  --kl-loss-type low_var_kl \
  --kl-coef 0.00 \
  --eps-clip 0.2 \
  --optimizer adam \
  --lr 1e-6 \
  --lr-decay-style constant \
  --weight-decay 0.1 \
  --tensor-model-parallel-size 8 \
  --pipeline-model-parallel-size 1 \
  --context-parallel-size 1 \
  --expert-model-parallel-size 1 \
  --expert-tensor-parallel-size 1 \
  --recompute-granularity full \
  --recompute-method uniform \
  --recompute-num-layers 1 \
  --use-dynamic-batch-size \
  --max-tokens-per-gpu 8192 \
  --load "${MODEL_ROOT}/model/Qwen3-4B" \
  --megatron-to-hf-mode bridge \
  --attention-dropout 0.0 \
  --hidden-dropout 0.0 \
  --accumulate-allreduce-grads-in-fp32 \
  --attention-softmax-in-fp32 \
  --attention-backend flash \
  --micro-batch-size 1 \
  --use-flash-attn \
  --train-memory-margin-bytes 2147483648 \
  2>&1 | tee -a "$LOG_FILE"

log "--- [R3] result summary ---"
grep -o "train_rollout_logprob_abs_diff': [0-9.eE+-]*" "$LOG_FILE" | tail -5 | tee -a "$VERIFY_LOG" || true
log "========== PR285 verify RUN (${MODE}) END log=$LOG_FILE =========="

Result

Step logprob_abs_diff
0 0.0114
1 0.0123
2 0.0122
3 0.0123

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant