Overview
Vision Transformers (ViTs) have become the go-to architecture for image classification. This tutorial walks you through fine-tuning a pretrained ViT on a custom dataset using HuggingFace's transformers library.
Prerequisites
- Python 3.10+
- A GPU (Colab free tier works)
- Basic PyTorch knowledge
pip install transformers datasets torch torchvision pillow
Step 1: Prepare Your Dataset
Organize images into folders by class:
data/
train/
cats/
dogs/
val/
cats/
dogs/
Load with HuggingFace datasets:
from datasets import load_dataset
dataset = load_dataset('imagefolder', data_dir='./data')
print(dataset)
Step 2: Load the Pretrained Model
from transformers import ViTForImageClassification, ViTImageProcessor
model_name = 'google/vit-base-patch16-224'
processor = ViTImageProcessor.from_pretrained(model_name)
model = ViTForImageClassification.from_pretrained(
model_name,
num_labels=2,
ignore_mismatched_sizes=True
)
Step 3: Preprocess Images
def transform(batch):
inputs = processor(batch['image'], return_tensors='pt')
inputs['labels'] = batch['label']
return inputs
dataset = dataset.with_transform(transform)
Step 4: Train
from transformers import TrainingArguments, Trainer
args = TrainingArguments(
output_dir='./vit-finetuned',
num_train_epochs=5,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
eval_strategy='epoch',
save_strategy='epoch',
learning_rate=2e-5,
load_best_model_at_end=True,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=dataset['train'],
eval_dataset=dataset['val'],
)
trainer.train()
Step 5: Evaluate and Export
metrics = trainer.evaluate()
print(f"Accuracy: {metrics['eval_accuracy']:.2%}")
model.save_pretrained('./vit-finetuned')
processor.save_pretrained('./vit-finetuned')
Tips
- Use data augmentation (random crops, flips) for small datasets
- Lower the learning rate for larger models
- Use mixed precision (
fp16=True) for faster training - Monitor with Weights & Biases for experiment tracking