Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bring-1535-to-point-release #1560

Closed
wants to merge 53 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
955473d
divide task and statistical metrics
davidslater May 7, 2022
43dc486
update refs
davidslater May 7, 2022
7d89905
load
davidslater May 7, 2022
88517ba
fix existing tests
davidslater May 8, 2022
6ae58ce
regsister metrics
davidslater May 8, 2022
3d2e04b
supported metrics references
davidslater May 8, 2022
ffbd029
move armory.utils.metrics to armory.metrics
davidslater May 9, 2022
23a3442
add loading
davidslater May 9, 2022
439c246
remove empty file
davidslater May 9, 2022
d474a1c
finish supported metrics
davidslater May 9, 2022
24a53ab
migrate poisoning fairness metrics to new module
davidslater May 9, 2022
97c5d33
moved Meter construction to scenario
davidslater May 10, 2022
04824ff
stub out metric tests
davidslater May 13, 2022
9a42d29
Merge branch 'develop' into task-metrics
davidslater May 13, 2022
d806aac
remove old metrics
davidslater May 13, 2022
6c89710
metric WIP
davidslater May 16, 2022
a2eb91f
poisoning metric refactor - work in progress
davidslater May 18, 2022
53445f6
update metrics refs
davidslater May 18, 2022
c76941f
update docs
davidslater May 18, 2022
f862925
update test stubs
davidslater May 18, 2022
07ddb06
refactor fairness metrics, untested
swsuggs May 18, 2022
ac3278c
update majority_mask
davidslater May 18, 2022
efefce7
minor update
davidslater May 18, 2022
f3e7678
finish task metric tests
davidslater May 20, 2022
ee4b75a
update docs
davidslater May 20, 2022
e5723d0
move to different namespace
davidslater May 20, 2022
a1f2f3a
update docs
davidslater May 20, 2022
2f77266
move subsection
davidslater May 20, 2022
47c3ca5
Merge branch 'develop' into task-metrics
davidslater May 20, 2022
53fa5ec
reference poisoning docs
davidslater May 20, 2022
5f75469
small bug fixes for statistical.py
swsuggs May 21, 2022
4e9e158
filled out tests in for statistical_metrics
swsuggs May 21, 2022
aca9cd3
numpy subdtype
davidslater May 23, 2022
a63bb6c
Merge pull request #2 from swsuggs/task-metrics-unit-tests
davidslater May 23, 2022
682308d
Merge pull request #1 from swsuggs/task-metrics-refactor
davidslater May 23, 2022
6324aee
add tag to test
davidslater May 23, 2022
511d763
poisoning metrics tests; will update after further poisoning refactors
swsuggs May 23, 2022
7eef6cc
potential update
davidslater May 23, 2022
89d0f3f
unneeded line
davidslater May 23, 2022
b836aec
Merge branch 'task-metrics' of https://github.com/davidslater/armory …
swsuggs May 24, 2022
b1437d8
remove tests for filter and model bias, obviated by poisoning updates
swsuggs May 24, 2022
9487977
Merge pull request #3 from swsuggs/task-metrics-unit-tests
davidslater May 24, 2022
3dbe39a
lint
davidslater May 24, 2022
04a239a
scenario updates
davidslater May 24, 2022
e84f3b2
remove explanatory stage
davidslater May 25, 2022
40c537e
remove commetns
davidslater May 25, 2022
e62a24b
pip install armory for build
davidslater May 25, 2022
551fe18
update dockerfile build
davidslater May 26, 2022
6db9a79
merge fix
davidslater May 26, 2022
1efbf78
fix merge
davidslater May 26, 2022
765ad2c
Merge pull request #1495 from davidslater/task-metrics
davidslater May 26, 2022
b1a923a
Merge pull request #1535 from davidslater/release-fix
lcadalzo May 27, 2022
84325d6
Merge commit 'b1a923aa996494901cf29e89c18d99dd26b557da' into bring-15…
mwartell Jun 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions armory/art_experimental/attacks/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,14 @@ def _load_metric_fn(self, metric_dict):
metric_module_name = metric_dict.get("module")
if metric_module_name is None:
# by default use categorical accuracy to measure attack success
from armory.utils.metrics import categorical_accuracy
from armory import metrics

