java接入AI大模型个人实践(一)

kimloner 2024-10-03 08:01:01 阅读 74

一、前言

大家好,你的月亮我的心,我是博主小阿金,欢迎各位工友。 近来工作比较清闲、当然这也得益于AI技术的日益成熟、由于一直使用的是发小公司的AI大模型产品、博主也没有跟上潮流去研究如何接入个人项目、本着知其然知其所以然,近几天就来浅浅的探究一下如何接入个人项目。

二、模型选择、demo测试

由于我个人比较懒,又比较抠,本着能省则省的原则,最终选择国产大模型通义千问作为本次接入的大模型之一、接入模型的方法千篇一律,一通百通,最终还是需要落脚在业的设计上。下面是java项目接入千问流程。

参考文章:通义千问模型服务

已开通服务并获得API-KEY:开通DashScope并创建API-KEY。引入Java SDK:java sdk引入流程选择适合的交互方式、单轮对话、多轮对话、流式输出、我选择的是市面上常见的多轮对话+流式输出模式,通过demo中的代码可以看出、要实现多轮对话需要将历史对话的上下文,传递给大模型,以实现对话的持续输出,观察流式输出样例中GenerationParam方法可知,我们只需要将连续的对话以Message集合的形式传入即可。

<code>private static GenerationParam buildGenerationParam(Message userMsg) {

return GenerationParam.builder()

.model("qwen-turbo")

.messages(Arrays.asList(userMsg))

.resultFormat(GenerationParam.ResultFormat.MESSAGE)

.topP(0.8)

.incrementalOutput(true)

.build();

}

进一步得到Message的属性,这便于我们进行相关表结构的设计

