Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add MXNet integration #2219

Merged
merged 2 commits into from
Oct 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- Add fast.ai integration (tmynn)
- Add command for dangling params cleanup (mihran113)
- Deprecate Python 3.6 (alberttorosyan)
- Add MXNet integration (tmynn)

### Fixes:

Expand Down
2 changes: 2 additions & 0 deletions aim/mxnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Alias to SDK mxnet interface
from aim.sdk.adapters.mxnet import AimLoggingHandler # noqa F401
182 changes: 182 additions & 0 deletions aim/sdk/adapters/mxnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import time
import numpy as np
from mxnet.gluon.contrib.estimator.utils import _check_metrics
from mxnet.gluon.contrib.estimator import TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd, Estimator
from typing import Optional, Union, Any, List
from aim.sdk.run import Run
from aim.ext.resource.configs import DEFAULT_SYSTEM_TRACKING_INT


class AimLoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, BatchEnd):
"""Aim wrapper on top of the Mxnet Basic Logging Handler that applies to every Gluon estimator by default.
:py:class:`AimLoggingHandler` logs hyper-parameters, training statistics,
and other useful information during training
Parameters
----------
log_interval: int or str, default 'epoch'
Logging interval during training.
log_interval='epoch': display metrics every epoch
log_interval=integer k: display metrics every interval of k batches
metrics : list of EvalMetrics
Metrics to be logged, logged at batch end, epoch end, train end.
priority : scalar, default np.Inf
Priority level of the AimLoggingHandler. Priority level is sorted in
ascending order. The lower the number is, the higher priority level the
handler is.
"""

def __init__(self, log_interval: Union[int, str] = 'epoch',
repo: Optional[str] = None,
experiment_name: Optional[str] = None,
system_tracking_interval: Optional[int] = DEFAULT_SYSTEM_TRACKING_INT,
log_system_params: bool = True,
metrics: List[Any] = None,
priority=np.Inf,):
super().__init__()
if not isinstance(log_interval, int) and log_interval != 'epoch':
raise ValueError("log_interval must be either an integer or string 'epoch'")

self.metrics = _check_metrics(metrics)
self.batch_index = 0
self.current_epoch = 0
self.processed_samples = 0
# logging handler need to be called at last to make sure all states are updated
# it will also shut down logging at train end
self.priority = priority
self.log_interval = log_interval
self.log_interval_time = 0

self.repo = repo
self.experiment_name = experiment_name
self.system_tracking_interval = system_tracking_interval
self.log_system_params = log_system_params
self._run = None
self._run_hash = None

def train_begin(self, estimator: Optional[Estimator], *args, **kwargs):
self.train_start = time.time()
trainer = estimator.trainer
optimizer = trainer.optimizer.__class__.__name__
lr = trainer.learning_rate

estimator.logger.info("Training begin: using optimizer %s "
"with current learning rate %.4f ",
optimizer, lr)
if estimator.max_epoch:
estimator.logger.info("Train for %d epochs.", estimator.max_epoch)
else:
estimator.logger.info("Train for %d batches.", estimator.max_batch)
# reset all counters
self.current_epoch = 0
self.batch_index = 0
self.processed_samples = 0
self.log_interval_time = 0

params = {
"arch": estimator.net.name,
"loss": estimator.loss.name,
"optimizer": optimizer,
"lr": lr,
"max_epoch": estimator.max_epoch,
"max_batch": estimator.max_batch
}

self.setup(estimator, params)

def train_end(self, estimator: Optional[Estimator], *args, **kwargs):
train_time = time.time() - self.train_start
msg = 'Train finished using total %ds with %d epochs. ' % (train_time, self.current_epoch)
# log every result in train stats including train/validation loss & metrics
for metric in self.metrics:
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))

def epoch_begin(self, estimator: Optional[Estimator], *args, **kwargs):
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
is_training = False
for metric in self.metrics:
if 'training' in metric.name:
is_training = True
self.epoch_start = time.time()
if is_training:
estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
self.current_epoch, estimator.trainer.learning_rate)
else:
estimator.logger.info("Validation Begin")

def epoch_end(self, estimator: Optional[Estimator], *args, **kwargs):
if isinstance(self.log_interval, int) or self.log_interval == 'epoch':
epoch_time = time.time() - self.epoch_start
msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
for metric in self.metrics:
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)

