ProbableOdyssey

Detecting cats with simple Pytorch

Today we’ll illustrate how machine learning works in a lightning-fast tutorial on how to use PyTorch – one of the leading Python frameworks for machine learning.

One fun example to build is an image model that classifies if there’s a cat in the picture. I want to keep this post relatively high-level, we’ll come back and expand on these topics in future entries.

Here’s what we’ll need for this recipe:

First we’ll start our project with uv

1uv init torch-cat --python 3.12.7
2cd torch-cat
3uv add torch torchvision
4uv add --dev ipython jupyter

Our end goal will be to produce a main.py script that we can run as a CLI application so that

We make no pretense, this will not be a reliable model in the slightest. In fact, the performance is going to be pretty poor in the first few iterations. But here we’ll learn the fundamental techniques necessary for using Pytorch and how to tune models.

We’ll put everything together in main.py, and we’ll run jupyter lab to develop and explore.

Here’s some boilerplate we can use for main.py:

 1# main.py
 2import argparse
 3
 4
 5def main() -> None:
 6    parser = argparse.ArgumentParser()
 7
 8    parser.add_argument(
 9        "command",
10        type=str,
11        choices=["train", "predict"],
12    )
13    parser.add_argument("path", type=str, nargs="?")
14
15    args = parser.parse_args()
16
17    if args.command == "train":
18        print("TRAIN")
19        train()
20        return
21
22    if args.command == "predict":
23        if not args.path:
24            print("image path needed for 'predict'")
25            return
26        print("PREDICT")
27        predict(args.path)
28
29
30def train() -> None:
31    ...
32    # TODO Get train dataset
33    # TODO Get model architecture
34    # TODO configure
35    # TODO Train model
36    # TODO save weights
37
38
39def predict(image_path: str) -> None:
40    ...
41    # TODO Load image
42    # TODO Get model architecture
43    # TODO load weights
44    # TODO get prediction
45    # TODO print result
46
47
48if __name__ == "__main__":
49    main()

Model

We’ll use a ResNet18 – a famous neural network architecture in computer vision. There are a lot of subtleties to neural network design, we may come back to discuss them in more depth another time. For now, we just want to get up and running with machine learning!

1from torchvision.models import resnet18
2
3model = resnet18(num_classes=1)
4print(model)
5# ResNet(
6#   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
7#   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
8#   ...

For all intents and purposes, this model is a function that maps multi-dimensional arrays (“tensors”) to other arrays. Let’s simulate what happens when we pass in a batch of images:

 1import torch
 2
 3batch_size = 10
 4n_channels = 3
 5dimensions = (32, 32)
 6
 7mock_x = torch.rand(batch_size, n_channels, *dimensions)
 8mock_y_pred = model(mock_x)
 9
10print(mock_y_pred.shape)
11# torch.Size([10, 1])
12print(mock_y_pred[:3, :])
13# tensor([-0.7909,  0.8365,  1.3185], grad_fn=<SliceBackward0>)

It appears that each image is mapped to 10 numbers, called “logits”. We can normalise these values to interpreted them as probabilities for each class:

1mock_probs = torch.sigmoid(mock_y_pred)
2print(mock_probs[:3, :])
3# tensor([0.3120, 0.6520, 0.5434], grad_fn=<SliceBackward0>)

Dataset

We’ll use the CIFAR-10 dataset for this task. It’s a common benchmarking dataset used for computer vision models, and it’s conveniently built into torchvision. This dataset comprises 60,000 colour images (32 x 32) across 10 classes – one of which is “cat”. This saves us a lot of time in acquiring and pre-processing the data.

The reality of machine learning is that a majority of the time is spent in data engineering. Improving the quality of the data results in orders of magnitude of improvement over anything other technique in general, but it’s definitely the less “sexy” part of machine learning. For the purpose of this tutorial, we’ll skip over important details so we can focus more on learning about the basics of modelling. But I want to call this out to people who want to progress further in machine learning.

In a jupyter notebook, let’s explore the data:

 1from torchvision.datasets import CIFAR10
 2
 3train_ds = CIFAR10(root="./data", train=True, download=True)
 4
 5data, label = train_ds[0]
 6
 7print(data)
 8print(label)
 9# <PIL.Image.Image image mode=RGB size=32x32 at 0x7C9C5E375A60>
