Handwritten Digit Recognition with scikit-learn

handwritten digits recognition

Version française

About this tutorial

This tutorial is a hands-on introduction to machine learning for beginners.

Getting started with machine learning can be quite difficult when you're randomly looking for information on the web.

Here, my goal is to help you with a concrete example of image recognition, with just a little bit of code, and no maths.

After a short introduction to machine learning, you will learn:

  • the principles of supervised machine learning for classification,
  • how to install the whole scientific python suite,
  • how to access and validate the training data for your network,
  • how to create and train your network
  • how to use the trained and test its performance.

Prerequisites

We will work in python, which is a wonderful choice for data science. If you're not a python developer but know a bit of C, C++ or Java for example, you'll be fine. That will be an excellent occasion to to discover python. And, who knows, you could even fall in love with this language too!

Why machine learning?

Machine learning is a field of artificial intelligence in which a system is designed to learn automatically given a set of input data. After the system has learnt (we say that the system has been trained), we can use it to make predictions for new data, unseen before.

This approach makes it possible to solve complex problems which are difficult or impossible to solve with traditional sequential programming.

Examples of machine learning applications include:

  • autonomous cars: given the data from sensors like cameras and radars, the car is trained to drive on its own. The one of google still needs to learn about the right lane ;-)
  • drones: the drone pilot only needs to give simple instructions (up, down, left, right, or just 3D coordinates), and the drone automatically performs complex adjustments to keep stability, or to fly in formation
  • robots
  • predicting real estate price from a set of variables like location, number of rooms, and even the text of the real estate ad. I'll certainly do a tutorial about that in the near future.
  • google ads that predict the probability for you to be interested in a given ad to send you the most promising ones
  • collaborative recommendation systems that give you the youtube videos or amazon products you will like
  • defect identification on production chains
  • identification of clusters of like-minded people on the social networks, and of the most important influencers within these groups
  • tagging photos (just type cat or food in the search box of your google photos library if you have one)
  • translation systems like google translate
  • spam email filtering
  • automatically generate paintings: deep dream , neural doodle
  • ...

Machine learning for classification

Let's get a feel for how a neural network can be trained for classification.

In this post, our goal is to get started hands on with machine learning fast and easy, so I'm only going to give you a simplified explanation for now. There will be a more detailed post about the training principles later on, so stay tuned if you're interested.

Supervised learning

The network is presented with a succession of training examples. Each training example consists of:

  • the image of a digit
  • a label, which tells us which digit the image truly represents. For a given image, the label could be told to us by the person who wrote the digit in the first place.

In the drawing above, the first image is processed by the neural network, which produces an answer: this is a 9.

At first, the connections between the neurons in the network are random, and the network is not able to do anything useful. It just provides a random answer.

The answer is compared to the label. In this case, the answer (9) is different from the label (the digit is actually a 3), and some feedback is given to the neural network so that it can improve. The connections between the neurons are modified, favoring the ones that tend to give a correct answer.

After the modification, the next examples are considered, and the neural network learns in an iterative process.

The number of training examples needed to train the network properly could be of the order of a few hundred for networks with a simple architecture, and millions for complex networks.

Installing python and its scientific library

if you're already running this tutorial in your jupyter notebook, please skip this section.

We will use a variety of tools from scipy , the scientific python library:

  • scikit-learn : one of leading machine-learning toolkits for python. It will provide an easy access to the handwritten digits dataset, and allow us to define and train our neural network in a few lines of code
  • numpy : core package providing powerful tools to manipulate data arrays, such as our digit images
  • matplotlib : visualization tools, essential to check what we are doing
  • jupyter : the web server that will allow you to follow this tutorial and run the code directly in your web browser.

Scipy is actually not a single library, but an "ecosystem" of interdependent python packages.

This ecosystem is full of snakes and beasts fighting survival -- you do not want to hang in there alone.

