Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions specforge/core/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from transformers.cache_utils import DynamicCache

from specforge.core.eagle3_adapters import BackendAdapter, SdpaLikeAdapter, UspAdapter
from specforge.core.loss import LogSoftmaxLoss
from specforge.core.loss import LogSoftmaxLoss, _compute_loss
from specforge.modeling.draft import Eagle3DraftModel
from specforge.utils import padding

Expand Down Expand Up @@ -92,7 +92,12 @@ def _acc_and_loss(
)
acc = local_correct / local_denom

loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
try:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
except RuntimeError:
# Fused Triton kernel has a block-size ceiling (131072); fall back
# to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)
Comment on lines +95 to +100

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching a generic RuntimeError to handle the Triton block size limit is suboptimal and potentially risky as it might swallow unrelated errors (e.g., device mismatches or other CUDA errors). Since the limit is a known constant (131072), it is better to perform an explicit check on the vocabulary size (logits.shape[-1]) to decide whether to use the fused kernel or the fallback implementation.

Suggested change
try:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
except RuntimeError:
# Fused Triton kernel has a block-size ceiling (131072); fall back
# to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)
if logits.shape[-1] <= 131072:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
else:
# Fused Triton kernel has a block-size ceiling (131072); fall back
# to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)

loss = adapter.reduce_loss(loss)
return acc, loss

Expand Down Expand Up @@ -553,7 +558,12 @@ def forward(
)

# Step 5.6: calculate loss, in-place modifies logits!
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
try:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
except RuntimeError:
# Fused Triton kernel has a block-size ceiling (131072); fall
# back to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)
Comment on lines +561 to +566

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

As noted in the previous comment, using an explicit check for the vocabulary size is preferred over a try-except block for handling the Triton kernel's block size constraints.

Suggested change
try:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
except RuntimeError:
# Fused Triton kernel has a block-size ceiling (131072); fall
# back to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)
if logits.shape[-1] <= 131072:
loss = LogSoftmaxLoss.apply(logits, target_p, position_mask)
else:
# Fused Triton kernel has a block-size ceiling (131072); fall
# back to the @torch.compile reference for large-vocab models.
loss = _compute_loss(logits, target_p, position_mask)

plosses.append(loss)

if not is_last:
Expand Down
9 changes: 8 additions & 1 deletion specforge/modeling/target/eagle3_target_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,14 @@ def from_pretrained(
def set_aux_hidden_states_layers(
self, aux_hidden_states_layers: Optional[List[int]] = None
) -> None:
self.model_runner.model.set_eagle3_layers_to_capture(aux_hidden_states_layers)
# Some target models (e.g., Kimi-K2.5) load via a multimodal wrapper
# that delegates to a text backbone at .language_model. The EAGLE-3
# helper set_eagle3_layers_to_capture is defined on the text backbone,
# not the outer wrapper.
inner = getattr(
self.model_runner.model, "language_model", self.model_runner.model
)
inner.set_eagle3_layers_to_capture(aux_hidden_states_layers)

@torch.no_grad
def _extend(
Expand Down
29 changes: 23 additions & 6 deletions specforge/modeling/target/sglang_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,27 @@ def wrap_eagle3_logits_processors_in_module(
module: nn.Module, return_full_logits: bool = False
):
"""
This function will wrap the SGLang's original logits processor with the modified one for EAGLE3.
Wrap SGLang's original logits processors with the EAGLE3 variant.

Fixes:
1. Iterate over a materialized list so mutations to _modules do not
corrupt the iterator returned by named_modules().
2. Use module.set_submodule(dotted_name, wrapped) so nested
LogitsProcessors (e.g. language_model.logits_processor) are actually
replaced in their parent module, instead of creating a literal
dotted-name attribute on the root module.
"""
for name, submodule in module.named_modules():
if isinstance(submodule, LogitsProcessor):
wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits)
setattr(module, name, wrapped)
print(f"wrapped {name} with LogitsProcessorForEAGLE3")
to_wrap = [
(name, submodule)
for name, submodule in list(module.named_modules())
if isinstance(submodule, LogitsProcessor)
]
for name, submodule in to_wrap:
wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The return_full_logits parameter is being passed as a positional argument to LogitsProcessorForEAGLE3, which maps it to return_last_hidden_states (the second parameter) instead of return_logits (the third parameter). This appears to be a logic error given the parameter names. Using keyword arguments would make this safer and clearer.

Suggested change
wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits)
wrapped = LogitsProcessorForEAGLE3(submodule, return_logits=return_full_logits)

if name == "":
print(
"warning: root module is a LogitsProcessor; cannot replace in-place"
)
continue
module.set_submodule(name, wrapped)
print(f"wrapped {name} with LogitsProcessorForEAGLE3")