win10系统下Tensorflow Faster RCNN 的安装配置与测试

博客原文链接:https://blog.csdn.net/kebi199312/article/details/88368904

最近在学习深度学习的目标检测(Objection Detection),对比了YOLO V3、Faster R-CNN、SSD及其它的目标检测算法,觉得Faster RCNN的性能不错,并且安装也比较简单,于是就自己动手在win10系统下安装了Faster RCNN。

本文分为两部分:

  • 在win10系统上配置Tensorflow版本的Faster RCNN
  • 运行Faster RCNN程序,测试了图片和视频

一、环境配置:

1、环境

  • win10系统,显卡GeForce GTX 960M;
  • TensorFlow-gpu 1.13.0-rc2,CUDA 10.0,Cudnn 7.4.2;
  • python 3.5.2

安装Tensorflow-gpu版本可以参考博主的另一篇博客:https://blog.csdn.net/kebi199312/article/details/86549637,虽然CUDA的版本不同,但是安装步骤大同小异。

Tensorflow-gpu是在windows PowerShell里用pip安装的,同时安装一些必要的库:cython、easydict、matplotlib、python-opencv等,可直接使用pip安装或者下载相应的.whl离线文件安装。

2、源码下载

Faster RCNN的下载地址:

https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3.5

也可通过git下载,在命令行中打开cmd,cd到你的目录,输入:

git clone https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3.5.git
  
  

 下载完成后,项目的根目录为:Faster-RCNN-TensorFlow-Python3.5-master

cd到Faster-RCNN-TensorFlow-Python3.5-master\data\coco\PythonAPI目录下,打开cmd,运行编译提供的代码:


  
  
  1. python setup .py build_ext --inplace
  2. python setup .py build_ext install

二、数据集VOC2007下载:

数据集使用的是VOC2007,下载地址:

http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar

由于被墙,可以下载百度云盘的数据集,链接:https://pan.baidu.com/s/1Y_RzqLvW4CAzTEq4ICFVUA ,提取码:m9dl

将下载后的三个压缩包解压到同一个文件夹,同时选中这三个压缩包,选择解压到当前文件夹,可得到VOCDevkit文件夹,如图1所示,将VOCDevkit重命名为VOCDevkit2007,然后将这个文件夹复制到data目录下 。文件夹目录

Faster-RCNN-TensorFlow-Python3.5-master\data\VOCDevkit2007

                                                               图1   数据集VOC2007的文件夹

三、VGG16模型下载:

VGG16模型的下载地址:http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz,也可去百度云盘下载,

链接:https://pan.baidu.com/s/11Ty10NJ-rgXkkvM92SVVKw ,提取码:d2jz

下载完后解压,文件重命名为vgg16.ckpt,如图2所示。新建文件夹imagenet_weights,把vgg16.ckpt放到imagenet_weights下,再将imagenet_weights文件夹复制到data文件夹下。文件夹目录:

Faster-RCNN-TensorFlow-Python3.5-master\data\imagenet_weights\vgg16.ckpt

                                                                图2   重命名后的vgg16.ckpt

四、 训练模型

训练模型的参数可以在Faster-RCNN-TensorFlow-Python3.5-master\lib\config文件夹里的config.py修改,包括训练的总步数、权重衰减、学习率、batch_size等参数。


  
  
  1. tf.app.flags.DEFINE_float( 'weight_decay', 0.0005, "Weight decay, for regularization")
  2. tf.app.flags.DEFINE_float( 'learning_rate', 0.001, "Learning rate")
  3. tf.app.flags.DEFINE_float( 'momentum', 0.9, "Momentum")
  4. tf.app.flags.DEFINE_float( 'gamma', 0.1, "Factor for reducing the learning rate")
  5. tf.app.flags.DEFINE_integer( 'batch_size', 128, "Network batch size during training")
  6. tf.app.flags.DEFINE_integer( 'max_iters', 40000, "Max iteration")
  7. tf.app.flags.DEFINE_integer( 'step_size', 30000, "Step size for reducing the learning rate, currently only support one step")
  8. tf.app.flags.DEFINE_integer( 'display', 20, "Iteration intervals for showing the loss during training, on command line interface")
  9. tf.app.flags.DEFINE_string( 'initializer', "truncated", "Network initialization parameters")
  10. tf.app.flags.DEFINE_string( 'pretrained_model', "./data/imagenet_weights/vgg16.ckpt", "Pretrained network weights")
  11. tf.app.flags.DEFINE_boolean( 'bias_decay', False, "Whether to have weight decay on bias as well")
  12. tf.app.flags.DEFINE_boolean( 'double_bias', True, "Whether to double the learning rate for bias")
  13. tf.app.flags.DEFINE_boolean( 'use_all_gt', True, "Whether to use all ground truth bounding boxes for training, "
  14. "For COCO, setting USE_ALL_GT to False will exclude boxes that are flagged as ''iscrowd''")
  15. tf.app.flags.DEFINE_integer( 'max_size', 1000, "Max pixel size of the longest side of a scaled input image")
  16. tf.app.flags.DEFINE_integer( 'test_max_size', 1000, "Max pixel size of the longest side of a scaled input image")
  17. tf.app.flags.DEFINE_integer( 'ims_per_batch', 1, "Images to use per minibatch")
  18. tf.app.flags.DEFINE_integer( 'snapshot_iterations', 5000, "Iteration to take snapshot")

