Scalable System Design: Xây dựng AI Applications chịu tải cao

Một model chạy tốt trên 10 users/second không đảm bảo sẽ handle được 10,000 users/second. Scalability là khả năng hệ thống tăng capacity để meet increased demand. Trong bài này, chúng ta sẽ khám phá các patterns và techniques để scale AI applications từ prototype đến production.

Horizontal vs Vertical Scaling

Vertical Scaling (Scale Up)

Tăng power của single machine.

Before: 
├─ 4 CPU cores
├─ 16GB RAM
└─ 1x GPU

After (Vertical Scale):
├─ 16 CPU cores  ← Upgrade
├─ 64GB RAM      ← Upgrade
└─ 4x GPU        ← Upgrade

Same machine, more powerful

Pros:

  • Đơn giản (no code changes)
  • No network latency giữa components
  • Easier to manage (single server)

Cons:

  • Physical limits (max CPU/RAM available)
  • Expensive (exponential cost)
  • Single point of failure
  • Downtime required for upgrades

When to use:

  • Early stage / prototyping
  • Database servers (harder to horizontally scale)
  • Applications requiring shared memory

Horizontal Scaling (Scale Out)

Thêm nhiều machines.

Before:
[Server 1]

After (Horizontal Scale):
[Server 1] [Server 2] [Server 3] [Server 4]
     └────────── Load Balancer ──────────┘

Multiple machines, distribute load

Pros:

  • (Almost) unlimited scalability
  • Better fault tolerance (redundancy)
  • Cost-effective (commodity hardware)
  • No downtime for adding servers

Cons:

  • More complex (distributed systems)
  • Network latency
  • Data consistency challenges
  • Requires stateless design

When to use:

  • Production systems
  • High traffic applications
  • Need fault tolerance
  • Cost-sensitive scaling

Scaling AI Applications

# ✅ Designed for horizontal scaling
class StatelessModelServer:
    def __init__(self):
        # Load model from shared storage
        self.model = load_from_s3("models/latest.pkl")
    
    def predict(self, features):
        # No shared state - each request independent
        return self.model.predict(features)

# Can run 100 instances of this server
# Load balancer distributes requests

# ❌ Hard to scale horizontally
class StatefulModelServer:
    def __init__(self):
        self.model = load_model()
        self.request_cache = {}  # Shared state!
        self.user_sessions = {}  # Shared state!
    
    def predict(self, user_id, features):
        # Requires session data
        session = self.user_sessions[user_id]
        # ...

Key principle: Design stateless services for horizontal scaling.

Load Balancing

Distribute traffic across multiple servers.

Load Balancing Algorithms

1. Round Robin

class RoundRobinLoadBalancer:
    def __init__(self, servers):
        self.servers = servers
        self.current = 0
    
    def get_next_server(self):
        server = self.servers[self.current]
        self.current = (self.current + 1) % len(self.servers)
        return server

# Example
lb = RoundRobinLoadBalancer(['server1', 'server2', 'server3'])

lb.get_next_server()  # 'server1'
lb.get_next_server()  # 'server2'
lb.get_next_server()  # 'server3'
lb.get_next_server()  # 'server1' (repeats)

Pros: Simple, fair distribution
Cons: Ignores server load/capacity

2. Least Connections

class LeastConnectionsLoadBalancer:
    def __init__(self, servers):
        self.connections = {server: 0 for server in servers}
    
    def get_next_server(self):
        # Choose server with fewest connections
        return min(self.connections, key=self.connections.get)
    
    def mark_request_start(self, server):
        self.connections[server] += 1
    
    def mark_request_end(self, server):
        self.connections[server] -= 1

# Example
lb = LeastConnectionsLoadBalancer(['server1', 'server2', 'server3'])

# Server2 has 5 connections, others have 0
lb.connections['server2'] = 5

lb.get_next_server()  # 'server1' or 'server3' (fewest connections)

Pros: Better for long-lived connections
Cons: Overhead tracking connections

3. Weighted Load Balancing

class WeightedLoadBalancer:
    def __init__(self, servers_with_weights):
        # servers_with_weights = [('server1', 3), ('server2', 2), ('server3', 1)]
        self.servers = []
        for server, weight in servers_with_weights:
            self.servers.extend([server] * weight)
        
        self.current = 0
    
    def get_next_server(self):
        server = self.servers[self.current]
        self.current = (self.current + 1) % len(self.servers)
        return server

