forked from daydreamlive/DEMON
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathstream.py
More file actions
1637 lines (1447 loc) · 71.3 KB
/
stream.py
File metadata and controls
1637 lines (1447 loc) · 71.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""StreamDiffusion-style pipeline for interactive ACE-Step generation.
Maintains a ring buffer of in-flight generations at different denoising
stages. Each tick(), one batched forward pass advances all slots. After
warmup, every tick produces a finished generation.
Supports per-slot denoise and source_latents for cover workflows where
the user adjusts the denoise knob in real time. When a TRT engine is
loaded on the DiffusionEngine, tick() routes through TensorRT
automatically.
"""
from __future__ import annotations
import time
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Callable, Dict, NamedTuple, Optional, List, Tuple, TYPE_CHECKING
from loguru import logger
import torch
from .diffusion import DiffusionConfig, DiffusionEngine
from .model_adapter import ACEAdapter, ModelAdapter
from . import ode_steps
from .dcw import DCWAdvanced, DCWCorrector
if TYPE_CHECKING:
from .masking import LatentNoiseMask
class _SteeringApply(NamedTuple):
"""One pre-resolved activation-steering shift bound to a layer."""
vector: torch.Tensor # 1-D [hidden_dim]
scale: float # alpha * magnitude
step: int # gate: only rows at this denoise step receive it
@dataclass
class SlotCondition:
"""One conditioning entry for multi-condition per-frame blending.
A ``SlotRequest`` always has a primary condition (its
``encoder_hidden_states`` / ``encoder_attention_mask`` /
``primary_temporal_weight`` / ``primary_step_range``) and may carry
additional conditions in ``extra_conditions``. At each step, the
decoder runs once per active condition and the velocities are
blended per-frame by ``temporal_weight`` (see
``ode_steps.blend_velocities``).
"""
encoder_hidden_states: torch.Tensor # [1, L, D]
encoder_attention_mask: torch.Tensor # [1, L]
temporal_weight: Optional[torch.Tensor] = None # [T], [1,T], or [1,T,1]
step_range: Optional[Tuple[float, float]] = None # (start_frac, end_frac)
def is_active_at_step(self, step_idx: int, total_steps: int) -> bool:
if self.step_range is None:
return True
if total_steps <= 0:
return True
progress = step_idx / total_steps
return self.step_range[0] <= progress < self.step_range[1]
@dataclass
class SlotRequest:
"""A generation request to be fed into the pipeline.
Holds the conditioning tensors and noise seed. All requests in a
pipeline must share the same sequence length T (duration).
Optional fields near the bottom (``x0_target_curve``,
``x0_target_gate``, ``initial_noise_curve``, ``latent_mask``,
``extra_conditions``, ``primary_temporal_weight``,
``primary_step_range``) were added so ``StreamPipeline`` can serve
as the single diffusion primitive for both streaming and one-shot
generation (Phase 1 of the diffusion-primitive unification).
The three ACE-shaped conditioning tensors are Optional since the
ModelAdapter seam: non-ACE families carry their conditioning in the
opaque ``aux_cond`` bundle instead and declare T via
``latent_frames`` (see :mod:`acestep.engine.model_adapter`). ACE
requests keep populating all three exactly as before.
"""
encoder_hidden_states: Optional[torch.Tensor] = None # [1, L, D]
encoder_attention_mask: Optional[torch.Tensor] = None # [1, L]
context_latents: Optional[torch.Tensor] = None # [1, T, D_ctx]
# Either a single int (same noise for every row in the request's batch)
# or a list of ints (one seed per row, matching the old
# _prepare_noise_cpu contract used by DiffusionEngine.generate()).
seed: Optional["int | List[int]"] = None
source_latents: Optional[torch.Tensor] = None # [1, T, D] for cover
denoise: float = 1.0 # per-request denoise strength
sde_denoise_curve: Optional[torch.Tensor] = None # [1, T, 1] per-frame denoise
velocity_scale: Optional[torch.Tensor] = None # [1, T, 1] per-frame velocity scaling
ode_noise_curve: Optional[torch.Tensor] = None # [1, T, 1] per-step noise injection
x0_target: Optional[torch.Tensor] = None # [1, T, D] target latent for blending
# Blend strength toward x0_target. Scalar (uniform across the timeline)
# or per-frame curve; both flow through normalize_curve at the read
# site so the engine sees a uniform [B, T, 1] tensor. Hot-mutable via
# set_shared_curve("x0_target_strength", value).
x0_target_strength: "float | torch.Tensor" = 0.0
# --- New in Phase 1: absorb one-shot generate() features ---
x0_target_curve: Optional[torch.Tensor] = None # per-frame blend curve [T], [1,T], or [1,T,1]
x0_target_gate: float = 0.0 # gate-start fraction (matches DiffusionConfig default)
initial_noise_curve: Optional[torch.Tensor] = None # per-frame noise/source init mix
latent_mask: Optional["LatentNoiseMask"] = None # inpainting (two-sided x0 blend)
extra_conditions: List[SlotCondition] = field(default_factory=list)
primary_temporal_weight: Optional[torch.Tensor] = None
primary_step_range: Optional[Tuple[float, float]] = None
# --- CFG (Phase 2) ---
# Flat list of negative conditions for classifier-free guidance. When
# set together with ``guidance_curve``, each step runs a second forward
# pass with the negative conditions and APG
# (:func:`ode_steps.apg_forward`) blends the two velocities per frame.
# Empty list disables CFG.
neg_conditions: List[SlotCondition] = field(default_factory=list)
guidance_curve: Optional[torch.Tensor] = None # [T], [1,T], or [1,T,1]
# APG momentum coefficient. Scalar (Python number) or per-frame curve;
# both flow through normalize_curve at the apg_forward boundary so the
# MomentumBuffer update sees a uniform tensor. Hot-mutable via
# set_shared_curve("apg_momentum", value) on the pipeline.
apg_momentum: "float | torch.Tensor" = -0.75
# --- RCFG (Residual CFG, after StreamDiffusion §3.2) ---
# Cuts the per-step uncond forward pass that standard CFG requires.
# Modes:
# None / "full" : standard two-pass CFG. Runs a negative forward
# every step (existing behavior).
# "initialize" : run the uncond pass once at step 0 per slot, cache
# the resulting velocity, reuse it as the negative
# for all remaining steps of that slot. One extra
# forward per slot, not per step.
# "self" : skip the uncond forward entirely; approximate
# ``v_uncond`` with the slot's initial noise tensor.
# In flow matching ``v = noise - x0``, so with the
# prior x0_uncond ~ 0 we have v_uncond ~ noise.
# Zero extra forwards.
# ``guidance_curve`` is still required (sets the APG scale). For
# "self" mode ``neg_conditions`` is unused; for "initialize" the
# initial uncond pass uses them just like full CFG.
rcfg_mode: Optional[str] = None
# ``cfg_rescale_curve`` blends the APG output's per-frame norm back
# toward ``vt_pos``'s norm (Lin et al. "Common Diffusion Noise
# Schedules and Sample Steps are Flawed"). ``None`` disables;
# otherwise a scalar or ``[1, T, 1]`` curve in [0, 1] where 0 keeps
# raw APG output and 1 fully snaps magnitude back to ``vt_pos``.
cfg_rescale_curve: "Optional[float | torch.Tensor]" = None
# --- Tier 2 (ModelAdapter seam) ---
# Opaque per-request conditioning bundle for non-ACE families. The
# pipeline never inspects it; only the family's adapter does (one
# bundle per request — single-condition v1, per the canonical SA3
# plan §5). ACE requests leave it None.
aux_cond: Optional[dict] = None
# Latent frame count T for families without ``context_latents``
# (the historical T source). Ignored when ``context_latents`` is
# present — the adapter's ``request_frames`` decides.
latent_frames: Optional[int] = None
def all_conditions(self) -> List[SlotCondition]:
"""Return primary + extra conditions as a single ordered list."""
primary = SlotCondition(
encoder_hidden_states=self.encoder_hidden_states,
encoder_attention_mask=self.encoder_attention_mask,
temporal_weight=self.primary_temporal_weight,
step_range=self.primary_step_range,
)
if not self.extra_conditions:
return [primary]
return [primary] + list(self.extra_conditions)
@property
def has_cfg(self) -> bool:
"""True when this request wants APG guidance applied each step.
Three families satisfy this:
- Standard CFG: ``neg_conditions`` + ``guidance_curve``.
- RCFG-initialize: ``neg_conditions`` + ``guidance_curve`` +
``rcfg_mode == 'initialize'`` (same inputs as standard, but
the uncond pass only runs at step 0).
- RCFG-self: ``guidance_curve`` + ``rcfg_mode == 'self'`` (no
``neg_conditions`` needed — uncond is the slot's initial noise).
"""
if self.guidance_curve is None:
return False
if self.rcfg_mode == "self":
return True
return bool(self.neg_conditions)
def needs_neg_forward(self, step_idx: int) -> bool:
"""True when this step requires running the uncond forward pass.
- "self": never (virtual negative).
- "initialize": only at step 0; subsequent steps reuse the
slot's cached velocity.
- None / "full": every step.
"""
if not self.has_cfg:
return False
if self.rcfg_mode == "self":
return False
if self.rcfg_mode == "initialize":
return step_idx == 0
return True
@dataclass
class _Slot:
"""Internal state for one pipeline slot."""
request: SlotRequest
xt: torch.Tensor # [1, T, D] current noisy latent
t_schedule: torch.Tensor # per-slot timestep schedule (on CPU)
step_idx: int = 0 # which denoising step we're on (0-indexed)
# APG momentum accumulator, one per slot with CFG. None for slots
# without CFG (cheaper than allocating an unused buffer).
momentum_buffer: Optional[ode_steps.MomentumBuffer] = None
# RCFG state. ``initial_noise`` is captured at slot init and used as
# the virtual ``v_uncond`` for ``rcfg_mode == 'self'``. ``vt_neg_cached``
# holds the cached uncond velocity for ``rcfg_mode == 'initialize'`` —
# populated on step 0, reused on every subsequent step of this slot.
initial_noise: Optional[torch.Tensor] = None
vt_neg_cached: Optional[torch.Tensor] = None
class StreamPipeline:
"""StreamDiffusion-style batched denoising pipeline.
Pipeline depth = number of denoising steps. After warmup (depth
ticks), every tick() returns a finished latent.
Each slot carries its own timestep schedule derived from its
denoise value, so the user can change denoise between submissions
and each in-flight generation uses the schedule it was born with.
When the DiffusionEngine has a TRT engine loaded, tick() uses
TensorRT for the batched forward pass automatically.
This is a low-level primitive. The ``StreamDenoise`` node in
``acestep.nodes.diffusion_nodes`` is the canonical way to drive
it — there should be exactly one call site constructing this
class anywhere in ``acestep/``.
"""
def __init__(
self,
engine: Optional[DiffusionEngine],
config: DiffusionConfig,
pipeline_depth: Optional[int] = None,
adapter: Optional[ModelAdapter] = None,
):
"""``adapter`` selects the model family behind the Tier-2 seam
(:mod:`acestep.engine.model_adapter`). ``None`` (every existing
call site) builds the default :class:`ACEAdapter` over
``engine``. ``engine`` may be ``None`` only when a non-default
adapter owns its model end-to-end (e.g. SA3); the TRT dispatch
state below is ACE machinery and stays engine-bound."""
self.engine = engine
self.decoder = engine.decoder if engine is not None else None
self.model = engine.model if engine is not None else None
self.config = config
self.adapter: ModelAdapter = (
adapter if adapter is not None else ACEAdapter(self)
)
if engine is None and adapter is None:
raise ValueError(
"StreamPipeline needs a DiffusionEngine for the default "
"ACE adapter; pass adapter= for engine-less families"
)
# Decouple ring buffer depth from denoising step count.
# Default: depth = infer_steps (classic StreamDiffusion).
self._depth: int = pipeline_depth if pipeline_depth is not None else config.infer_steps
# Pipeline state
self._slots: List[Optional[_Slot]] = [None] * self._depth
self._queue: List[SlotRequest] = []
# Cached device/dtype (set on first submit)
self._device: Optional[torch.device] = None
self._dtype: Optional[torch.dtype] = None
# Schedule cache: denoise -> cpu tensor
self._schedule_cache: dict[float, torch.Tensor] = {}
# TRT state (mirrors DiffusionEngine pattern). Snapshotted from
# the engine here, refreshed on profile swaps via the
# engine-swap listener registered below. Engine-less pipelines
# (non-ACE adapters) carry the null snapshot — their adapter
# never dispatches TRT through the pipeline.
self._trt_ctx = engine._trt_ctx if engine is not None else None
self._trt_stream = engine._trt_stream if engine is not None else None
self._trt_engine = engine._trt_engine if engine is not None else None
self._trt_io_dtype = getattr(engine, '_trt_io_dtype', torch.float32)
self._trt_input_dtypes = getattr(engine, "_trt_input_dtypes", {}) or {}
self._trt_output_dtype = getattr(engine, "_trt_output_dtype", self._trt_io_dtype)
# Steering shape (constants per engine); snapshotted so the
# per-tick buffer fill doesn't re-query TRT.
self._steering_num_layers = getattr(engine, "_steering_num_layers", 0)
self._steering_hidden_size = getattr(engine, "_steering_hidden_size", 0)
# Currently-bound TRT I/O buffers (set by _ensure_trt_bufs to one
# entry of _trt_bufs_cache). _trt_forward reads these directly.
self._trt_bufs: Optional[dict] = None
self._trt_out_buf: Optional[torch.Tensor] = None
# LRU cache of (B, eff_T, max_L) -> {bufs..., "_out_buf": tensor}.
# CFG/RCFG passes alternate between pos and neg encoder lengths
# (e.g. L=83 pos, L=66 empty-prompt neg), so a single-shape cache
# thrashes on every forward. 4 entries comfortably covers
# {pos, neg} × {two T values} during T transitions.
self._trt_bufs_cache: "OrderedDict[tuple, dict]" = OrderedDict()
self._trt_bufs_cache_max = 4
# Re-pick up the snapshot after each profile swap. The new
# engine has different I/O profile bounds and its execution
# context owns different tensor addresses, so the previous
# ``_trt_bufs`` cache is invalidated and rebuilt on the next
# forward pass via :meth:`_ensure_trt_bufs`.
if hasattr(engine, "add_engine_swap_listener"):
engine.add_engine_swap_listener(self._on_engine_swapped)
# Shared mutable curves: when a name is present, the corresponding
# per-slot field on every in-flight SlotRequest is overridden for
# that slot's next tick. Bypasses the ring-buffer drain, giving
# 1-tick latency. Invariant: every value is a normalized
# ``[1, T, 1]`` tensor (scalar-or-per-frame multiplier broadcast
# against ``[B, T, *]`` operands). Floats auto-lift to ``[1, 1, 1]``
# at the setter so callers can pass scalars without thinking
# about shape.
self._shared_curves: dict[str, torch.Tensor] = {}
# Channel guidance: a ``[1, T, 64]`` per-channel gain applied to
# ``xt`` before each forward pass. Lives in its own field rather
# than ``_shared_curves`` because its shape (per-channel) breaks
# the dict's per-frame invariant. Updated via
# :meth:`set_channel_guidance` / :meth:`set_channel_gain_tensor`,
# pre-cast to the pipeline's device/dtype so the hot-path
# ``.to(...)`` is a no-op.
self._channel_gain: Optional[torch.Tensor] = None
# Activation steering. Per-DiT-layer additive shift on the
# post-block residual, gated per-row by denoise step.
# ``_current_step_per_row`` is populated by _tick_complex_pt
# around each forward; empty means "skip injection" so a
# forward issued outside the rendezvous can't fire steering.
self._steering_by_layer: Dict[int, List[_SteeringApply]] = {}
self._steering_hooks_installed: bool = False
self._current_step_per_row: List[int] = []
# Sentinel tensors for the "always-on multiply" idiom in the step
# helpers. Built lazily once the first slot's device/dtype is known.
# ``_ones_3d`` stands in for absent ``velocity_scale`` (vt * 1 = vt).
# ``_zeros_3d`` stands in for absent ``ode_noise_curve`` (noise * 0 = 0).
# The sentinels keep the compiled ODE/SDE step graphs branch-free.
self._ones_3d: Optional[torch.Tensor] = None
self._zeros_3d: Optional[torch.Tensor] = None
# Compiled per-step helper cache. Populated lazily by
# ``_get_compiled`` on first use so we don't pay the
# ``torch.compile`` warmup on pipelines that never run their
# PyTorch path (TRT-only streams). Gated on engine.compile_loops —
# when the engine was constructed with ``compile_loops=False``,
# primitives run eagerly (still branch-free, just not compiled).
self._compile_loops: bool = (
getattr(engine, "_compile_loops", True) if engine is not None
else False # engine-less families opt into compile later
)
self._compiled_cache: dict[Callable, Callable] = {}
# DCW (Differential Correction in Wavelet domain) — post-step
# sampler correction. Always constructed; ``is_active`` short-
# circuits in the hot path when disabled. Hot-updatable via
# :meth:`set_dcw` without rebuilding the pipeline.
self._dcw_corrector: DCWCorrector = DCWCorrector(
enabled=config.dcw_enabled,
mode=config.dcw_mode,
scaler=config.dcw_scaler,
high_scaler=config.dcw_high_scaler,
wavelet=config.dcw_wavelet,
advanced=config.dcw_advanced,
)
# Stats
self.ticks: int = 0
self._last_tick_ms: float = 0.0
@property
def depth(self) -> int:
return self._depth
@property
def active_slots(self) -> int:
return sum(1 for s in self._slots if s is not None)
@property
def is_warmed_up(self) -> bool:
"""True when all slots are occupied (steady state)."""
return all(s is not None for s in self._slots)
@property
def has_trt(self) -> bool:
return self._trt_engine is not None
def _on_engine_swapped(self) -> None:
"""Re-snapshot TRT refs and drop the stale buffer cache.
Wired up in __init__ via ``engine.add_engine_swap_listener``.
Fires on the runner thread because the profile manager runs the
swap inside the streaming pipeline's ``before_tick`` rendezvous,
so no concurrent ``tick()`` can be reading these fields.
``_trt_bufs`` is invalidated rather than reallocated: the next
forward pass calls :meth:`_ensure_trt_bufs` with the live
``(B, T, max_L)`` shape, which binds against the new engine's
profile and only allocates if the shape actually differs from
the now-discarded one.
"""
engine = self.engine
self._trt_engine = engine._trt_engine
self._trt_ctx = engine._trt_ctx
self._trt_stream = engine._trt_stream
self._trt_io_dtype = getattr(engine, "_trt_io_dtype", torch.float32)
self._trt_input_dtypes = getattr(engine, "_trt_input_dtypes", {})
self._trt_output_dtype = getattr(
engine, "_trt_output_dtype", self._trt_io_dtype
)
self._steering_num_layers = getattr(engine, "_steering_num_layers", 0)
self._steering_hidden_size = getattr(engine, "_steering_hidden_size", 0)
self._trt_bufs = None
self._trt_out_buf = None
self._trt_bufs_cache.clear()
# Hooks live on the old decoder.layers; the new decoder needs a
# fresh install on the next non-empty set_steering call.
self._steering_hooks_installed = False
def submit(self, request: SlotRequest) -> None:
"""Enqueue a generation request.
The queue is capped at ``_depth`` items. When the caller submits
faster than the pipeline consumes (always the case when
depth < infer_steps), the oldest queued request is dropped so
that fresh parameters reach the ring buffer promptly instead of
sitting behind an ever-growing backlog of stale requests.
"""
if len(self._queue) >= self._depth:
self._queue.pop(0)
self._queue.append(request)
def _get_schedule(self, denoise: float) -> torch.Tensor:
"""Get (cached) timestep schedule for a given denoise value.
Schedule construction is family knowledge (ACE flow-matching
``shift`` warp vs SA3 LogSNR warp), so it lives on the adapter;
the cache stays here.
"""
if denoise not in self._schedule_cache:
self._schedule_cache[denoise] = self.adapter.build_schedule(
self.config, denoise, self._device, self._dtype
).cpu()
return self._schedule_cache[denoise]
def _ensure_device(self, device: torch.device, dtype: torch.dtype):
if self._device is None:
self._device = device
self._dtype = dtype
def _make_noise(self, request: SlotRequest) -> torch.Tensor:
"""Generate initial noise for a request.
``request.seed`` accepts the same shapes as the old
``DiffusionEngine._prepare_noise_cpu``:
- ``None`` → fresh RNG state per call
- ``int`` → scalar manual_seed, one randn call
- ``List[int]`` → per-row seeding, one row per seed
(rows with seed < 0 reuse the
current RNG state; matches upstream
behavior for "don't reseed")
Layout matches ComfyUI's RandomNoise node when noise_on_cpu is set:
generate in [B,D,T] on CPU, then transpose to [B,T,D] and move to
the pipeline's device/dtype.
T and the channel count are family knowledge (Tier-2 seam): the
historical ``context_latents.shape[-1] // 2`` is the ACE
``src ++ chunk_mask`` convention, now ``adapter.latent_channels``.
"""
T = self.adapter.request_frames(request)
D = self.adapter.latent_channels
seed = request.seed
cpu = self.config.noise_on_cpu
gen_device = "cpu" if cpu else self._device
gen_dtype = torch.float32 if cpu else self._dtype
if isinstance(seed, list):
rows = []
for s in seed:
if s is not None and s >= 0:
torch.manual_seed(int(s))
if cpu:
rows.append(torch.randn(1, D, T, device=gen_device, dtype=gen_dtype))
else:
rows.append(torch.randn(1, T, D, device=gen_device, dtype=gen_dtype))
stacked = torch.cat(rows, dim=0)
if cpu:
stacked = stacked.movedim(-1, -2)
return stacked.to(device=self._device, dtype=self._dtype)
if seed is not None:
torch.manual_seed(int(seed))
if cpu:
noise_bdt = torch.randn(1, D, T, device="cpu", dtype=torch.float32)
return noise_bdt.movedim(-1, -2).to(
device=self._device, dtype=self._dtype
)
return torch.randn(
1, T, D, device=self._device, dtype=self._dtype
)
def _init_slot(self, request: SlotRequest) -> _Slot:
"""Create a new slot from a request, initialized at step 0."""
self._ensure_device(*self.adapter.request_device_dtype(request))
t_schedule = self._get_schedule(request.denoise)
noise = self._make_noise(request)
t_start = t_schedule[0].item()
# Resolve the "clean" latent for partial-denoise init. Old generate()
# fell back to latent_mask.original_latents when no explicit
# source_latents was provided (inpainting flows rely on this).
src_clean = request.source_latents
if src_clean is None and request.latent_mask is not None:
src_clean = request.latent_mask.original_latents
if request.initial_noise_curve is not None and src_clean is not None:
# Per-frame noise/source mixing for the initial state. Ignores
# the denoise sigma in favor of the explicit per-frame curve —
# 1.0 = pure noise, 0.0 = pure source.
curve = ode_steps.normalize_curve(request.initial_noise_curve).to(
device=self._device, dtype=self._dtype,
)
xt = curve * noise + (1.0 - curve) * src_clean
elif src_clean is not None and request.denoise < 1.0:
xt = t_start * noise + (1.0 - t_start) * src_clean
else:
xt = noise.clone()
momentum_buffer = (
ode_steps.MomentumBuffer() if request.has_cfg else None
)
# RCFG-self uses the slot's initial noise tensor as the virtual
# ``v_uncond``. Captured once at slot init; lives on the slot for
# the rest of its schedule. Only allocated for ``rcfg_mode ==
# "self"`` — other modes never read this field.
initial_noise = (
noise.clone() if request.rcfg_mode == "self" else None
)
return _Slot(
request=request, xt=xt,
t_schedule=t_schedule, step_idx=0,
momentum_buffer=momentum_buffer,
initial_noise=initial_noise,
)
# ------------------------------------------------------------------
# Sentinel tensors + compiled step helpers (PyTorch backend)
# ------------------------------------------------------------------
def _ensure_sentinels(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Lazy-build broadcast-safe ones/zeros sentinels.
Constructed once the pipeline's device/dtype are known (set by the
first ``submit``/``generate`` call) and reused for the lifetime of
the pipeline. Shape is ``[1, 1, 1]`` so they broadcast cleanly
against ``xt`` and ``vt`` (``[1, T, D]``) without per-step allocs.
"""
if self._ones_3d is None:
self._ones_3d = torch.ones(
1, 1, 1, device=self._device, dtype=self._dtype,
)
self._zeros_3d = torch.zeros(
1, 1, 1, device=self._device, dtype=self._dtype,
)
return self._ones_3d, self._zeros_3d
# Compiled wrappers over the pure step primitives in ``ode_steps``.
# The primitives themselves are branch-free and reference no ``self``,
# so ``torch.compile`` can trace each into a single fused graph. The
# sentinel-tensor idiom (``ones_3d`` for absent velocity_scale,
# ``zeros_3d`` for absent ode_noise_curve) keeps the ODE graph flat
# without ``is None`` branches — the multiply is a byte-identical
# no-op but lets the compiler specialize one straight-line kernel.
def _get_compiled(self, fn: Callable) -> Callable:
"""Return a (possibly compiled) wrapper around a pure step primitive.
Compilation is lazy — we only pay the inductor cost on PT
pipelines that actually exercise the primitive. ``dynamic=True``
lets the graph accept varying T / dtype without re-tracing on
every shape change. Results are memoized per primitive on
``self._compiled_cache``.
"""
cache = self._compiled_cache
cached = cache.get(fn)
if cached is not None:
return cached
compiled = fn
if self._compile_loops:
try:
compiled = torch.compile(fn, backend="inductor", dynamic=True)
except Exception as e: # pragma: no cover - fallback path
logger.warning(
"torch.compile({}) failed ({}); falling back to eager",
fn.__name__, e,
)
cache[fn] = compiled
return compiled
# ------------------------------------------------------------------
# Per-slot curve resolution (reads shared overrides first, then slot)
# ------------------------------------------------------------------
def _resolve_slot_curves(
self,
slot: "_Slot",
vt: torch.Tensor,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]:
"""Resolve per-slot curve tensors, preferring shared overrides.
Returns ``(vs, sdc, onc)``:
- ``vs`` is the velocity_scale curve broadcast-ready against
``vt``; returns the ``ones_3d`` sentinel when absent so the
compiled step helpers can ``vt * vs`` branch-free.
- ``sdc`` is the normalized sde_denoise_curve or ``None``.
Callers gate different integration paths on its presence +
``source_latents``, so the sentinel idiom doesn't apply.
- ``onc`` is the normalized ode_noise_curve broadcast-ready
against ``slot.xt``; returns the ``zeros_3d`` sentinel when
absent so the post-step ``randn * onc`` injection is a
branch-free no-op.
Byte-equivalent to the inline curve-resolution blocks that
previously lived in both tick paths.
"""
ones_3d, zeros_3d = self._ensure_sentinels()
eff_vs = self._eff_shared(slot, "velocity_scale")
eff_sdc = self._eff_shared(slot, "sde_denoise_curve")
eff_onc = self._eff_shared(slot, "ode_noise_curve")
vs = (
eff_vs.to(device=vt.device, dtype=vt.dtype)
if eff_vs is not None else ones_3d
)
sdc = (
eff_sdc.to(device=slot.xt.device, dtype=slot.xt.dtype)
if eff_sdc is not None else None
)
onc = (
eff_onc.to(device=slot.xt.device, dtype=slot.xt.dtype)
if eff_onc is not None else zeros_3d
)
return vs, sdc, onc
# The batched decoder forward lives on the family's ModelAdapter
# (``self.adapter.batched_forward``) — moved there verbatim from
# the old ``_decoder_forward`` so conditioning batching is family
# knowledge (Tier-2 seam). The TRT dispatch machinery below stays
# pipeline-owned: it is ACE engine state, reached only via the
# ACEAdapter (relocating it is Phase-4 acceleration-contract work).
def _trt_forward(
self,
xt_batch: torch.Tensor,
timestep_list: List[float],
enc_batch: torch.Tensor,
ctx_batch: torch.Tensor,
) -> torch.Tensor:
"""Run one batched TRT forward pass on pre-built tensors.
Uses the shape-keyed buffer cache via :meth:`_ensure_trt_bufs`
so repeated ticks with the same ``(B, T, max_L)`` shape reuse
allocations.
The active ring-buffer rows are submitted as one TensorRT
execution. If ``B`` exceeds the loaded engine profile, TensorRT
rejects the shape instead of silently falling back to sequential
per-row dispatch.
"""
B, T, _ = xt_batch.shape
if len(timestep_list) != B:
raise RuntimeError(
"TRT stream batch mismatch: "
f"hidden_states batch={B}, timesteps={len(timestep_list)}"
)
if enc_batch.shape[0] != B or ctx_batch.shape[0] != B:
raise RuntimeError(
"TRT stream batch mismatch: "
f"hidden_states={tuple(xt_batch.shape)}, "
f"encoder={tuple(enc_batch.shape)}, "
f"context={tuple(ctx_batch.shape)}"
)
max_L = enc_batch.shape[1]
self._ensure_trt_bufs(B, T, max_L)
bufs = self._trt_bufs
pad = T % 2 == 1
# hidden_states: cast to engine I/O dtype; pad odd T with zeros.
xt_io = xt_batch.to(bufs["hidden_states"].dtype)
if pad:
bufs["hidden_states"][:, :T, :].copy_(xt_io)
bufs["hidden_states"][:, T:, :].zero_()
else:
bufs["hidden_states"].copy_(xt_io)
# timestep: one scalar per row.
for i, t in enumerate(timestep_list):
bufs["timestep"][i] = t
# encoder_hidden_states: already padded to max_L + catted by
# the caller. The engine has no ``encoder_attention_mask``
# input; padding is handled by zero-value convention.
bufs["encoder_hidden_states"].copy_(
enc_batch.to(bufs["encoder_hidden_states"].dtype)
)
# context_latents: pad odd T with zeros.
ctx_io = ctx_batch.to(bufs["context_latents"].dtype)
if pad:
bufs["context_latents"][:, :T, :].copy_(ctx_io)
bufs["context_latents"][:, T:, :].zero_()
else:
bufs["context_latents"].copy_(ctx_io)
# Steering: absent on non-spectral engines.
if "steering" in bufs:
self._fill_trt_steering_buffer(bufs["steering"], B)
# Rebind and execute.
ctx = self._trt_ctx
for name, buf in bufs.items():
if name.startswith("_"):
continue
if not ctx.set_tensor_address(name, buf.data_ptr()):
raise RuntimeError(f"TRT decoder rejected input address for {name}")
if not ctx.set_tensor_address("velocity", self._trt_out_buf.data_ptr()):
raise RuntimeError("TRT decoder rejected output address for velocity")
if not ctx.execute_async_v3(self._trt_stream.ptr):
raise RuntimeError("TRT decoder execution failed")
self._trt_stream.synchronize()
out = self._trt_out_buf
if pad:
return out[:, :T, :].to(self._dtype)
return out.to(self._dtype)
def _steering_row_mask(self, target_step: int, B: int) -> Optional[List[int]]:
"""Rows whose slot is at ``target_step``, or None to skip.
Returns None when no per-row step mapping exists or it
disagrees with ``B`` — the eager hook and TRT buffer fill both
skip injection in that case rather than firing blindly.
"""
row_steps = self._current_step_per_row
if not row_steps or len(row_steps) != B:
return None
mask = [i for i, s in enumerate(row_steps) if s == target_step]
return mask or None
def _fill_trt_steering_buffer(self, buf: torch.Tensor, B: int) -> None:
"""Populate the TRT steering buffer for one forward.
``buf`` is ``[B, num_layers, hidden_size]``; zeroed first so
previous-tick content doesn't leak. Rows with no matching shift
stay zero, which the engine adds as a no-op per layer.
"""
buf.zero_()
if not self._steering_by_layer:
return
for layer_idx, applies in self._steering_by_layer.items():
if layer_idx < 0 or layer_idx >= self._steering_num_layers:
continue
for apply in applies:
mask_rows = self._steering_row_mask(apply.step, B)
if mask_rows is None:
continue
v = apply.vector.to(device=buf.device, dtype=buf.dtype)
buf[mask_rows, layer_idx, :] += apply.scale * v
# ------------------------------------------------------------------
# TRT buffer management
# ------------------------------------------------------------------
def _ensure_trt_bufs(self, B: int, T: int, max_L: int):
"""Bind TRT I/O buffers for ``(B, T, max_L)`` via the LRU cache.
Reuses an existing cache entry when the shape has been seen
recently; allocates a new entry (evicting the oldest if the
cache is full) otherwise. The TRT execution context still has
to be re-bound to whichever entry we use, because addresses
change as we swap entries — but allocations only happen on a
true miss. Uses the engine's native I/O dtype.
"""
eff_T = T + 1 if T % 2 == 1 else T
key = (B, eff_T, max_L)
ctx = self._trt_ctx
cached = self._trt_bufs_cache.get(key)
if cached is not None:
self._trt_bufs_cache.move_to_end(key)
for name, buf in cached.items():
if name.startswith("_"):
continue
ctx.set_input_shape(name, tuple(buf.shape))
ctx.set_tensor_address(name, buf.data_ptr())
ctx.set_tensor_address("velocity", cached["_out_buf"].data_ptr())
self._trt_bufs = cached
self._trt_out_buf = cached["_out_buf"]
return
device = self._device
io_dtype = self._trt_io_dtype
in_dtypes = self._trt_input_dtypes
bufs = {
"hidden_states": torch.empty(
B, eff_T, 64,
dtype=in_dtypes.get("hidden_states", io_dtype),
device=device,
),
"timestep": torch.empty(
B,
dtype=in_dtypes.get("timestep", torch.float32),
device=device,
),
"encoder_hidden_states": torch.empty(
B, max_L, 2048,
dtype=in_dtypes.get("encoder_hidden_states", io_dtype),
device=device,
),
"context_latents": torch.empty(
B, eff_T, 128,
dtype=in_dtypes.get("context_latents", io_dtype),
device=device,
),
}
if self._steering_num_layers > 0:
# Zeroed so a tick with no active configs is a true no-op
# (engine still adds zeros per layer); repopulated each
# forward by _trt_forward.
bufs["steering"] = torch.zeros(
B, self._steering_num_layers, self._steering_hidden_size,
dtype=in_dtypes.get("steering", io_dtype),
device=device,
)
for name, buf in bufs.items():
if not ctx.set_input_shape(name, tuple(buf.shape)):
raise RuntimeError(
f"TRT decoder rejected input shape for {name}: "
f"{tuple(buf.shape)}"
)
if not ctx.set_tensor_address(name, buf.data_ptr()):
raise RuntimeError(f"TRT decoder rejected input address for {name}")
missing = ctx.infer_shapes()
if missing:
raise RuntimeError(
f"TRT decoder shapes are insufficiently specified: {missing}"
)
out_shape = tuple(ctx.get_tensor_shape("velocity"))
if any(d < 0 for d in out_shape):
raise RuntimeError(
f"TRT output shape unresolved: {out_shape}. "
f"B={B}, eff_T={eff_T}, L={max_L}"
)
out_buf = torch.empty(out_shape, dtype=self._trt_output_dtype, device=device)
if not ctx.set_tensor_address("velocity", out_buf.data_ptr()):
raise RuntimeError("TRT decoder rejected output address for velocity")
bufs["_key"] = key
bufs["_eff_T"] = eff_T
bufs["_T"] = T
bufs["_out_buf"] = out_buf
self._trt_bufs_cache[key] = bufs
while len(self._trt_bufs_cache) > self._trt_bufs_cache_max:
self._trt_bufs_cache.popitem(last=False)
self._trt_bufs = bufs
self._trt_out_buf = out_buf
logger.debug(
"Stream TRT bufs allocated: B={} eff_T={} L={}", B, eff_T, max_L
)
# ------------------------------------------------------------------
# Main tick
# ------------------------------------------------------------------
@torch.no_grad()
def tick(self) -> Optional[torch.Tensor]:
"""Run one batched forward pass, advancing all active slots.
Returns:
Finished latent [1, T, D] if a slot completed, else None.
Dispatches to :meth:`_tick_pt`, the single unified per-tick
path. The PT path handles every feature combination — mask
pre/post blending, multi-condition temporal blending, CFG
(APG), per-frame curves (velocity_scale, sde_denoise,
ode_noise), x0_target blending (scalar or per-frame curve),
and both ODE and SDE solvers — by composing pure step
primitives from :mod:`ode_steps`. The TRT backend is selected
inside the family adapter's forward
(:meth:`acestep.engine.model_adapter.ACEAdapter.batched_forward`).
"""
tick_start = time.time()
# T-coherence: a source swap to a different-length audio leaves
# in-flight slots holding xt of the old T while fresh submits
# carry the new T. _tick_pt cats xt across slots in dim 0, which
# requires the time dim to match — drop stale slots and stale
# queued requests so the next batch is uniform. The "target" T
# is the most recently submitted request's; older queued ones
# from before the swap are filtered out alongside the slots.
if self._queue:
target_T = self.adapter.request_frames(self._queue[-1])
if any(
self.adapter.request_frames(r) != target_T for r in self._queue
):
self._queue = [
r for r in self._queue
if self.adapter.request_frames(r) == target_T
]
for i, slot in enumerate(self._slots):
if slot is not None and slot.xt.shape[1] != target_T:
self._slots[i] = None
# Check for finished slot (slot at final step of its schedule)
finished = None
for i, slot in enumerate(self._slots):
if slot is not None and slot.step_idx >= len(slot.t_schedule) - 1:
finished = slot.xt
self._slots[i] = None
break
# Fill empty slots from queue
for i, slot in enumerate(self._slots):
if slot is None and self._queue:
req = self._queue.pop(0)
self._slots[i] = self._init_slot(req)
# Collect active slots (exclude completed)
active = [
(i, s) for i, s in enumerate(self._slots)
if s is not None and s.step_idx < len(s.t_schedule) - 1
]
if not active:
self._last_tick_ms = (time.time() - tick_start) * 1000
self.ticks += 1
return finished
indices, slots = zip(*active)
self._tick_pt(slots, indices)
self._last_tick_ms = (time.time() - tick_start) * 1000
self.ticks += 1
return finished
# ------------------------------------------------------------------
# Unified tick — all features compose via pure ode_steps bricks
# ------------------------------------------------------------------
def _active_conditions(self, slot) -> List[SlotCondition]:
"""Return positive conditions active at this slot's current step.
Falls back to the primary condition if no condition's
``step_range`` is satisfied, so the decoder always has at least
one forward pass per slot.
"""
total_steps = len(slot.t_schedule) - 1
conds = slot.request.all_conditions()
active = [c for c in conds if c.is_active_at_step(slot.step_idx, total_steps)]
return active if active else [conds[0]]
def _active_neg_conditions(self, slot) -> List[SlotCondition]:
"""Return negative conditions active at this slot's current step.
Falls back to ``neg_conditions[0]`` when none match — mirrors the
pre-refactor behavior of ``negative_condition_set``. Returns an
empty list when CFG is not enabled on the slot.
"""
if not slot.request.has_cfg:
return []