And indeed, six years ago, when I first got started with scipy, I tried to install manually all the packages I needed on top of the version of python already installed on my system.

I spent almost a day fighting against conflicting dependencies for these packages. For example, scikit-learn might need numpy version A, but pandas needs numpy version B. Or, one of these packages requires a version of python more recent than the one you have, meaning that you need to install an additional version of python and deal with your two versions later on.

And then, I discovered Anaconda .

As stated on Anaconda's website:

With over 6 million users, the open source Anaconda Distribution is the fastest and easiest way to do Python and R data science and machine learning on Linux, Windows, and Mac OS X. It's the industry standard for developing, testing, and training on a single machine.

In a nutshell, the anaconda team maintains a repository of more than 1400 data science packages, all compatible, and provides tools to install a version of python and these packages at the push of a button, and under five minutes.

Let's do it now!

First, download anaconda for your system:

  • Choose the python 2.X version, not the 3.X version.
  • If you're using Windows or Linux, make sure to pick the 64 bit installer if you have a 64 bit system.

Run the installer, and finally start the Anaconda Navigator. On windows, you can find it by clicking the windows start button, and typing anaconda.

In the Anaconda Navigator window, click on the Home tab, and launch the jupyter notebook.

Create a new notebook. In your notebook, you should see an empty cell, where you can write python code. Copy-paste the following lines, and execute the cell by pressing shift + enter.

print 'hello world!'
for i in range(10):
    print i

A new cell appears. Import numpy and matplotlib (remember that you need to execute the cell):

import matplotlib.pyplot as plt 
import numpy as np

This is a standard way to import these modules:

  • the pyplot module of matplotlib is called plt in this context
  • the numpy module is called np

You can very well choose other names, but these ones are used by almost everybody, so it's easier to use them as well.

Now let's try and do our first plot, just to make sure that numpy and matplotlib are working:

# create a numpy 1-D array with 16 evenly spaced values, from 0 to 3.
x = np.linspace(0, 3, 16)
print x 
# create a new numpy array. 
# x**2 means that each element of x is squared.
y = x**2
print y
# plot y versus x, you should get a parabola. 
# check that for x = 1 we have y = 1, and that for x = 2, y = 4. 
plt.plot(x, y)

💡 A word of caution:

It is very easy to get lost in the documentation of all these tools, and to waste a lot of time.

For example, if you check the documentation of the plt.plot method (I won't give you the link ;-) but you could google it), you will see that there are lots of ways to call it, with many optional parameters. But after all, do we need to know more than this: plt.plot(x,y) plots y vs x ?

If you want to have fun, I suggest to follow this tutorial until the end without digging deeper.

You'll train your first neural net easily and in the process, you'll get an understanding of the most important scikit-learn, numpy, and matplotlib tools. That's more than enough for a variety of machine learning tasks, and you can always learn more about specific features of these tools when you need them later on (you'll know!)

To go further, here's an excellent scipy lecture .


Now that you have access to the jupyter notebook, I have good news. You won't need to keep copy-pasting code from this page to your notebook.

Instead, just do the following:

  • download the repository containing this notebook
  • unzip it, say to Downloads/maldives-master
  • launch a jupyter notebook from the anaconda navigator
  • in the notebook, navigate to Downloads/maldives-master/handwritten_digits_sklearn
  • open handwritten_digits_sklean.ipynb

You should see this page appear in the notebook. From now on, follow the tutorial in the notebook. You should execute the cells as they come, or execute them all in one go. You can even add cells or modify existing cells to experiment a bit.

The digits dataset

scikit-learn comes with several test datasets. Let's load the handwritten digits dataset:

In [1]:
from sklearn import datasets
digits = datasets.load_digits()

In python, the dir function returns the names of the attributes of an object, in other words which information is stored in the object in the form of other objects. Let's use this function to check what can be found in the digits object:

In [2]:
dir(digits)
Out[2]:
['DESCR', 'data', 'images', 'target', 'target_names']

