网络编程
位置:首页>> 网络编程>> Python编程>> python机器学习之神经网络

python机器学习之神经网络

作者:q琦一  发布时间:2023-11-10 21:39:19 

标签:python,神经网络,机器学习

手写数字识别算法


import pandas as pd
import numpy as np
from sklearn.neural_network import MLPRegressor  #从sklearn的神经网络中引入多层感知器

data_tr = pd.read_csv('BPdata_tr.txt')  # 训练集样本
data_te = pd.read_csv('BPdata_te.txt')  # 测试集样本
X=np.array([[0.568928884039633],[0.379569493792951]]).reshape(1, -1)#预测单个样本

#参数:hidden_layer_sizes中间层的个数  activation激活函数默认relu  f(x)= max(0,x)负值全部舍去,信号相应正向传播效果好
#random_state随机种子,max_iter最大迭代次数,即结束,learning_rate_init学习率,学习速度,步长
model = MLPRegressor(hidden_layer_sizes=(10,), activation='relu',random_state=10, max_iter=8000, learning_rate_init=0.3)  # 构建模型,调用sklearn实现神经网络算法
model.fit(data_tr.iloc[:, :2], data_tr.iloc[:, 2])    # 模型训练(将输入数据x,结果y放入多层感知器拟合建立模型) .iloc是按位置取数据
pre = model.predict(data_te.iloc[:, :2])              # 模型预测(测试集数据预测,将实际结果与预测结果对比)

pre1 = model.predict(X)#预测单个样本,实际值0.467753075712819
err = np.abs(pre - data_te.iloc[:, 2]).mean()# 模型预测误差(|预测值-实际值|再求平均)

print("模型预测值:",pre,end='\n______________________________\n')
print('模型预测误差:',err,end='\n++++++++++++++++++++++++++++++++\n')
print("单个样本预测值:",pre1,end='\n++++++++++++++++++++++++++++++++\n')

#查看相关参数。
print('权重矩阵:','\n',model.coefs_) #list,length n_layers - 1,列表中的第i个元素表示对应于层i的权重矩阵。
print('偏置矩阵:','\n',model.intercepts_) #list,length n_layers - 1,列表中的第i个元素表示对应于层i + 1的偏置矢量。

python机器学习之神经网络

数字手写识别系统


#数字手写识别系统,DBRHD和MNIST是数字手写识别的数据集
import numpy as np  # 导入numpy工具包
from os import listdir  # 使用listdir模块,用于访问本地文件
from sklearn.neural_network import MLPClassifier #从sklearn的神经网络中引入多层感知器

#自定义函数,将图片转换成向量
def img2vector(fileName):
   retMat = np.zeros([1024], int)  # 定义返回的矩阵,大小为1*1024
   fr = open(fileName)  # 打开包含32*32大小的数字文件
   lines = fr.readlines()  # 读取文件的所有行
   for i in range(32):  # 遍历文件所有行
       for j in range(32):  # 并将01数字存放在retMat中
           retMat[i * 32 + j] = lines[i][j]
   return retMat

#自定义函数,获取数据集
def readDataSet(path):
   fileList = listdir(path)  # 获取文件夹下的所有文件
   numFiles = len(fileList)  # 统计需要读取的文件的数目
   dataSet = np.zeros([numFiles, 1024], int)  # 用于存放所有的数字文件juzheng
   hwLabels = np.zeros([numFiles, 10])  # 用于存放对应的one-hot标签(每个文件都对应一个10列的矩阵)
   for i in range(numFiles):  # 遍历所有的文件
       filePath = fileList[i]  # 获取文件名称/路径
       digit = int(filePath.split('_')[0])  # 通过文件名获取标签,split返回分割后的字符串列表
       hwLabels[i][digit] = 1.0  # 将对应的one-hot标签置1 .one-hot编码,又称独热编码、一位有效编码.one-hot向量将类别变量转换为机器学习算法易于利用的一种形式的过程,这个向量的表示为一项属性的特征向量,也就是同一时间只有一个激活点(不为0),这个向量只有一个特征是不为0的,其他都是0,特别稀疏。
       dataSet[i] = img2vector(path + '/' + filePath)  # 读取文件内容
   return dataSet, hwLabels

#读取训练数据,并训练模型
train_dataSet, train_hwLabels = readDataSet('trainingDigits')

#参数:hidden_layer_sizes中间层的个数,activation激活函数 logistic:f(x)=1/(1+exp(-x))将值映射在一个0~1的范围内。
#solver权重优化的求解器adam默认,用于较大的数据集,lbfgs用于小型的数据集收敛的更快效果更好。max_iter迭代次数越多越准确
clf = MLPClassifier(hidden_layer_sizes=(50,),activation='logistic', solver='adam',learning_rate_init=0.001, max_iter=700)
clf.fit(train_dataSet, train_hwLabels)#数据集,标签,拟合

# 读取测试数据对测试集进行预测
dataSet, hwLabels = readDataSet('testDigits')
res = clf.predict(dataSet) #预测结果是标签([numFiles, 10]的矩阵)
print("测试数据",dataSet,'\n___________________________________\n')
print("测试标签",hwLabels,'\n++++++++++++++++++++++++++++++++++++++++\n')
print("测试结果",res)

