The variance schedule can be improved from linear to cosine
The model can be improved by learning the variances of \(p_\theta(x_{t-1} | x_t)\) instead of fixing it
The loss \(L = L_{simple} + \lambda L_{vlb}\)
The forward process is \(p(x_t | x_{t-1})\) and the reverse denoising process is \(q(x_{t-1} | x_t)\). A good close-form property is that we don’t need to apply \(q\) iteratively to sample \(x_t\). We have
This allows us to optimize random terms of the loss \(L\). I.e. randomly sample \(t\) during the training to optimize \(L_t\)
Shown in the Ho et al 2020, one can reparametrize the mean to make the network learn the added noise \(\epsilon_\theta(x_t, t)\) for noise level \(t\) in the KL terms which constitute the losses. This means that our NN becomes a noise predictor, rather than a direct mean predictor. The mean can be computed as
We take a random sample \(x_0\) from the real unknown and possibly complex distribution \(q(x_0)\)
We sample a noise level \(t\) uniformly (or complex) between \(1\) and \(T\) (i.e. random time step)
We sample some noise from a Gaussian distribution and corupt the input by this noise level \(t\) using above equations
The neural network is trained to predict this noise based on the corrupted image \(x_t\), i.e. noise applied on \(x_0\) based on known schedule \(\beta_t\)
In practice, this is done on batches of data as one uses SGD to optimize the neural network
0. Dependencies
1. Model Architecture
The model’s input and output have the same dimension. It uses the noisy image at \(t\) to predict the corresponding noise and then remove that noise based on variance scheduling to obtain slightly denoised image at \(t-1\) and so on until fully denoised at \(t=1\). The trick to make the model aware of what kind of filter (lowpass or highpass or something in between) it should behave lies in the time step embedding.
The model consists of 3 parts: UNet, Time Embedding and pixel-pixel Attention.
1.1 Blocks for UNet
Outputs:
The UNet blocks using ResNet or ConvNeXT convolutional models:
Output:
1.2 Time Embeddings
The time embeddings are from the sequence positional encoding of the Transformer paper.
Output:
1.3 Attention
2 variants of attention:
regular multi-head self-attention (as used in the Transformer),
linear attention variant (Shen et al., 2018), whose time- and memory requirements scale linear in the sequence length, as opposed to quadratic for regular attention.
Output:
1.4 The UNet
The job of the network $\epsilon_\theta(x_t, t)$ is to take in a batch of noisy images + noisy levels and output the noise added to the input. I.e. the network takes a batch of noisy images of shape (batch_size, num_channels, height, width) and batch of noise levels (batch_size, 1) as input. Note that noise levels are just batch_size random integers. And the network returns a tensor of the same shape (batch_size, num_channels, height, width) as the predicted noises.
The network is built up as follows:
First, a conv layer is applied on the batch of noisy images, positional embeddings are computed from the noisy level (an integer).
Second, a sequence of downsampling stahes are applkied. Each downsampling stage consists of
2 ResNet/ConvNeXT blocks
Group norm
Attention
Residual connection
Down sampling operation
At the middle of the network, ResNet/ConvNeXT blocks are applied, interleaved with attention
Next, a sequence of upsampling stages are applied
2 ResNet/ConvNeXT blocks
Group norm
Attention
Residual connection
Up sampling operation
Finally, a ResNet/ConvNeXT block followed by a conv layer is applied
Output:
2. Forward Diffusion (Fixed)
Gradually add noises to the imput data (images) in a number (\(T\)) of steps. Note that this forward process is designed and scheduled but the result of this forward process is stochastic due to the noises.
Output:
To standardize the images, we use the following to make images the same size: 128 x 128
Output:
The forward diffusion process of the image is gradually adding noise to the signal until the signal is destroyed.
Output:
3. The Loss Function
The loss function is the L2 or L1 loss between the added noise and the predicted noise. People have found that Huber loss gives better results because it’s less sensitive to outliers in the pixels that might be introduced once in a while by Gaussian.
Note that we train the noise predictor $\epsilon_\theta(x_t, t)$ to achieve this noise-2-noise prediction task
And here, $\epsilon_\theta$ is the UNet.
Output:
Before Going Too Far..
So now we have a forward diffusion process, a noise-2-noise model and a loss to train the model.
If we have 10 images, T=200 diffusion steps, one epoch needs to go through 2000 images with different noise levels, this number scales with the number of images ….
In training, typically this is a batch b:
t0, img3 | t0
t1, img1 | t1
t2, img9 | t2
t3, img1 | t3
…
tb, img8 | tb
So different from supervised label prediction, the model needs to be trained using a large number of epochs.
4. Data and Datalaoder for diffusion
I used the flower dataset for playing with the DDPM model.
Output:
Still, they are of different sizes and need to be standardized:
Output:
5. Sampling (Generating)
The sampling is to get a batch of images and time batch. The sampling can be done with no_grad() for speeding up.
The function sample is sampling pure Gaussian noises at \(t=T\) and then denoising it back to the original signal in a generative way.
6. Training
Output:
The DDPM model might also be quite memory expensive due to the IMAGE_SIZE and the Attention mechanism in the UNet (on top of the ResNet parameters..).
batch size > 24: memory explodes for 1 GPU
current speed = 240s / epoch @ batch_size = 24 on 1 GPU
current speed = 110s / epoch @ batch_size = 64 on 4 GPUs
current speed = 110s / epoch @ batch_size = 96 on 4 GPUs
current speed = 116s / epoch @ batch_size = 192 on 8 GPUs
7. Generating
After the model got trained, it’ll be interesting to take a look at some of the generated images:
Output:
8. Notes
The generative power of DDPM is really strong and the model formulation is cool and with physics intuition behind it.
The Time Embedding dictates the UNet to be lowpass-like filter in the later time points (\(t \sim T\)), generating contours or global shapes from the noises. In the earlier time points (\(t << T\)), it performs like a highpass filter, denoising the high frequency noises away and resulting in a high resolution images that approximates the image distribution of the training data set. One can attenuate the denoising power, say, making it less stochastic by averaging 10 or more denoised images together at one step. It will make the image less hallucinated but also more blurry as one tries to averaging out some high-resolution signals that correspond to different end results. It is also the tradeoff between FID and IS in generative models
Calvin Luo, Understanding Diffusion Models: A Unified Perspective, https://arxiv.org/abs/2208.11970, 2022
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, Attention Is All You Need, https://arxiv.org/abs/1706.03762, 2017