Recognizing Digits using TensorFlow.js in Google Chrome

Deep Learning | 31 July 2018

#100DaysOfMLCode

In this blog post, we will create a simple web application that provides a canvas (mobile + desktop + laptop + tablet ready) for the user to draw a digit and uses a deep neural network (MLP or CNN) to predict what digit the user had drawn.

As we already know the capabilities offered by TensorFlow.js, we will extend the ideas to create two Deep Neural Networks (MLP and CNN) in Keras Python environment to recognize digits and use TensorFlow.js to predict the user drawn digit on a canvas in a web browser. In this learning path, we will restrict the user to draw a single digit between [0, 9] and later we will extend the idea to multiple digits.

Below is the interactive demo that you can use to draw a digit between [0, 9] and the Deep Neural Network (MLP or CNN) that is running in your browser will predict what that digit is in the form of a bar chart.

Important: This is a highly experimental demo for mobile and tablet devices. Please try this demo in laptop or desktop if you encounter issues in mobile and tablet devices. Also, please use Google Chrome as the browser to try this demo. Other browsers not supported as of now.

Handwritten Digit Recognizer using TensorFlow.js Demo

You have drawn

Note: To follow this tutorial, I assume you have basic knowledge of Python, HTML5, CSS3, Sass, JavaScript, jQuery and basic command line usage.

Downloads

Deep Neural Network for Digit Recognition

I have already posted a tutorial a year ago on how to build Deep Neural Nets (specifically a Multi-Layer Perceptron) to recognize hand-written digits using Keras and Python here. I highly encourage you to read that post before proceeding here.

I assume you have familiarity in using Keras before proceeding (else please read my other tutorials on Keras here and here). If you are a beginner to Keras, I strongly encourage you to visit Keras documentation where you could find tons of information on how to use the library and learn from examples found here.

For this tutorial, we will learn to create two popular Deep Neural Networks (DNN) namely Multi-Layer Perceptron (MLP) and Convolutional Neural Network (CNN). For this digit recognition problem, MLP achieves pretty good accuracy. But we will build a CNN too and compare both these model performances live.

Simple MLP using Keras and Python

We will simply use a Keras MLP model using Python, dump out the model and weights in Tf.js layers format (as we did here). After that, we will load the Keras dumped model and weights in our browser and use TensorFlow.js to make predictions.

To do this, here is the python code that fetches and loads MNIST dataset, trains a simple multi-layer perceptron with two hidden layers having 512 and 256 neurons respectively on a training data of 60000 images with labels, validates the trained model with 10000 unlabeled images and saves the model along with weights in Tf.js layers format in model_save_path.

mnist_mlp.pycode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# organize imports
import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Dropout
from keras.datasets import mnist
from keras.utils import np_utils
import tensorflowjs as tfjs

# fix a random seed for reproducibility
np.random.seed(9)

# user inputs
nb_epoch            = 25
num_classes         = 10
batch_size          = 64
train_size          = 60000
test_size           = 10000
v_length            = 784
model_save_path     = "output/mlp"

# split the mnist data into train and test
(trainData, trainLabels), (testData, testLabels) = mnist.load_data()

# reshape and scale the data
trainData   = trainData.reshape(train_size, v_length)
testData    = testData.reshape(test_size, v_length)
trainData   = trainData.astype("float32")
testData    = testData.astype("float32")
trainData  /= 255
testData   /= 255

# convert class vectors to binary class matrices --> one-hot encoding
mTrainLabels  = np_utils.to_categorical(trainLabels, num_classes)
mTestLabels   = np_utils.to_categorical(testLabels, num_classes)

# create the MLP model
model = Sequential()
model.add(Dense(512, input_shape=(v_length,)))
model.add(Activation("relu"))
model.add(Dense(256))
model.add(Activation("relu"))
model.add(Dropout(0.2))
model.add(Dense(num_classes))
model.add(Activation("softmax"))

# compile the model
model.compile(loss="categorical_crossentropy",
              optimizer="rmsprop",
              metrics=["accuracy"])

