Dynamic Routing Between Capsules 读书笔记

0 摘要(略)

1 导言

人类视觉通过使用仔细确定一系列关注点来忽略不相关的细节,来确保仅仅只有很小一部分的视觉的阵列以最高的解析度进行处理。内省是一个理解我们有多少关于场景的知识是来自于这样的关注点序列并且我们可以从一个关注点搜集多少知识的弱引导,但是在这篇文章里,我们假设一个关注点将会给出远多于单单一个可辨识的物体和它的属性。我们假设我们的多层视觉系统在每个关注点上都会建立一个类似分析树的结构,并且我们忽略掉这些单个关注点上建立的分析树在多个关注点上的协调问题。

分析树一般来说是通过动态内存分配来构造的。然而,我们应该假设,对于单个关注点,就像将一个石头雕刻成雕塑一样,在这里我们的分析树是从一个固定的神经网络中刻出来的(读者理解:这样的说法很形象,打个比喻,神经网络就像沙漠中的岩石一样,而数据和训练算法好比夹杂着沙粒的风一样蚀刻着岩石,这样看来也就可以理解,神经网络的实际容量不会超过其表示容量了)。每层中的神经元被分割成许多小的神经元组,我们称之为“胶囊”并且树中的每个节点对应着一个活跃的胶囊。使用一个迭代式的路由过程,每个活跃的胶囊将选择上一层中的一个胶囊作为其在分析树中的父节点。对于视觉系统中的更高的层,这样的过程将会解决指派部件到整体的问题。

一个活跃的胶囊中的神经元的活动情况表示着图片中存在的实体的多个属性。这些属性可能包括许多不同类型的实体化参数,比如姿态(位置、尺寸、方向)、形变、速度、反照率、纹理等等。此外还有又一个非常特殊的属性,那就是图像中该实例化的实体的存在性。一个非常显然的方法就是使用一个单独的逻辑斯蒂单元的输出来表示实体的存在概率的大小。但在本文中,我们探索了一个有意思的替代方法,我们使用实例化参数向量的长度来表示实体的存在性并且强制使向量的方向表示实体的属性。我们通过应用一个保向的非线性性来缩小胶囊的向量长度,使其不会超过1。

事实上,胶囊的输出是向量使得我们可以使用一个强大的动态路由机制来确保胶囊的输出被发送到下一层中合适的父节点。起初,输出被路由到所有可能的父节点中,但是会通过一组和为1的连接系数来响应地缩减。对于每个可能的父节点,胶囊通过将自己的输出与一个权重矩阵相乘来计算一个预测向量,如果这个预测向量和某个可能的父节点的标量积很大,那么就会有一个自顶向下的反馈来增大胶囊与该父节点的连接系数,并且减少与其它父节点的连接系数。这将会增加该胶囊对该父节点的贡献度,这样又进一步增加了标量积的大小。这种一种通过统一意见进行路由的方法要远远比非常原始的最大池化路由有效,最大池化路由简单地丢弃了局部区域内除了最明显特征以外的所有特征。我们证明了我们的动态路由机制是种有效的实现分割高度重叠的物体所需要的“解释远离”效应。

卷积神经网络使用特征探测器的转化的副本。我使得我们可以将在图片中某个位置学到的好的权重转化为其它位置上的知识。这已经被证明在图像解读中是非常有帮助的。尽管我们将特征探测器的标量输出替换为胶囊的向量输出,将最大池化替换为意见一致性路由,我们仍然还需要在不同的空间复制学到的知识。为了实现这个想法,我们使用卷机实现除了最后一层以外的所有层。就像CNN一样,我们使得更高层次的胶囊覆盖图片上更广的区域。然而并不想池化作用,我们并不会丢弃掉该区域内实体的具体的位置信息。对于低层次的胶囊,位置信息通过活跃的胶囊来进行“位置编码”。当我们提高层级时,越来越多的位置信息通过“比率编码”来编入胶囊输出的实值向量的各个分量中。这种从“位置编码”到“比率编码”的转变结合了这样的事实,就是更高层次的胶囊表示着有更多自由度的更加复杂实体,这暗示了随着层级的提高,胶囊的维度应该更大。

2 如何计算胶囊的向量输入输出

存在很多实现胶囊的这样一个通用的思想。本文的目的并不是探索该领域内的所有可能的想法,而仅仅展示了一个相当直接的实现就可能工作起来,并且动态路由的确可以训练胶囊层。

