From 06c615097cf30c2263d3de7b22bafccb8f732670 Mon Sep 17 00:00:00 2001 From: Adithya Balaji Date: Fri, 27 Mar 2020 00:16:22 -0400 Subject: [PATCH 1/9] Add yarl URL processing to endpoints file. --- poetry.lock | 60 ++++++++++++++++++++++++++++++++++++++- pyproject.toml | 3 +- pyrh/endpoints.py | 70 +++++++++++++++++++++++++--------------------- pyrh/exceptions.py | 16 ++++------- 4 files changed, 104 insertions(+), 45 deletions(-) diff --git a/poetry.lock b/poetry.lock index ce5623a2..53b8404c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -338,6 +338,14 @@ optional = false python-versions = ">=3.5" version = "8.2.0" +[[package]] +category = "main" +description = "multidict implementation" +name = "multidict" +optional = false +python-versions = ">=3.5" +version = "4.7.5" + [[package]] category = "dev" description = "Optional static typing for Python" @@ -870,6 +878,18 @@ all = ["six", "pytest", "pytest-cov", "codecov", "pygments", "colorama"] optional = ["pygments", "colorama"] tests = ["pytest", "pytest-cov", "codecov"] +[[package]] +category = "main" +description = "Yet another URL library" +name = "yarl" +optional = false +python-versions = ">=3.5" +version = "1.4.2" + +[package.dependencies] +idna = ">=2.0" +multidict = ">=4.0" + [[package]] category = "dev" description = "Backport of pathlib-compatible object wrapper for zip files" @@ -887,7 +907,7 @@ testing = ["jaraco.itertools", "func-timeout"] doc = ["sphinx", "sphinx-autodoc-typehints", "sphinx_rtd_theme", "autodocsumm"] [metadata] -content-hash = "d2d90bc3ec5835c477086c262ede51097246c76b0b5968e524bb53a696bfa003" +content-hash = "b2a59b49a59fd8e3bb516adb697945577de79229aed8083ce0550fdd20b743d3" python-versions = "^3.6" [metadata.files] @@ -1074,6 +1094,25 @@ more-itertools = [ {file = "more-itertools-8.2.0.tar.gz", hash = "sha256:b1ddb932186d8a6ac451e1d95844b382f55e12686d51ca0c68b6f61f2ab7a507"}, {file = "more_itertools-8.2.0-py3-none-any.whl", hash = "sha256:5dd8bcf33e5f9513ffa06d5ad33d78f31e1931ac9a18f33d37e77a180d393a7c"}, ] +multidict = [ + {file = "multidict-4.7.5-cp35-cp35m-macosx_10_13_x86_64.whl", hash = "sha256:fc3b4adc2ee8474cb3cd2a155305d5f8eda0a9c91320f83e55748e1fcb68f8e3"}, + {file = "multidict-4.7.5-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:42f56542166040b4474c0c608ed051732033cd821126493cf25b6c276df7dd35"}, + {file = "multidict-4.7.5-cp35-cp35m-win32.whl", hash = "sha256:7774e9f6c9af3f12f296131453f7b81dabb7ebdb948483362f5afcaac8a826f1"}, + {file = "multidict-4.7.5-cp35-cp35m-win_amd64.whl", hash = "sha256:c2c37185fb0af79d5c117b8d2764f4321eeb12ba8c141a95d0aa8c2c1d0a11dd"}, + {file = "multidict-4.7.5-cp36-cp36m-macosx_10_13_x86_64.whl", hash = "sha256:e439c9a10a95cb32abd708bb8be83b2134fa93790a4fb0535ca36db3dda94d20"}, + {file = "multidict-4.7.5-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:85cb26c38c96f76b7ff38b86c9d560dea10cf3459bb5f4caf72fc1bb932c7136"}, + {file = "multidict-4.7.5-cp36-cp36m-win32.whl", hash = "sha256:620b37c3fea181dab09267cd5a84b0f23fa043beb8bc50d8474dd9694de1fa6e"}, + {file = "multidict-4.7.5-cp36-cp36m-win_amd64.whl", hash = "sha256:6e6fef114741c4d7ca46da8449038ec8b1e880bbe68674c01ceeb1ac8a648e78"}, + {file = "multidict-4.7.5-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:a326f4240123a2ac66bb163eeba99578e9d63a8654a59f4688a79198f9aa10f8"}, + {file = "multidict-4.7.5-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:dc561313279f9d05a3d0ffa89cd15ae477528ea37aa9795c4654588a3287a9ab"}, + {file = "multidict-4.7.5-cp37-cp37m-win32.whl", hash = "sha256:4b7df040fb5fe826d689204f9b544af469593fb3ff3a069a6ad3409f742f5928"}, + {file = "multidict-4.7.5-cp37-cp37m-win_amd64.whl", hash = "sha256:317f96bc0950d249e96d8d29ab556d01dd38888fbe68324f46fd834b430169f1"}, + {file = "multidict-4.7.5-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:b51249fdd2923739cd3efc95a3d6c363b67bbf779208e9f37fd5e68540d1a4d4"}, + {file = "multidict-4.7.5-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:ae402f43604e3b2bc41e8ea8b8526c7fa7139ed76b0d64fc48e28125925275b2"}, + {file = "multidict-4.7.5-cp38-cp38-win32.whl", hash = "sha256:bb519becc46275c594410c6c28a8a0adc66fe24fef154a9addea54c1adb006f5"}, + {file = "multidict-4.7.5-cp38-cp38-win_amd64.whl", hash = "sha256:544fae9261232a97102e27a926019100a9db75bec7b37feedd74b3aa82f29969"}, + {file = "multidict-4.7.5.tar.gz", hash = "sha256:aee283c49601fa4c13adc64c09c978838a7e812f85377ae130a24d7198c0331e"}, +] mypy = [ {file = "mypy-0.770-cp35-cp35m-macosx_10_6_x86_64.whl", hash = "sha256:a34b577cdf6313bf24755f7a0e3f3c326d5c1f4fe7422d1d06498eb25ad0c600"}, {file = "mypy-0.770-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:86c857510a9b7c3104cf4cde1568f4921762c8f9842e987bc03ed4f160925754"}, @@ -1311,6 +1350,25 @@ xdoctest = [ {file = "xdoctest-0.11.0-py2.py3-none-any.whl", hash = "sha256:292d8c0e8de9bcab6ce3fa5009da445278ce343c744aa2ae260710f924c5242e"}, {file = "xdoctest-0.11.0.tar.gz", hash = "sha256:71951c60bb8b15fdf3c368c5636a6bbbd658ef5c2e5949e784adf5e1a1275fbd"}, ] +yarl = [ + {file = "yarl-1.4.2-cp35-cp35m-macosx_10_13_x86_64.whl", hash = "sha256:3ce3d4f7c6b69c4e4f0704b32eca8123b9c58ae91af740481aa57d7857b5e41b"}, + {file = "yarl-1.4.2-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:a4844ebb2be14768f7994f2017f70aca39d658a96c786211be5ddbe1c68794c1"}, + {file = "yarl-1.4.2-cp35-cp35m-win32.whl", hash = "sha256:d8cdee92bc930d8b09d8bd2043cedd544d9c8bd7436a77678dd602467a993080"}, + {file = "yarl-1.4.2-cp35-cp35m-win_amd64.whl", hash = "sha256:c2b509ac3d4b988ae8769901c66345425e361d518aecbe4acbfc2567e416626a"}, + {file = "yarl-1.4.2-cp36-cp36m-macosx_10_13_x86_64.whl", hash = "sha256:308b98b0c8cd1dfef1a0311dc5e38ae8f9b58349226aa0533f15a16717ad702f"}, + {file = "yarl-1.4.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:944494be42fa630134bf907714d40207e646fd5a94423c90d5b514f7b0713fea"}, + {file = "yarl-1.4.2-cp36-cp36m-win32.whl", hash = "sha256:5b10eb0e7f044cf0b035112446b26a3a2946bca9d7d7edb5e54a2ad2f6652abb"}, + {file = "yarl-1.4.2-cp36-cp36m-win_amd64.whl", hash = "sha256:a161de7e50224e8e3de6e184707476b5a989037dcb24292b391a3d66ff158e70"}, + {file = "yarl-1.4.2-cp37-cp37m-macosx_10_13_x86_64.whl", hash = "sha256:26d7c90cb04dee1665282a5d1a998defc1a9e012fdca0f33396f81508f49696d"}, + {file = "yarl-1.4.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:0c2ab325d33f1b824734b3ef51d4d54a54e0e7a23d13b86974507602334c2cce"}, + {file = "yarl-1.4.2-cp37-cp37m-win32.whl", hash = "sha256:e15199cdb423316e15f108f51249e44eb156ae5dba232cb73be555324a1d49c2"}, + {file = "yarl-1.4.2-cp37-cp37m-win_amd64.whl", hash = "sha256:2098a4b4b9d75ee352807a95cdf5f10180db903bc5b7270715c6bbe2551f64ce"}, + {file = "yarl-1.4.2-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c9959d49a77b0e07559e579f38b2f3711c2b8716b8410b320bf9713013215a1b"}, + {file = "yarl-1.4.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:25e66e5e2007c7a39541ca13b559cd8ebc2ad8fe00ea94a2aad28a9b1e44e5ae"}, + {file = "yarl-1.4.2-cp38-cp38-win32.whl", hash = "sha256:6faa19d3824c21bcbfdfce5171e193c8b4ddafdf0ac3f129ccf0cdfcb083e462"}, + {file = "yarl-1.4.2-cp38-cp38-win_amd64.whl", hash = "sha256:0ca2f395591bbd85ddd50a82eb1fde9c1066fafe888c5c7cc1d810cf03fd3cc6"}, + {file = "yarl-1.4.2.tar.gz", hash = "sha256:58cd9c469eced558cd81aa3f484b2924e8897049e06889e8ff2510435b7ef74b"}, +] zipp = [ {file = "zipp-3.1.0-py3-none-any.whl", hash = "sha256:aa36550ff0c0b7ef7fa639055d797116ee891440eac1a56f378e2d3179e0320b"}, {file = "zipp-3.1.0.tar.gz", hash = "sha256:c599e4d75c98f6798c509911d08a22e6c021d074469042177c8c86fb92eefd96"}, diff --git a/pyproject.toml b/pyproject.toml index d627d62b..9aa2ac8e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ sphinx-autodoc-typehints = { version = "^1.10.3", optional = true } sphinx_rtd_theme = { version = "^0.4.3", optional = true } autodocsumm = { version = "^0.1.13", optional = true } tomlkit = "^0.5.11" +yarl = "^1.4.2" [tool.poetry.dev-dependencies] # Linting @@ -76,7 +77,7 @@ exclude = ''' [tool.isort] known_first_party = 'robinhood' -known_third_party = ["dateutil", "pytest", "requests", "requests_mock"] +known_third_party = ["dateutil", "pytest", "requests", "requests_mock", "yarl"] multi_line_output = 3 lines_after_imports = 2 force_grid_wrap = 0 diff --git a/pyrh/endpoints.py b/pyrh/endpoints.py index da5fcd54..4a5c0fca 100755 --- a/pyrh/endpoints.py +++ b/pyrh/endpoints.py @@ -1,20 +1,25 @@ -BASE_API = "https://api.robinhood.com" +"""Define Robinhood endpoints.""" + +from yarl import URL + + +BASE = URL("https://api.robinhood.com") def login(): - return BASE_API + "/oauth2/token/" + return BASE.with_path("/oauth2/token/") def logout(): - return BASE_API + "/oauth2/revoke_token/" + return BASE.with_path("/oauth2/revoke_token/") def investment_profile(): - return BASE_API + "/user/investment_profile/" + return BASE.with_path("/user/investment_profile/") def accounts(): - return BASE_API + "/accounts/" + return BASE.with_path("/accounts/") def ach(option): @@ -25,20 +30,22 @@ def ach(option): * transfers """ return ( - BASE_API + "/ach/iav/auth/" if option == "iav" else BASE_API + f"/ach/{option}/" + BASE.with_path("/ach/iav/auth/") + if option == "iav" + else BASE.with_path(f"/ach/{option}/") ) def applications(): - return BASE_API + "/applications/" + return BASE.with_path("/applications/") def dividends(): - return BASE_API + "/dividends/" + return BASE.with_path("/dividends/") def edocuments(): - return BASE_API + "/documents/" + return BASE.with_path("/documents/") def instruments(instrument_id=None, option=None): @@ -46,7 +53,7 @@ def instruments(instrument_id=None, option=None): Return information about a specific instrument by providing its instrument id. Add extra options for additional information such as "popularity" """ - url = BASE_API + f"/instruments/" + url = BASE.with_path(f"/instruments/") if instrument_id is not None: url += f"{instrument_id}" if option is not None: @@ -56,83 +63,82 @@ def instruments(instrument_id=None, option=None): def margin_upgrades(): - return BASE_API + "/margin/upgrades/" + return BASE.with_path("/margin/upgrades/") def markets(): - return BASE_API + "/markets/" + return BASE.with_path("/markets/") def notifications(): - return BASE_API + "/notifications/" + return BASE.with_path("/notifications/") def orders(order_id=""): - return BASE_API + f"/orders/{order_id}" + return BASE.with_path(f"/orders/{order_id}/") def password_reset(): - return BASE_API + "/password_reset/request/" + return BASE.with_path("/password_reset/request/") def portfolios(): - return BASE_API + "/portfolios/" + return BASE.with_path("/portfolios/") def positions(): - return BASE_API + "/positions/" + return BASE.with_path("/positions/") def quotes(): - return BASE_API + "/quotes/" + return BASE.with_path("/quotes/") def historicals(): - return BASE_API + "/quotes/historicals/" + return BASE.with_path("/quotes/historicals/") def document_requests(): - return BASE_API + "/upload/document_requests/" + return BASE.with_path("/upload/document_requests/") def user(): - return BASE_API + "/user/" + return BASE.with_path("/user/") def watchlists(): - return BASE_API + "/watchlists/" + return BASE.with_path("/watchlists/") def news(stock): - return BASE_API + f"/midlands/news/{stock}/" + return BASE.with_path(f"/midlands/news/{stock}/") def fundamentals(stock): - return BASE_API + f"/fundamentals/{stock}/" + return BASE.with_path(f"/fundamentals/{stock}/") def tags(tag): """ Returns endpoint with tag concatenated. """ - return BASE_API + f"/midlands/tags/tag/{tag}/" + return BASE.with_path(f"/midlands/tags/tag/{tag}/") def chain(instrument_id): - return BASE_API + f"/options/chains/?equity_instrument_ids={instrument_id}" + return BASE.with_path(f"/options/chains/?equity_instrument_ids={instrument_id}/") def options(chain_id, dates, option_type): - return ( - BASE_API - + f"/options/instruments/?chain_id={chain_id}&expiration_dates={dates}" - + f"&state=active&tradability=tradable&type={option_type}" + return BASE.with_path( + f"/options/instruments/?chain_id={chain_id}&expiration_dates={dates}" + f"&state=active&tradability=tradable&type={option_type}" ) def market_data(option_id): - return BASE_API + f"/marketdata/options/{option_id}/" + return BASE.with_path(f"/marketdata/options/{option_id}/") def convert_token(): - return BASE_API + "/oauth2/migrate_token/" + return BASE.with_path("/oauth2/migrate_token/") diff --git a/pyrh/exceptions.py b/pyrh/exceptions.py index 6779a361..0163575e 100755 --- a/pyrh/exceptions.py +++ b/pyrh/exceptions.py @@ -1,37 +1,31 @@ """Exceptions: custom exceptions for library""" -class RobinhoodException(Exception): +class RHException(Exception): """Wrapper for custom robinhood library exceptions.""" pass -class AuthenticationError(RobinhoodException): +class AuthenticationError(RHException): """Error when trying to login to robinhood.""" pass -class LoginFailed(RobinhoodException): # TODO: Remove me - """Error when trying to login to robinhood.""" - - pass - - -class InvalidTickerSymbol(RobinhoodException): +class InvalidTickerSymbol(RHException): """When an invalid ticker (stock symbol) is given/""" pass -class InvalidInstrumentId(RobinhoodException): +class InvalidInstrumentId(RHException): """When an invalid instrument id is given/""" pass -class InvalidOptionId(RobinhoodException): +class InvalidOptionId(RHException): """When an invalid option id is given/""" pass From 64a61a5283a1cb8cb965ff6386c3ca00726e393a Mon Sep 17 00:00:00 2001 From: Adithya Balaji Date: Fri, 27 Mar 2020 10:03:03 -0400 Subject: [PATCH 2/9] Add major re-factor to use models in Authentication instead of hardcoded models. * Update example notebook * Add packages: * pytz (already dep) * marshmallow (serialization) * Move constant endpoints in session manager to endpoints.py * Add models file * Add models file with project models and schema * BaseModel + BaseSchema * OAuth + OAuthSchema * Re-factor tests --- notebooks/example.ipynb | 6 +- poetry.lock | 22 ++- pyproject.toml | 4 +- pyrh/__init__.py | 11 +- pyrh/endpoints.py | 7 +- pyrh/models.py | 107 +++++++++++ pyrh/sessionmanager.py | 362 +++++++++++++++++++---------------- tests/test_sessionmanager.py | 312 ++++++++++++++++++------------ 8 files changed, 537 insertions(+), 294 deletions(-) create mode 100644 pyrh/models.py diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index 6cebcc19..c94857dc 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -25,7 +25,7 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", - "from pyrh import Robinhood" + "from pyrh import Robinhood, dump_session, load_session" ] }, { @@ -42,8 +42,8 @@ "# Log in to app (will prompt for two-factor)\n", "rh = Robinhood(username=\"USERNAME\", password=\"PASSWORD\")\n", "rh.login()\n", - "rh.to_json() # so you don't have to do mfa again\n", - "rh.from_json() # to load the above json cache data" + "dump_session(rh) # so you don't have to do mfa again\n", + "rh = load_session(rh) # to load the above json cache data\n" ] }, { diff --git a/poetry.lock b/poetry.lock index 53b8404c..fbbf21d9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -322,6 +322,20 @@ optional = false python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*" version = "1.1.1" +[[package]] +category = "main" +description = "A lightweight library for converting complex datatypes to and from native Python datatypes." +name = "marshmallow" +optional = false +python-versions = ">=3.5" +version = "3.5.1" + +[package.extras] +dev = ["pytest", "pytz", "simplejson", "mypy (0.761)", "flake8 (3.7.9)", "flake8-bugbear (20.1.4)", "pre-commit (>=1.20,<3.0)", "tox"] +docs = ["sphinx (2.4.3)", "sphinx-issues (1.2.0)", "alabaster (0.7.12)", "sphinx-version-warning (1.1.2)"] +lint = ["mypy (0.761)", "flake8 (3.7.9)", "flake8-bugbear (20.1.4)", "pre-commit (>=1.20,<3.0)"] +tests = ["pytest", "pytz", "simplejson"] + [[package]] category = "dev" description = "McCabe checker, plugin for flake8" @@ -567,7 +581,7 @@ six = ">=1.5" category = "main" description = "World timezone definitions, modern and historical" name = "pytz" -optional = true +optional = false python-versions = "*" version = "2019.3" @@ -907,7 +921,7 @@ testing = ["jaraco.itertools", "func-timeout"] doc = ["sphinx", "sphinx-autodoc-typehints", "sphinx_rtd_theme", "autodocsumm"] [metadata] -content-hash = "b2a59b49a59fd8e3bb516adb697945577de79229aed8083ce0550fdd20b743d3" +content-hash = "8537d5b5677e1b07cb4df18228632f57b38fb2a3acbf7bc8b69597ae62c707af" python-versions = "^3.6" [metadata.files] @@ -1086,6 +1100,10 @@ markupsafe = [ {file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, ] +marshmallow = [ + {file = "marshmallow-3.5.1-py2.py3-none-any.whl", hash = "sha256:ac2e13b30165501b7d41fc0371b8df35944f5849769d136f20e2c5f6cdc6e665"}, + {file = "marshmallow-3.5.1.tar.gz", hash = "sha256:90854221bbb1498d003a0c3cc9d8390259137551917961c8b5258c64026b2f85"}, +] mccabe = [ {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, {file = "mccabe-0.6.1.tar.gz", hash = "sha256:dd8d182285a0fe56bace7f45b5e7d1a6ebcbf524e8f3bd87eb0f125271b8831f"}, diff --git a/pyproject.toml b/pyproject.toml index 9aa2ac8e..2b0269ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,8 @@ sphinx_rtd_theme = { version = "^0.4.3", optional = true } autodocsumm = { version = "^0.1.13", optional = true } tomlkit = "^0.5.11" yarl = "^1.4.2" +marshmallow = "^3.5.1" +pytz = "^2019.3" [tool.poetry.dev-dependencies] # Linting @@ -77,7 +79,7 @@ exclude = ''' [tool.isort] known_first_party = 'robinhood' -known_third_party = ["dateutil", "pytest", "requests", "requests_mock", "yarl"] +known_third_party = ["dateutil", "marshmallow", "pytest", "pytz", "requests", "requests_mock", "yarl"] multi_line_output = 3 lines_after_imports = 2 force_grid_wrap = 0 diff --git a/pyrh/__init__.py b/pyrh/__init__.py index d52e6a16..26bb7be3 100755 --- a/pyrh/__init__.py +++ b/pyrh/__init__.py @@ -2,7 +2,7 @@ from pyrh import exceptions from pyrh.robinhood import Robinhood -from pyrh.sessionmanager import SessionManager +from pyrh.sessionmanager import SessionManager, dump_session, load_session def _get_version() -> str: @@ -17,6 +17,13 @@ def _get_version() -> str: __version__ = _get_version() -__all__ = ["__version__", "Robinhood", "SessionManager", "exceptions"] +__all__ = [ + "__version__", + "Robinhood", + "SessionManager", + "load_session", + "dump_session", + "exceptions", +] del _get_version diff --git a/pyrh/endpoints.py b/pyrh/endpoints.py index 4a5c0fca..94656641 100755 --- a/pyrh/endpoints.py +++ b/pyrh/endpoints.py @@ -5,9 +5,10 @@ BASE = URL("https://api.robinhood.com") - -def login(): - return BASE.with_path("/oauth2/token/") +# OAuth +OAUTH: URL = BASE.with_path("/oauth2/token/") +OAUTH_REVOKE: URL = BASE.with_path("/oauth2/revoke_token/") +CHALLENGE: URL = lambda cid: BASE.with_path(f"/challenge/{cid}/respond") def logout(): diff --git a/pyrh/models.py b/pyrh/models.py new file mode 100644 index 00000000..49263640 --- /dev/null +++ b/pyrh/models.py @@ -0,0 +1,107 @@ +"""Define API models.""" + +from collections.abc import Mapping +from datetime import datetime +from types import SimpleNamespace +from typing import Any + +import pytz +from marshmallow import INCLUDE, Schema, fields, post_load, validate + + +CHALLENGE_TYPE_VAL = validate.OneOf(["email", "sms"]) +MAX_REPR_LEN = 50 + + +class BaseModel(SimpleNamespace): + """TODO.""" + + def __init__(self, **kwargs) -> None: + kwargs = { + k: UnknownModel(**v) if isinstance(v, Mapping) else v + for k, v in kwargs.items() + } + + self.__dict__.update(kwargs) + + def __repr__(self): + repr_ = super().__repr__() + if len(repr_) > MAX_REPR_LEN: + return repr_[:MAX_REPR_LEN] + " ...)" + else: + return repr_ + + def __len__(self): + return len(self.__dict__) + + +class UnknownModel(BaseModel): + """TODO.""" + + pass + + +class BaseSchema(Schema): + + __model__: Any = UnknownModel + + class Meta: + unknown = INCLUDE + + @post_load + def make_object(self, data, **kwargs): + return self.__model__(**data) + + +def lazy_model(class_name): + class_ = type(class_name, (BaseModel,), {}) + globals()[class_name] = class_ + + return class_ + + +class Challenge(BaseModel): + @property + def can_retry(self): + return self.remaining_attempts > 0 and ( + datetime.now(tz=pytz.utc) < self.expires_at + ) + + +class ChallengeSchema(BaseSchema): + __model__ = Challenge + + id = fields.UUID() + user = fields.UUID() + type = fields.Str(validate=CHALLENGE_TYPE_VAL) + alternate_type = fields.Str(validate=CHALLENGE_TYPE_VAL) + status = fields.Str(validate=validate.OneOf(["issued", "validated", "failed"])) + remaining_retries = fields.Int() + remaining_attempts = fields.Int() + expires_at = fields.AwareDateTime(default_timezone=pytz.UTC) + + +class OAuth(BaseModel): + @property + def is_challenge(self): + return hasattr(self, "challenge") + + @property + def is_mfa(self): + return hasattr(self, "mfa_required") + + @property + def is_valid(self): + return hasattr(self, "access_token") and hasattr(self, "refresh_token") + + +class OAuthSchema(BaseSchema): + __model__ = OAuth + + detail = fields.Str() + challenge = fields.Nested(ChallengeSchema) + mfa_required = fields.Boolean() + + access_token = fields.Str() + refresh_token = fields.Str() + expires_in = fields.Int() diff --git a/pyrh/sessionmanager.py b/pyrh/sessionmanager.py index d8888649..d1a2ba75 100644 --- a/pyrh/sessionmanager.py +++ b/pyrh/sessionmanager.py @@ -1,18 +1,21 @@ """Manage Robinhood Sessions.""" -import json import uuid -from copy import deepcopy from datetime import datetime, timedelta from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict, Optional, Tuple, Union from urllib.request import getproxies +import pytz import requests +from marshmallow import fields, post_load +from requests.exceptions import HTTPError from requests.structures import CaseInsensitiveDict +from pyrh import endpoints from pyrh.cache import CACHE_ROOT from pyrh.exceptions import AuthenticationError +from pyrh.models import CHALLENGE_TYPE_VAL, BaseSchema, OAuth, OAuthSchema CERTS_PATH: Path = Path(__file__).parent.joinpath("./ssl/certs.pem") @@ -25,13 +28,6 @@ """Path to login.json config file.""" CACHE_LOGIN.touch(exist_ok=True) -# TODO: put urls in an API module -OAUTH_TOKEN_URL: str = "https://api.robinhood.com/oauth2/token/" -"""Oauth token generation endpoint.""" - -OAUTH_REVOKE_URL: str = "https://api.robinhood.com/oauth2/revoke_token/" -"""Oauth revocation endpoint.""" - HEADERS: CaseInsensitiveDict = CaseInsensitiveDict( { "Accept": "*/*", @@ -45,11 +41,11 @@ ) """Headers used when performing requests with robinhood api.""" -EXPIRATION_TIME: int = 86400 +EXPIRATION_TIME: int = 10 """Default expiration time for requests.""" -class SessionManager(object): +class SessionManager: """Mange connectivity with Robinhood API. Once logged into the session, this class will manage automatic oauth token update @@ -57,8 +53,8 @@ class SessionManager(object): authentication on initialization. Example: - >>> sm = SessionManager() - >>> sm.login(username="USERNAME", password="PASSWORD") # xdoctest: +SKIP + >>> sm = SessionManager(username="USERNAME", password="PASSWORD") + >>> sm.login() # xdoctest: +SKIP >>> sm.logout() # xdoctest: +SKIP If you want to cache your session (you should) then you can use the following @@ -66,8 +62,8 @@ class SessionManager(object): manually re-enter multi-factor authentication codes. Example: - >>> sm.to_json() # xdoctest: +SKIP - >>> sm.from_json() # xdoctest: +SKIP + >>> dump_session(sm) # xdoctest: +SKIP + >>> load_session(sm) # xdoctest: +SKIP Args: username: The username to login to Robinhood. @@ -95,8 +91,8 @@ class SessionManager(object): def __init__( self, - username: Optional[str] = None, - password: Optional[str] = None, + username: str, + password: str, challenge_type: Optional[str] = "email", headers: Optional[CaseInsensitiveDict] = None, proxies: Optional[Dict] = None, @@ -104,63 +100,23 @@ def __init__( self.session: requests.Session = requests.session() self.session.headers = HEADERS if headers is None else headers self.session.proxies = getproxies() if proxies is None else proxies - self.expires_at = datetime.strptime("1970", "%Y") # some time in the past + self.expires_at = datetime.strptime("1970", "%Y").replace( + tzinfo=pytz.UTC + ) # some time in the past self.certs: Path = CERTS_PATH - self.username: Optional[str] = username - self.password: Optional[str] = password + self.username: str = username + self.password: str = password if challenge_type not in ["email", "sms"]: raise ValueError("challenge_type must be email or sms") self.challenge_type: str = challenge_type - self.device_token: str = str(uuid.uuid4()) - - self.access_token: Optional[str] = None - self.refresh_token: Optional[str] = None - def to_json(self, path: Optional[Union[Path, str]] = None) -> None: - """Save the current session parameters to a json file. - - Note: - This function defaults to caching this information to - ~/.robinhood/login.json - - Args: - path: The location to save the file and its name. - - """ - path = CACHE_LOGIN if path is None else path - data = deepcopy(self.__dict__) - data.pop("session") - data.pop("certs") - data["expires_at"] = self.expires_at.strftime("%Y-%m-%d %H:%M:%S") + self._gen_device_token: str = str(uuid.uuid4()) + self.oauth: OAuth = OAuth() - with open(path, "w+") as file: - file.write(json.dumps(data, indent=4, default=str)) - - def from_json(self, path: Optional[Union[Path, str]] = None) -> None: - """Load cached session parameters from a json file. - - Note: - This function defaults to caching this information to - ~/.robinhood/login.json - - Args: - path: The location and file name to load from. - - """ - path = path or CACHE_LOGIN - with open(path) as file: - data = json.load(file) - - for k, v in data.items(): - if k == "expires_at": - v = datetime.strptime(v, "%Y-%m-%d %H:%M:%S") - setattr(self, k, v) - - if self.access_token is not None: - self.session.headers.update( - {"Authorization": f"Bearer {self.access_token}"} - ) + @property + def token_expired(self): + return datetime.now(tz=pytz.UTC) > self.expires_at @property def login_set(self) -> bool: @@ -178,9 +134,7 @@ def authenticated(self) -> bool: Returns: Whether or not the session is logged in. """ - return ( - "Authorization" in self.session.headers and datetime.now() < self.expires_at - ) + return "Authorization" in self.session.headers and not self.token_expired def login(self, force_refresh: bool = False) -> None: """Login to the session. @@ -195,9 +149,7 @@ def login(self, force_refresh: bool = False) -> None: """ if "Authorization" not in self.session.headers: self._login_oauth2() - elif self.refresh_token is not None and ( - self.expires_at < datetime.now() or force_refresh - ): + elif self.oauth.is_valid and (self.token_expired or force_refresh): self._refresh_oauth2() def get( @@ -206,8 +158,9 @@ def get( params: dict = None, headers: Optional[CaseInsensitiveDict] = None, raise_errors: bool = True, + return_status: bool = False, auto_login: bool = True, - ) -> Dict: + ) -> Union[Tuple[Dict, int], Dict]: """Run a wrapped session HTTP GET request. Note: @@ -239,7 +192,7 @@ def get( if raise_errors: res.raise_for_status() - return res.json() + return (res.json(), res.status_code) if return_status else res.json() def post( self, @@ -247,8 +200,9 @@ def post( data: Optional[Dict] = None, headers: Optional[CaseInsensitiveDict] = None, raise_errors: bool = True, + return_status: bool = False, auto_login: bool = True, - ) -> Dict: + ) -> Union[Dict, Tuple[Dict, int]]: """Run a wrapped session HTTP POST request. Note: @@ -259,6 +213,7 @@ def post( url: The url to post to. data: The payload to POST to the endpoint. headers: A dict adding to and overriding the session headers. + return_status: Whether to include status in the response. raise_errors: Whether or not raise errors on POST request. auto_login: Whether or not to automatically login on restricted endpoint errors. @@ -286,11 +241,13 @@ def post( if raise_errors: res.raise_for_status() if res.headers.get("Content-Length", None) == "0": - return {} + ret = {} else: - return res.json() + ret = res.json() + + return (ret, res.status_code) if return_status else ret - def _process_auth_body(self, res: Dict) -> None: + def _configure_manager(self, oauth) -> None: """Process an authentication response dictionary. This method updates the internal state of the session based on a login or @@ -303,17 +260,74 @@ def _process_auth_body(self, res: Dict) -> None: AuthenticationError: If the input dictionary is malformed. """ - try: - self.access_token = res["access_token"] - self.refresh_token = res["refresh_token"] - self.expires_at = datetime.now() + timedelta(seconds=EXPIRATION_TIME) - self.session.headers.update( - {"Authorization": f"Bearer {self.access_token}"} - ) - except KeyError: - raise AuthenticationError( - "Authorization result body missing required responses." - ) + self.oauth = oauth + self.expires_at = datetime.now(tz=pytz.UTC) + timedelta( + seconds=self.oauth.expires_in + ) + self.session.headers.update( + {"Authorization": f"Bearer {self.oauth.access_token}"} + ) + + def _challenge_oauth2(self, oauth, oauth_payload) -> OAuth: + # login challenge + challenge_url = endpoints.CHALLENGE(oauth.challenge.id) + print( + f"Input challenge code from {oauth.challenge.type.capitalize()} " + f"({oauth.challenge.remaining_attempts}/" + f"{oauth.challenge.remaining_retries}):" + ) + challenge_code = input() + challenge_payload = {"response": str(challenge_code)} + challenge_header = CaseInsensitiveDict( + {"X-ROBINHOOD-CHALLENGE-RESPONSE-ID": str(oauth.challenge.id)} + ) + res, status = self.post( + challenge_url, + data=challenge_payload, + raise_errors=False, + headers=challenge_header, + auto_login=False, + return_status=True, + ) + oauth_inner = OAuthSchema().load(res) + if status == requests.codes.ok: + try: + res = self.post( + endpoints.OAUTH, + data=oauth_payload, + headers=challenge_header, + auto_login=False, + ) + except HTTPError: + raise AuthenticationError("Error in finalizing auth token") + else: + oauth = OAuthSchema().load(res) + return oauth + elif oauth_inner.is_challenge and oauth_inner.challenge.can_retry: + print("Invalid code entered") + return self._challenge_oauth2(oauth, oauth_payload) + else: + raise AuthenticationError("Exceeded available attempts or code expired") + + def _mfa_oauth2(self, oauth_payload, attempts=3) -> OAuth: + print(f"Input mfa code:") + mfa_code = input() + oauth_payload["mfa_code"] = mfa_code + res, status = self.post( + endpoints.OAUTH, + data=oauth_payload, + raise_errors=False, + auto_login=False, + return_status=True, + ) + attempts -= 1 + if (status != requests.codes.ok) and (attempts > 0): + print("Invalid mfa code") + return self._mfa_oauth2(oauth_payload, attempts) + elif status == requests.codes.ok: + return OAuthSchema().load(res) + else: + raise AuthenticationError("Too many incorrect mfa attempts") def _login_oauth2(self) -> None: """Create a new oauth2 token. @@ -323,11 +337,6 @@ def _login_oauth2(self) -> None: wasn't accepted, or if an mfa code is not accepted. """ - if not self.login_set: - raise AuthenticationError( - "Username and password must be passed to constructor or must be loaded " - "from json" - ) self.session.headers.pop("Authorization", None) oauth_payload = { @@ -337,59 +346,30 @@ def _login_oauth2(self) -> None: "client_id": CLIENT_ID, "expires_in": EXPIRATION_TIME, "scope": "internal", - "device_token": self.device_token, + "device_token": self._gen_device_token, "challenge_type": self.challenge_type, } - res = self.post( - OAUTH_TOKEN_URL, data=oauth_payload, raise_errors=False, auto_login=False + res, status = self.post( + endpoints.OAUTH, + data=oauth_payload, + raise_errors=False, + auto_login=False, + return_status=True, ) - if res is None or "error" in res: - raise AuthenticationError("Unknown login error") - elif "detail" in res and any(k in res["detail"] for k in ["Invalid", "Unable"]): - raise AuthenticationError(f"{res['detail']}") - elif "challenge" in res: - challenge_id = res["challenge"]["id"] - # TODO: use api module - challenge_url = ( - f"https://api.robinhood.com/challenge/{challenge_id}/respond/" - ) - print(f"Input challenge code from {self.challenge_type.capitalize()}:") - challenge_code = input() - challenge_payload = {"response": str(challenge_code)} - challenge_header = CaseInsensitiveDict( - {"X-ROBINHOOD-CHALLENGE-RESPONSE-ID": challenge_id} - ) - challenge = self.post( - challenge_url, - data=challenge_payload, - raise_errors=False, - headers=challenge_header, - auto_login=False, - ) - if challenge is None or challenge.get("status", "") != "validated": - raise AuthenticationError("Challenge response was not accepted") - res = self.post( - OAUTH_TOKEN_URL, - data=oauth_payload, - headers=challenge_header, - auto_login=False, - ) - elif "mfa_required" in res: - print(f"Input mfa code:") - mfa_code = input() - oauth_payload["mfa_code"] = mfa_code - res = self.post( - OAUTH_TOKEN_URL, - data=oauth_payload, - raise_errors=False, - auto_login=False, - ) - if res is None or (res.get("detail", "") == "Please enter a valid code."): - raise AuthenticationError("Mfa code was not accepted") + oauth = OAuthSchema().load(res) - self._process_auth_body(res) + if status == 401 and oauth.is_challenge: + oauth = self._challenge_oauth2(oauth, oauth_payload) + elif status == requests.codes.ok and oauth.is_mfa: + oauth = self._mfa_oauth2(oauth_payload) + + if not oauth.is_valid: + msg = f"{oauth.error}" if hasattr(oauth, "error") else "Unknown login error" + raise AuthenticationError(msg) + else: + self._configure_manager(oauth) def _refresh_oauth2(self) -> None: """Refresh an oauth2 token. @@ -399,22 +379,23 @@ def _refresh_oauth2(self) -> None: when trying to refresh a token. """ - if self.refresh_token is None: + if not self.oauth.is_valid: raise AuthenticationError("Cannot refresh login with unset refresh token") relogin_payload = { "grant_type": "refresh_token", - "refresh_token": self.refresh_token, + "refresh_token": self.oauth.refresh_token, "scope": "internal", "client_id": CLIENT_ID, "expires_in": EXPIRATION_TIME, } self.session.headers.pop("Authorization", None) - res = self.post( - OAUTH_TOKEN_URL, data=relogin_payload, auto_login=False, raise_errors=False - ) - if "error" in res: + try: + res = self.post(endpoints.OAUTH, data=relogin_payload, auto_login=False) + except HTTPError: raise AuthenticationError("Failed to refresh token") - self._process_auth_body(res) + + oauth = OAuthSchema().load(res) + self._configure_manager(oauth) def logout(self) -> None: """Logout from the session. @@ -423,17 +404,12 @@ def logout(self) -> None: AuthenticationError: If there is an error when logging out. """ - logout_payload = { - "client_id": CLIENT_ID, - "token": self.refresh_token, - } - res = self.post( - OAUTH_REVOKE_URL, data=logout_payload, raise_errors=False, auto_login=False - ) - if len(res) == 0: - self.access_token = None - self.refresh_token = None - else: + logout_payload = {"client_id": CLIENT_ID, "token": self.oauth.refresh_token} + try: + self.post(endpoints.OAUTH_REVOKE, data=logout_payload, auto_login=False) + self.oauth = OAuth() + self.session.headers.pop("Authorization", None) + except HTTPError: raise AuthenticationError("Could not log out") def __repr__(self) -> str: @@ -444,3 +420,67 @@ def __repr__(self) -> str: """ return f"SessionManager<{self.username}>" + + +class SessionManagerSchema(BaseSchema): + __model__ = SessionManager + + username = fields.Email() + password = fields.Str() + challenge_type = fields.Str(validate=CHALLENGE_TYPE_VAL) + oauth = fields.Nested(OAuthSchema) + expires_at = fields.AwareDateTime() + device_token = fields.Str() + headers = fields.Dict() + proxies = fields.Dict() + + @post_load + def make_object(self, data, **kwargs): + oauth = data.pop("oauth", None) + expires_at = data.pop("expires_at", None) + session_manager = self.__model__(**data) + + if oauth is not None and oauth.is_valid: + session_manager.oauth = oauth + session_manager.session.headers.update( + {"Authorization": f"Bearer {session_manager.oauth.access_token}"} + ) + if expires_at: + session_manager.expires_at = expires_at + + return session_manager + + +def dump_session(session_manager, path: Optional[Union[Path, str]] = None) -> None: + """Save the current session parameters to a json file. + + Note: + This function defaults to caching this information to + ~/.robinhood/login.json + + Args: + session_manager: A SessionManager instance. + path: The location to save the file and its name. + + """ + path = CACHE_LOGIN if path is None else path + json_str = SessionManagerSchema().dumps(session_manager, indent=4) + + with open(path, "w+") as file: + file.write(json_str) + + +def load_session(path: Optional[Union[Path, str]] = None) -> SessionManager: + """Load cached session parameters from a json file. + + Note: + This function defaults to caching this information to + ~/.robinhood/login.json + + Args: + path: The location and file name to load from. + + """ + path = path or CACHE_LOGIN + with open(path) as file: + return SessionManagerSchema().loads(file.read()) diff --git a/tests/test_sessionmanager.py b/tests/test_sessionmanager.py index 787765bc..b9ad0537 100644 --- a/tests/test_sessionmanager.py +++ b/tests/test_sessionmanager.py @@ -6,6 +6,9 @@ import requests_mock +MOCK_URL = "mock://test.com" + + @pytest.fixture def sm(): from pyrh.sessionmanager import SessionManager @@ -18,94 +21,97 @@ def sm(): return SessionManager(**sample_user) -def test_repr(sm): - assert str(sm) == "SessionManager" - - -def test_bad_challenge_type(): +@pytest.fixture +def sm_adap(monkeypatch): from pyrh.sessionmanager import SessionManager - with pytest.raises(ValueError) as e: - SessionManager(challenge_type="bad") - - assert "challenge_type must be" in str(e.value) - - -def test_no_user_pass_oauth2(monkeypatch): - from pyrh.sessionmanager import SessionManager - from pyrh.exceptions import AuthenticationError + sample_user = { + "username": "user@example.com", + "password": "some password", + } - def ret_false(): - return False + monkeypatch.setattr("pyrh.endpoints.OAUTH", MOCK_URL) + monkeypatch.setattr("pyrh.endpoints.OAUTH_REVOKE", MOCK_URL) + monkeypatch.setattr("pyrh.endpoints.CHALLENGE", lambda x: MOCK_URL) - with pytest.raises(AuthenticationError) as e: - sm = SessionManager() - monkeypatch.setattr(sm, "from_json", ret_false) - sm._login_oauth2() + session_manager = SessionManager(**sample_user) + adapter = requests_mock.Adapter() + session_manager.session.mount("mock", adapter) - assert "Username and password must be" in str(e.value) + return session_manager, adapter -def test_login_oauth2_err(monkeypatch, sm): - from pyrh.exceptions import AuthenticationError +def test_repr(sm): + assert str(sm) == "SessionManager" - def err_dict(*args, **kwargs): - return {"error": "Some error"} - def none_resp(*args, **kwargs): - return None +def test_bad_challenge_type(sm): + from pyrh.sessionmanager import SessionManager - monkeypatch.setattr(sm, "post", err_dict) - with pytest.raises(AuthenticationError) as e1: - sm._login_oauth2() + sample_user = { + "username": "user@example.com", + "password": "some password", + } - monkeypatch.setattr(sm, "post", none_resp) - with pytest.raises(AuthenticationError) as e2: - sm._login_oauth2() + with pytest.raises(ValueError) as e: + SessionManager(**sample_user, challenge_type="bad") - assert "Unknown login error" in str(e1.value) - assert "Unknown login error" in str(e2.value) + assert "challenge_type must be" in str(e.value) -def test_login_oauth2_detail(monkeypatch, sm): +def test_login_oauth2_errors(monkeypatch, sm_adap): from pyrh.exceptions import AuthenticationError - def invalid_jwt(*args, **kwargs): - return {"detail": "Invalid JWT. Signature has expired"} + sm, adapter = sm_adap + + # Note it is not possible to get invalid results to replace + # oauth from the mfa approaches as those individual functions will error + # out themselves - monkeypatch.setattr(sm, "post", invalid_jwt) + monkeypatch.setattr("pyrh.endpoints.OAUTH", MOCK_URL) + adapter.register_uri( + "POST", MOCK_URL, text='{"error": "Some error"}', status_code=400 + ) with pytest.raises(AuthenticationError) as e: sm._login_oauth2() - assert "Invalid JWT" in str(e.value) + assert "Some error" in str(e.value) -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_login_oauth2_challenge_valid(post_mock, monkeypatch, sm): +@mock.patch("pyrh.sessionmanager.datetime") +def test_login_oauth2_challenge_valid(dt, monkeypatch, sm_adap): + import uuid + from datetime import datetime + import pytz + from pyrh.models import OAuthSchema + monkeypatch.setattr("builtins.input", lambda: "123456") - post_mock.side_effect = [ + now = datetime.strptime("2005", "%Y").replace(tzinfo=pytz.UTC) + expiry = datetime.strptime("2010", "%Y").replace(tzinfo=pytz.UTC) + dt.now = mock.Mock(return_value=now) + responses = [ { "detail": "Request blocked, challenge issued.", "challenge": { - "id": "some_id", - "user": "some_user", + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), "type": "email", "alternate_type": "sms", "status": "issued", "remaining_retries": 3, "remaining_attempts": 3, - "expires_at": "some_datetime", + "expires_at": expiry, }, }, { - "id": "some_id", - "user": "some_user", + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), "type": "email", "alternate_type": "sms", "status": "validated", "remaining_retries": 0, "remaining_attempts": 0, - "expires_at": "some_datetime", + "expires_at": expiry, }, { "access_token": "some_token", @@ -117,56 +123,104 @@ def test_login_oauth2_challenge_valid(post_mock, monkeypatch, sm): "backup_code": None, }, ] + expected = [ + {"text": OAuthSchema().dumps(responses[0]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[1]), "status_code": 200}, + {"text": OAuthSchema().dumps(responses[2]), "status_code": 200}, + ] + sm, adapter = sm_adap + adapter.register_uri("POST", MOCK_URL, expected) sm._login_oauth2() - assert post_mock.call_count == 3 + assert sm.oauth.is_valid -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_login_oauth2_challenge_invalid(post_mock, monkeypatch, sm): +@mock.patch("pyrh.sessionmanager.datetime") +def test_login_oauth2_challenge_invalid(dt, monkeypatch, sm_adap): from pyrh.exceptions import AuthenticationError + from datetime import datetime + from pyrh.models import OAuthSchema + import pytz + import uuid monkeypatch.setattr("builtins.input", lambda: "123456") - post_mock.side_effect = [ + now = datetime.strptime("2005", "%Y").replace(tzinfo=pytz.UTC) + expiry = datetime.strptime("2010", "%Y").replace(tzinfo=pytz.UTC) + dt.now = mock.Mock(return_value=now) + responses = [ { "detail": "Request blocked, challenge issued.", "challenge": { - "id": "some_id", - "user": "some_user", + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), "type": "email", "alternate_type": "sms", "status": "issued", "remaining_retries": 3, "remaining_attempts": 3, - "expires_at": "some_datetime", + "expires_at": expiry, }, }, { "detail": "Challenge response is invalid.", "challenge": { - "id": "some_id", - "user": "some_user", + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), "type": "email", "alternate_type": "sms", "status": "issued", "remaining_retries": 3, "remaining_attempts": 2, - "expires_at": "some_datetime", + "expires_at": expiry, }, }, + { + "detail": "Challenge response is invalid.", + "challenge": { + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), + "type": "email", + "alternate_type": "sms", + "status": "issued", + "remaining_retries": 3, + "remaining_attempts": 1, + "expires_at": expiry, + }, + }, + { + "detail": "Some message.", + "challenge": { + "id": str(uuid.uuid4()), + "user": str(uuid.uuid4()), + "type": "email", + "alternate_type": "sms", + "status": "failed", + "remaining_retries": 3, + "remaining_attempts": 0, + "expires_at": expiry, + }, + }, + ] + expected = [ + {"text": OAuthSchema().dumps(responses[0]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[1]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[2]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[3]), "status_code": 401}, ] + sm, adapter = sm_adap + adapter.register_uri("POST", MOCK_URL, expected) with pytest.raises(AuthenticationError) as e: sm._login_oauth2() - assert post_mock.call_count == 2 - assert "Challenge response was not accepted" in str(e.value) + assert "Exceeded available" in str(e.value) -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_login_oauth2_mfa_valid(post_mock, monkeypatch, sm): +def test_login_oauth2_mfa_valid(monkeypatch, sm_adap): + from pyrh.models import OAuthSchema + mfa_code = "123456" monkeypatch.setattr("builtins.input", lambda: mfa_code) - post_mock.side_effect = [ + responses = [ {"mfa_required": True, "mfa_type": "app"}, { "access_token": "some_token", @@ -178,39 +232,46 @@ def test_login_oauth2_mfa_valid(post_mock, monkeypatch, sm): "backup_code": None, }, ] + expected = [ + {"text": OAuthSchema().dumps(responses[0]), "status_code": 200}, + {"text": OAuthSchema().dumps(responses[1]), "status_code": 200}, + ] + sm, adapter = sm_adap + adapter.register_uri("POST", MOCK_URL, expected) sm._login_oauth2() - assert post_mock.call_count == 2 + assert sm.oauth.is_valid -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_login_oauth2_mfa_invalid(post_mock, monkeypatch, sm): +def test_login_oauth2_mfa_invalid(monkeypatch, sm_adap): from pyrh.exceptions import AuthenticationError + from pyrh.models import OAuthSchema monkeypatch.setattr("builtins.input", lambda: "123456") - post_mock.side_effect = [ + responses = [ {"mfa_required": True, "mfa_type": "app"}, - {"detail": "Please enter a valid code."}, + {"detail": "Please enter a valid code"}, + {"detail": "Please enter a valid code"}, + {"detail": "Please enter a valid code"}, + ] + expected = [ + {"text": OAuthSchema().dumps(responses[0]), "status_code": 200}, + {"text": OAuthSchema().dumps(responses[1]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[2]), "status_code": 401}, + {"text": OAuthSchema().dumps(responses[3]), "status_code": 401}, ] + sm, adapter = sm_adap + adapter.register_uri("POST", MOCK_URL, expected) with pytest.raises(AuthenticationError) as e: sm._login_oauth2() - assert post_mock.call_count == 2 - assert "Mfa code was not accepted" in str(e.value) - - -def test_process_auth_body_invalid(sm): - from pyrh.exceptions import AuthenticationError - - with pytest.raises(AuthenticationError) as e: - sm._process_auth_body({}) + assert "Too many incorrect" in str(e.value) - assert "missing required responses" in str(e.value) +def test_refresh_oauth2_success(sm_adap): + from pyrh.models import OAuthSchema -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_refresh_oauth2_success(post_mock, sm): - post_mock.return_value = { + response = { "access_token": "some_token", "expires_in": 86400, "token_type": "Bearer", @@ -219,32 +280,38 @@ def test_refresh_oauth2_success(post_mock, sm): "mfa_code": None, "backup_code": None, } + sm, adapter = sm_adap + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" + adapter.register_uri( + "POST", MOCK_URL, text=OAuthSchema().dumps(response), status_code=200 + ) sm.refresh_token = "some_token" sm.session.headers["Authorization"] = "Bearer some_token" sm._refresh_oauth2() - assert post_mock.call_count == 1 + assert sm.oauth.is_valid -@mock.patch("pyrh.sessionmanager.SessionManager.post") -def test_refresh_oauth2_failure(post_mock, sm): +def test_refresh_oauth2_failure(sm_adap): from pyrh.exceptions import AuthenticationError + from pyrh.models import OAuthSchema - post_mock.return_value = {"error": "some_error"} - - with pytest.raises(AuthenticationError) as e1: - sm._refresh_oauth2() - - assert "Cannot refresh login" in str(e1.value) + response = {"error": "some_error"} + sm, adapter = sm_adap + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" + adapter.register_uri( + "POST", MOCK_URL, text=OAuthSchema().dumps(response), status_code=401 + ) sm.refresh_token = "some_token" sm.session.headers["Authorization"] = "Bearer some_token" - - with pytest.raises(AuthenticationError) as e2: + with pytest.raises(AuthenticationError) as e: sm._refresh_oauth2() - assert "Failed to refresh" in str(e2.value) + assert "Failed to refresh" in str(e.value) @mock.patch("pyrh.sessionmanager.SessionManager._login_oauth2") @@ -255,9 +322,10 @@ def test_login_init(login_mock, sm): @mock.patch("pyrh.sessionmanager.SessionManager._refresh_oauth2") -def test_login_refresh_default(refresh_mock, monkeypatch, sm): +def test_login_refresh_default(refresh_mock, sm): # default expires_at is 1970 - monkeypatch.setattr(sm, "refresh_token", "some_token") + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" sm.session.headers["Authorization"] = "Bearer some_token" sm.login() @@ -265,8 +333,9 @@ def test_login_refresh_default(refresh_mock, monkeypatch, sm): @mock.patch("pyrh.sessionmanager.SessionManager._refresh_oauth2") -def test_login_refresh_force(refresh_mock, monkeypatch, sm): - monkeypatch.setattr(sm, "refresh_token", "some_token") +def test_login_refresh_force(refresh_mock, sm): + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" sm.session.headers["Authorization"] = "Bearer some_token" sm.login(force_refresh=True) @@ -276,64 +345,63 @@ def test_login_refresh_force(refresh_mock, monkeypatch, sm): @mock.patch("pyrh.sessionmanager.SessionManager.post") def test_logout_success(post_mock, sm): post_mock.return_value = {} - sm.access_token = "some_token" - sm.refresh_token = "some_refresh_token" + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" sm.logout() - assert sm.access_token is None - assert sm.refresh_token is None + assert len(sm.oauth) == 0 assert post_mock.call_count == 1 @mock.patch("pyrh.sessionmanager.SessionManager.post") def test_logout_failure(post_mock, sm): from pyrh.exceptions import AuthenticationError + from requests.exceptions import HTTPError + + def raise_error(*args, **kwargs): + raise HTTPError - post_mock.return_value = {"error": "some_error"} - sm.access_token = "some_token" - sm.refresh_token = "some_refresh_token" + post_mock.side_effect = raise_error + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" with pytest.raises(AuthenticationError) as e: sm.logout() - assert sm.access_token == "some_token" - assert sm.refresh_token == "some_refresh_token" + assert sm.oauth.access_token == "some_token" + assert sm.oauth.refresh_token == "some_refresh_token" assert post_mock.call_count == 1 assert "Could not log out" == str(e.value) def test_jsonify(tmpdir, sm): - from copy import deepcopy import json - sm.access_token = "some_token" - sm.refresh_token = "some_refresh_token" + from pyrh.sessionmanager import dump_session, load_session + + sm.oauth.access_token = "some_token" + sm.oauth.refresh_token = "some_refresh_token" file = tmpdir.join("login.json") file.ensure(file=True) # this will likely migrate to pathlib at some point with pytest.raises(json.JSONDecodeError) as e: - sm.from_json(file) + load_session(file) assert "Expecting value" in str(e.value) - data = deepcopy(sm.__dict__) - data.pop("session") - data.pop("certs") - sm.to_json(file) - - sm.from_json(file) - data2 = deepcopy(sm.__dict__) - data2.pop("session") - data2.pop("certs") + dump_session(sm, file) + sm1 = load_session(file) - assert data == data2 + # TODO: make this test a bit more robust + assert sm.oauth == sm1.oauth @mock.patch("pyrh.sessionmanager.datetime") def test_authenticated(dt_mock, sm, monkeypatch): + import pytz from datetime import datetime as dt, timedelta from pyrh.sessionmanager import EXPIRATION_TIME - dt_mock.now.return_value = dt.strptime("2000", "%Y") + dt_mock.now.return_value = dt.strptime("2000", "%Y").replace(tzinfo=pytz.UTC) assert not sm.authenticated expires_at_time = dt_mock.now() + timedelta(seconds=EXPIRATION_TIME) From fcb339620d141595c0d49f0399333626e31705bc Mon Sep 17 00:00:00 2001 From: Adithya Balaji Date: Sun, 5 Apr 2020 13:40:13 -0400 Subject: [PATCH 3/9] Create models module and move models.py and sessionmanager.py * Move mypy to local execution and add poetry installation steps to the build process * Add freezegun package * Resolve typing issues in project --- .github/workflows/main.yml | 23 ++++ .pre-commit-config.yaml | 16 ++- poetry.lock | 37 +++++- pyproject.toml | 26 +++-- pyrh/__init__.py | 22 +--- pyrh/common.py | 6 + pyrh/endpoints.py | 6 +- pyrh/models.py | 107 ----------------- pyrh/models/__init__.py | 7 ++ pyrh/models/base.py | 52 +++++++++ pyrh/models/oauth.py | 61 ++++++++++ pyrh/{ => models}/sessionmanager.py | 172 +++++++++++++++++++--------- pyrh/robinhood.py | 2 +- setup.cfg | 12 +- tests/__init__.py | 0 tests/test_sessionmanager.py | 58 +++++----- 16 files changed, 376 insertions(+), 231 deletions(-) create mode 100644 pyrh/common.py delete mode 100644 pyrh/models.py create mode 100644 pyrh/models/__init__.py create mode 100644 pyrh/models/base.py create mode 100644 pyrh/models/oauth.py rename pyrh/{ => models}/sessionmanager.py (77%) create mode 100644 tests/__init__.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 62e2a903..dc671920 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -15,6 +15,29 @@ jobs: python-version: 3.7 - name: set PY run: echo "::set-env name=PY::$(python -c 'import hashlib, sys;print(hashlib.sha256(sys.version.encode()+sys.executable.encode()).hexdigest())')" + - name: Install Poetry + run: | + pip install --upgrade pip + curl -sSL https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py | python + - name: Add Poetry to Path Unix + run: echo "::add-path::$HOME/.poetry/bin" + if: (${{ runner.os }} == "Linux") || (${{ runner.os }} == "macOS") + - name: Configure Poetry + run: | + poetry config virtualenvs.in-project false + poetry config virtualenvs.path ~/.virtualenvs + - name: Cache Poetry virtualenv + uses: actions/cache@v1 + id: cache-poetry + with: + path: ~/.virtualenvs + key: poetry|pre-commit|${{ matrix.os }}|${{ env.PY }}|${{ hashFiles('poetry.lock') }} + restore-keys: | + poetry-${{ hashFiles('poetry.lock') }} + - name: Install Project Dependencies (Poetry) + run: | + poetry install -vvv + if: steps.cache-poetry.outputs.cache-hit != 'true' - uses: actions/cache@v1 with: path: ~/.cache/pre-commit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 889c8bba..8259dc39 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,10 @@ default_language_version: - python: python3.6 + python: python3.7 repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v2.5.0 hooks: + - id: check-toml - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace @@ -30,8 +31,13 @@ repos: rev: 3.7.9 hooks: - id: flake8 - additional_dependencies: ['flake8-docstrings'] -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.770 + additional_dependencies: ['flake8-docstrings==1.5.0', 'darglint==1.2.1'] +- repo: local hooks: - id: mypy + name: mypy + entry: mypy + language: system + types: [python] + args: ['-p=pyrh'] + pass_filenames: false diff --git a/poetry.lock b/poetry.lock index fbbf21d9..158cc1a1 100644 --- a/poetry.lock +++ b/poetry.lock @@ -146,6 +146,15 @@ version = "5.0.4" [package.extras] toml = ["toml"] +[[package]] +category = "dev" +description = "A utility for ensuring Google-style docstrings stay up to date with the source code." +marker = "python_version >= \"3.7\" and python_version < \"4.0\"" +name = "darglint" +optional = false +python-versions = ">=3.7,<4.0" +version = "1.2.1" + [[package]] category = "dev" description = "Decorators for Humans" @@ -196,6 +205,18 @@ version = "1.5.0" flake8 = ">=3" pydocstyle = ">=2.1" +[[package]] +category = "dev" +description = "Let your Python tests travel through time" +name = "freezegun" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +version = "0.3.15" + +[package.dependencies] +python-dateutil = ">=1.0,<2.0 || >2.0" +six = "*" + [[package]] category = "main" description = "Internationalized Domain Names in Applications (IDNA)" @@ -805,7 +826,7 @@ version = "0.10.0" category = "main" description = "Style preserving TOML library" name = "tomlkit" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" version = "0.5.11" @@ -848,7 +869,7 @@ python-versions = "*" version = "1.4.1" [[package]] -category = "dev" +category = "main" description = "Backported and Experimental Type Hints for Python 3.5+" name = "typing-extensions" optional = false @@ -918,10 +939,10 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [extras] -doc = ["sphinx", "sphinx-autodoc-typehints", "sphinx_rtd_theme", "autodocsumm"] +doc = ["sphinx", "sphinx-autodoc-typehints", "sphinx_rtd_theme", "autodocsumm", "tomlkit"] [metadata] -content-hash = "8537d5b5677e1b07cb4df18228632f57b38fb2a3acbf7bc8b69597ae62c707af" +content-hash = "ea1d8069d64d9480fd2b40cd50107dafc0688606674d23ed1268dff0e1bc33c4" python-versions = "^3.6" [metadata.files] @@ -1009,6 +1030,10 @@ coverage = [ {file = "coverage-5.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:4482f69e0701139d0f2c44f3c395d1d1d37abd81bfafbf9b6efbe2542679d892"}, {file = "coverage-5.0.4.tar.gz", hash = "sha256:1b60a95fc995649464e0cd48cecc8288bac5f4198f21d04b8229dc4097d76823"}, ] +darglint = [ + {file = "darglint-1.2.1-py3-none-any.whl", hash = "sha256:16ee69a67fc0f3a89917ba4028b9c50491d7cb4e569cb94eed2e013e2a574c77"}, + {file = "darglint-1.2.1.tar.gz", hash = "sha256:7fec9d38b545f49650e96c45f9c62d6ba6cc9c9f66d305a348813c1c0a6fcb02"}, +] decorator = [ {file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"}, {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"}, @@ -1029,6 +1054,10 @@ flake8-docstrings = [ {file = "flake8-docstrings-1.5.0.tar.gz", hash = "sha256:3d5a31c7ec6b7367ea6506a87ec293b94a0a46c0bce2bb4975b7f1d09b6f3717"}, {file = "flake8_docstrings-1.5.0-py2.py3-none-any.whl", hash = "sha256:a256ba91bc52307bef1de59e2a009c3cf61c3d0952dbe035d6ff7208940c2edc"}, ] +freezegun = [ + {file = "freezegun-0.3.15-py2.py3-none-any.whl", hash = "sha256:82c757a05b7c7ca3e176bfebd7d6779fd9139c7cb4ef969c38a28d74deef89b2"}, + {file = "freezegun-0.3.15.tar.gz", hash = "sha256:e2062f2c7f95cc276a834c22f1a17179467176b624cc6f936e8bc3be5535ad1b"}, +] idna = [ {file = "idna-2.9-py2.py3-none-any.whl", hash = "sha256:a068a21ceac8a4d63dbfd964670474107f541babbd2250d61922f029858365fa"}, {file = "idna-2.9.tar.gz", hash = "sha256:7588d1c14ae4c77d74036e8c22ff447b26d0fde8f007354fd48a7814db15b7cb"}, diff --git a/pyproject.toml b/pyproject.toml index 2b0269ea..2500fa70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,42 +25,48 @@ include = [ [tool.poetry.dependencies] python = "^3.6" -requests = "^2.23" -python-dateutil = "^2.8" -# docs need to be distributed for readthedocs +# Doc Dependencies (need to be distributed for readthedocs) +autodocsumm = { version = "^0.1.13", optional = true } sphinx = { version = "^2.4.4", optional = true } sphinx-autodoc-typehints = { version = "^1.10.3", optional = true } sphinx_rtd_theme = { version = "^0.4.3", optional = true } -autodocsumm = { version = "^0.1.13", optional = true } -tomlkit = "^0.5.11" -yarl = "^1.4.2" +tomlkit = {version = "^0.5.11", optional = true } + +# Main Dependencies marshmallow = "^3.5.1" +python-dateutil = "^2.8" pytz = "^2019.3" +requests = "^2.23" +yarl = "^1.4.2" +typing-extensions = "^3.7.4" [tool.poetry.dev-dependencies] # Linting +# These are version locked so .pre-commit-config is even flake8 = "3.7.9" flake8-docstrings = "1.5.0" black = { version = "19.10b0", python = "^3.6" } +darglint = { version = "1.2.1", python = "^3.7"} isort = { version = "4.3.21", extras = ["pyproject"] } seed-isort-config = { version = "2.1.0", python = "^3.7" } mypy = "0.770" nbstripout = "^0.3.7" # Testing +coverage = "^5.0.4" +freezegun = "^0.3.15" pytest = "^5.4.1" +pytest-cov = "^2.8.1" pytest-mock = "^2.0.0" requests-mock = "^1.7.0" -coverage = "^5.0.4" -pytest-cov = "^2.8.1" xdoctest = "^0.11.0" # Automation towncrier = "^19.2.0" [tool.poetry.extras] -doc = ["sphinx", "sphinx-autodoc-typehints", "sphinx_rtd_theme", "autodocsumm", "toml"] +doc = ["sphinx", "sphinx-autodoc-typehints", "sphinx_rtd_theme", "autodocsumm", "tomlkit"] [tool.black] include = '\.pyi?$' @@ -79,7 +85,7 @@ exclude = ''' [tool.isort] known_first_party = 'robinhood' -known_third_party = ["dateutil", "marshmallow", "pytest", "pytz", "requests", "requests_mock", "yarl"] +known_third_party = ["dateutil", "freezegun", "marshmallow", "pytest", "pytz", "requests", "requests_mock", "typing_extensions", "yarl"] multi_line_output = 3 lines_after_imports = 2 force_grid_wrap = 0 diff --git a/pyrh/__init__.py b/pyrh/__init__.py index 26bb7be3..b4c1deee 100755 --- a/pyrh/__init__.py +++ b/pyrh/__init__.py @@ -1,29 +1,15 @@ """Export pyrh sub classes.""" -from pyrh import exceptions -from pyrh.robinhood import Robinhood -from pyrh.sessionmanager import SessionManager, dump_session, load_session +from . import exceptions +from .models import dump_session, load_session +from .robinhood import Robinhood -def _get_version() -> str: - from pathlib import Path - from tomlkit import parse - - pyproject_path = Path(__file__).resolve().parent.joinpath("../pyproject.toml") - with open(pyproject_path) as file: - pyproject = parse(file.read()) - - return str(pyproject["tool"]["poetry"]["version"]) - - -__version__ = _get_version() +__version__ = "2.0" __all__ = [ "__version__", "Robinhood", - "SessionManager", "load_session", "dump_session", "exceptions", ] - -del _get_version diff --git a/pyrh/common.py b/pyrh/common.py new file mode 100644 index 00000000..ae25e0b1 --- /dev/null +++ b/pyrh/common.py @@ -0,0 +1,6 @@ +"""Shared resources across pyrh.""" + +from typing import Any, Dict + + +JSON = Dict[str, Any] diff --git a/pyrh/endpoints.py b/pyrh/endpoints.py index 94656641..71e1902e 100755 --- a/pyrh/endpoints.py +++ b/pyrh/endpoints.py @@ -1,5 +1,7 @@ """Define Robinhood endpoints.""" +from typing import Callable + from yarl import URL @@ -8,7 +10,9 @@ # OAuth OAUTH: URL = BASE.with_path("/oauth2/token/") OAUTH_REVOKE: URL = BASE.with_path("/oauth2/revoke_token/") -CHALLENGE: URL = lambda cid: BASE.with_path(f"/challenge/{cid}/respond") +CHALLENGE: Callable[[str], URL] = lambda cid: BASE.with_path( + f"/challenge/{cid}/respond" +) def logout(): diff --git a/pyrh/models.py b/pyrh/models.py deleted file mode 100644 index 49263640..00000000 --- a/pyrh/models.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Define API models.""" - -from collections.abc import Mapping -from datetime import datetime -from types import SimpleNamespace -from typing import Any - -import pytz -from marshmallow import INCLUDE, Schema, fields, post_load, validate - - -CHALLENGE_TYPE_VAL = validate.OneOf(["email", "sms"]) -MAX_REPR_LEN = 50 - - -class BaseModel(SimpleNamespace): - """TODO.""" - - def __init__(self, **kwargs) -> None: - kwargs = { - k: UnknownModel(**v) if isinstance(v, Mapping) else v - for k, v in kwargs.items() - } - - self.__dict__.update(kwargs) - - def __repr__(self): - repr_ = super().__repr__() - if len(repr_) > MAX_REPR_LEN: - return repr_[:MAX_REPR_LEN] + " ...)" - else: - return repr_ - - def __len__(self): - return len(self.__dict__) - - -class UnknownModel(BaseModel): - """TODO.""" - - pass - - -class BaseSchema(Schema): - - __model__: Any = UnknownModel - - class Meta: - unknown = INCLUDE - - @post_load - def make_object(self, data, **kwargs): - return self.__model__(**data) - - -def lazy_model(class_name): - class_ = type(class_name, (BaseModel,), {}) - globals()[class_name] = class_ - - return class_ - - -class Challenge(BaseModel): - @property - def can_retry(self): - return self.remaining_attempts > 0 and ( - datetime.now(tz=pytz.utc) < self.expires_at - ) - - -class ChallengeSchema(BaseSchema): - __model__ = Challenge - - id = fields.UUID() - user = fields.UUID() - type = fields.Str(validate=CHALLENGE_TYPE_VAL) - alternate_type = fields.Str(validate=CHALLENGE_TYPE_VAL) - status = fields.Str(validate=validate.OneOf(["issued", "validated", "failed"])) - remaining_retries = fields.Int() - remaining_attempts = fields.Int() - expires_at = fields.AwareDateTime(default_timezone=pytz.UTC) - - -class OAuth(BaseModel): - @property - def is_challenge(self): - return hasattr(self, "challenge") - - @property - def is_mfa(self): - return hasattr(self, "mfa_required") - - @property - def is_valid(self): - return hasattr(self, "access_token") and hasattr(self, "refresh_token") - - -class OAuthSchema(BaseSchema): - __model__ = OAuth - - detail = fields.Str() - challenge = fields.Nested(ChallengeSchema) - mfa_required = fields.Boolean() - - access_token = fields.Str() - refresh_token = fields.Str() - expires_in = fields.Int() diff --git a/pyrh/models/__init__.py b/pyrh/models/__init__.py new file mode 100644 index 00000000..2d3128f5 --- /dev/null +++ b/pyrh/models/__init__.py @@ -0,0 +1,7 @@ +"""pyrh models and schemas""" + +from .oauth import Challenge, OAuth +from .sessionmanager import SessionManager, dump_session, load_session + + +__all__ = ["OAuth", "Challenge", "SessionManager", "dump_session", "load_session"] diff --git a/pyrh/models/base.py b/pyrh/models/base.py new file mode 100644 index 00000000..1bf02d29 --- /dev/null +++ b/pyrh/models/base.py @@ -0,0 +1,52 @@ +"""Base Model.""" + +from types import SimpleNamespace +from typing import Any, Dict, Mapping, Tuple + +from marshmallow import INCLUDE, Schema, post_load + +from pyrh.common import JSON + + +MAX_REPR_LEN = 50 + + +class BaseModel(SimpleNamespace): + """TODO.""" + + def __init__(self, **kwargs: Any) -> None: + kwargs = { + k: UnknownModel(**v) if isinstance(v, Mapping) else v + for k, v in kwargs.items() + } + + self.__dict__.update(kwargs) + + def __repr__(self) -> str: + repr_ = super().__repr__() + if len(repr_) > MAX_REPR_LEN: + return repr_[:MAX_REPR_LEN] + " ...)" + else: + return repr_ + + def __len__(self) -> int: + return len(self.__dict__) + + +class UnknownModel(BaseModel): + """TODO.""" + + pass + + +class BaseSchema(Schema): + """TODO.""" + + __model__: Any = UnknownModel + + class Meta: + unknown = INCLUDE + + @post_load + def make_object(self, data: JSON, **kwargs: Any) -> "__model__": + return self.__model__(**data) diff --git a/pyrh/models/oauth.py b/pyrh/models/oauth.py new file mode 100644 index 00000000..30471d6f --- /dev/null +++ b/pyrh/models/oauth.py @@ -0,0 +1,61 @@ +"""Oauth models.""" + +from datetime import datetime + +import pytz +from marshmallow import fields, validate + +from .base import BaseModel, BaseSchema + + +CHALLENGE_TYPE_VAL = validate.OneOf(["email", "sms"]) + + +class Challenge(BaseModel): + remaining_attempts = 0 + + @property + def can_retry(self) -> bool: + """TODO.""" + return self.remaining_attempts > 0 and ( + datetime.now(tz=pytz.utc) < self.expires_at + ) + + +class ChallengeSchema(BaseSchema): + __model__ = Challenge + + id = fields.UUID() + user = fields.UUID() + type = fields.Str(validate=CHALLENGE_TYPE_VAL) + alternate_type = fields.Str(validate=CHALLENGE_TYPE_VAL) + status = fields.Str(validate=validate.OneOf(["issued", "validated", "failed"])) + remaining_retries = fields.Int() + remaining_attempts = fields.Int() + expires_at = fields.AwareDateTime(default_timezone=pytz.UTC) # type: ignore + + +class OAuth(BaseModel): + @property + def is_challenge(self) -> bool: + return hasattr(self, "challenge") + + @property + def is_mfa(self) -> bool: + return hasattr(self, "mfa_required") + + @property + def is_valid(self) -> bool: + return hasattr(self, "access_token") and hasattr(self, "refresh_token") + + +class OAuthSchema(BaseSchema): + __model__ = OAuth + + detail = fields.Str() + challenge = fields.Nested(ChallengeSchema) + mfa_required = fields.Boolean() + + access_token = fields.Str() + refresh_token = fields.Str() + expires_in = fields.Int() diff --git a/pyrh/sessionmanager.py b/pyrh/models/sessionmanager.py similarity index 77% rename from pyrh/sessionmanager.py rename to pyrh/models/sessionmanager.py index d1a2ba75..6ed16480 100644 --- a/pyrh/sessionmanager.py +++ b/pyrh/models/sessionmanager.py @@ -3,19 +3,25 @@ import uuid from datetime import datetime, timedelta from pathlib import Path -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union, cast, overload from urllib.request import getproxies import pytz import requests from marshmallow import fields, post_load +from requests import Response from requests.exceptions import HTTPError from requests.structures import CaseInsensitiveDict +from typing_extensions import Literal +from yarl import URL from pyrh import endpoints from pyrh.cache import CACHE_ROOT +from pyrh.common import JSON from pyrh.exceptions import AuthenticationError -from pyrh.models import CHALLENGE_TYPE_VAL, BaseSchema, OAuth, OAuthSchema + +from .base import BaseSchema +from .oauth import CHALLENGE_TYPE_VAL, OAuth, OAuthSchema CERTS_PATH: Path = Path(__file__).parent.joinpath("./ssl/certs.pem") @@ -28,7 +34,9 @@ """Path to login.json config file.""" CACHE_LOGIN.touch(exist_ok=True) -HEADERS: CaseInsensitiveDict = CaseInsensitiveDict( +HTTPHeader = CaseInsensitiveDict[str] +Proxies = Dict[str, str] +HEADERS: HTTPHeader = HTTPHeader( { "Accept": "*/*", "Accept-Encoding": "gzip, deflate", @@ -94,8 +102,8 @@ def __init__( username: str, password: str, challenge_type: Optional[str] = "email", - headers: Optional[CaseInsensitiveDict] = None, - proxies: Optional[Dict] = None, + headers: Optional[HTTPHeader] = None, + proxies: Optional[Proxies] = None, ) -> None: self.session: requests.Session = requests.session() self.session.headers = HEADERS if headers is None else headers @@ -115,7 +123,12 @@ def __init__( self.oauth: OAuth = OAuth() @property - def token_expired(self): + def token_expired(self) -> bool: + """Check if the issued auth token has expired. + + Returns: + True if expired otherwise False + """ return datetime.now(tz=pytz.UTC) > self.expires_at @property @@ -152,15 +165,42 @@ def login(self, force_refresh: bool = False) -> None: elif self.oauth.is_valid and (self.token_expired or force_refresh): self._refresh_oauth2() + @overload + def get( + self, + url: Union[str, URL], + params: Optional[Dict[str, Any]] = None, + *, + headers: Optional[HTTPHeader] = None, + raise_errors: bool = True, + return_response: Literal[True], + auto_login: bool = True, + ) -> Response: + ... + + @overload + def get( + self, + url: Union[str, URL], + params: Optional[Dict[str, Any]] = None, + *, + headers: Optional[HTTPHeader] = None, + raise_errors: bool = True, + return_response: Literal[False] = ..., + auto_login: bool = True, + ) -> JSON: + ... + def get( self, - url: str, - params: dict = None, - headers: Optional[CaseInsensitiveDict] = None, + url: Union[str, URL], + params: Optional[Dict[str, Any]] = None, + *, + headers: Optional[HTTPHeader] = None, raise_errors: bool = True, - return_status: bool = False, + return_response: bool = False, auto_login: bool = True, - ) -> Union[Tuple[Dict, int], Dict]: + ) -> Union[Response, JSON]: """Run a wrapped session HTTP GET request. Note: @@ -179,30 +219,56 @@ def get( The POST response """ - if params is None: - params = {} + params = {} if params is None else params res = self.session.get( - url, params=params, headers={} if headers is None else headers + str(url), params=params, headers={} if headers is None else headers ) if res.status_code == 401 and auto_login: self.login(force_refresh=True) res = self.session.get( - url, params=params, headers={} if headers is None else headers + str(url), params=params, headers={} if headers is None else headers ) if raise_errors: res.raise_for_status() - return (res.json(), res.status_code) if return_status else res.json() + return res if return_response else res.json() + @overload def post( self, - url: str, - data: Optional[Dict] = None, - headers: Optional[CaseInsensitiveDict] = None, + url: Union[str, URL], + data: Optional[JSON] = None, + *, + headers: Optional[HTTPHeader] = None, raise_errors: bool = True, - return_status: bool = False, + return_response: Literal[True], auto_login: bool = True, - ) -> Union[Dict, Tuple[Dict, int]]: + ) -> Response: + ... + + @overload + def post( + self, + url: Union[str, URL], + data: Optional[JSON] = None, + *, + headers: Optional[HTTPHeader] = None, + raise_errors: bool = True, + return_response: Literal[False] = ..., + auto_login: bool = True, + ) -> JSON: + ... + + def post( + self, + url: Union[str, URL], + data: Optional[JSON] = None, + *, + headers: Optional[HTTPHeader] = None, + raise_errors: bool = True, + return_response: bool = False, + auto_login: bool = True, + ) -> Union[JSON, Response]: """Run a wrapped session HTTP POST request. Note: @@ -213,7 +279,7 @@ def post( url: The url to post to. data: The payload to POST to the endpoint. headers: A dict adding to and overriding the session headers. - return_status: Whether to include status in the response. + return_response: Whether to include status_code in the response. raise_errors: Whether or not raise errors on POST request. auto_login: Whether or not to automatically login on restricted endpoint errors. @@ -223,7 +289,7 @@ def post( """ res = self.session.post( - url, + str(url), data=data, timeout=15, verify=self.certs, @@ -232,7 +298,7 @@ def post( if (res.status_code == 401) and auto_login: self.login(force_refresh=True) res = self.session.post( - url, + str(url), data=data, timeout=15, verify=self.certs, @@ -240,21 +306,17 @@ def post( ) if raise_errors: res.raise_for_status() - if res.headers.get("Content-Length", None) == "0": - ret = {} - else: - ret = res.json() - return (ret, res.status_code) if return_status else ret + return res if return_response else res.json() - def _configure_manager(self, oauth) -> None: + def _configure_manager(self, oauth: OAuth) -> None: """Process an authentication response dictionary. This method updates the internal state of the session based on a login or token refresh request. Args: - res: A response dictionary from a login request. + oauth: An oauth response model from a login request. Raises: AuthenticationError: If the input dictionary is malformed. @@ -268,7 +330,8 @@ def _configure_manager(self, oauth) -> None: {"Authorization": f"Bearer {self.oauth.access_token}"} ) - def _challenge_oauth2(self, oauth, oauth_payload) -> OAuth: + def _challenge_oauth2(self, oauth: OAuth, oauth_payload: JSON) -> OAuth: + """TODO.""" # login challenge challenge_url = endpoints.CHALLENGE(oauth.challenge.id) print( @@ -281,18 +344,18 @@ def _challenge_oauth2(self, oauth, oauth_payload) -> OAuth: challenge_header = CaseInsensitiveDict( {"X-ROBINHOOD-CHALLENGE-RESPONSE-ID": str(oauth.challenge.id)} ) - res, status = self.post( + res = self.post( challenge_url, data=challenge_payload, raise_errors=False, headers=challenge_header, auto_login=False, - return_status=True, + return_response=True, ) - oauth_inner = OAuthSchema().load(res) - if status == requests.codes.ok: + oauth_inner = OAuthSchema().load(res.json()) + if res.status_code == requests.codes.ok: try: - res = self.post( + res2 = self.post( endpoints.OAUTH, data=oauth_payload, headers=challenge_header, @@ -301,7 +364,7 @@ def _challenge_oauth2(self, oauth, oauth_payload) -> OAuth: except HTTPError: raise AuthenticationError("Error in finalizing auth token") else: - oauth = OAuthSchema().load(res) + oauth = OAuthSchema().load(res2) return oauth elif oauth_inner.is_challenge and oauth_inner.challenge.can_retry: print("Invalid code entered") @@ -309,23 +372,24 @@ def _challenge_oauth2(self, oauth, oauth_payload) -> OAuth: else: raise AuthenticationError("Exceeded available attempts or code expired") - def _mfa_oauth2(self, oauth_payload, attempts=3) -> OAuth: + def _mfa_oauth2(self, oauth_payload: JSON, attempts: int = 3) -> OAuth: print(f"Input mfa code:") mfa_code = input() oauth_payload["mfa_code"] = mfa_code - res, status = self.post( + res = self.post( endpoints.OAUTH, data=oauth_payload, raise_errors=False, auto_login=False, - return_status=True, + return_response=True, ) attempts -= 1 - if (status != requests.codes.ok) and (attempts > 0): + if (res.status_code != requests.codes.ok) and (attempts > 0): print("Invalid mfa code") return self._mfa_oauth2(oauth_payload, attempts) - elif status == requests.codes.ok: - return OAuthSchema().load(res) + elif res.status_code == requests.codes.ok: + # TODO: Write mypy issue on why this needs to be casted? + return cast(OAuth, OAuthSchema().load(res.json())) else: raise AuthenticationError("Too many incorrect mfa attempts") @@ -350,19 +414,19 @@ def _login_oauth2(self) -> None: "challenge_type": self.challenge_type, } - res, status = self.post( + res = self.post( endpoints.OAUTH, data=oauth_payload, raise_errors=False, auto_login=False, - return_status=True, + return_response=True, ) - oauth = OAuthSchema().load(res) + oauth = OAuthSchema().load(res.json()) - if status == 401 and oauth.is_challenge: + if res.status_code == 401 and oauth.is_challenge: oauth = self._challenge_oauth2(oauth, oauth_payload) - elif status == requests.codes.ok and oauth.is_mfa: + elif res.status_code == requests.codes.ok and oauth.is_mfa: oauth = self._mfa_oauth2(oauth_payload) if not oauth.is_valid: @@ -425,7 +489,7 @@ def __repr__(self) -> str: class SessionManagerSchema(BaseSchema): __model__ = SessionManager - username = fields.Email() + username = fields.Email() # type: ignore # Call untyped "Email" in typed context password = fields.Str() challenge_type = fields.Str(validate=CHALLENGE_TYPE_VAL) oauth = fields.Nested(OAuthSchema) @@ -435,7 +499,7 @@ class SessionManagerSchema(BaseSchema): proxies = fields.Dict() @post_load - def make_object(self, data, **kwargs): + def make_object(self, data: JSON, **kwargs: Any) -> SessionManager: oauth = data.pop("oauth", None) expires_at = data.pop("expires_at", None) session_manager = self.__model__(**data) @@ -451,7 +515,9 @@ def make_object(self, data, **kwargs): return session_manager -def dump_session(session_manager, path: Optional[Union[Path, str]] = None) -> None: +def dump_session( + session_manager: SessionManager, path: Optional[Union[Path, str]] = None +) -> None: """Save the current session parameters to a json file. Note: @@ -483,4 +549,4 @@ def load_session(path: Optional[Union[Path, str]] = None) -> SessionManager: """ path = path or CACHE_LOGIN with open(path) as file: - return SessionManagerSchema().loads(file.read()) + return cast(SessionManager, SessionManagerSchema().loads(file.read())) diff --git a/pyrh/robinhood.py b/pyrh/robinhood.py index 6de2a6b6..cb0fb605 100644 --- a/pyrh/robinhood.py +++ b/pyrh/robinhood.py @@ -12,7 +12,7 @@ InvalidOptionId, InvalidTickerSymbol, ) -from pyrh.sessionmanager import SessionManager +from pyrh.models import SessionManager class Bounds(Enum): diff --git a/setup.cfg b/setup.cfg index c0593781..86112b42 100644 --- a/setup.cfg +++ b/setup.cfg @@ -16,8 +16,16 @@ exclude_lines = if __name__ == .__main__.: [mypy] -ignore_missing_imports = True -disallow_untyped_calls = True +strict = True +disallow_untyped_decorators = False + +# TODO: Remove ignored typing errors +[mypy-*.robinhood] +ignore_errors = True +[mypy-*.endpoints] +ignore_errors = True +[mypy-tests.*] +ignore_errors = True # flake8 [flake8] diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_sessionmanager.py b/tests/test_sessionmanager.py index b9ad0537..6fd64880 100644 --- a/tests/test_sessionmanager.py +++ b/tests/test_sessionmanager.py @@ -4,6 +4,7 @@ import pytest import requests_mock +from freezegun import freeze_time MOCK_URL = "mock://test.com" @@ -11,7 +12,7 @@ @pytest.fixture def sm(): - from pyrh.sessionmanager import SessionManager + from pyrh.models import SessionManager sample_user = { "username": "user@example.com", @@ -23,7 +24,7 @@ def sm(): @pytest.fixture def sm_adap(monkeypatch): - from pyrh.sessionmanager import SessionManager + from pyrh.models import SessionManager sample_user = { "username": "user@example.com", @@ -46,7 +47,7 @@ def test_repr(sm): def test_bad_challenge_type(sm): - from pyrh.sessionmanager import SessionManager + from pyrh.models import SessionManager sample_user = { "username": "user@example.com", @@ -78,17 +79,15 @@ def test_login_oauth2_errors(monkeypatch, sm_adap): assert "Some error" in str(e.value) -@mock.patch("pyrh.sessionmanager.datetime") -def test_login_oauth2_challenge_valid(dt, monkeypatch, sm_adap): +@freeze_time("2005-01-01") +def test_login_oauth2_challenge_valid(monkeypatch, sm_adap): import uuid from datetime import datetime import pytz - from pyrh.models import OAuthSchema + from pyrh.models.oauth import OAuthSchema monkeypatch.setattr("builtins.input", lambda: "123456") - now = datetime.strptime("2005", "%Y").replace(tzinfo=pytz.UTC) expiry = datetime.strptime("2010", "%Y").replace(tzinfo=pytz.UTC) - dt.now = mock.Mock(return_value=now) responses = [ { "detail": "Request blocked, challenge issued.", @@ -135,18 +134,16 @@ def test_login_oauth2_challenge_valid(dt, monkeypatch, sm_adap): assert sm.oauth.is_valid -@mock.patch("pyrh.sessionmanager.datetime") -def test_login_oauth2_challenge_invalid(dt, monkeypatch, sm_adap): +@freeze_time("2005-01-01") +def test_login_oauth2_challenge_invalid(monkeypatch, sm_adap): from pyrh.exceptions import AuthenticationError from datetime import datetime - from pyrh.models import OAuthSchema + from pyrh.models.oauth import OAuthSchema import pytz import uuid monkeypatch.setattr("builtins.input", lambda: "123456") - now = datetime.strptime("2005", "%Y").replace(tzinfo=pytz.UTC) expiry = datetime.strptime("2010", "%Y").replace(tzinfo=pytz.UTC) - dt.now = mock.Mock(return_value=now) responses = [ { "detail": "Request blocked, challenge issued.", @@ -216,7 +213,7 @@ def test_login_oauth2_challenge_invalid(dt, monkeypatch, sm_adap): def test_login_oauth2_mfa_valid(monkeypatch, sm_adap): - from pyrh.models import OAuthSchema + from pyrh.models.oauth import OAuthSchema mfa_code = "123456" monkeypatch.setattr("builtins.input", lambda: mfa_code) @@ -245,7 +242,7 @@ def test_login_oauth2_mfa_valid(monkeypatch, sm_adap): def test_login_oauth2_mfa_invalid(monkeypatch, sm_adap): from pyrh.exceptions import AuthenticationError - from pyrh.models import OAuthSchema + from pyrh.models.oauth import OAuthSchema monkeypatch.setattr("builtins.input", lambda: "123456") responses = [ @@ -269,7 +266,7 @@ def test_login_oauth2_mfa_invalid(monkeypatch, sm_adap): def test_refresh_oauth2_success(sm_adap): - from pyrh.models import OAuthSchema + from pyrh.models.oauth import OAuthSchema response = { "access_token": "some_token", @@ -296,7 +293,7 @@ def test_refresh_oauth2_success(sm_adap): def test_refresh_oauth2_failure(sm_adap): from pyrh.exceptions import AuthenticationError - from pyrh.models import OAuthSchema + from pyrh.models.oauth import OAuthSchema response = {"error": "some_error"} sm, adapter = sm_adap @@ -314,14 +311,14 @@ def test_refresh_oauth2_failure(sm_adap): assert "Failed to refresh" in str(e.value) -@mock.patch("pyrh.sessionmanager.SessionManager._login_oauth2") +@mock.patch("pyrh.models.SessionManager._login_oauth2") def test_login_init(login_mock, sm): sm.login() assert login_mock.call_count == 1 -@mock.patch("pyrh.sessionmanager.SessionManager._refresh_oauth2") +@mock.patch("pyrh.models.SessionManager._refresh_oauth2") def test_login_refresh_default(refresh_mock, sm): # default expires_at is 1970 sm.oauth.access_token = "some_token" @@ -332,7 +329,7 @@ def test_login_refresh_default(refresh_mock, sm): assert refresh_mock.call_count == 1 -@mock.patch("pyrh.sessionmanager.SessionManager._refresh_oauth2") +@mock.patch("pyrh.models.SessionManager._refresh_oauth2") def test_login_refresh_force(refresh_mock, sm): sm.oauth.access_token = "some_token" sm.oauth.refresh_token = "some_refresh_token" @@ -342,7 +339,7 @@ def test_login_refresh_force(refresh_mock, sm): assert refresh_mock.call_count == 1 -@mock.patch("pyrh.sessionmanager.SessionManager.post") +@mock.patch("pyrh.models.SessionManager.post") def test_logout_success(post_mock, sm): post_mock.return_value = {} sm.oauth.access_token = "some_token" @@ -352,7 +349,7 @@ def test_logout_success(post_mock, sm): assert post_mock.call_count == 1 -@mock.patch("pyrh.sessionmanager.SessionManager.post") +@mock.patch("pyrh.models.SessionManager.post") def test_logout_failure(post_mock, sm): from pyrh.exceptions import AuthenticationError from requests.exceptions import HTTPError @@ -375,7 +372,7 @@ def raise_error(*args, **kwargs): def test_jsonify(tmpdir, sm): import json - from pyrh.sessionmanager import dump_session, load_session + from pyrh.models import dump_session, load_session sm.oauth.access_token = "some_token" sm.oauth.refresh_token = "some_refresh_token" @@ -394,17 +391,18 @@ def test_jsonify(tmpdir, sm): assert sm.oauth == sm1.oauth -@mock.patch("pyrh.sessionmanager.datetime") -def test_authenticated(dt_mock, sm, monkeypatch): +@freeze_time("2000-01-01") +def test_authenticated(sm, monkeypatch): import pytz from datetime import datetime as dt, timedelta - from pyrh.sessionmanager import EXPIRATION_TIME + from pyrh.models.sessionmanager import EXPIRATION_TIME - dt_mock.now.return_value = dt.strptime("2000", "%Y").replace(tzinfo=pytz.UTC) assert not sm.authenticated - expires_at_time = dt_mock.now() + timedelta(seconds=EXPIRATION_TIME) + expires_at_time = dt.now().replace(tzinfo=pytz.UTC) + timedelta( + seconds=EXPIRATION_TIME + ) sm.session.headers["Authorization"] = "Bearer some_token" sm.expires_at = expires_at_time @@ -412,7 +410,7 @@ def test_authenticated(dt_mock, sm, monkeypatch): assert sm.authenticated -@mock.patch("pyrh.sessionmanager.SessionManager.login") +@mock.patch("pyrh.models.SessionManager.login") def test_get(mock_login, sm): import json @@ -441,7 +439,7 @@ def test_get(mock_login, sm): assert "404 Client Error" in str(e.value) -@mock.patch("pyrh.sessionmanager.SessionManager.login") +@mock.patch("pyrh.models.SessionManager.login") def test_post(mock_login, sm): import json From 9b6e880e43fec7c3ce0f88da0f0270cc3e730142 Mon Sep 17 00:00:00 2001 From: Adithya Balaji Date: Sat, 11 Apr 2020 21:43:04 -0400 Subject: [PATCH 4/9] Fix failing tests and address ABCMeta Issue. * ABCMeta typing issue with requests package * https://github.com/python/mypy/issues/5264#issuecomment-399407428 * Run pytest and remove broken test as functionality was removed. --- pyrh/models/sessionmanager.py | 24 ++++++++++++++---------- tests/test_sessionmanager.py | 5 +---- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/pyrh/models/sessionmanager.py b/pyrh/models/sessionmanager.py index 6ed16480..6baff338 100644 --- a/pyrh/models/sessionmanager.py +++ b/pyrh/models/sessionmanager.py @@ -3,7 +3,7 @@ import uuid from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Dict, Optional, Union, cast, overload +from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast, overload from urllib.request import getproxies import pytz @@ -34,9 +34,13 @@ """Path to login.json config file.""" CACHE_LOGIN.touch(exist_ok=True) -HTTPHeader = CaseInsensitiveDict[str] +if TYPE_CHECKING: + CaseInsensitiveDictType = CaseInsensitiveDict[str] +else: + CaseInsensitiveDictType = CaseInsensitiveDict + Proxies = Dict[str, str] -HEADERS: HTTPHeader = HTTPHeader( +HEADERS: CaseInsensitiveDictType = CaseInsensitiveDict( { "Accept": "*/*", "Accept-Encoding": "gzip, deflate", @@ -102,7 +106,7 @@ def __init__( username: str, password: str, challenge_type: Optional[str] = "email", - headers: Optional[HTTPHeader] = None, + headers: Optional[CaseInsensitiveDictType] = None, proxies: Optional[Proxies] = None, ) -> None: self.session: requests.Session = requests.session() @@ -171,7 +175,7 @@ def get( url: Union[str, URL], params: Optional[Dict[str, Any]] = None, *, - headers: Optional[HTTPHeader] = None, + headers: Optional[CaseInsensitiveDictType] = None, raise_errors: bool = True, return_response: Literal[True], auto_login: bool = True, @@ -184,7 +188,7 @@ def get( url: Union[str, URL], params: Optional[Dict[str, Any]] = None, *, - headers: Optional[HTTPHeader] = None, + headers: Optional[CaseInsensitiveDictType] = None, raise_errors: bool = True, return_response: Literal[False] = ..., auto_login: bool = True, @@ -196,7 +200,7 @@ def get( url: Union[str, URL], params: Optional[Dict[str, Any]] = None, *, - headers: Optional[HTTPHeader] = None, + headers: Optional[CaseInsensitiveDictType] = None, raise_errors: bool = True, return_response: bool = False, auto_login: bool = True, @@ -239,7 +243,7 @@ def post( url: Union[str, URL], data: Optional[JSON] = None, *, - headers: Optional[HTTPHeader] = None, + headers: Optional[CaseInsensitiveDictType] = None, raise_errors: bool = True, return_response: Literal[True], auto_login: bool = True, @@ -252,7 +256,7 @@ def post( url: Union[str, URL], data: Optional[JSON] = None, *, - headers: Optional[HTTPHeader] = None, + headers: Optional[CaseInsensitiveDictType] = None, raise_errors: bool = True, return_response: Literal[False] = ..., auto_login: bool = True, @@ -264,7 +268,7 @@ def post( url: Union[str, URL], data: Optional[JSON] = None, *, - headers: Optional[HTTPHeader] = None, + headers: Optional[CaseInsensitiveDictType] = None, raise_errors: bool = True, return_response: bool = False, auto_login: bool = True, diff --git a/tests/test_sessionmanager.py b/tests/test_sessionmanager.py index 6fd64880..38e534d2 100644 --- a/tests/test_sessionmanager.py +++ b/tests/test_sessionmanager.py @@ -449,7 +449,6 @@ def test_post(mock_login, sm): sm.session.mount("mock", adapter) mock_url = "mock://test.com" expected = [ - {"text": "", "status_code": 200, "headers": {"Content-Length": "0"}}, {"text": '{"error": "login error"}', "status_code": 401}, {"text": '{"test": "321"}', "status_code": 200}, {"text": '{"error": "resource not found"}', "status_code": 404}, @@ -457,12 +456,10 @@ def test_post(mock_login, sm): adapter.register_uri("POST", mock_url, expected) resp1 = sm.post(mock_url) - resp2 = sm.post(mock_url) with pytest.raises(HTTPError) as e: sm.post(mock_url) - assert resp1 == {} - assert resp2 == json.loads(expected[2]["text"]) + assert resp1 == json.loads(expected[1]["text"]) assert mock_login.call_count == 1 assert "404 Client Error" in str(e.value) From a6cba7843c167f42ecb3158dcfc8a94b93a72bf0 Mon Sep 17 00:00:00 2001 From: Adithya Balaji Date: Sun, 12 Apr 2020 00:42:53 -0400 Subject: [PATCH 5/9] Add documentation and typing to code changes. * Ignore D106, this will be quite annoying with marshmallow's required Meta subclass definition. --- pyrh/models/__init__.py | 2 +- pyrh/models/base.py | 62 +++++++++++++++++++++---- pyrh/models/oauth.py | 36 ++++++++++++++- pyrh/models/sessionmanager.py | 86 +++++++++++++++++++++++++++++------ setup.cfg | 1 + 5 files changed, 162 insertions(+), 25 deletions(-) diff --git a/pyrh/models/__init__.py b/pyrh/models/__init__.py index 2d3128f5..d296e465 100644 --- a/pyrh/models/__init__.py +++ b/pyrh/models/__init__.py @@ -1,4 +1,4 @@ -"""pyrh models and schemas""" +"""pyrh models and schemas.""" from .oauth import Challenge, OAuth from .sessionmanager import SessionManager, dump_session, load_session diff --git a/pyrh/models/base.py b/pyrh/models/base.py index 1bf02d29..9fbd0ef6 100644 --- a/pyrh/models/base.py +++ b/pyrh/models/base.py @@ -1,7 +1,7 @@ """Base Model.""" from types import SimpleNamespace -from typing import Any, Dict, Mapping, Tuple +from typing import Any, Dict, List, Mapping, Union from marshmallow import INCLUDE, Schema, post_load @@ -11,18 +11,47 @@ MAX_REPR_LEN = 50 +def _process_dict_values( + value: Union[Dict[str, Any], List[Any]] +) -> Union["UnknownModel", List[Any]]: + """Process a returned from a JSON response. + + Args: + value: A dict or a list returned from a JSON response. + + Returns: + Either an UnknownModel or a List of processed values. + + """ + if isinstance(value, Mapping): + return UnknownModel(**value) + if isinstance(value, list): + return [_process_dict_values(v) for v in value] + + class BaseModel(SimpleNamespace): - """TODO.""" + """BaseModel that all models should inherit from. + + Note: + If a passed parameter is a nested dictionary, then it is created with the + `UnknownModel` class. If it is a list, then it is created with + + Args: + **kwargs: All passed parameters as converted to instance attributes. + """ def __init__(self, **kwargs: Any) -> None: - kwargs = { - k: UnknownModel(**v) if isinstance(v, Mapping) else v - for k, v in kwargs.items() - } + kwargs = {k: _process_dict_values(v) for k, v in kwargs.items()} self.__dict__.update(kwargs) def __repr__(self) -> str: + """Return a default repr of any Model. + + Returns: + The string model parameters up to a `MAX_REPR_LEN`. + + """ repr_ = super().__repr__() if len(repr_) > MAX_REPR_LEN: return repr_[:MAX_REPR_LEN] + " ...)" @@ -30,23 +59,40 @@ def __repr__(self) -> str: return repr_ def __len__(self) -> int: + """Return the length of the model. + + Returns: + The number of attributes a given model has. + + """ return len(self.__dict__) class UnknownModel(BaseModel): - """TODO.""" + """A convenience class that inherits from `BaseModel`.""" pass class BaseSchema(Schema): - """TODO.""" + """The default schema for all models.""" __model__: Any = UnknownModel + """Determines the object that is created when the load method is called.""" class Meta: unknown = INCLUDE @post_load def make_object(self, data: JSON, **kwargs: Any) -> "__model__": + """Build model for the given `__model__` class attribute. + + Args: + data: The JSON diction to use to build the model. + **kwargs: Unused but required to match signature of `Schema.make_object` + + Returns: + An instance of the `__model__` class. + + """ return self.__model__(**data) diff --git a/pyrh/models/oauth.py b/pyrh/models/oauth.py index 30471d6f..7cfa3c9e 100644 --- a/pyrh/models/oauth.py +++ b/pyrh/models/oauth.py @@ -12,17 +12,28 @@ class Challenge(BaseModel): + """The challenge response model.""" + remaining_attempts = 0 + """Default `remaining_attempts` attribute if it is not set on instance.""" @property def can_retry(self) -> bool: - """TODO.""" + """Determine if the challenge can be retried. + + Returns: + True if remaining_attempts is greater than zero and challenge is not \ + expired, False otherwise. + + """ return self.remaining_attempts > 0 and ( datetime.now(tz=pytz.utc) < self.expires_at ) class ChallengeSchema(BaseSchema): + """The challenge response schema.""" + __model__ = Challenge id = fields.UUID() @@ -36,20 +47,43 @@ class ChallengeSchema(BaseSchema): class OAuth(BaseModel): + """The OAuth response model.""" + @property def is_challenge(self) -> bool: + """Determine whether the oauth response is a challenge. + + Returns: + True response has the `challenge` key, False otherwise. + + """ return hasattr(self, "challenge") @property def is_mfa(self) -> bool: + """Determine whether the oauth response is a mfa challenge. + + Returns: + True response has the `mfa_required` key, False otherwise. + + """ return hasattr(self, "mfa_required") @property def is_valid(self) -> bool: + """Determine whether the oauth response is a valid response. + + Returns: + True if the response has both the `access_token` and `refresh_token` keys, \ + False otherwise. + + """ return hasattr(self, "access_token") and hasattr(self, "refresh_token") class OAuthSchema(BaseSchema): + """The OAuth response schema.""" + __model__ = OAuth detail = fields.Str() diff --git a/pyrh/models/sessionmanager.py b/pyrh/models/sessionmanager.py index 6baff338..db7f1b0b 100644 --- a/pyrh/models/sessionmanager.py +++ b/pyrh/models/sessionmanager.py @@ -20,7 +20,7 @@ from pyrh.common import JSON from pyrh.exceptions import AuthenticationError -from .base import BaseSchema +from .base import BaseModel, BaseSchema from .oauth import CHALLENGE_TYPE_VAL, OAuth, OAuthSchema @@ -56,8 +56,11 @@ EXPIRATION_TIME: int = 10 """Default expiration time for requests.""" +# TODO: Watch this issue and remove the F811 ignores when it is fixed +# https://gitlab.com/pycqa/flake8/-/merge_requests/417 (we need at least pyflakes 2.2.0) -class SessionManager: + +class SessionManager(BaseModel): """Mange connectivity with Robinhood API. Once logged into the session, this class will manage automatic oauth token update @@ -83,6 +86,7 @@ class SessionManager: challenge_type: Either sms or email. (only if not using mfa) headers: Any optional header dict modifications for the session. proxies: Any optional proxy dict modification for the session. + **kwargs: Any other passed parameters as converted to instance attributes. Attributes: session: A requests session instance. @@ -108,6 +112,7 @@ def __init__( challenge_type: Optional[str] = "email", headers: Optional[CaseInsensitiveDictType] = None, proxies: Optional[Proxies] = None, + **kwargs: Any, ) -> None: self.session: requests.Session = requests.session() self.session.headers = HEADERS if headers is None else headers @@ -126,6 +131,8 @@ def __init__( self._gen_device_token: str = str(uuid.uuid4()) self.oauth: OAuth = OAuth() + super().__init__(**kwargs) + @property def token_expired(self) -> bool: """Check if the issued auth token has expired. @@ -169,6 +176,10 @@ def login(self, force_refresh: bool = False) -> None: elif self.oauth.is_valid and (self.token_expired or force_refresh): self._refresh_oauth2() + # The following type hints helps mypy determine what the output type to assign based + # on the `return_response` parameter. The same "stub" method approach is used for + # the post method as well. + # https://github.com/python/mypy/issues/8634#issuecomment-609411104 @overload def get( self, @@ -179,10 +190,10 @@ def get( raise_errors: bool = True, return_response: Literal[True], auto_login: bool = True, - ) -> Response: + ) -> Response: # noqa: D102 ... - @overload + @overload # noqa: F811 def get( self, url: Union[str, URL], @@ -192,10 +203,10 @@ def get( raise_errors: bool = True, return_response: Literal[False] = ..., auto_login: bool = True, - ) -> JSON: + ) -> JSON: # noqa: D102 ... - def get( + def get( # noqa: F811 self, url: Union[str, URL], params: Optional[Dict[str, Any]] = None, @@ -216,6 +227,8 @@ def get( params: query string parameters headers: A dict adding to and overriding the session headers. raise_errors: Whether or not raise errors on GET request result. + return_response: Whether or not return a `requests.Response` object or the + JSON response from the request. auto_login: Whether or not to automatically login on restricted endpoint errors. @@ -247,10 +260,10 @@ def post( raise_errors: bool = True, return_response: Literal[True], auto_login: bool = True, - ) -> Response: + ) -> Response: # noqa: D102 ... - @overload + @overload # noqa: F811 def post( self, url: Union[str, URL], @@ -260,10 +273,10 @@ def post( raise_errors: bool = True, return_response: Literal[False] = ..., auto_login: bool = True, - ) -> JSON: + ) -> JSON: # noqa: D102 ... - def post( + def post( # noqa: F811 self, url: Union[str, URL], data: Optional[JSON] = None, @@ -283,7 +296,8 @@ def post( url: The url to post to. data: The payload to POST to the endpoint. headers: A dict adding to and overriding the session headers. - return_response: Whether to include status_code in the response. + return_response: Whether or not return a `requests.Response` object or the + JSON response from the request. raise_errors: Whether or not raise errors on POST request. auto_login: Whether or not to automatically login on restricted endpoint errors. @@ -322,9 +336,6 @@ def _configure_manager(self, oauth: OAuth) -> None: Args: oauth: An oauth response model from a login request. - Raises: - AuthenticationError: If the input dictionary is malformed. - """ self.oauth = oauth self.expires_at = datetime.now(tz=pytz.UTC) + timedelta( @@ -335,7 +346,22 @@ def _configure_manager(self, oauth: OAuth) -> None: ) def _challenge_oauth2(self, oauth: OAuth, oauth_payload: JSON) -> OAuth: - """TODO.""" + """Process the ouath challenge flow. + + Args: + oauth: An oauth response model from a login request. + oauth_payload: The payload to use once the challenge has been processed. + + Returns: + An OAuth response model from the login request. + + Raises: + AuthenticationError: If there is an error in the initial challenge response. + + .. # noqa: DAR202 + .. https://github.com/terrencepreilly/darglint/issues/81 + + """ # login challenge challenge_url = endpoints.CHALLENGE(oauth.challenge.id) print( @@ -377,6 +403,22 @@ def _challenge_oauth2(self, oauth: OAuth, oauth_payload: JSON) -> OAuth: raise AuthenticationError("Exceeded available attempts or code expired") def _mfa_oauth2(self, oauth_payload: JSON, attempts: int = 3) -> OAuth: + """Mfa auth flow. + + For people with 2fa. + + Args: + oauth_payload: JSON payload to send on mfa approval. + attempts: The number of attempts to allow for mfa approval. + + Returns: + An OAuth response model object. + + Raises: + AuthenticationError: If the mfa code is incorrect more than specified \ + number of attempts. + + """ print(f"Input mfa code:") mfa_code = input() oauth_payload["mfa_code"] = mfa_code @@ -491,6 +533,8 @@ def __repr__(self) -> str: class SessionManagerSchema(BaseSchema): + """Schema class for the SessionManager model.""" + __model__ = SessionManager username = fields.Email() # type: ignore # Call untyped "Email" in typed context @@ -504,6 +548,15 @@ class SessionManagerSchema(BaseSchema): @post_load def make_object(self, data: JSON, **kwargs: Any) -> SessionManager: + """Override default method to configure SessionManager object on load. + + Args: + data: The JSON dictionary to process + **kwargs: Not used but matches signature of `BaseSchema.make_object` + + Returns: + A configured instance of SessionManager. + """ oauth = data.pop("oauth", None) expires_at = data.pop("expires_at", None) session_manager = self.__model__(**data) @@ -550,6 +603,9 @@ def load_session(path: Optional[Union[Path, str]] = None) -> SessionManager: Args: path: The location and file name to load from. + Returns: + A configured instance of SessionManager. + """ path = path or CACHE_LOGIN with open(path) as file: diff --git a/setup.cfg b/setup.cfg index 86112b42..2c11d28c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -36,6 +36,7 @@ ignore = W503 # Line break occurred after a binary operator (opposite of W504) D107 # Missing docstring in __init__ D301 # Use r""" if any backslashes in a docstring + D106 # Nested class docs (marshmallow) max-complexity = 12 # TODO: remove docstring exemptions after refactor per-file-ignores = From 669dd3367e47270db587835bde4e68e7faef5ed0 Mon Sep 17 00:00:00 2001 From: Adithya Balaji Date: Sun, 12 Apr 2020 02:57:09 -0400 Subject: [PATCH 6/9] Add more unit tests for models and patch process_dict_values bug * _process_dict_values failed to account for passthrough case * Mark type checking code as no cover * Remove unnecessary status checking in login oauth code * Add test for token_expired and login_set in test_sessionamanger * Add tests for base model and schema * Add tests for challenge and oauth models --- pyrh/models/base.py | 15 ++++++------ pyrh/models/sessionmanager.py | 14 +++++------ tests/test_base.py | 44 +++++++++++++++++++++++++++++++++++ tests/test_oauth.py | 43 ++++++++++++++++++++++++++++++++++ tests/test_sessionmanager.py | 24 +++++++++++++++++++ 5 files changed, 126 insertions(+), 14 deletions(-) create mode 100644 tests/test_base.py create mode 100644 tests/test_oauth.py diff --git a/pyrh/models/base.py b/pyrh/models/base.py index 9fbd0ef6..24564ff4 100644 --- a/pyrh/models/base.py +++ b/pyrh/models/base.py @@ -1,7 +1,7 @@ """Base Model.""" from types import SimpleNamespace -from typing import Any, Dict, List, Mapping, Union +from typing import Any, Mapping from marshmallow import INCLUDE, Schema, post_load @@ -11,22 +11,23 @@ MAX_REPR_LEN = 50 -def _process_dict_values( - value: Union[Dict[str, Any], List[Any]] -) -> Union["UnknownModel", List[Any]]: +def _process_dict_values(value: Any) -> Any: """Process a returned from a JSON response. Args: - value: A dict or a list returned from a JSON response. + value: A dict, list, or value returned from a JSON response. Returns: - Either an UnknownModel or a List of processed values. + Either an UnknownModel, a List of processed values, or the original value \ + passed through. """ if isinstance(value, Mapping): return UnknownModel(**value) - if isinstance(value, list): + elif isinstance(value, list): return [_process_dict_values(v) for v in value] + else: + return value class BaseModel(SimpleNamespace): diff --git a/pyrh/models/sessionmanager.py b/pyrh/models/sessionmanager.py index db7f1b0b..777666e2 100644 --- a/pyrh/models/sessionmanager.py +++ b/pyrh/models/sessionmanager.py @@ -34,7 +34,7 @@ """Path to login.json config file.""" CACHE_LOGIN.touch(exist_ok=True) -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover CaseInsensitiveDictType = CaseInsensitiveDict[str] else: CaseInsensitiveDictType = CaseInsensitiveDict @@ -190,7 +190,7 @@ def get( raise_errors: bool = True, return_response: Literal[True], auto_login: bool = True, - ) -> Response: # noqa: D102 + ) -> Response: # noqa: D102 # pragma: no cover ... @overload # noqa: F811 @@ -203,7 +203,7 @@ def get( raise_errors: bool = True, return_response: Literal[False] = ..., auto_login: bool = True, - ) -> JSON: # noqa: D102 + ) -> JSON: # noqa: D102 # pragma: no cover ... def get( # noqa: F811 @@ -260,7 +260,7 @@ def post( raise_errors: bool = True, return_response: Literal[True], auto_login: bool = True, - ) -> Response: # noqa: D102 + ) -> Response: # noqa: D102 # pragma: no cover ... @overload # noqa: F811 @@ -273,7 +273,7 @@ def post( raise_errors: bool = True, return_response: Literal[False] = ..., auto_login: bool = True, - ) -> JSON: # noqa: D102 + ) -> JSON: # noqa: D102 # pragma: no cover ... def post( # noqa: F811 @@ -470,9 +470,9 @@ def _login_oauth2(self) -> None: oauth = OAuthSchema().load(res.json()) - if res.status_code == 401 and oauth.is_challenge: + if oauth.is_challenge: oauth = self._challenge_oauth2(oauth, oauth_payload) - elif res.status_code == requests.codes.ok and oauth.is_mfa: + elif oauth.is_mfa: oauth = self._mfa_oauth2(oauth_payload) if not oauth.is_valid: diff --git a/tests/test_base.py b/tests/test_base.py new file mode 100644 index 00000000..d4df403c --- /dev/null +++ b/tests/test_base.py @@ -0,0 +1,44 @@ +"""Test base model file.""" + + +def test_base_model_simplenamespace_simple(): + from pyrh.models.base import BaseModel, UnknownModel + + payload1 = {"a": 10, "b": 15, "c": 20} + payload2 = {"nested": {"b": 15, "c": 20}} + payload3 = {"list": [10, {"a": 5, "b": 10, "c": 20}]} + + bm1 = BaseModel(**payload1) + for k, v in payload1.items(): + assert getattr(bm1, k) == v + + bm2 = BaseModel(**payload2) + assert bm2.nested == UnknownModel(**payload2["nested"]) + + bm3 = BaseModel(**payload3) + assert bm3.list == [10, UnknownModel(**payload3["list"][1])] + + +def test_base_model_repr(): + from pyrh.models.base import BaseModel + + bm = BaseModel(a=10) + + assert "BaseModel(a=10)" == str(bm) + + +def test_base_model_len(): + from pyrh.models.base import BaseModel + + bm = BaseModel(a=10, b=20) + + assert len(bm) == 2 + + +def test_base_schema(): + from pyrh.models.base import UnknownModel, BaseSchema + + bm = UnknownModel(a=10) + load_bm = BaseSchema().load({"a": 10}) + assert bm == load_bm + assert type(bm) == type(load_bm) diff --git a/tests/test_oauth.py b/tests/test_oauth.py new file mode 100644 index 00000000..503e96b5 --- /dev/null +++ b/tests/test_oauth.py @@ -0,0 +1,43 @@ +"""Test the oauth classes.""" + +from freezegun import freeze_time + + +@freeze_time("2020-01-01") +def test_challenge_can_retry(): + from copy import copy + + from pyrh.models.oauth import Challenge + from datetime import datetime, timedelta + import pytz + + future = datetime.strptime("2020-01-02", "%Y-%m-%d").replace(tzinfo=pytz.UTC) + + data = {"expires_at": future} + + challenge = Challenge(**data) + + assert not challenge.can_retry + + challenge.remaining_attempts = 1 + assert challenge.can_retry + + challenge.expires_at = future - timedelta(days=3) + + assert not challenge.can_retry + + +def test_oauth_test_attrs(): + from pyrh.models.oauth import OAuth + from pyrh.models.base import UnknownModel + + oa = OAuth() + oa.challenge = UnknownModel(a="test") + assert oa.is_challenge + + oa.mfa_required = UnknownModel(a="test") + assert oa.is_mfa + + oa.access_token = "some-token" + oa.refresh_token = "other-token" + assert oa.is_valid diff --git a/tests/test_sessionmanager.py b/tests/test_sessionmanager.py index 38e534d2..9bae341c 100644 --- a/tests/test_sessionmanager.py +++ b/tests/test_sessionmanager.py @@ -9,6 +9,8 @@ MOCK_URL = "mock://test.com" +# TODO: refactor this to remove internal method testing and only test the public methods + @pytest.fixture def sm(): @@ -463,3 +465,25 @@ def test_post(mock_login, sm): assert resp1 == json.loads(expected[1]["text"]) assert mock_login.call_count == 1 assert "404 Client Error" in str(e.value) + + +@freeze_time("2020-01-01") +def test_token_expired(sm): + from datetime import datetime + import pytz + + # Assumes default token expired is set to 1970 + assert sm.token_expired + + sm.expires_at = datetime.strptime("2020-01-03", "%Y-%m-%d").replace(tzinfo=pytz.UTC) + + assert not sm.token_expired + + +def test_login_set(sm): + assert sm.login_set + + sm.username = None + sm.password = None + + assert not sm.login_set From 63a233a603c7d29d11a60fe0b87cb65b75eb424e Mon Sep 17 00:00:00 2001 From: Adithya Balaji Date: Sun, 12 Apr 2020 02:59:43 -0400 Subject: [PATCH 7/9] Add news snippet. --- newsfragments/223.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 newsfragments/223.feature diff --git a/newsfragments/223.feature b/newsfragments/223.feature new file mode 100644 index 00000000..f48a5e9a --- /dev/null +++ b/newsfragments/223.feature @@ -0,0 +1 @@ +Add marshmallow support for internal models. From ed845c170dcf396135f80a82e771fae8c78e0f42 Mon Sep 17 00:00:00 2001 From: Adithya Balaji Date: Sun, 12 Apr 2020 03:04:25 -0400 Subject: [PATCH 8/9] Activate poetry venv. --- .github/workflows/main.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index dc671920..be9b7d90 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -38,6 +38,9 @@ jobs: run: | poetry install -vvv if: steps.cache-poetry.outputs.cache-hit != 'true' + - name: Activate poetry env + run: | + source "$( poetry env info --path )/bin/activate" - uses: actions/cache@v1 with: path: ~/.cache/pre-commit From 334ca9c5e60971978c1ec0530d5a1080a872e512 Mon Sep 17 00:00:00 2001 From: Adithya Balaji Date: Sun, 12 Apr 2020 03:23:55 -0400 Subject: [PATCH 9/9] Edit linting process to install dependencies to system so that mypy can make use of them. --- .github/workflows/main.yml | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index be9b7d90..06824925 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -21,26 +21,21 @@ jobs: curl -sSL https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py | python - name: Add Poetry to Path Unix run: echo "::add-path::$HOME/.poetry/bin" - if: (${{ runner.os }} == "Linux") || (${{ runner.os }} == "macOS") - name: Configure Poetry run: | - poetry config virtualenvs.in-project false - poetry config virtualenvs.path ~/.virtualenvs - - name: Cache Poetry virtualenv + poetry config virtualenvs.create false + - name: Cache pip uses: actions/cache@v1 id: cache-poetry with: - path: ~/.virtualenvs - key: poetry|pre-commit|${{ matrix.os }}|${{ env.PY }}|${{ hashFiles('poetry.lock') }} + path: ~/.cache/pip + key: ${{ matrix.os }}|${{ env.PY }}|poetry|pre-commit|${{ hashFiles('poetry.lock') }} restore-keys: | - poetry-${{ hashFiles('poetry.lock') }} + ${{ matrix.os }}|${{ env.PY }}|poetry|pre-commit| - name: Install Project Dependencies (Poetry) run: | poetry install -vvv if: steps.cache-poetry.outputs.cache-hit != 'true' - - name: Activate poetry env - run: | - source "$( poetry env info --path )/bin/activate" - uses: actions/cache@v1 with: path: ~/.cache/pre-commit @@ -81,7 +76,7 @@ jobs: path: ~/.virtualenvs key: poetry|v2|${{ matrix.os }}|${{ env.PY }}|${{ hashFiles('poetry.lock') }} restore-keys: | - poetry-${{ hashFiles('poetry.lock') }} + poetry|v2|${{ matrix.os }}|${{ env.PY }}| - name: Install Project Dependencies (Poetry) run: | poetry install -vvv