开源模型应用落地-业务整合篇-多种方式调用AI服务(一)

开源技术探险家 2024-09-11 08:01:05 阅读 65

一、前言

    经过对qwen-7b-chat的部署以及与vllm的推理加速的整合,我们成功构建了一套高性能、高可靠、高安全的AI服务能力。现在,我们将着手整合具体的业务场景,以实现完整可落地的功能交付。

    作为上游部门,通常会采用最常用的方式来接入下游服务。为了调用我们的AI服务,我们将使用Java语言,并分别使用HttpClient、OkHttp等工具来实现调用。这样可以确保我们能够高效地与AI服务进行交互。


二、术语

2.1.OkHttp

    是一个开源的Java和Kotlin HTTP客户端库,用于进行网络请求。OkHttp支持HTTP/1.1和HTTP/2协议,具有连接池、请求重试、缓存、拦截器等功能。它还提供了异步和同步请求的支持,并且可以与各种平台和框架无缝集成,是Android开发中常用的网络请求库之一。通过使用OkHttp,开发人员可以轻松地发送HTTP请求、处理响应以及管理网络连接,从而加快应用程序的网络通信速度和效率。

2.2.HttpClient

    是一个用于发送HTTP请求和接收HTTP响应的开源库。它提供了一种方便的方式来与Web服务器进行通信,并执行各种HTTP操作,例如发送GET请求、POST请求等。HttpClient库通常用于编写客户端应用程序或服务,这些应用程序需要与Web服务器或Web API进行通信。它提供了许多功能,包括连接管理、身份验证、请求和响应拦截、Cookie管理等。

2.3.HttpURLConnection

    是Java提供的一个用于发送HTTP请求和接收HTTP响应的类。它是Java标准库中的一部分,用于与Web服务器进行通信。HttpURLConnection类提供了一组方法,使您能够创建HTTP连接、设置请求方法(如GET、POST等)、设置请求头、设置请求体和其他参数,并发送请求到指定的URL。它还提供了方法来获取HTTP响应的状态码、响应头和响应体等信息。

    OkHttp和HttpClient提供了更丰富的功能和更好的性能,适用于大多数情况下。HttpURLConnection是Java标准库中的类,提供了基本的HTTP功能


三、前置条件

3.1. 完成Qwen-7b-Chat(Qwen-1_8B-Chat)模型的本地部署或服务端部署

    参见“开源模型应用落地-qwen-7b-chat与vllm实现推理加速的正确姿势”系列文章

3.2. 完成对外服务接口的封装,屏蔽不同模型的调用差异

    参见“开源模型应用落地-qwen-7b-chat与vllm实现推理加速的正确姿势”系列文章


四、技术实现

4.1. HttpURLConnection

<code>import lombok.extern.slf4j.Slf4j;

import java.io.ByteArrayOutputStream;

import java.io.InputStream;

import java.io.OutputStream;

import java.net.HttpURLConnection;

import java.net.URL;

import java.nio.charset.StandardCharsets;

import java.util.Objects;

@Slf4j

