pytorch教程网络和损失函数的可视化代码示例
作者:xz1308579340 发布时间:2023-11-26 16:13:51
标签:pytorch,可视化,损失函数
1.效果
2.环境
1.pytorch
2.visdom
3.python3.5
3.用到的代码
# coding:utf8
import torch
from torch import nn, optim # nn 神经网络模块 optim优化函数模块
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchvision import transforms, datasets
from visdom import Visdom # 可视化处理模块
import time
import numpy as np
# 可视化app
viz = Visdom()
# 超参数
BATCH_SIZE = 40
LR = 1e-3
EPOCH = 2
# 判断是否使用gpu
USE_GPU = True
if USE_GPU:
gpu_status = torch.cuda.is_available()
else:
gpu_status = False
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
# 数据引入
train_dataset = datasets.MNIST('../data', True, transform, download=False)
test_dataset = datasets.MNIST('../data', False, transform)
train_loader = DataLoader(train_dataset, BATCH_SIZE, True)
# 为加快测试,把测试数据从10000缩小到2000
test_data = torch.unsqueeze(test_dataset.test_data, 1)[:1500]
test_label = test_dataset.test_labels[:1500]
# visdom可视化部分数据
viz.images(test_data[:100], nrow=10)
#viz.images(test_data[:100], nrow=10)
# 为防止可视化视窗重叠现象,停顿0.5秒
time.sleep(0.5)
if gpu_status:
test_data = test_data.cuda()
test_data = Variable(test_data, volatile=True).float()
# 创建线图可视化窗口
line = viz.line(np.arange(10))
# 创建cnn神经网络
class CNN(nn.Module):
def __init__(self, in_dim, n_class):
super(CNN, self).__init__()
self.conv = nn.Sequential(
# channel 为信息高度 padding为图片留白 kernel_size 扫描模块size(5x5)
nn.Conv2d(in_channels=in_dim, out_channels=16,kernel_size=5,stride=1, padding=2),
nn.ReLU(),
# 平面缩减 28x28 >> 14*14
nn.MaxPool2d(kernel_size=2),
nn.Conv2d(16, 32, 3, 1, 1),
nn.ReLU(),
# 14x14 >> 7x7
nn.MaxPool2d(2)
)
self.fc = nn.Sequential(
nn.Linear(32*7*7, 120),
nn.Linear(120, n_class)
)
def forward(self, x):
out = self.conv(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
net = CNN(1,10)
if gpu_status :
net = net.cuda()
#print("#"*26, "使用gpu", "#"*26)
else:
#print("#" * 26, "使用cpu", "#" * 26)
pass
# loss、optimizer 函数设置
loss_f = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=LR)
# 起始时间设置
start_time = time.time()
# 可视化所需数据点
time_p, tr_acc, ts_acc, loss_p = [], [], [], []
# 创建可视化数据视窗
text = viz.text("<h1>convolution Nueral Network</h1>")
for epoch in range(EPOCH):
# 由于分批次学习,输出loss为一批平均,需要累积or平均每个batch的loss,acc
sum_loss, sum_acc, sum_step = 0., 0., 0.
for i, (tx, ty) in enumerate(train_loader, 1):
if gpu_status:
tx, ty = tx.cuda(), ty.cuda()
tx = Variable(tx)
ty = Variable(ty)
out = net(tx)
loss = loss_f(out, ty)
#print(tx.size())
#print(ty.size())
#print(out.size())
sum_loss += loss.item()*len(ty)
#print(sum_loss)
pred_tr = torch.max(out,1)[1]
sum_acc += sum(pred_tr==ty).item()
sum_step += ty.size(0)
# 学习反馈
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 每40个batch可视化一下数据
if i % 40 == 0:
if gpu_status:
test_data = test_data.cuda()
test_out = net(test_data)
print(test_out.size())
# 如果用gpu运行out数据为cuda格式需要.cpu()转化为cpu数据 在进行比较
pred_ts = torch.max(test_out, 1)[1].cpu().data.squeeze()
print(pred_ts.size())
rightnum = pred_ts.eq(test_label.view_as(pred_ts)).sum().item()
#rightnum =sum(pred_tr==ty).item()
# sum_acc += sum(pred_tr==ty).item()
acc = rightnum/float(test_label.size(0))
print("epoch: [{}/{}] | Loss: {:.4f} | TR_acc: {:.4f} | TS_acc: {:.4f} | Time: {:.1f}".format(epoch+1, EPOCH,
sum_loss/(sum_step), sum_acc/(sum_step), acc, time.time()-start_time))
# 可视化部分
time_p.append(time.time()-start_time)
tr_acc.append(sum_acc/sum_step)
ts_acc.append(acc)
loss_p.append(sum_loss/sum_step)
viz.line(X=np.column_stack((np.array(time_p), np.array(time_p), np.array(time_p))),
Y=np.column_stack((np.array(loss_p), np.array(tr_acc), np.array(ts_acc))),
win=line,
opts=dict(legend=["Loss", "TRAIN_acc", "TEST_acc"]))
# visdom text 支持html语句
viz.text("<p style='color:red'>epoch:{}</p><br><p style='color:blue'>Loss:{:.4f}</p><br>"
"<p style='color:BlueViolet'>TRAIN_acc:{:.4f}</p><br><p style='color:orange'>TEST_acc:{:.4f}</p><br>"
"<p style='color:green'>Time:{:.2f}</p>".format(epoch, sum_loss/sum_step, sum_acc/sum_step, acc,
time.time()-start_time),
win=text)
sum_loss, sum_acc, sum_step = 0., 0., 0.
来源:https://blog.csdn.net/xz1308579340/article/details/85015343


