levy-n

model-interpretability

Model interpretability, explainability, and debugging tools. Covers SHAP (TreeExplainer, DeepExplainer, KernelExplainer), feature importance analysis, LIME, attention visualization, Grad-CAM for CNNs, confusion matrix analysis, error analysis patterns, and model fairness auditing. Use when user asks about 'SHAP', 'feature importance', 'explainability', 'interpretability', 'why did the model predict', 'Grad-CAM', 'LIME', 'attention weights', 'confusion matrix', 'error analysis', 'model debugging', 'fairness', 'bias detection', or 'what did the model learn'.

levy-n 10 1 Updated 4mo ago

Resources

1
GitHub

Install

npx skillscat add levy-n/claude-useful-skills/model-interpretability

Install via the SkillsCat registry.

SKILL.md

Model Interpretability - Understanding What Models Learn

פרשנות מודלים: להבין למה המודל החליט מה שהחליט.

Quick Start - SHAP Feature Importance

import shap
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer

# Train model
data = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(
    data.data, data.target, test_size=0.2, random_state=42
)

model = xgb.XGBClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# SHAP explainer
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

# Global feature importance (which features matter most?)
shap.summary_plot(shap_values, X_test, feature_names=data.feature_names)

# Single prediction explanation (why THIS prediction?)
shap.waterfall_plot(shap.Explanation(
    values=shap_values[0],
    base_values=explainer.expected_value,
    data=X_test[0],
    feature_names=data.feature_names
))

When This Skill Activates

Use this skill when:

  • Understanding why a model made a specific prediction
  • Analyzing global feature importance
  • Debugging model errors
  • Visualizing attention in transformers
  • Creating Grad-CAM heatmaps for CNNs
  • Auditing models for bias/fairness
  • Explaining models to stakeholders

Core Patterns

Pattern 1: SHAP for Different Model Types

import shap

# 1. Tree models (XGBoost, LightGBM, Random Forest) — FASTEST
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)

# 2. Deep learning (PyTorch, TensorFlow)
explainer = shap.DeepExplainer(model, background_data)
shap_values = explainer.shap_values(X)

# 3. Any model (model-agnostic) — SLOWEST but universal
explainer = shap.KernelExplainer(model.predict, shap.sample(X_train, 100))
shap_values = explainer.shap_values(X_test[:50])

# Visualizations
shap.summary_plot(shap_values, X_test)          # Beeswarm (global)
shap.bar_plot(shap_values)                        # Bar chart (global)
shap.waterfall_plot(shap_values[i])              # Single prediction
shap.dependence_plot("feature_name", shap_values, X_test)  # Interaction
shap.force_plot(explainer.expected_value, shap_values[0])   # Force diagram
SHAP Intuition:
┌─────────────────────────────────────────────────────┐
│  Base prediction: 0.5 (average)                     │
│                                                     │
│  Feature A (age=65):      +0.15  ──▶               │
│  Feature B (income=high): +0.10  ──▶               │
│  Feature C (location=NY): -0.05  ◀──               │
│  Feature D (history=yes): +0.20  ──▶               │
│                                                     │
│  Final prediction: 0.5 + 0.15 + 0.10 - 0.05 + 0.20│
│                  = 0.90 (high risk)                 │
│                                                     │
│  Each feature has a Shapley value = its contribution│
│  Sum of all SHAP values = prediction - base value   │
└─────────────────────────────────────────────────────┘

Pattern 2: Feature Importance Comparison

import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance

model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Method 1: Built-in feature importance (MDI - impurity-based)
mdi_importance = model.feature_importances_

# Method 2: Permutation importance (model-agnostic, more reliable)
perm_importance = permutation_importance(
    model, X_test, y_test, n_repeats=30, random_state=42
)

# Method 3: SHAP (gold standard)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
shap_importance = np.abs(shap_values[1]).mean(axis=0)  # class 1

# Compare all three
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
top_k = 10

for ax, importance, title in zip(axes,
    [mdi_importance, perm_importance.importances_mean, shap_importance],
    ["MDI (Built-in)", "Permutation", "SHAP"]):

    idx = np.argsort(importance)[-top_k:]
    ax.barh(range(top_k), importance[idx])
    ax.set_yticks(range(top_k))
    ax.set_yticklabels(np.array(feature_names)[idx])
    ax.set_title(title)

