diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 9310ec78b..77c1412a5 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -18,11 +18,11 @@ from typing import ( Any, Generic, Protocol, - TypeVar, cast, ) from loguru import logger +from typing_extensions import TypeVar from websockets.exceptions import ConnectionClosed from websockets.protocol import State @@ -119,7 +119,12 @@ class FunctionCallParams: function_name: str tool_call_id: str arguments: Mapping[str, Any] - llm: LLMService + # `LLMService[Any]` so any concrete subclass (regardless of how — or + # whether — it parameterizes the adapter type) can be assigned here. + # Plain `LLMService` would invoke the TypeVar default and pyright would + # treat it invariantly, rejecting `LLMService[XAdapter]` at the call + # sites that build FunctionCallParams. + llm: LLMService[Any] context: LLMContext result_callback: FunctionCallResultCallback app_resources: Any = None @@ -193,7 +198,11 @@ class FunctionCallRunnerItem: group_id: str | None = None -TAdapter = TypeVar("TAdapter", bound=BaseLLMAdapter) +# `default=BaseLLMAdapter` (PEP 696) so that unparameterized subclasses +# (e.g. third-party `class MyService(LLMService):` with no bracket) get +# `TAdapter = BaseLLMAdapter` instead of `Unknown` at type-check time — +# matching the pre-generic behavior of `get_llm_adapter()`. +TAdapter = TypeVar("TAdapter", bound=BaseLLMAdapter, default=BaseLLMAdapter) class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]): @@ -1172,7 +1181,11 @@ class WebsocketLLMService(LLMService[TAdapter], WebsocketService, Generic[TAdapt reconnect_on_error: Whether to automatically reconnect on websocket errors. **kwargs: Additional arguments passed to parent classes. """ - LLMService.__init__(self, **kwargs) + # pyright stumbles here because the TypeVar default makes + # `LLMService` resolve to `LLMService[BaseLLMAdapter]` invariantly, + # while `self` is `WebsocketLLMService[TAdapter]` for an arbitrary + # TAdapter. The runtime call is fine — generics are erased. + LLMService.__init__(self, **kwargs) # pyright: ignore[reportArgumentType] WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs) self._register_event_handler("on_connection_error")