Post

Generative Adversarial Networks (GANs)

Overview

Generative adversarial networks (GANs) a family of model architecture that use an adversarial training approach to produce a generative model capable of producing realistic synthetic data.

The generated data can be any type of data but GANs are commonly used to generate synthetic images, which is the type of data we will be working with in this post. GANs involve two key architectural components: the generator and the discriminator. During training, these two separate models work in tandem. The generator is tasked with producing synthetic data samples - images, in our case - and the discriminator is tasked with determining whether a given sample is a real sample from a ground truth dataset or is a synthetic sample produced by the generator.

As the generator learns to produce increasingly compelling synthetic samples, the discriminator also becomes increasingly effective at differentiating synthetic samples from real samples. Backpropagation performs gradient updates on both components since the output from the generator is used as input to the discriminator. When the optimizer updates parameters, it updates parameters of both components based on the same forward pass.

GANs suffer from instability during training and convergence is rare. There are formulations of loss functions that improve training stability but all GANs inherently involve some degree of training instability due to their adversarial nature.

GAN Objectives

Vanilla GANs

GANs were originally introduced in 2014 by Goodfellow et al. in the seminal paper Generative Adversarial Networks. GANs are presented in a couple different ways in this paper, with the original articulation taking the form:

\[\min_{G} \max_{D} \mathbb{E}_{x \sim p_{data}} \left[ \log(D(x)) \right] + \mathbb{E}_{z \sim p(z)} \left[ \log(1 - D(G(z))) \right]\]

Where:

  • $z \sim p(z)$ represents data sampled from uniform random noise
  • $x \sim p_{data}$ represents data sampled from the ground truth data set: images, in our case
  • $G$ represents the generator neural network
    • $G(x)$ represents an image/data produced by the generator
  • $D$ represents the discriminator neural network
    • $D(x)$ represents the binary judgment from the discriminator as to whether $x$ is a real (i.e. $D(x) =1$) or synthetic (i.e. $D(x) = 0$) sample

Goodfellow et al. analyze this particular articulation, establishing a relationship between this task and minimizing the Jensen-Shannon divergence between the distribution of the training data and the distribution of the synthetic data produced by the generator $G$.

Intuitively, this equation can be understood as updating the generator $G$ to minimize the probability of the discriminator making the correct choice, and updating the discriminator $D$ to maximize the probability of the discriminator making the correct choice.

Implementations of GANs, however, do not use this articulation since it tends to perform poorly when implemented due to vanishing gradients in the generator when the discriminator develops high degrees of confidence.

We instead articulate the objective of the generator in a slightly different way to yield better performance: maximize the probability of the discriminator making the incorrect choice. This approach was introduced by Goodfellow et al. and is by most GAN developers.

The slightly adjusted objective can then be broken up into two sub-objectives for the generator and discriminator updates, respectively:

\[\max_{G} \mathbb{E}_{z \sim p(z)} \left[ \log(D(G(z))) \right]\]

and

\[\max_{D} \mathbb{E}_{x \sim p_{data}} \left[ \log(D(x)) \right] + \mathbb{E}_{z \sim p(z)} \left[ \log(1 - D(G(z))) \right]\]

These two objectives correspond to the following loss functions for the generator $\ell_G$ and the discriminator $\ell_D$, respectively:

\[\ell_G = - \mathbb{E}_{z \sim p(z)} \left[ \log(D(G(z))) \right]\]

and

\[\ell_D = - \mathbb{E}_{x \sim p_{data}} \left[ \log(D(x)) \right] - \mathbb{E}_{z \sim p(z)} \left[ log(1 - D(G(z))) \right]\]

We negate these loss functions relative to the objectives stated above since our goal during training is generally to minimize the loss function. Also note that the expection $\mathbb{E}$ is computed as the average over each minibatch.

This approach yielded groundbreaking results when Goodfellow et al. published their paper circa 2014. We know now, however, that there are further improvements we can make to these objectives to improve the stability of the network during training and, therefore, the quality of both the generator and the discriminator.

A commonly-used improvement over the preceding objective articulation is that of the Least Squares GAN objectives.

Least Squares GANs

The GAN objective above can be adjusted to use a least squares approach. Empirically, GANs trained with these objectives are more stable to train and produce better results.

The least squares loss functions for the generator $\ell_G$ and the discriminator $\ell_D$ are very similar to the loss functions articulated in the previous section.

We can articulate the loss function for the generator $\ell_G$ as:

\[\ell_G = \frac{1}{2} \mathbb{E}_{z \sim p(z)} \left[ (D(G(z)) - 1)^2 \right]\]

$\ell_G$ penalizes the generator when the discriminator identifies a fake sample $G(z)$ produced by the generator as a synthetic example. When $D(G(z)) = 1$, the discriminator has identified the synthetic sample as a ground truth sample rather than as a synthetic sample originating from the generator, in which case, the loss for the generator is zeroed out. When $D(G(z)) = 0$, the discriminator has identified the sample as a synthetic sample originating from the generator, in which case, the loss for the generator becomes $(0-1)^2 = (-1)^2 = 1$.

And we can articulate the loss function for the discriminator $\ell_D$ as:

\[\ell_D = \frac{1}{2} \mathbb{E}_{x \sim p_{data}} \left[ (D(x) - 1)^2 \right] + \frac{1}{2} \mathbb{E}_{z \sim p(z)} \left[ D(G(z))^2 \right]\]

$\ell_D$ penalizes the discriminator in two ways: the first term penalizes the discriminator when $D(x) = 0$, denoting a synthetic sample, when $x$ was actually sampled from the ground truth data (i.e. $x \sim p_{data}$); the second term penalizes the discriminator when $D(G(z)) = 1$, indicating the discriminator identified $G(z)$ as an authentic sample despite the sample having originated from the generator and being based on uniform random noise $z \sim p(z)$.

In both $\ell_G$ and $\ell_D$, the $\frac{1}{2}$ coefficients are there solely to simplify the equation when the equations are differentiated during gradient computation.

In each loss function, the expectation $\mathbb{E}$ is computed by averaging over each minibatch.

Code

In this section, we provide the core logic for instantiating the dataset and dataloader objects, some helper functions, the loss functions, the generator and discriminator models, and the core training loop.

Note that we use the MNIST dataset provided through the torchvision package for this project.

The code provided in this post has been adapted from a personal project and has been adapted for this blog post. If you attempt to piece everything together into a functional, cohesive script, you may find that you need to make some changes to get it working. However, the core logic is all unchanged from my original implementation so it should work with minimal effort.

Dataset Plumbing

Before we jump in to the GAN-specific code, we will need to set up some plumbing.

We begin by importing some essential libraries:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np
import torch
import torch.nn as nn
import torch.tensor as tensor
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler
import torchvision.dataset as dset

BATCH_SIZE = 128
NOISE_DIM = 96
N_TRAIN_SAMPLES = 50000
N_VAL_SAMPLES = 5000

Let’s also define a helper function to generate uniform random noise corresponding to $z \sim p(z)$ in the mathematical definitions with which we will seed the generator model during training:

1
2
3
4
5
6
7
8
def sample_noise(batch_size: int, dim: int) -> np.ndarray:
    # Generate random numbers in range [-1.0, 1.0)
    rng = np.random.default_rng()
    return rng.uniform(
        low=-1.0,
        high=1.0,
        size=(batch_size, dim)
    )

We also define an OffsetSampler that will make it easy for us to train and validate on different subsets of the MNIST training corpus:

1
2
3
4
5
6
7
class OffsetSampler(Sampler):
    def __init__(self, n_samples, start=0):
        self.n_samples = n_samples
        self.start = start

    def __len__(self):
        return self.n_samples

We now instantiate our dataset and dataloader classes:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# T.ToTensor transposes dimensions from (H, W, C) to (C, H, W)
# and rescales pixels from [0, 255] to [0.0, 1.0].
train_mnist = dset.MNIST(
    './mnist-data',
    train=True,
    download=True,
    transform=T.ToTensor(),
)
train_loader = DataLoader(
    train_mnist,
    batch_size=128,
    sampler=OffsetSampler(N_TRAIN_SAMPLES, 0),
)

val_mnist = dset.MNIST(
    './mnist-data',
    train=True,
    download=True,
    transform=T.ToTensor(),
)
train_loader = DataLoader(
    val_mnist,
    batch_size=BATCH_SIZE,
    sampler=OffsetSampler(N_VAL_SAMPLES, N_TRAIN_SAMPLES),
)

Loss Functions

Next, we implement our least squares loss functions for the generator and the discriminator, respectively:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def LSGeneratorLoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()
    
    def forward(self, scores_fake: tensor) -> tensor:
        return (1/2) * torch.square(scores_fake - 1).mean()

