Understanding the Matrix Operations In Transformer Architecture

I wanted to write this article for a long time, there have been many excellent articles explaining the transformer architecture in detail like Annotated Transfomer . The article covers in detail about every building block of transformers. I would like to extend the article to explain how the matrix operations happen during the training as well as inference.

Transformer architecture is faster compared to LSTM is because of parallel batched multiplications which happen behind the scenes. It's hard to understand this process because it happens in a multi-dimensional space. In the below explanation I would simplify it by splitting the operations in a two dimensional space which are easier to understand.

Transformer Architecture

Major components of transfomer are

Training of transfomer happens in parallel (all words are passed one shot unlike the autoregressive way in LSTM). Inference in transformer happens in a auto regressive way.(give sos token predict first word then next word ..so on)

To understand the mathematical operations in transformer lets take an example of grammatical error correction.
source sentence : I going to school today.
target setnece : I am going to school today.
Lets assume the basic trainsfomer architecture which has 6 blocks, 8 attention heads and 512 dimesnion of hidden state.As all the 6 blocks have similar operations below is the explanation for one block.

Training Phase

Once we have parallel data (source & target sentence) as mentioned above.
Lets take the batch size for training is 4 , and vocab size is 30k

tokens = [4 x 10] -- batch size x sequence length(no of words)
one hot encoded tokens = b[4 x 10 x 30k]

Word embedding layer is a matrix which projects the one hot encoded tokens into a hidden dimension of length 512.
Wt = [30k x 512]
word embedding = (one hot encoded tokens) * (Wt) -- [4 x 10 x 512]

Postional embedding is a vector of length 512 which uniquly represents an index in the sequence of tokens . Better explanation here
positional embedding = [4 x 10 x 512]

input to encoder = word embedding + positional embedding -- [4 x 10 x 512] (matrix addition)

Self Attention

There are 8 heads in Self Attention layer of the transformer, all the heads have similar calculations but without parameter sharing.

Attention Head 1
Q1 = K1 = V1 = [4 x 10 x 512](input to encoder)
Here we learn three paramters Q1w,K1w,V1w which converts the input to a lower dimensional space of length 64 (64 because 512/(8 heads) = 64 , we concatenate the vectors from 8 attention heads to make it 512 again)
Q1w,K1w,V1w = [512 x 64] (Replicate the matrices batch size times) == [4 x 512 x 64]
Q1^ = Q1 * Q1w = [4 x 10 x 64]
K1^ = K1 * K1w = [4 x 10 x 64]
V1^ = V1 * V1w = [4 x 10 x 64]
matrix [ m x n ] * [n x k ] = [m x k]
Attention: Q1^ x K1^(transpose) = [4 x 10 x 64] * [4 x 64 x 10] == [4 x 10 x 10] --> Softmax --> [4 x 10 x 10]
Assuming there are no padding tokens in the batch for simplicity. we now got a attention matrix which shows importance of each word with respect to other words in the sentence.And this a bidirectional attention as each word can see the future words as well.
Weighted Context Vector = Attention * V1^ == [4 x 10 x 10] * [4 x 10 x 64] == [4 x 10 x 64]
matrix [ p x m x n] * [p x n x k] == [p x m x k]

Output of Attention head 1 = Weighted Context Vector = [4 x 10 x 64]

Attention Head 2
Q2 = K2 = V2 = [4 x 10 x 512](input to encoder)
Here again we learn three paramters Q2w,K2w,V2w
Q2w,K2w,V2w = [512 x 64]
Q2^ = Q2 * Q2w = [4 x 10 x 64]
K2^ = K2 * K2w = [4 x 10 x 64]
V2^ = V2 * V2w = [4 x 10 x 64]
matrix [ m x n ] * [n x k ] = [m x k]
Attention: Q2^ x K2^(transpose) = [4 x 10 x 64] * [4 x 64 x 10] == [4 x 10 x 10] --> Softmax --> [4 x 10 x 10]
Assuming there are no padding tokens in the batch for simplicity. we now got a attention matrix which shows importance of each word with respect to other words in the sentence.And this a bidirectional attention as each word can see the future words as well.
Weighted Context Vector = Attention * V2^ == [4 x 10 x 10] * [4 x 10 x 64] == [4 x 10 x 64]
matrix [ p x m x n] * [p x n x k] == [p x m x k]

Output of Attention head 2 = Weighted Context Vector = [4 x 10 x 64]

Similar calcualtion happens at other attention heads as well.
Output of block 1 of encoder = concatenation of ouput from 8 heads = [4 x 10 x 512](concatenation of 8 matrices on the last dimension)