LoRA (Low-Rank Adaptation) paper in-depth explanation
This article series will explain the two papers on finetuning neural networks and large language models. First, we will start with LoRA and then QLoRA (Quantized LoRA) in the next article. We will go in-depth into these two papers and explain all the moving parts. So, let’s start with LoRA.
YOU NEED TO KNOW THE BASICS OF TRANSFORMER, IF NOT PLEASE REFER TO THESE VIDEOS, 1 AND 2, AND THEN COME BACK TO THE ARTICLE.
INTRODUCTION AND WHY WE NEED LoRA
The standard procedure for fine-tuning your dataset is to add some adapter layers (linear layers) to the base pre-trained model and then freeze the weights of your pre-trained model and then update the weights of those adapter layers only. In language models based on transformer architecture, we add these adapter layers to our transformer layers and tune them. We can see some issues with this approach, the training overhead, and inference latency. We need to train these adapter layers, and we have to add these adapters to each transformer layer. The GPT-3 model has 96 layers of Multi-Head attention, so if we start adding more and more layers, the training overhead would increase. Also, during inference, these extra layers will increase the inference latency.
SOME MORE ISSUES WITH TRADITIONAL FINE-TUNING
- The adapter layers need to be processed sequentially. As discussed earlier, there are 96 layers in GPT-3, where each layer attends to certain linguistic properties. For each layer, we have some adapter layers and they are arranged in a hierarchical manner. Hence, we need to do sequential forward and backward propagation.
- We cannot shard the model, as in sharding we send the weights of the model to different devices after sharding them and then accumulate the gradients. Sharding is essential as these models are memory intensive and GPU memory is small. But due to these adapter layers and their sequential nature, we cannot shard them and store them on different devices
HYPOTHESIS AND LoRA EXPLANATION
- The hypothesis of the LoRA paper is that the pre-trained overparameterized models reside in a lower dimension and the changes in weights from gradient descent have a lower intrinsic rank.
- Rather than optimizing the parameters of the dense layers, we can represent them in lower dimensions using SVD (Singular Value Decomposition) and then do gradient descent to optimize the weights at lower dimensions.
- The Singular Value decomposition can convert any matrix to a set of two orthonormal matrices and a diagonal matrix
- Here the U and V matrices are orthonormal, and they rotate the M matrix. The ∑ matrix is a diagonal matrix with eigenvalues in the decreasing order σ1≥σ2≥σ3….≥σn.
- To compress a matrix to a dimension of r from n, we can select the first r eigenvalues, σ1 to σr, and then zero out other diagonal elements. This will reduce the matrix dimension to a lower order.
- So we can take a dense layer(feedforward layer) in the transformer, which will be of the dimension 4 times the transformer output (as mentioned in the paper).
- Here model dimension is the transformer output at each layer, and d_ffn is the dimension of the feedforward layer followed by the transformer layer (self-attention module). By using SVD, we can compress the matrix from d_ffn to the lower dimension and then do gradient descent. As the weight updates at lower dimensions can explain the high-dimensional weight matrices, we can fine-tune our model
- The feedforward is generally the query, key, and value matrices. In the paper, the authors suggest the above dimensionality reduction for query and value matrices only (because of computation budget), but in future directions, they mentioned more exploration into this
GOING DEEPER INTO THE ARCHITECTURE
- We understood how LoRA works from a high level, now let’s dig deeper. Let’s start with the architecture of the weight matrices.
- The query, key, and value feedforward layers have a dimension of d x d, which is in a higher dimension. But you can see the orange part which has A and B. The A matrix is initialized as a Gaussian with zero mean and σ standard deviation, and the B matrix is initialized as all zeros. The B matrix is of dimension d x r, and the A has a dimension of r x d. So the h output matrix is given by
- Here W0 is the weight matrix and x is my input. The δx is the gradient, which is a lower dimension in r and can be represented as BA. So you can finetune and get the matrices B and A.
- Here is an interesting observation, you can swap out the matrices A and B and swap out with different A` and B` for a different task with zero latency overhead. Now, you can have a mixture of expert models with their respective A’s and B’s, and you can choose an expert model on the fly. In traditional fine-tuning, this arrangement is not possible. Also, adapter-based models converge to an MLP, and prefix-based methods to a model that cannot take long input sequences because we are adding more layers and there will be a loss in gradient information for long sequences without residual connections.
SOME PRACTICAL BENEFITS
- There are some memory and storage benefits like the VRAM usage is reduced by up to 2/3 for r<<d. For GPT-3 with 175B parameters, they reduce the VRAM consumption from 1.2TB to 350GB.
- Also, they reduced checkpointing overhead as the size of activations and gradients is reduced by roughly 10,000x (from 350GB to 35MB) for GPT-3 model. There was also a 25% speedup in training.
LIMITATIONS
- We cannot train in batches for multiple tasks. For a task, we have A and B matrix, and in our tuning run, we cannot include examples from another task as the A and B are unique for a task. Also, during inference, we CAN NOT batch inputs from multiple tasks to predict, we can only send the inputs that belong to a particular task for which we have trained and got the A and B matrices.
UNDERSTANDING THE LOW-RANK UPDATES
The authors try to answer three questions for these low-rank updates
- Which subset of weight matrices do we need to adapt to maximize the fine-tuning performance?
- What is the ideal r to use?
- What is the connection between W and grad(W)?
QUESTION 1: WHICH MATRICES TO FINE-TUNE?
- These are the validation accuracy on the two tasks WikiSQL and MultiNLI for different rank and the matrices used. Due to budget constraints, we have a trade-off of the rank r and the matrices that we need to fine tune. If we increase r, then we can only finetune certain number of matrices from query, key and values. You can see that a lower rank r and more number of matrices for finetuning is preferred.
QUESTION 2: WHAT IS THE IDEAL VALUE OF R?
- From the above table, we can observe that as r increases, the accuracy starts to decrease. So the goldilock zone for the rank matric r is 4 or 8. It signifies that the rank of four captures enough information in the gradient of W such that it is preferable to adapt more matrices than adapting a single type of weights with a larger rank. Hence, increasing r does not cover a more meaningful subspace, so it is not worth the extra computational overhead
QUESTION 3: WHAT IS THE RELATION BETWEEN W AND THE GRADIENT OF W?
- To answer the third question, the author did two analysis, (i) subspace similarity between different r (ii) how does grad(W) compare to the W
THE ABOVE PLOT IS ONLY FOR THE 48th LAYER OUT OF 96 LAYERS IN GPT-3
- Here, i belongs to compressed matrix to a dimension of 8, and j belongs to compressed to a dimension of 64. They take the Grassman score between the gradients for different dimensions, and you can see that for lower dimension, the score is higher, which suggests that lower dimensions are similar for different values of r.
- Also, they compare the compressed grad(W) matrix with itself, and for baseline they take a random matrix. First, they take two random seeds, and generate two kinds of weight matrices for query and value. After fine-tuning, if the compressed grad(W) matrix is more similar to itself for these two different seeds than the random matrix, then we can say that grad(W) fine-tuning is not a random process, and the weight updates actually converge to weights that are similar across multiple random seeds.
- So you can see that for lower dimension, the weight updates are similar, and as the r increases, they tend to diverge (for r≥8). Also, the random matrices have no similarity. Hence, we would always prefer a smaller r and adapt to more matrices.
This brings us end to the LoRA paper. We will continue with the QLoRA paper in the next article and it will be clearly explained. If you have any queries, then I am one comment away. Thanks