用pytorch的nn.Module构造简单全链接层实例
作者:AItitanic 发布时间:2022-01-04 00:00:17
标签:pytorch,nn.Module,全链接层
python版本3.7,用的是虚拟环境安装的pytorch,这样随便折腾,不怕影响其他的python框架
1、先定义一个类Linear,继承nn.Module
import torch as t
from torch import nn
from torch.autograd import Variable as V
class Linear(nn.Module):
'''因为Variable自动求导,所以不需要实现backward()'''
def __init__(self, in_features, out_features):
super().__init__()
self.w = nn.Parameter( t.randn( in_features, out_features ) ) #权重w 注意Parameter是一个特殊的Variable
self.b = nn.Parameter( t.randn( out_features ) ) #偏值b
def forward( self, x ): #参数 x 是一个Variable对象
x = x.mm( self.w )
return x + self.b.expand_as( x ) #让b的形状符合 输出的x的形状
2、验证一下
layer = Linear( 4,3 )
input = V ( t.randn( 2 ,4 ) )#包装一个Variable作为输入
out = layer( input )
out
#成功运行,结果如下:
tensor([[-2.1934, 2.5590, 4.0233], [ 1.1098, -3.8182, 0.1848]], grad_fn=<AddBackward0>)
下面利用Linear构造一个多层网络
class Perceptron( nn.Module ):
def __init__( self,in_features, hidden_features, out_features ):
super().__init__()
self.layer1 = Linear( in_features , hidden_features )
self.layer2 = Linear( hidden_features, out_features )
def forward ( self ,x ):
x = self.layer1( x )
x = t.sigmoid( x ) #用sigmoid()激活函数
return self.layer2( x )
测试一下
perceptron = Perceptron ( 5,3 ,1 )
for name,param in perceptron.named_parameters():
print( name, param.size() )
输出如预期:
layer1.w torch.Size([5, 3])
layer1.b torch.Size([3])
layer2.w torch.Size([3, 1])
layer2.b torch.Size([1])
来源:https://blog.csdn.net/AItitanic/article/details/97611356


猜你喜欢
- 具体代码如下所述:import sysfrom PySide2.QtGui import *from PySide2.QtCore impo
- 1.lxml库简介lxml 是 Python 常用的文档解析库,能够高效地解析 HTML/XML 文档,常用于 Python 爬虫。lxml
- 1. 真值测试所谓真值测试,是指当一种类型对象出现在if或者while条件语句中时,对象值表现为True或者False。弄清楚各种情况下的真
- 一、项目概述本次项目目标是实现对自动生成的带有各种噪声的车牌识别。在噪声干扰情况下,车牌字符分割较困难,此次车牌识别是将车牌7个字符同时训练
- 说明:我这里要把MySql数据库存放目录/var/lib/mysql下面的pw85数据库备份到/home/mysql_data里面,并且保存
- 本文实例为大家分享了python实现TCP文件接收发送的具体代码,供大家参考,具体内容如下下一篇分享:udp收发的实现先运行服务器端打开接收
- 【1】 以XML 返回 (1)未定义属性的 select logisticsId,logisticsName from LogisticsC
- 这篇文章主要介绍了简单了解为什么python函数后有多个括号,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需
- 第一种方法:采用git命令操作1、例如仓库中有下面的代码(版本1)2、现在继续编写代码,并且提交到远程仓库中(版本2)3、回退到版本1中gi
- cli2去掉eslint检查器报错eslint在编写过程中及其严格,甚至单引号和双引号或者空格注释都会引起报错,导致项目无法正常运行因此,只
- 前言在java中,反斜杠“\”转义是“\”,因此表示一个“\”要使用“\\”,如果是正则表达式,那么表示一个“\”需要用“\\\\”,在my
- 这里使用TensorFlow实现一个简单的卷积神经网络,使用的是MNIST数据集。网络结构为:数据输入层–卷积层1–池化层1–卷积层2–池化
- 1、序言  上一节快速搭建Express开发系统步骤,对如何使用express-generator创建一
- 为数据库配置比较大的内存,可以有效提高数据库性能。因为数据库在运行过程中,会在内存中划出一块区域来作为数据缓存。通常情况下,用户访问数据库时
- 概述很多人接触Python,都是从爬虫开始,其实很多语言都可以做爬虫,只是Python相对其他语言来说,更加简单而已。但是Python并不止
- 1.彻底弄懂CSS盒子模式一(DIV布局快速入门) 2.彻底弄懂CSS盒子模式二(导航栏实例) 3.彻底弄懂CSS盒子模式三(浮动的表演和清
- 父传子:1、 在父组件的子组件标签上通过 :传递到子组件的数据名="需要传递的数据"在这里为了大家区分我将父组件中的数据
- 前段时间看到letcode上的元音字母字符串反转的题目,今天来研究一下字符串反转的内容。主要有三种方法:1.切片法(最简洁的一种)#切片法d
- PhotoSwipe插件能实现手机端点击图片全屏放大 再双击图片放大等功能PhotoSwipe插件官方网站 http://www.photo
- Python传入参数的方法有:位置参数、默认参数、可变参数、关键字参数、和命名关键字参数、以及各种参数调用的组合写在前面Python唯一支持