Compare commits

...

8 Commits

Author SHA1 Message Date
Moishe Lettvin
646db8b9bd cleanup continues 2024-01-26 07:57:41 -05:00
Moishe Lettvin
42c142aff0 ... 2024-01-25 14:55:51 -05:00
Moishe Lettvin
6da78dbf9c getting started on cleanup 2024-01-25 13:50:10 -05:00
Moishe Lettvin
f0d9b0613e Add faster_whisper to module dependencies; remove unneeded import 2024-01-25 11:27:00 -05:00
Moishe Lettvin
a661905d7f Merge pull request #9 from daily-co/interruptions
Interruptable conversation wrapper
2024-01-25 11:24:57 -05:00
Moishe Lettvin
c9c2e5f561 Remove unnecessary try/except 2024-01-25 11:18:55 -05:00
Moishe Lettvin
795a339542 Add InterruptibleConversationWrapper 2024-01-25 11:15:04 -05:00
Liza
31db156dfc Local Whisper transcription (#10)
* First pass at Whisper transcription

* deletions

* Revise based on feedback, add autopep8
2024-01-25 13:43:25 +01:00
25 changed files with 550 additions and 1353 deletions

View File

@@ -16,7 +16,8 @@ dependencies = [
"pyht",
"opentelemetry-sdk",
"aiohttp",
"fal"
"fal",
"faster_whisper"
]
[tool.setuptools.packages.find]

View File

@@ -1,3 +1,4 @@
autopep8==2.0.4
build==1.0.3
packaging==23.2
pyproject_hooks==1.0.0
pyproject_hooks==1.0.0

View File

@@ -1,347 +0,0 @@
import json
import logging
import re
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from queue import Queue, PriorityQueue, Empty
from threading import Event, Semaphore, Thread
from typing import Any, Generator, Iterator, Optional, Type
from dailyai.queue_frame import QueueFrame, FrameType
from dailyai.message_handler.message_handler import MessageHandler
from dailyai.services.ai_services import AIServiceConfig
class AsyncProcessorState:
# Setting class variables, other synchronous activities
INIT = 0
# Making asynchronous requests to LLM and other services to render response
PREPARING = 1
# Ready to start presenting to user (but may not have all data yet)
READY = 2
# Playing response
PLAYING = 3
# An interrupt has been requested and the response is shutting down in-flight processing
INTERRUPTING = 4
# An interrupt has been requested and the response is finished stopping in-flight processing
INTERRUPTED = 5
# Response has been played or interrupted
DONE = 6
# Response is being finalized (updating records of speech, updating LLM context, etc.)
FINALIZING = 7
# Response is complete. This could mean that everything is updated, or that the response
# was interrupted.
FINALIZED = 8
state_transitions = {
INIT: [PREPARING, INTERRUPTING],
PREPARING: [READY, INTERRUPTING],
READY: [PLAYING, INTERRUPTING],
PLAYING: [DONE, INTERRUPTING],
INTERRUPTING: [INTERRUPTED],
INTERRUPTED: [DONE],
DONE: [FINALIZING],
FINALIZING: [FINALIZED],
FINALIZED: [FINALIZED],
}
@dataclass(order=True)
class StateTransitionItem:
state: int
evt: Event = field(compare=False)
class AsyncProcessor:
def __init__(
self,
services: AIServiceConfig
) -> None:
self.state = AsyncProcessorState.INIT
self.prepare_thread = None
self.play_thread = None
self.finalize_thread = None
self.services: AIServiceConfig = services
self.state_transition_semaphore = Semaphore()
self.waiting_for_state_changes = PriorityQueue()
self.state_queue = Queue()
self.state_change_callbacks = defaultdict(list)
self.was_interrupted = False
self.logger: logging.Logger = logging.getLogger("dailyai")
def set_state(self, state: int) -> None:
if state in AsyncProcessorState.state_transitions[self.state]:
self.state_transition_semaphore.acquire()
self.state: int = state
self.state_transition_semaphore.release()
# wake up any threads waiting for this state transition
try:
while True:
waiter = self.waiting_for_state_changes.get_nowait()
if waiter.state <= state:
waiter.evt.set()
else:
self.waiting_for_state_changes.put(waiter)
break
except Empty:
pass
# make all the callbacks for this state
for callback in self.state_change_callbacks[state]:
callback(self)
else:
self.logger.error(
f"Invalid state transition from {self.state} to {state} in {self.__class__.__name__}"
)
raise Exception(f"Invalid state transition from {self.state} to {state}")
#
# This is used for state transitions that could be blocked by an interruption.
# If we are interrupted, we silently fail this call. Use only if you know that
# this state transition should fail if the processor has been interrupted.
#
def maybe_set_state(self, state: int) -> bool:
if state in AsyncProcessorState.state_transitions[self.state]:
self.set_state(state)
return True
else:
return False
def wait_for_state_transition(self, state: int) -> None:
if self.state >= state:
return
self.state_transition_semaphore.acquire()
evt = Event()
self.waiting_for_state_changes.put(StateTransitionItem(state, evt))
self.state_transition_semaphore.release()
result = evt.wait(120.0)
if not result:
self.logger.error(
f"Timed out waiting for state transition to {state} from {self.state}"
)
def set_state_callback(self, state: int, callback: callable) -> None:
self.state_change_callbacks[state].append(callback)
def prepare(self) -> None:
self.prepare_thread = Thread(target=self.async_prepare, daemon=True)
self.prepare_thread.start()
self.wait_for_state_transition(AsyncProcessorState.READY)
def play(self) -> None:
self.wait_for_state_transition(AsyncProcessorState.READY)
self.play_thread = Thread(target=self.async_play, daemon=True)
self.play_thread.start()
self.wait_for_state_transition(AsyncProcessorState.PLAYING)
def finalize(self) -> None:
# don't finalize until we're done playing.
self.wait_for_state_transition(AsyncProcessorState.DONE)
self.set_state(AsyncProcessorState.FINALIZING)
self.do_finalization()
self.set_state(AsyncProcessorState.FINALIZED)
def interrupt(self) -> None:
# nothing to interrupt if we're already finalizing or finalized, no-op
if self.state in [
AsyncProcessorState.FINALIZING,
AsyncProcessorState.FINALIZED,
]:
return
self.set_state(AsyncProcessorState.INTERRUPTING)
self.was_interrupted = True
self.do_interruption()
self.set_state(AsyncProcessorState.INTERRUPTED)
self.set_state(AsyncProcessorState.DONE)
def async_play(self) -> None:
self.logger.info(f"Starting to play")
if self.maybe_set_state(AsyncProcessorState.PLAYING):
self.do_play()
self.maybe_set_state(AsyncProcessorState.DONE)
def async_prepare(self) -> None:
self.set_state(AsyncProcessorState.PREPARING)
self.start_preparation()
self.set_state(AsyncProcessorState.READY)
self.continue_preparation()
self.logger.info(f"Preparation done for {self.__class__.__name__}")
self.preparation_done()
def start_preparation(self) -> None:
pass
def continue_preparation(self) -> None:
pass
def preparation_done(self):
pass
def get_preparation_iterator(self) -> Iterator:
yield None
def process_chunk(self, chunk) -> None:
pass
def do_interruption(self) -> None:
pass
def do_play(self) -> None:
pass
def do_finalization(self) -> None:
pass
# A common class for responses that use a message queue and
# an output queue.
class OrchestratorResponse(AsyncProcessor):
def __init__(
self,
services,
message_handler,
output_queue,
) -> None:
super().__init__(services)
self.message_handler: MessageHandler = message_handler
self.output_queue: Queue = output_queue
class LLMResponse(OrchestratorResponse):
def __init__(
self,
services,
message_handler,
output_queue,
) -> None:
super().__init__(services, message_handler, output_queue)
self.has_sent_first_frame = False
self.chunks_in_preparation = Queue()
self.llm_responses: list[str] = []
def get_preparation_iterator(self) -> Iterator:
messages_for_llm = self.message_handler.get_llm_messages()
self.logger.debug(f"Messages for llm: {json.dumps(messages_for_llm, indent=2)}")
return self.clauses_from_chunks(
self.services.llm.run_llm_async(messages_for_llm)
)
def clauses_from_chunks(self, chunks) -> Iterator:
out = ""
for chunk in chunks:
if self.state not in [
AsyncProcessorState.READY,
AsyncProcessorState.PLAYING,
]:
break
out += chunk
if re.match(r"^.*[.!?]$", out): # it looks like a sentence
yield out.strip()
out = ""
if out.strip():
yield out.strip()
def get_frames_from_tts_response(self, audio_frame) -> list[QueueFrame]:
return [QueueFrame(FrameType.AUDIO, audio_frame)]
def get_frames_from_chunk(self, chunk) -> Generator[list[QueueFrame], Any, None]:
for audio_frame in self.services.tts.run_tts(chunk):
yield self.get_frames_from_tts_response(audio_frame)
def start_preparation(self) -> None:
self.preparation_iterator = self.get_preparation_iterator()
def continue_preparation(self) -> None:
for chunk in self.preparation_iterator:
if self.state not in [
AsyncProcessorState.READY,
AsyncProcessorState.PLAYING,
]:
break
self.process_chunk(chunk)
def process_chunk(self, chunk) -> None:
self.chunks_in_preparation.put((chunk, self.get_frames_from_chunk(chunk)))
def preparation_done(self):
self.chunks_in_preparation.put((None, None))
def do_play(self) -> None:
while True:
if self.state not in [
AsyncProcessorState.READY,
AsyncProcessorState.PLAYING,
]:
break
prepared_chunk = self.chunks_in_preparation.get()
if prepared_chunk[0] == None:
return
self.play_prepared_chunk(prepared_chunk)
def play_prepared_chunk(self, prepared_chunk) -> None:
chunk, tts_generator = prepared_chunk
for frames in tts_generator:
if self.state not in [
AsyncProcessorState.READY,
AsyncProcessorState.PLAYING,
]:
break
if not self.has_sent_first_frame:
self.output_queue.put(QueueFrame(FrameType.START_STREAM, None))
self.has_sent_first_frame = True
for frame in frames:
self.output_queue.put(frame)
self.output_queue.join()
self.llm_responses.append(chunk)
def do_finalization(self) -> None:
self.message_handler.add_assistant_messages(self.llm_responses)
def do_interruption(self) -> None:
self.chunks_in_preparation.put((None, None))
if self.prepare_thread and self.prepare_thread.is_alive():
self.prepare_thread.join()
if self.play_thread and self.play_thread.is_alive():
self.play_thread.join()
@dataclass(frozen=True)
class ConversationProcessorCollection:
introduction: Optional[Type[OrchestratorResponse]] = None
waiting: Optional[Type[OrchestratorResponse]] = None
response: Optional[Type[OrchestratorResponse]] = None
goodbye: Optional[Type[OrchestratorResponse]] = None

View File

@@ -0,0 +1,76 @@
import asyncio
import copy
import functools
from typing import AsyncGenerator, Awaitable, Callable
from dailyai.queue_aggregators import LLMContextAggregator
from dailyai.queue_frame import EndStreamQueueFrame, QueueFrame, TranscriptionQueueFrame
class InterruptibleConversationWrapper:
def __init__(
self,
frame_generator: Callable[[], AsyncGenerator[QueueFrame, None]],
runner: Callable[
[str, LLMContextAggregator, LLMContextAggregator], Awaitable[None]
],
interrupt: Callable[[], None],
my_participant_id: str|None,
llm_messages: list[dict[str, str]],
llm_context_aggregator_in=LLMContextAggregator,
llm_context_aggregator_out=LLMContextAggregator,
delay_before_speech_seconds: float = 1.0,
):
self._frame_generator: Callable[[], AsyncGenerator[QueueFrame, None]] = frame_generator
self._runner: Callable[
[str, LLMContextAggregator, LLMContextAggregator], Awaitable[None]
] = runner
self._interrupt: Callable[[], None] = interrupt
self._my_participant_id = my_participant_id
self._messages: list[dict[str, str]] = llm_messages
self._delay_before_speech_seconds = delay_before_speech_seconds
self._llm_context_aggregator_in = llm_context_aggregator_in
self._llm_context_aggregator_out = llm_context_aggregator_out
self._current_phrase = ""
def update_messages(self, new_messages: list[dict[str, str]], task: asyncio.Task | None):
if task:
if not task.cancelled():
self._current_phrase = ""
self._messages = new_messages
async def speak_after_delay(self, user_speech, messages):
await asyncio.sleep(self._delay_before_speech_seconds)
tma_in = self._llm_context_aggregator_in(
messages, "user", self._my_participant_id, False
)
tma_out = self._llm_context_aggregator_out(
messages, "assistant", self._my_participant_id
)
await self._runner(user_speech, tma_in, tma_out)
async def run_conversation(self):
current_response_task = None
async for frame in self._frame_generator():
if isinstance(frame, EndStreamQueueFrame):
break
elif not isinstance(frame, TranscriptionQueueFrame):
continue
if frame.participantId == self._my_participant_id:
continue
if current_response_task:
current_response_task.cancel()
self._interrupt()
self._current_phrase += " " + frame.text
current_llm_messages = copy.deepcopy(self._messages)
current_response_task = asyncio.create_task(
self.speak_after_delay(self._current_phrase, current_llm_messages)
)
current_response_task.add_done_callback(
functools.partial(self.update_messages, current_llm_messages)
)

View File

@@ -1,127 +0,0 @@
import logging
import time
from dataclasses import dataclass
from queue import Queue, Empty
from threading import Thread
from dailyai.storage.search import SearchIndexer
from dailyai.services.ai_services import AIServiceConfig
@dataclass
class Message:
type: str
timestamp: float
message: str
class MessageHandler:
def __init__(self, intro):
self.messages: list[Message] = [Message("system", time.time(), intro)]
self.last_user_message_idx:int | None = None
self.finalized_user_message_idx: int | None = None
def add_user_message(self, message) -> None:
if self.last_user_message_idx is not None and self.last_user_message_idx != self.finalized_user_message_idx:
previous_message: str = self.messages[self.last_user_message_idx].message
self.messages[self.last_user_message_idx] = Message(
"user", time.time(), ' '.join([previous_message, message])
)
self.messages = self.messages[: self.last_user_message_idx + 1]
else:
self.messages.append(Message("user", time.time(), message))
self.last_user_message_idx = len(self.messages) - 1
def add_assistant_message(self, message) -> None:
if self.messages[-1].type == "assistant":
self.messages[-1].message += " " + message
else:
self.messages.append(Message("assistant", time.time(), message))
def add_assistant_messages(self, messages) -> None:
self.messages.append(Message("assistant", time.time(), " ".join(messages)))
def get_llm_messages(self) -> list[dict[str, str]]:
return [{"role": m.type, "content": m.message} for m in self.messages]
def finalize_user_message(self) -> None:
self.finalized_user_message_idx = self.last_user_message_idx
def shutdown(self) -> None:
pass
class IndexingMessageHandler(MessageHandler):
def __init__(
self, intro, services: AIServiceConfig, indexer: SearchIndexer
) -> None:
super().__init__(intro)
self.services = services
self.search_indexer = indexer
self.last_written_idx = 0
self.storage_message_queue = Queue()
self.index_writer_thread = Thread(target=self.storage_writer, daemon=True)
self.index_writer_thread.start()
self.logger = logging.getLogger("dailyai")
def shutdown(self):
self.finalize_user_message()
self.storage_message_queue.put(None)
self.index_writer_thread.join()
def storage_writer(self) -> None:
while True:
try:
message_idx = self.storage_message_queue.get()
self.storage_message_queue.task_done()
if message_idx is None:
return
if message_idx <= self.last_written_idx:
continue
self.last_written_idx = message_idx
message = self.messages[message_idx]
content = message.message
if message.type == "user":
content = self.cleanup_user_message(content)
# sometimes the LLM returns a string wrapped in quotes and sometimes it doesn't.
# if it didn't, wrap it in quotes
if content[0] != '"':
content = '"' + content + '"'
self.search_indexer.index_text(content)
except Empty:
pass
def cleanup_user_message(self, user_message) -> str:
return user_message
def finalize_user_message(self):
super().finalize_user_message()
self.write_messages_to_storage()
def write_messages_to_storage(self):
if self.finalized_user_message_idx is None:
return
for idx in range(self.last_written_idx, len(self.messages)):
self.logger.info(
f"Writing to storage: {self.messages[idx].type} {self.messages[idx].message}"
)
if (
self.messages[idx].type == "user"
and idx > self.finalized_user_message_idx
):
break
if self.messages[idx].type != "system":
self.storage_message_queue.put(idx)

View File

@@ -1,409 +0,0 @@
import logging
import os
import time
import wave
from dataclasses import dataclass
from enum import Enum
from queue import Queue, Empty
from opentelemetry import trace, context
from dailyai.async_processor.async_processor import (
AsyncProcessor,
AsyncProcessorState,
ConversationProcessorCollection,
OrchestratorResponse,
LLMResponse,
)
from dailyai.queue_frame import QueueFrame, FrameType
from dailyai.services.ai_services import AIServiceConfig
from dailyai.message_handler.message_handler import MessageHandler
from threading import Thread, Semaphore, Event, Timer
from opentelemetry import context
from opentelemetry.context.context import Context
from daily import (
EventHandler,
CallClient,
Daily,
VirtualCameraDevice,
VirtualMicrophoneDevice,
VirtualSpeakerDevice,
)
@dataclass
class OrchestratorConfig:
room_url: str
token: str
bot_name: str
expiration: float
# Note that we use this as a default parameter value in the Orchestrator
# constructor. The dataclass is defined with Frozen=True, so this should
# be safe.
default_conversation_collection = ConversationProcessorCollection(
introduction=LLMResponse,
waiting=None,
response=LLMResponse,
goodbye=None,
)
class Orchestrator(EventHandler):
def __init__(
self,
daily_config: OrchestratorConfig,
ai_service_config: AIServiceConfig,
message_handler: MessageHandler,
conversation_processors: ConversationProcessorCollection = default_conversation_collection,
tracer=None,
):
self.bot_name: str = daily_config.bot_name
self.room_url: str = daily_config.room_url
self.token: str = daily_config.token
self.expiration: float = daily_config.expiration
self.logger: logging.Logger = logging.getLogger("dailyai")
self.tracer = tracer or trace.get_tracer("orchestrator")
self.ctx: Context = context.get_current()
self.transcription = ""
self.last_fragment_at = None
self.talked_at = None
self.paused_at = None
self.logger.info(f"Creating Response for introductions")
self.services: AIServiceConfig = ai_service_config
self.output_queue = Queue()
self.is_interrupted = Event()
self.stop_threads = Event()
self.story_started = False
self.message_handler = message_handler
self.conversation_processors: ConversationProcessorCollection = conversation_processors
if conversation_processors.introduction is not None:
intro = conversation_processors.introduction(
services=self.services, message_handler=self.message_handler, output_queue=self.output_queue
)
intro.prepare()
intro.set_state_callback(AsyncProcessorState.DONE, self.on_intro_played)
intro.set_state_callback(AsyncProcessorState.FINALIZED, self.on_intro_finished)
self.logger.info(f"Introduction is preparing")
self.current_response: AsyncProcessor = intro
self.can_interrupt = False
# self.response_event.set()
self.response_semaphore = Semaphore()
self.speech_timeout = None
self.interrupt_time = None
self.logger.info("Configuring daily")
self.configure_daily()
def configure_daily(self):
Daily.init()
self.client = CallClient(event_handler=self)
self.logger.info(f"Mic sample rate: {self.services.tts.get_mic_sample_rate()}")
self.mic: VirtualMicrophoneDevice = Daily.create_microphone_device(
"mic", sample_rate=self.services.tts.get_mic_sample_rate(), channels=1
)
self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
"speaker", sample_rate=16000, channels=1
)
self.camera: VirtualCameraDevice = Daily.create_camera_device(
"camera", width=720, height=1280, color_format="RGB"
)
Daily.select_speaker_device("speaker")
self.client.set_user_name(self.bot_name)
self.client.join(self.room_url, self.token, completion=self.call_joined)
self.client.update_inputs(
{
"camera": {
"isEnabled": True,
"settings": {
"deviceId": "camera",
},
},
"microphone": {
"isEnabled": True,
"settings": {
"deviceId": "mic",
"customConstraints": {
"autoGainControl": {"exact": False},
"echoCancellation": {"exact": False},
"noiseSuppression": {"exact": False},
},
},
},
}
)
self.client.update_publishing(
{
"camera": {
"sendSettings": {
"maxQuality": "low",
"encodings": {
"low": {
"maxBitrate": 250000,
"scaleResolutionDownBy": 1.333,
"maxFramerate": 8,
}
},
}
}
}
)
self.my_participant_id = self.client.participants()["local"]["id"]
def start(self) -> None:
# TODO: this loop could, I think, be replaced with a timer and an event
self.participant_left = False
try:
participant_count: int = len(self.client.participants())
self.logger.info(f"{participant_count} participants in room")
while time.time() < self.expiration and not self.participant_left:
# all handling of incoming transcriptions happens in on_transcription_message
time.sleep(1)
except Exception as e:
self.logger.error(f"Exception {e}")
finally:
self.client.leave()
def stop(self):
self.logger.info("Stop current response")
if self.current_response:
if self.current_response.state < AsyncProcessorState.INTERRUPTED:
self.current_response.interrupt()
self.logger.info("Wait for state transition")
self.current_response.wait_for_state_transition(AsyncProcessorState.FINALIZED)
self.stop_threads.set()
self.camera_thread.join()
self.logger.info("Camera thread stopped")
self.logger.info("Put stop in output queue")
self.output_queue.put(QueueFrame(FrameType.END_STREAM, None))
self.frame_consumer_thread.join()
self.logger.info("Orchestrator stopped.")
def on_intro_played(self, intro):
self.logger.info(f"Introduction has played")
self.can_interrupt = True
intro.finalize()
def on_intro_finished(self, intro):
self.logger.info(f"Introduction has finished")
waiting = self.conversation_processors.waiting(self.services, self.message_handler, self.output_queue)
waiting.prepare()
waiting.play()
def on_response_played(self, response):
response.finalize()
def on_response_finished(self, response):
if not response.was_interrupted:
self.message_handler.finalize_user_message()
def call_joined(self, join_data, client_error):
self.logger.info(f"Call_joined: {join_data}, {client_error}")
self.client.start_transcription(
{
"language": "en",
"tier": "nova",
"model": "2-conversationalai",
"profanity_filter": True,
"redact": False,
"extra": {
"endpointing": True,
"punctuate": False,
}
}
)
def on_participant_joined(self, participant):
with self.tracer.start_as_current_span("on_participant_joined", context=self.ctx):
self.logger.info(f"on_participant_joined: {participant}")
# TODO: figure out the architecture to get the story id to the client
# self.client.send_app_message({"event": "story-id", "storyID": self.story_id})
time.sleep(2)
if not self.story_started:
self.action()
self.story_started = True
def on_participant_left(self, participant, reason):
self.logger.info(f"Participant {participant} left")
if len(self.client.participants()) < 2:
self.participant_left = True
def on_app_message(self, message, sender):
with self.tracer.start_as_current_span("on_app_message", context=self.ctx):
self.logger.info(f"on_app_message {message} from {sender}")
if "isSpeaking" in message and message["isSpeaking"] == True:
self.handle_user_started_talking()
if "isSpeaking" in message and message["isSpeaking"] == False:
self.handle_user_stopped_talking()
def on_transcription_message(self, message):
with self.tracer.start_as_current_span("on_transcription_message", context=self.ctx):
if message["session_id"] != self.my_participant_id:
self.handle_transcription_fragment(message['text'])
def on_transcription_stopped(self, stopped_by, stopped_by_error):
self.logger.info(f"Transcription stopped {stopped_by}, {stopped_by_error}")
def on_transcription_error(self, message):
self.logger.error(f"Transcription error {message}")
def on_transcription_started(self, status):
self.logger.info(f"Transcription started {status}")
def set_image(self, image: bytes):
self.image: bytes | None = image
def run_camera(self):
try:
while not self.stop_threads.is_set():
if self.image:
self.camera.write_frame(self.image)
time.sleep(1.0 / 8.0) # 8 fps
except Exception as e:
self.logger.error(f"Exception {e} in camera thread.")
def handle_user_started_talking(self):
# TODO: allow configuration of the timer timeout
self.logger.error("user started talking")
self.speech_timeout = Timer(1.0, self.utterance_interrupt)
def handle_user_stopped_talking(self):
self.logger.error("user stopped talking, canceling utterance interrupt")
if self.speech_timeout:
self.speech_timeout.cancel()
def utterance_interrupt(self):
self.logger.error("utterance interrupt")
self.is_interrupted.set()
def handle_transcription_fragment(self, fragment):
if not self.can_interrupt:
return
# start generating a new response. We'll do the fast parts of the interrupt
# now but wait for the state transition after we've kicked off the prepare
# on the new response.
if (
self.current_response
and self.current_response.state < AsyncProcessorState.INTERRUPTED
):
self.interrupt_time = time.perf_counter()
self.is_interrupted.set()
self.current_response.interrupt()
self.message_handler.add_user_message(fragment)
response_type: type[OrchestratorResponse] | type[LLMResponse] = self.conversation_processors.response or LLMResponse
new_response: OrchestratorResponse = response_type(
self.services, self.message_handler, self.output_queue
)
new_response.set_state_callback(
AsyncProcessorState.DONE, self.on_response_played
)
new_response.set_state_callback(
AsyncProcessorState.FINALIZED, self.on_response_finished
)
new_response.prepare()
self.response_semaphore.acquire()
if (
self.current_response
and self.current_response.state < AsyncProcessorState.INTERRUPTED
):
self.current_response.wait_for_state_transition(
AsyncProcessorState.FINALIZED
)
self.current_response = new_response
self.current_response.play()
self.response_semaphore.release()
def action(self):
self.logger.info("Starting camera thread")
self.image: bytes | None = None
self.camera_thread = Thread(target=self.run_camera, daemon=True)
self.camera_thread.start()
self.logger.info("Starting frame consumer thread")
self.frame_consumer_thread = Thread(target=self.frame_consumer, daemon=True)
self.frame_consumer_thread.start()
self.logger.info("Playing introduction")
self.can_interrupt = False
self.current_response.play()
def frame_consumer(self):
self.logger.info("🎬 Starting frame consumer thread")
b = bytearray()
smallest_write_size = 3200
all_audio_frames = bytearray()
while True:
try:
frame:QueueFrame = self.output_queue.get()
if frame.frame_type == FrameType.END_STREAM:
self.logger.info("Stopping frame consumer thread")
return
# if interrupted, we just pull frames off the queue and discard them
if not self.is_interrupted.is_set():
if frame:
if frame.frame_type == FrameType.AUDIO:
chunk = frame.frame_data
all_audio_frames.extend(chunk)
b.extend(chunk)
l = len(b) - (len(b) % smallest_write_size)
if l:
self.mic.write_frames(bytes(b[:l]))
b = b[l:]
elif frame.frame_type == FrameType.IMAGE:
self.set_image(frame.frame_data)
elif len(b):
self.mic.write_frames(bytes(b))
b = bytearray()
else:
if self.interrupt_time:
self.logger.info(f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}")
self.interrupt_time = None
if frame.frame_type == FrameType.START_STREAM:
self.is_interrupted.clear()
self.output_queue.task_done()
except Empty:
try:
if len(b):
self.mic.write_frames(bytes(b))
except Exception as e:
self.logger.error(f"Exception in frame_consumer: {e}, {len(b)}")
b = bytearray()

