Introduction

In a previous post, I explored using LLMs to perform text classification tasks. The idea was that they could enable automation of more complex tasks that are otherwise not automatable. Many people just use a model like ChatGPT for this, but after learning a bit about how they work, I was getting the sense that other approaches might be more efficient and accurate.

My experimental task was to classify messages sent from customers to an online store. Messages were classified as either refund requests, order inquiries, or general feedback.

Message Label
"The feedback submission form took a long time to load." General Feedback
"Can you tell me when my package will arrive? Order #1337" Order Inquiry
"I received the wrong item. How do I return it for a refund?" Refund Request

In that post, I used a wrapper around llama.cpp, running inference on the 8B parameter, 4-bit quantized version of Llama 3. Architecturally, this is works like a tiny ChatGPT, and is capable of running on pretty much any consumer hardware. By directly prompting the base model I achieved ~92% accuracy out of the box, improving to ~95% with post-processing (retrying invalid responses).

The results seemed okay, but a general purpose instruction model didn't really seem ideal for a specific task like this. I wanted to explore other options that might be more efficient or have better performance.

Two possible approaches to creating a task-specific model would be to either fine-tune an existing model or to train a smaller, specialized model from scratch.

Training a model from scratch would be extremely interesting, but it would also require a huge amount of data and computing power. On the other hand, fine-tuning is a technique that adapts pre-trained language models to specific tasks by further training on a smaller, task-specific dataset. This allows you to leverage the data and compute that went into the original pre-training of the model.

Model selection

My first naive idea was to try to fine-tune Llama 3 itself. Models such as GPT or Llama are primarily designed for text generation. Mechanically, they are trained to predict the next token in a sequence using masked self-attention, where each token can only attend to context to its left. What this means is that the neural network is only considering the words that appear before the token it is trying to predict. Lets look at an example.

Consider the sentence Live free or die hard. GPT/Llama will take this whole sentence and simultaneously learn to predict the next word given the input:

Input Target
Live free
Live free or
Live free or die
Live free or die hard

This makes perfect sense if you're training a model to generate text, but it doesn't really align with what we're trying to do. We're more interested in looking at an entire chunk of text and classifying it. It turns out there's better ways to achieve this.

Other models learn other types of tasks, such as Masked Language Modeling and Next Sentence Prediction. In Masked Language Modeling, the model is tasked with predicting the missing words in a sentence. For example, given the sentence "The [MASK] jumped over the fence," the model would try to predict the masked word (likely "dog" or "cat").

In Next Sentence Prediction, the model is tasked with determining if two sentences are adjacent in a text. For instance, given the sentences "I love ice cream" and "It's my favorite dessert", the model would determine that these sentences are relatively likely to be adjacent. For "The sky is blue" and "Elephants have trunks", the model would determine that these sentences relatively unlikely to be adjacent.

These approaches can be implemented through a mechanism called bidirectional self-attention. Unlike with masked-self attention, the model does not restrict itself to only consider tokens before something it is trying to predict. Instead, it considers all the tokens in the input and takes a stab at its task. This makes more sense if you're trying to perform sentiment analysis, named entity recognition, or text classification. In those cases, you wouldn't want all of the machinery associated with next-token prediction, because it would be dead weight.

BERT

A model that uses bidirectional self-attention is Bidirectional Encoder Representations from Transformers (BERT). Introduced by Google in 2018, BERT is pre-trained on Masked Language Modeling and Next Sentence Prediction described above.

These pre-training tasks equip BERT with a strong grasp of language structure and context. From the words of the authors, ...the pre-trained BERT model can be fine-tuned with just one additional output layer to create state-of-the-art models for a wide range of tasks, such as question answering and language inference, without substantial task-specific architecture modifications.

My guess was that a model like this would be a better starting point for my use case.

DistilBERT

In 2019, Hugging Face released a smaller, distilled version of BERT called DistilBERT. It preserves 95% of BERT's performance while using 40% fewer parameters and running 60% faster. For comparison, it is less than 1% of the size of Llama 3 8B.

DistilBERT achieves this efficiency through a process called knowledge distillation, where a smaller model (the student) is trained to mimic the behavior of a larger model (the teacher). This results in a more compact model that can still perform well on a wide range of tasks, making it well-suited for fine-tuning on specific applications, especially when computational resources are limited.

