From c74d8159f15c23b6ad255bf9cac40791b96f44d2 Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Wed, 27 May 2026 15:58:13 -0700 Subject: [PATCH 1/5] eagle3.1 initial code --- specforge/core/eagle3.py | 8 +++++++ specforge/modeling/draft/base.py | 9 +++++++ specforge/modeling/draft/llama3_eagle.py | 30 +++++++++++++++++------- 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 1e2f04e7e..2c8f7028b 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -249,6 +249,10 @@ def forward( # update hidden states for next step hidden_states = hidden_states_out + # Apply output norm for EAGLE 3.1 post-norm architecture + if self.draft_model.norm_output: + hidden_states = self.draft_model.norm(hidden_states) + # Step 5.4: get logits logits = self.draft_model.compute_logits(hidden_states) @@ -538,6 +542,10 @@ def forward( # update hidden states for next step hidden_states = hidden_states_out + # Apply output norm for EAGLE 3.1 post-norm architecture + if self.draft_model.norm_output: + hidden_states = self.draft_model.norm(hidden_states) + # Step 5.4: get logits logits = self.draft_model.compute_logits(hidden_states) diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index b5584a759..b0979e819 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -41,6 +41,15 @@ class Eagle3DraftModel(PreTrainedModel, ABC): the abstract methods to support training with TTT. """ + def __init__(self, config): + super().__init__(config) + + self.num_aux_hidden_states = getattr(config, "num_aux_hidden_states", None) + if self.num_aux_hidden_states is None: + eagle_config = getattr(config, "eagle_config", None) or {} + layer_ids = eagle_config.get("eagle_aux_hidden_state_layer_ids") + self.num_aux_hidden_states = len(layer_ids) if layer_ids else 3 + @abstractmethod def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: """ diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 268142c0c..3713d9744 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -1326,16 +1326,26 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: ) self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) - if hasattr(config, "target_hidden_size"): - self.fc = torch.nn.Linear( - config.target_hidden_size * 3, config.hidden_size, bias=False + target_hidden_size = getattr(config, "target_hidden_size", config.hidden_size) + + self.fc = torch.nn.Linear( + target_hidden_size * self.num_aux_hidden_states, + config.hidden_size, + bias=False, + ) + use_fc_norm = getattr(config, "fc_norm", None) + if use_fc_norm: + self.fc_norm = nn.ModuleList( + [ + LlamaRMSNorm(target_hidden_size, eps=config.rms_norm_eps) + for _ in range(self.num_aux_hidden_states) + ] ) else: - self.fc = torch.nn.Linear( - config.hidden_size * 3, config.hidden_size, bias=False - ) + self.fc_norm = None self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.norm_output = getattr(config, "norm_output", False) self.lm_head = nn.Linear( config.hidden_size, config.draft_vocab_size, bias=False ) @@ -1406,8 +1416,12 @@ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: - # eagle 3 requires hidden states from 3 layers - assert hidden_states.size(-1) == self.config.hidden_size * 3 + if self.fc_norm is not None: + chunks = hidden_states.chunk(self.num_aux_hidden_states, dim=-1) + hidden_states = torch.cat( + [norm(chunk) for norm, chunk in zip(self.fc_norm, chunks)], + dim=-1, + ) return self.fc(hidden_states) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: From 581664ac7ae4a034c5d0327f321d47cecb0051f0 Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Wed, 27 May 2026 17:24:58 -0700 Subject: [PATCH 2/5] add eagle3.1 example --- configs/qwen3-30B-A3B-eagle3.1.json | 33 +++++++++++++++++++ examples/run_qwen3_30b_a3b_eagle3.1_online.sh | 29 ++++++++++++++++ 2 files changed, 62 insertions(+) create mode 100644 configs/qwen3-30B-A3B-eagle3.1.json create mode 100644 examples/run_qwen3_30b_a3b_eagle3.1_online.sh diff --git a/configs/qwen3-30B-A3B-eagle3.1.json b/configs/qwen3-30B-A3B-eagle3.1.json new file mode 100644 index 000000000..987354542 --- /dev/null +++ b/configs/qwen3-30B-A3B-eagle3.1.json @@ -0,0 +1,33 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "fc_norm": true, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 12288, + "max_position_embeddings": 2048, + "max_window_layers": 48, + "model_type": "llama", + "norm_output": true, + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads":4, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.53.2", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936, + "draft_vocab_size": 32000 +} diff --git a/examples/run_qwen3_30b_a3b_eagle3.1_online.sh b/examples/run_qwen3_30b_a3b_eagle3.1_online.sh new file mode 100644 index 000000000..8c298a1e7 --- /dev/null +++ b/examples/run_qwen3_30b_a3b_eagle3.1_online.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# support tp4/tp8 train eagle3 for Qwen3-30B-A3B +NUM_GPUS=${1:-4} +TP_SIZE=${2:-4} +BUILD_DATASET_NUM_PROC=${BUILD_DATASET_NUM_PROC:-64} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path Qwen/Qwen3-30B-A3B-Instruct-2507 \ + --draft-model-config $ROOT_DIR/configs/qwen3-30B-A3B-eagle3.1.json \ + --train-data-path $ROOT_DIR/cache/dataset/sharegpt_train.jsonl \ + --build-dataset-num-proc $BUILD_DATASET_NUM_PROC \ + --output-dir $ROOT_DIR/outputs/qwen3-30b-a3b-instruct-eagle3-sharegpt \ + --num-epochs 10 \ + --batch-size 1 \ + --learning-rate 1e-4 \ + --max-length 4096 \ + --chat-template qwen \ + --cache-dir $ROOT_DIR/cache \ + --embedding-key model.embed_tokens.weight \ + --tp-size $TP_SIZE \ + --target-model-backend sglang From fd310888f4d5d801ad85ff0a127fc66eddf2a3aa Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Thu, 28 May 2026 14:07:11 -0700 Subject: [PATCH 3/5] fix linter --- examples/run_qwen3_30b_a3b_eagle3.1_online.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 examples/run_qwen3_30b_a3b_eagle3.1_online.sh diff --git a/examples/run_qwen3_30b_a3b_eagle3.1_online.sh b/examples/run_qwen3_30b_a3b_eagle3.1_online.sh old mode 100644 new mode 100755 From 2508b38654a50ca93ce9fb0757da5c1c01ebd9db Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Thu, 28 May 2026 20:23:22 -0700 Subject: [PATCH 4/5] address comments --- specforge/core/eagle3.py | 8 -------- specforge/modeling/draft/llama3_eagle.py | 12 +++++++----- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 2c8f7028b..1e2f04e7e 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -249,10 +249,6 @@ def forward( # update hidden states for next step hidden_states = hidden_states_out - # Apply output norm for EAGLE 3.1 post-norm architecture - if self.draft_model.norm_output: - hidden_states = self.draft_model.norm(hidden_states) - # Step 5.4: get logits logits = self.draft_model.compute_logits(hidden_states) @@ -542,10 +538,6 @@ def forward( # update hidden states for next step hidden_states = hidden_states_out - # Apply output norm for EAGLE 3.1 post-norm architecture - if self.draft_model.norm_output: - hidden_states = self.draft_model.norm(hidden_states) - # Step 5.4: get logits logits = self.draft_model.compute_logits(hidden_states) diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 3713d9744..1d2d3b7e3 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -1326,10 +1326,10 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: ) self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) - target_hidden_size = getattr(config, "target_hidden_size", config.hidden_size) + self.target_hidden_size = getattr(config, "target_hidden_size", config.hidden_size) self.fc = torch.nn.Linear( - target_hidden_size * self.num_aux_hidden_states, + self.target_hidden_size * self.num_aux_hidden_states, config.hidden_size, bias=False, ) @@ -1337,7 +1337,7 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: if use_fc_norm: self.fc_norm = nn.ModuleList( [ - LlamaRMSNorm(target_hidden_size, eps=config.rms_norm_eps) + LlamaRMSNorm(self.target_hidden_size, eps=config.rms_norm_eps) for _ in range(self.num_aux_hidden_states) ] ) @@ -1345,7 +1345,7 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: self.fc_norm = None self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.norm_output = getattr(config, "norm_output", False) + self.norm_output = getattr(config, "norm_output", True) self.lm_head = nn.Linear( config.hidden_size, config.draft_vocab_size, bias=False ) @@ -1416,6 +1416,7 @@ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: + assert hidden_states.size(-1) == self.target_hidden_size * self.num_aux_hidden_states if self.fc_norm is not None: chunks = hidden_states.chunk(self.num_aux_hidden_states, dim=-1) hidden_states = torch.cat( @@ -1425,7 +1426,8 @@ def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: return self.fc(hidden_states) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: - norm_hidden_states = self.norm(hidden_states) + if self.norm_output: + norm_hidden_states = self.norm(hidden_states) return self.lm_head(norm_hidden_states) def backbone( From 3e629721364f642b2ac7b7f67638a105f470f97d Mon Sep 17 00:00:00 2001 From: bluecoffee8 Date: Thu, 28 May 2026 20:24:51 -0700 Subject: [PATCH 5/5] fix --- specforge/modeling/draft/llama3_eagle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 1d2d3b7e3..4cbd3acf6 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -1428,6 +1428,8 @@ def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.norm_output: norm_hidden_states = self.norm(hidden_states) + else: + norm_hidden_states = hidden_states return self.lm_head(norm_hidden_states) def backbone(