View File

@@ -25,23 +25,24 @@ class QueueTee:
await queue.put(frame)
class LLMContextAggregator(AIService):
def __init__(self, messages: list[dict], role:str, bot_participant_id=None):
def __init__(self, messages: list[dict], role:str, bot_participant_id=None, complete_sentences=True):
self.messages = messages
self.bot_participant_id = bot_participant_id
self.role = role
self.sentence = ""
self.complete_sentences = complete_sentences
async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
content: str = ""
# TODO: split up transcription by participant
if isinstance(frame, TextQueueFrame):
content = frame.text
self.sentence += content
if self.sentence.endswith((".", "?", "!")):
self.messages.append({"role": self.role, "content": self.sentence})
self.sentence = ""
yield LLMMessagesQueueFrame(self.messages)
if self.complete_sentences:
self.sentence += frame.text
if self.sentence.endswith((".", "?", "!")):
self.messages.append({"role": self.role, "content": self.sentence})
self.sentence = ""
yield LLMMessagesQueueFrame(self.messages)
else:
self.messages.append({"role": self.role, "content": frame.text})
yield LLMMessagesQueueFrame(self.messages)
yield frame

View File

@@ -5,10 +5,13 @@ from typing import Any
class QueueFrame:
pass
class StartStreamQueueFrame(QueueFrame):
class ControlQueueFrame(QueueFrame):
pass
class EndStreamQueueFrame(QueueFrame):
class StartStreamQueueFrame(ControlQueueFrame):
pass
class EndStreamQueueFrame(ControlQueueFrame):
pass
@dataclass()

