CIFAR-10
Introduction
Here we will walk through an example project for training a model to classify images from the CIFAR10 dataset. This combines most of the concepts we have seen in the guides so far. It offers a simple model, data and optimizer, and shows how to:
- Use a variety of special attributes to define the config schema
- Load a config from a yaml file
- Apply CLI arguments to the config
- Load a sweep file to sample from a set of configurations
Schema
The schema is defined in a proto file. This is the file that we will use to train the model. It looks like this:
The model
The model definition is a simple convolutional neural network followed by a multi-layer perceptron. In proto definition, it looks like this:
// Model configuration
message Model {
// Conv blocks
ConvNet conv_net = 1;
// MLP head
MLP head = 2;
}
// Convolutional neural network configuration
message ConvNet {
// Conv layer configuration
ConvBlock block = 1;
// Number of layers
uint32 num_layers = 2 [(pgml.default).uint32 = 2];
}
// Multi-layer perceptron configuration
message MLP {
// Linear layer configuration
LinearBlock block = 1;
// Number of layers
uint32 num_layers = 2 [(pgml.default).uint32 = 2];
}
// Convolutional layer configuration
message ConvBlock {
option (pgml.factory) = "cifar10.modules.ConvBlock";
// Number of output channels
uint32 out_channels = 1 [(pgml.default).uint32 = 128];
// Square kernel size
uint32 kernel_size = 2 [(pgml.default).uint32 = 3];
// Square pool size
uint32 pool_size = 3 [(pgml.default).uint32 = 2];
// Activation function
Activation activation = 4 [(pgml.default).enum = "GELU"];
}
// Linear layer configuration
message LinearBlock {
option (pgml.factory) = "cifar10.modules.LinearBlock";
// Number of output features
uint32 out_features = 1 [(pgml.default).uint32 = 128];
// Activation function
Activation activation = 2 [(pgml.default).enum = "GELU"];
}
// Activation function
enum Activation {
// GELU activation
GELU = 0;
// ReLU activation
RELU = 1;
}
Concepts that we've used here are:
- Nesting: we define a
Model
message that contains two nested messages:ConvNet
andMLP
which in turn contain blocks of linear and convolutional layers. pgml.factory
: This is a special attribute that tellspy-gen-ml
that with these values we can build an instance of the given class.pgml.default
: This is a special attribute that tellspy-gen-ml
to use a default value if the field is not set.- Enums: we define an enum for the activation function.
The data
We define a Data
message that contains the batch size and the number of epochs to train. In proto definition, it looks like this:
// Data configuration
message Data {
// Batch size for a single GPU
uint32 batch_size = 1 [(pgml.default).uint32 = 32];
// Number of epochs to train
uint32 num_epochs = 2 [(pgml.default).uint32 = 10];
}
The optimizer
We define an Optimizer
message that contains the learning rate and the decay rate. In proto definition, it looks like this:
// Optimizer configuration
message Optimizer {
// Learning rate
float learning_rate = 1 [(pgml.default).float = 1e-4];
// Decay rate
float beta1 = 2 [(pgml.default).float = 0.99];
}
The project
We define a Project
message that contains the model, the optimizer, and the data. In proto definition, it looks like this:
syntax = "proto3";
package cifar10;
import "py_gen_ml/extensions.proto";
// Global configuration
message Project {
option (pgml.cli) = {
enable: true
arg: {name: "conv_activation", path: "net.conv_net.block.activation"}
arg: {name: "head_activation", path: "net.head.block.activation"}
arg: {name: "num_conv_layers", path: "net.conv_net.num_layers"}
arg: {name: "num_mlp_layers", path: "net.head.num_layers"}
};
// Model configuration
Model net = 1;
// Optimizer configuration
Optimizer optimizer = 2;
// Data configuration
Data data = 3;
}
The entrypoint
To launch the training, we create a function that can do a range of things:
- Load a yaml config file
- Load multiple yaml config files and merge them
- Load a yaml sweep file to generate patches and apply them to the config using Optuna
- Apply CLI arguments to the config
- Whatever the input is, train the model and return the accuracy
Even though this sounds like a lot, with pg-gen-ml
it is actually quite easy to do. The main function now becomes:
Let's break this down a bit more.
- At line 7, we load the project config from one or more yaml files. This is where files will be merged before they are passed to the Pydantic model validator.
- At line 8, we apply the CLI arguments to the config.
- At line 10, we check if there are any sweep files. If there are, we load them at line 14. If there are no sweep files, we simply train the model and return the accuracy.
- At line 16-20, we define an objective function for Optuna to use. This is the function that will be called to train the model and return the accuracy. Note that we are using the exact same
train_model
function that we used in line 10 earlier. - At line 22 and 23, we create an Optuna study and optimize the objective function.
The configuration
The configuration is defined in a yaml file. This is the file that we will use to train the model. It looks like this:
# yaml-language-server: $schema=schemas/project.json
net:
conv_net:
block:
activation: GELU
kernel_size: 3
out_channels: 32
pool_size: 2
head:
block:
activation: GELU
out_features: 128
num_layers: 3
optimizer:
beta1: 0.99
learning_rate: 0.0001
data:
batch_size: 32
num_epochs: 10
Launching the training
To launch the training, we can use the following command:
Showing CLI arguments
To show the CLI arguments, we can use the following command:
This will show the following:
Usage: train.py [OPTIONS]
â•â”€ Options ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│ * --config-paths TEXT Path to config file [default: None] [required] │
│ --sweep-paths TEXT Type of config to use [default: <class 'list'>] │
│ --out-channels INTEGER Number of output channels. Maps to 'net.conv_net.block.out_channels' [default: None] │
│ --kernel-size INTEGER Square kernel size. Maps to 'net.conv_net.block.kernel_size' [default: None] │
│ --pool-size INTEGER Square pool size. Maps to 'net.conv_net.block.pool_size' [default: None] │
│ --out-features INTEGER Number of output features. Maps to 'net.head.block.out_features' [default: None] │
│ --batch-size INTEGER Batch size for a single GPU. Maps to 'data.batch_size' [default: None] │
│ --num-epochs INTEGER Number of epochs to train. Maps to 'data.num_epochs' [default: None] │
│ --learning-rate FLOAT Learning rate. Maps to 'optimizer.learning_rate' [default: None] │
│ --beta1 FLOAT Decay rate. Maps to 'optimizer.beta1' [default: None] │
│ --conv-activation [gelu|relu] Activation function. Maps to 'net.conv_net.block.activation' [default: None] │
│ --head-activation [gelu|relu] Activation function. Maps to 'net.head.block.activation' [default: None] │
│ --num-conv-layers INTEGER Number of layers. Maps to 'net.conv_net.num_layers' [default: None] │
│ --num-mlp-layers INTEGER Number of layers. Maps to 'net.head.num_layers' [default: None] │
│ --install-completion Install completion for the current shell. │
│ --show-completion Show completion for the current shell, to copy it or customize the installation. │
│ --help Show this message and exit. │
╰────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
Starting a sweep
We define the following sweep configuration:
# yaml-language-server: $schema=schemas/project.json
optimizer:
beta1:
low: 0.9
high: 0.99
learning_rate:
log_low: 1e-5
log_high: 1e-3
This is a sweep over the learning rate and beta1 parameters.
To run the sweep, we can use the following command:
python examples/cifar10/src/cifar10/train.py \
--config-paths \
configs/base/default.yaml \
--sweep-paths \
configs/sweep/lr_beta1.yaml
Remaining code
Modules
We have define the modules here
from typing import Self
import pgml_out.config_base as base
import torch
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
def create_activation(activation: base.Activation) -> torch.nn.Module:
if activation == base.Activation.GELU:
return torch.nn.GELU()
elif activation == base.Activation.RELU:
return torch.nn.ReLU()
else:
raise ValueError(f'Invalid activation function: {activation}')
class LinearBlock(torch.nn.Module):
"""Linear layer with activation configuration"""
def __init__(self, out_features: int, activation: base.Activation) -> None:
super().__init__()
self.linear = torch.nn.LazyLinear(out_features=out_features)
self.activation = create_activation(activation)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.activation(self.linear(x))
class ConvBlock(torch.nn.Module):
"""Convolutional layer with activation configuration"""
def __init__(self, out_channels: int, kernel_size: int, pool_size: int, activation: base.Activation) -> None:
super().__init__()
self.conv = torch.nn.LazyConv2d(out_channels=out_channels, kernel_size=kernel_size)
self.pool = torch.nn.MaxPool2d(kernel_size=pool_size)
self.activation = create_activation(activation)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.activation(self.pool(self.conv(x)))
class ConvNet(torch.nn.Module):
def __init__(self, block: base.ConvBlock, num_layers: int) -> None:
super().__init__()
self.layers = [block.build() for _ in range(num_layers)]
self.net = torch.nn.Sequential(*self.layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class MLP(torch.nn.Module):
def __init__(self, block: base.LinearBlock, num_layers: int) -> None:
super().__init__()
self.layers = [block.build() for _ in range(num_layers)]
self.net = torch.nn.Sequential(*self.layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
class Model(torch.nn.Module):
def __init__(self, conv_net: ConvNet, head: MLP) -> None:
super().__init__()
self.conv_net = conv_net
self.head = head
self.class_logits = torch.nn.LazyLinear(out_features=len(CLASSES),)
self.flatten = torch.nn.Flatten()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv_net(x)
x = self.head(x)
x = self.flatten(x)
x = self.class_logits(x)
return x
@classmethod
def from_config(cls, config: base.Model) -> Self:
conv_net = ConvNet(
block=config.conv_net.block,
num_layers=config.conv_net.num_layers,
)
head = MLP(
block=config.head.block,
num_layers=config.head.num_layers,
)
return cls(conv_net, head)
Data
We have defined the data module here:
import os
from typing import Any
import torch
import torchvision
import torchvision.transforms as transforms
def get_transform() -> transforms.Compose:
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
],)
def get_data_loader(transform: transforms.Compose, batch_size: int, train: bool) -> torch.utils.data.DataLoader[Any]:
trainset = torchvision.datasets.CIFAR10(
root=f"{os.environ['HOME']}/data/torchvision",
train=train,
download=train,
transform=transform,
)
return torch.utils.data.DataLoader[Any](
trainset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
prefetch_factor=4,
)
Trainer
We have defined the trainer here:
class Trainer:
def __init__(
self,
model: torch.nn.Module,
train_loader: torch.utils.data.DataLoader[Any],
test_loader: torch.utils.data.DataLoader[Any],
criterion: torch.nn.Module,
optimizer: torch.optim.Optimizer,
accuracy_metric_train: torchmetrics.classification.MulticlassAccuracy,
accuracy_metric_test: torchmetrics.classification.MulticlassAccuracy,
train_loss_metric: torchmetrics.MeanMetric,
test_loss_metric: torchmetrics.MeanMetric,
trial: Optional[optuna.Trial] = None,
) -> None:
self._train_loader = train_loader
self._test_loader = test_loader
self._model = model
self._criterion = criterion
self._optimizer = optimizer
self._accuracy_metric_train = accuracy_metric_train
self._accuracy_metric_test = accuracy_metric_test
self._train_loss_metric = train_loss_metric
self._test_loss_metric = test_loss_metric
self._trial = trial
def train(self, num_epochs: int) -> float:
device = get_device()
step = 0
for epoch in tqdm.trange(num_epochs, position=0, desc='Epoch'):
batch_bar = tqdm.tqdm(self._train_loader, position=1, desc='Batch')
for inputs, labels in batch_bar:
inputs = inputs.to(device)
labels = labels.to(device)
self._optimizer.zero_grad()
outputs = self._model(inputs)
loss = self._criterion(outputs, labels)
loss.backward()
self._optimizer.step()
self._train_loss_metric.update(loss)
self._accuracy_metric_train.update(outputs, labels)
step += 1
if step % 10 == 0:
batch_bar.set_postfix(
loss=self._train_loss_metric.compute().item(),
accuracy=self._accuracy_metric_train.compute().item(),
)
self._train_loss_metric.reset()
self._accuracy_metric_train.reset()
self._evaluate()
if self._trial is not None:
self._trial.report(self._accuracy_metric_test.compute().item(), epoch)
return self._accuracy_metric_test.compute().item()
@torch.inference_mode()
def _evaluate(self) -> None:
self._model.eval()
self._accuracy_metric_test.reset()
self._test_loss_metric.reset()
for images, labels in tqdm.tqdm(self._test_loader, position=1, desc='Evaluating'):
images = images.to(get_device())
labels = labels.to(get_device())
outputs = self._model(images)
loss = self._criterion(outputs, labels)
self._test_loss_metric.update(loss)
self._accuracy_metric_test.update(outputs, labels)
print(f'Test accuracy: {self._accuracy_metric_test.compute().item()}')
print(f'Test loss: {self._test_loss_metric.compute().item()}')
self._model.train()
Train function
The train function that instantiates all the components and calls the trainer is defined here:
def train_model(project: base.Project, trial: typing.Optional[optuna.Trial] = None) -> float:
rich.print(project)
transform = get_transform()
train_loader = get_data_loader(transform=transform, batch_size=project.data.batch_size, train=True)
test_loader = get_data_loader(transform=transform, batch_size=project.data.batch_size, train=False)
device = get_device()
print(f'device {device}')
model = Model.from_config(config=project.net).to(device)
path = pathlib.Path(f"{os.environ['HOME']}/gen_ml/logs/{uuid.uuid4()}")
print(f'Storing logs at {path}')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(params=model.parameters(), lr=project.optimizer.learning_rate)
accuracy_metric_train = get_accuracy_metric(num_classes=len(CLASSES))
accuracy_metric_test = get_accuracy_metric(num_classes=len(CLASSES))
train_loss_metric = torchmetrics.MeanMetric().to(device)
test_loss_metric = torchmetrics.MeanMetric().to(device)
trainer = Trainer(
model=model,
train_loader=train_loader,
test_loader=test_loader,
criterion=criterion,
optimizer=optimizer,
accuracy_metric_train=accuracy_metric_train,
accuracy_metric_test=accuracy_metric_test,
train_loss_metric=train_loss_metric,
test_loss_metric=test_loss_metric,
trial=trial,
)
accuracy = trainer.train(num_epochs=project.data.num_epochs)
return accuracy