Skip to content
Launch GitLab Knowledge Graph

Migrate from TensorFlow to PyTorch for better GPU utilization

Architecture Improvement: PyTorch Migration

Current State

Our content moderation model uses TensorFlow 2.10.0:

import tensorflow as tf

model = tf.keras.models.load_model("moderation_model.h5")
prediction = model.predict(image_tensor)

Issues:

  • TensorFlow 2.10.0 has CVE-2024-77777 (related to #13)
  • GPU utilization: 45% (inefficient)
  • Inference time: 180ms p99
  • Model size: 450MB (too large for edge deployment)

Proposed Solution

Migrate to PyTorch 2.1.0:

import torch
import torchvision

model = torch.jit.load("moderation_model.pt")
with torch.cuda.amp.autocast():  # Mixed precision
    prediction = model(image_tensor)

Benefits:

  • No CVE vulnerabilities (latest PyTorch 2.1.0)
  • GPU utilization: 45% → 85% (+40%)
  • Inference time: 180ms → 95ms (-85ms)
  • Model size: 450MB → 180MB (TorchScript)
  • Easier debugging with eager execution

Migration Plan

Phase 1 (Week 1-2): Model Conversion

  1. Export TensorFlow model to ONNX
  2. Convert ONNX to PyTorch
  3. Fine-tune converted model
  4. Validate accuracy (target: >99% match with TF model)

Phase 2 (Week 3): Infrastructure

  1. Update Docker images with PyTorch
  2. Update CI/CD pipelines
  3. Performance benchmarking

Phase 3 (Week 4): Deployment

  1. Deploy to staging
  2. A/B test (50% TF, 50% PyTorch)
  3. Monitor accuracy and performance
  4. Full rollout if metrics are good

Performance Testing

# Benchmark inference latency
python benchmark.py --framework pytorch --batch-size 32 --iterations 1000

TensorFlow Results:
- p50: 85ms
- p95: 156ms
- p99: 180ms
- GPU util: 45%

PyTorch Results (Expected):
- p50: 42ms (-43ms)
- p95: 78ms (-78ms)
- p99: 95ms (-85ms)
- GPU util: 85% (+40%)

Risks

  1. Model accuracy degradation: Mitigate with thorough validation
  2. API compatibility: Update API contracts if needed
  3. Production incident: Mitigate with gradual rollout

Related

  • Fixes: #13 (CVE-2024-77777 TensorFlow vulnerability)
  • Related: #18 (Pickle RCE - also in ai-recommendation-engine)