refactor: give TAdapter a default to restore precise typing for unparameterized LLMService subclasses

After making LLMService generic, an unparameterized subclass
(`class MyService(LLMService):` with no bracket — the third-party
provider pattern) saw `get_llm_adapter()` return `Unknown` rather
than `BaseLLMAdapter` as it did before the refactor.

Add `default=BaseLLMAdapter` (PEP 696) on the TypeVar — via
`typing_extensions.TypeVar` so older Python targets keep working —
so unparameterized callers get `LLMService[BaseLLMAdapter]` and
`get_llm_adapter()` returns `BaseLLMAdapter`, matching the
pre-refactor type precision.

Two internal fallouts of having a default (where the default makes
unannotated `LLMService` resolve invariantly to
`LLMService[BaseLLMAdapter]`):

- `FunctionCallParams.llm` is now `LLMService[Any]` so concrete
  parameterizations like `LLMService[OpenAILLMAdapter]` can be
  passed where the field is set.
- The explicit `LLMService.__init__(self, **kwargs)` in
  `WebsocketLLMService.__init__` gets a `pyright: ignore[reportArgumentType]`
  comment — pyright's invariance handling can't see through the
  multi-inheritance + generic + default combination, but the
  runtime call is correct (generics are erased).
This commit is contained in:
Paul Kompfner
2026-04-28 10:21:25 -04:00
parent c4f5f1ebbb
commit 1cd73b1ef8

View File

@@ -18,11 +18,11 @@ from typing import (
Any,
Generic,
Protocol,
TypeVar,
cast,
)
from loguru import logger
from typing_extensions import TypeVar
from websockets.exceptions import ConnectionClosed
from websockets.protocol import State
@@ -119,7 +119,12 @@ class FunctionCallParams:
function_name: str
tool_call_id: str
arguments: Mapping[str, Any]
llm: LLMService
# `LLMService[Any]` so any concrete subclass (regardless of how — or
# whether — it parameterizes the adapter type) can be assigned here.
# Plain `LLMService` would invoke the TypeVar default and pyright would
# treat it invariantly, rejecting `LLMService[XAdapter]` at the call
# sites that build FunctionCallParams.
llm: LLMService[Any]
context: LLMContext
result_callback: FunctionCallResultCallback
app_resources: Any = None
@@ -193,7 +198,11 @@ class FunctionCallRunnerItem:
group_id: str | None = None
TAdapter = TypeVar("TAdapter", bound=BaseLLMAdapter)
# `default=BaseLLMAdapter` (PEP 696) so that unparameterized subclasses
# (e.g. third-party `class MyService(LLMService):` with no bracket) get
# `TAdapter = BaseLLMAdapter` instead of `Unknown` at type-check time —
# matching the pre-generic behavior of `get_llm_adapter()`.
TAdapter = TypeVar("TAdapter", bound=BaseLLMAdapter, default=BaseLLMAdapter)
class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]):
@@ -1172,7 +1181,11 @@ class WebsocketLLMService(LLMService[TAdapter], WebsocketService, Generic[TAdapt
reconnect_on_error: Whether to automatically reconnect on websocket errors.
**kwargs: Additional arguments passed to parent classes.
"""
LLMService.__init__(self, **kwargs)
# pyright stumbles here because the TypeVar default makes
# `LLMService` resolve to `LLMService[BaseLLMAdapter]` invariantly,
# while `self` is `WebsocketLLMService[TAdapter]` for an arbitrary
# TAdapter. The runtime call is fine — generics are erased.
LLMService.__init__(self, **kwargs) # pyright: ignore[reportArgumentType]
WebsocketService.__init__(self, reconnect_on_error=reconnect_on_error, **kwargs)
self._register_event_handler("on_connection_error")