Attention Mechanisms Explained: How Neural Networks Focus
The Problem Attention Solves
Before attention, sequence-to-sequence models encoded the entire input into a single hidden state vector, typically 256 to 1024 numbers. This vector had to contain everything the decoder needed to generate the output. For short inputs this worked adequately, but for long sequences the fixed-size bottleneck became severe. A single vector cannot faithfully represent a 500-word paragraph with all its nuances, specific details, and relationships between parts. Translation quality degraded sharply as sentence length increased, with error rates climbing steeply beyond 20 to 30 words.
Attention removes this bottleneck by allowing the decoder to access every position in the encoder's output directly. When generating each output word, the decoder computes a weighted combination of all encoder hidden states, where the weights reflect how relevant each input position is to the current output position. Translating the English word "dog" into French, the decoder can attend strongly to the English word "dog" in the input and weakly to everything else, directly accessing the information it needs without it having to survive compression into a single vector.
How Attention Computes Relevance
The core computation in attention is a compatibility function that scores how well two representations match. Given a query vector (representing "what am I looking for?") and a set of key vectors (representing "what does each position contain?"), the compatibility score between the query and each key determines the attention weights. Higher scores mean higher relevance. The scores are normalized through softmax so they sum to 1, turning them into a probability distribution over the input positions.
The two most common compatibility functions are additive attention and dot-product attention. Additive attention, used in the original 2014 paper by Bahdanau, computes the score through a small neural network: score(q, k) = v^T * tanh(W_q * q + W_k * k), where W_q, W_k, and v are learned parameters. Dot-product attention, used in transformers, simply computes the dot product: score(q, k) = q * k^T / sqrt(d), where d is the dimension of the vectors. Dot-product attention is faster because it requires only matrix multiplication, not a separate neural network evaluation for each pair.
Once the attention weights are computed, the output is a weighted sum of value vectors. Each input position contributes a value vector, and the final output for a given query is the sum of all value vectors weighted by their attention scores. If position 3 has an attention weight of 0.7 and position 8 has a weight of 0.2, the output is dominated by position 3's value with a smaller contribution from position 8.
Self-Attention vs Cross-Attention
Self-attention (also called intra-attention) computes attention within a single sequence. Each position in the sequence attends to every other position in the same sequence. The queries, keys, and values all come from the same set of input vectors. Self-attention is what allows a transformer to understand that "it" in "The company said it would raise prices" refers to "the company," by computing high attention weight between the two positions.
Cross-attention computes attention between two different sequences. One sequence provides the queries and the other provides the keys and values. In machine translation, the decoder's representations serve as queries, and the encoder's representations provide keys and values. This allows each output position to look at the input and select relevant information. In multimodal models that combine text and images, cross-attention might let text tokens attend to image features, enabling the model to ground language in visual content.
The transformer encoder uses only self-attention: each token's representation is enriched by context from all other tokens in the input. The transformer decoder uses both: causal self-attention (where each position can only attend to earlier positions, preventing information leakage from future tokens) and cross-attention to the encoder's output. Decoder-only models like GPT use only causal self-attention, because there is no separate encoder.
Multi-Head Attention in Detail
A single attention head captures one type of relationship between positions. Multi-head attention runs multiple attention computations in parallel, each with independent learned projections for queries, keys, and values. This allows the model to simultaneously attend to information from different representation subspaces at different positions.
Concretely, if the model dimension is 512 and there are 8 heads, each head operates in a 64-dimensional space. The input is projected to 8 separate sets of 64-dimensional queries, keys, and values. Each head computes attention independently, producing 8 separate 64-dimensional output vectors per position. These are concatenated back to a 512-dimensional vector and passed through a final linear projection.
Analysis of trained transformer heads reveals consistent specialization patterns. Some heads learn to attend to the immediately preceding or following position, capturing local context. Others learn syntactic relationships like subject-verb agreement across distances. Some heads specialize in coreference resolution, connecting pronouns to their antecedents. In multilingual models, certain heads appear to encode universal linguistic structures that are shared across languages. This specialization emerges entirely from training, with no explicit supervision telling heads what to focus on.
Variants and Improvements
Sparse Attention
Standard self-attention is quadratic in sequence length: for N tokens, the attention matrix has N^2 entries. At a sequence length of 100,000 tokens, this matrix would require roughly 40 gigabytes of memory in half precision. Sparse attention patterns reduce this by restricting which positions can attend to which. Local attention allows each position to attend only to a fixed window of nearby positions. Strided attention attends to every k-th position. Combinations of local and strided patterns (as in Longformer and BigBird) achieve near-linear complexity while retaining the ability to model long-range dependencies.
Linear Attention
Linear attention approximations replace the softmax over the full N x N attention matrix with a factored computation that scales linearly with sequence length. The key insight is that if you can decompose the attention kernel into a product of feature maps, the computation can be reordered to avoid materializing the full attention matrix. Performers, Random Feature Attention, and linear transformer variants achieve this with varying degrees of approximation quality. The tradeoff is typically some degradation in modeling quality compared to full quadratic attention.
FlashAttention
FlashAttention does not change the mathematical computation of attention but dramatically optimizes its implementation on GPU hardware. By restructuring the computation to minimize memory reads and writes between GPU high-bandwidth memory and on-chip SRAM, FlashAttention achieves 2 to 4 times the speed of standard attention with significantly less memory usage. It computes the exact same result as standard attention but exploits the memory hierarchy of modern GPUs. FlashAttention has become the default attention implementation in most training and inference frameworks.
Grouped Query Attention
Standard multi-head attention stores separate key and value projections for each head, which dominates memory usage during inference when long sequences of past keys and values must be cached. Grouped query attention (GQA) shares key and value projections across groups of query heads. Multi-query attention (MQA) takes this to the extreme by sharing a single set of keys and values across all query heads. GQA, used in LLaMA 2 and many subsequent models, reduces memory usage during inference by 4 to 8 times with minimal quality loss.
Attention Beyond Language
Attention has expanded far beyond its origins in machine translation. Vision Transformers (ViTs) split images into patches and apply self-attention across patches, allowing each patch to incorporate information from spatially distant parts of the image. Audio transformers process spectrograms or raw waveforms with attention. AlphaFold 2 uses attention to model relationships between amino acid residues in a protein sequence, enabling accurate 3D structure prediction. Graph attention networks use attention to weight the contributions of neighboring nodes in a graph.
The universality of attention as a mechanism for relating elements in a set is what makes it so broadly applicable. Any problem that can be framed as "given a collection of elements, compute an output for each element based on its relationships to all other elements" can benefit from attention. This description fits not just sequences but also images (collections of patches), molecules (collections of atoms), social networks (collections of users), and point clouds (collections of 3D points).
Attention mechanisms compute dynamic, learned relevance weights between elements in a sequence, replacing the fixed-size bottleneck of earlier architectures. Self-attention within sequences and cross-attention between sequences are the two fundamental forms, and multi-head attention allows simultaneous capture of multiple relationship types. Efficient variants address the quadratic scaling challenge for long sequences.