Model Serving Architecture: Từ Trained Model đến Production API

Training model chỉ là một nửa câu chuyện. Làm sao để model đó serve predictions cho hàng triệu users với latency thấp và cost hiệu quả? Model serving là cầu nối giữa ML research và production systems.

Trong bài này, chúng ta sẽ khám phá các kiến trúc serving, optimization techniques, và trade-offs khi deploy ML models.

Batch Processing vs Online Inference

Batch Processing (Offline Inference)

Xử lý large datasets theo batches, không cần real-time response.

# Example: Daily recommendation generation
def batch_inference():
    # Load all users (millions)
    users = db.get_all_users()
    
    # Process in batches
    batch_size = 1000
    for i in range(0, len(users), batch_size):
        batch = users[i:i+batch_size]
        
        # Generate recommendations
        recommendations = model.predict(batch)
        
        # Store for later retrieval
        db.save_recommendations(batch, recommendations)

# Run daily via cron
# Users get pre-computed recommendations instantly

Use cases:

  • Nightly ETL jobs
  • Daily recommendations (Netflix, Spotify)
  • Fraud detection on historical transactions
  • Batch translation of documents

Characteristics:

Latency: Minutes to hours (acceptable)
Throughput: Very high (millions per hour)
Cost: Lower (can use cheaper spot instances)
Complexity: Lower (no real-time infrastructure)

Online Inference (Real-time)

Serve predictions immediately on request.

from fastapi import FastAPI
import torch

app = FastAPI()

# Load model once at startup
model = torch.load("model.pt")
model.eval()

@app.post("/predict")
async def predict(data: dict):
    # Receive request
    features = preprocess(data)
    
    # Inference
    with torch.no_grad():
        prediction = model(features)
    
    # Return immediately
    return {"prediction": prediction.item()}

# Users get fresh predictions on-demand

Use cases:

  • Search ranking
  • Real-time fraud detection
  • Chatbots
  • Autonomous vehicles
  • Medical diagnosis

Characteristics:

Latency: Milliseconds (critical)
Throughput: Moderate (thousands per second)
Cost: Higher (always-on servers)
Complexity: Higher (load balancing, scaling)

Hybrid Approach

Combine both for optimal cost/performance.

# Batch: Pre-compute common cases
batch_precompute_popular_queries()

# Online: Handle fresh/uncommon requests
@app.post("/search")
def search(query: str):
    # Check cache (from batch)
    cached = redis.get(query)
    if cached:
        return cached
    
    # Fall back to online inference
    results = model.predict(query)
    return results

Model Serialization Formats

Persisting trained models efficiently.

1. Native Framework Formats

PyTorch (.pt, .pth):

# Save
torch.save(model.state_dict(), "model.pth")

# Load
model = MyModel()
model.load_state_dict(torch.load("model.pth"))
model.eval()

TensorFlow (SavedModel):

# Save
model.save("saved_model/")

# Load
model = tf.keras.models.load_model("saved_model/")

Pros: Easy, preserves full model
Cons: Framework-specific, not optimized for inference

2. ONNX (Open Neural Network Exchange)

Cross-framework format for interoperability.

import torch.onnx

# PyTorch → ONNX
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}}
)

# ONNX Runtime (faster than PyTorch/TF)
import onnxruntime as ort

session = ort.InferenceSession("model.onnx")
outputs = session.run(
    None,
    {"input": input_data.numpy()}
)

Benefits:

  • Framework-agnostic (train in PyTorch, serve in C++)
  • Optimized for inference (graph optimization)
  • 2-10x faster than native frameworks

Use when:

  • Need cross-platform deployment
  • Performance critical
  • Serving across different frameworks

3. TensorRT (NVIDIA GPU Optimization)

Optimizes models for NVIDIA GPUs.

import tensorrt as trt

# Convert ONNX → TensorRT
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)

# Parse ONNX
network = builder.create_network()
parser = trt.OnnxParser(network, TRT_LOGGER)
parser.parse_from_file("model.onnx")

# Optimize
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30  # 1GB

# Build engine
engine = builder.build_engine(network, config)

# Inference
context = engine.create_execution_context()
# ... (bindings setup)
context.execute_v2(bindings)

Optimizations:

  • Layer fusion (merge operations)
  • Precision calibration (FP32 → FP16 → INT8)
  • Kernel auto-tuning

Performance gain: 2-5x faster than ONNX on NVIDIA GPUs

Use when:

  • Deploying on NVIDIA hardware
  • Maximum GPU performance needed
  • Can tolerate longer build times

4. TorchScript

PyTorch's JIT compilation.

# Trace model
traced_model = torch.jit.trace(model, example_input)
traced_model.save("model_traced.pt")

