diff --git a/configs/gemma3-27b-eagle3.json b/configs/gemma3-27b-eagle3.json new file mode 100644 index 000000000..3123e6ecc --- /dev/null +++ b/configs/gemma3-27b-eagle3.json @@ -0,0 +1,37 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 1, + "pad_token_id": 0, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 5376, + "initializer_range": 0.02, + "intermediate_size": 21504, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 16, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": 512, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 262208, + "draft_vocab_size": 262208, + "target_model_type": "gemma3_text", + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [1, 29, 61] + }, + "use_aux_norm": true, + "reuse_target_lm_head": true +} diff --git a/configs/gemma4-26b-a4b-eagle3.json b/configs/gemma4-26b-a4b-eagle3.json new file mode 100644 index 000000000..f50e1345b --- /dev/null +++ b/configs/gemma4-26b-a4b-eagle3.json @@ -0,0 +1,35 @@ +{ + "architectures": [ + "LlamaForCausalLMEagle3" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "bos_token_id": 2, + "eos_token_id": 1, + "pad_token_id": 0, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 2816, + "initializer_range": 0.02, + "intermediate_size": 2112, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000.0, + "sliding_window": null, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 262144, + "draft_vocab_size": 262144, + "target_model_type": "gemma4_text", + "eagle_config": { + "eagle_aux_hidden_state_layer_ids": [5, 17, 29] + } +} diff --git a/examples/run_gemma3_27b_eagle3_online.sh b/examples/run_gemma3_27b_eagle3_online.sh new file mode 100755 index 000000000..d98416cbc --- /dev/null +++ b/examples/run_gemma3_27b_eagle3_online.sh @@ -0,0 +1,32 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for gemma3-27b +NUM_GPUS=${1:-8} +TP_SIZE=${2:-8} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path google/gemma-3-27b-it \ + --draft-model-config $ROOT_DIR/configs/gemma3-27b-eagle3.json \ + --train-data-path $ROOT_DIR/cache/dataset/ultrachat_train.jsonl \ + --output-dir $ROOT_DIR/outputs/gemma3-27b-eagle3 \ + --num-epochs 8 \ + --batch-size 2 \ + --draft-accumulation-steps 2 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template gemma \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend hf \ + --log-interval 200 \ + --eval-interval 5000 \ + --save-interval 10000 \ + --build-dataset-num-proc 64 \ + --report-to tensorboard \ + --embedding-key=language_model.model.embed_tokens.weight diff --git a/examples/run_gemma4_26b_eagle3_online.sh b/examples/run_gemma4_26b_eagle3_online.sh new file mode 100755 index 000000000..a58021b1e --- /dev/null +++ b/examples/run_gemma4_26b_eagle3_online.sh @@ -0,0 +1,33 @@ +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels + +# train eagle3 for gemma4-26b-a4b +NUM_GPUS=${1:-8} +TP_SIZE=${2:-2} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_eagle3.py \ + --target-model-path google/gemma-4-26b-a4b-it \ + --draft-model-config $ROOT_DIR/configs/gemma4-26b-a4b-eagle3.json \ + --train-data-path \ + $ROOT_DIR/outputs/dataset/ultrachat_regen_gemma4_preformatted.jsonl \ + --is-preformatted \ + --output-dir $ROOT_DIR/outputs/gemma4-26b-a4b-eagle3 \ + --num-epochs 8 \ + --batch-size 4 \ + --tp-size $TP_SIZE \ + --learning-rate 1e-4 \ + --max-length 2048 \ + --chat-template gemma-4 \ + --cache-dir $ROOT_DIR/cache \ + --attention-backend sdpa \ + --target-model-backend hf \ + --log-interval 200 \ + --eval-interval 5000 \ + --save-interval 10000 \ + --build-dataset-num-proc 64 \ + --report-to tensorboard \ + --embedding-key=model.language_model.embed_tokens.weight diff --git a/scripts/train_eagle3.py b/scripts/train_eagle3.py index 0bd157b39..3c668e64b 100644 --- a/scripts/train_eagle3.py +++ b/scripts/train_eagle3.py @@ -1,5 +1,6 @@ import argparse import hashlib +import json import math import os import time @@ -86,6 +87,13 @@ def parse_args() -> Tuple[ArgumentParser, Namespace]: default="lm_head.weight", help="The key of the lm head weight to load from the target model, this is only required for offline training", ) + model_group.add_argument( + "--reuse-target-lm-head", + action="store_true", + help="Load the target model's lm_head weights into the draft model's lm_head " + "and freeze it. Supports both tied and untied target models. " + "Requires draft_vocab_size == vocab_size.", + ) model_group.add_argument( "--is-vlm", action="store_true", help="Whether the target model is a VLM" ) @@ -339,6 +347,7 @@ def sanity_check(args: Namespace) -> None: """ args.dp_size = dist.get_world_size() // args.tp_size args.target_batch_size = args.tp_size * args.batch_size + if args.attention_backend == "usp": sp_sanity_check(args) @@ -433,6 +442,14 @@ def build_draft_model(args: Namespace) -> Tuple[AutoDraftModelConfig, nn.Module] draft_model.load_embedding(args.target_model_path, embedding_key=args.embedding_key) draft_model.freeze_embedding() + if args.reuse_target_lm_head: + draft_model.load_lm_head( + args.target_model_path, + lm_head_key=args.lm_head_key, + embedding_key=args.embedding_key, + ) + draft_model.freeze_lm_head() + print_on_rank0("Loaded and froze lm_head from target model") return draft_model_config, draft_model, ckpt_info, resume_state @@ -587,6 +604,16 @@ def save_checkpoints( epoch_output_dir, state_dict=draft_model_state_dict, ) + # Overwrite config.json with the original training config to avoid + # transformers v5 mutating rope_scaling/rope_parameters and other + # fields in model.config during save_pretrained. + if getattr(args, "draft_model_config", None): + config_path = os.path.join(epoch_output_dir, "config.json") + with open(args.draft_model_config) as f: + original_config = json.load(f) + with open(config_path, "w") as f: + json.dump(original_config, f, indent=2) + print_on_rank0(f"Overwrote config.json with original training config") print_on_rank0(f"Saved model configuration to {epoch_output_dir}") dist.barrier() @@ -758,9 +785,14 @@ def main(): args, draft_model_config, processor ) - # we load the vocab mapping then - draft_model.load_vocab_mapping(vocab_mapping_path) - print_with_rank("Loaded vocab mapping") + # we load the vocab mapping then (skip when draft_vocab_size == target_vocab_size) + if vocab_mapping_path is not None: + draft_model.load_vocab_mapping(vocab_mapping_path) + print_with_rank("Loaded vocab mapping") + else: + print_with_rank( + "Skipped vocab mapping loading (draft_vocab_size == target_vocab_size)" + ) # Calculate total steps if not provided if args.total_steps is None: diff --git a/specforge/core/eagle3.py b/specforge/core/eagle3.py index 1e2f04e7e..893af4fe6 100644 --- a/specforge/core/eagle3.py +++ b/specforge/core/eagle3.py @@ -153,9 +153,13 @@ def forward( position_ids: (batch, seq_len) """ # Step 1: handle vocab size + # When draft_vocab_size == vocab_size, skip the d2t/t2d mapping entirely. + use_vocab_mapping = ( + self.draft_model.draft_vocab_size != self.draft_model.vocab_size + ) target_p_padded, position_mask = _compute_target_p_padded( target=target, - t2d=self.draft_model.t2d, + t2d=self.draft_model.t2d if use_vocab_mapping else None, loss_mask=loss_mask, length=self.length, ) @@ -439,9 +443,13 @@ def forward( ) # Step 1: handle vocab size + # When draft_vocab_size == vocab_size, skip the d2t/t2d mapping entirely. + use_vocab_mapping = ( + self.draft_model.draft_vocab_size != self.draft_model.vocab_size + ) target_p_padded, position_mask = _compute_target_p_padded( target=target, - t2d=self.draft_model.t2d, + t2d=self.draft_model.t2d if use_vocab_mapping else None, loss_mask=loss_mask, length=self.length, ) @@ -567,11 +575,18 @@ def forward( def _compute_target_p_padded(target, t2d, loss_mask, length): with torch.no_grad(): - target_p, position_mask = _compute_target_p( - target=target, - t2d=t2d, - loss_mask=loss_mask, - ) + if t2d is None: + # draft_vocab_size == target_vocab_size: skip d2t/t2d mapping + target_p, position_mask = _compute_target_p_full_vocab( + target=target, + loss_mask=loss_mask, + ) + else: + target_p, position_mask = _compute_target_p( + target=target, + t2d=t2d, + loss_mask=loss_mask, + ) assert len(target_p.shape) == 3 target_p_padded = F.pad( @@ -585,6 +600,16 @@ def _compute_target_p_padded(target, t2d, loss_mask, length): return target_p_padded, position_mask +@torch.compile(dynamic=None) +def _compute_target_p_full_vocab(target, loss_mask): + """Fast path when draft_vocab_size == target_vocab_size (no vocab subsetting).""" + target_head = target.float() + target_p = nn.Softmax(dim=2)(target_head) + target_p = target_p.detach() + # All target tokens are in the draft vocab, so position_mask == loss_mask. + return target_p, loss_mask + + @torch.compile(dynamic=None) def _compute_target_p(target, t2d, loss_mask): target_head = target diff --git a/specforge/core/loss.py b/specforge/core/loss.py index 30e7fba7d..2aa337692 100644 --- a/specforge/core/loss.py +++ b/specforge/core/loss.py @@ -24,7 +24,7 @@ def _compute_loss(logits, target_p, position_mask): def _calculate_settings(n): # reference: https://github.com/unslothai/unsloth/blob/fd753fed99ed5f10ef8a9b7139588d9de9ddecfb/unsloth/kernels/utils.py#L43 - MAX_FUSED_SIZE = 131072 + MAX_FUSED_SIZE = 262208 BLOCK_SIZE = triton.next_power_of_2(n) if BLOCK_SIZE > MAX_FUSED_SIZE: raise RuntimeError( diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index d5af9479e..3944918be 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -665,7 +665,6 @@ def build_offline_eagle3_dataset( ttt_length: int = 1, use_usp_preprocess: bool = False, ) -> torch.utils.data.Dataset: - return OfflineEagle3Dataset( list_local_files(hidden_states_path), max_len=max_len, @@ -683,7 +682,7 @@ def generate_vocab_mapping_file( draft_vocab_size: int, cache_dir: str = "./cache/vocab_mapping", cache_key: str = "vocab_mapping", -) -> str: +) -> Optional[str]: """ Generate a vocab mapping file for the dataset. @@ -695,8 +694,16 @@ def generate_vocab_mapping_file( cache_key: The key to use for caching the vocab mapping file. Returns: - The path to the vocab mapping file. + The path to the vocab mapping file, or None if draft_vocab_size + equals target_vocab_size (no mapping needed). """ + if draft_vocab_size == target_vocab_size: + print( + f"draft_vocab_size ({draft_vocab_size}) == target_vocab_size " + f"({target_vocab_size}), skipping vocab mapping generation." + ) + return None + # prepare cache directory os.makedirs(cache_dir, exist_ok=True) vocab_mapping_path = os.path.join(cache_dir, f"{cache_key}.pt") diff --git a/specforge/data/template.py b/specforge/data/template.py index 4dde000fd..352d057e1 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -286,7 +286,7 @@ def get_all_template_names(self) -> List[str]: template=ChatTemplate( assistant_header="model\n", user_header="user\n", - system_prompt="You are a helpful assistant.", + system_prompt=None, end_of_turn_token="\n", ), ) @@ -324,3 +324,15 @@ def get_all_template_names(self) -> List[str]: enable_thinking=True, ), ) + +TEMPLATE_REGISTRY.register( + name="gemma-4", + template=ChatTemplate( + assistant_header="<|turn>model\n", + user_header="<|turn>user\n", + system_prompt="", + end_of_turn_token="\n", + parser_type="thinking", + enable_thinking=True, + ), +) diff --git a/specforge/modeling/draft/base.py b/specforge/modeling/draft/base.py index b5584a759..0c0c353de 100644 --- a/specforge/modeling/draft/base.py +++ b/specforge/modeling/draft/base.py @@ -114,64 +114,104 @@ def freeze_embedding(self) -> None: """ self.embed_tokens.weight.requires_grad = False + def freeze_lm_head(self) -> None: + """ + Freeze the lm_head of the draft model so that it is not updated during training. + """ + self.lm_head.weight.requires_grad = False + + @torch.no_grad() + def _load_tensor_from_checkpoint( + self, model_path: str, tensor_key: str + ) -> torch.Tensor: + """ + Load a single tensor from a target model checkpoint. + + Args: + model_path (str): Path to the target model. Can be either a Hugging Face + repository ID or a local directory path containing the model files. + tensor_key (str): The key of the tensor to load from the checkpoint. + + Returns: + torch.Tensor: The loaded tensor. + """ + if not os.path.exists(model_path): + # model_path is a huggingface repository, download first + model_path = snapshot_download(repo_id=model_path) + + glob_path = os.path.join(model_path, "*.index.json") + index_json_paths = glob.glob(glob_path) + + if len(index_json_paths) == 0: + # No index.json found, look for single model file + safetensors_path = os.path.join(model_path, "model.safetensors") + if os.path.exists(safetensors_path): + with safe_open(safetensors_path, framework="pt") as f: + return f.get_tensor(tensor_key) + + pytorch_model_path = os.path.join(model_path, "pytorch_model.bin") + if os.path.exists(pytorch_model_path): + state_dict = torch.load(pytorch_model_path, map_location="cpu") + return state_dict[tensor_key] + + raise FileNotFoundError( + f"No index.json, model.safetensors or pytorch_model.bin found in {model_path}" + ) + + if len(index_json_paths) > 1: + raise FileNotFoundError(f"Multiple index.json files found in {model_path}") + + with open(index_json_paths[0], "r") as f: + index_json = json.load(f) + ckpt_file = index_json["weight_map"][tensor_key] + + if ckpt_file.endswith(".safetensors"): + with safe_open(os.path.join(model_path, ckpt_file), framework="pt") as f: + return f.get_tensor(tensor_key) + else: + state_dict = torch.load(os.path.join(model_path, ckpt_file)) + return state_dict[tensor_key] + @torch.no_grad() def load_embedding( self, model_path: str, embedding_key: str = "model.embed_tokens.weight" ) -> None: """ - Load the embedding of the draft model. + Load the embedding of the draft model from the target model checkpoint. Args: model_path (str): Path to the target model. Can be either a Hugging Face - repository ID or a local directory path containing the model files. - """ - if os.path.exists(model_path): - # model_path is a local directory - # check if there is file ending with index.json - glob_path = os.path.join(model_path, "*.index.json") - index_json_path = glob.glob(glob_path) - - if len(index_json_path) == 0: - # No index.json found, look for single model file - safetensors_path = os.path.join(model_path, "model.safetensors") - if os.path.exists(safetensors_path): - with safe_open(safetensors_path, framework="pt") as f: - self.embed_tokens.weight.copy_(f.get_tensor(embedding_key)) - return - - pytorch_model_path = os.path.join(model_path, "pytorch_model.bin") - if os.path.exists(pytorch_model_path): - state_dict = torch.load(pytorch_model_path, map_location="cpu") - self.embed_tokens.weight.copy_(state_dict[embedding_key]) - return - - raise FileNotFoundError( - f"No index.json, model.safetensors or pytorch_model.bin found in {model_path}" - ) - if len(index_json_path) > 1: - raise FileNotFoundError( - f"Multiple index.json files found in {model_path}" - ) - index_json_path = index_json_path[0] - - with open(index_json_path, "r") as f: - index_json = json.load(f) - ckpt_file = index_json["weight_map"][embedding_key] - - if ckpt_file.endswith(".safetensors"): - with safe_open( - os.path.join(model_path, ckpt_file), framework="pt" - ) as f: - emb_tokens = f.get_tensor(embedding_key) - else: - state_dict = torch.load(os.path.join(model_path, ckpt_file)) - emb_tokens = state_dict[embedding_key] - self.embed_tokens.weight.copy_(emb_tokens) - else: - # this is the case where model_path is a huggingface repository - # we first need to locate its local cache - local_cache_path = snapshot_download(repo_id=model_path) - self.load_embedding(local_cache_path, embedding_key) + repository ID or a local directory path containing the model files. + embedding_key (str): The key of the embedding weight in the checkpoint. + """ + tensor = self._load_tensor_from_checkpoint(model_path, embedding_key) + self.embed_tokens.weight.copy_(tensor) + + @torch.no_grad() + def load_lm_head( + self, model_path: str, lm_head_key: str, embedding_key: str + ) -> None: + """ + Load the lm_head of the draft model from the target model checkpoint. + + For models with tied weights (embed_tokens == lm_head), the lm_head key + may not exist in the checkpoint. In that case, falls back to loading + from the embedding key. + + Args: + model_path (str): Path to the target model. Can be either a Hugging Face + repository ID or a local directory path containing the model files. + lm_head_key (str): The key of the lm_head weight in the checkpoint. + embedding_key (str): Fallback key if lm_head_key is not found (for + models with tie_word_embeddings=True). + """ + try: + tensor = self._load_tensor_from_checkpoint(model_path, lm_head_key) + except KeyError: + # Target model ties weights -- lm_head key doesn't exist in checkpoint, + # fall back to embedding key + tensor = self._load_tensor_from_checkpoint(model_path, embedding_key) + self.lm_head.weight.copy_(tensor) def load_vocab_mapping(self, file_path: str) -> None: """ diff --git a/specforge/modeling/draft/llama3_eagle.py b/specforge/modeling/draft/llama3_eagle.py index 268142c0c..6da987c5f 100644 --- a/specforge/modeling/draft/llama3_eagle.py +++ b/specforge/modeling/draft/llama3_eagle.py @@ -437,7 +437,6 @@ def yarn_linear_ramp_mask(min_val, max_val, dim): class LlamaYarnRotaryEmbedding(LlamaRotaryEmbedding): - def __init__( self, dim, @@ -538,12 +537,24 @@ def __init__(self, config): ) self._init_rope() + def _get_rope_theta(self): + """Extract rope_theta from config, handling transformers v5 which moves + rope_theta into rope_scaling/rope_parameters instead of a top-level attr.""" + rope_theta = getattr(self.config, "rope_theta", None) + if rope_theta is not None: + return rope_theta + for attr in ("rope_parameters", "rope_scaling"): + params = getattr(self.config, attr, None) + if isinstance(params, dict) and "rope_theta" in params: + return params["rope_theta"] + raise RuntimeError("rope theta is not set.") + def _init_rope(self): if self.config.rope_scaling is None: self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, - base=getattr(self.config, "rope_theta", 10000), + base=self._get_rope_theta(), ) else: rope_scaling = self.config.rope_scaling @@ -560,7 +571,7 @@ def rope_get(key, default=None): self.rotary_emb = LlamaRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, - base=getattr(self.config, "rope_theta", 10000), + base=self._get_rope_theta(), ) return elif scaling_type == "linear": @@ -1008,7 +1019,6 @@ def forward( output_attentions: bool = False, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - bsz, q_len, _ = hidden_states.size() local_q_len = q_len @@ -1246,10 +1256,8 @@ def __init__(self, config, attention_backend: str = "sdpa"): self.attention_backend = attention_backend self.mlp = LlamaMLP(config) - # self.fc = nn.Linear(config.hidden_size * 2, config.hidden_size) self.hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - # if self.index!=0: self.post_attention_layernorm = LlamaRMSNorm( config.hidden_size, eps=config.rms_norm_eps @@ -1306,12 +1314,10 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - # outputs = (hidden_states, return_hidden) return hidden_states class LlamaForCausalLMEagle3(Eagle3DraftModel): - config_class = LlamaConfig def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: @@ -1326,26 +1332,86 @@ def __init__(self, config, quant_config=None, attention_backend="sdpa") -> None: ) self.midlayer = LlamaDecoderLayer(config, attention_backend=attention_backend) - if hasattr(config, "target_hidden_size"): - self.fc = torch.nn.Linear( - config.target_hidden_size * 3, config.hidden_size, bias=False + target_hidden_size = getattr(config, "target_hidden_size", config.hidden_size) + + # Optional per-layer RMSNorm applied to each aux hidden state before + # concatenation, so that all three layers contribute equally regardless + # of their raw scale. Enabled via config "use_aux_norm": true. + self.use_aux_norm = getattr(config, "use_aux_norm", False) + if self.use_aux_norm: + self.aux_norm_low = LlamaRMSNorm( + target_hidden_size, eps=config.rms_norm_eps ) - else: - self.fc = torch.nn.Linear( - config.hidden_size * 3, config.hidden_size, bias=False + self.aux_norm_mid = LlamaRMSNorm( + target_hidden_size, eps=config.rms_norm_eps ) + self.aux_norm_high = LlamaRMSNorm( + target_hidden_size, eps=config.rms_norm_eps + ) + + self.fc = torch.nn.Linear( + target_hidden_size * 3, config.hidden_size, bias=False + ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.lm_head = nn.Linear( config.hidden_size, config.draft_vocab_size, bias=False ) + # Embedding scale factor for target models that use scaled embeddings + # (e.g., Gemma3/Gemma4 multiply by hidden_size**0.5). Set via config + # field ``embed_scale`` or auto-detected from ``target_model_type``. + target_type = getattr(config, "target_model_type", None) or "" + if getattr(config, "embed_scale", None) is not None: + self.embed_scale = config.embed_scale + elif "gemma" in target_type: + self.embed_scale = config.hidden_size**0.5 + else: + self.embed_scale = 1.0 + # create vocab buffers t2d = torch.ones(self.vocab_size, dtype=torch.bool) d2t = torch.zeros(self.draft_vocab_size, dtype=torch.int64) self.register_buffer("t2d", t2d) self.register_buffer("d2t", d2t) + # Apply improved initialization for stable training with mixed + # pretrained (frozen) and randomly initialized (trainable) parameters. + self._init_weights() + + def _init_weights(self) -> None: + """ + Initialize weights for stable training when mixing pretrained frozen + components (embed_tokens, lm_head) with randomly initialized trainable + layers. + + Strategy: + - Zero-init residual projections (o_proj, down_proj) so the decoder + layer starts as near-identity through the residual stream. + - Use small normal init (std=0.02) for other projections instead of + Kaiming uniform, matching the config's initializer_range. + - RMSNorm weights stay at ones (PyTorch default). + """ + std = getattr(self.config, "initializer_range", 0.02) + + for name, param in self.named_parameters(): + if "embed_tokens" in name or "lm_head" in name: + # These will be overwritten by load_embedding / load_lm_head + continue + if "norm" in name or "layernorm" in name: + # RMSNorm weights: keep at ones (default) + continue + if param.dim() < 2: + # Biases and 1-d params: zero init + nn.init.zeros_(param) + continue + + # Zero-init residual-path projections for near-identity at start + if "o_proj" in name or "down_proj" in name: + nn.init.zeros_(param) + else: + nn.init.normal_(param, mean=0.0, std=std) + def forward( self, hidden_states: torch.Tensor, @@ -1403,11 +1469,25 @@ def forward( return hidden_states def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + embeds = self.embed_tokens(input_ids) + if self.embed_scale != 1.0: + embeds = embeds * self.embed_scale + return embeds def project_hidden_states(self, hidden_states: torch.Tensor) -> torch.Tensor: # eagle 3 requires hidden states from 3 layers - assert hidden_states.size(-1) == self.config.hidden_size * 3 + target_h = getattr(self.config, "target_hidden_size", self.config.hidden_size) + assert hidden_states.size(-1) == target_h * 3 + + if self.use_aux_norm: + # Normalize each aux layer independently before fc projection, + # so all three contribute equally regardless of their raw scale. + h_low, h_mid, h_high = hidden_states.split(target_h, dim=-1) + h_low = self.aux_norm_low(h_low) + h_mid = self.aux_norm_mid(h_mid) + h_high = self.aux_norm_high(h_high) + hidden_states = torch.cat((h_low, h_mid, h_high), dim=-1) + return self.fc(hidden_states) def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: diff --git a/specforge/modeling/target/custom_backend/gpt_oss.py b/specforge/modeling/target/custom_backend/gpt_oss.py index b3b4a7972..9910633c8 100644 --- a/specforge/modeling/target/custom_backend/gpt_oss.py +++ b/specforge/modeling/target/custom_backend/gpt_oss.py @@ -36,7 +36,6 @@ from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple -from transformers.utils.generic import check_model_inputs from specforge.distributed import get_tp_group, shard_tensor from specforge.layers import ( @@ -585,7 +584,6 @@ def __init__(self, config: GptOssConfig): # Initialize weights and apply final processing self.post_init() - @check_model_inputs @auto_docstring def forward( self, diff --git a/specforge/modeling/target/custom_backend/llama.py b/specforge/modeling/target/custom_backend/llama.py index 04a3f6c9b..02a1c16c4 100644 --- a/specforge/modeling/target/custom_backend/llama.py +++ b/specforge/modeling/target/custom_backend/llama.py @@ -41,7 +41,6 @@ ) from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, logging -from transformers.utils.generic import check_model_inputs from specforge.distributed import get_tp_group from specforge.layers import ( @@ -275,7 +274,6 @@ def __init__(self, config: LlamaConfig): # Initialize weights and apply final processing self.post_init() - @check_model_inputs def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/specforge/modeling/target/custom_backend/llama4.py b/specforge/modeling/target/custom_backend/llama4.py index 22f807dae..bccbb19ea 100644 --- a/specforge/modeling/target/custom_backend/llama4.py +++ b/specforge/modeling/target/custom_backend/llama4.py @@ -52,7 +52,6 @@ logging, ) from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import check_model_inputs # [MODIFIED] Import from transformers library from specforge.distributed import get_tp_group, shard_tensor @@ -431,7 +430,6 @@ def __init__(self, config: Llama4TextConfig): self.post_init() @can_return_tuple - @check_model_inputs @auto_docstring def forward( self, diff --git a/specforge/modeling/target/custom_backend/phi3.py b/specforge/modeling/target/custom_backend/phi3.py index 2515701f9..c3ec1adcc 100644 --- a/specforge/modeling/target/custom_backend/phi3.py +++ b/specforge/modeling/target/custom_backend/phi3.py @@ -43,7 +43,6 @@ from transformers.processing_utils import Unpack from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple from transformers.utils.deprecation import deprecate_kwarg -from transformers.utils.generic import check_model_inputs from specforge.distributed import get_tp_group from specforge.layers import ( @@ -284,7 +283,6 @@ def __init__(self, config: Phi3Config): # Initialize weights and apply final processing self.post_init() - @check_model_inputs @auto_docstring def forward( self, diff --git a/specforge/modeling/target/eagle3_target_model.py b/specforge/modeling/target/eagle3_target_model.py index 2acf50ba5..a505194d5 100644 --- a/specforge/modeling/target/eagle3_target_model.py +++ b/specforge/modeling/target/eagle3_target_model.py @@ -94,6 +94,11 @@ def set_aux_hidden_states_layers( if aux_hidden_states_layers is None: if hasattr(self.model.config, "num_hidden_layers"): num_layers = self.model.config.num_hidden_layers + + elif hasattr(self.model.config, "text_config") and hasattr( + self.model.config.text_config, "num_hidden_layers" + ): + num_layers = self.model.config.text_config.num_hidden_layers else: raise ValueError( f"Failed to set aux hidden states layers as model config {self.model.config} does not have num_hidden_layers" @@ -154,18 +159,20 @@ def _get_transformer_layers(self): Helper to find the module list containing the transformer layers. Adapts to common architectures (Llama, Qwen, Mistral, OPT, etc.) """ - if hasattr(self.model, "model") and hasattr(self.model.model, "layers"): - return self.model.model.layers + if hasattr(self.model, "model"): + if hasattr(self.model.model, "layers"): + return self.model.model.layers + elif hasattr(self.model.model, "language_model"): + return self.model.model.language_model.layers elif hasattr(self.model, "layers"): return self.model.layers elif hasattr(self.model, "transformer") and hasattr( self.model.transformer, "h" ): return self.model.transformer.h - else: - raise ValueError( - "Could not locate transformer layers in the model architecture to register hooks." - ) + raise ValueError( + "Could not locate transformer layers in the model architecture to register hooks." + ) @torch.no_grad() def generate_eagle3_data( diff --git a/specforge/utils.py b/specforge/utils.py index af4d627c8..543908643 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -379,6 +379,10 @@ def safe_conversations_generator(file_path): # Build result with conversations result = {"conversations": cleaned_convs} + # Preserve 'text' field if present (for preformatted data) + if "text" in row: + result["text"] = row["text"] + # Preserve 'tools' field if present if "tools" in row: tools = row["tools"]