tensorflow 2.1.0 安装与实战教程(CASIA FACE v5)
作者:博二兔 发布时间:2022-06-12 08:41:23
1.0tensorflow的安装
1.1安装python
python下载 需要python3.x<=3.7
https://www.python.org/ftp/python/3.7.7/python-3.7.7-amd64.exe
安装时勾选Add Python 3.7 to PATH,把python添加到环境变量。
1.2安装tensorflow
打开命令行,执行
pip install tensorflow==2.1.0
pip 会安装tensorflow和一些其他的依赖
1.3安装vc++2015-2019redist…
tensorflow的另一个依赖(很多tensorflow安装失败的原因就是这个没安装)
https://support.microsoft.com/en-us/help/2977003/the-latest-supported-visual-c-downloads
1.4安装CUDA和CUDNN
cuda: https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=10&target_type=exelocal
cudnn: https://developer.nvidia.com/rdp/cudnn-download(需要注册nvidia账号)
cudnn下载后是个压缩文件,要把他解压出来放在CUDA里,如下图
高版本CUDA缺失cudart64_101.dll,下载后放在C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v10.2\bin里
https://cn.dll-files.com/cudart64_101.dll.html
2.0CASIA实战
2.1CASIA数据集
可以从网上下载casia数据集,
这里以casia数据集为例,现实中可以使用自己需要的数据集。
2.2数据集的处理
建立data和test两个文件夹,把casia复制到里面
目录是这样的./data/000/000_0.bmp
data.py处理数据,其实就是遍历,匹配,删除
import os
data = './data'
dirs = os.listdir(data)
for dir in dirs:
for file in os.listdir(data + '/' + dir):
if file.endswith("4.bmp"):
os.remove(data + '/' + dir + '/' + file)
test = './test'
tdirs = os.listdir(test)
for dir in tdirs:
for file in os.listdir(test + '/' + dir):
if file.endswith("0.bmp"):
os.remove(test + '/' + dir + '/' + file)
if file.endswith("1.bmp"):
os.remove(test + '/' + dir + '/' + file)
if file.endswith("2.bmp"):
os.remove(test + '/' + dir + '/' + file)
if file.endswith("3.bmp"):
os.remove(test + '/' + dir + '/' + file)
2.3训练代码
casia.py
import os
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
/*我直接建立了个0000,1111,...这样的数组作为标签*/
#data标签
arr = []
for i in range(100):
for j in range(4):
arr.append(i)
arr = np.array(arr)
#test标签
tarr = []
for i in range(100):
tarr.append(i)
tarr = np.array(tarr)
#训练集
pwd='./data'
dirs = os.listdir(pwd)
imgs = []
for dir in dirs:
for file in os.listdir(pwd + '/' + dir):
image = tf.io.read_file(pwd + '/' + dir + '/' + file)
img = tf.image.decode_bmp(image,channels=3)
imgs.append(img)
print("[*]训练集加载完毕")
print(imgs[0].shape)
#验证集(测试集)
tpwd='./test'
tdirs = os.listdir(tpwd)
timgs = []
for tdir in tdirs:
for tfile in os.listdir(tpwd + '/' + tdir):
timage = tf.io.read_file(tpwd + '/' + tdir + '/' + tfile)
timg = tf.image.decode_bmp(timage,channels=3)
timgs.append(timg)
print("[*]验证集加载完毕")
print(timgs[0].shape)
#神经网络模型
model = Sequential([
Conv2D(16, (3,3), padding='same', activation='relu',input_shape=(480,640,3)),
MaxPooling2D(),
Conv2D(64, (3,3), padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(128, (3,3), padding='same', activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(128, activation='relu'),
Dense(100, activation='softmax'),
])
model.summary()//打印神经网络模型
#优化器
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
#训练
ds = tf.data.Dataset.from_tensor_slices((imgs,arr))
ds = ds.batch(16)
ds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
model.fit(ds,epochs=20)
tds = tf.data.Dataset.from_tensor_slices((timgs,tarr))
tds = ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
model.evaluate(tds, verbose=2)
#保存
tf.saved_model.save(model, "./tmp/")
2.4训练与验证
在命令行运行 python casia.py进行训练
predict.py
import os
import tensorflow as tf
import numpy as np
/*这里显卡内存不够了*/
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
/*显卡内存*/
model_path = './tmp' //加载模型
test_path = "./test/002/002_4.bmp"//这里就是个栗子
model = tf.keras.models.load_model(model_path, custom_objects=None, compile=True)
image = tf.io.read_file(test_path)
img = tf.image.decode_bmp(image,channels=3)
img = img[tf.newaxis, ...]
res = model.predict(
img, batch_size=None, verbose=0, steps=None, callbacks=None, max_queue_size=10,
workers=1, use_multiprocessing=False
)
pred = tf.argmax(res, axis=1)
print (pred[0])
print (res[0,pred[0]])
来源:https://blog.csdn.net/weixin_44753738/article/details/107032785


猜你喜欢
- 一直有耳闻MySQL5.5的性能非常NB,所以近期打算测试一下,方便的时候就把bbs.kaoyan.com升级到这个版本的数据库。今天正好看
- 本文将详细介绍一下如何搭建深度学习所需要的实验环境.这个框架分为以下六个模块显卡简单理解这个就是我们常说的GPU,显卡的功能是一个专门做矩阵
- 级联查询在ORACLE 数据库中有一种方法可以实现级联查询select * //要查询的字段from table
- pandas提供了一个灵活高效的groupby功能,它使你能以一种自然的方式对数据集进行切片、切块、摘要等操作。根据一个或多个键(可以是函数
- 静态方法:将下面的代码复制到<body>~</body>内 程序代码 <table cellpadd
- 学Python中,自我感觉学的还不错的亚子~想做点什么来练练手,然后我疯狂的找各种小游戏的教程源码什么的,于是我就疯狂的找呀找呀,就找到了一
- 前言thinkphp3.1.2 需要使用cli方法运行脚本折腾了一天才搞定3.1.2的版本真的很古老解决增加cli.php入口文件defin
- 在同一个 Apache 实例中运行多个 Django 程序是完全可能的。 当你是一个独立的 Web 开发人员并有多个不同的客户时,你可能会想
- Javascript 正常取来源网页的URL只要用: document.referrer就可以了!但,如果来源页是Jav
- 一、前言该测试功能是Linux产测软件的一个子功能,主要涉及:140行代码PySide2的Event、信号和槽、QLabel,QWidget
- python tkinter按钮Button的使用创建和设置窗口from tkinter import *#创建窗口对象root = Tk(
- 这片文章只对本地存储方法做介绍,若要查看本地存储组件使用方法的介绍请稍等。本地数据持久化(或者也叫做浏览器本地存储)是一种在浏览器中长久保存
- 先判断是jquery对象还是html对象, 如果是jquery对象, 可以直接用 jquery对象.attr("
- 后续代码更新和功能添加会提交到个人github主页,有兴趣可以一起来完善!如果只是拿过去运行看结果,请注意平台相关性以及python版本号,
- 首先是对一元函数求积分,使用Scipy下的integrate函数:from scipy import integratedef g(x):
- 本文通过将同一个数据集在三种不同的简便项窗口部件中显示。三个窗口的数据得到实时的同步,数据和视图分离。当添加或删除数据行,三个不同的视图均保
- 文章介绍内容:操作MySQL数据库:创建MySQL数据表;向表中插入记录;其他数据库操作。面试题:如何创建MySQL数据表?如何向MySQL
- 许多人也许会注意到一个现象,那就是在一些现代编程语言(当然,并不是指“最近出现”的编程语言)中,自增
- 什么是下载?首先客户端会问服务器,有没有一个xxx的文件啊?服务器开始寻找,找到后对客户端说有,然后客户端在本地新建一个文件,客户端从服务器
- 方法一使用以下流式代码,无论下载文件的大小如何,Python 内存占用都不会增加:def download_file(url):