# fit the model
history = model.fit(trainData, 
                    mTrainLabels,
                    validation_data=(testData, mTestLabels),
                    batch_size=batch_size,
                    nb_epoch=nb_epoch,
                    verbose=2)

# evaluate the model
scores = model.evaluate(testData, mTestLabels, verbose=0)

# print the results
print ("[INFO] test score - {}".format(scores[0]))
print ("[INFO] test accuracy - {}".format(scores[1]))

# save tf.js specific files in model_save_path
tfjs.converters.save_keras_model(model, model_save_path)
1
2
[INFO] test score - 0.16604800749952142
[INFO] test accuracy - 0.9828

Two important things to watch carefully here are image size and input vector size to MLP. Keras function mnist.load_data() loads images of size of [28, 28]. We flatten this image into a vector of size 784 for the MLP. We also scale the pixel values between [0, 1] for the algorithm to perform better. These are the image preprocessing operations that we will do in the front-end too using javascript.

After training and validating the MLP, we save the model architecture and weights using tensorflowjs under model_save_path. We will upload this folder to our website from where we could easily load this Keras model in TensorFlow.js using HTTPS request to make predictions in the browser.

Simple CNN using Keras and Python

One caveat on using MNIST dataset from Keras is that the dataset is well cleaned, images are centered, cropped and aligned perfectly. So, there is very minimal preprocessing work required from a developer. But, the MLP model that we created above, isn’t well suited for cases like HTML5 canvas where different users have different handwriting.

That’s why we need to create a CNN (Convolutional Neural Network) which automatically learns from 2D matrix instead of vectorizing the canvas into a 1d vector.

Below is the python code snippet to create a simple CNN that fetches and loads MNIST dataset, trains a simple CNN with

  • One Convolution2D() layer with 32 feature maps with size [5, 5] and relu activation function that takes in the input canvas size of [28, 28, 1].
  • One MaxPooling2D() layer with a pool size of [2, 2].
  • One Dropout() layer with argument 0.2 (meaning randomly removes 20% of neurons to reduce overfitting).
  • One Flatten() layer that converts the 2D array to 1d vector for next layer.
  • One Dense() layer with 128 neurons activated by relu.
  • Final softmax activated dense layer with num_classes neurons.
mnist_cnn.pycode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# organize imports
import numpy as np
from keras.models import Sequential
from keras.layers.core import Dense, Activation, Dropout, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.datasets import mnist
from keras.utils import np_utils
import tensorflowjs as tfjs

# fix a random seed for reproducibility
np.random.seed(9)

# user inputs
nb_epoch            = 10
num_classes         = 10
batch_size          = 200
train_size          = 60000
test_size           = 10000
model_save_path     = "output/cnn"

# split the mnist data into train and test
(trainData, trainLabels), (testData, testLabels) = mnist.load_data()

# reshape and scale the data
trainData = trainData.reshape(trainData.shape[0], 28, 28, 1)
testData  = testData.reshape(testData.shape[0], 28, 28, 1)
trainData   = trainData.astype("float32")
testData    = testData.astype("float32")
trainData  /= 255
testData   /= 255

# convert class vectors to binary class matrices --> one-hot encoding
mTrainLabels  = np_utils.to_categorical(trainLabels, num_classes)
mTestLabels   = np_utils.to_categorical(testLabels, num_classes)

# create the CNN model
model = Sequential()
model.add(Convolution2D(32, (5, 5), border_mode='valid', input_shape=(28, 28, 1), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.2))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(num_classes, activation='softmax'))

# compile model
model.compile(loss='categorical_crossentropy', 
        optimizer='adam', 
        metrics=['accuracy'])

# fit the model
history = model.fit(trainData, 
                    mTrainLabels,
                    validation_data=(testData, mTestLabels),
                    batch_size=batch_size,
                    nb_epoch=nb_epoch,
                    verbose=2)

