refactor: make LLMService generic over its adapter type

Previously, `LLMService.get_llm_adapter()` returned `BaseLLMAdapter`,
which forced every caller that wanted the precise adapter type to
write `adapter: SomeAdapter = self.get_llm_adapter()` and accept
pyright's complaint that the assignment doesn't match the declared
type. That pattern existed in 17 places across the LLM services.

Make `LLMService` generic over its adapter type — `LLMService(...,
Generic[TAdapter])` with `TAdapter = TypeVar("TAdapter",
bound=BaseLLMAdapter)` — so subclasses opt in via
`LLMService[XAdapter]` and callers get the precise type back from
`get_llm_adapter()` automatically.

Backward-compatible for third-party providers: code that says
`class MyService(LLMService):` (no bracket) still type-checks, with
TAdapter resolving to BaseLLMAdapter from the bound — identical to
the pre-refactor behavior. The `adapter_class` attribute keeps its
loose `type[BaseLLMAdapter] = OpenAILLMAdapter` typing so the default
remains usable; one localized cast in `__init__` bridges the loose
class attr to the precise instance attr.

In-tree subclasses opted in:

- AnthropicLLMService -> LLMService[AnthropicLLMAdapter]
- AWSBedrockLLMService -> LLMService[AWSBedrockLLMAdapter]
- AWSNovaSonicLLMService -> LLMService[AWSNovaSonicLLMAdapter]
- BaseOpenAILLMService -> LLMService[OpenAILLMAdapter] (propagates to
  ~15 OpenAI-compatible providers like Cerebras, Groq, Together)
- GeminiLiveLLMService -> LLMService[GeminiLLMAdapter]
- GoogleLLMService -> LLMService[GeminiLLMAdapter]
- GrokRealtimeLLMService -> LLMService[GrokRealtimeLLMAdapter]
- InworldRealtimeLLMService -> LLMService[InworldRealtimeLLMAdapter]
- OpenAIRealtimeLLMService -> LLMService[OpenAIRealtimeLLMAdapter]
- _BaseOpenAIResponsesLLMService -> LLMService[OpenAIResponsesLLMAdapter]
- WebsocketLLMService is also generic so the multi-inheritance case
  (OpenAIResponsesLLMService) can keep both bases agreeing on TAdapter.

All 17 redundant `adapter: SomeAdapter = self.get_llm_adapter()`
annotations are now plain `adapter = self.get_llm_adapter()`.
This commit is contained in:
Paul Kompfner
2026-04-28 09:38:01 -04:00
parent d23bdaaacd
commit 49068ff557
11 changed files with 50 additions and 36 deletions

View File

@@ -105,7 +105,7 @@ class AnthropicLLMSettings(LLMSettings):
return instance
class AnthropicLLMService(LLMService):
class AnthropicLLMService(LLMService[AnthropicLLMAdapter]):
"""LLM service for Anthropic's Claude models.
Provides inference capabilities with Claude models including support for
@@ -293,7 +293,7 @@ class AnthropicLLMService(LLMService):
effective_instruction = system_instruction or assert_given(
self._settings.system_instruction
)
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
invocation_params = adapter.get_llm_invocation_params(
context,
enable_prompt_caching=assert_given(self._settings.enable_prompt_caching),
@@ -328,7 +328,7 @@ class AnthropicLLMService(LLMService):
return next((block.text for block in response.content if hasattr(block, "text")), None)
def _get_llm_invocation_params(self, context: LLMContext) -> AnthropicLLMInvocationParams:
adapter: AnthropicLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
params: AnthropicLLMInvocationParams = adapter.get_llm_invocation_params(
context,
enable_prompt_caching=assert_given(self._settings.enable_prompt_caching),

View File

@@ -74,7 +74,7 @@ class AWSBedrockLLMSettings(LLMSettings):
)
class AWSBedrockLLMService(LLMService):
class AWSBedrockLLMService(LLMService[AWSBedrockLLMAdapter]):
"""AWS Bedrock Large Language Model service implementation.
Provides inference capabilities for AWS Bedrock models including Amazon Nova
@@ -282,7 +282,7 @@ class AWSBedrockLLMService(LLMService):
effective_instruction = system_instruction or assert_given(
self._settings.system_instruction
)
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(
context, system_instruction=effective_instruction
)
@@ -371,7 +371,7 @@ class AWSBedrockLLMService(LLMService):
}
def _get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams:
adapter: AWSBedrockLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params(
context, system_instruction=assert_given(self._settings.system_instruction)
)

