教你利用PyTorch实现sin函数模拟
作者:hhh江月 发布时间:2021-06-23 18:17:25
一、简介
本文旨在使用两种方法来实现sin函数的模拟,具体的模拟方法是使用机器学习来实现的,我们使用Python的torch模块进行机器学习,从而为sin确定多项式的系数。
二、第一种方法
# 这个案例相当于是使用torch来模拟sin函数进行计算啦。
# 通过3次函数来模拟sin函数,实现类似于机器学习的操作。
import torch
import math
dtype = torch.float
# 数据的类型
device = torch.device("cpu")
# 设备的类型
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create random input and output data
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
# 与numpy的linspace是类似的
y = torch.sin(x)
# tensor->张量
# Randomly initialize weights
# 标准的高斯函数分布。
# 随机产生一个参数,然后通过学习来进行改进参数。
a = torch.randn((), device=device, dtype=dtype)
# a
b = torch.randn((), device=device, dtype=dtype)
# b
c = torch.randn((), device=device, dtype=dtype)
# c
d = torch.randn((), device=device, dtype=dtype)
# d
learning_rate = 1e-6
for t in range(2000):
# Forward pass: compute predicted y
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# 这个也是一个张量。
# 3次函数来进行模拟。
# Compute and print loss
loss = (y_pred - y).pow(2).sum().item()
if t % 100 == 99:
print(t, loss)
# 计算误差
# Backprop to compute gradients of a, b, c, d with respect to loss
grad_y_pred = 2.0 * (y_pred - y)
grad_a = grad_y_pred.sum()
grad_b = (grad_y_pred * x).sum()
grad_c = (grad_y_pred * x ** 2).sum()
grad_d = (grad_y_pred * x ** 3).sum()
# 计算误差。
# Update weights using gradient descent
# 更新参数,每一次都要更新。
a -= learning_rate * grad_a
b -= learning_rate * grad_b
c -= learning_rate * grad_c
d -= learning_rate * grad_d
# reward
# 最终的结果
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
运行结果:
99 676.0404663085938
199 478.38140869140625
299 339.39117431640625
399 241.61537170410156
499 172.80801391601562
599 124.37007904052734
699 90.26084899902344
799 66.23435974121094
899 49.30537033081055
999 37.37403106689453
1099 28.96288299560547
1199 23.031932830810547
1299 18.848905563354492
1399 15.898048400878906
1499 13.81600570678711
1599 12.34669017791748
1699 11.309612274169922
1799 10.57749080657959
1899 10.060576438903809
1999 9.695555686950684
Result: y = -0.03098311647772789 + 0.852223813533783 x + 0.005345103796571493 x^2 + -0.09268788248300552 x^3
三、第二种方法
import torch
import math
dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)
# Create random Tensors for weights. For a third order polynomial, we need
# 4 weights: y = a + b x + c x^2 + d x^3
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)
learning_rate = 1e-6
for t in range(2000):
# Forward pass: compute predicted y using operations on Tensors.
y_pred = a + b * x + c * x ** 2 + d * x ** 3
# Compute and print loss using operations on Tensors.
# Now loss is a Tensor of shape (1,)
# loss.item() gets the scalar value held in the loss.
loss = (y_pred - y).pow(2).sum()
if t % 100 == 99:
print(t, loss.item())
# Use autograd to compute the backward pass. This call will compute the
# gradient of loss with respect to all Tensors with requires_grad=True.
# After this call a.grad, b.grad. c.grad and d.grad will be Tensors holding
# the gradient of the loss with respect to a, b, c, d respectively.
loss.backward()
# Manually update weights using gradient descent. Wrap in torch.no_grad()
# because weights have requires_grad=True, but we don't need to track this
# in autograd.
with torch.no_grad():
a -= learning_rate * a.grad
b -= learning_rate * b.grad
c -= learning_rate * c.grad
d -= learning_rate * d.grad
# Manually zero the gradients after updating weights
a.grad = None
b.grad = None
c.grad = None
d.grad = None
print(f'Result: y = {a.item()} + {b.item()} x + {c.item()} x^2 + {d.item()} x^3')
运行结果:
99 1702.320556640625
199 1140.3609619140625
299 765.3402709960938
399 514.934326171875
499 347.6383972167969
599 235.80038452148438
699 160.98876953125
799 110.91152954101562
899 77.36819458007812
999 54.883243560791016
1099 39.79965591430664
1199 29.673206329345703
1299 22.869291305541992
1399 18.293842315673828
1499 15.214327812194824
1599 13.1397705078125
1699 11.740955352783203
1799 10.796865463256836
1899 10.159022331237793
1999 9.727652549743652
Result: y = 0.019909318536520004 + 0.8338049650192261 x + -0.0034346890170127153 x^2 + -0.09006795287132263 x^3
四、总结
以上的两种方法都只是模拟到了3次方,所以仅仅只是在x比较小的时候才比较合理,此外,由于系数是随机产生的,因此,每次运行的结果可能会有一定的差别的。
来源:https://blog.csdn.net/m0_54218263/article/details/122391053
猜你喜欢
- 1.find函数find() 方法检测字符串中是否包含子字符串 str ,如果指定 beg(开始) 和 end(结束) 范围,则检查是否包含
- (下面的代码原来我想用折叠的代码的,但是在google里面老是添加不了折叠的代码,所以就整屏的贴出来了,望大家不要见外。) 朋友的比较好的存
- 有时候在使用 Python 的时候,想要对一个数字或者字符串进行补零操作,即把「1」变为一个八位数的「00000001」,这个时候可以使用一
- React 是 Facebook 里一群牛 X 的码农折腾出的牛X的框架。 实现了一个虚拟 DOM,用 DOM 的方式将需要的组
- 1 介绍在设计到数据库的开发中,难免要将图片或音频文件插入到数据库中的情况。一般来说,我们可以同过插入图片文件相应的存储位置,而不是文件本身
- 本文实例讲述了Python向Excel中插入图片的简单实现方法。分享给大家供大家参考,具体如下:使用Python向Excel文件中插入图片,
- 我就废话不多说了,直接上代码吧!import numpy as npimport matplotlib.pyplot as pltx = n
- 前言:前几天上课闲着没事写了一个python敲击木鱼积累功德的小项目,当时纯粹就是写着玩,回顾一下鼠标事件的东西还记不记得,发现这个博客的点
- 题目描述1260. 二维网格迁移 - 力扣(LeetCode)给你一个 m 行 n 列的二维网格 grid 和
- 目前手边的一些工作,需要实现声音播放功能,而且仅支持wav声音格式。现在,一些网站上支持文字转语音功能,但是生成的都是MP3文件,这样还需要
- 检查图片是否损坏日常工作中,时常会需要用到图片,有时候图片在下载、解压过程中会损坏,而如果一张一张点击来检查就太不Cool了,因此我想大家都
- 本篇文章用到 element官网 和 七牛云官网element-ui 官网:https://element.eleme.io/#/zh-CN
- /** * Ajax分页功能 * 在需要分页的地方添加<ul class="pagination"><
- 1. 解压ZIP包和配置首先,将mysql-5.5.25-winx64.zip 解压缩到D:/mysql-5.5.25 目录下,然后根据网上
- 视图函数中加上认证功能,流程见下图import hashlibimport timedef get_random(name):
- el-table格式化el-table-column内容遇到一个需求,一个循环展示的table中的某项,或者某几项需要格式化。对于格式化的方
- 简单的2048小游戏不多说,直接上图,这里并未实现GUI之类的,需要的话,可自行实现:接下来就是代码模块,其中的2048游戏原来网络上有很多
- 1 concatconcat函数是在pandas底下的方法,可以将数据根据不同的轴作简单的融合pd.concat(objs, axis=0,
- 通过购物车的一个案列,把vuex学习了一篇。vuex概念浅谈Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式。它采用集中式存储
- Q&AQ: .js和.min.js文件分别是什么?A: .js是JavaScript 源码文件, .min.js是压缩版的js文件。