context_name, metric_name = name.split(" ")
context = {'subset': context_name}
self._run.track(value, metric_name, step=self.batch_index, context=context)
estimator.logger.info(msg.rstrip(', '))
self.current_epoch += 1
self.batch_index = 0

def batch_begin(self, estimator: Optional[Estimator], *args, **kwargs):
if isinstance(self.log_interval, int):
self.batch_start = time.time()

def batch_end(self, estimator: Optional[Estimator], *args, **kwargs):
if isinstance(self.log_interval, int):
batch_time = time.time() - self.batch_start
msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index)
self.processed_samples += kwargs['batch'][0].shape[0]
msg += '[Samples %s] ' % (self.processed_samples)
self.log_interval_time += batch_time
if self.batch_index % self.log_interval == 0:
msg += 'time/interval: %.3fs ' % self.log_interval_time
self.log_interval_time = 0
for metric in self.metrics:
# only log current training loss & metric after each interval
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)

context_name, metric_name = name.split(" ")
context = {'subset': context_name}
self._run.track(value, metric_name, step=self.batch_index, context=context)
estimator.logger.info(msg.rstrip(', '))
self.batch_index += 1

@property
def experiment(self) -> Run:
if not self._run:
self.setup()
return self._run

def setup(self, estimator: Optional[Estimator] = None, args=None):
if not self._run:
if self._run_hash:
self._run = Run(
self._run_hash,
repo=self.repo,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
else:
self._run = Run(
repo=self.repo,
experiment=self.experiment_name,
system_tracking_interval=self.system_tracking_interval,
log_system_params=self.log_system_params,
)
self._run_hash = self._run.hash

# Log config parameters
if args:
try:
for key in args:
self._run.set(key, args[key], strict=False)
except Exception as e:
estimator.logger.warning(f'Aim could not log config parameters -> {e}')

def __del__(self):
if self._run and self._run.active:
self._run.close()
71 changes: 71 additions & 0 deletions examples/mxnet_track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import warnings
import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
from mxnet.gluon.contrib.estimator import estimator

from aim.mxnet import AimLoggingHandler


gpu_count = mx.context.num_gpus()
ctx = [mx.gpu(i) for i in range(gpu_count)] if gpu_count > 0 else mx.cpu()


# Get the training data
fashion_mnist_train = gluon.data.vision.FashionMNIST(train=True)

# Get the validation data
fashion_mnist_val = gluon.data.vision.FashionMNIST(train=False)

transforms = [gluon.data.vision.transforms.Resize(224), # We pick 224 as the model we use takes an input of size 224.
gluon.data.vision.transforms.ToTensor()]

# Now we will stack all these together.
transforms = gluon.data.vision.transforms.Compose(transforms)

# Apply the transformations
fashion_mnist_train = fashion_mnist_train.transform_first(transforms)
fashion_mnist_val = fashion_mnist_val.transform_first(transforms)

batch_size = 256 # Batch size of the images
# The number of parallel workers for loading the data using Data Loaders.
num_workers = 4

train_data_loader = gluon.data.DataLoader(fashion_mnist_train, batch_size=batch_size,
shuffle=True, num_workers=num_workers)
val_data_loader = gluon.data.DataLoader(fashion_mnist_val, batch_size=batch_size,
shuffle=False, num_workers=num_workers)


resnet_18_v1 = vision.resnet18_v1(pretrained=False, classes=10)
resnet_18_v1.initialize(init=mx.init.Xavier(), ctx=ctx)

loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()

learning_rate = 0.04 # You can experiment with your own learning rate here
num_epochs = 2 # You can run training for more epochs
trainer = gluon.Trainer(resnet_18_v1.collect_params(),
'sgd', {'learning_rate': learning_rate})


train_acc = mx.metric.Accuracy() # Metric to monitor
train_loss = mx.metric.Loss()
val_acc = mx.metric.Accuracy()

# Define the estimator, by passing to it the model, loss function, metrics, trainer object and context
est = estimator.Estimator(net=resnet_18_v1,
loss=loss_fn,
train_metrics=[train_acc, train_loss],
val_metrics=val_acc,
trainer=trainer,
context=ctx)

aim_log_handler = AimLoggingHandler(repo='.tmp_mxnet', experiment_name='mxnet_example',
log_interval=1, metrics=[train_acc, train_loss, val_acc])

# ignore warnings for nightly test on CI only
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Magic line
est.fit(train_data=train_data_loader, val_data=val_data_loader,
epochs=num_epochs, event_handlers=[aim_log_handler])