Skip to content

🏭 Factories

🪄 Generated Factory Methods

Occasionally, you have enough information to instantiate a class from a configuration object immediately. py-gen-ml allows you to generate factory methods in such cases. The factory methods unpack the message fields into keyword arguments and then instantiate an object.

To specify a factory method for a message, you can use the (pgml.factory) option. For example:

// builder_demo.proto
syntax = "proto3";

package builder_demo;

// Import the PGML extensions
import "py_gen_ml/extensions.proto";


// Linear layer configuration
message Linear {
    option (pgml.factory) = "torch.nn.Linear";
    // Number of input features
    uint32 in_features = 1;
    // Number of output features
    uint32 out_features = 2;
    // Bias
    bool bias = 3;
}

// MLP configuration
message MLP {
    // Linear layers
    repeated Linear layers = 1;
}

The generated code will look like this:

# Autogenerated code. DO NOT EDIT.
import typing
import py_gen_ml as pgml

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import torch.nn


class Linear(pgml.YamlBaseModel):
    """Linear layer configuration"""

    in_features: int
    """Number of input features"""

    out_features: int
    """Number of output features"""

    bias: bool
    """Bias"""

    def build(self) -> "torch.nn.Linear":
        import torch.nn

        return torch.nn.Linear(
            in_features=self.in_features,
            out_features=self.out_features,
            bias=self.bias,
        )


class MLP(pgml.YamlBaseModel):
    """MLP configuration"""

    layers: typing.List[Linear]
    """Linear layers"""

Notice the build method. This method is automatically generated for you. It unpacks the message fields into keyword arguments and then instantiates the class.

In your experiment code, you can now call the build method to instantiate the class:

1
2
3
4
5
6
7
8
9
import torch

from pgml_out.builder_demo_base import MLP

if __name__ == "__main__":
    mlp_config = MLP.from_yaml("configs/base/mlp.yaml")

    layers = [layer.build() for layer in mlp_config.layers]
    mlp = torch.nn.Sequential(*layers)

🧱 Using custom classes

The builder extension can be used for any class, not just PyTorch classes. You can use it to instantiate any class that you have access to.

For example, let's say you have a custom class that you want to instantiate. You can do this:

# src/example_project/modules.py
import torch.nn


class LinearBlock(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, dropout: float = 0.0, activation: str = "relu"):
        super().__init__()
        self.linear = torch.nn.Linear(in_features, out_features, bias=bias)
        self.dropout = torch.nn.Dropout(dropout)
        self.activation = torch.nn.ReLU() if activation == "relu" else torch.nn.GELU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.dropout(self.activation(self.linear(x)))

And then define the following proto:

// builder_custom_class_demo.proto
syntax = "proto3";

package builder_custom_class_demo;

import "py_gen_ml/extensions.proto";


// Linear block configuration
message LinearBlock {
    option (pgml.factory) = "example_project.modules.LinearBlock";
    // Number of input features
    uint32 in_features = 1;
    // Number of output features
    uint32 out_features = 2;
    // Bias
    bool bias = 3;
    // Dropout probability
    float dropout = 4;
    // Activation function
    string activation = 5;
}

// MLP configuration
message MLP {
    // Linear blocks
    repeated LinearBlock layers = 1;
}

The generated code will look like this:

# Autogenerated code. DO NOT EDIT.
import typing
import py_gen_ml as pgml

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import example_project.modules


class LinearBlock(pgml.YamlBaseModel):
    """Linear block configuration"""

    in_features: int
    """Number of input features"""

    out_features: int
    """Number of output features"""

    bias: bool
    """Bias"""

    dropout: float
    """Dropout probability"""

    activation: str
    """Activation function"""

    def build(self) -> "example_project.modules.LinearBlock":
        import example_project.modules

        return example_project.modules.LinearBlock(
            in_features=self.in_features,
            out_features=self.out_features,
            bias=self.bias,
            dropout=self.dropout,
            activation=self.activation,
        )


class MLP(pgml.YamlBaseModel):
    """MLP configuration"""

    layers: typing.List[LinearBlock]
    """Linear blocks"""

💥 Expanding fields as varargs

You can also expand fields as varargs. This is useful if you have a list of arguments that you want to pass to the builder. For example, let's say you have a custom class that you want to instantiate. You can use the (pgml.as_varargs) option to expand the fields as varargs. For example:

// builder_varargs_demo.proto
syntax = "proto3";

package builder_varargs_demo;

import "py_gen_ml/extensions.proto";


// Linear layer configuration
message Linear {
    option (pgml.factory) = "torch.nn.Linear";
    // Number of input features
    uint32 in_features = 1;
    // Number of output features
    uint32 out_features = 2;
    // Bias
    bool bias = 3;
}

// MLP configuration
message MLP {
    option (pgml.factory) = "torch.nn.Sequential";
    // Linear layers
    repeated Linear layers = 1 [(pgml.as_varargs) = true];
}

The generated code will look like this:

# Autogenerated code. DO NOT EDIT.
import typing
import py_gen_ml as pgml

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import torch.nn


class Linear(pgml.YamlBaseModel):
    """Linear layer configuration"""

    in_features: int
    """Number of input features"""

    out_features: int
    """Number of output features"""

    bias: bool
    """Bias"""

    def build(self) -> "torch.nn.Linear":
        import torch.nn

        return torch.nn.Linear(
            in_features=self.in_features,
            out_features=self.out_features,
            bias=self.bias,
        )


class MLP(pgml.YamlBaseModel):
    """MLP configuration"""

    layers: typing.List[Linear]
    """Linear layers"""

    def build(self) -> "torch.nn.Sequential":
        import torch.nn

        return torch.nn.Sequential(
            *(elem.build() for elem in self.layers),
        )

🐣 Nesting factories

As you may have noticed, factories can also be nested. In the section on varargs, we see that the build method in MLP takes a varargs of Linear objects that are also instantiated with a factory. Nesting with factories can streamline instantiation of complex objects, but it also creates a tighter coupling between your schema and the objects that are created.

Usually, it is best to use factories for objects that don't need other factories for their fields. In other words, you should nest factories sparingly.