# Or script (better for control flow)
scripted_model = torch.jit.script(model)
scripted_model.save("model_scripted.pt")

# Load in C++
# torch::jit::load("model_traced.pt");

Benefits:

  • Deploy PyTorch models without Python runtime
  • Faster than eager execution
  • Mobile deployment (PyTorch Mobile)

Quantization - Reduce Model Size & Latency

Reduce precision of weights and activations.

Precision Levels

FP32 (32-bit float):  4 bytes per parameter
FP16 (16-bit float):  2 bytes (50% reduction)
INT8 (8-bit integer): 1 byte  (75% reduction)
INT4 (4-bit integer): 0.5 byte (87.5% reduction)

Post-Training Quantization (PTQ)

Quantize after training (no retraining needed).

import torch.quantization

# Dynamic quantization (easiest)
model_quantized = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Linear},  # Layers to quantize
    dtype=torch.qint8
)

# 4x smaller, 2-4x faster on CPU

Static quantization (better quality):

# Requires calibration data
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_prepared = torch.quantization.prepare(model)

# Calibrate with representative data
for data in calibration_dataset:
    model_prepared(data)

# Convert
model_quantized = torch.quantization.convert(model_prepared)

Quantization-Aware Training (QAT)

Train with quantization in mind (best quality).

model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model_prepared = torch.quantization.prepare_qat(model)

# Train normally
for epoch in range(num_epochs):
    train(model_prepared)

# Convert to quantized
model_quantized = torch.quantization.convert(model_prepared)

Accuracy vs Performance Trade-off

Method              | Size    | Speed   | Accuracy
--------------------|---------|---------|----------
FP32 (baseline)     | 100%    | 1x      | 100%
FP16                | 50%     | 1.5-2x  | 99.9%
INT8 (PTQ)          | 25%     | 2-4x    | 98-99%
INT8 (QAT)          | 25%     | 2-4x    | 99.5%+
INT4                | 12.5%   | 3-5x    | 95-98%

General rule:

  • FP16: Almost free performance boost
  • INT8 PTQ: Good for most models
  • INT8 QAT: When accuracy matters
  • INT4: Only if accuracy loss acceptable

Deployment Patterns

1. Serverless Inference

Deploy models as cloud functions (AWS Lambda, Cloud Functions).

# AWS Lambda handler
import json
import boto3

# Load model at cold start
s3 = boto3.client('s3')
s3.download_file('my-bucket', 'model.onnx', '/tmp/model.onnx')

session = ort.InferenceSession('/tmp/model.onnx')

def lambda_handler(event, context):
    # Parse input
    data = json.loads(event['body'])
    
    # Inference
    outputs = session.run(None, {"input": data['features']})
    
    return {
        'statusCode': 200,
        'body': json.dumps({'prediction': outputs[0].tolist()})
    }

Pros:

  • No server management
  • Auto-scaling
  • Pay per request

Cons:

  • Cold start latency (1-10s)
  • Size limits (250MB Lambda)
  • Timeout limits (15min Lambda)

Best for:

  • Infrequent inference
  • Variable load
  • Small models (<250MB)

2. Dedicated Instances

Run model servers on EC2/GCE/Azure VMs.

# FastAPI server
from fastapi import FastAPI
import torch

app = FastAPI()

# Load on startup (no cold start)
@app.on_event("startup")
async def load_model():
    global model
    model = torch.load("model.pt")
    model.eval()

@app.post("/predict")
async def predict(features: list):
    with torch.no_grad():
        prediction = model(torch.tensor(features))
    return {"prediction": prediction.item()}

# Run with Gunicorn
# gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app

Pros:

  • No cold starts
  • Full control
  • No size limits

Cons:

  • Always-on cost
  • Manual scaling
  • Infrastructure management

Best for:

  • Consistent high traffic
  • Large models
  • Low latency requirements

3. Managed Services

Use cloud ML serving (AWS SageMaker, GCP Vertex AI).

# AWS SageMaker deployment
import sagemaker

# Create model
model = sagemaker.Model(
    model_data="s3://bucket/model.tar.gz",
    image_uri="<docker-image>",
    role=role
)

# Deploy
predictor = model.deploy(
    instance_type="ml.g4dn.xlarge",
    initial_instance_count=2,
    endpoint_name="my-model-endpoint"
)

# Predict
result = predictor.predict(data)

Pros:

  • Built-in auto-scaling
  • A/B testing
  • Model monitoring
  • Multi-model endpoints

Cons:

  • Higher cost
  • Vendor lock-in
  • Less control

Best for:

  • Enterprise production
  • Multiple models
  • Need monitoring/governance

4. Edge Deployment

Deploy on devices (mobile, IoT, embedded).