# evaluate the model
scores = model.evaluate(testData, mTestLabels, verbose=0)

# print the results
print ("[INFO] test score - {}".format(scores[0]))
print ("[INFO] test accuracy - {}".format(scores[1]))

# save tf.js specific files in model_save_path
tfjs.converters.save_keras_model(model, model_save_path)
1
2
[INFO] test score - 0.03036247751080955
[INFO] test accuracy - 0.9904

Again, look carefully at the preprocessing steps that we do here for CNN and the input shape that we pass in to the Convolution2D layer. For this tutorial, we have used TensorFlow image ordering format.

Notice how the test accuracy jumped from 0.9828 (MLP) to 0.9904 using CNN. Similar to MLP, we use tensorflowjs to dump the CNN model + weights to the model_save_path and we can load it in our server or webpage to make predictions.

JavaScript handlers for Mouse and Touch

Now let’s get into the front-end code for this tutorial. Before we start anything with JavaScript, we first need to load the following JS libraries for everything in this tutorial to work. We need to append these lines in the head tag of our HTML.

index.htmlcode
1
2
3
4
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@latest"></script>
<script type="text/javascript" src="https://code.jquery.com/jquery-2.1.1.min.js"></script>  
<script src="https://cdnjs.cloudflare.com/ajax/libs/Chart.js/2.4.0/Chart.min.js"></script>
<script type="text/javascript" src="/js/app.js"></script>

For the user to draw a digit using mobile or desktop or tablet or laptop, we need to create a HTML5 element called canvas. Inside the canvas, the user will draw the digit. We will feed the user drawn digit into the deep neural network that we have created to make predictions.

Below is the HTML5 code to create the UI of our simple web app. We have two buttons namely Clear to clear the canvas and Predict to make predictions using our deep neural network model. We also have a select option to select any one of the two Deep Neural Nets that we have created (MLP or CNN).

index.htmlcode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
<div class="digit-demo-container">
  <h3 id="digit-recognizer-live-demo">Digit Recognizer using TensorFlow.js Demo</h3>
  <div class="flex-two" style="margin-top: 20px;">
    <button id="clear_canvas" class="material-button-pink" onclick="clearCanvas(this.id)">Clear</button>
    <select id="select_model">
      <option>MLP</option>
      <option>CNN</option>
    </select>
    <button id="predict_canvas" class="material-button-pink" onclick="predict(this.id)">Predict</button>
  </div>
  <div class="flex-two">
    <div id="canvas_box_wrapper" class="canvas-box-wrapper">
      <div id="canvas_box" class="canvas-box"></div>
    </div>
    <div id="result_box">
      <canvas id="chart_box" width="100" height="100"></canvas>
    </div>
  </div>
</div>

For a smooth user experience, we will add in some style for the above created HTML5 code. Shown below is the app.scss code which we will convert to app.css using sass app.scss app.css command. If you are not familiar with Sass, please learn about Sass here.

app.scsscode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
.digit-demo-container {
  background-color: #4CAF50;
  border-radius: 5px;
  margin: 20px auto;

  h3 {
    width: 100%;
    font-size: 15px;
    background-color: #356937;
    padding: 10px;
    color: white;
    margin: 0px;
    border: 1px solid #295b2b;
    border-bottom: 0px;
    border-top-left-radius: 5px;
    border-top-right-radius: 5px;
  }

  select {
    margin-top: 0px;
    height: 30px;
    font-size: 12px;
  }
}

.flex-two {
  display: flex;
  flex-wrap: wrap;
  div {
    flex: 1;
    padding: 20px;
    text-align: center;
    @include mobile {
      padding: 5px;
    }
  }
}

.canvas-box-wrapper {
  display: block !important;

}

.material-button-pink {
  background-color: #FFD740;
  height: 30px;
  color: black;
  display: block;
  font-family: $font_body;
  font-weight: bold;
  padding: 5px 10px;
  cursor: pointer;
  border: 1px solid #a58e3a;
  font-size: 12px;
  transition: all 0.3s cubic-bezier(0.25, 0.8, 0.25, 1); 
  margin: 0px auto;
  border-radius: 5px;
}

