【学习笔记】pyQt5学习笔记(5)——Google object detection API训练用软件

之前的学习笔记是调用训练好的结果来做识别,分为加载本地图片识别和调用usb摄像头实时识别(IP摄像头暂时不可用);但是首先有了训练才能有训练好的模型文件供我们使用。加之训练过程比较复杂,调用多个脚本,上手不便;制作训练用的软件一方面是方便自己使用,另一方面也对自己是个锻炼。软件最终的界面如下图所示,可以使用IP、USB摄像头实时将图片显示在界面中,并在界面中实时进行标记(类似于labelImg软件),标记结束后将标准xml文件、原始图像保存在软件脚本所在的目录下,xml保存于annotatis文件夹,图像保存于img文件夹,同时自动分配训练集及验证集,并生成对应的tfrecord格式,这些数据一并保存在data文件夹下。在标记过程中亦可自动生成lable map文件。

软件目前需加载预训练模型进行训练(重新训练按钮暂无作用),使用ssd_mobilenet_v2_coco的预训练数据集(API原始代码就带)。软件可以完成从图像采集----->标注图像----->生成所需数据----->进行训练----->生成pb模型文件的整个过程。

存在问题就是目前想完成自动修改config文件的功能,用户在界面中指定训练次数以及batchsize,但是还不能实现。

软件需进一步优化。所有的源代码如下,本软件只是将object detection API的各个脚本封装在了一起。

# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import hashlib
import io
import logging
import sys
import cv2
import os
import random
import PIL.Image
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import functools
import json
from lxml import etree
from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtWidgets import *
from PyQt5.QtCore import *
from PyQt5.QtGui import *
from xml.dom.minidom import Document
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
from object_detection.utils import config_util
from object_detection.builders import dataset_builder
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.legacy import trainer
from google.protobuf import text_format
from object_detection import exporter
from object_detection.protos import pipeline_pb2

class MyLabel(QLabel):
	x0 = 0
	y0 = 0
	x1 = 0
	y1 = 0
	flag = False
    #鼠标点击事件
	def mousePressEvent(self,event):
		global x_start
		global y_start
		if event.buttons () == QtCore.Qt.LeftButton:
			self.flag = True
			self.x0 = event.x()
			self.y0 = event.y()
			x_start = self.x0
			y_start = self.y0

    #鼠标释放事件
	def mouseReleaseEvent(self,event):
			self.flag = False

    #鼠标移动事件
	def mouseMoveEvent(self,event):
		global x_end
		global y_end
		if event.buttons () == QtCore.Qt.LeftButton:
			if self.flag:
				self.x1 = event.x()
				self.y1 = event.y()
				self.update()
				x_end = self.x1
				y_end = self.y1
    #绘制事件
	def paintEvent(self, event):
		super().paintEvent(event)
		rect =QRect(self.x0, self.y0, abs(self.x1-self.x0), abs(self.y1-self.y0))
		painter = QPainter(self)
		painter.setPen(QPen(Qt.green,2,Qt.SolidLine))
		painter.drawRect(rect)


