From 77ed92e3f1c66df740c1dcd6259fd1341bf54753 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 16 Dec 2020 23:30:47 +0900 Subject: [PATCH] Add missing optional packages to `requirements/*.txt` (#450) * Import matplotlib at the top * Add missing optional packages * Update wandb * Add mypy to requirements --- pl_bolts/callbacks/vision/confused_logit.py | 15 ++++++++++++++- requirements/loggers.txt | 5 ++++- requirements/test.txt | 1 + 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pl_bolts/callbacks/vision/confused_logit.py b/pl_bolts/callbacks/vision/confused_logit.py index df54d48e02..235dc95f1f 100644 --- a/pl_bolts/callbacks/vision/confused_logit.py +++ b/pl_bolts/callbacks/vision/confused_logit.py @@ -1,7 +1,17 @@ +import importlib + import torch from pytorch_lightning import Callback from torch import nn +from pl_bolts.utils.warnings import warn_missing_pkg + +_MATPLOTLIB_AVAILABLE = importlib.util.find_spec("matplotlib") is not None +if _MATPLOTLIB_AVAILABLE: + from matplotlib import pyplot as plt +else: + warn_missing_pkg("matplotlib") # pragma: no-cover + class ConfusedLogitCallback(Callback): # pragma: no-cover """ @@ -93,7 +103,10 @@ def training_step(...): pl_module.train() def _plot(self, confusing_x, confusing_y, trainer, model, mask_idxs): - from matplotlib import pyplot as plt + if not _MATPLOTLIB_AVAILABLE: + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `matplotlib` which is not installed yet, install it with `pip install matplotlib`.' + ) confusing_x = confusing_x[:self.top_k] confusing_y = confusing_y[:self.top_k] diff --git a/requirements/loggers.txt b/requirements/loggers.txt index e594af1651..faa232f2f6 100644 --- a/requirements/loggers.txt +++ b/requirements/loggers.txt @@ -1,2 +1,5 @@ # test_tube>=0.7.5 -# trains>=0.14.1 \ No newline at end of file +# trains>=0.14.1 +matplotlib +wandb +scipy diff --git a/requirements/test.txt b/requirements/test.txt index 6c40d41e00..44e3cf1965 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -9,5 +9,6 @@ check-manifest twine==1.13.0 isort>=5.6.4 pre-commit>=1.0 +mypy atari-py==0.2.6 # needed for RL