飞道的博客

从零开始编写一个宠物识别系统(爬虫、模型训练和调优、模型部署、Web服务)

876人阅读  评论(0)

心血来潮,想从零开始编写一个相对完整的深度学习小项目。想到就做,那么首先要考虑的问题是,写什么?

思量再三,我决定写一个宠物识别系统,即给定一张图片,判断图片上的宠物是什么。宠物种类暂定为四类——猫、狗、鼠、兔。之所以想到做这个,是因为在不使用公开数据集的情况下,宠物图片数据集获取的难度相对低一些。

小项目分为如下几个部分:

  • 爬虫。从网络上下载宠物图片,构建训练用的数据集。
  • 模型构建、训练和调优。鉴于我们的数据比较少,这部分需要做迁移学习。
  • 模型部署和Web服务。将训练好的模型部署成web接口,并使用Vue.js + Element UI编写测试页面。

好嘞,开搞吧!

本文涉及到的所有代码,均已上传到GitHub:

pets_classifer (https://github.com/AaronJny/pets_classifer)

转载请注明来源:https://blog.csdn.net/aaronjny/article/details/103605988

一、爬虫

训练模型肯定是需要数据集的,那么数据集从哪来?因为是从零开始嘛,假设我们做的这个问题,业内没有公开的数据集,我们需要自己制作数据集。

一个很简单的想法是,利用搜索引擎搜索相关图片,使用爬虫批量下载,然后人工去除不正确的图片。举个例子,我们先处理猫的图片,步骤如下:

  • 1.使用搜索引擎搜索猫的图片。
  • 2.使用爬虫将搜索出的猫的图片批量下载到本地,放到一个名为cats的文件夹里面。
  • 3.人工浏览一遍图片,将“不包含猫”的图片和“除猫外还包含其他宠物(狗、鼠、兔)”的图片从文件夹中删除。

这样,猫的图片我们就搜集完成了,其他几个类别的图片也是类似的操作。不用担心人工过滤图片花费的时间较长,全部过一遍也就二十多分钟吧。

然后是搜索引擎的选择。搜索引擎用的比较多的无非两种——Google和百度。我分别使用Google和百度进行了图片搜索,发现百度的搜索结果远不如Google准确,于是就选择了Google,所以我的爬虫代码是基于Google编写的,运行我的爬虫代码需要你的网络能够访问Google。

如果你的网络不能访问Google,可以考虑自行实现基于百度的爬虫程序,逻辑都是相通的。

因为想让项目轻量级一些,故没有使用scrapy框架。爬虫使用requests+beautifulsoup4实现,并发使用gevent实现。

# -*- coding: utf-8 -*-
# @File    : spider.py
# @Author  : AaronJny
# @Time    : 2019/12/16
# @Desc    : 从谷歌下载指定图片
from gevent import monkey

monkey.patch_all()
import functools
import logging
import os
from bs4 import BeautifulSoup
from gevent.pool import Pool
import requests
import settings

# 设置日志输出格式
logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',
                    level=logging.INFO)

# 搜索关键词字典
keywords_map = settings.IMAGE_CLASS_KEYWORD_MAP

# 图片保存根目录
images_root = settings.IMAGES_ROOT
# 每个类别下载多少页图片
download_pages = settings.SPIDER_DOWNLOAD_PAGES
# 图片编号字典,每种图片都从0开始编号,然后递增
images_index_map = dict(zip(keywords_map.keys(), [0 for _ in keywords_map]))
# 图片去重器
duplication_filter = set()

# 请求头
headers = {
    'accept-encoding': 'gzip, deflate, br',
    'accept-language': 'zh-CN,zh;q=0.9',
    'user-agent': 'Mozilla/5.0 (Linux; Android 4.0.4; Galaxy Nexus Build/IMM76B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/46.0.2490.76 Mobile Safari/537.36',
    'accept': '*/*',
    'referer': 'https://www.google.com/',
    'authority': 'www.google.com',
}


# 重试装饰器
def try_again_while_except(max_times=3):
    """
    当出现异常时,自动重试。
    连续失败max_times次后放弃。
    """

    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            error_cnt = 0
            error_msg = ''
            while error_cnt < max_times:
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    error_msg = str(e)
                    error_cnt += 1
            if error_msg:
                logging.error(error_msg)

        return wrapper

    return decorator


