[关闭]
@Team 2019-05-02T13:33:00.000000Z 字数 7173 阅读 769

一篇注入了灵魂的 cycleGAN

陈扬


Hi ,my name is Chen Yang ,I am a sophomore in ocean university of China .I do some scientific research in my spare time. Based on the current hot direction of artificial intelligence, I hope to share my research progress with Generative adversarial network

嗨,我的名字是陈扬,我是中国海洋大学的二年级学生。我在业余时间做了一些科学研究。基于当前人工智能的热点方向,我希望与生成对抗网络分享我的研究进展

前言

ODOG,顾名思义就我我希望能每天抽出一个小时的时间来讲讲到目前为止,GAN的前沿发展和研究,笔者观察了很多深度学习的应用,特别是在图像这一方面,GAN已经在扮演着越来越重要的角色,我们经常可以看到老黄的NVIDIA做了各种各样的application,而且其中涉及到了大量GAN的理论及其实现,再者笔者个人也觉得目前国内缺少GAN在pytorch,keras,tensorflow等主流的框架下的实现教学.

我的老师曾经对我说过:"深度学习是一块未知的新大陆,它是一个大的黑箱系统,而GAN则是黑箱中的黑箱,谁要是能打开这个盒子,将会引领一个新的时代"

cycleGAN

大名鼎鼎的 cycleGAN 就不用我多介绍了吧,那个大家每一次一讲到深度学习,总有人会放出来的那张图片,马变斑马的那个,就是出自于大名鼎鼎的 cycleGAN 了
200% center

68747470733a2f2f6a756e79616e7a2e6769746875622e696f2f4379636c6547414e2f696d616765732f7465617365725f686967685f7265732e6a7067.jpg-553.8kB

大概的思想我简要的说一下吧

img

就是我输入的图片是宋冬野的《斑马》,他经过一个简单的卷积神经网络 后,从 集合 A 映射到了集合 B,也是就《董小姐》里面的宋冬野爱上的那匹野马.

你以为这就结束了吗?

如果就此结束,

岂不是变成了讲 Pix2Pix 了

emmmmmmm,可惜我先写 cycleGAN 了,可惜我不是宋冬野啊,就算我是宋冬野,D 也不是董小姐啊,就算 D 是董小姐,我也不是宋冬野啊,就算我是宋冬野,D 是董小姐,可是我家里没有草原,这让我感到绝望,董小姐.

所以我为什么即兴写了这么多乱七八糟的东西呢,一是我确实遇见过一个很像董小姐的她,二是在图像风格迁移的实际训练中,我们很难找到能够一一对应的数据集,比如这样的:
image_1d9sbu2t1g431mlp1qdbpd5kcs16.png-201.9kB

我们大部分情况下做的有实际意义的数据集,都是 Unpaid 的.

所以我们在把斑马变成野马的过程中,我们希望能够尽可能的减少信息的损失.

所以,马卡斯·扬要把野马用变回斑马,再将变回去的斑马和宋冬野原来的斑马用 做对比,因为 相比保边缘(基于高斯先验,基于拉普拉斯先验)https://blog.csdn.net/m0_38045485/article/details/82147817

最后,判别器还是和寻常的一样,分别判别 A 和 B 是真是假

好了,正经的来看一下损失函数:

img

X→Y的判别器损失为,字母换了一下,和上面的单向GAN是一样的:


同理Y→X的判别器损失为 :

再循环回来和自己比一轮就是:

全部合在一起就是:

那么我们的目标就是:

代码实现

我选了eriklindernoren大佬的代码来讲,因为简单…...

先看看

生成器 G :

  1. ##############################
  2. # RESNET
  3. ##############################
  4. class ResidualBlock(nn.Module):
  5. def __init__(self, in_features):
  6. super(ResidualBlock, self).__init__()
  7. self.block = nn.Sequential(
  8. nn.ReflectionPad2d(1),
  9. nn.Conv2d(in_features, in_features, 3),
  10. nn.InstanceNorm2d(in_features),
  11. nn.ReLU(inplace=True),
  12. nn.ReflectionPad2d(1),
  13. nn.Conv2d(in_features, in_features, 3),
  14. nn.InstanceNorm2d(in_features),
  15. )
  16. def forward(self, x):
  17. return x + self.block(x)
  18. class GeneratorResNet(nn.Module):
  19. def __init__(self, input_shape, num_residual_blocks):
  20. super(GeneratorResNet, self).__init__()
  21. channels = input_shape[0]
  22. # Initial convolution block
  23. out_features = 64
  24. model = [
  25. nn.ReflectionPad2d(channels),
  26. nn.Conv2d(channels, out_features, 7),
  27. nn.InstanceNorm2d(out_features),
  28. nn.ReLU(inplace=True),
  29. ]
  30. in_features = out_features
  31. # Downsampling
  32. for _ in range(2):
  33. out_features *= 2
  34. model += [
  35. nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
  36. nn.InstanceNorm2d(out_features),
  37. nn.ReLU(inplace=True),
  38. ]
  39. in_features = out_features
  40. # Residual blocks
  41. for _ in range(num_residual_blocks):
  42. model += [ResidualBlock(out_features)]
  43. # Upsampling
  44. for _ in range(2):
  45. out_features //= 2
  46. model += [
  47. nn.Upsample(scale_factor=2),
  48. nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
  49. nn.InstanceNorm2d(out_features),
  50. nn.ReLU(inplace=True),
  51. ]
  52. in_features = out_features
  53. # Output layer
  54. model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
  55. self.model = nn.Sequential(*model)
  56. def forward(self, x):
  57. return self.model(x)

emmmmm,我觉得已经说的很清楚了吧,就是一个普普通通的 resnet 啦