参数调整完后,在Faster-RCNN-TensorFlow-Python3.5-master的目录下,运行 python train.py,就可以训练生成模型了。

模型训练结束后,在 Faster-RCNN-TensorFlow-Python3.5-master\default\voc_2007_trainval\default目录下可以看到训练的模型,一个迭代了40000次,迭代次数可在Faster-RCNN-TensorFlow-Python3.5-master\lib\config文件夹里的config.py修改。

在目录下新建output\vgg16\voc_2007_trainval\default文件,将训练生成的文件复制到该文件下,并改名如下:“vgg16.ckpt.meta”,如图4所示:

                                                        图4  改名后的vgg16.ckpt.meta

五、测试模型

对demo.py进行如下的修改

1、将NETS中的“vgg16_faster_rcnn_iter_70000.ckpt改成vgg16”,如下所示

NETS = {'vgg16': ('vgg16.ckpt',), 'res101': ('res101_faster_rcnn_iter_110000.ckpt',)}
  
  

2、将DATASETS中的“voc_2007_trainval+voc_2012_trainval”改为“voc_2007_trainval”,如下所示

DATASETS = {'pascal_voc': ('voc_2007_trainval',), 'pascal_voc_0712': ('voc_2007_trainval',)}
  
  

3、def parse_args()函数的两个default分别改成vgg16和pascal_voc,如下所示


  
  
  1. def parse_args():
  2. """Parse input arguments."""
  3. parser = argparse.ArgumentParser(description= 'Tensorflow Faster R-CNN demo')
  4. parser.add_argument( '--net', dest= 'demo_net', help= 'Network to use [vgg16 res101]',
  5. choices=NETS.keys(), default= 'vgg16')
  6. parser.add_argument( '--dataset', dest= 'dataset', help= 'Trained dataset [pascal_voc pascal_voc_0712]',
  7. choices=DATASETS.keys(), default= 'pascal_voc')
  8. args = parser.parse_args()
  9. return args

4、修改上述参数后,运行demo.py,出现错误:


  
  
  1. E:\Software\Python\python.exe E:/liukang/Faster-RCNN-TensorFlow-Python3.5-master/demo.py --net vgg16
  2. Traceback (most recent call last):
  3. File "E:/liukang/Faster-RCNN-TensorFlow-Python3.5-master/demo.py", line 142, in < module>
  4. tag= 'default', anchor_scales=[ 8, 16, 32])
  5. File "E:\liukang\Faster-RCNN-TensorFlow-Python3.5-master\lib\nets\network.py", line 283, in create_architecture
  6. weights_regularizer = tf.contrib.layers.l2_regularizer(cfg.FLAGS.weight_decay)
  7. File "E:\Software\Python\lib\site-packages\tensorflow\python\platform\flags.py", line 84, in __getattr__
  8. wrapped(_sys.argv)
  9. File "E:\Software\Python\lib\site-packages\absl\flags\_flagvalues.py", line 633, in __call__
  10. name, value, suggestions=suggestions)
  11. absl.flags._exceptions.UnrecognizedFlagError: Unknown command line flag 'net'. Did you mean: network ?

解决方法:

