Install
npx skillscat add oneboxcream/claude-code-setup/python-ml-patterns Install via the SkillsCat registry.
SKILL.md
Python ML Patterns
环境约束
- Python 3.8.13 — 禁止使用 3.9+ 语法
- 类型注解使用
typing模块:Dict,List,Optional,Union,Tuple - 使用
from __future__ import annotations实现延迟求值
训练代码模式
Config 管理
from dataclasses import dataclass, field
from typing import List, Optional
@dataclass
class TrainConfig:
model_name: str = "model_v1"
batch_size: int = 8
learning_rate: float = 1e-4
num_epochs: int = 10
gradient_accumulation_steps: int = 4
fp16: bool = True
output_dir: str = "./outputs"
seed: int = 42显存管理
import gc
import torch
def clear_gpu_memory():
gc.collect()
torch.cuda.empty_cache()
# 模型切换时必须先释放旧模型
def switch_model(old_model, new_model_fn):
del old_model
clear_gpu_memory()
return new_model_fn()可复现性
def set_seed(seed: int):
import random
import numpy as np
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)数据处理模式
Dataset 模板
from torch.utils.data import Dataset
from typing import Dict, Any
class MultiModalDataset(Dataset):
def __init__(self, data_path: str, transform=None):
self.data = self._load_data(data_path)
self.transform = transform
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, Any]:
item = self.data[idx]
if self.transform:
item = self.transform(item)
return item
def _load_data(self, path: str):
# 延迟加载,避免内存溢出
...数据验证
def validate_dataset(dataset, num_samples=5):
"""快速验证 dataset 输出格式"""
for i in range(min(num_samples, len(dataset))):
item = dataset[i]
assert isinstance(item, dict), f"Item {i} is not dict"
# 检查必要字段
for key in ["input_ids", "labels"]:
assert key in item, f"Missing key: {key} in item {i}"
# 检查 tensor shape
for key, val in item.items():
if hasattr(val, 'shape'):
print(f" {key}: {val.shape} {val.dtype}")评估模式
Metrics 收集
from collections import defaultdict
from typing import Dict, List
class MetricsCollector:
def __init__(self):
self.metrics = defaultdict(list)
def update(self, **kwargs):
for k, v in kwargs.items():
self.metrics[k].append(v)
def summary(self) -> Dict[str, float]:
return {k: sum(v) / len(v) for k, v in self.metrics.items()}
def to_table(self) -> str:
summary = self.summary()
lines = [f"| {'Metric':<20} | {'Value':>10} |"]
lines.append(f"|{'-'*22}|{'-'*12}|")
for k, v in summary.items():
lines.append(f"| {k:<20} | {v:>10.4f} |")
return "\n".join(lines)反模式(避免)
- ❌ 在循环中反复
.to(device)— 应提前移动 - ❌ 不设
model.eval()就推理 - ❌ 忘记
torch.no_grad()导致显存泄漏 - ❌ 用
os.path.join拼接而不检查路径存在性 - ❌ 硬编码绝对路径 — 应使用 config 或参数传入
- ❌
from typing import *— 应显式导入