基于星火大模型的群聊对话分角色要素提取挑战#AI夏令营 #Datawhale #夏令营

Koma Wong 2024-07-17 16:01:01 阅读 79

这次挑战基于 Spark AI 大模型,旨在从群聊对话中提取结构化信息,并以 JSON 格式输出,以满足特定的数据格式要求。

1、跑通baseline

准备工作与环境搭建

Datawhale官方有提供详细的速通文档:‬​​​‬​​​⁠​​‍‍‍‍‌⁠​‬⁠​​‬​​‬​​‍​​​​‍⁠‌​‬‍​‍​​零基础入门大模型技术竞赛 - 飞书云文档 (feishu.cn)

按照上述文档可以速通baseline。只要会点运行就可以!!!****领取API****

Step 1:下载相关库

为了使用 Spark AI 大模型,首先需要安装必要的 Python 库。在 Jupyter Notebook 或者命令行中执行以下命令:

<code>!pip uninstall websocket-client

!pip install --upgrade spark_ai_python websocket-client

Step 2:配置导入 

下面代码,我的理解是:这段代码用于调用讯飞开放平台的星火认知大模型(Spark3.5 Max),生成聊天机器人的响应。通过提供模型的URL、应用ID、API密钥等信息,建立与大模型的连接,并发送用户输入的消息,获取并返回模型生成的回复。

注意:SPARKAI_APP_ID 、SPARKAI_API_SECRET 、  SPARKAI_API_KEY使用自己的信息。连接文档中有,点击获取即可。

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler

from sparkai.core.messages import ChatMessage

import json

# Spark AI 大模型的 URL、API 秘钥信息等配置

SPARKAI_URL = 'wss://spark-api.xf-yun.com/v3.5/chat'

SPARKAI_APP_ID = ''

SPARKAI_API_SECRET = ''

SPARKAI_API_KEY = ''

SPARKAI_DOMAIN = 'generalv3.5'

模型配置与测试

在准备工作完成后,我们需要确保 Spark AI 大模型的配置和测试能够顺利进行。

Step 3:模型测试

编写一个函数来测试模型配置是否正确,并能够正确生成响应。这里以简单的问候语作为示例。

def get_completions(text):

messages = [ChatMessage(

role="user",code>

content=text

)]

spark = ChatSparkLLM(

spark_api_url=SPARKAI_URL,

spark_app_id=SPARKAI_APP_ID,

spark_api_key=SPARKAI_API_KEY,

spark_api_secret=SPARKAI_API_SECRET,

spark_llm_domain=SPARKAI_DOMAIN,

streaming=False,

)

handler = ChunkPrintHandler()

response = spark.generate([messages], callbacks=[handler])

return response.generations[0][0].text

# 测试模型配置是否正确

text = "你好"

get_completions(text)

数据准备与处理

在完成模型配置测试后,我们需要准备用于角色要素提取挑战的数据集,并进行必要的数据处理。

Step 4:数据读取

编写函数来读取和写入 JSON 文件,这些文件包含了我们需要处理的训练集和测试集数据。

def read_json(json_file_path):

"""读取json文件"""

with open(json_file_path, 'r') as f:

data = json.load(f)

return data

def write_json(json_file_path, data):

"""写入json文件"""

with open(json_file_path, 'w') as f:

json.dump(data, f, ensure_ascii=False, indent=4)

# 读取数据集

train_data = read_json("dataset/train.json")

test_data = read_json("dataset/test_data.json")

角色要素提取与结果生成

最后,我们将利用前面配置好的模型和数据集,进行角色要素提取,并生成符合要求的输出结果。

Step 5:角色要素提取

在这一步中,我们将每条测试数据输入模型,提取出符合预期格式的角色要素信息,并将结果存储在输出列表中。

from tqdm import tqdm

retry_count = 5 # 设置重试次数

result = [] # 存储最终结果

error_data = [] # 存储处理失败的数据

for index, data in tqdm(enumerate(test_data)):

index += 1

is_success = False

for i in range(retry_count):

try:

res = get_completions(PROMPT_EXTRACT.format(content=data["chat_text"]))

infos = convert_all_json_in_text_to_dict(res)

infos = check_and_complete_json_format(infos)

result.append({

"infos": infos,

"index": index

})

is_success = True

break

except Exception as e:

print("index:", index, ", error:", e)

continue

if not is_success:

data["index"] = index

error_data.append(data)

Step 6:结果保存与提交

最后,将处理后的结果保存为 JSON 文件,并通过比赛平台提交生成的 output.json 文件。

write_json("output.json", result)

2、尝试个人idea

由于对这部分比较陌生,我们可以依葫芦画瓢。先跑通然后再根据自己的想法进行优化。下面的内容主要是对数据进行优化(参考与于官方给的baseline2)

数据集制作

这段代码用于预处理聊天文本数据,包括去除表情符号和超链接,提取名字、电话号码和电子邮件,并合并连续相同人的聊天记录。

