Attention
Attention
Attention is what separates static embeddings from dynamic embeddings - they allow word embeddings to be updated, aka attended to, by the contextual words surrounding them
Attention stemmed from NLP Seq2Seq Tasks like next word prediction, and translation where using the surrounding context of the word was one of the major breakthroughs in achieving better Seq2Seq results
We need to remember that the embedding for "bank" is always the same embedding in the metric space in these scenario's, but by attending to it with Attention, we can change it's position! It's as simple as that, so at the end of attending to the vector, the vector for bank in river bank may point in a completely different direction than the vector for bank in bank vault - just because of how the other words add or detract from it geometrically in its metric space. Bank + having river in sentence moves vector in matrix space closer to a sand dune, where Bank + teller in sentence moves it closer to a financial worker
How is this done? Attention mechanisms in our DNN models. There are multiple forms of Attention - most useful / used are Self Attention, Encoder-Decoder Attention, and Masked Self Attention - each of them help to attend to a current query word / position based on it's surroundings. A single head of this Attention mechanism would only update certain "relationships", or attended to geometric shifts, but mutliple different Attention mechanisms might be able to learn a dynamic range of relationships
In the days before Attention, there would be Encoders that take input sentence and write it to a fixed size embedding layer, and then a separately trained Decoder that would take the embedding and output a new sentence - this architecture for Seq2Seq tasks is fine for short sequences, but degrades with longer sequences you try to "stuff into" a fixed size embedding layer
Therefore, Attention helps us to remove this fixed size constraint between encoders and decoders
All of these Attention mechanisms are tunable matrices of weights - they are learned and updated through the model training process, and it's why we need to "bring along the model" during inference...otherwise we can't use the Attention!
Below shows an example of how an embedding like creature would change based on surrounding context
Keys, Queries, and Values Intuition
Lots of the excerpts here are from the D2L AI Blog on Attention
Consider a K:V database, it may have tuples such as ('Luke', 'Sprangers'), ('Donald', 'Duck'), ('Jeff', 'Bezos'), ..., ('Puke', 'Skywalker')
with first and last names - if we wanted to query the database, we'd simply query based on a key like Database.get('Luke')
and it would return Sprangers
If we allow for approximate matches, we might get ['Sprangers', 'Skywalker']
Therefore, our Queries and our Keys have a relationship! They can be exact, or they can be similar (Luke and Puke are similar), and based on those similarities, we may or may not want to return a Value
What if we wanted to return a portion of the Value, based on how similar our Query was to our Key? Luke and Luke are 1:1, so we fully return Sprangers, but Luke and Puke are slightly different, so what if we returned something like intersect / total
* Value, or 75% of the Value?
Luke and Puke have letters as the same, so we can return of Skywalker!
This is the exact idea of Attention -
This attention can be read as "for every key / value pair , compare the key and the value with , and based on how similar they are return that much of the value. The total sum of all those values is the value we will return"
Therefore, we can mimic an exact lookup if we define our as a Kronecker Delta:
This means we'll return the list of all values that match Luke
, which is just 1
Another example would be if
D = [
([1, 0], 4),
([1, 1], 6),
([0, 1], 6)
]
And our Query is [0.5, 0]
- if we compare that to all of our Keys, we'd see distance([0.5, 0] - [1, 0]) = 0.5
and distance([0.5, 0] - [1, 1]) = - 0.5
so we would have 0.5 * 4 + -0.5 * 6 = -1
would be our answer!
So these comparisons of Queries and Keys results in some weight, and typically we will compare our Query to every Key, and the resulting set we will stuff through a Softmax function to get weights that sum to 1
To ensure weights are non-negative, we can use exponentiation
This is exactly what will come through in most attention calculations!
It is differentiable, and its gradient never vanishes! These are very desirable properties
Bahdanau RNN Attention
Things first started off with Bahdanau Style RNN Attention via Neural Machine Translation by Jointly Learning to Align and Translate (2014)
This paper discusses how most architectures of the time have a 2 pronged setup:
- An encoder to take the source sentence into a fixed length vector
- In most scenario's there's an encoder-per-language
- A decoder to take that fixed length vector, and output a translated sentence
- There's also typically a decoder-per-language as well
- Therefore, for every language there's an encoder-decoder pair which is jointly trained to maximize the probability of correct translation
- Why is this bad?
- Requires every translation problem to be "squished" into the same fixed-length vector
- Examples cited of how the performance of this encoder-decoder architecture deteriorates rapidly as the length of input sentence increases
- Proposal in this paper:
- "We introduce an extension of the encoder-decoder model which learns to align and translate jointly"
- This means it sets up (aligns) and decodes (translates) on the fly over the entire sentence, and not just all at once
- "Each time the proposed model generates a word in a translation, it (soft-)searches for a set of positions in a source sentence where the most relevant information is concentrated
- This means it uses some sort of comparison (later seen as attention) to figure out what words are most relevant in the translation
- The model then predicts a target word based on the context vectors associated with these source positions and all previously generated target words
- The model predicst the next word based on attention of this word and input sentence + previously generated words
- This was a breakthrough in attention, but apparently was proposed here earlier!
- "We introduce an extension of the encoder-decoder model which learns to align and translate jointly"
- Altogether, this architecture looks to break away from encoding the entire input into one single vector by encoding the input into a sequence of vectors which it uses adaptively while decoding
- The bottom portion is an encoder which receives source sentence as input
- The top part is a decoder, which outputs the translated sentence
Encoder RNN Attention
- Input: A sequence of vectors
- Encoder Hidden States:
- is some non-linear function
- This will take the current input word, and the output of the last recurrence
- These are bi-directional
-
- Concatenation!
- It is formalized that is the hidden state of the encoder at time
- is the dimensionality of our encoder hidden state
- Encoder Hidden States:
- These
- They live entirely in the encoder block
- They are fixed once the input is encoded
Decoder RNN Attention
Decoder is often trained to predict the next word given the context vector and all the previously predicted words
Our translation can be defined as which is just the sequence of words output so far!
It does this by defining a probability distribution over the translation output word by decomposing the joint probability.
Since our is just our entire word sequence, the probability we're solving for is "the probability that this is the sequence of words given our context vector"
So we're just choosing the next most likely word so that the probability of seeing all these words in a sequence is highest
The sequence "Hi, what's your", if we looked over all potential next words, would most likely have the highest predicted outcome of "Hi, what's your name"
Each of these conditional probabilities is typically modeled with a non-linear function such that
So, at each time step the decoder combines:
- Decoder hidden state formula: based on
- the last hidden state
- Dimension
- the last output word
- the context at that state
- Where is the weight of compared to each annotation and is a similarity metric between the two
- the last hidden state
Learning To Align And Translate
The rest of the architecture proposes using a bi-directional RNN as an encoder, and then a decoder that emulates searching the source sentence during translation
The "searching" is done by comparing the decoders last hidden state to each encoder hidden state in our alignment model, and then creating a distribution of weights (softmax) to create a context vector. This context vector, the previous hidden state, and the previous hidden word help us to compute the next word!
**P.S. the alignment model here is very similar to self-attention in the future
Since sentences aren't isomoprhic (one-to-one and onto), there may be 2 words squished into 1, 1 word expanded to 2, or 2 non-adjacent words that are used in outputting 1
Realistically, any Seq2Seq task that isn't isomorphic would benefit from this structure
Summary
- Encoder
- is the encoder hidden state / annotations
-
- Concatenation!
- is the encoder hidden state / annotations
- Decoder
- For each step , we need to come up with a context vector
- is an alignment model which scores similarities between decoder hidden state at and all encoder states, where each one is denoted at some time
- is a small feed-forward NN
-
- What does this mean?
- is a weighted version of our decoder hidden state at
- is a weighted version of our encoder hidden state at time
- Therefore, acts as an activation function which helps us to score how close the decoder hidden state at and our current encoder hidden state at are
- Normalize with softmax to get attention weights
- will be the attention score of to all other states
- If there are 5 words in the input, and we're at , this will score how well is to compared to all other annotations
- This will convert scores into a probability distribution
- Since our attention model helps to compute scores across encoder hidden states to a decoder hidden state, taking the softmax here will then give us the relative weight of each encoder hidden state to a decoder hidden state
- will be the attention score of to all other states
- Form the context vector as the weighted sum of these annotations
-
- Where is the weight of compared to each annotation and is a similarity metric between the two
-
- is the decoder hidden state
- is the last output word from decoder
- To bring it all out:
- )
- Our query is , and our keys / values are
- Our attention score is based on which is multiplied by to attend to it
- )
- All of this will be the basis of Self-Attention in the future, and for this you can just read this as the context vector is based on the similarity of an input annotation with the rest of the annotations
D2L AI Code Implementation In PyTorch
Intuition
- In the paper they even mention "this implements a mechanism of attention in the decoder"
- In the encoder, the only major trick is doing bi-directional hidden states and then concatenating them
- This productes the annotations themselves
- These annotations are then fed through the alignment, alpha weight, and context vectors before being used in the decoder with the last output word and hidden state
- This context vector allows us to "attend to" the last output word and last hidden state
- What's missing from transformers? We decide to drag along this hidden weight the entire time, and in Transformers we just re-compute context vector for each vector
Transformer Attention
The above RNN discussion is useful, as it shows how we can utilize the building blocks of forward and backwards passes, and even achieve attention mechnisms using basic building blocks
The rest of the discussion is around Attention blocks in Transformer Architectures, primarily using a similar encoder-decoder structure "on steroids"
Key, Query, and Value Matrices
This setup allows us to create a paradigm of:
- Queries (Q):
- Represents the word being attended to / comapred to
- Used to calculate attention scores with all Keys
- Key (K):
- Represents the context words being compared to the Query
- Used to compute the relevance of each context word to the Query
- Value (V):
- Another representation of the context words, but separate and different from Keys
- Although the same input context words are multiplied by 2 different K, V matrices, which results in 2 different Key and Value vectors for same context word
- It basically is a representation of each "word" so at the end, a scored
SUM()
of all words is over values! - Weighted by the attention scores to produce the final output
- Another representation of the context words, but separate and different from Keys
These matrices are learned during training and updated via backpropagation
Encoding Blocks
The main layer we focus on in our Encoding blocks is Self Attention, but alongside this there are other linear layers that help to stabilize our context creation
Self Attention
Self Attention allows words in a single sentence / document to attend to each other to update word embeddings in itself. It's most commonly used when we want a sentence's word embeddings to be updated by other words in the same sentence, but there's nothing stopping us from using it over an entire document.
It was born out of the example of desiring a different embedding outcome of the word bank in:
- The river bank was dirty
- I went to the bank to deposit money
Via Self Attention, the word "bank" in the two sentences above would be different, because the other words in the sentence "attended to" it
Self Attention is a mechanism that uses context words (Keys) to update the embedding of a current word (Query). It allows embeddings to dynamically adjust based on their surrounding context.
Example
Consider the phrase "fluffy blue creature." The embedding for "creature" is updated by attending to "fluffy" and "blue," which contribute the most to its contextual meaning.
How Self Attention Works
TLDR;
-
The Query vector represents the current word
-
The Key vector is an embedding representing every other word
- We multiply the Query by every Key to find out how "similar", or "attended to" each Query should be by each Key
-
Then we softmax it to find the percentage each Key should have on the Query
-
Finally we multiply that softmaxed representation by the Value vector, which is the input embedding multipled by Value matrix, and ultimately allow each Key context word to attend to our Query by some percentage
-
At the end, we sum together all of the resulting value vectors, and this resulting SUM of weighted value vectors is our attended to output embedding
-
In the below example:
- The dark blue vector from the left is the Query
- The light blue vector on top are the Keys
- We multiple them together + softmax
- Multiply the result of that by each Value vector on the bottom
In depth mathematical explanation below
- Input Transformation:
- Each input embedding is transformed into three vectors: Query (Q), Key (K), and Value (V)
- These are computed by multiplying the input embedding with learned weight matrices:
- Self-Attention Calculation:
- Step 1: Compute attention scores by taking the dot product of the Query vector with all Key vectors :
- Step 2: Scale the scores to prevent large values:
- Where is the dimensionality of the Key vectors
- As the size of the input embedding grows, so does the average size of the dot product that produces the weights
- Remember dot product is a scalar value
- Grows by a factor of where k = num dimensions
- Therefore, we can counteract this by normalizing is via as the denominator
- Step 3: Apply softmax to convert scores into probabilities:
- Step 4: Compute the weighted sum of Value vectors:
- Output:
- The output is a context-aware representation of the word , influenced by its relationship with other words in the sequence.
Multi-Head Attention
- Instead of using a single set of , Multi-Head Attention uses multiple sets to capture different types of relationships between words (e.g., syntactic vs. semantic).
- Each head computes its own attention output.
- Outputs from all heads are concatenated and passed through a final weight matrix :
- Outputs from all heads are concatenated and passed through a final weight matrix :
Positional Encoding
- Since Self Attention does not inherently consider word order, Positional Encoding is added to input embeddings to encode word positions
- Positional encodings are vectors added to each input embedding, allowing the model to distinguish between words based on their positions in the sequence
- Why is sinusoidal relevant and useful?
- Allows Transformer to learn relative positions via linear functions (e.g., can be derived from )
- We all know neural nets like linear functions! So it's helpful in ensuring a relationship that's understandable
Residual Connections and Normalization
- Each encoder layer includes a residual connection and normalization layers to stabilize training and improve gradient flow
- This happens after both Self Attention Layer and Feed Forward Layer in the "Add and Normalize" bubble
- Add the residual (the original input for that sublayer) to the output of the sublayer
- In the case of Self Attention layer, we add the output of Self Attention to the original input word (non-attended to word)
- Apply LayerNorm to the result
- This just means normalize all actual numeric values over the words embedding
- **If the diagram shows a block over the whole sentence, it just means the operation is applied to all words, but always independently for each word
- Why is any of this useful:
- Helps with gradient vanishing and exploding, and also ensures input stability
Summary of Self Attention Encoding
-
Input Embedings:
- We take our input words, process them, and retrieve static embeddings
- This only happens in the first encoding layer
-
Positional Encoding:
- Add positional information to embeddings to account for word order
-
Self Attention: 3.1 Input Transformation:
- Positionally encoded embeddings are transformed into using learned weight matrices.
3.2 Self Attention Calculation:
- Compute attention scores using dot products of and , scale them, and apply softmax.
3.3 Weighted Sum:
- Use the attention weights to compute a weighted sum of , and add that onto the input word, producing the output.
3.4 Residual + Normalization:
- LayerNorm add together input and self-attended to matrices
3.5 Feed Forward Layer:
- Each position’s output from the self-attention layer is passed through a fully connected feed-forward neural network (the same network is applied independently to each position)
- Essentially just gives model another chance to find and model more transformations / features, while also potentially allowing different dimensionalities to be stacked together
- If we have 10 words in our input, we want to ensure the final output is the same dimensionality as the input
- I don't know if this is exactly necessary
4 Multi-Head Attention:
- Use multiple sets of to capture diverse relationships, then concatenate the results.
This diagram below shows one single encoding block using Self Attention
Masked Self Attention
- In Masked Self Attention, it's the same process as Self Attention except we mask a certain number of words so that the results in 0 effectively removing it from attention scoring
- In BERT training we mask a number of words inside of the sentence
- In GPT2 training we mask all future words (right hand of sentence from any word)
Context Size and Scaling Challenges
- The size of the matrix grows quadratically with the context size (), making it computationally expensive for long sequences.
- To address this, masking is used to prevent future words from influencing current words during training (e.g., in autoregressive tasks).
- Context size
- Size of Q * K matrix at the end is the square of the context size, since we need to use all of the Q * K vectors, and…it’s a matrix! So it’s n*n = n^2 so it’s very hard to scale
- It does help that we mask ½ the examples because we don’t want future words to alter our current word and have it cheat
- Since for an entire sentence during training for each word we try to predict the next, so if there are 5 words there’s 1, 2, 3, 4, 5 training examples and not just 1
- Don’t want 4 and 5 to interfere with training 1, 2, 3
Encoder-Decoder Attention
Encoder-Decoder Attention is a mechanism used in Seq2Seq tasks (e.g., translation, summarization) to transform an input sequence into an output sequence. It combines Self Attention within the encoder and decoder blocks each, and then cross-attention between the encoder and decoder
How Encoder-Decoder Attention Works
- Encoder:
- The Encoder Portion is completely described by what we wrote above in Summary of Self Attention Encoding
- TLDR;
- The encoder processes the input sequence and generates a sequence of hidden states that represent the context of the input
- Each encoder block consists of:
- Input Embedding:
- The first encoding layer typically uses positional encoding + static embeddings from Word2Vec or GLoVE
- Self Attention Layer:
- Allows each token in the input sequence to attend to other tokens in the sequence
- This captures relationships between tokens in the input
- Feed Forward Layer:
- Applies a fully connected feed-forward network to each token independently
- Typically two linear transformations with a ReLU/GeLU in between:
FFN(x) = max(0, xW₁ + b₁)W₂ + b₂
- Residual Connection + LayerNorm
- Add input / output of layer, and then normalize across vector
- Input Embedding:
- The output of each encoder block is passed to the next encoder block as input, and the final encoder block produces the contextual embeddings for the entire input sequence
- These are further transformed into K, V contextual output embeddings
- This confused me at first, but basically the output of an encoder block is same dimensionality as word embedding input, so it can flow through
- This is usually known as
d_model
- This allows us to stack encoder blocks arbitrarily
- Architecture:
- Composed of multiple identical blocks (e.g., 6 blocks by default, but this is a hyperparameter)
- Each block contains:
- Self Attention Layer: Captures relationships within the input sequence
- Feed Forward Layer: Processes each token independently
- Decoder:
- The decoder generates the output sequence one token at a time, using both the encoder's output and its own previous outputs
- Input:
- The contextual embeddings output from the final Encoding Layer
- These K, V contextual output embeddings are passed to each Decoder block
- The previous Decoder block(s) output (previously generated word)
- The contextual embeddings output from the final Encoding Layer
- Each decoder block consists of:
- Masked Self Attention Layer:
- Allows each token in the output sequence to attend to previously generated tokens in the sequence (auto-regressive behavior)
- Future tokens are masked to prevent the model from "cheating" by looking ahead
- So self-attention only happens from words on the left, not all Keys
- Query: Current token's embedding
- Key and Values: All already generated words to the left
- Similar to self-attention except we ignore all to the right
- Encoder-Decoder Attention Layer:
- Attends to the encoder's output (contextual embeddings) to incorporate information from the input sequence
- Query: Comes from the decoder's self-attention output
- i.e. it's the decoder's current representation of a token after masked self-attention
- Key and Values: Encoder's output for each input token
- Feed Forward Layer:
- Applies a fully connected feed-forward network to each token independently
- Masked Self Attention Layer:
- Architecture:
- Composed of multiple identical blocks (e.g., 6 blocks by default).
- Each block contains:
- Self Attention Layer: Captures relationships within the output sequence
- Encoder-Decoder Attention Layer: Incorporates information from the encoder's output
- Feed Forward Layer: Processes each token independently
- Example:
- Input sentence has 5 words in total
- Remember, the encoder put out 5 total vectors, one for each input word
- Let's walk through the third word in the decoder output, meaning the first two have already been generated
- Decoder Self Attention
- Input: The embeddings for the first, second, and third generated tokens so far
- The query is the input embedding (same one fed to encoder) for the 3rd word
- K,V are the input embedding (same one fed to encoder) for the 1st and 2nd words so far
- Masking: The self-attention is masked so the third position can only "see" the first, second, and third tokens (not future tokens)
- What happens: The third token attends to itself and all previous tokens (but not future ones), using their embeddings as keys and values
- Input: The embeddings for the first, second, and third generated tokens so far
- Encoder Decoder Cross Attention
- Input: The output of the decoder’s self-attention for the third token (now a context-aware vector), and the encoder’s output for all input tokens
- Q, K, V:
- The query is the third word's attended to vector (after self-attention and residual/LayerNorm in decoder)
- The keys and values are the encoder’s output vectors for each input token (these are fixed for the whole output sequence)
- What happens: The third token’s representation attends to all positions in the input sequence, using the encoder’s outputs as keys and values
- Input sentence has 5 words in total
- Final Decoder Output:
- The final decoder layer produces a vector of floats for each token, which is passed through:
- A linear layer to expand the vector to the vocabulary size
- A softmax layer to produce a probability distribution over the vocabulary for the next token
- The final decoder layer produces a vector of floats for each token, which is passed through:
Visual Representation
-
Encoder Block:
- Self Attention → Feed Forward → Output to next encoder block.
-
Decoder Block:
- Self Attention → Encoder-Decoder Attention → Feed Forward → Output to next decoder block.
-
Final Decoder Output:
- The final decoder output is passed through a linear layer and softmax to produce the next token.
- The final decoder output is passed through a linear layer and softmax to produce the next token.
Summary of Encoder-Decoder Attention
-
Encoder:
- Processes the input sequence and generates contextual embeddings using self-attention.
-
Decoder:
- Generates the output sequence token by token using:
- Self Attention: Captures relationships within the output sequence.
- Encoder-Decoder Attention: Incorporates information from the input sequence.
- Auto-Regressive Decoder: Tokens are predicted auto-regressively, meaning words can only condition on leftward context while generating
- Generates the output sequence token by token using:
-
Final Output:
- The decoder's output is passed through a linear layer and softmax to produce the next token.
-
Training:
- The model is trained using cross-entropy loss and KL divergence, with each token in the output sequence contributing to the loss.
- The model is trained using cross-entropy loss and KL divergence, with each token in the output sequence contributing to the loss.
Vision Transformers (ViT)
In using transformers for vision, the overall architecture is largely the same - flattening structure out and using augmention for new examples and then doing self-supervised "fill in the blank" for training
All changes are relatively minor:
- Input:
- Text: Input is a sequence of tokens
- Vision: Input is an image split into fixed size patches
16x16
- Each patch gets flattened and linearly projected to form a "patch embedding" similar to static word embeddings
[CLS]
token used for classification tasks
- Positional Encoding:
- Text: Added to token embeddings to encode word order
- Vision: Added to patch embeddings to encode spatial information of each patch in the image
- Objective:
- Text: Predict the next word (causal), fill in the blank, or generate a sequence (translation / summarization)
- Vision: Usually image classification, or can also be segmentation, detection, or masked patch prediction (fill in the blank)
- Architecture: Basically the same without any major overhauls
- Self Supervision:
- Text: Fill in the blank, next sentence prediction
- Vision: Fill in the blank (patch), or pixel reconstruction which aims to recreate the original image from corrupted or downsampled versions