基于ChatMemory打造AI取名大师

极客Kimi 2024-07-29 14:01:03 阅读 84

当我们真正开发一款应用时,存储用户与大模型的历史对话是非常重要的,因为大模型需要利用到这些历史对话来理解用户最近一句话到底是什么意思。

比如你跟大模型说“换一个”,如果大模型不基于历史对话来分析,那么大模型根本就不知道你到底想换什么,而ChatMemory真是LangChain4j提供的用来存储历史对话的组件,并且还支持窗口限制、淘汰机制、持久化机制等等扩展功能。

ChatMemory取名大师

我们先回顾一下第一节实现历史对话功能的Demo:

<code>public class _01_HelloWorld {

public static void main(String[] args) {

ChatLanguageModel model = OpenAiChatModel.builder()

.baseUrl("http://langchain4j.dev/demo/openai/v1")

.apiKey("demo")

.build();

UserMessage userMessage1 = UserMessage.userMessage("你好,我是Timi");

Response<AiMessage> response1 = model.generate(userMessage1);

AiMessage aiMessage1 = response1.content(); // 大模型的第一次响应

System.out.println(aiMessage1.text());

System.out.println("----");

// 下面一行代码是重点

Response<AiMessage> response2 = model.generate(userMessage1, aiMessage1, UserMessage.userMessage("我叫什么"));

AiMessage aiMessage2 = response2.content(); // 大模型的第二次响应

System.out.println(aiMessage2.text());

}

}

这种实现方式太过麻烦了,我们用ChatMemory来优化,注意ChatMemory需要基于AiService来使用:

package com.timi;

import dev.langchain4j.data.message.AiMessage;

import dev.langchain4j.data.message.UserMessage;

import dev.langchain4j.memory.ChatMemory;

import dev.langchain4j.memory.chat.MessageWindowChatMemory;

import dev.langchain4j.model.chat.ChatLanguageModel;

import dev.langchain4j.model.openai.OpenAiChatModel;

import dev.langchain4j.model.output.Response;

import dev.langchain4j.service.AiServices;

import dev.langchain4j.service.SystemMessage;

public class _03_ChatMemory {

interface NamingMaster {

String talk(String desc);

}

public static void main(String[] args) {

ChatLanguageModel model = OpenAiChatModel.builder()

.baseUrl("http://langchain4j.dev/demo/openai/v1")

.apiKey("demo")

.build();

ChatMemory chatMemory = MessageWindowChatMemory.withMaxMessages(10);

NamingMaster namingMaster = AiServices.builder(NamingMaster.class)

.chatLanguageModel(model)

.chatMemory(chatMemory)

.build();

System.out.println(namingMaster.talk("帮我取一个很有中国文化内涵的男孩名字,给我一个你觉得最好的就行了"));

System.out.println("---");

System.out.println(namingMaster.talk("换一个"));

}

}

代码执行结果:

岳霖 (Yuè Lín)

---

岳华 (Yuè Huá)

首先定义一个NamingMaster表示取名大师,通过talk()方法来和大师进行交流,最终得到一个满意的名字。

在构造NamingMaster代理对象时,我们除开设置了ChatLanguageModel,还设置了一个ChatMemory对象,而这个ChatMemory对象就是用来存储历史对话记录的,比如我说的“换一个”时候,大模型是知道到底要换的是什么,从而给了我另外一个名字。

MessageWindowChatMemory

ChatMemory是一个接口,默认提供了两个实现类:

MessageWindowChatMemoryTokenWindowChatMemory

而这两个实现类内部都有一个ChatMemoryStore属性,ChatMemoryStore也是一个接口,默认有一个InMemoryChatMemoryStore实现类,该类的实现比较简单:

public class InMemoryChatMemoryStore implements ChatMemoryStore {

private final Map<Object, List<ChatMessage>> messagesByMemoryId = new ConcurrentHashMap<>();

public InMemoryChatMemoryStore() { }

@Override

public List<ChatMessage> getMessages(Object memoryId) {

return messagesByMemoryId.computeIfAbsent(memoryId, ignored -> new ArrayList<>());

}

@Override

public void updateMessages(Object memoryId, List<ChatMessage> messages) {

messagesByMemoryId.put(memoryId, messages);

}

@Override

public void deleteMessages(Object memoryId) {

messagesByMemoryId.remove(memoryId);

}

}

本质上就是一个ConcurrentHashMap,所以原理上我们可以自定义ChatMemoryStore的实现类来实现将ChatMessage持久化到磁盘,比如:

static class PersistentChatMemoryStore implements ChatMemoryStore {

private final DB db = DBMaker.fileDB("chat-memory.db").transactionEnable().make();

private final Map<String, String> map = db.hashMap("messages", STRING, STRING).createOrOpen();

@Override

public List<ChatMessage> getMessages(Object memoryId) {

String json = map.get((String) memoryId);

return messagesFromJson(json);

}

@Override

public void updateMessages(Object memoryId, List<ChatMessage> messages) {

String json = messagesToJson(messages);

map.put((String) memoryId, json);

db.commit();

}

@Override

public void deleteMessages(Object memoryId) {

map.remove((String) memoryId);

db.commit();

}

}

需要添加依赖:

<dependency>

<groupId>org.mapdb</groupId>

<artifactId>mapdb</artifactId>

<version>3.0.9</version>

<exclusions>

<exclusion>

<groupId>org.jetbrains.kotlin</groupId>

<artifactId>kotlin-stdlib</artifactId>

</exclusion>

</exclusions>

</dependency>

这样我们就可以自己定义ChatMemory从而实现持久化了:

ChatMemory chatMemory = MessageWindowChatMemory.builder()

.chatMemoryStore(new PersistentChatMemoryStore())

.maxMessages(10)

.build();

这里我们仍然利用的是MessageWindowChatMemory,只是修改了chatMemoryStore属性,同样我们也可以修改TokenWindowChatMemory,这里就不再重复演示了。

那么MessageWindowChatMemory除开可以存储ChatMessage之外,还有什么特殊的吗?

我们直接看它的add()方法实现:

@Override

public void add(ChatMessage message) {

// 从ChatMemoryStore获取当前所存储的ChatMessage

List<ChatMessage> messages = messages();

// 如果待添加的是SystemMessage

if (message instanceof SystemMessage) {

Optional<SystemMessage> systemMessage = findSystemMessage(messages);

if (systemMessage.isPresent()) {

// 如果存在相同的SystemMessage,则什么都不做,直接返回

if (systemMessage.get().equals(message)) {

return; // do not add the same system message

} else {

messages.remove(systemMessage.get()); // need to replace existing system message

}

}

}

// 添加

messages.add(message);

// 如果超过了maxMessages限制,则会淘汰List最前面的,也就是最旧的ChatMessage

// 注意,SystemMessage不会被淘汰

ensureCapacity(messages, maxMessages);

// 将改变了的List更新到ChatMemoryStore中

store.updateMessages(id, messages);

}

从以上源码可以看出MessageWindowChatMemory有淘汰机制,可以设置maxMessages,超过maxMessages会淘汰最旧的ChatMessage,SystemMessage不会被淘汰。

TokenWindowChatMemory

TokenWindowChatMemory和MessageWindowChatMemory类似,区别在于计算容量的方式不一样,MessageWindowChatMemory直接取的是List的大小,而TokenWindowChatMemory会利用指定的Tokenizer对List对应的Token数进行估算,然后和设置的maxTokens进行比较,超过maxTokens也会进行淘汰,也是淘汰最旧的ChatMessage。

Tokenizer是一个接口,默认提供了OpenAiTokenizer实现类,是用来估算一条ChatMessage对应多少个Token的,很多大模型的API都是按使用的Token数来收费的,所以在对成本比较敏感时,建议使用TokenWindowChatMemory来对一个会话使用的总Token数进行控制。

独立ChatMemory

我们再看一眼之前的代码:

public static void main(String[] args) {

ChatLanguageModel model = OpenAiChatModel.builder()

.baseUrl("http://langchain4j.dev/demo/openai/v1")

.apiKey("demo")

.build();

ChatMemory chatMemory = MessageWindowChatMemory.builder()

.chatMemoryStore(new PersistentChatMemoryStore())

.maxMessages(10)

.build();

NamingMaster namingMaster = AiServices.builder(NamingMaster.class)

.chatLanguageModel(model)

.chatMemory(chatMemory)

.build();

System.out.println(namingMaster.talk("帮我取一个很有中国文化内涵的男孩名字,给我一个你觉得最好的就行了"));

System.out.println("---");

System.out.println(namingMaster.talk("换一个"));

}

以上代码有什么问题吗?如果只有一个用户用是没问题的,那如果有多个用户用呢?

比如NamingMaster代理对象被多个用户同时使用,那么这多个用户使用的是同一个ChatMemory,那就会出现这多个用户的对话记录混杂在了一起,这肯定是有问题的,所以需要有一种机制能够使得每个用户对应一个ChatMemory。

所以MessageWindowChatMemory和TokenWindowChatMemory其实都还有一个id属性,而具体的id值则有用于使用时动态传入。

我们改造一下AiServices中设置ChatMemory的方式:

NamingMaster namingMaster = AiServices.builder(NamingMaster.class)

.chatLanguageModel(model)

.chatMemoryProvider(userId -> MessageWindowChatMemory.withMaxMessages(10))

.build();

以上代码表示,NamingMaster代理对象对应的ChatMemory并不是固定的,会根据设置的ChatMemoryProvider来提供,而ChatMemoryProvider是一个Lambda表达式,意思是每个不同的userId对应不同的ChatMemory对象。

