解决pytorch 模型复制的一些问题
作者:冰菓(笑) 发布时间:2022-04-23 03:57:58
直接使用
model2=model1
会出现当更新model2时,model1的权重也会更新,这和自己的初始目的不同。
经评论指出可以使用:
model2=copy.deepcopy(model1)
来实现深拷贝,手上没有pytorch环境,具体还没测试过,谁测试过可以和我说下有没有用。
原方法:
所有要使用模型复制可以使用如下方法。
torch.save(model, "net_params.pkl")
model5=Cnn(3,10)
model5=torch.load('net_params.pkl')
这样编写不会影响原始模型的权重
补充:pytorch模型训练流程中遇到的一些坑(持续更新)
要训练一个模型,主要分成几个部分,如下。
数据预处理
入门的话肯定是拿 MNIST 手写数据集先练习。
pytorch 中有帮助我们制作数据生成器的模块,其中有 Dataset、TensorDataset、DataLoader 等类可以来创建数据入口。
之前在 tensorflow 中可以用 dataset.from_generator() 的形式,pytorch 中也类似,目前我了解到的有两种方法可以实现。
第一种就继承 pytorch 定义的 dataset,改写其中的方法即可。如下,就获得了一个 DataLoader 生成器。
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.labels)
train_dataset = MyDataset(train_data, train_label)
train_loader = DataLoader(dataset = train_dataset,
batch_size = 1,
shuffle = True)
第二种就是转换,先把我们准备好的数据转化成 pytorch 的变量(或者是 Tensor),然后传入 TensorDataset,再构造 DataLoader。
X = torch.from_numpy(train_data).float()
Y = torch.from_numpy(train_label).float()
train_dataset = TensorDataset(X, Y)
train_loader = DataLoader(dataset = train_dataset,
batch_size = 1,
shuffle = True)
#num_workers = 2)
模型定义
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6 ,16, 3)
self.fc1 = nn.Linear(400, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
relu = F.relu(self.conv1(x))
x = F.max_pool2d(relu, (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:] #除了batch_size之外的维度
num_features = 1
for s in size:
num_features *= s
return num_features
训练模型那么肯定要先定义一个网络结构,如上定义一个前向传播网络。里面包含了卷积层、全连接层、最大池化层和 relu 非线性激活层(名字我自己取的)以及一个 view 展开,把一个多维的特征图平展成一维的。
其中nn.Conv2d(in_channels, out_channels, kernel_size),第一个参数是输入的深度,第二是输出的深度,第三是卷积核的尺寸。
F.max_pool2d(input, (pool_size, pool_size)),第二个参数是池话
nn.Linear(in_features, out_features)
x.view是平展的操作,不过实际上相当于 numpy 的 reshape,需要计算转换后的尺寸。
损失函数定义
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
模型定义完之后,意味着给出输入,就可以得到输出的结果。那么就来比较 outputs 和 targets 之间的区别,那么就需要用到损失函数来描述。
训练网络
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
以上的代码是官方教程中给出来的,我们要做的就是学习他的思路。
1.首先是 epoch 的数量为 2,每个 epoch 都会历遍一次整个训练集。在每个 epoch 内累积统计 running_loss,每 2000 个 batch 数据计算一次损失的平均值,然后 print 再重新将 running_loss 置为 0。
2.然后分 mini-batch 进行训练,在每个计算每个 mini-batch 的损失之前,都会将优化器 optimizer 中的梯度清空,防止不同 mini-batch 的梯度被累加到一起。更新分成两步:第一步计算损失函数,然后把总的损失分配到各个层中,即 loss.backward(),然后就使用优化器更新权重,即 optimizer.step()。
保存模型
PATH = '...'
torch.save(net.state_dict(), PATH)
爬坑总结
总的来说流程就是上面那几步,但自己做的时候就遇到了挺多问题,最主要是对于其中张量传播过程中的要求不清楚,导致出了不少错误。
首先是输入的数据,pytorch 默认图片的 batch 数据的结构是(BATCH_SIZE, CHANNELS, IMG_H, IMG_W),所以要在生成数据时做一些调整,满足这种 BCHW 的规则。
会经常出现一些某个矩阵或者张量要求的数据,例如 “RuntimeError: Expected object of scalar type Double but got scalar type Float for argument #2 ‘mat2'” 等错误信息。
可以使用 x.double(),y.float(),z.long() 等方式转换成他要求的格式。
RuntimeError: multi-target not supported。这个错误出现在损失函数那个地方,对于分类问题肯定是优先考虑交叉熵。
criterion = nn.CrossEntropyLoss()
loss = criterion(outputs, labels.long())#报错的地方
当我batch-size=1时这个地方不会报错,但是当batch-size>1时就会报错。
查了别人的代码,大家基本都是和官方教程里面写的一样,使用官方的 mnist 数据接口,代码如下。一开始我是不愿意的,因为那样子意味着可能数据格式被封装起来看不见,但是自己折腾成本比较高,所以还是试了,真香!
train_dataset = datasets.MNIST(root='./data/',
train=True,
transform=transforms.ToTensor(),
download=True)
train_loader = DataLoader(dataset = train_dataset,
batch_size = 4,
shuffle = True)
打印了一下从生成器中获得数据,看一下 size,发现果然和我自己写的不同。当 batch_size=4 时,数据 data.size() 都是4*1*28*28,这个是相同的;但是 labels.size() 是不同的,我写的是 one_hot 向量所以是 4*10,但它的是 4。
直接打印 labels 看看,果然,是单个指,例如 tensor([3, 2, 6, 2]) 这样。
不过模型的 outputs 依然是 4*10,看来是 nn.CrossEntropyLoss() 这个函数自己会做计算,所以他才会报错说 multi-target not supported,因为 lables.size() 不对,原本只有一个数字,但现在是10个数字,相当于被分配了10个属性,自然就报错啦。
所以稍微修改了自己写的生成器之后,就没问题了。
不过,如果想要更自由的调用数据,还是需要对对象进行一些方法的重载,使用 pytoch 定义的 DataLoader,用 enumerate,就会把所有的数据历遍一次,如果使用 iter() 得到一个可迭代对象之后 next(),并不可以像 tensorflow 那样子生成训练数据。
例如说,如果使用如上的形式,DataLoader 得到的是一个生成器,python 中的生成器对象主要有 __next__ 和 __iter__ 等魔术方法决定。
__iter__ 方法使得实例可以如下调用,可以得到一个可迭代对象,iterable,但是如果不加也没关系,因为更重要的是 __next__ 类方法。
如下自己写了 __next__ 方法之后就可以看到,原本会出现越界的现象不见了,可以循环的历遍数据,当然也可以想被注释的那部分一样,抛出 StopIteration 来终止。
a = A()
a_iter = iter(a)
class A():
def __init__(self):
self.list = [1,2,3]
self.index = 0
#def __getitem__(self, index):
# return self.list[i]
#def __iter__(self):
# return self
def __next__(self):
#for i in range():
if self.index >= len(self.list):
#raise StopIteration
self.index = self.index%len(self.list)
result = self.list[self.index]
self.index += 1
return result
b = A()
for i in range(20):
print(next(b))
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/a362682954/article/details/82693330
猜你喜欢
- 如你所见,功能很简单。只有基本的播放,停止,甚至只针对一首歌曲,仅供初学者参考学习用。代码from tkinter import *from
- 什么是机器学习 (Machine Learning) 机器学习是研究计算机怎样模
- 前言:文章利用Python pygame做一个贪吃蛇的小游戏而且讲清楚每一段代码是用来干嘛的。据说是贪吃蛇游戏是1976年,Gremlin公
- 以下保存成 App.xml , 与asp文件放在相同目录下! 代码如下: <?xml version="1.0"
- 用面向对象的思维解决问题的重点当遇到一个需求的时候不用自己去实现,如果自己一步步实现那就是面向过程;应该找一个专门做这个事的人来做。面向对象
- 刚接触 Go 语言时,就听说有一个叫rune的数据类型,即使查阅过一些资料,对它的理解依旧比较模糊,加之对陌生事物的天然排斥,在之后很长一段
- 在现在的项目里,不管是电商项目还是别的项目,在管理端都会有导出的功能,比方说订单表导出,用户表导出,业绩表导出。这些都需要提前生成excel
- 目录 一、环境配置 二、ASP对Excel的基本操作 三、ASP操作Excel生成数据表 四、ASP操作Excel生成Chart图 五、服务
- 本文实例为大家分享了python代码实现猜拳小游戏的具体代码,供大家参考,具体内容如下游戏实现具体功能原有的用户登录的信息均能保存在txt文
- 最近看Python看得都不用tab键了,哈哈。今天看了一个经典问题--八皇后问题,说实话,以前学C、C++的时候有这个问题,但是当时不爱学,
- 一、selenium实战这里我们只会用到很少的selenium语法,我这里就不补充别的用法了,以实战为目的二、打开艺龙网可以直接点击这里进入
- 查看安装的python版本号可以使用【python --version】命令。具体方法:首先按【win+r】组合键打开运行;然后输入cmd,
- 一、base64模块base64模块提供了在二进制数据和可打印ASCII字符间编解码的功能,包括 RFC3548中定义的Base16, Ba
- 平时经常看php的错误日志,很少有机会去自己动手写日志,看了王健的《最佳日志实践》觉得写一个清晰明了,结构分明的日志还是非常有必要的。在写日
- 在网页中,我们经常需要引用大量的javascript和css文件,在加上许多javascript库都包含debug版和经过压缩的releas
- 本文实例为大家分享了Django文件上传与下载的具体代码,供大家参考,具体内容如下文件上传1.新建django项目,创建应用stu: pyt
- Pygame的Draw绘图Pygame 中提供了一个draw模块用来绘制一些简单的图形状,比如矩形、多边形、圆形、直线、弧线等。pygame
- xhEditor简介xhEditor是一个基于jQuery开发的简单迷你并且高效的可视化HTML编辑器,基于网络访问并且兼容IE 6.0+,
- 用户管理是绝大部分Web网站都需要解决的问题。用户管理涉及到用户注册和登录。用户注册相对简单,我们可以先通过API把用户注册这个功能实现了:
- 最近写一个小爬虫,需要拿到邮箱信息,发现拿不到,也不是ajax接口。最后查资料发现是被Cloudflare加密起来了,有加密肯定有解密。通过