From a0e2f8efb6782c3979050f4d9dadf1bfd5d8c58f Mon Sep 17 00:00:00 2001 From: Zhenhua Han Date: Tue, 15 Dec 2020 12:05:21 +0800 Subject: [PATCH] [Retiarii] add validation in base trainers (#3184) --- .../logical_optimizer/logical_plan.py | 27 ++--- .../logical_optimizer/opt_dedup_input.py | 4 +- nni/retiarii/trainer/pytorch/base.py | 112 ++++++++++++------ test/ut/retiarii/test_cgo_engine.py | 3 +- test/ut/retiarii/test_dedup_input.py | 5 +- 5 files changed, 92 insertions(+), 59 deletions(-) diff --git a/nni/retiarii/execution/logical_optimizer/logical_plan.py b/nni/retiarii/execution/logical_optimizer/logical_plan.py index 06ca3ef7c8..5901d0a6a6 100644 --- a/nni/retiarii/execution/logical_optimizer/logical_plan.py +++ b/nni/retiarii/execution/logical_optimizer/logical_plan.py @@ -1,6 +1,7 @@ import copy from typing import Dict, Tuple, List, Any +from nni.retiarii.utils import uid from ...graph import Cell, Edge, Graph, Model, Node from ...operation import Operation, _IOPseudoOperation @@ -14,7 +15,7 @@ def __eq__(self, o) -> bool: return self.server == o.server and self.device == o.device def __hash__(self) -> int: - return hash(self.server+'_'+self.device) + return hash(self.server + '_' + self.device) class AbstractLogicalNode(Node): @@ -181,10 +182,8 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \ if isinstance(new_node.operation, _IOPseudoOperation): model_id = new_node.graph.model.model_id if model_id not in training_config_slot: - phy_model.training_config.kwargs['model_kwargs'].append( - new_node.graph.model.training_config.kwargs.copy()) - training_config_slot[model_id] = \ - len(phy_model.training_config.kwargs['model_kwargs'])-1 + phy_model.training_config.kwargs['model_kwargs'].append(new_node.graph.model.training_config.kwargs.copy()) + training_config_slot[model_id] = len(phy_model.training_config.kwargs['model_kwargs']) - 1 slot = training_config_slot[model_id] phy_model.training_config.kwargs['model_kwargs'][slot]['model_id'] = model_id phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = False @@ -221,18 +220,14 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \ tail_placement = node_placements[edge.tail] if head_placement != tail_placement: if head_placement.server != tail_placement.server: - raise ValueError( - 'Cross-server placement is not supported.') + raise ValueError('Cross-server placement is not supported.') # Same server different devices if (edge.head, tail_placement) in copied_op: to_node = copied_op[(edge.head, tail_placement)] else: - to_operation = Operation.new( - 'ToDevice', {"device": tail_placement.device}) - to_node = Node(phy_graph, phy_model._uid(), - edge.head.name+"_to_"+edge.tail.name, to_operation)._register() - Edge((edge.head, edge.head_slot), - (to_node, None), _internal=True)._register() + to_operation = Operation.new('ToDevice', {"device": tail_placement.device}) + to_node = Node(phy_graph, uid(), edge.head.name + "_to_" + edge.tail.name, to_operation)._register() + Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register() copied_op[(edge.head, tail_placement)] = to_node edge.head = to_node edge.head_slot = None @@ -266,11 +261,9 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \ return phy_model, node_placements - def node_replace(self, old_node: Node, - new_node: Node, - input_slot_mapping=None, output_slot_mapping=None): + def node_replace(self, old_node: Node, new_node: Node, input_slot_mapping=None, output_slot_mapping=None): # TODO: currently, only support single input slot and output slot. - if input_slot_mapping != None or output_slot_mapping != None: + if input_slot_mapping is not None or output_slot_mapping is not None: raise ValueError('Slot mapping is not supported') phy_graph = old_node.graph diff --git a/nni/retiarii/execution/logical_optimizer/opt_dedup_input.py b/nni/retiarii/execution/logical_optimizer/opt_dedup_input.py index 4b50346f0b..24612e45b6 100644 --- a/nni/retiarii/execution/logical_optimizer/opt_dedup_input.py +++ b/nni/retiarii/execution/logical_optimizer/opt_dedup_input.py @@ -1,5 +1,6 @@ from typing import List, Dict, Tuple +from nni.retiarii.utils import uid from ...graph import Graph, Model, Node from .interface import AbstractOptimizer from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan, @@ -78,8 +79,7 @@ def convert(self, logical_plan: LogicalPlan) -> None: assert(nodes_to_dedup[0] == root_node) nodes_to_skip.add(root_node) else: - dedup_node = DedupInputNode(logical_plan.logical_graph, - logical_plan.lp_model._uid(), nodes_to_dedup)._register() + dedup_node = DedupInputNode(logical_plan.logical_graph, uid(), nodes_to_dedup)._register() for edge in logical_plan.logical_graph.edges: if edge.head in nodes_to_dedup: edge.head = dedup_node diff --git a/nni/retiarii/trainer/pytorch/base.py b/nni/retiarii/trainer/pytorch/base.py index 6d2156c0b5..8b4a4e4f9b 100644 --- a/nni/retiarii/trainer/pytorch/base.py +++ b/nni/retiarii/trainer/pytorch/base.py @@ -36,7 +36,8 @@ def get_default_transform(dataset: str) -> Any: transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), - transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + transforms.Normalize((0.4914, 0.4822, 0.4465), + (0.2023, 0.1994, 0.2010)), ]) # unsupported dataset, return None return None @@ -79,20 +80,30 @@ def __init__(self, model, Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently, only the key ``max_epochs`` is useful. """ + super( + PyTorchImageClassificationTrainer, + self).__init__( + model, + dataset_cls, + dataset_kwargs, + dataloader_kwargs, + optimizer_cls, + optimizer_kwargs, + trainer_kwargs) self._use_cuda = torch.cuda.is_available() self.model = model if self._use_cuda: self.model.cuda() self._loss_fn = nn.CrossEntropyLoss() - self._dataset = getattr(datasets, dataset_cls)(transform=get_default_transform(dataset_cls), - **(dataset_kwargs or {})) - self._optimizer = getattr(torch.optim, optimizer_cls)( - model.parameters(), **(optimizer_kwargs or {})) + self._train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls), + **(dataset_kwargs or {})) + self._val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls), + **(dataset_kwargs or {})) + self._optimizer = getattr(torch.optim, optimizer_cls)(model.parameters(), **(optimizer_kwargs or {})) self._trainer_kwargs = trainer_kwargs or {'max_epochs': 10} - # TODO: we will need at least two (maybe three) data loaders in future. - self._dataloader = DataLoader( - self._dataset, **(dataloader_kwargs or {})) + self._train_dataloader = DataLoader(self._train_dataset, **(dataloader_kwargs or {})) + self._val_dataloader = DataLoader(self._val_dataset, **(dataloader_kwargs or {})) def _accuracy(self, input, target): # pylint: disable=redefined-builtin _, predict = torch.max(input.data, 1) @@ -137,12 +148,12 @@ def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]: def _validate(self): validation_outputs = [] - for i, batch in enumerate(self._dataloader): + for i, batch in enumerate(self._val_dataloader): validation_outputs.append(self.validation_step(batch, i)) return self.validation_epoch_end(validation_outputs) def _train(self): - for i, batch in enumerate(self._dataloader): + for i, batch in enumerate(self._train_dataloader): loss = self.training_step(batch, i) loss.backward() @@ -157,25 +168,32 @@ class PyTorchMultiModelTrainer(BaseTrainer): def __init__(self, multi_model, kwargs=[]): self.multi_model = multi_model self.kwargs = kwargs - self._dataloaders = [] - self._datasets = [] + self._train_dataloaders = [] + self._train_datasets = [] + self._val_dataloaders = [] + self._val_datasets = [] self._optimizers = [] self._trainers = [] self._loss_fn = nn.CrossEntropyLoss() - self.max_steps = None - if 'max_steps' in self.kwargs: - self.max_steps = self.kwargs['max_steps'] + self.max_steps = self.kwargs['max_steps'] if 'makx_steps' in self.kwargs else None + self.n_model = len(self.kwargs['model_kwargs']) for m in self.kwargs['model_kwargs']: if m['use_input']: dataset_cls = m['dataset_cls'] dataset_kwargs = m['dataset_kwargs'] dataloader_kwargs = m['dataloader_kwargs'] - dataset = getattr(datasets, dataset_cls)(transform=get_default_transform(dataset_cls), - **(dataset_kwargs or {})) - dataloader = DataLoader(dataset, **(dataloader_kwargs or {})) - self._datasets.append(dataset) - self._dataloaders.append(dataloader) + train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls), + **(dataset_kwargs or {})) + val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls), + **(dataset_kwargs or {})) + train_dataloader = DataLoader(train_dataset, **(dataloader_kwargs or {})) + val_dataloader = DataLoader(val_dataset, **(dataloader_kwargs or {})) + self._train_datasets.append(train_dataset) + self._train_dataloaders.append(train_dataloader) + + self._val_datasets.append(val_dataset) + self._val_dataloaders.append(val_dataloader) if m['use_output']: optimizer_cls = m['optimizer_cls'] @@ -195,9 +213,10 @@ def fit(self) -> None: max_epochs = max([x['trainer_kwargs']['max_epochs'] for x in self.kwargs['model_kwargs']]) for _ in range(max_epochs): self._train() + nni.report_final_result(self._validate()) def _train(self): - for batch_idx, multi_model_batch in enumerate(zip(*self._dataloaders)): + for batch_idx, multi_model_batch in enumerate(zip(*self._train_dataloaders)): for opt in self._optimizers: opt.zero_grad() xs = [] @@ -225,16 +244,9 @@ def _train(self): summed_loss.backward() for opt in self._optimizers: opt.step() - if batch_idx % 50 == 0: - nni.report_intermediate_result(report_loss) if self.max_steps and batch_idx >= self.max_steps: return - def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]: - x, y = self.training_step_before_model(batch, batch_idx) - y_hat = self.model(x) - return self.training_step_after_model(x, y, y_hat) - def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None): x, y = batch if device: @@ -245,17 +257,47 @@ def training_step_after_model(self, x, y, y_hat): loss = self._loss_fn(y_hat, y) return loss - def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]: - x, y = self.validation_step_before_model(batch, batch_idx) - y_hat = self.model(x) - return self.validation_step_after_model(x, y, y_hat) + def _validate(self): + all_val_outputs = {idx: [] for idx in range(self.n_model)} + for batch_idx, multi_model_batch in enumerate(zip(*self._val_dataloaders)): + xs = [] + ys = [] + for idx, batch in enumerate(multi_model_batch): + x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}') + xs.append(x) + ys.append(y) + if len(ys) != len(xs): + raise ValueError('len(ys) should be equal to len(xs)') - def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int): + y_hats = self.multi_model(*xs) + + for output_idx, yhat in enumerate(y_hats): + if len(ys) == len(y_hats): + acc = self.validation_step_after_model(xs[output_idx], ys[output_idx], yhat) + elif len(ys) == 1: + acc = self.validation_step_after_model(xs[0], ys[0].to(yhat.get_device()), yhat) + else: + raise ValueError('len(ys) should be either 1 or len(y_hats)') + all_val_outputs[output_idx].append(acc) + + report_acc = {} + for idx in all_val_outputs: + avg_acc = np.mean([x['val_acc'] for x in all_val_outputs[idx]]).item() + report_acc[self.kwargs['model_kwargs'][idx]['model_id']] = avg_acc + nni.report_intermediate_result(report_acc) + return report_acc + + def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None): x, y = batch - if self._use_cuda: - x, y = x.cuda(), y.cuda() + if device: + x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device)) return x, y def validation_step_after_model(self, x, y, y_hat): acc = self._accuracy(y_hat, y) return {'val_acc': acc} + + def _accuracy(self, input, target): # pylint: disable=redefined-builtin + _, predict = torch.max(input.data, 1) + correct = predict.eq(target.data).cpu().sum().item() + return correct / input.size(0) diff --git a/test/ut/retiarii/test_cgo_engine.py b/test/ut/retiarii/test_cgo_engine.py index 0592abdab9..6963c8f54f 100644 --- a/test/ut/retiarii/test_cgo_engine.py +++ b/test/ut/retiarii/test_cgo_engine.py @@ -42,8 +42,7 @@ def test_submit_models(self): protocol._in_file = open('generated/debug_protocol_out_file.py', 'rb') models = _load_mnist(2) - anything = lambda: None - advisor = RetiariiAdvisor(anything) + advisor = RetiariiAdvisor() submit_models(*models) if torch.cuda.is_available() and torch.cuda.device_count() >= 2: diff --git a/test/ut/retiarii/test_dedup_input.py b/test/ut/retiarii/test_dedup_input.py index 7b721647b2..4864447d1c 100644 --- a/test/ut/retiarii/test_dedup_input.py +++ b/test/ut/retiarii/test_dedup_input.py @@ -54,9 +54,8 @@ def test_dedup_input(self): lp_dump = lp.logical_graph._dump() self.assertTrue(correct_dump[0] == json.dumps(lp_dump)) - - anything = lambda: None - advisor = RetiariiAdvisor(anything) + + advisor = RetiariiAdvisor() cgo = CGOExecutionEngine() phy_models = cgo._assemble(lp)