diff --git a/configs/qwen3-8b-dta.json b/configs/qwen3-8b-dta.json new file mode 100644 index 000000000..d226ba250 --- /dev/null +++ b/configs/qwen3-8b-dta.json @@ -0,0 +1,47 @@ +{ + "architectures": [ + "DFlashDraftModel" + ], + "attention_bias": false, + "attention_dropout": 0.0, + "auto_map": { + "AutoModel": "dflash.DFlashDraftModel" + }, + "block_size": 16, + "bos_token_id": 151643, + "dflash_config": { + "mask_token_id": 151669, + "target_layer_ids": [1, 9, 17, 25, 33], + "training_mode": "vp_drafter", + "prefix_weight_base": 0.9 + }, + "dtype": "bfloat16", + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 12288, + "layer_types": [ + "full_attention", + "full_attention", + "full_attention", + "full_attention", + "full_attention" + ], + "max_position_embeddings": 40960, + "max_window_layers": 5, + "model_type": "qwen3", + "num_attention_heads": 32, + "num_hidden_layers": 5, + "num_key_value_heads": 8, + "num_target_layers": 36, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "sliding_window": null, + "tie_word_embeddings": false, + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936 +} diff --git a/examples/run_qwen3_8b_dta_online.sh b/examples/run_qwen3_8b_dta_online.sh new file mode 100755 index 000000000..4339a8ba8 --- /dev/null +++ b/examples/run_qwen3_8b_dta_online.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +ROOT_DIR=$(dirname $SCRIPT_DIR) +export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels +export SPECFORGE_DATA_NUM_PROC=32 +NUM_GPUS=${1:-8} + +ATTENTION_BACKEND=${2:-flex_attention} + +torchrun \ + --standalone \ + --nproc_per_node $NUM_GPUS \ + $ROOT_DIR/scripts/train_dflash.py \ + --target-model-path Qwen/Qwen3-8B \ + --target-model-backend sglang \ + --draft-config-path $ROOT_DIR/configs/qwen3-8b-dta.json \ + --train-data-path $ROOT_DIR/cache/dataset/perfectblend_qwen3-8b_regen.jsonl \ + --output-dir $ROOT_DIR/outputs/qwen3-8b-dta-perfectblend \ + --num-epochs 6 \ + --batch-size 4 \ + --learning-rate 6e-4 \ + --warmup-ratio 0.04 \ + --max-grad-norm 1.0 \ + --max-length 3072 \ + --chat-template qwen \ + --attention-backend $ATTENTION_BACKEND \ + --loss-decay-gamma 7.0 \ + --log-interval 50 \ + --save-interval 1000 \ + --report-to wandb \ + --wandb-project specforge-qwen3-8b-dta \ + --block-size 16 \ + --num-anchors 512 \ + --wandb-name qwen3-8b-dta-perfectblend diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index d899d95a8..fc3ba89a5 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -92,14 +92,18 @@ def parse_args(): model_group.add_argument( "--loss-type", type=str, - default="dflash", + default=None, choices=[ "dflash", + "vp_drafter", "dpace", "dpace-cumulative-confidence-only", "dpace-continuation-value-only", ], - help=("Loss variant. Use dpace for Dynamic Position-Aware Cross-Entropy."), + help=( + "Training objective. If omitted, reads dflash_config.training_mode or " + "dflash_config.loss_type from the draft config, defaulting to dflash." + ), ) model_group.add_argument( "--dpace-alpha", @@ -107,6 +111,15 @@ def parse_args(): default=0.5, help="Smoothing alpha for D-PACE position weights.", ) + model_group.add_argument( + "--prefix-weight-base", + type=float, + default=None, + help=( + "VP-Drafter prefix length sampling base. Values below 1 prefer shorter " + "visible prefixes; defaults to dflash_config.prefix_weight_base or 0.9." + ), + ) model_group.add_argument( "--embedding-key", type=str, @@ -218,8 +231,20 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]: if not hasattr(draft_config, "dflash_config") or draft_config.dflash_config is None: draft_config.dflash_config = {} + args.loss_type = ( + args.loss_type + or draft_config.dflash_config.get("training_mode") + or draft_config.dflash_config.get("loss_type") + or "dflash" + ) + if args.prefix_weight_base is None: + args.prefix_weight_base = draft_config.dflash_config.get( + "prefix_weight_base", 0.9 + ) + draft_config._attn_implementation = args.attention_backend print_on_rank0(f"Using attention backend: {args.attention_backend}") + print_on_rank0(f"Using DFlash training loss_type: {args.loss_type}") draft_model = DFlashDraftModel(draft_config).to(device=device, dtype=torch.bfloat16) @@ -449,6 +474,8 @@ def main(): print_on_rank0(f"Total training steps: {total_steps}") print_on_rank0("Loading target embeddings and head...") + device = get_local_device() + device_type = device.type target_components = TargetEmbeddingsAndHead.from_pretrained( args.target_model_path, embed_key=args.embedding_key, @@ -468,6 +495,7 @@ def main(): loss_decay_gamma=args.loss_decay_gamma, loss_type=args.loss_type, dpace_alpha=args.dpace_alpha, + prefix_weight_base=args.prefix_weight_base, ) # Wrap each transformer block as its own FSDP unit so that all-gather / diff --git a/specforge/core/dflash.py b/specforge/core/dflash.py index 2b6e5ec56..bf1724e38 100644 --- a/specforge/core/dflash.py +++ b/specforge/core/dflash.py @@ -25,11 +25,12 @@ _VALID_LOSS_TYPES = { "dflash", + "vp_drafter", "dpace", "dpace-cumulative-confidence-only", "dpace-continuation-value-only", } -_DPACE_LOSS_TYPES = _VALID_LOSS_TYPES - {"dflash"} +_DPACE_LOSS_TYPES = _VALID_LOSS_TYPES - {"dflash", "vp_drafter"} def create_dflash_sdpa_mask(anchor_positions, block_keep_mask, S, block_size, device): @@ -121,6 +122,7 @@ def __init__( loss_decay_gamma: Optional[float] = None, loss_type: str = "dflash", dpace_alpha: float = 0.5, + prefix_weight_base: float = 0.9, ): super().__init__() if loss_type not in _VALID_LOSS_TYPES: @@ -129,6 +131,12 @@ def __init__( ) if not 0.0 <= dpace_alpha <= 1.0: raise ValueError(f"dpace_alpha must be in [0, 1], got {dpace_alpha}") + if prefix_weight_base is None: + prefix_weight_base = 0.9 + if prefix_weight_base <= 0.0: + raise ValueError( + f"prefix_weight_base must be positive, got {prefix_weight_base}" + ) self.draft_model = draft_model self.lm_head = target_lm_head @@ -140,6 +148,7 @@ def __init__( self.loss_decay_gamma = loss_decay_gamma self.loss_type = loss_type self.dpace_alpha = dpace_alpha + self.prefix_weight_base = prefix_weight_base self._cached_block_mask: Optional[BlockMask] = None self._cached_seq_len: Optional[int] = None @@ -183,6 +192,33 @@ def _sample_anchor_positions( return anchors, keep_mask + def _sample_prefix_lengths( + self, bsz: int, n_blocks: int, device: torch.device + ) -> torch.Tensor: + """Sample visible prefix lengths for VP-Drafter training. + + A prefix length i means block positions [0, i) are visible real tokens and + positions [i, block_size) are masked prediction targets. The sampled + range follows D2SD's variable-prefix recipe while avoiding the degenerate + fixed-anchor DFlash case. + """ + min_prefix = min(2, self.block_size - 1) + max_prefix = self.block_size - 1 + if max_prefix <= min_prefix: + return torch.full( + (bsz, n_blocks), min_prefix, dtype=torch.long, device=device + ) + + prefix_ids = torch.arange(min_prefix, max_prefix + 1, device=device) + weights = torch.pow( + torch.full_like(prefix_ids, self.prefix_weight_base, dtype=torch.float32), + prefix_ids.float(), + ) + samples = torch.multinomial( + weights, num_samples=bsz * n_blocks, replacement=True + ).reshape(bsz, n_blocks) + return samples + min_prefix + def prepare_noise_input( self, input_ids: torch.Tensor, block_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: @@ -235,6 +271,36 @@ def _create_noise_embed(self, input_ids, anchor_positions, block_keep_mask): return self.embed_tokens(noise_ids) + def _create_vp_noise_embed( + self, + input_ids: torch.Tensor, + anchor_positions: torch.Tensor, + block_keep_mask: torch.Tensor, + prefix_lengths: torch.Tensor, + ) -> torch.Tensor: + """Prepare VP-Drafter inputs with variable visible prefixes.""" + bsz, seq_len = input_ids.shape + n = anchor_positions.shape[1] + bs = self.block_size + device = input_ids.device + + offsets = torch.arange(bs, device=device).view(1, 1, -1) + token_positions = anchor_positions.unsqueeze(-1) + offsets + safe_positions = token_positions.clamp(0, seq_len - 1) + + real_tokens = torch.gather( + input_ids.unsqueeze(1).expand(-1, n, -1), + 2, + safe_positions, + ) + visible_prefix = offsets < prefix_lengths.unsqueeze(-1) + valid_positions = token_positions < seq_len + fill_mask = visible_prefix & block_keep_mask.unsqueeze(-1) & valid_positions + + mask_tokens = torch.full_like(real_tokens, self.mask_token_id) + noise_ids = torch.where(fill_mask, real_tokens, mask_tokens) + return self.embed_tokens(noise_ids.reshape(bsz, n * bs)) + def _dpace_weight( self, prob: torch.Tensor, @@ -285,9 +351,18 @@ def forward( seq_len, loss_mask, device ) - noise_embedding = self._create_noise_embed( - input_ids, anchor_positions, block_keep_mask - ) + prefix_lengths = None + if self.loss_type == "vp_drafter": + prefix_lengths = self._sample_prefix_lengths( + bsz, anchor_positions.shape[1], device + ) + noise_embedding = self._create_vp_noise_embed( + input_ids, anchor_positions, block_keep_mask, prefix_lengths + ) + else: + noise_embedding = self._create_noise_embed( + input_ids, anchor_positions, block_keep_mask + ) context_position_ids = ( torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1) @@ -340,7 +415,12 @@ def forward( weight_mask = weight_mask * valid_label_mask.float() pos_in_block = torch.arange(self.block_size, device=device).view(1, 1, -1) - weight_mask = weight_mask * (pos_in_block > 0).float() + if self.loss_type == "vp_drafter": + weight_mask = ( + weight_mask * (pos_in_block >= prefix_lengths.unsqueeze(-1)).float() + ) + else: + weight_mask = weight_mask * (pos_in_block > 0).float() original_loss_mask_gathered = torch.gather( loss_mask.unsqueeze(1).expand(-1, anchor_positions.size(1), -1), @@ -367,6 +447,19 @@ def forward( ) loss_weights = loss_weights * decay_weights + flat_weights = loss_weights.view(-1) + valid_token_count = flat_weights.sum() + 1e-6 + loss = (loss_per_token * flat_weights).sum() / valid_token_count + elif self.loss_type == "vp_drafter": + loss_weights = weight_mask + if self.loss_decay_gamma is not None and self.loss_decay_gamma > 0: + k = torch.arange(self.block_size, device=device).view(1, 1, -1) + effective_pos = ( + k.float() - prefix_lengths.unsqueeze(-1).float() + ).clamp(min=0) + decay_weights = torch.exp(-effective_pos / self.loss_decay_gamma) + loss_weights = loss_weights * decay_weights + flat_weights = loss_weights.view(-1) valid_token_count = flat_weights.sum() + 1e-6 loss = (loss_per_token * flat_weights).sum() / valid_token_count diff --git a/tests/test_utils/test_dflash_losses.py b/tests/test_utils/test_dflash_losses.py index 43ed2919e..e4c3b8be9 100644 --- a/tests/test_utils/test_dflash_losses.py +++ b/tests/test_utils/test_dflash_losses.py @@ -94,6 +94,12 @@ def _fixed_noise_embed(self, input_ids, anchor_positions, block_keep_mask): ) +def _fixed_vp_noise_embed( + self, input_ids, anchor_positions, block_keep_mask, prefix_lengths +): + return _fixed_noise_embed(self, input_ids, anchor_positions, block_keep_mask) + + def _fixed_anchor_sampler(anchors, keep_mask): def _sample(self, seq_len, loss_mask, device): return anchors.to(device), keep_mask.to(device) @@ -101,7 +107,14 @@ def _sample(self, seq_len, loss_mask, device): return _sample -def _make_model(logits, anchors, keep_mask, **kwargs): +def _fixed_prefix_sampler(prefix_lengths): + def _sample(self, bsz, n_blocks, device): + return prefix_lengths.to(device) + + return _sample + + +def _make_model(logits, anchors, keep_mask, prefix_lengths=None, **kwargs): bsz, n_blocks, block_size, vocab_size = logits.shape model = OnlineDFlashModel( draft_model=_FixedDraft(hidden_size=4), @@ -119,6 +132,11 @@ def _make_model(logits, anchors, keep_mask, **kwargs): _fixed_anchor_sampler(anchors, keep_mask), model ) model._create_noise_embed = types.MethodType(_fixed_noise_embed, model) + model._create_vp_noise_embed = types.MethodType(_fixed_vp_noise_embed, model) + if prefix_lengths is not None: + model._sample_prefix_lengths = types.MethodType( + _fixed_prefix_sampler(prefix_lengths), model + ) return model @@ -166,6 +184,31 @@ def _targets_and_mask(input_ids, loss_mask, anchors, keep_mask, block_size): return targets, binary_mask +def _vp_targets_and_mask( + input_ids, loss_mask, anchors, keep_mask, prefix_lengths, block_size +): + bsz, seq_len = input_ids.shape + n_blocks = anchors.shape[1] + offsets = torch.arange(block_size).view(1, 1, -1) + label_indices = anchors.unsqueeze(-1) + offsets + safe_indices = label_indices.clamp(max=seq_len - 1) + targets = torch.gather( + input_ids.unsqueeze(1).expand(-1, n_blocks, -1), + 2, + safe_indices, + ) + binary_mask = keep_mask.unsqueeze(-1).expand(-1, -1, block_size).double() + binary_mask = binary_mask * (label_indices < seq_len).double() + binary_mask = binary_mask * (offsets >= prefix_lengths.unsqueeze(-1)).double() + gathered_loss_mask = torch.gather( + loss_mask.unsqueeze(1).expand(-1, n_blocks, -1), + 2, + safe_indices, + ) + binary_mask = binary_mask * gathered_loss_mask + return targets, binary_mask + + def _neg_log_q(logits, targets): return F.cross_entropy( logits.reshape(-1, logits.size(-1)), @@ -201,6 +244,17 @@ def _naive_dflash_loss(neg_log_q, binary_mask, gamma): return (neg_log_q * weight).sum() / (weight.sum() + 1e-6) +def _naive_vp_drafter_loss(neg_log_q, binary_mask, prefix_lengths, gamma): + weight = binary_mask + if gamma is not None and gamma > 0: + block_size = neg_log_q.shape[-1] + positions = torch.arange(block_size, dtype=neg_log_q.dtype).view(1, 1, -1) + prefix = prefix_lengths.unsqueeze(-1).to(dtype=neg_log_q.dtype) + decay = torch.exp(-(positions - prefix).clamp(min=0) / gamma) + weight = weight * decay + return (neg_log_q * weight).sum() / (weight.sum() + 1e-6) + + class TestDFlashLosses(unittest.TestCase): def setUp(self): ( @@ -243,6 +297,35 @@ def test_dflash_decay_gamma_is_preserved(self): want = _naive_dflash_loss(self.neg_log_q, self.binary_mask, gamma=gamma) torch.testing.assert_close(got, want, rtol=0, atol=1e-8) + def test_vp_drafter_masks_visible_prefix_and_decays_from_first_mask(self): + gamma = 7.0 + prefix_lengths = torch.tensor([[2, 3], [2, 4]], dtype=torch.long) + targets, binary_mask = _vp_targets_and_mask( + self.input_ids, + self.loss_mask, + self.anchors, + self.keep_mask, + prefix_lengths, + self.logits.shape[2], + ) + neg_log_q = _neg_log_q(self.logits, targets) + model = _make_model( + self.logits, + self.anchors, + self.keep_mask, + prefix_lengths=prefix_lengths, + loss_type="vp_drafter", + loss_decay_gamma=gamma, + ) + got, accuracy = model( + input_ids=self.input_ids, + hidden_states=self.hidden_states, + loss_mask=self.loss_mask, + ) + want = _naive_vp_drafter_loss(neg_log_q, binary_mask, prefix_lengths, gamma) + self.assertTrue(torch.isfinite(accuracy)) + torch.testing.assert_close(got, want, rtol=0, atol=1e-8) + def test_dpace_full_matches_naive_reference(self): alpha = 0.5 got = self._forward_loss(loss_type="dpace", dpace_alpha=alpha)