"""Model routing for Claude-compatible requests.""" from __future__ import annotations from dataclasses import dataclass from loguru import logger from config.provider_ids import SUPPORTED_PROVIDER_IDS from config.settings import Settings from core.model_capabilities import find_best_model_for_task from core.session_tracker import SessionTracker from core.task_detector import TaskDetector from providers.rate_limit import GlobalRateLimiter from .gateway_model_ids import decode_gateway_model_id from .models.anthropic import MessagesRequest, TokenCountRequest # Default NIM models to include in auto routing (in order of preference) DEFAULT_NIM_AUTO_MODELS = [ "nvidia_nim/qwen/qwen3-coder-480b-a35b-instruct", "nvidia_nim/z-ai/glm4.7", "nvidia_nim/stepfun-ai/step-3.5-flash", "nvidia_nim/mistralai/mistral-large-3-675b-instruct-2512", "nvidia_nim/abacusai/dracarys-llama-3.1-70b-instruct", "nvidia_nim/bytedance/seed-oss-36b-instruct", "nvidia_nim/mistralai/mistral-nemotron", ] @dataclass(frozen=True, slots=True) class ResolvedModel: original_model: str provider_id: str provider_model: str provider_model_ref: str thinking_enabled: bool @dataclass(frozen=True, slots=True) class RoutedMessagesRequest: request: MessagesRequest resolved: ResolvedModel @dataclass(frozen=True, slots=True) class RoutedTokenCountRequest: request: TokenCountRequest resolved: ResolvedModel class ModelRouter: """Resolve incoming Claude model names to configured provider/model pairs.""" def __init__(self, settings: Settings): self._settings = settings def _is_auto(self, model_name: str) -> bool: """Return whether the model name refers to the virtual 'auto' model.""" name_lower = model_name.lower() return name_lower == "auto" or name_lower == "anthropic/auto" def _normalize_candidate_ref(self, raw_ref: str) -> str | None: """Normalize auto candidate refs to ``provider/model`` when possible.""" candidate = (raw_ref or "").strip() if not candidate: return None provider_id, separator, remainder = candidate.partition("/") if separator and provider_id in SUPPORTED_PROVIDER_IDS and remainder: return f"{provider_id}/{remainder}" # Treat bare model ids and vendor/model ids as NVIDIA NIM models. return f"nvidia_nim/{candidate}" def resolve(self, claude_model_name: str) -> ResolvedModel: # Special virtual model 'auto' maps to the configured default MODEL and # enables provider-side fallbacks. Resolve it to the configured model # while preserving the original requested name. if self._is_auto(claude_model_name): # If the user configured an explicit AUTO_MODEL_ORDER, try each # provider/model pair in order and pick the first provider that is # plausibly configured. Fall back to the single configured MODEL. order_csv = (self._settings.auto_model_order or "").strip() if order_csv: for cand in [c.strip() for c in order_csv.split(",") if c.strip()]: if "/" not in cand: # assume vendor-prefixed entries; skip malformed continue provider_id = Settings.parse_provider_type(cand) provider_model = Settings.parse_model_name(cand) if self._settings.provider_is_configured(provider_id): thinking_enabled = self._settings.resolve_thinking( claude_model_name ) return ResolvedModel( original_model=claude_model_name, provider_id=provider_id, provider_model=provider_model, provider_model_ref=cand, thinking_enabled=thinking_enabled, ) # No explicit order matched or none configured — fall back to default MODEL provider_model_ref = self._settings.model provider_id = Settings.parse_provider_type(provider_model_ref) provider_model = Settings.parse_model_name(provider_model_ref) thinking_enabled = self._settings.resolve_thinking(claude_model_name) return ResolvedModel( original_model=claude_model_name, provider_id=provider_id, provider_model=provider_model, provider_model_ref=provider_model_ref, thinking_enabled=thinking_enabled, ) ( direct_provider_id, direct_provider_model, force_thinking_enabled, ) = self._direct_provider_model(claude_model_name) if direct_provider_id is not None and direct_provider_model is not None: thinking_enabled = ( force_thinking_enabled if force_thinking_enabled is not None else self._settings.resolve_thinking(direct_provider_model) ) logger.debug( "MODEL DIRECT: '{}' -> provider='{}' model='{}' thinking={}", claude_model_name, direct_provider_id, direct_provider_model, thinking_enabled, ) return ResolvedModel( original_model=claude_model_name, provider_id=direct_provider_id, provider_model=direct_provider_model, provider_model_ref=claude_model_name, thinking_enabled=thinking_enabled, ) provider_model_ref = self._settings.resolve_model(claude_model_name) thinking_enabled = self._settings.resolve_thinking(claude_model_name) provider_id = Settings.parse_provider_type(provider_model_ref) provider_model = Settings.parse_model_name(provider_model_ref) if provider_model != claude_model_name: logger.debug( "MODEL MAPPING: '{}' -> '{}'", claude_model_name, provider_model ) return ResolvedModel( original_model=claude_model_name, provider_id=provider_id, provider_model=provider_model, provider_model_ref=provider_model_ref, thinking_enabled=thinking_enabled, ) def resolve_candidates(self, claude_model_name: str) -> list[ResolvedModel]: """Resolve a model name to a prioritized list of candidates. Used by the 'auto' routing logic to implement provider-side failover. Considers session load for fair resource sharing across multiple clients. Priority order: 1. AUTO_MODEL_ORDER (if configured) 2. MODEL (primary) 3. NVIDIA NIM fallback models (if configured, or DEFAULT_NIM_AUTO_MODELS) 4. MODEL_OPUS, MODEL_SONNET, MODEL_HAIKU """ if not self._is_auto(claude_model_name): return [self.resolve(claude_model_name)] healthy_candidates: list[ResolvedModel] = [] blocked_candidates: list[ResolvedModel] = [] seen: set[str] = set() session_tracker = SessionTracker.get_instance() def add_candidate(ref: str | None, source: str) -> None: normalized_ref = self._normalize_candidate_ref(ref or "") if normalized_ref is None or normalized_ref in seen: return provider_id = Settings.parse_provider_type(normalized_ref) provider_model = Settings.parse_model_name(normalized_ref) if self._settings.provider_is_configured(provider_id): seen.add(normalized_ref) resolved = ResolvedModel( original_model=claude_model_name, provider_id=provider_id, provider_model=provider_model, provider_model_ref=normalized_ref, thinking_enabled=self._settings.resolve_thinking(claude_model_name), ) limiter = GlobalRateLimiter.get_scoped_instance(provider_id) is_blocked = limiter.is_blocked() # For Zen provider, never consider it blocked (no rate limits) if provider_id == "zen": is_blocked = False # Check model health (recent failures) is_healthy = limiter.is_healthy(normalized_ref) if is_blocked or not is_healthy: reason = "BLOCKED" if is_blocked else "UNHEALTHY" logger.debug( "Routing: candidate '{}' (from {}) is {} (health={})", normalized_ref, source, reason, is_healthy, ) blocked_candidates.append(resolved) else: # Smart ordering: Zen (no rate limits) gets priority, then by load logger.debug( "Routing: added candidate '{}' (from {})", normalized_ref, source, ) healthy_candidates.append(resolved) else: logger.debug( "Routing: candidate '{}' (from {}) is NOT CONFIGURED", normalized_ref, source, ) # 1. AUTO_MODEL_ORDER (user-configured priority) order_csv = (self._settings.auto_model_order or "").strip() if order_csv: for cand in [c.strip() for c in order_csv.split(",") if c.strip()]: add_candidate(cand, "AUTO_MODEL_PRIORITY") # 2. Primary MODEL add_candidate(self._settings.model, "MODEL") # 3. NVIDIA Fallbacks - use configured or defaults nim_csv = (self._settings.nvidia_nim_fallback_models or "").strip() if nim_csv: for cand in [c.strip() for c in nim_csv.split(",") if c.strip()]: add_candidate(cand, "NVIDIA_NIM_FALLBACK_MODELS") else: # Use default NIM models when no explicit fallback configured for cand in DEFAULT_NIM_AUTO_MODELS: add_candidate(cand, "DEFAULT_NIM_AUTO_MODELS") # 4. Model-specific overrides add_candidate(self._settings.model_opus, "MODEL_OPUS") add_candidate(self._settings.model_sonnet, "MODEL_SONNET") add_candidate(self._settings.model_haiku, "MODEL_HAIKU") # Smart ordering: Zen goes first (no rate limits), then sort by load def provider_priority(c: ResolvedModel) -> tuple: # Priority: zen > others, then by active request count is_zen = 0 if c.provider_id == "zen" else 1 active = session_tracker._provider_active.get(c.provider_id, 0) return (is_zen, active) healthy_candidates.sort(key=provider_priority) all_candidates = healthy_candidates + blocked_candidates logger.info( "Routing: resolved '{}' to {} candidates: {}", claude_model_name, len(all_candidates), ", ".join(c.provider_model_ref for c in all_candidates), ) return all_candidates def _direct_provider_model( self, model_name: str ) -> tuple[str | None, str | None, bool | None]: decoded = decode_gateway_model_id(model_name) if decoded is not None: if decoded.provider_id not in SUPPORTED_PROVIDER_IDS: return None, None, None return ( decoded.provider_id, decoded.provider_model, decoded.force_thinking_enabled, ) provider_id, separator, provider_model = model_name.partition("/") if not separator: return None, None, None if provider_id not in SUPPORTED_PROVIDER_IDS: return None, None, None if not provider_model: return None, None, None return provider_id, provider_model, None def resolve_messages_request( self, request: MessagesRequest ) -> RoutedMessagesRequest: """Return an internal routed request context.""" resolved = self.resolve(request.model) routed = request.model_copy(deep=True) routed.model = resolved.provider_model return RoutedMessagesRequest(request=routed, resolved=resolved) def resolve_token_count_request( self, request: TokenCountRequest ) -> RoutedTokenCountRequest: """Return an internal token-count request context.""" resolved = self.resolve(request.model) routed = request.model_copy( update={"model": resolved.provider_model}, deep=True ) return RoutedTokenCountRequest(request=routed, resolved=resolved) def resolve_with_task_awareness( self, claude_model_name: str, messages: list, ) -> ResolvedModel: """Resolve model with task-based capability matching. For 'auto' model, detects task requirements and routes to best-capable model. """ if not self._is_auto(claude_model_name): return self.resolve(claude_model_name) # Detect what capabilities are needed detector = TaskDetector() requirements = detector.detect_requirements(messages) logger.info( "Task-aware routing: detected requirements={} confidence={:.2f}", requirements.required_capabilities, requirements.confidence, ) # Get available candidates candidates = self.resolve_candidates(claude_model_name) if not candidates: # Fallback to default return self.resolve(claude_model_name) # If confidence is low or only general text needed, use load-based selection if requirements.confidence < 0.7 or ( not requirements.requires_vision and not requirements.requires_coding and not requirements.requires_reasoning ): logger.debug( "Task-aware routing: low confidence, using load-based selection" ) return candidates[0] # Find best model matching required capabilities required_caps = set() if requirements.requires_coding: required_caps.add("coding") if requirements.requires_reasoning: required_caps.add("reasoning") if requirements.requires_vision: required_caps.add("vision") if required_caps: model_refs = [c.provider_model_ref for c in candidates] best = find_best_model_for_task(required_caps, model_refs) if best: # Find the matching candidate for cand in candidates: if cand.provider_model_ref == best.model_ref: logger.info( "Task-aware routing: selected {} for capabilities={}", best.model_ref, required_caps, ) return cand # Default to first candidate (load-balanced) return candidates[0] def get_routing_hint(self, messages: list) -> str: """Get a hint about what kind of model would be best.""" detector = TaskDetector() requirements = detector.detect_requirements(messages) return detector.get_priority_hint(requirements)