需要做以下几个步骤:
1. 将代码中的分布式训练相关代码删除或注释掉。这些代码通常包括调用 init_process_group 函数以及使用 DistributedDataParallel 包装模型和优化器的代码。
2. 将代码中的 num_tasks 和 rank 变量的值设置为1。这些变量通常用于指定分布式训练的进程数量和当前进程的排名。在单卡运行时,这些变量的值应该都为1。
3. 将代码中的 start 和 end 变量的值设置为0和数据集的大小。这些变量通常用于指定当前进程处理的数据的起始和结束索引。在单卡运行时,你只需要处理整个数据集,因此这些变量的值应该分别为0和数据集的大小。
4. 将代码中的 metric_logger.log_every 函数调用删除或注释掉。这个函数通常用于在分布式训练中记录日志,但在单卡运行时不需要使用。
下面是修改后的代码示例:
texts = data_loader.dataset.text
num_text = len(texts)
text_bs = 256
text_ids = []
text_embeds = []
text_atts = []
for i in range(0, num_text, text_bs):
text = texts[i: min(num_text, i+text_bs)]
text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
text_embeds.append(text_embed)
text_ids.append(text_input.input_ids)
text_atts.append(text_input.attention_mask)
text_embeds = torch.cat(text_embeds,dim=0)
text_ids = torch.cat(text_ids,dim=0)
text_atts = torch.cat(text_atts,dim=0)
text_ids[:,0] = model.tokenizer.enc_token_id
# 计算图像特征
image_feats = []
image_embeds = []
for image, img_id in data_loader:
image = image.to(device)
image_feat = model.visual_encoder(image)
image_embed = model.vision_proj(image_feat[:,0,:])
image_embed = F.normalize(image_embed,dim=-1)
image_feats.append(image_feat.cpu())
image_embeds.append(image_embed)
image_feats = torch.cat(image_feats,dim=0)
image_embeds = torch.cat(image_embeds,dim=0)
# 计算相似度矩阵
sims_matrix = image_embeds @ text_embeds.t()
score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)
num_tasks = 1
rank = 1
step = sims_matrix.size(0)//num_tasks + 1
start = 0
end = sims_matrix.size(0)
# 计算图像到文本的分数矩阵
for i,sims in enumerate(sims_matrix[start:end]):
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device)
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
output = model.text_encoder(text_ids[topk_idx],
attention_mask = text_atts[topk_idx],
encoder_hidden_states = encoder_output,
encoder_attention_mask = encoder_att,
return_dict = True,
)
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
score_matrix_i2t[start+i,topk_idx] = score + topk_sim
sims_matrix = sims_matrix.t()
score_matrix_t2i = torch.full((len(texts),len(data_loader
如果代码本身写的好:
只需要:将--distributed 参数设置为False 就可以了