model.state_dict()函数的作用是保存模型,如下所示:
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.nn1 = nn.Linear(2, 3)
self.nn2 = nn.Linear(3, 6)
def forward(self, x):
x = F.relu(self.nn1(x))
return F.relu(self.nn2(x))
model = MyModel()
torch.save(model.state_dict(), "model_weights.pth")
state_dict是字典类型,具体来说,是OrderedDict字典类型。OrderedDict类型本质是list列表,其元素为元组,如下所示:
print(model.state_dict())
OrderedDict([
("nn1.weight", tensor([[-0.1838, -0.2477],[ 0.4845, 0.3157],[-0.5628, 0.3612]])),
("nn1.bias", tensor([-0.4328, -0.6779, 0.3845])),
("nn2.weight", tensor([[-5.0585e-01, -4.6973e-01, 1.6044e-02],[-3.4606e-01, 1.1130e-01, -2.0727e-01],
[-3.9844e-02, -4.2531e-01, 8.2558e-02],[ 3.3171e-02, -3.4334e-01, 4.5039e-01],
[-2.5320e-04, -5.2037e-01, 1.3504e-02],[-3.0776e-01, 8.9345e-02, -1.1076e-01]])),
("nn2.bias", tensor([ 0.1229, -0.2344, 0.0568, -0.3430, 0.2715, -0.3521]))
])