View File

@@ -235,7 +235,7 @@ class AWSNovaSonicLLMSettings(LLMSettings):
endpointing_sensitivity: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
class AWSNovaSonicLLMService(LLMService):
class AWSNovaSonicLLMService(LLMService[AWSNovaSonicLLMAdapter]):
"""AWS Nova Sonic speech-to-speech LLM service.
Provides bidirectional audio streaming, real-time transcription, text generation,
@@ -644,7 +644,7 @@ class AWSNovaSonicLLMService(LLMService):
await self._process_completed_function_calls(send_new_results=False)
# Read context
adapter: AWSNovaSonicLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
llm_connection_params = adapter.get_llm_invocation_params(
self._context, system_instruction=assert_given(self._settings.system_instruction)
)
@@ -1125,7 +1125,7 @@ class AWSNovaSonicLLMService(LLMService):
"""Return ``(system_instruction, tools)`` for the next session setup."""
if not self._context:
return None, []
adapter: AWSNovaSonicLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
llm_params = adapter.get_llm_invocation_params(
self._context, system_instruction=self._settings.system_instruction
)

View File

@@ -351,7 +351,7 @@ class GeminiLiveLLMSettings(LLMSettings):
proactivity: ProactivityConfig | dict | _NotGiven = field(default_factory=lambda: NOT_GIVEN)
class GeminiLiveLLMService(LLMService):
class GeminiLiveLLMService(LLMService[GeminiLLMAdapter]):
"""Provides access to Google's Gemini Live API.
This service enables real-time conversations with Gemini, supporting both
@@ -778,7 +778,7 @@ class GeminiLiveLLMService(LLMService):
# init-provided values). Note that the determination of "effective"
# system instruction is delegated to the adapter, which still
# chooses the init-provided value if there is one.
adapter: GeminiLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
params = adapter.get_llm_invocation_params(
self._context, system_instruction=assert_given(self._system_instruction_from_init)
)
@@ -840,7 +840,7 @@ class GeminiLiveLLMService(LLMService):
async def _process_completed_function_calls(self, send_new_results: bool):
# Check for set of completed function calls in the context
adapter: GeminiLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
messages = adapter.get_llm_invocation_params(self._context).get("messages", [])
for message in messages:
if message.parts:
@@ -1027,7 +1027,7 @@ class GeminiLiveLLMService(LLMService):
# Add system instruction and tools to configuration, if provided.
# These settings from the context take precedence over the ones
# provided at initialization time.
adapter: GeminiLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
system_instruction = None
tools = None
if self._context:
@@ -1333,7 +1333,7 @@ class GeminiLiveLLMService(LLMService):
self._run_llm_when_session_ready = True
return
adapter: GeminiLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
messages = adapter.get_llm_invocation_params(self._context).get("messages", [])
if not messages:
# No messages to seed convo with, so we're ready for realtime input right away
@@ -1392,7 +1392,7 @@ class GeminiLiveLLMService(LLMService):
# Create a throwaway context just for the purpose of getting messages
# in the right format
context = LLMContext(messages=messages_list)
adapter: GeminiLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
messages = adapter.get_llm_invocation_params(context).get("messages", [])
if not messages:

View File

