1.简介
由于现在很多机器学习的实验需要设置繁琐的参数,在多次实验中,有些参数是一样的,为了方便设置参数,Gin库出现了。它允许函数或类被注释为@gin.configurable,这使得能够使用清晰而强大的语法通过简单的配置文件来设置它们的参数。这种方法减少了配置维护,同时使实验配置透明且易于重复。
简单理解,gin像一个封装了参数配置的类,使用这个类将使得大量的参数配置变得简单清晰
安装
本文来源gaodai$ma#com搞$代*码网2
pip install gin-config
[email protected]
任何函数和类都可以使用@gin.configurable装饰器
@gin.configurable def my_network(images, num_outputs, num_layers=3, weight_decay=1e-4): ...
@gin.configurable装饰器做了如下三件事:
- 把类或函数声明成了可配置的东西
- 它决定了函数或类构造函数的哪些参数是可配置的(默认情况下是其所有的参数)
- 封装类或函数,拦截调用,并向函数的可配置参数提供来自参数设置全局注册表的值(这些值是类或函数声明时没有指定的值)
为了确定哪些是可以配置的参数,@gin.configurable会使用到allowlist和denylist参数,分别声明哪些是可配的哪些是不可配的,我们通常用一个即可,默认没有用allowlist指定的都为不可配,反之亦然。
@gin.configurable('supernet', denylist=['images']) def my_network(images, num_outputs, num_layers=3, weight_decay=1e-4): ...
其中supernet是我们指定的配置名。
3.赋值
我们使用如下两种格式给参数赋值:
gin.bind_parameter('configurable_name.parameter_name', value)
configurable_name.parameter_name = value
具体例子分别如下:
gin.bind_parameter('supernet.num_layers', 5) gin.bind_parameter('supernet.weight_decay', 1e-3)
supernet.num_layers = 5 supernet.weight_decay = 1e-3
4.取值
我们可以用gin.query_parameter
来取值,具体例子如下
num_layers = gin.query_parameter('supernet.num_layers') weight_decay = gin.query_parameter('supernet.weight_decay')
5.配置参考文件
假如我们有以下代码:
@gin.configurable class DNN(object): def __init__(self, num_units=(1024, 1024)): ... def __call__(inputs, num_outputs): ... @gin.configurable(denylist=['data']) def train_model(network_fn, data, learning_rate, optimizer): ...
我们可以在gin文件里配置参数:
train_model.network_fn = @DNN() # An instance of DNN is passed. train_model.optimizer = @MomentumOptimizer # The class itself is passed. train_model.learning_rate = 0.001 DNN.num_units = (2048, 2048, 2048) MomentumOptimizer.momentum = 0.9
上面显示了两种配置参数风格。@DNN()
和@MomentumOptimizer
。对于前者将会调用DNN类的实例参数,且每次参数配置都会随着每个DNN类的实例变动。对于后者将会调用类MomentumOptimizer的默认参数。
6.使用gin文件
我们经常会和absl下flags一起使用gin,比如下面这样
from absl import flags flags.DEFINE_multi_string( 'gin_file', None, 'List of paths to the config files.') flags.DEFINE_multi_string( 'gin_param', None, 'Newline separated list of Gin parameter bindings.') FLAGS = flags.FLAGS