@try_again_while_except()
def download_image(session, image_url, image_class):
    """
    从给定的url中下载图片,并保存到指定路径
    """
    # 下载图片
    resp = session.get(image_url, timeout=20)
    # 检查图片是否下载成功
    if resp.status_code != 200:
        raise Exception('Response Status Code {}!'.format(resp.status_code))
    # 分配一个图片编号
    image_index = images_index_map.get(image_class, 0)
    # 更新待分配编号
    images_index_map[image_class] = image_index + 1
    # 拼接图片路径
    image_path = os.path.join(images_root, image_class, '{}.jpg'.format(image_index))
    # 保存图片
    with open(image_path, 'wb') as f:
        f.write(resp.content)
    # 成功写入了一张图片
    return True


@try_again_while_except()
def get_and_analysis_google_search_page(session, page, image_class, keyword):
    """
    使用google进行搜索,下载搜索结果页面,解析其中的图片地址,并对有效图片进一步发起请求
    """
    logging.info('Class:{} Page:{} Processing...'.format(image_class, page + 1))
    # 记录从本页成功下载的图片数量
    downloaded_cnt = 0
    # 构建请求参数
    params = (
        ('q', keyword),
        ('tbm', 'isch'),
        ('async', '_id:islrg_c,_fmt:html'),
        ('asearch', 'ichunklite'),
        ('start', str(page * 100)),
        ('ijn', str(page)),
    )
    # 进行搜索
    resp = requests.get('https://www.google.com/search', params=params, timeout=20)
    # 解析搜索结果
    bsobj = BeautifulSoup(resp.content, 'lxml')
    divs = bsobj.find_all('div', {'class': 'islrtb isv-r'})
    for div in divs:
        image_url = div.get('data-ou')
        # 只有当图片以'.jpg','.jpeg','.png'结尾时才下载图片
        if image_url.endswith('.jpg') or image_url.endswith('.jpeg') or image_url.endswith('.png'):
            # 过滤掉相同图片
            if image_url not in duplication_filter:
                # 使用去重器记录
                duplication_filter.add(image_url)
                # 下载图片
                flag = download_image(session, image_url, image_class)
                if flag:
                    downloaded_cnt += 1
    logging.info('Class:{} Page:{} Done. {} images downloaded.'.format(image_class, page + 1, downloaded_cnt))


def search_with_google(image_class, keyword):
    """
    通过google下载数据集
    """
    # 创建session对象
    session = requests.session()
    session.headers.update(headers)
    # 每个类别下载10页数据
    for page in range(download_pages):
        get_and_analysis_google_search_page(session, page, image_class, keyword)


def run():
    # 首先,创建数据文件夹
    if not os.path.exists(images_root):
        os.mkdir(images_root)
    for sub_images_dir in keywords_map.keys():
        # 对于每个图片类别都创建一个单独的文件夹保存
        sub_path = os.path.join(images_root, sub_images_dir)
        if not os.path.exists(sub_path):
            os.mkdir(sub_path)
    # 开始下载,这里使用gevent的协程池进行并发
    pool = Pool(len(keywords_map))
    for image_class, keyword in keywords_map.items():
        pool.spawn(search_with_google, image_class, keyword)
    pool.join()


if __name__ == '__main__':
    run()

项目中涉及到的所有配置参数,都提取到了settings.py中,内容如下,以供查阅:

# -*- coding: utf-8 -*-
# @File    : settings.py
# @Author  : AaronJny
# @Time    : 2019/12/16
# @Desc    :


# ##########爬虫############

# 图片类别和搜索关键词的映射关系
IMAGE_CLASS_KEYWORD_MAP = {
    'cats': '宠物猫',
    'dogs': '宠物狗',
    'mouses': '宠物鼠',
    'rabbits': '宠物兔'
}
# 图片保存根目录
IMAGES_ROOT = './images'
# 爬虫每个类别下载多少页图片
SPIDER_DOWNLOAD_PAGES = 20

# #########数据###########

