在 PyTorch 中,我们可以使用多种方式调整张量的形状。以下是一些常见的形状操作方法及其示例:
-
reshape() 或 view():这两个方法都可以重新调整张量的形状。注意,这些操作并不会改变原始张量的数据,而是返回一个新的张量视图。
import torch x = torch.arange(12) # tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) reshaped_x = x.reshape(3, 4) print(reshaped_x)
-
squeeze():去除尺寸为1的维度。
x = torch.randn(1, 3, 1, 5) squeezed_x = x.squeeze() print(squeezed_x.shape) # 输出: torch.Size([3, 5])
-
unsqueeze():在指定位置添加尺寸为1的维度。
x = torch.randn(3, 5) unsqueezed_x = x.unsqueeze(0) print(unsqueezed_x.shape) # 输出: torch.Size([1, 3, 5])
-
transpose() 或 permute():交换张量的两个维度或更改多个维度的顺序。
x = torch.randn(3, 4, 5) transposed_x = x.transpose(0, 1) # 交换第0和第1个维度 print(transposed_x.shape) # 输出: torch.Size([4, 3, 5]) permuted_x = x.permute(2, 0, 1) # 改变维度的顺序 print(permuted_x.shape) # 输出: torch.Size([5, 3, 4])
-
expand() 或 repeat():用于复制张量。
x = torch.tensor([[1], [2], [3]]) # shape: [3, 1] expanded_x = x.expand(3, 4) # shape: [3, 4], 不增加额外的内存 print(expanded_x) repeated_x = x.repeat(1, 4) # shape: [3, 4], 增加额外的内存 print(repeated_x)
-
flatten():用于将多维张量转换为一维。
x = torch.randn(2, 3, 4) flattened_x = x.flatten() print(flattened_x.shape) # 输出: torch.Size([24])
-
torch.cat():此操作用于沿指定的维度连接张量序列。这意味着,对于指定的维度,张量的尺寸会增加,但其他维度的尺寸必须相同。
import torch x = torch.tensor([[1, 2], [3, 4]]) y = torch.tensor([[5, 6]]) # 沿第0个维度(即行)连接 z = torch.cat([x, y], dim=0) print(z) # 输出: # tensor([[1, 2], # [3, 4], # [5, 6]])
-
torch.stack():此操作用于在新的维度上堆叠张量。这意味着张量的总维数会增加1。为了使用
torch.stack()
,所有张量都必须具有相同的形状。x = torch.tensor([1, 2]) y = torch.tensor([3, 4]) # 沿一个新的维度堆叠 z = torch.stack([x, y]) print(z) # 输出: # tensor([[1, 2], # [3, 4]])
通过掌握这些形状操作,你可以更加灵活地处理和转换张量,从而满足不同的深度学习任务和数据处理需求。