83 lines
2.5 KiB
Python
83 lines
2.5 KiB
Python
"""LLM extension port contracts."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Protocol
|
|
|
|
from providers.common.base import LLMMessage, LLMStreamEvent
|
|
|
|
KnowledgeRetrieverFn = Callable[..., Awaitable[List[Dict[str, Any]]]]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class LLMServiceSpec:
|
|
"""Resolved runtime configuration for LLM service creation."""
|
|
|
|
provider: str
|
|
model: str
|
|
api_key: Optional[str] = None
|
|
base_url: Optional[str] = None
|
|
app_id: Optional[str] = None
|
|
system_prompt: Optional[str] = None
|
|
temperature: float = 0.7
|
|
knowledge_config: Dict[str, Any] = field(default_factory=dict)
|
|
knowledge_searcher: Optional[KnowledgeRetrieverFn] = None
|
|
|
|
|
|
class LLMPort(Protocol):
|
|
"""Port for LLM providers."""
|
|
|
|
async def connect(self) -> None:
|
|
"""Establish connection to LLM provider."""
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Release LLM resources."""
|
|
|
|
async def generate(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
) -> str:
|
|
"""Generate a complete assistant response."""
|
|
|
|
async def generate_stream(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None,
|
|
) -> AsyncIterator[LLMStreamEvent]:
|
|
"""Generate streaming assistant response events."""
|
|
|
|
|
|
class LLMCancellable(Protocol):
|
|
"""Optional extension for interrupting in-flight LLM generation."""
|
|
|
|
def cancel(self) -> None:
|
|
"""Cancel an in-flight generation request."""
|
|
|
|
|
|
class LLMRuntimeConfigurable(Protocol):
|
|
"""Optional extension for runtime config updates."""
|
|
|
|
def set_knowledge_config(self, config: Optional[Dict[str, Any]]) -> None:
|
|
"""Apply runtime knowledge retrieval settings."""
|
|
|
|
def set_tool_schemas(self, schemas: Optional[List[Dict[str, Any]]]) -> None:
|
|
"""Apply runtime tool schemas used for tool calling."""
|
|
|
|
|
|
class LLMClientToolResumable(Protocol):
|
|
"""Optional extension for providers that pause on client-side tool results."""
|
|
|
|
def handles_client_tool(self, tool_name: str) -> bool:
|
|
"""Return True when the provider owns the lifecycle of this client tool."""
|
|
|
|
def resume_after_client_tool_result(
|
|
self,
|
|
tool_call_id: str,
|
|
result: Dict[str, Any],
|
|
) -> AsyncIterator[LLMStreamEvent]:
|
|
"""Resume the provider stream after a correlated client-side tool result."""
|