Skip to content

Commit

Permalink
(feat) NVIDIA connectors update (#17672)
Browse files Browse the repository at this point in the history
  • Loading branch information
raspawar authored Feb 14, 2025
1 parent 93ba0d5 commit ddf5a48
Show file tree
Hide file tree
Showing 22 changed files with 6,369 additions and 425 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,48 +2,35 @@

from typing import Any, List, Literal, Optional
import warnings
import os

from llama_index.core.base.embeddings.base import (
DEFAULT_EMBED_BATCH_SIZE,
BaseEmbedding,
)
from llama_index.core.bridge.pydantic import Field, PrivateAttr, BaseModel
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.base.llms.generic_utils import get_from_param_or_env

from openai import OpenAI, AsyncOpenAI
from urllib.parse import urlparse

# integrate.api.nvidia.com is the default url for most models, any
# bespoke endpoints will need to be added to the MODEL_ENDPOINT_MAP
BASE_URL = "https://integrate.api.nvidia.com/v1/"
DEFAULT_MODEL = "nvidia/nv-embedqa-e5-v5"

# because MODEL_ENDPOINT_MAP is used to construct KNOWN_URLS, we need to
# include at least one model w/ https://integrate.api.nvidia.com/v1/
MODEL_ENDPOINT_MAP = {
"NV-Embed-QA": "https://ai.api.nvidia.com/v1/retrieval/nvidia/",
"snowflake/arctic-embed-l": "https://integrate.api.nvidia.com/v1/",
"nvidia/nv-embed-v1": "https://integrate.api.nvidia.com/v1/",
"nvidia/nv-embedqa-mistral-7b-v2": "https://integrate.api.nvidia.com/v1/",
"nvidia/nv-embedqa-e5-v5": "https://integrate.api.nvidia.com/v1/",
"baai/bge-m3": "https://integrate.api.nvidia.com/v1/",
"nvidia/llama-3.2-nv-embedqa-1b-v1": "https://integrate.api.nvidia.com/v1/",
"nvidia/llama-3.2-nv-embedqa-1b-v2": "https://integrate.api.nvidia.com/v1/",
}

KNOWN_URLS = list(MODEL_ENDPOINT_MAP.values())
KNOWN_URLS.append("https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l")


class Model(BaseModel):
id: str
base_model: Optional[str] = None
from urllib.parse import urlparse, urlunparse
from .utils import (
EMBEDDING_MODEL_TABLE,
BASE_URL,
KNOWN_URLS,
DEFAULT_MODEL,
Model,
determine_model,
)


class NVIDIAEmbedding(BaseEmbedding):
"""NVIDIA embeddings."""

base_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", BASE_URL),
description="Base url for model listing an invocation",
)
model: Optional[str] = Field(
description="Name of the NVIDIA embedding model to use.\n"
)
Expand Down Expand Up @@ -86,7 +73,6 @@ def __init__(
dimensions: Optional[int] = 0,
nvidia_api_key: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, # This could default to 50
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
Expand Down Expand Up @@ -133,27 +119,27 @@ def __init__(
"NO_API_KEY_PROVIDED",
)

base_url = base_url or BASE_URL
self._is_hosted = base_url in KNOWN_URLS
self._is_hosted = self.base_url in KNOWN_URLS
if not self._is_hosted:
self.base_url = self._validate_url(self.base_url)

if self._is_hosted: # hosted on API Catalog (build.nvidia.com)
if api_key == "NO_API_KEY_PROVIDED":
raise ValueError("An API key is required for hosted NIM.")
# TODO: we should not assume unknown models are at the base url
base_url = MODEL_ENDPOINT_MAP.get(model, BASE_URL)
else: # not hosted
base_url = self._validate_url(base_url)
self.base_url = self._validate_url(self.base_url)

self._client = OpenAI(
api_key=api_key,
base_url=base_url,
base_url=self.base_url,
timeout=timeout,
max_retries=max_retries,
)
self._client._custom_headers = {"User-Agent": "llama-index-embeddings-nvidia"}

self._aclient = AsyncOpenAI(
api_key=api_key,
base_url=base_url,
base_url=self.base_url,
timeout=timeout,
max_retries=max_retries,
)
Expand Down Expand Up @@ -192,44 +178,57 @@ def __get_default_model(self) -> None:

def _validate_url(self, base_url):
"""
Base URL Validation.
ValueError : url which do not have valid scheme and netloc.
Warning : v1/embeddings routes.
ValueError : Any other routes other than above.
validate the base_url.
if the base_url is not a url, raise an error
if the base_url does not end in /v1, e.g. /embeddings
emit a warning. old documentation told users to pass in the full
inference url, which is incorrect and prevents model listing from working.
normalize base_url to end in /v1.
"""
expected_format = "Expected format is 'http://host:port'."
result = urlparse(base_url)
if not (result.scheme and result.netloc):
raise ValueError(f"Invalid base_url, {expected_format}")
if base_url.endswith("embeddings"):
warnings.warn(f"{expected_format} Rest is ignored")
return base_url.strip("/")
if base_url is not None:
parsed = urlparse(base_url)

# Ensure scheme and netloc (domain name) are present
if not (parsed.scheme and parsed.netloc):
expected_format = "Expected format is: http://host:port"
raise ValueError(
f"Invalid base_url format. {expected_format} Got: {base_url}"
)

normalized_path = parsed.path.rstrip("/")
if not normalized_path.endswith("/v1"):
warnings.warn(
f"{base_url} does not end in /v1, you may "
"have inference and listing issues"
)
normalized_path += "/v1"

base_url = urlunparse(
(parsed.scheme, parsed.netloc, normalized_path, None, None, None)
)
return base_url

def _validate_model(self, model_name: str) -> None:
"""
Validates compatibility of the hosted model with the client.
Skipping the client validation for non-catalogue requests.
Args:
model_name (str): The name of the model.
Raises:
ValueError: If the model is incompatible with the client.
"""
model = determine_model(model_name)
if self._is_hosted:
if model_name not in MODEL_ENDPOINT_MAP:
if model_name in [model.id for model in self._client.models.list()]:
warnings.warn(f"Unable to determine validity of {model_name}")
else:
raise ValueError(
f"Model {model_name} is incompatible with client {self.class_name()}. "
f"Please check `{self.class_name()}.available_models()`."
)
else:
if model_name not in [model.id for model in self.available_models]:
raise ValueError(f"No locally hosted {model_name} was found.")
if not model:
warnings.warn(f"Unable to determine validity of {model_name}")
if model and model.endpoint:
self.base_url = model.endpoint
# TODO: handle locally hosted models

@property
def available_models(self) -> List[Model]:
def available_models(self) -> List[str]:
"""Get available models."""
# TODO: hosted now has a model listing, need to merge known and listed models
if not self._is_hosted:
Expand All @@ -241,7 +240,7 @@ def available_models(self) -> List[Model]:
for model in self._client.models.list()
]
else:
return [Model(id=id) for id in MODEL_ENDPOINT_MAP]
return [Model(id=id) for id in EMBEDDING_MODEL_TABLE]

