Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose missing configurations for grpc_config and pool_threads #296

Merged
merged 4 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ def __init__(
else:
self.config = PineconeConfig.build(api_key=api_key, host=host, additional_headers=additional_headers, **kwargs)

self.pool_threads = pool_threads
if index_api:
self.index_api = index_api
else:
self.pool_threads = pool_threads
api_client = ApiClient(configuration=self.config.openapi_config, pool_threads=self.pool_threads)
api_client.user_agent = get_user_agent()
extra_headers = self.config.additional_headers or {}
Expand Down Expand Up @@ -446,7 +446,7 @@ def _get_status(self, name: str):
response = api_instance.describe_index(name)
return response["status"]

def Index(self, name: str = '', host: str = ''):
def Index(self, name: str = '', host: str = '', **kwargs):
"""
Target an index for data operations.

Expand Down Expand Up @@ -518,12 +518,14 @@ def Index(self, name: str = '', host: str = ''):
"""
if name == '' and host == '':
raise ValueError("Either name or host must be specified")

pt = kwargs.pop('pool_threads', None) or self.pool_threads

if host != '':
# Use host url if it is provided
return Index(api_key=self.config.api_key, host=normalize_host(host), pool_threads=self.pool_threads)
return Index(api_key=self.config.api_key, host=normalize_host(host), pool_threads=pt, **kwargs)

if name != '':
# Otherwise, get host url from describe_index using the index name
index_host = self.index_host_store.get_host(self.index_api, self.config, name)
return Index(api_key=self.config.api_key, host=index_host, pool_threads=self.pool_threads)
return Index(api_key=self.config.api_key, host=index_host, pool_threads=pt, **kwargs)
1 change: 1 addition & 0 deletions pinecone/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@

from .index_grpc import GRPCIndex
from .pinecone import PineconeGRPC
from .config import GRPCClientConfig

from pinecone.core.grpc.protos.vector_service_pb2 import (
Vector as GRPCVector,
Expand Down
2 changes: 1 addition & 1 deletion pinecone/grpc/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class GRPCClientConfig(NamedTuple):
GRPC client configuration options.

:param secure: Whether to use encrypted protocol (SSL). defaults to True.
:type traceroute: bool, optional
:type secure: bool, optional
:param timeout: defaults to 2 seconds. Fail if gateway doesn't receive response within timeout.
:type timeout: int, optional
:param conn_timeout: defaults to 1. Timeout to retry connection if gRPC is unavailable. 0 is no retry.
Expand Down
6 changes: 3 additions & 3 deletions pinecone/grpc/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PineconeGRPC(Pinecone):

"""

def Index(self, name: str = '', host: str = ''):
def Index(self, name: str = '', host: str = '', **kwargs):
"""
Target an index for data operations.

Expand Down Expand Up @@ -123,10 +123,10 @@ def Index(self, name: str = '', host: str = ''):
if host != '':
# Use host if it is provided
config = ConfigBuilder.build(api_key=self.config.api_key, host=host)
return GRPCIndex(index_name=name, config=config)
return GRPCIndex(index_name=name, config=config, **kwargs)

if name != '':
# Otherwise, get host url from describe_index using the index name
index_host = self.index_host_store.get_host(self.index_api, self.config, name)
config = ConfigBuilder.build(api_key=self.config.api_key, host=index_host)
return GRPCIndex(index_name=name, config=config)
return GRPCIndex(index_name=name, config=config, **kwargs)
26 changes: 25 additions & 1 deletion tests/unit/test_control.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from pinecone import Pinecone, PodSpec, ServerlessSpec
from pinecone.core.client.models import IndexList, IndexModel
from pinecone.core.client.api.manage_indexes_api import ManageIndexesApi
import time

@pytest.fixture
Expand Down Expand Up @@ -67,4 +68,27 @@ def test_list_indexes_returns_iterable(self, mocker, index_list_response):
mocker.patch.object(p.index_api, 'list_indexes', side_effect=[index_list_response])

response = p.list_indexes()
assert [i.name for i in response] == ["index1", "index2", "index3"]
assert [i.name for i in response] == ["index1", "index2", "index3"]


class TestIndexConfig:
def test_default_pool_threads(self):
pc = Pinecone(api_key="123-456-789")
index = pc.Index(host='my-host.svg.pinecone.io')
assert index._api_client.pool_threads == 1

def test_pool_threads_when_indexapi_passed(self):
pc = Pinecone(api_key="123-456-789", pool_threads=2, index_api=ManageIndexesApi())
index = pc.Index(host='my-host.svg.pinecone.io')
assert index._api_client.pool_threads == 2

def test_target_index_with_pool_threads_inherited(self):
pc = Pinecone(api_key="123-456-789", pool_threads=10, foo='bar')
index = pc.Index(host='my-host.svg.pinecone.io')
assert index._api_client.pool_threads == 10

def test_target_index_with_pool_threads_kwarg(self):
pc = Pinecone(api_key="123-456-789", pool_threads=10)
index = pc.Index(host='my-host.svg.pinecone.io', pool_threads=5)
assert index._api_client.pool_threads == 5

66 changes: 66 additions & 0 deletions tests/unit_grpc/test_grpc_index_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from pinecone.grpc import PineconeGRPC, GRPCClientConfig

class TestGRPCIndexInitialization:
def test_init_with_default_config(self):
pc = PineconeGRPC(api_key='YOUR_API_KEY')
index = pc.Index(name='my-index', host='host')

assert index.grpc_client_config.secure == True
assert index.grpc_client_config.timeout == 20
assert index.grpc_client_config.conn_timeout == 1
assert index.grpc_client_config.reuse_channel == True
assert index.grpc_client_config.retry_config == None
assert index.grpc_client_config.grpc_channel_options == None

def test_init_with_grpc_config_from_dict(self):
pc = PineconeGRPC(api_key='YOUR_API_KEY')
config = GRPCClientConfig._from_dict({'timeout': 10})
index = pc.Index(name='my-index', host='host', grpc_config=config)

assert index.grpc_client_config.timeout == 10

# Unset fields still get default values
assert index.grpc_client_config.reuse_channel == True
assert index.grpc_client_config.secure == True


def test_init_with_grpc_config_non_dict(self):
pc = PineconeGRPC(api_key='YOUR_API_KEY')
config = GRPCClientConfig(timeout=10, secure=False)
index = pc.Index(name='my-index', host='host', grpc_config=config)

assert index.grpc_client_config.timeout == 10
assert index.grpc_client_config.secure == False

# Unset fields still get default values
assert index.grpc_client_config.reuse_channel == True
assert index.grpc_client_config.conn_timeout == 1

def test_config_passed_when_target_by_name(self):
pc = PineconeGRPC(api_key='YOUR_API_KEY')

# Set this state in the host store to skip network call
# to find host for name
pc.index_host_store.set_host(pc.config, 'my-index', 'myhost')

config = GRPCClientConfig(timeout=10, secure=False)
index = pc.Index(name='my-index', grpc_config=config)

assert index.grpc_client_config.timeout == 10
assert index.grpc_client_config.secure == False

# Unset fields still get default values
assert index.grpc_client_config.reuse_channel == True
assert index.grpc_client_config.conn_timeout == 1

def test_config_passed_when_target_by_host(self):
pc = PineconeGRPC(api_key='YOUR_API_KEY')
config = GRPCClientConfig(timeout=5, secure=True)
index = pc.Index(host='myhost', grpc_config=config)

assert index.grpc_client_config.timeout == 5
assert index.grpc_client_config.secure == True

# Unset fields still get default values
assert index.grpc_client_config.reuse_channel == True
assert index.grpc_client_config.conn_timeout == 1