Fine-tuning donut transformer for document classification
Document classification is a machine learning problem in which, given a document file as input, one receives its class as output. This task plays an important role at Qantev, where automating insurance claims requires correctly classifying all documents and extracting relevant information from them. The most naive approach is to treat this as a simple image classification problem and simply use a CNN to determine the class of the document. However, with this basic approach, we lose the richness of a document that could be captured using the text. To utilize both visual and textual information, we need to employ an Optical Character Recognition (OCR) model as a pre-processing module to also utilize these text features.
Once we have the image and textual information, there are multiple multimodal architectures such as LayoutLM [1], BROS [2] and DocFormer [3], that were developed for document information extraction, but can also be used to tackle the problem of document classification. We will call these architectures OCR-based Document Classification. The only drawback is that it’s very time consuming to first run OCR and then perform the classification. A solution to have a model that takes into account the text information of a document but doesn’t need to run an OCR model, is Donut [4]. In this blog article, we will show how to apply Donut to train a document classification model.
Donut, short for Document Understanding Transformer, is an Encoder-Decoder transformer-based architecture that can perform three tasks: classification, parsing and document Question and Answering. Donut is an OCR-free model that excels in recognizing and interpreting text, tables, and layout elements, having as input only the image itself. It presents a simple architecture that is highly effective in terms of memory usage and time cost [4]. Donut’s architecture is formed by a Swin Transformer [5] visual encoder that produces a set of embeddings from patches of the document image. Such embeddings are then fed to a BART textual decoder that outputs a sequence of tokens, which are converted to a structured format that depends on the desired task (Fig. 1). Donut was pre-trained on millions of synthetic documents to make it applicable for documents in different languages, such as English and Chinese.
In Donut’s paper, for the document classification problem, they show that they achieve similar performance compared to OCR-based models while having a way lower inference time and without the need to run OCR itself, which takes extra time.
How to fine-tune donut?
In order to find the most effective way of implementing a model for document classification, our AI team explored different ways of fine-tuning a Donut model. It is possible to import a pre-trained Donut to be fine-tuned [4]. At first, we used the approach implemented by Niels Rogge [6], using encoder and decoder structure and adding the classes as individual tokens in the end of the vocabulary. This is the most common method of classification using Donut on the web. Donut decoder possesses 57525 tokens in its vocabulary by default, and we added all the document class names to the vocabulary of the decoder as new tokens. We used the RVL-CDIP dataset, so we added 16 new tokens in the end.
Although this is the most popular technique to fine-tune Donut, we believe it is not the most efficient in terms of computation resources usage. As we add the classes as new tokens we basically forget all the learned tokens on the decoder vocabulary and will predict one of the 16 tokens that we added. Therefore, the full decoder is being used as a big classification head.
In this blog post, we propose to simply use the Encoder as a feature extractor and add a classification head on top of it, which is way more parameter and inference time efficient, as we don’t need to perform the decoder autoregressive prediction. Later, we compare the results using both approaches. A schematic of the model can be seen in the figure below. It receives a document image and outputs the class using one-hot encoding.
Results
Inspired by the experiment in the reference [6], we used data from the RVL-CDIP dataset to realize our experiments. It is a famous dataset of scanned documents used for document image classification tasks. As well as in our reference notebook, we used a subset of this dataset containing only 10 images per each one of the 16 classes of the dataset. We set the number of epochs and learning rate for both models, 20 and 1.0e-5, respectively, and we fine-tuned them on this small subset of RVL-CDIP. We measured the accuracy and the average inference time running on a T4 GPU on Google Colab. The link to run the experiments is available at the end of this blog post.
As expected, Encoder-only Donut demonstrated similar accuracy when compared to the Donut + Add Tokens model in this benchmark experiment. On the other hand, the average inference time was significantly shorter — approximately ten times less. This result was anticipated, as the Encoder-only Donut is a much smaller model, lacking the extensive structure and autoregressive computations required by the decoder. Remember that the accuracy of both models is low, because it’s a toy-model trained only for a few epochs. We didn’t focus on hyperparameter tuning and we used a small chunk of the dataset. It was just supposed to be a quick check. Even so, they are much much better than the random guess 1/16 = 6.25%.
Our findings suggest that the encoder-only approach is not only viable but also more efficient in terms of speed, ultimately leading to more efficient and effective solutions for our clients. Using Qantev’s internal datasets with way more data and performing proper hyper-parameter tuning, it was possible to see that both models performed similarly as well. Encoder-only was slightly better than its encoder-decoder counterpart. However, the inference time pattern remained the same, i.e. Encoder-only Donut is much faster. These results are really intriguing, as most of the materials and tutorials found on the internet explaining how to use Donut for classification tasks approach the solution using the entire model (encoder and decoder).
Link to notebooks:
Here are the notebooks used for the experiments. We have a copy of Niels Rogge’s notebook used to compute the metrics of the Donut + Add Tokens approach. We’ve also provided a link to our notebook, in which we show how to use Donut Encoder-only to classify documents and compute its metrics used in this article.
Copy of Niels Rogge’s Notebook:
Encoder-Only Donut Notebook:
References:
[1] Xu, Y., Li, M., Cui, L., Huang, S., Wei, F., & Zhou, M. (2020). LayoutLM: Pre-training of Text and Layout for Document Image Understanding. arXiv. https://arxiv.org/abs/1912.13318
[2] Hong, T., Kim, D., Ji, M., Hwang, W., Nam, D., & Park, S. (2022). BROS: A Pre-trained Language Model Focusing on Text and Layout for Better Key Information Extraction from Documents. arXiv. https://arxiv.org/abs/2108.04539
[3] Appalaraju, S., Jasani, B., Kota, B. U., Xie, Y., & Manmatha, R. (2021). DocFormer: End-to-End Transformer for Document Understanding. arXiv. https://arxiv.org/abs/2106.11539
[4] Kim, G., et al. (2022). OCR-free Document Understanding Transformer. arXiv. https://arxiv.org/abs/2111.15664
[5] Hugging Face. Donut SWIN Model. Retrieved from https://huggingface.co/docs/transformers/model_doc/donut#transformers.DonutSwinModel
[6] Rogge, N. . Fine-tune Donut on toy RVL-CDIP (document image classification) [Notebook]. GitHub. Retrieved July 18, 2024, from https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Donut/RVL-CDIP/Fine_tune_Donut_on_toy_RVL_CDIP_(document_image_classification).ipynb