torch显存分析——对生成模型清除显存

torch显存分析——对生成模型清除显存

1. 问题介绍

本文主要针对生成场景下,如何方便快捷地清除当前进程占用的显存。文章的重点不止是对显存的管理,还包括怎样灵活的使用自定义组件来控制生成过程。

在之前的文章torch显存分析——如何在不关闭进程的情况下释放显存中,通过一个实验,分析了torch的显存占用情况,以及如何在不关闭进程的前提下,利用代码将显存释放掉。然而,在近期的实验中,却发现之前所介绍的显存释放方法对生成模型并不好用。

在前文中,所使用的方法是:

real_inputs = inputs['input_ids'][..., : 2, ...].to(model.device)
with torch.no_grad():
    logits = model(real_inputs, tail)
del real_inputs
del logits
torch.cuda.empty_cache()

然而,如果对生成模型,直接将model的forward替换成generate的话,即如下的替换方法,则会遇到问题。

with torch.no_grad():
	logits = model.generate(real_inputs)
del real_inputs
del logits
torch.cuda.empty_cache()

因为生成过程中,会有新的token生成,model.generate很可能不止一次在调用forward,所以这种方法就不灵了。

2. 应对方法

既然是模拟一边模型的forward方法,那就想办法让forward方法只被调用一次。或许直接还是使用model.forward就可以解决这个问题。但是这里我采用了另一种方法——使用Stopping Criteria。

既然只希望它生成执行一次,那就可以直接使用一个默认的criteria:

from transformers.generation.stopping_criteria import MaxNewTokensCriteria, StoppingCriteriaList

empty_cache_helper = StoppingCriteriaList()
empty_cache_helper.append(MaxNewTokensCriteria(start_length=0, max_new_tokens=1))

这个东西的作用就是,最多只生成一个新的token,然后立即停止生成。

那么在清除显存时,只需要将它加上就好了:

扫描二维码关注公众号,回复: 16144241 查看本文章
with torch.no_grad():
	logits = model.generate(real_inputs, stopping_criteria=self.empty_cache_helper)
del real_inputs
del logits
torch.cuda.empty_cache()

如果不了解stopping criteria的话,可以去回顾之前的两篇文章:

以beam search为例,详解transformers中generate方法(上)
以beam search为例,详解transformers中generate方法(下)

今后的博客中,可能会结合一些例子,对自定义的logits processor和stopping criteria的使用进行介绍,感兴趣的同学可以关注一下。

猜你喜欢

转载自blog.csdn.net/weixin_44826203/article/details/132067916