Skip to main content

Tree-Based Models & Ensembles

Master decision trees, random forests, and gradient boosting for tabular data

~55 min
Listen to this lesson

Tree-Based Models & Ensembles

Tree-based models are the most successful family of algorithms for tabular data. They require minimal preprocessing, handle mixed feature types, and when ensembled, consistently win Kaggle competitions.

Decision Trees

A decision tree makes predictions by learning a series of if/else rules from the data. At each node, it asks a question about one feature and splits the data into two groups.

How Splitting Works

The tree algorithm tries every possible feature and threshold, choosing the split that best separates the target classes (classification) or reduces variance (regression).

Splitting Criteria for Classification:

  • Gini Impurity: Measures how often a randomly chosen sample would be misclassified. A pure node (all one class) has Gini = 0.
  • Entropy / Information Gain: Measures the reduction in uncertainty. Also 0 for a pure node.
  • Splitting Criteria for Regression:

  • MSE (Mean Squared Error): Split to minimize variance in each child node.
  • In practice, Gini and Entropy produce very similar trees. Gini is faster to compute and is the sklearn default.

    python
    1from sklearn.tree import DecisionTreeClassifier, export_text
    2from sklearn.datasets import load_iris
    3from sklearn.model_selection import train_test_split
    4
    5X, y = load_iris(return_X_y=True)
    6feature_names = load_iris().feature_names
    7X_train, X_test, y_train, y_test = train_test_split(
    8    X, y, test_size=0.2, random_state=42, stratify=y
    9)
    10
    11# Train a decision tree with pruning
    12tree = DecisionTreeClassifier(
    13    max_depth=3,
    14    min_samples_leaf=5,
    15    random_state=42
    16)
    17tree.fit(X_train, y_train)
    18
    19# Print the tree as text
    20print(export_text(tree, feature_names=feature_names))
    21print(f"\nAccuracy: {tree.score(X_test, y_test):.4f}")
    22print(f"Tree depth: {tree.get_depth()}")
    23print(f"Number of leaves: {tree.get_n_leaves()}")

    Pruning: Controlling Tree Complexity

    Without constraints, a decision tree will keep splitting until every leaf contains a single sample -- perfect training accuracy but terrible generalization. Pruning limits the tree's growth:

    ParameterWhat It Does
    max_depthMaximum depth of the tree
    min_samples_splitMinimum samples required to split a node
    min_samples_leafMinimum samples in each leaf node
    max_featuresNumber of features to consider at each split
    ccp_alphaCost-complexity pruning threshold
    Start with max_depth=5 and min_samples_leaf=5, then tune from there.

    Single Decision Trees Overfit Easily

    An unpruned decision tree will memorize the training data (100% train accuracy) but perform poorly on new data. Always use pruning parameters. Better yet, use an ensemble method like Random Forest which combines many trees to reduce overfitting.

    Random Forest: Bagging with Trees

    Random Forest builds many decision trees and averages their predictions. Two sources of randomness make each tree different:

    1. Bootstrap sampling (Bagging): Each tree is trained on a random sample (with replacement) of the training data 2. Random feature subsets: At each split, only a random subset of features is considered

    This diversity among trees reduces overfitting and makes the ensemble more robust than any individual tree.

    python
    1from sklearn.ensemble import RandomForestClassifier
    2from sklearn.datasets import load_breast_cancer
    3from sklearn.model_selection import train_test_split
    4import numpy as np
    5
    6X, y = load_breast_cancer(return_X_y=True)
    7feature_names = load_breast_cancer().feature_names
    8X_train, X_test, y_train, y_test = train_test_split(
    9    X, y, test_size=0.2, random_state=42, stratify=y
    10)
    11
    12rf = RandomForestClassifier(
    13    n_estimators=200,      # Number of trees
    14    max_depth=10,          # Limit tree depth
    15    min_samples_leaf=2,
    16    max_features="sqrt",   # sqrt(n_features) at each split
    17    random_state=42,
    18    n_jobs=-1              # Parallelize
    19)
    20rf.fit(X_train, y_train)
    21
    22print(f"Training accuracy: {rf.score(X_train, y_train):.4f}")
    23print(f"Test accuracy:     {rf.score(X_test, y_test):.4f}")
    24
    25# Feature importance
    26importances = rf.feature_importances_
    27top_5 = np.argsort(importances)[::-1][:5]
    28print("\nTop 5 important features:")
    29for idx in top_5:
    30    print(f"  {feature_names[idx]}: {importances[idx]:.4f}")

    Gradient Boosting: Sequential Error Correction

    Unlike Random Forests (which build trees in parallel), Gradient Boosting builds trees sequentially. Each new tree tries to correct the errors of all previous trees.

    1. Start with a simple prediction (e.g., the mean) 2. Calculate the residuals (errors) 3. Train a small tree to predict the residuals 4. Add the new tree's predictions (scaled by learning_rate) to the ensemble 5. Repeat

    The learning_rate controls how much each tree contributes. Lower values need more trees but often generalize better.

    python
    1from sklearn.ensemble import GradientBoostingClassifier
    2
    3gb = GradientBoostingClassifier(
    4    n_estimators=200,
    5    learning_rate=0.1,
    6    max_depth=3,
    7    min_samples_leaf=5,
    8    subsample=0.8,     # Stochastic gradient boosting
    9    random_state=42
    10)
    11gb.fit(X_train, y_train)
    12
    13print(f"Training accuracy: {gb.score(X_train, y_train):.4f}")
    14print(f"Test accuracy:     {gb.score(X_test, y_test):.4f}")

    Modern Boosting Libraries: XGBoost, LightGBM, CatBoost

    While sklearn has GradientBoosting, dedicated libraries are much faster and more feature-rich:

    LibraryStrengthsBest For
    XGBoostBattle-tested, regularization built-in, handles missing valuesGeneral tabular data, competitions
    LightGBMFastest training, leaf-wise growth, handles large datasetsLarge datasets, speed-critical applications
    CatBoostNative categorical feature handling, least tuning neededDatasets with many categorical features
    All three follow the sklearn API (fit/predict/score) and can be used with sklearn utilities like cross_val_score and GridSearchCV.

    python
    1# XGBoost example (pip install xgboost)
    2from xgboost import XGBClassifier
    3
    4xgb = XGBClassifier(
    5    n_estimators=200,
    6    learning_rate=0.1,
    7    max_depth=5,
    8    subsample=0.8,
    9    colsample_bytree=0.8,
    10    eval_metric="logloss",
    11    random_state=42,
    12    n_jobs=-1
    13)
    14
    15xgb.fit(X_train, y_train)
    16print(f"XGBoost Test accuracy: {xgb.score(X_test, y_test):.4f}")
    17
    18# LightGBM example (pip install lightgbm)
    19from lightgbm import LGBMClassifier
    20
    21lgbm = LGBMClassifier(
    22    n_estimators=200,
    23    learning_rate=0.1,
    24    max_depth=5,
    25    subsample=0.8,
    26    colsample_bytree=0.8,
    27    random_state=42,
    28    verbose=-1,
    29    n_jobs=-1
    30)
    31
    32lgbm.fit(X_train, y_train)
    33print(f"LightGBM Test accuracy: {lgbm.score(X_test, y_test):.4f}")

    Bagging vs Boosting

    Bagging (Random Forest) trains trees independently in parallel, each on a random subset of data. It reduces variance (overfitting). Boosting (GradientBoosting, XGBoost) trains trees sequentially, each correcting the previous errors. It reduces bias (underfitting). Boosting is usually more accurate but more prone to overfitting if not tuned carefully.

    Feature Importance

    Tree-based models naturally compute feature importance based on how much each feature reduces impurity across all trees. This is one of the biggest advantages of tree-based models -- you get interpretability for free.

    Types of feature importance:

  • Impurity-based (default): How much each feature reduces Gini/entropy across splits. Fast but biased toward high-cardinality features.
  • Permutation importance: Shuffle one feature and measure the drop in accuracy. Slower but more reliable.
  • python
    1from sklearn.inspection import permutation_importance
    2
    3# Permutation importance (more reliable)
    4perm_imp = permutation_importance(
    5    rf, X_test, y_test, n_repeats=10, random_state=42, n_jobs=-1
    6)
    7
    8print("Permutation Importance (top 5):")
    9top_5_perm = np.argsort(perm_imp.importances_mean)[::-1][:5]
    10for idx in top_5_perm:
    11    mean = perm_imp.importances_mean[idx]
    12    std = perm_imp.importances_std[idx]
    13    print(f"  {feature_names[idx]}: {mean:.4f} +/- {std:.4f}")

    Gradient Boosting is King of Tabular Data

    For structured/tabular data (spreadsheets, databases), gradient boosting models (especially XGBoost and LightGBM) consistently outperform other methods including deep learning. Deep learning excels at images, text, and audio, but for tables with rows and columns, trees still win.