diff --git a/README.md b/README.md index de8f20f..62c972e 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,31 @@ Agent 配置路径优先级 - 如果要引用环境变量,请在 YAML 显式写 `${ENV_VAR}`。 - `siliconflow` 独立 section 已移除;请在 `agent.llm / agent.tts / agent.asr` 内通过 `provider`、`api_key`、`api_url`、`model` 配置。 +## Backend Integration + +Engine runtime now supports adapter-based backend integration: + +- `BACKEND_MODE=auto|http|disabled` +- `BACKEND_URL` + `BACKEND_TIMEOUT_SEC` +- `HISTORY_ENABLED=true|false` + +Behavior: + +- `auto`: use HTTP backend only when `BACKEND_URL` is set, otherwise engine-only mode. +- `http`: force HTTP backend; falls back to engine-only mode when URL is missing. +- `disabled`: force engine-only mode (no backend calls). + +History write path is now asynchronous and buffered per session: + +- `HISTORY_QUEUE_MAX_SIZE` +- `HISTORY_RETRY_MAX_ATTEMPTS` +- `HISTORY_RETRY_BACKOFF_SEC` +- `HISTORY_FINALIZE_DRAIN_TIMEOUT_SEC` + +This keeps turn processing responsive even when backend history APIs are slow/failing. + +Detailed notes: `docs/backend_integration.md`. + 测试 ``` diff --git a/app/backend_adapters.py b/app/backend_adapters.py new file mode 100644 index 0000000..a05bd8f --- /dev/null +++ b/app/backend_adapters.py @@ -0,0 +1,357 @@ +"""Backend adapter implementations for engine integration ports.""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional + +import aiohttp +from loguru import logger + +from app.config import settings + + +class NullBackendAdapter: + """No-op adapter for engine-only runtime without backend dependencies.""" + + async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]: + _ = assistant_id + return None + + async def create_call_record( + self, + *, + user_id: int, + assistant_id: Optional[str], + source: str = "debug", + ) -> Optional[str]: + _ = (user_id, assistant_id, source) + return None + + async def add_transcript( + self, + *, + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: Optional[float] = None, + duration_ms: Optional[int] = None, + ) -> bool: + _ = (call_id, turn_index, speaker, content, start_ms, end_ms, confidence, duration_ms) + return False + + async def finalize_call_record( + self, + *, + call_id: str, + status: str, + duration_seconds: int, + ) -> bool: + _ = (call_id, status, duration_seconds) + return False + + async def search_knowledge_context( + self, + *, + kb_id: str, + query: str, + n_results: int = 5, + ) -> List[Dict[str, Any]]: + _ = (kb_id, query, n_results) + return [] + + async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]: + _ = tool_id + return None + + +class HistoryDisabledBackendAdapter: + """Adapter wrapper that disables history writes while keeping reads available.""" + + def __init__(self, delegate: HttpBackendAdapter | NullBackendAdapter): + self._delegate = delegate + + async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]: + return await self._delegate.fetch_assistant_config(assistant_id) + + async def create_call_record( + self, + *, + user_id: int, + assistant_id: Optional[str], + source: str = "debug", + ) -> Optional[str]: + _ = (user_id, assistant_id, source) + return None + + async def add_transcript( + self, + *, + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: Optional[float] = None, + duration_ms: Optional[int] = None, + ) -> bool: + _ = (call_id, turn_index, speaker, content, start_ms, end_ms, confidence, duration_ms) + return False + + async def finalize_call_record( + self, + *, + call_id: str, + status: str, + duration_seconds: int, + ) -> bool: + _ = (call_id, status, duration_seconds) + return False + + async def search_knowledge_context( + self, + *, + kb_id: str, + query: str, + n_results: int = 5, + ) -> List[Dict[str, Any]]: + return await self._delegate.search_knowledge_context( + kb_id=kb_id, + query=query, + n_results=n_results, + ) + + async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]: + return await self._delegate.fetch_tool_resource(tool_id) + + +class HttpBackendAdapter: + """HTTP implementation of backend integration ports.""" + + def __init__(self, backend_url: str, timeout_sec: int = 10): + base_url = str(backend_url or "").strip().rstrip("/") + if not base_url: + raise ValueError("backend_url is required for HttpBackendAdapter") + self._base_url = base_url + self._timeout_sec = timeout_sec + + def _timeout(self) -> aiohttp.ClientTimeout: + return aiohttp.ClientTimeout(total=self._timeout_sec) + + async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]: + """Fetch assistant config payload from backend API. + + Expected response shape: + { + "assistant": {...}, + "voice": {...} | null + } + """ + url = f"{self._base_url}/api/assistants/{assistant_id}/config" + + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.get(url) as resp: + if resp.status == 404: + logger.warning(f"Assistant config not found: {assistant_id}") + return None + resp.raise_for_status() + payload = await resp.json() + if not isinstance(payload, dict): + logger.warning("Assistant config payload is not a dict; ignoring") + return None + return payload + except Exception as exc: + logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}") + return None + + async def create_call_record( + self, + *, + user_id: int, + assistant_id: Optional[str], + source: str = "debug", + ) -> Optional[str]: + """Create a call record via backend history API and return call_id.""" + url = f"{self._base_url}/api/history" + payload: Dict[str, Any] = { + "user_id": user_id, + "assistant_id": assistant_id, + "source": source, + "status": "connected", + } + + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + data = await resp.json() + call_id = str((data or {}).get("id") or "") + return call_id or None + except Exception as exc: + logger.warning(f"Failed to create history call record: {exc}") + return None + + async def add_transcript( + self, + *, + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: Optional[float] = None, + duration_ms: Optional[int] = None, + ) -> bool: + """Append a transcript segment to backend history.""" + if not call_id: + return False + + url = f"{self._base_url}/api/history/{call_id}/transcripts" + payload: Dict[str, Any] = { + "turn_index": turn_index, + "speaker": speaker, + "content": content, + "confidence": confidence, + "start_ms": start_ms, + "end_ms": end_ms, + "duration_ms": duration_ms, + } + + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + return True + except Exception as exc: + logger.warning(f"Failed to append history transcript (call_id={call_id}, turn={turn_index}): {exc}") + return False + + async def finalize_call_record( + self, + *, + call_id: str, + status: str, + duration_seconds: int, + ) -> bool: + """Finalize a call record with status and duration.""" + if not call_id: + return False + + url = f"{self._base_url}/api/history/{call_id}" + payload: Dict[str, Any] = { + "status": status, + "duration_seconds": duration_seconds, + } + + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.put(url, json=payload) as resp: + resp.raise_for_status() + return True + except Exception as exc: + logger.warning(f"Failed to finalize history call record ({call_id}): {exc}") + return False + + async def search_knowledge_context( + self, + *, + kb_id: str, + query: str, + n_results: int = 5, + ) -> List[Dict[str, Any]]: + """Search backend knowledge base and return retrieval results.""" + if not kb_id or not query.strip(): + return [] + try: + safe_n_results = max(1, int(n_results)) + except (TypeError, ValueError): + safe_n_results = 5 + + url = f"{self._base_url}/api/knowledge/search" + payload: Dict[str, Any] = { + "kb_id": kb_id, + "query": query, + "nResults": safe_n_results, + } + + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.post(url, json=payload) as resp: + if resp.status == 404: + logger.warning(f"Knowledge base not found for retrieval: {kb_id}") + return [] + resp.raise_for_status() + data = await resp.json() + if not isinstance(data, dict): + return [] + results = data.get("results", []) + if not isinstance(results, list): + return [] + return [r for r in results if isinstance(r, dict)] + except Exception as exc: + logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}") + return [] + + async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]: + """Fetch tool resource configuration from backend API.""" + if not tool_id: + return None + + url = f"{self._base_url}/api/tools/resources/{tool_id}" + try: + async with aiohttp.ClientSession(timeout=self._timeout()) as session: + async with session.get(url) as resp: + if resp.status == 404: + return None + resp.raise_for_status() + data = await resp.json() + return data if isinstance(data, dict) else None + except Exception as exc: + logger.warning(f"Failed to fetch tool resource ({tool_id}): {exc}") + return None + + +def build_backend_adapter( + *, + backend_url: Optional[str], + backend_mode: str = "auto", + history_enabled: bool = True, + timeout_sec: int = 10, +) -> HttpBackendAdapter | NullBackendAdapter | HistoryDisabledBackendAdapter: + """Create backend adapter implementation based on runtime settings.""" + mode = str(backend_mode or "auto").strip().lower() + has_url = bool(str(backend_url or "").strip()) + + base_adapter: HttpBackendAdapter | NullBackendAdapter + if mode in {"disabled", "off", "none", "null", "engine_only", "engine-only"}: + base_adapter = NullBackendAdapter() + elif mode == "http": + if has_url: + base_adapter = HttpBackendAdapter(backend_url=str(backend_url), timeout_sec=timeout_sec) + else: + logger.warning("BACKEND_MODE=http but BACKEND_URL is empty; falling back to NullBackendAdapter") + base_adapter = NullBackendAdapter() + else: + if has_url: + base_adapter = HttpBackendAdapter(backend_url=str(backend_url), timeout_sec=timeout_sec) + else: + base_adapter = NullBackendAdapter() + + if not history_enabled: + return HistoryDisabledBackendAdapter(base_adapter) + return base_adapter + + +def build_backend_adapter_from_settings() -> HttpBackendAdapter | NullBackendAdapter | HistoryDisabledBackendAdapter: + """Create backend adapter using current app settings.""" + return build_backend_adapter( + backend_url=settings.backend_url, + backend_mode=settings.backend_mode, + history_enabled=settings.history_enabled, + timeout_sec=settings.backend_timeout_sec, + ) diff --git a/app/backend_client.py b/app/backend_client.py index b750564..93ea183 100644 --- a/app/backend_client.py +++ b/app/backend_client.py @@ -1,56 +1,19 @@ -"""Backend API client for assistant config and history persistence.""" +"""Compatibility wrappers around backend adapter implementations.""" from __future__ import annotations from typing import Any, Dict, List, Optional -import aiohttp -from loguru import logger +from app.backend_adapters import build_backend_adapter_from_settings -from app.config import settings + +def _adapter(): + return build_backend_adapter_from_settings() async def fetch_assistant_config(assistant_id: str) -> Optional[Dict[str, Any]]: - """Fetch assistant config payload from backend API. - - Expected response shape: - { - "assistant": {...}, - "voice": {...} | null - } - """ - if not settings.backend_url: - logger.warning("BACKEND_URL not set; skipping assistant config fetch") - return None - - url = f"{settings.backend_url.rstrip('/')}/api/assistants/{assistant_id}/config" - timeout = aiohttp.ClientTimeout(total=settings.backend_timeout_sec) - - try: - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.get(url) as resp: - if resp.status == 404: - logger.warning(f"Assistant config not found: {assistant_id}") - return None - resp.raise_for_status() - payload = await resp.json() - if not isinstance(payload, dict): - logger.warning("Assistant config payload is not a dict; ignoring") - return None - return payload - except Exception as exc: - logger.warning(f"Failed to fetch assistant config ({assistant_id}): {exc}") - return None - - -def _backend_base_url() -> Optional[str]: - if not settings.backend_url: - return None - return settings.backend_url.rstrip("/") - - -def _timeout() -> aiohttp.ClientTimeout: - return aiohttp.ClientTimeout(total=settings.backend_timeout_sec) + """Fetch assistant config payload from backend adapter.""" + return await _adapter().fetch_assistant_config(assistant_id) async def create_history_call_record( @@ -60,28 +23,11 @@ async def create_history_call_record( source: str = "debug", ) -> Optional[str]: """Create a call record via backend history API and return call_id.""" - base_url = _backend_base_url() - if not base_url: - return None - - url = f"{base_url}/api/history" - payload: Dict[str, Any] = { - "user_id": user_id, - "assistant_id": assistant_id, - "source": source, - "status": "connected", - } - - try: - async with aiohttp.ClientSession(timeout=_timeout()) as session: - async with session.post(url, json=payload) as resp: - resp.raise_for_status() - data = await resp.json() - call_id = str((data or {}).get("id") or "") - return call_id or None - except Exception as exc: - logger.warning(f"Failed to create history call record: {exc}") - return None + return await _adapter().create_call_record( + user_id=user_id, + assistant_id=assistant_id, + source=source, + ) async def add_history_transcript( @@ -96,29 +42,16 @@ async def add_history_transcript( duration_ms: Optional[int] = None, ) -> bool: """Append a transcript segment to backend history.""" - base_url = _backend_base_url() - if not base_url or not call_id: - return False - - url = f"{base_url}/api/history/{call_id}/transcripts" - payload: Dict[str, Any] = { - "turn_index": turn_index, - "speaker": speaker, - "content": content, - "confidence": confidence, - "start_ms": start_ms, - "end_ms": end_ms, - "duration_ms": duration_ms, - } - - try: - async with aiohttp.ClientSession(timeout=_timeout()) as session: - async with session.post(url, json=payload) as resp: - resp.raise_for_status() - return True - except Exception as exc: - logger.warning(f"Failed to append history transcript (call_id={call_id}, turn={turn_index}): {exc}") - return False + return await _adapter().add_transcript( + call_id=call_id, + turn_index=turn_index, + speaker=speaker, + content=content, + start_ms=start_ms, + end_ms=end_ms, + confidence=confidence, + duration_ms=duration_ms, + ) async def finalize_history_call_record( @@ -128,24 +61,11 @@ async def finalize_history_call_record( duration_seconds: int, ) -> bool: """Finalize a call record with status and duration.""" - base_url = _backend_base_url() - if not base_url or not call_id: - return False - - url = f"{base_url}/api/history/{call_id}" - payload: Dict[str, Any] = { - "status": status, - "duration_seconds": duration_seconds, - } - - try: - async with aiohttp.ClientSession(timeout=_timeout()) as session: - async with session.put(url, json=payload) as resp: - resp.raise_for_status() - return True - except Exception as exc: - logger.warning(f"Failed to finalize history call record ({call_id}): {exc}") - return False + return await _adapter().finalize_call_record( + call_id=call_id, + status=status, + duration_seconds=duration_seconds, + ) async def search_knowledge_context( @@ -155,57 +75,13 @@ async def search_knowledge_context( n_results: int = 5, ) -> List[Dict[str, Any]]: """Search backend knowledge base and return retrieval results.""" - base_url = _backend_base_url() - if not base_url: - return [] - if not kb_id or not query.strip(): - return [] - try: - safe_n_results = max(1, int(n_results)) - except (TypeError, ValueError): - safe_n_results = 5 - - url = f"{base_url}/api/knowledge/search" - payload: Dict[str, Any] = { - "kb_id": kb_id, - "query": query, - "nResults": safe_n_results, - } - - try: - async with aiohttp.ClientSession(timeout=_timeout()) as session: - async with session.post(url, json=payload) as resp: - if resp.status == 404: - logger.warning(f"Knowledge base not found for retrieval: {kb_id}") - return [] - resp.raise_for_status() - data = await resp.json() - if not isinstance(data, dict): - return [] - results = data.get("results", []) - if not isinstance(results, list): - return [] - return [r for r in results if isinstance(r, dict)] - except Exception as exc: - logger.warning(f"Knowledge search failed (kb_id={kb_id}): {exc}") - return [] + return await _adapter().search_knowledge_context( + kb_id=kb_id, + query=query, + n_results=n_results, + ) async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]: """Fetch tool resource configuration from backend API.""" - base_url = _backend_base_url() - if not base_url or not tool_id: - return None - - url = f"{base_url}/api/tools/resources/{tool_id}" - try: - async with aiohttp.ClientSession(timeout=_timeout()) as session: - async with session.get(url) as resp: - if resp.status == 404: - return None - resp.raise_for_status() - data = await resp.json() - return data if isinstance(data, dict) else None - except Exception as exc: - logger.warning(f"Failed to fetch tool resource ({tool_id}): {exc}") - return None + return await _adapter().fetch_tool_resource(tool_id) diff --git a/app/config.py b/app/config.py index 7eaf74d..1f6a2ea 100644 --- a/app/config.py +++ b/app/config.py @@ -468,9 +468,21 @@ class Settings(BaseSettings): ws_require_auth: bool = Field(default=False, description="Require auth in hello message even when ws_api_key is not set") # Backend bridge configuration (for call/transcript persistence) + backend_mode: str = Field( + default="auto", + description="Backend integration mode: auto | http | disabled" + ) backend_url: Optional[str] = Field(default=None, description="Backend API base URL (e.g. http://localhost:8787)") backend_timeout_sec: int = Field(default=10, description="Backend API request timeout in seconds") + history_enabled: bool = Field(default=True, description="Enable history write bridge") history_default_user_id: int = Field(default=1, description="Fallback user_id for history records") + history_queue_max_size: int = Field(default=256, description="Max buffered transcript writes per session") + history_retry_max_attempts: int = Field(default=2, description="Retry attempts for each transcript write") + history_retry_backoff_sec: float = Field(default=0.2, description="Base retry backoff for transcript writes") + history_finalize_drain_timeout_sec: float = Field( + default=1.5, + description="Max wait before finalizing history when queue is still draining" + ) # Agent YAML metadata agent_config_path: Optional[str] = Field(default=None, description="Resolved agent YAML path") diff --git a/app/main.py b/app/main.py index bf05cc9..c13daba 100644 --- a/app/main.py +++ b/app/main.py @@ -20,6 +20,7 @@ except ImportError: logger.warning("aiortc not available - WebRTC endpoint will be disabled") from app.config import settings +from app.backend_adapters import build_backend_adapter_from_settings from core.transports import SocketTransport, WebRtcTransport, BaseTransport from core.session import Session from processors.tracks import Resampled16kTrack @@ -75,6 +76,7 @@ app.add_middleware( # Active sessions storage active_sessions: Dict[str, Session] = {} +backend_gateway = build_backend_adapter_from_settings() # Configure logging logger.remove() @@ -164,7 +166,7 @@ async def websocket_endpoint(websocket: WebSocket): # Create transport and session transport = SocketTransport(websocket) - session = Session(session_id, transport) + session = Session(session_id, transport, backend_gateway=backend_gateway) active_sessions[session_id] = session logger.info(f"WebSocket connection established: {session_id}") @@ -243,7 +245,7 @@ async def webrtc_endpoint(websocket: WebSocket): # Create transport and session transport = WebRtcTransport(websocket, pc) - session = Session(session_id, transport) + session = Session(session_id, transport, backend_gateway=backend_gateway) active_sessions[session_id] = session logger.info(f"WebRTC connection established: {session_id}") diff --git a/core/duplex_pipeline.py b/core/duplex_pipeline.py index a148fc6..551c63c 100644 --- a/core/duplex_pipeline.py +++ b/core/duplex_pipeline.py @@ -15,7 +15,7 @@ import asyncio import json import time import uuid -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple import numpy as np from loguru import logger @@ -86,7 +86,16 @@ class DuplexPipeline: tts_service: Optional[BaseTTSService] = None, asr_service: Optional[BaseASRService] = None, system_prompt: Optional[str] = None, - greeting: Optional[str] = None + greeting: Optional[str] = None, + knowledge_searcher: Optional[ + Callable[..., Awaitable[List[Dict[str, Any]]]] + ] = None, + tool_resource_resolver: Optional[ + Callable[[str], Awaitable[Optional[Dict[str, Any]]]] + ] = None, + server_tool_executor: Optional[ + Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]] + ] = None, ): """ Initialize duplex pipeline. @@ -127,6 +136,9 @@ class DuplexPipeline: self.llm_service = llm_service self.tts_service = tts_service self.asr_service = asr_service # Will be initialized in start() + self._knowledge_searcher = knowledge_searcher + self._tool_resource_resolver = tool_resource_resolver + self._server_tool_executor = server_tool_executor # Track last sent transcript to avoid duplicates self._last_sent_transcript = "" @@ -215,6 +227,18 @@ class DuplexPipeline: self._pending_llm_delta: str = "" self._last_llm_delta_emit_ms: float = 0.0 + if self._server_tool_executor is None: + if self._tool_resource_resolver: + async def _executor(call: Dict[str, Any]) -> Dict[str, Any]: + return await execute_server_tool( + call, + tool_resource_fetcher=self._tool_resource_resolver, + ) + + self._server_tool_executor = _executor + else: + self._server_tool_executor = execute_server_tool + logger.info(f"DuplexPipeline initialized for session {session_id}") def set_event_sequence_provider(self, provider: Callable[[], int]) -> None: @@ -559,6 +583,7 @@ class DuplexPipeline: base_url=llm_base_url, model=llm_model, knowledge_config=self._resolved_knowledge_config(), + knowledge_searcher=self._knowledge_searcher, ) else: logger.warning("LLM provider unsupported or API key missing - using mock LLM") @@ -1491,7 +1516,7 @@ class DuplexPipeline: try: result = await asyncio.wait_for( - execute_server_tool(call), + self._server_tool_executor(call), timeout=self._SERVER_TOOL_TIMEOUT_SECONDS, ) except asyncio.TimeoutError: diff --git a/core/history_bridge.py b/core/history_bridge.py new file mode 100644 index 0000000..ead9a3b --- /dev/null +++ b/core/history_bridge.py @@ -0,0 +1,244 @@ +"""Async history bridge for non-blocking transcript persistence.""" + +from __future__ import annotations + +import asyncio +import time +from dataclasses import dataclass +from typing import Any, Optional + +from loguru import logger + + +@dataclass +class _HistoryTranscriptJob: + call_id: str + turn_index: int + speaker: str + content: str + start_ms: int + end_ms: int + duration_ms: int + + +class SessionHistoryBridge: + """Session-scoped buffered history writer with background retries.""" + + _STOP_SENTINEL = object() + + def __init__( + self, + *, + history_writer: Any, + enabled: bool, + queue_max_size: int, + retry_max_attempts: int, + retry_backoff_sec: float, + finalize_drain_timeout_sec: float, + ): + self._history_writer = history_writer + self._enabled = bool(enabled and history_writer is not None) + self._queue_max_size = max(1, int(queue_max_size)) + self._retry_max_attempts = max(0, int(retry_max_attempts)) + self._retry_backoff_sec = max(0.0, float(retry_backoff_sec)) + self._finalize_drain_timeout_sec = max(0.0, float(finalize_drain_timeout_sec)) + + self._call_id: Optional[str] = None + self._turn_index: int = 0 + self._started_mono: Optional[float] = None + self._finalized: bool = False + self._worker_task: Optional[asyncio.Task] = None + self._finalize_lock = asyncio.Lock() + self._queue: asyncio.Queue[_HistoryTranscriptJob | object] = asyncio.Queue(maxsize=self._queue_max_size) + + @property + def enabled(self) -> bool: + return self._enabled + + @property + def call_id(self) -> Optional[str]: + return self._call_id + + async def start_call( + self, + *, + user_id: int, + assistant_id: Optional[str], + source: str, + ) -> Optional[str]: + """Create remote call record and start background worker.""" + if not self._enabled or self._call_id: + return self._call_id + + call_id = await self._history_writer.create_call_record( + user_id=user_id, + assistant_id=assistant_id, + source=source, + ) + if not call_id: + return None + + self._call_id = str(call_id) + self._turn_index = 0 + self._finalized = False + self._started_mono = time.monotonic() + self._ensure_worker() + return self._call_id + + def elapsed_ms(self) -> int: + if self._started_mono is None: + return 0 + return max(0, int((time.monotonic() - self._started_mono) * 1000)) + + def enqueue_turn(self, *, role: str, text: str) -> bool: + """Queue one transcript write without blocking the caller.""" + if not self._enabled or not self._call_id or self._finalized: + return False + + content = str(text or "").strip() + if not content: + return False + + speaker = "human" if str(role or "").strip().lower() == "user" else "ai" + end_ms = self.elapsed_ms() + estimated_duration_ms = max(300, min(12000, len(content) * 80)) + start_ms = max(0, end_ms - estimated_duration_ms) + + job = _HistoryTranscriptJob( + call_id=self._call_id, + turn_index=self._turn_index, + speaker=speaker, + content=content, + start_ms=start_ms, + end_ms=end_ms, + duration_ms=max(1, end_ms - start_ms), + ) + self._turn_index += 1 + self._ensure_worker() + + try: + self._queue.put_nowait(job) + return True + except asyncio.QueueFull: + logger.warning( + "History queue full; dropping transcript call_id={} turn={}", + self._call_id, + job.turn_index, + ) + return False + + async def finalize(self, *, status: str) -> bool: + """Finalize history record once; waits briefly for queue drain.""" + if not self._enabled or not self._call_id: + return False + + async with self._finalize_lock: + if self._finalized: + return True + + await self._drain_queue() + ok = await self._history_writer.finalize_call_record( + call_id=self._call_id, + status=status, + duration_seconds=self.duration_seconds(), + ) + if ok: + self._finalized = True + await self._stop_worker() + return ok + + async def shutdown(self) -> None: + """Stop worker task and release queue resources.""" + await self._stop_worker() + + def duration_seconds(self) -> int: + if self._started_mono is None: + return 0 + return max(0, int(time.monotonic() - self._started_mono)) + + def _ensure_worker(self) -> None: + if self._worker_task and not self._worker_task.done(): + return + self._worker_task = asyncio.create_task(self._worker_loop()) + + async def _drain_queue(self) -> None: + if self._finalize_drain_timeout_sec <= 0: + return + try: + await asyncio.wait_for(self._queue.join(), timeout=self._finalize_drain_timeout_sec) + except asyncio.TimeoutError: + logger.warning("History queue drain timed out after {}s", self._finalize_drain_timeout_sec) + + async def _stop_worker(self) -> None: + task = self._worker_task + if not task: + return + if task.done(): + self._worker_task = None + return + + sent = False + try: + self._queue.put_nowait(self._STOP_SENTINEL) + sent = True + except asyncio.QueueFull: + pass + + if not sent: + try: + await asyncio.wait_for(self._queue.put(self._STOP_SENTINEL), timeout=0.5) + except asyncio.TimeoutError: + task.cancel() + + try: + await asyncio.wait_for(task, timeout=1.5) + except asyncio.TimeoutError: + task.cancel() + try: + await task + except Exception: + pass + except asyncio.CancelledError: + pass + finally: + self._worker_task = None + + async def _worker_loop(self) -> None: + while True: + item = await self._queue.get() + try: + if item is self._STOP_SENTINEL: + return + + assert isinstance(item, _HistoryTranscriptJob) + await self._write_with_retry(item) + except Exception as exc: + logger.warning("History worker write failed unexpectedly: {}", exc) + finally: + self._queue.task_done() + + async def _write_with_retry(self, job: _HistoryTranscriptJob) -> bool: + for attempt in range(self._retry_max_attempts + 1): + ok = await self._history_writer.add_transcript( + call_id=job.call_id, + turn_index=job.turn_index, + speaker=job.speaker, + content=job.content, + start_ms=job.start_ms, + end_ms=job.end_ms, + duration_ms=job.duration_ms, + ) + if ok: + return True + + if attempt >= self._retry_max_attempts: + logger.warning( + "History write dropped after retries call_id={} turn={}", + job.call_id, + job.turn_index, + ) + return False + + if self._retry_backoff_sec > 0: + await asyncio.sleep(self._retry_backoff_sec * (2**attempt)) + return False diff --git a/core/ports/__init__.py b/core/ports/__init__.py new file mode 100644 index 0000000..7d7c9dd --- /dev/null +++ b/core/ports/__init__.py @@ -0,0 +1,17 @@ +"""Port interfaces for engine-side integration boundaries.""" + +from core.ports.backend import ( + AssistantConfigProvider, + BackendGateway, + HistoryWriter, + KnowledgeSearcher, + ToolResourceResolver, +) + +__all__ = [ + "AssistantConfigProvider", + "BackendGateway", + "HistoryWriter", + "KnowledgeSearcher", + "ToolResourceResolver", +] diff --git a/core/ports/backend.py b/core/ports/backend.py new file mode 100644 index 0000000..227c743 --- /dev/null +++ b/core/ports/backend.py @@ -0,0 +1,84 @@ +"""Backend integration ports. + +These interfaces define the boundary between engine runtime logic and +backend-side capabilities (config lookup, history persistence, retrieval, +and tool resource discovery). +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Protocol + + +class AssistantConfigProvider(Protocol): + """Port for loading trusted assistant runtime configuration.""" + + async def fetch_assistant_config(self, assistant_id: str) -> Optional[Dict[str, Any]]: + """Fetch assistant configuration payload.""" + + +class HistoryWriter(Protocol): + """Port for persisting call and transcript history.""" + + async def create_call_record( + self, + *, + user_id: int, + assistant_id: Optional[str], + source: str = "debug", + ) -> Optional[str]: + """Create a call record and return backend call ID.""" + + async def add_transcript( + self, + *, + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: Optional[float] = None, + duration_ms: Optional[int] = None, + ) -> bool: + """Append one transcript turn segment.""" + + async def finalize_call_record( + self, + *, + call_id: str, + status: str, + duration_seconds: int, + ) -> bool: + """Finalize a call record.""" + + +class KnowledgeSearcher(Protocol): + """Port for RAG / knowledge retrieval operations.""" + + async def search_knowledge_context( + self, + *, + kb_id: str, + query: str, + n_results: int = 5, + ) -> List[Dict[str, Any]]: + """Search a knowledge source and return ranked snippets.""" + + +class ToolResourceResolver(Protocol): + """Port for resolving tool metadata/configuration.""" + + async def fetch_tool_resource(self, tool_id: str) -> Optional[Dict[str, Any]]: + """Fetch tool resource configuration.""" + + +class BackendGateway( + AssistantConfigProvider, + HistoryWriter, + KnowledgeSearcher, + ToolResourceResolver, + Protocol, +): + """Composite backend gateway interface used by engine services.""" + diff --git a/core/session.py b/core/session.py index b741425..2be41e2 100644 --- a/core/session.py +++ b/core/session.py @@ -9,15 +9,11 @@ from enum import Enum from typing import Optional, Dict, Any, List from loguru import logger -from app.backend_client import ( - fetch_assistant_config, - create_history_call_record, - add_history_transcript, - finalize_history_call_record, -) +from app.backend_adapters import build_backend_adapter_from_settings from core.transports import BaseTransport from core.duplex_pipeline import DuplexPipeline from core.conversation import ConversationTurn +from core.history_bridge import SessionHistoryBridge from core.workflow_runner import WorkflowRunner, WorkflowTransition, WorkflowNodeDef, WorkflowEdgeDef from app.config import settings from services.base import LLMMessage @@ -76,7 +72,13 @@ class Session: "config_version_id", } - def __init__(self, session_id: str, transport: BaseTransport, use_duplex: bool = None): + def __init__( + self, + session_id: str, + transport: BaseTransport, + use_duplex: bool = None, + backend_gateway: Optional[Any] = None, + ): """ Initialize session. @@ -88,12 +90,23 @@ class Session: self.id = session_id self.transport = transport self.use_duplex = use_duplex if use_duplex is not None else settings.duplex_enabled + self._backend_gateway = backend_gateway or build_backend_adapter_from_settings() + self._history_bridge = SessionHistoryBridge( + history_writer=self._backend_gateway, + enabled=settings.history_enabled, + queue_max_size=settings.history_queue_max_size, + retry_max_attempts=settings.history_retry_max_attempts, + retry_backoff_sec=settings.history_retry_backoff_sec, + finalize_drain_timeout_sec=settings.history_finalize_drain_timeout_sec, + ) self.pipeline = DuplexPipeline( transport=transport, session_id=session_id, system_prompt=settings.duplex_system_prompt, - greeting=settings.duplex_greeting + greeting=settings.duplex_greeting, + knowledge_searcher=getattr(self._backend_gateway, "search_knowledge_context", None), + tool_resource_resolver=getattr(self._backend_gateway, "fetch_tool_resource", None), ) # Session state @@ -107,10 +120,6 @@ class Session: # Track IDs self.current_track_id: str = self.TRACK_CONTROL self._event_seq: int = 0 - self._history_call_id: Optional[str] = None - self._history_turn_index: int = 0 - self._history_call_started_mono: Optional[float] = None - self._history_finalized: bool = False self._cleanup_lock = asyncio.Lock() self._cleaned_up = False self.workflow_runner: Optional[WorkflowRunner] = None @@ -424,11 +433,12 @@ class Session: logger.info(f"Session {self.id} cleaning up") await self._finalize_history(status="connected") await self.pipeline.cleanup() + await self._history_bridge.shutdown() await self.transport.close() async def _start_history_bridge(self, metadata: Dict[str, Any]) -> None: """Initialize backend history call record for this session.""" - if self._history_call_id: + if self._history_bridge.call_id: return history_meta: Dict[str, Any] = {} @@ -444,7 +454,7 @@ class Session: assistant_id = history_meta.get("assistantId", metadata.get("assistantId")) source = str(history_meta.get("source", metadata.get("source", "debug"))) - call_id = await create_history_call_record( + call_id = await self._history_bridge.start_call( user_id=user_id, assistant_id=str(assistant_id) if assistant_id else None, source=source, @@ -452,10 +462,6 @@ class Session: if not call_id: return - self._history_call_id = call_id - self._history_call_started_mono = time.monotonic() - self._history_turn_index = 0 - self._history_finalized = False logger.info(f"Session {self.id} history bridge enabled (call_id={call_id}, source={source})") async def _on_turn_complete(self, turn: ConversationTurn) -> None: @@ -467,48 +473,11 @@ class Session: elif role == "assistant": await self._maybe_advance_workflow(turn.text.strip()) - if not self._history_call_id: - return - if not turn.text or not turn.text.strip(): - return - - role = (turn.role or "").lower() - speaker = "human" if role == "user" else "ai" - - end_ms = 0 - if self._history_call_started_mono is not None: - end_ms = max(0, int((time.monotonic() - self._history_call_started_mono) * 1000)) - estimated_duration_ms = max(300, min(12000, len(turn.text.strip()) * 80)) - start_ms = max(0, end_ms - estimated_duration_ms) - - turn_index = self._history_turn_index - await add_history_transcript( - call_id=self._history_call_id, - turn_index=turn_index, - speaker=speaker, - content=turn.text.strip(), - start_ms=start_ms, - end_ms=end_ms, - duration_ms=max(1, end_ms - start_ms), - ) - self._history_turn_index += 1 + self._history_bridge.enqueue_turn(role=turn.role or "", text=turn.text or "") async def _finalize_history(self, status: str) -> None: """Finalize history call record once.""" - if not self._history_call_id or self._history_finalized: - return - - duration_seconds = 0 - if self._history_call_started_mono is not None: - duration_seconds = max(0, int(time.monotonic() - self._history_call_started_mono)) - - ok = await finalize_history_call_record( - call_id=self._history_call_id, - status=status, - duration_seconds=duration_seconds, - ) - if ok: - self._history_finalized = True + await self._history_bridge.finalize(status=status) def _bootstrap_workflow(self, metadata: Dict[str, Any]) -> Dict[str, Any]: """Parse workflow payload and return initial runtime overrides.""" @@ -795,10 +764,12 @@ class Session: ) if assistant_id is None: return {} - if not settings.backend_url: + + provider = getattr(self._backend_gateway, "fetch_assistant_config", None) + if not callable(provider): return {} - payload = await fetch_assistant_config(str(assistant_id).strip()) + payload = await provider(str(assistant_id).strip()) if not isinstance(payload, dict): return {} diff --git a/core/tool_executor.py b/core/tool_executor.py index 407e199..4505436 100644 --- a/core/tool_executor.py +++ b/core/tool_executor.py @@ -4,11 +4,13 @@ import asyncio import ast import operator from datetime import datetime -from typing import Any, Dict +from typing import Any, Awaitable, Callable, Dict, Optional import aiohttp -from app.backend_client import fetch_tool_resource +from app.backend_adapters import build_backend_adapter_from_settings + +ToolResourceFetcher = Callable[[str], Awaitable[Optional[Dict[str, Any]]]] _BIN_OPS = { ast.Add: operator.add, @@ -170,11 +172,21 @@ def _extract_tool_args(tool_call: Dict[str, Any]) -> Dict[str, Any]: return {} -async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]: +async def fetch_tool_resource(tool_id: str) -> Optional[Dict[str, Any]]: + """Default tool resource resolver via backend adapter.""" + adapter = build_backend_adapter_from_settings() + return await adapter.fetch_tool_resource(tool_id) + + +async def execute_server_tool( + tool_call: Dict[str, Any], + tool_resource_fetcher: Optional[ToolResourceFetcher] = None, +) -> Dict[str, Any]: """Execute a server-side tool and return normalized result payload.""" call_id = str(tool_call.get("id") or "").strip() tool_name = _extract_tool_name(tool_call) args = _extract_tool_args(tool_call) + resource_fetcher = tool_resource_fetcher or fetch_tool_resource if tool_name == "calculator": expression = str(args.get("expression") or "").strip() @@ -257,7 +269,7 @@ async def execute_server_tool(tool_call: Dict[str, Any]) -> Dict[str, Any]: } if tool_name and tool_name not in {"calculator", "code_interpreter", "current_time"}: - resource = await fetch_tool_resource(tool_name) + resource = await resource_fetcher(tool_name) if resource and str(resource.get("category") or "") == "query": method = str(resource.get("http_method") or "GET").strip().upper() if method not in {"GET", "POST", "PUT", "PATCH", "DELETE"}: diff --git a/docs/backend_integration.md b/docs/backend_integration.md new file mode 100644 index 0000000..1f5d14d --- /dev/null +++ b/docs/backend_integration.md @@ -0,0 +1,47 @@ +# Backend Integration and History Bridge + +This engine uses adapter-based backend integration so core runtime logic can run +with or without an external backend service. + +## Runtime Modes + +Configure with environment variables: + +- `BACKEND_MODE=auto|http|disabled` +- `BACKEND_URL` +- `BACKEND_TIMEOUT_SEC` +- `HISTORY_ENABLED=true|false` + +Mode behavior: + +- `auto`: use HTTP backend adapter only when `BACKEND_URL` is set. +- `http`: force HTTP backend adapter (falls back to null adapter when URL is missing). +- `disabled`: force null adapter and run engine-only. + +## Architecture + +- Ports: `core/ports/backend.py` +- Adapters: `app/backend_adapters.py` +- Compatibility wrappers: `app/backend_client.py` + +`Session` and `DuplexPipeline` receive backend capabilities via injected adapter +methods instead of hard-coding backend client imports. + +## Async History Writes + +Session history persistence is handled by `core/history_bridge.py`. + +Design: + +- transcript writes are queued with `put_nowait` (non-blocking turn path) +- background worker drains queue +- failed writes retry with exponential backoff +- finalize waits briefly for queue drain before sending call finalize +- finalize is idempotent + +Related settings: + +- `HISTORY_QUEUE_MAX_SIZE` +- `HISTORY_RETRY_MAX_ATTEMPTS` +- `HISTORY_RETRY_BACKOFF_SEC` +- `HISTORY_FINALIZE_DRAIN_TIMEOUT_SEC` diff --git a/services/llm.py b/services/llm.py index 51bfbe4..eb7f89c 100644 --- a/services/llm.py +++ b/services/llm.py @@ -7,10 +7,10 @@ for real-time voice conversation. import os import asyncio import uuid -from typing import AsyncIterator, Optional, List, Dict, Any +from typing import AsyncIterator, Optional, List, Dict, Any, Callable, Awaitable from loguru import logger -from app.backend_client import search_knowledge_context +from app.backend_adapters import build_backend_adapter_from_settings from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState # Try to import openai @@ -37,6 +37,7 @@ class OpenAILLMService(BaseLLMService): base_url: Optional[str] = None, system_prompt: Optional[str] = None, knowledge_config: Optional[Dict[str, Any]] = None, + knowledge_searcher: Optional[Callable[..., Awaitable[List[Dict[str, Any]]]]] = None, ): """ Initialize OpenAI LLM service. @@ -60,6 +61,11 @@ class OpenAILLMService(BaseLLMService): self.client: Optional[AsyncOpenAI] = None self._cancel_event = asyncio.Event() self._knowledge_config: Dict[str, Any] = knowledge_config or {} + if knowledge_searcher is None: + adapter = build_backend_adapter_from_settings() + self._knowledge_searcher = adapter.search_knowledge_context + else: + self._knowledge_searcher = knowledge_searcher self._tool_schemas: List[Dict[str, Any]] = [] _RAG_DEFAULT_RESULTS = 5 @@ -224,7 +230,7 @@ class OpenAILLMService(BaseLLMService): n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS) n_results = max(1, min(n_results, self._RAG_MAX_RESULTS)) - results = await search_knowledge_context( + results = await self._knowledge_searcher( kb_id=kb_id, query=latest_user, n_results=n_results, diff --git a/tests/test_backend_adapters.py b/tests/test_backend_adapters.py new file mode 100644 index 0000000..d55f5e2 --- /dev/null +++ b/tests/test_backend_adapters.py @@ -0,0 +1,150 @@ +import aiohttp +import pytest + +from app.backend_adapters import ( + HistoryDisabledBackendAdapter, + HttpBackendAdapter, + NullBackendAdapter, + build_backend_adapter, +) + + +@pytest.mark.asyncio +async def test_build_backend_adapter_without_url_returns_null_adapter(): + adapter = build_backend_adapter( + backend_url=None, + backend_mode="auto", + history_enabled=True, + timeout_sec=3, + ) + assert isinstance(adapter, NullBackendAdapter) + + assert await adapter.fetch_assistant_config("assistant_1") is None + assert ( + await adapter.create_call_record( + user_id=1, + assistant_id="assistant_1", + source="debug", + ) + is None + ) + assert ( + await adapter.add_transcript( + call_id="call_1", + turn_index=0, + speaker="human", + content="hi", + start_ms=0, + end_ms=100, + confidence=0.9, + duration_ms=100, + ) + is False + ) + assert ( + await adapter.finalize_call_record( + call_id="call_1", + status="connected", + duration_seconds=2, + ) + is False + ) + assert await adapter.search_knowledge_context(kb_id="kb_1", query="hello", n_results=3) == [] + assert await adapter.fetch_tool_resource("tool_1") is None + + +@pytest.mark.asyncio +async def test_http_backend_adapter_create_call_record_posts_expected_payload(monkeypatch): + captured = {} + + class _FakeResponse: + def __init__(self, status=200, payload=None): + self.status = status + self._payload = payload if payload is not None else {} + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + async def json(self): + return self._payload + + def raise_for_status(self): + if self.status >= 400: + raise RuntimeError("http_error") + + class _FakeClientSession: + def __init__(self, timeout=None): + captured["timeout"] = timeout + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + def post(self, url, json=None): + captured["url"] = url + captured["json"] = json + return _FakeResponse(status=200, payload={"id": "call_123"}) + + monkeypatch.setattr("app.backend_adapters.aiohttp.ClientSession", _FakeClientSession) + + adapter = build_backend_adapter( + backend_url="http://localhost:8100", + backend_mode="auto", + history_enabled=True, + timeout_sec=7, + ) + assert isinstance(adapter, HttpBackendAdapter) + + call_id = await adapter.create_call_record( + user_id=99, + assistant_id="assistant_9", + source="debug", + ) + + assert call_id == "call_123" + assert captured["url"] == "http://localhost:8100/api/history" + assert captured["json"] == { + "user_id": 99, + "assistant_id": "assistant_9", + "source": "debug", + "status": "connected", + } + assert isinstance(captured["timeout"], aiohttp.ClientTimeout) + assert captured["timeout"].total == 7 + + +@pytest.mark.asyncio +async def test_backend_mode_disabled_forces_null_even_with_url(): + adapter = build_backend_adapter( + backend_url="http://localhost:8100", + backend_mode="disabled", + history_enabled=True, + timeout_sec=7, + ) + assert isinstance(adapter, NullBackendAdapter) + + +@pytest.mark.asyncio +async def test_history_disabled_wraps_backend_adapter(): + adapter = build_backend_adapter( + backend_url="http://localhost:8100", + backend_mode="auto", + history_enabled=False, + timeout_sec=7, + ) + assert isinstance(adapter, HistoryDisabledBackendAdapter) + assert await adapter.create_call_record(user_id=1, assistant_id="a1", source="debug") is None + assert await adapter.add_transcript( + call_id="c1", + turn_index=0, + speaker="human", + content="hi", + start_ms=0, + end_ms=10, + duration_ms=10, + ) is False diff --git a/tests/test_history_bridge.py b/tests/test_history_bridge.py new file mode 100644 index 0000000..2f9dd80 --- /dev/null +++ b/tests/test_history_bridge.py @@ -0,0 +1,147 @@ +import asyncio +import time + +import pytest + +from core.history_bridge import SessionHistoryBridge + + +class _FakeHistoryWriter: + def __init__(self, *, add_delay_s: float = 0.0, add_result: bool = True): + self.add_delay_s = add_delay_s + self.add_result = add_result + self.created_call_ids = [] + self.transcripts = [] + self.finalize_calls = 0 + self.finalize_statuses = [] + self.finalize_at = None + self.last_transcript_at = None + + async def create_call_record(self, *, user_id: int, assistant_id: str | None, source: str = "debug"): + _ = (user_id, assistant_id, source) + call_id = "call_test_1" + self.created_call_ids.append(call_id) + return call_id + + async def add_transcript( + self, + *, + call_id: str, + turn_index: int, + speaker: str, + content: str, + start_ms: int, + end_ms: int, + confidence: float | None = None, + duration_ms: int | None = None, + ) -> bool: + _ = confidence + if self.add_delay_s > 0: + await asyncio.sleep(self.add_delay_s) + self.transcripts.append( + { + "call_id": call_id, + "turn_index": turn_index, + "speaker": speaker, + "content": content, + "start_ms": start_ms, + "end_ms": end_ms, + "duration_ms": duration_ms, + } + ) + self.last_transcript_at = time.monotonic() + return self.add_result + + async def finalize_call_record(self, *, call_id: str, status: str, duration_seconds: int) -> bool: + _ = (call_id, duration_seconds) + self.finalize_calls += 1 + self.finalize_statuses.append(status) + self.finalize_at = time.monotonic() + return True + +@pytest.mark.asyncio +async def test_slow_backend_does_not_block_enqueue(): + writer = _FakeHistoryWriter(add_delay_s=0.15, add_result=True) + bridge = SessionHistoryBridge( + history_writer=writer, + enabled=True, + queue_max_size=32, + retry_max_attempts=0, + retry_backoff_sec=0.01, + finalize_drain_timeout_sec=1.0, + ) + + try: + call_id = await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug") + assert call_id == "call_test_1" + + t0 = time.perf_counter() + queued = bridge.enqueue_turn(role="user", text="hello world") + elapsed_s = time.perf_counter() - t0 + + assert queued is True + assert elapsed_s < 0.02 + + await bridge.finalize(status="connected") + assert len(writer.transcripts) == 1 + assert writer.finalize_calls == 1 + finally: + await bridge.shutdown() + + +@pytest.mark.asyncio +async def test_failing_backend_retries_but_enqueue_remains_non_blocking(): + writer = _FakeHistoryWriter(add_delay_s=0.01, add_result=False) + bridge = SessionHistoryBridge( + history_writer=writer, + enabled=True, + queue_max_size=32, + retry_max_attempts=2, + retry_backoff_sec=0.01, + finalize_drain_timeout_sec=0.5, + ) + + try: + await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug") + t0 = time.perf_counter() + assert bridge.enqueue_turn(role="assistant", text="retry me") + elapsed_s = time.perf_counter() - t0 + assert elapsed_s < 0.02 + + await bridge.finalize(status="connected") + + # Initial try + 2 retries + assert len(writer.transcripts) == 3 + assert writer.finalize_calls == 1 + finally: + await bridge.shutdown() + + +@pytest.mark.asyncio +async def test_finalize_is_idempotent_and_waits_for_queue_drain(): + writer = _FakeHistoryWriter(add_delay_s=0.05, add_result=True) + bridge = SessionHistoryBridge( + history_writer=writer, + enabled=True, + queue_max_size=32, + retry_max_attempts=0, + retry_backoff_sec=0.01, + finalize_drain_timeout_sec=1.0, + ) + + try: + await bridge.start_call(user_id=1, assistant_id="assistant_1", source="debug") + assert bridge.enqueue_turn(role="user", text="first") + + ok_1 = await bridge.finalize(status="connected") + ok_2 = await bridge.finalize(status="connected") + + assert ok_1 is True + assert ok_2 is True + assert len(writer.transcripts) == 1 + assert writer.finalize_calls == 1 + assert writer.last_transcript_at is not None + assert writer.finalize_at is not None + assert writer.finalize_at >= writer.last_transcript_at + finally: + await bridge.shutdown()