我们希望一个胶囊的输出向量的长度表示在当前输入中由这个胶囊表示的实体的存在性概率。我们因此使用一个非线性的“挤压”函数来确保段向量的长度几乎收缩为零,并且长的向量收缩到略小于1。我们把它留给判别式学习来好好利用这样的非线性性。

V j = s j 2 1 + s j 2 s j s j ( 1 )

这里 v j 是胶囊的输出, s j 是它的整个输入。
除了第一层的胶囊,胶囊 s j 的整个输入是上一层中所有胶囊的“预测向量” u ^ j | i 的加权和,而 u ^ j | i 是通过将上一层的胶囊输出 u i 乘以权重矩阵 W i j
s j = i c i j u ^ i , u ^ i = W i j u i ( 2 )

这里的 c i j 是由动态路由过程确定的连接系数。
胶囊 i 与下一层中所有的胶囊的连接系数的和为1,并且使用“路由softmax”来确定,它的初始未归一化概率 b i j 代表了胶囊 i 应与胶囊 j 相连的对数先验概率。
c i j = exp ( b i j ) k exp ( b i k ) ( 3 )

对数先验可以和其它参数一样通过判别式地学习到。它们取决于两个胶囊的位置和类型,而不是当前图片。初始的连接系数通过迭代式的度量下一层胶囊每个胶囊 j 的输出 v j 和胶囊 i 的预测向量 u ^ j | i 的意见一致度来逐步精炼调优。
那么意见一致性度量可以简单地使用标量积 a i j = v j T u ^ j | i 来计算。我们可以把这样的意见一致度看作对数似然,并在计算胶囊 i 与下一层的胶囊新的连接系数之前把它加到初始的未归一化概率 b i j 中去。

在卷积胶囊层中,每个胶囊使用不同的变换矩阵对下一层中的每类胶囊输出一个局部向量网格,而这样的矩阵是对每种不同类型和网格中每个成员都是不同的。
这里写图片描述

3 数字存在性的边缘损失

我们使用实例化向量长度来表示胶囊表示的实体的存在概率。我们希望当且仅当在图像中存在数字 k 的时候,顶层的代表数字 k 的胶囊有着长的实例化向量。为了允许多个数字,我们对每个数字胶囊 k 使用一个单独的边缘损失 L k

L k = T k max ( 0 , m + v k ) 2 + λ ( 1 T k ) max ( 0 , v k m ) 2 ( 4 )

当且仅当数字 k 出现在输入中 T k 才为1,这里 m + = 0.9 , m = 0.1 。这里 λ ,我们使用 λ = 0.5 ,总体损失仅仅就是所有数字胶囊边缘损失的总和。

4 胶囊网路的架构

一个简单的胶囊网络结构如下图所示,这个浅的网络只有两层卷基层和一个全联接层,conv1由256个尺寸为9x9、步长为1带有relu单元的卷积核。这一层将原始图像像素强度转化为局部特征探测器的活动,并且作为基础胶囊层的输入。

基础胶囊层多维实体的最低层次,从逆图形透视法来看,激活基础胶囊对应于逆向渲染过程。这种计算方式非常不同与将实例化部分拼凑成熟悉的整体,这就是胶囊被设计成擅长的地方。

第二层(基础胶囊层)是卷积胶囊层,有32个通道,每个通道由8维胶囊构成的。每个基础胶囊的输出都能看见所有256x81个conv1单元的输出,这些单元以胶囊的中心位置互相重叠。总之,基础胶囊有32x6x6个胶囊输出(每个输出是个8维向量)并且6x6网格中的每个胶囊都共享相同的权重。我们可以把基础胶囊看作使用公式(1)作为其非线性激活函数的卷积层。最后一层(数字胶囊)对每个数字有一个16维的胶囊,并且每个数字胶囊接受来自上一层所有胶囊的输出。

图1:一个简单的3层胶囊网络。这个网络给出了和深度卷积网络具有可比性的结果。数字胶囊层的每个胶囊的活动向量长度表示着每个类别的实体的存在性,并且被用来计算分类损失。 W i j 是一个和每个基础胶囊 u i i ( 1 , 32 × 6 × 6 ) 和每个 v j , j ( 1 , 10 ) 之间的权重矩阵。
图1
图2:用于从数字胶囊层的表示中重构数字的解码器结构。在训练阶段,我们最小化输入图片和sigmoid层输出之间的欧几里得距离。我们在训练时使用真正的标签作为重构目标。
这里写图片描述
我们仅仅在两个相邻的胶囊层之间作路由过程(比如这里的基础胶囊层和数字胶囊层)。既然conv1层的输出是1维的,而在一维空间中是没有什么方向需要达成一致的,因此在conv1层和基础胶囊层之间就不需要路由过程。所有的路由未归一化先验概率( b i j )全部初始化维0。所以,初始情况下,一个胶囊的输出(u_i)被以同等的概率( c i j )发送到所有的父胶囊( v 0 . . . v 9 )中。

