diff --git a/src/pipecat/services/anthropic/llm.py b/src/pipecat/services/anthropic/llm.py index 138aa6981..c028ec7aa 100644 --- a/src/pipecat/services/anthropic/llm.py +++ b/src/pipecat/services/anthropic/llm.py @@ -105,7 +105,7 @@ class AnthropicLLMSettings(LLMSettings): return instance -class AnthropicLLMService(LLMService): +class AnthropicLLMService(LLMService[AnthropicLLMAdapter]): """LLM service for Anthropic's Claude models. Provides inference capabilities with Claude models including support for @@ -293,7 +293,7 @@ class AnthropicLLMService(LLMService): effective_instruction = system_instruction or assert_given( self._settings.system_instruction ) - adapter: AnthropicLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() invocation_params = adapter.get_llm_invocation_params( context, enable_prompt_caching=assert_given(self._settings.enable_prompt_caching), @@ -328,7 +328,7 @@ class AnthropicLLMService(LLMService): return next((block.text for block in response.content if hasattr(block, "text")), None) def _get_llm_invocation_params(self, context: LLMContext) -> AnthropicLLMInvocationParams: - adapter: AnthropicLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() params: AnthropicLLMInvocationParams = adapter.get_llm_invocation_params( context, enable_prompt_caching=assert_given(self._settings.enable_prompt_caching), diff --git a/src/pipecat/services/aws/llm.py b/src/pipecat/services/aws/llm.py index 5420addf0..7266821a1 100644 --- a/src/pipecat/services/aws/llm.py +++ b/src/pipecat/services/aws/llm.py @@ -74,7 +74,7 @@ class AWSBedrockLLMSettings(LLMSettings): ) -class AWSBedrockLLMService(LLMService): +class AWSBedrockLLMService(LLMService[AWSBedrockLLMAdapter]): """AWS Bedrock Large Language Model service implementation. Provides inference capabilities for AWS Bedrock models including Amazon Nova @@ -282,7 +282,7 @@ class AWSBedrockLLMService(LLMService): effective_instruction = system_instruction or assert_given( self._settings.system_instruction ) - adapter: AWSBedrockLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params( context, system_instruction=effective_instruction ) @@ -371,7 +371,7 @@ class AWSBedrockLLMService(LLMService): } def _get_llm_invocation_params(self, context: LLMContext) -> AWSBedrockLLMInvocationParams: - adapter: AWSBedrockLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() params: AWSBedrockLLMInvocationParams = adapter.get_llm_invocation_params( context, system_instruction=assert_given(self._settings.system_instruction) ) diff --git a/src/pipecat/services/aws/nova_sonic/llm.py b/src/pipecat/services/aws/nova_sonic/llm.py index d3d927f57..55a4dd813 100644 --- a/src/pipecat/services/aws/nova_sonic/llm.py +++ b/src/pipecat/services/aws/nova_sonic/llm.py @@ -235,7 +235,7 @@ class AWSNovaSonicLLMSettings(LLMSettings): endpointing_sensitivity: str | None | _NotGiven = field(default_factory=lambda: NOT_GIVEN) -class AWSNovaSonicLLMService(LLMService): +class AWSNovaSonicLLMService(LLMService[AWSNovaSonicLLMAdapter]): """AWS Nova Sonic speech-to-speech LLM service. Provides bidirectional audio streaming, real-time transcription, text generation, @@ -644,7 +644,7 @@ class AWSNovaSonicLLMService(LLMService): await self._process_completed_function_calls(send_new_results=False) # Read context - adapter: AWSNovaSonicLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() llm_connection_params = adapter.get_llm_invocation_params( self._context, system_instruction=assert_given(self._settings.system_instruction) ) @@ -1125,7 +1125,7 @@ class AWSNovaSonicLLMService(LLMService): """Return ``(system_instruction, tools)`` for the next session setup.""" if not self._context: return None, [] - adapter: AWSNovaSonicLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() llm_params = adapter.get_llm_invocation_params( self._context, system_instruction=self._settings.system_instruction ) diff --git a/src/pipecat/services/google/gemini_live/llm.py b/src/pipecat/services/google/gemini_live/llm.py index 7c0c22933..d6966cbc1 100644 --- a/src/pipecat/services/google/gemini_live/llm.py +++ b/src/pipecat/services/google/gemini_live/llm.py @@ -351,7 +351,7 @@ class GeminiLiveLLMSettings(LLMSettings): proactivity: ProactivityConfig | dict | _NotGiven = field(default_factory=lambda: NOT_GIVEN) -class GeminiLiveLLMService(LLMService): +class GeminiLiveLLMService(LLMService[GeminiLLMAdapter]): """Provides access to Google's Gemini Live API. This service enables real-time conversations with Gemini, supporting both @@ -778,7 +778,7 @@ class GeminiLiveLLMService(LLMService): # init-provided values). Note that the determination of "effective" # system instruction is delegated to the adapter, which still # chooses the init-provided value if there is one. - adapter: GeminiLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() params = adapter.get_llm_invocation_params( self._context, system_instruction=assert_given(self._system_instruction_from_init) ) @@ -840,7 +840,7 @@ class GeminiLiveLLMService(LLMService): async def _process_completed_function_calls(self, send_new_results: bool): # Check for set of completed function calls in the context - adapter: GeminiLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() messages = adapter.get_llm_invocation_params(self._context).get("messages", []) for message in messages: if message.parts: @@ -1027,7 +1027,7 @@ class GeminiLiveLLMService(LLMService): # Add system instruction and tools to configuration, if provided. # These settings from the context take precedence over the ones # provided at initialization time. - adapter: GeminiLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() system_instruction = None tools = None if self._context: @@ -1333,7 +1333,7 @@ class GeminiLiveLLMService(LLMService): self._run_llm_when_session_ready = True return - adapter: GeminiLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() messages = adapter.get_llm_invocation_params(self._context).get("messages", []) if not messages: # No messages to seed convo with, so we're ready for realtime input right away @@ -1392,7 +1392,7 @@ class GeminiLiveLLMService(LLMService): # Create a throwaway context just for the purpose of getting messages # in the right format context = LLMContext(messages=messages_list) - adapter: GeminiLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() messages = adapter.get_llm_invocation_params(context).get("messages", []) if not messages: diff --git a/src/pipecat/services/google/llm.py b/src/pipecat/services/google/llm.py index faba2868b..355291e1b 100644 --- a/src/pipecat/services/google/llm.py +++ b/src/pipecat/services/google/llm.py @@ -124,7 +124,7 @@ class GoogleLLMSettings(LLMSettings): return instance -class GoogleLLMService(LLMService): +class GoogleLLMService(LLMService[GeminiLLMAdapter]): """Google AI (Gemini) LLM service implementation. This class implements inference with Google's AI models, translating internally diff --git a/src/pipecat/services/inworld/realtime/llm.py b/src/pipecat/services/inworld/realtime/llm.py index 46b025006..c859c42be 100644 --- a/src/pipecat/services/inworld/realtime/llm.py +++ b/src/pipecat/services/inworld/realtime/llm.py @@ -189,7 +189,7 @@ _NON_FATAL_ERROR_CODES = { } -class InworldRealtimeLLMService(LLMService): +class InworldRealtimeLLMService(LLMService[InworldRealtimeLLMAdapter]): """Inworld Realtime LLM service for real-time audio and text communication. Implements the Inworld Realtime API with WebSocket communication for @@ -664,7 +664,7 @@ class InworldRealtimeLLMService(LLMService): async def _send_session_update(self): """Update session settings on the server.""" settings = assert_given(self._settings.session_properties) - adapter: InworldRealtimeLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() if self._context: llm_invocation_params = adapter.get_llm_invocation_params( @@ -963,7 +963,7 @@ class InworldRealtimeLLMService(LLMService): self._run_llm_when_api_session_ready = True return - adapter: InworldRealtimeLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() if self._llm_needs_conversation_setup: logger.debug( diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 6d8caacab..9310ec78b 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -16,7 +16,10 @@ from collections.abc import Awaitable, Callable, Mapping, Sequence from dataclasses import dataclass from typing import ( Any, + Generic, Protocol, + TypeVar, + cast, ) from loguru import logger @@ -190,7 +193,10 @@ class FunctionCallRunnerItem: group_id: str | None = None -class LLMService(UserTurnCompletionLLMServiceMixin, AIService): +TAdapter = TypeVar("TAdapter", bound=BaseLLMAdapter) + + +class LLMService(UserTurnCompletionLLMServiceMixin, AIService, Generic[TAdapter]): """Base class for all LLM services. Handles function calling registration and execution with support for both @@ -222,6 +228,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): """ _settings: LLMSettings + _adapter: TAdapter # OpenAILLMAdapter is used as the default adapter since it aligns with most LLM implementations. # However, subclasses should override this with a more specific adapter when necessary. @@ -269,7 +276,12 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): self._filter_incomplete_user_turns: bool = False self._async_tool_cancellation_enabled: bool = False self._base_system_instruction: str | None = None - self._adapter = self.adapter_class() + # `adapter_class` is typed as `type[BaseLLMAdapter]` so subclasses + # don't need to spell out the generic parameter just to subclass + # (backward compatibility for 3rd-party providers outside this repo). + # Cast to TAdapter to keep `_adapter` and `get_llm_adapter()` precisely + # typed for callers that opt into `LLMService[XAdapter]`. + self._adapter = cast(TAdapter, self.adapter_class()) self._functions: dict[str | None, FunctionCallRegistryItem] = {} self._function_call_tasks: dict[asyncio.Task | None, FunctionCallRunnerItem] = {} self._sequential_runner_task: asyncio.Task | None = None @@ -280,7 +292,7 @@ class LLMService(UserTurnCompletionLLMServiceMixin, AIService): self._register_event_handler("on_function_calls_cancelled") self._register_event_handler("on_completion_timeout") - def get_llm_adapter(self) -> BaseLLMAdapter: + def get_llm_adapter(self) -> TAdapter: """Get the LLM adapter instance. Returns: @@ -1112,7 +1124,7 @@ class WebsocketReconnectedError(Exception): pass -class WebsocketLLMService(LLMService, WebsocketService): +class WebsocketLLMService(LLMService[TAdapter], WebsocketService, Generic[TAdapter]): """Base class for websocket-based LLM services. Each LLM inference is a discrete request/response exchange: send one diff --git a/src/pipecat/services/openai/base_llm.py b/src/pipecat/services/openai/base_llm.py index 8f494b193..144927a8c 100644 --- a/src/pipecat/services/openai/base_llm.py +++ b/src/pipecat/services/openai/base_llm.py @@ -26,7 +26,7 @@ from openai._types import NotGiven as OpenAINotGiven from openai.types.chat import ChatCompletionChunk from pydantic import BaseModel, Field -from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams +from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter, OpenAILLMInvocationParams from pipecat.frames.frames import ( Frame, LLMContextFrame, @@ -71,7 +71,7 @@ class OpenAILLMSettings(LLMSettings): ) -class BaseOpenAILLMService(LLMService): +class BaseOpenAILLMService(LLMService[OpenAILLMAdapter]): """Base class for all services that use the AsyncOpenAI client. This service consumes LLMContextFrame frames, which contain a reference to diff --git a/src/pipecat/services/openai/realtime/llm.py b/src/pipecat/services/openai/realtime/llm.py index cf4f8943e..8519fddd8 100644 --- a/src/pipecat/services/openai/realtime/llm.py +++ b/src/pipecat/services/openai/realtime/llm.py @@ -194,7 +194,7 @@ class OpenAIRealtimeLLMSettings(LLMSettings): return instance -class OpenAIRealtimeLLMService(LLMService): +class OpenAIRealtimeLLMService(LLMService[OpenAIRealtimeLLMAdapter]): """OpenAI Realtime LLM service providing real-time audio and text communication. Implements the OpenAI Realtime API with WebSocket communication for low-latency @@ -657,7 +657,7 @@ class OpenAIRealtimeLLMService(LLMService): async def _send_session_update(self): settings = assert_given(self._settings.session_properties) - adapter: OpenAIRealtimeLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() if self._context: llm_invocation_params = adapter.get_llm_invocation_params( @@ -1002,7 +1002,7 @@ class OpenAIRealtimeLLMService(LLMService): self._run_llm_when_api_session_ready = True return - adapter: OpenAIRealtimeLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() # Configure the LLM for this session if needed if self._llm_needs_conversation_setup: diff --git a/src/pipecat/services/openai/responses/llm.py b/src/pipecat/services/openai/responses/llm.py index 6528303f9..b15e36f71 100644 --- a/src/pipecat/services/openai/responses/llm.py +++ b/src/pipecat/services/openai/responses/llm.py @@ -115,7 +115,7 @@ class OpenAIResponsesLLMSettings(LLMSettings): # --------------------------------------------------------------------------- -class _BaseOpenAIResponsesLLMService(LLMService): +class _BaseOpenAIResponsesLLMService(LLMService[OpenAIResponsesLLMAdapter]): """Shared base for HTTP and WebSocket OpenAI Responses API services. Contains settings, adapter reference, HTTP client creation, parameter @@ -294,7 +294,7 @@ class _BaseOpenAIResponsesLLMService(LLMService): Returns: The LLM's response as a string, or None if no response is generated. """ - adapter: OpenAIResponsesLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() effective_instruction = system_instruction or assert_given( self._settings.system_instruction ) @@ -353,7 +353,9 @@ class _BaseOpenAIResponsesLLMService(LLMService): # --------------------------------------------------------------------------- -class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService, WebsocketLLMService): +class OpenAIResponsesLLMService( + _BaseOpenAIResponsesLLMService, WebsocketLLMService[OpenAIResponsesLLMAdapter] +): """OpenAI Responses API LLM service using WebSocket transport. Maintains a persistent WebSocket connection to ``wss://api.openai.com/v1/responses`` @@ -747,7 +749,7 @@ class OpenAIResponsesLLMService(_BaseOpenAIResponsesLLMService, WebsocketLLMServ if self._needs_drain: await self._drain_cancelled_response() - adapter: OpenAIResponsesLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() logger.debug( f"{self}: Generating response from universal context " f"{adapter.get_messages_for_logging(context)}" @@ -987,7 +989,7 @@ class OpenAIResponsesHttpLLMService(_BaseOpenAIResponsesLLMService): @traced_llm async def _process_context(self, context: LLMContext): - adapter: OpenAIResponsesLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() logger.debug( f"{self}: Generating response from universal context " f"{adapter.get_messages_for_logging(context)}" diff --git a/src/pipecat/services/xai/realtime/llm.py b/src/pipecat/services/xai/realtime/llm.py index 448b5e339..55a7be4cd 100644 --- a/src/pipecat/services/xai/realtime/llm.py +++ b/src/pipecat/services/xai/realtime/llm.py @@ -179,7 +179,7 @@ class GrokRealtimeLLMSettings(LLMSettings): return instance -class GrokRealtimeLLMService(LLMService): +class GrokRealtimeLLMService(LLMService[GrokRealtimeLLMAdapter]): """Grok Realtime Voice Agent LLM service providing real-time audio and text communication. Implements the Grok Voice Agent API with WebSocket communication for low-latency @@ -596,7 +596,7 @@ class GrokRealtimeLLMService(LLMService): async def _send_session_update(self): """Update session settings on the server.""" settings = assert_given(self._settings.session_properties) - adapter: GrokRealtimeLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() if self._context: llm_invocation_params = adapter.get_llm_invocation_params( @@ -871,7 +871,7 @@ class GrokRealtimeLLMService(LLMService): self._run_llm_when_api_session_ready = True return - adapter: GrokRealtimeLLMAdapter = self.get_llm_adapter() + adapter = self.get_llm_adapter() if self._llm_needs_conversation_setup: logger.debug(