Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor refactors - cleaning models #524

Merged
merged 14 commits into from
Jan 19, 2021
6 changes: 3 additions & 3 deletions pl_bolts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@
except NameError:
__LIGHTNING_BOLT_SETUP__: bool = False

if __LIGHTNING_BOLT_SETUP__:
import sys # pragma: no-cover
if __LIGHTNING_BOLT_SETUP__: # pragma: no cover
import sys

sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover
sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n')
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
else:
from pl_bolts import callbacks, datamodules, datasets, losses, metrics, models, optimizers, transforms, utils
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/callbacks/knn_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
warn_missing_pkg("sklearn", pypi_name="scikit-learn")


class KNNOnlineEvaluator(Callback): # pragma: no-cover
class KNNOnlineEvaluator(Callback): # pragma: no cover
"""
Evaluates self-supervised K nearest neighbors.

Expand Down
6 changes: 3 additions & 3 deletions pl_bolts/callbacks/vision/confused_logit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Figure = object


class ConfusedLogitCallback(Callback): # pragma: no-cover
class ConfusedLogitCallback(Callback): # pragma: no cover
"""
Takes the logit predictions of a model and when the probabilities of two classes are very close, the model
doesn't have high certainty that it should pick one vs the other class.
Expand Down Expand Up @@ -122,8 +122,8 @@ def _plot(
model: LightningModule,
mask_idxs: Tensor,
) -> None:
if not _MATPLOTLIB_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _MATPLOTLIB_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `matplotlib` which is not installed yet, install it with `pip install matplotlib`.'
)

Expand Down
10 changes: 7 additions & 3 deletions pl_bolts/callbacks/vision/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import torch
from pytorch_lightning import Callback, LightningModule, Trainer

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
import torchvision
except ModuleNotFoundError:
warn_missing_pkg("torchvision") # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg("torchvision")


class TensorboardGenerativeModelImageSampler(Callback):
Expand Down Expand Up @@ -57,6 +58,9 @@ def __init__(
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value: Value for the padded pixels. Default: ``0``.
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")

super().__init__()
self.num_samples = num_samples
self.nrow = nrow
Expand Down
13 changes: 13 additions & 0 deletions pl_bolts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,16 @@
from pl_bolts.models.regression import LinearRegression, LogisticRegression # noqa: F401
from pl_bolts.models.vision import PixelCNN, SemSegment, UNet # noqa: F401
from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT # noqa: F401

__all__ = [
"AE",
"VAE",
"LitMNIST",
"LinearRegression",
"LogisticRegression",
"PixelCNN",
"SemSegment",
"UNet",
"GPT2",
"ImageGPT",
]
9 changes: 9 additions & 0 deletions pl_bolts/models/autoencoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,12 @@
resnet50_decoder,
resnet50_encoder,
)

__all__ = [
"AE",
"VAE",
"resnet18_decoder",
"resnet18_encoder",
"resnet50_decoder",
"resnet50_encoder",
]
12 changes: 7 additions & 5 deletions pl_bolts/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
try:
from pl_bolts.models.detection import components # noqa: F401
from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401
except ModuleNotFoundError: # pragma: no-cover
pass # pragma: no-cover
from pl_bolts.models.detection import components # noqa: F401
from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401

__all__ = [
"components",
"FasterRCNN",
]
9 changes: 7 additions & 2 deletions pl_bolts/models/detection/faster_rcnn/faster_rcnn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import pytorch_lightning as pl
import torch

from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision.models.detection.faster_rcnn import FasterRCNN as torchvision_FasterRCNN
from torchvision.models.detection.faster_rcnn import fasterrcnn_resnet50_fpn, FastRCNNPredictor
from torchvision.ops import box_iou

from pl_bolts.models.detection.faster_rcnn import create_fasterrcnn_backbone
else: # pragma: no cover
warn_missing_pkg("torchvision")

Expand All @@ -22,6 +21,9 @@ def _evaluate_iou(target, pred):
Evaluate intersection over union (IOU) for target from dataset and output prediction
from model
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
return torch.tensor(0.0, device=pred["boxes"].device)
Expand Down Expand Up @@ -69,6 +71,9 @@ def __init__(
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

super().__init__()

self.learning_rate = learning_rate
Expand Down
5 changes: 5 additions & 0 deletions pl_bolts/models/gans/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401
from pl_bolts.models.gans.dcgan.dcgan_module import DCGAN # noqa: F401

__all__ = [
"GAN",
"DCGAN",
]
2 changes: 1 addition & 1 deletion pl_bolts/models/gans/dcgan/dcgan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as transform_lib
from torchvision.datasets import LSUN, MNIST
else: # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg("torchvision")


Expand Down
10 changes: 7 additions & 3 deletions pl_bolts/models/mnist_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,22 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms
from torchvision.datasets import MNIST
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


class LitMNIST(LightningModule):

def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, data_dir='', **kwargs):
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

super().__init__()
self.save_hyperparameters()

