Model interpretability, explainability, and debugging tools. Covers SHAP (TreeExplainer, DeepExplainer, KernelExplainer), feature importance analysis, LIME, attention visualization, Grad-CAM for CNNs, confusion matrix analysis, error analysis patterns, and model fairness auditing. Use when user asks about 'SHAP', 'feature importance', 'explainability', 'interpretability', 'why did the model predict', 'Grad-CAM', 'LIME', 'attention weights', 'confusion matrix', 'error analysis', 'model debugging', 'fairness', 'bias detection', or 'what did the model learn'.
Resources
1Install
npx skillscat add levy-n/claude-useful-skills/model-interpretability Install via the SkillsCat registry.
SKILL.md
Model Interpretability - Understanding What Models Learn
פרשנות מודלים: להבין למה המודל החליט מה שהחליט.
Quick Start - SHAP Feature Importance
import shap
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
# Train model
data = load_breast_cancer()
X_train, X_test, y_train, y_test = train_test_split(
data.data, data.target, test_size=0.2, random_state=42
)
model = xgb.XGBClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# SHAP explainer
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
# Global feature importance (which features matter most?)
shap.summary_plot(shap_values, X_test, feature_names=data.feature_names)
# Single prediction explanation (why THIS prediction?)
shap.waterfall_plot(shap.Explanation(
values=shap_values[0],
base_values=explainer.expected_value,
data=X_test[0],
feature_names=data.feature_names
))When This Skill Activates
Use this skill when:
- Understanding why a model made a specific prediction
- Analyzing global feature importance
- Debugging model errors
- Visualizing attention in transformers
- Creating Grad-CAM heatmaps for CNNs
- Auditing models for bias/fairness
- Explaining models to stakeholders
Core Patterns
Pattern 1: SHAP for Different Model Types
import shap
# 1. Tree models (XGBoost, LightGBM, Random Forest) — FASTEST
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
# 2. Deep learning (PyTorch, TensorFlow)
explainer = shap.DeepExplainer(model, background_data)
shap_values = explainer.shap_values(X)
# 3. Any model (model-agnostic) — SLOWEST but universal
explainer = shap.KernelExplainer(model.predict, shap.sample(X_train, 100))
shap_values = explainer.shap_values(X_test[:50])
# Visualizations
shap.summary_plot(shap_values, X_test) # Beeswarm (global)
shap.bar_plot(shap_values) # Bar chart (global)
shap.waterfall_plot(shap_values[i]) # Single prediction
shap.dependence_plot("feature_name", shap_values, X_test) # Interaction
shap.force_plot(explainer.expected_value, shap_values[0]) # Force diagramSHAP Intuition:
┌─────────────────────────────────────────────────────┐
│ Base prediction: 0.5 (average) │
│ │
│ Feature A (age=65): +0.15 ──▶ │
│ Feature B (income=high): +0.10 ──▶ │
│ Feature C (location=NY): -0.05 ◀── │
│ Feature D (history=yes): +0.20 ──▶ │
│ │
│ Final prediction: 0.5 + 0.15 + 0.10 - 0.05 + 0.20│
│ = 0.90 (high risk) │
│ │
│ Each feature has a Shapley value = its contribution│
│ Sum of all SHAP values = prediction - base value │
└─────────────────────────────────────────────────────┘Pattern 2: Feature Importance Comparison
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# Method 1: Built-in feature importance (MDI - impurity-based)
mdi_importance = model.feature_importances_
# Method 2: Permutation importance (model-agnostic, more reliable)
perm_importance = permutation_importance(
model, X_test, y_test, n_repeats=30, random_state=42
)
# Method 3: SHAP (gold standard)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
shap_importance = np.abs(shap_values[1]).mean(axis=0) # class 1
# Compare all three
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
top_k = 10
for ax, importance, title in zip(axes,
[mdi_importance, perm_importance.importances_mean, shap_importance],
["MDI (Built-in)", "Permutation", "SHAP"]):
idx = np.argsort(importance)[-top_k:]
ax.barh(range(top_k), importance[idx])
ax.set_yticks(range(top_k))
ax.set_yticklabels(np.array(feature_names)[idx])
ax.set_title(title)
plt.tight_layout()
plt.show()Pattern 3: Grad-CAM for CNN Visualization
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
class GradCAM:
"""Gradient-weighted Class Activation Mapping for CNNs."""
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
# Register hooks
target_layer.register_forward_hook(self._forward_hook)
target_layer.register_full_backward_hook(self._backward_hook)
def _forward_hook(self, module, input, output):
self.activations = output.detach()
def _backward_hook(self, module, grad_input, grad_output):
self.gradients = grad_output[0].detach()
def generate(self, input_tensor, target_class=None):
# Forward pass
output = self.model(input_tensor)
if target_class is None:
target_class = output.argmax(dim=1).item()
# Backward pass for target class
self.model.zero_grad()
output[0, target_class].backward()
# Compute weights (global average pooling of gradients)
weights = self.gradients.mean(dim=[2, 3], keepdim=True)
# Weighted combination of activation maps
cam = (weights * self.activations).sum(dim=1, keepdim=True)
cam = F.relu(cam) # Only positive contributions
# Normalize to [0, 1]
cam = cam - cam.min()
cam = cam / (cam.max() + 1e-8)
return cam.squeeze().cpu().numpy()
# Usage
model.eval()
grad_cam = GradCAM(model, model.layer4[-1]) # Last conv layer
cam = grad_cam.generate(input_image.unsqueeze(0).to(device))
# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(original_image)
axes[0].set_title("Original")
axes[1].imshow(cam, cmap='jet')
axes[1].set_title("Grad-CAM")
axes[2].imshow(original_image)
axes[2].imshow(cam, cmap='jet', alpha=0.5)
axes[2].set_title("Overlay")
plt.show()Pattern 4: Confusion Matrix Deep Analysis
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
# Predictions
y_pred = model.predict(X_test)
# Basic confusion matrix
cm = confusion_matrix(y_test, y_pred)
# Normalized confusion matrix (percentages)
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# Counts
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names, ax=axes[0])
axes[0].set_title("Confusion Matrix (Counts)")
axes[0].set_ylabel("True Label")
axes[0].set_xlabel("Predicted Label")
# Percentages
sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues',
xticklabels=class_names, yticklabels=class_names, ax=axes[1])
axes[1].set_title("Confusion Matrix (Normalized)")
plt.tight_layout()
plt.show()
# Detailed classification report
print(classification_report(y_test, y_pred, target_names=class_names))
# Error analysis: find most confused pairs
errors = np.argwhere((cm > 0) & ~np.eye(len(class_names), dtype=bool))
for true_idx, pred_idx in errors:
if cm[true_idx, pred_idx] > 5: # Significant errors
print(f" {class_names[true_idx]} confused with "
f"{class_names[pred_idx]}: {cm[true_idx, pred_idx]} times")Pattern 5: Attention Visualization for Transformers
from transformers import AutoTokenizer, AutoModel
import torch
import matplotlib.pyplot as plt
import seaborn as sns
# Load model with attention output
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)
# Get attention weights
text = "The cat sat on the mat because it was tired"
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# outputs.attentions: tuple of (batch, heads, seq_len, seq_len)
attention = outputs.attentions # 12 layers × 12 heads
# Visualize attention from last layer, head 0
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
attn_matrix = attention[-1][0, 0].numpy() # Last layer, first head
plt.figure(figsize=(10, 8))
sns.heatmap(attn_matrix, xticklabels=tokens, yticklabels=tokens,
cmap="Blues", annot=False)
plt.title("BERT Attention (Layer 12, Head 1)")
plt.xlabel("Attending To")
plt.ylabel("Attending From")
plt.tight_layout()
plt.show()
# Average attention across all heads (more stable)
avg_attention = torch.stack(attention).mean(dim=[0, 2]).squeeze().numpy()Pattern 6: Error Analysis Pipeline
import pandas as pd
import numpy as np
def error_analysis(model, X_test, y_test, feature_names, class_names):
"""Systematic error analysis pipeline."""
y_pred = model.predict(X_test)
y_prob = model.predict_proba(X_test)
# Create analysis DataFrame
df = pd.DataFrame(X_test, columns=feature_names)
df['true_label'] = y_test
df['predicted'] = y_pred
df['correct'] = y_test == y_pred
df['confidence'] = y_prob.max(axis=1)
# 1. Overall accuracy
print(f"Overall accuracy: {df['correct'].mean():.3f}")
# 2. Accuracy by class
print("\nAccuracy by class:")
for cls in np.unique(y_test):
mask = df['true_label'] == cls
acc = df[mask]['correct'].mean()
print(f" {class_names[cls]}: {acc:.3f} ({mask.sum()} samples)")
# 3. High-confidence errors (model is SURE but WRONG)
high_conf_errors = df[(~df['correct']) & (df['confidence'] > 0.9)]
print(f"\nHigh-confidence errors: {len(high_conf_errors)}")
# 4. Low-confidence correct (model is UNSURE but RIGHT)
low_conf_correct = df[(df['correct']) & (df['confidence'] < 0.6)]
print(f"Low-confidence correct: {len(low_conf_correct)}")
# 5. Feature statistics for errors vs correct
print("\nFeature differences (errors vs correct):")
for feat in feature_names[:5]:
err_mean = df[~df['correct']][feat].mean()
cor_mean = df[df['correct']][feat].mean()
if abs(err_mean - cor_mean) > 0.1:
print(f" {feat}: errors={err_mean:.3f}, correct={cor_mean:.3f}")
return df
analysis_df = error_analysis(model, X_test, y_test, feature_names, class_names)Pattern 7: LIME Explanations
from lime.lime_tabular import LimeTabularExplainer
# Create LIME explainer
explainer = LimeTabularExplainer(
X_train,
feature_names=feature_names,
class_names=class_names,
mode='classification'
)
# Explain single prediction
explanation = explainer.explain_instance(
X_test[0],
model.predict_proba,
num_features=10,
top_labels=1
)
# Show in notebook
explanation.show_in_notebook()
# Or get as list
print("Feature contributions:")
for feature, weight in explanation.as_list():
direction = "+" if weight > 0 else "-"
print(f" {direction} {feature}: {weight:.4f}")Reference Navigation
For detailed content, see:
- SHAP Explainability:
reference/shap_explainability.md- All explainer types, Shapley theory - Feature Importance:
reference/feature_importance.md- MDI, Permutation, comparison - Model Debugging:
reference/model_debugging.md- Error analysis, failure modes - Visualization Tools:
reference/visualization_tools.md- Grad-CAM, attention, confusion matrix
Common Mistakes to Avoid
1. Using MDI Feature Importance Alone
# WRONG: MDI is biased toward high-cardinality features
importance = model.feature_importances_ # Can be misleading!
# CORRECT: Always validate with permutation importance or SHAP
perm_importance = permutation_importance(model, X_test, y_test)2. Wrong Background Data for SHAP
# WRONG: Using test data as background
explainer = shap.KernelExplainer(model.predict, X_test) # Data leakage!
# CORRECT: Use training data (or sample)
explainer = shap.KernelExplainer(model.predict, shap.sample(X_train, 100))3. Interpreting Correlation as Causation
SHAP tells you: "Feature X contributes +0.3 to prediction"
This means: "If X changes, prediction changes by 0.3"
This does NOT mean: "X causes the outcome"
Always combine SHAP with domain knowledge!Teaching Mode
Interpretability Spectrum
More Interpretable ◀─────────────────────▶ Less Interpretable
(Black Box)
┌──────────┬──────────┬──────────┬──────────┬──────────┐
│ Linear │ Decision │ Random │ Neural │ Deep │
│ Reg. │ Tree │ Forest │ Network │ LLM │
│ │ │ │ │ │
│ Coeff = │ Rules │ Feature │ Need │ Need │
│ feature │ visible │ import. │ SHAP / │ Attention│
│ import. │ │ │ Grad-CAM │ / Probes │
└──────────┴──────────┴──────────┴──────────┴──────────┘
Rule: As model complexity increases,
interpretability effort increases too.Cross-References
- ML evaluation:
../ml-fundamentals/SKILL.md- Metrics, cross-validation - CNN details:
../cnn-vision/SKILL.md- For Grad-CAM architecture context - Transformers:
../transformers-llm/SKILL.md- For attention visualization - MLOps:
../mlops-experiment/SKILL.md- Logging interpretability artifacts