Diffusion Models 的想法是通过缓慢地将噪声加入数据,然后学习逆向扩散的过程,从而能从噪声中构建出所需要的样本。

扩散模型

x0x_0 为原始的样本,xTx_T 为最终的噪声。设 qq 表示真实的概率分布,pp 则表示生成的概率分布。

  • 前向扩散,逐步向图像中添加高斯噪声,设在 t1t -1 时刻向 tt 时刻的图像扩散的方差为 βt\beta_t,那么

    xt=1βtxt1+βtϵ (ϵN(0,I))x_t = \sqrt{1 - \beta_t} x_{t-1} +\sqrt{\beta_t}\epsilon\ (\epsilon \sim\mathcal N(\mathbf 0, \mathbf I))

  • 逆向扩散,从某个存在噪声的时刻 xt0x_{t_0} 逐步还原出 x0x_0 的过程。

算法介绍

前向扩散

我们对 xtx_t 进行一些推导,我们令 αt=1βt\alpha_t = 1 - \beta_t

xt=1βtxt1+βtϵ=αtxt1+1αtϵ=αtαt1xt2+αt(1αt1)ϵ1+1αtϵ2\begin{aligned} x_t &= \sqrt{1 - \beta_t} x_{t - 1} + \sqrt{\beta_t} \epsilon \\ &= \sqrt{\alpha_t} x_{t - 1} + \sqrt{1 - \alpha_t} \epsilon \\ &= \sqrt{\alpha_t\alpha_{t-1}} x_{t-2} + \sqrt{\alpha_t(1 - \alpha_{t-1})}\epsilon_1 + \sqrt{1 - \alpha_t} \epsilon_2 \end{aligned}

由于 αt(1αt1)ϵ1N(0,αt(1αt1)I)\sqrt{\alpha_t(1 - \alpha_{t-1})}\epsilon_1 \sim \mathcal N(\mathbf 0, \alpha_t(1 - \alpha_{t-1})\mathbf I)1αtϵ2N(0,(1αt)I)\sqrt{1 - \alpha_t} \epsilon_2 \sim \mathcal N(\mathbf 0, (1 - \alpha_t)\mathbf I) ,所以有 αt(1αt1)ϵ1+1αtϵ2N(0,(1αtαt1)I)\sqrt{\alpha_t(1 - \alpha_{t-1})}\epsilon_1 + \sqrt{1 - \alpha_t} \epsilon_2\sim \mathcal N(\mathbf 0, (1 - \alpha_t\alpha_{t-1}) \mathbf I)

所以有 xt=αtαt1xt2+1αtαt1ϵx_t = \sqrt{\alpha_t\alpha_{t - 1}} x_{t - 2} + \sqrt{1 - \alpha_t\alpha_{t-1}} \epsilon。容易由此归纳得到 xt=i=1tαix0+1i=1tαiϵx_t = \sqrt{\prod_{i=1}^t \alpha_i} x_0 +\sqrt{1 - \prod_{i=1}^t\alpha_i} \epsilon

αtˉ=i=1tαi\bar{\alpha_t} = \prod_{i = 1}^t \alpha_i,那么有

xt=αtˉx0+1αˉtϵ\begin{aligned} x_t = \sqrt{\bar{\alpha_t}} x_0 + \sqrt{1 - \bar \alpha_t} \epsilon \end{aligned}

随着时间不断推移,αtˉ0\bar{\alpha_t} \to 0,所以最终图像会变为随机高斯噪声 xN(0,I)x_{\infty} \sim N(\mathbf 0, \mathbf I)

逆向过程

现在我们希望计算从 xtx_{t } 中还原出 xt1x_{t-1}。我们先来计算 q(xt1xt)q(x_{t - 1} \mid x_t),这个没有办法直接进行计算,我们在训练的时候是知道 x0x_0 的,我们可以借助 x0x_0 和 Bayes 公式来计算:

