Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for Gemini 1.5 Pro #219

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"@chakra-ui/react": "^2.8.2",
"@emotion/react": "^11.11.4",
"@emotion/styled": "^11.11.5",
"@google/generative-ai": "^0.19.0",
"accname": "^1.1.0",
"construct-style-sheets-polyfill": "3.1.0",
"formik": "^2.4.5",
Expand Down
8 changes: 8 additions & 0 deletions pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions src/common/Settings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ const Settings = ({ setInSettingsView }: SettingsProps) => {
voiceMode: state.settings.voiceMode,
openAIKey: state.settings.openAIKey,
anthropicKey: state.settings.anthropicKey,
geminiKey: state.settings.geminiKey,
}));
const toast = useToast();

Expand Down Expand Up @@ -182,6 +183,7 @@ const Settings = ({ setInSettingsView }: SettingsProps) => {
state.agentMode,
state.openAIKey,
state.anthropicKey,
state.geminiKey,
) ? (
<Alert status="error">
<AlertIcon />
Expand Down
11 changes: 9 additions & 2 deletions src/common/settings/ModelDropdown.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ const ModelDropdown = () => {
updateSettings: state.settings.actions.update,
}));

const { openAIKey, anthropicKey } = useAppState((state) => ({
const { openAIKey, anthropicKey, geminiKey } = useAppState((state) => ({
openAIKey: state.settings.openAIKey,
anthropicKey: state.settings.anthropicKey,
geminiKey: state.settings.geminiKey,
}));

return (
Expand All @@ -32,7 +33,13 @@ const ModelDropdown = () => {
key={model}
value={model}
disabled={
!isValidModelSettings(model, agentMode, openAIKey, anthropicKey)
!isValidModelSettings(
model,
agentMode,
openAIKey,
anthropicKey,
geminiKey,
)
}
>
{DisplayName[model]}
Expand Down
33 changes: 32 additions & 1 deletion src/common/settings/SetAPIKey.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ type SetAPIKeyProps = {
asInitializerView?: boolean;
initialOpenAIKey?: string;
initialAnthropicKey?: string;
initialGeminiKey?: string;
onClose?: () => void;
};

const SetAPIKey = ({
asInitializerView = false,
initialOpenAIKey = "",
initialAnthropicKey = "",
initialGeminiKey = "",
onClose,
}: SetAPIKeyProps) => {
const { updateSettings, initialOpenAIBaseUrl, initialAnthropicBaseUrl } =
Expand All @@ -38,6 +40,7 @@ const SetAPIKey = ({
const [anthropicKey, setAnthropicKey] = React.useState(
initialAnthropicKey || "",
);
const [geminiKey, setGeminiKey] = React.useState(initialGeminiKey || "");
const [openAIBaseUrl, setOpenAIBaseUrl] = React.useState(
initialOpenAIBaseUrl || "",
);
Expand All @@ -53,6 +56,7 @@ const SetAPIKey = ({
openAIBaseUrl,
anthropicKey,
anthropicBaseUrl,
geminiKey,
});
onClose && onClose();
};
Expand Down Expand Up @@ -157,10 +161,37 @@ const SetAPIKey = ({
/>
</FormControl>
)}

<Box position="relative" py={2} w="full">
<Divider />
<AbsoluteCenter bg="white" px="4">
Gemini (Google)
</AbsoluteCenter>
</Box>
<FormControl>
<FormLabel>Gemini API Key</FormLabel>
<HStack w="full">
<Input
placeholder="Enter Gemini API Key"
value={geminiKey}
onChange={(event) => setGeminiKey(event.target.value)}
type={showPassword ? "text" : "password"}
/>
{asInitializerView && (
<Button
onClick={() => setShowPassword(!showPassword)}
variant="outline"
>
{showPassword ? "Hide" : "Show"}
</Button>
)}
</HStack>
</FormControl>

<Button
onClick={onSave}
w="full"
isDisabled={!openAIKey && !anthropicKey}
isDisabled={!openAIKey && !anthropicKey && !geminiKey}
colorScheme="blue"
>
Save
Expand Down
96 changes: 77 additions & 19 deletions src/helpers/aiSdkUtils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Anthropic from "@anthropic-ai/sdk";
import { GoogleGenerativeAI, type Part } from "@google/generative-ai";
import OpenAI from "openai";
import { useAppState } from "../state/store";
import { enumValues } from "./utils";
Expand All @@ -22,6 +23,7 @@ export enum SupportedModels {
Claude3Sonnet = "claude-3-sonnet-20240229",
Claude3Opus = "claude-3-opus-20240229",
Claude35Sonnet = "claude-3-5-sonnet-20240620",
Gemini15Pro = "gemini-1.5-pro",
}

function isSupportedModel(value: string): value is SupportedModels {
Expand All @@ -43,6 +45,7 @@ export const DisplayName = {
[SupportedModels.Claude3Sonnet]: "Claude 3 Sonnet",
[SupportedModels.Claude3Opus]: "Claude 3 Opus",
[SupportedModels.Claude35Sonnet]: "Claude 3.5 Sonnet",
[SupportedModels.Gemini15Pro]: "Gemini 1.5 Pro",
};

export function hasVisionSupport(model: SupportedModels) {
Expand All @@ -53,16 +56,20 @@ export function hasVisionSupport(model: SupportedModels) {
model === SupportedModels.Gpt4OMini ||
model === SupportedModels.Claude3Sonnet ||
model === SupportedModels.Claude3Opus ||
model === SupportedModels.Claude35Sonnet
model === SupportedModels.Claude35Sonnet ||
model === SupportedModels.Gemini15Pro
);
}

export type SDKChoice = "OpenAI" | "Anthropic";
export type SDKChoice = "OpenAI" | "Anthropic" | "Google";

function chooseSDK(model: SupportedModels): SDKChoice {
if (model.startsWith("claude")) {
return "Anthropic";
}
if (model.startsWith("gemini")) {
return "Google";
}
return "OpenAI";
}

Expand All @@ -72,12 +79,16 @@ export function isOpenAIModel(model: SupportedModels) {
export function isAnthropicModel(model: SupportedModels) {
return chooseSDK(model) === "Anthropic";
}
export function isGoogleModel(model: SupportedModels) {
return chooseSDK(model) === "Google";
}

export function isValidModelSettings(
selectedModel: string,
agentMode: AgentMode,
openAIKey: string | undefined,
anthropicKey: string | undefined,
geminiKey: string | undefined,
): boolean {
if (!isSupportedModel(selectedModel)) {
return false;
Expand All @@ -88,10 +99,13 @@ export function isValidModelSettings(
) {
return false;
}
if (openAIKey && !anthropicKey && !isOpenAIModel(selectedModel)) {
if (isOpenAIModel(selectedModel) && !openAIKey) {
return false;
}
if (isAnthropicModel(selectedModel) && !anthropicKey) {
return false;
}
if (!openAIKey && anthropicKey && !isAnthropicModel(selectedModel)) {
if (isGoogleModel(selectedModel) && !geminiKey) {
return false;
}
return true;
Expand All @@ -102,24 +116,29 @@ export function findBestMatchingModel(
agentMode: AgentMode,
openAIKey: string | undefined,
anthropicKey: string | undefined,
geminiKey: string | undefined,
): SupportedModels {
let result: SupportedModels = DEFAULT_MODEL;
// verify the string value is a supported model
// this is to handle the case when we drop support for a model
if (isSupportedModel(selectedModel)) {
result = selectedModel;
if (
isValidModelSettings(
selectedModel,
agentMode,
openAIKey,
anthropicKey,
geminiKey,
)
) {
return selectedModel as SupportedModels;
}
if (openAIKey) {
return SupportedModels.Gpt4Turbo;
}
// if agent mode is vision-enhanced, we need to ensure the model supports vision
if (agentMode === AgentMode.VisionEnhanced && !hasVisionSupport(result)) {
result = SupportedModels.Gpt4Turbo;
if (anthropicKey) {
return SupportedModels.Claude35Sonnet;
}
// ensure the provider's API key is available
if (!openAIKey && anthropicKey && !isAnthropicModel(result)) {
result = SupportedModels.Claude35Sonnet;
} else if (openAIKey && !anthropicKey && !isOpenAIModel(result)) {
result = SupportedModels.Gpt4O;
if (geminiKey) {
return SupportedModels.Gemini15Pro;
}
return result;
return DEFAULT_MODEL;
}

export type CommonMessageCreateParams = {
Expand Down Expand Up @@ -275,14 +294,53 @@ export async function fetchResponseFromModelAnthropic(
};
}

export async function fetchResponseFromModelGoogle(
model: SupportedModels,
params: CommonMessageCreateParams,
): Promise<Response> {
const key = useAppState.getState().settings.geminiKey;
if (!key) {
throw new Error("No Google Gemini key found");
}
const genAI = new GoogleGenerativeAI(key);
const client = genAI.getGenerativeModel({
model: model,
systemInstruction: params.systemMessage,
});
const requestInput: Array<string | Part> = [];
requestInput.push(params.prompt);
if (params.imageData != null) {
requestInput.push({
inlineData: {
data: params.imageData.split("base64,")[1],
mimeType: "image/webp",
},
});
}
const result = await client.generateContent(requestInput);
return {
usage: {
completion_tokens:
result.response.usageMetadata?.candidatesTokenCount ?? 0,
prompt_tokens: result.response.usageMetadata?.promptTokenCount ?? 0,
total_tokens: result.response.usageMetadata?.totalTokenCount ?? 0,
},
rawResponse: result.response.text(),
};
}

export async function fetchResponseFromModel(
model: SupportedModels,
params: CommonMessageCreateParams,
): Promise<Response> {
const sdk = chooseSDK(model);
if (sdk === "OpenAI") {
return await fetchResponseFromModelOpenAI(model, params);
} else {
} else if (sdk === "Anthropic") {
return await fetchResponseFromModelAnthropic(model, params);
} else if (sdk === "Google") {
return await fetchResponseFromModelGoogle(model, params);
} else {
throw new Error("Unsupported model");
}
}
1 change: 1 addition & 0 deletions src/state/currentTask.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ export const createCurrentTaskSlice: MyStateCreator<CurrentTaskSlice> = (
get().settings.agentMode,
get().settings.openAIKey,
get().settings.anthropicKey,
get().settings.geminiKey,
)
) {
onError(
Expand Down
3 changes: 3 additions & 0 deletions src/state/settings.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ export type SettingsSlice = {
anthropicKey: string | undefined;
openAIBaseUrl: string | undefined;
anthropicBaseUrl: string | undefined;
geminiKey: string | undefined;
selectedModel: SupportedModels;
agentMode: AgentMode;
voiceMode: boolean;
Expand All @@ -24,6 +25,7 @@ export const createSettingsSlice: MyStateCreator<SettingsSlice> = (set) => ({
anthropicKey: undefined,
openAIBaseUrl: undefined,
anthropicBaseUrl: undefined,
geminiKey: undefined,
agentMode: AgentMode.VisionEnhanced,
selectedModel: SupportedModels.Gpt4Turbo,
voiceMode: false,
Expand All @@ -37,6 +39,7 @@ export const createSettingsSlice: MyStateCreator<SettingsSlice> = (set) => ({
newSettings.agentMode,
newSettings.openAIKey,
newSettings.anthropicKey,
newSettings.geminiKey,
);
// voice model current relies on OpenAI API key
if (!newSettings.openAIKey) {
Expand Down
2 changes: 2 additions & 0 deletions src/state/store.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export const useAppState = create<StoreType>()(
settings: {
openAIKey: state.settings.openAIKey,
anthropicKey: state.settings.anthropicKey,
geminiKey: state.settings.geminiKey,
openAIBaseUrl: state.settings.openAIBaseUrl,
anthropicBaseUrl: state.settings.anthropicBaseUrl,
agentMode: state.settings.agentMode,
Expand All @@ -55,6 +56,7 @@ export const useAppState = create<StoreType>()(
result.settings.agentMode,
result.settings.openAIKey,
result.settings.anthropicKey,
result.settings.geminiKey,
);
return result;
},
Expand Down
Loading