Fine-Tuning BERT for multiclass categorisation with Amazon SageMaker

Posted by on September 15, 2021

tuning mechanism for a guitar

Intro and Context

When FreeAgent customers import their bank transactions, we predict which accounting categories the transactions belong to by making requests to a machine learning model managed with Amazon SageMaker. The model inputs are bank transaction descriptions (e.g. ‘Costa Coffee Edinburgh’) vectorized by using a tf-idf bag-of-words approach1, and the model itself is a linear support vector machine (e.g. chapter 12 of the elements of statistical learning2).

This approach is simple both in terms of the preprocessing and modelling, and has been running on the smallest real-time inference instance SageMaker has to offer, so costs have been low! However we discovered a scalability issue that has been hampering us for some time. This issue is that the vocabulary learned when vectorizing the training data transactions, which considers both unigrams and bigrams, can reach the 10s of millions. The number of parameters we have to store is pretty much equivalent to our vocabulary size multiplied by the number of target classes when we train the model with a one vs. rest strategy. Adding more target classes to the model therefore increases the number of large matrix multiplications required to serve the predictions, affecting model latency, and also the space required to hold the model object in memory.

For this reason, as well as curiosity as to whether vector representations based on an attention mechanism3 will perform better than bag-of-words, we were keen to explore using a BERT4 model for transaction categorisation. In this post we discuss how we made use of the Hugging Face transformers library5 to fine-tune a BERT model to categorise our bank transactions. 

The work described was carried out together with our summer intern, Harry Tullett.

Fine-Tuning BERT for multiclass problems

BERT is an approach for constructing vector representations of input natural language data based on the transformer architecture6. The representations are learned by pre-training on very large text corpora and can be used as inputs when learning to perform various downstream tasks – a process referred to as fine-tuning. BERT has been instrumental in the adoption of transfer-learning for natural language processing in the same way as ImageNet7 for computer vision. The family of transformer-based models8 achieve the current state-of-the-art performance9 in tasks such as machine translation, named-entity recognition, question answering and sentiment analysis.  

Since we already use SageMaker and other AWS services like S3 and SageMaker Studio we wanted to stay within this ecosystem and make use of the fantastic set of pre-trained models and training APIs provided by Hugging Face. To get started quickly, we followed an example notebook that uses the Hugging Face framework in the Python SageMaker sdk to fine-tune a binary categorisation model. 

What we struggled to find was an example showing how to fine-tune for multiclass categorisation with a custom dataset. This post contributes a description of how to modify the above example to train multiclass categorisation models in SageMaker using CSV data stored in S3.

Our setup

We use SageMaker studio with a Python 3 (PyTorch 1.6 Python 3.6 CPU Optimized) kernel to run our code, and install the required packages on the first line as follows:

!pip install 'sagemaker>=2.48.0' 'transformers==4.9.2' 'datasets[s3]==1.11.0' --upgrade

Configure estimator source and output

The source_dir and output_path attributes of the Hugging Face Estimator define paths to the source directory of the training script (as a tar.gz file) and model outputs respectively. These should be set to appropriate S3 prefixes.

    'epochs': 1,
    'train_batch_size': 32,
    'model_name': 'distilbert-base-uncased',
