Skip to content

Commit

Permalink
🔨 refactor(model): 更改原先的实现方法,在 collect table 函数后面增加额外的 sort 处理
Browse files Browse the repository at this point in the history
  • Loading branch information
frostime committed Aug 5, 2024
1 parent 8a4b8a8 commit b023a00
Showing 1 changed file with 39 additions and 11 deletions.
50 changes: 39 additions & 11 deletions app/utils/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,29 @@ const customProvider = (providerName: string) => ({
providerType: "custom",
});

const sortModelTable = (
models: ReturnType<typeof collectModels>,
rule: "custom-first" | "default-first",
) =>
models.sort((a, b) => {
if (a.provider === undefined && b.provider === undefined) {
return 0;
}

let aIsCustom = a.provider?.providerType === "custom";
let bIsCustom = b.provider?.providerType === "custom";

if (aIsCustom === bIsCustom) {
return 0;
}

if (aIsCustom) {
return rule === "custom-first" ? -1 : 1;
} else {
return rule === "custom-first" ? 1 : -1;
}
});

export function collectModelTable(
models: readonly LLMModel[],
customModels: string,
Expand All @@ -22,6 +45,15 @@ export function collectModelTable(
}
> = {};

// default models
models.forEach((m) => {
// using <modelName>@<providerId> as fullName
modelTable[`${m.name}@${m?.provider?.id}`] = {
...m,
displayName: m.name, // 'provider' is copied over if it exists
};
});

// server custom models
customModels
.split(",")
Expand Down Expand Up @@ -80,15 +112,6 @@ export function collectModelTable(
}
});

// default models
models.forEach((m) => {
// using <modelName>@<providerId> as fullName
modelTable[`${m.name}@${m?.provider?.id}`] = {
...m,
displayName: m.name, // 'provider' is copied over if it exists
};
});

return modelTable;
}

Expand Down Expand Up @@ -126,7 +149,9 @@ export function collectModels(
customModels: string,
) {
const modelTable = collectModelTable(models, customModels);
const allModels = Object.values(modelTable);
let allModels = Object.values(modelTable);

allModels = sortModelTable(allModels, "custom-first");

return allModels;
}
Expand All @@ -141,7 +166,10 @@ export function collectModelsWithDefaultModel(
customModels,
defaultModel,
);
const allModels = Object.values(modelTable);
let allModels = Object.values(modelTable);

allModels = sortModelTable(allModels, "custom-first");

return allModels;
}

Expand Down

0 comments on commit b023a00

Please sign in to comment.