refactor: give TAdapter a default to restore precise typing for unparameterized LLMService subclasses
After making LLMService generic, an unparameterized subclass (`class MyService(LLMService):` with no bracket — the third-party provider pattern) saw `get_llm_adapter()` return `Unknown` rather than `BaseLLMAdapter` as it did before the refactor. Add `default=BaseLLMAdapter` (PEP 696) on the TypeVar — via `typing_extensions.TypeVar` so older Python targets keep working — so unparameterized callers get `LLMService[BaseLLMAdapter]` and `get_llm_adapter()` returns `BaseLLMAdapter`, matching the pre-refactor type precision. Two internal fallouts of having a default (where the default makes unannotated `LLMService` resolve invariantly to `LLMService[BaseLLMAdapter]`): - `FunctionCallParams.llm` is now `LLMService[Any]` so concrete parameterizations like `LLMService[OpenAILLMAdapter]` can be passed where the field is set. - The explicit `LLMService.__init__(self, **kwargs)` in `WebsocketLLMService.__init__` gets a `pyright: ignore[reportArgumentType]` comment — pyright's invariance handling can't see through the multi-inheritance + generic + default combination, but the runtime call is correct (generics are erased).
This commit is contained in:
@@ -18,11 +18,11 @@ from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
cast,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
from typing_extensions import TypeVar
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
from websockets.protocol import State
|
||||
|
||||
@@ -119,7 +119,12 @@ class FunctionCallParams:
|
||||
function_name: str
|
||||
tool_call_id: str
|
||||
arguments: Mapping[str, Any]
|
||||
llm: LLMService
|
||||
# `LLMService[Any]` so any concrete subclass (regardless of how — or
|
||||
# whether — it parameterizes the adapter type) can be assigned here.
|
||||
# Plain `LLMService` would invoke the TypeVar default and pyright would
|
||||
# treat it invariantly, rejecting `LLMService[XAdapter]` at the call
|
||||
# sites that build FunctionCallParams.
|
||||
llm: LLMService[Any]
|
||||
context: LLMContext
|
||||
result_callback: FunctionCallResultCallback
|
||||
app_resources: Any = None
|
||||
@@ -193,7 +198,11 @@ class FunctionCallRunnerItem:
|
||||
group_id: str | None = None
|
||||
|
||||
|
||||
TAdapter = TypeVar("TAdapter", bound=BaseLLMAdapter)
|
||||
# `default=BaseLLMAdapter` (PEP 696) so that unparameterized subclasses
|
||||
# (e.g. third-party `class MyService(LLMService):` with no bracket) get
|
||||
# `TAdapter = BaseLLMAdapter` instead of `Unknown` at type-check time —
|
||||
# matching the pre-generic behavior of `get_llm_adapter()`.
|
||||
TAdapter = TypeVar("TAdapter", bound=BaseLLMAdapter, default=BaseLLMAdapter)
|
||||
|
||||
|
||||
class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]):
|
||||
@@ -1172,7 +1181,11 @@ class WebsocketLLMService(LLMService[TAdapter], WebsocketService, Generic[TAdapt
|
||||
reconnect_on_error: Whether to automatically reconnect on websocket errors.
|
||||
**kwargs: Additional arguments passed to parent classes.
|
||||
"""
|
||||
LLMService.__init__(self, **kwargs)
|
||||
# pyright stumbles here because the TypeVar default makes
|
||||
# `LLMService` resolve to `LLMService[BaseLLMAdapter]` invariantly,
|
||||
# while `self` is `WebsocketLLMService[TAdapter]` for an arbitrary
|
||||
# TAdapter. The runtime call is fine — generics are erased.
|
||||
LLMService.__init__(self, **kwargs) # pyright: ignore[reportArgumentType]
|
||||
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
|
||||
self._register_event_handler("on_connection_error")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user