3.1 mnist简单数据识别
mnist数据集是手写数字识别所用的数据集,接下来,将用它实战各种深度学习算法。
这一节先进行一个最简单的,学习一下深度学习流程,所以直接全连接输出。
# 手写数字
import numpy as np
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import torch
# 训练集
train_dataset=datasets.MNIST(root='./',train=True,transform=transforms.ToTensor(),download=True)
# 测试集
test_dataset=datasets.MNIST(root='./',train=False,transform=transforms.ToTensor(),download=True)
# 批次大小
batch_size =64
# 装载训练集
train_loader= DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader= DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):
inputs,labels=data
print(inputs.shape)
print(labels.shape)
break
## torch.Size([64, 1, 28, 28])
## torch.Size([64])
# 定义网络结构
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
# 先定义一个简单的784个输入,10个输出的神经网络
self.fc1=nn.Linear(784,10)
# 输出是(64,10)这样的,前面的是序号,后面的是识别结果,对后面这个维度进行softmax让它转化为概率,因为从0开始,所以dim=1
self.softmax=nn.Softmax(dim=1)
def forward(self,x):
#([64, 1, 28, 28])->(64,784) -1表示自动匹配
x=x.view(x.size()[0],-1)
x=self.fc1(x)
x=self.softmax(x)
return x
LR=0.5
# 定义模型
model=Net()
# 定义代价函数
mse_loss=nn.MSELoss()
# 定义优化器
optimizer=optim.SGD(model.parameters(),LR)
# 模型训练
def train():
for i,data in enumerate(train_loader):
# 获得一个批次的数据和标签
inputs,labels=data
# 获得模型预测值
out=model(inputs)
# to onehot,把数据标签变成独热编码
# (64)->(64,1)
#tensor.scatter(dim,index,src) dim:对哪个维度进行独热编码 index:要将src中对应的值放入tensor的哪个位置 src:插入index的数值
labels=labels.reshape(-1,1)
one_hot=torch.zeros(inputs.shape[0],10).scatter(1,labels,1)
# 计算loss,mes_loss的两个数据的shape要一致
loss=mse_loss(out,one_hot)
# 梯度清0
optimizer.zero_grad()
# 梯度计算
loss.backward()
# 修改权值
optimizer.step()
# 模型测试
def test():
correct=0
for i,data in enumerate(test_loader):
# 获得一个批次的数据和标签
inputs,labels=data
# 获得模型预测值
out=model(inputs)
# 获得最大值(忽略),以及最大值所在的位置
_,predicted=torch.max(out,1)
# 计算正确的个数
correct+=(predicted==labels).sum()
# 输出正确率
print("Test acc:{0}".format(correct.item()/len(test_dataset)))
for epoch in range(10):
print('epoch:',epoch)
train()
test()
## epoch: 0
## Test acc:0.8876
## epoch: 1
## Test acc:0.9022
## epoch: 2
## Test acc:0.907
## epoch: 3
## Test acc:0.9107
## epoch: 4
## Test acc:0.9137
## epoch: 5
## Test acc:0.916
## epoch: 6
## Test acc:0.9177
## epoch: 7
## Test acc:0.9179
## epoch: 8
## Test acc:0.9182
## epoch: 9
## Test acc:0.9189
Comments NOTHING