Skip to content

Commit

Permalink
generate_webcam.py fix
Browse files Browse the repository at this point in the history
(use --precision "webcam" in training)
  • Loading branch information
nicolai256 authored Oct 24, 2022
1 parent 8655e90 commit b48c612
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions generate_webcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch.utils.data
import torch
import torch.nn as nn
import cv2


Expand All @@ -16,8 +17,11 @@
parser.add_argument("--resolution", type=int, nargs=2, metavar=('width', 'height'), default=(480, 640))
parser.add_argument("--show_original", type=int, default=0)
parser.add_argument("--resize", type=int, default=256)
parser.add_argument("--webcam_number", type=int, default=0)
parser.add_argument("--video_path", type=str)
args = parser.parse_args()



generator = (torch.load(args.checkpoint, map_location=lambda storage, loc: storage))
generator.eval()

Expand All @@ -29,19 +33,27 @@
generator = generator.type(torch.half)

transform = build_transform()

cap = cv2.VideoCapture(0)
if args.video_path:
cap = cv2.VideoCapture(args.video_path)
else:
cap = cv2.VideoCapture(args.webcam_number)
width, height = args.resolution
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)

#cap.set(cv2.CAP_PROP_CHANNEL, 1)
frame_counter = 0
while True:
ret, frame = cap.read()
if not ret:
cap.release()
cv2.destroyAllWindows()
exit()

frame_counter += 1
#If the last frame is reached, reset the capture and the frame_counter
if args.video_path:
if frame_counter == cap.get(cv2.CAP_PROP_FRAME_COUNT):
frame_counter = 0 #Or whatever as long as it is the same as next line
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
x = int(frame.shape[0] / 2)
y = int(frame.shape[1] / 2)
res = min(x, y)
Expand All @@ -53,6 +65,7 @@
if device.lower() != "cpu":
net_in = net_in.type(torch.half)
net_out = generator(net_in)

im = ((net_out[0].clamp(-1, 1) + 1) * 127.5).permute((1, 2, 0)).cpu().data.numpy().astype(np.uint8)
im = cv2.cvtColor(cv2.resize(im, (2*res, 2*res)), cv2.COLOR_RGB2BGR)
if args.show_original == 1:
Expand All @@ -63,5 +76,4 @@
cap.release()
cv2.destroyAllWindows()
exit()



0 comments on commit b48c612

Please sign in to comment.