Building a Multi-Provider AI Router: Never Depend on One LLM Again

7 min read
Building a Multi-Provider AI Router: Never Depend on One LLM Again

When Claude goes down at 2am and your production RAG pipeline starts failing, you’ll wish you’d built a multi-provider AI router. Single points of failure are the enemy of reliable AI systems, yet most developers still hardcode their applications to one LLM provider. I learned this the hard way when Anthropic’s API had an extended outage last month, taking down a client’s entire knowledge base system.

The solution? Build a router that can seamlessly switch between multiple AI providers based on availability, cost, and performance. Today I’ll walk you through building a production-ready multi-provider AI router using FastAPI that can handle local Ollama models, OpenAI, Anthropic Claude, and Google Gemini with intelligent fallback logic.

The Architecture: Smart Routing Strategies

A proper AI router needs more than just “try provider A, then B, then C”. You need intelligent routing based on multiple factors:

Cost-optimised routing routes cheaper providers first for simple queries, expensive ones for complex tasks. Latency-based routing prioritises local models for speed-critical applications. Capability-based routing matches specific providers to tasks they excel at. Health-based failover automatically excludes unhealthy providers from rotation.

Here’s the core router architecture:

from enum import Enum
from dataclasses import dataclass
from typing import List, Optional, Dict, Any
import asyncio
import time
import logging

class ProviderType(Enum):
    OLLAMA = "ollama"
    OPENAI = "openai"
    ANTHROPIC = "anthropic"
    GEMINI = "gemini"

@dataclass
class ProviderConfig:
    name: str
    provider_type: ProviderType
    endpoint: str
    api_key: Optional[str] = None
    model: str = ""
    max_tokens: int = 4000
    cost_per_token: float = 0.0  # For cost optimization
    priority: int = 1  # Lower = higher priority
    timeout: float = 30.0
    enabled: bool = True

class ProviderHealth:
    def __init__(self):
        self.response_times: List[float] = []
        self.success_rate: float = 1.0
        self.last_failure: Optional[float] = None
        self.consecutive_failures: int = 0
        
    def record_success(self, response_time: float):
        self.response_times.append(response_time)
        if len(self.response_times) > 100:  # Keep last 100 requests
            self.response_times.pop(0)
        self.consecutive_failures = 0
        
    def record_failure(self):
        self.consecutive_failures += 1
        self.last_failure = time.time()
        
    @property
    def avg_response_time(self) -> float:
        return sum(self.response_times) / len(self.response_times) if self.response_times else 999.0
        
    @property
    def is_healthy(self) -> bool:
        # Consider unhealthy if 3+ consecutive failures in last 5 minutes
        if self.consecutive_failures >= 3:
            if self.last_failure and time.time() - self.last_failure < 300:
                return False
        return True

Provider Implementations: Unified Interface

Each AI provider has its own API quirks, but your application shouldn’t care. Create a unified interface that abstracts away the differences:

from abc import ABC, abstractmethod
import httpx
import json

class BaseProvider(ABC):
    def __init__(self, config: ProviderConfig):
        self.config = config
        self.client = httpx.AsyncClient(timeout=config.timeout)
        
    @abstractmethod
    async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
        pass
        
    @abstractmethod
    async def health_check(self) -> bool:
        pass

class OllamaProvider(BaseProvider):
    async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
        payload = {
            "model": self.config.model,
            "prompt": prompt,
            "stream": False,
            "options": {
                "num_ctx": kwargs.get("max_tokens", self.config.max_tokens)
            }
        }
        
        response = await self.client.post(
            f"{self.config.endpoint}/api/generate",
            json=payload
        )
        response.raise_for_status()
        
        data = response.json()
        return {
            "content": data["response"],
            "model": self.config.model,
            "provider": self.config.name,
            "tokens_used": len(data["response"].split())  # Rough estimate
        }
    
    async def health_check(self) -> bool:
        try:
            response = await self.client.get(f"{self.config.endpoint}/api/tags")
            return response.status_code == 200
        except:
            return False

