Authors : Garrett Byrd, Dr. Joe Schoonover
What is Fine-Tuning?
Fine-tuning a large language model (LLM) is the process of increasing a model's performance for a specific task. E.g., making a model "familiar" with a particular dataset, or getting it to respond in a certain way. One might consider a use-case where fine-tuning an LLM on some dataset internal to a company could be massively useful. A user could provide a prompt like, "Can you summarize our expenses last quarter?" Even just specializing a model to produce responses in the traditional question/answer format is a core feature of the LLMs that works in the backend of chatbots.
The Two Phases of LLMs
There are two main phases in the lifecycle of a model: training and inference. Inference is the phase where some input is provided and the model produces some output. I.e, the model infers a response. Conversely, training is the process of adjusting the (typically billions of) parameters that constitute the model. The stark difference between training and inference exists in the hardware requirements.
Generally, inference requires that the model fits in memory. Some extra computational overhead can be required, but this overhead is typically a fraction of the memory footprint of the model's parameters. Training, on the other hand, generally requires considerable memory overhead. This is because training not only needs to fit the model in memory, but also requires memory for training-specific calculations (e.g., gradient descent or other optimizers), as well as batches of training data, which could have a huge memory footprint.
Fine-tuning, in its most general sense, requires retraining a model to consider a new target dataset. This becomes problematic for a variety of reasons, one of which is that training a model is computationally expensive. The training phase doesn't just require large amounts of memory, there is a lot of computation required to process data, calculate optimizations, and delicately adjust parameters. Now we realize our problem: We want to specialize a model to perform better at a certain task (i.e., fine-tune it). However, to do this, we need to retrain the model; this vastly increases the memory requirement compared to just using the model for inference. Our solution: Reduce the memory footprint of the (fine-tuning) training process.
Fine-Tuning Methods
The above solution almost seems naive. Were it so easy, we likely wouldn't be facing the memory problem at all; that is, why aren't we just making models smaller to begin with? The good news is that this is possible at all; as we will see, there is a buffet of methods designed for reducing the memory footprint of models, and we apply many of these methods to fine-tune Llama 3 with the MetaMathQA dataset on Radeon GPUs.
Parameter Efficient Fine-Tuning (PEFT)
To address the immediate issue of needing to retrain an entire model's worth of parameters, we opt for Parameter-Efficient Fine-Tuning. Parameter-Efficient Fine-Tuning (PEFT) methods enable efficient adaptation of large pre-trained models to various downstream applications by only fine-tuning a small number of (extra) model parameters instead of all the model's parameters. This significantly decreases the computational and storage costs. That is, instead of retraining every single parameter, we retrain a subset of parameters, or even just a small set that appends the default weights.
Low Rank Adaption (LoRA)
One popular PEFT implementation is Low Rank Adaptation. Low Rank Adaptation (LoRA) is a low-rank decomposition method to reduce the number of trainable parameters which speeds up fine-tuning large models and uses less memory. LoRA allows us to train some dense layers in a neural network indirectly by optimizing rank decomposition matrices of the dense layers’ change during adaptation instead, while keeping the pre-trained weights frozen.
Quantized Low Rank Adaption (QLoRA) and Quantization
QLoRA utilizes quantization to further reduce the memory footprint of low-rank adaptation methods. Quantization is the process of discretizing an input from a representation that holds more information to a representation with less information. In the context of LLMs, quantization is used to recast model parameters (often stored with high floating point precision) as lower precision floating point values, or even as integers.
On October 15th, AMD is going to offer a live webinar on fine-tuning LLMs on AMD Radeon GPUs. You will have chance to get insights for optimizing LLMs to meet diverse and evolving needs with our experts. Welcome to Register to attend this online event!