Let's have a look in more details at some of these attributes. We are going to start by checking their type:

In [3]:
print type(digits.images)
print type(digits.target)
<type 'numpy.ndarray'>
<type 'numpy.ndarray'>

images and target are ndarrays (N-dimensional arrays) from the numpy package. The shape attribute of an ndarray gives the number of dimensions and the size along each dimension of the array. For example:

In [4]:
digits.images.shape
Out[4]:
(1797, 8, 8)

digits.image is an array with 3 dimensions. The first dimension indexes images, and we see that we have 1797 images in total. The next two dimensions correspond to the x and y coordinates of the pixels in each image. Each image has 8x8 = 64 pixels. In other words, this array could be represented in 3D as a pile of images with 8x8 pixels each.

let's look at the data of the first 8x8 image. Each slot in the array corresponds to a pixel, and the value in the slot is the amount of black in the pixel

In [5]:
print digits.images[0]
[[ 0.  0.  5. 13.  9.  1.  0.  0.]
 [ 0.  0. 13. 15. 10. 15.  5.  0.]
 [ 0.  3. 15.  2.  0. 11.  8.  0.]
 [ 0.  4. 12.  0.  0.  8.  8.  0.]
 [ 0.  5.  8.  0.  0.  9.  8.  0.]
 [ 0.  4. 11.  0.  1. 12.  7.  0.]
 [ 0.  2. 14.  5. 10. 12.  0.  0.]
 [ 0.  0.  6. 13. 10.  0.  0.  0.]]

Now let's display this image: (sometimes, the plot does not appear, just rerun this cell if you don't see the image)

In [6]:
import matplotlib.pyplot as plt
plt.imshow(digits.images[0],cmap='binary')
plt.show()
<Figure size 640x480 with 1 Axes>

The image is low resolution. The original digits were of much higher resolution, and the resolution has been decreased when creating the dataset for scikit-learn to make it easier and faster to train a machine learning algorithm to recognize these digits.

Now let's investigate the target attribute:

In [7]:
print digits.target.shape
print digits.target
(1797,)
[0 1 2 ... 8 9 8]

It is a 1-dimensional array with 1797 slots. Looking into the array, we see that it contains the true numbers corresponding to each image. For example, the first target is 0, and corresponds to the image drawn just above.

Let's have a look at some more images using this function:

In [8]:
def plot_multi(i):
    '''Plots 16 digits, starting with digit i'''
    nplots = 16
    fig = plt.figure(figsize=(15,15))
    for j in range(nplots):
        plt.subplot(4,4,j+1)
        plt.imshow(digits.images[i+j], cmap='binary')
        plt.title(digits.target[i+j])
        plt.axis('off')
    plt.show()
In [9]:
plot_multi(0)

you can have a look at the next digits by calling plot_multi(16) , plot_multi(32) , etc. You will probably see that with such a low resolution, it's quite difficult to recognize some of the digits, even for a human. In these conditions, our neural network will also be limited by the low quality of the input images. Can the neural network perform at least as well as a human? It would already be an achievement!

Building the network and preparing the input data

With scikit-learn , creating, training, and evaluating a neural network can be done with only a few lines of code.

We will make a very simple neural network, with three layers:

  • an input layer, with 64 nodes, one node per pixel in the input images. Nodes are neurons that actually do nothing. They just take their input value and send it to the neurons of the next layer
  • a hidden layer with 15 neurons. We could choose a different number, and also add more hidden layers with different numbers of neurons
  • an output layer with 10 neurons corresponding to our 10 classes of digits, from 0 to 9.

This is a dense neural network, which means that each node in each layer is connected to all nodes in the previous and next layers.

Simple dense network

The input layer requires a 1-dimensional array in input, but our images are 2D. So we need to flatten all images:

In [10]:
y = digits.target
x = digits.images.reshape((len(digits.images), -1))
x.shape
Out[10]:
(1797, 64)

