SimMatch 论文分享

SimMatch: Semi-supervised Learning with Similarity Matching

现有的半监督学习策略:

  1. 在一个大规模数据集上进行预训练,并用少量标签数据微调
    (缺点在于,完全未利用到标签信息)
  2. 使用有标签数据训练一个语义分类器,并通过这个分类器为无标签数据生成伪标签(pseudo label)。伪标签通常由弱视图或多个增强视图的平均预测产生。最后目标通过多个强增强视图和伪标签之间的交叉熵来构建
    (缺点在于,当标签数据十分有限时,所训练的语义分类器并不可靠,由此生成的伪标签将会出现“overconfidence”问题,即模型会去拟合那些置信度很好但是错误的伪标签,由此导致性能下降)

本文方法

  1. 首先,希望强增强视图和弱增强视图具有相同的语义相似性(预测的标签)
  2. 强增强视图与弱增强视图具有相同的实例特征(即实例之间的相似性),以便于进行更多的内在特征匹配

在这里插入图片描述

方法

语义相似性

对于有标签样本,

  1. 对batch中所有样本随机应用一种弱增强 T w ( ⋅ ) T_w(\cdot) Tw()(例如旋转或裁剪)
  2. 用一个encoder F ( ⋅ ) F(\cdot) F()提取特征信息,即 h = F ( T ( x ) ) h=\mathcal{F}(T(x)) h=F(T(x))
  3. 采用一个全连接类别预测头 ϕ ( ⋅ ) \phi(\cdot) ϕ() h b \mathbf{h}_b hb映射为语义相似度 ,即 p = ϕ ( h ) p=\phi{(\mathbf{h})} p=ϕ(h)

其中,有标签样本直接使用交叉熵损失进行优化:
L s = 1 B ∑ H ( y , p ) \mathcal{L}_s=\frac{1}{B}\sum{\mathrm{H}(y,p)} Ls=B1H(y,p)
对于无标签样本,

  1. 随机应用弱增强或强增强的一种,并使用和有标签样本相同的处理方式,得到语义相似度 p w p^w pw p s p^s ps

  2. 计算两种标签之间的无监督损失:
    L u = 1 μ B ∑ 1 ( max ⁡ D A ( p w ) > τ ) H ( D A ( p w ) , p s ) \mathcal{L}_{u}=\frac{1}{\mu B} \sum \mathbb{1}\left(\max D A\left(p^{w}\right)>\tau\right) \mathrm{H}\left(D A\left(p^{w}\right), p^{s}\right) Lu=μB11(maxDA(pw)>τ)H(DA(pw),ps)
    这里 τ \tau τ作为置信度阈值,且仅保留在伪标签中最大类别概率大于 τ \tau τ无标签样本 D A ( ⋅ ) DA(\cdot) DA()表示分布对齐策略,用于平衡伪标签的分布。

实例相似性

