Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow isVisionModel function read runtime env var VISION_MODELS #5983

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/api/config/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const DANGER_CONFIG = {
hideBalanceQuery: serverConfig.hideBalanceQuery,
disableFastLink: serverConfig.disableFastLink,
customModels: serverConfig.customModels,
visionModels: serverConfig.visionModels,
defaultModel: serverConfig.defaultModel,
};

Expand Down
7 changes: 5 additions & 2 deletions app/client/platforms/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,13 @@ export class ClaudeApi implements LLMApi {
return res?.content?.[0]?.text;
}
async chat(options: ChatOptions): Promise<void> {
const visionModel = isVisionModel(options.config.model);

const accessStore = useAccessStore.getState();

const visionModel = isVisionModel(
options.config.model,
accessStore.visionModels,
);

const shouldStream = !!options.config.stream;

const modelConfig = {
Expand Down
2 changes: 1 addition & 1 deletion app/client/platforms/google.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ export class GeminiProApi implements LLMApi {
}
const messages = _messages.map((v) => {
let parts: any[] = [{ text: getMessageTextContent(v) }];
if (isVisionModel(options.config.model)) {
if (isVisionModel(options.config.model, accessStore.visionModels)) {
const images = getMessageImages(v);
if (images.length > 0) {
multimodal = true;
Expand Down
7 changes: 6 additions & 1 deletion app/client/platforms/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ export class ChatGPTApi implements LLMApi {

let requestPayload: RequestPayload | DalleRequestPayload;

const accessStore = useAccessStore.getState();

const isDalle3 = _isDalle3(options.config.model);
const isO1 = options.config.model.startsWith("o1");
if (isDalle3) {
Expand All @@ -211,7 +213,10 @@ export class ChatGPTApi implements LLMApi {
style: options.config?.style ?? "vivid",
};
} else {
const visionModel = isVisionModel(options.config.model);
const visionModel = isVisionModel(
options.config.model,
accessStore.visionModels,
);
const messages: ChatOptions["messages"] = [];
for (const v of options.messages) {
const content = visionModel
Expand Down
6 changes: 5 additions & 1 deletion app/client/platforms/tencent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ export class HunyuanApi implements LLMApi {
}

async chat(options: ChatOptions) {
const visionModel = isVisionModel(options.config.model);
const accessStore = useAccessStore.getState();
const visionModel = isVisionModel(
options.config.model,
accessStore.visionModels,
);
const messages = options.messages.map((v, index) => ({
// "Messages 中 system 角色必须位于列表的最开始"
role: index !== 0 && v.role === "system" ? "user" : v.role,
Expand Down
11 changes: 7 additions & 4 deletions app/components/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,7 @@ export function ChatActions(props: {
const currentProviderName =
session.mask.modelConfig?.providerName || ServiceProvider.OpenAI;
const allModels = useAllModels();
const customVisionModels = useAccessStore().visionModels;
const models = useMemo(() => {
const filteredModels = allModels.filter((m) => m.available);
const defaultModel = filteredModels.find((m) => m.isDefault);
Expand Down Expand Up @@ -529,7 +530,7 @@ export function ChatActions(props: {
const isMobileScreen = useMobileScreen();

useEffect(() => {
const show = isVisionModel(currentModel);
const show = isVisionModel(currentModel, customVisionModels);
setShowUploadImage(show);
if (!show) {
props.setAttachImages([]);
Expand Down Expand Up @@ -1457,10 +1458,12 @@ function _Chat() {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

const customVisionModels = useAccessStore().visionModels;

const handlePaste = useCallback(
async (event: React.ClipboardEvent<HTMLTextAreaElement>) => {
const currentModel = chatStore.currentSession().mask.modelConfig.model;
if (!isVisionModel(currentModel)) {
if (!isVisionModel(currentModel, customVisionModels)) {
return;
}
const items = (event.clipboardData || window.clipboardData).items;
Expand Down Expand Up @@ -1497,7 +1500,7 @@ function _Chat() {
}
}
},
[attachImages, chatStore],
[attachImages, chatStore, customVisionModels],
);

async function uploadImage() {
Expand Down Expand Up @@ -1545,7 +1548,7 @@ function _Chat() {
setAttachImages(images);
}

// 快捷键 shortcut keys
// 捷键 shortcut keys
const [showShortcutKeyModal, setShowShortcutKeyModal] = useState(false);

useEffect(() => {
Expand Down
7 changes: 6 additions & 1 deletion app/config/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ declare global {
ENABLE_BALANCE_QUERY?: string; // allow user to query balance or not
DISABLE_FAST_LINK?: string; // disallow parse settings from url or not
CUSTOM_MODELS?: string; // to control custom models
VISION_MODELS?: string; // to control vision models
DEFAULT_MODEL?: string; // to control default model in every new chat window

// stability only
Expand Down Expand Up @@ -123,13 +124,16 @@ export const getServerSideConfig = () => {

const disableGPT4 = !!process.env.DISABLE_GPT4;
let customModels = process.env.CUSTOM_MODELS ?? "";
let visionModels = process.env.VISION_MODELS ?? "";
let defaultModel = process.env.DEFAULT_MODEL ?? "";

if (disableGPT4) {
if (customModels) customModels += ",";
customModels += DEFAULT_MODELS.filter(
(m) =>
(m.name.startsWith("gpt-4") || m.name.startsWith("chatgpt-4o") || m.name.startsWith("o1")) &&
(m.name.startsWith("gpt-4") ||
m.name.startsWith("chatgpt-4o") ||
m.name.startsWith("o1")) &&
!m.name.startsWith("gpt-4o-mini"),
)
.map((m) => "-" + m.name)
Expand Down Expand Up @@ -247,6 +251,7 @@ export const getServerSideConfig = () => {
hideBalanceQuery: !process.env.ENABLE_BALANCE_QUERY,
disableFastLink: !!process.env.DISABLE_FAST_LINK,
customModels,
visionModels,
defaultModel,
allowedWebDavEndpoints,
};
Expand Down
1 change: 1 addition & 0 deletions app/store/access.ts
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ const DEFAULT_ACCESS_STATE = {
disableGPT4: false,
disableFastLink: false,
customModels: "",
visionModels: "",
defaultModel: "",

// tts config
Expand Down
14 changes: 9 additions & 5 deletions app/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { ServiceProvider } from "./constant";
import { fetch as tauriStreamFetch } from "./utils/stream";
import { VISION_MODEL_REGEXES, EXCLUDE_VISION_MODEL_REGEXES } from "./constant";
import { getClientConfig } from "./config/client";
import { getModelProvider } from "./utils/model";

export function trimTopic(topic: string) {
// Fix an issue where double quotes still show in the Indonesian language
Expand Down Expand Up @@ -253,12 +254,15 @@ export function getMessageImages(message: RequestMessage): string[] {
return urls;
}

export function isVisionModel(model: string) {
export function isVisionModel(model: string, customVisionModels: string) {
const clientConfig = getClientConfig();
const envVisionModels = clientConfig?.visionModels
?.split(",")
.map((m) => m.trim());
if (envVisionModels?.includes(model)) {
const allVisionModelsList = [customVisionModels, clientConfig?.visionModels]
?.join(",")
.split(",")
.map((m) => m.trim())
.filter(Boolean)
.map((m) => getModelProvider(m)[0]);
if (allVisionModelsList?.includes(model)) {
return true;
}
return (
Expand Down
23 changes: 13 additions & 10 deletions test/vision-model-checker.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { isVisionModel } from "../app/utils";

describe("isVisionModel", () => {
const originalEnv = process.env;
const customVisionModels = "custom-vlm,another-vlm";

beforeEach(() => {
jest.resetModules();
Expand All @@ -27,12 +28,12 @@ describe("isVisionModel", () => {
];

visionModels.forEach((model) => {
expect(isVisionModel(model)).toBe(true);
expect(isVisionModel(model, customVisionModels)).toBe(true);
});
});

test("should exclude specific models", () => {
expect(isVisionModel("claude-3-5-haiku-20241022")).toBe(false);
expect(isVisionModel("claude-3-5-haiku-20241022", customVisionModels)).toBe(false);
});

test("should not identify non-vision models", () => {
Expand All @@ -44,24 +45,26 @@ describe("isVisionModel", () => {
];

nonVisionModels.forEach((model) => {
expect(isVisionModel(model)).toBe(false);
expect(isVisionModel(model, customVisionModels)).toBe(false);
});
});

test("should identify models from VISION_MODELS env var", () => {
process.env.VISION_MODELS = "custom-vision-model,another-vision-model";

expect(isVisionModel("custom-vision-model")).toBe(true);
expect(isVisionModel("another-vision-model")).toBe(true);
expect(isVisionModel("unrelated-model")).toBe(false);

expect(isVisionModel("custom-vision-model", customVisionModels)).toBe(true);
expect(isVisionModel("another-vision-model", customVisionModels)).toBe(true);
expect(isVisionModel("custom-vlm", customVisionModels)).toBe(true);
expect(isVisionModel("another-vlm", customVisionModels)).toBe(true);
expect(isVisionModel("unrelated-model", customVisionModels)).toBe(false);
});

test("should handle empty or missing VISION_MODELS", () => {
process.env.VISION_MODELS = "";
expect(isVisionModel("unrelated-model")).toBe(false);
expect(isVisionModel("unrelated-model", customVisionModels)).toBe(false);

delete process.env.VISION_MODELS;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid using the delete operator.

Deleting environment variables can have performance overhead and may lead to unpredictable effects. Consider assigning process.env.VISION_MODELS = undefined instead of using delete.

-    delete process.env.VISION_MODELS;
+    process.env.VISION_MODELS = undefined;
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
delete process.env.VISION_MODELS;
process.env.VISION_MODELS = undefined;
🧰 Tools
🪛 Biome (1.9.4)

[error] 66-66: Avoid the delete operator which can impact performance.

Unsafe fix: Use an undefined assignment instead.

(lint/performance/noDelete)

expect(isVisionModel("unrelated-model")).toBe(false);
expect(isVisionModel("gpt-4-vision")).toBe(true);
expect(isVisionModel("unrelated-model", customVisionModels)).toBe(false);
expect(isVisionModel("gpt-4-vision", customVisionModels)).toBe(true);
});
});