From 078d71c22d3d4f46524ba8009a31313970a9e34b Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Fri, 24 Apr 2026 23:40:01 +0000 Subject: [PATCH 1/4] gemma3 eagle3 training --- configs/gemma3-27b-eagle3.json | 37 +++++ configs/gemma4-26b-a4b-eagle3.json | 35 ++++ examples/run_gemma3_27b_eagle3_online.sh | 34 ++++ examples/run_gemma4_26b_eagle3_online.sh | 35 ++++ scripts/train_eagle3.py | 149 ++++++++++++++++-- specforge/core/eagle3.py | 39 ++++- specforge/data/preprocessing.py | 24 ++- specforge/data/template.py | 20 ++- specforge/modeling/draft/base.py | 148 ++++++++++------- specforge/modeling/draft/llama3_eagle.py | 118 +++++++++++--- .../modeling/target/custom_backend/gpt_oss.py | 2 - .../modeling/target/custom_backend/llama.py | 2 - .../modeling/target/custom_backend/llama4.py | 2 - .../modeling/target/custom_backend/phi3.py | 2 - .../modeling/target/eagle3_target_model.py | 17 +- specforge/utils.py | 4 + 16 files changed, 542 insertions(+), 126 deletions(-) create mode 100644 configs/gemma3-27b-eagle3.json create mode 100644 configs/gemma4-26b-a4b-eagle3.json create mode 100755 examples/run_gemma3_27b_eagle3_online.sh create mode 100755 examples/run_gemma4_26b_eagle3_online.sh diff --git a/configs/gemma3-27b-eagle3.json b/configs/gemma3-27b-eagle3.json new file mode 100644 index 000000000..9c8a6ff58 --- /dev/null +++ b/configs/gemma3-27b-eagle3.json @@ -0,0 +1,37 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 1, + "pad_token_id": 0, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5376, + "initializer_range": 0.02, + "intermediate_size": 21504, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 16, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": 512, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 262208, + "draft_vocab_size": 262208, + "target_model_type": "gemma3_text", + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [1, 29, 61] + }, + "additional_fc": false, + "reuse_target_lm_head": true +} diff --git a/configs/gemma4-26b-a4b-eagle3.json b/configs/gemma4-26b-a4b-eagle3.json new file mode 100644 index 000000000..f50e1345b --- /dev/null +++ b/configs/gemma4-26b-a4b-eagle3.json @@ -0,0 +1,35 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 1, + "pad_token_id": 0, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2816, + "initializer_range": 0.02, + "intermediate_size": 2112, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 262144, + "draft_vocab_size": 262144, + "target_model_type": "gemma4_text", + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [5, 17, 29] + } +} diff --git a/examples/run_gemma3_27b_eagle3_online.sh b/examples/run_gemma3_27b_eagle3_online.sh new file mode 100755 index 000000000..e3709e32e --- /dev/null +++ b/examples/run_gemma3_27b_eagle3_online.sh @@ -0,0 +1,34 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for gemma3-27b +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path google/gemma-3-27b-it \ + --draft-model-config $ROOT_DIR/configs/gemma3-27b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \ + $ROOT_DIR/cache/dataset/translate_bp_regen_gemma3_train.jsonl \ + --output-dir $ROOT_DIR/outputs/gemma3-27b-eagle3-ucml-mix-l-aq \ + --num-epochs 8 \ + --batch-size 2 \ + --draft-accumulation-steps 2 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template gemma \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend hf \ + --log-interval 200 \ + --eval-interval 5000 \ + --save-interval 10000 \ + --build-dataset-num-proc 64 \ + --report-to tensorboard \ + --eval-holdout-ratio 0.005 \ + --embedding-key=language_model.model.embed_tokens.weight diff --git a/examples/run_gemma4_26b_eagle3_online.sh b/examples/run_gemma4_26b_eagle3_online.sh new file mode 100755 index 000000000..5af333c73 --- /dev/null +++ b/examples/run_gemma4_26b_eagle3_online.sh @@ -0,0 +1,35 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for gemma4-26b-a4b +NUM_GPUS=${1:-8} +TP_SIZE=${2:-2} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path google/gemma-4-26b-a4b-it \ + --draft-model-config $ROOT_DIR/configs/gemma4-26b-a4b-eagle3.json \ + --train-data-path \ + $ROOT_DIR/outputs/dataset/ultrachat_regen_gemma4_preformatted.jsonl \ + $ROOT_DIR/outputs/dataset/translate_bp_regen_gemma4_preformatted.jsonl \ + --is-preformatted \ + --output-dir $ROOT_DIR/outputs/gemma4-26b-a4b-eagle3-ucml \ + --num-epochs 8 \ + --batch-size 4 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template gemma-4 \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend hf \ + --log-interval 200 \ + --eval-interval 5000 \ + --save-interval 10000 \ + --build-dataset-num-proc 64 \ + --report-to tensorboard \ + --embedding-key=model.language_model.embed_tokens.weight \ + --eval-holdout-ratio 0.005 diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 0bd157b39..44374d79c 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -1,8 +1,10 @@ import argparse +import glob import hashlib import math import os import time +from datetime import datetime from argparse import ArgumentParser, Namespace from typing import List, Optional, Tuple, Union @@ -86,6 +88,13 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: default="lm_head.weight", help="The key of the lm head weight to load from the target model, this is only required for offline training", ) + model_group.add_argument( + "--reuse-target-lm-head", + action="store_true", + help="Load the target model's lm_head weights into the draft model's lm_head " + "and freeze it. Supports both tied and untied target models. " + "Requires draft_vocab_size == vocab_size.", + ) model_group.add_argument( "--is-vlm", action="store_true", help="Whether the target model is a VLM" ) @@ -99,10 +108,17 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: # dataset arguments dataset_group = parser.add_argument_group("dataset") - dataset_group.add_argument("--train-data-path", type=str, required=True) + dataset_group.add_argument("--train-data-path", type=str, nargs="+", required=True) dataset_group.add_argument("--train-hidden-states-path", type=str, default=None) dataset_group.add_argument("--eval-hidden-states-path", type=str, default=None) dataset_group.add_argument("--eval-data-path", type=str, default=None) + dataset_group.add_argument( + "--eval-holdout-ratio", + type=float, + default=None, + help="Fraction of the training dataset to hold out for evaluation (0 to 1). " + "Mutually exclusive with --eval-data-path and --eval-hidden-states-path.", + ) dataset_group.add_argument("--chat-template", type=str, default="llama3") dataset_group.add_argument( "--is-preformatted", @@ -339,6 +355,19 @@ def sanity_check(args: Namespace) -> None: """ args.dp_size = dist.get_world_size() // args.tp_size args.target_batch_size = args.tp_size * args.batch_size + + if args.eval_holdout_ratio is not None: + if not (0 < args.eval_holdout_ratio < 1): + raise ValueError( + f"--eval-holdout-ratio must be between 0 and 1 (exclusive), " + f"got {args.eval_holdout_ratio}" + ) + if args.eval_data_path is not None or args.eval_hidden_states_path is not None: + raise ValueError( + "--eval-holdout-ratio is mutually exclusive with " + "--eval-data-path and --eval-hidden-states-path" + ) + if args.attention_backend == "usp": sp_sanity_check(args) @@ -347,9 +376,9 @@ def sp_sanity_check(args: Namespace) -> None: args.draft_accumulation_steps = ( args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size ) - assert ( - args.batch_size == 1 - ), f"USP only supports batch_size=1, got batch_size={args.batch_size}" + assert args.batch_size == 1, ( + f"USP only supports batch_size=1, got batch_size={args.batch_size}" + ) assert args.sp_ring_size * args.sp_ulysses_size > 1, ( f"USP requires sp_ring_size * sp_ulysses_size > 1. " @@ -433,6 +462,14 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module] draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) draft_model.freeze_embedding() + if args.reuse_target_lm_head: + draft_model.load_lm_head( + args.target_model_path, + lm_head_key=args.lm_head_key, + embedding_key=args.embedding_key, + ) + draft_model.freeze_lm_head() + print_on_rank0("Loaded and froze lm_head from target model") return draft_model_config, draft_model, ckpt_info, resume_state @@ -440,24 +477,55 @@ def build_dataloaders( args: Namespace, draft_model_config: AutoDraftModelConfig, processor: Optional[AutoProcessor] = None, -) -> Tuple[DataLoader, str, Optional[DataLoader]]: +) -> Tuple[DataLoader, Optional[str], Optional[DataLoader]]: # build dataloaders tokenizer = AutoTokenizer.from_pretrained( args.target_model_path, trust_remote_code=args.trust_remote_code ) + # Resolve all training data paths: expand directories to their .jsonl files + resolved_train_files = [] + for path in args.train_data_path: + if os.path.isdir(path): + jsonl_files = sorted(glob.glob(os.path.join(path, "*.jsonl"))) + if not jsonl_files: + raise ValueError(f"No .jsonl files found in directory: {path}") + resolved_train_files.extend(jsonl_files) + elif os.path.isfile(path): + resolved_train_files.append(path) + else: + raise ValueError(f"Training data path does not exist: {path}") + print_on_rank0( + f"Resolved {len(resolved_train_files)} training file(s) from " + f"{len(args.train_data_path)} path(s)" + ) + # convert to dataloader cache_params_string = ( - f"{args.train_data_path}-" + f"{','.join(sorted(resolved_train_files))}-" f"{args.max_length}-" f"{args.chat_template}-" f"{args.target_model_path}" # Tokenizer may also different ) cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() - train_dataset = Dataset.from_generator( - generator=safe_conversations_generator, - gen_kwargs={"file_path": args.train_data_path}, + + # Build datasets from all resolved files and concatenate + from datasets import concatenate_datasets + + train_datasets = [] + for file_path in resolved_train_files: + ds = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": file_path}, + ) + train_datasets.append(ds) + train_dataset = ( + concatenate_datasets(train_datasets) + if len(train_datasets) > 1 + else train_datasets[0] ) + print_on_rank0(f"Combined training dataset: {len(train_dataset)} examples") + is_online = ( args.train_data_path is not None and args.train_hidden_states_path is None ) @@ -491,6 +559,21 @@ def build_dataloaders( use_usp_preprocess=(args.attention_backend == "usp"), ) + # Split a holdout portion from the training set if requested. + eval_eagle3_dataset_from_holdout = None + if args.eval_holdout_ratio is not None and args.eval_holdout_ratio > 0: + split = train_eagle3_dataset.train_test_split( + test_size=args.eval_holdout_ratio, + seed=args.seed, + ) + train_eagle3_dataset = split["train"] + eval_eagle3_dataset_from_holdout = split["test"] + print_on_rank0( + f"Holdout split: {len(train_eagle3_dataset)} train, " + f"{len(eval_eagle3_dataset_from_holdout)} eval " + f"(ratio={args.eval_holdout_ratio})" + ) + train_dataloader = prepare_dp_dataloaders( train_eagle3_dataset, args.target_batch_size, @@ -503,7 +586,13 @@ def build_dataloaders( ), is_vlm=args.is_vlm, ) - if args.eval_data_path is not None or args.eval_hidden_states_path is not None: + + has_eval = ( + args.eval_data_path is not None + or args.eval_hidden_states_path is not None + or eval_eagle3_dataset_from_holdout is not None + ) + if has_eval: if args.eval_data_path is not None: eval_dataset = Dataset.from_generator( generator=safe_conversations_generator, @@ -527,6 +616,8 @@ def build_dataloaders( ttt_length=args.ttt_length, use_usp_preprocess=(args.attention_backend == "usp"), ) + else: + eval_eagle3_dataset = eval_eagle3_dataset_from_holdout eval_dataloader = prepare_dp_dataloaders( eval_eagle3_dataset, args.target_batch_size, @@ -587,6 +678,18 @@ def save_checkpoints( epoch_output_dir, state_dict=draft_model_state_dict, ) + # Overwrite config.json with the original training config to avoid + # transformers v5 mutating rope_scaling/rope_parameters and other + # fields in model.config during save_pretrained. + if getattr(args, "draft_model_config", None): + import json + + config_path = os.path.join(epoch_output_dir, "config.json") + with open(args.draft_model_config) as f: + original_config = json.load(f) + with open(config_path, "w") as f: + json.dump(original_config, f, indent=2) + print_on_rank0(f"Overwrote config.json with original training config") print_on_rank0(f"Saved model configuration to {epoch_output_dir}") dist.barrier() @@ -742,6 +845,16 @@ def main(): ) sanity_check(args) + + # Create a datetime subfolder for this run (skip when resuming into an + # existing output directory so that checkpoints stay in the same place). + if not args.resume: + run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + args.output_dir = os.path.join(args.output_dir, run_timestamp) + if dist.get_rank() == 0: + os.makedirs(args.output_dir, exist_ok=True) + dist.barrier() + print_args_with_dots(args) print_with_rank("Initialized distributed environment") @@ -758,9 +871,14 @@ def main(): args, draft_model_config, processor ) - # we load the vocab mapping then - draft_model.load_vocab_mapping(vocab_mapping_path) - print_with_rank("Loaded vocab mapping") + # we load the vocab mapping then (skip when draft_vocab_size == target_vocab_size) + if vocab_mapping_path is not None: + draft_model.load_vocab_mapping(vocab_mapping_path) + print_with_rank("Loaded vocab mapping") + else: + print_with_rank( + "Skipped vocab mapping loading (draft_vocab_size == target_vocab_size)" + ) # Calculate total steps if not provided if args.total_steps is None: @@ -946,10 +1064,7 @@ def main(): # ================================================ # 7.2 Evaluation Step # ================================================ - should_evaluate = ( - args.eval_data_path is not None - or args.eval_hidden_states_path is not None - ) + should_evaluate = eval_dataloader is not None if ( should_evaluate and global_step % (args.eval_interval * args.draft_accumulation_steps) diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 1e2f04e7e..893af4fe6 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -153,9 +153,13 @@ def forward( position_ids: (batch, seq_len) """ # Step 1: handle vocab size + # When draft_vocab_size == vocab_size, skip the d2t/t2d mapping entirely. + use_vocab_mapping = ( + self.draft_model.draft_vocab_size != self.draft_model.vocab_size + ) target_p_padded, position_mask = _compute_target_p_padded( target=target, - t2d=self.draft_model.t2d, + t2d=self.draft_model.t2d if use_vocab_mapping else None, loss_mask=loss_mask, length=self.length, ) @@ -439,9 +443,13 @@ def forward( ) # Step 1: handle vocab size + # When draft_vocab_size == vocab_size, skip the d2t/t2d mapping entirely. + use_vocab_mapping = ( + self.draft_model.draft_vocab_size != self.draft_model.vocab_size + ) target_p_padded, position_mask = _compute_target_p_padded( target=target, - t2d=self.draft_model.t2d, + t2d=self.draft_model.t2d if use_vocab_mapping else None, loss_mask=loss_mask, length=self.length, ) @@ -567,11 +575,18 @@ def forward( def _compute_target_p_padded(target, t2d, loss_mask, length): with torch.no_grad(): - target_p, position_mask = _compute_target_p( - target=target, - t2d=t2d, - loss_mask=loss_mask, - ) + if t2d is None: + # draft_vocab_size == target_vocab_size: skip d2t/t2d mapping + target_p, position_mask = _compute_target_p_full_vocab( + target=target, + loss_mask=loss_mask, + ) + else: + target_p, position_mask = _compute_target_p( + target=target, + t2d=t2d, + loss_mask=loss_mask, + ) assert len(target_p.shape) == 3 target_p_padded = F.pad( @@ -585,6 +600,16 @@ def _compute_target_p_padded(target, t2d, loss_mask, length): return target_p_padded, position_mask +@torch.compile(dynamic=None) +def _compute_target_p_full_vocab(target, loss_mask): + """Fast path when draft_vocab_size == target_vocab_size (no vocab subsetting).""" + target_head = target.float() + target_p = nn.Softmax(dim=2)(target_head) + target_p = target_p.detach() + # All target tokens are in the draft vocab, so position_mask == loss_mask. + return target_p, loss_mask + + @torch.compile(dynamic=None) def _compute_target_p(target, t2d, loss_mask): target_head = target diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index d5af9479e..fd3445c46 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -123,7 +123,7 @@ def preprocess_conversations( max_length: int = 2048, is_preformatted: bool = False, train_only_last_turn: bool = False, - tools: Optional[List[List[Dict]]] = [[]], + tools: Optional[List[List[Dict]]] = None, **kwargs, ) -> Dict[str, List[torch.Tensor]]: """ @@ -155,6 +155,9 @@ def preprocess_conversations( parser = HarmonyParser(tokenizer, chat_template) else: raise ValueError(f"Invalid parser type: {chat_template.parser_type}") + # Ensure tools list matches conversations length + if tools is None or len(tools) != len(conversations): + tools = [[] for _ in range(len(conversations))] kwargs_list = [{} for _ in range(len(conversations))] for key, value_list in kwargs.items(): for i, value in enumerate(value_list): @@ -344,9 +347,9 @@ def build_eagle3_dataset( if chat_template is None: raise ValueError("chat_template must be provided for all dataset types") - assert ( - chat_template in TEMPLATE_REGISTRY.get_all_template_names() - ), f"Chat template {chat_template} not found in TEMPLATE_REGISTRY, you may need to register it first" + assert chat_template in TEMPLATE_REGISTRY.get_all_template_names(), ( + f"Chat template {chat_template} not found in TEMPLATE_REGISTRY, you may need to register it first" + ) template: ChatTemplate = TEMPLATE_REGISTRY.get(chat_template) @@ -665,7 +668,6 @@ def build_offline_eagle3_dataset( ttt_length: int = 1, use_usp_preprocess: bool = False, ) -> torch.utils.data.Dataset: - return OfflineEagle3Dataset( list_local_files(hidden_states_path), max_len=max_len, @@ -683,7 +685,7 @@ def generate_vocab_mapping_file( draft_vocab_size: int, cache_dir: str = "./cache/vocab_mapping", cache_key: str = "vocab_mapping", -) -> str: +) -> Optional[str]: """ Generate a vocab mapping file for the dataset. @@ -695,8 +697,16 @@ def generate_vocab_mapping_file( cache_key: The key to use for caching the vocab mapping file. Returns: - The path to the vocab mapping file. + The path to the vocab mapping file, or None if draft_vocab_size + equals target_vocab_size (no mapping needed). """ + if draft_vocab_size == target_vocab_size: + print( + f"draft_vocab_size ({draft_vocab_size}) == target_vocab_size " + f"({target_vocab_size}), skipping vocab mapping generation." + ) + return None + # prepare cache directory os.makedirs(cache_dir, exist_ok=True) vocab_mapping_path = os.path.join(cache_dir, f"{cache_key}.pt") diff --git a/specforge/data/template.py b/specforge/data/template.py index 4dde000fd..36381e69f 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -58,9 +58,9 @@ def register(self, name: str, template: ChatTemplate, override: bool = False): template(ChatTemplate): The chat template. override(bool): Whether to override the existing template, default to False """ - assert ( - not override and name not in self.templates - ), f"Chat template for the model type {name} has already been registered" + assert not override and name not in self.templates, ( + f"Chat template for the model type {name} has already been registered" + ) self.templates[name] = template def get(self, name: str) -> ChatTemplate: @@ -286,7 +286,7 @@ def get_all_template_names(self) -> List[str]: template=ChatTemplate( assistant_header="model\n", user_header="user\n", - system_prompt="You are a helpful assistant.", + system_prompt=None, end_of_turn_token="\n", ), ) @@ -324,3 +324,15 @@ def get_all_template_names(self) -> List[str]: enable_thinking=True, ), ) + +TEMPLATE_REGISTRY.register( + name="gemma-4", + template=ChatTemplate( + assistant_header="<|turn>model\n", + user_header="<|turn>user\n", + system_prompt="", + end_of_turn_token="\n", + parser_type="thinking", + enable_thinking=True, + ), +) diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index b5584a759..74e87a14e 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -114,64 +114,104 @@ def freeze_embedding(self) -> None: """ self.embed_tokens.weight.requires_grad = False + def freeze_lm_head(self) -> None: + """ + Freeze the lm_head of the draft model so that it is not updated during training. + """ + self.lm_head.weight.requires_grad = False + + @torch.no_grad() + def _load_tensor_from_checkpoint( + self, model_path: str, tensor_key: str + ) -> torch.Tensor: + """ + Load a single tensor from a target model checkpoint. + + Args: + model_path (str): Path to the target model. Can be either a Hugging Face + repository ID or a local directory path containing the model files. + tensor_key (str): The key of the tensor to load from the checkpoint. + + Returns: + torch.Tensor: The loaded tensor. + """ + if not os.path.exists(model_path): + # model_path is a huggingface repository, download first + model_path = snapshot_download(repo_id=model_path) + + glob_path = os.path.join(model_path, "*.index.json") + index_json_paths = glob.glob(glob_path) + + if len(index_json_paths) == 0: + # No index.json found, look for single model file + safetensors_path = os.path.join(model_path, "model.safetensors") + if os.path.exists(safetensors_path): + with safe_open(safetensors_path, framework="pt") as f: + return f.get_tensor(tensor_key) + + pytorch_model_path = os.path.join(model_path, "pytorch_model.bin") + if os.path.exists(pytorch_model_path): + state_dict = torch.load(pytorch_model_path, map_location="cpu") + return state_dict[tensor_key] + + raise FileNotFoundError( + f"No index.json, model.safetensors or pytorch_model.bin found in {model_path}" + ) + + if len(index_json_paths) > 1: + raise FileNotFoundError(f"Multiple index.json files found in {model_path}") + + with open(index_json_paths[0], "r") as f: + index_json = json.load(f) + ckpt_file = index_json["weight_map"][tensor_key] + + if ckpt_file.endswith(".safetensors"): + with safe_open(os.path.join(model_path, ckpt_file), framework="pt") as f: + return f.get_tensor(tensor_key) + else: + state_dict = torch.load(os.path.join(model_path, ckpt_file)) + return state_dict[tensor_key] + @torch.no_grad() def load_embedding( - self, model_path: str, embedding_key: str = "model.embed_tokens.weight" + self, model_path: str, embedding_key: str = "language_model.embed_tokens.weight" ) -> None: """ - Load the embedding of the draft model. + Load the embedding of the draft model from the target model checkpoint. Args: model_path (str): Path to the target model. Can be either a Hugging Face - repository ID or a local directory path containing the model files. - """ - if os.path.exists(model_path): - # model_path is a local directory - # check if there is file ending with index.json - glob_path = os.path.join(model_path, "*.index.json") - index_json_path = glob.glob(glob_path) - - if len(index_json_path) == 0: - # No index.json found, look for single model file - safetensors_path = os.path.join(model_path, "model.safetensors") - if os.path.exists(safetensors_path): - with safe_open(safetensors_path, framework="pt") as f: - self.embed_tokens.weight.copy_(f.get_tensor(embedding_key)) - return - - pytorch_model_path = os.path.join(model_path, "pytorch_model.bin") - if os.path.exists(pytorch_model_path): - state_dict = torch.load(pytorch_model_path, map_location="cpu") - self.embed_tokens.weight.copy_(state_dict[embedding_key]) - return - - raise FileNotFoundError( - f"No index.json, model.safetensors or pytorch_model.bin found in {model_path}" - ) - if len(index_json_path) > 1: - raise FileNotFoundError( - f"Multiple index.json files found in {model_path}" - ) - index_json_path = index_json_path[0] - - with open(index_json_path, "r") as f: - index_json = json.load(f) - ckpt_file = index_json["weight_map"][embedding_key] - - if ckpt_file.endswith(".safetensors"): - with safe_open( - os.path.join(model_path, ckpt_file), framework="pt" - ) as f: - emb_tokens = f.get_tensor(embedding_key) - else: - state_dict = torch.load(os.path.join(model_path, ckpt_file)) - emb_tokens = state_dict[embedding_key] - self.embed_tokens.weight.copy_(emb_tokens) - else: - # this is the case where model_path is a huggingface repository - # we first need to locate its local cache - local_cache_path = snapshot_download(repo_id=model_path) - self.load_embedding(local_cache_path, embedding_key) + repository ID or a local directory path containing the model files. + embedding_key (str): The key of the embedding weight in the checkpoint. + """ + tensor = self._load_tensor_from_checkpoint(model_path, embedding_key) + self.embed_tokens.weight.copy_(tensor) + + @torch.no_grad() + def load_lm_head( + self, model_path: str, lm_head_key: str, embedding_key: str + ) -> None: + """ + Load the lm_head of the draft model from the target model checkpoint. + + For models with tied weights (embed_tokens == lm_head), the lm_head key + may not exist in the checkpoint. In that case, falls back to loading + from the embedding key. + + Args: + model_path (str): Path to the target model. Can be either a Hugging Face + repository ID or a local directory path containing the model files. + lm_head_key (str): The key of the lm_head weight in the checkpoint. + embedding_key (str): Fallback key if lm_head_key is not found (for + models with tie_word_embeddings=True). + """ + try: + tensor = self._load_tensor_from_checkpoint(model_path, lm_head_key) + except KeyError: + # Target model ties weights -- lm_head key doesn't exist in checkpoint, + # fall back to embedding key + tensor = self._load_tensor_from_checkpoint(model_path, embedding_key) + self.lm_head.weight.copy_(tensor) def load_vocab_mapping(self, file_path: str) -> None: """ @@ -180,9 +220,9 @@ def load_vocab_mapping(self, file_path: str) -> None: Args: file_path (str): The path to the vocab mapping file. """ - assert hasattr(self, "t2d") and hasattr( - self, "d2t" - ), "t2d and d2t buffersare not found in the draft model, please check your draft model implementation" + assert hasattr(self, "t2d") and hasattr(self, "d2t"), ( + "t2d and d2t buffersare not found in the draft model, please check your draft model implementation" + ) vocab_mapping = torch.load(file_path) self.t2d.copy_(vocab_mapping["t2d"]) self.d2t.copy_(vocab_mapping["d2t"]) diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 268142c0c..7ca7e80e0 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -538,12 +538,24 @@ def __init__(self, config): ) self._init_rope() + def _get_rope_theta(self): + """Extract rope_theta from config, handling transformers v5 which moves + rope_theta into rope_scaling/rope_parameters instead of a top-level attr.""" + rope_theta = getattr(self.config, "rope_theta", None) + if rope_theta is not None: + return rope_theta + for attr in ("rope_parameters", "rope_scaling"): + params = getattr(self.config, attr, None) + if isinstance(params, dict) and "rope_theta" in params: + return params["rope_theta"] + raise RuntimeError("rope theta is not set.") + def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, - base=getattr(self.config, "rope_theta", 10000), + base=self._get_rope_theta(), ) else: rope_scaling = self.config.rope_scaling @@ -560,7 +572,7 @@ def rope_get(key, default=None): self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, - base=getattr(self.config, "rope_theta", 10000), + base=self._get_rope_theta(), ) return elif scaling_type == "linear": @@ -925,9 +937,9 @@ def forward( k0 = cache_k[0] v0 = cache_v[0] - assert ( - flash_attn_func is not None - ), "flash_attn is not installed, please install flash_attn if you want to use the flash attention backend" + assert flash_attn_func is not None, ( + "flash_attn is not installed, please install flash_attn if you want to use the flash attention backend" + ) attn_output, lse, _ = flash_attn_func( query_states, k0, @@ -1008,7 +1020,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() local_q_len = q_len @@ -1099,9 +1110,9 @@ def forward( else: acc_lse = lse_ring - assert ( - acc_lse.shape[1] == current_q_len - ), f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}" + assert acc_lse.shape[1] == current_q_len, ( + f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}" + ) acc_out = out_ring @@ -1246,10 +1257,8 @@ def __init__(self, config, attention_backend: str = "sdpa"): self.attention_backend = attention_backend self.mlp = LlamaMLP(config) - # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size) self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # if self.index!=0: self.post_attention_layernorm = LlamaRMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -1306,12 +1315,10 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - # outputs = (hidden_states, return_hidden) return hidden_states class LlamaForCausalLMEagle3(Eagle3DraftModel): - config_class = LlamaConfig def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: @@ -1326,26 +1333,78 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: ) self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) - if hasattr(config, "target_hidden_size"): - self.fc = torch.nn.Linear( - config.target_hidden_size * 3, config.hidden_size, bias=False - ) - else: - self.fc = torch.nn.Linear( - config.hidden_size * 3, config.hidden_size, bias=False - ) + + target_hidden_size = getattr(config, "target_hidden_size", config.hidden_size) + + # Per-layer RMSNorm applied to each aux hidden state before concatenation, + # so that all three layers contribute equally regardless of their raw scale. + self.aux_norm_low = LlamaRMSNorm(target_hidden_size, eps=config.rms_norm_eps) + self.aux_norm_mid = LlamaRMSNorm(target_hidden_size, eps=config.rms_norm_eps) + self.aux_norm_high = LlamaRMSNorm(target_hidden_size, eps=config.rms_norm_eps) + + self.fc = torch.nn.Linear( + target_hidden_size * 3, config.hidden_size, bias=False + ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear( config.hidden_size, config.draft_vocab_size, bias=False ) + # Embedding scale factor for target models that use scaled embeddings + # (e.g., Gemma3/Gemma4 multiply by hidden_size**0.5). Set via config + # field ``embed_scale`` or auto-detected from ``target_model_type``. + target_type = getattr(config, "target_model_type", None) or "" + if getattr(config, "embed_scale", None) is not None: + self.embed_scale = config.embed_scale + elif "gemma" in target_type: + self.embed_scale = config.hidden_size**0.5 + else: + self.embed_scale = 1.0 + # create vocab buffers t2d = torch.ones(self.vocab_size, dtype=torch.bool) d2t = torch.zeros(self.draft_vocab_size, dtype=torch.int64) self.register_buffer("t2d", t2d) self.register_buffer("d2t", d2t) + # Apply improved initialization for stable training with mixed + # pretrained (frozen) and randomly initialized (trainable) parameters. + self._init_weights() + + def _init_weights(self) -> None: + """ + Initialize weights for stable training when mixing pretrained frozen + components (embed_tokens, lm_head) with randomly initialized trainable + layers. + + Strategy: + - Zero-init residual projections (o_proj, down_proj) so the decoder + layer starts as near-identity through the residual stream. + - Use small normal init (std=0.02) for other projections instead of + Kaiming uniform, matching the config's initializer_range. + - RMSNorm weights stay at ones (PyTorch default). + """ + std = getattr(self.config, "initializer_range", 0.02) + + for name, param in self.named_parameters(): + if "embed_tokens" in name or "lm_head" in name: + # These will be overwritten by load_embedding / load_lm_head + continue + if "norm" in name or "layernorm" in name: + # RMSNorm weights: keep at ones (default) + continue + if param.dim() < 2: + # Biases and 1-d params: zero init + nn.init.zeros_(param) + continue + + # Zero-init residual-path projections for near-identity at start + if "o_proj" in name or "down_proj" in name: + nn.init.zeros_(param) + else: + nn.init.normal_(param, mean=0.0, std=std) + def forward( self, hidden_states: torch.Tensor, @@ -1403,11 +1462,24 @@ def forward( return hidden_states def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + embeds = self.embed_tokens(input_ids) + if self.embed_scale != 1.0: + embeds = embeds * self.embed_scale + return embeds def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: # eagle 3 requires hidden states from 3 layers - assert hidden_states.size(-1) == self.config.hidden_size * 3 + target_h = getattr(self.config, "target_hidden_size", self.config.hidden_size) + assert hidden_states.size(-1) == target_h * 3 + + # Normalize each aux layer independently before fc projection, + # so all three contribute equally regardless of their raw scale. + h_low, h_mid, h_high = hidden_states.split(target_h, dim=-1) + h_low = self.aux_norm_low(h_low) + h_mid = self.aux_norm_mid(h_mid) + h_high = self.aux_norm_high(h_high) + hidden_states = torch.cat((h_low, h_mid, h_high), dim=-1) + return self.fc(hidden_states) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/specforge/modeling/target/custom_backend/gpt_oss.py b/specforge/modeling/target/custom_backend/gpt_oss.py index b3b4a7972..9910633c8 100644 --- a/specforge/modeling/target/custom_backend/gpt_oss.py +++ b/specforge/modeling/target/custom_backend/gpt_oss.py @@ -36,7 +36,6 @@ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple -from transformers.utils.generic import check_model_inputs from specforge.distributed import get_tp_group, shard_tensor from specforge.layers import ( @@ -585,7 +584,6 @@ def __init__(self, config: GptOssConfig): # Initialize weights and apply final processing self.post_init() - @check_model_inputs @auto_docstring def forward( self, diff --git a/specforge/modeling/target/custom_backend/llama.py b/specforge/modeling/target/custom_backend/llama.py index 04a3f6c9b..02a1c16c4 100644 --- a/specforge/modeling/target/custom_backend/llama.py +++ b/specforge/modeling/target/custom_backend/llama.py @@ -41,7 +41,6 @@ ) from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, logging -from transformers.utils.generic import check_model_inputs from specforge.distributed import get_tp_group from specforge.layers import ( @@ -275,7 +274,6 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/specforge/modeling/target/custom_backend/llama4.py b/specforge/modeling/target/custom_backend/llama4.py index 22f807dae..bccbb19ea 100644 --- a/specforge/modeling/target/custom_backend/llama4.py +++ b/specforge/modeling/target/custom_backend/llama4.py @@ -52,7 +52,6 @@ logging, ) from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import check_model_inputs # [MODIFIED] Import from transformers library from specforge.distributed import get_tp_group, shard_tensor @@ -431,7 +430,6 @@ def __init__(self, config: Llama4TextConfig): self.post_init() @can_return_tuple - @check_model_inputs @auto_docstring def forward( self, diff --git a/specforge/modeling/target/custom_backend/phi3.py b/specforge/modeling/target/custom_backend/phi3.py index 2515701f9..c3ec1adcc 100644 --- a/specforge/modeling/target/custom_backend/phi3.py +++ b/specforge/modeling/target/custom_backend/phi3.py @@ -43,7 +43,6 @@ from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import check_model_inputs from specforge.distributed import get_tp_group from specforge.layers import ( @@ -284,7 +283,6 @@ def __init__(self, config: Phi3Config): # Initialize weights and apply final processing self.post_init() - @check_model_inputs @auto_docstring def forward( self, diff --git a/specforge/modeling/target/eagle3_target_model.py b/specforge/modeling/target/eagle3_target_model.py index 2acf50ba5..c7d4991f2 100644 --- a/specforge/modeling/target/eagle3_target_model.py +++ b/specforge/modeling/target/eagle3_target_model.py @@ -94,6 +94,9 @@ def set_aux_hidden_states_layers( if aux_hidden_states_layers is None: if hasattr(self.model.config, "num_hidden_layers"): num_layers = self.model.config.num_hidden_layers + + elif hasattr(self.model.config, "text_config") and hasattr(self.model.config.text_config, "num_hidden_layers"): + num_layers = self.model.config.text_config.num_hidden_layers else: raise ValueError( f"Failed to set aux hidden states layers as model config {self.model.config} does not have num_hidden_layers" @@ -154,18 +157,20 @@ def _get_transformer_layers(self): Helper to find the module list containing the transformer layers. Adapts to common architectures (Llama, Qwen, Mistral, OPT, etc.) """ - if hasattr(self.model, "model") and hasattr(self.model.model, "layers"): - return self.model.model.layers + if hasattr(self.model, "model"): + if hasattr(self.model.model, "layers"): + return self.model.model.layers + elif hasattr(self.model.model, "language_model"): + return self.model.model.language_model.layers elif hasattr(self.model, "layers"): return self.model.layers elif hasattr(self.model, "transformer") and hasattr( self.model.transformer, "h" ): return self.model.transformer.h - else: - raise ValueError( - "Could not locate transformer layers in the model architecture to register hooks." - ) + raise ValueError( + "Could not locate transformer layers in the model architecture to register hooks." + ) @torch.no_grad() def generate_eagle3_data( diff --git a/specforge/utils.py b/specforge/utils.py index af4d627c8..543908643 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -379,6 +379,10 @@ def safe_conversations_generator(file_path): # Build result with conversations result = {"conversations": cleaned_convs} + # Preserve 'text' field if present (for preformatted data) + if "text" in row: + result["text"] = row["text"] + # Preserve 'tools' field if present if "tools" in row: tools = row["tools"] From ceefdebd937cb3362d463f4dbc584d9f9546e60e Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Fri, 1 May 2026 21:11:18 +0000 Subject: [PATCH 2/4] support use_aux_norm --- configs/gemma3-27b-eagle3.json | 2 +- specforge/modeling/draft/llama3_eagle.py | 42 ++++++++++++++---------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/configs/gemma3-27b-eagle3.json b/configs/gemma3-27b-eagle3.json index 9c8a6ff58..3123e6ecc 100644 --- a/configs/gemma3-27b-eagle3.json +++ b/configs/gemma3-27b-eagle3.json @@ -32,6 +32,6 @@ "eagle_config": { "eagle_aux_hidden_state_layer_ids": [1, 29, 61] }, - "additional_fc": false, + "use_aux_norm": true, "reuse_target_lm_head": true } diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 7ca7e80e0..fbffe8ef6 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -437,7 +437,6 @@ def yarn_linear_ramp_mask(min_val, max_val, dim): class LlamaYarnRotaryEmbedding(LlamaRotaryEmbedding): - def __init__( self, dim, @@ -993,9 +992,9 @@ class LlamaUSPFlashAttention(LlamaAttention): def __init__(self, config): super().__init__(config) - assert ( - dist.is_initialized() - ), f"LlamaUSPAttention requires torch.distributed; call init_distributed first." + assert dist.is_initialized(), ( + f"LlamaUSPAttention requires torch.distributed; call init_distributed first." + ) if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): raise NotImplementedError( f"LlamaMutiRotaryEmbedding is currently not supported for LlamaUSPFlashAttention." @@ -1333,14 +1332,22 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: ) self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) - target_hidden_size = getattr(config, "target_hidden_size", config.hidden_size) - # Per-layer RMSNorm applied to each aux hidden state before concatenation, - # so that all three layers contribute equally regardless of their raw scale. - self.aux_norm_low = LlamaRMSNorm(target_hidden_size, eps=config.rms_norm_eps) - self.aux_norm_mid = LlamaRMSNorm(target_hidden_size, eps=config.rms_norm_eps) - self.aux_norm_high = LlamaRMSNorm(target_hidden_size, eps=config.rms_norm_eps) + # Optional per-layer RMSNorm applied to each aux hidden state before + # concatenation, so that all three layers contribute equally regardless + # of their raw scale. Enabled via config "use_aux_norm": true. + self.use_aux_norm = getattr(config, "use_aux_norm", False) + if self.use_aux_norm: + self.aux_norm_low = LlamaRMSNorm( + target_hidden_size, eps=config.rms_norm_eps + ) + self.aux_norm_mid = LlamaRMSNorm( + target_hidden_size, eps=config.rms_norm_eps + ) + self.aux_norm_high = LlamaRMSNorm( + target_hidden_size, eps=config.rms_norm_eps + ) self.fc = torch.nn.Linear( target_hidden_size * 3, config.hidden_size, bias=False @@ -1472,13 +1479,14 @@ def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: target_h = getattr(self.config, "target_hidden_size", self.config.hidden_size) assert hidden_states.size(-1) == target_h * 3 - # Normalize each aux layer independently before fc projection, - # so all three contribute equally regardless of their raw scale. - h_low, h_mid, h_high = hidden_states.split(target_h, dim=-1) - h_low = self.aux_norm_low(h_low) - h_mid = self.aux_norm_mid(h_mid) - h_high = self.aux_norm_high(h_high) - hidden_states = torch.cat((h_low, h_mid, h_high), dim=-1) + if self.use_aux_norm: + # Normalize each aux layer independently before fc projection, + # so all three contribute equally regardless of their raw scale. + h_low, h_mid, h_high = hidden_states.split(target_h, dim=-1) + h_low = self.aux_norm_low(h_low) + h_mid = self.aux_norm_mid(h_mid) + h_high = self.aux_norm_high(h_high) + hidden_states = torch.cat((h_low, h_mid, h_high), dim=-1) return self.fc(hidden_states) From a79a5f59211bfb2be4c51ca26e6f234c6d473232 Mon Sep 17 00:00:00 2001 From: Pengyu Chen Date: Fri, 1 May 2026 21:34:23 +0000 Subject: [PATCH 3/4] cleanup --- examples/run_gemma3_27b_eagle3_online.sh | 4 +- examples/run_gemma4_26b_eagle3_online.sh | 6 +- scripts/train_eagle3.py | 113 +++--------------- specforge/data/preprocessing.py | 11 +- specforge/data/template.py | 6 +- specforge/modeling/draft/base.py | 8 +- specforge/modeling/draft/llama3_eagle.py | 18 +-- .../modeling/target/eagle3_target_model.py | 4 +- 8 files changed, 41 insertions(+), 129 deletions(-) diff --git a/examples/run_gemma3_27b_eagle3_online.sh b/examples/run_gemma3_27b_eagle3_online.sh index e3709e32e..d98416cbc 100755 --- a/examples/run_gemma3_27b_eagle3_online.sh +++ b/examples/run_gemma3_27b_eagle3_online.sh @@ -13,8 +13,7 @@ torchrun \ --target-model-path google/gemma-3-27b-it \ --draft-model-config $ROOT_DIR/configs/gemma3-27b-eagle3.json \ --train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \ - $ROOT_DIR/cache/dataset/translate_bp_regen_gemma3_train.jsonl \ - --output-dir $ROOT_DIR/outputs/gemma3-27b-eagle3-ucml-mix-l-aq \ + --output-dir $ROOT_DIR/outputs/gemma3-27b-eagle3 \ --num-epochs 8 \ --batch-size 2 \ --draft-accumulation-steps 2 \ @@ -30,5 +29,4 @@ torchrun \ --save-interval 10000 \ --build-dataset-num-proc 64 \ --report-to tensorboard \ - --eval-holdout-ratio 0.005 \ --embedding-key=language_model.model.embed_tokens.weight diff --git a/examples/run_gemma4_26b_eagle3_online.sh b/examples/run_gemma4_26b_eagle3_online.sh index 5af333c73..a58021b1e 100755 --- a/examples/run_gemma4_26b_eagle3_online.sh +++ b/examples/run_gemma4_26b_eagle3_online.sh @@ -14,9 +14,8 @@ torchrun \ --draft-model-config $ROOT_DIR/configs/gemma4-26b-a4b-eagle3.json \ --train-data-path \ $ROOT_DIR/outputs/dataset/ultrachat_regen_gemma4_preformatted.jsonl \ - $ROOT_DIR/outputs/dataset/translate_bp_regen_gemma4_preformatted.jsonl \ --is-preformatted \ - --output-dir $ROOT_DIR/outputs/gemma4-26b-a4b-eagle3-ucml \ + --output-dir $ROOT_DIR/outputs/gemma4-26b-a4b-eagle3 \ --num-epochs 8 \ --batch-size 4 \ --tp-size $TP_SIZE \ @@ -31,5 +30,4 @@ torchrun \ --save-interval 10000 \ --build-dataset-num-proc 64 \ --report-to tensorboard \ - --embedding-key=model.language_model.embed_tokens.weight \ - --eval-holdout-ratio 0.005 + --embedding-key=model.language_model.embed_tokens.weight diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 44374d79c..3c668e64b 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -1,10 +1,9 @@ import argparse -import glob import hashlib +import json import math import os import time -from datetime import datetime from argparse import ArgumentParser, Namespace from typing import List, Optional, Tuple, Union @@ -108,17 +107,10 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: # dataset arguments dataset_group = parser.add_argument_group("dataset") - dataset_group.add_argument("--train-data-path", type=str, nargs="+", required=True) + dataset_group.add_argument("--train-data-path", type=str, required=True) dataset_group.add_argument("--train-hidden-states-path", type=str, default=None) dataset_group.add_argument("--eval-hidden-states-path", type=str, default=None) dataset_group.add_argument("--eval-data-path", type=str, default=None) - dataset_group.add_argument( - "--eval-holdout-ratio", - type=float, - default=None, - help="Fraction of the training dataset to hold out for evaluation (0 to 1). " - "Mutually exclusive with --eval-data-path and --eval-hidden-states-path.", - ) dataset_group.add_argument("--chat-template", type=str, default="llama3") dataset_group.add_argument( "--is-preformatted", @@ -356,18 +348,6 @@ def sanity_check(args: Namespace) -> None: args.dp_size = dist.get_world_size() // args.tp_size args.target_batch_size = args.tp_size * args.batch_size - if args.eval_holdout_ratio is not None: - if not (0 < args.eval_holdout_ratio < 1): - raise ValueError( - f"--eval-holdout-ratio must be between 0 and 1 (exclusive), " - f"got {args.eval_holdout_ratio}" - ) - if args.eval_data_path is not None or args.eval_hidden_states_path is not None: - raise ValueError( - "--eval-holdout-ratio is mutually exclusive with " - "--eval-data-path and --eval-hidden-states-path" - ) - if args.attention_backend == "usp": sp_sanity_check(args) @@ -376,9 +356,9 @@ def sp_sanity_check(args: Namespace) -> None: args.draft_accumulation_steps = ( args.draft_accumulation_steps * args.sp_ulysses_size * args.sp_ring_size ) - assert args.batch_size == 1, ( - f"USP only supports batch_size=1, got batch_size={args.batch_size}" - ) + assert ( + args.batch_size == 1 + ), f"USP only supports batch_size=1, got batch_size={args.batch_size}" assert args.sp_ring_size * args.sp_ulysses_size > 1, ( f"USP requires sp_ring_size * sp_ulysses_size > 1. " @@ -477,55 +457,24 @@ def build_dataloaders( args: Namespace, draft_model_config: AutoDraftModelConfig, processor: Optional[AutoProcessor] = None, -) -> Tuple[DataLoader, Optional[str], Optional[DataLoader]]: +) -> Tuple[DataLoader, str, Optional[DataLoader]]: # build dataloaders tokenizer = AutoTokenizer.from_pretrained( args.target_model_path, trust_remote_code=args.trust_remote_code ) - # Resolve all training data paths: expand directories to their .jsonl files - resolved_train_files = [] - for path in args.train_data_path: - if os.path.isdir(path): - jsonl_files = sorted(glob.glob(os.path.join(path, "*.jsonl"))) - if not jsonl_files: - raise ValueError(f"No .jsonl files found in directory: {path}") - resolved_train_files.extend(jsonl_files) - elif os.path.isfile(path): - resolved_train_files.append(path) - else: - raise ValueError(f"Training data path does not exist: {path}") - print_on_rank0( - f"Resolved {len(resolved_train_files)} training file(s) from " - f"{len(args.train_data_path)} path(s)" - ) - # convert to dataloader cache_params_string = ( - f"{','.join(sorted(resolved_train_files))}-" + f"{args.train_data_path}-" f"{args.max_length}-" f"{args.chat_template}-" f"{args.target_model_path}" # Tokenizer may also different ) cache_key = hashlib.md5(cache_params_string.encode()).hexdigest() - - # Build datasets from all resolved files and concatenate - from datasets import concatenate_datasets - - train_datasets = [] - for file_path in resolved_train_files: - ds = Dataset.from_generator( - generator=safe_conversations_generator, - gen_kwargs={"file_path": file_path}, - ) - train_datasets.append(ds) - train_dataset = ( - concatenate_datasets(train_datasets) - if len(train_datasets) > 1 - else train_datasets[0] + train_dataset = Dataset.from_generator( + generator=safe_conversations_generator, + gen_kwargs={"file_path": args.train_data_path}, ) - print_on_rank0(f"Combined training dataset: {len(train_dataset)} examples") - is_online = ( args.train_data_path is not None and args.train_hidden_states_path is None ) @@ -559,21 +508,6 @@ def build_dataloaders( use_usp_preprocess=(args.attention_backend == "usp"), ) - # Split a holdout portion from the training set if requested. - eval_eagle3_dataset_from_holdout = None - if args.eval_holdout_ratio is not None and args.eval_holdout_ratio > 0: - split = train_eagle3_dataset.train_test_split( - test_size=args.eval_holdout_ratio, - seed=args.seed, - ) - train_eagle3_dataset = split["train"] - eval_eagle3_dataset_from_holdout = split["test"] - print_on_rank0( - f"Holdout split: {len(train_eagle3_dataset)} train, " - f"{len(eval_eagle3_dataset_from_holdout)} eval " - f"(ratio={args.eval_holdout_ratio})" - ) - train_dataloader = prepare_dp_dataloaders( train_eagle3_dataset, args.target_batch_size, @@ -586,13 +520,7 @@ def build_dataloaders( ), is_vlm=args.is_vlm, ) - - has_eval = ( - args.eval_data_path is not None - or args.eval_hidden_states_path is not None - or eval_eagle3_dataset_from_holdout is not None - ) - if has_eval: + if args.eval_data_path is not None or args.eval_hidden_states_path is not None: if args.eval_data_path is not None: eval_dataset = Dataset.from_generator( generator=safe_conversations_generator, @@ -616,8 +544,6 @@ def build_dataloaders( ttt_length=args.ttt_length, use_usp_preprocess=(args.attention_backend == "usp"), ) - else: - eval_eagle3_dataset = eval_eagle3_dataset_from_holdout eval_dataloader = prepare_dp_dataloaders( eval_eagle3_dataset, args.target_batch_size, @@ -682,8 +608,6 @@ def save_checkpoints( # transformers v5 mutating rope_scaling/rope_parameters and other # fields in model.config during save_pretrained. if getattr(args, "draft_model_config", None): - import json - config_path = os.path.join(epoch_output_dir, "config.json") with open(args.draft_model_config) as f: original_config = json.load(f) @@ -845,16 +769,6 @@ def main(): ) sanity_check(args) - - # Create a datetime subfolder for this run (skip when resuming into an - # existing output directory so that checkpoints stay in the same place). - if not args.resume: - run_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - args.output_dir = os.path.join(args.output_dir, run_timestamp) - if dist.get_rank() == 0: - os.makedirs(args.output_dir, exist_ok=True) - dist.barrier() - print_args_with_dots(args) print_with_rank("Initialized distributed environment") @@ -1064,7 +978,10 @@ def main(): # ================================================ # 7.2 Evaluation Step # ================================================ - should_evaluate = eval_dataloader is not None + should_evaluate = ( + args.eval_data_path is not None + or args.eval_hidden_states_path is not None + ) if ( should_evaluate and global_step % (args.eval_interval * args.draft_accumulation_steps) diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index fd3445c46..3944918be 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -123,7 +123,7 @@ def preprocess_conversations( max_length: int = 2048, is_preformatted: bool = False, train_only_last_turn: bool = False, - tools: Optional[List[List[Dict]]] = None, + tools: Optional[List[List[Dict]]] = [[]], **kwargs, ) -> Dict[str, List[torch.Tensor]]: """ @@ -155,9 +155,6 @@ def preprocess_conversations( parser = HarmonyParser(tokenizer, chat_template) else: raise ValueError(f"Invalid parser type: {chat_template.parser_type}") - # Ensure tools list matches conversations length - if tools is None or len(tools) != len(conversations): - tools = [[] for _ in range(len(conversations))] kwargs_list = [{} for _ in range(len(conversations))] for key, value_list in kwargs.items(): for i, value in enumerate(value_list): @@ -347,9 +344,9 @@ def build_eagle3_dataset( if chat_template is None: raise ValueError("chat_template must be provided for all dataset types") - assert chat_template in TEMPLATE_REGISTRY.get_all_template_names(), ( - f"Chat template {chat_template} not found in TEMPLATE_REGISTRY, you may need to register it first" - ) + assert ( + chat_template in TEMPLATE_REGISTRY.get_all_template_names() + ), f"Chat template {chat_template} not found in TEMPLATE_REGISTRY, you may need to register it first" template: ChatTemplate = TEMPLATE_REGISTRY.get(chat_template) diff --git a/specforge/data/template.py b/specforge/data/template.py index 36381e69f..352d057e1 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -58,9 +58,9 @@ def register(self, name: str, template: ChatTemplate, override: bool = False): template(ChatTemplate): The chat template. override(bool): Whether to override the existing template, default to False """ - assert not override and name not in self.templates, ( - f"Chat template for the model type {name} has already been registered" - ) + assert ( + not override and name not in self.templates + ), f"Chat template for the model type {name} has already been registered" self.templates[name] = template def get(self, name: str) -> ChatTemplate: diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index 74e87a14e..0c0c353de 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -174,7 +174,7 @@ def _load_tensor_from_checkpoint( @torch.no_grad() def load_embedding( - self, model_path: str, embedding_key: str = "language_model.embed_tokens.weight" + self, model_path: str, embedding_key: str = "model.embed_tokens.weight" ) -> None: """ Load the embedding of the draft model from the target model checkpoint. @@ -220,9 +220,9 @@ def load_vocab_mapping(self, file_path: str) -> None: Args: file_path (str): The path to the vocab mapping file. """ - assert hasattr(self, "t2d") and hasattr(self, "d2t"), ( - "t2d and d2t buffersare not found in the draft model, please check your draft model implementation" - ) + assert hasattr(self, "t2d") and hasattr( + self, "d2t" + ), "t2d and d2t buffersare not found in the draft model, please check your draft model implementation" vocab_mapping = torch.load(file_path) self.t2d.copy_(vocab_mapping["t2d"]) self.d2t.copy_(vocab_mapping["d2t"]) diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index fbffe8ef6..6da987c5f 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -936,9 +936,9 @@ def forward( k0 = cache_k[0] v0 = cache_v[0] - assert flash_attn_func is not None, ( - "flash_attn is not installed, please install flash_attn if you want to use the flash attention backend" - ) + assert ( + flash_attn_func is not None + ), "flash_attn is not installed, please install flash_attn if you want to use the flash attention backend" attn_output, lse, _ = flash_attn_func( query_states, k0, @@ -992,9 +992,9 @@ class LlamaUSPFlashAttention(LlamaAttention): def __init__(self, config): super().__init__(config) - assert dist.is_initialized(), ( - f"LlamaUSPAttention requires torch.distributed; call init_distributed first." - ) + assert ( + dist.is_initialized() + ), f"LlamaUSPAttention requires torch.distributed; call init_distributed first." if isinstance(self.rotary_emb, LlamaMutiRotaryEmbedding): raise NotImplementedError( f"LlamaMutiRotaryEmbedding is currently not supported for LlamaUSPFlashAttention." @@ -1109,9 +1109,9 @@ def forward( else: acc_lse = lse_ring - assert acc_lse.shape[1] == current_q_len, ( - f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}" - ) + assert ( + acc_lse.shape[1] == current_q_len + ), f"LSE seq_len {acc_lse.shape[1]} mismatch with Query seq_len {current_q_len}" acc_out = out_ring diff --git a/specforge/modeling/target/eagle3_target_model.py b/specforge/modeling/target/eagle3_target_model.py index c7d4991f2..a505194d5 100644 --- a/specforge/modeling/target/eagle3_target_model.py +++ b/specforge/modeling/target/eagle3_target_model.py @@ -95,7 +95,9 @@ def set_aux_hidden_states_layers( if hasattr(self.model.config, "num_hidden_layers"): num_layers = self.model.config.num_hidden_layers - elif hasattr(self.model.config, "text_config") and hasattr(self.model.config.text_config, "num_hidden_layers"): + elif hasattr(self.model.config, "text_config") and hasattr( + self.model.config.text_config, "num_hidden_layers" + ): num_layers = self.model.config.text_config.num_hidden_layers else: raise ValueError( From 82a0c48cdbb2723e109df87765903ab2edd3240e Mon Sep 17 00:00:00 2001 From: kathyligg Date: Sat, 2 May 2026 00:31:51 +0000 Subject: [PATCH 4/4] Bump MAX_FUSED_SIZE to 262208 to fit Gemma3/4 vocab Gemma3 27B and Gemma4 26B have a vocabulary size of 262144, which makes triton.next_power_of_2 round up to 262144 (==2^18). The previous limit of 131072 caused _calculate_settings() to raise RuntimeError before the log-softmax loss kernel could launch, preventing Eagle3 training on these targets. --- specforge/core/loss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/specforge/core/loss.py b/specforge/core/loss.py index 30e7fba7d..2aa337692 100644 --- a/specforge/core/loss.py +++ b/specforge/core/loss.py @@ -24,7 +24,7 @@ def _compute_loss(logits, target_p, position_mask): def _calculate_settings(n): # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 - MAX_FUSED_SIZE = 131072 + MAX_FUSED_SIZE = 262208 BLOCK_SIZE = triton.next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: raise RuntimeError(