# TensorFlow Lite (mobile)
converter = tf.lite.TFLiteConverter.from_saved_model("saved_model/")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()

# Save
with open('model.tflite', 'wb') as f:
    f.write(tflite_model)

# Android/iOS loads this model locally

Requirements:

  • Small model size (<10MB ideal)
  • Fast inference (<100ms)
  • Low power consumption

Optimizations:

  • Quantization (INT8/INT4)
  • Pruning (remove weights)
  • Knowledge distillation (small student model)

Best for:

  • Privacy-sensitive apps
  • Offline functionality
  • Low latency (<10ms)

Batching Strategies

Improve throughput by processing multiple requests together.

Dynamic Batching

Collect requests and batch them.

import asyncio
from collections import deque

class DynamicBatcher:
    def __init__(self, max_batch_size=32, max_wait_ms=10):
        self.max_batch_size = max_batch_size
        self.max_wait_ms = max_wait_ms
        self.queue = deque()
        self.processing = False
    
    async def predict(self, input_data):
        # Add to queue
        future = asyncio.Future()
        self.queue.append((input_data, future))
        
        # Start batch processing if not running
        if not self.processing:
            asyncio.create_task(self.process_batch())
        
        # Wait for result
        return await future
    
    async def process_batch(self):
        self.processing = True
        
        # Wait for batch to fill or timeout
        await asyncio.sleep(self.max_wait_ms / 1000)
        
        # Collect batch
        batch_data = []
        batch_futures = []
        
        while self.queue and len(batch_data) < self.max_batch_size:
            data, future = self.queue.popleft()
            batch_data.append(data)
            batch_futures.append(future)
        
        if batch_data:
            # Batch inference
            results = model.predict(np.array(batch_data))
            
            # Distribute results
            for future, result in zip(batch_futures, results):
                future.set_result(result)
        
        self.processing = False
        
        # Continue if more in queue
        if self.queue:
            asyncio.create_task(self.process_batch())

# Usage
batcher = DynamicBatcher()

@app.post("/predict")
async def predict(data: dict):
    result = await batcher.predict(data['features'])
    return {"prediction": result}

Benefits:

  • Better GPU utilization (GPUs love batches)
  • Higher throughput (10-100x)

Trade-offs:

  • Slightly higher latency (wait time)
  • Complexity

NVIDIA Triton Inference Server

Production-grade server with built-in batching.

# model_config.pbtxt
name: "my_model"
platform: "onnxruntime_onnx"
max_batch_size: 32

dynamic_batching {
  preferred_batch_size: [ 8, 16, 32 ]
  max_queue_delay_microseconds: 10000
}

input [
  {
    name: "input"
    data_type: TYPE_FP32
    dims: [ 224, 224, 3 ]
  }
]

output [
  {
    name: "output"
    data_type: TYPE_FP32
    dims: [ 1000 ]
  }
]

Features:

  • Multi-framework (TensorRT, ONNX, PyTorch, TensorFlow)
  • Dynamic batching
  • Model versioning
  • Ensemble models (pipeline)
  • Concurrent model execution

Monitoring & Observability

Track model performance in production.

from prometheus_client import Counter, Histogram
import time

# Metrics
prediction_counter = Counter('predictions_total', 'Total predictions')
prediction_latency = Histogram('prediction_latency_seconds', 'Prediction latency')
prediction_errors = Counter('prediction_errors_total', 'Prediction errors')

@app.post("/predict")
async def predict(data: dict):
    start_time = time.time()
    
    try:
        result = model.predict(data['features'])
        prediction_counter.inc()
        return {"prediction": result}
    except Exception as e:
        prediction_errors.inc()
        raise
    finally:
        prediction_latency.observe(time.time() - start_time)

Key metrics:

  • Throughput (requests/second)
  • Latency (p50, p95, p99)
  • Error rate
  • GPU utilization
  • Model drift (input distribution changes)

Key Takeaways

  • Batch inference for offline jobs (high throughput, low cost)
  • Online inference for real-time (low latency, higher cost)
  • Model formats: Native → ONNX (portable) → TensorRT (GPU optimized)
  • Quantization: FP16 (almost free), INT8 (2-4x faster), INT4 (aggressive)
  • Deployment patterns: Serverless (variable load), Dedicated (consistent), Managed (enterprise), Edge (offline/privacy)
  • Dynamic batching improves GPU utilization (10-100x throughput)
  • Monitor: Latency, throughput, errors, drift

Trong bài tiếp theo, chúng ta sẽ khám phá MLOps Methodology - data drift, model drift, feature stores, và automated retraining pipelines.


Bài viết thuộc series "From Zero to AI Engineer" - Module 9: Deployment Strategy