版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/boom_man/article/details/86223341
官方:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor.proto
原作者在构建TensorProto对象时放入的是List<Float>
而官方推荐的是放入TensorContent,其速度差距2倍
官方描述是这样的
Serialized raw tensor content from either Tensor::AsProtoTensorContent or
memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
can be used for all tensor types. The purpose of this representation is to
reduce serialization overhead during RPC call by avoiding serialization of
many repeated small items.
代码:
//构造shape对象
TensorShapeProto.Builder tensorShapeBuilder = TensorShapeProto.newBuilder();
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(1));
//#150528 = 224 * 224 * 3
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(mat.height()));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(mat.width()));
tensorShapeBuilder.addDim(TensorShapeProto.Dim.newBuilder().setSize(3));
TensorProto.Builder tensorProtoBuilder = TensorProto.newBuilder();
tensorProtoBuilder.setDtype(DataType.DT_UINT8);
tensorProtoBuilder.setTensorShape(tensorShapeBuilder.build());
//图片的所有值
tensorProtoBuilder.setTensorContent(ByteString.copyFrom(OpenCVUtils.mat2Content(mat)));
ManagedChannel channel = ManagedChannelBuilder.forAddress(TENSOR_FLOW_URL, TENSOR_FLOW_PORT).useTransportSecurity().usePlaintext().build();
PredictionServiceGrpc.PredictionServiceFutureStub predictionServiceFutureStub = PredictionServiceGrpc.newFutureStub(channel);
//创建请求
Predict.PredictRequest.Builder request = Predict.PredictRequest.newBuilder();
//模型名称和模型方法名预设
Model.ModelSpec.Builder modelSpace = Model.ModelSpec.newBuilder();
modelSpace.setName("ssd_hand");
modelSpace.setSignatureName("serving_default");
request.setModelSpec(modelSpace);
//设置入参,访问默认是最新版本,如果需要特定版本可以使用tensorProtoBuilder.setVersionNumber方法
//将数据放到Request中
request.putInputs("input", tensorProtoBuilder.build());
ListenableFuture<Predict.PredictResponse> predict = predictionServiceFutureStub.predict(request.build());
try {
long t = System.currentTimeMillis();
Predict.PredictResponse response = predict.get(50000, TimeUnit.MILLISECONDS);
System.out.println("cost time: " + (System.currentTimeMillis() - t));
} catch (InterruptedException e) {
e.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
} catch (TimeoutException e) {
e.printStackTrace();
}
//mat 转字节数组
public static byte[] mat2Content(Mat mat) {
byte[] grayData = new byte[mat.cols() * mat.rows()*3];
mat.get(0, 0, grayData);
return grayData;
}
maven
<dependency>
<groupId>com.yesup.oss</groupId>
<artifactId>tensorflow-client</artifactId>
<version>1.4-2</version>
<exclusions>
<exclusion>
<artifactId>slf4j-log4j12</artifactId>
<groupId>org.slf4j</groupId>
</exclusion>
<exclusion>
<groupId>io.grpc</groupId>
<artifactId>grpc-protobuf</artifactId>
</exclusion>
<exclusion>
<artifactId>grpc-stub</artifactId>
<groupId>io.grpc</groupId>
</exclusion>
</exclusions>
</dependency>
<!-- 这个库是做图像处理的 -->
<dependency>
<groupId>net.coobird</groupId>
<artifactId>thumbnailator</artifactId>
<version>0.4.8</version>
</dependency>
<dependency>
<groupId>io.grpc</groupId>
<artifactId>grpc-all</artifactId>
<version>1.17.1</version>
<exclusions>
<exclusion>
<artifactId>protobuf-java</artifactId>
<groupId>com.google.protobuf</groupId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-tcnative-boringssl-static</artifactId>
<version>2.0.7.Final</version>
</dependency>
csdn:https://blog.csdn.net/shin627077/article/details/78592729
官方:https://github.com/tensorflow/serving
参考:小米云 http://docs.api.xiaomi.com/cloud-ml/modelservice/0903_use_java_client.html
总结:
Grpc调用核心是发送Tensor.proto
syntax = "proto3";
package tensorflow;
option cc_enable_arenas = true;
option java_outer_classname = "TensorProtos";
option java_multiple_files = true;
option java_package = "org.tensorflow.framework";
import "tensorflow/core/framework/resource_handle.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
// Protocol buffer representing a tensor.
message TensorProto {
DataType dtype = 1;
// Shape of the tensor. TODO(touts): sort out the 0-rank issues.
TensorShapeProto tensor_shape = 2;
// Only one of the representations below is set, one of "tensor_contents" and
// the "xxx_val" attributes. We are not using oneof because as oneofs cannot
// contain repeated fields it would require another extra set of messages.
// Version number.
//
// In version 0, if the "repeated xxx" representations contain only one
// element, that element is repeated to fill the shape. This makes it easy
// to represent a constant Tensor with a single value.
int32 version_number = 3;
// Serialized raw tensor content from either Tensor::AsProtoTensorContent or
// memcpy in tensorflow::grpc::EncodeTensorToByteBuffer. This representation
// can be used for all tensor types. The purpose of this representation is to
// reduce serialization overhead during RPC call by avoiding serialization of
// many repeated small items.
bytes tensor_content = 4;
// Type specific representations that make it easy to create tensor protos in
// all languages. Only the representation corresponding to "dtype" can
// be set. The values hold the flattened representation of the tensor in
// row major order.
// DT_HALF. Note that since protobuf has no int16 type, we'll have some
// pointless zero padding for each value here.
repeated int32 half_val = 13 [packed = true];
// DT_FLOAT.
repeated float float_val = 5 [packed = true];
// DT_DOUBLE.
repeated double double_val = 6 [packed = true];
// DT_INT32, DT_INT16, DT_INT8, DT_UINT8.
repeated int32 int_val = 7 [packed = true];
// DT_STRING
repeated bytes string_val = 8;
// DT_COMPLEX64. scomplex_val(2*i) and scomplex_val(2*i+1) are real
// and imaginary parts of i-th single precision complex.
repeated float scomplex_val = 9 [packed = true];
// DT_INT64
repeated int64 int64_val = 10 [packed = true];
// DT_BOOL
repeated bool bool_val = 11 [packed = true];
// DT_COMPLEX128. dcomplex_val(2*i) and dcomplex_val(2*i+1) are real
// and imaginary parts of i-th double precision complex.
repeated double dcomplex_val = 12 [packed = true];
// DT_RESOURCE
repeated ResourceHandleProto resource_handle_val = 14;
// DT_VARIANT
repeated VariantTensorDataProto variant_val = 15;
};
// Protocol buffer representing the serialization format of DT_VARIANT tensors.
message VariantTensorDataProto {
// Name of the type of objects being serialized.
string type_name = 1;
// Portions of the object that are not Tensors.
bytes metadata = 2;
// Tensors contained within objects being serialized.
repeated TensorProto tensors = 3;
}