Skip to content

Commit

Permalink
fix: allow isVisionModel function read runtime env var VISION_MODELS
Browse files Browse the repository at this point in the history
  • Loading branch information
JiangYingjin committed Dec 25, 2024
1 parent 0c3d446 commit fb5e9e5
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 25 deletions.
1 change: 1 addition & 0 deletions app/api/config/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const DANGER_CONFIG = {
hideBalanceQuery: serverConfig.hideBalanceQuery,
disableFastLink: serverConfig.disableFastLink,
customModels: serverConfig.customModels,
visionModels: serverConfig.visionModels,
defaultModel: serverConfig.defaultModel,
};

Expand Down
7 changes: 5 additions & 2 deletions app/client/platforms/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ export class ClaudeApi implements LLMApi {
return res?.content?.[0]?.text;
}
async chat(options: ChatOptions): Promise<void> {
const visionModel = isVisionModel(options.config.model);

const accessStore = useAccessStore.getState();

const visionModel = isVisionModel(
options.config.model,
accessStore.visionModels,
);

const shouldStream = !!options.config.stream;

const modelConfig = {
Expand Down
2 changes: 1 addition & 1 deletion app/client/platforms/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ export class GeminiProApi implements LLMApi {
}
const messages = _messages.map((v) => {
let parts: any[] = [{ text: getMessageTextContent(v) }];
if (isVisionModel(options.config.model)) {
if (isVisionModel(options.config.model, accessStore.visionModels)) {
const images = getMessageImages(v);
if (images.length > 0) {
multimodal = true;
Expand Down
7 changes: 6 additions & 1 deletion app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ export class ChatGPTApi implements LLMApi {

let requestPayload: RequestPayload | DalleRequestPayload;

const accessStore = useAccessStore.getState();

const isDalle3 = _isDalle3(options.config.model);
const isO1 = options.config.model.startsWith("o1");
if (isDalle3) {
Expand All @@ -211,7 +213,10 @@ export class ChatGPTApi implements LLMApi {
style: options.config?.style ?? "vivid",
};
} else {
const visionModel = isVisionModel(options.config.model);
const visionModel = isVisionModel(
options.config.model,
accessStore.visionModels,
);
const messages: ChatOptions["messages"] = [];
for (const v of options.messages) {
const content = visionModel
Expand Down
6 changes: 5 additions & 1 deletion app/client/platforms/tencent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ export class HunyuanApi implements LLMApi {
}

async chat(options: ChatOptions) {
const visionModel = isVisionModel(options.config.model);
const accessStore = useAccessStore.getState();
const visionModel = isVisionModel(
options.config.model,
accessStore.visionModels,
);
const messages = options.messages.map((v, index) => ({
// "Messages 中 system 角色必须位于列表的最开始"
role: index !== 0 && v.role === "system" ? "user" : v.role,
Expand Down
11 changes: 7 additions & 4 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ export function ChatActions(props: {
const currentProviderName =
session.mask.modelConfig?.providerName || ServiceProvider.OpenAI;
const allModels = useAllModels();
const customVisionModels = useAccessStore().visionModels;
const models = useMemo(() => {
const filteredModels = allModels.filter((m) => m.available);
const defaultModel = filteredModels.find((m) => m.isDefault);
Expand Down Expand Up @@ -529,7 +530,7 @@ export function ChatActions(props: {
const isMobileScreen = useMobileScreen();

useEffect(() => {
const show = isVisionModel(currentModel);
const show = isVisionModel(currentModel, customVisionModels);
setShowUploadImage(show);
if (!show) {
props.setAttachImages([]);
Expand Down Expand Up @@ -1457,10 +1458,12 @@ function _Chat() {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

const customVisionModels = useAccessStore().visionModels;

const handlePaste = useCallback(
async (event: React.ClipboardEvent<HTMLTextAreaElement>) => {
const currentModel = chatStore.currentSession().mask.modelConfig.model;
if (!isVisionModel(currentModel)) {
if (!isVisionModel(currentModel, customVisionModels)) {
return;
}
const items = (event.clipboardData || window.clipboardData).items;
Expand Down Expand Up @@ -1497,7 +1500,7 @@ function _Chat() {
}
}
},
[attachImages, chatStore],
[attachImages, chatStore, customVisionModels],
);

async function uploadImage() {
Expand Down Expand Up @@ -1545,7 +1548,7 @@ function _Chat() {
setAttachImages(images);
}

// 快捷键 shortcut keys
// 捷键 shortcut keys
const [showShortcutKeyModal, setShowShortcutKeyModal] = useState(false);

useEffect(() => {
Expand Down
7 changes: 6 additions & 1 deletion app/config/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ declare global {
ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not
DISABLE_FAST_LINK?: string; // disallow parse settings from url or not
CUSTOM_MODELS?: string; // to control custom models
VISION_MODELS?: string; // to control vision models
DEFAULT_MODEL?: string; // to control default model in every new chat window

// stability only
Expand Down Expand Up @@ -123,13 +124,16 @@ export const getServerSideConfig = () => {

const disableGPT4 = !!process.env.DISABLE_GPT4;
let customModels = process.env.CUSTOM_MODELS ?? "";
let visionModels = process.env.VISION_MODELS ?? "";
let defaultModel = process.env.DEFAULT_MODEL ?? "";

if (disableGPT4) {
if (customModels) customModels += ",";
customModels += DEFAULT_MODELS.filter(
(m) =>
(m.name.startsWith("gpt-4") || m.name.startsWith("chatgpt-4o") || m.name.startsWith("o1")) &&
(m.name.startsWith("gpt-4") ||
m.name.startsWith("chatgpt-4o") ||
m.name.startsWith("o1")) &&
!m.name.startsWith("gpt-4o-mini"),
)
.map((m) => "-" + m.name)
Expand Down Expand Up @@ -247,6 +251,7 @@ export const getServerSideConfig = () => {
hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY,
disableFastLink: !!process.env.DISABLE_FAST_LINK,
customModels,
visionModels,
defaultModel,
allowedWebDavEndpoints,
};
Expand Down
1 change: 1 addition & 0 deletions app/store/access.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ const DEFAULT_ACCESS_STATE = {
disableGPT4: false,
disableFastLink: false,
customModels: "",
visionModels: "",
defaultModel: "",

// tts config
Expand Down
14 changes: 9 additions & 5 deletions app/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { ServiceProvider } from "./constant";
import { fetch as tauriStreamFetch } from "./utils/stream";
import { VISION_MODEL_REGEXES, EXCLUDE_VISION_MODEL_REGEXES } from "./constant";
import { getClientConfig } from "./config/client";
import { getModelProvider } from "./utils/model";

export function trimTopic(topic: string) {
// Fix an issue where double quotes still show in the Indonesian language
Expand Down Expand Up @@ -253,12 +254,15 @@ export function getMessageImages(message: RequestMessage): string[] {
return urls;
}

export function isVisionModel(model: string) {
export function isVisionModel(model: string, customVisionModels: string) {
const clientConfig = getClientConfig();
const envVisionModels = clientConfig?.visionModels
?.split(",")
.map((m) => m.trim());
if (envVisionModels?.includes(model)) {
const allVisionModelsList = [customVisionModels, clientConfig?.visionModels]
?.join(",")
.split(",")
.map((m) => m.trim())
.filter(Boolean)
.map((m) => getModelProvider(m)[0]);
if (allVisionModelsList?.includes(model)) {
return true;
}
return (
Expand Down
23 changes: 13 additions & 10 deletions test/vision-model-checker.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { isVisionModel } from "../app/utils";

describe("isVisionModel", () => {
const originalEnv = process.env;
const customVisionModels = "custom-vlm,another-vlm";

beforeEach(() => {
jest.resetModules();
Expand All @@ -27,12 +28,12 @@ describe("isVisionModel", () => {
];

visionModels.forEach((model) => {
expect(isVisionModel(model)).toBe(true);
expect(isVisionModel(model, customVisionModels)).toBe(true);
});
});

test("should exclude specific models", () => {
expect(isVisionModel("claude-3-5-haiku-20241022")).toBe(false);
expect(isVisionModel("claude-3-5-haiku-20241022", customVisionModels)).toBe(false);
});

test("should not identify non-vision models", () => {
Expand All @@ -44,24 +45,26 @@ describe("isVisionModel", () => {
];

nonVisionModels.forEach((model) => {
expect(isVisionModel(model)).toBe(false);
expect(isVisionModel(model, customVisionModels)).toBe(false);
});
});

test("should identify models from VISION_MODELS env var", () => {
process.env.VISION_MODELS = "custom-vision-model,another-vision-model";

expect(isVisionModel("custom-vision-model")).toBe(true);
expect(isVisionModel("another-vision-model")).toBe(true);
expect(isVisionModel("unrelated-model")).toBe(false);

expect(isVisionModel("custom-vision-model", customVisionModels)).toBe(true);
expect(isVisionModel("another-vision-model", customVisionModels)).toBe(true);
expect(isVisionModel("custom-vlm", customVisionModels)).toBe(true);
expect(isVisionModel("another-vlm", customVisionModels)).toBe(true);
expect(isVisionModel("unrelated-model", customVisionModels)).toBe(false);
});

test("should handle empty or missing VISION_MODELS", () => {
process.env.VISION_MODELS = "";
expect(isVisionModel("unrelated-model")).toBe(false);
expect(isVisionModel("unrelated-model", customVisionModels)).toBe(false);

delete process.env.VISION_MODELS;
expect(isVisionModel("unrelated-model")).toBe(false);
expect(isVisionModel("gpt-4-vision")).toBe(true);
expect(isVisionModel("unrelated-model", customVisionModels)).toBe(false);
expect(isVisionModel("gpt-4-vision", customVisionModels)).toBe(true);
});
});

0 comments on commit fb5e9e5

Please sign in to comment.