RoFormer paper explained and implemented in JAX

Astarag Mohapatra
5 min readNov 13, 2023

In this article, we will go through the RoFormer paper, which introduced rotary positional embedding for transformer architecture and positional encodings. Also, we will implement it using the JAX deep learning framework

Before jumping into RoPE (rotary positional encoding), let’s first discuss positional encoding for the transformer architecture introduced in the original transformer paper

If you want to understand transformers better, I would suggest these videos (I), (II), and (III). I will assume that you know the basics of transformers and how Multi-head attention works.

THE NEED FOR POSITIONAL ENCODING

  • The self-attention formulation is as follows
  • Here the attention from the query vector for the token at the m position is given by
Taken from the RoPE paper
  • We take the dot product of the projection from the query vector with the query vector of all the tokens (preceding tokens for the decoder and all tokens for the encoder), and then we take softmax followed by matrix…

--

--

Astarag Mohapatra

Hi Astarag here, I am interested in topics about Deep learning and other topics. If you have any queries I am one comment away