小言_互联网的博客

百度PaddlePaddle_OCR文字识别_准确率98%

606人阅读  评论(0)

这篇文章主要介绍
如何将百度PaddlePaddle下的OCR文字识别整理打包成Flask框架的WebAPI,最终实现的效果是传入图片的base64编码,返回识别的字符串

其他开源算法应用

图像识别
GoogLeNet、MobileNet

语音识别
MASR中文语音识别

对象检测
YOLO深度学习框架

自然语言处理
谷歌BERT

源码如下:如需解说、完整思路说明、配置文件,请到我其他文章找到联系方式

import argparse
import base64
import hashlib
import json
import logging as logger
import math
import os
import sys
import time
from threading import Thread

import cv2
import numpy as np
import paddle.fluid as fluid
import requests
from flask import request, Flask, Request
from paddle.fluid.core_avx import AnalysisConfig, create_paddle_predictor

__dir__ = os.path.dirname(os.path.abspath(__file__))

from werkzeug.serving import run_simple

sys.path.append(__dir__)
sys.path.append(os.path.abspath(os.path.join(__dir__, '../..')))


class CharacterOps(object):
    """
    Convert between text-label and text-index
    """

    def __init__(self, config):
        self.character_type = config['character_type']
        self.loss_type = config['loss_type']
        self.max_text_len = config['max_text_length']
        # use the default dictionary(36 char)
        if self.character_type == "en":
            self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
            dict_character = list(self.character_str)
        # use the custom dictionary
        elif self.character_type == "ch":
            character_dict_path = config['character_dict_path']
            add_space = False
            if 'use_space_char' in config:
                add_space = config['use_space_char']
            self.character_str = ""
            with open(character_dict_path, "rb") as fin:
                lines = fin.readlines()
                for line in lines:
                    line = line.decode('utf-8').strip("\n").strip("\r\n")
                    self.character_str += line
            if add_space:
                self.character_str += " "
            dict_character = list(self.character_str)
        else:
            self.character_str = None
        assert self.character_str is not None, \
            "Nonsupport type of the character: {}".format(self.character_str)
        self.beg_str = "sos"
        self.end_str = "eos"
        # add start and end str for attention
        # create char dict
        self.dict = {
   }
        for i, char in enumerate(dict_character):
            self.dict[char] = i
        self.character = dict_character

    def decode(self, text_index, is_remove_duplicate=False):
        """
        convert text-index into text-label.
        Args:
            text_index: text index for each image
            is_remove_duplicate: Whether to remove duplicate characters,
                                 The default is False
        Return:
            text: text label
        """
        char_list = []
        char_num = self.get_char_num()

        ignored_tokens = [char_num]

        for idx in range(len(text_index)):
            if text_index[idx] in ignored_tokens:
                continue
            if is_remove_duplicate:
                if idx > 0 and text_index[idx - 1] == text_index[idx]:
                    continue
            char_list.append(self.character[int(text_index[idx])])
        text = ''.join(char_list)
        return text

    def get_char_num(self):
        """
        Get character num
        """
        return len(self.character)

    def get_beg_end_flag_idx(self, beg_or_end):
        if self.loss_type == "attention":
            if beg_or_end == "beg":
                idx = np.array(self.dict[self.beg_str])
            elif beg_or_end == "end":
                idx = np.array(self.dict[self.end_str])
            else:
                assert False, "Unsupport type %s in get_beg_end_flag_idx" \
                              % beg_or_end
            return idx
        else:
            err = "error in get_beg_end_flag_idx when using the loss %s" \
                  % (self.loss_type)
            assert False, err


def create_predictor(args):
    model_file_path = "__model__"
    params_file_path = "params"
    if not os.path.exists(model_file_path):
        logger.info("not find __model__ file path {}".format(model_file_path))
        sys.exit(0)
    if not os.path.exists(params_file_path):
        logger.info("not find params file path {}".format(params_file_path))
        sys.exit(0)

    config = AnalysisConfig(model_file_path, params_file_path)

    config.disable_gpu()
    config.set_cpu_math_library_num_threads(6)
    if args.enable_mkldnn:
        config.set_mkldnn_cache_capacity(10)
        config.enable_mkldnn()

    config.disable_glog_info()

    if args.use_zero_copy_run:
        config.delete_pass("conv_transpose_eltwiseadd_bn_fuse_pass")
        config.switch_use_feed_fetch_ops(False)
    else:
        config.switch_use_feed_fetch_ops(True)

    predictor = create_paddle_predictor(config)
    input_names = predictor.get_input_names()
    for name in input_names:
        input_tensor = predictor.get_input_tensor(name)
    output_names = predictor.get_output_names()
    output_tensors = []
    for output_name in output_names:
        output_tensor = predictor.get_output_tensor(output_name)
        output_tensors.append(output_tensor)
    return predictor, input_tensor, output_tensors


