Deploy a Deep Learning Model with Flask RESTful

Learn how to deploy your deep learning model as a REST API for your customers.

In this short article, you will learn how to deploy a deep learning model as a REST API with Flask RESTful.

Ok, you've worked hard tuning your model for best performance. Congratulations! It is now time to get it out of your notebooks and to bring it to the world.

If you've never done that, you might have no clue how to proceed, especially if your background is in data science rather than in software development.

REST APIs is a very common way to provide any kind of service. You can use them in the backend of a website, or even provide direct access to your APIs to your customers. Maybe you have already used REST APIs yourself to interact with services such as Google Vision or Amazon Textract.

Today you will learn how to :

  • write a small python script to classify images with a pre-trained deep neural network, ResNet50, with Keras and Tensorflow;
  • create your own REST API with Flask RESTful;
  • create an image classification API allowing anybody to push their image and get back the classification results.

This might seem ambitious, but it will actually take us less than 50 lines of code.

But beware ! I'm leaving out all the operational stuff such as security, containerized deployment, web server, etc. These points will be addressed in a future post. In the meanwhile, don't use this code as is in production.

Here is the Github repo with the code for this tutorial.

My cat Capuchon

My cat capuchon, proudly modelling for the model

Installation

First, let's install the tools we need :

  • python
  • flask : lightweight micro web framework for python
  • flask RESTful : quickly build REST APIs with Flask.
  • tensorflow (+ keras) : pre-trained image classification model
  • pillow : the python imaging library, necessary to preprocess our images before classification

As usual, we will use Anaconda. First install it, and then create an environment with the necessary tools:

conda create -n dlflask python=3.7 tensorflow flask pillow

We used python 3.7 because, at the moment, more recent versions of python seem to lead to conflicts between the dependencies of the flask and tensorflow packages.

Now activate the environment:

conda activate dlflask

Finally, we install flask RESTful with pip, as it is not available in conda :

pip install flask-restful

Predict a cat

The first thing we need is a deep learning model to integrate in our REST API.

We don't want to waste any time on this today, so we are simply going to use a pre-trained model from Keras.

I went for ResNet50, which is a high-performance classification model trained on ImageNet, a dataset with 1000 categories and 15 million images at the time of writing.

Create a python module called predict_resnet50.py with this code :

import tensorflow.keras.applications.resnet50 as resnet50
from tensorflow.keras.preprocessing import image
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # so that it runs on a mac


def predict(fname):
    """returns top 5 categories for an image.
    
    :param fname : path to the file 
    """
    # ResNet50 is trained on color images with 224x224 pixels
    input_shape = (224, 224, 3)

    # load and resize image ----------------------
    
    img = image.load_img(fname, target_size=input_shape[:2])
    x = image.img_to_array(img)

    # preprocess image ---------------------------

    # make a batch
    import numpy as np
    x = np.expand_dims(x, axis=0)
    print(x.shape)

    # apply the preprocessing function of resnet50
    img_array = resnet50.preprocess_input(x)

    model = resnet50.ResNet50(weights='imagenet',
                              input_shape=input_shape)
    preds = model.predict(x)
    return resnet50.decode_predictions(preds)


if __name__ == '__main__':

    import pprint
    import sys

    file_name = sys.argv[1]
    results = predict(file_name)
    pprint.pprint(results)

Make sure to read the comments to understand what the script is doing.

Before going further, you should check that the script works (I'm using the image of my cat, but you can use any image you want:)

python predict_resnet50.py capuchon.jpg

You should get something like:

[[('n02123159', 'tiger_cat', 0.58581424),

('n02124075', 'Egyptian_cat', 0.21068987),

('n02123045', 'tabby', 0.14554422),

('n03938244', 'pillow', 0.008319859),

('n02127052', 'lynx', 0.006789663)]]

The predictions look pretty good! this is indeed a tiger cat, and the next two categories are also tiger cats. Then comes "pillow" albeit with a much smaller probability. This is not surprising: the cat is on pillow.

We see that our script is working, so let's get started with flask.

Hello World with Flask RESTful

In my opinion, there are two notable web frameworks for python :

  • django has batteries included : everything is there to build a complete web site. This is my go-to choice for full websites or web apps, and this is what I'm using to power this blog.
  • flask is a lightweight micro web framework : it's ideal for building simple websites or web services. But it's also possible to do very complex things. The advantage of Flask with respect to Django is that you know what you're doing.

One could build a REST API with Flask rather easily, but it's even faster with its Flask RESTful extension.

Let's start with a simple "Hello World" example.

Create a python module called rest_api_hello.py with this code:

from flask import Flask
from flask_restful import Resource, Api

app = Flask(__name__)
app.logger.setLevel('INFO')

api = Api(app)


class Hello(Resource):

    def get(self):
        return {'hello': 'world'}


api.add_resource(Hello, '/hello')

if __name__ == '__main__':
    app.run(debug=True)

Then start this app on the flask debug server:

python rest_api_hello.py

Now, you can send a request to the server with curl:

curl localhost:5000/hello

Which gives:

{
"hello": "world"
}

Alternatively, you can point your browser to http://localhost:5000/hello, and you will get the same thing.

See ? that's quite easy.

If you want to understand in more details what's going on, just take one hour to follow the tutorials for Flask and Flask RESTful.

But for now, let's plug the model into our API.

Classify images with a REST API

Let's create another flask app to classify images.

To do this, create a python module called rest_api_predict.py with this code:

from flask import Flask
from flask_restful import Resource, Api, reqparse
from werkzeug.datastructures import FileStorage
from predict_resnet50 import predict
import tempfile

app = Flask(__name__)
app.logger.setLevel('INFO')

api = Api(app)

parser = reqparse.RequestParser()
parser.add_argument('file',
                    type=FileStorage,
                    location='files',
                    required=True,
                    help='provide a file')

class Image(Resource):

    def post(self):
        args = parser.parse_args()
        the_file = args['file']
        # save a temporary copy of the file
        ofile, ofname = tempfile.mkstemp()
        the_file.save(ofname)
        # predict
        results = predict(ofname)[0]
        # formatting the results as a JSON-serializable structure:
        output = {'top_categories': []}
        for _, categ, score in results:
            output['top_categories'].append((categ, float(score)))

        return output


api.add_resource(Image, '/image')

if __name__ == '__main__':
    app.run(debug=True)

The main differences are that :

  • we implement a POST method for our /image endpoint, able to receive images.
  • the method returns a JSON with the top categories. We cannot send back the results of the predictions directly, because the results dictionary cannot be serialized as a JSON because it contains float32 objects.

Start the app server:

python rest_api_predict.py

And send a request with an image (note the @):

curl localhost:5000/image -F file=@capuchon.jpg

This should give:

{

"top_categories": [
[
"tiger_cat",
0.5858142375946045
],
[
"Egyptian_cat",
0.21068987250328064
],
[
"tabby",
0.14554421603679657
],
[
"pillow",
0.008319859392940998
],
[
"lynx",
0.006789662875235081
]
]
}

Conclusion

In this post, you have learned how to :

  • write a small python script to classify images with a pre-trained deep neural network, ResNet50, with Keras and Tensorflow;
  • create your own REST API with Flask RESTful;
  • create an image classification API allowing anybody to push their image and get back the classification results.

Next time, we will see how to serve the web app with a proper web server, protected behind a reverse proxy, in https, and with user authentication. Until you know how to do this, don't use this code in production.


Please let me know what you think in the comments! I’ll try and answer all questions.

And if you liked this article, you can subscribe to my mailing list to be notified of new posts (no more than one mail per week I promise.)

Back Home