Predict before you read

What does PyTorch's tensor.backward() actually compute — and why is it not the full Jacobian?

Think about what shape the gradient of a scalar loss with respect to a matrix parameter must be.

From Tokens to Embodied Minds  ·  Chapter 03 of 36
Chapter 03

Backprop, autograd, and the chain rule on GPUs

Tape-based dynamic graphs, not static DAGs

O(1)
memory overhead for gradient checkpointing vs O(depth)
compute cost of gradient checkpointing — recompute forward during backward
VJP
vector-Jacobian product — what backward() actually computes
Maturity ladder

PyTorch's autograd is not a static computation graph that you compile and then execute. It is a tape: every tensor operation during the forward pass records itself onto the tape, and backward() unwinds the tape in reverse, applying the chain rule at each node. This distinction matters because it means the graph is rebuilt every forward pass — which enables dynamic control flow, variable-length sequences, and conditional computation, but also means that any Python-level mistake (detaching a tensor early, calling .numpy() inside a training loop, missing a .retain_grad()) silently corrupts the gradient computation without raising an error. Most production training failures — NaN losses at step 5000, gradient explosion on long sequences, silent precision loss that manifests as a plateau — root-cause to autograd misuse rather than model architecture. Understanding the tape is the prerequisite for debugging any of them.

How the tape works

Every PyTorch tensor that requires a gradient has a grad_fn attribute pointing to the operation that created it. During the forward pass, each operation appends itself to the global autograd graph: it records its inputs, its output, and a backward function that computes the VJP for that operation. When you call loss.backward(), PyTorch traverses this graph in topological reverse order, accumulating gradients into .grad attributes. The tape is discarded after backward() by default — call retain_graph=True only if you genuinely need multiple backward passes through the same graph.

The key correctness rule: if you detach a tensor from the graph (tensor.detach(), tensor.data, or .numpy()), gradients will not flow through it. This is intentional for frozen layers (LoRA adapters, for instance, freeze the base model by detaching or setting requires_grad=False). It is a bug when you do it accidentally inside a loss computation. The symptom is a zero gradient for all parameters upstream of the detach — the loss still computes to a finite value, which is why the bug is silent.

Custom backward functions via torch.autograd.Function are the correct way to implement fused ops — ops where you want the forward pass to use a memory-efficient implementation (e.g., FlashAttention tiles activation in SRAM rather than writing the full attention matrix to HBM) while the backward pass uses a separately optimized gradient computation. The backward() method receives the upstream gradient (grad_output) and must return a gradient for each input that requires grad.

Gradient checkpointing — the memory-compute trade

Without checkpointing, the peak memory during a backward pass is proportional to the number of transformer layers, because every intermediate activation from the forward pass must be kept alive until its gradient is computed. For a 70B-parameter model with 80 layers at batch size 4, this is gigabytes of intermediate activations that never leave GPU memory during training. Gradient checkpointing (Chen et al., arXiv:1604.06174, April 2016) breaks this by discarding activations after each forward pass and recomputing them during the backward pass. Peak memory drops from O(depth) to O(sqrt(depth)) with optimal segment placement, at the cost of one additional forward pass per backward — roughly 33% extra compute.

The PyTorch implementation is torch.utils.checkpoint.checkpoint(function, *inputs). For the JHU humanoid training pipeline, this is the primary knob for fitting GR00T N1.5 fine-tuning into a single A100 80GB — without checkpointing, the DiT activation memory alone exceeds the GPU budget at reasonable batch sizes.

Selective checkpointing — checkpointing only the attention layers (which dominate activation memory) while preserving FFN activations — recovers most of the memory benefit at lower compute overhead. This is the strategy used in most production training stacks including Megatron-LM and the CS336 training framework.

Debugging the gradient graph

torch.autograd.gradcheck(func, inputs) numerically verifies that the analytical gradient matches a finite-difference estimate. Run this on every custom backward you write before benchmarking it. The standard workflow: implement the forward pass, implement backward, call gradcheck with dtype=torch.float64 (finite differences need high precision), fix any discrepancies before moving to FP16/BF16. A gradient that passes gradcheck in FP64 but fails in BF16 indicates a numerical precision issue in the backward implementation, not a logical bug.

torch.autograd.set_detect_anomaly(True) inserts a check after every backward operation that verifies no gradients are NaN or Inf. This adds ~10-20% overhead and should not be left on in production, but is invaluable for localizing the first op that produces a bad gradient. The error message includes the Python stack trace of the forward op that created the offending tensor.

For the JHU humanoid capstone, custom backward functions appear in three places: (1) the FlashAttention kernel in the VLM backbone, which recomputes the Q, K, V from stored inputs rather than storing the full attention matrix; (2) any learned reward shaping you add to the RL loop in Isaac Lab, where the reward must be differentiable if you use gradient-based policy optimization; (3) the FLARE action loss in GR00T N1.5 training, which is a flow-matching loss that requires differentiating through the denoising process.

For DealLens, the autograd connection is less direct but still relevant: if you fine-tune a scoring head on top of a frozen LLM backbone, you must verify that requires_grad=False is correctly set on all base model parameters before the first optimizer step. A missed requires_grad on even one base layer parameter will cause that layer's gradient to accumulate, silently wasting memory and corrupting the optimizer state.

When to use torch.compile vs custom autograd

torch.compile (PyTorch 2.0+) fuses ops automatically via Triton and is the right first choice. Use custom torch.autograd.Function only when you need a numerically different backward pass — as in FlashAttention — not merely a faster one.

PyTorch Autograd — Tape Construction and ReversalFORWARD (tape records)Input xLinearRMSNormSoftmaxLossBACKWARD (tape unwinds — VJPs)dL/dLossdL/dSoftmaxdL/dNormdL/dWGradient checkpointing: discard activations after forward · recompute during backward · O(sqrt(depth)) peak memoryCustom ops: implement torch.autograd.Function · verify with gradcheck(dtype=float64) before BF16/FP16
Figure 3.1The PyTorch tape records ops during the forward pass; backward() unwinds it in reverse, applying VJPs to accumulate gradients. Gradient checkpointing reduces peak memory by discarding and recomputing activations.
Retrieve before you continue

Three questions on what you just read

Q1 Factual What does PyTorch's backward() actually compute, and why does it not compute the full Jacobian?
Q2 Conceptual Why does gradient checkpointing reduce peak memory from O(depth) to O(sqrt(depth))?
Q3 Synthetic In GR00T N1.5 fine-tuning on a single A100, what checkpointing strategy minimizes memory while preserving reasonable throughput?