Introduction

Transformers have revolutionized deep learning across domains from NLP to computer vision, but how do they actually work? While we know how to build and train these models, their internal mechanics remain somewhat mysterious.

Recently I came across this paper Attention is not all you need which tries to highlight how attention is not all you need and how MLP and skip connections play huge role in learning. We'll start by exploring two fascinating concepts: attention paths and rank collapse.

The Mathematical Foundation of Self-Attention

Before diving into the more complex concepts, let's review the basic building blocks of self-attention networks (SANs).

In a SAN, the input is a sequence of vectors $X = [x_1, x_2, ..., x_N]$, where $x_i \in \mathbb{R}^d$ represents a token embedding, $N$ is the sequence length, and $d$ is the embedding dimension.

Queries, Keys, Values

For each attention head, we transform the input into three roles:

  • Queries: $Q = XW_Q$, where $W_Q \in \mathbb{R}^{d \times d_k}$
  • Keys: $K = XW_K$, where $W_K \in \mathbb{R}^{d \times d_k}$
  • Values: $V = XW_V$, where $W_V \in \mathbb{R}^{d \times d_v}$

Here, $d_k$ and $d_v$ are dimensions of keys and values, typically $d_k = d_v = d/H$ where $H$ is the number of heads.

Attention Scores

We compute how much each token attends to others:

$A = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$

  • $QK^T \in \mathbb{R}^{N \times N}$ is a matrix of dot products measuring similarity between queries and keys
  • $\sqrt{d_k}$ scales the scores to prevent large values that could destabilize the softmax
  • $A_{ij}$ is the attention weight from token $i$ to token $j$

Output

The output of a single attention head is:

$\text{head} = AV$

So, $\text{head} \in \mathbb{R}^{N \times d_v}$.

Multi-Head Attention

The complete multi-head attention operation uses $H$ heads, each with its own $W_Q^h, W_K^h, W_V^h$:

$\text{MultiHead}(X) = \text{Concat}(\text{head}_1, ..., \text{head}_H)W_O$

where $W_O \in \mathbb{R}^{Hd_v \times d}$ projects the concatenated outputs back to $d$-dimensional space.

Layer Structure

In a SAN layer, we add residual connections and an MLP:

$X' = X + \text{MultiHead}(X) \text{ (residual connection)}$

$X_{out} = X' + \text{MLP}(X') \text{ (another residual connection)}$

The MLP typically consists of two linear layers with a nonlinearity: $\text{MLP}(x) = W_2 \cdot \text{ReLU}(W_1 x + b_1) + b_2$

Understanding Attention Paths

What is a Path?

In the context of transformers, a path refers to a specific route through the self-attention network. At each layer, you can choose one of the $H$ attention heads or bypass the layer via a skip connection.

Mathematically, we can define a path $P = (p_1, p_2, ..., p_L)$, where $p_l \in \{0, 1, ..., H\}$:

  • $p_l = h$ (1 to $H$): Use head $h$ at layer $l$
  • $p_l = 0$: Bypass the layer (skip connection)

image.png

Visualizing Paths

Let's interpret two example paths:

  1. Red Path (2, ..., 1): This path starts with head 2 in the first layer and ends with head 1 in the last layer, with unspecified heads in between. This represents a diverse approach, exploring different aspects of the input at different layers.

  2. Blue Path (H, 2, ..., 2): This path starts with the last head ($H$) and then consistently uses head 2 for all subsequent layers. This represents a more specialized strategy - beginning broadly but then maintaining consistency.

Why Do Paths Matter?

The total output $X_L$ of a SAN can be decomposed as a sum of contributions from all possible paths:

$X_L = \sum_{P \in \mathcal{P}} w_P X_L^P$

where $\mathcal{P}$ is the set of all possible paths (e.g., $(H+1)^L$ choices with $H$ heads + skip), $X_L^P$ is the output of path $P$, and $w_P$ weights its contribution.

This decomposition reveals:

  1. Interpretability: By analyzing individual paths, we can better understand how information flows through the network
  2. Efficiency: If only a few paths contribute significantly, we could potentially prune the others
  3. Theory: It provides a framework for understanding how depth functions in these networks

The challenge? With $H=8$ and $L=6$, there are $9^6 \approx 531,441$ paths! This complexity motivates the study of "weakly interdependent" paths.

