diff --git a/.github/workflows/python-main.yml b/.github/workflows/python-main.yml index 8543156..ad52e2f 100644 --- a/.github/workflows/python-main.yml +++ b/.github/workflows/python-main.yml @@ -13,7 +13,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.8'] + python-version: ['3.9'] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -21,7 +21,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install tox and any other packages - run: pip install tox + run: pip install tox tox-uv - name: Run tox # Run tox using tox.ini run: tox -c tox.ini diff --git a/.github/workflows/python-pull-request.yml b/.github/workflows/python-pull-request.yml index d9ffc6a..c9a046e 100644 --- a/.github/workflows/python-pull-request.yml +++ b/.github/workflows/python-pull-request.yml @@ -14,7 +14,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ['3.8'] + python-version: ['3.9'] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} @@ -22,7 +22,7 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install tox and any other packages - run: pip install tox + run: pip install tox tox-uv - name: Run tox # Run tox using tox.ini run: tox -c tox.ini diff --git a/Makefile b/Makefile index 6cec10d..08ed07f 100644 --- a/Makefile +++ b/Makefile @@ -1,13 +1,13 @@ -pip-compile: requirements.in test-requirements.in nb-requirements.in dev-requirements.in ray-env-requirements.in rendering-requirements.in - pip-compile --no-emit-index-url --no-emit-options --no-emit-find-links requirements.in - pip-compile --no-emit-index-url --no-emit-options --no-emit-find-links test-requirements.in - pip-compile --no-emit-index-url --no-emit-options --no-emit-find-links nb-requirements.in - pip-compile --no-emit-index-url --no-emit-options --no-emit-find-links dev-requirements.in - pip-compile --no-emit-index-url --no-emit-options --no-emit-find-links ray-env-requirements.in --unsafe-package ray - pip-compile --no-emit-index-url --no-emit-options --no-emit-find-links rendering-requirements.in +pip-compile: requirements.in test-requirements.in nb-requirements.in dev-requirements.in ray-env-requirements.in rendering-requirements.txt + uv pip compile --no-emit-index-url --no-emit-find-links requirements.in -o requirements.txt + uv pip compile --no-emit-index-url --no-emit-find-links test-requirements.in -o test-requirements.txt + uv pip compile --no-emit-index-url --no-emit-find-links nb-requirements.in -o nb-requirements.txt + uv pip compile --no-emit-index-url --no-emit-find-links dev-requirements.in -o dev-requirements.txt + uv pip compile --no-emit-index-url --no-emit-find-links ray-env-requirements.in -o ray-env-requirements.txt --unsafe-package ray + uv pip compile --no-emit-index-url --no-emit-find-links rendering-requirements.in -o rendering-requirements.txt pip-install: pip-compile - pip install -r dev-requirements.txt -e . + uv pip install -r dev-requirements.txt -e . format: ruff format src tests --line-length 88 diff --git a/README.md b/README.md index f77cc69..85ed403 100644 --- a/README.md +++ b/README.md @@ -32,8 +32,8 @@ For specific snapshots of code submitted to conferences: ## Dev Env First, make sure the following python development tools are installed: - - pip - - pip-tools(==7.3.0) + - [uv](https://docs.astral.sh/uv/getting-started/installation/) + - [ruff](https://docs.astral.sh/ruff/installation/) Then, in a virtual environment, run pip-compile and install: @@ -44,6 +44,12 @@ $ make pip-install These should install all the requirements dependencies for development. +For building, install tox and tox-uv + +``` +$ pip install tox tox-uv +``` + ## Dependencies The dependecy files map to a purpose as follows: @@ -54,4 +60,4 @@ The dependecy files map to a purpose as follows: - [rendering-requirements.in](rendering-requirements.in): for environments can be rendered in a graphical interface, with OpenGL. - [ray-env-requirements.in](ray-env-requirements.in): for ray in a cluster environment. During compilation with `pip-compile`, it's best to exclude the version of ray (see [Makefile](Makefile)). -All requirements files are compiled using `pip-compile`. \ No newline at end of file +All requirements files are compiled using `uv`. \ No newline at end of file diff --git a/dev-requirements.txt b/dev-requirements.txt index ac69c5a..490e08c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,9 +1,5 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# pip-compile --no-emit-find-links --no-emit-index-url --no-emit-options dev-requirements.in -# +# This file was autogenerated by uv via the following command: +# uv pip compile --no-emit-index-url --no-emit-find-links dev-requirements.in -o dev-requirements.txt absl-py==2.0.0 # via # -r nb-requirements.txt @@ -17,13 +13,11 @@ aiohttp==3.8.4 # -r requirements.txt # -r test-requirements.txt # aiohttp-cors - # ray aiohttp-cors==0.7.0 # via # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt - # ray aiosignal==1.3.1 # via # -r nb-requirements.txt @@ -31,16 +25,14 @@ aiosignal==1.3.1 # -r test-requirements.txt # aiohttp # ray -anyio==3.6.2 +anyio==4.4.0 # via # -r nb-requirements.txt # jupyter-server -argon2-cffi==21.3.0 +argon2-cffi==23.1.0 # via # -r nb-requirements.txt # jupyter-server - # nbclassic - # notebook argon2-cffi-bindings==21.2.0 # via # -r nb-requirements.txt @@ -49,7 +41,7 @@ arrow==1.3.0 # via # -r nb-requirements.txt # isoduration -asttokens==2.2.1 +asttokens==2.4.1 # via # -r nb-requirements.txt # stack-data @@ -77,28 +69,18 @@ attrs==22.2.0 # aiohttp # hypothesis # jsonschema -babel==2.15.0 +babel==2.16.0 # via # -r nb-requirements.txt # jupyterlab-server -backcall==0.2.0 - # via - # -r nb-requirements.txt - # ipython -beautifulsoup4==4.11.2 +beautifulsoup4==4.12.3 # via # -r nb-requirements.txt # nbconvert -bleach==6.0.0 +bleach==6.1.0 # via # -r nb-requirements.txt # nbconvert -blessed==1.20.0 - # via - # -r nb-requirements.txt - # -r requirements.txt - # -r test-requirements.txt - # gpustat bumpversion==0.5.3 # via -r dev-requirements.in cachetools==5.3.0 @@ -113,7 +95,7 @@ certifi==2022.12.7 # -r requirements.txt # -r test-requirements.txt # requests -cffi==1.15.1 +cffi==1.17.1 # via # -r nb-requirements.txt # argon2-cffi-bindings @@ -124,6 +106,12 @@ charset-normalizer==3.1.0 # -r test-requirements.txt # aiohttp # requests +clarabel==0.9.0 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # cvxpy click==8.0.4 # via # -r nb-requirements.txt @@ -141,24 +129,29 @@ colorful==0.5.5 # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt - # ray -comm==0.1.2 +comm==0.2.2 # via # -r nb-requirements.txt # ipykernel -contourpy==1.0.7 + # ipywidgets +contourpy==1.3.0 # via # -r nb-requirements.txt # matplotlib -coverage[toml]==7.3.2 +coverage==7.3.2 # via # -r test-requirements.txt # pytest-cov -cycler==0.11.0 +cvxpy==1.5.3 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt +cycler==0.12.1 # via # -r nb-requirements.txt # matplotlib -debugpy==1.6.6 +debugpy==1.8.5 # via # -r nb-requirements.txt # ipykernel @@ -176,12 +169,21 @@ distlib==0.3.6 # -r requirements.txt # -r test-requirements.txt # virtualenv +ecos==2.0.14 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # cvxpy exceptiongroup==1.2.0 # via + # -r nb-requirements.txt # -r test-requirements.txt + # anyio # hypothesis + # ipython # pytest -executing==1.2.0 +executing==2.1.0 # via # -r nb-requirements.txt # stack-data @@ -191,7 +193,7 @@ farama-notifications==0.0.4 # -r requirements.txt # -r test-requirements.txt # gymnasium -fastjsonschema==2.16.3 +fastjsonschema==2.20.0 # via # -r nb-requirements.txt # nbformat @@ -208,7 +210,7 @@ flatbuffers==23.5.26 # -r requirements.txt # -r test-requirements.txt # tensorflow -fonttools==4.39.0 +fonttools==4.53.1 # via # -r nb-requirements.txt # matplotlib @@ -229,7 +231,6 @@ fsspec==2024.2.0 # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt - # ray gast==0.4.0 # via # -r nb-requirements.txt @@ -268,18 +269,11 @@ googleapis-common-protos==1.58.0 # -r requirements.txt # -r test-requirements.txt # google-api-core -gpustat==1.0.0 - # via - # -r nb-requirements.txt - # -r requirements.txt - # -r test-requirements.txt - # ray grpcio==1.51.3 # via # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt - # ray # tensorboard # tensorflow gymnasium==0.28.1 @@ -317,7 +311,7 @@ importlib-metadata==6.0.0 # jupyterlab-server # markdown # nbconvert -importlib-resources==5.12.0 +importlib-resources==6.4.5 # via # -r nb-requirements.txt # matplotlib @@ -325,21 +319,17 @@ iniconfig==2.0.0 # via # -r test-requirements.txt # pytest -ipykernel==6.21.3 +ipykernel==6.29.5 # via # -r nb-requirements.txt # jupyterlab - # nbclassic - # notebook -ipython==8.11.0 +ipython==8.18.1 # via # -r nb-requirements.txt # ipykernel -ipython-genutils==0.2.0 - # via - # -r nb-requirements.txt - # nbclassic - # notebook + # ipywidgets +ipywidgets==8.1.5 + # via -r nb-requirements.txt isoduration==20.11.0 # via # -r nb-requirements.txt @@ -350,28 +340,29 @@ jax-jumpy==1.0.0 # -r requirements.txt # -r test-requirements.txt # gymnasium -jedi==0.18.2 +jedi==0.19.1 # via # -r nb-requirements.txt # ipython -jinja2==3.1.2 +jinja2==3.1.4 # via # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt # jupyter-server # jupyterlab # jupyterlab-server - # nbclassic + # memray # nbconvert - # notebook json5==0.9.25 # via # -r nb-requirements.txt # jupyterlab-server -jsonpointer==2.4 +jsonpointer==3.0.0 # via # -r nb-requirements.txt # jsonschema -jsonschema[format-nongpl]==4.17.3 +jsonschema==4.17.3 # via # -r nb-requirements.txt # -r requirements.txt @@ -380,19 +371,17 @@ jsonschema[format-nongpl]==4.17.3 # jupyterlab-server # nbformat # ray -jupyter-client==8.0.3 +jupyter-client==8.6.2 # via # -r nb-requirements.txt # ipykernel # jupyter-server - # nbclassic # nbclient - # notebook jupyter-contrib-core==0.4.2 # via # -r nb-requirements.txt # jupyter-nbextensions-configurator -jupyter-core==5.2.0 +jupyter-core==5.7.2 # via # -r nb-requirements.txt # ipykernel @@ -401,11 +390,9 @@ jupyter-core==5.2.0 # jupyter-nbextensions-configurator # jupyter-server # jupyterlab - # nbclassic # nbclient # nbconvert # nbformat - # notebook jupyter-events==0.6.3 # via # -r nb-requirements.txt @@ -418,22 +405,24 @@ jupyter-nbextensions-configurator==0.6.3 # via -r nb-requirements.txt jupyter-resource-usage==0.7.2 # via -r nb-requirements.txt -jupyter-server==2.7.0 +jupyter-server==2.10.0 # via # -r nb-requirements.txt # jupyter-lsp # jupyter-resource-usage # jupyterlab # jupyterlab-server - # nbclassic + # notebook # notebook-shim jupyter-server-terminals==0.5.3 # via # -r nb-requirements.txt # jupyter-server jupyterlab==4.0.13 - # via -r nb-requirements.txt -jupyterlab-pygments==0.2.2 + # via + # -r nb-requirements.txt + # notebook +jupyterlab-pygments==0.3.0 # via # -r nb-requirements.txt # nbconvert @@ -441,13 +430,18 @@ jupyterlab-server==2.24.0 # via # -r nb-requirements.txt # jupyterlab + # notebook +jupyterlab-widgets==3.0.13 + # via + # -r nb-requirements.txt + # ipywidgets keras==2.13.1 # via # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt # tensorflow -kiwisolver==1.4.4 +kiwisolver==1.4.7 # via # -r nb-requirements.txt # matplotlib @@ -457,12 +451,26 @@ libclang==16.0.6 # -r requirements.txt # -r test-requirements.txt # tensorflow +linkify-it-py==2.0.3 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # markdown-it-py markdown==3.5.1 # via # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt # tensorboard +markdown-it-py==3.0.0 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # mdit-py-plugins + # rich + # textual markupsafe==2.1.3 # via # -r nb-requirements.txt @@ -471,16 +479,33 @@ markupsafe==2.1.3 # jinja2 # nbconvert # werkzeug -matplotlib==3.7.1 +matplotlib==3.9.2 # via # -r nb-requirements.txt # seaborn -matplotlib-inline==0.1.6 +matplotlib-inline==0.1.7 # via # -r nb-requirements.txt # ipykernel # ipython -mistune==2.0.5 +mdit-py-plugins==0.4.2 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # markdown-it-py +mdurl==0.1.2 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # markdown-it-py +memray==1.14.0 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt +mistune==3.0.2 # via # -r nb-requirements.txt # nbconvert @@ -497,71 +522,60 @@ multidict==6.0.4 # -r test-requirements.txt # aiohttp # yarl -nbclassic==0.5.3 - # via - # -r nb-requirements.txt - # notebook -nbclient==0.7.2 +nbclient==0.10.0 # via # -r nb-requirements.txt # nbconvert -nbconvert==7.2.10 +nbconvert==7.16.4 # via # -r nb-requirements.txt # jupyter-server - # nbclassic - # notebook -nbformat==5.7.3 +nbformat==5.10.4 # via # -r nb-requirements.txt # jupyter-server - # nbclassic # nbclient # nbconvert - # notebook -nest-asyncio==1.5.6 +nest-asyncio==1.6.0 # via # -r nb-requirements.txt # ipykernel - # nbclassic - # notebook -notebook==6.5.4 +notebook==7.0.8 # via # -r nb-requirements.txt # jupyter-contrib-core # jupyter-nbextensions-configurator -notebook-shim==0.2.2 +notebook-shim==0.2.4 # via # -r nb-requirements.txt # jupyterlab - # nbclassic + # notebook numpy==1.23.5 # via # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt + # clarabel # contourpy + # cvxpy + # ecos # gymnasium # h5py # jax-jumpy # matplotlib # opt-einsum + # osqp # pandas # patsy # pyarrow - # ray + # qdldl # rlplg # scipy + # scs # seaborn # statsmodels # tensorboard # tensorflow -nvidia-ml-py==11.495.46 - # via - # -r nb-requirements.txt - # -r requirements.txt - # -r test-requirements.txt - # gpustat oauthlib==3.2.2 # via # -r nb-requirements.txt @@ -573,7 +587,6 @@ opencensus==0.11.2 # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt - # ray opencensus-context==0.1.3 # via # -r nb-requirements.txt @@ -586,6 +599,12 @@ opt-einsum==3.3.0 # -r requirements.txt # -r test-requirements.txt # tensorflow +osqp==0.6.7.post1 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # cvxpy overrides==7.7.0 # via # -r nb-requirements.txt @@ -610,14 +629,13 @@ pandas==2.0.3 # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt - # ray # seaborn # statsmodels -pandocfilters==1.5.0 +pandocfilters==1.5.1 # via # -r nb-requirements.txt # nbconvert -parso==0.8.3 +parso==0.8.4 # via # -r nb-requirements.txt # jedi @@ -625,11 +643,7 @@ patsy==0.5.6 # via # -r nb-requirements.txt # statsmodels -pexpect==4.8.0 - # via - # -r nb-requirements.txt - # ipython -pickleshare==0.7.5 +pexpect==4.9.0 # via # -r nb-requirements.txt # ipython @@ -658,10 +672,7 @@ prometheus-client==0.13.1 # -r test-requirements.txt # jupyter-resource-usage # jupyter-server - # nbclassic - # notebook - # ray -prompt-toolkit==3.0.38 +prompt-toolkit==3.0.47 # via # -r nb-requirements.txt # ipython @@ -675,12 +686,9 @@ protobuf==4.25.1 # ray # tensorboard # tensorflow -psutil==5.9.4 +psutil==5.9.8 # via # -r nb-requirements.txt - # -r requirements.txt - # -r test-requirements.txt - # gpustat # ipykernel # jupyter-resource-usage ptyprocess==0.7.0 @@ -688,7 +696,7 @@ ptyprocess==0.7.0 # -r nb-requirements.txt # pexpect # terminado -pure-eval==0.2.2 +pure-eval==0.2.3 # via # -r nb-requirements.txt # stack-data @@ -697,13 +705,11 @@ py-spy==0.3.14 # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt - # ray pyarrow==15.0.0 # via # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt - # ray pyasn1==0.4.8 # via # -r nb-requirements.txt @@ -717,7 +723,7 @@ pyasn1-modules==0.2.8 # -r requirements.txt # -r test-requirements.txt # google-auth -pycparser==2.21 +pycparser==2.22 # via # -r nb-requirements.txt # cffi @@ -726,13 +732,15 @@ pydantic==1.10.6 # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt - # ray -pygments==2.14.0 +pygments==2.18.0 # via # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt # ipython # nbconvert -pyparsing==3.0.9 + # rich +pyparsing==3.1.4 # via # -r nb-requirements.txt # matplotlib @@ -775,16 +783,20 @@ pyyaml==6.0 # jupyter-events # jupyter-nbextensions-configurator # ray -pyzmq==25.0.1 +pyzmq==26.2.0 # via # -r nb-requirements.txt # ipykernel # jupyter-client # jupyter-resource-usage # jupyter-server - # nbclassic - # notebook -ray[data,default]==2.9.3 +qdldl==0.1.7.post4 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # osqp +ray==2.23.0 # via # -r nb-requirements.txt # -r requirements.txt @@ -815,7 +827,14 @@ rfc3986-validator==0.1.1 # -r nb-requirements.txt # jsonschema # jupyter-events -rlplg @ git+https://github.com/guidj/rlplg.git@v0.19.10 +rich==13.8.1 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # memray + # textual +rlplg @ git+https://github.com/guidj/rlplg.git@93d819cab2eeefdd9bbd0ec7acb5d1dd0e5d1bed # via # -r nb-requirements.txt # -r requirements.txt @@ -833,16 +852,34 @@ scipy==1.10.1 # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt + # clarabel + # cvxpy + # ecos + # osqp + # qdldl # rlplg + # scs # statsmodels +scs==3.2.7 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # cvxpy seaborn==0.13.1 # via -r nb-requirements.txt -send2trash==1.8.0 +send2trash==1.8.3 # via # -r nb-requirements.txt # jupyter-server - # nbclassic - # notebook +setuptools==74.1.2 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # jupyter-contrib-core + # tensorboard + # tensorflow six==1.16.0 # via # -r nb-requirements.txt @@ -851,10 +888,8 @@ six==1.16.0 # asttokens # astunparse # bleach - # blessed # google-auth # google-pasta - # gpustat # patsy # python-dateutil # rfc3339-validator @@ -864,8 +899,7 @@ smart-open==6.3.0 # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt - # ray -sniffio==1.3.0 +sniffio==1.3.1 # via # -r nb-requirements.txt # anyio @@ -873,11 +907,11 @@ sortedcontainers==2.4.0 # via # -r test-requirements.txt # hypothesis -soupsieve==2.4 +soupsieve==2.6 # via # -r nb-requirements.txt # beautifulsoup4 -stack-data==0.6.2 +stack-data==0.6.3 # via # -r nb-requirements.txt # ipython @@ -918,14 +952,18 @@ termcolor==2.3.0 # -r requirements.txt # -r test-requirements.txt # tensorflow -terminado==0.17.1 +terminado==0.18.1 # via # -r nb-requirements.txt # jupyter-server # jupyter-server-terminals - # nbclassic - # notebook -tinycss2==1.2.1 +textual==0.78.0 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # memray +tinycss2==1.3.0 # via # -r nb-requirements.txt # nbconvert @@ -936,7 +974,7 @@ tomli==2.0.1 # coverage # jupyterlab # pytest -tornado==6.2 +tornado==6.4.1 # via # -r nb-requirements.txt # ipykernel @@ -945,15 +983,17 @@ tornado==6.2 # jupyter-nbextensions-configurator # jupyter-server # jupyterlab - # nbclassic # notebook # terminado -traitlets==5.9.0 +tqdm==4.66.5 + # via -r nb-requirements.txt +traitlets==5.14.3 # via # -r nb-requirements.txt # comm # ipykernel # ipython + # ipywidgets # jupyter-client # jupyter-contrib-core # jupyter-core @@ -962,12 +1002,10 @@ traitlets==5.9.0 # jupyter-server # jupyterlab # matplotlib-inline - # nbclassic # nbclient # nbconvert # nbformat - # notebook -types-python-dateutil==2.8.19.14 +types-python-dateutil==2.9.0.20240906 # via # -r nb-requirements.txt # arrow @@ -976,16 +1014,25 @@ typing-extensions==4.5.0 # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt + # anyio # async-lru # gymnasium + # ipython # pydantic # tensorflow + # textual tzdata==2024.1 # via # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt # pandas +uc-micro-py==1.0.3 + # via + # -r nb-requirements.txt + # -r requirements.txt + # -r test-requirements.txt + # linkify-it-py uri-template==1.3.0 # via # -r nb-requirements.txt @@ -1001,15 +1048,11 @@ virtualenv==20.21.0 # -r nb-requirements.txt # -r requirements.txt # -r test-requirements.txt - # ray -wcwidth==0.2.6 +wcwidth==0.2.13 # via # -r nb-requirements.txt - # -r requirements.txt - # -r test-requirements.txt - # blessed # prompt-toolkit -webcolors==1.13 +webcolors==24.8.0 # via # -r nb-requirements.txt # jsonschema @@ -1018,7 +1061,7 @@ webencodings==0.5.1 # -r nb-requirements.txt # bleach # tinycss2 -websocket-client==1.5.1 +websocket-client==1.8.0 # via # -r nb-requirements.txt # jupyter-server @@ -1035,6 +1078,10 @@ wheel==0.42.0 # -r test-requirements.txt # astunparse # tensorboard +widgetsnbextension==4.0.13 + # via + # -r nb-requirements.txt + # ipywidgets wrapt==1.16.0 # via # -r nb-requirements.txt @@ -1054,6 +1101,3 @@ zipp==3.15.0 # -r test-requirements.txt # importlib-metadata # importlib-resources - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/experiments/policycontrol/experiments.csv b/experiments/policycontrol/experiments.csv index 2357d97..da87ddc 100644 --- a/experiments/policycontrol/experiments.csv +++ b/experiments/policycontrol/experiments.csv @@ -1,159 +1,159 @@ "algorithm","policy_type","traj_mapping_method","reward_period","drop_truncated_feedback_episodes","discount_factor","learning_rate","epsilon","algorithm_args" -"sarsa","single-step","identity-mapper","1",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","identity-mapper","1",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-trajectory-mapper","2",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-trajectory-mapper","4",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-trajectory-mapper","6",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-trajectory-mapper","8",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-trajectory-mapper","16",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-trajectory-mapper","2",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-trajectory-mapper","4",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-trajectory-mapper","6",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-trajectory-mapper","8",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-trajectory-mapper","16",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-impute-missing-reward-mapper","16",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-impute-missing-reward-mapper","16",false,0.99,0.1,0.2,"{}" -"sarsa","options","daaf-mdp-with-options-mapper","2",false,1.0,0.1,0.2,"{}" -"sarsa","options","daaf-mdp-with-options-mapper","2",false,0.99,0.1,0.2,"{}" -"sarsa","options","daaf-mdp-with-options-mapper","4",false,1.0,0.1,0.2,"{}" -"sarsa","options","daaf-mdp-with-options-mapper","4",false,0.99,0.1,0.2,"{}" -"sarsa","options","daaf-mdp-with-options-mapper","6",false,1.0,0.1,0.2,"{}" -"sarsa","options","daaf-mdp-with-options-mapper","6",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.1,0.2,"{}" -"sarsa","single-step","daaf-lsq-reward-attribution-mapper","16",false,1.0,0.1,0.2,"{}" -"sarsa","single-step","daaf-lsq-reward-attribution-mapper","16",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","identity-mapper","1",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","identity-mapper","1",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-trajectory-mapper","2",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-trajectory-mapper","4",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-trajectory-mapper","6",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-trajectory-mapper","8",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-trajectory-mapper","16",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-trajectory-mapper","2",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-trajectory-mapper","4",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-trajectory-mapper","6",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-trajectory-mapper","8",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-trajectory-mapper","16",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-impute-missing-reward-mapper","16",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-impute-missing-reward-mapper","16",false,0.99,0.1,0.2,"{}" -"q-learning","options","daaf-mdp-with-options-mapper","2",false,1.0,0.1,0.2,"{}" -"q-learning","options","daaf-mdp-with-options-mapper","2",false,0.99,0.1,0.2,"{}" -"q-learning","options","daaf-mdp-with-options-mapper","4",false,1.0,0.1,0.2,"{}" -"q-learning","options","daaf-mdp-with-options-mapper","4",false,0.99,0.1,0.2,"{}" -"q-learning","options","daaf-mdp-with-options-mapper","6",false,1.0,0.1,0.2,"{}" -"q-learning","options","daaf-mdp-with-options-mapper","6",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.1,0.2,"{}" -"q-learning","single-step","daaf-lsq-reward-attribution-mapper","16",false,1.0,0.1,0.2,"{}" -"q-learning","single-step","daaf-lsq-reward-attribution-mapper","16",false,0.99,0.1,0.2,"{}" -"nstep-sarsa","single-step","identity-mapper","1",false,1.0,0.1,0.2,"{""nstep"":1}" -"nstep-sarsa","single-step","identity-mapper","1",false,0.99,0.1,0.2,"{""nstep"":1}" -"nstep-sarsa","single-step","identity-mapper","1",false,1.0,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","identity-mapper","1",false,0.99,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","identity-mapper","1",false,1.0,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","identity-mapper","1",false,0.99,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","identity-mapper","1",false,1.0,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","identity-mapper","1",false,0.99,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","identity-mapper","1",false,1.0,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","identity-mapper","1",false,0.99,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","2",false,1.0,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","2",false,0.99,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","4",false,1.0,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","4",false,0.99,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","6",false,1.0,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","6",false,0.99,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","8",false,1.0,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","8",false,0.99,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.1,0.2,"{""nstep"":2}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.1,0.2,"{""nstep"":4}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.1,0.2,"{""nstep"":6}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.1,0.2,"{""nstep"":8}" -"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.1,0.2,"{""nstep"":8}" +"sarsa","single-step","identity-mapper","1",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","identity-mapper","1",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-trajectory-mapper","2",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-trajectory-mapper","4",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-trajectory-mapper","6",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-trajectory-mapper","8",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-trajectory-mapper","16",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-trajectory-mapper","2",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-trajectory-mapper","4",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-trajectory-mapper","6",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-trajectory-mapper","8",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-trajectory-mapper","16",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-impute-missing-reward-mapper","16",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-impute-missing-reward-mapper","16",false,0.99,0.01,0.2,"{}" +"sarsa","options","daaf-mdp-with-options-mapper","2",false,1.0,0.01,0.2,"{}" +"sarsa","options","daaf-mdp-with-options-mapper","2",false,0.99,0.01,0.2,"{}" +"sarsa","options","daaf-mdp-with-options-mapper","4",false,1.0,0.01,0.2,"{}" +"sarsa","options","daaf-mdp-with-options-mapper","4",false,0.99,0.01,0.2,"{}" +"sarsa","options","daaf-mdp-with-options-mapper","6",false,1.0,0.01,0.2,"{}" +"sarsa","options","daaf-mdp-with-options-mapper","6",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.01,0.2,"{}" +"sarsa","single-step","daaf-lsq-reward-attribution-mapper","16",false,1.0,0.01,0.2,"{}" +"sarsa","single-step","daaf-lsq-reward-attribution-mapper","16",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","identity-mapper","1",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","identity-mapper","1",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-trajectory-mapper","2",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-trajectory-mapper","4",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-trajectory-mapper","6",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-trajectory-mapper","8",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-trajectory-mapper","16",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-trajectory-mapper","2",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-trajectory-mapper","4",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-trajectory-mapper","6",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-trajectory-mapper","8",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-trajectory-mapper","16",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-impute-missing-reward-mapper","16",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-impute-missing-reward-mapper","16",false,0.99,0.01,0.2,"{}" +"q-learning","options","daaf-mdp-with-options-mapper","2",false,1.0,0.01,0.2,"{}" +"q-learning","options","daaf-mdp-with-options-mapper","2",false,0.99,0.01,0.2,"{}" +"q-learning","options","daaf-mdp-with-options-mapper","4",false,1.0,0.01,0.2,"{}" +"q-learning","options","daaf-mdp-with-options-mapper","4",false,0.99,0.01,0.2,"{}" +"q-learning","options","daaf-mdp-with-options-mapper","6",false,1.0,0.01,0.2,"{}" +"q-learning","options","daaf-mdp-with-options-mapper","6",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.01,0.2,"{}" +"q-learning","single-step","daaf-lsq-reward-attribution-mapper","16",false,1.0,0.01,0.2,"{}" +"q-learning","single-step","daaf-lsq-reward-attribution-mapper","16",false,0.99,0.01,0.2,"{}" +"nstep-sarsa","single-step","identity-mapper","1",false,1.0,0.01,0.2,"{""nstep"":1}" +"nstep-sarsa","single-step","identity-mapper","1",false,0.99,0.01,0.2,"{""nstep"":1}" +"nstep-sarsa","single-step","identity-mapper","1",false,1.0,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","identity-mapper","1",false,0.99,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","identity-mapper","1",false,1.0,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","identity-mapper","1",false,0.99,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","identity-mapper","1",false,1.0,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","identity-mapper","1",false,0.99,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","identity-mapper","1",false,1.0,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","identity-mapper","1",false,0.99,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","2",false,1.0,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","2",false,0.99,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","4",false,1.0,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","4",false,0.99,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","6",false,1.0,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","6",false,0.99,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","8",false,1.0,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-nstep-td-update-mark-mapper","8",false,0.99,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.01,0.2,"{""nstep"":2}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.01,0.2,"{""nstep"":4}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.01,0.2,"{""nstep"":6}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.01,0.2,"{""nstep"":8}" +"nstep-sarsa","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.01,0.2,"{""nstep"":8}" diff --git a/experiments/policyeval/experiments.csv b/experiments/policyeval/experiments.csv index d660dd6..1b5ef5a 100644 --- a/experiments/policyeval/experiments.csv +++ b/experiments/policyeval/experiments.csv @@ -1,89 +1,97 @@ "algorithm","policy_type","traj_mapping_method","reward_period","drop_truncated_feedback_episodes","discount_factor","learning_rate","epsilon","algorithm_args" -"one-step-td","single-step","identity-mapper","1",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","identity-mapper","1",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-trajectory-mapper","2",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-trajectory-mapper","2",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-trajectory-mapper","4",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-trajectory-mapper","4",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-trajectory-mapper","6",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-trajectory-mapper","6",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-trajectory-mapper","8",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-trajectory-mapper","8",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.1,0.0,"{}" -"one-step-td","options","daaf-mdp-with-options-mapper","2",false,1.0,0.1,0.0,"{}" -"one-step-td","options","daaf-mdp-with-options-mapper","2",false,0.99,0.1,0.0,"{}" -"one-step-td","options","daaf-mdp-with-options-mapper","4",false,1.0,0.1,0.0,"{}" -"one-step-td","options","daaf-mdp-with-options-mapper","4",false,0.99,0.1,0.0,"{}" -"one-step-td","options","daaf-mdp-with-options-mapper","6",false,1.0,0.1,0.0,"{}" -"one-step-td","options","daaf-mdp-with-options-mapper","6",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.1,0.0,"{}" -"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.1,0.0,"{}" -"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","identity-mapper","1",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","identity-mapper","1",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","2",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","2",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","4",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","4",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","6",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","6",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.1,0.0,"{}" -"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.1,0.0,"{}" -"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","identity-mapper","1",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","identity-mapper","1",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-trajectory-mapper","2",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-trajectory-mapper","2",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-trajectory-mapper","4",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-trajectory-mapper","4",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-trajectory-mapper","6",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-trajectory-mapper","6",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-trajectory-mapper","8",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-trajectory-mapper","8",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.1,0.0,"{}" -"first-visit-mc","options","daaf-mdp-with-options-mapper","2",false,1.0,0.1,0.0,"{}" -"first-visit-mc","options","daaf-mdp-with-options-mapper","2",false,0.99,0.1,0.0,"{}" -"first-visit-mc","options","daaf-mdp-with-options-mapper","4",false,1.0,0.1,0.0,"{}" -"first-visit-mc","options","daaf-mdp-with-options-mapper","4",false,0.99,0.1,0.0,"{}" -"first-visit-mc","options","daaf-mdp-with-options-mapper","6",false,1.0,0.1,0.0,"{}" -"first-visit-mc","options","daaf-mdp-with-options-mapper","6",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.1,0.0,"{}" -"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.1,0.0,"{}" +"one-step-td","single-step","identity-mapper","1",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","identity-mapper","1",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-trajectory-mapper","2",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-trajectory-mapper","2",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-trajectory-mapper","4",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-trajectory-mapper","4",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-trajectory-mapper","6",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-trajectory-mapper","6",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-trajectory-mapper","8",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-trajectory-mapper","8",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.01,0.0,"{}" +"one-step-td","options","daaf-mdp-with-options-mapper","2",false,1.0,0.01,0.0,"{}" +"one-step-td","options","daaf-mdp-with-options-mapper","2",false,0.99,0.01,0.0,"{}" +"one-step-td","options","daaf-mdp-with-options-mapper","4",false,1.0,0.01,0.0,"{}" +"one-step-td","options","daaf-mdp-with-options-mapper","4",false,0.99,0.01,0.0,"{}" +"one-step-td","options","daaf-mdp-with-options-mapper","6",false,1.0,0.01,0.0,"{}" +"one-step-td","options","daaf-mdp-with-options-mapper","6",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.01,0.0,"{}" +"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.01,0.0,"{}" +"one-step-td","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.01,0.0,"{}" +"nstep-td","single-step","identity-mapper","1",false,1.0,0.01,0.0,"{""nstep"":1}" +"nstep-td","single-step","identity-mapper","1",false,0.99,0.01,0.0,"{""nstep"":1}" +"nstep-td","single-step","identity-mapper","1",false,1.0,0.01,0.0,"{""nstep"":2}" +"nstep-td","single-step","identity-mapper","1",false,0.99,0.01,0.0,"{""nstep"":2}" +"nstep-td","single-step","identity-mapper","1",false,1.0,0.01,0.0,"{""nstep"":4}" +"nstep-td","single-step","identity-mapper","1",false,0.99,0.01,0.0,"{""nstep"":4}" +"nstep-td","single-step","identity-mapper","1",false,1.0,0.01,0.0,"{""nstep"":6}" +"nstep-td","single-step","identity-mapper","1",false,0.99,0.01,0.0,"{""nstep"":6}" +"nstep-td","single-step","identity-mapper","1",false,1.0,0.01,0.0,"{""nstep"":8}" +"nstep-td","single-step","identity-mapper","1",false,0.99,0.01,0.0,"{""nstep"":8}" +"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","2",false,1.0,0.01,0.0,"{""nstep"":2}" +"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","2",false,0.99,0.01,0.0,"{""nstep"":2}" +"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","4",false,1.0,0.01,0.0,"{""nstep"":4}" +"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","4",false,0.99,0.01,0.0,"{""nstep"":4}" +"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","6",false,1.0,0.01,0.0,"{""nstep"":6}" +"nstep-td","single-step","daaf-nstep-td-update-mark-mapper","6",false,0.99,0.01,0.0,"{""nstep"":6}" +"nstep-td","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.01,0.0,"{""nstep"":2}" +"nstep-td","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.01,0.0,"{""nstep"":2}" +"nstep-td","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.01,0.0,"{""nstep"":4}" +"nstep-td","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.01,0.0,"{""nstep"":4}" +"nstep-td","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.01,0.0,"{""nstep"":6}" +"nstep-td","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.01,0.0,"{""nstep"":6}" +"nstep-td","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.01,0.0,"{""nstep"":8}" +"nstep-td","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.01,0.0,"{""nstep"":8}" +"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.01,0.0,"{""nstep"":2}" +"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.01,0.0,"{""nstep"":2}" +"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.01,0.0,"{""nstep"":4}" +"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.01,0.0,"{""nstep"":4}" +"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.01,0.0,"{""nstep"":6}" +"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.01,0.0,"{""nstep"":6}" +"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.01,0.0,"{""nstep"":8}" +"nstep-td","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.01,0.0,"{""nstep"":8}" +"first-visit-mc","single-step","identity-mapper","1",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","identity-mapper","1",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-trajectory-mapper","2",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-trajectory-mapper","2",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-trajectory-mapper","4",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-trajectory-mapper","4",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-trajectory-mapper","6",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-trajectory-mapper","6",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-trajectory-mapper","8",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-trajectory-mapper","8",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","2",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","2",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","4",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","4",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","6",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","6",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","8",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-impute-missing-reward-mapper","8",false,0.99,0.01,0.0,"{}" +"first-visit-mc","options","daaf-mdp-with-options-mapper","2",false,1.0,0.01,0.0,"{}" +"first-visit-mc","options","daaf-mdp-with-options-mapper","2",false,0.99,0.01,0.0,"{}" +"first-visit-mc","options","daaf-mdp-with-options-mapper","4",false,1.0,0.01,0.0,"{}" +"first-visit-mc","options","daaf-mdp-with-options-mapper","4",false,0.99,0.01,0.0,"{}" +"first-visit-mc","options","daaf-mdp-with-options-mapper","6",false,1.0,0.01,0.0,"{}" +"first-visit-mc","options","daaf-mdp-with-options-mapper","6",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","2",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","2",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","4",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","4",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","6",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","6",false,0.99,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","8",false,1.0,0.01,0.0,"{}" +"first-visit-mc","single-step","daaf-lsq-reward-attribution-mapper","8",false,0.99,0.01,0.0,"{}" diff --git a/nb-requirements.in b/nb-requirements.in index f4df666..a776013 100644 --- a/nb-requirements.in +++ b/nb-requirements.in @@ -6,5 +6,9 @@ jupyter_nbextensions_configurator==0.6.3 jupyter-resource-usage==0.7.2 statsmodels==0.14.1 seaborn==0.13.1 -matplotlib +matplotlib>=3.8.1 +tqdm>=4.66.2 +ipywidgets>=8.0.7 +# pin +exceptiongroup==1.2.0 diff --git a/nb-requirements.txt b/nb-requirements.txt index 42c3de2..e7142cb 100644 --- a/nb-requirements.txt +++ b/nb-requirements.txt @@ -1,9 +1,5 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# pip-compile --no-emit-find-links --no-emit-index-url --no-emit-options nb-requirements.in -# +# This file was autogenerated by uv via the following command: +# uv pip compile --no-emit-index-url --no-emit-find-links nb-requirements.in -o nb-requirements.txt absl-py==2.0.0 # via # -r requirements.txt @@ -13,28 +9,22 @@ aiohttp==3.8.4 # via # -r requirements.txt # aiohttp-cors - # ray aiohttp-cors==0.7.0 - # via - # -r requirements.txt - # ray + # via -r requirements.txt aiosignal==1.3.1 # via # -r requirements.txt # aiohttp # ray -anyio==3.6.2 +anyio==4.4.0 + # via jupyter-server +argon2-cffi==23.1.0 # via jupyter-server -argon2-cffi==21.3.0 - # via - # jupyter-server - # nbclassic - # notebook argon2-cffi-bindings==21.2.0 # via argon2-cffi arrow==1.3.0 # via isoduration -asttokens==2.2.1 +asttokens==2.4.1 # via stack-data astunparse==1.6.3 # via @@ -51,18 +41,12 @@ attrs==22.2.0 # -r requirements.txt # aiohttp # jsonschema -babel==2.15.0 +babel==2.16.0 # via jupyterlab-server -backcall==0.2.0 - # via ipython -beautifulsoup4==4.11.2 +beautifulsoup4==4.12.3 # via nbconvert -bleach==6.0.0 +bleach==6.1.0 # via nbconvert -blessed==1.20.0 - # via - # -r requirements.txt - # gpustat cachetools==5.3.0 # via # -r requirements.txt @@ -71,13 +55,17 @@ certifi==2022.12.7 # via # -r requirements.txt # requests -cffi==1.15.1 +cffi==1.17.1 # via argon2-cffi-bindings charset-normalizer==3.1.0 # via # -r requirements.txt # aiohttp # requests +clarabel==0.9.0 + # via + # -r requirements.txt + # cvxpy click==8.0.4 # via # -r requirements.txt @@ -87,16 +75,18 @@ cloudpickle==2.2.1 # -r requirements.txt # gymnasium colorful==0.5.5 + # via -r requirements.txt +comm==0.2.2 # via - # -r requirements.txt - # ray -comm==0.1.2 - # via ipykernel -contourpy==1.0.7 + # ipykernel + # ipywidgets +contourpy==1.3.0 # via matplotlib -cycler==0.11.0 +cvxpy==1.5.3 + # via -r requirements.txt +cycler==0.12.1 # via matplotlib -debugpy==1.6.6 +debugpy==1.8.5 # via ipykernel decorator==5.1.1 # via ipython @@ -106,13 +96,22 @@ distlib==0.3.6 # via # -r requirements.txt # virtualenv -executing==1.2.0 +ecos==2.0.14 + # via + # -r requirements.txt + # cvxpy +exceptiongroup==1.2.0 + # via + # -r nb-requirements.in + # anyio + # ipython +executing==2.1.0 # via stack-data farama-notifications==0.0.4 # via # -r requirements.txt # gymnasium -fastjsonschema==2.16.3 +fastjsonschema==2.20.0 # via nbformat filelock==3.9.1 # via @@ -123,7 +122,7 @@ flatbuffers==23.5.26 # via # -r requirements.txt # tensorflow -fonttools==4.39.0 +fonttools==4.53.1 # via matplotlib fqdn==1.5.1 # via jsonschema @@ -134,9 +133,7 @@ frozenlist==1.3.3 # aiosignal # ray fsspec==2024.2.0 - # via - # -r requirements.txt - # ray + # via -r requirements.txt gast==0.4.0 # via # -r requirements.txt @@ -163,14 +160,9 @@ googleapis-common-protos==1.58.0 # via # -r requirements.txt # google-api-core -gpustat==1.0.0 - # via - # -r requirements.txt - # ray grpcio==1.51.3 # via # -r requirements.txt - # ray # tensorboard # tensorflow gymnasium==0.28.1 @@ -198,56 +190,51 @@ importlib-metadata==6.0.0 # jupyterlab-server # markdown # nbconvert -importlib-resources==5.12.0 +importlib-resources==6.4.5 # via matplotlib -ipykernel==6.21.3 - # via - # jupyterlab - # nbclassic - # notebook -ipython==8.11.0 - # via ipykernel -ipython-genutils==0.2.0 +ipykernel==6.29.5 + # via jupyterlab +ipython==8.18.1 # via - # nbclassic - # notebook + # ipykernel + # ipywidgets +ipywidgets==8.1.5 + # via -r nb-requirements.in isoduration==20.11.0 # via jsonschema jax-jumpy==1.0.0 # via # -r requirements.txt # gymnasium -jedi==0.18.2 +jedi==0.19.1 # via ipython -jinja2==3.1.2 +jinja2==3.1.4 # via + # -r requirements.txt # jupyter-server # jupyterlab # jupyterlab-server - # nbclassic + # memray # nbconvert - # notebook json5==0.9.25 # via jupyterlab-server -jsonpointer==2.4 +jsonpointer==3.0.0 # via jsonschema -jsonschema[format-nongpl]==4.17.3 +jsonschema==4.17.3 # via # -r requirements.txt # jupyter-events # jupyterlab-server # nbformat # ray -jupyter-client==8.0.3 +jupyter-client==8.6.2 # via # ipykernel # jupyter-server - # nbclassic # nbclient - # notebook jupyter-contrib-core==0.4.2 # via jupyter-nbextensions-configurator -jupyter-core==5.2.0 +jupyter-core==5.7.2 # via # ipykernel # jupyter-client @@ -255,11 +242,9 @@ jupyter-core==5.2.0 # jupyter-nbextensions-configurator # jupyter-server # jupyterlab - # nbclassic # nbclient # nbconvert # nbformat - # notebook jupyter-events==0.6.3 # via jupyter-server jupyter-lsp==2.2.5 @@ -268,52 +253,78 @@ jupyter-nbextensions-configurator==0.6.3 # via -r nb-requirements.in jupyter-resource-usage==0.7.2 # via -r nb-requirements.in -jupyter-server==2.7.0 +jupyter-server==2.10.0 # via # -r nb-requirements.in # jupyter-lsp # jupyter-resource-usage # jupyterlab # jupyterlab-server - # nbclassic + # notebook # notebook-shim jupyter-server-terminals==0.5.3 # via jupyter-server jupyterlab==4.0.13 - # via -r nb-requirements.in -jupyterlab-pygments==0.2.2 + # via + # -r nb-requirements.in + # notebook +jupyterlab-pygments==0.3.0 # via nbconvert jupyterlab-server==2.24.0 - # via jupyterlab + # via + # jupyterlab + # notebook +jupyterlab-widgets==3.0.13 + # via ipywidgets keras==2.13.1 # via # -r requirements.txt # tensorflow -kiwisolver==1.4.4 +kiwisolver==1.4.7 # via matplotlib libclang==16.0.6 # via # -r requirements.txt # tensorflow +linkify-it-py==2.0.3 + # via + # -r requirements.txt + # markdown-it-py markdown==3.5.1 # via # -r requirements.txt # tensorboard +markdown-it-py==3.0.0 + # via + # -r requirements.txt + # mdit-py-plugins + # rich + # textual markupsafe==2.1.3 # via # -r requirements.txt # jinja2 # nbconvert # werkzeug -matplotlib==3.7.1 +matplotlib==3.9.2 # via # -r nb-requirements.in # seaborn -matplotlib-inline==0.1.6 +matplotlib-inline==0.1.7 # via # ipykernel # ipython -mistune==2.0.5 +mdit-py-plugins==0.4.2 + # via + # -r requirements.txt + # markdown-it-py +mdurl==0.1.2 + # via + # -r requirements.txt + # markdown-it-py +memray==1.14.0 + # via -r requirements.txt +mistune==3.0.2 # via nbconvert msgpack==1.0.5 # via @@ -324,66 +335,55 @@ multidict==6.0.4 # -r requirements.txt # aiohttp # yarl -nbclassic==0.5.3 - # via notebook -nbclient==0.7.2 +nbclient==0.10.0 # via nbconvert -nbconvert==7.2.10 - # via - # jupyter-server - # nbclassic - # notebook -nbformat==5.7.3 +nbconvert==7.16.4 + # via jupyter-server +nbformat==5.10.4 # via # jupyter-server - # nbclassic # nbclient # nbconvert - # notebook -nest-asyncio==1.5.6 - # via - # ipykernel - # nbclassic - # notebook -notebook==6.5.4 +nest-asyncio==1.6.0 + # via ipykernel +notebook==7.0.8 # via # jupyter-contrib-core # jupyter-nbextensions-configurator -notebook-shim==0.2.2 +notebook-shim==0.2.4 # via # jupyterlab - # nbclassic + # notebook numpy==1.23.5 # via # -r requirements.txt + # clarabel # contourpy + # cvxpy + # ecos # gymnasium # h5py # jax-jumpy # matplotlib # opt-einsum + # osqp # pandas # patsy # pyarrow - # ray + # qdldl # rlplg # scipy + # scs # seaborn # statsmodels # tensorboard # tensorflow -nvidia-ml-py==11.495.46 - # via - # -r requirements.txt - # gpustat oauthlib==3.2.2 # via # -r requirements.txt # requests-oauthlib opencensus==0.11.2 - # via - # -r requirements.txt - # ray + # via -r requirements.txt opencensus-context==0.1.3 # via # -r requirements.txt @@ -392,6 +392,10 @@ opt-einsum==3.3.0 # via # -r requirements.txt # tensorflow +osqp==0.6.7.post1 + # via + # -r requirements.txt + # cvxpy overrides==7.7.0 # via jupyter-server packaging==23.2 @@ -409,18 +413,15 @@ packaging==23.2 pandas==2.0.3 # via # -r requirements.txt - # ray # seaborn # statsmodels -pandocfilters==1.5.0 +pandocfilters==1.5.1 # via nbconvert -parso==0.8.3 +parso==0.8.4 # via jedi patsy==0.5.6 # via statsmodels -pexpect==4.8.0 - # via ipython -pickleshare==0.7.5 +pexpect==4.9.0 # via ipython pillow==9.4.0 # via @@ -437,10 +438,7 @@ prometheus-client==0.13.1 # -r requirements.txt # jupyter-resource-usage # jupyter-server - # nbclassic - # notebook - # ray -prompt-toolkit==3.0.38 +prompt-toolkit==3.0.47 # via ipython protobuf==4.25.1 # via @@ -450,26 +448,20 @@ protobuf==4.25.1 # ray # tensorboard # tensorflow -psutil==5.9.4 +psutil==5.9.8 # via - # -r requirements.txt - # gpustat # ipykernel # jupyter-resource-usage ptyprocess==0.7.0 # via # pexpect # terminado -pure-eval==0.2.2 +pure-eval==0.2.3 # via stack-data py-spy==0.3.14 - # via - # -r requirements.txt - # ray + # via -r requirements.txt pyarrow==15.0.0 - # via - # -r requirements.txt - # ray + # via -r requirements.txt pyasn1==0.4.8 # via # -r requirements.txt @@ -479,17 +471,17 @@ pyasn1-modules==0.2.8 # via # -r requirements.txt # google-auth -pycparser==2.21 +pycparser==2.22 # via cffi pydantic==1.10.6 + # via -r requirements.txt +pygments==2.18.0 # via # -r requirements.txt - # ray -pygments==2.14.0 - # via # ipython # nbconvert -pyparsing==3.0.9 + # rich +pyparsing==3.1.4 # via matplotlib pyrsistent==0.19.3 # via @@ -514,15 +506,17 @@ pyyaml==6.0 # jupyter-events # jupyter-nbextensions-configurator # ray -pyzmq==25.0.1 +pyzmq==26.2.0 # via # ipykernel # jupyter-client # jupyter-resource-usage # jupyter-server - # nbclassic - # notebook -ray[data,default]==2.9.3 +qdldl==0.1.7.post4 + # via + # -r requirements.txt + # osqp +ray==2.23.0 # via -r requirements.txt requests==2.28.2 # via @@ -544,7 +538,12 @@ rfc3986-validator==0.1.1 # via # jsonschema # jupyter-events -rlplg @ git+https://github.com/guidj/rlplg.git@v0.19.10 +rich==13.8.1 + # via + # -r requirements.txt + # memray + # textual +rlplg @ git+https://github.com/guidj/rlplg.git@93d819cab2eeefdd9bbd0ec7acb5d1dd0e5d1bed # via -r requirements.txt rsa==4.9 # via @@ -553,38 +552,47 @@ rsa==4.9 scipy==1.10.1 # via # -r requirements.txt + # clarabel + # cvxpy + # ecos + # osqp + # qdldl # rlplg + # scs # statsmodels +scs==3.2.7 + # via + # -r requirements.txt + # cvxpy seaborn==0.13.1 # via -r nb-requirements.in -send2trash==1.8.0 +send2trash==1.8.3 + # via jupyter-server +setuptools==74.1.2 # via - # jupyter-server - # nbclassic - # notebook + # -r requirements.txt + # jupyter-contrib-core + # tensorboard + # tensorflow six==1.16.0 # via # -r requirements.txt # asttokens # astunparse # bleach - # blessed # google-auth # google-pasta - # gpustat # patsy # python-dateutil # rfc3339-validator # tensorflow smart-open==6.3.0 - # via - # -r requirements.txt - # ray -sniffio==1.3.0 + # via -r requirements.txt +sniffio==1.3.1 # via anyio -soupsieve==2.4 +soupsieve==2.6 # via beautifulsoup4 -stack-data==0.6.2 +stack-data==0.6.3 # via ipython statsmodels==0.14.1 # via -r nb-requirements.in @@ -610,17 +618,19 @@ termcolor==2.3.0 # via # -r requirements.txt # tensorflow -terminado==0.17.1 +terminado==0.18.1 # via # jupyter-server # jupyter-server-terminals - # nbclassic - # notebook -tinycss2==1.2.1 +textual==0.78.0 + # via + # -r requirements.txt + # memray +tinycss2==1.3.0 # via nbconvert tomli==2.0.1 # via jupyterlab -tornado==6.2 +tornado==6.4.1 # via # ipykernel # jupyter-client @@ -628,14 +638,16 @@ tornado==6.2 # jupyter-nbextensions-configurator # jupyter-server # jupyterlab - # nbclassic # notebook # terminado -traitlets==5.9.0 +tqdm==4.66.5 + # via -r nb-requirements.in +traitlets==5.14.3 # via # comm # ipykernel # ipython + # ipywidgets # jupyter-client # jupyter-contrib-core # jupyter-core @@ -644,24 +656,29 @@ traitlets==5.9.0 # jupyter-server # jupyterlab # matplotlib-inline - # nbclassic # nbclient # nbconvert # nbformat - # notebook -types-python-dateutil==2.8.19.14 +types-python-dateutil==2.9.0.20240906 # via arrow typing-extensions==4.5.0 # via # -r requirements.txt + # anyio # async-lru # gymnasium + # ipython # pydantic # tensorflow + # textual tzdata==2024.1 # via # -r requirements.txt # pandas +uc-micro-py==1.0.3 + # via + # -r requirements.txt + # linkify-it-py uri-template==1.3.0 # via jsonschema urllib3==1.26.15 @@ -669,21 +686,16 @@ urllib3==1.26.15 # -r requirements.txt # requests virtualenv==20.21.0 - # via - # -r requirements.txt - # ray -wcwidth==0.2.6 - # via - # -r requirements.txt - # blessed - # prompt-toolkit -webcolors==1.13 + # via -r requirements.txt +wcwidth==0.2.13 + # via prompt-toolkit +webcolors==24.8.0 # via jsonschema webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.5.1 +websocket-client==1.8.0 # via jupyter-server werkzeug==3.0.1 # via @@ -694,6 +706,8 @@ wheel==0.42.0 # -r requirements.txt # astunparse # tensorboard +widgetsnbextension==4.0.13 + # via ipywidgets wrapt==1.16.0 # via # -r requirements.txt @@ -707,6 +721,3 @@ zipp==3.15.0 # -r requirements.txt # importlib-metadata # importlib-resources - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/notebooks/daaf_analyses/onpolicy_eval/exp-policyeval-v3-s0-debug.ipynb b/notebooks/daaf_analyses/onpolicy_eval/exp-policyeval-v3-s0-debug.ipynb index 8221a63..b48be63 100644 --- a/notebooks/daaf_analyses/onpolicy_eval/exp-policyeval-v3-s0-debug.ipynb +++ b/notebooks/daaf_analyses/onpolicy_eval/exp-policyeval-v3-s0-debug.ipynb @@ -44,11 +44,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "2024-05-09 14:25:53.518099: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", - "2024-05-09 14:25:53.601199: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", - "2024-05-09 14:25:53.602788: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "2024-10-22 20:10:43.322322: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-10-22 20:10:43.570276: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-10-22 20:10:43.574880: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", - "2024-05-09 14:25:55.277181: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" + "2024-10-22 20:10:47.471641: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" ] } ], @@ -102,8 +102,7 @@ }, "outputs": [], "source": [ - "# files = tf.io.gfile.glob(f\"{pathlib.Path.home()}/fs/daaf/exp/evaljob/agg/1713705374/1713706816/logs/*.parquet\")\n", - "files = tf.io.gfile.glob(f\"{pathlib.Path.home()}/fs/daaf/exp/evaljob/agg/1715263005/1715264399/logs/*.parquet\")" + "files = tf.io.gfile.glob(f\"{pathlib.Path.home()}/fs/daaf/exp/evaljob/agg/1726821587/1726836104/logs/*.parquet\")" ] }, { @@ -113,45 +112,23 @@ "metadata": { "tags": [] }, - "outputs": [], - "source": [ - "def read_data(files):\n", - " def parse_state_values(sv):\n", - " return np.reshape(\n", - " sv[\"data\"], newshape=sv[\"shape\"]\n", - " )\n", - "\n", - " with ray.init() as context:\n", - " ds_logs = ray.data.read_parquet(files)\n", - " df_logs = ds_logs.to_pandas()\n", - " df_logs[\"state_values\"] = df_logs[\"state_values\"].apply(parse_state_values)\n", - " return df_logs" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "0a64b5b3-7f77-4651-ad49-fb47ad711ab6", - "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "2024-05-09 14:26:01,096\tINFO worker.py:1715 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8265 \u001b[39m\u001b[22m\n", - "/home/guilherme/.pyenv/versions/3.8.15/envs/rl_daaf_py38/lib/python3.8/site-packages/ray/data/datasource/parquet_datasource.py:242: FutureWarning: Passing 'use_legacy_dataset' is deprecated as of pyarrow 15.0.0 and will be removed in a future version.\n", - " pq_ds = pq.ParquetDataset(\n" + "2024-10-22 20:10:54,958\tINFO worker.py:1740 -- Started a local Ray instance. View the dashboard at \u001b[1m\u001b[32m127.0.0.1:8265 \u001b[39m\u001b[22m\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "", + "model_id": "1765603bc1134e0ea4f3f7292f02c3b3", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "Metadata Fetch Progress 0: 0%| | 0/8 [00:00exp_id\n", " meta\n", " state_values\n", + " method\n", " \n", " \n", " \n", " \n", " 0\n", - " 190\n", - " 1715263005-63036c13-RedGreenSeq\n", - " {'algorithm': 'nstep-td', 'algorithm_args': '{...\n", - " [[-19.80916204739628, -17.15846900836637, -15....\n", + " 990\n", + " 1726821587-21b840ef-ABCSeq\n", + " {'algorithm': 'one-step-td', 'algorithm_args':...\n", + " [[-62.38473431795828, -57.893917238232035, -48...\n", + " OP/OT\n", " \n", " \n", " 1\n", - " 190\n", - " 1715263005-62c63001-RedGreenSeq\n", - " {'algorithm': 'one-step-td', 'algorithm_args':...\n", - " [[-43.92046183252249, -50.5326225834207, -46.3...\n", + " 990\n", + " 1726821587-2254cc65-RedGreenSeq\n", + " {'algorithm': 'nstep-td', 'algorithm_args': '{...\n", + " [[-25.86085290433757, -20.250514343849584, -16...\n", + " PP/nTD-SU\n", " \n", " \n", " 2\n", - " 190\n", - " 1715263005-610000df-IceWorld\n", - " {'algorithm': 'first-visit-mc', 'algorithm_arg...\n", - " [[-14.827999747368427, -11.283347368421053, -7...\n", + " 990\n", + " 1726821587-04f7bf9a-ABCSeq\n", + " {'algorithm': 'nstep-td', 'algorithm_args': '{...\n", + " [[-53.60551642990765, -51.4935817493364, -41.9...\n", + " PP/LEAST\n", " \n", " \n", " 3\n", - " 190\n", - " 1715263005-620823ad-IceWorld\n", - " {'algorithm': 'one-step-td', 'algorithm_args':...\n", - " [[-3.4963652196459276, -3.329828339930383, -3....\n", + " 990\n", + " 1726821587-0507c0a7-IceWorld\n", + " {'algorithm': 'nstep-td', 'algorithm_args': '{...\n", + " [[-35.681611557425725, -35.04424679205454, -35...\n", + " PP/FR\n", " \n", " \n", " 4\n", - " 190\n", - " 1715263005-621fa01e-IceWorld\n", - " {'algorithm': 'nstep-td', 'algorithm_args': '{...\n", - " [[-24.19032434155747, -20.160954699992576, -13...\n", + " 990\n", + " 1726821587-15400276-RedGreenSeq\n", + " {'algorithm': 'one-step-td', 'algorithm_args':...\n", + " [[-19.99905073604775, -17.97932123818952, -15....\n", + " PP/LEAST\n", " \n", " \n", " ...\n", @@ -273,91 +248,97 @@ " ...\n", " ...\n", " ...\n", + " ...\n", " \n", " \n", - " 35195\n", - " 320\n", - " 1715263005-fe765c36-ABCSeq\n", - " {'algorithm': 'one-step-td', 'algorithm_args':...\n", - " [[-60.63560866795348, -54.81887338278657, -49....\n", + " 191995\n", + " 2070\n", + " 1726821587-e6fee995-IceWorld\n", + " {'algorithm': 'first-visit-mc', 'algorithm_arg...\n", + " [[-74.13333333333331, -74.73787599754469, -72....\n", + " PP/LEAST\n", " \n", " \n", - " 35196\n", - " 320\n", - " 1715263005-fe035497-TowerOfHanoi\n", - " {'algorithm': 'nstep-td', 'algorithm_args': '{...\n", - " [[-585.2524318287763, -447.4177079856222, -392...\n", + " 191996\n", + " 2070\n", + " 1726821587-e6466ca2-IceWorld\n", + " {'algorithm': 'one-step-td', 'algorithm_args':...\n", + " [[-139.34779864785156, -136.62189048544238, -1...\n", + " PP/DMR\n", " \n", " \n", - " 35197\n", - " 320\n", - " 1715263005-fdfb0ab0-FrozenLake-v1\n", - " {'algorithm': 'one-step-td', 'algorithm_args':...\n", - " [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...\n", + " 191997\n", + " 2070\n", + " 1726821587-e608c03d-TowerOfHanoi\n", + " {'algorithm': 'first-visit-mc', 'algorithm_arg...\n", + " [[-97.2945799153621, -93.4193206879249, -89.01...\n", + " PP/LEAST\n", " \n", " \n", - " 35198\n", - " 320\n", - " 1715263005-fe4244b9-RedGreenSeq\n", - " {'algorithm': 'nstep-td', 'algorithm_args': '{...\n", - " [[-22.268445336471185, -20.09385533420391, -15...\n", + " 191998\n", + " 2070\n", + " 1726821587-fd16e07c-TowerOfHanoi\n", + " {'algorithm': 'one-step-td', 'algorithm_args':...\n", + " [[-2093.2026275281783, -1795.0428249931097, -1...\n", + " OP/OT\n", " \n", " \n", - " 35199\n", - " 320\n", - " 1715263005-ff4a7b3a-GridWorld\n", - " {'algorithm': 'nstep-td', 'algorithm_args': '{...\n", - " [[-112626.75179711408, -112630.30079229665, -1...\n", + " 191999\n", + " 2070\n", + " 1726821587-fd001225-TowerOfHanoi\n", + " {'algorithm': 'one-step-td', 'algorithm_args':...\n", + " [[-590.0846928476416, -558.3777259512275, -516...\n", + " PP/DMR\n", " \n", " \n", "\n", - "

