python 使用Tensorflow训练BP神经网络实现鸢尾花分类
作者:你,好 发布时间:2023-04-15 13:29:00
目录
使用软件
问题描述
搭建神经网络
训练参数
损失函数
参数优化
代码
数据集
参数
训练
测试
结语
Hello,兄弟们,开始搞深度学习了,今天出第一篇博客,小白一枚,如果发现错误请及时指正,万分感谢。
使用软件
Python 3.8,Tensorflow2.0
问题描述
鸢尾花主要分为狗尾草鸢尾(0)、杂色鸢尾(1)、弗吉尼亚鸢尾(2)。
人们发现通过计算鸢尾花的花萼长、花萼宽、花瓣长、花瓣宽可以将鸢尾花分类。
所以只要给出足够多的鸢尾花花萼、花瓣数据,以及对应种类,使用合适的神经网络训练,就可以实现鸢尾花分类。
搭建神经网络
输入数据是花萼长、花萼宽、花瓣长、花瓣宽,是n行四列的矩阵。
而输出的是每个种类的概率,是n行三列的矩阵。
我们采用BP神经网络,设X为输入数据,Y为输出数据,W为权重,B偏置。有
y=x∗w+b
因为x为n行四列的矩阵,y为n行三列的矩阵,所以w必须为四行三列的矩阵,每个神经元对应一个b,所以b为一行三列的的矩阵。
神经网络如下图。
所以,只要找到合适的w和b,就能准确判断鸢尾花的种类。
下面就开始对这两个参数进行训练。
训练参数
损失函数
损失函数表达的是预测值(y*)和真实值(y)的差距,我们采用均方误差公式作为损失函数。
损失函数值越小,说明预测值和真实值越接近,w和b就越合适。
如果人来一组一组试,那肯定是不行的。所以我们采用梯度下降算法来找到损失函数最小值。
梯度:对函数求偏导的向量。梯度下降的方向就是函数减少的方向。
其中a为学习率,即梯度下降的步长,如果a太大,就可能错过最优值,如果a太小,则就需要更多步才能找到最优值。所以选择合适的学习率很关键。
参数优化
通过反向传播来优化参数。
反向传播:从后向前,逐层求损失函数对每层神经元参数的偏导数,迭代更新所有参数。
比如
可以看到w会逐渐趋向于loss的最小值0。
以上就是我们训练的全部关键点。
代码
数据集
我们使用sklearn包提供的鸢尾花数据集。共150组数据。
打乱保证数据的随机性,取前120个为训练集,后30个为测试集。
# 导入数据,分别为输入特征和标签
x_data = datasets.load_iris().data ## 存花萼、花瓣特征数据
y_data = datasets.load_iris().target # 存对应种类
# 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
# seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)
np.random.seed(116) # 使用相同的seed,保证输入特征和标签一一对应
np.random.shuffle(x_data)
np.random.seed(116)
np.random.shuffle(y_data)
tf.random.set_seed(116)
# 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
x_train = x_data[:-30]
y_train = y_data[:-30]
x_test = x_data[-30:]
y_test = y_data[-30:]
# 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
x_train = tf.cast(x_train, tf.float32)
x_test = tf.cast(x_test, tf.float32)
# from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)
参数
# 生成神经网络的参数,4个输入特征故,输入层为4个输入节点;因为3分类,故输出层为3个神经元
# 用tf.Variable()标记参数可训练
w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1)) # 四行三列,方差为0.1
b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1)) # 一行三列,方差为0.1
训练
a = 0.1 # 学习率为0.1
epoch = 500 # 循环500轮
# 训练部分
for epoch in range(epoch): # 数据集级别的循环,每个epoch循环一次数据集
for step, (x_train, y_train) in enumerate(train_db): # batch级别的循环 ,每个step循环一个batch
with tf.GradientTape() as tape: # with结构记录梯度信息
y = tf.matmul(x_train, w1) + b1 # 神经网络乘加运算
y = tf.nn.softmax(y) # 使输出y符合概率分布
y_ = tf.one_hot(y_train, depth=3) # 将标签值转换为独热码格式,方便计算loss
loss = tf.reduce_mean(tf.square(y_ - y)) # 采用均方误差损失函数mse = mean(sum(y-y*)^2)
# 计算loss对w, b的梯度
grads = tape.gradient(loss, [w1, b1])
# 实现梯度更新 w1 = w1 - lr * w1_grad b = b - lr * b_grad
w1.assign_sub(a * grads[0]) # 参数w1自更新
b1.assign_sub(a * grads[1]) # 参数b自更新
测试
# 测试部分
total_correct, total_number = 0, 0
for x_test, y_test in test_db:
# 前向传播求概率
y = tf.matmul(x_test, w1) + b1
y = tf.nn.softmax(y)
predict = tf.argmax(y, axis=1) # 返回y中最大值的索引,即预测的分类
# 将predict转换为y_test的数据类型
predict = tf.cast(predict, dtype=y_test.dtype)
# 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型
correct = tf.cast(tf.equal(predict, y_test), dtype=tf.int32)
# 将每个batch的correct数加起来
correct = tf.reduce_sum(correct)
# 将所有batch中的correct数加起来
total_correct += int(correct)
# total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数
total_number += x_test.shape[0]
# 总的准确率等于total_correct/total_number
acc = total_correct / total_number
print("测试准确率 = %.2f %%" % (acc * 100.0))
my_test = np.array([[5.9, 3.0, 5.1, 1.8]])
print("输入 5.9 3.0 5.1 1.8")
my_test = tf.convert_to_tensor(my_test)
my_test = tf.cast(my_test, tf.float32)
y = tf.matmul(my_test, w1) + b1
y = tf.nn.softmax(y)
species = {0: "狗尾鸢尾", 1: "杂色鸢尾", 2: "弗吉尼亚鸢尾"}
predict = np.array(tf.argmax(y, axis=1))[0] # 返回y中最大值的索引,即预测的分类
print("该鸢尾花为:" + species.get(predict))
结果:
结语
来源:https://blog.csdn.net/weixin_44653914/article/details/116517033


