Walking through score-based diffusion with SDE
0. Diffusion Model using Score Matching and SDE
Score matching is a technique that is used for the score-based genertive models, which regress on the scores, \(\nabla_x\log p(x)\) instead of modeling the original data distribution \(p(x)\). The score is the gradient of the data distribution so if we can access this gradient on every point of the data space, we can just follow this gradient and achieve the points with large probability density. This can be achieved by gradient ascend or Langevin dynamics which is similar to the stochastic gradient ascend.
\[x_{t+1} = x_t + \epsilon\nabla_x\log p(x) + \sqrt{2\epsilon} z\]where \(z~N(0, I)\). The goal of training the score-based generative model is to approximate the scores using a neural network \(s_\theta(x) \approx \epsilon\nabla_x\log p(x)\) everywhere on \(x\). If we have such model, we can use the above Langevin dynamics for sampling, just replacing the score with model estimates:
\[x_{t+1} = x_t + \epsilon s_\theta(x) + \sqrt{2\epsilon} z\]The score-matching objectve (or loss function) is the Fischer divergence between \(s_\theta(x)\) and \(\nabla_x\log p(x)\), weighted by \(p(x)\).
\[\mathbb{E}_{p(x)}||s_\theta(x) - \nabla_x\log p(x)||_2^2 = \int p(x)||s_\theta(x) - \nabla_x\log p(x)||_2^2\]The main issue with this simple score-matching is that in regions where \(p(x) \approx 0\), the score extimate will be inaccurate due to the zero weighting in these regions. So if we start our sample in such region and try to follow the score (or gradient), we are doing some random walk and will never get closer to the modes of data distribution. This is more severe in the high-dimensional case where data distribution is like a bunch of spikes (delta functions) in the data space.
The trick to solve this if to introduce noise to the data distribution, trying to widen the distribution and cover as much space as possible. The introduction of the noise will result in non-zero \(p_{\sigma_i}(\tilde{x}\mid x)\), where \(\tilde{x}\) is the noised data controlled by variance \(\sigma_i^2\), and cover larger data space for more accurate score estimate.
Score Matching with Langevin Dynamics (SMLD)
One way is to add different noises with increasing variances \(\{\sigma_0^2, \sigma_1^2, ..., \sigma_N^2\}\), as proposed in the Noise Conditional Score Network (NCSN). The network \(s_\theta(x)\) now need to be noise-conditioned, \(s_\theta(x, \sigma)\) in order to match the scores at diferent variances. The NSCN objective is now
\[\sum_{i=1}^N \sigma_i^2 \mathbb{E}_{p(x)}\mathbb{E}_{p_{\sigma_i}(\tilde{x} \mid x)} || s(\tilde{x}, \sigma_i) - \nabla_{\tilde{x}}\log p_{\sigma_i}(\tilde{x}\mid x) ||_2^2\]Denoising Diffusion Probabilistic Model (DDPM)
Another way is to add noise via a discrete Markov chain, \(p(x_i \mid x_{i-1}) = \mathcal{N}(x_i, \sqrt{1-\beta_i}x_{i-1}, \beta_i I)\). Notice that this Markov chain attenuates the signal and adds noise, instead of just adding noise to overwhelm the signal as done in SMLD. The DDPM objective is now
\[\sum_{i=1}^N (1-\alpha_i) \mathbb{E}_{p(x)}\mathbb{E}_{p_{\alpha_i}(\tilde{x}\mid x)} || s(\tilde{x}, i) - \nabla_{\tilde{x}}\log p_{\alpha_i}(\tilde{x}\mid x) ||_2^2\]Notice the similarity between the objectives of SMLD and DDPM.
1. General Continuous-Time Diffusion
The SMLD and DDPM were descrete-time, with a pre-defined variance schedule: \(\sigma_i^2\) for SMLD and \(\beta_i\) for DDPM. A general case of adding noise to data is to use stochastic differential equations (SDEs), which involves adding gaussian noise (or Brownian motion).
\[dx = f(x, t)dt + g(t)dw\]where \(f(x, t)\) is the drift term that depends on current \(x\) and time \(t\). \(g(t)\) is the diffusion term, controlling how much noise to add to the data at certain time \(t\) and \(dw\) is the infinitesimal Brownian motion. With this predefined stochastic process, the time reveral of the SDE has close form
\[dx = [f(x, t) - g^2(t) \nabla_x\log p_t(x)]dt + g(t)dw\]Note that here, in the time reversal, \(dt < 0\), so we are actually reversing the drift and following along the gradient (in the same direction of the score). We now just need to have a time-conditioned score model \(s_\theta(x, t)\) that matches \(\nabla_x\log p_t(x)\) everywhere, everytime.
There is also a guarantee that the distribution of \(x(t)\) following this SDE, is a normal distribution with mean \(m(t)\) and variance \(v(t)\). So we can write down the perturbation kernel or transitional kernel from data \(x(0)\) to noised data \(x(t)\):
\[p_{0t}\left(x(t) | x(0) \right) = \mathcal{N}\left(x(t); m(t), v(t)I \right)\]We are just writing this down for derivation of \(m(t)\) and \(v(t)\) later. Note here that the data is actually multi-dimensional, but each dimensional is treated as independent, so we can just treat everything in scalar form and write $I$ for the variance.
The denoising score matching objective is now
\[\mathcal{L}_{dsm} = \mathbb{E}_t \lambda(t) \mathbb{E}_{x(0)} \mathbb{E}_{x(t)\mid x(0)} || s_\theta(x(t), t) - \nabla_{x(t)}\log p_{0t}\left(x(t) \mid x(0) \right)||_2^2\]\(\lambda(t)\) is the positive time-dependent weighting and is proportional to the variance squared \(v^2(t)\) as done in SMLD and DDPM. In the maximum likelihood training proposed later, \(\lambda(t)\) is proportional to the diffusion term squared \(g^2(t)\).
The gradient term can be easily calculated in exact form since \(p_{0t}\) is a gaussian:
\[\nabla_{x(t)}\log p_{0t}\left(x(t) \mid x(0) \right) = \nabla_{x(t)}\log \mathcal{N}\left(x(t); m(t), v(t)I \right) = -\frac{x(t) - m(t)}{v(t)}\]How do we compute \(x(t)\) in practice? Again, using \(p_{0t}\):
\[x(t) = m(t) + \sqrt{v(t)}z; z\sim \mathcal{N}(z; 0, I)\]Plug \(x(t) - m(t) = \sqrt{v(t)}z\) above and then \(\mathcal{L}_{dsm}\) yields:
\[\mathcal{L}_{dsm} = \mathbb{E}_t \lambda(t) \mathbb{E}_{x(0)} \mathbb{E}_{x(t)\mid x(0)} \left|\left| s_\theta(x(t), t) + \frac{z}{\sqrt{v(t)}}\right|\right|_2^2\]The critical components are \(m(t)\) and \(v(t)\). Once we have them, we can compute the loss and train the score model.
SDE for SMLD (VE-SDE)
The discrete-time Markov chain for SMLD is \(x_i = x_{i-1} + \sqrt{\sigma_i^2 - \sigma_{i-1}^2} z_{i-1}\). In the continuous-time generalization:
\[x(t+\Delta t) = x(t) + \sqrt{\sigma^2(t+\Delta t) - \sigma^2(t)}z(t) \approx x(t) + \sqrt{\frac{d\sigma^2(t)}{dt}\Delta t}z(t)\]We combine \(\sqrt{\Delta t}z(t) = dw\). The continuous-time SMLD is then
\[dx = \sqrt{\frac{d\sigma^2(t)}{dt}}dw\]which is also called variance-exploding SDE (VE-SDE).
SDE for DDPM (VP-SDE)
The discrete-time Markov chain for DDPM is \(x_i = \sqrt{1-\beta_i}x_{i-1} + \sqrt{\beta_i}z_{i-1}\). In the continuous-time generalization:
\[x(t+\Delta t) = \sqrt{1-\beta(t+\Delta t)\Delta t}x(t) + \sqrt{\beta(t+\Delta t)\Delta t}z(t) \approx x(t) - \frac{1}{2}\beta(t+\Delta t)\Delta t x(t) + \sqrt{\beta(t)\Delta t}z(t)\]We combine \(\sqrt{\Delta t}z(t) = dw\). The continuous0time DDPM is then
\[dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)}dw\]which is also called variance-preserving SDE (VP-SDE).
2. Marginal mean and variance from SDE
Now we will use the SDEs to derive \(m(t)\) and \(v(t)\) for the mean and variance of the perturbation kernel \(p_{0t}\)
Given an SDE with affine continuous function $f$ and $g$,
\[dx = f(x, t)dt + g(x, t)dw\]Let
\[f(x, t) = A(t)x(t) + a(t)\] \[g(x, t) = B(t)x(t) + b(t)\]and
\[\mathbb{E}[x(t)] = m(t)\] \[\mathbb{Var}[x(t)] = v(t)\]will satisfy the following ODEs with initial conditions:
\[m'(t) = A(t)m(t) + a(t); m(0) = m_0\] \[v'(t) = 2A(t)v(t) + b^2(t); v(0) = v_0\]3. Solving variable coefficient ODEs
The above ODEs are variable coefficient ODEs and have general solution. The general ODE \(y'(t) = a(t)y(t) + b(t)\) has solution
\[y(t) = Ce^{A(t)} + e^{A(t)}\int e^{-A(t)}b(t)dt\]where
\[A(t) = \int a(t)dt\]4. Deriving perturbation kernels from SDE
The perturbation kernel \(p_{0t}\) for SDE here is Gaussian. We are after the mean \(m(t)\) and variance \(v(t)\) for the distribution of \(x(t)\) given initial data point \(x(0)\). Note that because \(x(0)\) is out data point and we treat its distribution as a delta function, or a super tight gaussian with mean \(x(0)\) and variance \(v(0)=0I\).
VE-SDE
\[dx = \sqrt{\frac{d\sigma^2(t)}{dt}}dw\]Using the above notation, \(A(t) = a(t) = B(t) = 0\) and \(b(t) = \sqrt{d\sigma^2(t)/dt}\).
\[m'(t) = 0 \Rightarrow m(t) = c = x(0)\] \[v'(t) = b^2(t) = \frac{d\sigma^2(t)}{dt} \Rightarrow v(t) = \sigma^2(t) + c = \sigma^2(t)\]Therefore,
\[p_{0t}(x(t) | x(0)) = \mathcal{N}\left(x(t); x(0), \sigma^2(t)I \right)\]VP-SDE
\[dx = -\frac{1}{2}\beta(t)x dt + \sqrt{\beta(t)}dw\]\(a(t) = B(t) = 0\), \(A(t) = -\beta(t)/2\) and \(b(t) = \sqrt{\beta(t)}\). Plug in the ODEs:
\[m'(t) = -\frac{1}{2} \beta(t)m(t) \Rightarrow m(t) = Ce^{\int_0^t-\frac{1}{2}\beta(s)ds} = x(0)e^{\int_0^t-\frac{1}{2}\beta(s)ds}\] \[v'(t) = -\beta(t)v(t)+\beta(t)\] \[v(t) = Ce^{-\int\beta(t)dt} + e^{-\int\beta(t)dt}\int e^{\int\beta(t)dt}\beta(t)dt = Ce^{-\int\beta(t)dt} + 1; v(0)=0 \Rightarrow C=-1\] \[v(t) = 1 - e^{-\int\beta(t)dt}\]Therefore,
\[p_{0t}(x(t) | x(0)) = \mathcal{N}\left(x(t); x(0)e^{\int_0^t-\frac{1}{2}\beta(s)ds}, (1 - e^{-\int_0^t\beta(s)ds})I \right)\]sub VP-SDE
In the score-based SDE model, the author introduces another SDE called sub-VP SDE
\[dx = -\frac{1}{2}\beta(t)dt + \sqrt{\beta(t)\left(1-e^{-2\int_0^t\beta(s)ds} \right)}dw\]Now we only change the \(b(t)\) so \(m(t)\) remains the same as VP-SDE.
\[v'(t) = -\beta(t)v(t) + \beta(t)\left(1-e^{-2\int_0^t\beta(s)ds} \right)\] \[v(t) = Ce^{-\int\beta(t)dt} + e^{-\int\beta(t)dt}\int e^{\int\beta(t)dt}\beta(t)dt + e^{-\int\beta(t)dt}\int e^{-\int\beta(t)dt}\beta(t)dt\]The first two terms are the same as before:
\[v(t) =Ce^{-\int_0^t\beta(s)ds} + 1 + e^{-\int_0^t\beta(s)ds}e^{-\int_0^t\beta(s)ds} = Ce^{-\int_0^t\beta(s)ds} + 1 + e^{-2\int_0^t\beta(s)ds}\]With $v(0) = 0 \Rightarrow C = -2$
\[v(t) = -2e^{-\int_0^t\beta(s)ds} + 1 + e^{-2\int_0^t\beta(s)ds} = \left(1-e^{-\int_0^t\beta(s)ds}\right)^2\]Therefore,
\[p_{0t}(x(t) | x(0)) = \mathcal{N}\left(x(t); x(0)e^{\int_0^t-\frac{1}{2}\beta(s)ds}, \left(1 - e^{-\int_0^t\beta(s)ds}\right)^2I \right)\]Note that the variance of sub VP-SDE is bounded by (or always less than) the variance of VP-SDE.
\[\forall t > 0; \left(1 - e^{-\int_0^t\beta(s)ds}\right)^2 \le 1 - e^{-\int_0^t\beta(s)ds}\]These 3 SDEs appear in Eq.(29) of [1] and Table 1 of [2]. In the following code, we will use VE-SDE as an example.
5. Notebook
The original notebook is provided by the author: Google Colab.
Most of the cells contain similar code to the DDPM, especially the time-conditional UNet model. I’ll just pick the cells I did not grasp during the first pass. If I have time after work, I might implement all the training and sampling for the above 3 SDEs.
Cell #5
Cell #5 sets up the VE-SDE scheduling for the diffusion coefficient \(g(t)\) and standard deviation \(\sqrt{v(t)}\). Note that in VE-SDE \(f(x, t) = 0\) so the mean of \(x(t)\): \(m(t) = 0\).
We set up the Stochastic Differential Equation (VE-SDE) as the following
\[dx = \sigma^t dw\]where \(\sigma > 1.0\) is the standard deviation by design and \(dw\) is the Wiener process.
This setup is not unique. The marginal probability standard deviation can be customized. The marginal probability variance is then
\[v(t) = \int_0^t g(s)^2ds\]One can try different type of SDEs. One can verify that if \(g(t) = \sigma^t\) then
\[v(t) = \frac{\sigma^{2t} - 1}{2\log{\sigma}}\]Cell #6
Cell #6 sets up the loss function for the training objective. Recall that the score-function \(s_\theta(x, t)\) has to match \(\nabla_{x(t)}\log p_t(x(t))\) at everytime for every training data \(x(0)\). The DSM loss is then the regression loss.
Cell #8
Cell #8 prepares sampling with 3 different methods: Euler-Maruyama, Predictor-Corrector and ODE.
Euler-Maruyama
Recall that SDE of the form
\[dx = f(x, t)dt + g(t)dw\]has the reverse-time SDE:
\[dx = \left[f(x, t) - g^2(t)\nabla_{x(t)}\log p_t(x(t)) \right]dt + g(t)dw\] \[dx = -\sigma^{2t} s_\theta(x, t) dt + \sigma^t dw; dt < 0\] \[x_{t-\Delta t} = \mathbf{x}_t + \sigma^{2t} s_\theta(x_t, t)\Delta t + \sigma^t\sqrt{\Delta t} z_t\]where \(z_t \sim \mathcal{N}(0, I)\).
Euler-Maruyama applies \(dt \sim \Delta t\) discretization.
Predictor-Corrector
Recall that given the score function \(s_\theta(x, t)\), we can sample via Langevin dynamics:
\[x_{i+1} = x_i + \epsilon \nabla_{x_i} \log p(x_i) + \sqrt{2\epsilon} z_i\]The PC sampling combines the ODE/SDE solver (Predictor) with $N$ steps of local Langevin dynamics (Correcor).
- Predictor: Use ODE/SDE solver for the next time step $x(t-dt)$ using $s(x(t), t)$
- Corrector: Still using the score $s(x(t-dt), t-dt)$ and Langevin dynamics to correct for $x(t-dt)$ for $N$ steps
The step size \(\epsilon\) is determined with predefined \(r\):
\[\epsilon = 2 \left(r \frac{\|z\|_2}{\|\nabla_{x} \log p(x)\|_2}\right)^2\]which is determined by the norm of the score. The idea behind is like the adaptive step size in the stochastic gradient descent where we want to take a smaller (more careful) step when the score/gradient (slope) is steep to avoid overshoot (or sliding).
Probability flow ODE
For probability flow ODE, the reverse process:
\[dx = \left[f(x, t) - \frac{1}{2}g^2(t)\nabla_{x(t)}\log p_t(x(t)) \right]dt\] \[dx = -\frac{1}{2}\sigma^{2t} s_\theta(x, t) dt\] \[\frac{dx}{dt} = -\frac{1}{2}\sigma^{2t} s_\theta(x, t)\]Now we need to integrate from \(t=T\) to \(t=0\)
\[x(t) = \int_T^t\frac{dx}{dt} dt + x(T) = \int_T^t -\frac{1}{2}\sigma^{2t} s_\theta(x, t) dt + x(T)\] \[x(0) = \int_T^0 -\frac{1}{2}\sigma^{2t} s_\theta(x, t) dt + x(T)\]the above can be solved using existent ODE solver, such as Runge-Kutta.
The results of the above sampling methods are shown below. From left to right are: Euler-Maruyama, Predictor-Corrector and Probability flow ODE.
6. Likelihood Computation
We can compute the likelihood \(\log p_0(x(0))\) using
\[\log p_0(x(0)) = \log p_T(x(T)) + \int_0^T \nabla \cdot \left[ -\frac{1}{2}\sigma^{2t} s_\theta(x, t) \right] dt\]We can use the above to compute the likelihood as bits per dimension, the lower the better.
7. References
- Song et al, Score-Based Generative Modeling through Stochastic Differential Equations, (link)
- Song et al, Maximum Likelihood Training of Score-Based Diffusion Models, (link)
- Original PyTorch Implementation: Google Colab
- Blog post on Score-Based SDE link