Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions configs/qwen3-8b-dta.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
"architectures": [
"DFlashDraftModel"
],
"attention_bias": false,
"attention_dropout": 0.0,
"auto_map": {
"AutoModel": "dflash.DFlashDraftModel"
},
"block_size": 16,
"bos_token_id": 151643,
"dflash_config": {
"mask_token_id": 151669,
"target_layer_ids": [1, 9, 17, 25, 33],
"training_mode": "vp_drafter",
"prefix_weight_base": 0.9
},
"dtype": "bfloat16",
"eos_token_id": 151645,
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 12288,
"layer_types": [
"full_attention",
"full_attention",
"full_attention",
"full_attention",
"full_attention"
],
"max_position_embeddings": 40960,
"max_window_layers": 5,
"model_type": "qwen3",
"num_attention_heads": 32,
"num_hidden_layers": 5,
"num_key_value_heads": 8,
"num_target_layers": 36,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 1000000,
"sliding_window": null,
"tie_word_embeddings": false,
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
}
35 changes: 35 additions & 0 deletions examples/run_qwen3_8b_dta_online.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
export SPECFORGE_DATA_NUM_PROC=32
NUM_GPUS=${1:-8}

ATTENTION_BACKEND=${2:-flex_attention}

torchrun \
--standalone \
--nproc_per_node $NUM_GPUS \
$ROOT_DIR/scripts/train_dflash.py \
--target-model-path Qwen/Qwen3-8B \
--target-model-backend sglang \
--draft-config-path $ROOT_DIR/configs/qwen3-8b-dta.json \
--train-data-path $ROOT_DIR/cache/dataset/perfectblend_qwen3-8b_regen.jsonl \
--output-dir $ROOT_DIR/outputs/qwen3-8b-dta-perfectblend \
--num-epochs 6 \
--batch-size 4 \
--learning-rate 6e-4 \
--warmup-ratio 0.04 \
--max-grad-norm 1.0 \
--max-length 3072 \
--chat-template qwen \
--attention-backend $ATTENTION_BACKEND \
--loss-decay-gamma 7.0 \
--log-interval 50 \
--save-interval 1000 \
--report-to wandb \
--wandb-project specforge-qwen3-8b-dta \
--block-size 16 \
--num-anchors 512 \
--wandb-name qwen3-8b-dta-perfectblend
32 changes: 30 additions & 2 deletions scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,34 @@ def parse_args():
model_group.add_argument(
"--loss-type",
type=str,
default="dflash",
default=None,
choices=[
"dflash",
"vp_drafter",
"dpace",
"dpace-cumulative-confidence-only",
"dpace-continuation-value-only",
],
help=("Loss variant. Use dpace for Dynamic Position-Aware Cross-Entropy."),
help=(
"Training objective. If omitted, reads dflash_config.training_mode or "
"dflash_config.loss_type from the draft config, defaulting to dflash."
),
)
model_group.add_argument(
"--dpace-alpha",
type=float,
default=0.5,
help="Smoothing alpha for D-PACE position weights.",
)
model_group.add_argument(
"--prefix-weight-base",
type=float,
default=None,
help=(
"VP-Drafter prefix length sampling base. Values below 1 prefer shorter "
"visible prefixes; defaults to dflash_config.prefix_weight_base or 0.9."
),
)
model_group.add_argument(
"--embedding-key",
type=str,
Expand Down Expand Up @@ -218,8 +231,20 @@ def build_models(args) -> Tuple[DFlashTargetModel, DFlashDraftModel]:
if not hasattr(draft_config, "dflash_config") or draft_config.dflash_config is None:
draft_config.dflash_config = {}

args.loss_type = (
args.loss_type
or draft_config.dflash_config.get("training_mode")
or draft_config.dflash_config.get("loss_type")
or "dflash"
)
if args.prefix_weight_base is None:
args.prefix_weight_base = draft_config.dflash_config.get(
"prefix_weight_base", 0.9
)

draft_config._attn_implementation = args.attention_backend
print_on_rank0(f"Using attention backend: {args.attention_backend}")
print_on_rank0(f"Using DFlash training loss_type: {args.loss_type}")

draft_model = DFlashDraftModel(draft_config).to(device=device, dtype=torch.bfloat16)

Expand Down Expand Up @@ -449,6 +474,8 @@ def main():
print_on_rank0(f"Total training steps: {total_steps}")

print_on_rank0("Loading target embeddings and head...")
device = get_local_device()
device_type = device.type
target_components = TargetEmbeddingsAndHead.from_pretrained(
args.target_model_path,
embed_key=args.embedding_key,
Expand All @@ -468,6 +495,7 @@ def main():
loss_decay_gamma=args.loss_decay_gamma,
loss_type=args.loss_type,
dpace_alpha=args.dpace_alpha,
prefix_weight_base=args.prefix_weight_base,
)

# Wrap each transformer block as its own FSDP unit so that all-gather /
Expand Down
103 changes: 98 additions & 5 deletions specforge/core/dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@

_VALID_LOSS_TYPES = {
"dflash",
"vp_drafter",
"dpace",
"dpace-cumulative-confidence-only",
"dpace-continuation-value-only",
}
_DPACE_LOSS_TYPES = _VALID_LOSS_TYPES - {"dflash"}
_DPACE_LOSS_TYPES = _VALID_LOSS_TYPES - {"dflash", "vp_drafter"}


def create_dflash_sdpa_mask(anchor_positions, block_keep_mask, S, block_size, device):
Expand Down Expand Up @@ -121,6 +122,7 @@ def __init__(
loss_decay_gamma: Optional[float] = None,
loss_type: str = "dflash",
dpace_alpha: float = 0.5,
prefix_weight_base: float = 0.9,
):
super().__init__()
if loss_type not in _VALID_LOSS_TYPES:
Expand All @@ -129,6 +131,12 @@ def __init__(
)
if not 0.0 <= dpace_alpha <= 1.0:
raise ValueError(f"dpace_alpha must be in [0, 1], got {dpace_alpha}")
if prefix_weight_base is None:
prefix_weight_base = 0.9
if prefix_weight_base <= 0.0:
raise ValueError(
f"prefix_weight_base must be positive, got {prefix_weight_base}"
)
Comment thread
catnanami marked this conversation as resolved.

self.draft_model = draft_model
self.lm_head = target_lm_head
Expand All @@ -140,6 +148,7 @@ def __init__(
self.loss_decay_gamma = loss_decay_gamma
self.loss_type = loss_type
self.dpace_alpha = dpace_alpha
self.prefix_weight_base = prefix_weight_base

self._cached_block_mask: Optional[BlockMask] = None
self._cached_seq_len: Optional[int] = None
Expand Down Expand Up @@ -183,6 +192,33 @@ def _sample_anchor_positions(

return anchors, keep_mask

def _sample_prefix_lengths(
self, bsz: int, n_blocks: int, device: torch.device
) -> torch.Tensor:
"""Sample visible prefix lengths for VP-Drafter training.

A prefix length i means block positions [0, i) are visible real tokens and
positions [i, block_size) are masked prediction targets. The sampled
range follows D2SD's variable-prefix recipe while avoiding the degenerate
fixed-anchor DFlash case.
"""
min_prefix = min(2, self.block_size - 1)
max_prefix = self.block_size - 1
if max_prefix <= min_prefix:
return torch.full(
(bsz, n_blocks), min_prefix, dtype=torch.long, device=device
)

prefix_ids = torch.arange(min_prefix, max_prefix + 1, device=device)
weights = torch.pow(
torch.full_like(prefix_ids, self.prefix_weight_base, dtype=torch.float32),
prefix_ids.float(),
)
samples = torch.multinomial(
weights, num_samples=bsz * n_blocks, replacement=True
).reshape(bsz, n_blocks)
return samples + min_prefix
Comment thread
catnanami marked this conversation as resolved.

def prepare_noise_input(
self, input_ids: torch.Tensor, block_ids: Optional[torch.Tensor] = None
) -> torch.Tensor:
Expand Down Expand Up @@ -235,6 +271,36 @@ def _create_noise_embed(self, input_ids, anchor_positions, block_keep_mask):

return self.embed_tokens(noise_ids)

def _create_vp_noise_embed(
self,
input_ids: torch.Tensor,
anchor_positions: torch.Tensor,
block_keep_mask: torch.Tensor,
prefix_lengths: torch.Tensor,
) -> torch.Tensor:
"""Prepare VP-Drafter inputs with variable visible prefixes."""
bsz, seq_len = input_ids.shape
n = anchor_positions.shape[1]
bs = self.block_size
device = input_ids.device

offsets = torch.arange(bs, device=device).view(1, 1, -1)
token_positions = anchor_positions.unsqueeze(-1) + offsets
safe_positions = token_positions.clamp(0, seq_len - 1)

real_tokens = torch.gather(
input_ids.unsqueeze(1).expand(-1, n, -1),
2,
safe_positions,
)
visible_prefix = offsets < prefix_lengths.unsqueeze(-1)
valid_positions = token_positions < seq_len
fill_mask = visible_prefix & block_keep_mask.unsqueeze(-1) & valid_positions

mask_tokens = torch.full_like(real_tokens, self.mask_token_id)
noise_ids = torch.where(fill_mask, real_tokens, mask_tokens)
return self.embed_tokens(noise_ids.reshape(bsz, n * bs))

def _dpace_weight(
self,
prob: torch.Tensor,
Expand Down Expand Up @@ -285,9 +351,18 @@ def forward(
seq_len, loss_mask, device
)

noise_embedding = self._create_noise_embed(
input_ids, anchor_positions, block_keep_mask
)
prefix_lengths = None
if self.loss_type == "vp_drafter":
prefix_lengths = self._sample_prefix_lengths(
bsz, anchor_positions.shape[1], device
)
noise_embedding = self._create_vp_noise_embed(
input_ids, anchor_positions, block_keep_mask, prefix_lengths
)
else:
noise_embedding = self._create_noise_embed(
input_ids, anchor_positions, block_keep_mask
)

context_position_ids = (
torch.arange(seq_len, device=device).unsqueeze(0).expand(bsz, -1)
Expand Down Expand Up @@ -340,7 +415,12 @@ def forward(
weight_mask = weight_mask * valid_label_mask.float()

pos_in_block = torch.arange(self.block_size, device=device).view(1, 1, -1)
weight_mask = weight_mask * (pos_in_block > 0).float()
if self.loss_type == "vp_drafter":
weight_mask = (
weight_mask * (pos_in_block >= prefix_lengths.unsqueeze(-1)).float()
)
else:
weight_mask = weight_mask * (pos_in_block > 0).float()

original_loss_mask_gathered = torch.gather(
loss_mask.unsqueeze(1).expand(-1, anchor_positions.size(1), -1),
Expand All @@ -367,6 +447,19 @@ def forward(
)
loss_weights = loss_weights * decay_weights

flat_weights = loss_weights.view(-1)
valid_token_count = flat_weights.sum() + 1e-6
loss = (loss_per_token * flat_weights).sum() / valid_token_count
elif self.loss_type == "vp_drafter":
loss_weights = weight_mask
if self.loss_decay_gamma is not None and self.loss_decay_gamma > 0:
k = torch.arange(self.block_size, device=device).view(1, 1, -1)
effective_pos = (
k.float() - prefix_lengths.unsqueeze(-1).float()
).clamp(min=0)
decay_weights = torch.exp(-effective_pos / self.loss_decay_gamma)
loss_weights = loss_weights * decay_weights

flat_weights = loss_weights.view(-1)
valid_token_count = flat_weights.sum() + 1e-6
loss = (loss_per_token * flat_weights).sum() / valid_token_count
Expand Down
Loading
Loading