飞道的博客

神经网络听上去高大上?带你从零开始训练一个网络(基于MNIST)

409人阅读  评论(0)

1 什么是神经网络?

我们知道人工智能发展过程中有三个主要流派:

  • 符号主义学派
  • 连接主义学派
  • 行为主义学派

其中连接主义学派认为:人脑中万亿个神经元细胞间错综复杂的互相连接,是智能产生的来源。连接主义的核心是仿生学和神经科学,关注的是神经网络间的连接机制和学习算法,致力于通过计算机表示大量神经元,以模拟大脑的智力。这就是神经网络的起源,其正式定义如下:

人工神经网络(Artificial Neural Networks,ANNs)也简称为神经网络或称作连接模型,它是一种模仿动物神经网络行为特征,进行分布式并行信息处理的算法数学模型。这种网络依靠系统的复杂程度,通过调整内部大量节点之间相互连接的关系,从而达到处理信息的目的。

如下图所示是一个用于图形图像处理的卷积神经网络框架,今天就从零开始训练处下面这种网络。

2 卷积神经网络

卷积神经网络(Convolutional Neural Network, CNN)在全连接神经网络的基础上增加了特征提取层,主要用于计算机视觉领域,处理模式识别、图像分类、目标检测等问题。

CNN相比FCNN更适于处理视觉任务的原因在于,其实现了高维信息的聚合与压缩

举例而言,以一张二维图片的像素为输入,则FCNN的输入层神经元个数将非常庞大,再考虑每个神经元都相邻层所有神经元相连,因此作为优化目标的连接权矩阵指数增长,带来无法接受的学习时间复杂度。而CNN通过对图片信息的特征筛选,滤除图片掺杂的大量冗余信息,再通过简单的全连接网络映射到输出空间去,将大幅降低复杂度。

3 实验流程

本问针对经典的MNIST手写数字分类实验,基于Pytorch框架自主设计神经网络,测试网络性能,并进行一定的可视化分析。实验的流程如下:

  1. 搭建卷积神经网络;
  2. 加载数据集。下载MNIST手写数字数据集,划分训练集、验证集和测试集,并封装为可迭代的数据加载器对象;
  3. 训练模型。定义损失函数和优化方法,通过前向传播计算损失,再基于反向传播优化模型参数,迭代至训练误差收敛后保存模型到本地;

3.1 搭建神经网络

如下所示,搭建卷积神经网络。其中

  • Conv2d:卷积层
  • MaxPool2d:池化层
  • ReLu:激活函数

这些神经网络组件的具体原理与作用另开文章分析,本章关注于应用实践。

class CNN(nn.Module):
    '''
    * @breif: 卷积神经网络
    '''    
    def __init__(self):
        super().__init__()
        self.convPoolLayer_1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU()
        )
        self.convPoolLayer_2 = nn.Sequential(
            nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU()
        )
        self.fcLayer = nn.Linear(320, 10)

    def forward(self, x):
        batchSize = x.size(0)
        x = self.convPoolLayer_1(x)
        x = self.convPoolLayer_2(x)
        x = x.reshape(batchSize, -1)
        x = self.fcLayer(x)
        return x

 

3.2 加载数据集

使用pytorch提供的Dataset类进行数据集加载和预览

from abc import abstractmethod
import numpy as np
from torchvision.datasets import mnist
from torch.utils.data import Dataset
from PIL import Image

class mnistData(Dataset):
    '''
    * @breif: MNIST数据集抽象接口
    * @param[in]: dataPath -> 数据集存放路径
    * @param[in]: transforms -> 数据集变换
    '''    
    def __init__(self, dataPath: str, transforms=None) -> None:
        super().__init__()
        self.dataPath = dataPath
        self.transforms = transforms
        self.data, self.label = [], []

    def __len__(self) -> int:
        return len(self.label)

    def __getitem__(self, idx: int):
        img = self.data[idx]
        if self.transforms:
            img = self.transforms(img)
        return img, self.label[idx]

    def loadData(self, train: bool) -> list:
        '''
        * @breif: 下载与加载数据集
        * @param[in]: train -> 是否为训练集
        * @retval: 数据与标签列表
        '''    
        # 如果指定目录下不存在数据集则下载
        dataSet   = mnist.MNIST(self.dataPath, train=train, download=True)
        # 初始化数据与标签
        data  = [ i[0] for i in dataSet ]
        label = [ i[1] for i in dataSet ]
        return data, label

 

3.3 训练模型

考虑到该实践是多分类问题,因此最终网络的输出是十维向量并经过softmax转化为概率分布,损失函数设计为交叉熵,优化方法选择随机梯度下降算法。

for images, labels in trainBar:
    images, labels = images.to(config.device), labels.to(config.device)
    # 梯度清零
    opt.zero_grad()
    # 正向传播
    outputs = model(images)
    # 计算损失
    loss = F.cross_entropy(outputs, labels)
    # 反向传播
    loss.backward()
    # 模型更新
    opt.step()

训练过程如下:

4 算法分析

经过测试,同样的学习率下,CNN在20代左右就已收敛,但FCNN在20代左右才开始收敛。测试集泛化误差表明CNN的预测准确率达到95%,但FCNN只有70%,因此在图像分类问题上,CNN的效率和准确率都远远高于FCNN。因此CNN的学习和泛化能力较强,只要用已知模式对CNN加以训练,网络就具有对输入输出之间的映射和表达能力。

本期图书推荐

《Python深度学习:模型、方法与实现》

【书籍简介】

  • 本书集合了基于应用领域的高级深度学习的模型、方法和实现。本书分为四部分。第1部分介绍了深度学习的构建和神经网络背后的数学知识。第二部分讨论深度学习在计算机视觉领域的应用。第三部分阐述了自然语言和序列处理。讲解了使用神经网络提取复杂的单词向量表示。讨论了各种类型的循环网络,如长短期记忆网络和门控循环单元网络。第四部分介绍一些虽然还没有被广泛采用但有前途的深度学习技术,包括如何在自动驾驶上应用深度学习。
  • 学完本书,读者将掌握与深度学习相关的关键概念,学会如何使用TensorFlow和PyTorch实现相应的AI解决方案。

【抽奖方式】

  1. 关注博主,点赞收藏文章,并做出有效评论
  2. 根据评论记录随机抽取2位用户赠送实体图书
  3. 截止日期:7.17日晚8点,届时通过blink公布获奖信息,请中奖用户及时私信

🔥 更多精彩专栏


👇源码获取 · 技术交流 · 抱团学习 · 咨询分享 请联系👇

转载:https://blog.csdn.net/FRIGIDWINTER/article/details/125707637
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场