View File

@@ -1,2 +1,3 @@
Pillow==10.1.0
typing_extensions==4.9.0
typing_extensions==4.9.0
faster-whisper==0.10.0

View File

@@ -1,8 +1,11 @@
import asyncio
import io
import logging
import wave
from dailyai.queue_frame import (
AudioQueueFrame,
ControlQueueFrame,
EndStreamQueueFrame,
ImageQueueFrame,
LLMMessagesQueueFrame,
@@ -11,7 +14,7 @@ from dailyai.queue_frame import (
)
from abc import abstractmethod
from typing import AsyncGenerator, AsyncIterable, Iterable
from typing import AsyncGenerator, AsyncIterable, BinaryIO, Iterable
from dataclasses import dataclass
@@ -63,9 +66,8 @@ class AIService:
@abstractmethod
async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
# This is a trick for the interpreter (and linter) to know that this is a generator.
if False:
yield QueueFrame()
if isinstance(frame, ControlQueueFrame):
yield frame
@abstractmethod
async def finalize(self) -> AsyncGenerator[QueueFrame, None]:
@@ -83,7 +85,9 @@ class LLMService(AIService):
pass
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
if isinstance(frame, LLMMessagesQueueFrame):
if isinstance(frame, ControlQueueFrame):
yield frame
elif isinstance(frame, LLMMessagesQueueFrame):
async for text_chunk in self.run_llm_async(frame.messages):
yield TextQueueFrame(text_chunk)
@@ -106,9 +110,11 @@ class TTSService(AIService):
yield bytes()
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
if not isinstance(frame, TextQueueFrame):
if isinstance(frame, ControlQueueFrame):
yield frame
return
elif not isinstance(frame, TextQueueFrame):
return
text: str | None = None
if not self.aggregate_sentences:
@@ -145,14 +151,46 @@ class ImageGenService(AIService):
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
if not isinstance(frame, TextQueueFrame):
yield frame
return
(url, image_data) = await self.run_image_gen(frame.text)
yield ImageQueueFrame(url, image_data)
class STTService(AIService):
"""STTService is a base class for speech-to-text services."""
_frame_rate: int
def __init__(self, frame_rate: int = 16000, **kwargs):
super().__init__(**kwargs)
self._frame_rate = frame_rate
@abstractmethod
async def run_stt(self, audio: BinaryIO) -> str:
"""Returns transcript as a string"""
pass
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
"""Processes a frame of audio data, either buffering or transcribing it."""
if not isinstance(frame, AudioQueueFrame):
return
data = frame.data
content = io.BufferedRandom(io.BytesIO())
ww = wave.open(self._content, "wb")
ww.setnchannels(1)
ww.setsampwidth(2)
ww.setframerate(self._frame_rate)
ww.writeframesraw(data)
ww.close()
content.seek(0)
text = await self.run_stt(content)
yield TextQueueFrame(text)
@dataclass
class AIServiceConfig:
tts: TTSService
image: ImageGenService
llm: LLMService
stt: STTService

View File

@@ -5,7 +5,6 @@ import json
from openai import AsyncAzureOpenAI
import os
import requests
from collections.abc import AsyncGenerator
@@ -16,7 +15,10 @@ from PIL import Image
from azure.cognitiveservices.speech import SpeechSynthesizer, SpeechConfig, ResultReason, CancellationReason
class AzureTTSService(TTSService):
def __init__(self, speech_key=None, speech_region=None):
def __init__(
self, speech_key=None, speech_region=None, voice_name="en-US-SaraNeural"
):
super().__init__()
speech_key = speech_key or os.getenv("AZURE_SPEECH_SERVICE_KEY")
@@ -25,20 +27,19 @@ class AzureTTSService(TTSService):
self.speech_config = SpeechConfig(subscription=speech_key, region=speech_region)
self.speech_synthesizer = SpeechSynthesizer(speech_config=self.speech_config, audio_config=None)
self.voice_name = voice_name
async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]:
self.logger.info("Running azure tts")
ssml = "<speak version='1.0' xml:lang='en-US' xmlns='http://www.w3.org/2001/10/synthesis' " \
ssml = f"<speak version='1.0' xml:lang='en-US' xmlns='http://www.w3.org/2001/10/synthesis' " \
"xmlns:mstts='http://www.w3.org/2001/mstts'>" \
"<voice name='en-US-SaraNeural'>" \
f"<voice name={self.voice_name}>" \
"<mstts:silence type='Sentenceboundary' value='20ms' />" \
"<mstts:express-as style='lyrical' styledegree='2' role='SeniorFemale'>" \
"<prosody rate='1.05'>" \
f"{sentence}" \
"</prosody></mstts:express-as></voice></speak> "
try:
result = await asyncio.to_thread(self.speech_synthesizer.speak_ssml, (ssml))
except Exception as e:
self.logger.error("Error in azure tts", e)
result = await asyncio.to_thread(self.speech_synthesizer.speak_ssml, (ssml))
self.logger.info("Got azure tts result")
if result.reason == ResultReason.SynthesizingAudioCompleted:
self.logger.info("Returning result")
@@ -95,81 +96,58 @@ class AzureLLMService(LLMService):
class AzureImageGenServiceREST(ImageGenService):
def __init__(self, image_size:str, api_key=None, azure_endpoint=None, api_version=None, model=None):
def __init__(
self,
image_size: str,
api_key: str | None = None,
azure_endpoint: str | None = None,
api_version: str | None = None,
model: str | None = None,
aiohttp_session: aiohttp.ClientSession | None=None,
timeout_seconds=120,
):
super().__init__(image_size=image_size)
self.api_key = api_key or os.getenv("AZURE_DALLE_KEY")
self.azure_endpoint = azure_endpoint or os.getenv("AZURE_DALLE_ENDPOINT")
self.api_version = api_version or "2023-06-01-preview"
self.model = model or os.getenv("AZURE_DALLE_DEPLOYMENT_ID")
self.aiohttp_session: aiohttp.ClientSession = (
aiohttp_session or aiohttp.ClientSession()
)
self.timeout_seconds = timeout_seconds
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
# TODO hoist the session to app-level
async with aiohttp.ClientSession() as session:
url = f"{self.azure_endpoint}openai/images/generations:submit?api-version={self.api_version}"
headers= { "api-key": self.api_key, "Content-Type": "application/json" }
body = {
# Enter your prompt text here
"prompt": sentence,
"size": self.image_size,
"n": 1,
}
async with session.post(url, headers=headers, json=body) as submission:
operation_location = submission.headers['operation-location']
url = f"{self.azure_endpoint}openai/images/generations:submit?api-version={self.api_version}"
headers= { "api-key": self.api_key, "Content-Type": "application/json" }
body = {
"prompt": sentence,
"size": self.image_size,
"n": 1,
}
async with self.aiohttp_session.post(
url, headers=headers, json=body
) as submission:
operation_location = submission.headers['operation-location']
status = ""
attempts_left = 120
json_response = None
while status != "succeeded":
attempts_left -= 1
if attempts_left == 0:
raise Exception("Image generation timed out")
status = ""
attempts_left = self.timeout_seconds
json_response = None
while status != "succeeded":
attempts_left -= 1
if attempts_left == 0:
raise Exception("Image generation timed out")
await asyncio.sleep(1)
response = await session.get(operation_location, headers=headers)
json_response = await response.json()
status = json_response["status"]
await asyncio.sleep(1)
response = await self.aiohttp_session.get(operation_location, headers=headers)
json_response = await response.json()
status = json_response["status"]
image_url = json_response["result"]["data"][0]["url"] if json_response else None
if not image_url:
raise Exception("Image generation failed")
image_url = json_response["result"]["data"][0]["url"] if json_response else None
if not image_url:
raise Exception("Image generation failed")
# Load the image from the url
async with session.get(image_url) as response:
image_stream = io.BytesIO(await response.content.read())
image = Image.open(image_stream)
return (image_url, image.tobytes())
class AzureImageGenService(ImageGenService):
def __init__(self, api_key=None, azure_endpoint=None, api_version=None, model=None):
super().__init__()
api_key = api_key or os.getenv("AZURE_DALLE_KEY")
azure_endpoint = azure_endpoint or os.getenv("AZURE_DALLE_ENDPOINT")
api_version = api_version or "2023-06-01-preview"
self.model = model or os.getenv("AZURE_DALLE_DEPLOYMENT_ID")
self.client = AzureOpenAI(
api_key=api_key,
azure_endpoint=azure_endpoint,
api_version=api_version,
)
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
self.logger.info("Generating azure image", sentence)
image = self.client.images.generate(
model=self.model,
prompt=sentence,
n=1,
size=self.image_size,
)
url = image["data"][0]["url"]
response = requests.get(url)
dalle_stream = io.BytesIO(response.content)
dalle_im = Image.open(dalle_stream.tobytes())
return (url, dalle_im)
# Load the image from the url
async with self.aiohttp_session.get(image_url) as response:
image_stream = io.BytesIO(await response.content.read())
image = Image.open(image_stream)
return (image_url, image.tobytes())

