小言_互联网的博客

TensorFlow Serving之导出自己的训练模型

479人阅读  评论(0)

0 背景

在《TensorFlow Serving之安装及调用方法》中,介绍了tensorflow serving的基本概念和安装调用方法,本文介绍如何导出自己的训练模型,生成服务所需的*pb模型和variables文件夹。

在《TensorFlow之目标检测API接口调试(超详细)》中,介绍了如何准备自己的训练数据以及如何导出训练模型,在第5步导出模型中,发现只生成了saved_model.pb模型,而variables文件夹为空,而这个文件夹中的文件对于服务而言是必不可少的,因此本文的第一步工作是如何转化生成这些文件

1 转化方法

在github上有类似的问题提问export as Savedmodel generating empty variables directory #1988,下边有人指出,只需要修改exporter.py文件中的write_saved_modle函数即可

def write_saved_model(saved_model_path,
                      trained_checkpoint_prefix,
                      inputs,
                      outputs):

  saver = tf.train.Saver()
  with tf.Session() as sess:
    saver.restore(sess, trained_checkpoint_prefix)

    builder = tf.saved_model.builder.SavedModelBuilder(saved_model_path)

    tensor_info_inputs = {
        'inputs': tf.saved_model.utils.build_tensor_info(inputs)}
    tensor_info_outputs = {}
    for k, v in outputs.items():
      tensor_info_outputs[k] = tf.saved_model.utils.build_tensor_info(v)

    detection_signature = (
        tf.saved_model.signature_def_utils.build_signature_def(
            inputs=tensor_info_inputs,
            outputs=tensor_info_outputs,
            method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
        ))

    builder.add_meta_graph_and_variables(
        sess,
        [tf.saved_model.tag_constants.SERVING],
        signature_def_map={
            tf.saved_model.signature_constants
            .DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                detection_signature,
        },
    )
    builder.save()

同时将_export_inference_graph函数中最后调用write_saved_modle的地方修改为如下

  write_saved_model(saved_model_path, trained_checkpoint_prefix,
                    placeholder_tensor, outputs)

然后就可以导出生成variables文件夹。至此,如果你能成功导出,那么恭喜你已经完成任务了,但我在实现的时候,报错如下

Traceback (most recent call last):
  File "export_inference_graph_unfrozen.py", line 173, in <module>
    tf.app.run()
  File "/home/lthpc/anaconda3/envs/tensorflow/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 40, in run
    _run(main=main, argv=argv, flags_parser=_parse_flags_tolerate_undef)
  File "/home/lthpc/anaconda3/envs/tensorflow/lib/python3.5/site-packages/absl/app.py", line 299, in run
    _run_main(main, args)
  File "/home/lthpc/anaconda3/envs/tensorflow/lib/python3.5/site-packages/absl/app.py", line 250, in _run_main
    sys.exit(main(argv))
  File "export_inference_graph_unfrozen.py", line 169, in main
    FLAGS.output_directory)
  File "export_inference_graph_unfrozen.py", line 155, in export_inference_graph
    optimize_graph, output_collection_name)
  File "export_inference_graph_unfrozen.py", line 107, in _export_inference_graph
    output_tensors = detection_model.predict(preprocessed_inputs)
TypeError: predict() missing 1 required positional argument: 'true_image_shapes'

在github有人遇到过类似的问题,他说将python版本切换为2.7后就可以导出了,我尝试了下,果然成功了,因此接下来配置一个2.7版本的tensorflow环境,专门用来生成服务模型

2 环境配置

所用到的环境为:tensorflow-gpu 1.12.0 、cuda9.0、python2.7,从头开始配置环境

conda create -n py2.7 pip python=2.7
source activate py2.7
pip install --upgrade https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.12.0-cp27-none-linux_x86_64.whl
pip install tensorflow-estimator==1.10.12
pip install matplotlib
pip install pillow

配置好python2.7的环境后,下载目标检测的API代码

git clone https://github.com/tensorflow/models.git
cd models-master/research
protoc object_detection/protos/*.proto --python_out=.     
export PYTHONPATH=$PYTHONPATH:`pwd`:`pwd`/slim

按照第1步中的转化方法,修改exporter.py函数(两处修改),就可以运行export_inference_graph.py,导出模型了

python export_inference_graph.py --input_type image_tensor --pipeline_config_path mymodel/faster_rcnn_resnet50_coco.config --trained_checkpoint_prefix mymodel/input/model.ckpt-50000 --output_directory mymodel/output/

(python2.7) lthpc@lthpc:~/workspace_zong/tensorflow_serving/models-master/research/object_detection/mymodel$ tree 
.
├── faster_rcnn_resnet50_coco.config
├── input
│   ├── checkpoint
│   ├── model.ckpt-50000.data-00000-of-00001
│   ├── model.ckpt-50000.index
│   └── model.ckpt-50000.meta
└── output
    ├── checkpoint
    ├── frozen_inference_graph.pb
    ├── model.ckpt.data-00000-of-00001
    ├── model.ckpt.index
    ├── model.ckpt.meta
    ├── pipeline.config
    └── saved_model
        ├── saved_model.pb
        └── variables
            ├── variables.data-00000-of-00001
            └── variables.index

 有了上述模型,就可以进行下一步,将训练好的模型导入tensorflow服务当中了


转载:https://blog.csdn.net/zong596568821xp/article/details/101210524
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场