Memory-Efficient Backprop: Solving Unsloth's Challenge E
- Starting Point: What I Knew (and Didn't Know)
- Building the Foundation:
- Manual gradient computation
- Checkpointing Intuition
- The Implementation
- Section and Benchmarks
- Real Model Integration: OPT-125M
- Reflections and What's Next
- What's Next: GRPO and Beyond
Starting Point: What I Knew (and Didn't Know)
When I first encountered Unsloth's Challenges, I got hooked, they seemed like a perfect way to test my understanding and identify knowledge gaps. I chose Challenge E on memory-efficient backpropagation as my starting point: it looked approachable given my foundations, but I had to be honest with myself about what I didn't know.
What I brought to the table:
- Solid PyTorch fundamentals, but I'd never worked with torch.autograd.Function or gradient checkpointing in detail
- Strong mathematical foundations: calculus (chain rule) and linear algebra came naturally
The Challenge:
The problem was clear: in LLMs, the final layer projects hidden states to vocabulary logits using $\sigma(XW)$. For large vocabularies (128K tokens), this creates massive memory spikes.
The numbers: With batch=4, seq_len=4096, hidden_dim=4096, vocab=128K, the logits alone consume 4GB in bfloat16 (8GB in float32).
The goal: Implement a memory-efficient backpropagation that:
- Reduces VRAM usage by 50%
- Doesn't hardcode gradients (must use autograd)
- Works with cross entropy and other loss functions
- Supports dynamic chunk sizes
- Matches training loss on real models (Llama 1B)
My approach: Rather than dive straight into code, I wanted to fill the knowledge gaps to build understanding systematically.
Building the Foundation:
Rather than jumping straight to implementation, I mapped out prerequisite concepts as a dependency graph, each node representing a focused exercise building on previous ones.
Layer 1: Core Autograd Mechanics
Node 1 - Custom autograd basics: I started by understanding torch.autograd.Function. The key insight came from implementing a simple Square function:
class Square(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return x ** 2
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
return grad_output * 2 * x
My "aha" moment: grad_output is the upstream gradient, and we multiply it by the local gradient (2x). The chain rule in action.
Manual gradient computation
Implementing ReLU
To cement my understanding, I implemented a custom ReLU:
class ReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.clamp(x, min=0)
@staticmethod
def backward(ctx, grad_output):
x, = ctx.saved_tensors
local_grad = torch.zeros_like(x)
local_grad[x > 0] = 1
return grad_output * local_grad
Node 2: Manual Gradient Computation
Next, I derived gradients for a linear layer by hand. Given Y = X @ W and loss = Y.sum():
dL/dY = tensor of ones (shape matches Y) dL/dW = X.T @ dL/dY dL/dX = dL/dY @ W.T
I verified these against PyTorch's autograd:
x = torch.randn(size=(1, 3), requires_grad=True)
w = torch.randn(size=(3, 2), requires_grad=True)
y = x @ w
l = y.sum()
l.backward()
dl_dy = torch.ones_like(y)
dl_dw = x.T @ dl_dy
dl_dx = dl_dy @ w.T
assert torch.allclose(w.grad, dl_dw)
assert torch.allclose(x.grad, dl_dx)
Checkpointing Intuition
Node 3: The memory-compute tradeoff:
Another "aha" moment came when I understood what we're actually checkpointing:
| Approach | What's stored | Peak memory |
|---|---|---|
| Normal | Full logits [batch × seq × vocab] | 4GB |
| Chunked (2 chunks) | Half the logits at a time | 2GB |
We save the inputs (X, labels), throw away the massive logits, and recompute them chunk-by-chunk during backward. The key insight: gradients from independent chunks can simply be summed, that's gradient accumulation.
The Implementation
Node 5-6: Building the Chunked Forward and Backward
With the foundations in place, I started implementing MemoryEfficientLinear. The forward pass seemed straightforward: chunk the inputs, compute losses, combine with weighted average.
My first forward pass attempt:
def forward(ctx, X, linear, labels, forward_function, num_chunks):
Xs = torch.chunk(X, num_chunks, dim=0)
Ys = torch.chunk(labels, num_chunks)
losses = []
ws = []
for i, _ in enumerate(Xs):
loss = forward_function(Xs[i], linear, Ys[i])
ws.append(torch.tensor(len(Xs[i])))
losses.append(loss)
ws = torch.stack(ws)
losses = torch.stack(losses)
ws = ws / ws.sum()
ctx.save_for_backward(X, labels)
ctx.linear = linear
ctx.forward_function = forward_function
ctx.num_chunks = num_chunks
ctx.ws = ws
return (ws * losses).sum()
Why weighted average?
Each chunk might have different sizes (especially the last one), and since CrossEntropyLoss uses reduction="mean", I needed to weight by chunk size to get the correct overall mean.
The Backward Pass — Where Things Got Tricky
The backward implementation required recomputing the forward pass chunk-by-chunk:
def backward(ctx, grad_output):
X, labels = ctx.saved_tensors
linear = ctx.linear
forward_function = ctx.forward_function
num_chunks = ctx.num_chunks
ws = ctx.ws
linear.zero_grad()
X_chunks = torch.chunk(X, num_chunks, dim=0)
label_chunks = torch.chunk(labels, num_chunks)
dX_chunks = []
for i, (x_chunk, label_chunk) in enumerate(zip(X_chunks, label_chunks)):
x_chunk = x_chunk.clone().requires_grad_()
with torch.enable_grad():
loss = forward_function(x_chunk, linear, label_chunk)
loss.backward(grad_output * ws[i]) # Weighted!
dX_chunks.append(x_chunk.grad.detach().clone())
dX = torch.cat(dX_chunks, dim=0)
return dX, None, None, None, None
Bug Hunt #1: The Mystery of 2× Gradients
When I first tested this, the assertions failed:
Normal weight grad: [[ 0.0737, 0.0557, 0.1557], ...]
Efficient weight grad: [[ 0.1473, 0.1115, 0.3114], ...]
The efficient gradients were exactly double!
Debugging process:
- Added print statements: discovered linear.weight.grad already had values at the start of backward
- Realized the forward pass was creating a computation graph connecting to linear.weight
- The fix: Wrap forward computations in torch.no_grad() to prevent double-counting
Bug Hunt #2: The Weighted Gradient Mystery
Even after fixing the torch.no_grad() issue, I was still getting 2× gradients. More debugging:
print(f"Chunk loss requires_grad: {loss.requires_grad}")
# Output: False
The chunk losses weren't tracking gradients! The problem: When I chunked the saved tensors in backward, they were detached from the graph.
The fix: Wrap the forward recomputation in with torch.enable_grad():
But there was still one more issue — I was passing the full grad_output to each chunk's backward, when I should have been scaling by the weight:
loss.backward(grad_output * ws[i]) # Not just grad_output!
Why? The forward pass computed a weighted average: (ws * losses).sum(). By the chain rule:
dL/d(loss_i) = ws[i] * grad_output Each chunk contributes proportionally to the final loss, so its gradient must be scaled accordingly.
After this fix, the assertions finally passed!
Bug Hunt #3: Device Mismatch on GPU
When I moved to GPU testing for memory benchmarks:
RuntimeError: Expected all tensors to be on the same device,
but found at least two devices, cuda:0 and cpu!
The culprit: torch.tensor(len(Xs[i])) creates tensors on CPU by default.
The fix: Specify device explicitly:
ws.append(torch.tensor(len(Xs[i]), device=X.device))
Section and Benchmarks
Validating Correctness
First, I verified the implementation produced identical results to the standard approach:
# Normal approach
loss = transformation_function(X, linear, labels)
loss.backward()
normal_weight_grad = linear.weight.grad.clone()
# Efficient approach
loss_efficient = MemoryEfficientLinear.apply(X, linear, labels, transformation_function, 2)
loss_efficient.backward()
assert torch.allclose(normal_weight_grad, linear.weight.grad, atol=1e-6)
Testing other loss functions (MSE):
def mse_transformation_function(batch, linear, labels):
x = linear(batch).float()
loss_fn = MSELoss()
return loss_fn(x, labels)
Both cross entropy and MSE passed — the implementation generalizes!
Memory Benchmarks
Toy example (batch=4096, hidden=1024, vocab=8192):
Normal peak: 0.69 GB
Efficient (4 chunks): 0.37 GB
Savings: 46.2%
Nearly hit the 50% target! As expected, more chunks = more savings.
Real Model Integration: OPT-125M
I integrated with Hugging Face's OPT-125M by creating a wrapper:
class MemoryEfficientLMHead(nn.Module):
def __init__(self, original_lm_head, num_chunks=4):
super().__init__()
self.linear = original_lm_head
self.num_chunks = num_chunks
def forward(self, hidden_states, labels=None):
if labels is None:
return self.linear(hidden_states) # Inference
batch, seq_len, hidden = hidden_states.shape
hidden_flat = hidden_states.view(-1, hidden)
labels_flat = labels.view(-1)
loss = MemoryEfficientLinear.apply(
hidden_flat, self.linear, labels_flat,
transformation_function, self.num_chunks
)
return loss
Results (batch=8, seq_len=512):
Normal: 3.20 GB
Efficient: 2.70 GB
Savings: 15.7%
Loss match: (8.556 vs 8.556)
Lower savings percentage than the toy example because model weights and other activations dominate memory, but the logits reduction is working as expected.
Reflections and What's Next
What I Learned
This challenge was a masterclass in how theory meets practice. On paper, "chunk the computation and accumulate gradients" sounds simple. In reality:
- Autograd is subtle: Understanding when gradients flow, when they accumulate, and when they get zeroed required careful reasoning about PyTorch's execution model
- Debugging is systematic: The 2× gradient bug taught me to add strategic print statements and work backwards from symptoms to root causes
- Weighted gradients matter: The chain rule doesn't just apply to functions — it applies to how you combine losses
The biggest insight: Breaking down complex problems into a dependency graph of smaller exercises isn't just good pedagogy, it's how you build robust implementations. Each node gave me one piece of the puzzle, and when bugs appeared, I knew exactly which piece to inspect.
Scoring Against the Rubric
According to Unsloth's criteria:
- VRAM 50% reduction: (2 points)
- Cross entropy works: (1 point)
- Other functions work: (1 point)
- No hardcoded gradients: Using autograd throughout (required)
- Dynamic chunk sizes: Tested 1, 2, 4, 8 chunks (1 point)
- Real model integration: (1 point conceptually :D still have to test with LLama-1B)
What's Next: GRPO and Beyond
The remaining 4 points come from GRPO (Group Relative Policy Optimization) integration — a reinforcement learning technique for LLM fine-tuning.
The challenge: GRPO generates multiple responses per prompt and needs log probabilities (not just losses) for all of them. This requires:
- Generalizing MemoryEfficientLinear to return vectors instead of scalars
- Handling per-token gradients instead of aggregated loss gradients
- Integration with RL training loops
Conceptual approach:
- Modify transformation function to return log probabilities: [N] instead of scalar
- Chunk and concatenate outputs instead of weighted averaging
- Slice grad_output to match each chunk during backward
This is where the implementation gets significantly more complex, but the foundation is solid.
What started as "I don't know custom autograd" became a working memory-efficient backpropagation system tested on real models. The journey reinforced that in ML systems work, understanding the fundamentals deeply autograd mechanics, gradient flow, memory management is what separates implementations that "kind of work" from ones that are correct, efficient, and generalizable.