Pytorch 神经网络—自定义数据集上实现教程
作者:LZDCQU 发布时间:2022-11-30 20:05:04
标签:Pytorch,神经网络,自定义,数据集
第一步、导入需要的包
import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10 # epoch的最大值
第二步、构建神经网络
设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:
class Neuralnetwork(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Neuralnetwork, self).__init__()
self.layer1 = nn.Linear(in_dim, n_hidden_1)
self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
self.layer3 = nn.Linear(n_hidden_2, out_dim)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
model = Neuralnetwork(1*3, 4, 4, 1)
print(model) # net architecture
Neuralnetwork(
(layer1): Linear(in_features=3, out_features=4, bias=True)
(layer2): Linear(in_features=4, out_features=4, bias=True)
(layer3): Linear(in_features=4, out_features=1, bias=True)
)
第三步、读取数据
自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签
class SBPEstimateDataset(Dataset):
def __init__(self, ext='demo'):
data = sio.loadmat(ext+'_SBPFea.mat')
self.fea = data['fea']
self.sbp = data['sbp']
def __len__(self):
return len(self.sbp)
def __getitem__(self, idx):
fea = self.fea[idx]
sbp = self.sbp[idx]
"""Convert ndarrays to Tensors."""
return {'fea': torch.from_numpy(fea).float(),
'sbp': torch.from_numpy(sbp).float()
}
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
shuffle=True, num_workers=int(8))
整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的
第四步、模型训练
# 优化器,Adam
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997)
criterion = nn.MSELoss() # loss function
if torch.cuda.is_available(): # 有GPU,则用GPU计算
model.cuda()
criterion.cuda()
for epoch in range(niter):
losses = []
ERROR_Train = []
model.train()
for i, data in enumerate(train_loader, 0):
model.zero_grad()# 首先提取清零
real_cpu, label_cpu = data['fea'], data['sbp']
if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行
real_cpu = real_cpu.cuda()
label_cpu = label_cpu.cuda()
input=real_cpu
label=label_cpu
inputv = Variable(input)
labelv = Variable(label)
output = model(inputv)
err = criterion(output, labelv)
err.backward()
optimizer.step()
losses.append(err.data[0])
error = output.data-label+ 1e-12
ERROR_Train.extend(error)
MAE = np.average(np.abs(np.array(ERROR_Train)))
ME = np.average(np.array(ERROR_Train))
STD = np.std(np.array(ERROR_Train))
print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % (
epoch, niter, np.average(losses), MAE, ME, STD))
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332
来源:https://blog.csdn.net/qq_21905401/article/details/82627402
0
投稿
猜你喜欢
- 接着前面Django入门使用示例今天我们来看看Django是如何加载静态html的?我们首先来看一看什么是静态HTML,什么是动态的HTML
- xhtml+css页面制作过程中问题的解决方案,说是解决方案应该有点过了,充其量只不过是给刚刚开始学标准页面制作的朋友们的一些小建议,如果讲
- 由于javascript是一种无类型语言,所以一个数组的元素可以具有任意的数据类型,同一个数组的不同元素可以具有不同的类型,数组的元素设置可
- SQL Server是一个关系数据库管理系统,应用很广泛,在进行SQL Server数据库操作的过程中难免会出现误删或者别的原因引起的日志损
- PHP get_html_translation_table() 函数实例输出 htmlspecialchars 函数使用的翻译表:<
- 我搜集了国内10几个电影网站的数据,里面近几十W条记录,用文本没法存,mongodb学习成本非常低,安装、下载、运行起来不会花你5分钟时间。
- 下载:pip install apschedulerpip install django-apscheduler将 django-apsch
- 前几天要写一个东西里面有用到读文件的。 可是我不想用FSO,我怕有的空间不支持。 &nbs
- 由于工作关系,只能暂时放弃对mongodb的研究了 .开始研究PHPcms .目前为止我已经基本完成了模块的开发.趁着周末来这里做个总结.我
- 1、内容在一屏内显示的,采用了(内容框)上下左右居中的办法,里面的内容绝对于这个内容框定位.这样一来,在不同大小屏中,内容总是在中间,看起来
- 需求说明当用户申请售后,商家未在n小时内处理,系统自动进行退款。商家拒绝后,用户可申请客服介入,客服x天内超时未处理,系统自动退款。用户收到
- 我们知道,Diango 接收的 HTTP 请求信息里带有 Cookie 信息。Cookie的作用是为了识别当前用户的身份,通过以下例子来说明
- 这篇文章主要介绍了python 上下文管理器原理解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友
- MySQL根据不同条件联查不同表的数据项目开发中遇到类似的需求。Mybatis 中的< if >标签只能判断where部分,不能
- 如何实现质数求和生活中很多问题是需要用数学来解决的,比如说要是做一栋房子,各方面的数据都要计算,要用多少材料,长宽高多少等,简单地说,运算就
- php面试题的题目: $a = '/a/b/c/d/e.php'; $b = '/a/b/12/34/c.php
- 浅显地了解了一下 Go,发现 Go 语法的设计非常简洁,易于理解。正应了 Go 语言之父 Rob Pike 说的那句“Less is mor
- MongoDB是一个介于关系数据库和非关系数据库之间的产品,是非关系数据库当 * 能最丰富,最像关系数据库的。他支持的数据结构非常松散,是类似
- 起步Python 的成功一个原因是它的可读性,代码清晰易懂,更容易被人类所理解,但有时可读性会产生误解。假如要判断一个变量是不是 17,那可
- 前言之前看到 RunCat 一只可以在电脑上奔跑猫,其主要的功能是监控电脑的CPU、内存的使用情况,使用越多跑的越快。所以准备做一只在任务栏