我们使用tensorflow来实现这个网络,并且我们使用带有tensorflow默认参数指数衰退学习率的Adam优化器来最小化所有边缘损失的总和。

4.1 使用重构作为正则化方法

我们使用一个重构损失来激励数字胶囊编码数字的实例化参数。在训练时,我们掩盖了除了正确的数字胶囊以外的所有的数字胶囊的活跃向量。然后我们使用这个活跃向量来重构输入图像。数字胶囊的输出被馈入到一个用于建模像素强度的三层全连接的解码器网络中(如图2)。我们最小化逻辑斯蒂单元的输出和图像强度之间的误差平方和,我们将重构误差乘以0.0005来缩小它在误差总和中的分量,以至于它不会支配边缘损失。如图3所示,当仅保留重要的细节,从胶囊网络输出的16维表示中重构的图像是健壮的。

图3:从3次路由迭代的胶囊网络的测试重构中抽样的MNIST样本。(l,p,r)分别表示标签,预测,重构。最右边的两列表示两个重构失败的例子,它说明了模型是如何难以区分5和3的。其它的正确分类的列表示模型保存了许多细节的同时平滑了噪声。
这里写图片描述

5 用于分类手写数字的胶囊网络

我们在MNIST数据集的28 × 28图像上训练网络,我们仅仅在图像的每个方向上平移2个像素并使用0进行填充来扩充数据集,而并没有用到其它数据增强和形变方法。数据集有分别为60k和10k的训练集和测试集。

我们仅仅使用单个模型而不用任何模型平均方法来测试。 Wan et al. [2013]使用组合方法并使用旋转、伸缩的数据增强方法达到了0.21%的测试误差。而在放弃这些增强方法时测试误差为0.39%。我们在这样的三层网络中就得到了一个很低的测试误差(0.25%),而在这之前仅仅通过更深的网络才能达到这样的性能。表.1 展示了不同设置的胶囊网络的测试误差率,并且显示了路由和重构正则化的重要性。增加了的重构正则化器通过加强胶囊向量中的姿态编码提升了路由的性能。

基准是一个使用三层分别是256、256、128的卷积层标准CNN。每层使用5 × 5的核,步长为1。最后的卷积层后面跟着两个全连接层,尺寸分别为328、192。最后一个全连接层和10分类的softmax层之间使用了一个dropout层。基准模型同样也在平移2个像素的MNIST上使用Adam优化器训练,基准模型被设计为使用和胶囊网络同样的计算代价而能达到它的最优的性能。就参数数量而言,基准模型有高达35.4M的参数,而胶囊网络只有8.2M,除去重构子网络的话只有6.8M的参数数量。

5.1 胶囊的每个维度分别表示了什么?

既然我们仅仅允许通过一个数字的编码通过而掩盖其他的数字,那么一个数字胶囊的各个维度应该学会如何用这类数字是如何被实例化的方法来张成由数字所有的变种所组成的空间。这些变异包括笔画的粗细、偏斜度和宽度。同样也包括具体数字的变异,例如数字2的尾部的长度。我们可以利用解码网络来一窥每个维度表示的是什么特征。在计算出正确的数字胶囊的活动向量之后,我们可以将这个向量进行轻微扰动,之后再馈入解码网络并且观察这样的扰动对重构有着什么样的影响。这些扰动的影响通过图.4来展示。经过观察我们发现这16个维度中的某一维度几乎总是表示着数字笔画的宽度。而其他某一些维度代表着一些全局变异的组合,并且有些维度表示着一个数字局部区域的变异。例如数字6的上伸部分的长度和下面的圆圈的尺寸是分别使用不同的维度来表示的。

5.2 对仿射变换的健壮性

