飞道的博客

balancap/SSD-Tensorflow使用及训练预测自己的数据集

207人阅读  评论(0)

该版本的SSD实现github路径 https://github.com/balancap/SSD-Tensorflow

所用库配置: python 3.6.0

                       tensorflow 1.11

                       Keras 2.1.5

 

下载完毕后checkpoints下已经有训练好的模型,可以用此模型来预测下自带的测试图片;以及对视频内物体进行定位;此可参见该博客https://blog.csdn.net/zzz_cming/article/details/81128460

接下来咱们配置下自己的数据集。

1. 主目录下新建一个文件夹,用于存放原图、标注图、及参与训练和验证集的样本分布文本,这里取名为VOC2007

Annotations和JPEGImages的制作见我的之前的博客https://blog.csdn.net/jiugeshao/article/details/116084611

2. 下面要生成满足VOC2007数据集格式的ImageSets\Main里的四个txt文件

可以如上新建一个GenerateTXT.py文件


  
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # author:Icecream.Shao
  4. # -*- coding:utf-8 -*-
  5. # -*- author:zzZ_CMing CSDN address:https://blog.csdn.net/zzZ_CMing
  6. # -*- 2018/07/18; 15:19
  7. # -*- python3.5
  8. import os
  9. import random
  10. trainval_percent = 0.7
  11. train_percent = 0.8
  12. xmlfilepath = 'Annotations/'
  13. txtsavepath = 'ImageSets/Main'
  14. total_xml = os.listdir(xmlfilepath)
  15. num = len(total_xml)
  16. list = range(num)
  17. tv = int(num*trainval_percent)
  18. tr = int(tv*train_percent)
  19. trainval = random.sample(list,tv)
  20. train = random.sample(trainval,tr)
  21. ftrainval = open(txtsavepath+ '/trainval.txt', 'w')
  22. ftest = open(txtsavepath+ '/test.txt', 'w')
  23. ftrain = open(txtsavepath+ '/train.txt', 'w')
  24. fval = open(txtsavepath+ '/val.txt', 'w')
  25. for i in list:
  26. name = total_xml[i][: -4]+ '\n'
  27. if i in trainval:
  28. ftrainval.write(name)
  29. if i in train:
  30. ftrain.write(name)
  31. else:
  32. fval.write(name)
  33. else:
  34. ftest.write(name)
  35. ftrainval.close()
  36. ftrain.close()
  37. fval.close()
  38. ftest .close()
  39. print( 'Well Done!!!')