log.info(
"Using default categorical accuracy to measure attack success "
"since attack_config['sweep_params']['metric']['module'] is "
"unspecified."
)
self.metric_fn = categorical_accuracy
self.metric_fn = metrics.get("categorical_accuracy")
self.metric_threshold = (
0.5 # for binary metric, any x s.t. 0 < x < 1 suffices
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def forward(self, x: Tensor) -> Tensor:


class get_model(nn.Module):
def __init__(self, weights_path: Optional[str], model_kwargs: dict):
def __init__(self, weights_path: Optional[str], **model_kwargs):
super().__init__()
self.inner_model = Micronnet(**model_kwargs)
self.inner_model.to(DEVICE)
Expand Down
33 changes: 17 additions & 16 deletions armory/instrument/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
get_hub,
)
from armory.logs import log
from armory.utils import metrics

from armory import metrics


class MetricsLogger:
Expand Down Expand Up @@ -131,7 +132,7 @@ def construct_meters_for_perturbation_metrics(

hub = get_hub()
for name in names:
metric = metrics.get_supported_metric(name)
metric = metrics.get(name)
hub.connect_meter(
Meter(
f"perturbation_{name}",
Expand Down Expand Up @@ -192,12 +193,12 @@ def _write(self, name, batch, result):
# E.g., if someone renames this from "benign_word_error_rate" to "benign_wer"
if "word_error_rate" in name:
if "total_word_error_rate" not in name:
result = metrics.get_supported_metric("total_wer")(result)
result = metrics.get("total_wer")(result)
total, (num, denom) = result
f_result = f"total={total:.2%}, {num}/{denom}"
elif "entailment" in name:
if "total_entailment" not in name:
result = metrics.get_supported_metric("total_entailment")(result)
result = metrics.get("total_entailment")(result)
total = sum(result.values())
f_result = (
f"contradiction: {result['contradiction']}/{total}, "
Expand All @@ -208,8 +209,8 @@ def _write(self, name, batch, result):
if "input_to" in name:
for m in MEAN_AP_METRICS:
if m in name:
metric = metrics.get_supported_metric(m)
result = metrics.MeanAP(metric)(result)
metric = metrics.get(m)
result = metrics.task.MeanAP(metric)(result)
break
f_result = f"{result}"
elif any(m in name for m in QUANTITY_METRICS):
Expand All @@ -235,22 +236,22 @@ def _task_metric(
Return list of meters generated for this specific task
"""
meters = []
metric = metrics.get_supported_metric(name)
metric = metrics.get(name)
final_kwargs = {}
if name in MEAN_AP_METRICS:
final_suffix = name
final = metrics.MeanAP(metric)
final = metrics.task.MeanAP(metric)
final_kwargs = metric_kwargs

name = f"input_to_{name}"
metric = metrics.get_supported_metric("identity_unzip")
metric = metrics.get("identity_unzip")
metric_kwargs = None
record_final_only = True
elif name == "entailment":
final = metrics.get_supported_metric("total_entailment")
final = metrics.get("total_entailment")
final_suffix = "total_entailment"
elif name == "word_error_rate":
final = metrics.get_supported_metric("total_wer")
final = metrics.get("total_wer")
final_suffix = "total_word_error_rate"
elif use_mean:
final = np.mean
Expand Down Expand Up @@ -366,22 +367,22 @@ def _task_metric_wrt_benign_predictions(
Return the meter generated for this specific task
Return list of meters generated for this specific task
"""
metric = metrics.get_supported_metric(name)
metric = metrics.get(name)
final_kwargs = {}
if name in MEAN_AP_METRICS:
final_suffix = name
final = metrics.MeanAP(metric)
final = metrics.task.MeanAP(metric)
final_kwargs = metric_kwargs

name = f"input_to_{name}"
metric = metrics.get_supported_metric("identity_unzip")
metric = metrics.get("identity_unzip")
metric_kwargs = None
record_final_only = True
elif name == "entailment":
final = metrics.get_supported_metric("total_entailment")
final = metrics.get("total_entailment")
final_suffix = "total_entailment"
elif name == "word_error_rate":
final = metrics.get_supported_metric("total_wer")
final = metrics.get("total_wer")
final_suffix = "total_word_error_rate"
elif use_mean:
final = np.mean
Expand Down
93 changes: 92 additions & 1 deletion armory/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,92 @@
from armory.metrics import compute, perturbation
import importlib

from armory.metrics import compute, perturbation, statistical, task

SUPPORTED_METRICS = {}
for namespace in (
perturbation.batch,
task.batch,
task.aggregate,
task.population,
statistical.registered,
):
assert not any(k in namespace for k in SUPPORTED_METRICS)
SUPPORTED_METRICS.update(namespace)


def _instantiate_validate(function, name, instantiate_if_class=True):
if isinstance(function, type) and issubclass(function, object):
if instantiate_if_class:
function = function()
if not callable(function):
raise ValueError(f"function {name} is not callable")
return function


def supported(name):
"""
Return whether given name is a supported metric
"""
return name in SUPPORTED_METRICS


def get_supported_metric(name, instantiate_if_class=True):
try:
function = SUPPORTED_METRICS[name]
except KeyError:
raise KeyError(f"{name} is not part of armory.metrics")
return _instantiate_validate(
function, name, instantiate_if_class=instantiate_if_class
)


def load(string, instantiate_if_class=True):
"""
Import load a function from the given '.'-separated identifier string
"""
tokens = string.split(".")
if not all(token.isidentifier() for token in tokens):
raise ValueError(f"{string} is not a valid '.'-separated set of identifiers")
if len(tokens) < 2:
raise ValueError(f"{string} not a valid module and function path")

errors = []
for i in range(len(tokens) - 1, 0, -1):
module_name = ".".join(tokens[:i])
metric_name = ".".join(tokens[i:])
try:
module = importlib.import_module(module_name)
except ImportError:
errors.append(f"ImportError: import {module_name}")
continue
try:
obj = module
for name in tokens[i:]:
obj = getattr(obj, name)
function = obj
break
except AttributeError:
errors.append(
f"AttributeError: module {module_name} has no attribute {metric_name}"
)
else:
error_string = "\n ".join([""] + errors)
raise ValueError(
f"Could not import metric {string}. "
f"The following errors occurred: {error_string}"
)

return _instantiate_validate(
function, string, instantiate_if_class=instantiate_if_class
)


def get(name, instantiate_if_class=True):
"""
Get the given metric, first by looking for it in armory, then via import
instantiate_if_class - if a class is returned, instantiate it when True
"""
try:
return get_supported_metric(name, instantiate_if_class=instantiate_if_class)
except KeyError:
return load(name, instantiate_if_class=instantiate_if_class)
94 changes: 94 additions & 0 deletions armory/metrics/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""
Supporting tools for metrics
"""

import functools

import numpy as np

from armory.logs import log


class MetricNameSpace:
"""
Used to keep track of metrics and make them easily discoverable and enumerable
"""

def __setattr__(self, name, function):
if name.startswith("_"):
raise ValueError(f"Function name '{name}' cannot start with '_'")
if hasattr(self, name):
raise ValueError(f"Cannot overwrite existing function {name}")
if not callable(function):
raise ValueError(f"{name} function {function} is not callable")
super().__setattr__(name, function)

def __delattr__(self, name):
raise ValueError("Deletion not allowed")

def _names(self):
return sorted(x for x in self.__dict__ if not x.startswith("_"))

def __contains__(self, name):
return name in self._names()

def __repr__(self):
"""
Show the existing non-underscore names
"""
return str(self._names())

def __iter__(self):
for name in self._names():
yield name, self[name]

def __getitem__(self, name):
if not hasattr(self, name):
raise KeyError(name)
return getattr(self, name)

def __setitem__(self, name, function):
setattr(self, name, function)


def set_namespace(namespace, metric, name=None):
"""
Set the namespace, getting the metric name if none given, and return the metric
"""
if name is None:
name = metric.__name__
setattr(namespace, name, metric)
return metric


def as_batch(element_metric):
"""
Return a batchwise metric function from an elementwise metric function
"""

@functools.wraps(element_metric)
def wrapper(x_batch, x_adv_batch, **kwargs):
x_batch = list(x_batch)
x_adv_batch = list(x_adv_batch)
if len(x_batch) != len(x_adv_batch):
raise ValueError(
f"len(a_batch) {len(x_batch)} != len(b_batch) {len(x_adv_batch)}"
)
y = []
for x, x_adv in zip(x_batch, x_adv_batch):
y.append(element_metric(x, x_adv, **kwargs))
try:
y = np.array(y)
except ValueError:
# Handle ragged arrays
y = np.array(y, dtype=object)
return y

if wrapper.__doc__ is None:
log.warning(f"{element_metric.__name__} has no doc string")
wrapper.__doc__ = ""
wrapper.__doc__ = "Batch version of:\n" + wrapper.__doc__
wrapper.__name__ = "batch_" + wrapper.__name__
# note: repr(wrapper) defaults to the element_metric, not __name__
# See: https://stackoverflow.com/questions/10875442/possible-to-change-a-functions-repr-in-python
return wrapper
Loading