diff --git a/src/dailyai/queue_frame.py b/src/dailyai/queue_frame.py index 7b4d1ca83..7e77ff89d 100644 --- a/src/dailyai/queue_frame.py +++ b/src/dailyai/queue_frame.py @@ -2,16 +2,14 @@ from enum import Enum from dataclasses import dataclass class FrameType(Enum): + NOOP = -1 START_STREAM = 0 END_STREAM = 1 AUDIO = 2 IMAGE = 3 - SENTENCE = 4 - TEXT_CHUNK = 5 - LLM_MESSAGE = 6 - APP_MESSAGE = 7 - IMAGE_DESCRIPTION = 8 - TRANSCRIPTION = 9 + TEXT = 4 + LLM_MESSAGE = 5 + APP_MESSAGE = 6 @dataclass(frozen=True) class QueueFrame: diff --git a/src/dailyai/services/ai_services.py b/src/dailyai/services/ai_services.py index 83c1c7099..9ad8e0399 100644 --- a/src/dailyai/services/ai_services.py +++ b/src/dailyai/services/ai_services.py @@ -21,12 +21,6 @@ class AIService: def stop(self): pass - def allowed_input_frame_types(self) -> set[FrameType]: - return set() - - def possible_output_frame_types(self) -> set[FrameType]: - return set() - async def run_to_queue(self, queue: asyncio.Queue, frames, add_end_of_stream=False) -> None: async for frame in self.run(frames): await queue.put(frame) @@ -39,76 +33,50 @@ class AIService: frames: Iterable[QueueFrame] | AsyncIterable[QueueFrame] | asyncio.Queue[QueueFrame], - requested_frame_types: set[FrameType] | None=None, ) -> AsyncGenerator[QueueFrame, None]: - if requested_frame_types and self.possible_output_frame_types().intersection(requested_frame_types) == set(): - raise Exception(f"Requested frame types {requested_frame_types} are not supported by this service.") + try: + if isinstance(frames, AsyncIterable): + async for frame in frames: + async for output_frame in self.process_frame(frame): + yield output_frame + elif isinstance(frames, Iterable): + for frame in frames: + async for output_frame in self.process_frame(frame): + yield output_frame + elif isinstance(frames, asyncio.Queue): + while True: + frame = await frames.get() + async for output_frame in self.process_frame(frame): + yield output_frame + if frame.frame_type == FrameType.END_STREAM: + break + else: + raise Exception("Frames must be an iterable or async iterable") - if not requested_frame_types: - requested_frame_types = self.possible_output_frame_types() - - if isinstance(frames, AsyncIterable): - async for frame in frames: - async for output_frame in self.process_frame(requested_frame_types, frame): - yield output_frame - elif isinstance(frames, Iterable): - for frame in frames: - async for output_frame in self.process_frame(requested_frame_types, frame): - yield output_frame - elif isinstance(frames, asyncio.Queue): - while True: - frame = await frames.get() - async for output_frame in self.process_frame(requested_frame_types, frame): - yield output_frame - if frame.frame_type == FrameType.END_STREAM: - break - else: - raise Exception("Frames must be an iterable or async iterable") + async for output_frame in self.finalize(): + yield output_frame + except Exception as e: + self.logger.error("Exception occurred while running AI service", e) + raise e @abstractmethod - async def process_frame(self, requested_frame_types:set[FrameType], frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]: - # Yield something so the linter can deduce what should happen here. - yield QueueFrame(FrameType.END_STREAM, None) - -class SentenceAggregator(AIService): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.current_sentence = "" - - def allowed_input_frame_types(self) -> set[FrameType]: - return set([FrameType.TEXT_CHUNK, FrameType.SENTENCE]) - - def possible_output_frame_types(self) -> set[FrameType]: - return set([FrameType.SENTENCE]) - - async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]: - if not FrameType.SENTENCE in requested_frame_types: - return - - if frame.frame_type == FrameType.TEXT_CHUNK: - if type(frame.frame_data) != str: - raise Exception( - "Sentence aggregator requires a string for the data field" - ) - - self.current_sentence += frame.frame_data - if self.current_sentence.endswith((".", "?", "!")): - sentence = self.current_sentence - self.current_sentence = "" - yield QueueFrame(FrameType.SENTENCE, sentence) - elif frame.frame_type == FrameType.END_STREAM: - if self.current_sentence: - yield QueueFrame(FrameType.SENTENCE, self.current_sentence) - elif frame.frame_type == FrameType.SENTENCE: - yield frame + async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]: + # This is a trick for the interpreter (and linter) to know that this is a generator. + if False: + yield QueueFrame(FrameType.NOOP, None) + @abstractmethod + async def finalize(self) -> AsyncGenerator[QueueFrame, None]: + # This is a trick for the interpreter (and linter) to know that this is a generator. + if False: + yield QueueFrame(FrameType.NOOP, None) class LLMService(AIService): def allowed_input_frame_types(self) -> set[FrameType]: - return set([FrameType.LLM_MESSAGE, FrameType.SENTENCE, FrameType.TRANSCRIPTION]) + return set([FrameType.LLM_MESSAGE]) def allowed_output_frame_types(self) -> set[FrameType]: - return set([FrameType.SENTENCE, FrameType.TEXT_CHUNK]) + return set([FrameType.TEXT]) @abstractmethod async def run_llm_async(self, messages) -> AsyncGenerator[str, None]: @@ -118,52 +86,58 @@ class LLMService(AIService): async def run_llm(self, messages) -> str: pass - async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]: + async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]: if frame.frame_type == FrameType.LLM_MESSAGE: if type(frame.frame_data) != list: raise Exception("LLM service requires a dict for the data field") messages: list[dict[str, str]] = frame.frame_data - if FrameType.SENTENCE in requested_frame_types: - yield QueueFrame(FrameType.SENTENCE, await self.run_llm(messages)) - else: - async for text_chunk in self.run_llm_async(messages): - yield QueueFrame(FrameType.TEXT_CHUNK, text_chunk) - - # TODO: handle other frame types! Need to aggregate into messages + async for text_chunk in self.run_llm_async(messages): + yield QueueFrame(FrameType.TEXT, text_chunk) class TTSService(AIService): + def __init__(self, aggregate_sentences=True): + super().__init__() + self.aggregate_sentences: bool = aggregate_sentences + self.current_sentence: str = "" + # Some TTS services require a specific sample rate. We default to 16k def get_mic_sample_rate(self): return 16000 - def allowed_input_frame_types(self) -> set[FrameType]: - return set([FrameType.SENTENCE, FrameType.TRANSCRIPTION, FrameType.TEXT_CHUNK]) - - def possible_output_frame_types(self) -> set[FrameType]: - return set([FrameType.AUDIO]) - - # Converts the sentence to audio. Yields a list of audio frames that can + # Converts the text to audio. Yields a list of audio frames that can # be sent to the microphone device @abstractmethod - async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]: + async def run_tts(self, text) -> AsyncGenerator[bytes, None]: # yield empty bytes here, so linting can infer what this method does yield bytes() - async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]: - if not FrameType.AUDIO in requested_frame_types: - return + async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]: + if frame.frame_type != FrameType.TEXT or type(frame.frame_data) != str: + raise Exception(f"TTS service requires a string for the data field, got {frame.frame_type} and frame_data type {type(frame.frame_data)}") - if type(frame.frame_data) != str: - raise Exception("TTS service requires a string for the data field") + text: str | None = None + if not self.aggregate_sentences: + text = frame.frame_data + else: + self.current_sentence += frame.frame_data + if self.current_sentence.endswith((".", "?", "!")): + text = self.current_sentence + self.current_sentence = "" - async for audio_chunk in self.run_tts(frame.frame_data): - yield QueueFrame(FrameType.AUDIO, audio_chunk) + if text: + async for audio_chunk in self.run_tts(text): + yield QueueFrame(FrameType.AUDIO, audio_chunk) + + async def finalize(self): + if self.current_sentence: + async for audio_chunk in self.run_tts(self.current_sentence): + yield QueueFrame(FrameType.AUDIO, audio_chunk) # Convenience function to send the audio for a sentence to the given queue async def say(self, sentence, queue: asyncio.Queue): - await self.run_to_queue(queue, [QueueFrame(FrameType.SENTENCE, sentence)]) + await self.run_to_queue(queue, [QueueFrame(FrameType.TEXT, sentence)]) class ImageGenService(AIService): @@ -171,21 +145,12 @@ class ImageGenService(AIService): super().__init__(**kwargs) self.image_size = image_size - def allowed_input_frame_types(self) -> set[FrameType]: - return set([FrameType.SENTENCE, FrameType.TRANSCRIPTION, FrameType.TEXT_CHUNK, FrameType.IMAGE_DESCRIPTION]) - - def possible_output_frame_types(self) -> set[FrameType]: - return set([FrameType.IMAGE]) - # Renders the image. Returns an Image object. @abstractmethod async def run_image_gen(self, sentence) -> tuple[str, bytes]: pass - async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]: - if not FrameType.IMAGE in requested_frame_types: - return - + async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]: if type(frame.frame_data) != str: raise Exception("Image service requires a string for the data field") diff --git a/src/dailyai/services/daily_transport_service.py b/src/dailyai/services/daily_transport_service.py index 3e8cfda18..5d300cad3 100644 --- a/src/dailyai/services/daily_transport_service.py +++ b/src/dailyai/services/daily_transport_service.py @@ -279,7 +279,7 @@ class DailyTransportService(EventHandler): def on_transcription_message(self, message:dict): if self.loop: - frame = QueueFrame(FrameType.TRANSCRIPTION, message) + frame = QueueFrame(FrameType.TEXT, message) asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self.loop) def on_transcription_stopped(self, stopped_by, stopped_by_error): diff --git a/src/dailyai/tests/test_ai_services.py b/src/dailyai/tests/test_ai_services.py index 6467442a1..bfe8dbde8 100644 --- a/src/dailyai/tests/test_ai_services.py +++ b/src/dailyai/tests/test_ai_services.py @@ -3,25 +3,19 @@ import unittest from typing import AsyncGenerator, Generator -from dailyai.services.ai_services import AIService, SentenceAggregator +from dailyai.services.ai_services import AIService from dailyai.queue_frame import QueueFrame, FrameType class SimpleAIService(AIService): - def allowed_input_frame_types(self) -> set[FrameType]: - return set([FrameType.TEXT_CHUNK]) - - def possible_output_frame_types(self) -> set[FrameType]: - return set([FrameType.TEXT_CHUNK]) - - async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> QueueFrame | None: - return frame + async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]: + yield frame class TestBaseAIService(unittest.IsolatedAsyncioTestCase): async def test_async_input(self): service = SimpleAIService() input_frames = [ - QueueFrame(FrameType.TEXT_CHUNK, "hello"), + QueueFrame(FrameType.TEXT, "hello"), QueueFrame(FrameType.END_STREAM, None), ] async def iterate_frames() -> AsyncGenerator[QueueFrame, None]: @@ -29,7 +23,7 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase): yield frame output_frames = [] - async for frame in service.run(set([FrameType.TEXT_CHUNK]), iterate_frames()): + async for frame in service.run(iterate_frames()): output_frames.append(frame) self.assertEqual(input_frames, output_frames) @@ -38,7 +32,7 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase): service = SimpleAIService() input_frames = [ - QueueFrame(FrameType.TEXT_CHUNK, "hello"), + QueueFrame(FrameType.TEXT, "hello"), QueueFrame(FrameType.END_STREAM, None), ] @@ -47,83 +41,11 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase): yield frame output_frames = [] - async for frame in service.run(set([FrameType.TEXT_CHUNK]), iterate_frames()): + async for frame in service.run(iterate_frames()): output_frames.append(frame) self.assertEqual(input_frames, output_frames) -class TestSentenceAggregator(unittest.IsolatedAsyncioTestCase): - async def test_clause(self) -> None: - input_frames = [ - QueueFrame(FrameType.TEXT_CHUNK, "hello"), - QueueFrame(FrameType.END_STREAM, None), - ] - - service = SentenceAggregator() - output_frames = [] - async for frame in service.run(set([FrameType.SENTENCE]), input_frames): - output_frames.append(frame) - - self.assertEqual(1, len(output_frames)) - self.assertEqual(QueueFrame(FrameType.SENTENCE, "hello"), output_frames[0]) - - async def test_sentence(self) -> None: - input_frames = [ - QueueFrame(FrameType.TEXT_CHUNK, "hello, "), - QueueFrame(FrameType.TEXT_CHUNK, "world."), - QueueFrame(FrameType.END_STREAM, None), - ] - - service = SentenceAggregator() - output_frames = [] - async for frame in service.run(set([FrameType.SENTENCE]), input_frames): - output_frames.append(frame) - - self.assertEqual(1, len(output_frames)) - self.assertEqual(QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0]) - - async def test_sentence_and_clause(self) -> None: - input_frames = [ - QueueFrame(FrameType.TEXT_CHUNK, "hello, "), - QueueFrame(FrameType.TEXT_CHUNK, "world."), - QueueFrame(FrameType.TEXT_CHUNK, " How are"), - QueueFrame(FrameType.END_STREAM, None), - ] - - service = SentenceAggregator() - output_frames = [] - async for frame in service.run(set([FrameType.SENTENCE]), input_frames): - output_frames.append(frame) - - self.assertEqual(2, len(output_frames)) - self.assertEqual( - QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0] - ) - self.assertEqual( - QueueFrame(FrameType.SENTENCE, " How are"), output_frames[1] - ) - - async def test_two_sentences(self) -> None: - input_frames = [ - QueueFrame(FrameType.TEXT_CHUNK, "hello, "), - QueueFrame(FrameType.TEXT_CHUNK, "world."), - QueueFrame(FrameType.TEXT_CHUNK, " How are"), - QueueFrame(FrameType.TEXT_CHUNK, " you doing?"), - QueueFrame(FrameType.END_STREAM, None), - ] - - service = SentenceAggregator() - output_frames = [] - async for frame in service.run(set([FrameType.SENTENCE]), input_frames): - output_frames.append(frame) - - self.assertEqual(2, len(output_frames)) - self.assertEqual( - QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0] - ) - self.assertEqual(QueueFrame(FrameType.SENTENCE, " How are you doing?"), output_frames[1]) - - if __name__ == "__main__": unittest.main() diff --git a/src/samples/theoretical-to-real/02-llm-say-one-thing.py b/src/samples/theoretical-to-real/02-llm-say-one-thing.py index a2842315c..3ea0f2714 100644 --- a/src/samples/theoretical-to-real/02-llm-say-one-thing.py +++ b/src/samples/theoretical-to-real/02-llm-say-one-thing.py @@ -4,7 +4,6 @@ from typing import AsyncGenerator from dailyai.queue_frame import QueueFrame, FrameType from dailyai.services.daily_transport_service import DailyTransportService -from dailyai.services.ai_services import SentenceAggregator from dailyai.services.azure_ai_services import AzureLLMService from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService @@ -28,9 +27,7 @@ async def main(room_url): tts_task = asyncio.create_task( tts.run_to_queue( transport.send_queue, - SentenceAggregator().run( - llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)]) - ) + llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)]) ) ) diff --git a/src/samples/theoretical-to-real/03-still-frame.py b/src/samples/theoretical-to-real/03-still-frame.py index 79261214d..c69b12547 100644 --- a/src/samples/theoretical-to-real/03-still-frame.py +++ b/src/samples/theoretical-to-real/03-still-frame.py @@ -23,7 +23,7 @@ async def main(room_url): imagegen = OpenAIImageGenService(image_size="1024x1024") image_task = asyncio.create_task( - imagegen.run_to_queue(transport.send_queue, [QueueFrame(FrameType.IMAGE_DESCRIPTION, "a cat in the style of picasso")]) + imagegen.run_to_queue(transport.send_queue, [QueueFrame(FrameType.TEXT, "a cat in the style of picasso")]) ) @transport.event_handler("on_participant_joined") diff --git a/src/samples/theoretical-to-real/04-utterance-and-speech.py b/src/samples/theoretical-to-real/04-utterance-and-speech.py index 93130d0c3..296d17e6d 100644 --- a/src/samples/theoretical-to-real/04-utterance-and-speech.py +++ b/src/samples/theoretical-to-real/04-utterance-and-speech.py @@ -2,7 +2,6 @@ import argparse import asyncio import re -from dailyai.services.ai_services import SentenceAggregator from dailyai.services.daily_transport_service import DailyTransportService from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService from dailyai.queue_frame import QueueFrame, FrameType @@ -36,9 +35,7 @@ async def main(room_url:str): llm_response_task = asyncio.create_task( elevenlabs_tts.run_to_queue( buffer_queue, - SentenceAggregator().run( - llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)]) - ), + llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)]), True, ) ) @@ -48,10 +45,7 @@ async def main(room_url:str): if participant["id"] == transport.my_participant_id: return - await azure_tts.run_to_queue( - transport.send_queue, - [QueueFrame(FrameType.SENTENCE, "My friend the LLM is now going to tell a joke about llamas.")] - ) + await azure_tts.say("My friend the LLM is now going to tell a joke about llamas.", transport.send_queue) async def buffer_to_send_queue(): while True: diff --git a/src/samples/theoretical-to-real/05-sync-speech-and-text.py b/src/samples/theoretical-to-real/05-sync-speech-and-text.py index e38020cec..d589e778a 100644 --- a/src/samples/theoretical-to-real/05-sync-speech-and-text.py +++ b/src/samples/theoretical-to-real/05-sync-speech-and-text.py @@ -5,7 +5,6 @@ from asyncio.queues import Queue import re from dailyai.queue_frame import QueueFrame, FrameType -from dailyai.services.ai_services import SentenceAggregator from dailyai.services.azure_ai_services import AzureLLMService from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService from dailyai.services.open_ai_services import OpenAIImageGenService diff --git a/src/samples/theoretical-to-real/06-listen-and-respond.py b/src/samples/theoretical-to-real/06-listen-and-respond.py index d4bcf492f..ef34d1832 100644 --- a/src/samples/theoretical-to-real/06-listen-and-respond.py +++ b/src/samples/theoretical-to-real/06-listen-and-respond.py @@ -33,7 +33,7 @@ async def main(room_url:str, token): sentence = "" async for frame in transport.get_receive_frames(): - if frame.frame_type != FrameType.TRANSCRIPTION: + if frame.frame_type != FrameType.TEXT: continue message = frame.frame_data