Motivation

Softmax attention is O(N²) in sequence length. Linear attention (e.g., Performer) replaces the exponential kernel with a low-rank feature map so the compute becomes O(N·m). In theory, that should be faster—so why did my early experiments show and no speedup?

Dataset Acquisition and Preprocessing

For a single head, classic attention uses the exponential kernel $$\kappa(q,k) = e^{q\cdot k}.$$ Softmax attention for a query q_i can be written as a normalized kernel smoother: $$y_i = \frac{\sum_j \kappa(q_i,k_j)\,v_j}{\sum_j \kappa(q_i,k_j)}.$$ If we factor the kernel as an inner product of features $$\kappa(q,k)=\phi(q)^\top\psi(k),\quad y_i = \frac{\phi(q_i)^\top\Big(\sum_j \psi(k_j)v_j^\top\Big)}{\phi(q_i)^\top\Big(\sum_j \psi(k_j)\Big)}.$$ This moves the $$\sum{j}$$ outside the per-token computation, enabling global sums (bidirectional) or prefix sums (causal). That algebra is what turns O(N²) into O(N·m).

Three constructive feature maps for exp(q·k)

  1. Random positive features (FAVOR+): draw set $$\phi(x)=\frac{1}{\sqrt m}\exp\big(\Omega x - \tfrac{1}{2}\|x\|^2\big),\qquad \Omega\sim\mathcal N(0,I).$$ As m↑, variance↓.

  2. Taylor truncation: $$e^{q\cdot k} \approx \sum_{r=0}^{n} \frac{(q\cdot k)^r}{r!}.$$ with tensor-power features (theoretical lens; dims blow up as d^r). $$(q\cdot k)^r = (q^{\otimes r})\cdot (k^{\otimes r}).$$

  1. Limit definition: $$e^{x}=\lim_{n\to\infty}\big(1+\tfrac{x}{n}\big)^{n}.$$ Finite n → polynomial kernel realized via an n-fold tensor power of an augmented vector.

Minimal, corrected Performer-style MHA (sketch)

Core pieces that mattered in practice:

  • Standard Wq/Wk/Wv + output projection; split into heads.

  • /√d_head scaling applied to both q and k paths.

  • Positive random features (FAVOR+) with unbiased Ω ~ N(0, I).

  • Prefix sums for causal attention; global sums for bidirectional.

  • Small ε in denominators.

    Tip: for an orthogonal Ω variant, multiply the orthonormal rows by chi-distributed radii so the rows emulate Gaussian draws.

Experimental setup

  • Data: TinyStories subset (train 50k, 20% valid).
  • Tokenization: fastai TextBlock, seq_len=80.
  • Model: 2 layers, embed_dim=64, n_head=4, GELU FFN×4, LayerNorm, dropout≈0.
  • Training: 2 epochs, 1cycle, lr≈2e−3.
  • Performer rank: proj_dim m ∈ {64,128} (final runs used m=128).

Code

from fastai.text.all import *
from fastai.vision.all import *

