"""Application services for the Claude-compatible API.""" from __future__ import annotations import traceback import uuid from collections.abc import AsyncIterator, Callable from typing import Any from fastapi import HTTPException, Request from fastapi.responses import StreamingResponse from loguru import logger from config.settings import Settings, get_settings from core.anthropic import get_token_count, get_user_facing_error_message from core.anthropic.sse import ANTHROPIC_SSE_RESPONSE_HEADERS, format_sse_event from core.session_tracker import SessionTracker from providers.base import BaseProvider from providers.exceptions import ( InvalidRequestError, OverloadedError, ProviderError, RateLimitError, ) from .model_router import ModelRouter, ResolvedModel from .models.anthropic import MessagesRequest, TokenCountRequest from .models.responses import TokenCountResponse from .optimization_handlers import try_optimizations from .web_tools.egress import WebFetchEgressPolicy from .web_tools.request import ( is_web_server_tool_request, openai_chat_upstream_server_tool_error, ) TokenCounter = Callable[[list[Any], str | list[Any] | None, list[Any] | None], int] ProviderGetter = Callable[[str], BaseProvider] # Providers that use ``/chat/completions`` + Anthropic-to-OpenAI conversion (not native Messages). _OPENAI_CHAT_UPSTREAM_IDS = frozenset({"nvidia_nim", "groq", "cerebras", "silicon"}) def anthropic_sse_streaming_response( body: AsyncIterator[str], ) -> StreamingResponse: """Return a :class:`StreamingResponse` for Anthropic-style SSE streams.""" return StreamingResponse( body, media_type="text/event-stream", headers=ANTHROPIC_SSE_RESPONSE_HEADERS, ) def _http_status_for_unexpected_service_exception(_exc: BaseException) -> int: """HTTP status for uncaught non-provider failures (stable client contract).""" return 500 def _log_unexpected_service_exception( settings: Settings, exc: BaseException, *, context: str, request_id: str | None = None, ) -> None: """Log service-layer failures without echoing exception text unless opted in.""" if settings.log_api_error_tracebacks: if request_id is not None: logger.error("{} request_id={}: {}", context, request_id, exc) else: logger.error("{}: {}", context, exc) logger.error(traceback.format_exc()) return if request_id is not None: logger.error( "{} request_id={} exc_type={}", context, request_id, type(exc).__name__, ) else: logger.error("{} exc_type={}", context, type(exc).__name__) def _require_non_empty_messages(messages: list[Any]) -> None: if not messages: raise InvalidRequestError("messages cannot be empty") def _get_client_ip(request: Request) -> str | None: """Extract client IP from gateway headers or return None for direct connections.""" # Check for proxy/gateway headers forwarded = request.headers.get("X-Forwarded-For") if forwarded: return forwarded.split(",")[0].strip() real_ip = request.headers.get("X-Real-IP") if real_ip: return real_ip client_ip = request.headers.get("X-Client-IP") if client_ip: return client_ip via = request.headers.get("Via") if via: return request.client.host # Gateway/proxy IP return None # Direct connection def _get_session_id(request: Request) -> str: """Get session ID from X-Session-ID header or fall back to gateway IP. Claude Code sends X-Session-ID when started with --session-id . """ session = request.headers.get("X-Session-ID") if session: return session ip = _get_client_ip(request) return f"gateway_{ip}" if ip else "direct" class ClaudeProxyService: """Coordinate request optimization, model routing, and providers.""" def __init__( self, settings: Settings, provider_getter: ProviderGetter, model_router: ModelRouter | None = None, token_counter: TokenCounter = get_token_count, ): self._settings = settings self._provider_getter = provider_getter self._model_router = model_router or ModelRouter(settings) self._token_counter = token_counter settings_local = get_settings() self._session_tracker = SessionTracker.get_instance( retention_seconds=settings_local.session_retention_minutes * 60 ) def create_message(self, request: Request, request_data: MessagesRequest) -> object: """Create a message response or streaming response with optional failover.""" try: _require_non_empty_messages(request_data.messages) candidates = self._model_router.resolve_candidates(request_data.model) if not candidates: raise InvalidRequestError( f"No configured models available for '{request_data.model}'" ) # Debug log what we're routing to from loguru import logger logger.info( "REQUEST_MODEL_ROUTING: requested={} resolved_provider={} resolved_model={}", request_data.model, candidates[0].provider_id, candidates[0].provider_model, ) # For 'auto' requests with multiple candidates, we wrap the stream in a failover loop. if len(candidates) > 1: return anthropic_sse_streaming_response( self._stream_with_fallbacks(request, candidates, request_data) ) # Standard path for single-model requests return self._create_single_message(request, candidates[0], request_data) except ProviderError: raise except Exception as e: _log_unexpected_service_exception( self._settings, e, context="CREATE_MESSAGE_ERROR" ) raise HTTPException( status_code=_http_status_for_unexpected_service_exception(e), detail=get_user_facing_error_message(e), ) from e def _create_single_message( self, request: Request, resolved: ResolvedModel, request_data: MessagesRequest ) -> object: """Create a single message response from a resolved model.""" routed_request = request_data.model_copy(deep=True) routed_request.model = resolved.provider_model if resolved.provider_id in _OPENAI_CHAT_UPSTREAM_IDS: tool_err = openai_chat_upstream_server_tool_error( routed_request, web_tools_enabled=self._settings.enable_web_server_tools, ) if tool_err is not None: raise InvalidRequestError(tool_err) if self._settings.enable_web_server_tools and is_web_server_tool_request( routed_request ): from .web_tools.streaming import stream_web_server_tool_response input_tokens = self._token_counter( routed_request.messages, routed_request.system, routed_request.tools ) logger.info("Optimization: Handling Anthropic web server tool") egress = WebFetchEgressPolicy( allow_private_network_targets=self._settings.web_fetch_allow_private_networks, allowed_schemes=self._settings.web_fetch_allowed_scheme_set(), ) return anthropic_sse_streaming_response( stream_web_server_tool_response( routed_request, input_tokens=input_tokens, web_fetch_egress=egress, verbose_client_errors=self._settings.log_api_error_tracebacks, ), ) optimized = try_optimizations(routed_request, self._settings) if optimized is not None: return optimized provider = self._provider_getter(resolved.provider_id) provider.preflight_stream( routed_request, thinking_enabled=resolved.thinking_enabled, ) session_id = _get_session_id(request) self._session_tracker.track_request_sync(session_id, resolved.provider_id) request_id = f"req_{uuid.uuid4().hex[:12]}" logger.info( "API_REQUEST: request_id={} model={} messages={}", request_id, routed_request.model, len(routed_request.messages), ) input_tokens = self._token_counter( routed_request.messages, routed_request.system, routed_request.tools ) return anthropic_sse_streaming_response( provider.stream_response( routed_request, input_tokens=input_tokens, request_id=request_id, thinking_enabled=resolved.thinking_enabled, ), ) async def _stream_with_fallbacks( self, request: Request, candidates: list[ResolvedModel], request_data: MessagesRequest, ) -> AsyncIterator[str]: """Iterate through candidates until one succeeds or all fail.""" last_exc: Exception | None = None for i, resolved in enumerate(candidates): try: # Pre-check: skip candidates that are currently rate limited or unhealthy from providers.rate_limit import GlobalRateLimiter limiter = GlobalRateLimiter.get_scoped_instance(resolved.provider_id) if limiter.is_blocked() and resolved.provider_id != "zen": # Silently skip — no failure penalty for temporary rate limit logger.debug( "Skipping blocked provider '{}' (no penalty)", resolved.provider_id, ) continue # Check model health (recent failures) if not limiter.is_healthy(resolved.provider_model_ref): logger.warning( "Provider '{}' has recent failures, skipping to next candidate...", resolved.provider_model_ref, ) last_exc = Exception("Recent failures") continue provider = self._provider_getter(resolved.provider_id) routed_request = request_data.model_copy(deep=True) routed_request.model = resolved.provider_model provider.preflight_stream( routed_request, thinking_enabled=resolved.thinking_enabled, ) session_id = _get_session_id(request) self._session_tracker.track_request_sync( session_id, resolved.provider_id ) request_id = f"req_{uuid.uuid4().hex[:12]}" logger.info( "API_REQUEST (auto fallback {}/{}): request_id={} provider={} model={}", i + 1, len(candidates), request_id, resolved.provider_id, resolved.provider_model, ) input_tokens = self._token_counter( routed_request.messages, routed_request.system, routed_request.tools ) # Attempt to stream from this provider. async for event in provider.stream_response( routed_request, input_tokens=input_tokens, request_id=request_id, thinking_enabled=resolved.thinking_enabled, ): yield event # CRITICAL: If we have yielded even one event, we have committed to this provider. # We must not fallback to another candidate mid-stream. return # Success, exit the fallback loop. except (RateLimitError, OverloadedError) as e: logger.warning( "Provider '{}' is rate limited or overloaded ({}). Trying next candidate...", resolved.provider_id, e.status_code, ) limiter.record_failure(resolved.provider_model_ref) last_exc = e continue except TimeoutError as e: # Timeout = slow model, try next candidate for faster response logger.warning( "Provider '{}' timed out ({}). Trying next candidate...", resolved.provider_id, type(e).__name__, ) limiter.record_failure(resolved.provider_model_ref) last_exc = e continue except Exception as e: # Check if it's a transient error that should trigger fallback error_str = str(e).lower() is_transient = any( kw in error_str for kw in [ "timeout", "connection", "refused", "reset", "unavailable", "service", ] ) if is_transient: logger.warning( "Provider '{}' failed with transient error ({}): {}. Trying next candidate...", resolved.provider_id, type(e).__name__, e, ) limiter.record_failure(resolved.provider_model_ref) last_exc = e continue logger.error( "Provider '{}' failed with unexpected error: {}. Trying next candidate...", resolved.provider_id, e, ) last_exc = e continue err_msg = str(last_exc) if last_exc else "No candidates succeeded" yield format_sse_event( "error", { "type": "error", "error": { "type": "api_error", "message": f"All fallback candidates failed: {err_msg}", }, }, ) if last_exc: raise last_exc raise InvalidRequestError("No candidates succeeded") def count_tokens(self, request_data: TokenCountRequest) -> TokenCountResponse: """Count tokens for a request after applying configured model routing.""" request_id = f"req_{uuid.uuid4().hex[:12]}" with logger.contextualize(request_id=request_id): try: _require_non_empty_messages(request_data.messages) routed = self._model_router.resolve_token_count_request(request_data) tokens = self._token_counter( routed.request.messages, routed.request.system, routed.request.tools ) logger.info( "COUNT_TOKENS: request_id={} model={} messages={} input_tokens={}", request_id, routed.request.model, len(routed.request.messages), tokens, ) return TokenCountResponse(input_tokens=tokens) except ProviderError: raise except Exception as e: _log_unexpected_service_exception( self._settings, e, context="COUNT_TOKENS_ERROR", request_id=request_id, ) raise HTTPException( status_code=_http_status_for_unexpected_service_exception(e), detail=get_user_facing_error_message(e), ) from e