失眠网,内容丰富有趣,生活中的好帮手!
失眠网 > MoCo 动量对比学习——一种维护超大负样本训练的框架

MoCo 动量对比学习——一种维护超大负样本训练的框架

时间:2018-12-15 07:55:43

相关推荐

MoCo 动量对比学习——一种维护超大负样本训练的框架

MoCo 动量对比学习——一种维护超大负样本训练的框架 FesianXu 0803 at Baidu Search Team

前言

在拥有着海量数据的大型互联网公司中,对比学习变得逐渐流行起来,大家都拿它进行表征学习的探索。本文对MoCo这篇论文进行笔记,希望对读者有所帮助。

如有谬误请联系指出,本文遵守 CC 4.0 BY-SA 版权协议,转载请联系作者并注明出处,谢谢

∇\nabla∇ 联系方式:

e-mail: FesianXu@

github: /FesianXu

知乎专栏: 计算机视觉/计算机图形理论与应用

微信公众号:

对比学习(Contrastive Learning)是最近在表征学习(Representation Learning)中非常流行的学习范式。对比学习不同于一般的分类模型,考虑对数据进行多个类别的分类,在对比学习中,我们通常需要考虑的只有两类:正样本和负样本。举个实际应用系统的例子,在搜索系统中的QU相关性中[2],我们通常需要判断一个Query和一个Doc之间的相关性,如果相关的我们认为是正样本,而不相关的我们认为是负样本。目前很多模型都可以基于对比学习设计预训练任务,并且得到非常可观的性能提升,比如CLIP模型[3]。对于存有着海量用户行为数据(比如用户搜索的点击数据,浏览数据,电商的购买数据,停留时间等)的大型互联网公司而言,在对比学习这个方向有着更大的探索需求。

“正样本很多是正得相似,但是负样本却负得各有不同”,这句话虽然不太准确,但是描述了对比学习中的本质困难:如何去寻找合适的,更多的负样本去进行训练,从而让模型学习出足够能区分正负样本的特征?在Kaiming大佬的MoCo[1]这篇工作中,大佬给出了他的解决方案。

对比学习中的海量负样本需求

在对比学习中,对负样本有着海量的需求。在搜索场景中,我们会想尽办法去构造尽可能多的负样本。比如将展现了但是并没有被用户点击的样本当作是负样本,比如将同为点击了的样本混搭组成batch negative负样本。如Fig 1.1所示,通过对Query的特征和Doc的特征进行组合,如式子(1.1)所示,可以构建出打分矩阵S∈RN×N\mathbf{S} \in \mathbb{R}^{N \times N}S∈RN×N,其中NNN为batch size,DDD为特征大小。

S=Q×U∈RN×NQ∈RN×DU∈RN×D(1.1)\mathbf{S} = \mathbf{Q} \times \mathbf{U} \in \mathbb{R}^{N \times N} \\ \mathbf{Q} \in \mathbb{R}^{N \times D} \\ \mathbf{U} \in \mathbb{R}^{N \times D} \tag{1.1} S=Q×U∈RN×NQ∈RN×DU∈RN×D(1.1)

那么可知道只有对角线的打分才是正样本,而其他绿色部分的打分Qi⋅Uj,i≠jQ_i \cdot U_j, i \neq jQi​⋅Uj​,i​=j则是负样本打分,如果负样本的打分越高,那么就可以认为是难负样本。通过batch negative从一个batch内构造的负样本,因为这些Doc都是被用户点击过的,而且Query都是被搜索过的,通过对他们进行混搭构成的负样本足够“真实”。而通过无点数据构成的负样本则可能因为是展现的Doc过多,已经满足了用户的需求而没有被点击,因此可能出现将正样本误分为负样本的情况。当然,这个和具体的业务有关。

Fig 1.1 Batch Negative的方式从一个batch中构造负样本。

