-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdiffusion_model.py
More file actions
172 lines (139 loc) · 7.3 KB
/
Copy pathdiffusion_model.py
File metadata and controls
172 lines (139 loc) · 7.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import torch
from torch import Tensor
from typing import List
import matplotlib.pyplot as plt
from utils import *
class DiffusionProcess:
"""
Class that implements the forward and inverse diffusion process according to
the DDPM paper: https://arxiv.org/pdf/2006.11239.pdf
Attributes:
variance_schedule (torch.Tensor): list with the variance value at each
timestep according to DDPM paper
alpha (torch.Tensor): list of "complement" values for variance defined
in the DDPM paper. It is the same as 1-variance_schedule
alpha_bar (torch.Tensor): cummulative product defined int he DDPM
paper above. It is derived directly from the variance schedule
"""
def __init__(self, variance_schedule: List[float] = None,device=None) -> None:
"""
Args:
variance_schedule (list): list with the variance value at each
timestep according to DDPM paper. If left None, it will default
to a list of linearly increasing variance from 1e-4 to 0.02 in
1000 steps
"""
if variance_schedule is None:
variance_schedule = torch.linspace(1e-4, 0.01, steps=200)
self.variance_schedule = Tensor(variance_schedule).to(device)
self.alpha = 1 - self.variance_schedule
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
self.alpha_, self.beta_, self.alpha_cum_, self.sigmas_, self.T, self.c1, self.c2, self.c3, self.delta, self.delta_bar = inference_schedule()
def forward(self, x_0: Tensor,y: Tensor, time_step: Tensor, noise: Tensor) -> Tensor:
"""
This applies the forward diffusion process to original image ``x`` to
timestamp ``time_step``, using a sample ``noise`` from a zero mean, unit
variance gaussian distribution. The formula for the forward propagation,
given step time_step and original tensor x is:
p(x_t | x_0) = N(sqrt(alpha_t) * x_0, (1-alpha_t) * I)
where I is the identity matrix. This can be reparameterised as following:
p(x_t | x_0) = sqrt(alpha_t) * x_0 + sqrt(1-alpha_t) * noise
where ``noise``~N(0, I). This is the expression used in this function.
Args:
x_0 (torch.Tensor): original image. It can be of any shape, since
diffusion is independent and identically distributed (iid)
time_step (torch.Tensor): which step to diffuse the original image.
It is a Tensor with numbers between 0 and len(self.alpha_bar). It
also has to have the same batch size as ``x_0``
noise (torch.Tensor): the noise to be added at this ``time_step``. It
has to be a tensor sampled from a zero-mean unit-variance normal
distribution.
Returns:
torch.Tensor: the result of diffusig original image ``x_0`` to
``time_step`` using the variance schedule :attr:alpha_bar
"""
# Checking for validity of input
assert torch.all(time_step >= 0).item()
assert torch.all(time_step < len(self.alpha_bar)).item()
assert time_step.shape[0] == x_0.shape[0]
# print(' ')
# print(x_0.shape)
# print(time_step)
# print(self.alpha.shape)
# print(self.alpha_bar.shape)
# print(self.alpha_bar[time_step])
m_t =torch.sqrt((1-self.alpha_bar[time_step])/torch.sqrt(self.alpha_bar[time_step]))
k1 = (1-m_t)*torch.sqrt(self.alpha_bar[time_step])
k2 = m_t*torch.sqrt(self.alpha_bar[time_step])
k3 = torch.sqrt(1-(1+torch.square(m_t))*self.alpha_bar[time_step])
k4 = torch.sqrt(1-self.alpha_bar[time_step])
# std_dev = torch.sqrt(1 - self.alpha_bar[time_step])
# mean_multiplier = torch.sqrt(self.alpha_bar[time_step])
#
# print(" ")
# print(std_dev)
# print(mean_multiplier)
# print(std_dev.shape)
# print(mean_multiplier.shape)
# This makes sure that variance and mean multiplier are both broadcastable
# std_dev = std_dev[:, None, None, None].to(x_0.device)
# mean_multiplier = mean_multiplier[:, None, None, None].to(x_0.device)
k1 = k1[:, None, None, None].to(x_0.device)
k2 = k2[:, None, None, None].to(x_0.device)
k3 = k3[:, None, None, None].to(x_0.device)
k4 = k4[:, None, None, None].to(x_0.device)
# print(std_dev.shape)
# print(mean_multiplier.shape)
diffused_images = k1 * x_0 + k2 * y + k3 * noise
comb_noise = (k2 *(y-x_0) + k3*noise) / k4
# print(diffused_images.shape)
return diffused_images, comb_noise
def proposed_inverse(self,xt:Tensor, nac:Tensor, et: Tensor, t:int) -> Tensor:
output = self.c1[t] * xt + self.c2[t] * nac.to(xt.device) - self.c3[t] * et.to(xt.device)
return output
def inverse(self, xt: Tensor, et: Tensor, t: int) -> Tensor:
"""
This applies the unconditional sampling of the diffusion step. It uses the
equation as follow:
p(x_{t-1}| x_t) = mu_t + std_dev_t * N(0, I)
mu_t = (1/sqrt(alpha_t)) * (xt - noise_scale * et)
noise_scale = (1-alpha_t) / sqrt(1-alpha_bar_t)
std_dev_t = sqrt(variance_schedule)
this is from the DDPM paper.
Args:
xt (torch.Tensor): noisy image at time ``t``.
et (torch.Tensor): predicted error from diffusion model, which is
usually the output of the trained UNet architecture
t (int): the time ``t`` of the diffusion process
Returns:
torch.Tensor: the result of the sampling x_{t-1}
"""
sigmas = ((1-self.alpha_bar[t-1]) / (1 - self.alpha_bar[t]) * self.variance_schedule[t])
m_t = torch.sqrt((1 - self.alpha_bar[t]) / torch.sqrt(self.alpha_bar[t]))
scale = 1 / torch.sqrt(self.alpha[t])
noise_scale = (1 - self.alpha[t]) / torch.sqrt(1 - self.alpha_bar[t])
std_dev = torch.sqrt(self.variance_schedule[t])
mu_t = scale * (xt - noise_scale * et)
z = torch.randn(xt.shape) if t > 1 else torch.Tensor([0])
xt = mu_t + std_dev * z # remove noise from image
return xt
def inverse_DDIM(self, xt: Tensor, et: Tensor, t: int) -> Tensor:
"""
This applies the unconditional sampling of the diffusion step using the
DDIM method: https://arxiv.org/abs/2010.02502
f_theta acts as an approximation for x_0, and the rest follows equation
(7) in the paper. For DDIM, we have that std_dev = 0
This solves the problem of stochasticity, and it is supposed to be 10x
to 100x quicker than the DDPM method
"""
den = 1 / torch.sqrt(self.alpha_bar[t])
f_theta = (xt - torch.sqrt(1 - self.alpha_bar[t]) * et) * den
if t > 0:
part1 = torch.sqrt(self.alpha_bar[t - 1]) * f_theta
part2 = torch.sqrt(1 - self.alpha_bar[t - 1])
den = 1 / torch.sqrt(1 - self.alpha_bar[t])
scale = (xt - torch.sqrt(self.alpha_bar[t]) * f_theta) * den
xt = part1 + part2 * scale
else:
xt = f_theta
return xt