• 欢迎访问搞代码网站,推荐使用最新版火狐浏览器和Chrome浏览器访问本网站!
  • 如果您觉得本站非常有看点,那么赶紧使用Ctrl+D 收藏搞代码吧

pytorch中的matmul与mm,bmm区别说明

python 搞代码 4年前 (2022-01-07) 18次浏览 已收录 0个评论
文章目录[隐藏]

这篇文章主要介绍了pytorch中的matmul与mm,bmm区别说明,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝赐教

pytorch中matmul和mm和bmm区别 matmulmmbmm结论

先看下官网上对这三个函数的介绍。

matmul

mm

bmm

顾名思义, 就是两个batch矩阵乘法.

结论

从官方文档可以看出

1、mm只能进行矩阵乘法,也就是输入的两个tensor维度只能是( n × m ) (n\times m)(n×m)和( m × p ) (m\times p)(m×p)

2、bmm是两个三维张量相乘, 两个输入tensor维度是( b × n × m ) (b\times n\times m)(b×n×m)和( b × m × p ) (b\times m\times p)(b×m×p), 第一维b代表batch size,输出为( b × n × p ) (b\times n \times p)(b×n×p)

3、matmul可以进行张量乘法, 输入可以是高维.

补充:torch中的几种乘法。torch.mm, torch.mul, torch.matmul

一、点乘

点乘都是broadcast的,可以用torch.mul(a, b)实现,也可以直接用*实现。

 >>> a = torch.ones(3,4) >>> a tensor([[1., 1., 1., 1.], [1., 1., 1., 1.], [1., 1., 1., 1.]]) >>> b = torch.Tensor([1,2,3]).reshape((3,1)) >>> b tensor([[1.], [2.], [3.]]) >>> torch.mul(a, b) tensor([[1., 1., 1., 1.], [2., 2., 2., 2.], [3., 3., 3., 3.]])

当a, b维度不一致时,会自动填充到相同维度相点乘。

二、矩阵乘

矩阵相乘有torch.mm和torch.matmul两个函数。其中前一个是针对二维矩阵,后一个是高维。当torch.mm用于大于二维时将报错来源gaodai#ma#com搞@代~码网

 >>> a = torch.ones(3,4) >>> b = torch.ones(4,2) >>> torch.mm(a, b) tensor([[4., 4.], [4., 4.], [4., 4.]])
 >>> a = torch.ones(3,4) >>> b = torch.ones(5,4,2) >>> torch.matmul(a, b).shape torch.Size([5, 3, 2])
 >>> a = torch.ones(5,4,2) >>> b = torch.ones(5,2,3) >>> torch.matmul(a, b).shape torch.Size([5, 4, 3])
 >>> a = torch.ones(5,4,2) >>> b = torch.ones(5,2,3) >>> torch.matmul(b, a).shape 报错。 

以上为个人经验,希望能给大家一个参考,也希望大家多多支持gaodaima搞代码网。如有错误或未考虑完全的地方,望不吝赐教。

以上就是pytorch中的matmul与mm,bmm区别说明的详细内容,更多请关注gaodaima搞代码网其它相关文章!


搞代码网(gaodaima.com)提供的所有资源部分来自互联网,如果有侵犯您的版权或其他权益,请说明详细缘由并提供版权或权益证明然后发送到邮箱[email protected],我们会在看到邮件的第一时间内为您处理,或直接联系QQ:872152909。本网站采用BY-NC-SA协议进行授权
转载请注明原文链接:pytorch中的matmul与mm,bmm区别说明

喜欢 (0)
[搞代码]
分享 (0)
发表我的评论
取消评论

表情 贴图 加粗 删除线 居中 斜体 签到

Hi,您需要填写昵称和邮箱!

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址