猜你喜欢
- 下载地址下载地址: https://dev.mysql.com/downloads/mysql/解压安装将下载好的zip压缩包解压到你的安装
- 首先 跳过权限表模式启动MySQL:mysqld --skip-grant-tables &从现在开始,你将踏入第一个坑
- 从Python字符串中删除最后一个分号或者逗号第一种方法使用 str.rstrip() 方法从字符串中删除最后一个逗号,例如 new_str
- 一、条件简化我们编写的查询语句的搜索条件本质上是一个表达式,这些表达式可能比较繁杂,或者不能高效的执行,MySQL的查询优化器会为我们简化这
- Windows下的分隔符默认的是逗号,而MAC的分隔符是分号。拿到一份用分号分割的CSV文件,在Win下是无法正确读取的,因为CSV模块默认
- 京东商品详细的请求处理,是先显示html,然后再ajax请求处理显示价格。1.可以运行js,并解析之后得到的html2.模拟js请求,得到价
- Lambda函数,即Lambda 表达式(lambda expression),是一个匿名函数(不存在函数名的函数),Lambda表达式基于
- 最近使用工作需要,使用了Navicat8.2版本,发现备份数据都是默认存储在C盘,这个就比较郁闷了。重做系统忘记转移了。那不就死定了?找了一
- 说明: 本文的自动更新功能使用的项目为 electron-vue 脚手架搭建一个默认项目。 参考的文章如下:electron-vue 中文文
- 个人想到的解决方法有两种,一种是 .replace(' old ',' new ')
- MySQL数据库是一款非常好用的数据库管理系统,但是相对来说卸载起来麻烦一些这里给大家分享下MySQL数据库如何卸载干净~1 停止MySQL
- 本文介绍基于Python语言gdal模块,实现多波段HDF栅格图像文件的读取、处理与像元值可视化(直方图绘制)等操作。另外,基于gdal等模
- 在我们python中输入输出函数在程序中运用较为广泛,运算符常用于if判断的条件中,今天我来给大家讲解这两项概念.input输入和print
- tf.keras.layers模块中的函数from __future__ import print_function as _print_f
- Usage example (libtiff wrapper)from libtiff import TIFF# to open a tif
- Python3中二叉树前序遍历的迭代解决方案A Binary Tree二叉树是分层数据结构,其中每个父节点最多有 2 个子节点。在今天的文章
- 之前希望在手机端使用深度模型做OCR,于是尝试在手机端部署tensorflow模型,用于图像分类。思路主要是想使用tflite部署到安卓端,
- 微博上讨论MySQL在删除大表engine=innodb(30G+)时,如何减少MySQL hang的时间,现做一下简单总结: 当buffe
- 正则表达式中的符号例子 | 是或的关系,只要存在就会被捕获匹配到的数据只按字符串顺序返回,而不是按照匹配规则返回In [18]:
- 本文实例讲述了JS上传图片前实现图片预览效果的方法。分享给大家供大家参考。具体实现方法如下:<!doctype html public