在 PyTorch 中,张量的索引和切片与 Python 中的列表和 NumPy 数组的索引和切片非常相似。通过索引和切片,您可以访问、修改或提取张量的一部分。
1. 索引
索引允许你访问张量的单个元素。与 Python 列表类似,索引是从 0 开始的。
import torch
x = torch.tensor([1, 2, 3, 4, 5])
print(x[2]) # 输出 tensor(3)
对于多维张量,使用逗号隔开的索引访问其元素:
x = torch.tensor([[1, 2], [3, 4], [5, 6]])
print(x[1, 0]) # 输出 tensor(3) —— 第2行,第1列
2. 切片
切片允许你提取张量的子集。切片的基本语法是 start:stop:step
,其中所有值都是可选的。
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
# 提取从索引2到索引7(不包括7)的元素
print(x[2:7]) # 输出 tensor([2, 3, 4, 5, 6])
# 使用步长提取元素
print(x[1:8:2]) # 输出 tensor([1, 3, 5, 7])
多维张量同样支持切片:
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
# 提取前3行的前2列
print(x[:3, :2])
# 输出
# tensor([[ 1, 2],
# [ 4, 5],
# [ 7, 8]])
3. 高级索引
除了基本的索引和切片外,PyTorch 还支持更高级的索引方法,例如使用整数列表或张量来进行索引:
x = torch.tensor([0, 1, 2, 3, 4])
# 使用整数列表进行索引
indices = [1, 3]
print(x[indices]) # 输出 tensor([1, 3])
# 使用整数张量进行索引
indices_tensor = torch.tensor([1, 3])
print(x[indices_tensor]) # 输出 tensor([1, 3])
4. 布尔索引
你还可以使用布尔值进行索引,选择满足某个条件的元素:
x = torch.tensor([0, 1, 2, 3, 4, 5])
# 创建一个布尔张量,选择所有大于3的值
mask = x > 3
print(mask) # 输出 tensor([False, False, False, False, True, True])
print(x[mask]) # 输出 tensor([4, 5])
这些索引和切片方法提供了在 PyTorch 中操作张量的强大工具,无论是为了数据处理还是为了更高级的张量操作。