pytorch 移动端部署之helloworld的使用
作者:东方佑 发布时间:2022-03-30 18:56:18
开始
安装Androidstudio 4.1
克隆此项目
git clone https://github.com/pytorch/android-demo-app.git
使用androidstudio 打开 android-demo-app 中的HelloWordApp
打开之后androidstudio 会自动创建依赖 只需要等待即可
这个代码已经是官方写好的故而
开一下官方教程中的代码都在什么位置
这句
repositories {
jcenter()
}
dependencies {
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
位置
HelloWorldApp\app\build.gradle
里面的全部代码
apply plugin: 'com.android.application'
repositories {
jcenter()
}
android {
compileSdkVersion 28
buildToolsVersion "29.0.2"
defaultConfig {
applicationId "org.pytorch.helloworld"
minSdkVersion 21
targetSdkVersion 28
versionCode 1
versionName "1.0"
}
buildTypes {
release {
minifyEnabled false
}
}
}
dependencies {
implementation 'androidx.appcompat:appcompat:1.1.0'
implementation 'org.pytorch:pytorch_android:1.4.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}
这句
Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
Module module = Module.load(assetFilePath(this, "model.pt"));
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
都在这里
HelloWorldApp\app\src\main\java\org\pytorch\helloworld\MainActivity.java
全部代码
package org.pytorch.helloworld;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import androidx.appcompat.app.AppCompatActivity;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Bitmap bitmap = null;
Module module = null;
try {
// creating bitmap from packaged into app android asset 'image.jpg',
// app/src/main/assets/image.jpg
bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module = Module.load(assetFilePath(this, "model.pt"));
} catch (IOException e) {
Log.e("PytorchHelloWorld", "Error reading assets", e);
finish();
}
// showing image on UI
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(bitmap);
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
// running the model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// getting tensor content as java array of floats
final float[] scores = outputTensor.getDataAsFloatArray();
// searching for the index with maximum score
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
// showing className on UI
TextView textView = findViewById(R.id.text);
textView.setText(className);
}
/**
* Copies specified asset to the file in /files app directory and returns this file absolute path.
*
* @return absolute file path
*/
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
在Build 中选择Build Bundile APK 的 Build APK 就可以了
生成的apk 在
HelloWorldApp\app\build\outputs\apk\debug
中 这个是可以直接安装的
安装后是一个固定的照片 就是检测了一个固定的照片
这是一个例子如果你只是想测试自己的模型调用能不能成功这个项目改改模型和模型加载即可
这个项目模型是一个resnet18 接着我们将其替换为resnet50
模型转换代码如下
import torch
import torchvision.models as models
from PIL import Image
import numpy as np
image = Image.open("test.jpg") #图片发在了build文件夹下
image = image.resize((224, 224),Image.ANTIALIAS)
image = np.asarray(image)
image = image / 255
image = torch.Tensor(image).unsqueeze_(dim=0)
image = image.permute((0, 3, 1, 2)).float()
model = models.resnet50(pretrained=True)
model = model.eval()
resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
# output=resnet(torch.ones(1,3,224,224))
output = resnet(image)
max_index = torch.max(output, 1)[1].item()
print(max_index) # ImageNet1000类的类别序
resnet.save('model.pt')
if __name__ == '__main__':
pass
将这个保存的模型 覆盖掉下面路径中的模型
(在覆盖之前最好备份一个原来的模型,这里我们选择修改原来模型的名字为model_1.pt)
HelloWorldApp\app\src\main\assets\model.pt
成功覆盖后再一次执行打包操作(在Build 中选择Build Bundile APK 的 Build APK 就可以了
生成的apk 在HelloWorldApp\app\build\outputs\apk\debug)
而后打开文件发现一个123M的apk 之前的apk是73M
安装 并且测试
完美打开也就是说一切resnet 系列的 都可以通过这个 项目进行演化出来
来源:https://blog.csdn.net/weixin_32759777/article/details/109380404
猜你喜欢
- 环境:win10+phpstorm2022+phpstudy8+lnmp1、phpinfo(); 查看是否安装xdebug,没有
- 锁定数据库的一个表 SELECT * FROM table WITH (HOLDLOCK) 注意: 锁定数据库的一个表的区别 SELECT
- 本文实例讲述了Python定义二叉树及4种遍历方法。分享给大家供大家参考,具体如下:Python & BinaryTree1. Bi
- 前言容器数据类型包括数组list,字典dict以及元组tuple等。本篇,将详细介绍ChainMap字典序列的使用。ChainMapChai
- 现实生活中,有很多场景中的事情是同时进行的,比如开车的时候,手和脚共同来驾驶汽车,再比如唱歌跳舞也是同时进行的。以上这些可以理解为多任务。那
- #!/usr/bin/env python## Copyright 2009 Facebook## Licensed under the A
- 什么是固件Fixture 翻译成中文即是固件的意思。它其实就是一些函数,会在执行测试方法/测试函数之前(或之后)加载运行它们,常见的如接口用
- http://pyhdfs.readthedocs.io/en/latest/1:安装由于是windows环境(linux其实也一样),只要
- 1. 问题描述对右图进行修改:请更换图形的风格请将 x 轴的数据改为-10 到 10请自行构造一个 y 值的函数将直方图上的数字,位置改到柱
- 如下所示,代码为:array也可直接使用上面代码。测试如下:来源:https://blog.csdn.net/u011624019/arti
- 如果我们的web应用有大量的异步请求,而这些异步请求是在web服务器认证的情况下,那当我们请求发生在服务器认证失效下,服务器自动302到登录
- 这是一个系列文章,主要分享python的使用建议和技巧,每次分享3点,希望你能有所收获。1 如何去掉list中重复元素my_list = [
- 本地环境设置在这里我们介绍设置Go编程语言环境,需要在你的计算机上的准备以下两个软件,(A)文本编辑器和(B)Go编译器。文本编辑器这将用来
- 使用Python的人都知道range()函数和list很方便,今天再用到他的时候发现了很多以前看到过但是忘记的细节。这里记录一下range(
- 这篇文章主要介绍了Python线程条件变量Condition原理解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习
- 目录简介时间分类TimestampDatetimeIndexdate_range 和 bdate_rangeorigin格式化PeriodD
- CSS选择器目前,除了官方文档之外,市面上及网络详细介绍BeautifulSoup使用的技术书籍和博客软文并不多,而在这仅有的资料中介绍CS
- 在我的职业生涯中,我写过、用过和看到过很多随意的脚本。一些人需要半自动化完成任务,于是它们诞生了。一段时间后,它们变得越来越大。它们在一生中
- 本文实例讲述了Python实现队列的方法。分享给大家供大家参考,具体如下:Python实现队列队列(FIFO),添加元素在队列尾,删除元素在
- 主要作用与拷贝文件用的。1.shutil.copyfileobj(文件1,文件2):将文件1的数据覆盖copy给文件2。import shu