Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: configure dynamic providers via .env #1108

Merged
merged 8 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 32 additions & 30 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import { Menu } from '~/components/sidebar/Menu.client';
import { IconButton } from '~/components/ui/IconButton';
import { Workbench } from '~/components/workbench/Workbench.client';
import { classNames } from '~/utils/classNames';
import { MODEL_LIST, PROVIDER_LIST, initializeModelList } from '~/utils/constants';
import { PROVIDER_LIST } from '~/utils/constants';
import { Messages } from './Messages.client';
import { SendButton } from './SendButton.client';
import { APIKeyManager, getApiKeysFromCookies } from './APIKeyManager';
Expand All @@ -31,7 +31,7 @@ import { toast } from 'react-toastify';
import StarterTemplates from './StarterTemplates';
import type { ActionAlert } from '~/types/actions';
import ChatAlert from './ChatAlert';
import { LLMManager } from '~/lib/modules/llm/manager';
import type { ModelInfo } from '~/lib/modules/llm/types';

const TEXTAREA_MIN_HEIGHT = 76;

Expand Down Expand Up @@ -102,7 +102,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
) => {
const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200;
const [apiKeys, setApiKeys] = useState<Record<string, string>>(getApiKeysFromCookies());
const [modelList, setModelList] = useState(MODEL_LIST);
const [modelList, setModelList] = useState<ModelInfo[]>([]);
const [isModelSettingsCollapsed, setIsModelSettingsCollapsed] = useState(false);
const [isListening, setIsListening] = useState(false);
const [recognition, setRecognition] = useState<SpeechRecognition | null>(null);
Expand Down Expand Up @@ -131,6 +131,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(

return providerSettings;
}, []);