q(xt1xt,x0)=q(xtxt1,x0)q(xt1x0)q(xtx0)=12π(1αt)exp(xtαtxt122(1αt))12π(1αˉt1)exp(xt1αˉt1x022(1αˉt1))12π(1αtˉ)exp(xtαtˉx022(1αtˉ))exp(xtαtxt122(1αt)xt1αˉt1x022(1αˉt1)+xtαtˉx022(1αtˉ))=exp(12[(αtβt+11αˉt1)xt12(2αtβtxt+2αˉt11αˉt1x0)xt1+C(xt,x0)])\begin{aligned} q(x_{t - 1} \mid x_t, x_0) &= \frac{q(x_{t}\mid x_{t-1}, x_0) q(x_{t-1}\mid x_0)}{q(x_{t} \mid x_0)} \\ &= \frac { \frac{1}{\sqrt{2\pi(1 - \alpha_t)}}\exp(-\frac{\|x_t - \sqrt{\alpha_t}x_{t-1}\|^2}{2(1 - \alpha_t)}) \cdot \frac{1}{\sqrt{2\pi(1 - \bar{\alpha}_{t-1})}}\exp(-\frac{\|x_{t-1} - \sqrt{\bar{\alpha}_{t-1}}x_{0}\|^2}{2(1 - \bar{\alpha}_{t-1})}) } {\frac{1}{\sqrt{2\pi(1 - \bar{\alpha_{t}})}}\exp(-\frac{\|x_{t} - \sqrt{\bar{\alpha_{t}}}x_{0}\|^2}{2(1 - \bar{\alpha_{t}})})} \\ &\propto \exp\left( -\frac{\|x_t - \sqrt{\alpha_t}x_{t-1}\|^2}{2(1 - \alpha_t)} -\frac{\|x_{t-1} - \sqrt{\bar{\alpha}_{t-1}}x_{0}\|^2}{2(1 - \bar{\alpha}_{t-1})} +\frac{\|x_{t} - \sqrt{\bar{\alpha_{t}}}x_{0}\|^2}{2(1 - \bar{\alpha_{t}})} \right ) \\ &= \exp\left ( -\frac{1}{2}\left [ \left (\frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}}\right )\|x_{t-1}\|^2 - \left (\frac{2\sqrt{\alpha_t}}{\beta_t}x_t + \frac{2\sqrt{\bar\alpha_{t-1}}}{1 - \bar \alpha_{t-1}} x_0 \right )x_{t-1} + C(x_t, x_0) \right] \right ) \end{aligned}

其中 x=xx\|x \| = \sqrt{xx} 表示二范数,xtxt1x_t x_{t - 1} 表示它们的点积,C(xt,x0)C(x_t, x_0) 是一个与 xt,x0x_t, x_0 有关的常量,我们考虑对该式子进行配方,我们期望得到一个类似于 exp((xμ)22σ2)\exp(-\frac{(x -\mu)^2}{2\sigma^2}) 即正态分布的形式。

可以得到

1σ2=(αtβt+11αˉt1)σ2=(αtβt+11αˉt1)1=1αˉt11αˉtβtμ=12(2αtβtxt+2αˉt11αˉt1x0)σ2=αt(1αˉt1)1αˉtxt+αˉt1βt1αˉtx0\begin{aligned} \frac{1}{\sigma^2} &= \left (\frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}}\right ) \\ \sigma^2 &= \left (\frac{\alpha_t}{\beta_t} + \frac{1}{1-\bar{\alpha}_{t-1}}\right )^{-1} = \frac{1 - \bar\alpha_{t-1}}{1 - \bar \alpha_t}\beta_t \\ \mu &= \frac{1}{2} \left (\frac{2\sqrt{\alpha_t}}{\beta_t}x_t + \frac{2\sqrt{\bar\alpha_{t-1}}}{1 - \bar \alpha_{t-1}} x_0 \right )\sigma^2 \\ &= \frac{\sqrt{\alpha_t}(1-\bar\alpha_{t-1})}{1-\bar \alpha_t}x_t + \frac{\sqrt{\bar\alpha_{t-1}}\beta_t}{1 - \bar \alpha_t}x_0 \end{aligned}

我们希望最后能够不利用 x0x_0 这个信息,这样我们就能做生成了。注意到 xt=αtˉx0+1αˉtϵx_t = \sqrt{\bar{\alpha_t}} x_0 + \sqrt{1 - \bar \alpha_t} \epsilon,我们可以得到 x0=1αˉtxt1αˉtαˉtϵx_0 = \frac{1}{\sqrt{\bar \alpha_t}} x_t - \frac{\sqrt{1 - \bar \alpha_t}}{\sqrt{\bar \alpha_t}}\epsilon。代入再化简可以得到:

