【AI模型部署】基于flask的pytorch简单分类模型部署

论坛 期权论坛 编程之家     
选择匿名的用户   2021-6-2 15:51   1889   0

一.简介

通过flask框架,部署pytorch模型后,可以通过不同主机向服务端主机发送图像路径或图片请求服务,处理后返回结果。本文的返回结果是,对图片的分类结果】

1.1服务端

部署成功后,服务端接收不同主机请求的过程图如下:

服务端本地图片信息 :

1.2返回结果

其他主机的浏览器向服务端传图片路径(图片在服务端本机)

其他主机直接传送图片到服务端(图片在客户端)

二、实现过程

2.1测试flask服务

【参考】https://flask.palletsprojects.com/en/1.1.x/quickstart/#a-minimal-application

安装

pip install flask

测试程序:hello.py

from flask import Flask 
app = Flask(__name__)

@app.route('/')
def predict():
 return "hello world!  It is flask!"
 
if __name__ == '__main__':
    app.run()

运行:

python hello.py

结果

(本地浏览器访问:http://127.0.0.1:5000/

2.2 pytorch模型

代码

import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, request


app = Flask(__name__)  # 固定写法
imagenet_class_index = json.load(open('imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()


def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

#'/predict'是会影响请求的格式,可自由改名。
# 需要添加“get”方法,才能直接通过浏览器发送请求
# 请求的路径path是图片的路径,一般是在服务端本机
# 浏览器输入实例,请换自己的ip和路径:http://192.168.1.139:5005/predict?path=/home/ai004/sdg4.jpg


@app.route('/predict', methods=['GET', 'POST'])
def predict():
    if request.method == 'POST':  # 接收传输的图片
        file = request.files['file']
    # zxy add for GET
    else:
        image_file = request.args.get("path") #接收其他客户端浏览器发送的请求
        file = open(image_file, 'rb')
    img_bytes = file.read()
    class_id, class_name = get_prediction(image_bytes=img_bytes)
    return jsonify({'class_id': class_id, 'class_name': class_name})


if __name__ == '__main__':
    # app.run() # 原工程的写法,默认只能本机访问
    app.run(host='0.0.0.0', port=5005)  # 使其他主机可以访问服务

外部主机请求服务(需修改代码,指定ip等)

#从外部主机发送图片到服务器,并接收返回结果

curl -X POST -F file=@2.jpg http://192.168.1.139:5005/predict

# 从浏览器发出请求,图片在服务端本地

http://192.168.1.139:5005/predict?path=/home/ai004/sdg4.jpg

结果

如博文简介部分所示

三、完整工程

【本博客代码】https://gitee.com/zengxy2020/csdn/tree/master/flask

分享到 :
0 人收藏
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

积分:3875789
帖子:775174
精华:0
期权论坛 期权论坛
发布
内容

下载期权论坛手机APP