之前的学习笔记是调用训练好的结果来做识别,分为加载本地图片识别和调用usb摄像头实时识别(IP摄像头暂时不可用);但是首先有了训练才能有训练好的模型文件供我们使用。加之训练过程比较复杂,调用多个脚本,上手不便;制作训练用的软件一方面是方便自己使用,另一方面也对自己是个锻炼。软件最终的界面如下图所示,可以使用IP、USB摄像头实时将图片显示在界面中,并在界面中实时进行标记(类似于labelImg软件),标记结束后将标准xml文件、原始图像保存在软件脚本所在的目录下,xml保存于annotatis文件夹,图像保存于img文件夹,同时自动分配训练集及验证集,并生成对应的tfrecord格式,这些数据一并保存在data文件夹下。在标记过程中亦可自动生成lable map文件。
软件目前需加载预训练模型进行训练(重新训练按钮暂无作用),使用ssd_mobilenet_v2_coco的预训练数据集(API原始代码就带)。软件可以完成从图像采集----->标注图像----->生成所需数据----->进行训练----->生成pb模型文件的整个过程。
存在问题就是目前想完成自动修改config文件的功能,用户在界面中指定训练次数以及batchsize,但是还不能实现。
软件需进一步优化。所有的源代码如下,本软件只是将object detection API的各个脚本封装在了一起。
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import io
import logging
import sys
import cv2
import os
import random
import PIL.Image
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import functools
import json
from lxml import etree
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from xml.dom.minidom import Document
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.builders import dataset_builder
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.legacy import trainer
from google.protobuf import text_format
from object_detection import exporter
from object_detection.protos import pipeline_pb2
class MyLabel(QLabel):
x0 = 0
y0 = 0
x1 = 0
y1 = 0
flag = False
#鼠标点击事件
def mousePressEvent(self,event):
global x_start
global y_start
if event.buttons () == QtCore.Qt.LeftButton:
self.flag = True
self.x0 = event.x()
self.y0 = event.y()
x_start = self.x0
y_start = self.y0
#鼠标释放事件
def mouseReleaseEvent(self,event):
self.flag = False
#鼠标移动事件
def mouseMoveEvent(self,event):
global x_end
global y_end
if event.buttons () == QtCore.Qt.LeftButton:
if self.flag:
self.x1 = event.x()
self.y1 = event.y()
self.update()
x_end = self.x1
y_end = self.y1
#绘制事件
def paintEvent(self, event):
super().paintEvent(event)
rect =QRect(self.x0, self.y0, abs(self.x1-self.x0), abs(self.y1-self.y0))
painter = QPainter(self)
painter.setPen(QPen(Qt.green,2,Qt.SolidLine))
painter.drawRect(rect)
class Ui_train_window(QtWidgets.QWidget):
def setupUi(self):
self.setObjectName("train_window")
self.resize(690, 600)
self.setMinimumSize(QtCore.QSize(690, 600))
self.setMaximumSize(QtCore.QSize(690, 600))
self.horizontalLayoutWidget = QtWidgets.QWidget(self)
self.horizontalLayoutWidget.setGeometry(QtCore.QRect(0, 10, 681, 80))
self.horizontalLayoutWidget.setObjectName("horizontalLayoutWidget")
self.horizontalLayout = QtWidgets.QHBoxLayout(self.horizontalLayoutWidget)
self.horizontalLayout.setContentsMargins(0, 0, 0, 0)
self.horizontalLayout.setObjectName("horizontalLayout")
self.lab_enter_name = QtWidgets.QLabel(self.horizontalLayoutWidget)
self.lab_enter_name.setObjectName("lab_enter_name")
self.horizontalLayout.addWidget(self.lab_enter_name)
self.le_username = QtWidgets.QLineEdit(self.horizontalLayoutWidget)
self.le_username.setObjectName("le_username")
self.horizontalLayout.addWidget(self.le_username)
self.lab_enter_pw = QtWidgets.QLabel(self.horizontalLayoutWidget)
self.lab_enter_pw.setObjectName("lab_enter_pw")
self.horizontalLayout.addWidget(self.lab_enter_pw)
self.le_userpw = QtWidgets.QLineEdit(self.horizontalLayoutWidget)
self.le_userpw.setObjectName("le_userpw")
self.le_userpw.setEchoMode(QLineEdit.Password)
self.horizontalLayout.addWidget(self.le_userpw)
self.label = QtWidgets.QLabel(self.horizontalLayoutWidget)
self.label.setObjectName("label")
self.horizontalLayout.addWidget(self.label)
self.le_ipadr = QtWidgets.QLineEdit(self.horizontalLayoutWidget)
self.le_ipadr.setObjectName("le_ipadr")
self.horizontalLayout.addWidget(self.le_ipadr)
self.btn_openIPcam = QtWidgets.QPushButton(self.horizontalLayoutWidget)
self.btn_openIPcam.setMinimumSize(QtCore.QSize(50, 10))
self.btn_openIPcam.setObjectName("btn_openIPcam")
self.horizontalLayout.addWidget(self.btn_openIPcam)
self.showpic= MyLabel(self)
self.showpic.setGeometry(QtCore.QRect(15, 110, 500, 400))
self.showpic.setMinimumSize(QtCore.QSize(500, 400))
self.showpic.setMaximumSize(QtCore.QSize(500, 400))
self.showpic.setObjectName("show")
self.showpic.setStyleSheet(("border:2px solid lightgray"))
self.verticalLayoutWidget = QtWidgets.QWidget(self)
self.verticalLayoutWidget.setGeometry(QtCore.QRect(560, 100, 111, 381))
self.verticalLayoutWidget.setObjectName("verticalLayoutWidget")
self.verticalLayout = QtWidgets.QVBoxLayout(self.verticalLayoutWidget)
self.verticalLayout.setContentsMargins(0, 0, 0, 0)
self.verticalLayout.setObjectName("verticalLayout")
self.btn_start = QtWidgets.QPushButton(self.verticalLayoutWidget)
self.btn_start.setObjectName("btn_start")
self.verticalLayout.addWidget(self.btn_start)
self.lab_herelable = QtWidgets.QLabel(self.verticalLayoutWidget)
self.lab_herelable.setObjectName("lab_herelable")
self.lab_herelable.setMaximumSize(QtCore.QSize(16777215, 15))
self.verticalLayout.addWidget(self.lab_herelable)
self.combo_label = QtWidgets.QComboBox(self.verticalLayoutWidget)
self.combo_label.setObjectName("combo_label")
self.combo_label.addItem("")
self.combo_label.addItem("")
self.combo_label.addItem("")
self.combo_label.addItem("")
self.combo_label.addItem("")
self.combo_label.addItem("")
self.verticalLayout.addWidget(self.combo_label)
self.btn_save = QtWidgets.QPushButton(self.verticalLayoutWidget)
self.btn_save.setObjectName("btn_save")
self.verticalLayout.addWidget(self.btn_save)
self.btn_finish = QtWidgets.QPushButton(self.verticalLayoutWidget)
self.btn_finish.setObjectName("btn_finish")
self.verticalLayout.addWidget(self.btn_finish)
self.horizontalLayoutWidget_2 = QtWidgets.QWidget(self)
self.horizontalLayoutWidget_2.setGeometry(QtCore.QRect(0, 500, 681, 71))
self.horizontalLayoutWidget_2.setObjectName("horizontalLayoutWidget_2")
self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.horizontalLayoutWidget_2)
self.horizontalLayout_2.setContentsMargins(0, 0, 0, 0)
self.horizontalLayout_2.setObjectName("horizontalLayout_2")
self.lab_stepnum = QtWidgets.QLabel(self.horizontalLayoutWidget_2)
self.lab_stepnum.setObjectName("lab_stepnum")
self.horizontalLayout_2.addWidget(self.lab_stepnum)
self.le_stepnm = QtWidgets.QLineEdit(self.horizontalLayoutWidget_2)
self.le_stepnm.setObjectName("le_stepnm")
self.horizontalLayout_2.addWidget(self.le_stepnm)
self.lab_batchsize = QtWidgets.QLabel(self.horizontalLayoutWidget_2)
self.lab_batchsize.setObjectName("lab_batchsize")
self.horizontalLayout_2.addWidget(self.lab_batchsize)
self.le_batchsize = QtWidgets.QLineEdit(self.horizontalLayoutWidget_2)
self.le_batchsize.setObjectName("le_batchsize")
self.horizontalLayout_2.addWidget(self.le_batchsize)
self.btn_pretrain = QtWidgets.QPushButton(self.horizontalLayoutWidget_2)
self.btn_pretrain.setObjectName("btn_pretrain")
self.horizontalLayout_2.addWidget(self.btn_pretrain)
self.btn_retrain = QtWidgets.QPushButton(self.horizontalLayoutWidget_2)
self.btn_retrain.setObjectName("btn_retrain")
self.horizontalLayout_2.addWidget(self.btn_retrain)
self.showpic.raise_()
self.horizontalLayoutWidget.raise_()
self.verticalLayoutWidget.raise_()
self.horizontalLayoutWidget_2.raise_()
self.i = 0
self.labelname = []
self.retranslateUi(self)
self.btn_openIPcam.clicked.connect(self.IPinfo)
self.btn_start.clicked.connect(self.labledata)
self.btn_save.clicked.connect(self.save_lab)
self.btn_finish.clicked.connect(self.label_end)
self.btn_pretrain.clicked.connect(self.pretrain)
QtCore.QMetaObject.connectSlotsByName(self)
def retranslateUi(self, train_window):
_translate = QtCore.QCoreApplication.translate
self.setWindowTitle(_translate("train_window", "训练工具"))
self.lab_enter_name.setText(_translate("train_window", "输入IP摄像头用户名:"))
self.lab_enter_pw.setText(_translate("train_window", "输入IP摄像头密码:"))
self.label.setText(_translate("train_window", "IP地址"))
self.btn_openIPcam.setText(_translate("train_window", "连接IP摄像头"))
self.showpic.setText(_translate("train_window", ""))
self.btn_start.setText(_translate("train_window", "开始标记"))
self.lab_herelable.setText(_translate("train_window", " 请选择标签"))
self.combo_label.setItemText(0, _translate("train_window", "a"))
self.combo_label.setItemText(1, _translate("train_window", "b"))
self.combo_label.setItemText(2, _translate("train_window", "c"))
self.combo_label.setItemText(3, _translate("train_window", "d"))
self.combo_label.setItemText(4, _translate("train_window", "e"))
self.combo_label.setItemText(5, _translate("train_window", "f"))
self.btn_save.setText(_translate("train_window", "保存"))
self.btn_finish.setText(_translate("train_window", "标记结束"))
self.lab_stepnum.setText(_translate("train_window", "迭代次数:"))
self.lab_batchsize.setText(_translate("train_window", "batchsize:"))
self.btn_pretrain.setText(_translate("train_window", "预训练"))
self.btn_retrain.setText(_translate("train_window", "重新训练"))
def IPinfo(self):
username = self.le_username.text()
password = self.le_userpw.text()
ipaddress = self.le_ipadr.text()
#cam_rtsp_addr = "rtsp://" + username + ":" + password + "@" + ipaddress + "/h264/ch33/main/av_stream"
#self.camcapture = cv2.VideoCapture(cam_rtsp_addr)
self.camcapture = cv2.VideoCapture(0)
self.timer = QtCore.QTimer()
self.timer.start()
self.timer.setInterval(0.3)
self.timer.timeout.connect(self.camshow)
print('摄像头已启动')
def camshow(self):
global showImage
global camimg
global imghight
global imgwidth
global imgdepth
_ , camimg = self.camcapture.read()
camimg2 = cv2.cvtColor(camimg, cv2.COLOR_BGR2RGB)
imghight = camimg.shape[0]
imgwidth = camimg.shape[1]
imgdepth = camimg.shape[2]
showImage = QtGui.QImage(camimg2.data, camimg2.shape[1], camimg2.shape[0], QtGui.QImage.Format_RGB888)
def labledata(self):
self.showpic.setPixmap(QtGui.QPixmap.fromImage(showImage))
def save_lab(self):
global lb_names
global num_class
classname = self.combo_label.currentText()
classnum = self.combo_label.count()
self.labelname.append(classname)
lb_names = list(set(self.labelname))
print(lb_names)
num_class = len(lb_names)
# 这里保存标注原图
imgpath = os.path.join("./img/",classname + "000" + str(self.i) + ".jpg")
imgname = classname + "000" + str(self.i) + ".jpg"
cv2.imwrite(imgpath, camimg)
# 这里将每个文件的xml文件保存下来
doc = Document()
annotation = doc.createElement('annotation')
doc.appendChild(annotation)
folder = doc.createElement('folder')
annotation.appendChild(folder)
folder_text = doc.createTextNode('data')
folder.appendChild(folder_text)
filename = doc.createElement('filename')
annotation.appendChild(filename)
filename_text = doc.createTextNode(imgname)
filename.appendChild(filename_text)
path = doc.createElement('path')
annotation.appendChild(path)
path_text = doc.createTextNode(imgpath)
path.appendChild(path_text)
source = doc.createElement('source')
annotation.appendChild(source)
database = doc.createElement('database')
source.appendChild(database)
database_text = doc.createTextNode('Unknown')
database.appendChild(database_text)
size = doc.createElement('size')
annotation.appendChild(size)
width = doc.createElement('width')
height = doc.createElement('height')
depth = doc.createElement('depth')
width_text = doc.createTextNode(str(imgwidth))
height_text = doc.createTextNode(str(imghight))
depth_text = doc.createTextNode(str(imgdepth))
size.appendChild(width)
size.appendChild(height)
size.appendChild(depth)
width.appendChild(width_text)
height.appendChild(height_text)
depth.appendChild(depth_text)
segmented = doc.createElement('segmented')
annotation.appendChild(segmented)
segmented_text = doc.createTextNode('0')
segmented.appendChild(segmented_text)
objects = doc.createElement('object')
annotation.appendChild(objects)
name = doc.createElement('name')
pose = doc.createElement('pose')
truncated = doc.createElement('truncated')
difficult = doc.createElement('difficult')
name_text = doc.createTextNode(classname)
pose_text = doc.createTextNode('Unspecified')
truncated_text = doc.createTextNode('0')
difficult_text = doc.createTextNode('0')
objects.appendChild(name)
objects.appendChild(pose)
objects.appendChild(truncated)
objects.appendChild(difficult)
name.appendChild(name_text)
pose.appendChild(pose_text)
truncated.appendChild(truncated_text)
difficult.appendChild(difficult_text)
bndbox = doc.createElement('bndbox')
objects.appendChild(bndbox)
xmin = doc.createElement('xmin')
ymin = doc.createElement('ymin')
xmax = doc.createElement('xmax')
ymax = doc.createElement('ymax')
xmin_text = doc.createTextNode(str(x_start))
ymin_text = doc.createTextNode(str(y_start))
xmax_text = doc.createTextNode(str(x_end))
ymax_text = doc.createTextNode(str(y_end))
bndbox.appendChild(xmin)
bndbox.appendChild(ymin)
bndbox.appendChild(xmax)
bndbox.appendChild(ymax)
xmin.appendChild(xmin_text)
ymin.appendChild(ymin_text)
xmax.appendChild(xmax_text)
ymax.appendChild(ymax_text)
f = open(os.path.join("./annotations/",classname + "000" + str(self.i) + ".xml"),'w')
doc.writexml(f,indent = '\t',newl = '\n', addindent = '\t',encoding='utf-8')
f.close()
self.i = self.i+1
def label_end(self):
# 用于分配txt文件
trainval_percent = 0.8
train_percent = 0.7
xmlfilepath = './annotations/'
txtsavepath = './data/'
total_xml = os.listdir(xmlfilepath)
num=len(total_xml)
tv=int(num*trainval_percent)#训练集+验证集
tr=int(tv*train_percent) #训练集
trainval= random.sample(range(num),tv)
train=random.sample(trainval,tr)
ftrainval = open(os.path.join(txtsavepath,'trainval.txt'), 'w')
ftest = open(os.path.join(txtsavepath,'test.txt'), 'w')
ftrain = open(os.path.join(txtsavepath,'train.txt'), 'w')
fval = open(os.path.join(txtsavepath,'val.txt'), 'w')
for i in range(num):
name=total_xml[i][:-4]+'\n'
if i in trainval:
ftrainval.write(name)
if i in train:
ftrain.write(name)
else:
fval.write(name)
else:
ftest.write(name)
ftrainval.close()
ftrain.close()
fval.close()
ftest.close()
# 用于自动生成pbtxt文件----????每次都少保存一个??
# 添加close语句即可
j = 1
for lb_name in lb_names:
print(lb_name)
pbtxtfile = open('./label_map.pbtxt',mode = 'a')
pbtxtfile.write('item{' + '\n')
pbtxtfile.write(' id: ' + str(j) + '\n')
pbtxtfile.write(' name: ' + "'" + lb_name + "'" + '\n')
pbtxtfile.write('}' + '\n')
j += 1
pbtxtfile.close()
# 用于生成tfrecord文件
def dict_to_tf_example(data,
dataset_directory,
label_map_dict,
ignore_difficult_instances=False,
image_subdirectory='JPEGImages'):
img_path = os.path.join(dataset_directory, data['filename'])
full_path = img_path
with tf.gfile.GFile(full_path, 'rb') as fid:
encoded_jpg = fid.read()
encoded_jpg_io = io.BytesIO(encoded_jpg)
image = PIL.Image.open(encoded_jpg_io)
if image.format != 'JPEG':
raise ValueError('Image format not JPEG')
key = hashlib.sha256(encoded_jpg).hexdigest()
width = int(data['size']['width'])
height = int(data['size']['height'])
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
for obj in data['object']:
difficult = bool(int(obj['difficult']))
if ignore_difficult_instances and difficult:
continue
difficult_obj.append(int(difficult))
xmin.append(float(obj['bndbox']['xmin']) / width)
ymin.append(float(obj['bndbox']['ymin']) / height)
xmax.append(float(obj['bndbox']['xmax']) / width)
ymax.append(float(obj['bndbox']['ymax']) / height)
classes_text.append(obj['name'].encode('utf8'))
classes.append(label_map_dict[obj['name']])
truncated.append(int(obj['truncated']))
poses.append(obj['pose'].encode('utf8'))
example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/source_id': dataset_util.bytes_feature(
data['filename'].encode('utf8')),
'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded': dataset_util.bytes_feature(encoded_jpg),
'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated': dataset_util.int64_list_feature(truncated),
'image/object/view': dataset_util.bytes_list_feature(poses),
}))
return example
def trainrc():
data_dir = './img'
years = ['VOC2007']
writer = tf.python_io.TFRecordWriter(os.path.join('./data/','train.record'))
label_map_dict = label_map_util.get_label_map_dict('./label_map.pbtxt')
for year in years:
logging.info('Reading from PASCAL %s dataset.', year)
examples_path = './data/train.txt'
annotations_dir = './annotations/'
examples_list = dataset_util.read_examples_list(examples_path)
for idx, example in enumerate(examples_list):
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(examples_list))
path = os.path.join(annotations_dir, example + '.xml')
with tf.gfile.GFile(path, 'r') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str.encode('utf-8'))
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
tf_example = dict_to_tf_example(data, data_dir, label_map_dict,False)
writer.write(tf_example.SerializeToString())
writer.close()
trainrc()
def valrc():
data_dir = './img'
years = ['VOC2007']
writer = tf.python_io.TFRecordWriter(os.path.join('./data/','val.record'))
label_map_dict = label_map_util.get_label_map_dict('./label_map.pbtxt')
for year in years:
logging.info('Reading from PASCAL %s dataset.', year)
examples_path = './data/val.txt'
annotations_dir = './annotations/'
examples_list = dataset_util.read_examples_list(examples_path)
for idx, example in enumerate(examples_list):
if idx % 100 == 0:
logging.info('On image %d of %d', idx, len(examples_list))
path = os.path.join(annotations_dir, example + '.xml')
with tf.gfile.GFile(path, 'r') as fid:
xml_str = fid.read()
xml = etree.fromstring(xml_str.encode('utf-8'))
data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
tf_example = dict_to_tf_example(data, data_dir, label_map_dict,False)
writer.write(tf_example.SerializeToString())
writer.close()
valrc()
def pretrain(self):
batchsize = self.le_batchsize.text()
numsteps = self.le_stepnm.text()
file_config = open('ssd_mobilenet_v2.config' ,mode = 'w')
file_config.write('model {\n')
file_config.write(' ssd {\n')
file_config.write(' num_classes:' + str(num_class) + '\n')
file_config.write(' box_coder {\n')
file_config.write(' faster_rcnn_box_coder {\n')
file_config.write(' y_scale: 10.0\n')
file_config.write(' x_scale: 10.0\n')
file_config.write(' width_scale: 5.0\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' matcher {\n')
file_config.write(' argmax_matcher {\n')
file_config.write(' matched_threshold: 0.5\n')
file_config.write(' unmatched_threshold: 0.5\n')
file_config.write(' ignore_thresholds: false\n')
file_config.write(' negatives_lower_than_unmatched: true\n')
file_config.write(' force_match_for_each_row: true\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' similarity_calculator {\n')
file_config.write(' iou_similarity {\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' anchor_generator {\n')
file_config.write(' ssd_anchor_generator {\n')
file_config.write(' num_layers: 6\n')
file_config.write(' min_scale: 0.2\n')
file_config.write(' max_scale: 0.95\n')
file_config.write(' aspect_ratios: 1.0\n')
file_config.write(' aspect_ratios: 2.0\n')
file_config.write(' aspect_ratios: 0.5\n')
file_config.write(' aspect_ratios: 3.0\n')
file_config.write(' aspect_ratios: 0.3333\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' image_resizer {\n')
file_config.write(' fixed_shape_resizer {\n')
file_config.write(' height: 300\n')
file_config.write(' width: 300\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' box_predictor {\n')
file_config.write(' convolutional_box_predictor {\n')
file_config.write(' min_depth: 0\n')
file_config.write(' max_depth: 0\n')
file_config.write(' num_layers_before_predictor: 0\n')
file_config.write(' use_dropout: false\n')
file_config.write(' dropout_keep_probability: 0.8\n')
file_config.write(' kernel_size: 1\n')
file_config.write(' box_code_size: 4\n')
file_config.write(' apply_sigmoid_to_scores: false\n')
file_config.write(' conv_hyperparams {\n')
file_config.write(' activation: RELU_6,\n')
file_config.write(' regularizer {\n')
file_config.write(' l2_regularizer {\n')
file_config.write(' weight: 0.00004\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' initializer {\n')
file_config.write(' truncated_normal_initializer {\n')
file_config.write(' stddev: 0.03\n')
file_config.write(' mean: 0.0\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' batch_norm {\n')
file_config.write(' train: true,\n')
file_config.write(' scale: true,\n')
file_config.write(' center: true,\n')
file_config.write(' decay: 0.9997,\n')
file_config.write(' epsilon: 0.001,\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' feature_extractor {\n')
file_config.write(" type: 'ssd_mobilenet_v2'\n")
file_config.write(' min_depth: 16\n')
file_config.write(' depth_multiplier: 1.0\n')
file_config.write(' conv_hyperparams {\n')
file_config.write(' activation: RELU_6,\n')
file_config.write(' regularizer {\n')
file_config.write(' l2_regularizer {\n')
file_config.write(' weight: 0.00004\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' initializer {\n')
file_config.write(' truncated_normal_initializer {\n')
file_config.write(' stddev: 0.03\n')
file_config.write(' mean: 0.0\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' batch_norm {\n')
file_config.write(' train: true,\n')
file_config.write(' scale: true,\n')
file_config.write(' center: true,\n')
file_config.write(' decay: 0.9997,\n')
file_config.write(' epsilon: 0.001,\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' loss {\n')
file_config.write(' classification_loss {\n')
file_config.write(' weighted_sigmoid {\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' localization_loss {\n')
file_config.write(' weighted_smooth_l1 {\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' hard_example_miner {\n')
file_config.write(' num_hard_examples: 3000\n')
file_config.write(' iou_threshold: 0.99\n')
file_config.write(' loss_type: CLASSIFICATION\n')
file_config.write(' max_negatives_per_positive: 3\n')
file_config.write(' min_negatives_per_image: 3\n')
file_config.write(' }\n')
file_config.write(' classification_weight: 1.0\n')
file_config.write(' localization_weight: 1.0\n')
file_config.write(' }\n')
file_config.write(' normalize_loss_by_num_matches: true\n')
file_config.write(' post_processing {\n')
file_config.write(' batch_non_max_suppression {\n')
file_config.write(' score_threshold: 1e-8\n')
file_config.write(' iou_threshold: 0.6\n')
file_config.write(' max_detections_per_class: 100\n')
file_config.write(' max_total_detections: 100\n')
file_config.write(' }\n')
file_config.write(' score_converter: SIGMOID\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write('}\n')
file_config.write('train_config: {\n')
file_config.write(' batch_size: '+ str(batchsize) + '\n')
file_config.write(' optimizer {\n')
file_config.write(' rms_prop_optimizer: {\n')
file_config.write(' learning_rate: {\n')
file_config.write(' exponential_decay_learning_rate {\n')
file_config.write(' initial_learning_rate: 0.004\n')
file_config.write(' decay_steps: 800720\n')
file_config.write(' decay_factor: 0.95\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' momentum_optimizer_value: 0.9\n')
file_config.write(' decay: 0.9\n')
file_config.write(' epsilon: 1.0\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' fine_tune_checkpoint: "./ssd_mobilenet_v2_coco/model.ckpt"\n')
file_config.write(' fine_tune_checkpoint_type: "detection"\n')
file_config.write(' num_steps: ' + str(numsteps) + '\n')
file_config.write(' data_augmentation_options {\n')
file_config.write(' random_horizontal_flip {\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write(' data_augmentation_options {\n')
file_config.write(' ssd_random_crop {\n')
file_config.write(' }\n')
file_config.write(' }\n')
file_config.write('}\n')
file_config.write('train_input_reader: {\n')
file_config.write(' tf_record_input_reader {\n')
file_config.write(' input_path: "./data/train.record"\n')
file_config.write(' }\n')
file_config.write(' label_map_path: "./label_map.pbtxt"\n')
file_config.write('}\n')
file_config.write('eval_config: {\n')
file_config.write(' num_examples: 2000\n')
file_config.write(' max_evals: 10\n')
file_config.write('}\n')
file_config.write('eval_input_reader: {\n')
file_config.write(' tf_record_input_reader {\n')
file_config.write(' input_path: "./data/val.record"\n')
file_config.write(' }\n')
file_config.write(' label_map_path: "./label_map.pbtx"\n')
file_config.write(' shuffle: false\n')
file_config.write(' num_readers: 1\n')
file_config.write('}\n')
file_config.close()
# 以下代码用于训练
@tf.contrib.framework.deprecated(None, 'Use object_detection/model_main.py.')
def trainckpt():
train_dir = './trainckpt'
pipeline_config_path = './ssd_mobilenet_v2.config'
task = 0
assert train_dir, '`train_dir` is missing.'
if task == 0: tf.gfile.MakeDirs(train_dir)
if pipeline_config_path:
configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
if task == 0:
tf.gfile.Copy(pipeline_config_path,os.path.join(train_dir, 'pipeline.config'),overwrite=True)
else:
configs = config_util.get_configs_from_multiple_files(model_config_path='',train_config_path='',train_input_config_path='')
if task == 0:
for name, config in [('model.config', ''),('train.config', ''),('input.config', '')]:
tf.gfile.Copy(config, os.path.join(train_dir, name),overwrite=True)
model_config = configs['model']
train_config = configs['train_config']
input_config = configs['train_input_config']
model_fn = functools.partial(model_builder.build,model_config=model_config,is_training=True)
def get_next(config):
return dataset_builder.make_initializable_iterator(dataset_builder.build(config)).get_next()
create_input_dict_fn = functools.partial(get_next, input_config)
env = json.loads(os.environ.get('TF_CONFIG', '{}'))
cluster_data = env.get('cluster', None)
cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
task_data = env.get('task', None) or {'type': 'master', 'index': 0}
task_info = type('TaskSpec', (object,), task_data)
ps_tasks = 0
worker_replicas = 1
worker_job_name = 'lonely_worker'
task = 0
is_chief = True
master = ''
if cluster_data and 'worker' in cluster_data:
worker_replicas = len(cluster_data['worker']) + 1
if cluster_data and 'ps' in cluster_data:
ps_tasks = len(cluster_data['ps'])
if worker_replicas > 1 and ps_tasks < 1:
raise ValueError('At least 1 ps task is needed for distributed training.')
if worker_replicas >= 1 and ps_tasks > 0:
server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc',job_name=task_info.type, task_index=task_info.index)
if task_info.type == 'ps':
server.join()
return
worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
task = task_info.index
is_chief = (task_info.type == 'master')
master = server.target
graph_rewriter_fn = None
if 'graph_rewriter_config' in configs:
graph_rewriter_fn = graph_rewriter_builder.build(configs['graph_rewriter_config'], is_training=True)
trainer.train(
create_input_dict_fn,
model_fn,
train_config,
master,
task,
1,
worker_replicas,
False,
ps_tasks,
worker_job_name,
is_chief,
train_dir,
graph_hook_fn=graph_rewriter_fn)
trainckpt()
# 用于生成pb文件
def genpb(_):
pipeline_config_path = './ssd_mobilenet_v2.config'
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
with tf.gfile.GFile(pipeline_config_path, 'r') as f:
text_format.Merge(f.read(), pipeline_config)
text_format.Merge('', pipeline_config)
if None:
input_shape = [
int(dim) if dim != '-1' else None
for dim in None.split(',')
]
else:
input_shape = None
exporter.export_inference_graph(
'image_tensor', pipeline_config,'./trainckpt/model.ckpt-' + str(numsteps),
'./trainckpt/save', input_shape=input_shape,
write_inference_graph=False)
if __name__ == "__main__":
app = QtWidgets.QApplication(sys.argv)
ui = Ui_train_window()
ui.setupUi()
ui.show()
sys.exit(app.exec_())
使用的前提是需要将tensorflow、object detection api、python3.x等环境配置好。
配合识别小程序食用风味更加~