View File

@@ -7,6 +7,7 @@ import types
from functools import partial
from queue import Queue, Empty
from typing import AsyncGenerator
from dailyai.queue_frame import (
AudioQueueFrame,
@@ -14,6 +15,7 @@ from dailyai.queue_frame import (
ImageQueueFrame,
QueueFrame,
StartStreamQueueFrame,
TextQueueFrame,
TranscriptionQueueFrame,
)
@@ -31,6 +33,10 @@ from daily import (
class DailyTransportService(EventHandler):
_daily_initialized = False
_lock = threading.Lock()
speaker_enabled: bool
speaker_sample_rate: int
def __init__(
self,
room_url: str,
@@ -38,6 +44,9 @@ class DailyTransportService(EventHandler):
bot_name: str,
duration: float = 10,
min_others_count: int = 1,
start_transcription: bool = True,
speaker_enabled: bool = False,
speaker_sample_rate: int = 16000,
):
super().__init__()
self.bot_name: str = bot_name
@@ -46,6 +55,7 @@ class DailyTransportService(EventHandler):
self.duration: float = duration
self.expiration = time.time() + duration * 60
self.min_others_count = min_others_count
self.start_transcription = start_transcription
# This queue is used to marshal frames from the async send queue to the thread that emits audio & video.
# We need this to maintain the asynchronous behavior of asyncio queues -- to give async functions
@@ -61,6 +71,8 @@ class DailyTransportService(EventHandler):
self.camera_width = 1024
self.camera_height = 768
self.camera_enabled = False
self.speaker_enabled = speaker_enabled
self.speaker_sample_rate = speaker_sample_rate
self.send_queue = asyncio.Queue()
self.receive_queue = asyncio.Queue()
@@ -104,6 +116,7 @@ class DailyTransportService(EventHandler):
handler(*args, **kwargs)
except Exception as e:
self.logger.error(f"Exception in event handler {event_name}: {e}")
raise e
def add_event_handler(self, event_name: str, handler):
if not event_name.startswith("on_"):
@@ -144,9 +157,11 @@ class DailyTransportService(EventHandler):
"camera", width=self.camera_width, height=self.camera_height, color_format="RGB"
)
self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
"speaker", sample_rate=16000, channels=1
)
if self.speaker_enabled:
self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
"speaker", sample_rate=self.speaker_sample_rate, channels=1
)
Daily.select_speaker_device("speaker")
self.image: bytes | None = None
self.camera_thread = Thread(target=self.run_camera, daemon=True)
@@ -156,8 +171,6 @@ class DailyTransportService(EventHandler):
self.frame_consumer_thread = Thread(target=self.frame_consumer, daemon=True)
self.frame_consumer_thread.start()
Daily.select_speaker_device("speaker")
self.client.set_user_name(self.bot_name)
self.client.join(self.room_url, self.token, completion=self.call_joined)
self.my_participant_id = self.client.participants()["local"]["id"]
@@ -201,10 +214,24 @@ class DailyTransportService(EventHandler):
}
)
if self.token:
if self.token and self.start_transcription:
self.client.start_transcription(self.transcription_settings)
async def get_receive_frames(self):
def _receive_audio(self):
"""Receive audio from the Daily call and put it on the receive queue"""
seconds = 1
desired_frame_count = self.speaker_sample_rate * seconds
while True:
buffer = self.speaker.read_frames(desired_frame_count)
if len(buffer) > 0:
frame = AudioQueueFrame(buffer)
if self.loop:
asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self.loop)
def interrupt(self):
self.is_interrupted.set()
async def get_receive_frames(self) -> AsyncGenerator[QueueFrame, None]:
while True:
frame = await self.receive_queue.get()
yield frame
@@ -244,6 +271,7 @@ class DailyTransportService(EventHandler):
await asyncio.sleep(1)
except Exception as e:
self.logger.error(f"Exception {e}")
raise e
finally:
self.client.leave()
@@ -266,6 +294,9 @@ class DailyTransportService(EventHandler):
def call_joined(self, join_data, client_error):
self.logger.info(f"Call_joined: {join_data}, {client_error}")
if self.speaker_enabled:
t = Thread(target=self._receive_audio, daemon=True)
t.start()
def on_error(self, error):
self.logger.error(f"on_error: {error}")
@@ -317,6 +348,7 @@ class DailyTransportService(EventHandler):
time.sleep(1.0 / 8) # 8 fps
except Exception as e:
self.logger.error(f"Exception {e} in camera thread.")
raise e
def frame_consumer(self):
self.logger.info("🎬 Starting frame consumer thread")
@@ -358,11 +390,11 @@ class DailyTransportService(EventHandler):
self.mic.write_frames(bytes(b))
b = bytearray()
else:
if self.interrupt_time:
self.logger.info(
f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}"
)
self.interrupt_time = None
# if there are leftover audio bytes, write them now; failing to do so
# can cause static in the audio stream.
if len(b):
self.mic.write_frames(bytes(b))
b = bytearray()
if isinstance(frame, StartStreamQueueFrame):
self.is_interrupted.clear()
@@ -374,5 +406,6 @@ class DailyTransportService(EventHandler):
self.mic.write_frames(bytes(b))
except Exception as e:
self.logger.error(f"Exception in frame_consumer: {e}, {len(b)}")
raise e
b = bytearray()

