diff --git a/pinecone/control/pinecone.py b/pinecone/control/pinecone.py index ea14a7cc..1d315494 100644 --- a/pinecone/control/pinecone.py +++ b/pinecone/control/pinecone.py @@ -94,10 +94,10 @@ def __init__( else: self.config = PineconeConfig.build(api_key=api_key, host=host, additional_headers=additional_headers, **kwargs) + self.pool_threads = pool_threads if index_api: self.index_api = index_api else: - self.pool_threads = pool_threads api_client = ApiClient(configuration=self.config.openapi_config, pool_threads=self.pool_threads) api_client.user_agent = get_user_agent() extra_headers = self.config.additional_headers or {} @@ -446,7 +446,7 @@ def _get_status(self, name: str): response = api_instance.describe_index(name) return response["status"] - def Index(self, name: str = '', host: str = ''): + def Index(self, name: str = '', host: str = '', **kwargs): """ Target an index for data operations. @@ -518,12 +518,14 @@ def Index(self, name: str = '', host: str = ''): """ if name == '' and host == '': raise ValueError("Either name or host must be specified") + + pt = kwargs.pop('pool_threads', None) or self.pool_threads if host != '': # Use host url if it is provided - return Index(api_key=self.config.api_key, host=normalize_host(host), pool_threads=self.pool_threads) + return Index(api_key=self.config.api_key, host=normalize_host(host), pool_threads=pt, **kwargs) if name != '': # Otherwise, get host url from describe_index using the index name index_host = self.index_host_store.get_host(self.index_api, self.config, name) - return Index(api_key=self.config.api_key, host=index_host, pool_threads=self.pool_threads) + return Index(api_key=self.config.api_key, host=index_host, pool_threads=pt, **kwargs) diff --git a/pinecone/grpc/__init__.py b/pinecone/grpc/__init__.py index 67338779..0a7670ef 100644 --- a/pinecone/grpc/__init__.py +++ b/pinecone/grpc/__init__.py @@ -46,6 +46,7 @@ from .index_grpc import GRPCIndex from .pinecone import PineconeGRPC +from .config import GRPCClientConfig from pinecone.core.grpc.protos.vector_service_pb2 import ( Vector as GRPCVector, diff --git a/pinecone/grpc/config.py b/pinecone/grpc/config.py index 1a58cd6a..b34c885e 100644 --- a/pinecone/grpc/config.py +++ b/pinecone/grpc/config.py @@ -7,7 +7,7 @@ class GRPCClientConfig(NamedTuple): GRPC client configuration options. :param secure: Whether to use encrypted protocol (SSL). defaults to True. - :type traceroute: bool, optional + :type secure: bool, optional :param timeout: defaults to 2 seconds. Fail if gateway doesn't receive response within timeout. :type timeout: int, optional :param conn_timeout: defaults to 1. Timeout to retry connection if gRPC is unavailable. 0 is no retry. diff --git a/pinecone/grpc/pinecone.py b/pinecone/grpc/pinecone.py index bd68360b..c7141d79 100644 --- a/pinecone/grpc/pinecone.py +++ b/pinecone/grpc/pinecone.py @@ -46,7 +46,7 @@ class PineconeGRPC(Pinecone): """ - def Index(self, name: str = '', host: str = ''): + def Index(self, name: str = '', host: str = '', **kwargs): """ Target an index for data operations. @@ -123,10 +123,10 @@ def Index(self, name: str = '', host: str = ''): if host != '': # Use host if it is provided config = ConfigBuilder.build(api_key=self.config.api_key, host=host) - return GRPCIndex(index_name=name, config=config) + return GRPCIndex(index_name=name, config=config, **kwargs) if name != '': # Otherwise, get host url from describe_index using the index name index_host = self.index_host_store.get_host(self.index_api, self.config, name) config = ConfigBuilder.build(api_key=self.config.api_key, host=index_host) - return GRPCIndex(index_name=name, config=config) \ No newline at end of file + return GRPCIndex(index_name=name, config=config, **kwargs) \ No newline at end of file diff --git a/tests/unit/test_control.py b/tests/unit/test_control.py index 650f6d9d..2c3ed4da 100644 --- a/tests/unit/test_control.py +++ b/tests/unit/test_control.py @@ -1,6 +1,7 @@ import pytest from pinecone import Pinecone, PodSpec, ServerlessSpec from pinecone.core.client.models import IndexList, IndexModel +from pinecone.core.client.api.manage_indexes_api import ManageIndexesApi import time @pytest.fixture @@ -67,4 +68,27 @@ def test_list_indexes_returns_iterable(self, mocker, index_list_response): mocker.patch.object(p.index_api, 'list_indexes', side_effect=[index_list_response]) response = p.list_indexes() - assert [i.name for i in response] == ["index1", "index2", "index3"] \ No newline at end of file + assert [i.name for i in response] == ["index1", "index2", "index3"] + + +class TestIndexConfig: + def test_default_pool_threads(self): + pc = Pinecone(api_key="123-456-789") + index = pc.Index(host='my-host.svg.pinecone.io') + assert index._api_client.pool_threads == 1 + + def test_pool_threads_when_indexapi_passed(self): + pc = Pinecone(api_key="123-456-789", pool_threads=2, index_api=ManageIndexesApi()) + index = pc.Index(host='my-host.svg.pinecone.io') + assert index._api_client.pool_threads == 2 + + def test_target_index_with_pool_threads_inherited(self): + pc = Pinecone(api_key="123-456-789", pool_threads=10, foo='bar') + index = pc.Index(host='my-host.svg.pinecone.io') + assert index._api_client.pool_threads == 10 + + def test_target_index_with_pool_threads_kwarg(self): + pc = Pinecone(api_key="123-456-789", pool_threads=10) + index = pc.Index(host='my-host.svg.pinecone.io', pool_threads=5) + assert index._api_client.pool_threads == 5 + diff --git a/tests/unit_grpc/test_grpc_index_initialization.py b/tests/unit_grpc/test_grpc_index_initialization.py new file mode 100644 index 00000000..66226ea4 --- /dev/null +++ b/tests/unit_grpc/test_grpc_index_initialization.py @@ -0,0 +1,66 @@ +from pinecone.grpc import PineconeGRPC, GRPCClientConfig + +class TestGRPCIndexInitialization: + def test_init_with_default_config(self): + pc = PineconeGRPC(api_key='YOUR_API_KEY') + index = pc.Index(name='my-index', host='host') + + assert index.grpc_client_config.secure == True + assert index.grpc_client_config.timeout == 20 + assert index.grpc_client_config.conn_timeout == 1 + assert index.grpc_client_config.reuse_channel == True + assert index.grpc_client_config.retry_config == None + assert index.grpc_client_config.grpc_channel_options == None + + def test_init_with_grpc_config_from_dict(self): + pc = PineconeGRPC(api_key='YOUR_API_KEY') + config = GRPCClientConfig._from_dict({'timeout': 10}) + index = pc.Index(name='my-index', host='host', grpc_config=config) + + assert index.grpc_client_config.timeout == 10 + + # Unset fields still get default values + assert index.grpc_client_config.reuse_channel == True + assert index.grpc_client_config.secure == True + + + def test_init_with_grpc_config_non_dict(self): + pc = PineconeGRPC(api_key='YOUR_API_KEY') + config = GRPCClientConfig(timeout=10, secure=False) + index = pc.Index(name='my-index', host='host', grpc_config=config) + + assert index.grpc_client_config.timeout == 10 + assert index.grpc_client_config.secure == False + + # Unset fields still get default values + assert index.grpc_client_config.reuse_channel == True + assert index.grpc_client_config.conn_timeout == 1 + + def test_config_passed_when_target_by_name(self): + pc = PineconeGRPC(api_key='YOUR_API_KEY') + + # Set this state in the host store to skip network call + # to find host for name + pc.index_host_store.set_host(pc.config, 'my-index', 'myhost') + + config = GRPCClientConfig(timeout=10, secure=False) + index = pc.Index(name='my-index', grpc_config=config) + + assert index.grpc_client_config.timeout == 10 + assert index.grpc_client_config.secure == False + + # Unset fields still get default values + assert index.grpc_client_config.reuse_channel == True + assert index.grpc_client_config.conn_timeout == 1 + + def test_config_passed_when_target_by_host(self): + pc = PineconeGRPC(api_key='YOUR_API_KEY') + config = GRPCClientConfig(timeout=5, secure=True) + index = pc.Index(host='myhost', grpc_config=config) + + assert index.grpc_client_config.timeout == 5 + assert index.grpc_client_config.secure == True + + # Unset fields still get default values + assert index.grpc_client_config.reuse_channel == True + assert index.grpc_client_config.conn_timeout == 1