μ=1αt(xtβt1αˉtϵ)\mu = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1 - \bar \alpha_t}}\epsilon)

所以

xt1N(μ,σ2)x_{t - 1} \sim\mathcal N(\mu, \sigma^2)

由于高斯噪声 ϵ\epsilon,我们无法得到准确的均值 μ\mu,因此我们希望去构建一个模型根据 xt,tx_t, t 去对这个高斯噪声 ϵ\epsilon 进行预测,然后再在这个分布中采样得到 xt1x_{t - 1}。由此完成逆向扩散的过程。

损失函数

我们希望最终生成出来的概率分布 pθp_{\theta} 能够和目标的概率 qq 分布尽可能像,我们计算它们的交叉熵 H(p,q)=Ex0q[logpθ(x]Ex0q[logq(x0)]=H(p)H(p, q) = \mathbb E_{x_0 \sim q}[-\log p_{\theta}(x_] \geqslant \mathbb E_{x_0\sim q}[-\log q(x_0)] = H(p),取等当且仅当两个分布相同。但遗憾的是它们的交叉熵没有办法直接进行计算,我们对它进行一些变形:

H(p,q)=Ex0q[logpθ(x0:T)dx1:T]\begin{aligned} H(p, q) &= \mathbb E_{x_0\sim q} \left [-\log \int p_{\theta}(x_{0:T})\text dx_{1:T} \right ] \end{aligned}

这里  dx1:T\int \cdot \ \text d x_{1:T} 相当于是枚举所有可能的从 TT 逆向扩散到 11 的过程,pθ(x0:T)p_{\theta}(x_{0:T}) 则是计算像这样从 xTx_T 逆向扩散到 x1x_1 的概率。稍微再进行变形可以得到:

H(p,q)=Ex0q[logq(x1:Tx0)pθ(x0:T)q(x1:Tx0)dx1:T]=Ex0q{logEx1:Tqx0[pθ(x0:T)q(x1:Tx0)]}\begin{aligned} H(p, q) &= \mathbb E_{x_0\sim q} \left [-\log \int q(x_{1:T}\mid x_0)\cdot \frac{p_{\theta}(x_{0:T})}{q(x_{1:T}\mid x_0)}\text dx_{1:T} \right ] \\ &=\mathbb E_{x_0\sim q} \left \{-\log \mathbb E_{x_{1:T}\sim q\mid x_0} \left [\frac{p_{\theta}(x_{0:T})}{q(x_{1:T}\mid x_0)}\right ] \right \} \end{aligned}

这里 x1:Tqx0x_{1:T} \sim q\mid x_0 表示在已知 x0x_0 下,x1,,xTx_{1}, \cdots, x_{T} 的分布。由于 $-\log $ 是一个下凸函数,利用 Jensen 不等式有

log(E(X))E(logX)-\log (\mathbb E(X)) \leqslant \mathbb E(-\log X)

所以有:

H(p,q)Ex0:Tq[logpθ(x0:T)q(x1:Tx0)]=LVLB\begin{aligned} H(p, q) &\leqslant \mathbb E_{x_{0:T}\sim q}\left [-\log \frac{p_{\theta}(x_{0:T})}{q(x_{1:T}\mid x_0)}\right ] = L_{VLB} \end{aligned}

得到的这一项为其变分上界(数学是学不了一点的,这辈子都学不了一点的,明天再说.jpg),对其最小化近似于对原函数的最小化。继续化简有:(由于 hexo 博客解析下面这段 latex 有点问题,所以这里用的图片)

其中由于 pθ(xT)p_{\theta}(x_T) 为高斯分布,与 θ\theta 无关,所以 LTL_T 显然为常量。

LtL_t 项损失

我们对 LtL_t 项的损失继续进行化简:

Lt=E(x0,xt)q[DKL(q(xt1xt,x0)pθ(xt1xt))]\begin{aligned} L_t = \mathbb E_{(x_0, x_t)\sim q}\left [ D_{KL}(q(x_{t-1} \mid x_t, x_0) \| p_{\theta}(x_{t-1}\mid x_t)) \right ] \end{aligned}

由逆向过程的推导我们可以知道 q(xt1xt,x0)q(x_{t-1} \mid x_t, x_0) 为一个正态分布,设其均值为 μ~(xt,x0)\tilde \mu(x_t, x_0),方差为 β~t\tilde \beta _tpθ(xt1xt)p_{\theta}(x_{t-1}\mid x_t) 也为一个正态分布,均值为 μθ(xt,t)\mu_{\theta}(x_t, t),方差为 σt2\sigma_t^2。由于两个正态分布的 KL 散度为:

DKL(N(μ1,σ12)N(μ2,σ22))=logσ2σ1+σ12+(μ1μ2)22σ2212D_{K L}\left(\mathcal{N}\left(\mu_1, \sigma_1^2\right)|| \mathcal{N}\left(\mu_2, \sigma_2^2\right)\right)=\log \frac{\sigma_2}{\sigma_1}+\frac{\sigma_1^2+\left(\mu_1-\mu_2\right)^2}{2 \sigma_2^2}-\frac{1}{2}

所以有

Lt=E(x0,xt)q[12σt2μθ(xt,t)μ~(xt,x0)2+C]\begin{aligned} L_t = \mathbb E_{(x_0, x_t)\sim q}\left [ \frac{1}{2\sigma_t^2} \|\mu_{\theta}(x_t, t) - \tilde\mu (x_t, x_0) \|^2 + C \right ] \end{aligned}

其中 CC 为一项只与 tt 有关的常数,代入

μθ(xt,t)=1αt(xtβt1αˉtϵθ(xt,t))μ~(xt,x0)=1αt(xtβt1αˉtϵt)\begin{aligned} \mu_{\theta}(x_t, t) &= \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1 - \bar \alpha_t}}\epsilon_{\theta}(x_t, t)) \\ \tilde \mu(x_t, x_0) &= \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1 - \bar \alpha_t}}\epsilon_{t}) \end{aligned}

