This skill should be used when the user asks to "enable torch.compile", "add compilation to training", "kernel fusion", "TorchDynamo integration", "TorchInductor optimization", "reduce-overhead mode", "max-autotune mode", "fix graph breaks", "compile health check", "shape stabilization", "dynamic shapes for compile", "bucketing for torch.compile", "compile allowlist", "compile blocklist", "compile smoketest", "CUDA graphs for training", "maybe_compile wrapper", "debug recompiles", "TORCH_LOGS compile", "compile + DDP", "compile + FSDP", "torch.compiler.disable", or needs guidance on torch.compile integration, shape management, compilation debugging, or safe fallback patterns.
Resources
5Install
npx skillscat add sovr610/refffiy/compiler-kernel-fusion-torch-compile-integration Install via the SkillsCat registry.
Compiler & Kernel Fusion (torch.compile) Integration
Overview
Guide implementation of an opt-in compilation layer using torch.compile (TorchDynamo + AOTAutograd + TorchInductor) for training and inference speedup. Cover the safe wrapper pattern (never breaks training), compilation modes, shape stabilization (bucketing + mark_dynamic), health checks, allowlist/blocklist for selective compilation, distributed training interactions (DDP/FSDP), debugging, and integration with the benchmark harness.
Design principle: Safe by default. If compilation fails or regresses, fall back to eager with explicit logs.
Public Contract
maybe_compile
Safe wrapper — the single entry point for all compilation.
def maybe_compile(
model: nn.Module,
cfg: CompileConfig,
logger: logging.Logger,
sample_batch: Optional[Dict] = None,
) -> nn.Module: ...CompileConfig
All compilation flags in one place.
@dataclass
class CompileConfig:
enabled: bool = False
mode: str = "default" # default | reduce-overhead | max-autotune
dynamic: Optional[bool] = None # None=auto, False=static, True=force dynamic
backend: str = "inductor"
fullgraph: bool = False
options: Optional[Dict] = None
allowlist: List[str] = ()
blocklist: List[str] = ()
healthcheck: bool = True
fail_policy: str = "fallback_eager" # fallback_eager | raiseShapeStabilizer
Prevent recompile thrashing from variable sequence lengths.
class ShapeStabilizer:
def __init__(self, buckets: List[int] = (256, 512, 1024, 2048)): ...
def bucket_batch(self, input_ids: Tensor) -> Tensor: ...
def mark_dynamic_dims(self, batch: Dict[str, Tensor], dim: int = 1): ...CompileSmoketest
Validate compilation works before committing to a full run.
class CompileSmoketest:
def __init__(self, steps: int = 3): ...
def run(self, model: nn.Module, sample_batch: Dict,
optimizer: Optimizer, loss_fn: Callable) -> bool: ...Key Concepts
Compilation Modes
| Mode | Behavior | Trade-off |
|---|---|---|
default |
Basic TorchInductor fusion | Fastest compile, moderate speedup |
reduce-overhead |
Enables CUDA graphs | Lower overhead per step, more memory, stricter shape requirements |
max-autotune |
Autotuned kernels + CUDA graphs | Slowest compile, best steady-state throughput |
The Safe Wrapper Pattern
maybe_compile(model, cfg, logger, sample_batch)
1. Return model unchanged if cfg.enabled=False
2. Try torch.compile(model, backend, mode, dynamic, fullgraph, options)
3. If compile raises: log exception, return eager model
4. If healthcheck=True and sample_batch provided:
Run 1-3 training steps (forward → loss → backward → optimizer.step)
If smoketest fails: log error, return eager model
5. Return compiled modelNever hard-crash by default. Log: compile enabled, backend, mode, dynamic, fullgraph, success/failure, compile wall time, exception + stack on failure.
Shape Stabilization (Priority Order)
A) Bucketing (preferred for LLM training)
Bucket examples by length into fixed bins (256/512/1024/2048). Pad within bucket. Result: small set of stable shapes, kernels stay efficient.
B) Automatic dynamic shapes (dynamic=None, the default)
Compile assumes static first. On shape change, installs guards and auto-generalizes.
C) Pre-mark dynamic dimensions (advanced)
torch._dynamo.mark_dynamic(input_ids, dim=1, min=1, max=max_seq_len)Must be called on input tensors before invoking compiled code (not inside forward).
Avoid: dynamic=True (blanket force) — PyTorch docs explicitly call this "not recommended" / testing-oriented.
Allowlist / Blocklist
Compile only core compute modules (attention, MLP, layernorm, projections). Disable:
- Data preprocessing, tokenization
- Metrics/logging
- Text generation sampling loops (branchy, data-dependent)
- Custom ops not stable under Inductor
Use @torch.compiler.disable on known-problem functions. recursive=False to disable only the decorated function, not callees.
Debugging Recompiles
| Env Variable | Purpose |
|---|---|
TORCH_LOGS=dynamic |
Shows guard installation and recompile triggers |
TORCH_LOGS=perf_hints |
Diagnoses CUDA graph applicability |
TORCH_LOGS=recompiles |
Logs when and why recompiles happen |
torch._dynamo.config.suppress_errors = True |
Dev-only: continue on compile errors (not for production) |
Distributed Training (DDP/FSDP)
torch.compile supports DDP and FSDP. Policy:
- Treat DDP/FSDP + compile as "supported but gated by smoketest"
- If wrapping creates issues: compile only inner transformer blocks, keep wrappers eager
- DDP: compile before wrapping, overlap/bucketing interactions may need tuning
- FSDP: configuration constraints exist; test with smoketest
Compile Health Check
The smoketest runs 1-3 complete training steps:
- forward → loss → backward → optimizer.step → zero_grad
- Same device/precision/distributed wrapper as real training
- Deterministic mini-batch (fixed tokens) for stability
- Catches: compilation-time failures, backward failures, autograd interactions
First step includes compilation time — benchmark harness must warm up past this.
Benchmark Integration
Run same benchmark twice: eager vs compiled. Report:
- Compile time (time-to-first-step)
- Steady-state tokens/sec (after warmup)
- Peak memory delta
- Recompile count
Configuration Surface
@dataclass
class CompileConfig:
enabled: bool = False
mode: str = "default"
dynamic: Optional[bool] = None
backend: str = "inductor"
fullgraph: bool = False
options: Optional[Dict] = None # epilogue_fusion, shape_padding, etc.
allowlist: List[str] = ()
blocklist: List[str] = ()
healthcheck: bool = True
fail_policy: str = "fallback_eager" # fallback_eager | raise
smoketest_steps: int = 3
@dataclass
class BucketConfig:
buckets: List[int] = (256, 512, 1024, 2048)
pad_token_id: int = 0Done-When Gates
- Safe Fallback — With
enabled=Trueand a deliberately broken model (e.g., unsupported op),maybe_compilecatches the error, logs it, and returns the original eager model without crashing. - Smoketest Detects Failure —
CompileSmoketest.run()returnsFalsewhen backward fails on a compiled model andTruewhen compilation succeeds end-to-end. - Shape Stability —
ShapeStabilizer.bucket_batch()produces padded tensors matching bucket boundaries; compiled model runs without recompiles across all bucket sizes.
Resources
Reference Files
references/compile-modes.md— default/reduce-overhead/max-autotune behavior, CUDA graphs, backend options, mode selection guidereferences/shape-stabilization.md— Bucketing algorithm, mark_dynamic API, recompile avoidance, guard system, automatic dynamic shapesreferences/debugging-profiling.md— TORCH_LOGS flags, graph break diagnosis, suppress_errors, perf_hints, compile time measurementreferences/distributed-compile.md— DDP + compile, FSDP + compile, wrapping order, overlap/bucketing, selective block compilationreferences/testing-matrix.md— Test scenarios for all components
Asset Files
assets/compile_config_template.py— CompileConfig, BucketConfig dataclasses with validationassets/compile_wrap_template.py— maybe_compile safe wrapper with logging, timing, fallbackassets/shape_stabilize_template.py— ShapeStabilizer with bucketing, mark_dynamic, pad/unpadassets/compile_smoketest_template.py— CompileSmoketest with forward-backward-optimizer validationassets/compile_benchmark_template.py— Eager vs compiled benchmark comparison, report generation
Scripts
scripts/validate_compile.py— Validates done-when gatesscripts/gen_compile_tests.py— Generates 100+ pytest test casesscripts/compile_playbook.py— Generates docs/torch_compile.md playbook with gotchas and debug knobs