View File

@@ -9,28 +9,30 @@ from dailyai.services.ai_services import TTSService
class ElevenLabsTTSService(TTSService):
def __init__(self, api_key=None, voice_id=None):
def __init__(self, api_key=None, voice_id=None, aiohttp_session:aiohttp.ClientSession=None):
super().__init__()
self.api_key = api_key or os.getenv("ELEVENLABS_API_KEY")
self.voice_id = voice_id or os.getenv("ELEVENLABS_VOICE_ID")
self.aiohttp_session = aiohttp_session or aiohttp.ClientSession()
async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]:
async with aiohttp.ClientSession() as session:
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.voice_id}/stream"
payload = {"text": sentence, "model_id": "eleven_turbo_v2"}
querystring = {"output_format": "pcm_16000", "optimize_streaming_latency": 2}
headers = {
"xi-api-key": self.api_key,
"Content-Type": "application/json",
}
async with session.post(url, json=payload, headers=headers, params=querystring) as r:
if r.status != 200:
self.logger.error(
f"audio fetch status code: {r.status}, error: {r.text}"
)
return
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.voice_id}/stream"
payload = {"text": sentence, "model_id": "eleven_turbo_v2"}
querystring = {"output_format": "pcm_16000", "optimize_streaming_latency": 2}
headers = {
"xi-api-key": self.api_key,
"Content-Type": "application/json",
}
async with self.aiohttp_session.post(
url, json=payload, headers=headers, params=querystring
) as r:
if r.status != 200:
self.logger.error(
f"audio fetch status code: {r.status}, error: {r.text}"
)
return
async for chunk in r.content:
if chunk:
yield chunk
async for chunk in r.content:
if chunk:
yield chunk