huggingface_estimator = HuggingFace(
    hyperparameters = hyperparameters,

We first tested this by using the tokenized imdb data saved to our own S3 buckets as inputs and calling fit on the estimator to check if our SageMaker studio role had the required permissions to interact with S3 and SageMaker. The SageMaker training job successfully completed and model outputs were written to the expected S3 location. 

Read custom data from S3

Satisfied our permissions were set correctly, we started tackling the multiclass problem. Our training and validation data are stored in CSV files in S3. We load these into pandas dataframes with the read_csv method, which uses the s3fs package to read directly from S3:

train_df = pd.read_csv('s3://bucket/prefix/data/training/train.csv')
val_df = pd.read_csv('s3://bucket/prefix/data/validation/val.csv')

Encode target variable and format headers 

Each row of these dataframes contains the text description of a bank transaction in a ‘text’ column and the assigned accounting category (e.g. INSURANCE) in a ‘category’ column.

This categorical target is encoded as an integer and the name of the column is changed to ‘labels’ as expected by the HuggingFace model. The following examples shown for the training dataset are also applied to the validation data:

from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
train_df['category'] = le.fit_transform(train_df['category'])
train_df.rename(columns={'category': 'labels'}, inplace=True)

Create Hugging Face Dataset objects

As per the original example notebook we load the data into a Hugging Face Dataset object, but do this from memory using the from_pandas method:

from datasets import Dataset
train_dataset = Dataset.from_pandas(train_df)

Tokenize and pad the input data

The text column is tokenized and padded before passing to the fit method of the Estimator. This is the process of applying the pre-trained tokenizer associated with our chosen BERT model to the bank transaction text to generate input ids and an attention mask.

We tokenize batches of data and pad these to a common length (i.e. a common number of tokens) because each transaction in a batch passed to the model must have the same length. The input ids are used to lookup the vector representations of the corresponding tokens in the pre-trained BERT model.

The fine-tuning process involves passing vectors representing the token sequences to a feed-forward neural network head attached to the BERT architecture, which outputs probabilities for each of the target classes. Backpropagation occurs through this entire architecture in order to compute the updates to the weights and biases that will minimise the cost function for our classification problem. 

The code for the tokenization is as follows:

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

def tokenize(batch):
    return tokenizer(batch['text'], padding='longest', truncation=True)

train_dataset_tokenized =
    columns=['input_ids', 'attention_mask', 'labels']

Note we’ve updated the tokenize helper method defined in the notebook to use padding='longest', which pads all the transactions in the batch to the size of the transaction with the most tokens.  We find this is typically ~50 tokens in our use case. When left as padding='max_length', every transaction is padded to a size of 512, which is the maximum length that can be passed to the BERT model.

This turns out to be really significant in terms of training times, GPU memory utilization during training and the size of the output model artifacts. An even better strategy would be to batch similar sized transactions together, as most of our examples have much fewer than 50 tokens. We pass the tokenize method batches of 100 transactions and have to be careful to train the model with batches that are a multiple of this to ensure everything has the same length.

At this point the format of the dataset is set to torch so that Pytorch tensors are returned, and the unused ‘text’ column is dropped. 

Final step: save formatted padded data

After these transformations we’re almost there! The datasets are saved to s3 as follows: 

from datasets.filesystems import S3FileSystem
s3 = S3FileSystem()  

training_input_path = 's3://bucket/prefix/data/training/train_tokenized.csv'
train_dataset_tokenized.save_to_disk(training_input_path, fs=s3)

Modify the training script for multiclass categorisation

A couple of very minor changes need to be made to the ‘’ training script defined in the notebook to facilitate multiclass categorisation. These are to add the num_labels attribute to the model definition (the number can be found from the label encoder with e.g. len(le.classes_), in our case 15):

model = AutoModelForSequenceClassification.from_pretrained(

and to change the average type in the compute_metrics method from binary to micro (or one of the other multiclass options):

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall

This file is then tarred as ‘train_script.tar.gz’ and uploaded to the s3 location defined in the Estimator source_dir attribute.

Finally we change the train_batch_size in the hyperparameters to 100 to match the tokenizer batches as discussed above, define our Estimator and call fit:
    {'train': training_input_path, 'test': val_input_path}


The logs are streamed to the notebook when running from SageMaker studio so it’s easy to see the progress of the training job. When the job completes you’ll see the model outputs in the specified s3 location.

If you’re only interested in the final model for making predictions, you can add save_strategy='no' to the TrainingArguments in the train script, which will stop checkpoints being saved and greatly reduce the size of the model outputs.

Use the model artifacts for inference

This model can be deployed to a SageMaker real-time inference endpoint following the steps in the notebook. 

However, if you just want to experiment offline by downloading the model outputs to make predictions you can use a Hugging Face pipeline to output confidence scores: 

from transformers import AutoModelForSequenceClassification
from transformers import TextClassificationPipeline
from transformers import AutoTokenizer

model = AutoModelForSequenceClassification.from_pretrained(
tokenizer = AutoTokenizer.from_pretrained(
pipe = TextClassificationPipeline(
pipe(['Costa Coffee Edinburgh', 'TFL Travel London'])


Very preliminary evaluation suggests that the accuracy of predictions from our fine-tuned BERT model, completely out of the box, is slightly higher than predictions from the linear SVM. This, combined with the scalability advantages described in the introduction provides good motivation for further investigation into this modelling approach.   

Hopefully this post saves a bit of searching through documentation for someone! We’d be keen to hear from anyone working in the area of transaction categorisation, particularly which preprocessing and modelling approaches you’ve been working with, so please do get in touch if this was of interest!


  1. scikit-learn text-feature extraction:
  2. Elements of statistical learning ii Chapter 12:
  3. Attention distil blog post:
  4. BERT paper: Devlin et al. 2018:
  5. Hugging Face Transformers:
  6. Transformer Architecture: Vaswani et al. 2017
  7. ImageNet website:
  8. Must-read papers on pre-trained language models:
  9. Current state-of-the-art in machine learning tasks:
  10. Python sagemaker sdk documentation for the Hugging Face framework:
  11. YouTube video on fine-tuning a Hugging Face transformer model using SageMaker:
  12. Getting Started notebook to accompany above YouTube video:

Leave a reply

Your email address will not be published. Required fields are marked *