def LSDiscriminatorLoss(nn.Module):
    def __init__(self) -> None:
        super().__init__()
    
    def forward(self, scores_fake: tensor, scores_real: tensor) -> tensor:
        return (
            (1/2) * torch.square(scores_real - 1).mean() +
            (1/2) * torch.square(scores_fake).mean()
        )

Models

We could theoretically implement the generator and discriminator models using any arbitrary architecture, so long as the input/output shapes match our needs.

However, we are going to implement models based on techniques introduced in the deep convolutional GAN paper by Radford et al. These deeply convolutional models are dramatically more capable of spatial reasoning because of the translational nature of the learned kernels in convolutional neural networks.

We define our models as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def DCGenerator(nn.Module):
    def __init__(self, noise_dim: int) -> None:
        super().__init__()

        self.transforms = nn.Sequential(
            nn.Linear(noise_dim, 1024),
            nn.ReLU(),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 128 * 7 * 7),
            nn.ReLU(),
            nn.BatchNorm1d(128 * 7 * 7),
            # Unflatten dim 1 since we want to keep
            # dim 0, the batch dimension, unchanged
            nn.Unflatten(1, (128, 7, 7)),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 1, 4, 2, 1),
            nn.Tanh(),
            nn.Flatten(),
        )
    
    def forward(self, X: tensor) -> tensor:
        return self.transforms(X)

def DCDiscriminator(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.classifier = nn.Sequential(
            # Unflatten dim 1 since we want to keep
            # dim 0, the batch dimension, unchanged
            nn.Unflatten(1, (1, 28, 28)),
            nn.Conv2d(1, 32, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, 5, 1),
            nn.LeakyReLU(0.01),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64 * 4 * 4),
            nn.LeakyReLU(0.01),
            nn.Linear(64 * 4 * 4, 1),
        )
    
    def forward(self, X: tensor) -> tensor:
        return self.classifier(X)

Core Training Loop

Finally, we can tie everything together with the following core training logic:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
device: torch.device = torch.device('cpu')
if torch.backends.mps.is_available():
    # Apple silicon GPU
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')

generator: DCGenerator = DCGenerator(NOISE_DIM).to(device)
gen_optim: torch.optim.Adam = torch.optim.Adam(
    generator.parameters(),
    lr=1e-3,
    betas=(0.5, 0.999),
)
gen_loss: LSGeneratorLoss = LSGeneratorLoss()

discriminator: DCDiscriminator = DCDiscriminator().to(device)
disc_optim: torch.optim.Adam = torch.optim.Adam(
    discriminator.parameters(),
    lr=1e-3,
    betas=(0.5, 0.999),
)
disc_loss: LSDiscriminatorLoss = LSDiscriminatorLoss()

for epoch in range(10):
    for x, _ in train_loader:
        x = x.to(device)

        disc_optim.zero_grad()
        real_data = x.to(torch.float32)
        # `real_data` is preprocessed by T.ToTensor to have values in
        # the [0.0, 1.0] interval, but the uniform random noise that
        # our generator consumes and the data it produces are in the
        # [-1.0, 1.0) interval.
        #
        # We therefore subtract 0.5 from `real_data`, yielding data in
        # the [-0.5, 0.5] interval, and then multiply by 2, yielding
        # data in the [-1.0, 1.0] interval so that the input to our
        # discriminator matches the output of our generator.
        logits_real = discriminator(2 * (real_data - 0.5)).to(torch.float32)

        g_fake_seed = sample_noise(BATCH_SIZE, NOISE_DIM).to(torch.float32)
        fake_images = generator(g_fake_seed).detach()
        logits_fake = discriminator(fake_images.view(BATCH_SIZE, 1, 28, 28))

        disc_loss(logits_fake, logits_real).backward()
        disc_optim.step()

        gen_optim.zero_grad()
        g_fake_seed = sample_noise(BATCH_SIZE, NOISE_DIM).to(torch.float32)
        fake_images = generator(g_fake_seed)

        gen_logits_fake = discriminator(fake_images.view(BATCH_SIZE, 1, 28, 28))
        gen_loss(gen_logits_fake).backward()
        gen_optim.step()

Conclusion

After approximately 1750 training batches, the generator is capable of producing the following digits:

Desktop View

These are compelling results from a relatively simple model trained based on a relatively simple objective and loss functions trained on consumer hardware!

The MNIST dataset is, of course, pretty much as simple as it gets when it comes to computer vision tasks since each pixel is black or white and the amount of detail required is ultimately quite low, but the architecture and concepts scale up very nicely - as long as you have the compute!

This post is licensed under CC BY 4.0 by the author.