diff --git a/src/dailyai/async_processor/async_processor.py b/src/dailyai/async_processor/async_processor.py index a4bf9a83a..192de7b0e 100644 --- a/src/dailyai/async_processor/async_processor.py +++ b/src/dailyai/async_processor/async_processor.py @@ -4,14 +4,14 @@ import re from collections import defaultdict from dataclasses import dataclass, field +from enum import Enum from queue import Queue, PriorityQueue, Empty from threading import Event, Semaphore, Thread -from typing import Any, Generator, Iterator, Optional, Type, TypedDict +from typing import Any, Generator, Iterator, Optional, Type -from dailyai.services.ai_services import AIServiceConfig +from dailyai.output_queue import OutputQueueFrame, FrameType from dailyai.message_handler.message_handler import MessageHandler - -frame_idx = 0 +from dailyai.services.ai_services import AIServiceConfig class AsyncProcessorState: # Setting class variables, other synchronous activities @@ -211,6 +211,9 @@ class AsyncProcessor: def do_finalization(self) -> None: pass +# A common class for responses that use a message queue and +# an output queue. + class OrchestratorResponse(AsyncProcessor): def __init__( @@ -265,10 +268,10 @@ class LLMResponse(OrchestratorResponse): if out.strip(): yield out.strip() - def get_frames_from_tts_response(self, audio_frame) -> list[dict[str, Any]]: - return [{"type": "audio_frame", "data": audio_frame}] + def get_frames_from_tts_response(self, audio_frame) -> list[OutputQueueFrame]: + return [OutputQueueFrame(FrameType.AUDIO_FRAME, audio_frame)] - def get_frames_from_chunk(self, chunk) -> Generator[list[dict[str, Any]], Any, None]: + def get_frames_from_chunk(self, chunk) -> Generator[list[OutputQueueFrame], Any, None]: for audio_frame in self.services.tts.run_tts(chunk): yield self.get_frames_from_tts_response(audio_frame) @@ -299,14 +302,13 @@ class LLMResponse(OrchestratorResponse): ]: break prepared_chunk = self.chunks_in_preparation.get() - if prepared_chunk[0] is None: + if prepared_chunk[0] == None: return self.play_prepared_chunk(prepared_chunk) def play_prepared_chunk(self, prepared_chunk) -> None: chunk, tts_generator = prepared_chunk - global frame_idx for frames in tts_generator: if self.state not in [ AsyncProcessorState.READY, @@ -315,14 +317,11 @@ class LLMResponse(OrchestratorResponse): break if not self.has_sent_first_frame: - self.output_queue.put({"type": "start_stream", "idx": frame_idx}) - frame_idx += 1 + self.output_queue.put(OutputQueueFrame(FrameType.START_STREAM, None)) self.has_sent_first_frame = True for frame in frames: - frame["idx"] = frame_idx self.output_queue.put(frame) - frame_idx += 1 self.output_queue.join() self.llm_responses.append(chunk) diff --git a/src/dailyai/orchestrator.py b/src/dailyai/orchestrator.py index 920af0422..9b7ed7dbb 100644 --- a/src/dailyai/orchestrator.py +++ b/src/dailyai/orchestrator.py @@ -4,6 +4,7 @@ import time import wave from dataclasses import dataclass +from enum import Enum from queue import Queue, Empty from opentelemetry import trace, context @@ -14,6 +15,7 @@ from dailyai.async_processor.async_processor import ( OrchestratorResponse, LLMResponse, ) +from dailyai.output_queue import OutputQueueFrame, FrameType from dailyai.services.ai_services import AIServiceConfig from dailyai.message_handler.message_handler import MessageHandler @@ -49,6 +51,7 @@ default_conversation_collection = ConversationProcessorCollection( goodbye=None, ) + class Orchestrator(EventHandler): def __init__( @@ -194,7 +197,7 @@ class Orchestrator(EventHandler): self.logger.info("Camera thread stopped") self.logger.info("Put stop in output queue") - self.output_queue.put({"type": "stop"}) + self.output_queue.put(OutputQueueFrame(FrameType.END_STREAM, None)) self.frame_consumer_thread.join() self.logger.info("Orchestrator stopped.") @@ -357,36 +360,18 @@ class Orchestrator(EventHandler): self.logger.info("🎬 Starting frame consumer thread") b = bytearray() smallest_write_size = 3200 - expected_idx = 0 all_audio_frames = bytearray() while True: try: - frame = self.output_queue.get() - if frame["type"] == "stop": + frame:OutputQueueFrame = self.output_queue.get() + if frame.frame_type == FrameType.END_STREAM: self.logger.info("Stopping frame consumer thread") - if os.getenv("WRITE_BOT_AUDIO", False): - filename = f"conversation-{len(all_audio_frames)}.wav" - with wave.open(filename, "wb") as f: - f.setnchannels(1) - f.setframerate(16000) - f.setsampwidth(2) - f.setcomptype("NONE", "not compressed") - f.writeframes(all_audio_frames) - return - - if frame["idx"] != expected_idx and frame["idx"] != 0: - self.logger.error( - f"🎬 Expected frame {expected_idx}, got {frame['idx']}" - ) - - expected_idx += 1 - # if interrupted, we just pull frames off the queue and discard them if not self.is_interrupted.is_set(): if frame: - if frame["type"] == "audio_frame": - chunk = frame["data"] + if frame.frame_type == FrameType.AUDIO_FRAME: + chunk = frame.frame_data all_audio_frames.extend(chunk) @@ -395,8 +380,8 @@ class Orchestrator(EventHandler): if l: self.mic.write_frames(bytes(b[:l])) b = b[l:] - elif frame["type"] == "image_frame": - self.set_image(frame["data"]) + elif frame.frame_type == FrameType.IMAGE_FRAME: + self.set_image(frame.frame_data) elif len(b): self.mic.write_frames(bytes(b)) b = bytearray() diff --git a/src/dailyai/output_queue.py b/src/dailyai/output_queue.py new file mode 100644 index 000000000..11b36b95e --- /dev/null +++ b/src/dailyai/output_queue.py @@ -0,0 +1,14 @@ +from enum import Enum +from dataclasses import dataclass + +class FrameType(Enum): + AUDIO_FRAME = 1 + IMAGE_FRAME = 2 + START_STREAM = 3 + END_STREAM = 4 + + +@dataclass(frozen=True) +class OutputQueueFrame: + frame_type: FrameType + frame_data: bytes diff --git a/src/dailyai/tests/test_asyncprocessor.py b/src/dailyai/tests/test_asyncprocessor.py index b4a7893b5..e61dd525f 100644 --- a/src/dailyai/tests/test_asyncprocessor.py +++ b/src/dailyai/tests/test_asyncprocessor.py @@ -5,18 +5,19 @@ from queue import Queue, Empty from threading import Thread, Event from typing import Generator -from dailyai.services.ai_services import ( - AIServiceConfig, - ImageGenService, - LLMService, - TTSService -) -from dailyai.message_handler.message_handler import MessageHandler from dailyai.async_processor.async_processor import ( AsyncProcessor, AsyncProcessorState, LLMResponse, ) +from dailyai.message_handler.message_handler import MessageHandler +from dailyai.output_queue import OutputQueueFrame, FrameType +from dailyai.services.ai_services import ( + AIServiceConfig, + ImageGenService, + LLMService, + TTSService, +) class MockTTSService(TTSService): def run_tts(self, sentence): @@ -70,10 +71,10 @@ class TestResponse(unittest.TestCase): output_queue.task_done() while expected_words: - actual_word = output_queue.get() + actual_word:OutputQueueFrame = output_queue.get() word = expected_words.pop(0) - self.assertEqual(actual_word['type'], 'audio_frame') - self.assertEqual(actual_word['data'], bytes(word, "utf-8")) + self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME) + self.assertEqual(actual_word.frame_data, bytes(word, "utf-8")) output_queue.task_done() processor.finalize() @@ -126,12 +127,12 @@ class TestResponse(unittest.TestCase): expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."] while expected_words and not stop_processing_output_queue.is_set(): try: - actual_word = output_queue.get_nowait() - if actual_word['type'] == 'audio_frame': + actual_word:OutputQueueFrame = output_queue.get_nowait() + if actual_word.frame_type == FrameType.AUDIO_FRAME: time.sleep(0.1) word = expected_words.pop(0) - self.assertEqual(actual_word['type'], 'audio_frame') - self.assertEqual(actual_word['data'], bytes(word, "utf-8")) + self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME) + self.assertEqual(actual_word.frame_data, bytes(word, "utf-8")) output_queue.task_done() except Empty: pass diff --git a/src/samples/static-sprite/sprite-sample.py b/src/samples/static-sprite/sprite-sample.py index aac0d3d1e..8ce589feb 100644 --- a/src/samples/static-sprite/sprite-sample.py +++ b/src/samples/static-sprite/sprite-sample.py @@ -8,14 +8,15 @@ from PIL import Image from dailyai.async_processor.async_processor import ( ConversationProcessorCollection, - LLMResponse + LLMResponse, + OrchestratorResponse ) from dailyai.orchestrator import OrchestratorConfig, Orchestrator from dailyai.message_handler.message_handler import MessageHandler from dailyai.services.ai_services import AIServiceConfig from dailyai.services.azure_ai_services import AzureImageGenService, AzureTTSService, AzureLLMService -class StaticSpriteResponse(LLMResponse): +class StaticSpriteResponse(OrchestratorResponse): def __init__( self, @@ -28,7 +29,8 @@ class StaticSpriteResponse(LLMResponse): self.filename = None # override this in subclasses def start_preparation(self) -> None: - full_path = os.path.join(os.path.dirname(__file__), "/sprites/") + full_path = os.path.join(os.path.dirname(__file__), "sprites/", self.filename) + print(full_path) with Image.open(full_path) as img: self.image_bytes = img.tobytes() @@ -43,6 +45,12 @@ class IntroSpriteResponse(StaticSpriteResponse): self.filename = "intro.png" +class WaitingSpriteResponse(StaticSpriteResponse): + def __init__(self, services, message_handler, output_queue) -> None: + super().__init__(services, message_handler, output_queue) + self.filename = "waiting.png" + + def add_bot_to_room(room_url, token, expiration) -> None: # A simple prompt for a simple sample. @@ -74,9 +82,9 @@ def add_bot_to_room(room_url, token, expiration) -> None: ) sprite_conversation_processors = ConversationProcessorCollection( - intro = IntroSpriteResponse, - waiting = WaitingSpriteResponse, - response = ResponseSpriteResponse, + introduction=IntroSpriteResponse, + waiting=WaitingSpriteResponse, + response=LLMResponse, ) orchestrator_config = OrchestratorConfig( @@ -90,6 +98,7 @@ def add_bot_to_room(room_url, token, expiration) -> None: orchestrator_config, services, message_handler, + sprite_conversation_processors ) orchestrator.start() diff --git a/src/samples/static-sprite/sprites/sc-listen-1.png b/src/samples/static-sprite/sprites/wait.png similarity index 100% rename from src/samples/static-sprite/sprites/sc-listen-1.png rename to src/samples/static-sprite/sprites/wait.png