DistilBERT Architecture Overview

DistilBERT is based on the Transformer architecture. The model consists of several key components that work together to process and understand text. Text input is first passed through an embedding layer. This layer is responsible for converting input tokens (which are essentially words or parts of words) into dense vector representations. These vectors contain floating point numbers that encode the semantic meaning of the words in a format that the model can work with. The values of the vectors change throughout training as the model iterates on its semantic understanding.

The embedding layer learns that words like "dog" and "cat" are semantically closer to each other than to words like "car" or "building". It captures subtle relationships, such as understanding that "king" is to "man" as "queen" is to "woman", or that certain adjectives tend to precede certain types of nouns (like "delicious" often appearing before food-related words).

Once the text is embedded, it passes through multiple Transformer blocks. These blocks contain layers of self-attention and feed-forward neural networks. The self-attention mechanism allows the model to weigh the importance of different words in the input. This means the model can understand context and relationships between words, much like how we understand language by considering words in relation to each other. After the self-attention layer, feed-forward neural networks further process the attention output, allowing the model to capture complex patterns in the data.

Consider the sentence "The man who crossed the street was hit by a car." Self-attention allows the model to understand that "hit" is more strongly related to both "man" and "car" than to "crossed" or "street". This helps the model correctly interpret who was hit (the man) and what hit him (the car), even though these words are not adjacent in the sentence. The feed-forward networks then process this contextual information, potentially learning higher-level concepts like "traffic accidents" or "pedestrian safety" from such examples.

These layers represent different relationships that sub-word tokens can have with each other. As it learns, the model updates its understanding of both of these relationships at the same time. And during inference, it will consider what it has learned about both of these relationships when it predicts the next most likely token to occur.

Fine-tuning Process

The fine-tuning process involves adjusting the pre-trained DistilBERT model to our specific classification task.

The code below is implemented using Hugging Face's Transformers library, which provides a high-level interface for working with DistilBERT and other transformer models. This interface includes pre-built model architectures, tokenizers, and data processing utilities. The library offers components like DistilBertTokenizer for text tokenization and DistilBertForSequenceClassification for the actual model architecture. It also provides Dataset and DataLoader classes for efficient data handling and batching.

These high-level components abstract away much of the complexity involved in working with transformer models. The tokenizer handles the task of converting raw text into a format the model can understand, including subword tokenization and special token management. The model class encapsulates DistilBERT's architecture, including the self-attention mechanisms and feed-forward networks, exposing simple methods for forward passes and loss computation.

Transformers delegates the numerical computations to PyTorch, a deep learning framework that handles low-level operations. PyTorch manages tensor operations, automatic differentiation for backpropagation, and GPU acceleration. It provides the foundation for defining and training neural networks, offering classes like torch.nn for neural network layers and torch.optim for optimization algorithms.

Data Preparation

The process begins by loading the labeled data. It consisted of three files, one for each category, each containing several hundred customer messages.

def load_data(file_path):
    with open(file_path, 'r') as f:
        return [line.strip() for line in f]

feedback = load_data('data/feedback.txt')
inquiries = load_data('data/inquiries.txt')
refunds = load_data('data/refunds.txt')

all_texts = feedback + inquiries + refunds
all_labels = [0] * len(feedback) + [1] * len(inquiries) + [2] * len(refunds)

This code loads the text strings and assigns numerical labels to each. In neural networks, labels refer to "the right answer" and are used during training to check the model's guesses.

Dataset and DataLoader

A custom dataset class is created to handle tokenization:

class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __getitem__(self, item):
        text = str(self.texts[item])
        label = self.labels[item]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

A tokenizer is a tool that breaks down text into smaller units called tokens. These could be words, parts of words, or even punctuation. The tokenizer also handles converting these tokens into numbers that the model can understand. The DistilBertTokenizer is used because the same tokenizer that the model was trained with is desired. Under the hood, it uses WordPiece subword segmentation.

The encode_plus method of this tokenizer is used. This method does several things: It tokenizes the input text, adds special tokens that DistilBERT expects (like [CLS] at the start and [SEP] at the end), pads or truncates the input to a specified maximum length, and creates an "attention mask" which tells the model which tokens are actual input and which are padding.

