TensorFlow系列:第五讲:移动端部署模型

朽木成才 2024-07-16 13:37:02 阅读 92

项目地址:https://github.com/LionJackson/imageClassification

Flutter项目地址:https://github.com/LionJackson/flutter_image

一. 模型转换

编写tflite模型工具类:

<code>import os

import PIL

import tensorflow as tf

import keras

import numpy as np

from PIL.Image import Image

from matplotlib import pyplot as plt

from utils.dataset_loader import DatasetLoader

from utils.utils import Utils

"""

tflite模型工具类

"""

class TFLiteUtil:

def __init__(self, saved_model_dir, path_url):

self.save_model_dir = saved_model_dir

self.path_url = path_url

# 训练的模型生成标签列表

def get_folder_names(self):

folder_names = []

for root, dirs, files in os.walk(self.path_url + '/train'):

for dir_name in dirs:

folder_names.append(dir_name)

with open(self.save_model_dir + '.label', 'w') as file:

for name in folder_names:

file.write(name + '\n')

return folder_names

# 模型转成tflite格式

def convert_tflite(self):

self.get_folder_names()

converter = tf.lite.TFLiteConverter.from_saved_model(self.save_model_dir)

tflite_model = converter.convert()

# 将转换后的 TFLite 模型保存为文件

with open(self.save_model_dir + '.tflite', 'wb') as f:

f.write(tflite_model)

print("转换成功,已保存为 tflite")

# 加载keras并转成tflite

def convert_model_tflite(self):

self.get_folder_names()

model = keras.models.load_model(self.save_model_dir + ".keras")

converter = tf.lite.TFLiteConverter.from_keras_model(model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]

converter.target_spec.supported_types = [tf.float16]

tflite_model = converter.convert()

# 将转换后的 TFLite 模型保存为文件

with open(self.save_model_dir + '.tflite', 'wb') as f:

f.write(tflite_model)

print("转换成功(model),已保存为 tflite")

# 批量识别 进行可视化显示

def batch_evaluation(self, class_mode='categorical', image_size=(224, 224), num_images=25):code>

dataset_loader = DatasetLoader(self.path_url, image_size=image_size, class_mode=class_mode)

train_ds, val_ds, test_ds, class_names = dataset_loader.load_data()

interpreter = tf.lite.Interpreter(self.save_model_dir + '.tflite')

interpreter.allocate_tensors()

# 获取输入和输出张量的信息

input_details = interpreter.get_input_details()

output_details = interpreter.get_output_details()

plt.figure(figsize=(10, 10))

for images, labels in test_ds.take(1):

outputs = []

for img in images:

img_expanded = np.expand_dims(img, axis=0)

interpreter.set_tensor(input_details[0]['index'], img_expanded)

interpreter.invoke()

output = interpreter.get_tensor(output_details[0]['index'])

outputs.append(output)

for i in range(num_images):

plt.subplot(5, 5, i + 1)

image = np.array(images[i]).astype("uint8")

plt.imshow(image)

index = int(np.argmax(outputs[i]))

prediction = outputs[i][0][index]

percentage_str = "{:.2f}%".format(prediction * 100)

plt.title(f"{ class_names[index]}: { percentage_str}")

plt.axis("off")

plt.subplots_adjust(hspace=0.5, wspace=0.5)

plt.show()

# 查看tflite模型信息

def tflite_analyzer(self):

# 加载 TFLite 模型

interpreter = tf.lite.Interpreter(model_path=self.save_model_dir + '.tflite')

interpreter.allocate_tensors()

# 获取输入和输出的详细信息

input_details = interpreter.get_input_details()

output_details = interpreter.get_output_details()

# 打印输入和输出的详细信息

print("Input Details:")

for detail in input_details:

print(detail)

print("\nOutput Details:")

for detail in output_details:

print(detail)

# 列出所有使用的算子

tensor_details = interpreter.get_tensor_details()

