pytorch-神经网络拟合曲线实例
作者:马飞飞 发布时间:2022-03-17 18:17:30
标签:pytorch,神经网络,拟合曲线
代码已经调通,跑出来的效果如下:
# coding=gbk
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
import torch.nn.functional as F
'''
Pytorch是一个拥有强力GPU加速的张量和动态构建网络的库,其主要构建是张量,所以可以把PyTorch当做Numpy
来用,Pytorch的很多操作好比Numpy都是类似的,但是其能够在GPU上运行,所以有着比Numpy快很多倍的速度。
训练完了,发现隐层越大,拟合的速度越是快,拟合的效果越是好
'''
def train():
print('------ 构建数据集 ------')
# torch.linspace是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
#torch.rand返回的是[0,1]之间的均匀分布 这里是使用一个计算式子来构造出一个关联结果,当然后期要学的也就是这个式子
y = x.pow(2) + 0.2 * torch.rand(x.size())
# Variable是将tensor封装了下,用于自动求导使用
x, y = Variable(x), Variable(y)
#绘图展示
plt.scatter(x.data.numpy(), y.data.numpy())
#plt.show()
print('------ 搭建网络 ------')
#使用固定的方式继承并重写 init和forword两个类
class Net(torch.nn.Module):
def __init__(self,n_feature,n_hidden,n_output):
#初始网络的内部结构
super(Net,self).__init__()
self.hidden=torch.nn.Linear(n_feature,n_hidden)
self.predict=torch.nn.Linear(n_hidden,n_output)
def forward(self, x):
#一次正向行走过程
x=F.relu(self.hidden(x))
x=self.predict(x)
return x
net=Net(n_feature=1,n_hidden=1000,n_output=1)
print('网络结构为:',net)
print('------ 启动训练 ------')
loss_func=F.mse_loss
optimizer=torch.optim.SGD(net.parameters(),lr=0.001)
#使用数据 进行正向训练,并对Variable变量进行反向梯度传播 启动100次训练
for t in range(10000):
#使用全量数据 进行正向行走
prediction=net(x)
loss=loss_func(prediction,y)
optimizer.zero_grad() #清除上一梯度
loss.backward() #反向传播计算梯度
optimizer.step() #应用梯度
#间隔一段,对训练过程进行可视化展示
if t%5==0:
plt.cla()
plt.scatter(x.data.numpy(),y.data.numpy()) #绘制真是曲线
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
plt.text(0.5,0,'Loss='+str(loss.data[0]),fontdict={'size':20,'color':'red'})
plt.pause(0.1)
plt.ioff()
plt.show()
print('------ 预测和可视化 ------')
if __name__=='__main__':
train()
来源:https://blog.csdn.net/maqunfi/article/details/84504622


猜你喜欢
- 1.使用前先要安装 yagmailpip install yagmail -i https://pypi.douban.com/simple
- 最近一直在用python写点监控oracle的程序,一直没有用到异常处理这一块,然后日常监控中一些错误笼统的抛出数据库连接异常,导致后续处理
- Pycharm是大多数程序员都会使用的一款编程软件,可是对于新手小白对说,英文界面十分头晕。Pycharm最新版本2020.3汉化、解除汉化
- 闲的无聊。。。网上一堆,正好练手(主要是新手)# coding=utf-8 import requests from bs4 import
- 近来,随着XHTML(可扩展HTML)标准的出现,<script/>标签也经历了一些改变。该标签不再用language特性,而用
- 近日,因公司业务需要,需将原两个公众号合并为一个,即要将其中一个公众号(主要是粉丝)迁移到另一个公众号。按微信规范,同一用户在不同公众号内的
- DDPDistributed Data Parallel 简称 DDP,是 PyTorch 框架下一种适用于单机多卡、多机多卡任务的数据并行
- 其实有一个疑惑一直在小编心中,每一个代码段编写里,总会出现好多个函数,也许有人和小编有一样的认同感,后来,小编明白,每一个函数本身都是都有各
- 我的SQL Server2005 一直正常使用但昨天出现了错误,如图。经过上网查,网上说的办法试了好多都没有解决这个问题。在经过多次的摸索后
- PHP天然就对MySQL有良好的支持,但是想要用PHP对SQL Server进行操作,则需要花点时间了。今天刚好团队里的一个项目需要用PHP
- 1. 首先安装node,推荐偶数版本;好了之后检查一下: node -v;出现版本好即为安装成功;win10家庭版本的msi版本的时候出现无
- 三重相等运算符 === 严格检查2个值是否相同:1 === 1; // => true1 === '1';
- 一、 简单查询简单的Transact-SQL查询只包括选择列表、FROM子句和WHERE子句。它们分别说明所查询列、查询的表或视图、以及搜索
- SQLServer分页方式附带50万数据分页时间[本机访问|已重启SQL服务|无其他程序干扰][非索引排序]环境 WIN7 SQL服务12.
- BOF 指示当前记录位置位于 Recordset 对象的第一个记录之前。EOF 指示当前记录位置位于 Recordset 对象的最后一个记录
- 本文实例讲述了python使用正则表达式提取网页URL的方法。分享给大家供大家参考。具体实现方法如下:import reimport url
- 我们编写程序最终目的还是来解决实际问题,所以必然会遇到输入输出的交互问题,python中提供了input函数用来获取用户的输入,我们可以用以
- JMeter可以通过os命令调用Python脚本,Python同样可以通过系统命令调用JMeter执行压测Python调用JMeter首先要
- 配置可能会随官方改变,本文仅供参考。1.下载安装GO的包到https://code.google.com/p/go/downloads/li
- 要在用户浏览器上安装cookie,HTTP服务器向HTTP响应添加类似以下内容的HTTP报头:Set-Cookie:session=8345