Compare commits
2 Commits
main
...
aleix/no-n
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06043ce9b1 | ||
|
|
3f3a853d71 |
@@ -15,6 +15,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- It's no longer necessary to call `super().start/stop/cancel(frame)` if you
|
||||
subclass and implement `AIService.start/stop/cancel()`. This is all now done
|
||||
internally and will avoid possible issues if you forget to add it.
|
||||
|
||||
- It's no longer necessary to call `super().process_frame(frame, direction)` if
|
||||
you subclass and implement `FrameProcessor.process_frame()`. This is all now
|
||||
done internally and will avoid possible issues if you forget to add it.
|
||||
|
||||
@@ -111,11 +111,11 @@ class AIService(FrameProcessor):
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, StartFrame):
|
||||
await self.start(frame)
|
||||
await self._start(frame)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.cancel(frame)
|
||||
await self._cancel(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self.stop(frame)
|
||||
await self._stop(frame)
|
||||
|
||||
async def process_generator(self, generator: AsyncGenerator[Frame | None, None]):
|
||||
async for f in generator:
|
||||
@@ -125,6 +125,15 @@ class AIService(FrameProcessor):
|
||||
else:
|
||||
await self.push_frame(f)
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
await self.start(frame)
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self.stop(frame)
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
await self.cancel(frame)
|
||||
|
||||
|
||||
class LLMService(AIService):
|
||||
"""This class is a no-op but serves as a base class for LLM services."""
|
||||
@@ -248,19 +257,16 @@ class TTSService(AIService):
|
||||
pass
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
if self._push_stop_frames:
|
||||
self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
if self._stop_frame_task:
|
||||
self._stop_frame_task.cancel()
|
||||
await self._stop_frame_task
|
||||
self._stop_frame_task = None
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
if self._stop_frame_task:
|
||||
self._stop_frame_task.cancel()
|
||||
await self._stop_frame_task
|
||||
@@ -286,8 +292,6 @@ class TTSService(AIService):
|
||||
await self.queue_frame(TTSSpeakFrame(text))
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._process_text_frame(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
@@ -404,8 +408,6 @@ class WordTTSService(TTSService):
|
||||
await self._stop_words_task()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
|
||||
@@ -492,8 +494,6 @@ class STTService(AIService):
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Processes a frame of audio data, either buffering or transcribing it."""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
# In this service we accumulate audio internally and at the end we
|
||||
# push a TextFrame. We also push audio downstream in case someone
|
||||
@@ -591,8 +591,6 @@ class ImageGenService(AIService):
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self.start_processing_metrics()
|
||||
@@ -614,8 +612,6 @@ class VisionService(AIService):
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, VisionImageRawFrame):
|
||||
await self.start_processing_metrics()
|
||||
await self.process_generator(self.run_vision(frame))
|
||||
|
||||
@@ -270,8 +270,6 @@ class AnthropicLLMService(LLMService):
|
||||
)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context: "AnthropicLLMContext" = AnthropicLLMContext.upgrade_to_anthropic(frame.context)
|
||||
@@ -611,7 +609,6 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
|
||||
self._context = AnthropicLLMContext.from_openai_context(context)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# Our parent method has already called push_frame(). So we can't interrupt the
|
||||
# flow here and we don't need to call push_frame() ourselves. Possibly something
|
||||
# to talk through (tagging @aleix). At some point we might need to refactor these
|
||||
@@ -664,7 +661,6 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
self._pending_image_frame_message = None
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# See note above about not calling push_frame() here.
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
self._function_call_in_progress = None
|
||||
|
||||
@@ -61,15 +61,12 @@ class AssemblyAISTTService(STTService):
|
||||
self._settings["language"] = language
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
|
||||
@@ -676,16 +676,13 @@ class AzureSTTService(STTService):
|
||||
yield None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._speech_recognizer.start_continuous_recognition_async()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
self._audio_stream.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
self._speech_recognizer.stop_continuous_recognition_async()
|
||||
self._audio_stream.close()
|
||||
|
||||
|
||||
@@ -84,15 +84,12 @@ class CanonicalMetricsService(AIService):
|
||||
self._output_dir = output_dir
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._process_audio()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._process_audio()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_audio(self):
|
||||
|
||||
@@ -287,8 +287,6 @@ class CartesiaTTSService(WordTTSService):
|
||||
await self._connect_websocket()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# If we received a TTSSpeakFrame and the LLM response included text (it
|
||||
# might be that it's only a function calling response) we pause
|
||||
# processing more frames until we receive a BotStoppedSpeakingFrame.
|
||||
|
||||
@@ -176,15 +176,12 @@ class DeepgramSTTService(STTService):
|
||||
await self._connect()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
|
||||
@@ -272,8 +272,6 @@ class ElevenLabsTTSService(WordTTSService):
|
||||
await self.add_word_timestamps([("LLMFullResponseEndFrame", 0), ("Reset", 0)])
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# If we received a TTSSpeakFrame and the LLM response included text (it
|
||||
# might be that it's only a function calling response) we pause
|
||||
# processing more frames until we receive a BotStoppedSpeakingFrame.
|
||||
|
||||
@@ -107,7 +107,6 @@ class GeminiMultimodalLiveContext(OpenAILLMContext):
|
||||
|
||||
class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# kind of a hack just to pass the LLMMessagesAppendFrame through, but it's fine for now
|
||||
if isinstance(frame, LLMMessagesAppendFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -229,15 +228,12 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
#
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
#
|
||||
@@ -308,8 +304,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
#
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# logger.debug(f"Processing frame: {frame}")
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
|
||||
@@ -177,18 +177,15 @@ class GladiaSTTService(STTService):
|
||||
return language_to_gladia_language(language)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
response = await self._setup_gladia()
|
||||
self._websocket = await websockets.connect(response["url"])
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._send_stop_recording()
|
||||
await self._websocket.close()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._websocket.close()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
|
||||
@@ -652,8 +652,6 @@ class GoogleLLMService(LLMService):
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
|
||||
@@ -286,8 +286,6 @@ class BaseOpenAILLMService(LLMService):
|
||||
)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context: OpenAILLMContext = frame.context
|
||||
@@ -475,7 +473,6 @@ class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
super().__init__(context=context)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# Our parent method has already called push_frame(). So we can't interrupt the
|
||||
# flow here and we don't need to call push_frame() ourselves.
|
||||
try:
|
||||
@@ -516,7 +513,6 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
self._pending_image_frame_message = None
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# See note above about not calling push_frame() here.
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
self._function_calls_in_progress.clear()
|
||||
|
||||
@@ -148,7 +148,6 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
async def process_frame(
|
||||
self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM
|
||||
):
|
||||
await super().process_frame(frame, direction)
|
||||
# Parent does not push LLMMessagesUpdateFrame. This ensures that in a typical pipeline,
|
||||
# messages are only processed by the user context aggregator, which is generally what we want. But
|
||||
# we also need to send new messages over the websocket, so the openai realtime API has them
|
||||
|
||||
@@ -112,15 +112,12 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
#
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
#
|
||||
@@ -173,8 +170,6 @@ class OpenAIRealtimeBetaLLMService(LLMService):
|
||||
#
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
pass
|
||||
elif isinstance(frame, OpenAILLMContextFrame):
|
||||
|
||||
@@ -265,8 +265,6 @@ class PlayHTTTSService(TTSService):
|
||||
await self._connect_websocket()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# If we received a TTSSpeakFrame and the LLM response included text (it
|
||||
# might be that it's only a function calling response) we pause
|
||||
# processing more frames until we receive a BotStoppedSpeakingFrame.
|
||||
|
||||
@@ -187,17 +187,14 @@ class ParakeetSTTService(STTService):
|
||||
return False
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._thread_task = self.get_event_loop().create_task(self._thread_task_handler())
|
||||
self._response_task = self.get_event_loop().create_task(self._response_task_handler())
|
||||
self._response_queue = asyncio.Queue()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def _stop_tasks(self):
|
||||
|
||||
@@ -92,7 +92,6 @@ class TavusVideoService(AIService):
|
||||
await self._send_audio_message(audio_base64, done=done)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TTSStartedFrame):
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
@@ -101,8 +101,6 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
self._next_send_time = 0
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
await self._write_frame(frame)
|
||||
self._next_send_time = 0
|
||||
|
||||
@@ -139,8 +139,6 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
self._websocket = websocket
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
await self._write_frame(frame)
|
||||
self._next_send_time = 0
|
||||
|
||||
@@ -727,8 +727,6 @@ class DailyInputTransport(BaseInputTransport):
|
||||
#
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, UserImageRequestFrame):
|
||||
await self.request_participant_image(frame.user_id)
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
OutputAudioRawFrame,
|
||||
StartFrame,
|
||||
@@ -334,12 +333,6 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
await self._client.disconnect()
|
||||
logger.info("LiveKitInputTransport stopped")
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, EndFrame):
|
||||
await self.stop(frame)
|
||||
else:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._client.disconnect()
|
||||
@@ -411,12 +404,6 @@ class LiveKitOutputTransport(BaseOutputTransport):
|
||||
await self._client.disconnect()
|
||||
logger.info("LiveKitOutputTransport stopped")
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, EndFrame):
|
||||
await self.stop(frame)
|
||||
else:
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._client.disconnect()
|
||||
|
||||
Reference in New Issue
Block a user