Training một model là một chuyện, nhưng maintain nó trong production lại là câu chuyện hoàn toàn khác. Model performance degradation theo thời gian, data distribution thay đổi, và bạn cần retrain liên tục. MLOps (Machine Learning Operations) là practices, tools, và culture để operationalize ML systems.
Trong bài này, chúng ta sẽ khám phá các thách thức của production ML và cách MLOps giải quyết chúng.
Google định nghĩa 3 levels của MLOps maturity.
Developer laptop:
├─ Jupyter notebook
├─ Train model
├─ Save model.pkl
└─ Email to DevOps team
DevOps:
├─ Manually copy to server
├─ Update deployment script
└─ Deploy
Issues:
✗ No version control for data/models
✗ No reproducibility
✗ Manual, error-prone deployment
✗ No monitoring
Characteristics:
When acceptable: PoC, research projects
Automated training pipeline:
├─ Data validation
├─ Feature engineering
├─ Model training
├─ Model evaluation
├─ Model versioning
└─ Automated deployment (if metrics OK)
Monitoring:
├─ Model performance
├─ Data quality
└─ Alerts
Characteristics:
Components:
# Training pipeline (Airflow DAG)
from airflow import DAG
from airflow.operators.python import PythonOperator
def validate_data():
# Check data quality
assert data.isnull().sum() == 0
assert len(data) > MIN_SAMPLES
def train_model():
# Train with tracked metrics
import mlflow
with mlflow.start_run():
model = train(data)
metrics = evaluate(model, test_data)
mlflow.log_metrics(metrics)
mlflow.sklearn.log_model(model, "model")
# Auto-deploy if good enough
if metrics['accuracy'] > THRESHOLD:
deploy_model(model)
dag = DAG('ml_training_pipeline')
t1 = PythonOperator(task_id='validate', python_callable=validate_data)
t2 = PythonOperator(task_id='train', python_callable=train_model)
t1 >> t2
Full automation:
├─ Continuous Integration (CI)
│ ├─ Code tests
│ ├─ Data validation tests
│ ├─ Model validation tests
│ └─ Build artifacts
│
├─ Continuous Training (CT)
│ ├─ Automated retraining on new data
│ ├─ Automated model evaluation
│ └─ Automated promotion to registry
│
└─ Continuous Deployment (CD)
├─ Automated deployment to staging
├─ Integration tests
├─ Canary/Blue-green deployment
└─ Production deployment
Characteristics:
Example CI/CD pipeline:
# .github/workflows/ml-pipeline.yml
name: ML Pipeline
on:
push:
branches: [main]
schedule:
- cron: '0 0 * * 0' # Weekly retraining
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Run unit tests
run: pytest tests/
- name: Validate data schema
run: python scripts/validate_data.py
- name: Test model code
run: pytest tests/model_tests.py
train:
needs: test
runs-on: ubuntu-latest
steps:
- name: Train model
run: python train.py
- name: Evaluate model
run: python evaluate.py
- name: Register model
if: success()
run: python register_model.py
deploy:
needs: train
runs-on: ubuntu-latest
steps:
- name: Deploy to staging
run: kubectl apply -f k8s/staging/
- name: Integration tests
run: pytest tests/integration/
- name: Deploy to production
if: success()
run: kubectl apply -f k8s/production/
Models degrade over time due to distribution changes.
Input distribution thay đổi, nhưng relationship giữa input/output giống nhau.
# Example: E-commerce price prediction
# Training data (2020): Average price $50
# Production (2023): Average price $75 (inflation)
# P(X) changed, but P(Y|X) same
# Model trained on $50 range performs poorly on $75 range
Detection:
from scipy.stats import ks_2samp
def detect_data_drift(reference_data, production_data, threshold=0.05):
"""Kolmogorov-Smirnov test for distribution shift."""
drifts = {}
for column in reference_data.columns:
# Statistical test
statistic, p_value = ks_2samp(
reference_data[column],
production_data[column]
)
# Drift if distributions significantly different
drifts[column] = {
'drifted': p_value < threshold,
'p_value': p_value
}
return drifts
# Example
reference = pd.read_csv('training_data.csv')
current = pd.read_csv('last_week_production.csv')
drift_report = detect_data_drift(reference, current)
for feature, result in drift_report.items():
if result['drifted']:
print(f"⚠️ Drift detected in {feature}: p={result['p_value']}")
Mitigation:
Relationship between input và output thay đổi.
# Example: Fraud detection
# 2020: Fraudsters use pattern A
# 2023: Fraudsters adapt, use pattern B
# P(Y|X) changed
# Model learned old patterns, can't detect new fraud
Detection:
def detect_concept_drift(model, validation_data, window_size=1000):
"""Monitor model performance over time."""
performance_history = []
for i in range(0, len(validation_data), window_size):
window = validation_data[i:i+window_size]
# Evaluate on window
accuracy = model.score(window.X, window.y)
performance_history.append({
'timestamp': window.timestamp.max(),
'accuracy': accuracy
})
# Alert if performance drops
recent_accuracy = np.mean([p['accuracy'] for p in performance_history[-5:]])
baseline_accuracy = np.mean([p['accuracy'] for p in performance_history[:10]])
if recent_accuracy < baseline_accuracy * 0.9: # 10% drop
alert("Concept drift detected!")
return performance_history
Mitigation:
Evidently AI:
from evidently.report import Report
from evidently.metric_preset import DataDriftPreset
report = Report(metrics=[
DataDriftPreset()
])
report.run(
reference_data=reference_df,
current_data=current_df
)
report.save_html("drift_report.html")
WhyLabs:
import whylogs as why
# Log data profiles
profile = why.log(production_data)
# Compare with baseline
drift_report = profile.compare(baseline_profile)
# Alert if drift
if drift_report.has_drift():
send_alert(drift_report)
Centralized repository for features.
Problems without Feature Store:
Training:
├─ Data Scientist extracts features from raw data
├─ Trains model
└─ Saves features.csv
Serving:
├─ ML Engineer re-implements feature extraction
├─ Inconsistencies (training ≠ serving)
└─ Bugs, delays
Training/Serving Skew:
Training: AVG(price) = $50
Serving: AVG(price) = $48 (slightly different calculation)
Result: Model performs poorly!
Feature Store solves:
┌─────────────────────────────────────────┐
│ Feature Store │
├─────────────────────────────────────────┤
│ │
│ Offline Store (Historical features) │
│ ├─ S3/BigQuery │
│ └─ For training │
│ │
│ Online Store (Real-time features) │
│ ├─ Redis/DynamoDB │
│ └─ For serving (low latency) │
│ │
│ Feature Registry │
│ ├─ Feature definitions │
│ ├─ Metadata │
│ └─ Lineage │
└─────────────────────────────────────────┘
from feast import FeatureStore, Entity, FeatureView, Field
from feast.types import Float32, Int64
from datetime import timedelta
# Define entity
user = Entity(
name="user",
join_keys=["user_id"]
)
# Define features
user_features = FeatureView(
name="user_features",
entities=[user],
ttl=timedelta(days=1),
schema=[
Field(name="age", dtype=Int64),
Field(name="avg_purchase", dtype=Float32),
Field(name="total_purchases", dtype=Int64)
],
source=... # Data source
)
# Initialize store
store = FeatureStore(repo_path=".")
# Training: Get historical features
training_data = store.get_historical_features(
entity_df=user_df, # User IDs with timestamps
features=[
"user_features:age",
"user_features:avg_purchase",
"user_features:total_purchases"
]
).to_df()
# Serving: Get online features
features = store.get_online_features(
features=[
"user_features:age",
"user_features:avg_purchase"
],
entity_rows=[{"user_id": 123}]
).to_dict()
Benefits:
Trigger retraining based on conditions.
1. Time-based (Schedule):
# Retrain every Sunday
# cron: 0 0 * * 0
@weekly_schedule
def retrain():
data = fetch_last_week_data()
model = train(data)
if evaluate(model) > THRESHOLD:
deploy(model)
2. Performance-based:
def monitor_performance():
"""Check model metrics daily."""
recent_accuracy = get_production_accuracy(last_7_days)
if recent_accuracy < BASELINE * 0.95: # 5% drop
trigger_retraining()
3. Data-based:
def check_data_volume():
"""Retrain when enough new data."""
new_samples = count_samples_since_last_training()
if new_samples > MIN_NEW_SAMPLES:
trigger_retraining()
4. Drift-based:
def check_drift():
"""Retrain on significant drift."""
drift_detected = detect_data_drift(reference, current)
if any(d['drifted'] for d in drift_detected.values()):
trigger_retraining()
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.operators.python import BranchPythonOperator
def check_triggers():
"""Decide if retraining needed."""
time_based = is_scheduled_time()
perf_based = performance_degraded()
drift_based = drift_detected()
if time_based or perf_based or drift_based:
return 'retrain_model'
else:
return 'skip_training'
def retrain_model():
# Fetch fresh data
data = fetch_data(last_n_days=30)
# Train
model = train(data)
# Evaluate
metrics = evaluate(model, test_data)
# Version
model_version = save_to_registry(model, metrics)
return model_version
def validate_model(model_version):
"""A/B test or shadow deployment."""
# Deploy to 5% traffic
deploy_canary(model_version, traffic_percent=5)
# Monitor for 1 hour
time.sleep(3600)
# Check metrics
canary_metrics = get_canary_metrics(model_version)
baseline_metrics = get_production_metrics()
if canary_metrics['accuracy'] >= baseline_metrics['accuracy']:
return 'deploy_full'
else:
rollback_canary()
return 'keep_current'
def deploy_full(model_version):
"""Full production deployment."""
deploy_production(model_version)
notify_team(f"Model {model_version} deployed")
dag = DAG('automated_retraining')
check = BranchPythonOperator(
task_id='check_triggers',
python_callable=check_triggers
)
retrain = PythonOperator(
task_id='retrain_model',
python_callable=retrain_model
)
validate = BranchPythonOperator(
task_id='validate_model',
python_callable=validate_model
)
deploy = PythonOperator(
task_id='deploy_full',
python_callable=deploy_full
)
skip = DummyOperator(task_id='skip_training')
check >> [retrain, skip]
retrain >> validate >> deploy
Track models across lifecycle.
import mlflow
# Log model during training
with mlflow.start_run():
# Train
model = train(data)
# Log metrics
mlflow.log_metric("accuracy", accuracy)
mlflow.log_metric("f1_score", f1)
# Log model
mlflow.sklearn.log_model(
model,
"model",
registered_model_name="fraud_detection"
)
# Promote to staging
client = mlflow.tracking.MlflowClient()
client.transition_model_version_stage(
name="fraud_detection",
version=5,
stage="Staging"
)
# After validation, promote to production
client.transition_model_version_stage(
name="fraud_detection",
version=5,
stage="Production"
)
# Load production model
model_uri = "models:/fraud_detection/Production"
model = mlflow.sklearn.load_model(model_uri)
Stages:
Metadata tracked:
Compare experiments systematically.
import wandb
# Initialize
wandb.init(
project="image-classification",
config={
"learning_rate": 0.001,
"architecture": "ResNet50",
"batch_size": 32,
"epochs": 10
}
)
# Log during training
for epoch in range(epochs):
train_loss = train_epoch()
val_loss = validate()
wandb.log({
"epoch": epoch,
"train_loss": train_loss,
"val_loss": val_loss
})
# Log final model
wandb.log_artifact(model, name="model", type="model")
# Compare experiments in UI
# wandb.ai/project/runs
What to track:
Trong bài tiếp theo, chúng ta sẽ khám phá Scalable System Design - horizontal scaling, load balancing, caching patterns, và asynchronous processing cho AI applications.
Bài viết thuộc series "From Zero to AI Engineer" - Module 9: Deployment Strategy