-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtrain_continual_learning_few_shot_system.py
105 lines (90 loc) · 5.89 KB
/
train_continual_learning_few_shot_system.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from torch.utils.data import DataLoader
from utils.parser_utils import get_args
args, device = get_args()
from utils.dataset_tools import check_download_dataset
from data import ConvertToThreeChannels, FewShotLearningDatasetParallel
from torchvision import transforms
from experiment_builder import ExperimentBuilder
from few_shot_learning_system import *
# Combines the arguments, model, data and experiment builders to run an experiment
if args.classifier_type == 'maml++_high-end':
model = EmbeddingMAMLFewShotClassifier(**args.__dict__)
elif args.classifier_type == 'maml++_low-end':
model = VGGMAMLFewShotClassifier(**args.__dict__)
elif args.classifier_type == 'vgg-fine-tune-scratch':
model = FineTuneFromScratchFewShotClassifier(**args.__dict__)
elif args.classifier_type == 'vgg-fine-tune-pretrained':
model = FineTuneFromPretrainedFewShotClassifier(**args.__dict__)
elif args.classifier_type == 'vgg-matching_network':
model = MatchingNetworkFewShotClassifier(**args.__dict__)
else:
raise NotImplementedError
check_download_dataset(dataset_name=args.dataset_name)
if args.image_channels == 3:
transforms = [transforms.Resize(size=(args.image_height, args.image_width)), transforms.ToTensor(),
ConvertToThreeChannels(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]
elif args.image_channels == 1:
transforms = [transforms.Resize(size=(args.image_height, args.image_width)), transforms.ToTensor()]
train_setup_dict = dict(dataset_name=args.dataset_name,
indexes_of_folders_indicating_class=args.indexes_of_folders_indicating_class,
train_val_test_split=args.train_val_test_split,
labels_as_int=args.labels_as_int, transforms=transforms,
num_classes_per_set=args.num_classes_per_set,
num_samples_per_support_class=args.num_samples_per_support_class,
num_samples_per_target_class=args.num_samples_per_target_class,
seed=args.seed,
sets_are_pre_split=args.sets_are_pre_split,
load_into_memory=args.load_into_memory, set_name='train',
num_tasks_per_epoch=args.total_epochs * args.total_iter_per_epoch,
num_channels=args.image_channels,
num_support_sets=args.num_support_sets,
overwrite_classes_in_each_task=args.overwrite_classes_in_each_task,
class_change_interval=args.class_change_interval)
val_setup_dict = dict(dataset_name=args.dataset_name,
indexes_of_folders_indicating_class=args.indexes_of_folders_indicating_class,
train_val_test_split=args.train_val_test_split,
labels_as_int=args.labels_as_int, transforms=transforms,
num_classes_per_set=args.num_classes_per_set,
num_samples_per_support_class=args.num_samples_per_support_class,
num_samples_per_target_class=args.num_samples_per_target_class,
seed=args.seed,
sets_are_pre_split=args.sets_are_pre_split,
load_into_memory=args.load_into_memory, set_name='val',
num_tasks_per_epoch=600 ,
num_channels=args.image_channels,
num_support_sets=args.num_support_sets,
overwrite_classes_in_each_task=args.overwrite_classes_in_each_task,
class_change_interval=args.class_change_interval)
test_setup_dict = dict(dataset_name=args.dataset_name,
indexes_of_folders_indicating_class=args.indexes_of_folders_indicating_class,
train_val_test_split=args.train_val_test_split,
labels_as_int=args.labels_as_int, transforms=transforms,
num_classes_per_set=args.num_classes_per_set,
num_samples_per_support_class=args.num_samples_per_support_class,
num_samples_per_target_class=args.num_samples_per_target_class,
seed=args.seed,
sets_are_pre_split=args.sets_are_pre_split,
load_into_memory=args.load_into_memory, set_name='test',
num_tasks_per_epoch=600,
num_channels=args.image_channels,
num_support_sets=args.num_support_sets,
overwrite_classes_in_each_task=args.overwrite_classes_in_each_task,
class_change_interval=args.class_change_interval)
train_data = FewShotLearningDatasetParallel(**train_setup_dict)
val_data = FewShotLearningDatasetParallel(**val_setup_dict)
test_data = FewShotLearningDatasetParallel(**test_setup_dict)
data_dict = {'train': DataLoader(train_data, batch_size=args.batch_size,
num_workers=args.num_dataprovider_workers),
'val': DataLoader(val_data, batch_size=args.batch_size,
num_workers=args.num_dataprovider_workers),
'test': DataLoader(test_data, batch_size=args.batch_size,
num_workers=args.num_dataprovider_workers)}
maml_system = ExperimentBuilder(model=model, data_dict=data_dict, experiment_name=args.experiment_name,
continue_from_epoch=args.continue_from_epoch,
total_iter_per_epoch=args.total_iter_per_epoch,
num_evaluation_tasks=args.num_evaluation_tasks, total_epochs=args.total_epochs,
batch_size=args.batch_size, max_models_to_save=args.max_models_to_save,
evaluate_on_test_set_only=args.evaluate_on_test_set_only,
args=args)
maml_system.run_experiment()