Skip to content

Commit

Permalink
docs(generative_ai): Update Chat Completions API samples
Browse files Browse the repository at this point in the history
- Fix imports (tests were failing)
- Add authentication sample
- Combine credentials refresher region tags
- Add samples for self-hosted models
  • Loading branch information
holtskinner committed Jan 22, 2025
1 parent 70df78b commit f17f755
Show file tree
Hide file tree
Showing 11 changed files with 209 additions and 43 deletions.
50 changes: 50 additions & 0 deletions generative_ai/chat_completions/chat_completions_authentication.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def generate_text(project_id: str, location: str = "us-central1") -> object:
# [START generativeaionvertexai_gemini_chat_completions_authentication]
import openai

from google.auth import default
import google.auth.transport.requests

# TODO(developer): Update and un-comment below lines
# project_id = "PROJECT_ID"
# location = "us-central1"

# Programmatically get an access token
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
credentials.refresh(google.auth.transport.requests.Request())
# Note: the credential lives for 1 hour by default (https://cloud.google.com/docs/authentication/token-types#at-lifetime); after expiration, it must be refreshed.

##############################
# Choose one of the following:
##############################

# If you are calling a Gemini model, set the ENDPOINT_ID variable to use openapi.
ENDPOINT_ID = "openapi"

# If you are calling a self-deployed model from Model Garden, set the
# ENDPOINT_ID variable and set the client's base URL to use your endpoint.
ENDPOINT_ID = "YOUR_ENDPOINT_ID"

# OpenAI Client
client = openai.OpenAI(
base_url=f"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/{ENDPOINT_ID}",
api_key=credentials.token,
)
# [END generativeaionvertexai_gemini_chat_completions_authentication]

return client
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Disable linting on `Any` type annotations (needed for OpenAI kwargs and attributes).
# flake8: noqa ANN401

# [START generativeaionvertexai_credentials_refresher_class]
# [START generativeaionvertexai_credentials_refresher]
from typing import Any

import google.auth
Expand All @@ -25,16 +25,15 @@

class OpenAICredentialsRefresher:
def __init__(self, **kwargs: Any) -> None:
# Set a dummy key here
self.client = openai.OpenAI(**kwargs, api_key="DUMMY")
# Set a placeholder key here
self.client = openai.OpenAI(**kwargs, api_key="PLACEHOLDER")
self.creds, self.project = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"]
)

def __getattr__(self, name: str) -> Any:
if not self.creds.valid:
auth_req = google.auth.transport.requests.Request()
self.creds.refresh(auth_req)
self.creds.refresh(google.auth.transport.requests.Request())

if not self.creds.valid:
raise RuntimeError("Unable to refresh auth")
Expand All @@ -43,11 +42,9 @@ def __getattr__(self, name: str) -> Any:
return getattr(self.client, name)


# [END generativeaionvertexai_credentials_refresher_class]


# [END generativeaionvertexai_credentials_refresher]
def generate_text(project_id: str, location: str = "us-central1") -> object:
# [START generativeaionvertexai_credentials_refresher_usage]
# [START generativeaionvertexai_credentials_refresher]

# TODO(developer): Update and un-comment below lines
# project_id = "PROJECT_ID"
Expand All @@ -63,6 +60,6 @@ def generate_text(project_id: str, location: str = "us-central1") -> object:
)

print(response)
# [END generativeaionvertexai_credentials_refresher_usage]
# [END generativeaionvertexai_credentials_refresher]

return response
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,19 @@

def generate_text(project_id: str, location: str = "us-central1") -> object:
# [START generativeaionvertexai_gemini_chat_completions_non_streaming_image]
import vertexai
import openai

from google.auth import default, transport
from google.auth import default
import google.auth.transport.requests

import openai

# TODO(developer): Update and un-comment below lines
# project_id = "PROJECT_ID"
# location = "us-central1"

vertexai.init(project=project_id, location=location)

# Programmatically get an access token
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
auth_request = transport.requests.Request()
credentials.refresh(auth_request)
credentials.refresh(google.auth.transport.requests.Request())

