【大模型】Transformers库单机多卡推理之device_map
酒酿小圆子~ 2024-08-27 13:01:02 阅读 72
文章目录
device_map参数解析device_map="auto" 代码示例手动配置参考资料
Hugging Face的<code>transformers库支持自动模型(AutoModel)的模型实例化方法,来自动载入并使用GPT、ChatGLM等模型。在AutoModel.from_pretrained()
方法中的 device_map
参数,可实现单机多卡推理。
device_map参数解析
device_map
是AutoModel.from_pretrained()
方法中的一个重要参数,它用于指定模型的各个部件应加载到哪个具体的计算设备上,以实现资源的有效分配和利用。这个参数在进行模型并行或分布式训练时特别有用。
device_map
参数有 auto, balanced, balanced_low_0, sequential
几种选项,具体如下:
“auto” 和 “balanced”
:将会在所有的GPU上平衡切分模型。主要是有可能发现更高效的分配策略。“balanced” 参数的功能则保持稳定。(可按需使用)“balanced_low_0”
:会在除了第一个GPU上的其它GPU上平衡划分模型,并且在第一个 GPU 上占据较少资源。这个选项符合需要在第一个 GPU 上进行额外操作的需求,例如需要在第一个 GPU 执行 generate 函数(迭代过程)。(推荐使用)“sequential”
:按照GPU的顺序分配模型分片,从 GPU 0 开始,直到最后的 GPU(那么最后的 GPU 往往不会被占满,和 - “balanced_low_0” 的区别就是第一个还是最后一个,以及非均衡填充),但是我在实际使用当中GPU 0 会直接爆显存了。(不推荐使用)
device_map=“auto” 代码示例
这里我们的环境为单机两张显卡,使用 device_map="auto"code> 来加载
ChatGLM-6B
模型,观察显卡占用情况。
示例代码如下:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0, 1'
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
# 加载模型
model_path = "./model/chatglm2-6b"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map="auto")code>
text = '什么是机器学习?'
inputs = tokenizer(text, return_tensors="pt")code>
print(inputs)
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
程序运行前显卡占用情况如下:
可以看到0号显卡本身被其他程序占用了约13G的显存。
使用 <code>device_map="auto"code> 后的显卡占用情况如下:
使用auto策略后,显卡0和1分别多占用了约6~7G的显存。
手动配置
配置为单卡
model = AutoModel.from_pretrained(model_path, trust_remote_code=True, device_map=device)
配置为多卡
假设想要模型的某些部分在第一张显卡,另一部分在第二张显卡,需要知道模型的层名或者按照模型的组件大小进行合理分配。不过,具体层名需要根据实际模型来确定,这里提供一个概念性的示例:
device = {
"transformer.h.0": "cuda:0", # 第一部分放在GPU 0
"transformer.h.1": "cuda:1", # 第二部分放在GPU 1
# ... 根据模型结构继续分配
}
model = AutoModelForCausalLM.from_pretrained(model_dir, device_map=device)
参考资料
【AI大模型】Transformers大模型库(七):单机多卡推理之device_map【大模型运行漫长的开始】 关于多GPU使用 device_map
声明
本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。