@classmethod
def class_name(cls) -> str:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from dataclasses import dataclass
from typing import Optional
import warnings

# integrate.api.nvidia.com is the default url for most models, any
# bespoke endpoints will need to be added to the MODEL_ENDPOINT_MAP
BASE_URL = "https://integrate.api.nvidia.com/v1"
DEFAULT_MODEL = "nvidia/nv-embedqa-e5-v5"


@dataclass
class Model:
"""
Model information.
id: unique identifier for the model, passed as model parameter for requests
model_type: API type (chat, vlm, embedding, ranking, completions)
client: client name, e.g. NvidiaGenerator, NVIDIAEmbeddings,
NVIDIARerank, NvidiaTextEmbedder, NvidiaDocumentEmbedder
endpoint: custom endpoint for the model
aliases: list of aliases for the model
All aliases are deprecated and will trigger a warning when used.
"""

id: str
model_type: Optional[str] = "embedding"
client: str = "NVIDIAEmbedding"
endpoint: Optional[str] = None
aliases: Optional[list] = None
base_model: Optional[str] = None
supports_tools: Optional[bool] = False
supports_structured_output: Optional[bool] = False

def __hash__(self) -> int:
return hash(self.id)

def validate(self):
if self.client:
supported = {"NVIDIAEmbedding": ("embedding",)}
model_type = self.model_type
if model_type not in supported[self.client]:
err_msg = (
f"Model type '{model_type}' not supported by client '{self.client}'"
)
raise ValueError(err_msg)

return hash(self.id)


# because EMBEDDING_MODEL_TABLE is used to construct KNOWN_URLS, we need to
# include at least one model w/ https://integrate.api.nvidia.com/v1
EMBEDDING_MODEL_TABLE = {
"snowflake/arctic-embed-l": Model(
id="snowflake/arctic-embed-l",
model_type="embedding",
aliases=["ai-arctic-embed-l"],
),
"NV-Embed-QA": Model(
id="NV-Embed-QA",
model_type="embedding",
endpoint="https://ai.api.nvidia.com/v1/retrieval/nvidia",
aliases=[
"ai-embed-qa-4",
"playground_nvolveqa_40k",
"nvolveqa_40k",
],
),
"nvidia/nv-embed-v1": Model(
id="nvidia/nv-embed-v1",
model_type="embedding",
aliases=["ai-nv-embed-v1"],
),
"nvidia/nv-embedqa-mistral-7b-v2": Model(
id="nvidia/nv-embedqa-mistral-7b-v2",
model_type="embedding",
),
"nvidia/nv-embedqa-e5-v5": Model(
id="nvidia/nv-embedqa-e5-v5",
model_type="embedding",
),
"baai/bge-m3": Model(
id="baai/bge-m3",
model_type="embedding",
),
}


def lookup_model(name: str) -> Optional[Model]:
"""
Lookup a model by name, using only the table of known models.
The name is either:
- directly in the table
- an alias in the table
- not found (None)
Callers can check to see if the name was an alias by
comparing the result's id field to the name they provided.
"""
if not (model := EMBEDDING_MODEL_TABLE.get(name)):
for mdl in EMBEDDING_MODEL_TABLE.values():
if mdl.aliases and name in mdl.aliases:
model = mdl
break
return model


def determine_model(name: str) -> Optional[Model]:
"""
Determine the model to use based on a name, using
only the table of known models.
Raise a warning if the model is found to be
an alias of a known model.
If the model is not found, return None.
"""
if model := lookup_model(name):
# all aliases are deprecated
if model.id != name:
warn_msg = f"Model {name} is deprecated. Using {model.id} instead."
warnings.warn(warn_msg, UserWarning, stacklevel=1)
return model


KNOWN_URLS = [
BASE_URL,
"https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l",
]
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-nvidia"
readme = "README.md"
version = "0.3.1"
version = "0.3.2"

[tool.poetry.dependencies]
python = ">=3.9,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def masked_env_var() -> Generator[str, None, None]:
os.environ[var] = val


@pytest.fixture(params=[Interface])
def public_class(request: pytest.FixtureRequest) -> type:
return request.param


def pytest_collection_modifyitems(config, items):
if "NVIDIA_API_KEY" not in os.environ:
skip_marker = pytest.mark.skip(
Expand Down
Loading

0 comments on commit ddf5a48

Please sign in to comment.