pytorch + visdom 处理简单分类问题的示例
作者:泛泛之素 发布时间:2022-08-17 12:56:58
标签:pytorch,visdom
环境
系统 : win 10
显卡:gtx965m
cpu :i7-6700HQ
python 3.61
pytorch 0.3
包引用
import torch
from torch.autograd import Variable
import torch.nn.functional as F
import numpy as np
import visdom
import time
from torch import nn,optim
数据准备
use_gpu = True
ones = np.ones((500,2))
x1 = torch.normal(6*torch.from_numpy(ones),2)
y1 = torch.zeros(500)
x2 = torch.normal(6*torch.from_numpy(ones*[-1,1]),2)
y2 = y1 +1
x3 = torch.normal(-6*torch.from_numpy(ones),2)
y3 = y1 +2
x4 = torch.normal(6*torch.from_numpy(ones*[1,-1]),2)
y4 = y1 +3
x = torch.cat((x1, x2, x3 ,x4), 0).float()
y = torch.cat((y1, y2, y3, y4), ).long()
可视化如下看一下:
visdom可视化准备
先建立需要观察的windows
viz = visdom.Visdom()
colors = np.random.randint(0,255,(4,3)) #颜色随机
#线图用来观察loss 和 accuracy
line = viz.line(X=np.arange(1,10,1), Y=np.arange(1,10,1))
#散点图用来观察分类变化
scatter = viz.scatter(
X=x,
Y=y+1,
opts=dict(
markercolor = colors,
marksize = 5,
legend=["0","1","2","3"]),)
#text 窗口用来显示loss 、accuracy 、时间
text = viz.text("FOR TEST")
#散点图做对比
viz.scatter(
X=x,
Y=y+1,
opts=dict(
markercolor = colors,
marksize = 5,
legend=["0","1","2","3"]
),
)
效果如下:
逻辑回归处理
输入2,输出4
logstic = nn.Sequential(
nn.Linear(2,4)
)
gpu还是cpu选择:
if use_gpu:
gpu_status = torch.cuda.is_available()
if gpu_status:
logstic = logstic.cuda()
# net = net.cuda()
print("###############使用gpu##############")
else : print("###############使用cpu##############")
else:
gpu_status = False
print("###############使用cpu##############")
优化器和loss函数:
loss_f = nn.CrossEntropyLoss()
optimizer_l = optim.SGD(logstic.parameters(), lr=0.001)
训练2000次:
start_time = time.time()
time_point, loss_point, accuracy_point = [], [], []
for t in range(2000):
if gpu_status:
train_x = Variable(x).cuda()
train_y = Variable(y).cuda()
else:
train_x = Variable(x)
train_y = Variable(y)
# out = net(train_x)
out_l = logstic(train_x)
loss = loss_f(out_l,train_y)
optimizer_l.zero_grad()
loss.backward()
optimizer_l.step()
训练过成观察及可视化:
if t % 10 == 0:
prediction = torch.max(F.softmax(out_l, 1), 1)[1]
pred_y = prediction.data
accuracy = sum(pred_y ==train_y.data)/float(2000.0)
loss_point.append(loss.data[0])
accuracy_point.append(accuracy)
time_point.append(time.time()-start_time)
print("[{}/{}] | accuracy : {:.3f} | loss : {:.3f} | time : {:.2f} ".format(t + 1, 2000, accuracy, loss.data[0],
time.time() - start_time))
viz.line(X=np.column_stack((np.array(time_point),np.array(time_point))),
Y=np.column_stack((np.array(loss_point),np.array(accuracy_point))),
win=line,
opts=dict(legend=["loss", "accuracy"]))
#这里的数据如果用gpu跑会出错,要把数据换成cpu的数据 .cpu()即可
viz.scatter(X=train_x.cpu().data, Y=pred_y.cpu()+1, win=scatter,name="add",
opts=dict(markercolor=colors,legend=["0", "1", "2", "3"]))
viz.text("<h3 align='center' style='color:blue'>accuracy : {}</h3><br><h3 align='center' style='color:pink'>"
"loss : {:.4f}</h3><br><h3 align ='center' style='color:green'>time : {:.1f}</h3>"
.format(accuracy,loss.data[0],time.time()-start_time),win =text)
我们先用cpu运行一次,结果如下:
然后用gpu运行一下,结果如下:
发现cpu的速度比gpu快很多,但是我听说机器学习应该是gpu更快啊,百度了一下,知乎上的答案是:
我的理解就是gpu在处理图片识别大量矩阵运算等方面运算能力远高于cpu,在处理一些输入和输出都很少的,还是cpu更具优势。
添加神经层:
net = nn.Sequential(
nn.Linear(2, 10),
nn.ReLU(), #激活函数
nn.Linear(10, 4)
)
添加一层10单元神经层,看看效果是否会有所提升:
使用cpu:
使用gpu:
比较观察,似乎并没有什么区别,看来处理简单分类问题(输入,输出少)的问题,神经层和gpu不会对机器学习加持。
来源:https://blog.csdn.net/tonydz0523/article/details/79032936


