python中的Pytorch建模流程汇总
作者:Python学习与数据挖掘 发布时间:2022-04-26 19:02:14
标签:Pytorch,建模,流程,python
本节内容学习帮助大家梳理神经网络训练的架构。
一般我们训练神经网络有以下步骤:
导入库
设置训练参数的初始值
导入数据集并制作数据集
定义神经网络架构
定义训练流程
训练模型
推荐文章:
python实现可视化大屏
分享4款 Python 自动数据分析神器
以下,我就将上述步骤使用代码进行注释讲解:
1 导入库
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim
from torch.utils.data import DataLoader, DataLoader
import torchvision
import torchvision.transforms as transforms
2 设置初始值
# 学习率
lr = 0.15
# 优化算法参数
gamma = 0.8
# 每次小批次训练个数
bs = 128
# 整体数据循环次数
epochs = 10
3 导入并制作数据集
本次我们使用FashionMNIST
图像数据集,每个图像是一个28*28的像素数组,共有10个衣物类别,比如连衣裙、运动鞋、包等。
注:初次运行下载需要等待较长时间。
# 导入数据集
mnist = torchvision.datasets.FashionMNIST(
root = './Datastes'
, train = True
, download = True
, transform = transforms.ToTensor())
# 制作数据集
batchdata = DataLoader(mnist
, batch_size = bs
, shuffle = True
, drop_last = False)
我们可以对数据进行检查:
for x, y in batchdata:
print(x.shape)
print(y.shape)
break
# torch.Size([128, 1, 28, 28])
# torch.Size([128])
可以看到一个batch
中有128个样本,每个样本的维度是1*28*28。
之后我们确定模型的输入维度与输出维度:
# 输入的维度
input_ = mnist.data[0].numel()
# 784
# 输出的维度
output_ = len(mnist.targets.unique())
# 10
4 定义神经网络架构
先使用一个128个神经元的全连接层,然后用relu激活函数,再将其结果映射到标签的维度,并使用softmax
进行激活。
# 定义神经网络架构
class Model(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear1 = nn.Linear(in_features, 128, bias = True)
self.output = nn.Linear(128, out_features, bias = True)
def forward(self, x):
x = x.view(-1, 28*28)
sigma1 = torch.relu(self.linear1(x))
sigma2 = F.log_softmax(self.output(sigma1), dim = -1)
return sigma2
5 定义训练流程
在实际应用中,我们一般会将训练模型部分封装成一个函数,而这个函数可以继续细分为以下几步:
定义损失函数与优化器
完成向前传播
计算损失
反向传播
梯度更新
梯度清零
在此六步核心操作的基础上,我们通常还需要对模型的训练进度、损失值与准确度进行监视。
注释代码如下:
# 封装训练模型的函数
def fit(net, batchdata, lr, gamma, epochs):
# 参数:模型架构、数据、学习率、优化算法参数、遍历数据次数
# 5.1 定义损失函数
criterion = nn.NLLLoss()
# 5.1 定义优化算法
opt = optim.SGD(net.parameters(), lr = lr, momentum = gamma)
# 监视进度:循环之前,一个样本都没有看过
samples = 0
# 监视准确度:循环之前,预测正确的个数为0
corrects = 0
# 全数据训练几次
for epoch in range(epochs):
# 对每个batch进行训练
for batch_idx, (x, y) in enumerate(batchdata):
# 保险起见,将标签转为1维,与样本对齐
y = y.view(x.shape[0])
# 5.2 正向传播
sigma = net.forward(x)
# 5.3 计算损失
loss = criterion(sigma, y)
# 5.4 反向传播
loss.backward()
# 5.5 更新梯度
opt.step()
# 5.6 梯度清零
opt.zero_grad()
# 监视进度:每训练一个batch,模型见过的数据就会增加x.shape[0]
samples += x.shape[0]
# 求解准确度:全部判断正确的样本量/已经看过的总样本量
# 得到预测标签
yhat = torch.max(sigma, -1)[1]
# 将正确的加起来
corrects += torch.sum(yhat == y)
# 每200个batch和最后结束时,打印模型的进度
if (batch_idx + 1) % 200 == 0 or batch_idx == (len(batchdata) - 1):
# 监督模型进度
print("Epoch{}:[{}/{} {: .0f}%], Loss:{:.6f}, Accuracy:{:.6f}".format(
epoch + 1
, samples
, epochs*len(batchdata.dataset)
, 100*samples/(epochs*len(batchdata.dataset))
, loss.data.item()
, float(100.0*corrects/samples)))
6 训练模型
# 设置随机种子
torch.manual_seed(51)
# 实例化模型
net = Model(input_, output_)
# 训练模型
fit(net, batchdata, lr, gamma, epochs)
# Epoch1:[25600/600000 4%], Loss:0.524430, Accuracy:69.570312
# Epoch1:[51200/600000 9%], Loss:0.363422, Accuracy:74.984375
# ......
# Epoch10:[600000/600000 100%], Loss:0.284664, Accuracy:85.771835
现在我们已经用Pytorch
训练了最基础的神经网络,并且可以查看其训练成果。大家可以将代码复制进行运行!
虽然没有用到复杂的模型,但是我们在每次建模时的基本思想都是一致的
来源:https://blog.csdn.net/weixin_38037405/article/details/123157702


猜你喜欢
- 我的目标是写一个非常详细的关于diff的干货,所以本文有点长。也会用到大量的图片以及代码举例,目的让看这篇文章的朋友一定弄明白diff的边边
- 这段时间看了关于在SQL server 中通过日志和时间点来恢复数据。也看了一些网上的例子,看如何通过日志来恢复数据。 前提条件:数据库的故
- 1.锦短情长为什么选择这个标题,借鉴了一封情书里面的情长纸短,还吻你万千。锦短情长都只谓人走茶凉,怎感觉锦短情长?一提起眼泪汪汪,是明月人心
- SQL Server 阻止了对组件 'Ad Hoc Distributed&nbs
- 在我们日常上网浏览网页的时候,经常会看到一些好看的图片,我们就希望把这些图片保存下载,或者用户用来做桌面壁纸,或者用来做设计的素材。我们最常
- 下面的request.servervariables例子都是服务器探针采用的asp代码本机ip:<%=request.serverva
- 一、弹窗事件是什么?弹窗事件就是在我们执行某操作的时候,弹出信息框给出提示。或收集数据的时候,弹出窗口收集信息,不想收集可以取消隐藏。二、简
- 本文实例讲述了Javascript与PHP验证用户输入URL地址是否正确的方法,分享给大家供大家参考。具体方法如下:1.javascript
- 听歌识曲,顾名思义,用设备“听”歌曲,然后它要告诉你这是首什么歌。而且十之八九它还得把这首歌给你播放出来。这样的功能在QQ音乐等应用上早就出
- 为了方便例子讲解,现有数组和json对象如下var demoArr = ['Javascript', 'Gulp
- vue-element-admin导入组件封装模板和样式首先封装一个类似的组件,首先需要注意的是,类似功能,vue-element-admi
- 问题:SQL Server 2000中设计表时如何得到自动编号字段?解答:具体步骤如下:①像Access中的自动编号字段右键你的表-->
- 本文目标:使用selenium3.0+python3操纵浏览器,打开百度网站。(相当于selenium的hello world)环境基础:p
- 方法一:使用临时表。首先创建一个与sp_who相同字段的临时,然后用insert into 方法赋值,这样就可以select这个临时表了。具
- QThread是Qt的线程类中最核心的底层类。由于PyQt的的跨平台特性,QThread要隐藏所有与平台相关的代码要使用的QThread开始
- 首页url与视图函数的映射是通过@app.route()装饰器实现的。只有一个斜杠代表的是根目录——
- 如提取第1行,第2列的值:df.iloc[[0],[1]]则会返回一个df,即有字段名和行号。如果用values属性取值:df.iloc[[
- 目录现象根因分析getLastPacketReceivedTimeMs()方法调用时机解决方案现象应用升级MySQL驱动8.0后,在并发量较
- 很多人在使用AJAX调用别人站点内容的时候,JS会提示"没有权限"错误,这是XMLHTTP组件的限制-安全起见禁止访问非
- 我是从去年初开始学习web标准的,两年下来也有些心得。最近跳槽了正好闲在家里,写一些出来和大家交流一下。1对于web标准和W3C XHTML