Detecting cats with simple Pytorch
Today we’ll illustrate how machine learning works in a lightning-fast tutorial on how to use PyTorch – one of the leading Python frameworks for machine learning.
One fun example to build is an image model that classifies if there’s a cat in the picture. I want to keep this post relatively high-level, we’ll come back and expand on these topics in future entries.
Here’s what we’ll need for this recipe:
uv
(python environment management)torch
(the engine for defining neural networks)torchvision
(ready-to-use datasets and models)
First we’ll start our project with uv
uv init torch-cat --python 3.12.7
cd torch-cat
uv add torch torchvision
uv add --dev ipython jupyter
Our end goal will be to produce a main.py
script that we can run as a CLI
application so that
uv run python main.py train
will train and save a neural networkuv run python main.py predict <image-path>
will guess if the image contains a cat.
We make no pretense, this will not be a reliable model in the slightest. In fact, the performance is going to be pretty poor in the first few iterations. But here we’ll learn the fundamental techniques necessary for using Pytorch and how to tune models.
We’ll put everything together in main.py
, and we’ll run jupyter lab to develop
and explore.
Here’s some boilerplate we can use for main.py
:
# main.py
import argparse
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument(
"command",
type=str,
choices=["train", "predict"],
)
parser.add_argument("path", type=str, nargs="?")
args = parser.parse_args()
if args.command == "train":
print("TRAIN")
train()
return
if args.command == "predict":
if not args.path:
print("image path needed for 'predict'")
return
print("PREDICT")
predict(args.path)
def train() -> None:
...
# TODO Get train dataset
# TODO Get model architecture
# TODO configure
# TODO Train model
# TODO save weights
def predict(image_path: str) -> None:
...
# TODO Load image
# TODO Get model architecture
# TODO load weights
# TODO get prediction
# TODO print result
if __name__ == "__main__":
main()
Model
We’ll use a ResNet18
– a famous neural network architecture in computer
vision. There are a lot of subtleties to neural network design, we may come
back to discuss them in more depth another time. For now, we just want to get up
and running with machine learning!
from torchvision.models import resnet18
model = resnet18(num_classes=1)
print(model)
# ResNet(
# (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# ...
For all intents and purposes, this model is a function that maps multi-dimensional arrays (“tensors”) to other arrays. Let’s simulate what happens when we pass in a batch of images:
import torch
batch_size = 10
n_channels = 3
dimensions = (32, 32)
mock_x = torch.rand(batch_size, n_channels, *dimensions)
mock_y_pred = model(mock_x)
print(mock_y_pred.shape)
# torch.Size([10, 1])
print(mock_y_pred[:3, :])
# tensor([-0.7909, 0.8365, 1.3185], grad_fn=<SliceBackward0>)
It appears that each image is mapped to 10 numbers, called “logits”. We can normalise these values to interpreted them as probabilities for each class:
mock_probs = torch.sigmoid(mock_y_pred)
print(mock_probs[:3, :])
# tensor([0.3120, 0.6520, 0.5434], grad_fn=<SliceBackward0>)
Dataset
We’ll use the CIFAR-10
dataset
for this task. It’s a common benchmarking dataset used for computer vision
models, and it’s conveniently built into torchvision
. This dataset comprises
60,000 colour images (32 x 32) across 10 classes – one of which is “cat”. This
saves us a lot of time in acquiring and pre-processing the data.
The reality of machine learning is that a majority of the time is spent in data engineering. Improving the quality of the data results in orders of magnitude of improvement over anything other technique in general, but it’s definitely the less “sexy” part of machine learning. For the purpose of this tutorial, we’ll skip over important details so we can focus more on learning about the basics of modelling. But I want to call this out to people who want to progress further in machine learning.
In a jupyter notebook, let’s explore the data:
from torchvision.datasets import CIFAR10
train_ds = CIFAR10(root="./data", train=True, download=True)
data, label = train_ds[0]
print(data)
print(label)
# <PIL.Image.Image image mode=RGB size=32x32 at 0x7C9C5E375A60>
# 6
Let’s plot a few of the samples in this dataset:
import random
import plotly.graph_objects as go
from plotly.subplots import make_subplots
CLASSES = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
samples = [train_ds[random.randint(0, len(train_ds) - 1)] for _ in range(4)]
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=[CLASSES[label] for _, label in samples],
)
for i, (data, _) in enumerate(samples):
row = i // 2 + 1
col = i % 2 + 1
fig.add_trace(go.Image(z=np.array(data)), row=row, col=col)
fig.update_layout(height=600, width=600, showlegend=False)
fig.show()
Ideally, we want this dataset to output numpy arrays instead of PIL images. The
CIFAR10 class accepts transform
and target_transform
arguments that allows
us to pass in functions to be applied to the data and labels before it’s
returned:
import numpy as np
train_ds = CIFAR10(
root="./data",
train=True,
download=True,
transform=lambda data: np.array(data, dtype=np.float32) / 255,
)
Neural networks operations are vectorised in order to speed up computations
drastically, meaning that they accept a “batch” of images instead of a single
image at a time. We can use the DataLoader
class in torch
collate samples of
the dataset together into batches and convert them to tensors:
from torch.utils.data import DataLoader
train_dl = DataLoader(train_ds, batch_size=64)
x, y = next(iter(train_dl))
print(x.shape)
print(y.shape)
# torch.Size([64, 32, 32, 3])
# torch.Size([64])
By convention, operations in pytorch anticipate inputs with dimensions
[batch_idx, channel_idx, *spatial_idxs]
, so we’ll need to make sure we permute
the dimensions before we feed them into the model:
x = x.permute(0, 3, 1, 2)
print(x.shape)
# torch.Size([64, 3, 32, 32])
Training
Putting these elements together, we can start assembling our train
function:
# main.py
import random
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
CLASSES = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
def train() -> None:
batch_size = 64
# Get data
train_ds = CIFAR10(
root="./data",
train=True,
download=True,
transform=lambda data: np.array(data, dtype=np.float32) / 255,
)
train_dl = DataLoader(train_ds, batch_size=batch_size)
# Get model architecture
model = resnet18(num_classes=len(CLASSES))
# TODO Train model
# TODO save weights
Now how do we train a model? We need to iterate over each batch in out dataset, calculate the model predictions, calculate the error compared to the true labels, use back-propagation to update the model weights to minimise the error. We’ll then repeat this process across several epochs until the model is “trained”.
Lets go back to our notebook and step through the process with a single batch. First we’ll need a way to update the model weights with an optimiser, which will tell the model how to update its weights based on the gradients calculated from back-propagation:
from torch.optim import Adam
optimizer = Adam(model.parameters(), lr=0.01)
Now let’s calculate the error for a single batch:
loss_fn = torch.nn.BCEWithLogitsLoss() # This expects logits
x, y = next(iter(train_dl))
x = x.permute(0, 3, 1, 2)
y_pred = model(x)
error = loss_fn(y_pred, y)
print(error)
# tensor(0.0987, grad_fn=...)
We’ve calculated the error for a single batch, but how do we use this to update the model weights?
error.backward()
optimizer.step()
optimizer.zero_grad() # Make sure to reset the optimiser after a step
Now let’s take this and apply it to train()
:
# main.py
import argparse
import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from torch.optim import Adam
from PIL import Image
CLASSES = [
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
]
if torch.cuda.is_available():
DEVICE = torch.device("cuda")
elif torch.mps.is_available():
DEVICE = torch.device("mps")
else:
DEVICE = torch.device("cpu")
BATCH_SIZE = 1024
N_EPOCHS = 5
LR = 5e-5
SEED = 1234
def main():
...
def train() -> None:
# Get data
train_ds = CIFAR10(
root="./data",
train=True,
download=True,
transform=lambda data: np.array(data, dtype=np.float32) / 255,
target_transform=lambda label: np.array([label == 3], dtype=np.float32)
)
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
# Get model architecture
model = resnet18(num_classes=1)
model.to(DEVICE)
optimizer = Adam(model.parameters(), lr=LR)
# Train model
loss_fn = torch.nn.BCEWithLogitsLoss()
for epoch in range(N_EPOCHS):
epoch_error = 0.0
for batch_idx, (x, y) in enumerate(train_dl):
x, y = x.to(DEVICE), y.to(DEVICE)
x = x.permute(0, 3, 1, 2)
y_pred = model(x)
error = loss_fn(y_pred, y)
error.backward()
optimizer.step()
optimizer.zero_grad()
epoch_error += float(error)
print(f"Epoch {epoch + 1} error: {epoch_error / len(train_dl):.4}")
# Save weights
torch.save(model.state_dict(), "weights.pt")
Now we can train our model with
uv run python main.py train
# TRAIN
# Epoch 1 error: 0.258
# Epoch 2 error: 0.254
# Epoch 3 error: 0.2389
# Epoch 4 error: 0.2368
# ...
(You may need to change the batch_size
hyperparameter if you’re running out
of memory on your system)
Predict
Once we’ve done the training loop, writing the predict
method is relatively
straightforward:
def predict(image_path: str) -> None:
# Load image
img = Image.open(image_path).convert("RGB")
# Center-crop to square
width, height = img.size
min_dim = min(width, height)
left = (width - min_dim) // 2
top = (height - min_dim) // 2
right = left + min_dim
bottom = top + min_dim
img = img.crop((left, top, right, bottom))
# Resize to target size
img = img.resize((32, 32), Image.Resampling.LANCZOS)
data = np.array(img, dtype=np.float32) / 255
# Get model architecture
model = resnet18(num_classes=1)
# Load weights
state_dict = torch.load("weights.pt")
model.load_state_dict(state_dict)
# Get prediction
model.eval()
x = torch.tensor(data).unsqueeze(0)
x = x.permute(0, 3, 1, 2)
y_pred = model(x)
y_pred = torch.sigmoid(y_pred)
y_pred = y_pred.squeeze(0)
# Print result
prob_cat = float(y_pred)
if prob_cat < 0.5:
print(f"Image doesn't have a cat (confidence: {prob_cat:.4})")
else:
print(f"Image has a cat (confidence: {prob_cat:.4})")
Try download some images and see if the model detects if it has a cat:
uv run python main.py predict <path-to-image>
To get this model working sufficiently well, I had to tweak the learning rate and batch size and number of epochs while I was writing this. Particularly, the number of epochs is difficult to guess – too few and the model is useless, too many and the model wont generalise beyond the training data.
Conclusion
This was a bit of a silly example, but I hope this was an informative read on how to get started with pytorch. There are many directions we can go for improving this tutorial:
- Using validation and test splits to ensure we’re not overfitting
- Using a higher-quality dataset
- Balancing the dataset
- Adding noise to the dataset with “augmentations”
- Trying other architectures or loss functions
For those who are curious to learn more, I highly recommend the official tutorials on the pytorch website!