pytorch从头开始搭建UNet++的过程详解
作者:楚楚小甜心 发布时间:2023-03-11 09:19:24
Unet是一个最近比较火的网络结构。它的理论已经有很多大佬在讨论了。本文主要从实际操作的层面,讲解pytorch从头开始搭建UNet++的过程。
Unet++代码
网络架构
黑色部分是Backbone,是原先的UNet。
绿色箭头为上采样,蓝色箭头为密集跳跃连接。
绿色的模块为密集连接块,是经过左边两个部分拼接操作后组成的
Backbone
2个3x3的卷积,padding=1。
class VGGBlock(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(middle_channels)
self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
return out
上采样
图中的绿色箭头,上采样使用双线性插值。
双线性插值就是有两个变量的插值函数的线性插值扩展,其核心思想是在两个方向分别进行一次线性插值
torch.nn.Upsample(size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None)
参数说明:
①size:可以用来指定输出空间的大小,默认是None;
②scale_factor:比例因子,比如scale_factor=2意味着将输入图像上采样2倍,默认是None;
③mode:用来指定上采样算法,有’nearest’、 ‘linear’、‘bilinear’、‘bicubic’、‘trilinear’,默认是’nearest’。上采样算法在本文中会有详细理论进行讲解;
④align_corners:如果True,输入和输出张量的角像素对齐,从而保留这些像素的值,默认是False。此处True和False的区别本文中会有详细的理论讲解;
⑤recompute_scale_factor:如果recompute_scale_factor是True,则必须传入scale_factor并且scale_factor用于计算输出大小。计算出的输出大小将用于推断插值的新比例。请注意,当scale_factor为浮点数时,由于舍入和精度问题,它可能与重新计算的scale_factor不同。如果recompute_scale_factor是False,那么size或scale_factor将直接用于插值。
class Up(nn.Module):
def __init__(self):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return x
下采样
图中的黑色箭头,采用的是最大池化。
self.pool = nn.MaxPool2d(2, 2)
深度监督
所示,该结构下有4个分支,可以分为两种模式。
精确模式:4个分支取平均值结果
快速模式:只选择一个分支,其余被剪枝
if self.deep_supervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
return [output1, output2, output3, output4]
else:
output = self.final(x0_4)
return output
网络架构代码
class NestedUNet(nn.Module):
def __init__(self, num_classes=1, input_channels=1, deep_supervision=False, **kwargs):
super().__init__()
nb_filter = [32, 64, 128, 256, 512]
self.deep_supervision = deep_supervision
self.pool = nn.MaxPool2d(2, 2)
self.up = Up()
self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])
self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])
self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])
self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])
self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])
if self.deep_supervision:
self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
else:
self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
def forward(self, input):
x0_0 = self.conv0_0(input)
x1_0 = self.conv1_0(self.pool(x0_0))
x0_1 = self.conv0_1(self.up(x1_0, x0_0))
x2_0 = self.conv2_0(self.pool(x1_0))
x1_1 = self.conv1_1(self.up(x2_0, x1_0))
x0_2 = self.conv0_2(self.up(x1_1, torch.cat([x0_0, x0_1], 1)))
x3_0 = self.conv3_0(self.pool(x2_0))
x2_1 = self.conv2_1(self.up(x3_0, x2_0))
x1_2 = self.conv1_2(self.up(x2_1, torch.cat([x1_0, x1_1], 1)))
x0_3 = self.conv0_3(self.up(x1_2, torch.cat([x0_0, x0_1, x0_2], 1)))
x4_0 = self.conv4_0(self.pool(x3_0))
x3_1 = self.conv3_1(self.up(x4_0, x3_0))
x2_2 = self.conv2_2(self.up(x3_1, torch.cat([x2_0, x2_1], 1)))
x1_3 = self.conv1_3(self.up(x2_2, torch.cat([x1_0, x1_1, x1_2], 1)))
x0_4 = self.conv0_4(self.up(x1_3, torch.cat([x0_0, x0_1, x0_2, x0_3], 1)))
if self.deep_supervision:
output1 = self.final1(x0_1)
output2 = self.final2(x0_2)
output3 = self.final3(x0_3)
output4 = self.final4(x0_4)
return [output1, output2, output3, output4]
else:
output = self.final(x0_4)
return output
来源:https://blog.csdn.net/qq128252/article/details/127610581


猜你喜欢
- 项目使用Pyqt作为UI框架,使用相机线程捕捉image,并在QGraphicsView中显示,遇到以下问题:1、采集的数据为nparray
- 1、模拟退火算法退火是金属从熔融状态缓慢冷却、最终达到能量最低的平衡态的过程。模拟退火算法基于优化问题求解过程与金属退火过程的相似性,以优化
- 很久之前,分享过一次Python代码实现验证码识别的办法。当时采用的是pillow+pytesseract,优点是免费,较为易用。但其识别精
- Runtime包GOMAXPROCS()用来设置可以并行计算的CPU核数最大值,并返回之前的值,具体使用方法上一篇有些,这里不再赘述Gosc
- 一、在 VS Code 中配置调试使用 Vue CLI 2搭建项目时:更新 config/index.js 内的 devtool prope
- 使用perl连接mysql,这个网上有很多案例了,一般大家都是DBI下的DBD::MySQL这个模块进行.这里做一个mask弄一个TIPS:
- 创建用户定义函数,它是返回值的已保存的 Transact-SQL 例程。用户定义函数不能用于执行一组修改全局数据库状态的操作。与系统函数一样
- 本文介绍Python3使用PyMySQL连接数据库,并实现简单的增删改查。什么是PyMySQL?PyMySQL是Python3.x版本中用于
- 分页查询是经常能够遇到的问题,我们首先看看分页查询存在的理由:方便用户:用户不可能一次察看所有数据,所以一页一页的翻看比较好。提高性能:一次
- 一、简介从Python2.6开始,新增了str.format(),它增强了字符串格式化的功能。基本语法是通过 {} 和 : 来代替以前的 %
- 看代码吧~// Strval 获取变量的字符串值// 浮点型 3.0将会转换成字符串3, "3"// 非数值或字符类型的
- 本来想把之前对artTemplate源码解析的注释放上来分享下,不过隔了一年,找不到了,只好把当时分析模板引擎原理后,自己尝试写下的模板引擎
- Python读取配置文件-ConfigParser二次封装直接上上代码test.conf[database]connect = mysqls
- 1、简介MySQL是关系型数据库,我们在使用的时候往往会将对象的属性映射成列存储在表中,因此查询的到的结果在不做任何处理的情况下,也是一个个
- 由于工作关系,只能暂时放弃对mongodb的研究了 .开始研究PHPcms .目前为止我已经基本完成了模块的开发.趁着周末来这里做个总结.我
- 前提搭建钉钉应答机器人,需要先准备或拥有以下权限:钉钉企业的管理员或子管理员(如果不是企业管理员,可以自己创建一个企业,很方便的)有公网通信
- 一、环境要求windows系统,python3.6+安装模块pip install pyqt5pip install pygame二、游戏介
- 实现效果:input未输入值,按钮禁用jquery操作代码:html<input type="text" name
- Ubuntu18.04安装mysql5.7,供大家参考,具体内容如下1.1安装首先执行下面三条命令:# 安装mysql服务sudo apt-
- 安装pip install requests发送网络请求import requestsr=requests.get('http://