Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions scripts/train_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import argparse
import functools
import json
import logging
import math
import os
Expand Down Expand Up @@ -67,6 +68,12 @@ def parse_args():
choices=["eager", "sdpa", "flex_attention"],
help="Attention backend for draft model.",
)
model_group.add_argument(
"--flex-kernel-options-json",
type=json.loads,
default=None,
help="JSON dict forwarded as kernel_options when attention-backend=flex_attention.",
)
model_group.add_argument(
"--trust-remote-code", action="store_true", help="Trust remote code"
)
Expand Down Expand Up @@ -375,6 +382,17 @@ def main():
)

args = parse_args()
flex_kernel_options = args.flex_kernel_options_json
if flex_kernel_options is not None:
if args.attention_backend != "flex_attention":
raise ValueError(
"--flex-kernel-options-json can only be used when "
"--attention-backend is 'flex_attention'."
)
if not isinstance(flex_kernel_options, dict):
raise ValueError(
"--flex-kernel-options-json must decode to a JSON object."
)
set_seed(args.seed)

init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size)
Expand Down Expand Up @@ -460,6 +478,7 @@ def main():
loss_decay_gamma=args.loss_decay_gamma,
loss_type=args.loss_type,
dpace_alpha=args.dpace_alpha,
flex_kernel_options=flex_kernel_options,
)

# Wrap each transformer block as its own FSDP unit so that all-gather /
Expand Down
19 changes: 19 additions & 0 deletions scripts/train_domino.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Domino Training Script."""

import argparse
import json
import logging
import math
import os
Expand Down Expand Up @@ -64,6 +65,12 @@ def parse_args():
choices=["eager", "sdpa", "flex_attention"],
help="Attention backend for draft model.",
)
model_group.add_argument(
"--flex-kernel-options-json",
type=json.loads,
default=None,
help="JSON dict forwarded as kernel_options when attention-backend=flex_attention.",
)
model_group.add_argument(
"--trust-remote-code", action="store_true", help="Trust remote code"
)
Expand Down Expand Up @@ -444,6 +451,17 @@ def main():
)

args = parse_args()
flex_kernel_options = args.flex_kernel_options_json
if flex_kernel_options is not None:
if args.attention_backend != "flex_attention":
raise ValueError(
"--flex-kernel-options-json can only be used when "
"--attention-backend is 'flex_attention'."
)
if not isinstance(flex_kernel_options, dict):
raise ValueError(
"--flex-kernel-options-json must decode to a JSON object."
)
set_seed(args.seed)

init_distributed(timeout=args.dist_timeout, tp_size=args.tp_size)
Expand Down Expand Up @@ -528,6 +546,7 @@ def main():
num_anchors=args.num_anchors,
loss_decay_gamma=args.loss_decay_gamma,
shift_label=draft_model.shift_label,
flex_kernel_options=flex_kernel_options,
)

domino_model = FSDP(
Expand Down
12 changes: 11 additions & 1 deletion specforge/core/dflash.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# coding=utf-8
"""DFlash Training Wrapper."""

from typing import Optional, Tuple
from typing import Dict, Optional, Tuple

import torch
import torch.nn as nn
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(
loss_decay_gamma: Optional[float] = None,
loss_type: str = "dflash",
dpace_alpha: float = 0.5,
flex_kernel_options: Optional[Dict] = None,
):
super().__init__()
if loss_type not in _VALID_LOSS_TYPES:
Expand All @@ -136,6 +137,7 @@ def __init__(
self.loss_decay_gamma = loss_decay_gamma
self.loss_type = loss_type
self.dpace_alpha = dpace_alpha
self.flex_kernel_options = flex_kernel_options

self._cached_block_mask: Optional[BlockMask] = None
self._cached_seq_len: Optional[int] = None
Expand Down Expand Up @@ -304,11 +306,19 @@ def forward(
device=device,
)

draft_forward_kwargs = {}
if (
self.attention_backend == "flex_attention"
and self.flex_kernel_options is not None
):
draft_forward_kwargs["kernel_options"] = self.flex_kernel_options

output_hidden = self.draft_model(
position_ids=full_position_ids,
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
**draft_forward_kwargs,
)
Comment on lines 318 to 322

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Passing kernel_options directly to self.draft_model will cause a TypeError when using other attention backends like sdpa or eager (which are valid choices in the CLI), because their underlying attention functions do not accept kernel_options as a parameter. To prevent this, we should only pass kernel_options when the attention backend is set to flex_attention and flex_kernel_options is provided.

Suggested change
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
kernel_options=self.flex_kernel_options,
)
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
**({'kernel_options': self.flex_kernel_options} if (self.attention_backend == 'flex_attention' and self.flex_kernel_options is not None) else {}),
)


logits = self.lm_head(output_hidden)
Expand Down
10 changes: 10 additions & 0 deletions specforge/core/domino.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
num_anchors: int = 512,
loss_decay_gamma: Optional[float] = None,
shift_label: bool = False,
flex_kernel_options: Optional[Dict] = None,
):
super().__init__()
self.draft_model = draft_model
Expand All @@ -55,6 +56,7 @@ def __init__(
self.num_anchors = num_anchors
self.loss_decay_gamma = loss_decay_gamma
self.shift_label = shift_label
self.flex_kernel_options = flex_kernel_options

self._cached_block_mask: Optional[BlockMask] = None
self._cached_seq_len: Optional[int] = None
Expand Down Expand Up @@ -332,11 +334,19 @@ def forward(
device=device,
)

draft_forward_kwargs = {}
if (
self.attention_backend == "flex_attention"
and self.flex_kernel_options is not None
):
draft_forward_kwargs["kernel_options"] = self.flex_kernel_options

output_hidden = self.draft_model(
position_ids=full_position_ids,
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
**draft_forward_kwargs,
)
Comment on lines 346 to 350

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Passing kernel_options directly to self.draft_model will cause a TypeError when using other attention backends like sdpa or eager (which are valid choices in the CLI), because their underlying attention functions do not accept kernel_options as a parameter. To prevent this, we should only pass kernel_options when the attention backend is set to flex_attention and flex_kernel_options is provided.

Suggested change
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
kernel_options=self.flex_kernel_options,
)
noise_embedding=noise_embedding,
target_hidden=hidden_states,
attention_mask=dflash_attn_mask,
**({'kernel_options': self.flex_kernel_options} if (self.attention_backend == 'flex_attention' and self.flex_kernel_options is not None) else {}),
)


# --- Labels ---
Expand Down