Skip to content

Commit

Permalink
mypy type fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
jhamon committed Dec 16, 2024
1 parent 6ddc079 commit c58cba0
Show file tree
Hide file tree
Showing 14 changed files with 244 additions and 59 deletions.
4 changes: 2 additions & 2 deletions pinecone/control/pinecone_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,10 @@ def create_index(
name: str,
spec: Union[Dict, ServerlessSpec, PodSpec],
dimension: Optional[int],
metric: Optional[str] = "cosine",
metric: Optional[Literal["cosine", "euclidean", "dotproduct"]] = "cosine",
timeout: Optional[int] = None,
deletion_protection: Optional[Literal["enabled", "disabled"]] = "disabled",
vector_type: Optional[str] = "dense",
vector_type: Optional[Literal["dense", "sparse"]] = "dense",
):
"""Creates a Pinecone index.
Expand Down
7 changes: 4 additions & 3 deletions pinecone/data/dataclasses/sparse_values.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
from dataclasses import dataclass

from typing import Any, Dict, List
from typing import List
from .utils import DictLike
from ..types import SparseVectorTypedDict


@dataclass
class SparseValues(DictLike):
indices: List[int]
values: List[float]

def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> SparseVectorTypedDict:
return {"indices": self.indices, "values": self.values}

@staticmethod
def from_dict(sparse_values_dict: Dict[str, Any]) -> "SparseValues":
def from_dict(sparse_values_dict: SparseVectorTypedDict) -> "SparseValues":
return SparseValues(
indices=sparse_values_dict["indices"], values=sparse_values_dict["values"]
)
19 changes: 13 additions & 6 deletions pinecone/data/dataclasses/vector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Optional, Union
from typing import List, Optional
from .sparse_values import SparseValues
from .utils import DictLike
from ..types import VectorTypedDict, VectorMetadataTypedDict

from dataclasses import dataclass, field

Expand All @@ -9,26 +10,32 @@
class Vector(DictLike):
id: str
values: List[float] = field(default_factory=list)
metadata: Optional[Dict[str, Union[str, List[str]]]] = None
metadata: Optional[VectorMetadataTypedDict] = None
sparse_values: Optional[SparseValues] = None

def __post_init__(self):
if self.sparse_values is None and len(self.values) == 0:
raise ValueError("The values and sparse_values fields cannot both be empty")

def to_dict(self) -> Dict[str, Any]:
vector_dict = {"id": self.id, "values": self.values}
def to_dict(self) -> VectorTypedDict:
vector_dict: VectorTypedDict = {"id": self.id, "values": self.values}
if self.metadata is not None:
vector_dict["metadata"] = self.metadata
if self.sparse_values is not None:
vector_dict["sparse_values"] = self.sparse_values.to_dict()
return vector_dict

@staticmethod
def from_dict(vector_dict: Dict[str, Any]) -> "Vector":
def from_dict(vector_dict: VectorTypedDict) -> "Vector":
passed_sparse_values = vector_dict.get("sparse_values")
if passed_sparse_values is not None:
parsed_sparse_values = SparseValues.from_dict(passed_sparse_values)
else:
parsed_sparse_values = None

return Vector(
id=vector_dict["id"],
values=vector_dict["values"],
metadata=vector_dict.get("metadata"),
sparse_values=SparseValues.from_dict(vector_dict.get("sparse_values")),
sparse_values=parsed_sparse_values,
)
28 changes: 11 additions & 17 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .interfaces import IndexInterface
from .request_factory import IndexRequestFactory
from .features.bulk_import import ImportFeatureMixin
from .types import SparseVectorTypedDict, VectorTypedDict, VectorMetadataTypedDict, VectorTuple
from ..utils import (
setup_openapi_client,
parse_non_empty_args,
Expand Down Expand Up @@ -137,7 +138,7 @@ def __exit__(self, exc_type, exc_value, traceback):
@validate_and_convert_errors
def upsert(
self,
vectors: Union[List[Vector], List[tuple], List[dict]],
vectors: Union[List[Vector], List[VectorTuple], List[VectorTypedDict]],
namespace: Optional[str] = None,
batch_size: Optional[int] = None,
show_progress: bool = True,
Expand Down Expand Up @@ -172,7 +173,7 @@ def upsert(

def _upsert_batch(
self,
vectors: Union[List[Vector], List[tuple], List[dict]],
vectors: Union[List[Vector], List[VectorTuple], List[VectorTypedDict]],
namespace: Optional[str],
_check_type: bool,
**kwargs,
Expand Down Expand Up @@ -235,7 +236,7 @@ def fetch(self, ids: List[str], namespace: Optional[str] = None, **kwargs) -> Fe
args_dict = parse_non_empty_args([("namespace", namespace)])
result = self._vector_api.fetch_vectors(ids=ids, **args_dict, **kwargs)
return FetchResponse(
namespace=namespace,
namespace=result.namespace,
vectors={k: Vector.from_dict(v) for k, v in result.vectors},
usage=result.usage,
)
Expand All @@ -251,9 +252,7 @@ def query(
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
sparse_vector: Optional[Union[SparseValues, SparseVectorTypedDict]] = None,
**kwargs,
) -> Union[QueryResponse, ApplyResult]:
response = self._query(
Expand Down Expand Up @@ -284,9 +283,7 @@ def _query(
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
sparse_vector: Optional[Union[SparseValues, SparseVectorTypedDict]] = None,
**kwargs,
) -> QueryResponse:
if len(args) > 0:
Expand All @@ -310,7 +307,7 @@ def _query(
@validate_and_convert_errors
def query_namespaces(
self,
vector: List[float],
vector: Optional[List[float]],
namespaces: List[str],
metric: Literal["cosine", "euclidean", "dotproduct"],
top_k: Optional[int] = None,
Expand All @@ -324,7 +321,8 @@ def query_namespaces(
) -> QueryNamespacesResults:
if namespaces is None or len(namespaces) == 0:
raise ValueError("At least one namespace must be specified")
if len(vector) == 0:
if sparse_vector is None and vector is not None and len(vector) == 0:
# If querying with a vector, it must not be empty
raise ValueError("Query vector must not be empty")

overall_topk = top_k if top_k is not None else 10
Expand Down Expand Up @@ -360,13 +358,9 @@ def update(
self,
id: str,
values: Optional[List[float]] = None,
set_metadata: Optional[
Dict[str, Union[str, float, int, bool, List[int], List[float], List[str]]]
] = None,
set_metadata: Optional[VectorMetadataTypedDict] = None,
namespace: Optional[str] = None,
sparse_values: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
sparse_values: Optional[Union[SparseValues, SparseVectorTypedDict]] = None,
**kwargs,
) -> Dict[str, Any]:
return self._vector_api.update_vector(
Expand Down
15 changes: 5 additions & 10 deletions pinecone/data/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
)
from .query_results_aggregator import QueryNamespacesResults
from multiprocessing.pool import ApplyResult
from .types import VectorTypedDict, SparseVectorTypedDict, VectorMetadataTypedDict


class IndexInterface(ABC):
@abstractmethod
def upsert(
self,
vectors: Union[List[Vector], List[tuple], List[dict]],
vectors: Union[List[Vector], List[tuple], List[VectorTypedDict]],
namespace: Optional[str] = None,
batch_size: Optional[int] = None,
show_progress: bool = True,
Expand Down Expand Up @@ -229,9 +230,7 @@ def query_namespaces(
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
sparse_vector: Optional[Union[SparseValues, SparseVectorTypedDict]] = None,
**kwargs,
) -> QueryNamespacesResults:
"""The query_namespaces() method is used to make a query to multiple namespaces in parallel and combine the results into one result set.
Expand Down Expand Up @@ -285,13 +284,9 @@ def update(
self,
id: str,
values: Optional[List[float]] = None,
set_metadata: Optional[
Dict[str, Union[str, float, int, bool, List[int], List[float], List[str]]]
] = None,
set_metadata: Optional[VectorMetadataTypedDict] = None,
namespace: Optional[str] = None,
sparse_values: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
sparse_values: Optional[Union[SparseValues, SparseVectorTypedDict]] = None,
**kwargs,
) -> Dict[str, Any]:
"""
Expand Down
20 changes: 7 additions & 13 deletions pinecone/data/request_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@
from pinecone.core.openapi.db_data.models import (
QueryRequest,
UpsertRequest,
Vector,
DeleteRequest,
UpdateRequest,
DescribeIndexStatsRequest,
SparseValues,
)
from ..utils import parse_non_empty_args
from .vector_factory import VectorFactory
from pinecone.openapi_support import OPENAPI_ENDPOINT_PARAMS
from .types import VectorTypedDict, SparseVectorTypedDict, VectorMetadataTypedDict, VectorTuple
from .dataclasses import Vector, SparseValues

logger = logging.getLogger(__name__)


def parse_sparse_values_arg(
sparse_values: Optional[Union[SparseValues, Dict[str, Union[List[float], List[int]]]]],
sparse_values: Optional[Union[SparseValues, SparseVectorTypedDict]],
) -> Optional[SparseValues]:
if sparse_values is None:
return None
Expand Down Expand Up @@ -53,9 +53,7 @@ def query_request(
filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None,
include_values: Optional[bool] = None,
include_metadata: Optional[bool] = None,
sparse_vector: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
sparse_vector: Optional[Union[SparseValues, SparseVectorTypedDict]] = None,
**kwargs,
) -> QueryRequest:
if vector is not None and id is not None:
Expand All @@ -82,7 +80,7 @@ def query_request(

@staticmethod
def upsert_request(
vectors: Union[List[Vector], List[tuple], List[dict]],
vectors: Union[List[Vector], List[VectorTuple], List[VectorTypedDict]],
namespace: Optional[str],
_check_type: bool,
**kwargs,
Expand Down Expand Up @@ -117,13 +115,9 @@ def delete_request(
def update_request(
id: str,
values: Optional[List[float]] = None,
set_metadata: Optional[
Dict[str, Union[str, float, int, bool, List[int], List[float], List[str]]]
] = None,
set_metadata: Optional[VectorMetadataTypedDict] = None,
namespace: Optional[str] = None,
sparse_values: Optional[
Union[SparseValues, Dict[str, Union[List[float], List[int]]]]
] = None,
sparse_values: Optional[Union[SparseValues, SparseVectorTypedDict]] = None,
**kwargs,
) -> UpdateRequest:
_check_type = kwargs.pop("_check_type", False)
Expand Down
6 changes: 4 additions & 2 deletions pinecone/data/sparse_values_factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Mapping
from typing import Union, Dict
from typing import Union, Dict, Optional

from ..utils import convert_to_list

Expand All @@ -17,7 +17,9 @@ class SparseValuesFactory:
"""SparseValuesFactory is used to convert various types of user input into SparseValues objects used in generated request code."""

@staticmethod
def build(input: Union[Dict, SparseValues]) -> OpenApiSparseValues:
def build(
input: Union[Dict, Optional[SparseValues], OpenApiSparseValues],
) -> Optional[OpenApiSparseValues]:
if input is None:
return input
if isinstance(input, OpenApiSparseValues):
Expand Down
4 changes: 4 additions & 0 deletions pinecone/data/types/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .sparse_vector_typed_dict import SparseVectorTypedDict
from .vector_typed_dict import VectorTypedDict
from .vector_metadata_dict import VectorMetadataTypedDict
from .vector_tuple import VectorTuple
6 changes: 6 additions & 0 deletions pinecone/data/types/sparse_vector_typed_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import TypedDict, List


class SparseVectorTypedDict(TypedDict):
indices: List[int]
values: List[float]
4 changes: 4 additions & 0 deletions pinecone/data/types/vector_metadata_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from typing import Dict, List, Union

VectorDictMetadataValue = Union[str, int, float, List[str], List[int], List[float]]
VectorMetadataTypedDict = Dict[str, VectorDictMetadataValue]
3 changes: 3 additions & 0 deletions pinecone/data/types/vector_tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from typing import Tuple, List

VectorTuple = Tuple[str, List[float]]
9 changes: 9 additions & 0 deletions pinecone/data/types/vector_typed_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .sparse_vector_typed_dict import SparseVectorTypedDict
from typing import TypedDict, List


class VectorTypedDict(TypedDict, total=False):
values: List[float]
metadata: dict
sparse_values: SparseVectorTypedDict
id: str
8 changes: 2 additions & 6 deletions pinecone/grpc/index_grpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Dict, Union, List, Tuple, Any, TypedDict, Iterable, cast, Literal
from typing import Optional, Dict, Union, List, Tuple, Any, Iterable, cast, Literal

from google.protobuf import json_format

Expand Down Expand Up @@ -42,18 +42,14 @@
from pinecone.core.grpc.protos.db_data_2025_01_pb2_grpc import VectorServiceStub
from .base import GRPCIndexBase
from .future import PineconeGrpcFuture
from ..data.types import SparseVectorTypedDict


__all__ = ["GRPCIndex", "GRPCVector", "GRPCQueryVector", "GRPCSparseValues"]

_logger = logging.getLogger(__name__)


class SparseVectorTypedDict(TypedDict):
indices: List[int]
values: List[float]


class GRPCIndex(GRPCIndexBase):
"""A client for interacting with a Pinecone index via GRPC API."""

Expand Down
Loading

0 comments on commit c58cba0

Please sign in to comment.