Déployer un modèle d'apprentissage en profondeur avec Flask RESTful

Apprenez à déployer votre modèle d'apprentissage profond avec une API REST pour vos clients.

Dans ce court article, vous apprendrez à déployer un modèle d'apprentissage profond avec une API REST, à l'aide de Flask RESTful.

Vous avez travaillé dur pour régler votre modèle pour de meilleures performances. Toutes mes félicitations ! Il est maintenant temps de le sortir de vos notebooks et de le rendre public.

Si vous n'avez jamais fait cela, vous n'avez peut-être aucune idée de la marche à suivre, surtout si vous avez une formation en science des données plutôt qu'en développement de logiciels.

Les API REST sont un moyen très courant de fournir tout type de service. Vous pouvez les utiliser dans le backend d'un site Web, ou même fournir un accès direct à vos API à vos clients. Peut-être avez-vous déjà utilisé vous-même les API REST pour interagir avec des services tels que Google Vision ou Amazon Textract .

Aujourd'hui, vous apprendrez à :

  • écrire un petit script python pour classer les images avec un réseau de neurones profond pré-entraîné, ResNet50, avec Keras et Tensorflow ;
  • créer votre propre API REST avec Flask RESTful ;
  • créer une API de classification d'images permettant à n'importe qui de pousser son image et de récupérer les résultats de la classification.

Cela peut sembler ambitieux, mais cela nous prendra en réalité moins de 50 lignes de code.

Mais méfiez-vous ! Je laisse de côté tous les éléments opérationnels tels que la sécurité, le déploiement conteneurisé, le serveur Web, etc. Ces points seront abordés dans un prochain article. En attendant, n'utilisez pas ce code tel qu'il est en production.

Voici le repo Github avec le code de ce tutoriel.

My cat Capuchon

Mon chat capuchon, modèle pour le modèle.

Installation

Tout d'abord, installons les outils dont nous avons besoin :

  • python
  • flask : micro framework web léger pour python
  • flask RESTful : créez rapidement des API REST avec Flask.
  • tensorflow (+ keras) : modèle de classification d'images pré-entraîné
  • Pillow : la librairie d'imagerie python, nécessaire pour pré-traiter nos images avant classification

Comme d'habitude, nous utiliserons Anaconda . Installez-le d'abord, puis créez un environnement avec les outils nécessaires :

conda create -n dlflask python=3.7 tensorflow flask pillow

Nous avons utilisé python 3.7 car, pour le moment, les versions plus récentes de python semblent entraîner des conflits entre les dépendances des packages flask et tensorflow.

Activez maintenant l'environnement :

conda activate dlflask

Enfin, nous installons flask RESTful avec pip, car il n'est pas disponible dans conda :

pip install flask-restful

Prédire un chat

La première chose dont nous avons besoin est un modèle d'apprentissage en profondeur à intégrer dans notre API REST.

Nous ne voulons pas perdre de temps là-dessus aujourd'hui, nous allons donc simplement utiliser un modèle pré-entraîné de Keras.

J'ai opté pour ResNet50, qui est un modèle de classification haute performance entraîné avec ImageNet, un échantillon de données avec 1000 catégories et 15 millions d'images au moment de la rédaction de cet article.

Créez un module python appelé predict_resnet50.py contenant ce code :

import tensorflow.keras.applications.resnet50 as resnet50
from tensorflow.keras.preprocessing import image
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'  # so that it runs on a mac


def predict(fname):
    """returns top 5 categories for an image.
    
    :param fname : path to the file 
    """
    # ResNet50 is trained on color images with 224x224 pixels
    input_shape = (224, 224, 3)

    # load and resize image ----------------------
    
    img = image.load_img(fname, target_size=input_shape[:2])
    x = image.img_to_array(img)

    # preprocess image ---------------------------

    # make a batch
    import numpy as np
    x = np.expand_dims(x, axis=0)
    print(x.shape)

    # apply the preprocessing function of resnet50
    img_array = resnet50.preprocess_input(x)

    model = resnet50.ResNet50(weights='imagenet',
                              input_shape=input_shape)
    preds = model.predict(x)
    return resnet50.decode_predictions(preds)


if __name__ == '__main__':

    import pprint
    import sys

    file_name = sys.argv[1]
    results = predict(file_name)
    pprint.pprint(results)

Assurez-vous de lire les commentaires pour comprendre ce que fait le script.

Avant d'aller plus loin, vous devriez vérifier que le script fonctionne (j'utilise l'image de mon chat, mais vous pouvez utiliser n'importe quelle image) :

python predict_resnet50.py capuchon.jpg

Vous devriez obtenir quelque chose comme :

[[ 
  ('n02123159', 'tiger_cat', 0.58581424),
  ('n02124075', 'Egyptian_cat', 0.21068987),
  ('n02123045', 'tabby', 0.14554422),
  ('n03938244', 'pillow', 0.008319859),
  ('n02127052', 'lynx', 0.006789663)
]]

