Python1.0用C++实现残差网络图像分类

友情提示:

  • 阅读本文需要您已经掌握Pytorch的Python用法,并掌握C++语言。
  • 推荐使用Ubuntu/Mac系统实验(cmake可以自动找到已安装的opencv)。
  • 本实验需要已安装好opencv和pytorch 1.0,C++编译环境(Ubuntu需要g++,Mac需要XCode)和cmake。

Pytorch 1.0已经于近日推出,其中一个亮点功能是支持将python训练的模型导出到C++进行推理。相比于目前流行的caffe训练模型+opencv dnn模块推理,pytorch从Python训练到C++部署提供了一体化的方案,可谓攻城狮的福音。

Python导出模型

根据Pytorch官网教程,我们先导出残差网络Resnet18的模型和预训练权重:

# coding=utf-8
import torch
import torchvision
from torchvision import transforms
from PIL import Image
import json
import cv2

# 初始化模型
model = torchvision.models.resnet18(pretrained=True)
model.eval() #将模型置为推理状态

# 随机生成一个输入张量
example = torch.rand(1, 3, 224, 224)

# 利用跟踪数据流的方法生成导出模型
traced_script_module = torch.jit.trace(model, example)
output = traced_script_module(torch.ones(1, 3, 224, 224))
print output.shape
print output[0, :5]
traced_script_module.save("model.pt")

这样在当前目录下会生成一个model.pt文件,包含了模型定义和权重。

继续阅读