相信很多小伙伴对tensorflow的python接口使用都很熟悉了,前期,LZ已经将训练的模型通过tensorflow 的C++接口移植成功,现在出现的问题是tensorflow的session run只能从CPU上进行inference,并且inference后得到的结果也在CPU上,那么正常数据在CPU上,当然是没有问题的,重点是LZ获取的是GPU的buffer,如果使用cuda的接口进行memory的copy,那么花在IO上的时间损耗会特别大,所以后期还在琢磨怎么把整个程序移植到GPU端,从而加速整个算法的时间,这样就要了解CUDA的一些函数和GPU上的基本知识,所以LZ又要开始新的学习了(≧▽≦)/,不说了,程序猿不就是一直都在学习嘛,可能又要写点关于CUDA的学习博客了。
回到正题,我们先来看下tensorflow到底支持什么数据类型,有些小伙伴会问,为啥要了解数据类型呢?因为GPU上的操作大多数是依赖于指针,所以数据类型是直接关系到buferr的大小的,不然你的cudamalloc都没做对,后续还怎么开始呢?
在<tensorflow_path>/bazel-genfiles/tensorflow/core/framework/types.pb.h中,有具体的定义:
namespace tensorflow {
enum DataType {
DT_INVALID = 0,
DT_FLOAT = 1,
DT_DOUBLE = 2,
DT_INT32 = 3,
DT_UINT8 = 4,
DT_INT16 = 5,
DT_INT8 = 6,
DT_STRING = 7,
DT_COMPLEX64 = 8,
DT_INT64 = 9,
DT_BOOL = 10,
DT_QINT8 = 11,
DT_QUINT8 = 12,
DT_QINT32 = 13,
DT_BFLOAT16 = 14,
DT_QINT16 = 15,
DT_QUINT16 = 16,
DT_UINT16 = 17,
DT_COMPLEX128 = 18,
DT_HALF = 19,
DT_RESOURCE = 20,
DT_VARIANT = 21,
DT_UINT32 = 22,
DT_UINT64 = 23,
DT_FLOAT_REF = 101,
DT_DOUBLE_REF = 102,
DT_INT32_REF = 103,
DT_UINT8_REF = 104,
DT_INT16_REF = 105,
DT_INT8_REF = 106,
DT_STRING_REF = 107,
DT_COMPLEX64_REF = 108,
DT_INT64_REF = 109,
DT_BOOL_REF = 110,
DT_QINT8_REF = 111,
DT_QUINT8_REF = 112,
DT_QINT32_REF = 113,
DT_BFLOAT16_REF = 114,
DT_QINT16_REF = 115,
DT_QUINT16_REF = 116,
DT_UINT16_REF = 117,
DT_COMPLEX128_REF = 118,
DT_HALF_REF = 119,
DT_RESOURCE_REF = 120,
DT_VARIANT_REF = 121,
DT_UINT32_REF = 122,
DT_UINT64_REF = 123,
DataType_INT_MIN_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::google::protobuf::int32>::min(),
DataType_INT_MAX_SENTINEL_DO_NOT_USE_ = std::numeric_limits<::google::protobuf::int32>::max()
};
基本上涵盖了各种数据类型了,甚至有些LZ都不是很熟悉。
我们使用tensorflow的接口来定义对应的Tensor,然后对tensor函数的基本使用:
// tensor的初始化方式
// cpu上初始化方式
Tensor(DataType type, const TensorShape& shape);
// device上初始化方式,如果buffer在GPU上,就是使用GPU上初始化方式
Tensor(Allocator* a, DataType type, const TensorShape& shape);
// 这个简直是调试最常用的方式,会输出对应tensor的类型,shape,以及部分数据,但是这个只限制在cpu上的tensor,gpu上的tensor需要先拷贝到cpu上才能使用
std::cout << tensor_name.DebugString() << std::endl;
// tensor的元素数量
std::cout << tensor_name.NumElements() << std::endl;
// tensor的元素类型,如果是DF_FLOAT,输出就是1
std::cout << tensor_name.dtype() << std::endl;
上面就是tensorflow支持的数据类型和对于tensor的初始化和常用函数的使用。话说,好像快过年了呢(⊙o⊙)?