The torch Dataset is a PyTorch class that represents a dataset. By inheriting from this class, a custom dataset that PyTorch can easily work with is created. The __getitem__ method is called when an item from the dataset needs to be accessed.

PyTorch's DataLoader is used to efficiently batch and shuffle the data:

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16)

The DataLoader is a PyTorch utility that helps manage batching of data and provides an iterable over the Dataset. For the training data, shuffle=True is set, which means it will randomly shuffle the data at the start of each epoch. This shuffling helps prevent the model from learning any unintended patterns based on the order of the training data.

A batch size of 16 is used, meaning the DataLoader will yield batches of 16 examples at a time. This batch size is a balance between memory usage and training speed. For the test data, shuffling is not needed, so the shuffle parameter is omitted.

During training and evaluation, these DataLoaders will be iterated over, which will give batches of data in the format the model expects. This abstraction simplifies the training loop and makes it easier to work with large datasets that might not fit into memory all at once.

Train-Test Split

Before training begins, the data is split into training and testing sets:

def train_test_split(data, test_size=0.1):
    split_index = int(len(data) * (1 - test_size))
    return data[:split_index], data[split_index:]

train_data, test_data = train_test_split(list(zip(all_texts, all_labels)))
train_texts, train_labels = zip(*train_data)
test_texts, test_labels = zip(*test_data)

A manual split is used where the first 90% of the data is used for training and the last 10% for testing. This approach was chosen to make it easy for me to place all of the samples evaluated in the previous blog post into the test set. That way, performance could be more directly compared, because none of those samples would have been seen by any either model during training. More on that later.

Training Loop

The training process involves iterating over the data multiple times and updating the model's parameters to minimize the classification error.

def train_epoch(model, data_loader, optimizer, device):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0

    for batch in data_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        _, predicted = torch.max(logits, 1)
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(data_loader)
    accuracy = total_correct / total_samples
    return avg_loss, accuracy

It isn't shown here, but model passed in to this function is an instance of the DistilBertForSequenceClassification class. input_ids, attention mask, and labels are determined during dataset preparation outside of this function call. The input_ids represent the tokenized input text, while attention_mask indicates which tokens should be attended to. labels represent the correct answers, while predicted represents the guesses.

The function performs gradient descent. At a high level, this refers to a process where the model makes a "forward pass" where it makes guesses and checks its answers. Then it makes a "backward pass" where it nudges its weights in a direction that it believes will make a better guess next time.

Calling model in this manner is the PyTorch idiom for doing a forward pass (this has always been a strange design choice to me).

The backward pass is done by a call to loss.backward() where it computes the direction and amount to nudge its parameters and optimizer.step() where it makes those parameter updates.

Most of the actual number crunching is carried out by PyTorch's autograd code. The attention mechanics are implemented in DistilBertForSequenceClassification class of the Hugging Face transformers library. This library also provides AdamW's ability to perform parameter updates (with additional optimizations, in AdamW's case).

As you can see, there isn't really a defined API for gradient descent. It's on you to know what steps need to be done and do them.

Model Configuration and Training

The model configuration and training takes place in the program's main method, separate from the above helper methods.

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)

train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_len=128)
test_dataset = TextClassificationDataset(test_texts, test_labels, tokenizer, max_len=128)

First, the tokenizer and model are initialized. The pre-trained 'distilbert-base-uncased' model is used, which means it's been trained on lowercase English text. The num_labels=3 parameter tells the model it's dealing with a three-class classification problem.

max_len=128 is set when creating the datasets. This maximum length parameter determines the longest sequence of tokens the model will process. Shorter sequences will be padded, and longer ones will be truncated.

The batch size of 16 in the DataLoader means 16 examples will be processed at a time during training. This is another balance between memory usage and training speed. Larger batch sizes can lead to faster training but require more memory.

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
model.to(device)

Next, the model is moved to the appropriate device. A check is made if Apple's Metal Performance Shaders (MPS) are available, which can accelerate training on compatible Mac hardware (this was only run on a Mac; this line would need to be changed if it were to be run optimally on different hardware).

optimizer = AdamW(model.parameters(), lr=2e-5)

