python神经网络Pytorch中Tensorboard函数使用
作者:Bubbliiiing 发布时间:2021-03-30 04:27:01
标签:python,神经网络,Pytorch,Tensorboard,函数
所需库的安装
很多人问Pytorch要怎么可视化,于是决定搞一篇。
tensorboardX==2.0
tensorflow==1.13.2
由于tensorboard原本是在tensorflow里面用的,所以需要装一个tensorflow。会自带一个tensorboard。
也可以不装tensorboardX,直接使用pytorch当中的自带的Tensorboard。导入方式如下:
from torch.utils.tensorboard import SummaryWriter
不过由于我使用pytorch当中的自带的Tensorboard的时候有一些bug。所以还是使用tensorboardX来写这篇博客。
常用函数功能
1、SummaryWriter()
这个函数用于创建一个tensorboard文件,其中常用参数有:
log_dir:tensorboard文件的存放路径flush_secs:表示写入tensorboard文件的时间间隔
调用方式如下:
writer = SummaryWriter(log_dir='logs',flush_secs=60)
2、writer.add_graph()
这个函数用于在tensorboard中创建Graphs,Graphs中存放了网络结构,其中常用参数有:
model:pytorch模型
input_to_model:pytorch模型的输入
如下所示为graphs:
调用方式如下:
if Cuda:
graph_inputs = torch.from_numpy(np.random.rand(1,3,input_shape[0],input_shape[1])).type(torch.FloatTensor).cuda()
else:
graph_inputs = torch.from_numpy(np.random.rand(1,3,input_shape[0],input_shape[1])).type(torch.FloatTensor)
writer.add_graph(model, (graph_inputs,))
3、writer.add_scalar()
这个函数用于在tensorboard中加入loss,其中常用参数有:
tag:标签,如下图所示的Train_loss
scalar_value:标签的值
global_step:标签的x轴坐标
调用方式如下:
writer.add_scalar('Train_loss', loss, (epoch*epoch_size + iteration))
4、tensorboard --logdir=
在完成tensorboard文件的生成后,可在命令行调用该文件,tensorboard网址。具体代码如下:
tensorboard --logdir=D:\Study\Collection\Tensorboard-pytorch\logs
示例代码
import torch
from torch.autograd import Variable
import torch.nn.functional as functional
from tensorboardX import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np
# x的shape为(100,1)
x = torch.from_numpy(np.linspace(-1,1,100).reshape([100,1])).type(torch.FloatTensor)
# y的shape为(100,1)
y = torch.sin(x) + 0.2*torch.rand(x.size())
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
# Applies a linear transformation to the incoming data: :math:y = xA^T + b
# 全连接层,公式为y = xA^T + b
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
def forward(self, x):
# 隐含层的输出
hidden_layer = functional.relu(self.hidden(x))
output_layer = self.predict(hidden_layer)
return output_layer
# 类的建立
net = Net(n_feature=1, n_hidden=10, n_output=1)
writer = SummaryWriter('logs')
graph_inputs = torch.from_numpy(np.random.rand(2,1)).type(torch.FloatTensor)
writer.add_graph(net, (graph_inputs,))
# torch.optim是优化器模块
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# 均方差loss
loss_func = torch.nn.MSELoss()
for t in range(1000):
prediction = net(x)
loss = loss_func(prediction, y)
# 反向传递步骤
# 1、初始化梯度
optimizer.zero_grad()
# 2、计算梯度
loss.backward()
# 3、进行optimizer优化
optimizer.step()
writer.add_scalar('loss',loss, t)
writer.close()
效果如下:
来源:https://blog.csdn.net/weixin_44791964/article/details/106701052
0
投稿
猜你喜欢
- 在网上的一些资料的基础上自己又添了些新内容,算是Python socket编程练手吧。#coding=utf-8import socketi
- kNN算法是k-近邻算法的简称,主要用来进行分类实践,主要思路如下:1.存在一个训练数据集,每个数据都有对应的标签,也就是说,我们知道样本集
- 实现搜索历史-[即时自动补全&联想搜索]无论是新闻、内容、还是电商平台,联想输入已经成为搜索功能的标配,早已不是什么新鲜事物。我们随
- 一般情况下,局域网里的终端比如本地服务器设置静态IP的好处是可以有效减少网络连接时间,原因是过程中省略了每次联网后从DHCP服务器获取IP地
- 1.折线图 plt.plot()常用的一些参数:颜色(color):‘c’ 青红(cyan)&
- vue中为何方法要写在methods里面1.methods是什么?首先先来段代码,我们在template中设定一个按钮,在点击按钮的时候打印
- 因为即将开始淘宝的项目,在前端方面必然要深入了解taobao ued规范,规范还是比较全的,只是对taobao.com的编码和字符集的选择有
- js对文字进行编码涉及3个函数:escape,encodeURI,encodeURIComponent,相应3个解码函数:unescape,
- IEBlog公布了开发中的Internet Explorer 8 Beta2版本的最新功能.IE8 Beta2在第一个版本的基础上做出了很大
- 利用over(),将统计信息计算出来,然后直接筛选结果集declare @t table(ProductID int,ProductName
- 问题描述使用pandas库的read_excel()方法读取外部excel文件报错, 截图如下好像是缺少了什么方法的样子问题分析分析个啥,
- 连接MySQL时出现1449与1045异常解决办法mysql 1449 : The user specified as a definer
- 具体代码如下所示:<?php//在子类或类内部用“::”调用本类或父类时,不是静态调用方法,而是范围解析操作符。class Paren
- 谈到用户界面交互总少不了事件,前面一系列文章介绍的鼠标光标、坐标、弹出式提示框等实现的底层其实都是事件处理,只不过matplotlib或其他
- 保存文件名太长OSError: [Errno 36] File name too lon问题描述安装pip install python-d
- 这个东西算是我被这个shuffle坑了的一个总结吧!首先我得告诉你一件事,那就是pytorch中的tensor,如果直接使用random.s
- /* * Date Format 1.2.3 * (c) 2007-2009 Steven Levithan * MIT license *
- 如下所示:# 创建一个空的 DataFramedf_empty = pd.DataFrame()#或者df_empty = pd.DataF
- 疫情还没结束,小编只能宅在家里,哪哪也去不了,今天突发奇想给大家分享一篇教程关于Python paramiko 模块浅谈与SSH主要功能模拟
- 循环导入是指两个文件相互导入对方,形成一个导入循环。这会导致Python无法确定哪个模块应该先导入,进而出现错误。举个Flask中的例子:在