TensorFlow系列:第五讲:移动端部署模型

项目地址:https://github.com/LionJackson/imageClassification
Flutter项目地址:https://github.com/LionJackson/flutter_image

一. 模型转换

编写tflite模型工具类:

import os

import PIL
import tensorflow as tf
import keras
import numpy as np
from PIL.Image import Image
from matplotlib import pyplot as plt

from utils.dataset_loader import DatasetLoader
from utils.utils import Utils

"""
tflite模型工具类
"""


class TFLiteUtil:
    def __init__(self, saved_model_dir, path_url):
        self.save_model_dir = saved_model_dir
        self.path_url = path_url

    # 训练的模型生成标签列表
    def get_folder_names(self):
        folder_names = []
        for root, dirs, files in os.walk(self.path_url + '/train'):
            for dir_name in dirs:
                folder_names.append(dir_name)

        with open(self.save_model_dir + '.label', 'w') as file:
            for name in folder_names:
                file.write(name + '\n')
        return folder_names

    # 模型转成tflite格式
    def convert_tflite(self):
        self.get_folder_names()
        converter = tf.lite.TFLiteConverter.from_saved_model(self.save_model_dir)
        tflite_model = converter.convert()

        # 将转换后的 TFLite 模型保存为文件
        with open(self.save_model_dir + '.tflite', 'wb') as f:
            f.write(tflite_model)

        print("转换成功,已保存为 tflite")

    # 加载keras并转成tflite
    def convert_model_tflite(self):
        self.get_folder_names()
        model = keras.models.load_model(self.save_model_dir + ".keras")
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.target_spec.supported_types = [tf.float16]
        tflite_model = converter.convert()
        # 将转换后的 TFLite 模型保存为文件
        with open(self.save_model_dir + '.tflite', 'wb') as f:
            f.write(tflite_model)

        print("转换成功(model),已保存为 tflite")

    # 批量识别 进行可视化显示
    def batch_evaluation(self, class_mode='categorical', image_size=(224, 224), num_images=25):
        dataset_loader = DatasetLoader(self.path_url, image_size=image_size, class_mode=class_mode)
        train_ds, val_ds, test_ds, class_names = dataset_loader.load_data()

        interpreter = tf.lite.Interpreter(self.save_model_dir + '.tflite')
        interpreter.allocate_tensors()
        # 获取输入和输出张量的信息
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()

        plt.figure(figsize=(10, 10))
        for images, labels in test_ds.take(1):
            outputs = []
            for img in images:
                img_expanded = np.expand_dims(img, axis=0)
                interpreter.set_tensor(input_details[0]['index'], img_expanded)
                interpreter.invoke()
                output = interpreter.get_tensor(output_details[0]['index'])
                outputs.append(output)

            for i in range(num_images):
                plt.subplot(5, 5, i + 1)
                image = np.array(images[i]).astype("uint8")
                plt.imshow(image)
                index = int(np.argmax(outputs[i]))
                prediction = outputs[i][0][index]
                percentage_str = "{:.2f}%".format(prediction * 100)
                plt.title(f"{class_names[index]}: {percentage_str}")
                plt.axis("off")
        plt.subplots_adjust(hspace=0.5, wspace=0.5)
        plt.show()

    # 查看tflite模型信息
    def tflite_analyzer(self):
        # 加载 TFLite 模型
        interpreter = tf.lite.Interpreter(model_path=self.save_model_dir + '.tflite')
        interpreter.allocate_tensors()

        # 获取输入和输出的详细信息
        input_details = interpreter.get_input_details()
        output_details = interpreter.get_output_details()

        # 打印输入和输出的详细信息
        print("Input Details:")
        for detail in input_details:
            print(detail)

        print("\nOutput Details:")
        for detail in output_details:
            print(detail)

        # 列出所有使用的算子
        tensor_details = interpreter.get_tensor_details()

        print("\nTensor Details:")
        for tensor_detail in tensor_details:
            print("Index:", tensor_detail['index'])
            print("Name:", tensor_detail['name'])
            print("Shape:", tensor_detail['shape'])
            print("Shape Signature:", tensor_detail['shape_signature'])
            print("dtype:", tensor_detail['dtype'])
            print("Quantization:", tensor_detail['quantization'])
            print("Quantization Parameters:", tensor_detail['quantization_parameters'])
            print("Sparsity Parameters:", tensor_detail['sparsity_parameters'])
            print()

引用工具类:

if __name__ == '__main__':
    # train()
    # model_util = ModelUtil(SAVED_MODEL_DIR, PATH_URL)
    # model_util.batch_evaluation()
    tflite_util = TFLiteUtil(SAVED_MODEL_DIR, PATH_URL)
    tflite_util.convert_tflite()
    tflite_util.tflite_analyzer()
    tflite_util.batch_evaluation()

