3.2 mnist简单数据识别-使用交叉熵代价函数进行优化

# 分类应用中一般使用用交叉熵代价函数,回归中使用二次代价函数
# 手写数字
import numpy as np
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import torch

# 训练集
train_dataset=datasets.MNIST(root='./',train=True,transform=transforms.ToTensor(),download=True)
# 测试集
test_dataset=datasets.MNIST(root='./',train=False,transform=transforms.ToTensor(),download=True)

# 批次大小
batch_size =64
# 装载训练集
train_loader= DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader= DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)

for i,data in enumerate(train_loader):
    inputs,labels=data
    print(inputs.shape)
    print(labels.shape)
    break
    ## torch.Size([64, 1, 28, 28])
    ## torch.Size([64])

# 定义网络结构
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        # 先定义一个简单的784个输入,10个输出的神经网络
        self.fc1=nn.Linear(784,10)
        # 输出是(64,10)这样的,前面的是序号,后面的是识别结果,对后面这个维度进行softmax让它转化为概率,因为从0开始,所以dim=1
        self.softmax=nn.Softmax(dim=1)
    def forward(self,x):
        #([64, 1, 28, 28])->(64,784)   -1表示自动匹配
        x=x.view(x.size()[0],-1)
        x=self.fc1(x)
        x=self.softmax(x)
        return x

LR=0.5
# 定义模型
model=Net()
# 定义代价函数
mse_loss=nn.CrossEntropyLoss()
# 定义优化器
optimizer=optim.SGD(model.parameters(),LR)

# 模型训练
def train():
    for i,data in enumerate(train_loader):
        # 获得一个批次的数据和标签
        inputs,labels=data
        # 获得模型预测值
        out=model(inputs)
        # 交叉熵代价函数,不需要shape一致,它会自动独热编码
        loss=mse_loss(out,labels)
        # 梯度清0
        optimizer.zero_grad()
        # 梯度计算
        loss.backward()
        # 修改权值
        optimizer.step()
# 模型测试
def test():
    correct=0
    for i,data in enumerate(test_loader):
        # 获得一个批次的数据和标签
        inputs,labels=data
        # 获得模型预测值
        out=model(inputs)
        # 获得最大值(忽略),以及最大值所在的位置
        _,predicted=torch.max(out,1)
        # 计算正确的个数
        correct+=(predicted==labels).sum()
    # 输出正确率
    print("Test acc:{0}".format(correct.item()/len(test_dataset)))

# 可以发现使用交叉熵效果更好,参数收敛速度更快
for epoch in range(10):
    print('epoch:',epoch)
    train()
    test()
    ## epoch: 0
    ## Test acc:0.9067
    ## epoch: 1
    ## Test acc:0.9141
    ## epoch: 2
    ## Test acc:0.9171
    ## epoch: 3
    ## Test acc:0.9203
    ## epoch: 4
    ## Test acc:0.9216
    ## epoch: 5
    ## Test acc:0.9229
    ## epoch: 6
    ## Test acc:0.9231
    ## epoch: 7
    ## Test acc:0.9244
    ## epoch: 8
    ## Test acc:0.9247
    ## epoch: 9
    ## Test acc:0.9242