processors(rtvi): refactor to allow future custom pipelines
This commit is contained in:
@@ -7,8 +7,8 @@
|
||||
import asyncio
|
||||
import dataclasses
|
||||
|
||||
from typing import List, Literal, Optional, Type
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Type
|
||||
from pydantic import PrivateAttr, BaseModel, ValidationError
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
@@ -33,62 +33,76 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.ai_services import AIService
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMService, OpenAILLMContext
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
|
||||
DEFAULT_MESSAGES = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
}
|
||||
]
|
||||
|
||||
DEFAULT_MODEL = "llama3-70b-8192"
|
||||
|
||||
DEFAULT_VOICE = "79a125e8-cd45-4c13-8a67-188112f4dd22"
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class RTVILLMConfig(BaseModel):
|
||||
model: Optional[str] = None
|
||||
messages: Optional[List[dict]] = None
|
||||
class RTVIServiceOption(BaseModel):
|
||||
name: str
|
||||
handler: Optional[Callable[['RTVIProcessor',
|
||||
'RTVIServiceOptionConfig'],
|
||||
Awaitable[None]]] = None
|
||||
|
||||
|
||||
class RTVITTSConfig(BaseModel):
|
||||
voice: Optional[str] = None
|
||||
class RTVIService(BaseModel):
|
||||
name: str
|
||||
cls: Type[FrameProcessor]
|
||||
options: List[RTVIServiceOption]
|
||||
_options_dict: Dict[str, RTVIServiceOption] = PrivateAttr(default={})
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._options_dict = {}
|
||||
for option in self.options:
|
||||
self._options_dict[option.name] = option
|
||||
return super().model_post_init(__context)
|
||||
|
||||
#
|
||||
# Client -> Pipecat messages.
|
||||
#
|
||||
|
||||
|
||||
class RTVIServiceOptionConfig(BaseModel):
|
||||
name: str
|
||||
value: Any
|
||||
|
||||
|
||||
class RTVIServiceConfig(BaseModel):
|
||||
service: str
|
||||
options: List[RTVIServiceOptionConfig]
|
||||
|
||||
|
||||
class RTVIConfig(BaseModel):
|
||||
llm: Optional[RTVILLMConfig] = None
|
||||
tts: Optional[RTVITTSConfig] = None
|
||||
config: List[RTVIServiceConfig]
|
||||
_config_dict: Dict[str, RTVIServiceConfig] = PrivateAttr(default={})
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
self._config_dict = {}
|
||||
for c in self.config:
|
||||
self._config_dict[c.service] = c
|
||||
return super().model_post_init(__context)
|
||||
|
||||
|
||||
class RTVISetup(BaseModel):
|
||||
config: Optional[RTVIConfig] = None
|
||||
|
||||
|
||||
class RTVILLMMessageData(BaseModel):
|
||||
class RTVILLMContextData(BaseModel):
|
||||
messages: List[dict]
|
||||
|
||||
|
||||
class RTVITTSMessageData(BaseModel):
|
||||
class RTVITTSSpeakData(BaseModel):
|
||||
text: str
|
||||
interrupt: Optional[bool] = False
|
||||
|
||||
|
||||
class RTVIMessageData(BaseModel):
|
||||
setup: Optional[RTVISetup] = None
|
||||
config: Optional[RTVIConfig] = None
|
||||
llm: Optional[RTVILLMMessageData] = None
|
||||
tts: Optional[RTVITTSMessageData] = None
|
||||
|
||||
|
||||
class RTVIMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: str
|
||||
id: str
|
||||
data: Optional[RTVIMessageData] = None
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
|
||||
#
|
||||
# Pipecat -> Client responses and messages.
|
||||
#
|
||||
|
||||
|
||||
class RTVIResponseData(BaseModel):
|
||||
@@ -97,7 +111,7 @@ class RTVIResponseData(BaseModel):
|
||||
|
||||
|
||||
class RTVIResponse(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["response"] = "response"
|
||||
id: str
|
||||
data: RTVIResponseData
|
||||
@@ -108,7 +122,7 @@ class RTVIErrorData(BaseModel):
|
||||
|
||||
|
||||
class RTVIError(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["error"] = "error"
|
||||
data: RTVIErrorData
|
||||
|
||||
@@ -118,7 +132,7 @@ class RTVILLMContextMessageData(BaseModel):
|
||||
|
||||
|
||||
class RTVILLMContextMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["llm-context"] = "llm-context"
|
||||
data: RTVILLMContextMessageData
|
||||
|
||||
@@ -128,13 +142,13 @@ class RTVITTSTextMessageData(BaseModel):
|
||||
|
||||
|
||||
class RTVITTSTextMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["tts-text"] = "tts-text"
|
||||
data: RTVITTSTextMessageData
|
||||
|
||||
|
||||
class RTVIBotReady(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-ready"] = "bot-ready"
|
||||
|
||||
|
||||
@@ -146,23 +160,23 @@ class RTVITranscriptionMessageData(BaseModel):
|
||||
|
||||
|
||||
class RTVITranscriptionMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-transcription"] = "user-transcription"
|
||||
data: RTVITranscriptionMessageData
|
||||
|
||||
|
||||
class RTVIUserStartedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-started-speaking"] = "user-started-speaking"
|
||||
|
||||
|
||||
class RTVIUserStoppedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-stopped-speaking"] = "user-stopped-speaking"
|
||||
|
||||
|
||||
class RTVIJSONCompletion(BaseModel):
|
||||
label: Literal["rtvi"] = "rtvi"
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["json-completion"] = "json-completion"
|
||||
data: str
|
||||
|
||||
@@ -265,29 +279,45 @@ class RTVITTSTextProcessor(FrameProcessor):
|
||||
await self.push_frame(TransportMessageFrame(message=message.model_dump(exclude_none=True)))
|
||||
|
||||
|
||||
async def handle_llm_model_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
|
||||
frame = LLMModelUpdateFrame(option.value)
|
||||
await rtvi.push_frame(frame)
|
||||
|
||||
|
||||
async def handle_llm_messages_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
|
||||
frame = LLMMessagesUpdateFrame(option.value)
|
||||
await rtvi.push_frame(frame)
|
||||
|
||||
|
||||
async def handle_tts_voice_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
|
||||
frame = TTSVoiceUpdateFrame(option.value)
|
||||
await rtvi.push_frame(frame)
|
||||
|
||||
DEFAULT_LLM_SERVICE = RTVIService(
|
||||
name="llm",
|
||||
cls=OpenAILLMService,
|
||||
options=[
|
||||
RTVIServiceOption(name="model", handler=handle_llm_model_update),
|
||||
RTVIServiceOption(name="messages", handler=handle_llm_messages_update)
|
||||
])
|
||||
|
||||
DEFAULT_TTS_SERVICE = RTVIService(
|
||||
name="tts",
|
||||
cls=CartesiaTTSService,
|
||||
options=[
|
||||
RTVIServiceOption(name="voice_id", handler=handle_tts_voice_update),
|
||||
])
|
||||
|
||||
|
||||
class RTVIProcessor(FrameProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
transport: BaseTransport,
|
||||
setup: RTVISetup | None = None,
|
||||
llm_api_key: str = "",
|
||||
llm_base_url: str = "https://api.groq.com/openai/v1",
|
||||
tts_api_key: str = "",
|
||||
llm_cls: Type[AIService] = OpenAILLMService,
|
||||
tts_cls: Type[AIService] = CartesiaTTSService):
|
||||
def __init__(self, *, transport: BaseTransport):
|
||||
super().__init__()
|
||||
self._transport = transport
|
||||
self._setup = setup
|
||||
self._llm_api_key = llm_api_key
|
||||
self._llm_base_url = llm_base_url
|
||||
self._tts_api_key = tts_api_key
|
||||
self._llm_cls = llm_cls
|
||||
self._tts_cls = tts_cls
|
||||
self._config: RTVIConfig | None = None
|
||||
self._ctor_args: Dict[str, Any] = {}
|
||||
|
||||
self._start_frame: Frame | None = None
|
||||
self._llm: FrameProcessor | None = None
|
||||
self._tts: FrameProcessor | None = None
|
||||
self._pipeline: FrameProcessor | None = None
|
||||
self._first_participant_joined: bool = False
|
||||
|
||||
@@ -297,9 +327,24 @@ class RTVIProcessor(FrameProcessor):
|
||||
"on_first_participant_joined",
|
||||
self._on_first_participant_joined)
|
||||
|
||||
# Register default services.
|
||||
self._registered_services: Dict[str, RTVIService] = {}
|
||||
self.register_service(DEFAULT_LLM_SERVICE)
|
||||
self.register_service(DEFAULT_TTS_SERVICE)
|
||||
|
||||
self._frame_handler_task = self.get_event_loop().create_task(self._frame_handler())
|
||||
self._frame_queue = asyncio.Queue()
|
||||
|
||||
def register_service(self, service: RTVIService):
|
||||
self._registered_services[service.name] = service
|
||||
|
||||
def setup_on_start(self, config: RTVIConfig | None, ctor_args: Dict[str, Any]):
|
||||
self._config = config
|
||||
self._ctor_args = ctor_args
|
||||
|
||||
async def update_config(self, config: RTVIConfig):
|
||||
await self._handle_config_update(config)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -309,11 +354,10 @@ class RTVIProcessor(FrameProcessor):
|
||||
await self._frame_queue.put((frame, direction))
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
self._start_frame = frame
|
||||
try:
|
||||
await self._handle_setup(self._setup)
|
||||
await self._handle_pipeline_setup(frame, self._config)
|
||||
except Exception as e:
|
||||
await self._send_error(f"unable to setup RTVI: {e}")
|
||||
await self._send_error(f"unable to setup RTVI pipeline: {e}")
|
||||
|
||||
async def cleanup(self):
|
||||
self._frame_handler_task.cancel()
|
||||
@@ -379,90 +423,81 @@ class RTVIProcessor(FrameProcessor):
|
||||
try:
|
||||
message = RTVIMessage.model_validate(frame.message)
|
||||
except ValidationError as e:
|
||||
await self._send_error(f"invalid message: {e}")
|
||||
await self._send_error(f"Invalid incoming message: {e}")
|
||||
logger.warning(f"Invalid incoming message: {e}")
|
||||
return
|
||||
|
||||
try:
|
||||
success = True
|
||||
error = None
|
||||
match message.type:
|
||||
case "setup":
|
||||
setup = None
|
||||
if message.data:
|
||||
setup = message.data.setup
|
||||
await self._handle_setup(message.id, setup)
|
||||
case "config-update":
|
||||
await self._handle_config_update(message.data.config)
|
||||
await self._handle_config_update(RTVIConfig.model_validate(message.data))
|
||||
case "llm-get-context":
|
||||
await self._handle_llm_get_context()
|
||||
case "llm-append-context":
|
||||
await self._handle_llm_append_context(message.data.llm)
|
||||
await self._handle_llm_append_context(RTVILLMContextData.model_validate(message.data))
|
||||
case "llm-update-context":
|
||||
await self._handle_llm_update_context(message.data.llm)
|
||||
await self._handle_llm_update_context(RTVILLMContextData.model_validate(message.data))
|
||||
case "tts-speak":
|
||||
await self._handle_tts_speak(message.data.tts)
|
||||
await self._handle_tts_speak(RTVITTSSpeakData.model_validate(message.data))
|
||||
case "tts-interrupt":
|
||||
await self._handle_tts_interrupt()
|
||||
case _:
|
||||
success = False
|
||||
error = f"unsupported type {message.type}"
|
||||
error = f"Unsupported type {message.type}"
|
||||
|
||||
await self._send_response(message.id, success, error)
|
||||
except ValidationError as e:
|
||||
await self._send_response(message.id, False, f"invalid message: {e}")
|
||||
await self._send_response(message.id, False, f"Invalid incoming message: {e}")
|
||||
logger.warning(f"Invalid incoming message: {e}")
|
||||
except Exception as e:
|
||||
await self._send_response(message.id, False, f"{e}")
|
||||
await self._send_response(message.id, False, f"Exception processing message: {e}")
|
||||
logger.warning(f"Exception processing message: {e}")
|
||||
|
||||
async def _handle_setup(self, setup: RTVISetup | None):
|
||||
model = DEFAULT_MODEL
|
||||
if setup and setup.config and setup.config.llm and setup.config.llm.model:
|
||||
model = setup.config.llm.model
|
||||
async def _handle_pipeline_setup(self, start_frame: StartFrame, config: RTVIConfig | None):
|
||||
# TODO(aleix): We shouldn't need to save this in `self._tma_in`.
|
||||
self._tma_in = LLMUserResponseAggregator()
|
||||
tma_out = LLMAssistantResponseAggregator()
|
||||
|
||||
messages = DEFAULT_MESSAGES
|
||||
if setup and setup.config and setup.config.llm and setup.config.llm.messages:
|
||||
messages = setup.config.llm.messages
|
||||
llm_cls = self._registered_services["llm"].cls
|
||||
llm_args = self._ctor_args["llm"]
|
||||
llm = llm_cls(**llm_args)
|
||||
|
||||
voice = DEFAULT_VOICE
|
||||
if setup and setup.config and setup.config.tts and setup.config.tts.voice:
|
||||
voice = setup.config.tts.voice
|
||||
|
||||
self._tma_in = LLMUserResponseAggregator(messages)
|
||||
self._tma_out = LLMAssistantResponseAggregator(messages)
|
||||
|
||||
self._llm = self._llm_cls(
|
||||
name="LLM",
|
||||
base_url=self._llm_base_url,
|
||||
api_key=self._llm_api_key,
|
||||
model=model)
|
||||
|
||||
self._tts = self._tts_cls(name="TTS", api_key=self._tts_api_key, voice_id=voice)
|
||||
tts_cls = self._registered_services["tts"].cls
|
||||
tts_args = self._ctor_args["tts"]
|
||||
tts = tts_cls(**tts_args)
|
||||
|
||||
# TODO-CB: Eventually we'll need to switch the context aggregators to use the
|
||||
# OpenAI context frames instead of message frames
|
||||
context = OpenAILLMContext(messages=messages)
|
||||
self._fc = FunctionCaller(context)
|
||||
context = OpenAILLMContext()
|
||||
fc = FunctionCaller(context)
|
||||
|
||||
self._tts_text = RTVITTSTextProcessor()
|
||||
tts_text = RTVITTSTextProcessor()
|
||||
|
||||
pipeline = Pipeline([
|
||||
self._tma_in,
|
||||
self._llm,
|
||||
self._fc,
|
||||
self._tts,
|
||||
self._tts_text,
|
||||
self._tma_out,
|
||||
llm,
|
||||
fc,
|
||||
tts,
|
||||
tts_text,
|
||||
tma_out,
|
||||
self._transport.output(),
|
||||
])
|
||||
|
||||
parent = self.get_parent()
|
||||
if parent and self._start_frame:
|
||||
if parent:
|
||||
parent.link(pipeline)
|
||||
|
||||
# We need to initialize the new pipeline with the same settings
|
||||
# as the initial one.
|
||||
start_frame = dataclasses.replace(self._start_frame)
|
||||
start_frame = dataclasses.replace(start_frame)
|
||||
await self.push_frame(start_frame)
|
||||
|
||||
# Configure the pipeline
|
||||
if config:
|
||||
await self._handle_config_update(config)
|
||||
|
||||
# Send new initial metrics with the new processors
|
||||
processors = parent.processors_with_metrics()
|
||||
processors.extend(pipeline.processors_with_metrics())
|
||||
@@ -474,17 +509,16 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
await self._maybe_send_bot_ready()
|
||||
|
||||
async def _handle_config_update(self, config: RTVIConfig):
|
||||
# Change voice before LLM updates, so we can hear the new vocie.
|
||||
if config.tts and config.tts.voice:
|
||||
frame = TTSVoiceUpdateFrame(config.tts.voice)
|
||||
await self.push_frame(frame)
|
||||
if config.llm and config.llm.model:
|
||||
frame = LLMModelUpdateFrame(config.llm.model)
|
||||
await self.push_frame(frame)
|
||||
if config.llm and config.llm.messages:
|
||||
frame = LLMMessagesUpdateFrame(config.llm.messages)
|
||||
await self.push_frame(frame)
|
||||
async def _handle_config_service(self, config: RTVIServiceConfig):
|
||||
service = self._registered_services[config.service]
|
||||
for option in config.options:
|
||||
handler = service._options_dict[option.name].handler
|
||||
if handler:
|
||||
await handler(self, option)
|
||||
|
||||
async def _handle_config_update(self, data: RTVIConfig):
|
||||
for config in data.config:
|
||||
await self._handle_config_service(config)
|
||||
|
||||
async def _handle_llm_get_context(self):
|
||||
data = RTVILLMContextMessageData(messages=self._tma_in.messages)
|
||||
@@ -492,17 +526,17 @@ class RTVIProcessor(FrameProcessor):
|
||||
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_llm_append_context(self, data: RTVILLMMessageData):
|
||||
async def _handle_llm_append_context(self, data: RTVILLMContextData):
|
||||
if data and data.messages:
|
||||
frame = LLMMessagesAppendFrame(data.messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_llm_update_context(self, data: RTVILLMMessageData):
|
||||
async def _handle_llm_update_context(self, data: RTVILLMContextData):
|
||||
if data and data.messages:
|
||||
frame = LLMMessagesUpdateFrame(data.messages)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _handle_tts_speak(self, data: RTVITTSMessageData):
|
||||
async def _handle_tts_speak(self, data: RTVITTSSpeakData):
|
||||
if data and data.text:
|
||||
if data.interrupt:
|
||||
await self._handle_tts_interrupt()
|
||||
@@ -539,7 +573,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._pipeline = pipeline
|
||||
|
||||
parent = self.get_parent()
|
||||
if parent and self._start_frame:
|
||||
if parent:
|
||||
parent.link(pipeline)
|
||||
|
||||
message = RTVIResponse(id=id, data=RTVIResponseData(success=success, error=error))
|
||||
|
||||
Reference in New Issue
Block a user