ONNX模型简介
OnnxRuntime是一个跨平台的高性能推理引擎,可用于运行由各种深度学习框架(如 PyTorch、TensorFlow 等)导出的 ONNX 格式模型,这一特性极大地促进了不同框架之间的模型共享和迁移,打破了框架之间的壁垒,为深度学习技术的广泛应用和快速发展创造了有利条件。同时,OnnxRuntime可以使用多种语言对ONNX模型进行解析,方便快速部署自己训练模型。
C#
1. 创建 C# 项目并安装依赖
首先,创建一个新的 C# 控制台应用程序项目,然后通过 NuGet 包管理器安装Microsoft.ML.OnnxRuntime库。在 Visual Studio 中,可以通过以下步骤完成:
- 右键单击项目,选择 “管理 NuGet 程序包”。
- 在 “浏览” 选项卡中搜索Microsoft.ML.OnnxRuntime,然后安装该包。
2. 编写代码解析和运行 ONNX 模型
以下是一个完整的示例代码,展示了如何加载 ONNX 模型、准备输入数据、运行推理并获取输出结果:
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
namespace OnnxRuntimeExample
{
class Program
{
static void Main()
{
// 定义ONNX模型文件的路径
string modelPath = "your_model.onnx";
// 创建一个会话选项对象
var sessionOptions = new SessionOptions();
// 创建一个InferenceSession实例,加载ONNX模型
using (var session = new InferenceSession(modelPath, sessionOptions))
{
// 获取模型的输入和输出信息
var inputMeta = session.InputMetadata;
var outputMeta = session.OutputMetadata;
// 打印输入和输出信息
Console.WriteLine("Input information:");
foreach (var meta in inputMeta)
{
Console.WriteLine($"Name: {meta.Key}, Shape: {string.Join(", ", meta.Value.Dimensions)}");
}
Console.WriteLine("\nOutput information:");
foreach (var meta in outputMeta)
{
Console.WriteLine($"Name: {meta.Key}, Shape: {string.Join(", ", meta.Value.Dimensions)}");
}
// 准备输入数据
// 这里假设输入是一个形状为 [1, 3, 224, 224] 的张量,具体形状需根据模型而定
var inputTensor = new DenseTensor(new[] { 1, 3, 224, 224 });
// 可以在这里填充输入数据,例如:
// for (int i = 0; i < inputTensor.Length; i++)
// {
// inputTensor[i] = 0.5f;
// }
// 创建输入字典,将输入张量与模型的输入名称关联起来
var inputs = new List
{
NamedOnnxValue.CreateFromTensor(inputMeta.Keys.First(), inputTensor)
};
// 运行推理
using (var results = session.Run(inputs))
{
// 获取输出结果
foreach (var output in results)
{
var outputTensor = output.AsTensor();
// 打印输出张量的形状和部分值
Console.WriteLine($"\nOutput tensor shape: {string.Join(", ", outputTensor.Dimensions)}");
Console.WriteLine("First few values of the output tensor:");
for (int i = 0; i < Math.Min(10, outputTensor.Length); i++)
{
Console.Write($"{outputTensor[i]:F4} ");
}
Console.WriteLine();
}
}
}
}
}
}
- 加载模型:SessionOptions用于配置推理会话的选项,例如线程数、内存分配策略等。InferenceSession用于加载和管理 ONNX 模型。
- 获取输入和输出信息:session.InputMetadata和session.OutputMetadata分别返回模型的输入和输出元数据,包括名称和形状。
- 准备输入数据:创建一个DenseTensor对象,用于存储输入数据。根据模型的输入要求填充张量数据。
- 运行推理:创建一个NamedOnnxValue列表,将输入张量与模型的输入名称关联起来。调用session.Run方法进行推理,返回一个IDisposableReadOnlyCollection
对象,包含模型的输出结果。 - 处理输出结果:遍历输出结果,将其转换为Tensor对象,并打印张量的形状和部分值。
注意事项
- 请将"your_model.onnx"替换为实际的 ONNX 模型文件路径。
- 输入数据的形状和类型必须与模型的输入要求一致,否则会导致推理失败。
- 在使用完InferenceSession和IDisposableReadOnlyCollection
对象后,需要调用Dispose方法释放资源,这里使用using语句可以自动完成资源释放。
C++
C++ 向来以其卓越的高性能和强大的系统级编程能力在编程语言的领域中声名远扬,OnnxRuntime 的 C++ API 更是为开发者打开了一扇通往高效能计算的大门。它赋予了开发者在那些对性能有着极致追求的场景中灵活运用 ONNX 模型的能力。
就拿嵌入式系统来说,其资源往往受限,对计算效率和内存占用有着极为苛刻的要求。而 OnnxRuntime 的 C++ API 恰好能够满足这一需求,使得在有限的硬件条件下,依然能够实现高效的模型推理,从而为诸如智能家居设备、工业自动化控制等领域提供强大的支持。
#include
#include
int main() {
Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test");
Ort::SessionOptions session_options;
// 加载模型
const char* model_path = "your_model.onnx";
Ort::Session session(env, model_path, session_options);
// 获取输入信息
Ort::AllocatorWithDefaultOptions allocator;
auto input_name = session.GetInputName(0, allocator);
auto input_type_info = session.GetInputTypeInfo(0);
auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
// 准备输入数据
std::vector input_dims = input_tensor_info.GetShape();
size_t input_tensor_size = Ort::GetTensorElementSize(input_tensor_info.GetElementType());
for (auto dim : input_dims) {
input_tensor_size *= dim;
}
std::vector input_data(input_tensor_size);
// 填充输入数据
// 创建输入张量
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
Ort::Value input_tensor = Ort::Value::CreateTensor(memory_info, input_data.data(), input_data.size(), input_dims.data(), input_dims.size());
// 运行推理
std::vector input_names = {input_name};
std::vector output_names = {session.GetOutputName(0, allocator)};
auto output_tensors = session.Run(Ort::RunOptions{nullptr}, input_names.data(), &input_tensor, 1, output_names.data(), output_names.size());
// 处理输出结果
float* output_data = output_tensors[0].GetTensorMutableData();
// 处理输出数据
return 0;
}
Python
OnnxRuntime 专为 Python 精心打造了丰富且易用的 API 。这些 API 犹如一座坚固的桥梁,将 Python 强大的编程能力与高效的 ONNX 模型紧密连接起来。
借助这些精心设计的 API ,开发者能够轻松自如地加载各种复杂的 ONNX 模型。无论是大规模的图像识别模型,还是精准的自然语言处理模型,都能在瞬间完成加载。
import onnxruntime as ort
import numpy as np
# 加载模型
session = ort.InferenceSession('your_model.onnx')
# 获取输入名称
input_name = session.get_inputs()[0].name
# 准备输入数据
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# 运行推理
outputs = session.run(None, {input_name: input_data})
# 输出结果
print(outputs)
Java
Java 具有良好的跨平台性和强大的企业级开发能力,OnnxRuntime 的 Java API 使得 Java 开发者可以在 Java 应用程序中集成 ONNX 模型进行推理,如在 Android 应用、企业级服务等场景中使用。
import ai.onnxruntime.*;
import java.util.HashMap;
import java.util.Map;
public class OnnxRuntimeJavaExample {
public static void main(String[] args) throws OrtException {
// 加载模型
OrtEnvironment env = OrtEnvironment.getEnvironment();
OrtSession.SessionOptions options = new OrtSession.SessionOptions();
OrtSession session = env.createSession("your_model.onnx", options);
// 获取输入信息
OrtSession.InputInfo inputInfo = session.getInputInfo("input_name");
long[] inputShape = inputInfo.getInfo().getShape();
// 准备输入数据
float[] inputData = new float[(int) (inputShape[0] * inputShape[1] * inputShape[2] * inputShape[3])];
// 填充输入数据
// 创建输入张量
OnnxTensor inputTensor = OnnxTensor.createTensor(env, inputData, inputShape);
// 运行推理
Map inputs = new HashMap<>();
inputs.put("input_name", inputTensor);
OrtSession.Result results = session.run(inputs);
// 处理输出结果
OnnxTensor outputTensor = (OnnxTensor) results.get(0);
float[] outputData = (float[]) outputTensor.getValue();
// 处理输出数据
// 释放资源
results.close();
session.close();
env.close();
}
}
JavaScript
在 Web 开发和 Node.js 环境中,JavaScript 非常流行。OnnxRuntime 提供了 JavaScript 版本的库,允许开发者在浏览器或服务器端使用 ONNX 模型进行推理,例如实现基于浏览器的图像识别、自然语言处理等应用。
const ort = require('onnxruntime-node');
const fs = require('fs');
async function main() {
// 加载模型
const session = await ort.InferenceSession.create('your_model.onnx');
// 准备输入数据
const inputTensor = new ort.Tensor('float32', new Float32Array([1, 2, 3, 4]), [2, 2]);
// 运行推理
const feeds = { input_name: inputTensor };
const results = await session.run(feeds);
// 输出结果
console.log(results);
}
main().catch((err) => {
console.error(err);
});
总之,上述所提及的这些语言,均能够凭借 OnnxRuntime 库所提供的强大功能,极为便捷地运用 ONNX 模型来开展推理工作。在实际应用中,您完全可以依据具体的应用场景以及独特的开发需求,去审慎地挑选最为适宜的语言。例如,如果您所面对的是对计算效率和内存管理要求极高的实时应用场景,或许像 C++ 这样性能卓越的语言会是更理想的选择;而若是侧重于快速开发和原型验证,Python 则可能更能满足您的需求。你可以根据具体的应用场景、开发需求以及自己的熟悉程度选择合适的语言。