网络编程
位置:首页>> 网络编程>> Python编程>> python 使用Tensorflow训练BP神经网络实现鸢尾花分类

python 使用Tensorflow训练BP神经网络实现鸢尾花分类

作者:你,好  发布时间:2023-04-15 13:29:00 

标签:python,Tensorflow,鸢尾花分类,BP神经网络
目录
  • 使用软件

  • 问题描述

  • 搭建神经网络

  • 训练参数

    • 损失函数

  • 参数优化

    • 代码

      • 数据集

      • 参数

      • 训练

      • 测试

    • 结语

      Hello,兄弟们,开始搞深度学习了,今天出第一篇博客,小白一枚,如果发现错误请及时指正,万分感谢。

      使用软件

      Python 3.8,Tensorflow2.0

      问题描述

      鸢尾花主要分为狗尾草鸢尾(0)、杂色鸢尾(1)、弗吉尼亚鸢尾(2)。
      人们发现通过计算鸢尾花的花萼长、花萼宽、花瓣长、花瓣宽可以将鸢尾花分类。
      所以只要给出足够多的鸢尾花花萼、花瓣数据,以及对应种类,使用合适的神经网络训练,就可以实现鸢尾花分类。

      搭建神经网络

      输入数据是花萼长、花萼宽、花瓣长、花瓣宽,是n行四列的矩阵。
      而输出的是每个种类的概率,是n行三列的矩阵。
      我们采用BP神经网络,设X为输入数据,Y为输出数据,W为权重,B偏置。有

      y=x∗w+b

      因为x为n行四列的矩阵,y为n行三列的矩阵,所以w必须为四行三列的矩阵,每个神经元对应一个b,所以b为一行三列的的矩阵。
      神经网络如下图。

      python 使用Tensorflow训练BP神经网络实现鸢尾花分类

      所以,只要找到合适的w和b,就能准确判断鸢尾花的种类。
      下面就开始对这两个参数进行训练。

      训练参数

      损失函数

      损失函数表达的是预测值(y*)和真实值(y)的差距,我们采用均方误差公式作为损失函数。

      python 使用Tensorflow训练BP神经网络实现鸢尾花分类

      损失函数值越小,说明预测值和真实值越接近,w和b就越合适。
      如果人来一组一组试,那肯定是不行的。所以我们采用梯度下降算法来找到损失函数最小值。
      梯度:对函数求偏导的向量。梯度下降的方向就是函数减少的方向。

      python 使用Tensorflow训练BP神经网络实现鸢尾花分类

      其中a为学习率,即梯度下降的步长,如果a太大,就可能错过最优值,如果a太小,则就需要更多步才能找到最优值。所以选择合适的学习率很关键。

      python 使用Tensorflow训练BP神经网络实现鸢尾花分类

      参数优化

      通过反向传播来优化参数。
      反向传播:从后向前,逐层求损失函数对每层神经元参数的偏导数,迭代更新所有参数。
      比如

      python 使用Tensorflow训练BP神经网络实现鸢尾花分类

      python 使用Tensorflow训练BP神经网络实现鸢尾花分类

      可以看到w会逐渐趋向于loss的最小值0。
      以上就是我们训练的全部关键点。

      代码

      数据集

      我们使用sklearn包提供的鸢尾花数据集。共150组数据。
      打乱保证数据的随机性,取前120个为训练集,后30个为测试集。


      # 导入数据,分别为输入特征和标签
      x_data = datasets.load_iris().data ## 存花萼、花瓣特征数据
      y_data = datasets.load_iris().target # 存对应种类
      # 随机打乱数据(因为原始数据是顺序的,顺序不打乱会影响准确率)
      # seed: 随机数种子,是一个整数,当设置之后,每次生成的随机数都一样(为方便教学,以保每位同学结果一致)
      np.random.seed(116)  # 使用相同的seed,保证输入特征和标签一一对应
      np.random.shuffle(x_data)
      np.random.seed(116)
      np.random.shuffle(y_data)
      tf.random.set_seed(116)
      # 将打乱后的数据集分割为训练集和测试集,训练集为前120行,测试集为后30行
      x_train = x_data[:-30]
      y_train = y_data[:-30]
      x_test = x_data[-30:]
      y_test = y_data[-30:]
      # 转换x的数据类型,否则后面矩阵相乘时会因数据类型不一致报错
      x_train = tf.cast(x_train, tf.float32)
      x_test = tf.cast(x_test, tf.float32)
      # from_tensor_slices函数使输入特征和标签值一一对应。(把数据集分批次,每个批次batch组数据)
      train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(32)
      test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

      参数


      # 生成神经网络的参数,4个输入特征故,输入层为4个输入节点;因为3分类,故输出层为3个神经元
      # 用tf.Variable()标记参数可训练
      w1 = tf.Variable(tf.random.truncated_normal([4, 3], stddev=0.1)) # 四行三列,方差为0.1
      b1 = tf.Variable(tf.random.truncated_normal([3], stddev=0.1)) # 一行三列,方差为0.1

      训练


      a = 0.1  # 学习率为0.1
      epoch = 500  # 循环500轮
      # 训练部分
      for epoch in range(epoch):  # 数据集级别的循环,每个epoch循环一次数据集
         for step, (x_train, y_train) in enumerate(train_db):  # batch级别的循环 ,每个step循环一个batch
             with tf.GradientTape() as tape:  # with结构记录梯度信息
                 y = tf.matmul(x_train, w1) + b1  # 神经网络乘加运算
                 y = tf.nn.softmax(y)  # 使输出y符合概率分布
                 y_ = tf.one_hot(y_train, depth=3)  # 将标签值转换为独热码格式,方便计算loss
                 loss = tf.reduce_mean(tf.square(y_ - y))  # 采用均方误差损失函数mse = mean(sum(y-y*)^2)
             # 计算loss对w, b的梯度
             grads = tape.gradient(loss, [w1, b1])
             # 实现梯度更新 w1 = w1 - lr * w1_grad    b = b - lr * b_grad
             w1.assign_sub(a * grads[0])  # 参数w1自更新
             b1.assign_sub(a * grads[1])  # 参数b自更新

      测试


      # 测试部分
      total_correct, total_number = 0, 0
      for x_test, y_test in test_db:
         # 前向传播求概率
         y = tf.matmul(x_test, w1) + b1
         y = tf.nn.softmax(y)
         predict = tf.argmax(y, axis=1)  # 返回y中最大值的索引,即预测的分类
         # 将predict转换为y_test的数据类型
         predict = tf.cast(predict, dtype=y_test.dtype)
         # 若分类正确,则correct=1,否则为0,将bool型的结果转换为int型
         correct = tf.cast(tf.equal(predict, y_test), dtype=tf.int32)
         # 将每个batch的correct数加起来
         correct = tf.reduce_sum(correct)
         # 将所有batch中的correct数加起来
         total_correct += int(correct)
         # total_number为测试的总样本数,也就是x_test的行数,shape[0]返回变量的行数
         total_number += x_test.shape[0]
      # 总的准确率等于total_correct/total_number
      acc = total_correct / total_number
      print("测试准确率 = %.2f %%" % (acc * 100.0))
      my_test = np.array([[5.9, 3.0, 5.1, 1.8]])
      print("输入 5.9  3.0  5.1  1.8")
      my_test = tf.convert_to_tensor(my_test)
      my_test = tf.cast(my_test, tf.float32)
      y = tf.matmul(my_test, w1) + b1
      y = tf.nn.softmax(y)
      species = {0: "狗尾鸢尾", 1: "杂色鸢尾", 2: "弗吉尼亚鸢尾"}
      predict = np.array(tf.argmax(y, axis=1))[0]  # 返回y中最大值的索引,即预测的分类
      print("该鸢尾花为:" + species.get(predict))

      结果:

      python 使用Tensorflow训练BP神经网络实现鸢尾花分类

      结语

      来源:https://blog.csdn.net/weixin_44653914/article/details/116517033

      0
      投稿

      猜你喜欢

      • Asp(Active Server Pages)是Web服务器端脚本编写环境,可以使用Vbscript/Jscript两种脚本来编写.作为我
      • 前言当我们需要对列表(list)、元组(tuple)、字典(dictionary)和集合(set)的元素进行遍历时,其实Python内部都是
      • 今天介绍一下 go语言的并发机制以及它所使用的CSP并发模型CSP并发模型CSP模型是上个世纪七十年代提出的,用于描述两个独立的并发实体通过
      • [数据恢复故障描述]一台重要的MYSQL数据库服务器,146GB*2,RAID1,约130GB DATA卷,存储了大约200~300个数据库
      • 需求:启动程序后,让用户输入工资,然后打印商品列表允许用户根据商品编号购买商品用户选择商品后,检测余额是否够,够就直接扣款,不够就提醒可随时
      • 开发工具python版本:3.6.4相关模块:pygame;以及一些python自带的模块。环境搭建安装python并添加到环境变量,pip
      • 在python里面,读取或写入csv文件时,首先要import csv这个库,然后利用这个库提供的方法进行对文件的读写。典型的数据集stoc
      • 1. 用户输入内容与打印输入:input()输出:print()例1,输入字符串,并原样输出a = input('请输入一些字符&#
      • 前言相信每位家长都有所体会,因为要在孩子出生后两周内起个名字(需要办理出生证明了),估计很多人都像我一样,刚开始是很慌乱的,虽然感觉汉字非常
      • 注:所有文字,除注明网站类型外,其他均针对企业站点.请随时注意留言,若修改则会在首页提示文字里标注.若牵扯到业务方面的问题,我可能不会做过多
      • 也许已经有人发现可以这样写...CSS代码部分a.info {     position:
      • MySQL的本地备份和双机相互备份脚本:首先,我们需要修改脚本进行必要的配置,然后以root用户执行。◆1. 第一执行远程备份时先用 fir
      • 第一步一般是建立一个关键字替换表 如 id keyword url 等字段第二步是文章显示时把【文章】内容和【关键字替换表】对应的关键字替换
      • 我的世界小游戏使用方法:移动前进:W,后退:S,向左:A,向右:D,环顾四周:鼠标,跳起:空格键,切换飞行模式:Tab;选择建筑材料砖:1,
      • ASP编写完整的一个IP所在地搜索类的修正文稿修正了查询方法,查询的方法和追捕的一致;只是追捕会自动更正IP。还有个函数的书写错误,也已经修
      • 著名的老掉牙的IE6.0在我这里已经有六年工龄了,前几天朋友拿到个IE8.0新的Beta版本,我的Sever2003装不上,大为扫兴。Chr
      • 做沙盒的时候遇到一个小问题——在IE9里面竟然抓不到事件的keyCode:element.addEventListener('key
      • jsp登陆验证,网页登陆验证带验证码校验,登录功能之添加验证码part_1:专门用于生成一个验证码图片的类:VerificationCode
      • 我们有时候希望回车键敲在文本框(input element)里来提交表单(form),但有时候又不希望如此。比如搜索行为,希望输入完关键词之
      • 许多网站缺乏针对性和友好的导航设计,难以找到连接到相关网页的路径,也没有提供有助于让访客/用户找到所需信息的帮助,用户体验非常糟糕。本期薯片
      手机版 网络编程 asp之家 www.aspxhome.com