diff --git a/notebooks/1_maze_tuto.ipynb b/notebooks/1_maze_tuto.ipynb new file mode 100644 index 0000000000..e5a929e484 --- /dev/null +++ b/notebooks/1_maze_tuto.ipynb @@ -0,0 +1,727 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Maze tutorial\n", + "\n", + "In this tutorial, we tackle the maze problem.\n", + "We use this classical game to demonstrate how \n", + "- a new scikit-decide domain can be easily created\n", + "- to find solvers from scikit-decide hub matching its characteristics\n", + "- to apply a scikit-decide solver to a domain\n", + "- to create its own rollout function to play a trained solver on a domain\n", + "\n", + "\n", + "Notes:\n", + "- In order to focus on scikit-decide use, we put some code not directly related to the library in a [separate module](./maze_utils.py) (like maze generation and display).\n", + "- A similar maze domain is already defined in [scikit-decide hub](https://github.com/airbus/scikit-decide/blob/master/skdecide/hub/domain/maze/maze.py) but we do not use it for the sake of this tutorial.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from enum import Enum\n", + "from math import sqrt\n", + "from time import sleep\n", + "from typing import Any, NamedTuple, Optional, Union\n", + "\n", + "import ipywidgets as widgets\n", + "from IPython.display import clear_output, display\n", + "\n", + "# import Maze class from utility file for maze generation and display\n", + "from maze_utils import Maze\n", + "from stable_baselines3 import PPO\n", + "\n", + "from skdecide import DeterministicPlanningDomain, Solver, Space, Value\n", + "from skdecide.builders.domain import Renderable, UnrestrictedActions\n", + "from skdecide.hub.solver.astar import Astar\n", + "from skdecide.hub.solver.stable_baselines import StableBaseline\n", + "from skdecide.hub.space.gym import EnumSpace, ListSpace, MultiDiscreteSpace\n", + "from skdecide.utils import load_registered_solver, match_solvers\n", + "\n", + "# choose standard matplolib inline backend to render plots\n", + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## About the maze problem\n", + "The maze problem is about to make an agent finding the goal in a maze by going up, down, left, or right without going through walls. \n", + "\n", + "We show you such a maze by using the Maze class defined in the [maze module](./maze_utils.py). Here the agent starts at the top-left corner and the goal is at the bottom-right corner of the maze. The following colour convention is used:\n", + "- dark purple: walls\n", + "- yellow: empty cells\n", + "- light green: goal\n", + "- blue: current position" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# size of maze\n", + "width = 25\n", + "height = 19\n", + "# generate the maze\n", + "maze = Maze.generate_random_maze(width=width, height=height)\n", + "# starting position\n", + "entrance = 1, 1\n", + "# goal position\n", + "goal = height - 2, width - 2\n", + "# render the maze\n", + "ax, image = maze.render(current_position=entrance, goal=goal)\n", + "display(image.figure)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## MazeDomain definition\n", + "\n", + "In this section, we will wrap the Maze utility class so that it will be recognized as a scikit-decide domain. Several steps are needed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### States and actions\n", + "We begin by defining the state space (agent positions) and action space (agent movements)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class State(NamedTuple):\n", + " x: int\n", + " y: int\n", + "\n", + "\n", + "class Action(Enum):\n", + " up = 0\n", + " down = 1\n", + " left = 2\n", + " right = 3" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Domain type\n", + "Then we define the domain type from a base template (`DeterministicPlanningDomain`) with optional refinements (`UnrestrictedActions` and `Renderable`). This corresponds to the following characteristics:\n", + "- `DeterministicPlanningDomain`:\n", + " - only one agent\n", + " - deterministic starting state\n", + " - handle only actions\n", + " - actions are sequential\n", + " - deterministic transitions\n", + " - white box transition model\n", + " - goal states are defined\n", + " - positive costs (i.e. negative rewards)\n", + " - fully observable\n", + " - renderable (can be displayed)\n", + "- `UnrestrictedActions`: all actions are available at each step\n", + "- `Renderable`: can be displayed\n", + "\n", + "We also specify the type of states, observations, events, transition values, ... \n", + "\n", + "This is needed so that solvers know how to work properly with this domain, and this will also help IDE or Jupyter to propose you intelligent code completion." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class D(DeterministicPlanningDomain, UnrestrictedActions, Renderable):\n", + " T_state = State # Type of states\n", + " T_observation = State # Type of observations\n", + " T_event = Action # Type of events\n", + " T_value = float # Type of transition values (rewards or costs)\n", + " T_predicate = bool # Type of logical checks\n", + " T_info = None # Type of additional information in environment outcome\n", + " T_agent = Union # Inherited from SingleAgent" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Actual domain class\n", + "We can now implement the maze domain by \n", + "- deriving from the above domain type\n", + "- filling all non-implemented methods \n", + "- adding a constructor to define the maze & start/end positions.\n", + "\n", + "We also define (to help solvers that can make use of it)\n", + "- an heuristic for search algorithms\n", + "\n", + "\n", + "*NB: To know the methods not yet implemented, one can either use an IDE which can find them automatically or the [code generators](https://airbus.github.io/scikit-decide/guide/codegen.html) page in the online documentation, which generates the corresponding boilerplate code.*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class MazeDomain(D):\n", + " \"\"\"Maze scikit-decide domain\n", + "\n", + " Attributes:\n", + " start: the starting position\n", + " end: the goal to reach\n", + " maze: underlying Maze object\n", + "\n", + " \"\"\"\n", + "\n", + " def __init__(self, start: State, end: State, maze: Maze):\n", + " self.start = start\n", + " self.end = end\n", + " self.maze = maze\n", + " # display\n", + " self._image = None # image to update when rendering the maze\n", + " self._ax = None # subplot in which the maze is rendered\n", + "\n", + " def _get_next_state(self, memory: D.T_state, action: D.T_event) -> D.T_state:\n", + " \"\"\"Get the next state given a memory and action.\n", + "\n", + " Move agent according to action (except if bumping into a wall).\n", + "\n", + " \"\"\"\n", + "\n", + " next_x, next_y = memory.x, memory.y\n", + " if action == Action.up:\n", + " next_x -= 1\n", + " if action == Action.down:\n", + " next_x += 1\n", + " if action == Action.left:\n", + " next_y -= 1\n", + " if action == Action.right:\n", + " next_y += 1\n", + " return (\n", + " State(next_x, next_y)\n", + " if self.maze.is_an_empty_cell(next_x, next_y)\n", + " else memory\n", + " )\n", + "\n", + " def _get_transition_value(\n", + " self,\n", + " memory: D.T_state,\n", + " action: D.T_event,\n", + " next_state: Optional[D.T_state] = None,\n", + " ) -> Value[D.T_value]:\n", + " \"\"\"Get the value (reward or cost) of a transition.\n", + "\n", + " Set cost to 1 when moving (energy cost)\n", + " and to 2 when bumping into a wall (damage cost).\n", + "\n", + " \"\"\"\n", + " #\n", + " return Value(cost=1 if next_state != memory else 2)\n", + "\n", + " def _get_initial_state_(self) -> D.T_state:\n", + " \"\"\"Get the initial state.\n", + "\n", + " Set the start position as initial state.\n", + "\n", + " \"\"\"\n", + " return self.start\n", + "\n", + " def _get_goals_(self) -> Space[D.T_observation]:\n", + " \"\"\"Get the domain goals space (finite or infinite set).\n", + "\n", + " Set the end position as goal.\n", + "\n", + " \"\"\"\n", + " return ListSpace([self.end])\n", + "\n", + " def _is_terminal(self, state: State) -> D.T_predicate:\n", + " \"\"\"Indicate whether a state is terminal.\n", + "\n", + " Stop an episode only when goal reached.\n", + "\n", + " \"\"\"\n", + " return self._is_goal(state)\n", + "\n", + " def _get_action_space_(self) -> Space[D.T_event]:\n", + " \"\"\"Define action space.\"\"\"\n", + " return EnumSpace(Action)\n", + "\n", + " def _get_observation_space_(self) -> Space[D.T_observation]:\n", + " \"\"\"Define observation space.\"\"\"\n", + " return MultiDiscreteSpace([self.maze.height, self.maze.width])\n", + "\n", + " def _render_from(self, memory: State, **kwargs: Any) -> Any:\n", + " \"\"\"Render visually the maze.\n", + "\n", + " Returns:\n", + " matplotlib figure\n", + "\n", + " \"\"\"\n", + " # store used matplotlib subplot and image to only update them afterwards\n", + " self._ax, self._image = self.maze.render(\n", + " current_position=memory,\n", + " goal=self.end,\n", + " ax=self._ax,\n", + " image=self._image,\n", + " )\n", + " return self._image.figure\n", + "\n", + " def heuristic(self, s: D.T_State) -> Value[D.T_value]:\n", + " \"\"\"Heuristic to be used by search algorithms.\n", + "\n", + " Here Euclidean distance to goal.\n", + "\n", + " \"\"\"\n", + " return Value(cost=sqrt((self.end.x - s.x) ** 2 + (self.end.y - s.y) ** 2))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Domain factory" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use scikit-decide solvers on the maze problem, we will need a domain factory recreating the domain at will. \n", + "\n", + "Indeed the method `solve_with()` used [later](#Training-solver-on-the-domain) needs such a domain factory so that parallel solvers can create identical domains on separate processes. \n", + "(Even though we do not use parallel solvers in this particular notebook.)\n", + "\n", + "Here is such a domain factory reusing the maze created in [first section](#About-maze-problem). We render again the maze using the `render` method of the wrapping domain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# define start and end state from tuples defined above\n", + "start = State(*entrance)\n", + "end = State(*goal)\n", + "# domain factory\n", + "domain_factory = lambda: MazeDomain(maze=maze, start=start, end=end)\n", + "# instanciate the domain\n", + "domain = domain_factory()\n", + "# init the start position\n", + "domain.reset()\n", + "# display the corresponding maze\n", + "display(domain.render())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Solvers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Finding suitable solvers\n", + "The library hub includes a lot of solvers. We can use `match_solvers` function to show available solvers that fit the characteristics of the defined domain, according to the mixin classes used to define the [domain type](#domain-type). " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "match_solvers(domain=domain)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In the following, we will restrict ourself to 2 solvers:\n", + "\n", + "- `StableBaseline`, quite generic, allowing us to use reinforcement learning (RL) algorithms by wrapping a stable OpenAI Baselines solver ([stable_baselines3](https://github.com/DLR-RM/stable-baselines3))\n", + "- `LazyAstar` (A*), more specific, coming from path planning." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### PPO solver" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We first try a solver coming from the Reinforcement Learning community that makes use of OpenAI [stable_baselines3](https://github.com/DLR-RM/stable-baselines3), giving access to a lot of RL algorithms.\n", + "\n", + "Here we choose the [Proximal Policy Optimization (PPO)](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) solver. It directly optimizes the weights of the policy network using stochastic gradient ascent. See more details in stable baselines [documentation](https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html) and [original paper](https://arxiv.org/abs/1707.06347)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Solver instantiation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "solver = StableBaseline(\n", + " PPO, \"MlpPolicy\", learn_config={\"total_timesteps\": 10000}, verbose=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Training solver on the domain\n", + "The solver will try to find an appropriate policy to solve the maze. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MazeDomain.solve_with(solver, domain_factory)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The chosen syntax allows to apply *autocast* scikit-decide core mechanism to the solver so that generic solvers can be used to solve more specific domains. For instance solver that normally apply to multi-agent domain can also apply to single-agent domain thanks to this *autocast* mechanism." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Rolling out the solution (found by PPO)\n", + "\n", + "We can use the trained solver to roll out an episode to see if this is actually solving the maze.\n", + "\n", + "For educative purpose, we define here our own rollout (which will probably be needed if you want to actually use the solver in a real case). If you want to take a look at the (more complex) one already implemented in the library, see the [utils.py](https://github.com/airbus/scikit-decide/blob/master/skdecide/utils.py) module.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def rollout(\n", + " domain: MazeDomain,\n", + " solver: Solver,\n", + " max_steps: int,\n", + " pause_between_steps: Optional[float] = 0.01,\n", + "):\n", + " \"\"\"Roll out one episode in a domain according to the policy of a trained solver.\n", + "\n", + " Args:\n", + " domain: the maze domain to solve\n", + " solver: a trained solver\n", + " max_steps: maximum number of steps allowed to reach the goal\n", + " pause_between_steps: time (s) paused between agent movements.\n", + " No pause if None.\n", + "\n", + " \"\"\"\n", + " # Initialize episode\n", + " solver.reset()\n", + " observation = domain.reset()\n", + "\n", + " # Initialize image\n", + " figure = domain.render(observation)\n", + " display(figure)\n", + "\n", + " # loop until max_steps or goal is reached\n", + " for i_step in range(1, max_steps + 1):\n", + " if pause_between_steps is not None:\n", + " sleep(pause_between_steps)\n", + "\n", + " # choose action according to solver\n", + " action = solver.sample_action(observation)\n", + " # get corresponding action\n", + " outcome = domain.step(action)\n", + " observation = outcome.observation\n", + "\n", + " # update image\n", + " figure = domain.render(observation)\n", + " clear_output(wait=True)\n", + " display(figure)\n", + "\n", + " # final state reached?\n", + " if domain.is_terminal(observation):\n", + " break\n", + "\n", + " # goal reached?\n", + " is_goal_reached = domain.is_goal(observation)\n", + " if is_goal_reached:\n", + " print(f\"Goal reached in {i_step} steps!\")\n", + " else:\n", + " print(f\"Goal not reached after {i_step} steps!\")\n", + "\n", + " return is_goal_reached, i_step" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We set a maximum number of steps to reach the goal according to maze size in order to decide if the proposed solution is working or not." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "max_steps = maze.width * maze.height\n", + "print(f\"Rolling out a solution with max_steps={max_steps}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rollout(domain=domain, solver=solver, max_steps=max_steps, pause_between_steps=None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can see, the goal is not reached at the end of the episode. Though a generic algorithm that can apply to a lot of problems, PPO seems not to be able to solve this maze. This is actually due to the fact that the reward is sparse (you get rewarded only when you reach the goal) and this is nearly impossible for this kind of RL algorithm to reach the goal just by chance without shaping the reward." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Cleaning up the solver\n", + "\n", + "Some solvers need proper cleaning before being deleted." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "solver._cleanup()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that this is automatically done if you use the solver within a `with` statement. The syntax would look something like:\n", + "\n", + "```python\n", + "with solver_factory() as solver:\n", + " MyDomain.solve_with(solver, domain_factory)\n", + " rollout(domain=domain, solver=solver\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### A* solver" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now use [A*](https://en.wikipedia.org/wiki/A*_search_algorithm) well known to be suited to this kind of problem because it exploits the knowledge of the goal and of heuristic metrics to reach the goal (e.g. euclidean or Manhattan distance).\n", + "\n", + "A* (pronounced \"A-star\") is a graph traversal and path search algorithm, which is often used in many fields of computer science due to its completeness, optimality, and optimal efficiency.\n", + "One major practical drawback is its š¯‘‚(š¯‘¸š¯‘‘) space complexity, as it stores all generated nodes in memory.\n", + "\n", + "See more details in the [original paper](https://ieeexplore.ieee.org/document/4082128): P. E. Hart, N. J. Nilsson and B. Raphael, \"A Formal Basis for the Heuristic Determination of Minimum Cost Paths,\" in IEEE Transactions on Systems Science and Cybernetics, vol. 4, no. 2, pp. 100-107, July 1968.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Solver instantiation\n", + "\n", + "We use the heuristic previously defined in MazeDomain class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "solver = Astar(heuristic=lambda d, s: d.heuristic(s))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Training solver on the domain\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "MazeDomain.solve_with(solver, domain_factory)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Rolling out the solution (found by A*)\n", + "\n", + "We use the same rollout function and maximum number of steps as for the PPO solver." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rollout(domain=domain, solver=solver, max_steps=max_steps, pause_between_steps=None)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This time, the goal is reached!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The fact that A* (which was designed for path planning problems) can do better than Deep RL here is due to:\n", + "- mainly the fact that this algorithm uses more information from the domain to solve it efficiently, namely the fact that all rewards are negative here (\"positive cost\") + exhaustively given list of next states (which enables to explore a structured graph, instead of randomly looking for a sparse reward)\n", + "- the possible use of an admissible heuristic (distance to goal), which speeds up even more solving (while keeping optimality guarantee)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Cleaning up the solver" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "solver._cleanup()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##Ā Conclusion" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We saw how to define from scratch a scikit-decide domain by specifying its characteristics at the finer level possible, and how to find the existing solvers matching those characteristics.\n", + "\n", + "We also managed to apply a quite classical solver from the RL community (PPO) as well as a more specific solver (A*) for the maze problem. Some important lessons:\n", + "- Even though for many the go-to method for decision making, PPO was not able to solve the \"simple\" maze problem;\n", + "- More precisely, PPO seems not well-fitted to structured domains with sparse rewards (e.g. goal state to reach);\n", + "- Solvers that take more advantage of all characteristics available are generally more suited, as A* demonstrated.\n", + "\n", + "That is why it is important to define the domain with the finer granularity possible and also to use the solvers that can exploit at most the known characteristics of the domain.\n" + ] + } + ], + "metadata": { + "anaconda-cloud": {}, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": true, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "165px" + }, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/notebooks/maze_utils.py b/notebooks/maze_utils.py new file mode 100644 index 0000000000..8594be73db --- /dev/null +++ b/notebooks/maze_utils.py @@ -0,0 +1,157 @@ +"""Utility code for maze notebook.""" + +import random +from collections import deque +from typing import Any, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np + +_deltas_neighbour = [ + (0, 2), + (2, 0), + (0, -2), + (-2, 0), +] + + +class Maze: + _odd_row_pattern = (0, 0) + _even_row_pattern = (0, 1) + _row_end_cell = 0 + _empty_cell = 1 + + def __init__(self, maze_array: np.array): + self.maze_array = maze_array + + @property + def height(self): + return self.maze_array.shape[0] + + @property + def width(self): + return self.maze_array.shape[1] + + @staticmethod + def generate_random_maze(width: int, height: int) -> "Maze": + """Generate a maze string with given width and height. + + Width and height are assumed to be odd so that the maze is surrounded by a wall + and 1 pixel over 2 is a cell followed by a connexion or a wall. + + We use here the "recursive backtracker" algorithm which is a randomized depth-first search algorithm. + The chosen implementation is actually an iterative one to avoid max recursion stack issues. + See for instance https://en.wikipedia.org/wiki/Maze_generation_algorithm for more details. + + """ + + # initialization of maze strings: 1 pixel over 2 is an empty cell, the remaining are walls. + semiwidth = width // 2 + semiheight = height // 2 + odd_row = list(Maze._odd_row_pattern) * semiwidth + [Maze._row_end_cell] + even_row = list(Maze._even_row_pattern) * semiwidth + [Maze._row_end_cell] + maze = [list(row) for _ in range(semiheight) for row in (odd_row, even_row)] + [ + list(odd_row) + ] + + # recursive backtracker algorithm: + # 1. start from a cell + # 2. open randomly a wall towards a new cell (unvisited) and go to that cell + # 3. go on until new cell has no more unvisited neighbours + # 4. go back to a cell with unvisited neigbours and go back to step 2 + first_cell = (1, 1) + stack = deque([first_cell]) + visited = {first_cell} + while len(stack) > 0: + # pick a cell + current_cell = stack.pop() + i, j = current_cell + # Ā find unvisited neighbours + neighbours = [(i + di, j + dj) for (di, dj) in _deltas_neighbour] + neighbours = [ + (i, j) + for (i, j) in neighbours + if (i > 0) and (i < height) and (j > 0) and (j < width) + ] + unvisited_neighbours = [cell for cell in neighbours if cell not in visited] + # Ā remove a wall towards an unvisited cell + if len(unvisited_neighbours) > 0: + stack.append(current_cell) + next_cell = random.choice(unvisited_neighbours) + ii, jj = next_cell + wall_to_remove = () + maze[(i + ii) // 2][(j + jj) // 2] = Maze._empty_cell + visited.add(next_cell) + stack.append(next_cell) + + # corresponding maze string + return Maze(maze_array=np.array(maze)) + + def get_image_data( + self, + current_position: Optional[Tuple[int, int]] = None, + goal: Optional[Tuple[int, int]] = None, + ) -> np.array: + """Return a numpy array to be displayed with `matplotlib.pyplot.imshow()`. + + Args: + current_position: the current position. Will be highlighted if not None. + goal: the goal to reach position. Will be highlighted if not None. + + Returns: + + """ + image_data = np.array(self.maze_array, dtype=float) + if goal is not None: + image_data[goal] = 0.7 + if current_position is not None: + image_data[current_position] = 0.3 + return image_data + + def render( + self, + current_position: Optional[Tuple[int, int]] = None, + goal: Optional[Tuple[int, int]] = None, + ax: Optional[Any] = None, + image: Optional[Any] = None, + ) -> Tuple[Any, Any]: + """Render the maze in a matplotlib figure. + + Args: + current_position: the current position. Will be highlighted if not None. + goal: the goal to reach position. Will be highlighted if not None. + ax: the `matplotlib.axes._subplots.AxesSubplot` in which to render + image: the `matplotlib.image.AxesImage` to update + + Returns: + ax, image: the matplotlib subplot in which the maze in rendered + and the actual rendered matplotlib image + + """ + image_data = self.get_image_data( + current_position=current_position, + goal=goal, + ) + if ax is None: + plt.ioff() + fig, ax = plt.subplots(1) + ax.set_aspect("equal") # set the x and y axes to the same scale + plt.xticks([]) # remove the tick marks by setting to an empty list + plt.yticks([]) # remove the tick marks by setting to an empty list + ax.invert_yaxis() # invert the y-axis so the first row of data is at the top + plt.ion() + fig.canvas.header_visible = False + fig.canvas.footer_visible = False + fig.canvas.resizable = False + fig.set_dpi(1) + fig.set_figwidth(600) + fig.set_figheight(600) + if image is None: + image = ax.imshow(image_data) + else: + image.set_data(image_data) + image.figure.canvas.draw() + return ax, image + + def is_an_empty_cell(self, i, j): + return self.maze_array[i, j] == Maze._empty_cell