不管怎么说,通过以上构造负样本的方式,我们有能力构造足够多的负样本。特别是通过batch negative的方式,在batch size为NNN的情况下我们能构造出N2−NN^2-NN2−N的负样本,其中很可能就会出现足够“好”的负样本让模型学习的好的表征。但是问题在于,我们受限于显存的大小,我们无法无限增大batch size,即便通过多机多卡分布式的方法,每张卡上的batch size大小也是有极限的,即便通过某些机制去汇总多卡之间的负样本打分,在汇总阶段也会遇到瓶颈。因此目前能看到的公开工作中,据笔者了解,batch size最大的就是CLIP[3],开到了32,768。然而如果继续增大呢?可能就会遇到无法逾越的硬件瓶颈了。

Kaiming大佬在MoCo中,认为对比学习可以从某种形式看成是“字典查询”(Dictionary look-up),这一点不难理解,我们的正样本可以认为是分布在诸多的负样本之中,而输入一个Query,那么我们期望能匹配到正确的Doc(也是字典中的Key)。一次理想的匹配应该是:query尽可能相似于正样本的Key,而和其他负样本的Key尽可能不相似。从这个角度上看,我们能发现对比学习的一个根本难题在于:

如何提高这个字典的词表大小。如何提高字典的连续性(consistency)。

在NLP任务中,特别是word2vec这种在所有有限词表中选取负样本的方法,因为语言的词表显然是个离散有限集合,可能也就两万来个Key,因此对比学习在NLP中可能更好地进行。而对于视觉任务而言,因为视觉信息是连续的,很难将其划分为离散的有限Key集合将其覆盖。因此视觉任务的对比学习会比NLP更难。

让我们再总结下,对比学习的两个难题,第一个其实就可以任务是如何去增大负样本的数量;第二个其实可以认为是如何稳定Query和Key的编码器的一致性 (为了表述的通用性,我们将会用Key表示匹配任务中的Doc,Item端)。要知道,Query端和Key端的编码器可以是一样的,也可以是不一样的;可以是共享参数的,也可以是独立参数的。因此在参数更新过程中,要保证两端编码器状态的一致性。 (本文考虑的是双塔模型,因此交互式模型暂时不考虑。)

我们下面对已有的一些解决负样本数量的方法进行概述,以便于了解提出MoCo的历史渊源。

端到端方式

最为传统的提高负样本(也就是词典的词表)的方式就是提高Batch size大小,然后在更新参数阶段同时对Query和Key的编码器进行更新。这种方式可以最大程度上保证Query和Key编码器的一致性,但是显然会遇到我们之前讨论过的硬件瓶颈。

Fig 1.2 通过端到端的方式进行对比学习。

Memory Bank

一种直接的想法是将词典词表大小和batch大小进行解耦,从而使得词表不受限于batch大小。在[4]中,作者提出Memory Bank框架,该框架的主要特点在于维护一个队列作为负样本集合。在Memory Bank框架中,首先需要对数据集所有数据进行刷特征向量,然后在每个mini batch中从中随机挑选负样本作为词典,但是并不进行梯度回传,也就是不会对Key编码器进行参数更新。当然,为了优化特征,会对Query编码器进行参数更新,然后将当前Query编码器的特征入到负样本队列,并且将负样本队列中的旧特征出队,从而实现对负样本队列的更新。这个负样本队列称之为Memory Bank。

然而这个过程中,虽然这个方法满足了“词典词表足够大”的条件,但是“QK编码器一致性”没有保证,因为Key编码总是落后于Query编码的,这个过程中无法保证Query-Key匹配的一致性。可以认为Memory Bank提出的编码器更新方式是“异步”逐渐更新的。

Fig 1.3 Memory Bank框架示意图,主要通过维护一个负样本队列实现大尺度范围的负样本训练。

MoCo

kaiming大佬提出MoCo解决了以上的问题。首先进行符号化表达,考虑一个Query qqq和一系列编码后的样本{k0,k1,⋯,kn}\{k_0, k_1, \cdots,k_n\}{k0​,k1​,⋯,kn​},这些样本是词典的key,假设这个词典中只有唯一的k+k_{+}k+​和qqq匹配,那么在对比学习中,只有qqq和k+k_{+}k+​匹配的时候,损失函数才是最低的。利用infoNCE (info Noise Contrastive Estimate)损失,我们可以对这个过程建模为(1.2)所示。

