5.3 项目实战-猫狗分类预测
import torch
import numpy as np
from PIL import Image
from torchvision import transforms,models
model = models.vgg16(pretrained = True)
# 构建新的全连接层
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),
torch.nn.ReLU(),
torch.nn.Dropout(p=0.5),
torch.nn.Linear(100, 2))
# 这里我们载入全训练的参数,效果要好一些
model.load_state_dict(torch.load('model/猫狗分类全训练'))
label = np.array(['猫','狗'])
# 数据预处理
# 变成统一大小
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor()
])
def predict(image_path):
# 打开图片
img = Image.open(image_path)
# 数据处理,再增加一个维度
img = transform(img).unsqueeze(0)
# 预测得到结果
outputs = model(img)
# 获得最大值所在位置
_, predicted = torch.max(outputs,1)
# 转化为类别名称
print(label[predicted.item()])
predict('项目实战/猫狗识别/image/test/dog/dog.1010.jpg')
## 狗
predict('项目实战/猫狗识别/image/test/cat/cat.1010.jpg')
## 猫


Comments NOTHING