4.模型的保存与载入

4.1 模型的保存

模型保存实际上就是保存已经训练好的参数。

# 周期训练
for epoch in range(10):
    print('epoch:',epoch)
    train()
    test()

# 在训练完成之后就可保存模型
torch.save(model.state_dict(),'model/简单识别1.pth')

4.2 模型的载入

# 定义网络结构
class Net(nn.Module):
    def __init__(self):

    def forward(self,x):

        return x

LR=0.5
# 定义模型
model=Net()
# 在定义模型之后就可对其进行载入
model.load_state_dict(torch.load(model/简单识别1.pth))