Skip to content

Commit

Permalink
FEA Print Decision Trees in ASCII format (scikit-learn#9424)
Browse files Browse the repository at this point in the history
  • Loading branch information
JustGlowing authored and jnothman committed Feb 11, 2019
1 parent af842d3 commit a061ada
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 3 deletions.
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,7 @@ Low-level methods

tree.export_graphviz
tree.plot_tree
tree.export_text


.. _utils_ref:
Expand Down
23 changes: 23 additions & 0 deletions doc/modules/tree.rst
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,29 @@ render these plots inline automatically::
:align: center
:scale: 75

Alternatively, the tree can also be exported in textual format with the
function :func:`export_text`. This method doesn't require the installation
of external libraries and is more compact:

>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> from sklearn.tree.export import export_text
>>> iris = load_iris()
>>> X = iris['data']
>>> y = iris['target']
>>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
>>> decision_tree = decision_tree.fit(X, y)
>>> r = export_text(decision_tree, feature_names=iris['feature_names'])
>>> print(r)
|--- petal width (cm) <= 0.80
| |--- class: 0
|--- petal width (cm) > 0.80
| |--- petal width (cm) <= 1.75
| | |--- class: 1
| |--- petal width (cm) > 1.75
| | |--- class: 2
<BLANKLINE>

.. topic:: Examples:

* :ref:`sphx_glr_auto_examples_tree_plot_iris_dtc.py`
Expand Down
4 changes: 4 additions & 0 deletions doc/whats_new/v0.21.rst
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,10 @@ Support for Python 3.4 and below has been officially dropped.
:func:`tree.plot_tree` without relying on the ``dot`` library,
removing a hard-to-install dependency. :issue:`8508` by `Andreas Müller`_.

- |Feature| Decision Trees can now be exported in a human readable
textual format using :func:`tree.export.export_text`.
:issue:`6261` by `Giuseppe Vettigli <JustGlowing>`.

- |Feature| ``get_n_leaves()`` and ``get_depth()`` have been added to
:class:`tree.BaseDecisionTree` and consequently all estimators based
on it, including :class:`tree.DecisionTreeClassifier`,
Expand Down
4 changes: 2 additions & 2 deletions sklearn/tree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from .tree import DecisionTreeRegressor
from .tree import ExtraTreeClassifier
from .tree import ExtraTreeRegressor
from .export import export_graphviz, plot_tree
from .export import export_graphviz, plot_tree, export_text

__all__ = ["DecisionTreeClassifier", "DecisionTreeRegressor",
"ExtraTreeClassifier", "ExtraTreeRegressor", "export_graphviz",
"plot_tree"]
"plot_tree", "export_text"]
177 changes: 177 additions & 0 deletions sklearn/tree/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Satrajit Gosh <satrajit.ghosh@gmail.com>
# Trevor Stephens <trev.stephens@gmail.com>
# Li Li <aiki.nogard@gmail.com>
# Giuseppe Vettigli <vettigli@gmail.com>
# License: BSD 3 clause
import warnings
from io import StringIO
Expand All @@ -22,6 +23,7 @@
from . import _criterion
from . import _tree
from ._reingold_tilford import buchheim, Tree
from . import DecisionTreeClassifier


