-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdetect_mnist.py
32 lines (28 loc) · 1.07 KB
/
detect_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
#================================================================
#
# File name : detect_mnist.py
# Author : PyLessons
# Created date: 2020-08-12
# Website : https://pylessons.com/
# GitHub : https://github.com/pythonlessons/TensorFlow-2.x-YOLOv3
# Description : mnist object detection example
#
#================================================================
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import cv2
import numpy as np
import random
import time
import tensorflow as tf
from yolov3.yolov4 import Create_Yolo
from yolov3.utils import detect_image
from yolov3.configs import *
while True:
ID = random.randint(0, 200)
label_txt = "mnist/mnist_test.txt"
image_info = open(label_txt).readlines()[ID].split()
image_path = image_info[0]
yolo = Create_Yolo(input_size=YOLO_INPUT_SIZE, CLASSES=TRAIN_CLASSES)
yolo.load_weights(f"./checkpoints/{TRAIN_MODEL_NAME}") # use keras weights
detect_image(yolo, image_path, "mnist_test.jpg", input_size=YOLO_INPUT_SIZE, show=True, CLASSES=TRAIN_CLASSES, rectangle_colors=(255,0,0))