Skip to content

🖥️ CLI Argument Parsing

✨ Implicit Argument References

py-gen-ml generates a smart CLI argument parser using Pydantic base models. It shortens CLI argument names for deeply nested fields in your config when there's exactly one path to a field and the field name is unique.

Example protobuf structure:

// cli_demo.proto
syntax = "proto3";

package cli_demo;

import "py_gen_ml/extensions.proto";

// Global configuration
message CLIDemo {
    option (pgml.cli).enable = true;
    // Dataset configuration
    Data data = 1;
    // Model configuration
    Model model = 2;
    // Training configuration
    Training training = 3;
}

// Dataset configuration
message Dataset {
    // Path to the dataset
    string path = 1;
}

// Data config
message Data {
    // Path to the dataset
    Dataset dataset = 1;
    // Number of workers for loading the dataset
    uint32 num_workers = 2;
}

// Model configuration
message Model {
    // Number of layers
    uint32 num_layers = 1;
}

// Training configuration
message Training {
    // Number of epochs
    uint32 num_epochs = 1;
}

This generates a CLI args class:

# Autogenerated code. DO NOT EDIT.
import py_gen_ml as pgml
import typing

import pydantic
import typer

from . import cli_demo_base as base


class CLIDemoArgs(pgml.YamlBaseModel):
    """Global configuration"""

    path: typing.Annotated[
        typing.Optional[str],
        typer.Option(help="Path to the dataset. Maps to 'data.dataset.path'"),
        pydantic.Field(None),
        pgml.ArgRef("data.dataset.path"),
    ]
    """Path to the dataset"""

    num_layers: typing.Annotated[
        typing.Optional[int],
        typer.Option(help="Number of layers. Maps to 'model.num_layers'"),
        pydantic.Field(None),
        pgml.ArgRef("model.num_layers"),
    ]
    """Number of layers"""

    num_workers: typing.Annotated[
        typing.Optional[int],
        typer.Option(
            help="Number of workers for loading the dataset. Maps to 'data.num_workers'"
        ),
        pydantic.Field(None),
        pgml.ArgRef("data.num_workers"),
    ]
    """Number of workers for loading the dataset"""

    num_epochs: typing.Annotated[
        typing.Optional[int],
        typer.Option(help="Number of epochs. Maps to 'training.num_epochs'"),
        pydantic.Field(None),
        pgml.ArgRef("training.num_epochs"),
    ]
    """Number of epochs"""

🚪 Generated Entrypoint

It also generates a skeleton entrypoint:

import pgml_out.cli_demo_base as base
import pgml_out.cli_demo_sweep as sweep
import pgml_out.cli_demo_cli_args as cli_args
import typer
import py_gen_ml as pgml
import optuna
import typing

app = typer.Typer(pretty_exceptions_enable=False)

def run_trial(
    cli_demo: base.CLIDemo,
    trial: typing.Optional[optuna.Trial] = None
) -> typing.Union[float, typing.Sequence[float]]:
    """
    Run a trial with the given values for cli_demo. The sampled hyperparameters have
    already been added to the trial.
    """
    # TODO: Implement this function
    return 0.0

@pgml.pgml_cmd(app=app)
def main(
    config_paths: typing.List[str] = typer.Option(..., help="Paths to config files"),
    sweep_paths: typing.List[str] = typer.Option(
        default_factory=list,
        help="Paths to sweep files"
    ),
    cli_args: cli_args.CLIDemoArgs = typer.Option(...),
) -> None:
    cli_demo = base.CLIDemo.from_yaml_files(config_paths)
    cli_demo = cli_demo.apply_cli_args(cli_args)
    if len(sweep_paths) == 0:
        run_trial(cli_demo)
        return
    cli_demo_sweep = sweep.CLIDemoSweep.from_yaml_files(sweep_paths)

    def objective(trial: optuna.Trial) -> typing.Union[
        float,
        typing.Sequence[float]
    ]:
        optuna_sampler = pgml.OptunaSampler(trial)
        cli_demo_patch = optuna_sampler.sample(cli_demo_sweep)
        cli_demo_patched = cli_demo.merge(cli_demo_patch)
        objective_value = run_trial(cli_demo_patched, trial)
        return objective_value

    study = optuna.create_study(direction='maximize')
    study.optimize(objective, n_trials=100)


