From fb5e9e5aed796129714bd4ceeef581d2bc9bd1f3 Mon Sep 17 00:00:00 2001 From: JiangYinjin Date: Thu, 26 Dec 2024 03:33:24 +0800 Subject: [PATCH] fix: allow isVisionModel function read runtime env var VISION_MODELS --- app/api/config/route.ts | 1 + app/client/platforms/anthropic.ts | 7 +++++-- app/client/platforms/google.ts | 2 +- app/client/platforms/openai.ts | 7 ++++++- app/client/platforms/tencent.ts | 6 +++++- app/components/chat.tsx | 11 +++++++---- app/config/server.ts | 7 ++++++- app/store/access.ts | 1 + app/utils.ts | 14 +++++++++----- test/vision-model-checker.test.ts | 23 +++++++++++++---------- 10 files changed, 54 insertions(+), 25 deletions(-) diff --git a/app/api/config/route.ts b/app/api/config/route.ts index b0d9da03103..445dae30c1f 100644 --- a/app/api/config/route.ts +++ b/app/api/config/route.ts @@ -13,6 +13,7 @@ const DANGER_CONFIG = { hideBalanceQuery: serverConfig.hideBalanceQuery, disableFastLink: serverConfig.disableFastLink, customModels: serverConfig.customModels, + visionModels: serverConfig.visionModels, defaultModel: serverConfig.defaultModel, }; diff --git a/app/client/platforms/anthropic.ts b/app/client/platforms/anthropic.ts index 6747221a861..6e619a4d9c4 100644 --- a/app/client/platforms/anthropic.ts +++ b/app/client/platforms/anthropic.ts @@ -84,10 +84,13 @@ export class ClaudeApi implements LLMApi { return res?.content?.[0]?.text; } async chat(options: ChatOptions): Promise { - const visionModel = isVisionModel(options.config.model); - const accessStore = useAccessStore.getState(); + const visionModel = isVisionModel( + options.config.model, + accessStore.visionModels, + ); + const shouldStream = !!options.config.stream; const modelConfig = { diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts index a7bce4fc2d0..c10e4969d44 100644 --- a/app/client/platforms/google.ts +++ b/app/client/platforms/google.ts @@ -83,7 +83,7 @@ export class GeminiProApi implements LLMApi { } const messages = _messages.map((v) => { let parts: any[] = [{ text: getMessageTextContent(v) }]; - if (isVisionModel(options.config.model)) { + if (isVisionModel(options.config.model, accessStore.visionModels)) { const images = getMessageImages(v); if (images.length > 0) { multimodal = true; diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 15cfb7ca602..6d154251eb8 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -194,6 +194,8 @@ export class ChatGPTApi implements LLMApi { let requestPayload: RequestPayload | DalleRequestPayload; + const accessStore = useAccessStore.getState(); + const isDalle3 = _isDalle3(options.config.model); const isO1 = options.config.model.startsWith("o1"); if (isDalle3) { @@ -211,7 +213,10 @@ export class ChatGPTApi implements LLMApi { style: options.config?.style ?? "vivid", }; } else { - const visionModel = isVisionModel(options.config.model); + const visionModel = isVisionModel( + options.config.model, + accessStore.visionModels, + ); const messages: ChatOptions["messages"] = []; for (const v of options.messages) { const content = visionModel diff --git a/app/client/platforms/tencent.ts b/app/client/platforms/tencent.ts index 580844a5b31..5a1f39b392a 100644 --- a/app/client/platforms/tencent.ts +++ b/app/client/platforms/tencent.ts @@ -94,7 +94,11 @@ export class HunyuanApi implements LLMApi { } async chat(options: ChatOptions) { - const visionModel = isVisionModel(options.config.model); + const accessStore = useAccessStore.getState(); + const visionModel = isVisionModel( + options.config.model, + accessStore.visionModels, + ); const messages = options.messages.map((v, index) => ({ // "Messages 中 system 角色必须位于列表的最开始" role: index !== 0 && v.role === "system" ? "user" : v.role, diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 51fe74fe7be..7bb3b9586b9 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -490,6 +490,7 @@ export function ChatActions(props: { const currentProviderName = session.mask.modelConfig?.providerName || ServiceProvider.OpenAI; const allModels = useAllModels(); + const customVisionModels = useAccessStore().visionModels; const models = useMemo(() => { const filteredModels = allModels.filter((m) => m.available); const defaultModel = filteredModels.find((m) => m.isDefault); @@ -529,7 +530,7 @@ export function ChatActions(props: { const isMobileScreen = useMobileScreen(); useEffect(() => { - const show = isVisionModel(currentModel); + const show = isVisionModel(currentModel, customVisionModels); setShowUploadImage(show); if (!show) { props.setAttachImages([]); @@ -1457,10 +1458,12 @@ function _Chat() { // eslint-disable-next-line react-hooks/exhaustive-deps }, []); + const customVisionModels = useAccessStore().visionModels; + const handlePaste = useCallback( async (event: React.ClipboardEvent) => { const currentModel = chatStore.currentSession().mask.modelConfig.model; - if (!isVisionModel(currentModel)) { + if (!isVisionModel(currentModel, customVisionModels)) { return; } const items = (event.clipboardData || window.clipboardData).items; @@ -1497,7 +1500,7 @@ function _Chat() { } } }, - [attachImages, chatStore], + [attachImages, chatStore, customVisionModels], ); async function uploadImage() { @@ -1545,7 +1548,7 @@ function _Chat() { setAttachImages(images); } - // 快捷键 shortcut keys + // 捷键 shortcut keys const [showShortcutKeyModal, setShowShortcutKeyModal] = useState(false); useEffect(() => { diff --git a/app/config/server.ts b/app/config/server.ts index 9d6b3c2b8da..7f93822d5b5 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -21,6 +21,7 @@ declare global { ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not DISABLE_FAST_LINK?: string; // disallow parse settings from url or not CUSTOM_MODELS?: string; // to control custom models + VISION_MODELS?: string; // to control vision models DEFAULT_MODEL?: string; // to control default model in every new chat window // stability only @@ -123,13 +124,16 @@ export const getServerSideConfig = () => { const disableGPT4 = !!process.env.DISABLE_GPT4; let customModels = process.env.CUSTOM_MODELS ?? ""; + let visionModels = process.env.VISION_MODELS ?? ""; let defaultModel = process.env.DEFAULT_MODEL ?? ""; if (disableGPT4) { if (customModels) customModels += ","; customModels += DEFAULT_MODELS.filter( (m) => - (m.name.startsWith("gpt-4") || m.name.startsWith("chatgpt-4o") || m.name.startsWith("o1")) && + (m.name.startsWith("gpt-4") || + m.name.startsWith("chatgpt-4o") || + m.name.startsWith("o1")) && !m.name.startsWith("gpt-4o-mini"), ) .map((m) => "-" + m.name) @@ -247,6 +251,7 @@ export const getServerSideConfig = () => { hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY, disableFastLink: !!process.env.DISABLE_FAST_LINK, customModels, + visionModels, defaultModel, allowedWebDavEndpoints, }; diff --git a/app/store/access.ts b/app/store/access.ts index 4796b2fe84e..82cea5236ec 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -123,6 +123,7 @@ const DEFAULT_ACCESS_STATE = { disableGPT4: false, disableFastLink: false, customModels: "", + visionModels: "", defaultModel: "", // tts config diff --git a/app/utils.ts b/app/utils.ts index 962e68a101c..30df0a5999c 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -7,6 +7,7 @@ import { ServiceProvider } from "./constant"; import { fetch as tauriStreamFetch } from "./utils/stream"; import { VISION_MODEL_REGEXES, EXCLUDE_VISION_MODEL_REGEXES } from "./constant"; import { getClientConfig } from "./config/client"; +import { getModelProvider } from "./utils/model"; export function trimTopic(topic: string) { // Fix an issue where double quotes still show in the Indonesian language @@ -253,12 +254,15 @@ export function getMessageImages(message: RequestMessage): string[] { return urls; } -export function isVisionModel(model: string) { +export function isVisionModel(model: string, customVisionModels: string) { const clientConfig = getClientConfig(); - const envVisionModels = clientConfig?.visionModels - ?.split(",") - .map((m) => m.trim()); - if (envVisionModels?.includes(model)) { + const allVisionModelsList = [customVisionModels, clientConfig?.visionModels] + ?.join(",") + .split(",") + .map((m) => m.trim()) + .filter(Boolean) + .map((m) => getModelProvider(m)[0]); + if (allVisionModelsList?.includes(model)) { return true; } return ( diff --git a/test/vision-model-checker.test.ts b/test/vision-model-checker.test.ts index 734e992d829..5e5ffe56700 100644 --- a/test/vision-model-checker.test.ts +++ b/test/vision-model-checker.test.ts @@ -2,6 +2,7 @@ import { isVisionModel } from "../app/utils"; describe("isVisionModel", () => { const originalEnv = process.env; + const customVisionModels = "custom-vlm,another-vlm"; beforeEach(() => { jest.resetModules(); @@ -27,12 +28,12 @@ describe("isVisionModel", () => { ]; visionModels.forEach((model) => { - expect(isVisionModel(model)).toBe(true); + expect(isVisionModel(model, customVisionModels)).toBe(true); }); }); test("should exclude specific models", () => { - expect(isVisionModel("claude-3-5-haiku-20241022")).toBe(false); + expect(isVisionModel("claude-3-5-haiku-20241022", customVisionModels)).toBe(false); }); test("should not identify non-vision models", () => { @@ -44,24 +45,26 @@ describe("isVisionModel", () => { ]; nonVisionModels.forEach((model) => { - expect(isVisionModel(model)).toBe(false); + expect(isVisionModel(model, customVisionModels)).toBe(false); }); }); test("should identify models from VISION_MODELS env var", () => { process.env.VISION_MODELS = "custom-vision-model,another-vision-model"; - - expect(isVisionModel("custom-vision-model")).toBe(true); - expect(isVisionModel("another-vision-model")).toBe(true); - expect(isVisionModel("unrelated-model")).toBe(false); + + expect(isVisionModel("custom-vision-model", customVisionModels)).toBe(true); + expect(isVisionModel("another-vision-model", customVisionModels)).toBe(true); + expect(isVisionModel("custom-vlm", customVisionModels)).toBe(true); + expect(isVisionModel("another-vlm", customVisionModels)).toBe(true); + expect(isVisionModel("unrelated-model", customVisionModels)).toBe(false); }); test("should handle empty or missing VISION_MODELS", () => { process.env.VISION_MODELS = ""; - expect(isVisionModel("unrelated-model")).toBe(false); + expect(isVisionModel("unrelated-model", customVisionModels)).toBe(false); delete process.env.VISION_MODELS; - expect(isVisionModel("unrelated-model")).toBe(false); - expect(isVisionModel("gpt-4-vision")).toBe(true); + expect(isVisionModel("unrelated-model", customVisionModels)).toBe(false); + expect(isVisionModel("gpt-4-vision", customVisionModels)).toBe(true); }); }); \ No newline at end of file