Add InterruptibleConversationWrapper
This commit is contained in:
@@ -1,347 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from queue import Queue, PriorityQueue, Empty
|
||||
from threading import Event, Semaphore, Thread
|
||||
from typing import Any, Generator, Iterator, Optional, Type
|
||||
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.message_handler.message_handler import MessageHandler
|
||||
from dailyai.services.ai_services import AIServiceConfig
|
||||
|
||||
class AsyncProcessorState:
|
||||
# Setting class variables, other synchronous activities
|
||||
INIT = 0
|
||||
|
||||
# Making asynchronous requests to LLM and other services to render response
|
||||
PREPARING = 1
|
||||
|
||||
# Ready to start presenting to user (but may not have all data yet)
|
||||
READY = 2
|
||||
|
||||
# Playing response
|
||||
PLAYING = 3
|
||||
|
||||
# An interrupt has been requested and the response is shutting down in-flight processing
|
||||
INTERRUPTING = 4
|
||||
|
||||
# An interrupt has been requested and the response is finished stopping in-flight processing
|
||||
INTERRUPTED = 5
|
||||
|
||||
# Response has been played or interrupted
|
||||
DONE = 6
|
||||
|
||||
# Response is being finalized (updating records of speech, updating LLM context, etc.)
|
||||
FINALIZING = 7
|
||||
|
||||
# Response is complete. This could mean that everything is updated, or that the response
|
||||
# was interrupted.
|
||||
FINALIZED = 8
|
||||
|
||||
state_transitions = {
|
||||
INIT: [PREPARING, INTERRUPTING],
|
||||
PREPARING: [READY, INTERRUPTING],
|
||||
READY: [PLAYING, INTERRUPTING],
|
||||
PLAYING: [DONE, INTERRUPTING],
|
||||
INTERRUPTING: [INTERRUPTED],
|
||||
INTERRUPTED: [DONE],
|
||||
DONE: [FINALIZING],
|
||||
FINALIZING: [FINALIZED],
|
||||
FINALIZED: [FINALIZED],
|
||||
}
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class StateTransitionItem:
|
||||
state: int
|
||||
evt: Event = field(compare=False)
|
||||
|
||||
class AsyncProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
services: AIServiceConfig
|
||||
) -> None:
|
||||
self.state = AsyncProcessorState.INIT
|
||||
self.prepare_thread = None
|
||||
self.play_thread = None
|
||||
self.finalize_thread = None
|
||||
|
||||
self.services: AIServiceConfig = services
|
||||
|
||||
self.state_transition_semaphore = Semaphore()
|
||||
self.waiting_for_state_changes = PriorityQueue()
|
||||
self.state_queue = Queue()
|
||||
|
||||
self.state_change_callbacks = defaultdict(list)
|
||||
|
||||
self.was_interrupted = False
|
||||
|
||||
self.logger: logging.Logger = logging.getLogger("dailyai")
|
||||
|
||||
def set_state(self, state: int) -> None:
|
||||
if state in AsyncProcessorState.state_transitions[self.state]:
|
||||
self.state_transition_semaphore.acquire()
|
||||
|
||||
self.state: int = state
|
||||
self.state_transition_semaphore.release()
|
||||
|
||||
# wake up any threads waiting for this state transition
|
||||
try:
|
||||
while True:
|
||||
waiter = self.waiting_for_state_changes.get_nowait()
|
||||
if waiter.state <= state:
|
||||
waiter.evt.set()
|
||||
else:
|
||||
self.waiting_for_state_changes.put(waiter)
|
||||
break
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
# make all the callbacks for this state
|
||||
for callback in self.state_change_callbacks[state]:
|
||||
callback(self)
|
||||
else:
|
||||
self.logger.error(
|
||||
f"Invalid state transition from {self.state} to {state} in {self.__class__.__name__}"
|
||||
)
|
||||
raise Exception(f"Invalid state transition from {self.state} to {state}")
|
||||
|
||||
#
|
||||
# This is used for state transitions that could be blocked by an interruption.
|
||||
# If we are interrupted, we silently fail this call. Use only if you know that
|
||||
# this state transition should fail if the processor has been interrupted.
|
||||
#
|
||||
|
||||
def maybe_set_state(self, state: int) -> bool:
|
||||
if state in AsyncProcessorState.state_transitions[self.state]:
|
||||
self.set_state(state)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def wait_for_state_transition(self, state: int) -> None:
|
||||
if self.state >= state:
|
||||
return
|
||||
|
||||
self.state_transition_semaphore.acquire()
|
||||
|
||||
evt = Event()
|
||||
self.waiting_for_state_changes.put(StateTransitionItem(state, evt))
|
||||
self.state_transition_semaphore.release()
|
||||
result = evt.wait(120.0)
|
||||
if not result:
|
||||
self.logger.error(
|
||||
f"Timed out waiting for state transition to {state} from {self.state}"
|
||||
)
|
||||
|
||||
def set_state_callback(self, state: int, callback: callable) -> None:
|
||||
self.state_change_callbacks[state].append(callback)
|
||||
|
||||
def prepare(self) -> None:
|
||||
self.prepare_thread = Thread(target=self.async_prepare, daemon=True)
|
||||
self.prepare_thread.start()
|
||||
self.wait_for_state_transition(AsyncProcessorState.READY)
|
||||
|
||||
def play(self) -> None:
|
||||
self.wait_for_state_transition(AsyncProcessorState.READY)
|
||||
self.play_thread = Thread(target=self.async_play, daemon=True)
|
||||
self.play_thread.start()
|
||||
self.wait_for_state_transition(AsyncProcessorState.PLAYING)
|
||||
|
||||
def finalize(self) -> None:
|
||||
# don't finalize until we're done playing.
|
||||
self.wait_for_state_transition(AsyncProcessorState.DONE)
|
||||
self.set_state(AsyncProcessorState.FINALIZING)
|
||||
self.do_finalization()
|
||||
self.set_state(AsyncProcessorState.FINALIZED)
|
||||
|
||||
def interrupt(self) -> None:
|
||||
# nothing to interrupt if we're already finalizing or finalized, no-op
|
||||
if self.state in [
|
||||
AsyncProcessorState.FINALIZING,
|
||||
AsyncProcessorState.FINALIZED,
|
||||
]:
|
||||
return
|
||||
|
||||
self.set_state(AsyncProcessorState.INTERRUPTING)
|
||||
self.was_interrupted = True
|
||||
self.do_interruption()
|
||||
self.set_state(AsyncProcessorState.INTERRUPTED)
|
||||
self.set_state(AsyncProcessorState.DONE)
|
||||
|
||||
def async_play(self) -> None:
|
||||
self.logger.info(f"Starting to play")
|
||||
if self.maybe_set_state(AsyncProcessorState.PLAYING):
|
||||
self.do_play()
|
||||
self.maybe_set_state(AsyncProcessorState.DONE)
|
||||
|
||||
def async_prepare(self) -> None:
|
||||
self.set_state(AsyncProcessorState.PREPARING)
|
||||
self.start_preparation()
|
||||
self.set_state(AsyncProcessorState.READY)
|
||||
self.continue_preparation()
|
||||
self.logger.info(f"Preparation done for {self.__class__.__name__}")
|
||||
self.preparation_done()
|
||||
|
||||
def start_preparation(self) -> None:
|
||||
pass
|
||||
|
||||
def continue_preparation(self) -> None:
|
||||
pass
|
||||
|
||||
def preparation_done(self):
|
||||
pass
|
||||
|
||||
def get_preparation_iterator(self) -> Iterator:
|
||||
yield None
|
||||
|
||||
def process_chunk(self, chunk) -> None:
|
||||
pass
|
||||
|
||||
def do_interruption(self) -> None:
|
||||
pass
|
||||
|
||||
def do_play(self) -> None:
|
||||
pass
|
||||
|
||||
def do_finalization(self) -> None:
|
||||
pass
|
||||
|
||||
# A common class for responses that use a message queue and
|
||||
# an output queue.
|
||||
|
||||
class OrchestratorResponse(AsyncProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
services,
|
||||
message_handler,
|
||||
output_queue,
|
||||
) -> None:
|
||||
super().__init__(services)
|
||||
|
||||
self.message_handler: MessageHandler = message_handler
|
||||
self.output_queue: Queue = output_queue
|
||||
|
||||
|
||||
class LLMResponse(OrchestratorResponse):
|
||||
def __init__(
|
||||
self,
|
||||
services,
|
||||
message_handler,
|
||||
output_queue,
|
||||
) -> None:
|
||||
super().__init__(services, message_handler, output_queue)
|
||||
|
||||
self.has_sent_first_frame = False
|
||||
|
||||
self.chunks_in_preparation = Queue()
|
||||
|
||||
self.llm_responses: list[str] = []
|
||||
|
||||
def get_preparation_iterator(self) -> Iterator:
|
||||
messages_for_llm = self.message_handler.get_llm_messages()
|
||||
self.logger.debug(f"Messages for llm: {json.dumps(messages_for_llm, indent=2)}")
|
||||
return self.clauses_from_chunks(
|
||||
self.services.llm.run_llm_async(messages_for_llm)
|
||||
)
|
||||
|
||||
def clauses_from_chunks(self, chunks) -> Iterator:
|
||||
out = ""
|
||||
for chunk in chunks:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
out += chunk
|
||||
|
||||
if re.match(r"^.*[.!?]$", out): # it looks like a sentence
|
||||
yield out.strip()
|
||||
out = ""
|
||||
|
||||
if out.strip():
|
||||
yield out.strip()
|
||||
|
||||
def get_frames_from_tts_response(self, audio_frame) -> list[QueueFrame]:
|
||||
return [QueueFrame(FrameType.AUDIO, audio_frame)]
|
||||
|
||||
def get_frames_from_chunk(self, chunk) -> Generator[list[QueueFrame], Any, None]:
|
||||
for audio_frame in self.services.tts.run_tts(chunk):
|
||||
yield self.get_frames_from_tts_response(audio_frame)
|
||||
|
||||
def start_preparation(self) -> None:
|
||||
self.preparation_iterator = self.get_preparation_iterator()
|
||||
|
||||
def continue_preparation(self) -> None:
|
||||
for chunk in self.preparation_iterator:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
self.process_chunk(chunk)
|
||||
|
||||
def process_chunk(self, chunk) -> None:
|
||||
self.chunks_in_preparation.put((chunk, self.get_frames_from_chunk(chunk)))
|
||||
|
||||
def preparation_done(self):
|
||||
self.chunks_in_preparation.put((None, None))
|
||||
|
||||
def do_play(self) -> None:
|
||||
while True:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
prepared_chunk = self.chunks_in_preparation.get()
|
||||
if prepared_chunk[0] == None:
|
||||
return
|
||||
|
||||
self.play_prepared_chunk(prepared_chunk)
|
||||
|
||||
def play_prepared_chunk(self, prepared_chunk) -> None:
|
||||
chunk, tts_generator = prepared_chunk
|
||||
for frames in tts_generator:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
if not self.has_sent_first_frame:
|
||||
self.output_queue.put(QueueFrame(FrameType.START_STREAM, None))
|
||||
self.has_sent_first_frame = True
|
||||
|
||||
for frame in frames:
|
||||
self.output_queue.put(frame)
|
||||
|
||||
self.output_queue.join()
|
||||
self.llm_responses.append(chunk)
|
||||
|
||||
def do_finalization(self) -> None:
|
||||
self.message_handler.add_assistant_messages(self.llm_responses)
|
||||
|
||||
def do_interruption(self) -> None:
|
||||
self.chunks_in_preparation.put((None, None))
|
||||
|
||||
if self.prepare_thread and self.prepare_thread.is_alive():
|
||||
self.prepare_thread.join()
|
||||
|
||||
if self.play_thread and self.play_thread.is_alive():
|
||||
self.play_thread.join()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConversationProcessorCollection:
|
||||
introduction: Optional[Type[OrchestratorResponse]] = None
|
||||
waiting: Optional[Type[OrchestratorResponse]] = None
|
||||
response: Optional[Type[OrchestratorResponse]] = None
|
||||
goodbye: Optional[Type[OrchestratorResponse]] = None
|
||||
76
src/dailyai/conversation_wrappers.py
Normal file
76
src/dailyai/conversation_wrappers.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
from typing import AsyncGenerator, Awaitable, Callable
|
||||
from dailyai.queue_aggregators import LLMContextAggregator
|
||||
from dailyai.queue_frame import EndStreamQueueFrame, QueueFrame, TranscriptionQueueFrame
|
||||
|
||||
class InterruptibleConversationWrapper:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frame_generator: Callable[[], AsyncGenerator[QueueFrame, None]],
|
||||
runner: Callable[
|
||||
[str, LLMContextAggregator, LLMContextAggregator], Awaitable[None]
|
||||
],
|
||||
interrupt: Callable[[], None],
|
||||
my_participant_id: str|None,
|
||||
llm_messages: list[dict[str, str]],
|
||||
llm_context_aggregator_in=LLMContextAggregator,
|
||||
llm_context_aggregator_out=LLMContextAggregator,
|
||||
delay_before_speech_seconds: float = 1.0,
|
||||
):
|
||||
self._frame_generator: Callable[[], AsyncGenerator[QueueFrame, None]] = frame_generator
|
||||
self._runner: Callable[
|
||||
[str, LLMContextAggregator, LLMContextAggregator], Awaitable[None]
|
||||
] = runner
|
||||
self._interrupt: Callable[[], None] = interrupt
|
||||
self._my_participant_id = my_participant_id
|
||||
self._messages: list[dict[str, str]] = llm_messages
|
||||
self._delay_before_speech_seconds = delay_before_speech_seconds
|
||||
self._llm_context_aggregator_in = llm_context_aggregator_in
|
||||
self._llm_context_aggregator_out = llm_context_aggregator_out
|
||||
|
||||
self._current_phrase = ""
|
||||
|
||||
def update_messages(self, new_messages: list[dict[str, str]], task: asyncio.Task | None):
|
||||
if task:
|
||||
if not task.cancelled():
|
||||
self._current_phrase = ""
|
||||
self._messages = new_messages
|
||||
|
||||
async def speak_after_delay(self, user_speech, messages):
|
||||
await asyncio.sleep(self._delay_before_speech_seconds)
|
||||
tma_in = self._llm_context_aggregator_in(
|
||||
messages, "user", self._my_participant_id, False
|
||||
)
|
||||
tma_out = self._llm_context_aggregator_out(
|
||||
messages, "assistant", self._my_participant_id
|
||||
)
|
||||
|
||||
await self._runner(user_speech, tma_in, tma_out)
|
||||
|
||||
async def run_conversation(self):
|
||||
current_response_task = None
|
||||
|
||||
async for frame in self._frame_generator():
|
||||
if isinstance(frame, EndStreamQueueFrame):
|
||||
break
|
||||
elif not isinstance(frame, TranscriptionQueueFrame):
|
||||
continue
|
||||
|
||||
if frame.participantId == self._my_participant_id:
|
||||
continue
|
||||
|
||||
if current_response_task:
|
||||
current_response_task.cancel()
|
||||
self._interrupt()
|
||||
|
||||
self._current_phrase += " " + frame.text
|
||||
current_llm_messages = copy.deepcopy(self._messages)
|
||||
current_response_task = asyncio.create_task(
|
||||
self.speak_after_delay(self._current_phrase, current_llm_messages)
|
||||
)
|
||||
current_response_task.add_done_callback(
|
||||
functools.partial(self.update_messages, current_llm_messages)
|
||||
)
|
||||
@@ -1,127 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread
|
||||
|
||||
from dailyai.storage.search import SearchIndexer
|
||||
from dailyai.services.ai_services import AIServiceConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
type: str
|
||||
timestamp: float
|
||||
message: str
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
def __init__(self, intro):
|
||||
self.messages: list[Message] = [Message("system", time.time(), intro)]
|
||||
self.last_user_message_idx:int | None = None
|
||||
self.finalized_user_message_idx: int | None = None
|
||||
|
||||
def add_user_message(self, message) -> None:
|
||||
if self.last_user_message_idx is not None and self.last_user_message_idx != self.finalized_user_message_idx:
|
||||
previous_message: str = self.messages[self.last_user_message_idx].message
|
||||
self.messages[self.last_user_message_idx] = Message(
|
||||
"user", time.time(), ' '.join([previous_message, message])
|
||||
)
|
||||
self.messages = self.messages[: self.last_user_message_idx + 1]
|
||||
else:
|
||||
self.messages.append(Message("user", time.time(), message))
|
||||
|
||||
self.last_user_message_idx = len(self.messages) - 1
|
||||
|
||||
def add_assistant_message(self, message) -> None:
|
||||
if self.messages[-1].type == "assistant":
|
||||
self.messages[-1].message += " " + message
|
||||
else:
|
||||
self.messages.append(Message("assistant", time.time(), message))
|
||||
|
||||
def add_assistant_messages(self, messages) -> None:
|
||||
self.messages.append(Message("assistant", time.time(), " ".join(messages)))
|
||||
|
||||
def get_llm_messages(self) -> list[dict[str, str]]:
|
||||
return [{"role": m.type, "content": m.message} for m in self.messages]
|
||||
|
||||
def finalize_user_message(self) -> None:
|
||||
self.finalized_user_message_idx = self.last_user_message_idx
|
||||
|
||||
def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
class IndexingMessageHandler(MessageHandler):
|
||||
def __init__(
|
||||
self, intro, services: AIServiceConfig, indexer: SearchIndexer
|
||||
) -> None:
|
||||
super().__init__(intro)
|
||||
self.services = services
|
||||
|
||||
self.search_indexer = indexer
|
||||
|
||||
self.last_written_idx = 0
|
||||
self.storage_message_queue = Queue()
|
||||
|
||||
self.index_writer_thread = Thread(target=self.storage_writer, daemon=True)
|
||||
self.index_writer_thread.start()
|
||||
|
||||
self.logger = logging.getLogger("dailyai")
|
||||
|
||||
def shutdown(self):
|
||||
self.finalize_user_message()
|
||||
self.storage_message_queue.put(None)
|
||||
self.index_writer_thread.join()
|
||||
|
||||
def storage_writer(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
message_idx = self.storage_message_queue.get()
|
||||
self.storage_message_queue.task_done()
|
||||
|
||||
if message_idx is None:
|
||||
return
|
||||
|
||||
if message_idx <= self.last_written_idx:
|
||||
continue
|
||||
|
||||
self.last_written_idx = message_idx
|
||||
|
||||
message = self.messages[message_idx]
|
||||
content = message.message
|
||||
if message.type == "user":
|
||||
content = self.cleanup_user_message(content)
|
||||
|
||||
# sometimes the LLM returns a string wrapped in quotes and sometimes it doesn't.
|
||||
# if it didn't, wrap it in quotes
|
||||
if content[0] != '"':
|
||||
content = '"' + content + '"'
|
||||
|
||||
self.search_indexer.index_text(content)
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
def cleanup_user_message(self, user_message) -> str:
|
||||
return user_message
|
||||
|
||||
def finalize_user_message(self):
|
||||
super().finalize_user_message()
|
||||
self.write_messages_to_storage()
|
||||
|
||||
def write_messages_to_storage(self):
|
||||
if self.finalized_user_message_idx is None:
|
||||
return
|
||||
|
||||
for idx in range(self.last_written_idx, len(self.messages)):
|
||||
self.logger.info(
|
||||
f"Writing to storage: {self.messages[idx].type} {self.messages[idx].message}"
|
||||
)
|
||||
if (
|
||||
self.messages[idx].type == "user"
|
||||
and idx > self.finalized_user_message_idx
|
||||
):
|
||||
break
|
||||
|
||||
if self.messages[idx].type != "system":
|
||||
self.storage_message_queue.put(idx)
|
||||
@@ -1,409 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import wave
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from queue import Queue, Empty
|
||||
from opentelemetry import trace, context
|
||||
|
||||
from dailyai.async_processor.async_processor import (
|
||||
AsyncProcessor,
|
||||
AsyncProcessorState,
|
||||
ConversationProcessorCollection,
|
||||
OrchestratorResponse,
|
||||
LLMResponse,
|
||||
)
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.ai_services import AIServiceConfig
|
||||
from dailyai.message_handler.message_handler import MessageHandler
|
||||
|
||||
from threading import Thread, Semaphore, Event, Timer
|
||||
|
||||
from opentelemetry import context
|
||||
from opentelemetry.context.context import Context
|
||||
|
||||
from daily import (
|
||||
EventHandler,
|
||||
CallClient,
|
||||
Daily,
|
||||
VirtualCameraDevice,
|
||||
VirtualMicrophoneDevice,
|
||||
VirtualSpeakerDevice,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestratorConfig:
|
||||
room_url: str
|
||||
token: str
|
||||
bot_name: str
|
||||
expiration: float
|
||||
|
||||
# Note that we use this as a default parameter value in the Orchestrator
|
||||
# constructor. The dataclass is defined with Frozen=True, so this should
|
||||
# be safe.
|
||||
default_conversation_collection = ConversationProcessorCollection(
|
||||
introduction=LLMResponse,
|
||||
waiting=None,
|
||||
response=LLMResponse,
|
||||
goodbye=None,
|
||||
)
|
||||
|
||||
|
||||
class Orchestrator(EventHandler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
daily_config: OrchestratorConfig,
|
||||
ai_service_config: AIServiceConfig,
|
||||
message_handler: MessageHandler,
|
||||
conversation_processors: ConversationProcessorCollection = default_conversation_collection,
|
||||
tracer=None,
|
||||
):
|
||||
self.bot_name: str = daily_config.bot_name
|
||||
self.room_url: str = daily_config.room_url
|
||||
self.token: str = daily_config.token
|
||||
self.expiration: float = daily_config.expiration
|
||||
|
||||
self.logger: logging.Logger = logging.getLogger("dailyai")
|
||||
self.tracer = tracer or trace.get_tracer("orchestrator")
|
||||
|
||||
self.ctx: Context = context.get_current()
|
||||
|
||||
self.transcription = ""
|
||||
self.last_fragment_at = None
|
||||
self.talked_at = None
|
||||
self.paused_at = None
|
||||
|
||||
self.logger.info(f"Creating Response for introductions")
|
||||
self.services: AIServiceConfig = ai_service_config
|
||||
self.output_queue = Queue()
|
||||
self.is_interrupted = Event()
|
||||
self.stop_threads = Event()
|
||||
self.story_started = False
|
||||
|
||||
self.message_handler = message_handler
|
||||
self.conversation_processors: ConversationProcessorCollection = conversation_processors
|
||||
|
||||
if conversation_processors.introduction is not None:
|
||||
intro = conversation_processors.introduction(
|
||||
services=self.services, message_handler=self.message_handler, output_queue=self.output_queue
|
||||
)
|
||||
intro.prepare()
|
||||
intro.set_state_callback(AsyncProcessorState.DONE, self.on_intro_played)
|
||||
intro.set_state_callback(AsyncProcessorState.FINALIZED, self.on_intro_finished)
|
||||
self.logger.info(f"Introduction is preparing")
|
||||
|
||||
self.current_response: AsyncProcessor = intro
|
||||
self.can_interrupt = False
|
||||
# self.response_event.set()
|
||||
self.response_semaphore = Semaphore()
|
||||
|
||||
self.speech_timeout = None
|
||||
self.interrupt_time = None
|
||||
|
||||
self.logger.info("Configuring daily")
|
||||
self.configure_daily()
|
||||
|
||||
def configure_daily(self):
|
||||
Daily.init()
|
||||
self.client = CallClient(event_handler=self)
|
||||
|
||||
self.logger.info(f"Mic sample rate: {self.services.tts.get_mic_sample_rate()}")
|
||||
self.mic: VirtualMicrophoneDevice = Daily.create_microphone_device(
|
||||
"mic", sample_rate=self.services.tts.get_mic_sample_rate(), channels=1
|
||||
)
|
||||
self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
|
||||
"speaker", sample_rate=16000, channels=1
|
||||
)
|
||||
self.camera: VirtualCameraDevice = Daily.create_camera_device(
|
||||
"camera", width=720, height=1280, color_format="RGB"
|
||||
)
|
||||
|
||||
Daily.select_speaker_device("speaker")
|
||||
|
||||
self.client.set_user_name(self.bot_name)
|
||||
self.client.join(self.room_url, self.token, completion=self.call_joined)
|
||||
|
||||
self.client.update_inputs(
|
||||
{
|
||||
"camera": {
|
||||
"isEnabled": True,
|
||||
"settings": {
|
||||
"deviceId": "camera",
|
||||
},
|
||||
},
|
||||
"microphone": {
|
||||
"isEnabled": True,
|
||||
"settings": {
|
||||
"deviceId": "mic",
|
||||
"customConstraints": {
|
||||
"autoGainControl": {"exact": False},
|
||||
"echoCancellation": {"exact": False},
|
||||
"noiseSuppression": {"exact": False},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
self.client.update_publishing(
|
||||
{
|
||||
"camera": {
|
||||
"sendSettings": {
|
||||
"maxQuality": "low",
|
||||
"encodings": {
|
||||
"low": {
|
||||
"maxBitrate": 250000,
|
||||
"scaleResolutionDownBy": 1.333,
|
||||
"maxFramerate": 8,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
self.my_participant_id = self.client.participants()["local"]["id"]
|
||||
|
||||
def start(self) -> None:
|
||||
# TODO: this loop could, I think, be replaced with a timer and an event
|
||||
self.participant_left = False
|
||||
|
||||
try:
|
||||
participant_count: int = len(self.client.participants())
|
||||
self.logger.info(f"{participant_count} participants in room")
|
||||
while time.time() < self.expiration and not self.participant_left:
|
||||
# all handling of incoming transcriptions happens in on_transcription_message
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e}")
|
||||
finally:
|
||||
self.client.leave()
|
||||
|
||||
def stop(self):
|
||||
self.logger.info("Stop current response")
|
||||
if self.current_response:
|
||||
if self.current_response.state < AsyncProcessorState.INTERRUPTED:
|
||||
self.current_response.interrupt()
|
||||
|
||||
self.logger.info("Wait for state transition")
|
||||
self.current_response.wait_for_state_transition(AsyncProcessorState.FINALIZED)
|
||||
|
||||
self.stop_threads.set()
|
||||
self.camera_thread.join()
|
||||
self.logger.info("Camera thread stopped")
|
||||
|
||||
self.logger.info("Put stop in output queue")
|
||||
self.output_queue.put(QueueFrame(FrameType.END_STREAM, None))
|
||||
|
||||
self.frame_consumer_thread.join()
|
||||
self.logger.info("Orchestrator stopped.")
|
||||
|
||||
def on_intro_played(self, intro):
|
||||
self.logger.info(f"Introduction has played")
|
||||
self.can_interrupt = True
|
||||
intro.finalize()
|
||||
|
||||
def on_intro_finished(self, intro):
|
||||
self.logger.info(f"Introduction has finished")
|
||||
waiting = self.conversation_processors.waiting(self.services, self.message_handler, self.output_queue)
|
||||
waiting.prepare()
|
||||
waiting.play()
|
||||
|
||||
def on_response_played(self, response):
|
||||
response.finalize()
|
||||
|
||||
def on_response_finished(self, response):
|
||||
if not response.was_interrupted:
|
||||
self.message_handler.finalize_user_message()
|
||||
|
||||
def call_joined(self, join_data, client_error):
|
||||
self.logger.info(f"Call_joined: {join_data}, {client_error}")
|
||||
self.client.start_transcription(
|
||||
{
|
||||
"language": "en",
|
||||
"tier": "nova",
|
||||
"model": "2-conversationalai",
|
||||
"profanity_filter": True,
|
||||
"redact": False,
|
||||
"extra": {
|
||||
"endpointing": True,
|
||||
"punctuate": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def on_participant_joined(self, participant):
|
||||
with self.tracer.start_as_current_span("on_participant_joined", context=self.ctx):
|
||||
self.logger.info(f"on_participant_joined: {participant}")
|
||||
|
||||
# TODO: figure out the architecture to get the story id to the client
|
||||
# self.client.send_app_message({"event": "story-id", "storyID": self.story_id})
|
||||
time.sleep(2)
|
||||
|
||||
if not self.story_started:
|
||||
self.action()
|
||||
self.story_started = True
|
||||
|
||||
def on_participant_left(self, participant, reason):
|
||||
self.logger.info(f"Participant {participant} left")
|
||||
if len(self.client.participants()) < 2:
|
||||
self.participant_left = True
|
||||
|
||||
def on_app_message(self, message, sender):
|
||||
with self.tracer.start_as_current_span("on_app_message", context=self.ctx):
|
||||
self.logger.info(f"on_app_message {message} from {sender}")
|
||||
if "isSpeaking" in message and message["isSpeaking"] == True:
|
||||
self.handle_user_started_talking()
|
||||
|
||||
if "isSpeaking" in message and message["isSpeaking"] == False:
|
||||
self.handle_user_stopped_talking()
|
||||
|
||||
def on_transcription_message(self, message):
|
||||
with self.tracer.start_as_current_span("on_transcription_message", context=self.ctx):
|
||||
if message["session_id"] != self.my_participant_id:
|
||||
self.handle_transcription_fragment(message['text'])
|
||||
|
||||
def on_transcription_stopped(self, stopped_by, stopped_by_error):
|
||||
self.logger.info(f"Transcription stopped {stopped_by}, {stopped_by_error}")
|
||||
|
||||
def on_transcription_error(self, message):
|
||||
self.logger.error(f"Transcription error {message}")
|
||||
|
||||
def on_transcription_started(self, status):
|
||||
self.logger.info(f"Transcription started {status}")
|
||||
|
||||
def set_image(self, image: bytes):
|
||||
self.image: bytes | None = image
|
||||
|
||||
def run_camera(self):
|
||||
try:
|
||||
while not self.stop_threads.is_set():
|
||||
if self.image:
|
||||
self.camera.write_frame(self.image)
|
||||
|
||||
time.sleep(1.0 / 8.0) # 8 fps
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e} in camera thread.")
|
||||
|
||||
def handle_user_started_talking(self):
|
||||
# TODO: allow configuration of the timer timeout
|
||||
self.logger.error("user started talking")
|
||||
self.speech_timeout = Timer(1.0, self.utterance_interrupt)
|
||||
|
||||
def handle_user_stopped_talking(self):
|
||||
self.logger.error("user stopped talking, canceling utterance interrupt")
|
||||
if self.speech_timeout:
|
||||
self.speech_timeout.cancel()
|
||||
|
||||
def utterance_interrupt(self):
|
||||
self.logger.error("utterance interrupt")
|
||||
self.is_interrupted.set()
|
||||
|
||||
def handle_transcription_fragment(self, fragment):
|
||||
if not self.can_interrupt:
|
||||
return
|
||||
|
||||
# start generating a new response. We'll do the fast parts of the interrupt
|
||||
# now but wait for the state transition after we've kicked off the prepare
|
||||
# on the new response.
|
||||
if (
|
||||
self.current_response
|
||||
and self.current_response.state < AsyncProcessorState.INTERRUPTED
|
||||
):
|
||||
self.interrupt_time = time.perf_counter()
|
||||
self.is_interrupted.set()
|
||||
self.current_response.interrupt()
|
||||
|
||||
self.message_handler.add_user_message(fragment)
|
||||
|
||||
response_type: type[OrchestratorResponse] | type[LLMResponse] = self.conversation_processors.response or LLMResponse
|
||||
new_response: OrchestratorResponse = response_type(
|
||||
self.services, self.message_handler, self.output_queue
|
||||
)
|
||||
new_response.set_state_callback(
|
||||
AsyncProcessorState.DONE, self.on_response_played
|
||||
)
|
||||
new_response.set_state_callback(
|
||||
AsyncProcessorState.FINALIZED, self.on_response_finished
|
||||
)
|
||||
new_response.prepare()
|
||||
|
||||
self.response_semaphore.acquire()
|
||||
if (
|
||||
self.current_response
|
||||
and self.current_response.state < AsyncProcessorState.INTERRUPTED
|
||||
):
|
||||
self.current_response.wait_for_state_transition(
|
||||
AsyncProcessorState.FINALIZED
|
||||
)
|
||||
|
||||
self.current_response = new_response
|
||||
self.current_response.play()
|
||||
|
||||
self.response_semaphore.release()
|
||||
|
||||
def action(self):
|
||||
self.logger.info("Starting camera thread")
|
||||
self.image: bytes | None = None
|
||||
self.camera_thread = Thread(target=self.run_camera, daemon=True)
|
||||
self.camera_thread.start()
|
||||
|
||||
self.logger.info("Starting frame consumer thread")
|
||||
self.frame_consumer_thread = Thread(target=self.frame_consumer, daemon=True)
|
||||
self.frame_consumer_thread.start()
|
||||
|
||||
self.logger.info("Playing introduction")
|
||||
self.can_interrupt = False
|
||||
self.current_response.play()
|
||||
|
||||
def frame_consumer(self):
|
||||
self.logger.info("🎬 Starting frame consumer thread")
|
||||
b = bytearray()
|
||||
smallest_write_size = 3200
|
||||
all_audio_frames = bytearray()
|
||||
while True:
|
||||
try:
|
||||
frame:QueueFrame = self.output_queue.get()
|
||||
if frame.frame_type == FrameType.END_STREAM:
|
||||
self.logger.info("Stopping frame consumer thread")
|
||||
return
|
||||
|
||||
# if interrupted, we just pull frames off the queue and discard them
|
||||
if not self.is_interrupted.is_set():
|
||||
if frame:
|
||||
if frame.frame_type == FrameType.AUDIO:
|
||||
chunk = frame.frame_data
|
||||
|
||||
all_audio_frames.extend(chunk)
|
||||
|
||||
b.extend(chunk)
|
||||
l = len(b) - (len(b) % smallest_write_size)
|
||||
if l:
|
||||
self.mic.write_frames(bytes(b[:l]))
|
||||
b = b[l:]
|
||||
elif frame.frame_type == FrameType.IMAGE:
|
||||
self.set_image(frame.frame_data)
|
||||
elif len(b):
|
||||
self.mic.write_frames(bytes(b))
|
||||
b = bytearray()
|
||||
else:
|
||||
if self.interrupt_time:
|
||||
self.logger.info(f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}")
|
||||
self.interrupt_time = None
|
||||
|
||||
if frame.frame_type == FrameType.START_STREAM:
|
||||
self.is_interrupted.clear()
|
||||
|
||||
self.output_queue.task_done()
|
||||
except Empty:
|
||||
try:
|
||||
if len(b):
|
||||
self.mic.write_frames(bytes(b))
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception in frame_consumer: {e}, {len(b)}")
|
||||
|
||||
b = bytearray()
|
||||
@@ -25,23 +25,24 @@ class QueueTee:
|
||||
await queue.put(frame)
|
||||
|
||||
class LLMContextAggregator(AIService):
|
||||
def __init__(self, messages: list[dict], role:str, bot_participant_id=None):
|
||||
def __init__(self, messages: list[dict], role:str, bot_participant_id=None, complete_sentences=True):
|
||||
self.messages = messages
|
||||
self.bot_participant_id = bot_participant_id
|
||||
self.role = role
|
||||
self.sentence = ""
|
||||
self.complete_sentences = complete_sentences
|
||||
|
||||
async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
content: str = ""
|
||||
|
||||
# TODO: split up transcription by participant
|
||||
if isinstance(frame, TextQueueFrame):
|
||||
content = frame.text
|
||||
|
||||
self.sentence += content
|
||||
if self.sentence.endswith((".", "?", "!")):
|
||||
self.messages.append({"role": self.role, "content": self.sentence})
|
||||
self.sentence = ""
|
||||
yield LLMMessagesQueueFrame(self.messages)
|
||||
if self.complete_sentences:
|
||||
self.sentence += frame.text
|
||||
if self.sentence.endswith((".", "?", "!")):
|
||||
self.messages.append({"role": self.role, "content": self.sentence})
|
||||
self.sentence = ""
|
||||
yield LLMMessagesQueueFrame(self.messages)
|
||||
else:
|
||||
self.messages.append({"role": self.role, "content": frame.text})
|
||||
yield LLMMessagesQueueFrame(self.messages)
|
||||
|
||||
yield frame
|
||||
|
||||
@@ -5,10 +5,13 @@ from typing import Any
|
||||
class QueueFrame:
|
||||
pass
|
||||
|
||||
class StartStreamQueueFrame(QueueFrame):
|
||||
class ControlQueueFrame(QueueFrame):
|
||||
pass
|
||||
|
||||
class EndStreamQueueFrame(QueueFrame):
|
||||
class StartStreamQueueFrame(ControlQueueFrame):
|
||||
pass
|
||||
|
||||
class EndStreamQueueFrame(ControlQueueFrame):
|
||||
pass
|
||||
|
||||
@dataclass()
|
||||
|
||||
@@ -7,6 +7,7 @@ import wave
|
||||
|
||||
from dailyai.queue_frame import (
|
||||
AudioQueueFrame,
|
||||
ControlQueueFrame,
|
||||
EndStreamQueueFrame,
|
||||
ImageQueueFrame,
|
||||
LLMMessagesQueueFrame,
|
||||
@@ -28,11 +29,15 @@ class AIService:
|
||||
pass
|
||||
|
||||
async def run_to_queue(self, queue: asyncio.Queue, frames, add_end_of_stream=False) -> None:
|
||||
async for frame in self.run(frames):
|
||||
await queue.put(frame)
|
||||
try:
|
||||
async for frame in self.run(frames):
|
||||
await queue.put(frame)
|
||||
|
||||
if add_end_of_stream:
|
||||
await queue.put(EndStreamQueueFrame())
|
||||
if add_end_of_stream:
|
||||
await queue.put(EndStreamQueueFrame())
|
||||
except Exception as e:
|
||||
print("Exception in run_to_queue", e)
|
||||
raise e
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -67,9 +72,8 @@ class AIService:
|
||||
|
||||
@abstractmethod
|
||||
async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
# This is a trick for the interpreter (and linter) to know that this is a generator.
|
||||
if False:
|
||||
yield QueueFrame()
|
||||
if isinstance(frame, ControlQueueFrame):
|
||||
yield frame
|
||||
|
||||
@abstractmethod
|
||||
async def finalize(self) -> AsyncGenerator[QueueFrame, None]:
|
||||
@@ -87,7 +91,9 @@ class LLMService(AIService):
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if isinstance(frame, LLMMessagesQueueFrame):
|
||||
if isinstance(frame, ControlQueueFrame):
|
||||
yield frame
|
||||
elif isinstance(frame, LLMMessagesQueueFrame):
|
||||
async for text_chunk in self.run_llm_async(frame.messages):
|
||||
yield TextQueueFrame(text_chunk)
|
||||
|
||||
@@ -110,9 +116,11 @@ class TTSService(AIService):
|
||||
yield bytes()
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not isinstance(frame, TextQueueFrame):
|
||||
if isinstance(frame, ControlQueueFrame):
|
||||
yield frame
|
||||
return
|
||||
elif not isinstance(frame, TextQueueFrame):
|
||||
return
|
||||
|
||||
text: str | None = None
|
||||
if not self.aggregate_sentences:
|
||||
@@ -149,6 +157,7 @@ class ImageGenService(AIService):
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not isinstance(frame, TextQueueFrame):
|
||||
yield frame
|
||||
return
|
||||
|
||||
(url, image_data) = await self.run_image_gen(frame.text)
|
||||
|
||||
@@ -35,10 +35,7 @@ class AzureTTSService(TTSService):
|
||||
"<prosody rate='1.05'>" \
|
||||
f"{sentence}" \
|
||||
"</prosody></mstts:express-as></voice></speak> "
|
||||
try:
|
||||
result = await asyncio.to_thread(self.speech_synthesizer.speak_ssml, (ssml))
|
||||
except Exception as e:
|
||||
self.logger.error("Error in azure tts", e)
|
||||
result = await asyncio.to_thread(self.speech_synthesizer.speak_ssml, (ssml))
|
||||
self.logger.info("Got azure tts result")
|
||||
if result.reason == ResultReason.SynthesizingAudioCompleted:
|
||||
self.logger.info("Returning result")
|
||||
|
||||
@@ -7,6 +7,7 @@ import types
|
||||
|
||||
from functools import partial
|
||||
from queue import Queue, Empty
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from dailyai.queue_frame import (
|
||||
AudioQueueFrame,
|
||||
@@ -14,6 +15,7 @@ from dailyai.queue_frame import (
|
||||
ImageQueueFrame,
|
||||
QueueFrame,
|
||||
StartStreamQueueFrame,
|
||||
TextQueueFrame,
|
||||
TranscriptionQueueFrame,
|
||||
)
|
||||
|
||||
@@ -114,6 +116,7 @@ class DailyTransportService(EventHandler):
|
||||
handler(*args, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception in event handler {event_name}: {e}")
|
||||
raise e
|
||||
|
||||
def add_event_handler(self, event_name: str, handler):
|
||||
if not event_name.startswith("on_"):
|
||||
@@ -214,7 +217,6 @@ class DailyTransportService(EventHandler):
|
||||
if self.token and self.start_transcription:
|
||||
self.client.start_transcription(self.transcription_settings)
|
||||
|
||||
|
||||
def _receive_audio(self):
|
||||
"""Receive audio from the Daily call and put it on the receive queue"""
|
||||
seconds = 1
|
||||
@@ -223,9 +225,13 @@ class DailyTransportService(EventHandler):
|
||||
buffer = self.speaker.read_frames(desired_frame_count)
|
||||
if len(buffer) > 0:
|
||||
frame = AudioQueueFrame(buffer)
|
||||
asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self.loop)
|
||||
if self.loop:
|
||||
asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self.loop)
|
||||
|
||||
async def get_receive_frames(self):
|
||||
def interrupt(self):
|
||||
self.is_interrupted.set()
|
||||
|
||||
async def get_receive_frames(self) -> AsyncGenerator[QueueFrame, None]:
|
||||
while True:
|
||||
frame = await self.receive_queue.get()
|
||||
yield frame
|
||||
@@ -265,6 +271,7 @@ class DailyTransportService(EventHandler):
|
||||
await asyncio.sleep(1)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e}")
|
||||
raise e
|
||||
finally:
|
||||
self.client.leave()
|
||||
|
||||
@@ -341,6 +348,7 @@ class DailyTransportService(EventHandler):
|
||||
time.sleep(1.0 / 8) # 8 fps
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e} in camera thread.")
|
||||
raise e
|
||||
|
||||
def frame_consumer(self):
|
||||
self.logger.info("🎬 Starting frame consumer thread")
|
||||
@@ -382,11 +390,11 @@ class DailyTransportService(EventHandler):
|
||||
self.mic.write_frames(bytes(b))
|
||||
b = bytearray()
|
||||
else:
|
||||
if self.interrupt_time:
|
||||
self.logger.info(
|
||||
f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}"
|
||||
)
|
||||
self.interrupt_time = None
|
||||
# if there are leftover audio bytes, write them now; failing to do so
|
||||
# can cause static in the audio stream.
|
||||
if len(b):
|
||||
self.mic.write_frames(bytes(b))
|
||||
b = bytearray()
|
||||
|
||||
if isinstance(frame, StartStreamQueueFrame):
|
||||
self.is_interrupted.clear()
|
||||
@@ -398,5 +406,6 @@ class DailyTransportService(EventHandler):
|
||||
self.mic.write_frames(bytes(b))
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception in frame_consumer: {e}, {len(b)}")
|
||||
raise e
|
||||
|
||||
b = bytearray()
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread, Event
|
||||
from typing import Generator
|
||||
|
||||
from dailyai.async_processor.async_processor import (
|
||||
AsyncProcessor,
|
||||
AsyncProcessorState,
|
||||
LLMResponse,
|
||||
)
|
||||
from dailyai.message_handler.message_handler import MessageHandler
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.ai_services import (
|
||||
AIServiceConfig,
|
||||
ImageGenService,
|
||||
LLMService,
|
||||
TTSService,
|
||||
)
|
||||
"""
|
||||
class MockTTSService(TTSService):
|
||||
def run_tts(self, sentence):
|
||||
for word in sentence.split(' '):
|
||||
time.sleep(0.1)
|
||||
yield bytes(word, "utf-8")
|
||||
|
||||
class MockLLMService(LLMService):
|
||||
def run_llm_async(self, messages) -> Generator[str, None, None]:
|
||||
for i in ["Hello ", "there.", "How are ", "you?", "I ", "hope ", "you ", "are ", "well."]:
|
||||
time.sleep(0.1)
|
||||
yield i
|
||||
|
||||
class MockImageService(ImageGenService):
|
||||
def run_image_gen(self, sentence) -> None:
|
||||
return None
|
||||
|
||||
class TestResponse(unittest.TestCase):
|
||||
def test_base_state_transitions(self):
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
processor = AsyncProcessor(AIServiceConfig(tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service))
|
||||
processor.prepare()
|
||||
processor.play()
|
||||
processor.finalize()
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_state_transitions(self):
|
||||
output_queue = Queue()
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
message_handler = MessageHandler("Hello World")
|
||||
processor = LLMResponse(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
),
|
||||
message_handler,
|
||||
output_queue,
|
||||
)
|
||||
processor.prepare()
|
||||
processor.play()
|
||||
|
||||
# Consume the output from the output queue. It's necessary to mark these tasks as done for the
|
||||
# play function to return.
|
||||
expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."]
|
||||
|
||||
# remove the "start_stream" message from the queue
|
||||
output_queue.get()
|
||||
output_queue.task_done()
|
||||
|
||||
while expected_words:
|
||||
actual_word:QueueFrame = output_queue.get()
|
||||
word = expected_words.pop(0)
|
||||
self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME)
|
||||
self.assertEqual(actual_word.frame_data, bytes(word, "utf-8"))
|
||||
output_queue.task_done()
|
||||
|
||||
processor.finalize()
|
||||
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_interrupt_preparation(self):
|
||||
output_queue = Queue()
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
message_handler = MessageHandler("System Message")
|
||||
processor = LLMResponse(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
),
|
||||
message_handler,
|
||||
output_queue,
|
||||
)
|
||||
processor.prepare()
|
||||
interrupt_request_at = time.perf_counter()
|
||||
processor.interrupt()
|
||||
processor.finalize()
|
||||
finalized_at = time.perf_counter()
|
||||
self.assertTrue(0.1 < finalized_at - interrupt_request_at < 0.2)
|
||||
print(f"delta: {interrupt_request_at, finalized_at}")
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_interrupt_play(self):
|
||||
output_queue = Queue()
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
message_handler = MessageHandler("System Message")
|
||||
processor = LLMResponse(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
),
|
||||
message_handler,
|
||||
output_queue,
|
||||
)
|
||||
processor.prepare()
|
||||
processor.play()
|
||||
|
||||
stop_processing_output_queue = Event()
|
||||
def process_output_queue_async():
|
||||
# Consume the output from the output queue. It's necessary to mark these tasks as done for the
|
||||
# play function to return.
|
||||
time.sleep(0.1)
|
||||
expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."]
|
||||
while expected_words and not stop_processing_output_queue.is_set():
|
||||
try:
|
||||
actual_word:QueueFrame = output_queue.get_nowait()
|
||||
if actual_word.frame_type == FrameType.AUDIO_FRAME:
|
||||
time.sleep(0.1)
|
||||
word = expected_words.pop(0)
|
||||
self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME)
|
||||
self.assertEqual(actual_word.frame_data, bytes(word, "utf-8"))
|
||||
output_queue.task_done()
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
process_output_queue = Thread(target=process_output_queue_async, daemon=True)
|
||||
process_output_queue.start()
|
||||
|
||||
time.sleep(0.5)
|
||||
processor.interrupt()
|
||||
|
||||
stop_processing_output_queue.set()
|
||||
process_output_queue.join()
|
||||
|
||||
processor.finalize()
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_statechange_callback(self):
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
processor = AsyncProcessor(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
)
|
||||
)
|
||||
is_finalized = False
|
||||
def set_is_finalized(async_processor:AsyncProcessor):
|
||||
nonlocal is_finalized
|
||||
is_finalized = True
|
||||
|
||||
processor.set_state_callback(
|
||||
AsyncProcessorState.FINALIZED, set_is_finalized
|
||||
)
|
||||
processor.prepare()
|
||||
self.assertFalse(is_finalized)
|
||||
processor.play()
|
||||
self.assertFalse(is_finalized)
|
||||
processor.finalize()
|
||||
self.assertTrue(is_finalized)
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
"""
|
||||
@@ -1,147 +0,0 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from unittest.mock import MagicMock, call
|
||||
|
||||
from dailyai.message_handler.message_handler import MessageHandler, IndexingMessageHandler
|
||||
from dailyai.services.ai_services import (
|
||||
AIServiceConfig,
|
||||
TTSService,
|
||||
LLMService,
|
||||
ImageGenService,
|
||||
)
|
||||
from ..storage.search import SearchIndexer
|
||||
|
||||
|
||||
class TestMessageHandler(unittest.TestCase):
|
||||
def test_simple_intro(self):
|
||||
message_handler = MessageHandler("Hello world")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[{"role": "system", "content": "Hello world"}],
|
||||
)
|
||||
|
||||
def test_simple_user_message(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_simple_user_and_assistant_message(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message"},
|
||||
{"role": "assistant", "content": "Assistant message"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_user_message_overwrite(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message")
|
||||
message_handler.add_user_message("plus something else")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message plus something else"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_user_message_after_assistant(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message")
|
||||
message_handler.finalize_user_message()
|
||||
message_handler.add_user_message("other user message")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message"},
|
||||
{"role": "assistant", "content": "Assistant message"},
|
||||
{"role": "user", "content": "other user message"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class MockTTSService(TTSService):
|
||||
def run_tts(self, sentence):
|
||||
for word in sentence.split(" "):
|
||||
time.sleep(0.1)
|
||||
yield bytes(word, "utf-8")
|
||||
|
||||
|
||||
class MockLLMService(LLMService):
|
||||
def run_llm(self, messages) -> str:
|
||||
return "Parsed user message."
|
||||
|
||||
class MockImageService(ImageGenService):
|
||||
def run_image_gen(self, sentence) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class TestStorageMessageHandler(unittest.TestCase):
|
||||
def test_user_message_finalized(self):
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
|
||||
service_config = AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
)
|
||||
|
||||
mock_indexer = MagicMock(spec=SearchIndexer)
|
||||
|
||||
message_handler = IndexingMessageHandler(
|
||||
"Hello world", service_config, mock_indexer
|
||||
)
|
||||
message_handler.cleanup_user_message = MagicMock(return_value="Parsed user message.")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message will be ignored")
|
||||
message_handler.add_user_message("plus something else")
|
||||
message_handler.finalize_user_message()
|
||||
message_handler.add_assistant_message(
|
||||
"New assistant message will not be ignored"
|
||||
)
|
||||
message_handler.add_user_message("User message second time")
|
||||
message_handler.add_assistant_message("Assistant message second time")
|
||||
message_handler.write_messages_to_storage()
|
||||
|
||||
time.sleep(0.5)
|
||||
message_handler.cleanup_user_message.assert_called_with("User message plus something else")
|
||||
self.assertEqual(
|
||||
mock_indexer.mock_calls,
|
||||
[
|
||||
call.index_text('"Parsed user message."'),
|
||||
call.index_text("New assistant message will not be ignored"),
|
||||
],
|
||||
)
|
||||
|
||||
mock_indexer.reset_mock()
|
||||
|
||||
message_handler.finalize_user_message()
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
self.assertEqual(
|
||||
mock_indexer.mock_calls,
|
||||
[
|
||||
call.index_text('"Parsed user message."'),
|
||||
call.index_text("Assistant message second time"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
98
src/samples/foundational/07-interruptible.py
Normal file
98
src/samples/foundational/07-interruptible.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import requests
|
||||
import time
|
||||
import urllib.parse
|
||||
from dailyai.conversation_wrappers import InterruptibleConversationWrapper
|
||||
|
||||
from dailyai.queue_frame import StartStreamQueueFrame, TextQueueFrame
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
|
||||
|
||||
async def main(room_url:str, token):
|
||||
global transport
|
||||
global llm
|
||||
global tts
|
||||
|
||||
transport = DailyTransportService(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
5,
|
||||
)
|
||||
transport.mic_enabled = True
|
||||
transport.mic_sample_rate = 16000
|
||||
transport.camera_enabled = False
|
||||
|
||||
llm = AzureLLMService()
|
||||
tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV")
|
||||
|
||||
async def run_response(user_speech, tma_in, tma_out):
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
tma_out.run(
|
||||
llm.run(
|
||||
tma_in.run(
|
||||
[StartStreamQueueFrame(), TextQueueFrame(user_speech)]
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_other_participant_joined")
|
||||
async def on_first_other_participant_joined(transport):
|
||||
await tts.say("Hi, I'm listening!", transport.send_queue)
|
||||
|
||||
async def run_conversation():
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio. Respond to what the user said in a creative and helpful way."},
|
||||
]
|
||||
|
||||
conversation_wrapper = InterruptibleConversationWrapper(
|
||||
frame_generator=transport.get_receive_frames,
|
||||
runner=run_response,
|
||||
interrupt=transport.interrupt,
|
||||
my_participant_id=transport.my_participant_id,
|
||||
llm_messages=messages,
|
||||
)
|
||||
await conversation_wrapper.run_conversation()
|
||||
|
||||
transport.transcription_settings["extra"]["punctuate"] = False
|
||||
await asyncio.gather(transport.run(), run_conversation())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple Daily Bot Sample")
|
||||
parser.add_argument(
|
||||
"-u", "--url", type=str, required=True, help="URL of the Daily room to join"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--apikey",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Daily API Key (needed to create token)",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
# Create a meeting token for the given room with an expiration 1 hour in the future.
|
||||
room_name: str = urllib.parse.urlparse(args.url).path[1:]
|
||||
expiration: float = time.time() + 60 * 60
|
||||
|
||||
res: requests.Response = requests.post(
|
||||
f"https://api.daily.co/v1/meeting-tokens",
|
||||
headers={"Authorization": f"Bearer {args.apikey}"},
|
||||
json={
|
||||
"properties": {"room_name": room_name, "is_owner": True, "exp": expiration}
|
||||
},
|
||||
)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to create meeting token: {res.status_code} {res.text}")
|
||||
|
||||
token: str = res.json()["token"]
|
||||
|
||||
asyncio.run(main(args.url, token))
|
||||
Reference in New Issue
Block a user