CUDA Out Of Memory when training a DETR Object detection model with compute_metrics

I’m training a DETR Object Detection model using the Trainer API. I have properly created the coco dataset.

But when I run the Trainer API with custom_metrics, I get the error “CUDA Out of Memory”. I have reduced batch_size from 16 until 1, but the same error of “Out of memory”. Here’s how I’m creating the custom_metrics function

id2label = {id: label for id, label in enumerate(train_ds.classes)}
label2id = {label: id for id, label in enumerate(train_ds.classes)}


@dataclass
class ModelOutput:
    logits: torch.Tensor
    pred_boxes: torch.Tensor

class MAPEvaluator:

    def __init__(self, image_processor, threshold=0.00, id2label=None):
        self.image_processor = image_processor
        self.threshold = threshold
        self.id2label = id2label

    def collect_image_sizes(self, targets):
        """Collect image sizes across the dataset as list of tensors with shape [batch_size, 2]."""
        image_sizes = []
        for batch in targets:
            batch_image_sizes = torch.tensor(np.array([x["size"] for x in batch]))
            image_sizes.append(batch_image_sizes)
        return image_sizes

    def collect_targets(self, targets, image_sizes):
        post_processed_targets = []
        for target_batch, image_size_batch in zip(targets, image_sizes):
            for target, (height, width) in zip(target_batch, image_size_batch):
                boxes = target["boxes"]
                boxes = sv.xcycwh_to_xyxy(boxes)
                boxes = boxes * np.array([width, height, width, height])
                boxes = torch.tensor(boxes)
                labels = torch.tensor(target["class_labels"])
                post_processed_targets.append({"boxes": boxes, "labels": labels})
        return post_processed_targets

    def collect_predictions(self, predictions, image_sizes):
        post_processed_predictions = []
        for batch, target_sizes in zip(predictions, image_sizes):
            batch_logits, batch_boxes = batch[1], batch[2]
            output = ModelOutput(logits=torch.tensor(batch_logits), pred_boxes=torch.tensor(batch_boxes))
            post_processed_output = self.image_processor.post_process_object_detection(
                output, threshold=self.threshold, target_sizes=target_sizes
            )
            post_processed_predictions.extend(post_processed_output)
        return post_processed_predictions

    @torch.no_grad()
    def __call__(self, evaluation_results):

        predictions, targets = evaluation_results.predictions, evaluation_results.label_ids

        image_sizes = self.collect_image_sizes(targets)
        post_processed_targets = self.collect_targets(targets, image_sizes)
        post_processed_predictions = self.collect_predictions(predictions, image_sizes)

        evaluator = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
        evaluator.warn_on_many_detections = False
        evaluator.update(post_processed_predictions, post_processed_targets)

        metrics = evaluator.compute()

        # Replace list of per class metrics with separate metric for each class
        classes = metrics.pop("classes")
        map_per_class = metrics.pop("map_per_class")
        mar_100_per_class = metrics.pop("mar_100_per_class")
        for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
            class_name = id2label[class_id.item()] if id2label is not None else class_id.item()
            metrics[f"map_{class_name}"] = class_map
            metrics[f"mar_100_{class_name}"] = class_mar

        metrics = {k: round(v.item(), 4) for k, v in metrics.items()}

        return metrics

eval_compute_metrics_fn = MAPEvaluator(image_processor=processor, threshold=0.01, id2label=id2label)

Below is the model training on the custom dataset

training_args = TrainingArguments(
    output_dir=f"Malaria-finetune",
    report_to="none",
    num_train_epochs=10,
    max_grad_norm=0.1,
    learning_rate=5e-5,
    warmup_steps=300,
    per_device_train_batch_size=1,
    dataloader_num_workers=2,
    metric_for_best_model="eval_map",
    greater_is_better=True,
    load_best_model_at_end=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    save_total_limit=2,
    remove_unused_columns=False,
    eval_do_concat_batches=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=pytorch_dataset_train,
    eval_dataset=pytorch_dataset_valid,
    processing_class=processor,
    data_collator=collate_fn,
    compute_metrics=eval_compute_metrics_fn
)
trainer.train()

I have looked at a few discussions talking about a similar issue;

I have done all the suggestions, but wasn’t able to fix this DETR model.

With research, I see that using the preprocess_logits_for_metrics in the Trainer might work. I have implemented it like this;

