对于 PyTorch 的基本数据对象 Tensor (张量),在处理问题时,需要经常改变数据的维度,以便于后期的计算和进一步处理,本文旨在列举一些维度变换的方法并举例,方便大家查看。
维度查看:torch.Tensor.size()
查看当前 tensor 的维度
举个例子:
>>> import torch >>> a = torch.Tensor([[[1, 2], [3, 4], [5, 6]]]) >>> a.size() torch.Size([1, 3, 2])
张量变形:torch.Tensor.view(*args) → 来2源gaodaima#com搞(代@码&网Tensor
返回一个有相同数据但大小不同的 tensor。 返回的 tensor 必须有与原 tensor 相同的数据和相同数目的元素,但可以有不同的大小。一个 tensor 必须是连续的 contiguous() 才能被查看。
举个例子:
>>> x = torch.randn(2, 9) >>> x.size() torch.Size([2, 9]) >>> x tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038, 0.5166, 0.9774, 0.3455], [-0.2306, 0.4217, 1.2874, -0.3618, 1.7872, -0.9012, 0.8073, -1.1238, -0.3405]]) >>> y = x.view(3, 6) >>> y.size() torch.Size([3, 6]) >>> y tensor([[-1.6833, -0.4100, -1.5534, -0.6229, -1.0310, -0.8038], [ 0.5166, 0.9774, 0.3455, -0.2306, 0.4217, 1.2874], [-0.3618, 1.7872, -0.9012, 0.8073, -1.1238, -0.3405]]) >>> z = x.view(2, 3, 3) >>> z.size() torch.Size([2, 3, 3]) >>> z tensor([[[-1.6833, -0.4100, -1.5534], [-0.6229, -1.0310, -0.8038], [ 0.5166, 0.9774, 0.3455]], [[-0.2306, 0.4217, 1.2874], [-0.3618, 1.7872, -0.9012], [ 0.8073, -1.1238, -0.3405]]])
可以看到 x 和 y 、z 中数据的数量和每个数据的大小都是相等的,只是尺寸或维度数量发生了改变。
压缩 / 解压张量:torch.squeeze()、torch.unsqueeze()
- torch.squeeze(input, dim=None, out=None)
将输入张量形状中的 1 去除并返回。如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)
当给定 dim 时,那么挤压操作只在给定维度上。例如,输入形状为: (A×1×B),squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。
返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
举个例子:
>>> x = torch.randn(3, 1, 2) >>> x tensor([[[-0.1986, 0.4352]], [[ 0.0971, 0.2296]], [[ 0.8339, -0.5433]]]) >>> x.squeeze().size() # 不加参数,去掉所有为元素个数为1的维度 torch.Size([3, 2]) >>> x.squeeze() tensor([[-0.1986, 0.4352], [ 0.0971, 0.2296], [ 0.8339, -0.5433]]) >>> torch.squeeze(x, 0).size() # 加上参数,去掉第一维的元素,不起作用,因为第一维有2个元素 torch.Size([3, 1, 2]) >>> torch.squeeze(x, 1).size() # 加上参数,去掉第二维的元素,正好为 1,起作用 torch.Size([3, 2])
可以看到如果加参数,只有维度中尺寸为 1 的位置才会消失
- torch.unsqueeze(input, dim, out=None)
返回一个新的张量,对输入的制定位置插入维度 1
返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
如果 dim 为负,则将会被转化 dim+input.dim()+1
接着用上面的数据举个例子:
>>> x.unsqueeze(0).size() torch.Size([1, 3, 1, 2]) >>> x.unsqueeze(0) tensor([[[[-0.1986, 0.4352]], [[ 0.0971, 0.2296]], [[ 0.8339, -0.5433]]]]) >>> x.unsqueeze(-1).size() torch.Size([3, 1, 2, 1]) >>> x.unsqueeze(-1) tensor([[[[-0.1986], [ 0.4352]]], [[[ 0.0971], [ 0.2296]]], [[[ 0.8339], [-0.5433]]]])