当前应该是tensorrt对python3.6支持不完善,本机环境:
- tensorflow1.10 gpu
- cuda9.0
- cudnn7.3
在python3.6.7环境下通过tar包安装tensorrt导入的时候提示错误(../anaconda3/lib/python3.6/site-packages/tensorrt/tensorrt.so: undefined symbol: _Py_ZeroStruct
),python3.5正常。
使用tensorflow训练模型
# This file contains functions for training a TensorFlow model
import tensorflow as tf
import numpy as np
def process_dataset():
# Import the data
(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# Reshape the data
NUM_TRAIN = 60000
NUM_TEST = 10000
x_train = np.reshape(x_train, (NUM_TRAIN, 28, 28, 1))
x_test = np.reshape(x_test, (NUM_TEST, 28, 28, 1))
return x_train, y_train, x_test, y_test
def create_model():
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.InputLayer(input_shape=[28,28, 1]))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
return model
def save(model, filename):
# First freeze the graph and remove training nodes.
output_names = model.output.op.name
sess = tf.keras.backend.get_session()
frozen_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_names])
frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph)
# Save the model
with open(filename, "wb") as ofile:
ofile.write(frozen_graph.SerializeToString())
def main():
x_train, y_train, x_test, y_test = process_dataset()
model = create_model()
# Train the model on the data
model.fit(x_train, y_train, epochs = 5, verbose = 1)
# Evaluate the model on test data
model.evaluate(x_test, y_test)
save(model, filename="models/lenet5.pb")
if __name__ == '__main__':
main()
转换生成uff格式文件
mkdir models
convert-to-uff models/lenet5.pb
将需要预测的图像放入models目录下。
预测
sample代码:
from PIL import Image
import numpy as np
import pycuda.driver as cuda
# This import causes pycuda to automatically manage CUDA context creation and cleanup.
import pycuda.autoinit
import tensorrt as trt
import sys, os
sys.path.insert(1, os.path.join(sys.path[0], ".."))
import common
# You can set the logger severity higher to suppress messages (or lower to display more messages).
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
class ModelData(object):
MODEL_FILE = os.path.join(os.path.dirname(__file__), "models/lenet5.uff")
INPUT_NAME ="input_1"
INPUT_SHAPE = (1, 28, 28)
OUTPUT_NAME = "dense_1/Softmax"
def build_engine(model_file):
# For more information on TRT basics, refer to the introductory samples.
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as network, trt.UffParser() as parser:
builder.max_workspace_size = common.GiB(1)
# Parse the Uff Network
parser.register_input(ModelData.INPUT_NAME, ModelData.INPUT_SHAPE)
parser.register_output(ModelData.OUTPUT_NAME)
parser.parse(model_file, network)
# Build and return an engine.
return builder.build_cuda_engine(network)
# Loads a test case into the provided pagelocked_buffer.
def load_normalized_test_case(data_path, pagelocked_buffer, case_num=2):
test_case_path = os.path.join(data_path, str(case_num) + ".pgm")
# Flatten the image into a 1D array, normalize, and copy to pagelocked memory.
img = np.array(Image.open(test_case_path)).ravel()
np.copyto(pagelocked_buffer, 1.0 - img / 255.0)
return case_num
def main():
data_path = common.find_sample_data(description="Runs an MNIST network using a UFF model file", subfolder="mnist")
model_file = ModelData.MODEL_FILE
with build_engine(model_file) as engine:
# Build an engine, allocate buffers and create a stream.
# For more information on buffer allocation, refer to the introductory samples.
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
with engine.create_execution_context() as context:
case_num = load_normalized_test_case(data_path, pagelocked_buffer=inputs[0].host)
# For more information on performing inference, refer to the introductory samples.
# The common.do_inference function will return a list of outputs - we only have one in this case.
[output] = common.do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
pred = np.argmax(output)
print("Test Case: " + str(case_num))
print("Prediction: " + str(pred))
if __name__ == '__main__':
main()
python sample_alter.py -d models
WARNING: models/mnist does not exist. Using models instead.
Test Case: 2
Prediction: 2