Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
[Retiarii] add validation in base trainers (#3184)
Browse files Browse the repository at this point in the history
  • Loading branch information
hzhua authored Dec 15, 2020
1 parent 59cd398 commit a0e2f8e
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 59 deletions.
27 changes: 10 additions & 17 deletions nni/retiarii/execution/logical_optimizer/logical_plan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
from typing import Dict, Tuple, List, Any

from nni.retiarii.utils import uid
from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation

Expand All @@ -14,7 +15,7 @@ def __eq__(self, o) -> bool:
return self.server == o.server and self.device == o.device

def __hash__(self) -> int:
return hash(self.server+'_'+self.device)
return hash(self.server + '_' + self.device)


class AbstractLogicalNode(Node):
Expand Down Expand Up @@ -181,10 +182,8 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \
if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id
if model_id not in training_config_slot:
phy_model.training_config.kwargs['model_kwargs'].append(
new_node.graph.model.training_config.kwargs.copy())
training_config_slot[model_id] = \
len(phy_model.training_config.kwargs['model_kwargs'])-1
phy_model.training_config.kwargs['model_kwargs'].append(new_node.graph.model.training_config.kwargs.copy())
training_config_slot[model_id] = len(phy_model.training_config.kwargs['model_kwargs']) - 1
slot = training_config_slot[model_id]
phy_model.training_config.kwargs['model_kwargs'][slot]['model_id'] = model_id
phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = False
Expand Down Expand Up @@ -221,18 +220,14 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \
tail_placement = node_placements[edge.tail]
if head_placement != tail_placement:
if head_placement.server != tail_placement.server:
raise ValueError(
'Cross-server placement is not supported.')
raise ValueError('Cross-server placement is not supported.')
# Same server different devices
if (edge.head, tail_placement) in copied_op:
to_node = copied_op[(edge.head, tail_placement)]
else:
to_operation = Operation.new(
'ToDevice', {"device": tail_placement.device})
to_node = Node(phy_graph, phy_model._uid(),
edge.head.name+"_to_"+edge.tail.name, to_operation)._register()
Edge((edge.head, edge.head_slot),
(to_node, None), _internal=True)._register()
to_operation = Operation.new('ToDevice', {"device": tail_placement.device})
to_node = Node(phy_graph, uid(), edge.head.name + "_to_" + edge.tail.name, to_operation)._register()
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node
edge.head = to_node
edge.head_slot = None
Expand Down Expand Up @@ -266,11 +261,9 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \

return phy_model, node_placements

def node_replace(self, old_node: Node,
new_node: Node,
input_slot_mapping=None, output_slot_mapping=None):
def node_replace(self, old_node: Node, new_node: Node, input_slot_mapping=None, output_slot_mapping=None):
# TODO: currently, only support single input slot and output slot.
if input_slot_mapping != None or output_slot_mapping != None:
if input_slot_mapping is not None or output_slot_mapping is not None:
raise ValueError('Slot mapping is not supported')

phy_graph = old_node.graph
Expand Down
4 changes: 2 additions & 2 deletions nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Dict, Tuple

from nni.retiarii.utils import uid
from ...graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
Expand Down Expand Up @@ -78,8 +79,7 @@ def convert(self, logical_plan: LogicalPlan) -> None:
assert(nodes_to_dedup[0] == root_node)
nodes_to_skip.add(root_node)
else:
dedup_node = DedupInputNode(logical_plan.logical_graph,
logical_plan.lp_model._uid(), nodes_to_dedup)._register()
dedup_node = DedupInputNode(logical_plan.logical_graph, uid(), nodes_to_dedup)._register()
for edge in logical_plan.logical_graph.edges:
if edge.head in nodes_to_dedup:
edge.head = dedup_node
Expand Down
112 changes: 77 additions & 35 deletions nni/retiarii/trainer/pytorch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def get_default_transform(dataset: str) -> Any:
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
# unsupported dataset, return None
return None
Expand Down Expand Up @@ -79,20 +80,30 @@ def __init__(self, model,
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super(
PyTorchImageClassificationTrainer,
self).__init__(
model,
dataset_cls,
dataset_kwargs,
dataloader_kwargs,
optimizer_cls,
optimizer_kwargs,
trainer_kwargs)
self._use_cuda = torch.cuda.is_available()
self.model = model
if self._use_cuda:
self.model.cuda()
self._loss_fn = nn.CrossEntropyLoss()
self._dataset = getattr(datasets, dataset_cls)(transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._optimizer = getattr(torch.optim, optimizer_cls)(
model.parameters(), **(optimizer_kwargs or {}))
self._train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._optimizer = getattr(torch.optim, optimizer_cls)(model.parameters(), **(optimizer_kwargs or {}))
self._trainer_kwargs = trainer_kwargs or {'max_epochs': 10}

# TODO: we will need at least two (maybe three) data loaders in future.
self._dataloader = DataLoader(
self._dataset, **(dataloader_kwargs or {}))
self._train_dataloader = DataLoader(self._train_dataset, **(dataloader_kwargs or {}))
self._val_dataloader = DataLoader(self._val_dataset, **(dataloader_kwargs or {}))

def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
Expand Down Expand Up @@ -137,12 +148,12 @@ def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:

def _validate(self):
validation_outputs = []
for i, batch in enumerate(self._dataloader):
for i, batch in enumerate(self._val_dataloader):
validation_outputs.append(self.validation_step(batch, i))
return self.validation_epoch_end(validation_outputs)

def _train(self):
for i, batch in enumerate(self._dataloader):
for i, batch in enumerate(self._train_dataloader):
loss = self.training_step(batch, i)
loss.backward()

Expand All @@ -157,25 +168,32 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def __init__(self, multi_model, kwargs=[]):
self.multi_model = multi_model
self.kwargs = kwargs
self._dataloaders = []
self._datasets = []
self._train_dataloaders = []
self._train_datasets = []
self._val_dataloaders = []
self._val_datasets = []
self._optimizers = []
self._trainers = []
self._loss_fn = nn.CrossEntropyLoss()
self.max_steps = None
if 'max_steps' in self.kwargs:
self.max_steps = self.kwargs['max_steps']
self.max_steps = self.kwargs['max_steps'] if 'makx_steps' in self.kwargs else None
self.n_model = len(self.kwargs['model_kwargs'])

for m in self.kwargs['model_kwargs']:
if m['use_input']:
dataset_cls = m['dataset_cls']
dataset_kwargs = m['dataset_kwargs']
dataloader_kwargs = m['dataloader_kwargs']
dataset = getattr(datasets, dataset_cls)(transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
dataloader = DataLoader(dataset, **(dataloader_kwargs or {}))
self._datasets.append(dataset)
self._dataloaders.append(dataloader)
train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
train_dataloader = DataLoader(train_dataset, **(dataloader_kwargs or {}))
val_dataloader = DataLoader(val_dataset, **(dataloader_kwargs or {}))
self._train_datasets.append(train_dataset)
self._train_dataloaders.append(train_dataloader)

self._val_datasets.append(val_dataset)
self._val_dataloaders.append(val_dataloader)

if m['use_output']:
optimizer_cls = m['optimizer_cls']
Expand All @@ -195,9 +213,10 @@ def fit(self) -> None:
max_epochs = max([x['trainer_kwargs']['max_epochs'] for x in self.kwargs['model_kwargs']])
for _ in range(max_epochs):
self._train()
nni.report_final_result(self._validate())

def _train(self):
for batch_idx, multi_model_batch in enumerate(zip(*self._dataloaders)):
for batch_idx, multi_model_batch in enumerate(zip(*self._train_dataloaders)):
for opt in self._optimizers:
opt.zero_grad()
xs = []
Expand Down Expand Up @@ -225,16 +244,9 @@ def _train(self):
summed_loss.backward()
for opt in self._optimizers:
opt.step()
if batch_idx % 50 == 0:
nni.report_intermediate_result(report_loss)
if self.max_steps and batch_idx >= self.max_steps:
return

def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.training_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.training_step_after_model(x, y, y_hat)

def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch
if device:
Expand All @@ -245,17 +257,47 @@ def training_step_after_model(self, x, y, y_hat):
loss = self._loss_fn(y_hat, y)
return loss

def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.validation_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.validation_step_after_model(x, y, y_hat)
def _validate(self):
all_val_outputs = {idx: [] for idx in range(self.n_model)}
for batch_idx, multi_model_batch in enumerate(zip(*self._val_dataloaders)):
xs = []
ys = []
for idx, batch in enumerate(multi_model_batch):
x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}')
xs.append(x)
ys.append(y)
if len(ys) != len(xs):
raise ValueError('len(ys) should be equal to len(xs)')

def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
y_hats = self.multi_model(*xs)

for output_idx, yhat in enumerate(y_hats):
if len(ys) == len(y_hats):
acc = self.validation_step_after_model(xs[output_idx], ys[output_idx], yhat)
elif len(ys) == 1:
acc = self.validation_step_after_model(xs[0], ys[0].to(yhat.get_device()), yhat)
else:
raise ValueError('len(ys) should be either 1 or len(y_hats)')
all_val_outputs[output_idx].append(acc)

report_acc = {}
for idx in all_val_outputs:
avg_acc = np.mean([x['val_acc'] for x in all_val_outputs[idx]]).item()
report_acc[self.kwargs['model_kwargs'][idx]['model_id']] = avg_acc
nni.report_intermediate_result(report_acc)
return report_acc

def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch
if self._use_cuda:
x, y = x.cuda(), y.cuda()
if device:
x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device))
return x, y

def validation_step_after_model(self, x, y, y_hat):
acc = self._accuracy(y_hat, y)
return {'val_acc': acc}

def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0)
3 changes: 1 addition & 2 deletions test/ut/retiarii/test_cgo_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def test_submit_models(self):
protocol._in_file = open('generated/debug_protocol_out_file.py', 'rb')

models = _load_mnist(2)
anything = lambda: None
advisor = RetiariiAdvisor(anything)
advisor = RetiariiAdvisor()
submit_models(*models)

if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
Expand Down
5 changes: 2 additions & 3 deletions test/ut/retiarii/test_dedup_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ def test_dedup_input(self):
lp_dump = lp.logical_graph._dump()

self.assertTrue(correct_dump[0] == json.dumps(lp_dump))

anything = lambda: None
advisor = RetiariiAdvisor(anything)

advisor = RetiariiAdvisor()
cgo = CGOExecutionEngine()

phy_models = cgo._assemble(lp)
Expand Down

0 comments on commit a0e2f8e

Please sign in to comment.