We now have 1797 flattened images. The two dimensions of our 8x8 images have been collapsed into a single dimension by writing the rows of 8 pixels as they come, one after the other. The first image that we looked at earlier is now represented by a 1-D array with 8x8 = 64 slots. Please check that the values below are the same as in the original 2-D image.

In [11]:
x[0]
Out[11]:
array([ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.,  0.,  0., 13., 15., 10.,
       15.,  5.,  0.,  0.,  3., 15.,  2.,  0., 11.,  8.,  0.,  0.,  4.,
       12.,  0.,  0.,  8.,  8.,  0.,  0.,  5.,  8.,  0.,  0.,  9.,  8.,
        0.,  0.,  4., 11.,  0.,  1., 12.,  7.,  0.,  0.,  2., 14.,  5.,
       10., 12.,  0.,  0.,  0.,  0.,  6., 13., 10.,  0.,  0.,  0.])

let's now split our data into a training sample and a testing sample:

In [12]:
x_train = x[:1000]
y_train = y[:1000]
x_test = x[1000:]
y_test = y[1000:]

The first 1000 images and labels are going to be used for training. The rest of the dataset will be used later to test the performance of our network.

We can now create the neural network. We use one hidden layers with 15 neurons, and scikit-learn is smart enough to find out how many numbers to use in the input and output layers. Don't pay attention to the other parameters, we'll cover that in future posts.

In [18]:
from sklearn.neural_network import MLPClassifier

mlp = MLPClassifier(hidden_layer_sizes=(15,), activation='logistic', alpha=1e-4,
                    solver='sgd', tol=1e-4, random_state=1,
                    learning_rate_init=.1, verbose=True)

Finally, we can train the neural network:

