生成对抗网络由两个网络:生成器(generator)和鉴别器(discriminator)组成。它的想法很简单:生成器不断生成样本出来,鉴别器被训练为去鉴别一张样本是生成器生成出来的,还是原始数据集中的样本。
- 生成网络 G:去学习生成数据的分布 pg 通过数据集 x。我们假设需要的生成的数据和受到一些变量 z 的控制,模型将 z 映为生成的数据 G(z;θg),其中 θg 为生成模型的参数
- 鉴别网络 D:去辨别一个数据 x 是来自原始数据还是 pg,需要生成一个 [0,1] 中的标量,表示一个数据属于原始数据的概率。
我们去训练鉴别网络去提高正确分类的概率,同时训练生成网络去最小化 log(1−D(G(z))),即让生成的数据更难以被鉴别网络分辨。
定义价值 V(D,G) 为:
V(D,G)=Ex∼pdata(x)[logD(x)]+Ex∼pz(z)[log(1−D(G(z)))]
接着类似两个人在博弈,D 想要最大化 V,G 想要最小化 V。
训练的过程
对于每个训练的 epoch:
-
首先训练鉴别网络,用生成一些噪音 z(1),⋯,z(m) 以及对数据集进行采样 x(1),⋯,x(m),然后去最大化
L1=m1i=1∑m[logD(x(i))+log(1−D(G(z(i))))]
对鉴别网络进行更新,重复这个过程 k 轮
-
然后训练生成网络,随机生成一些 z,对于每个小批次,利用损失
L2=m1i=1∑mlog(1−D(G(z(i))))
对生成网络进行更新
主要代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
| def train_epoch(bar, D, G, optimizer_D, optimizer_G, data_iter, G_input_size, device, current_loss) : //...... for X in data_iter : batch_size = X.shape[0] # update the discriminator using the dataset D.train() D.requires_grad_(True) optimizer_G.zero_grad() optimizer_D.zero_grad() X = X.to(device) y = D(X) y = torch.clamp(y, 1e-3, 1) l = -torch.log(y).mean() l.backward() grad_clipping(D, 1) optimizer_D.step() # update the discriminator using the generated data G.eval() optimizer_G.zero_grad() optimizer_D.zero_grad() G.requires_grad_(False) Z = torch.randn(batch_size, G_input_size) y = D(G(Z.to(device))) y = torch.clamp(y, 0, 1 - 1e-3) l = -torch.log(1 - y).mean() l.backward() grad_clipping(D, 1) optimizer_D.step() # updated the generator D.eval() G.train() l = 0 train_G_flg = False # The expected loss_D is log(4). while (tmp_loss < 1.75 and rest_K <= 0) and abs(float(l)) < 1e-1 : train_G_flg = True G.requires_grad_(True) optimizer_G.zero_grad() optimizer_D.zero_grad() D.requires_grad_(False) Z = torch.randn(batch_size, G_input_size) y = D(G(Z.to(device))) y = torch.clamp(y, 0, 1 - 1e-3) l = torch.log(1 - y).mean() l.backward() grad_clipping(G, 1) optimizer_G.step() if train_G_flg : # when l = -0.5, rest_K = 4 rest_K = int(1 + abs(float(l)) * 6)
return sum_loss_D / num_tests_D, sum_loss_G / num_tests_G
|
复现
在 MNIST 数据集上训练 GAN 网络。其中辨别器由 2 个线性层和 1 个 LeakyReLU 激活层组成,生成器由 3 个线性层和 2 个ReLU 激活层组成。在 lr = 1e-4
下,训练 120 个周期后效果如下:
实际训练时做了一些略微的修改:
- 当生成器的损失的绝对值小于 0.1 的时候继续训练生成器
踩坑
- 计算损失的时候需要对 D 的结果进行裁剪,防止取对数的时候挂掉。
- GAN 比较容易炸梯度(梯度变成 inf),需要训练的时候梯度裁剪。
- 如果 D 训练较快,loss 变成 0 了,后面 loss 保持为 0,训练可能直接挂了。
- lr 设置过小会导致 L1,L2 的绝对值偏高,但实际效果比较差。需要结合模型具体的运行情况来确定 lr。
- 注意
net.eval()
调用后仍然会去计算梯度,需要通过 net.requires_grad_(False)
去禁用计算梯度。如果没有禁用梯度计算,需要在每次计算开始时,清空两个网络的梯度。