#squeeze 函数:从数组的形状中删除单维度条目,即把shape中为1的维度去掉
#unsqueeze() 是squeeze()的反向操作,增加一个维度,该维度维数为1,可以指定添加的维度。例如unsqueeze(a,1)表示在1这个维度进行添加
import torch a=tor<a style="color:transparent">本文来源gao($daima.com搞@代@#码(网5</a>ch.rand(2,3,1) print(torch.unsqueeze(a,2).size())#torch.Size([2, 3, 1, 1]) print(a.size()) #torch.Size([2, 3, 1]) print(a.squeeze().size()) #torch.Size([2, 3]) print(a.squeeze(0).size()) #torch.Size([2, 3, 1]) print(a.squeeze(-1).size()) #torch.Size([2, 3]) print(a.size()) #torch.Size([2, 3, 1]) print(a.squeeze(-2).size()) #torch.Size([2, 3, 1]) print(a.squeeze(-3).size()) #torch.Size([2, 3, 1]) print(a.squeeze(1).size()) #torch.Size([2, 3, 1]) print(a.squeeze(2).size()) #torch.Size([2, 3]) print(a.squeeze(3).size()) #RuntimeError: Dimension out of range (expected to be in range of [-3, 2], but got 3) print(a.unsqueeze().size()) #TypeError: unsqueeze() missing 1 required positional arguments: "dim" print(a.unsqueeze(-3).size()) #torch.Size([2, 1, 3, 1]) print(a.unsqueeze(-2).size()) #torch.Size([2, 3, 1, 1]) print(a.unsqueeze(-1).size()) #torch.Size([2, 3, 1, 1]) print(a.unsqueeze(0).size()) #torch.Size([1, 2, 3, 1]) print(a.unsqueeze(1).size()) #torch.Size([2, 1, 3, 1]) print(a.unsqueeze(2).size()) #torch.Size([2, 3, 1, 1]) print(a.unsqueeze(3).size()) #torch.Size([2, 3, 1, 1]) print(torch.unsqueeze(a,3)) b=torch.rand(2,1,3,1) print(b.squeeze().size()) #torch.Size([2, 3])
补充:pytorch中unsqueeze()、squeeze()、expand()、repeat()、view()、和cat()函数的总结
学习Bert模型的时候,需要使用到pytorch来进行tensor的操作,由于对pytorch和tensor不熟悉,就把pytorch中常用的、有关tensor操作的unsqueeze()、squeeze()、expand()、view()、cat()和repeat()等函数做一个总结,加深记忆。
1、unsqueeze()和squeeze()
torch.unsqueeze(input, dim,out=None) → Tensor
unsqueeze()的作用是用来增加给定tensor的维度的,unsqueeze(dim)就是在维度序号为dim的地方给tensor增加一维。例如:维度为torch.Size([768])的tensor要怎样才能变为torch.Size([1, 768, 1])呢?就可以用到unsqueeze(),直接上代码:
a=torch.randn(768) print(a.shape) # torch.Size([768]) a=a.unsqueeze(0) print(a.shape) #torch.Size([1, 768]) a = a.unsqueeze(2) print(a.shape) #torch.Size([1, 768, 1])
也可以直接使用链式编程:
a=torch.randn(768) print(a.shape) # torch.Size([768]) a=a.unsqueeze(1).unsqueeze(0) print(a.shape) #torch.Size([1, 768, 1])
tensor经过unsqueeze()处理之后,总数据量不变;维度的扩展类似于list不变直接在外面加几层[]括号。
torch.squeeze(input, dim=None, out=None) → Tensor
squeeze()的作用就是压缩维度,直接把维度为1的维给去掉。形式上表现为,去掉一层[]括号。
同时,输出的张量与原张量共享内存,如果改变其中的一个,另一个也会改变。
a=torch.randn(2,1,768) print(a) print(a.shape) #torch.Size([2, 1, 768]) a=a.squeeze() print(a) print(a.shape) #torch.Size([2, 768])
图片中的维度信息就不一样,红框中的括号层数不同。
注意的是:squeeze()只能压缩维度为1的维;其他大小的维不起作用。
a=torch.randn(2,768) print(a.shape) #torch.Size([2, 768]) a=a.squeeze() print(a.shape) #torch.Size([2, 768])