jsonl_data = {"instruction":"假设你是一个智能交互助手,基于用户的输入文本,解析其中语义,抽取关键信息,以json格式生成结构化的语义内容。","input":"请调小空气净化器的湿度到1","output":"{\"intent\":\"CONTROL\",\"slots\":[{\"name\":\"device\",\"normValue\":\"airCleaner\",\"value\":\"空气净化器\"},{\"name\":\"insType\",\"normValue\":\"set\",\"value\":\"调小\"},{\"name\":\"attr\",\"normValue\":\"humidity\",\"value\":\"湿度\"},{\"name\":\"attrValue\",\"normValue\":\"1\",\"value\":\"1\"}],\"sample\":\"请调小空气净化器的湿度到1\"}"}

import json

# 打开并读取JSON文件

with open('train.json', 'r', encoding='utf-8') as file:code>

data = json.load(file)

from dataclasses import dataclass

from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler

from sparkai.core.messages import ChatMessage

import pandas as pd

import os

import json

import re

import matplotlib.pyplot as plt

from tqdm import tqdm

from math import ceil

import numpy as np

from copy import deepcopy

import random

tqdm.pandas()

plt.rcParams['font.family'] = ['STFangsong']

plt.rcParams['axes.unicode_minus'] = False

train_file = "train.json"

test_file = "test_data.json"

train_data = pd.read_json(os.path.join( train_file))

test_data = pd.read_json(os.path.join(test_file))

# 删除表情图片、超链接

train_data['chat_text'] = train_data['chat_text'].str.replace(r"\[[^\[\]]{2,10}\]", "", regex=True)

train_data['chat_text'] = train_data['chat_text'].str.replace("https?://\S+", "", regex=True)

test_data['chat_text'] = test_data['chat_text'].str.replace(r"\[[^\[\]]{2,10}\]", "", regex=True)

test_data['chat_text'] = test_data['chat_text'].str.replace("https?://\S+", "", regex=True)

def get_names_phones_and_emails(example):

names = re.findall(r"(?:\n)?([\u4e00-\u9fa5]+\d+):", example["chat_text"])

names += re.findall(r"@([\u4e00-\u9fa5]+)\s", example["chat_text"])

emails = re.findall(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}", example["chat_text"])

# phones = re.findall(r"1[356789]\d{9}", example["chat_text"]) # 文本中的手机号并不是标准手机号

phones = re.findall(r"\d{3}\s*\d{4}\s*\d{4}", example["chat_text"])

return pd.Series([set(names), set(phones), set(emails)], index=['names', 'phones', 'emails'])

def merge_chat(example):

for name in example['names']:

example["chat_text"] = example["chat_text"].replace(f"\n{name}:", f"<|sep|>{name}:")

chats = example["chat_text"].split("<|sep|>")

last_name = "UNKNOWN"

new_chats = []

for chat in chats:

if chat.startswith(last_name):

chat = chat.strip("\n")

chat = "".join(chat.split(":")[1:])

new_chats[-1] += " " + chat

else:

new_chats.append(chat)

last_name = chat.split(":")[0]

return pd.Series(["\n".join(new_chats), new_chats], index=["chats", "chat_list"])

# 使用正则表达式获得'names', 'phones', 'emails'

train_data[['names', 'phones', 'emails']] = train_data.apply(get_names_phones_and_emails, axis=1)

test_data[['names', 'phones', 'emails']] = test_data.apply(get_names_phones_and_emails, axis=1)

# 分割聊天记录, 合并连续相同人的聊天

train_data[["chats", "chat_list"]] = train_data.apply(merge_chat, axis=1)

test_data[["chats", "chat_list"]] = test_data.apply(merge_chat, axis=1)

def process(excemple):

chat_list = excemple["chat_text"].split("\n")

res = []

s = 0

while s < len(chat_list):

i, j = s, s+1

start_j = j

while i < len(chat_list) and j < len(chat_list):

if chat_list[i] == chat_list[j]:

i += 1

else:

if i != s:

if j - start_j >10:

res += list(range(start_j, j))

i = s

start_j = j

j += 1

s += 1

texts = []

for i in range(len(chat_list)):

if i not in res:

texts.append(chat_list[i])

return "\n".join(texts)

train_data["chat_text"] = train_data.apply(process, axis = 1)

test_data["chat_text"] = test_data.apply(process, axis = 1)

 这段代码用于清理并合并聊天记录中的连续重复行,通过函数 `process` 分割聊天记录、检测重复行并移除它们,然后将处理后的聊天记录应用于训练和测试数据集中的每一行。

def process(excemple):

chat_list = excemple["chat_text"].split("\n")

res = []

s = 0

while s < len(chat_list):

i, j = s, s+1

start_j = j

while i < len(chat_list) and j < len(chat_list):

if chat_list[i] == chat_list[j]:

i += 1

else:

if i != s:

if j - start_j >10:

res += list(range(start_j, j))

i = s

start_j = j

j += 1

s += 1

texts = []

for i in range(len(chat_list)):

if i not in res:

texts.append(chat_list[i])

return "\n".join(texts)

train_data["chat_text"] = train_data.apply(process, axis = 1)

test_data["chat_text"] = test_data.apply(process, axis = 1)