同时,我们也需要改造talk()方法来支持动态传入userId:

interface NamingMaster {

String talk(@MemoryId String userId, @UserMessage String desc);

}

完整代码:

package com.timi;

import dev.langchain4j.agent.tool.P;

import dev.langchain4j.data.message.AiMessage;

import dev.langchain4j.data.message.ChatMessage;

import dev.langchain4j.memory.ChatMemory;

import dev.langchain4j.memory.chat.MessageWindowChatMemory;

import dev.langchain4j.model.chat.ChatLanguageModel;

import dev.langchain4j.model.openai.OpenAiChatModel;

import dev.langchain4j.model.output.Response;

import dev.langchain4j.service.AiServices;

import dev.langchain4j.service.MemoryId;

import dev.langchain4j.service.SystemMessage;

import dev.langchain4j.service.UserMessage;

import dev.langchain4j.store.memory.chat.ChatMemoryStore;

import org.mapdb.DB;

import org.mapdb.DBMaker;

import java.util.List;

import java.util.Map;

import static dev.langchain4j.data.message.ChatMessageDeserializer.messagesFromJson;

import static dev.langchain4j.data.message.ChatMessageSerializer.messagesToJson;

import static org.mapdb.Serializer.STRING;

public class _03_ChatMemory {

interface NamingMaster {

String talk(@MemoryId String userId, @UserMessage String desc);

}

public static void main(String[] args) {

ChatLanguageModel model = OpenAiChatModel.builder()

.baseUrl("http://langchain4j.dev/demo/openai/v1")

.apiKey("demo")

.build();

NamingMaster namingMaster = AiServices.builder(NamingMaster.class)

.chatLanguageModel(model)

.chatMemoryProvider(userId -> MessageWindowChatMemory.withMaxMessages(10))

.build();

System.out.println(namingMaster.talk("1", "帮我取一个很有中国文化内涵的男孩名字,给我一个你觉得最好的就行了"));

System.out.println("---");

System.out.println(namingMaster.talk("2", "换一个"));

}

static class PersistentChatMemoryStore implements ChatMemoryStore {

private final DB db = DBMaker.fileDB("chat-memory.db").transactionEnable().make();

private final Map<String, String> map = db.hashMap("messages", STRING, STRING).createOrOpen();

@Override

public List<ChatMessage> getMessages(Object memoryId) {

String json = map.get((String) memoryId);

return messagesFromJson(json);

}

@Override

public void updateMessages(Object memoryId, List<ChatMessage> messages) {

String json = messagesToJson(messages);

map.put((String) memoryId, json);

db.commit();

}

@Override

public void deleteMessages(Object memoryId) {

map.remove((String) memoryId);

db.commit();

}

}

}

由于以上代码传入的userId不同,所以代码执行结果为:

玉山 (Yushan)

---

好的,请问您想要换成什么样的内容呢?

这就表示,两个不同的用户使用的是独立的ChatMemory。

AiServices整合ChatMemory源码分析

最后,我们再来看看AiServices中是如何利用ChatMemory来实现对话历史记录的。

视线转移到第二节提到的DefaultAiServices中的代理对象中的invoke()方法中,在第二节我们解析了invoke()方法源码中会根据当前调用的方法信息和参数解析出SystemMessage和UserMessage,然后就会执行以下代码:

Object memoryId = memoryId(method, args).orElse(DEFAULT);

memoryId()方法其实就是解析方法参数中加了@MemoryId注解的参数值,我们的案例就是传入的userId,仅接着就会执行:

if (context.hasChatMemory()) {

// 根据memoryId获取或创建ChatMemory

ChatMemory chatMemory = context.chatMemory(memoryId);

// 将SystemMessage、UserMessage添加到ChatMemory中

systemMessage.ifPresent(chatMemory::add);

chatMemory.add(userMessage);

}

这里的context为AiServiceContext,它内部有一个chatMemories属性,类型为Map<Object, ChatMemory> ,就是专门用来存储memoryId和ChatMemory对象之间的映射关系的。

以上代码只是新增一条UserMessage,而传入给大模型的得是所有的对话历史,所以后续会执行:

List<ChatMessage> messages;

if (context.hasChatMemory()) {

messages = context.chatMemory(memoryId).messages();

} else {

messages = new ArrayList<>();

systemMessage.ifPresent(messages::add);

messages.add(userMessage);

}

根据memoryId把对应的ChatMemory中存储的所有ChatMessage获取出来,然后传入给大模型就可以了。

本节总结

以上就是关于ChatMemory的作用和实现原理,在实际应用开发中,ChatMemory的作用是重要的,下一节将介绍LangChain4j的工具机制时,其中也离不开ChatMemory的应用的,敬请期待。



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。