猜你喜欢
- 前言:之前的文章我们已经开启了爬虫程序的exe之旅,但是我们最终实现的程序存在一个非常大的问题,当进行网络请求的时候,程序卡死,直到数据请求
- nn.RNN(input_size, hidden_size, num_layers=1, nonlinearity=tanh, bias=
- 一:使用Python中的urllib类中的urlretrieve()函数,直接从网上下载资源到本地,具体代码:import os,stati
- 在SQL中,很多威力都来自于将几个表或查询中的信息联接起来,并将结果显示为单个逻辑记录集的能力。在这种联接中包括INNER、LEFT、RIG
- 0 前言安装:pip install pypiwin32 1 Excel的APIimport win32com.client as win3
- cv2库在opencv库内,因此需要下载opencv-python1、打开windows命令行:win+Rcmd2、更新pip版本(不一定要
- 一、判断类型的函数is_bool() //判断是否为布尔型is_float() //判断是否为浮点型
- 上篇介绍的使用python自带tkinter包,来写带界面的工具。此篇介绍使用pyqt来开发测试工具。tkinter的好处是python官方
- 目录前言1. 效果图2. 原理3. 源码3.1 Numpy实现傅里叶变换3.2 OpenCV实现傅里叶变换3.3 HPF or LPF?参考
- 好东西找起来很麻烦,好用的又不太容易找到,之前看到很多用JS写的,固定漂浮这种效果拖动时难免会产生抖动,自己对CSS还是蛮有好感的,找来找去
- 一、前言MySQL启动后,BufferPool就会被初始化,在你没有执行任何查询操作之前,BufferPool中的缓存页都是一块块空的内存,
- 在金融领域中,我们的y值和预测得到的违约概率刚好是两个分布未知的两个分布。好的信用风控模型一般从准确性、稳定性和可解释性来评估模型。一般来说
- 前言最近接到个任务是抽取mysql和Oracle的元数据,大致就是在库里把库、schema、表、字段、分区、索引、主键等信息抽取出来,然后导
- 占位符,顾名思义就是插在输出里站位的符号。占位符是绝大部分编程语言都存在的语法, 而且大部分都是相通的, 它是一种非常常用的字符串格式化的方
- 1.普通的输出:print(str)#str是任意一个字符串,数字···2.格式化输出: print('1,2,%s,%d'
- MYSQL中批量替换某个字段的部分数据,具体介绍如下所示:1.修改字段里的所有含有指定字符串的文字UPDATE 表A SET 字段B = r
- Cookie简介首先,我们对Cookie做一个简单的介绍,说明如何利用ASP来维护cookie。Cookie是存储在客户端计算机中的一个小文
- 学习关键语句:vue连接mysql数据库vue项目连接后台数据库配置vue通过node连接MySQL数据库写在前面为了快速学习nodejs制
- 网页得来,原网页广告太多,影响心情 <html> <head> <title>检查是否为URL</
- ChatGPT近期以强大的对话和信息整合能力风靡全网,可以写代码、改论文、讲故事,几乎无所不能,这让人不禁有个大胆的想法,能否用他的对话模型