实验表明相比传统的卷积网络,每个数字胶囊对每个类别学习一个更加健壮的表示。因为在手写数字的偏曲度、旋转、风格等等方面存在着很自然的变化,训练好的胶囊网络对训练数据轻微的放射变换有着适度的健壮性。为了测试胶囊网络对仿射变换的健壮度如何,我们在平移和补零过的MNIST训练集上训练一个胶囊网络和一个传统的卷积网络(使用最大池化和dropout),在这里每个MNIST数字被随机的防盗一个黑色背景的40 × 40的像素中。我们又在affNIST4数据集中测试这个网络,这里的样例是MNIST数字经过随机的小的仿射变换处理过的。我们的模型从未在除了平移和任何可以在标准MNIST数据集中能看到的自然的变换以外的放射变换数据集中进行训练。一个使用过早停止的未经过充分训练的胶囊网络在扩展的MNIST测试集上达到了99.23%的测试准确率,在affNIST测试集上达到了79%的准确率。一个有着差不多数量参数的传统卷积模型在扩展的MNIST测试数据上达到了类似的99.23%准确率,但在affNIST测试集上仅达到66%的准确率。

6 分割高度重叠的数字

7 其它的数据集

我们在cifar10数据集上测试我们的胶囊模型,在使用7个模型组成的组合模型在24x24图片块上使用3次路由迭代,获得了10.6%的误差。在这里,除了输入是3通道彩色图,还有我们使用64个不同类型的基础胶囊以外,我们的模型和上面MNIST的简单模型是一样的。我们同样发现为路由softmax引入一个哑类别是很有帮助的,因为我们并不期望最后一层的十个胶囊来解释图片中的所有信息。10.6%测试误差率其实是当年标准的卷积网络被首次应用到cifar10数据集上时所达到的精度。

胶囊网络和生成式模型有着一个相同的缺点,为了获得更好的效果,它会倾向于记录图像中的所有信息,它会建模记录一堆杂乱的信息,而不是直接在动态路由中添加一个“孤儿”类别。在CIFAR-10中,背景的变化太多以至于无法使用一个合理尺寸的网络来建模,而这就导致了更差的性能。

我们同样在smallNORB数据上测试了和上面完全一样架构的网络,并且获得了2.7%的测试误差,这和之前人们工作的最好的结果打了个平手。smallNORB数据集由96x96的立体的灰度图组成。我们将其尺寸改成48x48,并且在训练过程中,随机地从中剪切出32x32的图片块。而在测试时,我们使用的是中心的32x32图片块。我们同样在只有73257个图片的小的数据集SVHN上训练了更小的网络。我们将第一个卷积层的通道数量缩减到64,基础胶囊层缩减到16个6维的胶囊,最后一个胶囊层有着8维的胶囊向量,并且获得了4.3%的测试误差。

一些探讨和先前的工作

前30年中,最好的语音识别算法是使用混合高斯输出分布的隐马尔可夫模型。这种算法很容易在小型计算机上训练,但是它有表示致命的缺点,即表示上的局限性:它的一位有效编码(one-of-n representations)表示和其它使用分布式表示(distributed representation)的模型,例如循环神经网络相比较是指数级别的低效的。我们需要将隐藏节点增加到2次方级别的数量,才能使得HMM记忆两倍于它到目前所生成的序列的信息量。而对于循环网络,我们仅仅需要双倍的隐藏神经元即可。

如今卷积网络已经成为支配图像识别领域的算法,那么我们有理由来想想看在这个领域是否也存在可能导致消亡的指数级别的低消率。一个好的候选就是卷积网络泛化到新的视角点的困难度,对于卷积网络来说,可处理平移是内建的,但是对于处理其它维度仿射变换,我们要么就在网格中增加维度的指数级别的特征探测器的副本,要么就增加指数级别的带标签数据集。胶囊通过转化像素强度为识别到的零部件的实例化参数向量,再通过应用变换矩阵到这些零部件来预测更大的零部件的实例化参数这样的方法来避免指数级别的低效性。这些变换矩阵学习到将零部件和整体之间的本质的关系编码,构成了空间视觉变换不变知识,这些知识可以自动的泛化到新的视角。 Hinton et al. [2011]提出了变换自编码器来生成基础胶囊层的实例化参数,和由外部提供的系统所需的变换矩阵。在这里,我们提出了一个完备的系统并且回答了“更大更复杂的实体是如何通过使用由低层次活跃的胶囊预测的多个物体姿态的一致性来识别的”

