Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: update docs and examples for Azure OpenAI v1 #761

Merged
merged 2 commits into from
Nov 19, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/LLMs/llms.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,14 @@ from pandasai.llm import AzureOpenAI

llm = AzureOpenAI(
api_token="my-azure-openai-api-key",
api_base="my-azure-openai-api-endpoint",
azure_endpoint="my-azure-openai-api-endpoint",
api_version="2023-05-15",
deployment_name="my-deployment-name"
)
df = SmartDataframe("data.csv", config={"llm": llm})
```

As an alternative, you can set the `OPENAI_API_KEY`, `OPENAI_API_VERSION`, and `OPENAI_API_BASE` environment variables and instantiate the Azure OpenAI object without passing them:
As an alternative, you can set the `AZURE_OPENAI_API_KEY`, `OPENAI_API_VERSION`, and `AZURE_OPENAI_ENDPOINT` environment variables and instantiate the Azure OpenAI object without passing them:
mspronesti marked this conversation as resolved.
Show resolved Hide resolved

```python
from pandasai import SmartDataframe
Expand Down
6 changes: 3 additions & 3 deletions examples/with_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

df = pd.DataFrame(dataframe)

# export OPENAI_API_BASE=https://your-resource-name.openai.azure.com
# export OPENAI_API_KEY=<your Azure OpenAI API key>
# export AZURE_OPENAI_ENDPOINT=https://your-resource-name.openai.azure.com/
# export AZURE_OPENAI_API_KEY=<your Azure OpenAI API key>

# The name of your deployed model
# This will correspond to the custom name you chose for your
Expand All @@ -19,7 +19,7 @@
llm = AzureOpenAI(
deployment_name=deployment_name,
api_version="2023-05-15",
# is_chat_model=True, # Comment in if you deployed a chat model
# is_chat_model=False, # Comment in if you deployed a completion model
)

df = SmartDataframe(df, config={"llm": llm})
Expand Down
43 changes: 25 additions & 18 deletions pandasai/llm/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ class AzureOpenAI(BaseOpenAI):
api_type: str = "azure"

def __init__(
self,
api_token: Optional[str] = None,
azure_endpoint: Union[str, None] = None,
azure_ad_token: Union[str, None] = None,
azure_ad_token_provider: Union[str, None] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
deployment_name: str = None,
is_chat_model: bool = True,
**kwargs,
self,
api_token: Optional[str] = None,
azure_endpoint: Union[str, None] = None,
azure_ad_token: Union[str, None] = None,
azure_ad_token_provider: Union[str, None] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
deployment_name: str = None,
is_chat_model: bool = True,
**kwargs,
):
mspronesti marked this conversation as resolved.
Show resolved Hide resolved
"""
__init__ method of AzureOpenAI Class.
Expand All @@ -74,17 +74,18 @@ def __init__(
Will be invoked on every request.
api_version (str): Version of the Azure OpenAI API.
Be aware the API version may change.
api_base (str): Legacy, kept for backward compatibility with openai < 1.0
api_base (str): Legacy, kept for backward compatibility with openai < 1.0.
Ignored for openai >= 1.0.
deployment_name (str): Custom name of the deployed model
is_chat_model (bool): Whether ``deployment_name`` corresponds to a Chat
or a Completion model.
**kwargs: Inference Parameters.
"""

self.api_token = (
api_token
or os.getenv("OPENAI_API_KEY")
or os.getenv("AZURE_OPENAI_API_KEY")
api_token
or os.getenv("OPENAI_API_KEY")
or os.getenv("AZURE_OPENAI_API_KEY")
)
self.azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
self.api_base = api_base or os.getenv("OPENAI_API_BASE")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The environment variable fallbacks have been updated to reflect the new Azure-specific names. However, the fallback for api_token still checks for OPENAI_API_KEY before AZURE_OPENAI_API_KEY. This should be reversed to prioritize the Azure-specific environment variable, as the context suggests a move towards Azure-specific naming conventions.

