脚本专栏 
首页 > 脚本专栏 > 浏览文章

pytorch 利用lstm做mnist手写数字识别分类的实例

(编辑:jimmy 日期: 2025/1/16 浏览:3 次 )

代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定义数据
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定义模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用两层lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#将最后一个的rnn使用全连接的到最后的输出结果
     
   def forward(self, x):
     #x的大小为(batch,1,28,28),所以我们需要将其转化为rnn的输入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,变成(batch, 28,28)
     x = x.permute(2, 0, 1)#将最后一维放到第一维,变成(batch,28,28)
     out, _ = self.rnn(x) #使用默认的隐藏状态,得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一个,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分类结果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定义训练过程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)    

以上这篇pytorch 利用lstm做mnist手写数字识别分类的实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。

上一篇:pytorch1.0中torch.nn.Conv2d用法详解
下一篇:pytorch实现mnist分类的示例讲解
一句话新闻
高通与谷歌联手!首款骁龙PC优化Chrome浏览器发布
高通和谷歌日前宣布,推出首次面向搭载骁龙的Windows PC的优化版Chrome浏览器。
在对骁龙X Elite参考设计的初步测试中,全新的Chrome浏览器在Speedometer 2.1基准测试中实现了显著的性能提升。
预计在2024年年中之前,搭载骁龙X Elite计算平台的PC将面世。该浏览器的提前问世,有助于骁龙PC问世就获得满血表现。
谷歌高级副总裁Hiroshi Lockheimer表示,此次与高通的合作将有助于确保Chrome用户在当前ARM兼容的PC上获得最佳的浏览体验。
友情链接:杰晶网络 DDR爱好者之家 南强小屋 黑松山资源网 白云城资源网 站点导航 SiteMap