Skip to content

Commit

Permalink
Publish recall as a kpi metric (#581)
Browse files Browse the repository at this point in the history
Signed-off-by: Finn Roblin <finnrobl@amazon.com>
  • Loading branch information
finnroblin authored Jul 24, 2024
1 parent 43d43b2 commit 5886866
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 6 deletions.
43 changes: 41 additions & 2 deletions osbenchmark/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,7 +1440,6 @@ def as_dict(self):
if self.plugin_params:
d["plugin-params"] = self.plugin_params
return d

def to_result_dicts(self):
"""
:return: a list of dicts, suitable for persisting the results of this test execution in a format that is Kibana-friendly.
Expand Down Expand Up @@ -1784,6 +1783,7 @@ def __call__(self):
op_type = task.operation.type
error_rate = self.error_rate(t, op_type)
duration = self.duration(t)

if task.operation.include_in_results_publishing or error_rate > 0:
self.logger.debug("Gathering request metrics for [%s].", t)
result.add_op_metrics(
Expand All @@ -1800,8 +1800,19 @@ def __call__(self):
self.workload.meta_data,
self.test_procedure.meta_data,
task.operation.meta_data,
task.meta_data)
task.meta_data,
),
)

result.add_correctness_metrics(
t,
task.operation.name,
self.single_latency(t, op_type, metric_name="recall@k"),
self.single_latency(t, op_type, metric_name="recall@1"),
error_rate,
duration,
)

self.logger.debug("Gathering indexing metrics.")
result.total_time = self.sum("indexing_total_time")
result.total_time_per_shard = self.shard_stats("indexing_total_time")
Expand Down Expand Up @@ -1996,6 +2007,7 @@ def single_latency(self, task, operation_type, metric_name="latency"):
class GlobalStats:
def __init__(self, d=None):
self.op_metrics = self.v(d, "op_metrics", default=[])
self.correctness_metrics = self.v(d, "correctness_metrics", default=[])
self.total_time = self.v(d, "total_time")
self.total_time_per_shard = self.v(d, "total_time_per_shard", default={})
self.indexing_throttle_time = self.v(d, "indexing_throttle_time")
Expand Down Expand Up @@ -2081,6 +2093,22 @@ def op_metrics(op_item, key, single_value=False):
"max": item["max"]
}
})
elif metric == "correctness_metrics":
for item in value:
if "recall@k" in item:
all_results.append({
"task": item["task"],
"operation": item["operation"],
"name": "recall@k",
"value": item["recall@k"]
})
if "recall@1" in item:
all_results.append({
"task": item["task"],
"operation": item["operation"],
"name": "recall@1",
"value": item["recall@1"]
})
elif metric.startswith("total_transform_") and value is not None:
for item in value:
all_results.append({
Expand Down Expand Up @@ -2124,6 +2152,17 @@ def add_op_metrics(self, task, operation, throughput, latency, service_time, cli
doc["meta"] = meta
self.op_metrics.append(doc)

def add_correctness_metrics(self, task, operation, recall_at_k_stats, recall_at_1_stats, error_rate, duration):
self.correctness_metrics.append({
"task": task,
"operation": operation,
"recall@k": recall_at_k_stats,
"recall@1":recall_at_1_stats,
"error_rate": error_rate,
"duration": duration
}
)

def tasks(self):
# ensure we can read test_execution.json files before Benchmark 0.8.0
return [v.get("task", v["operation"]) for v in self.op_metrics]
Expand Down
21 changes: 19 additions & 2 deletions osbenchmark/results_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,14 @@ def publish(self):
metrics_table.extend(self._publish_error_rate(record, task))
self.add_warnings(warnings, record, task)

for record in stats.correctness_metrics:
task = record["task"]

keys = record.keys()
recall_keys_in_task_dict = "recall@1" in keys and "recall@k" in keys
if recall_keys_in_task_dict and "mean" in record["recall@1"] and "mean" in record["recall@k"]:
metrics_table.extend(self._publish_recall(record, task))

self.write_results(metrics_table)

if warnings:
Expand Down Expand Up @@ -200,14 +208,23 @@ def _publish_service_time(self, values, task):
def _publish_processing_time(self, values, task):
return self._publish_percentiles("processing time", task, values["processing_time"])

def _publish_percentiles(self, name, task, value):
def _publish_recall(self, values, task):
recall_k_mean = values["recall@k"]["mean"]
recall_1_mean = values["recall@1"]["mean"]

return self._join(
self._line("Mean recall@k", task, recall_k_mean, "", lambda v: "%.2f" % v),
self._line("Mean recall@1", task, recall_1_mean, "", lambda v: "%.2f" % v)
)

def _publish_percentiles(self, name, task, value, unit="ms"):
lines = []
percentiles = self.display_percentiles.get(name, metrics.GlobalStatsCalculator.OTHER_PERCENTILES)

if value:
for percentile in metrics.percentiles_for_sample_size(sys.maxsize, percentiles_list=percentiles):
percentile_value = value.get(metrics.encode_float_key(percentile))
a_line = self._line("%sth percentile %s" % (percentile, name), task, percentile_value, "ms",
a_line = self._line("%sth percentile %s" % (percentile, name), task, percentile_value, unit,
force=self.publish_all_percentile_values)
self._append_non_empty(lines, a_line)
return lines
Expand Down
40 changes: 38 additions & 2 deletions osbenchmark/worker_coordinator/worker_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
Expand Down Expand Up @@ -46,7 +46,6 @@
from osbenchmark.workload import WorkloadProcessorRegistry, load_workload, load_workload_plugins
from osbenchmark.utils import convert, console, net
from osbenchmark.worker_coordinator.errors import parse_error

##################################
#
# Messages sent between worker_coordinators
Expand Down Expand Up @@ -847,6 +846,43 @@ def __call__(self, raw_samples):
start = total_start
final_sample_count = 0
for idx, sample in enumerate(raw_samples):
self.logger.debug(
"All sample meta data: [%s],[%s],[%s],[%s],[%s]",
self.workload_meta_data,
self.test_procedure_meta_data,
sample.operation_meta_data,
sample.task.meta_data,
sample.request_meta_data,
)

# if request_meta_data exists then it will have {"success": true/false} as a parameter.
if sample.request_meta_data and len(sample.request_meta_data) > 1:
self.logger.debug("Found: %s", sample.request_meta_data)
recall_metric_names = ["recall@k", "recall@1"]

for recall_metric_name in recall_metric_names:
if recall_metric_name in sample.request_meta_data:
meta_data = self.merge(
self.workload_meta_data,
self.test_procedure_meta_data,
sample.operation_meta_data,
sample.task.meta_data,
sample.request_meta_data,
)

self.metrics_store.put_value_cluster_level(
name=recall_metric_name,
value=sample.request_meta_data[recall_metric_name],
unit="",
task=sample.task.name,
operation=sample.operation_name,
operation_type=sample.operation_type,
sample_type=sample.sample_type,
absolute_time=sample.absolute_time,
relative_time=sample.relative_time,
meta_data=meta_data,
)

if idx % self.downsample_factor == 0:
final_sample_count += 1
meta_data = self.merge(
Expand Down

0 comments on commit 5886866

Please sign in to comment.