# Example - server1 gets 50% traffic, server2 33%, server3 17%
lb = WeightedLoadBalancer([
    ('gpu-server', 3),    # Powerful server
    ('cpu-server', 2),    # Medium server
    ('backup-server', 1)  # Weak server
])

Use case: Servers with different capacities

4. IP Hash

import hashlib

class IPHashLoadBalancer:
    def __init__(self, servers):
        self.servers = servers
    
    def get_server_for_ip(self, client_ip):
        # Hash IP to consistent server
        hash_value = int(hashlib.md5(client_ip.encode()).hexdigest(), 16)
        server_index = hash_value % len(self.servers)
        return self.servers[server_index]

# Same client IP always goes to same server
lb = IPHashLoadBalancer(['server1', 'server2', 'server3'])

lb.get_server_for_ip('192.168.1.1')  # Always 'server2'
lb.get_server_for_ip('192.168.1.1')  # Always 'server2'
lb.get_server_for_ip('192.168.1.2')  # Always 'server1'

Use case: Session affinity (sticky sessions)

Health Checks

import requests
import time

class LoadBalancerWithHealthCheck:
    def __init__(self, servers):
        self.servers = servers
        self.healthy_servers = set(servers)
        
        # Start health check loop
        self.start_health_checks()
    
    def check_health(self, server):
        try:
            response = requests.get(f"http://{server}/health", timeout=2)
            return response.status_code == 200
        except:
            return False
    
    def health_check_loop(self):
        while True:
            for server in self.servers:
                if self.check_health(server):
                    self.healthy_servers.add(server)
                else:
                    self.healthy_servers.discard(server)
                    print(f"⚠️ Server {server} unhealthy, removing from pool")
            
            time.sleep(10)  # Check every 10 seconds
    
    def get_next_server(self):
        if not self.healthy_servers:
            raise Exception("No healthy servers available!")
        
        # Use any algorithm on healthy servers only
        return list(self.healthy_servers)[0]

Cloud Load Balancers

AWS Application Load Balancer:

# ALB configuration (Terraform)
resource "aws_lb" "main" {
  name               = "ml-api-lb"
  load_balancer_type = "application"
  subnets            = var.subnet_ids
}

resource "aws_lb_target_group" "ml_api" {
  name     = "ml-api-targets"
  port     = 8000
  protocol = "HTTP"
  vpc_id   = var.vpc_id
  
  health_check {
    path                = "/health"
    interval            = 30
    timeout             = 5
    healthy_threshold   = 2
    unhealthy_threshold = 3
  }
}

resource "aws_lb_listener" "main" {
  load_balancer_arn = aws_lb.main.arn
  port              = "80"
  protocol          = "HTTP"
  
  default_action {
    type             = "forward"
    target_group_arn = aws_lb_target_group.ml_api.arn
  }
}

Features:

  • Auto-scaling integration
  • SSL termination
  • Path-based routing
  • WebSocket support

Caching Patterns for AI Apps

Caching giảm load và latency dramatically.

1. Semantic Cache

Cache embeddings/predictions cho similar queries.

from sentence_transformers import SentenceTransformer, util
import redis

class SemanticCache:
    def __init__(self, similarity_threshold=0.95):
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
        self.redis = redis.Redis()
        self.threshold = similarity_threshold
    
    def get(self, query):
        """Check if semantically similar query exists in cache."""
        
        # Embed query
        query_embedding = self.model.encode(query)
        
        # Get all cached queries
        cached_queries = self.redis.keys("query:*")
        
        for cached_key in cached_queries:
            cached_query = cached_key.decode().replace("query:", "")
            cached_embedding = self.model.encode(cached_query)
            
            # Calculate similarity
            similarity = util.cos_sim(query_embedding, cached_embedding).item()
            
            if similarity > self.threshold:
                # Cache hit!
                return self.redis.get(cached_key)
        
        return None
    
    def set(self, query, result):
        """Cache query result."""
        self.redis.set(f"query:{query}", result, ex=3600)  # 1 hour TTL

# Usage
cache = SemanticCache()

# First request
result = cache.get("How to train ML model?")
if result is None:
    result = expensive_llm_call("How to train ML model?")
    cache.set("How to train ML model?", result)

