Autoencoders¶
A type of feedforward neural networks where the input is the same as the output. They compress the input into a lower-dimensional latent representation and then reconstruct the output from this representation. An autoencoder consists of 3 components: encoder, latent representation, and decoder. The encoder compresses the input and produces the representation, the decoder then reconstructs the input only using this representation.
Install required packages¶
!pip install torch torchvision matplotlib
Import¶
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import math
Load the MNIST DataSet¶
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5), (0.5))
])
transform = transforms.ToTensor()
mnist_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
data_loader = torch.utils.data.DataLoader(dataset=mnist_data,
batch_size=64,
shuffle=True)
dataiter = iter(data_loader)
images, labels = next(dataiter)
print(torch.min(images), torch.max(images))
tensor(0.) tensor(1.)
Define the autoencoder as an MLP¶
The input is 28x28=784. First linear layer maps to 128, then 64, then 12 and finally 3.
# repeatedly reduce the size
class Autoencoder_Linear(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128), # (N, 784) -> (N, 128)
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 12),
nn.ReLU(),
nn.Linear(12, 3) # -> N, 3
)
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(),
nn.Linear(12, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 28 * 28),
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
model = Autoencoder_Linear()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=1e-3,
weight_decay=1e-5)
Let's inspect the output of the encoder¶
We'll see that it maps a 784-dimensional vector to a 3-dimensional vector.
model.encoder(images[0].reshape(-1, 28*28))
tensor([[ 0.1642, -0.2323, -0.1464]], grad_fn=<AddmmBackward0>)
Train the MLP, encoder&decoder for 1 epoch¶
num_epochs = 1
outputs = []
for epoch in range(num_epochs):
for (img, _) in data_loader:
img = img.reshape(-1, 28*28) # -> use for Autoencoder_Linear
recon = model(img)
loss = criterion(recon, img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch:{epoch+1}, Loss:{loss.item():.4f}')
outputs.append((epoch, img, recon))
Epoch:1, Loss:0.0506
for k in range(0, num_epochs, 4):
plt.figure(figsize=(9, 2))
plt.gray()
imgs = outputs[k][1].detach().numpy()
recon = outputs[k][2].detach().numpy()
for i, item in enumerate(imgs):
if i >= 9: break
plt.subplot(2, 9, i+1)
item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
# item: 1, 28, 28
plt.imshow(item[0])
for i, item in enumerate(recon):
if i >= 9: break
plt.subplot(2, 9, 9+i+1) # row_length + i + 1
item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
# item: 1, 28, 28
plt.imshow(item[0])
We have a loss of 0.0506 and the decoder's output does not look great.
Now let's train it for 10 epochs¶
num_epochs = 10
outputs = []
for epoch in range(num_epochs):
for (img, _) in data_loader:
img = img.reshape(-1, 28*28) # -> use for Autoencoder_Linear
recon = model(img)
loss = criterion(recon, img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch:{epoch+1}, Loss:{loss.item():.4f}')
outputs.append((epoch, img, recon))
Epoch:1, Loss:0.0387 Epoch:2, Loss:0.0325 Epoch:3, Loss:0.0299 Epoch:4, Loss:0.0372 Epoch:5, Loss:0.0359 Epoch:6, Loss:0.0352 Epoch:7, Loss:0.0386 Epoch:8, Loss:0.0314 Epoch:9, Loss:0.0323 Epoch:10, Loss:0.0299
for k in range(0, num_epochs, 4):
plt.figure(figsize=(9, 2))
plt.gray()
imgs = outputs[k][1].detach().numpy()
recon = outputs[k][2].detach().numpy()
for i, item in enumerate(imgs):
if i >= 9: break
plt.subplot(2, 9, i+1)
item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
# item: 1, 28, 28
plt.imshow(item[0])
for i, item in enumerate(recon):
if i >= 9: break
plt.subplot(2, 9, 9+i+1) # row_length + i + 1
item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
# item: 1, 28, 28
plt.imshow(item[0])
We achieved a loss of 0.0299 and now the output is pretty accurate.
Let's repeat the experiment with a CNN¶
class Autoencoder(nn.Module):
def __init__(self):
super().__init__()
# N, 1, 28, 28
self.encoder = nn.Sequential(
nn.Conv2d(1, 16, 3, stride=2, padding=1), # -> N, 16, 14, 14
nn.ReLU(),
nn.Conv2d(16, 32, 3, stride=2, padding=1), # -> N, 32, 7, 7
nn.ReLU(),
nn.Conv2d(32, 64, 7) # -> N, 64, 1, 1
)
# N , 64, 1, 1
self.decoder = nn.Sequential(
nn.ConvTranspose2d(64, 32, 7), # -> N, 32, 7, 7
nn.ReLU(),
nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), # N, 16, 14, 14 (N,16,13,13 without output_padding)
nn.ReLU(),
nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1), # N, 1, 28, 28 (N,1,27,27)
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(),
lr=1e-3,
weight_decay=1e-5)
num_epochs = 1
outputs = []
for epoch in range(num_epochs):
for (img, _) in data_loader:
recon = model(img)
loss = criterion(recon, img)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch:{epoch+1}, Loss:{loss.item():.4f}')
outputs.append((epoch, img, recon))
Epoch:1, Loss:0.0051
for k in range(0, num_epochs, 4):
plt.figure(figsize=(9, 2))
plt.gray()
imgs = outputs[k][1].detach().numpy()
recon = outputs[k][2].detach().numpy()
for i, item in enumerate(imgs):
if i >= 9: break
plt.subplot(2, 9, i+1)
# item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
# item: 1, 28, 28
plt.imshow(item[0])
for i, item in enumerate(recon):
if i >= 9: break
plt.subplot(2, 9, 9+i+1) # row_length + i + 1
# item = item.reshape(-1, 28,28) # -> use for Autoencoder_Linear
# item: 1, 28, 28
plt.imshow(item[0])
Let's inspect the output of the encoder¶
We'll see that it maps a 28x28 vector to a 64-dimensional vector.
model.encoder(images[0]).shape
torch.Size([64, 1, 1])
VAE - Variational Autoencoders¶
Variational Autoencoders (VAEs) are a type of generative model that learns to encode data into a probability distribution rather than a fixed vector. Unlike regular autoencoders, VAEs add randomness to the latent space and force it to follow a normal distribution.
The key differences from regular autoencoders are:
- The encoder outputs parameters (mean and variance) of a distribution instead of a single point
- A random sample is drawn from this distribution to create the latent vector
- An additional loss term (KL divergence) ensures the latent distribution matches a standard normal
- Uses the SiLU activation function
This probabilistic approach allows VAEs to:
- Generate new, realistic samples by sampling from the latent space
- Create smooth interpolations between data points
- Learn more robust and meaningful representations
class SelfAttention(nn.Module):
def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True):
super().__init__()
# This combines the Wq, Wk and Wv matrices into one matrix
self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias)
# This one represents the Wo matrix
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
self.n_heads = n_heads
self.d_head = d_embed // n_heads
def forward(self, x, causal_mask=False):
# x: # (Batch_Size, Seq_Len, Dim)
# (Batch_Size, Seq_Len, Dim)
input_shape = x.shape
# (Batch_Size, Seq_Len, Dim)
batch_size, sequence_length, d_embed = input_shape
# (Batch_Size, Seq_Len, H, Dim / H)
interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim * 3) -> 3 tensor of shape (Batch_Size, Seq_Len, Dim)
q, k, v = self.in_proj(x).chunk(3, dim=-1)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
q = q.view(interim_shape).transpose(1, 2)
k = k.view(interim_shape).transpose(1, 2)
v = v.view(interim_shape).transpose(1, 2)
# (Batch_Size, H, Seq_Len, Dim / H) @ (Batch_Size, H, Dim / H, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
weight = q @ k.transpose(-1, -2)
if causal_mask:
# Mask where the upper triangle (above the principal diagonal) is 1
mask = torch.ones_like(weight, dtype=torch.bool).triu(1)
# Fill the upper triangle with -inf
weight.masked_fill_(mask, -torch.inf)
# Divide by d_k (Dim / H).
# (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
weight /= math.sqrt(self.d_head)
# (Batch_Size, H, Seq_Len, Seq_Len) -> (Batch_Size, H, Seq_Len, Seq_Len)
weight = F.softmax(weight, dim=-1)
# (Batch_Size, H, Seq_Len, Seq_Len) @ (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, H, Seq_Len, Dim / H)
output = weight @ v
# (Batch_Size, H, Seq_Len, Dim / H) -> (Batch_Size, Seq_Len, H, Dim / H)
output = output.transpose(1, 2)
# (Batch_Size, Seq_Len, H, Dim / H) -> (Batch_Size, Seq_Len, Dim)
output = output.reshape(input_shape)
# (Batch_Size, Seq_Len, Dim) -> (Batch_Size, Seq_Len, Dim)
output = self.out_proj(output)
# (Batch_Size, Seq_Len, Dim)
return output
class CrossAttention(nn.Module):
def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
super().__init__()
self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
self.n_heads = n_heads
self.d_head = d_embed // n_heads
def forward(self, x, y):
# x (latent): # (Batch_Size, Seq_Len_Q, Dim_Q)
# y (context): # (Batch_Size, Seq_Len_KV, Dim_KV) = (Batch_Size, 77, 768)
input_shape = x.shape
batch_size, sequence_length, d_embed = input_shape
# Divide each embedding of Q into multiple heads such that d_heads * n_heads = Dim_Q
interim_shape = (batch_size, -1, self.n_heads, self.d_head)
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
q = self.q_proj(x)
# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
k = self.k_proj(y)
# (Batch_Size, Seq_Len_KV, Dim_KV) -> (Batch_Size, Seq_Len_KV, Dim_Q)
v = self.v_proj(y)
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
q = q.view(interim_shape).transpose(1, 2)
# (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
k = k.view(interim_shape).transpose(1, 2)
# (Batch_Size, Seq_Len_KV, Dim_Q) -> (Batch_Size, Seq_Len_KV, H, Dim_Q / H) -> (Batch_Size, H, Seq_Len_KV, Dim_Q / H)
v = v.view(interim_shape).transpose(1, 2)
# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) @ (Batch_Size, H, Dim_Q / H, Seq_Len_KV) -> (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
weight = q @ k.transpose(-1, -2)
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
weight /= math.sqrt(self.d_head)
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV)
weight = F.softmax(weight, dim=-1)
# (Batch_Size, H, Seq_Len_Q, Seq_Len_KV) @ (Batch_Size, H, Seq_Len_KV, Dim_Q / H) -> (Batch_Size, H, Seq_Len_Q, Dim_Q / H)
output = weight @ v
# (Batch_Size, H, Seq_Len_Q, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, H, Dim_Q / H)
output = output.transpose(1, 2).contiguous()
# (Batch_Size, Seq_Len_Q, H, Dim_Q / H) -> (Batch_Size, Seq_Len_Q, Dim_Q)
output = output.view(input_shape)
# (Batch_Size, Seq_Len_Q, Dim_Q) -> (Batch_Size, Seq_Len_Q, Dim_Q)
output = self.out_proj(output)
# (Batch_Size, Seq_Len_Q, Dim_Q)
return output
class VAE_AttentionBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.groupnorm = nn.GroupNorm(32, channels)
self.attention = SelfAttention(1, channels)
def forward(self, x):
# x: (Batch_Size, Features, Height, Width)
residue = x
# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
x = self.groupnorm(x)
n, c, h, w = x.shape
# (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height * Width)
x = x.view((n, c, h * w))
# (Batch_Size, Features, Height * Width) -> (Batch_Size, Height * Width, Features). Each pixel becomes a feature of size "Features", the sequence length is "Height * Width".
x = x.transpose(-1, -2)
# Perform self-attention WITHOUT mask
# (Batch_Size, Height * Width, Features) -> (Batch_Size, Height * Width, Features)
x = self.attention(x)
# (Batch_Size, Height * Width, Features) -> (Batch_Size, Features, Height * Width)
x = x.transpose(-1, -2)
# (Batch_Size, Features, Height * Width) -> (Batch_Size, Features, Height, Width)
x = x.view((n, c, h, w))
# (Batch_Size, Features, Height, Width) + (Batch_Size, Features, Height, Width) -> (Batch_Size, Features, Height, Width)
x += residue
# (Batch_Size, Features, Height, Width)
return x
class VAE_ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.groupnorm_1 = nn.GroupNorm(32, in_channels)
self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.groupnorm_2 = nn.GroupNorm(32, out_channels)
self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
if in_channels == out_channels:
self.residual_layer = nn.Identity()
else:
self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
def forward(self, x):
# x: (Batch_Size, In_Channels, Height, Width)
residue = x
# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
x = self.groupnorm_1(x)
# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, In_Channels, Height, Width)
x = F.silu(x)
# (Batch_Size, In_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
x = self.conv_1(x)
# (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
x = self.groupnorm_2(x)
# (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
x = F.silu(x)
# (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
x = self.conv_2(x)
# (Batch_Size, Out_Channels, Height, Width) -> (Batch_Size, Out_Channels, Height, Width)
return x + self.residual_layer(residue)
class VAE_Decoder(nn.Sequential):
def __init__(self):
super().__init__(
# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
nn.Conv2d(4, 4, kernel_size=1, padding=0),
# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
nn.Conv2d(4, 512, kernel_size=3, padding=1),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAE_AttentionBlock(512),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAE_ResidualBlock(512, 512),
# Repeats the rows and columns of the data by scale_factor (like when you resize an image by doubling its size).
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 4, Width / 4)
nn.Upsample(scale_factor=2),
# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
nn.Conv2d(512, 512, kernel_size=3, padding=1),
# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 2, Width / 2)
nn.Upsample(scale_factor=2),
# (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 512, Height / 2, Width / 2)
nn.Conv2d(512, 512, kernel_size=3, padding=1),
# (Batch_Size, 512, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
VAE_ResidualBlock(512, 256),
# (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
VAE_ResidualBlock(256, 256),
# (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
VAE_ResidualBlock(256, 256),
# (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height, Width)
nn.Upsample(scale_factor=2),
# (Batch_Size, 256, Height, Width) -> (Batch_Size, 256, Height, Width)
nn.Conv2d(256, 256, kernel_size=3, padding=1),
# (Batch_Size, 256, Height, Width) -> (Batch_Size, 128, Height, Width)
VAE_ResidualBlock(256, 128),
# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
VAE_ResidualBlock(128, 128),
# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
VAE_ResidualBlock(128, 128),
# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
nn.GroupNorm(32, 128),
# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
nn.SiLU(),
# (Batch_Size, 128, Height, Width) -> (Batch_Size, 3, Height, Width)
nn.Conv2d(128, 3, kernel_size=3, padding=1),
)
def forward(self, x):
# x: (Batch_Size, 4, Height / 8, Width / 8)
# Remove the scaling added by the Encoder.
x /= 0.18215
for module in self:
x = module(x)
# (Batch_Size, 3, Height, Width)
return x
class VAE_Encoder(nn.Sequential):
def __init__(self):
super().__init__(
# (Batch_Size, Channel, Height, Width) -> (Batch_Size, 128, Height, Width)
nn.Conv2d(3, 128, kernel_size=3, padding=1),
# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
VAE_ResidualBlock(128, 128),
# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height, Width)
VAE_ResidualBlock(128, 128),
# (Batch_Size, 128, Height, Width) -> (Batch_Size, 128, Height / 2, Width / 2)
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
# (Batch_Size, 128, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
VAE_ResidualBlock(128, 256),
# (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 2, Width / 2)
VAE_ResidualBlock(256, 256),
# (Batch_Size, 256, Height / 2, Width / 2) -> (Batch_Size, 256, Height / 4, Width / 4)
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
# (Batch_Size, 256, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
VAE_ResidualBlock(256, 512),
# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 4, Width / 4)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 4, Width / 4) -> (Batch_Size, 512, Height / 8, Width / 8)
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAE_AttentionBlock(512),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
VAE_ResidualBlock(512, 512),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
nn.GroupNorm(32, 512),
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 512, Height / 8, Width / 8)
nn.SiLU(),
# Because the padding=1, it means the width and height will increase by 2
# Out_Height = In_Height + Padding_Top + Padding_Bottom
# Out_Width = In_Width + Padding_Left + Padding_Right
# Since padding = 1 means Padding_Top = Padding_Bottom = Padding_Left = Padding_Right = 1,
# Since the Out_Width = In_Width + 2 (same for Out_Height), it will compensate for the Kernel size of 3
# (Batch_Size, 512, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8).
nn.Conv2d(512, 8, kernel_size=3, padding=1),
# (Batch_Size, 8, Height / 8, Width / 8) -> (Batch_Size, 8, Height / 8, Width / 8)
nn.Conv2d(8, 8, kernel_size=1, padding=0),
)
def forward(self, x, noise):
# x: (Batch_Size, Channel, Height, Width)
# noise: (Batch_Size, 4, Height / 8, Width / 8)
for module in self:
if getattr(module, 'stride', None) == (2, 2): # Padding at downsampling should be asymmetric (see #8)
# Pad: (Padding_Left, Padding_Right, Padding_Top, Padding_Bottom).
# Pad with zeros on the right and bottom.
# (Batch_Size, Channel, Height, Width) -> (Batch_Size, Channel, Height + Padding_Top + Padding_Bottom, Width + Padding_Left + Padding_Right) = (Batch_Size, Channel, Height + 1, Width + 1)
x = F.pad(x, (0, 1, 0, 1))
x = module(x)
# (Batch_Size, 8, Height / 8, Width / 8) -> two tensors of shape (Batch_Size, 4, Height / 8, Width / 8)
mean, log_variance = torch.chunk(x, 2, dim=1)
# Clamp the log variance between -30 and 20, so that the variance is between (circa) 1e-14 and 1e8.
# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
log_variance = torch.clamp(log_variance, -30, 20)
# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
variance = log_variance.exp()
# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
stdev = variance.sqrt()
# Transform N(0, 1) -> N(mean, stdev)
# (Batch_Size, 4, Height / 8, Width / 8) -> (Batch_Size, 4, Height / 8, Width / 8)
x = mean + stdev * noise
# Scale by a constant
# Constant taken from: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/configs/stable-diffusion/v1-inference.yaml#L17C1-L17C1
x *= 0.18215
return x