在使用 TensorFlow 完成模型训练后,你可以按照以下步骤完成模型的保存、加载、推理以及部署到物理服务器上,从而接受客户端的预测请求。
TensorFlow 提供了两种模型保存方式:SavedModel
格式和 HDF5 格式。推荐使用 SavedModel
,因为它是 TensorFlow 的标准格式,支持更广泛的部署场景。
import tensorflow as tf
# 假设你已经有一个训练好的模型
model = ... # 你的训练好的模型
# 保存为 SavedModel 格式
model.save("saved_model/my_model")
# 或者保存为 HDF5 格式
model.save("my_model.h5")
加载模型后,你可以直接使用它进行推理。
SavedModel
格式模型# 加载 SavedModel
loaded_model = tf.keras.models.load_model("saved_model/my_model")
# 假设输入数据是一个 NumPy 数组
import numpy as np
test_input = np.random.rand(1, 28, 28) # 根据模型的输入形状修改
prediction = loaded_model.predict(test_input)
print("Prediction:", prediction)
# 加载 HDF5 模型
loaded_model = tf.keras.models.load_model("my_model.h5")
# 推理同上
prediction = loaded_model.predict(test_input)
print("Prediction:", prediction)
在你的服务器上,可以选择以下方案之一来部署模型并提供推理服务:
Flask 是一个轻量级的 Python Web 框架,非常适合搭建简单的推理服务。
pip install flask tensorflow
下面是一个简单的 Flask 应用,加载模型并提供推理接口:
from flask import Flask, request, jsonify
import tensorflow as tf
import numpy as np
# 创建 Flask 应用
app = Flask(__name__)
# 加载模型
model = tf.keras.models.load_model("saved_model/my_model")
# 定义推理接口
@app.route("/predict", methods=["POST"])
def predict():
try:
# 从请求中获取输入数据(假设为 JSON 格式)
data = request.get_json()
if "input" not in data:
return jsonify({"error": "Missing 'input' in request"}), 400
# 将输入转换为 NumPy 数组
input_data = np.array(data["input"])
# 模型预测
predictions = model.predict(input_data).tolist()
# 返回预测结果
return jsonify({"predictions": predictions})
except Exception as e:
return jsonify({"error": str(e)}), 500
# 启动服务
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000)
运行上述代码后,服务将监听 0.0.0.0:5000
,可以通过 HTTP POST 请求访问推理接口。
python app.py
你可以使用工具(如 curl
或 Postman)发送预测请求。例如:
curl -X POST -H "Content-Type: application/json" -d '{"input": [[1.0, 2.0, 3.0]]}' http://127.0.0.1:5000/predict
返回结果可能如下:
{
"predictions": [[0.1, 0.9, 0.0]]
}
TensorFlow Serving 是官方提供的高性能模型部署工具。需要安装 TensorFlow Serving 并导出模型为 SavedModel
格式。
在 Ubuntu 上可以通过以下方式安装:
# 添加 TensorFlow Serving 源
echo "deb [signed-by=/usr/share/keyrings/tf-serving.gpg] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo gpg --dearmor -o /usr/share/keyrings/tf-serving.gpg
# 更新并安装
sudo apt-get update && sudo apt-get install tensorflow-model-server
假设你的模型保存在 saved_model/my_model
:
tensorflow_model_server --rest_api_port=8501 --model_name=my_model --model_base_path=$(pwd)/saved_model/my_model
TensorFlow Serving 会在 http://localhost:8501/v1/models/my_model:predict
提供 REST 接口。可以通过 curl
测试:
curl -X POST -H "Content-Type: application/json" -d '{"instances": [[1.0, 2.0, 3.0]]}' http://127.0.0.1:8501/v1/models/my_model:predict
返回结果格式会类似:
{
"predictions": [[0.1, 0.9, 0.0]]
}
确保你的服务器允许外部访问相关端口(如 5000
或 8501
)。在 Linux 上可以使用以下命令打开端口:
sudo ufw allow 5000
sudo ufw allow 8501
如果需要绑定域名或使用 HTTPS,可以使用 Nginx 或 Apache 作为反向代理。以下是 Nginx 的基本配置示例:
server {
listen 80;
server_name your_domain.com;
location / {
proxy_pass http://127.0.0.1:5000; # 或 8501
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
}
}
配置完成后,重启 Nginx:
sudo systemctl restart nginx
model.save
保存为 SavedModel
或 HDF5
格式。tf.keras.models.load_model
加载模型并调用 predict
。根据你的需求选择合适的部署方式。如果是简单项目,Flask 足够用;如果需要高性能或扩展性,推荐 TensorFlow Serving。