if __name__ == "__main__":
    app()

It is a standard Typer app, so you can run it like a normal Python script:

python src/pgml_out/cli_demo_entrypoint.py --help

Which should show something like:

 Usage: cli_demo_entrypoint.py [OPTIONS]                                    

╭─ Options ────────────────────────────────────────────────────────────────╮
│ *  --config-paths              TEXT     Paths to config files            │
│                                         [default: None]                  │
│                                         [required]                       │
│    --sweep-paths               TEXT     Paths to sweep files             │
│                                         [default: <class 'list'>]        │
│    --num-epochs                INTEGER  Number of epochs. Maps to        │
│                                         'num_epochs'                     │
│                                         [default: None]                  │
│    --path                      TEXT     Path to the dataset. Maps to     │
│                                         'path'                           │
│                                         [default: None]                  │
│    --num-layers                INTEGER  Number of layers. Maps to        │
│                                         'num_layers'                     │
│                                         [default: None]                  │
│    --num-workers               INTEGER  Number of workers for loading    │
│                                         the dataset. Maps to             │
│                                         'num_workers'                    │
│                                         [default: None]                  │
│    --install-completion                 Install completion for the       │
│                                         current shell.                   │
│    --show-completion                    Show completion for the current  │
│                                         shell, to copy it or customize   │
│                                         the installation.                │
│    --help                               Show this message and exit.      │
╰──────────────────────────────────────────────────────────────────────────╯

Notice how the names of the args are just the names of the fields in the innermost message of the nested structure. The names are unique globally, so these short names suffice for finding the intended field within the full structure.

💡 Workflow

We recommend copying the generated entrypoint and modifying it to fit your needs.

For example, you might write a run_trial function that interfaces with your model and training code.

⏩ Shortening CLI arguments

As stated before, CLI argument names are shortened for deeply nested fields in your config when there's exactly one path to a field and the field name is unique. If the field name is not unique, we will prepend accessors to the field name until it is unique.

Take for example the following protobuf file:

// cli_demo_deep.proto
syntax = "proto3";

package cli_demo_deep;

import "py_gen_ml/extensions.proto";


// Global configuration
message CliDemoDeep {
    option (pgml.cli).enable = true;
    // Dataset configuration
    Data data = 1;
    // Model configuration
    Model model = 2;
    // Training configuration
    Training training = 3;
}

// Dataset configuration
message Dataset {
    // Path to the dataset
    string path = 1;
}

// Data config
message Data {
    // Path to the dataset
    Dataset train_dataset = 1;
    // Path to the dataset
    Dataset test_dataset = 2;
    // Number of workers for loading the dataset
    uint32 num_workers = 3;
}

// Model configuration
message Model {
    // Number of layers
    uint32 num_layers = 1;
}

// Training configuration
message Training {
    // Number of epochs
    uint32 num_epochs = 1;
}

This generates the following CLI arguments:

# Autogenerated code. DO NOT EDIT.
import py_gen_ml as pgml
import typing

import pydantic
import typer

from . import cli_demo_deep_base as base


