diff --git a/airflow/providers/pinecone/hooks/pinecone.py b/airflow/providers/pinecone/hooks/pinecone.py index a04ae60ce8391..35aa66c3204b5 100644 --- a/airflow/providers/pinecone/hooks/pinecone.py +++ b/airflow/providers/pinecone/hooks/pinecone.py @@ -226,7 +226,7 @@ def create_index( :param dimension: The dimension of the vectors to be indexed. :param spec: Pass a `ServerlessSpec` object to create a serverless index or a `PodSpec` object to create a pod index. ``get_serverless_spec_obj`` and ``get_pod_spec_obj`` can be used to create the Spec objects. - :param metric: The metric to use. + :param metric: The metric to use. Defaults to cosine. :param timeout: The timeout to use. """ self.pinecone_client.create_index( diff --git a/airflow/providers/pinecone/operators/pinecone.py b/airflow/providers/pinecone/operators/pinecone.py index 8431276206e07..bb3d44214d42b 100644 --- a/airflow/providers/pinecone/operators/pinecone.py +++ b/airflow/providers/pinecone/operators/pinecone.py @@ -99,10 +99,10 @@ class CreatePodIndexOperator(BaseOperator): :param replicas: The number of replicas to use. :param shards: The number of shards to use. :param pods: The number of pods to use. - :param pod_type: The type of pod to use. + :param pod_type: The type of pod to use. Defaults to p1.x1 :param metadata_config: The metadata configuration to use. :param source_collection: The source collection to use. - :param metric: The metric to use. + :param metric: The metric to use. Defaults to cosine. :param timeout: The timeout to use. """ @@ -116,10 +116,10 @@ def __init__( replicas: int | None = None, shards: int | None = None, pods: int | None = None, - pod_type: str | None = None, + pod_type: str = "p1.x1", metadata_config: dict | None = None, source_collection: str | None = None, - metric: str | None = None, + metric: str = "cosine", timeout: int | None = None, **kwargs: Any, ): diff --git a/tests/system/providers/pinecone/example_pinecone_cohere.py b/tests/system/providers/pinecone/example_pinecone_cohere.py index fa9dccac330b0..c74a376f61406 100644 --- a/tests/system/providers/pinecone/example_pinecone_cohere.py +++ b/tests/system/providers/pinecone/example_pinecone_cohere.py @@ -44,7 +44,8 @@ def create_index(): from airflow.providers.pinecone.hooks.pinecone import PineconeHook hook = PineconeHook() - hook.create_index(index_name=index_name, dimension=768) + pod_spec = hook.get_pod_spec_obj() + hook.create_index(index_name=index_name, dimension=768, spec=pod_spec) time.sleep(60) embed_task = CohereEmbeddingOperator( diff --git a/tests/system/providers/pinecone/example_pinecone_openai.py b/tests/system/providers/pinecone/example_pinecone_openai.py index f68fd5d3e04a4..d338e25542ce0 100644 --- a/tests/system/providers/pinecone/example_pinecone_openai.py +++ b/tests/system/providers/pinecone/example_pinecone_openai.py @@ -17,13 +17,12 @@ from __future__ import annotations import os -import time from datetime import datetime from airflow import DAG -from airflow.decorators import setup, task, teardown +from airflow.decorators import task, teardown from airflow.providers.openai.operators.openai import OpenAIEmbeddingOperator -from airflow.providers.pinecone.operators.pinecone import PineconeIngestOperator +from airflow.providers.pinecone.operators.pinecone import CreatePodIndexOperator, PineconeIngestOperator index_name = os.getenv("INDEX_NAME", "example-pinecone-index") namespace = os.getenv("NAMESPACE", "example-pinecone-index") @@ -75,15 +74,11 @@ start_date=datetime(2023, 1, 1), catchup=False, ) as dag: - - @setup - @task - def create_index(): - from airflow.providers.pinecone.hooks.pinecone import PineconeHook - - hook = PineconeHook() - hook.create_index(index_name=index_name, dimension=1536) - time.sleep(60) + create_index = CreatePodIndexOperator( + task_id="create_index", + index_name=index_name, + dimension=1536, + ) embed_task = OpenAIEmbeddingOperator( task_id="embed_task", @@ -110,7 +105,7 @@ def delete_index(): hook = PineconeHook() hook.delete_index(index_name=index_name) - create_index() >> embed_task >> perform_ingestion >> delete_index() + create_index >> embed_task >> perform_ingestion >> delete_index() from tests.system.utils import get_test_run # noqa: E402