Skip to main content

The Scikit-learn Workflow

Learn the end-to-end sklearn pattern: split, preprocess, train, evaluate, and tune

~50 min
Listen to this lesson

The Scikit-learn Workflow

Scikit-learn provides a consistent, elegant API for machine learning in Python. Every algorithm follows the same pattern: fit, predict, score. In this lesson, you'll learn the complete workflow from raw data to tuned model.

The Golden Pattern

Every sklearn estimator follows this interface:

model = SomeModel(hyperparameters)
model.fit(X_train, y_train)
predictions = model.predict(X_test)
score = model.score(X_test, y_test)

This consistency is what makes scikit-learn so powerful. Once you learn the pattern, you can swap algorithms with a single line change.

The Estimator API

Every scikit-learn model implements fit() to learn from data, predict() to make predictions, and score() to evaluate performance. Transformers also implement transform() and fit_transform(). This unified API is the foundation of sklearn's design philosophy.

Step 1: Train/Test Split

Before anything else, you must split your data into training and test sets. The test set simulates unseen real-world data.

python
1from sklearn.model_selection import train_test_split
2from sklearn.datasets import load_iris
3
4# Load a dataset
5X, y = load_iris(return_X_y=True)
6
7# Split: 80% train, 20% test
8X_train, X_test, y_train, y_test = train_test_split(
9    X, y,
10    test_size=0.2,
11    random_state=42,
12    stratify=y  # Preserve class proportions
13)
14
15print(f"Training set: {X_train.shape[0]} samples")
16print(f"Test set:     {X_test.shape[0]} samples")

Always Use stratify for Classification

When doing classification, pass stratify=y to train_test_split. This ensures each class is represented proportionally in both splits. Without it, you might get a training set with no examples of a rare class.

Step 2: Preprocessing

Real-world data needs cleaning and transformation before models can use it. Common preprocessing steps include scaling numeric features, encoding categorical features, and handling missing values.

python
1from sklearn.preprocessing import StandardScaler, LabelEncoder
2from sklearn.impute import SimpleImputer
3import numpy as np
4
5# StandardScaler: zero mean, unit variance
6scaler = StandardScaler()
7X_train_scaled = scaler.fit_transform(X_train)  # Learn stats from train
8X_test_scaled = scaler.transform(X_test)         # Apply same stats to test
9
10print(f"Train mean before: {X_train[:, 0].mean():.2f}")
11print(f"Train mean after:  {X_train_scaled[:, 0].mean():.6f}")
12
13# SimpleImputer: fill missing values
14imputer = SimpleImputer(strategy="median")
15X_clean = imputer.fit_transform(X_train)

Data Leakage: The #1 ML Mistake

NEVER fit your scaler (or any transformer) on the test data. Call fit_transform() on training data, then transform() on test data. If you fit on the full dataset before splitting, information from the test set 'leaks' into your training process, giving you falsely optimistic results that won't hold up in production.

Step 3: Pipelines

Pipelines chain preprocessing and modeling steps together, preventing data leakage and making your code cleaner.

python
1from sklearn.pipeline import Pipeline
2from sklearn.preprocessing import StandardScaler
3from sklearn.svm import SVC
4
5# Create a pipeline: scale then classify
6pipe = Pipeline([
7    ("scaler", StandardScaler()),
8    ("classifier", SVC(kernel="rbf", C=1.0))
9])
10
11# fit() scales the data, then trains the SVM
12pipe.fit(X_train, y_train)
13
14# predict() scales the data, then predicts
15accuracy = pipe.score(X_test, y_test)
16print(f"Pipeline accuracy: {accuracy:.4f}")

Why Pipelines Prevent Leakage

When you use a Pipeline, the scaler is fit only on training folds during cross-validation. Without a pipeline, you might accidentally fit the scaler on the full dataset first. Pipelines enforce correct data flow automatically.

Step 4: Cross-Validation

A single train/test split can be misleading. Cross-validation gives a more robust estimate of model performance by training and evaluating on multiple different splits.

python
1from sklearn.model_selection import cross_val_score
2from sklearn.ensemble import RandomForestClassifier
3
4model = RandomForestClassifier(n_estimators=100, random_state=42)
5
6# 5-fold cross-validation
7scores = cross_val_score(model, X_train, y_train, cv=5, scoring="accuracy")
8
9print(f"CV Scores: {scores}")
10print(f"Mean:      {scores.mean():.4f}")
11print(f"Std:       {scores.std():.4f}")

Step 5: Hyperparameter Tuning with GridSearchCV

GridSearchCV exhaustively searches through a specified parameter grid, evaluating each combination with cross-validation.

python
1from sklearn.model_selection import GridSearchCV
2from sklearn.svm import SVC
3
4# Define the parameter grid
5param_grid = {
6    "C": [0.1, 1, 10, 100],
7    "kernel": ["rbf", "linear"],
8    "gamma": ["scale", "auto"]
9}
10
11grid_search = GridSearchCV(
12    SVC(),
13    param_grid,
14    cv=5,
15    scoring="accuracy",
16    n_jobs=-1,       # Use all CPU cores
17    verbose=1
18)
19
20grid_search.fit(X_train, y_train)
21
22print(f"Best params: {grid_search.best_params_}")
23print(f"Best CV score: {grid_search.best_score_:.4f}")
24print(f"Test score: {grid_search.score(X_test, y_test):.4f}")

Algorithm Comparison Quick Reference

AlgorithmTypeBest ForKey Hyperparameters
LogisticRegressionClassificationLinear boundaries, interpretabilityC, penalty
SVC / SVRBothSmall-medium datasets, non-linearC, kernel, gamma
RandomForestBothGeneral purpose, feature importancen_estimators, max_depth
GradientBoostingBothCompetitions, tabular datan_estimators, learning_rate
KNeighborsBothSimple baselines, small datasetsn_neighbors, weights
KMeansClusteringSpherical clustersn_clusters
PCAReductionDimensionality reduction, visualizationn_components

Start Simple

Always start with a simple baseline model (like LogisticRegression or a DummyClassifier). If a complex model only beats the baseline by 1%, the extra complexity may not be worth it. Simple models are faster to train, easier to debug, and more interpretable.