生成对抗网络由两个网络:生成器(generator)和鉴别器(discriminator)组成。它的想法很简单:生成器不断生成样本出来,鉴别器被训练为去鉴别一张样本是生成器生成出来的,还是原始数据集中的样本。

  • 生成网络 GG:去学习生成数据的分布 pgp_g 通过数据集 xx。我们假设需要的生成的数据和受到一些变量 zz 的控制,模型将 zz 映为生成的数据 G(z;θg)G(z; \theta_g),其中 θg\theta_g 为生成模型的参数
  • 鉴别网络 DD:去辨别一个数据 xx 是来自原始数据还是 pgp_g,需要生成一个 [0,1][0, 1] 中的标量,表示一个数据属于原始数据的概率。

我们去训练鉴别网络去提高正确分类的概率,同时训练生成网络去最小化 log(1D(G(z)))\log(1 - D(G(z))),即让生成的数据更难以被鉴别网络分辨。

定义价值 V(D,G)V(D, G) 为:

V(D,G)=Expdata(x)[logD(x)]+Expz(z)[log(1D(G(z)))]V(D, G) = \mathbb E_{x\sim p_{data}(x)}[\log D(x)] + \mathbb E_{x\sim p_z(z)}[\log (1 - D(G(z)))]

接着类似两个人在博弈,DD 想要最大化 VVGG 想要最小化 VV

训练的过程

对于每个训练的 epoch:

  • 首先训练鉴别网络,用生成一些噪音 z(1),,z(m)z^{(1)}, \cdots, z^{(m)} 以及对数据集进行采样 x(1),,x(m)x^{(1)}, \cdots, x^{(m)},然后去最大化

    L1=1mi=1m[logD(x(i))+log(1D(G(z(i))))]L_1 = \frac{1}{m} \sum_{i=1}^m [\log D(x^{(i)}) + \log (1 - D(G(z^{(i)})))]

    对鉴别网络进行更新,重复这个过程 kk

  • 然后训练生成网络,随机生成一些 zz,对于每个小批次,利用损失

    L2=1mi=1mlog(1D(G(z(i))))L_2 = \frac{1}{m} \sum_{i = 1}^m \log (1 - D(G(z^{(i)})))

    对生成网络进行更新

主要代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def train_epoch(bar, D, G, optimizer_D, optimizer_G, data_iter, G_input_size, device, current_loss) :
//......
for X in data_iter :

batch_size = X.shape[0]

# update the discriminator using the dataset
D.train()
D.requires_grad_(True)
optimizer_G.zero_grad()
optimizer_D.zero_grad()
X = X.to(device)
y = D(X)
y = torch.clamp(y, 1e-3, 1)
l = -torch.log(y).mean()
l.backward()
grad_clipping(D, 1)
optimizer_D.step()

//......

# update the discriminator using the generated data
G.eval()
optimizer_G.zero_grad()
optimizer_D.zero_grad()
G.requires_grad_(False)
Z = torch.randn(batch_size, G_input_size)

y = D(G(Z.to(device)))
y = torch.clamp(y, 0, 1 - 1e-3)
l = -torch.log(1 - y).mean()
l.backward()
grad_clipping(D, 1)
optimizer_D.step()

//......

# updated the generator
D.eval()
G.train()
l = 0

train_G_flg = False

# The expected loss_D is log(4).
while (tmp_loss < 1.75 and rest_K <= 0) and abs(float(l)) < 1e-1 :
train_G_flg = True
G.requires_grad_(True)
optimizer_G.zero_grad()
optimizer_D.zero_grad()
D.requires_grad_(False)
Z = torch.randn(batch_size, G_input_size)
y = D(G(Z.to(device)))
y = torch.clamp(y, 0, 1 - 1e-3)
l = torch.log(1 - y).mean()
l.backward()
grad_clipping(G, 1)
optimizer_G.step()

if train_G_flg :
# when l = -0.5, rest_K = 4
rest_K = int(1 + abs(float(l)) * 6)

//......

return sum_loss_D / num_tests_D, sum_loss_G / num_tests_G

复现

在 MNIST 数据集上训练 GAN 网络。其中辨别器由 2 个线性层和 1 个 LeakyReLU 激活层组成,生成器由 3 个线性层和 2 个ReLU 激活层组成。在 lr = 1e-4 下,训练 120 个周期后效果如下:

实际训练时做了一些略微的修改:

  • 当生成器的损失的绝对值小于 0.1 的时候继续训练生成器

踩坑

  • 计算损失的时候需要对 DD 的结果进行裁剪,防止取对数的时候挂掉。
  • GAN 比较容易炸梯度(梯度变成 inf),需要训练的时候梯度裁剪。
  • 如果 DD 训练较快,loss 变成 0 了,后面 loss 保持为 0,训练可能直接挂了。
  • lr 设置过小会导致 L1,L2L_1, L_2 的绝对值偏高,但实际效果比较差。需要结合模型具体的运行情况来确定 lr。
  • 注意 net.eval() 调用后仍然会去计算梯度,需要通过 net.requires_grad_(False) 去禁用计算梯度。如果没有禁用梯度计算,需要在每次计算开始时,清空两个网络的梯度。