Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(UI) Refactor Add Models for Specific Teams #8592

Merged
merged 21 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions ui/litellm-dashboard/src/components/add_model/add_model_tab.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import React from "react";
import { Card, Form, Button, Tooltip, Typography, Select as AntdSelect } from "antd";
import type { FormInstance } from "antd";
import type { UploadProps } from "antd/es/upload";
import LiteLLMModelNameField from "./litellm_model_name";
import ConditionalPublicModelName from "./conditional_public_model_name";
import ProviderSpecificFields from "./provider_specific_fields";
import AdvancedSettings from "./advanced_settings";
import { Providers, providerLogoMap, getPlaceholder } from "../provider_info_helpers";
import type { Team } from "../key_team_helpers/key_list";

interface AddModelTabProps {
form: FormInstance;
handleOk: () => void;
selectedProvider: Providers;
setSelectedProvider: (provider: Providers) => void;
providerModels: string[];
setProviderModelsFn: (provider: Providers) => void;
getPlaceholder: (provider: Providers) => string;
uploadProps: UploadProps;
showAdvancedSettings: boolean;
setShowAdvancedSettings: (show: boolean) => void;
teams: Team[] | null;
}

const { Title, Link } = Typography;

const AddModelTab: React.FC<AddModelTabProps> = ({
form,
handleOk,
selectedProvider,
setSelectedProvider,
providerModels,
setProviderModelsFn,
getPlaceholder,
uploadProps,
showAdvancedSettings,
setShowAdvancedSettings,
teams,
}) => {
return (
<>
<Title level={2}>Add new model</Title>
<Card>
<Form
form={form}
onFinish={handleOk}
labelCol={{ span: 10 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
>
<>
{/* Provider Selection */}
<Form.Item
rules={[{ required: true, message: "Required" }]}
label="Provider:"
name="custom_llm_provider"
tooltip="E.g. OpenAI, Azure OpenAI, Anthropic, Bedrock, etc."
labelCol={{ span: 10 }}
labelAlign="left"
>
<AntdSelect
showSearch={true}
value={selectedProvider}
onChange={(value) => {
setSelectedProvider(value);
setProviderModelsFn(value);
form.setFieldsValue({
model: [],
model_name: undefined
});
}}
>
{Object.entries(Providers).map(([providerEnum, providerDisplayName]) => (
<AntdSelect.Option
key={providerEnum}
value={providerEnum}
>
<div className="flex items-center space-x-2">
<img
src={providerLogoMap[providerDisplayName]}
alt={`${providerEnum} logo`}
className="w-5 h-5"
onError={(e) => {
// Create a div with provider initial as fallback
const target = e.target as HTMLImageElement;
const parent = target.parentElement;
if (parent) {
const fallbackDiv = document.createElement('div');
fallbackDiv.className = 'w-5 h-5 rounded-full bg-gray-200 flex items-center justify-center text-xs';
fallbackDiv.textContent = providerDisplayName.charAt(0);
parent.replaceChild(fallbackDiv, target);
}
}}
/>
<span>{providerDisplayName}</span>
</div>
</AntdSelect.Option>
))}
</AntdSelect>
</Form.Item>
<LiteLLMModelNameField
selectedProvider={selectedProvider}
providerModels={providerModels}
getPlaceholder={getPlaceholder}
/>

{/* Conditionally Render "Public Model Name" */}
<ConditionalPublicModelName />

<ProviderSpecificFields
selectedProvider={selectedProvider}
uploadProps={uploadProps}
/>
<AdvancedSettings
showAdvancedSettings={showAdvancedSettings}
setShowAdvancedSettings={setShowAdvancedSettings}
/>


<div className="flex justify-between items-center mb-4">
<Tooltip title="Get help on our github">
<Typography.Link href="https://github.com/BerriAI/litellm/issues">
Need Help?
</Typography.Link>
</Tooltip>
<Button htmlType="submit">Add Model</Button>
</div>
</>
</Form>
</Card>


</>
);
};

export default AddModelTab;
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import React from "react";
import { Form } from "antd";
import { TextInput, Text } from "@tremor/react";
import React, { useEffect } from "react";
import { Form, Table, Input } from "antd";
import { Text, TextInput } from "@tremor/react";
import { Row, Col } from "antd";

const ConditionalPublicModelName: React.FC = () => {
Expand All @@ -11,32 +11,61 @@ const ConditionalPublicModelName: React.FC = () => {
const selectedModels = Form.useWatch('model', form) || [];
const showPublicModelName = !selectedModels.includes('all-wildcard');

// Auto-populate model mappings when selected models change
useEffect(() => {
if (selectedModels.length > 0 && !selectedModels.includes('all-wildcard')) {
const mappings = selectedModels.map(model => ({
public_name: model,
litellm_model: model
}));
form.setFieldValue('model_mappings', mappings);
}
}, [selectedModels, form]);

if (!showPublicModelName) return null;

const columns = [
{
title: 'Public Name',
dataIndex: 'public_name',
key: 'public_name',
render: (text: string, record: any, index: number) => {
return (
<TextInput
defaultValue={text}
onChange={(e) => {
const newMappings = [...form.getFieldValue('model_mappings')];
newMappings[index].public_name = e.target.value;
form.setFieldValue('model_mappings', newMappings);
}}
/>
);
}
},
{
title: 'LiteLLM Model',
dataIndex: 'litellm_model',
key: 'litellm_model',
}
];

return (
<>
<Form.Item
label="Public Model Name"
name="model_name"
tooltip="Model name your users will pass in. Also used for load-balancing, LiteLLM will load balance between all models with this public name."
label="Model Mappings"
name="model_mappings"
tooltip="Map public model names to LiteLLM model names for load balancing"
labelCol={{ span: 10 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
required={false}
className="mb-0"
rules={[
({ getFieldValue }) => ({
validator(_, value) {
const selectedModels = getFieldValue('model') || [];
if (!selectedModels.includes('all-wildcard') || value) {
return Promise.resolve();
}
return Promise.reject(new Error('Public Model Name is required unless "All Models" is selected.'));
},
}),
]}
required={true}
>
<TextInput placeholder="my-gpt-4" />
<Table
dataSource={form.getFieldValue('model_mappings')}
columns={columns}
pagination={false}
size="small"
/>
</Form.Item>
<Row>
<Col span={10}></Col>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,28 @@ export const handleAddModelSubmit = async (
) => {
try {
console.log("handling submit for formValues:", formValues);
// If model_name is not provided, use provider.toLowerCase() + "/*"

// Handle wildcard case
if (formValues["model"] && formValues["model"].includes("all-wildcard")) {
const customProvider: Providers = formValues["custom_llm_provider"];
const litellm_custom_provider = provider_map[customProvider as keyof typeof Providers];
const wildcardModel = litellm_custom_provider + "/*";
formValues["model_name"] = wildcardModel;
formValues["model"] = wildcardModel;
}
/**
* For multiple litellm model names - create a separate deployment for each
* - get the list
* - iterate through it
* - create a new deployment for each
*
* For single model name -> make it a 1 item list
*/

// get the list of deployments
let deployments: Array<string> = Array.isArray(formValues["model"])
? formValues["model"]
: [formValues["model"]];
console.log(`received deployments: ${deployments}`);
console.log(`received type of deployments: ${typeof deployments}`);
deployments.forEach(async (litellm_model) => {
console.log(`litellm_model: ${litellm_model}`);

// Get model mappings
const modelMappings = formValues["model_mappings"] || [];

// Create a deployment for each mapping
for (const mapping of modelMappings) {
const litellmParamsObj: Record<string, any> = {};
const modelInfoObj: Record<string, any> = {};

// Set the model name and litellm model from the mapping
const modelName = mapping.public_name;
litellmParamsObj["model"] = mapping.litellm_model;

// Handle pricing conversion before processing other fields
if (formValues.input_cost_per_token) {
formValues.input_cost_per_token = Number(formValues.input_cost_per_token) / 1000000;
Expand All @@ -49,8 +43,7 @@ export const handleAddModelSubmit = async (
// Keep input_cost_per_second as is, no conversion needed

// Iterate through the key-value pairs in formValues
litellmParamsObj["model"] = litellm_model;
let modelName: string = "";
litellmParamsObj["model"] = mapping.litellm_model;
console.log("formValues add deployment:", formValues);
for (const [key, value] of Object.entries(formValues)) {
if (value === "") {
Expand All @@ -61,7 +54,7 @@ export const handleAddModelSubmit = async (
continue;
}
if (key == "model_name") {
modelName = modelName + value;
litellmParamsObj["model"] = value;
} else if (key == "custom_llm_provider") {
console.log("custom_llm_provider:", value);
const mappingResult = provider_map[value]; // Get the corresponding value from the mapping
Expand Down Expand Up @@ -141,11 +134,10 @@ export const handleAddModelSubmit = async (
};

const response: any = await modelCreateCall(accessToken, new_model);
callback && callback()

console.log(`response for model create call: ${response["data"]}`);
});

}

callback && callback()
form.resetFields();
} catch (error) {
message.error("Failed to create model: " + error, 10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,33 @@ import { Row, Col } from "antd";
import { Providers } from "../provider_info_helpers";

interface LiteLLMModelNameFieldProps {
selectedProvider: string;
selectedProvider: Providers;
providerModels: string[];
getPlaceholder: (provider: string) => string;
getPlaceholder: (provider: Providers) => string;
}

const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
selectedProvider,
providerModels,
providerModels,
getPlaceholder,
}) => {
const form = Form.useFormInstance();

const handleModelChange = (value: string[]) => {
const handleModelChange = (value: string | string[]) => {
// Ensure value is always treated as an array
const values = Array.isArray(value) ? value : [value];

// If "all-wildcard" is selected, clear the model_name field
if (value.includes("all-wildcard")) {
form.setFieldsValue({ model_name: undefined });
if (values.includes("all-wildcard")) {
form.setFieldsValue({ model_name: undefined, model_mappings: [] });
} else {
// Update model mappings immediately for each selected model
const mappings = values
.map(model => ({
public_name: model,
litellm_model: model
}));
form.setFieldsValue({ model_mappings: mappings });
}
};

Expand All @@ -39,9 +50,10 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
{(selectedProvider === Providers.Azure) ||
(selectedProvider === Providers.OpenAI_Compatible) ||
(selectedProvider === Providers.Ollama) ? (
<TextInput placeholder={getPlaceholder(selectedProvider.toString())} />
<TextInput placeholder={getPlaceholder(selectedProvider)} />
) : providerModels.length > 0 ? (
<AntSelect
mode="multiple"
allowClear
showSearch
placeholder="Select models"
Expand All @@ -67,7 +79,7 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
style={{ width: '100%' }}
/>
) : (
<TextInput placeholder={getPlaceholder(selectedProvider.toString())} />
<TextInput placeholder={getPlaceholder(selectedProvider)} />
)}
</Form.Item>

Expand Down
Loading