化简得

Lt=E(x0,xt)q[βt22σt2αt(1αˉt)ϵθ(xt,t)ϵt2+C]L_t = \mathbb E_{(x_0, x_t)\sim q}\left [ \frac{\beta_t^2}{2\sigma_t^2\alpha_t(1-\bar\alpha_t)} \|\epsilon_\theta(x_t, t) - \epsilon_t\|^2 + C \right ]

L0L_0 项损失

由于原图像每一维的范围是 [0,255][0, 255] 中的整数,我们假设将其线性映射到了 [1,1][-1, 1] 中的实数。由于我们还原图像的时候会将实数映射为整数,并且进行 [0,255][0, 255] 范围内的裁剪,所以我们计算概率的时候用预测结果落在 [x0i1255,x0i+1255][x_0^{i}-\frac{1}{255}, x_{0}^i + \frac{1}{255}] 中的概率作为 pθ(x0ix1)p_{\theta}(x_0^i\mid x_1),其中 x0ix_0^i 表示 x0x_0 的第 ii 个像素的结果。我们对 x0i=1x_0^i = -1x0i=1x_0^i = 1 的时候进行特殊处理,因为还原的时候会被截断。具体计算方式如下所示:

pθ(x0x1)=i=1Dδ(x0i)δ+(x0i)N(x;μθi(x1,1),σ12)dxδ+(x)={ if x=1x+1255 if x<1δ(x)={ if x=1x1255 if x>1\begin{aligned} p_\theta\left(x_0 \mid x_1\right) & =\prod_{i=1}^D \int_{\delta_{-}\left(x_0^i\right)}^{\delta_{+}\left(x_0^i\right)} \mathcal{N}\left(x ; \mu_\theta^i\left(x_1, 1\right), \sigma_1^2\right) \text d x \\ \delta_{+}(x) & =\left\{\begin{array}{ll} \infty & \text { if } x=1 \\ x+\frac{1}{255} & \text { if } x<1 \end{array} \quad \delta_{-}(x)= \begin{cases}-\infty & \text { if } x=-1 \\ x-\frac{1}{255} & \text { if } x>-1\end{cases} \right. \end{aligned}

其中 DD 表示图像数据的维度。

训练过程

尽管通过上面两部分,已经可以进行训练, 但是为了实现更方便,使用均方差损失来替代原先的两种损失:

Lsimple=E(x0,xt)q[ϵ(xt,t)ϵt2]\begin{aligned} L_{simple} = \mathbb E_{(x_0, x_t)\sim q} \left [ \|\epsilon(x_t, t) - \epsilon_t\|^2 \right ] \end{aligned}

  • 对于 LtL_t 项损失相当于忽略了前面的系数
  • 对于 L0L_0 项损失,可以看作用 ϵ0iϵθi(x1,1)\epsilon_0^i - \epsilon _\theta^i(x_1,1) 代入正态分布的概率密度函数的值再乘上某个常数来近似那段积分。由于计算的是 logpθ(x0x1)-\log p_{\theta} (x_0 \mid x_1) ,取了对数后就有了 ϵ0iϵθi(x1,1)2\|\epsilon_0^i - \epsilon _\theta^i(x_1,1)\|^2 这一项。

训练的时候,我们从 q(x0)q(x_0) 中采样 x0x_0,然后在 1,,T1, \cdots, T 中随机选取一个时间步 TT,接着生成噪声 ϵN(0,I)\epsilon \sim \mathcal N (0, \mathbf I)。接着计算出 xtx_t,然后让模型预测噪声 ϵθ(xt,t)\epsilon_{\theta}(x_t, t),接着计算它们的均方差损失,然后反向传播。

采样过程

首先我们从 N(0,I)\mathcal N(\mathbf 0, \mathbf I) 中进行采样 xTx_T,然后枚举 t=T,T1,,1t = T, T-1, \cdots, 1 对每个时间步进行去噪和采样:

  • 通过 UNet 预测 tt 时刻的噪声 ϵ\epsilon
  • 接着在 N(0,I)\mathcal N(\mathbf 0, \mathbf I) 中采样得到 zz (如果 t=1t = 1,那么直接令 z=0z = 0,因此此时我们不会对 xt1x_{t - 1} 的分布进行采样,直接用均值作为最终的结果)
  • 然后取 xt1=1αt(xtβt1αˉtϵ)+1αˉt11αˉtβtzx_{t - 1} = \frac{1}{\sqrt{\alpha_t}}(x_t - \frac{\beta_t}{\sqrt{1 - \bar \alpha_t}}\epsilon) + \frac{1 - \bar\alpha_{t-1}}{1 - \bar \alpha_t}\beta_t z

复现

参考了一下源代码的实现,可以看一下这个有注解的版本,使用类似 Transformers 的位置编码来对时间步进行编码,然后使用一种复杂度更低的线性注意力来为 UNet 添加注意力机制。由于没有太理解代码作者把时间编码通过线性层变换后然后均分为两半得到 s,ts, t,然后通过 (1+s)x+t(1 + s)x + t 的方式来为隐状态 xx 加入时间信息的原理,这里实现的时候在注意力层加入一对有关时间的键值对,由此来利用时间信息。

线性注意力

线性注意力优化复杂度的想法是,传统的缩放点积注意力可以看作是计算 ρ(QKT)V\rho(QK^T)V,其中 ρ\rho 是归一化函数,比如逐行的 softmax。但是由于 QKTQK^T 的计算结果是一个 n×nn\times n 的矩阵,复杂度非常地高,假设没有 ρ\rho,那么我们计算矩阵乘法的时候会去考虑用 Q(KTV)Q(K^TV) 的顺序来计算,这样的复杂度仅为 O(dkdvn)O(d_kd_vn),其中 dkd_kdvd_v 分别表示键向量和值向量的维度,通常这个维度都是小于 nn 的,故这样计算的复杂度会比直接通过 (QKT)V(QK^T)V 的顺序计算的复杂度低很多。但是这里有个归一化操作,但是考虑用 ρ(QKT)ρq(Q)ρk(KT)\rho(QK^T)\approx \rho_q(Q)\rho_k(K^T) 来对其进行做近似。当 ρ\rho 是逐行 softmax 时,ρq\rho_qρk\rho_k 分别也取逐行 softmax,虽然不能保证结果相近,但是仍然能够保持 ρq(Q)ρk(KT)\rho_q(Q)\rho_k(K^T) 每行和为 1 的性质。

缩放点积注意力和线性注意力

保持每行和为 1 的性质证明如下:

A=ρq(Q)ρk(KT),Q~=ρq(Q),K~=ρk(KT)A = \rho_q(Q)\rho_k(K^T), \tilde Q = \rho_q(Q), \tilde K = \rho_k(K^T),那么有

jAi,j=jkQ~i,kK~k,j=kQ~i,kjK~k,j=KQ~i,k=1\begin{aligned} \sum_j A_{i,j} &= \sum_j \sum_k \tilde Q_{i,k} \tilde K_{k, j} \\ &= \sum_k \tilde Q_{i, k} \sum_j \tilde K_{k, j} \\ &= \sum_K \tilde Q_{i, k} \\ &= 1 \end{aligned}

网络中的注意力层实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
class CrossAttention(nn.Module) :
def __init__(self, num_input, num_heads=4, dim_heads=32, *, bias=False) :
super().__init__()
self.scale = math.pow(dim_heads, -0.5)
self.num_heads = num_heads
num_hiddens = num_heads * dim_heads
self.C_qkv = nn.Conv2d(num_input, num_hiddens * 3, 1, bias=bias)
self.W_vh = nn.Linear(num_input, num_hiddens, bias=bias)
self.C_hy = nn.Conv2d(num_hiddens, num_input, 1)

def forward(self, X, Vt) :
'''
- X: Images, with a shape of (batch_size, num_channels=num_input, x, y)
- V1: Time encoding, with a shape of (batch_size, num_hiddens=num_input)
'''
x = X.shape[-2]

# project
QKV = self.C_qkv(X).chunk(3, dim=1)
Vt = self.W_vh(Vt)

# reshape
Q, K, V = tuple(rearrange(t, 'b (h d) x y -> b h d (x y)', h=self.num_heads) for t in QKV)
Vt = rearrange(Vt, 'b (h d) -> b h d 1', h=self.num_heads)
K = torch.concat([K, Vt], dim=-1)
V = torch.concat([V, Vt], dim=-1)

# scaled dot-product attention
Q = Q * self.scale
sim = torch.einsum('b h d i, b h d j -> b h i j', Q, K)

# Equivalent to scaling, where both the numerator and the denominator are divided by e^{mx}.
sim = sim - sim.amax(dim=-1, keepdim=True).detach()

# shape = (b, h, q, k)
attn = sim.softmax(dim=-1)
# self.last_att = attn.detach()

Y = torch.einsum('b h q k, b h d k -> b h q d', attn, V)
Y = rearrange(Y, 'b h (x y) d -> b (h d) x y', x=x)
return self.C_hy(Y)

class CrossLinearAttention(nn.Module) :
def __init__(self, num_input, num_heads=4, dim_heads=32, *, bias=False) :
super().__init__()
self.scale = math.pow(dim_heads, -0.5)
self.num_heads = num_heads
num_hiddens = num_heads * dim_heads
self.C_qkv = nn.Conv2d(num_input, num_hiddens * 3, 1, bias=bias)
self.W_vh = nn.Linear(num_input, num_hiddens, bias=bias)
self.C_hy = nn.Conv2d(num_hiddens, num_input, 1)

def forward(self, X, Vt) :
x = X.shape[-2]

# project
QKV = self.C_qkv(X).chunk(3, dim=1)
Vt = self.W_vh(Vt)

# reshape
Q, K, V = tuple(rearrange(t, 'b (h d) x y -> b h d (x y)', h=self.num_heads) for t in QKV)
Vt = rearrange(Vt, 'b (h d) -> b h d 1', h=self.num_heads)
K = torch.concat([K, Vt], dim=-1)
V = torch.concat([V, Vt], dim=-1)

# f(QK^T)V -> f(Q) (f(K^T)V)
Q = Q.softmax(dim=-2)
K = K.softmax(dim=-1)

context = torch.einsum('b h d n, b h e n -> b h d e', K, V)

Y = torch.einsum('b h d e, b h d n -> b h e n', context, Q)
Y = rearrange(Y, 'b h c (x y) -> b (h c) x y', x=x)
return self.C_hy(Y)

UNet + Attention

接着参考了原作者的实现,对 UNet 的架构进行了一些简单的修改:

  • 将卷积层替换为 2 个 3x3 卷积层组成的残差块
  • 在上采样和下采样之前添加注意力层,注意力层除了包含图像每个点的键值和查询外,额外添加一对时间编码的键值
  • 在 UNet 的最下层采用缩放点积注意力,其余部分使用线性注意力

上采样层和下采样层,以及 UNet 的实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class CADown(nn.Module) :
def __init__(self, in_channels, out_channels, time_dim, *, num_groups=8, bias=False) :
super().__init__()
self.res1 = ResnetBlock(in_channels, out_channels, num_groups)
self.res2 = ResnetBlock(out_channels, out_channels, num_groups)

self.mlp1 = nn.Sequential(
nn.SiLU(),
nn.Linear(time_dim, out_channels)
)
# self.mlp2 = nn.Sequential(
# nn.SiLU(),
# nn.Linear(out_channels, out_channels)
# )

self.norm = nn.GroupNorm(1, out_channels)
self.attn = Residual(CrossLinearAttention(out_channels, bias=bias))

self.down = nn.MaxPool2d(kernel_size=2, stride=2)

def forward(self, X, T, stk) :
X = self.res1(X)
stk.append(X)
X = self.res2(X)
stk.append(X)

T = self.mlp1(T)

X = self.norm(X)
X = self.attn(X, T)

return self.down(X)

class CAUp(nn.Module) :
def __init__(self, in_channels, out_channels, time_dim, *, num_groups=8, bias=False) :
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, out_channels, 2, 2)

self.res1 = ResnetBlock(in_channels, out_channels, num_groups)
self.res2 = ResnetBlock(in_channels, out_channels, num_groups)

self.mlp1 = nn.Sequential(
nn.SiLU(),
nn.Linear(time_dim, out_channels)
)

self.norm = nn.GroupNorm(1, out_channels)
self.attn = Residual(CrossLinearAttention(out_channels, bias=bias))

def forward(self, X, T, stk) :
X = self.up(X)

X = torch.concat((X, stk.pop()), dim=1)
X = self.res1(X)
X = torch.concat((X, stk.pop()), dim=1)
X = self.res2(X)

T = self.mlp1(T)

X = self.norm(X)
return self.attn(X, T)

class ConditionalUNet(nn.Module) :
def __init__(self, in_channels, init_channels, out_channels, time_emb_dim, num_layers=4, *, num_groups=8, attn_bias=False) :
super().__init__()

down_block = partial(CADown, num_groups=num_groups, bias=attn_bias)
up_block = partial(CAUp, num_groups=num_groups, bias=attn_bias)

dim = in_channels

self.emb_t = SinusoidalPositionEmbedding(time_emb_dim)
self.mlp_t = nn.Sequential(
nn.Linear(time_emb_dim, time_emb_dim),
nn.GELU(),
nn.Linear(time_emb_dim, time_emb_dim)
)

self.downs = nn.ModuleList([])
for i in range(num_layers) :
out_dim = dim * 2 if i > 0 else init_channels
self.downs.append(down_block(dim, out_dim, time_emb_dim))
dim = out_dim

self.mid_res1 = ResnetBlock(dim, dim * 2, num_groups)
dim *= 2

self.mid_norm = nn.GroupNorm(1, dim)
self.mid_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim)
)
self.mid_attn = Residual(CrossAttention(dim))

