07 脕rboles de Decisi贸n#

Versi贸n v.1#

El notebook lo puedo modificar, esta versi贸n es la b.1 a 10/07/2024 a las Caracas.

Aprendizaje Autom谩tico [UCV]#


1. Intuici贸n#

1.1 Funciones para graficar#

import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import load_iris
import warnings
warnings.filterwarnings('ignore')
---------------------------------------------------------------------------
ModuleNotFoundError                       Traceback (most recent call last)
Cell In[1], line 1
----> 1 import matplotlib.pyplot as plt
      2 import pandas as pd
      3 from sklearn.datasets import load_iris

ModuleNotFoundError: No module named 'matplotlib'
def plot_data_multiclass(X, y, idx=(0, 1), xi=None, yi=None):
  plt.figure(figsize=(10,10))
  plt.scatter(X[:, idx[0]], X[:, idx[1]], c=y, s=30, cmap=plt.cm.Paired)

  if xi is None or yi is None:
    plt.show()
    return 1

  plt.scatter(xi, yi, c=[3] if len(xi) == 1 else [3, 4], s=30)

  plt.scatter(
      xi,
      yi,
      s=100,
      linewidth=1,
      facecolors="none",
      edgecolors="k",
  )
  if len(xi) == 2:
    plt.plot(xi, yi, 'k-')
  plt.show()

1.2 Separando con hiperplanos simples#

Supongamos que solo vamos a separar de la forma \((j, t_i)\) con:

  1. \(j\) el 铆ndice de alguna caracter铆stica.

  2. \(t_i\) un valor de la caracter铆stica creando el intervalo \(x_j \leq t_i\)

X, y = load_iris(return_X_y=True)
type(X)
numpy.ndarray
X
y
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

Algunas veces, este criterio resulta poco 煤til

plot_data_multiclass(X, y)
../../_images/585411bc45316b342456df5bfe60567dbcf87e8db6c8e14678bfd00d2b7174f7.png
1

sin embargo, en algunos casos resulta 煤til

plot_data_multiclass(X, y, (0, 2))
../../_images/9e718b7641daad5cd00d2027ba9e32f961f30180bb6f895406ac7990d5997fa8.png
1

podemos modificar las dimensiones y podr铆a funcionar tambi茅n

plot_data_multiclass(X, y, (1, 2))
../../_images/5e098159327f8094e1970e34ad317060317a4992314a69fb6564149aca5702ec.png
1

1.3 3D graph#

Veamos c贸mo se podr铆a entender en tres dimensiones

1.3.1 Usando plotly#

import plotly.express as px
df = px.data.iris()
df
sepal_length sepal_width petal_length petal_width species species_id
0 5.1 3.5 1.4 0.2 setosa 1
1 4.9 3.0 1.4 0.2 setosa 1
2 4.7 3.2 1.3 0.2 setosa 1
3 4.6 3.1 1.5 0.2 setosa 1
4 5.0 3.6 1.4 0.2 setosa 1
... ... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 virginica 3
146 6.3 2.5 5.0 1.9 virginica 3
147 6.5 3.0 5.2 2.0 virginica 3
148 6.2 3.4 5.4 2.3 virginica 3
149 5.9 3.0 5.1 1.8 virginica 3

150 rows 脳 6 columns

fig = px.scatter(
    df, x='sepal_length', y='sepal_width',
    color='species',
    size = 'petal_length'
)

fig.show()
import plotly.express as px
df = px.data.iris()

fig = px.scatter_3d(
    df, x='sepal_length', y='sepal_width', z='petal_width',
    color='species',
    size = 'petal_length',

)

fig.show()

1.3.2 Usando scipy y plotly#

import scipy.io
import plotly.graph_objs as go
import numpy as np
iris = load_iris()
df = pd.DataFrame(
    data=np.c_[
        iris['data'],
        iris['target']
    ],
    columns= iris['feature_names'] + ['target']).astype({'target': int}).assign(
        species=lambda x: x['target'].map(dict(enumerate(iris['target_names'])))
    )
