From 9ef1b8686703c69202cacff076d2c25543f8f28e Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Fri, 17 May 2024 21:02:17 +0545 Subject: [PATCH 1/2] Fix the argument type of input_vectors --- airflow/providers/pinecone/hooks/pinecone.py | 4 ++-- airflow/providers/pinecone/operators/pinecone.py | 4 ++-- .../providers/pinecone/example_pinecone_cohere.py | 15 +++++++++------ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/airflow/providers/pinecone/hooks/pinecone.py b/airflow/providers/pinecone/hooks/pinecone.py index 35aa66c3204b5..3a52bbcac04f9 100644 --- a/airflow/providers/pinecone/hooks/pinecone.py +++ b/airflow/providers/pinecone/hooks/pinecone.py @@ -24,7 +24,7 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from pinecone import Pinecone, PodSpec, ServerlessSpec +from pinecone import Pinecone, PodSpec, ServerlessSpec, Vector from airflow.hooks.base import BaseHook @@ -137,7 +137,7 @@ def list_indexes(self) -> Any: def upsert( self, index_name: str, - vectors: list[Any], + vectors: list[Vector] | list[tuple] | list[dict], namespace: str = "", batch_size: int | None = None, show_progress: bool = True, diff --git a/airflow/providers/pinecone/operators/pinecone.py b/airflow/providers/pinecone/operators/pinecone.py index bb3d44214d42b..ec206405f71ec 100644 --- a/airflow/providers/pinecone/operators/pinecone.py +++ b/airflow/providers/pinecone/operators/pinecone.py @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator -from airflow.providers.pinecone.hooks.pinecone import PineconeHook +from airflow.providers.pinecone.hooks.pinecone import PineconeHook, Vector from airflow.utils.context import Context if TYPE_CHECKING: @@ -52,7 +52,7 @@ def __init__( *, conn_id: str = PineconeHook.default_conn_name, index_name: str, - input_vectors: list[tuple], + input_vectors: list[Vector] | list[tuple] | list[dict], namespace: str = "", batch_size: int | None = None, upsert_kwargs: dict | None = None, diff --git a/tests/system/providers/pinecone/example_pinecone_cohere.py b/tests/system/providers/pinecone/example_pinecone_cohere.py index c74a376f61406..80e6766484d6b 100644 --- a/tests/system/providers/pinecone/example_pinecone_cohere.py +++ b/tests/system/providers/pinecone/example_pinecone_cohere.py @@ -17,7 +17,6 @@ from __future__ import annotations import os -import time from datetime import datetime from airflow import DAG @@ -46,19 +45,23 @@ def create_index(): hook = PineconeHook() pod_spec = hook.get_pod_spec_obj() hook.create_index(index_name=index_name, dimension=768, spec=pod_spec) - time.sleep(60) embed_task = CohereEmbeddingOperator( task_id="embed_task", input_text=data, ) + @task + def transform_output(embedding_output) -> list[dict]: + # Convert each embedding to a map with an ID and the embedding vector + return [dict(id=str(i), values=embedding) for i, embedding in enumerate(embedding_output)] + + transformed_output = transform_output(embed_task.output) + perform_ingestion = PineconeIngestOperator( task_id="perform_ingestion", index_name=index_name, - input_vectors=[ - ("id1", embed_task.output), - ], + input_vectors=transformed_output, namespace=namespace, batch_size=1, ) @@ -71,7 +74,7 @@ def delete_index(): hook = PineconeHook() hook.delete_index(index_name=index_name) - create_index() >> embed_task >> perform_ingestion >> delete_index() + create_index() >> embed_task >> transformed_output >> perform_ingestion >> delete_index() from tests.system.utils import get_test_run # noqa: E402 From 2c56b8e57c759a3f72aef478b9d450d876d169a6 Mon Sep 17 00:00:00 2001 From: Ankit Chaurasia <8670962+sunank200@users.noreply.github.com> Date: Fri, 17 May 2024 22:11:30 +0545 Subject: [PATCH 2/2] Fix typing and docstring --- airflow/providers/pinecone/hooks/pinecone.py | 3 ++- airflow/providers/pinecone/operators/pinecone.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/airflow/providers/pinecone/hooks/pinecone.py b/airflow/providers/pinecone/hooks/pinecone.py index 3a52bbcac04f9..b5e73ae4c69a0 100644 --- a/airflow/providers/pinecone/hooks/pinecone.py +++ b/airflow/providers/pinecone/hooks/pinecone.py @@ -24,11 +24,12 @@ from functools import cached_property from typing import TYPE_CHECKING, Any -from pinecone import Pinecone, PodSpec, ServerlessSpec, Vector +from pinecone import Pinecone, PodSpec, ServerlessSpec from airflow.hooks.base import BaseHook if TYPE_CHECKING: + from pinecone import Vector from pinecone.core.client.model.sparse_values import SparseValues from pinecone.core.client.models import DescribeIndexStatsResponse, QueryResponse, UpsertResponse diff --git a/airflow/providers/pinecone/operators/pinecone.py b/airflow/providers/pinecone/operators/pinecone.py index ec206405f71ec..70711e062308d 100644 --- a/airflow/providers/pinecone/operators/pinecone.py +++ b/airflow/providers/pinecone/operators/pinecone.py @@ -21,10 +21,12 @@ from typing import TYPE_CHECKING, Any, Sequence from airflow.models import BaseOperator -from airflow.providers.pinecone.hooks.pinecone import PineconeHook, Vector +from airflow.providers.pinecone.hooks.pinecone import PineconeHook from airflow.utils.context import Context if TYPE_CHECKING: + from pinecone import Vector + from airflow.utils.context import Context @@ -38,8 +40,8 @@ class PineconeIngestOperator(BaseOperator): :param conn_id: The connection id to use when connecting to Pinecone. :param index_name: Name of the Pinecone index. - :param input_vectors: Data to be ingested, in the form of a list of tuples where each tuple - contains (id, vector_embedding, metadata). + :param input_vectors: Data to be ingested, in the form of a list of vectors, list of tuples, + or list of dictionaries. :param namespace: The namespace to write to. If not specified, the default namespace is used. :param batch_size: The number of vectors to upsert in each batch. :param upsert_kwargs: .. seealso:: https://docs.pinecone.io/reference/upsert