0%

花样繁多的GAN

在做对抗样本的时候,我发现对抗防御和 GAN 在某种程度上很像:

  • 对抗防御:内部生成对抗样本攻击分类器,分类器更新参数防御攻击;
  • GAN:生成器 $G$ 生成样本欺骗判别器,判别器更新防御 $G$ 的欺骗。

所以说,这俩在某种程度上真的很像,所以决定整理一下 GAN 的知识,开拓一下思路,视频内容来自这里。警告:里面有大量数学公式的推导,而我就不一样了,不仅有公式推导还有代码。

前情提要

我发现我在看论文或者别人的博客的时候,经常发现数据分布、概率分布等,然而时间长了竟然不知道所谓的分布是什么,这里来提前说一下:

  • 数据分布:可以理解为字面意思,数据的分布。比如一个数据集:1,2,3,只有三条数据,可能是 1,可能是 2,可能是 3,这就是数据的分布;
  • 概率分布:可以理解为字面意思,概率的分布。比如一枚硬币,正面朝上的概率是 50%,反面朝上的概率也是 50%,这就是概率分布,类似的,正态分布、泊松分布等都是概率分布。

朴素 GAN

以图像为例,GAN 的生成器 $G$ 输入为某种随机分布的数据,输出为一张图像;而判别器 $D$ 的输入是一张图像,输出是一个数字,数字越大表示当前图像的质量越高,我们设置 $G$ 生成的图片是假的。两者相互博弈,$G$ 企图去欺骗 $D$,$D$ 防止来自 $G$ 的欺骗。

  • $G$ 只适合生成,不适合做判别,因为 $G$ 的主要任务是生成图像。如果 $G$ 要判别图像的好坏,如何定义生成图像的好坏?生成图像的每个像素点的位置、偏移、颜色、旋转,以及像素点之间的联系将会很复杂,模型将会很复杂,所以由 $D$ 来做判别;
  • $D$ 只适合判别,不适合生成。$D$ 的判别是宏观角度的判断当前图像好不好,并不会纠结于某个像素的颜色、旋转、偏移的好坏。如果 $D$ 要做生成,那么需要遍历所有数据 $x$,来看看哪个数据使得网络的输出得分最高。这样虽然可以生成,但枚举所有数据是不可能的操作。

或者说,这是一种另类的小样本,网络不可能见过所有生成的数据,但是要求 $D$ 能判断没见过的数据是好是坏。此外,足够强的 $D$ 才能迫使 $G$ 生成的图片足够逼真,足够好的 $G$ 才会使 $D$ 的判别越来越准。 在接下来的实验中将见识到这一点。

如上图所示:横坐标为 $G$ 和 $D$ 的演化过程,纵坐标表示概率分布,红色曲线表示 $G$,红色的点表示 $G$ 生成的数据,绿色的线表示 $D$,绿色的点表示真实数据。

  1. 首先,图的最左侧,$G$ 生成的数据和真是数据不符,那么判别器抑制 $G$ 把数据生成到其他分布,并希望 $G$ 生成的数据分布和真实数据的分布较为接近;
  2. 其次,图的中间,因为 $G$ 的目的是无限度找 $D$ 的漏洞,虽然相比左图,生成数据的分布和真实数据的分布接近了,但是把更多的数据生成到了另外的分布,这也是不行的;
  3. 经过多次的博弈,最终达到右侧图的效果。

公式推导

$G$ 的目标

我们的目标可以用公式表述为:通过 $G$,寻找数据 $x$ 的分布 $P_G(x)$,但是目前只有数据 $x$,也就是知道近似的 $P(x)$。可以通过极大似然估计来求解问题,使 $P_G(x)$ 逐渐逼近 $P(x)$,即在 $P(x)$ 中采 $m$ 个样本,使得 $\prod_{i=1}^mP_G(x_i;\theta)$ 的概率最大,逐步采样与迭代求出模型的参数 $\theta$。

