-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencode_prompt.py
More file actions
47 lines (36 loc) · 1.24 KB
/
Copy pathencode_prompt.py
File metadata and controls
47 lines (36 loc) · 1.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import hydra
import torch
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, DistributedDataParallelKwargs
from hydra.utils import instantiate
from omegaconf import OmegaConf
from torch.utils.data import Dataset
from diffusers_tuner.tuner import TuneConfigs, Tuner
from pipelines.pipeline_utils import PipelineConfigs, TunePipeline
@hydra.main(config_path="configs", config_name="infer", version_base="v1.2")
def tune(cfgs: OmegaConf):
accelerator = Accelerator()
weight_dtype = torch.float32
if accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
elif accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
dataset: Dataset = instantiate(cfgs.dataset)
pipe_cfgs: PipelineConfigs = instantiate(cfgs.pipeline)
pipeline = TunePipeline(
pipe_cfgs,
weight_dtype=weight_dtype,
device=accelerator.device,
)
tuner_cfgs: TuneConfigs = instantiate(cfgs.tune)
tuner = Tuner(tuner_cfgs)
tuner.prepare_prompt_embeds(
accelerator,
pipeline,
dataset,
cfgs.prompt_embeds_save_dir,
device=accelerator.device,
weight_dtype=weight_dtype,
)
if __name__ == "__main__":
tune()