In [19]:
mlp.fit(x_train,y_train)
Iteration 1, loss = 2.22958289
Iteration 2, loss = 1.91207743
Iteration 3, loss = 1.62507727
Iteration 4, loss = 1.32649842
Iteration 5, loss = 1.06100535
Iteration 6, loss = 0.83995513
Iteration 7, loss = 0.67806075
Iteration 8, loss = 0.55175832
Iteration 9, loss = 0.45840445
Iteration 10, loss = 0.39149735
Iteration 11, loss = 0.33676351
Iteration 12, loss = 0.29059880
Iteration 13, loss = 0.25437208
Iteration 14, loss = 0.22838372
Iteration 15, loss = 0.20200554
Iteration 16, loss = 0.18186565
Iteration 17, loss = 0.16461183
Iteration 18, loss = 0.14990228
Iteration 19, loss = 0.13892154
Iteration 20, loss = 0.12833784
Iteration 21, loss = 0.12138920
Iteration 22, loss = 0.11407971
Iteration 23, loss = 0.10677664
Iteration 24, loss = 0.10037149
Iteration 25, loss = 0.09593187
Iteration 26, loss = 0.09250135
Iteration 27, loss = 0.08676698
Iteration 28, loss = 0.08356043
Iteration 29, loss = 0.08209789
Iteration 30, loss = 0.07649168
Iteration 31, loss = 0.07410898
Iteration 32, loss = 0.07126869
Iteration 33, loss = 0.06926956
Iteration 34, loss = 0.06578496
Iteration 35, loss = 0.06374913
Iteration 36, loss = 0.06175492
Iteration 37, loss = 0.05975664
Iteration 38, loss = 0.05764485
Iteration 39, loss = 0.05623663
Iteration 40, loss = 0.05420966
Iteration 41, loss = 0.05413911
Iteration 42, loss = 0.05256140
Iteration 43, loss = 0.05020265
Iteration 44, loss = 0.04902779
Iteration 45, loss = 0.04788382
Iteration 46, loss = 0.04655532
Iteration 47, loss = 0.04586089
Iteration 48, loss = 0.04451758
Iteration 49, loss = 0.04341598
Iteration 50, loss = 0.04238096
Iteration 51, loss = 0.04162200
Iteration 52, loss = 0.04076839
Iteration 53, loss = 0.04003180
Iteration 54, loss = 0.03907774
Iteration 55, loss = 0.03815565
Iteration 56, loss = 0.03791975
Iteration 57, loss = 0.03706276
Iteration 58, loss = 0.03617874
Iteration 59, loss = 0.03593227
Iteration 60, loss = 0.03504175
Iteration 61, loss = 0.03441259
Iteration 62, loss = 0.03397449
Iteration 63, loss = 0.03326990
Iteration 64, loss = 0.03305025
Iteration 65, loss = 0.03244893
Iteration 66, loss = 0.03191504
Iteration 67, loss = 0.03132169
Iteration 68, loss = 0.03079707
Iteration 69, loss = 0.03044946
Iteration 70, loss = 0.03005546
Iteration 71, loss = 0.02960555
Iteration 72, loss = 0.02912799
Iteration 73, loss = 0.02859103
Iteration 74, loss = 0.02825959
Iteration 75, loss = 0.02788968
Iteration 76, loss = 0.02748725
Iteration 77, loss = 0.02721247
Iteration 78, loss = 0.02686225
Iteration 79, loss = 0.02635636
Iteration 80, loss = 0.02607439
Iteration 81, loss = 0.02577613
Iteration 82, loss = 0.02553642
Iteration 83, loss = 0.02518749
Iteration 84, loss = 0.02484300
Iteration 85, loss = 0.02455379
Iteration 86, loss = 0.02432480
Iteration 87, loss = 0.02398548
Iteration 88, loss = 0.02376004
Iteration 89, loss = 0.02341261
Iteration 90, loss = 0.02318255
Iteration 91, loss = 0.02296065
Iteration 92, loss = 0.02274048
Iteration 93, loss = 0.02241054
Iteration 94, loss = 0.02208181
Iteration 95, loss = 0.02190861
Iteration 96, loss = 0.02174404
Iteration 97, loss = 0.02156939
Iteration 98, loss = 0.02119768
Iteration 99, loss = 0.02101874
Iteration 100, loss = 0.02078230
Iteration 101, loss = 0.02061573
Iteration 102, loss = 0.02039802
Iteration 103, loss = 0.02017245
Iteration 104, loss = 0.01997162
Iteration 105, loss = 0.01989280
Iteration 106, loss = 0.01963828
Iteration 107, loss = 0.01941850
Iteration 108, loss = 0.01933154
Iteration 109, loss = 0.01911473
Iteration 110, loss = 0.01905371
Iteration 111, loss = 0.01876085
Iteration 112, loss = 0.01860656
Iteration 113, loss = 0.01848655
Iteration 114, loss = 0.01834844
Iteration 115, loss = 0.01818981
Iteration 116, loss = 0.01798523
Iteration 117, loss = 0.01783630
Iteration 118, loss = 0.01771441
Iteration 119, loss = 0.01749814
Iteration 120, loss = 0.01738339
Iteration 121, loss = 0.01726549
Iteration 122, loss = 0.01709638
Iteration 123, loss = 0.01698340
Iteration 124, loss = 0.01684606
Iteration 125, loss = 0.01667016
Iteration 126, loss = 0.01654172
Iteration 127, loss = 0.01641832
Iteration 128, loss = 0.01630111
Iteration 129, loss = 0.01623051
Iteration 130, loss = 0.01612736
Iteration 131, loss = 0.01590220
Iteration 132, loss = 0.01582485
Iteration 133, loss = 0.01571372
Iteration 134, loss = 0.01560349
Iteration 135, loss = 0.01557688
Iteration 136, loss = 0.01534420
Iteration 137, loss = 0.01527883
Iteration 138, loss = 0.01517545
Iteration 139, loss = 0.01503663
Iteration 140, loss = 0.01501192
Iteration 141, loss = 0.01482535
Iteration 142, loss = 0.01471388
Iteration 143, loss = 0.01463948
Iteration 144, loss = 0.01454059
Iteration 145, loss = 0.01441742
Iteration 146, loss = 0.01431741
Iteration 147, loss = 0.01428414
Iteration 148, loss = 0.01416364
Iteration 149, loss = 0.01406742
Iteration 150, loss = 0.01402651
Iteration 151, loss = 0.01389720
Iteration 152, loss = 0.01381412
Iteration 153, loss = 0.01371300
Iteration 154, loss = 0.01362465
Iteration 155, loss = 0.01357048
Iteration 156, loss = 0.01348760
Training loss did not improve more than tol=0.000100 for two consecutive epochs. Stopping.
Out[19]:
MLPClassifier(activation='logistic', alpha=0.0001, batch_size='auto',
       beta_1=0.9, beta_2=0.999, early_stopping=False, epsilon=1e-08,
       hidden_layer_sizes=(15,), learning_rate='constant',
       learning_rate_init=0.1, max_iter=200, momentum=0.9,
       nesterovs_momentum=True, power_t=0.5, random_state=1, shuffle=True,
       solver='sgd', tol=0.0001, validation_fraction=0.1, verbose=True,
       warm_start=False)

