python神经网络Pytorch中Tensorboard函数使用
作者:Bubbliiiing 发布时间:2021-03-30 04:27:01
标签:python,神经网络,Pytorch,Tensorboard,函数
所需库的安装
很多人问Pytorch要怎么可视化,于是决定搞一篇。
tensorboardX==2.0
tensorflow==1.13.2
由于tensorboard原本是在tensorflow里面用的,所以需要装一个tensorflow。会自带一个tensorboard。
也可以不装tensorboardX,直接使用pytorch当中的自带的Tensorboard。导入方式如下:
from torch.utils.tensorboard import SummaryWriter
不过由于我使用pytorch当中的自带的Tensorboard的时候有一些bug。所以还是使用tensorboardX来写这篇博客。
常用函数功能
1、SummaryWriter()
这个函数用于创建一个tensorboard文件,其中常用参数有:
log_dir:tensorboard文件的存放路径flush_secs:表示写入tensorboard文件的时间间隔
调用方式如下:
writer = SummaryWriter(log_dir='logs',flush_secs=60)
2、writer.add_graph()
这个函数用于在tensorboard中创建Graphs,Graphs中存放了网络结构,其中常用参数有:
model:pytorch模型
input_to_model:pytorch模型的输入
如下所示为graphs:
调用方式如下:
if Cuda:
graph_inputs = torch.from_numpy(np.random.rand(1,3,input_shape[0],input_shape[1])).type(torch.FloatTensor).cuda()
else:
graph_inputs = torch.from_numpy(np.random.rand(1,3,input_shape[0],input_shape[1])).type(torch.FloatTensor)
writer.add_graph(model, (graph_inputs,))
3、writer.add_scalar()
这个函数用于在tensorboard中加入loss,其中常用参数有:
tag:标签,如下图所示的Train_loss
scalar_value:标签的值
global_step:标签的x轴坐标
调用方式如下:
writer.add_scalar('Train_loss', loss, (epoch*epoch_size + iteration))
4、tensorboard --logdir=
在完成tensorboard文件的生成后,可在命令行调用该文件,tensorboard网址。具体代码如下:
tensorboard --logdir=D:\Study\Collection\Tensorboard-pytorch\logs
示例代码
import torch
from torch.autograd import Variable
import torch.nn.functional as functional
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np
# x的shape为(100,1)
x = torch.from_numpy(np.linspace(-1,1,100).reshape([100,1])).type(torch.FloatTensor)
# y的shape为(100,1)
y = torch.sin(x) + 0.2*torch.rand(x.size())
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
# Applies a linear transformation to the incoming data: :math:y = xA^T + b
# 全连接层,公式为y = xA^T + b
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
def forward(self, x):
# 隐含层的输出
hidden_layer = functional.relu(self.hidden(x))
output_layer = self.predict(hidden_layer)
return output_layer
# 类的建立
net = Net(n_feature=1, n_hidden=10, n_output=1)
writer = SummaryWriter('logs')
graph_inputs = torch.from_numpy(np.random.rand(2,1)).type(torch.FloatTensor)
writer.add_graph(net, (graph_inputs,))
# torch.optim是优化器模块
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# 均方差loss
loss_func = torch.nn.MSELoss()
for t in range(1000):
prediction = net(x)
loss = loss_func(prediction, y)
# 反向传递步骤
# 1、初始化梯度
optimizer.zero_grad()
# 2、计算梯度
loss.backward()
# 3、进行optimizer优化
optimizer.step()
writer.add_scalar('loss',loss, t)
writer.close()
效果如下:
来源:https://blog.csdn.net/weixin_44791964/article/details/106701052
0
投稿
猜你喜欢
- 一:最近,经常碰到有网友问,如何使vbscript和javascipt传递变量。不知道为什么要这么做。因为每一种脚本语言几乎都可以完成所需要
- 矩阵相乘需要前面矩阵的行数与后面矩阵的列数相同方可相乘。第一步,先将前面矩阵的每一行分别与后面矩阵的列相乘,作为结果矩阵的行列;第二步算出结
- 原文地址:30 Days of Mootools 1.2 Tutorials - Day 20 - A Few Mootools Tabs项
- 获取评论贴的请求头与表单数据下一篇在这里这里,我们随便选取一个网站,获取该贴评论后的请求头,表单数据以及评论贴链接。(因为涉及敏感信息,自己
- 我们进行CSS网页布局的时候,都知道它需要符合XHTML1.0规范。如果我们在进行CSS网页布局的时候,还在使用被W3C废弃的元素,那就失去
- 作者:Scott Gerber原标题:Mobile App Development: 10 Tips for Small Business
- 最近在做一个游戏数据统计后台,最基础的功能是通过分析注册登录日志来展示用户数据。在公司内部测试,用户量很少,所以就没有发现什么性能问题。但是
- 目录一、比较汽车性能二、比较不同城市近期天气状况雷达图是以从同一点开始的轴上表示的三个或更多个定量变量的二维图表的形式显示多变量数据的图形方
- 将有安全问题的SQL过程删除.比较全面.一切为了安全!删除了调用shell,注册表,COM组件的破坏权限MS SQL SERVER2000使
- 基本简介dot函数为numpy库下的一个函数,主要用于矩阵的乘法运算,其中包括:向量内积、多维矩阵乘法和矩阵与向量的乘法。1. 向量内积向量
- 一、弹窗事件是什么?弹窗事件就是在我们执行某操作的时候,弹出信息框给出提示。或收集数据的时候,弹出窗口收集信息,不想收集可以取消隐藏。二、简
- php输出文字乱码的解决办法:在php文件最开头写上:<?phpheader('Content-type: text/html
- 今天闲逛在网上时,看到一个11px大小的字体,显示却很清晰,赶紧查看站点的CSS,这字体称叫做:PMingLiu。效果相当不错,相比于我们使
- 在团队意见PK中,运用对方的知识背景说服对方,这就是技术性击倒。这样通常能把对方驳得哑口无言,我经常被这样击倒,甚至觉得怎么那么多牛逼的设计
- 在cssrain整理的一个 试题集 中有这么一道题:<SCRIPT LANGUAGE="JavaScript"&g
- 知识点: Array方法: sort:降序 reverse:反序 效果: 代码: <style> *{ margin
- <% Function XMLEncode(byVal sText) sText = Replace(sText, "&am
- 背景大家好,我是J哥。我们常常面临着大量的重复性工作,通过人工方式处理往往耗时耗力易出错。而Python在办公自动化方面具有天然优势,分分钟
- 本文实例为大家分享了python opencv识别图像轮廓的具体代码,供大家参考,具体内容如下要求:用矩形或者圆形框住图片中的云朵(不要求全
- 本文实例讲述了php+mysqli实现批量替换数据库表前缀的方法。分享给大家供大家参考。具体分析如下:在php中有时我们要替换数据库中表前缀