10# 6

Let’s plot a few of the samples in this dataset:

 1import random
 2
 3import plotly.graph_objects as go
 4from plotly.subplots import make_subplots
 5
 6CLASSES = [
 7    "airplane",
 8    "automobile",
 9    "bird",
10    "cat",
11    "deer",
12    "dog",
13    "frog",
14    "horse",
15    "ship",
16    "truck",
17]
18
19samples = [train_ds[random.randint(0, len(train_ds) - 1)] for _ in range(4)]
20
21fig = make_subplots(
22    rows=2,
23    cols=2,
24    subplot_titles=[CLASSES[label] for _, label in samples],
25)
26for i, (data, _) in enumerate(samples):
27    row = i // 2 + 1
28    col = i % 2 + 1
29    fig.add_trace(go.Image(z=np.array(data)), row=row, col=col)
30
31fig.update_layout(height=600, width=600, showlegend=False)
32fig.show()

Ideally, we want this dataset to output numpy arrays instead of PIL images. The CIFAR10 class accepts transform and target_transform arguments that allows us to pass in functions to be applied to the data and labels before it’s returned:

1import numpy as np
2
3train_ds = CIFAR10(
4    root="./data",
5    train=True,
6    download=True,
7    transform=lambda data: np.array(data, dtype=np.float32) / 255,
8)

Neural networks operations are vectorised in order to speed up computations drastically, meaning that they accept a “batch” of images instead of a single image at a time. We can use the DataLoader class in torch collate samples of the dataset together into batches and convert them to tensors:

1from torch.utils.data import DataLoader
2
3train_dl = DataLoader(train_ds, batch_size=64)
4
5x, y = next(iter(train_dl))
6print(x.shape)
7print(y.shape)
8# torch.Size([64, 32, 32, 3])
9# torch.Size([64])

By convention, operations in pytorch anticipate inputs with dimensions [batch_idx, channel_idx, *spatial_idxs], so we’ll need to make sure we permute the dimensions before we feed them into the model:

1x = x.permute(0, 3, 1, 2)
2print(x.shape)
3# torch.Size([64, 3, 32, 32])

Training

Putting these elements together, we can start assembling our train function:

 1# main.py
 2import random
 3
 4import numpy as np
 5from torch.utils.data import DataLoader
 6from torchvision.datasets import CIFAR10
 7from torchvision.models import resnet18
 8
 9random.seed(0)
10np.random.seed(0)
11torch.manual_seed(0)
12
13CLASSES = [
14    "airplane",
15    "automobile",
16    "bird",
17    "cat",
18    "deer",
19    "dog",
20    "frog",
21    "horse",
22    "ship",
23    "truck",
24]
25
26
27def train() -> None:
28    batch_size = 64
29
30    # Get data
31    train_ds = CIFAR10(
32        root="./data",
33        train=True,
34        download=True,
35        transform=lambda data: np.array(data, dtype=np.float32) / 255,
36    )
37    train_dl = DataLoader(train_ds, batch_size=batch_size)
38
39    # Get model architecture
40    model = resnet18(num_classes=len(CLASSES))
41
42    # TODO Train model
43    # TODO save weights

Now how do we train a model? We need to iterate over each batch in out dataset, calculate the model predictions, calculate the error compared to the true labels, use back-propagation to update the model weights to minimise the error. We’ll then repeat this process across several epochs until the model is “trained”.

Lets go back to our notebook and step through the process with a single batch. First we’ll need a way to update the model weights with an optimiser, which will tell the model how to update its weights based on the gradients calculated from back-propagation:

1from torch.optim import Adam
2
3optimizer = Adam(model.parameters(), lr=0.01)

Now let’s calculate the error for a single batch:

1loss_fn = torch.nn.BCEWithLogitsLoss()  # This expects logits
2x, y = next(iter(train_dl))
3
4x = x.permute(0, 3, 1, 2)
5y_pred = model(x)
6
7error = loss_fn(y_pred, y)
8print(error)
9# tensor(0.0987, grad_fn=...)

We’ve calculated the error for a single batch, but how do we use this to update the model weights?

1error.backward()
2optimizer.step()
3optimizer.zero_grad()  # Make sure to reset the optimiser after a step

