如何将pytorch模型部署到安卓上的方法示例
作者:AI浩 发布时间:2023-03-15 15:12:12
标签:pytorch,模型,部署,安卓
这篇文章演示如何将训练好的pytorch模型部署到安卓设备上。我也是刚开始学安卓,代码写的简单。
环境:
pytorch版本:1.10.0
模型转化
pytorch_android支持的模型是.pt模型,我们训练出来的模型是.pth。所以需要转化才可以用。先看官网上给的转化方式:
import torch
import torchvision
from torch.utils.mobile_optimizer import optimize_for_mobile
model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
optimized_traced_model = optimize_for_mobile(traced_script_module)
optimized_traced_model._save_for_lite_interpreter("app/src/main/assets/model.ptl")
这个模型在安卓对应的包:
repositories {
jcenter()
}
dependencies {
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'
}
注:pytorch_android_lite版本和转化模型用的版本要一致,不一致就会报各种错误。
目前用这种方法有点问题,我采用的另一种方法。
转化代码如下:
import torch
import torch.utils.data.distributed
# pytorch环境中
model_pth = 'model_31_0.96.pth' #模型的参数文件
mobile_pt ='model.pt' # 将模型保存为Android可以调用的文件
model = torch.load(model_pth)
model.eval() # 模型设为评估模式
device = torch.device('cpu')
model.to(device)
# 1张3通道224*224的图片
input_tensor = torch.rand(1, 3, 224, 224) # 设定输入数据格式
mobile = torch.jit.trace(model, input_tensor) # 模型转化
mobile.save(mobile_pt) # 保存文件
对应的包:
//pytorch
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
定义模型文件和转化后的文件路径。
load模型。这里要注意,如果保存模型
torch.save(model,'models.pth')
加载模型则是
model=torch.load('models.pth')
如果保存模型是
torch.save(model.state_dict(),"models.pth")
加载模型则是
model.load_state_dict(torch.load('models.pth'))
定义输入数据格式。
模型转化,然后再保存模型。
安卓部署
新建项目
新建安卓项目,选择Empy Activity,然后选择Next
然后,填写项目信息,选择安卓版本,我用的4.4,点击完成
导入包
导入pytorch_android的包
//pytorch
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
如果有参数报错请参照我的完整的配置,代码如下:
plugins {
id 'com.android.application'
}
android {
compileSdk 32
defaultConfig {
applicationId "com.example.myapplication"
minSdk 21
targetSdk 32
versionCode 1
versionName "1.0"
testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner"
}
buildTypes {
release {
minifyEnabled false
proguardFiles getDefaultProguardFile('proguard-android-optimize.txt'), 'proguard-rules.pro'
}
}
compileOptions {
sourceCompatibility JavaVersion.VERSION_1_8
targetCompatibility JavaVersion.VERSION_1_8
}
}
dependencies {
implementation 'androidx.appcompat:appcompat:1.3.0'
implementation 'com.google.android.material:material:1.4.0'
implementation 'androidx.constraintlayout:constraintlayout:2.0.4'
testImplementation 'junit:junit:4.13.2'
androidTestImplementation 'androidx.test.ext:junit:1.1.3'
androidTestImplementation 'androidx.test.espresso:espresso-core:3.4.0'
//pytorch
implementation 'org.pytorch:pytorch_android:1.10.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.10.0'
}
页面文件
页面的配置如下:
<?xml version="1.0" encoding="utf-8"?>
<FrameLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<ImageView
android:id="@+id/image"
android:layout_width="match_parent"
android:layout_height="match_parent"
android:scaleType="fitCenter" />
<TextView
android:id="@+id/text"
android:layout_width="match_parent"
android:layout_height="wrap_content"
android:layout_gravity="top"
android:textSize="24sp"
android:background="#80000000"
android:textColor="@android:color/holo_red_light" />
</FrameLayout>
这个页面只有两个空间,一个展示图片,一个显示文字。
模型推理
新增assets文件夹,然后将转化的模型和待测试的图片放进去。
新增ImageNetClasses类,这个类存放类别名字。
代码如下:
package com.example.myapplication;
public class ImageNetClasses {
public static String[] IMAGENET_CLASSES = new String[]{
"Black-grass",
"Charlock",
"Cleavers",
"Common Chickweed",
"Common wheat",
"Fat Hen",
"Loose Silky-bent",
"Maize",
"Scentless Mayweed",
"Shepherds Purse",
"Small-flowered Cranesbill",
"Sugar beet",
};
}
在MainActivity类中,增加模型推理的逻辑。完成代码如下:
package com.example.myapplication;
import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;
import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;
import org.pytorch.MemoryFormat;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import androidx.appcompat.app.AppCompatActivity;
public class MainActivity extends AppCompatActivity {
@Override
protected void onCreate(Bundle savedInstanceState) {
super.onCreate(savedInstanceState);
setContentView(R.layout.activity_main);
Bitmap bitmap = null;
Module module = null;
try {
// creating bitmap from packaged into app android asset 'image.jpg',
// app/src/main/assets/image.jpg
bitmap = BitmapFactory.decodeStream(getAssets().open("1.png"));
// loading serialized torchscript module from packaged into app android asset model.pt,
// app/src/model/assets/model.pt
module = Module.load(assetFilePath(this, "models.pt"));
} catch (IOException e) {
Log.e("PytorchHelloWorld", "Error reading assets", e);
finish();
}
// showing image on UI
ImageView imageView = findViewById(R.id.image);
imageView.setImageBitmap(bitmap);
// preparing input tensor
final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB, MemoryFormat.CHANNELS_LAST);
// running the model
final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
// getting tensor content as java array of floats
final float[] scores = outputTensor.getDataAsFloatArray();
// searching for the index with maximum score
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
maxScore = scores[i];
maxScoreIdx = i;
}
}
System.out.println(maxScoreIdx);
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];
// showing className on UI
TextView textView = findViewById(R.id.text);
textView.setText(className);
}
/**
* Copies specified asset to the file in /files app directory and returns this file absolute path.
*
* @return absolute file path
*/
public static String assetFilePath(Context context, String assetName) throws IOException {
File file = new File(context.getFilesDir(), assetName);
if (file.exists() && file.length() > 0) {
return file.getAbsolutePath();
}
try (InputStream is = context.getAssets().open(assetName)) {
try (OutputStream os = new FileOutputStream(file)) {
byte[] buffer = new byte[4 * 1024];
int read;
while ((read = is.read(buffer)) != -1) {
os.write(buffer, 0, read);
}
os.flush();
}
return file.getAbsolutePath();
}
}
}
然后运行。
来源:https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/122860445
0
投稿
猜你喜欢
- 用比较笨的方法来做abc ="AlkjA;lkjlkjAlkAkjAlkjAAAA" if instr(abc,&quo
- Python 多进程和数据传递的理解python不仅线程用的是系统原生线程,进程也是用的原生进程进程的用法和线程大同小异import mul
- 阅读上一篇:W3C优质网页小贴士(三)明智地选择 URI没有什么比走到你最喜欢的商店门口,却发现店门紧闭,而且没有看见店面搬迁告示这种事情还
- HTML5 是近十年来 Web 标准最巨大的飞跃。和以前的版本不同,HTML 5 并非仅仅用来表示 Web 内容,它的使命是将 W
- 垃圾评论,垃圾留言,人见人憎,用了验证码,效果也好不到哪里去,还影响用户体验。有的网站甚至不惜牺牲用户体验,而构造强悍的惨不忍睹的超级验证码
- 最近做一个的GUI,因为调用了os模块里的system方法,使用pyinstaller打包的时候选择不输出系统命令弹框,程序无法运行,要求要
- 今天在GOOGLE上查图片资料,这一幕真让我纠结啊:使用【向前】【向后】这种说法,就默认了有一个对比坐标,那就是当前显示的4张缩略图。点击【
- 秦歌这篇文章总结得很不错,俺挑刺来啦:1. 优先级的描述不严谨,有 !important 时,网页样式可以覆盖用户自定义样式。用户!impo
- 通过百度云API接口抽取得到产品评论的观点,也掠去了很多评论中无用的内容以及符号,为后续进行文本主题挖掘或者规则的提取提供基础。工具 1、百
- 搭建FTP,或者是搭建网络文件系统,这些方法都能够实现Linux的目录共享。但是FTP和网络文件系统的功能都过于强大,因此它们都有一些不够方
- 大家好,为了进行调试和错误跟踪,人们在整个代码库中广泛使用日志,今天来看看如何在代码中定义日志,并探讨日志的权限。一、日志层级在开始之前,需
- 记录一些pandas选择数据的内容,此前首先说行列名的获取和更改,以方便获取数据。此文作为学习巩固。这篇博的内容顺序大概就是: 行列名的获取
- 一、简述最近接到一个新需求,让做一个动效进度条。由于我们的产品比较大,在软件启动的时候会消耗比较长的时间,原生的进度条已经不能满足我们的需求
- 无论使用int还是varchar,对于Status的多选查询都是不易应对的。举例,常规思维下对CustomerStatus的Enum设置如下
- 前言:Redhat下安装Python2.7rhel6.4自带的是2.6, 发现有的机器是python2.4。 到python网站下载源代码,
- 限制访问可以基于某种权限,某些检查或者为login视图提供不同的位置,这些实现方式大致相同。一般的方法是直接在视图的 request.use
- 第一步:python中安装selenium库和其他所有Python库一样,selenium库需要安装pip install selenium
- 这篇文章主要介绍了python列表推导式入门学习解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友
- 新云4.0模版标签是全新改的了,加了前缀。如果你怀旧,请查看新云CMS3.1常用模板标签。下面的标签说明,后台就有,为了方便查看转到这里。{
- 使用MySQL,安全问题不能不注意。以下是MySQL提示的23个注意事项:1.如果客户端和服务器端的连接需要跨越并通过不可信任的网络,那么就