Transformers are the secret sauce that makes Chat-GPT, DALL-E, and other GPT-based systems so powerful. No, I'm not talking about Optimus Prime and his Autobot pals 🚗🤖 - I'm talking about the neural network architecture. Transformers are the "T" in GPT-4 (Generative Pretrained Transformers v4), and their development has enabled machines to understand language with an unprecedented level of accuracy.
This article attempts to break down the transformer architecture. To motivate why the architecture was groundbreaking, in part one I discuss the limitations of its predecessors: RNNs and LSTMs. Then I dive into the components of the architecture in part two.
TLDR:
Language models predict the next best word in a sentence. In order to do this accurately, the model needs to understand the whole sentence.
For example, to predict a likely word after server, the model must examine preceding words to determine if the sentence is about a computer or a waiter. This informs whether food or crash is a better successor.
RNNs understand the whole sentence by processing each word in a sentence sequentially.
Due to the nature of sequential processing, RNNs take a long time to train and are ineffective on long pieces of text. LSTMs incrementally improve RNNs, but they still take a long time to train.
Transformers introduced a new way of understanding a whole sentence without sequential processing. As a result, the training process can be parallelized, enabling the model to be trained on far more data. The two key concepts in transformers that enable this are positional encoding and self-attention.
If none of that made any sense yet, don’t fret, the point of this article is to explain these concepts in detail.
Part I, Predecessors to Transformers: Recurrent Neural Networks (RNNs) and Long Short-Term Memory Networks (LSTMs)
A property of language is that the order of words matters. Compare the following sentences: “Grandma I am ready to eat” vs “I am ready to eat Grandma” #yum. Even though both sentences use the same words, the order of words drastically changes the sentence’s meaning. Thus, for a machine to understand a sentence, models must capture word order and not just look at each word in isolation.
Before transformers, recurrent neural networks (RNNs) captured word order by processing each word in a sentence sequentially. For example, when predicting the next word after the phrase the server is, RNNs first process the, then server, and finally is. RNNs take the previous token’s output as input to process the current token as seen in the diagram below. Hence, when looking at a word and trying to predict the next word, the model has historical information about the previous words.
Sequential processing led to two big limitations:
Training RNNs take a long time. Processing each word sequentially means that the training process can not be parallelized. Thus, adding more computing power (e.g, GPUs) does not decrease the time it takes to train a model. Hence, RNNs can’t be trained on a lot of data, impacting the model’s understanding of language.
Vanishing gradient problem. Theoretically, passing previous states to the current state means that the model can remember previous words. In practice, however, combining the outputs of previous words means that the model is left without precise information about preceding words. By the time RNNs are at the end of the paragraph, they forget what the beginning of the paragraph was about. This can be thought of as RNNs running out of memory. Hence, they don’t work well on long inputs.
There were some improvements made to RNNs to work around these limitations. For example, the long short-term memory network (LSTM) is a type of RNN that gets around the vanishing gradient problem. It does this with a “forget gate” that is responsible for forgetting some past words so that only the important previous context is remembered. That way, the “memory” of an RNN is conserved and the model can remember the context of a sentence for a relatively long time.
LSTMs were the kings of NLP for a while. Still, they were limited by the fact that they were sequential by nature and thus limited in the amount of data they can be trained on.
Part II: Steps in the Transformer Architecture
In come transformers. They were introduced in a 2017 paper titled Attention Is All You Need by researchers at Google and the University of Toronto. The transformer architecture encodes the context of a sentence within the encoding of a word. That way each word can be processed in isolation instead of needing to be processed sequentially. Said differently, transformers move the burden of capturing word order from the neural network structure to the encoding of an individual word. This means that the training process can be parallelized, so adding more computers/GPUs decreases the time it takes to train the model, thereby enabling the model to be trained on far more data and wildly improving results.
The easiest way to understand the transformer is by stepping through a sample input. Consider the sentence: “in the warehouse, the server just
”. When training a language model, the next word is known (e.g, exploded
) and we tweak the model’s parameters until it returns a similar word. In the diagram below, the parameters are part of each of the boxes (e.g, the blue feed-forward box). Once the model is trained, the input still flows through these same steps but the output is determined by the chosen parameters.
These are the steps an input goes through in the transformer architecture:
Step 0: Tokenization. To process the English sentence, the input is divided into smaller units called tokens. For simplicity, I equate a token to a word. These tokens or words are then mapped to a number, creating a numerical representation of the input. In our example, the input would be turned into a set of numbers like: 102, 23, 56, 23, 64, 43.
Step 1: Embedding. Each token is mapped to a vector that meaningfully represents the word. The token 102
would be turned into an n-dimensional vector like [321, 23, … 54]
. If you’re curious, my intro-to-ML post explains how embeddings work in depth.
Step 2: Positional encoding. This is the first big win that enables parallelization. Positional encoding slaps a number onto every word’s embedding that represents its position in the sentence. That way the model can account for word order without processing each word sequentially. More precisely, positional encoding injects information about the word’s position into the embedding. It looks something like:
Step 3: Self-attention which occurs in the multi-attention heads and is the second big win. Self-attention involves focusing on (or attending to) the most relevant parts of a sentence while processing a single word. For example, when processing the word server the model would pay attention to the word warehouse to determine whether the sentence is talking about a machine server or a human server.
Self-attention allows the model to understand the context surrounding a word and generate more accurate outputs. In order to disambiguate homographs (words that are spelled the same but mean different things), recognize parts of speech, and identify word tenses, the model must pay attention to other parts of the sentence. Self-attention enables this (and without the need for sequential processing). During the self-attention step, the model calculates attention weights for the input and generates an output that encodes information about how each word should attend to all the other words in the sentence.
Step 4: Feed Forward Neural Network. The feed-forward network takes the information from the previous steps to find potential output words. Given the word and the context it is in, the network predicts the next most probable word. Training a model to predict the correct next word is done using the gradient descent and backpropagation algorithms. If you’re curious about how these algorithms work, check out this previous post on how language models are trained.
Step 5: Linear and Softmax. Functions that turn the scores from the previous step into probabilities between 0 and 1 of how likely an output word is.
At this stage, the model has a set of words with their corresponding probabilities of being a suitable next word. Simply selecting the word with the highest probability (using a greedy algorithm) can lead to repetitive and looping responses. To avoid this issue, most models use a sampling technique where one of the high-probability words is randomly selected. This is why ChatGPT often provides different responses to the same input. In the GPT playground, you can adjust the "temperature" variable to control the level of randomness.
Part III: Confession & Caveat & Conclusion
Confession: I obfuscated one thing in the walk-through above. The original transformer actually has two components: an encoder and a decoder. Since the encoder and decoder are similar, researchers have had success using permutations of the two. GPT only uses the decoder, while the language model behind Google Search named BERT only uses the encoder. The example I walked through follows an input through a decoder but not the flow through an encoder — though it is rather similar.
This is the diagram of the transformer from the original paper:
It shows that the encoder and decoder both use the same key components, namely: (i) embeddings (ii) positional encoding, (iii) multi-head attention where self-attention takes place, and (iv) feed-forward networks, where the prediction occurs.
It is also worth noting that language models have multiple transformer layers stacked on top of each other. Each additional layer increases the model's capability and parameters. For example, the smallest version of GPT-2 has 6 decoders stacked on top of each other, while the largest version has 73. A decoder-based model like GPT would only have decoders stacked on top of each other, while a classic transformer model like T5 would have both encoders and decoders stacked ontop of each other. In the diagram, Nx
indicates that the block is repeated n
times, representing the number of layers.
Huge thank you to Shyamoli Sanghi 👑, B.A in Machine Learning, who graciously fact-checked this post.
I hope this post was helpful! Subscribe if you are interested in having more posts like this sent to your inbox. If you have any feedback or corrections, send them my way. If you have any questions, leave a comment. And if you think anyone else would appreciate this post, kindly share it with them.
Great overview!