From e925df3563b7191e3a579d1dc33caebb8ff976f4 Mon Sep 17 00:00:00 2001 From: tmynn Date: Fri, 30 Sep 2022 11:48:39 +0000 Subject: [PATCH 1/2] [feat] Add MXNet integration --- CHANGELOG.md | 1 + aim/mxnet.py | 2 + aim/sdk/adapters/mxnet.py | 154 ++++++++++++++++++++++++++++++++++++++ examples/mxnet_track.py | 71 ++++++++++++++++++ 4 files changed, 228 insertions(+) create mode 100644 aim/mxnet.py create mode 100644 aim/sdk/adapters/mxnet.py create mode 100644 examples/mxnet_track.py diff --git a/CHANGELOG.md b/CHANGELOG.md index f453a4be96..34d437de25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ - Add fast.ai integration (tmynn) - Add command for dangling params cleanup (mihran113) - Deprecate Python 3.6 (alberttorosyan) +- Add MXNet integration (tmynn) ### Fixes: diff --git a/aim/mxnet.py b/aim/mxnet.py new file mode 100644 index 0000000000..403d33d40b --- /dev/null +++ b/aim/mxnet.py @@ -0,0 +1,2 @@ +# Alias to SDK mxnet interface +from aim.sdk.adapters.mxnet import AimLoggingHandler # noqa F401 diff --git a/aim/sdk/adapters/mxnet.py b/aim/sdk/adapters/mxnet.py new file mode 100644 index 0000000000..3763e4246c --- /dev/null +++ b/aim/sdk/adapters/mxnet.py @@ -0,0 +1,154 @@ +import time +import numpy as np +from mxnet.gluon.contrib.estimator.utils import _check_metrics +from mxnet.gluon.contrib.estimator import TrainBegin, TrainEnd, BatchBegin, BatchEnd, Estimator +from typing import Optional, Union, Any, List +from aim.sdk.run import Run +from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT + + +class AimLoggingHandler(TrainBegin, TrainEnd, BatchBegin, BatchEnd): + """Aim wrapper on top of the Mxnet Basic Logging Handler that applies to every Gluon estimator by default. + :py:class:`AimLoggingHandler` logs hyper-parameters, training statistics, + and other useful information during training + Parameters + ---------- + log_interval: int or str, default 'epoch' + Logging interval during training. + log_interval='epoch': display metrics every epoch + log_interval=integer k: display metrics every interval of k batches + metrics : list of EvalMetrics + Metrics to be logged, logged at batch end, epoch end, train end. + priority : scalar, default np.Inf + Priority level of the AimLoggingHandler. Priority level is sorted in + ascending order. The lower the number is, the higher priority level the + handler is. + """ + + def __init__(self, log_interval: Union[int, str] = 'epoch', + repo: Optional[str] = None, + experiment_name: Optional[str] = None, + system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT, + log_system_params: bool = True, + metrics: List[Any] = None, + priority=np.Inf,): + super().__init__() + if not isinstance(log_interval, int) and log_interval != 'epoch': + raise ValueError("log_interval must be either an integer or string 'epoch'") + + self.metrics = _check_metrics(metrics) + self.batch_index = 0 + self.current_epoch = 0 + self.processed_samples = 0 + # logging handler need to be called at last to make sure all states are updated + # it will also shut down logging at train end + self.priority = priority + self.log_interval = log_interval + self.log_interval_time = 0 + + self.repo = repo + self.experiment_name = experiment_name + self.system_tracking_interval = system_tracking_interval + self.log_system_params = log_system_params + self._run = None + self._run_hash = None + + def train_begin(self, estimator: Optional[Estimator], *args, **kwargs): + self.train_start = time.time() + trainer = estimator.trainer + optimizer = trainer.optimizer.__class__.__name__ + lr = trainer.learning_rate + + estimator.logger.info("Training begin: using optimizer %s " + "with current learning rate %.4f ", + optimizer, lr) + if estimator.max_epoch: + estimator.logger.info("Train for %d epochs.", estimator.max_epoch) + else: + estimator.logger.info("Train for %d batches.", estimator.max_batch) + # reset all counters + self.current_epoch = 0 + self.batch_index = 0 + self.processed_samples = 0 + self.log_interval_time = 0 + + params = { + "arch": estimator.net.name, + "loss": estimator.loss.name, + "optimizer": optimizer, + "lr": lr, + "max_epoch": estimator.max_epoch, + "max_batch": estimator.max_batch + } + + self.setup(estimator, params) + + def train_end(self, estimator: Optional[Estimator], *args, **kwargs): + train_time = time.time() - self.train_start + msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch) + # log every result in train stats including train/validation loss & metrics + for metric in self.metrics: + name, value = metric.get() + msg += '%s: %.4f, ' % (name, value) + estimator.logger.info(msg.rstrip(', ')) + + def batch_begin(self, estimator: Optional[Estimator], *args, **kwargs): + if isinstance(self.log_interval, int): + self.batch_start = time.time() + + def batch_end(self, estimator: Optional[Estimator], *args, **kwargs): + if isinstance(self.log_interval, int): + batch_time = time.time() - self.batch_start + msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) + self.processed_samples += kwargs['batch'][0].shape[0] + msg += '[Samples %s] ' % (self.processed_samples) + self.log_interval_time += batch_time + if self.batch_index % self.log_interval == 0: + msg += 'time/interval: %.3fs ' % self.log_interval_time + self.log_interval_time = 0 + for metric in self.metrics: + # only log current training loss & metric after each interval + name, value = metric.get() + msg += '%s: %.4f, ' % (name, value) + + context_name, metric_name = name.split(" ") + context = {'subset': context_name} + self._run.track(value, metric_name, step=self.batch_index, context=context) + estimator.logger.info(msg.rstrip(', ')) + self.batch_index += 1 + + @property + def experiment(self) -> Run: + if not self._run: + self.setup() + return self._run + + def setup(self, estimator: Optional[Estimator] = None, args=None): + if not self._run: + if self._run_hash: + self._run = Run( + self._run_hash, + repo=self.repo, + system_tracking_interval=self.system_tracking_interval, + log_system_params=self.log_system_params, + ) + else: + self._run = Run( + repo=self.repo, + experiment=self.experiment_name, + system_tracking_interval=self.system_tracking_interval, + log_system_params=self.log_system_params, + ) + self._run_hash = self._run.hash + + # Log config parameters + if args: + try: + for key in args: + self._run.set(key, args[key], strict=False) + except Exception as e: + estimator.logger.warning(f'Aim could not log config parameters -> {e}') + + def __del__(self): + if self._run and self._run.active: + self._run.close() diff --git a/examples/mxnet_track.py b/examples/mxnet_track.py new file mode 100644 index 0000000000..0f7bb602c5 --- /dev/null +++ b/examples/mxnet_track.py @@ -0,0 +1,71 @@ +import warnings +import mxnet as mx +from mxnet import gluon +from mxnet.gluon.model_zoo import vision +from mxnet.gluon.contrib.estimator import estimator + +from aim.mxnet import AimLoggingHandler + + +gpu_count = mx.context.num_gpus() +ctx = [mx.gpu(i) for i in range(gpu_count)] if gpu_count > 0 else mx.cpu() + + +# Get the training data +fashion_mnist_train = gluon.data.vision.FashionMNIST(train=True) + +# Get the validation data +fashion_mnist_val = gluon.data.vision.FashionMNIST(train=False) + +transforms = [gluon.data.vision.transforms.Resize(224), # We pick 224 as the model we use takes an input of size 224. + gluon.data.vision.transforms.ToTensor()] + +# Now we will stack all these together. +transforms = gluon.data.vision.transforms.Compose(transforms) + +# Apply the transformations +fashion_mnist_train = fashion_mnist_train.transform_first(transforms) +fashion_mnist_val = fashion_mnist_val.transform_first(transforms) + +batch_size = 256 # Batch size of the images +# The number of parallel workers for loading the data using Data Loaders. +num_workers = 4 + +train_data_loader = gluon.data.DataLoader(fashion_mnist_train, batch_size=batch_size, + shuffle=True, num_workers=num_workers) +val_data_loader = gluon.data.DataLoader(fashion_mnist_val, batch_size=batch_size, + shuffle=False, num_workers=num_workers) + + +resnet_18_v1 = vision.resnet18_v1(pretrained=False, classes=10) +resnet_18_v1.initialize(init=mx.init.Xavier(), ctx=ctx) + +loss_fn = gluon.loss.SoftmaxCrossEntropyLoss() + +learning_rate = 0.04 # You can experiment with your own learning rate here +num_epochs = 2 # You can run training for more epochs +trainer = gluon.Trainer(resnet_18_v1.collect_params(), + 'sgd', {'learning_rate': learning_rate}) + + +train_acc = mx.metric.Accuracy() # Metric to monitor +train_loss = mx.metric.Loss() +val_acc = mx.metric.Accuracy() + +# Define the estimator, by passing to it the model, loss function, metrics, trainer object and context +est = estimator.Estimator(net=resnet_18_v1, + loss=loss_fn, + train_metrics=[train_acc, train_loss], + val_metrics=val_acc, + trainer=trainer, + context=ctx) + +aim_log_handler = AimLoggingHandler(repo='.tmp_mxnet', experiment_name='mxnet_example', + log_interval=1, metrics=[train_acc, train_loss, val_acc]) + +# ignore warnings for nightly test on CI only +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # Magic line + est.fit(train_data=train_data_loader, val_data=val_data_loader, + epochs=num_epochs, event_handlers=[aim_log_handler]) From d06315dd656bdf9269401bc88e72e9434f2d45d7 Mon Sep 17 00:00:00 2001 From: tmynn Date: Fri, 30 Sep 2022 16:56:27 +0000 Subject: [PATCH 2/2] [fix] Add epoch_beign and epoch_end listeners --- aim/sdk/adapters/mxnet.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/aim/sdk/adapters/mxnet.py b/aim/sdk/adapters/mxnet.py index 3763e4246c..d7dab49454 100644 --- a/aim/sdk/adapters/mxnet.py +++ b/aim/sdk/adapters/mxnet.py @@ -1,13 +1,13 @@ import time import numpy as np from mxnet.gluon.contrib.estimator.utils import _check_metrics -from mxnet.gluon.contrib.estimator import TrainBegin, TrainEnd, BatchBegin, BatchEnd, Estimator +from mxnet.gluon.contrib.estimator import TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd, Estimator from typing import Optional, Union, Any, List from aim.sdk.run import Run from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT -class AimLoggingHandler(TrainBegin, TrainEnd, BatchBegin, BatchEnd): +class AimLoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd): """Aim wrapper on top of the Mxnet Basic Logging Handler that applies to every Gluon estimator by default. :py:class:`AimLoggingHandler` logs hyper-parameters, training statistics, and other useful information during training @@ -92,6 +92,34 @@ def train_end(self, estimator: Optional[Estimator], *args, **kwargs): msg += '%s: %.4f, ' % (name, value) estimator.logger.info(msg.rstrip(', ')) + def epoch_begin(self, estimator: Optional[Estimator], *args, **kwargs): + if isinstance(self.log_interval, int) or self.log_interval == 'epoch': + is_training = False + for metric in self.metrics: + if 'training' in metric.name: + is_training = True + self.epoch_start = time.time() + if is_training: + estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f", + self.current_epoch, estimator.trainer.learning_rate) + else: + estimator.logger.info("Validation Begin") + + def epoch_end(self, estimator: Optional[Estimator], *args, **kwargs): + if isinstance(self.log_interval, int) or self.log_interval == 'epoch': + epoch_time = time.time() - self.epoch_start + msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time) + for metric in self.metrics: + name, value = metric.get() + msg += '%s: %.4f, ' % (name, value) + + context_name, metric_name = name.split(" ") + context = {'subset': context_name} + self._run.track(value, metric_name, step=self.batch_index, context=context) + estimator.logger.info(msg.rstrip(', ')) + self.current_epoch += 1 + self.batch_index = 0 + def batch_begin(self, estimator: Optional[Estimator], *args, **kwargs): if isinstance(self.log_interval, int): self.batch_start = time.time()