Apprenez à déployer votre modèle d'apprentissage profond avec une API REST pour vos clients.
Dans ce court article, vous apprendrez à déployer un modèle d'apprentissage profond avec une API REST, à l'aide de Flask RESTful.
Vous avez travaillé dur pour régler votre modèle pour de meilleures performances. Toutes mes félicitations ! Il est maintenant temps de le sortir de vos notebooks et de le rendre public.
Si vous n'avez jamais fait cela, vous n'avez peut-être aucune idée de la marche à suivre, surtout si vous avez une formation en science des données plutôt qu'en développement de logiciels.
Les API REST sont un moyen très courant de fournir tout type de service. Vous pouvez les utiliser dans le backend d'un site Web, ou même fournir un accès direct à vos API à vos clients. Peut-être avez-vous déjà utilisé vous-même les API REST pour interagir avec des services tels que Google Vision ou Amazon Textract .
Aujourd'hui, vous apprendrez à :
Cela peut sembler ambitieux, mais cela nous prendra en réalité moins de 50 lignes de code.
Mais méfiez-vous ! Je laisse de côté tous les éléments opérationnels tels que la sécurité, le déploiement conteneurisé, le serveur Web, etc. Ces points seront abordés dans un prochain article. En attendant, n'utilisez pas ce code tel qu'il est en production.
Voici le repo Github avec le code de ce tutoriel.
Mon chat capuchon, modèle pour le modèle.
Tout d'abord, installons les outils dont nous avons besoin :
Comme d'habitude, nous utiliserons Anaconda . Installez-le d'abord, puis créez un environnement avec les outils nécessaires :
conda create -n dlflask python=3.7 tensorflow flask pillow
Nous avons utilisé python 3.7 car, pour le moment, les versions plus récentes de python semblent entraîner des conflits entre les dépendances des packages flask et tensorflow.
Activez maintenant l'environnement :
conda activate dlflask
Enfin, nous installons flask RESTful avec pip, car il n'est pas disponible dans conda :
pip install flask-restful
La première chose dont nous avons besoin est un modèle d'apprentissage en profondeur à intégrer dans notre API REST.
Nous ne voulons pas perdre de temps là-dessus aujourd'hui, nous allons donc simplement utiliser un modèle pré-entraîné de Keras.
J'ai opté pour ResNet50, qui est un modèle de classification haute performance entraîné avec ImageNet, un échantillon de données avec 1000 catégories et 15 millions d'images au moment de la rédaction de cet article.
Créez un module python appelé predict_resnet50.py
contenant ce code :
import tensorflow.keras.applications.resnet50 as resnet50
from tensorflow.keras.preprocessing import image
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # so that it runs on a mac
def predict(fname):
"""returns top 5 categories for an image.
:param fname : path to the file
"""
# ResNet50 is trained on color images with 224x224 pixels
input_shape = (224, 224, 3)
# load and resize image ----------------------
img = image.load_img(fname, target_size=input_shape[:2])
x = image.img_to_array(img)
# preprocess image ---------------------------
# make a batch
import numpy as np
x = np.expand_dims(x, axis=0)
print(x.shape)
# apply the preprocessing function of resnet50
img_array = resnet50.preprocess_input(x)
model = resnet50.ResNet50(weights='imagenet',
input_shape=input_shape)
preds = model.predict(x)
return resnet50.decode_predictions(preds)
if __name__ == '__main__':
import pprint
import sys
file_name = sys.argv[1]
results = predict(file_name)
pprint.pprint(results)
Assurez-vous de lire les commentaires pour comprendre ce que fait le script.
Avant d'aller plus loin, vous devriez vérifier que le script fonctionne (j'utilise l'image de mon chat, mais vous pouvez utiliser n'importe quelle image) :
python predict_resnet50.py capuchon.jpg
Vous devriez obtenir quelque chose comme :
[[
('n02123159', 'tiger_cat', 0.58581424),
('n02124075', 'Egyptian_cat', 0.21068987),
('n02123045', 'tabby', 0.14554422),
('n03938244', 'pillow', 0.008319859),
('n02127052', 'lynx', 0.006789663)
]]
Les prédictions sont plutôt bonnes ! il s'agit bien d'un chat tigré, et les deux catégories suivantes sont également des chats tigrés. Vient ensuite « oreiller », mais avec une probabilité beaucoup plus faible. Ce n'est pas surprenant : le chat est sur l'oreiller.
Nous voyons que notre script fonctionne, alors commençons avec flask.
Les deux frameworks web python les plus connus sont:
On pourrait construire une API REST avec Flask assez facilement, mais c'est encore plus rapide avec son extension Flask RESTful.
Commençons par un simple exemple "Hello World".
Créez un module python appelé rest_api_hello.py
avec ce code :
from flask import Flask
from flask_restful import Resource, Api
app = Flask(__name__)
app.logger.setLevel('INFO')
api = Api(app)
class Hello(Resource):
def get(self):
return {'hello': 'world'}
api.add_resource(Hello, '/hello')
if __name__ == '__main__':
app.run(debug=True)
Démarrez ensuite cette application sur le serveur de débogage de Flask :
python rest_api_hello.py
Maintenant, vous pouvez envoyer une requête au serveur avec curl :
curl localhost:5000/hello
Qui donne:
{
"hello": "world"
}
Alternativement, vous pouvez pointer votre navigateur sur http://localhost:5000/hello , et vous obtiendrez la même chose.
Vous voyez que c'est assez simple.
Si vous souhaitez comprendre plus en détail ce qui se passe, prenez juste une heure pour suivre les tutoriels pour Flask et Flask RESTful.
Mais pour l'instant, connectons le modèle à notre API.
Créons une autre application de flacon pour classer les images.
Pour ce faire, créez un module python appelé rest_api_predict.py
avec ce code :
from flask import Flask
from flask_restful import Resource, Api, reqparse
from werkzeug.datastructures import FileStorage
from predict_resnet50 import predict
import tempfile
app = Flask(__name__)
app.logger.setLevel('INFO')
api = Api(app)
parser = reqparse.RequestParser()
parser.add_argument('file',
type=FileStorage,
location='files',
required=True,
help='provide a file')
class Image(Resource):
def post(self):
args = parser.parse_args()
the_file = args['file']
# save a temporary copy of the file
ofile, ofname = tempfile.mkstemp()
the_file.save(ofname)
# predict
results = predict(ofname)[0]
# formatting the results as a JSON-serializable structure:
output = {'top_categories': []}
for _, categ, score in results:
output['top_categories'].append((categ, float(score)))
return output
api.add_resource(Image, '/image')
if __name__ == '__main__':
app.run(debug=True)
Les principales différences sont que :
/image
, capable de recevoir des images.Démarrez le serveur d'applications :
python rest_api_predict.py
Et envoyez une demande avec une image (notez le @):
curl localhost:5000/image -F file=@capuchon.jpg
Cela devrait donner :
{
"top_categories": [
["tiger_cat", 0.585814237594604],
["Egyptian_cat", 0.21068987250328064],
["tabby", 0.14554421603679657],
["pillow", 0.008319859392940998],
["lynx", 0.006789662875235081]
}
Dans cet article, vous avez appris à :
La prochaine fois, nous verrons comment servir l'application Web avec un serveur Web approprié, protégé derrière un proxy inverse, en https et avec authentification de l'utilisateur. Tant que vous ne savez pas comment procéder, n'utilisez pas ce code en production.
N'hésitez pas à me donner votre avis dans les commentaires ! Je répondrai à toutes les questions.
Et si vous avez aimé cet article, vous pouvez souscrire à ma newsletter pour être prévenu lorsque j'en sortirai un nouveau. Pas plus d'un mail par semaine, promis!
Rejoignez ma mailing list pour plus de posts et du contenu exclusif: