MinAI - Về trang chủ
Hướng dẫn
10/1340 phút
Đang tải...

Model Deployment

Deploy ML Models với FastAPI, Docker, và Cloud Services

Model Deployment

Train xong model rồi, giờ deploy cho users sử dụng. Bài này cover full stack deployment: FastAPI → Docker → Cloud.

🎯 Mục tiêu

  • Build REST API với FastAPI
  • Containerize với Docker
  • Batch vs Real-time serving
  • Cloud deployment options

1. Model Serving Patterns

1.1 Overview

Serving Patterns
Real-time (Online)Batch (Offline)
InterfaceREST API, gRPCScheduled jobs
Speed< 100msDaily/hourly predictions
VolumeOne-at-a-timeMillions at once
ExamplesFraud detection, ChatbotEmail campaign targeting, Product recommendations, Churn scoring

1.2 Model Serialization

Python
1import joblib
2import json
3
4# Save model + metadata
5def save_model(model, preprocessor, metadata, path="model_artifacts"):
6 import os
7 os.makedirs(path, exist_ok=True)
8
9 joblib.dump(model, f"{path}/model.pkl")
10 joblib.dump(preprocessor, f"{path}/preprocessor.pkl")
11
12 with open(f"{path}/metadata.json", "w") as f:
13 json.dump(metadata, f, indent=2)
14
15 print(f"Model saved to {path}/")
16
17save_model(
18 model=trained_model,
19 preprocessor=pipeline_preprocessor,
20 metadata={
21 "model_type": "GradientBoosting",
22 "version": "1.2.0",
23 "accuracy": 0.94,
24 "features": feature_names,
25 "training_date": "2025-01-15"
26 }
27)

2. REST API with FastAPI

2.1 Basic API

Python
1# api.py
2from fastapi import FastAPI, HTTPException
3from pydantic import BaseModel, Field
4import joblib
5import numpy as np
6
7app = FastAPI(
8 title="Churn Prediction API",
9 version="1.0.0"
10)
11
12# Load model at startup
13model = joblib.load("model_artifacts/model.pkl")
14preprocessor = joblib.load("model_artifacts/preprocessor.pkl")
15
16# Request schema
17class CustomerData(BaseModel):
18 age: int = Field(ge=18, le=100, description="Customer age")
19 monthly_spend: float = Field(ge=0, description="Monthly spending")
20 tenure_months: int = Field(ge=0, description="Months as customer")
21 support_tickets: int = Field(ge=0, description="Support tickets filed")
22 contract_type: str = Field(description="Contract: monthly/annual/two_year")
23
24# Response schema
25class PredictionResponse(BaseModel):
26 churn_probability: float
27 prediction: str
28 confidence: float
29
30@app.post("/predict", response_model=PredictionResponse)
31def predict(data: CustomerData):
32 try:
33 # Convert to array
34 features = np.array([[
35 data.age,
36 data.monthly_spend,
37 data.tenure_months,
38 data.support_tickets
39 ]])
40
41 # Preprocess
42 features_processed = preprocessor.transform(features)
43
44 # Predict
45 proba = model.predict_proba(features_processed)[0]
46 churn_prob = float(proba[1])
47 prediction = "churn" if churn_prob > 0.5 else "no_churn"
48 confidence = float(max(proba))
49
50 return PredictionResponse(
51 churn_probability=round(churn_prob, 4),
52 prediction=prediction,
53 confidence=round(confidence, 4)
54 )
55 except Exception as e:
56 raise HTTPException(status_code=500, detail=str(e))
57
58@app.get("/health")
59def health():
60 return {"status": "healthy", "model_version": "1.0.0"}

2.2 Batch Endpoint