\begin{aligned}
\theta &= \arg \max_\theta \prod_{i=1}^mP_G(x_i;\theta) \\
&\Leftrightarrow \arg \max_\theta \log \prod_{i=1}^mP_G(x_i;\theta) \\
&= \arg \max_\theta \sum_{i=1}^m \log P_G(x_i;\theta) \\
&\Leftrightarrow \arg \max_\theta \mathbb{E}_{x\sim p(x)} \bigl[ \log P_G(x;\theta) \bigr] \\
&\Leftrightarrow \arg \max_\theta \int_x p(x) \log P_G(x;\theta) dx - \int_x p(x) \log P(x) dx \\
&= \arg \min_\theta KL(P(x) || P_G(x;\theta))
\end{aligned}

也就是说,生成模型的求出的 $P_G(x;\theta)$ 最好情况就是与真是数据的分布 $P(x)$ 的 $KL$ 散度距离最小。当然你也可以用其他距离,详情可以参考 f-GAN 1 这篇论文,不过究其到底,都可以描述为寻找一个 $P_G(x;\theta)$,使其生成的数据分布和 $P(x)$ 最为接近, $G^\star=\arg\min_G D(P_G(x;\theta), P(x))$。

当然在模型参数过多的情况下,极大似然估计求解困难,可以使用神经网络的反向传播,来求解出模型的参数。

$D$ 的目标

对于 $D$ 而言,其目标是期望真实数据的评分很高,$G$ 生成数据的评分很低,以此来鼓励 $G$ 生成更加逼真的数据达到内卷的目的。那么 $D$ 的目标可以描述为:

\begin{equation}\label{D}
\max \mathbb{E}_{x\sim P(x)} \bigl[ \log D(x) \bigr] + \mathbb{E}_{x\sim P_G(x;\theta)} \bigl[ \log(1-D(x)) \bigr]
\end{equation}

此时固定 $G$,训练 $D$,式 $\eqref{D}$ 继续推导:

\begin{aligned}
\eqref{D} &= \int_x P(x) \log (D(x)) dx+ \int_x P_G(x) \log (1 - D(x)) dx \\
&= \int_x P(x) \log (D(x)) + P_G(x) \log (1 - D(x)) dx\\
\end{aligned}

为了使积分最大,可以等价转换为使积分内部的元素取值最大,即对于以下公式求对 $D(x)$ 微分,微分等于0的时候,取得极值,且下述公式是极大值:

\begin{equation}
f(D(x)) = P(x) \log (D(x)) + P_G(x) \log (1 - D(x))
\end{equation}

此时求出的 $D(x)$:

\begin{equation}
D(x) = \frac{P(x)}{P(x)+P_G(x)}
\end{equation}

把 $D(x)$ 带回式 $\eqref{D}$ :

\begin{aligned}
\eqref{D} &= \mathbb{E}_{x\sim P(x)} \bigl[ \log \frac{P(x)}{P(x)+P_G(x)} \bigr] + \mathbb{E}_{x\sim P_G(x)} \bigl[ \log \frac{P_G(x)}{P(x)+P_G(x)} \bigr] \\
&= \int_x P(x)\log \frac{P(x)}{P(x)+P_G(x)} dx + \int_x P_G(x)\log \frac{P_G(x)}{P(x)+P_G(x)} dx \\
&= -2\log 2 + KL(P(x) || \frac{P(x) + P_G(x)}{2}) + KL(P_G(x) || \frac{P(x) + P_G(x)}{2}) \\
&= -2\log 2 + 2 JS(P(x) || P_G(x))
\end{aligned}

而 JS 散度的取值范围是 $[0, \log2]$,因此 式 $\eqref{D}$ 的最大值为 0。而这里的代码实现可以使用 torch.BCE,但要按照负梯度进行反向传播,所以损失那里直接加了负数,因此代码的损失是正数不要感到意外。

code

既然有了模型结构和公式推导,那么来看一下代码改怎么写,以 MNIST 数据集为例,全部代码在 github

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42

class Generator(nn.Module):
# g_input_dim = 100, g_output_dim=784
def __init__(self, g_input_dim, g_output_dim):
super(Generator, self).__init__()
self.fc1 = nn.Linear(g_input_dim, 256)
self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.leaky_relu(self.fc3(x), 0.2)
return torch.tanh(self.fc4(x))