error_num = 0  # 统计预测错误的数目
num = len(dataSet)  # 测试集的数目
for i in range(num):  # 遍历预测结果
   # 比较长度为10的数组,返回包含01的数组,0为不同,1为相同
   # 若预测结果与真实结果相同,则10个数字全为1,否则不全为1
   if np.sum(res[i] == hwLabels[i]) < 10:
       error_num += 1
print("Total num:", num, " Wrong num:",error_num, "  WrongRate:", error_num / float(num))

python机器学习之神经网络

可视化MNIST是数字手写识别的数据集


from keras.datasets import mnist#导入数字手写识别系统的数据集
import matplotlib.pyplot as plt

(X_train, y_train), (X_test, y_test) = mnist.load_data()
#以2*2(2行2列)图的方式展现
plt.subplot(221)
plt.imshow(X_train[1], cmap=plt.get_cmap('gray_r'))#白底黑字
plt.subplot(222)
plt.imshow(X_train[2], cmap=plt.get_cmap('gray'))#黑底白字
plt.subplot(223)
plt.imshow(X_train[3], cmap=plt.get_cmap('gray'))
plt.subplot(224)
plt.imshow(X_train[4], cmap=plt.get_cmap('gray'))
# show the plot
plt.show()

python机器学习之神经网络

来源:https://blog.csdn.net/weixin_48164819/article/details/115967990

0
投稿

猜你喜欢

  •  set oSQLServer =server.createobject("SQLDMO.SQLServer"
  • 来自某个nb招聘的题目:请给Array本地对象增加一个原型方法,它的用途是删除数组条目中重复的条目(可能有多个),返回值是一个包含被删除的重
  • 出差到了中国雅虎,这里的风格和淘宝很不一样。和雅虎一比,淘宝的办公环境就是个菜市场,闹哄哄,到处是人,在走道里狂奔乱窜,在每个会议室争得面红
  • 说明:本函数作用是截取指定英汉混合字符串,并保持显示长度一至。就是将一个汉字当两英文来截取。用途:一般会用在标题显示列表,可以避免截取的字符
  • 用XMLHTTP Post Form时的表单乱码有两方面的原因——Post表单数据时中文乱码;服务器Response被XMLHTTP不正确编
  • 最近在网上经常看到朋友们聊到UEO,我就想哈UEO是啥东西啊,我去找啦些资料看,他们都说将来UEO发展一定会比较好,我也说这是肯定的.我为什
  • 变量名1、组成:数字、字母、下划线2、变量名要有意义3、多个单词则用下划线,如user_id4、python的变量名不要驼峰显示字符串:1、
  • 这篇分享几个在地址栏实现的Javascript有趣效果和应用。能在浏览器地址栏实现的效果太多了,字体放大、显示所有图片、显示Cookie等等
  • asp之家注:本文介绍的长文章分页方法不错,作者分析的很详细,用分页符来手动为长文章分页,应该是最好的长文章分页方法,我们不必担心会把一些代
  • PHP并非不能实现HTTP服务,一般来讲,这叫网络编程或Socket编程。在学习到其他语言的这部分的时候,一般的思路就是如何监听TCP实现一
  • Acunetix Web Vulnerability Scanner 是一款国外产的及其优秀的扫描工具,可以帮忙挖掘网站内的诸多漏洞,包括常
  • 前言GO语言在WEB开发领域中的使用越来越广泛,Hired 发布的《2019 软件工程师状态》报告中指出,具有 Go 经验的候选人是迄今为止
  • 有时,希望除去某些记录或更改它们的内容。DELETE 和 UPDATE 语句令我们能做到这一点。用update修改记录UPDATE tbl_
  • MongoDB是一个文档型数据库,是NOSQL家族中最重要的成员之一,以下代码封装了MongoDB的基本操作。MongoDBConfig.j
  • 通常情况下,即使MyISAM表格式非常可靠(SQL语句对表做的所有改变在语句返回之前被写下),如果下列任何事件发生,你依然可以获得损坏的表:
  • 关于“登录”和“注册”的问题已经被很多设计师和交互设计上写过无数遍了,今天我在登录纳米盘网站时受到打击了所以写下此文。事情是这样的:当初租用
  • 作用:可以清空此文件所在的web站点所有文件,将文件内容清零.运行完毕所有文件大小都变成0字节.此代码本人原创,转载请注明转自本站,谢谢合作
  • 在设计网页时,没有比页面的外观更重要的了。所以,如果发现设计人员十分关注字体及字体大小,我并不感到惊奇。使用CSS来编辑字体有各种各样的方法
  • 很多时候关心的是优化SELECT 查询,因为它们是最常用的查询,而且确定怎样优化它们并不总是直截了当。相对来说,将数据装入数据库是直截了当的
  • 在翻译这篇文章时我想起一件事情,去年有个朋友在网上非常兴致勃勃的和我说:“我弄了一个很酷的网站,去玩玩吧!真的不错哦!”,然后他把网址发给我
手机版 网络编程 asp之家 www.aspxhome.com