Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parameter of MulticlassStatScores and MultilabelStatScores to control which classes/labels to include the averages #1723

Open
plonerma opened this issue Apr 21, 2023 · 1 comment
Labels
enhancement New feature or request
Milestone

Comments

@plonerma
Copy link

plonerma commented Apr 21, 2023

🚀 Feature

Add a parameter to MulticlassStatScores and MultilabelStatScores to control which classes/labels to include in averaging.

Motivation

Sklearn's precision_recall_fscore_support allows users to define the labels used in averaging the computed metrics (as well as the order if the metrics are not averaged). This allows calculating "a multiclass average ignoring a majority negative class". E.g. in my use-case (sequence tagging), I do want to consider datapoints which have an out-tag ("O", meaning they are not tagged), as they might contribute to the false positives of other classes. Hence, ignore_index is not sufficient, as the datapoints would be completely excluded.

Pitch

Add a parameter classes to MulticlassStatScores and labels to MultilabelStatScores to limit the calculation of true positives, fp, fn, and tn to these classes/labels. The resulting averages (e.g. f1-score, accuracy) would then be an average only of the selected classes/labels.

If the classes / labels parameter is specified, num_classes / num_labels would not need to be set (or if they are set and do not agree with the passed number of classes/labels, an Exception would need to be raised).

Alternatives

Currently, I am using a very hacky solution:

def metric_with_certain_labels_only(
    metric_type: Union[Type[MulticlassStatScores], Type[MultilabelStatScores]],
    included_labels: torch.Tensor,
    average: Optional[Literal["micro", "macro", "weighted", "none"]] = "macro",
    **kwargs,
):
    metric_average = average

    if average == "micro":
        metric_average = "none"

    metric = metric_type(average=metric_average, **kwargs)

    _final_state_inner = metric._final_state

    def _final_state_wrapper():
        state = _final_state_inner()

        # manipulate the state variable
        new_state = (s[torch.tensor(included_labels)] for s in state)

        return new_state

    metric._final_state = _final_state_wrapper  # type: ignore

    if average == "micro":
        compute_inner = metric.compute

        def compute_wrapper():
            metric.average = "micro"
            result = compute_inner()
            metric.average = "none"
            return result

        metric.compute = compute_wrapper  # type: ignore

    return metric

I am not happy with this solution for two reasons:

  1. It would be nicer, if included_labels would be part of the metric init signature directly or (at least) would be provided in a wrapper. It is possible to rewrite this helper function as a Wrapper, however this would require changing the averaging on an already created metric (and recreating an new state), and (more importantly)
  2. since the classes/labels are selected on the read-out of the state (which requires fewer changes), relevant stats need to be tracked for all classes (even the once which are not included). This is especially inefficient in case average is set to "micro".

Ideally, the stats would be reduced to the selected classes already in _multiclass_stat_scores_update / _multilabel_stat_scores_update

Additional context

I had already opened a discussion. However, I believe this cannot be solved without a new feature.

@plonerma plonerma added the enhancement New feature or request label Apr 21, 2023
@github-actions
Copy link

Hi! thanks for your contribution!, great first issue!

@SkafteNicki SkafteNicki added this to the future milestone Apr 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants
@SkafteNicki @plonerma and others