class OpenAIProvider(BaseProvider):
    async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
        payload = {
            "model": self.config.model,
            "messages": [{"role": "user", "content": prompt}],
            "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
            "temperature": kwargs.get("temperature", 0.7)
        }
        
        headers = {
            "Authorization": f"Bearer {self.config.api_key}",
            "Content-Type": "application/json"
        }
        
        response = await self.client.post(
            f"{self.config.endpoint}/chat/completions",
            json=payload,
            headers=headers
        )
        response.raise_for_status()
        
        data = response.json()
        return {
            "content": data["choices"][0]["message"]["content"],
            "model": data["model"],
            "provider": self.config.name,
            "tokens_used": data["usage"]["total_tokens"]
        }
    
    async def health_check(self) -> bool:
        try:
            headers = {"Authorization": f"Bearer {self.config.api_key}"}
            response = await self.client.get(
                f"{self.config.endpoint}/models",
                headers=headers
            )
            return response.status_code == 200
        except:
            return False

class AnthropicProvider(BaseProvider):
    async def generate(self, prompt: str, **kwargs) -> Dict[str, Any]:
        payload = {
            "model": self.config.model,
            "max_tokens": kwargs.get("max_tokens", self.config.max_tokens),
            "messages": [{"role": "user", "content": prompt}]
        }
        
        headers = {
            "x-api-key": self.config.api_key,
            "Content-Type": "application/json",
            "anthropic-version": "2023-06-01"
        }
        
        response = await self.client.post(
            f"{self.config.endpoint}/messages",
            json=payload,
            headers=headers
        )
        response.raise_for_status()
        
        data = response.json()
        return {
            "content": data["content"][0]["text"],
            "model": data["model"],
            "provider": self.config.name,
            "tokens_used": data["usage"]["input_tokens"] + data["usage"]["output_tokens"]
        }
    
    async def health_check(self) -> bool:
        try:
            # Anthropic doesn't have a simple health endpoint, so we'll use a minimal request
            payload = {
                "model": self.config.model,
                "max_tokens": 1,
                "messages": [{"role": "user", "content": "hi"}]
            }
            headers = {
                "x-api-key": self.config.api_key,
                "Content-Type": "application/json",
                "anthropic-version": "2023-06-01"
            }
            response = await self.client.post(
                f"{self.config.endpoint}/messages",
                json=payload,
                headers=headers
            )
            return response.status_code == 200
        except:
            return False

The Router Engine: Intelligence in Action

Now for the brain of the operation. The router needs to select providers intelligently and handle failures gracefully:

from typing import Union
import random

class RoutingStrategy(Enum):
    COST_OPTIMIZED = "cost_optimized"
    LATENCY_OPTIMIZED = "latency_optimized" 
    ROUND_ROBIN = "round_robin"
    PRIORITY_BASED = "priority_based"