@@ -124,7 +124,7 @@ class GoogleLLMSettings(LLMSettings):
return instance
class GoogleLLMService(LLMService):
class GoogleLLMService(LLMService[GeminiLLMAdapter]):
"""Google AI (Gemini) LLM service implementation.
This class implements inference with Google's AI models, translating internally

View File

@@ -189,7 +189,7 @@ _NON_FATAL_ERROR_CODES = {
}
class InworldRealtimeLLMService(LLMService):
class InworldRealtimeLLMService(LLMService[InworldRealtimeLLMAdapter]):
"""Inworld Realtime LLM service for real-time audio and text communication.
Implements the Inworld Realtime API with WebSocket communication for
@@ -664,7 +664,7 @@ class InworldRealtimeLLMService(LLMService):
async def _send_session_update(self):
"""Update session settings on the server."""
settings = assert_given(self._settings.session_properties)
adapter: InworldRealtimeLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
if self._context:
llm_invocation_params = adapter.get_llm_invocation_params(
@@ -963,7 +963,7 @@ class InworldRealtimeLLMService(LLMService):
self._run_llm_when_api_session_ready = True
return
adapter: InworldRealtimeLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
if self._llm_needs_conversation_setup:
logger.debug(

View File

@@ -16,7 +16,10 @@ from collections.abc import Awaitable, Callable, Mapping, Sequence
from dataclasses import dataclass
from typing import (
Any,
Generic,
Protocol,
TypeVar,
cast,
)
from loguru import logger
@@ -190,7 +193,10 @@ class FunctionCallRunnerItem:
group_id: str | None = None
class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
TAdapter = TypeVar("TAdapter", bound=BaseLLMAdapter)
class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]):
"""Base class for all LLM services.
Handles function calling registration and execution with support for both
@@ -222,6 +228,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
"""
_settings: LLMSettings
_adapter: TAdapter
# OpenAILLMAdapter is used as the default adapter since it aligns with most LLM implementations.
# However, subclasses should override this with a more specific adapter when necessary.
@@ -269,7 +276,12 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
self._filter_incomplete_user_turns: bool = False
self._async_tool_cancellation_enabled: bool = False
self._base_system_instruction: str | None = None
self._adapter = self.adapter_class()
# `adapter_class` is typed as `type[BaseLLMAdapter]` so subclasses
# don't need to spell out the generic parameter just to subclass
# (backward compatibility for 3rd-party providers outside this repo).
# Cast to TAdapter to keep `_adapter` and `get_llm_adapter()` precisely
# typed for callers that opt into `LLMService[XAdapter]`.
self._adapter = cast(TAdapter, self.adapter_class())
self._functions: dict[str | None, FunctionCallRegistryItem] = {}
self._function_call_tasks: dict[asyncio.Task | None, FunctionCallRunnerItem] = {}
self._sequential_runner_task: asyncio.Task | None = None
@@ -280,7 +292,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService):
self._register_event_handler("on_function_calls_cancelled")
self._register_event_handler("on_completion_timeout")
def get_llm_adapter(self) -> BaseLLMAdapter:
def get_llm_adapter(self) -> TAdapter:
"""Get the LLM adapter instance.
Returns:
@@ -1112,7 +1124,7 @@ class WebsocketReconnectedError(Exception):
pass
class WebsocketLLMService(LLMService, WebsocketService):
class WebsocketLLMService(LLMService[TAdapter], WebsocketService, Generic[TAdapter]):
"""Base class for websocket-based LLM services.
Each LLM inference is a discrete request/response exchange: send one

View File

