Pytorch搭建SRGAN平台提升图片超分辨率
作者:Bubbliiiing 发布时间:2022-10-03 14:02:01
源码下载地址
网络构建
一、什么是SRGAN
SRGAN出自论文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。
如果将SRGAN看作一个黑匣子,其主要的功能就是输入一张低分辨率图片,生成高分辨率图片。
该文章提到,普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节。
SRGAN利用感知损失(perceptual loss)和对抗损失(adversarial loss)来提升恢复出的图片的真实感。
二、生成网络的构建
生成网络的构成如上图所示,生成网络的作用是输入一张低分辨率图片,生成高分辨率图片。:
SRGAN的生成网络由三个部分组成。
1、低分辨率图像进入后会经过一个卷积+RELU函数。
2、然后经过B个残差网络结构,每个残差结构都包含两个卷积+标准化+RELU,还有一个残差边。
3、然后进入上采样部分,在经过两次上采样后,原图的高宽变为原来的4倍,实现分辨率的提升。
前两个部分用于特征提取,第三部分用于提高分辨率。
import math
import torch
from torch import nn
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU(channels)
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
short_cut = x
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.conv2(x)
x = self.bn2(x)
return x + short_cut
class UpsampleBLock(nn.Module):
def __init__(self, in_channels, up_scale):
super(UpsampleBLock, self).__init__()
self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
self.pixel_shuffle = nn.PixelShuffle(up_scale)
self.prelu = nn.PReLU(in_channels)
def forward(self, x):
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.prelu(x)
return x
class Generator(nn.Module):
def __init__(self, scale_factor, num_residual=16):
upsample_block_num = int(math.log(scale_factor, 2))
super(Generator, self).__init__()
self.block_in = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=9, padding=4),
nn.PReLU(64)
)
self.blocks = []
for _ in range(num_residual):
self.blocks.append(ResidualBlock(64))
self.blocks = nn.Sequential(*self.blocks)
self.block_out = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64)
)
self.upsample = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
self.upsample.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
self.upsample = nn.Sequential(*self.upsample)
def forward(self, x):
x = self.block_in(x)
short_cut = x
x = self.blocks(x)
x = self.block_out(x)
upsample = self.upsample(x + short_cut)
return torch.tanh(upsample)
三、判别网络的构建
判别网络的构成如上图所示:
SRGAN的判别网络由不断重复的 卷积+LeakyRELU和标准化 组成。
对于判断网络来讲,它的目的是判断输入图片的真假,它的输入是图片,输出是判断结果。
判断结果处于0-1之间,利用接近1代表判断为真图片,接近0代表判断为假图片。
判断网络的构建和普通卷积网络差距不大,都是不断的卷积对图片进行下采用,在多次卷积后,最终接一次全连接判断结果。
实现代码如下:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(512, 1024, kernel_size=1),
nn.LeakyReLU(0.2),
nn.Conv2d(1024, 1, kernel_size=1)
)
def forward(self, x):
batch_size = x.size(0)
return torch.sigmoid(self.net(x).view(batch_size))
训练思路
SRGAN的训练可以分为生成器训练和判别器训练:
每一个step中一般先训练判别器,然后训练生成器。
一、判别器的训练
在训练判别器的时候我们希望判别器可以判断输入图片的真伪,因此我们的输入就是真图片、假图片和它们对应的标签。
因此判别器的训练步骤如下:
1、随机选取batch_size个真实高分辨率图片。
2、利用resize后的低分辨率图片,传入到Generator中生成batch_size个虚假高分辨率图片。
3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练。
二、生成器的训练
在训练生成器的时候我们希望生成器可以生成极为真实的假图片。因此我们在训练生成器需要知道判别器认为什么图片是真图片。
因此生成器的训练步骤如下:
1、将低分辨率图像传入生成模型,得到虚假高分辨率图像,将虚假高分辨率图像获得判别结果与1进行对比得到loss。(与1对比的意思是,让生成器根据判别器判别的结果进行训练)。
2、将真实高分辨率图像和虚假高分辨率图像传入VGG网络,获得两个图像的特征,通过这两个图像的特征进行比较获得loss
利用SRGAN生成图片
SRGAN的库整体结构如下:
一、数据集的准备
在训练前需要准备好数据集,数据集保存在datasets文件夹里面。
二、数据集的处理
打开txt_annotation.py,默认指向根目录下的datasets。运行txt_annotation.py。
此时生成根目录下面的train_lines.txt。
三、模型训练
在完成数据集处理后,运行train.py即可开始训练。
训练过程中,可在results文件夹内查看训练效果:
来源:https://blog.csdn.net/weixin_44791964/article/details/121628982
猜你喜欢
- 问题: pydev使用wx库开发的过程中,import时碰到wx可以识别,但是其它很多函数和变量上面全部是红叉,即无法识别。 解决方法: 1
- 怎样解决MySQL 5 0 16的乱码问题? 本文给出了解决方法:问:怎样解决MySQL 5.0.16的乱码问题?答:MySQL 5.0.1
- 讲解1、库:os,shutil.copy2、代码效果:对指定文件夹内文件等量分配到新的文件夹3、代码原理:用os.listdir()遍历文件
- 在settings.py里,配置如下logging:LOGGING = { 'version': 1, 'disab
- MVVM模式不但可用于Form表单,在复杂的管理页面中也能大显身手。例如,分页显示Blog的功能,我们先把后端代码写出来:在apis.py中
- ActiveMQ是java开发的消息中间件服务。可以支持多种协议(AMQP,MQTT,OpenWire,Stomp),默认的是OpenWir
- 今天研究了一下JS的用setAttribute方法实现一个页面两份样式表的效果,具体方法如下:第一步:在连接样式表的元素里定义一个id,例如
- Python中会遇到很多关于排序的问题,今天小编就带给大家实现插入排序的方法。在Python中插入排序的基本原理类似于摸牌,将摸起来的牌插入
- 方法一、尽量使用复杂的SQL来代替简单的一堆 SQL.同样的事务,一个复杂的SQL完成的效率高于一堆简单SQL完成的效率。有多个查询时,要善
- 字符串多级目录取值:比如说:你response接收到的数据是这样的。你现在只需要取到itemstring 这个字段下的值。其他的都不要!思路
- 十要:第一:要认真规划和分析。这是网页设计灵魂工作。创建站点之前,要明确你的网站主要针对哪些访问者,为哪些用户服务,要把握准主页题材第二:网
- 什么是F型浏览?2006年4月,美国长期研究网站可用性的著名网站设计师杰柯柏·尼尔森(Jakob Nielsen)发表了一项《眼球轨迹的研究
- 前言随着人工智能的日益火热,计算机视觉领域发展迅速,尤其在人脸识别或物体检测方向更为广泛,今天就为大家带来最基础的人脸识别基础,从一个个函数
- 前言最近参加了datawhale的组队学习活动,在组队学习动员下,开始通过强迫自己输出来实现更好的输入与处理,6-15开始自己的第一次文章发
- 程序设计是困难的,其核心是管理的复杂性。计算机程序是人类做出的最复杂的东西。质量是不可靠的且隐蔽的。好的体系架构是必需给程序足够的结构使其健
- function checkPhoto(fnUpload) { var filename = fnUpload.value; alert(f
- 废话少说,直接上SQL代码(有兴趣的测试验证一下),下面这个查询语句为什么将2008-11-27的记录查询出来了呢?这个是同事遇到的一个问题
- 本文介绍了用ASP的AdoDb.Stream读取/写入UTF-8编码格式的文件的方法:函数名称:ReadTextFile 作用:利用AdoD
- 一、Lambda表达式Lambda表达式又被称之为匿名函数格式lambda 参数列表:函数体def add(x,y): return x+y
- 本节我们来介绍一下新浪微博宫格验证码的识别,此验证码是一种新型交互式验证码,每个宫格之间会有一条指示连线,指示了我们应该的滑动轨迹,我们需要