def initial_logger():
    FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
    logger.basicConfig(level=logger.INFO, format=FORMAT)
    logger1 = logger.getLogger(__name__)
    return logger1


class TextRecognizer(object):
    def __init__(self, args):
        if args.use_pdserving is False:
            self.predictor, self.input_tensor, self.output_tensors = \
                create_predictor(args)
            self.use_zero_copy_run = args.use_zero_copy_run
        self.rec_image_shape = [int(v) for v in args.rec_image_shape.split(",")]
        self.character_type = args.rec_char_type
        self.rec_batch_num = args.rec_batch_num
        self.rec_algorithm = args.rec_algorithm
        self.text_len = args.max_text_length
        char_ops_params = {
   "character_type": args.rec_char_type, "character_dict_path": args.rec_char_dict_path,
                           "use_space_char": args.use_space_char, "max_text_length": args.max_text_length,
                           'loss_type': 'ctc'}

        self.loss_type = 'ctc'
        self.char_ops = CharacterOps(char_ops_params)

    def resize_norm_img(self, img, max_wh_ratio):
        imgC, imgH, imgW = self.rec_image_shape
        assert imgC == img.shape[2]
        wh_ratio = max(max_wh_ratio, imgW * 1.0 / imgH)
        if self.character_type == "ch":
            imgW = int((32 * wh_ratio))
        h, w = img.shape[:2]
        ratio = w / float(h)
        if math.ceil(imgH * ratio) > imgW:
            resized_w = imgW
        else:
            resized_w = int(math.ceil(imgH * ratio))
        resized_image = cv2.resize(img, (resized_w, imgH))
        resized_image = resized_image.astype('float32')
        resized_image = resized_image.transpose((2, 0, 1)) / 255
        resized_image -= 0.5
        resized_image /= 0.5
        padding_im = np.zeros((imgC, imgH, imgW), dtype=np.float32)
        padding_im[:, :, 0:resized_w] = resized_image
        return padding_im

    def __call__(self, img_list):
        img_num = len(img_list)
        # Calculate the aspect ratio of all text bars
        width_list = []
        for img in img_list:
            width_list.append(img.shape[1] / float(img.shape[0]))
        # Sorting can speed up the recognition process
        indices = np.argsort(np.array(width_list))

        rec_res = [['', 0.0]] * img_num
        batch_num = self.rec_batch_num
        predict_time = 0
        for beg_img_no in range(0, img_num, batch_num):
            end_img_no = min(img_num, beg_img_no + batch_num)
            norm_img_batch = []
            max_wh_ratio = 0
            for ino in range(beg_img_no, end_img_no):
                h, w = img_list[indices[ino]].shape[0:2]
                wh_ratio = w * 1.0 / h
                max_wh_ratio = max(max_wh_ratio, wh_ratio)
            for ino in range(beg_img_no, end_img_no):
                norm_img = self.resize_norm_img(img_list[indices[ino]],
                                                max_wh_ratio)
                norm_img = norm_img[np.newaxis, :]
                norm_img_batch.append(norm_img)

            norm_img_batch = np.concatenate(norm_img_batch, axis=0)
            norm_img_batch = norm_img_batch.copy()

            starttime = time.time()
            if self.use_zero_copy_run:
                self.input_tensor.copy_from_cpu(norm_img_batch)
                self.predictor.zero_copy_run()
            else:
                norm_img_batch = fluid.core.PaddleTensor(norm_img_batch)
                self.predictor.run([norm_img_batch])

            rec_idx_batch = self.output_tensors[0].copy_to_cpu()
            rec_idx_lod = self.output_tensors[0].lod()[0]
            predict_batch = self.output_tensors[1].copy_to_cpu()
            predict_lod = self.output_tensors[1].lod()[0]
            elapse = time.time() - starttime
            predict_time += elapse
            for rno in range(len(rec_idx_lod) - 1):
                beg = rec_idx_lod[rno]
                end = rec_idx_lod[rno + 1]
                rec_idx_tmp = rec_idx_batch[beg:end, 0]
                preds_text = self.char_ops.decode(rec_idx_tmp)
                beg = predict_lod[rno]
                end = predict_lod[rno + 1]
                probs = predict_batch[beg:end, :]
                ind = np.argmax(probs, axis=1)
                blank = probs.shape[1]
                valid_ind = np.where(ind != (blank - 1))[0]
                if len(valid_ind) == 0:
                    continue
                score = np.mean(probs[valid_ind, ind[valid_ind]])
                rec_res[indices[beg_img_no + rno]] = [preds_text, score]

        return rec_res, predict_time