Now let’s take this and apply it to train():

 1# main.py
 2import argparse
 3
 4import torch
 5import numpy as np
 6from torch.utils.data import DataLoader
 7from torchvision.datasets import CIFAR10
 8from torchvision.models import resnet18
 9from torch.optim import Adam
10from PIL import Image
11
12CLASSES = [
13    "airplane",
14    "automobile",
15    "bird",
16    "cat",
17    "deer",
18    "dog",
19    "frog",
20    "horse",
21    "ship",
22    "truck",
23]
24
25if torch.cuda.is_available():
26    DEVICE = torch.device("cuda")
27elif torch.mps.is_available():
28    DEVICE = torch.device("mps")
29else:
30    DEVICE = torch.device("cpu")
31
32BATCH_SIZE = 1024
33N_EPOCHS = 5
34LR = 5e-5
35SEED = 1234
36
37
38def main():
39    ...
40
41
42def train() -> None:
43    # Get data
44    train_ds = CIFAR10(
45        root="./data",
46        train=True,
47        download=True,
48        transform=lambda data: np.array(data, dtype=np.float32) / 255,
49        target_transform=lambda label: np.array([label == 3], dtype=np.float32)
50    )
51    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
52
53    # Get model architecture
54    model = resnet18(num_classes=1)
55    model.to(DEVICE)
56    optimizer = Adam(model.parameters(), lr=LR)
57
58    # Train model
59    loss_fn = torch.nn.BCEWithLogitsLoss()
60    for epoch in range(N_EPOCHS):
61        epoch_error = 0.0
62        for batch_idx, (x, y) in enumerate(train_dl):
63            x, y = x.to(DEVICE), y.to(DEVICE)
64            x = x.permute(0, 3, 1, 2)
65            y_pred = model(x)
66            error = loss_fn(y_pred, y)
67            error.backward()
68            optimizer.step()
69            optimizer.zero_grad()
70            epoch_error += float(error)
71        print(f"Epoch {epoch + 1} error: {epoch_error / len(train_dl):.4}")
72
73        # Save weights
74        torch.save(model.state_dict(), "weights.pt")

Now we can train our model with

1uv run python main.py train
2# TRAIN
3# Epoch 1 error: 0.258
4# Epoch 2 error: 0.254
5# Epoch 3 error: 0.2389
6# Epoch 4 error: 0.2368
7# ...

(You may need to change the batch_size hyperparameter if you’re running out of memory on your system)

Predict

Once we’ve done the training loop, writing the predict method is relatively straightforward:

 1def predict(image_path: str) -> None:
 2    # Load image
 3    img = Image.open(image_path).convert("RGB")
 4
 5    # Center-crop to square
 6    width, height = img.size
 7    min_dim = min(width, height)
 8    left = (width - min_dim) // 2
 9    top = (height - min_dim) // 2
10    right = left + min_dim
11    bottom = top + min_dim
12    img = img.crop((left, top, right, bottom))
13
14    # Resize to target size
15    img = img.resize((32, 32), Image.Resampling.LANCZOS)
16    data = np.array(img, dtype=np.float32) / 255
17
18    # Get model architecture
19    model = resnet18(num_classes=1)
20
21    # Load weights
22    state_dict = torch.load("weights.pt")
23    model.load_state_dict(state_dict)
24
25    # Get prediction
26    model.eval()
27    x = torch.tensor(data).unsqueeze(0)
28    x = x.permute(0, 3, 1, 2)
29    y_pred = model(x)
30    y_pred = torch.sigmoid(y_pred)
31    y_pred = y_pred.squeeze(0)
32
33    # Print result
34    prob_cat = float(y_pred)
35    if prob_cat < 0.5:
36        print(f"Image doesn't have a cat (confidence: {prob_cat:.4})")
37    else:
38        print(f"Image has a cat (confidence: {prob_cat:.4})")

Try download some images and see if the model detects if it has a cat:

1uv run python main.py predict <path-to-image>

To get this model working sufficiently well, I had to tweak the learning rate and batch size and number of epochs while I was writing this. Particularly, the number of epochs is difficult to guess – too few and the model is useless, too many and the model wont generalise beyond the training data.

Conclusion

This was a bit of a silly example, but I hope this was an informative read on how to get started with pytorch. There are many directions we can go for improving this tutorial:

For those who are curious to learn more, I highly recommend the official tutorials on the pytorch website!

Reply to this post by email ↪