diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 62e2a903..06824925 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -15,6 +15,27 @@ jobs: python-version: 3.7 - name: set PY run: echo "::set-env name=PY::$(python -c 'import hashlib, sys;print(hashlib.sha256(sys.version.encode()+sys.executable.encode()).hexdigest())')" + - name: Install Poetry + run: | + pip install --upgrade pip + curl -sSL https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py | python + - name: Add Poetry to Path Unix + run: echo "::add-path::$HOME/.poetry/bin" + - name: Configure Poetry + run: | + poetry config virtualenvs.create false + - name: Cache pip + uses: actions/cache@v1 + id: cache-poetry + with: + path: ~/.cache/pip + key: ${{ matrix.os }}|${{ env.PY }}|poetry|pre-commit|${{ hashFiles('poetry.lock') }} + restore-keys: | + ${{ matrix.os }}|${{ env.PY }}|poetry|pre-commit| + - name: Install Project Dependencies (Poetry) + run: | + poetry install -vvv + if: steps.cache-poetry.outputs.cache-hit != 'true' - uses: actions/cache@v1 with: path: ~/.cache/pre-commit @@ -55,7 +76,7 @@ jobs: path: ~/.virtualenvs key: poetry|v2|${{ matrix.os }}|${{ env.PY }}|${{ hashFiles('poetry.lock') }} restore-keys: | - poetry-${{ hashFiles('poetry.lock') }} + poetry|v2|${{ matrix.os }}|${{ env.PY }}| - name: Install Project Dependencies (Poetry) run: | poetry install -vvv diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 889c8bba..8259dc39 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,10 @@ default_language_version: - python: python3.6 + python: python3.7 repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v2.5.0 hooks: + - id: check-toml - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace @@ -30,8 +31,13 @@ repos: rev: 3.7.9 hooks: - id: flake8 - additional_dependencies: ['flake8-docstrings'] -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.770 + additional_dependencies: ['flake8-docstrings==1.5.0', 'darglint==1.2.1'] +- repo: local hooks: - id: mypy + name: mypy + entry: mypy + language: system + types: [python] + args: ['-p=pyrh'] + pass_filenames: false diff --git a/newsfragments/223.feature b/newsfragments/223.feature new file mode 100644 index 00000000..f48a5e9a --- /dev/null +++ b/newsfragments/223.feature @@ -0,0 +1 @@ +Add marshmallow support for internal models. diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index 6cebcc19..c94857dc 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -25,7 +25,7 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", - "from pyrh import Robinhood" + "from pyrh import Robinhood, dump_session, load_session" ] }, { @@ -42,8 +42,8 @@ "# Log in to app (will prompt for two-factor)\n", "rh = Robinhood(username=\"USERNAME\", password=\"PASSWORD\")\n", "rh.login()\n", - "rh.to_json() # so you don't have to do mfa again\n", - "rh.from_json() # to load the above json cache data" + "dump_session(rh) # so you don't have to do mfa again\n", + "rh = load_session(rh) # to load the above json cache data\n" ] }, { diff --git a/poetry.lock b/poetry.lock index ce5623a2..158cc1a1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -146,6 +146,15 @@ version = "5.0.4" [package.extras] toml = ["toml"] +[[package]] +category = "dev" +description = "A utility for ensuring Google-style docstrings stay up to date with the source code." +marker = "python_version >= \"3.7\" and python_version < \"4.0\"" +name = "darglint" +optional = false +python-versions = ">=3.7,<4.0" +version = "1.2.1" + [[package]] category = "dev" description = "Decorators for Humans" @@ -196,6 +205,18 @@ version = "1.5.0" flake8 = ">=3" pydocstyle = ">=2.1" +[[package]] +category = "dev" +description = "Let your Python tests travel through time" +name = "freezegun" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +version = "0.3.15" + +[package.dependencies] +python-dateutil = ">=1.0,<2.0 || >2.0" +six = "*" + [[package]] category = "main" description = "Internationalized Domain Names in Applications (IDNA)" @@ -322,6 +343,20 @@ optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*" version = "1.1.1" +[[package]] +category = "main" +description = "A lightweight library for converting complex datatypes to and from native Python datatypes." +name = "marshmallow" +optional = false +python-versions = ">=3.5" +version = "3.5.1" + +[package.extras] +dev = ["pytest", "pytz", "simplejson", "mypy (0.761)", "flake8 (3.7.9)", "flake8-bugbear (20.1.4)", "pre-commit (>=1.20,<3.0)", "tox"] +docs = ["sphinx (2.4.3)", "sphinx-issues (1.2.0)", "alabaster (0.7.12)", "sphinx-version-warning (1.1.2)"] +lint = ["mypy (0.761)", "flake8 (3.7.9)", "flake8-bugbear (20.1.4)", "pre-commit (>=1.20,<3.0)"] +tests = ["pytest", "pytz", "simplejson"] + [[package]] category = "dev" description = "McCabe checker, plugin for flake8" @@ -338,6 +373,14 @@ optional = false python-versions = ">=3.5" version = "8.2.0" +[[package]] +category = "main" +description = "multidict implementation" +name = "multidict" +optional = false +python-versions = ">=3.5" +version = "4.7.5" + [[package]] category = "dev" description = "Optional static typing for Python" @@ -559,7 +602,7 @@ six = ">=1.5" category = "main" description = "World timezone definitions, modern and historical" name = "pytz" -optional = true +optional = false python-versions = "*" version = "2019.3" @@ -783,7 +826,7 @@ version = "0.10.0" category = "main" description = "Style preserving TOML library" name = "tomlkit" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" version = "0.5.11" @@ -826,7 +869,7 @@ python-versions = "*" version = "1.4.1" [[package]] -category = "dev" +category = "main" description = "Backported and Experimental Type Hints for Python 3.5+" name = "typing-extensions" optional = false @@ -870,6 +913,18 @@ all = ["six", "pytest", "pytest-cov", "codecov", "pygments", "colorama"] optional = ["pygments", "colorama"] tests = ["pytest", "pytest-cov", "codecov"] +[[package]] +category = "main" +description = "Yet another URL library" +name = "yarl" +optional = false +python-versions = ">=3.5" +version = "1.4.2" + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" + [[package]] category = "dev" description = "Backport of pathlib-compatible object wrapper for zip files" @@ -884,10 +939,10 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [extras] -doc = ["sphinx", "sphinx-autodoc-typehints", "sphinx_rtd_theme", "autodocsumm"] +doc = ["sphinx", "sphinx-autodoc-typehints", "sphinx_rtd_theme", "autodocsumm", "tomlkit"] [metadata] -content-hash = "d2d90bc3ec5835c477086c262ede51097246c76b0b5968e524bb53a696bfa003" +content-hash = "ea1d8069d64d9480fd2b40cd50107dafc0688606674d23ed1268dff0e1bc33c4" python-versions = "^3.6" [metadata.files] @@ -975,6 +1030,10 @@ coverage = [ {file = "coverage-5.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:4482f69e0701139d0f2c44f3c395d1d1d37abd81bfafbf9b6efbe2542679d892"}, {file = "coverage-5.0.4.tar.gz", hash = "sha256:1b60a95fc995649464e0cd48cecc8288bac5f4198f21d04b8229dc4097d76823"}, ] +darglint = [ + {file = "darglint-1.2.1-py3-none-any.whl", hash = "sha256:16ee69a67fc0f3a89917ba4028b9c50491d7cb4e569cb94eed2e013e2a574c77"}, + {file = "darglint-1.2.1.tar.gz", hash = "sha256:7fec9d38b545f49650e96c45f9c62d6ba6cc9c9f66d305a348813c1c0a6fcb02"}, +] decorator = [ {file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"}, {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"}, @@ -995,6 +1054,10 @@ flake8-docstrings = [ {file = "flake8-docstrings-1.5.0.tar.gz", hash = "sha256:3d5a31c7ec6b7367ea6506a87ec293b94a0a46c0bce2bb4975b7f1d09b6f3717"}, {file = "flake8_docstrings-1.5.0-py2.py3-none-any.whl", hash = "sha256:a256ba91bc52307bef1de59e2a009c3cf61c3d0952dbe035d6ff7208940c2edc"}, ] +freezegun = [ + {file = "freezegun-0.3.15-py2.py3-none-any.whl", hash = "sha256:82c757a05b7c7ca3e176bfebd7d6779fd9139c7cb4ef969c38a28d74deef89b2"}, + {file = "freezegun-0.3.15.tar.gz", hash = "sha256:e2062f2c7f95cc276a834c22f1a17179467176b624cc6f936e8bc3be5535ad1b"}, +] idna = [ {file = "idna-2.9-py2.py3-none-any.whl", hash = "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa"}, {file = "idna-2.9.tar.gz", hash = "sha256:7588d1c14ae4c77d74036e8c22ff447b26d0fde8f007354fd48a7814db15b7cb"}, @@ -1066,6 +1129,10 @@ markupsafe = [ {file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, ] +marshmallow = [ + {file = "marshmallow-3.5.1-py2.py3-none-any.whl", hash = "sha256:ac2e13b30165501b7d41fc0371b8df35944f5849769d136f20e2c5f6cdc6e665"}, + {file = "marshmallow-3.5.1.tar.gz", hash = "sha256:90854221bbb1498d003a0c3cc9d8390259137551917961c8b5258c64026b2f85"}, +] mccabe = [ {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, @@ -1074,6 +1141,25 @@ more-itertools = [ {file = "more-itertools-8.2.0.tar.gz", hash = "sha256:b1ddb932186d8a6ac451e1d95844b382f55e12686d51ca0c68b6f61f2ab7a507"}, {file = "more_itertools-8.2.0-py3-none-any.whl", hash = "sha256:5dd8bcf33e5f9513ffa06d5ad33d78f31e1931ac9a18f33d37e77a180d393a7c"}, ] +multidict = [ + {file = "multidict-4.7.5-cp35-cp35m-macosx_10_13_x86_64.whl", hash = "sha256:fc3b4adc2ee8474cb3cd2a155305d5f8eda0a9c91320f83e55748e1fcb68f8e3"}, + {file = "multidict-4.7.5-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:42f56542166040b4474c0c608ed051732033cd821126493cf25b6c276df7dd35"}, + {file = "multidict-4.7.5-cp35-cp35m-win32.whl", hash = "sha256:7774e9f6c9af3f12f296131453f7b81dabb7ebdb948483362f5afcaac8a826f1"}, + {file = "multidict-4.7.5-cp35-cp35m-win_amd64.whl", hash = "sha256:c2c37185fb0af79d5c117b8d2764f4321eeb12ba8c141a95d0aa8c2c1d0a11dd"}, + {file = "multidict-4.7.5-cp36-cp36m-macosx_10_13_x86_64.whl", hash = "sha256:e439c9a10a95cb32abd708bb8be83b2134fa93790a4fb0535ca36db3dda94d20"}, + {file = "multidict-4.7.5-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:85cb26c38c96f76b7ff38b86c9d560dea10cf3459bb5f4caf72fc1bb932c7136"}, + {file = "multidict-4.7.5-cp36-cp36m-win32.whl", hash = "sha256:620b37c3fea181dab09267cd5a84b0f23fa043beb8bc50d8474dd9694de1fa6e"}, + {file = "multidict-4.7.5-cp36-cp36m-win_amd64.whl", hash = "sha256:6e6fef114741c4d7ca46da8449038ec8b1e880bbe68674c01ceeb1ac8a648e78"}, + {file = "multidict-4.7.5-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:a326f4240123a2ac66bb163eeba99578e9d63a8654a59f4688a79198f9aa10f8"}, + {file = "multidict-4.7.5-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:dc561313279f9d05a3d0ffa89cd15ae477528ea37aa9795c4654588a3287a9ab"}, + {file = "multidict-4.7.5-cp37-cp37m-win32.whl", hash = "sha256:4b7df040fb5fe826d689204f9b544af469593fb3ff3a069a6ad3409f742f5928"}, + {file = "multidict-4.7.5-cp37-cp37m-win_amd64.whl", hash = "sha256:317f96bc0950d249e96d8d29ab556d01dd38888fbe68324f46fd834b430169f1"}, + {file = "multidict-4.7.5-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:b51249fdd2923739cd3efc95a3d6c363b67bbf779208e9f37fd5e68540d1a4d4"}, + {file = "multidict-4.7.5-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:ae402f43604e3b2bc41e8ea8b8526c7fa7139ed76b0d64fc48e28125925275b2"}, + {file = "multidict-4.7.5-cp38-cp38-win32.whl", hash = "sha256:bb519becc46275c594410c6c28a8a0adc66fe24fef154a9addea54c1adb006f5"}, + {file = "multidict-4.7.5-cp38-cp38-win_amd64.whl", hash = "sha256:544fae9261232a97102e27a926019100a9db75bec7b37feedd74b3aa82f29969"}, + {file = "multidict-4.7.5.tar.gz", hash = "sha256:aee283c49601fa4c13adc64c09c978838a7e812f85377ae130a24d7198c0331e"}, +] mypy = [ {file = "mypy-0.770-cp35-cp35m-macosx_10_6_x86_64.whl", hash = "sha256:a34b577cdf6313bf24755f7a0e3f3c326d5c1f4fe7422d1d06498eb25ad0c600"}, {file = "mypy-0.770-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:86c857510a9b7c3104cf4cde1568f4921762c8f9842e987bc03ed4f160925754"}, @@ -1311,6 +1397,25 @@ xdoctest = [ {file = "xdoctest-0.11.0-py2.py3-none-any.whl", hash = "sha256:292d8c0e8de9bcab6ce3fa5009da445278ce343c744aa2ae260710f924c5242e"}, {file = "xdoctest-0.11.0.tar.gz", hash = "sha256:71951c60bb8b15fdf3c368c5636a6bbbd658ef5c2e5949e784adf5e1a1275fbd"}, ] +yarl = [ + {file = "yarl-1.4.2-cp35-cp35m-macosx_10_13_x86_64.whl", hash = "sha256:3ce3d4f7c6b69c4e4f0704b32eca8123b9c58ae91af740481aa57d7857b5e41b"}, + {file = "yarl-1.4.2-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:a4844ebb2be14768f7994f2017f70aca39d658a96c786211be5ddbe1c68794c1"}, + {file = "yarl-1.4.2-cp35-cp35m-win32.whl", hash = "sha256:d8cdee92bc930d8b09d8bd2043cedd544d9c8bd7436a77678dd602467a993080"}, + {file = "yarl-1.4.2-cp35-cp35m-win_amd64.whl", hash = "sha256:c2b509ac3d4b988ae8769901c66345425e361d518aecbe4acbfc2567e416626a"}, + {file = "yarl-1.4.2-cp36-cp36m-macosx_10_13_x86_64.whl", hash = "sha256:308b98b0c8cd1dfef1a0311dc5e38ae8f9b58349226aa0533f15a16717ad702f"}, + {file = "yarl-1.4.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:944494be42fa630134bf907714d40207e646fd5a94423c90d5b514f7b0713fea"}, + {file = "yarl-1.4.2-cp36-cp36m-win32.whl", hash = "sha256:5b10eb0e7f044cf0b035112446b26a3a2946bca9d7d7edb5e54a2ad2f6652abb"}, + {file = "yarl-1.4.2-cp36-cp36m-win_amd64.whl", hash = "sha256:a161de7e50224e8e3de6e184707476b5a989037dcb24292b391a3d66ff158e70"}, + {file = "yarl-1.4.2-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:26d7c90cb04dee1665282a5d1a998defc1a9e012fdca0f33396f81508f49696d"}, + {file = "yarl-1.4.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:0c2ab325d33f1b824734b3ef51d4d54a54e0e7a23d13b86974507602334c2cce"}, + {file = "yarl-1.4.2-cp37-cp37m-win32.whl", hash = "sha256:e15199cdb423316e15f108f51249e44eb156ae5dba232cb73be555324a1d49c2"}, + {file = "yarl-1.4.2-cp37-cp37m-win_amd64.whl", hash = "sha256:2098a4b4b9d75ee352807a95cdf5f10180db903bc5b7270715c6bbe2551f64ce"}, + {file = "yarl-1.4.2-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c9959d49a77b0e07559e579f38b2f3711c2b8716b8410b320bf9713013215a1b"}, + {file = "yarl-1.4.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:25e66e5e2007c7a39541ca13b559cd8ebc2ad8fe00ea94a2aad28a9b1e44e5ae"}, + {file = "yarl-1.4.2-cp38-cp38-win32.whl", hash = "sha256:6faa19d3824c21bcbfdfce5171e193c8b4ddafdf0ac3f129ccf0cdfcb083e462"}, + {file = "yarl-1.4.2-cp38-cp38-win_amd64.whl", hash = "sha256:0ca2f395591bbd85ddd50a82eb1fde9c1066fafe888c5c7cc1d810cf03fd3cc6"}, + {file = "yarl-1.4.2.tar.gz", hash = "sha256:58cd9c469eced558cd81aa3f484b2924e8897049e06889e8ff2510435b7ef74b"}, +] zipp = [ {file = "zipp-3.1.0-py3-none-any.whl", hash = "sha256:aa36550ff0c0b7ef7fa639055d797116ee891440eac1a56f378e2d3179e0320b"}, {file = "zipp-3.1.0.tar.gz", hash = "sha256:c599e4d75c98f6798c509911d08a22e6c021d074469042177c8c86fb92eefd96"}, diff --git a/pyproject.toml b/pyproject.toml index d627d62b..2500fa70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,39 +25,48 @@ include = [ [tool.poetry.dependencies] python = "^3.6" -requests = "^2.23" -python-dateutil = "^2.8" -# docs need to be distributed for readthedocs +# Doc Dependencies (need to be distributed for readthedocs) +autodocsumm = { version = "^0.1.13", optional = true } sphinx = { version = "^2.4.4", optional = true } sphinx-autodoc-typehints = { version = "^1.10.3", optional = true } sphinx_rtd_theme = { version = "^0.4.3", optional = true } -autodocsumm = { version = "^0.1.13", optional = true } -tomlkit = "^0.5.11" +tomlkit = {version = "^0.5.11", optional = true } + +# Main Dependencies +marshmallow = "^3.5.1" +python-dateutil = "^2.8" +pytz = "^2019.3" +requests = "^2.23" +yarl = "^1.4.2" +typing-extensions = "^3.7.4" [tool.poetry.dev-dependencies] # Linting +# These are version locked so .pre-commit-config is even flake8 = "3.7.9" flake8-docstrings = "1.5.0" black = { version = "19.10b0", python = "^3.6" } +darglint = { version = "1.2.1", python = "^3.7"} isort = { version = "4.3.21", extras = ["pyproject"] } seed-isort-config = { version = "2.1.0", python = "^3.7" } mypy = "0.770" nbstripout = "^0.3.7" # Testing +coverage = "^5.0.4" +freezegun = "^0.3.15" pytest = "^5.4.1" +pytest-cov = "^2.8.1" pytest-mock = "^2.0.0" requests-mock = "^1.7.0" -coverage = "^5.0.4" -pytest-cov = "^2.8.1" xdoctest = "^0.11.0" # Automation towncrier = "^19.2.0" [tool.poetry.extras] -doc = ["sphinx", "sphinx-autodoc-typehints", "sphinx_rtd_theme", "autodocsumm", "toml"] +doc = ["sphinx", "sphinx-autodoc-typehints", "sphinx_rtd_theme", "autodocsumm", "tomlkit"] [tool.black] include = '\.pyi?$' @@ -76,7 +85,7 @@ exclude = ''' [tool.isort] known_first_party = 'robinhood' -known_third_party = ["dateutil", "pytest", "requests", "requests_mock"] +known_third_party = ["dateutil", "freezegun", "marshmallow", "pytest", "pytz", "requests", "requests_mock", "typing_extensions", "yarl"] multi_line_output = 3 lines_after_imports = 2 force_grid_wrap = 0 diff --git a/pyrh/__init__.py b/pyrh/__init__.py index d52e6a16..b4c1deee 100755 --- a/pyrh/__init__.py +++ b/pyrh/__init__.py @@ -1,22 +1,15 @@ """Export pyrh sub classes.""" -from pyrh import exceptions -from pyrh.robinhood import Robinhood -from pyrh.sessionmanager import SessionManager - - -def _get_version() -> str: - from pathlib import Path - from tomlkit import parse - - pyproject_path = Path(__file__).resolve().parent.joinpath("../pyproject.toml") - with open(pyproject_path) as file: - pyproject = parse(file.read()) - - return str(pyproject["tool"]["poetry"]["version"]) - - -__version__ = _get_version() -__all__ = ["__version__", "Robinhood", "SessionManager", "exceptions"] - -del _get_version +from . import exceptions +from .models import dump_session, load_session +from .robinhood import Robinhood + + +__version__ = "2.0" +__all__ = [ + "__version__", + "Robinhood", + "load_session", + "dump_session", + "exceptions", +] diff --git a/pyrh/common.py b/pyrh/common.py new file mode 100644 index 00000000..ae25e0b1 --- /dev/null +++ b/pyrh/common.py @@ -0,0 +1,6 @@ +"""Shared resources across pyrh.""" + +from typing import Any, Dict + + +JSON = Dict[str, Any] diff --git a/pyrh/endpoints.py b/pyrh/endpoints.py index da5fcd54..71e1902e 100755 --- a/pyrh/endpoints.py +++ b/pyrh/endpoints.py @@ -1,20 +1,30 @@ -BASE_API = "https://api.robinhood.com" +"""Define Robinhood endpoints.""" +from typing import Callable -def login(): - return BASE_API + "/oauth2/token/" +from yarl import URL + + +BASE = URL("https://api.robinhood.com") + +# OAuth +OAUTH: URL = BASE.with_path("/oauth2/token/") +OAUTH_REVOKE: URL = BASE.with_path("/oauth2/revoke_token/") +CHALLENGE: Callable[[str], URL] = lambda cid: BASE.with_path( + f"/challenge/{cid}/respond" +) def logout(): - return BASE_API + "/oauth2/revoke_token/" + return BASE.with_path("/oauth2/revoke_token/") def investment_profile(): - return BASE_API + "/user/investment_profile/" + return BASE.with_path("/user/investment_profile/") def accounts(): - return BASE_API + "/accounts/" + return BASE.with_path("/accounts/") def ach(option): @@ -25,20 +35,22 @@ def ach(option): * transfers """ return ( - BASE_API + "/ach/iav/auth/" if option == "iav" else BASE_API + f"/ach/{option}/" + BASE.with_path("/ach/iav/auth/") + if option == "iav" + else BASE.with_path(f"/ach/{option}/") ) def applications(): - return BASE_API + "/applications/" + return BASE.with_path("/applications/") def dividends(): - return BASE_API + "/dividends/" + return BASE.with_path("/dividends/") def edocuments(): - return BASE_API + "/documents/" + return BASE.with_path("/documents/") def instruments(instrument_id=None, option=None): @@ -46,7 +58,7 @@ def instruments(instrument_id=None, option=None): Return information about a specific instrument by providing its instrument id. Add extra options for additional information such as "popularity" """ - url = BASE_API + f"/instruments/" + url = BASE.with_path(f"/instruments/") if instrument_id is not None: url += f"{instrument_id}" if option is not None: @@ -56,83 +68,82 @@ def instruments(instrument_id=None, option=None): def margin_upgrades(): - return BASE_API + "/margin/upgrades/" + return BASE.with_path("/margin/upgrades/") def markets(): - return BASE_API + "/markets/" + return BASE.with_path("/markets/") def notifications(): - return BASE_API + "/notifications/" + return BASE.with_path("/notifications/") def orders(order_id=""): - return BASE_API + f"/orders/{order_id}" + return BASE.with_path(f"/orders/{order_id}/") def password_reset(): - return BASE_API + "/password_reset/request/" + return BASE.with_path("/password_reset/request/") def portfolios(): - return BASE_API + "/portfolios/" + return BASE.with_path("/portfolios/") def positions(): - return BASE_API + "/positions/" + return BASE.with_path("/positions/") def quotes(): - return BASE_API + "/quotes/" + return BASE.with_path("/quotes/") def historicals(): - return BASE_API + "/quotes/historicals/" + return BASE.with_path("/quotes/historicals/") def document_requests(): - return BASE_API + "/upload/document_requests/" + return BASE.with_path("/upload/document_requests/") def user(): - return BASE_API + "/user/" + return BASE.with_path("/user/") def watchlists(): - return BASE_API + "/watchlists/" + return BASE.with_path("/watchlists/") def news(stock): - return BASE_API + f"/midlands/news/{stock}/" + return BASE.with_path(f"/midlands/news/{stock}/") def fundamentals(stock): - return BASE_API + f"/fundamentals/{stock}/" + return BASE.with_path(f"/fundamentals/{stock}/") def tags(tag): """ Returns endpoint with tag concatenated. """ - return BASE_API + f"/midlands/tags/tag/{tag}/" + return BASE.with_path(f"/midlands/tags/tag/{tag}/") def chain(instrument_id): - return BASE_API + f"/options/chains/?equity_instrument_ids={instrument_id}" + return BASE.with_path(f"/options/chains/?equity_instrument_ids={instrument_id}/") def options(chain_id, dates, option_type): - return ( - BASE_API - + f"/options/instruments/?chain_id={chain_id}&expiration_dates={dates}" - + f"&state=active&tradability=tradable&type={option_type}" + return BASE.with_path( + f"/options/instruments/?chain_id={chain_id}&expiration_dates={dates}" + f"&state=active&tradability=tradable&type={option_type}" ) def market_data(option_id): - return BASE_API + f"/marketdata/options/{option_id}/" + return BASE.with_path(f"/marketdata/options/{option_id}/") def convert_token(): - return BASE_API + "/oauth2/migrate_token/" + return BASE.with_path("/oauth2/migrate_token/") diff --git a/pyrh/exceptions.py b/pyrh/exceptions.py index 6779a361..0163575e 100755 --- a/pyrh/exceptions.py +++ b/pyrh/exceptions.py @@ -1,37 +1,31 @@ """Exceptions: custom exceptions for library""" -class RobinhoodException(Exception): +class RHException(Exception): """Wrapper for custom robinhood library exceptions.""" pass -class AuthenticationError(RobinhoodException): +class AuthenticationError(RHException): """Error when trying to login to robinhood.""" pass -class LoginFailed(RobinhoodException): # TODO: Remove me - """Error when trying to login to robinhood.""" - - pass - - -class InvalidTickerSymbol(RobinhoodException): +class InvalidTickerSymbol(RHException): """When an invalid ticker (stock symbol) is given/""" pass -class InvalidInstrumentId(RobinhoodException): +class InvalidInstrumentId(RHException): """When an invalid instrument id is given/""" pass -class InvalidOptionId(RobinhoodException): +class InvalidOptionId(RHException): """When an invalid option id is given/""" pass diff --git a/pyrh/models/__init__.py b/pyrh/models/__init__.py new file mode 100644 index 00000000..d296e465 --- /dev/null +++ b/pyrh/models/__init__.py @@ -0,0 +1,7 @@ +"""pyrh models and schemas.""" + +from .oauth import Challenge, OAuth +from .sessionmanager import SessionManager, dump_session, load_session + + +__all__ = ["OAuth", "Challenge", "SessionManager", "dump_session", "load_session"] diff --git a/pyrh/models/base.py b/pyrh/models/base.py new file mode 100644 index 00000000..24564ff4 --- /dev/null +++ b/pyrh/models/base.py @@ -0,0 +1,99 @@ +"""Base Model.""" + +from types import SimpleNamespace +from typing import Any, Mapping + +from marshmallow import INCLUDE, Schema, post_load + +from pyrh.common import JSON + + +MAX_REPR_LEN = 50 + + +def _process_dict_values(value: Any) -> Any: + """Process a returned from a JSON response. + + Args: + value: A dict, list, or value returned from a JSON response. + + Returns: + Either an UnknownModel, a List of processed values, or the original value \ + passed through. + + """ + if isinstance(value, Mapping): + return UnknownModel(**value) + elif isinstance(value, list): + return [_process_dict_values(v) for v in value] + else: + return value + + +class BaseModel(SimpleNamespace): + """BaseModel that all models should inherit from. + + Note: + If a passed parameter is a nested dictionary, then it is created with the + `UnknownModel` class. If it is a list, then it is created with + + Args: + **kwargs: All passed parameters as converted to instance attributes. + """ + + def __init__(self, **kwargs: Any) -> None: + kwargs = {k: _process_dict_values(v) for k, v in kwargs.items()} + + self.__dict__.update(kwargs) + + def __repr__(self) -> str: + """Return a default repr of any Model. + + Returns: + The string model parameters up to a `MAX_REPR_LEN`. + + """ + repr_ = super().__repr__() + if len(repr_) > MAX_REPR_LEN: + return repr_[:MAX_REPR_LEN] + " ...)" + else: + return repr_ + + def __len__(self) -> int: + """Return the length of the model. + + Returns: + The number of attributes a given model has. + + """ + return len(self.__dict__) + + +class UnknownModel(BaseModel): + """A convenience class that inherits from `BaseModel`.""" + + pass + + +class BaseSchema(Schema): + """The default schema for all models.""" + + __model__: Any = UnknownModel + """Determines the object that is created when the load method is called.""" + + class Meta: + unknown = INCLUDE + + @post_load + def make_object(self, data: JSON, **kwargs: Any) -> "__model__": + """Build model for the given `__model__` class attribute. + + Args: + data: The JSON diction to use to build the model. + **kwargs: Unused but required to match signature of `Schema.make_object` + + Returns: + An instance of the `__model__` class. + + """ + return self.__model__(**data) diff --git a/pyrh/models/oauth.py b/pyrh/models/oauth.py new file mode 100644 index 00000000..7cfa3c9e --- /dev/null +++ b/pyrh/models/oauth.py @@ -0,0 +1,95 @@ +"""Oauth models.""" + +from datetime import datetime + +import pytz +from marshmallow import fields, validate + +from .base import BaseModel, BaseSchema + + +CHALLENGE_TYPE_VAL = validate.OneOf(["email", "sms"]) + + +class Challenge(BaseModel): + """The challenge response model.""" + + remaining_attempts = 0 + """Default `remaining_attempts` attribute if it is not set on instance.""" + + @property + def can_retry(self) -> bool: + """Determine if the challenge can be retried. + + Returns: + True if remaining_attempts is greater than zero and challenge is not \ + expired, False otherwise. + + """ + return self.remaining_attempts > 0 and ( + datetime.now(tz=pytz.utc) < self.expires_at + ) + + +class ChallengeSchema(BaseSchema): + """The challenge response schema.""" + + __model__ = Challenge + + id = fields.UUID() + user = fields.UUID() + type = fields.Str(validate=CHALLENGE_TYPE_VAL) + alternate_type = fields.Str(validate=CHALLENGE_TYPE_VAL) + status = fields.Str(validate=validate.OneOf(["issued", "validated", "failed"])) + remaining_retries = fields.Int() + remaining_attempts = fields.Int() + expires_at = fields.AwareDateTime(default_timezone=pytz.UTC) # type: ignore + + +class OAuth(BaseModel): + """The OAuth response model.""" + + @property + def is_challenge(self) -> bool: + """Determine whether the oauth response is a challenge. + + Returns: + True response has the `challenge` key, False otherwise. + + """ + return hasattr(self, "challenge") + + @property + def is_mfa(self) -> bool: + """Determine whether the oauth response is a mfa challenge. + + Returns: + True response has the `mfa_required` key, False otherwise. + + """ + return hasattr(self, "mfa_required") + + @property + def is_valid(self) -> bool: + """Determine whether the oauth response is a valid response. + + Returns: + True if the response has both the `access_token` and `refresh_token` keys, \ + False otherwise. + + """ + return hasattr(self, "access_token") and hasattr(self, "refresh_token") + + +class OAuthSchema(BaseSchema): + """The OAuth response schema.""" + + __model__ = OAuth + + detail = fields.Str() + challenge = fields.Nested(ChallengeSchema) + mfa_required = fields.Boolean() + + access_token = fields.Str() + refresh_token = fields.Str() + expires_in = fields.Int() diff --git a/pyrh/models/sessionmanager.py b/pyrh/models/sessionmanager.py new file mode 100644 index 00000000..777666e2 --- /dev/null +++ b/pyrh/models/sessionmanager.py @@ -0,0 +1,612 @@ +"""Manage Robinhood Sessions.""" + +import uuid +from datetime import datetime, timedelta +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast, overload +from urllib.request import getproxies + +import pytz +import requests +from marshmallow import fields, post_load +from requests import Response +from requests.exceptions import HTTPError +from requests.structures import CaseInsensitiveDict +from typing_extensions import Literal +from yarl import URL + +from pyrh import endpoints +from pyrh.cache import CACHE_ROOT +from pyrh.common import JSON +from pyrh.exceptions import AuthenticationError + +from .base import BaseModel, BaseSchema +from .oauth import CHALLENGE_TYPE_VAL, OAuth, OAuthSchema + + +CERTS_PATH: Path = Path(__file__).parent.joinpath("./ssl/certs.pem") +"""Path to ssl files used when running post requests.""" + +CLIENT_ID: str = "c82SH0WZOsabOXGP2sxqcj34FxkvfnWRZBKlBjFS" +"""Robinhood client id.""" + +CACHE_LOGIN: Path = CACHE_ROOT.joinpath("login.json") +"""Path to login.json config file.""" +CACHE_LOGIN.touch(exist_ok=True) + +if TYPE_CHECKING: # pragma: no cover + CaseInsensitiveDictType = CaseInsensitiveDict[str] +else: + CaseInsensitiveDictType = CaseInsensitiveDict + +Proxies = Dict[str, str] +HEADERS: CaseInsensitiveDictType = CaseInsensitiveDict( + { + "Accept": "*/*", + "Accept-Encoding": "gzip, deflate", + "Accept-Language": "en;q=1, fr;q=0.9, de;q=0.8, ja;q=0.7, nl;q=0.6, it;q=0.5", + "Content-Type": "application/x-www-form-urlencoded; charset=utf-8", + "X-Robinhood-API-Version": "1.0.0", + "Connection": "keep-alive", + "User-Agent": "Robinhood/823 (iPhone; iOS 7.1.2; Scale/2.00)", + } +) +"""Headers used when performing requests with robinhood api.""" + +EXPIRATION_TIME: int = 10 +"""Default expiration time for requests.""" + +# TODO: Watch this issue and remove the F811 ignores when it is fixed +# https://gitlab.com/pycqa/flake8/-/merge_requests/417 (we need at least pyflakes 2.2.0) + + +class SessionManager(BaseModel): + """Mange connectivity with Robinhood API. + + Once logged into the session, this class will manage automatic oauth token update + requests allowing for the automation systems to only require multi-factor + authentication on initialization. + + Example: + >>> sm = SessionManager(username="USERNAME", password="PASSWORD") + >>> sm.login() # xdoctest: +SKIP + >>> sm.logout() # xdoctest: +SKIP + + If you want to cache your session (you should) then you can use the following + functions. This will allow you to re-cover from a script crash without having to + manually re-enter multi-factor authentication codes. + + Example: + >>> dump_session(sm) # xdoctest: +SKIP + >>> load_session(sm) # xdoctest: +SKIP + + Args: + username: The username to login to Robinhood. + password: The password to login to Robinhood. + challenge_type: Either sms or email. (only if not using mfa) + headers: Any optional header dict modifications for the session. + proxies: Any optional proxy dict modification for the session. + **kwargs: Any other passed parameters as converted to instance attributes. + + Attributes: + session: A requests session instance. + expires_at: The time the oauth token will expire at, default is + 1970-01-01 00:00:00. + certs: The path to the desired certs to check against. + device_token: A random guid representing the current device. + access_token: An oauth2 token to connect to the Robinhood API. + refresh_token: An oauth2 refresh token to refresh the access_token when + required. + username: The username to login to Robinhood. + password: The password to login to Robinhood. + challenge_type: Either sms or email. (only if not using mfa) + headers: Any optional header dict modifications for the session. + proxies: Any optional proxy dict modification for the session. + + """ + + def __init__( + self, + username: str, + password: str, + challenge_type: Optional[str] = "email", + headers: Optional[CaseInsensitiveDictType] = None, + proxies: Optional[Proxies] = None, + **kwargs: Any, + ) -> None: + self.session: requests.Session = requests.session() + self.session.headers = HEADERS if headers is None else headers + self.session.proxies = getproxies() if proxies is None else proxies + self.expires_at = datetime.strptime("1970", "%Y").replace( + tzinfo=pytz.UTC + ) # some time in the past + self.certs: Path = CERTS_PATH + + self.username: str = username + self.password: str = password + if challenge_type not in ["email", "sms"]: + raise ValueError("challenge_type must be email or sms") + self.challenge_type: str = challenge_type + + self._gen_device_token: str = str(uuid.uuid4()) + self.oauth: OAuth = OAuth() + + super().__init__(**kwargs) + + @property + def token_expired(self) -> bool: + """Check if the issued auth token has expired. + + Returns: + True if expired otherwise False + """ + return datetime.now(tz=pytz.UTC) > self.expires_at + + @property + def login_set(self) -> bool: + """Check if login info is properly configured. + + Returns: + Whether or not username and password are set. + """ + return self.password is not None and self.username is not None + + @property + def authenticated(self) -> bool: + """Check if the session is authenticated. + + Returns: + Whether or not the session is logged in. + """ + return "Authorization" in self.session.headers and not self.token_expired + + def login(self, force_refresh: bool = False) -> None: + """Login to the session. + + This method logs the user in if they are not already and otherwise refreshes + the oauth token if it is expired. + + Args: + force_refresh: If already logged in, whether or not to force a oauth token + refresh. + + """ + if "Authorization" not in self.session.headers: + self._login_oauth2() + elif self.oauth.is_valid and (self.token_expired or force_refresh): + self._refresh_oauth2() + + # The following type hints helps mypy determine what the output type to assign based + # on the `return_response` parameter. The same "stub" method approach is used for + # the post method as well. + # https://github.com/python/mypy/issues/8634#issuecomment-609411104 + @overload + def get( + self, + url: Union[str, URL], + params: Optional[Dict[str, Any]] = None, + *, + headers: Optional[CaseInsensitiveDictType] = None, + raise_errors: bool = True, + return_response: Literal[True], + auto_login: bool = True, + ) -> Response: # noqa: D102 # pragma: no cover + ... + + @overload # noqa: F811 + def get( + self, + url: Union[str, URL], + params: Optional[Dict[str, Any]] = None, + *, + headers: Optional[CaseInsensitiveDictType] = None, + raise_errors: bool = True, + return_response: Literal[False] = ..., + auto_login: bool = True, + ) -> JSON: # noqa: D102 # pragma: no cover + ... + + def get( # noqa: F811 + self, + url: Union[str, URL], + params: Optional[Dict[str, Any]] = None, + *, + headers: Optional[CaseInsensitiveDictType] = None, + raise_errors: bool = True, + return_response: bool = False, + auto_login: bool = True, + ) -> Union[Response, JSON]: + """Run a wrapped session HTTP GET request. + + Note: + This method automatically prompts the user to log in if not already logged + in. + + Args: + url: The url to get from. + params: query string parameters + headers: A dict adding to and overriding the session headers. + raise_errors: Whether or not raise errors on GET request result. + return_response: Whether or not return a `requests.Response` object or the + JSON response from the request. + auto_login: Whether or not to automatically login on restricted endpoint + errors. + + Returns: + The POST response + + """ + params = {} if params is None else params + res = self.session.get( + str(url), params=params, headers={} if headers is None else headers + ) + if res.status_code == 401 and auto_login: + self.login(force_refresh=True) + res = self.session.get( + str(url), params=params, headers={} if headers is None else headers + ) + if raise_errors: + res.raise_for_status() + + return res if return_response else res.json() + + @overload + def post( + self, + url: Union[str, URL], + data: Optional[JSON] = None, + *, + headers: Optional[CaseInsensitiveDictType] = None, + raise_errors: bool = True, + return_response: Literal[True], + auto_login: bool = True, + ) -> Response: # noqa: D102 # pragma: no cover + ... + + @overload # noqa: F811 + def post( + self, + url: Union[str, URL], + data: Optional[JSON] = None, + *, + headers: Optional[CaseInsensitiveDictType] = None, + raise_errors: bool = True, + return_response: Literal[False] = ..., + auto_login: bool = True, + ) -> JSON: # noqa: D102 # pragma: no cover + ... + + def post( # noqa: F811 + self, + url: Union[str, URL], + data: Optional[JSON] = None, + *, + headers: Optional[CaseInsensitiveDictType] = None, + raise_errors: bool = True, + return_response: bool = False, + auto_login: bool = True, + ) -> Union[JSON, Response]: + """Run a wrapped session HTTP POST request. + + Note: + This method automatically prompts the user to log in if not already logged + in. + + Args: + url: The url to post to. + data: The payload to POST to the endpoint. + headers: A dict adding to and overriding the session headers. + return_response: Whether or not return a `requests.Response` object or the + JSON response from the request. + raise_errors: Whether or not raise errors on POST request. + auto_login: Whether or not to automatically login on restricted endpoint + errors. + + Returns: + The response or an empty dict if an empty response is returned. + + """ + res = self.session.post( + str(url), + data=data, + timeout=15, + verify=self.certs, + headers={} if headers is None else headers, + ) + if (res.status_code == 401) and auto_login: + self.login(force_refresh=True) + res = self.session.post( + str(url), + data=data, + timeout=15, + verify=self.certs, + headers={} if headers is None else headers, + ) + if raise_errors: + res.raise_for_status() + + return res if return_response else res.json() + + def _configure_manager(self, oauth: OAuth) -> None: + """Process an authentication response dictionary. + + This method updates the internal state of the session based on a login or + token refresh request. + + Args: + oauth: An oauth response model from a login request. + + """ + self.oauth = oauth + self.expires_at = datetime.now(tz=pytz.UTC) + timedelta( + seconds=self.oauth.expires_in + ) + self.session.headers.update( + {"Authorization": f"Bearer {self.oauth.access_token}"} + ) + + def _challenge_oauth2(self, oauth: OAuth, oauth_payload: JSON) -> OAuth: + """Process the ouath challenge flow. + + Args: + oauth: An oauth response model from a login request. + oauth_payload: The payload to use once the challenge has been processed. + + Returns: + An OAuth response model from the login request. + + Raises: + AuthenticationError: If there is an error in the initial challenge response. + + .. # noqa: DAR202 + .. https://github.com/terrencepreilly/darglint/issues/81 + + """ + # login challenge + challenge_url = endpoints.CHALLENGE(oauth.challenge.id) + print( + f"Input challenge code from {oauth.challenge.type.capitalize()} " + f"({oauth.challenge.remaining_attempts}/" + f"{oauth.challenge.remaining_retries}):" + ) + challenge_code = input() + challenge_payload = {"response": str(challenge_code)} + challenge_header = CaseInsensitiveDict( + {"X-ROBINHOOD-CHALLENGE-RESPONSE-ID": str(oauth.challenge.id)} + ) + res = self.post( + challenge_url, + data=challenge_payload, + raise_errors=False, + headers=challenge_header, + auto_login=False, + return_response=True, + ) + oauth_inner = OAuthSchema().load(res.json()) + if res.status_code == requests.codes.ok: + try: + res2 = self.post( + endpoints.OAUTH, + data=oauth_payload, + headers=challenge_header, + auto_login=False, + ) + except HTTPError: + raise AuthenticationError("Error in finalizing auth token") + else: + oauth = OAuthSchema().load(res2) + return oauth + elif oauth_inner.is_challenge and oauth_inner.challenge.can_retry: + print("Invalid code entered") + return self._challenge_oauth2(oauth, oauth_payload) + else: + raise AuthenticationError("Exceeded available attempts or code expired") + + def _mfa_oauth2(self, oauth_payload: JSON, attempts: int = 3) -> OAuth: + """Mfa auth flow. + + For people with 2fa. + + Args: + oauth_payload: JSON payload to send on mfa approval. + attempts: The number of attempts to allow for mfa approval. + + Returns: + An OAuth response model object. + + Raises: + AuthenticationError: If the mfa code is incorrect more than specified \ + number of attempts. + + """ + print(f"Input mfa code:") + mfa_code = input() + oauth_payload["mfa_code"] = mfa_code + res = self.post( + endpoints.OAUTH, + data=oauth_payload, + raise_errors=False, + auto_login=False, + return_response=True, + ) + attempts -= 1 + if (res.status_code != requests.codes.ok) and (attempts > 0): + print("Invalid mfa code") + return self._mfa_oauth2(oauth_payload, attempts) + elif res.status_code == requests.codes.ok: + # TODO: Write mypy issue on why this needs to be casted? + return cast(OAuth, OAuthSchema().load(res.json())) + else: + raise AuthenticationError("Too many incorrect mfa attempts") + + def _login_oauth2(self) -> None: + """Create a new oauth2 token. + + Raises: + AuthenticationError: If the login credentials are not set, if a challenge + wasn't accepted, or if an mfa code is not accepted. + + """ + self.session.headers.pop("Authorization", None) + + oauth_payload = { + "password": self.password, + "username": self.username, + "grant_type": "password", + "client_id": CLIENT_ID, + "expires_in": EXPIRATION_TIME, + "scope": "internal", + "device_token": self._gen_device_token, + "challenge_type": self.challenge_type, + } + + res = self.post( + endpoints.OAUTH, + data=oauth_payload, + raise_errors=False, + auto_login=False, + return_response=True, + ) + + oauth = OAuthSchema().load(res.json()) + + if oauth.is_challenge: + oauth = self._challenge_oauth2(oauth, oauth_payload) + elif oauth.is_mfa: + oauth = self._mfa_oauth2(oauth_payload) + + if not oauth.is_valid: + msg = f"{oauth.error}" if hasattr(oauth, "error") else "Unknown login error" + raise AuthenticationError(msg) + else: + self._configure_manager(oauth) + + def _refresh_oauth2(self) -> None: + """Refresh an oauth2 token. + + Raises: + AuthenticationError: If refresh_token is missing or if there is an error + when trying to refresh a token. + + """ + if not self.oauth.is_valid: + raise AuthenticationError("Cannot refresh login with unset refresh token") + relogin_payload = { + "grant_type": "refresh_token", + "refresh_token": self.oauth.refresh_token, + "scope": "internal", + "client_id": CLIENT_ID, + "expires_in": EXPIRATION_TIME, + } + self.session.headers.pop("Authorization", None) + try: + res = self.post(endpoints.OAUTH, data=relogin_payload, auto_login=False) + except HTTPError: + raise AuthenticationError("Failed to refresh token") + + oauth = OAuthSchema().load(res) + self._configure_manager(oauth) + + def logout(self) -> None: + """Logout from the session. + + Raises: + AuthenticationError: If there is an error when logging out. + + """ + logout_payload = {"client_id": CLIENT_ID, "token": self.oauth.refresh_token} + try: + self.post(endpoints.OAUTH_REVOKE, data=logout_payload, auto_login=False) + self.oauth = OAuth() + self.session.headers.pop("Authorization", None) + except HTTPError: + raise AuthenticationError("Could not log out") + + def __repr__(self) -> str: + """Return the object as a string. + + Returns: + The string representation of the object. + + """ + return f"SessionManager<{self.username}>" + + +class SessionManagerSchema(BaseSchema): + """Schema class for the SessionManager model.""" + + __model__ = SessionManager + + username = fields.Email() # type: ignore # Call untyped "Email" in typed context + password = fields.Str() + challenge_type = fields.Str(validate=CHALLENGE_TYPE_VAL) + oauth = fields.Nested(OAuthSchema) + expires_at = fields.AwareDateTime() + device_token = fields.Str() + headers = fields.Dict() + proxies = fields.Dict() + + @post_load + def make_object(self, data: JSON, **kwargs: Any) -> SessionManager: + """Override default method to configure SessionManager object on load. + + Args: + data: The JSON dictionary to process + **kwargs: Not used but matches signature of `BaseSchema.make_object` + + Returns: + A configured instance of SessionManager. + """ + oauth = data.pop("oauth", None) + expires_at = data.pop("expires_at", None) + session_manager = self.__model__(**data) + + if oauth is not None and oauth.is_valid: + session_manager.oauth = oauth + session_manager.session.headers.update( + {"Authorization": f"Bearer {session_manager.oauth.access_token}"} + ) + if expires_at: + session_manager.expires_at = expires_at + + return session_manager + + +def dump_session( + session_manager: SessionManager, path: Optional[Union[Path, str]] = None +) -> None: + """Save the current session parameters to a json file. + + Note: + This function defaults to caching this information to + ~/.robinhood/login.json + + Args: + session_manager: A SessionManager instance. + path: The location to save the file and its name. + + """ + path = CACHE_LOGIN if path is None else path + json_str = SessionManagerSchema().dumps(session_manager, indent=4) + + with open(path, "w+") as file: + file.write(json_str) + + +def load_session(path: Optional[Union[Path, str]] = None) -> SessionManager: + """Load cached session parameters from a json file. + + Note: + This function defaults to caching this information to + ~/.robinhood/login.json + + Args: + path: The location and file name to load from. + + Returns: + A configured instance of SessionManager. + + """ + path = path or CACHE_LOGIN + with open(path) as file: + return cast(SessionManager, SessionManagerSchema().loads(file.read())) diff --git a/pyrh/robinhood.py b/pyrh/robinhood.py index 6de2a6b6..cb0fb605 100644 --- a/pyrh/robinhood.py +++ b/pyrh/robinhood.py @@ -12,7 +12,7 @@ InvalidOptionId, InvalidTickerSymbol, ) -from pyrh.sessionmanager import SessionManager +from pyrh.models import SessionManager class Bounds(Enum): diff --git a/pyrh/sessionmanager.py b/pyrh/sessionmanager.py deleted file mode 100644 index d8888649..00000000 --- a/pyrh/sessionmanager.py +++ /dev/null @@ -1,446 +0,0 @@ -"""Manage Robinhood Sessions.""" - -import json -import uuid -from copy import deepcopy -from datetime import datetime, timedelta -from pathlib import Path -from typing import Dict, Optional, Union -from urllib.request import getproxies - -import requests -from requests.structures import CaseInsensitiveDict - -from pyrh.cache import CACHE_ROOT -from pyrh.exceptions import AuthenticationError - - -CERTS_PATH: Path = Path(__file__).parent.joinpath("./ssl/certs.pem") -"""Path to ssl files used when running post requests.""" - -CLIENT_ID: str = "c82SH0WZOsabOXGP2sxqcj34FxkvfnWRZBKlBjFS" -"""Robinhood client id.""" - -CACHE_LOGIN: Path = CACHE_ROOT.joinpath("login.json") -"""Path to login.json config file.""" -CACHE_LOGIN.touch(exist_ok=True) - -# TODO: put urls in an API module -OAUTH_TOKEN_URL: str = "https://api.robinhood.com/oauth2/token/" -"""Oauth token generation endpoint.""" - -OAUTH_REVOKE_URL: str = "https://api.robinhood.com/oauth2/revoke_token/" -"""Oauth revocation endpoint.""" - -HEADERS: CaseInsensitiveDict = CaseInsensitiveDict( - { - "Accept": "*/*", - "Accept-Encoding": "gzip, deflate", - "Accept-Language": "en;q=1, fr;q=0.9, de;q=0.8, ja;q=0.7, nl;q=0.6, it;q=0.5", - "Content-Type": "application/x-www-form-urlencoded; charset=utf-8", - "X-Robinhood-API-Version": "1.0.0", - "Connection": "keep-alive", - "User-Agent": "Robinhood/823 (iPhone; iOS 7.1.2; Scale/2.00)", - } -) -"""Headers used when performing requests with robinhood api.""" - -EXPIRATION_TIME: int = 86400 -"""Default expiration time for requests.""" - - -class SessionManager(object): - """Mange connectivity with Robinhood API. - - Once logged into the session, this class will manage automatic oauth token update - requests allowing for the automation systems to only require multi-factor - authentication on initialization. - - Example: - >>> sm = SessionManager() - >>> sm.login(username="USERNAME", password="PASSWORD") # xdoctest: +SKIP - >>> sm.logout() # xdoctest: +SKIP - - If you want to cache your session (you should) then you can use the following - functions. This will allow you to re-cover from a script crash without having to - manually re-enter multi-factor authentication codes. - - Example: - >>> sm.to_json() # xdoctest: +SKIP - >>> sm.from_json() # xdoctest: +SKIP - - Args: - username: The username to login to Robinhood. - password: The password to login to Robinhood. - challenge_type: Either sms or email. (only if not using mfa) - headers: Any optional header dict modifications for the session. - proxies: Any optional proxy dict modification for the session. - - Attributes: - session: A requests session instance. - expires_at: The time the oauth token will expire at, default is - 1970-01-01 00:00:00. - certs: The path to the desired certs to check against. - device_token: A random guid representing the current device. - access_token: An oauth2 token to connect to the Robinhood API. - refresh_token: An oauth2 refresh token to refresh the access_token when - required. - username: The username to login to Robinhood. - password: The password to login to Robinhood. - challenge_type: Either sms or email. (only if not using mfa) - headers: Any optional header dict modifications for the session. - proxies: Any optional proxy dict modification for the session. - - """ - - def __init__( - self, - username: Optional[str] = None, - password: Optional[str] = None, - challenge_type: Optional[str] = "email", - headers: Optional[CaseInsensitiveDict] = None, - proxies: Optional[Dict] = None, - ) -> None: - self.session: requests.Session = requests.session() - self.session.headers = HEADERS if headers is None else headers - self.session.proxies = getproxies() if proxies is None else proxies - self.expires_at = datetime.strptime("1970", "%Y") # some time in the past - self.certs: Path = CERTS_PATH - - self.username: Optional[str] = username - self.password: Optional[str] = password - if challenge_type not in ["email", "sms"]: - raise ValueError("challenge_type must be email or sms") - self.challenge_type: str = challenge_type - self.device_token: str = str(uuid.uuid4()) - - self.access_token: Optional[str] = None - self.refresh_token: Optional[str] = None - - def to_json(self, path: Optional[Union[Path, str]] = None) -> None: - """Save the current session parameters to a json file. - - Note: - This function defaults to caching this information to - ~/.robinhood/login.json - - Args: - path: The location to save the file and its name. - - """ - path = CACHE_LOGIN if path is None else path - data = deepcopy(self.__dict__) - data.pop("session") - data.pop("certs") - data["expires_at"] = self.expires_at.strftime("%Y-%m-%d %H:%M:%S") - - with open(path, "w+") as file: - file.write(json.dumps(data, indent=4, default=str)) - - def from_json(self, path: Optional[Union[Path, str]] = None) -> None: - """Load cached session parameters from a json file. - - Note: - This function defaults to caching this information to - ~/.robinhood/login.json - - Args: - path: The location and file name to load from. - - """ - path = path or CACHE_LOGIN - with open(path) as file: - data = json.load(file) - - for k, v in data.items(): - if k == "expires_at": - v = datetime.strptime(v, "%Y-%m-%d %H:%M:%S") - setattr(self, k, v) - - if self.access_token is not None: - self.session.headers.update( - {"Authorization": f"Bearer {self.access_token}"} - ) - - @property - def login_set(self) -> bool: - """Check if login info is properly configured. - - Returns: - Whether or not username and password are set. - """ - return self.password is not None and self.username is not None - - @property - def authenticated(self) -> bool: - """Check if the session is authenticated. - - Returns: - Whether or not the session is logged in. - """ - return ( - "Authorization" in self.session.headers and datetime.now() < self.expires_at - ) - - def login(self, force_refresh: bool = False) -> None: - """Login to the session. - - This method logs the user in if they are not already and otherwise refreshes - the oauth token if it is expired. - - Args: - force_refresh: If already logged in, whether or not to force a oauth token - refresh. - - """ - if "Authorization" not in self.session.headers: - self._login_oauth2() - elif self.refresh_token is not None and ( - self.expires_at < datetime.now() or force_refresh - ): - self._refresh_oauth2() - - def get( - self, - url: str, - params: dict = None, - headers: Optional[CaseInsensitiveDict] = None, - raise_errors: bool = True, - auto_login: bool = True, - ) -> Dict: - """Run a wrapped session HTTP GET request. - - Note: - This method automatically prompts the user to log in if not already logged - in. - - Args: - url: The url to get from. - params: query string parameters - headers: A dict adding to and overriding the session headers. - raise_errors: Whether or not raise errors on GET request result. - auto_login: Whether or not to automatically login on restricted endpoint - errors. - - Returns: - The POST response - - """ - if params is None: - params = {} - res = self.session.get( - url, params=params, headers={} if headers is None else headers - ) - if res.status_code == 401 and auto_login: - self.login(force_refresh=True) - res = self.session.get( - url, params=params, headers={} if headers is None else headers - ) - if raise_errors: - res.raise_for_status() - - return res.json() - - def post( - self, - url: str, - data: Optional[Dict] = None, - headers: Optional[CaseInsensitiveDict] = None, - raise_errors: bool = True, - auto_login: bool = True, - ) -> Dict: - """Run a wrapped session HTTP POST request. - - Note: - This method automatically prompts the user to log in if not already logged - in. - - Args: - url: The url to post to. - data: The payload to POST to the endpoint. - headers: A dict adding to and overriding the session headers. - raise_errors: Whether or not raise errors on POST request. - auto_login: Whether or not to automatically login on restricted endpoint - errors. - - Returns: - The response or an empty dict if an empty response is returned. - - """ - res = self.session.post( - url, - data=data, - timeout=15, - verify=self.certs, - headers={} if headers is None else headers, - ) - if (res.status_code == 401) and auto_login: - self.login(force_refresh=True) - res = self.session.post( - url, - data=data, - timeout=15, - verify=self.certs, - headers={} if headers is None else headers, - ) - if raise_errors: - res.raise_for_status() - if res.headers.get("Content-Length", None) == "0": - return {} - else: - return res.json() - - def _process_auth_body(self, res: Dict) -> None: - """Process an authentication response dictionary. - - This method updates the internal state of the session based on a login or - token refresh request. - - Args: - res: A response dictionary from a login request. - - Raises: - AuthenticationError: If the input dictionary is malformed. - - """ - try: - self.access_token = res["access_token"] - self.refresh_token = res["refresh_token"] - self.expires_at = datetime.now() + timedelta(seconds=EXPIRATION_TIME) - self.session.headers.update( - {"Authorization": f"Bearer {self.access_token}"} - ) - except KeyError: - raise AuthenticationError( - "Authorization result body missing required responses." - ) - - def _login_oauth2(self) -> None: - """Create a new oauth2 token. - - Raises: - AuthenticationError: If the login credentials are not set, if a challenge - wasn't accepted, or if an mfa code is not accepted. - - """ - if not self.login_set: - raise AuthenticationError( - "Username and password must be passed to constructor or must be loaded " - "from json" - ) - self.session.headers.pop("Authorization", None) - - oauth_payload = { - "password": self.password, - "username": self.username, - "grant_type": "password", - "client_id": CLIENT_ID, - "expires_in": EXPIRATION_TIME, - "scope": "internal", - "device_token": self.device_token, - "challenge_type": self.challenge_type, - } - - res = self.post( - OAUTH_TOKEN_URL, data=oauth_payload, raise_errors=False, auto_login=False - ) - - if res is None or "error" in res: - raise AuthenticationError("Unknown login error") - elif "detail" in res and any(k in res["detail"] for k in ["Invalid", "Unable"]): - raise AuthenticationError(f"{res['detail']}") - elif "challenge" in res: - challenge_id = res["challenge"]["id"] - # TODO: use api module - challenge_url = ( - f"https://api.robinhood.com/challenge/{challenge_id}/respond/" - ) - print(f"Input challenge code from {self.challenge_type.capitalize()}:") - challenge_code = input() - challenge_payload = {"response": str(challenge_code)} - challenge_header = CaseInsensitiveDict( - {"X-ROBINHOOD-CHALLENGE-RESPONSE-ID": challenge_id} - ) - challenge = self.post( - challenge_url, - data=challenge_payload, - raise_errors=False, - headers=challenge_header, - auto_login=False, - ) - if challenge is None or challenge.get("status", "") != "validated": - raise AuthenticationError("Challenge response was not accepted") - res = self.post( - OAUTH_TOKEN_URL, - data=oauth_payload, - headers=challenge_header, - auto_login=False, - ) - elif "mfa_required" in res: - print(f"Input mfa code:") - mfa_code = input() - oauth_payload["mfa_code"] = mfa_code - res = self.post( - OAUTH_TOKEN_URL, - data=oauth_payload, - raise_errors=False, - auto_login=False, - ) - if res is None or (res.get("detail", "") == "Please enter a valid code."): - raise AuthenticationError("Mfa code was not accepted") - - self._process_auth_body(res) - - def _refresh_oauth2(self) -> None: - """Refresh an oauth2 token. - - Raises: - AuthenticationError: If refresh_token is missing or if there is an error - when trying to refresh a token. - - """ - if self.refresh_token is None: - raise AuthenticationError("Cannot refresh login with unset refresh token") - relogin_payload = { - "grant_type": "refresh_token", - "refresh_token": self.refresh_token, - "scope": "internal", - "client_id": CLIENT_ID, - "expires_in": EXPIRATION_TIME, - } - self.session.headers.pop("Authorization", None) - res = self.post( - OAUTH_TOKEN_URL, data=relogin_payload, auto_login=False, raise_errors=False - ) - if "error" in res: - raise AuthenticationError("Failed to refresh token") - self._process_auth_body(res) - - def logout(self) -> None: - """Logout from the session. - - Raises: - AuthenticationError: If there is an error when logging out. - - """ - logout_payload = { - "client_id": CLIENT_ID, - "token": self.refresh_token, - } - res = self.post( - OAUTH_REVOKE_URL, data=logout_payload, raise_errors=False, auto_login=False - ) - if len(res) == 0: - self.access_token = None - self.refresh_token = None - else: - raise AuthenticationError("Could not log out") - - def __repr__(self) -> str: - """Return the object as a string. - - Returns: - The string representation of the object. - - """ - return f"SessionManager<{self.username}>" diff --git a/setup.cfg b/setup.cfg index c0593781..2c11d28c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,8 +16,16 @@ exclude_lines = if __name__ == .__main__.: [mypy] -ignore_missing_imports = True -disallow_untyped_calls = True +strict = True +disallow_untyped_decorators = False + +# TODO: Remove ignored typing errors +[mypy-*.robinhood] +ignore_errors = True +[mypy-*.endpoints] +ignore_errors = True +[mypy-tests.*] +ignore_errors = True # flake8 [flake8] @@ -28,6 +36,7 @@ ignore = W503 # Line break occurred after a binary operator (opposite of W504) D107 # Missing docstring in __init__ D301 # Use r""" if any backslashes in a docstring + D106 # Nested class docs (marshmallow) max-complexity = 12 # TODO: remove docstring exemptions after refactor per-file-ignores = diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 00000000..d4df403c --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,44 @@ +"""Test base model file.""" + + +def test_base_model_simplenamespace_simple(): + from pyrh.models.base import BaseModel, UnknownModel + + payload1 = {"a": 10, "b": 15, "c": 20} + payload2 = {"nested": {"b": 15, "c": 20}} + payload3 = {"list": [10, {"a": 5, "b": 10, "c": 20}]} + + bm1 = BaseModel(**payload1) + for k, v in payload1.items(): + assert getattr(bm1, k) == v + + bm2 = BaseModel(**payload2) + assert bm2.nested == UnknownModel(**payload2["nested"]) + + bm3 = BaseModel(**payload3) + assert bm3.list == [10, UnknownModel(**payload3["list"][1])] + + +def test_base_model_repr(): + from pyrh.models.base import BaseModel + + bm = BaseModel(a=10) + + assert "BaseModel(a=10)" == str(bm) + + +def test_base_model_len(): + from pyrh.models.base import BaseModel + + bm = BaseModel(a=10, b=20) + + assert len(bm) == 2 + + +def test_base_schema(): + from pyrh.models.base import UnknownModel, BaseSchema + + bm = UnknownModel(a=10) + load_bm = BaseSchema().load({"a": 10}) + assert bm == load_bm + assert type(bm) == type(load_bm) diff --git a/tests/test_oauth.py b/tests/test_oauth.py new file mode 100644 index 00000000..503e96b5 --- /dev/null +++ b/tests/test_oauth.py @@ -0,0 +1,43 @@ +"""Test the oauth classes.""" + +from freezegun import freeze_time + + +@freeze_time("2020-01-01") +def test_challenge_can_retry(): + from copy import copy + + from pyrh.models.oauth import Challenge + from datetime import datetime, timedelta + import pytz + + future = datetime.strptime("2020-01-02", "%Y-%m-%d").replace(tzinfo=pytz.UTC) + + data = {"expires_at": future} + + challenge = Challenge(**data) + + assert not challenge.can_retry + + challenge.remaining_attempts = 1 + assert challenge.can_retry + + challenge.expires_at = future - timedelta(days=3) + + assert not challenge.can_retry + + +def test_oauth_test_attrs(): + from pyrh.models.oauth import OAuth + from pyrh.models.base import UnknownModel + + oa = OAuth() + oa.challenge = UnknownModel(a="test") + assert oa.is_challenge + + oa.mfa_required = UnknownModel(a="test") + assert oa.is_mfa + + oa.access_token = "some-token" + oa.refresh_token = "other-token" + assert oa.is_valid diff --git a/tests/test_sessionmanager.py b/tests/test_sessionmanager.py index 787765bc..9bae341c 100644 --- a/tests/test_sessionmanager.py +++ b/tests/test_sessionmanager.py @@ -4,11 +4,17 @@ import pytest import requests_mock +from freezegun import freeze_time + + +MOCK_URL = "mock://test.com" + +# TODO: refactor this to remove internal method testing and only test the public methods @pytest.fixture def sm(): - from pyrh.sessionmanager import SessionManager + from pyrh.models import SessionManager sample_user = { "username": "user@example.com", @@ -18,94 +24,95 @@ def sm(): return SessionManager(**sample_user) -def test_repr(sm): - assert str(sm) == "SessionManager" - - -def test_bad_challenge_type(): - from pyrh.sessionmanager import SessionManager - - with pytest.raises(ValueError) as e: - SessionManager(challenge_type="bad") - - assert "challenge_type must be" in str(e.value) - +@pytest.fixture +def sm_adap(monkeypatch): + from pyrh.models import SessionManager -def test_no_user_pass_oauth2(monkeypatch): - from pyrh.sessionmanager import SessionManager - from pyrh.exceptions import AuthenticationError + sample_user = { + "username": "user@example.com", + "password": "some password", + } - def ret_false(): - return False + monkeypatch.setattr("pyrh.endpoints.OAUTH", MOCK_URL) + monkeypatch.setattr("pyrh.endpoints.OAUTH_REVOKE", MOCK_URL) + monkeypatch.setattr("pyrh.endpoints.CHALLENGE", lambda x: MOCK_URL) - with pytest.raises(AuthenticationError) as e: - sm = SessionManager() - monkeypatch.setattr(sm, "from_json", ret_false) - sm._login_oauth2() + session_manager = SessionManager(**sample_user) + adapter = requests_mock.Adapter() + session_manager.session.mount("mock", adapter) - assert "Username and password must be" in str(e.value) + return session_manager, adapter -def test_login_oauth2_err(monkeypatch, sm): - from pyrh.exceptions import AuthenticationError +def test_repr(sm): + assert str(sm) == "SessionManager" - def err_dict(*args, **kwargs): - return {"error": "Some error"} - def none_resp(*args, **kwargs): - return None +def test_bad_challenge_type(sm): + from pyrh.models import SessionManager - monkeypatch.setattr(sm, "post", err_dict) - with pytest.raises(AuthenticationError) as e1: - sm._login_oauth2() + sample_user = { + "username": "user@example.com", + "password": "some password", + } - monkeypatch.setattr(sm, "post", none_resp) - with pytest.raises(AuthenticationError) as e2: - sm._login_oauth2() + with pytest.raises(ValueError) as e: + SessionManager(**sample_user, challenge_type="bad") - assert "Unknown login error" in str(e1.value) - assert "Unknown login error" in str(e2.value) + assert "challenge_type must be" in str(e.value) -def test_login_oauth2_detail(monkeypatch, sm): +def test_login_oauth2_errors(monkeypatch, sm_adap): from pyrh.exceptions import AuthenticationError - def invalid_jwt(*args, **kwargs): - return {"detail": "Invalid JWT. Signature has expired"} + sm, adapter = sm_adap - monkeypatch.setattr(sm, "post", invalid_jwt) + # Note it is not possible to get invalid results to replace + # oauth from the mfa approaches as those individual functions will error + # out themselves + + monkeypatch.setattr("pyrh.endpoints.OAUTH", MOCK_URL) + adapter.register_uri( + "POST", MOCK_URL, text='{"error": "Some error"}', status_code=400 + ) with pytest.raises(AuthenticationError) as e: sm._login_oauth2() - assert "Invalid JWT" in str(e.value) + assert "Some error" in str(e.value) + +@freeze_time("2005-01-01") +def test_login_oauth2_challenge_valid(monkeypatch, sm_adap): + import uuid + from datetime import datetime + import pytz + from pyrh.models.oauth import OAuthSchema -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_login_oauth2_challenge_valid(post_mock, monkeypatch, sm): monkeypatch.setattr("builtins.input", lambda: "123456") - post_mock.side_effect = [ + expiry = datetime.strptime("2010", "%Y").replace(tzinfo=pytz.UTC) + responses = [ { "detail": "Request blocked, challenge issued.", "challenge": { - "id": "some_id", - "user": "some_user", + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), "type": "email", "alternate_type": "sms", "status": "issued", "remaining_retries": 3, "remaining_attempts": 3, - "expires_at": "some_datetime", + "expires_at": expiry, }, }, { - "id": "some_id", - "user": "some_user", + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), "type": "email", "alternate_type": "sms", "status": "validated", "remaining_retries": 0, "remaining_attempts": 0, - "expires_at": "some_datetime", + "expires_at": expiry, }, { "access_token": "some_token", @@ -117,56 +124,102 @@ def test_login_oauth2_challenge_valid(post_mock, monkeypatch, sm): "backup_code": None, }, ] + expected = [ + {"text": OAuthSchema().dumps(responses[0]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[1]), "status_code": 200}, + {"text": OAuthSchema().dumps(responses[2]), "status_code": 200}, + ] + sm, adapter = sm_adap + adapter.register_uri("POST", MOCK_URL, expected) sm._login_oauth2() - assert post_mock.call_count == 3 + assert sm.oauth.is_valid -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_login_oauth2_challenge_invalid(post_mock, monkeypatch, sm): +@freeze_time("2005-01-01") +def test_login_oauth2_challenge_invalid(monkeypatch, sm_adap): from pyrh.exceptions import AuthenticationError + from datetime import datetime + from pyrh.models.oauth import OAuthSchema + import pytz + import uuid monkeypatch.setattr("builtins.input", lambda: "123456") - post_mock.side_effect = [ + expiry = datetime.strptime("2010", "%Y").replace(tzinfo=pytz.UTC) + responses = [ { "detail": "Request blocked, challenge issued.", "challenge": { - "id": "some_id", - "user": "some_user", + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), "type": "email", "alternate_type": "sms", "status": "issued", "remaining_retries": 3, "remaining_attempts": 3, - "expires_at": "some_datetime", + "expires_at": expiry, }, }, { "detail": "Challenge response is invalid.", "challenge": { - "id": "some_id", - "user": "some_user", + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), "type": "email", "alternate_type": "sms", "status": "issued", "remaining_retries": 3, "remaining_attempts": 2, - "expires_at": "some_datetime", + "expires_at": expiry, + }, + }, + { + "detail": "Challenge response is invalid.", + "challenge": { + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), + "type": "email", + "alternate_type": "sms", + "status": "issued", + "remaining_retries": 3, + "remaining_attempts": 1, + "expires_at": expiry, + }, + }, + { + "detail": "Some message.", + "challenge": { + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), + "type": "email", + "alternate_type": "sms", + "status": "failed", + "remaining_retries": 3, + "remaining_attempts": 0, + "expires_at": expiry, }, }, ] + expected = [ + {"text": OAuthSchema().dumps(responses[0]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[1]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[2]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[3]), "status_code": 401}, + ] + sm, adapter = sm_adap + adapter.register_uri("POST", MOCK_URL, expected) with pytest.raises(AuthenticationError) as e: sm._login_oauth2() - assert post_mock.call_count == 2 - assert "Challenge response was not accepted" in str(e.value) + assert "Exceeded available" in str(e.value) + +def test_login_oauth2_mfa_valid(monkeypatch, sm_adap): + from pyrh.models.oauth import OAuthSchema -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_login_oauth2_mfa_valid(post_mock, monkeypatch, sm): mfa_code = "123456" monkeypatch.setattr("builtins.input", lambda: mfa_code) - post_mock.side_effect = [ + responses = [ {"mfa_required": True, "mfa_type": "app"}, { "access_token": "some_token", @@ -178,39 +231,46 @@ def test_login_oauth2_mfa_valid(post_mock, monkeypatch, sm): "backup_code": None, }, ] + expected = [ + {"text": OAuthSchema().dumps(responses[0]), "status_code": 200}, + {"text": OAuthSchema().dumps(responses[1]), "status_code": 200}, + ] + sm, adapter = sm_adap + adapter.register_uri("POST", MOCK_URL, expected) sm._login_oauth2() - assert post_mock.call_count == 2 + assert sm.oauth.is_valid -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_login_oauth2_mfa_invalid(post_mock, monkeypatch, sm): +def test_login_oauth2_mfa_invalid(monkeypatch, sm_adap): from pyrh.exceptions import AuthenticationError + from pyrh.models.oauth import OAuthSchema monkeypatch.setattr("builtins.input", lambda: "123456") - post_mock.side_effect = [ + responses = [ {"mfa_required": True, "mfa_type": "app"}, - {"detail": "Please enter a valid code."}, + {"detail": "Please enter a valid code"}, + {"detail": "Please enter a valid code"}, + {"detail": "Please enter a valid code"}, + ] + expected = [ + {"text": OAuthSchema().dumps(responses[0]), "status_code": 200}, + {"text": OAuthSchema().dumps(responses[1]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[2]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[3]), "status_code": 401}, ] + sm, adapter = sm_adap + adapter.register_uri("POST", MOCK_URL, expected) with pytest.raises(AuthenticationError) as e: sm._login_oauth2() - assert post_mock.call_count == 2 - assert "Mfa code was not accepted" in str(e.value) - - -def test_process_auth_body_invalid(sm): - from pyrh.exceptions import AuthenticationError - - with pytest.raises(AuthenticationError) as e: - sm._process_auth_body({}) + assert "Too many incorrect" in str(e.value) - assert "missing required responses" in str(e.value) +def test_refresh_oauth2_success(sm_adap): + from pyrh.models.oauth import OAuthSchema -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_refresh_oauth2_success(post_mock, sm): - post_mock.return_value = { + response = { "access_token": "some_token", "expires_in": 86400, "token_type": "Bearer", @@ -219,124 +279,132 @@ def test_refresh_oauth2_success(post_mock, sm): "mfa_code": None, "backup_code": None, } + sm, adapter = sm_adap + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" + adapter.register_uri( + "POST", MOCK_URL, text=OAuthSchema().dumps(response), status_code=200 + ) sm.refresh_token = "some_token" sm.session.headers["Authorization"] = "Bearer some_token" sm._refresh_oauth2() - assert post_mock.call_count == 1 + assert sm.oauth.is_valid -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_refresh_oauth2_failure(post_mock, sm): +def test_refresh_oauth2_failure(sm_adap): from pyrh.exceptions import AuthenticationError + from pyrh.models.oauth import OAuthSchema - post_mock.return_value = {"error": "some_error"} - - with pytest.raises(AuthenticationError) as e1: - sm._refresh_oauth2() - - assert "Cannot refresh login" in str(e1.value) + response = {"error": "some_error"} + sm, adapter = sm_adap + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" + adapter.register_uri( + "POST", MOCK_URL, text=OAuthSchema().dumps(response), status_code=401 + ) sm.refresh_token = "some_token" sm.session.headers["Authorization"] = "Bearer some_token" - - with pytest.raises(AuthenticationError) as e2: + with pytest.raises(AuthenticationError) as e: sm._refresh_oauth2() - assert "Failed to refresh" in str(e2.value) + assert "Failed to refresh" in str(e.value) -@mock.patch("pyrh.sessionmanager.SessionManager._login_oauth2") +@mock.patch("pyrh.models.SessionManager._login_oauth2") def test_login_init(login_mock, sm): sm.login() assert login_mock.call_count == 1 -@mock.patch("pyrh.sessionmanager.SessionManager._refresh_oauth2") -def test_login_refresh_default(refresh_mock, monkeypatch, sm): +@mock.patch("pyrh.models.SessionManager._refresh_oauth2") +def test_login_refresh_default(refresh_mock, sm): # default expires_at is 1970 - monkeypatch.setattr(sm, "refresh_token", "some_token") + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" sm.session.headers["Authorization"] = "Bearer some_token" sm.login() assert refresh_mock.call_count == 1 -@mock.patch("pyrh.sessionmanager.SessionManager._refresh_oauth2") -def test_login_refresh_force(refresh_mock, monkeypatch, sm): - monkeypatch.setattr(sm, "refresh_token", "some_token") +@mock.patch("pyrh.models.SessionManager._refresh_oauth2") +def test_login_refresh_force(refresh_mock, sm): + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" sm.session.headers["Authorization"] = "Bearer some_token" sm.login(force_refresh=True) assert refresh_mock.call_count == 1 -@mock.patch("pyrh.sessionmanager.SessionManager.post") +@mock.patch("pyrh.models.SessionManager.post") def test_logout_success(post_mock, sm): post_mock.return_value = {} - sm.access_token = "some_token" - sm.refresh_token = "some_refresh_token" + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" sm.logout() - assert sm.access_token is None - assert sm.refresh_token is None + assert len(sm.oauth) == 0 assert post_mock.call_count == 1 -@mock.patch("pyrh.sessionmanager.SessionManager.post") +@mock.patch("pyrh.models.SessionManager.post") def test_logout_failure(post_mock, sm): from pyrh.exceptions import AuthenticationError + from requests.exceptions import HTTPError + + def raise_error(*args, **kwargs): + raise HTTPError - post_mock.return_value = {"error": "some_error"} - sm.access_token = "some_token" - sm.refresh_token = "some_refresh_token" + post_mock.side_effect = raise_error + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" with pytest.raises(AuthenticationError) as e: sm.logout() - assert sm.access_token == "some_token" - assert sm.refresh_token == "some_refresh_token" + assert sm.oauth.access_token == "some_token" + assert sm.oauth.refresh_token == "some_refresh_token" assert post_mock.call_count == 1 assert "Could not log out" == str(e.value) def test_jsonify(tmpdir, sm): - from copy import deepcopy import json - sm.access_token = "some_token" - sm.refresh_token = "some_refresh_token" + from pyrh.models import dump_session, load_session + + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" file = tmpdir.join("login.json") file.ensure(file=True) # this will likely migrate to pathlib at some point with pytest.raises(json.JSONDecodeError) as e: - sm.from_json(file) + load_session(file) assert "Expecting value" in str(e.value) - data = deepcopy(sm.__dict__) - data.pop("session") - data.pop("certs") - sm.to_json(file) - - sm.from_json(file) - data2 = deepcopy(sm.__dict__) - data2.pop("session") - data2.pop("certs") + dump_session(sm, file) + sm1 = load_session(file) - assert data == data2 + # TODO: make this test a bit more robust + assert sm.oauth == sm1.oauth -@mock.patch("pyrh.sessionmanager.datetime") -def test_authenticated(dt_mock, sm, monkeypatch): +@freeze_time("2000-01-01") +def test_authenticated(sm, monkeypatch): + import pytz from datetime import datetime as dt, timedelta - from pyrh.sessionmanager import EXPIRATION_TIME + from pyrh.models.sessionmanager import EXPIRATION_TIME - dt_mock.now.return_value = dt.strptime("2000", "%Y") assert not sm.authenticated - expires_at_time = dt_mock.now() + timedelta(seconds=EXPIRATION_TIME) + expires_at_time = dt.now().replace(tzinfo=pytz.UTC) + timedelta( + seconds=EXPIRATION_TIME + ) sm.session.headers["Authorization"] = "Bearer some_token" sm.expires_at = expires_at_time @@ -344,7 +412,7 @@ def test_authenticated(dt_mock, sm, monkeypatch): assert sm.authenticated -@mock.patch("pyrh.sessionmanager.SessionManager.login") +@mock.patch("pyrh.models.SessionManager.login") def test_get(mock_login, sm): import json @@ -373,7 +441,7 @@ def test_get(mock_login, sm): assert "404 Client Error" in str(e.value) -@mock.patch("pyrh.sessionmanager.SessionManager.login") +@mock.patch("pyrh.models.SessionManager.login") def test_post(mock_login, sm): import json @@ -383,7 +451,6 @@ def test_post(mock_login, sm): sm.session.mount("mock", adapter) mock_url = "mock://test.com" expected = [ - {"text": "", "status_code": 200, "headers": {"Content-Length": "0"}}, {"text": '{"error": "login error"}', "status_code": 401}, {"text": '{"test": "321"}', "status_code": 200}, {"text": '{"error": "resource not found"}', "status_code": 404}, @@ -391,12 +458,32 @@ def test_post(mock_login, sm): adapter.register_uri("POST", mock_url, expected) resp1 = sm.post(mock_url) - resp2 = sm.post(mock_url) with pytest.raises(HTTPError) as e: sm.post(mock_url) - assert resp1 == {} - assert resp2 == json.loads(expected[2]["text"]) + assert resp1 == json.loads(expected[1]["text"]) assert mock_login.call_count == 1 assert "404 Client Error" in str(e.value) + + +@freeze_time("2020-01-01") +def test_token_expired(sm): + from datetime import datetime + import pytz + + # Assumes default token expired is set to 1970 + assert sm.token_expired + + sm.expires_at = datetime.strptime("2020-01-03", "%Y-%m-%d").replace(tzinfo=pytz.UTC) + + assert not sm.token_expired + + +def test_login_set(sm): + assert sm.login_set + + sm.username = None + sm.password = None + + assert not sm.login_set