Skip to content

Latest commit

 

History

History
73 lines (57 loc) · 3.88 KB

README.md

File metadata and controls

73 lines (57 loc) · 3.88 KB

Digit Recognizer with Tensorflow JS

A short project to train a model using Keras in Python and use TFJS library for on-browser prediction.

See in action

Requires

This code has been tested on Tensorflow 1.11 and higher using python3.

#Start a virtual environment
virtualenv venv
source venv/bin/activate

#Install libraries
pip3 install tensorflow==1.11
pip3 install tensorflowjs

Train model and convert to the format required by Tensorflow JS

python3 model_builder.py

Here we use the following line in the Python code as mentioned in the docs to convert to desired format:

tfjs.converters.save_keras_model(model, 'tfjs_target_dir')

Alternatively, we could use the bash command:

tensorflowjs_converter --input_format keras \
                       my_mnist_model.h5 \
                       tfjs_target_dir

Notice that our model is very simple ie just a hidden layer with relu activations, a dropout layer and finally an output layer with softmax activations.

Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 512)               401920    
_________________________________________________________________
dropout (Dropout)            (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                5130      
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________

Credits

Some part of the javascript code is borrowed from the TFJS mnist-core example and also in order to avoid a lot of work on frontend, I used the template from the TFJS examples repo.

Bonus tip: It helps to clone the TFJS examples repo and run everything just to see what TFJS is capable of (a lot). You can use this custom-made bash command for this.

# replace with suitable command for non-macOS from https://yarnpkg.com/lang/en/docs/install
brew install yarn
git clone https://github.com/tensorflow/tfjs-examples
cd tfjs-examples
# goes into each directory and runs yarn watch without hanging up and gets you back to terminal
ls -d */ | xargs -I {} bash -c "cd '{}' && pwd && yarn && nohup yarn watch > /dev/null 2>&1 &"

Side Note

The focus for this project is not to attain very high accuracy on actual hand-written images in demo but to make a Tensorflow JS script work with the a model trained in Python. Even so, there were several challenges due to some backward incompatibility issues of TFJS.

If one would like to improve accuracy, one sure-shot way to improve accuracy would be to modify the way the canvas output is being converted into 2D tensor (Issue). The current workaround produces non-continuous segments. (Look here); whereas actual MNIST images contain continuous segments. This however will require significant tweaking with the canvas event listener functions. However since the results are reasonably well, this is not necessary.