Weak Interdependence Between Paths

For paths to be "weakly interdependent," their interactions must be small or negligible. Several mathematical properties can characterize this:

1. Orthogonality of Path Contributions

If the outputs $X_L^P$ and $X_L^Q$ of two distinct paths $P$ and $Q$ are orthogonal:

$\langle X_L^P, X_L^Q \rangle = 0 \text{ for } P \neq Q$

This ensures each path contributes to the output in a way that is linearly independent of other paths.

2. Low Covariance Between Paths

If the covariance between path outputs is small:

$\text{Cov}(X_L^P, X_L^Q) = \mathbb{E}[(X_L^P - \mathbb{E}[X_L^P])(X_L^Q - \mathbb{E}[X_L^Q])^T] \approx 0$

This indicates that fluctuations in one path's contribution don't strongly correlate with fluctuations in another's.

3. Additive Contributions with Minimal Interaction

If the output $X_L$ is well-approximated by the linear sum $\sum_P w_P X_L^P$, with higher-order interaction terms (e.g., products like $X_L^P X_L^Q$) being negligible.

4. Small Cross-Path Gradients

For parameters $\theta_P$ and $\theta_Q$ associated with paths $P$ and $Q$, the cross-gradient terms in the loss function $L$ are small:

$\left|\frac{\partial^2 L}{\partial \theta_P \partial \theta_Q}\right| \ll 1 \text{ for } P \neq Q$

This means optimizing one path has little impact on another during training.

5. Low Mutual Information

The mutual information between the outputs of paths $P$ and $Q$ is small:

$I(X_L^P; X_L^Q) \approx 0$

This formalizes the idea that the paths are nearly independent.

Gateway Choices: Rethinking Information Flow

The concept of "gateway choices" changes our understanding of how information flows through transformers. Rather than treating all information equally, certain components act as selective gateways that determine which information is emphasized or propagated.

This has several implications:

  1. Selective Routing: Transformers dynamically adjust their focus based on the input, prioritizing task-critical data.
  2. Dynamic Processing: The flow of information adapts to the context, with different aspects prioritized at different stages.
  3. Hierarchical Refinement: Early gateways might extract broad features, while later ones refine these into more specific representations.

Mathematically, gateway choices already exist in transformers:

  • Attention Weights: The softmax determines how much information flows between tokens
  • Residual Connections: Allow some information to bypass transformation

Rank Collapse: A Key Constraint

Perhaps the most intriguing concept is "rank collapse" - the phenomenon where representations or attention matrices become low-rank as they pass through the network layers.

What is Rank?

In linear algebra, the rank of a matrix is the number of linearly independent rows or columns. It reflects the amount of unique information the matrix contains.

The Rank Collapse Claim

The claim that attention "doubly exponentially loses rank" suggests that as we stack attention layers, the rank of the representations decreases at an accelerated rate.

This is potentially caused by several properties of the attention operation:

  1. Softmax Saturation: If the dot products between queries and keys have large values, softmax becomes sparse, reducing the rank of the attention matrix.
  2. Correlation in Queries and Keys: Since they're derived from the same input, this can lead to redundancy.
  3. Layer Composition: Low rank at one layer propagates to the next, creating a feedback loop.
  4. Scaling Effects: The scaling factor $\sqrt{d_k}$ can impact rank dynamics.

The Self-Attention Head Output

Looking at the formula for self-attention head output:

$\text{SA}_{(h)}(X) = P_{(h)} \cdot X \cdot W_{(V,h)} + \mathbf{1}b_{(V,h)}^T$

The properties of $P_{(h)}$ (the attention matrix from softmax) are crucial:

  • Row-Stochastic: Each row sums to 1, imposing a constraint
  • Non-Negative: All entries are between 0 and 1
  • Variable Sparsity: Depending on input patterns, it can range from uniform to nearly one-hot

If $P_{(h)}$ becomes low-rank (e.g., due to most tokens attending to the same keys), it constrains the rank of the output, leading to rank collapse.

Coming Up in Part 2

In the next installment, we'll explore

  • Dimensional Consistency in Transformers
  • Attention Mechanism Reformulation
  • The Mathematics of Rank Collapse
  • Concatenation vs. Summation in Multi-Head Attention
  • Scaling Factors and Bias Terms

Stay tuned for more deep dives into transformer theory in the upcoming parts of this series!