5.3 项目实战-猫狗分类预测

import torch
import numpy as np
from PIL import Image
from torchvision import transforms,models

model = models.vgg16(pretrained = True)
# 构建新的全连接层
model.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(100, 2))
# 这里我们载入全训练的参数,效果要好一些
model.load_state_dict(torch.load('model/猫狗分类全训练'))

label = np.array(['猫','狗'])

# 数据预处理
# 变成统一大小
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor() 
])

def predict(image_path):
    # 打开图片
    img = Image.open(image_path)
    # 数据处理,再增加一个维度
    img = transform(img).unsqueeze(0)
    # 预测得到结果
    outputs = model(img)
    # 获得最大值所在位置
    _, predicted = torch.max(outputs,1)
    # 转化为类别名称
    print(label[predicted.item()])

predict('项目实战/猫狗识别/image/test/dog/dog.1010.jpg')
## 狗

predict('项目实战/猫狗识别/image/test/cat/cat.1010.jpg')
## 猫

file

file