Les prédictions sont plutôt bonnes ! il s'agit bien d'un chat tigré, et les deux catégories suivantes sont également des chats tigrés. Vient ensuite « oreiller », mais avec une probabilité beaucoup plus faible. Ce n'est pas surprenant : le chat est sur l'oreiller.

Nous voyons que notre script fonctionne, alors commençons avec flask.

Hello World avec Flask RESTful

Les deux frameworks web python les plus connus sont:

  • Django : tout est là pour construire un site web complet. C'est mon choix de prédilection pour les sites Web complexes ou les applications Web, et c'est ce que j'utilise pour alimenter ce blog .
  • flask est un micro framework web léger : il est idéal pour créer des sites web ou des services web simples. Mais il est aussi possible de faire des choses très complexes. L'avantage de Flask par rapport à Django est que vous savez ce que vous faites.

On pourrait construire une API REST avec Flask assez facilement, mais c'est encore plus rapide avec son extension Flask RESTful.

Commençons par un simple exemple "Hello World".

Créez un module python appelé rest_api_hello.py avec ce code :

from flask import Flask
from flask_restful import Resource, Api

app = Flask(__name__)
app.logger.setLevel('INFO')

api = Api(app)


class Hello(Resource):

    def get(self):
        return {'hello': 'world'}


api.add_resource(Hello, '/hello')

if __name__ == '__main__':
    app.run(debug=True)

Démarrez ensuite cette application sur le serveur de débogage de Flask :

python rest_api_hello.py

Maintenant, vous pouvez envoyer une requête au serveur avec curl :

curl localhost:5000/hello

Qui donne:

{
"hello": "world"
}

Alternativement, vous pouvez pointer votre navigateur sur http://localhost:5000/hello , et vous obtiendrez la même chose.

Vous voyez que c'est assez simple.

Si vous souhaitez comprendre plus en détail ce qui se passe, prenez juste une heure pour suivre les tutoriels pour Flask et Flask RESTful.

Mais pour l'instant, connectons le modèle à notre API.

Classer les images avec une API REST

Créons une autre application de flacon pour classer les images.

Pour ce faire, créez un module python appelé rest_api_predict.py avec ce code :

from flask import Flask
from flask_restful import Resource, Api, reqparse
from werkzeug.datastructures import FileStorage
from predict_resnet50 import predict
import tempfile

app = Flask(__name__)
app.logger.setLevel('INFO')

api = Api(app)

parser = reqparse.RequestParser()
parser.add_argument('file',
                    type=FileStorage,
                    location='files',
                    required=True,
                    help='provide a file')

class Image(Resource):

    def post(self):
        args = parser.parse_args()
        the_file = args['file']
        # save a temporary copy of the file
        ofile, ofname = tempfile.mkstemp()
        the_file.save(ofname)
        # predict
        results = predict(ofname)[0]
        # formatting the results as a JSON-serializable structure:
        output = {'top_categories': []}
        for _, categ, score in results:
            output['top_categories'].append((categ, float(score)))

        return output


api.add_resource(Image, '/image')

if __name__ == '__main__':
    app.run(debug=True)

Les principales différences sont que :

  • nous implémentons une méthode POST pour notre /image , capable de recevoir des images.
  • la méthode renvoie un JSON avec les catégories supérieures. Nous ne pouvons pas renvoyer directement les résultats des prédictions, car le dictionnaire des résultats ne peut pas être sérialisé en tant que JSON car il contient des objets float32.

Démarrez le serveur d'applications :

python rest_api_predict.py

Et envoyez une demande avec une image (notez le @):

curl localhost:5000/image -F file=@capuchon.jpg

Cela devrait donner :

{
  "top_categories": [
     ["tiger_cat", 0.585814237594604], 
     ["Egyptian_cat", 0.21068987250328064],
     ["tabby", 0.14554421603679657], 
     ["pillow", 0.008319859392940998],
     ["lynx", 0.006789662875235081]
 }

Conclusion

Dans cet article, vous avez appris à :

  • écrire un petit script python pour classer les images avec un réseau de neurones profonds pré-entraîné, ResNet50, avec Keras et Tensorflow ;
  • créer votre propre API REST avec Flask RESTful ;
  • créer une API de classification d'images permettant à n'importe qui de pousser son image et de récupérer les résultats de la classification.

La prochaine fois, nous verrons comment servir l'application Web avec un serveur Web approprié, protégé derrière un proxy inverse, en https et avec authentification de l'utilisateur. Tant que vous ne savez pas comment procéder, n'utilisez pas ce code en production.


N'hésitez pas à me donner votre avis dans les commentaires ! Je répondrai à toutes les questions.

Et si vous avez aimé cet article, vous pouvez souscrire à ma newsletter pour être prévenu lorsque j'en sortirai un nouveau. Pas plus d'un mail par semaine, promis!

Retour


Encore plus de data science et de machine learning !

Rejoignez ma mailing list pour plus de posts et du contenu exclusif:

Je ne partagerai jamais vos infos.
Partagez si vous aimez cet article: