@Team 2018-12-16T12:50:46.000000Z 字数 3649 阅读 2421

SRGAN-超分辨率图像复原

陈扬

github:

https://github.com/OUCMachineLearning/OUCML/blob/master/GAN/srgan_celebA/srgan.py

前言:GAN

celebA

A popular component of computer vision and deep learning revolves around identifying faces for various applications from logging into your phone with your face or searching through surveillance images for a particular suspect. This dataset is great for training and testing models for face detection, particularly for recognising facial attributes such as finding people with brown hair, are smiling, or wearing glasses. Images cover large pose variations, background clutter, diverse people, supported by a large quantity of images and rich annotations. This data was originally collected by researchers at MMLAB, The Chinese University of Hong Kong (specific reference in Acknowledgment section).

Overall

202,599 number of face images of various celebrities
10,177 unique identities, but names of identities are not given
40 binary attribute annotations per image
5 landmark locations

定义一个目标函数

Our ultimate goal is to train a generating function G that estimates for a given LR input image its corresponding HR counterpart. To achieve this, we train a generator network as a feed-forward CNN GθG parametrized by θG. Here θG = {W1:L ; b1:L } denotes the weights and biases of a L-layer deep network and is obtained by optimizing a SR-specific
loss function lSR. For training images IHR , n = 1, . . . , N n
withcorrespondingILR,n=1,...,N,wesolve:

提出perceptual loss

核心公式:

,我们希望D最大,所以$logD_{\theta D}(I^{HR})$应该最大,意味着我的判别器可以很好的识别出,真实的高分辨率图像是"true",在看加号后面的$E_{I^{LR \sim} G(I^{LR})}$,要让log尽可能的大,需要的是ΘD(ΘG(z))尽可能的小,意味着我们

loss函数

###vgg用于提取特征self.vgg.compile(loss='mse',            optimizer=optimizer,            metrics=['accuracy'])###生成器self.combined.compile(loss=['binary_crossentropy', 'mse'],            loss_weights=[1e-3, 1],            optimizer=optimizer)###判别器self.discriminator.compile(loss='mse',            optimizer=optimizer,            metrics=['accuracy'])

train

训练判别器

d_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

训练生成器

image_features = self.vgg.predict(imgs_hr)# Train the generatorsg_loss = self.combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])

5000 batchsize

• 私有
• 公开
• 删除