猜你喜欢
- 前言大家应该都知道,我们在mysql运维中出现过不少因为update/delete条件错误导致数据被误更新或者删除的case,为避免类似问题
- 当使用桌面应用程序的时候,有没有那么一瞬间,想学习一下桌面应用程序开发?建议此次课程大家稍作了解不要浪费太多时间,因为没有哪家公司会招聘以为
- 目录1.一般的模型构造、训练、测试流程2.自定义损失和指标3.使用tf.data构造数据4.样本权重和类权重5.多输入多输出模型6.使用回
- 本文实例讲述了Golang算法问题之数组按指定规则排序的方法。分享给大家供大家参考,具体如下:给出一个二维数组,请将这个二维数组按第i列(i
- 逐步回归的基本思想是将变量逐个引入模型,每引入一个解释变量后都要进行F检验,并对已经选入的解释变量逐个进行t检验,当原来引入的解释变量由于后
- 本文实例讲述了Python实现PS滤镜中马赛克效果。分享给大家供大家参考,具体如下:这里利用 Python 实现PS 滤镜中的马赛克效果,具
- 如今,基本每个网站都会需要到Tab切换展示内容的滑动门效果应用,这种效果可以在更少的页面空间内,展示更多的网站内容,节约空间,方便用户集中操
- 安装MySQL5.1过程中,我把以前MySQL5.0的GUI工具和服务器全部删掉,安装目录全部删掉,数据文件目录名字改掉,注册表用完美卸载清
- 不能将 SQL Server 2000 日志传送配置升级到 SQL Server 2008。数据库维护计划向导是 SQL Server 20
- 前言只有你想不到,没有我找不到写不了的好游戏!哈喽。我是你们的栗子同学啦~今天小编去了我朋友家里玩儿,看到了一个敲可爱的小狗狗,是我朋友养的
- 1. 概述JSON (JavaScript Object Notation)是一种使用广泛的轻量数据格式. Python标准库中的json模
- 本文实例讲述了Python图像滤波处理操作。分享给大家供大家参考,具体如下:在图像处理中,经常需要对图像进行平滑、锐化、边界增强等滤波处理。
- 详细解读Jquery各Ajax函数:$.get(),$.post(),$.ajax(),$.getJSON()一,$.get(url,[da
- 该脚本是为了结合之前的编写的脚本,来实现数据的比对模块,实现数据的自动化!由于数据格式是定死的,该代码只做参考,有什么问题可以私信我!CSV
- BULK INSERT以用户指定的格式复制一个数据文件至数据库表或视图中。 语法:BULK INSERT [ [ 'database
- 1、设置web.config文件。以下为引用的内容:<system.web> ...... <globalization
- 如果你的PHP网站换了空间,必定要对Mysql数据库进行转移,一般的转移的方法,是备份再还原,有点繁琐,而且由于数据库版本的不一样会导致数据
- 切片与数组数组数组是具有相同 唯一类型 的一组以编号且长度固定的数据项序列数组声明var identifier [len]type切片切片(
- 1、更新NVIDIA驱动 选对应自己显卡的驱动,(选studio版本,不要game版本)驱动链接 2、添加Anacond
- 图像在计算机中的存储图像其实就是一个像素值组成的矩阵。1、黑白或灰度图像如何存储在计算机中在这里,我们已经采取了黑白图像,也被称为一个灰度图