Skip to main content

Experiment Tracking & Model Registry

Log, compare, and manage ML experiments systematically

~45 min
Listen to this lesson

Experiment Tracking & Model Registry

Every ML project involves hundreds of experiments — different architectures, hyperparameters, datasets, preprocessing steps. Without a system to track them, you'll quickly lose track of what worked and why.

Why Track Experiments?

Without tracking, you'll inevitably face these problems:

  • "Which hyperparameters gave me that 94% accuracy last week?"
  • "I can't reproduce the result from three months ago."
  • "Did I use the v2 or v3 dataset for this model?"
  • "Which team member's model is actually in production?"
  • Experiment tracking solves all of these by creating a structured, searchable record of every training run.

    The Three Pillars of Experiment Tracking

    Every experiment tracker captures three things: 1. **Parameters** — Hyperparameters, config values, data versions 2. **Metrics** — Loss, accuracy, F1 score, latency, etc. 3. **Artifacts** — Trained models, plots, datasets, config files

    MLflow

    MLflow is the most popular open-source experiment tracking platform. It has four main components:

    1. Tracking — Log parameters, metrics, and artifacts for each run 2. Projects — Package ML code in a reproducible format 3. Models — A standard format for packaging models with metadata 4. Model Registry — Central hub for managing model lifecycle (staging → production)

    Logging with MLflow

    python
    1import mlflow
    2import mlflow.tensorflow
    3from sklearn.datasets import make_classification
    4from sklearn.model_selection import train_test_split
    5import tensorflow as tf
    6
    7# Set the experiment name
    8mlflow.set_experiment("iris-classifier")
    9
    10# Start a tracked run
    11with mlflow.start_run(run_name="dense-relu-v1"):
    12    # Log hyperparameters
    13    mlflow.log_param("hidden_units", 64)
    14    mlflow.log_param("learning_rate", 0.001)
    15    mlflow.log_param("epochs", 50)
    16    mlflow.log_param("optimizer", "adam")
    17    mlflow.log_param("dropout", 0.2)
    18
    19    # Build and train model
    20    model = tf.keras.Sequential([
    21        tf.keras.layers.Dense(64, activation='relu', input_shape=(4,)),
    22        tf.keras.layers.Dropout(0.2),
    23        tf.keras.layers.Dense(3, activation='softmax')
    24    ])
    25    model.compile(
    26        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    27        loss='sparse_categorical_crossentropy',
    28        metrics=['accuracy']
    29    )
    30
    31    X, y = make_classification(n_features=4, n_classes=3,
    32                                n_informative=3, random_state=42)
    33    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
    34
    35    history = model.fit(X_train, y_train, epochs=50, verbose=0,
    36                        validation_data=(X_test, y_test))
    37
    38    # Log metrics
    39    final_loss = history.history['loss'][-1]
    40    final_acc = history.history['accuracy'][-1]
    41    val_acc = history.history['val_accuracy'][-1]
    42
    43    mlflow.log_metric("final_loss", final_loss)
    44    mlflow.log_metric("final_accuracy", final_acc)
    45    mlflow.log_metric("val_accuracy", val_acc)
    46
    47    # Log metrics over time (step-wise)
    48    for epoch, (loss, acc) in enumerate(
    49        zip(history.history['loss'], history.history['accuracy'])
    50    ):
    51        mlflow.log_metric("train_loss", loss, step=epoch)
    52        mlflow.log_metric("train_accuracy", acc, step=epoch)
    53
    54    # Log the model as an artifact
    55    mlflow.tensorflow.log_model(model, "model")
    56
    57    print(f"Run logged! Val accuracy: {val_acc:.4f}")

    MLflow UI

    Launch the UI to visually compare runs:

    mlflow ui --port 5000
    

    Open http://localhost:5000

    The UI lets you:

  • Compare metrics across runs as tables or charts
  • Filter and search runs by parameters or metrics
  • Download artifacts (models, plots, data)
  • View run details and metadata
  • MLflow Model Registry

    The Model Registry manages the lifecycle of models through stages:

    Experimental → Staging → Production → Archived
    

    python
    1import mlflow
    2from mlflow.tracking import MlflowClient
    3
    4client = MlflowClient()
    5
    6# Register a model from a run
    7model_uri = f"runs:/{run_id}/model"
    8model_name = "iris-classifier"
    9result = mlflow.register_model(model_uri, model_name)
    10print(f"Registered version: {result.version}")
    11
    12# Transition to staging
    13client.transition_model_version_stage(
    14    name=model_name,
    15    version=result.version,
    16    stage="Staging"
    17)
    18
    19# After validation, promote to production
    20client.transition_model_version_stage(
    21    name=model_name,
    22    version=result.version,
    23    stage="Production"
    24)
    25
    26# Load the production model for serving
    27production_model = mlflow.tensorflow.load_model(
    28    f"models:/{model_name}/Production"
    29)
    30
    31# Archive the old version
    32client.transition_model_version_stage(
    33    name=model_name,
    34    version=1,  # old version
    35    stage="Archived"
    36)

    Weights & Biases (W&B)

    Weights & Biases is a popular managed platform for experiment tracking with a beautiful web dashboard and collaboration features.

    Key features:

  • Dashboard: Interactive plots, tables, and run comparisons
  • Sweeps: Automated hyperparameter search (grid, random, Bayesian)
  • Artifacts: Version datasets, models, and other large files
  • Reports: Collaborative documents with embedded visualizations
  • python
    1import wandb
    2import tensorflow as tf
    3import numpy as np
    4
    5# Initialize a W&B run
    6wandb.init(
    7    project="iris-classifier",
    8    config={
    9        "hidden_units": 64,
    10        "learning_rate": 0.001,
    11        "epochs": 50,
    12        "optimizer": "adam",
    13        "dropout": 0.2,
    14    }
    15)
    16
    17# Build model using config
    18config = wandb.config
    19model = tf.keras.Sequential([
    20    tf.keras.layers.Dense(config.hidden_units, activation='relu',
    21                          input_shape=(4,)),
    22    tf.keras.layers.Dropout(config.dropout),
    23    tf.keras.layers.Dense(3, activation='softmax')
    24])
    25
    26model.compile(
    27    optimizer=tf.keras.optimizers.Adam(learning_rate=config.learning_rate),
    28    loss='sparse_categorical_crossentropy',
    29    metrics=['accuracy']
    30)
    31
    32# The WandbCallback automatically logs metrics every epoch
    33history = model.fit(
    34    X_train, y_train,
    35    epochs=config.epochs,
    36    validation_data=(X_test, y_test),
    37    callbacks=[wandb.keras.WandbCallback()]
    38)
    39
    40# Log additional custom metrics
    41wandb.log({
    42    "best_val_accuracy": max(history.history['val_accuracy']),
    43    "final_train_loss": history.history['loss'][-1],
    44})
    45
    46# Save model as an artifact
    47artifact = wandb.Artifact("iris-model", type="model")
    48model.save("iris_model")
    49artifact.add_dir("iris_model")
    50wandb.log_artifact(artifact)
    51
    52wandb.finish()

    W&B Sweeps (Hyperparameter Search)

    sweep_config = {
        "method": "bayes",    # or "grid", "random"
        "metric": {"name": "val_accuracy", "goal": "maximize"},
        "parameters": {
            "hidden_units": {"values": [32, 64, 128, 256]},
            "learning_rate": {"min": 0.0001, "max": 0.01},
            "dropout": {"min": 0.0, "max": 0.5},
        }
    }

    sweep_id = wandb.sweep(sweep_config, project="iris-classifier") wandb.agent(sweep_id, function=train_and_log, count=20)

    Comparison: MLflow vs W&B

    FeatureMLflowWeights & Biases
    HostingSelf-hosted (open-source)Managed cloud (free tier)
    UIFunctionalPolished, interactive
    SweepsManualBuilt-in (Bayesian, grid, random)
    Model RegistryBuilt-inVia Artifacts
    CollaborationGit-basedWeb-based teams
    CostFree (infra costs)Free tier; paid for teams
    IntegrationTF, PyTorch, SklearnTF, PyTorch, Sklearn, HF
    Best ForSelf-hosted, enterpriseTeams wanting managed solution

    Best Practices for Reproducibility

    1. **Log everything**: Random seeds, library versions, data hashes 2. **Version your data**: Use DVC or W&B Artifacts to track dataset versions 3. **Pin dependencies**: Use requirements.txt or conda environment.yml 4. **Save configs as files**: Don't rely on code comments 5. **Use git commit hashes**: Link each run to the exact code version 6. **Tag important runs**: Mark your best runs so you can find them later