def md_graph(X, y, targets):
  # Get first y so we don't lose interpertrability later on
  classes = y.copy()

  x = X[targets[0]]
  y = X[targets[1]]
  z = X[targets[2]]
  fig = go.Figure(
      data=[
          go.Scatter3d(
              x=x,
              y=y,
              z=z,
              marker=dict(
                  size=6,
                  color=classes.values.reshape(150,),
                  opacity=0.8
              )
          )
      ]
  )

  fig.update_layout(
      scene=dict(
          xaxis_title=targets[0],
          yaxis_title=targets[1],
          zaxis_title=targets[2]),
      scx=1.5, y=3, z=0ene_camera=dict(
          up=dict(x=0, y=0, z=10),
          center=dict(x=0, y=0, z=0),
          eye=dict()
      )
  )
  fig.show()
df
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target species
0 5.1 3.5 1.4 0.2 0 setosa
1 4.9 3.0 1.4 0.2 0 setosa
2 4.7 3.2 1.3 0.2 0 setosa
3 4.6 3.1 1.5 0.2 0 setosa
4 5.0 3.6 1.4 0.2 0 setosa
... ... ... ... ... ... ...
145 6.7 3.0 5.2 2.3 2 virginica
146 6.3 2.5 5.0 1.9 2 virginica
147 6.5 3.0 5.2 2.0 2 virginica
148 6.2 3.4 5.4 2.3 2 virginica
149 5.9 3.0 5.1 1.8 2 virginica

150 rows 脳 6 columns

X = df.loc[:, ["sepal length (cm)"]]
X
sepal length (cm)
0 5.1
1 4.9
2 4.7
3 4.6
4 5.0
... ...
145 6.7
146 6.3
147 6.5
148 6.2
149 5.9

150 rows 脳 1 columns

X_012, y_012 = df.iloc[:, [0, 1, 2]], df.loc[:, ["species"]]
md_graph(X=X_012, y=y_012, targets=["sepal length (cm)", "sepal width (cm)", "petal length (cm)"] )
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-168-4c4f56603923> in <cell line: 1>()
----> 1 md_graph(X=X_012, y=y_012, targets=["sepal length (cm)", "sepal width (cm)", "petal length (cm)"] )