Expand Down
5 changes: 5 additions & 0 deletions pl_bolts/models/regression/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,7 @@
from pl_bolts.models.regression.linear_regression import LinearRegression # noqa: F401
from pl_bolts.models.regression.logistic_regression import LogisticRegression # noqa: F401

__all__ = [
"LinearRegression",
"LogisticRegression",
]
9 changes: 5 additions & 4 deletions pl_bolts/models/regression/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,16 +110,17 @@ def add_model_specific_args(parent_parser):

def cli_main():
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule
from pl_bolts.utils import _SKLEARN_AVAILABLE

pl.seed_everything(1234)

# create dataset
try:
if _SKLEARN_AVAILABLE:
from sklearn.datasets import load_boston
except ModuleNotFoundError as err:
raise ModuleNotFoundError( # pragma: no-cover
else: # pragma: no cover
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there an advantage or using "no cover" compare to "no-cover" I was thinking that the joined form is more like a tag :]

raise ModuleNotFoundError(
'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.'
) from err
)

# args
parser = ArgumentParser()
Expand Down
9 changes: 5 additions & 4 deletions pl_bolts/models/regression/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,17 @@ def add_model_specific_args(parent_parser):

def cli_main():
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataModule
from pl_bolts.utils import _SKLEARN_AVAILABLE

pl.seed_everything(1234)

# Example: Iris dataset in Sklearn (4 features, 3 class labels)
try:
if _SKLEARN_AVAILABLE:
from sklearn.datasets import load_iris
except ModuleNotFoundError as err:
raise ModuleNotFoundError( # pragma: no-cover
else: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.'
) from err
)

# args
parser = ArgumentParser()
Expand Down
27 changes: 17 additions & 10 deletions pl_bolts/models/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
try:
from pl_bolts.models.rl.double_dqn_model import DoubleDQN # noqa: F401
from pl_bolts.models.rl.dqn_model import DQN # noqa: F401
from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN # noqa: F401
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401
from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401
from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401
except ModuleNotFoundError:
pass
from pl_bolts.models.rl.double_dqn_model import DoubleDQN # noqa: F401
from pl_bolts.models.rl.dqn_model import DQN # noqa: F401
from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN # noqa: F401
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401
from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401
from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401

__all__ = [
"DoubleDQN",
"DQN",
"DuelingDQN",
"NoisyDQN",
"PERDQN",
"Reinforce",
"VanillaPolicyGradient",
]
24 changes: 17 additions & 7 deletions pl_bolts/models/rl/common/gym_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,24 @@
import gym.spaces
from gym import make as gym_make
from gym import ObservationWrapper, Wrapper
else: # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('gym')
Wrapper = object
ObservationWrapper = object

if _OPENCV_AVAILABLE:
import cv2
else:
warn_missing_pkg('cv2', pypi_name='opencv-python') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('cv2', pypi_name='opencv-python')


class ToTensor(Wrapper):
"""For environments where the user need to press FIRE for the game to start."""

def __init__(self, env=None):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super(ToTensor, self).__init__(env)

def step(self, action):
Expand All @@ -45,6 +48,9 @@ class FireResetEnv(Wrapper):
"""For environments where the user need to press FIRE for the game to start."""

def __init__(self, env=None):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super(FireResetEnv, self).__init__(env)
assert env.unwrapped.get_action_meanings()[1] == "FIRE"
assert len(env.unwrapped.get_action_meanings()) >= 3
Expand All @@ -69,6 +75,9 @@ class MaxAndSkipEnv(Wrapper):
"""Return only every `skip`-th frame"""

def __init__(self, env=None, skip=4):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = collections.deque(maxlen=2)
Expand Down Expand Up @@ -99,8 +108,7 @@ class ProcessFrame84(ObservationWrapper):
"""preprocessing images from env"""

def __init__(self, env=None):

if not _OPENCV_AVAILABLE:
if not _OPENCV_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('This class uses OpenCV which it is not installed yet.')

super(ProcessFrame84, self).__init__(env)
Expand Down Expand Up @@ -130,8 +138,7 @@ class ImageToPyTorch(ObservationWrapper):
"""converts image to pytorch format"""

def __init__(self, env):

if not _OPENCV_AVAILABLE:
if not _OPENCV_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('This class uses OpenCV which it is not installed yet.')

super(ImageToPyTorch, self).__init__(env)
Expand Down Expand Up @@ -188,6 +195,9 @@ class DataAugmentation(ObservationWrapper):
"""

def __init__(self, env=None):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super().__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/rl/dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

if _GYM_AVAILABLE:
from gym import Env
else:
warn_missing_pkg('gym') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('gym')
Env = object


Expand Down
6 changes: 3 additions & 3 deletions pl_bolts/models/rl/reinforce_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

if _GYM_AVAILABLE:
import gym
else:
warn_missing_pkg('gym') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('gym')


class Reinforce(pl.LightningModule):
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
"""
super().__init__()

if not _GYM_AVAILABLE:
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('This Module requires gym environment which is not installed yet.')

# Hyperparameters
Expand Down
Loading