Training model chỉ là một nửa câu chuyện. Làm sao để model đó serve predictions cho hàng triệu users với latency thấp và cost hiệu quả? Model serving là cầu nối giữa ML research và production systems.
Trong bài này, chúng ta sẽ khám phá các kiến trúc serving, optimization techniques, và trade-offs khi deploy ML models.
Xử lý large datasets theo batches, không cần real-time response.
# Example: Daily recommendation generation
def batch_inference():
# Load all users (millions)
users = db.get_all_users()
# Process in batches
batch_size = 1000
for i in range(0, len(users), batch_size):
batch = users[i:i+batch_size]
# Generate recommendations
recommendations = model.predict(batch)
# Store for later retrieval
db.save_recommendations(batch, recommendations)
# Run daily via cron
# Users get pre-computed recommendations instantly
Use cases:
Characteristics:
Latency: Minutes to hours (acceptable)
Throughput: Very high (millions per hour)
Cost: Lower (can use cheaper spot instances)
Complexity: Lower (no real-time infrastructure)
Serve predictions immediately on request.
from fastapi import FastAPI
import torch
app = FastAPI()
# Load model once at startup
model = torch.load("model.pt")
model.eval()
@app.post("/predict")
async def predict(data: dict):
# Receive request
features = preprocess(data)
# Inference
with torch.no_grad():
prediction = model(features)
# Return immediately
return {"prediction": prediction.item()}
# Users get fresh predictions on-demand
Use cases:
Characteristics:
Latency: Milliseconds (critical)
Throughput: Moderate (thousands per second)
Cost: Higher (always-on servers)
Complexity: Higher (load balancing, scaling)
Combine both for optimal cost/performance.
# Batch: Pre-compute common cases
batch_precompute_popular_queries()
# Online: Handle fresh/uncommon requests
@app.post("/search")
def search(query: str):
# Check cache (from batch)
cached = redis.get(query)
if cached:
return cached
# Fall back to online inference
results = model.predict(query)
return results
Persisting trained models efficiently.
PyTorch (.pt, .pth):
# Save
torch.save(model.state_dict(), "model.pth")
# Load
model = MyModel()
model.load_state_dict(torch.load("model.pth"))
model.eval()
TensorFlow (SavedModel):
# Save
model.save("saved_model/")
# Load
model = tf.keras.models.load_model("saved_model/")
Pros: Easy, preserves full model
Cons: Framework-specific, not optimized for inference
Cross-framework format for interoperability.
import torch.onnx
# PyTorch → ONNX
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["output"],
dynamic_axes={"input": {0: "batch_size"}}
)
# ONNX Runtime (faster than PyTorch/TF)
import onnxruntime as ort
session = ort.InferenceSession("model.onnx")
outputs = session.run(
None,
{"input": input_data.numpy()}
)
Benefits:
Use when:
Optimizes models for NVIDIA GPUs.
import tensorrt as trt
# Convert ONNX → TensorRT
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(TRT_LOGGER)
# Parse ONNX
network = builder.create_network()
parser = trt.OnnxParser(network, TRT_LOGGER)
parser.parse_from_file("model.onnx")
# Optimize
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30 # 1GB
# Build engine
engine = builder.build_engine(network, config)
# Inference
context = engine.create_execution_context()
# ... (bindings setup)
context.execute_v2(bindings)
Optimizations:
Performance gain: 2-5x faster than ONNX on NVIDIA GPUs
Use when:
PyTorch's JIT compilation.
# Trace model
traced_model = torch.jit.trace(model, example_input)
traced_model.save("model_traced.pt")
# Or script (better for control flow)
scripted_model = torch.jit.script(model)
scripted_model.save("model_scripted.pt")
# Load in C++
# torch::jit::load("model_traced.pt");
Benefits:
Reduce precision of weights and activations.
FP32 (32-bit float): 4 bytes per parameter
FP16 (16-bit float): 2 bytes (50% reduction)
INT8 (8-bit integer): 1 byte (75% reduction)
INT4 (4-bit integer): 0.5 byte (87.5% reduction)
Quantize after training (no retraining needed).
import torch.quantization
# Dynamic quantization (easiest)
model_quantized = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear}, # Layers to quantize
dtype=torch.qint8
)
# 4x smaller, 2-4x faster on CPU
Static quantization (better quality):
# Requires calibration data
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_prepared = torch.quantization.prepare(model)
# Calibrate with representative data
for data in calibration_dataset:
model_prepared(data)
# Convert
model_quantized = torch.quantization.convert(model_prepared)
Train with quantization in mind (best quality).
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model_prepared = torch.quantization.prepare_qat(model)
# Train normally
for epoch in range(num_epochs):
train(model_prepared)
# Convert to quantized
model_quantized = torch.quantization.convert(model_prepared)
Method | Size | Speed | Accuracy
--------------------|---------|---------|----------
FP32 (baseline) | 100% | 1x | 100%
FP16 | 50% | 1.5-2x | 99.9%
INT8 (PTQ) | 25% | 2-4x | 98-99%
INT8 (QAT) | 25% | 2-4x | 99.5%+
INT4 | 12.5% | 3-5x | 95-98%
General rule:
Deploy models as cloud functions (AWS Lambda, Cloud Functions).
# AWS Lambda handler
import json
import boto3
# Load model at cold start
s3 = boto3.client('s3')
s3.download_file('my-bucket', 'model.onnx', '/tmp/model.onnx')
session = ort.InferenceSession('/tmp/model.onnx')
def lambda_handler(event, context):
# Parse input
data = json.loads(event['body'])
# Inference
outputs = session.run(None, {"input": data['features']})
return {
'statusCode': 200,
'body': json.dumps({'prediction': outputs[0].tolist()})
}
Pros:
Cons:
Best for:
Run model servers on EC2/GCE/Azure VMs.
# FastAPI server
from fastapi import FastAPI
import torch
app = FastAPI()
# Load on startup (no cold start)
@app.on_event("startup")
async def load_model():
global model
model = torch.load("model.pt")
model.eval()
@app.post("/predict")
async def predict(features: list):
with torch.no_grad():
prediction = model(torch.tensor(features))
return {"prediction": prediction.item()}
# Run with Gunicorn
# gunicorn -w 4 -k uvicorn.workers.UvicornWorker main:app
Pros:
Cons:
Best for:
Use cloud ML serving (AWS SageMaker, GCP Vertex AI).
# AWS SageMaker deployment
import sagemaker
# Create model
model = sagemaker.Model(
model_data="s3://bucket/model.tar.gz",
image_uri="<docker-image>",
role=role
)
# Deploy
predictor = model.deploy(
instance_type="ml.g4dn.xlarge",
initial_instance_count=2,
endpoint_name="my-model-endpoint"
)
# Predict
result = predictor.predict(data)
Pros:
Cons:
Best for:
Deploy on devices (mobile, IoT, embedded).
# TensorFlow Lite (mobile)
converter = tf.lite.TFLiteConverter.from_saved_model("saved_model/")
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model = converter.convert()
# Save
with open('model.tflite', 'wb') as f:
f.write(tflite_model)
# Android/iOS loads this model locally
Requirements:
Optimizations:
Best for:
Improve throughput by processing multiple requests together.
Collect requests and batch them.
import asyncio
from collections import deque
class DynamicBatcher:
def __init__(self, max_batch_size=32, max_wait_ms=10):
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.queue = deque()
self.processing = False
async def predict(self, input_data):
# Add to queue
future = asyncio.Future()
self.queue.append((input_data, future))
# Start batch processing if not running
if not self.processing:
asyncio.create_task(self.process_batch())
# Wait for result
return await future
async def process_batch(self):
self.processing = True
# Wait for batch to fill or timeout
await asyncio.sleep(self.max_wait_ms / 1000)
# Collect batch
batch_data = []
batch_futures = []
while self.queue and len(batch_data) < self.max_batch_size:
data, future = self.queue.popleft()
batch_data.append(data)
batch_futures.append(future)
if batch_data:
# Batch inference
results = model.predict(np.array(batch_data))
# Distribute results
for future, result in zip(batch_futures, results):
future.set_result(result)
self.processing = False
# Continue if more in queue
if self.queue:
asyncio.create_task(self.process_batch())
# Usage
batcher = DynamicBatcher()
@app.post("/predict")
async def predict(data: dict):
result = await batcher.predict(data['features'])
return {"prediction": result}
Benefits:
Trade-offs:
Production-grade server with built-in batching.
# model_config.pbtxt
name: "my_model"
platform: "onnxruntime_onnx"
max_batch_size: 32
dynamic_batching {
preferred_batch_size: [ 8, 16, 32 ]
max_queue_delay_microseconds: 10000
}
input [
{
name: "input"
data_type: TYPE_FP32
dims: [ 224, 224, 3 ]
}
]
output [
{
name: "output"
data_type: TYPE_FP32
dims: [ 1000 ]
}
]
Features:
Track model performance in production.
from prometheus_client import Counter, Histogram
import time
# Metrics
prediction_counter = Counter('predictions_total', 'Total predictions')
prediction_latency = Histogram('prediction_latency_seconds', 'Prediction latency')
prediction_errors = Counter('prediction_errors_total', 'Prediction errors')
@app.post("/predict")
async def predict(data: dict):
start_time = time.time()
try:
result = model.predict(data['features'])
prediction_counter.inc()
return {"prediction": result}
except Exception as e:
prediction_errors.inc()
raise
finally:
prediction_latency.observe(time.time() - start_time)
Key metrics:
Trong bài tiếp theo, chúng ta sẽ khám phá MLOps Methodology - data drift, model drift, feature stores, và automated retraining pipelines.
Bài viết thuộc series "From Zero to AI Engineer" - Module 9: Deployment Strategy