diff --git a/src/dailyai/async_processor/__init__.py b/src/dailyai/async_processor/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/dailyai/async_processor/async_processor.py b/src/dailyai/async_processor/async_processor.py deleted file mode 100644 index 4acd6cc46..000000000 --- a/src/dailyai/async_processor/async_processor.py +++ /dev/null @@ -1,347 +0,0 @@ -import json -import logging -import re - -from collections import defaultdict -from dataclasses import dataclass, field -from enum import Enum -from queue import Queue, PriorityQueue, Empty -from threading import Event, Semaphore, Thread -from typing import Any, Generator, Iterator, Optional, Type - -from dailyai.queue_frame import QueueFrame, FrameType -from dailyai.message_handler.message_handler import MessageHandler -from dailyai.services.ai_services import AIServiceConfig - -class AsyncProcessorState: - # Setting class variables, other synchronous activities - INIT = 0 - - # Making asynchronous requests to LLM and other services to render response - PREPARING = 1 - - # Ready to start presenting to user (but may not have all data yet) - READY = 2 - - # Playing response - PLAYING = 3 - - # An interrupt has been requested and the response is shutting down in-flight processing - INTERRUPTING = 4 - - # An interrupt has been requested and the response is finished stopping in-flight processing - INTERRUPTED = 5 - - # Response has been played or interrupted - DONE = 6 - - # Response is being finalized (updating records of speech, updating LLM context, etc.) - FINALIZING = 7 - - # Response is complete. This could mean that everything is updated, or that the response - # was interrupted. - FINALIZED = 8 - - state_transitions = { - INIT: [PREPARING, INTERRUPTING], - PREPARING: [READY, INTERRUPTING], - READY: [PLAYING, INTERRUPTING], - PLAYING: [DONE, INTERRUPTING], - INTERRUPTING: [INTERRUPTED], - INTERRUPTED: [DONE], - DONE: [FINALIZING], - FINALIZING: [FINALIZED], - FINALIZED: [FINALIZED], - } - - -@dataclass(order=True) -class StateTransitionItem: - state: int - evt: Event = field(compare=False) - -class AsyncProcessor: - def __init__( - self, - services: AIServiceConfig - ) -> None: - self.state = AsyncProcessorState.INIT - self.prepare_thread = None - self.play_thread = None - self.finalize_thread = None - - self.services: AIServiceConfig = services - - self.state_transition_semaphore = Semaphore() - self.waiting_for_state_changes = PriorityQueue() - self.state_queue = Queue() - - self.state_change_callbacks = defaultdict(list) - - self.was_interrupted = False - - self.logger: logging.Logger = logging.getLogger("dailyai") - - def set_state(self, state: int) -> None: - if state in AsyncProcessorState.state_transitions[self.state]: - self.state_transition_semaphore.acquire() - - self.state: int = state - self.state_transition_semaphore.release() - - # wake up any threads waiting for this state transition - try: - while True: - waiter = self.waiting_for_state_changes.get_nowait() - if waiter.state <= state: - waiter.evt.set() - else: - self.waiting_for_state_changes.put(waiter) - break - except Empty: - pass - - # make all the callbacks for this state - for callback in self.state_change_callbacks[state]: - callback(self) - else: - self.logger.error( - f"Invalid state transition from {self.state} to {state} in {self.__class__.__name__}" - ) - raise Exception(f"Invalid state transition from {self.state} to {state}") - - # - # This is used for state transitions that could be blocked by an interruption. - # If we are interrupted, we silently fail this call. Use only if you know that - # this state transition should fail if the processor has been interrupted. - # - - def maybe_set_state(self, state: int) -> bool: - if state in AsyncProcessorState.state_transitions[self.state]: - self.set_state(state) - return True - else: - return False - - def wait_for_state_transition(self, state: int) -> None: - if self.state >= state: - return - - self.state_transition_semaphore.acquire() - - evt = Event() - self.waiting_for_state_changes.put(StateTransitionItem(state, evt)) - self.state_transition_semaphore.release() - result = evt.wait(120.0) - if not result: - self.logger.error( - f"Timed out waiting for state transition to {state} from {self.state}" - ) - - def set_state_callback(self, state: int, callback: callable) -> None: - self.state_change_callbacks[state].append(callback) - - def prepare(self) -> None: - self.prepare_thread = Thread(target=self.async_prepare, daemon=True) - self.prepare_thread.start() - self.wait_for_state_transition(AsyncProcessorState.READY) - - def play(self) -> None: - self.wait_for_state_transition(AsyncProcessorState.READY) - self.play_thread = Thread(target=self.async_play, daemon=True) - self.play_thread.start() - self.wait_for_state_transition(AsyncProcessorState.PLAYING) - - def finalize(self) -> None: - # don't finalize until we're done playing. - self.wait_for_state_transition(AsyncProcessorState.DONE) - self.set_state(AsyncProcessorState.FINALIZING) - self.do_finalization() - self.set_state(AsyncProcessorState.FINALIZED) - - def interrupt(self) -> None: - # nothing to interrupt if we're already finalizing or finalized, no-op - if self.state in [ - AsyncProcessorState.FINALIZING, - AsyncProcessorState.FINALIZED, - ]: - return - - self.set_state(AsyncProcessorState.INTERRUPTING) - self.was_interrupted = True - self.do_interruption() - self.set_state(AsyncProcessorState.INTERRUPTED) - self.set_state(AsyncProcessorState.DONE) - - def async_play(self) -> None: - self.logger.info(f"Starting to play") - if self.maybe_set_state(AsyncProcessorState.PLAYING): - self.do_play() - self.maybe_set_state(AsyncProcessorState.DONE) - - def async_prepare(self) -> None: - self.set_state(AsyncProcessorState.PREPARING) - self.start_preparation() - self.set_state(AsyncProcessorState.READY) - self.continue_preparation() - self.logger.info(f"Preparation done for {self.__class__.__name__}") - self.preparation_done() - - def start_preparation(self) -> None: - pass - - def continue_preparation(self) -> None: - pass - - def preparation_done(self): - pass - - def get_preparation_iterator(self) -> Iterator: - yield None - - def process_chunk(self, chunk) -> None: - pass - - def do_interruption(self) -> None: - pass - - def do_play(self) -> None: - pass - - def do_finalization(self) -> None: - pass - -# A common class for responses that use a message queue and -# an output queue. - -class OrchestratorResponse(AsyncProcessor): - - def __init__( - self, - services, - message_handler, - output_queue, - ) -> None: - super().__init__(services) - - self.message_handler: MessageHandler = message_handler - self.output_queue: Queue = output_queue - - -class LLMResponse(OrchestratorResponse): - def __init__( - self, - services, - message_handler, - output_queue, - ) -> None: - super().__init__(services, message_handler, output_queue) - - self.has_sent_first_frame = False - - self.chunks_in_preparation = Queue() - - self.llm_responses: list[str] = [] - - def get_preparation_iterator(self) -> Iterator: - messages_for_llm = self.message_handler.get_llm_messages() - self.logger.debug(f"Messages for llm: {json.dumps(messages_for_llm, indent=2)}") - return self.clauses_from_chunks( - self.services.llm.run_llm_async(messages_for_llm) - ) - - def clauses_from_chunks(self, chunks) -> Iterator: - out = "" - for chunk in chunks: - if self.state not in [ - AsyncProcessorState.READY, - AsyncProcessorState.PLAYING, - ]: - break - - out += chunk - - if re.match(r"^.*[.!?]$", out): # it looks like a sentence - yield out.strip() - out = "" - - if out.strip(): - yield out.strip() - - def get_frames_from_tts_response(self, audio_frame) -> list[QueueFrame]: - return [QueueFrame(FrameType.AUDIO, audio_frame)] - - def get_frames_from_chunk(self, chunk) -> Generator[list[QueueFrame], Any, None]: - for audio_frame in self.services.tts.run_tts(chunk): - yield self.get_frames_from_tts_response(audio_frame) - - def start_preparation(self) -> None: - self.preparation_iterator = self.get_preparation_iterator() - - def continue_preparation(self) -> None: - for chunk in self.preparation_iterator: - if self.state not in [ - AsyncProcessorState.READY, - AsyncProcessorState.PLAYING, - ]: - break - - self.process_chunk(chunk) - - def process_chunk(self, chunk) -> None: - self.chunks_in_preparation.put((chunk, self.get_frames_from_chunk(chunk))) - - def preparation_done(self): - self.chunks_in_preparation.put((None, None)) - - def do_play(self) -> None: - while True: - if self.state not in [ - AsyncProcessorState.READY, - AsyncProcessorState.PLAYING, - ]: - break - prepared_chunk = self.chunks_in_preparation.get() - if prepared_chunk[0] == None: - return - - self.play_prepared_chunk(prepared_chunk) - - def play_prepared_chunk(self, prepared_chunk) -> None: - chunk, tts_generator = prepared_chunk - for frames in tts_generator: - if self.state not in [ - AsyncProcessorState.READY, - AsyncProcessorState.PLAYING, - ]: - break - - if not self.has_sent_first_frame: - self.output_queue.put(QueueFrame(FrameType.START_STREAM, None)) - self.has_sent_first_frame = True - - for frame in frames: - self.output_queue.put(frame) - - self.output_queue.join() - self.llm_responses.append(chunk) - - def do_finalization(self) -> None: - self.message_handler.add_assistant_messages(self.llm_responses) - - def do_interruption(self) -> None: - self.chunks_in_preparation.put((None, None)) - - if self.prepare_thread and self.prepare_thread.is_alive(): - self.prepare_thread.join() - - if self.play_thread and self.play_thread.is_alive(): - self.play_thread.join() - - -@dataclass(frozen=True) -class ConversationProcessorCollection: - introduction: Optional[Type[OrchestratorResponse]] = None - waiting: Optional[Type[OrchestratorResponse]] = None - response: Optional[Type[OrchestratorResponse]] = None - goodbye: Optional[Type[OrchestratorResponse]] = None diff --git a/src/dailyai/conversation_wrappers.py b/src/dailyai/conversation_wrappers.py new file mode 100644 index 000000000..79d751f36 --- /dev/null +++ b/src/dailyai/conversation_wrappers.py @@ -0,0 +1,76 @@ +import asyncio +import copy +import functools +from typing import AsyncGenerator, Awaitable, Callable +from dailyai.queue_aggregators import LLMContextAggregator +from dailyai.queue_frame import EndStreamQueueFrame, QueueFrame, TranscriptionQueueFrame + +class InterruptibleConversationWrapper: + + def __init__( + self, + frame_generator: Callable[[], AsyncGenerator[QueueFrame, None]], + runner: Callable[ + [str, LLMContextAggregator, LLMContextAggregator], Awaitable[None] + ], + interrupt: Callable[[], None], + my_participant_id: str|None, + llm_messages: list[dict[str, str]], + llm_context_aggregator_in=LLMContextAggregator, + llm_context_aggregator_out=LLMContextAggregator, + delay_before_speech_seconds: float = 1.0, + ): + self._frame_generator: Callable[[], AsyncGenerator[QueueFrame, None]] = frame_generator + self._runner: Callable[ + [str, LLMContextAggregator, LLMContextAggregator], Awaitable[None] + ] = runner + self._interrupt: Callable[[], None] = interrupt + self._my_participant_id = my_participant_id + self._messages: list[dict[str, str]] = llm_messages + self._delay_before_speech_seconds = delay_before_speech_seconds + self._llm_context_aggregator_in = llm_context_aggregator_in + self._llm_context_aggregator_out = llm_context_aggregator_out + + self._current_phrase = "" + + def update_messages(self, new_messages: list[dict[str, str]], task: asyncio.Task | None): + if task: + if not task.cancelled(): + self._current_phrase = "" + self._messages = new_messages + + async def speak_after_delay(self, user_speech, messages): + await asyncio.sleep(self._delay_before_speech_seconds) + tma_in = self._llm_context_aggregator_in( + messages, "user", self._my_participant_id, False + ) + tma_out = self._llm_context_aggregator_out( + messages, "assistant", self._my_participant_id + ) + + await self._runner(user_speech, tma_in, tma_out) + + async def run_conversation(self): + current_response_task = None + + async for frame in self._frame_generator(): + if isinstance(frame, EndStreamQueueFrame): + break + elif not isinstance(frame, TranscriptionQueueFrame): + continue + + if frame.participantId == self._my_participant_id: + continue + + if current_response_task: + current_response_task.cancel() + self._interrupt() + + self._current_phrase += " " + frame.text + current_llm_messages = copy.deepcopy(self._messages) + current_response_task = asyncio.create_task( + self.speak_after_delay(self._current_phrase, current_llm_messages) + ) + current_response_task.add_done_callback( + functools.partial(self.update_messages, current_llm_messages) + ) diff --git a/src/dailyai/message_handler/__init__.py b/src/dailyai/message_handler/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/dailyai/message_handler/message_handler.py b/src/dailyai/message_handler/message_handler.py deleted file mode 100644 index b9570a016..000000000 --- a/src/dailyai/message_handler/message_handler.py +++ /dev/null @@ -1,127 +0,0 @@ -import logging -import time - -from dataclasses import dataclass -from queue import Queue, Empty -from threading import Thread - -from dailyai.storage.search import SearchIndexer -from dailyai.services.ai_services import AIServiceConfig - - -@dataclass -class Message: - type: str - timestamp: float - message: str - - -class MessageHandler: - def __init__(self, intro): - self.messages: list[Message] = [Message("system", time.time(), intro)] - self.last_user_message_idx:int | None = None - self.finalized_user_message_idx: int | None = None - - def add_user_message(self, message) -> None: - if self.last_user_message_idx is not None and self.last_user_message_idx != self.finalized_user_message_idx: - previous_message: str = self.messages[self.last_user_message_idx].message - self.messages[self.last_user_message_idx] = Message( - "user", time.time(), ' '.join([previous_message, message]) - ) - self.messages = self.messages[: self.last_user_message_idx + 1] - else: - self.messages.append(Message("user", time.time(), message)) - - self.last_user_message_idx = len(self.messages) - 1 - - def add_assistant_message(self, message) -> None: - if self.messages[-1].type == "assistant": - self.messages[-1].message += " " + message - else: - self.messages.append(Message("assistant", time.time(), message)) - - def add_assistant_messages(self, messages) -> None: - self.messages.append(Message("assistant", time.time(), " ".join(messages))) - - def get_llm_messages(self) -> list[dict[str, str]]: - return [{"role": m.type, "content": m.message} for m in self.messages] - - def finalize_user_message(self) -> None: - self.finalized_user_message_idx = self.last_user_message_idx - - def shutdown(self) -> None: - pass - -class IndexingMessageHandler(MessageHandler): - def __init__( - self, intro, services: AIServiceConfig, indexer: SearchIndexer - ) -> None: - super().__init__(intro) - self.services = services - - self.search_indexer = indexer - - self.last_written_idx = 0 - self.storage_message_queue = Queue() - - self.index_writer_thread = Thread(target=self.storage_writer, daemon=True) - self.index_writer_thread.start() - - self.logger = logging.getLogger("dailyai") - - def shutdown(self): - self.finalize_user_message() - self.storage_message_queue.put(None) - self.index_writer_thread.join() - - def storage_writer(self) -> None: - while True: - try: - message_idx = self.storage_message_queue.get() - self.storage_message_queue.task_done() - - if message_idx is None: - return - - if message_idx <= self.last_written_idx: - continue - - self.last_written_idx = message_idx - - message = self.messages[message_idx] - content = message.message - if message.type == "user": - content = self.cleanup_user_message(content) - - # sometimes the LLM returns a string wrapped in quotes and sometimes it doesn't. - # if it didn't, wrap it in quotes - if content[0] != '"': - content = '"' + content + '"' - - self.search_indexer.index_text(content) - except Empty: - pass - - def cleanup_user_message(self, user_message) -> str: - return user_message - - def finalize_user_message(self): - super().finalize_user_message() - self.write_messages_to_storage() - - def write_messages_to_storage(self): - if self.finalized_user_message_idx is None: - return - - for idx in range(self.last_written_idx, len(self.messages)): - self.logger.info( - f"Writing to storage: {self.messages[idx].type} {self.messages[idx].message}" - ) - if ( - self.messages[idx].type == "user" - and idx > self.finalized_user_message_idx - ): - break - - if self.messages[idx].type != "system": - self.storage_message_queue.put(idx) diff --git a/src/dailyai/orchestrator.py b/src/dailyai/orchestrator.py deleted file mode 100644 index dcae4a78d..000000000 --- a/src/dailyai/orchestrator.py +++ /dev/null @@ -1,409 +0,0 @@ -import logging -import os -import time -import wave - -from dataclasses import dataclass -from enum import Enum -from queue import Queue, Empty -from opentelemetry import trace, context - -from dailyai.async_processor.async_processor import ( - AsyncProcessor, - AsyncProcessorState, - ConversationProcessorCollection, - OrchestratorResponse, - LLMResponse, -) -from dailyai.queue_frame import QueueFrame, FrameType -from dailyai.services.ai_services import AIServiceConfig -from dailyai.message_handler.message_handler import MessageHandler - -from threading import Thread, Semaphore, Event, Timer - -from opentelemetry import context -from opentelemetry.context.context import Context - -from daily import ( - EventHandler, - CallClient, - Daily, - VirtualCameraDevice, - VirtualMicrophoneDevice, - VirtualSpeakerDevice, -) - - -@dataclass -class OrchestratorConfig: - room_url: str - token: str - bot_name: str - expiration: float - -# Note that we use this as a default parameter value in the Orchestrator -# constructor. The dataclass is defined with Frozen=True, so this should -# be safe. -default_conversation_collection = ConversationProcessorCollection( - introduction=LLMResponse, - waiting=None, - response=LLMResponse, - goodbye=None, -) - - -class Orchestrator(EventHandler): - - def __init__( - self, - daily_config: OrchestratorConfig, - ai_service_config: AIServiceConfig, - message_handler: MessageHandler, - conversation_processors: ConversationProcessorCollection = default_conversation_collection, - tracer=None, - ): - self.bot_name: str = daily_config.bot_name - self.room_url: str = daily_config.room_url - self.token: str = daily_config.token - self.expiration: float = daily_config.expiration - - self.logger: logging.Logger = logging.getLogger("dailyai") - self.tracer = tracer or trace.get_tracer("orchestrator") - - self.ctx: Context = context.get_current() - - self.transcription = "" - self.last_fragment_at = None - self.talked_at = None - self.paused_at = None - - self.logger.info(f"Creating Response for introductions") - self.services: AIServiceConfig = ai_service_config - self.output_queue = Queue() - self.is_interrupted = Event() - self.stop_threads = Event() - self.story_started = False - - self.message_handler = message_handler - self.conversation_processors: ConversationProcessorCollection = conversation_processors - - if conversation_processors.introduction is not None: - intro = conversation_processors.introduction( - services=self.services, message_handler=self.message_handler, output_queue=self.output_queue - ) - intro.prepare() - intro.set_state_callback(AsyncProcessorState.DONE, self.on_intro_played) - intro.set_state_callback(AsyncProcessorState.FINALIZED, self.on_intro_finished) - self.logger.info(f"Introduction is preparing") - - self.current_response: AsyncProcessor = intro - self.can_interrupt = False - # self.response_event.set() - self.response_semaphore = Semaphore() - - self.speech_timeout = None - self.interrupt_time = None - - self.logger.info("Configuring daily") - self.configure_daily() - - def configure_daily(self): - Daily.init() - self.client = CallClient(event_handler=self) - - self.logger.info(f"Mic sample rate: {self.services.tts.get_mic_sample_rate()}") - self.mic: VirtualMicrophoneDevice = Daily.create_microphone_device( - "mic", sample_rate=self.services.tts.get_mic_sample_rate(), channels=1 - ) - self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device( - "speaker", sample_rate=16000, channels=1 - ) - self.camera: VirtualCameraDevice = Daily.create_camera_device( - "camera", width=720, height=1280, color_format="RGB" - ) - - Daily.select_speaker_device("speaker") - - self.client.set_user_name(self.bot_name) - self.client.join(self.room_url, self.token, completion=self.call_joined) - - self.client.update_inputs( - { - "camera": { - "isEnabled": True, - "settings": { - "deviceId": "camera", - }, - }, - "microphone": { - "isEnabled": True, - "settings": { - "deviceId": "mic", - "customConstraints": { - "autoGainControl": {"exact": False}, - "echoCancellation": {"exact": False}, - "noiseSuppression": {"exact": False}, - }, - }, - }, - } - ) - - self.client.update_publishing( - { - "camera": { - "sendSettings": { - "maxQuality": "low", - "encodings": { - "low": { - "maxBitrate": 250000, - "scaleResolutionDownBy": 1.333, - "maxFramerate": 8, - } - }, - } - } - } - ) - - self.my_participant_id = self.client.participants()["local"]["id"] - - def start(self) -> None: - # TODO: this loop could, I think, be replaced with a timer and an event - self.participant_left = False - - try: - participant_count: int = len(self.client.participants()) - self.logger.info(f"{participant_count} participants in room") - while time.time() < self.expiration and not self.participant_left: - # all handling of incoming transcriptions happens in on_transcription_message - time.sleep(1) - except Exception as e: - self.logger.error(f"Exception {e}") - finally: - self.client.leave() - - def stop(self): - self.logger.info("Stop current response") - if self.current_response: - if self.current_response.state < AsyncProcessorState.INTERRUPTED: - self.current_response.interrupt() - - self.logger.info("Wait for state transition") - self.current_response.wait_for_state_transition(AsyncProcessorState.FINALIZED) - - self.stop_threads.set() - self.camera_thread.join() - self.logger.info("Camera thread stopped") - - self.logger.info("Put stop in output queue") - self.output_queue.put(QueueFrame(FrameType.END_STREAM, None)) - - self.frame_consumer_thread.join() - self.logger.info("Orchestrator stopped.") - - def on_intro_played(self, intro): - self.logger.info(f"Introduction has played") - self.can_interrupt = True - intro.finalize() - - def on_intro_finished(self, intro): - self.logger.info(f"Introduction has finished") - waiting = self.conversation_processors.waiting(self.services, self.message_handler, self.output_queue) - waiting.prepare() - waiting.play() - - def on_response_played(self, response): - response.finalize() - - def on_response_finished(self, response): - if not response.was_interrupted: - self.message_handler.finalize_user_message() - - def call_joined(self, join_data, client_error): - self.logger.info(f"Call_joined: {join_data}, {client_error}") - self.client.start_transcription( - { - "language": "en", - "tier": "nova", - "model": "2-conversationalai", - "profanity_filter": True, - "redact": False, - "extra": { - "endpointing": True, - "punctuate": False, - } - } - ) - - def on_participant_joined(self, participant): - with self.tracer.start_as_current_span("on_participant_joined", context=self.ctx): - self.logger.info(f"on_participant_joined: {participant}") - - # TODO: figure out the architecture to get the story id to the client - # self.client.send_app_message({"event": "story-id", "storyID": self.story_id}) - time.sleep(2) - - if not self.story_started: - self.action() - self.story_started = True - - def on_participant_left(self, participant, reason): - self.logger.info(f"Participant {participant} left") - if len(self.client.participants()) < 2: - self.participant_left = True - - def on_app_message(self, message, sender): - with self.tracer.start_as_current_span("on_app_message", context=self.ctx): - self.logger.info(f"on_app_message {message} from {sender}") - if "isSpeaking" in message and message["isSpeaking"] == True: - self.handle_user_started_talking() - - if "isSpeaking" in message and message["isSpeaking"] == False: - self.handle_user_stopped_talking() - - def on_transcription_message(self, message): - with self.tracer.start_as_current_span("on_transcription_message", context=self.ctx): - if message["session_id"] != self.my_participant_id: - self.handle_transcription_fragment(message['text']) - - def on_transcription_stopped(self, stopped_by, stopped_by_error): - self.logger.info(f"Transcription stopped {stopped_by}, {stopped_by_error}") - - def on_transcription_error(self, message): - self.logger.error(f"Transcription error {message}") - - def on_transcription_started(self, status): - self.logger.info(f"Transcription started {status}") - - def set_image(self, image: bytes): - self.image: bytes | None = image - - def run_camera(self): - try: - while not self.stop_threads.is_set(): - if self.image: - self.camera.write_frame(self.image) - - time.sleep(1.0 / 8.0) # 8 fps - except Exception as e: - self.logger.error(f"Exception {e} in camera thread.") - - def handle_user_started_talking(self): - # TODO: allow configuration of the timer timeout - self.logger.error("user started talking") - self.speech_timeout = Timer(1.0, self.utterance_interrupt) - - def handle_user_stopped_talking(self): - self.logger.error("user stopped talking, canceling utterance interrupt") - if self.speech_timeout: - self.speech_timeout.cancel() - - def utterance_interrupt(self): - self.logger.error("utterance interrupt") - self.is_interrupted.set() - - def handle_transcription_fragment(self, fragment): - if not self.can_interrupt: - return - - # start generating a new response. We'll do the fast parts of the interrupt - # now but wait for the state transition after we've kicked off the prepare - # on the new response. - if ( - self.current_response - and self.current_response.state < AsyncProcessorState.INTERRUPTED - ): - self.interrupt_time = time.perf_counter() - self.is_interrupted.set() - self.current_response.interrupt() - - self.message_handler.add_user_message(fragment) - - response_type: type[OrchestratorResponse] | type[LLMResponse] = self.conversation_processors.response or LLMResponse - new_response: OrchestratorResponse = response_type( - self.services, self.message_handler, self.output_queue - ) - new_response.set_state_callback( - AsyncProcessorState.DONE, self.on_response_played - ) - new_response.set_state_callback( - AsyncProcessorState.FINALIZED, self.on_response_finished - ) - new_response.prepare() - - self.response_semaphore.acquire() - if ( - self.current_response - and self.current_response.state < AsyncProcessorState.INTERRUPTED - ): - self.current_response.wait_for_state_transition( - AsyncProcessorState.FINALIZED - ) - - self.current_response = new_response - self.current_response.play() - - self.response_semaphore.release() - - def action(self): - self.logger.info("Starting camera thread") - self.image: bytes | None = None - self.camera_thread = Thread(target=self.run_camera, daemon=True) - self.camera_thread.start() - - self.logger.info("Starting frame consumer thread") - self.frame_consumer_thread = Thread(target=self.frame_consumer, daemon=True) - self.frame_consumer_thread.start() - - self.logger.info("Playing introduction") - self.can_interrupt = False - self.current_response.play() - - def frame_consumer(self): - self.logger.info("🎬 Starting frame consumer thread") - b = bytearray() - smallest_write_size = 3200 - all_audio_frames = bytearray() - while True: - try: - frame:QueueFrame = self.output_queue.get() - if frame.frame_type == FrameType.END_STREAM: - self.logger.info("Stopping frame consumer thread") - return - - # if interrupted, we just pull frames off the queue and discard them - if not self.is_interrupted.is_set(): - if frame: - if frame.frame_type == FrameType.AUDIO: - chunk = frame.frame_data - - all_audio_frames.extend(chunk) - - b.extend(chunk) - l = len(b) - (len(b) % smallest_write_size) - if l: - self.mic.write_frames(bytes(b[:l])) - b = b[l:] - elif frame.frame_type == FrameType.IMAGE: - self.set_image(frame.frame_data) - elif len(b): - self.mic.write_frames(bytes(b)) - b = bytearray() - else: - if self.interrupt_time: - self.logger.info(f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}") - self.interrupt_time = None - - if frame.frame_type == FrameType.START_STREAM: - self.is_interrupted.clear() - - self.output_queue.task_done() - except Empty: - try: - if len(b): - self.mic.write_frames(bytes(b)) - except Exception as e: - self.logger.error(f"Exception in frame_consumer: {e}, {len(b)}") - - b = bytearray() diff --git a/src/dailyai/queue_aggregators.py b/src/dailyai/queue_aggregators.py index ee6ca34bb..659af1cd7 100644 --- a/src/dailyai/queue_aggregators.py +++ b/src/dailyai/queue_aggregators.py @@ -25,23 +25,24 @@ class QueueTee: await queue.put(frame) class LLMContextAggregator(AIService): - def __init__(self, messages: list[dict], role:str, bot_participant_id=None): + def __init__(self, messages: list[dict], role:str, bot_participant_id=None, complete_sentences=True): self.messages = messages self.bot_participant_id = bot_participant_id self.role = role self.sentence = "" + self.complete_sentences = complete_sentences async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]: - content: str = "" - # TODO: split up transcription by participant if isinstance(frame, TextQueueFrame): - content = frame.text - - self.sentence += content - if self.sentence.endswith((".", "?", "!")): - self.messages.append({"role": self.role, "content": self.sentence}) - self.sentence = "" - yield LLMMessagesQueueFrame(self.messages) + if self.complete_sentences: + self.sentence += frame.text + if self.sentence.endswith((".", "?", "!")): + self.messages.append({"role": self.role, "content": self.sentence}) + self.sentence = "" + yield LLMMessagesQueueFrame(self.messages) + else: + self.messages.append({"role": self.role, "content": frame.text}) + yield LLMMessagesQueueFrame(self.messages) yield frame diff --git a/src/dailyai/queue_frame.py b/src/dailyai/queue_frame.py index d72345aaf..81b391f36 100644 --- a/src/dailyai/queue_frame.py +++ b/src/dailyai/queue_frame.py @@ -5,10 +5,13 @@ from typing import Any class QueueFrame: pass -class StartStreamQueueFrame(QueueFrame): +class ControlQueueFrame(QueueFrame): pass -class EndStreamQueueFrame(QueueFrame): +class StartStreamQueueFrame(ControlQueueFrame): + pass + +class EndStreamQueueFrame(ControlQueueFrame): pass @dataclass() diff --git a/src/dailyai/services/ai_services.py b/src/dailyai/services/ai_services.py index 24c676709..88b84bea7 100644 --- a/src/dailyai/services/ai_services.py +++ b/src/dailyai/services/ai_services.py @@ -7,6 +7,7 @@ import wave from dailyai.queue_frame import ( AudioQueueFrame, + ControlQueueFrame, EndStreamQueueFrame, ImageQueueFrame, LLMMessagesQueueFrame, @@ -28,11 +29,15 @@ class AIService: pass async def run_to_queue(self, queue: asyncio.Queue, frames, add_end_of_stream=False) -> None: - async for frame in self.run(frames): - await queue.put(frame) + try: + async for frame in self.run(frames): + await queue.put(frame) - if add_end_of_stream: - await queue.put(EndStreamQueueFrame()) + if add_end_of_stream: + await queue.put(EndStreamQueueFrame()) + except Exception as e: + print("Exception in run_to_queue", e) + raise e async def run( self, @@ -67,9 +72,8 @@ class AIService: @abstractmethod async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]: - # This is a trick for the interpreter (and linter) to know that this is a generator. - if False: - yield QueueFrame() + if isinstance(frame, ControlQueueFrame): + yield frame @abstractmethod async def finalize(self) -> AsyncGenerator[QueueFrame, None]: @@ -87,7 +91,9 @@ class LLMService(AIService): pass async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]: - if isinstance(frame, LLMMessagesQueueFrame): + if isinstance(frame, ControlQueueFrame): + yield frame + elif isinstance(frame, LLMMessagesQueueFrame): async for text_chunk in self.run_llm_async(frame.messages): yield TextQueueFrame(text_chunk) @@ -110,9 +116,11 @@ class TTSService(AIService): yield bytes() async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]: - if not isinstance(frame, TextQueueFrame): + if isinstance(frame, ControlQueueFrame): yield frame return + elif not isinstance(frame, TextQueueFrame): + return text: str | None = None if not self.aggregate_sentences: @@ -149,6 +157,7 @@ class ImageGenService(AIService): async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]: if not isinstance(frame, TextQueueFrame): + yield frame return (url, image_data) = await self.run_image_gen(frame.text) diff --git a/src/dailyai/services/azure_ai_services.py b/src/dailyai/services/azure_ai_services.py index b723e77e4..46449f208 100644 --- a/src/dailyai/services/azure_ai_services.py +++ b/src/dailyai/services/azure_ai_services.py @@ -35,10 +35,7 @@ class AzureTTSService(TTSService): "" \ f"{sentence}" \ " " - try: - result = await asyncio.to_thread(self.speech_synthesizer.speak_ssml, (ssml)) - except Exception as e: - self.logger.error("Error in azure tts", e) + result = await asyncio.to_thread(self.speech_synthesizer.speak_ssml, (ssml)) self.logger.info("Got azure tts result") if result.reason == ResultReason.SynthesizingAudioCompleted: self.logger.info("Returning result") diff --git a/src/dailyai/services/daily_transport_service.py b/src/dailyai/services/daily_transport_service.py index aed5d9c9e..36b218f77 100644 --- a/src/dailyai/services/daily_transport_service.py +++ b/src/dailyai/services/daily_transport_service.py @@ -7,6 +7,7 @@ import types from functools import partial from queue import Queue, Empty +from typing import AsyncGenerator from dailyai.queue_frame import ( AudioQueueFrame, @@ -14,6 +15,7 @@ from dailyai.queue_frame import ( ImageQueueFrame, QueueFrame, StartStreamQueueFrame, + TextQueueFrame, TranscriptionQueueFrame, ) @@ -114,6 +116,7 @@ class DailyTransportService(EventHandler): handler(*args, **kwargs) except Exception as e: self.logger.error(f"Exception in event handler {event_name}: {e}") + raise e def add_event_handler(self, event_name: str, handler): if not event_name.startswith("on_"): @@ -214,7 +217,6 @@ class DailyTransportService(EventHandler): if self.token and self.start_transcription: self.client.start_transcription(self.transcription_settings) - def _receive_audio(self): """Receive audio from the Daily call and put it on the receive queue""" seconds = 1 @@ -223,9 +225,13 @@ class DailyTransportService(EventHandler): buffer = self.speaker.read_frames(desired_frame_count) if len(buffer) > 0: frame = AudioQueueFrame(buffer) - asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self.loop) + if self.loop: + asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self.loop) - async def get_receive_frames(self): + def interrupt(self): + self.is_interrupted.set() + + async def get_receive_frames(self) -> AsyncGenerator[QueueFrame, None]: while True: frame = await self.receive_queue.get() yield frame @@ -265,6 +271,7 @@ class DailyTransportService(EventHandler): await asyncio.sleep(1) except Exception as e: self.logger.error(f"Exception {e}") + raise e finally: self.client.leave() @@ -341,6 +348,7 @@ class DailyTransportService(EventHandler): time.sleep(1.0 / 8) # 8 fps except Exception as e: self.logger.error(f"Exception {e} in camera thread.") + raise e def frame_consumer(self): self.logger.info("🎬 Starting frame consumer thread") @@ -382,11 +390,11 @@ class DailyTransportService(EventHandler): self.mic.write_frames(bytes(b)) b = bytearray() else: - if self.interrupt_time: - self.logger.info( - f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}" - ) - self.interrupt_time = None + # if there are leftover audio bytes, write them now; failing to do so + # can cause static in the audio stream. + if len(b): + self.mic.write_frames(bytes(b)) + b = bytearray() if isinstance(frame, StartStreamQueueFrame): self.is_interrupted.clear() @@ -398,5 +406,6 @@ class DailyTransportService(EventHandler): self.mic.write_frames(bytes(b)) except Exception as e: self.logger.error(f"Exception in frame_consumer: {e}, {len(b)}") + raise e b = bytearray() diff --git a/src/dailyai/tests/test_asyncprocessor.py b/src/dailyai/tests/test_asyncprocessor.py deleted file mode 100644 index fcb2781e4..000000000 --- a/src/dailyai/tests/test_asyncprocessor.py +++ /dev/null @@ -1,180 +0,0 @@ -import time -import unittest - -from queue import Queue, Empty -from threading import Thread, Event -from typing import Generator - -from dailyai.async_processor.async_processor import ( - AsyncProcessor, - AsyncProcessorState, - LLMResponse, -) -from dailyai.message_handler.message_handler import MessageHandler -from dailyai.queue_frame import QueueFrame, FrameType -from dailyai.services.ai_services import ( - AIServiceConfig, - ImageGenService, - LLMService, - TTSService, -) -""" -class MockTTSService(TTSService): - def run_tts(self, sentence): - for word in sentence.split(' '): - time.sleep(0.1) - yield bytes(word, "utf-8") - -class MockLLMService(LLMService): - def run_llm_async(self, messages) -> Generator[str, None, None]: - for i in ["Hello ", "there.", "How are ", "you?", "I ", "hope ", "you ", "are ", "well."]: - time.sleep(0.1) - yield i - -class MockImageService(ImageGenService): - def run_image_gen(self, sentence) -> None: - return None - -class TestResponse(unittest.TestCase): - def test_base_state_transitions(self): - mock_tts_service = MockTTSService() - mock_llm_service = MockLLMService() - mock_image_service = MockImageService() - processor = AsyncProcessor(AIServiceConfig(tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service)) - processor.prepare() - processor.play() - processor.finalize() - self.assertEqual(processor.state, AsyncProcessorState.FINALIZED) - - def test_state_transitions(self): - output_queue = Queue() - mock_tts_service = MockTTSService() - mock_llm_service = MockLLMService() - mock_image_service = MockImageService() - message_handler = MessageHandler("Hello World") - processor = LLMResponse( - AIServiceConfig( - tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service - ), - message_handler, - output_queue, - ) - processor.prepare() - processor.play() - - # Consume the output from the output queue. It's necessary to mark these tasks as done for the - # play function to return. - expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."] - - # remove the "start_stream" message from the queue - output_queue.get() - output_queue.task_done() - - while expected_words: - actual_word:QueueFrame = output_queue.get() - word = expected_words.pop(0) - self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME) - self.assertEqual(actual_word.frame_data, bytes(word, "utf-8")) - output_queue.task_done() - - processor.finalize() - - self.assertEqual(processor.state, AsyncProcessorState.FINALIZED) - - def test_interrupt_preparation(self): - output_queue = Queue() - mock_tts_service = MockTTSService() - mock_llm_service = MockLLMService() - mock_image_service = MockImageService() - message_handler = MessageHandler("System Message") - processor = LLMResponse( - AIServiceConfig( - tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service - ), - message_handler, - output_queue, - ) - processor.prepare() - interrupt_request_at = time.perf_counter() - processor.interrupt() - processor.finalize() - finalized_at = time.perf_counter() - self.assertTrue(0.1 < finalized_at - interrupt_request_at < 0.2) - print(f"delta: {interrupt_request_at, finalized_at}") - self.assertEqual(processor.state, AsyncProcessorState.FINALIZED) - - def test_interrupt_play(self): - output_queue = Queue() - mock_tts_service = MockTTSService() - mock_llm_service = MockLLMService() - mock_image_service = MockImageService() - message_handler = MessageHandler("System Message") - processor = LLMResponse( - AIServiceConfig( - tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service - ), - message_handler, - output_queue, - ) - processor.prepare() - processor.play() - - stop_processing_output_queue = Event() - def process_output_queue_async(): - # Consume the output from the output queue. It's necessary to mark these tasks as done for the - # play function to return. - time.sleep(0.1) - expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."] - while expected_words and not stop_processing_output_queue.is_set(): - try: - actual_word:QueueFrame = output_queue.get_nowait() - if actual_word.frame_type == FrameType.AUDIO_FRAME: - time.sleep(0.1) - word = expected_words.pop(0) - self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME) - self.assertEqual(actual_word.frame_data, bytes(word, "utf-8")) - output_queue.task_done() - except Empty: - pass - - process_output_queue = Thread(target=process_output_queue_async, daemon=True) - process_output_queue.start() - - time.sleep(0.5) - processor.interrupt() - - stop_processing_output_queue.set() - process_output_queue.join() - - processor.finalize() - self.assertEqual(processor.state, AsyncProcessorState.FINALIZED) - - def test_statechange_callback(self): - mock_tts_service = MockTTSService() - mock_llm_service = MockLLMService() - mock_image_service = MockImageService() - processor = AsyncProcessor( - AIServiceConfig( - tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service - ) - ) - is_finalized = False - def set_is_finalized(async_processor:AsyncProcessor): - nonlocal is_finalized - is_finalized = True - - processor.set_state_callback( - AsyncProcessorState.FINALIZED, set_is_finalized - ) - processor.prepare() - self.assertFalse(is_finalized) - processor.play() - self.assertFalse(is_finalized) - processor.finalize() - self.assertTrue(is_finalized) - self.assertEqual(processor.state, AsyncProcessorState.FINALIZED) - - -if __name__ == '__main__': - unittest.main() -""" diff --git a/src/dailyai/tests/test_message_handler.py b/src/dailyai/tests/test_message_handler.py deleted file mode 100644 index 9869755c2..000000000 --- a/src/dailyai/tests/test_message_handler.py +++ /dev/null @@ -1,147 +0,0 @@ -import time -import unittest - -from unittest.mock import MagicMock, call - -from dailyai.message_handler.message_handler import MessageHandler, IndexingMessageHandler -from dailyai.services.ai_services import ( - AIServiceConfig, - TTSService, - LLMService, - ImageGenService, -) -from ..storage.search import SearchIndexer - - -class TestMessageHandler(unittest.TestCase): - def test_simple_intro(self): - message_handler = MessageHandler("Hello world") - self.assertEqual( - message_handler.get_llm_messages(), - [{"role": "system", "content": "Hello world"}], - ) - - def test_simple_user_message(self): - message_handler = MessageHandler("System prompt") - message_handler.add_user_message("User message") - self.assertEqual( - message_handler.get_llm_messages(), - [ - {"role": "system", "content": "System prompt"}, - {"role": "user", "content": "User message"}, - ], - ) - - def test_simple_user_and_assistant_message(self): - message_handler = MessageHandler("System prompt") - message_handler.add_user_message("User message") - message_handler.add_assistant_message("Assistant message") - self.assertEqual( - message_handler.get_llm_messages(), - [ - {"role": "system", "content": "System prompt"}, - {"role": "user", "content": "User message"}, - {"role": "assistant", "content": "Assistant message"}, - ], - ) - - def test_user_message_overwrite(self): - message_handler = MessageHandler("System prompt") - message_handler.add_user_message("User message") - message_handler.add_assistant_message("Assistant message") - message_handler.add_user_message("plus something else") - self.assertEqual( - message_handler.get_llm_messages(), - [ - {"role": "system", "content": "System prompt"}, - {"role": "user", "content": "User message plus something else"}, - ], - ) - - def test_user_message_after_assistant(self): - message_handler = MessageHandler("System prompt") - message_handler.add_user_message("User message") - message_handler.add_assistant_message("Assistant message") - message_handler.finalize_user_message() - message_handler.add_user_message("other user message") - self.assertEqual( - message_handler.get_llm_messages(), - [ - {"role": "system", "content": "System prompt"}, - {"role": "user", "content": "User message"}, - {"role": "assistant", "content": "Assistant message"}, - {"role": "user", "content": "other user message"}, - ], - ) - - -class MockTTSService(TTSService): - def run_tts(self, sentence): - for word in sentence.split(" "): - time.sleep(0.1) - yield bytes(word, "utf-8") - - -class MockLLMService(LLMService): - def run_llm(self, messages) -> str: - return "Parsed user message." - -class MockImageService(ImageGenService): - def run_image_gen(self, sentence) -> None: - return None - - -class TestStorageMessageHandler(unittest.TestCase): - def test_user_message_finalized(self): - mock_tts_service = MockTTSService() - mock_llm_service = MockLLMService() - mock_image_service = MockImageService() - - service_config = AIServiceConfig( - tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service - ) - - mock_indexer = MagicMock(spec=SearchIndexer) - - message_handler = IndexingMessageHandler( - "Hello world", service_config, mock_indexer - ) - message_handler.cleanup_user_message = MagicMock(return_value="Parsed user message.") - message_handler.add_user_message("User message") - message_handler.add_assistant_message("Assistant message will be ignored") - message_handler.add_user_message("plus something else") - message_handler.finalize_user_message() - message_handler.add_assistant_message( - "New assistant message will not be ignored" - ) - message_handler.add_user_message("User message second time") - message_handler.add_assistant_message("Assistant message second time") - message_handler.write_messages_to_storage() - - time.sleep(0.5) - message_handler.cleanup_user_message.assert_called_with("User message plus something else") - self.assertEqual( - mock_indexer.mock_calls, - [ - call.index_text('"Parsed user message."'), - call.index_text("New assistant message will not be ignored"), - ], - ) - - mock_indexer.reset_mock() - - message_handler.finalize_user_message() - - time.sleep(0.5) - - self.assertEqual( - mock_indexer.mock_calls, - [ - call.index_text('"Parsed user message."'), - call.index_text("Assistant message second time"), - ], - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/samples/foundational/07-interruptible.py b/src/samples/foundational/07-interruptible.py new file mode 100644 index 000000000..96b80325d --- /dev/null +++ b/src/samples/foundational/07-interruptible.py @@ -0,0 +1,98 @@ +import argparse +import asyncio +import requests +import time +import urllib.parse +from dailyai.conversation_wrappers import InterruptibleConversationWrapper + +from dailyai.queue_frame import StartStreamQueueFrame, TextQueueFrame +from dailyai.services.daily_transport_service import DailyTransportService +from dailyai.services.azure_ai_services import AzureLLMService +from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService + + +async def main(room_url:str, token): + global transport + global llm + global tts + + transport = DailyTransportService( + room_url, + token, + "Respond bot", + 5, + ) + transport.mic_enabled = True + transport.mic_sample_rate = 16000 + transport.camera_enabled = False + + llm = AzureLLMService() + tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV") + + async def run_response(user_speech, tma_in, tma_out): + await tts.run_to_queue( + transport.send_queue, + tma_out.run( + llm.run( + tma_in.run( + [StartStreamQueueFrame(), TextQueueFrame(user_speech)] + ) + ) + ), + ) + + @transport.event_handler("on_first_other_participant_joined") + async def on_first_other_participant_joined(transport): + await tts.say("Hi, I'm listening!", transport.send_queue) + + async def run_conversation(): + messages = [ + {"role": "system", "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio. Respond to what the user said in a creative and helpful way."}, + ] + + conversation_wrapper = InterruptibleConversationWrapper( + frame_generator=transport.get_receive_frames, + runner=run_response, + interrupt=transport.interrupt, + my_participant_id=transport.my_participant_id, + llm_messages=messages, + ) + await conversation_wrapper.run_conversation() + + transport.transcription_settings["extra"]["punctuate"] = False + await asyncio.gather(transport.run(), run_conversation()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Simple Daily Bot Sample") + parser.add_argument( + "-u", "--url", type=str, required=True, help="URL of the Daily room to join" + ) + parser.add_argument( + "-k", + "--apikey", + type=str, + required=True, + help="Daily API Key (needed to create token)", + ) + + args, unknown = parser.parse_known_args() + + # Create a meeting token for the given room with an expiration 1 hour in the future. + room_name: str = urllib.parse.urlparse(args.url).path[1:] + expiration: float = time.time() + 60 * 60 + + res: requests.Response = requests.post( + f"https://api.daily.co/v1/meeting-tokens", + headers={"Authorization": f"Bearer {args.apikey}"}, + json={ + "properties": {"room_name": room_name, "is_owner": True, "exp": expiration} + }, + ) + + if res.status_code != 200: + raise Exception(f"Failed to create meeting token: {res.status_code} {res.text}") + + token: str = res.json()["token"] + + asyncio.run(main(args.url, token))