From f3e476c3007ceac416bcd7cdd88104ec3b8899b7 Mon Sep 17 00:00:00 2001 From: Will S Date: Wed, 5 Jan 2022 12:16:06 +0000 Subject: [PATCH] Add tflite_infer.py --- README.md | 4 +- load_model/tflite_loader.py | 41 ++++++++++ tflite_infer.py | 148 ++++++++++++++++++++++++++++++++++++ 3 files changed, 191 insertions(+), 2 deletions(-) create mode 100644 load_model/tflite_loader.py create mode 100644 tflite_infer.py diff --git a/README.md b/README.md index a403105..9210460 100644 --- a/README.md +++ b/README.md @@ -71,8 +71,8 @@ python pytorch_infer.py --img-mode 0 --video-path /path/to/video python pytorch_infer.py --img-mode 0 --video-path 0 ``` ### TensorFlow/Keras/MXNet/Caffe -The other four frameworks running method is similar to pytorch, just replace `pytorch`with `tensorflow`, `keras`,`caffe`,`mxnet`, -if you want to use tensorflow, just run: +The other four frameworks running method is similar to pytorch, just replace `pytorch` with `tensorflow`, `tflite`, `keras`,`caffe` or `mxnet`. +For instance if you want to use tensorflow, just run: ``` python tensorflow_infer.py --img-path /path/to/your/img ``` diff --git a/load_model/tflite_loader.py b/load_model/tflite_loader.py new file mode 100644 index 0000000..017a2ec --- /dev/null +++ b/load_model/tflite_loader.py @@ -0,0 +1,41 @@ +# -*- encoding=utf-8 -*- +import tflite_runtime.interpreter as tflite + +import numpy as np + +def load_tflite_model(tflite_model_path): + ''' + Load the model. + :param tflite_model_path: model to tflite model. + :return: interpreter and tensor indexes + ''' + interpreter = tflite.Interpreter(tflite_model_path) + interpreter.allocate_tensors() + + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + input_idx = input_details[0]['index'] + bbox_idx = output_details[0]['index'] + scores_idx = output_details[1]['index'] + + return interpreter, (input_idx, (bbox_idx, scores_idx)) + + +def tflite_inference(interpreter, indexes, img_arr): + ''' + Receive an image array and run inference + :param interpreter: tflite interpreter. + :param indexes: tflite tensor indexes. + :param img_arr: 3D numpy array, RGB order. + :return: + ''' + input_data = np.array(img_arr, dtype=np.float32) + interpreter.set_tensor(indexes[0], input_data) + + interpreter.invoke() + + bboxes = interpreter.get_tensor(indexes[1][0]) + scores = interpreter.get_tensor(indexes[1][1]) + + return bboxes, scores diff --git a/tflite_infer.py b/tflite_infer.py new file mode 100644 index 0000000..83da529 --- /dev/null +++ b/tflite_infer.py @@ -0,0 +1,148 @@ +# -*- coding:utf-8 -*- +import cv2 +import time +import argparse + +import numpy as np +from PIL import Image +#from keras.models import model_from_json +from utils.anchor_generator import generate_anchors +from utils.anchor_decode import decode_bbox +from utils.nms import single_class_non_max_suppression +from load_model.tflite_loader import load_tflite_model, tflite_inference + +interpreter, indexes = load_tflite_model('models/face_mask_detection.tflite') +# anchor configuration +feature_map_sizes = [[33, 33], [17, 17], [9, 9], [5, 5], [3, 3]] +anchor_sizes = [[0.04, 0.056], [0.08, 0.11], [0.16, 0.22], [0.32, 0.45], [0.64, 0.72]] +anchor_ratios = [[1, 0.62, 0.42]] * 5 + +# generate anchors +anchors = generate_anchors(feature_map_sizes, anchor_sizes, anchor_ratios) + +# for inference , the batch size is 1, the model output shape is [1, N, 4], +# so we expand dim for anchors to [1, anchor_num, 4] +anchors_exp = np.expand_dims(anchors, axis=0) + +id2class = {0: 'Mask', 1: 'NoMask'} + + +def inference(image, + conf_thresh=0.5, + iou_thresh=0.4, + target_shape=(160, 160), + draw_result=True, + show_result=True + ): + ''' + Main function of detection inference + :param image: 3D numpy array of image + :param conf_thresh: the min threshold of classification probabity. + :param iou_thresh: the IOU threshold of NMS + :param target_shape: the model input size. + :param draw_result: whether to daw bounding box to the image. + :param show_result: whether to display the image. + :return: + ''' + # image = np.copy(image) + output_info = [] + height, width, _ = image.shape + image_resized = cv2.resize(image, target_shape) + image_np = image_resized / 255.0 # 归一化到0~1 + image_exp = np.expand_dims(image_np, axis=0) + y_bboxes_output, y_cls_output = tflite_inference(interpreter, indexes, image_exp) + + # remove the batch dimension, for batch is always 1 for inference. + y_bboxes = decode_bbox(anchors_exp, y_bboxes_output)[0] + y_cls = y_cls_output[0] + # To speed up, do single class NMS, not multiple classes NMS. + bbox_max_scores = np.max(y_cls, axis=1) + bbox_max_score_classes = np.argmax(y_cls, axis=1) + + # keep_idx is the alive bounding box after nms. + keep_idxs = single_class_non_max_suppression(y_bboxes, + bbox_max_scores, + conf_thresh=conf_thresh, + iou_thresh=iou_thresh, + ) + + for idx in keep_idxs: + conf = float(bbox_max_scores[idx]) + class_id = bbox_max_score_classes[idx] + bbox = y_bboxes[idx] + # clip the coordinate, avoid the value exceed the image boundary. + xmin = max(0, int(bbox[0] * width)) + ymin = max(0, int(bbox[1] * height)) + xmax = min(int(bbox[2] * width), width) + ymax = min(int(bbox[3] * height), height) + + if draw_result: + if class_id == 0: + color = (0, 255, 0) + else: + color = (255, 0, 0) + cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2) + cv2.putText(image, "%s: %.2f" % (id2class[class_id], conf), (xmin + 2, ymin - 2), + cv2.FONT_HERSHEY_SIMPLEX, 0.8, color) + output_info.append([class_id, conf, xmin, ymin, xmax, ymax]) + + if show_result: + Image.fromarray(image).show() + return output_info + + +def run_on_video(video_path, output_video_name, conf_thresh): + cap = cv2.VideoCapture(video_path) + height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) + width = cap.get(cv2.CAP_PROP_FRAME_WIDTH) + fps = cap.get(cv2.CAP_PROP_FPS) + fourcc = cv2.VideoWriter_fourcc(*'XVID') + # writer = cv2.VideoWriter(output_video_name, fourcc, int(fps), (int(width), int(height))) + total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) + if not cap.isOpened(): + raise ValueError("Video open failed.") + return + status = True + idx = 0 + while status: + start_stamp = time.time() + status, img_raw = cap.read() + img_raw = cv2.cvtColor(img_raw, cv2.COLOR_BGR2RGB) + read_frame_stamp = time.time() + if (status): + inference(img_raw, + conf_thresh, + iou_thresh=0.5, + target_shape=(260, 260), + draw_result=True, + show_result=False) + cv2.imshow('image', img_raw[:, :, ::-1]) + cv2.waitKey(1) + inference_stamp = time.time() + # writer.write(img_raw) + write_frame_stamp = time.time() + idx += 1 + print("%d of %d" % (idx, total_frames)) + print("read_frame:%f, infer time:%f, write time:%f" % (read_frame_stamp - start_stamp, + inference_stamp - read_frame_stamp, + write_frame_stamp - inference_stamp)) + # writer.release() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Face Mask Detection") + parser.add_argument('--img-mode', type=int, default=1, help='set 1 to run on image, 0 to run on video.') + parser.add_argument('--img-path', type=str, help='path to your image.') + parser.add_argument('--video-path', type=str, default='0', help='path to your video, `0` means to use camera.') + # parser.add_argument('--hdf5', type=str, help='keras hdf5 file') + args = parser.parse_args() + if args.img_mode: + imgPath = args.img_path + img = cv2.imread(imgPath) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + inference(img, show_result=True, target_shape=(260, 260)) + else: + video_path = args.video_path + if args.video_path == '0': + video_path = 0 + run_on_video(video_path, '', conf_thresh=0.5)