View File

@@ -34,9 +34,9 @@ class FalImageGenService(ImageGenService):
raise Exception("Image generation failed")
return image_url
print(f"fetching image url...")
print("fetching image url...")
image_url = await asyncio.to_thread(get_image_url, sentence, self.image_size)
print(f"got image url, downloading image...")
print("got image url, downloading image...")
# Load the image from the url
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:

View File

@@ -0,0 +1,72 @@
import array
import io
import math
from typing import AsyncGenerator
import wave
from dailyai.queue_frame import AudioQueueFrame, QueueFrame, TextQueueFrame
from dailyai.services.ai_services import STTService
class LocalSTTService(STTService):
_content: io.BufferedRandom
_wave: wave.Wave_write
_current_silence_frames: int
# Configuration
_min_rms: int
_max_silence_frames: int
_frame_rate: int
def __init__(self,
min_rms: int = 400,
max_silence_frames: int = 3,
frame_rate: int = 16000,
**kwargs):
super().__init__(frame_rate, **kwargs)
self._current_silence_frames = 0
self._min_rms = min_rms
self._max_silence_frames = max_silence_frames
self._frame_rate = frame_rate
self._new_wave()
def _new_wave(self):
"""Creates a new wave object and content buffer."""
self._content = io.BufferedRandom(io.BytesIO())
ww = wave.open(self._content, "wb")
ww.setnchannels(1)
ww.setsampwidth(2)
ww.setframerate(self._frame_rate)
self._wave = ww
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
"""Processes a frame of audio data, either buffering or transcribing it."""
if not isinstance(frame, AudioQueueFrame):
return
data = frame.data
# Try to filter out empty background noise
# (Very rudimentary approach, can be improved)
rms = self._get_volume(data)
if rms >= self._min_rms:
# If volume is high enough, write new data to wave file
self._wave.writeframesraw(data)
# If buffer is not empty and we detect a 3-frame pause in speech,
# transcribe the audio gathered so far.
if self._content.tell() > 0 and self._current_silence_frames > self._max_silence_frames:
self._current_silence_frames = 0
self._wave.close()
self._content.seek(0)
text = await self.run_stt(self._content)
self._new_wave()
yield TextQueueFrame(text)
# If we get this far, this is a frame of silence
self._current_silence_frames += 1
def _get_volume(self, audio: bytes) -> float:
# https://docs.python.org/3/library/array.html
audio_array = array.array('h', audio)
squares = [sample**2 for sample in audio_array]
mean = sum(squares) / len(audio_array)
rms = math.sqrt(mean)
return rms

View File