Python
1from typing import List
2
3class BatchRequest(BaseModel):
4 customers: List[CustomerData]
5
6class BatchResponse(BaseModel):
7 predictions: List[PredictionResponse]
8 total: int
9
10@app.post("/predict/batch", response_model=BatchResponse)
11def predict_batch(request: BatchRequest):
12 if len(request.customers) > 1000:
13 raise HTTPException(400, "Max 1000 records per batch")
14
15 features = np.array([
16 [c.age, c.monthly_spend, c.tenure_months, c.support_tickets]
17 for c in request.customers
18 ])
19
20 features_processed = preprocessor.transform(features)
21 probas = model.predict_proba(features_processed)
22
23 predictions = []
24 for proba in probas:
25 churn_prob = float(proba[1])
26 predictions.append(PredictionResponse(
27 churn_probability=round(churn_prob, 4),
28 prediction="churn" if churn_prob > 0.5 else "no_churn",
29 confidence=round(float(max(proba)), 4)
30 ))
31
32 return BatchResponse(predictions=predictions, total=len(predictions))

2.3 Run & Test

Bash
1# Start server
2uvicorn api:app --host 0.0.0.0 --port 8000 --reload
3
4# Test (another terminal)
5curl -X POST "http://localhost:8000/predict" \
6 -H "Content-Type: application/json" \
7 -d '{"age": 35, "monthly_spend": 89.5, "tenure_months": 24, "support_tickets": 3, "contract_type": "monthly"}'
Python
1# Python client
2import requests
3
4response = requests.post("http://localhost:8000/predict", json={
5 "age": 35,
6 "monthly_spend": 89.5,
7 "tenure_months": 24,
8 "support_tickets": 3,
9 "contract_type": "monthly"
10})
11print(response.json())
12# {"churn_probability": 0.7234, "prediction": "churn", "confidence": 0.7234}

3. Docker Containerization

3.1 Dockerfile

dockerfile
1# Dockerfile
2FROM python:3.11-slim
3
4WORKDIR /app
5
6# Install dependencies
7COPY requirements.txt .
8RUN pip install --no-cache-dir -r requirements.txt
9
10# Copy application
11COPY api.py .
12COPY model_artifacts/ model_artifacts/
13
14# Expose port
15EXPOSE 8000
16
17# Health check
18HEALTHCHECK CMD curl -f http://localhost:8000/health || exit 1
19
20# Run
21CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "8000"]

3.2 Requirements

Ví dụ
1# requirements.txt
2fastapi==0.109.0
3uvicorn==0.27.0
4scikit-learn==1.4.0
5joblib==1.3.2
6numpy==1.26.3
7pydantic==2.5.3

3.3 Build & Run

Bash
1# Build image
2docker build -t churn-model:v1.0 .
3
4# Run container
5docker run -d -p 8000:8000 --name churn-api churn-model:v1.0
6
7# Test
8curl http://localhost:8000/health
9
10# View logs
11docker logs churn-api
12
13# Stop
14docker stop churn-api

3.4 Docker Compose (with monitoring)

yaml
1# docker-compose.yml
2version: '3.8'
3
4services:
5 model-api:
6 build: .
7 ports:
8 - "8000:8000"
9 environment:
10 - MODEL_VERSION=1.0.0
11 - LOG_LEVEL=info
12 volumes:
13 - ./model_artifacts:/app/model_artifacts
14 restart: unless-stopped
15 healthcheck:
16 test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
17 interval: 30s
18 timeout: 10s
19 retries: 3
20
21 prometheus:
22 image: prom/prometheus
23 ports:
24 - "9090:9090"
25 volumes:
26 - ./prometheus.yml:/etc/prometheus/prometheus.yml
27
28 grafana:
29 image: grafana/grafana
30 ports:
31 - "3000:3000"
32 depends_on:
33 - prometheus

4. Cloud Deployment Options

4.1 Comparison

PlatformProsConsBest For
AWS SageMakerFull ML platform, auto-scalingComplex, expensiveEnterprise
GCP Vertex AIGood integrationGCP lock-inGCP users
Azure MLEnterprise featuresComplex pricingMS shops
Hugging FaceEasy, free tierLimited computeNLP/demo
Railway/RenderSimple deployBasic featuresMVPs/startups
Self-hosted (K8s)Full controlHigh maintenanceLarge teams

4.2 Deploy to Hugging Face Spaces (Free)

