µ Transformer attention, a short introduction · Tanishq Kumar

Transformer Attention, a short introduction

Here, we present a short introduction to the attention mechanism and transformer model that have taken over NLP in recent years. We assume a working understanding of the fundamentals of machine learning, including what recurrent neural nets (RNNs) are and how they work, because the transformer architecture essentially seeks to improve on these for the purposes of language translation. We note that the attention mechanism is much more general than just the multi-head attention blocks used in the transformers, and can indeed be applied to tasks like vision, as we emphasize in reviewing more recent works (on visual transformers) at the end of this essay.

Vanilla RNNs, LSTMs, and motivating the need for attention

First observe that historically (pre 2010s), the canonical approaches to machine translation was using a combination of formal linguistical knowledge, and statistical heuristics, to produce workable translations. Then, around 2010, as neural networks began to gain traction as compute grew, they became canon in NLP tasks, the most common of which is language translation. That is, sequence to sequence translation of sentences and phrases in one language to another. Much as how convNets take the main ideas from vanilla feedforward nets and add the inductive biases of spatial locality and translational invariance, RNNs adds the bias of sequential operation: we humans look at phrases one word at a time. Likewise, recall that RNNs are composed of many layers, each of which roughly taking the form \( h_t = f(W_1 x_{t-1} + W_2h_{t-1} ) \), where \(s \) represents the \(t-1\)-th word in the input stream, and \( h_{t-1} \) represents the previous hidden state (some latent vector representation of a combination of the previous encountered). Importantly, the matrices are the same throughout time.

These standard RNNs were the state of the art (SOTA) for a while, but suffered from the vanishing gradient problem. Because each layer reprented one word token or time step, backpropagating through RNNs working with long sentences meant that we were using the chain rule many, many times in updating the weights. This meant that as we nested many gradients, the product took the form \(x^L \) where \(L\) is the number of layers, and so the updates we make to the weights were either too large, for \(x >1 \), or almost non-existent, for \(x < 1 \).

To solve this problem of information not persisting through time, people just added some sort of "permanent memory" to RNNs, sort of like the RAM of a computer. These were called "cell memories," and were just a new form of state that persisted in an RNN, over time. The motivation is remembering things from a long time ago (in the sentence) just as well as information from more recent tokens. Of course, we want to be able to modify this new memory, so they introduced "gates," which are just really (like everything else under these abstract words we're using) mathematical functions that update the cell memory. What updates would we like to perform to this intended-long term memory? The natural ones are "forgetting," inputting information (from hidden state into cell), and outputting (information from a while ago, stored in cell, into current state \( h_i\)). These three gates are exactly what we have in this new architecture.

Formally, for instance, the forget gate takes the form \( f^{(t)} = \sigma(W_fh^{t-1} + U_f x^{t-1}) \), and then we can write cell state over time as \( c^t = c^{t-1} \circ f^{t}\) where \( \circ \) denotes the Hadamard product. This makes sense, as it means that cell memory in the next time step is what it was just now, minus what we intend to forget. Similarly, we can write equations for input to cell memory (from hidden state) and output (from cell memory to hidden state) that would result in updated equations. This new architecture, with a cell and gates that obey these equations, is called a Long-short term memory architecture (LSTM), and now is able to remember longer-term dependencies by storing them in this cell memory.

The key is that any updates to cell (eg. input or forgetting) are additive, not multiplicative, so nesting over time doesn't cause its content to vanish. It's also conceptually crucial to note that the updates to cell (ie. what and when to forget information) are learned through learning of the matrices \(W_f, U_f\) (and analogues for input and output), so that we can train the LSTM to optimally store long-term dependencies (we don't want to store a latent representation of early words if that information is never needed to translate later parts of a sentence, and so what to store and how to store it should be learned). Still, though, its recall over hundreds of tokens is imperfect. This motivates the use of the famed attention mechanism in RNN design.

[modern LSTM architectures -- enc/dec and stacking]

The attention mechanism

Lets put LSTMs aside for now, and try to think up a fundamentally new and better way to remember longer term dependencies. This is where the fun begins. In around 2014, before the introduction of the transformer model itself, a mechanism was proposed to allow RNNs to mimic the way human translators parse text -- by looking back and forth between the text to translate, and their growing translation. The key idea was that the bottleneck for information in a traditional encoder-decoder RNN architecture is the latent representation \( z\) passed between the encoder and decoder, where information from the beginning of the sequence (that may still be relevant later) has mostly been lost by transformations from more recent inputs/hidden states. So, in the spirit of the human translator, we can allow our decoder to "look back" at earlier inputs by simply taking representation of early words, and concatenating them to \(z \) and feeding that into the decoder.

[TODO: fill out all below]

The Transformer architecture

Recent work

(multi-head attention, visual transformers)

Immediate questions