[关闭]
@Team 2019-03-01T08:08:04.000000Z 字数 3748 阅读 1569

SRGAN-超分辨率图像复原

陈扬

github:

https://github.com/OUCMachineLearning/OUCML/blob/master/GAN/srgan_celebA/srgan.py

arxiv:https://arxiv.org/abs/1609.04802

带你手把手来个有趣的SRGAN-超分辨率复原(有代码哦)

结合我的研究方向:GAN
大家好,我是中国海洋大学的陈扬。在遥远的九月份,我开始做了keras的系列教程,现在我主要的研究方向转到了生成对抗网络,生成对抗网络的代码实现和训练机制比分类模型都要复杂和难入门.之前一段时间时间一直在帮璇姐跑cvpr的实验代码,做了蛮多的对比实验,其中我就发现了,keras的代码实现和可阅读性很好,搭生成对抗网络网络GAN就好像搭乐高积木一样有趣哦。不只是demo哦,我还会在接下来的一系列 keras教程中教你搭建Alexnet,Vggnet,Resnet,DCGAN,ACGAN,CGAN,SRGAN,等等实际的模型并且教你如何在GPU服务器上运行。

前言:GAN

上个星期发了一篇有关GAN入门的文章,同学们都觉得挺有趣的,上一次我写了如何理解最基础的GAN的原理,今天我给大家带来的是如何运用强大的GAN做一些好玩的应用.

超分辨率复原一直是计算机视觉领域一个十分热门的研究方向,在商业上也有着很大的用武之地,随着2014年goodflew那篇惊世骇俗的GAN发表出来,GAN伴随着CNN一起,可谓是乘风破浪,衍生出来琳琅满目的各种应用.

image_1ct4sdm7c171la761q1rf881aak9.png-75.2kB
简单的来说,就给定一个低分辨率图片作为噪声z的输入,通过生成器的变换把噪声的概率分布空间尽可能的去拟合真实数据的分布空间.

基本框架:


在这里,我们把生成器看的目标看成是要以次充好,判别器的目标是要明辨真假.
我们可以的看到,在生成器的前6层网络中,我们运用了残差块,为什么要用残差块呢?
因为我们可以从上图看出来,当损失函数从判别器开始反向传播会生成器的时候,实际上进过来很多层,我们知道越深的网络隐藏参数越多,在反向传播的过程中也越容易梯度弥散.而且残差连接的方法,就有效的保证了我们梯度信息能够有效的传递而增强生成对抗网络的鲁棒性.(事实上沃瑟斯坦loss也可以增强GAN训练的鲁棒性,以后会写)

celebA

再来聊聊今天用的数据集,这是Celeb-A,里面有大量的带标注信息的明星人脸.在目前很多的GAN的应用中,都是用CelebA作为基础的数据集,这个数据集大概在1.2G左右,可以在kaggle上下载.
https://www.kaggle.com/jessicali9530/celeba-dataset
浏览数据集
A popular component of computer vision and deep learning revolves around identifying faces for various applications from logging into your phone with your face or searching through surveillance images for a particular suspect. This dataset is great for training and testing models for face detection, particularly for recognising facial attributes such as finding people with brown hair, are smiling, or wearing glasses. Images cover large pose variations, background clutter, diverse people, supported by a large quantity of images and rich annotations. This data was originally collected by researchers at MMLAB, The Chinese University of Hong Kong (specific reference in Acknowledgment section).

Overall

202,599 number of face images of various celebrities
10,177 unique identities, but names of identities are not given
40 binary attribute annotations per image
5 landmark locations

Super-Resolution IMAGE

简单点说,就是给你一张模糊的图片,让你复原一张高清的图片.
image_1ctqd73lo1vkka11cgu1kpt1qd316.png-907kB

我们如何用生成对抗网络来做呢?

这个时候,我们可以把LRimg看成是一个噪声z的输入,G生成的是一个FAKE-HRimg,我们让D分辨fake-HRimg and original HRimg.

定义一个目标函数

Our ultimate goal is to train a generating function G that estimates for a given LR input image its corresponding HR counterpart. To achieve this, we train a generator network as a feed-forward CNN GθG parametrized by θG. Here θG = {W1:L ; b1:L } denotes the weights and biases of a L-layer deep network and is obtained by optimizing a SR-specific
loss function lSR. For training images IHR , n = 1, . . . , N n
withcorrespondingILR,n=1,...,N,wesolve:
image_1ctqhfu351lqh1sns1kfh196liq878.png-39.4kB

提出perceptual loss

作者认为这更接近人的主观感受,因为使用pixel-wise的MSE使得图像变得平滑,而如果先用VGG来抓取到高级特征(feature)表示,再对feature使用MSE,可以更好的抓取不变特征。
image_1ctqe4ifno4g16hkiuo15114f37.png-170kB
lovemusicge_1505531635.8098323.JPG-27.9kB
lovemusicge_1505531647.0589468.JPG-31.4kB
lovemusicge_1505531655.8268762.JPG-25.3kB

核心公式:


这个公式我们要分成两个部分来看:
先看前半部分:

这个公式的意思是,先看加号前面的
,我们希望D最大,所以应该最大,意味着我的判别器可以很好的识别出,真实的高分辨率图像是"true",在看加号后面的,要让log尽可能的大,需要的是ΘD(ΘG(z))尽可能的小,意味着我们
生成模型复原的图片应该尽可能的被判别模型视为"FALSE".

再看后半部分部分
我们应该让G尽可能的小,加号前面的式子并没有G,所以无关,在看加号后面的式子,要让ΘG尽可能地小,就要ΘD(ΘG(Z))尽可能的大,也就是说本来就一张低分辨率生成的图片,判别器却被迷惑了,以为是一张原始的高分辨率图片.这就是所谓的以次充好.

网络设计

image_1ctqe9ihsrc7cbv75k16nl1oco44.png-173.9kB

loss函数

  1. ###vgg用于提取特征
  2. self.vgg.compile(loss='mse',
  3. optimizer=optimizer,
  4. metrics=['accuracy'])
  5. ###生成器
  6. self.combined.compile(loss=['binary_crossentropy', 'mse'],
  7. loss_weights=[1e-3, 1],
  8. optimizer=optimizer)
  9. ###判别器
  10. self.discriminator.compile(loss='mse',
  11. optimizer=optimizer,
  12. metrics=['accuracy'])

train

训练判别器

  1. d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)
  2. d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)
  3. d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

训练生成器

  1. image_features = self.vgg.predict(imgs_hr)
  2. # Train the generators
  3. g_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])

实际结果

5000.png-193.8kB
5000 batchsize

对比实验结果

image_1ctqikdcj13h21l1u151c1rab1pra8r.png-517.7kB

谢谢大家

添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注