tensorflow实现逻辑回归模型
作者:Missayaa 发布时间:2022-01-18 20:10:28
标签:tensorflow,逻辑回归
逻辑回归模型
逻辑回归是应用非常广泛的一个分类机器学习算法,它将数据拟合到一个logit函数(或者叫做logistic函数)中,从而能够完成对事件发生的概率进行预测。
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#下载好的mnist数据集存在F:/mnist/data/中
mnist = input_data.read_data_sets('F:/mnist/data/',one_hot = True)
print(mnist.train.num_examples)
print(mnist.test.num_examples)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print(type(trainimg))
print(trainimg.shape,)
print(trainlabel.shape,)
print(testimg.shape,)
print(testlabel.shape,)
nsample = 5
randidx = np.random.randint(trainimg.shape[0],size = nsample)
for i in randidx:
curr_img = np.reshape(trainimg[i,:],(28,28))
curr_label = np.argmax(trainlabel[i,:])
plt.matshow(curr_img,cmap=plt.get_cmap('gray'))
plt.title(""+str(i)+"th Training Data"+"label is"+str(curr_label))
print(""+str(i)+"th Training Data"+"label is"+str(curr_label))
plt.show()
x = tf.placeholder("float",[None,784])
y = tf.placeholder("float",[None,10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
#
actv = tf.nn.softmax(tf.matmul(x,W)+b)
#计算损失
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
#学习率
learning_rate = 0.01
#随机梯度下降
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
#求1位置索引值 对比预测值索引与label索引是否一样,一样返回True
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#tf.cast把True和false转换为float类型 0,1
#把所有预测结果加在一起求精度
accr = tf.reduce_mean(tf.cast(pred,"float"))
init = tf.global_variables_initializer()
"""
#测试代码
sess = tf.InteractiveSession()
arr = np.array([[31,23,4,24,27,34],[18,3,25,4,5,6],[4,3,2,1,5,67]])
#返回数组的维数 2
print(tf.rank(arr).eval())
#返回数组的行列数 [3 6]
print(tf.shape(arr).eval())
#返回数组中每一列中最大元素的索引[0 0 1 0 0 2]
print(tf.argmax(arr,0).eval())
#返回数组中每一行中最大元素的索引[5 2 5]
print(tf.argmax(arr,1).eval())
J"""
#把所有样本迭代50次
training_epochs = 50
#每次迭代选择多少样本
batch_size = 100
display_step = 5
sess = tf.Session()
sess.run(init)
#循环迭代
for epoch in range(training_epochs):
avg_cost = 0
num_batch = int(mnist.train.num_examples/batch_size)
for i in range(num_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(optm,feed_dict = {x:batch_xs,y:batch_ys})
feeds = {x:batch_xs,y:batch_ys}
avg_cost += sess.run(cost,feed_dict = feeds)/num_batch
if epoch % display_step ==0:
feeds_train = {x:batch_xs,y:batch_ys}
feeds_test = {x:mnist.test.images,y:mnist.test.labels}
train_acc = sess.run(accr,feed_dict = feeds_train)
test_acc = sess.run(accr,feed_dict = feeds_test)
#每五个epoch打印一次信息
print("Epoch:%03d/%03d cost:%.9f train_acc:%.3f test_acc: %.3f" %(epoch,training_epochs,avg_cost,train_acc,test_acc))
print("Done")
程序训练结果如下:
Epoch:000/050 cost:1.177228655 train_acc:0.800 test_acc: 0.855
Epoch:005/050 cost:0.440933891 train_acc:0.890 test_acc: 0.894
Epoch:010/050 cost:0.383387268 train_acc:0.930 test_acc: 0.905
Epoch:015/050 cost:0.357281335 train_acc:0.930 test_acc: 0.909
Epoch:020/050 cost:0.341473956 train_acc:0.890 test_acc: 0.913
Epoch:025/050 cost:0.330586549 train_acc:0.920 test_acc: 0.915
Epoch:030/050 cost:0.322370980 train_acc:0.870 test_acc: 0.916
Epoch:035/050 cost:0.315942993 train_acc:0.940 test_acc: 0.916
Epoch:040/050 cost:0.310728854 train_acc:0.890 test_acc: 0.917
Epoch:045/050 cost:0.306357428 train_acc:0.870 test_acc: 0.918
Done
来源:https://blog.csdn.net/Missayaaa/article/details/80063512
0
投稿
猜你喜欢
- 核心导出作业的 代码 和 作业备份是相似的 代码如下:alter PROC DumpJob (@job VARCHAR(100)
- 目的测试一个对象是否是字符串方法Python的字符串的基类是basestring,包括了str和unicode类型。一般可以采用以下方法:d
- 代码如下:---1.平均销售等待时间 ---有一张Sales表,其中有销售日期与顾客两列,现在要求使用一条SQL语句实现计算 --每个顾客的
- vi /etc/freetds/freetds.conf [global]# TDS protocol versiontds version
- read()方法读取文件size个字节大小。如果读取命中获得EOF大小字节之前,那么它只能读取可用的字节。语法以下是read()
- 前言最近在用python写一个项目,发现一个很恶心的bug,就是同由一个类生成的两个实例之间的数据竟然会相互影响,这让我非常不解。后来联想到
- 大家好,我是丁小杰。上次和大家分享了Python定时爬取微博热搜示例介绍,堪称摸鱼神器,一个热榜不够看?今天我们再来爬取一下抖音热搜榜,感兴
- slice(切片)是 go 里面非常常用的一种数据结构,它代表了一个变长的序列,序列中的每个元素都有相同的数据类型。 一个 slice 类型
- 概要在自然语言处理(NLP)领域,情感分析及分类是一项十分热门的任务。它的目标是从文本中提取出情感信息和意义,通常分为两类:正向情感和负向情
- Python int() 函数描述int() 函数用于将一个字符串或数字转换为整型。语法以下是 int() 方法的语法:class int(
- 当然这应该属于正常过滤手法,而还有一种过滤HTML标签的最终极手法,则是将一对尖括号及尖括号中的所有字符均替换不显示,该方法对于内容中必须描
- 本文实例讲述了php防止sql注入中过滤分页参数的方法。分享给大家供大家参考。具体分析如下:就网络安全而言,在网络上不要相信任何输入信息,对
- 如下所示:str='abcdef'print(str.endswith('f'))print(str.sta
- 阅读上一篇:javascript面向对象编程(三)继承是面向对象语言中的一个重要概念,现在我们来探讨一下继承。在网上搜一下javascrip
- 前言首先线程和线程池不管在哪个语言里面,理论都是通用的。对于开发来说,解决高并发问题离不开对多个线程处理。我们先从线程到线程池,从每个线程的
- 简要pyinstaller模块主要用于python代码打包成exe程序直接使用,这样在其它电脑上即使没有python环境也是可以运行的。用法
- 1. 首先 进入cmd, 输入python,看python是否安装成功说明python安装,没有问题2. 修改注册表第一步window +
- 目录前言什么是装饰器Python 函数的基本特性函数名的本质:将函数作为变量使用:进一步实现装饰器使用Python装饰器语句:总结前言在 p
- 实例如下所:import osimport xlrdimport xlwtfrom xlutils.copy import copydef
- 直接使用model2=model1会出现当更新model2时,model1的权重也会更新,这和自己的初始目的不同。经评论指出可以使用:mod