Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Unpin pytorch_lightning<1.2 #3598

Merged
merged 1 commit into from
May 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dependencies/recommended.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ torch == 1.6.0+cpu ; sys_platform != "darwin"
torch == 1.6.0 ; sys_platform == "darwin"
torchvision == 0.7.0+cpu ; sys_platform != "darwin"
torchvision == 0.7.0 ; sys_platform == "darwin"
pytorch-lightning >= 1.1.1, < 1.2
pytorch-lightning >= 1.1.1
onnx
peewee
graphviz
7 changes: 6 additions & 1 deletion nni/retiarii/evaluator/pytorch/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,18 @@ def _get_validation_metrics(self):
return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics}


class _AccuracyWithLogits(pl.metrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)


@serialize_cls
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam):
super().__init__(criterion, {'acc': pl.metrics.Accuracy},
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer)


Expand Down
2 changes: 1 addition & 1 deletion pipelines/full-test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
python3 -m pip install scikit-learn==0.24.1
python3 -m pip install torchvision==0.7.0
python3 -m pip install torch==1.6.0
python3 -m pip install 'pytorch-lightning>=1.1.1,<1.2'
python3 -m pip install 'pytorch-lightning>=1.1.1'
python3 -m pip install keras==2.1.6
python3 -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
python3 -m pip install thop
Expand Down
2 changes: 1 addition & 1 deletion pipelines/full-test-windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
python -m pip install scikit-learn==0.24.1
python -m pip install keras==2.1.6
python -m pip install torch==1.6.0 torchvision==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install 'pytorch-lightning>=1.1.1,<1.2'
python -m pip install 'pytorch-lightning>=1.1.1'
python -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0
displayName: Install extra dependencies

Expand Down
1 change: 1 addition & 0 deletions test/ut/retiarii/test_lightning_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import nni.runtime.platform.test
import pytorch_lightning
import torch
import torch.nn as nn
Expand Down