Attention Paths and Rank Collapse Part 1
This is the first part of a multi-part series diving into the theoretical underpinnings of transformer models. In this installment, we'll explore the concepts of attention paths and rank collapse, two fundamental ideas that help explain how transformers actually work under the hood.
- Introduction
- The Mathematical Foundation of Self-Attention
- Understanding Attention Paths
- Weak Interdependence Between Paths
- Gateway Choices: Rethinking Information Flow
- Rank Collapse: A Key Constraint
- Coming Up in Part 2
- References:
Introduction
Transformers have revolutionized deep learning across domains from NLP to computer vision, but how do they actually work? While we know how to build and train these models, their internal mechanics remain somewhat mysterious.
Recently I came across this paper Attention is not all you need which tries to highlight how attention is not all you need and how MLP and skip connections play huge role in learning. We'll start by exploring two fascinating concepts: attention paths and rank collapse.
The Mathematical Foundation of Self-Attention
Before diving into the more complex concepts, let's review the basic building blocks of self-attention networks (SANs).
In a SAN, the input is a sequence of vectors $X = [x_1, x_2, ..., x_N]$, where $x_i \in \mathbb{R}^d$ represents a token embedding, $N$ is the sequence length, and $d$ is the embedding dimension.
Queries, Keys, Values
For each attention head, we transform the input into three roles:
- Queries: $Q = XW_Q$, where $W_Q \in \mathbb{R}^{d \times d_k}$
- Keys: $K = XW_K$, where $W_K \in \mathbb{R}^{d \times d_k}$
- Values: $V = XW_V$, where $W_V \in \mathbb{R}^{d \times d_v}$
Here, $d_k$ and $d_v$ are dimensions of keys and values, typically $d_k = d_v = d/H$ where $H$ is the number of heads.
Attention Scores
We compute how much each token attends to others:
$A = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)$
- $QK^T \in \mathbb{R}^{N \times N}$ is a matrix of dot products measuring similarity between queries and keys
- $\sqrt{d_k}$ scales the scores to prevent large values that could destabilize the softmax
- $A_{ij}$ is the attention weight from token $i$ to token $j$
Output
The output of a single attention head is:
$\text{head} = AV$
So, $\text{head} \in \mathbb{R}^{N \times d_v}$.
Multi-Head Attention
The complete multi-head attention operation uses $H$ heads, each with its own $W_Q^h, W_K^h, W_V^h$:
$\text{MultiHead}(X) = \text{Concat}(\text{head}_1, ..., \text{head}_H)W_O$
where $W_O \in \mathbb{R}^{Hd_v \times d}$ projects the concatenated outputs back to $d$-dimensional space.
Layer Structure
In a SAN layer, we add residual connections and an MLP:
$X' = X + \text{MultiHead}(X) \text{ (residual connection)}$
$X_{out} = X' + \text{MLP}(X') \text{ (another residual connection)}$
The MLP typically consists of two linear layers with a nonlinearity: $\text{MLP}(x) = W_2 \cdot \text{ReLU}(W_1 x + b_1) + b_2$
Understanding Attention Paths
What is a Path?
In the context of transformers, a path refers to a specific route through the self-attention network. At each layer, you can choose one of the $H$ attention heads or bypass the layer via a skip connection.
Mathematically, we can define a path $P = (p_1, p_2, ..., p_L)$, where $p_l \in \{0, 1, ..., H\}$:
- $p_l = h$ (1 to $H$): Use head $h$ at layer $l$
- $p_l = 0$: Bypass the layer (skip connection)
Visualizing Paths
Let's interpret two example paths:
-
Red Path (2, ..., 1): This path starts with head 2 in the first layer and ends with head 1 in the last layer, with unspecified heads in between. This represents a diverse approach, exploring different aspects of the input at different layers.
-
Blue Path (H, 2, ..., 2): This path starts with the last head ($H$) and then consistently uses head 2 for all subsequent layers. This represents a more specialized strategy - beginning broadly but then maintaining consistency.
Why Do Paths Matter?
The total output $X_L$ of a SAN can be decomposed as a sum of contributions from all possible paths:
$X_L = \sum_{P \in \mathcal{P}} w_P X_L^P$
where $\mathcal{P}$ is the set of all possible paths (e.g., $(H+1)^L$ choices with $H$ heads + skip), $X_L^P$ is the output of path $P$, and $w_P$ weights its contribution.
This decomposition reveals:
- Interpretability: By analyzing individual paths, we can better understand how information flows through the network
- Efficiency: If only a few paths contribute significantly, we could potentially prune the others
- Theory: It provides a framework for understanding how depth functions in these networks
The challenge? With $H=8$ and $L=6$, there are $9^6 \approx 531,441$ paths! This complexity motivates the study of "weakly interdependent" paths.
Weak Interdependence Between Paths
For paths to be "weakly interdependent," their interactions must be small or negligible. Several mathematical properties can characterize this:
1. Orthogonality of Path Contributions
If the outputs $X_L^P$ and $X_L^Q$ of two distinct paths $P$ and $Q$ are orthogonal:
$\langle X_L^P, X_L^Q \rangle = 0 \text{ for } P \neq Q$
This ensures each path contributes to the output in a way that is linearly independent of other paths.
2. Low Covariance Between Paths
If the covariance between path outputs is small:
$\text{Cov}(X_L^P, X_L^Q) = \mathbb{E}[(X_L^P - \mathbb{E}[X_L^P])(X_L^Q - \mathbb{E}[X_L^Q])^T] \approx 0$
This indicates that fluctuations in one path's contribution don't strongly correlate with fluctuations in another's.
3. Additive Contributions with Minimal Interaction
If the output $X_L$ is well-approximated by the linear sum $\sum_P w_P X_L^P$, with higher-order interaction terms (e.g., products like $X_L^P X_L^Q$) being negligible.
4. Small Cross-Path Gradients
For parameters $\theta_P$ and $\theta_Q$ associated with paths $P$ and $Q$, the cross-gradient terms in the loss function $L$ are small:
$\left|\frac{\partial^2 L}{\partial \theta_P \partial \theta_Q}\right| \ll 1 \text{ for } P \neq Q$
This means optimizing one path has little impact on another during training.
5. Low Mutual Information
The mutual information between the outputs of paths $P$ and $Q$ is small:
$I(X_L^P; X_L^Q) \approx 0$
This formalizes the idea that the paths are nearly independent.
Gateway Choices: Rethinking Information Flow
The concept of "gateway choices" changes our understanding of how information flows through transformers. Rather than treating all information equally, certain components act as selective gateways that determine which information is emphasized or propagated.
This has several implications:
- Selective Routing: Transformers dynamically adjust their focus based on the input, prioritizing task-critical data.
- Dynamic Processing: The flow of information adapts to the context, with different aspects prioritized at different stages.
- Hierarchical Refinement: Early gateways might extract broad features, while later ones refine these into more specific representations.
Mathematically, gateway choices already exist in transformers:
- Attention Weights: The softmax determines how much information flows between tokens
- Residual Connections: Allow some information to bypass transformation
Rank Collapse: A Key Constraint
Perhaps the most intriguing concept is "rank collapse" - the phenomenon where representations or attention matrices become low-rank as they pass through the network layers.
What is Rank?
In linear algebra, the rank of a matrix is the number of linearly independent rows or columns. It reflects the amount of unique information the matrix contains.
The Rank Collapse Claim
The claim that attention "doubly exponentially loses rank" suggests that as we stack attention layers, the rank of the representations decreases at an accelerated rate.
This is potentially caused by several properties of the attention operation:
- Softmax Saturation: If the dot products between queries and keys have large values, softmax becomes sparse, reducing the rank of the attention matrix.
- Correlation in Queries and Keys: Since they're derived from the same input, this can lead to redundancy.
- Layer Composition: Low rank at one layer propagates to the next, creating a feedback loop.
- Scaling Effects: The scaling factor $\sqrt{d_k}$ can impact rank dynamics.
The Self-Attention Head Output
Looking at the formula for self-attention head output:
$\text{SA}_{(h)}(X) = P_{(h)} \cdot X \cdot W_{(V,h)} + \mathbf{1}b_{(V,h)}^T$
The properties of $P_{(h)}$ (the attention matrix from softmax) are crucial:
- Row-Stochastic: Each row sums to 1, imposing a constraint
- Non-Negative: All entries are between 0 and 1
- Variable Sparsity: Depending on input patterns, it can range from uniform to nearly one-hot
If $P_{(h)}$ becomes low-rank (e.g., due to most tokens attending to the same keys), it constrains the rank of the output, leading to rank collapse.
Coming Up in Part 2
In the next installment, we'll explore
- Dimensional Consistency in Transformers
- Attention Mechanism Reformulation
- The Mathematics of Rank Collapse
- Concatenation vs. Summation in Multi-Head Attention
- Scaling Factors and Bias Terms
Stay tuned for more deep dives into transformer theory in the upcoming parts of this series!