原
keras版本的Mask-RCNN里的形状目标检测例子跑通教程
2018年06月27日 15:12:21
yangdashi888
阅读数:279
更多
<div class="tags-box space">
<span class="label">个人分类:</span>
<a class="tag-link" href="https://blog.csdn.net/yangdashi888/article/category/6896214" target="_blank">深度学习 </a>
</div>
</div>
<div class="operating">
</div>
</div>
</div>
</div>
<article>
<div id="article_content" class="article_content clearfix csdn-tracking-statistics" data-pid="blog" data-mod="popu_307" data-dsm="post">
<div class="article-copyright">
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/yangdashi888/article/details/80827513 </div>
<link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_views-e4c7a3727d.css">
<div class="htmledit_views">
<p>其中源代码官方网站是:<a href="https://github.com/matterport/Mask_RCNN" rel="nofollow" target="_blank">https://github.com/matterport/Mask_RCNN</a></p><p>自己改好的代码下载地址:<a href="https://mp.csdn.net/postedit/train_shapes2.py" rel="nofollow" target="_blank">点击打开链接</a> 自己上把官网的ipynb格式代码通过Jutyper notebook的【Download as】转化为py格式的代码。其如图:</p><p><img src="https://img-blog.csdn.net/20180627151147134?watermark/2/text/aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3lhbmdkYXNoaTg4OA==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70" alt=""><br></p><p>由于官网经常进行一些代码优化升级,但是其里面的例子程序没有跟着进行更新接口,这就导致一些例子代码运行有些问题。其中有一个例子程序就是形状识别的。其中有一一些小地方需要修改下的。</p><p>其形状识别时使用Opencv来构建一个数据集的,其如下:</p><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);">构造的数据集需要继承<code style="font-size:14px;line-height:22px;">utils.Dataset</code>类,使用<code style="font-size:14px;line-height:22px;">load_shapes()</code>方法向外提供加载数据的方法,并需要重写下面的方法:</p><ul style="list-style:none;color:rgb(51,51,51);font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;font-size:14px;background-color:rgb(255,255,255);"><li>load_image()</li><li>load_mask()</li><li>image_reference()</li></ul><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);">构造数据集的代码:</p><pre class="prettyprint" style="font-size:14px;line-height:22px;" name="code" onclick="hljs.copyCode(event)"><code class="hljs python has-numbering"><ol class="hljs-ln hundred"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-class"><span class="hljs-keyword"><span class="hljs-class"><span class="hljs-keyword"><span class="hljs-class"><span class="hljs-keyword">class</span></span></span></span></span><span class="hljs-class"><span class="hljs-class"> </span></span><span class="hljs-title"><span class="hljs-class"><span class="hljs-title"><span class="hljs-class"><span class="hljs-title">ShapesDataset</span></span></span></span></span><span class="hljs-params"><span class="hljs-class"><span class="hljs-params"><span class="hljs-class"><span class="hljs-params">(utils.Dataset)</span></span></span></span></span><span class="hljs-class"><span class="hljs-class">:</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-string"><span class="hljs-string"><span class="hljs-string">"""</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> 生成一个数据集,数据集由简单的(三角形,正方形,圆形)放置在空白画布的图片组成。</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> """</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword">def</span></span></span></span></span><span class="hljs-function"><span class="hljs-function"> </span></span><span class="hljs-title"><span class="hljs-function"><span class="hljs-title"><span class="hljs-function"><span class="hljs-title">load_shapes</span></span></span></span></span><span class="hljs-params"><span class="hljs-function"><span class="hljs-params"><span class="hljs-function"><span class="hljs-params">(self, count, height, width)</span></span></span></span></span><span class="hljs-function"><span class="hljs-function">:</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-string"><span class="hljs-string"><span class="hljs-string">"""</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> 产生对应数目的固定大小图片</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> count: 生成数据的数量</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="10"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> height, width: 产生图片的大小</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="11"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> """</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="12"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 添加种类信息</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="13"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> self.add_class(<span class="hljs-string"><span class="hljs-string">"shapes"</span></span>, <span class="hljs-number"><span class="hljs-number">1</span></span>, <span class="hljs-string"><span class="hljs-string">"square"</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="14"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> self.add_class(<span class="hljs-string"><span class="hljs-string">"shapes"</span></span>, <span class="hljs-number"><span class="hljs-number">2</span></span>, <span class="hljs-string"><span class="hljs-string">"circle"</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="15"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> self.add_class(<span class="hljs-string"><span class="hljs-string">"shapes"</span></span>, <span class="hljs-number"><span class="hljs-number">3</span></span>, <span class="hljs-string"><span class="hljs-string">"triangle"</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="16"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="17"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 生成随机规格形状,每张图片依据image_id指定</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="18"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">for</span></span> i <span class="hljs-keyword"><span class="hljs-keyword">in</span></span> range(count):</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="19"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> bg_color, shapes = self.random_image(height, width)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="20"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> self.add_image(<span class="hljs-string"><span class="hljs-string">"shapes"</span></span>, image_id=i, path=<span class="hljs-keyword"><span class="hljs-keyword">None</span></span>,</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="21"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> width=width, height=height,</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="22"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> bg_color=bg_color, shapes=shapes)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="23"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="24"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword">def</span></span></span></span></span><span class="hljs-function"><span class="hljs-function"> </span></span><span class="hljs-title"><span class="hljs-function"><span class="hljs-title"><span class="hljs-function"><span class="hljs-title">load_image</span></span></span></span></span><span class="hljs-params"><span class="hljs-function"><span class="hljs-params"><span class="hljs-function"><span class="hljs-params">(self, image_id)</span></span></span></span></span><span class="hljs-function"><span class="hljs-function">:</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="25"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-string"><span class="hljs-string"><span class="hljs-string">"""</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="26"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> 依据给定的iamge_id产生对应图片。</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="27"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> 通常这个函数是读取文件的,这里我们是依据image_id到image_info里面查找信息,再生成图片</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="28"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> """</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="29"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> info = self.image_info[image_id]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="30"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> bg_color = np.array(info[<span class="hljs-string"><span class="hljs-string">'bg_color'</span></span>]).reshape([<span class="hljs-number"><span class="hljs-number">1</span></span>, <span class="hljs-number"><span class="hljs-number">1</span></span>, <span class="hljs-number"><span class="hljs-number">3</span></span>])</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="31"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> image = np.ones([info[<span class="hljs-string"><span class="hljs-string">'height'</span></span>], info[<span class="hljs-string"><span class="hljs-string">'width'</span></span>], <span class="hljs-number"><span class="hljs-number">3</span></span>], dtype=np.uint8)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="32"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> image = image * bg_color.astype(np.uint8)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="33"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">for</span></span> shape, color, dims <span class="hljs-keyword"><span class="hljs-keyword">in</span></span> info[<span class="hljs-string"><span class="hljs-string">'shapes'</span></span>]:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="34"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> image = self.draw_shape(image, shape, dims, color)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="35"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">return</span></span> image</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="36"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="37"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword">def</span></span></span></span></span><span class="hljs-function"><span class="hljs-function"> </span></span><span class="hljs-title"><span class="hljs-function"><span class="hljs-title"><span class="hljs-function"><span class="hljs-title">image_reference</span></span></span></span></span><span class="hljs-params"><span class="hljs-function"><span class="hljs-params"><span class="hljs-function"><span class="hljs-params">(self, image_id)</span></span></span></span></span><span class="hljs-function"><span class="hljs-function">:</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="38"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-string"><span class="hljs-string">"""Return the shapes data of the image."""</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="39"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> info = self.image_info[image_id]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="40"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">if</span></span> info[<span class="hljs-string"><span class="hljs-string">"source"</span></span>] == <span class="hljs-string"><span class="hljs-string">"shapes"</span></span>:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="41"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">return</span></span> info[<span class="hljs-string"><span class="hljs-string">"shapes"</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="42"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">else</span></span>:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="43"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> super(self.__class__).image_reference(self, image_id)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="44"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="45"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword">def</span></span></span></span></span><span class="hljs-function"><span class="hljs-function"> </span></span><span class="hljs-title"><span class="hljs-function"><span class="hljs-title"><span class="hljs-function"><span class="hljs-title">load_mask</span></span></span></span></span><span class="hljs-params"><span class="hljs-function"><span class="hljs-params"><span class="hljs-function"><span class="hljs-params">(self, image_id)</span></span></span></span></span><span class="hljs-function"><span class="hljs-function">:</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="46"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-string"><span class="hljs-string">"""依据给定的image_id产生相应的规格形状的掩膜"""</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="47"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> info = self.image_info[image_id]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="48"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> shapes = info[<span class="hljs-string"><span class="hljs-string">'shapes'</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="49"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> count = len(shapes)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="50"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> mask = np.zeros([info[<span class="hljs-string"><span class="hljs-string">'height'</span></span>], info[<span class="hljs-string"><span class="hljs-string">'width'</span></span>], count], dtype=np.uint8)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="51"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">for</span></span> i, (shape, _, dims) <span class="hljs-keyword"><span class="hljs-keyword">in</span></span> enumerate(info[<span class="hljs-string"><span class="hljs-string">'shapes'</span></span>]):</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="52"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> mask[:, :, i:i+<span class="hljs-number"><span class="hljs-number">1</span></span>] = self.draw_shape(mask[:, :, i:i+<span class="hljs-number"><span class="hljs-number">1</span></span>].copy(),</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="53"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> shape, dims, <span class="hljs-number"><span class="hljs-number">1</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="54"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># Handle occlusions</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="55"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> occlusion = np.logical_not(mask[:, :, <span class="hljs-number">-</span><span class="hljs-number"><span class="hljs-number">1</span></span>]).astype(np.uint8)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="56"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">for</span></span> i <span class="hljs-keyword"><span class="hljs-keyword">in</span></span> range(count<span class="hljs-number">-</span><span class="hljs-number"><span class="hljs-number">2</span></span>, <span class="hljs-number">-</span><span class="hljs-number"><span class="hljs-number">1</span></span>, <span class="hljs-number">-</span><span class="hljs-number"><span class="hljs-number">1</span></span>):</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="57"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> mask[:, :, i] = mask[:, :, i] * occlusion</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="58"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> occlusion = np.logical_and(occlusion, np.logical_not(mask[:, :, i]))</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="59"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># Map class names to class IDs.</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="60"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> class_ids = np.array([self.class_names.index(s[<span class="hljs-number"><span class="hljs-number">0</span></span>]) <span class="hljs-keyword"><span class="hljs-keyword">for</span></span> s <span class="hljs-keyword"><span class="hljs-keyword">in</span></span> shapes])</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="61"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">return</span></span> mask, class_ids.astype(np.int32)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="62"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="63"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword">def</span></span></span></span></span><span class="hljs-function"><span class="hljs-function"> </span></span><span class="hljs-title"><span class="hljs-function"><span class="hljs-title"><span class="hljs-function"><span class="hljs-title">draw_shape</span></span></span></span></span><span class="hljs-params"><span class="hljs-function"><span class="hljs-params"><span class="hljs-function"><span class="hljs-params">(self, image, shape, dims, color)</span></span></span></span></span><span class="hljs-function"><span class="hljs-function">:</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="64"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-string"><span class="hljs-string">"""绘制给定的形状."""</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="65"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># Get the center x, y and the size s</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="66"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> x, y, s = dims</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="67"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">if</span></span> shape == <span class="hljs-string"><span class="hljs-string">'square'</span></span>:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="68"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> image = cv2.rectangle(image, (x-s, y-s), (x+s, y+s), color, <span class="hljs-number">-</span><span class="hljs-number"><span class="hljs-number">1</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="69"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">elif</span></span> shape == <span class="hljs-string"><span class="hljs-string">"circle"</span></span>:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="70"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> image = cv2.circle(image, (x, y), s, color, <span class="hljs-number">-</span><span class="hljs-number"><span class="hljs-number">1</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="71"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">elif</span></span> shape == <span class="hljs-string"><span class="hljs-string">"triangle"</span></span>:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="72"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> points = np.array([[(x, y-s),</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="73"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> (x-s/math.sin(math.radians(<span class="hljs-number"><span class="hljs-number">60</span></span>)), y+s),</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="74"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> (x+s/math.sin(math.radians(<span class="hljs-number"><span class="hljs-number">60</span></span>)), y+s),</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="75"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> ]], dtype=np.int32)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="76"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> image = cv2.fillPoly(image, points, color)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="77"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">return</span></span> image</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="78"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="79"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword">def</span></span></span></span></span><span class="hljs-function"><span class="hljs-function"> </span></span><span class="hljs-title"><span class="hljs-function"><span class="hljs-title"><span class="hljs-function"><span class="hljs-title">random_shape</span></span></span></span></span><span class="hljs-params"><span class="hljs-function"><span class="hljs-params"><span class="hljs-function"><span class="hljs-params">(self, height, width)</span></span></span></span></span><span class="hljs-function"><span class="hljs-function">:</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="80"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-string"><span class="hljs-string"><span class="hljs-string">"""</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="81"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> 依据给定的长宽边界生成随机形状</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="82"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="83"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> 返回一个有三个值的元组:</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="84"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> * shape: 形状名称(square, circle, ...)</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="85"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> * color: 形状颜色(a tuple of 3 values, RGB.)</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="86"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> * dimensions: 随机形状的中心位置和大小(center_x,center_y,size)</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="87"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> """</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="88"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># Shape</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="89"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> shape = random.choice([<span class="hljs-string"><span class="hljs-string">"square"</span></span>, <span class="hljs-string"><span class="hljs-string">"circle"</span></span>, <span class="hljs-string"><span class="hljs-string">"triangle"</span></span>])</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="90"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># Color</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="91"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> color = tuple([random.randint(<span class="hljs-number"><span class="hljs-number">0</span></span>, <span class="hljs-number"><span class="hljs-number">255</span></span>) <span class="hljs-keyword"><span class="hljs-keyword">for</span></span> _ <span class="hljs-keyword"><span class="hljs-keyword">in</span></span> range(<span class="hljs-number"><span class="hljs-number">3</span></span>)])</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="92"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># Center x, y</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="93"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> buffer = <span class="hljs-number"><span class="hljs-number">20</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="94"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> y = random.randint(buffer, height - buffer - <span class="hljs-number"><span class="hljs-number">1</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="95"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> x = random.randint(buffer, width - buffer - <span class="hljs-number"><span class="hljs-number">1</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="96"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># Size</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="97"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> s = random.randint(buffer, height//<span class="hljs-number"><span class="hljs-number">4</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="98"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">return</span></span> shape, color, (x, y, s)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="99"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="100"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword">def</span></span></span></span></span><span class="hljs-function"><span class="hljs-function"> </span></span><span class="hljs-title"><span class="hljs-function"><span class="hljs-title"><span class="hljs-function"><span class="hljs-title">random_image</span></span></span></span></span><span class="hljs-params"><span class="hljs-function"><span class="hljs-params"><span class="hljs-function"><span class="hljs-params">(self, height, width)</span></span></span></span></span><span class="hljs-function"><span class="hljs-function">:</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="101"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-string"><span class="hljs-string"><span class="hljs-string">"""</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="102"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> 产生有多种形状的随机规格的图片</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="103"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> 返回背景色 和 可以用于绘制图片的形状规格列表</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="104"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> """</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="105"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 随机生成三个通道颜色</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="106"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> bg_color = np.array([random.randint(<span class="hljs-number"><span class="hljs-number">0</span></span>, <span class="hljs-number"><span class="hljs-number">255</span></span>) <span class="hljs-keyword"><span class="hljs-keyword">for</span></span> _ <span class="hljs-keyword"><span class="hljs-keyword">in</span></span> range(<span class="hljs-number"><span class="hljs-number">3</span></span>)])</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="107"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 生成一些随机形状并记录它们的bbox</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="108"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> shapes = []</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="109"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> boxes = []</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="110"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> N = random.randint(<span class="hljs-number"><span class="hljs-number">1</span></span>, <span class="hljs-number"><span class="hljs-number">4</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="111"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">for</span></span> _ <span class="hljs-keyword"><span class="hljs-keyword">in</span></span> range(N):</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="112"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> shape, color, dims = self.random_shape(height, width)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="113"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> shapes.append((shape, color, dims))</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="114"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> x, y, s = dims</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="115"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> boxes.append([y-s, x-s, y+s, x+s])</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="116"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 使用非极大值抑制避免各种形状之间覆盖 阈值为:0.3</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="117"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> keep_ixs = utils.non_max_suppression(np.array(boxes), np.arange(N), <span class="hljs-number"><span class="hljs-number">0.3</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="118"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> shapes = [s <span class="hljs-keyword"><span class="hljs-keyword">for</span></span> i, s <span class="hljs-keyword"><span class="hljs-keyword">in</span></span> enumerate(shapes) <span class="hljs-keyword"><span class="hljs-keyword">if</span></span> i <span class="hljs-keyword"><span class="hljs-keyword">in</span></span> keep_ixs]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="119"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">return</span></span> bg_color, shapes</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);">用上面的数据类构造一组数据,看看:</p><pre class="prettyprint" style="font-size:14px;line-height:22px;" name="code" onclick="hljs.copyCode(event)"><code class="language-python hljs has-numbering"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 构建训练集,大小为500</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dataset_train = ShapesDataset()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dataset_train.load_shapes(<span class="hljs-number"><span class="hljs-number">500</span></span>, config.IMAGE_SHAPE[<span class="hljs-number"><span class="hljs-number">0</span></span>], config.IMAGE_SHAPE[<span class="hljs-number"><span class="hljs-number">1</span></span>])</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dataset_train.prepare()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 构建验证集,大小为50</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dataset_val = ShapesDataset()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dataset_val.load_shapes(<span class="hljs-number"><span class="hljs-number">50</span></span>, config.IMAGE_SHAPE[<span class="hljs-number"><span class="hljs-number">0</span></span>], config.IMAGE_SHAPE[<span class="hljs-number"><span class="hljs-number">1</span></span>])</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">dataset_val.prepare()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="10"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="11"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 随机选取4个样本</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="12"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">image_ids = np.random.choice(dataset_train.image_ids, <span class="hljs-number"><span class="hljs-number">4</span></span>) </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="13"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="14"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword"><span class="hljs-keyword">for</span></span> image_id <span class="hljs-keyword"><span class="hljs-keyword">in</span></span> image_ids:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="15"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> image = dataset_train.load_image(image_id)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="16"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> mask, class_ids = dataset_train.load_mask(image_id)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="17"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> visualize.display_top_masks(image, mask, class_ids, dataset_train.class_names)</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><img src="https://img-blog.csdn.net/20180627142110493?watermark/2/text/aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3lhbmdkYXNoaTg4OA==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70" alt=""><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);"> <strong> 注意:</strong>这个就是其中一个训练样本的原图、mask图。其中不同种类里的mask值是不同的。由于其显示的模态对话框,所以在浏览完随机选取的样本后,需要手动把这个对话框关闭,程序才会继续往下执行。</p><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);">为上面构造的数据集配置一个对应的<code style="font-size:14px;line-height:22px;">ShapesConfig</code>类,该类的作用统一模型配置参数。该类需要继承<code style="font-size:14px;line-height:22px;">Config</code>类:</p><pre class="prettyprint" style="font-size:14px;line-height:22px;" name="code" onclick="hljs.copyCode(event)"><code class="hljs python has-numbering"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-class"><span class="hljs-keyword"><span class="hljs-class"><span class="hljs-keyword"><span class="hljs-class"><span class="hljs-keyword">class</span></span></span></span></span><span class="hljs-class"><span class="hljs-class"> </span></span><span class="hljs-title"><span class="hljs-class"><span class="hljs-title"><span class="hljs-class"><span class="hljs-title">ShapesConfig</span></span></span></span></span><span class="hljs-params"><span class="hljs-class"><span class="hljs-params"><span class="hljs-class"><span class="hljs-params">(Config)</span></span></span></span></span><span class="hljs-class"><span class="hljs-class">:</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-string"><span class="hljs-string"><span class="hljs-string">"""</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> 为数据集添加训练配置</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> 继承基类Config</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-string"> """</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> NAME = <span class="hljs-string"><span class="hljs-string">"shapes"</span></span> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 该配置类的识别符</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment">#Batch size is 8 (GPUs * images/GPU).</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> GPU_COUNT = <span class="hljs-number"><span class="hljs-number">1</span></span> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># GPU数量</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="10"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> IMAGES_PER_GPU = <span class="hljs-number"><span class="hljs-number">8</span></span> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 单GPU上处理图片数(这里我们构造的数据集图片小,可以多处理几张) </span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="11"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="12"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 分类种类数目 (包括背景)</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="13"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> NUM_CLASSES = <span class="hljs-number"><span class="hljs-number">1</span></span> + <span class="hljs-number"><span class="hljs-number">3</span></span> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># background + 3 shapes</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="14"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="15"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 使用小图片可以更快的训练</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="16"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> IMAGE_MIN_DIM = <span class="hljs-number"><span class="hljs-number">128</span></span> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 图片的小边长</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="17"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> IMAGE_MAX_DIM = <span class="hljs-number"><span class="hljs-number">128</span></span> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 图片的大边长</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="18"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="19"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 使用小的anchors,因为数据图片和目标都小</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="20"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> RPN_ANCHOR_SCALES = (<span class="hljs-number"><span class="hljs-number">8</span></span>, <span class="hljs-number"><span class="hljs-number">16</span></span>, <span class="hljs-number"><span class="hljs-number">32</span></span>, <span class="hljs-number"><span class="hljs-number">64</span></span>, <span class="hljs-number"><span class="hljs-number">128</span></span>) <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># anchor side in pixels</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="21"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="22"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 减少训练每张图片上的ROIs,因为图片很小且目标很少,</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="23"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># Aim to allow ROI sampling to pick 33% positive ROIs.</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="24"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> TRAIN_ROIS_PER_IMAGE = <span class="hljs-number"><span class="hljs-number">32</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="25"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="26"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> STEPS_PER_EPOCH = <span class="hljs-number"><span class="hljs-number">100</span></span> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 因为数据简单,使用小的epoch,每个epoch训练的步数</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="27"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="28"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> VALIDATION_STPES = <span class="hljs-number"><span class="hljs-number">5</span></span> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 因为epoch较小,使用小的交叉验证步数</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="29"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="30"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">config = ShapesConfig()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="31"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">config.print()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="32"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="33"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">>>></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="34"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">>>></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="35"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="36"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">Configurations:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="37"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">BACKBONE_SHAPES [[<span class="hljs-number"><span class="hljs-number">32</span></span> <span class="hljs-number"><span class="hljs-number">32</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="38"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> [<span class="hljs-number"><span class="hljs-number">16</span></span> <span class="hljs-number"><span class="hljs-number">16</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="39"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> [ <span class="hljs-number"><span class="hljs-number">8</span></span> <span class="hljs-number"><span class="hljs-number">8</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="40"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> [ <span class="hljs-number"><span class="hljs-number">4</span></span> <span class="hljs-number"><span class="hljs-number">4</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="41"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> [ <span class="hljs-number"><span class="hljs-number">2</span></span> <span class="hljs-number"><span class="hljs-number">2</span></span>]]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="42"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">BACKBONE_STRIDES [<span class="hljs-number"><span class="hljs-number">4</span></span>, <span class="hljs-number"><span class="hljs-number">8</span></span>, <span class="hljs-number"><span class="hljs-number">16</span></span>, <span class="hljs-number"><span class="hljs-number">32</span></span>, <span class="hljs-number"><span class="hljs-number">64</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="43"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">BATCH_SIZE <span class="hljs-number"><span class="hljs-number">8</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="44"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">BBOX_STD_DEV [ <span class="hljs-number"><span class="hljs-number">0.1</span></span> <span class="hljs-number"><span class="hljs-number">0.1</span></span> <span class="hljs-number"><span class="hljs-number">0.2</span></span> <span class="hljs-number"><span class="hljs-number">0.2</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="45"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">DETECTION_MAX_INSTANCES <span class="hljs-number"><span class="hljs-number">100</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="46"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">DETECTION_MIN_CONFIDENCE <span class="hljs-number"><span class="hljs-number">0.7</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="47"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">DETECTION_NMS_THRESHOLD <span class="hljs-number"><span class="hljs-number">0.3</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="48"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">GPU_COUNT <span class="hljs-number"><span class="hljs-number">1</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="49"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">IMAGES_PER_GPU <span class="hljs-number"><span class="hljs-number">8</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="50"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">IMAGE_MAX_DIM <span class="hljs-number"><span class="hljs-number">128</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="51"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">IMAGE_MIN_DIM <span class="hljs-number"><span class="hljs-number">128</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="52"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">IMAGE_PADDING <span class="hljs-keyword"><span class="hljs-keyword">True</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="53"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">IMAGE_SHAPE [<span class="hljs-number"><span class="hljs-number">128</span></span> <span class="hljs-number"><span class="hljs-number">128</span></span> <span class="hljs-number"><span class="hljs-number">3</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="54"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">LEARNING_MOMENTUM <span class="hljs-number"><span class="hljs-number">0.9</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="55"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">LEARNING_RATE <span class="hljs-number"><span class="hljs-number">0.002</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="56"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">MASK_POOL_SIZE <span class="hljs-number"><span class="hljs-number">14</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="57"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">MASK_SHAPE [<span class="hljs-number"><span class="hljs-number">28</span></span>, <span class="hljs-number"><span class="hljs-number">28</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="58"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">MAX_GT_INSTANCES <span class="hljs-number"><span class="hljs-number">100</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="59"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">MEAN_PIXEL [ <span class="hljs-number"><span class="hljs-number">123.7</span></span> <span class="hljs-number"><span class="hljs-number">116.8</span></span> <span class="hljs-number"><span class="hljs-number">103.9</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="60"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">MINI_MASK_SHAPE (<span class="hljs-number"><span class="hljs-number">56</span></span>, <span class="hljs-number"><span class="hljs-number">56</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="61"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">NAME shapes</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="62"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">NUM_CLASSES <span class="hljs-number"><span class="hljs-number">4</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="63"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">POOL_SIZE <span class="hljs-number"><span class="hljs-number">7</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="64"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">POST_NMS_ROIS_INFERENCE <span class="hljs-number"><span class="hljs-number">1000</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="65"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">POST_NMS_ROIS_TRAINING <span class="hljs-number"><span class="hljs-number">2000</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="66"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">ROI_POSITIVE_RATIO <span class="hljs-number"><span class="hljs-number">0.33</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="67"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">RPN_ANCHOR_RATIOS [<span class="hljs-number"><span class="hljs-number">0.5</span></span>, <span class="hljs-number"><span class="hljs-number">1</span></span>, <span class="hljs-number"><span class="hljs-number">2</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="68"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">RPN_ANCHOR_SCALES (<span class="hljs-number"><span class="hljs-number">8</span></span>, <span class="hljs-number"><span class="hljs-number">16</span></span>, <span class="hljs-number"><span class="hljs-number">32</span></span>, <span class="hljs-number"><span class="hljs-number">64</span></span>, <span class="hljs-number"><span class="hljs-number">128</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="69"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">RPN_ANCHOR_STRIDE <span class="hljs-number"><span class="hljs-number">2</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="70"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">RPN_BBOX_STD_DEV [ <span class="hljs-number"><span class="hljs-number">0.1</span></span> <span class="hljs-number"><span class="hljs-number">0.1</span></span> <span class="hljs-number"><span class="hljs-number">0.2</span></span> <span class="hljs-number"><span class="hljs-number">0.2</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="71"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">RPN_TRAIN_ANCHORS_PER_IMAGE <span class="hljs-number"><span class="hljs-number">256</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="72"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">STEPS_PER_EPOCH <span class="hljs-number"><span class="hljs-number">100</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="73"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">TRAIN_ROIS_PER_IMAGE <span class="hljs-number"><span class="hljs-number">32</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="74"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">USE_MINI_MASK <span class="hljs-keyword"><span class="hljs-keyword">True</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="75"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">USE_RPN_ROIS <span class="hljs-keyword"><span class="hljs-keyword">True</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="76"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">VALIDATION_STPES <span class="hljs-number"><span class="hljs-number">5</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="77"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">WEIGHT_DECAY <span class="hljs-number"><span class="hljs-number">0.0001</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="78"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><ul class="pre-numbering"><li style="color:rgb(153,153,153);"><br></li></ul><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);"><strong>2、加载模型并训练:</strong></p><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);">上面配置好了个人数据集和对应的Config了,下面加载预训练模型:</p><pre class="prettyprint" style="font-size:14px;line-height:22px;" name="code" onclick="hljs.copyCode(event)"><code class="hljs python has-numbering"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 模型有两种模式: training inference</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 创建模型并设置training模式</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">model = modellib.MaskRCNN(mode=<span class="hljs-string"><span class="hljs-string">"training"</span></span>, config=config,</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> model_dir=MODEL_DIR)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 选择权重类型,这里我们的预训练权重是COCO的</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">init_with = <span class="hljs-string"><span class="hljs-string">"coco"</span></span> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># imagenet, coco, or last</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword"><span class="hljs-keyword">if</span></span> init_with == <span class="hljs-string"><span class="hljs-string">"imagenet"</span></span>:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="10"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> model.load_weights(model.get_imagenet_weights(), by_name=<span class="hljs-keyword"><span class="hljs-keyword">True</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="11"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword"><span class="hljs-keyword">elif</span></span> init_with == <span class="hljs-string"><span class="hljs-string">"coco"</span></span>:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="12"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 载入在MS COCO上的预训练模型,跳过不一样的分类数目层</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="13"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> model.load_weights(COCO_MODEL_PATH, by_name=<span class="hljs-keyword"><span class="hljs-keyword">True</span></span>,</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="14"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> exclude=[<span class="hljs-string"><span class="hljs-string">"mrcnn_class_logits"</span></span>, <span class="hljs-string"><span class="hljs-string">"mrcnn_bbox_fc"</span></span>, </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="15"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-string"><span class="hljs-string">"mrcnn_bbox"</span></span>, <span class="hljs-string"><span class="hljs-string">"mrcnn_mask"</span></span>])</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="16"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword"><span class="hljs-keyword">elif</span></span> init_with == <span class="hljs-string"><span class="hljs-string">"last"</span></span>:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="17"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 载入你最后训练的模型,继续训练</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="18"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> model.load_weights(model.find_last()[<span class="hljs-number"><span class="hljs-number">1</span></span>], by_name=<span class="hljs-keyword"><span class="hljs-keyword">True</span></span>)</div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><p>3、训练模型:</p><p>在完成上面的权重加载后,下面就可以执行训练,其中训练的了:</p><p>A、其中训练方法有两种一种是进行微调网络最后几层参数。其中需要向train()网络传入layers=‘heads’参数,这种微调方法适合现在的数据种类给coco里的物体种类类似。这样子训练可以省时省事。</p><p>B、另一种微调有点类型重新训练。这种是在物体种类跟原始权重检测的物体种类相差太大的时候使用。这是传入的参数是layer='all'。</p><p>其中在原例子工程上是这两个训练代码都没有注释的。这时我们只需要注释掉其中的一个训练代码即可。</p><p>例如我使用的是微调最后几层参数。则是如下代码:</p><pre onclick="hljs.copyCode(event)"><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># Train the head branches</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># Passing layers="heads" freezes all layers except the head</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># layers. You can also pass a regular expression to select</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># which layers to train by name pattern.</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">model.train(dataset_train, dataset_val,</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> learning_rate=config.LEARNING_RATE,</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> epochs=<span class="hljs-number">2</span>,</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> layers=<span class="hljs-string">'heads'</span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="10"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># Fine tune all layers</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="11"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># Passing layers="all" trains all layers. You can also </span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="12"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># pass a regular expression to select which layers to</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="13"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># train by name pattern.</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="14"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="15"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># model.train(dataset_train, dataset_val,</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="16"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># learning_rate=config.LEARNING_RATE / 10,</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="17"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># epochs=2,</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="18"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment"># layers="all")</span></div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><p> 注意:这个是把官网的ipynb格式的代码转为py格式的代码。具体装换方法可以参考网上的一些教程,很方便。</p><p>3、模型预测:</p><p>模型预测需要配置一个类InferenceConfig类即可,其中就有一些地方需要修改的。</p><p></p><pre class="prettyprint" style="font-size:14px;line-height:22px;" name="code" onclick="hljs.copyCode(event)"><code class="language-python hljs has-numbering"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-class"><span class="hljs-keyword"><span class="hljs-class"><span class="hljs-keyword"><span class="hljs-class"><span class="hljs-keyword">class</span></span></span></span></span><span class="hljs-class"><span class="hljs-class"> </span></span><span class="hljs-title"><span class="hljs-class"><span class="hljs-title"><span class="hljs-class"><span class="hljs-title">InferenceConfig</span></span></span></span></span><span class="hljs-params"><span class="hljs-class"><span class="hljs-params"><span class="hljs-class"><span class="hljs-params">(ShapesConfig)</span></span></span></span></span><span class="hljs-class"><span class="hljs-class">:</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> GPU_COUNT = <span class="hljs-number"><span class="hljs-number">1</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> IMAGES_PER_GPU = <span class="hljs-number"><span class="hljs-number">1</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">inference_config = InferenceConfig()</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 重新创建模型设置为inference模式</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">model = modellib.MaskRCNN(mode=<span class="hljs-string"><span class="hljs-string">"inference"</span></span>, </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> config=inference_config,</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="10"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> model_dir=MODEL_DIR)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="11"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="12"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 这里手动指定训练好的文件名称。这里可以单独放到一个文件里专门是网络预测跟计算</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="13"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># model_path = os.path.join(ROOT_DIR, ".h5 file name here")</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="14"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">model_path = os.path.join(<span class="hljs-string">'./'</span>, <span class="hljs-string">"mask_rcnn_shapes_0002.h5"</span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="15"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="16"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 加载权重</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="17"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword"><span class="hljs-keyword">assert</span></span> model_path != <span class="hljs-string"><span class="hljs-string">""</span></span>, <span class="hljs-string"><span class="hljs-string">"Provide path to trained weights"</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="18"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">print(<span class="hljs-string"><span class="hljs-string">"Loading weights from "</span></span>, model_path)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="19"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">model.load_weights(model_path, by_name=<span class="hljs-keyword"><span class="hljs-keyword">True</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="20"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="21"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 测试随机图片,官网里的例子是缺失gt_class_id,返回变量的,需要添加上去。</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="22"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">image_id = random.choice(dataset_val.image_ids)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="23"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">original_image, image_meta,<span style="color:rgb(255,0,0);">gt_class_id</span>,gt_bbox, gt_mask =\</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="24"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> modellib.load_image_gt(dataset_val, inference_config, </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="25"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> image_id, use_mini_mask=<span class="hljs-keyword"><span class="hljs-keyword">False</span></span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="26"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="27"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">log(<span class="hljs-string"><span class="hljs-string">"original_image"</span></span>, original_image)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="28"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">log(<span class="hljs-string"><span class="hljs-string">"image_meta"</span></span>, image_meta)</div></div></li></ol></code><code class="language-python hljs has-numbering"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">log(<span class="hljs-string">"gt_class_id"</span>,gt_class_id)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">log(<span class="hljs-string"><span class="hljs-string">"gt_bbox"</span></span>, gt_bbox)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">log(<span class="hljs-string"><span class="hljs-string">"gt_mask"</span></span>, gt_mask)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">visualize.display_instances(original_image, <span style="color:rgb(255,0,0);">gt_bbox, gt_mask, gt_class_id, </span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> dataset_train.class_names, figsize=(<span class="hljs-number"><span class="hljs-number">8</span></span>, <span class="hljs-number"><span class="hljs-number">8</span></span>))</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">>>></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">>>></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="10"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">original_image shape: (<span class="hljs-number"><span class="hljs-number">128</span></span>, <span class="hljs-number"><span class="hljs-number">128</span></span>, <span class="hljs-number"><span class="hljs-number">3</span></span>) min: <span class="hljs-number"><span class="hljs-number">18.00000</span></span> max: <span class="hljs-number"><span class="hljs-number">231.00000</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="11"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">image_meta shape: (<span class="hljs-number"><span class="hljs-number">12</span></span>,) min: <span class="hljs-number"><span class="hljs-number">0.00000</span></span> max: <span class="hljs-number"><span class="hljs-number">128.00000</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="12"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">gt_bbox shape: (<span class="hljs-number"><span class="hljs-number">2</span></span>, <span class="hljs-number"><span class="hljs-number">5</span></span>) min: <span class="hljs-number"><span class="hljs-number">1.00000</span></span> max: <span class="hljs-number"><span class="hljs-number">115.00000</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="13"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">gt_mask shape: (<span class="hljs-number"><span class="hljs-number">128</span></span>, <span class="hljs-number"><span class="hljs-number">128</span></span>, <span class="hljs-number"><span class="hljs-number">2</span></span>) min: <span class="hljs-number"><span class="hljs-number">0.00000</span></span> max: <span class="hljs-number"><span class="hljs-number">1.00000</span></span></div></div></li></ol></code><div class="hljs-button" data-title="复制"></div></pre><ul class="pre-numbering"><li style="color:rgb(153,153,153);"><br></li></ul><p></p><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);">随机几张验证集图片看看:</p><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);"><img src="http://owv7la1di.bkt.clouddn.com/blog/171109/0Dk6GfajF2.png" alt="mark" title=""></p><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);"> 注意:这个是其中一个验证集里的样本图片,其中需要把对话框关闭,代码才可以往下执行。</p><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);">使用模型预测,:</p><pre class="prettyprint" style="font-size:14px;line-height:22px;" name="code"><code class="hljs python has-numbering"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword"><span class="hljs-function"><span class="hljs-keyword">def</span></span></span></span></span><span class="hljs-function"><span class="hljs-function"> </span></span><span class="hljs-title"><span class="hljs-function"><span class="hljs-title"><span class="hljs-function"><span class="hljs-title">get_ax</span></span></span></span></span><span class="hljs-params"><span class="hljs-function"><span class="hljs-params"><span class="hljs-function"><span class="hljs-params">(rows=</span></span></span></span><span class="hljs-number"><span class="hljs-function"><span class="hljs-params"><span class="hljs-number"><span class="hljs-function"><span class="hljs-params"><span class="hljs-number">1</span></span></span></span></span></span></span><span class="hljs-function"><span class="hljs-params"><span class="hljs-function"><span class="hljs-params">, cols=</span></span></span></span><span class="hljs-number"><span class="hljs-function"><span class="hljs-params"><span class="hljs-number"><span class="hljs-function"><span class="hljs-params"><span class="hljs-number">1</span></span></span></span></span></span></span><span class="hljs-function"><span class="hljs-params"><span class="hljs-function"><span class="hljs-params">, size=</span></span></span></span><span class="hljs-number"><span class="hljs-function"><span class="hljs-params"><span class="hljs-number"><span class="hljs-function"><span class="hljs-params"><span class="hljs-number">8</span></span></span></span></span></span></span><span class="hljs-function"><span class="hljs-params"><span class="hljs-function"><span class="hljs-params">)</span></span></span></span></span><span class="hljs-function"><span class="hljs-function">:</span></span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-string"><span class="hljs-string">"""返回Matplotlib Axes数组用于可视化.提供中心点控制图形大小"""</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> _, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-keyword"><span class="hljs-keyword">return</span></span> ax</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">results = model.detect([original_image], verbose=<span class="hljs-number"><span class="hljs-number">1</span></span>) <span class="hljs-comment" style="font-style:italic;"><span class="hljs-comment"># 预测</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">r = results[<span class="hljs-number"><span class="hljs-number">0</span></span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">visualize.display_instances(original_image, r[<span class="hljs-string"><span class="hljs-string">'rois'</span></span>], r[<span class="hljs-string"><span class="hljs-string">'masks'</span></span>], r[<span class="hljs-string"><span class="hljs-string">'class_ids'</span></span>], </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="10"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> dataset_val.class_names, r[<span class="hljs-string"><span class="hljs-string">'scores'</span></span>], ax=get_ax())</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="11"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="12"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">>>></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="13"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">>>></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="14"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">Processing <span class="hljs-number"><span class="hljs-number">1</span></span> images</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="15"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">image shape: (<span class="hljs-number"><span class="hljs-number">128</span></span>, <span class="hljs-number"><span class="hljs-number">128</span></span>, <span class="hljs-number"><span class="hljs-number">3</span></span>) min: <span class="hljs-number"><span class="hljs-number">18.00000</span></span> max: <span class="hljs-number"><span class="hljs-number">231.00000</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="16"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">molded_images shape: (<span class="hljs-number"><span class="hljs-number">1</span></span>, <span class="hljs-number"><span class="hljs-number">128</span></span>, <span class="hljs-number"><span class="hljs-number">128</span></span>, <span class="hljs-number"><span class="hljs-number">3</span></span>) min: <span class="hljs-number">-</span><span class="hljs-number"><span class="hljs-number">98.80000</span></span> max: <span class="hljs-number"><span class="hljs-number">127.10000</span></span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="17"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">image_metas shape: (<span class="hljs-number"><span class="hljs-number">1</span></span>, <span class="hljs-number"><span class="hljs-number">12</span></span>) min: <span class="hljs-number"><span class="hljs-number">0.00000</span></span> max: <span class="hljs-number"><span class="hljs-number">128.00000</span></span></div></div></li></ol></code></pre><ul class="pre-numbering"><li style="color:rgb(153,153,153);"><br></li></ul><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);"><img src="https://img-blog.csdn.net/20180627145758454?watermark/2/text/aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3lhbmdkYXNoaTg4OA==/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70" alt=""><br></p><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);">注意:这里需要注意的是把<span style="color:rgb(0,0,0);font-family:Consolas, Inconsolata, Courier, monospace;font-size:14px;background-color:rgb(246,248,250);">display_instances里的</span>ax=get_ax()参数给注释掉,因为这个参数没注释掉的话则会导致不显示预测出来的结果图,如上。最后在预览完后关闭窗口。</p><p style="font-family:'-apple-system', 'SF UI Text', Arial, 'PingFang SC', 'Hiragino Sans GB', 'Microsoft YaHei', 'WenQuanYi Micro Hei', sans-serif, SimHei, SimSun;background-color:rgb(255,255,255);">4、计算输出网络的AP值:</p><pre><code class="language-python hljs"><ol class="hljs-ln"><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="1"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">image_ids = np.random.choice(dataset_val.image_ids, <span class="hljs-number">10</span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="2"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">APs = []</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="3"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"><span class="hljs-keyword">for</span> image_id <span class="hljs-keyword">in</span> image_ids:</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="4"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment"># 加载数据</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="5"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> image, image_meta,class_ids, gt_bbox, gt_mask =\</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="6"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> modellib.load_image_gt(dataset_val, inference_config,</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="7"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> image_id, use_mini_mask=<span class="hljs-keyword">False</span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="8"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> molded_images = np.expand_dims(modellib.mold_image(image, inference_config), <span class="hljs-number">0</span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="9"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment"># Run object detection</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="10"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> results = model.detect([image], verbose=<span class="hljs-number">0</span>)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="11"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> r = results[<span class="hljs-number">0</span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="12"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> hh=gt_bbox[:,:<span class="hljs-number">4</span>]</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="13"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment"># hh3=gt_bbox[:,4]</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="14"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> <span class="hljs-comment"># Compute AP</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="15"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> AP, precisions, recalls, overlaps =\</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="16"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> utils.compute_ap(gt_bbox, <span style="color:rgb(255,0,0);">class_ids,gt_mask,</span></div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="17"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> r[<span class="hljs-string">"rois"</span>], r[<span class="hljs-string">"class_ids"</span>], r[<span class="hljs-string">"scores"</span>],r[<span class="hljs-string">"masks"</span>])</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="18"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> visualize.plot_precision_recall(AP,precisions,recalls)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="19"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> APs.append(AP)</div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="20"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line"> </div></div></li><li><div class="hljs-ln-numbers"><div class="hljs-ln-line hljs-ln-n" data-line-number="21"></div></div><div class="hljs-ln-code"><div class="hljs-ln-line">print(<span class="hljs-string">"mAP: "</span>, np.mean(APs))</div></div></li></ol></code></pre> 注意:这里标红的是需要修改源代码的位置。现在是改好的;最后就可以打印出10张图片的准确度的值;<p></p><div><br></div> </div>
</div>
<script>
(function(){
function setArticleH(btnReadmore,posi){
var winH = $(window).height();
var articleBox = $("div.article_content");
var artH = articleBox.height();
if(artH > winH*posi){
articleBox.css({
'height':winH*posi+'px',
'overflow':'hidden'
})
btnReadmore.click(function(){
articleBox.removeAttr("style");
$(this).parent().remove();
})
}else{
btnReadmore.parent().remove();
}
}
var btnReadmore = $("#btn-readmore");
if(btnReadmore.length>0){
if(currentUserName){
setArticleH(btnReadmore,3);
}else{
setArticleH(btnReadmore,1.2);
}
}
})()
</script>
</article>