feat: add provider base class
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
345
src/providers/base.py
Normal file
345
src/providers/base.py
Normal file
@@ -0,0 +1,345 @@
|
|||||||
|
"""Abstract base class for AI chat providers."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from src.utils import redact_secrets
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Request timeouts (connect, read) in seconds
|
||||||
|
REQUEST_TIMEOUT = (10, 30)
|
||||||
|
|
||||||
|
# Retry configuration
|
||||||
|
MAX_RETRIES = 3
|
||||||
|
BACKOFF_BASE = 2.0
|
||||||
|
BACKOFF_MAX = 60.0
|
||||||
|
|
||||||
|
# Realistic Chrome User-Agent
|
||||||
|
USER_AGENT = (
|
||||||
|
"Mozilla/5.0 (X11; Linux x86_64) "
|
||||||
|
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||||
|
"Chrome/121.0.0.0 Safari/537.36"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderError(Exception):
|
||||||
|
"""Raised by provider methods when an operation fails.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
provider_name: e.g. "chatgpt" or "claude"
|
||||||
|
operation: e.g. "list_conversations" or "get_conversation"
|
||||||
|
original: The underlying exception that triggered this error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, provider_name: str, operation: str, original: Exception) -> None:
|
||||||
|
self.provider_name = provider_name
|
||||||
|
self.operation = operation
|
||||||
|
self.original = original
|
||||||
|
super().__init__(f"[{provider_name}] {operation} failed: {original}")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseProvider(ABC):
|
||||||
|
"""Abstract base for chat providers (ChatGPT, Claude, future FileProvider)."""
|
||||||
|
|
||||||
|
#: Provider identifier used in normalized schema and file paths
|
||||||
|
provider_name: str = ""
|
||||||
|
|
||||||
|
def __init__(self, session: requests.Session | None = None) -> None:
|
||||||
|
self._session = session or requests.Session()
|
||||||
|
self._session.headers.update(
|
||||||
|
{
|
||||||
|
"User-Agent": USER_AGENT,
|
||||||
|
"Accept": "application/json",
|
||||||
|
"Accept-Language": "en-US,en;q=0.9",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Abstract interface — subclasses must implement these
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_conversations(self, offset: int = 0, limit: int = 100) -> list[dict]:
|
||||||
|
"""Return one page of conversations from the provider API."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_conversation(self, conv_id: str) -> dict:
|
||||||
|
"""Return the full conversation detail for a single ID."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def normalize_conversation(self, raw: dict) -> dict:
|
||||||
|
"""Transform provider-specific schema to the common normalized schema."""
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Concrete helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def fetch_all_conversations(self, since: datetime | None = None) -> list[dict]:
|
||||||
|
"""Fetch every conversation, handling pagination automatically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
since: If provided, only return conversations updated after this
|
||||||
|
datetime. Filtering is done client-side.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of raw conversation dicts (not yet normalized).
|
||||||
|
"""
|
||||||
|
all_convs: list[dict] = []
|
||||||
|
offset = 0
|
||||||
|
limit = 100
|
||||||
|
page = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
page += 1
|
||||||
|
logger.debug(
|
||||||
|
"[%s] Fetching conversation list page %d (offset=%d, limit=%d)",
|
||||||
|
self.provider_name,
|
||||||
|
page,
|
||||||
|
offset,
|
||||||
|
limit,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
batch = self.list_conversations(offset=offset, limit=limit)
|
||||||
|
except ProviderError:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise ProviderError(self.provider_name, "fetch_all_conversations", e) from e
|
||||||
|
|
||||||
|
if not batch:
|
||||||
|
break
|
||||||
|
|
||||||
|
all_convs.extend(batch)
|
||||||
|
logger.debug(
|
||||||
|
"[%s] Page %d: got %d conversations (total so far: %d)",
|
||||||
|
self.provider_name,
|
||||||
|
page,
|
||||||
|
len(batch),
|
||||||
|
len(all_convs),
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(batch) < limit:
|
||||||
|
# Received fewer than requested — we've reached the end
|
||||||
|
break
|
||||||
|
|
||||||
|
offset += len(batch)
|
||||||
|
|
||||||
|
logger.info("[%s] Fetched %d total conversations", self.provider_name, len(all_convs))
|
||||||
|
|
||||||
|
# Client-side date filter
|
||||||
|
if since is not None:
|
||||||
|
since_aware = since if since.tzinfo else since.replace(tzinfo=None)
|
||||||
|
filtered = []
|
||||||
|
for c in all_convs:
|
||||||
|
updated_raw = c.get("updated_at") or c.get("update_time") or ""
|
||||||
|
if updated_raw:
|
||||||
|
try:
|
||||||
|
from src.utils import _parse_dt
|
||||||
|
updated = _parse_dt(updated_raw)
|
||||||
|
if updated.replace(tzinfo=None) >= since_aware.replace(tzinfo=None):
|
||||||
|
filtered.append(c)
|
||||||
|
except Exception:
|
||||||
|
filtered.append(c) # include if we can't parse the date
|
||||||
|
else:
|
||||||
|
filtered.append(c)
|
||||||
|
logger.info(
|
||||||
|
"[%s] After --since filter: %d/%d conversations",
|
||||||
|
self.provider_name,
|
||||||
|
len(filtered),
|
||||||
|
len(all_convs),
|
||||||
|
)
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
return all_convs
|
||||||
|
|
||||||
|
def _make_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
url: str,
|
||||||
|
debug_log_body: bool = False,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> dict | list:
|
||||||
|
"""Make an authenticated HTTP request with retry logic and unified error handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
method: HTTP method ("GET", "POST", etc.)
|
||||||
|
url: Full URL.
|
||||||
|
debug_log_body: If True, log redacted response body at DEBUG level.
|
||||||
|
**kwargs: Passed to requests.Session.request().
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed JSON response body.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ProviderError: On 401, exhausted retries, or unrecoverable errors.
|
||||||
|
"""
|
||||||
|
kwargs.setdefault("timeout", REQUEST_TIMEOUT)
|
||||||
|
|
||||||
|
attempt = 0
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
|
||||||
|
while attempt <= MAX_RETRIES:
|
||||||
|
attempt += 1
|
||||||
|
start = time.monotonic()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self._session.request(method, url, **kwargs)
|
||||||
|
elapsed_ms = int((time.monotonic() - start) * 1000)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"[%s] %s %s → %d %s (%dms)",
|
||||||
|
self.provider_name,
|
||||||
|
method,
|
||||||
|
url,
|
||||||
|
response.status_code,
|
||||||
|
response.reason,
|
||||||
|
elapsed_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 401: token expired / invalid ──────────────────────────
|
||||||
|
if response.status_code == 401:
|
||||||
|
self._handle_401()
|
||||||
|
# _handle_401 raises ProviderError — this line never runs
|
||||||
|
raise ProviderError(
|
||||||
|
self.provider_name,
|
||||||
|
f"{method} {url}",
|
||||||
|
RuntimeError("401 Unauthorized"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 429: rate limited ──────────────────────────────────────
|
||||||
|
if response.status_code == 429:
|
||||||
|
if attempt > MAX_RETRIES:
|
||||||
|
raise ProviderError(
|
||||||
|
self.provider_name,
|
||||||
|
f"{method} {url}",
|
||||||
|
RuntimeError("429 Rate Limited — max retries exhausted"),
|
||||||
|
)
|
||||||
|
wait = self._backoff_wait(response, attempt)
|
||||||
|
logger.warning(
|
||||||
|
"[%s] Rate limited. Retry %d/%d in %.1fs",
|
||||||
|
self.provider_name,
|
||||||
|
attempt,
|
||||||
|
MAX_RETRIES,
|
||||||
|
wait,
|
||||||
|
)
|
||||||
|
time.sleep(wait)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ── 5xx: server error ──────────────────────────────────────
|
||||||
|
if response.status_code >= 500:
|
||||||
|
if attempt > MAX_RETRIES:
|
||||||
|
raise ProviderError(
|
||||||
|
self.provider_name,
|
||||||
|
f"{method} {url}",
|
||||||
|
RuntimeError(
|
||||||
|
f"HTTP {response.status_code} after {MAX_RETRIES} retries"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
wait = self._backoff_wait(None, attempt)
|
||||||
|
logger.warning(
|
||||||
|
"[%s] HTTP %d. Retry %d/%d in %.1fs",
|
||||||
|
self.provider_name,
|
||||||
|
response.status_code,
|
||||||
|
attempt,
|
||||||
|
MAX_RETRIES,
|
||||||
|
wait,
|
||||||
|
)
|
||||||
|
time.sleep(wait)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ── Other HTTP errors ──────────────────────────────────────
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
# ── Success ────────────────────────────────────────────────
|
||||||
|
body = response.json()
|
||||||
|
|
||||||
|
if debug_log_body:
|
||||||
|
logger.debug(
|
||||||
|
"[%s] Response body (redacted): %s",
|
||||||
|
self.provider_name,
|
||||||
|
redact_secrets(body),
|
||||||
|
)
|
||||||
|
|
||||||
|
return body
|
||||||
|
|
||||||
|
except ProviderError:
|
||||||
|
raise
|
||||||
|
|
||||||
|
except (requests.ConnectionError, requests.Timeout) as e:
|
||||||
|
last_exc = e
|
||||||
|
if attempt > MAX_RETRIES:
|
||||||
|
raise ProviderError(
|
||||||
|
self.provider_name,
|
||||||
|
f"{method} {url}",
|
||||||
|
RuntimeError(
|
||||||
|
f"Network error after {MAX_RETRIES} retries: {e}. "
|
||||||
|
"Check your internet connection."
|
||||||
|
),
|
||||||
|
) from e
|
||||||
|
wait = self._backoff_wait(None, attempt)
|
||||||
|
logger.warning(
|
||||||
|
"[%s] Network error (attempt %d/%d): %s. Retrying in %.1fs",
|
||||||
|
self.provider_name,
|
||||||
|
attempt,
|
||||||
|
MAX_RETRIES,
|
||||||
|
e,
|
||||||
|
wait,
|
||||||
|
)
|
||||||
|
time.sleep(wait)
|
||||||
|
|
||||||
|
except requests.HTTPError as e:
|
||||||
|
raise ProviderError(
|
||||||
|
self.provider_name, f"{method} {url}", e
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# Should not reach here — last_exc is set if we exhausted network retries
|
||||||
|
raise ProviderError(
|
||||||
|
self.provider_name,
|
||||||
|
f"{method} {url}",
|
||||||
|
last_exc or RuntimeError("Unknown error"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _handle_401(self) -> None:
|
||||||
|
"""Log a clear human-readable message for a 401 and raise ProviderError."""
|
||||||
|
# Subclasses override to include provider-specific cookie name
|
||||||
|
msg = (
|
||||||
|
f"[{self.provider_name}] Authentication failed (401 Unauthorized). "
|
||||||
|
"Your session token has likely expired. "
|
||||||
|
"Run 'python -m src.main auth' to refresh your token."
|
||||||
|
)
|
||||||
|
logger.error(msg)
|
||||||
|
raise ProviderError(
|
||||||
|
self.provider_name,
|
||||||
|
"authentication",
|
||||||
|
RuntimeError("401 Unauthorized — token expired"),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _backoff_wait(response: requests.Response | None, attempt: int) -> float:
|
||||||
|
"""Compute exponential backoff with jitter."""
|
||||||
|
if response is not None:
|
||||||
|
retry_after = response.headers.get("Retry-After")
|
||||||
|
if retry_after:
|
||||||
|
try:
|
||||||
|
return float(retry_after)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
base = min(BACKOFF_BASE ** attempt, BACKOFF_MAX)
|
||||||
|
jitter = random.uniform(0, base * 0.3)
|
||||||
|
return base + jitter
|
||||||
|
|
||||||
|
def _warn_unexpected_schema(self, context: str, key: str) -> None:
|
||||||
|
"""Log a warning about a missing key in the API response."""
|
||||||
|
logger.warning(
|
||||||
|
"[%s] Unexpected API response shape in %s — missing key '%s'. "
|
||||||
|
"The API may have changed. Run with --debug for full response details.",
|
||||||
|
self.provider_name,
|
||||||
|
context,
|
||||||
|
key,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user