-
Notifications
You must be signed in to change notification settings - Fork 0
/
diffusion_models.py
100 lines (78 loc) · 3.39 KB
/
diffusion_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
#from keras.datasets.fashion_mnist import load_data
from keras.src.datasets.cifar10 import load_data
from unet import UNet
(trainX, _), (testX, _) = load_data()
trainX = np.float32(trainX) / 255.
testX = np.float32(testX) / 255.
# print the labels
print(trainX.shape)
exit()
# Ensure the shape is (N, C, H, W)
trainX = trainX.transpose(0, 3, 1, 2) # Change to (num_samples, channels, height, width)
testX = testX.transpose(0, 3, 1, 2) # Change to (num_samples, channels, height, width)
def sample_batch(batch_size, device):
indices = torch.randperm(trainX.shape[0])[:batch_size]
data = torch.from_numpy(trainX[indices]).to(device) # Shape: [batch_size, 3, 32, 32]
return data
class DiffusionModel:
def __init__(self, T: int, model: nn.Module, device: str):
self.T = T
self.function_approximator = model.to(device)
self.device = device
self.beta = torch.linspace(1e-4, 0.02, T).to(device)
self.alpha = 1. - self.beta
self.alpha_bar = torch.cumprod(self.alpha, dim=0)
def training(self, batch_size, optimizer):
"""
Algorithm 1 in Denoising Diffusion Probabilistic Models
"""
x0 = sample_batch(batch_size, self.device)
t = torch.randint(1, self.T + 1, (batch_size,), device=self.device,
dtype=torch.long)
eps = torch.randn_like(x0)
# Take one gradient descent step
alpha_bar_t = self.alpha_bar[t - 1].unsqueeze(-1).unsqueeze(
-1).unsqueeze(-1)
eps_predicted = self.function_approximator(torch.sqrt(
alpha_bar_t) * x0 + torch.sqrt(1 - alpha_bar_t) * eps, t - 1)
loss = nn.functional.mse_loss(eps, eps_predicted)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def sampling(self, n_samples=1, image_channels=3, img_size=(32, 32),
use_tqdm=True):
"""
Algorithm 2 in Denoising Diffusion Probabilistic Models
"""
x = torch.randn((n_samples, image_channels, img_size[0], img_size[1]),
device=self.device)
progress_bar = tqdm if use_tqdm else lambda x: x
for t in progress_bar(range(self.T, 0, -1)):
z = torch.randn_like(x) if t > 1 else torch.zeros_like(x)
t = torch.ones(n_samples, dtype=torch.long, device=self.device) * t
beta_t = self.beta[t - 1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
alpha_t = self.alpha[t - 1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
alpha_bar_t = self.alpha_bar[t - 1].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
mean = 1 / torch.sqrt(alpha_t) * (x - ((1 - alpha_t) / torch.sqrt(
1 - alpha_bar_t)) * self.function_approximator(x, t - 1))
sigma = torch.sqrt(beta_t)
x = mean + sigma * z
return x
if __name__ == "__main__":
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 64
model = UNet()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
diffusion_model = DiffusionModel(1000, model, device)
# Training
for epoch in tqdm(range(40_000)):
loss = diffusion_model.training(batch_size, optimizer)
# Save model
torch.save(model.state_dict(), 'model.pth')