Qualcomm AI Hub-示例(三)模型推理

文章介绍

Qualcomm® AI Hub提供了部署在云端边缘物理设备执行模型推理的任务,让你能够快速的评估在真实硬件上模型推理的精度和性能。本文介绍了如何使用AI Hub提供的接口在云端设备执行推理,更多详情可以参阅 Running Inference

模型推理

      出于功耗和性能的考虑,部署到边缘物理设备是量化后的模型。比如,原始模型是PyTorch实现以float32精度运行推理时,单目标硬件可能使用float16甚至int8运行计算。这些差异可能导致结果数值差异,所以在部署之前需要在真实的环境中对比评估数值的差异是否能够满足到需求。

      AI Hub提供云端真实硬件去运行推理和并且输出报告。通过报告和原始模型的实现进行比较,您可以确定优化的模型是否按预期工作。需要注意,如果原始模型为PyTorch和ONNX的,必须使用submit_compile_job()接口完成编译模型

使用TensorFlow Lite模型运行推理

此示例使用TensorFlow Lite模型 SqueezeNet10.tflite来运行推理。

import numpy as np

import qai_hub as hub

sample = np.random.random((1, 224, 224, 3)).astype(np.float32)

inference_job = hub.submit_inference_job(

    model="SqueezeNet10.tflite",

    device=hub.Device("Samsung Galaxy S23 Ultra"),

    inputs=dict(x=[sample]),

)

assert isinstance(inference_job, hub.InferenceJob)

inference_job.download_output_data()

  • 推理的输入必须是dict,其中键是特征的名称,值是张量。张量可以是numpy数组的列表,如果是单个数据点,则可以是单个numpy数组。
  • inference_job是InferenceJob的一个实例。

通过向submit_inference_job() API提供device列表,可以同时启动多个推理任务。

使用QNN模型库运行推理

此示例将TorchScript模型(mobilenet_v2.pt)编译为QNN模型库格式。然后在具有编译的目标模型的设备上运行推理。

import numpy as np

import qai_hub as hub

sample = np.random.random((1, 3, 224, 224)).astype(np.float32)

compile_job = hub.submit_compile_job(

    model="mobilenet_v2.pt",

    device=hub.Device("Samsung Galaxy S23"),

    options="--target_runtime qnn_lib_aarch64_android",

    input_specs=dict(image=(1, 3, 224, 224)),

)

assert isinstance(compile_job, hub.CompileJob)

inference_job = hub.submit_inference_job(

    model=compile_job.get_target_model(),

    device=hub.Device("Samsung Galaxy S23 Ultra"),

    inputs=dict(image=[sample]),

)

assert isinstance(inference_job, hub.InferenceJob)

使用推理任务验证设备上的模型准确性

此示例演示如何在设备上验证QNN模型库模型的数值。

重用评测示例中的模型(mobilenet_v2.pt)

from typing import Dict, List

import torch

import qai_hub as hub

device_s23 = hub.Device(name="Samsung Galaxy S23 Ultra")

compile_job = hub.submit_compile_job(

    model="mobilenet_v2.pt",

    device=device_s23,

    input_specs={"x": (1, 3, 224, 224)},

    options="--target_runtime qnn_lib_aarch64_android",

)

assert isinstance(compile_job, hub.CompileJob)

on_device_model = compile_job.get_target_model()

我们可以使用这个优化的.so模型,并在特定设备上使用输入数据进行推理。本例中使用的输入图像可以下载- input_image1.jpg

import numpy as np

from PIL import Image

# Convert the image to numpy array of shape [1, 3, 224, 224]

image = Image.open("input_image1.jpg").resize((224, 224))

img_array = np.array(image, dtype=np.float32)

# Ensure correct layout (NCHW) and re-scale

input_array = np.expand_dims(np.transpose(img_array / 255.0, (2, 0, 1)), axis=0)

# Run inference using the on-device model on the input image

inference_job = hub.submit_inference_job(

    model=on_device_model,

    device=device_s23,

    inputs=dict(x=[input_array]),

)

assert isinstance(inference_job, hub.InferenceJob)

我们可以在设备上使用这个原始输出来生成类预测,并将其与参考实现进行比较。为此,您需要imagenet类-imagenet_classes.txt

# Get the on-device output

on_device_output: Dict[str, List[np.ndarray]] = inference_job.download_output_data()  # type: ignore

# Load the torch model and perform inference

torch_model = torch.jit.load("mobilenet_v2.pt")