胶囊作出了几个非常强的表示上的假设:在图像的一个位置,至多只存在一个胶囊表示的实体类型的实例。这个假设受一种称为“群集”的感知现象启发,消除了绑定问题,并且允许胶囊使用一个分布式表示(它的活动向量)来编码在给定位置的这种类型的一个实体的实例化参数。这种分布式的表示比较用激活一个高维度网格中的某一点,是具有指数级别的高效率,并且伴随着正确的分布式表示,胶囊可以充分地利用空间关系可以通过乘以矩阵来建模这样的事实。

胶囊利用随着视角变化而变化的神经元活动,而不是尝试在神经元活动中消除这样的视角变化。这给某些网络在“规范化”方法上的有着优势,例如空间变换网络:它可以同时处理不同物体或不同部件的不同的放射变换。

胶囊同样也很善于处理机器视觉中另一个最顽固的难题,即图像分割。因为实例化参数向量使得它们可以通过意见一致路由法,如本文所述。动态路由过程的重要性是以视觉脑皮层中不变模式识别这样的生物学上貌似合理的模型为支撑的。Hinton [1981b] 提出了动态链接和生成用于物体识别的形状描述的规范的基于物体的框架。Olshausen et al. [1993] 基于 Hinton [1981b]改善了动态连接和表示一个神物学上的合理性,位置和尺度不变物体表示模型。

如今关于胶囊的研究就好比本世纪初关于循环神经网络在语音识别方面的研究状况。这里有着一些基本的表示上的原因使得我们相信使用胶囊是一种更好的途径,但是它在打败当今高度发展好的技术之前可能需要许多更加细致的探索。事实上一个简单的胶囊系统已经在分割重叠的数字任务上给出了无与伦比的性能,这个初步的事实表明胶囊是一个值得探索的方向。

附录

A 需要使用多少次路由迭代?

为了通过实验来验证路由算法的收敛性,我们在每个路由迭代步骤中绘制未规范化概率(logits)的平均变化。图A.1显示了每个路由步中平均的 b | i j 的变化。通过实现观察,我们发现从开始训练到仅仅5次路由迭代后,logits的变化就已经小的可忽略不计了。在训练500回合后,由第二次路由调整的logits的平均变化已经小于0.007了,而由第五次路由引起的logits平均变化仅为1e-5。

图A.1:每次迭代中每个路由logit( b i j )的平均变化。在MNIST数据集上训练500个回合后的平均变化已经趋向稳定,并且如右图所示:平均变化在对数空间中随着更多的路由迭代次数几乎呈线性变化。
这里写图片描述

我们发现一般情况下,更多的迭代次数会增加模型的容量,并且在训练数据集上趋向于过拟合。图.A.2展示了在cifar10数据集上使用1次路由迭代和3次路由迭代的训练损失值的对比图。受图.A.2和图.A.1诱导,我们建议在所有的实验中都使用3次迭代。

图 A.2:在cifar10数据集上的训练损失。批量大小为128。使用3次路由迭代的胶囊网络优化损失函数的时候更快速,并且最终收敛到一个更低的损失值。
这里写图片描述


收藏一个不错的教学篇:链接


这是一个大神的实现:源码
我从中摘取了一下代码主干,试着结合论文一步步理解:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import sys
import numpy as np

def squash(s):
  s_norm = tf.norm(s, axis=-2, keep_dims=True)
  v = s*s_norm/(1+s_norm**2)
  return v

MAX_ROUTING_ITER = 3
M_PLUS = 0.9
M_MINUS = 0.1
LAMBDA = 0.5
SCALE_COEF = 0.003921568
BATCH_SIZE = 20
MAX_TRAIN_STEP = 25000
REGULAR_COEF = 0.0005
mnist = input_data.read_data_sets('./MNIST_data', one_hot=True)

input_X = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 784])
input_Y = tf.placeholder(tf.float32, shape=[BATCH_SIZE, 10])

X = tf.multiply(input_X, SCALE_COEF)
X = tf.reshape(X, [-1, 28, 28, 1])

# conv1 形状为 20,20,256
Conv1 = tf.contrib.layers.conv2d(inputs=X,
                                 num_outputs=256,
                                 kernel_size=[9,9],
                                 stride=1, 
                                 padding='VALID',
                                 activation_fn=tf.nn.relu)


# 基础胶囊层 PrimaryCaps (20-9+1)/2形状 batch,6,6,256
PrimaryCaps = tf.contrib.layers.conv2d(inputs=Conv1,
                                       num_outputs=256,
                                       kernel_size=[9,9],
                                       stride=2, 
                                       padding='VALID',
                                       activation_fn=tf.nn.relu)