新建一个py文件,把demo.py脚本内容复制到里面就好了;新建一个脚本temp.py,测试多张图片,运行的结果如图5所示:

                                                                         图5  运行demo.py的结果

 测试视频,代码如下,截取视频的一张图片如图6所示:


  
  
  1. # -*- coding: utf-8 -*-
  2. # --------------------------------------------------------
  3. # Faster R-CNN
  4. #author lk
  5. # --------------------------------------------------------
  6. """
  7. Demo script showing detections in videos.
  8. """
  9. from __future__ import absolute_import
  10. from __future__ import division
  11. from __future__ import print_function
  12. import argparse
  13. import os
  14. import cv2
  15. import tensorflow as tf
  16. from lib.config import config as cfg
  17. from lib.utils.test import im_detect
  18. from lib.utils.nms_wrapper import nms
  19. from lib.utils.timer import Timer
  20. from lib.nets.vgg16 import vgg16
  21. import matplotlib.pyplot as plt
  22. import numpy as np
  23. import sys
  24. import time
  25. (major_ver, minor_ver, subminor_ver) = (cv2.__version__).split( '.')
  26. CLASSES = ( '__background__',
  27. 'aeroplane', 'bicycle', 'bird', 'boat',
  28. 'bottle', 'bus', 'car', 'cat', 'chair',
  29. 'cow', 'diningtable', 'dog', 'horse',
  30. 'motorbike', 'person', 'pottedplant',
  31. 'sheep', 'sofa', 'train', 'tvmonitor')
  32. NETS = { 'vgg16': ( 'vgg16.ckpt',), 'res101': ( 'res101_faster_rcnn_iter_110000.ckpt',)}
  33. DATASETS = { 'pascal_voc': ( 'voc_2007_trainval',), 'pascal_voc_0712': ( 'voc_2007_trainval',)}
  34. def vis_detections(im, class_name, dets, thresh=0.5):
  35. """Draw detected bounding boxes."""
  36. inds = np.where(dets[:, -1] >= thresh)[ 0]
  37. if len(inds) == 0:
  38. return
  39. im = im[:, :, ( 2, 1, 0)]
  40. fig, ax = plt.subplots(figsize=( 12, 12))
  41. ax.imshow(im, aspect= 'equal')
  42. for i in inds:
  43. bbox = dets[i, : 4]
  44. score = dets[i, -1]
  45. ax.add_patch(
  46. plt.Rectangle((bbox[ 0], bbox[ 1]),
  47. bbox[ 2] - bbox[ 0],
  48. bbox[ 3] - bbox[ 1], fill= False,
  49. edgecolor= 'red', linewidth= 3.5)
  50. )
  51. ax.text(bbox[ 0], bbox[ 1] - 2,
  52. '{:s} {:.3f}'.format(class_name, score),
  53. bbox=dict(facecolor= 'blue', alpha= 0.5),
  54. fontsize= 14, color= 'white')
  55. ax.set_title(( '{} detections with '
  56. 'p({} | box) >= {:.1f}').format(class_name, class_name,
  57. thresh),
  58. fontsize= 14)
  59. plt.axis( 'off')
  60. plt.tight_layout()
  61. plt.draw()
  62. def vis_detections_video(im, class_name, dets, thresh=0.5):
  63. """Draw detected bounding boxes."""
  64. #np.where判断语句
  65. inds = np.where(dets[:, -1] >= thresh)[ 0]
  66. if len(inds) == 0:
  67. return im
  68. for i in inds:
  69. bbox = dets[i, : 4]
  70. score = dets[i, -1]
  71. cv2.rectangle(im, (bbox[ 0], bbox[ 1]), (bbox[ 2], bbox[ 3]), ( 0, 0, 255), 2)
  72. cv2.rectangle(im, (int(bbox[ 0]), int(bbox[ 1] - 20)), (int(bbox[ 0] + 200), int(bbox[ 1])), ( 10, 10, 10), -1)
  73. cv2.putText(im, '{:s} {:.3f}'.format(class_name, score), (int(bbox[ 0]), int(bbox[ 1] - 2)),
  74. cv2.FONT_HERSHEY_SIMPLEX, .75, ( 0, 0, 255))
  75. return im
  76. def demo(net, im):
  77. """Detect object classes in an image using pre-computed object proposals."""
  78. global frameRate
  79. global fps
  80. # Detect all object classes and regress object bounds
  81. timer = Timer()
  82. timer.tic()
  83. scores, boxes = im_detect(sess,net, im)
  84. timer.toc()
  85. print( 'Detection took {:.3f}s for '
  86. '{:d} object proposals'.format(timer.total_time, boxes.shape[ 0]))
  87. frameRate = 1.0 / timer.total_time
  88. print( 'fps:'+str(float(frameRate)))
  89. # Visualize detections for each class
  90. CONF_THRESH = 0.8
  91. NMS_THRESH = 0.3
  92. for cls_ind, cls in enumerate(CLASSES[ 1:]):
  93. # because we skipped background
  94. cls_ind += 1
  95. cls_boxes = boxes[:, 4 * cls_ind: 4 * (cls_ind + 1)]
  96. cls_scores = scores[:, cls_ind]
  97. dets = np.hstack((cls_boxes,
  98. cls_scores[:, np.newaxis])).astype(np.float32)
  99. keep = nms(dets, NMS_THRESH)
  100. dets = dets[keep, :]
  101. vis_detections_video(im, cls, dets, thresh=CONF_THRESH)
  102. text= '{:s} {:.2f}'.format( "FPS:", frameRate)
  103. position=( 50, 50)
  104. cv2.putText(im, text, position, cv2.FONT_HERSHEY_SIMPLEX, 1, ( 0, 0, 255))
  105. cv2.imshow(videoFilePath.split( '/')[len(videoFilePath.split( '/')) - 1], im)
  106. cv2.waitKey( 50)
  107. def parse_args():
  108. """Parse input arguments."""
  109. parser = argparse.ArgumentParser(description= 'Tensorflow Faster R-CNN demo')
  110. parser.add_argument( '--net', dest= 'demo_net', help= 'Network to use [vgg16 res101]',
  111. choices=NETS.keys(), default= 'vgg16')
  112. parser.add_argument( '--dataset', dest= 'dataset', help= 'Trained dataset [pascal_voc pascal_voc_0712]',
  113. choices=DATASETS.keys(), default= 'pascal_voc')
  114. args = parser.parse_args()
  115. return args
  116. if __name__ == '__main__':
  117. args = parse_args()
  118. # model path
  119. demonet = args.demo_net
  120. dataset = args.dataset
  121. tfmodel = os.path.join( 'output', demonet, DATASETS[dataset][ 0], 'default', NETS[demonet][ 0])
  122. if not os.path.isfile(tfmodel + '.meta'):
  123. print(tfmodel)
  124. raise IOError(( '{:s} not found.\nDid you download the proper networks from '
  125. 'our server and place them properly?').format(tfmodel + '.meta'))
  126. # set config
  127. tfconfig = tf.ConfigProto(allow_soft_placement= True)
  128. tfconfig.gpu_options.allow_growth = True
  129. # load network
  130. if demonet == 'vgg16':
  131. net = vgg16(batch_size= 1)
  132. else:
  133. raise NotImplementedError
  134. # init session
  135. sess = tf.Session(config=tfconfig)
  136. net.create_architecture(sess, "TEST", 21,
  137. tag= 'default', anchor_scales=[ 8, 16, 32])
  138. saver = tf.train.Saver()
  139. saver.restore(sess, tfmodel)
  140. print( '\n\nLoaded network {:s}'.format(tfmodel))
  141. # Warmup on a dummy image
  142. im = 128 * np.ones(( 300, 500, 3), dtype=np.uint8)
  143. for i in range( 2):
  144. _, _ = im_detect(sess,net, im)
  145. videoFilePath = 'Camera Road 01.avi'
  146. videoCapture = cv2.VideoCapture(videoFilePath)
  147. while True:
  148. success, im = videoCapture.read()
  149. demo(net, im)
  150. if cv2.waitKey( 10) & 0xFF == ord( 'q'):
  151. break
  152. videoCapture.release()
  153. cv2.destroyAllWindows()
  154. sess.close()

大家在运行该程序时,要将 videoFilePath改成自己的路径。

                                                                                      图6  测试视频

最近在学习深度学习的目标检测(Objection Detection),对比了YOLO V3、Faster R-CNN、SSD及其它的目标检测算法,觉得Faster RCNN的性能不错,并且安装也比较简单,于是就自己动手在win10系统下安装了Faster RCNN。

猜你喜欢

转载自blog.csdn.net/kellyroslyn/article/details/92159004