深度学习之生成对抗网络(2)GAN原理
1. 网络结构生成网络G(z)\text{G}(\boldsymbol z)G(z)判别网络D(x)\text{D}(\boldsymbol x)D(x)2. 网络训练3. 统一目标函数现在我们来正式介绍生成对抗网络的网络结构和训练方法。
1. 网络结构
&esmp;生成对抗网络包含了两个子网络:生成网络(Generator,简称G)和判别网络(Discriminator,简称D),其中生成网络G负责学习样本的真实分布,判别网络D负责将生成网络采样的样本与真实样本区分开来。
生成网络G(z)\text{G}(\boldsymbol z)G(z)
生成网络G和自编码器的Decoder功能类似,从先验分布pz(⋅)p_\boldsymbol z (\cdot)pz(⋅)中采样隐藏变量z∼pz(⋅)\boldsymbol z\sim p_\boldsymbol z (\cdot)z∼pz(⋅),通过生成网络G参数化的pg(x∣z)p_g (\boldsymbol x|\boldsymbol z)pg(x∣z)分布,获得生成样本x∼pg(x∣z)\boldsymbol x\sim p_g (\boldsymbol x|\boldsymbol z)x∼pg(x∣z),如下图所示。其中隐藏变量z\boldsymbol zz的先验分布pz(⋅)p_\boldsymbol z (\cdot)pz(⋅)可以假设为某种已知的分布,比如多源均匀分布z∼Uniform(−1,1)\boldsymbol z\sim\text{Uniform}(-1,1)z∼Uniform(−1,1)。
生成网络G
pg(x∣z)p_g (\boldsymbol x|\boldsymbol z)pg(x∣z)可以用深度神经网络来参数化,如下图所示,从均匀分布pz(⋅)p_\boldsymbol z (\cdot)pz(⋅)中采样出隐藏变量z\boldsymbol zz,经过多层转置卷积层为了参数化的pg(x∣z)p_g (\boldsymbol x|\boldsymbol z)pg(x∣z)分布中采样出样本xf\boldsymbol x_fxf。从输入输出层面来看,生成器G的功能是将隐向量z\boldsymbol zz通过神经网络转换为样本向量xf\boldsymbol x_fxf,下标fff代表假采样(Fake samples)。
转置卷积构成的生成网络
判别网络D(x)\text{D}(\boldsymbol x)D(x)
判别网络和普通的二分类网络功能类似,它接受输入样本x\boldsymbol xx的数据集,包含了采样自真实数据分布pr(⋅)p_r (\cdot)pr(⋅)的样本xr∼pr(⋅)\boldsymbol x_r\sim p_r (\cdot)xr∼pr(⋅),也包含了采样自生成网络的假样本xf∼pg(x∣z)\boldsymbol x_f\sim p_g (\boldsymbol x|\boldsymbol z)xf∼pg(x∣z),xr\boldsymbol x_rxr和xf\boldsymbol x_fxf共同组成了判别网络的训练数据集。判别网络输出为x\boldsymbol xx属于真实样本的概率P(x为真∣x)P(\boldsymbol x为真|\boldsymbol x)P(x为真∣x),我们把所有真实样本xr\boldsymbol x_rxr的标签标注为真(1),所有生成网络产生的样本xf\boldsymbol x_fxf标注为假(0),通过最小化判别网络D的预测值与标签之间的误差来优化判别网络参数,如下图所示:
生成网络和判别网络
2. 网络训练
GAN博弈学习的思想体现在它的训练方式上,由于生成器G和判别器D的优化目标不一样,不能和之前的网络模型的训练一样,只采用一个损失函数。下面我们来分别介绍如何训练生成器G和判别器D。
对于判别网络D,它的目标是能够很好地分辨出真样本xr\boldsymbol x_rxr和假样本xf\boldsymbol x_fxf。以图片生成为例,它的目标是最小化图片的预测值和真实值之间的交叉熵损失函数:
minθL=CE(Dθ(xr),yr,Dθ(xf),yf)\underset{θ}{\text{min}}\mathcal L=\text{CE}(D_θ (\boldsymbol x_r ),y_r,D_θ (\boldsymbol x_f ),y_f )θminL=CE(Dθ(xr),yr,Dθ(xf),yf)
其中Dθ(xr)D_θ (\boldsymbol x_r )Dθ(xr)代表真实样本xr\boldsymbol x_rxr在判别网络DθD_θDθ的输出,θθθ为判别网络的参数集,Dθ(xf)D_θ (\boldsymbol x_f )Dθ(xf)为生成样本xf\boldsymbol x_fxf在判别网络的输出,yry_ryr为xr\boldsymbol x_rxr的标签,由于真实样本标注为真,故yr=1y_r=1yr=1,yfy_fyf为生成样本xf\boldsymbol x_fxf的标签,由于生成样本标注为假,故yf=0y_f=0yf=0。CE\text{CE}CE函数代表交叉熵损失函数CrossEntropy。二分类问题的交叉熵损失函数定义为:
L=−∑xr∼pr(⋅)logDθ(xr)−∑xf∼pg(⋅)log(1−Dθ(xf))\mathcal L=-\sum_{\boldsymbol x_r\sim p_r (\cdot)}\text{log}D_θ (\boldsymbol x_r ) -\sum_{\boldsymbol x_f\sim p_g (\cdot)}\text{log}(1-D_θ (\boldsymbol x_f ))L=−xr∼pr(⋅)∑logDθ(xr)−xf∼pg(⋅)∑log(1−Dθ(xf))
因此判别网络D的优化目标是:
θ∗=argmaxθ−∑xr∼pr(⋅)logDθ(xr)−∑xf∼pg(⋅)log(1−Dθ(xf))θ^*=\underset{θ}{\text{argmax}}-\sum_{\boldsymbol x_r\sim p_r (\cdot)}\text{log}D_θ (\boldsymbol x_r ) -\sum_{\boldsymbol x_f\sim p_g (\cdot)}\text{log}(1-D_θ (\boldsymbol x_f ))θ∗=θargmax−xr∼pr(⋅)∑logDθ(xr)−xf∼pg(⋅)∑log(1−Dθ(xf))
把minθL\underset{θ}{\text{min}}\mathcal LθminL问题转换为maxθ−L\underset{θ}{\text{max}}-\mathcal Lθmax−L,并写成期望形式:
θ∗=argmaxθExr∼pr(⋅)logDθ(xr)+Exf∼pg(⋅)log(1−Dθ(xf))θ^*=\underset{θ}{\text{argmax}} \mathbb E_{\boldsymbol x_r\sim p_r (\cdot) } \text{log}D_θ (\boldsymbol x_r )+\mathbb E_{\boldsymbol x_f\sim p_g (\cdot) }\text{log}(1-D_θ (\boldsymbol x_f ))θ∗=θargmaxExr∼pr(⋅)logDθ(xr)+Exf∼pg(⋅)log(1−Dθ(xf))
对于生成网络G(z)\text{G}(\boldsymbol z)G(z),我们希望xf=G(z)\boldsymbol x_f=\text{G}(\boldsymbol z)xf=G(z)能够很好地骗过判别网络D,假样本xf\boldsymbol x_fxf在判别网络的输出越接近真实的标签越好。也就是说,在训练生成网络时,希望判别网络的输出D(G(z))\text{D}(\text{G}(\boldsymbol z))D(G(z))越逼近1越好,最小化D(G(z))\text{D}(\text{G}(\boldsymbol z))D(G(z))与111之间的交叉熵损失函数:
minϕL=CE(D(G(z)),1)=−logGϕ(z)\underset{ϕ}{\text{min}}\mathcal L=\text{CE}(\text{D}(\text{G}(\boldsymbol z)),1)=-\text{log}\text{G}_ϕ (\boldsymbol z)ϕminL=CE(D(G(z)),1)=−logGϕ(z)
把minϕL\underset{ϕ}{\text{min}}\mathcal LϕminL问题转换成maxϕ−L\underset{ϕ}{\text{max}}-\mathcal Lϕmax−L,并写成期望形式:
ϕ∗=argmaxϕEz∼pz(⋅)logD(Gϕ(z))ϕ^*=\underset{ϕ}{\text{argmax}}\mathbb E_{\boldsymbol z\sim p_z (\cdot)} \text{log}\text{D}(\text{G}_ϕ (\boldsymbol z))ϕ∗=ϕargmaxEz∼pz(⋅)logD(Gϕ(z))
再次等价转化为:
ϕ∗=argminϕL=argmaxϕEz∼pz(⋅)log[1−D(Gϕ(z))]ϕ^*=\underset{ϕ}{\text{argmin}}\mathcal L=\underset{ϕ}{\text{argmax}}\mathbb E_{\boldsymbol z\sim p_z (\cdot)} \text{log}[1-\text{D}(\text{G}_ϕ (\boldsymbol z))]ϕ∗=ϕargminL=ϕargmaxEz∼pz(⋅)log[1−D(Gϕ(z))]
其中ϕϕϕ为生成网络G的参数集,可以利用梯度下降算法来优化参数ϕϕϕ。
3. 统一目标函数
我们把判别网络的目标和生成网络的目标合并,写成min−max\text{min}-\text{max}min−max博弈形式:
minϕmaxϕL(D,G)=Exr∼pr(⋅)logDθ(xr)+Exf∼pg(⋅)log(1−Dθ(xf))=Ex∼pr(⋅)logDθ(x)+Ez∼pz(⋅)log(1−Dθ(Gϕ(z)))\begin{aligned}\underset{ϕ}{\text{min}} \ \underset{ϕ}{\text{max}}\mathcal L(\text{D},\text{G})&=\mathbb E_{\boldsymbol x_r\sim p_r (\cdot) } \text{log}D_θ (\boldsymbol x_r )+\mathbb E_{\boldsymbol x_f\sim p_g (\cdot) } \text{log}(1-D_θ (\boldsymbol x_f ))\\ &=\mathbb E_{\boldsymbol x\sim p_r (\cdot) } \text{log}D_θ (\boldsymbol x)+\mathbb E_{\boldsymbol z\sim p_z (\cdot)} \text{log}(1-D_θ (G_ϕ (\boldsymbol z)))\end{aligned}ϕminϕmaxL(D,G)=Exr∼pr(⋅)logDθ(xr)+Exf∼pg(⋅)log(1−Dθ(xf))=Ex∼pr(⋅)logDθ(x)+Ez∼pz(⋅)log(1−Dθ(Gϕ(z)))
算法流程如下:
如果觉得《深度学习之生成对抗网络(2)GAN原理》对你有帮助,请点赞、收藏,并留下你的观点哦!