useEffect(() => {
console.log(transcript);
}, [transcript]);
Expand Down Expand Up @@ -169,25 +170,31 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(

useEffect(() => {
if (typeof window !== 'undefined') {
const providerSettings = getProviderSettings();
let parsedApiKeys: Record<string, string> | undefined = {};
const providerSettings = getProviderSettings();

try {
parsedApiKeys = getApiKeysFromCookies();
setApiKeys(parsedApiKeys);
} catch (error) {
console.error('Error loading API keys from cookies:', error);

// Clear invalid cookie data
Cookies.remove('apiKeys');
}

setIsModelLoading('all');
initializeModelList({ apiKeys: parsedApiKeys, providerSettings })
.then((modelList) => {
setModelList(modelList);
fetch('/api/models', {
headers: {
'x-client-api-keys': JSON.stringify(parsedApiKeys),
'x-client-provider-settings': JSON.stringify(providerSettings),
},
})
.then((response) => response.json())
.then((data) => {
const typedData = data as { modelList: ModelInfo[] };
setModelList(typedData.modelList);
})
.catch((error) => {
console.error('Error initializing model list:', error);
console.error('Error fetching model list:', error);
})
.finally(() => {
setIsModelLoading(undefined);
Expand All @@ -200,29 +207,24 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
setApiKeys(newApiKeys);
Cookies.set('apiKeys', JSON.stringify(newApiKeys));

const provider = LLMManager.getInstance(import.meta.env || process.env || {}).getProvider(providerName);
setIsModelLoading(providerName);

if (provider && provider.getDynamicModels) {
setIsModelLoading(providerName);
try {
const providerSettings = getProviderSettings();

try {
const providerSettings = getProviderSettings();
const staticModels = provider.staticModels;
const dynamicModels = await provider.getDynamicModels(
newApiKeys,
providerSettings,
import.meta.env || process.env || {},
);

setModelList((preModels) => {
const filteredOutPreModels = preModels.filter((x) => x.provider !== providerName);
return [...filteredOutPreModels, ...staticModels, ...dynamicModels];
});
} catch (error) {
console.error('Error loading dynamic models:', error);
}
setIsModelLoading(undefined);
const response = await fetch('/api/models', {
headers: {
'x-client-api-keys': JSON.stringify(newApiKeys),
'x-client-provider-settings': JSON.stringify(providerSettings),
},
});
const data = await response.json();
const typedData = data as { modelList: ModelInfo[] };
setModelList(typedData.modelList);
} catch (error) {
console.error('Error loading dynamic models:', error);
}
setIsModelLoading(undefined);
};

const startListening = () => {
Expand Down
2 changes: 1 addition & 1 deletion app/lib/modules/llm/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ export class LLMManager {

let enabledProviders = Array.from(this._providers.values()).map((p) => p.name);

if (providerSettings) {
if (providerSettings && Object.keys(providerSettings).length > 0) {
enabledProviders = enabledProviders.filter((p) => providerSettings[p].enabled);
}

Expand Down
17 changes: 14 additions & 3 deletions app/routes/api.llmcall.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare';
import { streamText } from '~/lib/.server/llm/stream-text';
import type { IProviderSetting, ProviderInfo } from '~/types/model';
import { generateText } from 'ai';
import { getModelList, PROVIDER_LIST } from '~/utils/constants';
import { PROVIDER_LIST } from '~/utils/constants';
import { MAX_TOKENS } from '~/lib/.server/llm/constants';
import { LLMManager } from '~/lib/modules/llm/manager';
import type { ModelInfo } from '~/lib/modules/llm/types';

export async function action(args: ActionFunctionArgs) {
return llmCallAction(args);
Expand All @@ -31,6 +33,15 @@ function parseCookies(cookieHeader: string) {
return cookies;
}

async function getModelList(options: {
apiKeys?: Record<string, string>;
providerSettings?: Record<string, IProviderSetting>;
serverEnv?: Record<string, string>;
}) {
const llmManager = LLMManager.getInstance(import.meta.env);
return llmManager.updateModelList(options);
}

async function llmCallAction({ context, request }: ActionFunctionArgs) {
const { system, message, model, provider, streamOutput } = await request.json<{
system: string;
Expand Down Expand Up @@ -105,8 +116,8 @@ async function llmCallAction({ context, request }: ActionFunctionArgs) {
}
} else {
try {
const MODEL_LIST = await getModelList({ apiKeys, providerSettings, serverEnv: context.cloudflare.env as any });
const modelDetails = MODEL_LIST.find((m) => m.name === model);
const models = await getModelList({ apiKeys, providerSettings, serverEnv: context.cloudflare.env as any });
const modelDetails = models.find((m: ModelInfo) => m.name === model);

if (!modelDetails) {
throw new Error('Model not found');
Expand Down
63 changes: 60 additions & 3 deletions app/routes/api.models.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,63 @@
import { json } from '@remix-run/cloudflare';
import { MODEL_LIST } from '~/utils/constants';
import { LLMManager } from '~/lib/modules/llm/manager';
import type { ModelInfo } from '~/lib/modules/llm/types';
import type { ProviderInfo } from '~/types/model';

export async function loader() {
return json(MODEL_LIST);
interface ModelsResponse {
modelList: ModelInfo[];
providers: ProviderInfo[];
defaultProvider: ProviderInfo;
}

let cachedProviders: ProviderInfo[] | null = null;
let cachedDefaultProvider: ProviderInfo | null = null;

function getProviderInfo(llmManager: LLMManager) {
if (!cachedProviders) {
cachedProviders = llmManager.getAllProviders().map((provider) => ({
name: provider.name,
staticModels: provider.staticModels,
getApiKeyLink: provider.getApiKeyLink,
labelForGetApiKey: provider.labelForGetApiKey,
icon: provider.icon,
}));
}

if (!cachedDefaultProvider) {
const defaultProvider = llmManager.getDefaultProvider();
cachedDefaultProvider = {
name: defaultProvider.name,
staticModels: defaultProvider.staticModels,
getApiKeyLink: defaultProvider.getApiKeyLink,
labelForGetApiKey: defaultProvider.labelForGetApiKey,
icon: defaultProvider.icon,
};
}

return { providers: cachedProviders, defaultProvider: cachedDefaultProvider };
}

export async function loader({ request }: { request: Request }): Promise<Response> {
const llmManager = LLMManager.getInstance(import.meta.env);

// process client-side overwritten api keys and provider settings
const clientsideApiKeys = request.headers.get('x-client-api-keys');
const cliensideProviderSettings = request.headers.get('x-client-provider-settings');

const apiKeys = clientsideApiKeys ? JSON.parse(clientsideApiKeys) : {};
const providerSettings = cliensideProviderSettings ? JSON.parse(cliensideProviderSettings) : [];

const { providers, defaultProvider } = getProviderInfo(llmManager);

const modelList = await llmManager.updateModelList({
apiKeys,
providerSettings,
serverEnv: import.meta.env,
});

return json<ModelsResponse>({
modelList,
providers,
defaultProvider,
});
}
35 changes: 1 addition & 34 deletions app/utils/constants.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import type { IProviderSetting } from '~/types/model';

import { LLMManager } from '~/lib/modules/llm/manager';
import type { ModelInfo } from '~/lib/modules/llm/types';
import type { Template } from '~/types/template';

export const WORK_DIR_NAME = 'project';
Expand All @@ -17,44 +14,14 @@ const llmManager = LLMManager.getInstance(import.meta.env);
export const PROVIDER_LIST = llmManager.getAllProviders();
export const DEFAULT_PROVIDER = llmManager.getDefaultProvider();

let MODEL_LIST = llmManager.getModelList();

const providerBaseUrlEnvKeys: Record<string, { baseUrlKey?: string; apiTokenKey?: string }> = {};
export const providerBaseUrlEnvKeys: Record<string, { baseUrlKey?: string; apiTokenKey?: string }> = {};
PROVIDER_LIST.forEach((provider) => {
providerBaseUrlEnvKeys[provider.name] = {
baseUrlKey: provider.config.baseUrlKey,
apiTokenKey: provider.config.apiTokenKey,
};
});

// Export the getModelList function using the manager
export async function getModelList(options: {
apiKeys?: Record<string, string>;
providerSettings?: Record<string, IProviderSetting>;
serverEnv?: Record<string, string>;
}) {
return await llmManager.updateModelList(options);
}

async function initializeModelList(options: {
env?: Record<string, string>;
providerSettings?: Record<string, IProviderSetting>;
apiKeys?: Record<string, string>;
}): Promise<ModelInfo[]> {
const { providerSettings, apiKeys, env } = options;
const list = await getModelList({
apiKeys,
providerSettings,
serverEnv: env,
});
MODEL_LIST = list || MODEL_LIST;

return list;
}

// initializeModelList({})
export { initializeModelList, providerBaseUrlEnvKeys, MODEL_LIST };

// starter Templates

export const STARTER_TEMPLATES: Template[] = [
Expand Down
Loading