Train a Text Summarizer in 5 Minutes

Have you ever wanted to fine-tune a large language model, but didn't have the time or resources to train it from scratch? In this post, I'll show you how to use adapter-transformers to fine-tune Flan-T5 on the XSum dataset in five minutes. I'll also show you how to use the resulting model to generate summaries for any text.

What is Fine-Tuning?

Fine-tuning is the process of training a pre-trained model on a new dataset. This usually takes less time than training a model from scratch, but it can still take a long time.

What is adapter-transformers?

adapter-transformers is a library for fine-tuning transformers, based on HuggingFace's transformers library. It has a number of new features that make it easier to fine-tune transformers, such as LoRA (Low Rank Adaptation), which allow you to reduce the amount of parameters that need to be trained.

What is LoRA?

Diagram of LoRA
LoRA works by training two small matrices that are multiplied together to produce a large matrix. This large matrix is then added to the pre-trained model's weights, usually in the attention layers. Adding the matrices is equivalent to adding the results of the matrix multiplication, which is done in training.

The Flan-T5 Model

Flan-T5 is a series of models by Google that are based on the T5 architecture. They are text-to-text transformers that can be used for a variety of tasks, making them a great starting point for fine-tuning.
In this post, I'll be using flan-t5-base, which is one of the smaller models in the series.

The XSum Dataset

XSum stands for Extreme Summarization and is a dataset of documents and summaries.
In this post, I'll be using a small subset of the XSum dataset, which contains 1,000 documents and their summaries.

Environment Setup

I'll be using Google Colab for this post, but you can also run it locally if you have a GPU and the necessary libraries installed.
To install the libraries that aren't installed in Google Colab, you'll have to run the following commands:
pip install -U adapter-transformers sentencepiece
pip install datasets

Tokenizing the Dataset

In order to use the Flan-T5 model, we'll have to convert the documents and summaries into tokens that the model can understand.
To do this, we'll use the AutoTokenizer class from the transformers library. We'll also use a prefix before the documents, which is used by the model to determine what task it should perform. In this case, the prefix is "summarize: ". We have to do this because this is the format the model expects.
from transformers import AutoTokenizer

# the base model that we'll be using
base_model = "google/flan-t5-base"

# the tokenizer that we'll be using
tokenizer = AutoTokenizer.from_pretrained(base_model)

# the prefix that we'll be using
prefix = 'summarize: '

# tokenize the dataset
def encode_batch(examples):
    # the name of the input column
    text_column = 'document'
    # the name of the target column
    summary_column = 'summary'
    # used to format the tokens
    padding = "max_length"

    # convert to lists of strings
    inputs, targets = [], []
    for i in range(len(examples[text_column])):
        if examples[text_column][i] and examples[summary_column][i]:
            inputs.append(examples[text_column][i])
            targets.append(examples[summary_column][i])

    # add prefix to inputs
    inputs = [prefix + inp for inp in inputs]

    # finally we can tokenize the inputs and targets
    model_inputs = tokenizer(inputs, max_length=512, padding=padding, truncation=True)
    labels = tokenizer(targets, max_length=512, padding=padding, truncation=True)

    # rename to labels for training
    model_inputs["labels"] = labels["input_ids"]

    return model_inputs

Loading the Dataset

We'll be using the datasets library to load the dataset, and we will load a subset of the dataset from each split.
from datasets import load_dataset

# load the dataset
def load_split(split_name, max_items):
    # load the split
    dataset = load_dataset("xsum")[split_name]
    # only use the first max_items items
    dataset = dataset.filter(lambda _, idx: idx < max_items, with_indices=True)
    # tokenize the dataset
    dataset = dataset.map(
        encode_batch,
        batched=True,
        remove_columns=dataset.column_names,
        desc="Running tokenizer on " + split_name + " dataset",
    )
    # set the format to torch
    dataset.set_format(type="torch", columns=["input_ids", "labels"])

    return dataset

Preparing the Model

Now that we have the dataset, we can prepare the model for training.
from transformers import AutoModelForSeq2SeqLM
from transformers.adapters import LoRAConfig
import numpy as np

# start with the pretrained base model
model = AutoModelForSeq2SeqLM.from_pretrained(
    base_model
)

# set the parameters for LoRA
config = LoRAConfig(
    r=8,
    alpha=16,
    # use it on all of the layers
    intermediate_lora=True,
    output_lora=True
)

# make a new adapter for the XSum dataset
model.add_adapter("xsum", config=config)
# enable the adapter for training
model.train_adapter("xsum")
model.set_active_adapters(["xsum"])

Preparing the Trainer

Now that we have the model, we can prepare the trainer.
from transformers import TrainingArguments, AdapterTrainer, TrainerCallback

# small batch size to fit in memory
batch_size = 1

training_args = TrainingArguments(
    learning_rate=3e-4,
    num_train_epochs=1,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    logging_steps=200,
    output_dir="./training_output",
    overwrite_output_dir=True,
    remove_unused_columns=False
)

# create the trainer
trainer = AdapterTrainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    # load the dataset
    train_dataset=load_split("train", 1000),
    eval_dataset=load_split("test", 100),
)

Training the Model

Now that we have the trainer, we can train the model. This will take a while, especially depending on your hardware.
trainer.train()

Evaluating the Model

Now that we have the trained model, we can evaluate its performance on the test split.
trainer.evaluate()

Merging the Adapters

Like I mentioned earlier, we will need to add the matrices.
# merge the adapter with the model
# this will add the adapter weight matrices to the model weight matrices
model.merge_adapter("xsum")

Generating Summaries

We will use the validation split of the dataset and compare our summaries with the actual summaries.
num_validation = 10

validation_dataset = load_split('validation', num_validation)

for i in range(num_validation):
    # load the input and label
    input_ids = validation_dataset[i]['input_ids'].unsqueeze(0).to(0)
    label_ids = validation_dataset[i]['labels'].unsqueeze(0).to(0)
    # use the model to generate the output
    output = model.generate(input_ids, max_length=1024)
    # convert the tokens to text
    input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
    output_text = tokenizer.decode(output[0], skip_special_tokens=True)
    label_text = tokenizer.decode(label_ids[0], skip_special_tokens=True)

    print('Input:', input_text)
    print('Output:', output_text)
    print('Label:', label_text)
    print('---')

Saving the Model

Now that we have the trained model, we can save it. Mine is saved to tripplyons/flan-t5-base-xsum.
# login to upload the model
from huggingface_hub import login
login()

from huggingface_hub import HfApi
import torch
api = HfApi()

torch.save(model.state_dict(), 'pytorch_model.bin')

api.upload_file(
    path_or_fileobj="pytorch_model.bin",
    path_in_repo="pytorch_model.bin",
    # replace with your own username in order to upload
    repo_id="tripplyons/flan-t5-base-xsum",
    repo_type="model",
)

Conclusion

Hopefully this tutorial was helpful in showing how to quickly fine-tune a model for a new task using adapters. My source code for this tutorial can be found on Google Colab.