def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak. 
    This is a workaround to avoid storing too many tensors that are not needed.
    """
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels

Implemented it in the Trainer;

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=pytorch_dataset_train,
    eval_dataset=pytorch_dataset_valid,
    processing_class=processor,
    data_collator=collate_fn,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    compute_metrics=eval_compute_metrics_fn
)
trainer.train()

This couldn’t work as well. Please help

1 Like

I am having similar issue with RT-DETR. Found your post when searching online. Did you ever solve this?

1 Like

Seems resolved in GitHub…?

I’m considering hiring a custom app development company for a mid-sized fintech app. I came across TekRevol and liked their portfolio and focus on digital transformation.

Before jumping in, what should I look out for when partnering with a software development firm?

  • Contracts and IP?
  • Tech flexibility?
  • Project management transparency?

This was helpful. I implemented my own solution similar to the one on github. So far, no errors. I’ll add it here in case someone else finds it useful, or if someone finds bugs I missed. Either way, please let me know. Thanks!

@dataclass
class ModelOutput:
    logits: torch.Tensor
    pred_boxes: torch.Tensor

class BatchMAPEvaluator:

def __init__(self, image_processor, threshold=0.00, id2label=None):
    self.image_processor = image_processor
    self.threshold = threshold
    self.id2label = id2label
    self.evaluator = MeanAveragePrecision(box_format="xyxy", class_metrics=True)
    self.evaluator.warn_on_many_detections = False

def reset(self):
    """Reset the evaluator state for a new evaluation run."""
    self.evaluator.reset()

def process_batch(self, predictions, targets):
    """
    Process a single batch and update the evaluator.
    
    Args:
        predictions: tuple of (loss, logits, pred_boxes)
        targets: list of dicts, each with "size", "boxes", "class_labels" as tensors
    """
    # Get image sizes - targets is a list of dicts with tensor values
    image_sizes = torch.stack([x["size"] for x in targets]).cpu()
    
    # Process predictions
    batch_logits = predictions[1]
    batch_boxes = predictions[2]
    
    # Ensure tensors are on CPU
    if isinstance(batch_logits, torch.Tensor):
        batch_logits = batch_logits.cpu()
    else:
        batch_logits = torch.tensor(batch_logits)
        
    if isinstance(batch_boxes, torch.Tensor):
        batch_boxes = batch_boxes.cpu()
    else:
        batch_boxes = torch.tensor(batch_boxes)
    
    output = ModelOutput(logits=batch_logits, pred_boxes=batch_boxes)
    post_processed_predictions = self.image_processor.post_process_object_detection(
        output, threshold=self.threshold, target_sizes=image_sizes
    )
    
    # Process targets
    post_processed_targets = []
    for target, (height, width) in zip(targets, image_sizes):
        # Move tensors to CPU and convert to numpy
        boxes = target["boxes"].cpu().numpy()
        labels = target["class_labels"].cpu()  # Already a torch tensor, just moved to CPU
        
        # Convert xcycwh to xyxy and scale to image size
        boxes = sv.xcycwh_to_xyxy(boxes)
        boxes = boxes * np.array([width.item(), height.item(), width.item(), height.item()])
        
        post_processed_targets.append({
            "boxes": torch.tensor(boxes),
            "labels": labels
        })
    
    # Update evaluator with this batch
    self.evaluator.update(post_processed_predictions, post_processed_targets)

def compute(self):
    """Compute final metrics after all batches have been processed."""
    metrics = self.evaluator.compute()

    classes = metrics.pop("classes")
    map_per_class = metrics.pop("map_per_class")
    mar_100_per_class = metrics.pop("mar_100_per_class")
    for class_id, class_map, class_mar in zip(classes, map_per_class, mar_100_per_class):
        class_name = self.id2label[class_id.item()] if self.id2label is not None else class_id.item()
        metrics[f"map_{class_name}"] = class_map
        metrics[f"mar_100_{class_name}"] = class_mar

    metrics = {k: round(v.item(), 4) for k, v in metrics.items()}
    
    return metrics

@torch.no_grad()
def __call__(self, evaluation_results, compute_result: bool):
    if not compute_result:
        predictions = evaluation_results.predictions
        targets = evaluation_results.label_ids
        self.process_batch(predictions, targets)
        return {}
    else:
        metrics = self.compute()
        self.reset()
        return metrics

# Create batched map evaluator instance
batch_eval_compute_metrics_fn = BatchMAPEvaluator(image_processor=processor, threshold=0.30, id2label=id2label)

# Set batch_eval_metrics in TrainingArguments
training_args = TrainingArguments(
    ...
    batch_eval_metrics=True,
)

# Add batched map evaluator instance to Trainer
trainer = Trainer(
    ...
    compute_metrics=batch_eval_compute_metrics_fn,
)
1 Like