This repository has been archived by the owner on Sep 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Retiarii] support hypermodule: autoactivation #3868
Merged
QuanluZhang
merged 10 commits into
microsoft:master
from
QuanluZhang:dev-retiarii-module
Jul 15, 2021
Merged
Changes from 8 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
9ad789e
add the first hypermodule, autoactivation
QuanluZhang 32c4f51
support base execution engine
QuanluZhang 4e42cdf
support various number of core units
QuanluZhang ca88604
add doc docstring
QuanluZhang 96bde5b
fix pylint
QuanluZhang 1945988
Merge branch 'master' of https://github.com/microsoft/nni into dev-re…
QuanluZhang f081b8e
add ut for autoactivation
QuanluZhang 185b8ca
add one more test for base engine
QuanluZhang 5d8fe77
remove redundant test
QuanluZhang 89e55ae
Merge branch 'master' of https://github.com/microsoft/nni into dev-re…
QuanluZhang File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
Hypermodules | ||
============ | ||
|
||
Hypermodule is a (PyTorch) module which contains many architecture/hyperparameter candidates for this module. By using hypermodule in user defined model, NNI will help users automatically find the best architecture/hyperparameter of the hypermodules for this model. This follows the design philosophy of Retiarii that users write DNN model as a space. | ||
|
||
There has been proposed some hypermodules in NAS community, such as AutoActivation, AutoDropout. Some of them are implemented in the Retiarii framework. | ||
|
||
.. autoclass:: nni.retiarii.nn.pytorch.AutoActivation | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .api import * | ||
from .component import * | ||
from .nn import * | ||
from .hypermodule import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT license. | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from nni.retiarii.serializer import basic_unit | ||
|
||
from .api import LayerChoice | ||
from ...utils import version_larger_equal | ||
|
||
__all__ = ['AutoActivation'] | ||
|
||
TorchVersion = '1.5.0' | ||
|
||
# ============== unary function modules ============== | ||
|
||
@basic_unit | ||
class UnaryIdentity(nn.Module): | ||
def forward(self, x): | ||
return x | ||
|
||
@basic_unit | ||
class UnaryNegative(nn.Module): | ||
def forward(self, x): | ||
return -x | ||
|
||
@basic_unit | ||
class UnaryAbs(nn.Module): | ||
def forward(self, x): | ||
return torch.abs(x) | ||
|
||
@basic_unit | ||
class UnarySquare(nn.Module): | ||
def forward(self, x): | ||
return torch.square(x) | ||
|
||
@basic_unit | ||
class UnaryPow(nn.Module): | ||
def forward(self, x): | ||
return torch.pow(x, 3) | ||
|
||
@basic_unit | ||
class UnarySqrt(nn.Module): | ||
def forward(self, x): | ||
return torch.sqrt(x) | ||
|
||
@basic_unit | ||
class UnaryMul(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
# element-wise for now, will change to per-channel trainable parameter | ||
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable | ||
def forward(self, x): | ||
return x * self.beta | ||
|
||
@basic_unit | ||
class UnaryAdd(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
# element-wise for now, will change to per-channel trainable parameter | ||
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable | ||
def forward(self, x): | ||
return x + self.beta | ||
|
||
@basic_unit | ||
class UnaryLogAbs(nn.Module): | ||
def forward(self, x): | ||
return torch.log(torch.abs(x) + 1e-7) | ||
|
||
@basic_unit | ||
class UnaryExp(nn.Module): | ||
def forward(self, x): | ||
return torch.exp(x) | ||
|
||
@basic_unit | ||
class UnarySin(nn.Module): | ||
def forward(self, x): | ||
return torch.sin(x) | ||
|
||
@basic_unit | ||
class UnaryCos(nn.Module): | ||
def forward(self, x): | ||
return torch.cos(x) | ||
|
||
@basic_unit | ||
class UnarySinh(nn.Module): | ||
def forward(self, x): | ||
return torch.sinh(x) | ||
|
||
@basic_unit | ||
class UnaryCosh(nn.Module): | ||
def forward(self, x): | ||
return torch.cosh(x) | ||
|
||
@basic_unit | ||
class UnaryTanh(nn.Module): | ||
def forward(self, x): | ||
return torch.tanh(x) | ||
|
||
if not version_larger_equal(torch.__version__, TorchVersion): | ||
@basic_unit | ||
class UnaryAsinh(nn.Module): | ||
def forward(self, x): | ||
return torch.asinh(x) | ||
|
||
@basic_unit | ||
class UnaryAtan(nn.Module): | ||
def forward(self, x): | ||
return torch.atan(x) | ||
|
||
if not version_larger_equal(torch.__version__, TorchVersion): | ||
@basic_unit | ||
class UnarySinc(nn.Module): | ||
def forward(self, x): | ||
return torch.sinc(x) | ||
|
||
@basic_unit | ||
class UnaryMax(nn.Module): | ||
def forward(self, x): | ||
return torch.max(x, torch.zeros_like(x)) | ||
|
||
@basic_unit | ||
class UnaryMin(nn.Module): | ||
def forward(self, x): | ||
return torch.min(x, torch.zeros_like(x)) | ||
|
||
@basic_unit | ||
class UnarySigmoid(nn.Module): | ||
def forward(self, x): | ||
return torch.sigmoid(x) | ||
|
||
@basic_unit | ||
class UnaryLogExp(nn.Module): | ||
def forward(self, x): | ||
return torch.log(1 + torch.exp(x)) | ||
|
||
@basic_unit | ||
class UnaryExpSquare(nn.Module): | ||
def forward(self, x): | ||
return torch.exp(-torch.square(x)) | ||
|
||
@basic_unit | ||
class UnaryErf(nn.Module): | ||
def forward(self, x): | ||
return torch.erf(x) | ||
|
||
unary_modules = ['UnaryIdentity', 'UnaryNegative', 'UnaryAbs', 'UnarySquare', 'UnaryPow', | ||
'UnarySqrt', 'UnaryMul', 'UnaryAdd', 'UnaryLogAbs', 'UnaryExp', 'UnarySin', 'UnaryCos', | ||
'UnarySinh', 'UnaryCosh', 'UnaryTanh', 'UnaryAtan', 'UnaryMax', | ||
'UnaryMin', 'UnarySigmoid', 'UnaryLogExp', 'UnaryExpSquare', 'UnaryErf'] | ||
|
||
if not version_larger_equal(torch.__version__, TorchVersion): | ||
unary_modules.append('UnaryAsinh') | ||
unary_modules.append('UnarySinc') | ||
|
||
# ============== binary function modules ============== | ||
|
||
@basic_unit | ||
class BinaryAdd(nn.Module): | ||
def forward(self, x): | ||
return x[0] + x[1] | ||
|
||
@basic_unit | ||
class BinaryMul(nn.Module): | ||
def forward(self, x): | ||
return x[0] * x[1] | ||
|
||
@basic_unit | ||
class BinaryMinus(nn.Module): | ||
def forward(self, x): | ||
return x[0] - x[1] | ||
|
||
@basic_unit | ||
class BinaryDivide(nn.Module): | ||
def forward(self, x): | ||
return x[0] / (x[1] + 1e-7) | ||
|
||
@basic_unit | ||
class BinaryMax(nn.Module): | ||
def forward(self, x): | ||
return torch.max(x[0], x[1]) | ||
|
||
@basic_unit | ||
class BinaryMin(nn.Module): | ||
def forward(self, x): | ||
return torch.min(x[0], x[1]) | ||
|
||
@basic_unit | ||
class BinarySigmoid(nn.Module): | ||
def forward(self, x): | ||
return torch.sigmoid(x[0]) * x[1] | ||
|
||
@basic_unit | ||
class BinaryExpSquare(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable | ||
def forward(self, x): | ||
return torch.exp(-self.beta * torch.square(x[0] - x[1])) | ||
|
||
@basic_unit | ||
class BinaryExpAbs(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable | ||
def forward(self, x): | ||
return torch.exp(-self.beta * torch.abs(x[0] - x[1])) | ||
|
||
@basic_unit | ||
class BinaryParamAdd(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable | ||
def forward(self, x): | ||
return self.beta * x[0] + (1 - self.beta) * x[1] | ||
|
||
binary_modules = ['BinaryAdd', 'BinaryMul', 'BinaryMinus', 'BinaryDivide', 'BinaryMax', | ||
'BinaryMin', 'BinarySigmoid', 'BinaryExpSquare', 'BinaryExpAbs', 'BinaryParamAdd'] | ||
|
||
|
||
class AutoActivation(nn.Module): | ||
""" | ||
This module is an implementation of the paper "Searching for Activation Functions" | ||
(https://arxiv.org/abs/1710.05941). | ||
NOTE: current `beta` is not per-channel parameter | ||
|
||
Parameters | ||
---------- | ||
unit_num : int | ||
the number of core units | ||
""" | ||
def __init__(self, unit_num = 1): | ||
super().__init__() | ||
self.unaries = nn.ModuleList() | ||
self.binaries = nn.ModuleList() | ||
self.first_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules]) | ||
for _ in range(unit_num): | ||
one_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules]) | ||
self.unaries.append(one_unary) | ||
for _ in range(unit_num): | ||
one_binary = LayerChoice([eval('{}()'.format(binary)) for binary in binary_modules]) | ||
self.binaries.append(one_binary) | ||
|
||
def forward(self, x): | ||
out = self.first_unary(x) | ||
for unary, binary in zip(self.unaries, self.binaries): | ||
out = binary(torch.stack([out, unary(x)])) | ||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can write one in the base class and
Python
will automatically inherit it.