tensorflow实现KNN识别MNIST
作者:freedom098 发布时间:2021-08-10 06:41:50
标签:tensorflow,KNN,MNIST
KNN算法算是最简单的机器学习算法之一了,这个算法最大的特点是没有训练过程,是一种懒惰学习,这种结构也可以在tensorflow实现。
KNN的最核心就是距离度量方式,官方例程给出的是L1范数的例子,我这里改成了L2范数,也就是我们常说的欧几里得距离度量,另外,虽然是叫KNN,意思是选取k个最接近的元素来投票产生分类,但是这里只是用了最近的那个数据的标签作为预测值了。
__author__ = 'freedom'
import tensorflow as tf
import numpy as np
def loadMNIST():
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
return mnist
def KNN(mnist):
train_x,train_y = mnist.train.next_batch(5000)
test_x,test_y = mnist.train.next_batch(200)
xtr = tf.placeholder(tf.float32,[None,784])
xte = tf.placeholder(tf.float32,[784])
distance = tf.sqrt(tf.reduce_sum(tf.pow(tf.add(xtr,tf.neg(xte)),2),reduction_indices=1))
pred = tf.argmin(distance,0)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
right = 0
for i in range(200):
ansIndex = sess.run(pred,{xtr:train_x,xte:test_x[i,:]})
print 'prediction is ',np.argmax(train_y[ansIndex])
print 'true value is ',np.argmax(test_y[i])
if np.argmax(test_y[i]) == np.argmax(train_y[ansIndex]):
right += 1.0
accracy = right/200.0
print accracy
if __name__ == "__main__":
mnist = loadMNIST()
KNN(mnist)
来源:http://blog.csdn.net/freedom098/article/details/52117330
0
投稿
猜你喜欢
- 上期我们介绍了函数式编程,这期内容就是关于递归的函数内容,本期还是按照老规矩,给大家进行核心整理,内容通俗易懂,搭配实际应用,以供大家理解。
- 数据分析师肯定每天都被各种各样的数据数据报表搞得焦头烂额,老板的,运营的、产品的等等。而且大部分报表都是重复性的工作,这篇文章就是帮助大家如
- 发现问题最近在打开项目的时候,发现我的默认路由没加载上linkActiveClass,网上一搜,发现很多同学也有这个问题,查了一些资料发现这
- MAC上的PyCharm中默认的python解释器是python2的,windows下的没用过不是很清楚,所以特来记录下设置python3解
- (1)、back_log:要求 MySQL 能有的连接数量。当主要MySQL线程在一个很短时间内得到非常多的连接请求,这就起作用,然后主线程
- Python not equal operator returns True if two variables are of same ty
- 1. 加载数据集这次我们搭建一个小小的多层线性网络对糖尿病的病例进行分类首先先导入需要的库文件先来看看我们的数据集观察可以发现,前八列是我们
- 其实相信每个和mysql打过交道的程序员都应该会尝试去封装一套mysql的接口,这一次的封装已经记不清是我第几次了,但是每一次我希望都能做的
- create procedure test_tran as set xact_abort on -----用@@error判断,对于严重的错
- 事情是这样的,博主初学python和机器学习,在跑一个代码的时候被提示出现以下错误:(能被提示出现这个错误,可见确实是初学了!)图1:跑代码
- 本文实例讲述了Python实现通过文件路径获取文件hash值的方法。分享给大家供大家参考,具体如下:import hashlibimport
- 本文给大家分享的是查看MySQL连接的root密码的方法,下面话不多说来来看正文:1.首先我们进到MySQL的bin目录下➜ cd /usr
- 最近遇到了一个下载静态html报表的需求,需要以提供压缩包的形式完成下载功能,实现的过程中发现相关文档非常杂,故总结一下自己的实现。开发环境
- web.config第一种方法:<?xml version="1.0" encoding="utf-8&
- 单继承时super()和__init__()实现的功能是类似的class Base(object):def __init__(self):p
- 在大三的时候,一直就想搭建属于自己的一个博客,但由于各种原因,最终都不了了之,恰好最近比较有空,于是就自己参照网上的教程,搭建了属于自己的博
- 到2019年初,Python3已经更新到了Python
- 本节我们首先来尝试识别最简单的一种验证码,图形验证码,这种验证码出现的最早,现在也很常见,一般是四位字母或者数字组成的,例如中国知网的注册页
- 这里我不想采用诸如ubuntu下的apt-get install方式进行python的安装,而是在linux下采用源码包的方式进行pytho
- 如果说亲密性原则是对元素的归类组合,是将元素之间逻辑理解上的差异在视觉上表现出来,是属于信息分类的话,那么对齐原则即是在视觉上串起这些差异化