# 每个类别选取的图片数量
SAMPLES_PER_CLASS = 345
# 参与训练的类别
CLASSES = ['cats', 'dogs', 'mouses', 'rabbits']
# 参与训练的类别数量
CLASS_NUM = len(CLASSES)
# 类别->编号的映射
CLASS_CODE_MAP = {
    'cats': 0,
    'dogs': 1,
    'mouses': 2,
    'rabbits': 3
}
# 编号->类别的映射
CODE_CLASS_MAP = {
    0: '猫',
    1: '狗',
    2: '鼠',
    3: '兔'
}
# 随机数种子
RANDOM_SEED = 13  # 四个类别时样本较为均衡的随机数种子
# RANDOM_SEED = 19  # 三个类别时样本较为均衡的随机数种子
# 训练集比例
TRAIN_DATASET = 0.6
# 开发集比例
DEV_DATASET = 0.2
# 测试集比例
TEST_DATASET = 0.2
# mini_batch大小
BATCH_SIZE = 16
# imagenet数据集均值
IMAGE_MEAN = [0.485, 0.456, 0.406]
# imagenet数据集标准差
IMAGE_STD = [0.299, 0.224, 0.225]

# #########训练#########

# 学习率
LEARNING_RATE = 0.001
# 训练epoch数
TRAIN_EPOCHS = 30
# 保存训练模型的路径
MODEL_PATH = './model.h5'

# ########Web#########

# Web服务端口
WEB_PORT = 5000

爬虫使用Google进行图片搜索,每个宠物搜索10页,下载其中的所有图片。当爬虫运行完成后,项目下会多出一个images文件夹,点进去有四个子文件夹,分别为catsdogsmousesrabbits。每一个子文件夹里面是对应类别的宠物图片。

其中猫图片600+张,狗图片600+张,鼠图片400+张,兔图片500+张。花二十多分钟时间,过一遍全部图片,剔除其中不符合要求的图片。注意,这一步是必做的,而且要认真对待,我吃了亏的= =

进行一轮筛选后,剩下图片张数:

宠物 图片数量
521
526
346
345

考虑各类别样本均衡的问题,无非是过采样和欠采样。因为是图片数据,也可以使用数据增强的手段,为图片数量较少的类别生成一些图片,使样本数量均衡。但出于如下原因考虑,我直接做了欠采样,即每个类别只选取了345张样本:

  • 使用数据增强的话,需要在原图片的基础上,重新生成一份数据集,嫌麻烦……
  • 使用数据增强后,样本数量比较多,无法同时读取到内存里面,只能写个生成器,处理哪一部分的时候,实时从硬盘读取。弊端有俩:①频繁读取硬盘,肯定比不上所有数据都放在内存里面,会拖慢训练速度;②还是嫌麻烦……

说到底就是自己太懒了……当然,可想而知,使用数据增强(在这里,数据增强可以作为一种过采样的方式)使数据样本都达到526,训练的效果肯定会更好,能好多少就不知道了,有兴趣的可以自行实现,没啥难点,就是麻烦点。

下面该对数据做预处理了。很多经典的模型接收的输入格式都为(None,224,224,3),由于我们的样本较少,不可避免地需要用到迁移学习,所以我们的数据格式与经典模型保持一致,也使用(None,224,224,3),下面是预处理过程:

# -*- coding: utf-8 -*-
# @File    : data.py
# @Author  : AaronJny
# @Time    : 2019/12/16
# @Desc    :
import os
import random
import tensorflow as tf
import settings

# 每个类别选取的图片数量
samples_per_class = settings.SAMPLES_PER_CLASS
# 图片根目录
images_root = settings.IMAGES_ROOT
# 类别->编码的映射
class_code_map = settings.CLASS_CODE_MAP

# 我们准备使用经典网络在imagenet数据集上的与训练权重,所以归一化时也要使用imagenet的平均值和标准差
image_mean = tf.constant(settings.IMAGE_MEAN)
image_std = tf.constant(settings.IMAGE_STD)


def normalization(x):
    """
    对输入图片x进行归一化,返回归一化的值
    """
    return (x - image_mean) / image_std


