PyTorch中张量形状的常见操作

在 PyTorch 中,我们可以使用多种方式调整张量的形状。以下是一些常见的形状操作方法及其示例:

  1. reshape() 或 view():这两个方法都可以重新调整张量的形状。注意,这些操作并不会改变原始张量的数据,而是返回一个新的张量视图。

    import torch
    
    x = torch.arange(12)  # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
    reshaped_x = x.reshape(3, 4)
    print(reshaped_x)
  2. squeeze():去除尺寸为1的维度。

    x = torch.randn(1, 3, 1, 5)
    squeezed_x = x.squeeze()
    print(squeezed_x.shape)  # 输出: torch.Size([3, 5])
  3. unsqueeze():在指定位置添加尺寸为1的维度。

    x = torch.randn(3, 5)
    unsqueezed_x = x.unsqueeze(0)
    print(unsqueezed_x.shape)  # 输出: torch.Size([1, 3, 5])
  4. transpose() 或 permute():交换张量的两个维度或更改多个维度的顺序。

    x = torch.randn(3, 4, 5)
    transposed_x = x.transpose(0, 1)  # 交换第0和第1个维度
    print(transposed_x.shape)  # 输出: torch.Size([4, 3, 5])
    
    permuted_x = x.permute(2, 0, 1)  # 改变维度的顺序
    print(permuted_x.shape)  # 输出: torch.Size([5, 3, 4])
  5. expand() 或 repeat():用于复制张量。

    x = torch.tensor([[1], [2], [3]])  # shape: [3, 1]
    expanded_x = x.expand(3, 4)        # shape: [3, 4], 不增加额外的内存
    print(expanded_x)
    
    repeated_x = x.repeat(1, 4)        # shape: [3, 4], 增加额外的内存
    print(repeated_x)
  6. flatten():用于将多维张量转换为一维。

    x = torch.randn(2, 3, 4)
    flattened_x = x.flatten()
    print(flattened_x.shape)  # 输出: torch.Size([24])
  7. torch.cat():此操作用于沿指定的维度连接张量序列。这意味着,对于指定的维度,张量的尺寸会增加,但其他维度的尺寸必须相同。

    import torch
    
    x = torch.tensor([[1, 2], [3, 4]])
    y = torch.tensor([[5, 6]])
    
    # 沿第0个维度(即行)连接
    z = torch.cat([x, y], dim=0)
    print(z)
    # 输出:
    # tensor([[1, 2],
    #         [3, 4],
    #         [5, 6]])
  8. torch.stack():此操作用于在新的维度上堆叠张量。这意味着张量的总维数会增加1。为了使用 torch.stack(),所有张量都必须具有相同的形状。

    x = torch.tensor([1, 2])
    y = torch.tensor([3, 4])
    
    # 沿一个新的维度堆叠
    z = torch.stack([x, y])
    print(z)
    # 输出:
    # tensor([[1, 2],
    #         [3, 4]])

通过掌握这些形状操作,你可以更加灵活地处理和转换张量,从而满足不同的深度学习任务和数据处理需求。

发表评论

您的邮箱地址不会被公开。 必填项已用 * 标注