Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Adopt torchmetrics #4290

Merged
merged 5 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dependencies/recommended.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ torch == 1.9.0+cpu ; sys_platform != "darwin"
torch == 1.9.0 ; sys_platform == "darwin"
torchvision == 0.10.0+cpu ; sys_platform != "darwin"
torchvision == 0.10.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.4.2
pytorch-lightning >= 1.5
torchmetrics
onnx
peewee
graphviz
Expand Down
1 change: 1 addition & 0 deletions dependencies/recommended_legacy.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ torchvision == 0.7.0+cpu
# It will install pytorch-lightning 0.8.x and unit tests won't work.
# Latest version has conflict with tensorboard and tensorflow 1.x.
pytorch-lightning
torchmetrics

keras == 2.1.6
onnx
Expand Down
23 changes: 12 additions & 11 deletions nni/retiarii/evaluator/pytorch/cgo/accelerator.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Any, Union, Optional, List
import torch
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Any, List, Optional, Union

import torch
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer import Trainer

from pytorch_lightning.plugins import Plugin
from pytorch_lightning.plugins.environments import ClusterEnvironment
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector

from ....serializer import serialize_cls

Expand Down Expand Up @@ -69,9 +70,8 @@ def model_to_device(self) -> None:
# bypass device placement from pytorch lightning
pass

def setup(self, model: torch.nn.Module) -> torch.nn.Module:
self.model_to_device()
return self.model
def setup(self) -> None:
pass

@property
def is_global_zero(self) -> bool:
Expand Down Expand Up @@ -100,8 +100,9 @@ def get_accelerator_connector(
deterministic: bool = False,
precision: int = 32,
amp_backend: str = 'native',
amp_level: str = 'O2',
plugins: Optional[Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str]] = None,
amp_level: Optional[str] = None,
plugins: Optional[Union[List[Union[TrainingTypePlugin, ClusterEnvironment, str]],
TrainingTypePlugin, ClusterEnvironment, str]] = None,
**other_trainier_kwargs) -> AcceleratorConnector:
gpu_ids = Trainer()._parse_devices(gpus, auto_select_gpus, tpu_cores)
return AcceleratorConnector(
Expand Down
8 changes: 4 additions & 4 deletions nni/retiarii/evaluator/pytorch/cgo/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch.nn as nn
import torch.optim as optim
import pytorch_lightning as pl
import torchmetrics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add to requirements?

from torch.utils.data import DataLoader

import nni
Expand All @@ -19,7 +19,7 @@

@serialize_cls
class _MultiModelSupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
n_models: int = 0,
learning_rate: float = 0.001,
weight_decay: float = 0.,
Expand Down Expand Up @@ -119,7 +119,7 @@ class MultiModelSupervisedLearningModule(_MultiModelSupervisedLearningModule):
Class for optimizer (not an instance). default: ``Adam``
"""

def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
Expand Down Expand Up @@ -180,7 +180,7 @@ def __init__(self, criterion: nn.Module = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)


Expand Down
7 changes: 4 additions & 3 deletions nni/retiarii/evaluator/pytorch/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torch.utils.data import DataLoader

import nni
Expand Down Expand Up @@ -140,7 +141,7 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###

class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, pl.metrics.Metric],
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
Expand Down Expand Up @@ -213,7 +214,7 @@ def _get_validation_metrics(self):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}


class _AccuracyWithLogits(pl.metrics.Accuracy):
class _AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)

Expand Down Expand Up @@ -278,7 +279,7 @@ def __init__(self, criterion: nn.Module = nn.MSELoss,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': pl.metrics.MeanSquaredError},
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
export_onnx=export_onnx)

Expand Down