diff --git a/src/marqo/s2_inference/processing/custom_clip_utils.py b/src/marqo/s2_inference/processing/custom_clip_utils.py index 32723f520..10f57a89e 100644 --- a/src/marqo/s2_inference/processing/custom_clip_utils.py +++ b/src/marqo/s2_inference/processing/custom_clip_utils.py @@ -6,7 +6,7 @@ import os import urllib from tqdm import tqdm -from src.marqo.s2_inference.configs import ModelCache +from marqo.s2_inference.configs import ModelCache def whitespace_clean(text): text = re.sub(r'\s+', ' ', text) text = text.strip() diff --git a/src/marqo/s2_inference/processing/image.py b/src/marqo/s2_inference/processing/image.py index dd0b1050e..95bd9b157 100644 --- a/src/marqo/s2_inference/processing/image.py +++ b/src/marqo/s2_inference/processing/image.py @@ -5,7 +5,7 @@ import torch import torchvision -from marqo.s2_inference.s2_inference import available_models +from marqo.s2_inference.s2_inference import available_models,_create_model_cache_key from marqo.s2_inference.s2_inference import get_logger from marqo.s2_inference.types import Dict, List, Union, ImageType, Tuple, ndarray, Literal from marqo.s2_inference.clip_utils import format_and_load_CLIP_image @@ -200,8 +200,9 @@ def _get_model_specific_parameters(self): def _load_and_cache_model(self): model_type = (self.model_name, self.device) + model_cache_key = _create_model_cache_key(self.model_name, self.device) - if model_type not in available_models: + if model_cache_key not in available_models: logger.info(f"loading model {model_type}") if model_type[0] in self.allowed_model_types: func = self.model_load_function @@ -210,9 +211,9 @@ def _load_and_cache_model(self): self.model, self.preprocess = func(self.model_name, self.device) - available_models[model_type] = (self.model, self.preprocess) + available_models[model_cache_key] = self.model, self.preprocess else: - self.model, self.preprocess = available_models[model_type] + self.model, self.preprocess = available_models[model_cache_key] def _load_image(self, image): self.image, self.image_pt, self.original_size = load_rcnn_image(image, size=self.size) diff --git a/src/marqo/s2_inference/reranking/model_utils.py b/src/marqo/s2_inference/reranking/model_utils.py index 9c2890073..6ff972bde 100644 --- a/src/marqo/s2_inference/reranking/model_utils.py +++ b/src/marqo/s2_inference/reranking/model_utils.py @@ -309,7 +309,7 @@ def load_owl_vit(model_name: str, device: str = 'cpu') -> Dict: Dict: _description_ """ - model_cache_key = (model_name, device) + model_cache_key = _create_model_cache_key(model_name, device) if model_cache_key in available_models: logger.info(f"loading {model_cache_key} from cache...") diff --git a/src/marqo/s2_inference/s2_inference.py b/src/marqo/s2_inference/s2_inference.py index 5d51645f2..e2b462110 100644 --- a/src/marqo/s2_inference/s2_inference.py +++ b/src/marqo/s2_inference/s2_inference.py @@ -66,7 +66,7 @@ def _create_model_cache_key(model_name: str, device: str, model_properties: dict if model_properties is None: model_properties = dict() - model_cache_key = (model_name + "||" + + model_cache_key = (model_name + "||" + model_properties.get('name', '') + "||" + str(model_properties.get('dimensions', '')) + "||" + model_properties.get('type', '') + "||" + @@ -318,9 +318,12 @@ def eject_model(model_name:str, device:str): # we can't handle the situation where there are two models with the same name and device # but different properties. for key in model_cache_keys: - if key.startswith(model_name) and key.endswith(device): - model_cache_key = key - break + if isinstance(key, str): + if key.startswith(model_name) and key.endswith(device): + model_cache_key = key + break + else: + continue if model_cache_key is None: raise ModelNotInCacheError(f"The model_name `{model_name}` device `{device}` is not cached or found") diff --git a/src/marqo/tensor_search/tensor_search.py b/src/marqo/tensor_search/tensor_search.py index 1f0371d7b..111dad7a6 100644 --- a/src/marqo/tensor_search/tensor_search.py +++ b/src/marqo/tensor_search/tensor_search.py @@ -1368,13 +1368,14 @@ def _get_model_properties(index_info): return model_properties + def get_loaded_models() -> dict: available_models = s2_inference.get_available_models() - message = { - "models" : [ - {"model_name": ix.split("||")[0], "model_device": ix.split("||")[-1]} for ix in available_models.keys() - ] - } + message = {"models":[]} + + for ix in available_models: + if isinstance(ix, str): + message["models"].append({"model_name": ix.split("||")[0], "model_device": ix.split("||")[-1]}) return message diff --git a/tests/tensor_search/test_model_cache_management.py b/tests/tensor_search/test_model_cache_management.py index dbe9f7494..01566f765 100644 --- a/tests/tensor_search/test_model_cache_management.py +++ b/tests/tensor_search/test_model_cache_management.py @@ -4,7 +4,11 @@ _create_model_cache_key, _update_available_models, available_models, clear_loaded_models from marqo.tensor_search.tensor_search import eject_model, get_cuda_info, get_loaded_models, get_cpu_info from marqo.errors import ModelNotInCacheError, HardwareCompatabilityError +from marqo.s2_inference.reranking.cross_encoders import ReRankerText, ReRankerOwl +from marqo.s2_inference.reranking.model_utils import load_owl_vit +from marqo.s2_inference.reranking import rerank import psutil +from marqo.tensor_search import tensor_search @@ -15,7 +19,6 @@ def load_model(model_name: str, device: str, model_properteis: dict = None) -> N class TestModelCacheManagement(MarqoTestCase): - def setUp(self) -> None: # We pre-define 3 dummy models for testing purpose self.MODEL_1 = "ViT-B/32" @@ -259,6 +262,22 @@ def test_overall_eject_and_load_model(self): raise AssertionError + def test_model_cache_management_with_text_reranker(self): + model_name = 'google/owlvit-base-patch32' + + _ = load_owl_vit('google/owlvit-base-patch32', "cpu") + model_cache_key = _create_model_cache_key(model_name, "cpu", model_properties=None) + + if model_cache_key not in available_models: + raise AssertionError + + eject_model(model_name, "cpu") + if model_cache_key in available_models: + raise AssertionError + + + +