pytorch 移动端部署之helloworld的使用
作者:东方佑 发布时间:2022-03-30 18:56:18
开始
安装Androidstudio 4.1
克隆此项目
git clone https://github.com/pytorch/android-demo-app.git
使用androidstudio 打开 android-demo-app 中的HelloWordApp
打开之后androidstudio 会自动创建依赖 只需要等待即可
这个代码已经是官方写好的故而
开一下官方教程中的代码都在什么位置
这句
repositories {
jcenter()
}
dependencies {
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
位置
HelloWorldApp\app\build.gradle
里面的全部代码
apply plugin: 'com.android.application'
repositories {
jcenter()
}
android {
compileSdkVersion 28
buildToolsVersion "29.0.2"
defaultConfig {
applicationId "org.pytorch.helloworld"
minSdkVersion 21
targetSdkVersion 28
versionCode 1
versionName "1.0"
}
buildTypes {
release {
minifyEnabled false
}
}
}
dependencies {
implementation 'androidx.appcompat:appcompat:1.1.0'
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
这句
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
Module module = Module.load(assetFilePath(this, "model.pt"));
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
都在这里
HelloWorldApp\app\src\main\java\org\pytorch\helloworld\MainActivity.java
全部代码
package org.pytorch.helloworld;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import androidx.appcompat.app.AppCompatActivity;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Bitmap bitmap = null;
Module module = null;
try {
// creating bitmap from packaged into app android asset 'image.jpg',
// app/src/main/assets/image.jpg
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module = Module.load(assetFilePath(this, "model.pt"));
} catch (IOException e) {
Log.e("PytorchHelloWorld", "Error reading assets", e);
finish();
}
// showing image on UI
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(bitmap);
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
// running the model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// getting tensor content as java array of floats
final float[] scores = outputTensor.getDataAsFloatArray();
// searching for the index with maximum score
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
// showing className on UI
TextView textView = findViewById(R.id.text);
textView.setText(className);
}
/**
* Copies specified asset to the file in /files app directory and returns this file absolute path.
*
* @return absolute file path
*/
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
在Build 中选择Build Bundile APK 的 Build APK 就可以了
生成的apk 在
HelloWorldApp\app\build\outputs\apk\debug
中 这个是可以直接安装的
安装后是一个固定的照片 就是检测了一个固定的照片
这是一个例子如果你只是想测试自己的模型调用能不能成功这个项目改改模型和模型加载即可
这个项目模型是一个resnet18 接着我们将其替换为resnet50
模型转换代码如下
import torch
import torchvision.models as models
from PIL import Image
import numpy as np
image = Image.open("test.jpg") #图片发在了build文件夹下
image = image.resize((224, 224),Image.ANTIALIAS)
image = np.asarray(image)
image = image / 255
image = torch.Tensor(image).unsqueeze_(dim=0)
image = image.permute((0, 3, 1, 2)).float()
model = models.resnet50(pretrained=True)
model = model.eval()
resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
# output=resnet(torch.ones(1,3,224,224))
output = resnet(image)
max_index = torch.max(output, 1)[1].item()
print(max_index) # ImageNet1000类的类别序
resnet.save('model.pt')
if __name__ == '__main__':
pass
将这个保存的模型 覆盖掉下面路径中的模型
(在覆盖之前最好备份一个原来的模型,这里我们选择修改原来模型的名字为model_1.pt)
HelloWorldApp\app\src\main\assets\model.pt
成功覆盖后再一次执行打包操作(在Build 中选择Build Bundile APK 的 Build APK 就可以了
生成的apk 在HelloWorldApp\app\build\outputs\apk\debug)
而后打开文件发现一个123M的apk 之前的apk是73M
安装 并且测试
完美打开也就是说一切resnet 系列的 都可以通过这个 项目进行演化出来
来源:https://blog.csdn.net/weixin_32759777/article/details/109380404


猜你喜欢
- Python操作Mysql最近在学习python,这种脚本语言毫无疑问的会跟数据库产生关联,因此这里介绍一下如何使用python操作mysq
- 1.最大值max(3,4) ##运行结果为42.最小值min(3,4) ##运行结果为33.求和sum(range
- 第一种:获取不带后缀的文件名,直接上代码:就是直接用basename()函数就可以返回路径中的文件名部分,其语法是“basename(pat
- 本文实例讲述了Python使用装饰器模拟用户登陆验证功能。分享给大家供大家参考,具体如下:# -*- coding:utf-8 -*-#!p
- pandas loc的指定条件索引(布尔索引)pandas中的loc不仅仅可以用于直接的标签的索引,也可以用于指定条件的索引。1.准备数据首
- 开发工具Python版本:3.6.4相关模块:cv2模块;以及一些Python自带的模块。环境搭建安装Python并添加到环境变量,pip安
- 本文实例讲述了Python运维自动化之nginx配置文件对比操作。分享给大家供大家参考,具体如下:文件差异对比diff.py#!/usr/b
- 写过PHP+MySQL的程序员都知道有时间差,UNIX时间戳和格式化日期是我们常打交道的两个时间表示形式,Unix时间戳存储、处理方便,但是
- http协议是无状态的。下一次去访问一个页面时并不知道上一次对这个页面做了什么。无状态的应用层面的原因是:浏览器和服务器之间的通信都遵守HT
- 本文实例为大家分享了php微信公众号开发之快递查询的具体代码,供大家参考,具体内容如下快递查询数组用法foreach查询接口是:爱快递:ht
- 一般来说,造成MySQL出现中文乱码的因素主要有下列几点:1.server本身字符集设定的问题,例如还停留在latin12.table的语系
- 闪回区爆满问题也是经常会遇到的问题,最关键的是闪回设置大小以及归档被默认存放在了闪回目录,恰巧今天又遇到了这个问题,就记录下处理步骤,仅供遇
- 本文实例讲述了微信小程序开发之animation循环动画实现的让云朵飘效果。分享给大家供大家参考,具体如下:微信小程序提供了实现动画的api
- 网页开发人员常常希望能够了解并掌握多种语言,结果是,学习一门语言的所有内容是棘手的,但是却很容易发现你并没有完全利用那些比较特殊却很有用的标
- 分组取TOP数据是T-SQL中的常用查询, 如学生信息管理系统中取出每个学科前3名的学生。这种查询在SQL Server 2005之前,写起
- 前言不管是哪门语言,千变万化不离其宗,深入理解其本质,方能应用自如。对应到js,闭包,原型,函数,对象等是需要花费大功夫思考、理解的。本文穿
- 用Python+ChatGPT批量生成论文概述做算法研究离不开阅读大量论文。从海量论文中找到需要的论文往往耗费算法团队不少的精力。ChatG
- 处理中文在进行写文件时,必须采用以下方式:tree.write(nxmlpath, "UTF-8")如果写成:tree.
- 自定义事件也可以用来创建自定义的表单输入组件,使用 v-model 来进行数据双向绑定。所以要让组件的 v-model 生效,它必须:接受一
- 一个线上问题的引发的思考 今天上班的时候,开发的同事拿过来一个.zip的压缩包文件,说是要把里面的数据