pytorch实现CNN卷积神经网络
作者:小山爱学习 发布时间:2023-07-04 20:23:06
标签:pytorch,神经网络
本文为大家讲解了pytorch实现CNN卷积神经网络,供大家参考,具体内容如下
我对卷积神经网络的一些认识
卷积神经网络是时下最为流行的一种深度学习网络,由于其具有局部感受野等特性,让其与人眼识别图像具有相似性,因此被广泛应用于图像识别中,本人是研究机械故障诊断方面的,一般利用旋转机械的振动信号作为数据。
对一维信号,通常采取的方法有两种,第一,直接对其做一维卷积,第二,反映到时频图像上,这就变成了图像识别,此前一直都在利用keras搭建网络,最近学了pytroch搭建cnn的方法,进行一下代码的尝试。所用数据为经典的minist手写字体数据集
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
`EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = True
从网上下载数据集:
```python
train_data = torchvision.datasets.MNIST(
root="./mnist/",
train = True,
transform=torchvision.transforms.ToTensor(),
download = DOWNLOAD_MNIST,
)
print(train_data.train_data.size())
print(train_data.train_labels.size())
```plt.imshow(train_data.train_data[0].numpy(), cmap='autumn')
plt.title("%i" % train_data.train_labels[0])
plt.show()
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = torchvision.datasets.MNIST(root="./mnist/", train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=5,
stride=1,
padding=2,
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.out = nn.Linear(32*7*7, 10) # fully connected layer, output 10 classes
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1) # flatten the output of conv2 to (batch_size, 32*7*7)
output = self.out(x)
return output
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
from matplotlib import cm
try: from sklearn.manifold import TSNE; HAS_SK = True
except: HAS_SK = False; print('Please install sklearn for layer visualization')
def plot_with_labels(lowDWeights, labels):
plt.cla()
X, Y = lowDWeights[:, 0], lowDWeights[:, 1]
for x, y, s in zip(X, Y, labels):
c = cm.rainbow(int(255 * s / 9)); plt.text(x, y, s, backgroundcolor=c, fontsize=9)
plt.xlim(X.min(), X.max()); plt.ylim(Y.min(), Y.max()); plt.title('Visualize last layer'); plt.show(); plt.pause(0.01)
plt.ion()
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader):
output = cnn(b_x)
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
test_output = cnn(test_x)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
print("Epoch: ", epoch, "| train loss: %.4f" % loss.data.numpy(),
"| test accuracy: %.2f" % accuracy)
plt.ioff()
来源:https://blog.csdn.net/xiaoshan0609/article/details/104382920


猜你喜欢
- 1.前言在 Python 中,说到函数,大家都很容易想到用 def 关键字来声明一个函数:def Hello():
- 本文实例讲述了Django发送html邮件的方法。分享给大家供大家参考。具体如下:在Django中,发送邮件非常的方便,一直没有时间,今天来
- 今天用Python提取了Linux内核源代码的目录树结构,没有怎么写过脚本程序,我居然折腾了2个小时,先是如何枚举出给定目录下的所有文件和文
- 一般我是用<%@ include %>方式来包含这个文件,主要是这样能够被包含页面会跟包含页面在编译时被编译成一个文件,里面的变
- sql脚本是包含一到多个sql命令的sql语句,我们可以将这些sql脚本放在一个文本文件中(我们称之为“sql脚本文件”),然后通过相关的命
- 今天,本文向大家推荐20佳国外的脚本下载网站。1- Hot Scripts2- Code Canyon3- User Scripts4- S
- 在许多场合,你将不得不编写必须处理时间的代码。你可以写一个时钟程序,或者在你的代码中测量两点之间的时间差。无论是哪种方式,知道如何在Go中处
- 1.什么是ORMORM 全拼Object-Relation Mapping.中文意为 对象-关系映射.在MVC/MVT设
- 目录先明确几点赋值浅拷贝深拷贝总结先明确几点不可变类型:该数据类型对象所指定内存中的值不可以被改变。(1)、在改变某个对象的值时,由于其内存
- 完成asp语言对XML文档中指定节点文本的增加、删除、修改、查看 <% '-------------------
- 摘要:利用xlrd读取excel利用xlwt写excel利用xlutils修改excel利用xlrd读取excel先需要在命令行中pip i
- 今天我要讲如何远程调试openstack。首先我们使用的工具是Pycharm.1.首先介绍一下环境我的openstack是使用rdo一键安装
- var a = 0, b = 0;[0, 0].sort(function() {a = 1;return 0;});[0, 1].sort
- 废话不多说,我直接上代码吧!# 递归方法打印多重列表li = [1, [[2, [3]], [4], 5], 6, 7, [8], 9, 1
- SQLServer数据导出到excel有很多种方法,比如dts、ssis、还可以用sql语句调用openrowset。我们这里开拓思路,用C
- 最近一直在用Vs2013调试编译opencv,意外发现一个超级赞的图片查看的插件, 超级方便易用的一个插件,直接以图片形式可视化了openc
- 为何选Nuxt.js?在前后端分离出现之前,传统的web页面都是服务端渲染的,如JSP、PHP、Python Django,还有各种模板技术
- 如何用WSH获取机器的IP配置信息?我们用VBSCRIPT转换了: Option Explicit Dim&n
- 聚集索引,数据实际上是按顺序存储的,数据页就在索引页上。就好像参考手册将所有主题按顺序编排一样。一旦找到了所要搜索的数据,就完成了这次搜索,
- 本文实例为大家分享了OpenCV+face++实现实时人脸识别解锁功能的具体代码,供大家参考,具体内容如下1.背景最近做一个小东西,需要登录