Detecting cats with simple Pytorch
Today we’ll illustrate how machine learning works in a lightning-fast tutorial on how to use PyTorch – one of the leading Python frameworks for machine learning.
One fun example to build is an image model that classifies if there’s a cat in the picture. I want to keep this post relatively high-level, we’ll come back and expand on these topics in future entries.
Here’s what we’ll need for this recipe:
uv
(python environment management)torch
(the engine for defining neural networks)torchvision
(ready-to-use datasets and models)
First we’ll start our project with uv
1uv init torch-cat --python 3.12.7
2cd torch-cat
3uv add torch torchvision
4uv add --dev ipython jupyter
Our end goal will be to produce a main.py
script that we can run as a CLI
application so that
uv run python main.py train
will train and save a neural networkuv run python main.py predict <image-path>
will guess if the image contains a cat.
We make no pretense, this will not be a reliable model in the slightest. In fact, the performance is going to be pretty poor in the first few iterations. But here we’ll learn the fundamental techniques necessary for using Pytorch and how to tune models.
We’ll put everything together in main.py
, and we’ll run jupyter lab to develop
and explore.
Here’s some boilerplate we can use for main.py
:
1# main.py
2import argparse
3
4
5def main() -> None:
6 parser = argparse.ArgumentParser()
7
8 parser.add_argument(
9 "command",
10 type=str,
11 choices=["train", "predict"],
12 )
13 parser.add_argument("path", type=str, nargs="?")
14
15 args = parser.parse_args()
16
17 if args.command == "train":
18 print("TRAIN")
19 train()
20 return
21
22 if args.command == "predict":
23 if not args.path:
24 print("image path needed for 'predict'")
25 return
26 print("PREDICT")
27 predict(args.path)
28
29
30def train() -> None:
31 ...
32 # TODO Get train dataset
33 # TODO Get model architecture
34 # TODO configure
35 # TODO Train model
36 # TODO save weights
37
38
39def predict(image_path: str) -> None:
40 ...
41 # TODO Load image
42 # TODO Get model architecture
43 # TODO load weights
44 # TODO get prediction
45 # TODO print result
46
47
48if __name__ == "__main__":
49 main()
Model
We’ll use a ResNet18
– a famous neural network architecture in computer
vision. There are a lot of subtleties to neural network design, we may come
back to discuss them in more depth another time. For now, we just want to get up
and running with machine learning!
1from torchvision.models import resnet18
2
3model = resnet18(num_classes=1)
4print(model)
5# ResNet(
6# (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
7# (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
8# ...
For all intents and purposes, this model is a function that maps multi-dimensional arrays (“tensors”) to other arrays. Let’s simulate what happens when we pass in a batch of images:
1import torch
2
3batch_size = 10
4n_channels = 3
5dimensions = (32, 32)
6
7mock_x = torch.rand(batch_size, n_channels, *dimensions)
8mock_y_pred = model(mock_x)
9
10print(mock_y_pred.shape)
11# torch.Size([10, 1])
12print(mock_y_pred[:3, :])
13# tensor([-0.7909, 0.8365, 1.3185], grad_fn=<SliceBackward0>)
It appears that each image is mapped to 10 numbers, called “logits”. We can normalise these values to interpreted them as probabilities for each class:
1mock_probs = torch.sigmoid(mock_y_pred)
2print(mock_probs[:3, :])
3# tensor([0.3120, 0.6520, 0.5434], grad_fn=<SliceBackward0>)
Dataset
We’ll use the CIFAR-10
dataset
for this task. It’s a common benchmarking dataset used for computer vision
models, and it’s conveniently built into torchvision
. This dataset comprises
60,000 colour images (32 x 32) across 10 classes – one of which is “cat”. This
saves us a lot of time in acquiring and pre-processing the data.
The reality of machine learning is that a majority of the time is spent in data engineering. Improving the quality of the data results in orders of magnitude of improvement over anything other technique in general, but it’s definitely the less “sexy” part of machine learning. For the purpose of this tutorial, we’ll skip over important details so we can focus more on learning about the basics of modelling. But I want to call this out to people who want to progress further in machine learning.
In a jupyter notebook, let’s explore the data:
1from torchvision.datasets import CIFAR10
2
3train_ds = CIFAR10(root="./data", train=True, download=True)
4
5data, label = train_ds[0]
6
7print(data)
8print(label)
9# <PIL.Image.Image image mode=RGB size=32x32 at 0x7C9C5E375A60>
10# 6
Let’s plot a few of the samples in this dataset:
1import random
2
3import plotly.graph_objects as go
4from plotly.subplots import make_subplots
5
6CLASSES = [
7 "airplane",
8 "automobile",
9 "bird",
10 "cat",
11 "deer",
12 "dog",
13 "frog",
14 "horse",
15 "ship",
16 "truck",
17]
18
19samples = [train_ds[random.randint(0, len(train_ds) - 1)] for _ in range(4)]
20
21fig = make_subplots(
22 rows=2,
23 cols=2,
24 subplot_titles=[CLASSES[label] for _, label in samples],
25)
26for i, (data, _) in enumerate(samples):
27 row = i // 2 + 1
28 col = i % 2 + 1
29 fig.add_trace(go.Image(z=np.array(data)), row=row, col=col)
30
31fig.update_layout(height=600, width=600, showlegend=False)
32fig.show()
Ideally, we want this dataset to output numpy arrays instead of PIL images. The
CIFAR10 class accepts transform
and target_transform
arguments that allows
us to pass in functions to be applied to the data and labels before it’s
returned:
1import numpy as np
2
3train_ds = CIFAR10(
4 root="./data",
5 train=True,
6 download=True,
7 transform=lambda data: np.array(data, dtype=np.float32) / 255,
8)
Neural networks operations are vectorised in order to speed up computations
drastically, meaning that they accept a “batch” of images instead of a single
image at a time. We can use the DataLoader
class in torch
collate samples of
the dataset together into batches and convert them to tensors:
1from torch.utils.data import DataLoader
2
3train_dl = DataLoader(train_ds, batch_size=64)
4
5x, y = next(iter(train_dl))
6print(x.shape)
7print(y.shape)
8# torch.Size([64, 32, 32, 3])
9# torch.Size([64])
By convention, operations in pytorch anticipate inputs with dimensions
[batch_idx, channel_idx, *spatial_idxs]
, so we’ll need to make sure we permute
the dimensions before we feed them into the model:
1x = x.permute(0, 3, 1, 2)
2print(x.shape)
3# torch.Size([64, 3, 32, 32])
Training
Putting these elements together, we can start assembling our train
function:
1# main.py
2import random
3
4import numpy as np
5from torch.utils.data import DataLoader
6from torchvision.datasets import CIFAR10
7from torchvision.models import resnet18
8
9random.seed(0)
10np.random.seed(0)
11torch.manual_seed(0)
12
13CLASSES = [
14 "airplane",
15 "automobile",
16 "bird",
17 "cat",
18 "deer",
19 "dog",
20 "frog",
21 "horse",
22 "ship",
23 "truck",
24]
25
26
27def train() -> None:
28 batch_size = 64
29
30 # Get data
31 train_ds = CIFAR10(
32 root="./data",
33 train=True,
34 download=True,
35 transform=lambda data: np.array(data, dtype=np.float32) / 255,
36 )
37 train_dl = DataLoader(train_ds, batch_size=batch_size)
38
39 # Get model architecture
40 model = resnet18(num_classes=len(CLASSES))
41
42 # TODO Train model
43 # TODO save weights
Now how do we train a model? We need to iterate over each batch in out dataset, calculate the model predictions, calculate the error compared to the true labels, use back-propagation to update the model weights to minimise the error. We’ll then repeat this process across several epochs until the model is “trained”.
Lets go back to our notebook and step through the process with a single batch. First we’ll need a way to update the model weights with an optimiser, which will tell the model how to update its weights based on the gradients calculated from back-propagation:
1from torch.optim import Adam
2
3optimizer = Adam(model.parameters(), lr=0.01)
Now let’s calculate the error for a single batch:
1loss_fn = torch.nn.BCEWithLogitsLoss() # This expects logits
2x, y = next(iter(train_dl))
3
4x = x.permute(0, 3, 1, 2)
5y_pred = model(x)
6
7error = loss_fn(y_pred, y)
8print(error)
9# tensor(0.0987, grad_fn=...)
We’ve calculated the error for a single batch, but how do we use this to update the model weights?
1error.backward()
2optimizer.step()
3optimizer.zero_grad() # Make sure to reset the optimiser after a step
Now let’s take this and apply it to train()
:
1# main.py
2import argparse
3
4import torch
5import numpy as np
6from torch.utils.data import DataLoader
7from torchvision.datasets import CIFAR10
8from torchvision.models import resnet18
9from torch.optim import Adam
10from PIL import Image
11
12CLASSES = [
13 "airplane",
14 "automobile",
15 "bird",
16 "cat",
17 "deer",
18 "dog",
19 "frog",
20 "horse",
21 "ship",
22 "truck",
23]
24
25if torch.cuda.is_available():
26 DEVICE = torch.device("cuda")
27elif torch.mps.is_available():
28 DEVICE = torch.device("mps")
29else:
30 DEVICE = torch.device("cpu")
31
32BATCH_SIZE = 1024
33N_EPOCHS = 5
34LR = 5e-5
35SEED = 1234
36
37
38def main():
39 ...
40
41
42def train() -> None:
43 # Get data
44 train_ds = CIFAR10(
45 root="./data",
46 train=True,
47 download=True,
48 transform=lambda data: np.array(data, dtype=np.float32) / 255,
49 target_transform=lambda label: np.array([label == 3], dtype=np.float32)
50 )
51 train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
52
53 # Get model architecture
54 model = resnet18(num_classes=1)
55 model.to(DEVICE)
56 optimizer = Adam(model.parameters(), lr=LR)
57
58 # Train model
59 loss_fn = torch.nn.BCEWithLogitsLoss()
60 for epoch in range(N_EPOCHS):
61 epoch_error = 0.0
62 for batch_idx, (x, y) in enumerate(train_dl):
63 x, y = x.to(DEVICE), y.to(DEVICE)
64 x = x.permute(0, 3, 1, 2)
65 y_pred = model(x)
66 error = loss_fn(y_pred, y)
67 error.backward()
68 optimizer.step()
69 optimizer.zero_grad()
70 epoch_error += float(error)
71 print(f"Epoch {epoch + 1} error: {epoch_error / len(train_dl):.4}")
72
73 # Save weights
74 torch.save(model.state_dict(), "weights.pt")
Now we can train our model with
1uv run python main.py train
2# TRAIN
3# Epoch 1 error: 0.258
4# Epoch 2 error: 0.254
5# Epoch 3 error: 0.2389
6# Epoch 4 error: 0.2368
7# ...
(You may need to change the batch_size
hyperparameter if you’re running out
of memory on your system)
Predict
Once we’ve done the training loop, writing the predict
method is relatively
straightforward:
1def predict(image_path: str) -> None:
2 # Load image
3 img = Image.open(image_path).convert("RGB")
4
5 # Center-crop to square
6 width, height = img.size
7 min_dim = min(width, height)
8 left = (width - min_dim) // 2
9 top = (height - min_dim) // 2
10 right = left + min_dim
11 bottom = top + min_dim
12 img = img.crop((left, top, right, bottom))
13
14 # Resize to target size
15 img = img.resize((32, 32), Image.Resampling.LANCZOS)
16 data = np.array(img, dtype=np.float32) / 255
17
18 # Get model architecture
19 model = resnet18(num_classes=1)
20
21 # Load weights
22 state_dict = torch.load("weights.pt")
23 model.load_state_dict(state_dict)
24
25 # Get prediction
26 model.eval()
27 x = torch.tensor(data).unsqueeze(0)
28 x = x.permute(0, 3, 1, 2)
29 y_pred = model(x)
30 y_pred = torch.sigmoid(y_pred)
31 y_pred = y_pred.squeeze(0)
32
33 # Print result
34 prob_cat = float(y_pred)
35 if prob_cat < 0.5:
36 print(f"Image doesn't have a cat (confidence: {prob_cat:.4})")
37 else:
38 print(f"Image has a cat (confidence: {prob_cat:.4})")
Try download some images and see if the model detects if it has a cat:
1uv run python main.py predict <path-to-image>
To get this model working sufficiently well, I had to tweak the learning rate and batch size and number of epochs while I was writing this. Particularly, the number of epochs is difficult to guess – too few and the model is useless, too many and the model wont generalise beyond the training data.
Conclusion
This was a bit of a silly example, but I hope this was an informative read on how to get started with pytorch. There are many directions we can go for improving this tutorial:
- Using validation and test splits to ensure we’re not overfitting
- Using a higher-quality dataset
- Balancing the dataset
- Adding noise to the dataset with “augmentations”
- Trying other architectures or loss functions
For those who are curious to learn more, I highly recommend the official tutorials on the pytorch website!