def train_preprocess(x, y):
    """
    对训练数据进行预处理。
    注意,这里的参数x是图片的路径,不是图片本身;y是图片的标签值
    """
    # 读取图片
    x = tf.io.read_file(x)
    # 解码成张量
    x = tf.image.decode_jpeg(x, channels=3)
    # 将图片缩放到[244,244],比输入[224,224]稍大一些,方便后面数据增强
    x = tf.image.resize(x, [244, 244])
    # 随机决定是否左右镜像
    if random.choice([0, 1]):
        x = tf.image.random_flip_left_right(x)
    # 随机从x中剪裁出(224,224,3)大小的图片
    x = tf.image.random_crop(x, [224, 224, 3])
    # 读完上面的代码可以发现,这里的数据增强并不增加图片数量,一张图片经过变换后,
    # 仍然只是一张图片,跟我们前面说的增加图片数量的逻辑不太一样。
    # 这么做主要是应对我们的数据集里可能会存在相同图片的情况。

    # 将图片的像素值缩放到[0,1]之间
    x = tf.cast(x, dtype=tf.float32) / 255.
    # 归一化
    x = normalization(x)

    # 将标签转成one-hot形式
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, settings.CLASS_NUM)

    return x, y


def dev_preprocess(x, y):
    """
    对验证集和测试集进行数据预处理的方法。
    和train_preprocess的主要区别在于,不进行数据增强,以保证验证结果的稳定性。
    """
    # 读取并缩放图片
    x = tf.io.read_file(x)
    x = tf.image.decode_jpeg(x, channels=3)
    x = tf.image.resize(x, [224, 224])
    # 归一化
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = normalization(x)
    # 将标签转成one-hot形式
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, settings.CLASS_NUM)

    return x, y


# (图片路径,标签)的列表
image_path_and_labels = []
# 排序,保证每次拿到的顺序都一样
sub_images_dir_list = sorted(list(os.listdir(images_root)))
# 遍历每一个子目录
for sub_images_dir in sub_images_dir_list:
    sub_path = os.path.join(images_root, sub_images_dir)
    # 如果给定路径是文件夹,并且这个类别参与训练
    if os.path.isdir(sub_path) and sub_images_dir in settings.CLASSES:
        # 获取当前类别的编码
        current_label = class_code_map.get(sub_images_dir)
        # 获取子目录下的全部图片名称
        images = sorted(list(os.listdir(sub_path)))
        # 随机打乱(排序和置随机数种子都是为了保证每次的结果都一样)
        random.seed(settings.RANDOM_SEED)
        random.shuffle(images)
        # 保留前settings.SAMPLES_PER_CLASS个
        images = images[:samples_per_class]
        # 构建(x,y)对
        for image_name in images:
            abs_image_path = os.path.join(sub_path, image_name)
            image_path_and_labels.append((abs_image_path, current_label))
# 计算各数据集样例数
total_samples = len(image_path_and_labels)  # 总样例数
train_samples = int(total_samples * settings.TRAIN_DATASET)  # 训练集样例数
dev_samples = int(total_samples * settings.DEV_DATASET)  # 开发集样例数
test_samples = total_samples - train_samples - dev_samples  # 测试集样例数
# 打乱数据集
random.seed(settings.RANDOM_SEED)
random.shuffle(image_path_and_labels)
# 将图片数据和标签数据分开,此时它们仍是一一对应的
x_data = tf.constant([img for img, label in image_path_and_labels])
y_data = tf.constant([label for img, label in image_path_and_labels])
# 开始划分数据集
# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_data[:train_samples], y_data[:train_samples]))
# 打乱顺序,数据预处理,设置批大小
train_db = train_db.shuffle(10000).map(train_preprocess).batch(settings.BATCH_SIZE)
# 开发集(验证集)
dev_db = tf.data.Dataset.from_tensor_slices(
    (x_data[train_samples:train_samples + dev_samples], y_data[train_samples:train_samples + dev_samples]))
# 数据预处理,设置批大小
dev_db = dev_db.map(dev_preprocess).batch(settings.BATCH_SIZE)
# 测试集
test_db = tf.data.Dataset.from_tensor_slices(
    (x_data[train_samples + dev_samples:], y_data[train_samples + dev_samples:]))
# 数据预处理,设置批大小
test_db = test_db.map(dev_preprocess).batch(settings.BATCH_SIZE)

二、模型构建、训练和调优

数据已经全部处理完毕,该考虑模型了。首先,我们数据集太小了,直接构建自己的网络并训练,并不是一个好方案。因为这几种宠物其实挺难区分的,所以模型需要有一定复杂度,才能很好拟合这些数据,但我们的数据又太少了,最后的结果一定是过拟合,而且还是救不回来的那种= =所以我们考虑从迁移学习入手。

