作者:京东零售 刘岩
扩散模型讲解
前沿
人工智能生成内容(AI Generated Content,AIGC)近年来成为了非常前沿的一个研究方向,生成模型目前有四个流派,分别是生成对抗网络(Generative Adversarial Models,GAN),变分自编码器(Variance Auto-Encoder,VAE),标准化流模型(Normalization Flow, NF)以及这里要介绍的扩散模型(Diffusion Models,DM)。扩散模型是受到热力学中的一个分支,它的思想来源是非平衡热力学(Non-equilibrium thermodynamics)。扩散模型的算法理论基础是通过变分推断(Variational Inference)训练参数化的马尔可夫链(Markov Chain),它在许多任务上展现了超过GAN等其它生成模型的效果,例如最近非常火热的OpenAI的DALL-E 2,Stability.ai的Stable Diffusion等。这些效果惊艳的模型扩散模型的理论基础便是我们这里要介绍的提出扩散模型的文章[1]和非常重要的DDPM[2],扩散模型的实现并不复杂,但其背后的数学原理却非常丰富。在这里我会介绍这些重要的数学原理,但省去了这些公式的推导计算,如果你对这些推导感兴趣,可以学习参考文献[4,5,11]的相关内容。我在这里主要以一个相对简单的角度来讲解扩散模型,帮助你快速入门这个非常重要的生成算法。
1. 背景知识: 生成模型
生成模型的本质是通过一个已知的概率模型来拟合所给的数据样本,也就是说,我们往往需要通过模型得到一个带参数的分布。即如果训练数据的分布是\(p_\text{data}(x)\),生成样本的分布为\(p_\text{model}(x)\),我们希望得到的分布和训练数据的分布尽可能相似。目前生成模型主要有图1所示的四类。其中GAN的原理是通过判别器和生成器的互相博弈来让生成器生成足以以假乱真的图像。VAE的原理是通过一个编码器将输入图像编码成特征向量,它用来学习高斯分布的均值和方差,而解码器则可以将特征向量转化为生成图像,它侧重于学习生成能力。流模型是从一个简单的分布开始,通过一系列可逆的转换函数将分布转化成目标分布。扩散模型先通过正向过程将噪声逐渐加入到数据中,然后通过反向过程预测每一步加入的噪声,通过将噪声去掉的方式逐渐还原得到无噪声的图像,扩散模型本质上是一个马尔可夫架构,只是其中训练过程用到了深度学习的BP,但它更属于数学层面的创新。这也就是为什么很多计算机的同学看扩散模型相关的论文会如此费力。
图1:生成模型的四种类型 [4]
扩散模型中最重要的思想根基是马尔可夫链,它的一个关键性质是平稳性。即如果一个概率随时间变化,那么再马尔可夫链的作用下,它会趋向于某种平稳分布,时间越长,分布越平稳。如图2所示,当你向一滴水中滴入一滴颜料时,无论你滴在什么位置,只要时间足够长,最终颜料都会均匀的分布在水溶液中。这也就是扩散模型的前向过程。
图2:颜料分子在水溶液中的扩散过程
如果我们能够在扩散的过程颜料分子的位置、移动速度、方向等移动属性。那么也可以根据正向过程的保存的移动属性从一杯被溶解了颜料的水中反推颜料的滴入位置。这边是扩散模型的反向过程。记录移动属性的快照便是我们要训练的模型。
2. 扩散模型
在这一部分我们将集中介绍扩散模型的数学原理以及推导的几个重要性质,因为推导过程涉及大量的数学知识但是对理解扩散模型本身思想并无太大帮助,所以这里我会省去推导的过程而直接给出结论。但是我也会给出推导过程的出处,对其中的推导过程比较感兴趣的请自行查看。
2.1 计算原理
扩散模型简单的讲就是通过神经网络学习从纯噪声数据逐渐对数据进行去噪的过程,它包含两个步骤,如图3:
- 固定的前向过程\(q\):在这一步我们逐渐将高斯噪声添加到图像中,直到得到一个纯噪声的图像;
- 可学习的反向去噪过程\(p_\theta\):在这一步我们从纯噪声图像中逐渐对其进行去噪,直到得到真实的图像。
图3:DDPM的前向加噪和后向去噪过程
更具体些,对于一个\(T\)步的扩散模型,每一步的索引为\(t\)。在前向过程中,我们从一个真实图像\(\boldsymbol x_0\)开始,在每一步我们随机生成一些高斯噪声,然后将生成的噪声逐步加入到输入图像中,当\(T\)足够大时,我们得到的加噪后的图像便接近一个高斯噪声图像,例如DDPM中 \(T=1000\)。在后向过程中,我们从噪声图像\(\boldsymbol x_T\)开始(训练时是真实图像加噪的结果,采样时是随机噪声),通过一个神经网络学习\(\boldsymbol x_{t-1}\)到\(\boldsymbol x_t\)添加的噪声,然后通过逐渐去噪的方式得到最后要生成的图像。
2.1.1 前向过程
前向过程即扩散过程指的是向数据中逐渐添加高斯噪声直到数据完全变成噪声的过程。假设\(q(\boldsymbol x_0)\)是真实图像的分布,我们可以通过从训练集的真实图像中随机采样一张图像,表示为\(\boldsymbol x \sim q(\boldsymbol x_0)\)。那么前向过程\(q(\boldsymbol x_t | \boldsymbol x_{t-1})\)指的是在前向的每一步通过向图像\(\boldsymbol x_{t-1}\)中添加高斯噪声得到\(\boldsymbol x_t\)。我们知道,一个高斯分布由均值\(\mu\)和方差\(\sigma^2\)定义。那么在每一步向图像中添加高斯噪声的过程表示为式(1),它是一个均值\(\mu_t = \sqrt{1-\beta_t} \boldsymbol x_{t-1}\),方差\(\sigma^2_t = \beta_t\)的高斯分布。
\[q(\boldsymbol x_t | \boldsymbol x_{t-1}) = \mathcal N(\boldsymbol x_t | \sqrt{1 – \beta_t} \boldsymbol x_{t-1}, \beta_t \mathbf I) \tag1 \]
具体到每一步的计算时,我们先采样一个二维标准高斯分布\(\epsilon \sim \mathcal N(\boldsymbol 0, \mathbf I)\),然后通过参数\(\beta_t\)由\(\boldsymbol x_{t-1}\)得到\(\boldsymbol x_t\),表示为\(\boldsymbol x_t = \sqrt{1-\beta_t} \boldsymbol x_{t-1} + \sqrt{\beta_t} \epsilon\)。注意这里\(\beta_t\)并不是一个常数值,而是一个随时间变化的变量。\(\beta_t\)的变化情况在扩散模型中被叫做差异时间表(Variance Schedule),常见的时间表策略有线性时间表(DDPM),平方时间表,cosine时间表[6]等。在DDPM扩散模型的前项过程中,我们需要保证\(T\)足够大并且\(\beta_t\)的时间表配置合理,才能保证我们最终得到的\(\boldsymbol x_t\)也将是一个高斯噪声图像。扩散模型的全部前向过程可以表示为从\(t=1\)到\(t=T\)的时刻的马尔可夫链,如式(2)。
\[q(\boldsymbol x_{0:T}) = q(\boldsymbol x_0) \prod_{t=1}^T q\left( \boldsymbol x_t | \boldsymbol x_{t-1} \right) \tag2 \]
扩散过程一个隐藏的重要特征是我们可以直接基于原始数据\(\boldsymbol x_0\)来对任意\(t\)步的\(\boldsymbol x_t\)进行采样。这里我们定义\(\bar{a}_t = \prod_{s=1}^t \alpha_t\)以及\(\alpha_t = 1 – \beta_t\),通过重参数(Reparamazation)技巧,我们可以得到服从分布\(q(\boldsymbol x_t| \boldsymbol x_0)\)的任意一个样本\(\boldsymbol x_t\),我们可以通过反重参数化得到式(3),推导过程见参考文献[11]的式(61)到式(70)。
\[\begin{aligned} \boldsymbol x_t & \sim q(\boldsymbol x_t | \boldsymbol x_0) \\ & = \sqrt{\bar{\alpha_t}} \boldsymbol x_0 + \sqrt{1 – \bar{\alpha_t}}\epsilon \\ & = \mathcal N(\boldsymbol x_t; \sqrt{\bar{\alpha}_t} \boldsymbol x_0, (1 – \bar{\alpha}_t) \mathbf I) \end{aligned} \tag3 \]
上面推理反应了一个重要的性质,即\(\boldsymbol x_t\)可以看做原始数据\(\boldsymbol x_0\)和随机噪声\(\boldsymbol \epsilon\)的线性组合,其中\(\sqrt{\bar{\alpha}_t}\)和\(\sqrt{1 – \bar{\alpha}_t}\)是组合系数,它们的平方和为\(1\)。公式(3)在论文中被叫做“Nice Property”。进一步的,我们可以使用\(\bar \alpha\)来定义差异时间表,例如[6]中的cosine时间表便是这么做的。\(\bar{\alpha}_t\)要比\(\beta\)更直接,例如我们可以通过将\(\bar{\alpha}_T\)设置为一个接近\(0\)的值,来使得最终得到的噪声更倾向于是一个高斯噪声。
2.1.2 后向过程
前向过程是将数据噪声化的过程,那么扩散模型的后向过程\(p(\boldsymbol x_{t-1} | \boldsymbol x_t)\)则是一个去噪过程。即我们先在\(T\)时刻随机采样一个二维高斯噪声,然后逐步进行去噪,最终得到一个和真实图像分布一致的生成图像\(\boldsymbol x_0\)。
所以扩散模型的核心是如何进行这个去噪过程,因为我们并不知道\(p(\boldsymbol x_{t-1} | \boldsymbol x_t)\)的具体形式是什么。扩散模型指出,我们可以使用一个神经网络学习这个去噪过程。因为第\(t\)时刻的分布\(\boldsymbol x_t\)是已知的,因此我们这个神经网络的目标是根据\(\boldsymbol x_t\)去学习\(\boldsymbol x_{t-1}\)的概率分布。综上,扩散模型的后向过程表示为\(p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t)\),其中\(\theta\)是神经网络的参数,我们可以用SGD等策略对该网络进行优化。
因为前向过程我们添加的噪声是高斯噪声,为了简化模型的训练难度,我们假设反向的去噪过程去掉的噪声也是高斯噪声。因为一个高斯分布是通过均值\(\mu_\theta\)和方差\(\sum_\theta\)决定的,那么\(p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t)\)可以表示为式(4)的形式。
\[p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t) = \mathcal N(\boldsymbol x_{t-1}; \mu_\theta(\boldsymbol x_t,t), \sum_\theta(\boldsymbol x_t, t)) \tag4 \]
其中均值和方差均是根据模型计算得到的。综合所有时间步,我们也可以通过马尔可夫链得到扩散模型的后向过程,如式(5)。
\[p_\theta(\boldsymbol x_{0:T}) = p(x_T) \prod_{t=1}^T p_\theta (\boldsymbol x_{t-1} | \boldsymbol x_t) \tag5 \]
其中\(p(\boldsymbol x_t) = \mathcal N(\boldsymbol x_T; \boldsymbol 0, \mathbf I)\)是随机采样的高斯噪声,\(p_\theta (\boldsymbol x_{t-1} | \boldsymbol x_t)\)是一个均值和方差需要计算的高斯分布。
2.1.3 目标函数
那么问题来了,我们究竟使用什么样的优化目标才能比较好的预测高斯噪声的分布呢?一个比较复杂的方式是使用变分自编码器的最大化证据下界(Evidence Lower Bound, ELBO)的思想来推导,如式(6),推导详细过程见论文[11]的式(47)到式(58),这里主要用到了贝叶斯定理和琴生不等式。
\[\begin{aligned} \mathcal L & = – \log p(\boldsymbol x) \\ & = – \log \int \frac{p_\theta(\boldsymbol x_{0:T})q(\boldsymbol x_{1:T} | \boldsymbol x_0)}{q(\boldsymbol x_{1:T} | \boldsymbol x_0)} d \boldsymbol x_{1:T} \\ & \leq – \mathbb E_{q(\boldsymbol x_{1:T} | \boldsymbol x_0)} \left[ \frac{p_\theta(\boldsymbol x_{0:T})}{q(\boldsymbol x_{1:T} | \boldsymbol x_0)}\right] \\ & = – \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]}_{\text {重构项}} + \underbrace{D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right) \| p\left(\boldsymbol{x}_T\right)\right)}_{\text {先验匹配项}} + \sum_{t=2}^T \underbrace{\mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) \| p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)\right)\right]}_{\text {去噪匹配项}} \end{aligned} \tag6 \]
式(6)的推导细节并不重要,我们需要重点关注的是它的最终等式的三个组成部分,下面我们分别介绍它们:
(1) 重构项 \(\mathcal L_0 = \mathbb{E}_{q\left(\boldsymbol{x}_1 \mid \boldsymbol{x}_0\right)}\left[\log p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right)\right]\),它的作用是对原始数据进行重构,优化的是负log似然。DDPM提供的计算方式是:它首先将离散的图像像素从\([0, 255]\)归一化到了\([-1,1]\)的范围,然后用式(3)中估计的$\mathcal N(\boldsymbol x_t; \sqrt{\bar{\alpha}_t} \boldsymbol x_0, (1 – \bar{\alpha}_t) \boldsymbol I) $构建一个离散的解码器来计算,如式(7)。它计算的是高斯分布落在以Ground Truth为中心,且范围大小为\(2/255\)时的概率积分,即累积分布函数(Cumulative Distribution Function,CDF)。
\[\begin{aligned} p_\theta\left(\boldsymbol{x}_0 \mid \boldsymbol{x}_1\right) & =\prod_{i=1}^D \int_{\delta_{-}\left(x_0^i\right)}^{\delta_{+}\left(x_0^i\right)} \mathcal{N}\left(x ; \mu_\theta^i\left(\boldsymbol{x}_1, 1\right), \sigma_1^2\right) d x \\ \delta_{+}(x) & =\left\{\begin{array}{ll} \infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1 \end{array} \quad \delta_{-}(x)= \begin{cases}-\infty & \text { if } x=-1 \\ x-\frac{1}{255} & \text { if } x>-1\end{cases} \right. \end{aligned} \tag7 \]
其中\(D\)是整张图像,\(i\)是图像上的像素点坐标。
(2) 先验重构项 \(\mathcal L_T=D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_T \mid \boldsymbol{x}_0\right) \| p\left(\boldsymbol{x}_T\right)\right)\):它使用KL散度计算了最后的噪声输入和标准的高斯先验的接近程度,因为这一部分没有可以训练的参数,我们可以将它视作常数0。
(3) 去噪匹配项 \(\mathcal L_{t-1} = \mathbb{E}_{q\left(\boldsymbol{x}_t \mid \boldsymbol{x}_0\right)}\left[D_{\mathrm{KL}}\left(q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right) \| p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)\right)\right]\):它计算的是真实后验分布\(q\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t, \boldsymbol{x}_0\right)\)和预测的分布\(p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right)\)之间的KL散度。因为我们希望真实的去噪过程和模型预测的去噪过程完全一致,如图4。
图4:扩散模型的去噪匹配项在每一步都要拟合噪音的真实后验分布和估计分布
真实后验分布可以使用贝叶斯定理进行推导,最终结果如式(8),推导过程见论文[11]的式(71)到式(84)。
\[\begin{aligned} q(\boldsymbol x_{t-1} | \boldsymbol x_t, \boldsymbol x_0) & = \frac{q(\boldsymbol x_{t} | \boldsymbol x_{t-1}, \boldsymbol x_0) q(\boldsymbol x_{t-1} | \boldsymbol x_0)}{q(\boldsymbol x_{t} | \boldsymbol x_0)} \\ & \propto \mathcal N \left( \boldsymbol x_{t-1}; \frac{\sqrt{\alpha_t} (1 – \bar{\alpha}_{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}_{t-1}}(1 – \alpha_t) \boldsymbol x_0}{ 1- \bar{\alpha}_t}, \frac{(1 – \alpha_t)(1 – \bar{\alpha}_{t-1})}{1 – \bar{\alpha}_t} \mathbf I \right) \\ & = \mathcal N(\boldsymbol x_{t-1}; \mu_q(\boldsymbol x_t, \boldsymbol x_0), \Sigma_q(t)) \end{aligned} \tag8 \]
其中\(\mu_q(\boldsymbol x_t, \boldsymbol x_0) = \frac{\sqrt{\alpha_t} (1 – \bar{\alpha}_{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}_{t-1}}(1 – \alpha_t) \boldsymbol x_0}{ 1- \bar{\alpha}_t}\)是均值,\(\Sigma_q(t) = \frac{(1 – \alpha_t)(1 – \bar{\alpha}_{t-1})}{1 – \bar{\alpha}_t} \mathbf I\)是方差。可以看出均值是和\(\boldsymbol x_t\)和\(\boldsymbol x_0\)相关的,但是方差与数据无关。假设预测的分布也服从正态分布,表示为式(9)。
\[p_{\boldsymbol{\theta}}\left(\boldsymbol{x}_{t-1} \mid \boldsymbol{x}_t\right) = \mathcal N(\boldsymbol x_{t-1}; \mu_\theta(\boldsymbol x_t, t), \Sigma_q(t)) \tag9 \]
为了进一步化简\(\mathcal L_{t-1}\)我们需要用到式(10)的两个高斯分布的KL散度计算公式。
\[\begin{aligned} & D_\text{KL}(\mathcal N(\boldsymbol x; \boldsymbol \mu_x, \boldsymbol \Sigma_x), \mathcal N(\boldsymbol y; \boldsymbol \mu_y, \boldsymbol \Sigma_y) \\ = & \frac{1}{2}\left[ \log \frac{|\boldsymbol \Sigma_x|}{|\boldsymbol \Sigma_y|} – d + \text{tr}(\boldsymbol \Sigma_y^{-1} \boldsymbol \Sigma_x) + (\boldsymbol \mu_y – \boldsymbol \mu_x)^\intercal \boldsymbol \sigma_y^{-1}(\boldsymbol \mu_y – \boldsymbol \mu_x)\right]) \end{aligned} \tag{10} \]
在两个分布均是高斯分布的前提下,我们可以使用公式(10)继续对\(\mathcal L_{t-1}\)进行进行计算,这一部分完整的推导流程参考论文[11]的式(87)到式(92)。从这里我们可以看出,当两个高斯分布方差相同时,求它们之间的KL散度既是求两个分布的均值的l2距离。
\[\begin{aligned} \mathop{\arg\min}_\theta D_\text{KL}(q(\boldsymbol x_{t-1} | \boldsymbol x_t, \boldsymbol x_0) || p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t)) = \mathop{\arg\min}_\theta \frac{1}{2\sigma_q^2(t)} \left[\|\boldsymbol \mu_\theta (\boldsymbol x_t, \boldsymbol x_0) – \boldsymbol \mu_q(\boldsymbol x_t, t) \|^2_2\right] \\ \end{aligned} \tag{11} \]
通过式(3)我们可以得到\(\boldsymbol x_0 = \frac{\boldsymbol x_t – \sqrt{1 – \bar\alpha_t}\boldsymbol \epsilon_0}{\sqrt {\bar \alpha_t}}\),将它代入到\(\mu_q(\boldsymbol x_t, \boldsymbol x_0) = \frac{\sqrt{\alpha_t} (1 – \bar{\alpha}_{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}_{t-1}}(1 – \alpha_t) \boldsymbol x_0}{ 1- \bar{\alpha}_t}\)中,我们可以得到式(12)的最终结果,推导过程见参考文献(11)的式(116)到式(124)。
\[\begin{aligned} \boldsymbol \mu_q(\boldsymbol x_t, \boldsymbol x_0) & = \frac{\sqrt{\alpha_t} (1 – \bar{\alpha}_{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}_{t-1}}(1 – \alpha_t) \boldsymbol x_0}{ 1- \bar{\alpha}_t} \\ & = \frac{\sqrt{\alpha_t} (1 – \bar{\alpha}_{t-1}) \boldsymbol x_t + \sqrt{\bar{\alpha}_{t-1}}(1 – \alpha_t) \frac{\boldsymbol x_t – \sqrt{1 – \bar\alpha_t}\boldsymbol \epsilon_0}{\sqrt {\bar \alpha_t}}}{ 1- \bar{\alpha}_t} \\ & = \frac{1}{\sqrt{\alpha_t}} \boldsymbol x_t – \frac{1-\alpha_t}{\sqrt{1 – \bar{\alpha_t}}{\sqrt{\alpha_t}}} \boldsymbol \epsilon_0 \end{aligned} \tag{12} \]
同理,我们也可以用这种方式计算$ \mu_\theta (\boldsymbol x_t, \boldsymbol x_0)$,如式(13):
\[\mu_\theta (\boldsymbol x_t, t) = \frac{1}{\sqrt{\alpha_t}} \boldsymbol x_t – \frac{1-\alpha_t}{\sqrt{1 – \bar{\alpha_t}}{\sqrt{\alpha_t}}} \hat {\boldsymbol \epsilon }_\theta(\boldsymbol x_t, t) \tag{13} \]
将式(12)和式(13)带入到式(11)中,我们可以得到:
\[\begin{aligned} & \mathop{\arg\min}_\theta D_\text{KL}(q(\boldsymbol x_{t-1} | \boldsymbol x_t, \boldsymbol x_0) || p_\theta(\boldsymbol x_{t-1} | \boldsymbol x_t)) \\ = \; & \mathop{\arg\min}_\theta \frac{1}{2\sigma^2_q(t)} \frac{(1-\alpha_t)^2}{(1-\bar{\alpha}_t)\alpha_t} \left[ \| \boldsymbol \epsilon_0 – \hat{\boldsymbol \epsilon}_\theta (\boldsymbol x_t, t)\|^2_2 \right] \\ = \; & \mathop{\arg\min}_\theta \frac{1}{2\sigma^2_q(t)} \frac{(1-\alpha_t)^2}{(1-\bar{\alpha}_t)\alpha_t} \left[ \| \boldsymbol \epsilon_0 – \hat{\boldsymbol \epsilon}_\theta (\sqrt{\bar{\alpha}_t} \boldsymbol x_0 + \sqrt{ 1- \bar{\alpha}_t} \boldsymbol \epsilon , t)\|^2_2 \right] \end{aligned} \tag{14} \]
最终我们可以将扩散模型的损失函数简化为式(15)的形式。其中\(\boldsymbol \epsilon_t\)是添加的高斯噪声,\(\hat{\boldsymbol \epsilon}_\theta(\boldsymbol x_t, t)\)是一个神经网络,用于预测从\(\boldsymbol x_0\)到\(\boldsymbol x_t\)时刻添加的噪声。
\[\mathcal L_\text{simple} = \mathbb E_{t, \boldsymbol x_0, \epsilon} \| \boldsymbol \epsilon_t – \hat{\boldsymbol \epsilon}_\theta (\sqrt{\bar{\alpha}_t} \boldsymbol x_0 + \sqrt{ 1- \bar{\alpha}_t} \boldsymbol \epsilon , t)\|^2 \tag{15} \]
2.1.4 模型训练
在第2.1.1节我们讲到我们可以直接基于原始数据\(\boldsymbol x_0\)来对任意\(t\)步的\(\boldsymbol x_t\)进行采样。那么在实际训练过程中,我们不必将所有的时间片都拿来训练。而采取直接采样到时刻\(t\),然后得到该时刻的\(\boldsymbol x_t\)并使用神经网络预测添加的噪声即可。因为扩散模型的\(T\)是一个非常大的值,使用这种方式将大幅提升训练速度。它的训练过程为:
- 从分布为\(q(\boldsymbol x_0)\)的数据集随机采样一个样本\(\boldsymbol x_0 \sim q(\boldsymbol x_0)\);
- 从\(1\)到\(T\)中随机采样一个值\(t\),用于表示添加噪声的水平;
- 随机采样一个二维高斯噪音\(\epsilon\),然后使用上面介绍的“Nice Property”对\(\boldsymbol x_0\)施加\(t\)级别的噪声;
- 训练神经网络根据加噪之后的\(\boldsymbol x_t\)预测作用到\(\boldsymbol x_0\)之上的噪声。
虽然上面我们介绍了很多内容,并给出了大量公式,但得益于推导出的几个重要性质,扩散模型的训练并不复杂,它的训练伪代码见算法1。
2.1.5 样本生成
正如我们所介绍的,扩散模型的生成过程是一个反向去噪的过程,它的伪代码见算法2。具体的讲,我们从\(T\)时刻开始,首先随机采样一个高斯噪声。然后我们使用神经网络预测的噪声逐渐对其去噪,直到\(0\)时刻停止。通过式(9)我们得到了\(\boldsymbol x_{t-1} \sim p_\theta(\boldsymbol x_{t-1}|\boldsymbol x_t)\),那么我们可以进一步得到从\(\boldsymbol x_t\)到\(\boldsymbol x_{t-1}\)的计算公式,如式(16)。
\[\boldsymbol x_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left( \boldsymbol x_t – \frac{1 – \alpha_t}{\sqrt{1 – \bar{\alpha}_t}}\epsilon(\boldsymbol x_t, t)\right) + \sigma_t \boldsymbol z \tag{16} \]
其中\(\boldsymbol z\sim \mathcal N(\mathbf 0, \mathbf I)\)是一个二维标准高斯分布。
2.2 算法实现
2.2.1模型结构
DDPM在预测施加的噪声时,它的输入是施加噪声之后的图像,预测内容是和输入图像相同尺寸的噪声,所以它可以看做一个Img2Img的任务。DDPM选择了U-Net[9]作为噪声预测的模型结构。U-Net是一个U形的网络结构,它由编码器,解码器以及编码器和解码器之间的跨层连接(残差连接)组成。其中编码器将图像降采样成一个特征,解码器将这个特征上采样为目标噪声,跨层连接用于拼接编码器和解码器之间的特征。
图5:U-Net的网络结构
下面我们介绍DDPM的模型结构的重要组件。首先在U-Net的卷积部分,DDPM使用了宽残差网络(Wide Residual Network,WRN)[12]作为核心结构,WRN是一个比标准残差网络层数更少,但是通道数更多的网络结构。也有作者复现发现ConvNeXT作为基础结构会取得非常显著的效果提升[13,14]。这里我们可以根据训练资源灵活的调整卷积结构以及具体的层数等超参。因为我们在扩散过程的整个流程中都共享同一套参数,为了区分不同的时间片,作者借鉴了Transformer [15]的位置编码的思想,采用了正弦位置嵌入对时间\(t\)进行了编码,这使得模型在预测噪声时知道它预测的是批次中分别是哪个时间片添加的噪声。在卷积层之间,DDPM添加了一个注意力层。这里我们可以使用Transformer中提出的自注意力机制或是多头自注意力机制。[13]则提出了一个线性注意力机制的模块,它的特点是消耗的时间以及占用的内存和序列长度是线性相关的,对比传统注意力机制的平方相关要高效很多。在进行归一化时,DDPM选择了组归一化(Group Normalization,GN)[16]。最后,对于U-Net中的降采样和上采样操作,DDPM分别选择了步长为2的卷积以及反卷积。
确定了这些组件,我们便可以搭建用于DDPM的U-Net的模型了。从第2.1节的介绍我们知道,模型的输入为形状为(batch_size, num_channels, height, width)的噪声图像和形状为(batch_size,1)的噪声水平,返回的是形状为(batch_size, num_channels, height, width)的预测噪声,我们搭建的用于噪声预测的模型结构如下:
- 首先在噪声图像\(\boldsymbol x_0\)上应用卷积层,并为噪声水平\(t\)计算时间嵌入;
- 接下来是降采样阶段。采用的模型结构依次是两个卷积(WRNS或是ConvNeXT)+GN+Attention+降采样层;
- 在网络的最中间,依次是卷积层+Attention+卷积层;
- 接下来是上采样阶段。它首先会使用Short-cut拼接来自降采样中同样尺寸的卷积,再之后是两个卷积+GN+Attention+上采样层。
- 最后是使用WRNS或是ConvNeXT作为输出层的卷积。
U-Net类的forword函数如下面代码片段所示,完整的实现代码参照[3]。
def forward(self, x, time):
x = self.init_conv(x)
t = self.time_mlp(time) if exists(self.time_mlp) else None
h = []
# downsample
for block1, block2, attn, downsample in self.downs:
x = block1(x, t)
x = block2(x, t)
x = attn(x)
h.append(x)
x = downsample(x)
# bottleneck
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
# upsample
for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = block1(x, t)
x = block2(x, t)
x = attn(x)
x = upsample(x)
return self.final_conv(x)
2.2.2 前向加噪
扩散模型的前向过程是逐渐向图像中添加噪声的过程,这个从时刻\(0\)到时刻\(T\)的\(t\)的变化情况叫做差异时间表。DDPM使用的是线性时间表,即我们首先对\(T\)个时间步做均匀拆分,得到\(\beta_t\)。接下来我们根据\(\beta_t\)计算我们需要的其它变量,例如\(\alpha\),\(\bar{\alpha}\)等,它们最好被存储起来以避免重复计算。接下来我们需要准备输入图像\(\boldsymbol x_0\)以及随机噪声\(\epsilon\),其中图像的处理主要包括resize到模型输入大小以及进行归一化,DDPM的策略是将它们线性归一化到\([-1,1]\)之间,随机噪声即随机生成一个和输入图像相同尺寸的二维高斯噪声。最后我们根据式(3)将它们合成一张图像即可。图6是参考文献[3]中给出的输入图像依次经过0次,50次,100次,150次以及199次加噪后的效果图,可以看出随着逐渐添加噪声,图像越来越难以区分,直到彻底变成一个二维高斯噪声。
图6:一张图依次经过0次,50次,100次,150次以及199次加噪后的效果图
根据式(14)我们知道,扩散模型的损失函数计算的是两张图像的相似性,因此我们可以选择使用回归算法的所有损失函数,以MSE为例,前向过程的核心代码如下面代码片段。
def p_losses(denoise_model, x_start, t, noise=None, loss_type="l1"):
# 1. 根据时刻t计算随机噪声分布,并对图像x_start进行加噪
x_noisy = q_sample(x_start=x_start, t=t, noise=noise)
# 2. 根据噪声图像以及时刻t,预测添加的噪声
predicted_noise = denoise_model(x_noisy, t)
# 3. 对比添加的噪声和预测的噪声的相似性
loss = F.mse_loss(noise, predicted_noise)
return loss
2.2.3 样本生成
根据2.1.5节介绍的样本生成流程,它的核心代码片段所示,关于这段代码的讲解我通过注释添加到了代码片段中。
@torch.no_grad()
def p_sample(model, x, t, t_index):
betas_t = extract(betas, t, x.shape)
sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
# 使用式(13)计算模型的均值
model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
if t_index == 0:
return model_mean
else:
# 获取保存的方差
posterior_variance_t = extract(posterior_variance, t, x.shape)
noise = torch.randn_like(x)
# 算法2的第4行
return model_mean + torch.sqrt(posterior_variance_t) * noise
# 算法2的流程,但是我们保存了所有中间样本
@torch.no_grad()
def p_sample_loop(model, shape):
device = next(model.parameters()).device
b = shape[0]
# start from pure noise (for each example in the batch)
img = torch.randn(shape, device=device)
imgs = []
for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
imgs.append(img.cpu().numpy())
return imgs
最后我们看下在人脸图像数据集下训练的模型,一批随机噪声经过逐渐去噪变成人脸图像的示例。
图7:扩散模型由随机噪声通过去噪逐渐生成人脸图像
3. 总结
这里我们以DDPM为例介绍了另一个派系的生成算法:扩散模型。扩散模型是一个基于马尔可夫链的数学模型,它通过预测每个时间片添加的噪声来进行模型的训练。作为近日来引发热烈讨论的ControlNet, Stable Diffusion等模型的底层算法,我们十分有必要对其有所了解。DDPM的实现并不复杂,这得益于大量数学界大佬通过大量的数学推导将整个扩散过程和反向去噪过程进行了精彩的化简,这才有了DDPM的大道至简的实现。DDPM作为一个扩散模型的基石算法,它有着很多早期算法的共同问题:
- 采样速度慢:DDPM的去噪是从时刻\(T\)到时刻\(1\)的一个完整的马尔可夫链的计算,尤其是DDPM还需要一个比较大的\(T\)才能保证比较好的效果,这就导致了DDPM的采样过程注定是非常慢的;
- 生成效果差:DDPM的效果并不能说是非常好,尤其是对于高分辨率图像的生成。这一方面是因为它的计算速度限制了它扩展到更大的模型;另一方面它的设计还有一些问题,例如逐像素的计算损失并使用相同权值而忽略图像中的主体并不是非常好的策略。
- 内容不可控:我们可以看出,DDPM生成的内容完全还是取决于它的训练集。它并没有引入一些先验条件,因此并不能通过控制图像中的细节来生成我们制定的内容。
我们现在已经知道,DDPM的这些问题已大幅得到改善,现在基于扩散模型生成的图像已经达到甚至超过人类多数的画师的效果,我也会在之后逐渐给出这些优化方案的讲解。
Reference
[1] Sohl-Dickstein, Jascha, et al. “Deep unsupervised learning using nonequilibrium thermodynamics.” International Conference on Machine Learning. PMLR, 2015.
[2] Ho, Jonathan, Ajay Jain, and Pieter Abbeel. “Denoising diffusion probabilistic models.” Advances in Neural Information Processing Systems 33 (2020): 6840-6851.
[3] https://huggingface.co/blog/annotated-diffusion
[4] https://lilianweng.github.io/posts/2021-07-11-diffusion-models/#simplification
[5] https://openai.com/blog/generative-models/
[6] Nichol, Alexander Quinn, and Prafulla Dhariwal. “Improved denoising diffusion probabilistic models.” International Conference on Machine Learning. PMLR, 2021.
[7] Kingma, Diederik P., and Max Welling. “Auto-encoding variational bayes.” arXiv preprint arXiv:1312.6114 (2013).
[8] Hinton, Geoffrey E., and Ruslan R. Salakhutdinov. “Reducing the dimensionality of data with neural networks.” science 313.5786 (2006): 504-507.
[9] Ronneberger O, Fischer P, Brox T. U-net: Convolutional networks for biomedical image segmentation[C]//International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015: 234-241.
[10] Long, Jonathan, Evan Shelhamer, and Trevor Darrell. “Fully convolutional networks for semantic segmentation.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.
[11] Luo, Calvin. “Understanding diffusion models: A unified perspective.” arXiv preprint arXiv:2208.11970 (2022).
[12] Zagoruyko, Sergey, and Nikos Komodakis. “Wide residual networks.” arXiv preprint arXiv:1605.07146 (2016).
[13] https://github.com/lucidrains/denoising-diffusion-pytorch
[14] Liu, Zhuang, et al. “A convnet for the 2020s.” Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
[15] Vaswani, Ashish, et al. “Attention is all you need.” Advances in neural information processing systems 30 (2017).
[16] Wu, Yuxin, and Kaiming He. “Group normalization.” Proceedings of the European conference on computer vision (ECCV). 2018.