diff --git a/.flake8 b/.flake8 index daef398..b0c1e83 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,5 @@ [flake8] max-line-length = 120 -ignore = F401, E402, E265, F403, W503, W504, F821, W605 +ignore = E402, E265, F403, W503, W504, E731 exclude = .github, .git, venv*, docs, build +per-file-ignores = **/__init__.py:F401 diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 436ad1c..0000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,45 +0,0 @@ ---- -name: Bug report -about: Create a report to help us improve -title: '' -labels: bug -assignees: frgfm - ---- - -## 🐛 Bug - - - -## To Reproduce - -Steps to reproduce the behavior: - -1. -2. -3. - - - -## Expected behavior - - - -## Environment - -Please describe your environement so that the bug can be easily reproduced: - - - Holocron Version (e.g., 0.1.1): - - PyTorch Version (e.g., 1.7): - - Torchvision Version (e.g., 0.8): - - OS (e.g., Linux): - - How you installed Holocron (`conda`, `pip`, source): - - Python version: - - CUDA/cuDNN version: - - GPU models and configuration: - - Any other relevant information: - - -## Additional context - - diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000..8c26978 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,64 @@ +name: 🐛 Bug report +description: Create a report to help us improve the library +labels: 'type: bug' +assignees: frgfm + +body: +- type: markdown + attributes: + value: > + #### Before reporting a bug, please check that the issue hasn't already been addressed in [the existing and past issues](https://github.com/frgfm/torch-cam/issues?q=is%3Aissue). +- type: textarea + attributes: + label: Bug description + description: | + A clear and concise description of what the bug is. + + Please explain the result you observed and the behavior you were expecting. + placeholder: | + A clear and concise description of what the bug is. + validations: + required: true + +- type: textarea + attributes: + label: Code snippet to reproduce the bug + description: | + Sample code to reproduce the problem. + + Please wrap your code snippet with ```` ```triple quotes blocks``` ```` for readability. + placeholder: | + ```python + Sample code to reproduce the problem + ``` + validations: + required: true +- type: textarea + attributes: + label: Error traceback + description: | + The error message you received running the code snippet, with the full traceback. + + Please wrap your error message with ```` ```triple quotes blocks``` ```` for readability. + placeholder: | + ``` + The error message you got, with the full traceback. + ``` + validations: + required: true +- type: textarea + attributes: + label: Environment + description: | + Please run the following command and paste the output below. + ```sh + wget https://raw.githubusercontent.com/frgfm/torch-scan/master/scripts/collect_env.py + # For security purposes, please check the contents of collect_env.py before running it. + python collect_env.py + ``` + validations: + required: true +- type: markdown + attributes: + value: > + Thanks for helping us improve the library! \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..af4b2ed --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: true +contact_links: + - name: Usage questions + url: https://github.com/frgfm/torch-scan/discussions + about: Ask questions and discuss with other TorchCAM community members diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index bf37197..0000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,27 +0,0 @@ ---- -name: Feature request -about: Suggest an idea for this project -title: '' -labels: enhancement -assignees: frgfm - ---- - -## 🚀 Feature - - -## Motivation - - - -## Pitch - - - -## Alternatives - - - -## Additional context - - diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000..00b9a38 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,34 @@ +name: 🚀 Feature request +description: Submit a proposal/request for a new feature +labels: 'type: enhancement' +assignees: frgfm + +body: +- type: textarea + attributes: + label: 🚀 Feature + description: > + A clear and concise description of the feature proposal + validations: + required: true +- type: textarea + attributes: + label: Motivation & pitch + description: > + Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. + validations: + required: true +- type: textarea + attributes: + label: Alternatives + description: > + A description of any alternative solutions or features you've considered, if any. +- type: textarea + attributes: + label: Additional context + description: > + Add any other context or screenshots about the feature request. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉 \ No newline at end of file diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..eda7575 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,20 @@ +# What does this PR do? + + + + + +Closes # (issue) + + +## Before submitting +- [ ] Was this discussed/approved in a Github [issue](https://github.com/frgfm/torch-scan/issues?q=is%3Aissue) or a [discussion](https://github.com/frgfm/torch-scan/discussions)? Please add a link to it if that's the case. +- [ ] You have read the [contribution guidelines](https://github.com/frgfm/torch-scan/blob/master/CONTRIBUTING.md#submitting-a-pull-request) and followed them in this PR. +- [ ] Did you make sure to update the documentation with your changes? Here are the + [documentation guidelines](https://github.com/frgm/torch-scan/tree/master/docs). +- [ ] Did you write any new necessary tests? diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 0000000..8f7f47c --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,24 @@ +changelog: + exclude: + labels: + - ignore-for-release + categories: + - title: Breaking Changes 🛠 + labels: + - "type: breaking change" + # NEW FEATURES + - title: New Features 🚀 + labels: + - "type: new feature" + # BUG FIXES + - title: Bug Fixes 🐛 + labels: + - "type: bug" + # IMPROVEMENTS + - title: Improvements + labels: + - "type: enhancement" + # MISC + - title: Miscellaneous + labels: + - "type: misc" diff --git a/.github/validate_deps.py b/.github/validate_deps.py new file mode 100644 index 0000000..cd5a736 --- /dev/null +++ b/.github/validate_deps.py @@ -0,0 +1,64 @@ +from pathlib import Path + +import requirements +from requirements.requirement import Requirement + +# Deps that won't have a specific requirements.txt +IGNORE = ["flake8", "isort", "mypy", "pydocstyle"] +# All req files to check +REQ_FILES = ["requirements.txt", "tests/requirements.txt", "docs/requirements.txt"] + + +def main(): + + # Collect the deps from all requirements.txt + folder = Path(__file__).parent.parent.absolute() + req_deps = {} + for file in REQ_FILES: + with open(folder.joinpath(file), 'r') as f: + _deps = [(req.name, req.specs) for req in requirements.parse(f)] + + for _dep in _deps: + lib, specs = _dep + assert req_deps.get(lib, specs) == specs, f"conflicting deps for {lib}" + req_deps[lib] = specs + + # Collect the one from setup.py + setup_deps = {} + with open(folder.joinpath("setup.py"), 'r') as f: + setup = f.readlines() + lines = setup[setup.index("_deps = [\n") + 1:] + lines = [_dep.strip() for _dep in lines[:lines.index("]\n")]] + lines = [_dep.split('"')[1] for _dep in lines if _dep.startswith('"')] + _reqs = [Requirement.parse(_line) for _line in lines] + _deps = [(req.name, req.specs) for req in _reqs] + for _dep in _deps: + lib, specs = _dep + assert setup_deps.get(lib) is None, f"conflicting deps for {lib}" + setup_deps[lib] = specs + + # Remove ignores + for k in IGNORE: + if isinstance(req_deps.get(k), list): + del req_deps[k] + if isinstance(setup_deps.get(k), list): + del setup_deps[k] + + # Compare them + assert len(req_deps) == len(setup_deps) + mismatches = [] + for k, v in setup_deps.items(): + assert isinstance(req_deps.get(k), list) + if req_deps[k] != v: + mismatches.append((k, v, req_deps[k])) + + if len(mismatches) > 0: + mismatch_str = "version specifiers mismatches:\n" + mismatch_str += '\n'.join( + f"- {lib}: {setup} (from setup.py) | {reqs} (from requirements)" + for lib, setup, reqs in mismatches + ) + raise AssertionError(mismatch_str) + +if __name__ == "__main__": + main() diff --git a/.github/validate_headers.py b/.github/validate_headers.py new file mode 100644 index 0000000..8ef5d9d --- /dev/null +++ b/.github/validate_headers.py @@ -0,0 +1,64 @@ +from datetime import datetime +from pathlib import Path + +shebang = ["#!usr/bin/python\n"] +blank_line = "\n" + +# Possible years +starting_year = 2020 +current_year = datetime.now().year + +year_options = [f"{current_year}"] + [f"{year}-{current_year}" for year in range(starting_year, current_year)] +copyright_notices = [ + [f"# Copyright (C) {year_str}, François-Guillaume Fernandez.\n"] + for year_str in year_options +] +license_notice = [ + "# This program is licensed under the Apache License version 2.\n", + "# See LICENSE or go to for full license details.\n" +] + +# Define all header options +HEADERS = [ + shebang + [blank_line] + copyright_notice + [blank_line] + license_notice + for copyright_notice in copyright_notices +] + [ + copyright_notice + [blank_line] + license_notice + for copyright_notice in copyright_notices +] + + +IGNORED_FILES = ["version.py", "__init__.py"] +FOLDERS = ["torchscan", "scripts"] + + +def main(): + + invalid_files = [] + + # For every python file in the repository + for folder in FOLDERS: + for source_path in Path(__file__).parent.parent.joinpath(folder).rglob('**/*.py'): + if source_path.name not in IGNORED_FILES: + # Parse header + header_length = max(len(option) for option in HEADERS) + current_header = [] + with open(source_path) as f: + for idx, line in enumerate(f): + current_header.append(line) + if idx == header_length - 1: + break + # Validate it + if not any( + "".join(current_header[:min(len(option), len(current_header))]) == "".join(option) + for option in HEADERS + ): + invalid_files.append(source_path) + + if len(invalid_files) > 0: + invalid_str = "\n- " + "\n- ".join(map(str, invalid_files)) + raise AssertionError(f"Invalid header in the following files:{invalid_str}") + + +if __name__ == "__main__": + main() diff --git a/.github/verify_labels.py b/.github/verify_labels.py new file mode 100644 index 0000000..5d109cb --- /dev/null +++ b/.github/verify_labels.py @@ -0,0 +1,75 @@ +""" +Borrowed & adapted from https://github.com/pytorch/vision/blob/main/.github/process_commit.py +This script finds the merger responsible for labeling a PR by a commit SHA. It is used by the workflow in +'.github/workflows/pr-labels.yml'. If there exists no PR associated with the commit or the PR is properly labeled, +this script is a no-op. +Note: we ping the merger only, not the reviewers, as the reviewers can sometimes be external to torchvision +with no labeling responsibility, so we don't want to bother them. +""" + +from typing import Any, Set, Tuple + +import requests + +# For a PR to be properly labeled it should have one primary label and one secondary label + +# Should specify the type of change +PRIMARY_LABELS = { + "type: new feature", + "type: bug", + "type: enhancement", + "type: misc", +} + +# Should specify what has been modified +SECONDARY_LABELS = { + "topic: documentation", + "module: modules", + "module: process", + "module: crawler", + "module: utils", + "ext: docs", + "ext: scripts", + "ext: tests", + "topic: build", + "topic: ci", +} + +GH_ORG = 'frgfm' +GH_REPO = 'torch-scan' + + +def query_repo(cmd: str, *, accept) -> Any: + response = requests.get(f"https://api.github.com/repos/{GH_ORG}/{GH_REPO}/{cmd}", headers=dict(Accept=accept)) + return response.json() + + +def get_pr_merger_and_labels(pr_number: int) -> Tuple[str, Set[str]]: + # See https://docs.github.com/en/rest/reference/pulls#get-a-pull-request + data = query_repo(f"pulls/{pr_number}", accept="application/vnd.github.v3+json") + merger = data.get("merged_by", {}).get("login") + labels = {label["name"] for label in data["labels"]} + return merger, labels + + +def main(args): + merger, labels = get_pr_merger_and_labels(args.pr) + is_properly_labeled = bool(PRIMARY_LABELS.intersection(labels) and SECONDARY_LABELS.intersection(labels)) + if isinstance(merger, str) and not is_properly_labeled: + print(f"@{merger}") + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='PR label checker', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('pr', type=int, help='PR number') + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/.github/workflows/doc-status.yaml b/.github/workflows/doc-status.yml similarity index 88% rename from .github/workflows/doc-status.yaml rename to .github/workflows/doc-status.yml index e655182..96e3534 100644 --- a/.github/workflows/doc-status.yaml +++ b/.github/workflows/doc-status.yml @@ -1,4 +1,4 @@ -name: doc-status +name: GH-Pages Status on: page_build @@ -19,4 +19,4 @@ jobs: shell: python env: STATUS: ${{ github.event.build.status }} - ERROR: ${{ github.event.build.error.message }} + ERROR: ${{ github.event.build.error.message }} \ No newline at end of file diff --git a/.github/workflows/docs.yaml b/.github/workflows/docs.yml similarity index 92% rename from .github/workflows/docs.yaml rename to .github/workflows/docs.yml index 83c2c7c..9b3c84a 100644 --- a/.github/workflows/docs.yaml +++ b/.github/workflows/docs.yml @@ -33,12 +33,14 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e . - pip install -r docs/requirements.txt + pip install -e ".[docs]" - name: Build documentation run: cd docs && bash build.sh + - name: Documentation sanity check + run: test -e docs/build/index.html || exit + - name: Install SSH Client 🔑 uses: webfactory/ssh-agent@v0.4.1 with: diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9abc52d..daa5c09 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -15,6 +15,8 @@ jobs: python: [3.7] steps: - uses: actions/checkout@v2 + with: + persist-credentials: false - name: Set up Python uses: actions/setup-python@v1 with: @@ -35,7 +37,7 @@ jobs: python -m pip install --upgrade pip pip install -e . --upgrade - unittest: + pytest: needs: install runs-on: ${{ matrix.os }} strategy: @@ -44,6 +46,8 @@ jobs: python: [3.7] steps: - uses: actions/checkout@v2 + with: + persist-credentials: false - name: Set up Python uses: actions/setup-python@v1 with: @@ -53,8 +57,9 @@ jobs: uses: actions/cache@v2 with: path: ~/.cache/pip - key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('**/*.py') }} + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('tests/requirements.txt') }}-${{ hashFiles('**/*.py') }} restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('tests/requirements.txt') }}- ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}- ${{ runner.os }}-pkg-deps-${{ matrix.python }}- ${{ runner.os }}-pkg-deps- @@ -62,11 +67,10 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e . --upgrade - pip install -r test/requirements.txt + pip install -e ".[testing]" --upgrade - name: Run unittests run: | - coverage run -m unittest discover test/ + coverage run -m pytest tests/ coverage xml - uses: actions/upload-artifact@v2 with: @@ -75,7 +79,7 @@ jobs: codecov-upload: runs-on: ubuntu-latest - needs: unittest + needs: pytest steps: - uses: actions/checkout@v2 - uses: actions/download-artifact@v2 @@ -86,8 +90,8 @@ jobs: fail_ci_if_error: true docs-build: - needs: install runs-on: ${{ matrix.os }} + needs: install strategy: matrix: os: [ubuntu-latest] @@ -115,8 +119,50 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -e . --upgrade - pip install -r docs/requirements.txt + pip install -e ".[docs]" --upgrade - name: Build documentation run: cd docs && bash build.sh + + - name: Documentation sanity check + run: test -e docs/build/index.html || exit + + headers: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Run unittests + run: python .github/validate_headers.py + + dependencies: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install requirements-parser==0.2.0 + - name: Run unittests + run: python .github/validate_deps.py diff --git a/.github/workflows/pr-labels.yml b/.github/workflows/pr-labels.yml new file mode 100644 index 0000000..d7a6671 --- /dev/null +++ b/.github/workflows/pr-labels.yml @@ -0,0 +1,29 @@ +name: pr-labels + +on: + pull_request: + branches: master + types: closed + +jobs: + is-properly-labeled: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v2 + - name: Set up python + uses: actions/setup-python@v2 + - name: Install requests + run: pip install requests + - name: Process commit and find merger responsible for labeling + id: commit + run: echo "::set-output name=merger::$(python .github/verify_labels.py ${{ github.event.pull_request.number }})" + - name: Comment PR + uses: actions/github-script@0.3.0 + if: ${{ steps.commit.outputs.merger != '' }} + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const { issue: { number: issue_number }, repo: { owner, repo } } = context; + github.issues.createComment({ issue_number, owner, repo, body: 'Hey ${{ steps.commit.outputs.merger }} 👋\nYou merged this PR, but it is not correctly labeled. The list of valid labels is available at https://github.com/frgfm/torch-cam/blob/master/.github/verify_labels.py' }); diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yml similarity index 90% rename from .github/workflows/release.yaml rename to .github/workflows/release.yml index 9ef69a5..34da599 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yml @@ -1,11 +1,10 @@ -name: pypi-publish +name: release on: release: types: [published] jobs: - pypi-publish: runs-on: ubuntu-latest steps: @@ -65,10 +64,11 @@ jobs: steps: - uses: actions/checkout@v2 - name: Miniconda setup - uses: goanpeca/setup-miniconda@v1 + uses: conda-incubator/setup-miniconda@v2 with: auto-update-conda: true python-version: 3.7 + auto-activate-base: true - name: Install dependencies run: | conda install -y conda-build conda-verify anaconda-client @@ -76,14 +76,12 @@ jobs: id: release_tag run: | echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//} - - name: Anaconda login - run: | - conda install -y conda-build conda-verify anaconda-client - name: Build and publish env: ANACONDA_API_TOKEN: ${{ secrets.ANACONDA_TOKEN }} - BUILD_VERSION: ${{ steps.release_tag.outputs.VERSION }} + VERSION: ${{ steps.release_tag.outputs.VERSION }} run: | + export BUILD_VERSION="${VERSION:1}" python setup.py sdist mkdir conda-dist conda-build .conda/ -c pytorch --output-folder conda-dist @@ -95,14 +93,13 @@ jobs: runs-on: ubuntu-latest needs: conda-publish steps: - - uses: actions/checkout@v2 - name: Miniconda setup - uses: goanpeca/setup-miniconda@v1 + uses: conda-incubator/setup-miniconda@v2 with: auto-update-conda: true python-version: 3.7 + auto-activate-base: true - name: Install package run: | conda install -c frgfm torchscan python -c "import torchscan; print(torchscan.__version__)" - diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 830a581..d70729f 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -26,6 +26,26 @@ jobs: flake8 --version flake8 ./ + isort-py3: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Run isort + run: | + pip install isort + isort --version + isort . + if [ -n "$(git status --porcelain --untracked-files=no)" ]; then exit 1; else echo "All clear"; fi + mypy-py3: runs-on: ${{ matrix.os }} strategy: @@ -58,3 +78,23 @@ jobs: run: | mypy --version mypy --config-file mypy.ini torchscan/ + + pydocstyle-py3: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Run pydocstyle + run: | + pip install pydocstyle + pydocstyle --version + pydocstyle torchscan/ + diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..248f68d --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,5 @@ +[settings] +line_length = 120 +src_paths = torchscan,tests +skip_glob=**/__init__.py +known_third_party=torch,torchvision diff --git a/.pydocstyle b/.pydocstyle new file mode 100644 index 0000000..f81d27e --- /dev/null +++ b/.pydocstyle @@ -0,0 +1,3 @@ +[pydocstyle] +select = D300,D301,D417 +match = .*\.py diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..c29a0db --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +fg-feedback@protonmail.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1fe8ff3..1711c8a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -2,12 +2,13 @@ Everything you need to know to contribute efficiently to the project. +Whatever the way you wish to contribute to the project, please respect the [code of conduct](CODE_OF_CONDUCT.md). ## Codebase structure - [torchscan](https://github.com/frgfm/torch-scan/blob/master/torchscan) - The actual torchscan library -- [test](https://github.com/frgfm/torch-scan/blob/master/test) - Python unit tests +- [tests](https://github.com/frgfm/torch-scan/blob/master/tests) - Python unit tests - [docs](https://github.com/frgfm/torch-scan/blob/master/docs) - Sphinx documentation building - [scripts](https://github.com/frgfm/torch-scan/blob/master/scripts) - Example and utilities scripts @@ -24,33 +25,67 @@ This project uses the following integrations to ensure proper codebase maintenan As a contributor, you will only have to ensure coverage of your code by adding appropriate unit testing of your code. +## Feedback -## Issues +### Feature requests & bug report -Use Github [issues](https://github.com/frgfm/torch-scan/issues) for feature requests, or bug reporting. When doing so, use issue templates whenever possible and provide enough information for other contributors to jump in. +Whether you encountered a problem, or you have a feature suggestion, your input has value and can be used by contributors to reference it in their developments. For this purpose, we advise you to use Github [issues](https://github.com/frgfm/torch-scan/issues). +First, check whether the topic wasn't already covered in an open / closed issue. If not, feel free to open a new one! When doing so, use issue templates whenever possible and provide enough information for other contributors to jump in. +### Questions -## Developping Torchscan +If you are wondering how to do something with TorchScan, or a more general question, you should consider checking out Github [discussions](https://github.com/frgfm/torch-scan/discussions). See it as a Q&A forum, or the TorchScan-specific StackOverflow! -### Commits -- **Code**: ensure to provide docstrings to your Python code. In doing so, please follow [Google-style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) so it can ease the process of documentation later. -- **Commit message**: please follow [Udacity guide](http://udacity.github.io/git-styleguide/) +## Submitting a Pull Request + +### Preparing your local branch + +1 - Fork this [repository](https://github.com/frgfm/torch-scan) by clicking on the "Fork" button at the top right of the page. This will create a copy of the project under your GitHub account (cf. [Fork a repo](https://docs.github.com/en/get-started/quickstart/fork-a-repo)). + +2 - [Clone your fork](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository) to your local disk and set the upstream to this repo +```shell +git clone git@github.com:/torch-scan.git +cd torch-scan +git remote add upstream https://github.com/frgfm/torch-scan.git +``` + +3 - You should not work on the `master` branch, so let's create a new one +```shell +git checkout -b a-short-description +``` + +4 - You only have to set your development environment now. First uninstall any existing installation of the library with `pip uninstall torch-scan`, then: +```shell +pip install -e ".[dev]" +``` + +### Developing your feature +#### Commits -### Running CI verifications locally +- **Code**: ensure to provide docstrings to your Python code. In doing so, please follow [Google-style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) so it can ease the process of documentation later. +- **Commit message**: please follow [Udacity guide](http://udacity.github.io/git-styleguide/) #### Unit tests In order to run the same unit tests as the CI workflows, you can run unittests locally: ```shell -coverage run -m unittest discover test/ +make test ``` -#### Lint verification +#### Code quality + +To run all quality checks together + +```shell +make quality +``` + +##### Lint verification To ensure that your incoming PR complies with the lint settings, you need to install [flake8](https://flake8.pycqa.org/en/latest/) and run the following command from the repository's root folder: @@ -59,11 +94,29 @@ flake8 ./ ``` This will read the `.flake8` setting file and let you know whether your commits need some adjustments. -#### Annotation typing +##### Import order + +In order to ensure there is a common import order convention, run [isort](https://github.com/PyCQA/isort) as follows: + +```shell +isort **/*.py +``` +This will reorder the imports of your local files. + +##### Annotation typing Additionally, to catch type-related issues and have a cleaner codebase, annotation typing are expected. After installing [mypy](https://github.com/python/mypy), you can run the verifications as follows: ```shell -mypy --config-file mypy.ini torchscan/ +mypy --config-file mypy.ini ``` The `mypy.ini` file will be read to check your typing. + +### Submit your modifications + +Push your last modifications to your remote branch +```shell +git push -u origin a-short-description +``` + +Then [open a Pull Request](https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request) from your fork's branch. Follow the instructions of the Pull Request template and then click on "Create a pull request". diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..a87e974 --- /dev/null +++ b/Makefile @@ -0,0 +1,18 @@ +# this target runs checks on all files +quality: + isort . -c -v + flake8 ./ + mypy torchscan/ + pydocstyle torchscan/ + +# this target runs checks on all files and potentially modifies some of them +style: + isort . + +# Run tests for the library +test: + coverage run -m pytest tests/ + +# Check that docs can build +docs: + cd docs && bash build.sh diff --git a/docs/requirements.txt b/docs/requirements.txt index d5f1c66..41db998 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,6 @@ -sphinx -sphinx-rtd-theme==0.4.3 \ No newline at end of file +sphinx<=3.4.3 +sphinx-rtd-theme==0.4.3 +sphinxemoji>=0.1.8 +sphinx-copybutton>=0.3.1 +docutils<0.18 +Jinja2<3.1 diff --git a/docs/source/changelog.rst b/docs/source/changelog.rst new file mode 100644 index 0000000..8598f30 --- /dev/null +++ b/docs/source/changelog.rst @@ -0,0 +1,11 @@ +Changelog +========= + + +v0.1.1 (2020-08-04) +------------------- +Release note: `v0.1.1 `_ + +v0.1.0 (2020-05-21) +------------------- +Release note: `v0.1.0 `_ diff --git a/docs/source/conf.py b/docs/source/conf.py index 3632550..d34b05a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,17 +11,20 @@ # -- Path setup -------------------------------------------------------------- -import sphinx_rtd_theme # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os import sys + +import sphinx_rtd_theme + sys.path.insert(0, os.path.abspath('../..')) -import torchscan from datetime import datetime +import torchscan + # -- Project information ----------------------------------------------------- master_doc = 'index' @@ -43,7 +46,9 @@ 'sphinx.ext.autodoc', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', - 'sphinx.ext.mathjax' + 'sphinx.ext.mathjax', + 'sphinxemoji.sphinxemoji', # cf. https://sphinxemojicodes.readthedocs.io/en/stable/ + 'sphinx_copybutton', ] napoleon_use_ivar = True diff --git a/docs/source/index.rst b/docs/source/index.rst index a2bcf45..300ac47 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,25 +1,36 @@ -Torchscan documentation -======================= +TorchScan: inspect your PyTorch models +====================================== The :mod:`torchscan` package provides tools for analyzing your PyTorch modules and models. Additionally to performance benchmarks, a comprehensive architecture comparison require some insights in the model complexity, its usage of computational and memory resources. +This project is meant for: + +* |:zap:| **exploration**: easily assess the influence of your architecture on resource consumption +* |:woman_scientist:| **research**: quickly implement your own ideas to mitigate latency + + +.. toctree:: + :maxdepth: 2 + :caption: Getting Started + :hidden: + + installing + + .. toctree:: :maxdepth: 1 :caption: Package Reference + :hidden: torchscan modules process utils +.. toctree:: + :maxdepth: 2 + :caption: Notes + :hidden: -.. automodule:: torchscan - :members: - - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` + changelog diff --git a/docs/source/installing.rst b/docs/source/installing.rst new file mode 100644 index 0000000..bac566f --- /dev/null +++ b/docs/source/installing.rst @@ -0,0 +1,36 @@ + +************ +Installation +************ + +This library requires `Python `_ 3.6 or higher. + +Via Python Package +================== + +Install the last stable release of the package using `pip `_: + +.. code:: bash + + pip install torchscan + + +Via Conda +========= + +Install the last stable release of the package using `conda `_: + +.. code:: bash + + conda install -c frgfm torchscan + + +Via Git +======= + +Install the library in developer mode: + +.. code:: bash + + git clone https://github.com/frgfm/torch-scan.git + pip install -e torch-scan/. diff --git a/scripts/benchmark.py b/scripts/benchmark.py index 8e5fa53..f6388d3 100644 --- a/scripts/benchmark.py +++ b/scripts/benchmark.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. +# Copyright (C) 2020-2022, François-Guillaume Fernandez. # This program is licensed under the Apache License version 2. # See LICENSE or go to for full license details. diff --git a/scripts/collect_env.py b/scripts/collect_env.py new file mode 100644 index 0000000..16bab7c --- /dev/null +++ b/scripts/collect_env.py @@ -0,0 +1,318 @@ +# Copyright (C) 2020-2022, François-Guillaume Fernandez. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +""" +Based on https://github.com/pytorch/pytorch/blob/master/torch/utils/collect_env.py +This script outputs relevant system environment info +Run it with `python collect_env.py`. +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import locale +import os +import re +import subprocess +import sys +from collections import namedtuple + +try: + import torchscan + TORCHSCAN_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TORCHSCAN_AVAILABLE = False + +try: + import torch + TORCH_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TORCH_AVAILABLE = False + +PY3 = sys.version_info >= (3, 0) + + +# System Environment Information +SystemEnv = namedtuple('SystemEnv', [ + 'torchscan_version', + 'torch_version', + 'os', + 'python_version', + 'is_cuda_available', + 'cuda_runtime_version', + 'nvidia_driver_version', + 'nvidia_gpu_models', + 'cudnn_version', +]) + + +def run(command): + """Returns (return-code, stdout, stderr)""" + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=True) + output, err = p.communicate() + rc = p.returncode + if PY3: + enc = locale.getpreferredencoding() + output = output.decode(enc) + err = err.decode(enc) + return rc, output.strip(), err.strip() + + +def run_and_read_all(run_lambda, command): + """Runs command using run_lambda; reads and returns entire output if rc is 0""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out + + +def run_and_parse_first_match(run_lambda, command, regex): + """Runs command using run_lambda, returns the first regex match if it exists""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) + + +def get_nvidia_driver_version(run_lambda): + if get_platform() == 'darwin': + cmd = 'kextstat | grep -i cuda' + return run_and_parse_first_match(run_lambda, cmd, + r'com[.]nvidia[.]CUDA [(](.*?)[)]') + smi = get_nvidia_smi() + return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') + + +def get_gpu_info(run_lambda): + if get_platform() == 'darwin': + if TORCH_AVAILABLE and torch.cuda.is_available(): + return torch.cuda.get_device_name(None) + return None + smi = get_nvidia_smi() + uuid_regex = re.compile(r' \(UUID: .+?\)') + rc, out, _ = run_lambda(smi + ' -L') + if rc != 0: + return None + # Anonymize GPUs by removing their UUID + return re.sub(uuid_regex, '', out) + + +def get_running_cuda_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') + + +def get_cudnn_version(run_lambda): + """This will return a list of libcudnn.so; it's hard to tell which one is being used""" + if get_platform() == 'win32': + cudnn_cmd = 'where /R "%CUDA_PATH%\\bin" cudnn*.dll' + elif get_platform() == 'darwin': + # CUDA libraries and drivers can be found in /usr/local/cuda/. See + # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install + # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac + # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. + cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + else: + cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' + rc, out, _ = run_lambda(cudnn_cmd) + # find will return 1 if there are permission errors or if not found + if len(out) == 0 or rc not in (1, 0): + lib = os.environ.get('CUDNN_LIBRARY') + if lib is not None and os.path.isfile(lib): + return os.path.realpath(lib) + return None + files = set() + for fn in out.split('\n'): + fn = os.path.realpath(fn) # eliminate symbolic links + if os.path.isfile(fn): + files.add(fn) + if not files: + return None + # Alphabetize the result because the order is non-deterministic otherwise + files = list(sorted(files)) + if len(files) == 1: + return files[0] + result = '\n'.join(files) + return 'Probably one of the following:\n{}'.format(result) + + +def get_nvidia_smi(): + # Note: nvidia-smi is currently available only on Windows and Linux + smi = 'nvidia-smi' + if get_platform() == 'win32': + smi = '"C:\\Program Files\\NVIDIA Corporation\\NVSMI\\%s"' % smi + return smi + + +def get_platform(): + if sys.platform.startswith('linux'): + return 'linux' + elif sys.platform.startswith('win32'): + return 'win32' + elif sys.platform.startswith('cygwin'): + return 'cygwin' + elif sys.platform.startswith('darwin'): + return 'darwin' + else: + return sys.platform + + +def get_mac_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') + + +def get_windows_version(run_lambda): + return run_and_read_all(run_lambda, 'wmic os get Caption | findstr /v Caption') + + +def get_lsb_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') + + +def check_release_file(run_lambda): + return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', + r'PRETTY_NAME="(.*)"') + + +def get_os(run_lambda): + platform = get_platform() + + if platform in ('win32', 'cygwin'): + return get_windows_version(run_lambda) + + if platform == 'darwin': + version = get_mac_version(run_lambda) + if version is None: + return None + return 'Mac OSX {}'.format(version) + + if platform == 'linux': + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return desc + + # Try reading /etc/*-release + desc = check_release_file(run_lambda) + if desc is not None: + return desc + + return platform + + # Unknown platform + return platform + + +def get_env_info(): + run_lambda = run + + if TORCHSCAN_AVAILABLE: + torchscan_str = torchscan.__version__ + else: + torchscan_str = 'N/A' + + if TORCH_AVAILABLE: + torch_str = torch.__version__ + cuda_available_str = torch.cuda.is_available() + else: + torch_str = cuda_available_str = 'N/A' + + return SystemEnv( + torchscan_version=torchscan_str, + torch_version=torch_str, + python_version=".".join(map(str, sys.version_info[:3])), + is_cuda_available=cuda_available_str, + cuda_runtime_version=get_running_cuda_version(run_lambda), + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version=get_cudnn_version(run_lambda), + os=get_os(run_lambda), + ) + + +env_info_fmt = """ +TorchScan version: {torchscan_version} +PyTorch version: {torch_version} + +OS: {os} + +Python version: {python_version} +Is CUDA available: {is_cuda_available} +CUDA runtime version: {cuda_runtime_version} +GPU models and configuration: {nvidia_gpu_models} +Nvidia driver version: {nvidia_driver_version} +cuDNN version: {cudnn_version} +""".strip() + + +def pretty_str(envinfo): + def replace_nones(dct, replacement='Could not collect'): + for key in dct.keys(): + if dct[key] is not None: + continue + dct[key] = replacement + return dct + + def replace_bools(dct, true='Yes', false='No'): + for key in dct.keys(): + if dct[key] is True: + dct[key] = true + elif dct[key] is False: + dct[key] = false + return dct + + def maybe_start_on_next_line(string): + # If `string` is multiline, prepend a \n to it. + if string is not None and len(string.split('\n')) > 1: + return '\n{}\n'.format(string) + return string + + mutable_dict = envinfo._asdict() + + # If nvidia_gpu_models is multiline, start on the next line + mutable_dict['nvidia_gpu_models'] = \ + maybe_start_on_next_line(envinfo.nvidia_gpu_models) + + # If the machine doesn't have CUDA, report some fields as 'No CUDA' + dynamic_cuda_fields = [ + 'cuda_runtime_version', + 'nvidia_gpu_models', + 'nvidia_driver_version', + ] + all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] + all_dynamic_cuda_fields_missing = all( + mutable_dict[field] is None for field in dynamic_cuda_fields) + if TORCH_AVAILABLE and not torch.cuda.is_available() and all_dynamic_cuda_fields_missing: + for field in all_cuda_fields: + mutable_dict[field] = 'No CUDA' + + # Replace True with Yes, False with No + mutable_dict = replace_bools(mutable_dict) + + # Replace all None objects with 'Could not collect' + mutable_dict = replace_nones(mutable_dict) + + return env_info_fmt.format(**mutable_dict) + + +def get_pretty_env_info(): + """Collects environment information for debugging purposes + + Returns: + str: environment information + """ + return pretty_str(get_env_info()) + + +def main(): + print("Collecting environment information...") + output = get_pretty_env_info() + print(output) + + +if __name__ == '__main__': + main() diff --git a/setup.py b/setup.py index e7591dd..b44c93d 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. +# Copyright (C) 2020-2022, François-Guillaume Fernandez. # This program is licensed under the Apache License version 2. # See LICENSE or go to for full license details. @@ -8,78 +8,131 @@ """ import os +import re import subprocess -from setuptools import find_packages, setup +from pathlib import Path +from setuptools import find_packages, setup -version = '0.1.1' +version = '0.1.2.dev0' sha = 'Unknown' package_name = 'torchscan' -cwd = os.path.dirname(os.path.abspath(__file__)) - -try: - sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() -except Exception: - pass +cwd = Path(__file__).parent.absolute() if os.getenv('BUILD_VERSION'): version = os.getenv('BUILD_VERSION') -elif sha != 'Unknown': - version += '+' + sha[:7] -print("Building wheel {}-{}".format(package_name, version)) +else: + try: + sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() + except Exception: + pass + if sha != 'Unknown': + version += '+' + sha[:7] +print(f"Building wheel {package_name}-{version}") + +with open(cwd.joinpath('torchscan', 'version.py'), 'w') as f: + f.write(f"__version__ = '{version}'\n") + +with open('README.md', 'r') as f: + readme = f.read() +_deps = [ + "torch>=1.5.0", + # Testing + "pytest>=5.3.2", + "coverage>=4.5.4", + # Quality + "flake8>=3.9.0", + "isort>=5.7.0", + "mypy>=0.812", + "pydocstyle>=6.0.0", + # Docs + "sphinx<=3.4.3", + "sphinx-rtd-theme==0.4.3", + "sphinxemoji>=0.1.8", + "sphinx-copybutton>=0.3.1", + "docutils<0.18", + "Jinja2<3.1", # cf. https://github.com/readthedocs/readthedocs.org/issues/9038 +] -def write_version_file(): - version_path = os.path.join(cwd, 'torchscan', 'version.py') - with open(version_path, 'w') as f: - f.write("__version__ = '{}'\n".format(version)) +# Borrowed from https://github.com/huggingface/transformers/blob/master/setup.py +deps = {b: a for a, b in (re.findall(r"^(([^!=<>]+)(?:[!=<>].*)?$)", x)[0] for x in _deps)} -write_version_file() +def deps_list(*pkgs): + return [deps[pkg] for pkg in pkgs] -with open('README.md') as f: - readme = f.read() -requirements = [ - 'torch>=1.5.0' +install_requires = [ + deps["torch"], ] +extras = {} + +extras["testing"] = deps_list( + "pytest", + "coverage", +) + +extras["quality"] = deps_list( + "flake8", + "isort", + "mypy", + "pydocstyle", +) + +extras["docs"] = deps_list( + "sphinx", + "sphinx-rtd-theme", + "sphinxemoji", + "sphinx-copybutton", + "docutils", + "Jinja2", +) + +extras["dev"] = ( + extras["testing"] + + extras["quality"] + + extras["docs"] +) + + setup( # Metadata name=package_name, version=version, author='François-Guillaume Fernandez', + author_email='fg-feedback@protonmail.com', description='Useful information about your Pytorch module', long_description=readme, long_description_content_type="text/markdown", url='https://github.com/frgfm/torch-scan', download_url='https://github.com/frgfm/torch-scan/tags', - license='MIT', + license='Apache', classifiers=[ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: MIT License', + 'License :: OSI Approved :: Apache Software License', 'Natural Language :: English', 'Operating System :: OS Independent', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering :: Mathematics', 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', ], keywords=['pytorch', 'deep learning', 'summary', 'memory', 'ram'], # Package info - packages=find_packages(exclude=('test',)), + packages=find_packages(exclude=('tests',)), zip_safe=True, python_requires='>=3.6.0', include_package_data=True, - install_requires=requirements, + install_requires=install_requires, + extras_require=extras, package_data={'': ['LICENSE']} ) diff --git a/test/test_crawler.py b/test/test_crawler.py deleted file mode 100644 index 6a9b665..0000000 --- a/test/test_crawler.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. - -# This program is licensed under the Apache License version 2. -# See LICENSE or go to for full license details. - -import io -import sys -import unittest - -import torch.nn as nn - -from torchscan import crawler - - -class UtilsTester(unittest.TestCase): - def test_apply(self): - multi_convs = nn.Sequential(nn.Conv2d(16, 32, 3), nn.Conv2d(32, 64, 3)) - mod = nn.Sequential(nn.Conv2d(3, 16, 3), multi_convs) - - # Tag module attributes - def tag_name(mod, name): - mod.__depth__ = len(name.split('.')) - 1 - mod.__name__ = name.rpartition('.')[-1] - - crawler.apply(mod, tag_name) - - self.assertEqual(mod[1][1].__depth__, 2) - self.assertEqual(mod[1][1].__name__, '1') - - def test_crawl_module(self): - - mod = nn.Conv2d(3, 8, 3) - - res = crawler.crawl_module(mod, (3, 32, 32)) - self.assertIsInstance(res, dict) - self.assertEqual(res['overall']['grad_params'], 224) - self.assertEqual(res['layers'][0]['output_shape'], (-1, 8, 30, 30)) - - def test_summary(self): - - mod = nn.Conv2d(3, 8, 3) - - # Redirect stdout with StringIO object - captured_output = io.StringIO() - sys.stdout = captured_output - crawler.summary(mod, (3, 32, 32)) - # Reset redirect. - sys.stdout = sys.__stdout__ - self.assertEqual(captured_output.getvalue().split('\n')[7], 'Total params: 224') - - # Check receptive field - captured_output = io.StringIO() - sys.stdout = captured_output - crawler.summary(mod, (3, 32, 32), receptive_field=True) - # Reset redirect. - sys.stdout = sys.__stdout__ - self.assertEqual(captured_output.getvalue().split('\n')[1].rpartition(' ')[-1], 'Receptive field') - self.assertEqual(captured_output.getvalue().split('\n')[3].split()[-1], '3') - # Check effective stats - captured_output = io.StringIO() - sys.stdout = captured_output - crawler.summary(mod, (3, 32, 32), receptive_field=True, effective_rf_stats=True) - # Reset redirect. - sys.stdout = sys.__stdout__ - self.assertEqual(captured_output.getvalue().split('\n')[1].rpartition(' ')[-1], 'Effective padding') - self.assertEqual(captured_output.getvalue().split('\n')[3].split()[-1], '0') - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_modules.py b/test/test_modules.py deleted file mode 100644 index b2a3cc9..0000000 --- a/test/test_modules.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. - -# This program is licensed under the Apache License version 2. -# See LICENSE or go to for full license details. - -import unittest - -import torch -from torch import nn - -from torchscan import modules - - -class MyModule(nn.Module): - def __init__(self): - super().__init__() - - -class Tester(unittest.TestCase): - @torch.no_grad() - def test_module_flops(self): - - # Check for unknown module that it returns 0 and throws a warning - self.assertEqual(modules.module_flops(MyModule(), None, None), 0) - self.assertWarns(UserWarning, modules.module_flops, MyModule(), None, None) - - # Common unit tests - self.assertEqual(modules.module_flops(nn.Linear(8, 4), (torch.zeros((1, 8)),), torch.zeros((1, 4))), - 4 * (2 * 8 - 1) + 4) - self.assertEqual(modules.module_flops(nn.Linear(8, 4, bias=False), (torch.zeros((1, 8)),), torch.zeros((1, 4))), - 4 * (2 * 8 - 1)) - self.assertEqual(modules.module_flops(nn.Linear(8, 4), (torch.zeros((1, 2, 8)),), torch.zeros((1, 2, 4))), - 2 * (4 * (2 * 8 - 1) + 4)) - # Activations - self.assertEqual(modules.module_flops(nn.Identity(), (torch.zeros((1, 8)),), torch.zeros((1, 8))), 0) - self.assertEqual(modules.module_flops(nn.Flatten(), (torch.zeros((1, 8)),), torch.zeros((1, 8))), 0) - self.assertEqual(modules.module_flops(nn.ReLU(), (torch.zeros((1, 8)),), torch.zeros((1, 8))), 8) - self.assertEqual(modules.module_flops(nn.ELU(), (torch.zeros((1, 8)),), torch.zeros((1, 8))), 48) - self.assertEqual(modules.module_flops(nn.LeakyReLU(), (torch.zeros((1, 8)),), torch.zeros((1, 8))), 32) - self.assertEqual(modules.module_flops(nn.ReLU6(), (torch.zeros((1, 8)),), torch.zeros((1, 8))), 16) - self.assertEqual(modules.module_flops(nn.Tanh(), (torch.zeros((1, 8)),), torch.zeros((1, 8))), 48) - self.assertEqual(modules.module_flops(nn.Sigmoid(), (torch.zeros((1, 8)),), torch.zeros((1, 8))), 32) - - # BN - self.assertEqual(modules.module_flops(nn.BatchNorm1d(8), (torch.zeros((1, 8, 4)),), torch.zeros((1, 8, 4))), - 144 + 32 + 32 * 3 + 48) - - # Pooling - self.assertEqual(modules.module_flops(nn.MaxPool2d((2, 2)), - (torch.zeros((1, 8, 4, 4)),), torch.zeros((1, 8, 2, 2))), - 3 * 32) - self.assertEqual(modules.module_flops(nn.AvgPool2d((2, 2)), - (torch.zeros((1, 8, 4, 4)),), torch.zeros((1, 8, 2, 2))), - 5 * 32) - self.assertEqual(modules.module_flops(nn.AdaptiveMaxPool2d((2, 2)), - (torch.zeros((1, 8, 4, 4)),), torch.zeros((1, 8, 2, 2))), - 3 * 32) - # Check that single integer output size is supported - self.assertEqual(modules.module_flops(nn.AdaptiveMaxPool2d(2), - (torch.zeros((1, 8, 4, 4)),), torch.zeros((1, 8, 2, 2))), - 3 * 32) - self.assertEqual(modules.module_flops(nn.AdaptiveAvgPool2d((2, 2)), - (torch.zeros((1, 8, 4, 4)),), torch.zeros((1, 8, 2, 2))), - 5 * 32) - # Check that single integer output size is supported - self.assertEqual(modules.module_flops(nn.AdaptiveAvgPool2d(2), - (torch.zeros((1, 8, 4, 4)),), torch.zeros((1, 8, 2, 2))), - 5 * 32) - - # Dropout - self.assertEqual(modules.module_flops(nn.Dropout(), (torch.zeros((1, 8)),), torch.zeros((1, 8))), 8) - self.assertEqual(modules.module_flops(nn.Dropout(p=0), (torch.zeros((1, 8)),), torch.zeros((1, 8))), 0) - - # Conv - input_t = torch.rand((1, 3, 32, 32)) - mod = nn.Conv2d(3, 8, 3) - self.assertEqual(modules.module_flops(mod, (input_t,), mod(input_t)), 388800) - # ConvTranspose - mod = nn.ConvTranspose2d(3, 8, 3) - self.assertEqual(modules.module_flops(mod, (input_t,), mod(input_t)), 499408) - # Transformer - mod = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=3) - src = torch.rand((10, 16, 64)) - tgt = torch.rand((20, 16, 64)) - self.assertEqual(modules.module_flops(mod, (src, tgt), mod(src, tgt)), 774952841) - - @torch.no_grad() - def test_module_macs(self): - - # Check for unknown module that it returns 0 and throws a warning - self.assertEqual(modules.module_macs(MyModule(), None, None), 0) - self.assertWarns(UserWarning, modules.module_macs, MyModule(), None, None) - - # Linear - self.assertEqual(modules.module_macs(nn.Linear(8, 4), torch.zeros((1, 8)), torch.zeros((1, 4))), - 8 * 4) - self.assertEqual(modules.module_macs(nn.Linear(8, 4), torch.zeros((1, 2, 8)), torch.zeros((1, 2, 4))), - 8 * 4 * 2) - # Activations - self.assertEqual(modules.module_macs(nn.ReLU(), None, None), 0) - # Conv - input_t = torch.rand((1, 3, 32, 32)) - mod = nn.Conv2d(3, 8, 3) - self.assertEqual(modules.module_macs(mod, input_t, mod(input_t)), 194400) - # ConvTranspose - mod = nn.ConvTranspose2d(3, 8, 3) - self.assertEqual(modules.module_macs(mod, input_t, mod(input_t)), 249704) - # BN - self.assertEqual(modules.module_macs(nn.BatchNorm1d(8), torch.zeros((1, 8, 4)), torch.zeros((1, 8, 4))), - 64 + 24 + 56 + 32) - - # Pooling - self.assertEqual(modules.module_macs(nn.MaxPool2d((2, 2)), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - 3 * 32) - self.assertEqual(modules.module_macs(nn.AvgPool2d((2, 2)), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - 5 * 32) - self.assertEqual(modules.module_macs(nn.AdaptiveMaxPool2d((2, 2)), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - 3 * 32) - self.assertEqual(modules.module_macs(nn.AdaptiveAvgPool2d((2, 2)), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - 5 * 32) - # Test support integer output-size support - self.assertEqual(modules.module_macs(nn.AdaptiveMaxPool2d(2), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - 3 * 32) - self.assertEqual(modules.module_macs(nn.AdaptiveAvgPool2d(2), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - 5 * 32) - - # Dropout - self.assertEqual(modules.module_macs(nn.Dropout(), torch.zeros((1, 8)), torch.zeros((1, 8))), 0) - - @torch.no_grad() - def test_module_dmas(self): - - # Check for unknown module that it returns 0 and throws a warning - self.assertEqual(modules.module_dmas(MyModule(), None, None), 0) - self.assertWarns(UserWarning, modules.module_dmas, MyModule(), None, None) - - # Common unit tests - # Linear - self.assertEqual(modules.module_dmas(nn.Linear(8, 4), torch.zeros((1, 8)), torch.zeros((1, 4))), - 4 * (8 + 1) + 8 + 4) - # Activation - self.assertEqual(modules.module_dmas(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), 8) - self.assertEqual(modules.module_dmas(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), 16) - self.assertEqual(modules.module_dmas(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 8 * 2) - self.assertEqual(modules.module_dmas(nn.ReLU(inplace=True), torch.zeros((1, 8)), None), 8) - self.assertEqual(modules.module_dmas(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), 17) - self.assertEqual(modules.module_dmas(nn.Sigmoid(), torch.zeros((1, 8)), torch.zeros((1, 8))), 16) - self.assertEqual(modules.module_dmas(nn.Tanh(), torch.zeros((1, 8)), torch.zeros((1, 8))), 24) - # Conv - input_t = torch.rand((1, 3, 32, 32)) - mod = nn.Conv2d(3, 8, 3) - self.assertEqual(modules.module_dmas(mod, input_t, mod(input_t)), 201824) - # ConvTranspose - mod = nn.ConvTranspose2d(3, 8, 3) - self.assertEqual(modules.module_dmas(mod, input_t, mod(input_t)), 259178) - # BN - self.assertEqual(modules.module_dmas(nn.BatchNorm1d(8), torch.zeros((1, 8, 4)), torch.zeros((1, 8, 4))), - 32 + 17 + 1 + 16 + 17 + 32) - - # Pooling - self.assertEqual(modules.module_dmas(nn.MaxPool2d((2, 2)), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - 4 * 32 + 32) - self.assertEqual(modules.module_dmas(nn.AdaptiveMaxPool2d((2, 2)), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - 4 * 32 + 32) - # Integer output size support - self.assertEqual(modules.module_dmas(nn.MaxPool2d(2), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - 4 * 32 + 32) - self.assertEqual(modules.module_dmas(nn.AdaptiveMaxPool2d(2), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - 4 * 32 + 32) - - # Dropout - self.assertEqual(modules.module_dmas(nn.Dropout(), torch.zeros((1, 8)), torch.zeros((1, 8))), 17) - - @torch.no_grad() - def test_module_rf(self): - - # Check for unknown module that it returns 0 and throws a warning - self.assertEqual(modules.module_rf(MyModule(), None, None), (1, 1, 0)) - self.assertWarns(UserWarning, modules.module_rf, MyModule(), None, None) - - # Common unit tests - # Linear - self.assertEqual(modules.module_rf(nn.Linear(8, 4), torch.zeros((1, 8)), torch.zeros((1, 4))), - (1, 1, 0)) - # Activation - self.assertEqual(modules.module_rf(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) - self.assertEqual(modules.module_rf(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) - self.assertEqual(modules.module_rf(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) - self.assertEqual(modules.module_rf(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) - self.assertEqual(modules.module_rf(nn.Sigmoid(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) - self.assertEqual(modules.module_rf(nn.Tanh(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) - # Conv - input_t = torch.rand((1, 3, 32, 32)) - mod = nn.Conv2d(3, 8, 3) - self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (3, 1, 0)) - # Check for dilation support - mod = nn.Conv2d(3, 8, 3, dilation=2) - self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (5, 1, 0)) - # ConvTranspose - mod = nn.ConvTranspose2d(3, 8, 3) - self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (-3, 1, 0)) - # BN - self.assertEqual(modules.module_rf(nn.BatchNorm1d(8), torch.zeros((1, 8, 4)), torch.zeros((1, 8, 4))), - (1, 1, 0)) - - # Pooling - self.assertEqual(modules.module_rf(nn.MaxPool2d((2, 2)), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - (2, 2, 0)) - self.assertEqual(modules.module_rf(nn.AdaptiveMaxPool2d((2, 2)), - torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), - (2, 2, 0)) - - # Dropout - self.assertEqual(modules.module_rf(nn.Dropout(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_process.py b/test/test_process.py deleted file mode 100644 index 12e9cb5..0000000 --- a/test/test_process.py +++ /dev/null @@ -1,24 +0,0 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. - -# This program is licensed under the Apache License version 2. -# See LICENSE or go to for full license details. - -import os -import unittest - -import torch - -from torchscan import process - - -class Tester(unittest.TestCase): - def test_get_process_gpu_ram(self): - - if torch.cuda.is_initialized: - self.assertGreaterEqual(process.get_process_gpu_ram(os.getpid()), 0) - else: - self.assertEqual(process.get_process_gpu_ram(os.getpid()), 0) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_utils.py b/test/test_utils.py deleted file mode 100644 index 1189283..0000000 --- a/test/test_utils.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. - -# This program is licensed under the Apache License version 2. -# See LICENSE or go to for full license details. - -import unittest - -from torchscan import utils - - -class UtilsTester(unittest.TestCase): - def test_format_name(self): - name = 'mymodule' - self.assertEqual(utils.format_name(name), name) - self.assertEqual(utils.format_name(name, depth=1), f"├─{name}") - self.assertEqual(utils.format_name(name, depth=3), f"| | └─{name}") - - def test_wrap_string(self): - - example = '.'.join(['a' for _ in range(10)]) - max_len = 10 - wrap = '[...]' - - self.assertEqual(utils.wrap_string(example, max_len, mode='end'), - example[:max_len - len(wrap)] + wrap) - self.assertEqual(utils.wrap_string(example, max_len, mode='mid'), - f"{example[:max_len - 2 - len(wrap)]}{wrap}.a") - self.assertEqual(utils.wrap_string(example, len(example), mode='end'), example) - self.assertRaises(ValueError, utils.wrap_string, example, max_len, mode='test') - - def test_unit_scale(self): - - self.assertEqual(utils.unit_scale(3e14), (300, 'T')) - self.assertEqual(utils.unit_scale(3e10), (30, 'G')) - self.assertEqual(utils.unit_scale(3e7), (30, 'M')) - self.assertEqual(utils.unit_scale(15e3), (15, 'k')) - self.assertEqual(utils.unit_scale(500), (500, '')) - - -if __name__ == '__main__': - unittest.main() diff --git a/test/requirements.txt b/tests/requirements.txt similarity index 53% rename from test/requirements.txt rename to tests/requirements.txt index 7dd0fc7..f067cbf 100644 --- a/test/requirements.txt +++ b/tests/requirements.txt @@ -1 +1,2 @@ +pytest>=5.3.2 coverage>=4.5.4 diff --git a/tests/test_crawler.py b/tests/test_crawler.py new file mode 100644 index 0000000..8194f5e --- /dev/null +++ b/tests/test_crawler.py @@ -0,0 +1,61 @@ +import io +import sys + +import torch.nn as nn + +from torchscan import crawler + + +def test_apply(): + multi_convs = nn.Sequential(nn.Conv2d(16, 32, 3), nn.Conv2d(32, 64, 3)) + mod = nn.Sequential(nn.Conv2d(3, 16, 3), multi_convs) + + # Tag module attributes + def tag_name(mod, name): + mod.__depth__ = len(name.split('.')) - 1 + mod.__name__ = name.rpartition('.')[-1] + + crawler.apply(mod, tag_name) + + assert mod[1][1].__depth__ == 2 + assert mod[1][1].__name__ == '1' + + +def test_crawl_module(): + + mod = nn.Conv2d(3, 8, 3) + + res = crawler.crawl_module(mod, (3, 32, 32)) + assert isinstance(res, dict) + assert res['overall']['grad_params'] == 224 + assert res['layers'][0]['output_shape'] == (-1, 8, 30, 30) + + +def test_summary(): + + mod = nn.Conv2d(3, 8, 3) + + # Redirect stdout with StringIO object + captured_output = io.StringIO() + sys.stdout = captured_output + crawler.summary(mod, (3, 32, 32)) + # Reset redirect. + sys.stdout = sys.__stdout__ + assert captured_output.getvalue().split('\n')[7] == 'Total params: 224' + + # Check receptive field + captured_output = io.StringIO() + sys.stdout = captured_output + crawler.summary(mod, (3, 32, 32), receptive_field=True) + # Reset redirect. + sys.stdout = sys.__stdout__ + assert captured_output.getvalue().split('\n')[1].rpartition(' ')[-1] == 'Receptive field' + assert captured_output.getvalue().split('\n')[3].split()[-1] == '3' + # Check effective stats + captured_output = io.StringIO() + sys.stdout = captured_output + crawler.summary(mod, (3, 32, 32), receptive_field=True, effective_rf_stats=True) + # Reset redirect. + sys.stdout = sys.__stdout__ + assert captured_output.getvalue().split('\n')[1].rpartition(' ')[-1] == 'Effective padding' + assert captured_output.getvalue().split('\n')[3].split()[-1] == '0' diff --git a/tests/test_modules.py b/tests/test_modules.py new file mode 100644 index 0000000..40fccb5 --- /dev/null +++ b/tests/test_modules.py @@ -0,0 +1,181 @@ +import pytest +import torch +from torch import nn + +from torchscan import modules + + +class MyModule(nn.Module): + def __init__(self): + super().__init__() + + +def test_module_flops_warning(): + with pytest.warns(UserWarning): + modules.module_flops(MyModule(), None, None) + + +@pytest.mark.parametrize( + "mod, input_shape, output_shape, expected_val", + [ + # Check for unknown module that it returns 0 and throws a warning + [MyModule(), (1,), (1,), 0], + # Fully-connected + [nn.Linear(8, 4), (1, 8), (1, 4), 4 * (2 * 8 - 1) + 4], + [nn.Linear(8, 4, bias=False), (1, 8), (1, 4), 4 * (2 * 8 - 1)], + [nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 2 * (4 * (2 * 8 - 1) + 4)], + # Activations + [nn.Identity(), (1, 8), (1, 8), 0], + [nn.Flatten(), (1, 8), (1, 8), 0], + [nn.ReLU(), (1, 8), (1, 8), 8], + [nn.ELU(), (1, 8), (1, 8), 48], + [nn.LeakyReLU(), (1, 8), (1, 8), 32], + [nn.ReLU6(), (1, 8), (1, 8), 16], + [nn.Tanh(), (1, 8), (1, 8), 48], + [nn.Sigmoid(), (1, 8), (1, 8), 32], + # BN + [nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 144 + 32 + 32 * 3 + 48], + # Pooling + [nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32], + [nn.AvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32], + [nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32], + [nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32], + [nn.AdaptiveAvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32], + [nn.AdaptiveAvgPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32], + # Dropout + [nn.Dropout(), (1, 8), (1, 8), 8], + [nn.Dropout(p=0), (1, 8), (1, 8), 0], + # Conv + [nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 388800], + [nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 499408], + ], +) +def test_module_flops(mod, input_shape, output_shape, expected_val): + assert modules.module_flops(mod, (torch.zeros(input_shape),), torch.zeros(output_shape)) == expected_val + + +def test_transformer_flops(): + mod = nn.Transformer(d_model=64, nhead=4, num_encoder_layers=3) + src = torch.rand((10, 16, 64)) + tgt = torch.rand((20, 16, 64)) + assert modules.module_flops(mod, (src, tgt), mod(src, tgt)) == 774952841 + + +def test_module_macs_warning(): + with pytest.warns(UserWarning): + modules.module_macs(MyModule(), None, None) + + +@pytest.mark.parametrize( + "mod, input_shape, output_shape, expected_val", + [ + # Check for unknown module that it returns 0 and throws a warning + [MyModule(), (1,), (1,), 0], + # Fully-connected + [nn.Linear(8, 4), (1, 8), (1, 4), 8 * 4], + [nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 8 * 4 * 2], + # Activations + [nn.ReLU(), (1, 8), (1, 8), 0], + # BN + [nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 64 + 24 + 56 + 32], + # Pooling + [nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32], + [nn.AvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32], + [nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32], + [nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 3 * 32], + [nn.AdaptiveAvgPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32], + [nn.AdaptiveAvgPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 5 * 32], + # Dropout + [nn.Dropout(), (1, 8), (1, 8), 0], + # Conv + [nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 194400], + [nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 249704], + ], +) +def test_module_macs(mod, input_shape, output_shape, expected_val): + + assert modules.module_macs(mod, torch.zeros(input_shape), torch.zeros(output_shape)) == expected_val + + +def test_module_dmas_warning(): + with pytest.warns(UserWarning): + modules.module_dmas(MyModule(), None, None) + + +@pytest.mark.parametrize( + "mod, input_shape, output_shape, expected_val", + [ + # Check for unknown module that it returns 0 and throws a warning + [MyModule(), (1,), (1,), 0], + # Fully-connected + [nn.Linear(8, 4), (1, 8), (1, 4), 4 * (8 + 1) + 8 + 4], + [nn.Linear(8, 4), (1, 2, 8), (1, 2, 4), 4 * (8 + 1) + 2 * (8 + 4)], + # Activations + [nn.Identity(), (1, 8), (1, 8), 8], + [nn.Flatten(), (1, 8), (1, 8), 16], + [nn.ReLU(), (1, 8), (1, 8), 8 * 2], + [nn.ReLU(inplace=True), (1, 8), (1, 8), 8], + [nn.ELU(), (1, 8), (1, 8), 17], + [nn.Tanh(), (1, 8), (1, 8), 24], + [nn.Sigmoid(), (1, 8), (1, 8), 16], + # BN + [nn.BatchNorm1d(8), (1, 8, 4), (1, 8, 4), 32 + 17 + 16 + 1 + 17 + 32], + # Pooling + [nn.MaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32], + [nn.MaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32], + [nn.AdaptiveMaxPool2d((2, 2)), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32], + [nn.AdaptiveMaxPool2d(2), (1, 8, 4, 4), (1, 8, 2, 2), 4 * 32 + 32], + # Dropout + [nn.Dropout(), (1, 8), (1, 8), 17], + # Conv + [nn.Conv2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 30, 30), 201824], + [nn.ConvTranspose2d(3, 8, 3), (1, 3, 32, 32), (1, 8, 34, 34), 259178], + ], +) +def test_module_dmas(mod, input_shape, output_shape, expected_val): + + assert modules.module_dmas(mod, torch.zeros(input_shape), torch.zeros(output_shape)) == expected_val + + +# @torch.no_grad() +# def test_module_rf(self): + +# # Check for unknown module that it returns 0 and throws a warning +# self.assertEqual(modules.module_rf(MyModule(), None, None), (1, 1, 0)) +# self.assertWarns(UserWarning, modules.module_rf, MyModule(), None, None) + +# # Common unit tests +# # Linear +# self.assertEqual(modules.module_rf(nn.Linear(8, 4), torch.zeros((1, 8)), torch.zeros((1, 4))), +# (1, 1, 0)) +# # Activation +# self.assertEqual(modules.module_rf(nn.Identity(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) +# self.assertEqual(modules.module_rf(nn.Flatten(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) +# self.assertEqual(modules.module_rf(nn.ReLU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) +# self.assertEqual(modules.module_rf(nn.ELU(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) +# self.assertEqual(modules.module_rf(nn.Sigmoid(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) +# self.assertEqual(modules.module_rf(nn.Tanh(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) +# # Conv +# input_t = torch.rand((1, 3, 32, 32)) +# mod = nn.Conv2d(3, 8, 3) +# self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (3, 1, 0)) +# # Check for dilation support +# mod = nn.Conv2d(3, 8, 3, dilation=2) +# self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (5, 1, 0)) +# # ConvTranspose +# mod = nn.ConvTranspose2d(3, 8, 3) +# self.assertEqual(modules.module_rf(mod, input_t, mod(input_t)), (-3, 1, 0)) +# # BN +# self.assertEqual(modules.module_rf(nn.BatchNorm1d(8), torch.zeros((1, 8, 4)), torch.zeros((1, 8, 4))), +# (1, 1, 0)) + +# # Pooling +# self.assertEqual(modules.module_rf(nn.MaxPool2d((2, 2)), +# torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), +# (2, 2, 0)) +# self.assertEqual(modules.module_rf(nn.AdaptiveMaxPool2d((2, 2)), +# torch.zeros((1, 8, 4, 4)), torch.zeros((1, 8, 2, 2))), +# (2, 2, 0)) + +# # Dropout +# self.assertEqual(modules.module_rf(nn.Dropout(), torch.zeros((1, 8)), torch.zeros((1, 8))), (1, 1, 0)) diff --git a/tests/test_process.py b/tests/test_process.py new file mode 100644 index 0000000..b40371a --- /dev/null +++ b/tests/test_process.py @@ -0,0 +1,13 @@ +import os + +import torch + +from torchscan import process + + +def test_get_process_gpu_ram(): + + if torch.cuda.is_initialized: + assert process.get_process_gpu_ram(os.getpid()) >= 0 + else: + assert process.get_process_gpu_ram(os.getpid()) == 0 diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 0000000..e29a63a --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,37 @@ +import pytest + +from torchscan import utils + + +def test_format_name(): + name = 'mymodule' + assert utils.format_name(name) == name + assert utils.format_name(name, depth=1) == f"├─{name}" + assert utils.format_name(name, depth=3) == f"| | └─{name}" + + +def test_wrap_string(): + + example = '.'.join(['a' for _ in range(10)]) + max_len = 10 + wrap = '[...]' + + assert utils.wrap_string(example, max_len, mode='end') == example[:max_len - len(wrap)] + wrap + assert utils.wrap_string(example, max_len, mode='mid') == f"{example[:max_len - 2 - len(wrap)]}{wrap}.a" + assert utils.wrap_string(example, len(example), mode='end') == example + with pytest.raises(ValueError): + _ = utils.wrap_string(example, max_len, mode='test') + + +@pytest.mark.parametrize( + "input_val, num_val, unit", + [ + [3e14, 300, "T"], + [3e10, 30, "G"], + [3e7, 30, "M"], + [15e3, 15, "k"], + [500, 500, ""], + ], +) +def test_unit_scale(input_val, num_val, unit): + assert utils.unit_scale(input_val) == (num_val, unit) diff --git a/torchscan/crawler.py b/torchscan/crawler.py index ea808e7..876bfe2 100644 --- a/torchscan/crawler.py +++ b/torchscan/crawler.py @@ -1,13 +1,13 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. +# Copyright (C) 2020-2022, François-Guillaume Fernandez. # This program is licensed under the Apache License version 2. # See LICENSE or go to for full license details. import os +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import torch from torch.nn import Module -from typing import Callable, Optional, Dict, Any, Tuple, List, Union, Iterable from .modules import module_dmas, module_flops, module_macs, module_rf from .process import get_process_gpu_ram diff --git a/torchscan/modules/flops.py b/torchscan/modules/flops.py index 0433d17..b6782a0 100644 --- a/torchscan/modules/flops.py +++ b/torchscan/modules/flops.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. +# Copyright (C) 2020-2022, François-Guillaume Fernandez. # This program is licensed under the Apache License version 2. # See LICENSE or go to for full license details. @@ -9,13 +9,11 @@ from typing import Tuple import torch -from torch import nn -from torch import Tensor +from torch import Tensor, nn from torch.nn import Module from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd -from torch.nn.modules.pooling import _MaxPoolNd, _AvgPoolNd, _AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd - +from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd __all__ = ['module_flops'] diff --git a/torchscan/modules/macs.py b/torchscan/modules/macs.py index 7ebb7bc..f1b66f7 100644 --- a/torchscan/modules/macs.py +++ b/torchscan/modules/macs.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. +# Copyright (C) 2020-2022, François-Guillaume Fernandez. # This program is licensed under the Apache License version 2. # See LICENSE or go to for full license details. @@ -7,13 +7,11 @@ from functools import reduce from operator import mul -from torch import nn -from torch import Tensor +from torch import Tensor, nn from torch.nn import Module -from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd from torch.nn.modules.batchnorm import _BatchNorm -from torch.nn.modules.pooling import _MaxPoolNd, _AvgPoolNd, _AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd - +from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd +from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd __all__ = ['module_macs'] diff --git a/torchscan/modules/memory.py b/torchscan/modules/memory.py index 0f53cc4..81073e9 100644 --- a/torchscan/modules/memory.py +++ b/torchscan/modules/memory.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. +# Copyright (C) 2020-2022, François-Guillaume Fernandez. # This program is licensed under the Apache License version 2. # See LICENSE or go to for full license details. @@ -6,22 +6,20 @@ import warnings from functools import reduce from operator import mul +from typing import Union -from torch import nn -from torch import Tensor +from torch import Tensor, nn from torch.nn import Module from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd -from torch.nn.modules.pooling import _MaxPoolNd, _AvgPoolNd, _AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd -from typing import Union - +from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd __all__ = ['module_dmas'] def module_dmas(module: Module, input: Tensor, output: Tensor) -> int: """Estimate the number of direct memory accesses by the module. - The implementation overhead is neglected + The implementation overhead is neglected. Args: module (torch.nn.Module): PyTorch module @@ -184,7 +182,7 @@ def dmas_bn(module: _BatchNorm, input: Tensor, output: Tensor) -> int: input_dma = input.numel() # Access running_mean, running_var and eps - ops_dma = module.running_mean.numel() + module.running_var.numel() + 1 # type: ignore[operator] + ops_dma = module.running_mean.numel() + module.running_var.numel() + 1 # type: ignore[union-attr] # Access to weight and bias if module.affine: ops_dma += module.weight.data.numel() + module.bias.data.numel() @@ -195,7 +193,7 @@ def dmas_bn(module: _BatchNorm, input: Tensor, output: Tensor) -> int: if module.training and module.track_running_stats: # Current mean and std computation only requires access to input, already counted in input_dma # Update num of batches and running stats - ops_dma += 1 + module.running_mean.numel() + module.running_var.numel() # type: ignore[operator] + ops_dma += 1 + module.running_mean.numel() + module.running_var.numel() # type: ignore[union-attr] output_dma = output.numel() diff --git a/torchscan/modules/receptive.py b/torchscan/modules/receptive.py index 3fa5f97..1fe6e1b 100644 --- a/torchscan/modules/receptive.py +++ b/torchscan/modules/receptive.py @@ -1,18 +1,17 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. +# Copyright (C) 2020-2022, François-Guillaume Fernandez. # This program is licensed under the Apache License version 2. # See LICENSE or go to for full license details. import math import warnings -from torch import nn -from torch import Tensor +from typing import Tuple, Union + +from torch import Tensor, nn from torch.nn import Module from torch.nn.modules.batchnorm import _BatchNorm from torch.nn.modules.conv import _ConvNd, _ConvTransposeNd -from torch.nn.modules.pooling import _MaxPoolNd, _AvgPoolNd, _AdaptiveMaxPoolNd, _AdaptiveAvgPoolNd -from typing import Tuple, Union - +from torch.nn.modules.pooling import _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd, _AvgPoolNd, _MaxPoolNd __all__ = ['module_rf'] diff --git a/torchscan/process/memory.py b/torchscan/process/memory.py index 5c9974b..e96b9ca 100644 --- a/torchscan/process/memory.py +++ b/torchscan/process/memory.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. +# Copyright (C) 2020-2022, François-Guillaume Fernandez. # This program is licensed under the Apache License version 2. # See LICENSE or go to for full license details. @@ -23,10 +23,10 @@ def get_process_gpu_ram(pid: int) -> float: try: res = subprocess.run(["nvidia-smi", "-q", "-d", "PIDS"], capture_output=True).stdout.decode() # Try to locate the process - pids = re.findall("Process ID\s+:\s([^\D]*)", res) + pids = re.findall(r"Process ID\s+:\s([^\D]*)", res) for idx, _pid in enumerate(pids): if int(_pid) == pid: - return float(re.findall("Used GPU Memory\s+:\s([^\D]*)", res)[idx]) + return float(re.findall(r"Used GPU Memory\s+:\s([^\D]*)", res)[idx]) except Exception as e: warnings.warn(f"raised: {e}. Assuming no GPU is available.") diff --git a/torchscan/utils.py b/torchscan/utils.py index c1bc3ab..93d6f3c 100644 --- a/torchscan/utils.py +++ b/torchscan/utils.py @@ -1,9 +1,9 @@ -# Copyright (C) 2020-2021, François-Guillaume Fernandez. +# Copyright (C) 2020-2022, François-Guillaume Fernandez. # This program is licensed under the Apache License version 2. # See LICENSE or go to for full license details. -from typing import Tuple, Dict, Any, Optional, List +from typing import Any, Dict, List, Optional, Tuple def format_name(name: str, depth: int = 0) -> str: