Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] add validation in base trainers #3184

Merged
merged 26 commits into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
58de5a3
cross-graph optimization: input dedup
hzhua Nov 19, 2020
8022ab8
nni integration test of cross-graph optimization
hzhua Nov 19, 2020
510f572
update cross-graph ut
hzhua Nov 19, 2020
1b60074
Merge branch 'dev-retiarii' into dev-retiarii
hzhua Nov 19, 2020
a0f7d09
sovle merge conflict
hzhua Nov 20, 2020
8d04404
Merge remote-tracking branch 'upstream/dev-retiarii' into dev-retiarii
hzhua Nov 24, 2020
4124371
fix inconsistent implementation with upstream of new code converter
hzhua Nov 25, 2020
a82825a
remove duplicated __hash__ in nni.retiarii.graph
hzhua Nov 25, 2020
c93c4f3
remove bypass optimization
hzhua Nov 25, 2020
b1de4be
use __name__ in CGOExecutionEngine logger
hzhua Nov 25, 2020
9dfacaa
Merge remote-tracking branch 'upstream/dev-retiarii' into dev-retiarii
hzhua Dec 11, 2020
21eb936
add validation in PyTorchMultiModelTrainer
hzhua Dec 11, 2020
ae41b3a
add validation in PyTorchImageClassificationTrainer
hzhua Dec 11, 2020
6752a2e
Merge remote-tracking branch 'upstream/dev-retiarii' into dev-retiarii
hzhua Dec 11, 2020
2d4bbda
format file
hzhua Dec 14, 2020
f6328d6
format file
hzhua Dec 14, 2020
2033846
remove unused training_step and validation_step in PyTorchMultiModelT…
hzhua Dec 14, 2020
fbf6e19
remove todo: add val_data_loader
hzhua Dec 14, 2020
501ad90
format
hzhua Dec 14, 2020
a31b0db
format
hzhua Dec 14, 2020
67b3a85
Merge remote-tracking branch 'upstream/dev-retiarii' into dev-retiarii
hzhua Dec 14, 2020
c0a6ea0
format
hzhua Dec 14, 2020
b1e2751
Merge remote-tracking branch 'upstream/dev-retiarii' into dev-retiarii
hzhua Dec 15, 2020
6aba977
format
hzhua Dec 15, 2020
dd0017b
fix pylint
hzhua Dec 15, 2020
ba9c00f
fix pylint
hzhua Dec 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 10 additions & 17 deletions nni/retiarii/execution/logical_optimizer/logical_plan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
from typing import Dict, Tuple, List, Any

from nni.retiarii.utils import uid
from ...graph import Cell, Edge, Graph, Model, Node
from ...operation import Operation, _IOPseudoOperation

Expand All @@ -14,7 +15,7 @@ def __eq__(self, o) -> bool:
return self.server == o.server and self.device == o.device

def __hash__(self) -> int:
return hash(self.server+'_'+self.device)
return hash(self.server + '_' + self.device)


class AbstractLogicalNode(Node):
Expand Down Expand Up @@ -181,10 +182,8 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \
if isinstance(new_node.operation, _IOPseudoOperation):
model_id = new_node.graph.model.model_id
if model_id not in training_config_slot:
phy_model.training_config.kwargs['model_kwargs'].append(
new_node.graph.model.training_config.kwargs.copy())
training_config_slot[model_id] = \
len(phy_model.training_config.kwargs['model_kwargs'])-1
phy_model.training_config.kwargs['model_kwargs'].append(new_node.graph.model.training_config.kwargs.copy())
training_config_slot[model_id] = len(phy_model.training_config.kwargs['model_kwargs']) - 1
slot = training_config_slot[model_id]
phy_model.training_config.kwargs['model_kwargs'][slot]['model_id'] = model_id
phy_model.training_config.kwargs['model_kwargs'][slot]['use_input'] = False
Expand Down Expand Up @@ -221,18 +220,14 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \
tail_placement = node_placements[edge.tail]
if head_placement != tail_placement:
if head_placement.server != tail_placement.server:
raise ValueError(
'Cross-server placement is not supported.')
raise ValueError('Cross-server placement is not supported.')
# Same server different devices
if (edge.head, tail_placement) in copied_op:
to_node = copied_op[(edge.head, tail_placement)]
else:
to_operation = Operation.new(
'ToDevice', {"device": tail_placement.device})
to_node = Node(phy_graph, phy_model._uid(),
edge.head.name+"_to_"+edge.tail.name, to_operation)._register()
Edge((edge.head, edge.head_slot),
(to_node, None), _internal=True)._register()
to_operation = Operation.new('ToDevice', {"device": tail_placement.device})
to_node = Node(phy_graph, uid(), edge.head.name + "_to_" + edge.tail.name, to_operation)._register()
Edge((edge.head, edge.head_slot), (to_node, None), _internal=True)._register()
copied_op[(edge.head, tail_placement)] = to_node
edge.head = to_node
edge.head_slot = None
Expand Down Expand Up @@ -266,11 +261,9 @@ def assemble(self, multi_model_placement: Dict[Model, PhysicalDevice]) \

return phy_model, node_placements

def node_replace(self, old_node: Node,
new_node: Node,
input_slot_mapping=None, output_slot_mapping=None):
def node_replace(self, old_node: Node, new_node: Node, input_slot_mapping=None, output_slot_mapping=None):
# TODO: currently, only support single input slot and output slot.
if input_slot_mapping != None or output_slot_mapping != None:
if input_slot_mapping is not None or output_slot_mapping is not None:
raise ValueError('Slot mapping is not supported')

phy_graph = old_node.graph
Expand Down
4 changes: 2 additions & 2 deletions nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Dict, Tuple

from nni.retiarii.utils import uid
from ...graph import Graph, Model, Node
from .interface import AbstractOptimizer
from .logical_plan import (AbstractLogicalNode, LogicalGraph, LogicalPlan,
Expand Down Expand Up @@ -78,8 +79,7 @@ def convert(self, logical_plan: LogicalPlan) -> None:
assert(nodes_to_dedup[0] == root_node)
nodes_to_skip.add(root_node)
else:
dedup_node = DedupInputNode(logical_plan.logical_graph,
logical_plan.lp_model._uid(), nodes_to_dedup)._register()
dedup_node = DedupInputNode(logical_plan.logical_graph, uid(), nodes_to_dedup)._register()
for edge in logical_plan.logical_graph.edges:
if edge.head in nodes_to_dedup:
edge.head = dedup_node
Expand Down
112 changes: 77 additions & 35 deletions nni/retiarii/trainer/pytorch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def get_default_transform(dataset: str) -> Any:
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
# unsupported dataset, return None
return None
Expand Down Expand Up @@ -79,20 +80,30 @@ def __init__(self, model,
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super(
PyTorchImageClassificationTrainer,
self).__init__(
model,
dataset_cls,
dataset_kwargs,
dataloader_kwargs,
optimizer_cls,
optimizer_kwargs,
trainer_kwargs)
self._use_cuda = torch.cuda.is_available()
self.model = model
if self._use_cuda:
self.model.cuda()
self._loss_fn = nn.CrossEntropyLoss()
self._dataset = getattr(datasets, dataset_cls)(transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._optimizer = getattr(torch.optim, optimizer_cls)(
model.parameters(), **(optimizer_kwargs or {}))
self._train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._optimizer = getattr(torch.optim, optimizer_cls)(model.parameters(), **(optimizer_kwargs or {}))
self._trainer_kwargs = trainer_kwargs or {'max_epochs': 10}

# TODO: we will need at least two (maybe three) data loaders in future.
self._dataloader = DataLoader(
self._dataset, **(dataloader_kwargs or {}))
self._train_dataloader = DataLoader(self._train_dataset, **(dataloader_kwargs or {}))
self._val_dataloader = DataLoader(self._val_dataset, **(dataloader_kwargs or {}))

def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
Expand Down Expand Up @@ -137,12 +148,12 @@ def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:

def _validate(self):
validation_outputs = []
for i, batch in enumerate(self._dataloader):
for i, batch in enumerate(self._val_dataloader):
validation_outputs.append(self.validation_step(batch, i))
return self.validation_epoch_end(validation_outputs)

def _train(self):
for i, batch in enumerate(self._dataloader):
for i, batch in enumerate(self._train_dataloader):
loss = self.training_step(batch, i)
loss.backward()

Expand All @@ -157,25 +168,32 @@ class PyTorchMultiModelTrainer(BaseTrainer):
def __init__(self, multi_model, kwargs=[]):
self.multi_model = multi_model
self.kwargs = kwargs
self._dataloaders = []
self._datasets = []
self._train_dataloaders = []
self._train_datasets = []
self._val_dataloaders = []
self._val_datasets = []
self._optimizers = []
self._trainers = []
self._loss_fn = nn.CrossEntropyLoss()
self.max_steps = None
if 'max_steps' in self.kwargs:
self.max_steps = self.kwargs['max_steps']
self.max_steps = self.kwargs['max_steps'] if 'makx_steps' in self.kwargs else None
self.n_model = len(self.kwargs['model_kwargs'])

for m in self.kwargs['model_kwargs']:
if m['use_input']:
dataset_cls = m['dataset_cls']
dataset_kwargs = m['dataset_kwargs']
dataloader_kwargs = m['dataloader_kwargs']
dataset = getattr(datasets, dataset_cls)(transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
dataloader = DataLoader(dataset, **(dataloader_kwargs or {}))
self._datasets.append(dataset)
self._dataloaders.append(dataloader)
train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
train_dataloader = DataLoader(train_dataset, **(dataloader_kwargs or {}))
val_dataloader = DataLoader(val_dataset, **(dataloader_kwargs or {}))
self._train_datasets.append(train_dataset)
self._train_dataloaders.append(train_dataloader)

self._val_datasets.append(val_dataset)
self._val_dataloaders.append(val_dataloader)

if m['use_output']:
optimizer_cls = m['optimizer_cls']
Expand All @@ -195,9 +213,10 @@ def fit(self) -> None:
max_epochs = max([x['trainer_kwargs']['max_epochs'] for x in self.kwargs['model_kwargs']])
for _ in range(max_epochs):
self._train()
nni.report_final_result(self._validate())

def _train(self):
for batch_idx, multi_model_batch in enumerate(zip(*self._dataloaders)):
for batch_idx, multi_model_batch in enumerate(zip(*self._train_dataloaders)):
for opt in self._optimizers:
opt.zero_grad()
xs = []
Expand Down Expand Up @@ -225,16 +244,9 @@ def _train(self):
summed_loss.backward()
for opt in self._optimizers:
opt.step()
if batch_idx % 50 == 0:
nni.report_intermediate_result(report_loss)
if self.max_steps and batch_idx >= self.max_steps:
return

def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.training_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.training_step_after_model(x, y, y_hat)

def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest using self.device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In MultiModel, different model's input may need to be placed on different devices (called in _train). Currently, the trainer just sets one GPU per model in hard-code.

BTW, train_step and validation_step are not used in PyTorchImageClassificationTrainer. Removed.

x, y = batch
if device:
Expand All @@ -245,17 +257,47 @@ def training_step_after_model(self, x, y, y_hat):
loss = self._loss_fn(y_hat, y)
return loss

def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.validation_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.validation_step_after_model(x, y, y_hat)
def _validate(self):
all_val_outputs = {idx: [] for idx in range(self.n_model)}
for batch_idx, multi_model_batch in enumerate(zip(*self._val_dataloaders)):
xs = []
ys = []
for idx, batch in enumerate(multi_model_batch):
x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}')
xs.append(x)
ys.append(y)
if len(ys) != len(xs):
raise ValueError('len(ys) should be equal to len(xs)')

def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
y_hats = self.multi_model(*xs)

for output_idx, yhat in enumerate(y_hats):
if len(ys) == len(y_hats):
acc = self.validation_step_after_model(xs[output_idx], ys[output_idx], yhat)
elif len(ys) == 1:
acc = self.validation_step_after_model(xs[0], ys[0].to(yhat.get_device()), yhat)
else:
raise ValueError('len(ys) should be either 1 or len(y_hats)')
all_val_outputs[output_idx].append(acc)

report_acc = {}
for idx in all_val_outputs:
avg_acc = np.mean([x['val_acc'] for x in all_val_outputs[idx]]).item()
report_acc[self.kwargs['model_kwargs'][idx]['model_id']] = avg_acc
nni.report_intermediate_result(report_acc)
return report_acc

def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch
if self._use_cuda:
x, y = x.cuda(), y.cuda()
if device:
x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device))
return x, y

def validation_step_after_model(self, x, y, y_hat):
acc = self._accuracy(y_hat, y)
return {'val_acc': acc}

def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0)
3 changes: 1 addition & 2 deletions test/ut/retiarii/test_cgo_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,7 @@ def test_submit_models(self):
protocol._in_file = open('generated/debug_protocol_out_file.py', 'rb')

models = _load_mnist(2)
anything = lambda: None
advisor = RetiariiAdvisor(anything)
advisor = RetiariiAdvisor()
submit_models(*models)

if torch.cuda.is_available() and torch.cuda.device_count() >= 2:
Expand Down
5 changes: 2 additions & 3 deletions test/ut/retiarii/test_dedup_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ def test_dedup_input(self):
lp_dump = lp.logical_graph._dump()

self.assertTrue(correct_dump[0] == json.dumps(lp_dump))

anything = lambda: None
advisor = RetiariiAdvisor(anything)

advisor = RetiariiAdvisor()
cgo = CGOExecutionEngine()

phy_models = cgo._assemble(lp)
Expand Down