这段代码用于生成训练和测试数据集的提示词和目标,指导信息提取模型从聊天文本中提取客户信息。 

def process(x):

# 提示词,我们交代清楚大模型的角色、目标、注意事项,然后提供背景信息,输出格式就可以了

prompt = f"""Instruction:

你是一个信息要素提取工作人员,你需要从给定的`ChatText`中提取出**客户**的`Infos`中相关信息,将提取的信息填到`Infos`中,

注意事项:

1. 没有的信息无需填写

2. 保持`Infos`的JSON格式不变,没有的信息项也要保留!!!

4. 姓名可以是聊天昵称

5. 注意是客户的信息,不是客服的信息

6. 可以有多个客户信息

ChatText:

{x["chat_text"]}

"""

# 要求的输出格式

infos = """"

Infos:

infos": [{

"基本信息-姓名": "",

"基本信息-手机号码": "",

"基本信息-邮箱": "",

"基本信息-地区": "",

"基本信息-详细地址": "",

"基本信息-性别": "",

"基本信息-年龄": "",

"基本信息-生日": "",

"咨询类型": [],

"意向产品": [],

"购买异议点": [],

"客户预算-预算是否充足": "",

"客户预算-总体预算金额": "",

"客户预算-预算明细": "",

"竞品信息": "",

"客户是否有意向": "",

"客户是否有卡点": "",

"客户购买阶段": "",

"下一步跟进计划-参与人": [],

"下一步跟进计划-时间点": "",

"下一步跟进计划-具体事项": ""

}]

"""

# prompt+infos是文件中的input,answer是文件中的target

answer = f"""{x["infos"]}""" #target

total= len(prompt + infos + answer)

if total > 8000:

prompt = prompt[:8000-len(infos + answer)]

return pd.Series([prompt, answer], index=["input", "target"])

train_data = train_data.apply(process, axis=1)

# 测试集中的target并没有用可以忽略

data = test_data.apply(process, axis=1)

这段代码将处理后的聊天数据写入名为 `traindata.jsonl` 的文件中,为了满足训练需求,数据集被重复12次,确保达到足够的训练数据量。 

with open('traindata.jsonl', 'w', encoding='utf-8') as file:code>

# 训练集行数(130)不符合要求,范围:1500~90000000

# 遍历数据列表,并将每一行写入文件

# 这里为了满足微调需求我们重复12次数据集 130*12=1560

for line_data in tqdm(data):

print(line_data)

line_input = line_data["chat_text"]

line_output = line_data["infos"]

content = line_input

prompt = f'''

你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。

****群聊对话****

{content}

****分析数据****

客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日

客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细

客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段

跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

****注意****

1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容

2.不要输出分析内容

3.输出内容格式为md格式

'''

res = chatbot(prompt=prompt)

# print(res)

line_write = {

"instruction":jsonl_data["instruction"],

"input":json.dumps(res, ensure_ascii=False),

"output":json.dumps(line_output, ensure_ascii=False)

}

# 因为数据共有130行,为了能满足训练需要的1500条及以上,我们将正常训练数据扩充12倍。

for time in range(12):

file.write(json.dumps(line_write, ensure_ascii=False) + '\n')

测试集也是一样的。 

import csv

# 打开一个文件用于写入CSV数据

with open('test.csv', 'w', newline='', encoding='utf-8') as csvfile:code>

# 创建一个csv writer对象

csvwriter = csv.writer(csvfile)

csvwriter.writerow(["input","target"])

# 遍历数据列表,并将每一行写入CSV文件

for line_data in tqdm(data_test):

content = line_data["chat_text"]

prompt = f'''

你是一个数据分析大师,你需要从群聊对话中进行分析,里面对话的角色中大部分是客服角色,你需要从中区分出有需求的客户,并得到以下四类数据。

****群聊对话****

{content}

****分析数据****

客户基本信息:需要从中区分出客户角色,并得到客户基本信息,其中包括姓名、手机号码、邮箱、地区、详细地址、性别、年龄和生日

客户意向与预算信息: 客户意向与预算信息包括咨询类型、意向产品、购买异议点、预算是否充足、总体预算金额以及预算明细

客户购买准备情况:户购买准备情况包括竞品信息、客户是否有意向、客户是否有卡点以及客户购买阶段

跟进计划信息: 跟进计划信息包括参与人、时间点和具体事项,这些信息用于指导销售团队在未来的跟进工作中与客户互动

****注意****

1.只输出客户基本信息、客户意向与预算信息、客户购买准备情况、跟进计划信息对应的信息,不要输出无关内容

2.不要输出分析内容

3.输出内容格式为md格式

'''

res = chatbot(prompt=prompt)

# print(line_data["chat_text"])

## 文件内容校验失败: test.jsonl(不含表头起算)第1行的内容不符合规则,限制每组input和target字符数量总和上限为8000,当前行字符数量:10721

line_list = [res, "-"]

csvwriter.writerow(line_list)

 微调需要排队,建议大家,可以多开几个训练,增大被选中的概率。

3、尝试进阶baseline

后续继续学习,加油!!!



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。