Exercise VII: Decision Trees and Random Forests#

Decision trees are as easy to implement with sklearn as any other of the models we’ve studied so far. As a quick example we could try to classify the iris dataset which we’re already familiar with:

Hide code cell source

import warnings
import pandas as pd

from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
warnings.filterwarnings("ignore")

data = load_iris(as_frame=True)

X, y = data.data, data.target
y = y.replace({index: name for index, name in enumerate(data["target_names"])})

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
decision_tree = DecisionTreeClassifier(random_state=42)
_ = decision_tree.fit(X_train, y_train)

Trees are even easy to visualize, thanks to the plot_tree() function:

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(16, 12))
_ = plot_tree(decision_tree,
              class_names=data["target_names"].tolist(),
              feature_names=data["feature_names"],
              filled=True,
              fontsize=11,
              node_ids=True,
              proportion=True,
              rounded=True,
              ax=ax)
../../_images/391e1a6700b743e074d10b6abb111e0d2ae1347c7d88967e7d2ee5228d5e8c4a.png

Finally, we can evaluate the model’s performance using the appropriate metrics, e.g.:

from sklearn.metrics import ConfusionMatrixDisplay

disp = ConfusionMatrixDisplay.from_estimator(decision_tree,
                                             X_test,
                                             y_test,
                                             cmap=plt.cm.Blues,
                                             normalize="true",
                                             display_labels=data["target_names"])
disp.ax_.set_title("Confusion Matrix")
Text(0.5, 1.0, 'Confusion Matrix')
../../_images/6fc54e5167696dc0ce1b795ea2dbf3646b44486c93b00647681c8be8d4f0f27c.png
from sklearn.metrics import classification_report

y_predicted = decision_tree.predict(X_test)
print(classification_report(y_test, y_predicted))
              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        10
  versicolor       1.00      1.00      1.00         9
   virginica       1.00      1.00      1.00        11

    accuracy                           1.00        30
   macro avg       1.00      1.00      1.00        30
weighted avg       1.00      1.00      1.00        30

Predicting ASD Diagnosis:#

Again, we will try and create a model to classify ASD diagnosis using four FreeSurfer metrics extracted over 360 brain regions. In the previous exercise we used an \(\ell_2\) regularized logistic regression model (estimated using the LogisticRegressionCV class) to select and fit a model, this time we will use the RandomForestClassifier and GradientBoostingClassifier to do the same.

Hide code cell source

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Load
tsv_url = "https://raw.githubusercontent.com/neurohackademy/nh2020-curriculum/master/tu-machine-learning-yarkoni/data/abide2.tsv"
data = pd.read_csv(tsv_url, delimiter="\t")

# Clean
IGNORED_COLUMNS = ["age_resid", "sex"]
REPLACE_DICT = {"group": {1: "ASD", 2: "Control"}}
data.drop(columns=IGNORED_COLUMNS, inplace=True)
data.replace(REPLACE_DICT, inplace=True)

# Feature matrix
X = data.filter(regex="^fs").copy()  # Select columns starting with "fs"
scaler = StandardScaler()
X.loc[:, :] = scaler.fit_transform(X.loc[:, :])

# Target vector
y = data["group"] == "ASD"

# Train/test split
TRAIN_SIZE = 900

X_train, X_test, y_train, y_test = train_test_split(X,
                                                    y,
                                                    train_size=TRAIN_SIZE,
                                                    random_state=0)

Random Forest#

Model Creation#

from sklearn.ensemble import RandomForestClassifier

random_forest = RandomForestClassifier(random_state=0)
_ = random_forest.fit(X_train, y_train)

Model Application#

y_predicted = random_forest.predict(X_test)

Model Evaluation#

Confusion Matrix#
disp = ConfusionMatrixDisplay.from_estimator(random_forest,
                                             X_test,
                                             y_test,
                                             cmap=plt.cm.Blues,
                                             normalize="true")
disp.ax_.set_title("Confusion Matrix")
Text(0.5, 1.0, 'Confusion Matrix')
../../_images/7c3b0c841ce7cfc582d173d537df351037518f567a3036449659aa7fa5fae10a.png
Classification Report#
print(classification_report(y_test, y_predicted))
              precision    recall  f1-score   support

       False       0.68      0.72      0.70        58
        True       0.62      0.57      0.59        46

    accuracy                           0.65       104
   macro avg       0.65      0.64      0.65       104
weighted avg       0.65      0.65      0.65       104

Not too bad! We’ve improved our accuracy from 0.63 to 0.65.

Feature Importance#

One of the greatest things about trees is their interpretability. sklearn exposes one measure (“Gini importance”) as a built-in property of the fitted random forest estimator:

random_forest.feature_importances_
array([2.96342846e-04, 5.08910846e-05, 6.69469236e-04, ...,
       5.30759983e-04, 2.16033616e-04, 4.88856474e-04], shape=(1440,))

However, the documentation states that:

“impurity-based feature importances can be misleading for high cardinality features (many unique values).”

In our case, having numerical features, we would be better off using the permutation_importance() function (for more information, see Permutation Importance vs Random Forest Feature Importance).

