支持PyTorch的einops张量操作神器用法示例详解
作者:木盏 发布时间:2023-10-17 23:13:06
标签:PyTorch,einops,张量操作
今天做visual transformer研究的时候,发现了einops这么个神兵利器,决定大肆安利一波。
先看链接:https://github.com/arogozhnikov/einops
安装:
pip install einops
基础用法
einops的强项是把张量的维度操作具象化,让开发者“想出即写出”。举个例子:
from einops import rearrange
# rearrange elements according to the pattern
output_tensor = rearrange(input_tensor, 'h w c -> c h w')
用'h w c -> c h w'就完成了维度调换,这个功能与pytorch中的permute相似。但是,einops的rearrange玩法可以更高级:
from einops import rearrange
import torch
a = torch.randn(3, 9, 9) # [3, 9, 9]
output = rearrange(a, 'c (r p) w -> c r p w', p=3)
print(output.shape) # [3, 3, 3, 9]
这就是高级用法了,把中间维度看作r×p,然后给出p的数值,这样系统会自动把中间那个维度拆解成3×3。这样就完成了[3, 9, 9] -> [3, 3, 3, 9]的维度转换。
这个功能就不是pytorch的内置功能可比的。
除此之外,还有reduce和repeat,也是很好用。
from einops import repeat
import torch
a = torch.randn(9, 9) # [9, 9]
output_tensor = repeat(a, 'h w -> c h w', c=3) # [3, 9, 9]
指定c,就可以指定复制的层数了。
再看reduce:
from einops import reduce
import torch
a = torch.randn(9, 9) # [9, 9]
output_tensor = reduce(a, 'b c (h h2) (w w2) -> b h w c', 'mean', h2=2, w2=2)
这里的'mean'指定池化方式。 相信你看得懂,不懂可留言提问~
高级用法
einops也可以嵌套在pytorch的layer里,请看:
# example given for pytorch, but code in other frameworks is almost identical
from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
from einops.layers.torch import Rearrange
model = Sequential(
Conv2d(3, 6, kernel_size=5),
MaxPool2d(kernel_size=2),
Conv2d(6, 16, kernel_size=5),
MaxPool2d(kernel_size=2),
# flattening
Rearrange('b c h w -> b (c h w)'),
Linear(16*5*5, 120),
ReLU(),
Linear(120, 10),
)
这里的Rearrange是nn.module的子类,直接可以当作网络层放到模型里~
一个字,绝。
来源:https://blog.csdn.net/leviopku/article/details/116204922


猜你喜欢
- 1.在Home(你取的项目名)的config.php中添加如下配置<?phpreturn array( &nbs
- vue3.0 beta 版本已经发布有一阵子了,是时候上手体验一波了~注意,本文所有演示都是基于 vue3.0 beta 版本,不保证后续正
- 本文实例为大家分享了python七夕浪漫表白的具体代码,供大家参考,具体内容如下from turtle import *from time
- 内置函数Built-in Functionsabs()dict()help()min()setattr()all()dir()hex()ne
- 大家好,我是丁小杰!今天和大家分享Pandas中四种有关数据透视的通用函数,在数据处理中遇到这类需求时,能够很好地应对。pandas.mel
- PS:下面是转过来的,用于记录下,这个不是正则的初衷,只是用了REGEXP而已,正则的更灵活更方便 将comment表中的author_ur
- 有时需要获取远程网站的某些信息,而服务器又限制了GET方式,只能通过POST数据提交,这个时候我们可以通过asp来实现模拟提交post数据,
- 在windows下安装配置Ulipad今天推荐一款轻便的文本编辑器Ulipad,用来写一些小的Python脚本非常方便。Ulipad下载地址
- 上一篇介绍了如何在 Oracle 生成随机数字、字符串、日期、验证码以及 UUID,今天我们继续讨论在 MySQL 中生成各种随机数据的方法
- 在Python中用matplotlib画图的时候,为了区分曲线的类型,给曲线上面加一些标识或者颜色。以下是颜色和标识的汇总。颜色(color
- myisam_max_[extra]_sort_file_size足够大delay_key_write减少io,提高写入性能bulk_ins
- 在查询凭证、审核凭证时出现“列前缀tempdb.无效: 未指定表名”的错误提示,怎么解决?原因:是因为SQL2000无法识别计算机名称中”-
- Python2.7: 使用Pyhook模块监听鼠标键盘事件-获取坐标。因该模块对Python3 有兼容性问题,故采用python2.7解释器
- Python 正则表达式是什么学习 Python 正则表达式离不开 re 模块,所以本篇博客会配合 re 模块进行编写。re 库是 Pyth
- 通过查看书籍,自己总结了一下,怎样用python代码实现调用笔记本摄像头的功能。这主要是通过opencv中cv2模块来实现这个功能。其中是调
- python逆序的三位数程序每次读入一个正3位数,然后输出按位逆序的数字。注意:当输入的数字含有结尾的0时,输出不应带有前导的0。比如输入7
- 新建图像文件后选Channels面板,新建Alpha1通道; 做压
- 代码如下#!/bin/python#coding=utf-8#python-version=2.75  
- 01.简介当我们使用的鱼眼镜头视角大于160°时,OpenCV中用于校准镜头“经典”方法的效果可能就不是和理想了。即使我们仔细遵循OpenC
- webpack.config.js文件通常放在项目的根目录中,它本身也是一个标准的Commonjs规范的模块。var webpack = r