feat: add provider base class

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
JesseMarkowitz
2026-02-27 22:56:07 -05:00
parent 6a32e127fd
commit 6073034789

345
src/providers/base.py Normal file
View File

@@ -0,0 +1,345 @@
"""Abstract base class for AI chat providers."""
import logging
import random
import time
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any
import requests
from src.utils import redact_secrets
logger = logging.getLogger(__name__)
# Request timeouts (connect, read) in seconds
REQUEST_TIMEOUT = (10, 30)
# Retry configuration
MAX_RETRIES = 3
BACKOFF_BASE = 2.0
BACKOFF_MAX = 60.0
# Realistic Chrome User-Agent
USER_AGENT = (
"Mozilla/5.0 (X11; Linux x86_64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/121.0.0.0 Safari/537.36"
)
class ProviderError(Exception):
"""Raised by provider methods when an operation fails.
Attributes:
provider_name: e.g. "chatgpt" or "claude"
operation: e.g. "list_conversations" or "get_conversation"
original: The underlying exception that triggered this error.
"""
def __init__(self, provider_name: str, operation: str, original: Exception) -> None:
self.provider_name = provider_name
self.operation = operation
self.original = original
super().__init__(f"[{provider_name}] {operation} failed: {original}")
class BaseProvider(ABC):
"""Abstract base for chat providers (ChatGPT, Claude, future FileProvider)."""
#: Provider identifier used in normalized schema and file paths
provider_name: str = ""
def __init__(self, session: requests.Session | None = None) -> None:
self._session = session or requests.Session()
self._session.headers.update(
{
"User-Agent": USER_AGENT,
"Accept": "application/json",
"Accept-Language": "en-US,en;q=0.9",
}
)
# ------------------------------------------------------------------
# Abstract interface — subclasses must implement these
# ------------------------------------------------------------------
@abstractmethod
def list_conversations(self, offset: int = 0, limit: int = 100) -> list[dict]:
"""Return one page of conversations from the provider API."""
@abstractmethod
def get_conversation(self, conv_id: str) -> dict:
"""Return the full conversation detail for a single ID."""
@abstractmethod
def normalize_conversation(self, raw: dict) -> dict:
"""Transform provider-specific schema to the common normalized schema."""
# ------------------------------------------------------------------
# Concrete helpers
# ------------------------------------------------------------------
def fetch_all_conversations(self, since: datetime | None = None) -> list[dict]:
"""Fetch every conversation, handling pagination automatically.
Args:
since: If provided, only return conversations updated after this
datetime. Filtering is done client-side.
Returns:
List of raw conversation dicts (not yet normalized).
"""
all_convs: list[dict] = []
offset = 0
limit = 100
page = 0
while True:
page += 1
logger.debug(
"[%s] Fetching conversation list page %d (offset=%d, limit=%d)",
self.provider_name,
page,
offset,
limit,
)
try:
batch = self.list_conversations(offset=offset, limit=limit)
except ProviderError:
raise
except Exception as e:
raise ProviderError(self.provider_name, "fetch_all_conversations", e) from e
if not batch:
break
all_convs.extend(batch)
logger.debug(
"[%s] Page %d: got %d conversations (total so far: %d)",
self.provider_name,
page,
len(batch),
len(all_convs),
)
if len(batch) < limit:
# Received fewer than requested — we've reached the end
break
offset += len(batch)
logger.info("[%s] Fetched %d total conversations", self.provider_name, len(all_convs))
# Client-side date filter
if since is not None:
since_aware = since if since.tzinfo else since.replace(tzinfo=None)
filtered = []
for c in all_convs:
updated_raw = c.get("updated_at") or c.get("update_time") or ""
if updated_raw:
try:
from src.utils import _parse_dt
updated = _parse_dt(updated_raw)
if updated.replace(tzinfo=None) >= since_aware.replace(tzinfo=None):
filtered.append(c)
except Exception:
filtered.append(c) # include if we can't parse the date
else:
filtered.append(c)
logger.info(
"[%s] After --since filter: %d/%d conversations",
self.provider_name,
len(filtered),
len(all_convs),
)
return filtered
return all_convs
def _make_request(
self,
method: str,
url: str,
debug_log_body: bool = False,
**kwargs: Any,
) -> dict | list:
"""Make an authenticated HTTP request with retry logic and unified error handling.
Args:
method: HTTP method ("GET", "POST", etc.)
url: Full URL.
debug_log_body: If True, log redacted response body at DEBUG level.
**kwargs: Passed to requests.Session.request().
Returns:
Parsed JSON response body.
Raises:
ProviderError: On 401, exhausted retries, or unrecoverable errors.
"""
kwargs.setdefault("timeout", REQUEST_TIMEOUT)
attempt = 0
last_exc: Exception | None = None
while attempt <= MAX_RETRIES:
attempt += 1
start = time.monotonic()
try:
response = self._session.request(method, url, **kwargs)
elapsed_ms = int((time.monotonic() - start) * 1000)
logger.debug(
"[%s] %s %s%d %s (%dms)",
self.provider_name,
method,
url,
response.status_code,
response.reason,
elapsed_ms,
)
# ── 401: token expired / invalid ──────────────────────────
if response.status_code == 401:
self._handle_401()
# _handle_401 raises ProviderError — this line never runs
raise ProviderError(
self.provider_name,
f"{method} {url}",
RuntimeError("401 Unauthorized"),
)
# ── 429: rate limited ──────────────────────────────────────
if response.status_code == 429:
if attempt > MAX_RETRIES:
raise ProviderError(
self.provider_name,
f"{method} {url}",
RuntimeError("429 Rate Limited — max retries exhausted"),
)
wait = self._backoff_wait(response, attempt)
logger.warning(
"[%s] Rate limited. Retry %d/%d in %.1fs",
self.provider_name,
attempt,
MAX_RETRIES,
wait,
)
time.sleep(wait)
continue
# ── 5xx: server error ──────────────────────────────────────
if response.status_code >= 500:
if attempt > MAX_RETRIES:
raise ProviderError(
self.provider_name,
f"{method} {url}",
RuntimeError(
f"HTTP {response.status_code} after {MAX_RETRIES} retries"
),
)
wait = self._backoff_wait(None, attempt)
logger.warning(
"[%s] HTTP %d. Retry %d/%d in %.1fs",
self.provider_name,
response.status_code,
attempt,
MAX_RETRIES,
wait,
)
time.sleep(wait)
continue
# ── Other HTTP errors ──────────────────────────────────────
response.raise_for_status()
# ── Success ────────────────────────────────────────────────
body = response.json()
if debug_log_body:
logger.debug(
"[%s] Response body (redacted): %s",
self.provider_name,
redact_secrets(body),
)
return body
except ProviderError:
raise
except (requests.ConnectionError, requests.Timeout) as e:
last_exc = e
if attempt > MAX_RETRIES:
raise ProviderError(
self.provider_name,
f"{method} {url}",
RuntimeError(
f"Network error after {MAX_RETRIES} retries: {e}. "
"Check your internet connection."
),
) from e
wait = self._backoff_wait(None, attempt)
logger.warning(
"[%s] Network error (attempt %d/%d): %s. Retrying in %.1fs",
self.provider_name,
attempt,
MAX_RETRIES,
e,
wait,
)
time.sleep(wait)
except requests.HTTPError as e:
raise ProviderError(
self.provider_name, f"{method} {url}", e
) from e
# Should not reach here — last_exc is set if we exhausted network retries
raise ProviderError(
self.provider_name,
f"{method} {url}",
last_exc or RuntimeError("Unknown error"),
)
def _handle_401(self) -> None:
"""Log a clear human-readable message for a 401 and raise ProviderError."""
# Subclasses override to include provider-specific cookie name
msg = (
f"[{self.provider_name}] Authentication failed (401 Unauthorized). "
"Your session token has likely expired. "
"Run 'python -m src.main auth' to refresh your token."
)
logger.error(msg)
raise ProviderError(
self.provider_name,
"authentication",
RuntimeError("401 Unauthorized — token expired"),
)
@staticmethod
def _backoff_wait(response: requests.Response | None, attempt: int) -> float:
"""Compute exponential backoff with jitter."""
if response is not None:
retry_after = response.headers.get("Retry-After")
if retry_after:
try:
return float(retry_after)
except ValueError:
pass
base = min(BACKOFF_BASE ** attempt, BACKOFF_MAX)
jitter = random.uniform(0, base * 0.3)
return base + jitter
def _warn_unexpected_schema(self, context: str, key: str) -> None:
"""Log a warning about a missing key in the API response."""
logger.warning(
"[%s] Unexpected API response shape in %s — missing key '%s'. "
"The API may have changed. Run with --debug for full response details.",
self.provider_name,
context,
key,
)