print("\nTensor Details:")

for tensor_detail in tensor_details:

print("Index:", tensor_detail['index'])

print("Name:", tensor_detail['name'])

print("Shape:", tensor_detail['shape'])

print("Shape Signature:", tensor_detail['shape_signature'])

print("dtype:", tensor_detail['dtype'])

print("Quantization:", tensor_detail['quantization'])

print("Quantization Parameters:", tensor_detail['quantization_parameters'])

print("Sparsity Parameters:", tensor_detail['sparsity_parameters'])

print()

引用工具类:

if __name__ == '__main__':

# train()

# model_util = ModelUtil(SAVED_MODEL_DIR, PATH_URL)

# model_util.batch_evaluation()

tflite_util = TFLiteUtil(SAVED_MODEL_DIR, PATH_URL)

tflite_util.convert_tflite()

tflite_util.tflite_analyzer()

tflite_util.batch_evaluation()

此时会生成tflite模型文件:

在这里插入图片描述

二. 使用模型

创建flutter项目,引入以下库:

<code> image: ^4.0.17

path: ^1.8.3

path_provider: ^2.0.15

image_picker: ^0.8.8

tflite_flutter: ^0.10.4

camera: ^0.10.5+2

把模型文件拷贝到项目中:

在这里插入图片描述

核心代码:

<code>

import 'dart:developer';

import 'dart:io';

import 'dart:isolate';

import 'package:camera/camera.dart';

import 'package:flutter/services.dart';

import 'package:image/image.dart';

import 'package:tflite_flutter/tflite_flutter.dart';

import 'isolate_inference.dart';

class ImageClassificationHelper {

static const modelPath = 'assets/models/fruits.tflite';

static const labelsPath = 'assets/models/fruits.label';

late final Interpreter interpreter;

late final List<String> labels;

late final IsolateInference isolateInference;

late Tensor inputTensor;

late Tensor outputTensor;

// Load model

Future<void> _loadModel() async {

final options = InterpreterOptions();

// Use XNNPACK Delegate

if (Platform.isAndroid) {

options.addDelegate(XNNPackDelegate());

}

// Use GPU Delegate

// doesn't work on emulator

// if (Platform.isAndroid) {

// options.addDelegate(GpuDelegateV2());

// }

// Use Metal Delegate

if (Platform.isIOS) {

options.addDelegate(GpuDelegate());

}

// Load model from assets

interpreter = await Interpreter.fromAsset(modelPath, options: options);

// Get tensor input shape [1, 224, 224, 3]

inputTensor = interpreter.getInputTensors().first;

// Get tensor output shape [1, 1001]

outputTensor = interpreter.getOutputTensors().first;

log('Interpreter loaded successfully');

}

// Load labels from assets

Future<void> _loadLabels() async {

final labelTxt = await rootBundle.loadString(labelsPath);

labels = labelTxt.split('\n');

}

Future<void> initHelper() async {

_loadLabels();

_loadModel();

isolateInference = IsolateInference();

await isolateInference.start();

}

Future<Map<String, double>> _inference(InferenceModel inferenceModel) async {

ReceivePort responsePort = ReceivePort();

isolateInference.sendPort

.send(inferenceModel..responsePort = responsePort.sendPort);

// get inference result.

var results = await responsePort.first;

return results;

}

// inference camera frame

Future<Map<String, double>> inferenceCameraFrame(

CameraImage cameraImage) async {

var isolateModel = InferenceModel(cameraImage, null, interpreter.address,

labels, inputTensor.shape, outputTensor.shape);

return _inference(isolateModel);

}

// inference still image

Future<Map<String, double>> inferenceImage(Image image) async {

var isolateModel = InferenceModel(null, image, interpreter.address, labels,

inputTensor.shape, outputTensor.shape);

return _inference(isolateModel);

}

Future<void> close() async {

isolateInference.close();

}

}

页面部分:

在这里插入图片描述



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。