Skip to content

Commit

Permalink
Deletion optimisation (#436)
Browse files Browse the repository at this point in the history
* Refactored delete documents into its own file

Ran tensor search unit tests - passed (besides a randomly failing one)

* Refactored a deletion interface. At parity, in terms of delete tests

* swapped delete by query with bulk delete. untested

* combined components together

* used existing tensor_search entrypoint function for minimal interface disruption

* added label for data-layer agnostic logic

* added more tests

* Overwrite files from mainline

* Overwrite files from mainline

* added tests for config.backend

* added env var for delete docs request

* added tests for read_env_vars_and_defaults_ints

* fixed read_env_vars_and_defaults, added mock environ test

* standardised the call from api.delete_docs to tensor_search.delete_documents
  • Loading branch information
pandu-k authored Apr 20, 2023
1 parent 988ad9e commit 7d571c2
Show file tree
Hide file tree
Showing 13 changed files with 728 additions and 68 deletions.
4 changes: 3 additions & 1 deletion src/marqo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ def __init__(
url: str,
timeout: Optional[int] = None,
indexing_device: Optional[Union[enums.Device, str]] = None,
search_device: Optional[Union[enums.Device, str]] = None
search_device: Optional[Union[enums.Device, str]] = None,
backend: Optional[Union[enums.SearchDb, str]] = None,
) -> None:
"""
Parameters
Expand All @@ -23,6 +24,7 @@ def __init__(

self.indexing_device = indexing_device if indexing_device is not None else default_device
self.search_device = search_device if search_device is not None else default_device
self.backend = backend if backend is not None else enums.SearchDb.opensearch

def set_url(self, url):
"""Set the URL, and infers whether that url is remote"""
Expand Down
6 changes: 3 additions & 3 deletions src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""The API entrypoint for Tensor Search"""
import typing
from fastapi.responses import JSONResponse
from fastapi import FastAPI, Request, Depends, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi import Request, Depends
import marqo.tensor_search.delete_docs
import marqo.tensor_search.tensor_search
from marqo.errors import InvalidArgError, MarqoWebError, MarqoError
from fastapi import FastAPI, Query
import json
Expand All @@ -12,7 +13,6 @@
import os
from marqo.tensor_search.models.api_models import BulkSearchQuery, SearchQuery
from marqo.tensor_search.web import api_validation, api_utils
from marqo.tensor_search import utils
from marqo.tensor_search.on_start_script import on_start
from marqo import version
from marqo.tensor_search.enums import RequestType
Expand Down
8 changes: 3 additions & 5 deletions src/marqo/tensor_search/configs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
import sys

from torch import multiprocessing as mp

from marqo.tensor_search import enums as ns_enums
from marqo.tensor_search.enums import IndexSettingsField as NsFields, EnvVars

Expand Down Expand Up @@ -62,5 +58,7 @@ def default_env_vars() -> dict:
EnvVars.MARQO_MAX_CUDA_MODEL_MEMORY: 4, # For multi-GPU, this is the max memory for each GPU.
EnvVars.MARQO_EF_CONSTRUCTION_MAX_VALUE: 4096,
EnvVars.MARQO_MAX_VECTORISE_BATCH_SIZE: 16,
EnvVars.MARQO_MAX_DELETE_DOCS_COUNT: 10000,
EnvVars.MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES: None
}
}