Python
1# app.py (Gradio version for HF Spaces)
2import gradio as gr
3import joblib
4import numpy as np
5
6model = joblib.load("model_artifacts/model.pkl")
7
8def predict_churn(age, monthly_spend, tenure, tickets):
9 features = np.array([[age, monthly_spend, tenure, tickets]])
10 proba = model.predict_proba(features)[0]
11 churn_prob = proba[1]
12
13 label = "High Risk" if churn_prob > 0.7 else "Medium Risk" if churn_prob > 0.4 else "Low Risk"
14 return {
15 "Churn Probability": f"{churn_prob:.1%}",
16 "Risk Level": label
17 }
18
19demo = gr.Interface(
20 fn=predict_churn,
21 inputs=[
22 gr.Slider(18, 80, value=35, label="Age"),
23 gr.Number(value=89.5, label="Monthly Spend ($)"),
24 gr.Slider(0, 120, value=24, label="Tenure (months)"),
25 gr.Slider(0, 20, value=2, label="Support Tickets")
26 ],
27 outputs=gr.JSON(label="Prediction"),
28 title="Customer Churn Predictor",
29 description="Predict customer churn probability"
30)
31
32demo.launch()

4.3 Deploy with BentoML

Python
1# service.py
2import bentoml
3import numpy as np
4from bentoml.io import JSON
5
6# Save model to BentoML
7# bentoml.sklearn.save_model("churn_model", trained_model)
8
9runner = bentoml.sklearn.get("churn_model:latest").to_runner()
10svc = bentoml.Service("churn_prediction", runners=[runner])
11
12@svc.api(input=JSON(), output=JSON())
13def predict(input_data: dict) -> dict:
14 features = np.array([[
15 input_data["age"],
16 input_data["monthly_spend"],
17 input_data["tenure_months"],
18 input_data["support_tickets"]
19 ]])
20
21 proba = runner.predict_proba.run(features)[0]
22 return {
23 "churn_probability": float(proba[1]),
24 "prediction": "churn" if proba[1] > 0.5 else "no_churn"
25 }

5. Production Best Practices

5.1 API Design Checklist

PracticeDetails
Input validationPydantic models, type checks
Error handlingProper HTTP codes, error messages
LoggingRequest/response logging, prediction logging
Versioning/v1/predict, /v2/predict
Rate limitingPrevent abuse
AuthenticationAPI keys, JWT tokens
DocumentationAuto-generated OpenAPI/Swagger
Health check/health endpoint

5.2 Logging Predictions

Python
1import logging
2from datetime import datetime
3
4logging.basicConfig(level=logging.INFO)
5logger = logging.getLogger("prediction_logger")
6
7@app.post("/predict")
8def predict(data: CustomerData):
9 start_time = datetime.now()
10
11 # ... prediction logic ...
12
13 # Log for monitoring
14 latency_ms = (datetime.now() - start_time).total_seconds() * 1000
15 logger.info(
16 f"prediction_log | "
17 f"input={data.dict()} | "
18 f"output={prediction} | "
19 f"probability={churn_prob:.4f} | "
20 f"latency_ms={latency_ms:.1f} | "
21 f"model_version=1.0.0"
22 )
23
24 return response

📝 Quiz

  1. Real-time serving vs Batch serving khác nhau ở?

    • Batch chính xác hơn
    • Real-time từng request (dưới 100ms), Batch xử lý hàng loạt (minutes OK)
    • Real-time không cần API
    • Batch không cần model
  2. Tại sao dùng Docker cho ML deployment?

    • Chạy nhanh hơn
    • Reproducible environment, consistent across dev/staging/prod
    • Bắt buộc phải dùng
    • Docker trains model tốt hơn
  3. FastAPI được ưa chuộng cho ML serving vì?

    • Async, auto-docs, type validation, high performance
    • Chỉ có FastAPI mới deploy được ML
    • Miễn phí
    • Google tạo ra

🎯 Key Takeaways

  1. FastAPI — Best Python framework cho ML APIs
  2. Docker — Standard containerization cho deployment
  3. Pydantic — Input validation essential cho production
  4. Health checks — Monitor API availability
  5. Logging — Log mọi prediction cho monitoring

🚀 Bài tiếp theo

Feature Store & Model Monitoring — Feature engineering at scale và monitoring model drift!