Tutorial

Fine-Tuning a Vision Transformer (ViT) on Custom Data

Learn how to fine-tune a pretrained Vision Transformer for image classification on your own dataset using HuggingFace Transformers and PyTorch. Covers data preparation, training loop, evaluation, and deployment tips.

Mohammed Gamal Mohammed Gamal
· 2026-03-14 · 10 min read · Intermediate
Computer Vision Vision Transformers PyTorch Deep Learning HuggingFace

Overview

Vision Transformers (ViTs) have become the go-to architecture for image classification. This tutorial walks you through fine-tuning a pretrained ViT on a custom dataset using HuggingFace's transformers library.


Prerequisites

  • Python 3.10+
  • A GPU (Colab free tier works)
  • Basic PyTorch knowledge
pip install transformers datasets torch torchvision pillow

Step 1: Prepare Your Dataset

Organize images into folders by class:

data/
  train/
    cats/
    dogs/
  val/
    cats/
    dogs/

Load with HuggingFace datasets:

from datasets import load_dataset

dataset = load_dataset('imagefolder', data_dir='./data')
print(dataset)

Step 2: Load the Pretrained Model

from transformers import ViTForImageClassification, ViTImageProcessor

model_name = 'google/vit-base-patch16-224'
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(
    model_name,
    num_labels=2,
    ignore_mismatched_sizes=True
)

Step 3: Preprocess Images

def transform(batch):
    inputs = processor(batch['image'], return_tensors='pt')
    inputs['labels'] = batch['label']
    return inputs

dataset = dataset.with_transform(transform)

Step 4: Train

from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    output_dir='./vit-finetuned',
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    eval_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['val'],
)

trainer.train()

Step 5: Evaluate and Export

metrics = trainer.evaluate()
print(f"Accuracy: {metrics['eval_accuracy']:.2%}")

model.save_pretrained('./vit-finetuned')
processor.save_pretrained('./vit-finetuned')

Tips

  • Use data augmentation (random crops, flips) for small datasets
  • Lower the learning rate for larger models
  • Use mixed precision (fp16=True) for faster training
  • Monitor with Weights & Biases for experiment tracking

Continue learning

Browse All Tutorials