小言_互联网的博客

基于pytorch搭建Resnet18网络结构

395人阅读  评论(0)

我们都知道神经网络深度不断增加,会出现两个问题:

1.梯度弥散、梯度爆炸
2.退化问题,训练集的准确率下降

基于以上问题何凯明于2015年提出了残差神经网络(ResNet),此网络在深度够深的前提下训练集的准确率也不会下降太多,而且能一直保持最优状态,那它具体是怎么构建的呢?

以Resnet18为例,它是由残差块堆叠而成的网络:
1个卷积层+8个残差块(每个残差块有2个卷积层)+1个全连接层
如下图:

我个人认为ResNet是最适合深度学习小白研究的网络之一,知晓了原理,看着图就能搭建好网络,代码风格清晰明朗,网上有博主形容其“简单而实用”,我看确实如此。在一开始学习ResNet时,我试着成功搭建10层的Resnet用于训练我的垃圾分类数据集,今晚将其改成18层。

代码如下:

import torch
from torch import nn
from torch.nn import functional as F


class ResBlk(nn.Module):
    """
    resnet block
    """

    def __init__(self, ch_in, ch_out, stride=1):
        """
        :param ch_in:
        :param ch_out:
        """
        super(ResBlk, self).__init__()

        self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(ch_out)
        self.conv2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(ch_out)

        self.extra = nn.Sequential()
        if ch_out != ch_in:
            # [b, ch_in, h, w] => [b, ch_out, h, w]
            self.extra = nn.Sequential(
                nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride),
                nn.BatchNorm2d(ch_out)
            )

    def forward(self, x):
        """
        :param x: [b, ch, h, w]
        :return:
        """
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = self.extra(x) + out
        out = F.relu(out)

        return out


class ResNet18(nn.Module):

    def __init__(self, num_class):
        super(ResNet18, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        # conv1:[2,3,224,224]=>[2,64,56,56]

        self.blk1 = ResBlk(64, 64)
        # [2, 64, 56, 56] => [2, 64, 56, 56]

        self.blk2 = ResBlk(64, 128, stride=2)
        # [2, 64, 56, 56]=>[2, 128, 28, 28]

        self.blk2_1 = ResBlk(128, 128)
        # [2, 128, 28, 28] =>[2, 128, 28, 28]

        self.blk3 = ResBlk(128, 256, stride=2)
        # [2, 128, 28, 28]=>[2, 256, 14, 14]

        self.blk3_1 = ResBlk(256, 256)
        # [2, 256, 14, 14]=>[2, 256, 14, 14]

        self.blk4 = ResBlk(256, 512, stride=2)
        # [2, 256, 14, 14]=>[2, 512, 7, 7]

        self.blk4_1 = ResBlk(512, 512)
        # [2, 512, 7, 7]=>[2, 512, 7, 7]

        self.pool2 = nn.AvgPool2d(kernel_size=7, stride=1, padding=0)

        self.outlayer = nn.Linear(512, num_class)

    def forward(self, x):
        """
        :param x:
        :return:
        """
        x = F.relu(self.conv1(x))  # conv1:[b,3,224,224]=>[b,64,56,56]
        # print(x.shape)

        x = self.blk1(x)  # [b, 64, 56, 56]=>[2, 64, 56, 56]

        x = self.blk1(x)  # [b, 64, 56, 56]=>[2, 64, 56, 56]

        x = self.blk2(x)  # [2, 64, 56, 56]=>[2, 128, 28, 28]

        x = self.blk2_1(x)  # [2, 128, 28, 28] =>[2, 128, 28, 28]

        x = self.blk3(x)  # [2, 128, 28, 28]=>[2, 256, 14, 14]

        x = self.blk3_1(x)  # [2, 256, 14, 14]=>[2, 256, 14, 14]

        x = self.blk4(x)  # [2, 256, 14, 14]=>[2, 512, 7, 7]

        x = self.blk4_1(x)  # [2, 512, 7, 7]=>[2, 512, 7, 7]

        x = self.pool2(x)  # [2, 512, 7, 7]=>[2,512,1,1]

        x = x.view(x.size(0), -1)  # flatten

        x = self.outlayer(x)

        return x


def main():
    model = ResNet18(7)
    tmp = torch.randn(1, 3, 224, 224)
    out = model(tmp)
    print('resnet:', out.shape)


if __name__ == '__main__':
    main()

参考:
(图出自链接1,0基础ResNet建议看第2篇)

https://blog.csdn.net/weixin_44331304/article/details/106127552?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522161779576016780271556440%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=161779576016780271556440&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_positive~default-1-106127552.first_rank_v2_pc_rank_v29&utm_term=resnet18

https://blog.csdn.net/weixin_44331304/article/details/106127552?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522161779576016780271556440%2522%252C%2522scm%2522%253A%252220140713.130102334…%2522%257D&request_id=161779576016780271556440&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2alltop_positive~default-1-106127552.first_rank_v2_pc_rank_v29&utm_term=resnet18

https://www.bilibili.com/video/BV1j64y1D748?p=102&spm_id_from=pageDriver


转载:https://blog.csdn.net/weixin_48638088/article/details/115495988
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场