就不可视化出来了,免得说我水字数…..

判别器:

  1. ##############################
  2. # Discriminator
  3. ##############################
  4. class Discriminator(nn.Module):
  5. def __init__(self, input_shape):
  6. super(Discriminator, self).__init__()
  7. channels, height, width = input_shape
  8. # Calculate output shape of image discriminator (PatchGAN)
  9. self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
  10. def discriminator_block(in_filters, out_filters, normalize=True):
  11. """Returns downsampling layers of each discriminator block"""
  12. layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
  13. if normalize:
  14. layers.append(nn.InstanceNorm2d(out_filters))
  15. layers.append(nn.LeakyReLU(0.2, inplace=True))
  16. return layers
  17. self.model = nn.Sequential(
  18. *discriminator_block(channels, 64, normalize=False),
  19. *discriminator_block(64, 128),
  20. *discriminator_block(128, 256),
  21. *discriminator_block(256, 512),
  22. nn.ZeroPad2d((1, 0, 1, 0)),
  23. nn.Conv2d(512, 1, 4, padding=1)
  24. )
  25. def forward(self, img):
  26. return self.model(img)

这个其实是有点意思的,这个叫 PatchGAN,就是说判别器 D 的输出不是简单的 0/1,而是输出一张比原来输入size小 16 倍的通道数为 1 的特征图,特征图的每个点,1 表示real,0 表示 fake.他的感受野计算你们就别算了,我告诉你最后一层是 70.

再来看看训练

首先有三个损失函数:

  1. # Losses
  2. criterion_GAN = torch.nn.MSELoss()
  3. criterion_cycle = torch.nn.L1Loss()
  4. criterion_identity = torch.nn.L1Loss()
  1. # Set model input
  2. real_A = Variable(batch["A"].type(Tensor))
  3. real_B = Variable(batch["B"].type(Tensor))
  4. # Adversarial ground truths
  5. valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
  6. fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

这个 label 就是对应了 PatchGAN 的输出,有点意思

训练生成器

  1. # ------------------
  2. # Train Generators
  3. # ------------------
  4. G_AB.train()
  5. G_BA.train()
  6. optimizer_G.zero_grad()
  7. # Identity loss
  8. loss_id_A = criterion_identity(G_BA(real_A), real_A)
  9. loss_id_B = criterion_identity(G_AB(real_B), real_B)
  10. loss_identity = (loss_id_A + loss_id_B) / 2
  11. # GAN loss
  12. fake_B = G_AB(real_A)
  13. loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
  14. fake_A = G_BA(real_B)
  15. loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
  16. loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
  17. # Cycle loss
  18. recov_A = G_BA(fake_B)
  19. loss_cycle_A = criterion_cycle(recov_A, real_A)
  20. recov_B = G_AB(fake_A)
  21. loss_cycle_B = criterion_cycle(recov_B, real_B)
  22. loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
  23. # Total loss
  24. loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
  25. loss_G.backward()
  26. optimizer_G.step()

这个训练生成器比较复杂,我们得先把他拆解成 3 个部分:

  1. identity_loss:是指我把 A 集合中的图片过生成器 后还是 A,意思是 A 就是 A,B 就是 B(相当于证明你妈是你妈,没有那个意思哈)
  2. GAN_loss:经典的以假乱真 loss了,就是希望判别器把生成的图片以为是ground truth.
  3. Cycle_loss:循环一致性损失,就是说,我希望斑马变成野马后,再从野马变成斑马还是原来那样的斑马(其实从信息论的角度来看这是不可能的,毕竟信息流动了就会损失,就好比你曾经爱上了她,后来你俩分手了,即便是再后来你们和好了,也不可能再回到当初那个"人生若只如初见,何事西风悲画扇"的状态了)

好了,生成器也就这么回事吧.

再来看看判别器:

  1. # -----------------------
  2. # Train Discriminator A
  3. # -----------------------
  4. optimizer_D_A.zero_grad()
  5. # Real loss
  6. loss_real = criterion_GAN(D_A(real_A), valid)
  7. # Fake loss (on batch of previously generated samples)
  8. fake_A_ = fake_A_buffer.push_and_pop(fake_A)
  9. loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
  10. # Total loss
  11. loss_D_A = (loss_real + loss_fake) / 2
  12. loss_D_A.backward()
  13. optimizer_D_A.step()
  14. # -----------------------
  15. # Train Discriminator B
  16. # -----------------------
  17. optimizer_D_B.zero_grad()
  18. # Real loss
  19. loss_real = criterion_GAN(D_B(real_B), valid)
  20. # Fake loss (on batch of previously generated samples)
  21. fake_B_ = fake_B_buffer.push_and_pop(fake_B)
  22. loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
  23. # Total loss
  24. loss_D_B = (loss_real + loss_fake) / 2
  25. loss_D_B.backward()
  26. optimizer_D_B.step()
  27. loss_D = (loss_D_A + loss_D_B) / 2

也是有两个判别器,分别判别 A,B 是真是假.

如何我再聊聊代码这么跑起来吧

你先 clone 我上传上去的代码,我改了一些,

额外帮你们写了一个 test.py

  1. cd data/
  2. sh download_cyclegan_dataset.sh monet2photo
  3. cd ..
  4. python cyclegan.py
  5. after long long time............
  6. python test --checkpiont saved_models/monet2photo/G_BA_100.pth --image_path 你的图片路径

这是一篇很有灵魂的文章,在文章最后我想说,图像翻译的坑很深,训练也是真的很难,不过真的挺有意思的,很难得写这么有灵魂的文章了.

来一首董小姐

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注