Friday, September 08, 2023

Key, Query, Value Matrices in Masked Self-Attention of Decoder-Only Transformers

   StatQuest uploaded a good video at explaining how a Decorder-Only transformer works. Most of the content talked about how Key, Query, Value Matrices are calculated. It is quite complex. So here I am going to explain it in a more intuitive way (based on my own understanding).

A sentence is first parsed to tokens, and each token has an embedding and its position i in the sentence. The word embedding + the position encoding make a vector for the word at the position i. Nothing special so far.

Now at each token at position i, using this vector (call it word_vector_i), we'd like to encode another vector to represent the context in the sentence so far. This new vector at i should be based on the vector for this word at i and all the previous words from [0, i -1]. To combine these vectors, we are going to take a weighted sum. This is the overall idea.

    vector_with_context (i) =  w1 * value_vector_1 + w2 * value_vector_2 + ... + wi * value_vector_i

But wait, it is not nice to directly use the embedding + positional encoding (word_vector_i) as the value_vector_i. Instead, we will transform it with a matrix (Mv). Mv will be adjustable and learned. So,

    value_vector_i = Mv * word_vector_i

Weight w1 is how similar the 1st word is related to the ith word. Weight w2 is how similar the 2nd word is related to the ith word, etc. To find out how similar the two words are, we are going to apply a dot product on the vector for two words.

But wait, it is not nice to directly use the embedding + positional encoding (word_vector_i), so we again are going to transform word_vector_i with a matrix... Actually, two matrices - one matrix (Mq) for transforming word_vector_i and one matrix (Mk) for transforming word_vector before ith position. 

    query_vector_i = Mq * word_vector_i

    key_vector_1 = Mk * word_vector_1 

    key_vector_2 = Mk * word_vector_2

    ...

    key_vector_(i-1) = Mk * word_vector_(i-1)

Mq and Mwill be adjustable and learned. The weight can be calculated 

    wj = query_vector_i * key_vector_j   

But wait, these weights are not nice. So, we are going to take all the weights and run a softmax to get a better scaled weights (which sum to 1). Applying these weights and value_vector's, vector_with_context(i) is calculated. vector_with_context(i) is called Masked Self-Attention.

To predict the next word at i+1 position, just apply vector_with_context(i) to a fully connected layer to a  result vector representing probability at each word in the dictionary.

But wait, using only masked self-attention (vector_with_context(i) ) isn't nice, we'd like to sum it with embedding and positional encoding (aka word_vector_i as described above). So the prediction of next word is really depending on 3 things. Since we are summing a later vector with an earlier vector, this becomes a residual link in the network.

(Note: since the residual link will sum the masked self-attention and the word embedding, that means their dimensions have to match. This also means Mk, MqMhave to produce the same size. So the size of the Matrix is predetermined.)

Of course, the result vector will apply a softmax to scale probability better.

  - What if it predicted the next word wrong, in my prompt?

      Run your optimizer to train Mk, MqMv and the fully connected layer to make it right.

 - What if I want to generate a reply?

      Repeat the process (without training) to run at every position in your prompt. At the end of the prompt, (at the end-of-sentence token), let the transformer predict the next word. Your transformer is now generating a reply! Keep output the next word and add to the end of the sentence until it outputs the end-of-sentence token.



No comments: