轻松实现TensorFlow微信跳一跳的AI
作者:zhanys_7 发布时间:2021-11-24 10:35:40
作为python和机器学习的初学者,目睹了AI玩游戏的各种风骚操作,心里也是跃跃欲试。
然后发现微信跳一跳很符合需求,因为它不需要处理连续画面(截屏太慢了)和复杂的操作,很适合拿来练手。于是…这个东西诞生了,目前它一般都可以跳到100多分,发挥好了能上200。
1.需要设备:
Android手机,数据线
ADB环境
Python环境(本例使用3.6.1)
TensorFlow(本例使用1.0.0)
2.大致原理
使用adb模拟点击和截屏,使用两层卷积神经网络作为训练模型,截屏图片作为输入,按压毫秒数直接作为为输出。
3.训练过程
最开始想的用强化学习,然后发现让它自己去玩成功率太!低!了!,加上每次截屏需要大量时间,就放弃了这个方法,于是考虑用自己玩的数据作为样本喂给它,这样就需要知道每次按压的时间。
我是这样做的,找一个手机写个app监听按压屏幕时间,另一个手机玩游戏,然后两个手指同时按两个手机o(╯□╰)o
4.上代码
首先,搭建模型:
第一层卷积:5*5的卷积核,12个featuremap,此时形状为96*96*12
池化层:4*4 max pooling,此时形状为24*24*12
第二层卷积:5*5的卷积核,24个featuremap,此时形状为20*20*24
池化层:4*4 max pooling,此时形状为5*5*24
全连接层:5*5*24连接到32个节点,使用relu激活函数和0.4的dropout率
输出:32个节点连接到1个节点,此节点就代表按压的时间(单位s)
# 输入:100*100的灰度图片,前面的None是batch size,这里都为1
x = tf.placeholder(tf.float32, shape=[None, 100, 100, 1])
# 输出:一个浮点数,就是按压时间,单位s
y_ = tf.placeholder(tf.float32, shape=[None, 1])
# 第一层卷积 12个feature map
W_conv1 = weight_variable([5, 5, 1, 12], 0.1)
b_conv1 = bias_variable([12], 0.1)
# 卷积后为96*96*12
h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1)
h_pool1 = max_pool_4x4(h_conv1)
# 池化后为24*24*12
# 第二层卷积 24个feature map
W_conv2 = weight_variable([5, 5, 12, 24], 0.1)
b_conv2 = bias_variable([24], 0.1)
# 卷积后为20*20*24
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_4x4(h_conv2)
# 池化后为5*5*24
# 全连接层5*5*24 --> 32
W_fc1 = weight_variable([5 * 5 * 24, 32], 0.1)
b_fc1 = bias_variable([32], 0.1)
h_pool2_flat = tf.reshape(h_pool2, [-1, 5 * 5 * 24])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
# drapout,play时为1训练时为0.6
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
# 学习率
learn_rate = tf.placeholder(tf.float32)
# 32 --> 1
W_fc2 = weight_variable([32, 1], 0.1)
b_fc2 = bias_variable([1], 0.1)
y_fc2 = tf.matmul(h_fc1_drop, W_fc2) + b_fc2
# 因输出直接是时间值,而不是分类概率,所以用平方损失
cross_entropy = tf.reduce_mean(tf.square(y_fc2 - y_))
train_step = tf.train.AdamOptimizer(learn_rate).minimize(cross_entropy)
其次,获取屏幕截图并转换为模型输入:
# 获取屏幕截图并转换为模型的输入
def get_screen_shot():
# 使用adb命令截图并获取图片,这里如果把后缀改成jpg会导致TensorFlow读不出来
os.system('adb shell screencap -p /sdcard/jump_temp.png')
os.system('adb pull /sdcard/jump_temp.png .')
# 使用PIL处理图片,并转为jpg
im = Image.open(r"./jump_temp.png")
w, h = im.size
# 将图片压缩,并截取中间部分,截取后为100*100
im = im.resize((108, 192), Image.ANTIALIAS)
region = (4, 50, 104, 150)
im = im.crop(region)
# 转换为jpg
bg = Image.new("RGB", im.size, (255, 255, 255))
bg.paste(im, im)
bg.save(r"./jump_temp.jpg")
img_data = tf.image.decode_jpeg(tf.gfile.FastGFile('./jump_temp.jpg', 'rb').read())
# 使用TensorFlow转为只有1通道的灰度图
img_data_gray = tf.image.rgb_to_grayscale(img_data)
x_in = np.asarray(img_data_gray.eval(), dtype='float32')
# [0,255]转为[0,1]浮点
for i in range(len(x_in)):
for j in range(len(x_in[i])):
x_in[i][j][0] /= 255
# 因为输入shape有batch维度,所以还要套一层
return [x_in]
以上代码过程大概是这样:
最后,开始训练:
while True:
…………
# 每训练100个保存一次
if train_count % 100 == 0:
saver_init.save(sess, "./save/mode.mod")
…………
sess.run(train_step, feed_dict={x: x_in, y_: y_out, keep_prob: 0.6, learn_rate: 0.00005})
训练所用数据是直接从采集好的文件中读取的,由于样本有限(目前采集了800张图和对应800个按压时间,在github上train_data文件夹里),并且学习率太大又会震荡,只能用较小学习率反复学习这些图片。
5.总结
1.样本的按压时间大都分布在300ms到900ms之间,刚开始训练的时候发现不论什么输入,输出都一直很谨慎的停留在600左右,还以为这种方法不可行。不过半个小时后再看发现已经有效果了,对于不同的输入,输出值差距开始变大了。所以…相信卷积网络的威力,多给它点耐心。
2.由于我自己最多玩到100多分,后面的数据没法采集到,所以当后面物体变得越来越小时,这个AI也会变得容易挂掉。理论上说让它自己探索不会有这个瓶颈,只是截屏时间实在难以忍受。
3.目前还是初级的版本,有很多可以优化的地方,比如说识别左上角的分数,如果某次跳跃得分较高,那么可以把这次的学习率增大;检测特殊物体,比如超市音乐盒,就停留几秒再进行下一次跳跃,等等。
下面是github地址,源码加注释总共不到300行:
https://github.com/zhanyongsheng/LetsJump
更多内容大家可以参考专题《微信跳一跳》进行学习。
来源:http://blog.csdn.net/zhanys_7/article/details/78940763
猜你喜欢
- 发一个数字拼图游戏,有点小疑问前几天写得,其中一段代码还要感谢“簡簡單單愛妳”的提示,不过我还是不太明白, ,有点笨。 $(&qu
- min()方法返回它的参数最小值:最接近负无穷大的值。语法以下是min()方法的语法:min( x, y, z, .... )参
- HTML是万维网上发布超文本的通用语言[1]。从1982年Tim Berners-Lee简化SGML建立HTML的原始定义到2001年发布X
- 本文实例讲述了python使用mailbox打印电子邮件的方法。分享给大家供大家参考。具体如下:该范例在linux下使用import mai
- 隐患一:如果客户端机器的cookie一旦因病毒而失效了,那么session也就相当于没有了。 隐患二:session在php中默认的是以文件
- 高能预警本文包含演示部分,请读者自行copy代码编译体验。参考资料:sync.WaitGroup / signal.Notify / con
- 引子Matlab中有一个函数叫做find,可以很方便地寻找数组内特定元素的下标,即:Find indices and values of n
- 外联接。外联接可以是左向外联接、右向外联接或完整外部联接。 在 FROM 子句中指定外联接时,可以由下列几组关键字中的一组指定:LEFT J
- 单线程+多任务异步协程协程在函数(特殊函数)定义的时候,使用async修饰,函数调用后,内部语句不会立即执行,而是会返回一个协程对象任务对象
- 引言 性能是一个特征。您必须预先设计性能,否则您以后就得重写应用程序。就是说,有哪些好的策略可使 Active Server Pages (
- 并行查询其优势就是可以通过多个线程来处理查询作业,从而提高查询的效率。SQL Server数据库为具有多个CPU的数据库服务器提供并行查询的
- 本来想控制鼠标自动移动防止公司电脑自动休眠的策略,然而,实现了并没什么卵用,还是会休眠。但还是分享出来吧。win10的系统。首先要安装几个第
- SQLServer中建立与服务器的连接时出错的解决方案如下:步骤1:在SQLServer 实例上启用远程连接1.指向“开始->程序-&
- Float(浮动)概念也许是CSS中最让人迷惑的一个概念吧。Float经常被错误理解,而且因为将上下文元素全部浮动导致的可读性、
- 最近尝试把项目迁移到Python环境下,特别新装了一台干净的Debian系统,准备重新配置环境,上网找了一些运行Python Web的环境方
- 人的大脑通过双眼来辨别视觉图形获取信息。大脑根据储存的经验,将所看到的视觉图形建立起优先级。由此可见,一个良好的视觉设计可以帮助大脑迅速有效
- 前言: 上一篇讲了Python排序问题中比较经典的三个方法,(链接:关于Python排
- 问题一:安装模块时出现报错 Microsoft Visual C++ 14.0 is required,也下载安装了运行库依然还是
- 本文实例讲述了Python利用matplotlib绘制约数个数统计图。分享给大家供大家参考,具体如下:利用Python计算1000以内自然数
- CSS中最常用的布局类属性,一个是Float(CSS浮动属性Float详解),另一个就是CSS定位属性Position。1. positio