Skip to content

CIFAR-10

Introduction

Here we will walk through an example project for training a model to classify images from the CIFAR10 dataset. This combines most of the concepts we have seen in the guides so far. It offers a simple model, data and optimizer, and shows how to:

  • Use a variety of special attributes to define the config schema
  • Load a config from a yaml file
  • Apply CLI arguments to the config
  • Load a sweep file to sample from a set of configurations

Schema

The schema is defined in a proto file. This is the file that we will use to train the model. It looks like this:

The model

The model definition is a simple convolutional neural network followed by a multi-layer perceptron. In proto definition, it looks like this:

// Model configuration
message Model {
    // Conv blocks
    ConvNet conv_net = 1;
    // MLP head
    MLP head = 2;
}

// Convolutional neural network configuration
message ConvNet {
    // Conv layer configuration
    ConvBlock block = 1;
    // Number of layers
    uint32 num_layers = 2 [(pgml.default).uint32 = 2];
}

// Multi-layer perceptron configuration
message MLP {
    // Linear layer configuration
    LinearBlock block = 1;
    // Number of layers
    uint32 num_layers = 2 [(pgml.default).uint32 = 2];
}

// Convolutional layer configuration
message ConvBlock {
    option (pgml.factory) = "cifar10.modules.ConvBlock";
    // Number of output channels
    uint32 out_channels = 1  [(pgml.default).uint32 = 128];
    // Square kernel size
    uint32 kernel_size = 2 [(pgml.default).uint32 = 3];
    // Square pool size
    uint32 pool_size = 3 [(pgml.default).uint32 = 2];
    // Activation function
    Activation activation = 4 [(pgml.default).enum = "GELU"];
}

// Linear layer configuration
message LinearBlock {
    option (pgml.factory) = "cifar10.modules.LinearBlock";
    // Number of output features
    uint32 out_features = 1 [(pgml.default).uint32 = 128];
    // Activation function
    Activation activation = 2 [(pgml.default).enum = "GELU"];
}

// Activation function
enum Activation {
    // GELU activation
    GELU = 0;
    // ReLU activation
    RELU = 1;
}

Concepts that we've used here are:

  • Nesting: we define a Model message that contains two nested messages: ConvNet and MLP which in turn contain blocks of linear and convolutional layers.
  • pgml.factory: This is a special attribute that tells py-gen-ml that with these values we can build an instance of the given class.
  • pgml.default: This is a special attribute that tells py-gen-ml to use a default value if the field is not set.
  • Enums: we define an enum for the activation function.

The data

We define a Data message that contains the batch size and the number of epochs to train. In proto definition, it looks like this:

// Data configuration
message Data {
    // Batch size for a single GPU
    uint32 batch_size = 1 [(pgml.default).uint32 = 32];
    // Number of epochs to train
    uint32 num_epochs = 2 [(pgml.default).uint32 = 10];
}

The optimizer

We define an Optimizer message that contains the learning rate and the decay rate. In proto definition, it looks like this:

// Optimizer configuration
message Optimizer {
    // Learning rate
    float learning_rate = 1 [(pgml.default).float = 1e-4];
    // Decay rate
    float beta1 = 2 [(pgml.default).float = 0.99];
}

The project

We define a Project message that contains the model, the optimizer, and the data. In proto definition, it looks like this:

syntax = "proto3";

package cifar10;

import "py_gen_ml/extensions.proto";


// Global configuration
message Project {
    option (pgml.cli) = {
        enable: true
        arg: {name: "conv_activation", path: "net.conv_net.block.activation"}
        arg: {name: "head_activation", path: "net.head.block.activation"}
        arg: {name: "num_conv_layers", path: "net.conv_net.num_layers"}
        arg: {name: "num_mlp_layers", path: "net.head.num_layers"}
    };
    // Model configuration
    Model net = 1;
    // Optimizer configuration
    Optimizer optimizer = 2;
    // Data configuration
    Data data = 3;
}
As you can see, there are a couple of arg references to ensure we can propagate values via the command line for fields that have the same name and that are deeply nested.

The entrypoint