public class QWenCallTest {

private static String url = "http://127.0.0.1:9999/api/chat";

private static String DEFAULT_TEMPLATE = "{\"prompt\":\"%s\",\"history\":%x,\"top_p\":0.9, \"temperature\":0.45,\"repetition_penalty\":1.1, \"max_new_tokens\":8192}";

private static String DEFAULT_USERID = "xxxxx";

private static String DEFAULT_SECRET = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";

private static int DEFAULT_CONNECTION_TIMEOUT = 3 * 1000;

private static int DEFAULT_READ_TIMEOUT = 30 * 1000;

public static void main(String[] args) {

String question = "我家周边有什么好吃、好玩的地方嘛?";

String history = "[{\n" +

"\"from\": \"user\",\n" +

"\"value\": \"你好\"\n" +

"},\n" +

"{\n" +

"\"from\": \"assistant\",\n" +

"\"value\": \"你好!有什么我能为你效劳的吗?\"\n" +

"},\n" +

"{\n" +

"\"from\": \"user\",\n" +

"\"value\": \"我家在广州,你呢?\"\n" +

"},\n" +

"{\n" +

"\"from\": \"assistant\",\n" +

"\"value\": \"我是一个人工智能助手,没有具体的家。\"\n" +

"}]";

String prompt = DEFAULT_TEMPLATE.replace("%s", question).replace("%x", history);

log.info("prompt: {}", prompt);

HttpURLConnection conn = null;

OutputStream os = null;

try {

//1.设置URL

URL urlObject = new URL(url);

//2.打开URL连接

conn = (HttpURLConnection) urlObject.openConnection();

//3.设置请求方式

conn.setRequestMethod("POST");

conn.setRequestProperty("Content-Type", "application/json;charset=utf-8");

conn.setRequestProperty("Accept", "text/event-stream");

conn.setRequestProperty("userId", DEFAULT_USERID);

conn.setRequestProperty("secret", DEFAULT_SECRET);

conn.setDoOutput(true);

conn.setDoInput(true);

// 设置连接超时时间为60秒

conn.setConnectTimeout(DEFAULT_CONNECTION_TIMEOUT);

// 设置读取超时时间为60秒

conn.setReadTimeout(DEFAULT_READ_TIMEOUT);

os = conn.getOutputStream();

os.write(prompt.getBytes("utf-8"));

} catch (Exception e) {

log.error("请求模型接口异常", e);

} finally {

if(!Objects.isNull(os)){

try {

os.flush();

os.close();

} catch (Exception e) {

}

}

}

InputStream is = null;

try{

if(!Objects.isNull(conn)){

int responseCode = conn.getResponseCode();

log.info("Response Code: " + responseCode);

if(responseCode == 200){

is = conn.getInputStream();

}else{

is = conn.getErrorStream();

}

byte[] bytes = new byte[1024];

int len = 0;

while ((len = is.read(bytes)) != -1) {

ByteArrayOutputStream outputStream = new ByteArrayOutputStream();

outputStream.write(bytes, 0, len);

String response = new String(outputStream.toByteArray(), StandardCharsets.UTF_8);

log.info(response);

}

}

} catch (Exception e) {

log.error("请求模型接口异常", e);

} finally {

if (!Objects.isNull(is)) {

try {

is.close();

} catch (Exception e) {

e.printStackTrace();

}

}

}

}

}

4.2. OkHttp

import com.alibaba.fastjson.JSON;

import lombok.extern.slf4j.Slf4j;

import okhttp3.*;

import java.io.ByteArrayOutputStream;

import java.io.InputStream;

import java.nio.charset.StandardCharsets;

import java.util.Objects;

import java.util.concurrent.CountDownLatch;

import java.util.concurrent.TimeUnit;

@Slf4j

public class QWenCallTest {

private static String url = "http://127.0.0.1:9999/api/chat";

private static String DEFAULT_TEMPLATE = "{\"prompt\":\"%s\",\"history\":%x,\"top_p\":0.9, \"temperature\":0.45,\"repetition_penalty\":1.2, \"max_new_tokens\":8192}";

private static long DEFAULT_CONNECTION_TIMEOUT = 3 * 1000;

private static long DEFAULT_WRITE_TIMEOUT = 15 * 1000;

private static long DEFAULT_READ_TIMEOUT = 15 * 1000;

private final static Request.Builder buildHeader(Request.Builder builder) {

return builder

.addHeader("Content-Type", "application/json; charset=utf-8")

.addHeader("userId", "xxxxx")

.addHeader("secret", "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");

}

private final static Request buildRequest(String prompt) {

//创建一个请求体对象(body)

MediaType mediaType = MediaType.parse("application/json");

RequestBody requestBody = RequestBody.create(mediaType,prompt);

return buildHeader(new Request.Builder().post(requestBody))

.url(url).build();

}

public static void chat(String question,String history,CountDownLatch countDownLatch) {

//定义请求的参数

String prompt = DEFAULT_TEMPLATE.replace("%s", question).replace("%x", history);

log.info("prompt: {}", prompt);

//创建一个请求对象

Request request = buildRequest(prompt);

//发送请求:创建了一个请求工具对象,调用执行request对象

OkHttpClient okHttpClient = new OkHttpClient().newBuilder()

.connectTimeout(DEFAULT_CONNECTION_TIMEOUT, TimeUnit.MILLISECONDS)

.writeTimeout(DEFAULT_WRITE_TIMEOUT, TimeUnit.MILLISECONDS)

.readTimeout(DEFAULT_READ_TIMEOUT, TimeUnit.MILLISECONDS)

.build();

InputStream is = null;

try {

Response response = okHttpClient.newCall(request).execute();

//正常返回

if(response.code() == 200){

//打印返回的字符数据

is = response.body().byteStream();

byte[] bytes = new byte[1024];

int len = 0;

while ((len = is.read(bytes)) != -1) {

ByteArrayOutputStream outputStream = new ByteArrayOutputStream();

outputStream.write(bytes, 0, len);

outputStream.flush();

String result = new String(outputStream.toByteArray(), StandardCharsets.UTF_8);

log.info(result);

}

}

else{

String result = response.body().string();

String jsonStr = JSON.parseObject(result).toJSONString();

log.info(jsonStr);

}

} catch (Throwable e) {

log.error("执行异常", e);

} finally {

if (!Objects.isNull(is)) {

try {

is.close();

} catch (Exception e) {

e.printStackTrace();

}

}

countDownLatch.countDown();

}

}

public static void main(String[] args) {

CountDownLatch countDownLatch = new CountDownLatch(1);

String question = "我家周边有什么好吃、好玩的地方嘛?";

String history = "[{\n" +

"\"from\": \"user\",\n" +

"\"value\": \"你好\"\n" +

"},\n" +

"{\n" +

"\"from\": \"assistant\",\n" +

"\"value\": \"你好!有什么我能为你效劳的吗?\"\n" +

"},\n" +

"{\n" +

"\"from\": \"user\",\n" +

"\"value\": \"我家在广州,你呢?\"\n" +

"},\n" +

"{\n" +

"\"from\": \"assistant\",\n" +

"\"value\": \"我是一个人工智能助手,没有具体的家。\"\n" +

"}]";

//流式输出

long starttime = System.currentTimeMillis();

chat(question,history,countDownLatch);

long endtime = System.currentTimeMillis();

System.err.println((endtime-starttime));

try {

countDownLatch.await();

} catch (InterruptedException e) {

e.printStackTrace();

}

}

}

