pytorch中的自定义反向传播,求导实例
作者:xuxiaoyuxuxiaoyu 发布时间:2021-08-07 06:57:53
pytorch中自定义backward()函数。在图像处理过程中,我们有时候会使用自己定义的算法处理图像,这些算法多是基于numpy或者scipy等包。
那么如何将自定义算法的梯度加入到pytorch的计算图中,能使用Loss.backward()操作自动求导并优化呢。下面的代码展示了这个功能`
import torch
import numpy as np
from PIL import Image
from torch.autograd import gradcheck
class Bicubic(torch.autograd.Function):
def basis_function(self, x, a=-1):
x_abs = np.abs(x)
if x_abs < 1 and x_abs >= 0:
y = (a + 2) * np.power(x_abs, 3) - (a + 3) * np.power(x_abs, 2) + 1
elif x_abs > 1 and x_abs < 2:
y = a * np.power(x_abs, 3) - 5 * a * np.power(x_abs, 2) + 8 * a * x_abs - 4 * a
else:
y = 0
return y
def bicubic_interpolate(self,data_in, scale=1 / 4, mode='edge'):
# data_in = data_in.detach().numpy()
self.grad = np.zeros(data_in.shape,dtype=np.float32)
obj_shape = (int(data_in.shape[0] * scale), int(data_in.shape[1] * scale), data_in.shape[2])
data_tmp = data_in.copy()
data_obj = np.zeros(shape=obj_shape, dtype=np.float32)
data_in = np.pad(data_in, pad_width=((2, 2), (2, 2), (0, 0)), mode=mode)
print(data_tmp.shape)
for axis0 in range(obj_shape[0]):
f_0 = float(axis0) / scale - np.floor(axis0 / scale)
int_0 = int(axis0 / scale) + 2
axis0_weight = np.array(
[[self.basis_function(1 + f_0), self.basis_function(f_0), self.basis_function(1 - f_0), self.basis_function(2 - f_0)]])
for axis1 in range(obj_shape[1]):
f_1 = float(axis1) / scale - np.floor(axis1 / scale)
int_1 = int(axis1 / scale) + 2
axis1_weight = np.array(
[[self.basis_function(1 + f_1), self.basis_function(f_1), self.basis_function(1 - f_1), self.basis_function(2 - f_1)]])
nbr_pixel = np.zeros(shape=(obj_shape[2], 4, 4), dtype=np.float32)
grad_point = np.matmul(np.transpose(axis0_weight, (1, 0)), axis1_weight)
for i in range(4):
for j in range(4):
nbr_pixel[:, i, j] = data_in[int_0 + i - 1, int_1 + j - 1, :]
for ii in range(data_in.shape[2]):
self.grad[int_0 - 2 + i - 1, int_1 - 2 + j - 1, ii] = grad_point[i,j]
tmp = np.matmul(axis0_weight, nbr_pixel)
data_obj[axis0, axis1, :] = np.matmul(tmp, np.transpose(axis1_weight, (1, 0)))[:, 0, 0]
# img = np.transpose(img[0, :, :, :], [1, 2, 0])
return data_obj
def forward(self,input):
print(type(input))
input_ = input.detach().numpy()
output = self.bicubic_interpolate(input_)
# return input.new(output)
return torch.Tensor(output)
def backward(self,grad_output):
print(self.grad.shape,grad_output.shape)
grad_output.detach().numpy()
grad_output_tmp = np.zeros(self.grad.shape,dtype=np.float32)
for i in range(self.grad.shape[0]):
for j in range(self.grad.shape[1]):
grad_output_tmp[i,j,:] = grad_output[int(i/4),int(j/4),:]
grad_input = grad_output_tmp*self.grad
print(type(grad_input))
# return grad_output.new(grad_input)
return torch.Tensor(grad_input)
def bicubic(input):
return Bicubic()(input)
def main():
hr = Image.open('./baboon/baboon_hr.png').convert('L')
hr = torch.Tensor(np.expand_dims(np.array(hr), axis=2))
hr.requires_grad = True
lr = bicubic(hr)
print(lr.is_leaf)
loss=torch.mean(lr)
loss.backward()
if __name__ =='__main__':
main()
要想实现自动求导,必须同时实现forward(),backward()两个函数。
1、从代码中可以看出来,forward()函数是针对numpy数据操作,返回值再重新指定为torch.Tensor类型。因此就有这个问题出现了:forward输入input被转换为numpy类型,输出转换为tensor类型,那么输出output的grad_fn参数是如何指定的呢。调试发现,当main()中hr的requires_grad被指定为True,即hr被指定为需要求导的叶子节点。只要Bicubic类继承自torch.autograd.Function,那么output也就是代码中的lr的grad_fn就会被指定为<main.Bicubic object at 0x000001DD5A280D68>,即Bicubic这个类。
2、backward()为求导的函数,gard_output是链式求导法则的上一级的梯度,grad_input即为我们想要得到的梯度。只需要在输入指定grad_output,在调用loss.backward()过程中的某一步会执行到Bicubic的backwward()函数
来源:https://blog.csdn.net/xuxiaoyuxuxiaoyu/article/details/86737492


猜你喜欢
- 序列解包(Sequence Unpacking)是Python中非常重要和常用的一个功能,可以使用非常简洁的形式完成复杂的功能,大幅度提高了
- python2和python3实现在图片上加汉字,最主要的区别还是内部编码方式不一样导致的,在代码上表现为些许的差别。理解了内部编码原理也就
- 很久以前写过如何成为优秀的设计师,近半年来经常做设计评审,有很多感触,顺便写一点下来,我们的Blog也应该有更高的更新频率。言归正传,我认为
- ini文件是windows中经常使用的配置文件,主要的格式为:[Section1]option1 : value1option2 : val
- 本文实例讲述了Python爬虫爬取电影票房数据及图表展示操作。分享给大家供大家参考,具体如下:爬虫电影历史票房排行榜 http://www.
- 前提条件:本地已经安装好oracle单实例,能使用plsql developer连接,或者能使用TNS连接串远程连接到oracle集群读取e
- 前言不要在用手敲生成Excel数据报表了,用Python自动生成Excel数据报表!废话不多说让我们愉快地开始吧~开发工具Python版本:
- 微信小程序全称微信公众平台·小程序,原名微信公众平台·应用号(简称微信应用号)声明•微信小程序开发工具类似于一个轻量级的IDE集成开发环境,
- 约定:import pandas as pdimport numpy as npReIndex重新索引reindex()是pandas对象的
- 楔子由于之前电脑上安装的MySQL版本是比较老的了,大概是5.1的版本,不支持JSON字段功能。而最新开发部门开发的的编辑器产品,使用到了J
- 什么是ODBCODBC是open database connect的缩写,意思是开放式数据库连接利用ODBC进行数据库连接首先要下载数据库!
- 一、使用loadVariables 一个例子简单的描述了如何通过GET方法向服务器端的ASP发送请求: _root. pushAc
- 最近在重新看vue3的rfcs,发现一个细节,原话如下:props that start with on are handled as v-
- 本文实例讲述了Python实现的排列组合计算操作。分享给大家供大家参考,具体如下:1. 调用 scipy 计算排列组合的具体数值>&g
- 1.函数array() 功能:创建一个数组变量 格式:array(list) 参数:list为数组变量中的每个数值列,中间用逗号间隔 例子:
- windows下python的安装教程,供大家参考,具体内容如下—–因为我是个真小白,网上的大多入门教程并不适合我这种超级超级小白,有时候还
- python在进行字符串的拼接时,一般有两种方法,一种是使用+直接相加,另一种是使用joina = "tests"b =
- 本文实例讲述了js表格排序的方法。分享给大家供大家参考。具体如下:<html><head><title>
- 1.创建一个项目django-admin.py startproject HelloWorld2.进入HelloWorld项目,在manag
- 1.C++ 代码Demo.h#pragma oncevoid GeneratorGaussKernel(int ksize, float s