PyTorch的索引和切片

在 PyTorch 中,张量的索引和切片与 Python 中的列表和 NumPy 数组的索引和切片非常相似。通过索引和切片,您可以访问、修改或提取张量的一部分。

1. 索引

索引允许你访问张量的单个元素。与 Python 列表类似,索引是从 0 开始的。

import torch

x = torch.tensor([1, 2, 3, 4, 5])
print(x[2])  # 输出 tensor(3)

对于多维张量,使用逗号隔开的索引访问其元素:

x = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(x[1, 0])  # 输出 tensor(3) —— 第2行,第1列

2. 切片

切片允许你提取张量的子集。切片的基本语法是 start:stop:step,其中所有值都是可选的。

x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

# 提取从索引2到索引7(不包括7)的元素
print(x[2:7])  # 输出 tensor([2, 3, 4, 5, 6])

# 使用步长提取元素
print(x[1:8:2])  # 输出 tensor([1, 3, 5, 7])

多维张量同样支持切片:

x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])

# 提取前3行的前2列
print(x[:3, :2])  
# 输出
# tensor([[ 1,  2],
#         [ 4,  5],
#         [ 7,  8]])

3. 高级索引

除了基本的索引和切片外,PyTorch 还支持更高级的索引方法,例如使用整数列表或张量来进行索引:

x = torch.tensor([0, 1, 2, 3, 4])

# 使用整数列表进行索引
indices = [1, 3]
print(x[indices])  # 输出 tensor([1, 3])

# 使用整数张量进行索引
indices_tensor = torch.tensor([1, 3])
print(x[indices_tensor])  # 输出 tensor([1, 3])

4. 布尔索引

你还可以使用布尔值进行索引,选择满足某个条件的元素:

x = torch.tensor([0, 1, 2, 3, 4, 5])

# 创建一个布尔张量,选择所有大于3的值
mask = x > 3
print(mask)  # 输出 tensor([False, False, False, False, True, True])
print(x[mask])  # 输出 tensor([4, 5])

这些索引和切片方法提供了在 PyTorch 中操作张量的强大工具,无论是为了数据处理还是为了更高级的张量操作。

发表评论

您的邮箱地址不会被公开。 必填项已用 * 标注