plt.tight_layout()
plt.show()

Pattern 3: Grad-CAM for CNN Visualization

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

class GradCAM:
    """Gradient-weighted Class Activation Mapping for CNNs."""

    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        # Register hooks
        target_layer.register_forward_hook(self._forward_hook)
        target_layer.register_full_backward_hook(self._backward_hook)

    def _forward_hook(self, module, input, output):
        self.activations = output.detach()

    def _backward_hook(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def generate(self, input_tensor, target_class=None):
        # Forward pass
        output = self.model(input_tensor)

        if target_class is None:
            target_class = output.argmax(dim=1).item()

        # Backward pass for target class
        self.model.zero_grad()
        output[0, target_class].backward()

        # Compute weights (global average pooling of gradients)
        weights = self.gradients.mean(dim=[2, 3], keepdim=True)

        # Weighted combination of activation maps
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = F.relu(cam)  # Only positive contributions

        # Normalize to [0, 1]
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)

        return cam.squeeze().cpu().numpy()

# Usage
model.eval()
grad_cam = GradCAM(model, model.layer4[-1])  # Last conv layer

cam = grad_cam.generate(input_image.unsqueeze(0).to(device))

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(original_image)
axes[0].set_title("Original")
axes[1].imshow(cam, cmap='jet')
axes[1].set_title("Grad-CAM")
axes[2].imshow(original_image)
axes[2].imshow(cam, cmap='jet', alpha=0.5)
axes[2].set_title("Overlay")
plt.show()

Pattern 4: Confusion Matrix Deep Analysis

from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# Predictions
y_pred = model.predict(X_test)

# Basic confusion matrix
cm = confusion_matrix(y_test, y_pred)

# Normalized confusion matrix (percentages)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names, ax=axes[0])
axes[0].set_title("Confusion Matrix (Counts)")
axes[0].set_ylabel("True Label")
axes[0].set_xlabel("Predicted Label")

# Percentages
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names, ax=axes[1])
axes[1].set_title("Confusion Matrix (Normalized)")
plt.tight_layout()
plt.show()

# Detailed classification report
print(classification_report(y_test, y_pred, target_names=class_names))

# Error analysis: find most confused pairs
errors = np.argwhere((cm > 0) & ~np.eye(len(class_names), dtype=bool))
for true_idx, pred_idx in errors:
    if cm[true_idx, pred_idx] > 5:  # Significant errors
        print(f"  {class_names[true_idx]} confused with "
              f"{class_names[pred_idx]}: {cm[true_idx, pred_idx]} times")

Pattern 5: Attention Visualization for Transformers

from transformers import AutoTokenizer, AutoModel
import torch
import matplotlib.pyplot as plt
import seaborn as sns

# Load model with attention output
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

# Get attention weights
text = "The cat sat on the mat because it was tired"
inputs = tokenizer(text, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

# outputs.attentions: tuple of (batch, heads, seq_len, seq_len)
attention = outputs.attentions  # 12 layers × 12 heads

# Visualize attention from last layer, head 0
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
attn_matrix = attention[-1][0, 0].numpy()  # Last layer, first head

plt.figure(figsize=(10, 8))
sns.heatmap(attn_matrix, xticklabels=tokens, yticklabels=tokens,
            cmap="Blues", annot=False)
plt.title("BERT Attention (Layer 12, Head 1)")
plt.xlabel("Attending To")
plt.ylabel("Attending From")
plt.tight_layout()
plt.show()

# Average attention across all heads (more stable)
avg_attention = torch.stack(attention).mean(dim=[0, 2]).squeeze().numpy()

Pattern 6: Error Analysis Pipeline

import pandas as pd
import numpy as np

def error_analysis(model, X_test, y_test, feature_names, class_names):
    """Systematic error analysis pipeline."""

    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)

    # Create analysis DataFrame
    df = pd.DataFrame(X_test, columns=feature_names)
    df['true_label'] = y_test
    df['predicted'] = y_pred
    df['correct'] = y_test == y_pred
    df['confidence'] = y_prob.max(axis=1)

    # 1. Overall accuracy
    print(f"Overall accuracy: {df['correct'].mean():.3f}")

    # 2. Accuracy by class
    print("\nAccuracy by class:")
    for cls in np.unique(y_test):
        mask = df['true_label'] == cls
        acc = df[mask]['correct'].mean()
        print(f"  {class_names[cls]}: {acc:.3f} ({mask.sum()} samples)")

    # 3. High-confidence errors (model is SURE but WRONG)
    high_conf_errors = df[(~df['correct']) & (df['confidence'] > 0.9)]
    print(f"\nHigh-confidence errors: {len(high_conf_errors)}")

    # 4. Low-confidence correct (model is UNSURE but RIGHT)
    low_conf_correct = df[(df['correct']) & (df['confidence'] < 0.6)]
    print(f"Low-confidence correct: {len(low_conf_correct)}")

    # 5. Feature statistics for errors vs correct
    print("\nFeature differences (errors vs correct):")
    for feat in feature_names[:5]:
        err_mean = df[~df['correct']][feat].mean()
        cor_mean = df[df['correct']][feat].mean()
        if abs(err_mean - cor_mean) > 0.1:
            print(f"  {feat}: errors={err_mean:.3f}, correct={cor_mean:.3f}")

    return df

