扩散模型研究一:去噪扩散概率模型DDPM

作者: 引线小白-本文永久链接:httpss://www.limoncc.com/post/1c60669bbe56769f/
知识共享许可协议: 本博客采用署名-非商业-禁止演绎4.0国际许可证

一、基本介绍

扩散模型大放异彩,从原理上搞清楚运作机制非常关键。下面我们定义一些符号:对于观测变量 x,与VAE对应一个隐变量 z不同,扩散模型对应一组隐变量 DTz={zt}t=1T, 而且我们假设DTz有马尔可夫性质。这样我们的模型就不是 p(x)=p(xz)p(z)dz而是

(1)p(x)=p(xDTz)dDTz(2)=p(xz1)p(z1z2)p(zt1zT)p(zT)dDTz(3)=p(xz1)p(zT)t=2Tp(zt1zt)dDTz

下面我们来解释一下这么做的动机:我们认为上帝是这么来生成一张图片的。先掷骰子(对,没错我们又假设上帝投骰子),然后慢慢画出图像(先色块,再形状,再细节,就像美术的厚涂法一样)。这对应着解码过程(decode),反过来就是编码过程(encode)。

(4)encode:xz1z2zT1zT=z(5)decode:z=zTzT1zT2z1x

也就是说有三个关键分布
1、编码分布 p(zx)
2、生成分布(解码分布) p(xz)
3、先验分布 p(z)

二、损失函数

考虑到编码分布的复杂性,我们一般使用易于计算的分布 q(zx)。我们使假设的模型分布尽可能与真实数据靠拢

(6)KL[q(x,DTz)|p(x,DTz)]=q(x,DTz)logq(x,DTz)p(x,DTz)dDTzdx(7)=q(DTzx)q(x)logq(DTzx)q(x)p(x,DTz)dDTzdx(8)=Exq(x)[q(DTzx)logq(x)dDTzq(DTzx)logp(x,DTz)q(DTzx)dDTz](9)=Hq(x)[x]+Exq(x)[q(DTzx)logp(x,DTz)q(DTzx)dDTz](10)Exq(x)[q(DTzx)logp(x,DTz)q(DTzx)dDTz]

f(θ)=q(DTzx)logp(x,DTz)q(DTzx)dDTz 我们有

(11)f(θ)=q(DTzx)logp(x,DTz)q(DTzx)dDTz(12)=q(DTzx)logp(xz1)p(zT)t=2Tp(zt1zt)q(z1x)t=2Tq(ztzt1)dDTz(13)=Eq(DTzx)[logp(xz1)p(zT)t=2Tp(zt1zt)q(z1x)t=2Tq(ztzt1)]

考虑到
(14)q(ztzt1)=q(ztzt1,x)=q(zt1zt,x)q(ztx)q(zt1x)

于是有
(15)f(θ)=Eq(DTzx)[logp(xz1)p(zT)t=2Tp(zt1zt)q(z1x)t=2Tq(ztzt1)](16)=Eq(DTzx)[logp(xz1)p(zT)t=2Tp(zt1zt)q(zt1x)q(z1x)t=2Tq(zt1zt,x)q(ztx)](17)=Eq(DTzx)[logp(xz1)+logp(zT)t=2Tq(zt1x)q(zTx)t=1T1q(ztx)+t=2Tlogp(zt1zt)q(zt1zt,x)](18)=Eq(DTzx)[logp(xz1)+logp(zT)q(zTx)+t=2Tlogp(zt1zt)q(zt1zt,x)](19)=KL[q(zTx)|p(zT))]LT+t=2TKL[q(zt1zt,x)|p(zt1zt)]L1:T1+Eq(DTzx)[logp(xz1)]L0

所以有损失函数:
(20)L=Exp(x)[KL[q(zTx)|p(zT))]+t=2TKL[q(zt1zt,x)|p(zt1zt)]+Eq(DTzx)[logp(xz1)]]

这里需要注意一点,我们是根据马尔可夫假设,推导出了损失函数。如果我们仔细观察损失函数的构成,其实我们很容易发现,马尔可夫假设不是必须的。这也为之后的模型提供了优化空间。

三、模型构建的细节

3.1、编码模型的实现

我们首先来考察一个损失函数 LT。所谓编码其实就是一个破坏的过程。我们希望通过往正常图片中逐步的加入噪声,来得到噪声。
(21)zt=αtzt1+βtϵt其中 ϵN(0,I), zt{x}{zt}t=1T={zt}t=0T
最终得到噪声
(22)p(zT)N(0,I)

这样有:

(23)q(ztx)=q(Dtzx)dDt=N(αtzt1,βt2I)
对于 q(zTx), 我们展开到 x就有