目的是希望强增强视图与弱增强视图具有类似的相似性分布。

  1. 这里引入一个非线性映射头 g ( ⋅ ) g(\cdot) g(),能将特征表示 h \mathbf{h} h映射为一个低维嵌入,即 z b = g ( h b ) \mathbf{z}_b=g(\mathbf{h}_b) zb=g(hb)

  2. 遵循基于anchoring的方法,这里将 z b w \mathbf{z}^w_b zbw z b s \mathbf{z}^s_b zbs分别表示为来自弱增强和强增强的嵌入

  3. 现在,假设对于一簇不同的样本 z k : k ∈ ( 1 , … , K ) {\mathbf{z}_k:k \in(1, \dots,K)} zk:k(1,,K),具有K个弱增强嵌入,使用相似度函数 s i m ( ⋅ ) sim(\cdot) sim()计算 z w \mathbf{z}^w zw和第i个实例 z i \mathbf{z}_i zi之间的相似度:
    sim ⁡ ( u , v ) = u T ^ v ∥ u ∥ ∥ v ∥ \operatorname{sim}(\mathbf{u}, \mathbf{v})=\frac{\mathbf{u}^{\hat{T}} \mathbf{v}}{ \|\mathbf{u}\|\|\mathbf{v}\|} sim(u,v)=uvuT^v
    使用softmax函数处理相似度计算结果,得到相似度分布:
    q i w = exp ⁡ ( sim ⁡ ( z b w , z i ) / t ) ∑ k = 1 K exp ⁡ ( sim ⁡ ( z b w , z k ) / t ) q_{i}^{w}=\frac{\exp \left(\operatorname{sim}\left(\mathbf{z}_{b}^{w}, \mathbf{z}_{i}\right) / t\right)}{\sum_{k=1}^{K} \exp \left(\operatorname{sim}\left(\mathbf{z}_{b}^{w}, \mathbf{z}_{k}\right) / t\right)} qiw=k=1Kexp(sim(zbw,zk)/t)exp(sim(zbw,zi)/t)
    其中 t t t为温度系数,用于控制分布的平滑程度。

  4. 同样计算 z s \mathbf{z}^s zs z i \mathbf{z}_i zi之间的相似度分布:
    q i s = exp ⁡ ( sim ⁡ ( z b s , z i ) / t ) ∑ k = 1 K exp ⁡ ( sim ⁡ ( z b s , z k ) / t ) q_{i}^{s}=\frac{\exp \left(\operatorname{sim}\left(\mathbf{z}_{b}^{s}, \mathbf{z}_{i}\right) / t\right)}{\sum_{k=1}^{K} \exp \left(\operatorname{sim}\left(\mathbf{z}_{b}^{s}, \mathbf{z}_{k}\right) / t\right)} qis=k=1Kexp(sim(zbs,zk)/t)exp(sim(zbs,zi)/t)

  5. 最后,可以通过最小化 q s q^s qs q w q^w qw之间的差异实现一致性正则化(consistency regularization),这里采用交叉熵损失实现:
    L i n = 1 μ B ∑ H ( q w , q s ) \mathcal{L}_{i n}=\frac{1}{\mu B} \sum \mathrm{H}\left(q^{w}, q^{s}\right) Lin=μB1H(qw,qs)
    这里需要注意的是,这种实例的一致性正则化只应用于无标签样本。

最终的损失函数为:
L overall  = L s + λ u L u + λ i n L i n \mathcal{L}_{\text {overall }}=\mathcal{L}_{s}+\lambda_{u} \mathcal{L}_{u}+\lambda_{i n} \mathcal{L}_{i n} Loverall =Ls+λuLu+λinLin
其中, λ u \lambda_u λu λ i n \lambda_{in} λin是控制两种损失权重的平衡因子。

通过SimMatch 进行标签传播

由于上述过程完全未利用到标签信息,这里进一步介绍一种能利用到标签信息的方法,并允许语义相似性和实例相似性相互交互。

在这里插入图片描述

