• 欢迎访问搞代码网站,推荐使用最新版火狐浏览器和Chrome浏览器访问本网站!
  • 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏搞代码吧

pytorch交叉熵损失函数的weight参数的使用

python 搞代码 4年前 (2022-01-08) 115次浏览 已收录 0个评论
文章目录[隐藏]

这篇文章主要介绍了pytorch交叉熵损失函数的weight参数的使用,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

首先

必须将权重也转为Tensor的cuda格式;

然后

将该class_weight作为交叉熵函数对应参数的输入值。

 class_weight = torch.FloatTensor([0.13859937, 0.5821059, 0.63871904, 2.30220396, 7.1588294, 0]).cuda()

补充:关于pytorch的CrossEntropyLoss的weight参数

首先这个weight参数比想象中的要考虑的多

你可以试试下面代码

 import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,0,0,0,1]) outputs = torch.LongTensor([0,1]) inputs = inputs.view((1,3,2)) outputs = outputs.view((1,2)) weight_CE = torch.FloatTensor([1,1,1]) ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE) loss = ce(inputs,outputs) print(loss)
 tensor(1.4803)

这里的手动计算是:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *1)/ 2 = 1.4803

加权呢?

 import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,0,0,0,1]) outputs = torch.LongTensor([0,1]) inputs = inputs.view((1<p style="color:transparent">来源gao!%daima.com搞$代*!码$网</p>,3,2)) outputs = outputs.view((1,2)) weight_CE = torch.FloatTensor([1,2,3]) ce = nn.CrossEntropyLoss(ignore_index=255,weight=weight_CE) loss = ce(inputs,outputs) print(loss) 
 tensor(1.6075)

手算发现,并不是单纯的那权重相乘:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 * 1 + loss2 * 2)/ 2 = 2.4113

而是

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

求平均 = (loss1 *1 + loss2 *2) / 3 = 1.6075

发现了么,加权后,除以的是权重的和,不是数目的和。

我们再验证一遍:

 import torch import torch.nn as nn inputs = torch.FloatTensor([0,1,2,0,0,0,0,0,0,1,0,0.5]) outputs = torch.LongTensor([0,1,2,2]) inputs = inputs.view((1,3,4)) outputs = outputs.view((1,4)) weight_CE = torch.FloatTensor([1,2,3]) ce = nn.CrossEntropyLoss(weight=weight_CE) # ce = nn.CrossEntropyLoss(ignore_index=255) loss = ce(inputs,outputs) print(loss) 
 tensor(1.5472)

手算:

loss1 = 0 + ln(e0 + e0 + e0) = 1.098

loss2 = 0 + ln(e1 + e0 + e1) = 1.86

loss3 = 0 + ln(e2 + e0 + e0) = 2.2395

loss4 = -0.5 + ln(e0.5 + e0 + e0) = 0.7943

求平均 = (loss1 * 1 + loss2 * 2+loss3 * 3+loss4 * 3) / 9 = 1.5472

可能有人对loss的CE计算过程有疑问,我这里细致写写交叉熵的计算过程,就拿最后一个例子的loss4的计算说明

以上就是pytorch交叉熵损失函数的weight参数的使用的详细内容,更多请关注gaodaima搞代码网其它相关文章!


搞代码网(gaodaima.com)提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发送到邮箱[email protected],我们会在看到邮件的第一时间内为您处理,或直接联系QQ:872152909。本网站采用BY-NC-SA协议进行授权
转载请注明原文链接:pytorch交叉熵损失函数的weight参数的使用

喜欢 (0)
[搞代码]
分享 (0)
发表我的评论
取消评论

表情 贴图 加粗 删除线 居中 斜体 签到

Hi,您需要填写昵称和邮箱!

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址