self.mid_res2 = ResnetBlock(dim, dim, num_groups)

self.ups = nn.ModuleList([])
for i in range(num_layers) :
self.ups.append(up_block(dim, dim // 2, time_emb_dim))
dim = dim // 2

self.final_conv = nn.Conv2d(dim, out_channels, 1)

def forward(self, X, time) :
T = self.emb_t(time)
T = self.mlp_t(T)

stk = []
for blk in self.downs :
X = blk(X, T, stk)

X = self.mid_res1(X)
X = self.mid_attn(self.mid_norm(X), self.mid_mlp(T))
X = self.mid_res2(X)

for blk in self.ups :
X = blk(X, T, stk)
return self.final_conv(X)

运行结果

使用 CelebA 数据集,UNet 使用 4 层(进行 4 次下采样),初始通道数 64,扩散的过程取 T=1000T=1000β\beta10410^4 线性增加到 0.020.02,在 RTX 4090 上训练约 5h。

分别对 t=200,500,900t = 200, 500, 900 去噪的结果如下:

去噪结果

采样结果如下

生成结果

虽然效果不佳,但是已经初具人形了

踩坑

时间步较小时去噪的均方差损失会较大,当时以为需要去平衡一下这里的损失,于是训练时根据每个时间步上次的损失再归一化的方式对时间步进行采样,反而导致从时间步较大进行还原时的效果非常差,以至于最终生成的结果所有地方几乎一个颜色:

后来改成训练时对每个时间步均匀采样即可得到前面提到的初具人形的结果。