class Discriminator(nn.Module):
# d_input_dim = 784
def __init__(self, d_input_dim):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(d_input_dim, 1024)
self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
self.fc4 = nn.Linear(self.fc3.out_features, 1)

def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.dropout(x, 0.3)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.dropout(x, 0.3)
x = F.leaky_relu(self.fc3(x), 0.2)
x = F.dropout(x, 0.3)
return torch.sigmoid(self.fc4(x))

n_epoch = 200
for epoch in range(1, n_epoch+1):
D_losses, G_losses = [], []
for batch_idx, (x, _) in enumerate(train_loader):
for i in range(2):
D_train(x)
D_losses.append(D_train(x))
G_losses.append(G_train(x))

$D$ 为什么要比 $G$ 多训练两次?因为只有去够好的 $D$,此案能使得 $G$ 也好,$G$ 好了 $D$ 也会好;否则 $D$ 很差,$G$ 也懒得更新,生成的结果会很差。此外,如果两者都只训练一次,我发现 $D$ 的损失是逐步上升的。如下图所示,右侧是 $D$ 训练三次的,左侧是 $D$ 训练一次的效果,明显发现右侧的效果要好一些。

WGAN-GP

前言

通过前文我们可以了解到,如果 $D$ 训练得太弱,指示作用不显著,则 $G$ 不能进行有效的学习;但是,如果 $D$ 训练得太好, $G$ 就无法得到足够的梯度继续优化,这样一来, $D$ 的训练火候就非常难把控,这就是 GAN 训练难的根源。如下图所示:

  • 左侧的 $D$ 很弱,$G$ 也不知道哪边好哪边弱;
  • 右侧的 $D$ 很强,存在梯度消失和梯度爆炸问题,即在区域 $G$ 获取到的梯度近似为0而无法获得有效的更新。

说点通俗的话,你在学习,你的老师太弱,你学不出来;你的老师太强,无论你怎么学老师都说你是错的,即使进步是有效的,但还是被你的老师说没用,那么之前的进步也可能会回退,你还是学不出来。

朴素 GAN 的缺陷

虽然有着完备的理论推导,但是朴素 GAN 仍有以下的缺陷:

  • 对于 $D$ 的目标,经过推导会发现竟然要最大化 $P(x)$ 和 $P_G(x;\theta)$ 的距离,显然违背了 $D$ 的初衷,或者说,不是这样判别的;关于 $D$ 计算距离的缺陷,这里 3 有详细公式推导。
  • 对于 $G$ 的目标,如果一味的最小化 $P_G(x;\theta)$ 和 $P(x)$ 的距离,会使生成的数据只有安全性,没有多样性,这显然也是不好的。

Wasserstein 距离

于是,如何巧妙的衡量生成分布与真实分布之间的距离,WGAN 定义了 Wasserstein 距离。即:

\begin{equation}
W(P(x), P_G(x;\theta))=\inf_{\gamma\sim\prod(P(x), P_G(x;\theta))} \mathbb{E}_{(x,y)\sim \gamma} \bigl[ \Vert| x-y \Vert \bigr]
\end{equation}

$\prod(P(x), P_G(x;\theta))$ 是 $P(x)$ 和 $P_G(x;\theta)$ 联合分布的集合,对于每一个联合分布 $\gamma$ 从中取样,得到一个真实样本 $x$ 和一个虚假样本 $y$,并计算联合分布 $\gamma$ 中两两样本的距离 $\Vert x-y \Vert$,这也就是 Wasserstein 距离。意思是,从 $y$ 移动到 $x$ 需要多远,即便两个分布没有重叠,Wasserstein 距离仍然能够反映它们的远近。

既然如此,那么就把 Wasserstein 距离用到 GAN 中,由于 $\inf$ 是最大下确界,所以:

\begin{aligned}
W(P(x), P_G(x;\theta)) & \leq \int_{P(x), P_G(x;\theta)} \Vert x-y \Vert d\gamma \\
{ }& = \mathbb{E}_{(x,y)\sim \gamma} \bigl[ \Vert x-y \Vert \bigr] \\
&= \mathbb{E}_{x\sim P(x)} D(x) - \mathbb{E}_{x\sim P_G(x;\theta)} D(G(x))
\end{aligned}

至此,真实分布与生成分布之间的 Wasserstein 距离融入到了 GAN 中。朴素 GAN 的 $D$ 做的是真假二分类任务,所以最后一层是 sigmoid,但是现在 WGAN 中的 $D$ 做的是近似拟合 Wasserstein 距离,属于回归任务,所以要把 $D$ 最后一层的 sigmoid 拿掉。

接下来 $G$ 要最小化 Wasserstein 距离,$D$ 要最大化 Wasserstein 距离的优良性质,也不需要担心 $D$ 导致的梯度消失的问题,这样就得到了WGAN的两个loss,$G$ 的 loss 是 $- \mathbb{E}_{x\sim P_G(x;\theta)} D(G(x))$,$D$ 的 loss 是 $-\mathbb{E}_{x\sim P(x)} D(x) + \mathbb{E}_{x\sim P_G(x;\theta)} D(G(x))$ (距离最小,取负号)。

为了使距离不在梯度消失问题和梯度爆炸,后续有论文进行了改进 2,并不是直接对 $D$ 的权重进行裁剪限制在某个范围内,而是加入了惩罚项,因此 $D$ 的损失函数变为:

\begin{equation}
-\mathbb{E}_{x\sim P(x)} D(x) + \mathbb{E}_{x\sim P_G(x;\theta)} D(G(x)) + \lambda \mathbb{E}_{\hat{x}\sim P(\hat{x})} \bigl( \Vert \nabla_\hat{x} D(\hat{x}) \Vert_2 - 1 \bigr)^2
\end{equation}

$\hat{x}$ 是真实数据和虚假数据的随机采样,也就是说,希望 $D$ 对 $\hat{x}$ 数据的梯度保持在都是 1 左右,这样就不用再考虑梯度消失问题和梯度爆炸问题,虚假数据也好往真实数据去移动。最终的算法如下所示:

WGAN 有以下优势:

  • 不再需要纠结如何平衡 $G$ 和 $D$ 的训练程度,大大提高了GAN 训练的稳定性,$D$ 训练得越好,对提升 $G$ 就越有利。
  • 即使网络结构设计得比较简陋,WGAN 也能展现出良好的性能,包括避免了样本不够多样性的现象,体现了出色的鲁棒性。
  • $D$ 的 loss 很准确地反映了 $G$ 生成样本的质量,因此可以作为展现 GAN 训练进度的定性指标。

code

完整代码在 github,这里只展示计算惩罚项的部分:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def compute_gradient_penalty(D, real_samples, fake_samples):

# [0, 1] 之间的随机数
alpha = torch.rand(1).to(device)

# 生成 \hat{x}
x = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True).to(device)

# D 对生成数据的梯度控制为 1 左右,所以计算梯度
interpolates = D(x)

fake = torch.ones(bs, 1).to(device)

# 计算 D(x) 对 interpolates 的梯度
gradients = torch.autograd.grad(
outputs=interpolates, # 用来求导的
inputs=x, # 被求导的梯度值
grad_outputs=fake, # 求梯度时对输出的权重
create_graph=True, # 创建计算图
)[0]

gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
return gradient_penalty

可能是我网络设置的比较简单,效果并不是很好:

After 100 epochs
After 200 epochs
After 300 epochs
After 400 epochs
After 500 epochs

结语

本文涉及了大量的公式推导,从朴素的 GAN 推到了 WGAN-GP。此外,还有啥 conditional GAN,infoGAN,BigGAN,cycleGAN,XGAN 等等等等,不过那些东西脱离了我要做的内容,所以,等哪天空闲了,会简单的整理这些模型的框架和 idea,如果不忙会尝试复现一个 photo shop。

reference

感谢上学期间打赏我的朋友们。赛博乞讨:我,秦始皇,打钱。

欢迎订阅我的文章