77 changes: 77 additions & 0 deletions src/marqo/tensor_search/delete_docs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
This module handles the delete documents endpoint
"""
import datetime
import json
from marqo._httprequests import HttpRequests
from marqo.config import Config
from marqo.tensor_search import validation, utils, enums
from marqo.tensor_search.models.delete_docs_objects import MqDeleteDocsResponse, MqDeleteDocsRequest

# -- Marqo delete endpoint interface: --


def format_delete_docs_response(marqo_response: MqDeleteDocsResponse) -> dict:
"""This formats the delete response for users """
return {
"index_name": marqo_response.index_name, "status": marqo_response.status_string,
"type": "documentDeletion", "details": {
"receivedDocumentIds": len(marqo_response.document_ids),
"deletedDocuments": marqo_response.deleted_docments_count,
},
"duration": utils.create_duration_string(marqo_response.deletion_end - marqo_response.deletion_start),
"startedAt": utils.format_timestamp(marqo_response.deletion_start),
"finishedAt": utils.format_timestamp(marqo_response.deletion_end),
}


# -- Data-layer agnostic logic --


def delete_documents(config: Config, del_request: MqDeleteDocsRequest) -> dict:
"""entrypoint function for deleting documents"""

validation.validate_delete_docs_request(
delete_request=del_request,
max_delete_docs_count=utils.read_env_vars_and_defaults_ints(enums.EnvVars.MARQO_MAX_DELETE_DOCS_COUNT)
)

if config.backend == enums.SearchDb.opensearch:
del_response: MqDeleteDocsResponse = delete_documents_marqo_os(config=config, deletion_instruction=del_request)
else:
raise RuntimeError(f"Config set to use unknown backend `{config.backend}`. "
f"See tensor_search.enums.SearchDB for allowed backends")

return format_delete_docs_response(del_response)


# -- Marqo-OS-specific deletion implementation: --


def delete_documents_marqo_os(config: Config, deletion_instruction: MqDeleteDocsRequest) -> MqDeleteDocsResponse:
"""Deletes documents """

# Prepare bulk delete request body
bulk_request_body = ""
for doc_id in deletion_instruction.document_ids:
bulk_request_body += json.dumps({"delete": {"_index": deletion_instruction.index_name, "_id": doc_id}}) + "\n"

# Send bulk delete request
t0 = datetime.datetime.utcnow()
delete_res_backend = HttpRequests(config=config).post(
path="_bulk",
body=bulk_request_body,
)

if deletion_instruction.auto_refresh:
refresh_response = HttpRequests(config).post(path=f"{deletion_instruction.index_name}/_refresh")

t1 = datetime.datetime.utcnow()
deleted_documents_count = sum(1 for item in delete_res_backend["items"] if "delete" in item and item["delete"]["status"] == 200)

mq_delete_res = MqDeleteDocsResponse(
index_name=deletion_instruction.index_name, status_string='succeeded', document_ids=deletion_instruction.document_ids,
deleted_docments_count=deleted_documents_count, deletion_start=t0,
deletion_end=t1
)
return mq_delete_res
6 changes: 6 additions & 0 deletions src/marqo/tensor_search/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ class EnvVars:
MARQO_EF_CONSTRUCTION_MAX_VALUE = "MARQO_EF_CONSTRUCTION_MAX_VALUE"
MARQO_MAX_VECTORISE_BATCH_SIZE = "MARQO_MAX_VECTORISE_BATCH_SIZE"
MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES = "MARQO_MAX_SEARCHABLE_TENSOR_ATTRIBUTES"
MARQO_MAX_DELETE_DOCS_COUNT = "MARQO_MAX_DELETE_DOCS_COUNT"


class RequestType:
INDEX = "INDEX"
Expand All @@ -118,6 +120,10 @@ class MappingsObjectType:
multimodal_combination = "multimodal_combination"


class SearchDb:
opensearch = 'opensearch'


class AvailableModelsKey:
model = "model"
most_recently_used_time = "most_recently_used_time"
Expand Down
24 changes: 24 additions & 0 deletions src/marqo/tensor_search/models/delete_docs_objects.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
This module holds the classes which define the interface the delete documents
endpoint.
"""

import datetime
from typing import NamedTuple, Literal, List


class MqDeleteDocsResponse(NamedTuple):
"""An object that holds the data we send back to users"""
index_name: str
status_string: Literal["succeeded"]
document_ids: List[str]
deleted_docments_count: int
deletion_start: datetime.datetime
deletion_end: datetime.datetime


