Skip to content

Commit

Permalink
Remove 'alpha' from non-QKeras quantizers. Treat binary_tanh activati…
Browse files Browse the repository at this point in the history
…on properly. Propagate new Softmax to keras_to_hls.
  • Loading branch information
thesps committed May 27, 2020
1 parent b1d3b54 commit 5173511
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 26 deletions.
8 changes: 4 additions & 4 deletions hls4ml/converters/keras/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def parse_reshape_layer(keras_layer, input_names, input_shapes, data_reader, con


class BinaryQuantizer(Quantizer):
def __init__(self, bits=2, alpha=1):
def __init__(self, bits=2):
if bits == 1:
hls_type = IntegerPrecisionType(width=1, signed=False)
elif bits == 2:
hls_type = IntegerPrecisionType(width=2)
else:
raise Exception('BinaryQuantizer suppots 1 or 2 bits, but called with bits={}'.format(bits))
super(BinaryQuantizer, self).__init__(bits, hls_type, alpha=alpha)
super(BinaryQuantizer, self).__init__(bits, hls_type)

def __call__(self, data):
zeros = np.zeros_like(data)
Expand All @@ -54,8 +54,8 @@ def __call__(self, data):
return quant_data

class TernaryQuantizer(Quantizer):
def __init__(self, alpha=1):
super(TernaryQuantizer, self).__init__(2, IntegerPrecisionType(width=2), alpha=alpha)
def __init__(self):
super(TernaryQuantizer, self).__init__(2, IntegerPrecisionType(width=2))

def __call__(self, data):
zeros = np.zeros_like(data)
Expand Down
40 changes: 26 additions & 14 deletions hls4ml/converters/keras/qkeras_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,41 @@ def parse_qconv_layer(keras_layer, input_names, input_shapes, data_reader, confi
@keras_handler('QActivation')
def parse_qactivation_layer(keras_layer, input_names, input_shapes, data_reader, config):
assert(keras_layer['class_name'] == 'QActivation')
supported_activations = ['quantized_relu', 'quantized_tanh']
print(keras_layer)
supported_activations = ['quantized_relu', 'quantized_tanh', 'binary_tanh', 'ternary_tanh']

layer = parse_default_keras_layer(keras_layer, input_names)

activation_config = keras_layer['config']['activation']
if isinstance(activation_config, str):
quantizer_obj = get_quantizer(activation_config)
activation_config = {}
activation_config['class_name'] = quantizer_obj.__class__.__name__
activation_config['config'] = quantizer_obj.get_config()

print(activation_config)
if isinstance(activation_config, str):
quantizer_obj = get_quantizer(activation_config)
activation_config = {}
# some activations are classes
if hasattr(quantizer_obj, 'get_config'):
print("Name: " + quantizer_obj.__class__.__name__)
activation_config['class_name'] = quantizer_obj.__class__.__name__
activation_config['config'] = quantizer_obj.get_config()
# some activation quantizers are just functions with no config
else:
activation_config['config'] = {}
if quantizer_obj.__name__ == 'binary_tanh':
activation_config['class_name'] = 'binary_tanh'
activation_config['config']['bits'] = 1
activation_config['config']['integer'] = 1
elif quantizer_obj.__name__ == 'ternary_tanh':
activation_config['class_name'] = 'ternary_tanh'
activation_config['config']['bits'] = 2
activation_config['config']['integer'] = 2
else:
activation_config['class_name'] = 'unknown'

act_class = activation_config['class_name']
if act_class not in supported_activations:
raise Exception('Unsupported QKeras activation: {}'.format(act_class))
if activation_config['class_name'] not in supported_activations:
raise Exception('Unsupported QKeras activation: {}'.format(activation_config['class_name']))

layer['class_name'] = 'Activation'
layer['activation'] = act_class.replace('quantized_', '')
layer['bits'] = activation_config['config']['bits'] + 1
layer['integer'] = activation_config['config']['integer'] + 1
#TODO this needs extra work in HLS model and HLS templates

layer['class_name'] = 'Activation'
layer['activation'] = activation_config['class_name'].replace('quantized_', '')
return layer, [shape for shape in input_shapes[0]]

2 changes: 2 additions & 0 deletions hls4ml/converters/keras_to_hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,8 @@ def keras_to_hls(config):
if 'activ_param' in layer:
act_layer['activ_param'] = layer['activ_param']
act_layer['class_name'] = layer['activation']
elif layer['activation'] == 'softmax':
act_layer['class_name'] = 'Softmax'
else:
act_layer['class_name'] = 'Activation'
inputs_map[layer['name']] = act_layer['name']
Expand Down
3 changes: 1 addition & 2 deletions hls4ml/model/hls_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from collections import OrderedDict

class Quantizer(object):
def __init__(self, bits, hls_type, alpha):
def __init__(self, bits, hls_type):
self.bits = bits
self.hls_type = hls_type
self.alpha = alpha

def __call__(self, data):
raise NotImplementedError
Expand Down
10 changes: 8 additions & 2 deletions hls4ml/model/optimizer/passes/qkeras.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,15 @@ def match(self, node):
has_b_quant = node.get_attr('bias_quantizer') is not None
has_w_alpha, has_b_alpha = False, False
if has_w_quant:
has_w_alpha = node.get_attr('weight_quantizer').alpha != 1
if hasattr(node.get_attr('weight_quantizer'), 'alpha'):
has_w_alpha = node.get_attr('weight_quantizer').alpha != 1
else:
has_w_alpha = False
if has_b_quant:
has_b_alpha = node.get_attr('bias_quantizer').alpha != 1
if hasattr(node.get_attr('bias_quantizer'), 'alpha'):
has_b_alpha = node.get_attr('bias_quantizer').alpha != 1
else:
has_b_alpha = False
is_match = q_layer and ((has_w_quant and has_w_alpha) or (has_b_quant and has_b_alpha))
return is_match

Expand Down
14 changes: 10 additions & 4 deletions hls4ml/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,24 @@ def _get_precision_from_quantizer(quantizer):
if isinstance(quantizer, str):
quantizer_obj = qkeras.get_quantizer(quantizer)
quantizer = {}
quantizer['class_name'] = quantizer_obj.__class__.__name__
quantizer['config'] = quantizer_obj.get_config()
# Some activations are classes with get_config method
if hasattr(quantizer_obj, 'get_config'):
quantizer['class_name'] = quantizer_obj.__class__.__name__
quantizer['config'] = quantizer_obj.get_config()
# Some activations are just functions
else:
quantizer['class_name'] = quantizer_obj.__name__

supported_quantizers = ['quantized_bits', 'quantized_relu', 'quantized_tanh']
if quantizer['class_name'] in supported_quantizers:
bits = int(quantizer['config']['bits']) + 1
integer = int(quantizer['config']['integer']) + 1

elif quantizer['class_name'] in ['binary', 'stochastic_binary']:
elif quantizer['class_name'] in ['binary', 'stochastic_binary', 'binary_tanh']:
bits = 2
integer = 2

elif quantizer['class_name'] in ['ternary', 'stochastic_ternary']:
elif quantizer['class_name'] in ['ternary', 'stochastic_ternary', 'ternary_tanh']:
bits = 2
integer = 2
else:
Expand Down

0 comments on commit 5173511

Please sign in to comment.