pytorch 自定义参数不更新方式
作者:ShellCollector 发布时间:2021-11-11 01:55:55
标签:pytorch,自定义,参数,更新
nn.Module中定义参数:不需要加cuda,可以求导,反向传播
class BiFPN(nn.Module):
def __init__(self, fpn_sizes):
self.w1 = nn.Parameter(torch.rand(1))
print("no---------------------------------------------------",self.w1.data, self.w1.grad)
下面这个例子说明中间变量可能没有梯度,但是最终变量有梯度:
cy1 cd都有梯度
import torch
xP=torch.Tensor([[ 3233.8557, 3239.0657, 3243.4355, 3234.4507, 3241.7087,
3243.7292, 3234.6826, 3237.6609, 3249.7937, 3244.8623,
3239.5349, 3241.4626, 3251.3457, 3247.4263, 3236.4924,
3251.5735, 3246.4731, 3242.4692, 3239.4958, 3247.7283,
3251.7134, 3249.0237, 3247.5637],
[ 1619.9011, 1619.7140, 1620.4883, 1620.0642, 1620.2191,
1619.9796, 1617.6597, 1621.1522, 1621.0869, 1620.9725,
1620.7130, 1620.6071, 1620.7437, 1621.4825, 1620.5107,
1621.1519, 1620.8462, 1620.5944, 1619.8038, 1621.3364,
1620.7399, 1621.1178, 1618.7080],
[ 1619.9330, 1619.8542, 1620.5176, 1620.1167, 1620.1577,
1620.0579, 1617.7155, 1621.1718, 1621.1338, 1620.9572,
1620.6288, 1620.6621, 1620.7074, 1621.5305, 1620.5656,
1621.2281, 1620.8346, 1620.6021, 1619.8228, 1621.3936,
1620.7616, 1621.1954, 1618.7983],
[ 1922.6078, 1922.5680, 1923.1331, 1922.6604, 1922.9589,
1922.8818, 1920.4602, 1923.8107, 1924.0142, 1923.6907,
1923.4465, 1923.2820, 1923.5728, 1924.4071, 1922.8853,
1924.1107, 1923.5465, 1923.5121, 1922.4673, 1924.1871,
1923.6248, 1923.9086, 1921.9496],
[ 1922.5948, 1922.5311, 1923.2850, 1922.6613, 1922.9734,
1922.9271, 1920.5950, 1923.8757, 1924.0422, 1923.7318,
1923.4889, 1923.3296, 1923.5752, 1924.4948, 1922.9866,
1924.1642, 1923.6427, 1923.6067, 1922.5214, 1924.2761,
1923.6636, 1923.9481, 1921.9005]])
yP=torch.Tensor([[ 2577.7729, 2590.9868, 2600.9712, 2579.0195, 2596.3684,
2602.2771, 2584.0305, 2584.7749, 2615.4897, 2603.3164,
2589.8406, 2595.3486, 2621.9116, 2608.2820, 2582.9534,
2619.2073, 2607.1233, 2597.7888, 2591.5735, 2608.9060,
2620.8992, 2613.3511, 2614.2195],
[ 673.7830, 693.8904, 709.2661, 675.4254, 702.4049,
711.2085, 683.1571, 684.6160, 731.3878, 712.7546,
692.3011, 701.0069, 740.6815, 720.4229, 681.8199,
736.9869, 718.5508, 704.3666, 695.0511, 721.5912,
739.6672, 728.0584, 729.3143],
[ 673.8367, 693.9529, 709.3196, 675.5266, 702.3820,
711.2159, 683.2151, 684.6421, 731.5291, 712.6366,
692.1913, 701.0057, 740.6229, 720.4082, 681.8656,
737.0168, 718.4943, 704.2719, 695.0775, 721.5616,
739.7233, 728.1235, 729.3387],
[ 872.9419, 891.7061, 905.8004, 874.6565, 899.2053,
907.5082, 881.5528, 883.0028, 926.3083, 908.9742,
890.0403, 897.8606, 934.6913, 916.0902, 880.4689,
931.3562, 914.4233, 901.2154, 892.5759, 916.9590,
933.9291, 923.0745, 924.4461],
[ 872.9661, 891.7683, 905.8128, 874.6301, 899.2887,
907.5155, 881.6916, 883.0234, 926.3242, 908.9561,
890.0731, 897.9221, 934.7324, 916.0806, 880.4300,
931.3933, 914.5662, 901.2715, 892.5501, 916.9894,
933.9813, 923.0823, 924.3654]])
shape=[4000, 6000]
cx,cy1=torch.rand(1,requires_grad=True),torch.rand(1,requires_grad=True)
cd=torch.rand(1,requires_grad=True)
ox,oy=cx,cy1
print('cx:{},cy:{}'.format(id(cx),id(cy1)))
print('ox:{},oy:{}'.format(id(ox),id(oy)))
cx,cy=cx*shape[1],cy1*shape[0]
print('cx:{},cy:{}'.format(id(cx),id(cy)))
print('ox:{},oy:{}'.format(id(ox),id(oy)))
distance=torch.sqrt(torch.pow((xP-cx),2)+torch.pow((yP-cy),2))
mean=torch.mean(distance,1)
starsFC=cd*torch.pow((distance-mean[...,None]),2)
loss=torch.sum(torch.mean(starsFC,1).squeeze(),0)
loss.backward()
print(loss)
print(cx)
print(cy1)
print("cx",cx.grad)
print("cy",cy1.grad)
print("cd",cd.grad)
print(ox.grad)
print(oy.grad)
print('cx:{},cy:{}'.format(id(cx),id(cy)))
print('ox:{},oy:{}'.format(id(ox),id(oy)))
来源:https://blog.csdn.net/jacke121/article/details/103672674
0
投稿
猜你喜欢
- 这篇文章主要介绍了如何通过Django使用本地css/js文件,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,
- 1.配置环境安装python3安装python3-pip通过pip安装Django**如果需要使用Jinja模板,需要通过pip安装djan
- 本文实例为大家分享了java连接mysql底层封装代码,供大家参考,具体内容如下连接数据库package com.dao.db;import
- 方案有很多种,我这里简单说一下:1. into outfileSELECT * FROM mytable  
- 我想大家在用Sql2005一般都是.NET2005自带的SQL Server 2005是SQL Server2005 Express版本的,
- 本文实例讲述了Python树莓派学习笔记之UDP传输视频帧操作。分享给大家供大家参考,具体如下:因为我在自己笔记本电脑上没能成功安装Open
- 1、首先简述数据挖掘的过程第一步:数据选择可以通过业务原始数据、公开的数据集、也可通过爬虫的方式获取。第二步: 数据预处理数据极可能有噪音,
- 页面加载loading效果, 这个挺好玩的!用setTimeout实现的!可以和服务端整合弄一些生成HTML或者上传文件等应用!
- W3C终于发布了第一个HTML5草案,大家还沉溺在HTML2XHTML转换的快乐和痛苦中时,却又突然发现,HTML5和XHTML2,到底谁是
- 背景测试工具箱写到一半,今天遇到了一个前后端数据交互的问题,就一起做一下整理。环境-----------------------------
- 今天写的代码片段:X = Y = []..X.append(x)Y.append(y)其中x和y是读取的每一个数据的xy值,打算将其归入列表
- 模块导入的规范模块是类或函数的集合,用于实现某个功能。模块的导入和Java 中包的导入的概念很相似都使用import语句。在Python中,
- 人丑就要多读书,颜值不够知识来凑,至少你可以用书籍来武装你的大脑,拯救你的人生。TIOBE编程语言排行榜前20的语言入门书籍推荐
- 日期是许多 JavaScript 应用程序的基本组成部分,无论是在网页上显示当前日期还是处理用户输入以安排事件。但以清晰一致的格式显示日期对
- 本文为大家分享了Windows下mysql5.7.18安装配置教程,供大家参考,具体内容如下准备:操作系统:win7下64位的zip版本的M
- 独立 fmt Log输出重定向golang的fmt包的输出函数 Println、Printf、PrintStack等,默认将打印输出到os.
- 本文为大家分享了解决Mysql存储引擎MyISAM常见问题的方法,供大家参考,具体内容如下一、处理MyISAM存储引擎的表损坏在使用MySQ
- 百度AI提供了一天50000次的免费文字识别额度,可以愉快的免费使用!下面直接上方法:首先在百度AI创建一个应用,按照下图创建即可,创建后会
- 1、出现错误train_df = pd.read_csv( 'C:\Users\lenovo\Desktop\train.csv
- 在SQL Server数据库中如何查看一个登录名(login)的具体权限呢,如果使用SSMS的UI界面查看登录名的具体权限的话,用户数据库非