OneBoxCream

python-ml-patterns

Python 机器学习代码模式和最佳实践。涵盖 PyTorch 训练、数据处理、模型评估、显存管理等。编写 ML 相关代码时自动激活。

OneBoxCream 0 Updated 3mo ago
GitHub

Install

npx skillscat add oneboxcream/claude-code-setup/python-ml-patterns

Install via the SkillsCat registry.

SKILL.md

Python ML Patterns

环境约束

  • Python 3.8.13 — 禁止使用 3.9+ 语法
  • 类型注解使用 typing 模块:Dict, List, Optional, Union, Tuple
  • 使用 from __future__ import annotations 实现延迟求值

训练代码模式

Config 管理

from dataclasses import dataclass, field
from typing import List, Optional

@dataclass
class TrainConfig:
    model_name: str = "model_v1"
    batch_size: int = 8
    learning_rate: float = 1e-4
    num_epochs: int = 10
    gradient_accumulation_steps: int = 4
    fp16: bool = True
    output_dir: str = "./outputs"
    seed: int = 42

显存管理

import gc
import torch

def clear_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()

# 模型切换时必须先释放旧模型
def switch_model(old_model, new_model_fn):
    del old_model
    clear_gpu_memory()
    return new_model_fn()

可复现性

def set_seed(seed: int):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

数据处理模式

Dataset 模板

from torch.utils.data import Dataset
from typing import Dict, Any

class MultiModalDataset(Dataset):
    def __init__(self, data_path: str, transform=None):
        self.data = self._load_data(data_path)
        self.transform = transform

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        item = self.data[idx]
        if self.transform:
            item = self.transform(item)
        return item

    def _load_data(self, path: str):
        # 延迟加载,避免内存溢出
        ...

数据验证

def validate_dataset(dataset, num_samples=5):
    """快速验证 dataset 输出格式"""
    for i in range(min(num_samples, len(dataset))):
        item = dataset[i]
        assert isinstance(item, dict), f"Item {i} is not dict"
        # 检查必要字段
        for key in ["input_ids", "labels"]:
            assert key in item, f"Missing key: {key} in item {i}"
        # 检查 tensor shape
        for key, val in item.items():
            if hasattr(val, 'shape'):
                print(f"  {key}: {val.shape} {val.dtype}")

评估模式

Metrics 收集

from collections import defaultdict
from typing import Dict, List

class MetricsCollector:
    def __init__(self):
        self.metrics = defaultdict(list)

    def update(self, **kwargs):
        for k, v in kwargs.items():
            self.metrics[k].append(v)

    def summary(self) -> Dict[str, float]:
        return {k: sum(v) / len(v) for k, v in self.metrics.items()}

    def to_table(self) -> str:
        summary = self.summary()
        lines = [f"| {'Metric':<20} | {'Value':>10} |"]
        lines.append(f"|{'-'*22}|{'-'*12}|")
        for k, v in summary.items():
            lines.append(f"| {k:<20} | {v:>10.4f} |")
        return "\n".join(lines)

反模式(避免)

  • ❌ 在循环中反复 .to(device) — 应提前移动
  • ❌ 不设 model.eval() 就推理
  • ❌ 忘记 torch.no_grad() 导致显存泄漏
  • ❌ 用 os.path.join 拼接而不检查路径存在性
  • ❌ 硬编码绝对路径 — 应使用 config 或参数传入
  • from typing import * — 应显式导入