🏭 Factories
🪄 Generated Factory Methods
Occasionally, you have enough information to instantiate a class from a configuration object immediately. py-gen-ml
allows you to generate factory methods in such cases. The factory methods unpack the message fields into keyword arguments and then instantiate an object.
To specify a factory method for a message, you can use the (pgml.factory)
option. For example:
| // builder_demo.proto
syntax = "proto3";
package builder_demo;
// Import the PGML extensions
import "py_gen_ml/extensions.proto";
// Linear layer configuration
message Linear {
option (pgml.factory) = "torch.nn.Linear";
// Number of input features
uint32 in_features = 1;
// Number of output features
uint32 out_features = 2;
// Bias
bool bias = 3;
}
// MLP configuration
message MLP {
// Linear layers
repeated Linear layers = 1;
}
|
The generated code will look like this:
| # Autogenerated code. DO NOT EDIT.
import typing
import py_gen_ml as pgml
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import torch.nn
class Linear(pgml.YamlBaseModel):
"""Linear layer configuration"""
in_features: int
"""Number of input features"""
out_features: int
"""Number of output features"""
bias: bool
"""Bias"""
def build(self) -> "torch.nn.Linear":
import torch.nn
return torch.nn.Linear(
in_features=self.in_features,
out_features=self.out_features,
bias=self.bias,
)
class MLP(pgml.YamlBaseModel):
"""MLP configuration"""
layers: typing.List[Linear]
"""Linear layers"""
|
Notice the build
method. This method is automatically generated for you. It unpacks the message fields into keyword arguments and then instantiates the class.
In your experiment code, you can now call the build
method to instantiate the class:
| import torch
from pgml_out.builder_demo_base import MLP
if __name__ == "__main__":
mlp_config = MLP.from_yaml("configs/base/mlp.yaml")
layers = [layer.build() for layer in mlp_config.layers]
mlp = torch.nn.Sequential(*layers)
|
🧱 Using custom classes
The builder extension can be used for any class, not just PyTorch classes. You can use it to instantiate any class that you have access to.
For example, let's say you have a custom class that you want to instantiate. You can do this:
# src/example_project/modules.py
import torch.nn
class LinearBlock(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, bias: bool = True, dropout: float = 0.0, activation: str = "relu"):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
self.dropout = torch.nn.Dropout(dropout)
self.activation = torch.nn.ReLU() if activation == "relu" else torch.nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.dropout(self.activation(self.linear(x)))
And then define the following proto:
| // builder_custom_class_demo.proto
syntax = "proto3";
package builder_custom_class_demo;
import "py_gen_ml/extensions.proto";
// Linear block configuration
message LinearBlock {
option (pgml.factory) = "example_project.modules.LinearBlock";
// Number of input features
uint32 in_features = 1;
// Number of output features
uint32 out_features = 2;
// Bias
bool bias = 3;
// Dropout probability
float dropout = 4;
// Activation function
string activation = 5;
}
// MLP configuration
message MLP {
// Linear blocks
repeated LinearBlock layers = 1;
}
|
The generated code will look like this:
| # Autogenerated code. DO NOT EDIT.
import typing
import py_gen_ml as pgml
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import example_project.modules
class LinearBlock(pgml.YamlBaseModel):
"""Linear block configuration"""
in_features: int
"""Number of input features"""
out_features: int
"""Number of output features"""
bias: bool
"""Bias"""
dropout: float
"""Dropout probability"""
activation: str
"""Activation function"""
def build(self) -> "example_project.modules.LinearBlock":
import example_project.modules
return example_project.modules.LinearBlock(
in_features=self.in_features,
out_features=self.out_features,
bias=self.bias,
dropout=self.dropout,
activation=self.activation,
)
class MLP(pgml.YamlBaseModel):
"""MLP configuration"""
layers: typing.List[LinearBlock]
"""Linear blocks"""
|
💥 Expanding fields as varargs
You can also expand fields as varargs. This is useful if you have a list of arguments that you want to pass to the builder. For example, let's say you have a custom class that you want to instantiate. You can use the
(pgml.as_varargs)
option to expand the fields as varargs. For example:
| // builder_varargs_demo.proto
syntax = "proto3";
package builder_varargs_demo;
import "py_gen_ml/extensions.proto";
// Linear layer configuration
message Linear {
option (pgml.factory) = "torch.nn.Linear";
// Number of input features
uint32 in_features = 1;
// Number of output features
uint32 out_features = 2;
// Bias
bool bias = 3;
}
// MLP configuration
message MLP {
option (pgml.factory) = "torch.nn.Sequential";
// Linear layers
repeated Linear layers = 1 [(pgml.as_varargs) = true];
}
|
The generated code will look like this:
| # Autogenerated code. DO NOT EDIT.
import typing
import py_gen_ml as pgml
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import torch.nn
class Linear(pgml.YamlBaseModel):
"""Linear layer configuration"""
in_features: int
"""Number of input features"""
out_features: int
"""Number of output features"""
bias: bool
"""Bias"""
def build(self) -> "torch.nn.Linear":
import torch.nn
return torch.nn.Linear(
in_features=self.in_features,
out_features=self.out_features,
bias=self.bias,
)
class MLP(pgml.YamlBaseModel):
"""MLP configuration"""
layers: typing.List[Linear]
"""Linear layers"""
def build(self) -> "torch.nn.Sequential":
import torch.nn
return torch.nn.Sequential(
*(elem.build() for elem in self.layers),
)
|
🐣 Nesting factories
As you may have noticed, factories can also be nested. In the section on varargs, we see that the build method in MLP
takes a varargs of Linear
objects that are also instantiated with a factory. Nesting with factories can streamline instantiation of complex objects, but it also creates a tighter coupling between your schema and the objects that are created.
Usually, it is best to use factories for objects that don't need other factories for their fields. In other words, you should nest factories sparingly.