forked from killthekitten/kaggle-carvana-2017
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathensemble_cpu.py
59 lines (44 loc) · 1.68 KB
/
ensemble_cpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import threading
from scipy.misc.pilutil import imread, imsave
from params import args
import numpy as np
import os
from utils import ThreadsafeIter
def average_strategy(images):
return np.average(images, axis=0)
def hard_voting(images):
rounded = np.round(images / 255.)
return np.round(np.sum(rounded, axis=0) / images.shape[0]) * 255.
def ensemble_image(files, dirs, ensembling_dir, strategy):
for file in files:
images = []
for dir in dirs:
file_path = os.path.join(dir, file)
if os.path.exists(file_path):
images.append(imread(file_path, mode='L'))
images = np.array(images)
if strategy == 'average':
ensembled = average_strategy(images)
elif strategy == 'hard_voting':
ensembled = hard_voting(images)
else:
raise ValueError('Unknown ensembling strategy')
imsave(os.path.join(ensembling_dir, file), ensembled)
def ensemble(dirs, strategy, ensembling_dir, n_threads):
files = ThreadsafeIter(os.listdir(dirs[0]))
threads = [threading.Thread(target=ensemble_image, args=(files, dirs, ensembling_dir, strategy)) for i in range(n_threads)]
for t in threads:
t.start()
for t in threads:
t.join()
if __name__ == '__main__':
n_threads = args.ensembling_cpu_threads
ensembling_dir = args.ensembling_dir
strategy = args.ensembling_strategy
dirs = args.dirs_to_ensemble
folds_dir = args.folds_dir
dirs = [os.path.join(folds_dir, d) for d in dirs]
for d in dirs:
if not os.path.exists(d):
raise ValueError(d + " doesn't exist")
ensemble(dirs, strategy, ensembling_dir, n_threads)