Skip to content

Commit

Permalink
feat: add support for Gemini 1.5 Pro
Browse files Browse the repository at this point in the history
  • Loading branch information
mondaychen committed Sep 17, 2024
1 parent 1cd3600 commit d08f402
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 22 deletions.
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"@chakra-ui/react": "^2.8.2",
"@emotion/react": "^11.11.4",
"@emotion/styled": "^11.11.5",
"@google/generative-ai": "^0.19.0",
"accname": "^1.1.0",
"construct-style-sheets-polyfill": "3.1.0",
"formik": "^2.4.5",
Expand Down
8 changes: 8 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/common/Settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const Settings = ({ setInSettingsView }: SettingsProps) => {
voiceMode: state.settings.voiceMode,
openAIKey: state.settings.openAIKey,
anthropicKey: state.settings.anthropicKey,
geminiKey: state.settings.geminiKey,
}));
const toast = useToast();

Expand Down Expand Up @@ -182,6 +183,7 @@ const Settings = ({ setInSettingsView }: SettingsProps) => {
state.agentMode,
state.openAIKey,
state.anthropicKey,
state.geminiKey,
) ? (
<Alert status="error">
<AlertIcon />
Expand Down
11 changes: 9 additions & 2 deletions src/common/settings/ModelDropdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ const ModelDropdown = () => {
updateSettings: state.settings.actions.update,
}));

