ProbableOdyssey

Detecting cats with simple Pytorch

Today we’ll illustrate how machine learning works in a lightning-fast tutorial on how to use PyTorch – one of the leading Python frameworks for machine learning.

One fun example to build is an image model that classifies if there’s a cat in the picture. I want to keep this post relatively high-level, we’ll come back and expand on these topics in future entries.

Here’s what we’ll need for this recipe:

First we’ll start our project with uv

uv init torch-cat --python 3.12.7
cd torch-cat
uv add torch torchvision
uv add --dev ipython jupyter

Our end goal will be to produce a main.py script that we can run as a CLI application so that

We make no pretense, this will not be a reliable model in the slightest. In fact, the performance is going to be pretty poor in the first few iterations. But here we’ll learn the fundamental techniques necessary for using Pytorch and how to tune models.

We’ll put everything together in main.py, and we’ll run jupyter lab to develop and explore.

Here’s some boilerplate we can use for main.py:

# main.py
import argparse


def main() -> None:
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "command",
        type=str,
        choices=["train", "predict"],
    )
    parser.add_argument("path", type=str, nargs="?")

    args = parser.parse_args()

    if args.command == "train":
        print("TRAIN")
        train()
        return

    if args.command == "predict":
        if not args.path:
            print("image path needed for 'predict'")
            return
        print("PREDICT")
        predict(args.path)


def train() -> None:
    ...
    # TODO Get train dataset
    # TODO Get model architecture
    # TODO configure
    # TODO Train model
    # TODO save weights


def predict(image_path: str) -> None:
    ...
    # TODO Load image
    # TODO Get model architecture
    # TODO load weights
    # TODO get prediction
    # TODO print result


if __name__ == "__main__":
    main()

Model

We’ll use a ResNet18 – a famous neural network architecture in computer vision. There are a lot of subtleties to neural network design, we may come back to discuss them in more depth another time. For now, we just want to get up and running with machine learning!

from torchvision.models import resnet18

model = resnet18(num_classes=1)
print(model)
# ResNet(
#   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
#   ...

For all intents and purposes, this model is a function that maps multi-dimensional arrays (“tensors”) to other arrays. Let’s simulate what happens when we pass in a batch of images:

import torch

batch_size = 10
n_channels = 3
dimensions = (32, 32)

mock_x = torch.rand(batch_size, n_channels, *dimensions)
mock_y_pred = model(mock_x)

print(mock_y_pred.shape)
# torch.Size([10, 1])
print(mock_y_pred[:3, :])
# tensor([-0.7909,  0.8365,  1.3185], grad_fn=<SliceBackward0>)

It appears that each image is mapped to 10 numbers, called “logits”. We can normalise these values to interpreted them as probabilities for each class:

mock_probs = torch.sigmoid(mock_y_pred)
print(mock_probs[:3, :])
# tensor([0.3120, 0.6520, 0.5434], grad_fn=<SliceBackward0>)

Dataset

We’ll use the CIFAR-10 dataset for this task. It’s a common benchmarking dataset used for computer vision models, and it’s conveniently built into torchvision. This dataset comprises 60,000 colour images (32 x 32) across 10 classes – one of which is “cat”. This saves us a lot of time in acquiring and pre-processing the data.

The reality of machine learning is that a majority of the time is spent in data engineering. Improving the quality of the data results in orders of magnitude of improvement over anything other technique in general, but it’s definitely the less “sexy” part of machine learning. For the purpose of this tutorial, we’ll skip over important details so we can focus more on learning about the basics of modelling. But I want to call this out to people who want to progress further in machine learning.

In a jupyter notebook, let’s explore the data:

from torchvision.datasets import CIFAR10

train_ds = CIFAR10(root="./data", train=True, download=True)

data, label = train_ds[0]

print(data)
print(label)
# <PIL.Image.Image image mode=RGB size=32x32 at 0x7C9C5E375A60>
# 6

Let’s plot a few of the samples in this dataset:

import random

import plotly.graph_objects as go
from plotly.subplots import make_subplots

CLASSES = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

samples = [train_ds[random.randint(0, len(train_ds) - 1)] for _ in range(4)]

fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=[CLASSES[label] for _, label in samples],
)
for i, (data, _) in enumerate(samples):
    row = i // 2 + 1
    col = i % 2 + 1
    fig.add_trace(go.Image(z=np.array(data)), row=row, col=col)

fig.update_layout(height=600, width=600, showlegend=False)
fig.show()

Ideally, we want this dataset to output numpy arrays instead of PIL images. The CIFAR10 class accepts transform and target_transform arguments that allows us to pass in functions to be applied to the data and labels before it’s returned:

import numpy as np

train_ds = CIFAR10(
    root="./data",
    train=True,
    download=True,
    transform=lambda data: np.array(data, dtype=np.float32) / 255,
)

Neural networks operations are vectorised in order to speed up computations drastically, meaning that they accept a “batch” of images instead of a single image at a time. We can use the DataLoader class in torch collate samples of the dataset together into batches and convert them to tensors:

from torch.utils.data import DataLoader

train_dl = DataLoader(train_ds, batch_size=64)

x, y = next(iter(train_dl))
print(x.shape)
print(y.shape)
# torch.Size([64, 32, 32, 3])
# torch.Size([64])