To launch the training, we create a function that can do a range of things:

  1. Load a yaml config file
  2. Load multiple yaml config files and merge them
  3. Load a yaml sweep file to generate patches and apply them to the config using Optuna
  4. Apply CLI arguments to the config
  5. Whatever the input is, train the model and return the accuracy

Even though this sounds like a lot, with pg-gen-ml it is actually quite easy to do. The main function now becomes:

Entrypoint
@pgml.pgml_cmd(app=app)
def main(
    config_paths: List[str] = typer.Option(..., help='Paths to config files'),
    sweep_paths: List[str] = typer.Option(default_factory=list, help='Paths to sweep files'),
    cli_args: cli_args.ProjectArgs = typer.Option(...),
) -> None:
    config = base.Project.from_yaml_files(config_paths)
    config = config.apply_cli_args(cli_args)

    if len(sweep_paths) == 0:
        train_model(config, trial=None)
        return

    sweep_config = sweep.ProjectSweep.from_yaml_files(sweep_paths)

    def objective(trial: optuna.Trial) -> float:
        sampler = pgml.OptunaSampler(trial=trial)
        patch = sampler.sample(sweep_config)
        accuracy = train_model(project=config.merge(patch), trial=trial)
        return accuracy

    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=100)

Let's break this down a bit more.

  • At line 7, we load the project config from one or more yaml files. This is where files will be merged before they are passed to the Pydantic model validator.
  • At line 8, we apply the CLI arguments to the config.
  • At line 10, we check if there are any sweep files. If there are, we load them at line 14. If there are no sweep files, we simply train the model and return the accuracy.
  • At line 16-20, we define an objective function for Optuna to use. This is the function that will be called to train the model and return the accuracy. Note that we are using the exact same train_model function that we used in line 10 earlier.
  • At line 22 and 23, we create an Optuna study and optimize the objective function.

The configuration

The configuration is defined in a yaml file. This is the file that we will use to train the model. It looks like this:

Config
# yaml-language-server: $schema=schemas/project.json
net:
  conv_net:
    block:
      activation: GELU
      kernel_size: 3
      out_channels: 32
      pool_size: 2
  head:
    block:
      activation: GELU
      out_features: 128
    num_layers: 3
optimizer:
  beta1: 0.99
  learning_rate: 0.0001
data:
  batch_size: 32
  num_epochs: 10

Launching the training

To launch the training, we can use the following command:

python examples/cifar10/src/cifar10/train.py \
    --config-paths \
    configs/base/default.yaml

Showing CLI arguments

To show the CLI arguments, we can use the following command:

python examples/cifar10/src/cifar10/train.py --help

This will show the following:

 Usage: train.py [OPTIONS]

╭─ Options ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ *  --config-paths              TEXT         Path to config file [default: None] [required]                                             │
│    --sweep-paths               TEXT         Type of config to use [default: <class 'list'>]                                            │
│    --out-channels              INTEGER      Number of output channels. Maps to 'net.conv_net.block.out_channels' [default: None]       │
│    --kernel-size               INTEGER      Square kernel size. Maps to 'net.conv_net.block.kernel_size' [default: None]               │
│    --pool-size                 INTEGER      Square pool size. Maps to 'net.conv_net.block.pool_size' [default: None]                   │
│    --out-features              INTEGER      Number of output features. Maps to 'net.head.block.out_features' [default: None]           │
│    --batch-size                INTEGER      Batch size for a single GPU. Maps to 'data.batch_size' [default: None]                     │
│    --num-epochs                INTEGER      Number of epochs to train. Maps to 'data.num_epochs' [default: None]                       │
│    --learning-rate             FLOAT        Learning rate. Maps to 'optimizer.learning_rate' [default: None]                           │
│    --beta1                     FLOAT        Decay rate. Maps to 'optimizer.beta1' [default: None]                                      │
│    --conv-activation           [gelu|relu]  Activation function. Maps to 'net.conv_net.block.activation' [default: None]               │
│    --head-activation           [gelu|relu]  Activation function. Maps to 'net.head.block.activation' [default: None]                   │
│    --num-conv-layers           INTEGER      Number of layers. Maps to 'net.conv_net.num_layers' [default: None]                        │
│    --num-mlp-layers            INTEGER      Number of layers. Maps to 'net.head.num_layers' [default: None]                            │
│    --install-completion                     Install completion for the current shell.                                                  │
│    --show-completion                        Show completion for the current shell, to copy it or customize the installation.           │
│    --help                                   Show this message and exit.                                                                │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

