支持PyTorch的einops张量操作神器用法示例详解
作者:木盏 发布时间:2023-10-17 23:13:06
标签:PyTorch,einops,张量操作
今天做visual transformer研究的时候,发现了einops这么个神兵利器,决定大肆安利一波。
先看链接:https://github.com/arogozhnikov/einops
安装:
pip install einops
基础用法
einops的强项是把张量的维度操作具象化,让开发者“想出即写出”。举个例子:
from einops import rearrange
# rearrange elements according to the pattern
output_tensor = rearrange(input_tensor, 'h w c -> c h w')
用'h w c -> c h w'就完成了维度调换,这个功能与pytorch中的permute相似。但是,einops的rearrange玩法可以更高级:
from einops import rearrange
import torch
a = torch.randn(3, 9, 9) # [3, 9, 9]
output = rearrange(a, 'c (r p) w -> c r p w', p=3)
print(output.shape) # [3, 3, 3, 9]
这就是高级用法了,把中间维度看作r×p,然后给出p的数值,这样系统会自动把中间那个维度拆解成3×3。这样就完成了[3, 9, 9] -> [3, 3, 3, 9]的维度转换。
这个功能就不是pytorch的内置功能可比的。
除此之外,还有reduce和repeat,也是很好用。
from einops import repeat
import torch
a = torch.randn(9, 9) # [9, 9]
output_tensor = repeat(a, 'h w -> c h w', c=3) # [3, 9, 9]
指定c,就可以指定复制的层数了。
再看reduce:
from einops import reduce
import torch
a = torch.randn(9, 9) # [9, 9]
output_tensor = reduce(a, 'b c (h h2) (w w2) -> b h w c', 'mean', h2=2, w2=2)
这里的'mean'指定池化方式。 相信你看得懂,不懂可留言提问~
高级用法
einops也可以嵌套在pytorch的layer里,请看:
# example given for pytorch, but code in other frameworks is almost identical
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
from einops.layers.torch import Rearrange
model = Sequential(
Conv2d(3, 6, kernel_size=5),
MaxPool2d(kernel_size=2),
Conv2d(6, 16, kernel_size=5),
MaxPool2d(kernel_size=2),
# flattening
Rearrange('b c h w -> b (c h w)'),
Linear(16*5*5, 120),
ReLU(),
Linear(120, 10),
)
这里的Rearrange是nn.module的子类,直接可以当作网络层放到模型里~
一个字,绝。
来源:https://blog.csdn.net/leviopku/article/details/116204922
0
投稿
猜你喜欢
- 前言有一天朋友A向我抱怨,他的老板要求他把几百份word填好的word表格简历信息整理到excel中,看着他一个个将姓名,年龄……从word
- 一、前情提要为什么要使用Scrapy 框架?前两篇深造篇介绍了多线程这个概念和实战多线程网页爬取多线程爬取网页项目实战经过之前的学习,我们基
- 最近一周每天早上起来第一件事,就是打开新闻软件看疫情相关的新闻。了解下自己和亲友所在城市的确诊人数,但纯数字还是缺乏一个直观的概念。那我们来
- 前面最近,看到不少小伙伴问pytorch如何保存和加载模型,其实这部分pytorch官网介绍的也是很清楚的,感兴趣的点击了解详情🥁🥁🥁但是肯
- 从Android 3.0开始除了我们重点讲解的Fragment外,Action Bar也是一个重要的内容,Action Bar主要是用于代替
- 一.__eq__方法在我们定义一个类的时候,常常想对一个类所实例化出来的两个对象进行判断这两个对象是否是完全相同的。一般情况下,我们认为如果
- 这一版,对虹软的功能进行了一些封装,添加了人脸特征比对,比对结果保存到文件,和从文件提取特征进行比对,大体功能基本都已经实现,可以进行下一步
- 目录一、🌕月亮二、🌕雪花月饼一、🌕月亮导入库matplotlib和numpy,作为工具直接用。from mpl_toolkits.mplot
- 直方图处理直方图从图像内部灰度级的角度对图像进行表述从直方图的角度对图像进行处理,可以达到增强图像显示效果的目的。直方图的含义直方图是图像内
- 这篇文章主要介绍了python StringIO如何在内存中读写str,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学
- 本文实例为大家分享了python实现图像拼接的具体代码,供大家参考,具体内容如下1.待拼接的图像2. 基于SIFT特征点和RANSAC方法得
- 图片外框特征参数: ①dashed:虚线②dotted:点虚线③solid:实线④double:双线⑤groove:沟
- 在python中,通过内嵌集成re模块,程序媛们可以直接调用来实现正则匹配。本文重点给大家介绍python中正则表达式 re.findall
- 一、首先从SQLServer中Error讲起,SQL中错误处理有些怪辟 错误级别同是16但结果都不同。select *
- 如何做一个只搜索本网站的引擎? 用下面两个文件即可实现:searchfiles.html &l
- 学习前言最近在学目标检测……SSD的源码好复杂……看
- 首先我们放出tf2.0关于tf.keras.layers.Conv2D()函数的官方文档,然后逐一对每个参数的含义和用法进行解释:tf.ke
- python openvc 裁剪图片下面是4个坐标代码:import cv2#裁剪图片路径input_path,四个裁剪坐标为:y1,y2,
- asp防止用户同时登陆的方法,实现这个功能可有两种方式:1.使用application用application对象:如果做的是大型社区,可能
- 我的PJBlog在从2.7升级的3.0的时候,犹豫了很久。升级到PJBlog3.0就是看中了新增的静态页面功能,但是同时又担心造成博客出现大