85        self.api_token = (
86            api_token
-           or os.getenv("OPENAI_API_KEY")
+           or os.getenv("AZURE_OPENAI_API_KEY")
+           or os.getenv("OPENAI_API_KEY")
88        )

Commitable suggestion

[!IMPORTANT]
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
self.api_token = (
api_token
or os.getenv("OPENAI_API_KEY")
or os.getenv("AZURE_OPENAI_API_KEY")
api_token
or os.getenv("OPENAI_API_KEY")
or os.getenv("AZURE_OPENAI_API_KEY")
)
self.azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
self.api_base = api_base or os.getenv("OPENAI_API_BASE")
self.api_token = (
api_token
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
self.azure_endpoint = azure_endpoint or os.getenv("AZURE_OPENAI_ENDPOINT")
self.api_base = api_base or os.getenv("OPENAI_API_BASE")

Expand All @@ -105,6 +106,7 @@ def __init__(
"Azure OpenAI base is required. Please add an environment variable "
"`OPENAI_API_BASE` or pass `api_base` as a named parameter"
)

if self.api_version is None:
raise APIKeyNotFoundError(
"Azure OpenAI version is required. Please add an environment variable "
mspronesti marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -133,10 +135,12 @@ def __init__(
if is_openai_v1()
else openai.ChatCompletion
)
elif is_openai_v1():
self.client = openai.AzureOpenAI(**self._client_params).completions
else:
self.client = openai.Completion
self.client = (
openai.AzureOpenAI(**self._client_params).completions
if is_openai_v1()
else openai.Completion
)

@property
def _default_params(self) -> Dict[str, Any]:
Expand All @@ -147,7 +151,10 @@ def _default_params(self) -> Dict[str, Any]:
dict: A dictionary containing Default Params.

"""
return {**super()._default_params, "model" if is_openai_v1() else "engine": self.deployment_name}
return {
**super()._default_params,
"model" if is_openai_v1() else "engine": self.deployment_name,
}

@property
def _invocation_params(self) -> Dict[str, Any]:
Expand Down
8 changes: 4 additions & 4 deletions pandasai/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ def _extract_tag_text(self, response: str, tag: str) -> str:
"""

if match := re.search(
f"(<{tag}>)(.*)(</{tag}>)",
response,
re.DOTALL | re.MULTILINE,
f"(<{tag}>)(.*)(</{tag}>)",
response,
re.DOTALL | re.MULTILINE,
):
return match[2]
return None
Expand Down Expand Up @@ -414,7 +414,7 @@ def _setup(self, **kwargs):

"""
self.api_token = (
kwargs.get("api_token") or os.getenv("HUGGINGFACE_API_KEY") or None
kwargs.get("api_token") or os.getenv("HUGGINGFACE_API_KEY") or None
)
if self.api_token is None:
raise APIKeyNotFoundError("HuggingFace Hub API key is required")
Expand Down
24 changes: 13 additions & 11 deletions pandasai/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ class OpenAI(BaseOpenAI):
model: str = "gpt-3.5-turbo"

def __init__(
self,
api_token: Optional[str] = None,
**kwargs,
self,
api_token: Optional[str] = None,
**kwargs,
):
mspronesti marked this conversation as resolved.
Show resolved Hide resolved
"""
__init__ method of OpenAI Class
Expand All @@ -76,16 +76,18 @@ def __init__(
model_name = self.model.split(":")[1] if "ft:" in self.model else self.model
if model_name in self._supported_chat_models:
self._is_chat_model = True
if is_openai_v1():
self.client = openai.OpenAI(**self._client_params).chat.completions
else:
self.client = openai.ChatCompletion
self.client = (
openai.OpenAI(**self._client_params).chat.completions
if is_openai_v1()
else openai.ChatCompletion
)
elif model_name in self._supported_completion_models:
self._is_chat_model = False
if is_openai_v1():
self.client = openai.OpenAI(**self._client_params).completions
else:
self.client = openai.Completion
self.client = (
openai.OpenAI(**self._client_params).completions
if is_openai_v1()
else openai.Completion
)
else:
raise UnsupportedModelError(self.model)

Expand Down