pytorch网络模型构建场景的问题介绍
作者:mingqian_chu 发布时间:2022-07-24 22:38:42
记录使用pytorch构建网络模型过程遇到的点
1. 网络模型构建中的问题
1.1 输入变量是Tensor张量
各个模块和网络模型的输入,一定要是tensor
张量;
可以用一个列表存放多个张量。
如果是张量维度不够,需要升维度,
可以先使用 torch.unsqueeze(dim = expected)
然后再使用torch.cat(dim )
进行拼接;
需要传递梯度的数据,禁止使用numpy
, 也禁止先使用numpy,然后再转换成张量的这种情况出现;
这是因为pytorch的机制是只有是 Tensor
张量的类型,才会有梯度等属性值,如果是numpy这些类别,这些变量并会丢失其梯度值。
1.2 __init__()方法使用
class ex:
def __init__(self):
pass
__init__
方法必须接受至少一个参数即self,
Python中,self是指向该对象本身的一个引用,
通过在类的内部使用self变量,
类中的方法可以访问自己的成员变量,简单来说,self.varname的意义为”访问该对象的varname属性“
当然,__init__()
中可以封装任意的程序逻辑,这是允许的,init()方法还接受任意多个其他参数,允许在初始化时提供一些数据,例如,对于刚刚的worker类,可以这样写:
class worker:
def __init__(self,name,pay):
self.name=name
self.pay=pay
这样,在创建worker类的对象时,必须提供name和pay两个参数:
b=worker('Jim',5000)
Python会自动调用worker.init()方法,并传递参数。
细节参考这里init方法
1.3 内置函数setattr()
此时,可以使用python自带的内置函数 setattr()
,和对应的getattr()
setattr(object, name, value)
object – 对象。
name – 字符串,对象属性。
value – 属性值。
对已存在的属性进行赋值:
>>>class A(object):
... bar = 1
...
>>> a = A()
>>> getattr(a, 'bar') # 获取属性 bar 值
1
>>> setattr(a, 'bar', 5) # 设置属性 bar 值
>>> a.bar
5
如果属性不存在会创建一个新的对象属性,并对属性赋值:>>>class A():
... name = "runoob"
...
>>> a = A()
>>> setattr(a, "age", 28)
>>> print(a.age)
28
>>>
setattr() 语法
setattr(object, name, value)
object – 对象。
name – 字符串,对象属性。
value – 属性值。
1.4 网络模型的构建
注意到,在python的 __init__()
函数中,self
本身就是该类的对象的一个引用,即self是指向该对象本身的一个引用,
利用上述这一点,当在神经网络中,
需要给多个属性进行实例化时,
且这多个属性使用的是同一个类进行实例化.
则使用 setattr(self, string, object1)
添加属性;
class Temporal_GroupTrans(nn.Module):
def __init__(self, num_classes=10,num_groups=35, drop_prob=0.5, pretrained= True):
super(Temporal_GroupTrans, self).__init__()
conv_block = Basic_slide_conv()
for i in range( num_groups):
setattr(self, "group" + str(i), conv_block)
# 自定义transformer模型的初始化, CustomTransformerModel() 在该类中传入初始化模型的参数,
# nip:512输入序列中,每个列向量的编码维度,16:注意力头的个数
# 600:中间mlp 隐藏层的维数, 6: 堆叠transforEncode编码模块的个数;
self.trans_model = CustomTransformerModel(512,16,600, 6,droupout=0.5,nclass=4)
则使用 getattr(self, string, object1)
获取属性;
trans_input_sequence = []
for i in range(0, num_groups, ):
# 每组语谱图的大小是一个 (bt, ch,96,12)的矩阵,组与组之间没有重叠;
cur_group = x[:, :, :, 12 * i:12 * (i + 1)]
# VARIABLE_fun = "self.group" # 每一组,与之对应的卷积模块;
# cur_fun = eval(VARIABLE_fun + str(i ))
cur_fun = getattr(self, 'group'+str(i))
cur_group_out = cur_fun(cur_group).unsqueeze(dim=1) # [bt,1, 512]
trans_input_sequence.append(cur_group_out)
来源:https://blog.csdn.net/chumingqian/article/details/129417691
猜你喜欢
- 一、前言1.1.环境python版本:3.6Django版本:1.11.61.2.预览效果最终搭建的blog的样子,基本上满足需求了。框架搭
- 实现了宽度、高度、透明度的渐变,还能以高度宽度中点为中心,还扩展成以任意点为中心渐变(实例中以点击点为中心)。<!DOCTYPE ht
- 在VBScript中有Filter这个函数可以用来对数组进行过滤,并返回原数组的一个子集数组。语法说明: 引用内容Filter 函
- 本文实例讲述了Python3.4类型判断,异常处理,终止程序操作。分享给大家供大家参考,具体如下:python3.4学习笔记 类型判断,异常
- 由于最近测试需要录制系统界面的操作过程,因为都是全屏的操作,所以用python做一个简单的录屏小工具。实现过程也是比较简单,就是通过对屏幕操
- 最近为了熟悉一下 js 用有道翻译练了一下手,写一篇博客记录一下,也希望能对大家有所启迪,不过这些网站更新太快,可能大家尝试的时候会有所不同
- 有一个群友在群里问个如何快速搭建一个搜索引擎,在搜索之后我看到了这个代码所在Git:https://github.com/asciimoo/
- 目标是拷贝微信的飞机大战,当然拷贝完以后大家就具备自己添加不同内容的能力了。首先是要拿到一些图片素材,熟悉使用图像处理软件和绘画的人可以自己
- 本文实例为大家分享了使用RNN进行文本分类,python代码实现,供大家参考,具体内容如下1、本博客项目由来是oxford 的nlp 深度学
- 神经网络只是由两个或多个线性网络层叠加,并不能学到新的东西,简单地堆叠网络层,不经过非线性激活函数激活,学到的仍然是线性关系。但是加入激活函
- 函数是有组织的,可重复使用的代码,用于执行一个单一的,相关的动作的块。函数为应用程序和代码重用的高度提供了更好的模块。正如我们知
- 研究(2)中讨论了栅格系统的基础知识。这一篇将集中探讨栅格系统的粒度问题。(注:如非特别指明,栅格系统均指24列960栅格系统)淘宝的首页(
- 古巴比伦王颁布了汉摩拉比法典,刻在黑色的玄武岩,距今已经三千七百多年,你在橱窗前…熟悉吧?没错,这就是周董的爱在西元前歌词。前不久工作不是很
- 最近遇到一个问题,就是获取表单中的日期往后台通过json方式传的时候,遇到Date.parse(str)函数在ff下报错: NAN 找了些资
- 详解Python import方法引入模块的实例在Python用import或者from…import或者from…import…as…来导
- 阅读上一节:无序列表信息有时候是无序归纳的,有的却有着明确的顺序,在上一篇也提到了。那么简单的来想一下身边有哪些事物是有先后顺序的:操作步骤
- 1、字符串拼接通过+运算符现有字符串码农飞哥好,,要求将字符串码农飞哥牛逼拼接到其后面,生成新的字符串码农飞哥好,码农飞哥牛逼举个例子:st
- 自动签到的python脚本源码新建一个python文件,checkin.py,保存到电脑上某个位置,我这里保存到的是E:\pyproject
- 一.windows系统的解决方法1.首先以系统管理员身份登陆系统。2.停止MySQL的服务。3.进入命令窗口,然后进入MySQL的安装目录,
- 我就废话不多说了,大家还是直接看代码吧!import socketimport sysimport timeimport structHOS