-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathexport_larq_model.py
74 lines (62 loc) · 3.18 KB
/
export_larq_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import cv2
import numpy as np
import larq_compute_engine as lce
import tensorflow as tf
from tensorflow.python.saved_model import tag_constants
from yolov3.dataset import Dataset
from yolov3.yolov4 import Create_Yolo
from yolov3.utils import load_yolo_weights, postprocess_boxes, nms_no_gather
from yolov3.configs import *
import shutil
import json
import time
gpus = tf.config.experimental.list_physical_devices('GPU')
if len(gpus) > 0:
try: tf.config.experimental.set_memory_growth(gpus[0], True)
except RuntimeError: print("RuntimeError in tf.config.experimental.list_physical_devices('GPU')")
# Custom Keras layer, for easy exporting
class PostProcess(tf.keras.layers.Layer):
def __init__(self, iou_threshold, score_threshold, **kwargs):
self.iou_threshold = iou_threshold
self.score_threshold = score_threshold
super(PostProcess, self).__init__(**kwargs)
def post_prediction_process(self,
pred_boxes):
flattened_boxes = tf.reshape(pred_boxes, (-1, tf.shape(pred_boxes)[-1]))
boxes = postprocess_boxes(flattened_boxes, score_threshold=self.score_threshold)
selected_indices = nms_no_gather(boxes, iou_threshold=self.iou_threshold)
boxes, box_scores, box_classes = tf.split(boxes, (4, 1, 1), axis=-1)
box_scores = tf.squeeze(box_scores, axis=-1)
box_classes = tf.cast(box_classes, dtype=tf.int32)
box_classes = tf.squeeze(box_classes, axis=-1)
return boxes, box_scores, box_classes, selected_indices
def call(self, y_pred):
return self.post_prediction_process(y_pred)
if __name__ == '__main__':
if YOLO_FRAMEWORK == "tf": # TensorFlow detection
if YOLO_TYPE == "yolov4":
Darknet_weights = YOLO_V4_TINY_WEIGHTS if TRAIN_YOLO_TINY else YOLO_V4_WEIGHTS
if YOLO_TYPE == "yolov3":
Darknet_weights = YOLO_V3_TINY_WEIGHTS if TRAIN_YOLO_TINY else YOLO_V3_WEIGHTS
if YOLO_TYPE == "yolov2":
Darknet_weights = YOLO_V2_WEIGHTS
if YOLO_CUSTOM_WEIGHTS == False:
yolo = Create_Yolo(input_size=YOLO_INPUT_SIZE, CLASSES=YOLO_COCO_CLASSES)
load_yolo_weights(yolo, Darknet_weights) # use Darknet weights
else:
yolo = Create_Yolo(input_size=YOLO_INPUT_SIZE, CLASSES=TRAIN_CLASSES)
yolo.load_weights(f"./checkpoints/{TRAIN_MODEL_NAME}") # use custom weights
elif YOLO_FRAMEWORK == "trt": # TensorRT detection
saved_model_loaded = tf.saved_model.load(f"./checkpoints/{TRAIN_MODEL_NAME}", tags=[tag_constants.SERVING])
signature_keys = list(saved_model_loaded.signatures.keys())
yolo = saved_model_loaded.signatures['serving_default']
post_processed_output = PostProcess(TEST_IOU_THRESHOLD, TEST_SCORE_THRESHOLD)(yolo.output)
yolo = tf.keras.models.Model(yolo.input, post_processed_output)
flatbuffer_bytes = lce.convert_keras_model(yolo)
# export
exported_model_path = f'checkpoints/{TRAIN_MODEL_NAME}.tflite'
with open(exported_model_path, "wb") as flatbuffer_file:
flatbuffer_file.write(flatbuffer_bytes)
print(f'exported to: {exported_model_path}')