Neural networks rely on nonlinear activation functions to model complex relationships in data. These functions, like tanh or ReLU, work best when their inputs are well-distributed across their operating range. For instance, a tanh function, which maps inputs to outputs between -1 and 1, is most effective when its inputs span this entire range rather than being concentrated at the extremes.

Ensuring this ideal distribution of inputs became a significant challenge as networks grew deeper. In 2015, researchers introduced Batch Normalization as a solution. This technique normalizes the inputs to each layer, improving the flow of data through the network. This worked, but it also introduced complexities. Eventually, this approach was superseded by adaptive learning rate methods like Adam (Adaptive Moment Estimation) and later the variant AdamW.

A learning rate of 2e-5 (0.00002) is used here. It was selected by manually tuning the learning rate and number of epochs and evaluating network performance. It's small enough to allow for fine adjustments to the pre-trained weights without causing drastic changes that could destroy the model's pre-trained knowledge. Too high, and the model might overshoot optimal solutions; too low, and the model might train too slowly or get stuck in suboptimal solutions. This learning rate is much lower than what you would use for training a network from scratch, which is usually around 1e-3.

num_epochs = 3
for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1}/{num_epochs}')
    train_loss, train_acc = train_epoch(model, train_dataloader, optimizer, device)
    print(f'Train loss {train_loss:.4f} accuracy {train_acc:.4f}')

    val_loss, val_acc = evaluate(model, test_dataloader, device)
    print(f'Val loss {val_loss:.4f} accuracy {val_acc:.4f}')

torch.save(model.state_dict(), 'model/distilbert_classifier.pth')
print("Model saved successfully.")

The training loop runs for 3 epochs, after which the accuracy and loss of both the training and validation sets reached desirable values (see "Results and Discussion" below).

An epoch is one complete pass through the entire training dataset. In each epoch, the model is trained on the training data and then evaluated on the validation (test) data. The loss and accuracy for both training and validation sets are printed out with 4 decimal places.

After training, the model is saved using torch.save(). This function saves the model's state_dict. A state_dict in PyTorch is a Python dictionary that maps each layer to its parameter tensors. It contains all the learned weights and biases of the model. By saving the state_dict, all the knowledge the model has gained during training is essentially saved. The model is saved with a .pth file extension. This is a common convention in PyTorch for saved model files, standing for "PyTorch".

To use this saved model later, it would be loaded like this:

model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=3)
model.load_state_dict(torch.load('model/distilbert_classifier.pth'))
model.eval()

Results and Discussion

For simplicity, I ran with the assumption that each string of text had one and only one classification. It would be much more useful if each string of text instead could have 0-3 classifications. This would also almost certainly bring down the accuracy of the model because determining whether a given string belongs to a single category at all seems like it would be a much more complex task than what it had to learn during training.

After 3 epochs, the model achieves an accuracy of 100% and loss of 0.156 on the test set and an accuracy of 99.59% and loss of 0.0278 on the training set.

Epoch 1/3
Train loss 0.6052 accuracy 0.8264
Val loss 0.1448 accuracy 1.0000
Epoch 2/3
Train loss 0.0708 accuracy 0.9959
Val loss 0.0314 accuracy 1.0000
Epoch 3/3
Train loss 0.0278 accuracy 0.9959
Val loss 0.0156 accuracy 1.0000

The task was designed to be easy enough that a network could handle it, so the performance met my expectations. The 100% accuracy of the test set includes the entirety of the tasks that Llama 3 8B only achieved 92% accuracy on, with additional test data added in. And just for kicks, I later tried coming up with weird and contrived strings of text to try to confuse the model. I could not get it to incorrectly classify anything.

After the fine-tuned DistilBERT model was trained on learning the distinctions between the three specific categories (refund requests, order inquiries, and general feedback), the fine-tuned model outperformed base Llama 3 for this specialized task at roughly 0.8% of its size. Llama 3 8B is subject to the variability of a general purpose instruction model, so it makes sense that it gets tripped up on even simple tasks from time to time.

I came away completely sold on the idea of fine-tuning small models for repetitive, special purpose tasks. It is inevitable that many suitable tasks will instead be run on models that are hundreds of billions of parameters or larger via commercial APIs. This presents a huge opportunity for cost savings and efficiency that will only grow over time.

The source code for this project is available on GitHub.