Beginner-friendly image classifier built with PyTorch and CIFAR-10.
An image classifier is an ML model that recognizes objects in images. We can build image classifiers by feeding tens of thousands of labelled images to a neural network. Tools like PyTorch train these networks by evaluating their performance against the dataset.
Let's build an image classifier that detects planes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. We'll download a dataset, configure a neural network, train a model, and evaluate its performance.
A model is only as good as its dataset.
Training tools need lots of high-quality data to build accurate models. We'll use the CIFAR-10 dataset of 60,000 photos to build our image classifier. Get started by downloading the dataset with torchvision
and previewing a handful of images from it.
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# Download the CIFAR-10 dataset to ./data
batch_size=10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
print("Downloading training data...")
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
print("Downloading testing data...")
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)
# Our model will recognize these kinds of objects
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# Grab images from our training data
dataiter = iter(trainloader)
images, labels = dataiter.next()
for i in range(batch_size):
# Add new subplot
plt.subplot(2, int(batch_size/2), i + 1)
# Plot the image
img = images[i]
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off')
# Add the image's label
plt.title(classes[labels[i]])
plt.suptitle('Preview of Training Data', size=20)
plt.show()
Downloading training data... Files already downloaded and verified Downloading testing data... Files already downloaded and verified
Now that we have our dataset, we need to set up a neural network for PyTorch. Our neural network will transform an image into a description.
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# Define a convolutional neural network
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
# Define a loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
print("Your network is ready for training!")
Your network is ready for training!
PyTorch trains our network by adjusting its parameters and evaluating its performance against our labelled dataset.
from tqdm import tqdm
EPOCHS = 2
print("Training...")
for epoch in range(EPOCHS):
running_loss = 0.0
for i, data in enumerate(tqdm(trainloader, desc=f"Epoch {epoch + 1} of {EPOCHS}", leave=True, ncols=80)):
inputs, labels = data
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# Save our trained model
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
Training...
Epoch 1 of 2: 100%|████████████████████████| 5000/5000 [00:20<00:00, 241.32it/s] Epoch 2 of 2: 100%|████████████████████████| 5000/5000 [00:20<00:00, 239.70it/s]
Let's test our model!
# Pick random photos from training set
if dataiter == None:
dataiter = iter(testloader)
images, labels = dataiter.next()
# Load our model
net = Net()
net.load_state_dict(torch.load(PATH))
# Analyze images
outputs = net(images)
_, predicted = torch.max(outputs, 1)
# Show results
for i in range(batch_size):
# Add new subplot
plt.subplot(2, int(batch_size/2), i + 1)
# Plot the image
img = images[i]
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.axis('off')
# Add the image's label
color = "green"
label = classes[predicted[i]]
if classes[labels[i]] != classes[predicted[i]]:
color = "red"
label = "(" + label + ")"
plt.title(label, color=color)
plt.suptitle('Objects Found by Model', size=20)
plt.show()
Let's conclude by evaluating our model's overall performance.
# Measure accuracy for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predictions = torch.max(outputs, 1)
# collect the correct predictions for each class
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
# Print accuracy statistics
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
Accuracy for class: plane is 55.8 % Accuracy for class: car is 56.2 % Accuracy for class: bird is 40.3 % Accuracy for class: cat is 25.0 % Accuracy for class: deer is 46.4 % Accuracy for class: dog is 40.8 % Accuracy for class: frog is 57.3 % Accuracy for class: horse is 62.8 % Accuracy for class: ship is 69.7 % Accuracy for class: truck is 61.6 %