pytorch简单实现神经网络功能
作者:那小子真混蛋 发布时间:2022-02-21 23:29:49
标签:pytorch,神经网络
一、基本
(1)利用pytorch建好的层进行搭建
import torch
from torch import nn
from torch.nn import functional as F
#定义一个MLP网络
class MLP(nn.Module):
'''
网络里面主要是要定义__init__()、forward()
'''
def __init__(self):
'''
这里定义网络有哪些层(比如nn.Linear,Conv2d……)[可不含激活函数]
'''
super().__init__()#调用Module(父)初始化
self.hidden = nn.Linear(5,10)
self.out = nn.Linear(10,2)
def forward(self,x):
'''
这里定义前向传播的顺序,即__init__()中定义的层是按怎样的顺序进行连接以及传播的[在这里加上激活函数,以构造复杂函数,提高拟合能力]
'''
return self.out(F.relu(self.hidden(x)))
上面的3层感知器可以用于解决一个简单的现实问题:给定5个特征,输出0-1类别概率值,是一个简单的2分类解决方案。
搭建一些简单的网络时,可以用nn.Sequence(层1,层2,……,层n)一步到位:
import torch
from torch import nn
from torch.nn import functional as F
net = nn.Sequential(nn.Linear(5,10),nn.ReLU(),nn.Linear(10,2))
但是nn.Sequence仅局限于简单的网络搭建,而自定义网络可以实现复杂网络结构。
(1)中定义的MLP大致如上(5个输入->全连接->ReLU()->输出)
(2)使用网络
import torch
from torch import nn
from torch.nn import functional as F
net = MLP()
x = torch.randn((15,5))#15个samples,5个输入属性
out = net(x)
#也可调用forward->"out = net.forward(x)"
print(out)
#print(out.shape)
tensor([[-0.0760, -0.1026],
[-0.3277, -0.2332],
[-0.0314, -0.1921],
[ 0.0131, -0.1473],
[-0.0650, -0.2310],
[ 0.3009, -0.5510],
[ 0.1491, -0.0928],
[-0.1438, -0.1304],
[-0.1945, -0.1944],
[ 0.1088, -0.2249],
[ 0.0016, -0.2334],
[ 0.1401, -0.3709],
[-0.1864, -0.1764],
[ 0.0775, -0.0160],
[ 0.0150, -0.3198]], grad_fn=<AddmmBackward>)
二、进阶
(1)构建较复杂的网络结构
a. Sequence、net套娃
import torch
from torch import nn
from torch.nn import functional as F
class MLP2(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(nn.Linear(5,10),nn.ReLU(),nn.Linear(10,5))
self.out = nn.Linear(5,4)
def forward(self,x):
return self.out(F.relu(self.net(x)))
net2 = nn.Sequential(MLP2(),nn.ReLU(),nn.Linear(4,2))
net2.eval()
# eval()等价print(net2)
Sequential(
(0): MLP2(
(net): Sequential(
(0): Linear(in_features=5, out_features=10, bias=True)
(1): ReLU()
(2): Linear(in_features=10, out_features=5, bias=True)
)
(out): Linear(in_features=5, out_features=4, bias=True)
)
(1): ReLU()
(2): Linear(in_features=4, out_features=2, bias=True)
)
(2) 参数
a. 权重、偏差的访问
#访问权重和偏差
print(net2[2].weight)#注意weight是parameter类型,.data访问数值
print(net2[2].bias.data)
#输出所有权重、偏差
print(*[(name,param) for name,param in net2[2].parameters()])
b. 不同网络之间共享参数
shared = nn.Linear(8,8)
net = nn.Sequential(nn.Linear(5,8),nn.ReLU(),shared,nn.ReLU(),shared)
print(net[2].weight.data[0])
net[2].weight.data[0][0] = 100
print(net[2].weight.data[0][0])
print(net[2].weight.data[0] == net[4].weight.data[0])
net.eval()
c. 参数初始化
def init_Linear(m):
if type(m) == nn.Linear:
nn.init.normal_(m.weight,mean = 0,std = 0.01) #将权重按照均值为0,标准差为0.01的正态分布进行初始化
nn.init.zeros_(m.bias) #将偏差置为0
def init_const(m):
if type(m) == nn.Linear:
nn.init.constant_(m.weight,42) #将权重全部置为42
def my_init(m):
if type(m) == nn.Linear:
'''
对weight和bias自定义初始化
'''
pass
#如何调用?
net2.apply(init_const) #在net2中进行遍历,对每个Linear执行初始化
(3)自定义层(__init__()中可含输入输出层)
a. 不带输入输出的自定义层(输入输出一致,x数进,x数出,对每个值进行相同的操作,类似激活函数)
b. 带输入输出的自定义层
import torch
from torch import nn
from torch.nn import functional as F
#a
class decentralized(nn.Module):
def __init__(self):
super().__init__()
def forward(self,x):
return x-x.mean()
#b
class my_Linear(nn.Module):
def __init__(self,dim_in,dim_out):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim_in,dim_out)) #由于x行数为dim_out,列数为dim_in,要做乘法,权重行列互换
self.bias = nn.Parameter(torch.randn(dim_out))
def forward(self,x):
return F.relu(torch.matmul(x,self.weight.data)+self.bias.data)
tmp = my_Linear(5,3)
print(tmp.weight)
(4)读写
#存取任意torch类型变量
x = torch.randn((20,20))
torch.save(x,'X') #存
y = torch.load('X') #取
#存储网络
torch.save(net2.state_dict(),'Past_parameters') #把所有参数全部存储
clone = nn.Sequential(MLP2(),nn.ReLU(),nn.Linear(4,2)) #存储时同时存储网络定义(网络结构)
clone.load_state_dict(torch.load('Past_parameters'))
clone.eval()
来源:https://www.cnblogs.com/BGM-WCJ/p/16695133.html
0
投稿
猜你喜欢
- 问题的起源早些时候使用with实现了一版全局进程锁,希望实现以下效果:with CacheLock("test_lock"
- 目录1. 简介2. 示例代码13. 示例代码24. 启动异常1. 简介Gunicorn(Green Unicorn)是给Unix用的WSGI
- JavaScript Length 字符长度函数,在很多时间我们会用length函数了,因为你得前台判断一个用户输入
- 看代码吧~package mainimport ("fmt""io""net/http&q
- 总说由于pytorch 0.4版本更新实在太大了, 以前版本的代码必须有一定程度的更新. 主要的更新在于 Variable和Tensor的合
- 1.新建四个层,放入相应图片,模特层的z-index值设为0。2.把第一个层移到模特身上,找出衣服刚好穿上时层的top和left值,记下来,
- 这篇文章将介绍在Python中使用 "frozenset "函数的指南,该函数返回一个新的frozenset类型的Pyt
- 穿过云朵升一级是要花6个金币的,有的时候金币真的很重要前言嗨喽,大家好呀!这里是魔王~一天晚上,天空中掉下一颗神奇的豌豆种子,正好落在了梦之
- 前言在pytorch中, 想删除tensor中的指定行列,原本以为有个函数或者直接把某一行赋值为[]就可以,结果发现没这么简单,因此用了一个
- 可视性的问题几乎在每次不同产品的用户测试中都会出现:用户总是对页面的某些元素、功能视若无睹,或根本无视。基于此,对这个问题进行了一番小小的研
- 前言最近天气好像有了点小脾气,总是在万分晴朗得时候耍点小性子~阴会天,下上一会的雨~提醒我们时刻记得带伞哦,不然会被雨淋或者被太阳公公晒到
- 最近老板叫做一个数据查重的小练习,涉及从一个包含中文字段的文件中提取出其中的中文字段并存储,使用php开发。中间涉及到php正则表达式中文匹
- 需求背景女朋友的论文需要爬取YouTube视频热评,但爬下来的都是外文。主要设计 读取一个表格文件,获取需要翻译的文本
- getatter()通过方法名字符串调用方法,这个方法最主要的作用就是实现反射机制,也就是说可以通过字符串获取方法实例,这样就可以把一个类可
- 曾经有许多创造性的logo设计案例,logo设计资源和logo设计指导张贴在互联网的各个角落。这些帮助会为你的logo设计创造一个功能强大的
- <?php //本功能主要是利用文件修改时间函数filemtime与现在时间作减法判断是否更新内容。 $cahetime=2;//设置
- 作为设计主管,Peter Stern 已经领导 microsoft.com 重新设计了主页并且开发了五个不同的交互工具,这些工具被用于下载中
- 一、Socketserver实现FTP,文件上传、下载目录结构1、socketserver实现ftp文件上传下载,可以同时多用户登录、上传、
- 六、XML展望 任何一项新技术的产生都是有其需求背景的,XML的诞生是在HTML遇到不可克服的困难之后。近年来HTML在许多复杂的Web应用
- 一、简单介绍正则表达式是一种小型的、高度专业化的编程语言,并不是python * 有的,是许多编程语言中基础而又重要的一部分。在python中