-
Notifications
You must be signed in to change notification settings - Fork 419
Expand file tree
/
Copy pathvae_patcher.py
More file actions
38 lines (31 loc) · 854 Bytes
/
vae_patcher.py
File metadata and controls
38 lines (31 loc) · 854 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from comfy.model_management import dtype_size
from .nodes_registry import comfy_node
@comfy_node(name="LTXVPatcherVAE")
class LTXVPatcherVAE:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"vae": ("VAE",),
}
}
RETURN_TYPES = ("VAE",)
FUNCTION = "patch"
CATEGORY = "lightricks/LTXV"
TITLE = "LTXV VAE Patcher"
def patch(self, vae):
from q8_kernels.integration.patch_vae import patch_vae
vae_model = vae.first_stage_model
vae.memory_used_decode = (
lambda shape, dtype: shape[1]
* shape[2]
* shape[3]
* shape[4]
* 8
* 8
* 8
* 3
* dtype_size(dtype)
)
patch_vae(vae_model, patch_block=4)
return (vae,)