Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions configs/gemma3-27b-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 2,
"eos_token_id": 1,
"pad_token_id": 0,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 5376,
"initializer_range": 0.02,
"intermediate_size": 21504,
"max_position_embeddings": 4096,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 1,
"num_key_value_heads": 16,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000,
"sliding_window": 512,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.50.0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 262208,
"draft_vocab_size": 262208,
"target_model_type": "gemma3_text",
"eagle_config": {
"eagle_aux_hidden_state_layer_ids": [1, 29, 61]
},
"use_aux_norm": true,
"reuse_target_lm_head": true
}
35 changes: 35 additions & 0 deletions configs/gemma4-26b-a4b-eagle3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"architectures": [
"LlamaForCausalLMEagle3"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 2,
"eos_token_id": 1,
"pad_token_id": 0,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 2816,
"initializer_range": 0.02,
"intermediate_size": 2112,
"max_position_embeddings": 4096,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 1,
"num_key_value_heads": 8,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000.0,
"sliding_window": null,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.50.0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 262144,
"draft_vocab_size": 262144,
"target_model_type": "gemma4_text",
"eagle_config": {
"eagle_aux_hidden_state_layer_ids": [5, 17, 29]
}
}
32 changes: 32 additions & 0 deletions examples/run_gemma3_27b_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels

# train eagle3 for gemma3-27b
NUM_GPUS=${1:-8}
TP_SIZE=${2:-8}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path google/gemma-3-27b-it \
--draft-model-config $ROOT_DIR/configs/gemma3-27b-eagle3.json \
--train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \
--output-dir $ROOT_DIR/outputs/gemma3-27b-eagle3 \
--num-epochs 8 \
--batch-size 2 \
--draft-accumulation-steps 2 \
--tp-size $TP_SIZE \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template gemma \
--cache-dir $ROOT_DIR/cache \
--attention-backend sdpa \
--target-model-backend hf \
--log-interval 200 \
--eval-interval 5000 \
--save-interval 10000 \
--build-dataset-num-proc 64 \
--report-to tensorboard \
--embedding-key=language_model.model.embed_tokens.weight
33 changes: 33 additions & 0 deletions examples/run_gemma4_26b_eagle3_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels

# train eagle3 for gemma4-26b-a4b
NUM_GPUS=${1:-8}
TP_SIZE=${2:-2}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_eagle3.py \
--target-model-path google/gemma-4-26b-a4b-it \
--draft-model-config $ROOT_DIR/configs/gemma4-26b-a4b-eagle3.json \
--train-data-path \
$ROOT_DIR/outputs/dataset/ultrachat_regen_gemma4_preformatted.jsonl \
--is-preformatted \
--output-dir $ROOT_DIR/outputs/gemma4-26b-a4b-eagle3 \
--num-epochs 8 \
--batch-size 4 \
--tp-size $TP_SIZE \
--learning-rate 1e-4 \
--max-length 2048 \
--chat-template gemma-4 \
--cache-dir $ROOT_DIR/cache \
--attention-backend sdpa \
--target-model-backend hf \
--log-interval 200 \
--eval-interval 5000 \
--save-interval 10000 \
--build-dataset-num-proc 64 \
--report-to tensorboard \
--embedding-key=model.language_model.embed_tokens.weight
38 changes: 35 additions & 3 deletions scripts/train_eagle3.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import hashlib
import json
import math
import os
import time
Expand Down Expand Up @@ -86,6 +87,13 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]:
default="lm_head.weight",
help="The key of the lm head weight to load from the target model, this is only required for offline training",
)
model_group.add_argument(
"--reuse-target-lm-head",
action="store_true",
help="Load the target model's lm_head weights into the draft model's lm_head "
"and freeze it. Supports both tied and untied target models. "
"Requires draft_vocab_size == vocab_size.",
)
model_group.add_argument(
"--is-vlm", action="store_true", help="Whether the target model is a VLM"
)
Expand Down Expand Up @@ -339,6 +347,7 @@ def sanity_check(args: Namespace) -> None:
"""
args.dp_size = dist.get_world_size() // args.tp_size
args.target_batch_size = args.tp_size * args.batch_size

if args.attention_backend == "usp":
sp_sanity_check(args)

Expand Down Expand Up @@ -433,6 +442,14 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module]

draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key)
draft_model.freeze_embedding()
if args.reuse_target_lm_head:
draft_model.load_lm_head(
args.target_model_path,
lm_head_key=args.lm_head_key,
embedding_key=args.embedding_key,
)
draft_model.freeze_lm_head()
print_on_rank0("Loaded and froze lm_head from target model")
return draft_model_config, draft_model, ckpt_info, resume_state


Expand Down Expand Up @@ -587,6 +604,16 @@ def save_checkpoints(
epoch_output_dir,
state_dict=draft_model_state_dict,
)
# Overwrite config.json with the original training config to avoid
# transformers v5 mutating rope_scaling/rope_parameters and other
# fields in model.config during save_pretrained.
if getattr(args, "draft_model_config", None):
config_path = os.path.join(epoch_output_dir, "config.json")
with open(args.draft_model_config) as f:
original_config = json.load(f)
with open(config_path, "w") as f:
json.dump(original_config, f, indent=2)
print_on_rank0(f"Overwrote config.json with original training config")
print_on_rank0(f"Saved model configuration to {epoch_output_dir}")
dist.barrier()

Expand Down Expand Up @@ -758,9 +785,14 @@ def main():
args, draft_model_config, processor
)

# we load the vocab mapping then
draft_model.load_vocab_mapping(vocab_mapping_path)
print_with_rank("Loaded vocab mapping")
# we load the vocab mapping then (skip when draft_vocab_size == target_vocab_size)
if vocab_mapping_path is not None:
draft_model.load_vocab_mapping(vocab_mapping_path)
print_with_rank("Loaded vocab mapping")
else:
print_with_rank(
"Skipped vocab mapping loading (draft_vocab_size == target_vocab_size)"
)

# Calculate total steps if not provided
if args.total_steps is None:
Expand Down
39 changes: 32 additions & 7 deletions specforge/core/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,13 @@ def forward(
position_ids: (batch, seq_len)
"""
# Step 1: handle vocab size
# When draft_vocab_size == vocab_size, skip the d2t/t2d mapping entirely.
use_vocab_mapping = (
self.draft_model.draft_vocab_size != self.draft_model.vocab_size
)
target_p_padded, position_mask = _compute_target_p_padded(
target=target,
t2d=self.draft_model.t2d,
t2d=self.draft_model.t2d if use_vocab_mapping else None,
loss_mask=loss_mask,
length=self.length,
)
Expand Down Expand Up @@ -439,9 +443,13 @@ def forward(
)

