Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies = [
"torch==2.9.1",
"torchaudio==2.9.1",
"torchvision==0.24.1",
"transformers==4.57.1",
"transformers==5.3.0",
"qwen-vl-utils==0.0.11",
"datasets",
"setuptools",
Expand All @@ -25,7 +25,7 @@ dependencies = [
"numpy",
"accelerate",
"pydantic",
"sglang==0.5.9",
"sglang==0.5.10",
"openai-harmony",
"ninja",
"packaging",
Expand Down
4 changes: 2 additions & 2 deletions requirements-rocm.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pre-commit
torch==2.8.0+rocm6.3
torchaudio==2.8.0+rocm6.3
torchvision==0.23.0+rocm6.3
transformers==4.57.1
transformers==5.3.0
qwen-vl-utils==0.0.11
datasets
setuptools
Expand All @@ -15,6 +15,6 @@ psutil
numpy
accelerate
pydantic
sglang[all]==0.5.4
sglang[all]==0.5.10
openai-harmony
tensorboard
10 changes: 5 additions & 5 deletions specforge/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class SGLangBackendArgs:
sglang_enable_torch_compile: bool = True
sglang_enable_dp_attention: bool = False
sglang_enable_dp_lm_head: bool = False
sglang_enable_piecewise_cuda_graph: bool = False
sglang_enforce_piecewise_cuda_graph: bool = False
sglang_piecewise_cuda_graph_max_tokens: int = 4096
sglang_piecewise_cuda_graph_tokens: List[int] = None
sglang_ep_size: int = 1
Expand Down Expand Up @@ -151,9 +151,9 @@ def add_args(parser: argparse.ArgumentParser) -> None:
help="Enable piecewise CUDA graph for SGLang backend",

Copilot AI Apr 13, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CLI help text for --sglang-enable-dp-lm-head appears incorrect (it mentions piecewise CUDA graph). This is user-facing and may confuse users; update the help string to describe DP LM head behavior instead.

Suggested change
help="Enable piecewise CUDA graph for SGLang backend",
help="Enable DP LM head for SGLang backend",

Copilot uses AI. Check for mistakes.
)
parser.add_argument(
"--sglang-enable-piecewise-cuda-graph",
"--sglang-enforce-piecewise-cuda-graph",
action="store_true",
help="Enable piecewise CUDA graph for SGLang backend's prefill",
help="Enforce piecewise CUDA graph for SGLang backend's prefill",

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The help text for the renamed argument --sglang-enforce-piecewise-cuda-graph is correct, but it highlights a significant issue in the preceding argument's help text (line 151), which incorrectly describes --sglang-enable-dp-lm-head as enabling piecewise CUDA graphs. While line 151 is not directly modified in this diff, the rename here makes the duplication and inaccuracy more apparent to users. Consider fixing the help text for --sglang-enable-dp-lm-head in a follow-up or including it here if possible.

)
Comment on lines 153 to 157

Copilot AI Apr 13, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming --sglang-enable-piecewise-cuda-graph to --sglang-enforce-piecewise-cuda-graph is a breaking CLI change. Consider keeping the old flag as a deprecated alias (same dest) or clearly documenting the change so existing scripts don’t fail on upgrade.

Copilot uses AI. Check for mistakes.
parser.add_argument(
"--sglang-piecewise-cuda-graph-max-tokens",
Expand Down Expand Up @@ -186,7 +186,7 @@ def from_args(args: argparse.Namespace) -> "SGLangBackendArgs":
sglang_enable_torch_compile=args.sglang_enable_torch_compile,
sglang_enable_dp_attention=args.sglang_enable_dp_attention,
sglang_enable_dp_lm_head=args.sglang_enable_dp_lm_head,
sglang_enable_piecewise_cuda_graph=args.sglang_enable_piecewise_cuda_graph,
sglang_enforce_piecewise_cuda_graph=args.sglang_enforce_piecewise_cuda_graph,
sglang_piecewise_cuda_graph_max_tokens=args.sglang_piecewise_cuda_graph_max_tokens,
sglang_piecewise_cuda_graph_tokens=args.sglang_piecewise_cuda_graph_tokens,
sglang_ep_size=args.sglang_ep_size,
Expand All @@ -210,7 +210,7 @@ def to_kwargs(self) -> Dict[str, Any]:
enable_torch_compile=self.sglang_enable_torch_compile,
enable_dp_attention=self.sglang_enable_dp_attention,
enable_dp_lm_head=self.sglang_enable_dp_lm_head,
enable_piecewise_cuda_graph=self.sglang_enable_piecewise_cuda_graph,
enforce_piecewise_cuda_graph=self.sglang_enforce_piecewise_cuda_graph,
piecewise_cuda_graph_max_tokens=self.sglang_piecewise_cuda_graph_max_tokens,
piecewise_cuda_graph_tokens=self.sglang_piecewise_cuda_graph_tokens,
ep_size=self.sglang_ep_size,
Expand Down
24 changes: 24 additions & 0 deletions specforge/modeling/draft/llama3_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,18 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
"sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False
)

def rebuild_buffers(self, device):
"""Rebuild non-persistent RoPE buffers corrupted by transformers 5.x meta-device init."""
self.inv_freq = 1.0 / (
self.base
** (torch.arange(0, self.dim, 2, device=device).float() / self.dim)
)
self._set_cos_sin_cache(
seq_len=self.max_position_embeddings + 20,
device=device,
dtype=torch.get_default_dtype(),
)

@torch.compile(dynamic=True)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
Expand Down Expand Up @@ -1314,6 +1326,16 @@ class LlamaForCausalLMEagle3(Eagle3DraftModel):

config_class = LlamaConfig

def _init_weights(self, module):
# Override the transformers 5.x default _init_weights which would
# re-randomize all Linear/Embedding weights with normal_(0, 0.02).
# Draft model weights come from checkpoint, not random init.
#
# For RotaryEmbedding: rebuild non-persistent buffers (inv_freq,
# cos_cached, sin_cached) corrupted by meta-device materialization.
if isinstance(module, LlamaRotaryEmbedding):
module.rebuild_buffers(module.inv_freq.device)

def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:
super().__init__(config)
self.config = config
Expand Down Expand Up @@ -1346,6 +1368,8 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None:
self.register_buffer("t2d", t2d)
self.register_buffer("d2t", d2t)

self.post_init()

def forward(
self,
hidden_states: torch.Tensor,
Expand Down
8 changes: 5 additions & 3 deletions specforge/modeling/target/custom_backend/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from transformers.utils.generic import check_model_inputs
from transformers.utils.generic import merge_with_config_defaults
from transformers.utils.output_capturing import capture_outputs

from specforge.distributed import get_tp_group, shard_tensor
from specforge.layers import (
Expand Down Expand Up @@ -585,7 +586,8 @@ def __init__(self, config: GptOssConfig):
# Initialize weights and apply final processing
self.post_init()

@check_model_inputs
@merge_with_config_defaults
@capture_outputs
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -759,7 +761,7 @@ def load_balancing_loss_func(

@auto_docstring
class GptOssForCausalLM(GptOssPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

Expand Down
8 changes: 5 additions & 3 deletions specforge/modeling/target/custom_backend/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, logging
from transformers.utils.generic import check_model_inputs
from transformers.utils.generic import merge_with_config_defaults
from transformers.utils.output_capturing import capture_outputs

from specforge.distributed import get_tp_group
from specforge.layers import (
Expand Down Expand Up @@ -275,7 +276,8 @@ def __init__(self, config: LlamaConfig):
# Initialize weights and apply final processing
self.post_init()

@check_model_inputs
@merge_with_config_defaults
@capture_outputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -353,7 +355,7 @@ def forward(


class LlamaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

Expand Down
8 changes: 5 additions & 3 deletions specforge/modeling/target/custom_backend/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@
logging,
)
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import check_model_inputs
from transformers.utils.generic import merge_with_config_defaults
from transformers.utils.output_capturing import capture_outputs

# [MODIFIED] Import from transformers library
from specforge.distributed import get_tp_group, shard_tensor
Expand Down Expand Up @@ -431,7 +432,8 @@ def __init__(self, config: Llama4TextConfig):
self.post_init()

@can_return_tuple
@check_model_inputs
@merge_with_config_defaults
@capture_outputs
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -526,7 +528,7 @@ def forward(
class Llama4ForCausalLM(Llama4PreTrainedModel, GenerationMixin):
_no_split_modules = ["Llama4TextDecoderLayer"]
base_model_prefix = "language_model"
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
config: Llama4TextConfig

Expand Down
8 changes: 5 additions & 3 deletions specforge/modeling/target/custom_backend/phi3.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
from transformers.utils.deprecation import deprecate_kwarg
from transformers.utils.generic import check_model_inputs
from transformers.utils.generic import merge_with_config_defaults
from transformers.utils.output_capturing import capture_outputs

from specforge.distributed import get_tp_group
from specforge.layers import (
Expand Down Expand Up @@ -284,7 +285,8 @@ def __init__(self, config: Phi3Config):
# Initialize weights and apply final processing
self.post_init()

@check_model_inputs
@merge_with_config_defaults
@capture_outputs
@auto_docstring
def forward(
self,
Expand Down Expand Up @@ -371,7 +373,7 @@ def forward(

@auto_docstring
class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

Expand Down
2 changes: 1 addition & 1 deletion specforge/modeling/target/custom_backend/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def forward(

@auto_docstring
class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

Expand Down
7 changes: 5 additions & 2 deletions specforge/modeling/target/custom_backend/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3RMSNorm,
Qwen3RotaryEmbedding as _OrigQwen3RotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
Expand Down Expand Up @@ -261,7 +262,9 @@ def __init__(self, config: Qwen3Config, device=None):
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.rope_init_fn = _OrigQwen3RotaryEmbedding.compute_default_rope_parameters
if self.rope_type != "default":
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
Expand Down Expand Up @@ -483,7 +486,7 @@ def forward(

@auto_docstring
class Qwen3ForCausalLM(Qwen3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

Expand Down
7 changes: 5 additions & 2 deletions specforge/modeling/target/custom_backend/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
Qwen3MoeRotaryEmbedding as _OrigQwen3MoeRotaryEmbedding,
apply_rotary_pos_emb,
eager_attention_forward,
)
Expand Down Expand Up @@ -430,7 +431,9 @@ def __init__(self, config: Qwen3MoeConfig, device=None):
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
self.rope_init_fn = _OrigQwen3MoeRotaryEmbedding.compute_default_rope_parameters
if self.rope_type != "default":
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
Expand Down Expand Up @@ -742,7 +745,7 @@ def load_balancing_loss_func(

@auto_docstring
class Qwen3MoeForCausalLM(Qwen3MoePreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
_tp_plan = {"lm_head": "colwise_rep"}
_pp_plan = {"lm_head": (["hidden_states"], ["logits"])}

Expand Down
2 changes: 0 additions & 2 deletions specforge/modeling/target/sglang_backend/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
),
group_name="tp",
pynccl_use_current_stream=duplicate_tp_group,
)

if duplicate_tp_group:
Expand All @@ -156,7 +155,6 @@ def initialize_model_parallel(
"SGLANG_USE_MESSAGE_QUEUE_BROADCASTER", "true"
),
group_name="pdmux_prefill_tp",
pynccl_use_current_stream=True,
)
# NOTE: Check pynccl_comm exists before accessing it (may be None in sglang 0.5.9)
if parallel_state._TP.pynccl_comm is not None:
Expand Down
Loading