PyTorch中的一个tensor分为头信息区(Tensor)和存储区(Storage)两部分。
头信息区主要保存着tensor的形状(size)、步长(stride)、数据类型(type)等信息,而真正的数据则保存成连续数组,存储在存储区。
一般情况下,一个tensor都会有相对应的Storage,但有时候多个tensor都对应着相同的一个Storage,这几个tensor只是头信息区不同。如下所示,转置操作就是这种情况。
import torch
## 建立一个 size 是 (4, 3) 的 tensor
a = torch.arange(1, 13).view(4, 3)
## 对 a 进行转置操作
a_t = a.t()
print(a.is_contiguous()) ## True
print(a_t.is_contiguous()) ## False
a_t_v = a_t.view(-1) ## 会报错
通过上面的例子我们可以看到,对一个 tensor 进行转置操作之后会改变它的 contiguous 特性,然后再使用 view 操作就会报错。这时候我们就需要在 view 操作之前把转置之后的 tensor 变成 contiguous。
import torch
## 建立一个 size 是 (4, 3) 的 tensor
a = torch.arange(1, 13).view(4, 3)
## 对 a 进行转置操作并把转置之后的 tensor 变连续
a_t = a.t().contiguous()
print(a.is_contiguous()) ## True
print(a_t.is_contiguous()) ## True
a_t_v = a_t.view(-1) ## 正常运行
大家可以思考一下,为什么转置之后 tensor 就不连续了呢?tensor.contiguous() 又是怎么使 tensor 变的连续的呢?
在上面的代码示例中, 执行 transpose 操作的时候系统就不会另外分配内存,所以示例中的 a 和 a_t 是共用一块内存的,只不过读取内存的顺序有所不同。因此,对 a 来说,数据是按维度存储在内存里的,但对 a_t 来说,它的维度改变了,数据存的方式却并没有改变,这就解释了为什么 a 是 contiguous,而 a_t 不是 contiguous。再解释对 a_t 使用 tensor.contiguous(),它的作用是为 a_t 申请一块新的内存,并让数据按照 a_t 的维度进行存储。它们之间的差别可以通过读取 tensor 的 stride 来区分,如下所示:
import torch
## 建立一个 size 是 (4, 3) 的 tensor
a = torch.arange(1, 13).view(4, 3)
print(a.stride()) ## (3, 1)
## 对 a 进行转置操作
a_t = a.t()
print(a_t.stride()) ## (1, 3)
a_t = a_t.contiguous()
print(a_t.stride()) ## (4, 1)
view 操作要求 tensor 具有相应的 contiguous 特性,其实并不是说 tensor 一定要是 contiguous 才能进行 view 操作,而是 view 操作执行的维度上数据一定是要按这个维度进行存储的。