(24)zT=αTzT1+βTϵT=αT[αT1zT2+βT1ϵT1]+βTϵT(25)=αTαT1[αT2zT3+βT2ϵT2]+αTβT1ϵT1+βTϵT(26)=t=1Tαtx+[t=1T1βtτ=tTατ+1ϵt+βTϵT]
考察均值与方差
(27)E[zT]=t=1Tαtx(28)Cov[zT]=[t=1T1βt2τ=tTατ+12+βT2]I

我们希望最终的 zTN(0,I), 若 αt2+βt2=1 则有
(29)t=1T1βt2τ=tTατ+12+βT2=t=1T1(1αt2)τ=tTατ+12+βT2(30)=t=1T1[τ=tTατ+12τ=tTατ2]+βT2(31)=t=1T1[τ=tTατ+12τ=tTατ2]+βT2(32)=t=1Tαt2+αT2+βT2(33)=1t=1Tαt2

α¯T=t=1Tαt, β¯T=1t=1Tαt2, 这样就有:

(34)q(zTx)N(α¯Tx,β¯T2)=N(t=1Tαtx,[1t=1Tαt2]I)

我们只要设计合理的 αt使得 limT+t=1Tαt=0即可。这样无需参数,只通过一些必要的高斯假设就实现了我们的目的 KL[q(zTx)|p(zT))]=0。故我们损失的第一项是零。

3.2、生成模型的实现
3.2.1、L1:T1

考虑逆过程 q(zt1zt,x),根据贝叶斯定理有:
(35)q(zt1zt,x)=q(ztzt1,x)q(zt1x)q(ztx)(36)=N(αtzt1,βt2I)N(α¯t1x,β¯t12I)N(α¯tx,β¯t2I)(37)expi=1k[12[(ztαtzt1)2βt2+(zt1α¯t1x)2β¯t12(ztα¯tx)2β¯t2]](38)expi=1k[12[(αt2βt2+1β¯t12)zt122(αtβt2zt+α¯t1β¯t12x)zt1]]
根据高斯分布,我们配平方 ax22bx=a(xba)2+C可以得到:

均值有 (39)E[zt1zt,x]=(αtβt2zt+α¯t1β¯t12x)/(αt2βt2+1β¯t12)=β¯t12β¯t2αtzt+βt2β¯t2α¯t1x

方差有
(40)Cov[zt1zt,x]=1/(αt2βt2+1β¯t12)I=1/(αt2(1α¯t12)+βt2βt2(1α¯t12))I=1α¯t121α¯t2βt2=β¯t12β¯t2βt2I

也就是说:
(41)q(zt1zt,x)=N(β¯t12β¯t2αtzt+βt2β¯t2α¯t1x,β¯t12β¯t2βt2I)

在编码过程(变成噪声)中,我们有 q(ztx)N(α¯tx,β¯t2)=N(τ=1tατx,[1τ=1tατ2]I) 我们改变一下表现形式

(42)zt=α¯tx+β¯tϵx=1α¯t(ztβ¯tϵ)我们带入上式,进一步考察均值
(43)E[zt1zt,x]=β¯t12β¯t2αtzt+βt2β¯t2α¯t1x=β¯t12β¯t2αtzt+βt2β¯t2α¯t11α¯t(ztβ¯tϵ)(44)=[β¯t12β¯t2αt+βt2β¯t2α¯t1α¯t]ztβt2β¯t2α¯t1β¯tα¯tϵ(45)=[1α¯t12β¯t2αt+βt2β¯t21αt]ztβt2β¯t1αtϵ(46)=[αt2+βt2α¯t12αt2β¯t2αt]ztβt2β¯t1αtϵ(47)=[1α¯t2β¯t2αt]ztβt2β¯t1αtϵ(48)=1αtztβt2β¯t1αtϵ

我们知道 KL[N(μ1,σ12)|N(μ2,σ22)]=(μ1μ2)2+σ122σ22+logσ2σ112,同时 p(zt1zt)=N(μθ[zt,t],Σθ[zt,t]) 是我们真实的生成分布, 其中 μθ[zt,t] 是我们要学习的神经网络,用神经网络原因很简单图像数据太复杂。那么还剩下 Σθ[zt,t],我们注意到:

(49)q(zt1zt,x)=N(β¯t12β¯t2αtzt+βt2β¯t2α¯t1x,β¯t12β¯t2βt2I)

注意到损失函数 KL[q(zt1zt,x)|p(zt1zt)],使得最小化的最简单的方法就是令 Σθ[zt,t]=σt2I, σt2=β¯t12β¯t2βt2(论文作者是这么干的)。 σt2 其实可以取其他的值,在这之后的模型中提出的各种方案。那么有

(50)t=2TEx[KL[q(zt1zt,x)|p(zt1zt)]]=t=2TEt,x,ϵ[12σt2|1αtztβt2β¯t1αtϵμθ[zt,t]|2]+Const