# Second request with similar query
result = cache.get("How do I train a machine learning model?")
# Cache hit! (95%+ similar) → No LLM call needed

Benefits:

  • Reduce API costs (fewer LLM calls)
  • Lower latency (cache faster than API)
  • Works for paraphrased queries

2. Write-through Cache

Update cache immediately when data changes.

class WriteThroughCache:
    def __init__(self):
        self.cache = redis.Redis()
        self.db = Database()
    
    def get(self, key):
        # Try cache first
        cached = self.cache.get(key)
        if cached:
            return cached
        
        # Cache miss - load from DB
        value = self.db.get(key)
        
        # Update cache
        self.cache.set(key, value)
        return value
    
    def set(self, key, value):
        # Write to DB first
        self.db.set(key, value)
        
        # Then update cache
        self.cache.set(key, value)

# Example: Model predictions cache
cache = WriteThroughCache()

# New prediction
prediction = model.predict(features)
cache.set(f"prediction:{user_id}", prediction)

# Later requests - served from cache
cached_prediction = cache.get(f"prediction:{user_id}")  # Fast!

Pros: Cache always consistent
Cons: Write latency (update both cache + DB)

3. Cache-aside (Lazy Loading)

Application manages cache explicitly.

def get_model_prediction(user_id, features):
    cache_key = f"prediction:{user_id}"
    
    # Try cache
    cached = redis.get(cache_key)
    if cached:
        return json.loads(cached)
    
    # Cache miss - compute
    prediction = model.predict(features)
    
    # Store in cache
    redis.set(cache_key, json.dumps(prediction), ex=300)  # 5 min TTL
    
    return prediction

Pros: Only cache what's needed
Cons: Cache misses add latency

4. Embedding Cache

Pre-compute embeddings for common documents.

class EmbeddingCache:
    def __init__(self):
        self.cache = {}  # In-memory for speed
        self.model = SentenceTransformer('all-MiniLM-L6-v2')
    
    def precompute_embeddings(self, documents):
        """Pre-compute embeddings for all documents."""
        for doc_id, doc_text in documents.items():
            embedding = self.model.encode(doc_text)
            self.cache[doc_id] = embedding
    
    def get_embedding(self, doc_id):
        """Get pre-computed embedding."""
        return self.cache.get(doc_id)

# Usage
cache = EmbeddingCache()

# At startup: Pre-compute all document embeddings
cache.precompute_embeddings(knowledge_base)

# Runtime: Instant retrieval
embedding = cache.get_embedding(doc_id)  # No encoding needed!

Use case: RAG systems with static knowledge base

Cache Invalidation Strategies

# 1. Time-based (TTL)
redis.set(key, value, ex=3600)  # Expire after 1 hour

# 2. Event-based
def on_model_update(new_model_version):
    # Clear prediction cache
    redis.flushdb()

# 3. LRU (Least Recently Used)
redis.config_set('maxmemory-policy', 'allkeys-lru')
redis.config_set('maxmemory', '1gb')

# 4. Manual invalidation
def invalidate_user_cache(user_id):
    redis.delete(f"predictions:{user_id}")
    redis.delete(f"recommendations:{user_id}")

Asynchronous Processing

Handle long-running tasks without blocking.

Message Queue Pattern

from celery import Celery
import time

# Configure Celery
app = Celery('tasks', broker='redis://localhost:6379')

# Define async task
@app.task
def train_model(dataset_id):
    """Long-running task (hours)."""
    dataset = load_dataset(dataset_id)
    model = train(dataset)
    save_model(model)
    return model.id

# API endpoint
from fastapi import FastAPI, BackgroundTasks

api = FastAPI()

@api.post("/train")
async def trigger_training(dataset_id: str):
    # Queue task (returns immediately)
    task = train_model.delay(dataset_id)
    
    return {
        "message": "Training started",
        "task_id": task.id
    }

@api.get("/status/{task_id}")
async def check_status(task_id: str):
    task = train_model.AsyncResult(task_id)
    
    return {
        "status": task.state,  # PENDING, STARTED, SUCCESS, FAILURE
        "result": task.result if task.ready() else None
    }

# User flow:
# 1. POST /train → Get task_id
# 2. Poll GET /status/{task_id} → Check progress
# 3. When SUCCESS → Get result

Kafka for Event Streaming

from kafka import KafkaProducer, KafkaConsumer
import json