def parse_args():
    def str2bool(v):
        return v.lower() in ("true", "t", "1")

    parser = argparse.ArgumentParser()
    # params for prediction engine
    parser.add_argument("--use_gpu", type=str2bool, default=False)

    # params for text recognizer
    parser.add_argument("--rec_algorithm", type=str, default='CRNN')
    parser.add_argument("--rec_model_dir", type=str, default='')
    parser.add_argument("--rec_image_shape", type=str, default="3, 32, 320")
    parser.add_argument("--rec_char_type", type=str, default='ch')
    parser.add_argument("--rec_batch_num", type=int, default=120)
    parser.add_argument("--max_text_length", type=int, default=25)
    parser.add_argument(
        "--rec_char_dict_path",
        type=str,
        default="ppocr_keys_v1.txt")
    parser.add_argument("--use_space_char", type=str2bool, default=True)

    parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
    parser.add_argument("--use_zero_copy_run", type=str2bool, default=False)

    parser.add_argument("--use_pdserving", type=str2bool, default=False)

    return parser.parse_args()


def base64_to_image(base64_code):
    """将base64的数据转换成rgb格式的图像矩阵"""
    img_data = base64.b64decode(base64_code)
    img_array = np.frombuffer(img_data, np.uint8)
    img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
    return img


def main(args, image_str):
    img_list = []
    try:
        img = base64_to_image(image_str)
        img_list.append(img)
    except Exception as e:
        print(e)
        return 'img_str'
    try:
        text_recognizer = TextRecognizer(args)
        rec_res, predict_time = text_recognizer(img_list)
    except Exception as e:
        print(e)
        return 'text_recognizer'
    if rec_res:
        print("Predict:%s" % (rec_res[0]))
        print("Total predict time for %d images:%.3f" %
              (len(img_list), predict_time))
        return rec_res[0][0]
    else:
        return 'text_recognizer'


app = Flask('ocr')


@app.route('/ocr', methods=['POST'])  # 代表首页
def ocr():
    try:
        json_str = request.json
    except Exception as e:
        print(e)
        return json.dumps({
   
            'status': 0,
            'msg': 'json wrong!'
        })
    if json_str:
        keys = json_str.keys()
        if 'code' in keys:
            if 'image' in keys:
                image_str = json_str['image']
                if image_str:
                    code = hashlib.new('md5', md5_str.encode(encoding='UTF-8')).hexdigest()
                    if code == json_str['code']:
                        rec_res = main(parse_args(), image_str)
                        if rec_res == 'img_str':
                            print('该base64字符串无法解析')
                            return json.dumps({
   
                                'status': -1,
                                'msg': 'The base64 string cannot be parsed'
                            })
                        elif rec_res == 'text_recognizer':
                            print('图片识别异常')
                            return json.dumps({
   
                                'status': -1,
                                'msg': 'The picture is not recognized'
                            })
                        else:
                            return json.dumps({
   
                                'status': 1,
                                'data': rec_res
                            })
                    else:
                        return json.dumps({
   
                            'status': -1,
                            'msg': 'Code verification failed'
                        })
                else:
                    return json.dumps({
   
                        'status': -1,
                        'msg': 'The parameter is empty or the parameter is not standard'
                    })
            else:
                return json.dumps({
   
                    'status': -1,
                    'msg': 'image is null'
                })
        else:
            return json.dumps({
   
                'status': 0,
                'msg': 'Missing parameter'
            })
    else:
        return json.dumps({
   
            'status': 0,
            'msg': 'json is null'
        })


def application():
    while True:
        dd = requests.get(url_bert)
        print(dd.text)
        time.sleep(10)


def start_app():
    app.run('192.168.0.128', port=52013)  # 运行程序


if __name__ == '__main__':
    print('start app server!')
    url_bert = 'http://192.168.0.128:8080/HT/api/TaskSave?task=PROCPPS.OCR文字识别服务&notice=1&key=202101041434'
    Thread(target=start_app).start()
    Thread(target=application).start()
    # app.run(host='192.168.0.128', port=52013)  # 运行程序
    print('end app server!')


转载:https://blog.csdn.net/qq_30803353/article/details/113754183
查看评论
* 以上用户言论只代表其个人观点,不代表本网站的观点或立场