@@ -26,7 +26,7 @@ from openai._types import NotGiven as OpenAINotGiven
from openai.types.chat import ChatCompletionChunk
from pydantic import BaseModel, Field
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter, OpenAILLMInvocationParams
from pipecat.frames.frames import (
Frame,
LLMContextFrame,
@@ -71,7 +71,7 @@ class OpenAILLMSettings(LLMSettings):
)
class BaseOpenAILLMService(LLMService):
class BaseOpenAILLMService(LLMService[OpenAILLMAdapter]):
"""Base class for all services that use the AsyncOpenAI client.
This service consumes LLMContextFrame frames, which contain a reference to

View File

@@ -194,7 +194,7 @@ class OpenAIRealtimeLLMSettings(LLMSettings):
return instance
class OpenAIRealtimeLLMService(LLMService):
class OpenAIRealtimeLLMService(LLMService[OpenAIRealtimeLLMAdapter]):
"""OpenAI Realtime LLM service providing real-time audio and text communication.
Implements the OpenAI Realtime API with WebSocket communication for low-latency
@@ -657,7 +657,7 @@ class OpenAIRealtimeLLMService(LLMService):
async def _send_session_update(self):
settings = assert_given(self._settings.session_properties)
adapter: OpenAIRealtimeLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
if self._context:
llm_invocation_params = adapter.get_llm_invocation_params(
@@ -1002,7 +1002,7 @@ class OpenAIRealtimeLLMService(LLMService):
self._run_llm_when_api_session_ready = True
return
adapter: OpenAIRealtimeLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
# Configure the LLM for this session if needed
if self._llm_needs_conversation_setup:

View File

@@ -115,7 +115,7 @@ class OpenAIResponsesLLMSettings(LLMSettings):
# ---------------------------------------------------------------------------
class _BaseOpenAIResponsesLLMService(LLMService):
class _BaseOpenAIResponsesLLMService(LLMService[OpenAIResponsesLLMAdapter]):
"""Shared base for HTTP and WebSocket OpenAI Responses API services.
Contains settings, adapter reference, HTTP client creation, parameter
@@ -294,7 +294,7 @@ class _BaseOpenAIResponsesLLMService(LLMService):
Returns:
The LLM's response as a string, or None if no response is generated.
"""
adapter: OpenAIResponsesLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
effective_instruction = system_instruction or assert_given(
self._settings.system_instruction
)
@@ -353,7 +353,9 @@ class _BaseOpenAIResponsesLLMService(LLMService):
# ---------------------------------------------------------------------------
class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService, WebsocketLLMService):
class OpenAIResponsesLLMService(
_BaseOpenAIResponsesLLMService, WebsocketLLMService[OpenAIResponsesLLMAdapter]
):
"""OpenAI Responses API LLM service using WebSocket transport.
Maintains a persistent WebSocket connection to ``wss://api.openai.com/v1/responses``
@@ -747,7 +749,7 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService, WebsocketLLMServ
if self._needs_drain:
await self._drain_cancelled_response()
adapter: OpenAIResponsesLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
logger.debug(
f"{self}: Generating response from universal context "
f"{adapter.get_messages_for_logging(context)}"
@@ -987,7 +989,7 @@ class OpenAIResponsesHttpLLMService(_BaseOpenAIResponsesLLMService):
@traced_llm
async def _process_context(self, context: LLMContext):
adapter: OpenAIResponsesLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
logger.debug(
f"{self}: Generating response from universal context "
f"{adapter.get_messages_for_logging(context)}"

View File

@@ -179,7 +179,7 @@ class GrokRealtimeLLMSettings(LLMSettings):
return instance
class GrokRealtimeLLMService(LLMService):
class GrokRealtimeLLMService(LLMService[GrokRealtimeLLMAdapter]):
"""Grok Realtime Voice Agent LLM service providing real-time audio and text communication.
Implements the Grok Voice Agent API with WebSocket communication for low-latency
@@ -596,7 +596,7 @@ class GrokRealtimeLLMService(LLMService):
async def _send_session_update(self):
"""Update session settings on the server."""
settings = assert_given(self._settings.session_properties)
adapter: GrokRealtimeLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
if self._context:
llm_invocation_params = adapter.get_llm_invocation_params(
@@ -871,7 +871,7 @@ class GrokRealtimeLLMService(LLMService):
self._run_llm_when_api_session_ready = True
return
adapter: GrokRealtimeLLMAdapter = self.get_llm_adapter()
adapter = self.get_llm_adapter()
if self._llm_needs_conversation_setup:
logger.debug(