torch.stack()函数的参数形式为:torch.stack(inputs,dim=0,out=None)
,其作用是将若干个形状相同的张量在dim维度上连接,生成一个扩维的张量。比如,我们原本有若干个2维张量,连接之后可以得到一个3维的张量。
注:python的序列数据只有list和tuple。
0<=dim<len(outputs.shape)
。注:len(outputs.shape)是生成数据的维度大小,也就是outputs的维度值,它比inputs的维度要多出一个。
此处主要以二维向量的stack为主,因为二维向量经过stack之后变成三维向量,大家容易理解:
dim=0时,将tensor进行叠加从而形成三维向量,这种情况比较容易理解。
dim=1时,将每个tensor的第i行抽出连接组成一个新的2维tensor,然后进行叠加从而形成三维向量。
dim=2时,将每个tensor的第i列抽出连接组成一个新的2维tensor,然后进行叠加从而形成三维向量。