Skip to content

Commit

Permalink
Support alpha='auto_po2' in quantizers. Adds QKerasFactorizeAlpha Opt…
Browse files Browse the repository at this point in the history
…imizer pass to factorize out alpha scale and insert new 'ApplyAlpha' (BatchNormalization) layer to apply it back. Attach data_unquantized to WeightVariables to retain access to them later (used in QKerasFactorizeAlpha pass)
  • Loading branch information
thesps committed May 13, 2020
1 parent 4840d73 commit 2a44923
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
13 changes: 8 additions & 5 deletions hls4ml/model/hls_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def size_cpp(self):
return '*'.join([str(k) for k in self.dim_names])

class WeightVariable(Variable):
def __init__(self, var_name, type_name, precision, data, **kwargs):
def __init__(self, var_name, type_name, precision, quantizer, data, **kwargs):
super(WeightVariable, self).__init__(var_name, type_name, precision, **kwargs)
self.data = data
self.nzeros = -1
Expand All @@ -131,6 +131,7 @@ def __init__(self, var_name, type_name, precision, data, **kwargs):
self.max = np.max(self.data)
self._iterator = None
self.update_precision(precision)
self.quantizer = quantizer

def __iter__(self):
self._iterator = np.nditer(self.data, order='C')
Expand Down Expand Up @@ -173,8 +174,8 @@ def definition_cpp(self):
return '{type} {name}[{size}]'.format(type=self.type.name, name=self.cppname, size=self.data_length)

class CompressedWeightVariable(WeightVariable):
def __init__(self, var_name, type_name, precision, data, reuse_factor, **kwargs):
super(CompressedWeightVariable, self).__init__(var_name, type_name, precision, data, **kwargs)
def __init__(self, var_name, type_name, precision, quantizer, data, reuse_factor, **kwargs):
super(CompressedWeightVariable, self).__init__(var_name, type_name, precision, quantizer, data, **kwargs)
self.extra_zeros = 0
self.data_length = np.prod(data.shape) - self.nzeros
while self.data_length % reuse_factor != 0:
Expand Down Expand Up @@ -345,16 +346,18 @@ def add_weights_variable(self, name, var_name=None, type_name=None, precision=No
elif isinstance(data, six.string_types):
data = self.model.get_weights_data(self.name, data)

data_unquantized = data
if quantizer is not None:
precision = quantizer.hls_type
type_name = name + '{index}_t'
data = quantizer(data)

if compression:
var = CompressedWeightVariable(var_name, type_name=type_name, precision=precision, data=data, reuse_factor=self.reuse_factor, index=self.index)
var = CompressedWeightVariable(var_name, type_name=type_name, precision=precision, quantizer=quantizer, data=data, reuse_factor=self.reuse_factor, index=self.index)
else:
var = WeightVariable(var_name, type_name=type_name, precision=precision, data=data, index=self.index)
var = WeightVariable(var_name, type_name=type_name, precision=precision, quantizer=quantizer, data=data, index=self.index)

var.data_unquantized = data_unquantized
self.weights[name] = var
self.precision[var.type.name] = var.type

Expand Down
10 changes: 5 additions & 5 deletions hls4ml/model/optimizer/passes/qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def match(self, node):

def transform(self, model, node):
# The quantizer has to be applied to set the scale attribute
# Should work whether the weights have been quantized or not
quantizer = node.get_attr('weight_quantizer').quantizer_fn # get quantizer
weights = node.weights['weight'].data # get weights
# This must be applied to the _unquantized_ weights to obtain the correct scale
quantizer = node.weights['weight'].quantizer.quantizer_fn # get QKeras quantizer
weights = node.weights['weight'].data_unquantized # get weights
qweights = quantizer(tf.convert_to_tensor(weights))
scale = quantizer.scale.numpy()
unscale = 1. / scale
Expand All @@ -99,17 +99,17 @@ def transform(self, model, node):
node.weights['weight'].data = new_weights.numpy()

# insert a Batch Normalization layer to apply the alpha scale
next_node = next((x for x in model.graph.values() if x.inputs[0] == node.outputs[0]), None)
attrs = {
'name' : node.get_attr('name') + '_alpha',
'class_name' : 'Alpha',
'inputs' : node.outputs,
'n_in' : node.get_attr('n_out'),
'n_filt' : node.get_attr('n_filt') if node.get_attr('n_filt') is not None else -1,
'reuse_factor' : node.get_attr('reuse_factor'),
'bias_t' : 'ap_fixed<16,6>', # TODO automate this
'scale_t' : 'ap_fixed<16,6>' # TODO automate this
}
alpha_layer = model.make_node('ApplyAlpha', node.name + '_alpha', attrs, node.outputs, next_node.inputs)
alpha_layer = model.make_node('ApplyAlpha', node.name + '_alpha', attrs, node.outputs)
alpha_layer.add_weights(scale, np.zeros(scale.shape))
model.insert_node(alpha_layer)

3 changes: 2 additions & 1 deletion hls4ml/model/profiling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib
from hls4ml.model.hls_model import HLSModel
import qkeras

libs = [('numpy', 'np'), ('pandas', 'pandas'), ('tensorflow', 'tensorflow'),
('seaborn', 'sb'), ('matplotlib.pyplot', 'plt')]
Expand Down Expand Up @@ -356,7 +357,7 @@ def get_ymodel_keras(keras_model, X):
if not _is_ignored_layer(layer):
#If the layer has activation integrated then separate them
#Note that if the layer is a standalone activation layer then skip this
if hasattr(layer, 'activation') and not isinstance(layer,keras.layers.Activation):
if hasattr(layer, 'activation') and not (isinstance(layer,keras.layers.Activation) or isinstance(layer, qkeras.qlayers.QActivation)):
if layer.activation:

if layer.activation.__class__.__name__ == "linear":
Expand Down

0 comments on commit 2a44923

Please sign in to comment.