from sklearn.inspection import permutation_importance

importance = permutation_importance(random_forest,
                                    X_test,
                                    y_test,
                                    random_state=0,
                                    n_jobs=8)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[12], line 3
      1 from sklearn.inspection import permutation_importance
----> 3 importance = permutation_importance(random_forest,
      4                                     X_test,
      5                                     y_test,
      6                                     random_state=0,
      7                                     n_jobs=8)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/sklearn/utils/_param_validation.py:218, in validate_params.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    212 try:
    213     with config_context(
    214         skip_parameter_validation=(
    215             prefer_skip_nested_validation or global_skip_validation
    216         )
    217     ):
--> 218         return func(*args, **kwargs)
    219 except InvalidParameterError as e:
    220     # When the function is just a wrapper around an estimator, we allow
    221     # the function to delegate validation to the estimator, but we replace
    222     # the name of the estimator by the name of the function in the error
    223     # message to avoid confusion.
    224     msg = re.sub(
    225         r"parameter of \w+ must be",
    226         f"parameter of {func.__qualname__} must be",
    227         str(e),
    228     )

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/sklearn/inspection/_permutation_importance.py:288, in permutation_importance(estimator, X, y, scoring, n_repeats, n_jobs, random_state, sample_weight, max_samples)
    285 scorer = check_scoring(estimator, scoring=scoring)
    286 baseline_score = _weights_scorer(scorer, estimator, X, y, sample_weight)