@@ -1,6 +1,4 @@
import requests
import aiohttp
import asyncio
from PIL import Image
import io
from openai import AsyncOpenAI
@@ -9,7 +7,7 @@ import os
import json
from collections.abc import AsyncGenerator
from dailyai.services.ai_services import AIService, TTSService, LLMService, ImageGenService
from dailyai.services.ai_services import LLMService, ImageGenService
class OpenAILLMService(LLMService):
@@ -50,11 +48,19 @@ class OpenAILLMService(LLMService):
return None
class OpenAIImageGenService(ImageGenService):
def __init__(self, image_size:str, api_key=None, model=None):
def __init__(
self,
image_size: str,
api_key=None,
model=None,
aiohttp_session: aiohttp.ClientSession | None = None,
):
super().__init__(image_size=image_size)
api_key = api_key or os.getenv("OPEN_AI_KEY")
self.model = model or os.getenv("OPEN_AI_IMAGE_MODEL") or "dall-e-3"
self.client = AsyncOpenAI(api_key=api_key)
self.aiohttp_session=aiohttp_session or aiohttp.ClientSession()
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
self.logger.info("Generating OpenAI image", sentence)
@@ -70,10 +76,7 @@ class OpenAIImageGenService(ImageGenService):
raise Exception("No image provided in response", image)
# Load the image from the url
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
image_stream = io.BytesIO(await response.content.read())
image = Image.open(image_stream)
return (image_url, image.tobytes())
return (image_url, dalle_im.tobytes())
async with self.aiohttp_session.get(image_url) as response:
image_stream = io.BytesIO(await response.content.read())
image = Image.open(image_stream)
return (image_url, image.tobytes())

View File

@@ -0,0 +1,55 @@
"""This module implements Whisper transcription with a locally-downloaded model."""
import asyncio
from enum import Enum
import logging
from typing import BinaryIO
from faster_whisper import WhisperModel
from dailyai.services.local_stt_service import LocalSTTService
class Model(Enum):
"""Class of basic Whisper model selection options"""
TINY = "tiny"
BASE = "base"
MEDIUM = "medium"
LARGE = "large-v3"
DISTIL_LARGE_V2 = "Systran/faster-distil-whisper-large-v2"
DISTIL_MEDIUM_EN = "Systran/faster-distil-whisper-medium.en"
class WhisperSTTService(LocalSTTService):
"""Class to transcribe audio with a locally-downloaded Whisper model"""
_model: WhisperModel
# Model configuration
_model_name: Model
_device: str
_compute_type: str
def __init__(self, model_name: Model = Model.DISTIL_MEDIUM_EN,
device: str = "auto",
compute_type: str = "default"):
super().__init__()
self.logger: logging.Logger = logging.getLogger("dailyai")
self._model_name = model_name
self._device = device
self._compute_type = compute_type
self._load()
def _load(self):
"""Loads the Whisper model. Note that if this is the first time
this model is being run, it will take time to download."""
model = WhisperModel(
self._model_name.value,
device=self._device,
compute_type=self._compute_type)
self._model = model
async def run_stt(self, audio: BinaryIO = None) -> str:
"""Transcribes given audio using Whisper"""
segments, _ = await asyncio.to_thread(self._model.transcribe, audio)
res: str = ""
for segment in segments:
res += f"{segment.text} "
return res

View File

@@ -1,180 +0,0 @@
import time
import unittest
from queue import Queue, Empty
from threading import Thread, Event
from typing import Generator
from dailyai.async_processor.async_processor import (
AsyncProcessor,
AsyncProcessorState,
LLMResponse,
)
from dailyai.message_handler.message_handler import MessageHandler
from dailyai.queue_frame import QueueFrame, FrameType
from dailyai.services.ai_services import (
AIServiceConfig,
ImageGenService,
LLMService,
TTSService,
)
"""
class MockTTSService(TTSService):
def run_tts(self, sentence):
for word in sentence.split(' '):
time.sleep(0.1)
yield bytes(word, "utf-8")
class MockLLMService(LLMService):
def run_llm_async(self, messages) -> Generator[str, None, None]:
for i in ["Hello ", "there.", "How are ", "you?", "I ", "hope ", "you ", "are ", "well."]:
time.sleep(0.1)
yield i
class MockImageService(ImageGenService):
def run_image_gen(self, sentence) -> None:
return None
class TestResponse(unittest.TestCase):
def test_base_state_transitions(self):
mock_tts_service = MockTTSService()
mock_llm_service = MockLLMService()
mock_image_service = MockImageService()
processor = AsyncProcessor(AIServiceConfig(tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service))
processor.prepare()
processor.play()
processor.finalize()
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
def test_state_transitions(self):
output_queue = Queue()
mock_tts_service = MockTTSService()
mock_llm_service = MockLLMService()
mock_image_service = MockImageService()
message_handler = MessageHandler("Hello World")
processor = LLMResponse(
AIServiceConfig(
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
),
message_handler,
output_queue,
)
processor.prepare()
processor.play()
# Consume the output from the output queue. It's necessary to mark these tasks as done for the
# play function to return.
expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."]
# remove the "start_stream" message from the queue
output_queue.get()
output_queue.task_done()
while expected_words:
actual_word:QueueFrame = output_queue.get()
word = expected_words.pop(0)
self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME)
self.assertEqual(actual_word.frame_data, bytes(word, "utf-8"))
output_queue.task_done()
processor.finalize()
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
def test_interrupt_preparation(self):
output_queue = Queue()
mock_tts_service = MockTTSService()
mock_llm_service = MockLLMService()
mock_image_service = MockImageService()
message_handler = MessageHandler("System Message")
processor = LLMResponse(
AIServiceConfig(
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
),
message_handler,
output_queue,
)
processor.prepare()
interrupt_request_at = time.perf_counter()
processor.interrupt()
processor.finalize()
finalized_at = time.perf_counter()
self.assertTrue(0.1 < finalized_at - interrupt_request_at < 0.2)
print(f"delta: {interrupt_request_at, finalized_at}")
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
def test_interrupt_play(self):
output_queue = Queue()
mock_tts_service = MockTTSService()
mock_llm_service = MockLLMService()
mock_image_service = MockImageService()
message_handler = MessageHandler("System Message")
processor = LLMResponse(
AIServiceConfig(
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
),
message_handler,
output_queue,
)
processor.prepare()
processor.play()
stop_processing_output_queue = Event()
def process_output_queue_async():
# Consume the output from the output queue. It's necessary to mark these tasks as done for the
# play function to return.
time.sleep(0.1)
expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."]
while expected_words and not stop_processing_output_queue.is_set():
try:
actual_word:QueueFrame = output_queue.get_nowait()
if actual_word.frame_type == FrameType.AUDIO_FRAME:
time.sleep(0.1)
word = expected_words.pop(0)
self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME)
self.assertEqual(actual_word.frame_data, bytes(word, "utf-8"))
output_queue.task_done()
except Empty:
pass
process_output_queue = Thread(target=process_output_queue_async, daemon=True)
process_output_queue.start()
time.sleep(0.5)
processor.interrupt()
stop_processing_output_queue.set()
process_output_queue.join()
processor.finalize()
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
def test_statechange_callback(self):
mock_tts_service = MockTTSService()
mock_llm_service = MockLLMService()
mock_image_service = MockImageService()
processor = AsyncProcessor(
AIServiceConfig(
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
)
)
is_finalized = False
def set_is_finalized(async_processor:AsyncProcessor):
nonlocal is_finalized
is_finalized = True
processor.set_state_callback(
AsyncProcessorState.FINALIZED, set_is_finalized
)
processor.prepare()
self.assertFalse(is_finalized)
processor.play()
self.assertFalse(is_finalized)
processor.finalize()
self.assertTrue(is_finalized)
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
if __name__ == '__main__':
unittest.main()
"""

View File

