From 11254820f3fe61d1180a45644ffbcc2c3449ae5d Mon Sep 17 00:00:00 2001 From: heiheiha798 <2300012738@stu.pku.edu.cn> Date: Thu, 18 Jun 2026 20:41:59 +0800 Subject: [PATCH] Expose flex_attention kernel options in DFlash and Domino training --- scripts/train_dflash.py | 19 +++++++++++++++++++ scripts/train_domino.py | 19 +++++++++++++++++++ specforge/core/dflash.py | 12 +++++++++++- specforge/core/domino.py | 10 ++++++++++ 4 files changed, 59 insertions(+), 1 deletion(-) diff --git a/scripts/train_dflash.py b/scripts/train_dflash.py index cc8531edc..e7fe94aac 100755 --- a/scripts/train_dflash.py +++ b/scripts/train_dflash.py @@ -4,6 +4,7 @@ import argparse import functools +import json import logging import math import os @@ -67,6 +68,12 @@ def parse_args(): choices=["eager", "sdpa", "flex_attention"], help="Attention backend for draft model.", ) + model_group.add_argument( + "--flex-kernel-options-json", + type=json.loads, + default=None, + help="JSON dict forwarded as kernel_options when attention-backend=flex_attention.", + ) model_group.add_argument( "--trust-remote-code", action="store_true", help="Trust remote code" ) @@ -375,6 +382,17 @@ def main(): ) args = parse_args() + flex_kernel_options = args.flex_kernel_options_json + if flex_kernel_options is not None: + if args.attention_backend != "flex_attention": + raise ValueError( + "--flex-kernel-options-json can only be used when " + "--attention-backend is 'flex_attention'." + ) + if not isinstance(flex_kernel_options, dict): + raise ValueError( + "--flex-kernel-options-json must decode to a JSON object." + ) set_seed(args.seed) init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) @@ -460,6 +478,7 @@ def main(): loss_decay_gamma=args.loss_decay_gamma, loss_type=args.loss_type, dpace_alpha=args.dpace_alpha, + flex_kernel_options=flex_kernel_options, ) # Wrap each transformer block as its own FSDP unit so that all-gather / diff --git a/scripts/train_domino.py b/scripts/train_domino.py index 98beae896..68eb50603 100755 --- a/scripts/train_domino.py +++ b/scripts/train_domino.py @@ -3,6 +3,7 @@ """Domino Training Script.""" import argparse +import json import logging import math import os @@ -64,6 +65,12 @@ def parse_args(): choices=["eager", "sdpa", "flex_attention"], help="Attention backend for draft model.", ) + model_group.add_argument( + "--flex-kernel-options-json", + type=json.loads, + default=None, + help="JSON dict forwarded as kernel_options when attention-backend=flex_attention.", + ) model_group.add_argument( "--trust-remote-code", action="store_true", help="Trust remote code" ) @@ -444,6 +451,17 @@ def main(): ) args = parse_args() + flex_kernel_options = args.flex_kernel_options_json + if flex_kernel_options is not None: + if args.attention_backend != "flex_attention": + raise ValueError( + "--flex-kernel-options-json can only be used when " + "--attention-backend is 'flex_attention'." + ) + if not isinstance(flex_kernel_options, dict): + raise ValueError( + "--flex-kernel-options-json must decode to a JSON object." + ) set_seed(args.seed) init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size) @@ -528,6 +546,7 @@ def main(): num_anchors=args.num_anchors, loss_decay_gamma=args.loss_decay_gamma, shift_label=draft_model.shift_label, + flex_kernel_options=flex_kernel_options, ) domino_model = FSDP( diff --git a/specforge/core/dflash.py b/specforge/core/dflash.py index e97847917..301342723 100644 --- a/specforge/core/dflash.py +++ b/specforge/core/dflash.py @@ -1,7 +1,7 @@ # coding=utf-8 """DFlash Training Wrapper.""" -from typing import Optional, Tuple +from typing import Dict, Optional, Tuple import torch import torch.nn as nn @@ -117,6 +117,7 @@ def __init__( loss_decay_gamma: Optional[float] = None, loss_type: str = "dflash", dpace_alpha: float = 0.5, + flex_kernel_options: Optional[Dict] = None, ): super().__init__() if loss_type not in _VALID_LOSS_TYPES: @@ -136,6 +137,7 @@ def __init__( self.loss_decay_gamma = loss_decay_gamma self.loss_type = loss_type self.dpace_alpha = dpace_alpha + self.flex_kernel_options = flex_kernel_options self._cached_block_mask: Optional[BlockMask] = None self._cached_seq_len: Optional[int] = None @@ -304,11 +306,19 @@ def forward( device=device, ) + draft_forward_kwargs = {} + if ( + self.attention_backend == "flex_attention" + and self.flex_kernel_options is not None + ): + draft_forward_kwargs["kernel_options"] = self.flex_kernel_options + output_hidden = self.draft_model( position_ids=full_position_ids, noise_embedding=noise_embedding, target_hidden=hidden_states, attention_mask=dflash_attn_mask, + **draft_forward_kwargs, ) logits = self.lm_head(output_hidden) diff --git a/specforge/core/domino.py b/specforge/core/domino.py index 467a8fbef..1eb413cba 100644 --- a/specforge/core/domino.py +++ b/specforge/core/domino.py @@ -44,6 +44,7 @@ def __init__( num_anchors: int = 512, loss_decay_gamma: Optional[float] = None, shift_label: bool = False, + flex_kernel_options: Optional[Dict] = None, ): super().__init__() self.draft_model = draft_model @@ -55,6 +56,7 @@ def __init__( self.num_anchors = num_anchors self.loss_decay_gamma = loss_decay_gamma self.shift_label = shift_label + self.flex_kernel_options = flex_kernel_options self._cached_block_mask: Optional[BlockMask] = None self._cached_seq_len: Optional[int] = None @@ -332,11 +334,19 @@ def forward( device=device, ) + draft_forward_kwargs = {} + if ( + self.attention_backend == "flex_attention" + and self.flex_kernel_options is not None + ): + draft_forward_kwargs["kernel_options"] = self.flex_kernel_options + output_hidden = self.draft_model( position_ids=full_position_ids, noise_embedding=noise_embedding, target_hidden=hidden_states, attention_mask=dflash_attn_mask, + **draft_forward_kwargs, ) # --- Labels ---