具体做法:

  1. 实例化一个带标签的内存缓冲区,用于存放所有标注的样本( q i w q_i^w qiw q i s q_i^s qis),这样使得每个用到的样本都被指定一个特定的类别。

    如果我们将 ϕ ( ⋅ ) \phi(\cdot) ϕ()中的向量(由有标签样本生成)解释为"中心化"的类引用,那么我们标记的内存缓冲区中的embedding(由无标签样本生成)可以看作是实例个体引用的集合。

  2. 给定一个弱增强样本,文中首先计算它的语义相似度(可以认为是类别标签) p w ∈ R 1 × L p^w \in \mathbb{R}^{1 \times L} pwR1×L和实例相似度 q w ∈ R 1 × K q^w \in \mathbb{R}^{1 \times K} qwR1×K(这里L一定是远小于K的,因为文中希望每个类别至少具有一个样本(L即为类别数))

    使用语义相似度来校准实例相似度

  3. 为了使用 p w p^w pw 校准 q w q^w qw,我们需要将 p w p^w pw展开到 K K K维空间,文中将其表示为 p u n f o l d p^{unfold} punfold。文中通过为每个已标记的嵌入匹配相应的语义相似性来实现这一点,即:
    p i u n f o l d = p j w ,  where  class ⁡ ( q j w ) = class ⁡ ( p i w ) p_{i}^{u n f o l d}=p_{j}^{w}, \text { where } \operatorname{class}\left(q_{j}^{w}\right)=\operatorname{class}\left(p_{i}^{w}\right) piunfold=pjw, where class(qjw)=class(piw)
    其中, c l a s s ( ⋅ ) class(\cdot) class()是返回ground truth类别的函数。

    具体来说, c l a s s ( q j w ) class(q^w_j) class(qjw)表示内存缓冲区中第j个元素的标签, c l a s s ( p i w ) class(p^w_i) class(piw)表示第i个类。

  4. 接下来,通过使用 p u n f o l d p^{unfold} punfold q w q^w qw进行缩放来重新生成校准后的实例的伪标签,可以表示为如下形式:
    q ^ i = q i w p i u n f o l d ∑ k = 1 K q k w p k unfold  \widehat{q}_{i}=\frac{q_{i}^{w} p_{i}^{u n f o l d}}{\sum_{k=1}^{K} q_{k}^{w} p_{k}^{\text {unfold }}} q i=k=1Kqkwpkunfold qiwpiunfold

  5. 将校准后的伪标签 q ^ \hat{q} q^作为新的目标并替代之前计算损失 L i n \mathcal{L}_{in} Lin中的 q w q^w qw

    使用实例相似度调整语义相似度

  6. 首先将 q q q汇聚到 L L L维空间,记为 q a g g q^{agg} qagg,通过对具有相同ground-truth的实例求和进行实现:
    q i a g g = ∑ j = 0 K 1 ( class ⁡ ( p i w ) = class ⁡ ( q j w ) ) q j w q_{i}^{a g g}=\sum_{j=0}^{K} \mathbb{1}\left(\operatorname{class}\left(p_{i}^{w}\right)=\operatorname{class}\left(q_{j}^{w}\right)\right) q_{j}^{w} qiagg=j=0K1(class(piw)=class(qjw))qjw

  7. 通过使用 q a g g q^{agg} qagg平滑 p w p^w pw重新生成调整过的语义伪标签,:
    p ^ i = α p i w + ( 1 − α ) q i a g g \widehat{p}_{i}=\alpha p_{i}^{w}+(1-\alpha) q_{i}^{a g g} p i=αpiw+(1α)qiagg
    其中 α \alpha α作为超参数控制语义信息和实例信息的权重

  8. 同样地,将校准后的伪标签 p ^ \hat{p} p^作为新的目标并替代之前计算损失 L u \mathcal{L}_{u} Lu中的 p i w p^w_i piw

  9. 此时,伪标签 p ^ \hat{p} p^ q ^ \hat{q} q^便都具有语义和实例级别的信息

在这里插入图片描述
其意义在于:当语义相似度和实力相似度接近时,意味着两个分布与彼此的预测一致,由此生成的伪标签将具有更高的置信度,从而更加可靠

在这里插入图片描述
整个训练过程如图所示。

细节实现

unfold操作

batch_u = 1
num_class = 10
K = 256 

prob_ku_orig = torch.zeros((batch_u, num_class)) #(1, 10)
labels = torch.zeros(K, dtype=torch.long) #(256, )
index = labels.expand([batch_u, -1]) #(1, 256)

factor = prob_ku_orig.gather(1, index) # p^{unfold}
print(prob_ku_orig.shape, factor.shape)
# torch.Size([1, 10]) torch.Size([1, 256])

# e.g.
prob = torch.tensor([[0.15, 0.8, 0.05]])
print(prob)
labels = torch.tensor([1,0,0,2,1])
index = labels.expand([1, -1])
prob.gather(1, index)
# tensor([[0.8000, 0.1500, 0.1500, 0.0500, 0.8000]])

aggregate操作

bs = teacher_prob_orig.size(0) # batch_u
aggregated_prob = torch.zeros([bs, self.num_classes], device=teacher_prob_orig.device)
aggregated_prob = aggregated_prob.scatter_add(1, self.labels.expand([bs,-1]), teacher_prob_orig) #q^{agg}        

猜你喜欢

转载自blog.csdn.net/qq_45802280/article/details/127897761