.material-button-pink:focus {
  outline: none;
}

.material-button-pink:hover {
  box-shadow: 0 1px 8px rgba(0,0,0,0.3), 0 5px 10px rgba(0,0,0,0.22);
}

#chart_box {
  background-color: white;
  padding: 20px;
  border-radius: 5px;
  @include mobile {
    padding: 5px;
  }
}

Drawing inside a canvas is a little bit tricky in mobile and desktop. We need to be aware of all the jQuery handlers that are available for mouse and touch. Below are the jQuery event handlers that we will be using.

For Desktop & Laptop
  • mousedown
  • mousemove
  • mouseup
  • mouseleave
For Tablet & Mobile
  • touchstart
  • touchmove
  • touchend
  • touchleave

Additionally, we will add two more user-defined functions namely addUserGesture and drawOnCanvas which I will be explaining shortly.

First, we will have a div with id canvas_box inside which we will dynamically create a canvas. Below is the html code and JavaScript code to create a canvas inside which the user will draw.

index.htmlcode
1
2
3
<div id="canvas_box" class="canvas-box">
  <button id="clear_canvas" class="material-button-pink" onclick="clearCanvas(this.id)">Clear</button>
</div>
app.jscode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
// GLOBAL variables
var modelName = "digitrecognizermlp";
let model;

// canvas related variables
// you can change these variables
var canvasWidth             = 150;
var canvasHeight            = 150;
var canvasStrokeStyle       = "white";
var canvasLineJoin          = "round";
var canvasLineWidth         = 12;
var canvasBackgroundColor   = "black";
var canvasId                = "canvas";

// variables to hold coordinates and dragging boolean
var clickX = new Array();
var clickY = new Array();
var clickD = new Array();
var drawing;

document.getElementById('chart_box').innerHTML = "";
document.getElementById('chart_box').style.display = "none";

//---------------------
// Create canvas
//---------------------
var canvasBox = document.getElementById('canvas_box');
var canvas    = document.createElement("canvas");

canvas.setAttribute("width", canvasWidth);
canvas.setAttribute("height", canvasHeight);
canvas.setAttribute("id", canvasId);
canvas.style.backgroundColor = canvasBackgroundColor;
canvasBox.appendChild(canvas);
if(typeof G_vmlCanvasManager != 'undefined') {
  canvas = G_vmlCanvasManager.initElement(canvas);
}

ctx = canvas.getContext("2d");

Notice we get the context ctx of the canvas that we created dynamically using canvas.getContext(“2d”).

When the user draws on the canvas, we need to register the position X and Y within the browser. To do that, we make use of mousedown and touchstart functions. For mobile and tablet devices, we need to tell JavaScript to prevent the default functionality of scroll if canvas is touched using e.preventDefault() function.

When the user starts drawing, we pass the X and Y values to addUserGesture() function and set the drawing flag true.

Below two code snippets does these functions for both mobile and desktop devices.

app.jscode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
//---------------------
// MOUSE DOWN function
//---------------------
$("#canvas").mousedown(function(e) {
  var mouseX = e.pageX - this.offsetLeft;
  var mouseY = e.pageY - this.offsetTop;

  drawing = true;
  addUserGesture(mouseX, mouseY);
  drawOnCanvas();
});

//---------------------
// TOUCH START function
//---------------------
canvas.addEventListener("touchstart", function (e) {
  if (e.target == canvas) {
      e.preventDefault();
    }

  var rect  = canvas.getBoundingClientRect();
  var touch = e.touches[0];

  var mouseX = touch.clientX - rect.left;
  var mouseY = touch.clientY - rect.top;

  drawing = true;
  addUserGesture(mouseX, mouseY);
  drawOnCanvas();

}, false);

We have asked JavaScript to just start recording the positions. But the user normally move his finger or move cursor to draw something on the canvas. To record the movement, we use mousemove and touchmove functions.

