PyTorch中的一个tensor分为头信息区(Tensor)和存储区(Storage)两部分。
头信息区主要保存着tensor的形状(size)、步长(stride)、数据类型(type)等信息,而真正的数据则保存成连续数组,存储在存储区。
一般情况下,一个tensor都会有相对应的Storage,但有时候多个tensor都对应着相同的一个Storage,这几个tensor只是头信息区不同。
张量的步长(stride)是定义了在存储中每个维度的偏移量,以便找到下一个元素。具体来说,步长指定了在存储中移动指针时应该跨过的字节数。步长是一个元组,其中每个元素对应于张量的一个维度。步长的长度与张量的维度数相同。为了更好地理解步长的概念和功能,让我们通过一些示例来说明。假设我们有一个形状为(2, 3, 4)的三维张量tensor,并且通过以下代码创建和初始化它:
import torch
tensor = torch.zeros(2, 3, 4)
print(tensor.stride())
现在我们可以查看张量的步长,在PyTorch中可以使用stride()方法来实现,结果为:
(12, 4, 1)
表示在存储中移动指针时,每个维度需要跨越的字节数。 我们可以看到,在这个示例中,第一个维度需要跨越12个字节才能找到下一个元素,第二个维度需要跨越4个字节,而第三个维度只需要跨越1个字节。这是因为在内存中,张量中的元素是以连续的方式排列的,所以根据张量的尺寸和数据类型,PyTorch会自动计算合适的步长。
除了查看张量的步长,我们还可以通过修改步长来改变张量的存储方式。 PyTorch提供了as_strided()函数来实现这一点。
torch.as_strided(input, size, stride, storage_offset=0)—>Tensor
此方法是根据现有tensor以及给定的步长来创建一个视图(类型仍然为tensor)。视图是指创建一个方便查看的东西,与原数据共享内存,它并不占用内存,也不存储数据,只是将原有的数据进行整理,显示其中部分内容或者进行重排序后显示出来等等。
接下来介绍各个参数:
input:此参数指定了在哪个数据上创建视图,input需为tensor。
size:指定了生成的视图的大小,需要为一个矩阵(当然此矩阵大小可以大于原矩阵,但是也有限制),可以是tensor或者list等等。
stride:输出tensor的步长,根据原矩阵和步长生成了新矩阵,此参数后面会细讲。
storage_offset:输出张量的基础存储中的偏移量。
考虑以下示例,我们有一个(2, 3)的二维张量tensor2,并且通过以下代码创建和初始化它:
tensor2 = torch.zeros(2, 3)
我们可以使用as_strided()函数来修改tensor2的步长,如下所示:
new_tensor = torch.as_strided(tensor2, (2, 6), (3, 1))
在这个示例中,我们将tensor2的步长设定为(3, 1),这将导致新张量new_tensor在存储中跳过一些元素。具体来说,对于new_tensor中的每个元素,我们需要跳过3个字节才能到达下一个元素,而不是原始tensor2中的1个字节。