使用 pytorch 创建神经网络拟合sin函数的实现
作者:假装很坏的谦谦君 发布时间:2023-02-04 03:31:40
我们知道深度神经网络的本质是输入端数据和输出端数据的一种高维非线性拟合,如何更好的理解它,下面尝试拟合一个正弦函数,本文可以通过简单设置节点数,实现任意隐藏层数的拟合。
基于pytorch的深度神经网络实战,无论任务多么复杂,都可以将其拆分成必要的几个模块来进行理解。
1)构建数据集,包括输入,对应的标签y
2) 构建神经网络模型,一般基于nn.Module继承一个net类,必须的是__init__函数和forward函数。__init__构造函数包括创建该类是必须的参数,比如输入节点数,隐藏层节点数,输出节点数。forward函数则定义了整个网络的前向传播过程,类似于一个Sequential。
3)实例化上步创建的类。
4)定义损失函数(判别准则),比如均方误差,交叉熵等
5)定义优化器(optim:SGD,adam,adadelta等),设置学习率
6)开始训练。开始训练是一个从0到设定的epoch的循环,循环期间,根据loss,不断迭代和更新网络权重参数。
无论多么复杂的网络,基于pytorch的深度神经网络都包括6个模块,训练阶段包括5个步骤,本文只通过拟合一个正弦函数来说明加深理解。
废话少说,直接上代码:
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import torch.nn as nn
import numpy as np
import torch
# 准备数据
x=np.linspace(-2*np.pi,2*np.pi,400)
y=np.sin(x)
# 将数据做成数据集的模样
X=np.expand_dims(x,axis=1)
Y=y.reshape(400,-1)
# 使用批训练方式
dataset=TensorDataset(torch.tensor(X,dtype=torch.float),torch.tensor(Y,dtype=torch.float))
dataloader=DataLoader(dataset,batch_size=100,shuffle=True)
# 神经网络主要结构,这里就是一个简单的线性结构
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net=nn.Sequential(
nn.Linear(in_features=1,out_features=10),nn.ReLU(),
nn.Linear(10,100),nn.ReLU(),
nn.Linear(100,10),nn.ReLU(),
nn.Linear(10,1)
)
def forward(self, input:torch.FloatTensor):
return self.net(input)
net=Net()
# 定义优化器和损失函数
optim=torch.optim.Adam(Net.parameters(net),lr=0.001)
Loss=nn.MSELoss()
# 下面开始训练:
# 一共训练 1000次
for epoch in range(1000):
loss=None
for batch_x,batch_y in dataloader:
y_predict=net(batch_x)
loss=Loss(y_predict,batch_y)
optim.zero_grad()
loss.backward()
optim.step()
# 每100次 的时候打印一次日志
if (epoch+1)%100==0:
print("step: {0} , loss: {1}".format(epoch+1,loss.item()))
# 使用训练好的模型进行预测
predict=net(torch.tensor(X,dtype=torch.float))
# 绘图展示预测的和真实数据之间的差异
import matplotlib.pyplot as plt
plt.plot(x,y,label="fact")
plt.plot(x,predict.detach().numpy(),label="predict")
plt.title("sin function")
plt.xlabel("x")
plt.ylabel("sin(x)")
plt.legend()
plt.savefig(fname="result.png",figsize=[10,10])
plt.show()
输出结果:
step: 100 , loss: 0.06755948066711426
step: 200 , loss: 0.003788222325965762
step: 300 , loss: 0.0004728269996121526
step: 400 , loss: 0.0001810075482353568
step: 500 , loss: 0.0001108720971387811
step: 600 , loss: 6.29749265499413e-05
step: 700 , loss: 3.707894938997924e-05
step: 800 , loss: 0.0001250380591955036
step: 900 , loss: 3.0654005968244746e-05
step: 1000 , loss: 4.349641676526517e-05
输出图像:
来源:https://blog.csdn.net/qq_38863413/article/details/104437824


猜你喜欢
- 并发是一个很酷的话题,一旦你掌握了它,就会成为一笔巨大的财富。说实话,我一开始很害怕写这篇文章,因为我自己直到最近才对并发性不太适应。我已经
- 本文研究的主要是numpy使用技巧之数组过滤的相关内容,具体如下。当使用布尔数组b作为下标存取数组x中的元素时,将收集数组x中所有在数组b中
- 这篇文章为大家提供了Mysql的安装包,详细的安装步骤,以及安装过程中出现的问题的解决方案,希望对大家有所帮助......工具:Mysql
- 首先感谢朋友们对第一篇文章的鼎力支持,感动中....... 今天说的是选择排序,包括“直接选择排序”和“堆排序”。话说
- docs = [‘icassp improved human face identification using frequency dom
- 本文实例讲述了Python面向对象编程基础。分享给大家供大家参考,具体如下:1、类的定义Python中类的定义与对象的初始化如下,pytho
- declare @tt varchar(20) set @tt = 'monisubbouns' declare @int
- 本文实例讲述了Python装饰器基础概念与用法。分享给大家供大家参考,具体如下:装饰器基础前面快速介绍了装饰器的语法,在这里,我们将深入装饰
- --按日 select sum(consume),day([date]) from consume_record where year([d
- 前言登录跳转:不同的用户在登录成功之后跳转到不同的网页当中例如:网站管理员登录成功后跳转到网站后台,vip用户登录成功后跳转到vip页面准备
- 我在初学时查阅过大量相关资料,发现其中提供的很多方法实际操作起来并不是那么回事。对于简单的应用,这些资料也许是有帮助的,但仅限于此,因为它们
- 什么是pyc文件pyc是一种二进制文件,是由py文件经过编译后,生成的文件,是一种byte code,py文件变成pyc文件后,加载的速度有
- 本文实例为大家分享了Python 12306抢火车票的具体代码,供大家参考,具体内容如下# -*- coding: utf-8 -*-fro
- 为了实现挖掘,我们需要开发一个挖掘功能.挖掘功能需要在给定的消息字符串上生成摘要并提供工作证明.让我们在本章讨论这个.消息摘要函数我们将编写
- 最近老板叫做一个数据查重的小练习,涉及从一个包含中文字段的文件中提取出其中的中文字段并存储,使用php开发。中间涉及到php正则表达式中文匹
- 程序员的时间很宝贵,Python这门语言虽然足够简单、优雅,但并不是说你使用Python编程,效率就一定会高。要想节省时间、提高效率,还是需
- 下边我就简单说一下过程和原理。第一步:实现一个匿名函数并能自己执行。(function(){ })() 这个函数在一样编的好的J
- __getitem__ 来看个简单的例子就明白:def __getitem__(self, key): return self.data[k
- 在web2.0的站中用户互动性是很强的,例如用户留言我们可能放开img标签,允许用户外链其他站点的图片,那么我们就需要解决图片尺寸过大所带来
- 不知道大家有没有一种感觉,每次当使用numpy数组的时候坐标轴总是傻傻分不清楚,然后就会十分的困惑,每次运算都需要去尝试好久才能得出想要的结