4.模型的保存与载入
4.1 模型的保存
模型保存实际上就是保存已经训练好的参数。
# 周期训练
for epoch in range(10):
print('epoch:',epoch)
train()
test()
# 在训练完成之后就可保存模型
torch.save(model.state_dict(),'model/简单识别1.pth')
4.2 模型的载入
# 定义网络结构
class Net(nn.Module):
def __init__(self):
def forward(self,x):
return x
LR=0.5
# 定义模型
model=Net()
# 在定义模型之后就可对其进行载入
model.load_state_dict(torch.load(model/简单识别1.pth))
Comments NOTHING