-
Notifications
You must be signed in to change notification settings - Fork 198
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[features] bring your own vectors (#381)
* updated * update mock * update mock * change sum to mean. * change sum to mean. * updated * updated * add query type check logic * add maximum number of vectors limit * add tests * updated * catch mainline * catch mainline * catch mainline * catch mainline * add test for vectors * add test for vectors * add test for vectors
- Loading branch information
Showing
7 changed files
with
327 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
context_schema = { | ||
"$schema": "http://json-schema.org/draft-04/schema#", | ||
"type": "object", | ||
"properties": { | ||
"tensor": { | ||
"type": "array", | ||
"minItems":1, | ||
"maxItems" : 64, | ||
"items": | ||
{ | ||
"type": "object", | ||
"properties": { | ||
"vector": { | ||
"type": "array", | ||
"items": {"type": "number"} | ||
}, | ||
"weight": { | ||
"type": "number" | ||
} | ||
}, | ||
"required": [ | ||
"vector", | ||
"weight" | ||
] | ||
}, | ||
} | ||
}, | ||
"required": [ | ||
"tensor" | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import unittest.mock | ||
import pprint | ||
|
||
import torch | ||
|
||
import marqo.tensor_search.backend | ||
from marqo.errors import IndexNotFoundError, InvalidArgError | ||
from marqo.tensor_search import tensor_search | ||
from marqo.tensor_search.enums import TensorField, IndexSettingsField, SearchMethod | ||
from tests.marqo_test import MarqoTestCase | ||
from unittest.mock import patch | ||
import numpy as np | ||
|
||
|
||
class TestMultimodalTensorCombination(MarqoTestCase): | ||
|
||
def setUp(self): | ||
self.index_name_1 = "my-test-index-1" | ||
self.endpoint = self.authorized_url | ||
|
||
try: | ||
tensor_search.delete_index(config=self.config, index_name=self.index_name_1) | ||
except IndexNotFoundError as e: | ||
pass | ||
|
||
tensor_search.create_vector_index( | ||
index_name=self.index_name_1, config=self.config, index_settings={ | ||
IndexSettingsField.index_defaults: { | ||
IndexSettingsField.model: "ViT-B/32", | ||
IndexSettingsField.treat_urls_and_pointers_as_images: True, | ||
IndexSettingsField.normalize_embeddings: True | ||
} | ||
}) | ||
tensor_search.add_documents(config=self.config, index_name=self.index_name_1, docs=[ | ||
{ | ||
"Title": "Horse rider", | ||
"text_field": "A rider is riding a horse jumping over the barrier.", | ||
"_id": "1" | ||
}], auto_refresh=True) | ||
|
||
def tearDown(self) -> None: | ||
try: | ||
tensor_search.delete_index(config=self.config, index_name=self.index_name_1) | ||
except: | ||
pass | ||
|
||
def test_search(self): | ||
query = { | ||
"A rider is riding a horse jumping over the barrier": 1, | ||
} | ||
res = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= | ||
{"tensor": [{"vector": [1, ] * 512, "weight": 2}, {"vector": [2, ] * 512, "weight": -1}], }) | ||
|
||
def test_search_with_incorrect_tensor_dimension(self): | ||
query = { | ||
"A rider is riding a horse jumping over the barrier": 1, | ||
} | ||
try: | ||
res = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= | ||
{"tensor": [{"vector": [1, ] * 3, "weight": 0}, {"vector": [2, ] * 512, "weight": 0}], }) | ||
raise AssertionError | ||
except InvalidArgError as e: | ||
assert "This causes the error when we do `numpy.mean()` over" in e.message | ||
|
||
def test_search_with_incorrect_query_format(self): | ||
query = "A rider is riding a horse jumping over the barrier" | ||
try: | ||
res = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= | ||
{"tensor": [{"vector": [1, ] * 512, "weight": 0}, {"vector": [2, ] * 512, "weight": 0}], }) | ||
raise AssertionError | ||
except InvalidArgError as e: | ||
assert "This is not supported as the context only works when the query is a dictionary." in e.message | ||
|
||
def test_search_score(self): | ||
query = { | ||
"A rider is riding a horse jumping over the barrier": 1, | ||
} | ||
|
||
res_1 = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query) | ||
res_2 = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= | ||
{"tensor": [{"vector": [1, ] * 512, "weight": 0}, {"vector": [2, ] * 512, "weight": 0}], }) | ||
res_3 = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= | ||
{"tensor": [{"vector": [1, ] * 512, "weight": -1}, {"vector": [1, ] * 512, "weight": 1}], }) | ||
|
||
assert res_1["hits"][0]["_score"] == res_2["hits"][0]["_score"] | ||
assert res_1["hits"][0]["_score"] == res_3["hits"][0]["_score"] | ||
|
||
def test_search_vectors(self): | ||
with patch("numpy.mean", wraps = np.mean) as mock_mean: | ||
query = { | ||
"A rider is riding a horse jumping over the barrier": 1, | ||
} | ||
res_1 = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query) | ||
|
||
weight_1, weight_2, weight_3 = 2.5, 3.4, -1.334 | ||
vector_2 = [-1,] * 512 | ||
vector_3 = [1.3,] * 512 | ||
query = { | ||
"A rider is riding a horse jumping over the barrier": weight_1, | ||
} | ||
|
||
res_2 = tensor_search.search(config=self.config, index_name=self.index_name_1, text=query, context= | ||
{"tensor": [{"vector": vector_2, "weight": weight_2}, {"vector": vector_3, "weight": weight_3}], }) | ||
|
||
args_list = [args[0] for args in mock_mean.call_args_list] | ||
vectorised_string = args_list[0][0][0] | ||
weighted_vectors = args_list[1][0] | ||
|
||
assert np.allclose(vectorised_string * weight_1, weighted_vectors[0], atol=1e-9) | ||
assert np.allclose(np.array(vector_2) * weight_2, weighted_vectors[1], atol=1e-9) | ||
assert np.allclose(np.array(vector_3) * weight_3, weighted_vectors[2], atol=1e-9) |
Oops, something went wrong.