# Step 1: handle vocab size
# When draft_vocab_size == vocab_size, skip the d2t/t2d mapping entirely.
use_vocab_mapping = (
self.draft_model.draft_vocab_size != self.draft_model.vocab_size
)
target_p_padded, position_mask = _compute_target_p_padded(
target=target,
t2d=self.draft_model.t2d,
t2d=self.draft_model.t2d if use_vocab_mapping else None,
loss_mask=loss_mask,
length=self.length,
)
Expand Down Expand Up @@ -567,11 +575,18 @@ def forward(

def _compute_target_p_padded(target, t2d, loss_mask, length):
with torch.no_grad():
target_p, position_mask = _compute_target_p(
target=target,
t2d=t2d,
loss_mask=loss_mask,
)
if t2d is None:
# draft_vocab_size == target_vocab_size: skip d2t/t2d mapping
target_p, position_mask = _compute_target_p_full_vocab(
target=target,
loss_mask=loss_mask,
)
else:
target_p, position_mask = _compute_target_p(
target=target,
t2d=t2d,
loss_mask=loss_mask,
)

assert len(target_p.shape) == 3
target_p_padded = F.pad(
Expand All @@ -585,6 +600,16 @@ def _compute_target_p_padded(target, t2d, loss_mask, length):
return target_p_padded, position_mask


@torch.compile(dynamic=None)
def _compute_target_p_full_vocab(target, loss_mask):
"""Fast path when draft_vocab_size == target_vocab_size (no vocab subsetting)."""
target_head = target.float()
target_p = nn.Softmax(dim=2)(target_head)
target_p = target_p.detach()
# All target tokens are in the draft vocab, so position_mask == loss_mask.
return target_p, loss_mask


@torch.compile(dynamic=None)
def _compute_target_p(target, t2d, loss_mask):
target_head = target
Expand Down
2 changes: 1 addition & 1 deletion specforge/core/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _compute_loss(logits, target_p, position_mask):
def _calculate_settings(n):
# reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43

MAX_FUSED_SIZE = 131072
MAX_FUSED_SIZE = 262208
BLOCK_SIZE = triton.next_power_of_2(n)
if BLOCK_SIZE > MAX_FUSED_SIZE:
raise RuntimeError(
Expand Down
13 changes: 10 additions & 3 deletions specforge/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,6 @@ def build_offline_eagle3_dataset(
ttt_length: int = 1,
use_usp_preprocess: bool = False,
) -> torch.utils.data.Dataset:

return OfflineEagle3Dataset(
list_local_files(hidden_states_path),
max_len=max_len,
Expand All @@ -683,7 +682,7 @@ def generate_vocab_mapping_file(
draft_vocab_size: int,
cache_dir: str = "./cache/vocab_mapping",
cache_key: str = "vocab_mapping",
) -> str:
) -> Optional[str]:
"""
Generate a vocab mapping file for the dataset.

Expand All @@ -695,8 +694,16 @@ def generate_vocab_mapping_file(
cache_key: The key to use for caching the vocab mapping file.

Returns:
The path to the vocab mapping file.
The path to the vocab mapping file, or None if draft_vocab_size
equals target_vocab_size (no mapping needed).
"""
if draft_vocab_size == target_vocab_size:
print(
f"draft_vocab_size ({draft_vocab_size}) == target_vocab_size "
f"({target_vocab_size}), skipping vocab mapping generation."
)
return None

# prepare cache directory
os.makedirs(cache_dir, exist_ok=True)
vocab_mapping_path = os.path.join(cache_dir, f"{cache_key}.pt")
Expand Down
14 changes: 13 additions & 1 deletion specforge/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def get_all_template_names(self) -> List[str]:
template=ChatTemplate(
assistant_header="<start_of_turn>model\n",
user_header="<start_of_turn>user\n",
system_prompt="You are a helpful assistant.",
system_prompt=None,
end_of_turn_token="<end_of_turn>\n",
),
)
Expand Down Expand Up @@ -324,3 +324,15 @@ def get_all_template_names(self) -> List[str]:
enable_thinking=True,
),
)

TEMPLATE_REGISTRY.register(
name="gemma-4",
template=ChatTemplate(
assistant_header="<|turn>model\n",
user_header="<|turn>user\n",
system_prompt="",
end_of_turn_token="<turn|>\n",
parser_type="thinking",
enable_thinking=True,
),
)
Loading
Loading