maven依赖

<dependency>

<groupId>com.squareup.okhttp3</groupId>

<artifactId>okhttp</artifactId>

<version>3.14.9</version>

</dependency>

4.3. HttpClient

import lombok.extern.slf4j.Slf4j;

import org.asynchttpclient.AsyncHttpClient;

import org.asynchttpclient.AsyncHttpClientConfig;

import org.asynchttpclient.DefaultAsyncHttpClient;

import org.asynchttpclient.DefaultAsyncHttpClientConfig;

import org.asynchttpclient.channel.DefaultKeepAliveStrategy;

import java.io.IOException;

import java.util.concurrent.CountDownLatch;

import java.util.concurrent.TimeUnit;

@Slf4j

public class QwenCallTest {

private static String url = "http://127.0.0.1:9999/api/chat";

private static String DEFAULT_TEMPLATE = "{\"prompt\":\"%s\",\"history\":%x,\"top_p\":0.9, \"temperature\":0.45,\"repetition_penalty\":1.1, \"max_new_tokens\":8192}";

private static int DEFAULT_CONNECTION_TIMEOUT = 3 * 1000;

private static int DEFAULT_REQUEST_TIMEOUT = 15* 1000;

private static int DEFAULT_READ_TIMEOUT = 15* 1000;

private static String DEFAULT_USERID = "xxxxx";

private static String DEFAULT_SECRET = "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx";

private static AsyncHttpClientConfig asyncHttpClientConfig = new DefaultAsyncHttpClientConfig.Builder()

.setConnectTimeout(DEFAULT_CONNECTION_TIMEOUT)

.setReadTimeout(DEFAULT_READ_TIMEOUT)

.setRequestTimeout(DEFAULT_REQUEST_TIMEOUT)

.setTcpNoDelay(true)

.setMaxConnections(1_000_000)

.setMaxConnectionsPerHost(100_000)

.setMaxRequestRetry(0)

.setSoReuseAddress(true)

.setKeepAlive(true)

.setKeepAliveStrategy(new DefaultKeepAliveStrategy())

.build();

public static final AsyncHttpClient ugc_asyncHttpClient = new DefaultAsyncHttpClient(asyncHttpClientConfig);

public static void chat(String question,String history, CountDownLatch countDownLatch ) {

String prompt = DEFAULT_TEMPLATE.replace("%s", question).replace("%x", history);

log.info("prompt: {}", prompt);

try {

ugc_asyncHttpClient.preparePost(url)

.addHeader("Content-Type", "application/json; charset=utf-8")

.addHeader("userId", DEFAULT_USERID)

.addHeader("secret", DEFAULT_SECRET)

.addHeader("Accept", "text/event-stream")

.setBody(prompt)

.execute(new QwenStreamHandler(countDownLatch))

.get(30, TimeUnit.SECONDS);

} catch (Exception e) {

log.error(prompt + " >> 出现异常");

}

}

public static void main(String[] args) {

CountDownLatch countDownLatch = new CountDownLatch(1);

String question = "我家周边有什么好吃、好玩的地方嘛?";

String history = "[{\n" +

"\"from\": \"user\",\n" +

"\"value\": \"你好\"\n" +

"},\n" +

"{\n" +

"\"from\": \"assistant\",\n" +

"\"value\": \"你好!有什么我能为你效劳的吗?\"\n" +

"},\n" +

"{\n" +

"\"from\": \"user\",\n" +

"\"value\": \"我家在广州,你呢?\"\n" +

"},\n" +

"{\n" +

"\"from\": \"assistant\",\n" +

"\"value\": \"我是一个人工智能助手,没有具体的家。\"\n" +

"}]";

chat(question,history,countDownLatch);

try {

countDownLatch.await();

} catch (InterruptedException e) {

e.printStackTrace();

}

try {

ugc_asyncHttpClient.close();

} catch (IOException e) {

e.printStackTrace();

}

}

}

