java接入AI大模型个人实践(一)
kimloner 2024-10-03 08:01:01 阅读 74
一、前言
大家好,你的月亮我的心,我是博主小阿金,欢迎各位工友。 近来工作比较清闲、当然这也得益于AI技术的日益成熟、由于一直使用的是发小公司的AI大模型产品、博主也没有跟上潮流去研究如何接入个人项目、本着知其然知其所以然,近几天就来浅浅的探究一下如何接入个人项目。
二、模型选择、demo测试
由于我个人比较懒,又比较抠,本着能省则省的原则,最终选择国产大模型通义千问作为本次接入的大模型之一、接入模型的方法千篇一律,一通百通,最终还是需要落脚在业的设计上。下面是java项目接入千问流程。
参考文章:通义千问模型服务
已开通服务并获得API-KEY:开通DashScope并创建API-KEY。引入Java SDK:java sdk引入流程选择适合的交互方式、单轮对话、多轮对话、流式输出、我选择的是市面上常见的多轮对话+流式输出模式,通过demo中的代码可以看出、要实现多轮对话需要将历史对话的上下文,传递给大模型,以实现对话的持续输出,观察流式输出样例中GenerationParam方法可知,我们只需要将连续的对话以Message集合的形式传入即可。
<code>private static GenerationParam buildGenerationParam(Message userMsg) {
return GenerationParam.builder()
.model("qwen-turbo")
.messages(Arrays.asList(userMsg))
.resultFormat(GenerationParam.ResultFormat.MESSAGE)
.topP(0.8)
.incrementalOutput(true)
.build();
}
进一步得到Message的属性,这便于我们进行相关表结构的设计
public class Message {
String role;
String content;
@SerializedName("tool_calls")
List<ToolCallBase> toolCalls;
@SerializedName("tool_call_id")
String toolCallId;
@SerializedName("name")
String name;
其中需要关注的暂时只有 role content,role对应的为身份标识,content是输出的内容,其中role有以下几种类型
USER("user"),
ASSISTANT("assistant"),
BOT("bot"),
SYSTEM("system"),
ATTACHMENT("attachment"),
TOOL("tool");
常用的就是user和assistant一个是用户发送,一个是系统回复类型。
三、流程及表结构设计思考
设计思考
基于demo代码设计可知
为了实现连续对话、我们必须将对话历史存储在本地需要一定的会话控制逻辑,一是确保连续对话的上下文token,二是对于持续会话时间的控制
基于以上两点,我在数据库建立了两张数据表一张为问题表,一张为回复表用以解决历史对话问题,新增了会话id即sessionId属性用来控制连续一段时间的对话。以下是代码实现、仅供参考,如有高见请私信或者评论指出,不胜感激。
表结构
表结构设计,两表以问题id关联,确保为一个对话,以user_id、与session_id进行区分不同用户的会话历史。
CREATE TABLE `chat_answer_history` (
`id` varchar(50) NOT NULL COMMENT '主键id',
`user_id` bigint NOT NULL COMMENT '用户id',
`session_id` varchar(30) NOT NULL COMMENT '会话id',
`question_id` bigint NOT NULL COMMENT '问题id',
`answer` longtext CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '答案',
`message_role` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '角色',
`message_tool_calls` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL,
`message_tool_call_id` varchar(64) DEFAULT '',
`message_name` varchar(64) DEFAULT '',
`create_time` datetime DEFAULT NULL COMMENT '创建时间',
`create_by` varchar(255) DEFAULT NULL COMMENT '创建者',
PRIMARY KEY (`id`),
KEY `session_id_index` (`session_id`) USING BTREE,
KEY `question_id_index` (`question_id`) USING BTREE,
KEY `user_id_index` (`user_id`) USING BTREE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='答案历史对话表';code>
CREATE TABLE `chat_question_history` (
`id` bigint NOT NULL AUTO_INCREMENT COMMENT '主键id',
`user_id` bigint NOT NULL COMMENT '用户id',
`session_id` varchar(30) NOT NULL COMMENT '会话id',
`question` longtext CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '问题',
`question_file_name` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '文件名称',
`question_file_url` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '文件地址',
`model` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '模型类型',
`assistant` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '定义模型方向',
`create_time` datetime DEFAULT NULL COMMENT '创建时间',
`create_by` varchar(255) DEFAULT NULL COMMENT '创建者',
`message_role` varchar(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '角色',
PRIMARY KEY (`id`),
KEY `session_id_index` (`session_id`) USING BTREE,
KEY `user_id_index` (`user_id`) USING BTREE
) ENGINE=InnoDB AUTO_INCREMENT=215 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='提问历史对话表';code>
四、多轮对话+流式输出代码实现
此处给出基本实现逻辑
1、请求接口
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.kingoffice.common.core.web.controller.BaseController;
import com.kingoffice.common.core.web.page.TableDataInfo;
import com.kingoffice.common.security.utils.SecurityUtils;
import com.kingoffice.system.domain.vo.ChatHistoryVo;
import com.kingoffice.system.domain.vo.ChatQuestionVo;
import com.kingoffice.system.service.aigc.IChatQuestionHistoryService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.annotation.Transient;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.*;
import reactor.core.publisher.Flux;
import java.util.Comparator;
import java.util.List;
@RestController
@RequestMapping("/chat")
public class ChatController extends BaseController{
@Autowired
private IChatQuestionHistoryService questionHistoryService;
/**
* 获取gpt返回
*
* @param chatQuestionVo
* @return
*/
@Transient
@PostMapping(path = "/getChat")
public Flux<GenerationResult> gptChat( @RequestBody ChatQuestionVo chatQuestionVo) {
return questionHistoryService.gptChat(chatQuestionVo);
}
/**
* 历史对话
*/
@GetMapping("/historyList")
public TableDataInfo historyList() {
startPage();
List<ChatHistoryVo> list = questionHistoryService.historyList(SecurityUtils.getUserId());
list.sort(Comparator.comparing(ChatHistoryVo::getCreateTime));
return getDataTable(list);
}
}
//入参
@Data
public class ChatQuestionVo {
/**
* 问题
*/
private String question;
/**
* 会话id
*/
private String sessionId;
/**
* 调用模型
*/
private String model;
/**
* 助手类型
*/
private String assistant;
private String questionFileName;
private String questionFileUrl;
}
2、接口实现
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationParam;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.ResultCallback;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.utils.Constants;
import com.kingoffice.common.core.constant.ChatModelConstants;
import com.kingoffice.common.core.utils.DateUtils;
import com.kingoffice.common.security.utils.SecurityUtils;
import com.kingoffice.system.domain.aigc.ChatAnswerHistory;
import com.kingoffice.system.domain.aigc.ChatQuestionHistory;
import com.kingoffice.system.domain.vo.ChatHistoryVo;
import com.kingoffice.system.domain.vo.ChatQuestionVo;
import com.kingoffice.system.mapper.aigc.ChatAnswerHistoryMapper;
import com.kingoffice.system.mapper.aigc.ChatQuestionHistoryMapper;
import com.kingoffice.system.service.aigc.IChatQuestionHistoryService;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
/**
* 提问历史对话Service业务层处理
*
* @author kim
* @date 2024-07-19
*/
@Service
public class ChatQuestionHistoryServiceImpl implements IChatQuestionHistoryService {
@Autowired
private ChatQuestionHistoryMapper chatQuestionHistoryMapper;
@Autowired
private ChatAnswerHistoryMapper chatAnswerHistoryMapper;
@Value("${chat.accessKeyAli}")
private String accessKeyAli;
@Value("${chat.accessKeyMoon}")
private String accessKeyMoon;
@Override
public Flux<GenerationResult> gptChat(ChatQuestionVo question) {
StringBuilder contentBuilder = new StringBuilder();
Constants.apiKey = accessKeyAli; // 密钥
Generation gen = new Generation(); // 创建流
// 获取用户信息
Long userId = SecurityUtils.getUserId();
String username = SecurityUtils.getUsername();
if (StringUtils.isEmpty(question.getSessionId())) {
//创建新的会话
question.setSessionId(String.valueOf(System.currentTimeMillis()));
} else {
long sessionTime = Long.parseLong(question.getSessionId());
if (System.currentTimeMillis() - sessionTime > 5 * 60 * 1000) {
//会话超时 自动更新会话
question.setSessionId(String.valueOf(System.currentTimeMillis()));
}
}
// 加载历史message,并创建新的message
List<Message> messages = conversationHistory(question.getSessionId());
messages.add(Message.builder().role(Role.USER.getValue()).content(question.getQuestion()).build());
// 保存用户问题历史记录
GenerationParam param = buildGenerationParam(messages);
// GenerationParam param = buildQwenParam(messages);
return Flux.create(sink -> {
try {
gen.streamCall(param, new ResultCallback<GenerationResult>() {
@Override
public void onEvent(GenerationResult message) {
String newContent = message.getOutput().getChoices().get(0).getMessage().getContent();
//改造message属性 为此属性赋值sessionId
message.getOutput().getChoices().get(0).getMessage().setToolCallId(question.getSessionId());
contentBuilder.append(newContent);
if ("stop".equals(message.getOutput().getChoices().get(0).getFinishReason())) {
ChatQuestionHistory history = new ChatQuestionHistory();
BeanUtils.copyProperties(question, history);
history.setUserId(userId);
history.setCreateBy(username);
history.setMessageRole(Role.USER.getValue());
history.setCreateTime(new Date());
chatQuestionHistoryMapper.insertChatQuestionHistory(history);
ChatAnswerHistory answerHistory = new ChatAnswerHistory();
answerHistory.setId(message.getRequestId());
answerHistory.setAnswer(contentBuilder.toString());
answerHistory.setUserId(userId);
answerHistory.setSessionId(question.getSessionId());
answerHistory.setQuestionId(history.getId());
answerHistory.setMessageRole(message.getOutput().getChoices().get(0).getMessage().getRole());
answerHistory.setCreateBy(username);
answerHistory.setCreateTime(new Date());
chatAnswerHistoryMapper.insertChatAnswerHistory(answerHistory);
contentBuilder.setLength(0); // 清空StringBuilder
}
sink.next(message);
}
@Override
public void onError(Exception err) {
sink.error(err);
}
@Override
public void onComplete() {
sink.complete();
}
});
} catch (Exception e) {
sink.error(e);
}
});
}
@Override
public List<ChatHistoryVo> historyList(Long userId) {
return chatQuestionHistoryMapper.historyList(userId);
}
private GenerationParam buildGenerationParam(List<Message> messages) {
return GenerationParam.builder()
.model(ChatModelConstants.Models.QWEN_MAX)
.messages(messages)
.resultFormat(GenerationParam.ResultFormat.MESSAGE)
.topP(0.8)
.incrementalOutput(true)
.build();
}
//接入开源模型 chatGLM
private QwenParam buildQwenParam(List<Message> messages) {
return QwenParam.builder().model("chatglm3-6b").messages(messages)
.resultFormat(QwenParam.ResultFormat.MESSAGE)
.topP(0.8)
.incrementalOutput(true) // get streaming output incrementally
.build();
}
private List<Message> conversationHistory(String sessionId) {
//根据用户id 会话id 查询10条记录
List<Message> messages = new ArrayList<>();
if (StringUtils.isEmpty(sessionId)) {
return messages;
}
messages = chatAnswerHistoryMapper.conversationHistory(SecurityUtils.getUserId(), sessionId);
return messages;
}
}
三、查询方法
此处给出sql
<select id="conversationHistory" resultType="com.alibaba.dashscope.common.Message">code>
SELECT
content,
role
FROM
(
SELECT
question AS content,
message_role AS role,
create_time
FROM
kim_chat_question_history
WHERE
user_id = #{ userId}
AND session_id = #{ sessionId} UNION ALL
SELECT
answer AS content,
message_role AS role,
create_time
FROM
kim_chat_answer_history
WHERE
user_id = #{ userId}
AND session_id = #{ sessionId}
) AS combined_data
ORDER BY
create_time
LIMIT 10
</select>
<select id="historyList" resultType="com.kingoffice.system.domain.vo.ChatHistoryVo">code>
SELECT
t1.question,
t1.id AS questionId,
t1.create_time AS createTime,
t2.answer,
t2.id AS answerId
FROM
kim_chat_question_history t1
INNER JOIN kim_chat_answer_history t2 ON t1.id = t2.question_id
WHERE
t1.user_id = #{ userId}
ORDER BY
t1.create_time DESC
</select>
四、调用展示
五、总结
至此,后端已实现多轮对话以及流式输出,代码仍有待完善的地方,比如对于token的控制,参数校验,会话时长等等都是需要考虑的问题,后续在给出具体的优化建议,另外关注到Spring新出了SpringAI组件可以快速的实现众多大模型的调用,还有阿里也推出了同样的微服务组件,感兴趣的小伙伴可以去看下,我们此处不再赘述,下一篇我们进行前端调用后端流式接口的demo展示。
声明
本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。