在 PyTorch 中,expand()
和 repeat()
都是用来复制张量的方法,但它们的方式和用途有所不同。下面详细介绍这两个函数:
1. expand()
expand()
用于沿特定维度广播(扩展)张量,但它并不真正复制数据,而是返回一个新的视图,该视图在某些维度上具有更大的大小。这意味着 expand()
创建的张量与原始张量共享内存。
基本用法:
tensor.expand(*sizes)
sizes
是一个可变参数,表示要扩展的每个维度的大小。
示例:
import torch
x = torch.tensor([[1], [2], [3]]) # shape: (3, 1)
y = x.expand(3, 4) # shape: (3, 4)
print(y)
输出:
tensor([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3]])
在 PyTorch 中使用 expand()
方法时,新的大小必须满足以下条件才与原始大小兼容:
维度数量:新的张量维度数量可以与原始张量相同或更大。如果新的维度数量较大,则原始张量会被视为在左侧有大小为 1 的额外维度。
维度大小:
- 如果原始张量在某个维度上的大小为 1,那么这个维度可以在新张量中被扩展为任何大小。
- 如果原始张量在某个维度上的大小大于 1,那么新张量在该维度上的大小必须与原始张量相同。
让我们通过几个例子来说明这一点:
例1:
原始张量形状:(3, 1)
允许的扩展形状包括:(3, 4)
、(3, 7)
等,因为第二个维度大小为 1,所以可以被扩展为任何大小。
例2:
原始张量形状:(1, 4)
允许的扩展形状包括:(5, 4)
、(1, 4)
、(3, 4)
等,因为第一个维度大小为 1,所以可以被扩展为任何大小,但第二个维度必须保持大小为 4。
例3:
原始张量形状:(2, 3)
允许的扩展形状:只有 (2, 3)
。因为两个维度的大小都大于1,所以它们不能被扩展。
例4:
原始张量形状:(5,)
允许的扩展形状包括:(5,)
、(1, 5)
、(4, 5)
等。增加的新维度大小可以为任何值,但原始维度大小必须保持为 5。
这些规则确保 expand()
操作能够在不复制数据的情况下,通过只改变形状和步长来有效地进行广播。
2. repeat()
与 expand()
不同,repeat()
真实地复制数据来增加张量的大小。您需要为每个维度提供一个重复次数。
基本用法:
tensor.repeat(*sizes)
sizes
是一个可变参数,表示要重复的每个维度的次数。
示例:
x = torch.tensor([[1, 2]]) # shape: (1, 2)
y = x.repeat(3, 2) # shape: (3, 4)
print(y)
输出:
tensor([[1, 2, 1, 2],
[1, 2, 1, 2],
[1, 2, 1, 2]])
区别:
-
内存使用:
expand()
不真正复制数据,只是返回一个新的、具有更大维度的视图,而repeat()
实际上复制了数据。 -
用途:
expand()
通常用于广播操作,这是一种非常高效的方法,特别是在与其他形状不同的张量进行操作时(例如,在矩阵与向量的加法中)。而repeat()
则更多地用于真实地复制和堆叠数据。
记住这两者之间的差异是很重要的,因为在某些情况下使用不恰当的方法可能会导致不必要的内存使用或计算效率低下。