Skip to content

Commit

Permalink
Make neighbor's dataset path as optional
Browse files Browse the repository at this point in the history
If neighbor's dataset path or corpus is not set,
use dataset path as neighbor's dataset path.

Signed-off-by: Vijayan Balasubramanian <balasvij@amazon.com>
  • Loading branch information
VijayanB committed Feb 17, 2024
1 parent 822f31d commit d549424
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
9 changes: 3 additions & 6 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,7 +1052,7 @@ def __init__(self, workloads, params, query_params, **kwargs):
self.PARAMS_NAME_NEIGHBORS_DATA_SET_FORMAT, params, self.data_set_format)
self.neighbors_data_set_path = params.get(self.PARAMS_NAME_NEIGHBORS_DATA_SET_PATH)
self.neighbors_data_set_corpus = params.get(self.PARAMS_NAME_NEIGHBORS_DATA_SET_CORPUS)
self._validate_data_set(self.neighbors_data_set_path, self.neighbors_data_set_corpus)
self._validate_neighbors_data_set(self.neighbors_data_set_path, self.neighbors_data_set_corpus)
self.neighbors_data_set = None
operation_type = parse_string_parameter(self.PARAMS_NAME_OPERATION_TYPE, params,
self.PARAMS_VALUE_VECTOR_SEARCH)
Expand All @@ -1072,11 +1072,8 @@ def __init__(self, workloads, params, query_params, **kwargs):
neighbors_corpora = self.extract_corpora(self.neighbors_data_set_corpus, self.neighbors_data_set_format)
self.corpora.extend(corpora for corpora in neighbors_corpora if corpora not in self.corpora)

def _validate_neighbors_data_set(self):
if not self.data_set_path and not self.data_set_corpus:
raise exceptions.ConfigurationError(
"Dataset is missing. Provide either dataset file path or valid corpus.")
if self.data_set_path and self.data_set_corpus:
def _validate_neighbors_data_set(self, file_path, corpus):
if file_path and corpus:
raise exceptions.ConfigurationError(
"Provide either dataset file path '%s' or corpus '%s'." % (self.data_set_path, self.data_set_corpus))

Expand Down
14 changes: 9 additions & 5 deletions tests/utils/dataset_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,16 @@ def create_data_set(
dimension: int,
extension: str,
data_set_context: Context,
data_set_dir
data_set_dir,
file_path: str = None
) -> str:
file_name_base = ''.join(random.choice(string.ascii_letters) for _ in
range(DEFAULT_RANDOM_STRING_LENGTH))
data_set_file_name = "{}.{}".format(file_name_base, extension)
data_set_path = os.path.join(data_set_dir, data_set_file_name)
if file_path:
data_set_path = file_path
else:
file_name_base = ''.join(random.choice(string.ascii_letters) for _ in
range(DEFAULT_RANDOM_STRING_LENGTH))
data_set_file_name = "{}.{}".format(file_name_base, extension)
data_set_path = os.path.join(data_set_dir, data_set_file_name)
context = DataSetBuildContext(
data_set_context,
create_random_2d_array(num_vectors, dimension),
Expand Down
6 changes: 3 additions & 3 deletions tests/workload/params_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2862,20 +2862,20 @@ def test_params_default(self):
Context.QUERY,
self.data_set_dir
)
neighbors_data_set_path = create_data_set(
create_data_set(
self.DEFAULT_NUM_VECTORS,
self.DEFAULT_DIMENSION,
self.DEFAULT_TYPE,
Context.NEIGHBORS,
self.data_set_dir
self.data_set_dir,
data_set_path
)

# Create a QueryVectorsFromDataSetParamSource with relevant params
test_param_source_params = {
"field": self.DEFAULT_FIELD_NAME,
"data_set_format": self.DEFAULT_TYPE,
"data_set_path": data_set_path,
"neighbors_data_set_path": neighbors_data_set_path,
"k": k
}
query_param_source = VectorSearchPartitionParamSource(
Expand Down

0 comments on commit d549424

Please sign in to comment.