import io.netty.handler.codec.http.HttpHeaders;

import lombok.extern.slf4j.Slf4j;

import org.asynchttpclient.HttpResponseBodyPart;

import org.asynchttpclient.HttpResponseStatus;

import org.asynchttpclient.handler.StreamedAsyncHandler;

import org.asynchttpclient.netty.EagerResponseBodyPart;

import org.reactivestreams.Publisher;

import org.reactivestreams.Subscriber;

import org.reactivestreams.Subscription;

import java.util.concurrent.CountDownLatch;

@Slf4j

public class QwenStreamHandler implements StreamedAsyncHandler<String> {

private CountDownLatch countDownLatch;

public QwenStreamHandler(CountDownLatch countDownLatch){

this.countDownLatch = countDownLatch;

}

@Override

public State onStream(Publisher publisher) {

publisher.subscribe(new Subscriber() {

@Override

public void onSubscribe(Subscription subscription) {

subscription.request(Long.MAX_VALUE);

}

@Override

public void onNext(Object obj) {

try{

if(obj instanceof EagerResponseBodyPart){

EagerResponseBodyPart part = (EagerResponseBodyPart)obj;

byte[] bytes = part.getBodyPartBytes();

String words = new String(bytes,"UTF-8");

log.info(words);

}

}catch(Throwable e){

log.error("系统异常",e);

}

}

@Override

public void onError(Throwable throwable) {

log.error("系统异常",throwable);

}

@Override

public void onComplete() {

countDownLatch.countDown();

}

});

return State.CONTINUE;

}

@Override

public State onStatusReceived(HttpResponseStatus responseStatus) throws Exception {

log.info("onStatusReceived: {}",responseStatus.getStatusCode());

return responseStatus.getStatusCode() == 200 ? State.CONTINUE : State.ABORT;

}

@Override

public State onHeadersReceived(HttpHeaders headers) throws Exception {

return State.CONTINUE;

}

@Override

public State onBodyPartReceived(HttpResponseBodyPart bodyPart) throws Exception {

return State.CONTINUE;

}

@Override

public void onThrowable(Throwable t) {

log.error("onThrowable", t);

}

@Override

public String onCompleted() throws Exception {

return State.ABORT.name();

}

}

maven依赖

<dependency>

<groupId>org.apache.httpcomponents</groupId>

<artifactId>httpclient</artifactId>

<version>4.5.12</version>

</dependency>

<dependency>

<groupId>org.asynchttpclient</groupId>

<artifactId>async-http-client</artifactId>

<version>2.12.3</version>

</dependency>


五、附带说明

5.1. 需要根据实际情况修改url,userId以及secret的信息

5.2. 如果AI服务已经实现了负载均衡,那么在URL中应该使用SLB(全局负载均衡)的IP地址



声明

本文内容仅代表作者观点,或转载于其他网站,本站不以此文作为商业用途
如有涉及侵权,请联系本站进行删除
转载本站原创文章,请注明来源及作者。