Hello
I’m also struggling with the TrOCR finetuning atm. and was searching the web and asking the AI - no success.
I do train models for historical sources/languages and was quite successful till some months ago. I trained like a German kurrent model (dh-unibe/trocr-kurrent-XVI-XVII · Hugging Face - test CER 0.05416, transformers 4.26.0) and a Latin one (dh-unibe/trocr-essoins-middle-latin · Hugging Face - test CER 0.0622, transformers 4.44.2).
But since some weeks, when I tried to train/finetune new models with my existing scripts and new data, I didn’t get under 50% CER anymore. I tried a lot like optimizing my scripts - without success.
So this is my generic TrOCR finetuning script, I use normally with a gt.txt TSV (image_line_filename \t line_text) and a image lines folder like img_def:
import argparse
import os
import csv
import glob
import pandas as pd
import torch
from evaluate import load
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from transformers import (
Seq2SeqTrainer,
Seq2SeqTrainingArguments,
TrOCRProcessor,
VisionEncoderDecoderModel,
GenerationConfig,
default_data_collator
)
# create the argparse object
parser = argparse.ArgumentParser(
prog='TrOCR training',
description='You can train a set of image-lines with a ground truth txt.',
)
# add the arguments
parser.add_argument('--fp16', default=False, action='store_true', help='use fp16 or not')
parser.add_argument('--fp16_eval', default=False, action='store_true', help='use fp16 for eval')
parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train (default 10)')
parser.add_argument('--gt_file', type=str, required=True, help='path to the ground truth file')
parser.add_argument('--img_folder', type=str, required=True, help='path to the folder containing the images')
parser.add_argument('--base_model', type=str, default='microsoft/trocr-base-handwritten', help='name or path of the base model (default microsoft/trocr-base-handwritten)')
parser.add_argument('--trocr_processor', type=str, default='microsoft/trocr-base-handwritten', help='name or path of the trocr processor (default microsoft/trocr-base-handwritten)')
parser.add_argument('--output_dir', type=str, required=True, help='path to the output directory')
parser.add_argument('--eval_split', type=float, default=0.05, help='percentage of data to use for evaluation (default 0.05)')
parser.add_argument('--logging_steps', type=int, default=50, help='number of steps between logging (default 50)')
parser.add_argument('--eval_steps', type=int, default=100, help='number of steps between evaluation (default 100)')
parser.add_argument('--save_steps', type=int, default=100, help='number of steps between saving (default 100)')
parser.add_argument('--warmup_steps', type=int, default=100, help='number of steps between warmup (default 100)')
parser.add_argument('--save_total_limit', type=int, default=5, help='maximum number of checkpoints to save (default 5)')
parser.add_argument('--max_target_length', type=int, default=256, help='maximum length of the tensors (default 256)')
args = parser.parse_args()
class TrOCRDataset(Dataset):
def __init__(self, root_dir, df, processor, max_target_length=args.max_target_length):
self.root_dir = root_dir
self.df = df
self.processor = processor
self.max_target_length = max_target_length
def __len__(self):
return len(self.df)
def __getitem__(self, idx):
# get file name + text
file_name = self.df['file_name'][idx]
text = self.df['text'][idx]
# prepare image (i.e. normalize)
file_path = os.path.join(self.root_dir, file_name)
image = Image.open(file_path).convert('RGB')
pixel_values = self.processor(image, return_tensors='pt').pixel_values
# add labels (input_ids) by encoding the text
labels = self.processor.tokenizer(
text,
padding='max_length',
max_length=self.max_target_length,
).input_ids
# important: make sure that PAD tokens are ignored by the loss function
labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
encoding = {'pixel_values': pixel_values.squeeze(), 'labels': torch.tensor(labels)}
return encoding
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)
cer = cer_metric.compute(predictions=pred_str, references=label_str)
return {'cer': cer}
GT_FILE_TRAIN = args.gt_file
IMAGE_FOLDER = args.img_folder
BASE_MODEL = args.base_model
TROCR_PROCESSOR = args.trocr_processor
OUTPUT_DIR = args.output_dir
df = pd.read_csv(
GT_FILE_TRAIN,
sep='\t',
names=['file_name', 'text'],
header=None,
dtype=str,
encoding='utf-8',
quoting=csv.QUOTE_NONE
)
img_lines = [os.path.basename(g) for g in glob.glob(f'{IMAGE_FOLDER}/*')]
df = df[df['file_name'].isin(img_lines)]
df = df.dropna()
df.reset_index(drop=True, inplace=True)
train_df, val_df = train_test_split(df, test_size=args.eval_split)
# we reset the indices to start from zero
train_df.reset_index(drop=True, inplace=True)
val_df.reset_index(drop=True, inplace=True)
processor = TrOCRProcessor.from_pretrained(TROCR_PROCESSOR, use_fast=True)
train_dataset = TrOCRDataset(root_dir=IMAGE_FOLDER,
df=train_df,
processor=processor)
eval_dataset = TrOCRDataset(root_dir=IMAGE_FOLDER,
df=val_df,
processor=processor)
print('Number of training examples:', len(train_dataset))
print('Number of validation examples:', len(eval_dataset))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VisionEncoderDecoderModel.from_pretrained(BASE_MODEL).to(device)
# removed because of user warning
if hasattr(model, 'loss_type'):
model.__dict__.pop('loss_type', None)
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size
# change generation config
generation_config = model.generation_config if hasattr(model, 'generation_config') else GenerationConfig()
generation_config.eos_token_id = processor.tokenizer.sep_token_id
generation_config.max_length = args.max_target_length
generation_config.early_stopping = True
generation_config.length_penalty = 2.0
generation_config.num_beams = 4
model.generation_config = generation_config
training_args = Seq2SeqTrainingArguments(
auto_find_batch_size=True,
do_train=True,
do_eval=True,
predict_with_generate=True,
fp16=args.fp16,
fp16_full_eval=args.fp16_eval,
overwrite_output_dir=True,
output_dir=OUTPUT_DIR,
num_train_epochs = args.epochs,
eval_strategy='steps',
save_strategy='steps',
logging_strategy='steps',
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
warmup_steps=args.warmup_steps,
save_total_limit=args.save_total_limit,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
optim='adafactor',
load_best_model_at_end=True,
metric_for_best_model='eval_cer',
greater_is_better=False,
report_to='none',
)
cer_metric = load('cer')
# instantiate trainer
trainer = Seq2SeqTrainer(
model=model,
processing_class=processor,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=default_data_collator,
)
trainer.train()
trainer.save_model(os.path.join(OUTPUT_DIR, 'last'))
processor.save_pretrained(os.path.join(OUTPUT_DIR, 'last'))
I tried the following settings:
Atm I’m working on the following environment:
- Ubuntu 22.04
- Nvidia RTX A6000
- CUDA 12.6
And the newest Python envs (didn’t document them too good - after the fail with 3.12, i decided to try it with 3.10, so that’s why there are the newer versions now…):
- Python 3.12:
transformers 4.53.2
pytorch 2.7.1
torchvision 0.22.1
- Python 3.10:
transformers 4.55.0
pytorch 2.8.0+cu128
torchvision 0.22.1
Any idea, why I cannot get under 50% CER with those settings?
Do you have recommendations for other settings/version combinations?
With transformers==4.44.2I get far better results…
So should I just stay with the older version? Any idea, what changed?
Thanks a lot in advance!
Greetings