Attention mechanism is the basic building block of almost all so-called foundational models and their derivatives, including the infamous overhyped LLMs, that everyone talks about. The mechanism introduced in “Neural Machine Translation by Jointly Learning to Align and Translate” primarily targeting problems in sequence-to-sequence neural machine translation and adopted in “Attention Is All You Need” to be the foundation of transformer models has quite a downside associated with it - in its basic form it is being calculated in O(n2) where n is the size of context. Why does it matter? Nowadays, we want to use as large a context as possible - ranging from 4K to 1M (declarative) or at least 128K(effective) according to Ruler Benchmark, and quadratic complexity makes models consume vast amounts of processing time and memory, making SoTA models unavailable for “mortals”, even with massive quantization. Another troubling aspect of attention mechanisms is attention sinks - a phenomenon described in “Efficient streaming language models with attention sinks”, “Vision transformers need registers”,and “Massive activations in large language models”, where specific tokens receive the vast majority of the attention scores based only on their position and not their “content”. Commonly the first (eg. [CLS] or special tokens draw an outsized share of attention from other tokens. This concentration of attention acts as a bottleneck, limiting the model’s ability to effectively incorporate other relevant information in the sequence and functioning like an attention “black hole” that absorbs focus without proportionate benefit.
Before diving into the contribution of this publication, let’s do a quick recap of the attention mechanism, which would be crucial for understanding the proposed changes. The first stage of Multi-Head Softmax Attention is calculation of Query (Q), Key (K), and Value (V) projections.
Next, attention is computed as a sum of values weighted by softmax-normalized scores between queries and keys.
In multi-head attention, this computation is performed simultaneously for each head, having its own projection matrices:
Finally, the output is put through the last projection layer:
Gating mechanisms are known from early sequence processing models like LSTMs ( “Long Short-Term Memory” ), GRUs ( “Gate-Variants of Gated Recurrent Unit (GRU) Neural Networks” ), and Highway Networks, and were a crucial part in regulating the flow of information in these models. The basic gating mechanism used in this paper was described as multiplication between gate input (Y) and the gating score computed as a sigmoid activation function over separate linear projections on pre-normalized hidden states (X). The gating mechanism also increases expressiveness by introducing an additional non-linearity between Value (WV) and Output (WO) projection layers, which in “standard” attention are equivalent to a single low-rank linear projection. An additional factor contributing to the effectiveness of this mechanism is input-dependent sparsity facilitated by gating scores.
source: "Neural Machine Translation by Jointly Learning to Align and Translate"
source: "Neural Machine Translation by Jointly Learning to Align and Translate"
The gating mechanism can be considered as an improvement to the attention with different aspects:
Position:
(G2-G4) - gate on input projection
G1 - gate before output projection
G5 - gate after output projection
Granularity:
Headwise - single gating score regulates all outputs of the whole attention head
Element-wise - single gating score regulates a single input dimension
Sharing:
Head specific - each attention head has separate gating weights/scores
Shared - all gating weights and scores are shared across all heads
Results were obtained by the authors over 30 variants of 15B MoE and 1.7B dense models trained on up to 3.5T tokens.
Best Gating configuration
Experimental results show that the best performance is achieved with elementwise, head specific, multiplicative gating with sigmoid activation placed before output projection (position G1).
source: "Neural Machine Translation by Jointly Learning to Align and Translate"
Attention Sink
The gating mechanism was proven to significantly reduce the attention sink problem, where the baseline model assigned on average 46,7% of attention scores (globally) to the first token.
source: "Neural Machine Translation by Jointly Learning to Align and Translate"
In-depth analysis of the model shows that for Layer 21 with 83% of attention scores on the first token, the gating mechanism reduced that number to 4%, and for Layer 23, from 41% to 1%.
source: "Neural Machine Translation by Jointly Learning to Align and Translate"
Training loss and stability
Gating mechanism also results in improved training stability and lower final loss value as shown on the chart for training comparison with baseline model using the same set of hyperparameters (dense models with 1.7B parameters, dataset of 3.5T tokens).
source: "Neural Machine Translation by Jointly Learning to Align and Translate"
Gating is a very simple mechanism, yet it proves to give significant improvements, enhancing the expressiveness of models through non-linearity, introducing input-dependent sparsity, and almost completely removing the problem of attention sink. What is probably most important is that it supports better generalization (effective context length) over long sequences without retraining, therefore enabling context length extension - just by adding and training only the gating mechanism, which causes better information distribution across all tokens.