PyTorch小功能之TensorDataset解读
作者:菜鸟向前冲fighting 发布时间:2023-02-26 06:06:27
PyTorch之TensorDataset
TensorDataset 可以用来对 tensor 进行打包,就好像 python 中的 zip 功能。
该类通过每一个 tensor 的第一个维度进行索引。
因此,该类中的 tensor 第一维度必须相等。
from torch.utils.data import TensorDataset
import torch
from torch.utils.data import DataLoader
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([44, 55, 66, 44, 55, 66, 44, 55, 66, 44, 55, 66])
train_ids = TensorDataset(a, b)
# 切片输出
print(train_ids[0:2])
print('=' * 80)
# 循环取数据
for x_train, y_label in train_ids:
print(x_train, y_label)
# DataLoader进行数据封装
print('=' * 80)
train_loader = DataLoader(dataset=train_ids, batch_size=4, shuffle=True)
for i, data in enumerate(train_loader, 1): # 注意enumerate返回值有两个,一个是序号,一个是数据(包含训练数据和标签)
x_data, label = data
print(' batch:{0} x_data:{1} label: {2}'.format(i, x_data, label))
运行结果:
(tensor([[1, 2, 3],
[4, 5, 6]]), tensor([44, 55]))
================================================================================
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
tensor([1, 2, 3]) tensor(44)
tensor([4, 5, 6]) tensor(55)
tensor([7, 8, 9]) tensor(66)
================================================================================
batch:1 x_data:tensor([[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
[4, 5, 6]]) label: tensor([44, 44, 55, 55])
batch:2 x_data:tensor([[4, 5, 6],
[7, 8, 9],
[7, 8, 9],
[7, 8, 9]]) label: tensor([55, 66, 66, 66])
batch:3 x_data:tensor([[1, 2, 3],
[1, 2, 3],
[7, 8, 9],
[4, 5, 6]]) label: tensor([44, 44, 66, 55])
注意:TensorDataset 中的参数必须是 tensor
Pytorch中TensorDataset的快速使用
Pytorch中,TensorDataset()可以快速构建训练所用的数据,不用使用自建的Mydataset(),如果没有熟悉适用的dataset可以使用TensorDataset()作为暂时替代。
只需要把data和label作为参数输入,就可以快速构建,之后便可以用Dataloader处理。
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
data = np.loadtxt('x.txt')
label = np.loadtxt('y.txt')
data = torch.tensor(data)
label = torch.tensor(label)
train_data = TensorDataset(data, label)
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
来源:https://blog.csdn.net/qq_40211493/article/details/107529148


猜你喜欢
- 指定的代码页特性无效。 codepage属性:是指出网页的代码页 如果制作的网页脚本与WEB服务端的默认代码页不同,则必须指明代码页: 代码
- var YX = { //得到JS内置数据类型的类型,返回值包括[Date,RegExp,Number,String,Array,Boole
- IE8主页http://www.microsoft.com/windows/products/winfamily/ie/ie8/defaul
- Default.aspx<%@ Page Language="C#" AutoEventWireup="
- 本文实例讲述了JS实现简单的二元方程计算器功能。分享给大家供大家参考,具体如下:<!DOCTYPE HTML PUBLIC "
- 前言索引的本质是存储引擎用于快速查询记录的一种数据结构。特别是数据表中数据特别多的时候,索引对于数据库的性能就愈发重要。在数据量比较大的时候
- 在 Python 中,我们可以使用基本的索引操作来获取数组中的元素。然而,有时候我们需要获取一个数组的子数组,也就是只获取数组中的一部分元素
- get_template()中使用子目录把所有的模板都存放在一个目录下可能会让事情变得难以掌控。 你可能会考虑把模板存放在你模板目录的子目录
- 文档介绍利用python写“猜数字”,“猜词语”,“谁是卧底”这三个游戏,从而快速掌握python编程的入门知识,包括python语法/列表
- 我也一一试过,结果是:中文乱码问题没解决,mysql服务却不能启动了, 汗颜了,还是自己动手解决吧,我这里也截图了,方便参观。我用的是app
- 使用Django中遇到这样一个需求,对一个表的几个字段做 联合唯一索引,例如学生表中 姓名和班级 2个字段在一起表示一个唯一记录。Djang
- 1 非贪婪flag>>> re.findall(r"a(\d+?)", "a23b"
- 环境:adobe flash CS4,VS2008 , Access2003 实现步骤: 1、创建ASP.net页面 testCommuni
- 目录问题一:默认的 HTTP Client问题二:默认的 Http Transport总结HTTP(超文本传输协议)是一种用于客户端和服务器
- 模块:包含定义函数和变量的python文件,可以被别的程序引入。os模块是操作系统接口模块,提供了一些方便使用操作系统相关功能函数,这里介绍
- 使用rpm安装方式安装完MySQL数据库后,数据文件的默认路径为/var/lib/mysql,然而根目录并不适合用于存储数据文件。原路径:/
- 问一下谁知道如何用 javascript 获取硬盘信息1.获得硬盘当前有几个盘符.2.每个盘符的 大小,已经使用的大小,和没有使用的大小原理
- 在《python深度学习》这本书中。一、21页mnist十分类导入数据集from keras.datasets import mnist(t
- 文章所罗列的问题虽然看似简单,但是每个背后都涵盖了一个或几个大家容易忽视的基础知识点,希望能够帮助到你的面试和平时工作。Q1第一个问题关于弱
- 进入查询窗口后,输入下面的语句: CREATE INDEX mycolumn_index ON mytable (myclumn) 这个语句