Skip to content

Commit

Permalink
feat: add support for NVIDIA inference for ElizaOS (#2512)
Browse files Browse the repository at this point in the history
* Adding support for NVIDIA inference for ElizaOS

* removed wrong image generation

* Fixed wrong indentation

---------

Co-authored-by: Sayo <hi@sayo.wtf>
  • Loading branch information
AIFlowML and wtfsayo authored Jan 19, 2025
1 parent 3206ef4 commit a5dccdb
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 4 deletions.
10 changes: 8 additions & 2 deletions agent/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -546,15 +546,20 @@ export function getTokenForProvider(
character.settings?.secrets?.HYPERBOLIC_API_KEY ||
settings.HYPERBOLIC_API_KEY
);

case ModelProviderName.VENICE:
return (
character.settings?.secrets?.VENICE_API_KEY ||
settings.VENICE_API_KEY
);
case ModelProviderName.ATOMA:
return (
character.settings?.secrets?.ATOMASDK_BEARER_AUTH ||
settings.ATOMASDK_BEARER_AUTH
case ModelProviderName.NVIDIA:
return (
character.settings?.secrets?.ATOMASDK_BEARER_AUTH ||
settings.ATOMASDK_BEARER_AUTH
character.settings?.secrets?.NVIDIA_API_KEY ||
settings.NVIDIA_API_KEY
);
case ModelProviderName.AKASH_CHAT_API:
return (
Expand Down Expand Up @@ -918,6 +923,7 @@ export async function createAgent(
getSecret(character, "FAL_API_KEY") ||
getSecret(character, "OPENAI_API_KEY") ||
getSecret(character, "VENICE_API_KEY") ||
getSecret(character, "NVIDIA_API_KEY") ||
getSecret(character, "NINETEEN_AI_API_KEY") ||
getSecret(character, "HEURIST_API_KEY") ||
getSecret(character, "LIVEPEER_GATEWAY_URL")
Expand Down
30 changes: 28 additions & 2 deletions packages/core/src/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,32 @@ export async function generateText({
break;
}

case ModelProviderName.NVIDIA: {
elizaLogger.debug("Initializing NVIDIA model.");
const nvidia = createOpenAI({
apiKey: apiKey,
baseURL: endpoint,
});

const { text: nvidiaResponse } = await aiGenerateText({
model: nvidia.languageModel(model),
prompt: context,
system:
runtime.character.system ??
settings.SYSTEM_PROMPT ??
undefined,
tools: tools,
onStepFinish: onStepFinish,
temperature: temperature,
maxSteps: maxSteps,
maxTokens: max_response_length,
});

response = nvidiaResponse;
elizaLogger.debug("Received response from NVIDIA model.");
break;
}

case ModelProviderName.DEEPSEEK: {
elizaLogger.debug("Initializing Deepseek model.");
const serverUrl = models[provider].endpoint;
Expand Down Expand Up @@ -1615,8 +1641,8 @@ export const generateImage = async (
return runtime.getSetting("FAL_API_KEY");
case ModelProviderName.OPENAI:
return runtime.getSetting("OPENAI_API_KEY");
case ModelProviderName.VENICE:
return runtime.getSetting("VENICE_API_KEY");
case ModelProviderName.VENICE:
return runtime.getSetting("VENICE_API_KEY");
case ModelProviderName.LIVEPEER:
return runtime.getSetting("LIVEPEER_GATEWAY_URL");
default:
Expand Down
26 changes: 26 additions & 0 deletions packages/core/src/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,32 @@ export const models: Models = {
},
},
},
[ModelProviderName.NVIDIA]: {
endpoint: "https://integrate.api.nvidia.com/v1",
model: {
[ModelClass.SMALL]: {
name: settings.SMALL_NVIDIA_MODEL || "meta/llama-3.2-3b-instruct",
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
temperature: 0.6,
},
[ModelClass.MEDIUM]: {
name: settings.MEDIUM_NVIDIA_MODEL || "meta/llama-3.3-70b-instruct",
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
temperature: 0.6,
},
[ModelClass.LARGE]: {
name: settings.LARGE_NVIDIA_MODEL || "meta/llama-3.1-405b-instruct",
stop: [],
maxInputTokens: 128000,
maxOutputTokens: 8192,
temperature: 0.6,
},
},
},
[ModelProviderName.NINETEEN_AI]: {
endpoint: "https://api.nineteen.ai/v1",
model: {
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ export type Models = {
[ModelProviderName.NANOGPT]: Model;
[ModelProviderName.HYPERBOLIC]: Model;
[ModelProviderName.VENICE]: Model;
[ModelProviderName.NVIDIA]: Model;
[ModelProviderName.NINETEEN_AI]: Model;
[ModelProviderName.AKASH_CHAT_API]: Model;
[ModelProviderName.LIVEPEER]: Model;
Expand Down Expand Up @@ -259,6 +260,7 @@ export enum ModelProviderName {
NANOGPT = "nanogpt",
HYPERBOLIC = "hyperbolic",
VENICE = "venice",
NVIDIA = "nvidia",
NINETEEN_AI = "nineteen_ai",
AKASH_CHAT_API = "akash_chat_api",
LIVEPEER = "livepeer",
Expand Down

0 comments on commit a5dccdb

Please sign in to comment.