@@ -1,147 +0,0 @@
import time
import unittest
from unittest.mock import MagicMock, call
from dailyai.message_handler.message_handler import MessageHandler, IndexingMessageHandler
from dailyai.services.ai_services import (
AIServiceConfig,
TTSService,
LLMService,
ImageGenService,
)
from ..storage.search import SearchIndexer
class TestMessageHandler(unittest.TestCase):
def test_simple_intro(self):
message_handler = MessageHandler("Hello world")
self.assertEqual(
message_handler.get_llm_messages(),
[{"role": "system", "content": "Hello world"}],
)
def test_simple_user_message(self):
message_handler = MessageHandler("System prompt")
message_handler.add_user_message("User message")
self.assertEqual(
message_handler.get_llm_messages(),
[
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "User message"},
],
)
def test_simple_user_and_assistant_message(self):
message_handler = MessageHandler("System prompt")
message_handler.add_user_message("User message")
message_handler.add_assistant_message("Assistant message")
self.assertEqual(
message_handler.get_llm_messages(),
[
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "User message"},
{"role": "assistant", "content": "Assistant message"},
],
)
def test_user_message_overwrite(self):
message_handler = MessageHandler("System prompt")
message_handler.add_user_message("User message")
message_handler.add_assistant_message("Assistant message")
message_handler.add_user_message("plus something else")
self.assertEqual(
message_handler.get_llm_messages(),
[
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "User message plus something else"},
],
)
def test_user_message_after_assistant(self):
message_handler = MessageHandler("System prompt")
message_handler.add_user_message("User message")
message_handler.add_assistant_message("Assistant message")
message_handler.finalize_user_message()
message_handler.add_user_message("other user message")
self.assertEqual(
message_handler.get_llm_messages(),
[
{"role": "system", "content": "System prompt"},
{"role": "user", "content": "User message"},
{"role": "assistant", "content": "Assistant message"},
{"role": "user", "content": "other user message"},
],
)
class MockTTSService(TTSService):
def run_tts(self, sentence):
for word in sentence.split(" "):
time.sleep(0.1)
yield bytes(word, "utf-8")
class MockLLMService(LLMService):
def run_llm(self, messages) -> str:
return "Parsed user message."
class MockImageService(ImageGenService):
def run_image_gen(self, sentence) -> None:
return None
class TestStorageMessageHandler(unittest.TestCase):
def test_user_message_finalized(self):
mock_tts_service = MockTTSService()
mock_llm_service = MockLLMService()
mock_image_service = MockImageService()
service_config = AIServiceConfig(
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
)
mock_indexer = MagicMock(spec=SearchIndexer)
message_handler = IndexingMessageHandler(
"Hello world", service_config, mock_indexer
)
message_handler.cleanup_user_message = MagicMock(return_value="Parsed user message.")
message_handler.add_user_message("User message")
message_handler.add_assistant_message("Assistant message will be ignored")
message_handler.add_user_message("plus something else")
message_handler.finalize_user_message()
message_handler.add_assistant_message(
"New assistant message will not be ignored"
)
message_handler.add_user_message("User message second time")
message_handler.add_assistant_message("Assistant message second time")
message_handler.write_messages_to_storage()
time.sleep(0.5)
message_handler.cleanup_user_message.assert_called_with("User message plus something else")
self.assertEqual(
mock_indexer.mock_calls,
[
call.index_text('"Parsed user message."'),
call.index_text("New assistant message will not be ignored"),
],
)
mock_indexer.reset_mock()
message_handler.finalize_user_message()
time.sleep(0.5)
self.assertEqual(
mock_indexer.mock_calls,
[
call.index_text('"Parsed user message."'),
call.index_text("Assistant message second time"),
],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -4,6 +4,7 @@ import asyncio
from dailyai.queue_frame import TextQueueFrame
from dailyai.services.daily_transport_service import DailyTransportService
from dailyai.services.open_ai_services import OpenAIImageGenService
from dailyai.services.azure_ai_services import AzureImageGenServiceREST
local_joined = False
participant_joined = False
@@ -21,7 +22,7 @@ async def main(room_url):
transport.camera_width = 1024
transport.camera_height = 1024
imagegen = OpenAIImageGenService(image_size="1024x1024")
imagegen = AzureImageGenServiceREST(image_size="1024x1024")
image_task = asyncio.create_task(
imagegen.run_to_queue(transport.send_queue, [TextQueueFrame("a cat in the style of picasso")])
)

View File

@@ -2,7 +2,7 @@ import argparse
import asyncio
from dailyai.queue_frame import AudioQueueFrame, ImageQueueFrame
from dailyai.services.azure_ai_services import AzureLLMService
from dailyai.services.azure_ai_services import AzureImageGenServiceREST, AzureLLMService
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
from dailyai.services.daily_transport_service import DailyTransportService
from dailyai.services.fal_ai_services import FalImageGenService
@@ -22,9 +22,9 @@ async def main(room_url):
transport.camera_height = 1024
llm = AzureLLMService()
dalle = FalImageGenService(image_size="1024x1024")
#dalle = FalImageGenService(image_size="1024x1024")
tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV")
# dalle = OpenAIImageGenService(image_size="1024x1024")
dalle = AzureImageGenServiceREST(image_size="1024x1024")
# Get a complete audio chunk from the given text. Splitting this into its own
# coroutine lets us ensure proper ordering of the audio chunks on the send queue.

View File

@@ -0,0 +1,98 @@
import argparse
import asyncio
import requests
import time
import urllib.parse
from dailyai.conversation_wrappers import InterruptibleConversationWrapper
from dailyai.queue_frame import StartStreamQueueFrame, TextQueueFrame
from dailyai.services.daily_transport_service import DailyTransportService
from dailyai.services.azure_ai_services import AzureLLMService
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
async def main(room_url:str, token):
global transport
global llm
global tts
transport = DailyTransportService(
room_url,
token,
"Respond bot",
5,
)
transport.mic_enabled = True
transport.mic_sample_rate = 16000
transport.camera_enabled = False
llm = AzureLLMService()
tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV")
async def run_response(user_speech, tma_in, tma_out):
await tts.run_to_queue(
transport.send_queue,
tma_out.run(
llm.run(
tma_in.run(
[StartStreamQueueFrame(), TextQueueFrame(user_speech)]
)
)
),
)
@transport.event_handler("on_first_other_participant_joined")
async def on_first_other_participant_joined(transport):
await tts.say("Hi, I'm listening!", transport.send_queue)
async def run_conversation():
messages = [
{"role": "system", "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio. Respond to what the user said in a creative and helpful way."},
]
conversation_wrapper = InterruptibleConversationWrapper(
frame_generator=transport.get_receive_frames,
runner=run_response,
interrupt=transport.interrupt,
my_participant_id=transport.my_participant_id,
llm_messages=messages,
)
await conversation_wrapper.run_conversation()
transport.transcription_settings["extra"]["punctuate"] = False
await asyncio.gather(transport.run(), run_conversation())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple Daily Bot Sample")
parser.add_argument(
"-u", "--url", type=str, required=True, help="URL of the Daily room to join"
)
parser.add_argument(
"-k",
"--apikey",
type=str,
required=True,
help="Daily API Key (needed to create token)",
)
args, unknown = parser.parse_known_args()
# Create a meeting token for the given room with an expiration 1 hour in the future.
room_name: str = urllib.parse.urlparse(args.url).path[1:]
expiration: float = time.time() + 60 * 60
res: requests.Response = requests.post(
f"https://api.daily.co/v1/meeting-tokens",
headers={"Authorization": f"Bearer {args.apikey}"},
json={
"properties": {"room_name": room_name, "is_owner": True, "exp": expiration}
},
)
if res.status_code != 200:
raise Exception(f"Failed to create meeting token: {res.status_code} {res.text}")
token: str = res.json()["token"]
asyncio.run(main(args.url, token))

View File

@@ -0,0 +1,44 @@
import argparse
import asyncio
from dailyai.services.daily_transport_service import DailyTransportService
from dailyai.services.whisper_ai_services import WhisperSTTService
async def main(room_url: str):
global transport
global stt
transport = DailyTransportService(
room_url,
None,
"Transcription bot",
)
transport.mic_enabled = False
transport.camera_enabled = False
transport.speaker_enabled = True
stt = WhisperSTTService()
transcription_output_queue = asyncio.Queue()
async def handle_transcription():
print("`````````TRANSCRIPTION`````````")
while True:
item = await transcription_output_queue.get()
print(item.text)
async def handle_speaker():
await stt.run_to_queue(
transcription_output_queue,
transport.get_receive_frames()
)
await asyncio.gather(transport.run(), handle_speaker(), handle_transcription())
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple Daily Bot Sample")
parser.add_argument(
"-u", "--url", type=str, required=True, help="URL of the Daily room to join"
)
args, unknown = parser.parse_known_args()
asyncio.run(main(args.url))