processors(rtvi): refactor to allow future custom pipelines

This commit is contained in:
Aleix Conchillo Flaqué
2024-07-26 10:26:36 -07:00
parent 028e38a86b
commit c1e8a5e522

View File

@@ -7,8 +7,8 @@
import asyncio
import dataclasses
from typing import List, Literal, Optional, Type
from pydantic import BaseModel, ValidationError
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Type
from pydantic import PrivateAttr, BaseModel, ValidationError
from pipecat.frames.frames import (
BotInterruptionFrame,
@@ -33,62 +33,76 @@ from pipecat.pipeline.pipeline import Pipeline
from pipecat.processors.aggregators.llm_response import (
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.ai_services import AIService
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai import OpenAILLMService, OpenAILLMContext
from pipecat.transports.base_transport import BaseTransport
DEFAULT_MESSAGES = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
}
]
DEFAULT_MODEL = "llama3-70b-8192"
DEFAULT_VOICE = "79a125e8-cd45-4c13-8a67-188112f4dd22"
from loguru import logger
class RTVILLMConfig(BaseModel):
model: Optional[str] = None
messages: Optional[List[dict]] = None
class RTVIServiceOption(BaseModel):
name: str
handler: Optional[Callable[['RTVIProcessor',
'RTVIServiceOptionConfig'],
Awaitable[None]]] = None
class RTVITTSConfig(BaseModel):
voice: Optional[str] = None
class RTVIService(BaseModel):
name: str
cls: Type[FrameProcessor]
options: List[RTVIServiceOption]
_options_dict: Dict[str, RTVIServiceOption] = PrivateAttr(default={})
def model_post_init(self, __context: Any) -> None:
self._options_dict = {}
for option in self.options:
self._options_dict[option.name] = option
return super().model_post_init(__context)
#
# Client -> Pipecat messages.
#
class RTVIServiceOptionConfig(BaseModel):
name: str
value: Any
class RTVIServiceConfig(BaseModel):
service: str
options: List[RTVIServiceOptionConfig]
class RTVIConfig(BaseModel):
llm: Optional[RTVILLMConfig] = None
tts: Optional[RTVITTSConfig] = None
config: List[RTVIServiceConfig]
_config_dict: Dict[str, RTVIServiceConfig] = PrivateAttr(default={})
def model_post_init(self, __context: Any) -> None:
self._config_dict = {}
for c in self.config:
self._config_dict[c.service] = c
return super().model_post_init(__context)
class RTVISetup(BaseModel):
config: Optional[RTVIConfig] = None
class RTVILLMMessageData(BaseModel):
class RTVILLMContextData(BaseModel):
messages: List[dict]
class RTVITTSMessageData(BaseModel):
class RTVITTSSpeakData(BaseModel):
text: str
interrupt: Optional[bool] = False
class RTVIMessageData(BaseModel):
setup: Optional[RTVISetup] = None
config: Optional[RTVIConfig] = None
llm: Optional[RTVILLMMessageData] = None
tts: Optional[RTVITTSMessageData] = None
class RTVIMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: str
id: str
data: Optional[RTVIMessageData] = None
data: Optional[Dict[str, Any]] = None
#
# Pipecat -> Client responses and messages.
#
class RTVIResponseData(BaseModel):
@@ -97,7 +111,7 @@ class RTVIResponseData(BaseModel):
class RTVIResponse(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["response"] = "response"
id: str
data: RTVIResponseData
@@ -108,7 +122,7 @@ class RTVIErrorData(BaseModel):
class RTVIError(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["error"] = "error"
data: RTVIErrorData
@@ -118,7 +132,7 @@ class RTVILLMContextMessageData(BaseModel):
class RTVILLMContextMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["llm-context"] = "llm-context"
data: RTVILLMContextMessageData
@@ -128,13 +142,13 @@ class RTVITTSTextMessageData(BaseModel):
class RTVITTSTextMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["tts-text"] = "tts-text"
data: RTVITTSTextMessageData
class RTVIBotReady(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["bot-ready"] = "bot-ready"
@@ -146,23 +160,23 @@ class RTVITranscriptionMessageData(BaseModel):
class RTVITranscriptionMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["user-transcription"] = "user-transcription"
data: RTVITranscriptionMessageData
class RTVIUserStartedSpeakingMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["user-started-speaking"] = "user-started-speaking"
class RTVIUserStoppedSpeakingMessage(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["user-stopped-speaking"] = "user-stopped-speaking"
class RTVIJSONCompletion(BaseModel):
label: Literal["rtvi"] = "rtvi"
label: Literal["rtvi-ai"] = "rtvi-ai"
type: Literal["json-completion"] = "json-completion"
data: str
@@ -265,29 +279,45 @@ class RTVITTSTextProcessor(FrameProcessor):
await self.push_frame(TransportMessageFrame(message=message.model_dump(exclude_none=True)))
async def handle_llm_model_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
frame = LLMModelUpdateFrame(option.value)
await rtvi.push_frame(frame)
async def handle_llm_messages_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
frame = LLMMessagesUpdateFrame(option.value)
await rtvi.push_frame(frame)
async def handle_tts_voice_update(rtvi: 'RTVIProcessor', option: RTVIServiceOptionConfig):
frame = TTSVoiceUpdateFrame(option.value)
await rtvi.push_frame(frame)
DEFAULT_LLM_SERVICE = RTVIService(
name="llm",
cls=OpenAILLMService,
options=[
RTVIServiceOption(name="model", handler=handle_llm_model_update),
RTVIServiceOption(name="messages", handler=handle_llm_messages_update)
])
DEFAULT_TTS_SERVICE = RTVIService(
name="tts",
cls=CartesiaTTSService,
options=[
RTVIServiceOption(name="voice_id", handler=handle_tts_voice_update),
])
class RTVIProcessor(FrameProcessor):
def __init__(
self,
*,
transport: BaseTransport,
setup: RTVISetup | None = None,
llm_api_key: str = "",
llm_base_url: str = "https://api.groq.com/openai/v1",
tts_api_key: str = "",
llm_cls: Type[AIService] = OpenAILLMService,
tts_cls: Type[AIService] = CartesiaTTSService):
def __init__(self, *, transport: BaseTransport):
super().__init__()
self._transport = transport
self._setup = setup
self._llm_api_key = llm_api_key
self._llm_base_url = llm_base_url
self._tts_api_key = tts_api_key
self._llm_cls = llm_cls
self._tts_cls = tts_cls
self._config: RTVIConfig | None = None
self._ctor_args: Dict[str, Any] = {}
self._start_frame: Frame | None = None
self._llm: FrameProcessor | None = None
self._tts: FrameProcessor | None = None
self._pipeline: FrameProcessor | None = None
self._first_participant_joined: bool = False
@@ -297,9 +327,24 @@ class RTVIProcessor(FrameProcessor):
"on_first_participant_joined",
self._on_first_participant_joined)
# Register default services.
self._registered_services: Dict[str, RTVIService] = {}
self.register_service(DEFAULT_LLM_SERVICE)
self.register_service(DEFAULT_TTS_SERVICE)
self._frame_handler_task = self.get_event_loop().create_task(self._frame_handler())
self._frame_queue = asyncio.Queue()
def register_service(self, service: RTVIService):
self._registered_services[service.name] = service
def setup_on_start(self, config: RTVIConfig | None, ctor_args: Dict[str, Any]):
self._config = config
self._ctor_args = ctor_args
async def update_config(self, config: RTVIConfig):
await self._handle_config_update(config)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -309,11 +354,10 @@ class RTVIProcessor(FrameProcessor):
await self._frame_queue.put((frame, direction))
if isinstance(frame, StartFrame):
self._start_frame = frame
try:
await self._handle_setup(self._setup)
await self._handle_pipeline_setup(frame, self._config)
except Exception as e:
await self._send_error(f"unable to setup RTVI: {e}")
await self._send_error(f"unable to setup RTVI pipeline: {e}")
async def cleanup(self):
self._frame_handler_task.cancel()
@@ -379,90 +423,81 @@ class RTVIProcessor(FrameProcessor):
try:
message = RTVIMessage.model_validate(frame.message)
except ValidationError as e:
await self._send_error(f"invalid message: {e}")
await self._send_error(f"Invalid incoming message: {e}")
logger.warning(f"Invalid incoming message: {e}")
return
try:
success = True
error = None
match message.type:
case "setup":
setup = None
if message.data:
setup = message.data.setup
await self._handle_setup(message.id, setup)
case "config-update":
await self._handle_config_update(message.data.config)
await self._handle_config_update(RTVIConfig.model_validate(message.data))
case "llm-get-context":
await self._handle_llm_get_context()
case "llm-append-context":
await self._handle_llm_append_context(message.data.llm)
await self._handle_llm_append_context(RTVILLMContextData.model_validate(message.data))
case "llm-update-context":
await self._handle_llm_update_context(message.data.llm)
await self._handle_llm_update_context(RTVILLMContextData.model_validate(message.data))
case "tts-speak":
await self._handle_tts_speak(message.data.tts)
await self._handle_tts_speak(RTVITTSSpeakData.model_validate(message.data))
case "tts-interrupt":
await self._handle_tts_interrupt()
case _:
success = False
error = f"unsupported type {message.type}"
error = f"Unsupported type {message.type}"
await self._send_response(message.id, success, error)
except ValidationError as e:
await self._send_response(message.id, False, f"invalid message: {e}")
await self._send_response(message.id, False, f"Invalid incoming message: {e}")
logger.warning(f"Invalid incoming message: {e}")
except Exception as e:
await self._send_response(message.id, False, f"{e}")
await self._send_response(message.id, False, f"Exception processing message: {e}")
logger.warning(f"Exception processing message: {e}")
async def _handle_setup(self, setup: RTVISetup | None):
model = DEFAULT_MODEL
if setup and setup.config and setup.config.llm and setup.config.llm.model:
model = setup.config.llm.model
async def _handle_pipeline_setup(self, start_frame: StartFrame, config: RTVIConfig | None):
# TODO(aleix): We shouldn't need to save this in `self._tma_in`.
self._tma_in = LLMUserResponseAggregator()
tma_out = LLMAssistantResponseAggregator()
messages = DEFAULT_MESSAGES
if setup and setup.config and setup.config.llm and setup.config.llm.messages:
messages = setup.config.llm.messages
llm_cls = self._registered_services["llm"].cls
llm_args = self._ctor_args["llm"]
llm = llm_cls(**llm_args)
voice = DEFAULT_VOICE
if setup and setup.config and setup.config.tts and setup.config.tts.voice:
voice = setup.config.tts.voice
self._tma_in = LLMUserResponseAggregator(messages)
self._tma_out = LLMAssistantResponseAggregator(messages)
self._llm = self._llm_cls(
name="LLM",
base_url=self._llm_base_url,
api_key=self._llm_api_key,
model=model)
self._tts = self._tts_cls(name="TTS", api_key=self._tts_api_key, voice_id=voice)
tts_cls = self._registered_services["tts"].cls
tts_args = self._ctor_args["tts"]
tts = tts_cls(**tts_args)
# TODO-CB: Eventually we'll need to switch the context aggregators to use the
# OpenAI context frames instead of message frames
context = OpenAILLMContext(messages=messages)
self._fc = FunctionCaller(context)
context = OpenAILLMContext()
fc = FunctionCaller(context)
self._tts_text = RTVITTSTextProcessor()
tts_text = RTVITTSTextProcessor()
pipeline = Pipeline([
self._tma_in,
self._llm,
self._fc,
self._tts,
self._tts_text,
self._tma_out,
llm,
fc,
tts,
tts_text,
tma_out,
self._transport.output(),
])
parent = self.get_parent()
if parent and self._start_frame:
if parent:
parent.link(pipeline)
# We need to initialize the new pipeline with the same settings
# as the initial one.
start_frame = dataclasses.replace(self._start_frame)
start_frame = dataclasses.replace(start_frame)
await self.push_frame(start_frame)
# Configure the pipeline
if config:
await self._handle_config_update(config)
# Send new initial metrics with the new processors
processors = parent.processors_with_metrics()
processors.extend(pipeline.processors_with_metrics())
@@ -474,17 +509,16 @@ class RTVIProcessor(FrameProcessor):
await self._maybe_send_bot_ready()
async def _handle_config_update(self, config: RTVIConfig):
# Change voice before LLM updates, so we can hear the new vocie.
if config.tts and config.tts.voice:
frame = TTSVoiceUpdateFrame(config.tts.voice)
await self.push_frame(frame)
if config.llm and config.llm.model:
frame = LLMModelUpdateFrame(config.llm.model)
await self.push_frame(frame)
if config.llm and config.llm.messages:
frame = LLMMessagesUpdateFrame(config.llm.messages)
await self.push_frame(frame)
async def _handle_config_service(self, config: RTVIServiceConfig):
service = self._registered_services[config.service]
for option in config.options:
handler = service._options_dict[option.name].handler
if handler:
await handler(self, option)
async def _handle_config_update(self, data: RTVIConfig):
for config in data.config:
await self._handle_config_service(config)
async def _handle_llm_get_context(self):
data = RTVILLMContextMessageData(messages=self._tma_in.messages)
@@ -492,17 +526,17 @@ class RTVIProcessor(FrameProcessor):
frame = TransportMessageFrame(message=message.model_dump(exclude_none=True))
await self.push_frame(frame)
async def _handle_llm_append_context(self, data: RTVILLMMessageData):
async def _handle_llm_append_context(self, data: RTVILLMContextData):
if data and data.messages:
frame = LLMMessagesAppendFrame(data.messages)
await self.push_frame(frame)
async def _handle_llm_update_context(self, data: RTVILLMMessageData):
async def _handle_llm_update_context(self, data: RTVILLMContextData):
if data and data.messages:
frame = LLMMessagesUpdateFrame(data.messages)
await self.push_frame(frame)
async def _handle_tts_speak(self, data: RTVITTSMessageData):
async def _handle_tts_speak(self, data: RTVITTSSpeakData):
if data and data.text:
if data.interrupt:
await self._handle_tts_interrupt()
@@ -539,7 +573,7 @@ class RTVIProcessor(FrameProcessor):
self._pipeline = pipeline
parent = self.get_parent()
if parent and self._start_frame:
if parent:
parent.link(pipeline)
message = RTVIResponse(id=id, data=RTVIResponseData(success=success, error=error))