# 基础胶囊层 PrimaryCaps [batch,32*6*6,10,8,1]
PrimaryCaps = tf.reshape(PrimaryCaps, [-1,32*6*6,1,8,1])
PrimaryCaps = squash(PrimaryCaps)
u = tf.tile(PrimaryCaps, [1,1,10,1,1])

# 数字胶囊层 DigitCaps [batchSize, 10]
# W shape [batchSize, 32*6*6, 10, 16, 8]
W = tf.get_variable(name='W',
                    shape=[1,32*6*6,10,16,8],
                    dtype=tf.float32,
                    initializer=tf.truncated_normal_initializer(stddev=0.1))
W = tf.tile(W, [BATCH_SIZE,1,1,1,1])

# u_hat shape [batchSize, 32*6*6, 10, 16, 1]
u_hat = tf.matmul(W, u)

b = tf.constant(np.zeros([BATCH_SIZE, 1152, 10, 1, 1], dtype=np.float32))

u_hat_stopped = tf.stop_gradient(u_hat, name='stop_gradient')

for iter_r in range(MAX_ROUTING_ITER):
  c = tf.nn.softmax(b, dim=2)
  if iter_r == (MAX_ROUTING_ITER - 1):
    s = c * u_hat
    s = tf.reduce_sum(s, axis=1, keep_dims=True)
    # v shape [batchSize,1, 10, 16, 1 ]
    v = squash(s)
  else:
    s = c * u_hat_stopped
    s = tf.reduce_sum(s, axis=1, keep_dims=True)
    # v shape [batchSize,1, 10, 16, 1 ]
    v = squash(s)

    # 计算预测一致性度量 ,v_tile [batchSize, 32*6*6,10,16,1]; agreements [batch,32*6*6,10,1,1]
    v_tile = tf.tile(v, [1,32*6*6,1,1,1])
    agreements = tf.matmul(u_hat_stopped, v_tile, transpose_a=True)

    # 更新路由系数
    b += agreements

# v shape [batchSize, 10, 16, 1 ]
v = tf.squeeze(v, axis=1)
# v_nrom shape [batchSize, 10, 1, 1 ]
v_norm = tf.norm(v, axis=2)

max_index = tf.to_int64(tf.argmax(v_norm, axis=1))
# v_nrom shape [batchSize, 1, 1, 1 ]
max_index = tf.reshape(max_index, shape=(BATCH_SIZE, ))

# max_plus shape [batch, 1, 1, 1]
max_plus = tf.reshape(tf.square(tf.maximum(0., M_PLUS-v_norm)), [BATCH_SIZE, -1])
max_minus = tf.reshape(tf.square(tf.maximum(0., v_norm-M_MINUS)), [BATCH_SIZE, -1])

L = input_Y * max_plus + 0.5 * (1 - input_Y) * max_minus
margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1))

mask_v = tf.reshape(input_Y, [BATCH_SIZE, 10, 1]) * tf.squeeze(v)
flat_v = tf.reshape(mask_v, [BATCH_SIZE, -1])
fc1 = tf.contrib.layers.fully_connected(flat_v, num_outputs=512)
fc2 = tf.contrib.layers.fully_connected(fc1, num_outputs=1024)
decoded = tf.contrib.layers.fully_connected(fc2, num_outputs=784, activation_fn=tf.sigmoid)

recons_loss = tf.reduce_mean(tf.reduce_sum(tf.square(decoded-input_X), axis=1))

total_loss = margin_loss + REGULAR_COEF * recons_loss

correct_prediction = tf.equal(tf.argmax(input_Y, axis=1), max_index)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

optimizer = tf.train.AdamOptimizer()
global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(total_loss)
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  for i in range(1000):
    batch_X, batch_y = mnist.train.next_batch(BATCH_SIZE)
    _, train_acc = sess.run([train_op, accuracy], feed_dict={input_X: batch_X, input_Y: batch_y})
    sys.stdout.write('\rTrain step: %d train acc:%.2f' % (i, train_acc))

    if (i+1) % 500 == 0:
      valid_X, valid_y = mnist.validation.next_batch(BATCH_SIZE)
      valid_acc = sess.run(accuracy, feed_dict={input_X: valid_X, input_Y: valid_y})
      sys.stdout.write('\nTrain step: %d, valid accuracy: %.3f\n' % (i, valid_acc))

猜你喜欢

转载自blog.csdn.net/mask_fade/article/details/80147101