optim 的基本使用
for do:
1. 计算loss
2. 清空梯度
3. 反传梯度
4. 更新参数
optim的完整流程
cifiron = nn.MSELoss() optimiter = torch.optim.SGD(net.parameters(),lr=0.01,momentum=0.9) for i in range(iters): out = net(inputs) loss = cifiron(out,label) optimiter.zero_grad() # 清空之前保留的梯度信息 loss.backward() # 将mini_batch 的loss 信息反传回去 optimiter.step() # 根据 optim参数 和 梯度 更新参数 w.data -= w.grad*lr
网络参数 默认使用统一的 优化器参数
如下设置 网络全局参数 使用统一的优化器参数
optimiter = torch.optim.Adam(net.parameters(),lr=0.01,momentum=0.9)
如下设置将optimizer的可更新参数分为不同的三组,每组使用不同的策略
optimizer = torch.optim.SGD([ {'params': other_params}, {'params': first_params, 'lr': 0.01*args.learning_rate}, {'params': second_params, 'weight_decay': args.weight_decay}], lr=args.learning_rate, momentum=args.momentum, )
我们追溯一下构造Optim的过程
为了更好的看整个过程,去掉了很多 条件判断 语句,如 >0 <0
# 首先是 子类Adam 的构造函数 class Adam(Optimizer): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) ''' 构造了 参数params,可以有两种传入格式,分别对应 1. 全局参数 net.parameters() 2. 不同参数组 [{'params': other_params}, {'params': first_params, 'lr': 0.1*lr}] 和 <全局> 的默认参数字典defaults ''' # 然后调用 父类Optimizer 的构造函数 super(Adam, self).__init__(params, defaults) # 看一下 Optim类的构造函数 只有两个输入 params 和 defaults class Optimizer(object): def __init__(self, params, defaults): torch._C._log_api_usage_once("python.optimizer") self.defaults = defaults self.state = defaultdict(dict) self.param_groups = [] # 自身构造的参数组,每个组使用一套参数 param_groups = list(params) if len(param_groups) == 0: raise ValueError("optimizer got an empty parameter list") # 如果传入的net.parameters(),将其转换为 字典 if not isinstance(param_groups[0], dict): param_groups = [{'params': param_groups}] for param_group in param_groups: #add_param_group 这个函数,主要是处理一下每个参数组其它属性参数(lr,eps) self.add_param_group(param_group) def add_param_group(self, param_group): # 如果当前 参数组中 不存在默认参数的设置,则使用全局参数属性进行覆盖 ''' [{'params': other_params}, {'params': first_params, 'lr': 0.1*lr}] 如第一个参数组 只提供了参数列表,没有其它的参数属性,则使用全局属性覆盖,第二个参数组 则设置了自身的lr为全局 (0.1*lr) ''' for name, default in self.defaults.items(): if default is required and name not in param_group: raise ValueError("parameter group didn't specify a value of required optimization parameter " + name) else: param_group.setdefault(name, default) # 判断 是否有一个参数 出现在不同的参数组中,否则会报错 param_set = set() for group in self.param_groups: param_set.update(set(group['params'])) if not param_set.isdisjoint(set(param_group['params'])): raise ValueErro<b>本文来源gao@!dai!ma.com搞$$代^@码!网</b>r("some parameters appear in more than one parameter group") # 然后 更新自身的参数组中 self.param_groups.append(param_group)