From fb90e98a4e13c18871a3882fcd8fb82ca100da26 Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Thu, 29 Jun 2023 23:52:48 +0800 Subject: [PATCH 1/3] openAI --- .../common/common/constant/RedisKey.java | 5 + .../common/common/utils/DateUtils.java | 16 ++ .../src/main/resources/application.yml | 13 +- .../chat/service/impl/ChatServiceImpl.java | 2 + .../service/strategy/msg/TextMsgHandler.java | 2 +- .../custom/openai/enums/OpenAIModelEnums.java | 92 ++++++++++ .../custom/openai/event/OpenAIEvent.java | 14 ++ .../openai/event/listener/OpenAIListener.java | 63 +++++++ .../custom/openai/service/IOpenAIService.java | 11 ++ .../service/impl/OpenAIServiceImpl.java | 171 ++++++++++++++++++ .../custom/openai/utils/OpenAIUtils.java | 158 ++++++++++++++++ 11 files changed, 542 insertions(+), 5 deletions(-) create mode 100644 mallchat-common/src/main/java/com/abin/mallchat/common/common/utils/DateUtils.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/enums/OpenAIModelEnums.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java create mode 100644 mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java index faa6bd2f..a86d1c43 100644 --- a/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/constant/RedisKey.java @@ -36,6 +36,11 @@ public class RedisKey { */ public static final String USER_SUMMARY_STRING = "userSummary:uid_%d"; + /** + * 用户AI聊天次数 + */ + public static final String USER_CHAT_NUM = "userAIChatNum:uid_%d"; + public static String getKey(String key, Object... objects) { return BASE_KEY + String.format(key, objects); } diff --git a/mallchat-common/src/main/java/com/abin/mallchat/common/common/utils/DateUtils.java b/mallchat-common/src/main/java/com/abin/mallchat/common/common/utils/DateUtils.java new file mode 100644 index 00000000..ecc2a599 --- /dev/null +++ b/mallchat-common/src/main/java/com/abin/mallchat/common/common/utils/DateUtils.java @@ -0,0 +1,16 @@ +package com.abin.mallchat.common.common.utils; + +import java.util.Calendar; +import java.util.Date; + +public class DateUtils extends org.apache.commons.lang3.time.DateUtils { + public static Long getEndTimeByToday() { + Calendar instance = Calendar.getInstance(); + Date now = new Date(); + instance.setTime(now); + instance.set(Calendar.HOUR_OF_DAY, 23); + instance.set(Calendar.MINUTE, 59); + instance.set(Calendar.SECOND, 59); + return instance.getTime().getTime() - now.getTime(); + } +} diff --git a/mallchat-common/src/main/resources/application.yml b/mallchat-common/src/main/resources/application.yml index 53b6b9b8..0293bb69 100644 --- a/mallchat-common/src/main/resources/application.yml +++ b/mallchat-common/src/main/resources/application.yml @@ -12,7 +12,7 @@ mybatis-plus: spring: profiles: #运行的环境 - active: my-prod + active: test application: name: mallchat datasource: @@ -37,8 +37,8 @@ spring: database: 0 # 连接超时时间 timeout: 1800000 - # 设置密码 - password: ${mallchat.redis.password} +# # 设置密码 +# password: ${mallchat.redis.password} lettuce: pool: # 最大阻塞等待时间,负数表示没有限制 @@ -62,4 +62,9 @@ wx: - appId: ${mallchat.wx.appId} # 第一个公众号的appid secret: ${mallchat.wx.secret} # 公众号的appsecret token: ${mallchat.wx.token} # 接口配置里的Token值 - aesKey: ${mallchat.wx.aesKey} # 接口配置里的EncodingAESKey值 \ No newline at end of file + aesKey: ${mallchat.wx.aesKey} # 接口配置里的EncodingAESKey值 +openai: + use-openai: true + ai-user-id: xxxxx + key: xxxxxxx + proxy-url: https://xxxxxxx \ No newline at end of file diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java index adbb1d31..177dbfa6 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/impl/ChatServiceImpl.java @@ -38,6 +38,7 @@ import com.abin.mallchat.custom.chat.service.strategy.msg.AbstractMsgHandler; import com.abin.mallchat.custom.chat.service.strategy.msg.MsgHandlerFactory; import com.abin.mallchat.custom.chat.service.strategy.msg.RecallMsgHandler; +import com.abin.mallchat.custom.openai.event.OpenAIEvent; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; @@ -95,6 +96,7 @@ public Long sendMsg(ChatMessageReq request, Long uid) { msgHandler.saveMsg(insert, request); //发布消息发送事件 applicationEventPublisher.publishEvent(new MessageSendEvent(this, insert.getId())); + applicationEventPublisher.publishEvent(new OpenAIEvent(this, insert.getId())); return insert.getId(); } diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java index 461ed883..5bf8b67a 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java @@ -66,7 +66,7 @@ public void checkMsg(ChatMessageReq request, Long uid) { AssertUtil.equal(replyMsg.getRoomId(), request.getRoomId(), "只能回复相同会话内的消息"); } if (CollectionUtil.isNotEmpty(body.getAtUidList())) { - AssertUtil.isTrue(body.getAtUidList().size() > 10, "一次别艾特这么多人"); + AssertUtil.isFalse(body.getAtUidList().size() > 10, "一次别艾特这么多人"); List atUidList = body.getAtUidList(); Map batch = userInfoCache.getBatch(atUidList); AssertUtil.equal(atUidList.size(), batch.values().size(), "@用户不存在"); diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/enums/OpenAIModelEnums.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/enums/OpenAIModelEnums.java new file mode 100644 index 00000000..194c3df6 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/enums/OpenAIModelEnums.java @@ -0,0 +1,92 @@ +package com.abin.mallchat.custom.openai.enums; + +import lombok.AllArgsConstructor; +import lombok.Getter; + +import java.util.Arrays; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +@AllArgsConstructor +@Getter +public enum OpenAIModelEnums { + // chat + GPT_35_TURBO("gpt-3.5-turbo", 3, 40000), + GPT_35_TURBO_0301("gpt-3.5-turbo-0301", 3, 40000), + GPT_35_TURBO_0613("gpt-3.5-turbo-0613", 3, 40000), + GPT_35_TURBO_16K("gpt-3.5-turbo-16k", 3, 40000), + GPT_35_TURBO_16K_0613("gpt-3.5-turbo-16k-0613", 3, 40000), + // text + ADA("ada", 60, 150000), + ADA_CODE_SEARCH_CODE("ada-code-search-code", 60, 150000), + ADA_CODE_SEARCH_TEXT("ada-code-search-text", 60, 150000), + ADA_SEARCH_DOCUMENT("ada-search-document", 60, 150000), + ADA_SEARCH_QUERY("ada-search-query", 60, 150000), + ADA_SIMILARITY("ada-similarity", 60, 150000), + BABBAGE("babbage", 60, 150000), + BABBAGE_CODE_SEARCH_CODE("babbage-code-search-code", 60, 150000), + BABBAGE_CODE_SEARCH_TEXT("babbage-code-search-text", 60, 150000), + BABBAGE_SEARCH_DOCUMENT("babbage-search-document", 60, 150000), + BABBAGE_SEARCH_QUERY("babbage-search-query", 60, 150000), + BABBAGE_SIMILARITY("babbage-similarity", 60, 150000), + CODE_DAVINCI_EDIT_001("code-davinci-edit-001", 20, 150000), + CODE_SEARCH_ADA_CODE_001("code-search-ada-code-001", 60, 150000), + CODE_SEARCH_ADA_TEXT_001("code-search-ada-text-001", 60, 150000), + CODE_SEARCH_BABBAGE_CODE_001("code-search-babbage-code-001", 60, 150000), + CODE_SEARCH_BABBAGE_TEXT_001("code-search-babbage-text-001", 60, 150000), + CURIE("curie", 60, 150000), + CURIE_INSTRUCT_BETA("curie-instruct-beta", 60, 150000), + CURIE_SEARCH_DOCUMENT("curie-search-document", 60, 150000), + CURIE_SEARCH_QUERY("curie-search-query", 60, 150000), + CURIE_SIMILARITY("curie-similarity", 60, 150000), + DAVINCI("davinci", 60, 150000), + DAVINCI_INSTRUCT_BETA("davinci-instruct-beta", 60, 150000), + DAVINCI_SEARCH_DOCUMENT("davinci-search-document", 60, 150000), + DAVINCI_SEARCH_QUERY("davinci-search-query", 60, 150000), + DAVINCI_SIMILARITY("davinci-similarity", 60, 150000), + TEXT_ADA_001("text-ada-001", 60, 150000), + TEXT_BABBAGE_001("text-babbage-001", 60, 150000), + TEXT_CURIE_001("text-curie-001", 60, 150000), + TEXT_DAVINCI_001("text-davinci-001", 60, 150000), + TEXT_DAVINCI_002("text-davinci-002", 60, 150000), + TEXT_DAVINCI_003("text-davinci-003", 60, 150000), + TEXT_DAVINCI_EDIT_001("text-davinci-edit-001", 20, 150000), + TEXT_EMBEDDING_ADA_002("text-embedding-ada-002", 60, 150000), + TEXT_SEARCH_ADA_DOC_001("text-search-ada-doc-001", 60, 150000), + TEXT_SEARCH_ADA_QUERY_001("text-search-ada-query-001", 60, 150000), + TEXT_SEARCH_BABBAGE_DOC_001("text-search-babbage-doc-001", 60, 150000), + TEXT_SEARCH_BABBAGE_QUERY_001("text-search-babbage-query-001", 60, 150000), + TEXT_SEARCH_CURIE_DOC_001("text-search-curie-doc-001", 60, 150000), + TEXT_SEARCH_CURIE_QUERY_001("text-search-curie-query-001", 60, 150000), + TEXT_SEARCH_DAVINCI_DOC_001("text-search-davinci-doc-001", 60, 150000), + TEXT_SEARCH_DAVINCI_QUERY_001("text-search-davinci-query-001", 60, 150000), + TEXT_SIMILARITY_ADA_001("text-similarity-ada-001", 60, 150000), + TEXT_SIMILARITY_BABBAGE_001("text-similarity-babbage-001", 60, 150000), + TEXT_SIMILARITY_CURIE_001("text-similarity-curie-001", 60, 150000), + TEXT_SIMILARITY_DAVINCI_001("text-similarity-davinci-001", 60, 150000); + + /** + * 名字 + */ + private final String name; + /** + * 每分钟请求数 + */ + private final Integer RPM; + /** + * 每分钟令牌数 + */ + private final Integer TPM; + + private static final Map cache; + + static { + cache = Arrays.stream(OpenAIModelEnums.values()).collect(Collectors.toMap(OpenAIModelEnums::getName, Function.identity())); + } + + public static OpenAIModelEnums of(String name) { + return cache.get(name); + } + +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java new file mode 100644 index 00000000..493d54e3 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/OpenAIEvent.java @@ -0,0 +1,14 @@ +package com.abin.mallchat.custom.openai.event; + +import lombok.Getter; +import org.springframework.context.ApplicationEvent; + +@Getter +public class OpenAIEvent extends ApplicationEvent { + private Long msgId; + + public OpenAIEvent(Object source, Long msgId) { + super(source); + this.msgId = msgId; + } +} \ No newline at end of file diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java new file mode 100644 index 00000000..160a1f15 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/event/listener/OpenAIListener.java @@ -0,0 +1,63 @@ +package com.abin.mallchat.custom.openai.event.listener; + +import com.abin.mallchat.common.chat.dao.MessageDao; +import com.abin.mallchat.common.chat.domain.entity.Message; +import com.abin.mallchat.custom.openai.event.OpenAIEvent; +import com.abin.mallchat.custom.openai.service.IOpenAIService; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.jetbrains.annotations.NotNull; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; +import org.springframework.transaction.event.TransactionalEventListener; + +import static com.abin.mallchat.custom.openai.service.impl.OpenAIServiceImpl.MALL_CHAT_AI_NAME; + +/** + * 是否AI回复监听器 + * + * @author zhaoyuhang + * @date 2023/06/29 + */ +@Slf4j +@Component +public class OpenAIListener { + @Autowired + private IOpenAIService openAIService; + @Autowired + private MessageDao messageDao; + + @TransactionalEventListener(classes = OpenAIEvent.class, fallbackExecution = true) + public void notifyAllOnline(@NotNull OpenAIEvent event) { + Message message = messageDao.getById(event.getMsgId()); + if (ATedAI(message)) { + openAIService.chat(message); + } + } + + /** + * @return boolean + * @了AI + */ + private boolean ATedAI(Message message) { + /* 前端传@信息后取消注释 */ + +// MessageExtra extra = message.getExtra(); +// if (extra == null) { +// return false; +// } +// if (CollectionUtils.isEmpty(extra.getAtUidList())) { +// return false; +// } +// if (!extra.getAtUidList().contains(OpenAIServiceImpl.AI_USER_ID)) { +// return false; +// } + + if (StringUtils.isBlank(message.getContent())) { + return false; + } + return StringUtils.contains(message.getContent(), "@" + MALL_CHAT_AI_NAME) + && StringUtils.isNotBlank(message.getContent().replace(MALL_CHAT_AI_NAME, "").trim()); + } + +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java new file mode 100644 index 00000000..0216a0bb --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/IOpenAIService.java @@ -0,0 +1,11 @@ +package com.abin.mallchat.custom.openai.service; + +import com.abin.mallchat.common.chat.domain.entity.Message; +import com.abin.mallchat.custom.chat.domain.vo.request.ChatMessageReq; + +public interface IOpenAIService { + + + void chat(ChatMessageReq chatMessageReq, Long uid); + void chat(Message message); +} diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java new file mode 100644 index 00000000..49249635 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/service/impl/OpenAIServiceImpl.java @@ -0,0 +1,171 @@ +package com.abin.mallchat.custom.openai.service.impl; + +import cn.hutool.core.bean.BeanUtil; +import cn.hutool.core.thread.NamedThreadFactory; +import cn.hutool.http.HttpResponse; +import cn.hutool.http.HttpUtil; +import com.abin.mallchat.common.chat.domain.entity.Message; +import com.abin.mallchat.common.chat.domain.enums.MessageTypeEnum; +import com.abin.mallchat.common.common.constant.RedisKey; +import com.abin.mallchat.common.common.exception.BusinessException; +import com.abin.mallchat.common.common.handler.GlobalUncaughtExceptionHandler; +import com.abin.mallchat.common.common.utils.DateUtils; +import com.abin.mallchat.common.common.utils.RedisUtils; +import com.abin.mallchat.custom.chat.domain.vo.request.ChatMessageReq; +import com.abin.mallchat.custom.chat.domain.vo.request.msg.TextMsgReq; +import com.abin.mallchat.custom.chat.service.ChatService; +import com.abin.mallchat.custom.openai.enums.OpenAIModelEnums; +import com.abin.mallchat.custom.openai.service.IOpenAIService; +import com.abin.mallchat.custom.openai.utils.OpenAIUtils; +import com.abin.mallchat.custom.user.domain.vo.response.user.UserInfoResp; +import com.abin.mallchat.custom.user.service.UserService; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.StringUtils; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.beans.factory.InitializingBean; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Description; +import org.springframework.context.annotation.Lazy; +import org.springframework.stereotype.Service; + +import java.util.Collections; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +@Slf4j +@Service +public class OpenAIServiceImpl implements IOpenAIService, DisposableBean, InitializingBean { + private static ExecutorService EXECUTOR; + + @Value("${openai.use-openai:false}") + private boolean USE_OPENAI; + @Value("${openai.ai-user-id}") + public Long AI_USER_ID; + + @Value("${openai.model.name:text-davinci-003}") + private String modelName; + @Value("${openai.key}") + private String key; + @Value("${openai.proxy-url:}") + private String proxyUrl; + + @Value("${openai.limit:5}") + private Integer limit; + + @Autowired + private UserService userService; + @Lazy + @Autowired + private ChatService chatService; + + public static String MALL_CHAT_AI_NAME; + + /** + * 聊天 + * + * @param chatMessageReq 提示词 + * @param uid 用户id + */ + @Deprecated + @Override + public void chat(ChatMessageReq chatMessageReq, Long uid) { + TextMsgReq body = BeanUtil.toBean(chatMessageReq.getBody(), TextMsgReq.class); + String content = body.getContent().replace(MALL_CHAT_AI_NAME, "").trim(); + EXECUTOR.execute(() -> { + Long chatNum; + if ((chatNum = userChatNumInrc(uid)) > limit) { + answerMsg("你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧", chatMessageReq.getRoomId(), uid); + } else { + chat(content, chatMessageReq.getRoomId(), uid); + } + }); + + } + + @Override + public void chat(Message message) { + String content = message.getContent().replace(MALL_CHAT_AI_NAME, "").trim(); + Long roomId = message.getRoomId(); + Long uid = message.getFromUid(); + EXECUTOR.execute(() -> { + Long chatNum; + if ((chatNum = userChatNumInrc(uid)) > limit) { + answerMsg("你今天已经和我聊了" + chatNum + "次了,我累了,明天再聊吧", roomId, uid); + } else { + chat(content, roomId, uid); + } + }); + + } + + private Long userChatNumInrc(Long uid) { + //todo:白名单 + return RedisUtils.inc(RedisKey.getKey(RedisKey.USER_CHAT_NUM, uid), DateUtils.getEndTimeByToday().intValue(), TimeUnit.MILLISECONDS); + } + + private void chat(String content, Long roomId, Long uid) { + HttpResponse response = OpenAIUtils.create(key) + .proxyUrl(proxyUrl) + .model(modelName) + .prompt(content) + .send(); + String text = OpenAIUtils.parseText(response); + answerMsg(text, roomId, uid); + } + + private void answerMsg(String text, Long roomId, Long uid) { + ChatMessageReq answerReq = new ChatMessageReq(); + answerReq.setRoomId(roomId); + answerReq.setMsgType(MessageTypeEnum.TEXT.getType()); + UserInfoResp userInfo = userService.getUserInfo(uid); + TextMsgReq textMsgReq = new TextMsgReq(); + textMsgReq.setContent("@" + userInfo.getName() + " " + text); + textMsgReq.setAtUidList(Collections.singletonList(uid)); + answerReq.setBody(textMsgReq); + chatService.sendMsg(answerReq, AI_USER_ID); + } + + + @Override + public void afterPropertiesSet() { + if (!USE_OPENAI) { + return; + } + if (StringUtils.isNotBlank(proxyUrl) && !HttpUtil.isHttp(proxyUrl) && !HttpUtil.isHttps(proxyUrl)) { + throw new BusinessException("openai.proxy-url 配置错误"); + } + OpenAIModelEnums modelEnum = OpenAIModelEnums.of(modelName); + if (modelEnum == null) { + throw new BusinessException("openai.model.name 配置错误"); + } + Integer rpm = modelEnum.getRPM(); + EXECUTOR = new ThreadPoolExecutor(10, 10, + 0L, TimeUnit.MILLISECONDS, + new LinkedBlockingQueue<>(rpm), + new NamedThreadFactory("openAI-chat-gpt", + null, + false, + new GlobalUncaughtExceptionHandler()), + (r, executor) -> { + throw new BusinessException("别问的太快了,我的脑子不够用了"); + }); + UserInfoResp userInfo = userService.getUserInfo(AI_USER_ID); + if (userInfo == null) { + throw new BusinessException("openai.ai-user-id 配置错误"); + } + MALL_CHAT_AI_NAME = userInfo.getName(); + } + + @Override + public void destroy() throws Exception { + EXECUTOR.shutdown(); + if (!EXECUTOR.awaitTermination(30, TimeUnit.SECONDS)) { //最多等30秒,处理不完就拉倒 + if (log.isErrorEnabled()) { + log.error("Timed out while waiting for executor [{}] to terminate", EXECUTOR); + } + } + } +} \ No newline at end of file diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java new file mode 100644 index 00000000..dee05703 --- /dev/null +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java @@ -0,0 +1,158 @@ +package com.abin.mallchat.custom.openai.utils; + +import cn.hutool.http.HttpResponse; +import cn.hutool.http.HttpUtil; +import cn.hutool.json.JSONArray; +import cn.hutool.json.JSONObject; +import com.abin.mallchat.common.common.exception.BusinessException; +import org.apache.commons.lang3.StringUtils; + +import java.util.HashMap; +import java.util.Map; + +public class OpenAIUtils { + + private static final String URL = "https://api.openai.com/v1/completions"; + + private String model = "text-davinci-003"; + + private final Map headers; + /** + * 超时30秒 + */ + private Integer timeout = 30 * 1000; + /** + * 参数用于指定生成文本的最大长度。 + * 它表示生成的文本中最多包含多少个 token。一个 token 可以是一个单词、一个标点符号或一个空格。 + */ + private int maxTokens = 2048; + /** + * 用于控制生成文本的多样性。 + * 较高的温度会导致更多的随机性和多样性,但可能会降低生成文本的质量。默认值为 1,建议在 0.7 到 1.3 之间调整。 + */ + private Object temperature = 1; + /** + * 用于控制生成文本的多样性。 + * 它会根据概率选择最高的几个单词,而不是选择概率最高的单词。默认值为 1,建议在 0.7 到 0.9 之间调整。 + */ + private Object topP = 0.9; + /** + * 用于控制生成文本中重复单词的数量。 + * 较高的惩罚值会导致更少的重复单词,但可能会降低生成文本的流畅性。默认值为 0,建议在 0 到 2 之间调整。 + */ + private Object frequencyPenalty = 0.0; + /** + * 用于控制生成文本中出现特定单词的数量。 + * 较高的惩罚值会导致更少的特定单词,但可能会降低生成文本的流畅性。默认值为 0,建议在 0 到 2 之间调整。 + */ + private Object presencePenalty = 0.6; + + /** + * 提示词 + */ + private String prompt; + + private String proxyUrl; + + public OpenAIUtils(String key) { + HashMap _headers_ = new HashMap<>(); + _headers_.put("Content-Type", "application/json"); + if (StringUtils.isBlank(key)) { + throw new BusinessException("openAi key is blank"); + } + _headers_.put("Authorization", "Bearer " + key); + this.headers = _headers_; + } + + public static OpenAIUtils create(String key) { + return new OpenAIUtils(key); + } + + public static String parseText(HttpResponse response) { + return parseText(response.body()); + } + + public static String parseText(String body) { + JSONObject jsonObj = new JSONObject(body); + JSONArray choicesArr = jsonObj.getJSONArray("choices"); + JSONObject choiceObj = choicesArr.getJSONObject(0); + return choiceObj.getStr("text"); + } + + public OpenAIUtils model(String model) { + this.model = model; + return this; + } + + public OpenAIUtils timeout(int timeout) { + this.timeout = timeout; + return this; + } + + public OpenAIUtils maxTokens(int maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public OpenAIUtils temperature(int temperature) { + this.temperature = temperature; + return this; + } + + public OpenAIUtils topP(int topP) { + this.topP = topP; + return this; + } + + public OpenAIUtils frequencyPenalty(int frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public OpenAIUtils presencePenalty(int presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public OpenAIUtils prompt(String prompt) { + this.prompt = prompt; + return this; + } + + public OpenAIUtils proxyUrl(String proxyUrl) { + this.proxyUrl = proxyUrl; + return this; + } + + public HttpResponse send() { + JSONObject param = new JSONObject(); + param.set("model", model); + param.set("prompt", prompt); + param.set("max_tokens", maxTokens); + param.set("temperature", temperature); + param.set("top_p", topP); + param.set("frequency_penalty", frequencyPenalty); + param.set("presence_penalty", presencePenalty); + return HttpUtil.createPost(StringUtils.isNotBlank(proxyUrl) ? proxyUrl : URL) + .addHeaders(headers) + .body(param.toString()) + .timeout(timeout) + .execute(); + } + + public static void main(String[] args) { + HttpResponse send = OpenAIUtils.create("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") + .timeout(30 * 1000) + .prompt("Spring的启动流程是什么") + .send(); + System.out.println("send = " + send); + // JSON 数据 + // JSON 数据 + JSONObject jsonObj = new JSONObject(send.body()); + JSONArray choicesArr = jsonObj.getJSONArray("choices"); + JSONObject choiceObj = choicesArr.getJSONObject(0); + String text = choiceObj.getStr("text"); + System.out.println("text = " + text); + + } +} \ No newline at end of file From c06934bc896aadfdce5d32d2e9386092001d7772 Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Fri, 30 Jun 2023 00:03:21 +0800 Subject: [PATCH 2/3] openAI --- .../abin/mallchat/custom/openai/utils/OpenAIUtils.java | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java index dee05703..f01c8a83 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/openai/utils/OpenAIUtils.java @@ -5,11 +5,13 @@ import cn.hutool.json.JSONArray; import cn.hutool.json.JSONObject; import com.abin.mallchat.common.common.exception.BusinessException; +import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import java.util.HashMap; import java.util.Map; +@Slf4j public class OpenAIUtils { private static final String URL = "https://api.openai.com/v1/completions"; @@ -73,7 +75,13 @@ public static String parseText(HttpResponse response) { } public static String parseText(String body) { + log.info("body >>> " + body); JSONObject jsonObj = new JSONObject(body); + JSONObject error = jsonObj.getJSONObject("error"); + if (error != null) { + log.error("error >>> " + error); + return "闹脾气了,等会再试试吧~"; + } JSONArray choicesArr = jsonObj.getJSONArray("choices"); JSONObject choiceObj = choicesArr.getJSONObject(0); return choiceObj.getStr("text"); From 169db03d253509e8d84e408e4433c9aa7e8e6aa5 Mon Sep 17 00:00:00 2001 From: zhaoyuhang <1045078399@qq.com> Date: Sat, 1 Jul 2023 13:48:59 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E5=9B=9E=E6=BB=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../custom/chat/service/strategy/msg/TextMsgHandler.java | 1 + 1 file changed, 1 insertion(+) diff --git a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java index 461ed883..a925f533 100644 --- a/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java +++ b/mallchat-custom-server/src/main/java/com/abin/mallchat/custom/chat/service/strategy/msg/TextMsgHandler.java @@ -86,6 +86,7 @@ public void saveMsg(Message msg, ChatMessageReq request) {//插入文本内容 update.setContent(SensitiveWordUtils.filter(body.getContent())); update.setExtra(extra); //如果有回复消息 + if (Objects.nonNull(body.getReplyMsgId())) { Integer gapCount = messageDao.getGapCount(request.getRoomId(), body.getReplyMsgId(), msg.getId()); update.setGapCount(gapCount);