Skip to content

Commit

Permalink
(UI) Refactor Add Models for Specific Teams (BerriAI#8592)
Browse files Browse the repository at this point in the history
* ui - use common team dropdown component

* re-use team component

* rename org field on add model

* handle add model submit

* working view model_id and team_id on root models page

* cleaner

* show all fields

* working model info view

* working team info selector

* clean up team id

* new component for model dashboard

* ui show table with dropdown

* make public model names like email

* revert changes to litellm model name

* fix litellm model name

* ui fix public model

* fix mappings

* fix conditional text input

* fix message

* ui fix bulk add models
  • Loading branch information
ishaan-jaff authored and abhijitherekar committed Feb 20, 2025
1 parent 1298be2 commit b63dd33
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 149 deletions.
138 changes: 138 additions & 0 deletions ui/litellm-dashboard/src/components/add_model/add_model_tab.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import React from "react";
import { Card, Form, Button, Tooltip, Typography, Select as AntdSelect } from "antd";
import type { FormInstance } from "antd";
import type { UploadProps } from "antd/es/upload";
import LiteLLMModelNameField from "./litellm_model_name";
import ConditionalPublicModelName from "./conditional_public_model_name";
import ProviderSpecificFields from "./provider_specific_fields";
import AdvancedSettings from "./advanced_settings";
import { Providers, providerLogoMap, getPlaceholder } from "../provider_info_helpers";
import type { Team } from "../key_team_helpers/key_list";

interface AddModelTabProps {
form: FormInstance;
handleOk: () => void;
selectedProvider: Providers;
setSelectedProvider: (provider: Providers) => void;
providerModels: string[];
setProviderModelsFn: (provider: Providers) => void;
getPlaceholder: (provider: Providers) => string;
uploadProps: UploadProps;
showAdvancedSettings: boolean;
setShowAdvancedSettings: (show: boolean) => void;
teams: Team[] | null;
}

const { Title, Link } = Typography;

const AddModelTab: React.FC<AddModelTabProps> = ({
form,
handleOk,
selectedProvider,
setSelectedProvider,
providerModels,
setProviderModelsFn,
getPlaceholder,
uploadProps,
showAdvancedSettings,
setShowAdvancedSettings,
teams,
}) => {
return (
<>
<Title level={2}>Add new model</Title>
<Card>
<Form
form={form}
onFinish={handleOk}
labelCol={{ span: 10 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
>
<>
{/* Provider Selection */}
<Form.Item
rules={[{ required: true, message: "Required" }]}
label="Provider:"
name="custom_llm_provider"
tooltip="E.g. OpenAI, Azure OpenAI, Anthropic, Bedrock, etc."
labelCol={{ span: 10 }}
labelAlign="left"
>
<AntdSelect
showSearch={true}
value={selectedProvider}
onChange={(value) => {
setSelectedProvider(value);
setProviderModelsFn(value);
form.setFieldsValue({
model: [],
model_name: undefined
});
}}
>
{Object.entries(Providers).map(([providerEnum, providerDisplayName]) => (
<AntdSelect.Option
key={providerEnum}
value={providerEnum}
>
<div className="flex items-center space-x-2">
<img
src={providerLogoMap[providerDisplayName]}
alt={`${providerEnum} logo`}
className="w-5 h-5"
onError={(e) => {
// Create a div with provider initial as fallback
const target = e.target as HTMLImageElement;
const parent = target.parentElement;
if (parent) {
const fallbackDiv = document.createElement('div');
fallbackDiv.className = 'w-5 h-5 rounded-full bg-gray-200 flex items-center justify-center text-xs';
fallbackDiv.textContent = providerDisplayName.charAt(0);
parent.replaceChild(fallbackDiv, target);
}
}}
/>
<span>{providerDisplayName}</span>
</div>
</AntdSelect.Option>
))}
</AntdSelect>
</Form.Item>
<LiteLLMModelNameField
selectedProvider={selectedProvider}
providerModels={providerModels}
getPlaceholder={getPlaceholder}
/>

{/* Conditionally Render "Public Model Name" */}
<ConditionalPublicModelName />

<ProviderSpecificFields
selectedProvider={selectedProvider}
uploadProps={uploadProps}
/>
<AdvancedSettings
showAdvancedSettings={showAdvancedSettings}
setShowAdvancedSettings={setShowAdvancedSettings}
/>


<div className="flex justify-between items-center mb-4">
<Tooltip title="Get help on our github">
<Typography.Link href="https://github.com/BerriAI/litellm/issues">
Need Help?
</Typography.Link>
</Tooltip>
<Button htmlType="submit">Add Model</Button>
</div>
</>
</Form>
</Card>


</>
);
};

export default AddModelTab;
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import React from "react";
import { Form } from "antd";
import { TextInput, Text } from "@tremor/react";
import React, { useEffect } from "react";
import { Form, Table, Input } from "antd";
import { Text, TextInput } from "@tremor/react";
import { Row, Col } from "antd";

const ConditionalPublicModelName: React.FC = () => {
Expand All @@ -11,32 +11,61 @@ const ConditionalPublicModelName: React.FC = () => {
const selectedModels = Form.useWatch('model', form) || [];
const showPublicModelName = !selectedModels.includes('all-wildcard');

// Auto-populate model mappings when selected models change
useEffect(() => {
if (selectedModels.length > 0 && !selectedModels.includes('all-wildcard')) {
const mappings = selectedModels.map(model => ({
public_name: model,
litellm_model: model
}));
form.setFieldValue('model_mappings', mappings);
}
}, [selectedModels, form]);

if (!showPublicModelName) return null;

const columns = [
{
title: 'Public Name',
dataIndex: 'public_name',
key: 'public_name',
render: (text: string, record: any, index: number) => {
return (
<TextInput
defaultValue={text}
onChange={(e) => {
const newMappings = [...form.getFieldValue('model_mappings')];
newMappings[index].public_name = e.target.value;
form.setFieldValue('model_mappings', newMappings);
}}
/>
);
}
},
{
title: 'LiteLLM Model',
dataIndex: 'litellm_model',
key: 'litellm_model',
}
];

return (
<>
<Form.Item
label="Public Model Name"
name="model_name"
tooltip="Model name your users will pass in. Also used for load-balancing, LiteLLM will load balance between all models with this public name."
label="Model Mappings"
name="model_mappings"
tooltip="Map public model names to LiteLLM model names for load balancing"
labelCol={{ span: 10 }}
wrapperCol={{ span: 16 }}
labelAlign="left"
required={false}
className="mb-0"
rules={[
({ getFieldValue }) => ({
validator(_, value) {
const selectedModels = getFieldValue('model') || [];
if (!selectedModels.includes('all-wildcard') || value) {
return Promise.resolve();
}
return Promise.reject(new Error('Public Model Name is required unless "All Models" is selected.'));
},
}),
]}
required={true}
>
<TextInput placeholder="my-gpt-4" />
<Table
dataSource={form.getFieldValue('model_mappings')}
columns={columns}
pagination={false}
size="small"
/>
</Form.Item>
<Row>
<Col span={10}></Col>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,28 @@ export const handleAddModelSubmit = async (
) => {
try {
console.log("handling submit for formValues:", formValues);
// If model_name is not provided, use provider.toLowerCase() + "/*"

// Handle wildcard case
if (formValues["model"] && formValues["model"].includes("all-wildcard")) {
const customProvider: Providers = formValues["custom_llm_provider"];
const litellm_custom_provider = provider_map[customProvider as keyof typeof Providers];
const wildcardModel = litellm_custom_provider + "/*";
formValues["model_name"] = wildcardModel;
formValues["model"] = wildcardModel;
}
/**
* For multiple litellm model names - create a separate deployment for each
* - get the list
* - iterate through it
* - create a new deployment for each
*
* For single model name -> make it a 1 item list
*/

// get the list of deployments
let deployments: Array<string> = Array.isArray(formValues["model"])
? formValues["model"]
: [formValues["model"]];
console.log(`received deployments: ${deployments}`);
console.log(`received type of deployments: ${typeof deployments}`);
deployments.forEach(async (litellm_model) => {
console.log(`litellm_model: ${litellm_model}`);

// Get model mappings
const modelMappings = formValues["model_mappings"] || [];

// Create a deployment for each mapping
for (const mapping of modelMappings) {
const litellmParamsObj: Record<string, any> = {};
const modelInfoObj: Record<string, any> = {};

// Set the model name and litellm model from the mapping
const modelName = mapping.public_name;
litellmParamsObj["model"] = mapping.litellm_model;

// Handle pricing conversion before processing other fields
if (formValues.input_cost_per_token) {
formValues.input_cost_per_token = Number(formValues.input_cost_per_token) / 1000000;
Expand All @@ -49,8 +43,7 @@ export const handleAddModelSubmit = async (
// Keep input_cost_per_second as is, no conversion needed

// Iterate through the key-value pairs in formValues
litellmParamsObj["model"] = litellm_model;
let modelName: string = "";
litellmParamsObj["model"] = mapping.litellm_model;
console.log("formValues add deployment:", formValues);
for (const [key, value] of Object.entries(formValues)) {
if (value === "") {
Expand All @@ -61,7 +54,7 @@ export const handleAddModelSubmit = async (
continue;
}
if (key == "model_name") {
modelName = modelName + value;
litellmParamsObj["model"] = value;
} else if (key == "custom_llm_provider") {
console.log("custom_llm_provider:", value);
const mappingResult = provider_map[value]; // Get the corresponding value from the mapping
Expand Down Expand Up @@ -141,11 +134,10 @@ export const handleAddModelSubmit = async (
};

const response: any = await modelCreateCall(accessToken, new_model);
callback && callback()

console.log(`response for model create call: ${response["data"]}`);
});

}

callback && callback()
form.resetFields();
} catch (error) {
message.error("Failed to create model: " + error, 10);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,33 @@ import { Row, Col } from "antd";
import { Providers } from "../provider_info_helpers";

interface LiteLLMModelNameFieldProps {
selectedProvider: string;
selectedProvider: Providers;
providerModels: string[];
getPlaceholder: (provider: string) => string;
getPlaceholder: (provider: Providers) => string;
}

const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
selectedProvider,
providerModels,
providerModels,
getPlaceholder,
}) => {
const form = Form.useFormInstance();

const handleModelChange = (value: string[]) => {
const handleModelChange = (value: string | string[]) => {
// Ensure value is always treated as an array
const values = Array.isArray(value) ? value : [value];

// If "all-wildcard" is selected, clear the model_name field
if (value.includes("all-wildcard")) {
form.setFieldsValue({ model_name: undefined });
if (values.includes("all-wildcard")) {
form.setFieldsValue({ model_name: undefined, model_mappings: [] });
} else {
// Update model mappings immediately for each selected model
const mappings = values
.map(model => ({
public_name: model,
litellm_model: model
}));
form.setFieldsValue({ model_mappings: mappings });
}
};

Expand All @@ -39,9 +50,10 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
{(selectedProvider === Providers.Azure) ||
(selectedProvider === Providers.OpenAI_Compatible) ||
(selectedProvider === Providers.Ollama) ? (
<TextInput placeholder={getPlaceholder(selectedProvider.toString())} />
<TextInput placeholder={getPlaceholder(selectedProvider)} />
) : providerModels.length > 0 ? (
<AntSelect
mode="multiple"
allowClear
showSearch
placeholder="Select models"
Expand All @@ -67,7 +79,7 @@ const LiteLLMModelNameField: React.FC<LiteLLMModelNameFieldProps> = ({
style={{ width: '100%' }}
/>
) : (
<TextInput placeholder={getPlaceholder(selectedProvider.toString())} />
<TextInput placeholder={getPlaceholder(selectedProvider)} />
)}
</Form.Item>

Expand Down
Loading

0 comments on commit b63dd33

Please sign in to comment.