def _color_brew(n):
Expand Down Expand Up @@ -778,3 +780,178 @@ def export_graphviz(decision_tree, out_file=None, max_depth=None,
finally:
if own_file:
out_file.close()


def _compute_depth(tree, node):
"""
Returns the depth of the subtree rooted in node.
"""
def compute_depth_(current_node, current_depth,
children_left, children_right, depths):
depths += [current_depth]
left = children_left[current_node]
right = children_right[current_node]
if left != -1 and right != -1:
compute_depth_(left, current_depth+1,
children_left, children_right, depths)
compute_depth_(right, current_depth+1,
children_left, children_right, depths)

depths = []
compute_depth_(node, 1, tree.children_left, tree.children_right, depths)
return max(depths)


def export_text(decision_tree, feature_names=None, max_depth=10,
spacing=3, decimals=2, show_weights=False):
"""Build a text report showing the rules of a decision tree.
Note that backwards compatibility may not be supported.
Parameters
----------
decision_tree : object
The decision tree estimator to be exported.
It can be an instance of
DecisionTreeClassifier or DecisionTreeRegressor.
feature_names : list, optional (default=None)
A list of length n_features containing the feature names.
If None generic names will be used ("feature_0", "feature_1", ...).
max_depth : int, optional (default=10)
Only the first max_depth levels of the tree are exported.
Truncated branches will be marked with "...".
spacing : int, optional (default=3)
Number of spaces between edges. The higher it is, the wider the result.
decimals : int, optional (default=2)
Number of decimal digits to display.
show_weights : bool, optional (default=False)
If true the classification weights will be exported on each leaf.
The classification weights are the number of samples each class.
Returns
-------
report : string
Text summary of all the rules in the decision tree.
Examples
-------
>>> from sklearn.datasets import load_iris
>>> from sklearn.tree import DecisionTreeClassifier
>>> from sklearn.tree.export import export_text
>>> iris = load_iris()
>>> X = iris['data']
>>> y = iris['target']
>>> decision_tree = DecisionTreeClassifier(random_state=0, max_depth=2)
>>> decision_tree = decision_tree.fit(X, y)
>>> r = export_text(decision_tree, feature_names=iris['feature_names'])
>>> print(r)
|--- petal width (cm) <= 0.80
| |--- class: 0
|--- petal width (cm) > 0.80
| |--- petal width (cm) <= 1.75
| | |--- class: 1
| |--- petal width (cm) > 1.75
| | |--- class: 2
...
"""
check_is_fitted(decision_tree, 'tree_')
tree_ = decision_tree.tree_
class_names = decision_tree.classes_
right_child_fmt = "{} {} <= {}\n"
left_child_fmt = "{} {} > {}\n"
truncation_fmt = "{} {}\n"

if max_depth < 0:
raise ValueError("max_depth bust be >= 0, given %d" % max_depth)

if (feature_names is not None and
len(feature_names) != tree_.n_features):
raise ValueError("feature_names must contain "
"%d elements, got %d" % (tree_.n_features,
len(feature_names)))

if spacing <= 0:
raise ValueError("spacing must be > 0, given %d" % spacing)

if decimals < 0:
raise ValueError("decimals must be >= 0, given %d" % decimals)

if isinstance(decision_tree, DecisionTreeClassifier):
value_fmt = "{}{} weights: {}\n"
if not show_weights:
value_fmt = "{}{}{}\n"
else:
value_fmt = "{}{} value: {}\n"

if feature_names:
feature_names_ = [feature_names[i] for i in tree_.feature]
else:
feature_names_ = ["feature_{}".format(i) for i in tree_.feature]

export_text.report = ""

def _add_leaf(value, class_name, indent):
val = ''
is_classification = isinstance(decision_tree,
DecisionTreeClassifier)
if show_weights or not is_classification:
val = ["{1:.{0}f}, ".format(decimals, v) for v in value]
val = '['+''.join(val)[:-2]+']'
if is_classification:
val += ' class: ' + str(class_name)
export_text.report += value_fmt.format(indent, '', val)

def print_tree_recurse(node, depth):
indent = ("|" + (" " * spacing)) * depth
indent = indent[:-spacing] + "-" * spacing

value = None
if tree_.n_outputs == 1:
value = tree_.value[node][0]
else:
value = tree_.value[node].T[0]
class_name = np.argmax(value)

if (tree_.n_classes[0] != 1 and
tree_.n_outputs == 1):
class_name = class_names[class_name]

if depth <= max_depth+1:
info_fmt = ""
info_fmt_left = info_fmt
info_fmt_right = info_fmt

if tree_.feature[node] != _tree.TREE_UNDEFINED:
name = feature_names_[node]
threshold = tree_.threshold[node]
threshold = "{1:.{0}f}".format(decimals, threshold)
export_text.report += right_child_fmt.format(indent,
name,
threshold)
export_text.report += info_fmt_left
print_tree_recurse(tree_.children_left[node], depth+1)

export_text.report += left_child_fmt.format(indent,
name,
threshold)
export_text.report += info_fmt_right
print_tree_recurse(tree_.children_right[node], depth+1)
else: # leaf
_add_leaf(value, class_name, indent)
else:
subtree_depth = _compute_depth(tree_, node)
if subtree_depth == 1:
_add_leaf(value, class_name, indent)
else:
trunc_report = 'truncated branch of depth %d' % subtree_depth
export_text.report += truncation_fmt.format(indent,
trunc_report)

print_tree_recurse(0, 1)
return export_text.report
90 changes: 89 additions & 1 deletion sklearn/tree/tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import pytest

from re import finditer, search
from textwrap import dedent

from numpy.random import RandomState

from sklearn.base import is_classifier
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.tree import export_graphviz, plot_tree
from sklearn.tree import export_graphviz, plot_tree, export_text
from io import StringIO
from sklearn.utils.testing import (assert_in, assert_equal, assert_raises,
assert_less_equal, assert_raises_regex,
Expand Down Expand Up @@ -311,6 +312,93 @@ def test_precision():
precision + 1)


def test_export_text_errors():
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X, y)