The training was extremely fast because the neural network is simple and the input dataset is small. Now that the network has been trained, let's see what it can say about our test images:

In [20]:
predictions = mlp.predict(x_test)
predictions[:50] 
# we just look at the 1st 50 examples in the test sample
Out[20]:
array([1, 4, 0, 5, 3, 6, 9, 6, 1, 7, 5, 4, 4, 7, 2, 8, 2, 2, 5, 7, 9, 5,
       4, 4, 9, 0, 8, 9, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8, 3, 0, 1, 2, 3, 4,
       5, 6, 7, 8, 5, 0])

These predictions should be fairly close to the targets of our training sample. Let's check by eye (please compare the values of these arrays)

In [21]:
y_test[:50] 
# true labels for the 1st 50 examples in the test sample
Out[21]:
array([1, 4, 0, 5, 3, 6, 9, 6, 1, 7, 5, 4, 4, 7, 2, 8, 2, 2, 5, 7, 9, 5,
       4, 4, 9, 0, 8, 9, 8, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4,
       5, 6, 7, 8, 9, 0])

Not bad! we see that most (if not all) predictions match the true labels.

But can we be a bit more quantitative? We can compute the accuracy of the classifier, which the probability for a digit to be classified in the right category. Again, scikit-learn comes with a handy tool to do that:

In [22]:
from sklearn.metrics import accuracy_score
accuracy_score(y_test, predictions)
Out[22]:
0.9159347553324969

This number is the probability for the digits in the test sample to be classified in the right category, meaning that we get 91.6% of the digits right, and 8.4% wrong.

We managed to get an 91.6% accuracy with this very simple neural network. Not too bad!

However, this is only a first try.

Actually, I must confess that I chose to use a simplistic network to keep the performance on the low side, so that we can optimize it later on.

Conclusion and outlook

In this hands-on tutorial, you have learned:

  • The principles of supervised machine learning for classification,
  • How to install and use the scientific python suite for machine learning,
  • How to investigate about your input dataset,
  • How to train a neural network for image recognition, reaching an accuracy larger than 90% for digit classification.

It's only the beginning! In future posts we will:

  • see if we can optimize our network to further increase the accuracy,
  • use deep learning (much more complex networks) to reach extreme accuracies,
  • dive a bit more into the mechanism of the training to understand why we have created the neural network with these parameters.


Read more

about neural networks

about image recognition


Please let me know what you think in the comments! I’ll try and answer all questions.

And if you liked this article, you can subscribe to my newsletter to be notified of new posts (no more than one mail per week I promise.)

Back Home