3.每个框架所用的文件格式是不一样的,这里需要做转化,可以使用主目录 下的tf_convert_data.py文件


  
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Convert a dataset to TFRecords format, which can be easily integrated into
  16. a TensorFlow pipeline.
  17. Usage:
  18. ```shell
  19. python tf_convert_data.py \
  20. --dataset_name=pascalvoc \
  21. --dataset_dir=/tmp/pascalvoc \
  22. --output_name=pascalvoc \
  23. --output_dir=/tmp/
  24. ```
  25. """
  26. import tensorflow as tf
  27. from datasets import pascalvoc_to_tfrecords
  28. FLAGS = tf.app.flags.FLAGS
  29. tf.app.flags.DEFINE_string(
  30. 'dataset_name', 'pascalvoc',
  31. 'The name of the dataset to convert.')
  32. tf.app.flags.DEFINE_string(
  33. 'dataset_dir', '.\\VOC2007\\',
  34. 'Directory where the original dataset is stored.')
  35. tf.app.flags.DEFINE_string(
  36. 'output_name', 'mydata_train',
  37. 'Basename used for TFRecords output files.')
  38. tf.app.flags.DEFINE_string(
  39. 'output_dir', '.\\tfrecords\\',
  40. 'Output directory where to store TFRecords files.')
  41. def main(_):
  42. if not FLAGS.dataset_dir:
  43. raise ValueError( 'You must supply the dataset directory with --dataset_dir')
  44. print( 'Dataset directory:', FLAGS.dataset_dir)
  45. print( 'Output directory:', FLAGS.output_dir)
  46. if FLAGS.dataset_name == 'pascalvoc':
  47. pascalvoc_to_tfrecords.run(FLAGS.dataset_dir, FLAGS.output_dir, FLAGS.output_name)
  48. else:
  49. raise ValueError( 'Dataset [%s] was not recognized.' % FLAGS.dataset_name)
  50. if __name__ == '__main__':
  51. tf.app.run()

 

4. 修改datasets目录下的pascalvoc_common.py文件中的VOC_LABELS变量

5. 修改pascalvoc_to_tfrecords.py中的代码


  
  1. # Copyright 2015 Paul Balanca. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Converts Pascal VOC data to TFRecords file format with Example protos.
  16. The raw Pascal VOC data set is expected to reside in JPEG files located in the
  17. directory 'JPEGImages'. Similarly, bounding box annotations are supposed to be
  18. stored in the 'Annotation directory'
  19. This TensorFlow script converts the training and evaluation data into
  20. a sharded data set consisting of 1024 and 128 TFRecord files, respectively.
  21. Each validation TFRecord file contains ~500 records. Each training TFREcord
  22. file contains ~1000 records. Each record within the TFRecord file is a
  23. serialized Example proto. The Example proto contains the following fields:
  24. image/encoded: string containing JPEG encoded image in RGB colorspace
  25. image/height: integer, image height in pixels
  26. image/width: integer, image width in pixels
  27. image/channels: integer, specifying the number of channels, always 3
  28. image/format: string, specifying the format, always'JPEG'
  29. image/object/bbox/xmin: list of float specifying the 0+ human annotated
  30. bounding boxes
  31. image/object/bbox/xmax: list of float specifying the 0+ human annotated
  32. bounding boxes
  33. image/object/bbox/ymin: list of float specifying the 0+ human annotated
  34. bounding boxes
  35. image/object/bbox/ymax: list of float specifying the 0+ human annotated
  36. bounding boxes
  37. image/object/bbox/label: list of integer specifying the classification index.
  38. image/object/bbox/label_text: list of string descriptions.
  39. Note that the length of xmin is identical to the length of xmax, ymin and ymax
  40. for each example.
  41. """
  42. import os
  43. import sys
  44. import random
  45. import numpy as np
  46. import tensorflow as tf
  47. import xml.etree.ElementTree as ET
  48. from datasets.dataset_utils import int64_feature, float_feature, bytes_feature
  49. from datasets.pascalvoc_common import VOC_LABELS
  50. # Original dataset organisation.
  51. DIRECTORY_ANNOTATIONS = 'Annotations/'
  52. DIRECTORY_IMAGES = 'JPEGImages/'
  53. # TFRecords convertion parameters.
  54. RANDOM_SEED = 4242
  55. SAMPLES_PER_FILES = 5
  56. def _process_image(directory, name):
  57. """Process a image and annotation file.
  58. Args:
  59. filename: string, path to an image file e.g., '/path/to/example.JPG'.
  60. coder: instance of ImageCoder to provide TensorFlow image coding utils.
  61. Returns:
  62. image_buffer: string, JPEG encoding of RGB image.
  63. height: integer, image height in pixels.
  64. width: integer, image width in pixels.
  65. """
  66. # Read the image file.
  67. filename = directory + DIRECTORY_IMAGES + name + '.bmp'
  68. image_data = tf.gfile.FastGFile(filename, 'rb').read()
  69. # Read the XML annotation file.
  70. filename = os.path.join(directory, DIRECTORY_ANNOTATIONS, name + '.xml')
  71. tree = ET.parse(filename)
  72. root = tree.getroot()
  73. # Image shape.
  74. size = root.find( 'size')
  75. shape = [int(size.find( 'height').text),
  76. int(size.find( 'width').text),
  77. int(size.find( 'depth').text)]
  78. # Find annotations.
  79. bboxes = []
  80. labels = []
  81. labels_text = []
  82. difficult = []
  83. truncated = []
  84. for obj in root.findall( 'object'):
  85. label = obj.find( 'name').text
  86. labels.append(int(VOC_LABELS[label][ 0]))
  87. labels_text.append(label.encode( 'ascii'))
  88. if obj.find( 'difficult'):
  89. difficult.append(int(obj.find( 'difficult').text))
  90. else:
  91. difficult.append( 0)
  92. if obj.find( 'truncated'):
  93. truncated.append(int(obj.find( 'truncated').text))
  94. else:
  95. truncated.append( 0)
  96. bbox = obj.find( 'bndbox')
  97. bboxes.append((float(bbox.find( 'ymin').text) / shape[ 0],
  98. float(bbox.find( 'xmin').text) / shape[ 1],
  99. float(bbox.find( 'ymax').text) / shape[ 0],
  100. float(bbox.find( 'xmax').text) / shape[ 1]
  101. ))
  102. return image_data, shape, bboxes, labels, labels_text, difficult, truncated
  103. def _convert_to_example(image_data, labels, labels_text, bboxes, shape,
  104. difficult, truncated):
  105. """Build an Example proto for an image example.
  106. Args:
  107. image_data: string, JPEG encoding of RGB image;
  108. labels: list of integers, identifier for the ground truth;
  109. labels_text: list of strings, human-readable labels;
  110. bboxes: list of bounding boxes; each box is a list of integers;
  111. specifying [xmin, ymin, xmax, ymax]. All boxes are assumed to belong
  112. to the same label as the image label.
  113. shape: 3 integers, image shapes in pixels.
  114. Returns:
  115. Example proto
  116. """
  117. xmin = []
  118. ymin = []
  119. xmax = []
  120. ymax = []
  121. for b in bboxes:
  122. assert len(b) == 4
  123. # pylint: disable=expression-not-assigned
  124. [l.append(point) for l, point in zip([ymin, xmin, ymax, xmax], b)]
  125. # pylint: enable=expression-not-assigned
  126. image_format = b'JPEG'
  127. example = tf.train.Example(features=tf.train.Features(feature={
  128. 'image/height': int64_feature(shape[ 0]),
  129. 'image/width': int64_feature(shape[ 1]),
  130. 'image/channels': int64_feature(shape[ 2]),
  131. 'image/shape': int64_feature(shape),
  132. 'image/object/bbox/xmin': float_feature(xmin),
  133. 'image/object/bbox/xmax': float_feature(xmax),
  134. 'image/object/bbox/ymin': float_feature(ymin),
  135. 'image/object/bbox/ymax': float_feature(ymax),
  136. 'image/object/bbox/label': int64_feature(labels),
  137. 'image/object/bbox/label_text': bytes_feature(labels_text),
  138. 'image/object/bbox/difficult': int64_feature(difficult),
  139. 'image/object/bbox/truncated': int64_feature(truncated),
  140. 'image/format': bytes_feature(image_format),
  141. 'image/encoded': bytes_feature(image_data)}))
  142. return example
  143. def _add_to_tfrecord(dataset_dir, name, tfrecord_writer):
  144. """Loads data from image and annotations files and add them to a TFRecord.
  145. Args:
  146. dataset_dir: Dataset directory;
  147. name: Image name to add to the TFRecord;
  148. tfrecord_writer: The TFRecord writer to use for writing.
  149. """
  150. image_data, shape, bboxes, labels, labels_text, difficult, truncated = \
  151. _process_image(dataset_dir, name)
  152. example = _convert_to_example(image_data, labels, labels_text,
  153. bboxes, shape, difficult, truncated)
  154. tfrecord_writer.write(example.SerializeToString())
  155. def _get_output_filename(output_dir, name, idx):
  156. return '%s/%s_%03d.tfrecord' % (output_dir, name, idx)
  157. def run(dataset_dir, output_dir, name='voc_train', shuffling=False):
  158. """Runs the conversion operation.
  159. Args:
  160. dataset_dir: The dataset directory where the dataset is stored.
  161. output_dir: Output directory.
  162. """
  163. if not tf.gfile.Exists(dataset_dir):
  164. tf.gfile.MakeDirs(dataset_dir)
  165. # Dataset filenames, and shuffling.
  166. path = os.path.join(dataset_dir, DIRECTORY_ANNOTATIONS)
  167. filenames = sorted(os.listdir(path))
  168. if shuffling:
  169. random.seed(RANDOM_SEED)
  170. random.shuffle(filenames)
  171. # Process dataset files.
  172. i = 0
  173. fidx = 0
  174. while i < len(filenames):
  175. # Open new TFRecord file.
  176. tf_filename = _get_output_filename(output_dir, name, fidx)
  177. with tf.python_io.TFRecordWriter(tf_filename) as tfrecord_writer:
  178. j = 0
  179. while i < len(filenames) and j < SAMPLES_PER_FILES:
  180. sys.stdout.write( '\r>> Converting image %d/%d' % (i+ 1, len(filenames)))
  181. sys.stdout.flush()
  182. filename = filenames[i]
  183. img_name = filename[: -4]
  184. _add_to_tfrecord(dataset_dir, img_name, tfrecord_writer)
  185. i += 1
  186. j += 1
  187. fidx += 1
  188. # Finally, write the labels file:
  189. # labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
  190. # dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
  191. print( '\nFinished converting the Pascal VOC dataset!')

6.执行上面的tf_convert_data文件后,在tfrecords目录下有。tfrecord后缀的文件

7. 再修改如下几处:

datasets/pascalvoc_2007.py

nets/ssd_vgg_300.py

eval_ssd_network.py

train_ssd_network.py

 

点击train_ssd_network.py开始训练,训练结束后,在train_model目录下存在了所间隔时间内保存的模型

 

7.用训练好的模型来预测一张人行步道图片


  
  1. # -*- coding:utf-8 -*-
  2. # -*- author:zzZ_CMing CSDN address:https://blog.csdn.net/zzZ_CMing
  3. # -*- 2018/07/14; 15:19
  4. # -*- python3.5
  5. """
  6. address: https://blog.csdn.net/qq_35608277/article/details/78660469
  7. 本文代码来自于github中微软官方仓库
  8. """
  9. import os
  10. import cv2
  11. import math
  12. import random
  13. import tensorflow as tf
  14. import matplotlib.pyplot as plt
  15. import matplotlib.cm as mpcm
  16. import matplotlib.image as mpimg
  17. from notebooks import visualization
  18. from nets import ssd_vgg_300, ssd_common, np_methods
  19. from preprocessing import ssd_vgg_preprocessing
  20. import sys
  21. sys.path.append( './SSD-Tensorflow-master/')
  22. slim = tf.contrib.slim
  23. gpu_options = tf.GPUOptions(allow_growth= True)
  24. config = tf.ConfigProto(log_device_placement= False, gpu_options=gpu_options)
  25. isess = tf.InteractiveSession(config=config)
  26. l_VOC_CLASS = [ 'sidewalk']
  27. net_shape = ( 300, 300)
  28. img_input = tf.placeholder(tf.uint8, shape=( None, None, 3))
  29. data_format = 'NHWC'
  30. image_pre, labels_pre, bboxes_pre, bbox_img = ssd_vgg_preprocessing.preprocess_for_eval(
  31. img_input, None, None, net_shape, data_format,
  32. resize=ssd_vgg_preprocessing.Resize.WARP_RESIZE)
  33. image_4d = tf.expand_dims(image_pre, 0)
  34. reuse = True if 'ssd_net' in locals() else None
  35. ssd_net = ssd_vgg_300.SSDNet()
  36. with slim.arg_scope(ssd_net.arg_scope(data_format=data_format)):
  37. predictions, localisations, _, _ = ssd_net.net(image_4d, is_training= False, reuse=reuse)
  38. ckpt_filename = '../train_model/model.ckpt-20000'
  39. isess.run(tf.global_variables_initializer())
  40. saver = tf.train.Saver()
  41. saver.restore(isess, ckpt_filename)
  42. ssd_anchors = ssd_net.anchors(net_shape)
  43. def colors_subselect(colors, num_classes=2):
  44. dt = len(colors) // num_classes
  45. sub_colors = []
  46. for i in range(num_classes):
  47. color = colors[i * dt]
  48. if isinstance(color[ 0], float):
  49. sub_colors.append([int(c * 255) for c in color])
  50. else:
  51. sub_colors.append([c for c in color])
  52. return sub_colors
  53. def bboxes_draw_on_img(img, classes, scores, bboxes, colors, thickness=2):
  54. shape = img.shape
  55. for i in range(bboxes.shape[ 0]):
  56. bbox = bboxes[i]
  57. color = colors[classes[i]]
  58. # Draw bounding box...
  59. p1 = (int(bbox[ 0] * shape[ 0]), int(bbox[ 1] * shape[ 1]))
  60. p2 = (int(bbox[ 2] * shape[ 0]), int(bbox[ 3] * shape[ 1]))
  61. cv2.rectangle(img, p1[:: -1], p2[:: -1], color, thickness)
  62. # Draw text...
  63. s = '%s/%.3f' % (l_VOC_CLASS[int(classes[i]) - 1], scores[i])
  64. p1 = (p1[ 0] - 5, p1[ 1])
  65. # cv2.putText(img, s, p1[::-1], cv2.FONT_HERSHEY_DUPLEX, 1.5, color, 3)
  66. colors_plasma = colors_subselect(mpcm.plasma.colors, num_classes= 21)
  67. # 主流程函数
  68. def process_image(img, case, select_threshold=0.15, nms_threshold=.1, net_shape=(300, 300)):
  69. # select_threshold:box阈值——每个像素的box分类预测数据的得分会与box阈值比较,高于一个box阈值则认为这个box成功框到了一个对象
  70. # nms_threshold:重合度阈值——同一对象的两个框的重合度高于该阈值,则运行下面去重函数
  71. # 执行SSD模型,得到4维输入变量,分类预测,坐标预测,rbbox_img参数为最大检测范围,本文固定为[0,0,1,1]即全图
  72. rimg, rpredictions, rlocalisations, rbbox_img = isess.run([image_4d, predictions,
  73. localisations, bbox_img], feed_dict={img_input: img})
  74. # ssd_bboxes_select()函数根据每个特征层的分类预测分数,归一化后的映射坐标,
  75. # ancohor_box的大小,通过设定一个阈值计算得到每个特征层检测到的对象以及其分类和坐标
  76. rclasses, rscores, rbboxes = np_methods.ssd_bboxes_select(rpredictions, rlocalisations, ssd_anchors,
  77. select_threshold=select_threshold,
  78. img_shape=net_shape,
  79. num_classes= 21, decode= True)
  80. """
  81. 这个函数做的事情比较多,这里说的细致一些:
  82. 首先是输入,输入的数据为每个特征层(一共6个,见上文)的:
  83. rpredictions: 分类预测数据,
  84. rlocalisations: 坐标预测数据,
  85. ssd_anchors: anchors_box数据
  86. 其中:
  87. 分类预测数据为当前特征层中每个像素的每个box的分类预测
  88. 坐标预测数据为当前特征层中每个像素的每个box的坐标预测
  89. anchors_box数据为当前特征层中每个像素的每个box的修正数据
  90. 函数根据坐标预测数据和anchors_box数据,计算得到每个像素的每个box的中心和长宽,这个中心坐标和长宽会根据一个算法进行些许的修正,
  91. 从而得到一个更加准确的box坐标;修正的算法会在后文中详细解释,如果只是为了理解算法流程也可以不必深究这个,因为这个修正算法属于经验算
  92. 法,并没有太多逻辑可循。
  93. 修正完box和中心后,函数会计算每个像素的每个box的分类预测数据的得分,当这个分数高于一个阈值(这里是0.5)则认为这个box成功
  94. 框到了一个对象,然后将这个box的坐标数据,所属分类和分类得分导出,从而得到:
  95. rclasses:所属分类
  96. rscores:分类得分
  97. rbboxes:坐标
  98. 最后要注意的是,同一个目标可能会在不同的特征层都被检测到,并且他们的box坐标会有些许不同,这里并没有去掉重复的目标,而是在下文
  99. 中专门用了一个函数来去重
  100. """
  101. # 检测有没有超出检测边缘
  102. rbboxes = np_methods.bboxes_clip(rbbox_img, rbboxes)
  103. rclasses, rscores, rbboxes = np_methods.bboxes_sort(rclasses, rscores, rbboxes, top_k= 400)
  104. # 去重,将重复检测到的目标去掉
  105. rclasses, rscores, rbboxes = np_methods.bboxes_nms(rclasses, rscores, rbboxes, nms_threshold=nms_threshold)
  106. # 将box的坐标重新映射到原图上(上文所有的坐标都进行了归一化,所以要逆操作一次)
  107. rbboxes = np_methods.bboxes_resize(rbbox_img, rbboxes)
  108. if case == 1:
  109. bboxes_draw_on_img(img, rclasses, rscores, rbboxes, colors_plasma, thickness= 8)
  110. return img
  111. else:
  112. return rclasses, rscores, rbboxes
  113. """
  114. # 只做目标定位,不做预测分析
  115. case = 1
  116. img = cv2.imread("../demo/person.jpg")
  117. img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  118. plt.imshow(process_image(img, case))
  119. plt.show()
  120. """
  121. # 做目标定位,同时做预测分析
  122. case = 2
  123. path = '../VOC2007/JPEGImages/166.bmp'
  124. # 读取图片
  125. img = mpimg.imread(path)
  126. # 执行主流程函数
  127. rclasses, rscores, rbboxes = process_image(img, case)
  128. # visualization.bboxes_draw_on_img(img, rclasses, rscores, rbboxes, visualization.colors_plasma)
  129. # 显示分类结果图
  130. visualization.plt_bboxes(img, rclasses, rscores, rbboxes), rscores, rbboxes

 

预测结果如下:

                                                                                     原图

 

                                                                              预测结果

 

后续有精力会对参数进行精调,获得一张好的检测效果图。

 

工程代码见如下链接:

链接:https://pan.baidu.com/s/1EDWix2XvzF8URTxlbNLJCA 
提取码:3kyb 
 


转载:https://blog.csdn.net/jiugeshao/article/details/116902463
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场