Implementing VAE
Updated:
Let’s implement VAE using MNIST dataset!
Load MNIST
Torch has MNIST dataset support. Let’s use them.
import torch
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
mnist_train = datasets.MNIST(root='MNIST_data/',
train=True,
transform=transforms.ToTensor(),
download=True)
mnist_test = datasets.MNIST(root='MNIST_data/',
train=False,
transform=transforms.ToTensor(),
download=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz
0%| | 0/9912422 [00:00<?, ?it/s]
Extracting MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz
0%| | 0/28881 [00:00<?, ?it/s]
Extracting MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz
0%| | 0/1648877 [00:00<?, ?it/s]
Extracting MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz
0%| | 0/4542 [00:00<?, ?it/s]
Extracting MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw
Let’s setup dataloaders.
BATCH_SIZE = 64
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
from torch.utils.data import DataLoader
dataloaders = {}
dataloaders["train"] = DataLoader(dataset=mnist_train,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True)
Model Implementation
VAE has encoder and decoder as its component, which is the same architecture as AE. However, the difference lies in sampling: μ and σ !
Encoder has two outputs, mu (μ) and log_var ($logσ^2$). The reason for taking $logσ^2$ rather than $σ$ itself is quite clear.
Encoder output may produce any real number, so it is not guaranteed to be positive. We need $σ$ to be bigger than $0$.
Using those techniques, encoder produces $\mathcal{N}(μ, σ)$.
Decoder gets latent $z$ ~ $ \mathcal{N}(μ, σ)$, and this $z$ reconstructs image.
class VAE(nn.Module):
def __init__(self):
super().__init__()
self.encoder = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 32),
nn.ReLU(),
)
self.fc_mu = nn.Linear(32, 10)
self.fc_log_var = nn.Linear(32, 10)
self.decoder = nn.Sequential(
nn.Linear(10, 32),
nn.ReLU(),
nn.Linear(32, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 784),
)
def forward(self, x):
x = x.reshape(-1, 1*28*28)
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
out = self.decoder(z)
out = out.reshape(-1, 1, 28, 28)
return out, mu, log_var
def encode(self, x):
h = self.encoder(x)
mu, log_var = self.fc_mu(h), self.fc_log_var(h)
return mu, log_var
def reparameterize(self, mu, log_var):
eps = torch.randn(mu.shape).to(mu.device)
std = torch.exp(0.5 * log_var)
z = mu + eps*std
return z
def decode(self, x):
return self.decoder(x)
However, model itself does not have a power two generate images. Loss term has two parts: $KLD$ and $Reconstruction$.
-
$KLD$ forces encoder’s outputs to follow certain distribution, namely $\mathcal{N}(0, I)$.
-
$Reconstruction$ minimizes difference of input and its decoded output.
from torch import optim
import torch.nn.functional as F
def loss_fn(x, recon_x, mu, log_var):
Recon_loss = F.mse_loss(recon_x.view(-1, 784), x.view(-1, 784), reduction = "sum")
KLD_loss = 0.5 * torch.sum(mu.pow(2) + log_var.exp() - 1 - log_var)
return Recon_loss + KLD_loss
model = VAE()
model.to(device)
optimizer = optim.Adam(model.parameters(), lr = 1e-03)
Train!
epoch = 10
from tqdm import tqdm
for e in range(1, epoch+1):
train_loss = 0.0
for x, _ in tqdm(dataloaders["train"]):
x = x.to(device)
x_recon, mu, log_var = model(x)
optimizer.zero_grad()
loss = loss_fn(x, x_recon, mu, log_var)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_loss /= len(dataloaders["train"].dataset)
print(f"EPOCH {e}: {train_loss}")
100%|██████████| 937/937 [00:15<00:00, 60.56it/s]
EPOCH 1: 46.222997973632815
100%|██████████| 937/937 [00:12<00:00, 73.47it/s]
EPOCH 2: 36.6455737508138
100%|██████████| 937/937 [00:12<00:00, 74.22it/s]
EPOCH 3: 34.71712950642904
100%|██████████| 937/937 [00:12<00:00, 73.54it/s]
EPOCH 4: 33.845694342041014
100%|██████████| 937/937 [00:13<00:00, 71.19it/s]
EPOCH 5: 33.33422303466797
100%|██████████| 937/937 [00:12<00:00, 72.83it/s]
EPOCH 6: 32.95128912760417
100%|██████████| 937/937 [00:13<00:00, 70.49it/s]
EPOCH 7: 32.67411953735352
100%|██████████| 937/937 [00:12<00:00, 73.87it/s]
EPOCH 8: 32.436136362711586
100%|██████████| 937/937 [00:12<00:00, 74.13it/s]
EPOCH 9: 32.26390340779622
100%|██████████| 937/937 [00:12<00:00, 73.89it/s]
EPOCH 10: 32.058827408854164
Generate!
Let’s try generating some images using our trained VAE.
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
def to_img(x):
x = x.clamp(0, 1)
return x
def show_image(img):
img = to_img(img)
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
To do so, let’s sample $z$ ~ $\mathcal{N}(0, I)$ and feed that as decoder’s input.
torch.randn is a way to do so.
with torch.no_grad():
# sample latent vectors from the normal distribution
latent = torch.randn(128, 10, device=device)
# reconstruct images from the latent vectors
img_recon = model.decoder(latent)
img_recon = img_recon.cpu()
fig, ax = plt.subplots(figsize=(10, 10))
imgs = img_recon.reshape(-1, 1, 28, 28).data[:100]
show_image(make_grid(imgs,10,5))
plt.show()
Leave a comment