Pytorch 神经网络—自定义数据集上实现教程
作者:LZDCQU 发布时间:2022-11-30 20:05:04
标签:Pytorch,神经网络,自定义,数据集
第一步、导入需要的包
import os
import scipy.io as sio
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from torch.autograd import Variable
batchSize = 128 # batchsize的大小
niter = 10 # epoch的最大值
第二步、构建神经网络
设神经网络为如上图所示,输入层4个神经元,两层隐含层各4个神经元,输出层一个神经。每一层网络所做的都是线性变换,即y=W×X+b;代码实现如下:
class Neuralnetwork(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super(Neuralnetwork, self).__init__()
self.layer1 = nn.Linear(in_dim, n_hidden_1)
self.layer2 = nn.Linear(n_hidden_1, n_hidden_2)
self.layer3 = nn.Linear(n_hidden_2, out_dim)
def forward(self, x):
x = x.view(x.size(0), -1)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x
model = Neuralnetwork(1*3, 4, 4, 1)
print(model) # net architecture
Neuralnetwork(
(layer1): Linear(in_features=3, out_features=4, bias=True)
(layer2): Linear(in_features=4, out_features=4, bias=True)
(layer3): Linear(in_features=4, out_features=1, bias=True)
)
第三步、读取数据
自定义的数据为demo_SBPFea.mat,是MATLAB保存的数据格式,其存储的内容如下:包括fea(1000*3)和sbp(1000*1)两个数组;fea为特征向量,行为样本数,列为特征宽度;sbp为标签
class SBPEstimateDataset(Dataset):
def __init__(self, ext='demo'):
data = sio.loadmat(ext+'_SBPFea.mat')
self.fea = data['fea']
self.sbp = data['sbp']
def __len__(self):
return len(self.sbp)
def __getitem__(self, idx):
fea = self.fea[idx]
sbp = self.sbp[idx]
"""Convert ndarrays to Tensors."""
return {'fea': torch.from_numpy(fea).float(),
'sbp': torch.from_numpy(sbp).float()
}
train_dataset = SBPEstimateDataset(ext='demo')
train_loader = DataLoader(train_dataset, batch_size=batchSize, # 分批次训练
shuffle=True, num_workers=int(8))
整个数据样本为1000,以batchSize = 128划分,分为8份,前7份为104个样本,第8份则为104个样本。在网络训练过程中,是一份数据一份数据进行训练的
第四步、模型训练
# 优化器,Adam
optimizer = optim.Adam(list(model.parameters()), lr=0.0001, betas=(0.9, 0.999),weight_decay=0.004)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.997)
criterion = nn.MSELoss() # loss function
if torch.cuda.is_available(): # 有GPU,则用GPU计算
model.cuda()
criterion.cuda()
for epoch in range(niter):
losses = []
ERROR_Train = []
model.train()
for i, data in enumerate(train_loader, 0):
model.zero_grad()# 首先提取清零
real_cpu, label_cpu = data['fea'], data['sbp']
if torch.cuda.is_available():# CUDA可用情况下,将Tensor 在GPU上运行
real_cpu = real_cpu.cuda()
label_cpu = label_cpu.cuda()
input=real_cpu
label=label_cpu
inputv = Variable(input)
labelv = Variable(label)
output = model(inputv)
err = criterion(output, labelv)
err.backward()
optimizer.step()
losses.append(err.data[0])
error = output.data-label+ 1e-12
ERROR_Train.extend(error)
MAE = np.average(np.abs(np.array(ERROR_Train)))
ME = np.average(np.array(ERROR_Train))
STD = np.std(np.array(ERROR_Train))
print('[%d/%d] Loss: %.4f MAE: %.4f Mean Error: %.4f STD: %.4f' % (
epoch, niter, np.average(losses), MAE, ME, STD))
[0/10] Loss: 18384.6699 MAE: 135.3871 Mean Error: -135.3871 STD: 7.5580
[1/10] Loss: 17063.0215 MAE: 130.4145 Mean Error: -130.4145 STD: 7.8918
[2/10] Loss: 13689.1934 MAE: 116.6625 Mean Error: -116.6625 STD: 9.7946
[3/10] Loss: 8192.9053 MAE: 89.6611 Mean Error: -89.6611 STD: 12.9911
[4/10] Loss: 2979.1340 MAE: 52.5410 Mean Error: -52.5279 STD: 15.0930
[5/10] Loss: 599.7094 MAE: 22.2735 Mean Error: -19.9979 STD: 14.2069
[6/10] Loss: 207.2831 MAE: 11.2394 Mean Error: -4.8821 STD: 13.5528
[7/10] Loss: 189.8173 MAE: 9.8020 Mean Error: -1.2357 STD: 13.7095
[8/10] Loss: 188.3376 MAE: 9.6512 Mean Error: -0.6498 STD: 13.7075
[9/10] Loss: 186.8393 MAE: 9.6946 Mean Error: -1.0850 STD: 13.6332
来源:https://blog.csdn.net/qq_21905401/article/details/82627402
0
投稿
猜你喜欢
- copy()chutil.copy(source, destination)shutil.copy() 函数实现文件复制功能,将 sourc
- 本文为大家分享了python银行管理系统的具体代码,供大家参考,具体内容如下自己写的练手小程序,练习面向对象的概念,代码中都有注释,刚学的同
- 介绍在本文中,你将学习如何使用 Python 构建人脸识别系统。人脸识别比人脸检测更进一步。在人脸检测中,我们只检测人脸在图像中的位置,但在
- Linux RedHat下安装Python2.7、pip、ipython环境、eclipse和PyDev环境准备工作,源Python2.6备
- 在程序中,变量就是一个名称,让我们更加方便记忆。cars = 100 space_in_a_car = 4.0 drivers = 30 p
- 毫无疑问,JavaScript 是一种非常灵活的脚本语言,有时候它像一只难以驯服的野马——你受益于它的灵活性的同时,也要时刻提防它变得失去控
- 多线程-共享全局变量#coding=utf-8from threading import Threadimport timeg_num =
- 一、TCP1、tcp服务器创建#创建服务器from socket import *from time import ctime #导入cti
- 列表(List)是你使用Python过程中接触最为频繁的数据结构,也是功能最为强大的几种数据结构之一。Python列表非常的万能且蕴含着许多
- 直接调用系统的颜色显示在网页上本来是件很好玩滴事,但是,也有个缺点,就是可用的色太少 比如Bindows在它的启动画面一点点应用。=。= 上
- Python内置函数isdigit()使用今天简单介绍一下Python中的isdigit()函数的用法:判断单个字符是否为数字判断字符串中是
- 本文实例讲述了Python实现的文本简单可逆加密算法。分享给大家供大家参考,具体如下:其实很简单,就是把一段文本每个字符都通过某种方式改变(
- 一、项目工程目录:二、具体工程文件代码:1、新建一个包名:common(用于存放基本函数封装)(1)在common包下新建一个base.py
- 小的本身是一个平面设计人员,前一阵儿有一些空闲的时间,便在各个站长网上发布了贴子,大意是免费制作logo,以换取网站连接(相信很多人都看过)
- 判断访问是否来自搜索引擎的函数,有兴趣的可以试试! <% '检查当前用户是否是蜘蛛人 Function check(
- 本文实例讲述了Python下载指定页面上图片的方法。分享给大家供大家参考,具体如下:#!/usr/bin/python #coding:ut
- 这篇文章主要介绍了python文字和unicode/ascll相互转换函数及简单加密解密实现代码,下面我们来了解一下。import reim
- 前言Python环境的搭建这里就不赘述了,有需要的小伙伴可以在网上搜罗出很多教程,注意安装PyChom编辑工具。这次我们主要讲一下几点内容:
- 1.TCP是一种面向连接的可靠地协议,在一方发送数据之前,必须在双方之间建立一个连接,建立的过程需要经过三次握手,通信完成后要拆除连接,需要
- 如图:Oracle 11g安装到42%挂了。上度娘查了一下,原来是Oracle安装包的问题,1,2两个包都要下载下来,而且需要解压到相同(同