# Producer (API server)
producer = KafkaProducer(
    bootstrap_servers=['localhost:9092'],
    value_serializer=lambda v: json.dumps(v).encode()
)

@api.post("/predict")
async def predict(data: dict):
    # Send to Kafka
    producer.send('predictions', {
        'user_id': data['user_id'],
        'features': data['features'],
        'timestamp': time.time()
    })
    
    return {"status": "queued"}

# Consumer (Model server)
consumer = KafkaConsumer(
    'predictions',
    bootstrap_servers=['localhost:9092'],
    value_deserializer=lambda m: json.loads(m.decode())
)

for message in consumer:
    data = message.value
    
    # Process
    prediction = model.predict(data['features'])
    
    # Send result to another topic
    producer.send('results', {
        'user_id': data['user_id'],
        'prediction': prediction
    })

Use cases:

  • Real-time analytics
  • Event-driven architectures
  • Decoupling services

Background Tasks

from fastapi import BackgroundTasks

def send_notification(user_id, prediction):
    """Send email/push notification."""
    time.sleep(5)  # Simulate slow operation
    notify_user(user_id, f"Your prediction: {prediction}")

@api.post("/predict")
async def predict(data: dict, background_tasks: BackgroundTasks):
    # Immediate prediction
    prediction = model.predict(data['features'])
    
    # Queue notification (non-blocking)
    background_tasks.add_task(send_notification, data['user_id'], prediction)
    
    # Return immediately
    return {"prediction": prediction}

# User gets response instantly
# Notification sent in background

Database Scaling Patterns

Read Replicas

┌─────────────┐
│   Primary   │ ← All WRITES go here
│   Database  │
└──────┬──────┘
       │ Replication
   ┌───┴────┬────────┐
   ▼        ▼        ▼
[Replica1][Replica2][Replica3] ← READS distributed

Implementation:

import random

class DatabaseRouter:
    def __init__(self, primary, replicas):
        self.primary = primary
        self.replicas = replicas
    
    def get_connection(self, operation):
        if operation == 'write':
            return self.primary
        else:
            # Load balance reads across replicas
            return random.choice(self.replicas)

# Usage
db = DatabaseRouter(
    primary="primary.db.example.com",
    replicas=[
        "replica1.db.example.com",
        "replica2.db.example.com",
        "replica3.db.example.com"
    ]
)

# Writes to primary
db.get_connection('write').execute("INSERT INTO users ...")

# Reads from replicas
db.get_connection('read').execute("SELECT * FROM users")

Sharding

Partition data across multiple databases.

class ShardedDatabase:
    def __init__(self, shards):
        self.shards = shards  # List of database connections
    
    def get_shard(self, user_id):
        """Route based on user_id."""
        shard_index = hash(user_id) % len(self.shards)
        return self.shards[shard_index]
    
    def save_user_data(self, user_id, data):
        shard = self.get_shard(user_id)
        shard.execute(f"INSERT INTO users VALUES ({user_id}, ...)")
    
    def get_user_data(self, user_id):
        shard = self.get_shard(user_id)
        return shard.execute(f"SELECT * FROM users WHERE id={user_id}")

# Example: 4 shards
db = ShardedDatabase([
    connect("shard1.db"),
    connect("shard2.db"),
    connect("shard3.db"),
    connect("shard4.db")
])

# User 1 → Shard 2
# User 2 → Shard 1
# User 3 → Shard 4
# etc.

Challenges:

  • Cross-shard queries complex
  • Rebalancing when adding shards
  • Transactions across shards

Key Takeaways

  • Vertical scaling: Upgrade single machine (simple but limited)
  • Horizontal scaling: Add more machines (complex but unlimited)
  • Load balancing: Round Robin (simple), Least Connections (fairness), Weighted (capacity), IP Hash (sticky sessions)
  • Caching patterns: Semantic cache (AI queries), Write-through (consistency), Cache-aside (lazy), Embedding cache (speed)
  • Async processing: Message queues (Celery, Kafka) decouple slow operations
  • Database scaling: Read replicas (read-heavy), Sharding (massive scale)
  • Design principle: Stateless services enable horizontal scaling

Trong bài tiếp theo, chúng ta sẽ khám phá Performance & Cost Optimization - token cost models, latency vs throughput trade-offs, batching, và model distillation.


Bài viết thuộc series "From Zero to AI Engineer" - Module 10: Scalability & Observability