小言_互联网的博客

2020年tensorflow定制训练模型笔记(1)——object detection的安装

273人阅读  评论(0)

自己看着网上的很多教程摸索了好几天,终于能够自己训练。事实上,网上关于这个API的教程还是非常多的,但我实际做起来发现其实在某些关键部分缺少点步骤,会把我这样的小白搞得一头雾水、无从下手,最后在无穷无尽的报错中崩溃。所以我决定写这篇笔记,一来帮助最初像我一样的小白轻松搞定,二来就是为自己做笔记 ,以后万一忘记了,可以回来看看回想一下。

电脑配置

  • cpu: i7-8750H
  • gpu: 1060 6G
  • 内存: 16G
  • 操作系统:win10

我的环境

  • anaconda 3.x
  • python 3.6
  • tensorflow 1.15

cpu或gpu版本都可,具体怎么安装我就不介绍了,2.0版本我也试过,只有一步骤我暂时没有办法用2.0解决所以不写2.0。

tf2.0与1.0的一些区别

报错:AttributeError: module 'tensorflow' has no attribute 'contrib'
这个库在测试环节和训练环节使用,而tf2.0移除了contrib。就是这个原因使得我无法用tf2.0成功训练模型。
其他步骤1.0和2.0通用。因为我用的是高版本的1.0,可以识别一些2.0的内容,所以我的代码都改成适用于2.0的内容。例如tf.Session ->tf.compat.v1.Session
当然tf2.0也有专门升级的代码:

tf_upgrade_v2 --infile foo.py --outfile foo-upgraded.py
#tf_upgrade_v2 --infile 1.*版.py文件 --outfile 生成的2.0版.py文件

正式开始前,我还想说明下,有些类似报错找不到这个库的问题,我这里就不多说了,直接去anaconda去下载缺少的库,一些特殊的库我会特别说一下的。
另外,我这篇是自己的学习笔记,我也用了大量的别人的代码,用了的我基本都会写明出处,如果不妥的,请手下留情。

1.object detection

1.1下载

首先,我们得从GitHub上下载object detection。下载地址如下:
https://github.com/tensorflow/models
下载后解压你就会得到models-master的文件夹(ps:他文件里的路径都叫models,不知道为什么解压后叫models-master,这里因为我自己是改成models了,为了方便我自己截图演示,后面出现的都是models)

1.2配置

激活tensorflow环境,来到models/research/object_detection/ 文件夹下你会看到一个object_detection_tutorial.ipynb文件,这是一个demo文件,在tensorflow环境下用jupyter打开它出现如下界面:(我现在时间是2020.2.26,这是现在最新的版本)
我们按照他的install步骤一步步来:
第一行:安装tensorflow2.0,我们不要理他,1.0也是可以用的,把这句话用#注释掉或直接删掉。
第二行:安装pycocotools,windows安装pycocotools必须先得安装cython,所以我们先去下载cython到tensorflow的环境里。你可以使用命令行,但我没试过,我是在anaconda里完成的,我个人感觉很方便。
安装cython后,你要么在jupyter里运行这一行代码。我实际操作中感觉太慢了,就去上网找了另一个方法。在tf环境下用命令行:

pip install git+https://github.com/philferriere/cocoapi.git#subdirectory=PythonAPI

这好像是哪个大佬写的支持 Windows 的 COCO 地址,原网址是:https://www.jianshu.com/p/8658cda3d553
需要注意的是这个网是国外的,所以国内下载还是会有点慢,你要不等着要不那个(不说了)
第三行:可运行可不运行,就是下载这个API
第四行:配置路径,在jupyter里运行这个代码有点慢,所以我直接在命令行里操作