class AIRouter:
    def __init__(self, providers: List[ProviderConfig], strategy: RoutingStrategy = RoutingStrategy.PRIORITY_BASED):
        self.providers: Dict[str, BaseProvider] = {}
        self.health: Dict[str, ProviderHealth] = {}
        self.strategy = strategy
        self.round_robin_index = 0
        
        # Initialize providers
        for config in providers:
            provider = self._create_provider(config)
            self.providers[config.name] = provider
            self.health[config.name] = ProviderHealth()
    
    def _create_provider(self, config: ProviderConfig) -> BaseProvider:
        provider_map = {
            ProviderType.OLLAMA: OllamaProvider,
            ProviderType.OPENAI: OpenAIProvider,
            ProviderType.ANTHROPIC: AnthropicProvider,
            # Add GeminiProvider here when implemented
        }
        
        provider_class = provider_map.get(config.provider_type)
        if not provider_class:
            raise ValueError(f"Unknown provider type: {config.provider_type}")
            
        return provider_class(config)
    
    def _get_healthy_providers(self) -> List[str]:
        return [
            name for name, health in self.health.items()
            if health.is_healthy and self.providers[name].config.enabled
        ]
    
    def _select_provider(self, healthy_providers: List[str]) -> str:
        if not healthy_providers:
            raise RuntimeError("No healthy providers available")
            
        if self.strategy == RoutingStrategy.COST_OPTIMIZED:
            # Sort by cost per token (ascending)
            return min(healthy_providers, 
                      key=lambda p: self.providers[p].config.cost_per_token)
                      
        elif self.strategy == RoutingStrategy.LATENCY_OPTIMIZED:
            # Sort by average response time (ascending)
            return min(healthy_providers,
                      key=lambda p: self.health[p].avg_response_time)
                      
        elif self.strategy == RoutingStrategy.ROUND_ROBIN:
            provider = healthy_providers[self.round_robin_index % len(healthy_providers)]
            self.round_robin_index += 1
            return provider
            
        else:  # PRIORITY_BASED
            # Sort by priority (ascending - lower number = higher priority)
            return min(healthy_providers,
                      key=lambda p: self.providers[p].config.priority)
    
    async def generate(self, prompt: str, max_retries: int = 3, **kwargs) -> Dict[str, Any]:
        last_error = None
        
        for attempt in range(max_retries):
            healthy_providers = self._get_healthy_providers()
            
            if not healthy_providers:
                await asyncio.sleep(2 ** attempt)  # Exponential backoff
                continue
                
            provider_name = self._select_provider(healthy_providers)
            provider = self.providers[provider_name]
            
            start_time = time.time()
            
            try:
                result = await provider.generate(prompt, **kwargs)
                
                # Record success
                response_time = time.time() - start_time
                self.health[provider_name].record_success(response_time)
                
                logging.info(f"Request routed to {provider_name} in {response_time:.2f}s")
                return result
                
            except Exception as e:
                # Record failure
                self.health[provider_name].record_failure()
                last_error = e
                
                logging.warning(f"Provider {provider_name} failed: {str(e)}")
                
                # Remove from this attempt's pool
                if provider_name in healthy_providers:
                    healthy_providers.remove(provider_name)
        
        raise RuntimeError(f"All providers failed after {max_retries} attempts. Last error: {last_error}")
    
    async def health_check_all(self):
        """Run health checks on all providers"""
        tasks = []
        for name, provider in self.providers.items():
            tasks.append(self._check_provider_health(name, provider))
        
        await asyncio.gather(*tasks)
    
    async def _check_provider_health(self, name: str, provider: BaseProvider):
        try:
            is_healthy = await provider.health_check()
            if not is_healthy:
                self.health[name].record_failure()
        except Exception:
            self.health[name].record_failure()

FastAPI Integration: Production-Ready API

Wrap everything in a FastAPI application with proper monitoring and observability:

from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel
import uvicorn
import asyncio

class GenerateRequest(BaseModel):
    prompt: str
    max_tokens: Optional[int] = None
    temperature: Optional[float] = 0.7
    strategy: Optional[str] = None

class GenerateResponse(BaseModel):
    content: str
    model: str
    provider: str
    tokens_used: int
    response_time: float

app = FastAPI(title="Multi-Provider AI Router", version="1.0.0")

# Initialize router with your providers
providers = [
    ProviderConfig(
        name="local_llama",
        provider_type=ProviderType.OLLAMA,
        endpoint="http://localhost:11434",
        model="llama2:7b",
        cost_per_token=0.0,
        priority=1
    ),
    ProviderConfig(
        name="openai_gpt4",
        provider_type=ProviderType.OPENAI,
        endpoint="https://api.openai.com/v1",
        api_key="your-api-key",
        model="gpt-4",
        cost_per_token=0.00003,
        priority=2
    ),
    ProviderConfig(
        name="claude_sonnet",
        provider_type=ProviderType.ANTHROPIC,
        endpoint="https://api.anthropic.com/v1",
        api_key="your-api-key",
        model="claude-3-sonnet-20240229",
        cost_per_token=0.000015,
        priority=3
    )
]

router = AIRouter(providers, RoutingStrategy.PRIORITY_BASED)

@app.on_event("startup")
async def startup_event():
    # Start background health checks
    asyncio.create_task(health_check_loop())

async def health_check_loop():
    while True:
        await router.health_check_all()
        await asyncio.sleep(30)  # Check every 30 seconds

@app.post("/generate", response_model=GenerateResponse)
async def generate_text(request: GenerateRequest):
    start_time = time.time()
    
    try:
        # Override strategy if specified
        if request.strategy:
            original_strategy = router.strategy
            router.strategy = RoutingStrategy(request.strategy)
        
        result = await router.generate(
            request.prompt,
            max_tokens=request.max_tokens,
            temperature=request.temperature
        )
        
        # Restore original strategy
        if request.strategy:
            router.strategy = original_strategy
        
        response_time = time.time() - start_time
        
        return GenerateResponse(
            content=result["content"],
            model=result["model"],
            provider=result["provider"],
            tokens_used=result["tokens_used"],
            response_time=response_time
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(

Need this built for your business?

Get In Touch