C++代码如下
#include "torch/script.h"
#include "torch/torch.h"
#include <opencv2/opencv.hpp>
#include <iostream>
#include <memory>
#include <string>
#include<cstdio>
#include<vector>
#include<typeinfo>
using namespace std;
using namespace cv;
string str_="";
extern "C"{
const char* ocr_recgnition(char* s);
}
Mat process_img(Mat& img)
{
Mat temImage;
int w = img.cols;
int h = img.rows;
int dst_h = 32;
int dst_w = w*32/h;
resize(img,temImage,Size(w*32/h,32));
return temImage;
}
extern "C" const char* ocr_recgnition(char *s)
{
//readTxt("zidian.txt");
std:cout<<"start"<<endl;
// 读取我们的权重信息
torch::jit::script::Module module;
try {
module = torch::jit::load("ocr_libtorch.pt");
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
//return -1;
}
module.to(at::kCUDA);
std::cout << "ok\n";
//std::cout<<s<<endl;
Mat src = imread(s,0);
if (src.empty())
{
printf("could not load image...\n");
//return -1;
}
Mat pro_img = process_img(src);
pro_img.convertTo(pro_img, CV_32F, 1.0 / 255.0);
torch::TensorOptions option(torch::kFloat32);
auto img_tensor = torch::from_blob(pro_img.data, { 1,pro_img.rows,pro_img.cols,pro_img.channels() }, option);// opencv H x W x C torch C x H x W
img_tensor = img_tensor.permute({ 0,3,1,2 });
img_tensor = img_tensor.sub_(0.5).div(0.5);
img_tensor = img_tensor.to(torch::kCUDA);
torch::Tensor output = module.forward({img_tensor}).toTensor();
// namedWindow("input", WINDOW_AUTOSIZE);
// imshow("input", src);
// waitKey(0);
auto max_result = output.max(2,true);
auto max_index = std::get<1>(max_result);
//std::cout<<torch::size(max_index,0)<<endl;
int d = 0;
int d_ = 0; //前一个位置
//cout<<str_<<endl;
for(int i=0;i<torch::size(max_index,0);i++)
{
d = max_index[i].item<int>();
if((d!=0)&&(d!=d_))
{
//std::cout<<d<<endl;
str_ = str_ + char(d-1);
}
d_ = d;
}
return str_.data();
}
CMakeList
project(test_pytorch)
#set(CMAKE_CXX_STANDARD 11)
#set(CMAKE_CXX_STANDARD_REQUIRED ON)
cmake_minimum_required(VERSION 3.5)
set(Torch_DIR /home/mask/Downloads/libtorch/share/cmake/Torch) #指定libTorch位置(应该是有更好的办法安装)
#include_directories(${OpenCV_INCLUDE_DIRS} /home/mask/Downloads/opencv-master/include/opencv2)
find_package(OpenCV REQUIRED)
find_package(Torch REQUIRED) # 自动查找libTorch包
#add_executable(test_pytorch test_pytorch.cpp)
add_library( # Sets the name of the library.
test_pytorch
# Sets the library as a shared library.
SHARED
# Provides a relative path to your source file(s).
test_pytorch.cpp)
target_link_libraries(test_pytorch ${OpenCV_LIBS} ${TORCH_LIBRARIES}) # 加入libTorch的库文件路径
set_property(TARGET test_pytorch PROPERTY CXX_STANDARD 11)
python调用并返回显示
from ctypes import *
import os
from keys import alphabet
libtest = cdll.LoadLibrary(os.getcwd() + '/libtest_pytorch.so')
str1 = '../10.jpg'.encode()
t= libtest.ocr_recgnition
t.restype=c_char_p
t = t(str1)
for i in t:
print(alphabet[i])
print(t)
这里主要是对自己训练的ocr模型进行了封装,利用Python调用并显示结果的简单展示。