Only if the drawing bool is set (i.e the user have started to draw), we record the position X, Y and send the drawing boolean to addUserGesture() function. Then, we call drawOnCanvas() function to update the user’s drawing which I will explain in a while.

app.jscode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
//---------------------
// MOUSE MOVE function
//---------------------
$("#canvas").mousemove(function(e) {
  if(drawing) {
    var mouseX = e.pageX - this.offsetLeft;
    var mouseY = e.pageY - this.offsetTop;
    addUserGesture(mouseX, mouseY, true);
    drawOnCanvas();
  }
});

//---------------------
// TOUCH MOVE function
//---------------------
canvas.addEventListener("touchmove", function (e) {
  if (e.target == canvas) {
      e.preventDefault();
    }
  if(drawing) {
    var rect = canvas.getBoundingClientRect();
    var touch = e.touches[0];

    var mouseX = touch.clientX - rect.left;
    var mouseY = touch.clientY - rect.top;

    addUserGesture(mouseX, mouseY, true);
    drawOnCanvas();
  }
}, false);

During all the other cases, we simply make the drawing variable false. Below is the code snippet to do that.

app.jscode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
//---------------------
// MOUSE UP function
//---------------------
$("#canvas").mouseup(function(e) {
  drawing = false;
});

//---------------------
// TOUCH END function
//---------------------
canvas.addEventListener("touchend", function (e) {
  if (e.target == canvas) {
      e.preventDefault();
    }
  drawing = false;
}, false);

//----------------------
// MOUSE LEAVE function
//----------------------
$("#canvas").mouseleave(function(e) {
  drawing = false;
});

//---------------------
// TOUCH LEAVE function
//---------------------
canvas.addEventListener("touchleave", function (e) {
  if (e.target == canvas) {
      e.preventDefault();
    }
  drawing = false;
}, false);

Finally, we will understand what drawOnCanvas() function does. First, we clear the canvas during each move or touch and then refill it with the values of X and Y using the ctx we obtained eariler for our canvas. We make use of canvas attributes such as strokeStyle, lineJoin and lineWidth, and canvas functions such as beginPath(), moveTo(), lineTo(), closePath() and stroke() to visualize what the user had drawn.

To clear the canvas, we simply use clearRect function and pass in the width and height of the canvas, and we reinitialize the position arrays.

app.jscode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
//----------------------
// ADD CLICK function
//----------------------
function addUserGesture(x, y, dragging) {
  clickX.push(x);
  clickY.push(y);
  clickD.push(dragging);
}

//----------------------
// RE DRAW function
//----------------------
function drawOnCanvas() {
  ctx.clearRect(0, 0, ctx.canvas.width, ctx.canvas.height);

  ctx.strokeStyle = canvasStrokeStyle;
  ctx.lineJoin    = canvasLineJoin;
  ctx.lineWidth   = canvasLineWidth;

  for (var i = 0; i < clickX.length; i++) {
    ctx.beginPath();
    if(clickD[i] && i) {
      ctx.moveTo(clickX[i-1], clickY[i-1]);
    } else {
      ctx.moveTo(clickX[i]-1, clickY[i]);
    }
    ctx.lineTo(clickX[i], clickY[i]);
    ctx.closePath();
    ctx.stroke();
  }
}

//----------------------
// CLEAR CANVAS function
//----------------------
function clearCanvas(id) {
  ctx.clearRect(0, 0, canvasWidth, canvasHeight);
  clickX = new Array();
  clickY = new Array();
  clickD = new Array();
}

That’s it! We have all the super-power javascript functions to make the user draw on a canvas in all devices such as desktop, laptop, tablet and mobile. Now, let’s move on to our deep neural network model that we dumped eariler.

You can load the entire model folder dumped eariler into your web app or server. We write three important functions that are related to our model in JavaScript.

  1. loadModel() with select element handler
  2. preprocessCanvas()
  3. predict()
Load Model with select element handler

