Một model chạy tốt trên 10 users/second không đảm bảo sẽ handle được 10,000 users/second. Scalability là khả năng hệ thống tăng capacity để meet increased demand. Trong bài này, chúng ta sẽ khám phá các patterns và techniques để scale AI applications từ prototype đến production.
Tăng power của single machine.
Before:
├─ 4 CPU cores
├─ 16GB RAM
└─ 1x GPU
After (Vertical Scale):
├─ 16 CPU cores ← Upgrade
├─ 64GB RAM ← Upgrade
└─ 4x GPU ← Upgrade
Same machine, more powerful
Pros:
Cons:
When to use:
Thêm nhiều machines.
Before:
[Server 1]
After (Horizontal Scale):
[Server 1] [Server 2] [Server 3] [Server 4]
└────────── Load Balancer ──────────┘
Multiple machines, distribute load
Pros:
Cons:
When to use:
# ✅ Designed for horizontal scaling
class StatelessModelServer:
def __init__(self):
# Load model from shared storage
self.model = load_from_s3("models/latest.pkl")
def predict(self, features):
# No shared state - each request independent
return self.model.predict(features)
# Can run 100 instances of this server
# Load balancer distributes requests
# ❌ Hard to scale horizontally
class StatefulModelServer:
def __init__(self):
self.model = load_model()
self.request_cache = {} # Shared state!
self.user_sessions = {} # Shared state!
def predict(self, user_id, features):
# Requires session data
session = self.user_sessions[user_id]
# ...
Key principle: Design stateless services for horizontal scaling.
Distribute traffic across multiple servers.
1. Round Robin
class RoundRobinLoadBalancer:
def __init__(self, servers):
self.servers = servers
self.current = 0
def get_next_server(self):
server = self.servers[self.current]
self.current = (self.current + 1) % len(self.servers)
return server
# Example
lb = RoundRobinLoadBalancer(['server1', 'server2', 'server3'])
lb.get_next_server() # 'server1'
lb.get_next_server() # 'server2'
lb.get_next_server() # 'server3'
lb.get_next_server() # 'server1' (repeats)
Pros: Simple, fair distribution
Cons: Ignores server load/capacity
2. Least Connections
class LeastConnectionsLoadBalancer:
def __init__(self, servers):
self.connections = {server: 0 for server in servers}
def get_next_server(self):
# Choose server with fewest connections
return min(self.connections, key=self.connections.get)
def mark_request_start(self, server):
self.connections[server] += 1
def mark_request_end(self, server):
self.connections[server] -= 1
# Example
lb = LeastConnectionsLoadBalancer(['server1', 'server2', 'server3'])
# Server2 has 5 connections, others have 0
lb.connections['server2'] = 5
lb.get_next_server() # 'server1' or 'server3' (fewest connections)
Pros: Better for long-lived connections
Cons: Overhead tracking connections
3. Weighted Load Balancing
class WeightedLoadBalancer:
def __init__(self, servers_with_weights):
# servers_with_weights = [('server1', 3), ('server2', 2), ('server3', 1)]
self.servers = []
for server, weight in servers_with_weights:
self.servers.extend([server] * weight)
self.current = 0
def get_next_server(self):
server = self.servers[self.current]
self.current = (self.current + 1) % len(self.servers)
return server
# Example - server1 gets 50% traffic, server2 33%, server3 17%
lb = WeightedLoadBalancer([
('gpu-server', 3), # Powerful server
('cpu-server', 2), # Medium server
('backup-server', 1) # Weak server
])
Use case: Servers with different capacities
4. IP Hash
import hashlib
class IPHashLoadBalancer:
def __init__(self, servers):
self.servers = servers
def get_server_for_ip(self, client_ip):
# Hash IP to consistent server
hash_value = int(hashlib.md5(client_ip.encode()).hexdigest(), 16)
server_index = hash_value % len(self.servers)
return self.servers[server_index]
# Same client IP always goes to same server
lb = IPHashLoadBalancer(['server1', 'server2', 'server3'])
lb.get_server_for_ip('192.168.1.1') # Always 'server2'
lb.get_server_for_ip('192.168.1.1') # Always 'server2'
lb.get_server_for_ip('192.168.1.2') # Always 'server1'
Use case: Session affinity (sticky sessions)
import requests
import time
class LoadBalancerWithHealthCheck:
def __init__(self, servers):
self.servers = servers
self.healthy_servers = set(servers)
# Start health check loop
self.start_health_checks()
def check_health(self, server):
try:
response = requests.get(f"http://{server}/health", timeout=2)
return response.status_code == 200
except:
return False
def health_check_loop(self):
while True:
for server in self.servers:
if self.check_health(server):
self.healthy_servers.add(server)
else:
self.healthy_servers.discard(server)
print(f"⚠️ Server {server} unhealthy, removing from pool")
time.sleep(10) # Check every 10 seconds
def get_next_server(self):
if not self.healthy_servers:
raise Exception("No healthy servers available!")
# Use any algorithm on healthy servers only
return list(self.healthy_servers)[0]
AWS Application Load Balancer:
# ALB configuration (Terraform)
resource "aws_lb" "main" {
name = "ml-api-lb"
load_balancer_type = "application"
subnets = var.subnet_ids
}
resource "aws_lb_target_group" "ml_api" {
name = "ml-api-targets"
port = 8000
protocol = "HTTP"
vpc_id = var.vpc_id
health_check {
path = "/health"
interval = 30
timeout = 5
healthy_threshold = 2
unhealthy_threshold = 3
}
}
resource "aws_lb_listener" "main" {
load_balancer_arn = aws_lb.main.arn
port = "80"
protocol = "HTTP"
default_action {
type = "forward"
target_group_arn = aws_lb_target_group.ml_api.arn
}
}
Features:
Caching giảm load và latency dramatically.
Cache embeddings/predictions cho similar queries.
from sentence_transformers import SentenceTransformer, util
import redis
class SemanticCache:
def __init__(self, similarity_threshold=0.95):
self.model = SentenceTransformer('all-MiniLM-L6-v2')
self.redis = redis.Redis()
self.threshold = similarity_threshold
def get(self, query):
"""Check if semantically similar query exists in cache."""
# Embed query
query_embedding = self.model.encode(query)
# Get all cached queries
cached_queries = self.redis.keys("query:*")
for cached_key in cached_queries:
cached_query = cached_key.decode().replace("query:", "")
cached_embedding = self.model.encode(cached_query)
# Calculate similarity
similarity = util.cos_sim(query_embedding, cached_embedding).item()
if similarity > self.threshold:
# Cache hit!
return self.redis.get(cached_key)
return None
def set(self, query, result):
"""Cache query result."""
self.redis.set(f"query:{query}", result, ex=3600) # 1 hour TTL
# Usage
cache = SemanticCache()
# First request
result = cache.get("How to train ML model?")
if result is None:
result = expensive_llm_call("How to train ML model?")
cache.set("How to train ML model?", result)
# Second request with similar query
result = cache.get("How do I train a machine learning model?")
# Cache hit! (95%+ similar) → No LLM call needed
Benefits:
Update cache immediately when data changes.
class WriteThroughCache:
def __init__(self):
self.cache = redis.Redis()
self.db = Database()
def get(self, key):
# Try cache first
cached = self.cache.get(key)
if cached:
return cached
# Cache miss - load from DB
value = self.db.get(key)
# Update cache
self.cache.set(key, value)
return value
def set(self, key, value):
# Write to DB first
self.db.set(key, value)
# Then update cache
self.cache.set(key, value)
# Example: Model predictions cache
cache = WriteThroughCache()
# New prediction
prediction = model.predict(features)
cache.set(f"prediction:{user_id}", prediction)
# Later requests - served from cache
cached_prediction = cache.get(f"prediction:{user_id}") # Fast!
Pros: Cache always consistent
Cons: Write latency (update both cache + DB)
Application manages cache explicitly.
def get_model_prediction(user_id, features):
cache_key = f"prediction:{user_id}"
# Try cache
cached = redis.get(cache_key)
if cached:
return json.loads(cached)
# Cache miss - compute
prediction = model.predict(features)
# Store in cache
redis.set(cache_key, json.dumps(prediction), ex=300) # 5 min TTL
return prediction
Pros: Only cache what's needed
Cons: Cache misses add latency
Pre-compute embeddings for common documents.
class EmbeddingCache:
def __init__(self):
self.cache = {} # In-memory for speed
self.model = SentenceTransformer('all-MiniLM-L6-v2')
def precompute_embeddings(self, documents):
"""Pre-compute embeddings for all documents."""
for doc_id, doc_text in documents.items():
embedding = self.model.encode(doc_text)
self.cache[doc_id] = embedding
def get_embedding(self, doc_id):
"""Get pre-computed embedding."""
return self.cache.get(doc_id)
# Usage
cache = EmbeddingCache()
# At startup: Pre-compute all document embeddings
cache.precompute_embeddings(knowledge_base)
# Runtime: Instant retrieval
embedding = cache.get_embedding(doc_id) # No encoding needed!
Use case: RAG systems with static knowledge base
# 1. Time-based (TTL)
redis.set(key, value, ex=3600) # Expire after 1 hour
# 2. Event-based
def on_model_update(new_model_version):
# Clear prediction cache
redis.flushdb()
# 3. LRU (Least Recently Used)
redis.config_set('maxmemory-policy', 'allkeys-lru')
redis.config_set('maxmemory', '1gb')
# 4. Manual invalidation
def invalidate_user_cache(user_id):
redis.delete(f"predictions:{user_id}")
redis.delete(f"recommendations:{user_id}")
Handle long-running tasks without blocking.
from celery import Celery
import time
# Configure Celery
app = Celery('tasks', broker='redis://localhost:6379')
# Define async task
@app.task
def train_model(dataset_id):
"""Long-running task (hours)."""
dataset = load_dataset(dataset_id)
model = train(dataset)
save_model(model)
return model.id
# API endpoint
from fastapi import FastAPI, BackgroundTasks
api = FastAPI()
@api.post("/train")
async def trigger_training(dataset_id: str):
# Queue task (returns immediately)
task = train_model.delay(dataset_id)
return {
"message": "Training started",
"task_id": task.id
}
@api.get("/status/{task_id}")
async def check_status(task_id: str):
task = train_model.AsyncResult(task_id)
return {
"status": task.state, # PENDING, STARTED, SUCCESS, FAILURE
"result": task.result if task.ready() else None
}
# User flow:
# 1. POST /train → Get task_id
# 2. Poll GET /status/{task_id} → Check progress
# 3. When SUCCESS → Get result
from kafka import KafkaProducer, KafkaConsumer
import json
# Producer (API server)
producer = KafkaProducer(
bootstrap_servers=['localhost:9092'],
value_serializer=lambda v: json.dumps(v).encode()
)
@api.post("/predict")
async def predict(data: dict):
# Send to Kafka
producer.send('predictions', {
'user_id': data['user_id'],
'features': data['features'],
'timestamp': time.time()
})
return {"status": "queued"}
# Consumer (Model server)
consumer = KafkaConsumer(
'predictions',
bootstrap_servers=['localhost:9092'],
value_deserializer=lambda m: json.loads(m.decode())
)
for message in consumer:
data = message.value
# Process
prediction = model.predict(data['features'])
# Send result to another topic
producer.send('results', {
'user_id': data['user_id'],
'prediction': prediction
})
Use cases:
from fastapi import BackgroundTasks
def send_notification(user_id, prediction):
"""Send email/push notification."""
time.sleep(5) # Simulate slow operation
notify_user(user_id, f"Your prediction: {prediction}")
@api.post("/predict")
async def predict(data: dict, background_tasks: BackgroundTasks):
# Immediate prediction
prediction = model.predict(data['features'])
# Queue notification (non-blocking)
background_tasks.add_task(send_notification, data['user_id'], prediction)
# Return immediately
return {"prediction": prediction}
# User gets response instantly
# Notification sent in background
┌─────────────┐
│ Primary │ ← All WRITES go here
│ Database │
└──────┬──────┘
│ Replication
┌───┴────┬────────┐
▼ ▼ ▼
[Replica1][Replica2][Replica3] ← READS distributed
Implementation:
import random
class DatabaseRouter:
def __init__(self, primary, replicas):
self.primary = primary
self.replicas = replicas
def get_connection(self, operation):
if operation == 'write':
return self.primary
else:
# Load balance reads across replicas
return random.choice(self.replicas)
# Usage
db = DatabaseRouter(
primary="primary.db.example.com",
replicas=[
"replica1.db.example.com",
"replica2.db.example.com",
"replica3.db.example.com"
]
)
# Writes to primary
db.get_connection('write').execute("INSERT INTO users ...")
# Reads from replicas
db.get_connection('read').execute("SELECT * FROM users")
Partition data across multiple databases.
class ShardedDatabase:
def __init__(self, shards):
self.shards = shards # List of database connections
def get_shard(self, user_id):
"""Route based on user_id."""
shard_index = hash(user_id) % len(self.shards)
return self.shards[shard_index]
def save_user_data(self, user_id, data):
shard = self.get_shard(user_id)
shard.execute(f"INSERT INTO users VALUES ({user_id}, ...)")
def get_user_data(self, user_id):
shard = self.get_shard(user_id)
return shard.execute(f"SELECT * FROM users WHERE id={user_id}")
# Example: 4 shards
db = ShardedDatabase([
connect("shard1.db"),
connect("shard2.db"),
connect("shard3.db"),
connect("shard4.db")
])
# User 1 → Shard 2
# User 2 → Shard 1
# User 3 → Shard 4
# etc.
Challenges:
Trong bài tiếp theo, chúng ta sẽ khám phá Performance & Cost Optimization - token cost models, latency vs throughput trade-offs, batching, và model distillation.
Bài viết thuộc series "From Zero to AI Engineer" - Module 10: Scalability & Observability