<ipython-input-160-2f942e4ec992> in md_graph(X, y, targets)
      8   fig = go.Figure(
      9       data=[
---> 10           go.Scatter3d(
     11               x=x,
     12               y=y,

/usr/local/lib/python3.10/dist-packages/plotly/graph_objs/_scatter3d.py in __init__(self, arg, connectgaps, customdata, customdatasrc, error_x, error_y, error_z, hoverinfo, hoverinfosrc, hoverlabel, hovertemplate, hovertemplatesrc, hovertext, hovertextsrc, ids, idssrc, legend, legendgroup, legendgrouptitle, legendrank, legendwidth, line, marker, meta, metasrc, mode, name, opacity, projection, scene, showlegend, stream, surfaceaxis, surfacecolor, text, textfont, textposition, textpositionsrc, textsrc, texttemplate, texttemplatesrc, uid, uirevision, visible, x, xcalendar, xhoverformat, xsrc, y, ycalendar, yhoverformat, ysrc, z, zcalendar, zhoverformat, zsrc, **kwargs)
   2668         _v = marker if marker is not None else _v
   2669         if _v is not None:
-> 2670             self["marker"] = _v
   2671         _v = arg.pop("meta", None)
   2672         _v = meta if meta is not None else _v

/usr/local/lib/python3.10/dist-packages/plotly/basedatatypes.py in __setitem__(self, prop, value)
   4863                 # ### Handle compound property ###
   4864                 if isinstance(validator, CompoundValidator):
-> 4865                     self._set_compound_prop(prop, value)
   4866 
   4867                 # ### Handle compound array property ###

/usr/local/lib/python3.10/dist-packages/plotly/basedatatypes.py in _set_compound_prop(self, prop, val)
   5274         # ------------
   5275         validator = self._get_validator(prop)
-> 5276         val = validator.validate_coerce(val, skip_invalid=self._skip_invalid)
   5277 
   5278         # Save deep copies of current and new states

/usr/local/lib/python3.10/dist-packages/_plotly_utils/basevalidators.py in validate_coerce(self, v, skip_invalid, _validate)
   2473 
   2474         elif isinstance(v, dict):
-> 2475             v = self.data_class(v, skip_invalid=skip_invalid, _validate=_validate)
   2476 
   2477         elif isinstance(v, self.data_class):

/usr/local/lib/python3.10/dist-packages/plotly/graph_objs/scatter3d/_marker.py in __init__(self, arg, autocolorscale, cauto, cmax, cmid, cmin, color, coloraxis, colorbar, colorscale, colorsrc, line, opacity, reversescale, showscale, size, sizemin, sizemode, sizeref, sizesrc, symbol, symbolsrc, **kwargs)
   1265         _v = color if color is not None else _v
   1266         if _v is not None:
-> 1267             self["color"] = _v
   1268         _v = arg.pop("coloraxis", None)
   1269         _v = coloraxis if coloraxis is not None else _v

/usr/local/lib/python3.10/dist-packages/plotly/basedatatypes.py in __setitem__(self, prop, value)
   4871                 # ### Handle simple property ###
   4872                 else:
-> 4873                     self._set_prop(prop, value)
   4874             else:
   4875                 # Make sure properties dict is initialized

/usr/local/lib/python3.10/dist-packages/plotly/basedatatypes.py in _set_prop(self, prop, val)
   5215                 return
   5216             else:
-> 5217                 raise err
   5218 
   5219         # val is None

/usr/local/lib/python3.10/dist-packages/plotly/basedatatypes.py in _set_prop(self, prop, val)
   5210 
   5211         try:
-> 5212             val = validator.validate_coerce(val)
   5213         except ValueError as err:
   5214             if self._skip_invalid:

/usr/local/lib/python3.10/dist-packages/_plotly_utils/basevalidators.py in validate_coerce(self, v, should_raise)
   1352 
   1353                 if invalid_els and should_raise:
-> 1354                     self.raise_invalid_elements(invalid_els)
   1355 
   1356                 # ### Check that elements have valid colors types ###

/usr/local/lib/python3.10/dist-packages/_plotly_utils/basevalidators.py in raise_invalid_elements(self, invalid_els)
    301     def raise_invalid_elements(self, invalid_els):
    302         if invalid_els:
--> 303             raise ValueError(
    304                 """
    305     Invalid element(s) received for the '{name}' property of {pname}

ValueError: 
    Invalid element(s) received for the 'color' property of scatter3d.marker
        Invalid elements include: ['setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa', 'setosa']

    The 'color' property is a color and may be specified as:
      - A hex string (e.g. '#ff0000')
      - An rgb/rgba string (e.g. 'rgb(255,0,0)')
      - An hsl/hsla string (e.g. 'hsl(0,100%,50%)')
      - An hsv/hsva string (e.g. 'hsv(0,100%,100%)')
      - A named CSS color:
            aliceblue, antiquewhite, aqua, aquamarine, azure,
            beige, bisque, black, blanchedalmond, blue,
            blueviolet, brown, burlywood, cadetblue,
            chartreuse, chocolate, coral, cornflowerblue,
            cornsilk, crimson, cyan, darkblue, darkcyan,
            darkgoldenrod, darkgray, darkgrey, darkgreen,
            darkkhaki, darkmagenta, darkolivegreen, darkorange,
            darkorchid, darkred, darksalmon, darkseagreen,
            darkslateblue, darkslategray, darkslategrey,
            darkturquoise, darkviolet, deeppink, deepskyblue,
            dimgray, dimgrey, dodgerblue, firebrick,
            floralwhite, forestgreen, fuchsia, gainsboro,
            ghostwhite, gold, goldenrod, gray, grey, green,
            greenyellow, honeydew, hotpink, indianred, indigo,
            ivory, khaki, lavender, lavenderblush, lawngreen,
            lemonchiffon, lightblue, lightcoral, lightcyan,
            lightgoldenrodyellow, lightgray, lightgrey,
            lightgreen, lightpink, lightsalmon, lightseagreen,
            lightskyblue, lightslategray, lightslategrey,
            lightsteelblue, lightyellow, lime, limegreen,
            linen, magenta, maroon, mediumaquamarine,
            mediumblue, mediumorchid, mediumpurple,
            mediumseagreen, mediumslateblue, mediumspringgreen,
            mediumturquoise, mediumvioletred, midnightblue,
            mintcream, mistyrose, moccasin, navajowhite, navy,
            oldlace, olive, olivedrab, orange, orangered,
            orchid, palegoldenrod, palegreen, paleturquoise,
            palevioletred, papayawhip, peachpuff, peru, pink,
            plum, powderblue, purple, red, rosybrown,
            royalblue, rebeccapurple, saddlebrown, salmon,
            sandybrown, seagreen, seashell, sienna, silver,
            skyblue, slateblue, slategray, slategrey, snow,
            springgreen, steelblue, tan, teal, thistle, tomato,
            turquoise, violet, wheat, white, whitesmoke,
            yellow, yellowgreen
      - A number that will be interpreted as a color
        according to scatter3d.marker.colorscale
      - A list or array of any of the above

1.3.2.a Solved#

X_012, y_012 = df.iloc[:, [0, 1, 2]], df.loc[:, ["target"]]
X_012["sepal length (cm)"]
0      5.1
1      4.9
2      4.7
3      4.6
4      5.0
      ... 
145    6.7
146    6.3
147    6.5
148    6.2
149    5.9
Name: sepal length (cm), Length: 150, dtype: float64
md_graph(X=X_012, y=y_012, targets=["sepal length (cm)", "sepal width (cm)", "petal length (cm)"] )

2. Formalizando matem谩ticamente#

Vayamos a la definici贸n de 谩rboles de clasificaci贸n en scikit

Dados los vectores de entrenamiento \(x_s \in \mathbb{R}^p, \mathrm{i}=1, \ldots\), n y un vector de clases \(y \in \mathbb{R}^n\), un 谩rbol de decisi贸n divide recursivamente el espacio de caracter铆sticas de modo que las muestras con las mismas etiquetas o valores objetivo similares se agrupen juntas.

2.1 Los candidatos#

Sean los datos en el nodo \(i\) representados por \(Q_i\) con \(n_i\) muestras.

Para cada divisi贸n candidata \(\theta=\left(j, t_i\right)\) que consta de una caracter铆stica \(j\) y un umbral \(t_i\), divide el nodo en

  1. \(Q_i^{\text {izq }}(\theta)\) y

  2. \(Q_i^{\text {right }}(\theta)\)

Donde

\[\begin{split} \begin{array}{r} Q_i^{\text {izq }}(\theta)=\left\{(x, y) \mid x_j \leq t_i\right\} \\ Q_i^{\text {right }}(\theta)=Q_i \backslash Q_i^{\text {izq }}(\theta) \end{array} \end{split}\]

Canvas

canvas.png

2.1 Midiendo la calidad de un candidato#

El problema es que podemos tener infinitos cortes \(\theta\) por lo que tenemos que tener un m茅todo para definiir qu茅 tan bueno es dentro de un conjunto de opciones y escoger el mejor.

La calidad de una divisi贸n candidata del nodo \(i\) se calcula utilizando una funci贸n de impureza o funci贸n de p茅rdida \(H()\), cuya elecci贸n depende de la tarea que se resuelve (clasificaci贸n o regresi贸n).

F铆jense como \(H\) la definen sin par谩metros porque pueden variar dependiendo de H

\[ G\left(Q_i, \theta\right)=\frac{n_i^{\text {izq}}}{n_i} H\left(Q_i^{\text {izq}}(\theta)\right)+\frac{n_i^{\text {der}}}{n_i} H\left(Q_i^{\text {der}}(\theta)\right) \]

Seleccionar el mejor \(\theta\) que minimice la impureza. $\( \theta^*=\operatorname{argmin}_\theta G\left(Q_i, \theta\right) \)$

F铆jense como \(Q_i\) permanece est谩tico porque buscamos es el mejor \(\theta\) para ese nodo \(i\)

De manera recursiva, ejecutamos ahora para los subconjuntos \(Q_i^{\text{izq}}\left(\theta^*\right)\) y \(Q_i^{\text{der}}\left(\theta^*\right)\) hasta el m谩ximo permitido

2.2 Criterios de parada#

Se alcanza la profundidad, \(n_i<\min _{\text {samples }}\) o \(n_i=1\).

2.3 Criterios para clasificaci贸n#

Si el objetivo es hacer clasificaci贸n sobre las clases $\(n_c \in \{0,1, \ldots, \mathrm{N}-1\}\)\(, para el nodo \)i\(, tenemos \)\( p_{i n_c}=\frac{1}{n_i} \sum_{y \in Q_i} I(y=n_c) \)\( la proporci贸n de la clase \)\mathrm{n_c}\( en el nodo \)i\(. Si \)i\( es un nodo terminal, predict_proba para la regi贸n es \)p_{i n_c}$. Criterios/m茅tricas de calidad comunes son

脥ndice gini: $\( H\left(Q_i\right)=\displaystyle \sum_{n_c} p_{i n_c}\left(1-p_{i n_c}\right) \)$

Log Loss or Entropy: $\( H\left(Q_i\right)=-\sum_{n_c} p_{i n_c} \log \left(p_{i n_c}\right) \)$

Canvas

canvas.png

2.4 Limitaciones#

2.4.1 Pros#

  1. Simples de entender y de interpretar. El modelo se puede visualizar!

  2. Requiere poco preprocesamiento de datos y algunos algoritmos admiten valores faltantes.

  3. El costo de predicci贸n tiene complejidad logar铆tmica.

  4. Generalizable a muchas clases.

2.4.2 Cons#

  1. Propenso a overfitting sobre todo la profundidad del 谩rbol.

  2. Los modelos pueden ser demasiado complejos y no generalizan bien.

  3. Si expandimos el modelo con datos nuevos, el 谩rbol puede modificarse radicalmente.

  4. La mayor铆a de implementaciones son heur铆sticas porque el problema es NP-completo.

2.3 C贸digo#

from sklearn.datasets import load_iris
from sklearn import tree
iris = load_iris()
X, y = iris.data, iris.target
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)
tree.plot_tree(clf)
[Text(0.5, 0.9166666666666666, 'x[2] <= 2.45\ngini = 0.667\nsamples = 150\nvalue = [50, 50, 50]'),
 Text(0.4230769230769231, 0.75, 'gini = 0.0\nsamples = 50\nvalue = [50, 0, 0]'),
 Text(0.5769230769230769, 0.75, 'x[3] <= 1.75\ngini = 0.5\nsamples = 100\nvalue = [0, 50, 50]'),
 Text(0.3076923076923077, 0.5833333333333334, 'x[2] <= 4.95\ngini = 0.168\nsamples = 54\nvalue = [0, 49, 5]'),
 Text(0.15384615384615385, 0.4166666666666667, 'x[3] <= 1.65\ngini = 0.041\nsamples = 48\nvalue = [0, 47, 1]'),
 Text(0.07692307692307693, 0.25, 'gini = 0.0\nsamples = 47\nvalue = [0, 47, 0]'),
 Text(0.23076923076923078, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'),
 Text(0.46153846153846156, 0.4166666666666667, 'x[3] <= 1.55\ngini = 0.444\nsamples = 6\nvalue = [0, 2, 4]'),
 Text(0.38461538461538464, 0.25, 'gini = 0.0\nsamples = 3\nvalue = [0, 0, 3]'),
 Text(0.5384615384615384, 0.25, 'x[2] <= 5.45\ngini = 0.444\nsamples = 3\nvalue = [0, 2, 1]'),
 Text(0.46153846153846156, 0.08333333333333333, 'gini = 0.0\nsamples = 2\nvalue = [0, 2, 0]'),
 Text(0.6153846153846154, 0.08333333333333333, 'gini = 0.0\nsamples = 1\nvalue = [0, 0, 1]'),
 Text(0.8461538461538461, 0.5833333333333334, 'x[2] <= 4.85\ngini = 0.043\nsamples = 46\nvalue = [0, 1, 45]'),
 Text(0.7692307692307693, 0.4166666666666667, 'x[1] <= 3.1\ngini = 0.444\nsamples = 3\nvalue = [0, 1, 2]'),
 Text(0.6923076923076923, 0.25, 'gini = 0.0\nsamples = 2\nvalue = [0, 0, 2]'),
 Text(0.8461538461538461, 0.25, 'gini = 0.0\nsamples = 1\nvalue = [0, 1, 0]'),
 Text(0.9230769230769231, 0.4166666666666667, 'gini = 0.0\nsamples = 43\nvalue = [0, 0, 43]')]
../../_images/855972308f1832cff9158227db70032aba33ac17fce6f1f64e751d6d17f41834.png

2.3.2 Otras opciones para graficar#

import plotly.express as px
df = px.data.iris()

fig = px.scatter_3d(
    df, x='sepal_length', y='sepal_width', z='petal_width',
    color='species',
    size = 'petal_length'
)

fig.show()
import graphviz
dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)
graph
../../_images/07ccb058f413ef7de671cb3983b9771a6d7ef6b0b6bba7330ce20e7c337521e3.svg

2.3.4 Intentemos nosotros#

Probabilidades#

Si el objetivo es hacer clasificaci贸n sobre las clases $\(n_c \in \{0,1, \ldots, \mathrm{N}-1\}\)\(, para el nodo \)i\(, tenemos \)\( p_{i n_c}=\frac{1}{n_i} \sum_{y \in Q_i} I(y=n_c) \)$

Gini#

\[ H\left(Q_i\right)=\displaystyle \sum_{n_c} p_{i n_c}\left(1-p_{i n_c}\right) \]
X, y = load_iris(return_X_y=True)
p_i_nc = lambda y, nc: 1/len(y)* np.sum(y == nc)
gini = lambda y, N: np.sum(np.sum(((p_i_nc(y, nc))*(1 - p_i_nc(y, nc)) for nc in N)))

image.png

2.3.4.a Primera iteraci贸n#

\(Q_i = \{1, \cdots, 150\}\) con \(\theta=(3, 0.8)\) el indice gini es:

gini(y, np.unique(y))
0.6666666666666667

2.3.4.b Segunda iteraci贸n#

dot_data = tree.export_graphviz(clf, out_file=None,
                      feature_names=iris.feature_names,
                     class_names=iris.target_names,
                     filled=True, rounded=True,
                     special_characters=True)
graph = graphviz.Source(dot_data)
graph
../../_images/a65695f2ef39480b7a209829aceb98bd2ea36eb3e47c72bf17e4f89d0d178864.svg
gini(y[X[:, 3] <= 0.8], np.unique(y[X[:, 3] <= 0.8]))
0.0
gini(y[X[:, 3] > 0.8], np.unique(y[X[:, 3] > 0.8]))
0.5

2.3.4.c Tercera iteraci贸n (izquierda)#

image.png

gini(y[(X[:, 3] > 0.8)&(X[:, 3] <= 1.75)], np.unique(y[(X[:, 3] > 0.8)&(X[:, 3] <= 1.75)]))
0.04079861111111115

image.png

gini(y[(X[:, 3] > 0.8)&(X[:, 3] <= 1.75)&(X[:, 2] <= 4.95)], np.unique(y[(X[:, 3] > 0.8)&(X[:, 3] <= 1.75)&(X[:, 2] <= 4.95)]))

2.5 Regi贸n de decisi贸n#

2.5.1 C贸digo#

Ejemplo tomado de Plot Iris DTC

 import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import load_iris
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.tree import DecisionTreeClassifier

# Parameters
n_classes = 3
plot_colors = "ryb"
plot_step = 0.02




for pairidx, pair in enumerate([[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]):
    # We only take the two corresponding features
    X = iris.data[:, pair]
    y = iris.target

    # Train
    clf = DecisionTreeClassifier().fit(X, y)

    # Plot the decision boundary
    ax = plt.subplot(2, 3, pairidx + 1)
    plt.tight_layout(h_pad=0.5, w_pad=0.5, pad=2.5)
    DecisionBoundaryDisplay.from_estimator(
        clf,
        X,
        cmap=plt.cm.RdYlBu,
        response_method="predict",
        ax=ax,
        xlabel=iris.feature_names[pair[0]],
        ylabel=iris.feature_names[pair[1]],
    )

    # Plot the training points
    for i, color in zip(range(n_classes), plot_colors):
        idx = np.where(y == i)
        plt.scatter(
            X[idx, 0],
            X[idx, 1],
            c=color,
            label=iris.target_names[i],
            cmap=plt.cm.RdYlBu,
            edgecolor="black",
            s=15,
        )

plt.suptitle("Decision surface of decision trees trained on pairs of features")
plt.legend(loc="lower right", borderpad=0, handletextpad=0)
_ = plt.axis("tight")
../../_images/8cd40e5a4ae7693cb476afc2033ab624598c82dd2197d654604ea1fe44425635.png