# OpenAI Client
client = openai.OpenAI(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,20 @@

def generate_text(project_id: str, location: str = "us-central1") -> object:
# [START generativeaionvertexai_gemini_chat_completions_non_streaming]
import vertexai
import openai
from google.auth import default
import google.auth.transport.requests

from google.auth import default, transport
import openai

# TODO(developer): Update and un-comment below lines
# project_id = "PROJECT_ID"
# location = "us-central1"

vertexai.init(project=project_id, location=location)

# Programmatically get an access token
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
auth_request = transport.requests.Request()
credentials.refresh(auth_request)
credentials.refresh(google.auth.transport.requests.Request())

# # OpenAI Client
# OpenAI Client
client = openai.OpenAI(
base_url=f"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/openapi",
api_key=credentials.token,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def generate_text(
project_id: str,
location: str = "us-central1",
model_id: str = "gemma-2-9b-it",
endpoint_id: str = "YOUR_ENDPOINT_ID",
) -> object:
# [START generativeaionvertexai_gemini_chat_completions_non_streaming_self_deployed]
from google.auth import default
import google.auth.transport.requests

import openai

# TODO(developer): Update and un-comment below lines
# project_id = "PROJECT_ID"
# location = "us-central1"
# model_id = "gemma-2-9b-it"
# endpoint_id = "YOUR_ENDPOINT_ID"

# Programmatically get an access token
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
credentials.refresh(google.auth.transport.requests.Request())

# OpenAI Client
client = openai.OpenAI(
base_url=f"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}",
api_key=credentials.token,
)

response = client.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": "Why is the sky blue?"}],
)
print(response)

# [END generativeaionvertexai_gemini_chat_completions_non_streaming_self_deployed]

return response
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@

def generate_text(project_id: str, location: str = "us-central1") -> object:
# [START generativeaionvertexai_gemini_chat_completions_streaming_image]
import vertexai
import openai
from google.auth import default
import google.auth.transport.requests

from google.auth import default, transport
import openai

# TODO(developer): Update and un-comment below lines
# project_id = "PROJECT_ID"
# location = "us-central1"

vertexai.init(project=project_id, location=location)

# Programmatically get an access token
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
auth_request = transport.requests.Request()
credentials.refresh(auth_request)
credentials.refresh(google.auth.transport.requests.Request())

# OpenAI Client
client = openai.OpenAI(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@

def generate_text(project_id: str, location: str = "us-central1") -> object:
# [START generativeaionvertexai_gemini_chat_completions_streaming]
import vertexai
import openai
from google.auth import default
import google.auth.transport.requests

from google.auth import default, transport
import openai

# TODO(developer): Update and un-comment below lines
# project_id = "PROJECT_ID"
# location = "us-central1"

vertexai.init(project=project_id, location=location)

# Programmatically get an access token
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
auth_request = transport.requests.Request()
credentials.refresh(auth_request)
credentials.refresh(google.auth.transport.requests.Request())

# OpenAI Client
client = openai.OpenAI(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


def generate_text(
project_id: str,
location: str = "us-central1",
model_id: str = "gemma-2-9b-it",
endpoint_id: str = "YOUR_ENDPOINT_ID",
) -> object:
# [START generativeaionvertexai_gemini_chat_completions_streaming_self_deployed]
from google.auth import default
import google.auth.transport.requests

import openai

# TODO(developer): Update and un-comment below lines
# project_id = "PROJECT_ID"
# location = "us-central1"
# model_id = "gemma-2-9b-it"
# endpoint_id = "YOUR_ENDPOINT_ID"

# Programmatically get an access token
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
credentials.refresh(google.auth.transport.requests.Request())

# OpenAI Client
client = openai.OpenAI(
base_url=f"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project_id}/locations/{location}/endpoints/{endpoint_id}",
api_key=credentials.token,
)

response = client.chat.completions.create(
model=model_id,
messages=[{"role": "user", "content": "Why is the sky blue?"}],
stream=True,
)
for chunk in response:
print(chunk)

# [END generativeaionvertexai_gemini_chat_completions_streaming_self_deployed]

return response
24 changes: 24 additions & 0 deletions generative_ai/chat_completions/chat_completions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,25 @@

import os

import chat_completions_authentication
import chat_completions_credentials_refresher
import chat_completions_non_streaming_image
import chat_completions_non_streaming_text
import chat_completions_streaming_image
import chat_completions_streaming_text
import chat_completions_streaming_text_self_deployed
import chat_completions_non_streaming_text_self_deployed


PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
LOCATION = "us-central1"
SELF_HOSTED_MODEL_ID = "gemma2-9b-it-mg-one-click-deploy"
ENDPOINT_ID = "6443623023395209216"


def test_authentication() -> None:
response = chat_completions_authentication.generate_text(PROJECT_ID, LOCATION)
assert response


def test_streaming_text() -> None:
Expand Down Expand Up @@ -50,3 +60,17 @@ def test_credentials_refresher() -> None:
PROJECT_ID, LOCATION
)
assert response


def test_streaming_text_self_deployed() -> None:
response = chat_completions_streaming_text_self_deployed.generate_text(
PROJECT_ID, LOCATION, SELF_HOSTED_MODEL_ID, ENDPOINT_ID
)
assert response


def test_non_streaming_text_self_deployed() -> None:
response = chat_completions_non_streaming_text_self_deployed.generate_text(
PROJECT_ID, LOCATION, SELF_HOSTED_MODEL_ID, ENDPOINT_ID
)
assert response
2 changes: 1 addition & 1 deletion generative_ai/chat_completions/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
backoff==2.2.1
google-api-core==2.19.0
google-api-core==2.24.0
pytest==8.2.0
pytest-asyncio==0.23.6
6 changes: 3 additions & 3 deletions generative_ai/chat_completions/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ pandas==2.0.3; python_version == '3.8'
pandas==2.1.4; python_version > '3.8'
pillow==10.3.0; python_version < '3.8'
pillow==10.3.0; python_version >= '3.8'
google-cloud-aiplatform[all]==1.69.0
google-cloud-aiplatform[all]==1.78.0
sentencepiece==0.2.0
google-auth==2.29.0
google-auth==2.37.0
anthropic[vertex]==0.28.0
langchain-core==0.2.11
langchain-google-vertexai==1.0.6
numpy<2
openai==1.30.5
openai==1.60.0
immutabledict==4.2.0

0 comments on commit f17f755

Please sign in to comment.