public class Message {

String role;

String content;

@SerializedName("tool_calls")

List<ToolCallBase> toolCalls;

@SerializedName("tool_call_id")

String toolCallId;

@SerializedName("name")

String name;

其中需要关注的暂时只有 role content,role对应的为身份标识,content是输出的内容,其中role有以下几种类型

USER("user"),

ASSISTANT("assistant"),

BOT("bot"),

SYSTEM("system"),

ATTACHMENT("attachment"),

TOOL("tool");

常用的就是user和assistant一个是用户发送,一个是系统回复类型。

三、流程及表结构设计思考

设计思考

基于demo代码设计可知

为了实现连续对话、我们必须将对话历史存储在本地需要一定的会话控制逻辑,一是确保连续对话的上下文token,二是对于持续会话时间的控制

基于以上两点,我在数据库建立了两张数据表一张为问题表,一张为回复表用以解决历史对话问题,新增了会话id即sessionId属性用来控制连续一段时间的对话。以下是代码实现、仅供参考,如有高见请私信或者评论指出,不胜感激。

表结构

表结构设计,两表以问题id关联,确保为一个对话,以user_id、与session_id进行区分不同用户的会话历史。

CREATE TABLE `chat_answer_history` (

`id` varchar(50) NOT NULL COMMENT '主键id',

`user_id` bigint NOT NULL COMMENT '用户id',

`session_id` varchar(30) NOT NULL COMMENT '会话id',

`question_id` bigint NOT NULL COMMENT '问题id',

`answer` longtext CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '答案',

`message_role` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '角色',

`message_tool_calls` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL,

`message_tool_call_id` varchar(64) DEFAULT '',

`message_name` varchar(64) DEFAULT '',

`create_time` datetime DEFAULT NULL COMMENT '创建时间',

`create_by` varchar(255) DEFAULT NULL COMMENT '创建者',

PRIMARY KEY (`id`),

KEY `session_id_index` (`session_id`) USING BTREE,

KEY `question_id_index` (`question_id`) USING BTREE,

KEY `user_id_index` (`user_id`) USING BTREE

) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='答案历史对话表';code>

CREATE TABLE `chat_question_history` (

`id` bigint NOT NULL AUTO_INCREMENT COMMENT '主键id',

`user_id` bigint NOT NULL COMMENT '用户id',

`session_id` varchar(30) NOT NULL COMMENT '会话id',

`question` longtext CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci NOT NULL COMMENT '问题',

`question_file_name` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '文件名称',

`question_file_url` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '文件地址',

`model` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '模型类型',

`assistant` varchar(30) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '定义模型方向',

`create_time` datetime DEFAULT NULL COMMENT '创建时间',

`create_by` varchar(255) DEFAULT NULL COMMENT '创建者',

`message_role` varchar(10) CHARACTER SET utf8mb4 COLLATE utf8mb4_0900_ai_ci DEFAULT NULL COMMENT '角色',

PRIMARY KEY (`id`),

KEY `session_id_index` (`session_id`) USING BTREE,

KEY `user_id_index` (`user_id`) USING BTREE

) ENGINE=InnoDB AUTO_INCREMENT=215 DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='提问历史对话表';code>

四、多轮对话+流式输出代码实现

此处给出基本实现逻辑

1、请求接口

import com.alibaba.dashscope.aigc.generation.GenerationResult;

import com.kingoffice.common.core.web.controller.BaseController;

import com.kingoffice.common.core.web.page.TableDataInfo;

import com.kingoffice.common.security.utils.SecurityUtils;

import com.kingoffice.system.domain.vo.ChatHistoryVo;

import com.kingoffice.system.domain.vo.ChatQuestionVo;

import com.kingoffice.system.service.aigc.IChatQuestionHistoryService;

import org.springframework.beans.factory.annotation.Autowired;

import org.springframework.data.annotation.Transient;

import org.springframework.http.MediaType;

import org.springframework.web.bind.annotation.*;

import reactor.core.publisher.Flux;

import java.util.Comparator;

import java.util.List;

@RestController

@RequestMapping("/chat")

public class ChatController extends BaseController{

@Autowired

private IChatQuestionHistoryService questionHistoryService;

/**

* 获取gpt返回

*

* @param chatQuestionVo

* @return

*/

@Transient

@PostMapping(path = "/getChat")

public Flux<GenerationResult> gptChat( @RequestBody ChatQuestionVo chatQuestionVo) {

return questionHistoryService.gptChat(chatQuestionVo);

}

/**

* 历史对话

*/

@GetMapping("/historyList")

public TableDataInfo historyList() {

startPage();

List<ChatHistoryVo> list = questionHistoryService.historyList(SecurityUtils.getUserId());

list.sort(Comparator.comparing(ChatHistoryVo::getCreateTime));

return getDataTable(list);

}

}

//入参

@Data

public class ChatQuestionVo {

/**

* 问题

*/

private String question;

/**

* 会话id

*/

private String sessionId;

/**

* 调用模型

*/

private String model;

/**

* 助手类型

*/

private String assistant;

private String questionFileName;

private String questionFileUrl;

}

2、接口实现

import com.alibaba.dashscope.aigc.generation.Generation;

import com.alibaba.dashscope.aigc.generation.GenerationParam;

import com.alibaba.dashscope.aigc.generation.GenerationResult;

import com.alibaba.dashscope.aigc.generation.models.QwenParam;

import com.alibaba.dashscope.common.Message;

import com.alibaba.dashscope.common.ResultCallback;

import com.alibaba.dashscope.common.Role;

import com.alibaba.dashscope.utils.Constants;

import com.kingoffice.common.core.constant.ChatModelConstants;

import com.kingoffice.common.core.utils.DateUtils;

import com.kingoffice.common.security.utils.SecurityUtils;

import com.kingoffice.system.domain.aigc.ChatAnswerHistory;

import com.kingoffice.system.domain.aigc.ChatQuestionHistory;

import com.kingoffice.system.domain.vo.ChatHistoryVo;

import com.kingoffice.system.domain.vo.ChatQuestionVo;

import com.kingoffice.system.mapper.aigc.ChatAnswerHistoryMapper;

import com.kingoffice.system.mapper.aigc.ChatQuestionHistoryMapper;

import com.kingoffice.system.service.aigc.IChatQuestionHistoryService;

import org.apache.commons.lang3.StringUtils;

import org.springframework.beans.BeanUtils;

import org.springframework.beans.factory.annotation.Autowired;

import org.springframework.beans.factory.annotation.Value;

import org.springframework.stereotype.Service;

import reactor.core.publisher.Flux;

import java.util.ArrayList;

import java.util.Date;

import java.util.List;

/**

* 提问历史对话Service业务层处理

*

* @author kim

* @date 2024-07-19

*/

@Service

public class ChatQuestionHistoryServiceImpl implements IChatQuestionHistoryService {

@Autowired

private ChatQuestionHistoryMapper chatQuestionHistoryMapper;

@Autowired

private ChatAnswerHistoryMapper chatAnswerHistoryMapper;

@Value("${chat.accessKeyAli}")

private String accessKeyAli;

@Value("${chat.accessKeyMoon}")

private String accessKeyMoon;

@Override

public Flux<GenerationResult> gptChat(ChatQuestionVo question) {

StringBuilder contentBuilder = new StringBuilder();

Constants.apiKey = accessKeyAli; // 密钥

Generation gen = new Generation(); // 创建流

// 获取用户信息

Long userId = SecurityUtils.getUserId();

String username = SecurityUtils.getUsername();

if (StringUtils.isEmpty(question.getSessionId())) {

//创建新的会话

question.setSessionId(String.valueOf(System.currentTimeMillis()));

} else {

long sessionTime = Long.parseLong(question.getSessionId());

if (System.currentTimeMillis() - sessionTime > 5 * 60 * 1000) {

//会话超时 自动更新会话

question.setSessionId(String.valueOf(System.currentTimeMillis()));

}

}

// 加载历史message,并创建新的message

List<Message> messages = conversationHistory(question.getSessionId());

messages.add(Message.builder().role(Role.USER.getValue()).content(question.getQuestion()).build());

// 保存用户问题历史记录

GenerationParam param = buildGenerationParam(messages);

// GenerationParam param = buildQwenParam(messages);

return Flux.create(sink -> {

try {

gen.streamCall(param, new ResultCallback<GenerationResult>() {

@Override

public void onEvent(GenerationResult message) {

String newContent = message.getOutput().getChoices().get(0).getMessage().getContent();

//改造message属性 为此属性赋值sessionId

message.getOutput().getChoices().get(0).getMessage().setToolCallId(question.getSessionId());

contentBuilder.append(newContent);

if ("stop".equals(message.getOutput().getChoices().get(0).getFinishReason())) {

ChatQuestionHistory history = new ChatQuestionHistory();

BeanUtils.copyProperties(question, history);

history.setUserId(userId);

history.setCreateBy(username);

history.setMessageRole(Role.USER.getValue());

history.setCreateTime(new Date());

chatQuestionHistoryMapper.insertChatQuestionHistory(history);

ChatAnswerHistory answerHistory = new ChatAnswerHistory();

answerHistory.setId(message.getRequestId());

answerHistory.setAnswer(contentBuilder.toString());

answerHistory.setUserId(userId);

answerHistory.setSessionId(question.getSessionId());

answerHistory.setQuestionId(history.getId());

answerHistory.setMessageRole(message.getOutput().getChoices().get(0).getMessage().getRole());

answerHistory.setCreateBy(username);

answerHistory.setCreateTime(new Date());

chatAnswerHistoryMapper.insertChatAnswerHistory(answerHistory);

contentBuilder.setLength(0); // 清空StringBuilder

}

sink.next(message);

}

@Override

public void onError(Exception err) {

sink.error(err);

}

@Override

public void onComplete() {

sink.complete();

}

});

} catch (Exception e) {

sink.error(e);

}

});

}

@Override

public List<ChatHistoryVo> historyList(Long userId) {

return chatQuestionHistoryMapper.historyList(userId);

}

private GenerationParam buildGenerationParam(List<Message> messages) {

return GenerationParam.builder()

.model(ChatModelConstants.Models.QWEN_MAX)

.messages(messages)

.resultFormat(GenerationParam.ResultFormat.MESSAGE)

.topP(0.8)

.incrementalOutput(true)

.build();

}

//接入开源模型 chatGLM

private QwenParam buildQwenParam(List<Message> messages) {

return QwenParam.builder().model("chatglm3-6b").messages(messages)

.resultFormat(QwenParam.ResultFormat.MESSAGE)

.topP(0.8)

.incrementalOutput(true) // get streaming output incrementally

.build();

}

private List<Message> conversationHistory(String sessionId) {

//根据用户id 会话id 查询10条记录

List<Message> messages = new ArrayList<>();

if (StringUtils.isEmpty(sessionId)) {

return messages;

}

messages = chatAnswerHistoryMapper.conversationHistory(SecurityUtils.getUserId(), sessionId);

return messages;

}

}

三、查询方法

此处给出sql

<select id="conversationHistory" resultType="com.alibaba.dashscope.common.Message">code>

SELECT

content,

role

FROM

(

SELECT

question AS content,

message_role AS role,

create_time

FROM

kim_chat_question_history

WHERE

user_id = #{ userId}

AND session_id = #{ sessionId} UNION ALL

SELECT

answer AS content,

message_role AS role,

create_time

FROM

kim_chat_answer_history

WHERE

user_id = #{ userId}

AND session_id = #{ sessionId}

) AS combined_data

ORDER BY

create_time

LIMIT 10

</select>

<select id="historyList" resultType="com.kingoffice.system.domain.vo.ChatHistoryVo">code>

SELECT

t1.question,

t1.id AS questionId,

t1.create_time AS createTime,

t2.answer,

t2.id AS answerId

FROM

kim_chat_question_history t1

INNER JOIN kim_chat_answer_history t2 ON t1.id = t2.question_id

WHERE

t1.user_id = #{ userId}

ORDER BY

t1.create_time DESC

</select>

四、调用展示

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

五、总结

至此,后端已实现多轮对话以及流式输出,代码仍有待完善的地方,比如对于token的控制,参数校验,会话时长等等都是需要考虑的问题,后续在给出具体的优化建议,另外关注到Spring新出了SpringAI组件可以快速的实现众多大模型的调用,还有阿里也推出了同样的微服务组件,感兴趣的小伙伴可以去看下,我们此处不再赘述,下一篇我们进行前端调用后端流式接口的demo展示。



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。