diff --git a/notebooks/gym_tuto.ipynb b/notebooks/gym_tuto.ipynb index fc92e8dc22..acee69b2c9 100644 --- a/notebooks/gym_tuto.ipynb +++ b/notebooks/gym_tuto.ipynb @@ -33,8 +33,13 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 38, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-21T15:16:02.538864Z", + "start_time": "2021-10-21T15:16:02.533765Z" + } + }, "outputs": [], "source": [ "from typing import Optional, Callable\n", @@ -43,7 +48,7 @@ "\n", "from IPython.display import clear_output\n", "import matplotlib.pyplot as plt\n", - "from stable_baselines3 import PPO\n", + "from stable_baselines3 import PPO, SAC\n", "import gym\n", "\n", "from skdecide.hub.solver.stable_baselines import StableBaseline\n", @@ -223,12 +228,17 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 39, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-21T15:16:13.432840Z", + "start_time": "2021-10-21T15:16:13.427673Z" + } + }, "outputs": [], "source": [ "solver = StableBaseline(\n", - " PPO, \"MlpPolicy\", learn_config={\"total_timesteps\": 50000}, verbose=True\n", + " SAC, \"MlpPolicy\", learn_config={\"total_timesteps\": 50000}, verbose=True\n", ")" ] }, @@ -236,17 +246,215 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Training solver on domain\n", - "The solver will try to find an appropriate policy to solve the maze. " + "### Training solver on domain" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 40, "metadata": { + "ExecuteTime": { + "end_time": "2021-10-21T15:22:56.379775Z", + "start_time": "2021-10-21T15:16:16.042887Z" + }, "scrolled": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cuda device\n", + "------------------------------------------\n", + "| time/ | |\n", + "| episodes | 4 |\n", + "| fps | 134 |\n", + "| time_elapsed | 29 |\n", + "| total timesteps | 3996 |\n", + "| train/ | |\n", + "| actor_loss | -6.26 |\n", + "| approx_kl | 0.0028341115 |\n", + "| clip_fraction | 0.0246 |\n", + "| clip_range | 0.2 |\n", + "| critic_loss | 0.00218 |\n", + "| ent_coef | 0.311 |\n", + "| ent_coef_loss | -1.97 |\n", + "| entropy_loss | 0.11 |\n", + "| explained_variance | 0.017 |\n", + "| learning_rate | 0.0003 |\n", + "| loss | -0.0109 |\n", + "| n_updates | 3895 |\n", + "| policy_gradient_loss | -0.00649 |\n", + "| std | 0.211 |\n", + "| value_loss | 0.00705 |\n", + "------------------------------------------\n", + "---------------------------------\n", + "| time/ | |\n", + "| episodes | 8 |\n", + "| fps | 131 |\n", + "| time_elapsed | 60 |\n", + "| total timesteps | 7992 |\n", + "| train/ | |\n", + "| actor_loss | -6.3 |\n", + "| critic_loss | 0.00104 |\n", + "| ent_coef | 0.0942 |\n", + "| ent_coef_loss | -3.87 |\n", + "| learning_rate | 0.0003 |\n", + "| n_updates | 7891 |\n", + "---------------------------------\n", + "---------------------------------\n", + "| time/ | |\n", + "| episodes | 12 |\n", + "| fps | 130 |\n", + "| time_elapsed | 91 |\n", + "| total timesteps | 11988 |\n", + "| train/ | |\n", + "| actor_loss | -5.12 |\n", + "| critic_loss | 0.000366 |\n", + "| ent_coef | 0.0295 |\n", + "| ent_coef_loss | -4.91 |\n", + "| learning_rate | 0.0003 |\n", + "| n_updates | 11887 |\n", + "---------------------------------\n", + "---------------------------------\n", + "| time/ | |\n", + "| episodes | 16 |\n", + "| fps | 130 |\n", + "| time_elapsed | 122 |\n", + "| total timesteps | 15984 |\n", + "| train/ | |\n", + "| actor_loss | -3.96 |\n", + "| critic_loss | 0.000292 |\n", + "| ent_coef | 0.0101 |\n", + "| ent_coef_loss | -4.43 |\n", + "| learning_rate | 0.0003 |\n", + "| n_updates | 15883 |\n", + "---------------------------------\n", + "---------------------------------\n", + "| time/ | |\n", + "| episodes | 20 |\n", + "| fps | 130 |\n", + "| time_elapsed | 153 |\n", + "| total timesteps | 19980 |\n", + "| train/ | |\n", + "| actor_loss | -3.09 |\n", + "| critic_loss | 0.000286 |\n", + "| ent_coef | 0.00486 |\n", + "| ent_coef_loss | -0.972 |\n", + "| learning_rate | 0.0003 |\n", + "| n_updates | 19879 |\n", + "---------------------------------\n", + "---------------------------------\n", + "| time/ | |\n", + "| episodes | 24 |\n", + "| fps | 129 |\n", + "| time_elapsed | 184 |\n", + "| total timesteps | 23976 |\n", + "| train/ | |\n", + "| actor_loss | -2.4 |\n", + "| critic_loss | 0.000111 |\n", + "| ent_coef | 0.00371 |\n", + "| ent_coef_loss | -0.492 |\n", + "| learning_rate | 0.0003 |\n", + "| n_updates | 23875 |\n", + "---------------------------------\n", + "---------------------------------\n", + "| time/ | |\n", + "| episodes | 28 |\n", + "| fps | 128 |\n", + "| time_elapsed | 217 |\n", + "| total timesteps | 27972 |\n", + "| train/ | |\n", + "| actor_loss | -1.83 |\n", + "| critic_loss | 0.000186 |\n", + "| ent_coef | 0.00304 |\n", + "| ent_coef_loss | -0.772 |\n", + "| learning_rate | 0.0003 |\n", + "| n_updates | 27871 |\n", + "---------------------------------\n", + "---------------------------------\n", + "| time/ | |\n", + "| episodes | 32 |\n", + "| fps | 127 |\n", + "| time_elapsed | 250 |\n", + "| total timesteps | 31968 |\n", + "| train/ | |\n", + "| actor_loss | -1.39 |\n", + "| critic_loss | 5.11e-05 |\n", + "| ent_coef | 0.00263 |\n", + "| ent_coef_loss | 0.666 |\n", + "| learning_rate | 0.0003 |\n", + "| n_updates | 31867 |\n", + "---------------------------------\n", + "---------------------------------\n", + "| time/ | |\n", + "| episodes | 36 |\n", + "| fps | 127 |\n", + "| time_elapsed | 283 |\n", + "| total timesteps | 35964 |\n", + "| train/ | |\n", + "| actor_loss | -1.05 |\n", + "| critic_loss | 1.75e-05 |\n", + "| ent_coef | 0.00225 |\n", + "| ent_coef_loss | -0.104 |\n", + "| learning_rate | 0.0003 |\n", + "| n_updates | 35863 |\n", + "---------------------------------\n", + "---------------------------------\n", + "| time/ | |\n", + "| episodes | 40 |\n", + "| fps | 126 |\n", + "| time_elapsed | 314 |\n", + "| total timesteps | 39960 |\n", + "| train/ | |\n", + "| actor_loss | -0.782 |\n", + "| critic_loss | 1.45e-05 |\n", + "| ent_coef | 0.00203 |\n", + "| ent_coef_loss | -0.0396 |\n", + "| learning_rate | 0.0003 |\n", + "| n_updates | 39859 |\n", + "---------------------------------\n", + "---------------------------------\n", + "| time/ | |\n", + "| episodes | 44 |\n", + "| fps | 125 |\n", + "| time_elapsed | 349 |\n", + "| total timesteps | 43956 |\n", + "| train/ | |\n", + "| actor_loss | -0.574 |\n", + "| critic_loss | 8.13e-06 |\n", + "| ent_coef | 0.0018 |\n", + "| ent_coef_loss | -0.302 |\n", + "| learning_rate | 0.0003 |\n", + "| n_updates | 43855 |\n", + "---------------------------------\n", + "---------------------------------\n", + "| time/ | |\n", + "| episodes | 48 |\n", + "| fps | 125 |\n", + "| time_elapsed | 382 |\n", + "| total timesteps | 47952 |\n", + "| train/ | |\n", + "| actor_loss | -0.406 |\n", + "| critic_loss | 0.000652 |\n", + "| ent_coef | 0.00178 |\n", + "| ent_coef_loss | -0.25 |\n", + "| learning_rate | 0.0003 |\n", + "| n_updates | 47851 |\n", + "---------------------------------\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "GymDomain.solve_with(solver, domain_factory)" ] @@ -259,13 +467,20 @@ "\n", "We can use the trained solver to roll out an episode to see if this is actually solving the problem at hand.\n", "\n", - "For educative purpose, we define here our own rollout (which will probably be needed if you want to actually use the solver in a real case). If you want to take a look at the (more complex) one already implemented in the library, see the `rollout()` function in [utils.py](https://github.com/airbus/scikit-decide/blob/master/skdecide/utils.py) module.\n" + "For educative purpose, we define here our own rollout (which will probably be needed if you want to actually use the solver in a real case). If you want to take a look at the (more complex) one already implemented in the library, see the `rollout()` function in [utils.py](https://github.com/airbus/scikit-decide/blob/master/skdecide/utils.py) module.\n", + "\n", + "By default we display the solution in a matplotlib figure. If you need only to check wether the goal is reached or not, you can specify `render=False`. In this case, the rollout is greatly speed up and a message is still printed at the end of process specifying success or not, with the number of steps required." ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 41, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-21T15:22:56.385962Z", + "start_time": "2021-10-21T15:22:56.381089Z" + } + }, "outputs": [], "source": [ "def rollout(\n", @@ -273,6 +488,7 @@ " solver: Solver,\n", " max_steps: int,\n", " pause_between_steps: Optional[float] = 0.01,\n", + " render: bool = True,\n", "):\n", " \"\"\"Roll out one episode in a domain according to the policy of a trained solver.\n", "\n", @@ -282,6 +498,8 @@ " max_steps: maximum number of steps allowed to reach the goal\n", " pause_between_steps: time (s) paused between agent movements.\n", " No pause if None.\n", + " render: if True, the rollout is rendered in a matplotlib figure as an animation;\n", + " if False, speed up a lot the rollout.\n", "\n", " \"\"\"\n", " # Initialize episode\n", @@ -289,12 +507,13 @@ " observation = domain.reset()\n", "\n", " # Initialize image\n", - " plt.ioff()\n", - " fig, ax = plt.subplots(1)\n", - " ax.axis(\"off\")\n", - " plt.ion()\n", - " img = ax.imshow(domain.render(mode=\"rgb_array\"))\n", - " display(fig)\n", + " if render:\n", + " plt.ioff()\n", + " fig, ax = plt.subplots(1)\n", + " ax.axis(\"off\")\n", + " plt.ion()\n", + " img = ax.imshow(domain.render(mode=\"rgb_array\"))\n", + " display(fig)\n", "\n", " # loop until max_steps or goal is reached\n", " for i_step in range(1, max_steps + 1):\n", @@ -308,17 +527,19 @@ " observation = outcome.observation\n", "\n", " # update image\n", - " img.set_data(domain.render(mode=\"rgb_array\"))\n", - " fig.canvas.draw()\n", - " clear_output(wait=True)\n", - " display(fig)\n", + " if render: \n", + " img.set_data(domain.render(mode=\"rgb_array\"))\n", + " fig.canvas.draw()\n", + " clear_output(wait=True)\n", + " display(fig)\n", "\n", " # final state reached?\n", " if outcome.termination:\n", " break\n", "\n", " # close the figure to avoid jupyter duplicating the last image\n", - " plt.close(fig)\n", + " if render:\n", + " plt.close(fig)\n", "\n", " # goal reached?\n", " is_goal_reached = observation[0] >= 0.45\n", @@ -347,7 +568,7 @@ "source": [ "domain = domain_factory()\n", "try:\n", - " rollout(domain=domain, solver=solver, max_steps=500, pause_between_steps=None)\n", + " rollout(domain=domain, solver=solver, max_steps=999, pause_between_steps=None, render=False)\n", "finally:\n", " domain.close()" ] @@ -363,6 +584,53 @@ "We will see in the next sections that non-RL methods can overcome this issue." ] }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "ExecuteTime": { + "end_time": "2021-10-21T15:23:01.907505Z", + "start_time": "2021-10-21T15:22:56.387361Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Episode #0\n", + "Goal not reached after 999 steps!\n", + "Episode #1\n", + "Goal not reached after 999 steps!\n", + "Episode #2\n", + "Goal not reached after 999 steps!\n", + "Episode #3\n", + "Goal not reached after 999 steps!\n", + "Episode #4\n", + "Goal not reached after 999 steps!\n", + "Episode #5\n", + "Goal not reached after 999 steps!\n", + "Episode #6\n", + "Goal not reached after 999 steps!\n", + "Episode #7\n", + "Goal not reached after 999 steps!\n", + "Episode #8\n", + "Goal not reached after 999 steps!\n", + "Episode #9\n", + "Goal not reached after 999 steps!\n" + ] + } + ], + "source": [ + "for i_episode in range(10):\n", + " print(f\"Episode #{i_episode}\")\n", + " domain = domain_factory()\n", + " try:\n", + " rollout(domain=domain, solver=solver, max_steps=999, pause_between_steps=None, render=False)\n", + " finally:\n", + " domain.close()" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -457,8 +725,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Training solver on domain\n", - "The solver will try to find an appropriate policy to solve the maze. " + "### Training solver on domain" ] }, { @@ -496,7 +763,7 @@ "source": [ "domain = domain_factory()\n", "try:\n", - " rollout(domain=domain, solver=solver, max_steps=500, pause_between_steps=None)\n", + " rollout(domain=domain, solver=solver, max_steps=999, pause_between_steps=None, render=True)\n", "finally:\n", " domain.close()" ] @@ -505,9 +772,29 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "CGP seems doing well on this problem. Indeed the presence of periodic functions ($asin$, $acos$, and $atan$) in its base set of atomic functions makes it suitable for modelling this kind of pendular motion.\n", - "\n", - "***Warning***: on some cases, it happens that CGP does not actually find a solution. It may be due to a different random seeding." + "CGP seems doing well on this problem. Indeed the presence of periodic functions ($asin$, $acos$, and $atan$) in its base set of atomic functions makes it suitable for modelling this kind of pendular motion." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "***Warning***: On some cases, it happens that CGP does not actually find a solution. As there is randomness here, this is not possible. Running multiple episodes can sometimes solve the problem. If you have bad luck, you will even have to train again the solver." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i_episode in range(10):\n", + " print(f\"Episode #{i_episode}\")\n", + " domain = domain_factory()\n", + " try:\n", + " rollout(domain=domain, solver=solver, max_steps=999, pause_between_steps=None, render=False)\n", + " finally:\n", + " domain.close()" ] }, { @@ -662,8 +949,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Training solver on domain\n", - "The solver will try to find an appropriate policy to solve the maze. " + "### Training solver on domain" ] }, { @@ -704,6 +990,7 @@ " solver: Solver,\n", " max_steps: int,\n", " pause_between_steps: Optional[float] = 0.01,\n", + " render: bool = False,\n", "):\n", " \"\"\"Roll out one episode in a domain according to the policy of a trained solver.\n", "\n", @@ -713,6 +1000,8 @@ " max_steps: maximum number of steps allowed to reach the goal\n", " pause_between_steps: time (s) paused between agent movements.\n", " No pause if None.\n", + " render: if True, the rollout is rendered in a matplotlib figure as an animation;\n", + " if False, speed up a lot the rollout.\n", "\n", " \"\"\"\n", " # Initialize episode\n", @@ -720,12 +1009,13 @@ " observation = domain.reset()\n", "\n", " # Initialize image\n", - " plt.ioff()\n", - " fig, ax = plt.subplots(1)\n", - " ax.axis(\"off\")\n", - " plt.ion()\n", - " img = ax.imshow(domain.render(mode=\"rgb_array\"))\n", - " display(fig)\n", + " if render:\n", + " plt.ioff()\n", + " fig, ax = plt.subplots(1)\n", + " ax.axis(\"off\")\n", + " plt.ion()\n", + " img = ax.imshow(domain.render(mode=\"rgb_array\"))\n", + " display(fig)\n", "\n", " # loop until max_steps or goal is reached\n", " for i_step in range(1, max_steps + 1):\n", @@ -739,17 +1029,19 @@ " observation = outcome.observation\n", "\n", " # update image\n", - " img.set_data(domain.render(mode=\"rgb_array\"))\n", - " fig.canvas.draw()\n", - " clear_output(wait=True)\n", - " display(fig)\n", + " if render:\n", + " img.set_data(domain.render(mode=\"rgb_array\"))\n", + " fig.canvas.draw()\n", + " clear_output(wait=True)\n", + " display(fig)\n", "\n", " # final state reached?\n", " if outcome.termination:\n", " break\n", "\n", " # close the figure to avoid jupyter duplicating the last image\n", - " plt.close(fig)\n", + " if render:\n", + " plt.close(fig)\n", "\n", " # goal reached?\n", " is_goal_reached = observation._state[0] >= 0.45\n", @@ -765,13 +1057,13 @@ "cell_type": "code", "execution_count": null, "metadata": { - "scrolled": false + "scrolled": true }, "outputs": [], "source": [ "domain = domain4width_factory()\n", "try:\n", - " rollout_iw(domain=domain, solver=solver, max_steps=500, pause_between_steps=None)\n", + " rollout_iw(domain=domain, solver=solver, max_steps=999, pause_between_steps=None, render=True)\n", "finally:\n", " domain.close()" ]