Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Incorrect ABC type signature of BaseModel.forward #1027

Closed
3 tasks done
Rocamonde opened this issue Aug 23, 2022 · 4 comments · Fixed by #1061
Closed
3 tasks done

[Bug] Incorrect ABC type signature of BaseModel.forward #1027

Rocamonde opened this issue Aug 23, 2022 · 4 comments · Fixed by #1061
Labels
bug Something isn't working

Comments

@Rocamonde
Copy link
Contributor

Rocamonde commented Aug 23, 2022

Important Note: We do not do technical support, nor consulting and don't answer personal questions per email.
Please post your question on the RL Discord, Reddit or Stack Overflow in that case.

If your issue is related to a custom gym environment, please use the custom gym env template.

🐛 Bug

The defined type signature of the abstract .forward() method in sb3.common.policies.BaseModel is too general and no actual implementation that is type-safe is possible.

The method is defined as:

@abstractmethod
def forward(self, *args, **kwargs):
    pass

This means that all objects subclassing BaseModel (including BasePolicy) have to necessarily accept any positional and named arguments in their type signature to be compliant with the type.

One has to either

  1. Violate the Liskov Substitution Principle and override type checking
  2. Make the implementation type-unsafe by accepting any arguments, or manually checking the kwargs and raising TypeErrors if the wrong types are passed.

To Reproduce

  1. Create a subclass of the BaseModel instance with a type signature more restrictive than *args, **kwargs
  2. Run mypy on that file
  3. Observe error signature of "func" incompatible with supertype "BaseClass"

Minimal working example. The following code is wrong:

class BaseClass(abc.ABC):
    @abc.abstractmethod
    def func(self, *args, **kwargs) -> None:
        pass


class BlueClass(BaseClass):
    def func(self, arg_1: str, arg_2: str) -> None:
        print(arg_1, arg_2)

Expected behavior

I would like to be allowed to implement this abstract method without having to hack the type checker.

### System Info

  • Installed from PIP
  • GPU models and configuration: N/A
  • Python version: 3.10
  • PyTorch version: N/A
  • Gym version: N/A
  • Versions of any other relevant libraries: N/A

You can use sb3.get_system_info() to print relevant packages info:

>>> sb3.get_system_info()
OS: macOS-12.4-arm64-arm-64bit Darwin Kernel Version 21.5.0: Tue Apr 26 21:08:37 PDT 2022; root:xnu-8020.121.3~4/RELEASE_ARM64_T6000
Python: 3.10.5
Stable-Baselines3: 1.6.0
PyTorch: 1.12.1
GPU Enabled: False
Numpy: 1.23.2
Gym: 0.21.0

({'OS': 'macOS-12.4-arm64-arm-64bit Darwin Kernel Version 21.5.0: Tue Apr 26 21:08:37 PDT 2022; root:xnu-8020.121.3~4/RELEASE_ARM64_T6000', 'Python': '3.10.5', 'Stable-Baselines3': '1.6.0', 'PyTorch': '1.12.1', 'GPU Enabled': 'False', 'Numpy': '1.23.2', 'Gym': '0.21.0'}, 'OS: macOS-12.4-arm64-arm-64bit Darwin Kernel Version 21.5.0: Tue Apr 26 21:08:37 PDT 2022; root:xnu-8020.121.3~4/RELEASE_ARM64_T6000\nPython: 3.10.5\nStable-Baselines3: 1.6.0\nPyTorch: 1.12.1\nGPU Enabled: False\nNumpy: 1.23.2\nGym: 0.21.0\n')

Potential solution

This has been discussed in mypy (python/mypy#5876) and the recommended approach is to avoid violating LSP if possible, especially in ABCs. However, there can a good reason for requiring subclasses to implement some forward method, whose specific type signature may depend on the class. A potential hack is (again with the minimal working example):

class BaseClass(abc.ABC):
    func: Callable[..., None]

    @abc.abstractmethod
    def func(self, *args: P.args, **kwargs: P.kwargs) -> None:  # type: ignore
        ...

class BlueClass(BaseClass):
    def func(self, arg_1: str, arg_2: str) -> None:
        print(arg_1, arg_2)

So in the case of SB3,

forward: Callable[..., Any]
@abstractmethod
def forward(self, *args, **kwargs):  #type: ignore
    pass

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)
@Rocamonde Rocamonde added the bug Something isn't working label Aug 23, 2022
@qgallouedec
Copy link
Collaborator

Your recommendation seems legitimate.
Have you found any other such problems in the SB3 code?
Could just removing those lines solve the problem?

@abstractmethod
def forward(self, *args, **kwargs):
pass

@AdamGleave
Copy link
Collaborator

Deleting the definition of forward in BaseModel does seem like it should work, so long as it's already defined sufficiently generically in nn.Module, which it seems it is from https://github.com/pytorch/pytorch/blob/ce7a9f92e30b93ab6efff4135be005c9afd0533a/torch/nn/modules/module.py#L230-L244

I like PRs that just delete code :)

@Rocamonde
Copy link
Contributor Author

Rocamonde commented Sep 9, 2022 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants