飞道的博客

5 分钟掌握 Python 中的 Hook 钩子函数

438人阅读  评论(0)

1. 什么是Hook

经常会听到钩子函数(hook function)这个概念,最近在看目标检测开源框架mmdetection,里面也出现大量Hook的编程方式,那到底什么是hook?hook的作用是什么?

  • what is hook ?钩子hook,顾名思义,可以理解是一个挂钩,作用是有需要的时候挂一个东西上去。具体的解释是:钩子函数是把我们自己实现的hook函数在某一时刻挂接到目标挂载点上。

  • hook函数的作用 举个例子,hook的概念在windows桌面软件开发很常见,特别是各种事件触发的机制; 比如C++的MFC程序中,要监听鼠标左键按下的时间,MFC提供了一个onLeftKeyDown的钩子函数。很显然,MFC框架并没有为我们实现onLeftKeyDown具体的操作,只是为我们提供一个钩子,当我们需要处理的时候,只要去重写这个函数,把我们需要操作挂载在这个钩子里,如果我们不挂载,MFC事件触发机制中执行的就是空的操作。

从上面可知

  • hook函数是程序中预定义好的函数,这个函数处于原有程序流程当中(暴露一个钩子出来)

  • 我们需要再在有流程中钩子定义的函数块中实现某个具体的细节,需要把我们的实现,挂接或者注册(register)到钩子里,使得hook函数对目标可用

  • hook 是一种编程机制,和具体的语言没有直接的关系

  • 如果从设计模式上看,hook模式是模板方法的扩展

  • 钩子只有注册的时候,才会使用,所以原有程序的流程中,没有注册或挂载时,执行的是空(即没有执行任何操作)

本文用python来解释hook的实现方式,并展示在开源项目中hook的应用案例。hook函数和我们常听到另外一个名称:回调函数(callback function)功能是类似的,可以按照同种模式来理解。

2. hook实现例子

据我所知,hook函数最常使用在某种流程处理当中。这个流程往往有很多步骤。hook函数常常挂载在这些步骤中,为增加额外的一些操作,提供灵活性。

下面举一个简单的例子,这个例子的目的是实现一个通用往队列中插入内容的功能。流程步骤有2个

  • 需要再插入队列前,对数据进行筛选 input_filter_fn

  • 插入队列 insert_queue


   
  1. class ContentStash(object):
  2.      "" "
  3.     content stash for online operation
  4.     pipeline is
  5.     1. input_filter: filter some contents, no use to user
  6.     2. insert_queue(redis or other broker): insert useful content to queue
  7.     " ""
  8.     def __init__(self):
  9.         self.input_filter_fn = None
  10.         self.broker = []
  11.     def register_input_filter_hook(self, input_filter_fn):
  12.          "" "
  13.         register input filter function, parameter is content dict
  14.         Args:
  15.             input_filter_fn: input filter function
  16.         Returns:
  17.         " ""
  18.         self.input_filter_fn = input_filter_fn
  19.     def insert_queue(self, content):
  20.          "" "
  21.         insert content to queue
  22.         Args:
  23.             content: dict
  24.         Returns:
  25.         " ""
  26.         self.broker. append(content)
  27.     def input_pipeline(self, content, use=False):
  28.          "" "
  29.         pipeline of input for content stash
  30.         Args:
  31.             use: is use, defaul False
  32.             content: dict
  33.         Returns:
  34.         " ""
  35.          if not use:
  36.              return
  37.         # input filter
  38.          if self.input_filter_fn:
  39.             _filter = self.input_filter_fn(content)
  40.             
  41.         # insert to queue
  42.          if not _filter:
  43.             self.insert_queue(content)
  44. # test
  45. ## 实现一个你所需要的钩子实现:比如如果content 包含time就过滤掉,否则插入队列
  46. def input_filter_hook(content):
  47.      "" "
  48.     test input filter hook
  49.     Args:
  50.         content: dict
  51.     Returns: None or content
  52.     " ""
  53.      if content.get( 'time') is None:
  54.          return
  55.      else:
  56.          return content
  57. # 原有程序
  58. content = { 'filename''test.jpg''b64_file'"#test"'data': { "result""cat""probility"0.9}}
  59. content_stash = ContentStash( 'audit', work_dir= '')
  60. # 挂上钩子函数, 可以有各种不同钩子函数的实现,但是要主要函数输入输出必须保持原有程序中一致,比如这里是content
  61. content_stash.register_input_filter_hook(input_filter_hook)
  62. # 执行流程
  63. content_stash.input_pipeline(content)

