Skip to content

Commit

Permalink
Merge pull request #308 from marqo-ai/bug-fix
Browse files Browse the repository at this point in the history
[Bug fix] Model Cache Management for Reranker Models
  • Loading branch information
wanliAlex authored Feb 9, 2023
2 parents e04d484 + e36719b commit ddec690
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/marqo/s2_inference/processing/custom_clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import urllib
from tqdm import tqdm
from src.marqo.s2_inference.configs import ModelCache
from marqo.s2_inference.configs import ModelCache
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
Expand Down
9 changes: 5 additions & 4 deletions src/marqo/s2_inference/processing/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torchvision

from marqo.s2_inference.s2_inference import available_models
from marqo.s2_inference.s2_inference import available_models,_create_model_cache_key
from marqo.s2_inference.s2_inference import get_logger
from marqo.s2_inference.types import Dict, List, Union, ImageType, Tuple, ndarray, Literal
from marqo.s2_inference.clip_utils import format_and_load_CLIP_image
Expand Down Expand Up @@ -200,8 +200,9 @@ def _get_model_specific_parameters(self):

def _load_and_cache_model(self):
model_type = (self.model_name, self.device)
model_cache_key = _create_model_cache_key(self.model_name, self.device)

if model_type not in available_models:
if model_cache_key not in available_models:
logger.info(f"loading model {model_type}")
if model_type[0] in self.allowed_model_types:
func = self.model_load_function
Expand All @@ -210,9 +211,9 @@ def _load_and_cache_model(self):

self.model, self.preprocess = func(self.model_name, self.device)

available_models[model_type] = (self.model, self.preprocess)
available_models[model_cache_key] = self.model, self.preprocess
else:
self.model, self.preprocess = available_models[model_type]
self.model, self.preprocess = available_models[model_cache_key]

def _load_image(self, image):
self.image, self.image_pt, self.original_size = load_rcnn_image(image, size=self.size)
Expand Down
2 changes: 1 addition & 1 deletion src/marqo/s2_inference/reranking/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def load_owl_vit(model_name: str, device: str = 'cpu') -> Dict:
Dict: _description_
"""

model_cache_key = (model_name, device)
model_cache_key = _create_model_cache_key(model_name, device)

if model_cache_key in available_models:
logger.info(f"loading {model_cache_key} from cache...")
Expand Down
11 changes: 7 additions & 4 deletions src/marqo/s2_inference/s2_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def _create_model_cache_key(model_name: str, device: str, model_properties: dict
if model_properties is None:
model_properties = dict()

model_cache_key = (model_name + "||" +
model_cache_key = (model_name + "||" +
model_properties.get('name', '') + "||" +
str(model_properties.get('dimensions', '')) + "||" +
model_properties.get('type', '') + "||" +
Expand Down Expand Up @@ -318,9 +318,12 @@ def eject_model(model_name:str, device:str):
# we can't handle the situation where there are two models with the same name and device
# but different properties.
for key in model_cache_keys:
if key.startswith(model_name) and key.endswith(device):
model_cache_key = key
break
if isinstance(key, str):
if key.startswith(model_name) and key.endswith(device):
model_cache_key = key
break
else:
continue

if model_cache_key is None:
raise ModelNotInCacheError(f"The model_name `{model_name}` device `{device}` is not cached or found")
Expand Down
11 changes: 6 additions & 5 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,13 +1368,14 @@ def _get_model_properties(index_info):

return model_properties


def get_loaded_models() -> dict:
available_models = s2_inference.get_available_models()
message = {
"models" : [
{"model_name": ix.split("||")[0], "model_device": ix.split("||")[-1]} for ix in available_models.keys()
]
}
message = {"models":[]}

for ix in available_models:
if isinstance(ix, str):
message["models"].append({"model_name": ix.split("||")[0], "model_device": ix.split("||")[-1]})
return message


Expand Down
21 changes: 20 additions & 1 deletion tests/tensor_search/test_model_cache_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
_create_model_cache_key, _update_available_models, available_models, clear_loaded_models
from marqo.tensor_search.tensor_search import eject_model, get_cuda_info, get_loaded_models, get_cpu_info
from marqo.errors import ModelNotInCacheError, HardwareCompatabilityError
from marqo.s2_inference.reranking.cross_encoders import ReRankerText, ReRankerOwl
from marqo.s2_inference.reranking.model_utils import load_owl_vit
from marqo.s2_inference.reranking import rerank
import psutil
from marqo.tensor_search import tensor_search



Expand All @@ -15,7 +19,6 @@ def load_model(model_name: str, device: str, model_properteis: dict = None) -> N


class TestModelCacheManagement(MarqoTestCase):

def setUp(self) -> None:
# We pre-define 3 dummy models for testing purpose
self.MODEL_1 = "ViT-B/32"
Expand Down Expand Up @@ -259,6 +262,22 @@ def test_overall_eject_and_load_model(self):
raise AssertionError


def test_model_cache_management_with_text_reranker(self):
model_name = 'google/owlvit-base-patch32'

_ = load_owl_vit('google/owlvit-base-patch32', "cpu")
model_cache_key = _create_model_cache_key(model_name, "cpu", model_properties=None)

if model_cache_key not in available_models:
raise AssertionError

eject_model(model_name, "cpu")
if model_cache_key in available_models:
raise AssertionError







Expand Down

0 comments on commit ddec690

Please sign in to comment.