Skip to content

Commit

Permalink
Lint gym notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
nhuet committed Nov 12, 2021
1 parent 6596fad commit acfb358
Showing 1 changed file with 47 additions and 41 deletions.
88 changes: 47 additions & 41 deletions notebooks/2_gym_tuto.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,25 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional, Callable\n",
"from time import sleep\n",
"import os\n",
"from time import sleep\n",
"from typing import Callable, Optional\n",
"\n",
"from IPython.display import clear_output\n",
"import gym\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import clear_output\n",
"from stable_baselines3 import PPO\n",
"import gym\n",
"\n",
"from skdecide.hub.solver.stable_baselines import StableBaseline\n",
"from skdecide import Solver\n",
"from skdecide.hub.domain.gym import (\n",
" GymDomain,\n",
" GymWidthDomain,\n",
" GymDiscreteActionDomain,\n",
" GymDomain,\n",
" GymPlanningDomain,\n",
" GymWidthDomain,\n",
")\n",
"from skdecide.hub.solver.iw import IW\n",
"from skdecide.hub.solver.cgp import CGP\n",
"from skdecide.hub.solver.iw import IW\n",
"from skdecide.hub.solver.stable_baselines import StableBaseline\n",
"\n",
"# choose standard matplolib inline backend to render plots\n",
"%matplotlib inline"
Expand Down Expand Up @@ -204,9 +204,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"domain = domain_factory()\n",
Expand Down Expand Up @@ -242,9 +240,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"GymDomain.solve_with(solver, domain_factory)"
Expand Down Expand Up @@ -313,7 +309,7 @@
" observation = outcome.observation\n",
"\n",
" # update image\n",
" if render: \n",
" if render:\n",
" img.set_data(domain.render(mode=\"rgb_array\"))\n",
" fig.canvas.draw()\n",
" clear_output(wait=True)\n",
Expand Down Expand Up @@ -347,14 +343,18 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"domain = domain_factory()\n",
"try:\n",
" rollout(domain=domain, solver=solver, max_steps=999, pause_between_steps=None, render=True)\n",
" rollout(\n",
" domain=domain,\n",
" solver=solver,\n",
" max_steps=999,\n",
" pause_between_steps=None,\n",
" render=True,\n",
" )\n",
"finally:\n",
" domain.close()"
]
Expand Down Expand Up @@ -435,9 +435,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"domain = domain_factory()\n",
Expand Down Expand Up @@ -471,9 +469,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"GymDomain.solve_with(solver, domain_factory)"
Expand All @@ -496,14 +492,18 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"domain = domain_factory()\n",
"try:\n",
" rollout(domain=domain, solver=solver, max_steps=999, pause_between_steps=None, render=True)\n",
" rollout(\n",
" domain=domain,\n",
" solver=solver,\n",
" max_steps=999,\n",
" pause_between_steps=None,\n",
" render=True,\n",
" )\n",
"finally:\n",
" domain.close()"
]
Expand Down Expand Up @@ -532,7 +532,13 @@
" print(f\"Episode #{i_episode}\")\n",
" domain = domain_factory()\n",
" try:\n",
" rollout(domain=domain, solver=solver, max_steps=999, pause_between_steps=None, render=False)\n",
" rollout(\n",
" domain=domain,\n",
" solver=solver,\n",
" max_steps=999,\n",
" pause_between_steps=None,\n",
" render=False,\n",
" )\n",
" finally:\n",
" domain.close()"
]
Expand Down Expand Up @@ -615,7 +621,7 @@
" GymWidthDomain.__init__(\n",
" self, continuous_feature_fidelity=continuous_feature_fidelity\n",
" )\n",
" gym_env._max_episode_steps = max_depth\n"
" gym_env._max_episode_steps = max_depth"
]
},
{
Expand Down Expand Up @@ -645,9 +651,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"domain = domain4width_factory()\n",
Expand Down Expand Up @@ -695,9 +699,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"GymDomainForWidthSolvers.solve_with(solver, domain4width_factory)"
Expand Down Expand Up @@ -796,14 +798,18 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"metadata": {},
"outputs": [],
"source": [
"domain = domain4width_factory()\n",
"try:\n",
" rollout_iw(domain=domain, solver=solver, max_steps=999, pause_between_steps=None, render=True)\n",
" rollout_iw(\n",
" domain=domain,\n",
" solver=solver,\n",
" max_steps=999,\n",
" pause_between_steps=None,\n",
" render=True,\n",
" )\n",
"finally:\n",
" domain.close()"
]
Expand Down

0 comments on commit acfb358

Please sign in to comment.