Performer vs Softmax — From Kernel View to Fair Speed Tests
TL;DR. I reimplemented softmax attention and a Performer-style linear attention layer from first principles. My first attempt gave terrible perplexity and no speedup. After aligning the implementations (multi-head Wq/Wk/Wv, 1/√d_head scaling, causal prefix sums, unbiased random features), perplexity became close to softmax on TinyStories. Speed still didn’t beat softmax at seq_len=80—because the linear-time benefits only show up when context length N is large relative to the random-feature rank m and when kernels are well-fused. This post walks through the kernel view of softmax, the fixes that mattered, and a rigorous way to measure the crossover.
- Motivation
- Dataset Acquisition and Preprocessing
- Minimal, corrected Performer-style MHA (sketch)
- Experimental setup
- Code
- Results (representative)
- Why no speedup (yet)
- What I would do next
- References
For a single head, classic attention uses the exponential kernel $$\kappa(q,k) = e^{q\cdot k}.$$ Softmax attention for a query q_i can be written as a normalized kernel smoother: $$y_i = \frac{\sum_j \kappa(q_i,k_j)\,v_j}{\sum_j \kappa(q_i,k_j)}.$$ If we factor the kernel as an inner product of features $$\kappa(q,k)=\phi(q)^\top\psi(k),\quad y_i = \frac{\phi(q_i)^\top\Big(\sum_j \psi(k_j)v_j^\top\Big)}{\phi(q_i)^\top\Big(\sum_j \psi(k_j)\Big)}.$$ This moves the $$\sum{j}$$ outside the per-token computation, enabling global sums (bidirectional) or prefix sums (causal). That algebra is what turns O(N²) into O(N·m).
-
Random positive features (FAVOR+): draw set $$\phi(x)=\frac{1}{\sqrt m}\exp\big(\Omega x - \tfrac{1}{2}\|x\|^2\big),\qquad \Omega\sim\mathcal N(0,I).$$ As m↑, variance↓.
-
Taylor truncation: $$e^{q\cdot k} \approx \sum_{r=0}^{n} \frac{(q\cdot k)^r}{r!}.$$ with tensor-power features (theoretical lens; dims blow up as d^r). $$(q\cdot k)^r = (q^{\otimes r})\cdot (k^{\otimes r}).$$
- Limit definition: $$e^{x}=\lim_{n\to\infty}\big(1+\tfrac{x}{n}\big)^{n}.$$ Finite n → polynomial kernel realized via an n-fold tensor power of an augmented vector.
Core pieces that mattered in practice:
-
Standard Wq/Wk/Wv + output projection; split into heads.
-
/√d_head scaling applied to both q and k paths.
-
Positive random features (FAVOR+) with unbiased Ω ~ N(0, I).
-
Prefix sums for causal attention; global sums for bidirectional.
-
Small ε in denominators.
Tip: for an orthogonal Ω variant, multiply the orthonormal rows by chi-distributed radii so the rows emulate Gaussian draws.
- Data: TinyStories subset (train 50k, 20% valid).
- Tokenization: fastai TextBlock, seq_len=80.
- Model: 2 layers, embed_dim=64, n_head=4, GELU FFN×4, LayerNorm, dropout≈0.
- Training: 2 epochs, 1cycle, lr≈2e−3.
- Performer rank: proj_dim m ∈ {64,128} (final runs used m=128).
from fastai.text.all import *
from fastai.vision.all import *
class RandomFeatureMap(nn.Module):
"""
FAVOR+ style positive random features for k(q,k) = exp(q·k).
phi(x) = exp(Ω x - ||x||^2 / 2) / sqrt(m), with Ω ~ N(0,I).
"""
def __init__(self, dim: int, proj_dim: int, orthogonal: bool = False, seed: int | None = None):
super().__init__()
self.proj_dim, self.dim = proj_dim, dim
g = torch.Generator()
if seed is not None: g.manual_seed(seed)
if not orthogonal:
omega = torch.randn(proj_dim, dim, generator=g)
else:
# Proper Orthogonal Random Features (blocks of dim with Gaussian radii)
blocks = []
remain = proj_dim
while remain > 0:
b = min(dim, remain)
Q, _ = torch.linalg.qr(torch.randn(dim, dim, generator=g))
# draw radii so that rows emulate N(0, I)
radii = torch.sqrt(torch.distributions.Chi2(dim).sample((b,)))
blocks.append((radii.unsqueeze(1) * Q[:b, :]))
remain -= b
omega = torch.cat(blocks, dim=0)
self.register_buffer("omega", omega)
def phi(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, H, N, d)
returns: (B, H, N, m)
"""
proj = torch.einsum("bhnd,md->bhnm", x, self.omega) # (B,H,N,M)
sq = (x.pow(2).sum(dim=-1, keepdim=True)) * 0.5 # (B,H,N,1)
return torch.exp(proj - sq) / math.sqrt(self.proj_dim) # positive features
class MultiheadPerformerAttention(nn.Module):
def __init__(self, embed_dim: int, num_heads: int, proj_dim: int,
causal: bool = True, eps: float = 1e-6, orthogonal: bool = False):
super().__init__()
assert embed_dim % num_heads == 0
self.h = num_heads
self.d = embed_dim // num_heads
self.m = proj_dim
self.eps = eps
self.causal = causal
# Standard projections + output (like MHA)
self.wq = nn.Linear(embed_dim, embed_dim, bias=False)
self.wk = nn.Linear(embed_dim, embed_dim, bias=False)
self.wv = nn.Linear(embed_dim, embed_dim, bias=False)
self.wo = nn.Linear(embed_dim, embed_dim, bias=False)
self.feature = RandomFeatureMap(self.d, proj_dim, orthogonal=orthogonal)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, E = x.shape
H, D = self.h, self.d
# [B,N,E] -> [B,H,N,D], with 1/sqrt(d_head) temperature
q = self.wq(x).view(B, N, H, D).transpose(1, 2) / math.sqrt(D)
k = self.wk(x).view(B, N, H, D).transpose(1, 2) / math.sqrt(D)
v = self.wv(x).view(B, N, H, D).transpose(1, 2)
phi_q = self.feature.phi(q) # (B,H,N,M)
phi_k = self.feature.phi(k) # (B,H,N,M)
if not self.causal:
# Global sums
KV = torch.einsum("bhnm,bhnd->bhmd", phi_k, v) # (B,H,M,D)
z = torch.einsum("bhnm,bhm->bhn", phi_q, phi_k.sum(2)) # (B,H,N)
y = torch.einsum("bhnm,bhmd->bhnd", phi_q, KV) # (B,H,N,D)
y = y / (z.clamp_min(self.eps).unsqueeze(-1))
else:
# Causal: prefix sums over N
# Outer product per position: (B,H,N,M,D)
outer = phi_k.unsqueeze(-1) * v.unsqueeze(-2)
KVpref = outer.cumsum(dim=2) # (B,H,N,M,D)
bpref = phi_k.cumsum(dim=2) # (B,H,N,M)
# Contract M: (B,H,N,D) and (B,H,N)
y = torch.einsum("bhnm,bhnmd->bhnd", phi_q, KVpref)
z = torch.einsum("bhnm,bhnm->bhn", phi_q, bpref)
y = y / (z.clamp_min(self.eps).unsqueeze(-1))
y = y.transpose(1, 2).contiguous().view(B, N, E) # [B,N,E]
return self.wo(y)
class PerformerBlock(nn.Module):
def __init__(self, dim, n_head=8, proj_dim=64, causal=True, dropout=0.0, orthogonal=False):
super().__init__()
self.attn = MultiheadPerformerAttention(dim, n_head, proj_dim, causal=causal, orthogonal=orthogonal)
self.norm1 = nn.LayerNorm(dim)
self.ffn = nn.Sequential(
nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim)
)
self.norm2 = nn.LayerNorm(dim)
self.drop = nn.Dropout(dropout)
def forward(self, x):
x = self.norm1(x + self.drop(self.attn(x)))
x = self.norm2(x + self.drop(self.ffn(x)))
return x
from datasets import load_dataset
hf_ds = load_dataset("roneneldan/TinyStories") # {'train': Dataset, 'validation': Dataset}
TARGET_SIZE = 50_000
train_full = hf_ds["train"]
SEED = 42
rng = random.Random(SEED)
indices = list(range(len(train_full)))
rng.shuffle(indices)
# Keep only the first TARGET_SIZE indices (or the whole set if it’s smaller)
sub_idx = indices[:min(TARGET_SIZE, len(train_full))]
train_subset = train_full.select(sub_idx) # new Dataset with exactly TARGET_SIZE rows
# Directory where the cached tokenised tensors will live
cache_dir = Path("./cache_tinystories")
cache_dir.mkdir(parents=True, exist_ok=True)
# Define a fastai TextBlock that reads from the HF column “text”
# and automatically caches the tokenised output.
text_block = TextBlock.from_df(
text_cols="text",
is_lm=True,
seq_len=80, # keep your window size
cache_dir=cache_dir,
)
# Build the DataBlock – we give it the raw HF Dataset objects directly.
tinystories_block = DataBlock(
blocks=(text_block,),
get_x=ColReader("text"),
splitter=RandomSplitter(valid_pct=0.2, seed=42),
)
dls = tinystories_block.dataloaders(
train_subset,
bs=128,
num_workers=12,
pin_memory=True,
shuffle=True,
)
class SimpleMultiHeadAttention(Module):
"""
A lightweight multi-head attention that mimics fastai's implementation.
- `d_model` : hidden dimension (same as embed_dim)
- `n_head` : number of attention heads
- `bias` : whether to use bias in the linear projections
- `causal` : if True, applies an upper-triangular mask
"""
def __init__(self,
d_model: int,
n_head: int = 8,
bias: bool = True,
causal: bool = False,
attn_dropout: float = 0.):
assert d_model % n_head == 0, "d_model must be divisible by n_head"
self.n_head = n_head
self.d_head = d_model // n_head
self.scale = 1.0 / math.sqrt(self.d_head)
# Combined q,k,v projection for efficiency
self.W_qkv = nn.Linear(d_model, d_model * 3, bias=bias)
self.out_proj = nn.Linear(d_model, d_model, bias=bias)
self.dropout = nn.Dropout(attn_dropout)
self.causal = causal
def forward(self, x):
"""
x : (B, N, d_model)
returns (B, N, d_model)
"""
B, N, _ = x.shape
# qkv shape → (B, N, 3, n_head, d_head)
qkv = self.W_qkv(x).view(B, N, 3, self.n_head, self.d_head)
q, k, v = qkv.unbind(dim=2) # each -> (B, N, n_head, d_head)
# transpose for batched matmul: (B, n_head, N, d_head)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# scaled dot-product
attn_weights = (q @ k.transpose(-2, -1)) * self.scale # (B, n_head, N, N)
if self.causal:
# Upper-triangular mask: allow only j ≤ i
mask = torch.triu(torch.ones(N, N, device=x.device), diagonal=1).bool()
attn_weights = attn_weights.masked_fill(mask, float('-inf'))
attn = torch.softmax(attn_weights, dim=-1)
attn = self.dropout(attn)
# Weighted sum of values
out = (attn @ v) # (B, n_head, N, d_head)
out = out.transpose(1, 2).contiguous().view(B, N, -1) # (B, N, d_model)
return self.out_proj(out)
class TransformerBlock(Module):
"""
Minimal fastai-style transformer block.
Parameters match the original fastai class so you can use it interchangeably.
"""
def __init__(self,
d_model: int,
n_head: int = 8,
d_ff: int = None, # hidden size of the FFN; default = 4 * d_model
bias: bool = True,
dropout: float = 0., # dropout after attention & FFN
attn_dropout: float = 0., # dropout inside attention
ff_dropout: float = 0., # dropout inside the feed-forward
causal: bool = False,
act_fn: nn.Module = nn.GELU()):
self.attn = SimpleMultiHeadAttention(d_model,
n_head=n_head,
bias=bias,
causal=causal,
attn_dropout=attn_dropout)
self.ln1 = nn.LayerNorm(d_model)
# Feed-forward network (two linear layers + activation)
ff_dim = d_ff or 4 * d_model
self.ff = nn.Sequential(
nn.Linear(d_model, ff_dim, bias=bias),
act_fn,
nn.Dropout(ff_dropout),
nn.Linear(ff_dim, d_model, bias=bias),
nn.Dropout(dropout)
)
self.ln2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
"""
x : (B, N, d_model)
Returns the same shape after applying:
self-attention + residual + norm
feed-forward + residual + norm
"""
attn_out = self.attn(x)
x = x + self.dropout(attn_out)
x = self.ln1(x)
ff_out = self.ff(x)
x = x + ff_out
x = self.ln2(x)
return x
class SoftmaxLM(Module):
"""Standard fastai transformer with softmax attention."""
def __init__(self, vocab_sz, embed_dim, n_head, n_layer):
self.emb = nn.Embedding(vocab_sz, embed_dim)
self.pos = nn.Parameter(torch.randn(1, 512, embed_dim)) # static positional encodings
self.blocks = nn.ModuleList([
TransformerBlock(d_model=embed_dim,
n_head=n_head,
causal=causal) for _ in range(n_layer)
])
self.ln = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_sz)
def forward(self, x):
b, n = x.shape
x = self.emb(x) + self.pos[:, :n]
for blk in self.blocks:
x = blk(x)
x = self.ln(x)
return self.head(x)
class PerformerLM(Module):
"""Transformer that swaps the softmax block for Performer-style linear attention."""
def __init__(self, vocab_sz, embed_dim, n_head, n_layer, proj_dim, causal=False):
self.emb = nn.Embedding(vocab_sz, embed_dim)
self.pos = nn.Parameter(torch.randn(1, 512, embed_dim))
self.blocks = nn.ModuleList([
PerformerBlock(dim=embed_dim,
n_head=n_head,
proj_dim=proj_dim,
causal=causal) for _ in range(n_layer)
])
self.ln = nn.LayerNorm(embed_dim)
self.head = nn.Linear(embed_dim, vocab_sz)
def forward(self, x):
b, n = x.shape
x = self.emb(x) + self.pos[:, :n]
for blk in self.blocks:
x = blk(x)
x = self.ln(x)
return self.head(x)
class EpochTimer(Callback):
"""Records wall-clock time for each epoch."""
def before_fit(self):
self.epoch_times = []
def after_epoch(self):
self.epoch_times.append(time.time() - self.epoch_start)
def before_epoch(self):
self.epoch_start = time.time()
def train_one(model, name, epochs=2, lr=2e-3):
learn = Learner(dls, model,
loss_func=CrossEntropyLossFlat(),
metrics=[Perplexity()], # fastai's built-in perplexity metric
cbs=EpochTimer())
print(f"\n=== Training {name} ===")
learn.fit_one_cycle(epochs, lr)
final_ppl = float(learn.validate()[1])
return learn.cbs[-1].epoch_times, final_ppl
softmax_model = SoftmaxLM(vocab_sz, embed_dim, n_head, n_layer)
soft_times, soft_ppl = train_one(softmax_model, "Softmax-Transformer")
performer_model = PerformerLM(vocab_sz, embed_dim, n_head, n_layer,
proj_dim=proj_dim, causal=causal)
perf_times, perf_ppl = train_one(performer_model, "Performer-Transformer")
summary = pd.DataFrame({
"Model": ["Softmax-Transformer", "Performer-Transformer"],
"Final Perplexity": [soft_ppl, perf_ppl],
"Epoch times (sec)": [soft_times, perf_times],
"Avg time (sec)": [sum(soft_times)/len(soft_times),
sum(perf_times)/len(perf_times)]
})
print("\n=== Comparison ===")
print(summary.to_string(index=False))
Model | Valid PPL | Epoch time (s) |
---|---|---|
Softmax Transformer | ~1.09 | ~127 |
Performer Transformer (m=128) | ~1.13 | ~128 |
Perplexity parity achieved; no speedup at N=80.
Back-of-envelope FLOPs per head (ignoring projections/FFN):
Softmax: $4𝑁^{2}𝐷$
Performer: 𝑐⋅𝑁⋅𝑚⋅𝐷, with c≈6–10 (feature projections, exp, and cumsums).
Break-even: $$4N^2D \approx c\,N\,m\,D\;\Rightarrow\; N \approx \frac{c}{4}\,m.$$
With c≈8 and m=128, break-even 𝑁≈256
At short contexts (N≪256), softmax’s O(N²) cost is still small, and modern softmax kernels (FlashAttention/Xformers) are heavily fused. A naïve einsum-based linear attention will not be faster.
What I would do next
-
Long-context grid: (N×m) sweep to locate the speed crossover on my GPU.
-
Fused kernels: compare Softmax(FlashAttention) vs Performer(fused FAVOR+) to reflect realistic deployments.
-
Hybrid maps: combine a small random-feature bank with a low-degree deterministic Taylor tail to reduce variance.
-
Layer-wise m: try larger m only in upper layers or selected heads.
-
Task sensitivity: evaluate on tasks where attention patterns are smoother (long-range modeling), where linear attention often shines.
References
-
Su Jianlin — perspective that softmax attention is linear in an infinite-dimensional space and constructive feature maps for exp kernels.
-
Choromanski et al., Rethinking Attention with Performers (ICLR 2021): FAVOR+ positive random features and linear-time attention.
-
Dao et al., FlashAttention series: fused softmax attention kernels for long contexts.