MAMBA and State Space Models Explained
This article will go through a new class of deep learning models called Structured State Spaces and Mamba.
- Transformer and RNN plus their issues and positive sides
- S4 models and architecture details
- Mamba architecture
EDIT: (UPDATE March 5 2024)
A brilliant visual introduction to Mamba.
RNN
Before transformers were introduced, we did sequence modelling using recurrent neural networks (RNN) and long short-term memory (LSTM).
- In the photo above, we have RNN architecture unrolled over time. We can see here that the output from the RNN architecture is again fed to the RNN, and while backpropagation, the architecture is unrolled over time to update the hidden state. During training time, the time complexity is O(N), as we have to sequentially pass through all the tokens in a sequence and during inference, it is O(1), which is very similar to the Markovian state. The hidden state is the aggregation of whatever the RNN has seen till now, so we just need the last hidden state and current token to generate the next token.
- However, the RNN architecture suffers from vanishing and exploding gradient problems. The current hidden state may not properly encode all the information that it has seen till now, leading to catastrophic forgetting. For LSTM, the architecture is more complex, and it tries to ameliorate it by adding forget and update gates. Still, we face issues regarding vanishing gradient problems and we can’t scale well to longer sequences. Also, the training is not parallelizable. But during inference, it is constant time, hence it will scale well during inference
TRANSFORMERS
- To ameliorate the issues faced in RNN-based sequence modelling, a new class of algorithms were introduced called transformers. They scale QUADRATICALLY as each token pays attention to all the previous tokens. The BERT family models have encoder and decoder layers, where in the encoder layers, the tokens pay attention to all the tokens, hence the computation would be O(N²). In the decoder, the attention computation would also be O(N²) as it is masked (the current token attention computation does not consider future tokens)
- Hence, the transformers do not scale well, and the SOTA model that has the longest context window length is 200K by Claude from Anthropic. Also, during inference, we need to store the KV-cache along with the model, which grows quadratically too.
- A paper that was hard to digest for me was the Attention sinks paper. Here the authors suggested that the initial tokens are paid more attention, as all the generated tokens pay attention to the first few tokens. Hence, if you try sliding window attention, which will scale in constant time, you always need to consider the first few tokens. Also, in the paper, they depict that if they replace the first token with four backslash newline characters (\n), even then it is paid more attention when we are generating the tokens. Hence, it is a positional thing, not a semantic thing. So to improve inference speed, we can do a sliding window, but always include the initial four tokens in our attention computation.
- Another thing is when you speak, you don’t go back and pay attention to all the words that you have said. When you speak, you create a Markovian state at each instance, and we store a rough view of what we have spoken till now, which helps us to speak further. In attention, we consider all the tokens generated till now.
WHAT WOULD A BETTER ARCHITECTURE LOOK LIKE?
- The RNN architecture is suitable for inference, as it is a constant time operation, which has the current hidden state. However, we need a transformer-style architecture to ameliorate the vanishing gradient problems. HOWEVER, transformers scale quadratically during training and inference. Hence, we need a sub-quadratic approach to challenge transformers, which is nearly constant time during inference.
- The transformer architecture is parallelizable during training, but it scales quadratically. Hence, we want something that can scale well during training (parallelizable) and inference(constant time).
The structured state space (S4)model was introduced in this paper to address these requirements. These annotated S4 helped me to understand S4 and also its corresponding video was useful.
LET’S DIVE IN
RECURRENT VIEW FOR STATE SPACES
- The state space model maps 1-D input (u(t)) to 1-D output (y(t)) via an N-D latent dimension (x(t)). Hence, the A ϵ (N x N), C ϵ (N x 1) and B ϵ (1 x N).
- Drawing to language modelling, the embeddings of tokens, for each dimension in the embedding, we have a state space model. In a transformer, we transform the embeddings into a number of head dimensions and then concatenate them.
- The D-matrix operation is similar to skip connection in neural networks.
- However, with continuous settings, the authors found issues with vanishing gradients. Hence, we will be moving into a recurrent view of the state space model.
- The recurrent view of SSM first involves discretising A, B and C matrices. The discretization follows the bilinear method and it results in
- Finally, the recurrent view is
- However, we know that recurrent networks cannot be trained efficiently as they are not parallelizable, and we will encounter vanishing gradient problems. Hence for efficient parallelization, we can move from a continuous setting to a convolution setting.
- The convolutional view derivation is as follows
- An interesting observation is that it looks a lot like an attention matrix.
- Now, the convolutional kernel looks like the following
- The convolutional view of the SSM can be parallelized during training, which addresses the shortcomings of the RNN view of SSM. This video and set of diagrams provide a good understanding of the convolutional view. The convolutional view starts from
- The last step of an input vector of dimension 4
BUT, THERE IS AN ISSUE WITH THE CONVOLUTIONAL VIEW (INFERENCE PART)
- In autoregressive models, you don’t have access to future tokens during inference. Hence, you can’t build the convolutional kernel in advance.
- Also, the A, B and C matrices are dynamically updated as the model generates the tokens based on the new information and all previously generated tokens.
- The transformer architecture can parallelize during training by taking the embeddings of a sequence of tokens, adding positional embeddings and computing self-attention so that the model is better aware of the context in the sequence length. Here, all the tokens interact with each other, to better inform the model and the transformation matrix to query, key and value to update their weights to form a better context. The positional embeddings inform the model about the sequential tokens.
- But with the convolutional view, the context length is fixed by the matrices dimension, which will be limited given that A is N x N, and it would hit some memory bottleneck. As the inference and dynamically updating kernel make it harder to use the convolutional view, hence the authors kept the recurrent view, but addressed the computational inefficiency of the sequential nature in the recurrent view.
PREFIX SUM AND PARALLEL SCANNING ALGORITHMS
- First, let’s clarify the naming in the MAMBA paper
State Space Models (SSMs) → Structured State Space Model (S4) → Selective scan State Space Models (S6 or the MAMBA paper).
- The structured part comes from the structured initialization of the matrices, like a HiPPO matrix for A matrix. The diagonal matrix initialization for A leads to performance improvement.
- The extra selective scan makes the MAMBA architecture scale and match the performance with transformers. Let’s discuss the scan part in the Mamba, which is the parallel scan. The authors mention the following resource regarding the parallel scan algorithm
- The parallel scan operation is covered well in this set of videos from Udacity on GPU programming.
- The Blelloch scan works in the following manner (for exclusive cumulative sum)
- The Blelloch scan algorithm computes the cumulative sum of a sequence, by first getting to the sum of the sequence, and then in the downstream part, recomputing the cumulative sum by copying the elements from the previous levels.
- The recurrent view can now be efficiently parallelized, instead of sum, we will have the SSM computations.
- Also, other GPU hardware-aware optimizations were implemented to address the memory bottlenecks. In the image below, the hidden states were stored in a faster GPU memory (GPU SRAM), and the intermediate matrices, A, B and C are stored in a slower GPU memory (GPU High Bandwidth Memory or HBM). These matrices are recomputed from the embedding dimension x_t (if we are talking about NLP).
NOW, COMING TO THE SELECTIVE PART IN S4
- The embedding dimension (x_t ) for NLP problems is projected to the Bt, Δt (discretization) and Ct matrices. The discretization learnable parameter discretizes the A and B continuous matrices to discrete-time via the following computations
- An important note here: the discretization is not the same as sampling a continuous-time function. Rather it is approximating the continuous-time function in a discrete setting.
- The discretization parameter acts as forget and update gates (similar to LSTMs). The higher values of Δ, which will lead to Abar, and will sustain the information in the hidden states, and the current token will not be given too much importance. Hence, the model dynamically selects (selective) what to forget and update.
- The transformer architecture does not have any hidden states, rather it just attends to all the previous tokens to build context, hence the quadratic nature of the transformer. But the MAMBA architecture keeps the context information in the hidden state, hence it is constant time inference.
- The scan algorithms can expedite the training procedures, hence we can parallelize the workloads in MAMBA.
SOME GOOD RESOURCES
- The Annotated S4: https://srush.github.io/annotated-s4/
- Umar Jamil Video
- Gabriel Video
- West Coast Machine Learning Group Video (recommended): As of today, there are four parts
The article got too long, hence, I will be writing Part 2 for MAMBA, where I will dive deeper into the code. If you have any queries, I am one comment away. Thank you for reading till the last.