Starting a sweep

We define the following sweep configuration:

Sweep
# yaml-language-server: $schema=schemas/project.json
optimizer:
  beta1:
    low: 0.9
    high: 0.99
  learning_rate:
    log_low: 1e-5
    log_high: 1e-3

This is a sweep over the learning rate and beta1 parameters.

To run the sweep, we can use the following command:

python examples/cifar10/src/cifar10/train.py \
    --config-paths \
    configs/base/default.yaml \
    --sweep-paths \
    configs/sweep/lr_beta1.yaml

Remaining code

Modules

We have define the modules here

Modules
from typing import Self

import pgml_out.config_base as base
import torch

CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


def create_activation(activation: base.Activation) -> torch.nn.Module:
    if activation == base.Activation.GELU:
        return torch.nn.GELU()
    elif activation == base.Activation.RELU:
        return torch.nn.ReLU()
    else:
        raise ValueError(f'Invalid activation function: {activation}')


class LinearBlock(torch.nn.Module):
    """Linear layer with activation configuration"""

    def __init__(self, out_features: int, activation: base.Activation) -> None:
        super().__init__()
        self.linear = torch.nn.LazyLinear(out_features=out_features)
        self.activation = create_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.activation(self.linear(x))


class ConvBlock(torch.nn.Module):
    """Convolutional layer with activation configuration"""

    def __init__(self, out_channels: int, kernel_size: int, pool_size: int, activation: base.Activation) -> None:
        super().__init__()
        self.conv = torch.nn.LazyConv2d(out_channels=out_channels, kernel_size=kernel_size)
        self.pool = torch.nn.MaxPool2d(kernel_size=pool_size)
        self.activation = create_activation(activation)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.activation(self.pool(self.conv(x)))


class ConvNet(torch.nn.Module):

    def __init__(self, block: base.ConvBlock, num_layers: int) -> None:
        super().__init__()
        self.layers = [block.build() for _ in range(num_layers)]
        self.net = torch.nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class MLP(torch.nn.Module):

    def __init__(self, block: base.LinearBlock, num_layers: int) -> None:
        super().__init__()
        self.layers = [block.build() for _ in range(num_layers)]
        self.net = torch.nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


class Model(torch.nn.Module):

    def __init__(self, conv_net: ConvNet, head: MLP) -> None:
        super().__init__()
        self.conv_net = conv_net
        self.head = head
        self.class_logits = torch.nn.LazyLinear(out_features=len(CLASSES),)
        self.flatten = torch.nn.Flatten()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_net(x)
        x = self.head(x)
        x = self.flatten(x)
        x = self.class_logits(x)
        return x

    @classmethod
    def from_config(cls, config: base.Model) -> Self:
        conv_net = ConvNet(
            block=config.conv_net.block,
            num_layers=config.conv_net.num_layers,
        )
        head = MLP(
            block=config.head.block,
            num_layers=config.head.num_layers,
        )
        return cls(conv_net, head)

Data

We have defined the data module here:

Data
import os
from typing import Any

import torch
import torchvision
import torchvision.transforms as transforms


def get_transform() -> transforms.Compose:
    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ],)


def get_data_loader(transform: transforms.Compose, batch_size: int, train: bool) -> torch.utils.data.DataLoader[Any]:
    trainset = torchvision.datasets.CIFAR10(
        root=f"{os.environ['HOME']}/data/torchvision",
        train=train,
        download=train,
        transform=transform,
    )
    return torch.utils.data.DataLoader[Any](
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        prefetch_factor=4,
    )

Trainer

We have defined the trainer here:

