请选择 进入手机版 | 继续访问电脑版
  • 设为首页
  • 点击收藏
  • 手机版
    手机扫一扫访问
    迪恩网络手机版
  • 关注官方公众号
    微信扫一扫关注
    公众号

Torch张量的view方法有什么作用?

原作者: [db:作者] 来自: [db:来源] 收藏 邀请

我对以下代码片段中的方法view()感到困惑,不知道这个view方法起什么作用。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool  = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1   = nn.Linear(16*5*5, 120)
        self.fc2   = nn.Linear(120, 84)
        self.fc3   = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()

有困惑问题在下面这行。

x = x.view(-1, 16*5*5)

tensor.view()方法有什么作用?我已经在许多地方看到它的用法,但我无法理解它如何解析它的’参数。

如果我将负值作为参数提供给view()函数会发生什么?例如,如果我调用tensor_variable.view(1, 1, -1)会怎么样?

请举例说明view()功能的主要原理。

最佳解释

view函数旨在reshape张量形状。

假设你有一个张量

import torch
a = torch.range(1, 16)

其中a是一个张量,有16个元素,从1到16(包括在内)。如果你想重塑这个张量使其成为4 x 4张量,那么你可以使用

a = a.view(4, 4)

现在a将是一个4 x 4张量。(注意,重塑后元素的总数需要保持不变。将张量a重新整形为3 x 5张量是不合适的)

参数-1是什么意思?

如果你不知道你想要多少行,但确定列数,那么你可以将行数设置为-1(你可以将它扩展到具有更多维度的张量。只有一个轴值可以是-1)。这是告诉系统Library:给我一个具有这么多列的张量,并计算实现这一点所需的适当行数。

这可以在您上面给出的神经网络代码中看到。在前向功能中的x = self.pool(F.relu(self.conv2(x)))行之后,您将拥有一个16深度的特征图。您必须将其展平以将其提供给全连接的图层。所以告诉pytorch重新塑造你获得的张量,使其具有特定数量的列并让它自己决定行数。

从numpy和pytorch之间的相似性来看,view类似于numpy的reshape函数。

补充解释

让我们举一些例子,从简到难。

  1. view方法返回张量与self张量相同的数据(这意味着返回的张量具有相同数量的元素),但具有不同的形状。例如:

    a = torch.arange(1, 17)  # a's shape is (16,)
    
    a.view(4, 4) # output below
      1   2   3   4
      5   6   7   8
      9  10  11  12
     13  14  15  16
    [torch.FloatTensor of size 4x4]
    
    a.view(2, 2, 4) # output below
    (0 ,.,.) = 
    1   2   3   4
    5   6   7   8
    
    (1 ,.,.) = 
     9  10  11  12
    13  14  15  16
    [torch.FloatTensor of size 2x2x4]
    
  2. 假设-1不是其中一个参数,当将它们相乘时,结果必须等于张量中的元素数。如果您执行:a.view(3, 3),它将引发RuntimeError,因为对于具有16个元素的输入,形状(3 x 3)无效。换句话说:3 x 3不等于16但是9。

  3. 您可以使用-1作为传递给函数的参数之一,但只能使用一次。所发生的事情是该方法将自动计算维度。例如,a.view(2, -1, 4)等同于a.view(2, 2, 4)。 [16 /(2 x 4)= 2]

  4. 特别请注意,返回的张量共享相同的数据。如果您相对”view”中进行更改,则需要更改原始张量数据:

    b = a.view(4, 4)
    b[0, 2] = 2
    a[2] == 3.0
    False
    
  5. 现在,对于更复杂的用例。文档说每个新的视图维度必须是原始维度的子空间,或者只有跨度d,d + 1,…,d + k满足以下contiguity-like条件,对于所有i = 0,… ,k – 1,stride [i] = stride [i + 1] x size [i + 1]。否则,需要在可以查看张量之前调用contiguous()。例如:

    a = torch.rand(5, 4, 3, 2) # size (5, 4, 3, 2)
    a_t = a.permute(0, 2, 3, 1) # size (5, 3, 2, 4)
    
    # The commented line below will raise a RuntimeError, because one dimension
    # spans across two contiguous subspaces
    # a_t.view(-1, 4)
    
    # instead do:
    a_t.contiguous().view(-1, 4)
    
    # To see why the first one does not work and the second does,
    # compare a.stride() and a_t.stride()
    a.stride() # (24, 6, 2, 1)
    a_t.stride() # (24, 2, 1, 6)
    

    请注意,对于a_t,stride [0]!= stride [1] x size [1],因为24!= 2 x 3

参考资料

  • How view() method works for tensor in torch


鲜花

握手

雷人

路过

鸡蛋
专题导读
上一篇:
PyTorch入门简介发布时间:2022-05-14
下一篇:
Pandas 相关矩阵的计算与可视化发布时间:2022-05-14
热门推荐
阅读排行榜

扫描微信二维码

查看手机版网站

随时了解更新最新资讯

139-2527-9053

在线客服(服务时间 9:00~18:00)

在线QQ客服
地址:深圳市南山区西丽大学城创智工业园
电邮:jeky_zhao#qq.com
移动电话:139-2527-9053

Powered by 互联科技 X3.4© 2001-2213 极客世界.|Sitemap