Skip to content

Commit

Permalink
adding finish reason mapping for aleph alpha and baseten
Browse files Browse the repository at this point in the history
  • Loading branch information
krrishdholakia committed Sep 14, 2023
1 parent aaa57ab commit fef2a39
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 28 deletions.
1 change: 1 addition & 0 deletions litellm/llms/aleph_alpha.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def completion(
else:
try:
model_response["choices"][0]["message"]["content"] = completion_response["completions"][0]["completion"]
model_response.choices[0].finish_reason = completion_response["completions"][0]["finish_reason"]
except:
raise AlephAlphaError(message=json.dumps(completion_response), status_code=response.status_code)

Expand Down
1 change: 1 addition & 0 deletions litellm/llms/baseten.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def completion(
model_response["choices"][0]["message"]["content"] = completion_response[0]["generated_text"]
## GETTING LOGPROBS
if "details" in completion_response[0] and "tokens" in completion_response[0]["details"]:
model_response.choices[0].finish_reason = completion_response[0]["details"]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
sum_logprob += token["logprob"]
Expand Down
57 changes: 30 additions & 27 deletions litellm/tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,25 +487,25 @@ def test_completion_azure_deployment_id():

# Replicate API endpoints are unstable -> throw random CUDA errors -> this means our tests can fail even if our tests weren't incorrect.

def test_completion_replicate_llama_2():
model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
try:
response = completion(
model=model_name,
messages=messages,
max_tokens=20,
custom_llm_provider="replicate"
)
print(response)
cost = completion_cost(completion_response=response)
print("Cost for completion call with llama-2: ", f"${float(cost):.10f}")
# Add any assertions here to check the response
response_str = response["choices"][0]["message"]["content"]
print(response_str)
if type(response_str) != str:
pytest.fail(f"Error occurred: {e}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# def test_completion_replicate_llama_2():
# model_name = "replicate/llama-2-70b-chat:2796ee9483c3fd7aa2e171d38f4ca12251a30609463dcfd4cd76703f22e96cdf"
# try:
# response = completion(
# model=model_name,
# messages=messages,
# max_tokens=20,
# custom_llm_provider="replicate"
# )
# print(response)
# cost = completion_cost(completion_response=response)
# print("Cost for completion call with llama-2: ", f"${float(cost):.10f}")
# # Add any assertions here to check the response
# response_str = response["choices"][0]["message"]["content"]
# print(response_str)
# if type(response_str) != str:
# pytest.fail(f"Error occurred: {e}")
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
# test_completion_replicate_llama_2()

def test_completion_replicate_vicuna():
Expand Down Expand Up @@ -601,6 +601,7 @@ def test_completion_sagemaker():
except Exception as e:
pytest.fail(f"Error occurred: {e}")

# test_completion_sagemaker()
######## Test VLLM ########
# def test_completion_vllm():
# try:
Expand Down Expand Up @@ -658,14 +659,15 @@ def test_completion_sagemaker():

# test_completion_custom_api_base()

# def test_vertex_ai():
# model_name = "chat-bison"
# try:
# response = completion(model=model_name, messages=messages, logger_fn=logger_fn)
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")
def test_vertex_ai():
model_name = "chat-bison"
try:
response = completion(model=model_name, messages=messages, logger_fn=logger_fn)
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")

test_vertex_ai()
# def test_petals():
# model_name = "stabilityai/StableBeluga2"
# try:
Expand Down Expand Up @@ -696,12 +698,13 @@ def test_completion_with_fallbacks():
# def test_baseten():
# try:

# response = completion(model="baseten/RqgAEn0", messages=messages, logger_fn=logger_fn)
# response = completion(model="baseten/7qQNLDB", messages=messages, logger_fn=logger_fn)
# # Add any assertions here to check the response
# print(response)
# except Exception as e:
# pytest.fail(f"Error occurred: {e}")

# test_baseten()
# def test_baseten_falcon_7bcompletion():
# model_name = "qvv0xeq"
# try:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "litellm"
version = "0.1.621"
version = "0.1.622"
description = "Library to easily interface with LLM API providers"
authors = ["BerriAI"]
license = "MIT License"
Expand Down

0 comments on commit fef2a39

Please sign in to comment.