Skip to content

Commit

Permalink
Deprecate VertexAI PaLM text generative model (apache#44719)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaksYermak authored and Lefteris Gilmaz committed Jan 5, 2025
1 parent 3cd4912 commit 30d1a17
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 72 deletions.
10 changes: 0 additions & 10 deletions docs/apache-airflow-providers-google/operators/cloud/vertex_ai.rst
Original file line number Diff line number Diff line change
Expand Up @@ -573,16 +573,6 @@ To get a pipeline job list you can use
Interacting with Generative AI
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

To generate a prediction via language model you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator`.
The operator returns the model's response in :ref:`XCom <concepts:xcom>` under ``model_response`` key.

.. exampleinclude:: /../../providers/tests/system/google/cloud/vertex_ai/example_vertex_ai_generative_model.py
:language: python
:dedent: 4
:start-after: [START how_to_cloud_vertex_ai_text_generation_model_predict_operator]
:end-before: [END how_to_cloud_vertex_ai_text_generation_model_predict_operator]

To generate text embeddings you can use
:class:`~airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextEmbeddingModelGetEmbeddingsOperator`.
The operator returns the model's response in :ref:`XCom <concepts:xcom>` under ``model_response`` key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
class GenerativeModelHook(GoogleBaseHook):
"""Hook for Google Cloud Vertex AI Generative Model APIs."""

@deprecated(
planned_removal_date="April 09, 2025",
use_instead="GenerativeModelHook.get_generative_model",
category=AirflowProviderDeprecationWarning,
)
def get_text_generation_model(self, pretrained_model: str):
"""Return a Model Garden Model object based on Text Generation."""
model = TextGenerationModel.from_pretrained(pretrained_model)
Expand Down Expand Up @@ -275,6 +280,11 @@ def prompt_multimodal_model_with_media(

return response.text

@deprecated(
planned_removal_date="April 09, 2025",
use_instead="GenerativeModelHook.generative_model_generate_content",
category=AirflowProviderDeprecationWarning,
)
@GoogleBaseHook.fallback_to_default_project_id
def text_generation_model_predict(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,11 @@ def execute(self, context: Context):
return response


@deprecated(
planned_removal_date="April 09, 2025",
use_instead="GenerativeModelGenerateContentOperator",
category=AirflowProviderDeprecationWarning,
)
class TextGenerationModelPredictOperator(GoogleCloudBaseOperator):
"""
Uses the Vertex AI PaLM API to generate natural language text.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,24 +205,18 @@ def test_prompt_multimodal_model_with_media(self, mock_model, mock_part) -> None

@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_generation_model"))
def test_text_generation_model_predict(self, mock_model) -> None:
self.hook.text_generation_model_predict(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=TEST_PROMPT,
pretrained_model=TEST_LANGUAGE_PRETRAINED_MODEL,
temperature=TEST_TEMPERATURE,
max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
top_p=TEST_TOP_P,
top_k=TEST_TOP_K,
)
mock_model.assert_called_once_with(TEST_LANGUAGE_PRETRAINED_MODEL)
mock_model.return_value.predict.assert_called_once_with(
prompt=TEST_PROMPT,
temperature=TEST_TEMPERATURE,
max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
top_p=TEST_TOP_P,
top_k=TEST_TOP_K,
)
with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
self.hook.text_generation_model_predict(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=TEST_PROMPT,
pretrained_model=TEST_LANGUAGE_PRETRAINED_MODEL,
temperature=TEST_TEMPERATURE,
max_output_tokens=TEST_MAX_OUTPUT_TOKENS,
top_p=TEST_TOP_P,
top_k=TEST_TOP_K,
)
assert_warning("generative_model_generate_content", warnings)

@mock.patch(GENERATIVE_MODEL_STRING.format("GenerativeModelHook.get_text_embedding_model"))
def test_text_embedding_model_get_embeddings(self, mock_model) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,28 +278,46 @@ def test_execute(self, mock_hook):


class TestVertexAITextGenerationModelPredictOperator:
prompt = "In 10 words or less, what is Apache Airflow?"
pretrained_model = "text-bison"
temperature = 0.0
max_output_tokens = 256
top_p = 0.8
top_k = 40

def test_deprecation_warning(self):
with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
TextGenerationModelPredictOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.prompt,
pretrained_model=self.pretrained_model,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
top_p=self.top_p,
top_k=self.top_k,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
assert_warning("GenerativeModelGenerateContentOperator", warnings)

@mock.patch(VERTEX_AI_PATH.format("generative_model.GenerativeModelHook"))
def test_execute(self, mock_hook):
prompt = "In 10 words or less, what is Apache Airflow?"
pretrained_model = "text-bison"
temperature = 0.0
max_output_tokens = 256
top_p = 0.8
top_k = 40

op = TextGenerationModelPredictOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=prompt,
pretrained_model=pretrained_model,
temperature=temperature,
max_output_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
with pytest.warns(AirflowProviderDeprecationWarning):
op = TextGenerationModelPredictOperator(
task_id=TASK_ID,
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=self.prompt,
pretrained_model=self.pretrained_model,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
top_p=self.top_p,
top_k=self.top_k,
gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN,
)
op.execute(context={"ti": mock.MagicMock()})
mock_hook.assert_called_once_with(
gcp_conn_id=GCP_CONN_ID,
Expand All @@ -308,12 +326,12 @@ def test_execute(self, mock_hook):
mock_hook.return_value.text_generation_model_predict.assert_called_once_with(
project_id=GCP_PROJECT,
location=GCP_LOCATION,
prompt=prompt,
pretrained_model=pretrained_model,
temperature=temperature,
max_output_tokens=max_output_tokens,
top_p=top_p,
top_k=top_k,
prompt=self.prompt,
pretrained_model=self.pretrained_model,
temperature=self.temperature,
max_output_tokens=self.max_output_tokens,
top_p=self.top_p,
top_k=self.top_k,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,13 @@
GenerativeModelGenerateContentOperator,
RunEvaluationOperator,
TextEmbeddingModelGetEmbeddingsOperator,
TextGenerationModelPredictOperator,
)

PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
DAG_ID = "vertex_ai_generative_model_dag"
REGION = "us-central1"
PROMPT = "In 10 words or less, why is Apache Airflow amazing?"
CONTENTS = [PROMPT]
LANGUAGE_MODEL = "text-bison"
TEXT_EMBEDDING_MODEL = "textembedding-gecko"
MULTIMODAL_MODEL = "gemini-pro"
MULTIMODAL_VISION_MODEL = "gemini-pro-vision"
Expand Down Expand Up @@ -117,16 +115,6 @@
catchup=False,
tags=["example", "vertex_ai", "generative_model"],
) as dag:
# [START how_to_cloud_vertex_ai_text_generation_model_predict_operator]
predict_task = TextGenerationModelPredictOperator(
task_id="predict_task",
project_id=PROJECT_ID,
location=REGION,
prompt=PROMPT,
pretrained_model=LANGUAGE_MODEL,
)
# [END how_to_cloud_vertex_ai_text_generation_model_predict_operator]

# [START how_to_cloud_vertex_ai_text_embedding_model_get_embeddings_operator]
generate_embeddings_task = TextEmbeddingModelGetEmbeddingsOperator(
task_id="generate_embeddings_task",
Expand Down
13 changes: 7 additions & 6 deletions tests/always/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.operators.automl.AutoMLTablesListTableSpecsOperator",
"airflow.providers.google.cloud.operators.automl.AutoMLTablesUpdateDatasetOperator",
"airflow.providers.google.cloud.operators.automl.AutoMLDeployModelOperator",
"airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator",
"airflow.providers.google.cloud.operators.datapipeline.CreateDataPipelineOperator",
"airflow.providers.google.cloud.operators.datapipeline.RunDataPipelineOperator",
"airflow.providers.google.cloud.operators.dataproc.DataprocScaleClusterOperator",
Expand All @@ -367,6 +368,12 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.operators.mlengine.MLEngineSetDefaultVersionOperator",
"airflow.providers.google.cloud.operators.mlengine.MLEngineStartBatchPredictionJobOperator",
"airflow.providers.google.cloud.operators.mlengine.MLEngineStartTrainingJobOperator",
"airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptLanguageModelOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.TextGenerationModelPredictOperator",
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360CreateQueryOperator",
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360RunQueryOperator",
"airflow.providers.google.marketing_platform.operators.GoogleDisplayVideo360DownloadReportV2Operator",
Expand All @@ -385,7 +392,6 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
}

MISSING_EXAMPLES_FOR_CLASSES = {
"airflow.providers.google.cloud.operators.mlengine.MLEngineTrainingCancelJobOperator",
"airflow.providers.google.cloud.operators.dlp.CloudDLPRedactImageOperator",
"airflow.providers.google.cloud.transfers.cassandra_to_gcs.CassandraToGCSOperator",
"airflow.providers.google.cloud.transfers.adls_to_gcs.ADLSToGCSOperator",
Expand All @@ -394,11 +400,6 @@ class TestGoogleProviderProjectStructure(ExampleCoverageTest, AssetsCoverageTest
"airflow.providers.google.cloud.operators.vertex_ai.auto_ml.AutoMLTrainingJobBaseOperator",
"airflow.providers.google.cloud.operators.vertex_ai.endpoint_service.UpdateEndpointOperator",
"airflow.providers.google.cloud.operators.vertex_ai.batch_prediction_job.GetBatchPredictionJobOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptLanguageModelOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.GenerateTextEmbeddingsOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelOperator",
"airflow.providers.google.cloud.operators.vertex_ai.generative_model.PromptMultimodalModelWithMediaOperator",
"airflow.providers.google.cloud.operators.automl.AutoMLBatchPredictOperator",
}

ASSETS_NOT_REQUIRED = {
Expand Down

0 comments on commit 30d1a17

Please sign in to comment.