Compare commits

...

2 Commits

Author SHA1 Message Date
Aleix Conchillo Flaqué
06043ce9b1 missing no longer necessary to call super().process_frame(frame, direction) 2024-12-12 14:53:56 -08:00
Aleix Conchillo Flaqué
3f3a853d71 no longer necessary to call AIService super().start/stop/cancel(frame) 2024-12-12 14:45:20 -08:00
22 changed files with 16 additions and 82 deletions

View File

@@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- It's no longer necessary to call `super().start/stop/cancel(frame)` if you
subclass and implement `AIService.start/stop/cancel()`. This is all now done
internally and will avoid possible issues if you forget to add it.
- It's no longer necessary to call `super().process_frame(frame, direction)` if
you subclass and implement `FrameProcessor.process_frame()`. This is all now
done internally and will avoid possible issues if you forget to add it.

View File

@@ -111,11 +111,11 @@ class AIService(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, StartFrame):
await self.start(frame)
await self._start(frame)
elif isinstance(frame, CancelFrame):
await self.cancel(frame)
await self._cancel(frame)
elif isinstance(frame, EndFrame):
await self.stop(frame)
await self._stop(frame)
async def process_generator(self, generator: AsyncGenerator[Frame | None, None]):
async for f in generator:
@@ -125,6 +125,15 @@ class AIService(FrameProcessor):
else:
await self.push_frame(f)
async def _start(self, frame: StartFrame):
await self.start(frame)
async def _stop(self, frame: EndFrame):
await self.stop(frame)
async def _cancel(self, frame: CancelFrame):
await self.cancel(frame)
class LLMService(AIService):
"""This class is a no-op but serves as a base class for LLM services."""
@@ -248,19 +257,16 @@ class TTSService(AIService):
pass
async def start(self, frame: StartFrame):
await super().start(frame)
if self._push_stop_frames:
self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler())
async def stop(self, frame: EndFrame):
await super().stop(frame)
if self._stop_frame_task:
self._stop_frame_task.cancel()
await self._stop_frame_task
self._stop_frame_task = None
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
if self._stop_frame_task:
self._stop_frame_task.cancel()
await self._stop_frame_task
@@ -286,8 +292,6 @@ class TTSService(AIService):
await self.queue_frame(TTSSpeakFrame(text))
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
await self._process_text_frame(frame)
elif isinstance(frame, StartInterruptionFrame):
@@ -404,8 +408,6 @@ class WordTTSService(TTSService):
await self._stop_words_task()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
await self.flush_audio()
@@ -492,8 +494,6 @@ class STTService(AIService):
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Processes a frame of audio data, either buffering or transcribing it."""
await super().process_frame(frame, direction)
if isinstance(frame, AudioRawFrame):
# In this service we accumulate audio internally and at the end we
# push a TextFrame. We also push audio downstream in case someone
@@ -591,8 +591,6 @@ class ImageGenService(AIService):
pass
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TextFrame):
await self.push_frame(frame, direction)
await self.start_processing_metrics()
@@ -614,8 +612,6 @@ class VisionService(AIService):
pass
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, VisionImageRawFrame):
await self.start_processing_metrics()
await self.process_generator(self.run_vision(frame))

View File

@@ -270,8 +270,6 @@ class AnthropicLLMService(LLMService):
)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):
context: "AnthropicLLMContext" = AnthropicLLMContext.upgrade_to_anthropic(frame.context)
@@ -611,7 +609,6 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
self._context = AnthropicLLMContext.from_openai_context(context)
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
# Our parent method has already called push_frame(). So we can't interrupt the
# flow here and we don't need to call push_frame() ourselves. Possibly something
# to talk through (tagging @aleix). At some point we might need to refactor these
@@ -664,7 +661,6 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
self._pending_image_frame_message = None
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
# See note above about not calling push_frame() here.
if isinstance(frame, StartInterruptionFrame):
self._function_call_in_progress = None

View File

@@ -61,15 +61,12 @@ class AssemblyAISTTService(STTService):
self._settings["language"] = language
async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:

View File

@@ -676,16 +676,13 @@ class AzureSTTService(STTService):
yield None
async def start(self, frame: StartFrame):
await super().start(frame)
self._speech_recognizer.start_continuous_recognition_async()
async def stop(self, frame: EndFrame):
await super().stop(frame)
self._speech_recognizer.stop_continuous_recognition_async()
self._audio_stream.close()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
self._speech_recognizer.stop_continuous_recognition_async()
self._audio_stream.close()

View File

@@ -84,15 +84,12 @@ class CanonicalMetricsService(AIService):
self._output_dir = output_dir
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._process_audio()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._process_audio()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
await self.push_frame(frame, direction)
async def _process_audio(self):

View File

@@ -287,8 +287,6 @@ class CartesiaTTSService(WordTTSService):
await self._connect_websocket()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# If we received a TTSSpeakFrame and the LLM response included text (it
# might be that it's only a function calling response) we pause
# processing more frames until we receive a BotStoppedSpeakingFrame.

View File

@@ -176,15 +176,12 @@ class DeepgramSTTService(STTService):
await self._connect()
async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:

View File

@@ -272,8 +272,6 @@ class ElevenLabsTTSService(WordTTSService):
await self.add_word_timestamps([("LLMFullResponseEndFrame", 0), ("Reset", 0)])
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# If we received a TTSSpeakFrame and the LLM response included text (it
# might be that it's only a function calling response) we pause
# processing more frames until we receive a BotStoppedSpeakingFrame.

View File

@@ -107,7 +107,6 @@ class GeminiMultimodalLiveContext(OpenAILLMContext):
class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
# kind of a hack just to pass the LLMMessagesAppendFrame through, but it's fine for now
if isinstance(frame, LLMMessagesAppendFrame):
await self.push_frame(frame, direction)
@@ -229,15 +228,12 @@ class GeminiMultimodalLiveLLMService(LLMService):
#
async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()
#
@@ -308,8 +304,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
#
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# logger.debug(f"Processing frame: {frame}")
if isinstance(frame, TranscriptionFrame):

View File

@@ -177,18 +177,15 @@ class GladiaSTTService(STTService):
return language_to_gladia_language(language)
async def start(self, frame: StartFrame):
await super().start(frame)
response = await self._setup_gladia()
self._websocket = await websockets.connect(response["url"])
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._send_stop_recording()
await self._websocket.close()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._websocket.close()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:

View File

@@ -652,8 +652,6 @@ class GoogleLLMService(LLMService):
await self.push_frame(LLMFullResponseEndFrame())
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):

View File

@@ -286,8 +286,6 @@ class BaseOpenAILLMService(LLMService):
)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
context = None
if isinstance(frame, OpenAILLMContextFrame):
context: OpenAILLMContext = frame.context
@@ -475,7 +473,6 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator):
super().__init__(context=context)
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
# Our parent method has already called push_frame(). So we can't interrupt the
# flow here and we don't need to call push_frame() ourselves.
try:
@@ -516,7 +513,6 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
self._pending_image_frame_message = None
async def process_frame(self, frame, direction):
await super().process_frame(frame, direction)
# See note above about not calling push_frame() here.
if isinstance(frame, StartInterruptionFrame):
self._function_calls_in_progress.clear()

View File

@@ -148,7 +148,6 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
async def process_frame(
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
):
await super().process_frame(frame, direction)
# Parent does not push LLMMessagesUpdateFrame. This ensures that in a typical pipeline,
# messages are only processed by the user context aggregator, which is generally what we want. But
# we also need to send new messages over the websocket, so the openai realtime API has them

View File

@@ -112,15 +112,12 @@ class OpenAIRealtimeBetaLLMService(LLMService):
#
async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()
#
@@ -173,8 +170,6 @@ class OpenAIRealtimeBetaLLMService(LLMService):
#
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
pass
elif isinstance(frame, OpenAILLMContextFrame):

View File

@@ -265,8 +265,6 @@ class PlayHTTTSService(TTSService):
await self._connect_websocket()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
# If we received a TTSSpeakFrame and the LLM response included text (it
# might be that it's only a function calling response) we pause
# processing more frames until we receive a BotStoppedSpeakingFrame.

View File

@@ -187,17 +187,14 @@ class ParakeetSTTService(STTService):
return False
async def start(self, frame: StartFrame):
await super().start(frame)
self._thread_task = self.get_event_loop().create_task(self._thread_task_handler())
self._response_task = self.get_event_loop().create_task(self._response_task_handler())
self._response_queue = asyncio.Queue()
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._stop_tasks()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._stop_tasks()
async def _stop_tasks(self):

View File

@@ -92,7 +92,6 @@ class TavusVideoService(AIService):
await self._send_audio_message(audio_base64, done=done)
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TTSStartedFrame):
await self.start_processing_metrics()
await self.start_ttfb_metrics()

View File

@@ -101,8 +101,6 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
self._next_send_time = 0
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
await self._write_frame(frame)
self._next_send_time = 0

View File

@@ -139,8 +139,6 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
self._websocket = websocket
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, StartInterruptionFrame):
await self._write_frame(frame)
self._next_send_time = 0

View File

@@ -727,8 +727,6 @@ class DailyInputTransport(BaseInputTransport):
#
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, UserImageRequestFrame):
await self.request_participant_image(frame.user_id)

View File

@@ -16,7 +16,6 @@ from pipecat.frames.frames import (
AudioRawFrame,
CancelFrame,
EndFrame,
Frame,
InputAudioRawFrame,
OutputAudioRawFrame,
StartFrame,
@@ -334,12 +333,6 @@ class LiveKitInputTransport(BaseInputTransport):
await self._client.disconnect()
logger.info("LiveKitInputTransport stopped")
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, EndFrame):
await self.stop(frame)
else:
await super().process_frame(frame, direction)
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._client.disconnect()
@@ -411,12 +404,6 @@ class LiveKitOutputTransport(BaseOutputTransport):
await self._client.disconnect()
logger.info("LiveKitOutputTransport stopped")
async def process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, EndFrame):
await self.stop(frame)
else:
await super().process_frame(frame, direction)
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._client.disconnect()