Trainer
class Trainer:

    def __init__(
        self,
        model: torch.nn.Module,
        train_loader: torch.utils.data.DataLoader[Any],
        test_loader: torch.utils.data.DataLoader[Any],
        criterion: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        accuracy_metric_train: torchmetrics.classification.MulticlassAccuracy,
        accuracy_metric_test: torchmetrics.classification.MulticlassAccuracy,
        train_loss_metric: torchmetrics.MeanMetric,
        test_loss_metric: torchmetrics.MeanMetric,
        trial: Optional[optuna.Trial] = None,
    ) -> None:
        self._train_loader = train_loader
        self._test_loader = test_loader
        self._model = model
        self._criterion = criterion
        self._optimizer = optimizer
        self._accuracy_metric_train = accuracy_metric_train
        self._accuracy_metric_test = accuracy_metric_test
        self._train_loss_metric = train_loss_metric
        self._test_loss_metric = test_loss_metric
        self._trial = trial

    def train(self, num_epochs: int) -> float:
        device = get_device()
        step = 0
        for epoch in tqdm.trange(num_epochs, position=0, desc='Epoch'):
            batch_bar = tqdm.tqdm(self._train_loader, position=1, desc='Batch')
            for inputs, labels in batch_bar:
                inputs = inputs.to(device)
                labels = labels.to(device)

                self._optimizer.zero_grad()
                outputs = self._model(inputs)
                loss = self._criterion(outputs, labels)
                loss.backward()
                self._optimizer.step()

                self._train_loss_metric.update(loss)
                self._accuracy_metric_train.update(outputs, labels)
                step += 1

                if step % 10 == 0:
                    batch_bar.set_postfix(
                        loss=self._train_loss_metric.compute().item(),
                        accuracy=self._accuracy_metric_train.compute().item(),
                    )
                    self._train_loss_metric.reset()
                    self._accuracy_metric_train.reset()

            self._evaluate()
            if self._trial is not None:
                self._trial.report(self._accuracy_metric_test.compute().item(), epoch)
        return self._accuracy_metric_test.compute().item()

    @torch.inference_mode()
    def _evaluate(self) -> None:
        self._model.eval()
        self._accuracy_metric_test.reset()
        self._test_loss_metric.reset()
        for images, labels in tqdm.tqdm(self._test_loader, position=1, desc='Evaluating'):
            images = images.to(get_device())
            labels = labels.to(get_device())
            outputs = self._model(images)
            loss = self._criterion(outputs, labels)
            self._test_loss_metric.update(loss)
            self._accuracy_metric_test.update(outputs, labels)
        print(f'Test accuracy: {self._accuracy_metric_test.compute().item()}')
        print(f'Test loss: {self._test_loss_metric.compute().item()}')
        self._model.train()

Train function

The train function that instantiates all the components and calls the trainer is defined here:

Train function
def train_model(project: base.Project, trial: typing.Optional[optuna.Trial] = None) -> float:
    rich.print(project)

    transform = get_transform()
    train_loader = get_data_loader(transform=transform, batch_size=project.data.batch_size, train=True)
    test_loader = get_data_loader(transform=transform, batch_size=project.data.batch_size, train=False)

    device = get_device()
    print(f'device {device}')

    model = Model.from_config(config=project.net).to(device)
    path = pathlib.Path(f"{os.environ['HOME']}/gen_ml/logs/{uuid.uuid4()}")
    print(f'Storing logs at {path}')

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(params=model.parameters(), lr=project.optimizer.learning_rate)
    accuracy_metric_train = get_accuracy_metric(num_classes=len(CLASSES))
    accuracy_metric_test = get_accuracy_metric(num_classes=len(CLASSES))
    train_loss_metric = torchmetrics.MeanMetric().to(device)
    test_loss_metric = torchmetrics.MeanMetric().to(device)
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        test_loader=test_loader,
        criterion=criterion,
        optimizer=optimizer,
        accuracy_metric_train=accuracy_metric_train,
        accuracy_metric_test=accuracy_metric_test,
        train_loss_metric=train_loss_metric,
        test_loss_metric=test_loss_metric,
        trial=trial,
    )
    accuracy = trainer.train(num_epochs=project.data.num_epochs)
    return accuracy