Sunday, July 23, 2023

Key, Query, Value Matrices in Self Attention

The Attention is all you need paper mentioned about an attention function with construction of three matrices Q, K, V without much explanation. Fortunately, this Youtube tutorial on attention explained well (voiced in Chinese). Here is a note that I took from the video.

In self attention, there is only one input, as a list of tokens, each of which is a word expressed as a vector embedding. Call this input X of m elements. Each Xi is the embedding of the ith word. The task to guess the ith word to output, by looking at all words and the i-th word in the input.

Wk, Wq, Wv are the parameter matrices to be learned. Each of them multiplies X to get Q, K, V.

1. It needs to look at all words, which is the Wk matrix multiply X. This matrix is called key matrix as it looks at all keys (words). K = Wk * X

2. It needs to look at the ith word Xi, which is transformed by Q matrix, aka query matrix. qi = Wq * Xi

3. Take the result of K from step 1 and multiply the qi in step 2 and take a softmax. Call this result Ai = Softmax(K.transpose * qi)

4. The context vector at ith location Ai and multiply it with V. Call this result Ci = V * Ai. Since Ai came from step 3 with a softmax , Ci is essentially a weighted sum of V, based on the weight Ai.

5. Take Ci into a Softmax Classifier to get an output word.


Also, Self Attention is a special case of Attention. For self attention, qi is calculated by the ith word in the input Xi. For attention, qi is calculated by looking at the previous output of ith word ( For the very first position, <start> token is considered as the previous output.)