Skip to content

Commit

Permalink
Fixed code review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Vijayan Balasubramanian <balasvij@amazon.com>
  • Loading branch information
VijayanB committed Feb 8, 2024
1 parent 844168f commit c6767e1
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions osbenchmark/workload/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ def __init__(self, workload, params, context: Context, **kwargs):
self.data_set_format = parse_string_parameter("data_set_format", params)
self.data_set_path = parse_string_parameter("data_set_path", params, "")
self.data_set_corpus = parse_string_parameter("data_set_corpus", params, "")
self._validate_data_set(self.data_set_path, self.data_set_corpus)
self.total_num_vectors: int = parse_int_parameter("num_vectors", params, -1)
self.num_vectors = 0
self.total = 1
Expand All @@ -914,10 +915,14 @@ def infinite(self):
def _is_last_partition(partition_index, total_partitions):
return partition_index == total_partitions - 1

def _validate_data_set(self):
if not self.data_set_path and not self.data_set_corpus:
@staticmethod
def _validate_data_set(file_path, corpus):
if not file_path and not corpus:
raise exceptions.ConfigurationError(
"Dataset is missing. Provide either dataset file path or valid corpus.")
if file_path and corpus:
raise exceptions.ConfigurationError(
"Provide either dataset file path '%s' or corpus '%s'." % (file_path, corpus))

@staticmethod
def _validate_data_set_corpus(data_set_path_list):
Expand All @@ -939,7 +944,6 @@ def partition(self, partition_index, total_partitions):
Returns:
The parameter source for this particular partition
"""
self._validate_data_set()
if self.data_set_corpus and not self.data_set_path:
data_set_path = self._get_corpora_file_paths(self.data_set_corpus, self.data_set_format)
self._validate_data_set_corpus(data_set_path)
Expand Down Expand Up @@ -1045,6 +1049,7 @@ def __init__(self, workloads, params, query_params, **kwargs):
self.PARAMS_NAME_NEIGHBORS_DATA_SET_FORMAT, params, self.data_set_format)
self.neighbors_data_set_path = params.get(self.PARAMS_NAME_NEIGHBORS_DATA_SET_PATH)
self.neighbors_data_set_corpus = params.get(self.PARAMS_NAME_NEIGHBORS_DATA_SET_CORPUS)
self._validate_data_set(self.neighbors_data_set_path, self.neighbors_data_set_corpus)
self.neighbors_data_set = None
operation_type = parse_string_parameter(self.PARAMS_NAME_OPERATION_TYPE, params,
self.PARAMS_VALUE_VECTOR_SEARCH)
Expand All @@ -1060,6 +1065,14 @@ def __init__(self, workloads, params, query_params, **kwargs):
neighbors_corpora = self.extract_corpora(self.neighbors_data_set_corpus, self.neighbors_data_set_format)
self.corpora.extend(corpora for corpora in neighbors_corpora if corpora not in self.corpora)

def _validate_neighbors_data_set(self):
if not self.data_set_path and not self.data_set_corpus:
raise exceptions.ConfigurationError(
"Dataset is missing. Provide either dataset file path or valid corpus.")
if self.data_set_path and self.data_set_corpus:
raise exceptions.ConfigurationError(
"Provide either dataset file path '%s' or corpus '%s'." % (self.data_set_path, self.data_set_corpus))

def _update_request_params(self):
request_params = self.query_params.get(self.PARAMS_NAME_REQUEST_PARAMS, {})
request_params[self.PARAMS_NAME_SOURCE] = request_params.get(
Expand Down

0 comments on commit c6767e1

Please sign in to comment.