计算图(Computational Graph)就是记录运算过程的有向无环图,比如前向传播时输入张量经过加、减、乘、除得到输出张量,那么计算图就会记录输入输出张量、加减乘除运算和一些中间变量,这是进行反向传播的前提。
(1)Pytorch的计算图就是动态的,几乎每进行一次运算都会拓展原先的计算图,最后生成完成。
(2)当反向传播完成,计算图默认会被清除,所以只能用生成的计算图进行一次反向传播。
(3)retain_graph
参数可以保持计算图,从而避免别清除掉,其用法为:loss.backward(retain_graph=True)
。代码举例如下所示:
import torch
x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
y = x ** 2
z = y * 4
print(x)
print(y)
print(z)
loss1 = z.mean()
loss2 = z.sum()
print(loss1,loss2)
loss1.backward() # 这个代码执行正常,但是计算图会释放
loss2.backward() # 这时会引发错误
程序正常执行到最后一行报错:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
代码分析:计算loss1的backward的时候,计算图x-y-z
结构被释放了,而计算loss2的backward仍然试图利用x-y-z
的结构,因此会报错。只需要设置retain_graph
参数为True即可保留计算图,从而两个loss的backward()不会相互影响。正确的代码应当是:
# 这里参数表明保留backward后的中间参数。
loss1.backward(retain_graph=True)
# 执行完这个代码后,所有中间变量都会被释放,以便下一次的循环
loss2.backward()