-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathremote_dataflow.py
31 lines (25 loc) · 1.02 KB
/
remote_dataflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# -*- coding: utf-8 -*-
# !/usr/bin/env python
import argparse
from tensorpack.dataflow.remote import send_dataflow_zmq
from data_load import DataLoader, AudioMeta
from hparam import hparam as hp
import multiprocessing
if __name__ == '__main__':
# get arguments
parser = argparse.ArgumentParser()
parser.add_argument('case', type=str, help='experiment case name.')
parser.add_argument('-data_path', type=str)
parser.add_argument('-dest_url', type=str)
parser.add_argument('-num_thread', type=int, default=1)
args = parser.parse_args()
# set hyper-parameters from yaml file
hp.set_hparam_yaml(case=args.case)
if args.data_path:
hp.train.data_path = args.data_path
# dataflow
audio_meta = AudioMeta(hp.train.data_path)
data_loader = DataLoader(audio_meta, 1)
num_thread = args.num_thread if args.num_thread else multiprocessing.cpu_count() // 1.5
data_loader = data_loader.dataflow(nr_prefetch=5000, nr_thread=args.num_thread)
send_dataflow_zmq(data_loader, args.dest_url)