聊聊GLM-4-9B开源模型的微调loss计算

cnblogs 2024-06-12 10:43:00 阅读 64

概述

Github官方地址:GLM-4

网上已经有很多关于微调的文章,介绍各种方式下的使用,这里不会赘述。我个人比较关心的是微调时的loss计算逻辑,这点在很多的文章都不会有相关的描述,因为大多数人都是关心如何使用之类的应用层,而不是其具体的底层逻辑,当然咱也说不清太底层的计算。

可了解其它loss计算的文章:

再聊多轮对话微调训练格式与长序列训练

聊聊ChatGLM2与ChatGLM3微调多轮对话的设计逻辑及源码分析

聊聊大模型多轮对话的训练及优化

微调

微调格式:

[

{

"messages": [

{

"role": "system",

"content": "<system prompt text>",

"tools": [

{

"name": "<tool name>",

"args": {

"<arg name>": "<arg value>"

}

}

]

},

{

"role": "user",

"content": "<user prompt text>"

},

{

"role": "assistant",

"content": "<assistant response text>"

},

{

"role": "user",

"content": "<user prompt text>"

},

{

"role": "assistant",

"content": "<assistant response text>"

},

{

"role": "observation",

"content": "<observation prompt text>"

},

{

"role": "assistant",

"content": "<assistant response observation>"

},

{

"role": "user",

"content": "<user prompt text>"

},

{

"role": "assistant",

"content": "<assistant response text>"

}

]

}

]

微调源码地址:finetune.py

Loss计算代码:

def process_batch(

batch: Mapping[str, Sequence],

tokenizer: PreTrainedTokenizer,

max_input_length: int,

max_output_length: int,

) -> dict[str, list]:

batched_conv = batch['messages']

batched_input_ids = []

batched_labels = []

# batched_conv 是一个数组

# conv 是数组内的单个 message

for conv in batched_conv:

input_ids = [151331, 151333]

loss_masks = [False, False]

# conv 是数组内的单个 message

# message 是 单个role json对象

for message in conv:

message = process_message(message)

# 设置 mask 掩码,只有system,user,observation不参与mask计算,其余的角色参与计算

loss_mask_val = False if message['role'] in ('system', 'user', 'observation') else True

# 获取 input 文本的数字表示(ids)

new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[0][2:]

# 计算整句的 mask

new_loss_masks = [loss_mask_val] * len(new_input_ids)

# 拼接message中的每段json

input_ids += new_input_ids

# 拼接message中每段json对应的mask

loss_masks += new_loss_masks

# 追加结尾的 token id

input_ids.append(tokenizer.eos_token_id)

loss_masks = [False, *loss_masks]

labels = []

for input_id, mask in zip(input_ids, loss_masks):

if mask:

# 添加到label,计算loss

labels.append(input_id)

else:

# -100 不处理,即ignore_index

labels.append(-100)

max_length = max_input_length + max_output_length + 1

# 截断

batched_input_ids.append(input_ids[:max_length])

batched_labels.append(labels[:max_length])

return {'input_ids': batched_input_ids, 'labels': batched_labels}

注释在代码中已经写明。process_batch方法用于将输入转换为ids,并计算mask(用于Loss计算)。而该方法的调用是在数据集的遍历处理中,即如下所示:

tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)

data_manager = DataManager(data_dir, ft_config.data_config)

# 数据集拆分遍历

train_dataset = data_manager.get_dataset(

Split.TRAIN,

functools.partial(

process_batch,

tokenizer=tokenizer,

max_input_length=ft_config.max_input_length,

max_output_length=ft_config.max_output_length,

),

batched=True,

)

print('train_dataset:', train_dataset)

Loss计算如下图所示:

总结

相比较于之前的ChatGLM版本,GLM4开源版本的多轮对话loss计算更恰当且效率也会更高;在其它的开源模型/微调框架中早已支持该种loss计算,如InternLM、XTuner、Firefly等。对于loss格式的类别,可参考XTuner的官方文档说明:dataset_format.md。

原文链接:https://mp.weixin.qq.com/s/0mLCQfpaZr7eEonG4a4Etg

更多大模型相关的文章,请上个人公众号查阅:

image



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。