assert_raise_message(ValueError,
"max_depth bust be >= 0, given -1",
export_text, clf, max_depth=-1)
assert_raise_message(ValueError,
"feature_names must contain 2 elements, got 1",
export_text, clf, feature_names=['a'])
assert_raise_message(ValueError,
"decimals must be >= 0, given -1",
export_text, clf, decimals=-1)
assert_raise_message(ValueError,
"spacing must be > 0, given 0",
export_text, clf, spacing=0)


def test_export_text():
clf = DecisionTreeClassifier(max_depth=2, random_state=0)
clf.fit(X, y)

expected_report = dedent("""
|--- feature_1 <= 0.00
| |--- class: -1
|--- feature_1 > 0.00
| |--- class: 1
""").lstrip()

assert export_text(clf) == expected_report
# testing that leaves at level 1 are not truncated
assert export_text(clf, max_depth=0) == expected_report
# testing that the rest of the tree is truncated
assert export_text(clf, max_depth=10) == expected_report

expected_report = dedent("""
|--- b <= 0.00
| |--- class: -1
|--- b > 0.00
| |--- class: 1
""").lstrip()
assert export_text(clf, feature_names=['a', 'b']) == expected_report

expected_report = dedent("""
|--- feature_1 <= 0.00
| |--- weights: [3.00, 0.00] class: -1
|--- feature_1 > 0.00
| |--- weights: [0.00, 3.00] class: 1
""").lstrip()
assert export_text(clf, show_weights=True) == expected_report

expected_report = dedent("""
|- feature_1 <= 0.00
| |- class: -1
|- feature_1 > 0.00
| |- class: 1
""").lstrip()
assert export_text(clf, spacing=1) == expected_report

X_l = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1], [-1, 1]]
y_l = [-1, -1, -1, 1, 1, 1, 2]
clf = DecisionTreeClassifier(max_depth=4, random_state=0)
clf.fit(X_l, y_l)
expected_report = dedent("""
|--- feature_1 <= 0.00
| |--- class: -1
|--- feature_1 > 0.00
| |--- truncated branch of depth 2
""").lstrip()
assert export_text(clf, max_depth=0) == expected_report

X_mo = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
y_mo = [[-1, -1], [-1, -1], [-1, -1], [1, 1], [1, 1], [1, 1]]

reg = DecisionTreeRegressor(max_depth=2, random_state=0)
reg.fit(X_mo, y_mo)

expected_report = dedent("""
|--- feature_1 <= 0.0
| |--- value: [-1.0, -1.0]
|--- feature_1 > 0.0
| |--- value: [1.0, 1.0]
""").lstrip()
assert export_text(reg, decimals=1) == expected_report
assert export_text(reg, decimals=1, show_weights=True) == expected_report


def test_plot_tree():
# mostly smoke tests
pytest.importorskip("matplotlib.pyplot")
Expand Down

0 comments on commit a061ada

Please sign in to comment.