First, we load the model trained in Keras Python environment using a simple HTTPS request and tf.loadModel() function. We check if the model has loaded or not by using console.log(). Below is the code snippet to load the trained Keras model using TensorFlow.js.

app.jscode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
//-----------------------
// select model handler
//-----------------------
$("#select_model").change(function() {
    var select_model  = document.getElementById("select_model");
    var select_option = select_model.options[select_model.selectedIndex].value;

    if (select_option == "MLP") {
      modelName = "digitrecognizermlp";

    } else if (select_option == "CNN") {
      modelName = "digitrecognizercnn";

    } else {
      modelName = "digitrecognizermlp";
    }

    loadModel(modelName);
});

//-------------------------------------
// loader for digitrecognizermlp model
//-------------------------------------
async function loadModel() {
  console.log("model loading..");

  // clear the model variable
  model = undefined;
  
  // load the model using a HTTPS request (where you have stored your model files)
  model = await tf.loadModel("https://gogul09.github.io/models/" + modelName + "/model.json");
  
  console.log("model loaded..");
}

loadModel();
Preprocess Canvas

After loading the model, we need to preprocess the canvas drawn by the user to feed it to the DNN (MLP or CNN) that we have trained using Keras.

Warning: Preprocessing the HTML5 canvas element is the crucial step in this application.

Preprocessing for MLP
  1. We use tf.fromPixels() and pass in the HTML5 canvas element directly without any transformations.
  2. We resize the canvas into our MLP input image size of [28, 28] using tf.resizeNearestNeighbor() function.
  3. We transform the canvas image into a grayscale image which becomes two-dimensional using tf.mean(2) function.
  4. We convert all the values in the canvas to float using tf.toFloat() function and reshape the two-dimensional matrix into a row vector of shape [1, 784] to feed it to our MLP model using tf.reshape().
  5. Finally, we return the tensor after dividing each value in it by 255.0 using tf.div() as we did earlier during MLP model training.
Preprocessing for CNN
  1. We use tf.fromPixels() and pass in the HTML5 canvas element directly without any transformations.
  2. We resize the canvas into our CNN input image size of [28, 28] using tf.resizeNearestNeighbor() function.
  3. We transform the canvas image into a grayscale image which becomes two-dimensional using tf.mean(2) function.
  4. We then expand the dimensions of the grayscale image into 4 dimensions as CNN expects the input to be 4D. To do this, we use tf.expandDims(2) to get a 3D matrix of shape [28, 28, 1] and then we use tf.expandDims() to get a 4D matrix of shape [1, 28, 28, 1].
  5. We convert all the values in the canvas to float using tf.toFloat() function.
  6. Finally, we return the tensor after dividing each value in it by 255.0 using tf.div() as we did earlier during CNN model training.
app.jscode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
//-----------------------------------------------
// preprocess the canvas to be DNN friendly
//-----------------------------------------------
function preprocessCanvas(image, modelName) {

  // if model is not available, send the tensor with expanded dimensions
  if (modelName === undefined) {
    alert("No model defined..")
  } 

  // if model is digitrecognizermlp, perform all the preprocessing
  else if (modelName === "digitrecognizermlp") {
    
    // resize the input image to digitrecognizermlp's target size of (784, )
    let tensor = tf.fromPixels(image)
        .resizeNearestNeighbor([28, 28])
        .mean(2)
        .toFloat()
        .reshape([1 , 784]);
    return tensor.div(255.0);
  }

  // if model is digitrecognizercnn, perform all the preprocessing
  else if (modelName === "digitrecognizercnn") {
    // resize the input image to digitrecognizermlp's target size of (1, 28, 28, 1)
    let tensor = tf.fromPixels(image)
        .resizeNearestNeighbor([28, 28])
        .mean(2)
        .expandDims(2)
        .expandDims()
        .toFloat();
    console.log(tensor.shape);
    return tensor.div(255.0);
  }

  // else throw an error
  else {
    alert("Unknown model name..")
  }
}
Predict