此时会生成tflite模型文件:

在这里插入图片描述

二. 使用模型

创建flutter项目,引入以下库:

  image: ^4.0.17
  path: ^1.8.3
  path_provider: ^2.0.15
  image_picker: ^0.8.8
  tflite_flutter: ^0.10.4
  camera: ^0.10.5+2

把模型文件拷贝到项目中:

在这里插入图片描述
核心代码:



import 'dart:developer';
import 'dart:io';
import 'dart:isolate';

import 'package:camera/camera.dart';
import 'package:flutter/services.dart';
import 'package:image/image.dart';
import 'package:tflite_flutter/tflite_flutter.dart';

import 'isolate_inference.dart';

class ImageClassificationHelper {
  static const modelPath = 'assets/models/fruits.tflite';
  static const labelsPath = 'assets/models/fruits.label';

  late final Interpreter interpreter;
  late final List<String> labels;
  late final IsolateInference isolateInference;
  late Tensor inputTensor;
  late Tensor outputTensor;

  // Load model
  Future<void> _loadModel() async {
    final options = InterpreterOptions();

    // Use XNNPACK Delegate
    if (Platform.isAndroid) {
      options.addDelegate(XNNPackDelegate());
    }

    // Use GPU Delegate
    // doesn't work on emulator
    // if (Platform.isAndroid) {
    //   options.addDelegate(GpuDelegateV2());
    // }

    // Use Metal Delegate
    if (Platform.isIOS) {
      options.addDelegate(GpuDelegate());
    }

    // Load model from assets
    interpreter = await Interpreter.fromAsset(modelPath, options: options);
    // Get tensor input shape [1, 224, 224, 3]
    inputTensor = interpreter.getInputTensors().first;
    // Get tensor output shape [1, 1001]
    outputTensor = interpreter.getOutputTensors().first;

    log('Interpreter loaded successfully');
  }

  // Load labels from assets
  Future<void> _loadLabels() async {
    final labelTxt = await rootBundle.loadString(labelsPath);
    labels = labelTxt.split('\n');
  }

  Future<void> initHelper() async {
    _loadLabels();
    _loadModel();
    isolateInference = IsolateInference();
    await isolateInference.start();
  }

  Future<Map<String, double>> _inference(InferenceModel inferenceModel) async {
    ReceivePort responsePort = ReceivePort();
    isolateInference.sendPort
        .send(inferenceModel..responsePort = responsePort.sendPort);
    // get inference result.
    var results = await responsePort.first;
    return results;
  }

  // inference camera frame
  Future<Map<String, double>> inferenceCameraFrame(
      CameraImage cameraImage) async {
    var isolateModel = InferenceModel(cameraImage, null, interpreter.address,
        labels, inputTensor.shape, outputTensor.shape);
    return _inference(isolateModel);
  }

  // inference still image
  Future<Map<String, double>> inferenceImage(Image image) async {
    var isolateModel = InferenceModel(null, image, interpreter.address, labels,
        inputTensor.shape, outputTensor.shape);
    return _inference(isolateModel);
  }

  Future<void> close() async {
    isolateInference.close();
  }
}

页面部分:

在这里插入图片描述

相关推荐

  1. TensorFlow系列:MobileNetV2使用介绍

    2024-07-13 09:24:02       23 阅读
  2. _css元素显示模式

    2024-07-13 09:24:02       58 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-13 09:24:02       66 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-13 09:24:02       70 阅读
  3. 在Django里面运行非项目文件

    2024-07-13 09:24:02       57 阅读
  4. Python语言-面向对象

    2024-07-13 09:24:02       68 阅读

热门阅读

  1. STL内建仿函数

    2024-07-13 09:24:02       23 阅读
  2. 开源 Wiki 系统 InfoSphere 2024.01.1 发布

    2024-07-13 09:24:02       29 阅读
  3. macOS 的电源适配器设置

    2024-07-13 09:24:02       25 阅读
  4. PTA 7-14 畅通工程之局部最小花费问题

    2024-07-13 09:24:02       27 阅读
  5. Vue单路由的独享守卫怎么设置

    2024-07-13 09:24:02       26 阅读
  6. 代码随想录算法训练营第33天

    2024-07-13 09:24:02       25 阅读
  7. 总结:Hadoop高可用

    2024-07-13 09:24:02       26 阅读
  8. 使用Python进行音频处理:掌握音频世界的魔法

    2024-07-13 09:24:02       28 阅读
  9. ssh:(xshell)远程连接失败

    2024-07-13 09:24:02       24 阅读
  10. Hadoop 面试题(十一)

    2024-07-13 09:24:02       26 阅读
  11. 深入理解外观模式(Facade Pattern)及其实际应用

    2024-07-13 09:24:02       21 阅读