diff --git a/src/pipecat/processors/frameworks/rtvi.py b/src/pipecat/processors/frameworks/rtvi.py index 9bd08c489..343f0a217 100644 --- a/src/pipecat/processors/frameworks/rtvi.py +++ b/src/pipecat/processors/frameworks/rtvi.py @@ -7,8 +7,8 @@ import asyncio import dataclasses -from typing import List, Literal, Optional, Type -from pydantic import BaseModel, ValidationError +from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Type +from pydantic import PrivateAttr, BaseModel, ValidationError from pipecat.frames.frames import ( BotInterruptionFrame, @@ -33,62 +33,76 @@ from pipecat.pipeline.pipeline import Pipeline from pipecat.processors.aggregators.llm_response import ( LLMAssistantResponseAggregator, LLMUserResponseAggregator) from pipecat.processors.frame_processor import FrameDirection, FrameProcessor -from pipecat.services.ai_services import AIService from pipecat.services.cartesia import CartesiaTTSService from pipecat.services.openai import OpenAILLMService, OpenAILLMContext from pipecat.transports.base_transport import BaseTransport -DEFAULT_MESSAGES = [ - { - "role": "system", - "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", - } -] - -DEFAULT_MODEL = "llama3-70b-8192" - -DEFAULT_VOICE = "79a125e8-cd45-4c13-8a67-188112f4dd22" +from loguru import logger -class RTVILLMConfig(BaseModel): - model: Optional[str] = None - messages: Optional[List[dict]] = None +class RTVIServiceOption(BaseModel): + name: str + handler: Optional[Callable[['RTVIProcessor', + 'RTVIServiceOptionConfig'], + Awaitable[None]]] = None -class RTVITTSConfig(BaseModel): - voice: Optional[str] = None +class RTVIService(BaseModel): + name: str + cls: Type[FrameProcessor] + options: List[RTVIServiceOption] + _options_dict: Dict[str, RTVIServiceOption] = PrivateAttr(default={}) + + def model_post_init(self, __context: Any) -> None: + self._options_dict = {} + for option in self.options: + self._options_dict[option.name] = option + return super().model_post_init(__context) + +# +# Client -> Pipecat messages. +# + + +class RTVIServiceOptionConfig(BaseModel): + name: str + value: Any + + +class RTVIServiceConfig(BaseModel): + service: str + options: List[RTVIServiceOptionConfig] class RTVIConfig(BaseModel): - llm: Optional[RTVILLMConfig] = None - tts: Optional[RTVITTSConfig] = None + config: List[RTVIServiceConfig] + _config_dict: Dict[str, RTVIServiceConfig] = PrivateAttr(default={}) + + def model_post_init(self, __context: Any) -> None: + self._config_dict = {} + for c in self.config: + self._config_dict[c.service] = c + return super().model_post_init(__context) -class RTVISetup(BaseModel): - config: Optional[RTVIConfig] = None - - -class RTVILLMMessageData(BaseModel): +class RTVILLMContextData(BaseModel): messages: List[dict] -class RTVITTSMessageData(BaseModel): +class RTVITTSSpeakData(BaseModel): text: str interrupt: Optional[bool] = False -class RTVIMessageData(BaseModel): - setup: Optional[RTVISetup] = None - config: Optional[RTVIConfig] = None - llm: Optional[RTVILLMMessageData] = None - tts: Optional[RTVITTSMessageData] = None - - class RTVIMessage(BaseModel): - label: Literal["rtvi"] = "rtvi" + label: Literal["rtvi-ai"] = "rtvi-ai" type: str id: str - data: Optional[RTVIMessageData] = None + data: Optional[Dict[str, Any]] = None + +# +# Pipecat -> Client responses and messages. +# class RTVIResponseData(BaseModel): @@ -97,7 +111,7 @@ class RTVIResponseData(BaseModel): class RTVIResponse(BaseModel): - label: Literal["rtvi"] = "rtvi" + label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["response"] = "response" id: str data: RTVIResponseData @@ -108,7 +122,7 @@ class RTVIErrorData(BaseModel): class RTVIError(BaseModel): - label: Literal["rtvi"] = "rtvi" + label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["error"] = "error" data: RTVIErrorData @@ -118,7 +132,7 @@ class RTVILLMContextMessageData(BaseModel): class RTVILLMContextMessage(BaseModel): - label: Literal["rtvi"] = "rtvi" + label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["llm-context"] = "llm-context" data: RTVILLMContextMessageData @@ -128,13 +142,13 @@ class RTVITTSTextMessageData(BaseModel): class RTVITTSTextMessage(BaseModel): - label: Literal["rtvi"] = "rtvi" + label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["tts-text"] = "tts-text" data: RTVITTSTextMessageData class RTVIBotReady(BaseModel): - label: Literal["rtvi"] = "rtvi" + label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["bot-ready"] = "bot-ready" @@ -146,23 +160,23 @@ class RTVITranscriptionMessageData(BaseModel): class RTVITranscriptionMessage(BaseModel): - label: Literal["rtvi"] = "rtvi" + label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["user-transcription"] = "user-transcription" data: RTVITranscriptionMessageData class RTVIUserStartedSpeakingMessage(BaseModel): - label: Literal["rtvi"] = "rtvi" + label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["user-started-speaking"] = "user-started-speaking" class RTVIUserStoppedSpeakingMessage(BaseModel): - label: Literal["rtvi"] = "rtvi" + label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["user-stopped-speaking"] = "user-stopped-speaking" class RTVIJSONCompletion(BaseModel): - label: Literal["rtvi"] = "rtvi" + label: Literal["rtvi-ai"] = "rtvi-ai" type: Literal["json-completion"] = "json-completion" data: str @@ -265,29 +279,45 @@ class RTVITTSTextProcessor(FrameProcessor): await self.push_frame(TransportMessageFrame(message=message.model_dump(exclude_none=True))) +async def handle_llm_model_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig): + frame = LLMModelUpdateFrame(option.value) + await rtvi.push_frame(frame) + + +async def handle_llm_messages_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig): + frame = LLMMessagesUpdateFrame(option.value) + await rtvi.push_frame(frame) + + +async def handle_tts_voice_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig): + frame = TTSVoiceUpdateFrame(option.value) + await rtvi.push_frame(frame) + +DEFAULT_LLM_SERVICE = RTVIService( + name="llm", + cls=OpenAILLMService, + options=[ + RTVIServiceOption(name="model", handler=handle_llm_model_update), + RTVIServiceOption(name="messages", handler=handle_llm_messages_update) + ]) + +DEFAULT_TTS_SERVICE = RTVIService( + name="tts", + cls=CartesiaTTSService, + options=[ + RTVIServiceOption(name="voice_id", handler=handle_tts_voice_update), + ]) + + class RTVIProcessor(FrameProcessor): - def __init__( - self, - *, - transport: BaseTransport, - setup: RTVISetup | None = None, - llm_api_key: str = "", - llm_base_url: str = "https://api.groq.com/openai/v1", - tts_api_key: str = "", - llm_cls: Type[AIService] = OpenAILLMService, - tts_cls: Type[AIService] = CartesiaTTSService): + def __init__(self, *, transport: BaseTransport): super().__init__() self._transport = transport - self._setup = setup - self._llm_api_key = llm_api_key - self._llm_base_url = llm_base_url - self._tts_api_key = tts_api_key - self._llm_cls = llm_cls - self._tts_cls = tts_cls + self._config: RTVIConfig | None = None + self._ctor_args: Dict[str, Any] = {} + self._start_frame: Frame | None = None - self._llm: FrameProcessor | None = None - self._tts: FrameProcessor | None = None self._pipeline: FrameProcessor | None = None self._first_participant_joined: bool = False @@ -297,9 +327,24 @@ class RTVIProcessor(FrameProcessor): "on_first_participant_joined", self._on_first_participant_joined) + # Register default services. + self._registered_services: Dict[str, RTVIService] = {} + self.register_service(DEFAULT_LLM_SERVICE) + self.register_service(DEFAULT_TTS_SERVICE) + self._frame_handler_task = self.get_event_loop().create_task(self._frame_handler()) self._frame_queue = asyncio.Queue() + def register_service(self, service: RTVIService): + self._registered_services[service.name] = service + + def setup_on_start(self, config: RTVIConfig | None, ctor_args: Dict[str, Any]): + self._config = config + self._ctor_args = ctor_args + + async def update_config(self, config: RTVIConfig): + await self._handle_config_update(config) + async def process_frame(self, frame: Frame, direction: FrameDirection): await super().process_frame(frame, direction) @@ -309,11 +354,10 @@ class RTVIProcessor(FrameProcessor): await self._frame_queue.put((frame, direction)) if isinstance(frame, StartFrame): - self._start_frame = frame try: - await self._handle_setup(self._setup) + await self._handle_pipeline_setup(frame, self._config) except Exception as e: - await self._send_error(f"unable to setup RTVI: {e}") + await self._send_error(f"unable to setup RTVI pipeline: {e}") async def cleanup(self): self._frame_handler_task.cancel() @@ -379,90 +423,81 @@ class RTVIProcessor(FrameProcessor): try: message = RTVIMessage.model_validate(frame.message) except ValidationError as e: - await self._send_error(f"invalid message: {e}") + await self._send_error(f"Invalid incoming message: {e}") + logger.warning(f"Invalid incoming message: {e}") return try: success = True error = None match message.type: - case "setup": - setup = None - if message.data: - setup = message.data.setup - await self._handle_setup(message.id, setup) case "config-update": - await self._handle_config_update(message.data.config) + await self._handle_config_update(RTVIConfig.model_validate(message.data)) case "llm-get-context": await self._handle_llm_get_context() case "llm-append-context": - await self._handle_llm_append_context(message.data.llm) + await self._handle_llm_append_context(RTVILLMContextData.model_validate(message.data)) case "llm-update-context": - await self._handle_llm_update_context(message.data.llm) + await self._handle_llm_update_context(RTVILLMContextData.model_validate(message.data)) case "tts-speak": - await self._handle_tts_speak(message.data.tts) + await self._handle_tts_speak(RTVITTSSpeakData.model_validate(message.data)) case "tts-interrupt": await self._handle_tts_interrupt() case _: success = False - error = f"unsupported type {message.type}" + error = f"Unsupported type {message.type}" await self._send_response(message.id, success, error) except ValidationError as e: - await self._send_response(message.id, False, f"invalid message: {e}") + await self._send_response(message.id, False, f"Invalid incoming message: {e}") + logger.warning(f"Invalid incoming message: {e}") except Exception as e: - await self._send_response(message.id, False, f"{e}") + await self._send_response(message.id, False, f"Exception processing message: {e}") + logger.warning(f"Exception processing message: {e}") - async def _handle_setup(self, setup: RTVISetup | None): - model = DEFAULT_MODEL - if setup and setup.config and setup.config.llm and setup.config.llm.model: - model = setup.config.llm.model + async def _handle_pipeline_setup(self, start_frame: StartFrame, config: RTVIConfig | None): + # TODO(aleix): We shouldn't need to save this in `self._tma_in`. + self._tma_in = LLMUserResponseAggregator() + tma_out = LLMAssistantResponseAggregator() - messages = DEFAULT_MESSAGES - if setup and setup.config and setup.config.llm and setup.config.llm.messages: - messages = setup.config.llm.messages + llm_cls = self._registered_services["llm"].cls + llm_args = self._ctor_args["llm"] + llm = llm_cls(**llm_args) - voice = DEFAULT_VOICE - if setup and setup.config and setup.config.tts and setup.config.tts.voice: - voice = setup.config.tts.voice - - self._tma_in = LLMUserResponseAggregator(messages) - self._tma_out = LLMAssistantResponseAggregator(messages) - - self._llm = self._llm_cls( - name="LLM", - base_url=self._llm_base_url, - api_key=self._llm_api_key, - model=model) - - self._tts = self._tts_cls(name="TTS", api_key=self._tts_api_key, voice_id=voice) + tts_cls = self._registered_services["tts"].cls + tts_args = self._ctor_args["tts"] + tts = tts_cls(**tts_args) # TODO-CB: Eventually we'll need to switch the context aggregators to use the # OpenAI context frames instead of message frames - context = OpenAILLMContext(messages=messages) - self._fc = FunctionCaller(context) + context = OpenAILLMContext() + fc = FunctionCaller(context) - self._tts_text = RTVITTSTextProcessor() + tts_text = RTVITTSTextProcessor() pipeline = Pipeline([ self._tma_in, - self._llm, - self._fc, - self._tts, - self._tts_text, - self._tma_out, + llm, + fc, + tts, + tts_text, + tma_out, self._transport.output(), ]) parent = self.get_parent() - if parent and self._start_frame: + if parent: parent.link(pipeline) # We need to initialize the new pipeline with the same settings # as the initial one. - start_frame = dataclasses.replace(self._start_frame) + start_frame = dataclasses.replace(start_frame) await self.push_frame(start_frame) + # Configure the pipeline + if config: + await self._handle_config_update(config) + # Send new initial metrics with the new processors processors = parent.processors_with_metrics() processors.extend(pipeline.processors_with_metrics()) @@ -474,17 +509,16 @@ class RTVIProcessor(FrameProcessor): await self._maybe_send_bot_ready() - async def _handle_config_update(self, config: RTVIConfig): - # Change voice before LLM updates, so we can hear the new vocie. - if config.tts and config.tts.voice: - frame = TTSVoiceUpdateFrame(config.tts.voice) - await self.push_frame(frame) - if config.llm and config.llm.model: - frame = LLMModelUpdateFrame(config.llm.model) - await self.push_frame(frame) - if config.llm and config.llm.messages: - frame = LLMMessagesUpdateFrame(config.llm.messages) - await self.push_frame(frame) + async def _handle_config_service(self, config: RTVIServiceConfig): + service = self._registered_services[config.service] + for option in config.options: + handler = service._options_dict[option.name].handler + if handler: + await handler(self, option) + + async def _handle_config_update(self, data: RTVIConfig): + for config in data.config: + await self._handle_config_service(config) async def _handle_llm_get_context(self): data = RTVILLMContextMessageData(messages=self._tma_in.messages) @@ -492,17 +526,17 @@ class RTVIProcessor(FrameProcessor): frame = TransportMessageFrame(message=message.model_dump(exclude_none=True)) await self.push_frame(frame) - async def _handle_llm_append_context(self, data: RTVILLMMessageData): + async def _handle_llm_append_context(self, data: RTVILLMContextData): if data and data.messages: frame = LLMMessagesAppendFrame(data.messages) await self.push_frame(frame) - async def _handle_llm_update_context(self, data: RTVILLMMessageData): + async def _handle_llm_update_context(self, data: RTVILLMContextData): if data and data.messages: frame = LLMMessagesUpdateFrame(data.messages) await self.push_frame(frame) - async def _handle_tts_speak(self, data: RTVITTSMessageData): + async def _handle_tts_speak(self, data: RTVITTSSpeakData): if data and data.text: if data.interrupt: await self._handle_tts_interrupt() @@ -539,7 +573,7 @@ class RTVIProcessor(FrameProcessor): self._pipeline = pipeline parent = self.get_parent() - if parent and self._start_frame: + if parent: parent.link(pipeline) message = RTVIResponse(id=id, data=RTVIResponseData(success=success, error=error))