cd models/research/
protoc object_detection/protos/*.proto --python_out=.

第五行:安装包,也是直接在命令行操作

pip install .

1.3测试

至此,我们也就完全搭好了环境,install的代码以后其基本都不会用了,全部注释掉或者删掉。接下来直接所有的运行代码,也许你会成功,但我是失败的,我到最后要出结果的时候内核死机了,一开始我为了这个问题苦苦寻求答案很久,直到最近我才发觉这可能是系统的问题,因为写这代码的人是在Linux上运行的,可能代码里的哪一部分与win有冲突,反正这demo代码里有块内容是运行不起来的,不是我们环境和电脑的问题。
办法嘛,是有的,我找了好多教程,才找到一个适合我们这种情况的代码

import os
import sys
import cv2
import numpy as np
import tensorflow as tf
sys.path.append("..")

from utils import label_map_util
from utils import visualization_utils as vis_util


class TOD(object):
    def __init__(self):
        # 这是用于对象检测的实际模型的路径,如果没有这个pb文件,说明你还未下载。可以用demo里的下载代码来替换。
        self.PATH_TO_CKPT = 'ssd_mobilenet_v1_coco_2017_11_17/frozen_inference_graph.pb'
        #  用于为每个框添加正确标签的字符串的列表的路径。
        self.PATH_TO_LABELS = os.path.join('data', 'mscoco_label_map.pbtxt')

        self.NUM_CLASSES = 90

        self.detection_graph = self._load_model()
        self.category_index = self._load_label_map()

    def _load_model(self):
        detection_graph = tf.Graph()
        with detection_graph.as_default():
            od_graph_def = tf.compat.v1.GraphDef()
            with tf.io.gfile.GFile(self.PATH_TO_CKPT, 'rb') as fid:
                serialized_graph = fid.read()
                od_graph_def.ParseFromString(serialized_graph)
                tf.import_graph_def(od_graph_def, name='')
        return detection_graph

    def _load_label_map(self):
        label_map = label_map_util.load_labelmap(self.PATH_TO_LABELS)
        categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=self.NUM_CLASSES, use_display_name=True)
        category_index = label_map_util.create_category_index(categories)
        return category_index

    def detect(self, image):
        with self.detection_graph.as_default():
            with tf.compat.v1.Session(graph=self.detection_graph) as sess:
                # Expand dimensions since the model expects images to have shape: [1, None, None, 3]  扩展维度,因为模型期望图像具有以下形状:[1,None, None, 3]
                image_np_expanded = np.expand_dims(image, axis=0)
                image_tensor = self.detection_graph.get_tensor_by_name('image_tensor:0')
                # Each box represents a part of the image where a particular object was detected.  每个框表示检测到特定对象的图像的一部分。
                boxes = self.detection_graph.get_tensor_by_name('detection_boxes:0')
                # Each score represent how level of confidence for each of the objects.  每个分数表示每个对象的置信度。
                # Score is shown on the result image, together with the class label.  分数与类标签一起显示在结果图像上。
                scores = self.detection_graph.get_tensor_by_name('detection_scores:0')
                classes = self.detection_graph.get_tensor_by_name('detection_classes:0')
                num_detections = self.detection_graph.get_tensor_by_name('num_detections:0')
                # Actual detection.  实际检测。
                (boxes, scores, classes, num_detections) = sess.run(
                    [boxes, scores, classes, num_detections],
                    feed_dict={image_tensor: image_np_expanded})
                # Visualization of the results of a detection.  检测结果的可视化。
                vis_util.visualize_boxes_and_labels_on_image_array(
                    image,
                    np.squeeze(boxes),
                    np.squeeze(classes).astype(np.int32),
                    np.squeeze(scores),
                    self.category_index,
                    use_normalized_coordinates=True,
                    line_thickness=8)

        while True:
            cv2.namedWindow("detection", cv2.WINDOW_NORMAL)
            cv2.imshow("detection", image)
            if cv2.waitKey(110) & 0xff == 27:
                break


if __name__ == '__main__':
    image = cv2.imread('test_images/image1.jpg')#测试照片的路径
    detecotr = TOD()
    detecotr.detect(image)

该代码出自Tensorflow object detection API 搭建属于自己的物体识别模型(转载修改)这篇文章。我个人是非常感谢这位博主,它解决了我的测试不成功的问题,上面这个代码要注意的是三个路径,我在代码里已经标注出来了,这三个地方根据自己文件夹的情况自行更改。测试路径你也可以改成一个for循环测试一组图片。
运行此代码,这张带标注的照片终于出现了:

小结

因为之前被内核死掉的事崩溃太久了,我不仅新版本demo会崩溃,老版本的demo也是加载不出照片,只有这个用上cv的代码才可以。这张照片一出来,我当时才舒了一口气,顿时对后面的训练充满信心了有木有。
下一篇我们继续,接下来是生成训练文件的事了…
如果你知道具体的原因,可以的话就教教我吧,在下面的评论留个言或者私我。


转载:https://blog.csdn.net/weixin_45569617/article/details/104511366
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场