Train a Text Summarizer in 5 Minutes
Have you ever wanted to fine-tune a large language model,
but didn't have the time or resources to train it from scratch?
In this post, I'll show you how to use adapter-transformers
to fine-tune Flan-T5 on the XSum dataset in five minutes.
I'll also show you how to use the resulting model to generate summaries for any text.
What is Fine-Tuning?
Fine-tuning is the process of training a pre-trained model on a new dataset.
This usually takes less time than training a model from scratch,
but it can still take a long time.
What is adapter-transformers?
adapter-transformers is a library for fine-tuning transformers,
based on HuggingFace's transformers library.
It has a number of new features that make it easier to fine-tune transformers,
such as LoRA (Low Rank Adaptation), which allow you to reduce the amount of parameters
that need to be trained.
What is LoRA?
LoRA works by training two small matrices
that are multiplied together to produce a large matrix.
This large matrix is then added to the pre-trained model's weights,
usually in the attention layers. Adding the matrices
is equivalent to adding the results of the matrix multiplication,
which is done in training.
The Flan-T5 Model
Flan-T5 is a series of models by Google that are based on the T5 architecture.
They are text-to-text transformers that can be used for a variety of tasks,
making them a great starting point for fine-tuning.
In this post, I'll be using flan-t5-base,
which is one of the smaller models in the series.
The XSum Dataset
XSum stands for Extreme Summarization and is a dataset of documents and summaries.
In this post, I'll be using a small subset of the XSum dataset,
which contains 1,000 documents and their summaries.
Environment Setup
I'll be using Google Colab for this post,
but you can also run it locally if you have a GPU and the necessary libraries installed.
To install the libraries that aren't installed in Google Colab, you'll have to run the following commands:
pip install -U adapter-transformers sentencepiece pip install datasets
Tokenizing the Dataset
In order to use the Flan-T5 model, we'll have to convert the documents and summaries
into tokens that the model can understand.
To do this, we'll use the AutoTokenizer class from the transformers library.
We'll also use a prefix before the documents, which is used by the model to determine
what task it should perform. In this case, the prefix is "summarize: ". We have to do
this because this is the format the model expects.
from transformers import AutoTokenizer # the base model that we'll be using base_model = "google/flan-t5-base" # the tokenizer that we'll be using tokenizer = AutoTokenizer.from_pretrained(base_model) # the prefix that we'll be using prefix = 'summarize: ' # tokenize the dataset def encode_batch(examples): # the name of the input column text_column = 'document' # the name of the target column summary_column = 'summary' # used to format the tokens padding = "max_length" # convert to lists of strings inputs, targets = [], [] for i in range(len(examples[text_column])): if examples[text_column][i] and examples[summary_column][i]: inputs.append(examples[text_column][i]) targets.append(examples[summary_column][i]) # add prefix to inputs inputs = [prefix + inp for inp in inputs] # finally we can tokenize the inputs and targets model_inputs = tokenizer(inputs, max_length=512, padding=padding, truncation=True) labels = tokenizer(targets, max_length=512, padding=padding, truncation=True) # rename to labels for training model_inputs["labels"] = labels["input_ids"] return model_inputs
Loading the Dataset
We'll be using the datasets library to load the dataset,
and we will load a subset of the dataset from each split.
from datasets import load_dataset # load the dataset def load_split(split_name, max_items): # load the split dataset = load_dataset("xsum")[split_name] # only use the first max_items items dataset = dataset.filter(lambda _, idx: idx < max_items, with_indices=True) # tokenize the dataset dataset = dataset.map( encode_batch, batched=True, remove_columns=dataset.column_names, desc="Running tokenizer on " + split_name + " dataset", ) # set the format to torch dataset.set_format(type="torch", columns=["input_ids", "labels"]) return dataset
Preparing the Model
Now that we have the dataset, we can prepare the model for training.
from transformers import AutoModelForSeq2SeqLM from transformers.adapters import LoRAConfig import numpy as np # start with the pretrained base model model = AutoModelForSeq2SeqLM.from_pretrained( base_model ) # set the parameters for LoRA config = LoRAConfig( r=8, alpha=16, # use it on all of the layers intermediate_lora=True, output_lora=True ) # make a new adapter for the XSum dataset model.add_adapter("xsum", config=config) # enable the adapter for training model.train_adapter("xsum") model.set_active_adapters(["xsum"])
Preparing the Trainer
Now that we have the model, we can prepare the trainer.
from transformers import TrainingArguments, AdapterTrainer, TrainerCallback # small batch size to fit in memory batch_size = 1 training_args = TrainingArguments( learning_rate=3e-4, num_train_epochs=1, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, logging_steps=200, output_dir="./training_output", overwrite_output_dir=True, remove_unused_columns=False ) # create the trainer trainer = AdapterTrainer( model=model, args=training_args, tokenizer=tokenizer, # load the dataset train_dataset=load_split("train", 1000), eval_dataset=load_split("test", 100), )
Training the Model
Now that we have the trainer, we can train the model.
This will take a while, especially depending on your hardware.
trainer.train()
Evaluating the Model
Now that we have the trained model, we can evaluate its performance on the test split.
trainer.evaluate()
Merging the Adapters
Like I mentioned earlier, we will need to add the matrices.
# merge the adapter with the model # this will add the adapter weight matrices to the model weight matrices model.merge_adapter("xsum")
Generating Summaries
We will use the validation split of the dataset and compare our summaries
with the actual summaries.
num_validation = 10 validation_dataset = load_split('validation', num_validation) for i in range(num_validation): # load the input and label input_ids = validation_dataset[i]['input_ids'].unsqueeze(0).to(0) label_ids = validation_dataset[i]['labels'].unsqueeze(0).to(0) # use the model to generate the output output = model.generate(input_ids, max_length=1024) # convert the tokens to text input_text = tokenizer.decode(input_ids[0], skip_special_tokens=True) output_text = tokenizer.decode(output[0], skip_special_tokens=True) label_text = tokenizer.decode(label_ids[0], skip_special_tokens=True) print('Input:', input_text) print('Output:', output_text) print('Label:', label_text) print('---')
Saving the Model
Now that we have the trained model, we can save it. Mine is saved to tripplyons/flan-t5-base-xsum.
# login to upload the model from huggingface_hub import login login() from huggingface_hub import HfApi import torch api = HfApi() torch.save(model.state_dict(), 'pytorch_model.bin') api.upload_file( path_or_fileobj="pytorch_model.bin", path_in_repo="pytorch_model.bin", # replace with your own username in order to upload repo_id="tripplyons/flan-t5-base-xsum", repo_type="model", )
Conclusion
Hopefully this tutorial was helpful in showing how to quickly fine-tune a model
for a new task using adapters. My source code for this tutorial can be found
on Google Colab.