torch_model.eval()

# Calculate probabilities for torch model

torch_input = torch.from_numpy(input_array)

torch_output = torch_model(torch_input)

torch_probabilities = torch.nn.functional.softmax(torch_output[0], dim=0)

# Calculate probabilities for the on-device output

output_name = list(on_device_output.keys())[0]

out = on_device_output[output_name][0]

on_device_probabilities = np.exp(out) / np.sum(np.exp(out), axis=1)

# Read the class labels for imagenet

with open("imagenet_classes.txt", "r") as f:

categories = [s.strip() for s in f.readlines()]

# Print top five predictions for the on-device model

print("Top-5 On-Device predictions:")

top5_classes = np.argsort(on_device_probabilities[0], axis=0)[-5:]

for c in reversed(top5_classes):

    print(f"{c} {categories[c]:20s} {on_device_probabilities[0][c]:>6.1%}")

# Print top five prediction for torch model

print("Top-5 PyTorch predictions:")

top5_prob, top5_catid = torch.topk(torch_probabilities, 5)

for i in range(top5_prob.size(0)):

    print(

        f"{top5_catid[i]:4d} {categories[top5_catid[i]]:20s} {top5_prob[i].item():>6.1%}"

)

上面的代码生成的结果如下所示:

Top-5 On-Device predictions:

968 cup                   71.3%

504 coffee mug            16.4%

967 espresso               7.8%

809 soup bowl              1.3%

659 mixing bowl            1.2%

Top-5 PyTorch predictions:

968 cup                   71.4%

504 coffee mug            16.1%

967 espresso               8.0%

809 soup bowl              1.4%

659 mixing bowl            1.2%

设备上的结果几乎等同于参考实现。这告诉我们,该模型精度没有倒退,并使我们相信,一旦部署,它将按预期运行。

为了增强这种信心,可以考虑将其扩展到几个图像,并使用定量总结,例如测量KL散度或比较准确性(如果标签已知)。这也使您更容易在目标设备上进行验证。

使用先前上传的数据集和模型运行推理

与模型类似,AI Hub公开了一个API,允许用户上传可重复使用的数据。

import numpy as np

import qai_hub as hub

data = dict(

    x=[

        np.random.random((1, 224, 224, 3)).astype(np.float32),

        np.random.random((1, 224, 224, 3)).astype(np.float32),

    ]

)

hub_dataset = hub.upload_dataset(data)

现在,您可以使用上载的数据集运行推理作业。此示例使用SqueezeNet10.tflite

# Submit job

job = hub.submit_inference_job(

    model="SqueezeNet10.tflite",

    device=hub.Device("Samsung Galaxy S23 Ultra"),

    inputs=hub_dataset,

)

作者:高通工程师,戴忠忠(Zhongzhong Dai)

相关推荐

  1. 机器学习_XGBoost模型_用C++推理示例Demo

    2024-03-20 20:48:04       189 阅读
  2. 隐马尔可夫模型系列——(模型推断

    2024-03-20 20:48:04       50 阅读
  3. vue用法示例

    2024-03-20 20:48:04       35 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-03-20 20:48:04       94 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-03-20 20:48:04       100 阅读
  3. 在Django里面运行非项目文件

    2024-03-20 20:48:04       82 阅读
  4. Python语言-面向对象

    2024-03-20 20:48:04       91 阅读

热门阅读

  1. 动态加载CSS文件

    2024-03-20 20:48:04       46 阅读
  2. 如何从零开始拆解uni-app开发的vue项目(二)

    2024-03-20 20:48:04       40 阅读
  3. Python 中可以用来生成 SVG 图的库

    2024-03-20 20:48:04       43 阅读
  4. linux系统中的PS命令详解

    2024-03-20 20:48:04       47 阅读
  5. 主流开发语言和开发环境介绍

    2024-03-20 20:48:04       39 阅读
  6. DNS劫持怎么预防?

    2024-03-20 20:48:04       45 阅读
  7. 去除项目git的控制 端口号的关闭

    2024-03-20 20:48:04       39 阅读
  8. 对建造者模式的理解

    2024-03-20 20:48:04       35 阅读
  9. 《建造者模式(极简c++)》

    2024-03-20 20:48:04       47 阅读
  10. Doris案例篇—美团外卖数仓中的应用实践

    2024-03-20 20:48:04       43 阅读
  11. 前端面试小节

    2024-03-20 20:48:04       40 阅读