网络编程
位置:首页>> 网络编程>> 网络编程>> pytorch 移动端部署之helloworld的使用

pytorch 移动端部署之helloworld的使用

作者:东方佑  发布时间:2022-03-30 18:56:18 

标签:pytorch,移动端部署

开始

安装Androidstudio 4.1

克隆此项目


git clone https://github.com/pytorch/android-demo-app.git

使用androidstudio 打开 android-demo-app 中的HelloWordApp

打开之后androidstudio 会自动创建依赖 只需要等待即可

这个代码已经是官方写好的故而

开一下官方教程中的代码都在什么位置

这句


repositories {
 jcenter()
}

dependencies {
 implementation 'org.pytorch:pytorch_android:1.4.0'
 implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}

位置

HelloWorldApp\app\build.gradle

里面的全部代码


apply plugin: 'com.android.application'
repositories {
 jcenter()
}

android {
 compileSdkVersion 28
 buildToolsVersion "29.0.2"
 defaultConfig {
   applicationId "org.pytorch.helloworld"
   minSdkVersion 21
   targetSdkVersion 28
   versionCode 1
   versionName "1.0"
 }
 buildTypes {
   release {
     minifyEnabled false
   }
 }
}

dependencies {
 implementation 'androidx.appcompat:appcompat:1.1.0'
 implementation 'org.pytorch:pytorch_android:1.4.0'
 implementation 'org.pytorch:pytorch_android_torchvision:1.4.0'
}

这句


Bitmap bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
Module module = Module.load(assetFilePath(this, "model.pt"));
Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
 TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);
 Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();
float[] scores = outputTensor.getDataAsFloatArray();
float maxScore = -Float.MAX_VALUE;
int maxScoreIdx = -1;
for (int i = 0; i < scores.length; i++) {
if (scores[i] > maxScore) {
 maxScore = scores[i];
 maxScoreIdx = i;
}
}
String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

都在这里

HelloWorldApp\app\src\main\java\org\pytorch\helloworld\MainActivity.java

全部代码


package org.pytorch.helloworld;

import android.content.Context;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.os.Bundle;
import android.util.Log;
import android.widget.ImageView;
import android.widget.TextView;

import org.pytorch.IValue;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;

import androidx.appcompat.app.AppCompatActivity;

public class MainActivity extends AppCompatActivity {

@Override
protected void onCreate(Bundle savedInstanceState) {
 super.onCreate(savedInstanceState);
 setContentView(R.layout.activity_main);

Bitmap bitmap = null;
 Module module = null;
 try {
  // creating bitmap from packaged into app android asset 'image.jpg',
  // app/src/main/assets/image.jpg
  bitmap = BitmapFactory.decodeStream(getAssets().open("image.jpg"));
  // loading serialized torchscript module from packaged into app android asset model.pt,
  // app/src/model/assets/model.pt
  module = Module.load(assetFilePath(this, "model.pt"));
 } catch (IOException e) {
  Log.e("PytorchHelloWorld", "Error reading assets", e);
  finish();
 }

// showing image on UI
 ImageView imageView = findViewById(R.id.image);
 imageView.setImageBitmap(bitmap);

// preparing input tensor
 final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(bitmap,
   TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB);

// running the model
 final Tensor outputTensor = module.forward(IValue.from(inputTensor)).toTensor();

// getting tensor content as java array of floats
 final float[] scores = outputTensor.getDataAsFloatArray();

// searching for the index with maximum score
 float maxScore = -Float.MAX_VALUE;
 int maxScoreIdx = -1;
 for (int i = 0; i < scores.length; i++) {
  if (scores[i] > maxScore) {
   maxScore = scores[i];
   maxScoreIdx = i;
  }
 }

String className = ImageNetClasses.IMAGENET_CLASSES[maxScoreIdx];

// showing className on UI
 TextView textView = findViewById(R.id.text);
 textView.setText(className);
}

/**
 * Copies specified asset to the file in /files app directory and returns this file absolute path.
 *
 * @return absolute file path
 */
public static String assetFilePath(Context context, String assetName) throws IOException {
 File file = new File(context.getFilesDir(), assetName);
 if (file.exists() && file.length() > 0) {
  return file.getAbsolutePath();
 }

try (InputStream is = context.getAssets().open(assetName)) {
  try (OutputStream os = new FileOutputStream(file)) {
   byte[] buffer = new byte[4 * 1024];
   int read;
   while ((read = is.read(buffer)) != -1) {
    os.write(buffer, 0, read);
   }
   os.flush();
  }
  return file.getAbsolutePath();
 }
}
}

在Build 中选择Build Bundile APK 的 Build APK 就可以了

生成的apk 在

HelloWorldApp\app\build\outputs\apk\debug

中 这个是可以直接安装的

安装后是一个固定的照片 就是检测了一个固定的照片

这是一个例子如果你只是想测试自己的模型调用能不能成功这个项目改改模型和模型加载即可

这个项目模型是一个resnet18 接着我们将其替换为resnet50

模型转换代码如下


import torch
import torchvision.models as models
from PIL import Image
import numpy as np
image = Image.open("test.jpg") #图片发在了build文件夹下
image = image.resize((224, 224),Image.ANTIALIAS)
image = np.asarray(image)
image = image / 255
image = torch.Tensor(image).unsqueeze_(dim=0)
image = image.permute((0, 3, 1, 2)).float()

model = models.resnet50(pretrained=True)
model = model.eval()
resnet = torch.jit.trace(model, torch.rand(1,3,224,224))
# output=resnet(torch.ones(1,3,224,224))
output = resnet(image)
max_index = torch.max(output, 1)[1].item()
print(max_index) # ImageNet1000类的类别序
resnet.save('model.pt')
if __name__ == '__main__':
 pass

将这个保存的模型 覆盖掉下面路径中的模型
 (在覆盖之前最好备份一个原来的模型,这里我们选择修改原来模型的名字为model_1.pt)

HelloWorldApp\app\src\main\assets\model.pt

成功覆盖后再一次执行打包操作(在Build 中选择Build Bundile APK 的 Build APK 就可以了
生成的apk 在
HelloWorldApp\app\build\outputs\apk\debug)
而后打开文件发现一个123M的apk 之前的apk是73M

安装 并且测试

完美打开也就是说一切resnet 系列的 都可以通过这个 项目进行演化出来

来源:https://blog.csdn.net/weixin_32759777/article/details/109380404

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com