const { openAIKey, anthropicKey } = useAppState((state) => ({
const { openAIKey, anthropicKey, geminiKey } = useAppState((state) => ({
openAIKey: state.settings.openAIKey,
anthropicKey: state.settings.anthropicKey,
geminiKey: state.settings.geminiKey,
}));

return (
Expand All @@ -32,7 +33,13 @@ const ModelDropdown = () => {
key={model}
value={model}
disabled={
!isValidModelSettings(model, agentMode, openAIKey, anthropicKey)
!isValidModelSettings(
model,
agentMode,
openAIKey,
anthropicKey,
geminiKey,
)
}
>
{DisplayName[model]}
Expand Down
33 changes: 32 additions & 1 deletion src/common/settings/SetAPIKey.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ type SetAPIKeyProps = {
asInitializerView?: boolean;
initialOpenAIKey?: string;
initialAnthropicKey?: string;
initialGeminiKey?: string;
onClose?: () => void;
};

const SetAPIKey = ({
asInitializerView = false,
initialOpenAIKey = "",
initialAnthropicKey = "",
initialGeminiKey = "",
onClose,
}: SetAPIKeyProps) => {
const { updateSettings, initialOpenAIBaseUrl, initialAnthropicBaseUrl } =
Expand All @@ -38,6 +40,7 @@ const SetAPIKey = ({
const [anthropicKey, setAnthropicKey] = React.useState(
initialAnthropicKey || "",
);
const [geminiKey, setGeminiKey] = React.useState(initialGeminiKey || "");
const [openAIBaseUrl, setOpenAIBaseUrl] = React.useState(
initialOpenAIBaseUrl || "",
);
Expand All @@ -53,6 +56,7 @@ const SetAPIKey = ({
openAIBaseUrl,
anthropicKey,
anthropicBaseUrl,
geminiKey,
});
onClose && onClose();
};
Expand Down Expand Up @@ -157,10 +161,37 @@ const SetAPIKey = ({
/>
</FormControl>
)}

<Box position="relative" py={2} w="full">
<Divider />
<AbsoluteCenter bg="white" px="4">
Gemini (Google)
</AbsoluteCenter>
</Box>
<FormControl>
<FormLabel>Gemini API Key</FormLabel>
<HStack w="full">
<Input
placeholder="Enter Gemini API Key"
value={geminiKey}
onChange={(event) => setGeminiKey(event.target.value)}
type={showPassword ? "text" : "password"}
/>
{asInitializerView && (
<Button
onClick={() => setShowPassword(!showPassword)}
variant="outline"
>
{showPassword ? "Hide" : "Show"}
</Button>
)}
</HStack>
</FormControl>

<Button
onClick={onSave}
w="full"
isDisabled={!openAIKey && !anthropicKey}
isDisabled={!openAIKey && !anthropicKey && !geminiKey}
colorScheme="blue"
>
Save
Expand Down
96 changes: 77 additions & 19 deletions src/helpers/aiSdkUtils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Anthropic from "@anthropic-ai/sdk";
import { GoogleGenerativeAI, type Part } from "@google/generative-ai";
import OpenAI from "openai";
import { useAppState } from "../state/store";
import { enumValues } from "./utils";
Expand All @@ -22,6 +23,7 @@ export enum SupportedModels {
Claude3Sonnet = "claude-3-sonnet-20240229",
Claude3Opus = "claude-3-opus-20240229",
Claude35Sonnet = "claude-3-5-sonnet-20240620",
Gemini15Pro = "gemini-1.5-pro",
}

function isSupportedModel(value: string): value is SupportedModels {
Expand All @@ -43,6 +45,7 @@ export const DisplayName = {
[SupportedModels.Claude3Sonnet]: "Claude 3 Sonnet",
[SupportedModels.Claude3Opus]: "Claude 3 Opus",
[SupportedModels.Claude35Sonnet]: "Claude 3.5 Sonnet",
[SupportedModels.Gemini15Pro]: "Gemini 1.5 Pro",
};

export function hasVisionSupport(model: SupportedModels) {
Expand All @@ -53,16 +56,20 @@ export function hasVisionSupport(model: SupportedModels) {
model === SupportedModels.Gpt4OMini ||
model === SupportedModels.Claude3Sonnet ||
model === SupportedModels.Claude3Opus ||
model === SupportedModels.Claude35Sonnet
model === SupportedModels.Claude35Sonnet ||
model === SupportedModels.Gemini15Pro
);
}

export type SDKChoice = "OpenAI" | "Anthropic";
export type SDKChoice = "OpenAI" | "Anthropic" | "Google";

function chooseSDK(model: SupportedModels): SDKChoice {
if (model.startsWith("claude")) {
return "Anthropic";
}
if (model.startsWith("gemini")) {
return "Google";
}
return "OpenAI";
}

Expand All @@ -72,12 +79,16 @@ export function isOpenAIModel(model: SupportedModels) {
export function isAnthropicModel(model: SupportedModels) {
return chooseSDK(model) === "Anthropic";
}
export function isGoogleModel(model: SupportedModels) {
return chooseSDK(model) === "Google";
}

export function isValidModelSettings(
selectedModel: string,
agentMode: AgentMode,
openAIKey: string | undefined,
anthropicKey: string | undefined,
geminiKey: string | undefined,
): boolean {
if (!isSupportedModel(selectedModel)) {
return false;
Expand All @@ -88,10 +99,13 @@ export function isValidModelSettings(
) {
return false;
}
if (openAIKey && !anthropicKey && !isOpenAIModel(selectedModel)) {
if (isOpenAIModel(selectedModel) && !openAIKey) {
return false;
}
if (isAnthropicModel(selectedModel) && !anthropicKey) {
return false;
}
if (!openAIKey && anthropicKey && !isAnthropicModel(selectedModel)) {
if (isGoogleModel(selectedModel) && !geminiKey) {
return false;
}
return true;
Expand All @@ -102,24 +116,29 @@ export function findBestMatchingModel(
agentMode: AgentMode,
openAIKey: string | undefined,
anthropicKey: string | undefined,
geminiKey: string | undefined,
): SupportedModels {
let result: SupportedModels = DEFAULT_MODEL;
// verify the string value is a supported model
// this is to handle the case when we drop support for a model
if (isSupportedModel(selectedModel)) {
result = selectedModel;
if (
isValidModelSettings(
selectedModel,
agentMode,
openAIKey,
anthropicKey,
geminiKey,
)
) {
return selectedModel as SupportedModels;
}
if (openAIKey) {
return SupportedModels.Gpt4Turbo;
}
// if agent mode is vision-enhanced, we need to ensure the model supports vision
if (agentMode === AgentMode.VisionEnhanced && !hasVisionSupport(result)) {
result = SupportedModels.Gpt4Turbo;
if (anthropicKey) {
return SupportedModels.Claude35Sonnet;
}
// ensure the provider's API key is available
if (!openAIKey && anthropicKey && !isAnthropicModel(result)) {
result = SupportedModels.Claude35Sonnet;
} else if (openAIKey && !anthropicKey && !isOpenAIModel(result)) {
result = SupportedModels.Gpt4O;
if (geminiKey) {
return SupportedModels.Gemini15Pro;
}
return result;
return DEFAULT_MODEL;
}

export type CommonMessageCreateParams = {
Expand Down Expand Up @@ -275,14 +294,53 @@ export async function fetchResponseFromModelAnthropic(
};
}

export async function fetchResponseFromModelGoogle(
model: SupportedModels,
params: CommonMessageCreateParams,
): Promise<Response> {
const key = useAppState.getState().settings.geminiKey;
if (!key) {
throw new Error("No Google Gemini key found");
}
const genAI = new GoogleGenerativeAI(key);
const client = genAI.getGenerativeModel({
model: model,
systemInstruction: params.systemMessage,
});
const requestInput: Array<string | Part> = [];
requestInput.push(params.prompt);
if (params.imageData != null) {
requestInput.push({
inlineData: {
data: params.imageData.split("base64,")[1],
mimeType: "image/webp",
},
});
}
const result = await client.generateContent(requestInput);
return {
usage: {
completion_tokens:
result.response.usageMetadata?.candidatesTokenCount ?? 0,
prompt_tokens: result.response.usageMetadata?.promptTokenCount ?? 0,
total_tokens: result.response.usageMetadata?.totalTokenCount ?? 0,
},
rawResponse: result.response.text(),
};
}

export async function fetchResponseFromModel(
model: SupportedModels,
params: CommonMessageCreateParams,
): Promise<Response> {
const sdk = chooseSDK(model);
if (sdk === "OpenAI") {
return await fetchResponseFromModelOpenAI(model, params);
} else {
} else if (sdk === "Anthropic") {
return await fetchResponseFromModelAnthropic(model, params);
} else if (sdk === "Google") {
return await fetchResponseFromModelGoogle(model, params);
} else {
throw new Error("Unsupported model");
}
}
1 change: 1 addition & 0 deletions src/state/currentTask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ export const createCurrentTaskSlice: MyStateCreator<CurrentTaskSlice> = (
get().settings.agentMode,
get().settings.openAIKey,
get().settings.anthropicKey,
get().settings.geminiKey,
)
) {
onError(
Expand Down
3 changes: 3 additions & 0 deletions src/state/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export type SettingsSlice = {
anthropicKey: string | undefined;
openAIBaseUrl: string | undefined;
anthropicBaseUrl: string | undefined;
geminiKey: string | undefined;
selectedModel: SupportedModels;
agentMode: AgentMode;
voiceMode: boolean;
Expand All @@ -24,6 +25,7 @@ export const createSettingsSlice: MyStateCreator<SettingsSlice> = (set) => ({
anthropicKey: undefined,
openAIBaseUrl: undefined,
anthropicBaseUrl: undefined,
geminiKey: undefined,
agentMode: AgentMode.VisionEnhanced,
selectedModel: SupportedModels.Gpt4Turbo,
voiceMode: false,
Expand All @@ -37,6 +39,7 @@ export const createSettingsSlice: MyStateCreator<SettingsSlice> = (set) => ({
newSettings.agentMode,
newSettings.openAIKey,
newSettings.anthropicKey,
newSettings.geminiKey,
);
// voice model current relies on OpenAI API key
if (!newSettings.openAIKey) {
Expand Down
2 changes: 2 additions & 0 deletions src/state/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export const useAppState = create<StoreType>()(
settings: {
openAIKey: state.settings.openAIKey,
anthropicKey: state.settings.anthropicKey,
geminiKey: state.settings.geminiKey,
openAIBaseUrl: state.settings.openAIBaseUrl,
anthropicBaseUrl: state.settings.anthropicBaseUrl,
agentMode: state.settings.agentMode,
Expand All @@ -55,6 +56,7 @@ export const useAppState = create<StoreType>()(
result.settings.agentMode,
result.settings.openAIKey,
result.settings.anthropicKey,
result.settings.geminiKey,
);
return result;
},
Expand Down

0 comments on commit d08f402

Please sign in to comment.