1.torch.optim优化器实现L2正则化
torch.optim集成了很多优化器,如SGD,Adadelta,Adam,Adagrad,RMSprop等,这些优化器自带的一个参数weight_decay,用于指定权值衰减率,相当于L2正则化中的λ参数,注意torch.optim集成的优化器只有L2正则化方法,你可以查看注释,参数weight_decay 的解析是:
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
使用torch.optim的优化器,可如下设置L2正则化
optimizer = optim.Adam(model.parameters(),lr=learning_rate,weight_decay=0.01)
但是这种方法存在几个问题,
(1)一般正则化,只是对模型的权重W参数进行惩罚,而偏置参数b是不进行惩罚的,而torch.optim的优化器weight_decay参数指定的权值衰减是对网络中的所有参数,包括权值w和偏置b同时进行惩罚。很多时候如果对b 进行L2正则化将会导致严重的欠拟合,因此这个时候一般只需要对权值w进行正则即可。(PS:这个我真不确定,源码解析是 weight decay (L2 penalty) ,但有些网友说这种方法会对参数偏置b也进行惩罚,可解惑的网友给个明确的答复)
(2)缺点:torch.optim的优化器固定实现L2正则化,不能实现L1正则化。如果需要L1正则化,可如下实现:
(3)根据正则化的公式,加入正则化后,loss会变原来大,比如weight_decay=1的loss为10,那么weight_decay=100时,loss输出应该也提高100倍左右。而采用torch.optim的优化器的方法,如果你依然采用loss_fun= nn.CrossEntropyLoss()进行计算loss,你会发现,不管你怎么改变weight_decay的大小,loss会跟之前没有加正则化的大小差不多。这是因为你的loss_fun损失函数没有把权重W的损失加上。
(4)采用torch.optim的优化器实现正则化的方法,是没问题的!只不过很容易让人产生误解,对鄙人而言,我更喜欢TensorFlow的正则化实现方法,只需要tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES),实现过程几乎跟正则化的公式对应的上。
(5)Github项目源码:点击进入
为了,解决这些问题,我特定自定义正则化的方法,类似于TensorFlow正则化实现方法。
2. 如何判断正则化作用了模型?
一般来说,正则化的主要作用是避免模型产生过拟合,当然啦,过拟合问题,有时候是难以判断的。但是,要判断正则化是否作用了模型,还是很容易的。下面我给出两组训练时产生的loss和Accuracy的log信息,一组是未加入正则化的,一组是加入正则化:
2.1 未加入正则化loss和Accuracy
优化器采用Adam,并且设置参数weight_decay=0.0,即无正则化的方法
optimizer = optim.Adam(model.parameters(),lr=learning_rate,weight_decay=0.0)
训练时输出的 loss和Accuracy信息
step/epoch:0/0,Train Loss: 2.418065, Acc: [0.15625] step/epoch:10/0,Train Loss: 5.194936, Acc: [0.34375] step/epoch:20/0,Train Loss: 0.973226, Acc: [0.8125] step/epoch:30/0,Train Loss: 1.215165, Acc: [0.65625] step/epoch:40/0,Train Loss: 1.808068, Acc: [0.65625] step/epoch:50/0,Train Loss: 1.661446, Acc: [0.625] step/epoch:60/0,Train Loss: 1.552345, Acc:<i style="color:transparent">本文来源gaodai$ma#com搞$代*码*网(</i> [0.6875] step/epoch:70/0,Train Loss: 1.052912, Acc: [0.71875] step/epoch:80/0,Train Loss: 0.910738, Acc: [0.75] step/epoch:90/0,Train Loss: 1.142454, Acc: [0.6875] step/epoch:100/0,Train Loss: 0.546968, Acc: [0.84375] step/epoch:110/0,Train Loss: 0.415631, Acc: [0.9375] step/epoch:120/0,Train Loss: 0.533164, Acc: [0.78125] step/epoch:130/0,Train Loss: 0.956079, Acc: [0.6875] step/epoch:140/0,Train Loss: 0.711397, Acc: [0.8125]