35200 rows × 4 columns

\n", + "

192000 rows × 5 columns

\n", "" ], "text/plain": [ - " episode exp_id \\\n", - "0 190 1715263005-63036c13-RedGreenSeq \n", - "1 190 1715263005-62c63001-RedGreenSeq \n", - "2 190 1715263005-610000df-IceWorld \n", - "3 190 1715263005-620823ad-IceWorld \n", - "4 190 1715263005-621fa01e-IceWorld \n", - "... ... ... \n", - "35195 320 1715263005-fe765c36-ABCSeq \n", - "35196 320 1715263005-fe035497-TowerOfHanoi \n", - "35197 320 1715263005-fdfb0ab0-FrozenLake-v1 \n", - "35198 320 1715263005-fe4244b9-RedGreenSeq \n", - "35199 320 1715263005-ff4a7b3a-GridWorld \n", + " episode exp_id \\\n", + "0 990 1726821587-21b840ef-ABCSeq \n", + "1 990 1726821587-2254cc65-RedGreenSeq \n", + "2 990 1726821587-04f7bf9a-ABCSeq \n", + "3 990 1726821587-0507c0a7-IceWorld \n", + "4 990 1726821587-15400276-RedGreenSeq \n", + "... ... ... \n", + "191995 2070 1726821587-e6fee995-IceWorld \n", + "191996 2070 1726821587-e6466ca2-IceWorld \n", + "191997 2070 1726821587-e608c03d-TowerOfHanoi \n", + "191998 2070 1726821587-fd16e07c-TowerOfHanoi \n", + "191999 2070 1726821587-fd001225-TowerOfHanoi \n", "\n", - " meta \\\n", - "0 {'algorithm': 'nstep-td', 'algorithm_args': '{... \n", - "1 {'algorithm': 'one-step-td', 'algorithm_args':... \n", - "2 {'algorithm': 'first-visit-mc', 'algorithm_arg... \n", - "3 {'algorithm': 'one-step-td', 'algorithm_args':... \n", - "4 {'algorithm': 'nstep-td', 'algorithm_args': '{... \n", - "... ... \n", - "35195 {'algorithm': 'one-step-td', 'algorithm_args':... \n", - "35196 {'algorithm': 'nstep-td', 'algorithm_args': '{... \n", - "35197 {'algorithm': 'one-step-td', 'algorithm_args':... \n", - "35198 {'algorithm': 'nstep-td', 'algorithm_args': '{... \n", - "35199 {'algorithm': 'nstep-td', 'algorithm_args': '{... \n", + " meta \\\n", + "0 {'algorithm': 'one-step-td', 'algorithm_args':... \n", + "1 {'algorithm': 'nstep-td', 'algorithm_args': '{... \n", + "2 {'algorithm': 'nstep-td', 'algorithm_args': '{... \n", + "3 {'algorithm': 'nstep-td', 'algorithm_args': '{... \n", + "4 {'algorithm': 'one-step-td', 'algorithm_args':... \n", + "... ... \n", + "191995 {'algorithm': 'first-visit-mc', 'algorithm_arg... \n", + "191996 {'algorithm': 'one-step-td', 'algorithm_args':... \n", + "191997 {'algorithm': 'first-visit-mc', 'algorithm_arg... \n", + "191998 {'algorithm': 'one-step-td', 'algorithm_args':... \n", + "191999 {'algorithm': 'one-step-td', 'algorithm_args':... \n", "\n", - " state_values \n", - "0 [[-19.80916204739628, -17.15846900836637, -15.... \n", - "1 [[-43.92046183252249, -50.5326225834207, -46.3... \n", - "2 [[-14.827999747368427, -11.283347368421053, -7... \n", - "3 [[-3.4963652196459276, -3.329828339930383, -3.... \n", - "4 [[-24.19032434155747, -20.160954699992576, -13... \n", - "... ... \n", - "35195 [[-60.63560866795348, -54.81887338278657, -49.... \n", - "35196 [[-585.2524318287763, -447.4177079856222, -392... \n", - "35197 [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,... \n", - "35198 [[-22.268445336471185, -20.09385533420391, -15... \n", - "35199 [[-112626.75179711408, -112630.30079229665, -1... \n", + " state_values method \n", + "0 [[-62.38473431795828, -57.893917238232035, -48... OP/OT \n", + "1 [[-25.86085290433757, -20.250514343849584, -16... PP/nTD-SU \n", + "2 [[-53.60551642990765, -51.4935817493364, -41.9... PP/LEAST \n", + "3 [[-35.681611557425725, -35.04424679205454, -35... PP/FR \n", + "4 [[-19.99905073604775, -17.97932123818952, -15.... PP/LEAST \n", + "... ... ... \n", + "191995 [[-74.13333333333331, -74.73787599754469, -72.... PP/LEAST \n", + "191996 [[-139.34779864785156, -136.62189048544238, -1... PP/DMR \n", + "191997 [[-97.2945799153621, -93.4193206879249, -89.01... PP/LEAST \n", + "191998 [[-2093.2026275281783, -1795.0428249931097, -1... OP/OT \n", + "191999 [[-590.0846928476416, -558.3777259512275, -516... PP/DMR \n", "\n", - "[35200 rows x 4 columns]" + "[192000 rows x 5 columns]" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -368,33 +349,33 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "4f0a9fd6-4f63-491c-96bc-8b61db780c15", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'algorithm': 'nstep-td',\n", + "{'algorithm': 'one-step-td',\n", " 'algorithm_args': '{}',\n", - " 'discount_factor': 0.99,\n", + " 'discount_factor': 1.0,\n", " 'drop_truncated_feedback_episodes': False,\n", - " 'dyna_prog_state_values': array([-21.2446288 , -18.85810242, -16.39925704, -13.8659012 ,\n", - " -11.25577699, -8.56655811, -5.79584775, -2.94117647,\n", - " 0. ]),\n", - " 'env': {'args': '{\"cure\": [\"red\", \"green\", \"wait\", \"green\", \"red\", \"red\", \"green\", \"wait\"]}',\n", - " 'level': 'NNLHYJFTC5ENMMDZWRNQ37B6VVDXQ7WHB5EJOPXYZFLMJEZOYLTSLB4ID4WHQG57XQPNUHGZCFDCWHYGXWSBW7FBWYRZGAGBW4J7MEQ=',\n", - " 'name': 'RedGreenSeq'},\n", + " 'dyna_prog_state_values': array([-48.99999993, -41.99999999, -35. , -28. ,\n", + " -21. , -14. , -7. , 0. ]),\n", + " 'env': {'args': '{\"length\": 7, \"distance_penalty\": false}',\n", + " 'level': 'n=7',\n", + " 'name': 'ABCSeq',\n", + " '_level': '7'},\n", " 'epsilon': 0.0,\n", - " 'exp_id': '1715263005-63036c13-RedGreenSeq',\n", + " 'exp_id': '1726821587-21b840ef-ABCSeq',\n", " 'learning_rate': 0.1,\n", - " 'policy_type': 'single-step',\n", - " 'reward_period': 1,\n", - " 'run_id': 0,\n", - " 'traj_mapping_method': 'identity-mapper'}" + " 'policy_type': 'OP',\n", + " 'reward_period': 4,\n", + " 'run_id': 1,\n", + " 'traj_mapping_method': 'OT'}" ] }, - "execution_count": 11, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -405,31 +386,57 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "88a3d394-6efe-4078-bb42-7968518f84e3", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array([[-19.80916205, -17.15846901, -15.07822161, -13.10162062,\n", - " -11.11065543, -7.89008745, -5.64204705, -3.77681442,\n", - " 0. ],\n", - " [-20.80130241, -18.15388695, -15.95066797, -12.58330711,\n", - " -11.19311473, -7.55809078, -5.55615529, -2.39728152,\n", - " 0. ],\n", - " [-20.09659903, -18.01323126, -15.60486008, -13.13494021,\n", - " -9.75580219, -7.6048606 , -5.73190235, -2.44412537,\n", - " 0. ],\n", - " [-18.75516043, -15.96808788, -14.15271232, -12.3543663 ,\n", - " -9.64362078, -7.67929014, -4.77736684, -2.09508469,\n", - " 0. ],\n", - " [-19.13063067, -17.23350318, -14.73423066, -12.40770957,\n", - " -9.70660169, -7.28868384, -5.00166949, -2.04525222,\n", - " 0. ]])" + "array([array([-62.38473432, -57.89391724, -48.40994031, -41.93988829,\n", + " -36.82012224, -33.00358336, -33.35413741, 0. ]),\n", + " array([-70.21600194, -64.05622633, -59.23307837, -54.43573262,\n", + " -48.10336585, -37.65105584, -39.33444675, 0. ]),\n", + " array([-59.31433658, -53.6538727 , -50.5297967 , -44.29168706,\n", + " -39.00348638, -33.22616021, -34.11424269, 0. ]),\n", + " array([-64.4761011 , -55.82896246, -46.97910587, -39.72116391,\n", + " -32.28467785, -25.51095978, -20.26832619, 0. ]),\n", + " array([-69.59847322, -61.45655243, -54.78226728, -44.16563635,\n", + " -36.41240245, -27.19048432, -18.00215805, 0. ]),\n", + " array([-61.54400074, -53.34351252, -46.98489461, -43.8664305 ,\n", + " -38.36243059, -28.31403682, -21.84773079, 0. ]),\n", + " array([-59.78711475, -51.53655656, -46.78929235, -41.40451792,\n", + " -37.24107973, -33.87089113, -26.02184843, 0. ]),\n", + " array([-59.53803928, -53.6398948 , -48.93794001, -40.44610007,\n", + " -35.91181892, -28.89896686, -22.22427463, 0. ]),\n", + " array([-65.42240892, -59.45936588, -52.5658377 , -45.07079332,\n", + " -41.3818037 , -37.27629655, -33.92328829, 0. ]),\n", + " array([-68.31657327, -64.55343529, -60.79172138, -52.26920948,\n", + " -46.65858095, -34.59051215, -27.12679202, 0. ]),\n", + " array([-66.04777687, -59.77373016, -52.67904634, -43.42434074,\n", + " -38.21979578, -33.00999479, -28.59073861, 0. ]),\n", + " array([-68.34315818, -61.06738001, -58.05289359, -50.87478519,\n", + " -40.25280686, -36.79997692, -23.27329403, 0. ]),\n", + " array([-65.46381065, -59.60573407, -52.65617267, -50.40200058,\n", + " -39.1041481 , -36.25872015, -30.14125955, 0. ]),\n", + " array([-69.43674113, -66.91938263, -58.59263317, -48.75698361,\n", + " -42.91883335, -31.77018949, -28.96906258, 0. ]),\n", + " array([-68.81135842, -63.57972422, -54.29431291, -45.76332261,\n", + " -41.23253474, -31.83075389, -24.66734394, 0. ]),\n", + " array([-73.97143748, -65.80372165, -57.73843701, -49.10862825,\n", + " -42.7424336 , -31.97280061, -27.78977434, 0. ]),\n", + " array([-64.02734503, -59.36834589, -52.88777528, -47.8488008 ,\n", + " -39.12020938, -30.00429947, -19.03085421, 0. ]),\n", + " array([-61.35066975, -53.28254387, -46.41961134, -42.342734 ,\n", + " -37.6979335 , -30.06854272, -23.55119613, 0. ]),\n", + " array([-59.31065636, -52.71621682, -48.55534776, -40.9161892 ,\n", + " -35.05792463, -29.54701469, -27.17912862, 0. ]),\n", + " array([-61.9275178 , -56.78906942, -50.84300074, -42.76652845,\n", + " -36.13238298, -30.5997187 , -20.79458891, 0. ])],\n", + " dtype=object)" ] }, - "execution_count": 12, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -440,7 +447,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "c5097387-244d-40b2-89b6-f631867eb280", "metadata": {}, "outputs": [], @@ -490,7 +497,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "ec0f89b6-3cc5-49ed-a2dd-55792bbbc895", "metadata": {}, "outputs": [], @@ -498,32 +505,39 @@ "def plots_value_changes(df_problem: pd.DataFrame):\n", " # plot per state and traj mapper\n", " rows = []\n", - " for row in df_problem.to_dict(\"records\"):\n", - " state_values = row[\"state_values\"]\n", + " df_ref = copy.deepcopy(df_problem)\n", + " \n", + " for row in df_ref.to_dict(\"records\"):\n", + " state_values = np.stack(row[\"state_values\"])\n", " # shape: (runs x states)\n", - " for run in range(state_values.shape[0]):\n", - " for state in range(state_values.shape[1]):\n", - " new_row = copy.deepcopy(row)\n", - " del new_row[\"state_values\"]\n", - " new_row[\"run\"] = run\n", - " new_row[\"state\"] = state\n", - " new_row[\"state_value\"] = row[\"state_values\"][run, state]\n", - " rows.append(new_row)\n", - " df_methods = pd.DataFrame(rows)\n", + " num_runs, _ = state_values.shape\n", + " # new_row = copy.deepcopy(row)\n", + " for run in range(num_runs):\n", + " rows.extend(\n", + " [\n", + " {**row, \"run\": run, \"state\": state, \"state_value\": value}\n", + " for state, value in enumerate(state_values[run])\n", + " ]\n", + " )\n", "\n", + " df_methods = pd.DataFrame(rows)\n", " ref_row = copy.deepcopy(df_methods.iloc[0]).to_dict()\n", " episodes = df_methods[\"episode\"].unique()\n", " runs = df_methods[\"run\"].unique()\n", + " \n", " for episode in episodes:\n", " for run in runs:\n", - " for state in range(len(ref_row[\"dyna_prog_state_values\"])):\n", - " ref_row[\"episode\"] = episode\n", - " ref_row[\"run\"] = run\n", - " ref_row[\"state\"] = state\n", - " ref_row[\"state_value\"] = ref_row[\"dyna_prog_state_values\"][state]\n", - " ref_row[\"traj_mapping_method\"] = \"dynamic-programming\"\n", - " rows.append(copy.deepcopy(ref_row))\n", + " rows.extend(\n", + " [\n", + " {**ref_row, \"episode\": episode, \"run\": run, \n", + " \"state\": state, \"state_value\": value,\n", + " \"traj_mapping_method\": \"dynamic-programming\"}\n", + " for state, value in enumerate(ref_row[\"dyna_prog_state_values\"])\n", + " ]\n", + " )\n", + " \n", " df_plot = pd.DataFrame(rows)\n", + "\n", " palette = sns.color_palette(\"tab10\")\n", " return sns.relplot(\n", " data=df_plot,\n", @@ -544,13 +558,13 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "a6134de5-95c9-435d-a57b-5fb7ef7bab8a", "metadata": {}, "outputs": [], "source": [ "df_abc_td = isolate_problem(\n", - " df_logs, problem=\"ABCSeq\", level=\"7\",\n", + " df_logs, problem=\"ABCSeq\", level=\"n=7\",\n", " algo=\"one-step-td\",\n", " reward_period=4,\n", " gamma=1.0\n", @@ -559,29 +573,8 @@ }, { "cell_type": "code", - "execution_count": 16, - "id": "58542734-fa6b-4c2b-a3e9-2a51a9ccc2e0", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(250, 14)" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_abc_td.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "b3c3d094-13f1-4dc7-8019-87aac0736529", + "execution_count": 15, + "id": "cc8a59d2-4c01-421b-86b5-9da84b3f354f", "metadata": {}, "outputs": [ { @@ -608,6 +601,7 @@ " episode\n", " exp_id\n", " state_values\n", + " method\n", " algorithm\n", " algorithm_args\n", " discount_factor\n", @@ -623,10 +617,11 @@ " \n", " \n", " \n", - " 85\n", + " 126\n", " 0\n", - " 1715263005-0fceea2e-ABCSeq\n", - " [[0.0, -0.4, -2.4, 0.0, -0.4, -2.2, 0.0, 0.0],...\n", + " 1726821587-1ea78f43-ABCSeq\n", + " [[-0.36, -0.36, -1.08, -1.08, 0.0, -0.4, -0.36...\n", + " PP/LEAST\n", " one-step-td\n", " {}\n", " 1.0\n", @@ -634,16 +629,17 @@ " [-48.99999993362666, -41.99999998952205, -34.9...\n", " 0.0\n", " 0.1\n", - " single-step\n", + " PP\n", " 4\n", - " 3\n", - " daaf-trajectory-mapper\n", + " 0\n", + " LEAST\n", " \n", " \n", - " 84\n", + " 124\n", " 0\n", - " 1715263005-21b6c410-ABCSeq\n", - " [[-1.44, -0.36, -1.08, -1.84, -0.4, -0.72, -0....\n", + " 1726821587-21b840ef-ABCSeq\n", + " [[-0.4, 0.0, -0.4, -0.4, -1.12, -0.4, -1.6, 0....\n", + " OP/OT\n", " one-step-td\n", " {}\n", " 1.0\n", @@ -651,16 +647,17 @@ " [-48.99999993362666, -41.99999998952205, -34.9...\n", " 0.0\n", " 0.1\n", - " single-step\n", - " 4\n", + " OP\n", " 4\n", - " daaf-lsq-reward-attribution-mapper\n", + " 0\n", + " OT\n", " \n", " \n", - " 86\n", + " 125\n", " 0\n", - " 1715263005-a0875722-ABCSeq\n", - " [[-0.4, -1.12, 0.0, -1.12, -1.12, 0.0, -0.8, 0...\n", + " 1726821587-2c44d7d4-ABCSeq\n", + " [[0.0, -0.76, -1.08, 0.0, -1.44, -0.36, -1.480...\n", + " PP/IMR\n", " one-step-td\n", " {}\n", " 1.0\n", @@ -668,16 +665,17 @@ " [-48.99999993362666, -41.99999998952205, -34.9...\n", " 0.0\n", " 0.1\n", - " options\n", + " PP\n", " 4\n", - " 0\n", - " daaf-mdp-with-options-mapper\n", + " 7\n", + " IMR\n", " \n", " \n", - " 221\n", + " 127\n", " 0\n", - " 1715263005-bc83cde9-ABCSeq\n", - " [[-0.19, -0.9999999999999999, -0.82, -0.1, -0....\n", + " 1726821587-73d209a3-ABCSeq\n", + " [[0.0, -0.76, 0.0, -0.8, -0.4, -0.4, -0.4, 0.0...\n", + " PP/DMR\n", " one-step-td\n", " {}\n", " 1.0\n", @@ -685,16 +683,17 @@ " [-48.99999993362666, -41.99999998952205, -34.9...\n", " 0.0\n", " 0.1\n", - " single-step\n", - " 1\n", - " 2\n", - " identity-mapper\n", + " PP\n", + " 4\n", + " 7\n", + " DMR\n", " \n", " \n", - " 87\n", - " 0\n", - " 1715263005-deb09c2d-ABCSeq\n", - " [[-0.36, -0.72, -2.52, -1.44, -1.12, -1.08, -1...\n", + " 630\n", + " 10\n", + " 1726821587-1ea78f43-ABCSeq\n", + " [[-5.32313996368, -6.18825221716, -6.524410501...\n", + " PP/LEAST\n", " one-step-td\n", " {}\n", " 1.0\n", @@ -702,10 +701,10 @@ " [-48.99999993362666, -41.99999998952205, -34.9...\n", " 0.0\n", " 0.1\n", - " single-step\n", + " PP\n", " 4\n", - " 4\n", - " daaf-impute-missing-reward-mapper\n", + " 5\n", + " LEAST\n", " \n", " \n", " ...\n", @@ -723,12 +722,14 @@ " ...\n", " ...\n", " ...\n", + " ...\n", " \n", " \n", - " 73\n", - " 490\n", - " 1715263005-0fceea2e-ABCSeq\n", - " [[-184.2275547627881, -168.00976950574454, -14...\n", + " 227\n", + " 2480\n", + " 1726821587-73d209a3-ABCSeq\n", + " [[-182.56377883692443, -155.14354656423944, -1...\n", + " PP/DMR\n", " one-step-td\n", " {}\n", " 1.0\n", @@ -736,16 +737,17 @@ " [-48.99999993362666, -41.99999998952205, -34.9...\n", " 0.0\n", " 0.1\n", - " single-step\n", + " PP\n", " 4\n", - " 0\n", - " daaf-trajectory-mapper\n", + " 13\n", + " DMR\n", " \n", " \n", - " 72\n", - " 490\n", - " 1715263005-21b6c410-ABCSeq\n", - " [[-44.32532817580987, -36.053428897770765, -28...\n", + " 922\n", + " 2490\n", + " 1726821587-1ea78f43-ABCSeq\n", + " [[-45.99478030246791, -38.75632803206125, -31....\n", + " PP/LEAST\n", " one-step-td\n", " {}\n", " 1.0\n", @@ -753,16 +755,17 @@ " [-48.99999993362666, -41.99999998952205, -34.9...\n", " 0.0\n", " 0.1\n", - " single-step\n", + " PP\n", " 4\n", - " 0\n", - " daaf-lsq-reward-attribution-mapper\n", + " 9\n", + " LEAST\n", " \n", " \n", - " 74\n", - " 490\n", - " 1715263005-a0875722-ABCSeq\n", - " [[-68.52090593819865, -64.48037932133991, -54....\n", + " 920\n", + " 2490\n", + " 1726821587-21b840ef-ABCSeq\n", + " [[-66.92386938335389, -60.87563408620904, -57....\n", + " OP/OT\n", " one-step-td\n", " {}\n", " 1.0\n", @@ -770,16 +773,17 @@ " [-48.99999993362666, -41.99999998952205, -34.9...\n", " 0.0\n", " 0.1\n", - " options\n", + " OP\n", " 4\n", - " 4\n", - " daaf-mdp-with-options-mapper\n", + " 9\n", + " OT\n", " \n", " \n", - " 218\n", - " 490\n", - " 1715263005-bc83cde9-ABCSeq\n", - " [[-44.05348390668802, -39.67905217710317, -34....\n", + " 921\n", + " 2490\n", + " 1726821587-2c44d7d4-ABCSeq\n", + " [[-45.47058454401428, -40.75236931873213, -30....\n", + " PP/IMR\n", " one-step-td\n", " {}\n", " 1.0\n", @@ -787,16 +791,17 @@ " [-48.99999993362666, -41.99999998952205, -34.9...\n", " 0.0\n", " 0.1\n", - " single-step\n", - " 1\n", + " PP\n", " 4\n", - " identity-mapper\n", + " 7\n", + " IMR\n", " \n", " \n", - " 75\n", - " 490\n", - " 1715263005-deb09c2d-ABCSeq\n", - " [[-41.98748866631166, -39.082314732788504, -30...\n", + " 923\n", + " 2490\n", + " 1726821587-73d209a3-ABCSeq\n", + " [[-198.9943684873033, -194.72431378660377, -12...\n", + " PP/DMR\n", " one-step-td\n", " {}\n", " 1.0\n", @@ -804,99 +809,86 @@ " [-48.99999993362666, -41.99999998952205, -34.9...\n", " 0.0\n", " 0.1\n", - " single-step\n", + " PP\n", " 4\n", - " 3\n", - " daaf-impute-missing-reward-mapper\n", + " 7\n", + " DMR\n", " \n", " \n", "\n", - "

250 rows × 14 columns

\n", + "

1000 rows × 15 columns

\n", "" ], "text/plain": [ " episode exp_id \\\n", - "85 0 1715263005-0fceea2e-ABCSeq \n", - "84 0 1715263005-21b6c410-ABCSeq \n", - "86 0 1715263005-a0875722-ABCSeq \n", - "221 0 1715263005-bc83cde9-ABCSeq \n", - "87 0 1715263005-deb09c2d-ABCSeq \n", + "126 0 1726821587-1ea78f43-ABCSeq \n", + "124 0 1726821587-21b840ef-ABCSeq \n", + "125 0 1726821587-2c44d7d4-ABCSeq \n", + "127 0 1726821587-73d209a3-ABCSeq \n", + "630 10 1726821587-1ea78f43-ABCSeq \n", ".. ... ... \n", - "73 490 1715263005-0fceea2e-ABCSeq \n", - "72 490 1715263005-21b6c410-ABCSeq \n", - "74 490 1715263005-a0875722-ABCSeq \n", - "218 490 1715263005-bc83cde9-ABCSeq \n", - "75 490 1715263005-deb09c2d-ABCSeq \n", + "227 2480 1726821587-73d209a3-ABCSeq \n", + "922 2490 1726821587-1ea78f43-ABCSeq \n", + "920 2490 1726821587-21b840ef-ABCSeq \n", + "921 2490 1726821587-2c44d7d4-ABCSeq \n", + "923 2490 1726821587-73d209a3-ABCSeq \n", "\n", - " state_values algorithm \\\n", - "85 [[0.0, -0.4, -2.4, 0.0, -0.4, -2.2, 0.0, 0.0],... one-step-td \n", - "84 [[-1.44, -0.36, -1.08, -1.84, -0.4, -0.72, -0.... one-step-td \n", - "86 [[-0.4, -1.12, 0.0, -1.12, -1.12, 0.0, -0.8, 0... one-step-td \n", - "221 [[-0.19, -0.9999999999999999, -0.82, -0.1, -0.... one-step-td \n", - "87 [[-0.36, -0.72, -2.52, -1.44, -1.12, -1.08, -1... one-step-td \n", - ".. ... ... \n", - "73 [[-184.2275547627881, -168.00976950574454, -14... one-step-td \n", - "72 [[-44.32532817580987, -36.053428897770765, -28... one-step-td \n", - "74 [[-68.52090593819865, -64.48037932133991, -54.... one-step-td \n", - "218 [[-44.05348390668802, -39.67905217710317, -34.... one-step-td \n", - "75 [[-41.98748866631166, -39.082314732788504, -30... one-step-td \n", + " state_values method algorithm \\\n", + "126 [[-0.36, -0.36, -1.08, -1.08, 0.0, -0.4, -0.36... PP/LEAST one-step-td \n", + "124 [[-0.4, 0.0, -0.4, -0.4, -1.12, -0.4, -1.6, 0.... OP/OT one-step-td \n", + "125 [[0.0, -0.76, -1.08, 0.0, -1.44, -0.36, -1.480... PP/IMR one-step-td \n", + "127 [[0.0, -0.76, 0.0, -0.8, -0.4, -0.4, -0.4, 0.0... PP/DMR one-step-td \n", + "630 [[-5.32313996368, -6.18825221716, -6.524410501... PP/LEAST one-step-td \n", + ".. ... ... ... \n", + "227 [[-182.56377883692443, -155.14354656423944, -1... PP/DMR one-step-td \n", + "922 [[-45.99478030246791, -38.75632803206125, -31.... PP/LEAST one-step-td \n", + "920 [[-66.92386938335389, -60.87563408620904, -57.... OP/OT one-step-td \n", + "921 [[-45.47058454401428, -40.75236931873213, -30.... PP/IMR one-step-td \n", + "923 [[-198.9943684873033, -194.72431378660377, -12... PP/DMR one-step-td \n", "\n", " algorithm_args discount_factor drop_truncated_feedback_episodes \\\n", - "85 {} 1.0 False \n", - "84 {} 1.0 False \n", - "86 {} 1.0 False \n", - "221 {} 1.0 False \n", - "87 {} 1.0 False \n", + "126 {} 1.0 False \n", + "124 {} 1.0 False \n", + "125 {} 1.0 False \n", + "127 {} 1.0 False \n", + "630 {} 1.0 False \n", ".. ... ... ... \n", - "73 {} 1.0 False \n", - "72 {} 1.0 False \n", - "74 {} 1.0 False \n", - "218 {} 1.0 False \n", - "75 {} 1.0 False \n", + "227 {} 1.0 False \n", + "922 {} 1.0 False \n", + "920 {} 1.0 False \n", + "921 {} 1.0 False \n", + "923 {} 1.0 False \n", "\n", " dyna_prog_state_values epsilon \\\n", - "85 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", - "84 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", - "86 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", - "221 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", - "87 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "126 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "124 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "125 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "127 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "630 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", ".. ... ... \n", - "73 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", - "72 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", - "74 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", - "218 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", - "75 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", - "\n", - " learning_rate policy_type reward_period run_id \\\n", - "85 0.1 single-step 4 3 \n", - "84 0.1 single-step 4 4 \n", - "86 0.1 options 4 0 \n", - "221 0.1 single-step 1 2 \n", - "87 0.1 single-step 4 4 \n", - ".. ... ... ... ... \n", - "73 0.1 single-step 4 0 \n", - "72 0.1 single-step 4 0 \n", - "74 0.1 options 4 4 \n", - "218 0.1 single-step 1 4 \n", - "75 0.1 single-step 4 3 \n", + "227 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "922 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "920 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "921 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "923 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", "\n", - " traj_mapping_method \n", - "85 daaf-trajectory-mapper \n", - "84 daaf-lsq-reward-attribution-mapper \n", - "86 daaf-mdp-with-options-mapper \n", - "221 identity-mapper \n", - "87 daaf-impute-missing-reward-mapper \n", - ".. ... \n", - "73 daaf-trajectory-mapper \n", - "72 daaf-lsq-reward-attribution-mapper \n", - "74 daaf-mdp-with-options-mapper \n", - "218 identity-mapper \n", - "75 daaf-impute-missing-reward-mapper \n", + " learning_rate policy_type reward_period run_id traj_mapping_method \n", + "126 0.1 PP 4 0 LEAST \n", + "124 0.1 OP 4 0 OT \n", + "125 0.1 PP 4 7 IMR \n", + "127 0.1 PP 4 7 DMR \n", + "630 0.1 PP 4 5 LEAST \n", + ".. ... ... ... ... ... \n", + "227 0.1 PP 4 13 DMR \n", + "922 0.1 PP 4 9 LEAST \n", + "920 0.1 OP 4 9 OT \n", + "921 0.1 PP 4 7 IMR \n", + "923 0.1 PP 4 7 DMR \n", "\n", - "[250 rows x 14 columns]" + "[1000 rows x 15 columns]" ] }, - "execution_count": 17, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -907,170 +899,942 @@ }, { "cell_type": "code", - "execution_count": 18, - "id": "f44eb0de-3f33-4202-93b9-5e98bb5fca12", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array(['daaf-trajectory-mapper', 'daaf-lsq-reward-attribution-mapper',\n", - " 'daaf-mdp-with-options-mapper', 'identity-mapper',\n", - " 'daaf-impute-missing-reward-mapper'], dtype=object)" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_abc_td[\"traj_mapping_method\"].unique()" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "9bafca96-a251-436e-8b77-024eba96fc53", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(5, 8)" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "df_abc_td[\"state_values\"].iloc[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "6564c1d0-2573-4c2e-afce-b9a75827ec45", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_3266377/3725396169.py:31: UserWarning: The palette list has more values (10) than needed (6), which may not be intended.\n", - " return sns.relplot(\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plots_value_changes(df_abc_td)" - ] - }, - { - "cell_type": "markdown", - "id": "a49f0cbf-c120-4d08-96e1-ed8b2c668769", - "metadata": {}, - "source": [ - "## Monte Carlo" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "4958f05d-67c9-441f-9eff-6cd37717fb99", - "metadata": {}, - "outputs": [], - "source": [ - "df_abc_mc = isolate_problem(\n", - " df_logs, problem=\"ABCSeq\", level=\"7\",\n", - " algo=\"first-visit-mc\",\n", - " reward_period=4,\n", - " gamma=1.0\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "6317b523-ccac-4435-a21e-23cad48ac96a", + "execution_count": 16, + "id": "58542734-fa6b-4c2b-a3e9-2a51a9ccc2e0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array(['daaf-impute-missing-reward-mapper', 'daaf-trajectory-mapper',\n", - " 'identity-mapper', 'daaf-mdp-with-options-mapper',\n", - " 'daaf-lsq-reward-attribution-mapper'], dtype=object)" + "(1000, 15)" ] }, - "execution_count": 22, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "df_abc_mc[\"traj_mapping_method\"].unique()" + "df_abc_td.shape" ] }, { "cell_type": "code", - "execution_count": 23, - "id": "67ff8052-7972-4bcb-808b-7c48b0138b4e", + "execution_count": 17, + "id": "b3c3d094-13f1-4dc7-8019-87aac0736529", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/tmp/ipykernel_3266377/3725396169.py:31: UserWarning: The palette list has more values (10) than needed (6), which may not be intended.\n", - " return sns.relplot(\n" - ] - }, - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - }, { "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plots_value_changes(df_abc_mc)" + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
episodeexp_idstate_valuesmethodalgorithmalgorithm_argsdiscount_factordrop_truncated_feedback_episodesdyna_prog_state_valuesepsilonlearning_ratepolicy_typereward_periodrun_idtraj_mapping_method
12601726821587-1ea78f43-ABCSeq[[-0.36, -0.36, -1.08, -1.08, 0.0, -0.4, -0.36...PP/LEASTone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP40LEAST
12401726821587-21b840ef-ABCSeq[[-0.4, 0.0, -0.4, -0.4, -1.12, -0.4, -1.6, 0....OP/OTone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1OP40OT
12501726821587-2c44d7d4-ABCSeq[[0.0, -0.76, -1.08, 0.0, -1.44, -0.36, -1.480...PP/IMRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP47IMR
12701726821587-73d209a3-ABCSeq[[0.0, -0.76, 0.0, -0.8, -0.4, -0.4, -0.4, 0.0...PP/DMRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP47DMR
630101726821587-1ea78f43-ABCSeq[[-5.32313996368, -6.18825221716, -6.524410501...PP/LEASTone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP45LEAST
................................................
22724801726821587-73d209a3-ABCSeq[[-182.56377883692443, -155.14354656423944, -1...PP/DMRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP413DMR
92224901726821587-1ea78f43-ABCSeq[[-45.99478030246791, -38.75632803206125, -31....PP/LEASTone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP49LEAST
92024901726821587-21b840ef-ABCSeq[[-66.92386938335389, -60.87563408620904, -57....OP/OTone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1OP49OT
92124901726821587-2c44d7d4-ABCSeq[[-45.47058454401428, -40.75236931873213, -30....PP/IMRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP47IMR
92324901726821587-73d209a3-ABCSeq[[-198.9943684873033, -194.72431378660377, -12...PP/DMRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP47DMR
\n", + "

1000 rows × 15 columns

\n", + "
" + ], + "text/plain": [ + " episode exp_id \\\n", + "126 0 1726821587-1ea78f43-ABCSeq \n", + "124 0 1726821587-21b840ef-ABCSeq \n", + "125 0 1726821587-2c44d7d4-ABCSeq \n", + "127 0 1726821587-73d209a3-ABCSeq \n", + "630 10 1726821587-1ea78f43-ABCSeq \n", + ".. ... ... \n", + "227 2480 1726821587-73d209a3-ABCSeq \n", + "922 2490 1726821587-1ea78f43-ABCSeq \n", + "920 2490 1726821587-21b840ef-ABCSeq \n", + "921 2490 1726821587-2c44d7d4-ABCSeq \n", + "923 2490 1726821587-73d209a3-ABCSeq \n", + "\n", + " state_values method algorithm \\\n", + "126 [[-0.36, -0.36, -1.08, -1.08, 0.0, -0.4, -0.36... PP/LEAST one-step-td \n", + "124 [[-0.4, 0.0, -0.4, -0.4, -1.12, -0.4, -1.6, 0.... OP/OT one-step-td \n", + "125 [[0.0, -0.76, -1.08, 0.0, -1.44, -0.36, -1.480... PP/IMR one-step-td \n", + "127 [[0.0, -0.76, 0.0, -0.8, -0.4, -0.4, -0.4, 0.0... PP/DMR one-step-td \n", + "630 [[-5.32313996368, -6.18825221716, -6.524410501... PP/LEAST one-step-td \n", + ".. ... ... ... \n", + "227 [[-182.56377883692443, -155.14354656423944, -1... PP/DMR one-step-td \n", + "922 [[-45.99478030246791, -38.75632803206125, -31.... PP/LEAST one-step-td \n", + "920 [[-66.92386938335389, -60.87563408620904, -57.... OP/OT one-step-td \n", + "921 [[-45.47058454401428, -40.75236931873213, -30.... PP/IMR one-step-td \n", + "923 [[-198.9943684873033, -194.72431378660377, -12... PP/DMR one-step-td \n", + "\n", + " algorithm_args discount_factor drop_truncated_feedback_episodes \\\n", + "126 {} 1.0 False \n", + "124 {} 1.0 False \n", + "125 {} 1.0 False \n", + "127 {} 1.0 False \n", + "630 {} 1.0 False \n", + ".. ... ... ... \n", + "227 {} 1.0 False \n", + "922 {} 1.0 False \n", + "920 {} 1.0 False \n", + "921 {} 1.0 False \n", + "923 {} 1.0 False \n", + "\n", + " dyna_prog_state_values epsilon \\\n", + "126 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "124 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "125 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "127 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "630 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + ".. ... ... \n", + "227 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "922 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "920 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "921 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "923 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "\n", + " learning_rate policy_type reward_period run_id traj_mapping_method \n", + "126 0.1 PP 4 0 LEAST \n", + "124 0.1 OP 4 0 OT \n", + "125 0.1 PP 4 7 IMR \n", + "127 0.1 PP 4 7 DMR \n", + "630 0.1 PP 4 5 LEAST \n", + ".. ... ... ... ... ... \n", + "227 0.1 PP 4 13 DMR \n", + "922 0.1 PP 4 9 LEAST \n", + "920 0.1 OP 4 9 OT \n", + "921 0.1 PP 4 7 IMR \n", + "923 0.1 PP 4 7 DMR \n", + "\n", + "[1000 rows x 15 columns]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_abc_td" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "f44eb0de-3f33-4202-93b9-5e98bb5fca12", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['LEAST', 'OT', 'IMR', 'DMR'], dtype=object)" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_abc_td[\"traj_mapping_method\"].unique()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "9bafca96-a251-436e-8b77-024eba96fc53", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(20,)" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_abc_td[\"state_values\"].iloc[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "6564c1d0-2573-4c2e-afce-b9a75827ec45", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3114566/2457586796.py:38: UserWarning: The palette list has more values (10) than needed (5), which may not be intended.\n", + " return sns.relplot(\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plots_value_changes(df_abc_td)" + ] + }, + { + "cell_type": "markdown", + "id": "a49f0cbf-c120-4d08-96e1-ed8b2c668769", + "metadata": {}, + "source": [ + "## Monte Carlo" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "4958f05d-67c9-441f-9eff-6cd37717fb99", + "metadata": {}, + "outputs": [], + "source": [ + "df_abc_mc = isolate_problem(\n", + " df_logs, problem=\"ABCSeq\", level=\"n=7\",\n", + " algo=\"first-visit-mc\",\n", + " reward_period=4,\n", + " gamma=1.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "6317b523-ccac-4435-a21e-23cad48ac96a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(['DMR', 'LEAST', 'IMR', 'OT'], dtype=object)" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_abc_mc[\"traj_mapping_method\"].unique()" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "67ff8052-7972-4bcb-808b-7c48b0138b4e", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3114566/2457586796.py:38: UserWarning: The palette list has more values (10) than needed (5), which may not be intended.\n", + " return sns.relplot(\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plots_value_changes(df_abc_mc)" + ] + }, + { + "cell_type": "markdown", + "id": "3a07d7a5-1e00-4873-928f-74f20b66b0f5", + "metadata": {}, + "source": [ + "## Full Rewards" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "5737ec07-f812-4d6b-bbde-60dd72bf2408", + "metadata": {}, + "outputs": [], + "source": [ + "df_fr = isolate_problem(\n", + " df_logs, problem=\"ABCSeq\", level=\"n=7\",\n", + " algo=\"one-step-td\",\n", + " reward_period=1,\n", + " gamma=1.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "5fa7b13c-1c92-49fc-a2a1-0ee4ab28416d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
episodeexp_idstate_valuesmethodalgorithmalgorithm_argsdiscount_factordrop_truncated_feedback_episodesdyna_prog_state_valuesepsilonlearning_ratepolicy_typereward_periodrun_idtraj_mapping_method
3101726821587-74818508-ABCSeq[[-0.37000000000000005, -0.1, -0.3700000000000...PP/FRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP10FR
157101726821587-74818508-ABCSeq[[-8.10481282220499, -8.441783053979991, -7.75...PP/FRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP15FR
243201726821587-74818508-ABCSeq[[-13.263838468341735, -15.149045449201608, -1...PP/FRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP13FR
172301726821587-74818508-ABCSeq[[-21.37771271226704, -20.82920026250384, -19....PP/FRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP113FR
91401726821587-74818508-ABCSeq[[-27.699128593575917, -25.629846225256244, -2...PP/FRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP15FR
................................................
20224501726821587-74818508-ABCSeq[[-53.88400977444208, -44.773783195429914, -34...PP/FRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP114FR
22124601726821587-74818508-ABCSeq[[-43.45472354381776, -34.5372137401518, -30.6...PP/FRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP17FR
10124701726821587-74818508-ABCSeq[[-45.73286567267412, -39.93073384516813, -30....PP/FRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP15FR
5624801726821587-74818508-ABCSeq[[-45.24511475426608, -41.35581741781331, -31....PP/FRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP11FR
23024901726821587-74818508-ABCSeq[[-43.208181578307695, -34.8156996409463, -27....PP/FRone-step-td{}1.0False[-48.99999993362666, -41.99999998952205, -34.9...0.00.1PP19FR
\n", + "

250 rows × 15 columns

\n", + "
" + ], + "text/plain": [ + " episode exp_id \\\n", + "31 0 1726821587-74818508-ABCSeq \n", + "157 10 1726821587-74818508-ABCSeq \n", + "243 20 1726821587-74818508-ABCSeq \n", + "172 30 1726821587-74818508-ABCSeq \n", + "91 40 1726821587-74818508-ABCSeq \n", + ".. ... ... \n", + "202 2450 1726821587-74818508-ABCSeq \n", + "221 2460 1726821587-74818508-ABCSeq \n", + "101 2470 1726821587-74818508-ABCSeq \n", + "56 2480 1726821587-74818508-ABCSeq \n", + "230 2490 1726821587-74818508-ABCSeq \n", + "\n", + " state_values method algorithm \\\n", + "31 [[-0.37000000000000005, -0.1, -0.3700000000000... PP/FR one-step-td \n", + "157 [[-8.10481282220499, -8.441783053979991, -7.75... PP/FR one-step-td \n", + "243 [[-13.263838468341735, -15.149045449201608, -1... PP/FR one-step-td \n", + "172 [[-21.37771271226704, -20.82920026250384, -19.... PP/FR one-step-td \n", + "91 [[-27.699128593575917, -25.629846225256244, -2... PP/FR one-step-td \n", + ".. ... ... ... \n", + "202 [[-53.88400977444208, -44.773783195429914, -34... PP/FR one-step-td \n", + "221 [[-43.45472354381776, -34.5372137401518, -30.6... PP/FR one-step-td \n", + "101 [[-45.73286567267412, -39.93073384516813, -30.... PP/FR one-step-td \n", + "56 [[-45.24511475426608, -41.35581741781331, -31.... PP/FR one-step-td \n", + "230 [[-43.208181578307695, -34.8156996409463, -27.... PP/FR one-step-td \n", + "\n", + " algorithm_args discount_factor drop_truncated_feedback_episodes \\\n", + "31 {} 1.0 False \n", + "157 {} 1.0 False \n", + "243 {} 1.0 False \n", + "172 {} 1.0 False \n", + "91 {} 1.0 False \n", + ".. ... ... ... \n", + "202 {} 1.0 False \n", + "221 {} 1.0 False \n", + "101 {} 1.0 False \n", + "56 {} 1.0 False \n", + "230 {} 1.0 False \n", + "\n", + " dyna_prog_state_values epsilon \\\n", + "31 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "157 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "243 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "172 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "91 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + ".. ... ... \n", + "202 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "221 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "101 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "56 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "230 [-48.99999993362666, -41.99999998952205, -34.9... 0.0 \n", + "\n", + " learning_rate policy_type reward_period run_id traj_mapping_method \n", + "31 0.1 PP 1 0 FR \n", + "157 0.1 PP 1 5 FR \n", + "243 0.1 PP 1 3 FR \n", + "172 0.1 PP 1 13 FR \n", + "91 0.1 PP 1 5 FR \n", + ".. ... ... ... ... ... \n", + "202 0.1 PP 1 14 FR \n", + "221 0.1 PP 1 7 FR \n", + "101 0.1 PP 1 5 FR \n", + "56 0.1 PP 1 1 FR \n", + "230 0.1 PP 1 9 FR \n", + "\n", + "[250 rows x 15 columns]" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_fr" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "86e63173-c58b-4bff-b330-c79f12c2a752", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([array([-0.37, -0.1 , -0.37, -0.37, -0.1 , -0.46, -0.37, 0. ]),\n", + " array([-0.73, -0.46, -0.28, -0.19, -0.73, -0.46, -0.19, 0. ]),\n", + " array([-0.46, -0.19, -1.54, -0.73, -0.1 , -0.46, -0.91, 0. ]),\n", + " array([-0.64, -0.46, -0.37, -0.28, -0.1 , -0.28, -0.37, 0. ]),\n", + " array([-0.1 , -0.19, -0.37, -0.37, -1.09, -1.18, -0.28, 0. ]),\n", + " array([-1.18, -0.46, -0.19, -0.28, -0.28, -1.18, -0.46, 0. ]),\n", + " array([-0.28, -0.1 , -0.19, -1.18, -0.19, -0.19, -0.1 , 0. ]),\n", + " array([-0.55, -0.37, -0.19, -0.91, -0.73, -0.19, -1.54, 0. ]),\n", + " array([-0.1 , -0.64, -0.64, -0.1 , -0.19, -0.19, -0.91, 0. ]),\n", + " array([-1.54, -0.19, -0.55, -0.1 , -0.28, -0.1 , -0.1 , 0. ]),\n", + " array([-0.19, -0.1 , -1.09, -1. , -0.19, -0.91, -0.37, 0. ]),\n", + " array([-0.28, -0.37, -0.19, -0.73, -0.37, -0.19, -0.46, 0. ]),\n", + " array([-0.28, -0.1 , -0.28, -0.37, -0.91, -1.54, -0.55, 0. ]),\n", + " array([-0.37, -0.1 , -0.19, -0.1 , -0.37, -1. , -0.1 , 0. ]),\n", + " array([-1.36, -0.46, -0.1 , -0.55, -0.19, -1.99, -0.19, 0. ]),\n", + " array([-0.19, -1.36, -0.1 , -0.46, -0.19, -0.37, -0.19, 0. ]),\n", + " array([-0.28, -0.19, -0.19, -1.09, -0.1 , -0.37, -0.28, 0. ]),\n", + " array([-0.1 , -1. , -0.55, -0.1 , -0.1 , -0.64, -1.45, 0. ]),\n", + " array([-0.37, -0.1 , -0.1 , -0.28, -0.55, -0.37, -0.28, 0. ]),\n", + " array([-0.28, -0.46, -0.73, -0.28, -1.99, -0.1 , -0.28, 0. ])],\n", + " dtype=object)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_fr.iloc[0][\"state_values\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "8a44abaf-010e-4857-9f54-779f2d2f12e1", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3114566/2457586796.py:38: UserWarning: The palette list has more values (10) than needed (2), which may not be intended.\n", + " return sns.relplot(\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plots_value_changes(df_fr)" ] } ], @@ -1090,7 +1854,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.15" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/notebooks/daaf_analyses/reward_estimation/reward-recovery-lskv-factor-ts.ipynb b/notebooks/daaf_analyses/reward_estimation/reward-recovery-lskv-factor-ts.ipynb new file mode 100644 index 0000000..0d5ae46 --- /dev/null +++ b/notebooks/daaf_analyses/reward_estimation/reward-recovery-lskv-factor-ts.ipynb @@ -0,0 +1,582 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import copy\n", + "import pathlib\n", + "import os.path\n", + "import json\n", + "import hashlib\n", + "import itertools\n", + "import collections\n", + "from typing import Any, Sequence, Mapping, Set" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-08-08 12:40:42.277525: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-08-08 12:40:42.352900: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n", + "2024-08-08 12:40:42.355094: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "2024-08-08 12:40:44.074258: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import tensorflow as tf" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib as mpl\n", + "from matplotlib import pyplot as plt\n", + "import seaborn as sns" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from rlplg import envsuite, core\n", + "from daaf import estimator_metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "ENVS_MAPPING = {\n", + " (\n", + " \"IceWorld\",\n", + " \"4KE3ASUFQGGUPERSDDRQAZAMA46CI2CMCJHGWJ7MRNI64JMEBETNDXFFPYWTQJF46S5BJ4NXXCHNMJSLII3ROYXI76DFOC3VAABGNVA=\",\n", + " ): {\"args\": '{\"map_name\": \"4x4\"}', \"name\": \"4x4\"},\n", + " (\"ABCSeq\", \"2\"): {\n", + " \"args\": '{\"length\": 3, \"distance_penalty\": false}',\n", + " \"name\": \"n=3\",\n", + " }, \n", + " (\"ABCSeq\", \"3\"): {\n", + " \"args\": '{\"length\": 3, \"distance_penalty\": false}',\n", + " \"name\": \"n=3\",\n", + " }, \n", + " (\"ABCSeq\", \"10\"): {\n", + " \"args\": '{\"length\": 10, \"distance_penalty\": false}',\n", + " \"name\": \"n=10\",\n", + " },\n", + " (\n", + " \"RedGreenSeq\",\n", + " \"NNLHYJFTC5ENMMDZWRNQ37B6VVDXQ7WHB5EJOPXYZFLMJEZOYLTSLB4ID4WHQG57XQPNUHGZCFDCWHYGXWSBW7FBWYRZGAGBW4J7MEQ=\",\n", + " ): {\n", + " \"args\": '{\"cure\": [\"red\", \"green\", \"wait\", \"green\", \"red\", \"red\", \"green\", \"wait\"]}',\n", + " \"name\": \"n=9\",\n", + " },\n", + " (\n", + " \"FrozenLake-v1\",\n", + " \"U75ZLQLLXYRFQE5KOJJGNVQZGQ65U5RVVN3ZV5F4UNYQVK6NGTAAU62O2DKMOEGACNNUQOSWGYYOV7LQHK7GAWG2CL3U3RZJFIEIB5I=\",\n", + " ): {\"args\": '{\"is_slippery\": false, \"map_name\": \"4x4\"}', \"name\": \"4x4\"},\n", + " (\"TowerOfHanoi\", \"4\"): {\"args\": '{\"num_disks\": 4}', \"name\": \"disks=4\"},\n", + " (\"ABCSeq\", \"7\"): {\n", + " \"args\": '{\"length\": 7, \"distance_penalty\": false}',\n", + " \"name\": \"n=7\",\n", + " },\n", + " (\n", + " \"IceWorld\",\n", + " \"JKNDNWGM45FELU53ZLLVJEPY2SFZBCX54PSACOQOFMTDUAK5VNQ4KE45QZINGYFU5GR6D7F3GJMW7EC4TAY5PHCYRN5GPGP7YNACHEI=\",\n", + " ): {\"args\": '{\"map_name\": \"8x8\"}', \"name\": \"8x8\"},\n", + " (\n", + " \"GridWorld\",\n", + " \"P3VJZBIJ7PNUOFG2SCF532NH5AQ6NOBZEZ6UZNZ7D3AU3GQZSLKURMS2SRPEUF6O65F3ETJXEFNTR3UYS73TUCIIU3YIONXHAR6WE5A=\",\n", + " ): {\n", + " \"args\": '{\"grid\": \"oooooooooooo\\\\noooooooooooo\\\\noooooooooooo\\\\nsxxxxxxxxxxg\"}',\n", + " \"name\": \"4x12\",\n", + " },\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "EXCLUDED_ENVS = set([\"FrozenLake-v1\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Data" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "PATH = str(pathlib.Path.home() / \"fs/daaf/exp/reward-recovery/1723120236-report.json\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def read_data(path: str) -> pd.DataFrame:\n", + " return pd.read_json(path)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "df_raw = read_data(PATH)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
specmethodoutput
0{'name': 'ABCSeq', 'args': {'length': 3, 'dist...factor-ts{'least': [[-1.1666666667, -0.75, -1.083333333...
1{'name': 'ABCSeq', 'args': {'length': 2, 'dist...factor-ts{'least': [[-1.0, -1.0], [-1.0, -1.0]], 'ols-e...
2{'name': 'RedGreenSeq', 'args': {'cure': ['red...factor-ts{'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1...
3{'name': 'RedGreenSeq', 'args': {'cure': ['red...plain{'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1...
4{'name': 'ABCSeq', 'args': {'length': 3, 'dist...plain{'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1...
5{'name': 'ABCSeq', 'args': {'length': 2, 'dist...plain{'least': [[-1.0, -1.0], [-1.0, -1.0], [0.0, 0...
\n", + "
" + ], + "text/plain": [ + " spec method \\\n", + "0 {'name': 'ABCSeq', 'args': {'length': 3, 'dist... factor-ts \n", + "1 {'name': 'ABCSeq', 'args': {'length': 2, 'dist... factor-ts \n", + "2 {'name': 'RedGreenSeq', 'args': {'cure': ['red... factor-ts \n", + "3 {'name': 'RedGreenSeq', 'args': {'cure': ['red... plain \n", + "4 {'name': 'ABCSeq', 'args': {'length': 3, 'dist... plain \n", + "5 {'name': 'ABCSeq', 'args': {'length': 2, 'dist... plain \n", + "\n", + " output \n", + "0 {'least': [[-1.1666666667, -0.75, -1.083333333... \n", + "1 {'least': [[-1.0, -1.0], [-1.0, -1.0]], 'ols-e... \n", + "2 {'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1... \n", + "3 {'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1... \n", + "4 {'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1... \n", + "5 {'least': [[-1.0, -1.0], [-1.0, -1.0], [0.0, 0... " + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_raw" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'name': 'ABCSeq', 'args': {'length': 3, 'distance_penalty': False}}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_raw.iloc[0][\"spec\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "def get_env_level(env_spec):\n", + " loaded_env_spec = envsuite.load(env_spec[\"name\"], **env_spec[\"args\"])\n", + " return loaded_env_spec.level" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'3'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "get_env_level(df_raw.iloc[0][\"spec\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "def proc_data(df_raw: pd.DataFrame) -> pd.DataFrame:\n", + " rows = []\n", + " for row in df_raw.to_dict(\"records\"):\n", + " new_row = copy.deepcopy(row)\n", + " \n", + " # add env id;\n", + " new_row[\"env_name\"] = new_row[\"spec\"][\"name\"]\n", + " new_row[\"env_level\"] = get_env_level(new_row[\"spec\"])\n", + " eid = (new_row[\"spec\"][\"name\"], new_row[\"env_level\"])\n", + " new_row[\"env_label\"] = ENVS_MAPPING[eid][\"name\"]\n", + " rows.append(new_row)\n", + " return pd.DataFrame(rows)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "df_results = proc_data(df_raw)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
specmethodoutputenv_nameenv_levelenv_label
0{'name': 'ABCSeq', 'args': {'length': 3, 'dist...factor-ts{'least': [[-1.1666666667, -0.75, -1.083333333...ABCSeq3n=3
1{'name': 'ABCSeq', 'args': {'length': 2, 'dist...factor-ts{'least': [[-1.0, -1.0], [-1.0, -1.0]], 'ols-e...ABCSeq2n=3
2{'name': 'RedGreenSeq', 'args': {'cure': ['red...factor-ts{'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1...RedGreenSeqNNLHYJFTC5ENMMDZWRNQ37B6VVDXQ7WHB5EJOPXYZFLMJE...n=9
3{'name': 'RedGreenSeq', 'args': {'cure': ['red...plain{'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1...RedGreenSeqNNLHYJFTC5ENMMDZWRNQ37B6VVDXQ7WHB5EJOPXYZFLMJE...n=9
4{'name': 'ABCSeq', 'args': {'length': 3, 'dist...plain{'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1...ABCSeq3n=3
5{'name': 'ABCSeq', 'args': {'length': 2, 'dist...plain{'least': [[-1.0, -1.0], [-1.0, -1.0], [0.0, 0...ABCSeq2n=3
\n", + "
" + ], + "text/plain": [ + " spec method \\\n", + "0 {'name': 'ABCSeq', 'args': {'length': 3, 'dist... factor-ts \n", + "1 {'name': 'ABCSeq', 'args': {'length': 2, 'dist... factor-ts \n", + "2 {'name': 'RedGreenSeq', 'args': {'cure': ['red... factor-ts \n", + "3 {'name': 'RedGreenSeq', 'args': {'cure': ['red... plain \n", + "4 {'name': 'ABCSeq', 'args': {'length': 3, 'dist... plain \n", + "5 {'name': 'ABCSeq', 'args': {'length': 2, 'dist... plain \n", + "\n", + " output env_name \\\n", + "0 {'least': [[-1.1666666667, -0.75, -1.083333333... ABCSeq \n", + "1 {'least': [[-1.0, -1.0], [-1.0, -1.0]], 'ols-e... ABCSeq \n", + "2 {'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1... RedGreenSeq \n", + "3 {'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1... RedGreenSeq \n", + "4 {'least': [[-1.0, -1.0, -1.0], [-1.0, -1.0, -1... ABCSeq \n", + "5 {'least': [[-1.0, -1.0], [-1.0, -1.0], [0.0, 0... ABCSeq \n", + "\n", + " env_level env_label \n", + "0 3 n=3 \n", + "1 2 n=3 \n", + "2 NNLHYJFTC5ENMMDZWRNQ37B6VVDXQ7WHB5EJOPXYZFLMJE... n=9 \n", + "3 NNLHYJFTC5ENMMDZWRNQ37B6VVDXQ7WHB5EJOPXYZFLMJE... n=9 \n", + "4 3 n=3 \n", + "5 2 n=3 " + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_results" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('factor-ts',\n", + " [[1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", + " [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0],\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 3.0, 0.0],\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.0, 0.0, 1.0],\n", + " [1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0],\n", + " [1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],\n", + " [1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0],\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 4.0, 0.0],\n", + " [0.0, 1.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],\n", + " [-4.0, -4.0, -4.0, -3.0, -4.0, -4.0, -4.0, -4.0, -4.0],\n", + " [[-1.1666666667, -0.75, -1.0833333333],\n", + " [-1.0, -0.75, -1.0833333333],\n", + " [-1.0, -1.0, -1.0]])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_results.iloc[0][\"method\"], df_results.iloc[0][\"output\"][\"matrix\"], df_results.iloc[0][\"output\"][\"rhs\"], df_results.iloc[0][\"output\"][\"least\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The factor-ts matrix above excludes terminal states." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('plain',\n", + " [[1.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", + " [0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],\n", + " [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", + " [0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 1.0, 0.0, 0.0, 0.0],\n", + " [0.0, 1.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", + " [1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],\n", + " [1.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", + " [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0],\n", + " [0.0, 0.0, 0.0, 2.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n", + " [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0]],\n", + " [-4.0, -4.0, -4.0, -4.0, -4.0, -4.0, -4.0, -3.0, -4.0, -3.0, -4.0, -3.0],\n", + " [[-1.0, -1.0, -1.0],\n", + " [-1.0, -1.0, -1.0],\n", + " [-1.0, -1.0, -1.0],\n", + " [3.7021965659999997e-16, 2.908664673e-16, -0.0]])" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_results.iloc[4][\"method\"], df_results.iloc[4][\"output\"][\"matrix\"], df_results.iloc[4][\"output\"][\"rhs\"], df_results.iloc[4][\"output\"][\"least\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `factor-ts` estimate above is an example of cases where reward recovery has multiple possible solutions. Without any compensating strategies, the rewards learned in such cases can be incorrect.\n", + "\n", + "The `plain` estimate has correct values for most state-action pairs - this is because visitations to the terminal state anchor the values of non-terminal states visited before it.\n", + "Still, there are some minor errors in the values estimates of rewards for the terminal states, which can be manually corrected since we know them." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/daaf_analyses/reward_estimation/reward_recovery_example.py b/notebooks/daaf_analyses/reward_estimation/reward_recovery_example.py index 4303f15..b23365c 100644 --- a/notebooks/daaf_analyses/reward_estimation/reward_recovery_example.py +++ b/notebooks/daaf_analyses/reward_estimation/reward_recovery_example.py @@ -7,10 +7,9 @@ import numpy as np import pandas as pd -from rlplg import envplay, envsuite -from rlplg.learning.tabular import policies - from daaf import math_ops, replay_mapper +from rlplg import core, envplay, envsuite +from rlplg.learning.tabular import policies ENV_SPECS = [ {"name": "ABCSeq", "args": {"length": 7}}, @@ -107,6 +106,7 @@ def estimate_reward( buffer_size=env_spec.mdp.env_desc.num_states * env_spec.mdp.env_desc.num_actions * BUFFER_MULT, + terminal_states=core.infer_env_terminal_states(env_spec.mdp.transition), ) policy = policies.PyRandomPolicy(num_actions=env_spec.mdp.env_desc.num_actions) # collect data @@ -188,5 +188,8 @@ def estimation_experiment(env_specs: Sequence[Mapping[str, Any]]): now = int(time.time()) df_results = estimation_experiment(env_specs=ENV_SPECS) df_results.to_json( - os.path.join(str(pathlib.Path.home()), f"fs/daaf/exp/reward-recovery/{now}-report.json"), orient="records" + os.path.join( + str(pathlib.Path.home()), f"fs/daaf/exp/reward-recovery/{now}-report.json" + ), + orient="records", ) diff --git a/notebooks/daaf_analyses/reward_estimation/reward_recovery_lskv.py b/notebooks/daaf_analyses/reward_estimation/reward_recovery_lskv.py new file mode 100644 index 0000000..2283fd1 --- /dev/null +++ b/notebooks/daaf_analyses/reward_estimation/reward_recovery_lskv.py @@ -0,0 +1,241 @@ +import concurrent.futures +import logging +import os.path +import pathlib +import time +from typing import Any, Mapping, Optional, Sequence, Tuple + +import numpy as np +import pandas as pd +from daaf import math_ops, replay_mapper +from rlplg import core, envplay, envsuite +from rlplg.learning.tabular import policies + +ENV_SPECS = [ + {"name": "ABCSeq", "args": {"length": 2, "distance_penalty": False}}, + {"name": "ABCSeq", "args": {"length": 3, "distance_penalty": False}}, + {"name": "ABCSeq", "args": {"length": 7, "distance_penalty": False}}, + {"name": "FrozenLake-v1", "args": {"is_slippery": False, "map_name": "4x4"}}, + { + "name": "GridWorld", + "args": {"grid": "oooooooooooo\noooooooooooo\noooooooooooo\nsxxxxxxxxxxg"}, + }, + { + "name": "RedGreenSeq", + "args": { + "cure": ["red", "green", "wait", "green", "red", "red", "green", "wait"] + }, + }, + {"name": "IceWorld", "args": {"map_name": "4x4"}}, + {"name": "TowerOfHanoi", "args": {"num_disks": 4}}, +] + +BUFFER_MULT = 2**10 +EST_PLAIN = "plain" +EST_FACTOR_TS = "factor-ts" +EST_PREFILL_BUFFER = "prefill-buffer" + + +def estimation_experiment(env_specs: Sequence[Mapping[str, Any]]): + rows = [] + with concurrent.futures.ProcessPoolExecutor() as executor: + future_to_spec = {} + for spec in env_specs: + for method in (EST_PLAIN, EST_FACTOR_TS): + future_to_spec[executor.submit(run_fn, spec, method)] = (spec, method) + for future in concurrent.futures.as_completed(future_to_spec): + spec, method = future_to_spec[future] + output = future.result() + rows.append({"spec": spec, "method": method, "output": output}) + return pd.DataFrame(rows) + + +def run_fn(spec: Mapping[str, Any], method: str): + if method == EST_PLAIN: + factor_terminal_states = False + prefill_buffer = False + elif method == EST_FACTOR_TS: + factor_terminal_states = True + prefill_buffer = False + # elif method == EST_PREFILL_BUFFER: + # factor_terminal_states = False + # prefill_buffer = True + else: + raise ValueError(f"Unsupported method: {method}") + + del prefill_buffer + return estimate_reward(spec=spec, factor_terminal_states=factor_terminal_states) + + +def estimate_reward( + spec: Mapping[str, Any], + accuracy: float = 1e-8, + max_episodes: int = 7500, + logging_steps: int = 100, + factor_terminal_states: bool = False, +) -> Mapping[str, np.ndarray]: + def reshape_rr(array: np.ndarray, nrows: int, ncols: int) -> np.ndarray: + return np.reshape( + array, + newshape=( + nrows, + ncols, + ), + ) + + env_spec = envsuite.load(spec["name"], **spec["args"]) + # logging.info("Env: %s, %s", env_spec.name, env_spec.level) + init_rtable = np.zeros( + shape=(env_spec.mdp.env_desc.num_states, env_spec.mdp.env_desc.num_actions), + dtype=np.float64, + ) + terminal_states = core.infer_env_terminal_states(env_spec.mdp.transition) + mapper = replay_mapper.DaafLsqRewardAttributionMapper( + num_states=env_spec.mdp.env_desc.num_states, + num_actions=env_spec.mdp.env_desc.num_actions, + reward_period=4, + state_id_fn=env_spec.discretizer.state, + action_id_fn=env_spec.discretizer.action, + init_rtable=init_rtable, + buffer_size=env_spec.mdp.env_desc.num_states + * env_spec.mdp.env_desc.num_actions + * BUFFER_MULT, + terminal_states=terminal_states, + factor_terminal_states=factor_terminal_states, + ) + policy = policies.PyRandomPolicy(num_actions=env_spec.mdp.env_desc.num_actions) + # collect data + logging.info("Collecting data for %s", spec["name"]) + episode = 1 + while True: + traj = envplay.generate_episodes( + env_spec.environment, policy=policy, num_episodes=1 + ) + for _ in mapper.apply(traj): + pass + + if ( + not mapper._estimation_buffer.is_empty + and mapper._estimation_buffer.is_full_rank + ): + break + + if episode % logging_steps == 0: + logging.info("Data collection for %s at %d episodes", spec["name"], episode) + if episode >= max_episodes: + break + episode += 1 + + # estimate rewards + yhat_lstsq: Optional[np.ndarray] = None + yhat_ols_em: Optional[np.ndarray] = None + if mapper._estimation_buffer.is_full_rank: + logging.info( + "Estimating rewards for %s, after %d episodes. Matrix shape: %s", + spec["name"], + episode, + mapper._estimation_buffer.matrix.shape, + ) + yhat_ols_em, iters = ols_em_reward_estimation( + obs_matrix=mapper._estimation_buffer.matrix, + agg_rewards=mapper._estimation_buffer.rhs, + accuracy=accuracy, + ) + yhat_ols_em = reshape_rr( + yhat_ols_em, + env_spec.mdp.env_desc.num_states + - (len(terminal_states) if factor_terminal_states else 0), + env_spec.mdp.env_desc.num_actions, + ) + logging.info("OLS ran in %d iterations for %s", iters, spec["name"]) + yhat_lstsq = lstsq_reward_estimation( + obs_matrix=mapper._estimation_buffer.matrix, + agg_rewards=mapper._estimation_buffer.rhs, + ) + yhat_lstsq = reshape_rr( + yhat_lstsq, + env_spec.mdp.env_desc.num_states + - (len(terminal_states) if factor_terminal_states else 0), + env_spec.mdp.env_desc.num_actions, + ) + else: + logging.info( + "Matrix is ill defined. Skipping reward estimation for %s: %s", + spec["name"], + spec["args"], + ) + return { + "least": yhat_lstsq, + "ols-em": yhat_ols_em, + "matrix": mapper._estimation_buffer.matrix, + "rhs": mapper._estimation_buffer.rhs, + } + + +def lstsq_reward_estimation( + obs_matrix: np.ndarray, agg_rewards: np.ndarray +) -> np.ndarray: + return math_ops.solve_least_squares( + matrix=obs_matrix, + rhs=agg_rewards, + ) + + +def ols_em_reward_estimation( + obs_matrix: np.ndarray, + agg_rewards: np.ndarray, + accuracy: float = 1e-8, + max_iters: int = 1_000_000, + stop_check_interval: int = 1000, +) -> Tuple[np.ndarray, int]: + iteration = 1 + yhat_rewards = np.random.rand(obs_matrix.shape[1]) + # multiply the cumulative reward by visits of each state action + # dim: (num obs, num states x num actions) + nomin = np.expand_dims(agg_rewards, axis=-1) * obs_matrix + qs = np.sum(obs_matrix, axis=0) + while True: + delta = np.zeros_like(yhat_rewards) + # multiply reward guess by row and sum each row's entry + # dim: num obs + denom = np.sum(yhat_rewards * obs_matrix, axis=1) + factor = np.sum(nomin / np.expand_dims(denom, 1), axis=0) + new_yhat_rewards = yhat_rewards * (factor / qs) + delta = np.maximum(delta, np.abs(yhat_rewards - new_yhat_rewards)) + if ( + iteration % stop_check_interval == 0 + and np.sum(np.isnan(new_yhat_rewards)) > 0 + ): + logging.info( + "Stopping at iteration %d/%d. `nan` values: %s", + iteration, + max_iters, + new_yhat_rewards, + ) + break + if np.alltrue(delta < accuracy) or iteration >= max_iters: + logging.info( + "Stopping at iteration %d/%d. Max error: %f", + iteration, + max_iters, + np.max(delta), + ) + break + yhat_rewards = new_yhat_rewards + iteration += 1 + return yhat_rewards, iteration + + +def main(): + now = int(time.time()) + df_results = estimation_experiment(env_specs=ENV_SPECS) + df_results.to_json( + os.path.join( + str(pathlib.Path.home()), f"fs/daaf/exp/reward-recovery/{now}-report.json" + ), + orient="records", + ) + + +if __name__ == "__main__": + main() diff --git a/ray-env-requirements.txt b/ray-env-requirements.txt index 8a90333..be70f31 100644 --- a/ray-env-requirements.txt +++ b/ray-env-requirements.txt @@ -1,9 +1,5 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# pip-compile --no-emit-find-links --no-emit-index-url --no-emit-options --unsafe-package=ray ray-env-requirements.in -# +# This file was autogenerated by uv via the following command: +# uv pip compile --no-emit-index-url --no-emit-find-links ray-env-requirements.in -o ray-env-requirements.txt --unsafe-package ray absl-py==2.0.0 # via # -r requirements.txt @@ -13,11 +9,8 @@ aiohttp==3.8.4 # via # -r requirements.txt # aiohttp-cors - # ray aiohttp-cors==0.7.0 - # via - # -r requirements.txt - # ray + # via -r requirements.txt aiosignal==1.3.1 # via # -r requirements.txt @@ -36,10 +29,6 @@ attrs==22.2.0 # -r requirements.txt # aiohttp # jsonschema -blessed==1.20.0 - # via - # -r requirements.txt - # gpustat cachetools==5.3.0 # via # -r requirements.txt @@ -53,6 +42,10 @@ charset-normalizer==3.1.0 # -r requirements.txt # aiohttp # requests +clarabel==0.9.0 + # via + # -r requirements.txt + # cvxpy click==8.0.4 # via # -r requirements.txt @@ -62,13 +55,17 @@ cloudpickle==2.2.1 # -r requirements.txt # gymnasium colorful==0.5.5 - # via - # -r requirements.txt - # ray + # via -r requirements.txt +cvxpy==1.5.3 + # via -r requirements.txt distlib==0.3.6 # via # -r requirements.txt # virtualenv +ecos==2.0.14 + # via + # -r requirements.txt + # cvxpy farama-notifications==0.0.4 # via # -r requirements.txt @@ -89,9 +86,7 @@ frozenlist==1.3.3 # aiosignal # ray fsspec==2024.2.0 - # via - # -r requirements.txt - # ray + # via -r requirements.txt gast==0.4.0 # via # -r requirements.txt @@ -118,14 +113,9 @@ googleapis-common-protos==1.58.0 # via # -r requirements.txt # google-api-core -gpustat==1.0.0 - # via - # -r requirements.txt - # ray grpcio==1.51.3 # via # -r requirements.txt - # ray # tensorboard # tensorflow gymnasium==0.28.1 @@ -150,6 +140,10 @@ jax-jumpy==1.0.0 # via # -r requirements.txt # gymnasium +jinja2==3.1.4 + # via + # -r requirements.txt + # memray jsonschema==4.17.3 # via # -r requirements.txt @@ -162,14 +156,35 @@ libclang==16.0.6 # via # -r requirements.txt # tensorflow +linkify-it-py==2.0.3 + # via + # -r requirements.txt + # markdown-it-py markdown==3.5.1 # via # -r requirements.txt # tensorboard +markdown-it-py==3.0.0 + # via + # -r requirements.txt + # mdit-py-plugins + # rich + # textual markupsafe==2.1.3 # via # -r requirements.txt + # jinja2 # werkzeug +mdit-py-plugins==0.4.2 + # via + # -r requirements.txt + # markdown-it-py +mdurl==0.1.2 + # via + # -r requirements.txt + # markdown-it-py +memray==1.14.0 + # via -r requirements.txt msgpack==1.0.5 # via # -r requirements.txt @@ -182,29 +197,28 @@ multidict==6.0.4 numpy==1.23.5 # via # -r requirements.txt + # clarabel + # cvxpy + # ecos # gymnasium # h5py # jax-jumpy # opt-einsum + # osqp # pandas # pyarrow - # ray + # qdldl # rlplg # scipy + # scs # tensorboard # tensorflow -nvidia-ml-py==11.495.46 - # via - # -r requirements.txt - # gpustat oauthlib==3.2.2 # via # -r requirements.txt # requests-oauthlib opencensus==0.11.2 - # via - # -r requirements.txt - # ray + # via -r requirements.txt opencensus-context==0.1.3 # via # -r requirements.txt @@ -213,15 +227,17 @@ opt-einsum==3.3.0 # via # -r requirements.txt # tensorflow +osqp==0.6.7.post1 + # via + # -r requirements.txt + # cvxpy packaging==23.2 # via # -r requirements.txt # ray # tensorflow pandas==2.0.3 - # via - # -r requirements.txt - # ray + # via -r requirements.txt pillow==9.4.0 # via # -r requirements.txt @@ -231,9 +247,7 @@ platformdirs==3.1.1 # -r requirements.txt # virtualenv prometheus-client==0.13.1 - # via - # -r requirements.txt - # ray + # via -r requirements.txt protobuf==4.25.1 # via # -r requirements.txt @@ -242,18 +256,10 @@ protobuf==4.25.1 # ray # tensorboard # tensorflow -psutil==5.9.4 - # via - # -r requirements.txt - # gpustat py-spy==0.3.14 - # via - # -r requirements.txt - # ray + # via -r requirements.txt pyarrow==15.0.0 - # via - # -r requirements.txt - # ray + # via -r requirements.txt pyasn1==0.4.8 # via # -r requirements.txt @@ -264,9 +270,11 @@ pyasn1-modules==0.2.8 # -r requirements.txt # google-auth pydantic==1.10.6 + # via -r requirements.txt +pygments==2.18.0 # via # -r requirements.txt - # ray + # rich pyrsistent==0.19.3 # via # -r requirements.txt @@ -283,6 +291,10 @@ pyyaml==6.0 # via # -r requirements.txt # ray +qdldl==0.1.7.post4 + # via + # -r requirements.txt + # osqp requests==2.28.2 # via # -r requirements.txt @@ -294,7 +306,12 @@ requests-oauthlib==1.3.1 # via # -r requirements.txt # google-auth-oauthlib -rlplg @ git+https://github.com/guidj/rlplg.git@v0.19.10 +rich==13.8.1 + # via + # -r requirements.txt + # memray + # textual +rlplg @ git+https://github.com/guidj/rlplg.git@93d819cab2eeefdd9bbd0ec7acb5d1dd0e5d1bed # via -r requirements.txt rsa==4.9 # via @@ -303,25 +320,32 @@ rsa==4.9 scipy==1.10.1 # via # -r requirements.txt + # clarabel + # cvxpy + # ecos + # osqp + # qdldl # rlplg -setuptools==69.0.3 + # scs +scs==3.2.7 + # via + # -r requirements.txt + # cvxpy +setuptools==74.1.2 # via + # -r requirements.txt # tensorboard # tensorflow six==1.16.0 # via # -r requirements.txt # astunparse - # blessed # google-auth # google-pasta - # gpustat # python-dateutil # tensorflow smart-open==6.3.0 - # via - # -r requirements.txt - # ray + # via -r requirements.txt tblib==1.7.0 # via -r ray-env-requirements.in tensorboard==2.13.0 @@ -346,28 +370,31 @@ termcolor==2.3.0 # via # -r requirements.txt # tensorflow +textual==0.78.0 + # via + # -r requirements.txt + # memray typing-extensions==4.5.0 # via # -r requirements.txt # gymnasium # pydantic # tensorflow + # textual tzdata==2024.1 # via # -r requirements.txt # pandas +uc-micro-py==1.0.3 + # via + # -r requirements.txt + # linkify-it-py urllib3==1.26.15 # via # -r requirements.txt # requests virtualenv==20.21.0 - # via - # -r requirements.txt - # ray -wcwidth==0.2.6 - # via - # -r requirements.txt - # blessed + # via -r requirements.txt werkzeug==3.0.1 # via # -r requirements.txt @@ -390,5 +417,5 @@ zipp==3.15.0 # -r requirements.txt # importlib-metadata -# The following packages are considered to be unsafe in a requirements file: +# The following packages were excluded from the output: # ray diff --git a/rendering-requirements.txt b/rendering-requirements.txt index 44f2771..60aedeb 100644 --- a/rendering-requirements.txt +++ b/rendering-requirements.txt @@ -1,8 +1,4 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# pip-compile --no-emit-find-links --no-emit-index-url --no-emit-options rendering-requirements.in -# +# This file was autogenerated by uv via the following command: +# uv pip compile --no-emit-index-url --no-emit-find-links rendering-requirements.in -o rendering-requirements.txt pyglet==1.5.21 # via -r rendering-requirements.in diff --git a/requirements.in b/requirements.in index 0623678..343e5fb 100644 --- a/requirements.in +++ b/requirements.in @@ -1,6 +1,7 @@ rlplg @ git+https://github.com/guidj/rlplg.git@v0.19.10 # dashboard + cluster -ray[default]==2.9.3 -ray[data]==2.9.3 +ray[default]==2.23.0 +ray[data]==2.23.0 numpy==1.23.5 tensorflow>=2.11.1 +cvxpy==1.5.3 diff --git a/requirements.txt b/requirements.txt index 9a66d8a..88d8735 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,5 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# pip-compile --no-emit-find-links --no-emit-index-url --no-emit-options requirements.in -# +# This file was autogenerated by uv via the following command: +# uv pip compile --no-emit-index-url --no-emit-find-links requirements.in -o requirements.txt absl-py==2.0.0 # via # tensorboard @@ -26,8 +22,6 @@ attrs==22.2.0 # via # aiohttp # jsonschema -blessed==1.20.0 - # via gpustat cachetools==5.3.0 # via google-auth certifi==2022.12.7 @@ -36,14 +30,20 @@ charset-normalizer==3.1.0 # via # aiohttp # requests +clarabel==0.9.0 + # via cvxpy click==8.0.4 # via ray cloudpickle==2.2.1 # via gymnasium colorful==0.5.5 # via ray +cvxpy==1.5.3 + # via -r requirements.in distlib==0.3.6 # via virtualenv +ecos==2.0.14 + # via cvxpy farama-notifications==0.0.4 # via gymnasium filelock==3.9.1 @@ -74,8 +74,6 @@ google-pasta==0.2.0 # via tensorflow googleapis-common-protos==1.58.0 # via google-api-core -gpustat==1.0.0 - # via ray grpcio==1.51.3 # via # ray @@ -95,16 +93,33 @@ importlib-metadata==6.0.0 # markdown jax-jumpy==1.0.0 # via gymnasium +jinja2==3.1.4 + # via memray jsonschema==4.17.3 # via ray keras==2.13.1 # via tensorflow libclang==16.0.6 # via tensorflow +linkify-it-py==2.0.3 + # via markdown-it-py markdown==3.5.1 # via tensorboard +markdown-it-py==3.0.0 + # via + # mdit-py-plugins + # rich + # textual markupsafe==2.1.3 - # via werkzeug + # via + # jinja2 + # werkzeug +mdit-py-plugins==0.4.2 + # via markdown-it-py +mdurl==0.1.2 + # via markdown-it-py +memray==1.14.0 + # via ray msgpack==1.0.5 # via ray multidict==6.0.4 @@ -114,19 +129,23 @@ multidict==6.0.4 numpy==1.23.5 # via # -r requirements.in + # clarabel + # cvxpy + # ecos # gymnasium # h5py # jax-jumpy # opt-einsum + # osqp # pandas # pyarrow + # qdldl # ray # rlplg # scipy + # scs # tensorboard # tensorflow -nvidia-ml-py==11.495.46 - # via gpustat oauthlib==3.2.2 # via requests-oauthlib opencensus==0.11.2 @@ -135,6 +154,8 @@ opencensus-context==0.1.3 # via opencensus opt-einsum==3.3.0 # via tensorflow +osqp==0.6.7.post1 + # via cvxpy packaging==23.2 # via # ray @@ -154,8 +175,6 @@ protobuf==4.25.1 # ray # tensorboard # tensorflow -psutil==5.9.4 - # via gpustat py-spy==0.3.14 # via ray pyarrow==15.0.0 @@ -168,6 +187,8 @@ pyasn1-modules==0.2.8 # via google-auth pydantic==1.10.6 # via ray +pygments==2.18.0 + # via rich pyrsistent==0.19.3 # via jsonschema python-dateutil==2.8.2 @@ -176,7 +197,9 @@ pytz==2024.1 # via pandas pyyaml==6.0 # via ray -ray[data,default]==2.9.3 +qdldl==0.1.7.post4 + # via osqp +ray==2.23.0 # via -r requirements.in requests==2.28.2 # via @@ -186,19 +209,34 @@ requests==2.28.2 # tensorboard requests-oauthlib==1.3.1 # via google-auth-oauthlib -rlplg @ git+https://github.com/guidj/rlplg.git@v0.19.10 +rich==13.8.1 + # via + # memray + # textual +rlplg @ git+https://github.com/guidj/rlplg.git@93d819cab2eeefdd9bbd0ec7acb5d1dd0e5d1bed # via -r requirements.in rsa==4.9 # via google-auth scipy==1.10.1 - # via rlplg + # via + # clarabel + # cvxpy + # ecos + # osqp + # qdldl + # rlplg + # scs +scs==3.2.7 + # via cvxpy +setuptools==74.1.2 + # via + # tensorboard + # tensorflow six==1.16.0 # via # astunparse - # blessed # google-auth # google-pasta - # gpustat # python-dateutil # tensorflow smart-open==6.3.0 @@ -215,19 +253,22 @@ tensorflow-io-gcs-filesystem==0.34.0 # via tensorflow termcolor==2.3.0 # via tensorflow +textual==0.78.0 + # via memray typing-extensions==4.5.0 # via # gymnasium # pydantic # tensorflow + # textual tzdata==2024.1 # via pandas +uc-micro-py==1.0.3 + # via linkify-it-py urllib3==1.26.15 # via requests virtualenv==20.21.0 # via ray -wcwidth==0.2.6 - # via blessed werkzeug==3.0.1 # via tensorboard wheel==0.42.0 @@ -240,6 +281,3 @@ yarl==1.8.2 # via aiohttp zipp==3.15.0 # via importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/sbin/local/reward-estjob.sh b/sbin/local/reward-estjob.sh index 3544710..6639f97 100755 --- a/sbin/local/reward-estjob.sh +++ b/sbin/local/reward-estjob.sh @@ -7,6 +7,6 @@ PARENT_DIR=$DIR/../.. TIMESTAMP=`date +%s` python $PARENT_DIR/src/daaf/rewardest/estjob.py \ --num-runs=10 \ ---max-episodes=7500 \ +--max-episodes=2500 \ --output-dir=$HOME/fs/daaf/exp/reward-estjob/logs/$TIMESTAMP \ --log-episode-frequency=10 diff --git a/setup.py b/setup.py index c70668a..8508070 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def read_requirements(filename): "Operating System :: POSIX :: Linux", "Operating System :: MacOS :: MacOS X", "Programming Language :: Python", - "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Programming Language :: Python :: Implementation :: CPython", ] META_FILE = read(META_PATH) diff --git a/src/daaf/controlexps/control.py b/src/daaf/controlexps/control.py index 14d8dd0..b58945e 100644 --- a/src/daaf/controlexps/control.py +++ b/src/daaf/controlexps/control.py @@ -20,7 +20,7 @@ from daaf.controlexps import methods -def run_fn(experiment_task: expconfig.ExperimentTask): +def run_fn(experiment_run: expconfig.ExperimentRun): """ Entry point running on-policy evaluation for DAAF. @@ -29,15 +29,15 @@ def run_fn(experiment_task: expconfig.ExperimentTask): """ # init env and agent env_spec = task.create_env_spec( - problem=experiment_task.experiment.env_config.name, - env_args=experiment_task.experiment.env_config.args, + problem=experiment_run.experiment.env_config.name, + env_args=experiment_run.experiment.env_config.args, ) traj_mappers = task.create_trajectory_mappers( env_spec=env_spec, - reward_period=experiment_task.experiment.daaf_config.reward_period, - traj_mapping_method=experiment_task.experiment.daaf_config.traj_mapping_method, + reward_period=experiment_run.experiment.daaf_config.reward_period, + traj_mapping_method=experiment_run.experiment.daaf_config.traj_mapping_method, buffer_size_or_multiplier=(None, None), - drop_truncated_feedback_episodes=experiment_task.experiment.daaf_config.drop_truncated_feedback_episodes, + drop_truncated_feedback_episodes=experiment_run.experiment.daaf_config.drop_truncated_feedback_episodes, ) # Collect returns on underlying MDP # before other mappers change it. @@ -46,30 +46,30 @@ def run_fn(experiment_task: expconfig.ExperimentTask): logging.debug("Starting DAAF Control Experiments") results = policy_control( env_spec=env_spec, - daaf_config=experiment_task.experiment.daaf_config, - num_episodes=experiment_task.run_config.num_episodes, - learnign_args=experiment_task.experiment.learning_args, + daaf_config=experiment_run.experiment.daaf_config, + num_episodes=experiment_run.run_config.num_episodes, + learnign_args=experiment_run.experiment.learning_args, generate_steps_fn=task.create_generate_episode_fn(mappers=traj_mappers), ) env_info: Mapping[str, Any] = { "env": { "name": env_spec.name, "level": env_spec.level, - "args": json.dumps(experiment_task.experiment.env_config.args), + "args": json.dumps(experiment_run.experiment.env_config.args), }, } with utils.ExperimentLogger( - log_dir=experiment_task.run_config.output_dir, - exp_id=experiment_task.exp_id, - run_id=experiment_task.run_id, + log_dir=experiment_run.run_config.output_dir, + exp_id=experiment_run.exp_id, + run_id=experiment_run.run_id, params={ **env_info, **utils.json_from_dict( - dataclasses.asdict(experiment_task.experiment.daaf_config), + dataclasses.asdict(experiment_run.experiment.daaf_config), dict_encode_level=0, ), - **dataclasses.asdict(experiment_task.experiment.learning_args), - **experiment_task.context, + **dataclasses.asdict(experiment_run.experiment.learning_args), + **experiment_run.context, }, ) as exp_logger: state_values: Optional[np.ndarray] = None @@ -78,7 +78,7 @@ def run_fn(experiment_task: expconfig.ExperimentTask): for episode, snapshot in enumerate(results): state_values = np.max(snapshot.action_values, axis=1) state_actions = np.argmax(snapshot.action_values, axis=1) - if episode % experiment_task.run_config.log_episode_frequency == 0: + if episode % experiment_run.run_config.log_episode_frequency == 0: mean_returns = np.mean(returns_collector.traj_returns) exp_logger.log( episode=episode, @@ -95,13 +95,13 @@ def run_fn(experiment_task: expconfig.ExperimentTask): logging.debug( "\nEstimated values run %d of %s:\n%s", - experiment_task.run_id, - experiment_task.exp_id, + experiment_run.run_id, + experiment_run.exp_id, state_values, ) except Exception as err: raise RuntimeError( - f"Task {experiment_task.exp_id}, run {experiment_task.run_id} failed" + f"Task {experiment_run.exp_id}, run {experiment_run.run_id} failed" ) from err env_spec.environment.close() diff --git a/src/daaf/controlexps/controljob.py b/src/daaf/controlexps/controljob.py index 290cab5..2981ad8 100644 --- a/src/daaf/controlexps/controljob.py +++ b/src/daaf/controlexps/controljob.py @@ -82,7 +82,7 @@ def create_tasks( output_dir: str, task_prefix: str, log_episode_frequency: int, -) -> Sequence[Tuple[ray.ObjectRef, expconfig.ExperimentTask]]: +) -> Sequence[Tuple[ray.ObjectRef, expconfig.ExperimentRun]]: """ Runs numerical experiments on policy evaluation. """ @@ -97,7 +97,7 @@ def create_tasks( ) ) experiments_and_context = add_experiment_context(experiments, assets_dir=assets_dir) - experiment_tasks = tuple( + experiment_runs = tuple( expconfig.generate_tasks_from_experiments_context_and_run_config( run_config=expconfig.RunConfig( num_episodes=num_episodes, @@ -110,12 +110,12 @@ def create_tasks( ) ) # shuffle tasks to balance workload - experiment_tasks = random.sample( - experiment_tasks, - len(experiment_tasks), # type: ignore + experiment_runs = random.sample( + experiment_runs, + len(experiment_runs), # type: ignore ) experiment_batches = utils.bundle( - experiment_tasks, bundle_size=constants.DEFAULT_BATCH_SIZE + experiment_runs, bundle_size=constants.DEFAULT_BATCH_SIZE ) logging.info( "Parsed %d DAAF configs and %d environments into %d tasks", @@ -171,7 +171,7 @@ def add_experiment_context( @ray.remote def run_experiments( - experiments_batch: Sequence[expconfig.ExperimentTask], + experiments_batch: Sequence[expconfig.ExperimentRun], ) -> Sequence[str]: """ Run experiments. diff --git a/src/daaf/controlexps/results_agg_pipeline.py b/src/daaf/controlexps/results_agg_pipeline.py index bc377a0..03b1036 100644 --- a/src/daaf/controlexps/results_agg_pipeline.py +++ b/src/daaf/controlexps/results_agg_pipeline.py @@ -139,11 +139,7 @@ def parse_experiment_logs(paths: Sequence[str]) -> ray.data.Dataset: """ logs_files = [os.path.join(path, "experiment-logs.jsonl") for path in paths] ds_logs = ray.data.read_json( - logs_files, - include_paths=True, - partition_filter=ray.data.datasource.FileExtensionFilter( - file_extensions=["jsonl"] - ), + logs_files, include_paths=True, file_extensions=["jsonl"] ) return ds_logs diff --git a/src/daaf/estimator_metrics.py b/src/daaf/estimator_metrics.py index 009645d..8212b79 100644 --- a/src/daaf/estimator_metrics.py +++ b/src/daaf/estimator_metrics.py @@ -12,12 +12,12 @@ def rmse(v_pred: np.ndarray, v_true: np.ndarray, axis: int): v_pred: An array of shape [b, k] v_true: An array of shape [b, k] """ - if v_pred.shape != v_true.shape: + if np.shape(v_pred) != np.shape(v_true): raise ValueError( - f"Tensors have different shapes: {v_pred.shape} != {v_true.shape}" + f"Tensors have different shapes: {np.shape(v_pred)} != {np.shape(v_true)}" ) return np.sqrt( - np.sum(np.power(v_pred - v_true, 2.0), axis=axis) / v_pred.shape[axis] + np.sum(np.power(v_pred - v_true, 2.0), axis=axis) / np.shape(v_pred)[axis] ) @@ -27,9 +27,9 @@ def mean_absolute_error(v_pred: np.ndarray, v_true: np.ndarray, axis: int): v_pred: An array of shape [b, k] v_true: An array of shape [b, k] """ - if v_pred.shape != v_true.shape: + if np.shape(v_pred) != np.shape(v_true): raise ValueError( - f"Tensors have different shapes: {v_pred.shape} != {v_true.shape}" + f"Tensors have different shapes: {np.shape(v_pred)} != {np.shape(v_true)}" ) delta = np.abs(v_pred - v_true) return np.mean(delta, axis=axis) @@ -53,16 +53,18 @@ def cosine_distance(v_pred: np.ndarray, v_true: np.ndarray): v_pred: An array of shape [b, k] v_true: An array of shape [b, k] """ - if v_pred.shape != v_true.shape: + if np.shape(v_pred) != np.shape(v_true): raise ValueError( - f"Tensors have different shapes: {v_pred.shape} != {v_true.shape}" + f"Tensors have different shapes: {np.shape(v_pred)} != {np.shape(v_true)}" ) - if len(v_pred.shape) != 2: - raise ValueError(f"Tensors are not 2-dim: {v_pred.shape}, {v_true.shape}") + if len(np.shape(v_pred)) != 2: + raise ValueError( + f"Tensors are not 2-dim: {np.shape(v_pred)}, {np.shape(v_true)}" + ) cosines = [] - for row in range(v_pred.shape[0]): + for row in range(np.shape(v_pred)[0]): cosines.append(distance.cosine(v_pred[row], v_true[row])) return np.array(cosines, dtype=v_true.dtype) @@ -73,15 +75,17 @@ def dotproduct(v_pred: np.ndarray, v_true: np.ndarray): v_pred: An array of shape [b, k] v_true: An array of shape [b, k] """ - if v_pred.shape != v_true.shape: + if np.shape(v_pred) != np.shape(v_true): raise ValueError( - f"Tensors have different shapes: {v_pred.shape} != {v_true.shape}" + f"Tensors have different shapes: {np.shape(v_pred)} != {np.shape(v_true)}" ) - if len(v_pred.shape) != 2: - raise ValueError(f"Tensors are not 2-dim: {v_pred.shape}, {v_true.shape}") + if len(np.shape(v_pred)) != 2: + raise ValueError( + f"Tensors are not 2-dim: {np.shape(v_pred)}, {np.shape(v_true)}" + ) dps = [] - for row in range(v_pred.shape[0]): + for row in range(np.shape(v_pred)[0]): dps.append(np.dot(v_pred[row], v_true[row])) return np.array(dps, dtype=v_true.dtype) diff --git a/src/daaf/evalexps/evaljob.py b/src/daaf/evalexps/evaljob.py index 5d246d7..27c413e 100644 --- a/src/daaf/evalexps/evaljob.py +++ b/src/daaf/evalexps/evaljob.py @@ -84,7 +84,7 @@ def create_tasks( output_dir: str, task_prefix: str, log_episode_frequency: int, -) -> Sequence[Tuple[ray.ObjectRef, expconfig.ExperimentTask]]: +) -> Sequence[Tuple[ray.ObjectRef, expconfig.ExperimentRun]]: """ Runs numerical experiments on policy evaluation. """ @@ -99,7 +99,7 @@ def create_tasks( ) ) experiments_and_context = add_experiment_context(experiments, assets_dir=assets_dir) - experiment_tasks = tuple( + experiment_runs = tuple( expconfig.generate_tasks_from_experiments_context_and_run_config( run_config=expconfig.RunConfig( num_episodes=num_episodes, @@ -112,10 +112,10 @@ def create_tasks( ) ) # shuffle tasks to balance workload - experiment_tasks = random.sample(experiment_tasks, len(experiment_tasks)) # type: ignore + experiment_runs = random.sample(experiment_runs, len(experiment_runs)) # type: ignore # bundle tasks experiment_batches = utils.bundle( - experiment_tasks, bundle_size=constants.DEFAULT_BATCH_SIZE + experiment_runs, bundle_size=constants.DEFAULT_BATCH_SIZE ) logging.info( "Parsed %d DAAF configs and %d environments into %d tasks", @@ -171,20 +171,20 @@ def add_experiment_context( @ray.remote def run_experiments( - experiments_batch: Sequence[expconfig.ExperimentTask], + experiments_batch: Sequence[expconfig.ExperimentRun], ) -> Sequence[str]: """ Runs experiments. """ ids: List[str] = [] - for experiment_task in experiments_batch: - task_id = f"{experiment_task.exp_id}/{experiment_task.run_id}" + for experiment_run in experiments_batch: + task_id = f"{experiment_run.exp_id}/{experiment_run.run_id}" logging.debug( "Experiment %s starting: %s", task_id, - experiment_task, + experiment_run, ) - evaluation.run_fn(experiment_task) + evaluation.run_fn(experiment_run) ids.append(task_id) logging.debug("Experiment %s finished", task_id) return ids diff --git a/src/daaf/evalexps/evaluation.py b/src/daaf/evalexps/evaluation.py index 161d99d..fe4cce9 100644 --- a/src/daaf/evalexps/evaluation.py +++ b/src/daaf/evalexps/evaluation.py @@ -17,7 +17,7 @@ from daaf.evalexps import methods -def run_fn(experiment_task: expconfig.ExperimentTask): +def run_fn(experiment_run: expconfig.ExperimentRun): """ Entry point running on-policy evaluation for DAAF. @@ -26,50 +26,50 @@ def run_fn(experiment_task: expconfig.ExperimentTask): """ # init env and agent env_spec = task.create_env_spec( - problem=experiment_task.experiment.env_config.name, - env_args=experiment_task.experiment.env_config.args, + problem=experiment_run.experiment.env_config.name, + env_args=experiment_run.experiment.env_config.args, ) traj_mappers = task.create_trajectory_mappers( env_spec=env_spec, - reward_period=experiment_task.experiment.daaf_config.reward_period, - traj_mapping_method=experiment_task.experiment.daaf_config.traj_mapping_method, + reward_period=experiment_run.experiment.daaf_config.reward_period, + traj_mapping_method=experiment_run.experiment.daaf_config.traj_mapping_method, buffer_size_or_multiplier=(None, None), - drop_truncated_feedback_episodes=experiment_task.experiment.daaf_config.drop_truncated_feedback_episodes, + drop_truncated_feedback_episodes=experiment_run.experiment.daaf_config.drop_truncated_feedback_episodes, ) returns_collector = task.returns_collection_mapper() traj_mappers = tuple([returns_collector] + list(traj_mappers)) # Policy Eval with DAAF logging.info("Starting DAAF Evaluation Experiments") policy = create_eval_policy( - env_spec=env_spec, daaf_config=experiment_task.experiment.daaf_config + env_spec=env_spec, daaf_config=experiment_run.experiment.daaf_config ) results = evaluate_policy( policy=policy, env_spec=env_spec, - daaf_config=experiment_task.experiment.daaf_config, - num_episodes=experiment_task.run_config.num_episodes, - algorithm=experiment_task.experiment.daaf_config.algorithm, - learnign_args=experiment_task.experiment.learning_args, + daaf_config=experiment_run.experiment.daaf_config, + num_episodes=experiment_run.run_config.num_episodes, + algorithm=experiment_run.experiment.daaf_config.algorithm, + learnign_args=experiment_run.experiment.learning_args, generate_steps_fn=task.create_generate_episode_fn(mappers=traj_mappers), ) env_info: Mapping[str, Any] = { "env": { "name": env_spec.name, "level": env_spec.level, - "args": json.dumps(experiment_task.experiment.env_config.args), + "args": json.dumps(experiment_run.experiment.env_config.args), }, } with utils.ExperimentLogger( - log_dir=experiment_task.run_config.output_dir, - exp_id=experiment_task.exp_id, - run_id=experiment_task.run_id, + log_dir=experiment_run.run_config.output_dir, + exp_id=experiment_run.exp_id, + run_id=experiment_run.run_id, params={ **utils.json_from_dict( - dataclasses.asdict(experiment_task.experiment.daaf_config), + dataclasses.asdict(experiment_run.experiment.daaf_config), dict_encode_level=0, ), - **dataclasses.asdict(experiment_task.experiment.learning_args), - **experiment_task.context, + **dataclasses.asdict(experiment_run.experiment.learning_args), + **experiment_run.context, **env_info, }, ) as exp_logger: @@ -77,7 +77,7 @@ def run_fn(experiment_task: expconfig.ExperimentTask): try: for episode, snapshot in enumerate(results): state_values = snapshot.values - if episode % experiment_task.run_config.log_episode_frequency == 0: + if episode % experiment_run.run_config.log_episode_frequency == 0: mean_returns = np.mean(returns_collector.traj_returns) exp_logger.log( episode=episode, @@ -90,13 +90,13 @@ def run_fn(experiment_task: expconfig.ExperimentTask): logging.debug( "\nEstimated values run %d of %s:\n%s", - experiment_task.run_id, - experiment_task.exp_id, + experiment_run.run_id, + experiment_run.exp_id, state_values, ) except Exception as err: raise RuntimeError( - f"Task {experiment_task.exp_id}, run {experiment_task.run_id} failed" + f"Task {experiment_run.exp_id}, run {experiment_run.run_id} failed" ) from err env_spec.environment.close() @@ -161,7 +161,7 @@ def evaluate_policy( schedule=task.constant_learning_rate, ), gamma=learnign_args.discount_factor, - nstep=daaf_config.reward_period, + nstep=daaf_config.algorithm_args["nstep"], state_id_fn=env_spec.discretizer.state, initial_values=initial_state_values, generate_episode=generate_steps_fn, @@ -176,7 +176,7 @@ def evaluate_policy( schedule=task.constant_learning_rate, ), gamma=learnign_args.discount_factor, - nstep=daaf_config.reward_period, + nstep=daaf_config.algorithm_args["nstep"], state_id_fn=env_spec.discretizer.state, initial_values=initial_state_values, generate_episode=generate_steps_fn, diff --git a/src/daaf/evalexps/results_agg_pipeline.py b/src/daaf/evalexps/results_agg_pipeline.py index b05489d..28991f9 100644 --- a/src/daaf/evalexps/results_agg_pipeline.py +++ b/src/daaf/evalexps/results_agg_pipeline.py @@ -141,11 +141,7 @@ def parse_experiment_logs(paths: Sequence[str]) -> ray.data.Dataset: """ logs_files = [os.path.join(path, "experiment-logs.jsonl") for path in paths] ds_logs = ray.data.read_json( - logs_files, - include_paths=True, - partition_filter=ray.data.datasource.FileExtensionFilter( - file_extensions=["jsonl"] - ), + logs_files, include_paths=True, file_extensions=["jsonl"] ) return ds_logs @@ -240,7 +236,7 @@ def calc_state_metrics(y_preds, y_true, axis): } def apply(row): - y_preds = row["state_values"] + y_preds = np.array(row["state_values"]) y_true = np.tile(row["meta"]["dyna_prog_state_values"], reps=(len(y_preds), 1)) over_runs_then_states = calc_state_metrics( y_preds=y_preds, y_true=y_true, axis=0 diff --git a/src/daaf/expconfig.py b/src/daaf/expconfig.py index 6e39f4f..5046880 100644 --- a/src/daaf/expconfig.py +++ b/src/daaf/expconfig.py @@ -72,7 +72,7 @@ class Experiment: @dataclasses.dataclass(frozen=True) -class ExperimentTask: +class ExperimentRun: """ A single experiment task. """ @@ -156,7 +156,7 @@ def generate_tasks_from_experiments_context_and_run_config( experiments_and_context: Sequence[Tuple[Experiment, Mapping[str, Any]]], num_runs: int, task_prefix: str, -) -> Iterator[ExperimentTask]: +) -> Iterator[ExperimentRun]: """ Given a sequence of experiments, expands them to tasks. @@ -179,7 +179,7 @@ def generate_tasks_from_experiments_context_and_run_config( ] ) for idx in range(num_runs): - yield ExperimentTask( + yield ExperimentRun( exp_id=exp_id, run_id=idx, experiment=experiment, diff --git a/src/daaf/replay_mapper.py b/src/daaf/replay_mapper.py index d2830a8..4dd2e7e 100644 --- a/src/daaf/replay_mapper.py +++ b/src/daaf/replay_mapper.py @@ -9,16 +9,7 @@ import copy import dataclasses import logging -from typing import ( - Any, - Callable, - Dict, - FrozenSet, - Iterator, - Optional, - Sequence, - Set, -) +from typing import Any, Callable, Dict, FrozenSet, Iterator, Optional, Sequence, Set import numpy as np from rlplg import combinatorics, core @@ -181,6 +172,12 @@ def __init__( buffer_size: The maximum number of trajectories to keep in the buffer - each one should contain `reward_period` steps. impute_value: Value to use when rewards are missing. + terminal_states: A set of terminal states. When provided, + estimates for these states are ignored. + factor_terminal_states: Solves LEAST for non-terminal states + only when `True`. + prefill_buffer: Examples for terminal states are added + before data collection when `True`. Note: decay isn't used when summing up the rewards for K steps. """ @@ -208,6 +205,7 @@ def __init__( self.state_id_fn = state_id_fn self.action_id_fn = action_id_fn self.buffer_size = buffer_size + self.terminal_states = terminal_states self.factor_terminal_states = factor_terminal_states self.prefill_buffer = prefill_buffer self.num_updates = 0 @@ -272,32 +270,20 @@ def apply( reward_sum = 0.0 # Run estimation at the first possible moment, if self._estimation_buffer.is_full_rank: - logging.debug("Estimating rewards with Least-Squares.") try: - estimated_rewards = math_ops.solve_least_squares( + estimated_rewards = self.__estimate_rewards( matrix=self._estimation_buffer.matrix, rhs=self._estimation_buffer.rhs, + num_states=self.num_states, + num_actions=self.num_actions, + terminal_state_action_mask=self._terminal_state_action_mask, ) - # we only solved for non-terminal states - if estimated_rewards.size < ( - self.num_states * self.num_actions - ): - pos = 0 - est_rewards_ext = np.zeros( - self.num_states * self.num_actions, dtype=np.float64 - ) - ignore_factors_mask = np.reshape( - self._terminal_state_action_mask, newshape=[-1] - ) - for idx in range(len(ignore_factors_mask)): - if ignore_factors_mask[idx] == 1: - est_rewards_ext[idx] = 0.0 - else: - est_rewards_ext[idx] = estimated_rewards[pos] - pos += 1 - estimated_rewards = est_rewards_ext new_rtable = np.reshape( - estimated_rewards, + self.__zero_terminal_states( + estimated_rewards, + terminal_states=self.terminal_states, + num_actions=self.num_actions, + ), newshape=(self.num_states, self.num_actions), ) # update the reward estimates by a fraction of the delta @@ -317,6 +303,60 @@ def apply( yield dataclasses.replace(traj_step, reward=reward) + @staticmethod + def __estimate_rewards( + matrix: np.ndarray, + rhs: np.ndarray, + num_states: int, + num_actions: int, + terminal_state_action_mask: np.ndarray, + ): + logging.debug("Estimating rewards with Least-Squares.") + estimated_rewards = math_ops.solve_least_squares( + matrix=matrix, + rhs=rhs, + ) + # we only solved for non-terminal states + if estimated_rewards.size < (num_states * num_actions): + pos = 0 + est_rewards_ext = np.zeros(num_states * num_actions, dtype=np.float64) + ignore_factors_mask = np.reshape(terminal_state_action_mask, newshape=[-1]) + for idx in range(len(ignore_factors_mask)): + if ignore_factors_mask[idx] == 1: + est_rewards_ext[idx] = 0.0 + else: + est_rewards_ext[idx] = estimated_rewards[pos] + pos += 1 + estimated_rewards = est_rewards_ext + new_rtable = np.reshape( + estimated_rewards, + newshape=(num_states, num_actions), + ) + return new_rtable + + @staticmethod + def __zero_terminal_states( + rtable: np.ndarray, terminal_states: FrozenSet[int], num_actions: int + ): + """ + Assumes states are zero-indexed. + """ + new_rtable = copy.deepcopy(rtable) + ndim = np.ndim(new_rtable) + if ndim > 2: + raise ValueError(f"`rtable` must be 1D or 2D tensor. Got: {ndim}.") + + for tstate in terminal_states: + if ndim == 1: + # Examples + # nA = 2 + # 0 -> 0*2, (0+1)*2-1 -> 0, 1 + # 2 -> 2*2, (2+1)*2-1 -> 4, 5 + new_rtable[tstate * num_actions : (tstate + 1) * num_actions - 1] = 0.0 + elif ndim == 2: + new_rtable[tstate, :] = 0.0 + return new_rtable + class DaafMdpWithOptionsMapper(TrajMapper): """ @@ -527,17 +567,22 @@ def __init__( if ignore_factors_mask is not None else np.zeros(num_factors, dtype=np.int64) ) + self.nkeep_factors = self.num_factors - ( + 0 + if self.ignore_factors_mask is None + else np.sum(self.ignore_factors_mask).item() + ) self._keep_factors_mask = (self.ignore_factors_mask - 1) * -1 # `np.where` returns a tuple per dim; # Keep the first dim - self._col_mask = np.where(self._keep_factors_mask == 1)[0] + self._cols_mask = np.where(self._keep_factors_mask == 1)[0] # pre-allocate arrays - self._rows = np.zeros(shape=(buffer_size, num_factors), dtype=np.float64) + self._rows = np.zeros(shape=(buffer_size, self.nkeep_factors), dtype=np.float64) self._b = np.zeros(shape=(buffer_size,), dtype=np.float64) self._next_pos = 0 self._additions = 0 self._factors_tracker: Set[int] = set() - self._rank_flag = np.zeros(shape=self.num_factors, dtype=np.float64) + self._rank_flag = np.zeros(shape=self.nkeep_factors, dtype=np.float64) def add(self, row: np.ndarray, rhs: np.ndarray) -> None: """ @@ -553,24 +598,28 @@ def add(self, row: np.ndarray, rhs: np.ndarray) -> None: ) # Add rows for factors of interest (`keep_factors_mask`). - if np.sum(row * self._keep_factors_mask) > 0: - mask = (row > 0).astype(np.int64) - row_key = combinatorics.sequence_to_integer(space_size=2, sequence=mask) + candidate_row = row[self._cols_mask] + if np.sum(candidate_row) > 0: + mask = (candidate_row > 0).astype(np.int64) + row_factors_key = combinatorics.sequence_to_integer( + space_size=2, sequence=mask + ) # Only add distict rows - based on their mask - if row_key not in self._factors_tracker: + if row_factors_key not in self._factors_tracker: + current_row_mask = (self._rows[self._next_pos] > 0).astype(np.int64) current_row_key = combinatorics.sequence_to_integer( space_size=2, - sequence=(self._rows[self._next_pos] > 0).astype(np.int64), + sequence=current_row_mask, ) if current_row_key in self._factors_tracker: # Every row is unique, thus removing it # removes it's marker self._factors_tracker.remove(current_row_key) self._rank_flag -= self._rows[self._next_pos] - self._factors_tracker.add(row_key) + self._factors_tracker.add(row_factors_key) self._rank_flag += mask - self._rows[self._next_pos] = row + self._rows[self._next_pos] = candidate_row self._b[self._next_pos] = rhs # cycle least recent self._next_pos = (self._next_pos + 1) % self.buffer_size @@ -584,8 +633,8 @@ def matrix(self) -> np.ndarray: it returns the values available - which can be an empty array. """ if self._additions >= self.buffer_size: - return self._rows[:, self._col_mask] - return self._rows[: self._next_pos, self._col_mask] + return self._rows + return self._rows[: self._next_pos, :] @property def rhs(self) -> np.ndarray: @@ -608,9 +657,11 @@ def is_empty(self) -> bool: @property def is_full_rank(self) -> bool: - return self._additions >= self.num_factors and np.sum( - (self._rank_flag * self._keep_factors_mask) > 0 - ) == np.sum(self._keep_factors_mask) + square_or_tall = self._additions >= self.nkeep_factors + factors_rank = ( + np.sum((self._rank_flag > 0).astype(np.int64)) == self.nkeep_factors + ).item() + return square_or_tall and factors_rank class Counter: diff --git a/src/daaf/rewardest/estimation.py b/src/daaf/rewardest/estimation.py index 31a2140..2a604b1 100644 --- a/src/daaf/rewardest/estimation.py +++ b/src/daaf/rewardest/estimation.py @@ -20,14 +20,9 @@ def estimate_reward( logging_steps: int = 100, factor_terminal_states: bool = False, prefill_buffer: bool = False, - export_path: Optional[str] = None, ) -> Mapping[str, np.ndarray]: env_spec = envsuite.load(spec["name"], **spec["args"]) - terminal_states = ( - core.infer_env_terminal_states(env_spec.mdp.transition) - if factor_terminal_states - else frozenset() - ) + terminal_states = core.infer_env_terminal_states(env_spec.mdp.transition) init_rtable = np.zeros( shape=(env_spec.mdp.env_desc.num_states, env_spec.mdp.env_desc.num_actions), dtype=np.float64, @@ -48,13 +43,17 @@ def estimate_reward( ) policy = policies.PyRandomPolicy(num_actions=env_spec.mdp.env_desc.num_actions) # collect data - logging.info("Collecting data for %s", spec["name"]) + logging.info("Collecting data for %s/%s", spec["name"], spec["args"]) episode = 1 steps = 0 yhat_lstsq: Optional[np.ndarray] = None yhat_ols_em: Optional[np.ndarray] = None - meta: Dict[str, Any] = {"max_episodes": max_episodes, "est_accuracy": accuracy} - visited_states: Dict[int, int] = collections.defaultdict(int) + meta: Dict[str, Any] = { + "max_episodes": max_episodes, + "est_accuracy": accuracy, + "ols_iters": None, + } + num_visited_states_dist: Dict[int, int] = collections.defaultdict(int) while True: traj = envplay.generate_episode(env_spec.environment, policy=policy) @@ -63,7 +62,7 @@ def estimate_reward( episode_visited_states.add( env_spec.discretizer.state(traj_step.observation) ) - visited_states[len(episode_visited_states)] += 1 + num_visited_states_dist[len(episode_visited_states)] += 1 if ( not mapper._estimation_buffer.is_empty @@ -72,7 +71,12 @@ def estimate_reward( break if episode % logging_steps == 0: - logging.info("Data collection for %s at %d episodes", spec["name"], episode) + logging.info( + "Data collection for %s/%s at %d episodes", + spec["name"], + spec["args"], + episode, + ) if episode >= max_episodes: break episode += 1 @@ -81,8 +85,9 @@ def estimate_reward( # estimate rewards if mapper._estimation_buffer.is_full_rank: logging.info( - "Estimating rewards for %s, after %d episodes (%d steps). Matrix shape: %s", + "Estimating rewards for %s/%s, after %d episodes (%d steps). Matrix shape: %s", spec["name"], + spec["args"], episode, steps, mapper._estimation_buffer.matrix.shape, @@ -92,7 +97,9 @@ def estimate_reward( agg_rewards=mapper._estimation_buffer.rhs, accuracy=accuracy, ) - logging.info("OLS ran in %d iterations for %s", iters, spec["name"]) + logging.info( + "OLS ran in %d iterations for %s/%s", iters, spec["name"], spec["args"] + ) yhat_lstsq = lstsq_reward_estimation( obs_matrix=mapper._estimation_buffer.matrix, agg_rewards=mapper._estimation_buffer.rhs, @@ -118,19 +125,6 @@ def estimate_reward( spec["name"], spec["args"], ) - - if export_path: - import os.path - - for name, array in zip( - ["lhs", "rhs"], - [mapper._estimation_buffer.matrix, mapper._estimation_buffer.rhs], - ): - if not os.path.exists(export_path): - os.makedirs(export_path) - with open(os.path.join(export_path, name), "wb") as writable: - np.save(writable, array) - return { "least": yhat_lstsq, "ols_em": yhat_ols_em, @@ -138,8 +132,15 @@ def estimate_reward( "steps": steps, "full_rank": mapper._estimation_buffer.is_full_rank, "samples": mapper._estimation_buffer.matrix.shape[0], + "data": { + "lhs": mapper._estimation_buffer.matrix, + "rhs": mapper._estimation_buffer.rhs, + }, "buffer_size": mapper._estimation_buffer.buffer_size, - "episode_visited_states_count": dict(visited_states), + "episode_visited_states_count": { + "num_unique_states": list(num_visited_states_dist.keys()), + "num_episodes": list(num_visited_states_dist.values()), + }, "meta": meta, } diff --git a/src/daaf/rewardest/estjob.py b/src/daaf/rewardest/estjob.py index 5e414b1..bd9a7ee 100644 --- a/src/daaf/rewardest/estjob.py +++ b/src/daaf/rewardest/estjob.py @@ -3,26 +3,39 @@ """ import argparse +import copy import dataclasses +import itertools import json import logging +import pathlib import random +import time import uuid from typing import Any, Mapping, Optional, Sequence, Tuple +import numpy as np import ray import ray.data +from daaf import constants, utils from daaf.rewardest import estimation ENV_SPECS = [ + {"name": "ABCSeq", "args": {"length": 2, "distance_penalty": False}}, + {"name": "ABCSeq", "args": {"length": 3, "distance_penalty": False}}, {"name": "ABCSeq", "args": {"length": 7, "distance_penalty": False}}, {"name": "ABCSeq", "args": {"length": 10, "distance_penalty": False}}, {"name": "FrozenLake-v1", "args": {"is_slippery": False, "map_name": "4x4"}}, + { + "name": "GridWorld", + "args": {"grid": "ooooo\nooxoo\noxooo\nsxxxg"}, + }, { "name": "GridWorld", "args": {"grid": "oooooooooooo\noooooooooooo\noooooooooooo\nsxxxxxxxxxxg"}, }, + {"name": "RedGreenSeq", "args": {"cure": ["red", "green"]}}, { "name": "RedGreenSeq", "args": { @@ -37,8 +50,7 @@ EST_PLAIN = "plain" EST_FACTOR_TS = "factor-ts" EST_PREFILL_BUFFER = "prefill-buffer" - -AGG_REWARD_PERIODS = [2, 3, 4, 5, 6, 7, 8] +AGG_REWARD_PERIODS = [2, 3, 4, 5, 6, 7, 8, 15] EST_ACCURACY = 1e-8 @@ -59,7 +71,7 @@ class EstimationPipelineArgs: @dataclasses.dataclass(frozen=True) -class EstimationTask: +class EstimationRun: uid: str env_spec: Mapping[str, Any] run_id: int @@ -81,7 +93,7 @@ def main(args: EstimationPipelineArgs): logging.info("Ray Context: %s", context) logging.info("Ray Nodes: %s", ray.nodes()) - tasks_futures = create_tasks( + tasks_and_result_refs = create_tasks( env_specs=ENV_SPECS, agg_reward_periods=AGG_REWARD_PERIODS, num_runs=args.num_runs, @@ -92,29 +104,33 @@ def main(args: EstimationPipelineArgs): # since ray tracks objectref items # we swap the key:value - task_ref_to_spec = {future: task for task, future in tasks_futures} - results = [] - unfinished_tasks = list(task_ref_to_spec.keys()) + task_ref_to_tasks = { + result_ref: tasks for tasks, result_ref in tasks_and_result_refs + } + datasets = [] + unfinished_task_ref = list(task_ref_to_tasks.keys()) while True: - finished_tasks, unfinished_tasks = ray.wait(unfinished_tasks) - for finished_task in finished_tasks: - task = task_ref_to_spec[finished_task] - result = {"result": ray.get(finished_task)} - task_dict = dataclasses.asdict(task) - task_dict["env_spec"] = json.dumps(task_dict["env_spec"]) - entry = {**result, **task_dict} - results.append(entry) + finished_task_ref, unfinished_task_ref = ray.wait(unfinished_task_ref) + for finished_task_ref in finished_task_ref: + datasets.append(ray.get(finished_task_ref)) logging.info( "Tasks left: %d out of %d.", - len(unfinished_tasks), - len(task_ref_to_spec), + len(unfinished_task_ref), + len(task_ref_to_tasks), ) - if len(unfinished_tasks) == 0: + if len(unfinished_task_ref) == 0: break - ray.data.from_items(results).write_json(args.output_dir) + if len(datasets) > 0: + if len(datasets) > 1: + ds_head, ds_tail = datasets[0], datasets[1:] + ds_result: ray.data.Dataset = ds_head.union(*ds_tail) + else: + ds_result: ray.data.Dataset = datasets[0] + ds_output = ds_result.map(serialize) + ds_output.write_parquet(args.output_dir) def create_tasks( @@ -124,36 +140,56 @@ def create_tasks( max_episodes: int, log_episode_frequency: int, accuracy: float, -) -> Sequence[Tuple[EstimationTask, ray.ObjectRef]]: - tasks = [] +) -> Sequence[Tuple[ray.ObjectRef]]: + estimation_runs = [] futures = [] - for env_spec in env_specs: - for reward_period in agg_reward_periods: - for method in (EST_PLAIN, EST_FACTOR_TS, EST_PREFILL_BUFFER): - uid = str(uuid.uuid4()) - for run_id in range(num_runs): - task = EstimationTask( - uid=uid, - env_spec=env_spec, - reward_period=reward_period, - run_id=run_id, - accuracy=accuracy, - max_episodes=max_episodes, - log_episode_frequency=log_episode_frequency, - method=method, - ) - tasks.append(task) + methods = (EST_PLAIN, EST_FACTOR_TS, EST_PREFILL_BUFFER) + for env_spec, reward_period, method in itertools.product( + env_specs, agg_reward_periods, methods + ): + uid = str(uuid.uuid4()) + estimation_runs.extend( + [ + EstimationRun( + uid=uid, + env_spec=env_spec, + reward_period=reward_period, + run_id=run_id, + accuracy=accuracy, + max_episodes=max_episodes, + log_episode_frequency=log_episode_frequency, + method=method, + ) + for run_id in range(num_runs) + ] + ) + # shuffle to workload - random.shuffle(tasks) - for task in tasks: - futures.append((task, estimate.remote(task))) + random.shuffle(estimation_runs) + # batch tasks + estimation_run_batches = utils.bundle( + estimation_runs, bundle_size=constants.DEFAULT_BATCH_SIZE + ) + for batch in estimation_run_batches: + futures.append((batch, run_fn.remote(batch))) return futures @ray.remote -def estimate(task: EstimationTask) -> Mapping[str, Any]: +def run_fn(estimation_runs: Sequence[EstimationRun]) -> ray.data.Dataset: + results = [] + for experiment_run in estimation_runs: + estimation_run_dict = dataclasses.asdict(experiment_run) + result = estimate(experiment_run) + result = {"result": result} + entry = {**result, **estimation_run_dict} + results.append(entry) + return ray.data.from_items(results) + + +def estimate(task: EstimationRun) -> Mapping[str, Any]: """ - Runs evaluation. + Reward estimation. """ logging.info( "Task %s for %s/%d (%s) starting", @@ -191,15 +227,39 @@ def estimate(task: EstimationTask) -> Mapping[str, Any]: return result +def serialize(example: Mapping[str, Any]) -> Mapping[str, Any]: + def go(key: str, element: Any): + if key == "args" and isinstance(element, Mapping): + return json.dumps(element) + elif isinstance(element, Mapping): + return {skey: go(skey, svalue) for skey, svalue in element.items()} + elif isinstance(element, np.ndarray): + return element.flatten() + return copy.deepcopy(element) + + return {key: go(key, value) for key, value in example.items()} + + def parse_args() -> EstimationPipelineArgs: """ Parses program arguments. """ arg_parser = argparse.ArgumentParser() - arg_parser.add_argument("--num-runs", type=int, required=True) - arg_parser.add_argument("--max-episodes", type=int, required=True) - arg_parser.add_argument("--output-dir", type=str, required=True) - arg_parser.add_argument("--log-episode-frequency", type=int, required=True) + arg_parser.add_argument("--num-runs", type=int, default=3) + arg_parser.add_argument("--max-episodes", type=int, default=2500) + arg_parser.add_argument( + "--output-dir", + type=str, + default=pathlib.Path.home() + / "fs/daaf/exp/reward-estjob/logs" + / str(int(time.time())), + ) + arg_parser.add_argument("--log-episode-frequency", type=int, default=1) + + # arg_parser.add_argument("--num-runs", type=int, required=True) + # arg_parser.add_argument("--max-episodes", type=int, required=True) + # arg_parser.add_argument("--output-dir", type=str, required=True) + # arg_parser.add_argument("--log-episode-frequency", type=int, required=True) arg_parser.add_argument("--cluster-uri", type=str, default=None) known_args, unknown_args = arg_parser.parse_known_args() logging.info("Unknown args: %s", unknown_args) diff --git a/src/daaf/task.py b/src/daaf/task.py index dc2b9a3..19ee216 100644 --- a/src/daaf/task.py +++ b/src/daaf/task.py @@ -86,6 +86,7 @@ def create_trajectory_mappers( num_states=env_spec.mdp.env_desc.num_states, num_actions=env_spec.mdp.env_desc.num_actions, ), + terminal_states=core.infer_env_terminal_states(env_spec.mdp.transition), ) ) elif traj_mapping_method == constants.MDP_WITH_OPTIONS_MAPPER: diff --git a/test-requirements.txt b/test-requirements.txt index f624d30..5187225 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,9 +1,5 @@ -# -# This file is autogenerated by pip-compile with Python 3.9 -# by the following command: -# -# pip-compile --no-emit-find-links --no-emit-index-url --no-emit-options test-requirements.in -# +# This file was autogenerated by uv via the following command: +# uv pip compile --no-emit-index-url --no-emit-find-links test-requirements.in -o test-requirements.txt absl-py==2.0.0 # via # -r requirements.txt @@ -13,11 +9,8 @@ aiohttp==3.8.4 # via # -r requirements.txt # aiohttp-cors - # ray aiohttp-cors==0.7.0 - # via - # -r requirements.txt - # ray + # via -r requirements.txt aiosignal==1.3.1 # via # -r requirements.txt @@ -37,10 +30,6 @@ attrs==22.2.0 # aiohttp # hypothesis # jsonschema -blessed==1.20.0 - # via - # -r requirements.txt - # gpustat cachetools==5.3.0 # via # -r requirements.txt @@ -54,6 +43,10 @@ charset-normalizer==3.1.0 # -r requirements.txt # aiohttp # requests +clarabel==0.9.0 + # via + # -r requirements.txt + # cvxpy click==8.0.4 # via # -r requirements.txt @@ -63,17 +56,21 @@ cloudpickle==2.2.1 # -r requirements.txt # gymnasium colorful==0.5.5 - # via - # -r requirements.txt - # ray -coverage[toml]==7.3.2 + # via -r requirements.txt +coverage==7.3.2 # via # -r test-requirements.in # pytest-cov +cvxpy==1.5.3 + # via -r requirements.txt distlib==0.3.6 # via # -r requirements.txt # virtualenv +ecos==2.0.14 + # via + # -r requirements.txt + # cvxpy exceptiongroup==1.2.0 # via # hypothesis @@ -98,9 +95,7 @@ frozenlist==1.3.3 # aiosignal # ray fsspec==2024.2.0 - # via - # -r requirements.txt - # ray + # via -r requirements.txt gast==0.4.0 # via # -r requirements.txt @@ -127,14 +122,9 @@ googleapis-common-protos==1.58.0 # via # -r requirements.txt # google-api-core -gpustat==1.0.0 - # via - # -r requirements.txt - # ray grpcio==1.51.3 # via # -r requirements.txt - # ray # tensorboard # tensorflow gymnasium==0.28.1 @@ -163,6 +153,10 @@ jax-jumpy==1.0.0 # via # -r requirements.txt # gymnasium +jinja2==3.1.4 + # via + # -r requirements.txt + # memray jsonschema==4.17.3 # via # -r requirements.txt @@ -175,14 +169,35 @@ libclang==16.0.6 # via # -r requirements.txt # tensorflow +linkify-it-py==2.0.3 + # via + # -r requirements.txt + # markdown-it-py markdown==3.5.1 # via # -r requirements.txt # tensorboard +markdown-it-py==3.0.0 + # via + # -r requirements.txt + # mdit-py-plugins + # rich + # textual markupsafe==2.1.3 # via # -r requirements.txt + # jinja2 # werkzeug +mdit-py-plugins==0.4.2 + # via + # -r requirements.txt + # markdown-it-py +mdurl==0.1.2 + # via + # -r requirements.txt + # markdown-it-py +memray==1.14.0 + # via -r requirements.txt msgpack==1.0.5 # via # -r requirements.txt @@ -195,29 +210,28 @@ multidict==6.0.4 numpy==1.23.5 # via # -r requirements.txt + # clarabel + # cvxpy + # ecos # gymnasium # h5py # jax-jumpy # opt-einsum + # osqp # pandas # pyarrow - # ray + # qdldl # rlplg # scipy + # scs # tensorboard # tensorflow -nvidia-ml-py==11.495.46 - # via - # -r requirements.txt - # gpustat oauthlib==3.2.2 # via # -r requirements.txt # requests-oauthlib opencensus==0.11.2 - # via - # -r requirements.txt - # ray + # via -r requirements.txt opencensus-context==0.1.3 # via # -r requirements.txt @@ -226,6 +240,10 @@ opt-einsum==3.3.0 # via # -r requirements.txt # tensorflow +osqp==0.6.7.post1 + # via + # -r requirements.txt + # cvxpy packaging==23.2 # via # -r requirements.txt @@ -233,9 +251,7 @@ packaging==23.2 # ray # tensorflow pandas==2.0.3 - # via - # -r requirements.txt - # ray + # via -r requirements.txt pillow==9.4.0 # via # -r requirements.txt @@ -247,9 +263,7 @@ platformdirs==3.1.1 pluggy==1.0.0 # via pytest prometheus-client==0.13.1 - # via - # -r requirements.txt - # ray + # via -r requirements.txt protobuf==4.25.1 # via # -r requirements.txt @@ -258,18 +272,10 @@ protobuf==4.25.1 # ray # tensorboard # tensorflow -psutil==5.9.4 - # via - # -r requirements.txt - # gpustat py-spy==0.3.14 - # via - # -r requirements.txt - # ray + # via -r requirements.txt pyarrow==15.0.0 - # via - # -r requirements.txt - # ray + # via -r requirements.txt pyasn1==0.4.8 # via # -r requirements.txt @@ -280,9 +286,11 @@ pyasn1-modules==0.2.8 # -r requirements.txt # google-auth pydantic==1.10.6 + # via -r requirements.txt +pygments==2.18.0 # via # -r requirements.txt - # ray + # rich pyrsistent==0.19.3 # via # -r requirements.txt @@ -305,7 +313,11 @@ pyyaml==6.0 # via # -r requirements.txt # ray -ray[data,default]==2.9.3 +qdldl==0.1.7.post4 + # via + # -r requirements.txt + # osqp +ray==2.23.0 # via -r requirements.txt requests==2.28.2 # via @@ -318,7 +330,12 @@ requests-oauthlib==1.3.1 # via # -r requirements.txt # google-auth-oauthlib -rlplg @ git+https://github.com/guidj/rlplg.git@v0.19.10 +rich==13.8.1 + # via + # -r requirements.txt + # memray + # textual +rlplg @ git+https://github.com/guidj/rlplg.git@93d819cab2eeefdd9bbd0ec7acb5d1dd0e5d1bed # via -r requirements.txt rsa==4.9 # via @@ -327,21 +344,32 @@ rsa==4.9 scipy==1.10.1 # via # -r requirements.txt + # clarabel + # cvxpy + # ecos + # osqp + # qdldl # rlplg + # scs +scs==3.2.7 + # via + # -r requirements.txt + # cvxpy +setuptools==74.1.2 + # via + # -r requirements.txt + # tensorboard + # tensorflow six==1.16.0 # via # -r requirements.txt # astunparse - # blessed # google-auth # google-pasta - # gpustat # python-dateutil # tensorflow smart-open==6.3.0 - # via - # -r requirements.txt - # ray + # via -r requirements.txt sortedcontainers==2.4.0 # via hypothesis tensorboard==2.13.0 @@ -366,6 +394,10 @@ termcolor==2.3.0 # via # -r requirements.txt # tensorflow +textual==0.78.0 + # via + # -r requirements.txt + # memray tomli==2.0.1 # via # coverage @@ -376,22 +408,21 @@ typing-extensions==4.5.0 # gymnasium # pydantic # tensorflow + # textual tzdata==2024.1 # via # -r requirements.txt # pandas +uc-micro-py==1.0.3 + # via + # -r requirements.txt + # linkify-it-py urllib3==1.26.15 # via # -r requirements.txt # requests virtualenv==20.21.0 - # via - # -r requirements.txt - # ray -wcwidth==0.2.6 - # via - # -r requirements.txt - # blessed + # via -r requirements.txt werkzeug==3.0.1 # via # -r requirements.txt @@ -413,6 +444,3 @@ zipp==3.15.0 # via # -r requirements.txt # importlib-metadata - -# The following packages are considered to be unsafe in a requirements file: -# setuptools diff --git a/tests/daaf/test_replay_mapper.py b/tests/daaf/test_replay_mapper.py index d82a819..349824a 100644 --- a/tests/daaf/test_replay_mapper.py +++ b/tests/daaf/test_replay_mapper.py @@ -244,7 +244,6 @@ def test_daaf_lsq_reward_attribution_mapper_apply(): """ Initial events will have reward values from rtable. Once there are enough samples, Least Square Estimates are used instead. - The estimates are updated at `update_steps` intervals. Problem: Two states (A, B), two actions (left, right) Table: @@ -335,6 +334,119 @@ def test_daaf_lsq_reward_attribution_mapper_apply(): np.testing.assert_array_equal(output.truncated, expected.truncated) +def test_daaf_lsq_reward_attribution_mapper_apply_with_terminal_states(): + """ + Initial events will have reward values from rtable. + Once there are enough samples, Least Square Estimates are used instead. + + Problem: Three states (A, B, C), two actions (left, right) + Table: + Actions + States Left Right + A 0 1 + B 0 1 + C 0 0 + + events: (A, left, A, right) -> (0, 0), (0, 1) -> 0 + 1 = 1 + (B, left, B, right) -> (1, 0), (1, 1) -> 0 + 1 = 1 + (A, right, B, left) -> (0, 1), (1, 0) -> 1 + 0 = 1 + (A, right, B, right) -> (0, 1), (1, 1) -> 1 + 1 = 2 + (C, left, C, right) -> (2, 0), (2, 1) -> 0 + 0 = 0 + + matrix: (A, left), (A, right), (B, left), (B, right) (C,left) (C, right) + 1 1 0 0 0 0 + 0 0 1 1 0 0 + 0 1 1 0 0 0 + 0 1 0 1 0 0 + 0 0 0 0 1 1 + rhs: 1, 1, 1, 2, 0 + """ + + mapper = replay_mapper.DaafLsqRewardAttributionMapper( + num_states=3, + num_actions=2, + reward_period=2, + state_id_fn=item, + action_id_fn=item, + buffer_size=8, + init_rtable=defaults.array([-1.0, -1.0], [-1.0, -1.0], [-1.0, -1.0]), + impute_value=88, + terminal_states={ + 2, + }, + ) + + # We are simulating cumulative rewards. + # So we supply the actual rewards to the simulator to aggregate (sum). + inputs = [ + traj_step(state=0, action=0, reward=0.0, prob=0.0), + traj_step(state=0, action=1, reward=1.0, prob=1.0), + traj_step(state=1, action=0, reward=0.0, prob=0.0), + traj_step(state=1, action=1, reward=1.0, prob=1.0), + traj_step(state=2, action=0, reward=-1.0, prob=0.0), + traj_step(state=2, action=1, reward=-1.0, prob=1.0), + # after the event above, all factors are present, but we still lack rows + # to satisfy the condition m >= n + traj_step(state=0, action=1, reward=1.0, prob=1.0), + traj_step(state=1, action=0, reward=0.0, prob=0.0), + traj_step(state=0, action=1, reward=1.0, prob=1.0), + traj_step(state=1, action=1, reward=1.0, prob=1.0), + traj_step(state=1, action=0, reward=0.0, prob=0.0), + traj_step(state=2, action=1, reward=-1.0, prob=1.0), + # after the event above, m >= n + # the events will below will be emitted with estimated rewards + traj_step(state=0, action=0, reward=-7.0, prob=0.0), + traj_step(state=0, action=1, reward=-7.0, prob=1.0), + traj_step(state=1, action=0, reward=-7.0, prob=0.0), + traj_step(state=1, action=1, reward=-7.0, prob=1.0), + traj_step(state=0, action=1, reward=-7.0, prob=1.0), + traj_step(state=1, action=0, reward=-7.0, prob=0.0), + traj_step(state=0, action=1, reward=-7.0, prob=1.0), + traj_step(state=1, action=1, reward=-7.0, prob=1.0), + traj_step(state=2, action=0, reward=0.0, prob=0.0), + traj_step(state=2, action=1, reward=0.0, prob=1.0), + ] + expectactions = [ + # the events below are emitted with the impute value + # or the aggregate feedback + traj_step(state=0, action=0, reward=88, prob=0.0), + traj_step(state=0, action=1, reward=1.0, prob=1.0), + traj_step(state=1, action=0, reward=88, prob=0.0), + traj_step(state=1, action=1, reward=1.0, prob=1.0), + traj_step(state=2, action=0, reward=88, prob=0.0), + traj_step(state=2, action=1, reward=-2.0, prob=1.0), + traj_step(state=0, action=1, reward=88, prob=1.0), + traj_step(state=1, action=0, reward=1.0, prob=0.0), + traj_step(state=0, action=1, reward=88, prob=1.0), + traj_step(state=1, action=1, reward=2.0, prob=1.0), + traj_step(state=1, action=0, reward=88, prob=0.0), + traj_step(state=2, action=1, reward=-1.0, prob=1.0), + # the events below are emitted with estimated rewards + traj_step(state=0, action=0, reward=0.0, prob=0.0), + traj_step(state=0, action=1, reward=1.0, prob=1.0), + traj_step(state=1, action=0, reward=0.0, prob=0.0), + traj_step(state=1, action=1, reward=1.0, prob=1.0), + traj_step(state=0, action=1, reward=1.0, prob=1.0), + traj_step(state=1, action=0, reward=0.0, prob=0.0), + traj_step(state=0, action=1, reward=1.0, prob=1.0), + traj_step(state=1, action=1, reward=1.0, prob=1.0), + # zero'd out because 2 is passed as a terminal state + traj_step(state=2, action=0, reward=0.0, prob=0.0), + traj_step(state=2, action=1, reward=0.0, prob=1.0), + ] + + outputs = tuple(mapper.apply(inputs)) + assert len(outputs) == 22 + for output, expected in zip(outputs, expectactions): + # reward can only be approximately equal + np.testing.assert_array_equal(output.observation, expected.observation) + np.testing.assert_array_equal(output.action, expected.action) + np.testing.assert_array_equal(output.policy_info, expected.policy_info) + np.testing.assert_array_almost_equal(output.reward, expected.reward) + np.testing.assert_array_equal(output.terminated, expected.terminated) + np.testing.assert_array_equal(output.truncated, expected.truncated) + + def test_daaf_mdp_with_options_mapper_apply_given_truncated_options(): mapper = replay_mapper.DaafMdpWithOptionsMapper() inputs = [ @@ -637,36 +749,100 @@ def test_abqueuebuffer_init(): def test_abqueuebuffer(): buffer = replay_mapper.AbQueueBuffer(buffer_size=4, num_factors=3) + # First entry, added. buffer.add(np.array([1, 0, 0]), 1) assert getattr(buffer, "_factors_tracker") == set([4]) np.testing.assert_allclose(getattr(buffer, "_rank_flag"), np.array([1, 0, 0])) + assert buffer.is_empty is False + assert buffer.is_full_rank is False + np.testing.assert_allclose(buffer.matrix, np.array([[1, 0, 0]])) + # Second entry, independent, added. buffer.add(np.array([1, 0, 1]), 2) assert getattr(buffer, "_factors_tracker") == set([4, 5]) np.testing.assert_allclose(getattr(buffer, "_rank_flag"), np.array([2, 0, 1])) + assert buffer.is_empty is False + assert buffer.is_full_rank is False + np.testing.assert_allclose(buffer.matrix, np.array([[1, 0, 0], [1, 0, 1]])) + # Third entry, independent, added. + # Matrix is now full rank. buffer.add(np.array([1, 1, 1]), 3) assert getattr(buffer, "_factors_tracker") == set([4, 5, 7]) np.testing.assert_allclose(getattr(buffer, "_rank_flag"), np.array([3, 1, 2])) + assert buffer.is_empty is False + assert buffer.is_full_rank is True + np.testing.assert_allclose( + buffer.matrix, np.array([[1, 0, 0], [1, 0, 1], [1, 1, 1]]) + ) - # duplicate entry; no change + # Fourth entry, non-independent; no change buffer.add(np.array([1, 0, 1]), 4) assert getattr(buffer, "_factors_tracker") == set([4, 5, 7]) np.testing.assert_allclose(getattr(buffer, "_rank_flag"), np.array([3, 1, 2])) - + assert buffer.is_empty is False + assert buffer.is_full_rank is True np.testing.assert_allclose( buffer.matrix, np.array([[1, 0, 0], [1, 0, 1], [1, 1, 1]]) ) + # Fifth entry, indenpedent, added. buffer.add(np.array([1, 1, 0]), 5) assert getattr(buffer, "_factors_tracker") == set([4, 5, 7, 6]) np.testing.assert_allclose(getattr(buffer, "_rank_flag"), np.array([4, 2, 2])) - + assert buffer.is_empty is False + assert buffer.is_full_rank is True np.testing.assert_allclose( buffer.matrix, np.array([[1, 0, 0], [1, 0, 1], [1, 1, 1], [1, 1, 0]]) ) +def test_abqueuebuffer_ignore_factors_mask(): + buffer = replay_mapper.AbQueueBuffer( + buffer_size=4, num_factors=3, ignore_factors_mask=np.array([0, 1, 0]) + ) + + # First entry, added. + buffer.add(np.array([1, 0, 0]), 1) + assert getattr(buffer, "_factors_tracker") == set([2]) + np.testing.assert_allclose(getattr(buffer, "_rank_flag"), np.array([1, 0])) + assert buffer.is_empty is False + assert buffer.is_full_rank is False + np.testing.assert_allclose(buffer.matrix, np.array([[1, 0]])) + + # Second entry, added. + buffer.add(np.array([1, 0, 1]), 2) + assert getattr(buffer, "_factors_tracker") == set([2, 3]) + np.testing.assert_allclose(getattr(buffer, "_rank_flag"), np.array([2, 1])) + assert buffer.is_empty is False + assert buffer.is_full_rank is True + np.testing.assert_allclose(buffer.matrix, np.array([[1, 0], [1, 1]])) + + # Third entry, non-independent, ignored. + buffer.add(np.array([1, 1, 1]), 3) + assert getattr(buffer, "_factors_tracker") == set([2, 3]) + np.testing.assert_allclose(getattr(buffer, "_rank_flag"), np.array([2, 1])) + assert buffer.is_empty is False + assert buffer.is_full_rank is True + np.testing.assert_allclose(buffer.matrix, np.array([[1, 0], [1, 1]])) + + # Fourth entry, duplicate, ignored. + buffer.add(np.array([0, 1, 0]), 4) + assert getattr(buffer, "_factors_tracker") == set([2, 3]) + np.testing.assert_allclose(getattr(buffer, "_rank_flag"), np.array([2, 1])) + assert buffer.is_empty is False + assert buffer.is_full_rank is True + np.testing.assert_allclose(buffer.matrix, np.array([[1, 0], [1, 1]])) + + # Fifth entry, independent, added. + buffer.add(np.array([0, 0, 1]), 5) + assert getattr(buffer, "_factors_tracker") == set([2, 3, 1]) + np.testing.assert_allclose(getattr(buffer, "_rank_flag"), np.array([2, 2])) + assert buffer.is_empty is False + assert buffer.is_full_rank is True + np.testing.assert_allclose(buffer.matrix, np.array([[1, 0], [1, 1], [0, 1]])) + + def test_counter_init(): counter = replay_mapper.Counter() assert counter.value is None diff --git a/tox.ini b/tox.ini index a4b4a1f..03b750b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,10 +1,10 @@ [tox] -envlist = py38,docs,manifest,check-formatting,lint +envlist = py39,docs,manifest,check-formatting,lint skipsdist = True usedevelop = True [testenv] -basepython = python3.8 +basepython = python3.9 deps = -r{toxinidir}/dev-requirements.txt -e . @@ -13,20 +13,20 @@ commands = [testenv:manifest] ; a safety check for source distributions -basepython = python3.8 +basepython = python3.9 deps = check-manifest skip_install = true commands = check-manifest [testenv:check-formatting] -basepython = python3.8 +basepython = python3.9 deps = ruff==0.2.2 skip_install = true commands = ruff format src/daaf tests --line-length 88 --diff --check [testenv:format] -basepython = python3.8 +basepython = python3.9 deps = ruff==0.2.2 skip_install = true commands = @@ -35,7 +35,7 @@ commands = [testenv:lint] ; feel free to add flake8 plugins (i.e. flake8-import-order), add isort, ; or use another linter of your choice -basepython = python3.8 +basepython = python3.9 deps = ruff==0.2.2 skip_install = true commands = ruff check @@ -45,7 +45,7 @@ commands = ruff check ; i.e. `(env) $ pytest`, it will still pick up this configuration. [testenv:lint-types] -basepython = python3.8 +basepython = python3.9 deps = mypy skip_install = true setenv =