3. hook在开源框架中的应用

3.1 keras

在深度学习训练流程中,hook函数体现的淋漓尽致。

一个训练过程(不包括数据准备),会轮询多次训练集,每次称为一个epoch,每个epoch又分为多个batch来训练。流程先后拆解成:

  • 开始训练

  • 训练一个epoch前

  • 训练一个batch前

  • 训练一个batch后

  • 训练一个epoch后

  • 评估验证集

  • 结束训练

这些步骤是穿插在训练一个batch数据的过程中,这些可以理解成是钩子函数,我们可能需要在这些钩子函数中实现一些定制化的东西,比如在训练一个epoch后我们要保存下训练的模型,在结束训练时用最好的模型执行下测试集的效果等等。

keras中是通过各种回调函数来实现钩子hook功能的。这里放一个callback的父类,定制时只要继承这个父类,实现你过关注的钩子就可以了。


   
  1. @keras_export( 'keras.callbacks.Callback')
  2. class Callback(object):
  3.    "" "Abstract base class used to build new callbacks.
  4.   Attributes:
  5.       params: Dict. Training parameters
  6.           (eg. verbosity, batch size, number of epochs...).
  7.       model: Instance of `keras.models.Model`.
  8.           Reference of the model being trained.
  9.   The `logs` dictionary that callback methods
  10.   take as argument will contain keys for quantities relevant to
  11.   the current batch or epoch (see method-specific docstrings).
  12.   " ""
  13.   def __init__(self):
  14.     self.validation_data = None  # pylint: disable=g-missing-from-attributes
  15.     self.model = None
  16.     # Whether this Callback should only run on the chief worker in a
  17.     # Multi-Worker setting.
  18.     # TODO(omalleyt): Make this attr public once solution is stable.
  19.     self._chief_worker_only = None
  20.     self._supports_tf_logs = False
  21.   def set_params(self, params):
  22.     self.params = params
  23.   def set_model(self, model):
  24.     self.model = model
  25.   @doc_controls.for_subclass_implementers
  26.   @generic_utils. default
  27.   def on_batch_begin(self, batch, logs=None):
  28.      "" "A backwards compatibility alias for `on_train_batch_begin`." ""
  29.   @doc_controls.for_subclass_implementers
  30.   @generic_utils. default
  31.   def on_batch_end(self, batch, logs=None):
  32.      "" "A backwards compatibility alias for `on_train_batch_end`." ""
  33.   @doc_controls.for_subclass_implementers
  34.   def on_epoch_begin(self, epoch, logs=None):
  35.      "" "Called at the start of an epoch.
  36.     Subclasses should override for any actions to run. This function should only
  37.     be called during TRAIN mode.
  38.     Arguments:
  39.         epoch: Integer, index of epoch.
  40.         logs: Dict. Currently no data is passed to this argument for this method
  41.           but that may change in the future.
  42.     " ""
  43.   @doc_controls.for_subclass_implementers
  44.   def on_epoch_end(self, epoch, logs=None):
  45.      "" "Called at the end of an epoch.
  46.     Subclasses should override for any actions to run. This function should only
  47.     be called during TRAIN mode.
  48.     Arguments:
  49.         epoch: Integer, index of epoch.
  50.         logs: Dict, metric results for this training epoch, and for the
  51.           validation epoch if validation is performed. Validation result keys
  52.           are prefixed with `val_`.
  53.     " ""
  54.   @doc_controls.for_subclass_implementers
  55.   @generic_utils. default
  56.   def on_train_batch_begin(self, batch, logs=None):
  57.      "" "Called at the beginning of a training batch in `fit` methods.
  58.     Subclasses should override for any actions to run.
  59.     Arguments:
  60.         batch: Integer, index of batch within the current epoch.
  61.         logs: Dict, contains the return value of `model.train_step`. Typically,
  62.           the values of the `Model`'s metrics are returned.  Example:
  63.           `{'loss': 0.2, 'accuracy': 0.7}`.
  64.     " ""
  65.     # For backwards compatibility.
  66.     self.on_batch_begin(batch, logs=logs)
  67.   @doc_controls.for_subclass_implementers
  68.   @generic_utils. default
  69.   def on_train_batch_end(self, batch, logs=None):
  70.      "" "Called at the end of a training batch in `fit` methods.
  71.     Subclasses should override for any actions to run.
  72.     Arguments:
  73.         batch: Integer, index of batch within the current epoch.
  74.         logs: Dict. Aggregated metric results up until this batch.
  75.     " ""
  76.     # For backwards compatibility.
  77.     self.on_batch_end(batch, logs=logs)
  78.   @doc_controls.for_subclass_implementers
  79.   @generic_utils. default
  80.   def on_test_batch_begin(self, batch, logs=None):
  81.      "" "Called at the beginning of a batch in `evaluate` methods.
  82.     Also called at the beginning of a validation batch in the `fit`
  83.     methods, if validation data is provided.
  84.     Subclasses should override for any actions to run.
  85.     Arguments:
  86.         batch: Integer, index of batch within the current epoch.
  87.         logs: Dict, contains the return value of `model.test_step`. Typically,
  88.           the values of the `Model`'s metrics are returned.  Example:
  89.           `{'loss': 0.2, 'accuracy': 0.7}`.
  90.     " ""
  91.   @doc_controls.for_subclass_implementers
  92.   @generic_utils. default
  93.   def on_test_batch_end(self, batch, logs=None):
  94.      "" "Called at the end of a batch in `evaluate` methods.
  95.     Also called at the end of a validation batch in the `fit`
  96.     methods, if validation data is provided.
  97.     Subclasses should override for any actions to run.
  98.     Arguments:
  99.         batch: Integer, index of batch within the current epoch.
  100.         logs: Dict. Aggregated metric results up until this batch.
  101.     " ""
  102.   @doc_controls.for_subclass_implementers
  103.   @generic_utils. default
  104.   def on_predict_batch_begin(self, batch, logs=None):
  105.      "" "Called at the beginning of a batch in `predict` methods.
  106.     Subclasses should override for any actions to run.
  107.     Arguments:
  108.         batch: Integer, index of batch within the current epoch.
  109.         logs: Dict, contains the return value of `model.predict_step`,
  110.           it typically returns a dict with a key 'outputs' containing
  111.           the model's outputs.
  112.     " ""
  113.   @doc_controls.for_subclass_implementers
  114.   @generic_utils. default
  115.   def on_predict_batch_end(self, batch, logs=None):
  116.      "" "Called at the end of a batch in `predict` methods.
  117.     Subclasses should override for any actions to run.
  118.     Arguments:
  119.         batch: Integer, index of batch within the current epoch.
  120.         logs: Dict. Aggregated metric results up until this batch.
  121.     " ""
  122.   @doc_controls.for_subclass_implementers
  123.   def on_train_begin(self, logs=None):
  124.      "" "Called at the beginning of training.
  125.     Subclasses should override for any actions to run.
  126.     Arguments:
  127.         logs: Dict. Currently no data is passed to this argument for this method
  128.           but that may change in the future.
  129.     " ""
  130.   @doc_controls.for_subclass_implementers
  131.   def on_train_end(self, logs=None):
  132.      "" "Called at the end of training.
  133.     Subclasses should override for any actions to run.
  134.     Arguments:
  135.         logs: Dict. Currently the output of the last call to `on_epoch_end()`
  136.           is passed to this argument for this method but that may change in
  137.           the future.
  138.     " ""
  139.   @doc_controls.for_subclass_implementers
  140.   def on_test_begin(self, logs=None):
  141.      "" "Called at the beginning of evaluation or validation.
  142.     Subclasses should override for any actions to run.
  143.     Arguments:
  144.         logs: Dict. Currently no data is passed to this argument for this method
  145.           but that may change in the future.
  146.     " ""
  147.   @doc_controls.for_subclass_implementers
  148.   def on_test_end(self, logs=None):
  149.      "" "Called at the end of evaluation or validation.
  150.     Subclasses should override for any actions to run.
  151.     Arguments:
  152.         logs: Dict. Currently the output of the last call to
  153.           `on_test_batch_end()` is passed to this argument for this method
  154.           but that may change in the future.
  155.     " ""
  156.   @doc_controls.for_subclass_implementers
  157.   def on_predict_begin(self, logs=None):
  158.      "" "Called at the beginning of prediction.
  159.     Subclasses should override for any actions to run.
  160.     Arguments:
  161.         logs: Dict. Currently no data is passed to this argument for this method
  162.           but that may change in the future.
  163.     " ""
  164.   @doc_controls.for_subclass_implementers
  165.   def on_predict_end(self, logs=None):
  166.      "" "Called at the end of prediction.
  167.     Subclasses should override for any actions to run.
  168.     Arguments:
  169.         logs: Dict. Currently no data is passed to this argument for this method
  170.           but that may change in the future.
  171.     " ""
  172.   def _implements_train_batch_hooks(self):
  173.      "" "Determines if this Callback should be called for each train batch." ""
  174.      return (not generic_utils.is_default(self.on_batch_begin) or
  175.             not generic_utils.is_default(self.on_batch_end) or
  176.             not generic_utils.is_default(self.on_train_batch_begin) or
  177.             not generic_utils.is_default(self.on_train_batch_end))

