diff --git a/src/providers/base.py b/src/providers/base.py new file mode 100644 index 0000000..0661c5b --- /dev/null +++ b/src/providers/base.py @@ -0,0 +1,345 @@ +"""Abstract base class for AI chat providers.""" + +import logging +import random +import time +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any + +import requests + +from src.utils import redact_secrets + +logger = logging.getLogger(__name__) + +# Request timeouts (connect, read) in seconds +REQUEST_TIMEOUT = (10, 30) + +# Retry configuration +MAX_RETRIES = 3 +BACKOFF_BASE = 2.0 +BACKOFF_MAX = 60.0 + +# Realistic Chrome User-Agent +USER_AGENT = ( + "Mozilla/5.0 (X11; Linux x86_64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/121.0.0.0 Safari/537.36" +) + + +class ProviderError(Exception): + """Raised by provider methods when an operation fails. + + Attributes: + provider_name: e.g. "chatgpt" or "claude" + operation: e.g. "list_conversations" or "get_conversation" + original: The underlying exception that triggered this error. + """ + + def __init__(self, provider_name: str, operation: str, original: Exception) -> None: + self.provider_name = provider_name + self.operation = operation + self.original = original + super().__init__(f"[{provider_name}] {operation} failed: {original}") + + +class BaseProvider(ABC): + """Abstract base for chat providers (ChatGPT, Claude, future FileProvider).""" + + #: Provider identifier used in normalized schema and file paths + provider_name: str = "" + + def __init__(self, session: requests.Session | None = None) -> None: + self._session = session or requests.Session() + self._session.headers.update( + { + "User-Agent": USER_AGENT, + "Accept": "application/json", + "Accept-Language": "en-US,en;q=0.9", + } + ) + + # ------------------------------------------------------------------ + # Abstract interface — subclasses must implement these + # ------------------------------------------------------------------ + + @abstractmethod + def list_conversations(self, offset: int = 0, limit: int = 100) -> list[dict]: + """Return one page of conversations from the provider API.""" + + @abstractmethod + def get_conversation(self, conv_id: str) -> dict: + """Return the full conversation detail for a single ID.""" + + @abstractmethod + def normalize_conversation(self, raw: dict) -> dict: + """Transform provider-specific schema to the common normalized schema.""" + + # ------------------------------------------------------------------ + # Concrete helpers + # ------------------------------------------------------------------ + + def fetch_all_conversations(self, since: datetime | None = None) -> list[dict]: + """Fetch every conversation, handling pagination automatically. + + Args: + since: If provided, only return conversations updated after this + datetime. Filtering is done client-side. + + Returns: + List of raw conversation dicts (not yet normalized). + """ + all_convs: list[dict] = [] + offset = 0 + limit = 100 + page = 0 + + while True: + page += 1 + logger.debug( + "[%s] Fetching conversation list page %d (offset=%d, limit=%d)", + self.provider_name, + page, + offset, + limit, + ) + try: + batch = self.list_conversations(offset=offset, limit=limit) + except ProviderError: + raise + except Exception as e: + raise ProviderError(self.provider_name, "fetch_all_conversations", e) from e + + if not batch: + break + + all_convs.extend(batch) + logger.debug( + "[%s] Page %d: got %d conversations (total so far: %d)", + self.provider_name, + page, + len(batch), + len(all_convs), + ) + + if len(batch) < limit: + # Received fewer than requested — we've reached the end + break + + offset += len(batch) + + logger.info("[%s] Fetched %d total conversations", self.provider_name, len(all_convs)) + + # Client-side date filter + if since is not None: + since_aware = since if since.tzinfo else since.replace(tzinfo=None) + filtered = [] + for c in all_convs: + updated_raw = c.get("updated_at") or c.get("update_time") or "" + if updated_raw: + try: + from src.utils import _parse_dt + updated = _parse_dt(updated_raw) + if updated.replace(tzinfo=None) >= since_aware.replace(tzinfo=None): + filtered.append(c) + except Exception: + filtered.append(c) # include if we can't parse the date + else: + filtered.append(c) + logger.info( + "[%s] After --since filter: %d/%d conversations", + self.provider_name, + len(filtered), + len(all_convs), + ) + return filtered + + return all_convs + + def _make_request( + self, + method: str, + url: str, + debug_log_body: bool = False, + **kwargs: Any, + ) -> dict | list: + """Make an authenticated HTTP request with retry logic and unified error handling. + + Args: + method: HTTP method ("GET", "POST", etc.) + url: Full URL. + debug_log_body: If True, log redacted response body at DEBUG level. + **kwargs: Passed to requests.Session.request(). + + Returns: + Parsed JSON response body. + + Raises: + ProviderError: On 401, exhausted retries, or unrecoverable errors. + """ + kwargs.setdefault("timeout", REQUEST_TIMEOUT) + + attempt = 0 + last_exc: Exception | None = None + + while attempt <= MAX_RETRIES: + attempt += 1 + start = time.monotonic() + + try: + response = self._session.request(method, url, **kwargs) + elapsed_ms = int((time.monotonic() - start) * 1000) + + logger.debug( + "[%s] %s %s → %d %s (%dms)", + self.provider_name, + method, + url, + response.status_code, + response.reason, + elapsed_ms, + ) + + # ── 401: token expired / invalid ────────────────────────── + if response.status_code == 401: + self._handle_401() + # _handle_401 raises ProviderError — this line never runs + raise ProviderError( + self.provider_name, + f"{method} {url}", + RuntimeError("401 Unauthorized"), + ) + + # ── 429: rate limited ────────────────────────────────────── + if response.status_code == 429: + if attempt > MAX_RETRIES: + raise ProviderError( + self.provider_name, + f"{method} {url}", + RuntimeError("429 Rate Limited — max retries exhausted"), + ) + wait = self._backoff_wait(response, attempt) + logger.warning( + "[%s] Rate limited. Retry %d/%d in %.1fs", + self.provider_name, + attempt, + MAX_RETRIES, + wait, + ) + time.sleep(wait) + continue + + # ── 5xx: server error ────────────────────────────────────── + if response.status_code >= 500: + if attempt > MAX_RETRIES: + raise ProviderError( + self.provider_name, + f"{method} {url}", + RuntimeError( + f"HTTP {response.status_code} after {MAX_RETRIES} retries" + ), + ) + wait = self._backoff_wait(None, attempt) + logger.warning( + "[%s] HTTP %d. Retry %d/%d in %.1fs", + self.provider_name, + response.status_code, + attempt, + MAX_RETRIES, + wait, + ) + time.sleep(wait) + continue + + # ── Other HTTP errors ────────────────────────────────────── + response.raise_for_status() + + # ── Success ──────────────────────────────────────────────── + body = response.json() + + if debug_log_body: + logger.debug( + "[%s] Response body (redacted): %s", + self.provider_name, + redact_secrets(body), + ) + + return body + + except ProviderError: + raise + + except (requests.ConnectionError, requests.Timeout) as e: + last_exc = e + if attempt > MAX_RETRIES: + raise ProviderError( + self.provider_name, + f"{method} {url}", + RuntimeError( + f"Network error after {MAX_RETRIES} retries: {e}. " + "Check your internet connection." + ), + ) from e + wait = self._backoff_wait(None, attempt) + logger.warning( + "[%s] Network error (attempt %d/%d): %s. Retrying in %.1fs", + self.provider_name, + attempt, + MAX_RETRIES, + e, + wait, + ) + time.sleep(wait) + + except requests.HTTPError as e: + raise ProviderError( + self.provider_name, f"{method} {url}", e + ) from e + + # Should not reach here — last_exc is set if we exhausted network retries + raise ProviderError( + self.provider_name, + f"{method} {url}", + last_exc or RuntimeError("Unknown error"), + ) + + def _handle_401(self) -> None: + """Log a clear human-readable message for a 401 and raise ProviderError.""" + # Subclasses override to include provider-specific cookie name + msg = ( + f"[{self.provider_name}] Authentication failed (401 Unauthorized). " + "Your session token has likely expired. " + "Run 'python -m src.main auth' to refresh your token." + ) + logger.error(msg) + raise ProviderError( + self.provider_name, + "authentication", + RuntimeError("401 Unauthorized — token expired"), + ) + + @staticmethod + def _backoff_wait(response: requests.Response | None, attempt: int) -> float: + """Compute exponential backoff with jitter.""" + if response is not None: + retry_after = response.headers.get("Retry-After") + if retry_after: + try: + return float(retry_after) + except ValueError: + pass + base = min(BACKOFF_BASE ** attempt, BACKOFF_MAX) + jitter = random.uniform(0, base * 0.3) + return base + jitter + + def _warn_unexpected_schema(self, context: str, key: str) -> None: + """Log a warning about a missing key in the API response.""" + logger.warning( + "[%s] Unexpected API response shape in %s — missing key '%s'. " + "The API may have changed. Run with --debug for full response details.", + self.provider_name, + context, + key, + )