By convention, operations in pytorch anticipate inputs with dimensions [batch_idx, channel_idx, *spatial_idxs], so we’ll need to make sure we permute the dimensions before we feed them into the model:

x = x.permute(0, 3, 1, 2)
print(x.shape)
# torch.Size([64, 3, 32, 32])

Training

Putting these elements together, we can start assembling our train function:

# main.py
import random

import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18

random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

CLASSES = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]


def train() -> None:
    batch_size = 64

    # Get data
    train_ds = CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=lambda data: np.array(data, dtype=np.float32) / 255,
    )
    train_dl = DataLoader(train_ds, batch_size=batch_size)

    # Get model architecture
    model = resnet18(num_classes=len(CLASSES))

    # TODO Train model
    # TODO save weights

Now how do we train a model? We need to iterate over each batch in out dataset, calculate the model predictions, calculate the error compared to the true labels, use back-propagation to update the model weights to minimise the error. We’ll then repeat this process across several epochs until the model is “trained”.

Lets go back to our notebook and step through the process with a single batch. First we’ll need a way to update the model weights with an optimiser, which will tell the model how to update its weights based on the gradients calculated from back-propagation:

from torch.optim import Adam

optimizer = Adam(model.parameters(), lr=0.01)

Now let’s calculate the error for a single batch:

loss_fn = torch.nn.BCEWithLogitsLoss()  # This expects logits
x, y = next(iter(train_dl))

x = x.permute(0, 3, 1, 2)
y_pred = model(x)

error = loss_fn(y_pred, y)
print(error)
# tensor(0.0987, grad_fn=...)

We’ve calculated the error for a single batch, but how do we use this to update the model weights?

error.backward()
optimizer.step()
optimizer.zero_grad()  # Make sure to reset the optimiser after a step

Now let’s take this and apply it to train():

# main.py
import argparse

import torch
import numpy as np
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from torch.optim import Adam
from PIL import Image

CLASSES = [
    "airplane",
    "automobile",
    "bird",
    "cat",
    "deer",
    "dog",
    "frog",
    "horse",
    "ship",
    "truck",
]

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif torch.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

BATCH_SIZE = 1024
N_EPOCHS = 5
LR = 5e-5
SEED = 1234


def main():
    ...


def train() -> None:
    # Get data
    train_ds = CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=lambda data: np.array(data, dtype=np.float32) / 255,
        target_transform=lambda label: np.array([label == 3], dtype=np.float32)
    )
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)

    # Get model architecture
    model = resnet18(num_classes=1)
    model.to(DEVICE)
    optimizer = Adam(model.parameters(), lr=LR)

    # Train model
    loss_fn = torch.nn.BCEWithLogitsLoss()
    for epoch in range(N_EPOCHS):
        epoch_error = 0.0
        for batch_idx, (x, y) in enumerate(train_dl):
            x, y = x.to(DEVICE), y.to(DEVICE)
            x = x.permute(0, 3, 1, 2)
            y_pred = model(x)
            error = loss_fn(y_pred, y)
            error.backward()
            optimizer.step()
            optimizer.zero_grad()
            epoch_error += float(error)
        print(f"Epoch {epoch + 1} error: {epoch_error / len(train_dl):.4}")

        # Save weights
        torch.save(model.state_dict(), "weights.pt")

Now we can train our model with

uv run python main.py train
# TRAIN
# Epoch 1 error: 0.258
# Epoch 2 error: 0.254
# Epoch 3 error: 0.2389
# Epoch 4 error: 0.2368
# ...

(You may need to change the batch_size hyperparameter if you’re running out of memory on your system)

Predict

Once we’ve done the training loop, writing the predict method is relatively straightforward:

def predict(image_path: str) -> None:
    # Load image
    img = Image.open(image_path).convert("RGB")

    # Center-crop to square
    width, height = img.size
    min_dim = min(width, height)
    left = (width - min_dim) // 2
    top = (height - min_dim) // 2
    right = left + min_dim
    bottom = top + min_dim
    img = img.crop((left, top, right, bottom))

    # Resize to target size
    img = img.resize((32, 32), Image.Resampling.LANCZOS)
    data = np.array(img, dtype=np.float32) / 255

    # Get model architecture
    model = resnet18(num_classes=1)

    # Load weights
    state_dict = torch.load("weights.pt")
    model.load_state_dict(state_dict)

    # Get prediction
    model.eval()
    x = torch.tensor(data).unsqueeze(0)
    x = x.permute(0, 3, 1, 2)
    y_pred = model(x)
    y_pred = torch.sigmoid(y_pred)
    y_pred = y_pred.squeeze(0)

    # Print result
    prob_cat = float(y_pred)
    if prob_cat < 0.5:
        print(f"Image doesn't have a cat (confidence: {prob_cat:.4})")
    else:
        print(f"Image has a cat (confidence: {prob_cat:.4})")

Try download some images and see if the model detects if it has a cat:

uv run python main.py predict <path-to-image>

To get this model working sufficiently well, I had to tweak the learning rate and batch size and number of epochs while I was writing this. Particularly, the number of epochs is difficult to guess – too few and the model is useless, too many and the model wont generalise beyond the training data.

Conclusion

This was a bit of a silly example, but I hope this was an informative read on how to get started with pytorch. There are many directions we can go for improving this tutorial:

For those who are curious to learn more, I highly recommend the official tutorials on the pytorch website!

Reply to this post by email ↪