Finally, we are ready to predict what the user has drawn using our loaded DNN (MLP or CNN) model with preprocessed canvas tensor available.

We use the method model.predict() and pass in our canvas tensor as the argument and get the predictions using data(). We convert the predictions into a JavaScript array and use displayChart() to display the predictions in a visually pleasing format.

Displaying the model predictions in the form of a chart is optional. By the way, the model predictions are now available in the variable results.

app.jscode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
//----------------------------
// Bounding box for centering
//----------------------------
function boundingBox() {
  var minX = Math.min.apply(Math, clickX) - 20;
  var maxX = Math.max.apply(Math, clickX) + 20;
  
  var minY = Math.min.apply(Math, clickY) - 20;
  var maxY = Math.max.apply(Math, clickY) + 20;

  var tempCanvas = document.createElement("canvas"),
  tCtx = tempCanvas.getContext("2d");

  tempCanvas.width  = maxX - minX;
  tempCanvas.height = maxY - minY;

  tCtx.drawImage(canvas, minX, minY, maxX - minX, maxY - minY, 0, 0, maxX - minX, maxY - minY);

  var imgBox = document.getElementById("canvas_image");
  imgBox.src = tempCanvas.toDataURL();

  return tempCanvas;
}

//--------------------------------------------
// predict function for digit recognizer mlp
//--------------------------------------------
async function predict() {

  // get the user drawn region alone cropped
  croppedCanvas = boundingBox();

  // show the cropped image 
  document.getElementById("canvas_output").style.display = "block";

  // preprocess canvas
  let tensor = preprocessCanvas(croppedCanvas, modelName);

  // make predictions on the preprocessed image tensor
  let predictions = await model.predict(tensor).data();

  // get the model's prediction results
  let results = Array.from(predictions)

  // display the predictions in chart
  displayChart(results)

  console.log(results);
}

Simple Bar Chart to display the predictions

This section of this tutorial is optional. It is for people like me who are obsessed with data visualization.

You can use the below lines of code to display the predictions of our DNN (MLP or CNN) model in the form of a bar chart. I have used Chart.js which is a open-source JavaScript charting library.

app.jscode
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
//------------------------------
// Chart to display predictions
//------------------------------
var chart = "";
var firstTime = 0;
function loadChart(label, data, modelSelected) {
  var context = document.getElementById('chart_box').getContext('2d');
  chart = new Chart(context, {
      // we are in need of a bar chart
      type: 'bar',

      // we feed in data dynamically using data variable
      // that is passed as an argument to this function
      data: {
          labels: label,
          datasets: [{
              label: modelSelected + " prediction",
              backgroundColor: '#f50057',
              borderColor: 'rgb(255, 99, 132)',
              data: data,
          }]
      },

      // you can also play around with options for the 
      // chart if you find time!
      options: {}
  });
}

//----------------------------
// display chart with updated
// drawing from canvas
//----------------------------
function displayChart(data) {
  var select_model  = document.getElementById("select_model");
  var select_option = select_model.options[select_model.selectedIndex].value;
  
  label = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9"];
  if (firstTime == 0) {
    loadChart(label, data, select_option);
    firstTime = 1;
  } else {
    chart.destroy();
    loadChart(label, data, select_option);
  }
  document.getElementById('chart_box').style.display = "block";
}

That’s it! Finally, we have built something awesome using multiple programming languages such as JavaScript, Python, HTML5 and CSS3. It’s all possible because of two amazing deep learning libraries such as Keras and TensorFlow.js.

References

  1. MNIST Database
  2. MNIST For ML Beginners
  3. Keras Official Documentation
  4. TensorFlow.js Official Documentation
  5. Create a Drawing App with HTML5 Canvas and JavaScript
  6. Multi-Layer Perceptron
  7. Convolutional Neural Network

In case if you found something useful to add to this article or you found a bug in the code or would like to improve some points mentioned, feel free to write it down in the comments. Hope you found something useful here.

Happy learning!