版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/yangjf91/article/details/84097913
c_predict_api
_CreateExecutor对模型预测句柄连接引擎:
inline void _CreateExecutor(PredictorHandle pred_hnd)
{
MXAPIPredictor *pred = static_cast<MXAPIPredictor*>(pred_hnd);
if (pred->exec == nullptr)
{
auto sym = pred->sym;
auto ctx = pred->ctx;
auto key2arg = pred->key2arg;
auto arg_arrays = pred->arg_arrays;
auto aux_arrays = pred->aux_arrays;
//剩余out_arrays、out_shapes、out_shapes_buffer
std::map<std::string, Context> ctx_map;
std::vector<NDArray> grad_store(arg_arrays.size());
std::vector<OpReqType> grad_req(arg_arrays.size(), kNullOp);
pred->exec.reset(Executor::Bind(sym, ctx, ctx_map, arg_arrays, grad_store, grad_req, aux_arrays));//释放ret原空间并指向Bind结果
pred->out_arrays = pred->exec->outputs();
}
}
static_cast用法说明
static_cast是C++强制类型转换操作符,将void*的pred_hnd转换为MXAPIPredictor指针类型。
_CreatePartialOut创建预测句柄:
int _CreatePartialOut(const char* symbol_json_str,
const void* param_bytes,
int param_size,
int dev_type, int dev_id,
mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
mx_uint num_output_nodes,
const char** output_keys,
// This is used for parallel inference.
int num_threads,
bool lazy,
PredictorHandle* out)
{
using nnvm::Symbol;
API_BEGIN();//异常检测及状态检测
Symbol sym;
{
mx_uint outSize;
const char **outArray;
MXListAllOpNames(&outSize, &outArray);
}// 将mxnet函数转到nnvm中
{
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(std::string(symbol_json_str));
sym.outputs = nnvm::ApplyPass(g, "LoadLegacyJSON").outputs;
}// 载入json文件,内含定义的网络结构.
if (num_output_nodes != 0) {
Symbol internal = sym.GetInternals();
std::vector<std::string> all_out = internal.ListOutputNames();
std::vector<Symbol> out_syms(num_output_nodes);
for (mx_uint i = 0; i < num_output_nodes; ++i) {
std::string out_key(output_keys[i]);
out_key += "_output";
for (size_t j = 0; j < all_out.size(); ++j) {
if (all_out[j] == out_key) {
out_syms[i] = internal[j];
break;
}
CHECK_NE(j, all_out.size() - 1) << "didn't find node name: " << out_key;
}
}
sym = nnvm::Symbol::CreateGroup(out_syms);
}//检查节点名称是否正确
std::unordered_map<std::string, NDArray> arg_params, aux_params;
{
std::unordered_set<std::string> arg_names, aux_names;
std::vector<std::string> arg_names_vec = sym.ListInputNames(Symbol::kReadOnlyArgs);
std::vector<std::string> aux_names_vec = sym.ListInputNames(Symbol::kAuxiliaryStates);
for (size_t i = 0; i < arg_names_vec.size(); ++i) {
arg_names.insert(arg_names_vec[i]);
}
for (size_t i = 0; i < aux_names_vec.size(); ++i) {
aux_names.insert(aux_names_vec[i]);
}
std::vector<NDArray> data;
std::vector<std::string> names;
dmlc::MemoryFixedSizeStream fi((void*)param_bytes, param_size); // NOLINT(*)
NDArray::Load(&fi, &data, &names);
CHECK_EQ(names.size(), data.size())
<< "Invalid param file format";
for (size_t i = 0; i < names.size(); ++i) {
if (!strncmp(names[i].c_str(), "aux:", 4)) {
std::string name(names[i].c_str() + 4);
if (aux_names.count(name) != 0) {
aux_params[name] = data[i];
}
}
if (!strncmp(names[i].c_str(), "arg:", 4)) {
std::string name(names[i].c_str() + 4);
if (arg_names.count(name) != 0) {
arg_params[name] = data[i];
}
}
}
}//根据json中定义的网络名称,在param中找到对应的参数并载入
std::unordered_map<std::string, TShape> known_shape;
for (mx_uint i = 0; i < num_input_nodes; ++i) {
known_shape[std::string(input_keys[i])] =
TShape(input_shape_data + input_shape_indptr[i],
input_shape_data + input_shape_indptr[i + 1]);
}//获取每个输出节点的尺寸
std::vector<std::string> arg_names = sym.ListInputNames(Symbol::kReadOnlyArgs);
std::vector<std::string> aux_names = sym.ListInputNames(Symbol::kAuxiliaryStates);
std::vector<TShape> out_shapes(sym.ListOutputNames().size());
std::vector<TShape> aux_shapes(aux_names.size());
std::vector<TShape> arg_shapes;
std::unordered_map<std::string, size_t> key2arg;
for (size_t i = 0; i < arg_names.size(); ++i) {
std::string key = arg_names[i];
key2arg[key] = i;
}//对每个输出节点编号
try {
std::vector<TShape> in_shapes;
for (std::string key : sym.ListInputNames(Symbol::kAll)) {
if (known_shape.count(key) != 0) {
in_shapes.push_back(known_shape[key]);
} else {
in_shapes.emplace_back();
}
}
nnvm::Graph g; g.outputs = sym.outputs;
g = mxnet::exec::InferShape(std::move(g), std::move(in_shapes), "__shape__");
bool infer_complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
CHECK(infer_complete) << "The shape information of is not enough to get the shapes";
CopyAttr(g.indexed_graph(),
g.GetAttr<nnvm::ShapeVector>("shape"),
&arg_shapes, &out_shapes, &aux_shapes);
} catch (const mxnet::op::InferShapeError &err) {
throw dmlc::Error(err.msg);
}//没有未知规模的节点
Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), dev_id);//创建上下文
std::vector<NDArray> arg_arrays, aux_arrays;
for (size_t i = 0; i < arg_shapes.size(); ++i) {
NDArray nd = NDArray(arg_shapes[i], ctx);
if (arg_params.count(arg_names[i]) != 0) {
CopyFromTo(arg_params[arg_names[i]], &nd);
}
arg_arrays.push_back(nd);
}
for (size_t i = 0; i < aux_shapes.size(); ++i) {
NDArray nd = NDArray(aux_shapes[i], ctx);
if (aux_params.count(aux_names[i]) != 0) {
CopyFromTo(aux_params[aux_names[i]], &nd);
}
aux_arrays.push_back(nd);
}//将通过param文件读取到的参数拷贝到规定设备的NDArray中
for (int i = 0; i < num_threads; i++) {
std::unique_ptr<MXAPIPredictor> ret(new MXAPIPredictor());
ret->sym = sym;
ret->ctx = ctx;
ret->key2arg = key2arg;
ret->arg_arrays = arg_arrays;
ret->aux_arrays = aux_arrays;
ret->out_shapes = out_shapes;
if (!lazy) {
std::map<std::string, Context> ctx_map;
std::vector<NDArray> grad_store(arg_arrays.size());
std::vector<OpReqType> grad_req(arg_arrays.size(), kNullOp);
ret->exec.reset(Executor::Bind(sym, ctx, ctx_map,
arg_arrays,
grad_store, grad_req,
aux_arrays));//释放ret原空间并指向Bind结果
ret->out_arrays = ret->exec->outputs();
}//是否直接建立连接
out[i] = ret.release();//将所有权传递给out[i]
}
API_END_HANDLE_ERROR();//异常检测结束
}
通过std::move提高参数赋值效率,通过unique_ptr的release、reset管理内存。
MXPredCreatePartialOut为创建mxnet自己管理多线程的预测句柄:
int MXPredCreatePartialOut(const char* symbol_json_str,
const void* param_bytes,
int param_size,
int dev_type, int dev_id,
mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
mx_uint num_output_nodes,
const char** output_keys,
PredictorHandle* out) {
return _CreatePartialOut(
symbol_json_str,
param_bytes,
param_size,
dev_type, dev_id,
num_input_nodes,
input_keys,
input_shape_indptr,
input_shape_data,
num_output_nodes,
output_keys,
1,
false,//在创建时bind执行器
out);
}
与其相对的是创建mxnet在单线程下使用的句柄,可在多个自己的线程中运行。
int MXPredCreateMultiThread(const char* symbol_json_str,
const void* param_bytes,
int param_size,
int dev_type, int dev_id,
mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
mx_uint num_output_nodes,
const char** output_keys,
// This is used for paralle inference.
int num_threads,
PredictorHandle* out) {
return _CreatePartialOut(
symbol_json_str,
param_bytes,
param_size,
dev_type,
dev_id,
num_input_nodes,
input_keys,
input_shape_indptr,
input_shape_data,
num_output_nodes,
output_keys,
num_threads,
true,//在创建时不bind执行器,而在第一次forward时bind
out);
}
MXPredReshape是句柄尺寸调整:
int MXPredReshape(mx_uint num_input_nodes,
const char** input_keys,
const mx_uint* input_shape_indptr,
const mx_uint* input_shape_data,
PredictorHandle handle,
PredictorHandle* out) {
_CreateExecutor(handle);//bind,确保handle有效,可以通过盖变量使用和比较
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
std::unique_ptr<MXAPIPredictor> ret(new MXAPIPredictor());
API_BEGIN();
// shape inference
std::unordered_map<std::string, TShape> new_shape;
for (mx_uint i = 0; i < num_input_nodes; ++i) {
new_shape[std::string(input_keys[i])] =
TShape(input_shape_data + input_shape_indptr[i],
input_shape_data + input_shape_indptr[i + 1]);
}//设置新模型输入尺度
ret->sym = p->sym;//sym = nnvm::Symbol::CreateGroup(out_syms);
std::vector<std::string> arg_names = ret->sym.ListInputNames(Symbol::kReadOnlyArgs);
std::vector<std::string> aux_names = ret->sym.ListInputNames(Symbol::kAuxiliaryStates);
std::vector<TShape> out_shapes(ret->sym.ListOutputNames().size());
std::vector<TShape> aux_shapes(aux_names.size());
std::vector<TShape> arg_shapes;
ret->key2arg = p->key2arg;//获取原始模型中的网络结构
try {
std::vector<TShape> in_shapes;
in_shapes.reserve(arg_names.size());
for (std::string key : ret->sym.ListInputNames(Symbol::kAll)) {
if (new_shape.count(key) != 0) {
in_shapes.push_back(new_shape[key]);
} else {
in_shapes.emplace_back();
}
}
nnvm::Graph g; g.outputs = ret->sym.outputs;
g = mxnet::exec::InferShape(std::move(g), std::move(in_shapes), "__shape__");
bool infer_complete = (g.GetAttr<size_t>("shape_num_unknown_nodes") == 0);
CHECK(infer_complete) << "The shape information of is not enough to get the shapes";
CopyAttr(g.indexed_graph(),
g.GetAttr<nnvm::ShapeVector>("shape"),
&arg_shapes, &out_shapes, &aux_shapes);
} catch (const mxnet::op::InferShapeError &err) {
throw dmlc::Error(err.msg);
}
ret->arg_arrays = p->arg_arrays;
ret->ctx = p->ctx;
for (size_t i=0; i < arg_names.size(); ++i) {
TShape newShape = arg_shapes[i];
NDArray &arr = p->arg_arrays[i];
if (new_shape.count(arg_names[i]) != 0) {
ret->arg_arrays[i].ReshapeAndAlloc(newShape);
} else {
CHECK_EQ(newShape.Size(), arr.shape().Size())
<< "arg " << arg_names[i]
<< " shape has been changed, only allow to change the shape of input data.";
}
}
for (size_t i=0; i < aux_names.size(); ++i) {
TShape newShape = aux_shapes[i];
NDArray &arr = p->aux_arrays[i];
CHECK_EQ(newShape.Size(), arr.shape().Size())
<< "aux " << aux_names[i]
<< " shape has been changed, only allow to change the shape of input data.";
}
ret->aux_arrays = p->aux_arrays;
// bind,可以看出,Reshape必须在子线程中进行
{
std::map<std::string, Context> ctx_map;
std::vector<NDArray> grad_store;
grad_store.reserve(ret->arg_arrays.size());
std::vector<OpReqType> grad_req(ret->arg_arrays.size(), kNullOp);
ret->exec.reset(Executor::Bind(ret->sym, ret->ctx, ctx_map,
ret->arg_arrays,
grad_store, grad_req,
ret->aux_arrays,
p->exec.get()));
ret->out_shapes = out_shapes;
ret->out_arrays = ret->exec->outputs();
}
*out = ret.release();
API_END();
}
MXPredGetOutputShape获取输出的数据规模:
int MXPredGetOutputShape(PredictorHandle handle,
mx_uint out_index,
mx_uint** shape_data,
mx_uint* shape_ndim) {
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
API_BEGIN();
CHECK_LT(out_index, p->out_arrays.size()) << "Index exceed number of outputs";
const TShape& s = p->out_shapes[out_index];//从tshape中获取尺度,放到可读的buffer中
p->out_shapes_buffer.resize(s.ndim());
nnvm::ShapeTypeCast(s.begin(), s.end(), p->out_shapes_buffer.data());
*shape_data = p->out_shapes_buffer.data();
*shape_ndim = p->out_shapes[out_index].ndim();
API_END();
}
MXPredSetInput载入输入数据:
int MXPredSetInput(PredictorHandle handle,
const char* key,
const mx_float* data,
mx_uint size) {
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
API_BEGIN();
auto it = p->key2arg.find(key);
if (it == p->key2arg.end()) {
LOG(FATAL) << "cannot find input key " << key;
}
NDArray& nd = p->arg_arrays[it->second];
nd.SyncCopyFromCPU(data, size);//SyncCopyFromCPU会调用WaitToWrite,待详细分析
API_END();
}
MXPredForward为前向运算,MXPredPartialForward为指定步骤运行,重点用MXPredForward:
int MXPredPartialForward(PredictorHandle handle, int step, int* step_left) {
_CreateExecutor(handle);
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
API_BEGIN();
p->exec->PartialForward(false, step, step_left);//执行graph_executor中的RunOps,待详细分析
API_END();
}
MXPredGetOutput获取输出结果
int MXPredGetOutput(PredictorHandle handle,
mx_uint index,
mx_float* data,
mx_uint size) {
MXAPIPredictor* p = static_cast<MXAPIPredictor*>(handle);
API_BEGIN();
CHECK_LT(index, p->out_arrays.size()) << "Output index out of range";
const NDArray& nd = p->out_arrays[index];
nd.SyncCopyToCPU(data, size);//SyncCopyToCPU会调用WaitToRead,待详细分析
API_END();
}
MXPredFree在模型使用完毕后回收内存空间
int MXPredFree(PredictorHandle handle) {
API_BEGIN();
delete static_cast<MXAPIPredictor*>(handle);
API_END();
}