-
-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Model][Misc] Add e5-mistral-7b-instruct and Embedding API (#3734)
- Loading branch information
1 parent
4e12131
commit e254497
Showing
38 changed files
with
1,627 additions
and
160 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from vllm import LLM | ||
|
||
# Sample prompts. | ||
prompts = [ | ||
"Hello, my name is", | ||
"The president of the United States is", | ||
"The capital of France is", | ||
"The future of AI is", | ||
] | ||
|
||
# Create an LLM. | ||
model = LLM(model="intfloat/e5-mistral-7b-instruct", enforce_eager=True) | ||
# Generate embedding. The output is a list of EmbeddingRequestOutputs. | ||
outputs = model.encode(prompts) | ||
# Print the outputs. | ||
for output in outputs: | ||
print(output.outputs.embedding) # list of 4096 floats |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from openai import OpenAI | ||
|
||
# Modify OpenAI's API key and API base to use vLLM's API server. | ||
openai_api_key = "EMPTY" | ||
openai_api_base = "http://localhost:8000/v1" | ||
|
||
client = OpenAI( | ||
# defaults to os.environ.get("OPENAI_API_KEY") | ||
api_key=openai_api_key, | ||
base_url=openai_api_base, | ||
) | ||
|
||
models = client.models.list() | ||
model = models.data[0].id | ||
|
||
responses = client.embeddings.create(input=[ | ||
"Hello my name is", | ||
"The best thing about vLLM is that it supports many different models" | ||
], | ||
model=model) | ||
|
||
for data in responses.data: | ||
print(data.embedding) # list of float of len 4096 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. | ||
Run `pytest tests/models/test_llama_embedding.py`. | ||
""" | ||
import pytest | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
MODELS = [ | ||
"intfloat/e5-mistral-7b-instruct", | ||
] | ||
|
||
|
||
def compare_embeddings(embeddings1, embeddings2): | ||
similarities = [ | ||
F.cosine_similarity(torch.tensor(e1), torch.tensor(e2), dim=0) | ||
for e1, e2 in zip(embeddings1, embeddings2) | ||
] | ||
return similarities | ||
|
||
|
||
@pytest.mark.parametrize("model", MODELS) | ||
@pytest.mark.parametrize("dtype", ["half"]) | ||
def test_models( | ||
hf_runner, | ||
vllm_runner, | ||
example_prompts, | ||
model: str, | ||
dtype: str, | ||
) -> None: | ||
hf_model = hf_runner(model, dtype=dtype) | ||
hf_outputs = hf_model.encode(example_prompts) | ||
del hf_model | ||
|
||
vllm_model = vllm_runner(model, dtype=dtype) | ||
vllm_outputs = vllm_model.encode(example_prompts) | ||
del vllm_model | ||
|
||
similarities = compare_embeddings(hf_outputs, vllm_outputs) | ||
all_similarities = torch.stack(similarities) | ||
tolerance = 1e-2 | ||
assert torch.all((all_similarities <= 1.0 + tolerance) | ||
& (all_similarities >= 1.0 - tolerance) | ||
), f"Not all values are within {tolerance} of 1.0" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.