Lq=−log⁡exp⁡(q⋅k+/τ)∑i=0Kexp⁡(q⋅ki/τ)(1.2)\mathcal{L}_{q} = -\log{\dfrac{\exp(q \cdot k_{+} / \tau)}{\sum_{i=0}^K \exp(q \cdot k_i / \tau)}} \tag{1.2} Lq​=−log∑i=0K​exp(q⋅ki​/τ)exp(q⋅k+​/τ)​(1.2)

其中的τ\tauτ是温度系数,而KKK是负样本个数。我们用fq(⋅)f_q(\cdot)fq​(⋅)表示query编码器,fk(⋅)f_{k}(\cdot)fk​(⋅)表示key编码器,那么有q=fq(xq),k=fk(xk)q=f_{q}(x^q), k=f_k(x^k)q=fq​(xq),k=fk​(xk),用θq,θk\theta_q, \theta_kθq​,θk​分别表示这两个编码器的参数。

Fig 1.4 MoCo框架示意图。

为了保持大规模的词典词表大小,和Memory Bank一样需要保持一个负样本队列,不同点在于对Key编码器的更新方式不再采用“异步”的方式,而是采用了“同步”的方式。一种最为朴素的同步更新方式是:直接将当前step更新好的Query编码器的参数拷贝到Key编码器中(当然前提是QK编码器结构一致),如式子(1.3)所示。然而这种方式将会导致Key编码器训练状态的完全紊乱,因此无法训练将会一直震荡,那么通过引入动量进行QK状态协调就是一个自然而然的方案,如式子(1.4)所示。

θk←θq(1.3)\theta_k \leftarrow \theta_q \tag{1.3} θk​←θq​(1.3)

θk←mθk+(1−m)θq(1.4)\theta_k \leftarrow m\theta_k +(1-m)\theta_q \tag{1.4} θk​←mθk​+(1−m)θq​(1.4)

引入动量,将Key编码器的参数更新,以一种合适地方式引入了Query编码器的状态。此处的mmm通常是一个比较大的值,比如0.99。从Fig 1.5中,我们可以看到不同大小的动量对结果的影响,其中0.999是最为合适的。这也告诉我们,保持QK一致性很重要,同时保持Key编码器训练的一致性也很重要。

Fig 1.5 不同大小的动量大小对结果的影响。

作者在文章中还以伪代码的形式举了一个应用例子,如Alg 1所示。

Alg 1. 以PyTorch的方式写的MoCo流程伪代码。

# f_q, f_k: encoder networks for query and key# queue: dictionary as a queue of K keys (CxK)# m: momentum# t: temperaturef_k.params = f_q.params # initializefor x in loader: # load a minibatch x with N samplesx_q = aug(x) # a randomly augmented versionx_k = aug(x) # another randomly augmented versionq = f_q.forward(x_q) # queries: NxCk = f_k.forward(x_k) # keys: NxCk = k.detach() # no gradient to keys# positive logits: Nx1l_pos = bmm(q.view(N,1,C), k.view(N,C,1))# negative logits: NxKl_neg = mm(q.view(N,C), queue.view(C,K))# logits: Nx(1+K)logits = cat([l_pos, l_neg], dim=1)# contrastive loss, Eqn.(1)labels = zeros(N) # positives are the 0-thloss = CrossEntropyLoss(logits/t, labels)# SGD update: query networkloss.backward()update(f_q.params)# momentum update: key networkf_k.params = m*f_k.params+(1-m)*f_q.params# update dictionaryenqueue(queue, k) # enqueue the current minibatchdequeue(queue) # dequeue the earliest minibatch

在这个例子中,作者将同一个图片的不同数据增广样本当成是正样本,而其他所有的图片都是负样本。在代码最后,保持了对队列的更新(通过将新的minibatch入队列和将旧的出队列实现),并且通过动量更新实现了对QK编码器的更新。

