Create a first simple neural network to classify handwritten digits.
This tutorial is a hands-on introduction to machine learning for beginners.
Getting started with machine learning can be quite difficult when you're randomly looking for information on the web.
Here, my goal is to help you with a concrete example of image recognition, with just a little bit of code, and no maths.
After a short introduction to machine learning, you will learn:
We will work in python, which is a wonderful choice for data science. If you're not a python developer but know a bit of C, C++ or Java for example, you'll be fine. That will be an excellent occasion to to discover python. And, who knows, you could even fall in love with this language too!
Machine learning is a field of artificial intelligence in which a system is designed to learn automatically given a set of input data. After the system has learnt (we say that the system has been trained), we can use it to make predictions for new data, unseen before.
This approach makes it possible to solve complex problems which are difficult or impossible to solve with traditional sequential programming.
Examples of machine learning applications include:
Let's get a feel for how a neural network can be trained for classification.
In this post, our goal is to get started hands on with machine learning fast and easy, so I'm only going to give you a simplified explanation for now. There will be a more detailed post about the training principles later on, so stay tuned if you're interested.
The network is presented with a succession of training examples. Each training example consists of:
In the drawing above, the first image is processed by the neural network, which produces an answer: this is a 9.
At first, the connections between the neurons in the network are random, and the network is not able to do anything useful. It just provides a random answer.
The answer is compared to the label. In this case, the answer (9) is different from the label (the digit is actually a 3), and some feedback is given to the neural network so that it can improve. The connections between the neurons are modified, favoring the ones that tend to give a correct answer.
After the modification, the next examples are considered, and the neural network learns in an iterative process.
The number of training examples needed to train the network properly could be of the order of a few hundred for networks with a simple architecture, and millions for complex networks.
if you're already running this tutorial in your jupyter notebook, please skip this section.
We will use a variety of tools from scipy , the scientific python library:
Scipy is actually not a single library, but an "ecosystem" of interdependent python packages.
This ecosystem is full of snakes and beasts fighting survival -- you do not want to hang in there alone.
And indeed, six years ago, when I first got started with scipy, I tried to install manually all the packages I needed on top of the version of python already installed on my system.
I spent almost a day fighting against conflicting dependencies for these packages. For example, scikit-learn might need numpy version A, but pandas needs numpy version B. Or, one of these packages requires a version of python more recent than the one you have, meaning that you need to install an additional version of python and deal with your two versions later on.
And then, I discovered Anaconda .
As stated on Anaconda's website:
With over 6 million users, the open source Anaconda Distribution is the fastest and easiest way to do Python and R data science and machine learning on Linux, Windows, and Mac OS X. It's the industry standard for developing, testing, and training on a single machine.
In a nutshell, the anaconda team maintains a repository of more than 1400 data science packages, all compatible, and provides tools to install a version of python and these packages at the push of a button, and under five minutes.
Let's do it now!
First, download anaconda for your system:
Run the installer, and finally start the Anaconda Navigator. On windows, you can find it by clicking the windows start button, and typing anaconda.
In the Anaconda Navigator window, click on the Home tab, and launch the jupyter notebook.
Create a new notebook. In your notebook, you should see an empty cell, where you can write python code. Copy-paste the following lines, and execute the cell by pressing shift + enter.
print 'hello world!'
for i in range(10):
print i
A new cell appears. Import numpy and matplotlib (remember that you need to execute the cell):
import matplotlib.pyplot as plt
import numpy as np
This is a standard way to import these modules:
You can very well choose other names, but these ones are used by almost everybody, so it's easier to use them as well.
Now let's try and do our first plot, just to make sure that numpy and matplotlib are working:
# create a numpy 1-D array with 16 evenly spaced values, from 0 to 3.
x = np.linspace(0, 3, 16)
print x
# create a new numpy array.
# x**2 means that each element of x is squared.
y = x**2
print y
# plot y versus x, you should get a parabola.
# check that for x = 1 we have y = 1, and that for x = 2, y = 4.
plt.plot(x, y)
💡 A word of caution:
It is very easy to get lost in the documentation of all these tools, and to waste a lot of time.
For example, if you check the documentation of the plt.plot method (I won't give you the link ;-) but you could google it), you will see that there are lots of ways to call it, with many optional parameters. But after all, do we need to know more than this:
plt.plot(x,y)
plots y vs x ?
If you want to have fun, I suggest to follow this tutorial until the end without digging deeper.
You'll train your first neural net easily and in the process, you'll get an understanding of the most important scikit-learn, numpy, and matplotlib tools. That's more than enough for a variety of machine learning tasks, and you can always learn more about specific features of these tools when you need them later on (you'll know!)
To go further, here's an excellent scipy lecture .
Now that you have access to the jupyter notebook, I have good news. You won't need to keep copy-pasting code from this page to your notebook.
Instead, just do the following:
Downloads/maldives-master
Downloads/maldives-master/handwritten_digits_sklearn
handwritten_digits_sklean.ipynb
You should see this page appear in the notebook. From now on, follow the tutorial in the notebook. You should execute the cells as they come, or execute them all in one go. You can even add cells or modify existing cells to experiment a bit.
scikit-learn comes with several test datasets. Let's load the handwritten digits dataset:
from sklearn import datasets
digits = datasets.load_digits()
In python, the
dir
function returns the names of the attributes of an object, in other words which information is stored in the object in the form of other objects. Let's use this function to check what can be found in the digits object:
dir(digits)
Let's have a look in more details at some of these attributes. We are going to start by checking their type:
print type(digits.images)
print type(digits.target)
images
and
target
are ndarrays (N-dimensional arrays) from the numpy package. The shape attribute of an ndarray gives the number of dimensions and the size along each dimension of the array. For example:
digits.images.shape
digits.image is an array with 3 dimensions. The first dimension indexes images, and we see that we have 1797 images in total. The next two dimensions correspond to the x and y coordinates of the pixels in each image. Each image has 8x8 = 64 pixels. In other words, this array could be represented in 3D as a pile of images with 8x8 pixels each.
let's look at the data of the first 8x8 image. Each slot in the array corresponds to a pixel, and the value in the slot is the amount of black in the pixel
print digits.images[0]
Now let's display this image: (sometimes, the plot does not appear, just rerun this cell if you don't see the image)
import matplotlib.pyplot as plt
plt.imshow(digits.images[0],cmap='binary')
plt.show()
The image is low resolution. The original digits were of much higher resolution, and the resolution has been decreased when creating the dataset for scikit-learn to make it easier and faster to train a machine learning algorithm to recognize these digits.
Now let's investigate the target attribute:
print digits.target.shape
print digits.target
It is a 1-dimensional array with 1797 slots. Looking into the array, we see that it contains the true numbers corresponding to each image. For example, the first target is 0, and corresponds to the image drawn just above.
Let's have a look at some more images using this function:
def plot_multi(i):
'''Plots 16 digits, starting with digit i'''
nplots = 16
fig = plt.figure(figsize=(15,15))
for j in range(nplots):
plt.subplot(4,4,j+1)
plt.imshow(digits.images[i+j], cmap='binary')
plt.title(digits.target[i+j])
plt.axis('off')
plt.show()
plot_multi(0)
you can have a look at the next digits by calling
plot_multi(16)
,
plot_multi(32)
, etc. You will probably see that with such a low resolution, it's quite difficult to recognize some of the digits, even for a human. In these conditions, our neural network will also be limited by the low quality of the input images. Can the neural network perform at least as well as a human? It would already be an achievement!
With scikit-learn , creating, training, and evaluating a neural network can be done with only a few lines of code.
We will make a very simple neural network, with three layers:
This is a dense neural network, which means that each node in each layer is connected to all nodes in the previous and next layers.
The input layer requires a 1-dimensional array in input, but our images are 2D. So we need to flatten all images:
y = digits.target
x = digits.images.reshape((len(digits.images), -1))
x.shape
We now have 1797 flattened images. The two dimensions of our 8x8 images have been collapsed into a single dimension by writing the rows of 8 pixels as they come, one after the other. The first image that we looked at earlier is now represented by a 1-D array with 8x8 = 64 slots. Please check that the values below are the same as in the original 2-D image.
x[0]
let's now split our data into a training sample and a testing sample:
x_train = x[:1000]
y_train = y[:1000]
x_test = x[1000:]
y_test = y[1000:]
The first 1000 images and labels are going to be used for training. The rest of the dataset will be used later to test the performance of our network.
We can now create the neural network. We use one hidden layers with 15 neurons, and scikit-learn is smart enough to find out how many numbers to use in the input and output layers. Don't pay attention to the other parameters, we'll cover that in future posts.
from sklearn.neural_network import MLPClassifier
mlp = MLPClassifier(hidden_layer_sizes=(15,), activation='logistic', alpha=1e-4,
solver='sgd', tol=1e-4, random_state=1,
learning_rate_init=.1, verbose=True)
Finally, we can train the neural network:
mlp.fit(x_train,y_train)
The training was extremely fast because the neural network is simple and the input dataset is small. Now that the network has been trained, let's see what it can say about our test images:
predictions = mlp.predict(x_test)
predictions[:50]
# we just look at the 1st 50 examples in the test sample
These predictions should be fairly close to the targets of our training sample. Let's check by eye (please compare the values of these arrays)
y_test[:50]
# true labels for the 1st 50 examples in the test sample
Not bad! we see that most (if not all) predictions match the true labels.
But can we be a bit more quantitative? We can compute the accuracy of the classifier, which the probability for a digit to be classified in the right category. Again, scikit-learn comes with a handy tool to do that:
from sklearn.metrics import accuracy_score
accuracy_score(y_test, predictions)
This number is the probability for the digits in the test sample to be classified in the right category, meaning that we get 91.6% of the digits right, and 8.4% wrong.
We managed to get an 91.6% accuracy with this very simple neural network. Not too bad!
However, this is only a first try.
Actually, I must confess that I chose to use a simplistic network to keep the performance on the low side, so that we can optimize it later on.
In this hands-on tutorial, you have learned:
It's only the beginning! In future posts we will:
Please let me know what you think in the comments! I’ll try and answer all questions.
And if you liked this article, you can subscribe to my mailing list to be notified of new posts (no more than one mail per week I promise.)
You can join my mailing list for new posts and exclusive content: