Pytorch基本变量类型FloatTensor与Variable用法
作者:jingxian 发布时间:2022-10-10 21:14:45
标签:Pytorch,FloatTensor,Variable
pytorch中基本的变量类型当属FloatTensor(以下都用floattensor),而Variable(以下都用variable)是floattensor的封装,除了包含floattensor还包含有梯度信息
pytorch中的dochi给出一些对于floattensor的基本的操作,比如四则运算以及平方等(链接),这些操作对于floattensor是十分的不友好,有时候需要写一个正则化的项需要写很长的一串,比如两个floattensor之间的相加需要用torch.add()来实现
然而正确的打开方式并不是这样
韩国一位大神写了一个pytorch的turorial,其中包含style transfer的一个代码实现
for step in range(config.total_step):
# Extract multiple(5) conv feature vectors
target_features = vgg(target) # 每一次输入到网络中的是同样一张图片,反传优化的目标是输入的target
content_features = vgg(Variable(content))
style_features = vgg(Variable(style))
style_loss = 0
content_loss = 0
for f1, f2, f3 in zip(target_features, content_features, style_features):
# Compute content loss (target and content image)
content_loss += torch.mean((f1 - f2)**2) # square 可以进行直接加-操作?可以,并且mean对所有的元素进行均值化造作
# Reshape conv features
_, c, h, w = f1.size() # channel height width
f1 = f1.view(c, h * w) # reshape a vector
f3 = f3.view(c, h * w) # reshape a vector
# Compute gram matrix
f1 = torch.mm(f1, f1.t())
f3 = torch.mm(f3, f3.t())
# Compute style loss (target and style image)
style_loss += torch.mean((f1 - f3)**2) / (c * h * w) # 总共元素的数目?
其中f1与f2,f3的变量类型是Variable,作者对其直接用四则运算符进行加减,并且用python内置的**进行平方操作,然后
# -*-coding: utf-8 -*-
import torch
from torch.autograd import Variable
# dtype = torch.FloatTensor
dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU
# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10
# Randomly initialize weights
w1 = torch.randn(D_in, H).type(dtype) # 两个权重矩阵
w2 = torch.randn(D_in, H).type(dtype)
# operate with +-*/ and **
w3 = w1-2*w2
w4 = w3**2
w5 = w4/w1
# operate the Variable with +-*/ and **
w6 = Variable(torch.randn(N, D_in).type(dtype))
w7 = Variable(torch.randn(N, D_in).type(dtype))
w8 = w6 + w7
w9 = w6*w7
w10 = w9**2
print(1)
基本上调试的结果与预期相符
所以,对于floattensor以及variable进行普通的+-×/以及**没毛病
来源:https://blog.csdn.net/u013517182/article/details/93051322


猜你喜欢
- 前言在上一章中,我们通过基础的搭建,成功的渲染了列表页面.但是,其中的问题是很多的.这一章,我们来解决这些问题.使用 v-bind 绑定数据
- mysqladmin是MySQL官方提供的shell命令行工具,它的参数都需要在shell
- 动态页面的模拟点击:以斗鱼直播为例:http://www.douyu.com/directory/all爬取每页的房间名、直播类型、主播名称
- 前言结构体是包含多个字段的集合类型,用于将数据组合为记录。这样可以将与同一实体相关联的数据利落地封装到一个轻量的类型定义中,然后通过对该结构
- 本文实例讲述了Python二叉树定义与遍历方法。分享给大家供大家参考,具体如下:二叉树基本概述:二叉树是有限个元素的几个,如果为空则为空二叉
- 如何用ASP输出HTML文件?<!--#include file="top.inc"--><
- 如下所示:l = [1, 2, 3, 5]l_one = [2, 8, 6, 10]print set(l) & set(l_one
- 前言最近学习了Fiddler抓包工具的简单使用,通过抓包,我们可以抓取到HTTP请求,并对其进行分析。现在我准备尝试着结合Python来模拟
- python中的集合什么是集合?集合是一个无序的不重复元素序列常用来对两个列表进行交并差的处理集合与列表一样,支持所有数据类型集合与列表的区
- request post 列表的方法今天拿着已经写好的服务接口, 尝试传送一些列表, 发现传送的结果跟实际传送的数据并不一致,然后又开始了漫
- 本文实例讲述了Python3实现的反转单链表算法。分享给大家供大家参考,具体如下:反转一个单链表。方案一:迭代# Definition fo
- 设计是一个输入-输出的过程,因为首先有用户的需求,客户的项目才有设计的产生,设计是带有目的性和市场行为的,当然也有一部分的创造性设计,仅仅为
- 续上一篇文章:vue2.0 开发实践总结之入门篇 ,如果没有看过的可以移步看一下。 本篇文章目录如下:1. vue 组
- 基于Bootstrap jQuery.validate Form表单验证实践项目结构 :github 上源码地址:https://githu
- 第一次写ASP类,实现功能:分段统计程序执行时间,输出统计表等.程序代码:Class ccClsProcessTimeRecord
- 问题在Django中使用mysql偶尔会出现数据库连接丢失的情况,错误通常有如下两种OperationalError: (2006,
- 共享标签默认情况下,git push 命令并不会传送标签到远程仓库服务器上。在创建完标签后,你必须显式地(手动)推送标签到远程服务
- 微信小程序canvas写字板效果及实例写字板效果:书写文字,画板重置,导出图片,导出图片前判断是否书写内容app.json:添加一个路由:&
- 今天研究了个开源项目,数据库是mysql的,其中的脚本数据需要备份,由于本人的机器时mac pro,而且mac下的数据库连接工具都不怎么好用
- 在使用python爬虫技术采集数据信息时,经常会遇到在返回的网页信息中,无法抓取动态加载的可用数据。例如,获取某网页中,商品价格时就会出现此