class RandomFeatureMap(nn.Module):
    """
    FAVOR+ style positive random features for k(q,k) = exp(q·k).
    phi(x) = exp(Ω x - ||x||^2 / 2) / sqrt(m), with Ω ~ N(0,I).
    """
    def __init__(self, dim: int, proj_dim: int, orthogonal: bool = False, seed: int | None = None):
        super().__init__()
        self.proj_dim, self.dim = proj_dim, dim

        g = torch.Generator()
        if seed is not None: g.manual_seed(seed)

        if not orthogonal:
            omega = torch.randn(proj_dim, dim, generator=g)
        else:
            # Proper Orthogonal Random Features (blocks of dim with Gaussian radii)
            blocks = []
            remain = proj_dim
            while remain > 0:
                b = min(dim, remain)
                Q, _ = torch.linalg.qr(torch.randn(dim, dim, generator=g))
                # draw radii so that rows emulate N(0, I)
                radii = torch.sqrt(torch.distributions.Chi2(dim).sample((b,)))
                blocks.append((radii.unsqueeze(1) * Q[:b, :]))
                remain -= b
            omega = torch.cat(blocks, dim=0)

        self.register_buffer("omega", omega)

    def phi(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, H, N, d)
        returns: (B, H, N, m)
        """
        proj = torch.einsum("bhnd,md->bhnm", x, self.omega)              # (B,H,N,M)
        sq   = (x.pow(2).sum(dim=-1, keepdim=True)) * 0.5                # (B,H,N,1)
        return torch.exp(proj - sq) / math.sqrt(self.proj_dim)           # positive features


class MultiheadPerformerAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, proj_dim: int,
                 causal: bool = True, eps: float = 1e-6, orthogonal: bool = False):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.h  = num_heads
        self.d  = embed_dim // num_heads
        self.m  = proj_dim
        self.eps = eps
        self.causal = causal

        # Standard projections + output (like MHA)
        self.wq = nn.Linear(embed_dim, embed_dim, bias=False)
        self.wk = nn.Linear(embed_dim, embed_dim, bias=False)
        self.wv = nn.Linear(embed_dim, embed_dim, bias=False)
        self.wo = nn.Linear(embed_dim, embed_dim, bias=False)

        self.feature = RandomFeatureMap(self.d, proj_dim, orthogonal=orthogonal)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, N, E = x.shape
        H, D = self.h, self.d

        # [B,N,E] -> [B,H,N,D], with 1/sqrt(d_head) temperature
        q = self.wq(x).view(B, N, H, D).transpose(1, 2) / math.sqrt(D)
        k = self.wk(x).view(B, N, H, D).transpose(1, 2) / math.sqrt(D)
        v = self.wv(x).view(B, N, H, D).transpose(1, 2)

        phi_q = self.feature.phi(q)     # (B,H,N,M)
        phi_k = self.feature.phi(k)     # (B,H,N,M)

        if not self.causal:
            # Global sums
            KV = torch.einsum("bhnm,bhnd->bhmd", phi_k, v)                  # (B,H,M,D)
            z  = torch.einsum("bhnm,bhm->bhn",     phi_q, phi_k.sum(2))     # (B,H,N)
            y  = torch.einsum("bhnm,bhmd->bhnd",   phi_q, KV)               # (B,H,N,D)
            y  = y / (z.clamp_min(self.eps).unsqueeze(-1))
        else:
            # Causal: prefix sums over N
            # Outer product per position: (B,H,N,M,D)
            outer = phi_k.unsqueeze(-1) * v.unsqueeze(-2)
            KVpref = outer.cumsum(dim=2)                                     # (B,H,N,M,D)
            bpref  = phi_k.cumsum(dim=2)                                     # (B,H,N,M)
            # Contract M: (B,H,N,D) and (B,H,N)
            y  = torch.einsum("bhnm,bhnmd->bhnd", phi_q, KVpref)
            z  = torch.einsum("bhnm,bhnm->bhn",   phi_q, bpref)
            y  = y / (z.clamp_min(self.eps).unsqueeze(-1))

        y = y.transpose(1, 2).contiguous().view(B, N, E)                     # [B,N,E]
        return self.wo(y)

class PerformerBlock(nn.Module):
    def __init__(self, dim, n_head=8, proj_dim=64, causal=True, dropout=0.0, orthogonal=False):
        super().__init__()
        self.attn = MultiheadPerformerAttention(dim, n_head, proj_dim, causal=causal, orthogonal=orthogonal)
        self.norm1 = nn.LayerNorm(dim)
        self.ffn   = nn.Sequential(
            nn.Linear(dim, dim*4), nn.GELU(), nn.Linear(dim*4, dim)
        )
        self.norm2 = nn.LayerNorm(dim)
        self.drop  = nn.Dropout(dropout)

    def forward(self, x):
        x = self.norm1(x + self.drop(self.attn(x)))
        x = self.norm2(x + self.drop(self.ffn(x)))
        return x

from datasets import load_dataset

hf_ds = load_dataset("roneneldan/TinyStories")   # {'train': Dataset, 'validation': Dataset}

TARGET_SIZE = 50_000                
train_full  = hf_ds["train"]

SEED = 42
rng = random.Random(SEED)
indices = list(range(len(train_full)))
rng.shuffle(indices)

# Keep only the first TARGET_SIZE indices (or the whole set if it’s smaller)
sub_idx = indices[:min(TARGET_SIZE, len(train_full))]
train_subset = train_full.select(sub_idx)   # new Dataset with exactly TARGET_SIZE rows

# Directory where the cached tokenised tensors will live
cache_dir = Path("./cache_tinystories")
cache_dir.mkdir(parents=True, exist_ok=True)

# Define a fastai TextBlock that reads from the HF column “text”
# and automatically caches the tokenised output.
text_block = TextBlock.from_df(
    text_cols="text",
    is_lm=True,
    seq_len=80,                 # keep your window size
    cache_dir=cache_dir,
)

# Build the DataBlock – we give it the raw HF Dataset objects directly.
tinystories_block = DataBlock(
    blocks=(text_block,),
    get_x=ColReader("text"),   
    splitter=RandomSplitter(valid_pct=0.2, seed=42),   
)

dls = tinystories_block.dataloaders(
    train_subset,               
    bs=128,
    num_workers=12,
    pin_memory=True,
    shuffle=True,
)

class SimpleMultiHeadAttention(Module):
    """
    A lightweight multi-head attention that mimics fastai's implementation.
    - `d_model` : hidden dimension (same as embed_dim)
    - `n_head`  : number of attention heads
    - `bias`    : whether to use bias in the linear projections
    - `causal`  : if True, applies an upper-triangular mask
    """
    def __init__(self,
                 d_model: int,
                 n_head: int = 8,
                 bias: bool = True,
                 causal: bool = False,
                 attn_dropout: float = 0.):
        
        assert d_model % n_head == 0, "d_model must be divisible by n_head"
        self.n_head   = n_head
        self.d_head   = d_model // n_head
        self.scale    = 1.0 / math.sqrt(self.d_head)

        # Combined q,k,v projection for efficiency
        self.W_qkv = nn.Linear(d_model, d_model * 3, bias=bias)

        self.out_proj  = nn.Linear(d_model, d_model, bias=bias)
        self.dropout   = nn.Dropout(attn_dropout)
        self.causal    = causal

    def forward(self, x):
        """
        x : (B, N, d_model)
        returns (B, N, d_model)
        """
        B, N, _ = x.shape
        # qkv shape → (B, N, 3, n_head, d_head)
        qkv = self.W_qkv(x).view(B, N, 3, self.n_head, self.d_head)
        q, k, v = qkv.unbind(dim=2)                     # each -> (B, N, n_head, d_head)

        # transpose for batched matmul: (B, n_head, N, d_head)
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # scaled dot-product
        attn_weights = (q @ k.transpose(-2, -1)) * self.scale   # (B, n_head, N, N)

        if self.causal:
            # Upper-triangular mask: allow only j ≤ i
            mask = torch.triu(torch.ones(N, N, device=x.device), diagonal=1).bool()
            attn_weights = attn_weights.masked_fill(mask, float('-inf'))

        attn = torch.softmax(attn_weights, dim=-1)
        attn = self.dropout(attn)

        # Weighted sum of values
        out = (attn @ v)                                   # (B, n_head, N, d_head)
        out = out.transpose(1, 2).contiguous().view(B, N, -1)  # (B, N, d_model)

        return self.out_proj(out)

class TransformerBlock(Module):
    """
    Minimal fastai-style transformer block.
    Parameters match the original fastai class so you can use it interchangeably.
    """
    def __init__(self,
                 d_model: int,
                 n_head: int = 8,
                 d_ff: int = None,          # hidden size of the FFN; default = 4 * d_model
                 bias: bool = True,
                 dropout: float = 0.,       # dropout after attention & FFN
                 attn_dropout: float = 0., # dropout inside attention
                 ff_dropout: float = 0.,    # dropout inside the feed-forward
                 causal: bool = False,
                 act_fn: nn.Module = nn.GELU()):
        
        self.attn = SimpleMultiHeadAttention(d_model,
                                            n_head=n_head,
                                            bias=bias,
                                            causal=causal,
                                            attn_dropout=attn_dropout)

        self.ln1 = nn.LayerNorm(d_model)

        # Feed-forward network (two linear layers + activation)
        ff_dim = d_ff or 4 * d_model
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_dim, bias=bias),
            act_fn,
            nn.Dropout(ff_dropout),
            nn.Linear(ff_dim, d_model, bias=bias),
            nn.Dropout(dropout)
        )

        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        """
        x : (B, N, d_model)
        Returns the same shape after applying:
          self-attention + residual + norm
          feed-forward + residual + norm
        """
        attn_out = self.attn(x)
        x = x + self.dropout(attn_out)
        x = self.ln1(x)

        ff_out = self.ff(x)
        x = x + ff_out
        x = self.ln2(x)

        return x


class SoftmaxLM(Module):
    """Standard fastai transformer with softmax attention."""
    def __init__(self, vocab_sz, embed_dim, n_head, n_layer):
        self.emb  = nn.Embedding(vocab_sz, embed_dim)
        self.pos  = nn.Parameter(torch.randn(1, 512, embed_dim))   # static positional encodings
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model=embed_dim,
                             n_head=n_head,
                             causal=causal) for _ in range(n_layer)
        ])
        self.ln   = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_sz)

    def forward(self, x):
        b, n = x.shape
        x = self.emb(x) + self.pos[:, :n]
        for blk in self.blocks:
            x = blk(x)
        x = self.ln(x)
        return self.head(x)


class PerformerLM(Module):
    """Transformer that swaps the softmax block for Performer-style linear attention."""
    def __init__(self, vocab_sz, embed_dim, n_head, n_layer, proj_dim, causal=False):
        self.emb  = nn.Embedding(vocab_sz, embed_dim)
        self.pos  = nn.Parameter(torch.randn(1, 512, embed_dim))
        self.blocks = nn.ModuleList([
            PerformerBlock(dim=embed_dim,
                          n_head=n_head,
                          proj_dim=proj_dim,
                          causal=causal) for _ in range(n_layer)
        ])
        self.ln   = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_sz)

    def forward(self, x):
        b, n = x.shape
        x = self.emb(x) + self.pos[:, :n]
        for blk in self.blocks:
            x = blk(x)
        x = self.ln(x)
        return self.head(x)


class EpochTimer(Callback):
    """Records wall-clock time for each epoch."""
    def before_fit(self):
        self.epoch_times = []

    def after_epoch(self):
        self.epoch_times.append(time.time() - self.epoch_start)

    def before_epoch(self):
        self.epoch_start = time.time()


def train_one(model, name, epochs=2, lr=2e-3):
    learn = Learner(dls, model,
                    loss_func=CrossEntropyLossFlat(),
                    metrics=[Perplexity()],         # fastai's built-in perplexity metric
                    cbs=EpochTimer())
    print(f"\n=== Training {name} ===")
    learn.fit_one_cycle(epochs, lr)
    final_ppl = float(learn.validate()[1])
    return learn.cbs[-1].epoch_times, final_ppl


softmax_model = SoftmaxLM(vocab_sz, embed_dim, n_head, n_layer)
soft_times, soft_ppl = train_one(softmax_model, "Softmax-Transformer")

performer_model = PerformerLM(vocab_sz, embed_dim, n_head, n_layer,
                                proj_dim=proj_dim, causal=causal)
perf_times, perf_ppl = train_one(performer_model, "Performer-Transformer")

summary = pd.DataFrame({
    "Model": ["Softmax-Transformer", "Performer-Transformer"],
    "Final Perplexity": [soft_ppl, perf_ppl],
    "Epoch times (sec)": [soft_times, perf_times],
    "Avg time (sec)": [sum(soft_times)/len(soft_times),
                        sum(perf_times)/len(perf_times)]
})
print("\n=== Comparison ===")
print(summary.to_string(index=False))

Results (representative)

Model Valid PPL Epoch time (s)
Softmax Transformer ~1.09 ~127
Performer Transformer (m=128) ~1.13 ~128

Perplexity parity achieved; no speedup at N=80.

Why no speedup (yet)

Back-of-envelope FLOPs per head (ignoring projections/FFN):

Softmax:  $4𝑁^{2}𝐷$  

Performer:  𝑐⋅𝑁⋅𝑚⋅𝐷, with c≈6–10 (feature projections, exp, and cumsums).  

Break-even: $$4N^2D \approx c\,N\,m\,D\;\Rightarrow\; N \approx \frac{c}{4}\,m.$$

With c≈8 and m=128, break-even 𝑁≈256

At short contexts (N≪256), softmax’s O(N²) cost is still small, and modern softmax kernels (FlashAttention/Xformers) are heavily fused. A naïve einsum-based linear attention will not be faster.

What I would do next

  • Long-context grid: (N×m) sweep to locate the speed crossover on my GPU.

  • Fused kernels: compare Softmax(FlashAttention) vs Performer(fused FAVOR+) to reflect realistic deployments.

  • Hybrid maps: combine a small random-feature bank with a low-degree deterministic Taylor tail to reduce variance.

  • Layer-wise m: try larger m only in upper layers or selected heads.

  • Task sensitivity: evaluate on tasks where attention patterns are smoother (long-range modeling), where linear attention often shines.

References

  • Su Jianlin — perspective that softmax attention is linear in an infinite-dimensional space and constructive feature maps for exp kernels.

  • Choromanski et al., Rethinking Attention with Performers (ICLR 2021): FAVOR+ positive random features and linear-time attention.

  • Dao et al., FlashAttention series: fused softmax attention kernels for long contexts.