作者对比了端到端,Memory Bank和MoCo框架的效果,如Fig 1.6所示,我们可以看到几个值得注意的现象:

端到端的方式无法一直提高负样本大小,可以发现MoCo和端到端在256-1024的负样本大小阶段的结果可以匹配(track)上的。MoCo和Memory Bank都可以将负样本大小提高到很大。但是MoCo的效果一直都比Memory Bank高约2%点。

这个现象也和之前的分析保持一致。

Fig 1.6 三种框架的实验结果对比。

Shuffling BN

作者在文中提到了一嘴“Shuffling BN”,而这似乎是在本文才引出来的概念,我们在这儿讨论一下。在实践中,研究者发现在对比学习中的编码器使用Batch Normalization可能会使得模型学习不到好的表征[5],之前的研究会采取不采用BN层的方法。在本工作中,作者认为编码器中的传统BN层可能会影响预训练任务,进而让预训练很容易就找到了类似于平凡解的“低损失解”,而这将会使得预训练失效。

这个过程在“同步”更新的对比学习框架中容易出现,比如如Fig 1.7所示,其中BN层将会在N-F面上进行标准化,我们知道一个batch中的Query和Key是一一匹配的,都是正样本,也即是qi,kj,i=jq_i, k_j, i=jqi​,kj​,i=j是匹配的。但是qi,kj,i≠jq_i, k_j, i \neq jqi​,kj​,i​=j时却不是匹配的,因此通过BN可能会将本是负样本的信息“泄漏”到正样本中去。而在Memory Bank中不会出现这个问题,因为Memory Bank的正样本Key是来自于过去的mini batch的,因为是“异步”过程所以不会影响当前结果。

Fig 1.7 在对比学习中,BN层导致预训练任务失效的原因猜想。

为了解决这个问题,作者提出了Shuffling BN的概念,而这个层必须在多GPU卡上才能进行。首先作者将不同GPU的sub mini batch进行汇聚,然后对汇聚后的batch进行打乱,并且记下打乱后的索引(便于恢复shuffle结果),在提取了Key编码特征后,通过之前记下的索引,恢复shuffle结果。注意到这个过程只对Key编码器进行,而不考虑Query编码器。正如作者提供的代码code 1.1所示。

# compute key featureswith torch.no_grad(): # no gradient to keysself._momentum_update_key_encoder() # update the key encoder# shuffle for making use of BNim_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)k = self.encoder_k(im_k) # keys: NxCk = nn.functional.normalize(k, dim=1)# undo shufflek = self._batch_unshuffle_ddp(k, idx_unshuffle)

总结

本文简单对MoCo进行了讨论,从实验上看,的确MoCo能有很大程度的提高,但是也正如作者文章最后讨论的,并没有完全利用到数据。如Fig 1.8所示,MoCo在100万的数据和在十亿级别的数据上性能差别似乎不是很大。而且,目前的MoCo动量更新框架似乎要求QK编码器的结构需要一致?这一点不知道是否笔者理解正确了。

Fig 1.8 MoCo在IN-1M和IG-1B上的结果并没有差别很大。

Reference

[1]. He, K., Fan, H., Wu, Y., Xie, S., & Girshick, R. (). Momentum contrast for unsupervised visual representation learning. In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 9729-9738).

[2]. https://fesian./article/details/116377189

[3]. Radford, A., Kim, J. W., Hallacy, C., Ramesh, A., Goh, G., Agarwal, S., … & Sutskever, I. (). Learning transferable visual models from natural language supervision.arXiv preprint arXiv:2103.00020.

[4]. Wu, Z., Xiong, Y., Yu, S. X., & Lin, D. (). Unsupervised feature learning via non-parametric instance discrimination. InProceedings of the IEEE conference on computer vision and pattern recognition(pp. 3733-3742).

[5]. Henaff, O. (, November). Data-efficient image recognition with contrastive predictive coding. InInternational Conference on Machine Learning(pp. 4182-4192). PMLR.

如果觉得《MoCo 动量对比学习——一种维护超大负样本训练的框架》对你有帮助,请点赞、收藏,并留下你的观点哦!

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。