RoFormer paper explained and implemented in JAX
In this article, we will go through the RoFormer paper, which introduced rotary positional embedding for transformer architecture and positional encodings. Also, we will implement it using the JAX deep learning framework
Before jumping into RoPE (rotary positional encoding), let’s first discuss positional encoding for the transformer architecture introduced in the original transformer paper
If you want to understand transformers better, I would suggest these videos (I), (II), and (III). I will assume that you know the basics of transformers and how Multi-head attention works.
THE NEED FOR POSITIONAL ENCODING
- The self-attention formulation is as follows
- Here the attention from the query vector for the token at the m position is given by
- We take the dot product of the projection from the query vector with the query vector of all the tokens (preceding tokens for the decoder and all tokens for the encoder), and then we take softmax followed by matrix…