在PyTorch中,任何一个向量tensor都具有自动梯度计算的功能,因为动态计算图是PyTorch天然的特性,但是梯度计算是件十分耗费内存资源任务,所以某些情况下禁用梯度计算十分有必要。
torch.no_grad()
函数是禁用梯度计算的上下文管理器。当我们确信不会调用backward()
时,禁用梯度计算很有用,因为它将减少计算的内存消耗。在这种模式下,即使输入的向量的requires_grad=True
,每次计算的结果也将为requires_grad=False
。但是,有种例外:所有工厂函数或创建新张量的函数,都不受此模式的影响,如下代码所示。
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
>>> @torch.no_grad
... def tripler(x):
... return x * 3
>>> z = tripler(x)
>>> z.requires_grad
False
>>> # 工厂函数并不受no_grad的影响
>>> with torch.no_grad():
... a = torch.nn.Parameter(torch.rand(10))
>>> a.requires_grad
True
工厂函数是用于生成tensor的函数。常见的工厂函数有torch.rand
、torch.randint
、torch.randn
、torch.eye
等。