feat: add provider base class
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
345
src/providers/base.py
Normal file
345
src/providers/base.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Abstract base class for AI chat providers."""
|
||||
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
|
||||
from src.utils import redact_secrets
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Request timeouts (connect, read) in seconds
|
||||
REQUEST_TIMEOUT = (10, 30)
|
||||
|
||||
# Retry configuration
|
||||
MAX_RETRIES = 3
|
||||
BACKOFF_BASE = 2.0
|
||||
BACKOFF_MAX = 60.0
|
||||
|
||||
# Realistic Chrome User-Agent
|
||||
USER_AGENT = (
|
||||
"Mozilla/5.0 (X11; Linux x86_64) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/121.0.0.0 Safari/537.36"
|
||||
)
|
||||
|
||||
|
||||
class ProviderError(Exception):
|
||||
"""Raised by provider methods when an operation fails.
|
||||
|
||||
Attributes:
|
||||
provider_name: e.g. "chatgpt" or "claude"
|
||||
operation: e.g. "list_conversations" or "get_conversation"
|
||||
original: The underlying exception that triggered this error.
|
||||
"""
|
||||
|
||||
def __init__(self, provider_name: str, operation: str, original: Exception) -> None:
|
||||
self.provider_name = provider_name
|
||||
self.operation = operation
|
||||
self.original = original
|
||||
super().__init__(f"[{provider_name}] {operation} failed: {original}")
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""Abstract base for chat providers (ChatGPT, Claude, future FileProvider)."""
|
||||
|
||||
#: Provider identifier used in normalized schema and file paths
|
||||
provider_name: str = ""
|
||||
|
||||
def __init__(self, session: requests.Session | None = None) -> None:
|
||||
self._session = session or requests.Session()
|
||||
self._session.headers.update(
|
||||
{
|
||||
"User-Agent": USER_AGENT,
|
||||
"Accept": "application/json",
|
||||
"Accept-Language": "en-US,en;q=0.9",
|
||||
}
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Abstract interface — subclasses must implement these
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@abstractmethod
|
||||
def list_conversations(self, offset: int = 0, limit: int = 100) -> list[dict]:
|
||||
"""Return one page of conversations from the provider API."""
|
||||
|
||||
@abstractmethod
|
||||
def get_conversation(self, conv_id: str) -> dict:
|
||||
"""Return the full conversation detail for a single ID."""
|
||||
|
||||
@abstractmethod
|
||||
def normalize_conversation(self, raw: dict) -> dict:
|
||||
"""Transform provider-specific schema to the common normalized schema."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Concrete helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def fetch_all_conversations(self, since: datetime | None = None) -> list[dict]:
|
||||
"""Fetch every conversation, handling pagination automatically.
|
||||
|
||||
Args:
|
||||
since: If provided, only return conversations updated after this
|
||||
datetime. Filtering is done client-side.
|
||||
|
||||
Returns:
|
||||
List of raw conversation dicts (not yet normalized).
|
||||
"""
|
||||
all_convs: list[dict] = []
|
||||
offset = 0
|
||||
limit = 100
|
||||
page = 0
|
||||
|
||||
while True:
|
||||
page += 1
|
||||
logger.debug(
|
||||
"[%s] Fetching conversation list page %d (offset=%d, limit=%d)",
|
||||
self.provider_name,
|
||||
page,
|
||||
offset,
|
||||
limit,
|
||||
)
|
||||
try:
|
||||
batch = self.list_conversations(offset=offset, limit=limit)
|
||||
except ProviderError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ProviderError(self.provider_name, "fetch_all_conversations", e) from e
|
||||
|
||||
if not batch:
|
||||
break
|
||||
|
||||
all_convs.extend(batch)
|
||||
logger.debug(
|
||||
"[%s] Page %d: got %d conversations (total so far: %d)",
|
||||
self.provider_name,
|
||||
page,
|
||||
len(batch),
|
||||
len(all_convs),
|
||||
)
|
||||
|
||||
if len(batch) < limit:
|
||||
# Received fewer than requested — we've reached the end
|
||||
break
|
||||
|
||||
offset += len(batch)
|
||||
|
||||
logger.info("[%s] Fetched %d total conversations", self.provider_name, len(all_convs))
|
||||
|
||||
# Client-side date filter
|
||||
if since is not None:
|
||||
since_aware = since if since.tzinfo else since.replace(tzinfo=None)
|
||||
filtered = []
|
||||
for c in all_convs:
|
||||
updated_raw = c.get("updated_at") or c.get("update_time") or ""
|
||||
if updated_raw:
|
||||
try:
|
||||
from src.utils import _parse_dt
|
||||
updated = _parse_dt(updated_raw)
|
||||
if updated.replace(tzinfo=None) >= since_aware.replace(tzinfo=None):
|
||||
filtered.append(c)
|
||||
except Exception:
|
||||
filtered.append(c) # include if we can't parse the date
|
||||
else:
|
||||
filtered.append(c)
|
||||
logger.info(
|
||||
"[%s] After --since filter: %d/%d conversations",
|
||||
self.provider_name,
|
||||
len(filtered),
|
||||
len(all_convs),
|
||||
)
|
||||
return filtered
|
||||
|
||||
return all_convs
|
||||
|
||||
def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
url: str,
|
||||
debug_log_body: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> dict | list:
|
||||
"""Make an authenticated HTTP request with retry logic and unified error handling.
|
||||
|
||||
Args:
|
||||
method: HTTP method ("GET", "POST", etc.)
|
||||
url: Full URL.
|
||||
debug_log_body: If True, log redacted response body at DEBUG level.
|
||||
**kwargs: Passed to requests.Session.request().
|
||||
|
||||
Returns:
|
||||
Parsed JSON response body.
|
||||
|
||||
Raises:
|
||||
ProviderError: On 401, exhausted retries, or unrecoverable errors.
|
||||
"""
|
||||
kwargs.setdefault("timeout", REQUEST_TIMEOUT)
|
||||
|
||||
attempt = 0
|
||||
last_exc: Exception | None = None
|
||||
|
||||
while attempt <= MAX_RETRIES:
|
||||
attempt += 1
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
response = self._session.request(method, url, **kwargs)
|
||||
elapsed_ms = int((time.monotonic() - start) * 1000)
|
||||
|
||||
logger.debug(
|
||||
"[%s] %s %s → %d %s (%dms)",
|
||||
self.provider_name,
|
||||
method,
|
||||
url,
|
||||
response.status_code,
|
||||
response.reason,
|
||||
elapsed_ms,
|
||||
)
|
||||
|
||||
# ── 401: token expired / invalid ──────────────────────────
|
||||
if response.status_code == 401:
|
||||
self._handle_401()
|
||||
# _handle_401 raises ProviderError — this line never runs
|
||||
raise ProviderError(
|
||||
self.provider_name,
|
||||
f"{method} {url}",
|
||||
RuntimeError("401 Unauthorized"),
|
||||
)
|
||||
|
||||
# ── 429: rate limited ──────────────────────────────────────
|
||||
if response.status_code == 429:
|
||||
if attempt > MAX_RETRIES:
|
||||
raise ProviderError(
|
||||
self.provider_name,
|
||||
f"{method} {url}",
|
||||
RuntimeError("429 Rate Limited — max retries exhausted"),
|
||||
)
|
||||
wait = self._backoff_wait(response, attempt)
|
||||
logger.warning(
|
||||
"[%s] Rate limited. Retry %d/%d in %.1fs",
|
||||
self.provider_name,
|
||||
attempt,
|
||||
MAX_RETRIES,
|
||||
wait,
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
|
||||
# ── 5xx: server error ──────────────────────────────────────
|
||||
if response.status_code >= 500:
|
||||
if attempt > MAX_RETRIES:
|
||||
raise ProviderError(
|
||||
self.provider_name,
|
||||
f"{method} {url}",
|
||||
RuntimeError(
|
||||
f"HTTP {response.status_code} after {MAX_RETRIES} retries"
|
||||
),
|
||||
)
|
||||
wait = self._backoff_wait(None, attempt)
|
||||
logger.warning(
|
||||
"[%s] HTTP %d. Retry %d/%d in %.1fs",
|
||||
self.provider_name,
|
||||
response.status_code,
|
||||
attempt,
|
||||
MAX_RETRIES,
|
||||
wait,
|
||||
)
|
||||
time.sleep(wait)
|
||||
continue
|
||||
|
||||
# ── Other HTTP errors ──────────────────────────────────────
|
||||
response.raise_for_status()
|
||||
|
||||
# ── Success ────────────────────────────────────────────────
|
||||
body = response.json()
|
||||
|
||||
if debug_log_body:
|
||||
logger.debug(
|
||||
"[%s] Response body (redacted): %s",
|
||||
self.provider_name,
|
||||
redact_secrets(body),
|
||||
)
|
||||
|
||||
return body
|
||||
|
||||
except ProviderError:
|
||||
raise
|
||||
|
||||
except (requests.ConnectionError, requests.Timeout) as e:
|
||||
last_exc = e
|
||||
if attempt > MAX_RETRIES:
|
||||
raise ProviderError(
|
||||
self.provider_name,
|
||||
f"{method} {url}",
|
||||
RuntimeError(
|
||||
f"Network error after {MAX_RETRIES} retries: {e}. "
|
||||
"Check your internet connection."
|
||||
),
|
||||
) from e
|
||||
wait = self._backoff_wait(None, attempt)
|
||||
logger.warning(
|
||||
"[%s] Network error (attempt %d/%d): %s. Retrying in %.1fs",
|
||||
self.provider_name,
|
||||
attempt,
|
||||
MAX_RETRIES,
|
||||
e,
|
||||
wait,
|
||||
)
|
||||
time.sleep(wait)
|
||||
|
||||
except requests.HTTPError as e:
|
||||
raise ProviderError(
|
||||
self.provider_name, f"{method} {url}", e
|
||||
) from e
|
||||
|
||||
# Should not reach here — last_exc is set if we exhausted network retries
|
||||
raise ProviderError(
|
||||
self.provider_name,
|
||||
f"{method} {url}",
|
||||
last_exc or RuntimeError("Unknown error"),
|
||||
)
|
||||
|
||||
def _handle_401(self) -> None:
|
||||
"""Log a clear human-readable message for a 401 and raise ProviderError."""
|
||||
# Subclasses override to include provider-specific cookie name
|
||||
msg = (
|
||||
f"[{self.provider_name}] Authentication failed (401 Unauthorized). "
|
||||
"Your session token has likely expired. "
|
||||
"Run 'python -m src.main auth' to refresh your token."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise ProviderError(
|
||||
self.provider_name,
|
||||
"authentication",
|
||||
RuntimeError("401 Unauthorized — token expired"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _backoff_wait(response: requests.Response | None, attempt: int) -> float:
|
||||
"""Compute exponential backoff with jitter."""
|
||||
if response is not None:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
if retry_after:
|
||||
try:
|
||||
return float(retry_after)
|
||||
except ValueError:
|
||||
pass
|
||||
base = min(BACKOFF_BASE ** attempt, BACKOFF_MAX)
|
||||
jitter = random.uniform(0, base * 0.3)
|
||||
return base + jitter
|
||||
|
||||
def _warn_unexpected_schema(self, context: str, key: str) -> None:
|
||||
"""Log a warning about a missing key in the API response."""
|
||||
logger.warning(
|
||||
"[%s] Unexpected API response shape in %s — missing key '%s'. "
|
||||
"The API may have changed. Run with --debug for full response details.",
|
||||
self.provider_name,
|
||||
context,
|
||||
key,
|
||||
)
|
||||
Reference in New Issue
Block a user