--> 288 scores = Parallel(n_jobs=n_jobs)(
    289     delayed(_calculate_permutation_scores)(
    290         estimator,
    291         X,
    292         y,
    293         sample_weight,
    294         col_idx,
    295         random_seed,
    296         n_repeats,
    297         scorer,
    298         max_samples,
    299     )
    300     for col_idx in range(X.shape[1])
    301 )
    303 if isinstance(baseline_score, dict):
    304     return {
    305         name: _create_importances_bunch(
    306             baseline_score[name],
   (...)    310         for name in baseline_score
    311     }

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/sklearn/utils/parallel.py:91, in Parallel.__call__(self, iterable)
     79 warning_filters = (
     80     filters_func() if filters_func is not None else warnings.filters
     81 )
     83 iterable_with_config_and_warning_filters = (
     84     (
     85         _with_config_and_warning_filters(delayed_func, config, warning_filters),
   (...)     89     for delayed_func, args, kwargs in iterable
     90 )
---> 91 return super().__call__(iterable_with_config_and_warning_filters)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/joblib/parallel.py:2072, in Parallel.__call__(self, iterable)
   2066 # The first item from the output is blank, but it makes the interpreter
   2067 # progress until it enters the Try/Except block of the generator and
   2068 # reaches the first `yield` statement. This starts the asynchronous
   2069 # dispatch of the tasks to the workers.
   2070 next(output)
-> 2072 return output if self.return_generator else list(output)

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/joblib/parallel.py:1682, in Parallel._get_outputs(self, iterator, pre_dispatch)
   1679     yield
   1681     with self._backend.retrieval_context():
-> 1682         yield from self._retrieve()
   1684 except GeneratorExit:
   1685     # The generator has been garbage collected before being fully
   1686     # consumed. This aborts the remaining tasks if possible and warn
   1687     # the user if necessary.
   1688     self._exception = True

File /opt/hostedtoolcache/Python/3.11.14/x64/lib/python3.11/site-packages/joblib/parallel.py:1800, in Parallel._retrieve(self)
   1789 if self.return_ordered:
   1790     # Case ordered: wait for completion (or error) of the next job
   1791     # that have been dispatched and not retrieved yet. If no job
   (...)   1795     # control only have to be done on the amount of time the next
   1796     # dispatched job is pending.
   1797     if (nb_jobs == 0) or (
   1798         self._jobs[0].get_status(timeout=self.timeout) == TASK_PENDING
   1799     ):
-> 1800         time.sleep(0.01)
   1801         continue
   1803 elif nb_jobs == 0:
   1804     # Case unordered: jobs are added to the list of jobs to
   1805     # retrieve `self._jobs` only once completed or in error, which
   (...)   1811     # timeouts before any other dispatched job has completed and
   1812     # been added to `self._jobs` to be retrieved.

KeyboardInterrupt: 

Hide code cell source

import numpy as np

MEASUREMENT_DICT = {
    "fsArea": "Surface Area",
    "fsCT": "Cortical Thinkness",
    "fsVol": "Cortical Volume",
    "fsLGI": "Local Gyrification Index"
}
HEMISPHERE_DICT = {"L": "Left", "R": "Right"}
REGION_IDS = range(1, 181)
FEATURE_INDEX = pd.MultiIndex.from_product(
    [HEMISPHERE_DICT.values(), REGION_IDS,
     MEASUREMENT_DICT.values()],
    names=["Hemisphere", "Region ID", "Measurement"])
COLUMN_NAMES = ["Identifier", "Importance"]


def parse_importance(X: pd.DataFrame, importance: np.ndarray) -> dict:
    feature_info = pd.DataFrame(index=FEATURE_INDEX, columns=COLUMN_NAMES)
    for i, column_name in enumerate(X_train.columns):
        measurement, hemisphere, identifier, _ = column_name.split("_")
        measurement = MEASUREMENT_DICT.get(measurement)
        hemisphere = HEMISPHERE_DICT.get(hemisphere)
        region_id = i % 180 + 1
        feature_info.loc[(hemisphere, region_id,
                          measurement), :] = identifier, importance[i]
    return feature_info


importance_series = parse_importance(X, importance["importances_mean"])

Hide code cell source

import nibabel as nib

from nilearn import datasets
from nilearn import plotting
from nilearn import surface
from nilearn.image import new_img_like

HCP_NIFTI_PATH = "../chapter_06/HCP-MMP1_on_MNI152_ICBM2009a_nlin.nii.gz"
HCP_IMAGE = nib.load(HCP_NIFTI_PATH)
HCP_DATA = np.round(HCP_IMAGE.get_fdata())


def plot_importance(feature_importance: pd.DataFrame,
                    measurement: str) -> None:
    """
    Plots coefficient estimation results using a "glass brain" plot.
    
    Parameters
    ----------
    coefficient_values : pd.DataFrame
        Formatted dataframe containing coefficient values indexed by
        (Hemisphere, Region ID, Measurement)
    measurement: str
        String identifier for the desired type of measurement
    """

    # Create a copy of the HCP-MMP1 atlas array
    template = HCP_DATA.copy()

    # Replace region indices in the template with their matching importance values
    for hemisphere in HEMISPHERE_DICT.values():
        # Query appropriate rows
        selection = feature_importance.xs((hemisphere, measurement),
                                          level=("Hemisphere", "Measurement"))
        # Extract an array of importance values
        values = selection["Importance"].values
        if (values > 0).any():
            # Replace region indices with values
            for region_id in REGION_IDS:
                template_id = region_id if hemisphere == "Left" else region_id + 180
                template[template == template_id] = values[region_id - 1]
        else:
            return

    # Create a `nibabel.nifti1.Nifti1Image` instance for `plot_glass_brain`
    importance_nifti = new_img_like(HCP_IMAGE, template, HCP_IMAGE.affine)

    _ = plotting.plot_glass_brain(importance_nifti,
                                  display_mode='ortho',
                                  colorbar=True,
                                  title=measurement)


for measurement in MEASUREMENT_DICT.values():
    plot_importance(importance_series, measurement)
../../_images/f6ac99e4b9335e00df410e0ecb60eb6a9c2e88b128eb52dbbb31574c9d483118.png ../../_images/6bfd8d9a97c8b3806bce1b9a296759be32dca78e875f9e19ef2d93b0b3ed89bb.png ../../_images/3fc249835e64df316949da8ff4afbfa5d6c7fd440b98f3707904d73f884ecb4a.png ../../_images/bc102d3888351b0ae304f888a28302d738063d425ac2372edc8e26485db35d8a.png

Gradient Boosting#

Model Creation#

from sklearn.ensemble import GradientBoostingClassifier

gradient_boosting = GradientBoostingClassifier(random_state=0)
_ = gradient_boosting.fit(X_train, y_train)

Model Application#

y_predicted_gb = gradient_boosting.predict(X_test)

Model Evaluation#

Confusion Matrix#
disp = ConfusionMatrixDisplay.from_estimator(gradient_boosting,
                                             X_test,
                                             y_test,
                                             cmap=plt.cm.Blues,
                                             normalize="true")
disp.ax_.set_title("Confusion Matrix")
Text(0.5, 1.0, 'Confusion Matrix')
../../_images/de83dd7a6f5ff6765a0825cb75fb9dc4bb233ccdf10581639c4fb26597ee32fc.png
Classification Report#
print(classification_report(y_test, y_predicted_gb))
              precision    recall  f1-score   support

       False       0.71      0.67      0.69        58
        True       0.61      0.65      0.63        46

    accuracy                           0.66       104
   macro avg       0.66      0.66      0.66       104
weighted avg       0.67      0.66      0.66       104

We’ve ended up with an overall slightly better accuracy and precision.

Feature Importance#
from sklearn.inspection import permutation_importance

gb_importance = permutation_importance(gradient_boosting,
                                       X_test,
                                       y_test,
                                       random_state=0,
                                       n_jobs=8)

Hide code cell source

gb_importance_series = parse_importance(X, gb_importance["importances_mean"])
for measurement in MEASUREMENT_DICT.values():
    plot_importance(gb_importance_series, measurement)
../../_images/e91df70e9344f2656d04ef3143de62a19318b7bb1f7da55fa226e8cbdfbe16b8.png ../../_images/42e20b71014bbb0838d9554d673f9c09916e6dcb1a41c08ed6598a26a5b41d82.png ../../_images/192cee5f1c90ba870b8238500458b23aa4a917b8358058c77ecf64a7f0ba3b8c.png ../../_images/5967fe827219ce1abde2fdf4bdab42949cb93820aa86ccb06124ac3debddd2e9.png

Hyperparameter Tuning#