From e724720e7626409a6dcc9884e9a798c2ee8d40e4 Mon Sep 17 00:00:00 2001 From: Moishe Lettvin Date: Mon, 25 Dec 2023 19:09:11 -0500 Subject: [PATCH] Getting started --- .gitignore | 5 + __init__.py | 0 async_processor/async_processor.py | 373 ++++++++++++++++++++++++ message_handler/message_handler.py | 144 +++++++++ orchestrator.py | 453 +++++++++++++++++++++++++++++ requirements.txt | 2 + services/ai_services.py | 56 ++++ services/azure_ai_services.py | 116 ++++++++ services/cloudflare_ai_service.py | 65 +++++ services/deepgram_ai_service.py | 28 ++ services/elevenlabs_ai_service.py | 38 +++ services/google_ai_service.py | 26 ++ services/huggingface_ai_service.py | 26 ++ services/mock_ai_service.py | 27 ++ services/open_ai_service.py | 57 ++++ services/playht_ai_service.py | 56 ++++ storage/search.py | 50 ++++ tests/test_asyncprocessor.py | 179 ++++++++++++ tests/test_message_handler.py | 129 ++++++++ 19 files changed, 1830 insertions(+) create mode 100644 .gitignore create mode 100644 __init__.py create mode 100644 async_processor/async_processor.py create mode 100644 message_handler/message_handler.py create mode 100644 orchestrator.py create mode 100644 requirements.txt create mode 100644 services/ai_services.py create mode 100644 services/azure_ai_services.py create mode 100644 services/cloudflare_ai_service.py create mode 100644 services/deepgram_ai_service.py create mode 100644 services/elevenlabs_ai_service.py create mode 100644 services/google_ai_service.py create mode 100644 services/huggingface_ai_service.py create mode 100644 services/mock_ai_service.py create mode 100644 services/open_ai_service.py create mode 100644 services/playht_ai_service.py create mode 100644 storage/search.py create mode 100644 tests/test_asyncprocessor.py create mode 100644 tests/test_message_handler.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..8f629269b --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.vscode +env/ +__pycache__/ +*~ +#*# \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/async_processor/async_processor.py b/async_processor/async_processor.py new file mode 100644 index 000000000..64c0b2efb --- /dev/null +++ b/async_processor/async_processor.py @@ -0,0 +1,373 @@ +import json +import logging +import re + +from collections import defaultdict +from dataclasses import dataclass, field +from queue import Queue, PriorityQueue, Empty +from threading import Event, Semaphore, Thread +from typing import Iterator, Optional, Type, TypedDict +from typing_extensions import Unpack + +from services.ai_services import AIServiceConfig +from message_handler.message_handler import MessageHandler + +frame_idx = 0 + + +class AsyncProcessorState: + # Setting class variables, other synchronous activities + INIT = 0 + + # Making asynchronous requests to LLM and other services to render response + PREPARING = 1 + + # Ready to start presenting to user (but may not have all data yet) + READY = 2 + + # Playing response + PLAYING = 3 + + # An interrupt has been requested and the response is shutting down in-flight processing + INTERRUPTING = 4 + + # An interrupt has been requested and the response is finished stopping in-flight processing + INTERRUPTED = 5 + + # Response has been played or interrupted + DONE = 6 + + # Response is being finalized (updating records of speech, updating LLM context, etc.) + FINALIZING = 7 + + # Response is complete. This could mean that everything is updated, or that the response + # was interrupted. + FINALIZED = 8 + + state_transitions = { + INIT: [PREPARING, INTERRUPTING], + PREPARING: [READY, INTERRUPTING], + READY: [PLAYING, INTERRUPTING], + PLAYING: [DONE, INTERRUPTING], + INTERRUPTING: [INTERRUPTED], + INTERRUPTED: [DONE], + DONE: [FINALIZING], + FINALIZING: [FINALIZED], + FINALIZED: [FINALIZED], + } + + +@dataclass(order=True) +class StateTransitionItem: + state: int + evt: Event = field(compare=False) + +class AsyncProcessor: + def __init__( + self, + services: AIServiceConfig + ) -> None: + self.state = AsyncProcessorState.INIT + self.prepare_thread = None + self.play_thread = None + self.finalize_thread = None + + self.services: AIServiceConfig = services + + self.state_transition_semaphore = Semaphore() + self.waiting_for_state_changes = PriorityQueue() + self.state_queue = Queue() + + self.state_change_callbacks = defaultdict(list) + + self.was_interrupted = False + + self.logger = logging.getLogger("bot-instance") + + def set_state(self, state: int) -> None: + if state in AsyncProcessorState.state_transitions[self.state]: + self.state_transition_semaphore.acquire() + + self.state = state + self.state_transition_semaphore.release() + + # wake up any threads waiting for this state transition + try: + while True: + waiter = self.waiting_for_state_changes.get_nowait() + if waiter.state <= state: + waiter.evt.set() + else: + self.waiting_for_state_changes.put(waiter) + break + except Empty: + pass + + # make all the callbacks for this state + for callback in self.state_change_callbacks[state]: + callback(self) + else: + self.logger.error( + f"Invalid state transition from {self.state} to {state} in {self.__class__.__name__}" + ) + raise Exception(f"Invalid state transition from {self.state} to {state}") + + # + # This is used for state transitions that could be blocked by an interruption. + # If we are interrupted, we silently fail this call. Use only if you know that + # this state transition should fail if the processor has been interrupted. + # + + def maybe_set_state(self, state: int) -> bool: + if state in AsyncProcessorState.state_transitions[self.state]: + self.set_state(state) + return True + else: + return False + + def wait_for_state_transition(self, state: int) -> None: + if self.state >= state: + return + + self.state_transition_semaphore.acquire() + + evt = Event() + self.waiting_for_state_changes.put(StateTransitionItem(state, evt)) + self.state_transition_semaphore.release() + result = evt.wait(120.0) + if not result: + self.logger.error( + f"Timed out waiting for state transition to {state} from {self.state}" + ) + + def set_state_callback(self, state: int, callback: callable) -> None: + self.state_change_callbacks[state].append(callback) + + def prepare(self) -> None: + self.prepare_thread = Thread(target=self.async_prepare, daemon=True) + self.prepare_thread.start() + self.wait_for_state_transition(AsyncProcessorState.READY) + + def play(self) -> None: + self.wait_for_state_transition(AsyncProcessorState.READY) + self.play_thread = Thread(target=self.async_play, daemon=True) + self.play_thread.start() + self.wait_for_state_transition(AsyncProcessorState.PLAYING) + + def finalize(self) -> None: + # don't finalize until we're done playing. + self.wait_for_state_transition(AsyncProcessorState.DONE) + self.set_state(AsyncProcessorState.FINALIZING) + self.do_finalization() + self.set_state(AsyncProcessorState.FINALIZED) + + def interrupt(self) -> None: + # nothing to interrupt if we're already finalizing or finalized, no-op + if self.state in [ + AsyncProcessorState.FINALIZING, + AsyncProcessorState.FINALIZED, + ]: + return + + self.set_state(AsyncProcessorState.INTERRUPTING) + self.was_interrupted = True + self.do_interruption() + self.set_state(AsyncProcessorState.INTERRUPTED) + self.set_state(AsyncProcessorState.DONE) + + def async_play(self) -> None: + self.logger.info(f"starting to play") + if self.maybe_set_state(AsyncProcessorState.PLAYING): + self.do_play() + self.maybe_set_state(AsyncProcessorState.DONE) + + def async_prepare(self) -> None: + self.set_state(AsyncProcessorState.PREPARING) + self.preparation_iterator = self.get_preparation_iterator() + self.set_state(AsyncProcessorState.READY) + for chunk in self.preparation_iterator: + if self.state not in [ + AsyncProcessorState.READY, + AsyncProcessorState.PLAYING, + ]: + break + + self.process_chunk(chunk) + + self.logger.info(f"Preparation done for {self.__class__.__name__}") + self.preparation_done() + + def preparation_done(self): + pass + + def get_preparation_iterator(self) -> Iterator: + yield None + + def process_chunk(self, chunk) -> None: + pass + + def do_interruption(self) -> None: + pass + + def do_play(self) -> None: + pass + + def do_finalization(self) -> None: + pass + + +class ResponseArgs(TypedDict): + services: AIServiceConfig + message_handler: MessageHandler + output_queue: Queue + +class Response(AsyncProcessor): + def __init__( + self, + services, + message_handler, + output_queue, + ) -> None: + super().__init__(services) + + self.message_handler: MessageHandler = message_handler + self.output_queue: Queue = output_queue + self.has_sent_first_frame = False + + self.chunks_in_preparation = Queue() + + #self.sprite_loader = sprite_loader.SpriteLoader() + + self.llm_responses: list[str] = [] + + def get_preparation_iterator(self) -> Iterator: + messages_for_llm = self.message_handler.get_llm_messages() + self.logger.error(f"messages for llm: {json.dumps(messages_for_llm, indent=2)}") + return self.clauses_from_chunks( + self.services.llm.run_llm_async(messages_for_llm) + ) + + def clauses_from_chunks(self, chunks) -> Iterator: + out = "" + for chunk in chunks: + if self.state not in [ + AsyncProcessorState.READY, + AsyncProcessorState.PLAYING, + ]: + break + + out += chunk + + if re.match(r"^.*[.!?]$", out): # it looks like a sentence + yield out.strip() + out = "" + + if out.strip(): + yield out.strip() + + def process_chunk(self, chunk) -> None: + # could also put other generators in this tuple + self.logger.info(f"putting chunk in preparation queue {chunk}") + + def get_frames_from_chunk(chunk): + image_list = [ + "sc-talk", + "sc-default", + "sc-default", + "sc-default", + "sc-talk", + "sc-default", + "sc-default", + "sc-default", + "sc-default", + "sc-talk", + "sc-talk", + "sc-default", + "sc-default", + "sc-talk", + "sc-talk", + "sc-default", + "sc-talk", + "sc-default", + "sc-default", + "sc-default", + "sc-talk", + "sc-talk", + "sc-talk", + "sc-talk", + "sc-talk", + "sc-talk", + "sc-default", + "sc-default", + "sc-talk", + "sc-talk", + ] + image_list_idx = 0 + for frame in self.services.tts.run_tts(chunk): + yield (bytearray(frame), None) #self.sprite_loader.get_sprite_bytes(image_list[image_list_idx])) + image_list_idx = (image_list_idx + 1) % len(image_list) + + self.chunks_in_preparation.put((chunk, get_frames_from_chunk(chunk))) + + def preparation_done(self): + self.chunks_in_preparation.put((None, None)) + + def do_play(self) -> None: + while True: + if self.state not in [ + AsyncProcessorState.READY, + AsyncProcessorState.PLAYING, + ]: + break + prepared_chunk = self.chunks_in_preparation.get() + if prepared_chunk[0] is None: + return + + self.play_prepared_chunk(prepared_chunk) + + def play_prepared_chunk(self, prepared_chunk) -> None: + chunk, tts_generator = prepared_chunk + global frame_idx + for tts_chunk in tts_generator: + if self.state not in [ + AsyncProcessorState.READY, + AsyncProcessorState.PLAYING, + ]: + break + + if not self.has_sent_first_frame: + self.output_queue.put({"type": "start_stream", "idx": frame_idx}) + frame_idx += 1 + self.has_sent_first_frame = True + + (audio_frame, video_frame) = tts_chunk + self.output_queue.put( + {"type": "image_frame", "data": video_frame, "idx": frame_idx} + ) + self.output_queue.put( + {"type": "audio_frame", "data": audio_frame, "idx": frame_idx + 1} + ) + frame_idx += 2 + + self.output_queue.join() + self.llm_responses.append(chunk) + + def do_finalization(self) -> None: + self.message_handler.add_assistant_messages(self.llm_responses) + + def do_interruption(self) -> None: + self.chunks_in_preparation.put((None, None)) + + if self.prepare_thread and self.prepare_thread.is_alive(): + self.prepare_thread.join() + + if self.play_thread and self.play_thread.is_alive(): + self.play_thread.join() + + +@dataclass +class ConversationProcessorCollection: + introduction: Optional[Type[Response]] = None + waiting: Optional[Type[Response]] = None + response: Optional[Type[Response]] = None + goodbye: Optional[Type[Response]] = None diff --git a/message_handler/message_handler.py b/message_handler/message_handler.py new file mode 100644 index 000000000..0c3850497 --- /dev/null +++ b/message_handler/message_handler.py @@ -0,0 +1,144 @@ +import logging +import time + +from dataclasses import dataclass +from queue import Queue, Empty +from threading import Thread + +from storage.search import SearchIndexer +from services.ai_services import AIServiceConfig + + +@dataclass +class Message: + type: str + timestamp: float + message: str + + +class MessageHandler: + def __init__(self, intro): + self.messages = [Message("system", time.time(), intro)] + self.last_user_message_idx = None + + def add_user_message(self, message): + if (self.last_user_message_idx is not None and self.last_user_message_idx != self.finalized_user_message_idx): + previous_message = self.messages[self.last_user_message_idx].message + self.messages[self.last_user_message_idx] = Message( + "user", time.time(), ' '.join([previous_message, message]) + ) + self.messages = self.messages[: self.last_user_message_idx + 1] + else: + self.messages.append(Message("user", time.time(), message)) + + self.last_user_message_idx = len(self.messages) - 1 + + def add_assistant_message(self, message): + if self.messages[-1].type == "assistant": + self.messages[-1].message += " " + message + else: + self.messages.append(Message("assistant", time.time(), message)) + + def add_assistant_messages(self, messages): + self.messages.append(Message("assistant", time.time(), " ".join(messages))) + + def get_llm_messages(self): + return [{"role": m.type, "content": m.message} for m in self.messages] + + def finalize_user_message(self): + pass + + def shutdown(self): + pass + +class IndexingMessageHandler(MessageHandler): + def __init__( + self, intro, services: AIServiceConfig, indexer: SearchIndexer + ) -> None: + super().__init__(intro) + self.services = services + + self.search_indexer = indexer + + self.last_written_idx = 0 + self.index_message_queue = Queue() + + self.index_writer_thread = Thread(target=self.indexer_writer, daemon=True) + self.index_writer_thread.start() + + self.finalized_user_message_idx = None + + self.logger = logging.getLogger("bot-instance") + + def shutdown(self): + self.finalize_user_message() + self.index_message_queue.put(None) + self.index_writer_thread.join() + + def indexer_writer(self): + while True: + try: + message_idx = self.index_message_queue.get() + self.index_message_queue.task_done() + + if message_idx is None: + return + + if message_idx <= self.last_written_idx: + continue + + self.last_written_idx = message_idx + + message = self.messages[message_idx] + content = message.message + if message.type == "user": + content = self.cleanup_user_message(content) + + # sometimes the LLM returns a string wrapped in quotes and sometimes it doesn't. + # if it didn't, wrap it in quotes + if content[0] != '"': + content = '"' + content + '"' + + self.search_indexer.index_text(content) + except Empty: + pass + + def cleanup_user_message(self, user_message) -> str: + messages = [ + { + "role": "system", + "content": """ + You are an assistant who is very good at making transcriptions + of human speech into well-capitalized and punctuated text, without + changing any words or the order of the words. Please change this + transcription to something suitable for the printed page. + """, + }, + {"role": "user", "content": user_message}, + ] + result = self.services.llm.run_llm(messages) + if result: + user_message = result + + return user_message + + def finalize_user_message(self): + self.finalized_user_message_idx = self.last_user_message_idx + self.write_messages_to_index() + + def write_messages_to_index(self): + if self.finalized_user_message_idx is None: + return + + for idx in range(self.last_written_idx, len(self.messages)): + self.logger.info( + f"writing to index: {self.messages[idx].type} {self.messages[idx].message}" + ) + if ( + self.messages[idx].type == "user" + and idx > self.finalized_user_message_idx + ): + break + + if self.messages[idx].type != "system": + self.index_message_queue.put(idx) diff --git a/orchestrator.py b/orchestrator.py new file mode 100644 index 000000000..84ece2667 --- /dev/null +++ b/orchestrator.py @@ -0,0 +1,453 @@ +import logging +import os +import time +import wave + +from dataclasses import dataclass +from queue import Queue, Empty + +from daily_ai.async_processor import ( + AsyncProcessor, + AsyncProcessorState, + ConversationProcessorCollection, + Response, +) +from daily_ai.services.ai_services import AIServiceConfig +from daily_ai.message_handler import MessageHandler + +from threading import Thread, Semaphore, Event, Timer + +from opentelemetry import context +from opentelemetry.context.context import Context + +from daily import ( + EventHandler, + CallClient, + Daily, + VirtualCameraDevice, + VirtualMicrophoneDevice, + VirtualSpeakerDevice, +) + + +@dataclass +class OrchestratorConfig: + room_url: str + token: str + bot_name: str + expiration: float + + +class Orchestrator(EventHandler): + def __init__( + self, + daily_config: OrchestratorConfig, + ai_service_config: AIServiceConfig, + conversation_processors: ConversationProcessorCollection, + message_handler: MessageHandler, + tracer, + ): + self.bot_name: str = daily_config.bot_name + self.room_url: str = daily_config.room_url + self.token: str = daily_config.token + self.expiration: float = daily_config.expiration + + self.logger: logging.Logger = logging.getLogger("bot-instance") + self.tracer = tracer + self.ctx: Context = context.get_current() + + self.transcription = "" + self.last_fragment_at = None + self.talked_at = None + self.paused_at = None + + self.logger.info(f"Creating Response for introductions") + self.services: AIServiceConfig = ai_service_config + self.output_queue = Queue() + self.is_interrupted = Event() + self.stop_threads = Event() + self.story_started = False + + self.message_handler = message_handler + + if conversation_processors.introduction is not None: + intro = conversation_processors.introduction( + services=self.services, message_handler=self.message_handler, output_queue=self.output_queue + ) + intro.prepare() + intro.set_state_callback(AsyncProcessorState.DONE, self.on_intro_played) + intro.set_state_callback(AsyncProcessorState.FINALIZED, self.on_intro_finished) + self.logger.info(f"Response is preparing") + + self.current_response: AsyncProcessor = intro + self.can_interrupt = False + # self.response_event.set() + self.response_semaphore = Semaphore() + + self.speech_timeout = None + self.interrupt_time = None + + self.logger.info("configuring daily") + self.configure_daily() + + def configure_daily(self): + Daily.init() + self.client = CallClient(event_handler=self) + + self.logger.info(f"mic sample rate: {self.services.tts.get_mic_sample_rate()}") + self.mic: VirtualMicrophoneDevice = Daily.create_microphone_device( + "mic", sample_rate=self.services.tts.get_mic_sample_rate(), channels=1 + ) + self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device( + "speaker", sample_rate=16000, channels=1 + ) + self.camera: VirtualCameraDevice = Daily.create_camera_device( + "camera", width=720, height=1280, color_format="RGB" + ) + + Daily.select_speaker_device("speaker") + + self.client.set_user_name(self.bot_name) + self.client.join(self.room_url, self.token, completion=self.call_joined) + + self.client.update_inputs( + { + "camera": { + "isEnabled": True, + "settings": { + "deviceId": "camera", + }, + }, + "microphone": { + "isEnabled": True, + "settings": { + "deviceId": "mic", + "customConstraints": { + "autoGainControl": {"exact": False}, + "echoCancellation": {"exact": False}, + "noiseSuppression": {"exact": False}, + }, + }, + }, + } + ) + + self.client.update_publishing( + { + "camera": { + "sendSettings": { + "maxQuality": "low", + "encodings": { + "low": { + "maxBitrate": 250000, + "scaleResolutionDownBy": 1.333, + "maxFramerate": 8, + } + }, + } + } + } + ) + + self.my_participant_id = self.client.participants()["local"]["id"] + + def start(self) -> None: + # TODO: this loop could, I think, be replaced with a timer and an event + self.participant_left = False + + try: + participant_count: int = len(self.client.participants()) + self.logger.info(f"{participant_count} participants in room") + while time.time() < self.expiration and not self.participant_left: + # all handling of incoming transcriptions happens in on_transcription_message + time.sleep(1) + except Exception as e: + self.logger.error(f"Exception {e}") + finally: + self.client.leave() + + def stop(self): + self.logger.info("stop current response") + if self.current_response: + if self.current_response.state < AsyncProcessorState.INTERRUPTED: + self.current_response.interrupt() + + self.logger.info("wait for state transition") + self.current_response.wait_for_state_transition(AsyncProcessorState.FINALIZED) + + self.stop_threads.set() + self.camera_thread.join() + self.logger.info("camera thread stopped") + + self.logger.info("put stop in output queue") + self.output_queue.put({"type": "stop"}) + + self.frame_consumer_thread.join() + self.logger.info("orchestrator stopped.") + + def on_intro_played(self, intro): + self.can_interrupt = True + intro.finalize() + + def on_intro_finished(self, intro): + pass + + def on_response_played(self, response): + response.finalize() + self.display_waiting() + + def on_response_finished(self, response): + if not response.was_interrupted: + self.message_handler.finalize_user_message() + + def call_joined(self, join_data, client_error): + self.logger.info(f"call_joined: {join_data}, {client_error}") + self.client.start_transcription( + { + "language": "en", + "tier": "nova", + "model": "2-conversationalai", + "profanity_filter": True, + "redact": False, + "extra": { + "endpointing": True, + "punctuate": False, + } + } + ) + + def on_participant_joined(self, participant): + with self.tracer.start_as_current_span("on_participant_joined", context=self.ctx): + self.logger.info(f"on_participant_joined: {participant}") + + # TODO: figure out the architecture to get the story id to the client + # self.client.send_app_message({"event": "story-id", "storyID": self.story_id}) + time.sleep(2) + + if not self.story_started: + self.action() + self.story_started = True + + def on_participant_left(self, participant, reason): + if len(self.client.participants()) < 2: + self.logger.info("participant left") + self.participant_left = True + + def on_app_message(self, message, sender): + with self.tracer.start_as_current_span("on_app_message", context=self.ctx): + self.logger.info(f"on_app_message {message} from {sender}") + if "isSpeaking" in message and message["isSpeaking"] == True: + self.handle_user_started_talking() + + if "isSpeaking" in message and message["isSpeaking"] == False: + self.handle_user_stopped_talking() + + def on_transcription_message(self, message): + with self.tracer.start_as_current_span("on_transcription_message", context=self.ctx): + if message["session_id"] != self.my_participant_id: + self.handle_transcription_fragment(message['text']) + + def on_transcription_stopped(self, stopped_by, stopped_by_error): + self.logger.info(f"transcription stopped {stopped_by}, {stopped_by_error}") + + def on_transcription_error(self, message): + self.logger.error(f"transcription error {message}") + + def on_transcription_started(self, status): + self.logger.info(f"transcription started {status}") + + def set_image(self, image: bytes): + self.image: bytes | None = image + + def run_camera(self): + try: + while not self.stop_threads.is_set(): + if self.image: + self.camera.write_frame(self.image) + + time.sleep(1.0 / 8.0) # 8 fps + except Exception as e: + self.logger.error(f"Exception {e} in camera thread.") + + print("==== camera thread exitings") + + def handle_user_started_talking(self): + # TODO: allow configuration of the timer timeout + self.logger.error("user started talking") + self.speech_timeout = Timer(1.0, self.utterance_interrupt) + + def handle_user_stopped_talking(self): + self.logger.error("user stopped talking, canceling utterance interrupt") + if self.speech_timeout: + self.speech_timeout.cancel() + + def utterance_interrupt(self): + self.logger.error("utterance interrupt") + self.is_interrupted.set() + + def handle_transcription_fragment(self, fragment): + if not self.can_interrupt: + return + + # start generating a new response. We'll do the fast parts of the interrupt + # now but wait for the state transition after we've kicked off the prepare + # on the new response. + if ( + self.current_response + and self.current_response.state < AsyncProcessorState.INTERRUPTED + ): + self.interrupt_time = time.perf_counter() + self.is_interrupted.set() + self.current_response.interrupt() + + self.display_thinking() + self.message_handler.add_user_message(fragment) + + new_response = Response(self.services, self.message_handler, self.output_queue) + new_response.set_state_callback( + AsyncProcessorState.DONE, self.on_response_played + ) + new_response.set_state_callback( + AsyncProcessorState.FINALIZED, self.on_response_finished + ) + new_response.prepare() + + self.response_semaphore.acquire() + if ( + self.current_response + and self.current_response.state < AsyncProcessorState.INTERRUPTED + ): + self.current_response.wait_for_state_transition( + AsyncProcessorState.FINALIZED + ) + + self.current_response = new_response + self.current_response.play() + + self.response_semaphore.release() + + def display_waiting(self): + # I don't love this design, need to think more about how to do this well + listening_images = [ + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-2", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-2", + "sc-listen-1", + "sc-listen-2", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-2", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + "sc-listen-2", + "sc-listen-1", + "sc-listen-1", + "sc-listen-1", + ] + #self.display_images(listening_images) + + def display_thinking(self): + thinking_images = [ + "sc-think-1", + "sc-think-1", + "sc-think-2", + "sc-think-2", + "sc-think-3", + "sc-think-3", + "sc-think-4", + "sc-think-4", + ] + #self.display_images(thinking_images) + + def action(self): + self.logger.info("starting camera thread") + self.image: bytes | None = None + self.camera_thread = Thread(target=self.run_camera, daemon=True) + self.camera_thread.start() + + self.frame_consumer_thread = Thread(target=self.frame_consumer, daemon=True) + self.frame_consumer_thread.start() + + self.can_interrupt = False + self.current_response.play() + + def frame_consumer(self): + self.logger.info("🎬 Starting frame consumer thread") + b = bytearray() + smallest_write_size = 3200 + expected_idx = 0 + all_audio_frames = bytearray() + while True: + try: + frame = self.output_queue.get() + if frame["type"] == "stop": + self.logger.info("🎬 Stopping frame consumer thread") + + if os.getenv("WRITE_BOT_AUDIO", False): + filename = f"conversation-{len(all_audio_frames)}.wav" + with wave.open(filename, "wb") as f: + f.setnchannels(1) + f.setframerate(16000) + f.setsampwidth(2) + f.setcomptype("NONE", "not compressed") + f.writeframes(all_audio_frames) + return + + if frame["idx"] != expected_idx and frame["idx"] != 0: + self.logger.error( + f"🎬 Expected frame {expected_idx}, got {frame['idx']}" + ) + + expected_idx += 1 + + # if interrupted, we just pull frames off the queue and discard them + if not self.is_interrupted.is_set(): + if frame: + if frame["type"] == "audio_frame": + chunk = frame["data"] + + all_audio_frames.extend(chunk) + + b.extend(chunk) + l = len(b) - (len(b) % smallest_write_size) + if l: + self.mic.write_frames(bytes(b[:l])) + b = b[l:] + elif frame["type"] == "image_frame": + self.set_image(frame["data"]) + elif len(b): + self.mic.write_frames(bytes(b)) + b = bytearray() + else: + if self.interrupt_time: + self.logger.info(f"====== lag to stop stream ====== {time.perf_counter() - self.interrupt_time}") + self.interrupt_time = None + + if frame["type"] == "start_stream": + self.is_interrupted.clear() + + self.output_queue.task_done() + except Empty: + try: + if len(b): + self.mic.write_frames(bytes(b)) + except Exception as e: + self.logger.error(f"Exception in frame_consumer: {e}, {len(b)}") + + b = bytearray() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..69cf2c592 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +Pillow==10.1.0 +typing_extensions==4.9.0 diff --git a/services/ai_services.py b/services/ai_services.py new file mode 100644 index 000000000..0a04da8b6 --- /dev/null +++ b/services/ai_services.py @@ -0,0 +1,56 @@ +import logging + +from abc import abstractmethod +from dataclasses import dataclass +from typing import Generator +from PIL import Image + + +class AIService: + def __init__(self): + self.logger = logging.getLogger("bot-instance") + + def close(self): + pass + + +class LLMService(AIService): + # Generate a set of responses to a prompt. Yields a list of responses. + @abstractmethod + def run_llm_async( + self, messages + ) -> Generator[str, None, None]: + pass + + # Generate a responses to a prompt. Returns the response + @abstractmethod + def run_llm( + self, messages + ) -> str or None: + pass + + +class TTSService(AIService): + # Some TTS services require a specific sample rate. We default to 16k + def get_mic_sample_rate(self): + return 16000 + + # Converts the sentence to audio. Yields a list of audio frames that can + # be sent to the microphone device + @abstractmethod + def run_tts(self, sentence) -> Generator[bytes, None, None]: + pass + + +class ImageGenService(AIService): + # Renders the image. Returns an Image object. + @abstractmethod + def run_image_gen(self, sentence) -> Image.Image: + pass + + +@dataclass +class AIServiceConfig: + tts: TTSService + image: ImageGenService + llm: LLMService diff --git a/services/azure_ai_services.py b/services/azure_ai_services.py new file mode 100644 index 000000000..f0ab18fb1 --- /dev/null +++ b/services/azure_ai_services.py @@ -0,0 +1,116 @@ +import json +import io +import openai +import os +import requests + +from typing import Generator + +from daily_ai.services.ai_services import LLMService, TTSService, ImageGenService +from PIL import Image + +# See .env.example for Azure configuration needed +from azure.cognitiveservices.speech import SpeechSynthesizer, SpeechConfig, ResultReason, CancellationReason + +class AzureTTSService(TTSService): + def __init__(self): + super().__init__() + + self.speech_key = os.getenv("AZURE_SPEECH_SERVICE_KEY") + self.speech_region = os.getenv("AZURE_SPEECH_SERVICE_REGION") + + self.speech_config = SpeechConfig(subscription=self.speech_key, region=self.speech_region) + self.speech_synthesizer = SpeechSynthesizer(speech_config=self.speech_config, audio_config=None) + + def run_tts(self, sentence) -> Generator[bytes, None, None]: + self.logger.info("⌨️ running azure tts async") + ssml = "" \ + "" \ + "" \ + "" \ + "" \ + f"{sentence}" \ + " " + result = self.speech_synthesizer.speak_ssml(ssml) + self.logger.info("⌨️ got azure tts result") + if result.reason == ResultReason.SynthesizingAudioCompleted: + self.logger.info("⌨️ returning result") + # azure always sends a 44-byte header. Strip it off. + yield result.audio_data[44:] + elif result.reason == ResultReason.Canceled: + cancellation_details = result.cancellation_details + self.logger.info("Speech synthesis canceled: {}".format(cancellation_details.reason)) + if cancellation_details.reason == CancellationReason.Error: + self.logger.info("Error details: {}".format(cancellation_details.error_details)) + +class AzureLLMService(LLMService): + def get_response(self, messages, stream): + return openai.ChatCompletion.create( + api_type="azure", + api_version="2023-06-01-preview", + api_key=os.getenv("AZURE_CHATGPT_KEY"), + api_base=os.getenv("AZURE_CHATGPT_ENDPOINT"), + deployment_id=os.getenv("AZURE_CHATGPT_DEPLOYMENT_ID"), + stream=stream, + messages=messages, + ) + + + def run_llm_async(self, messages) -> Generator[str, None, None]: + local_messages = messages.copy() + messages_for_log = json.dumps(local_messages) + self.logger.info(f"==== generating chat via azure: {messages_for_log}") + + response = self.get_response(local_messages, stream=True) + + for chunk in response: + if len(chunk["choices"]) == 0: + continue + + if "content" in chunk["choices"][0]["delta"]: + if ( + chunk["choices"][0]["delta"]["content"] != {} + ): # streaming a content chunk + yield chunk["choices"][0]["delta"]["content"] + + + def run_llm(self, messages) -> str or None: + local_messages = messages.copy() + messages_for_log = json.dumps(local_messages) + self.logger.info(f"==== generating chat via azure: {messages_for_log}") + + response = self.get_response(local_messages, stream=False) + if ( + response + and len(response["choices"]) > 0 + and "message" in response["choices"][0] + and "content" in response["choices"][0]["message"] + ): + return response["choices"][0]["message"]["content"] + else: + return None + + +class AzureImageGenService(ImageGenService): + def run_image_gen(self, sentence) -> Image.Image: + self.logger.info("generating azure image", sentence) + + image = openai.Image.create( + api_type = 'azure', + api_version = '2023-06-01-preview', + api_key = os.getenv('AZURE_DALLE_KEY'), + api_base = os.getenv('AZURE_DALLE_ENDPOINT'), + deployment_id = os.getenv("AZURE_DALLE_DEPLOYMENT_ID"), + prompt=f'{sentence} in the style of {self.image_style}', + n=1, + size=f"1024x1024", + ) + + url = image["data"][0]["url"] + response = requests.get(url) + + dalle_stream = io.BytesIO(response.content) + dalle_im = Image.open(dalle_stream) + + return (url, dalle_im) diff --git a/services/cloudflare_ai_service.py b/services/cloudflare_ai_service.py new file mode 100644 index 000000000..c249da58a --- /dev/null +++ b/services/cloudflare_ai_service.py @@ -0,0 +1,65 @@ +import requests +import os +from services.ai_service import AIService + +# Note that Cloudflare's AI workers are still in beta. +# https://developers.cloudflare.com/workers-ai/ +class CloudflareAIService(AIService): + def __init__(self): + super().__init__() + self.cloudflare_account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID") + self.cloudflare_api_token = os.getenv("CLOUDFLARE_API_TOKEN") + + self.api_base_url = f'https://api.cloudflare.com/client/v4/accounts/{self.cloudflare_account_id}/ai/run/' + self.headers = {"Authorization": f'Bearer {self.cloudflare_api_token}'} + + # base endpoint, used by the others + def run(self, model, input): + response = requests.post(f"{self.api_base_url}{model}", headers=self.headers, json=input) + return response.json() + + # https://developers.cloudflare.com/workers-ai/models/llm/ + def run_llm(self, messages, latest_user_message=None, stream = True): + input = { + "messages": [ + { "role": "system", "content": "You are a friendly assistant" }, + { "role": "user", "content": sentence } + ] + } + + return self.run("@cf/meta/llama-2-7b-chat-int8", input) + + # https://developers.cloudflare.com/workers-ai/models/translation/ + def run_text_translation(self, sentence, source_language, target_language): + return self.run('@cf/meta/m2m100-1.2b', { + "text": sentence, + "source_lang": source_language, + "target_lang": target_language + }) + + # https://developers.cloudflare.com/workers-ai/models/sentiment-analysis/ + def run_text_sentiment(self, sentence): + return self.run("@cf/huggingface/distilbert-sst-2-int8", {"text": sentence}) + + # https://developers.cloudflare.com/workers-ai/models/image-classification/ + def run_image_classification(self, image_url): + response = requests.get(image_url) + + if response.status_code != 200: + return {"error": "There was a problem downloading the image."} + + if response.status_code == 200: + data = response.content + inputs = {"image": list(data)} + + return self.run("@cf/microsoft/resnet-50", inputs) + + # https://developers.cloudflare.com/workers-ai/models/embedding/ + def run_embeddings(self, texts, size="medium"): + models = { + "small": "@cf/baai/bge-small-en-v1.5", # 384 output dimensions + "medium": "@cf/baai/bge-base-en-v1.5", # 768 output dimensions + "large": "@cf/baai/bge-large-en-v1.5" #1024 output dimensions + } + + return self.run(models[size], {"text": texts}) diff --git a/services/deepgram_ai_service.py b/services/deepgram_ai_service.py new file mode 100644 index 000000000..77686e927 --- /dev/null +++ b/services/deepgram_ai_service.py @@ -0,0 +1,28 @@ +import os +import requests + +from services.ai_service import AIService +from PIL import Image + + +class DeepgramAIService(AIService): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.api_key = os.getenv("DEEPGRAM_API_KEY") + + def get_mic_sample_rate(self): + return 24000 + + def run_tts(self, sentence): + self.logger.info(f"running deepgram tts for {sentence}") + base_url = "https://api.beta.deepgram.com/v1/speak" + voice = os.getenv("DEEPGRAM_VOICE") or "alpha-apollo-en-v1" # move this to an environment variable + request_url = f"{base_url}?model={voice}&encoding=linear16&container=none" + headers = {"authorization": f"token {self.api_key}"} + + r = requests.post(request_url, headers=headers, data=sentence) + self.logger.info( + f"audio fetch status code: {r.status_code}, content length: {len(r.content)}" + ) + yield r.content diff --git a/services/elevenlabs_ai_service.py b/services/elevenlabs_ai_service.py new file mode 100644 index 000000000..180927760 --- /dev/null +++ b/services/elevenlabs_ai_service.py @@ -0,0 +1,38 @@ +import os +import requests +import time + +from typing import Generator + +from daily_ai.services.ai_services import TTSService + + +class ElevenLabsTTSService(TTSService): + def __init__(self): + super().__init__() + + self.api_key = os.getenv("ELEVENLABS_API_KEY") + self.voice_id = os.getenv("ELEVENLABS_VOICE_ID") + + def run_tts(self, sentence) -> Generator[bytes, None, None]: + url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.voice_id}/stream" + payload = {"text": sentence, "model_id": "eleven_turbo_v2"} + querystring = {"output_format": "pcm_16000", "optimize_streaming_latency": 2} + headers = { + "xi-api-key": self.api_key, + "Content-Type": "application/json", + } + + r = requests.request( + "POST", url, json=payload, headers=headers, params=querystring, stream=True + ) + + if r.status_code != 200: + self.logger.error( + f"audio fetch status code: {r.status_code}, error: {r.text}" + ) + return + + for chunk in r.iter_content(chunk_size=3200): + if chunk: + yield chunk diff --git a/services/google_ai_service.py b/services/google_ai_service.py new file mode 100644 index 000000000..c8b0715df --- /dev/null +++ b/services/google_ai_service.py @@ -0,0 +1,26 @@ +from services.ai_service import AIService +import openai +import os + +# To use Google Cloud's AI products, you'll need to install Google Cloud CLI and enable the TTS and in your project: https://cloud.google.com/sdk/docs/install +from google.cloud import texttospeech + +class GoogleAIService(AIService): + def __init__(self): + super().__init__() + + self.client = texttospeech.TextToSpeechClient() + self.voice = texttospeech.VoiceSelectionParams( + language_code="en-GB", name="en-GB-Neural2-F" + ) + + self.audio_config = texttospeech.AudioConfig( + audio_encoding = texttospeech.AudioEncoding.LINEAR16, + sample_rate_hertz = 16000 + ) + + def run_tts(self, sentence): + print("running google tts") + synthesis_input = texttospeech.SynthesisInput(text = sentence.strip()) + result = self.client.synthesize_speech(input=synthesis_input, voice=self.voice, audio_config=self.audio_config) + return result diff --git a/services/huggingface_ai_service.py b/services/huggingface_ai_service.py new file mode 100644 index 000000000..4492cda26 --- /dev/null +++ b/services/huggingface_ai_service.py @@ -0,0 +1,26 @@ +from services.ai_service import AIService +from transformers import pipeline + +# These functions are just intended for testing, not production use. If you'd like to use HuggingFace, you should use your own models, or do some research into the specific models that will work best for your use case. +class HuggingFaceAIService(AIService): + def __init__(self): + super().__init__() + + def run_text_sentiment(self, sentence): + classifier = pipeline("sentiment-analysis") + return classifier(sentence) + + # available models at https://huggingface.co/Helsinki-NLP (**not all models use 2-character language codes**) + def run_text_translation(self, sentence, source_language, target_language): + translator = pipeline(f"translation", model=f"Helsinki-NLP/opus-mt-{source_language}-{target_language}") + print(translator(sentence)) + + return translator(sentence)[0]["translation_text"] + + def run_text_summarization(self, sentence): + summarizer = pipeline("summarization") + return summarizer(sentence) + + def run_image_classification(self, image_path): + classifier = pipeline("image-classification") + return classifier(image_path) diff --git a/services/mock_ai_service.py b/services/mock_ai_service.py new file mode 100644 index 000000000..27d8154bf --- /dev/null +++ b/services/mock_ai_service.py @@ -0,0 +1,27 @@ +import io +import requests +import time +from PIL import Image +from services.ai_service import AIService + +class MockAIService(AIService): + def __init__(self): + super().__init__() + + def run_tts(self, sentence): + print("running tts", sentence) + time.sleep(2) + + def run_image_gen(self, sentence): + image_url = "https://d3d00swyhr67nd.cloudfront.net/w800h800/collection/ASH/ASHM/ASH_ASHM_WA1940_2_22-001.jpg" + response = requests.get(image_url) + image_stream = io.BytesIO(response.content) + image = Image.open(image_stream) + time.sleep(1) + return (image_url, image) + + def run_llm(self, messages, latest_user_message=None, stream = True): + for i in range(5): + time.sleep(1) + yield({"choices": [{"delta": {"content": f"hello {i}!"}}]}) + diff --git a/services/open_ai_service.py b/services/open_ai_service.py new file mode 100644 index 000000000..88ce45ac1 --- /dev/null +++ b/services/open_ai_service.py @@ -0,0 +1,57 @@ +from services.ai_service import AIService +import requests +from PIL import Image +import io +import openai +import os +import time +import json + +class OpenAIService(AIService): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def run_llm(self, messages, latest_user_message=None, stream = True): + local_messages = messages.copy() + if latest_user_message: + local_messages.append({"role": "user", "content": latest_user_message}) + messages_for_log = json.dumps(local_messages, indent=2) + self.logger.info(f"==== generating chat via openai: {messages_for_log}") + + model = os.getenv("OPEN_AI_MODEL") + if not model: + model = "gpt-4" + response = openai.ChatCompletion.create( + api_type = 'openai', + api_version = '2020-11-07', + api_base = "https://api.openai.com/v1", + api_key = os.getenv("OPEN_AI_KEY"), + model=model, + stream=stream, + messages=local_messages + ) + + return response + + def run_image_gen(self, sentence): + self.logger.info("🖌️ generating openai image async for ", sentence) + start = time.time() + + image = openai.Image.create( + api_type = 'openai', + api_version = '2020-11-07', + api_base = "https://api.openai.com/v1", + api_key = os.getenv("OPEN_AI_KEY"), + prompt=f'{sentence} in the style of {self.image_style}', + n=1, + size=f"1024x1024", + ) + image_url = image["data"][0]["url"] + self.logger.info("🖌️ generated image from url", image["data"][0]["url"]) + response = requests.get(image_url) + self.logger.info("🖌️ got image from url", response) + dalle_stream = io.BytesIO(response.content) + dalle_im = Image.open(dalle_stream) + self.logger.info("🖌️ total time", time.time() - start) + + return (image_url, dalle_im) diff --git a/services/playht_ai_service.py b/services/playht_ai_service.py new file mode 100644 index 000000000..d38c59b72 --- /dev/null +++ b/services/playht_ai_service.py @@ -0,0 +1,56 @@ +import io +import os +import struct +from pyht import Client +from dotenv import load_dotenv +from pyht.client import TTSOptions +from pyht.protos.api_pb2 import Format + +from services.ai_service import AIService + +class PlayHTAIService(AIService): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + self.speech_key = os.getenv("PLAY_HT_KEY") or '' + self.user_id = os.getenv("PLAY_HT_USER_ID") or '' + + self.client = Client( + user_id=self.user_id, + api_key=self.speech_key, + ) + self.options = TTSOptions( + voice="s3://voice-cloning-zero-shot/820da3d2-3a3b-42e7-844d-e68db835a206/sarah/manifest.json", + sample_rate=16000, + quality="higher", + format=Format.FORMAT_WAV + ) + + def close(self): + super().close() + self.client.close() + + def run_tts(self, sentence): + b = bytearray() + in_header = True + for chunk in self.client.tts(sentence, self.options): + # skip the RIFF header. + if in_header: + b.extend(chunk) + if len(b) <= 36: + continue + else: + fh = io.BytesIO(b) + fh.seek(36) + (data, size) = struct.unpack('<4sI', fh.read(8)) + self.logger.info(f"first attempt: data: {data}, size: {hex(size)}, position: {fh.tell()}") + while data != b'data': + fh.read(size) + (data, size) = struct.unpack('<4sI', fh.read(8)) + self.logger.info(f"subsequent data: {data}, size: {hex(size)}, position: {fh.tell()}, data != data: {data != b'data'}") + self.logger.info("position: ", fh.tell()) + in_header = False + else: + if len(chunk): + yield chunk + diff --git a/storage/search.py b/storage/search.py new file mode 100644 index 000000000..cb924b244 --- /dev/null +++ b/storage/search.py @@ -0,0 +1,50 @@ +import os +import random +import time + +""" +from algoliasearch.configs import SearchConfig +from algoliasearch.search_client import SearchClient +""" + +class SearchIndexer(): + def __init__(self, story_id): + pass + + def index_text(self, text): + pass + + def index_image(self, text): + pass +""" +class AlgoliaSearchIndexer(SearchIndexer): + def __init__(self, story_id): + self.index = None + self.story_id = story_id + + self.search_enabled = os.getenv('ALGOLIA_APP_ID') and os.getenv('ALGOLIA_API_KEY') + if self.search_enabled: + config = SearchConfig(os.getenv('ALGOLIA_APP_ID'), os.getenv('ALGOLIA_API_KEY')) + self.algolia = SearchClient.create_with_config(config) + self.index = self.algolia.init_index('daily-llm-conversations') + + def index_text(self, text): + if self.index: + res = self.index.save_object({ + "objectID": hex(random.getrandbits(128))[2:], + "storyID": self.story_id, + "type": "text", + "text": text, + "createdAt": int(time.time()) + }).wait() + + def index_image(self, url): + if self.index: + self.index.save_object({ + "objectID": hex(random.getrandbits(128))[2:], + "storyID": self.story_id, + "type": "image", + "image": url, + "createdAt": int(time.time()) + }).wait() +""" diff --git a/tests/test_asyncprocessor.py b/tests/test_asyncprocessor.py new file mode 100644 index 000000000..fe219e214 --- /dev/null +++ b/tests/test_asyncprocessor.py @@ -0,0 +1,179 @@ +import time +import unittest + +from queue import Queue, Empty +from threading import Thread, Event +from typing import Generator + +from services.ai_services import LLMService, TTSService, ImageGenService +from message_handler.message_handler import MessageHandler +from async_processor.async_processor import ( + AsyncProcessor, + AIServiceConfig, + AsyncProcessorState, + Response +) + +class MockTTSService(TTSService): + def run_tts(self, sentence): + for word in sentence.split(' '): + time.sleep(0.1) + yield bytes(word, "utf-8") + +class MockLLMService(LLMService): + def run_llm_async(self, messages) -> Generator[str, None, None]: + for i in ["Hello ", "there.", "How are ", "you?", "I ", "hope ", "you ", "are ", "well."]: + time.sleep(0.1) + yield i + + +class MockImageService(ImageGenService): + def run_image_gen(self, sentence) -> None: + return None + +class TestResponse(unittest.TestCase): + def test_base_state_transitions(self): + mock_tts_service = MockTTSService() + mock_llm_service = MockLLMService() + mock_image_service = MockImageService() + processor = AsyncProcessor(AIServiceConfig(tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service)) + processor.prepare() + processor.play() + processor.finalize() + self.assertEqual(processor.state, AsyncProcessorState.FINALIZED) + + def test_state_transitions(self): + output_queue = Queue() + mock_tts_service = MockTTSService() + mock_llm_service = MockLLMService() + mock_image_service = MockImageService() + message_handler = MessageHandler("Hello World") + processor = Response( + AIServiceConfig( + tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service + ), + message_handler, + output_queue, + ) + processor.prepare() + processor.play() + + # Consume the output from the output queue. It's necessary to mark these tasks as done for the + # play function to return. + expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."] + + # remove the "start_stream" message from the queue + output_queue.get() + output_queue.task_done() + + while expected_words: + # get the corresponding video frame off the queue. + output_queue.get() + output_queue.task_done() + + actual_word = output_queue.get() + word = expected_words.pop(0) + self.assertEqual(actual_word['type'], 'audio_frame') + self.assertEqual(actual_word['data'], bytes(word, "utf-8")) + output_queue.task_done() + + processor.finalize() + + self.assertEqual(processor.state, AsyncProcessorState.FINALIZED) + + def test_interrupt_preparation(self): + output_queue = Queue() + mock_tts_service = MockTTSService() + mock_llm_service = MockLLMService() + mock_image_service = MockImageService() + message_handler = MessageHandler("System Message") + processor = Response( + AIServiceConfig( + tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service + ), + message_handler, + output_queue, + ) + processor.prepare() + interrupt_request_at = time.perf_counter() + processor.interrupt() + processor.finalize() + finalized_at = time.perf_counter() + self.assertTrue(0.1 < finalized_at - interrupt_request_at < 0.2) + print(f"delta: {interrupt_request_at, finalized_at}") + self.assertEqual(processor.state, AsyncProcessorState.FINALIZED) + + def test_interrupt_play(self): + output_queue = Queue() + mock_tts_service = MockTTSService() + mock_llm_service = MockLLMService() + mock_image_service = MockImageService() + message_handler = MessageHandler("System Message") + processor = Response( + AIServiceConfig( + tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service + ), + message_handler, + output_queue, + ) + processor.prepare() + processor.play() + + stop_processing_output_queue = Event() + def process_output_queue_async(): + # Consume the output from the output queue. It's necessary to mark these tasks as done for the + # play function to return. + time.sleep(0.1) + expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."] + while expected_words and not stop_processing_output_queue.is_set(): + try: + actual_word = output_queue.get_nowait() + if actual_word['type'] == 'audio_frame': + time.sleep(0.1) + word = expected_words.pop(0) + self.assertEqual(actual_word['type'], 'audio_frame') + self.assertEqual(actual_word['data'], bytes(word, "utf-8")) + output_queue.task_done() + except Empty: + pass + + process_output_queue = Thread(target=process_output_queue_async, daemon=True) + process_output_queue.start() + + time.sleep(0.5) + processor.interrupt() + + stop_processing_output_queue.set() + process_output_queue.join() + + processor.finalize() + self.assertEqual(processor.state, AsyncProcessorState.FINALIZED) + + def test_statechange_callback(self): + mock_tts_service = MockTTSService() + mock_llm_service = MockLLMService() + mock_image_service = MockImageService() + processor = AsyncProcessor( + AIServiceConfig( + tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service + ) + ) + is_finalized = False + def set_is_finalized(async_processor:AsyncProcessor): + nonlocal is_finalized + is_finalized = True + + processor.set_state_callback( + AsyncProcessorState.FINALIZED, set_is_finalized + ) + processor.prepare() + self.assertFalse(is_finalized) + processor.play() + self.assertFalse(is_finalized) + processor.finalize() + self.assertTrue(is_finalized) + self.assertEqual(processor.state, AsyncProcessorState.FINALIZED) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_message_handler.py b/tests/test_message_handler.py new file mode 100644 index 000000000..0c4deed20 --- /dev/null +++ b/tests/test_message_handler.py @@ -0,0 +1,129 @@ +import time +import unittest + +from unittest.mock import MagicMock, call + +from message_handler.message_handler import MessageHandler, IndexingMessageHandler +from services.ai_services import AIService, AIServiceConfig +from storage.search import SearchIndexer + + +class TestMessageHandler(unittest.TestCase): + def test_simple_intro(self): + message_handler = MessageHandler("Hello world") + self.assertEqual( + message_handler.get_llm_messages(), + [{"role": "system", "content": "Hello world"}], + ) + + def test_simple_user_message(self): + message_handler = MessageHandler("System prompt") + message_handler.add_user_message("User message") + self.assertEqual( + message_handler.get_llm_messages(), + [ + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": "User message"}, + ], + ) + + def test_simple_user_and_assistant_message(self): + message_handler = MessageHandler("System prompt") + message_handler.add_user_message("User message") + message_handler.add_assistant_message("Assistant message") + self.assertEqual( + message_handler.get_llm_messages(), + [ + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": "User message"}, + {"role": "assistant", "content": "Assistant message"}, + ], + ) + + def test_user_message_overwrite(self): + message_handler = MessageHandler("System prompt") + message_handler.add_user_message("User message") + message_handler.add_assistant_message("Assistant message") + message_handler.add_user_message("User message plus something else") + self.assertEqual( + message_handler.get_llm_messages(), + [ + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": "User message plus something else"}, + ], + ) + + def test_user_message_after_assistant(self): + message_handler = MessageHandler("System prompt") + message_handler.add_user_message("User message") + message_handler.add_assistant_message("Assistant message") + message_handler.add_user_message("other user message") + self.assertEqual( + message_handler.get_llm_messages(), + [ + {"role": "system", "content": "System prompt"}, + {"role": "user", "content": "User message"}, + {"role": "assistant", "content": "Assistant message"}, + {"role": "user", "content": "other user message"}, + ], + ) + + +class MockAIService(AIService): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def run_llm(self, messages, latest_user_message=None, stream=True): + return {"choices": [{"message": {"content": "Parsed user message."}}]} + + +class TestIndexingMessageHandler(unittest.TestCase): + def test_user_message_finalized(self): + mock_ai_service = MockAIService() + service_config = AIServiceConfig( + mock_ai_service, mock_ai_service, mock_ai_service + ) + + mock_indexer = MagicMock(spec=SearchIndexer) + + message_handler = IndexingMessageHandler( + "Hello world", "story_id", service_config, mock_indexer + ) + message_handler.add_user_message("User message") + message_handler.add_assistant_message("Assistant message will be ignored") + message_handler.add_user_message("User message plus something else") + message_handler.finalize_user_message() + message_handler.add_assistant_message( + "New assistant message will not be ignored" + ) + message_handler.add_user_message("User message second time") + message_handler.add_assistant_message("Assistant message second time") + message_handler.write_messages_to_index() + + time.sleep(0.5) + + self.assertEqual( + mock_indexer.mock_calls, + [ + call.index_text("Parsed user message."), + call.index_text("New assistant message will not be ignored"), + ], + ) + + mock_indexer.reset_mock() + + message_handler.finalize_user_message() + + time.sleep(0.5) + + self.assertEqual( + mock_indexer.mock_calls, + [ + call.index_text("Parsed user message."), + call.index_text("Assistant message second time"), + ], + ) + + +if __name__ == "__main__": + unittest.main()