From dabb1e83359404f299463fb6bbc322349dcc6779 Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Thu, 12 Feb 2026 00:14:54 +0530 Subject: [PATCH] Fix NotImplemented errors, xformers attention shape, and missing text conditioning - Replace `raise NotImplemented` with `raise NotImplementedError` in both latte.py and latte_img.py. `NotImplemented` is not an exception class and will raise a TypeError instead of the intended error. - Transpose q, k, v from (B, heads, N, dim) to (B, N, heads, dim) before calling xformers memory_efficient_attention in latte_img.py, matching the correct implementation in latte.py. xformers expects the (B, N, heads, dim) layout. - Add missing `elif self.extras == 78` branch before the final layer in latte.py so that text_embedding_spatial conditioning is applied during the final adaptive layer norm, consistent with the temporal blocks above. --- models/latte.py | 4 +++- models/latte_img.py | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/models/latte.py b/models/latte.py index f7ee920..9680d7d 100644 --- a/models/latte.py +++ b/models/latte.py @@ -70,7 +70,7 @@ def forward(self, x): x = (attn @ v).transpose(1, 2).reshape(B, N, C) else: - raise NotImplemented + raise NotImplementedError x = self.proj(x) x = self.proj_drop(x) @@ -369,6 +369,8 @@ def forward(self, if self.extras == 2: c = timestep_spatial + y_spatial + elif self.extras == 78: + c = timestep_spatial + text_embedding_spatial else: c = timestep_spatial x = self.final_layer(x, c) diff --git a/models/latte_img.py b/models/latte_img.py index c468c63..4fc507d 100644 --- a/models/latte_img.py +++ b/models/latte_img.py @@ -58,7 +58,11 @@ def forward(self, x): q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) if self.attention_mode == 'xformers': # cause loss nan while using with amp - x = xformers.ops.memory_efficient_attention(q, k, v).reshape(B, N, C) + # https://github.com/facebookresearch/xformers/blob/e8bd8f932c2f48e3a3171d06749eecbbf1de420c/xformers/ops/fmha/__init__.py#L135 + q_xf = q.transpose(1,2).contiguous() + k_xf = k.transpose(1,2).contiguous() + v_xf = v.transpose(1,2).contiguous() + x = xformers.ops.memory_efficient_attention(q_xf, k_xf, v_xf).reshape(B, N, C) elif self.attention_mode == 'flash': # cause loss nan while using with amp @@ -73,7 +77,7 @@ def forward(self, x): x = (attn @ v).transpose(1, 2).reshape(B, N, C) else: - raise NotImplemented + raise NotImplementedError x = self.proj(x) x = self.proj_drop(x)