
【Keras+计算机视觉+Tensorflow】实现基于YOLO和Deep Sort的目标检测与跟踪实战(附源码和数据集)

        YOLO是端到端的物体检测深度卷积神经网络,YOLO可以一次性预测多个候选框,并直接在输出层回归物体位置区域和区域内物体所属类别,而Faster R-CNN仍然是采用R-CNN那种将物体位置区域框与物体分开训练的思想,只是利用RPN网络,将提取候选框的步骤放在深度卷积神经网络内部实现,YOLO最大的优势就是速度快,可满足端到端训练和实时检测要求

二、Deep Sort多目标跟踪算法 












4:目标跟踪部分Deep Sort 

  1. import os
  2. import numpy as np
  3. import copy
  4. import colorsys
  5. from timeit import default_timer as timer
  6. from keras import backend as K
  7. from keras.models import load_model
  8. from keras.layers import Input
  9. from PIL import Image, ImageFont, ImageDraw
  10. from nets.yolo4 import yolo_body,yolo_eval
  11. from utils.utils import letterbox_image
  12. #--------------------------------------------#
  13. # 使用自己训练好的模型预测需要修改2个参数
  14. # model_path和classes_path都需要修改!
  15. #--------------------------------------------#
  16. class YOLO( object):
  17. _defaults = {
  18. "model_path" : 'model_data/yolo4_weight.h5',
  19. "anchors_path" : 'model_data/yolo_anchors.txt',
  20. "classes_path" : 'model_data/coco_classes.txt',
  21. "score" : 0.5,
  22. "iou" : 0.3,
  23. "max_boxes" : 100,
  24. # 显存比较小可以使用416x416
  25. # 显存比较大可以使用608x608
  26. "model_image_size" : ( 416, 416)
  27. }
  28. @classmethod
  29. def get_defaults( cls, n):
  30. if n in cls._defaults:
  31. return cls._defaults[n]
  32. else:
  33. return "Unrecognized attribute name '" + n + "'"
  34. #---------------------------------------------------#
  35. # 初始化yolo
  36. #---------------------------------------------------#
  37. def __init__( self, **kwargs):
  38. self.__dict__.update(self._defaults)
  39. self.class_names = self._get_class()
  40. self.anchors = self._get_anchors()
  41. self.sess = K.get_session()
  42. self.boxes, self.scores, self.classes = self.generate()
  43. #---------------------------------------------------#
  44. # 获得所有的分类
  45. #---------------------------------------------------#
  46. def _get_class( self):
  47. classes_path = os.path.expanduser(self.classes_path)
  48. with open(classes_path) as f:
  49. class_names = f.readlines()
  50. class_names = [c.strip() for c in class_names]
  51. return class_names
  52. #---------------------------------------------------#
  53. # 获得所有的先验框
  54. #---------------------------------------------------#
  55. def _get_anchors( self):
  56. anchors_path = os.path.expanduser(self.anchors_path)
  57. with open(anchors_path) as f:
  58. anchors = f.readline()
  59. anchors = [ float(x) for x in anchors.split( ',')]
  60. return np.array(anchors).reshape(- 1, 2)
  61. #---------------------------------------------------#
  62. # 获得所有的分类
  63. #---------------------------------------------------#
  64. def generate( self):
  65. model_path = os.path.expanduser(self.model_path)
  66. assert model_path.endswith( '.h5'), 'Keras model or weights must be a .h5 file.'
  67. # 计算anchor数量
  68. num_anchors = len(self.anchors)
  69. num_classes = len(self.class_names)
  70. # 载入模型,如果原来的模型里已经包括了模型结构则直接载入。
  71. # 否则先构建模型再载入
  72. try:
  73. self.yolo_model = load_model(model_path, compile= False)
  74. except:
  75. self.yolo_model = yolo_body(Input(shape=( None, None, 3)), num_anchors// 3, num_classes)
  76. self.yolo_model.load_weights(self.model_path)
  77. else:
  78. assert self.yolo_model.layers[- 1].output_shape[- 1] == \
  79. num_anchors/ len(self.yolo_model.output) * (num_classes + 5), \
  80. 'Mismatch between model and given anchor and class sizes'
  81. print( '{} model, anchors, and classes loaded.'. format(model_path))
  82. # 画框设置不同的颜色
  83. hsv_tuples = [(x / len(self.class_names), 1., 1.)
  84. for x in range( len(self.class_names))]
  85. self.colors = list( map( lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
  86. self.colors = list(
  87. map( lambda x: ( int(x[ 0] * 255), int(x[ 1] * 255), int(x[ 2] * 255)),
  88. self.colors))
  89. # 打乱颜色
  90. np.random.seed( 10101)
  91. np.random.shuffle(self.colors)
  92. np.random.seed( None)
  93. self.input_image_shape = K.placeholder(shape=( 2, ))
  94. boxes, scores, classes = yolo_eval(self.yolo_model.output, self.anchors,
  95. num_classes, self.input_image_shape, max_boxes = self.max_boxes,
  96. score_threshold = self.score, iou_threshold = self.iou)
  97. return boxes, scores, classes
  98. '''
  99. 函数名称:detect_image
  100. 函数作用:目标跟踪程序(YOLO V4)
  101. 函数输入:frame: 图像
  102. 函数输出:
  103. boxs_person:行人检测框【x1, y1, w, h】
  104. boxs_others:其他类别检测框【x1, y1, x2, y2】
  105. labels_others: 其他框的类别
  106. '''
  107. def detect_image( self, image):
  108. new_image_size = (self.model_image_size[ 1],self.model_image_size[ 0])
  109. boxed_image = letterbox_image(image, new_image_size)
  110. image_data = np.array(boxed_image, dtype= 'float32')
  111. image_data /= 255.
  112. image_data = np.expand_dims(image_data, 0) # Add batch dimension.
  113. boxs_person = []
  114. boxs_others = []
  115. labels_others = []
  116. # 预测结果
  117. out_boxes, out_scores, out_classes = self.sess.run(
  118. [self.boxes, self.scores, self.classes],
  119. feed_dict={
  120. self.yolo_model. input: image_data,
  121. self.input_image_shape: [image.size[ 1], image.size[ 0]],
  122. K.learning_phase(): 0
  123. })
  124. for i, c in list( enumerate(out_classes)):
  125. predicted_class = self.class_names[c]
  126. box = out_boxes[i]
  127. score = out_scores[i]
  128. top, left, bottom, right = box
  129. ###输入deepSort的格式如下###
  130. box_deepsort = [left,top,right-left,bottom-top]
  131. box_other = [left,top,right,bottom]
  132. if predicted_class == 'person':
  133. boxs_person.append(box_deepsort)
  134. else:
  135. boxs_others.append(box_other)
  136. labels_others.append(predicted_class)
  137. return boxs_person,boxs_others,labels_others
  138. def close_session( self):
  139. self.sess.close()

 Deep Sort算法代码

  1. #!python3
  2. #--coding:utf8--
  3. from yolo import YOLO
  4. from PIL import Image
  5. import os
  6. import sys
  7. import time
  8. import logging
  9. import random
  10. from random import randint
  11. import math
  12. import statistics
  13. import getopt
  14. from ctypes import *
  15. import numpy as np
  16. import cv2
  17. from deep_sort import nn_matching
  18. from deep_sort.detection import Detection
  19. from deep_sort.tracker import Tracker
  20. from tools import generate_detections as gdet
  21. from deep_sort.detection import Detection as ddet
  22. from collections import deque
  23. from deep_sort import preprocessing
  24. '''
  25. 函数名称:track_deepsort
  26. 函数作用:目标跟踪程序
  27. 函数输入:frame:图像
  28. boxs_person:行人检测框【x1,y1,w,h】
  29. boxs_others:其他类别检测框【x1,y1,x2,y2】
  30. labels_others:其他框的类别
  31. encoder:跟踪器的编码器
  32. tracker: 跟踪器
  33. pts: 运动点初始化值
  34. show_results:是否显示结果
  35. 函数输出:tracker 跟踪器
  36. pts 运动轨迹
  37. '''
  38. def track_deepsort( frame,boxs_person,boxs_others,labels_others,encoder,tracker,pts,show_results=True):
  39. nms_max_overlap = 1.0
  40. features = encoder(frame, boxs_person)
  41. detections = [Detection(bbox, 1.0, feature) for bbox, feature in zip(boxs_person, features)]
  42. boxes = np.array([d.tlwh for d in detections])
  43. scores = np.array([d.confidence for d in detections])
  44. indices = preprocessing.non_max_suppression(boxes, nms_max_overlap, scores)
  45. detections = [detections[i] for i in indices]
  46. # 跟踪
  47. tracker.predict()
  48. tracker.update(detections)
  49. i = int( 0)
  50. indexIDs = []
  51. ##########结果显示###########
  52. if show_results:
  53. for det in detections:
  54. bbox = det.to_tlbr()
  55. cv2.rectangle(frame, ( int(bbox[ 0]), int(bbox[ 1])), ( int(bbox[ 2]), int(bbox[ 3])), ( 255, 255, 255), 2)
  56. for ii in range( len(boxs_others)):
  57. bbox = boxs_others[ii]
  58. label = labels_others[ii]
  59. cv2.rectangle(frame, ( int(bbox[ 0]), int(bbox[ 1])), ( int(bbox[ 2]), int(bbox[ 3])), ( 0, 255, 255), 2)
  60. cv2.putText(frame, str(label), ( int(bbox[ 0]), int(bbox[ 1])), cv2.FONT_HERSHEY_COMPLEX, 0.5, ( 0, 255, 255), 2)
  61. for track in tracker.tracks:
  62. if not track.is_confirmed() or track.time_since_update > 1:
  63. continue
  64. # boxes.append([track[0], track[1], track[2], track[3]])
  65. indexIDs.append( int(track.track_id))
  66. bbox = track.to_tlbr()
  67. cv2.rectangle(frame, ( int(bbox[ 0]), int(bbox[ 1])), ( int(bbox[ 2]), int(bbox[ 3])), ( 0, 255, 0), 3)
  68. cv2.putText(frame, str(track.track_id), ( int(bbox[ 0]), int(bbox[ 1] - 50)), 0, 5e-3 * 150, ( 0, 255, 0), 2)
  69. i = i + 1
  70. center = ( int(((bbox[ 0]) + (bbox[ 2])) / 2), int(((bbox[ 1]) + (bbox[ 3])) / 2))
  71. pts[track.track_id].append(center)
  72. # draw motion path
  73. for j in range( 1, len(pts[track.track_id])):
  74. if pts[track.track_id][j - 1] is None or pts[track.track_id][j] is None:
  75. continue
  76. thickness = int(np.sqrt( 64 / float(j + 1)) * 2)
  77. cv2.line(frame, (pts[track.track_id][j - 1]), (pts[track.track_id][j]), ( 0, 255, 255), thickness)
  78. return tracker,pts
  79. if __name__ == "__main__":
  80. yolo = YOLO()
  81. ####设置跟踪参数###
  82. max_cosine_distance = 0.5
  83. nn_budget = 20
  84. metric = nn_matching.NearestNeighborDistanceMetric( "cosine", max_cosine_distance,
  85. nn_budget) # 最近邻距离度量,对于每个目标,返回到目前为止已观察到的任何样本的最近距离(欧式或余弦)。
  86. tracker = Tracker(metric) # 由距离度量方法构造一个 Tracker。
  87. writeVideo_flag = False
  88. ###轨迹点定义##
  89. pts = [deque(maxlen= 30) for _ in range( 9999)]
  90. model_filename = './model_data/mars-small128.pb' ###DeepSort 模型位置##
  91. encoder = gdet.create_box_encoder(model_filename, batch_size= 1)
  92. Obj_centre = [[] for i in range( 200)]
  93. Obj_pre_direction = [[] for i in range( 200)]
  94. ShowFlag = True ##是否显示结果
  95. ####打开摄像机###
  96. # 创建VideoCapture,传入0即打开系统默认摄像头
  97. # cap = cv2.VideoCapture(0)
  98. #######读取视频######################################
  99. video_path = 'structure.mp4'
  100. video_capture = cv2.VideoCapture(video_path)
  101. key = ''
  102. count = 0
  103. save_path = './saveimg/'
  104. while key != 113: # for 'q' key
  105. ###读取图像###
  106. ret, frame = video_capture.read()
  107. #######目标检测########################
  108. frame2 = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGRA2RGBA))
  109. boxs_person,boxs_others,labels_others = yolo.detect_image(frame2)
  110. #######目标跟踪########################
  111. tracker, pts = track_deepsort(frame, boxs_person, boxs_others, labels_others, encoder, tracker, pts)
  112. #######显示检测及跟踪结果####
  113. cv2.namedWindow( "YOLO3_Deep_SORT", 0)
  114. cv2.resizeWindow( 'YOLO3_Deep_SORT', 1024, 768)
  115. cv2.imshow( 'YOLO3_Deep_SORT', frame)
  116. cv2.waitKey( 3)
  117. count += 1
  118. jpg_name = os.path.join(save_path, str(count).zfill( 6)+ '.jpg')
  119. cv2.imwrite(jpg_name,frame)


结果显示,在目标检测环节,当人群交叉 光照突变时可能出现漏检的现象,这将导致目标跟踪环节出现跟踪错误,应该进一步地调整目标跟踪策略,使目标跟踪算法具有鲁棒性,尤其是解决人员聚集情况下的目标跟踪问题