什么是迁移学习?懒得重新组织语言的我,默默地从之前写的博文里面摘了一段:

一般认为,深度卷积神经网络的训练是对数据集特征的一步步抽取的过程,从简单的特征,到复杂的特征。
训练好的模型学习到的是对图像特征的抽取方法,所以在imagenet数据集上训练好的模型理论上来说,也可以直接用于抽取其他图像的特征,这也是迁移学习的基础。自然,这样的效果往往没有在新数据上重新训练的效果好,但能够节省大量的训练时间,在特定情况下非常有用。

上面说的特定情况也包括我们面临的这一种——用于实际问题的数据集过小。

说到迁移学习,我最先想到的是VGG16,就先用VGG16搞了一波。使用在imagenet数据集上预训练的VGG16网络,去除顶部的全连接层,冻结全部参数,使它们在接下来的训练中不会改变。然后加上自己的全连接层,最后的输出层节点为4,对应于我们的四分类问题。开始训练。

模型在训练集上的误差很快降到5%以下,但是在验证集上的准确率基本在70+%,很明显,过拟合了。好嘛,盘它!主要使用如下方法尝试解决过拟合问题:

  • 调节全连接层的层数和每层的节点数
  • 添加BN层(虽说不是为了解决过拟合问题诞生的,但一定程度上是有效果的)
  • 添加Dropout层
  • 调节Dropout Rate
  • 添加l2正则

一顿操作猛如虎,回头一看0-5。这些方法确实对过拟合有所缓解,验证集上的准确率也确实有所提升,但只能达到81%左右。

然后我尝试了Resnet50,当然也过拟合了,盘它!最后验证集accuracy能达到83%左右。

很明显了,在全连接层的调整意义不大,究其根本,在于VGG16和ResNet50去除了全连接层之后,参数的数量也达到了20M+。两千万的参数使得模型严重过拟合,所以我们需要换一个参数少一点的模型。

于是,我盯上了DenseNet121,它的参数数量只有7M。继续盘它!果然,在一段时间的调优后,模型的性能有了明显的提升,验证集上的accuracy达到了87%左右。虽然和ResNet相比,准确率只高了4%,但相比于ResNet50 96%的训练accuracy而言,DenseNet121的训练accuracy只有90%左右。也就是说,对于DenseNet121而言,这个问题已经不再是过拟合问题了(相差3%我是可以接受的),而是欠拟合了。

然而淡腾的是,再怎么调参,模型都很难继续拟合了,调小学习率也不行。模型本身没啥问题的话,我开始怀疑数据集有没有问题,毕竟这种无法拟合的问题有很大概率是数据导致的。于是我就去检查了一下数据集……

这就是我前面强调认真过一遍数据集的原因了,我当时只是花个几分钟粗略地过了一下,删除掉一些明显不对的图片。我第二次认真过数据集的时候才发现,有很多异常图片没有过滤掉,比如猫的目录下有狗的图片,狗的目录下有猫的图片,还有一些不同动物同框的图片,以及我自己都认不出来的图片……

文章第一部分中各类图片数量的表格,其实就是我第二遍过滤后的结果统计。

过滤完成后,模型的性能有了明显的提升,训练accuracy约为93%-94%,验证accuracy为94%,测试accuracy为92%.我们先来看一下代码,后面会对这个结果再进行分析。

首先,是模型的构建:

# -*- coding: utf-8 -*-
# @File    : models.py
# @Author  : AaronJny
# @Time    : 2019/12/16
# @Desc    :
import tensorflow as tf
import settings


