Github - MalongTech/research-xbm

XBM 实现伪代码如:

在 PyTorch 训练中的构建与应用如:

https://github.com/msight-tech/research-xbm/blob/master/ret_benchmark/modeling/xbm.py

import torch

class XBM:
    def __init__(self, K):
        self.K = K #队列长度
        self.feats = torch.zeros(self.K, 128).cuda() #特征
        self.targets = torch.zeros(self.K, dtype=torch.long).cuda()#标签
        self.ptr = 0 #队列指针

    @property
    def is_full(self):
        return self.targets[-1].item() != 0

    def get(self):
        if self.is_full:
            return self.feats, self.targets
        else:
            return self.feats[:self.ptr], self.targets[:self.ptr]

    def enqueue_dequeue(self, feats, targets):
        q_size = len(targets)
        if self.ptr + q_size > self.K:
            self.feats[-q_size:] = feats
            self.targets[-q_size:] = targets
            self.ptr = 0
        else:
            self.feats[self.ptr: self.ptr + q_size] = feats
            self.targets[self.ptr: self.ptr + q_size] = targets
            self.ptr += q_size
#
#训练代码中
print("[INFO]>>> use XBM")
#XBM 初始化
xbm = XBM(K=1000)

#
feats = ''
targets = ''
#入队列
xbm.enqueue_dequeue(feats.detach(), targets.detach())
#注:.detach() 意味着不用进行梯度计算

loss = criterion(feats, targets, feats, targets)
log_info["batch_loss"] = loss.item()

#出队列
xbm_feats, xbm_targets = xbm.get()
xbm_loss = criterion(feats, targets, xbm_feats, xbm_targets)
log_info["xbm_loss"] = xbm_loss.item()
loss = loss + XBM.WEIGHT * xbm_loss #如:WEIGHT=1.0

#
optimizer.zero_grad()
loss.backward()
optimizer.step()
Last modification:April 29th, 2021 at 10:44 am