这些钩子的原始程序是在模型训练流程中的

keras源码位置: tensorflow\python\keras\engine\training.py

部分摘录如下(## I am hook):


   
  1. # Container that configures and calls  `tf.keras.Callback`s.
  2.        if not isinstance(callbacks, callbacks_module.CallbackList):
  3.         callbacks = callbacks_module.CallbackList(
  4.             callbacks,
  5.             add_history=True,
  6.             add_progbar=verbose !=  0,
  7.             model=self,
  8.             verbose=verbose,
  9.             epochs=epochs,
  10.             steps=data_handler.inferred_steps)
  11.       ## I am hook
  12.       callbacks.on_train_begin()
  13.       training_logs = None
  14.       # Handle fault-tolerance  for multi-worker.
  15.       # TODO(omalleyt): Fix the ordering issues that mean this has to
  16.       # happen after  `callbacks.on_train_begin`.
  17.       data_handler._initial_epoch = (  # pylint: disable=protected-access
  18.           self._maybe_load_initial_epoch_from_ckpt(initial_epoch))
  19.        for epoch, iterator in data_handler.enumerate_epochs():
  20.         self.reset_metrics()
  21.         callbacks.on_epoch_begin(epoch)
  22.         with data_handler.catch_stop_iteration():
  23.            for step in data_handler.steps():
  24.             with trace.Trace(
  25.                  'TraceContext',
  26.                 graph_type= 'train',
  27.                 epoch_num=epoch,
  28.                 step_num=step,
  29.                 batch_size=batch_size):
  30.               ## I am hook
  31.               callbacks.on_train_batch_begin(step)
  32.               tmp_logs = train_function(iterator)
  33.                if data_handler.should_sync:
  34.                 context.async_wait()
  35.               logs = tmp_logs  # No error, now safe to assign to logs.
  36.               end_step = step + data_handler.step_increment
  37.               callbacks.on_train_batch_end(end_step, logs)
  38.         epoch_logs =  copy. copy(logs)
  39.         # Run validation.
  40.         ## I am hook
  41.         callbacks.on_epoch_end(epoch, epoch_logs)

3.2 mmdetection

mmdetection是一个目标检测的开源框架,集成了许多不同的目标检测深度学习算法(pytorch版),如faster-rcnn, fpn, retianet等。里面也大量使用了hook,暴露给应用实现流程中具体部分。

详见https://github.com/open-mmlab/mmdetection

这里看一个训练的调用例子(摘录)(https://github.com/open-mmlab/mmdetection/blob/5d592154cca589c5113e8aadc8798bbc73630d98/mmdet/apis/train.py


   
  1. def train_detector(model,
  2.                    dataset,
  3.                    cfg,
  4.                    distributed=False,
  5.                    validate=False,
  6.                    timestamp=None,
  7.                    meta=None):
  8.     logger = get_root_logger(cfg.log_level)
  9.     # prepare data loaders
  10.     # put model on gpus
  11.     # build runner
  12.     optimizer = build_optimizer(model, cfg.optimizer)
  13.     runner = EpochBasedRunner(
  14.         model,
  15.         optimizer=optimizer,
  16.         work_dir=cfg.work_dir,
  17.         logger=logger,
  18.         meta=meta)
  19.     # an ugly workaround to  make .log and .log.json filenames the same
  20.     runner.timestamp = timestamp
  21.     # fp16 setting
  22.     # register hooks
  23.     runner.register_training_hooks(cfg.lr_config, optimizer_config,
  24.                                    cfg.checkpoint_config, cfg.log_config,
  25.                                    cfg.get( 'momentum_config', None))
  26.      if distributed:
  27.         runner.register_hook(DistSamplerSeedHook())
  28.     # register eval hooks
  29.      if validate:
  30.         # Support batch_size >  1 in validation
  31.         eval_cfg = cfg.get( 'evaluation', {})
  32.         eval_hook = DistEvalHook  if distributed  else EvalHook
  33.         runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
  34.     # user-defined hooks
  35.      if cfg.get( 'custom_hooks', None):
  36.         custom_hooks = cfg.custom_hooks
  37.         assert isinstance(custom_hooks, list), \
  38.             f 'custom_hooks expect list type, but got {type(custom_hooks)}'
  39.          for hook_cfg in cfg.custom_hooks:
  40.             assert isinstance(hook_cfg, dict), \
  41.                  'Each item in custom_hooks expects dict type, but got ' \
  42.                 f '{type(hook_cfg)}'
  43.             hook_cfg = hook_cfg. copy()
  44.             priority = hook_cfg.pop( 'priority''NORMAL')
  45.             hook = build_from_cfg(hook_cfg, HOOKS)
  46.             runner.register_hook(hook, priority=priority)

4. 总结

本文介绍了hook的概念和应用,并给出了python的实现细则。希望对比有帮助。总结如下:

  • hook函数是流程中预定义好的一个步骤,没有实现

  • 挂载或者注册时, 流程执行就会执行这个钩子函数

  • 回调函数和hook函数功能上是一致的

  • hook设计方式带来灵活性,如果流程中有一个步骤,你想让调用方来实现,你可以用hook函数

作者简介:wedo实验君, 数据分析师;热爱生活,热爱写作

赞 赏 作 者

推荐阅读

有人在代码里下毒!慎用 pip install 命令

让 Pandas DataFrame 性能飞升 40 倍

Flask框架钩子函数使用方式及应用场景分析

点击下方阅读原文加入社区会员

点赞鼓励一下


转载:https://blog.csdn.net/BF02jgtRS00XKtCx/article/details/110458293
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场