Qwen-VL大模型LoRA微调、融合及部署
天道酬勤者 2024-09-13 16:31:01 阅读 78
1.服务器租赁
在 <code>AutoDL 平台中租赁一个4090 等 24G 显存大小的容器实例
2.环境配置
conda create -n qwenvl python=3.11 -y
source activate qwenvl
conda install -y -c "nvidia/label/cuda-12.1.0" cuda-runtim #安装 cuda-runtime
3.下载模型
cd <code>~/autodl-tmp/ // 在
~/autodl-tmp/
创建如下目录
mkdir model //存放模型文件
cd model
git lfs install //确保 lfs 已经被正确安装
sudo apt-get update //否则更新包列表
sudo apt-get install git-lfs //安装 Git LFS
git clone https://github.com/QwenLM/Qwen-VL.git //下载包含所需依赖的github文件
git clone https://www.modelscope.cn/qwen/Qwen-VL-Chat.git //下载模型
git clone https://www.modelscope.cn/qwen/Qwen-VL-Chat-Int4.git //下载量化模型
4.依赖配置
cd Qwen-VL //进入包含所需依赖的github文件目录
pip3 install -r requirements.txt
pip3 install -r requirements_openai_api.txt
pip3 install -r requirements_web_demo.txt
pip3 install deepspeed
pip3 install peft
pip3 install optimum
pip3 install auto-gptq
pip3 install modelscope -U
5.测试
通过网页端Web UI使用:
python /root/autodl-tmp/model/Qwen-VL/web_demo_mm.py --checkpoint-path /root/autodl-tmp/model/Qwen-VL-Chat
通过代码使用:
<code>import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image
torch.manual_seed(1234)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("/root/autodl-tmp/model/Qwen-VL-Chat", trust_remote_code=True)
# Load the model with GPU device map
model = AutoModelForCausalLM.from_pretrained(
"/root/autodl-tmp/model/Qwen-VL-Chat",
device_map="auto", code>
trust_remote_code=True
).eval()
# 1st dialogue turn
query = tokenizer.from_list_format([
{'image': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'},
{'text': '这是什么'},
])
response, history = model.chat(tokenizer, query=query, history=None)
print(response)
# Expected output: 图中是一名女子在沙滩上和狗玩耍,旁边的狗是一只拉布拉多犬,它们处于沙滩上。
# 2nd dialogue turn
response, history = model.chat(tokenizer, '输出"击掌"的检测框', history=history)
print(response)
# Expected output: <ref>击掌</ref><box>(536,509),(588,602)</box>
# Debugging: Print history and response before drawing bbox
print("History:", history)
print("Response before drawing bbox:", response)
# Check the return type and save the image
result_image = tokenizer.draw_bbox_on_latest_picture(response, history)
if result_image:
# Save the generated image to a file
result_image.save('output_image.jpg')
print(f"Generated image saved as output_image.jpg")
else:
print("No box detected or no image generated")
5.数据集准备
需要将所有样本数据放到一个列表中并存入JSON文件中。每个样本对应一个字典,包含id和conversation,其中后者为一个列表。示例如下所示:
<code>[
{
"id": "identity_0",
"conversations": [
{
"from": "user",
"value": "你好"
},
{
"from": "assistant",
"value": "我是Qwen-VL,一个支持视觉输入的大模型。"
}
]
},
{
"id": "identity_1",
"conversations": [
{
"from": "user",
"value": "Picture 1: <img>https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg</img>\n图中的狗是什么品种?"
},
{
"from": "assistant",
"value": "图中是一只拉布拉多犬。"
},
{
"from": "user",
"value": "框出图中的格子衬衫"
},
{
"from": "assistant",
"value": "<ref>格子衬衫</ref><box>(588,499),(725,789)</box>"
}
]
},
{
"id": "identity_2",
"conversations": [
{
"from": "user",
"value": "Picture 1: <img>assets/mm_tutorial/Chongqing.jpeg</img>\nPicture 2: <img>assets/mm_tutorial/Beijing.jpeg</img>\n图中都是哪"
},
{
"from": "assistant",
"value": "第一张图片是重庆的城市天际线,第二张图片是北京的天际线。"
}
]
}
]
对数据格式的解释:
1.为针对多样的VL任务,增加了一下的特殊tokens: <img> </img> <ref> </ref> <box> </box>.
2.对于带图像输入的内容可表示为 Picture id: <img>img_path</img>\n{your prompt},其中id表示对话中的第几张图片。"img_path"可以是本地的图片或网络地址。
3.对话中的检测框可以表示为<box>(x1,y1),(x2,y2)</box>,其中 (x1, y1) 和(x2, y2)分别对应左上角和右下角的坐标,并且被归一化到[0, 1000)的范围内. 检测框对应的文本描述也可以通过<ref>text_caption</ref>表示。
6.微调
python3 /root/autodl-tmp/model/Qwen-VL/finetune.py \
--model_name_or_path /root/autodl-tmp/model/Qwen-VL-Chat \
--data_path /root/autodl-tmp/data/data.json \
--bf16 True \
--fix_vit True \
--output_dir output_qwen \
--num_train_epochs 5 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-5 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.01 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "none" \
--model_max_length 600 \
--lazy_preprocess True \
--gradient_checkpointing true \
--use_lora
7.模型合并及推理
与全参数微调不同,LoRA的训练只需存储adapter部分的参数。因此需要先合并并存储模型
<code>from peft import AutoPeftModelForCausalLM # 确保导入所需的模块
from modelscope import (
AutoTokenizer
)
path_to_adapter = "/root/autodl-tmp/model/output_qwen"
# 从预训练模型中加载自定义适配器模型
model = AutoPeftModelForCausalLM.from_pretrained(
path_to_adapter, # 适配器的路径
device_map="auto", # 自动映射设备code>
trust_remote_code=True # 信任远程代码
).eval() # 设置为评估模式
new_model_directory = "/root/autodl-tmp/model/New-Model"
tokenizer = AutoTokenizer.from_pretrained(
path_to_adapter, trust_remote_code=True,
)
tokenizer.save_pretrained(new_model_directory)
# 合并并卸载模型
merged_model = model.merge_and_unload()
# 保存合并后的模型
merged_model.save_pretrained(new_model_directory, max_shard_size="2048MB", safe_serialization=True)code>
8.部署
<code>from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image
torch.manual_seed(1234)
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("/root/autodl-tmp/model/New-Model", trust_remote_code=True)
# Load the model with GPU device map
model = AutoModelForCausalLM.from_pretrained(
"/root/autodl-tmp/model/New-Model",
device_map="auto", code>
trust_remote_code=True
).eval()
query = tokenizer.from_list_format([
{'image': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'},
{'text': '这是什么'},
])
response, history = model.chat(tokenizer, query, history=None)
print("回答如下:\n", response)
9.保存依赖包信息
pip freeze > requirements_qwen_vl_sy.txt
依赖包内容:
absl-py==2.1.0
accelerate==0.32.1
aiofiles==23.2.1
aiohttp==3.9.5
aiosignal==1.3.1
altair==5.3.0
annotated-types==0.7.0
anyio==4.4.0
attrs==23.2.0
auto_gptq==0.7.1
certifi==2024.7.4
charset-normalizer==3.3.2
click==8.1.7
coloredlogs==15.0.1
contourpy==1.2.1
cycler==0.12.1
datasets==2.20.0
deepspeed==0.14.4
dill==0.3.8
distro==1.9.0
dnspython==2.6.1
einops==0.8.0
email_validator==2.2.0
fastapi==0.111.1
fastapi-cli==0.0.4
ffmpy==0.3.2
filelock==3.13.1
fonttools==4.53.1
frozenlist==1.4.1
fsspec==2024.2.0
gekko==1.2.1
gradio==4.38.1
gradio_client==1.1.0
grpcio==1.65.1
h11==0.14.0
hjson==3.1.0
httpcore==1.0.5
httptools==0.6.1
httpx==0.27.0
huggingface-hub==0.24.0
humanfriendly==10.0
idna==3.7
importlib_resources==6.4.0
Jinja2==3.1.3
jsonschema==4.23.0
jsonschema-specifications==2023.12.1
kiwisolver==1.4.5
Markdown==3.6
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.1
mdurl==0.1.2
modelscope==1.16.1
mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
networkx==3.2.1
ninja==1.11.1.1
numpy==1.26.3
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-ml-py==12.555.43
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.1.105
nvidia-nvtx-cu12==12.1.105
openai==1.35.15
optimum==1.21.2
orjson==3.10.6
packaging==24.1
pandas==2.2.2
peft==0.11.1
pillow==10.2.0
protobuf==4.25.3
psutil==6.0.0
py-cpuinfo==9.0.0
pyarrow==17.0.0
pyarrow-hotfix==0.6
pydantic==2.8.2
pydantic_core==2.20.1
pydub==0.25.1
Pygments==2.18.0
pyparsing==3.1.2
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-multipart==0.0.9
pytz==2024.1
PyYAML==6.0.1
referencing==0.35.1
regex==2024.5.15
requests==2.32.3
rich==13.7.1
rouge==1.0.1
rpds-py==0.19.0
ruff==0.5.3
safetensors==0.4.3
scipy==1.14.0
semantic-version==2.10.0
sentencepiece==0.2.0
shellingham==1.5.4
six==1.16.0
sniffio==1.3.1
sse-starlette==2.1.2
starlette==0.37.2
sympy==1.12
tensorboard==2.17.0
tensorboard-data-server==0.7.2
tiktoken==0.7.0
tokenizers==0.13.3
tomlkit==0.12.0
toolz==0.12.1
torch==2.3.1+cu121
torchaudio==2.3.1+cu121
torchvision==0.18.1+cu121
tqdm==4.66.4
transformers==4.32.0
transformers-stream-generator==0.0.4
triton==2.3.1
typer==0.12.3
typing_extensions==4.9.0
tzdata==2024.1
urllib3==2.2.2
uvicorn==0.30.1
uvloop==0.19.0
watchfiles==0.22.0
websockets==11.0.3
Werkzeug==3.0.3
xxhash==3.4.1
yarl==1.9.4
声明
本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。