Compare commits
8 Commits
cb/09-bots
...
cleanup
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
646db8b9bd | ||
|
|
42c142aff0 | ||
|
|
6da78dbf9c | ||
|
|
f0d9b0613e | ||
|
|
a661905d7f | ||
|
|
c9c2e5f561 | ||
|
|
795a339542 | ||
|
|
31db156dfc |
@@ -16,7 +16,8 @@ dependencies = [
|
||||
"pyht",
|
||||
"opentelemetry-sdk",
|
||||
"aiohttp",
|
||||
"fal"
|
||||
"fal",
|
||||
"faster_whisper"
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
autopep8==2.0.4
|
||||
build==1.0.3
|
||||
packaging==23.2
|
||||
pyproject_hooks==1.0.0
|
||||
pyproject_hooks==1.0.0
|
||||
@@ -1,347 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from queue import Queue, PriorityQueue, Empty
|
||||
from threading import Event, Semaphore, Thread
|
||||
from typing import Any, Generator, Iterator, Optional, Type
|
||||
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.message_handler.message_handler import MessageHandler
|
||||
from dailyai.services.ai_services import AIServiceConfig
|
||||
|
||||
class AsyncProcessorState:
|
||||
# Setting class variables, other synchronous activities
|
||||
INIT = 0
|
||||
|
||||
# Making asynchronous requests to LLM and other services to render response
|
||||
PREPARING = 1
|
||||
|
||||
# Ready to start presenting to user (but may not have all data yet)
|
||||
READY = 2
|
||||
|
||||
# Playing response
|
||||
PLAYING = 3
|
||||
|
||||
# An interrupt has been requested and the response is shutting down in-flight processing
|
||||
INTERRUPTING = 4
|
||||
|
||||
# An interrupt has been requested and the response is finished stopping in-flight processing
|
||||
INTERRUPTED = 5
|
||||
|
||||
# Response has been played or interrupted
|
||||
DONE = 6
|
||||
|
||||
# Response is being finalized (updating records of speech, updating LLM context, etc.)
|
||||
FINALIZING = 7
|
||||
|
||||
# Response is complete. This could mean that everything is updated, or that the response
|
||||
# was interrupted.
|
||||
FINALIZED = 8
|
||||
|
||||
state_transitions = {
|
||||
INIT: [PREPARING, INTERRUPTING],
|
||||
PREPARING: [READY, INTERRUPTING],
|
||||
READY: [PLAYING, INTERRUPTING],
|
||||
PLAYING: [DONE, INTERRUPTING],
|
||||
INTERRUPTING: [INTERRUPTED],
|
||||
INTERRUPTED: [DONE],
|
||||
DONE: [FINALIZING],
|
||||
FINALIZING: [FINALIZED],
|
||||
FINALIZED: [FINALIZED],
|
||||
}
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class StateTransitionItem:
|
||||
state: int
|
||||
evt: Event = field(compare=False)
|
||||
|
||||
class AsyncProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
services: AIServiceConfig
|
||||
) -> None:
|
||||
self.state = AsyncProcessorState.INIT
|
||||
self.prepare_thread = None
|
||||
self.play_thread = None
|
||||
self.finalize_thread = None
|
||||
|
||||
self.services: AIServiceConfig = services
|
||||
|
||||
self.state_transition_semaphore = Semaphore()
|
||||
self.waiting_for_state_changes = PriorityQueue()
|
||||
self.state_queue = Queue()
|
||||
|
||||
self.state_change_callbacks = defaultdict(list)
|
||||
|
||||
self.was_interrupted = False
|
||||
|
||||
self.logger: logging.Logger = logging.getLogger("dailyai")
|
||||
|
||||
def set_state(self, state: int) -> None:
|
||||
if state in AsyncProcessorState.state_transitions[self.state]:
|
||||
self.state_transition_semaphore.acquire()
|
||||
|
||||
self.state: int = state
|
||||
self.state_transition_semaphore.release()
|
||||
|
||||
# wake up any threads waiting for this state transition
|
||||
try:
|
||||
while True:
|
||||
waiter = self.waiting_for_state_changes.get_nowait()
|
||||
if waiter.state <= state:
|
||||
waiter.evt.set()
|
||||
else:
|
||||
self.waiting_for_state_changes.put(waiter)
|
||||
break
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
# make all the callbacks for this state
|
||||
for callback in self.state_change_callbacks[state]:
|
||||
callback(self)
|
||||
else:
|
||||
self.logger.error(
|
||||
f"Invalid state transition from {self.state} to {state} in {self.__class__.__name__}"
|
||||
)
|
||||
raise Exception(f"Invalid state transition from {self.state} to {state}")
|
||||
|
||||
#
|
||||
# This is used for state transitions that could be blocked by an interruption.
|
||||
# If we are interrupted, we silently fail this call. Use only if you know that
|
||||
# this state transition should fail if the processor has been interrupted.
|
||||
#
|
||||
|
||||
def maybe_set_state(self, state: int) -> bool:
|
||||
if state in AsyncProcessorState.state_transitions[self.state]:
|
||||
self.set_state(state)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def wait_for_state_transition(self, state: int) -> None:
|
||||
if self.state >= state:
|
||||
return
|
||||
|
||||
self.state_transition_semaphore.acquire()
|
||||
|
||||
evt = Event()
|
||||
self.waiting_for_state_changes.put(StateTransitionItem(state, evt))
|
||||
self.state_transition_semaphore.release()
|
||||
result = evt.wait(120.0)
|
||||
if not result:
|
||||
self.logger.error(
|
||||
f"Timed out waiting for state transition to {state} from {self.state}"
|
||||
)
|
||||
|
||||
def set_state_callback(self, state: int, callback: callable) -> None:
|
||||
self.state_change_callbacks[state].append(callback)
|
||||
|
||||
def prepare(self) -> None:
|
||||
self.prepare_thread = Thread(target=self.async_prepare, daemon=True)
|
||||
self.prepare_thread.start()
|
||||
self.wait_for_state_transition(AsyncProcessorState.READY)
|
||||
|
||||
def play(self) -> None:
|
||||
self.wait_for_state_transition(AsyncProcessorState.READY)
|
||||
self.play_thread = Thread(target=self.async_play, daemon=True)
|
||||
self.play_thread.start()
|
||||
self.wait_for_state_transition(AsyncProcessorState.PLAYING)
|
||||
|
||||
def finalize(self) -> None:
|
||||
# don't finalize until we're done playing.
|
||||
self.wait_for_state_transition(AsyncProcessorState.DONE)
|
||||
self.set_state(AsyncProcessorState.FINALIZING)
|
||||
self.do_finalization()
|
||||
self.set_state(AsyncProcessorState.FINALIZED)
|
||||
|
||||
def interrupt(self) -> None:
|
||||
# nothing to interrupt if we're already finalizing or finalized, no-op
|
||||
if self.state in [
|
||||
AsyncProcessorState.FINALIZING,
|
||||
AsyncProcessorState.FINALIZED,
|
||||
]:
|
||||
return
|
||||
|
||||
self.set_state(AsyncProcessorState.INTERRUPTING)
|
||||
self.was_interrupted = True
|
||||
self.do_interruption()
|
||||
self.set_state(AsyncProcessorState.INTERRUPTED)
|
||||
self.set_state(AsyncProcessorState.DONE)
|
||||
|
||||
def async_play(self) -> None:
|
||||
self.logger.info(f"Starting to play")
|
||||
if self.maybe_set_state(AsyncProcessorState.PLAYING):
|
||||
self.do_play()
|
||||
self.maybe_set_state(AsyncProcessorState.DONE)
|
||||
|
||||
def async_prepare(self) -> None:
|
||||
self.set_state(AsyncProcessorState.PREPARING)
|
||||
self.start_preparation()
|
||||
self.set_state(AsyncProcessorState.READY)
|
||||
self.continue_preparation()
|
||||
self.logger.info(f"Preparation done for {self.__class__.__name__}")
|
||||
self.preparation_done()
|
||||
|
||||
def start_preparation(self) -> None:
|
||||
pass
|
||||
|
||||
def continue_preparation(self) -> None:
|
||||
pass
|
||||
|
||||
def preparation_done(self):
|
||||
pass
|
||||
|
||||
def get_preparation_iterator(self) -> Iterator:
|
||||
yield None
|
||||
|
||||
def process_chunk(self, chunk) -> None:
|
||||
pass
|
||||
|
||||
def do_interruption(self) -> None:
|
||||
pass
|
||||
|
||||
def do_play(self) -> None:
|
||||
pass
|
||||
|
||||
def do_finalization(self) -> None:
|
||||
pass
|
||||
|
||||
# A common class for responses that use a message queue and
|
||||
# an output queue.
|
||||
|
||||
class OrchestratorResponse(AsyncProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
services,
|
||||
message_handler,
|
||||
output_queue,
|
||||
) -> None:
|
||||
super().__init__(services)
|
||||
|
||||
self.message_handler: MessageHandler = message_handler
|
||||
self.output_queue: Queue = output_queue
|
||||
|
||||
|
||||
class LLMResponse(OrchestratorResponse):
|
||||
def __init__(
|
||||
self,
|
||||
services,
|
||||
message_handler,
|
||||
output_queue,
|
||||
) -> None:
|
||||
super().__init__(services, message_handler, output_queue)
|
||||
|
||||
self.has_sent_first_frame = False
|
||||
|
||||
self.chunks_in_preparation = Queue()
|
||||
|
||||
self.llm_responses: list[str] = []
|
||||
|
||||
def get_preparation_iterator(self) -> Iterator:
|
||||
messages_for_llm = self.message_handler.get_llm_messages()
|
||||
self.logger.debug(f"Messages for llm: {json.dumps(messages_for_llm, indent=2)}")
|
||||
return self.clauses_from_chunks(
|
||||
self.services.llm.run_llm_async(messages_for_llm)
|
||||
)
|
||||
|
||||
def clauses_from_chunks(self, chunks) -> Iterator:
|
||||
out = ""
|
||||
for chunk in chunks:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
out += chunk
|
||||
|
||||
if re.match(r"^.*[.!?]$", out): # it looks like a sentence
|
||||
yield out.strip()
|
||||
out = ""
|
||||
|
||||
if out.strip():
|
||||
yield out.strip()
|
||||
|
||||
def get_frames_from_tts_response(self, audio_frame) -> list[QueueFrame]:
|
||||
return [QueueFrame(FrameType.AUDIO, audio_frame)]
|
||||
|
||||
def get_frames_from_chunk(self, chunk) -> Generator[list[QueueFrame], Any, None]:
|
||||
for audio_frame in self.services.tts.run_tts(chunk):
|
||||
yield self.get_frames_from_tts_response(audio_frame)
|
||||
|
||||
def start_preparation(self) -> None:
|
||||
self.preparation_iterator = self.get_preparation_iterator()
|
||||
|
||||
def continue_preparation(self) -> None:
|
||||
for chunk in self.preparation_iterator:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
self.process_chunk(chunk)
|
||||
|
||||
def process_chunk(self, chunk) -> None:
|
||||
self.chunks_in_preparation.put((chunk, self.get_frames_from_chunk(chunk)))
|
||||
|
||||
def preparation_done(self):
|
||||
self.chunks_in_preparation.put((None, None))
|
||||
|
||||
def do_play(self) -> None:
|
||||
while True:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
prepared_chunk = self.chunks_in_preparation.get()
|
||||
if prepared_chunk[0] == None:
|
||||
return
|
||||
|
||||
self.play_prepared_chunk(prepared_chunk)
|
||||
|
||||
def play_prepared_chunk(self, prepared_chunk) -> None:
|
||||
chunk, tts_generator = prepared_chunk
|
||||
for frames in tts_generator:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
if not self.has_sent_first_frame:
|
||||
self.output_queue.put(QueueFrame(FrameType.START_STREAM, None))
|
||||
self.has_sent_first_frame = True
|
||||
|
||||
for frame in frames:
|
||||
self.output_queue.put(frame)
|
||||
|
||||
self.output_queue.join()
|
||||
self.llm_responses.append(chunk)
|
||||
|
||||
def do_finalization(self) -> None:
|
||||
self.message_handler.add_assistant_messages(self.llm_responses)
|
||||
|
||||
def do_interruption(self) -> None:
|
||||
self.chunks_in_preparation.put((None, None))
|
||||
|
||||
if self.prepare_thread and self.prepare_thread.is_alive():
|
||||
self.prepare_thread.join()
|
||||
|
||||
if self.play_thread and self.play_thread.is_alive():
|
||||
self.play_thread.join()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConversationProcessorCollection:
|
||||
introduction: Optional[Type[OrchestratorResponse]] = None
|
||||
waiting: Optional[Type[OrchestratorResponse]] = None
|
||||
response: Optional[Type[OrchestratorResponse]] = None
|
||||
goodbye: Optional[Type[OrchestratorResponse]] = None
|
||||
76
src/dailyai/conversation_wrappers.py
Normal file
76
src/dailyai/conversation_wrappers.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
from typing import AsyncGenerator, Awaitable, Callable
|
||||
from dailyai.queue_aggregators import LLMContextAggregator
|
||||
from dailyai.queue_frame import EndStreamQueueFrame, QueueFrame, TranscriptionQueueFrame
|
||||
|
||||
class InterruptibleConversationWrapper:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frame_generator: Callable[[], AsyncGenerator[QueueFrame, None]],
|
||||
runner: Callable[
|
||||
[str, LLMContextAggregator, LLMContextAggregator], Awaitable[None]
|
||||
],
|
||||
interrupt: Callable[[], None],
|
||||
my_participant_id: str|None,
|
||||
llm_messages: list[dict[str, str]],
|
||||
llm_context_aggregator_in=LLMContextAggregator,
|
||||
llm_context_aggregator_out=LLMContextAggregator,
|
||||
delay_before_speech_seconds: float = 1.0,
|
||||
):
|
||||
self._frame_generator: Callable[[], AsyncGenerator[QueueFrame, None]] = frame_generator
|
||||
self._runner: Callable[
|
||||
[str, LLMContextAggregator, LLMContextAggregator], Awaitable[None]
|
||||
] = runner
|
||||
self._interrupt: Callable[[], None] = interrupt
|
||||
self._my_participant_id = my_participant_id
|
||||
self._messages: list[dict[str, str]] = llm_messages
|
||||
self._delay_before_speech_seconds = delay_before_speech_seconds
|
||||
self._llm_context_aggregator_in = llm_context_aggregator_in
|
||||
self._llm_context_aggregator_out = llm_context_aggregator_out
|
||||
|
||||
self._current_phrase = ""
|
||||
|
||||
def update_messages(self, new_messages: list[dict[str, str]], task: asyncio.Task | None):
|
||||
if task:
|
||||
if not task.cancelled():
|
||||
self._current_phrase = ""
|
||||
self._messages = new_messages
|
||||
|
||||
async def speak_after_delay(self, user_speech, messages):
|
||||
await asyncio.sleep(self._delay_before_speech_seconds)
|
||||
tma_in = self._llm_context_aggregator_in(
|
||||
messages, "user", self._my_participant_id, False
|
||||
)
|
||||
tma_out = self._llm_context_aggregator_out(
|
||||
messages, "assistant", self._my_participant_id
|
||||
)
|
||||
|
||||
await self._runner(user_speech, tma_in, tma_out)
|
||||
|
||||
async def run_conversation(self):
|
||||
current_response_task = None
|
||||
|
||||
async for frame in self._frame_generator():
|
||||
if isinstance(frame, EndStreamQueueFrame):
|
||||
break
|
||||
elif not isinstance(frame, TranscriptionQueueFrame):
|
||||
continue
|
||||
|
||||
if frame.participantId == self._my_participant_id:
|
||||
continue
|
||||
|
||||
if current_response_task:
|
||||
current_response_task.cancel()
|
||||
self._interrupt()
|
||||
|
||||
self._current_phrase += " " + frame.text
|
||||
current_llm_messages = copy.deepcopy(self._messages)
|
||||
current_response_task = asyncio.create_task(
|
||||
self.speak_after_delay(self._current_phrase, current_llm_messages)
|
||||
)
|
||||
current_response_task.add_done_callback(
|
||||
functools.partial(self.update_messages, current_llm_messages)
|
||||
)
|
||||
@@ -1,127 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread
|
||||
|
||||
from dailyai.storage.search import SearchIndexer
|
||||
from dailyai.services.ai_services import AIServiceConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
type: str
|
||||
timestamp: float
|
||||
message: str
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
def __init__(self, intro):
|
||||
self.messages: list[Message] = [Message("system", time.time(), intro)]
|
||||
self.last_user_message_idx:int | None = None
|
||||
self.finalized_user_message_idx: int | None = None
|
||||
|
||||
def add_user_message(self, message) -> None:
|
||||
if self.last_user_message_idx is not None and self.last_user_message_idx != self.finalized_user_message_idx:
|
||||
previous_message: str = self.messages[self.last_user_message_idx].message
|
||||
self.messages[self.last_user_message_idx] = Message(
|
||||
"user", time.time(), ' '.join([previous_message, message])
|
||||
)
|
||||
self.messages = self.messages[: self.last_user_message_idx + 1]
|
||||
else:
|
||||
self.messages.append(Message("user", time.time(), message))
|
||||
|
||||
self.last_user_message_idx = len(self.messages) - 1
|
||||
|
||||
def add_assistant_message(self, message) -> None:
|
||||
if self.messages[-1].type == "assistant":
|
||||
self.messages[-1].message += " " + message
|
||||
else:
|
||||
self.messages.append(Message("assistant", time.time(), message))
|
||||
|
||||
def add_assistant_messages(self, messages) -> None:
|
||||
self.messages.append(Message("assistant", time.time(), " ".join(messages)))
|
||||
|
||||
def get_llm_messages(self) -> list[dict[str, str]]:
|
||||
return [{"role": m.type, "content": m.message} for m in self.messages]
|
||||
|
||||
def finalize_user_message(self) -> None:
|
||||
self.finalized_user_message_idx = self.last_user_message_idx
|
||||
|
||||
def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
class IndexingMessageHandler(MessageHandler):
|
||||
def __init__(
|
||||
self, intro, services: AIServiceConfig, indexer: SearchIndexer
|
||||
) -> None:
|
||||
super().__init__(intro)
|
||||
self.services = services
|
||||
|
||||
self.search_indexer = indexer
|
||||
|
||||
self.last_written_idx = 0
|
||||
self.storage_message_queue = Queue()
|
||||
|
||||
self.index_writer_thread = Thread(target=self.storage_writer, daemon=True)
|
||||
self.index_writer_thread.start()
|
||||
|
||||
self.logger = logging.getLogger("dailyai")
|
||||
|
||||
def shutdown(self):
|
||||
self.finalize_user_message()
|
||||
self.storage_message_queue.put(None)
|
||||
self.index_writer_thread.join()
|
||||
|
||||
def storage_writer(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
message_idx = self.storage_message_queue.get()
|
||||
self.storage_message_queue.task_done()
|
||||
|
||||
if message_idx is None:
|
||||
return
|
||||
|
||||
if message_idx <= self.last_written_idx:
|
||||
continue
|
||||
|
||||
self.last_written_idx = message_idx
|
||||
|
||||
message = self.messages[message_idx]
|
||||
content = message.message
|
||||
if message.type == "user":
|
||||
content = self.cleanup_user_message(content)
|
||||
|
||||
# sometimes the LLM returns a string wrapped in quotes and sometimes it doesn't.
|
||||
# if it didn't, wrap it in quotes
|
||||
if content[0] != '"':
|
||||
content = '"' + content + '"'
|
||||
|
||||
self.search_indexer.index_text(content)
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
def cleanup_user_message(self, user_message) -> str:
|
||||
return user_message
|
||||
|
||||
def finalize_user_message(self):
|
||||
super().finalize_user_message()
|
||||
self.write_messages_to_storage()
|
||||
|
||||
def write_messages_to_storage(self):
|
||||
if self.finalized_user_message_idx is None:
|
||||
return
|
||||
|
||||
for idx in range(self.last_written_idx, len(self.messages)):
|
||||
self.logger.info(
|
||||
f"Writing to storage: {self.messages[idx].type} {self.messages[idx].message}"
|
||||
)
|
||||
if (
|
||||
self.messages[idx].type == "user"
|
||||
and idx > self.finalized_user_message_idx
|
||||
):
|
||||
break
|
||||
|
||||
if self.messages[idx].type != "system":
|
||||
self.storage_message_queue.put(idx)
|
||||
@@ -1,409 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import wave
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from queue import Queue, Empty
|
||||
from opentelemetry import trace, context
|
||||
|
||||
from dailyai.async_processor.async_processor import (
|
||||
AsyncProcessor,
|
||||
AsyncProcessorState,
|
||||
ConversationProcessorCollection,
|
||||
OrchestratorResponse,
|
||||
LLMResponse,
|
||||
)
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.ai_services import AIServiceConfig
|
||||
from dailyai.message_handler.message_handler import MessageHandler
|
||||
|
||||
from threading import Thread, Semaphore, Event, Timer
|
||||
|
||||
from opentelemetry import context
|
||||
from opentelemetry.context.context import Context
|
||||
|
||||
from daily import (
|
||||
EventHandler,
|
||||
CallClient,
|
||||
Daily,
|
||||
VirtualCameraDevice,
|
||||
VirtualMicrophoneDevice,
|
||||
VirtualSpeakerDevice,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestratorConfig:
|
||||
room_url: str
|
||||
token: str
|
||||
bot_name: str
|
||||
expiration: float
|
||||
|
||||
# Note that we use this as a default parameter value in the Orchestrator
|
||||
# constructor. The dataclass is defined with Frozen=True, so this should
|
||||
# be safe.
|
||||
default_conversation_collection = ConversationProcessorCollection(
|
||||
introduction=LLMResponse,
|
||||
waiting=None,
|
||||
response=LLMResponse,
|
||||
goodbye=None,
|
||||
)
|
||||
|
||||
|
||||
class Orchestrator(EventHandler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
daily_config: OrchestratorConfig,
|
||||
ai_service_config: AIServiceConfig,
|
||||
message_handler: MessageHandler,
|
||||
conversation_processors: ConversationProcessorCollection = default_conversation_collection,
|
||||
tracer=None,
|
||||
):
|
||||
self.bot_name: str = daily_config.bot_name
|
||||
self.room_url: str = daily_config.room_url
|
||||
self.token: str = daily_config.token
|
||||
self.expiration: float = daily_config.expiration
|
||||
|
||||
self.logger: logging.Logger = logging.getLogger("dailyai")
|
||||
self.tracer = tracer or trace.get_tracer("orchestrator")
|
||||
|
||||
self.ctx: Context = context.get_current()
|
||||
|
||||
self.transcription = ""
|
||||
self.last_fragment_at = None
|
||||
self.talked_at = None
|
||||
self.paused_at = None
|
||||
|
||||
self.logger.info(f"Creating Response for introductions")
|
||||
self.services: AIServiceConfig = ai_service_config
|
||||
self.output_queue = Queue()
|
||||
self.is_interrupted = Event()
|
||||
self.stop_threads = Event()
|
||||
self.story_started = False
|
||||
|
||||
self.message_handler = message_handler
|
||||
self.conversation_processors: ConversationProcessorCollection = conversation_processors
|
||||
|
||||
if conversation_processors.introduction is not None:
|
||||
intro = conversation_processors.introduction(
|
||||
services=self.services, message_handler=self.message_handler, output_queue=self.output_queue
|
||||
)
|
||||
intro.prepare()
|
||||
intro.set_state_callback(AsyncProcessorState.DONE, self.on_intro_played)
|
||||
intro.set_state_callback(AsyncProcessorState.FINALIZED, self.on_intro_finished)
|
||||
self.logger.info(f"Introduction is preparing")
|
||||
|
||||
self.current_response: AsyncProcessor = intro
|
||||
self.can_interrupt = False
|
||||
# self.response_event.set()
|
||||
self.response_semaphore = Semaphore()
|
||||
|
||||
self.speech_timeout = None
|
||||
self.interrupt_time = None
|
||||
|
||||
self.logger.info("Configuring daily")
|
||||
self.configure_daily()
|
||||
|
||||
def configure_daily(self):
|
||||
Daily.init()
|
||||
self.client = CallClient(event_handler=self)
|
||||
|
||||
self.logger.info(f"Mic sample rate: {self.services.tts.get_mic_sample_rate()}")
|
||||
self.mic: VirtualMicrophoneDevice = Daily.create_microphone_device(
|
||||
"mic", sample_rate=self.services.tts.get_mic_sample_rate(), channels=1
|
||||
)
|
||||
self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
|
||||
"speaker", sample_rate=16000, channels=1
|
||||
)
|
||||
self.camera: VirtualCameraDevice = Daily.create_camera_device(
|
||||
"camera", width=720, height=1280, color_format="RGB"
|
||||
)
|
||||
|
||||
Daily.select_speaker_device("speaker")
|
||||
|
||||
self.client.set_user_name(self.bot_name)
|
||||
self.client.join(self.room_url, self.token, completion=self.call_joined)
|
||||
|
||||
self.client.update_inputs(
|
||||
{
|
||||
"camera": {
|
||||
"isEnabled": True,
|
||||
"settings": {
|
||||
"deviceId": "camera",
|
||||
},
|
||||
},
|
||||
"microphone": {
|
||||
"isEnabled": True,
|
||||
"settings": {
|
||||
"deviceId": "mic",
|
||||
"customConstraints": {
|
||||
"autoGainControl": {"exact": False},
|
||||
"echoCancellation": {"exact": False},
|
||||
"noiseSuppression": {"exact": False},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
self.client.update_publishing(
|
||||
{
|
||||
"camera": {
|
||||
"sendSettings": {
|
||||
"maxQuality": "low",
|
||||
"encodings": {
|
||||
"low": {
|
||||
"maxBitrate": 250000,
|
||||
"scaleResolutionDownBy": 1.333,
|
||||
"maxFramerate": 8,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
self.my_participant_id = self.client.participants()["local"]["id"]
|
||||
|
||||
def start(self) -> None:
|
||||
# TODO: this loop could, I think, be replaced with a timer and an event
|
||||
self.participant_left = False
|
||||
|
||||
try:
|
||||
participant_count: int = len(self.client.participants())
|
||||
self.logger.info(f"{participant_count} participants in room")
|
||||
while time.time() < self.expiration and not self.participant_left:
|
||||
# all handling of incoming transcriptions happens in on_transcription_message
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e}")
|
||||
finally:
|
||||
self.client.leave()
|
||||
|
||||
def stop(self):
|
||||
self.logger.info("Stop current response")
|
||||
if self.current_response:
|
||||
if self.current_response.state < AsyncProcessorState.INTERRUPTED:
|
||||
self.current_response.interrupt()
|
||||
|
||||
self.logger.info("Wait for state transition")
|
||||
self.current_response.wait_for_state_transition(AsyncProcessorState.FINALIZED)
|
||||
|
||||
self.stop_threads.set()
|
||||
self.camera_thread.join()
|
||||
self.logger.info("Camera thread stopped")
|
||||
|
||||
self.logger.info("Put stop in output queue")
|
||||
self.output_queue.put(QueueFrame(FrameType.END_STREAM, None))
|
||||
|
||||
self.frame_consumer_thread.join()
|
||||
self.logger.info("Orchestrator stopped.")
|
||||
|
||||
def on_intro_played(self, intro):
|
||||
self.logger.info(f"Introduction has played")
|
||||
self.can_interrupt = True
|
||||
intro.finalize()
|
||||
|
||||
def on_intro_finished(self, intro):
|
||||
self.logger.info(f"Introduction has finished")
|
||||
waiting = self.conversation_processors.waiting(self.services, self.message_handler, self.output_queue)
|
||||
waiting.prepare()
|
||||
waiting.play()
|
||||
|
||||
def on_response_played(self, response):
|
||||
response.finalize()
|
||||
|
||||
def on_response_finished(self, response):
|
||||
if not response.was_interrupted:
|
||||
self.message_handler.finalize_user_message()
|
||||
|
||||
def call_joined(self, join_data, client_error):
|
||||
self.logger.info(f"Call_joined: {join_data}, {client_error}")
|
||||
self.client.start_transcription(
|
||||
{
|
||||
"language": "en",
|
||||
"tier": "nova",
|
||||
"model": "2-conversationalai",
|
||||
"profanity_filter": True,
|
||||
"redact": False,
|
||||
"extra": {
|
||||
"endpointing": True,
|
||||
"punctuate": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def on_participant_joined(self, participant):
|
||||
with self.tracer.start_as_current_span("on_participant_joined", context=self.ctx):
|
||||
self.logger.info(f"on_participant_joined: {participant}")
|
||||
|
||||
# TODO: figure out the architecture to get the story id to the client
|
||||
# self.client.send_app_message({"event": "story-id", "storyID": self.story_id})
|
||||
time.sleep(2)
|
||||
|
||||
if not self.story_started:
|
||||
self.action()
|
||||
self.story_started = True
|
||||
|
||||
def on_participant_left(self, participant, reason):
|
||||
self.logger.info(f"Participant {participant} left")
|
||||
if len(self.client.participants()) < 2:
|
||||
self.participant_left = True
|
||||
|
||||
def on_app_message(self, message, sender):
|
||||
with self.tracer.start_as_current_span("on_app_message", context=self.ctx):
|
||||
self.logger.info(f"on_app_message {message} from {sender}")
|
||||
if "isSpeaking" in message and message["isSpeaking"] == True:
|
||||
self.handle_user_started_talking()
|
||||
|
||||
if "isSpeaking" in message and message["isSpeaking"] == False:
|
||||
self.handle_user_stopped_talking()
|
||||
|
||||
def on_transcription_message(self, message):
|
||||
with self.tracer.start_as_current_span("on_transcription_message", context=self.ctx):
|
||||
if message["session_id"] != self.my_participant_id:
|
||||
self.handle_transcription_fragment(message['text'])
|
||||
|
||||
def on_transcription_stopped(self, stopped_by, stopped_by_error):
|
||||
self.logger.info(f"Transcription stopped {stopped_by}, {stopped_by_error}")
|
||||
|
||||
def on_transcription_error(self, message):
|
||||
self.logger.error(f"Transcription error {message}")
|
||||
|
||||
def on_transcription_started(self, status):
|
||||
self.logger.info(f"Transcription started {status}")
|
||||
|
||||
def set_image(self, image: bytes):
|
||||
self.image: bytes | None = image
|
||||
|
||||
def run_camera(self):
|
||||
try:
|
||||
while not self.stop_threads.is_set():
|
||||
if self.image:
|
||||
self.camera.write_frame(self.image)
|
||||
|
||||
time.sleep(1.0 / 8.0) # 8 fps
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e} in camera thread.")
|
||||
|
||||
def handle_user_started_talking(self):
|
||||
# TODO: allow configuration of the timer timeout
|
||||
self.logger.error("user started talking")
|
||||
self.speech_timeout = Timer(1.0, self.utterance_interrupt)
|
||||
|
||||
def handle_user_stopped_talking(self):
|
||||
self.logger.error("user stopped talking, canceling utterance interrupt")
|
||||
if self.speech_timeout:
|
||||
self.speech_timeout.cancel()
|
||||
|
||||
def utterance_interrupt(self):
|
||||
self.logger.error("utterance interrupt")
|
||||
self.is_interrupted.set()
|
||||
|
||||
def handle_transcription_fragment(self, fragment):
|
||||
if not self.can_interrupt:
|
||||
return
|
||||
|
||||
# start generating a new response. We'll do the fast parts of the interrupt
|
||||
# now but wait for the state transition after we've kicked off the prepare
|
||||
# on the new response.
|
||||
if (
|
||||
self.current_response
|
||||
and self.current_response.state < AsyncProcessorState.INTERRUPTED
|
||||
):
|
||||
self.interrupt_time = time.perf_counter()
|
||||
self.is_interrupted.set()
|
||||
self.current_response.interrupt()
|
||||
|
||||
self.message_handler.add_user_message(fragment)
|
||||
|
||||
response_type: type[OrchestratorResponse] | type[LLMResponse] = self.conversation_processors.response or LLMResponse
|
||||
new_response: OrchestratorResponse = response_type(
|
||||
self.services, self.message_handler, self.output_queue
|
||||
)
|
||||
new_response.set_state_callback(
|
||||
AsyncProcessorState.DONE, self.on_response_played
|
||||
)
|
||||
new_response.set_state_callback(
|
||||
AsyncProcessorState.FINALIZED, self.on_response_finished
|
||||
)
|
||||
new_response.prepare()
|
||||
|
||||
self.response_semaphore.acquire()
|
||||
if (
|
||||
self.current_response
|
||||
and self.current_response.state < AsyncProcessorState.INTERRUPTED
|
||||
):
|
||||
self.current_response.wait_for_state_transition(
|
||||
AsyncProcessorState.FINALIZED
|
||||
)
|
||||
|
||||
self.current_response = new_response
|
||||
self.current_response.play()
|
||||
|
||||
self.response_semaphore.release()
|
||||
|
||||
def action(self):
|
||||
self.logger.info("Starting camera thread")
|
||||
self.image: bytes | None = None
|
||||
self.camera_thread = Thread(target=self.run_camera, daemon=True)
|
||||
self.camera_thread.start()
|
||||
|
||||
self.logger.info("Starting frame consumer thread")
|
||||
self.frame_consumer_thread = Thread(target=self.frame_consumer, daemon=True)
|
||||
self.frame_consumer_thread.start()
|
||||
|
||||
self.logger.info("Playing introduction")
|
||||
self.can_interrupt = False
|
||||
self.current_response.play()
|
||||
|
||||
def frame_consumer(self):
|
||||
self.logger.info("🎬 Starting frame consumer thread")
|
||||
b = bytearray()
|
||||
smallest_write_size = 3200
|
||||
all_audio_frames = bytearray()
|
||||
while True:
|
||||
try:
|
||||
frame:QueueFrame = self.output_queue.get()
|
||||
if frame.frame_type == FrameType.END_STREAM:
|
||||
self.logger.info("Stopping frame consumer thread")
|
||||
return
|
||||
|
||||
# if interrupted, we just pull frames off the queue and discard them
|
||||
if not self.is_interrupted.is_set():
|
||||
if frame:
|
||||
if frame.frame_type == FrameType.AUDIO:
|
||||
chunk = frame.frame_data
|
||||
|
||||
all_audio_frames.extend(chunk)
|
||||
|
||||
b.extend(chunk)
|
||||
l = len(b) - (len(b) % smallest_write_size)
|
||||
if l:
|
||||
self.mic.write_frames(bytes(b[:l]))
|
||||
b = b[l:]
|
||||
elif frame.frame_type == FrameType.IMAGE:
|
||||
self.set_image(frame.frame_data)
|
||||
elif len(b):
|
||||
self.mic.write_frames(bytes(b))
|
||||
b = bytearray()
|
||||
else:
|
||||
if self.interrupt_time:
|
||||
self.logger.info(f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}")
|
||||
self.interrupt_time = None
|
||||
|
||||
if frame.frame_type == FrameType.START_STREAM:
|
||||
self.is_interrupted.clear()
|
||||
|
||||
self.output_queue.task_done()
|
||||
except Empty:
|
||||
try:
|
||||
if len(b):
|
||||
self.mic.write_frames(bytes(b))
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception in frame_consumer: {e}, {len(b)}")
|
||||
|
||||
b = bytearray()
|
||||
@@ -25,23 +25,24 @@ class QueueTee:
|
||||
await queue.put(frame)
|
||||
|
||||
class LLMContextAggregator(AIService):
|
||||
def __init__(self, messages: list[dict], role:str, bot_participant_id=None):
|
||||
def __init__(self, messages: list[dict], role:str, bot_participant_id=None, complete_sentences=True):
|
||||
self.messages = messages
|
||||
self.bot_participant_id = bot_participant_id
|
||||
self.role = role
|
||||
self.sentence = ""
|
||||
self.complete_sentences = complete_sentences
|
||||
|
||||
async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
content: str = ""
|
||||
|
||||
# TODO: split up transcription by participant
|
||||
if isinstance(frame, TextQueueFrame):
|
||||
content = frame.text
|
||||
|
||||
self.sentence += content
|
||||
if self.sentence.endswith((".", "?", "!")):
|
||||
self.messages.append({"role": self.role, "content": self.sentence})
|
||||
self.sentence = ""
|
||||
yield LLMMessagesQueueFrame(self.messages)
|
||||
if self.complete_sentences:
|
||||
self.sentence += frame.text
|
||||
if self.sentence.endswith((".", "?", "!")):
|
||||
self.messages.append({"role": self.role, "content": self.sentence})
|
||||
self.sentence = ""
|
||||
yield LLMMessagesQueueFrame(self.messages)
|
||||
else:
|
||||
self.messages.append({"role": self.role, "content": frame.text})
|
||||
yield LLMMessagesQueueFrame(self.messages)
|
||||
|
||||
yield frame
|
||||
|
||||
@@ -5,10 +5,13 @@ from typing import Any
|
||||
class QueueFrame:
|
||||
pass
|
||||
|
||||
class StartStreamQueueFrame(QueueFrame):
|
||||
class ControlQueueFrame(QueueFrame):
|
||||
pass
|
||||
|
||||
class EndStreamQueueFrame(QueueFrame):
|
||||
class StartStreamQueueFrame(ControlQueueFrame):
|
||||
pass
|
||||
|
||||
class EndStreamQueueFrame(ControlQueueFrame):
|
||||
pass
|
||||
|
||||
@dataclass()
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
Pillow==10.1.0
|
||||
typing_extensions==4.9.0
|
||||
typing_extensions==4.9.0
|
||||
faster-whisper==0.10.0
|
||||
@@ -1,8 +1,11 @@
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import wave
|
||||
|
||||
from dailyai.queue_frame import (
|
||||
AudioQueueFrame,
|
||||
ControlQueueFrame,
|
||||
EndStreamQueueFrame,
|
||||
ImageQueueFrame,
|
||||
LLMMessagesQueueFrame,
|
||||
@@ -11,7 +14,7 @@ from dailyai.queue_frame import (
|
||||
)
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator, AsyncIterable, Iterable
|
||||
from typing import AsyncGenerator, AsyncIterable, BinaryIO, Iterable
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -63,9 +66,8 @@ class AIService:
|
||||
|
||||
@abstractmethod
|
||||
async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
# This is a trick for the interpreter (and linter) to know that this is a generator.
|
||||
if False:
|
||||
yield QueueFrame()
|
||||
if isinstance(frame, ControlQueueFrame):
|
||||
yield frame
|
||||
|
||||
@abstractmethod
|
||||
async def finalize(self) -> AsyncGenerator[QueueFrame, None]:
|
||||
@@ -83,7 +85,9 @@ class LLMService(AIService):
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if isinstance(frame, LLMMessagesQueueFrame):
|
||||
if isinstance(frame, ControlQueueFrame):
|
||||
yield frame
|
||||
elif isinstance(frame, LLMMessagesQueueFrame):
|
||||
async for text_chunk in self.run_llm_async(frame.messages):
|
||||
yield TextQueueFrame(text_chunk)
|
||||
|
||||
@@ -106,9 +110,11 @@ class TTSService(AIService):
|
||||
yield bytes()
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not isinstance(frame, TextQueueFrame):
|
||||
if isinstance(frame, ControlQueueFrame):
|
||||
yield frame
|
||||
return
|
||||
elif not isinstance(frame, TextQueueFrame):
|
||||
return
|
||||
|
||||
text: str | None = None
|
||||
if not self.aggregate_sentences:
|
||||
@@ -145,14 +151,46 @@ class ImageGenService(AIService):
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not isinstance(frame, TextQueueFrame):
|
||||
yield frame
|
||||
return
|
||||
|
||||
(url, image_data) = await self.run_image_gen(frame.text)
|
||||
yield ImageQueueFrame(url, image_data)
|
||||
|
||||
class STTService(AIService):
|
||||
"""STTService is a base class for speech-to-text services."""
|
||||
|
||||
_frame_rate: int
|
||||
def __init__(self, frame_rate: int = 16000, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._frame_rate = frame_rate
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def run_stt(self, audio: BinaryIO) -> str:
|
||||
"""Returns transcript as a string"""
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
"""Processes a frame of audio data, either buffering or transcribing it."""
|
||||
if not isinstance(frame, AudioQueueFrame):
|
||||
return
|
||||
|
||||
data = frame.data
|
||||
content = io.BufferedRandom(io.BytesIO())
|
||||
ww = wave.open(self._content, "wb")
|
||||
ww.setnchannels(1)
|
||||
ww.setsampwidth(2)
|
||||
ww.setframerate(self._frame_rate)
|
||||
ww.writeframesraw(data)
|
||||
ww.close()
|
||||
content.seek(0)
|
||||
text = await self.run_stt(content)
|
||||
yield TextQueueFrame(text)
|
||||
|
||||
@dataclass
|
||||
class AIServiceConfig:
|
||||
tts: TTSService
|
||||
image: ImageGenService
|
||||
llm: LLMService
|
||||
stt: STTService
|
||||
|
||||
@@ -5,7 +5,6 @@ import json
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
import os
|
||||
import requests
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
@@ -16,7 +15,10 @@ from PIL import Image
|
||||
from azure.cognitiveservices.speech import SpeechSynthesizer, SpeechConfig, ResultReason, CancellationReason
|
||||
|
||||
class AzureTTSService(TTSService):
|
||||
def __init__(self, speech_key=None, speech_region=None):
|
||||
|
||||
def __init__(
|
||||
self, speech_key=None, speech_region=None, voice_name="en-US-SaraNeural"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
speech_key = speech_key or os.getenv("AZURE_SPEECH_SERVICE_KEY")
|
||||
@@ -25,20 +27,19 @@ class AzureTTSService(TTSService):
|
||||
self.speech_config = SpeechConfig(subscription=speech_key, region=speech_region)
|
||||
self.speech_synthesizer = SpeechSynthesizer(speech_config=self.speech_config, audio_config=None)
|
||||
|
||||
self.voice_name = voice_name
|
||||
|
||||
async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]:
|
||||
self.logger.info("Running azure tts")
|
||||
ssml = "<speak version='1.0' xml:lang='en-US' xmlns='http://www.w3.org/2001/10/synthesis' " \
|
||||
ssml = f"<speak version='1.0' xml:lang='en-US' xmlns='http://www.w3.org/2001/10/synthesis' " \
|
||||
"xmlns:mstts='http://www.w3.org/2001/mstts'>" \
|
||||
"<voice name='en-US-SaraNeural'>" \
|
||||
f"<voice name={self.voice_name}>" \
|
||||
"<mstts:silence type='Sentenceboundary' value='20ms' />" \
|
||||
"<mstts:express-as style='lyrical' styledegree='2' role='SeniorFemale'>" \
|
||||
"<prosody rate='1.05'>" \
|
||||
f"{sentence}" \
|
||||
"</prosody></mstts:express-as></voice></speak> "
|
||||
try:
|
||||
result = await asyncio.to_thread(self.speech_synthesizer.speak_ssml, (ssml))
|
||||
except Exception as e:
|
||||
self.logger.error("Error in azure tts", e)
|
||||
result = await asyncio.to_thread(self.speech_synthesizer.speak_ssml, (ssml))
|
||||
self.logger.info("Got azure tts result")
|
||||
if result.reason == ResultReason.SynthesizingAudioCompleted:
|
||||
self.logger.info("Returning result")
|
||||
@@ -95,81 +96,58 @@ class AzureLLMService(LLMService):
|
||||
|
||||
class AzureImageGenServiceREST(ImageGenService):
|
||||
|
||||
def __init__(self, image_size:str, api_key=None, azure_endpoint=None, api_version=None, model=None):
|
||||
def __init__(
|
||||
self,
|
||||
image_size: str,
|
||||
api_key: str | None = None,
|
||||
azure_endpoint: str | None = None,
|
||||
api_version: str | None = None,
|
||||
model: str | None = None,
|
||||
aiohttp_session: aiohttp.ClientSession | None=None,
|
||||
timeout_seconds=120,
|
||||
):
|
||||
super().__init__(image_size=image_size)
|
||||
self.api_key = api_key or os.getenv("AZURE_DALLE_KEY")
|
||||
self.azure_endpoint = azure_endpoint or os.getenv("AZURE_DALLE_ENDPOINT")
|
||||
self.api_version = api_version or "2023-06-01-preview"
|
||||
self.model = model or os.getenv("AZURE_DALLE_DEPLOYMENT_ID")
|
||||
self.aiohttp_session: aiohttp.ClientSession = (
|
||||
aiohttp_session or aiohttp.ClientSession()
|
||||
)
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
|
||||
# TODO hoist the session to app-level
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = f"{self.azure_endpoint}openai/images/generations:submit?api-version={self.api_version}"
|
||||
headers= { "api-key": self.api_key, "Content-Type": "application/json" }
|
||||
body = {
|
||||
# Enter your prompt text here
|
||||
"prompt": sentence,
|
||||
"size": self.image_size,
|
||||
"n": 1,
|
||||
}
|
||||
async with session.post(url, headers=headers, json=body) as submission:
|
||||
operation_location = submission.headers['operation-location']
|
||||
url = f"{self.azure_endpoint}openai/images/generations:submit?api-version={self.api_version}"
|
||||
headers= { "api-key": self.api_key, "Content-Type": "application/json" }
|
||||
body = {
|
||||
"prompt": sentence,
|
||||
"size": self.image_size,
|
||||
"n": 1,
|
||||
}
|
||||
async with self.aiohttp_session.post(
|
||||
url, headers=headers, json=body
|
||||
) as submission:
|
||||
operation_location = submission.headers['operation-location']
|
||||
|
||||
status = ""
|
||||
attempts_left = 120
|
||||
json_response = None
|
||||
while status != "succeeded":
|
||||
attempts_left -= 1
|
||||
if attempts_left == 0:
|
||||
raise Exception("Image generation timed out")
|
||||
status = ""
|
||||
attempts_left = self.timeout_seconds
|
||||
json_response = None
|
||||
while status != "succeeded":
|
||||
attempts_left -= 1
|
||||
if attempts_left == 0:
|
||||
raise Exception("Image generation timed out")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
response = await session.get(operation_location, headers=headers)
|
||||
json_response = await response.json()
|
||||
status = json_response["status"]
|
||||
await asyncio.sleep(1)
|
||||
response = await self.aiohttp_session.get(operation_location, headers=headers)
|
||||
json_response = await response.json()
|
||||
status = json_response["status"]
|
||||
|
||||
image_url = json_response["result"]["data"][0]["url"] if json_response else None
|
||||
if not image_url:
|
||||
raise Exception("Image generation failed")
|
||||
image_url = json_response["result"]["data"][0]["url"] if json_response else None
|
||||
if not image_url:
|
||||
raise Exception("Image generation failed")
|
||||
|
||||
# Load the image from the url
|
||||
async with session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
return (image_url, image.tobytes())
|
||||
|
||||
|
||||
class AzureImageGenService(ImageGenService):
|
||||
|
||||
def __init__(self, api_key=None, azure_endpoint=None, api_version=None, model=None):
|
||||
super().__init__()
|
||||
|
||||
api_key = api_key or os.getenv("AZURE_DALLE_KEY")
|
||||
azure_endpoint = azure_endpoint or os.getenv("AZURE_DALLE_ENDPOINT")
|
||||
api_version = api_version or "2023-06-01-preview"
|
||||
self.model = model or os.getenv("AZURE_DALLE_DEPLOYMENT_ID")
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
|
||||
self.logger.info("Generating azure image", sentence)
|
||||
|
||||
image = self.client.images.generate(
|
||||
model=self.model,
|
||||
prompt=sentence,
|
||||
n=1,
|
||||
size=self.image_size,
|
||||
)
|
||||
|
||||
url = image["data"][0]["url"]
|
||||
response = requests.get(url)
|
||||
|
||||
dalle_stream = io.BytesIO(response.content)
|
||||
dalle_im = Image.open(dalle_stream.tobytes())
|
||||
|
||||
return (url, dalle_im)
|
||||
# Load the image from the url
|
||||
async with self.aiohttp_session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
return (image_url, image.tobytes())
|
||||
|
||||
@@ -7,6 +7,7 @@ import types
|
||||
|
||||
from functools import partial
|
||||
from queue import Queue, Empty
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from dailyai.queue_frame import (
|
||||
AudioQueueFrame,
|
||||
@@ -14,6 +15,7 @@ from dailyai.queue_frame import (
|
||||
ImageQueueFrame,
|
||||
QueueFrame,
|
||||
StartStreamQueueFrame,
|
||||
TextQueueFrame,
|
||||
TranscriptionQueueFrame,
|
||||
)
|
||||
|
||||
@@ -31,6 +33,10 @@ from daily import (
|
||||
class DailyTransportService(EventHandler):
|
||||
_daily_initialized = False
|
||||
_lock = threading.Lock()
|
||||
|
||||
speaker_enabled: bool
|
||||
speaker_sample_rate: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
room_url: str,
|
||||
@@ -38,6 +44,9 @@ class DailyTransportService(EventHandler):
|
||||
bot_name: str,
|
||||
duration: float = 10,
|
||||
min_others_count: int = 1,
|
||||
start_transcription: bool = True,
|
||||
speaker_enabled: bool = False,
|
||||
speaker_sample_rate: int = 16000,
|
||||
):
|
||||
super().__init__()
|
||||
self.bot_name: str = bot_name
|
||||
@@ -46,6 +55,7 @@ class DailyTransportService(EventHandler):
|
||||
self.duration: float = duration
|
||||
self.expiration = time.time() + duration * 60
|
||||
self.min_others_count = min_others_count
|
||||
self.start_transcription = start_transcription
|
||||
|
||||
# This queue is used to marshal frames from the async send queue to the thread that emits audio & video.
|
||||
# We need this to maintain the asynchronous behavior of asyncio queues -- to give async functions
|
||||
@@ -61,6 +71,8 @@ class DailyTransportService(EventHandler):
|
||||
self.camera_width = 1024
|
||||
self.camera_height = 768
|
||||
self.camera_enabled = False
|
||||
self.speaker_enabled = speaker_enabled
|
||||
self.speaker_sample_rate = speaker_sample_rate
|
||||
|
||||
self.send_queue = asyncio.Queue()
|
||||
self.receive_queue = asyncio.Queue()
|
||||
@@ -104,6 +116,7 @@ class DailyTransportService(EventHandler):
|
||||
handler(*args, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception in event handler {event_name}: {e}")
|
||||
raise e
|
||||
|
||||
def add_event_handler(self, event_name: str, handler):
|
||||
if not event_name.startswith("on_"):
|
||||
@@ -144,9 +157,11 @@ class DailyTransportService(EventHandler):
|
||||
"camera", width=self.camera_width, height=self.camera_height, color_format="RGB"
|
||||
)
|
||||
|
||||
self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
|
||||
"speaker", sample_rate=16000, channels=1
|
||||
)
|
||||
if self.speaker_enabled:
|
||||
self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
|
||||
"speaker", sample_rate=self.speaker_sample_rate, channels=1
|
||||
)
|
||||
Daily.select_speaker_device("speaker")
|
||||
|
||||
self.image: bytes | None = None
|
||||
self.camera_thread = Thread(target=self.run_camera, daemon=True)
|
||||
@@ -156,8 +171,6 @@ class DailyTransportService(EventHandler):
|
||||
self.frame_consumer_thread = Thread(target=self.frame_consumer, daemon=True)
|
||||
self.frame_consumer_thread.start()
|
||||
|
||||
Daily.select_speaker_device("speaker")
|
||||
|
||||
self.client.set_user_name(self.bot_name)
|
||||
self.client.join(self.room_url, self.token, completion=self.call_joined)
|
||||
self.my_participant_id = self.client.participants()["local"]["id"]
|
||||
@@ -201,10 +214,24 @@ class DailyTransportService(EventHandler):
|
||||
}
|
||||
)
|
||||
|
||||
if self.token:
|
||||
if self.token and self.start_transcription:
|
||||
self.client.start_transcription(self.transcription_settings)
|
||||
|
||||
async def get_receive_frames(self):
|
||||
def _receive_audio(self):
|
||||
"""Receive audio from the Daily call and put it on the receive queue"""
|
||||
seconds = 1
|
||||
desired_frame_count = self.speaker_sample_rate * seconds
|
||||
while True:
|
||||
buffer = self.speaker.read_frames(desired_frame_count)
|
||||
if len(buffer) > 0:
|
||||
frame = AudioQueueFrame(buffer)
|
||||
if self.loop:
|
||||
asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self.loop)
|
||||
|
||||
def interrupt(self):
|
||||
self.is_interrupted.set()
|
||||
|
||||
async def get_receive_frames(self) -> AsyncGenerator[QueueFrame, None]:
|
||||
while True:
|
||||
frame = await self.receive_queue.get()
|
||||
yield frame
|
||||
@@ -244,6 +271,7 @@ class DailyTransportService(EventHandler):
|
||||
await asyncio.sleep(1)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e}")
|
||||
raise e
|
||||
finally:
|
||||
self.client.leave()
|
||||
|
||||
@@ -266,6 +294,9 @@ class DailyTransportService(EventHandler):
|
||||
|
||||
def call_joined(self, join_data, client_error):
|
||||
self.logger.info(f"Call_joined: {join_data}, {client_error}")
|
||||
if self.speaker_enabled:
|
||||
t = Thread(target=self._receive_audio, daemon=True)
|
||||
t.start()
|
||||
|
||||
def on_error(self, error):
|
||||
self.logger.error(f"on_error: {error}")
|
||||
@@ -317,6 +348,7 @@ class DailyTransportService(EventHandler):
|
||||
time.sleep(1.0 / 8) # 8 fps
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e} in camera thread.")
|
||||
raise e
|
||||
|
||||
def frame_consumer(self):
|
||||
self.logger.info("🎬 Starting frame consumer thread")
|
||||
@@ -358,11 +390,11 @@ class DailyTransportService(EventHandler):
|
||||
self.mic.write_frames(bytes(b))
|
||||
b = bytearray()
|
||||
else:
|
||||
if self.interrupt_time:
|
||||
self.logger.info(
|
||||
f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}"
|
||||
)
|
||||
self.interrupt_time = None
|
||||
# if there are leftover audio bytes, write them now; failing to do so
|
||||
# can cause static in the audio stream.
|
||||
if len(b):
|
||||
self.mic.write_frames(bytes(b))
|
||||
b = bytearray()
|
||||
|
||||
if isinstance(frame, StartStreamQueueFrame):
|
||||
self.is_interrupted.clear()
|
||||
@@ -374,5 +406,6 @@ class DailyTransportService(EventHandler):
|
||||
self.mic.write_frames(bytes(b))
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception in frame_consumer: {e}, {len(b)}")
|
||||
raise e
|
||||
|
||||
b = bytearray()
|
||||
|
||||
@@ -9,28 +9,30 @@ from dailyai.services.ai_services import TTSService
|
||||
|
||||
|
||||
class ElevenLabsTTSService(TTSService):
|
||||
def __init__(self, api_key=None, voice_id=None):
|
||||
def __init__(self, api_key=None, voice_id=None, aiohttp_session:aiohttp.ClientSession=None):
|
||||
super().__init__()
|
||||
|
||||
self.api_key = api_key or os.getenv("ELEVENLABS_API_KEY")
|
||||
self.voice_id = voice_id or os.getenv("ELEVENLABS_VOICE_ID")
|
||||
self.aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
|
||||
async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.voice_id}/stream"
|
||||
payload = {"text": sentence, "model_id": "eleven_turbo_v2"}
|
||||
querystring = {"output_format": "pcm_16000", "optimize_streaming_latency": 2}
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
async with session.post(url, json=payload, headers=headers, params=querystring) as r:
|
||||
if r.status != 200:
|
||||
self.logger.error(
|
||||
f"audio fetch status code: {r.status}, error: {r.text}"
|
||||
)
|
||||
return
|
||||
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.voice_id}/stream"
|
||||
payload = {"text": sentence, "model_id": "eleven_turbo_v2"}
|
||||
querystring = {"output_format": "pcm_16000", "optimize_streaming_latency": 2}
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
async with self.aiohttp_session.post(
|
||||
url, json=payload, headers=headers, params=querystring
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
self.logger.error(
|
||||
f"audio fetch status code: {r.status}, error: {r.text}"
|
||||
)
|
||||
return
|
||||
|
||||
async for chunk in r.content:
|
||||
if chunk:
|
||||
yield chunk
|
||||
async for chunk in r.content:
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
@@ -34,9 +34,9 @@ class FalImageGenService(ImageGenService):
|
||||
raise Exception("Image generation failed")
|
||||
|
||||
return image_url
|
||||
print(f"fetching image url...")
|
||||
print("fetching image url...")
|
||||
image_url = await asyncio.to_thread(get_image_url, sentence, self.image_size)
|
||||
print(f"got image url, downloading image...")
|
||||
print("got image url, downloading image...")
|
||||
# Load the image from the url
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as response:
|
||||
|
||||
72
src/dailyai/services/local_stt_service.py
Normal file
72
src/dailyai/services/local_stt_service.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import array
|
||||
import io
|
||||
import math
|
||||
from typing import AsyncGenerator
|
||||
import wave
|
||||
from dailyai.queue_frame import AudioQueueFrame, QueueFrame, TextQueueFrame
|
||||
from dailyai.services.ai_services import STTService
|
||||
|
||||
|
||||
class LocalSTTService(STTService):
|
||||
_content: io.BufferedRandom
|
||||
_wave: wave.Wave_write
|
||||
_current_silence_frames: int
|
||||
|
||||
# Configuration
|
||||
_min_rms: int
|
||||
_max_silence_frames: int
|
||||
_frame_rate: int
|
||||
|
||||
def __init__(self,
|
||||
min_rms: int = 400,
|
||||
max_silence_frames: int = 3,
|
||||
frame_rate: int = 16000,
|
||||
**kwargs):
|
||||
super().__init__(frame_rate, **kwargs)
|
||||
self._current_silence_frames = 0
|
||||
self._min_rms = min_rms
|
||||
self._max_silence_frames = max_silence_frames
|
||||
self._frame_rate = frame_rate
|
||||
self._new_wave()
|
||||
|
||||
def _new_wave(self):
|
||||
"""Creates a new wave object and content buffer."""
|
||||
self._content = io.BufferedRandom(io.BytesIO())
|
||||
ww = wave.open(self._content, "wb")
|
||||
ww.setnchannels(1)
|
||||
ww.setsampwidth(2)
|
||||
ww.setframerate(self._frame_rate)
|
||||
self._wave = ww
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
"""Processes a frame of audio data, either buffering or transcribing it."""
|
||||
if not isinstance(frame, AudioQueueFrame):
|
||||
return
|
||||
|
||||
data = frame.data
|
||||
# Try to filter out empty background noise
|
||||
# (Very rudimentary approach, can be improved)
|
||||
rms = self._get_volume(data)
|
||||
if rms >= self._min_rms:
|
||||
# If volume is high enough, write new data to wave file
|
||||
self._wave.writeframesraw(data)
|
||||
|
||||
# If buffer is not empty and we detect a 3-frame pause in speech,
|
||||
# transcribe the audio gathered so far.
|
||||
if self._content.tell() > 0 and self._current_silence_frames > self._max_silence_frames:
|
||||
self._current_silence_frames = 0
|
||||
self._wave.close()
|
||||
self._content.seek(0)
|
||||
text = await self.run_stt(self._content)
|
||||
self._new_wave()
|
||||
yield TextQueueFrame(text)
|
||||
# If we get this far, this is a frame of silence
|
||||
self._current_silence_frames += 1
|
||||
|
||||
def _get_volume(self, audio: bytes) -> float:
|
||||
# https://docs.python.org/3/library/array.html
|
||||
audio_array = array.array('h', audio)
|
||||
squares = [sample**2 for sample in audio_array]
|
||||
mean = sum(squares) / len(audio_array)
|
||||
rms = math.sqrt(mean)
|
||||
return rms
|
||||
@@ -1,6 +1,4 @@
|
||||
import requests
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from PIL import Image
|
||||
import io
|
||||
from openai import AsyncOpenAI
|
||||
@@ -9,7 +7,7 @@ import os
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from dailyai.services.ai_services import AIService, TTSService, LLMService, ImageGenService
|
||||
from dailyai.services.ai_services import LLMService, ImageGenService
|
||||
|
||||
|
||||
class OpenAILLMService(LLMService):
|
||||
@@ -50,11 +48,19 @@ class OpenAILLMService(LLMService):
|
||||
return None
|
||||
|
||||
class OpenAIImageGenService(ImageGenService):
|
||||
def __init__(self, image_size:str, api_key=None, model=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: str,
|
||||
api_key=None,
|
||||
model=None,
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
):
|
||||
super().__init__(image_size=image_size)
|
||||
api_key = api_key or os.getenv("OPEN_AI_KEY")
|
||||
self.model = model or os.getenv("OPEN_AI_IMAGE_MODEL") or "dall-e-3"
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
self.aiohttp_session=aiohttp_session or aiohttp.ClientSession()
|
||||
|
||||
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
|
||||
self.logger.info("Generating OpenAI image", sentence)
|
||||
@@ -70,10 +76,7 @@ class OpenAIImageGenService(ImageGenService):
|
||||
raise Exception("No image provided in response", image)
|
||||
|
||||
# Load the image from the url
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
return (image_url, image.tobytes())
|
||||
|
||||
return (image_url, dalle_im.tobytes())
|
||||
async with self.aiohttp_session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
return (image_url, image.tobytes())
|
||||
|
||||
55
src/dailyai/services/whisper_ai_services.py
Normal file
55
src/dailyai/services/whisper_ai_services.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""This module implements Whisper transcription with a locally-downloaded model."""
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
import logging
|
||||
from typing import BinaryIO
|
||||
from faster_whisper import WhisperModel
|
||||
from dailyai.services.local_stt_service import LocalSTTService
|
||||
|
||||
|
||||
class Model(Enum):
|
||||
"""Class of basic Whisper model selection options"""
|
||||
TINY = "tiny"
|
||||
BASE = "base"
|
||||
MEDIUM = "medium"
|
||||
LARGE = "large-v3"
|
||||
DISTIL_LARGE_V2 = "Systran/faster-distil-whisper-large-v2"
|
||||
DISTIL_MEDIUM_EN = "Systran/faster-distil-whisper-medium.en"
|
||||
|
||||
|
||||
class WhisperSTTService(LocalSTTService):
|
||||
"""Class to transcribe audio with a locally-downloaded Whisper model"""
|
||||
_model: WhisperModel
|
||||
|
||||
# Model configuration
|
||||
_model_name: Model
|
||||
_device: str
|
||||
_compute_type: str
|
||||
|
||||
def __init__(self, model_name: Model = Model.DISTIL_MEDIUM_EN,
|
||||
device: str = "auto",
|
||||
compute_type: str = "default"):
|
||||
|
||||
super().__init__()
|
||||
self.logger: logging.Logger = logging.getLogger("dailyai")
|
||||
self._model_name = model_name
|
||||
self._device = device
|
||||
self._compute_type = compute_type
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
"""Loads the Whisper model. Note that if this is the first time
|
||||
this model is being run, it will take time to download."""
|
||||
model = WhisperModel(
|
||||
self._model_name.value,
|
||||
device=self._device,
|
||||
compute_type=self._compute_type)
|
||||
self._model = model
|
||||
|
||||
async def run_stt(self, audio: BinaryIO = None) -> str:
|
||||
"""Transcribes given audio using Whisper"""
|
||||
segments, _ = await asyncio.to_thread(self._model.transcribe, audio)
|
||||
res: str = ""
|
||||
for segment in segments:
|
||||
res += f"{segment.text} "
|
||||
return res
|
||||
@@ -1,180 +0,0 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread, Event
|
||||
from typing import Generator
|
||||
|
||||
from dailyai.async_processor.async_processor import (
|
||||
AsyncProcessor,
|
||||
AsyncProcessorState,
|
||||
LLMResponse,
|
||||
)
|
||||
from dailyai.message_handler.message_handler import MessageHandler
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.ai_services import (
|
||||
AIServiceConfig,
|
||||
ImageGenService,
|
||||
LLMService,
|
||||
TTSService,
|
||||
)
|
||||
"""
|
||||
class MockTTSService(TTSService):
|
||||
def run_tts(self, sentence):
|
||||
for word in sentence.split(' '):
|
||||
time.sleep(0.1)
|
||||
yield bytes(word, "utf-8")
|
||||
|
||||
class MockLLMService(LLMService):
|
||||
def run_llm_async(self, messages) -> Generator[str, None, None]:
|
||||
for i in ["Hello ", "there.", "How are ", "you?", "I ", "hope ", "you ", "are ", "well."]:
|
||||
time.sleep(0.1)
|
||||
yield i
|
||||
|
||||
class MockImageService(ImageGenService):
|
||||
def run_image_gen(self, sentence) -> None:
|
||||
return None
|
||||
|
||||
class TestResponse(unittest.TestCase):
|
||||
def test_base_state_transitions(self):
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
processor = AsyncProcessor(AIServiceConfig(tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service))
|
||||
processor.prepare()
|
||||
processor.play()
|
||||
processor.finalize()
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_state_transitions(self):
|
||||
output_queue = Queue()
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
message_handler = MessageHandler("Hello World")
|
||||
processor = LLMResponse(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
),
|
||||
message_handler,
|
||||
output_queue,
|
||||
)
|
||||
processor.prepare()
|
||||
processor.play()
|
||||
|
||||
# Consume the output from the output queue. It's necessary to mark these tasks as done for the
|
||||
# play function to return.
|
||||
expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."]
|
||||
|
||||
# remove the "start_stream" message from the queue
|
||||
output_queue.get()
|
||||
output_queue.task_done()
|
||||
|
||||
while expected_words:
|
||||
actual_word:QueueFrame = output_queue.get()
|
||||
word = expected_words.pop(0)
|
||||
self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME)
|
||||
self.assertEqual(actual_word.frame_data, bytes(word, "utf-8"))
|
||||
output_queue.task_done()
|
||||
|
||||
processor.finalize()
|
||||
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_interrupt_preparation(self):
|
||||
output_queue = Queue()
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
message_handler = MessageHandler("System Message")
|
||||
processor = LLMResponse(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
),
|
||||
message_handler,
|
||||
output_queue,
|
||||
)
|
||||
processor.prepare()
|
||||
interrupt_request_at = time.perf_counter()
|
||||
processor.interrupt()
|
||||
processor.finalize()
|
||||
finalized_at = time.perf_counter()
|
||||
self.assertTrue(0.1 < finalized_at - interrupt_request_at < 0.2)
|
||||
print(f"delta: {interrupt_request_at, finalized_at}")
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_interrupt_play(self):
|
||||
output_queue = Queue()
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
message_handler = MessageHandler("System Message")
|
||||
processor = LLMResponse(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
),
|
||||
message_handler,
|
||||
output_queue,
|
||||
)
|
||||
processor.prepare()
|
||||
processor.play()
|
||||
|
||||
stop_processing_output_queue = Event()
|
||||
def process_output_queue_async():
|
||||
# Consume the output from the output queue. It's necessary to mark these tasks as done for the
|
||||
# play function to return.
|
||||
time.sleep(0.1)
|
||||
expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."]
|
||||
while expected_words and not stop_processing_output_queue.is_set():
|
||||
try:
|
||||
actual_word:QueueFrame = output_queue.get_nowait()
|
||||
if actual_word.frame_type == FrameType.AUDIO_FRAME:
|
||||
time.sleep(0.1)
|
||||
word = expected_words.pop(0)
|
||||
self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME)
|
||||
self.assertEqual(actual_word.frame_data, bytes(word, "utf-8"))
|
||||
output_queue.task_done()
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
process_output_queue = Thread(target=process_output_queue_async, daemon=True)
|
||||
process_output_queue.start()
|
||||
|
||||
time.sleep(0.5)
|
||||
processor.interrupt()
|
||||
|
||||
stop_processing_output_queue.set()
|
||||
process_output_queue.join()
|
||||
|
||||
processor.finalize()
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_statechange_callback(self):
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
processor = AsyncProcessor(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
)
|
||||
)
|
||||
is_finalized = False
|
||||
def set_is_finalized(async_processor:AsyncProcessor):
|
||||
nonlocal is_finalized
|
||||
is_finalized = True
|
||||
|
||||
processor.set_state_callback(
|
||||
AsyncProcessorState.FINALIZED, set_is_finalized
|
||||
)
|
||||
processor.prepare()
|
||||
self.assertFalse(is_finalized)
|
||||
processor.play()
|
||||
self.assertFalse(is_finalized)
|
||||
processor.finalize()
|
||||
self.assertTrue(is_finalized)
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
"""
|
||||
@@ -1,147 +0,0 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from unittest.mock import MagicMock, call
|
||||
|
||||
from dailyai.message_handler.message_handler import MessageHandler, IndexingMessageHandler
|
||||
from dailyai.services.ai_services import (
|
||||
AIServiceConfig,
|
||||
TTSService,
|
||||
LLMService,
|
||||
ImageGenService,
|
||||
)
|
||||
from ..storage.search import SearchIndexer
|
||||
|
||||
|
||||
class TestMessageHandler(unittest.TestCase):
|
||||
def test_simple_intro(self):
|
||||
message_handler = MessageHandler("Hello world")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[{"role": "system", "content": "Hello world"}],
|
||||
)
|
||||
|
||||
def test_simple_user_message(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_simple_user_and_assistant_message(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message"},
|
||||
{"role": "assistant", "content": "Assistant message"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_user_message_overwrite(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message")
|
||||
message_handler.add_user_message("plus something else")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message plus something else"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_user_message_after_assistant(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message")
|
||||
message_handler.finalize_user_message()
|
||||
message_handler.add_user_message("other user message")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message"},
|
||||
{"role": "assistant", "content": "Assistant message"},
|
||||
{"role": "user", "content": "other user message"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class MockTTSService(TTSService):
|
||||
def run_tts(self, sentence):
|
||||
for word in sentence.split(" "):
|
||||
time.sleep(0.1)
|
||||
yield bytes(word, "utf-8")
|
||||
|
||||
|
||||
class MockLLMService(LLMService):
|
||||
def run_llm(self, messages) -> str:
|
||||
return "Parsed user message."
|
||||
|
||||
class MockImageService(ImageGenService):
|
||||
def run_image_gen(self, sentence) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class TestStorageMessageHandler(unittest.TestCase):
|
||||
def test_user_message_finalized(self):
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
|
||||
service_config = AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
)
|
||||
|
||||
mock_indexer = MagicMock(spec=SearchIndexer)
|
||||
|
||||
message_handler = IndexingMessageHandler(
|
||||
"Hello world", service_config, mock_indexer
|
||||
)
|
||||
message_handler.cleanup_user_message = MagicMock(return_value="Parsed user message.")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message will be ignored")
|
||||
message_handler.add_user_message("plus something else")
|
||||
message_handler.finalize_user_message()
|
||||
message_handler.add_assistant_message(
|
||||
"New assistant message will not be ignored"
|
||||
)
|
||||
message_handler.add_user_message("User message second time")
|
||||
message_handler.add_assistant_message("Assistant message second time")
|
||||
message_handler.write_messages_to_storage()
|
||||
|
||||
time.sleep(0.5)
|
||||
message_handler.cleanup_user_message.assert_called_with("User message plus something else")
|
||||
self.assertEqual(
|
||||
mock_indexer.mock_calls,
|
||||
[
|
||||
call.index_text('"Parsed user message."'),
|
||||
call.index_text("New assistant message will not be ignored"),
|
||||
],
|
||||
)
|
||||
|
||||
mock_indexer.reset_mock()
|
||||
|
||||
message_handler.finalize_user_message()
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
self.assertEqual(
|
||||
mock_indexer.mock_calls,
|
||||
[
|
||||
call.index_text('"Parsed user message."'),
|
||||
call.index_text("Assistant message second time"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
from dailyai.queue_frame import TextQueueFrame
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.open_ai_services import OpenAIImageGenService
|
||||
from dailyai.services.azure_ai_services import AzureImageGenServiceREST
|
||||
|
||||
local_joined = False
|
||||
participant_joined = False
|
||||
@@ -21,7 +22,7 @@ async def main(room_url):
|
||||
transport.camera_width = 1024
|
||||
transport.camera_height = 1024
|
||||
|
||||
imagegen = OpenAIImageGenService(image_size="1024x1024")
|
||||
imagegen = AzureImageGenServiceREST(image_size="1024x1024")
|
||||
image_task = asyncio.create_task(
|
||||
imagegen.run_to_queue(transport.send_queue, [TextQueueFrame("a cat in the style of picasso")])
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ import argparse
|
||||
import asyncio
|
||||
|
||||
from dailyai.queue_frame import AudioQueueFrame, ImageQueueFrame
|
||||
from dailyai.services.azure_ai_services import AzureLLMService
|
||||
from dailyai.services.azure_ai_services import AzureImageGenServiceREST, AzureLLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.fal_ai_services import FalImageGenService
|
||||
@@ -22,9 +22,9 @@ async def main(room_url):
|
||||
transport.camera_height = 1024
|
||||
|
||||
llm = AzureLLMService()
|
||||
dalle = FalImageGenService(image_size="1024x1024")
|
||||
#dalle = FalImageGenService(image_size="1024x1024")
|
||||
tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV")
|
||||
# dalle = OpenAIImageGenService(image_size="1024x1024")
|
||||
dalle = AzureImageGenServiceREST(image_size="1024x1024")
|
||||
|
||||
# Get a complete audio chunk from the given text. Splitting this into its own
|
||||
# coroutine lets us ensure proper ordering of the audio chunks on the send queue.
|
||||
|
||||
98
src/samples/foundational/07-interruptible.py
Normal file
98
src/samples/foundational/07-interruptible.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import requests
|
||||
import time
|
||||
import urllib.parse
|
||||
from dailyai.conversation_wrappers import InterruptibleConversationWrapper
|
||||
|
||||
from dailyai.queue_frame import StartStreamQueueFrame, TextQueueFrame
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
|
||||
|
||||
async def main(room_url:str, token):
|
||||
global transport
|
||||
global llm
|
||||
global tts
|
||||
|
||||
transport = DailyTransportService(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
5,
|
||||
)
|
||||
transport.mic_enabled = True
|
||||
transport.mic_sample_rate = 16000
|
||||
transport.camera_enabled = False
|
||||
|
||||
llm = AzureLLMService()
|
||||
tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV")
|
||||
|
||||
async def run_response(user_speech, tma_in, tma_out):
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
tma_out.run(
|
||||
llm.run(
|
||||
tma_in.run(
|
||||
[StartStreamQueueFrame(), TextQueueFrame(user_speech)]
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_other_participant_joined")
|
||||
async def on_first_other_participant_joined(transport):
|
||||
await tts.say("Hi, I'm listening!", transport.send_queue)
|
||||
|
||||
async def run_conversation():
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio. Respond to what the user said in a creative and helpful way."},
|
||||
]
|
||||
|
||||
conversation_wrapper = InterruptibleConversationWrapper(
|
||||
frame_generator=transport.get_receive_frames,
|
||||
runner=run_response,
|
||||
interrupt=transport.interrupt,
|
||||
my_participant_id=transport.my_participant_id,
|
||||
llm_messages=messages,
|
||||
)
|
||||
await conversation_wrapper.run_conversation()
|
||||
|
||||
transport.transcription_settings["extra"]["punctuate"] = False
|
||||
await asyncio.gather(transport.run(), run_conversation())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple Daily Bot Sample")
|
||||
parser.add_argument(
|
||||
"-u", "--url", type=str, required=True, help="URL of the Daily room to join"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--apikey",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Daily API Key (needed to create token)",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
# Create a meeting token for the given room with an expiration 1 hour in the future.
|
||||
room_name: str = urllib.parse.urlparse(args.url).path[1:]
|
||||
expiration: float = time.time() + 60 * 60
|
||||
|
||||
res: requests.Response = requests.post(
|
||||
f"https://api.daily.co/v1/meeting-tokens",
|
||||
headers={"Authorization": f"Bearer {args.apikey}"},
|
||||
json={
|
||||
"properties": {"room_name": room_name, "is_owner": True, "exp": expiration}
|
||||
},
|
||||
)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to create meeting token: {res.status_code} {res.text}")
|
||||
|
||||
token: str = res.json()["token"]
|
||||
|
||||
asyncio.run(main(args.url, token))
|
||||
44
src/samples/foundational/07-whisper-transcription.py
Normal file
44
src/samples/foundational/07-whisper-transcription.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.whisper_ai_services import WhisperSTTService
|
||||
|
||||
|
||||
async def main(room_url: str):
|
||||
global transport
|
||||
global stt
|
||||
|
||||
transport = DailyTransportService(
|
||||
room_url,
|
||||
None,
|
||||
"Transcription bot",
|
||||
)
|
||||
transport.mic_enabled = False
|
||||
transport.camera_enabled = False
|
||||
transport.speaker_enabled = True
|
||||
stt = WhisperSTTService()
|
||||
transcription_output_queue = asyncio.Queue()
|
||||
|
||||
async def handle_transcription():
|
||||
print("`````````TRANSCRIPTION`````````")
|
||||
while True:
|
||||
item = await transcription_output_queue.get()
|
||||
print(item.text)
|
||||
|
||||
async def handle_speaker():
|
||||
await stt.run_to_queue(
|
||||
transcription_output_queue,
|
||||
transport.get_receive_frames()
|
||||
)
|
||||
await asyncio.gather(transport.run(), handle_speaker(), handle_transcription())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple Daily Bot Sample")
|
||||
parser.add_argument(
|
||||
"-u", "--url", type=str, required=True, help="URL of the Daily room to join"
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
asyncio.run(main(args.url))
|
||||
Reference in New Issue
Block a user