def my_densenet():
    """
    创建并返回一个基于densenet的Model对象
    """
    # 获取densenet网络,使用在imagenet上训练的参数值,移除头部的全连接网络,池化层使用max_pooling
    densenet = tf.keras.applications.DenseNet121(include_top=False, weights='imagenet', pooling='max')
    # 冻结预训练的参数,在之后的模型训练中不会改变它们
    densenet.trainable = False
    # 构建模型
    model = tf.keras.Sequential([
        # 输入层,shape为(None,224,224,3)
        tf.keras.layers.Input((224, 224, 3)),
        # 输入到DenseNet121中
        densenet,
        # 将DenseNet121的输出展平,以作为全连接层的输入
        tf.keras.layers.Flatten(),
        # 添加BN层
        tf.keras.layers.BatchNormalization(),
        # 随机失活
        tf.keras.layers.Dropout(0.5),
        # 第一个全连接层,激活函数relu
        tf.keras.layers.Dense(512, activation=tf.nn.relu),
        # BN层
        tf.keras.layers.BatchNormalization(),
        # 随机失活
        tf.keras.layers.Dropout(0.5),
        # 第二个全连接层,激活函数relu
        tf.keras.layers.Dense(64, activation=tf.nn.relu),
        # BN层
        tf.keras.layers.BatchNormalization(),
        # 输出层,为了保证输出结果的稳定,这里就不添加Dropout层了
        tf.keras.layers.Dense(settings.CLASS_NUM, activation=tf.nn.softmax)
    ])

    return model


if __name__ == '__main__':
    model = my_densenet()
    model.summary()

