pytorch网络模型构建场景的问题介绍
作者:mingqian_chu 发布时间:2022-07-24 22:38:42
记录使用pytorch构建网络模型过程遇到的点
1. 网络模型构建中的问题
1.1 输入变量是Tensor张量
各个模块和网络模型的输入,一定要是tensor
张量;
可以用一个列表存放多个张量。
如果是张量维度不够,需要升维度,
可以先使用 torch.unsqueeze(dim = expected)
然后再使用torch.cat(dim )
进行拼接;
需要传递梯度的数据,禁止使用numpy
, 也禁止先使用numpy,然后再转换成张量的这种情况出现;
这是因为pytorch的机制是只有是 Tensor
张量的类型,才会有梯度等属性值,如果是numpy这些类别,这些变量并会丢失其梯度值。
1.2 __init__()方法使用
class ex:
def __init__(self):
pass
__init__
方法必须接受至少一个参数即self,
Python中,self是指向该对象本身的一个引用,
通过在类的内部使用self变量,
类中的方法可以访问自己的成员变量,简单来说,self.varname的意义为”访问该对象的varname属性“
当然,__init__()
中可以封装任意的程序逻辑,这是允许的,init()方法还接受任意多个其他参数,允许在初始化时提供一些数据,例如,对于刚刚的worker类,可以这样写:
class worker:
def __init__(self,name,pay):
self.name=name
self.pay=pay
这样,在创建worker类的对象时,必须提供name和pay两个参数:
b=worker('Jim',5000)
Python会自动调用worker.init()方法,并传递参数。
细节参考这里init方法
1.3 内置函数setattr()
此时,可以使用python自带的内置函数 setattr()
,和对应的getattr()
setattr(object, name, value)
object – 对象。
name – 字符串,对象属性。
value – 属性值。
对已存在的属性进行赋值:
>>>class A(object):
... bar = 1
...
>>> a = A()
>>> getattr(a, 'bar') # 获取属性 bar 值
1
>>> setattr(a, 'bar', 5) # 设置属性 bar 值
>>> a.bar
5
如果属性不存在会创建一个新的对象属性,并对属性赋值:>>>class A():
... name = "runoob"
...
>>> a = A()
>>> setattr(a, "age", 28)
>>> print(a.age)
28
>>>
setattr() 语法
setattr(object, name, value)
object – 对象。
name – 字符串,对象属性。
value – 属性值。
1.4 网络模型的构建
注意到,在python的 __init__()
函数中,self
本身就是该类的对象的一个引用,即self是指向该对象本身的一个引用,
利用上述这一点,当在神经网络中,
需要给多个属性进行实例化时,
且这多个属性使用的是同一个类进行实例化.
则使用 setattr(self, string, object1)
添加属性;
class Temporal_GroupTrans(nn.Module):
def __init__(self, num_classes=10,num_groups=35, drop_prob=0.5, pretrained= True):
super(Temporal_GroupTrans, self).__init__()
conv_block = Basic_slide_conv()
for i in range( num_groups):
setattr(self, "group" + str(i), conv_block)
# 自定义transformer模型的初始化, CustomTransformerModel() 在该类中传入初始化模型的参数,
# nip:512输入序列中,每个列向量的编码维度,16:注意力头的个数
# 600:中间mlp 隐藏层的维数, 6: 堆叠transforEncode编码模块的个数;
self.trans_model = CustomTransformerModel(512,16,600, 6,droupout=0.5,nclass=4)
则使用 getattr(self, string, object1)
获取属性;
trans_input_sequence = []
for i in range(0, num_groups, ):
# 每组语谱图的大小是一个 (bt, ch,96,12)的矩阵,组与组之间没有重叠;
cur_group = x[:, :, :, 12 * i:12 * (i + 1)]
# VARIABLE_fun = "self.group" # 每一组,与之对应的卷积模块;
# cur_fun = eval(VARIABLE_fun + str(i ))
cur_fun = getattr(self, 'group'+str(i))
cur_group_out = cur_fun(cur_group).unsqueeze(dim=1) # [bt,1, 512]
trans_input_sequence.append(cur_group_out)
来源:https://blog.csdn.net/chumingqian/article/details/129417691


猜你喜欢
- Git 代码管理工具,类似 SVN 客户端。安装步骤:1、官网下载Git:https://gitforwindows.org/2、双击运行,
- 写入:1:把gif图像文件读入内存(一个变量strTemp)。2:写入数据库。Dim binTmp() As
- 夹角余弦(Cosine)也可以叫余弦相似度。 几何中夹角余弦可用来衡量两个向量方向的差异,机器学习中借用这一概念来衡量样本向量之间的差异。(
- 目录一.准备工作二.预览1.启动2.运行3.结果三.设计思路四.源代码4.1 GUI.py4.2 Search_Apps.py五.总结一.准
- 本文所用环境:Python 3.6.5 |Anaconda custom (64-bit)|引言由于某些原因,需要用python读取二进制文
- 安装laravel框架命令行cd进入指定目录下,执行composer create-project --prefer-dist larave
- 一、Pythont如何打开 txt 格式的文件?1.首先我使用pycharm创建一个项目,然后在这个项目里面再创建一个python的包,然后
-   MySQL行转列,对经常处理数据的同学们来说,一定是不陌生的,甚至是印象深刻,因为它大概率困扰过你,
- 简述生活中经常要用到各种要求的证件照电子版,红底,蓝底,白底等,大部分情况我们只有其中一种,所以通过技术手段进行合成,用ps处理证件照,由于
- 本文实例讲述了python装饰器原理与用法。分享给大家供大家参考,具体如下:你会Python嘛?我会!那你给我讲下Python装饰器吧!Py
- 工作中我们经常需要判断某个变量/属性是否为undefined。通常有两种写法// 方式1 typeof age === 'undef
- 网上存在这么一个例子 obj = pd.Series([7,-5,7,4,2,0,4])obj.rank()输出为:0 6.51
- 问题我们使用anoconda创建envs环境下的Tensorflow-gpu版的,但是当我们在Pycharm设置里的工程中安装Keras后,
- 1、yield,将函数变为 generator (生成器)例如:斐波那契数列def fib(num): a, b, c = 1,
- ALTER TABLE 表名字 ADD CONSTRAINT pk_表名字 PRIMARY KEY( SNumber, SDate ); S
- 实现功能excel表格中有4列数,分别为RMF计算得到的 β,γ,势能面及组态,需要挑选出相同 β 值下势能面最低时的组态。为了减小数据量,
- 前言今天在开发时发现一个奇怪的问题,我手动改完数据库竟然不生效,反复确认环境无误后猜测是缓存的问题,因为是新接手的项目,代码还不熟悉,仔细一
- XML的未来 现在你已经知道XML。确实,结构有点复杂,而且DTD有各种可以定义文档可以包含的内容的选项。但还不只这些。考虑一个数据交换对其
- 这篇文章主要介绍了Python3打包exe代码2种方法实例解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,
- 【问】使用FCKeditor添加文章时,在文章最后多了逗号。【答】此情况发生在asp环境中。在asp里对于 提交的表单信息中如果有相同nam