class Ui_train_window(QtWidgets.QWidget):
	def setupUi(self):
		self.setObjectName("train_window")
		self.resize(690, 600)
		self.setMinimumSize(QtCore.QSize(690, 600))
		self.setMaximumSize(QtCore.QSize(690, 600))
		self.horizontalLayoutWidget = QtWidgets.QWidget(self)
		self.horizontalLayoutWidget.setGeometry(QtCore.QRect(0, 10, 681, 80))
		self.horizontalLayoutWidget.setObjectName("horizontalLayoutWidget")
		self.horizontalLayout = QtWidgets.QHBoxLayout(self.horizontalLayoutWidget)
		self.horizontalLayout.setContentsMargins(0, 0, 0, 0)
		self.horizontalLayout.setObjectName("horizontalLayout")
		self.lab_enter_name = QtWidgets.QLabel(self.horizontalLayoutWidget)
		self.lab_enter_name.setObjectName("lab_enter_name")
		self.horizontalLayout.addWidget(self.lab_enter_name)
		self.le_username = QtWidgets.QLineEdit(self.horizontalLayoutWidget)
		self.le_username.setObjectName("le_username")
		self.horizontalLayout.addWidget(self.le_username)
		self.lab_enter_pw = QtWidgets.QLabel(self.horizontalLayoutWidget)
		self.lab_enter_pw.setObjectName("lab_enter_pw")
		self.horizontalLayout.addWidget(self.lab_enter_pw)
		self.le_userpw = QtWidgets.QLineEdit(self.horizontalLayoutWidget)
		self.le_userpw.setObjectName("le_userpw")
		self.le_userpw.setEchoMode(QLineEdit.Password)
		self.horizontalLayout.addWidget(self.le_userpw)
		self.label = QtWidgets.QLabel(self.horizontalLayoutWidget)
		self.label.setObjectName("label")
		self.horizontalLayout.addWidget(self.label)
		self.le_ipadr = QtWidgets.QLineEdit(self.horizontalLayoutWidget)
		self.le_ipadr.setObjectName("le_ipadr")
		self.horizontalLayout.addWidget(self.le_ipadr)
		self.btn_openIPcam = QtWidgets.QPushButton(self.horizontalLayoutWidget)
		self.btn_openIPcam.setMinimumSize(QtCore.QSize(50, 10))
		self.btn_openIPcam.setObjectName("btn_openIPcam")
		self.horizontalLayout.addWidget(self.btn_openIPcam)
		
		self.showpic= MyLabel(self)
		self.showpic.setGeometry(QtCore.QRect(15, 110, 500, 400))
		self.showpic.setMinimumSize(QtCore.QSize(500, 400))
		self.showpic.setMaximumSize(QtCore.QSize(500, 400))
		self.showpic.setObjectName("show")
		self.showpic.setStyleSheet(("border:2px solid lightgray"))
		
		self.verticalLayoutWidget = QtWidgets.QWidget(self)
		self.verticalLayoutWidget.setGeometry(QtCore.QRect(560, 100, 111, 381))
		self.verticalLayoutWidget.setObjectName("verticalLayoutWidget")
		self.verticalLayout = QtWidgets.QVBoxLayout(self.verticalLayoutWidget)
		self.verticalLayout.setContentsMargins(0, 0, 0, 0)
		self.verticalLayout.setObjectName("verticalLayout")
		
		self.btn_start = QtWidgets.QPushButton(self.verticalLayoutWidget)
		self.btn_start.setObjectName("btn_start")
		self.verticalLayout.addWidget(self.btn_start)
		
		self.lab_herelable = QtWidgets.QLabel(self.verticalLayoutWidget)
		self.lab_herelable.setObjectName("lab_herelable")
		self.lab_herelable.setMaximumSize(QtCore.QSize(16777215, 15))
		self.verticalLayout.addWidget(self.lab_herelable)
		
		self.combo_label = QtWidgets.QComboBox(self.verticalLayoutWidget)
		self.combo_label.setObjectName("combo_label")
		self.combo_label.addItem("")
		self.combo_label.addItem("")
		self.combo_label.addItem("")
		self.combo_label.addItem("")
		self.combo_label.addItem("")
		self.combo_label.addItem("")
		self.verticalLayout.addWidget(self.combo_label)
		self.btn_save = QtWidgets.QPushButton(self.verticalLayoutWidget)
		self.btn_save.setObjectName("btn_save")
		self.verticalLayout.addWidget(self.btn_save)
		self.btn_finish = QtWidgets.QPushButton(self.verticalLayoutWidget)
		self.btn_finish.setObjectName("btn_finish")
		self.verticalLayout.addWidget(self.btn_finish)
		self.horizontalLayoutWidget_2 = QtWidgets.QWidget(self)
		self.horizontalLayoutWidget_2.setGeometry(QtCore.QRect(0, 500, 681, 71))
		self.horizontalLayoutWidget_2.setObjectName("horizontalLayoutWidget_2")
		self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.horizontalLayoutWidget_2)
		self.horizontalLayout_2.setContentsMargins(0, 0, 0, 0)
		self.horizontalLayout_2.setObjectName("horizontalLayout_2")
		self.lab_stepnum = QtWidgets.QLabel(self.horizontalLayoutWidget_2)
		self.lab_stepnum.setObjectName("lab_stepnum")
		self.horizontalLayout_2.addWidget(self.lab_stepnum)
		self.le_stepnm = QtWidgets.QLineEdit(self.horizontalLayoutWidget_2)
		self.le_stepnm.setObjectName("le_stepnm")
		self.horizontalLayout_2.addWidget(self.le_stepnm)
		self.lab_batchsize = QtWidgets.QLabel(self.horizontalLayoutWidget_2)
		self.lab_batchsize.setObjectName("lab_batchsize")
		self.horizontalLayout_2.addWidget(self.lab_batchsize)
		self.le_batchsize = QtWidgets.QLineEdit(self.horizontalLayoutWidget_2)
		self.le_batchsize.setObjectName("le_batchsize")
		self.horizontalLayout_2.addWidget(self.le_batchsize)
		self.btn_pretrain = QtWidgets.QPushButton(self.horizontalLayoutWidget_2)
		self.btn_pretrain.setObjectName("btn_pretrain")
		self.horizontalLayout_2.addWidget(self.btn_pretrain)
		self.btn_retrain = QtWidgets.QPushButton(self.horizontalLayoutWidget_2)
		self.btn_retrain.setObjectName("btn_retrain")
		self.horizontalLayout_2.addWidget(self.btn_retrain)
		self.showpic.raise_()
		self.horizontalLayoutWidget.raise_()
		self.verticalLayoutWidget.raise_()
		self.horizontalLayoutWidget_2.raise_()

		self.i = 0
		self.labelname = []

		self.retranslateUi(self)
		self.btn_openIPcam.clicked.connect(self.IPinfo)
		self.btn_start.clicked.connect(self.labledata)
		self.btn_save.clicked.connect(self.save_lab)
		self.btn_finish.clicked.connect(self.label_end)
		self.btn_pretrain.clicked.connect(self.pretrain)
		QtCore.QMetaObject.connectSlotsByName(self)

	def retranslateUi(self, train_window):
		_translate = QtCore.QCoreApplication.translate
		self.setWindowTitle(_translate("train_window", "训练工具"))
		self.lab_enter_name.setText(_translate("train_window", "输入IP摄像头用户名:"))
		self.lab_enter_pw.setText(_translate("train_window", "输入IP摄像头密码:"))
		self.label.setText(_translate("train_window", "IP地址"))
		self.btn_openIPcam.setText(_translate("train_window", "连接IP摄像头"))
		self.showpic.setText(_translate("train_window", ""))
		self.btn_start.setText(_translate("train_window", "开始标记"))
		self.lab_herelable.setText(_translate("train_window", "         请选择标签"))
		self.combo_label.setItemText(0, _translate("train_window", "a"))
		self.combo_label.setItemText(1, _translate("train_window", "b"))
		self.combo_label.setItemText(2, _translate("train_window", "c"))
		self.combo_label.setItemText(3, _translate("train_window", "d"))
		self.combo_label.setItemText(4, _translate("train_window", "e"))
		self.combo_label.setItemText(5, _translate("train_window", "f"))
		self.btn_save.setText(_translate("train_window", "保存"))
		self.btn_finish.setText(_translate("train_window", "标记结束"))
		self.lab_stepnum.setText(_translate("train_window", "迭代次数:"))
		self.lab_batchsize.setText(_translate("train_window", "batchsize:"))
		self.btn_pretrain.setText(_translate("train_window", "预训练"))
		self.btn_retrain.setText(_translate("train_window", "重新训练"))

	def IPinfo(self):
		username = self.le_username.text()
		password = self.le_userpw.text()
		ipaddress = self.le_ipadr.text()
		#cam_rtsp_addr = "rtsp://" + username + ":" + password + "@" + ipaddress + "/h264/ch33/main/av_stream"
		#self.camcapture = cv2.VideoCapture(cam_rtsp_addr)
		self.camcapture = cv2.VideoCapture(0)
		self.timer = QtCore.QTimer()
		self.timer.start()
		self.timer.setInterval(0.3) 
		self.timer.timeout.connect(self.camshow)
		print('摄像头已启动')

	def camshow(self):
		global showImage
		global camimg
		global imghight
		global imgwidth
		global imgdepth
		_ , camimg = self.camcapture.read()
		camimg2 = cv2.cvtColor(camimg, cv2.COLOR_BGR2RGB)
		imghight = camimg.shape[0]
		imgwidth = camimg.shape[1]
		imgdepth = camimg.shape[2]
		showImage = QtGui.QImage(camimg2.data, camimg2.shape[1], camimg2.shape[0], QtGui.QImage.Format_RGB888)

	def labledata(self):
		self.showpic.setPixmap(QtGui.QPixmap.fromImage(showImage))

	def save_lab(self):
		global lb_names
		global num_class
		classname = self.combo_label.currentText()
		classnum = self.combo_label.count()
		self.labelname.append(classname)
		lb_names = list(set(self.labelname))
		print(lb_names)
		num_class = len(lb_names)
		
		# 这里保存标注原图
		imgpath = os.path.join("./img/",classname + "000" + str(self.i) + ".jpg")
		imgname = classname + "000" + str(self.i) + ".jpg"
		cv2.imwrite(imgpath, camimg)
		
		# 这里将每个文件的xml文件保存下来
		doc = Document()
		annotation = doc.createElement('annotation')
		doc.appendChild(annotation)

		folder = doc.createElement('folder')
		annotation.appendChild(folder)
		folder_text = doc.createTextNode('data')
		folder.appendChild(folder_text)

		filename = doc.createElement('filename')
		annotation.appendChild(filename)
		filename_text = doc.createTextNode(imgname)
		filename.appendChild(filename_text)

		path = doc.createElement('path')
		annotation.appendChild(path)
		path_text = doc.createTextNode(imgpath)
		path.appendChild(path_text)

		source = doc.createElement('source')
		annotation.appendChild(source)

		database = doc.createElement('database')
		source.appendChild(database)
		database_text = doc.createTextNode('Unknown')
		database.appendChild(database_text)

		size = doc.createElement('size')
		annotation.appendChild(size)

		width = doc.createElement('width')
		height = doc.createElement('height')
		depth = doc.createElement('depth')
		width_text = doc.createTextNode(str(imgwidth))
		height_text = doc.createTextNode(str(imghight))
		depth_text = doc.createTextNode(str(imgdepth))
		size.appendChild(width)
		size.appendChild(height)
		size.appendChild(depth)
		width.appendChild(width_text)
		height.appendChild(height_text)
		depth.appendChild(depth_text)

		segmented = doc.createElement('segmented')
		annotation.appendChild(segmented)
		segmented_text = doc.createTextNode('0')
		segmented.appendChild(segmented_text)

		objects = doc.createElement('object')
		annotation.appendChild(objects)

		name = doc.createElement('name')
		pose = doc.createElement('pose')
		truncated = doc.createElement('truncated')
		difficult = doc.createElement('difficult')
		name_text = doc.createTextNode(classname)
		pose_text = doc.createTextNode('Unspecified')
		truncated_text = doc.createTextNode('0')
		difficult_text = doc.createTextNode('0')
		objects.appendChild(name)
		objects.appendChild(pose)
		objects.appendChild(truncated)
		objects.appendChild(difficult)
		name.appendChild(name_text)
		pose.appendChild(pose_text)
		truncated.appendChild(truncated_text)
		difficult.appendChild(difficult_text)

		bndbox = doc.createElement('bndbox')
		objects.appendChild(bndbox)

		xmin = doc.createElement('xmin')
		ymin = doc.createElement('ymin')
		xmax = doc.createElement('xmax')
		ymax = doc.createElement('ymax')
		xmin_text = doc.createTextNode(str(x_start))
		ymin_text = doc.createTextNode(str(y_start))
		xmax_text = doc.createTextNode(str(x_end))
		ymax_text = doc.createTextNode(str(y_end))
		bndbox.appendChild(xmin)
		bndbox.appendChild(ymin)
		bndbox.appendChild(xmax)
		bndbox.appendChild(ymax)
		xmin.appendChild(xmin_text)
		ymin.appendChild(ymin_text)
		xmax.appendChild(xmax_text)
		ymax.appendChild(ymax_text)

		f = open(os.path.join("./annotations/",classname + "000" + str(self.i) + ".xml"),'w')
		doc.writexml(f,indent = '\t',newl = '\n', addindent = '\t',encoding='utf-8')
		f.close()
		self.i = self.i+1

	def label_end(self):
		# 用于分配txt文件
		trainval_percent = 0.8
		train_percent = 0.7
		xmlfilepath = './annotations/'
		txtsavepath = './data/'
		total_xml = os.listdir(xmlfilepath)

		num=len(total_xml)
		tv=int(num*trainval_percent)#训练集+验证集
		tr=int(tv*train_percent) #训练集
		trainval= random.sample(range(num),tv)
		train=random.sample(trainval,tr)

		ftrainval = open(os.path.join(txtsavepath,'trainval.txt'), 'w')
		ftest = open(os.path.join(txtsavepath,'test.txt'), 'w')
		ftrain = open(os.path.join(txtsavepath,'train.txt'), 'w')
		fval = open(os.path.join(txtsavepath,'val.txt'), 'w')

		for i in range(num):
			name=total_xml[i][:-4]+'\n'
			if i in trainval:
				ftrainval.write(name)
				if i in train:
					ftrain.write(name)
				else:
					fval.write(name)
			else:
					ftest.write(name)

		ftrainval.close()
		ftrain.close()
		fval.close()
		ftest.close()

		# 用于自动生成pbtxt文件----????每次都少保存一个??
		# 添加close语句即可
		j = 1
		for lb_name in lb_names:
			print(lb_name)
			pbtxtfile = open('./label_map.pbtxt',mode = 'a')
			pbtxtfile.write('item{' + '\n')
			pbtxtfile.write('  id: ' + str(j) + '\n')
			pbtxtfile.write('  name: ' + "'" + lb_name + "'" + '\n')
			pbtxtfile.write('}' + '\n')
			j += 1
			pbtxtfile.close()

		# 用于生成tfrecord文件
		def dict_to_tf_example(data,
							   dataset_directory,
							   label_map_dict,
							   ignore_difficult_instances=False,
							   image_subdirectory='JPEGImages'):

			img_path = os.path.join(dataset_directory, data['filename'])
			full_path = img_path
			with tf.gfile.GFile(full_path, 'rb') as fid:
				encoded_jpg = fid.read()
			encoded_jpg_io = io.BytesIO(encoded_jpg)
			image = PIL.Image.open(encoded_jpg_io)
			if image.format != 'JPEG':
				raise ValueError('Image format not JPEG')
			key = hashlib.sha256(encoded_jpg).hexdigest()

			width = int(data['size']['width'])
			height = int(data['size']['height'])

			xmin = []
			ymin = []
			xmax = []
			ymax = []
			classes = []
			classes_text = []
			truncated = []
			poses = []
			difficult_obj = []
			for obj in data['object']:
				difficult = bool(int(obj['difficult']))
				if ignore_difficult_instances and difficult:
					continue

				difficult_obj.append(int(difficult))

				xmin.append(float(obj['bndbox']['xmin']) / width)
				ymin.append(float(obj['bndbox']['ymin']) / height)
				xmax.append(float(obj['bndbox']['xmax']) / width)
				ymax.append(float(obj['bndbox']['ymax']) / height)
				classes_text.append(obj['name'].encode('utf8'))
				classes.append(label_map_dict[obj['name']])
				truncated.append(int(obj['truncated']))
				poses.append(obj['pose'].encode('utf8'))

			example = tf.train.Example(features=tf.train.Features(feature={
				'image/height': dataset_util.int64_feature(height),
				'image/width': dataset_util.int64_feature(width),
				'image/filename': dataset_util.bytes_feature(
					data['filename'].encode('utf8')),
				'image/source_id': dataset_util.bytes_feature(
					data['filename'].encode('utf8')),
				'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
				'image/encoded': dataset_util.bytes_feature(encoded_jpg),
				'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
				'image/object/bbox/xmin': dataset_util.float_list_feature(xmin),
				'image/object/bbox/xmax': dataset_util.float_list_feature(xmax),
				'image/object/bbox/ymin': dataset_util.float_list_feature(ymin),
				'image/object/bbox/ymax': dataset_util.float_list_feature(ymax),
				'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
				'image/object/class/label': dataset_util.int64_list_feature(classes),
				'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
				'image/object/truncated': dataset_util.int64_list_feature(truncated),
				'image/object/view': dataset_util.bytes_list_feature(poses),
			}))
			return example

		def trainrc():
			data_dir = './img'
			years = ['VOC2007']
			writer = tf.python_io.TFRecordWriter(os.path.join('./data/','train.record'))
			label_map_dict = label_map_util.get_label_map_dict('./label_map.pbtxt')
			for year in years:
				logging.info('Reading from PASCAL %s dataset.', year)
				examples_path = './data/train.txt'
				annotations_dir = './annotations/'
				examples_list = dataset_util.read_examples_list(examples_path)
				for idx, example in enumerate(examples_list):
					if idx % 100 == 0:
						logging.info('On image %d of %d', idx, len(examples_list))
					path = os.path.join(annotations_dir, example + '.xml')
					with tf.gfile.GFile(path, 'r') as fid:
						xml_str = fid.read()
					xml = etree.fromstring(xml_str.encode('utf-8'))
					data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
					tf_example = dict_to_tf_example(data, data_dir, label_map_dict,False)
					writer.write(tf_example.SerializeToString())
			writer.close()
		trainrc()
		
		def valrc():
			data_dir = './img'
			years = ['VOC2007']
			writer = tf.python_io.TFRecordWriter(os.path.join('./data/','val.record'))
			label_map_dict = label_map_util.get_label_map_dict('./label_map.pbtxt')
			for year in years:
				logging.info('Reading from PASCAL %s dataset.', year)
				examples_path = './data/val.txt'
				annotations_dir = './annotations/'
				examples_list = dataset_util.read_examples_list(examples_path)
				for idx, example in enumerate(examples_list):
					if idx % 100 == 0:
						logging.info('On image %d of %d', idx, len(examples_list))
					path = os.path.join(annotations_dir, example + '.xml')
					with tf.gfile.GFile(path, 'r') as fid:
						xml_str = fid.read()
					xml = etree.fromstring(xml_str.encode('utf-8'))
					data = dataset_util.recursive_parse_xml_to_dict(xml)['annotation']
					tf_example = dict_to_tf_example(data, data_dir, label_map_dict,False)
					writer.write(tf_example.SerializeToString())
			writer.close()
		valrc()

	def pretrain(self):
		batchsize = self.le_batchsize.text()
		numsteps = self.le_stepnm.text()
		
		file_config = open('ssd_mobilenet_v2.config' ,mode = 'w')
		file_config.write('model {\n')
		file_config.write('  ssd {\n')
		file_config.write('    num_classes:' + str(num_class) + '\n')
		file_config.write('    box_coder {\n')
		file_config.write('      faster_rcnn_box_coder {\n')
		file_config.write('        y_scale: 10.0\n')
		file_config.write('        x_scale: 10.0\n')
		file_config.write('        width_scale: 5.0\n')
		file_config.write('      }\n')
		file_config.write('    }\n')
		file_config.write('    matcher {\n')
		file_config.write('      argmax_matcher {\n')
		file_config.write('        matched_threshold: 0.5\n')
		file_config.write('        unmatched_threshold: 0.5\n')
		file_config.write('        ignore_thresholds: false\n')
		file_config.write('        negatives_lower_than_unmatched: true\n')
		file_config.write('        force_match_for_each_row: true\n')
		file_config.write('      }\n')
		file_config.write('    }\n')
		file_config.write('    similarity_calculator {\n')
		file_config.write('      iou_similarity {\n')
		file_config.write('      }\n')
		file_config.write('    }\n')
		file_config.write('    anchor_generator {\n')
		file_config.write('      ssd_anchor_generator {\n')
		file_config.write('        num_layers: 6\n')
		file_config.write('        min_scale: 0.2\n')
		file_config.write('        max_scale: 0.95\n')
		file_config.write('        aspect_ratios: 1.0\n')
		file_config.write('        aspect_ratios: 2.0\n')
		file_config.write('        aspect_ratios: 0.5\n')
		file_config.write('        aspect_ratios: 3.0\n')
		file_config.write('        aspect_ratios: 0.3333\n')
		file_config.write('      }\n')
		file_config.write('    }\n')
		file_config.write('    image_resizer {\n')
		file_config.write('      fixed_shape_resizer {\n')
		file_config.write('        height: 300\n')
		file_config.write('        width: 300\n')
		file_config.write('      }\n')
		file_config.write('    }\n')
		file_config.write('    box_predictor {\n')
		file_config.write('      convolutional_box_predictor {\n')
		file_config.write('        min_depth: 0\n')
		file_config.write('        max_depth: 0\n')
		file_config.write('        num_layers_before_predictor: 0\n')
		file_config.write('        use_dropout: false\n')
		file_config.write('        dropout_keep_probability: 0.8\n')
		file_config.write('        kernel_size: 1\n')
		file_config.write('        box_code_size: 4\n')
		file_config.write('        apply_sigmoid_to_scores: false\n')
		file_config.write('        conv_hyperparams {\n')
		file_config.write('          activation: RELU_6,\n')
		file_config.write('          regularizer {\n')
		file_config.write('            l2_regularizer {\n')
		file_config.write('              weight: 0.00004\n')
		file_config.write('            }\n')
		file_config.write('          }\n')
		file_config.write('          initializer {\n')
		file_config.write('            truncated_normal_initializer {\n')
		file_config.write('              stddev: 0.03\n')
		file_config.write('              mean: 0.0\n')
		file_config.write('            }\n')
		file_config.write('          }\n')
		file_config.write('          batch_norm {\n')
		file_config.write('            train: true,\n')
		file_config.write('            scale: true,\n')
		file_config.write('            center: true,\n')
		file_config.write('            decay: 0.9997,\n')
		file_config.write('            epsilon: 0.001,\n')
		file_config.write('          }\n')
		file_config.write('        }\n')
		file_config.write('      }\n')
		file_config.write('    }\n')
		file_config.write('    feature_extractor {\n')
		file_config.write("      type: 'ssd_mobilenet_v2'\n")
		file_config.write('      min_depth: 16\n')
		file_config.write('      depth_multiplier: 1.0\n')
		file_config.write('      conv_hyperparams {\n')
		file_config.write('        activation: RELU_6,\n')
		file_config.write('        regularizer {\n')
		file_config.write('          l2_regularizer {\n')
		file_config.write('            weight: 0.00004\n')
		file_config.write('          }\n')
		file_config.write('        }\n')
		file_config.write('        initializer {\n')
		file_config.write('          truncated_normal_initializer {\n')
		file_config.write('            stddev: 0.03\n')
		file_config.write('            mean: 0.0\n')
		file_config.write('          }\n')
		file_config.write('        }\n')
		file_config.write('        batch_norm {\n')
		file_config.write('          train: true,\n')
		file_config.write('          scale: true,\n')
		file_config.write('          center: true,\n')
		file_config.write('          decay: 0.9997,\n')
		file_config.write('          epsilon: 0.001,\n')
		file_config.write('        }\n')
		file_config.write('      }\n')
		file_config.write('    }\n')
		file_config.write('    loss {\n')
		file_config.write('      classification_loss {\n')
		file_config.write('        weighted_sigmoid {\n')
		file_config.write('        }\n')
		file_config.write('      }\n')
		file_config.write('      localization_loss {\n')
		file_config.write('        weighted_smooth_l1 {\n')
		file_config.write('        }\n')
		file_config.write('      }\n')
		file_config.write('      hard_example_miner {\n')
		file_config.write('        num_hard_examples: 3000\n')
		file_config.write('        iou_threshold: 0.99\n')
		file_config.write('        loss_type: CLASSIFICATION\n')
		file_config.write('        max_negatives_per_positive: 3\n')
		file_config.write('        min_negatives_per_image: 3\n')
		file_config.write('      }\n')
		file_config.write('      classification_weight: 1.0\n')
		file_config.write('      localization_weight: 1.0\n')
		file_config.write('    }\n')
		file_config.write('    normalize_loss_by_num_matches: true\n')
		file_config.write('    post_processing {\n')
		file_config.write('      batch_non_max_suppression {\n')
		file_config.write('        score_threshold: 1e-8\n')
		file_config.write('        iou_threshold: 0.6\n')
		file_config.write('        max_detections_per_class: 100\n')
		file_config.write('        max_total_detections: 100\n')
		file_config.write('      }\n')
		file_config.write('      score_converter: SIGMOID\n')
		file_config.write('    }\n')
		file_config.write('  }\n')
		file_config.write('}\n')
		file_config.write('train_config: {\n')
		file_config.write('  batch_size: '+ str(batchsize) + '\n')
		file_config.write('  optimizer {\n')
		file_config.write('    rms_prop_optimizer: {\n')
		file_config.write('      learning_rate: {\n')
		file_config.write('        exponential_decay_learning_rate {\n')
		file_config.write('          initial_learning_rate: 0.004\n')
		file_config.write('          decay_steps: 800720\n')
		file_config.write('          decay_factor: 0.95\n')
		file_config.write('        }\n')
		file_config.write('      }\n')
		file_config.write('      momentum_optimizer_value: 0.9\n')
		file_config.write('      decay: 0.9\n')
		file_config.write('      epsilon: 1.0\n')
		file_config.write('    }\n')
		file_config.write('  }\n')
		file_config.write('  fine_tune_checkpoint: "./ssd_mobilenet_v2_coco/model.ckpt"\n')
		file_config.write('  fine_tune_checkpoint_type:  "detection"\n')
		file_config.write('  num_steps: ' + str(numsteps) + '\n')
		file_config.write('  data_augmentation_options {\n')
		file_config.write('    random_horizontal_flip {\n')
		file_config.write('    }\n')
		file_config.write('  }\n')
		file_config.write('  data_augmentation_options {\n')
		file_config.write('    ssd_random_crop {\n')
		file_config.write('    }\n')
		file_config.write('  }\n')
		file_config.write('}\n')
		file_config.write('train_input_reader: {\n')
		file_config.write('  tf_record_input_reader {\n')
		file_config.write('    input_path: "./data/train.record"\n')
		file_config.write('  }\n')
		file_config.write('  label_map_path: "./label_map.pbtxt"\n')
		file_config.write('}\n')
		file_config.write('eval_config: {\n')
		file_config.write('  num_examples: 2000\n')
		file_config.write('  max_evals: 10\n')
		file_config.write('}\n')
		file_config.write('eval_input_reader: {\n')
		file_config.write('  tf_record_input_reader {\n')
		file_config.write('    input_path: "./data/val.record"\n')
		file_config.write('  }\n')
		file_config.write('  label_map_path: "./label_map.pbtx"\n')
		file_config.write('  shuffle: false\n')
		file_config.write('  num_readers: 1\n')
		file_config.write('}\n')
		file_config.close()
		# 以下代码用于训练
		@tf.contrib.framework.deprecated(None, 'Use object_detection/model_main.py.')
		def trainckpt():
			train_dir = './trainckpt'
			pipeline_config_path = './ssd_mobilenet_v2.config'
			task = 0
			assert train_dir, '`train_dir` is missing.'
			if task == 0: tf.gfile.MakeDirs(train_dir)
			if pipeline_config_path:
				configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
				if task == 0:
					tf.gfile.Copy(pipeline_config_path,os.path.join(train_dir, 'pipeline.config'),overwrite=True)
			else:
				configs = config_util.get_configs_from_multiple_files(model_config_path='',train_config_path='',train_input_config_path='')
				if task == 0:
					for name, config in [('model.config', ''),('train.config', ''),('input.config', '')]:
						tf.gfile.Copy(config, os.path.join(train_dir, name),overwrite=True)
			
			model_config = configs['model']
			train_config = configs['train_config']
			input_config = configs['train_input_config']
			model_fn = functools.partial(model_builder.build,model_config=model_config,is_training=True)
			
			def get_next(config):
				return dataset_builder.make_initializable_iterator(dataset_builder.build(config)).get_next()
			
			create_input_dict_fn = functools.partial(get_next, input_config)

			env = json.loads(os.environ.get('TF_CONFIG', '{}'))
			cluster_data = env.get('cluster', None)
			cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
			task_data = env.get('task', None) or {'type': 'master', 'index': 0}
			task_info = type('TaskSpec', (object,), task_data)
			ps_tasks = 0
			worker_replicas = 1
			worker_job_name = 'lonely_worker'
			task = 0
			is_chief = True
			master = ''

			if cluster_data and 'worker' in cluster_data:
				worker_replicas = len(cluster_data['worker']) + 1
			if cluster_data and 'ps' in cluster_data:
				ps_tasks = len(cluster_data['ps'])

			if worker_replicas > 1 and ps_tasks < 1:
				raise ValueError('At least 1 ps task is needed for distributed training.')

			if worker_replicas >= 1 and ps_tasks > 0:
				server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc',job_name=task_info.type, task_index=task_info.index)
				
				if task_info.type == 'ps':
					server.join()
					return

				worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
				task = task_info.index
				is_chief = (task_info.type == 'master')
				master = server.target

			graph_rewriter_fn = None
			if 'graph_rewriter_config' in configs:
				graph_rewriter_fn = graph_rewriter_builder.build(configs['graph_rewriter_config'], is_training=True)

			trainer.train(
				create_input_dict_fn,
				model_fn,
				train_config,
				master,
				task,
				1,
				worker_replicas,
				False,
				ps_tasks,
				worker_job_name,
				is_chief,
				train_dir,
				graph_hook_fn=graph_rewriter_fn)
		trainckpt()

		# 用于生成pb文件
		def genpb(_):
			pipeline_config_path = './ssd_mobilenet_v2.config'
			pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
			with tf.gfile.GFile(pipeline_config_path, 'r') as f:
				text_format.Merge(f.read(), pipeline_config)
			text_format.Merge('', pipeline_config)
			if None:
				input_shape = [
					int(dim) if dim != '-1' else None
					for dim in None.split(',')
				]
			else:
				input_shape = None
			exporter.export_inference_graph(
				'image_tensor', pipeline_config,'./trainckpt/model.ckpt-' + str(numsteps),
				'./trainckpt/save', input_shape=input_shape,
				write_inference_graph=False)

if __name__ == "__main__":
	app = QtWidgets.QApplication(sys.argv)
	ui = Ui_train_window()
	ui.setupUi()
	ui.show()
	sys.exit(app.exec_())

使用的前提是需要将tensorflow、object detection api、python3.x等环境配置好。

配合识别小程序食用风味更加~

猜你喜欢

转载自blog.csdn.net/yourgreatfather/article/details/85063787