Skip to content

Commit

Permalink
adjust fetch model to keras json
Browse files Browse the repository at this point in the history
  • Loading branch information
Duchstf committed Jul 29, 2020
1 parent 004ac1b commit 921fa67
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import hls4ml

#Fetch a keras model from our example repository
#This will download our example model to your working directory and return an example configuration file
config = hls4ml.utils.fetch_example_model('keras_3layer.h5')
config = hls4ml.utils.fetch_example_model('KERAS_3layer.json')

print(config) #You can print it to see some default parameters

Expand Down
19 changes: 14 additions & 5 deletions hls4ml/utils/example_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def fetch_example_model(model_name):
https://github.com/hls-fpga-machine-learning/example-models
Args:
- model_name: string, name of the example model in the repo. Example: 'keras_3_layer.h5'
- model_name: string, name of the example model in the repo. Example: 'KERAS_3layer.json'
"""

Expand All @@ -20,9 +20,9 @@ def fetch_example_model(model_name):
model_config = None

#Check for model's type to update link
if '.h5' in model_name:
if '.json' in model_name:
model_type = 'keras'
model_config = 'KerasH5'
model_config = 'KerasJson'
elif '.pt' in model_name:
model_type = 'pytorch'
model_config = 'PytorchModel'
Expand All @@ -36,13 +36,22 @@ def fetch_example_model(model_name):
raise TypeError('Model type is not supported in hls4ml.')


download_link += model_type + '/' + model_name
download_link_model = download_link + model_type + '/' + model_name

#Initiate the configuration file
config = create_vivado_config()

#Download the example model
urlretrieve(download_link, model_name)
urlretrieve(download_link_model, model_name)

#If the model is a keras model then have to download its weight file as well
if model_type == 'keras':
model_weight_name = model_name[:-5] + "_weights.h5"

download_link_weight = download_link + model_type + '/' + model_weight_name
urlretrieve(download_link_weight, model_weight_name)

config['KerasH5'] = model_weight_name #Set configuration for the weight file

#Additional configuration parameters
config[model_config] = model_name
Expand Down

0 comments on commit 921fa67

Please sign in to comment.