函数作用
计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,h),注意两个tensor的维度必须为3.
代码示例
>>> c=torch.randn((2,5)) >>> print(c) tensor([[ 1.0559, -0.3533, 0.5194, 0.9526, -0.2483], [-0.1293, 0.4809, -0.5268, -0.3673, 0.0666]]) >>> d=torch.reshape(c,(5,2)) >>> print(d) tensor([[ 1.0559, -0.3533], [ 0.5194, 0.9526], [-0.2483, -0.1293], [ 0.4809, -0.5268], [-0.3673, 0.0666]]) >>> e=torch.bmm(c,d) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
当tensor维度为2时会报错!
>>> ccc=torch.randn((1,2,2,5)) >>> ddd=torch.randn((1,2,5,2)) >>> e=torch.bmm(ccc,ddd) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: invalid argument 1: expected 3D tensor, got 4D at /opt/conda/conda-bld/pytorch_1535490206202/work/aten/src/TH/generic/THTensorMath.cpp:2304
维度为4时也会报错!
>>> cc=torch.randn((2,2,5)) >>>print(cc) tensor([[[ 1.4873, -0.7482, -0.6734, -0.9682, 1.2869], [ 0.0550, -0.4461, -0.1102, -0.0797, -0.8349]], [[-0.6872, 1.1920, -0.9732, 0.4580, 0.7901], [ 0.3035, 0.2022, 0.8815, 0.9982, -1.1892]]]) >>>dd=torch.reshape(cc,(2,5,2)) >>> print(dd) tensor([[[ 1.4873, -0.7482], [-0.6734, -0.9682], [ 1.2869, 0.0550], [-0.4461, -0.1102], [-0.0797, -0.8349]], [[-0.6872, 1.1920], [-0.9732, 0.4580], [ 0.7901, 0.3035], [ 0.2022, 0.8815], [ 0.9982, -1.1892]]]) >>>e=torch.bmm(cc,dd) >>> print(e) tensor([[[ 2.1787, -1.3931], [ 0.3425, 1.0906]], [[-0.5754, -1.1045], [-0.6941, 3.0161]]]) >>> e.size() torch.Size([2, 2, 2])