pytorch中使用cuda扩展的实现示例
作者:outthinker 发布时间:2021-02-17 23:46:55
标签:pytorch,cuda
以下面这个例子作为教程,实现功能是element-wise add;
(pytorch中想调用cuda模块,还是用另外使用C编写接口脚本)
第一步:cuda编程的源文件和头文件
// mathutil_cuda_kernel.cu
// 头文件,最后一个是cuda特有的
#include <curand.h>
#include <stdio.h>
#include <math.h>
#include <float.h>
#include "mathutil_cuda_kernel.h"
// 获取GPU线程通道信息
dim3 cuda_gridsize(int n)
{
int k = (n - 1) / BLOCK + 1;
int x = k;
int y = 1;
if(x > 65535) {
x = ceil(sqrt(k));
y = (n - 1) / (x * BLOCK) + 1;
}
dim3 d(x, y, 1);
return d;
}
// 这个函数是cuda执行函数,可以看到细化到了每一个元素
__global__ void broadcast_sum_kernel(float *a, float *b, int x, int y, int size)
{
int i = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
if(i >= size) return;
int j = i % x; i = i / x;
int k = i % y;
a[IDX2D(j, k, y)] += b[k];
}
// 这个函数是与c语言函数链接的接口函数
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream)
{
int size = x * y;
cudaError_t err;
// 上面定义的函数
broadcast_sum_kernel<<<cuda_gridsize(size), BLOCK, 0, stream>>>(a, b, x, y, size);
err = cudaGetLastError();
if (cudaSuccess != err)
{
fprintf(stderr, "CUDA kernel failed : %s\n", cudaGetErrorString(err));
exit(-1);
}
}
#ifndef _MATHUTIL_CUDA_KERNEL
#define _MATHUTIL_CUDA_KERNEL
#define IDX2D(i, j, dj) (dj * i + j)
#define IDX3D(i, j, k, dj, dk) (IDX2D(IDX2D(i, j, dj), k, dk))
#define BLOCK 512
#define MAX_STREAMS 512
#ifdef __cplusplus
extern "C" {
#endif
void broadcast_sum_cuda(float *a, float *b, int x, int y, cudaStream_t stream);
#ifdef __cplusplus
}
#endif
#endif
第二步:C编程的源文件和头文件(接口函数)
// mathutil_cuda.c
// THC是pytorch底层GPU库
#include <THC/THC.h>
#include "mathutil_cuda_kernel.h"
extern THCState *state;
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y)
{
float *a = THCudaTensor_data(state, a_tensor);
float *b = THCudaTensor_data(state, b_tensor);
cudaStream_t stream = THCState_getCurrentStream(state);
// 这里调用之前在cuda中编写的接口函数
broadcast_sum_cuda(a, b, x, y, stream);
return 1;
}
int broadcast_sum(THCudaTensor *a_tensor, THCudaTensor *b_tensor, int x, int y);
第三步:编译,先编译cuda模块,再编译接口函数模块(不能放在一起同时编译)
nvcc -c -o mathutil_cuda_kernel.cu.o mathutil_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
import os
import torch
from torch.utils.ffi import create_extension
this_file = os.path.dirname(__file__)
sources = []
headers = []
defines = []
with_cuda = False
if torch.cuda.is_available():
print('Including CUDA code.')
sources += ['src/mathutil_cuda.c']
headers += ['src/mathutil_cuda.h']
defines += [('WITH_CUDA', None)]
with_cuda = True
this_file = os.path.dirname(os.path.realpath(__file__))
extra_objects = ['src/mathutil_cuda_kernel.cu.o'] # 这里是编译好后的.o文件位置
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
ffi = create_extension(
'_ext.cuda_util',
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
extra_objects=extra_objects
)
if __name__ == '__main__':
ffi.build()
第四步:调用cuda模块
from _ext import cuda_util #从对应路径中调用编译好的模块
a = torch.randn(3, 5).cuda()
b = torch.randn(3, 1).cuda()
mathutil.broadcast_sum(a, b, *map(int, a.size()))
# 上面等价于下面的效果:
a = torch.randn(3, 5)
b = torch.randn(3, 1)
a += b
来源:https://www.cnblogs.com/zf-blog/p/11883166.html


猜你喜欢
- 一个例子让你彻底明白原型对象和原型链开篇之前对js中的原型链和原型对象有所了解,每当别人问我什么是原型链和原型对象时,我总是用很官方(其实自
- 本文为大家分享了python实现外卖信息管理系统的具体代码,供大家参考,具体内容如下一、需求分析 需求分析包含如下:1、问题描述 以外卖信息
- Django里面集成了SQLite的数据库,对于初期研究来说,可以用这个学习。第一步,创建数据库就涉及到建表等一系列的工作,在此之前,要先在
- 导入库和数据首先,我们需要导入PyTorch和PyG库,然后准备好我们的数据。例如,我们可以使用以下方式生成一个简单的随机数据集:from
- 项目地址:https://github.com/chen0495/pythonCrawlerForJSU环境python 3.5即以上req
- 我们需要将【小组销量排名表.xlsx】通过邮件发送给【组长邮箱.xlsx】中的各个组长。这里会学一个新的知识点—&
- 1、re.match()的用法re.match()方法是从起始位置开始匹配一个模式,匹配成功返回一个对象,未匹配成功返回None。语法:re
- 关于Python的文件遍历,大概有两种方法,一种是较为便利的os.walk(),还有一种是利用os.listdir()递归遍历。方法一:利用
- 类型转换和类型断言类型转换语法:Type(expression)类型断言语法为:expression.(Type)1.类型转换示例代码pac
- 前言密码安全是非常重要的,因此我们在代码中往往需要对密码进行加密,以此保证密码的安全加依赖<!-- jasypt --><
- 本文实例讲述了Python wxpython模块响应鼠标拖动事件操作。分享给大家供大家参考,具体如下:wxpython鼠标拖动事件小案例:#
- //有1-22个文件夹,各文件夹下有Detect_0文件夹,此文件夹下有source与mask文件夹,目的是将需要获取图片的文件夹下的图片复
- 前言用python编程绘图,其实非常简单。中学生、大学生、研究生都能通过这10篇教程从入门到精通!快速绘制几种简单的柱状图。1垂直柱图(普通
- 在本篇的开始之前,我必须阐明,我们对数组无论是索引还是切片,我是通过编号(或称为序列号)来进行操作,请记住:无论是 0轴(行)还是 1轴(列
- 本教程为大家分享了Ubuntu手动安装mysql5.7.10的过程,供大家参考,具体内容如下1、下载安装包MySQL官网下载地址选择系统版本
- 会用到的库的1、selenium的webdriver2、tesserocr或者pytesseract进行图像识别3、pillow的Image
- 使用Django的ORM操作的时候,想要获取本条,上一条,下一条。初步的想法是写3个ORM,3个ORM如下:本条:models.Obj.ob
- 一、前言xlwt模块是python中专门用于写入Excel的拓展模块,可以实现创建表单、写入指定单元格、指定单元格样式等人工实现的功能,一句
- 查询mysql的操作信息show status -- 显示全部mysql操作信息show status like "com_ins
- 本文实例为大家分享了javascript自定义加载loading效果的具体代码,供大家参考,具体内容如下加载中图片,底色为白色(看不到)效果