analysis_df = error_analysis(model, X_test, y_test, feature_names, class_names)

Pattern 7: LIME Explanations

from lime.lime_tabular import LimeTabularExplainer

# Create LIME explainer
explainer = LimeTabularExplainer(
    X_train,
    feature_names=feature_names,
    class_names=class_names,
    mode='classification'
)

# Explain single prediction
explanation = explainer.explain_instance(
    X_test[0],
    model.predict_proba,
    num_features=10,
    top_labels=1
)

# Show in notebook
explanation.show_in_notebook()

# Or get as list
print("Feature contributions:")
for feature, weight in explanation.as_list():
    direction = "+" if weight > 0 else "-"
    print(f"  {direction} {feature}: {weight:.4f}")

Reference Navigation

For detailed content, see:

  • SHAP Explainability: reference/shap_explainability.md - All explainer types, Shapley theory
  • Feature Importance: reference/feature_importance.md - MDI, Permutation, comparison
  • Model Debugging: reference/model_debugging.md - Error analysis, failure modes
  • Visualization Tools: reference/visualization_tools.md - Grad-CAM, attention, confusion matrix

Common Mistakes to Avoid

1. Using MDI Feature Importance Alone

# WRONG: MDI is biased toward high-cardinality features
importance = model.feature_importances_  # Can be misleading!

# CORRECT: Always validate with permutation importance or SHAP
perm_importance = permutation_importance(model, X_test, y_test)

2. Wrong Background Data for SHAP

# WRONG: Using test data as background
explainer = shap.KernelExplainer(model.predict, X_test)  # Data leakage!

# CORRECT: Use training data (or sample)
explainer = shap.KernelExplainer(model.predict, shap.sample(X_train, 100))

3. Interpreting Correlation as Causation

SHAP tells you: "Feature X contributes +0.3 to prediction"
This means: "If X changes, prediction changes by 0.3"
This does NOT mean: "X causes the outcome"

Always combine SHAP with domain knowledge!

Teaching Mode

Interpretability Spectrum

More Interpretable ◀─────────────────────▶ Less Interpretable
                                             (Black Box)
┌──────────┬──────────┬──────────┬──────────┬──────────┐
│ Linear   │ Decision │ Random   │ Neural   │ Deep     │
│ Reg.     │ Tree     │ Forest   │ Network  │ LLM      │
│          │          │          │          │          │
│ Coeff =  │ Rules    │ Feature  │ Need     │ Need     │
│ feature  │ visible  │ import.  │ SHAP /   │ Attention│
│ import.  │          │          │ Grad-CAM │ / Probes │
└──────────┴──────────┴──────────┴──────────┴──────────┘

Rule: As model complexity increases,
      interpretability effort increases too.

Cross-References

  • ML evaluation: ../ml-fundamentals/SKILL.md - Metrics, cross-validation
  • CNN details: ../cnn-vision/SKILL.md - For Grad-CAM architecture context
  • Transformers: ../transformers-llm/SKILL.md - For attention visualization
  • MLOps: ../mlops-experiment/SKILL.md - Logging interpretability artifacts