@Team 2019-05-02T13:33:00.000000Z 字数 7173 阅读 1426

# 一篇注入了灵魂的 cycleGAN

Hi ,my name is Chen Yang ,I am a sophomore in ocean university of China .I do some scientific research in my spare time. Based on the current hot direction of artificial intelligence, I hope to share my research progress with Generative adversarial network

## 前言

ODOG,顾名思义就我我希望能每天抽出一个小时的时间来讲讲到目前为止,GAN的前沿发展和研究,笔者观察了很多深度学习的应用,特别是在图像这一方面,GAN已经在扮演着越来越重要的角色,我们经常可以看到老黄的NVIDIA做了各种各样的application,而且其中涉及到了大量GAN的理论及其实现,再者笔者个人也觉得目前国内缺少GAN在pytorch,keras,tensorflow等主流的框架下的实现教学.

## cycleGAN

emmmmmmm,可惜我先写 cycleGAN 了,可惜我不是宋冬野啊,就算我是宋冬野,D 也不是董小姐啊,就算 D 是董小姐,我也不是宋冬野啊,就算我是宋冬野,D 是董小姐,可是我家里没有草原,这让我感到绝望,董小姐.

X→Y的判别器损失为，字母换了一下，和上面的单向GAN是一样的：

### 代码实现

#### 生成器 G :

1. ##############################
2. # RESNET
3. ##############################
4. class ResidualBlock(nn.Module):
5. def __init__(self, in_features):
6. super(ResidualBlock, self).__init__()
7. self.block = nn.Sequential(
9. nn.Conv2d(in_features, in_features, 3),
10. nn.InstanceNorm2d(in_features),
11. nn.ReLU(inplace=True),
13. nn.Conv2d(in_features, in_features, 3),
14. nn.InstanceNorm2d(in_features),
15. )
16. def forward(self, x):
17. return x + self.block(x)
18. class GeneratorResNet(nn.Module):
19. def __init__(self, input_shape, num_residual_blocks):
20. super(GeneratorResNet, self).__init__()
21. channels = input_shape[0]
22. # Initial convolution block
23. out_features = 64
24. model = [
26. nn.Conv2d(channels, out_features, 7),
27. nn.InstanceNorm2d(out_features),
28. nn.ReLU(inplace=True),
29. ]
30. in_features = out_features
31. # Downsampling
32. for _ in range(2):
33. out_features *= 2
34. model += [
35. nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
36. nn.InstanceNorm2d(out_features),
37. nn.ReLU(inplace=True),
38. ]
39. in_features = out_features
40. # Residual blocks
41. for _ in range(num_residual_blocks):
42. model += [ResidualBlock(out_features)]
43. # Upsampling
44. for _ in range(2):
45. out_features //= 2
46. model += [
47. nn.Upsample(scale_factor=2),
48. nn.Conv2d(in_features, out_features, 3, stride=1, padding=1),
49. nn.InstanceNorm2d(out_features),
50. nn.ReLU(inplace=True),
51. ]
52. in_features = out_features
53. # Output layer
54. model += [nn.ReflectionPad2d(channels), nn.Conv2d(out_features, channels, 7), nn.Tanh()]
55. self.model = nn.Sequential(*model)
56. def forward(self, x):
57. return self.model(x)

emmmmm,我觉得已经说的很清楚了吧,就是一个普普通通的 resnet 啦

#### 判别器:

1. ##############################
2. # Discriminator
3. ##############################
4. class Discriminator(nn.Module):
5. def __init__(self, input_shape):
6. super(Discriminator, self).__init__()
7. channels, height, width = input_shape
8. # Calculate output shape of image discriminator (PatchGAN)
9. self.output_shape = (1, height // 2 ** 4, width // 2 ** 4)
10. def discriminator_block(in_filters, out_filters, normalize=True):
11. """Returns downsampling layers of each discriminator block"""
12. layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
13. if normalize:
14. layers.append(nn.InstanceNorm2d(out_filters))
15. layers.append(nn.LeakyReLU(0.2, inplace=True))
16. return layers
17. self.model = nn.Sequential(
18. *discriminator_block(channels, 64, normalize=False),
19. *discriminator_block(64, 128),
20. *discriminator_block(128, 256),
21. *discriminator_block(256, 512),
24. )
25. def forward(self, img):
26. return self.model(img)

### 再来看看训练

1. # Losses
2. criterion_GAN = torch.nn.MSELoss()
3. criterion_cycle = torch.nn.L1Loss()
4. criterion_identity = torch.nn.L1Loss()
1. # Set model input
2. real_A = Variable(batch["A"].type(Tensor))
3. real_B = Variable(batch["B"].type(Tensor))
5. valid = Variable(Tensor(np.ones((real_A.size(0), *D_A.output_shape))), requires_grad=False)
6. fake = Variable(Tensor(np.zeros((real_A.size(0), *D_A.output_shape))), requires_grad=False)

#### 训练生成器

1. # ------------------
2. # Train Generators
3. # ------------------
4. G_AB.train()
5. G_BA.train()
7. # Identity loss
8. loss_id_A = criterion_identity(G_BA(real_A), real_A)
9. loss_id_B = criterion_identity(G_AB(real_B), real_B)
10. loss_identity = (loss_id_A + loss_id_B) / 2
11. # GAN loss
12. fake_B = G_AB(real_A)
13. loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
14. fake_A = G_BA(real_B)
15. loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
16. loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
17. # Cycle loss
18. recov_A = G_BA(fake_B)
19. loss_cycle_A = criterion_cycle(recov_A, real_A)
20. recov_B = G_AB(fake_A)
21. loss_cycle_B = criterion_cycle(recov_B, real_B)
22. loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
23. # Total loss
24. loss_G = loss_GAN + opt.lambda_cyc * loss_cycle + opt.lambda_id * loss_identity
25. loss_G.backward()
26. optimizer_G.step()

1. identity_loss:是指我把 A 集合中的图片过生成器 $G_{BA}$后还是 A,意思是 A 就是 A,B 就是 B(相当于证明你妈是你妈,没有那个意思哈)
2. GAN_loss:经典的以假乱真 loss了,就是希望判别器把生成的图片以为是ground truth.
3. Cycle_loss:循环一致性损失,就是说,我希望斑马变成野马后,再从野马变成斑马还是原来那样的斑马(其实从信息论的角度来看这是不可能的,毕竟信息流动了就会损失,就好比你曾经爱上了她,后来你俩分手了,即便是再后来你们和好了,也不可能再回到当初那个"人生若只如初见,何事西风悲画扇"的状态了)

1. # -----------------------
2. # Train Discriminator A
3. # -----------------------
5. # Real loss
6. loss_real = criterion_GAN(D_A(real_A), valid)
7. # Fake loss (on batch of previously generated samples)
8. fake_A_ = fake_A_buffer.push_and_pop(fake_A)
9. loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
10. # Total loss
11. loss_D_A = (loss_real + loss_fake) / 2
12. loss_D_A.backward()
13. optimizer_D_A.step()
14. # -----------------------
15. # Train Discriminator B
16. # -----------------------
18. # Real loss
19. loss_real = criterion_GAN(D_B(real_B), valid)
20. # Fake loss (on batch of previously generated samples)
21. fake_B_ = fake_B_buffer.push_and_pop(fake_B)
22. loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
23. # Total loss
24. loss_D_B = (loss_real + loss_fake) / 2
25. loss_D_B.backward()
26. optimizer_D_B.step()
27. loss_D = (loss_D_A + loss_D_B) / 2

1. cd data/
3. cd ..
4. python cyclegan.py
5. after long long time............
6. python test --checkpiont saved_models/monet2photo/G_BA_100.pth --image_path 你的图片路径

• 私有
• 公开
• 删除