diff --git a/dependencies/recommended.txt b/dependencies/recommended.txt index a7d2bfbbfb..79a148c3d4 100644 --- a/dependencies/recommended.txt +++ b/dependencies/recommended.txt @@ -6,7 +6,7 @@ torch == 1.6.0+cpu ; sys_platform != "darwin" torch == 1.6.0 ; sys_platform == "darwin" torchvision == 0.7.0+cpu ; sys_platform != "darwin" torchvision == 0.7.0 ; sys_platform == "darwin" -pytorch-lightning >= 1.1.1, < 1.2 +pytorch-lightning >= 1.1.1 onnx peewee graphviz diff --git a/docs/en_US/Compression/Overview.rst b/docs/en_US/Compression/Overview.rst index 5b63927af6..262d9631f1 100644 --- a/docs/en_US/Compression/Overview.rst +++ b/docs/en_US/Compression/Overview.rst @@ -87,6 +87,8 @@ Quantization algorithms compress the original network by reducing the number of - DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. `Reference Paper `__ * - `BNN Quantizer <../Compression/Quantizer.rst#bnn-quantizer>`__ - Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1. `Reference Paper `__ + * - `LSQ Quantizer <../Compression/Quantizer.rst#lsq-quantizer>`__ + - Learned step size quantization. `Reference Paper `__ Model Speedup diff --git a/docs/en_US/Compression/Quantizer.rst b/docs/en_US/Compression/Quantizer.rst index cc164c5296..2af973e4f2 100644 --- a/docs/en_US/Compression/Quantizer.rst +++ b/docs/en_US/Compression/Quantizer.rst @@ -8,6 +8,7 @@ Index of supported quantization algorithms * `QAT Quantizer <#qat-quantizer>`__ * `DoReFa Quantizer <#dorefa-quantizer>`__ * `BNN Quantizer <#bnn-quantizer>`__ +* `LSQ Quantizer <#lsq-quantizer>`__ Naive Quantizer --------------- @@ -86,6 +87,61 @@ note batch normalization folding is currently not supported. +---- + +LSQ Quantizer +------------- + +In `LEARNED STEP SIZE QUANTIZATION `__\ , authors Steven K. Esser and Jeffrey L. McKinstry provide an algorithm to train the scales with gradients. + +.. + + The authors introduce a novel means to estimate and scale the task loss gradient at each weight and activation layer’s quantizer step size, such that it can be learned in conjunction with other network parameters. + + +Usage +^^^^^ +You can add codes below before your training codes. Three things must be done: + + +1. configure which layer to be quantized and which tensor (input/output/weight) of that layer to be quantized. +2. construct the lsq quantizer +3. call the `compress` API + + +PyTorch code + +.. code-block:: python + + from nni.algorithms.compression.pytorch.quantization import LsqQuantizer + model = Mnist() + + configure_list = [{ + 'quant_types': ['weight', 'input'], + 'quant_bits': { + 'weight': 8, + 'input': 8, + }, + 'op_names': ['conv1'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8,}, + 'op_names': ['relu1'] + }] + + quantizer = LsqQuantizer(model, configure_list, optimizer) + quantizer.compress() + +You can view example for more information. :githublink:`examples/model_compress/quantization/LSQ_torch_quantizer.py ` + +User configuration for LSQ Quantizer +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +common configuration needed by compression algorithms can be found at `Specification of `config_list <./QuickStart.rst>`__. + +configuration needed by this algorithm : + + ---- DoReFa Quantizer diff --git a/docs/en_US/NAS/retiarii/retiarii_index.rst b/docs/en_US/NAS/retiarii/retiarii_index.rst index 579567f309..73ea9c9aa0 100644 --- a/docs/en_US/NAS/retiarii/retiarii_index.rst +++ b/docs/en_US/NAS/retiarii/retiarii_index.rst @@ -2,7 +2,13 @@ Retiarii Overview ################# -`Retiarii `__ is a new framework to support neural architecture search and hyper-parameter tuning. It allows users to express various search space with high flexibility, to reuse many SOTA search algorithms, and to leverage system level optimizations to speed up the search process. This framework provides the following new user experiences. +`Retiarii `__ is a deep learning framework that supports the exploratory training on a neural network model space, rather than on a single neural network model. + +Exploratory training with Retiarii allows user to express various search space for **Neural Architecture Search** and **Hyper-Parameter Tuning** with high flexibility. + +As previous NAS and HPO supports, the new framework continued the ability for allowing user to reuse SOTA search algorithms, and to leverage system level optimizations to speed up the search process. + +Follow the instructions below to start your journey with Retiarii. .. toctree:: :maxdepth: 2 @@ -12,4 +18,4 @@ Retiarii Overview One-shot NAS Advanced Tutorial Customize a New Strategy - Retiarii APIs \ No newline at end of file + Retiarii APIs diff --git a/docs/en_US/TrainingService/HybridMode.rst b/docs/en_US/TrainingService/HybridMode.rst index 854c60da47..76cc940457 100644 --- a/docs/en_US/TrainingService/HybridMode.rst +++ b/docs/en_US/TrainingService/HybridMode.rst @@ -15,40 +15,25 @@ Use ``examples/trials/mnist-tfv1`` as an example. The NNI config YAML file's con .. code-block:: yaml - authorName: default - experimentName: example_mnist + experimentName: MNIST + searchSpaceFile: search_space.json + trialCommand: python3 mnist.py + trialCodeDirectory: . trialConcurrency: 2 - maxExecDuration: 1h - maxTrialNum: 10 - trainingServicePlatform: hybrid - searchSpacePath: search_space.json - #choice: true, false - useAnnotation: false + trialGpuNumber: 0 + maxExperimentDuration: 24h + maxTrialNumber: 100 tuner: - builtinTunerName: TPE + name: TPE classArgs: - #choice: maximize, minimize optimize_mode: maximize - trial: - command: python3 mnist.py - codeDir: . - gpuNum: 1 - hybridConfig: - trainingServicePlatforms: - - local - - remote - remoteConfig: - reuse: true - machineList: - - ip: 10.1.1.1 - username: bob - passwd: bob123 - -Configurations for hybrid mode: - -hybridConfig: - -* trainingServicePlatforms. required key. This field specify the platforms used in hybrid mode, the values using yaml list format. NNI support setting ``local``, ``remote``, ``aml``, ``pai`` in this field. - - -.. Note:: If setting a platform in trainingServicePlatforms mode, users should also set the corresponding configuration for the platform. For example, if set ``remote`` as one of the platform, should also set ``machineList`` and ``remoteConfig`` configuration. \ No newline at end of file + trainingService: + - platform: remote + machineList: + - host: 127.0.0.1 + user: bob + password: bob + - platform: local + +To use hybrid training services, users should set training service configurations as a list in `trainingService` field. +Currently, hybrid support setting `local`, `remote`, `pai` and `aml` training services. diff --git a/docs/en_US/Tutorial/Nnictl.rst b/docs/en_US/Tutorial/Nnictl.rst index 98cf1a62bc..4b3d40f7b2 100644 --- a/docs/en_US/Tutorial/Nnictl.rst +++ b/docs/en_US/Tutorial/Nnictl.rst @@ -28,7 +28,6 @@ nnictl support commands: * `nnictl config <#config>`__ * `nnictl log <#log>`__ * `nnictl webui <#webui>`__ -* `nnictl tensorboard <#tensorboard>`__ * `nnictl algo <#algo>`__ * `nnictl ss_gen <#ss_gen>`__ * `nnictl --version <#version>`__ @@ -1311,97 +1310,6 @@ Manage webui - Experiment ID -:raw-html:`` - -Manage tensorboard -^^^^^^^^^^^^^^^^^^ - - -* - **nnictl tensorboard start** - - - * - Description - - Start the tensorboard process. - - * - Usage - - .. code-block:: bash - - nnictl tensorboard start - - * - Options - -.. list-table:: - :header-rows: 1 - :widths: auto - - * - Name, shorthand - - Required - - Default - - Description - * - id - - False - - - - ID of the experiment you want to set - * - --trial_id, -T - - False - - - - ID of the trial - * - --port - - False - - 6006 - - The port of the tensorboard process - - - -* - Detail - - - #. NNICTL support tensorboard function in local and remote platform for the moment, other platforms will be supported later. - #. If you want to use tensorboard, you need to write your tensorboard log data to environment variable [NNI_OUTPUT_DIR] path. - #. In local mode, nnictl will set --logdir=[NNI_OUTPUT_DIR] directly and start a tensorboard process. - #. In remote mode, nnictl will create a ssh client to copy log data from remote machine to local temp directory firstly, and then start a tensorboard process in your local machine. You need to notice that nnictl only copy the log data one time when you use the command, if you want to see the later result of tensorboard, you should execute nnictl tensorboard command again. - #. If there is only one trial job, you don't need to set trial id. If there are multiple trial jobs running, you should set the trial id, or you could use [nnictl tensorboard start --trial_id all] to map --logdir to all trial log paths. - - -* - **nnictl tensorboard stop** - - - * - Description - - Stop all of the tensorboard process. - - * - Usage - - .. code-block:: bash - - nnictl tensorboard stop - - * - Options - -.. list-table:: - :header-rows: 1 - :widths: auto - - * - Name, shorthand - - Required - - Default - - Description - * - id - - False - - - - ID of the experiment you want to set - :raw-html:`` diff --git a/docs/en_US/builtin_tuner.rst b/docs/en_US/builtin_tuner.rst index 7acd002808..1e6fd36a73 100644 --- a/docs/en_US/builtin_tuner.rst +++ b/docs/en_US/builtin_tuner.rst @@ -10,9 +10,7 @@ Tuner receives metrics from `Trial` to evaluate the performance of a specific pa :maxdepth: 1 Overview - TPE - Random Search - Anneal + TPE / Random Search / Anneal Naive Evolution SMAC Metis Tuner diff --git a/docs/en_US/conf.py b/docs/en_US/conf.py index 9190022b9b..794f97fd1a 100644 --- a/docs/en_US/conf.py +++ b/docs/en_US/conf.py @@ -201,4 +201,4 @@ # -- Extension configuration ------------------------------------------------- def setup(app): - app.add_stylesheet('css/custom.css') + app.add_css_file('css/custom.css') diff --git a/docs/requirements.txt b/docs/requirements.txt index 5c7426c2e9..04c0633cba 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,4 @@ -sphinx>=3.3.1 +sphinx>=4.0 sphinx-argparse sphinx-rtd-theme sphinxcontrib-websupport diff --git a/examples/model_compress/quantization/LSQ_torch_quantizer.py b/examples/model_compress/quantization/LSQ_torch_quantizer.py new file mode 100644 index 0000000000..449a4e179c --- /dev/null +++ b/examples/model_compress/quantization/LSQ_torch_quantizer.py @@ -0,0 +1,142 @@ +import torch +import torch.nn.functional as F +from torchvision import datasets, transforms +from nni.algorithms.compression.pytorch.quantization import LsqQuantizer +from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT + + +class Mnist(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv1 = torch.nn.Conv2d(1, 20, 5, 1) + self.conv2 = torch.nn.Conv2d(20, 50, 5, 1) + self.fc1 = torch.nn.Linear(4 * 4 * 50, 500) + self.fc2 = torch.nn.Linear(500, 10) + self.relu1 = torch.nn.ReLU6() + self.relu2 = torch.nn.ReLU6() + self.relu3 = torch.nn.ReLU6() + self.max_pool1 = torch.nn.MaxPool2d(2, 2) + self.max_pool2 = torch.nn.MaxPool2d(2, 2) + + def forward(self, x): + x = self.relu1(self.conv1(x)) + x = self.max_pool1(x) + x = self.relu2(self.conv2(x)) + x = self.max_pool2(x) + x = x.view(-1, 4 * 4 * 50) + x = self.relu3(self.fc1(x)) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +def train(model, quantizer, device, train_loader, optimizer): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % 100 == 0: + print('{:2.0f}% Loss {}'.format(100 * batch_idx / len(train_loader), loss.item())) + + +def test(model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + test_loss /= len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%)\n'.format( + test_loss, 100 * correct / len(test_loader.dataset))) + + +def test_trt(engine, test_loader): + test_loss = 0 + correct = 0 + time_elasped = 0 + for data, target in test_loader: + output, time = engine.inference(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + time_elasped += time + test_loss /= len(test_loader.dataset) + + print('Loss: {} Accuracy: {}%'.format( + test_loss, 100 * correct / len(test_loader.dataset))) + print("Inference elapsed_time (whole dataset): {}s".format(time_elasped)) + + +def main(): + torch.manual_seed(0) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) + train_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=True, download=True, transform=trans), + batch_size=64, shuffle=True) + test_loader = torch.utils.data.DataLoader( + datasets.MNIST('data', train=False, transform=trans), + batch_size=1000, shuffle=True) + + model = Mnist() + configure_list = [{ + 'quant_types': ['weight', 'input'], + 'quant_bits': {'weight': 8, 'input': 8}, + 'op_names': ['conv1'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8, }, + 'op_names': ['relu1'] + }, { + 'quant_types': ['weight', 'input'], + 'quant_bits': {'weight': 8, 'input': 8}, + 'op_names': ['conv2'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8}, + 'op_names': ['relu2'] + }, { + 'quant_types': ['output'], + 'quant_bits': {'output': 8}, + 'op_names': ['max_pool2'] + } + ] + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + quantizer = LsqQuantizer(model, configure_list, optimizer) + quantizer.compress() + + model.to(device) + for epoch in range(40): + print('# Epoch {} #'.format(epoch)) + train(model, quantizer, device, train_loader, optimizer) + test(model, device, test_loader) + + model_path = "mnist_model.pth" + calibration_path = "mnist_calibration.pth" + calibration_config = quantizer.export_model(model_path, calibration_path) + + test(model, device, test_loader) + + print("calibration_config: ", calibration_config) + + batch_size = 32 + input_shape = (batch_size, 1, 28, 28) + + engine = ModelSpeedupTensorRT(model, input_shape, config=calibration_config, batchsize=batch_size) + engine.compress() + + test_trt(engine, test_loader) + + +if __name__ == '__main__': + main() diff --git a/nni/algorithms/compression/pytorch/pruning/amc/channel_pruning_env.py b/nni/algorithms/compression/pytorch/pruning/amc/channel_pruning_env.py index 443daf7efb..428f7e7532 100644 --- a/nni/algorithms/compression/pytorch/pruning/amc/channel_pruning_env.py +++ b/nni/algorithms/compression/pytorch/pruning/amc/channel_pruning_env.py @@ -85,7 +85,7 @@ class ChannelPruningEnv: args: A Namespace object containing following arguments: model_type: str - model type to prune, currently 'mobilenet' and 'mobilenetv2' are supported. + model type to prune, currently 'mobilenet', 'mobilenetv2' and 'resnet' are supported. flops_ratio: float preserve flops ratio. lbound: float diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index ca40e30e45..62703d449b 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -6,9 +6,9 @@ import torch from schema import Schema, And, Or, Optional from nni.compression.pytorch.utils.config_validation import CompressorSchema -from nni.compression.pytorch.compressor import Quantizer, QuantGrad, QuantType +from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType -__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer'] +__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer'] logger = logging.getLogger(__name__) @@ -59,7 +59,7 @@ def update_ema(biased_ema, value, decay): float, float """ biased_ema = biased_ema * decay + (1 - decay) * value - return biased_ema + return biased_ema def update_quantization_param(bits, rmin, rmax): @@ -146,7 +146,7 @@ def __init__(self, model, config_list, optimizer=None): types of nn.module you want to apply quantization, eg. 'Conv2d' """ super().__init__(model, config_list, optimizer) - self.quant_grad = QATGrad + self.quant_grad = QATGrad.apply modules_to_compress = self.get_modules_to_compress() self.bound_model.register_buffer("steps", torch.Tensor([1])) for layer, config in modules_to_compress: @@ -474,7 +474,7 @@ class BNNQuantizer(Quantizer): def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, optimizer) - self.quant_grad = ClipGrad + self.quant_grad = ClipGrad.apply modules_to_compress = self.get_modules_to_compress() for layer, config in modules_to_compress: if "weight" in config.get("quant_types", []): @@ -559,4 +559,206 @@ def export_model(self, model_path, calibration_path=None, onnx_path=None, input_ self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, input_shape, device) - return calibration_config \ No newline at end of file + return calibration_config + + +class LsqQuantizer(Quantizer): + """Quantizer defined in: + Learned Step Size Quantization (ICLR 2020) + https://arxiv.org/pdf/1902.08153.pdf + """ + + def __init__(self, model, config_list, optimizer=None): + """ + Parameters + ---------- + model : torch.nn.Module + the model to be quantized + config_list : list of dict + list of configurations for quantization + supported keys for dict: + - quant_types : list of string + type of quantization you want to apply, currently support 'weight', 'input', 'output' + - quant_bits : int or dict of {str : int} + bits length of quantization, key is the quantization type, value is the length, eg. {'weight': 8}, + when the type is int, all quantization types share same bits length + - quant_start_step : int + disable quantization until model are run by certain number of steps, this allows the network to enter a more stable + state where activation quantization ranges do not exclude a significant fraction of values, default value is 0 + - op_types : list of string + types of nn.module you want to apply quantization, eg. 'Conv2d' + """ + super().__init__(model, config_list, optimizer) + self.quant_grad = QuantForward() + modules_to_compress = self.get_modules_to_compress() + self.bound_model.register_buffer("steps", torch.Tensor([1])) + for layer, config in modules_to_compress: + if "weight" in config.get("quant_types", []): + layer.module.register_parameter("weight_scale", torch.nn.Parameter(torch.Tensor([1.0]))) + # todo: support per-channel quantization for weight since TensorRT use it for conv weight + q_bit = get_bits_length(config, "weight") + layer.module.register_buffer('weight_bit', torch.Tensor([q_bit])) + qmax = 2 ** (q_bit - 1) - 1 + qmin = -2 ** (q_bit - 1) + init_weight_scale = layer.module.weight.data.detach().abs().mean() * 2 / (qmax ** 0.5) + layer.module.weight_scale = torch.nn.Parameter(init_weight_scale) + layer.module.weight_qmax = qmax + layer.module.weight_qmin = qmin + + self.optimizer.add_param_group({"params": layer.module.weight_scale}) + + if "output" in config.get("quant_types", []): + # scale of activation will be initialized using the first batch data + layer.module.register_parameter("output_scale", torch.nn.Parameter(torch.Tensor([1.0]))) + q_bit = get_bits_length(config, "output") + layer.module.register_buffer('output_bit', torch.Tensor([q_bit])) + qmax = 2 ** (q_bit - 1) - 1 + qmin = -2 ** (q_bit - 1) + layer.module.output_qmax = qmax + layer.module.output_qmin = qmin + + self.optimizer.add_param_group({"params": layer.module.output_scale}) + + if "input" in config.get("quant_types", []): + # scale of input will be initialized using the first batch data + layer.module.register_parameter("input_scale", torch.nn.Parameter(torch.Tensor([1.0]))) + q_bit = get_bits_length(config, "input") + layer.module.register_buffer('input_bit', torch.Tensor([q_bit])) + qmax = 2 ** (q_bit - 1) - 1 + qmin = -2 ** (q_bit - 1) + layer.module.input_qmax = qmax + layer.module.input_qmin = qmin + + self.optimizer.add_param_group({"params": layer.module.input_scale}) + + @staticmethod + def grad_scale(x, scale): + """ + Used to scale the gradient. Give tensor `x`, we have `y=grad_scale(x, scale)=x` in the forward pass, + which means that this function will not change the value of `x`. In the backward pass, we have: + + :math:`\frac{\alpha_L}{\alpha_x}=\frac{\alpha_L}{\alpha_y}*\frac{\alpha_y}{\alpha_x}=sclae*\frac{\alpha_L}{\alpha_x}` + + This means that the origin gradient of x is scaled by a factor of `scale`. Applying this function + to a nn.Parameter will scale the gradient of it without changing its value. + """ + y = x + y_grad = x * scale + return (y - y_grad).detach() + y_grad + + @staticmethod + def round_pass(x): + """ + A simple way to achieve STE operation. + """ + y = x.round() + y_grad = x + return (y - y_grad).detach() + y_grad + + def quantize(self, x, scale, qmin, qmax): + grad_scale_factor = 1.0 / ((qmax * x.numel()) ** 0.5) + scale = self.grad_scale(scale, grad_scale_factor) + x = x / scale + x = torch.clamp(x, qmin, qmax) + x = self.round_pass(x) + x = x * scale + return x + + def quantize_weight(self, wrapper, **kwargs): + module = wrapper.module + + # todo: add support for quantize bias. If we use TensorRT as backend, there is no need to quantize + # bias + old_weight = module.old_weight + weight = self.quantize(old_weight, module.weight_scale, module.weight_qmin, module.weight_qmax) + module.weight = weight + return weight + + def quantize_output(self, output, wrapper, **kwargs): + module = wrapper.module + + # initialize the scale + if self.bound_model.steps == 1: + qmax = module.output_qmax + init_oup_scale = output.data.detach().abs().mean() * 2 / (qmax ** 0.5) + module.output_scale.data = init_oup_scale + + output = self.quantize(output, module.output_scale, module.output_qmin, module.output_qmax) + return output + + def quantize_input(self, *inputs, wrapper, **kwargs): + # This is hacky since it is not recommended to modify a tuple + # NB: support layers with multi inputs + module = wrapper.module + # initialize the scale + if self.bound_model.steps == 1: + qmax = module.input_qmax + init_oup_scale = inputs[0].data.detach().abs().mean() * 2 / (qmax ** 0.5) + module.input_scale.data = init_oup_scale + + new_input = self.quantize(inputs[0], module.input_scale, module.input_qmin, module.input_qmax) + list_inp = list(inputs) + list_inp[0] = new_input + return tuple(list_inp) + + def export_model(self, model_path, calibration_path=None, onnx_path=None, input_shape=None, device=None): + """ + Export quantized model weights and calibration parameters(optional) + + Parameters + ---------- + model_path : str + path to save quantized model weight + calibration_path : str + (optional) path to save quantize parameters after calibration + onnx_path : str + (optional) path to save onnx model + input_shape : list or tuple + input shape to onnx model + device : torch.device + device of the model, used to place the dummy input tensor for exporting onnx file. + the tensor is placed on cpu if ```device``` is None + + Returns + ------- + Dict + """ + assert model_path is not None, 'model_path must be specified' + self._unwrap_model() + calibration_config = {} + + for name, module in self.bound_model.named_modules(): + if hasattr(module, 'input_bit') or hasattr(module, 'output_bit'): + calibration_config[name] = {} + if hasattr(module, 'weight_bit'): + calibration_config[name]['weight_bit'] = int(module.weight_bit) + abs_max_input = float(module.input_scale * module.input_qmax) + calibration_config[name]['tracked_min_input'] = -abs_max_input + calibration_config[name]['tracked_max_input'] = abs_max_input + if hasattr(module, 'output_bit'): + calibration_config[name]['activation_bit'] = int(module.output_bit) + abs_max_output = float(module.output_scale * module.output_qmax) + calibration_config[name]['tracked_min_activation'] = -abs_max_output + calibration_config[name]['tracked_max_activation'] = abs_max_output + self._del_simulated_attr(module) + + self.export_model_save(self.bound_model, model_path, calibration_config, calibration_path, onnx_path, + input_shape, device) + + return calibration_config + + def _del_simulated_attr(self, module): + """ + delete redundant parameters in quantize module + """ + del_attr_list = ['old_weight', 'tracked_min_input', 'tracked_max_input', 'tracked_min_activation', \ + 'tracked_max_activation', 'output_scale', 'input_scale', 'weight_scale','weight_bit', 'output_bit', 'input_bit'] + for attr in del_attr_list: + if hasattr(module, attr): + delattr(module, attr) + + def step_with_optimizer(self): + """ + override `compressor` `step` method, quantization only happens after certain number of steps + """ + self.bound_model.steps += 1 diff --git a/nni/algorithms/hpo/networkmorphism_tuner/networkmorphism_tuner.py b/nni/algorithms/hpo/networkmorphism_tuner/networkmorphism_tuner.py index 6a73cad3c6..385028506d 100644 --- a/nni/algorithms/hpo/networkmorphism_tuner/networkmorphism_tuner.py +++ b/nni/algorithms/hpo/networkmorphism_tuner/networkmorphism_tuner.py @@ -225,7 +225,7 @@ def update(self, other_info, graph, metric_value, model_id): ---------- other_info: any object In our case it is the father ID in the search tree. - graph: Graph + graph: graph.Graph An instance of Graph. The trained neural architecture. metric_value: float The final evaluated metric value. @@ -284,7 +284,7 @@ def load_model_by_id(self, model_id): Returns ------- - load_model : Graph + load_model : graph.Graph the model graph representation """ @@ -300,7 +300,7 @@ def load_best_model(self): Returns ------- - load_model : Graph + load_model : graph.Graph the model graph representation """ return self.load_model_by_id(self.get_best_model_id()) diff --git a/nni/compression/pytorch/compressor.py b/nni/compression/pytorch/compressor.py index 7fecdc3b4f..08543caf1a 100644 --- a/nni/compression/pytorch/compressor.py +++ b/nni/compression/pytorch/compressor.py @@ -474,13 +474,13 @@ def __init__(self, module, module_name, module_type, config, quantizer): def forward(self, *inputs): if 'input' in self.config['quant_types']: - inputs = self.quantizer.quant_grad.apply( + inputs = self.quantizer.quant_grad( inputs, QuantType.QUANT_INPUT, self) if 'weight' in self.config['quant_types'] and _check_weight(self.module): - self.quantizer.quant_grad.apply( + self.quantizer.quant_grad( self.module.old_weight, QuantType.QUANT_WEIGHT, self, inputs[0]) @@ -489,12 +489,13 @@ def forward(self, *inputs): result = self.module(*inputs) if 'output' in self.config['quant_types']: - result = self.quantizer.quant_grad.apply( + result = self.quantizer.quant_grad( result, QuantType.QUANT_OUTPUT, self) return result + class Quantizer(Compressor): """ Base quantizer for pytorch quantizer @@ -502,7 +503,7 @@ class Quantizer(Compressor): def __init__(self, model, config_list, optimizer=None): super().__init__(model, config_list, optimizer) - self.quant_grad = QuantGrad + self.quant_grad = QuantGrad.apply if self.optimizer is not None: self.patch_optimizer(self.step_with_optimizer) for wrapper in self.get_modules_wrapper(): @@ -719,15 +720,7 @@ def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qma @staticmethod def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs): - if quant_type == QuantType.QUANT_INPUT: - output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs) - elif quant_type == QuantType.QUANT_WEIGHT: - output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs) - elif quant_type == QuantType.QUANT_OUTPUT: - output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) - else: - raise ValueError("unrecognized QuantType.") - + output = quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs) bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type]) qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device) @@ -750,3 +743,24 @@ def _check_weight(module): return isinstance(module.weight.data, torch.Tensor) except AttributeError: return False + +def quantize_helper(tensor, quant_type, wrapper, input_tensor=None, **kwargs): + if quant_type == QuantType.QUANT_INPUT: + output = wrapper.quantizer.quantize_input(*tensor, wrapper=wrapper, **kwargs) + elif quant_type == QuantType.QUANT_WEIGHT: + output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs) + elif quant_type == QuantType.QUANT_OUTPUT: + output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) + else: + raise ValueError("unrecognized QuantType.") + + return output + +class QuantForward(torch.nn.Module): + """ + Base class for executing quantization operations. This is for quantization algorithms + that do not need to customize gradient. + """ + + def forward(self, tensor, quant_type, wrapper, input_tensor=None, **kwargs): + return quantize_helper(tensor, quant_type, wrapper, input_tensor, **kwargs) diff --git a/nni/compression/pytorch/utils/mask_conflict.py b/nni/compression/pytorch/utils/mask_conflict.py index 7b7ea03719..8e37893ba4 100644 --- a/nni/compression/pytorch/utils/mask_conflict.py +++ b/nni/compression/pytorch/utils/mask_conflict.py @@ -333,7 +333,7 @@ def fix_mask(self): elif type(m).__name__ == 'Linear': new_mask[:, merged_index] = 1. elif type(m).__name__ == 'BatchNorm2d': - new_mask = merged_index.type_as(orig_mask) + new_mask = merged_channel_mask.type_as(orig_mask) else: raise RuntimeError( f'unsupported module type: {type(m).__name__}') diff --git a/nni/retiarii/evaluator/pytorch/lightning.py b/nni/retiarii/evaluator/pytorch/lightning.py index 3a0d272b13..d316ce857a 100644 --- a/nni/retiarii/evaluator/pytorch/lightning.py +++ b/nni/retiarii/evaluator/pytorch/lightning.py @@ -165,13 +165,18 @@ def _get_validation_metrics(self): return {name: self.trainer.callback_metrics['val_' + name].item() for name in self.metrics} +class _AccuracyWithLogits(pl.metrics.Accuracy): + def update(self, pred, target): + return super().update(nn.functional.softmax(pred), target) + + @serialize_cls class _ClassificationModule(_SupervisedLearningModule): def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, learning_rate: float = 0.001, weight_decay: float = 0., optimizer: optim.Optimizer = optim.Adam): - super().__init__(criterion, {'acc': pl.metrics.Accuracy}, + super().__init__(criterion, {'acc': _AccuracyWithLogits}, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer) diff --git a/nni/retiarii/experiment/pytorch.py b/nni/retiarii/experiment/pytorch.py index 2e5df288c6..b7473507a4 100644 --- a/nni/retiarii/experiment/pytorch.py +++ b/nni/retiarii/experiment/pytorch.py @@ -52,7 +52,7 @@ class RetiariiExeConfig(ConfigBase): nni_manager_ip: Optional[str] = None debug: bool = False log_level: Optional[str] = None - experiment_working_directory: Optional[PathLike] = None + experiment_working_directory: PathLike = '~/nni-experiments' # remove configuration of tuner/assessor/advisor training_service: TrainingServiceConfig diff --git a/nni/runtime/log.py b/nni/runtime/log.py index 0382624a8d..0fc97c666f 100644 --- a/nni/runtime/log.py +++ b/nni/runtime/log.py @@ -88,7 +88,7 @@ def _init_logger_dispatcher() -> None: def _init_logger_trial() -> None: log_path = _prepare_log_dir(trial_env_vars.NNI_OUTPUT_DIR) / 'trial.log' - log_file = open(log_path, 'w') + log_file = open(log_path, 'a') _register_handler(StreamHandler(log_file), logging.INFO) if trial_env_vars.NNI_PLATFORM == 'local': diff --git a/nni/runtime/platform/local.py b/nni/runtime/platform/local.py index 681292ba73..b1f26462e2 100644 --- a/nni/runtime/platform/local.py +++ b/nni/runtime/platform/local.py @@ -13,7 +13,7 @@ _sysdir = trial_env_vars.NNI_SYS_DIR if not os.path.exists(os.path.join(_sysdir, '.nni')): os.makedirs(os.path.join(_sysdir, '.nni')) -_metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'wb') +_metric_file = open(os.path.join(_sysdir, '.nni', 'metrics'), 'ab') _outputdir = trial_env_vars.NNI_OUTPUT_DIR if not os.path.exists(_outputdir): diff --git a/nni/tools/nnictl/common_utils.py b/nni/tools/nnictl/common_utils.py index ba4aad7233..3667f8adfd 100644 --- a/nni/tools/nnictl/common_utils.py +++ b/nni/tools/nnictl/common_utils.py @@ -81,14 +81,6 @@ def get_user(): else: return os.environ['USER'] -def check_tensorboard_version(): - try: - import tensorboard - return tensorboard.__version__ - except: - print_error('import tensorboard error!') - exit(1) - def generate_temp_dir(): '''generate a temp folder''' def generate_folder_name(): diff --git a/nni/tools/nnictl/launcher.py b/nni/tools/nnictl/launcher.py index 581c8f1d12..7b2d20f9b8 100644 --- a/nni/tools/nnictl/launcher.py +++ b/nni/tools/nnictl/launcher.py @@ -368,7 +368,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else NNI_HOME_DIR else: log_dir = experiment_config['experimentWorkingDirectory'] if experiment_config.get('experimentWorkingDirectory') else NNI_HOME_DIR - log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None + log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else 'info' #view experiment mode do not need debug function, when view an experiment, there will be no new logs created foreground = False if mode != 'view': diff --git a/nni/tools/nnictl/nnictl.py b/nni/tools/nnictl/nnictl.py index dbdae0d84d..fd8697337e 100644 --- a/nni/tools/nnictl/nnictl.py +++ b/nni/tools/nnictl/nnictl.py @@ -16,7 +16,6 @@ save_experiment, load_experiment from .algo_management import algo_reg, algo_unreg, algo_show, algo_list from .constants import DEFAULT_REST_PORT -from .tensorboard_utils import start_tensorboard, stop_tensorboard init(autoreset=True) if os.environ.get('COVERAGE_PROCESS_START'): @@ -250,18 +249,6 @@ def show_messsage_for_nnictl_package(args): parser_package_subparsers.add_argument('args', nargs=argparse.REMAINDER) parser_package_subparsers.set_defaults(func=show_messsage_for_nnictl_package) - #parse tensorboard command - parser_tensorboard = subparsers.add_parser('tensorboard', help='manage tensorboard') - parser_tensorboard_subparsers = parser_tensorboard.add_subparsers() - parser_tensorboard_start = parser_tensorboard_subparsers.add_parser('start', help='start tensorboard') - parser_tensorboard_start.add_argument('id', nargs='?', help='the id of experiment') - parser_tensorboard_start.add_argument('--trial_id', '-T', dest='trial_id', help='the id of trial') - parser_tensorboard_start.add_argument('--port', dest='port', default=6006, type=int, help='the port to start tensorboard') - parser_tensorboard_start.set_defaults(func=start_tensorboard) - parser_tensorboard_stop = parser_tensorboard_subparsers.add_parser('stop', help='stop tensorboard') - parser_tensorboard_stop.add_argument('id', nargs='?', help='the id of experiment') - parser_tensorboard_stop.set_defaults(func=stop_tensorboard) - #parse top command parser_top = subparsers.add_parser('top', help='monitor the experiment') parser_top.add_argument('--time', '-t', dest='time', type=int, default=3, help='the time interval to update the experiment status, ' \ diff --git a/nni/tools/nnictl/nnictl_utils.py b/nni/tools/nnictl/nnictl_utils.py index 9637ecb2e7..16942b74ef 100644 --- a/nni/tools/nnictl/nnictl_utils.py +++ b/nni/tools/nnictl/nnictl_utils.py @@ -223,14 +223,6 @@ def stop_experiment(args): rest_pid = experiments_dict.get(experiment_id).get('pid') if rest_pid: kill_command(rest_pid) - tensorboard_pid_list = experiments_dict.get(experiment_id).get('tensorboardPidList') - if tensorboard_pid_list: - for tensorboard_pid in tensorboard_pid_list: - try: - kill_command(tensorboard_pid) - except Exception as exception: - print_error(exception) - experiments_config.update_experiment(experiment_id, 'tensorboardPidList', []) print_normal('Stop experiment success.') def trial_ls(args): diff --git a/nni/tools/nnictl/tensorboard_utils.py b/nni/tools/nnictl/tensorboard_utils.py deleted file mode 100644 index fe8262c355..0000000000 --- a/nni/tools/nnictl/tensorboard_utils.py +++ /dev/null @@ -1,166 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import os -import json -import re -import tempfile -from subprocess import call, Popen -from .rest_utils import rest_get, check_rest_server_quick, check_response -from .config_utils import Config, Experiments -from .url_utils import trial_jobs_url, get_local_urls -from .constants import REST_TIME_OUT -from .common_utils import print_normal, print_warning, print_error, print_green, detect_process, detect_port, check_tensorboard_version -from .nnictl_utils import check_experiment_id -from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local - -def parse_log_path(args, trial_content): - '''parse log path''' - path_list = [] - host_list = [] - for trial in trial_content: - if args.trial_id and args.trial_id != 'all' and trial.get('trialJobId') != args.trial_id: - continue - pattern = r'(?P.+)://(?P.+):(?P.*)' - match = re.search(pattern, trial['logPath']) - if match: - path_list.append(match.group('path')) - host_list.append(match.group('host')) - if not path_list: - print_error('Trial id %s error!' % args.trial_id) - exit(1) - return path_list, host_list - -def copy_data_from_remote(args, experiment_config, trial_content, path_list, host_list, temp_nni_path): - '''use ssh client to copy data from remote machine to local machien''' - machine_list = experiment_config.get('machineList') - machine_dict = {} - local_path_list = [] - for machine in machine_list: - machine_dict[machine['ip']] = {'port': machine['port'], 'passwd': machine['passwd'], 'username': machine['username'], - 'sshKeyPath': machine.get('sshKeyPath'), 'passphrase': machine.get('passphrase')} - for index, host in enumerate(host_list): - local_path = os.path.join(temp_nni_path, trial_content[index].get('trialJobId')) - local_path_list.append(local_path) - print_normal('Copying log data from %s to %s' % (host + ':' + path_list[index], local_path)) - sftp = create_ssh_sftp_client(host, machine_dict[host]['port'], machine_dict[host]['username'], machine_dict[host]['passwd'], - machine_dict[host]['sshKeyPath'], machine_dict[host]['passphrase']) - copy_remote_directory_to_local(sftp, path_list[index], local_path) - print_normal('Copy done!') - return local_path_list - -def get_path_list(args, experiment_config, trial_content, temp_nni_path): - '''get path list according to different platform''' - path_list, host_list = parse_log_path(args, trial_content) - platform = experiment_config.get('trainingServicePlatform') - if platform == 'local': - print_normal('Log path: %s' % ' '.join(path_list)) - return path_list - elif platform == 'remote': - path_list = copy_data_from_remote(args, experiment_config, trial_content, path_list, host_list, temp_nni_path) - print_normal('Log path: %s' % ' '.join(path_list)) - return path_list - else: - print_error('Not supported platform!') - exit(1) - -def format_tensorboard_log_path(path_list): - new_path_list = [] - for index, value in enumerate(path_list): - new_path_list.append('name%d:%s' % (index + 1, value)) - return ','.join(new_path_list) - -def start_tensorboard_process(args, experiment_id, path_list, temp_nni_path): - '''call cmds to start tensorboard process in local machine''' - if detect_port(args.port): - print_error('Port %s is used by another process, please reset port!' % str(args.port)) - exit(1) - with open(os.path.join(temp_nni_path, 'tensorboard_stdout'), 'a+') as stdout_file, \ - open(os.path.join(temp_nni_path, 'tensorboard_stderr'), 'a+') as stderr_file: - log_dir_cmd = '--logdir_spec' if check_tensorboard_version() >= '2.0' else '--logdir' - cmds = ['tensorboard', log_dir_cmd, format_tensorboard_log_path(path_list), '--port', str(args.port)] - tensorboard_process = Popen(cmds, stdout=stdout_file, stderr=stderr_file) - url_list = get_local_urls(args.port) - print_green('Start tensorboard success!') - print_normal('Tensorboard urls: ' + ' '.join(url_list)) - experiments_config = Experiments() - tensorboard_process_pid_list = experiments_config.get_all_experiments().get(experiment_id).get('tensorboardPidList') - if tensorboard_process_pid_list is None: - tensorboard_process_pid_list = [tensorboard_process.pid] - else: - tensorboard_process_pid_list.append(tensorboard_process.pid) - experiments_config.update_experiment(experiment_id, 'tensorboardPidList', tensorboard_process_pid_list) - -def stop_tensorboard(args): - '''stop tensorboard''' - experiment_id = check_experiment_id(args) - experiments_config = Experiments() - tensorboard_pid_list = experiments_config.get_all_experiments().get(experiment_id).get('tensorboardPidList') - if tensorboard_pid_list: - for tensorboard_pid in tensorboard_pid_list: - try: - cmds = ['kill', '-9', str(tensorboard_pid)] - call(cmds) - except Exception as exception: - print_error(exception) - experiments_config.update_experiment(experiment_id, 'tensorboardPidList', []) - print_normal('Stop tensorboard success!') - else: - print_error('No tensorboard configuration!') - -def adl_tensorboard_helper(args): - '''start tensorboard on adl''' - import subprocess - if args.trial_id is not None: - print_warning('Tensorboard on adl platform will show all trials. No trial ids needed.') - cmd = "kubectl port-forward --address 0.0.0.0 deployment/{} {}:{}".format( - "adaptdl-tensorboard" + "-" + args.id.lower(), - args.port, - 6006 - ) - print_green('Tensorboard is accessible at 0.0.0.0:{port} or localhost:{port}'.format(port=args.port)) - subprocess.run(args=cmd, shell=True) - -def start_tensorboard(args): - '''start tensorboard''' - experiment_id = check_experiment_id(args) - if not experiment_id: - return - if args.id is None: - args.id = experiment_id - experiments_config = Experiments() - experiments_dict = experiments_config.get_all_experiments() - if experiments_dict[args.id]["status"] == "STOPPED": - print_error("Experiment {} is stopped...".format(args.id)) - return - experiment_config = Config(args.id, experiments_dict[args.id]['logDir']).get_config() - if experiment_config.get('trainingServicePlatform') == 'adl': - adl_tensorboard_helper(args) - return - rest_port = experiments_dict[args.id]['port'] - rest_pid = experiments_dict[args.id]['pid'] - if not detect_process(rest_pid): - print_error('Experiment is not running...') - return - running, response = check_rest_server_quick(rest_port) - trial_content = None - if running: - response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT) - if response and check_response(response): - trial_content = json.loads(response.text) - else: - print_error('List trial failed...') - else: - print_error('Restful server is not running...') - if not trial_content: - print_error('No trial information!') - exit(1) - if len(trial_content) > 1 and not args.trial_id: - print_error('There are multiple trials, please set trial id!') - exit(1) - experiment_id = args.id - temp_nni_path = os.path.join(tempfile.gettempdir(), 'nni', experiment_id) - os.makedirs(temp_nni_path, exist_ok=True) - - path_list = get_path_list(args, experiment_config, trial_content, temp_nni_path) - start_tensorboard_process(args, experiment_id, path_list, temp_nni_path) diff --git a/pipelines/full-test-linux.yml b/pipelines/full-test-linux.yml index 57ae0c29bf..aaf09d175c 100644 --- a/pipelines/full-test-linux.yml +++ b/pipelines/full-test-linux.yml @@ -31,7 +31,7 @@ jobs: python3 -m pip install scikit-learn==0.24.1 python3 -m pip install torchvision==0.7.0 python3 -m pip install torch==1.6.0 - python3 -m pip install 'pytorch-lightning>=1.1.1,<1.2' + python3 -m pip install 'pytorch-lightning>=1.1.1' python3 -m pip install keras==2.1.6 python3 -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0 python3 -m pip install thop diff --git a/pipelines/full-test-windows.yml b/pipelines/full-test-windows.yml index 4a2a44e70b..282b9fd1e7 100644 --- a/pipelines/full-test-windows.yml +++ b/pipelines/full-test-windows.yml @@ -28,7 +28,7 @@ jobs: python -m pip install scikit-learn==0.24.1 python -m pip install keras==2.1.6 python -m pip install torch==1.6.0 torchvision==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html - python -m pip install 'pytorch-lightning>=1.1.1,<1.2' + python -m pip install 'pytorch-lightning>=1.1.1' python -m pip install tensorflow==2.3.1 tensorflow-estimator==2.3.0 displayName: Install extra dependencies diff --git a/pipelines/integration-test-aml.yml b/pipelines/integration-test-aml.yml new file mode 100644 index 0000000000..a862b7b9b9 --- /dev/null +++ b/pipelines/integration-test-aml.yml @@ -0,0 +1,63 @@ +trigger: none +pr: none +schedules: +- cron: 0 16 * * * + branches: + include: [ master ] + +jobs: +- job: aml + pool: NNI CI REMOTE CLI + timeoutInMinutes: 120 + + steps: + - script: | + export NNI_RELEASE=999.$(date -u +%Y%m%d%H%M%S) + echo "##vso[task.setvariable variable=PATH]${PATH}:${HOME}/.local/bin" + echo "##vso[task.setvariable variable=NNI_RELEASE]${NNI_RELEASE}" + + echo "Working directory: ${PWD}" + echo "NNI version: ${NNI_RELEASE}" + echo "Build docker image: $(build_docker_image)" + + python3 -m pip install --upgrade pip setuptools + displayName: Prepare + + - script: | + set -e + python3 setup.py build_ts + python3 setup.py bdist_wheel -p manylinux1_x86_64 + python3 -m pip install dist/nni-${NNI_RELEASE}-py3-none-manylinux1_x86_64.whl[SMAC,BOHB] + displayName: Build and install NNI + + - script: | + set -e + cd examples/tuners/customized_tuner + python3 setup.py develop --user + nnictl algo register --meta meta_file.yml + displayName: Install customized tuner + + - script: | + set -e + docker login -u nnidev -p $(docker_hub_password) + echo '## Build docker image ##' + docker build --build-arg NNI_RELEASE=${NNI_RELEASE} -t nnidev/nni-nightly . + echo '## Upload docker image ##' + docker push nnidev/nni-nightly + condition: eq(variables['build_docker_image'], 'true') + displayName: Build and upload docker image + + - script: | + set -e + cd test + python3 nni_test/nnitest/generate_ts_config.py \ + --ts aml \ + --subscription_id $(subscriptionId) \ + --resource_group $(resourceGroup) \ + --workspace_name $(workspaceName) \ + --compute_target $(computeTarget) \ + --nni_manager_ip $(manager_ip) \ + --nni_docker_image nnidev/nni-nightly + + python3 nni_test/nnitest/run_tests.py --config config/integration_tests.yml --ts aml + displayName: Integration test diff --git a/test/config/training_service.yml b/test/config/training_service.yml index 43838069d9..cc85f57d5c 100644 --- a/test/config/training_service.yml +++ b/test/config/training_service.yml @@ -105,4 +105,18 @@ adl: storageClass: storageSize: trainingServicePlatform: adl - +aml: + nniManagerIp: + maxExecDuration: 15m + # PAI has job submission limitation, set maxTrialNum=1 to control trial job numbers for PAI + maxTrialNum: 2 + trialConcurrency: 2 + trainingServicePlatform: aml + trial: + gpuNum: 1 + image: + amlConfig: + subscriptionId: + resourceGroup: + workspaceName: + computeTarget: diff --git a/test/nni_test/nnitest/generate_ts_config.py b/test/nni_test/nnitest/generate_ts_config.py index 843aba170d..7dcf5465c8 100644 --- a/test/nni_test/nnitest/generate_ts_config.py +++ b/test/nni_test/nnitest/generate_ts_config.py @@ -88,13 +88,24 @@ def update_training_service_config(args): config[args.ts]['trial']['nfs']['server'] = args.adl_nfs_server config[args.ts]['trial']['nfs']['path'] = args.adl_nfs_path config[args.ts]['trial']['nfs']['container_mount_path'] = args.nadl_fs_container_mount_path + elif args.ts == 'aml': + if args.nni_docker_image is not None: + config[args.ts]['trial']['image'] = args.nni_docker_image + if args.subscription_id is not None: + config[args.ts]['amlConfig']['subscriptionId'] = args.subscription_id + if args.resource_group is not None: + config[args.ts]['amlConfig']['resourceGroup'] = args.resource_group + if args.workspace_name is not None: + config[args.ts]['amlConfig']['workspaceName'] = args.workspace_name + if args.compute_target is not None: + config[args.ts]['amlConfig']['computeTarget'] = args.compute_target dump_yml_content(TRAINING_SERVICE_FILE, config) if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument("--ts", type=str, choices=['pai', 'kubeflow', 'remote', 'local', 'frameworkcontroller', 'adl'], default='pai') + parser.add_argument("--ts", type=str, choices=['pai', 'kubeflow', 'remote', 'local', 'frameworkcontroller', 'adl', 'aml'], default='pai') parser.add_argument("--nni_docker_image", type=str) parser.add_argument("--nni_manager_ip", type=str) # args for PAI @@ -129,6 +140,11 @@ def update_training_service_config(args): parser.add_argument("--adl_nfs_server", type=str) parser.add_argument("--adl_nfs_path", type=str) parser.add_argument("--adl_nfs_container_mount_path", type=str) + # args for aml + parser.add_argument("--subscription_id", type=str) + parser.add_argument("--resource_group", type=str) + parser.add_argument("--workspace_name", type=str) + parser.add_argument("--compute_target", type=str) args = parser.parse_args() update_training_service_config(args) diff --git a/test/nni_test/nnitest/run_tests.py b/test/nni_test/nnitest/run_tests.py index d0d285829a..0b888cf3b7 100644 --- a/test/nni_test/nnitest/run_tests.py +++ b/test/nni_test/nnitest/run_tests.py @@ -281,7 +281,7 @@ def run(args): parser.add_argument("--cases", type=str, default=None) parser.add_argument("--exclude", type=str, default=None) parser.add_argument("--ts", type=str, choices=['local', 'remote', 'pai', - 'kubeflow', 'frameworkcontroller', 'adl'], default='local') + 'kubeflow', 'frameworkcontroller', 'adl', 'aml'], default='local') args = parser.parse_args() run(args) diff --git a/test/ut/retiarii/test_lightning_trainer.py b/test/ut/retiarii/test_lightning_trainer.py index b9fe9b23d3..e0dd32813d 100644 --- a/test/ut/retiarii/test_lightning_trainer.py +++ b/test/ut/retiarii/test_lightning_trainer.py @@ -3,6 +3,7 @@ import nni import nni.retiarii.evaluator.pytorch.lightning as pl +import nni.runtime.platform.test import pytorch_lightning import torch import torch.nn as nn diff --git a/ts/nni_manager/common/log.ts b/ts/nni_manager/common/log.ts index 9a33a48e49..dff76028e1 100644 --- a/ts/nni_manager/common/log.ts +++ b/ts/nni_manager/common/log.ts @@ -5,151 +5,135 @@ import * as fs from 'fs'; import { Writable } from 'stream'; -import { WritableStreamBuffer } from 'stream-buffers'; -import { format } from 'util'; -import * as component from '../common/component'; -import { getExperimentStartupInfo, isReadonly } from './experimentStartupInfo'; - -const FATAL: number = 1; -const ERROR: number = 2; -const WARNING: number = 3; -const INFO: number = 4; -const DEBUG: number = 5; -const TRACE: number = 6; - -const logLevelNameMap: Map = new Map([ - ['fatal', FATAL], - ['error', ERROR], - ['warning', WARNING], - ['info', INFO], - ['debug', DEBUG], - ['trace', TRACE] + +/* log level constants */ + +export const DEBUG = 10; +export const INFO = 20; +export const WARNING = 30; +export const ERROR = 40; +export const CRITICAL = 50; + +export const TRACE = 1; +export const FATAL = 50; + +const levelNames = new Map([ + [CRITICAL, 'CRITICAL'], + [ERROR, 'ERROR'], + [WARNING, 'WARNING'], + [INFO, 'INFO'], + [DEBUG, 'DEBUG'], + [TRACE, 'TRACE'], ]); -class BufferSerialEmitter { - private buffer: Buffer; - private emitting: boolean; - private writable: Writable; +/* global_ states */ - constructor(writable: Writable) { - this.buffer = Buffer.alloc(0); - this.emitting = false; - this.writable = writable; - } +let logFile: Writable | null = null; +let logLevel: number = 0; +const loggers = new Map(); - public feed(buffer: Buffer): void { - this.buffer = Buffer.concat([this.buffer, buffer]); - if (!this.emitting) { - this.emit(); - } - } +/* major api */ - private emit(): void { - this.emitting = true; - this.writable.write(this.buffer, () => { - if (this.buffer.length === 0) { - this.emitting = false; - } else { - this.emit(); - } - }); - this.buffer = Buffer.alloc(0); +export class Logger { + private name: string; + + constructor(name: string = 'root') { + this.name = name; } -} -@component.Singleton -class Logger { - private level: number = INFO; - private bufferSerialEmitter?: BufferSerialEmitter; - private writable?: Writable; - private readonly: boolean = false; - - constructor(fileName?: string) { - const logFile: string | undefined = fileName; - if (logFile) { - this.writable = fs.createWriteStream(logFile, { - flags: 'a+', - encoding: 'utf8', - autoClose: true - }); - this.bufferSerialEmitter = new BufferSerialEmitter(this.writable); - } + public trace(...args: any[]): void { + this.log(TRACE, args); + } - const logLevelName: string = getExperimentStartupInfo() - .getLogLevel(); - const logLevel: number | undefined = logLevelNameMap.get(logLevelName); - if (logLevel !== undefined) { - this.level = logLevel; - } + public debug(...args: any[]): void { + this.log(DEBUG, args); + } - this.readonly = isReadonly(); + public info(...args: any[]): void { + this.log(INFO, args); } - public close(): void { - if (this.writable) { - this.writable.destroy(); - } + public warning(...args: any[]): void { + this.log(WARNING, args); } - public trace(...param: any[]): void { - if (this.level >= TRACE) { - this.log('TRACE', param); - } + public error(...args: any[]): void { + this.log(ERROR, args); } - public debug(...param: any[]): void { - if (this.level >= DEBUG) { - this.log('DEBUG', param); - } + public critical(...args: any[]): void { + this.log(CRITICAL, args); } - public info(...param: any[]): void { - if (this.level >= INFO) { - this.log('INFO', param); - } + public fatal(...args: any[]): void { + this.log(FATAL, args); } - public warning(...param: any[]): void { - if (this.level >= WARNING) { - this.log('WARNING', param); + private log(level: number, args: any[]): void { + if (level < logLevel || logFile === null) { + return; } - } - public error(...param: any[]): void { - if (this.level >= ERROR) { - this.log('ERROR', param); + // `time.toLocaleString('sv')` trick does not work for Windows + const isoTime = new Date(new Date().toLocaleString() + ' UTC').toISOString(); + const time = isoTime.slice(0, 10) + ' ' + isoTime.slice(11, 19); + + const levelName = levelNames.has(level) ? levelNames.get(level) : level.toString(); + + const words = []; + for (const arg of args) { + if (arg === undefined) { + words.push('undefined'); + } else if (arg === null) { + words.push('null'); + } else if (typeof arg === 'object') { + const json = JSON.stringify(arg); + words.push(json === undefined ? arg : json); + } else { + words.push(arg); + } } + const message = words.join(' '); + + const record = `[${time}] ${levelName} (${this.name}) ${message}\n`; + logFile.write(record); } +} - public fatal(...param: any[]): void { - this.log('FATAL', param); +export function getLogger(name: string = 'root'): Logger { + let logger = loggers.get(name); + if (logger === undefined) { + logger = new Logger(name); + loggers.set(name, logger); } - - /** - * if the experiment is not in readonly mode, write log content to stream - * @param level log level - * @param param the params to be written - */ - private log(level: string, param: any[]): void { - if (!this.readonly) { - const time = new Date(); - const localTime = new Date(time.getTime() - time.getTimezoneOffset() * 60000); - const timeStr = localTime.toISOString().slice(0, -5).replace('T', ' '); - const logContent = `[${timeStr}] ${level} ${format(param)}\n`; - if (this.writable && this.bufferSerialEmitter) { - const buffer: WritableStreamBuffer = new WritableStreamBuffer(); - buffer.write(logContent); - buffer.end(); - this.bufferSerialEmitter.feed(buffer.getContents()); - } else { - console.log(logContent); - } + return logger; +} + +/* management functions */ + +export function setLogLevel(levelName: string): void { + if (levelName) { + const level = module.exports[levelName.toUpperCase()]; + if (typeof level === 'number') { + logLevel = level; + } else { + console.log('[ERROR] Bad log level:', levelName); + getLogger('logging').error('Bad log level:', levelName); } } } -function getLogger(): Logger { - return component.get(Logger); +export function startLogging(logPath: string): void { + logFile = fs.createWriteStream(logPath, { + flags: 'a+', + encoding: 'utf8', + autoClose: true + }); } -export { Logger, getLogger, logLevelNameMap }; +export function stopLogging(): void { + if (logFile !== null) { + logFile.end(); + logFile = null; + } +} diff --git a/ts/nni_manager/common/utils.ts b/ts/nni_manager/common/utils.ts index 965582b73b..8e0a5b4667 100644 --- a/ts/nni_manager/common/utils.ts +++ b/ts/nni_manager/common/utils.ts @@ -23,7 +23,6 @@ import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupIn import { ExperimentConfig, Manager } from './manager'; import { ExperimentManager } from './experimentManager'; import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService'; -import { logLevelNameMap } from './log'; function getExperimentRootDir(): string { return getExperimentStartupInfo() @@ -195,9 +194,6 @@ function prepareUnitTest(): void { Container.snapshot(ExperimentManager); const logLevel: string = parseArg(['--log_level', '-ll']); - if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) { - console.log(`FATAL: invalid log_level: ${logLevel}`); - } setExperimentStartupInfo(true, 'unittest', 8080, 'unittest', undefined, logLevel); mkDirPSync(getLogDir()); diff --git a/ts/nni_manager/core/nniDataStore.ts b/ts/nni_manager/core/nniDataStore.ts index 095167cc84..1d0ac3adef 100644 --- a/ts/nni_manager/core/nniDataStore.ts +++ b/ts/nni_manager/core/nniDataStore.ts @@ -71,7 +71,7 @@ class NNIDataStore implements DataStore { public storeTrialJobEvent( event: TrialJobEvent, trialJobId: string, hyperParameter?: string, jobDetail?: TrialJobDetail): Promise { - this.log.debug(`storeTrialJobEvent: event: ${event}, data: ${hyperParameter}, jobDetail: ${JSON.stringify(jobDetail)}`); + //this.log.debug(`storeTrialJobEvent: event: ${event}, data: ${hyperParameter}, jobDetail: ${JSON.stringify(jobDetail)}`); // Use the timestamp in jobDetail as TrialJobEvent timestamp for different events let timestamp: number | undefined; diff --git a/ts/nni_manager/core/nniTensorboardManager.ts b/ts/nni_manager/core/nniTensorboardManager.ts index 721cd7e1d0..f3993b9dd7 100644 --- a/ts/nni_manager/core/nniTensorboardManager.ts +++ b/ts/nni_manager/core/nniTensorboardManager.ts @@ -116,7 +116,7 @@ class NNITensorboardManager implements TensorboardManager { private setTensorboardVersion(): void { let command = `python3 -c 'import tensorboard ; print(tensorboard.__version__)'`; if (process.platform === 'win32') { - command = `python -c 'import tensorboard ; print(tensorboard.__version__)'`; + command = `python -c "import tensorboard ; print(tensorboard.__version__)"`; } try { const tensorboardVersion = cp.execSync(command).toString(); diff --git a/ts/nni_manager/core/nnimanager.ts b/ts/nni_manager/core/nnimanager.ts index 7f89bce92c..e535052c91 100644 --- a/ts/nni_manager/core/nnimanager.ts +++ b/ts/nni_manager/core/nnimanager.ts @@ -10,7 +10,7 @@ import * as component from '../common/component'; import { DataStore, MetricDataRecord, MetricType, TrialJobInfo } from '../common/datastore'; import { NNIError } from '../common/errors'; import { getExperimentId, getDispatcherPipe } from '../common/experimentStartupInfo'; -import { getLogger, Logger } from '../common/log'; +import { Logger, getLogger, stopLogging } from '../common/log'; import { ExperimentProfile, Manager, ExperimentStatus, NNIManagerStatus, ProfileUpdateType, TrialJobStatistics @@ -362,7 +362,7 @@ class NNIManager implements Manager { hasError = true; this.log.error(`${err.stack}`); } finally { - this.log.close(); + stopLogging(); process.exit(hasError ? 1 : 0); } } diff --git a/ts/nni_manager/main.ts b/ts/nni_manager/main.ts index d0f90dd470..3b6c9bae90 100644 --- a/ts/nni_manager/main.ts +++ b/ts/nni_manager/main.ts @@ -10,7 +10,7 @@ import * as path from 'path'; import * as component from './common/component'; import { Database, DataStore } from './common/datastore'; import { setExperimentStartupInfo } from './common/experimentStartupInfo'; -import { getLogger, Logger, logLevelNameMap } from './common/log'; +import { getLogger, setLogLevel, startLogging } from './common/log'; import { Manager, ExperimentStartUpMode } from './common/manager'; import { ExperimentManager } from './common/experimentManager'; import { TensorboardManager } from './common/tensorboardManager'; @@ -47,14 +47,15 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN .to(NNITensorboardManager) .scope(Scope.Singleton); const DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log'); - if (foreground) { - logFileName = undefined; - } else if (logFileName === undefined) { - logFileName = DEFAULT_LOGFILE; + if (!foreground) { + if (logFileName === undefined) { + startLogging(DEFAULT_LOGFILE); + } else { + startLogging(logFileName); + } + // eslint-disable-next-line @typescript-eslint/no-use-before-define + setLogLevel(logLevel); } - Container.bind(Logger).provider({ - get: (): Logger => new Logger(logFileName) - }); const ds: DataStore = component.get(DataStore); await ds.init(); @@ -110,9 +111,6 @@ if (logDir.length > 0) { } const logLevel: string = parseArg(['--log_level', '-ll']); -if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) { - console.log(`FATAL: invalid log_level: ${logLevel}`); -} const readonlyArg: string = parseArg(['--readonly', '-r']); if (!('true' || 'false').includes(readonlyArg.toLowerCase())) { @@ -132,11 +130,9 @@ mkDirP(getLogDir()) await initContainer(foreground, mode); const restServer: NNIRestServer = component.get(NNIRestServer); await restServer.start(); - const log: Logger = getLogger(); - log.info(`Rest server listening on: ${restServer.endPoint}`); + getLogger('main').info(`Rest server listening on: ${restServer.endPoint}`); } catch (err) { - const log: Logger = getLogger(); - log.error(`${err.stack}`); + getLogger('main').error(`${err.stack}`); throw err; } }) diff --git a/ts/nni_manager/package.json b/ts/nni_manager/package.json index ae32a636fb..409876e487 100644 --- a/ts/nni_manager/package.json +++ b/ts/nni_manager/package.json @@ -76,7 +76,7 @@ "node-forge": ">=0.10.0", "dot-prop": "^4.2.1", "npm": ">=6.14.8", - "yargs": ">=16.0.3", + "yargs": "~16.0.3", "yargs-parser": ">=20.2.0", "y18n": ">=5.0.5", "acorn": ">=8.0.4", diff --git a/ts/nni_manager/yarn.lock b/ts/nni_manager/yarn.lock index cec6098baf..4681e5c278 100644 --- a/ts/nni_manager/yarn.lock +++ b/ts/nni_manager/yarn.lock @@ -1102,10 +1102,10 @@ cli-width@^2.0.0: version "2.2.0" resolved "https://registry.yarnpkg.com/cli-width/-/cli-width-2.2.0.tgz#ff19ede8a9a5e579324147b0c11f0fbcbabed639" -cliui@^7.0.2: - version "7.0.3" - resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.3.tgz#ef180f26c8d9bff3927ee52428bfec2090427981" - integrity sha512-Gj3QHTkVMPKqwP3f7B4KPkBZRMR9r4rfi5bXFpg1a+Svvj8l7q5CnkBkVQzfxT5DFSsGk2+PascOgL0JYkL2kw== +cliui@^7.0.0: + version "7.0.4" + resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" + integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== dependencies: string-width "^4.2.0" strip-ansi "^6.0.0" @@ -1612,7 +1612,7 @@ es6-promisify@^5.0.0: dependencies: es6-promise "^4.0.3" -escalade@^3.1.1: +escalade@^3.0.2: version "3.1.1" resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== @@ -2279,14 +2279,10 @@ hoek@2.x.x, hoek@^4.2.1: version "4.2.1" resolved "https://registry.yarnpkg.com/hoek/-/hoek-4.2.1.tgz#9634502aa12c445dd5a7c5734b572bb8738aacbb" -hosted-git-info@^2.1.4: - version "2.7.1" - resolved "https://registry.yarnpkg.com/hosted-git-info/-/hosted-git-info-2.7.1.tgz#97f236977bd6e125408930ff6de3eec6281ec047" - -hosted-git-info@^2.7.1, hosted-git-info@^2.8.8: - version "2.8.8" - resolved "https://registry.yarnpkg.com/hosted-git-info/-/hosted-git-info-2.8.8.tgz#7539bd4bc1e0e0a895815a2e0262420b12858488" - integrity sha512-f/wzC2QaWBs7t9IYqB4T3sR1xviIViXJRJTWBlx2Gf3g0Xi5vI7Yy4koXQ1c9OYDGHN9sBy1DQ2AB8fqZBWhUg== +hosted-git-info@^2.1.4, hosted-git-info@^2.7.1, hosted-git-info@^2.8.8: + version "2.8.9" + resolved "https://registry.yarnpkg.com/hosted-git-info/-/hosted-git-info-2.8.9.tgz#dffc0bf9a21c02209090f2aa69429e1414daf3f9" + integrity sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw== html-escaper@^2.0.0: version "2.0.2" @@ -3157,9 +3153,9 @@ lodash.without@~4.4.0: resolved "https://registry.yarnpkg.com/lodash.without/-/lodash.without-4.4.0.tgz#3cd4574a00b67bae373a94b748772640507b7aac" lodash@>=4.17.13, lodash@^4.17.11, lodash@^4.17.13, lodash@^4.17.14, lodash@^4.17.15: - version "4.17.20" - resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.20.tgz#b44a9b6297bcb698f1c51a3545a2b3b368d59c52" - integrity sha512-PlhdFcillOINfeV7Ni6oF1TAEayyZBoZ8bcshTHqOYJYlrqzRK5hagpagky5o4HfCzzd1TRkXPMFq6cKk9rGmA== + version "4.17.21" + resolved "https://registry.yarnpkg.com/lodash/-/lodash-4.17.21.tgz#679591c564c3bffaae8454cf0b3df370c3d6911c" + integrity sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg== log-symbols@4.0.0: version "4.0.0" @@ -5050,9 +5046,9 @@ sshpk@^1.7.0: tweetnacl "~0.14.0" ssri@^6.0.0, ssri@^6.0.1: - version "6.0.1" - resolved "https://registry.yarnpkg.com/ssri/-/ssri-6.0.1.tgz#2a3c41b28dd45b62b63676ecb74001265ae9edd8" - integrity sha512-3Wge10hNcT1Kur4PDFwEieXSCMCJs/7WvSACcrMYrNp+b8kDL1/0wJch5Ni2WrtwEa2IO8OsVfeKIciKCDx/QA== + version "6.0.2" + resolved "https://registry.yarnpkg.com/ssri/-/ssri-6.0.2.tgz#157939134f20464e7301ddba3e90ffa8f7728ac5" + integrity sha512-cepbSq/neFK7xB6A50KHN0xHDotYzq58wWCa5LeWqnPrHG8GzfEjO/4O8kpmcGW+oaxkvhEJCWgbgNk4/ZV93Q== dependencies: figgy-pudding "^3.5.1" @@ -5729,7 +5725,7 @@ xtend@~4.0.1: resolved "https://registry.yarnpkg.com/xtend/-/xtend-4.0.2.tgz#bb72779f5fa465186b1f438f674fa347fdb5db54" integrity sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ== -y18n@>=5.0.5, y18n@^4.0.0, y18n@^5.0.2: +y18n@>=5.0.5, y18n@^4.0.0, y18n@^5.0.1: version "5.0.5" resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.5.tgz#8769ec08d03b1ea2df2500acef561743bbb9ab18" integrity sha512-hsRUr4FFrvhhRH12wOdfs38Gy7k2FFzB9qgN9v3aLykRq0dRcdcpz5C9FxdS2NuhOrI/628b/KSTJ3rwHysYSg== @@ -5751,7 +5747,7 @@ yallist@^4.0.0: version "4.0.0" resolved "https://registry.yarnpkg.com/yallist/-/yallist-4.0.0.tgz#9bb92790d9c0effec63be73519e11a35019a3a72" -yargs-parser@13.1.2, yargs-parser@>=20.2.0, yargs-parser@^20.2.2: +yargs-parser@13.1.2, yargs-parser@>=20.2.0, yargs-parser@^20.0.0: version "20.2.3" resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.3.tgz#92419ba867b858c868acf8bae9bf74af0dd0ce26" integrity sha512-emOFRT9WVHw03QSvN5qor9QQT9+sw5vwxfYweivSMHTcAXPefwVae2FjO7JJjj8hCE4CzPOPeFM83VwT29HCww== @@ -5767,18 +5763,18 @@ yargs-unparser@1.6.1: is-plain-obj "^1.1.0" yargs "^14.2.3" -yargs@13.3.2, yargs@>=16.0.3, yargs@^11.0.0, yargs@^14.2.3, yargs@^15.0.2, yargs@^8.0.2: - version "16.1.0" - resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.1.0.tgz#fc333fe4791660eace5a894b39d42f851cd48f2a" - integrity sha512-upWFJOmDdHN0syLuESuvXDmrRcWd1QafJolHskzaw79uZa7/x53gxQKiR07W59GWY1tFhhU/Th9DrtSfpS782g== +yargs@13.3.2, yargs@^11.0.0, yargs@^14.2.3, yargs@^15.0.2, yargs@^8.0.2, yargs@~16.0.3: + version "16.0.3" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.0.3.tgz#7a919b9e43c90f80d4a142a89795e85399a7e54c" + integrity sha512-6+nLw8xa9uK1BOEOykaiYAJVh6/CjxWXK/q9b5FpRgNslt8s22F2xMBqVIKgCRjNgGvGPBy8Vog7WN7yh4amtA== dependencies: - cliui "^7.0.2" - escalade "^3.1.1" + cliui "^7.0.0" + escalade "^3.0.2" get-caller-file "^2.0.5" require-directory "^2.1.1" string-width "^4.2.0" - y18n "^5.0.2" - yargs-parser "^20.2.2" + y18n "^5.0.1" + yargs-parser "^20.0.0" yn@^2.0.0: version "2.0.0" diff --git a/ts/webui/package.json b/ts/webui/package.json index 9356401be3..541948418e 100644 --- a/ts/webui/package.json +++ b/ts/webui/package.json @@ -114,7 +114,7 @@ }, "resolutions": { "npm": ">=6.14.4", - "yargs": ">=16.0.3", + "yargs": "~16.0.3", "acorn": ">=8.0.4", "node-forge": ">=0.10.0", "y18n": ">=5.0.5", diff --git a/ts/webui/src/components/modals/Compare.tsx b/ts/webui/src/components/modals/Compare.tsx index 2736b49898..4c49a85de8 100644 --- a/ts/webui/src/components/modals/Compare.tsx +++ b/ts/webui/src/components/modals/Compare.tsx @@ -51,6 +51,7 @@ interface CompareProps { title: string; showDetails: boolean; onHideDialog: () => void; + changeSelectTrialIds?: () => void; } class Compare extends React.Component { @@ -196,8 +197,17 @@ class Compare extends React.Component { ); } + private closeCompareModal = (): void => { + const { showDetails, changeSelectTrialIds, onHideDialog } = this.props; + if (showDetails === true) { + // eslint-disable-next-line @typescript-eslint/no-non-null-assertion + changeSelectTrialIds!(); + } + onHideDialog(); + }; + render(): React.ReactNode { - const { onHideDialog, trials, title, showDetails } = this.props; + const { trials, title, showDetails } = this.props; const flatten = (m: Map): Map => { return new Map(Array.from(m).map(([key, value]) => [key.baseName, value])); }; @@ -218,7 +228,7 @@ class Compare extends React.Component { className='compare-modal' allowTouchBodyScroll={true} dragOptions={dragOptions} - onDismiss={onHideDialog} + onDismiss={this.closeCompareModal} >
@@ -227,7 +237,7 @@ class Compare extends React.Component { styles={iconButtonStyles} iconProps={{ iconName: 'Cancel' }} ariaLabel='Close popup modal' - onClick={onHideDialog} + onClick={this.closeCompareModal} />
diff --git a/ts/webui/src/components/modals/tensorboard/TensorboardUI.tsx b/ts/webui/src/components/modals/tensorboard/TensorboardUI.tsx index d0b486847b..df935a45f6 100644 --- a/ts/webui/src/components/modals/tensorboard/TensorboardUI.tsx +++ b/ts/webui/src/components/modals/tensorboard/TensorboardUI.tsx @@ -9,7 +9,7 @@ import TensorboardDialog from './TensorboardDialog'; function TensorboardUI(props): any { let refreshTensorboard = 0; - const { selectedRowIds } = props; + const { selectedRowIds, changeSelectTrialIds } = props; const [queryTensorboardList, setQueryTensorboardList] = useState([]); const [isReaptedStartTensorboard, setReaptedTensorboard] = useState(false); const [tensorboardPanelVisible, setTensorboardPanelVisible] = useState(false); @@ -130,6 +130,7 @@ function TensorboardUI(props): any { item={selectedTensorboard} onHideDialog={(): void => { setTensorboardPanelVisible(false); + changeSelectTrialIds(); }} /> )} @@ -138,7 +139,8 @@ function TensorboardUI(props): any { } TensorboardUI.propTypes = { - selectedRowIds: PropTypes.array + selectedRowIds: PropTypes.array, + changeSelectTrialIds: PropTypes.func }; export default TensorboardUI; diff --git a/ts/webui/src/components/stateless-component/NNItabs.tsx b/ts/webui/src/components/stateless-component/NNItabs.tsx index 36156f35e6..a3051dae19 100644 --- a/ts/webui/src/components/stateless-component/NNItabs.tsx +++ b/ts/webui/src/components/stateless-component/NNItabs.tsx @@ -2,19 +2,19 @@ import * as React from 'react'; import { NavLink } from 'react-router-dom'; const OVERVIEWTABS = ( - + Overview ); const DETAILTABS = ( - + Trials detail ); const NNILOGO = ( - + NNI logo ); diff --git a/ts/webui/src/components/trial-detail/TableList.tsx b/ts/webui/src/components/trial-detail/TableList.tsx index 5a7599e29d..a4405e63c4 100644 --- a/ts/webui/src/components/trial-detail/TableList.tsx +++ b/ts/webui/src/components/trial-detail/TableList.tsx @@ -6,12 +6,11 @@ import { Icon, IDropdownOption, PrimaryButton, - Selection, - SelectionMode, Stack, StackItem, TooltipHost, - DirectionalHint + DirectionalHint, + Checkbox } from '@fluentui/react'; import { EXPERIMENT, TRIALS } from '../../static/datamodel'; import { TOOLTIP_BACKGROUND_COLOR } from '../../static/const'; @@ -95,7 +94,6 @@ interface TableListState { } class TableList extends React.Component { - private _selection: Selection; private _expandedTrialIds: Set; constructor(props: TableListProps) { @@ -119,14 +117,6 @@ class TableList extends React.Component { sortInfo: { field: '', isDescend: true } }; - this._selection = new Selection({ - onSelectionChanged: (): void => { - this.setState({ - selectedRowIds: this._selection.getSelection().map(s => (s as any).id) - }); - } - }); - this._expandedTrialIds = new Set(); } @@ -185,10 +175,12 @@ class TableList extends React.Component { // TODO: use search space and metrics space from TRIALS will cause update issues. const searchSpace = TRIALS.inferredSearchSpace(EXPERIMENT.searchSpaceNew); const metricSpace = TRIALS.inferredMetricSpace(); + const { selectedRowIds } = this.state; const items = trials.map(trial => { const ret = { sequenceId: trial.sequenceId, id: trial.id, + checked: selectedRowIds.includes(trial.id) ? true : false, startTime: (trial as Trial).info.startTime, // FIXME: why do we need info here? endTime: (trial as Trial).info.endTime, duration: trial.duration, @@ -216,9 +208,58 @@ class TableList extends React.Component { } } + private selectedTrialOnChangeEvent = ( + id: string, + _ev?: React.FormEvent, + checked?: boolean + ): void => { + const { displayedItems, selectedRowIds } = this.state; + const items = JSON.parse(JSON.stringify(displayedItems)); + const temp = selectedRowIds; + if (checked === true) { + temp.push(id); + } + items.forEach(item => { + if (item.id === id) { + item.checked = !!checked; + } + }); + this.setState(() => ({ displayedItems: items, selectedRowIds: temp })); + }; + + private changeSelectTrialIds = (): void => { + const { displayedItems } = this.state; + const newDisplayedItems = displayedItems; + newDisplayedItems.forEach(item => { + item.checked = false; + }); + this.setState(() => ({ + selectedRowIds: [], + displayedItems: newDisplayedItems + })); + }; + private _buildColumnsFromTableItems(tableItems: any[]): IColumn[] { - // extra column, for a icon to expand the trial details panel const columns: IColumn[] = [ + // select trial function + { + name: '', + key: '_selected', + fieldName: 'selected', + minWidth: 20, + maxWidth: 20, + isResizable: true, + className: 'detail-table', + onRender: (record): React.ReactNode => ( + + ) + }, + // extra column, for a icon to expand the trial details panel { key: '_expand', name: '', @@ -265,6 +306,7 @@ class TableList extends React.Component { maxWidth: 20 } ]; + // looking at the first row only for now for (const k of Object.keys(tableItems[0])) { if (k === 'metric/default') { @@ -493,7 +535,10 @@ class TableList extends React.Component { }} disabled={selectedRowIds.length === 0} /> - + @@ -531,12 +576,12 @@ class TableList extends React.Component { - displayedColumns.includes(column.key) || ['_expand', '_operation'].includes(column.key) + displayedColumns.includes(column.key) || + ['_expand', '_operation', '_selected'].includes(column.key) )} items={displayedItems} compact={true} - selection={this._selection} - selectionMode={SelectionMode.multiple} + selectionMode={0} selectionPreservedOnEmptyClick={true} onRenderRow={(props): any => { // eslint-disable-next-line @typescript-eslint/no-non-null-assertion @@ -552,6 +597,7 @@ class TableList extends React.Component { onHideDialog={(): void => { this.setState({ compareDialogVisible: false }); }} + changeSelectTrialIds={this.changeSelectTrialIds} /> )} {intermediateDialogTrial !== undefined && ( diff --git a/ts/webui/src/index.tsx b/ts/webui/src/index.tsx index 1cd2170de6..45b1861262 100644 --- a/ts/webui/src/index.tsx +++ b/ts/webui/src/index.tsx @@ -1,6 +1,7 @@ import React, { lazy, Suspense } from 'react'; import ReactDOM from 'react-dom'; import App from './App'; +import { getPrefix } from './static/function'; import { BrowserRouter as Router, Route, Switch } from 'react-router-dom'; const Overview = lazy(() => import('./components/Overview')); const TrialsDetail = lazy(() => import('./components/TrialsDetail')); @@ -9,8 +10,10 @@ import './index.css'; import './static/style/loading.scss'; import * as serviceWorker from './serviceWorker'; +const path = getPrefix(); + ReactDOM.render( - + diff --git a/ts/webui/src/static/const.ts b/ts/webui/src/static/const.ts index 8ae46ffe88..d527adae48 100644 --- a/ts/webui/src/static/const.ts +++ b/ts/webui/src/static/const.ts @@ -1,10 +1,16 @@ +import { getPrefix } from './function'; + // when there are more trials than this threshold, metrics will be updated in group of this size to avoid freezing const METRIC_GROUP_UPDATE_THRESHOLD = 100; const METRIC_GROUP_UPDATE_SIZE = 20; -const MANAGER_IP = `/api/v1/nni`; +const prefix = getPrefix(); + +const MANAGER_IP = prefix === undefined ? '/api/v1/nni' : `${prefix}`; const DOWNLOAD_IP = `/logs`; + const WEBUIDOC = 'https://nni.readthedocs.io/en/latest/Tutorial/WebUI.html'; + const trialJobStatus = [ 'UNKNOWN', 'WAITING', diff --git a/ts/webui/src/static/function.ts b/ts/webui/src/static/function.ts index bf10d8ef6e..cb1c9275df 100644 --- a/ts/webui/src/static/function.ts +++ b/ts/webui/src/static/function.ts @@ -4,6 +4,17 @@ import { IContextualMenuProps } from '@fluentui/react'; import { MANAGER_IP } from './const'; import { MetricDataRecord, FinalType, TableObj, Tensorboard } from './interface'; +function getPrefix(): string | undefined { + const pathName = window.location.pathname; + let newPathName = pathName; + + if (pathName.endsWith('/oview') || pathName.endsWith('/detail') || pathName.endsWith('/experiment')) { + newPathName = pathName.replace('/oview' || '/detail' || '/experiment', ''); + } + + return newPathName === '' ? undefined : newPathName; +} + async function requestAxios(url: string): Promise { const response = await axios.get(url); if (response.status === 200) { @@ -346,6 +357,7 @@ function getTensorboardMenu(queryTensorboardList: Tensorboard[], stopFunc, seeDe return tensorboardMenu; } export { + getPrefix, convertTime, convertDuration, convertTimeAsUnit, diff --git a/ts/webui/src/static/style/table.scss b/ts/webui/src/static/style/table.scss index c89b8e0ad8..d5fb29527e 100644 --- a/ts/webui/src/static/style/table.scss +++ b/ts/webui/src/static/style/table.scss @@ -59,3 +59,22 @@ max-height: 335px; overflow-y: auto; } + +$checkboxwidth: 17px; + +.detail-check { + .ms-Checkbox-checkbox { + width: $checkboxwidth; + height: $checkboxwidth; + border-radius: 50%; + border: none; + + &:hover { + border: 1px solid grey; + } + + i { + width: 12px; + } + } +} diff --git a/ts/webui/yarn.lock b/ts/webui/yarn.lock index 62a766ec08..22daa008aa 100644 --- a/ts/webui/yarn.lock +++ b/ts/webui/yarn.lock @@ -3577,7 +3577,7 @@ cli-width@^2.0.0: resolved "https://registry.yarnpkg.com/cli-width/-/cli-width-2.2.1.tgz#b0433d0b4e9c847ef18868a4ef16fd5fc8271c48" integrity sha512-GRMWDxpOB6Dgk2E5Uo+3eEBvtOOlimMmpbFiKuLFnQzYDavtLFY3K5ona41jgN/WdRZtG7utuVSVTL4HbZHGkw== -cliui@^7.0.2: +cliui@^7.0.0: version "7.0.4" resolved "https://registry.yarnpkg.com/cliui/-/cliui-7.0.4.tgz#a0265ee655476fc807aea9df3df8df7783808b4f" integrity sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ== @@ -5103,7 +5103,7 @@ es6-promisify@^5.0.0: dependencies: es6-promise "^4.0.3" -escalade@^3.1.1: +escalade@^3.0.2, escalade@^3.1.1: version "3.1.1" resolved "https://registry.yarnpkg.com/escalade/-/escalade-3.1.1.tgz#d8cfdc7000965c5a0174b4a82eaa5c0552742e40" integrity sha512-k0er2gUkLf8O0zKJiAhmkTnJlTvINGv7ygDNPbeIsX/TJjGJZHuh9B2UxbsaEkmlEo9MfhrSzmhIlhRlI2GXnw== @@ -6376,9 +6376,9 @@ hoist-non-react-statics@^3.1.0: react-is "^16.7.0" hosted-git-info@^2.1.4, hosted-git-info@^2.7.1, hosted-git-info@^2.8.8: - version "2.8.8" - resolved "https://registry.yarnpkg.com/hosted-git-info/-/hosted-git-info-2.8.8.tgz#7539bd4bc1e0e0a895815a2e0262420b12858488" - integrity sha512-f/wzC2QaWBs7t9IYqB4T3sR1xviIViXJRJTWBlx2Gf3g0Xi5vI7Yy4koXQ1c9OYDGHN9sBy1DQ2AB8fqZBWhUg== + version "2.8.9" + resolved "https://registry.yarnpkg.com/hosted-git-info/-/hosted-git-info-2.8.9.tgz#dffc0bf9a21c02209090f2aa69429e1414daf3f9" + integrity sha512-mxIDAb9Lsm6DoOJ7xH+5+X4y1LU/4Hi50L9C5sIswK3JzULS4bwk1FvjdBgvYR4bzT4tuUQiC15FE2f5HbLvYw== hosted-git-info@^3.0.6: version "3.0.7" @@ -12345,9 +12345,9 @@ sshpk@^1.7.0: tweetnacl "~0.14.0" ssri@^6.0.0, ssri@^6.0.1: - version "6.0.1" - resolved "https://registry.yarnpkg.com/ssri/-/ssri-6.0.1.tgz#2a3c41b28dd45b62b63676ecb74001265ae9edd8" - integrity sha512-3Wge10hNcT1Kur4PDFwEieXSCMCJs/7WvSACcrMYrNp+b8kDL1/0wJch5Ni2WrtwEa2IO8OsVfeKIciKCDx/QA== + version "6.0.2" + resolved "https://registry.yarnpkg.com/ssri/-/ssri-6.0.2.tgz#157939134f20464e7301ddba3e90ffa8f7728ac5" + integrity sha512-cepbSq/neFK7xB6A50KHN0xHDotYzq58wWCa5LeWqnPrHG8GzfEjO/4O8kpmcGW+oaxkvhEJCWgbgNk4/ZV93Q== dependencies: figgy-pudding "^3.5.1" @@ -12467,7 +12467,16 @@ string-width@^3.0.0: is-fullwidth-code-point "^2.0.0" strip-ansi "^5.1.0" -string-width@^4.1.0, string-width@^4.2.0: +string-width@^4.1.0: + version "4.2.2" + resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.2.tgz#dafd4f9559a7585cfba529c6a0a4f73488ebd4c5" + integrity sha512-XBJbT3N4JhVumXE0eoLU9DCjcaF92KLNqTmFCnG1pf8duUxFGwtP6AD6nkjw9a3IdiRtL3E2w3JDiE/xi3vOeA== + dependencies: + emoji-regex "^8.0.0" + is-fullwidth-code-point "^3.0.0" + strip-ansi "^6.0.0" + +string-width@^4.2.0: version "4.2.0" resolved "https://registry.yarnpkg.com/string-width/-/string-width-4.2.0.tgz#952182c46cc7b2c313d1596e623992bd163b72b5" integrity sha512-zUz5JD+tgqtuDjMhwIg5uFVV3dtqZ9yQJlZVfq4I01/K5Paj5UHj7VyrQOJvzawSVlKpObApbfD0Ed6yJc+1eg== @@ -13316,9 +13325,9 @@ url-parse-lax@^1.0.0: prepend-http "^1.0.1" url-parse@^1.4.3: - version "1.4.7" - resolved "https://registry.yarnpkg.com/url-parse/-/url-parse-1.4.7.tgz#a8a83535e8c00a316e403a5db4ac1b9b853ae278" - integrity sha512-d3uaVyzDB9tQoSXFvuSUNFibTd9zxd2bkVrDRvF5TmvWWQwqE4lgYJ5m+x1DbecWkw+LK4RNl2CU1hHuOKPVlg== + version "1.5.1" + resolved "https://registry.yarnpkg.com/url-parse/-/url-parse-1.5.1.tgz#d5fa9890af8a5e1f274a2c98376510f6425f6e3b" + integrity sha512-HOfCOUJt7iSYzEx/UqgtwKRMC6EU91NFhsCHMv9oM03VJcVo2Qrp8T8kI9D7amFf1cu+/3CEhgb3rF9zL7k85Q== dependencies: querystringify "^2.1.1" requires-port "^1.0.0" @@ -13936,7 +13945,7 @@ xtend@^4.0.0, xtend@~4.0.1: resolved "https://registry.yarnpkg.com/xtend/-/xtend-4.0.2.tgz#bb72779f5fa465186b1f438f674fa347fdb5db54" integrity sha512-LKYU1iAXJXUgAXn9URjiu+MWhyUXHsvfp7mcuYm9dSUKK0/CjtrUwFAxD82/mCWbtLsGjFIad0wIsod4zrTAEQ== -y18n@>=5.0.5, y18n@^4.0.0, y18n@^5.0.5: +y18n@>=5.0.5, y18n@^4.0.0, y18n@^5.0.1: version "5.0.5" resolved "https://registry.yarnpkg.com/y18n/-/y18n-5.0.5.tgz#8769ec08d03b1ea2df2500acef561743bbb9ab18" integrity sha512-hsRUr4FFrvhhRH12wOdfs38Gy7k2FFzB9qgN9v3aLykRq0dRcdcpz5C9FxdS2NuhOrI/628b/KSTJ3rwHysYSg== @@ -13961,23 +13970,28 @@ yaml@^1.10.0: resolved "https://registry.yarnpkg.com/yaml/-/yaml-1.10.0.tgz#3b593add944876077d4d683fee01081bd9fff31e" integrity sha512-yr2icI4glYaNG+KWONODapy2/jDdMSDnrONSjblABjD9B4Z5LgiircSt8m8sRZFNi08kG9Sm0uSHtEmP3zaEGg== -yargs-parser@^20.2.2, yargs-parser@^20.2.3: +yargs-parser@^20.0.0: + version "20.2.7" + resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.7.tgz#61df85c113edfb5a7a4e36eb8aa60ef423cbc90a" + integrity sha512-FiNkvbeHzB/syOjIUxFDCnhSfzAL8R5vs40MgLFBorXACCOAEaWu0gRZl14vG8MR9AOJIZbmkjhusqBYZ3HTHw== + +yargs-parser@^20.2.3: version "20.2.4" resolved "https://registry.yarnpkg.com/yargs-parser/-/yargs-parser-20.2.4.tgz#b42890f14566796f85ae8e3a25290d205f154a54" integrity sha512-WOkpgNhPTlE73h4VFAFsOnomJVaovO8VqLDzy5saChRBFQFBoMYirowyW+Q9HB4HFF4Z7VZTiG3iSzJJA29yRA== -yargs@12.0.2, yargs@>=16.0.3, yargs@^11.0.0, yargs@^13.3.0, yargs@^13.3.2, yargs@^14.2.3, yargs@^8.0.2: - version "16.2.0" - resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.2.0.tgz#1c82bf0f6b6a66eafce7ef30e376f49a12477f66" - integrity sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw== +yargs@12.0.2, yargs@^11.0.0, yargs@^13.3.0, yargs@^13.3.2, yargs@^14.2.3, yargs@^8.0.2, yargs@~16.0.3: + version "16.0.3" + resolved "https://registry.yarnpkg.com/yargs/-/yargs-16.0.3.tgz#7a919b9e43c90f80d4a142a89795e85399a7e54c" + integrity sha512-6+nLw8xa9uK1BOEOykaiYAJVh6/CjxWXK/q9b5FpRgNslt8s22F2xMBqVIKgCRjNgGvGPBy8Vog7WN7yh4amtA== dependencies: - cliui "^7.0.2" - escalade "^3.1.1" + cliui "^7.0.0" + escalade "^3.0.2" get-caller-file "^2.0.5" require-directory "^2.1.1" string-width "^4.2.0" - y18n "^5.0.5" - yargs-parser "^20.2.2" + y18n "^5.0.1" + yargs-parser "^20.0.0" zrender@5.0.4: version "5.0.4"