变形神经网络形式如下:μθ[zt,t]=[1αtztβt2β¯t1αtϵθ[zt,t]],那么可以简化得到:

(51)L1:T1=t=2TEt,x,ϵ[KL[q(zt1zt,x)|p(zt1zt)]]=t=2TEt,x,ϵ[βt42σt2αt2β¯t2|ϵϵθ[zt,t]|2]+Const

3.2.2、L0

接下来我们考察最后一项 L0=Eq(DTzx)[logp(xz1)]

(52)L0=Ex[Eq(DTzx)[logp(xz1)]]=Ex[q(DTzx)logp(xz1)dDTz](53)=Ex[q(z1x)logp(xz1)dz1]

由于 z1=α1x+β1ϵ1p(xz1)=N(μθ[z1,1],Σθ[z1,1]=σ12),有 μθ[z1,1]=[1α1z1β12β¯11α1ϵθ[z1,1]]于是有:

(54)L0=Et,x,ϵ1[12σ12|xμθ[z1,1]|2]+Const(55)=Et,x,ϵ1[12σ12|x1α1[α1x+β1ϵ1]+β12β¯11α1ϵθ[z1,1]|2]+Const(56)=Et,x,ϵ1[12σ12|β1α1ϵ1+β12β¯11α1ϵθ[z1,1]|2]+Const(57)=Et,x,ϵ1[β142σ12α12β¯12|ϵ1ϵθ[α¯1x+β¯1ϵ1,1]|2]+Const

3.2.3、L0:T1

对于 L0:T1,令 Lt1=Et,x,ϵ[βt42σt2αt2β¯t2|ϵϵθ[zt,t]|2]+Const, 我们去掉与参数无关的权重,简化得到L0:T1的统一表达式:

(58)Lt1=Et,x,ϵ[|ϵϵθ[α¯tx+β¯tϵ,t]|2]+Const

于是有损失函数:

(59)L=t=1TEt,x,ϵ[βt42σt2αt2β¯t2|ϵϵθ[zt,t]|2]+Const

这样,我们有算法:
(60)Algorithm 1 Training1:repeat2:x0q(x0)3:tUniform({1,,T})4:ϵN(0,I)5:Take gradient descent step onθ|ϵϵθ[α¯tx+β¯tϵ,t]|26:until converged

对于采样生成有:
(61)Algorithm 2 Sampling1:zTN(0,I)2:fort=T,,1do3:ξN(0,I) if t>1, else ξ=04:zt1=1αtztβt2β¯t1αtϵθ[zt,t]+σtξ5:end for6:return x=z0

四、评述

1、推理中的一些关键动机,视乎没有加以说明。一些想法视乎是从天而降。 αt2+βt2=1似乎还能说的过去。σt2=β¯t12β¯t2βt2似乎就动机不足了,毕竟有那么多可以取的值,不一定非要这个。

2、在逆向生成中E[zt1zt,x], 做变形zt=α¯tx+β¯tϵx=1α¯t(ztβ¯tϵ)带入。就是为了消除 x,从而使得 ztzt1不依赖 x。而这一步推导DDPM依赖了马尔可夫假设,显然限制了分布空间。如果我们去掉马尔可夫假设,我们可以得到更大自由度。

3、最终的结果是简单,一定还有更加显然的推理方式,论文DDPM的推理总觉得还是太过迂回,与不显然。接下我们将使用宋大神SDE的思想来解读这个模型。

4、DDPM以后有一系列的改进,也将一一解读,敬请关注。

Ho, J., Jain, A., & Abbeel, P. (2020, December 16). Denoising Diffusion Probabilistic Models. arXiv. http://arxiv.org/abs/2006.11239. Accessed 22 March 2023
Kingma, D. P., & Welling, M. (2022, December 10). Auto-Encoding Variational Bayes. arXiv. http://arxiv.org/abs/1312.6114. Accessed 18 April 2023


版权声明
引线小白创作并维护的柠檬CC博客采用署名-非商业-禁止演绎4.0国际许可证。
本文首发于柠檬CC [ https://www.limoncc.com ] , 版权所有、侵权必究。
本文永久链接httpss://www.limoncc.com/post/1c60669bbe56769f/
如果您需要引用本文,请参考:
引线小白. (May. 16, 2023). 《扩散模型研究一:去噪扩散概率模型DDPM》[Blog post]. Retrieved from https://www.limoncc.com/post/1c60669bbe56769f
@online{limoncc-1c60669bbe56769f,
title={扩散模型研究一:去噪扩散概率模型DDPM},
author={引线小白},
year={2023},
month={May},
date={16},
url={\url{https://www.limoncc.com/post/1c60669bbe56769f}},
}

来发评论吧~
Powered By Valine
v1.5.2
'