-
Notifications
You must be signed in to change notification settings - Fork 272
fix: EAGLE-3 training compatibility with multimodal-wrapped targets and large vocabs #535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -28,7 +28,7 @@ | |||||||||||||||||||||||||
| from transformers.cache_utils import DynamicCache | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| from specforge.core.eagle3_adapters import BackendAdapter, SdpaLikeAdapter, UspAdapter | ||||||||||||||||||||||||||
| from specforge.core.loss import LogSoftmaxLoss | ||||||||||||||||||||||||||
| from specforge.core.loss import LogSoftmaxLoss, _compute_loss | ||||||||||||||||||||||||||
| from specforge.modeling.draft import Eagle3DraftModel | ||||||||||||||||||||||||||
| from specforge.utils import padding | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -92,7 +92,12 @@ def _acc_and_loss( | |||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
| acc = local_correct / local_denom | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) | ||||||||||||||||||||||||||
| except RuntimeError: | ||||||||||||||||||||||||||
| # Fused Triton kernel has a block-size ceiling (131072); fall back | ||||||||||||||||||||||||||
| # to the @torch.compile reference for large-vocab models. | ||||||||||||||||||||||||||
| loss = _compute_loss(logits, target_p, position_mask) | ||||||||||||||||||||||||||
| loss = adapter.reduce_loss(loss) | ||||||||||||||||||||||||||
| return acc, loss | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
@@ -553,7 +558,12 @@ def forward( | |||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| # Step 5.6: calculate loss, in-place modifies logits! | ||||||||||||||||||||||||||
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) | ||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||
| loss = LogSoftmaxLoss.apply(logits, target_p, position_mask) | ||||||||||||||||||||||||||
| except RuntimeError: | ||||||||||||||||||||||||||
| # Fused Triton kernel has a block-size ceiling (131072); fall | ||||||||||||||||||||||||||
| # back to the @torch.compile reference for large-vocab models. | ||||||||||||||||||||||||||
| loss = _compute_loss(logits, target_p, position_mask) | ||||||||||||||||||||||||||
|
Comment on lines
+561
to
+566
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As noted in the previous comment, using an explicit check for the vocabulary size is preferred over a
Suggested change
|
||||||||||||||||||||||||||
| plosses.append(loss) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if not is_last: | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -164,10 +164,27 @@ def wrap_eagle3_logits_processors_in_module( | |||||
| module: nn.Module, return_full_logits: bool = False | ||||||
| ): | ||||||
| """ | ||||||
| This function will wrap the SGLang's original logits processor with the modified one for EAGLE3. | ||||||
| Wrap SGLang's original logits processors with the EAGLE3 variant. | ||||||
|
|
||||||
| Fixes: | ||||||
| 1. Iterate over a materialized list so mutations to _modules do not | ||||||
| corrupt the iterator returned by named_modules(). | ||||||
| 2. Use module.set_submodule(dotted_name, wrapped) so nested | ||||||
| LogitsProcessors (e.g. language_model.logits_processor) are actually | ||||||
| replaced in their parent module, instead of creating a literal | ||||||
| dotted-name attribute on the root module. | ||||||
| """ | ||||||
| for name, submodule in module.named_modules(): | ||||||
| if isinstance(submodule, LogitsProcessor): | ||||||
| wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits) | ||||||
| setattr(module, name, wrapped) | ||||||
| print(f"wrapped {name} with LogitsProcessorForEAGLE3") | ||||||
| to_wrap = [ | ||||||
| (name, submodule) | ||||||
| for name, submodule in list(module.named_modules()) | ||||||
| if isinstance(submodule, LogitsProcessor) | ||||||
| ] | ||||||
| for name, submodule in to_wrap: | ||||||
| wrapped = LogitsProcessorForEAGLE3(submodule, return_full_logits) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||
| if name == "": | ||||||
| print( | ||||||
| "warning: root module is a LogitsProcessor; cannot replace in-place" | ||||||
| ) | ||||||
| continue | ||||||
| module.set_submodule(name, wrapped) | ||||||
| print(f"wrapped {name} with LogitsProcessorForEAGLE3") | ||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Catching a generic
RuntimeErrorto handle the Triton block size limit is suboptimal and potentially risky as it might swallow unrelated errors (e.g., device mismatches or other CUDA errors). Since the limit is a known constant (131072), it is better to perform an explicit check on the vocabulary size (logits.shape[-1]) to decide whether to use the fused kernel or the fallback implementation.