class MqDeleteDocsRequest(NamedTuple):
"""An object that holds the data from users for a delete request"""
index_name: str
document_ids: List[str]
auto_refresh: bool
56 changes: 16 additions & 40 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,17 @@
"""
import copy
import json
import datetime
from collections import defaultdict
from timeit import default_timer as timer
import functools
import pprint
import typing
import uuid
from typing import List, Optional, Union, Iterable, Sequence, Dict, Any, Tuple, Set
from typing import List, Optional, Union, Iterable, Sequence, Dict, Any, Tuple
import numpy as np
from PIL import Image
import marqo.config as config
from marqo.tensor_search.models.delete_docs_objects import MqDeleteDocsRequest
from marqo.tensor_search.enums import (
MediaType, MlModel, TensorField, SearchMethod, OpenSearchDataType,
EnvVars
Expand All @@ -56,6 +56,7 @@
from marqo.tensor_search.models.search import VectorisedJobs, VectorisedJobPointer, Qidx, JHash
from marqo.tensor_search.models.index_info import IndexInfo
from marqo.tensor_search.utils import add_timing
from marqo.tensor_search import delete_docs
from marqo.s2_inference.processing import text as text_processor
from marqo.s2_inference.processing import image as image_processor
from marqo.s2_inference.clip_utils import _is_image
Expand Down Expand Up @@ -959,41 +960,6 @@ def _get_documents_for_upsert(
return res


def delete_documents(config: Config, index_name: str, doc_ids: List[str], auto_refresh):
"""Deletes documents """
if not doc_ids:
raise errors.InvalidDocumentIdError("doc_ids can't be empty!")

for _id in doc_ids:
validation.validate_id(_id)

# TODO: change to timer()
t0 = datetime.datetime.utcnow()
delete_res_backend = HttpRequests(config=config).post(
path=f"{index_name}/_delete_by_query", body={
"query": {
"terms": {
"_id": doc_ids
}
}
}
)
if auto_refresh:
refresh_response = HttpRequests(config).post(path=F"{index_name}/_refresh")
t1 = datetime.datetime.utcnow()
delete_res = {
"index_name": index_name, "status": "succeeded",
"type": "documentDeletion", "details": {
"receivedDocumentIds": len(doc_ids),
"deletedDocuments": delete_res_backend["deleted"],
},
"duration": utils.create_duration_string(t1 - t0),
"startedAt": utils.format_timestamp(t0),
"finishedAt": utils.format_timestamp(t1),
}
return delete_res


def refresh_index(config: Config, index_name: str):
return HttpRequests(config).post(path=F"{index_name}/_refresh")

Expand All @@ -1017,7 +983,7 @@ def bulk_search(query: BulkSearchQuery, marqo_config: config.Config, verbose: bo
refresh_indexes_in_background(marqo_config, [q.index for q in query.queries])

# TODO: Let non-errored docs to propagate.
errs = [ validation.validate_bulk_query_input(q) for q in query.queries ]
errs = [validation.validate_bulk_query_input(q) for q in query.queries]
if any(errs):
err = next(e for e in errs if e is not None)
raise err
Expand Down Expand Up @@ -1174,7 +1140,6 @@ def search(config: Config, index_name: str, text: Union[str, dict],
args=(config, index_name, REFRESH_INTERVAL_SECONDS))
cache_update_thread.start()


if search_method.upper() == SearchMethod.TENSOR:
search_result = _vector_text_search(
config=config, index_name=index_name, query=text, result_count=result_count, offset=offset,
Expand Down Expand Up @@ -2597,6 +2562,18 @@ def _create_score_modifiers_tensor_search_query(result_count, offset, vector_fie
}
return search_query


def delete_documents(config: Config, index_name: str, doc_ids: List[str], auto_refresh):
"""Delete documents from the Marqo index with the given doc_ids """
return delete_docs.delete_documents(
config=config,
del_request=MqDeleteDocsRequest(
index_name=index_name,
document_ids=doc_ids,
auto_refresh=auto_refresh)
)


def get_settings(index_name: str, marqo_config: Config):
"""Get the settings for a specific index."""
shards = backend.get_num_shards(config=marqo_config, index_name=index_name)
Expand All @@ -2605,4 +2582,3 @@ def get_settings(index_name: str, marqo_config: Config):
index_info.index_settings["number_of_shards"] = shards

return index_info.index_settings

43 changes: 38 additions & 5 deletions src/marqo/tensor_search/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import (
List, Optional, Union, Callable, Iterable, Sequence, Dict, Tuple
)
from marqo.marqo_logging import logger
import copy
import datetime
import pathlib
Expand Down Expand Up @@ -175,20 +176,52 @@ def read_env_vars_and_defaults(var: str) -> Optional[str]:
"""Attempts to read an environment variable.
If none is found, it will attempt to retrieve it from
configs.default_env_vars(). If still unsuccessful, None is returned.
If it's an empty string, None is returned.
"""
try:
var = os.environ[var]
if var is not None and len(var) == 0:

def none_if_empty(value: Optional[str]) -> Optional[str]:
"""Returns None if value is an empty string"""
if value is not None and len(value) == 0:
return None
else:
return var
return value

try:
return none_if_empty(os.environ[var])
except KeyError:
try:
return configs.default_env_vars()[var]
default_val = configs.default_env_vars()[var]
if isinstance(default_val, str):
return none_if_empty(default_val)
else:
return default_val
except KeyError:
return None


def read_env_vars_and_defaults_ints(var: str) -> Optional[int]:
"""Gets env var from read_env_vars_and_defaults() and attempts to coerce it to an int
Returns
the coerced int value, or None if the key is not found.
"""
str_val = read_env_vars_and_defaults(var)

if str_val is None:
return None

validation_error_msg = (
f"Could not properly read env var `{var}`. `{var}` must be able to be parsed as an int."
)
try:
as_int = int(str_val)
except (ValueError, TypeError) as e:
value_error_msg = f"`{validation_error_msg} Current value: `{str_val}`. Reason: {e}"
logger.error(value_error_msg)
raise errors.ConfigurationError(value_error_msg)
return as_int


def parse_lexical_query(text: str) -> Tuple[List[str], str]:
"""Find required terms enclosed within double quotes.
Expand Down
Loading

0 comments on commit 7d571c2

Please sign in to comment.