class CliDemoDeepArgs(pgml.YamlBaseModel):
    """Global configuration"""

    train_dataset_path: typing.Annotated[
        typing.Optional[str],
        typer.Option(help="Path to the dataset. Maps to 'data.train_dataset.path'"),
        pydantic.Field(None),
        pgml.ArgRef("data.train_dataset.path"),
    ]
    """Path to the dataset"""

    test_dataset_path: typing.Annotated[
        typing.Optional[str],
        typer.Option(help="Path to the dataset. Maps to 'data.test_dataset.path'"),
        pydantic.Field(None),
        pgml.ArgRef("data.test_dataset.path"),
    ]
    """Path to the dataset"""

    num_workers: typing.Annotated[
        typing.Optional[int],
        typer.Option(
            help="Number of workers for loading the dataset. Maps to 'data.num_workers'"
        ),
        pydantic.Field(None),
        pgml.ArgRef("data.num_workers"),
    ]
    """Number of workers for loading the dataset"""

    num_epochs: typing.Annotated[
        typing.Optional[int],
        typer.Option(help="Number of epochs. Maps to 'training.num_epochs'"),
        pydantic.Field(None),
        pgml.ArgRef("training.num_epochs"),
    ]
    """Number of epochs"""

    num_layers: typing.Annotated[
        typing.Optional[int],
        typer.Option(help="Number of layers. Maps to 'model.num_layers'"),
        pydantic.Field(None),
        pgml.ArgRef("model.num_layers"),
    ]
    """Number of layers"""

Notice how data.train_dataset.path is shortened to train_dataset_path and data.test_dataset.path is shortened to test_dataset_path.

🎯 Explicit Argument References

For more control, use explicit argument references in your protobuf:

// cli_extension_demo.proto
syntax = "proto3";

package cli_extension_demo;

import "py_gen_ml/extensions.proto";


// Global configuration
message CliExtensionDemo {
    option (pgml.cli) = {
        enable: true
        arg: { name: "train_path", path: "data.train_dataset.path" }
        arg: { name: "test_path", path: "data.test_dataset.path" }
    };
    // Dataset configuration
    Data data = 1;
    // Model configuration
    Model model = 2;
    // Training configuration
    Training training = 3;
}

// Dataset configuration
message Dataset {
    // Path to the dataset
    string path = 1;
}

// Data config
message Data {
    // Path to the dataset
    Dataset train_dataset = 1;
    // Path to the dataset
    Dataset test_dataset = 2;
    // Number of workers for loading the dataset
    uint32 num_workers = 3;
}

// Model configuration
message Model {
    // Number of layers
    uint32 num_layers = 1;
}

// Training configuration
message Training {
    // Number of epochs
    uint32 num_epochs = 1;
}

The explicit argument references will replace the ones we have seen previously:

# Autogenerated code. DO NOT EDIT.
import py_gen_ml as pgml
import typing

import pydantic
import typer

from . import cli_extension_demo_base as base


class CliExtensionDemoArgs(pgml.YamlBaseModel):
    """Global configuration"""

    train_path: typing.Annotated[
        typing.Optional[str],
        typer.Option(help="Path to the dataset. Maps to 'data.train_dataset.path'"),
        pydantic.Field(None),
        pgml.ArgRef("data.train_dataset.path"),
    ]
    """Path to the dataset"""

    test_path: typing.Annotated[
        typing.Optional[str],
        typer.Option(help="Path to the dataset. Maps to 'data.test_dataset.path'"),
        pydantic.Field(None),
        pgml.ArgRef("data.test_dataset.path"),
    ]
    """Path to the dataset"""

    num_epochs: typing.Annotated[
        typing.Optional[int],
        typer.Option(help="Number of epochs. Maps to 'training.num_epochs'"),
        pydantic.Field(None),
        pgml.ArgRef("training.num_epochs"),
    ]
    """Number of epochs"""

    num_workers: typing.Annotated[
        typing.Optional[int],
        typer.Option(
            help="Number of workers for loading the dataset. Maps to 'data.num_workers'"
        ),
        pydantic.Field(None),
        pgml.ArgRef("data.num_workers"),
    ]
    """Number of workers for loading the dataset"""

    num_layers: typing.Annotated[
        typing.Optional[int],
        typer.Option(help="Number of layers. Maps to 'model.num_layers'"),
        pydantic.Field(None),
        pgml.ArgRef("model.num_layers"),
    ]
    """Number of layers"""

📚 Summary

With py-gen-ml, you get powerful, flexible CLI argument parsing that adapts to your needs, whether using implicit shortcuts or explicit references.