网络的summary:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
densenet121 (Model)          (None, 1024)              7037504   
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
batch_normalization (BatchNo (None, 1024)              4096      
_________________________________________________________________
dropout (Dropout)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 512)               524800    
_________________________________________________________________
batch_normalization_1 (Batch (None, 512)               2048      
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 64)                32832     
_________________________________________________________________
batch_normalization_2 (Batch (None, 64)                256       
_________________________________________________________________
dense_2 (Dense)              (None, 4)                 260       
=================================================================
Total params: 7,601,796
Trainable params: 561,092
Non-trainable params: 7,040,704
_________________________________________________________________

参数总量7601796个,其中可训练参数561092个 。

模型和数据都已准备完毕,可以开始训练了。让我们编写一个训练用的脚本:

# -*- coding: utf-8 -*-
# @File    : train.py
# @Author  : AaronJny
# @Time    : 2019/12/17
# @Desc    :
import tensorflow as tf
from data import train_db, dev_db
import models
import settings

# 从models文件中导入模型
model = models.my_densenet()
model.summary()

# 配置优化器、损失函数、以及监控指标
model.compile(tf.keras.optimizers.Adam(settings.LEARNING_RATE), loss=tf.keras.losses.categorical_crossentropy,
              metrics=['accuracy'])

# 在每个epoch结束后尝试保存模型参数,只有当前参数的val_accuracy比之前保存的更优时,才会覆盖掉之前保存的参数
model_check_point = tf.keras.callbacks.ModelCheckpoint(filepath=settings.MODEL_PATH, monitor='val_accuracy',
                                                       save_best_only=True)
# 使用tf.keras的高级接口进行训练
model.fit_generator(train_db, epochs=settings.TRAIN_EPOCHS, validation_data=dev_db, callbacks=[model_check_point])

现在,我们可以运行脚本进行训练了,最优的参数将被保存在settings.MODEL_PATH。训练完成后,我们需要调用验证脚本,验证下模型在验证集和测试集上的表现:

# -*- coding: utf-8 -*-
# @File    : eval.py
# @Author  : AaronJny
# @Time    : 2019/12/17
# @Desc    :
import tensorflow as tf
from data import dev_db, test_db
from models import my_densenet
import settings

# 创建模型
model = my_densenet()
# 加载参数
model.load_weights(settings.MODEL_PATH)
# 因为想用tf.keras的高级接口做验证,所以还是需要编译模型
model.compile(tf.keras.optimizers.Adam(settings.LEARNING_RATE), loss=tf.keras.losses.categorical_crossentropy,
              metrics=['accuracy'])
# 验证集accuracy
print('dev', model.evaluate(dev_db))
# 测试集accuracy
print('test', model.evaluate(test_db))

输出如下:

18/18 [==============================] - 5s 304ms/step - loss: 0.1936 - accuracy: 0.9457
dev [0.19364455559601387, 0.9456522]
18/18 [==============================] - 1s 64ms/step - loss: 0.2666 - accuracy: 0.9203
test [0.26657224384446937, 0.9202899]

能够看到,模型在验证集上的准确率为94.57%,在测试集上的准确率为92.03%,已经达到我的心里预期了,毕竟这么少的数据,还要啥自行车?

随着训练epoch的增多,模型的训练accuracy始终在[0.92,0.95]左右徘徊不定,没法继续拟合。究其原因,应该还是数据的锅。我们看一下识别错的样本,在eval.py脚本中,增加下面这一段程序:

# 查看识别错误的数据
for x, y in test_db:
    y_pred = model(x)
    y_pred = tf.argmax(y_pred, axis=1).numpy()
    y_true = tf.argmax(y, axis=1).numpy()
    batch_size = y_pred.shape[0]
    for i in range(batch_size):
        if y_pred[i] != y_true[i]:
            print('{} 被错误识别成 {}!'.format(settings.CODE_CLASS_MAP[y_true[i]], settings.CODE_CLASS_MAP[y_pred[i]]))

重新跑一下eval.py脚本,输出如下:

18/18 [==============================] - 5s 291ms/step - loss: 0.1936 - accuracy: 0.9457
dev [0.19364455559601387, 0.9456522]
18/18 [==============================] - 1s 64ms/step - loss: 0.2666 - accuracy: 0.9203
test [0.26657224384446937, 0.9202899]
狗 被错误识别成 兔!
狗 被错误识别成 兔!
狗 被错误识别成 兔!
鼠 被错误识别成 兔!
狗 被错误识别成 猫!
鼠 被错误识别成 猫!
狗 被错误识别成 兔!
狗 被错误识别成 鼠!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!
猫 被错误识别成 兔!
猫 被错误识别成 鼠!
猫 被错误识别成 兔!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!
狗 被错误识别成 猫!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!
鼠 被错误识别成 兔!
狗 被错误识别成 猫!
鼠 被错误识别成 兔!
狗 被错误识别成 兔!

来,跟我一起唱——都是兔子惹的祸~

能够看到,出错的大部分都是被误识别成兔子了。对应到数据集上,虽然已经删掉了部分问题比较大的图片,但兔子的图片确实不好认。有很多兔子图片我人工分辨都认不出是兔子(捂脸.jpg)。然后,有些兔子图片看起来很像猫,有些看起来很像狗,有些看起来很像鼠……

如果我们把兔子图片去掉,将系统改为三分类问题,准确度将大幅度提高。当然了,按理说识别的类别数量变了,除了调整输出层的节点数量外,要想取得最佳效果,模型的其他参数也需要做相应调整的。我自己已经实测了,但限于篇幅,就不演示了,如果有兴趣的话,可以直接在settings.py里进行调整,将它变为三分类问题。改这两个地方:

# 参与训练的类别
CLASSES = ['cats', 'dogs', 'mouses', 'rabbits']
# 随机数种子
RANDOM_SEED = 13  # 四个类别时样本较为均衡的随机数种子

改成:

# 参与训练的类别
CLASSES = ['cats', 'dogs', 'mouses']
# 随机数种子
RANDOM_SEED = 19  # 三个类别时样本较为均衡的随机数种子

然后重新训练和验证即可。这只是一个插曲,本文仍然以四分类问题继续说明后续内容。

三、Web接口编写

模型训练好了,我们要把它应用起来。我准备编写一个Web服务,用户可以通过浏览器上传一张图片,服务器判断此图片的类别后,返回相关数据给用户。Web后端使用Flask,小而轻,前端则选用Vue.js + Element-UI实现。

先写后端:

# -*- coding: utf-8 -*-
# @File    : app.py
# @Author  : AaronJny
# @Time    : 2019/12/18
# @Desc    :
from flask import Flask
from flask import jsonify
from flask import request, render_template
import tensorflow as tf
from models import my_densenet
import settings

app = Flask(__name__)

# 导入模型
model = my_densenet()
# 加载训练好的参数
model.load_weights(settings.MODEL_PATH)


@app.route('/', methods=['GET'])
def index():
    """
    首页,vue入口
    """
    return render_template('index.html')


@app.route('/api/v1/pets_classify/', methods=['POST'])
def pets_classify():
    """
    宠物图片分类接口,上传一张图片,返回此图片上的宠物是那种类别,概率多少
    """
    # 获取用户上传的图片
    img_str = request.files.get('file').read()
    # 进行数据预处理
    x = tf.image.decode_image(img_str, channels=3)
    x = tf.image.resize(x, (224, 224))
    x = x / 255.
    x = (x - tf.constant(settings.IMAGE_MEAN)) / tf.constant(settings.IMAGE_STD)
    x = tf.reshape(x, (1, 224, 224, 3))
    # 预测
    y_pred = model(x)
    pet_cls_code = tf.argmax(y_pred, axis=1).numpy()[0]
    pet_cls_prob = float(y_pred.numpy()[0][pet_cls_code])
    pet_cls_prob = '{}%'.format(int(pet_cls_prob * 100))
    pet_class = settings.CODE_CLASS_MAP.get(pet_cls_code)
    # 将预测结果组织成json
    res = {
        'code': 0,
        'data': {
            'pet_cls': pet_class,
            'probability': pet_cls_prob,
            'msg': '<br><br><strong style="font-size: 48px;">{}</strong> <span style="font-size: 24px;"'
                   '>概率<strong>{}</strong></span>'.format(pet_class, pet_cls_prob),
        }
    }
    # 返回json数据
    return jsonify(res)


if __name__ == '__main__':
    app.run(port=settings.WEB_PORT)

后端脚本app.py很简单,主要就两个方法。其中index方法会返回首页的html源码,是用户在浏览器端的访问入口;另一个方法pets_classify则提供了计算给定图片类别的功能。

前端文件index.html主要是提供了一个照片墙,用户上传图片到照片墙,服务器就会计算图片类别并返回相关数据。代码如下:

<!DOCTYPE html>
<html>
<head>
    <meta charset="UTF-8">
    <!-- import CSS -->
    <link rel="stylesheet" href="https://unpkg.com/element-ui/lib/theme-chalk/index.css">
</head>
<body>
<div id="app">
    <el-card class="box-card">
        <div slot="header" class="clearfix">
            <h1>宠物识别Demo</h1>
        </div>
        <el-upload
                action="http://localhost:5000/api/v1/pets_classify/"
                list-type="picture-card"
                :on-preview="handlePictureCardPreview"
                :on-success="handleUploadSuccess"
                :on-remove="handleRemove">
            <i class="el-icon-plus"></i>
        </el-upload>
        <el-dialog :visible.sync="dialogVisible">
            <img width="100%" :src="dialogImageUrl" alt="">
        </el-dialog>
    </el-card>
</div>
</body>

<!-- import Vue before Element -->
<script src="https://unpkg.com/vue/dist/vue.js"></script>
<!-- import JavaScript -->
<script src="https://unpkg.com/element-ui/lib/index.js"></script>
<script>
    new Vue({
        el: '#app',
        data() {
            return {
                dialogImageUrl: '',
                dialogVisible: false
            };
        },
        methods: {
            handleRemove(file, fileList) {
                console.log(file, fileList);
                console.log(this.dialogImageUrl);
            },
            handlePictureCardPreview(file) {
                this.dialogImageUrl = file.url;
                this.dialogVisible = true;
            },
            handleUploadSuccess(response, file, fileList) {
                this.$notify({
                    title: '识别结果',
                    message: response.data.msg,
                    dangerouslyUseHTMLString: true,
                    type: 'success',
                    duration: 3000
                });
            }
        }
    })
</script>
</html>

让我们试试效果。首先,运行app.py脚本,启动web服务,当你看到如下输出时,说明服务启动成功了:

 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
 * Serving Flask app "app" (lazy loading)
 * Environment: production
   WARNING: This is a development server. Do not use it in a production deployment.
   Use a production WSGI server instead.
 * Debug mode: off

因为只是开发环境,这么启动就可以了。如果是生产环境,请不要这么做,可以选择使用nginx + gunicorn +uWSGI + gevent进行部署。

四、测试

打开浏览器,输入 http://localhost:5000 进入index页面。页面长这个样子:

点击网页中的上传框,我们可以选择图片上传并识别:

当然了,这里不选择我们数据集里的图片更好,哪怕是测试集里的。你可以去网上下载、或者通过其他渠道获取这四种动物的图片来测试,这里我只做演示,就不搞那么麻烦了,直接从数据集里随便选几张照片。我们可以继续上传图片给服务器识别:










OK,演示到此为止,如果有兴趣的话可以自行测试。

结语

文章到此结束,如果您喜欢的话,给我点个赞呗~

菜鸟一只,欢迎大佬们拍砖~


转载:https://blog.csdn.net/aaronjny/article/details/103605988
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场