Compare commits
19 Commits
transcript
...
cleanup
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
646db8b9bd | ||
|
|
42c142aff0 | ||
|
|
6da78dbf9c | ||
|
|
f0d9b0613e | ||
|
|
a661905d7f | ||
|
|
c9c2e5f561 | ||
|
|
795a339542 | ||
|
|
31db156dfc | ||
|
|
690cf2e47d | ||
|
|
ba89e41c5b | ||
|
|
c134598a77 | ||
|
|
b51abd2969 | ||
|
|
3fda9b0ecb | ||
|
|
95c92e5304 | ||
|
|
b443fbdb60 | ||
|
|
ccd2fa31e5 | ||
|
|
9b65286216 | ||
|
|
6ae733ebfe | ||
|
|
1071dede1a |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
||||
env/
|
||||
__pycache__/
|
||||
*~
|
||||
venv
|
||||
#*#
|
||||
|
||||
# Distribution / packaging
|
||||
|
||||
85
README.md
85
README.md
@@ -53,3 +53,88 @@ If you have those environment variables stored in an .env file, you can quickly
|
||||
```bash
|
||||
export $(grep -v '^#' .env | xargs)
|
||||
```
|
||||
|
||||
## Overview
|
||||
The Daily AI SDK allows you to build applications that can participate in WebRTC sessions and interact with AI Services. Some examples of what you can build with this:
|
||||
* conversational bots that interact 1:1 with a user, using voice recognition and text-to-speech
|
||||
* assistant bots that aggregate transcriptions from multiple participants in a meeting and provide realtime summaries or other AI-generated output.
|
||||
* image-recognition bots
|
||||
* etc
|
||||
## Concepts
|
||||
### Transport Service
|
||||
The SDK provides one “transport service”, which is a wrapper around Daily’s `daily-python` client (tk add link). You can use this service to listen for events related to a WebRTC session, such as “a participant joined the meeting”.
|
||||
The transport service also exposes a send queue, and a receive queue. You can use the send queue to send audio and video to the WebRTC session, and you can listen to the receive queue to see audio, video and transcription data from the WebRTC session.
|
||||
### AI Services
|
||||
The AI Service classes provide wrappers around various AI providers, and allow you to query LLMs, convert text to speech and make images from text. The audio and images can then be placed on the transport service’s send queue, where they’ll be sent to the WebRTC session.
|
||||
### Queue Frames
|
||||
Communication between the transport service and AI services, and between various AI services, takes place in Queue Frames. These frames contain an indication of the type of data as well as the data itself.
|
||||
## Using Transports, AI Services and Frames
|
||||
AI Services all define a `.run` method. This method consumes and generates `QueueFrame` frames. The kind of frames that can be consumed and generated depend on the kind of service. For instance, an LLM AI Service consumes `LLM_MESSAGE` frames (which define a history of interaction with an LLM) and emit `TEXT` frames (the response from the LLM).
|
||||
|
||||
The `.run` method is an `AsyncIterable`, and it takes an `iterable`, `AsyncIterable` or `asyncio.Queue` that produces QueueFrames as a parameter. This makes it easy to chain AI Services, and consume input from the Transport’s `receive_queue` .
|
||||
|
||||
AI Services also have a `.run_to_queue` method. This method is not an AsyncIterable, but instead sends processed QueueFrames to a queue. This makes it easy to send the output of an AI Service to the Transport’s `send_queue`.
|
||||
|
||||
AI Services also define convenience functions that let you bypass creating QueueFrames for some simple cases (eg. using the TTS service to convert a string to audio output and send that audio to the transport’s `send_queue`). See below for examples.
|
||||
## Examples
|
||||
### Say Something
|
||||
The base TTS AI service exposes a `.say` method. After creating a transport and TTS service, you can use this method like so:
|
||||
```
|
||||
transport = DailyTransportService(...)
|
||||
tts = AzureTTSService()
|
||||
await tts.say("hello world", transport.send_queue)
|
||||
```
|
||||
This will call the TTS service to render the text to audio frames, then put the audio frames on the transport’s send queue. The transport will then send those frames along to the WebRTC session.
|
||||
|
||||
### Speak an LLM response
|
||||
Given a system prompt contained in a `messages` array, you can emit the LLM’s response as audio with a chain like this:
|
||||
```
|
||||
transport = DailyTransportService(...) # setup parameters omitted
|
||||
tts = AzureTTSService()
|
||||
llm = AzureLLMService()
|
||||
messages = [...] # system prompt omitted for brevity
|
||||
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
llm.run([QueueFrame.LLM_MESSAGES, messages])
|
||||
)
|
||||
```
|
||||
In this code, the LLM service object sends the messages to Azure’s OpenAI implementation, which streams chunks back asynchronously. Those chunks are aggregated by the TTS Service to ensure the best audio response (TTS works best when it gets complete sentence, so it can inflect correctly), then sent to Azure’s TTS service, converted to audio frames, and sent to the WebRTC session via the Daily transport.
|
||||
|
||||
### Pre-cache an LLM response
|
||||
Sometimes LLMs can be slower than we’d like for natural-feeling communication. Here’s an example where we take advantage of the time it takes to speak some pre-defined text to get a head start on the LLM response:
|
||||
|
||||
(TK link to 04- sample)
|
||||
|
||||
In this sample, we set up a buffer queue to receive the audio frames from the LLM response before while we are joining the call and start an asynchronous task to start filling this buffer:
|
||||
```
|
||||
buffer_queue = asyncio.Queue()
|
||||
llm_response_task = asyncio.create_task(
|
||||
elevenlabs_tts.run_to_queue(
|
||||
buffer_queue,
|
||||
llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)]),
|
||||
True,
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Then, when we’ve joined the call, we speak the static text:
|
||||
```
|
||||
await azure_tts.say("My friend...", transport.send_queue)
|
||||
```
|
||||
|
||||
As that text is being spoken, the asynchronous LLM task continues in the background. When the text is done, we pull the frames off the buffer queue and put them in the transport’s `send_queue`:
|
||||
```
|
||||
async def buffer_to_send_queue():
|
||||
while True:
|
||||
frame = await buffer_queue.get()
|
||||
await transport.send_queue.put(frame)
|
||||
buffer_queue.task_done()
|
||||
if frame.frame_type == FrameType.END_STREAM:
|
||||
break
|
||||
|
||||
await asyncio.gather(llm_response_task, buffer_to_send_queue())
|
||||
|
||||
```
|
||||
|
||||
One thing to note here is the last parameter to `run_to_queue` in the first code clause above: this causes the `run_to_queue` method to send an `END_STREAM` frame when it’s done rendering. This lets us know when to stop our `buffer_to_send_queue` task above.
|
||||
|
||||
@@ -16,7 +16,8 @@ dependencies = [
|
||||
"pyht",
|
||||
"opentelemetry-sdk",
|
||||
"aiohttp",
|
||||
"fal"
|
||||
"fal",
|
||||
"faster_whisper"
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
autopep8==2.0.4
|
||||
build==1.0.3
|
||||
packaging==23.2
|
||||
pyproject_hooks==1.0.0
|
||||
pyproject_hooks==1.0.0
|
||||
@@ -1,347 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from queue import Queue, PriorityQueue, Empty
|
||||
from threading import Event, Semaphore, Thread
|
||||
from typing import Any, Generator, Iterator, Optional, Type
|
||||
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.message_handler.message_handler import MessageHandler
|
||||
from dailyai.services.ai_services import AIServiceConfig
|
||||
|
||||
class AsyncProcessorState:
|
||||
# Setting class variables, other synchronous activities
|
||||
INIT = 0
|
||||
|
||||
# Making asynchronous requests to LLM and other services to render response
|
||||
PREPARING = 1
|
||||
|
||||
# Ready to start presenting to user (but may not have all data yet)
|
||||
READY = 2
|
||||
|
||||
# Playing response
|
||||
PLAYING = 3
|
||||
|
||||
# An interrupt has been requested and the response is shutting down in-flight processing
|
||||
INTERRUPTING = 4
|
||||
|
||||
# An interrupt has been requested and the response is finished stopping in-flight processing
|
||||
INTERRUPTED = 5
|
||||
|
||||
# Response has been played or interrupted
|
||||
DONE = 6
|
||||
|
||||
# Response is being finalized (updating records of speech, updating LLM context, etc.)
|
||||
FINALIZING = 7
|
||||
|
||||
# Response is complete. This could mean that everything is updated, or that the response
|
||||
# was interrupted.
|
||||
FINALIZED = 8
|
||||
|
||||
state_transitions = {
|
||||
INIT: [PREPARING, INTERRUPTING],
|
||||
PREPARING: [READY, INTERRUPTING],
|
||||
READY: [PLAYING, INTERRUPTING],
|
||||
PLAYING: [DONE, INTERRUPTING],
|
||||
INTERRUPTING: [INTERRUPTED],
|
||||
INTERRUPTED: [DONE],
|
||||
DONE: [FINALIZING],
|
||||
FINALIZING: [FINALIZED],
|
||||
FINALIZED: [FINALIZED],
|
||||
}
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class StateTransitionItem:
|
||||
state: int
|
||||
evt: Event = field(compare=False)
|
||||
|
||||
class AsyncProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
services: AIServiceConfig
|
||||
) -> None:
|
||||
self.state = AsyncProcessorState.INIT
|
||||
self.prepare_thread = None
|
||||
self.play_thread = None
|
||||
self.finalize_thread = None
|
||||
|
||||
self.services: AIServiceConfig = services
|
||||
|
||||
self.state_transition_semaphore = Semaphore()
|
||||
self.waiting_for_state_changes = PriorityQueue()
|
||||
self.state_queue = Queue()
|
||||
|
||||
self.state_change_callbacks = defaultdict(list)
|
||||
|
||||
self.was_interrupted = False
|
||||
|
||||
self.logger: logging.Logger = logging.getLogger("dailyai")
|
||||
|
||||
def set_state(self, state: int) -> None:
|
||||
if state in AsyncProcessorState.state_transitions[self.state]:
|
||||
self.state_transition_semaphore.acquire()
|
||||
|
||||
self.state: int = state
|
||||
self.state_transition_semaphore.release()
|
||||
|
||||
# wake up any threads waiting for this state transition
|
||||
try:
|
||||
while True:
|
||||
waiter = self.waiting_for_state_changes.get_nowait()
|
||||
if waiter.state <= state:
|
||||
waiter.evt.set()
|
||||
else:
|
||||
self.waiting_for_state_changes.put(waiter)
|
||||
break
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
# make all the callbacks for this state
|
||||
for callback in self.state_change_callbacks[state]:
|
||||
callback(self)
|
||||
else:
|
||||
self.logger.error(
|
||||
f"Invalid state transition from {self.state} to {state} in {self.__class__.__name__}"
|
||||
)
|
||||
raise Exception(f"Invalid state transition from {self.state} to {state}")
|
||||
|
||||
#
|
||||
# This is used for state transitions that could be blocked by an interruption.
|
||||
# If we are interrupted, we silently fail this call. Use only if you know that
|
||||
# this state transition should fail if the processor has been interrupted.
|
||||
#
|
||||
|
||||
def maybe_set_state(self, state: int) -> bool:
|
||||
if state in AsyncProcessorState.state_transitions[self.state]:
|
||||
self.set_state(state)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def wait_for_state_transition(self, state: int) -> None:
|
||||
if self.state >= state:
|
||||
return
|
||||
|
||||
self.state_transition_semaphore.acquire()
|
||||
|
||||
evt = Event()
|
||||
self.waiting_for_state_changes.put(StateTransitionItem(state, evt))
|
||||
self.state_transition_semaphore.release()
|
||||
result = evt.wait(120.0)
|
||||
if not result:
|
||||
self.logger.error(
|
||||
f"Timed out waiting for state transition to {state} from {self.state}"
|
||||
)
|
||||
|
||||
def set_state_callback(self, state: int, callback: callable) -> None:
|
||||
self.state_change_callbacks[state].append(callback)
|
||||
|
||||
def prepare(self) -> None:
|
||||
self.prepare_thread = Thread(target=self.async_prepare, daemon=True)
|
||||
self.prepare_thread.start()
|
||||
self.wait_for_state_transition(AsyncProcessorState.READY)
|
||||
|
||||
def play(self) -> None:
|
||||
self.wait_for_state_transition(AsyncProcessorState.READY)
|
||||
self.play_thread = Thread(target=self.async_play, daemon=True)
|
||||
self.play_thread.start()
|
||||
self.wait_for_state_transition(AsyncProcessorState.PLAYING)
|
||||
|
||||
def finalize(self) -> None:
|
||||
# don't finalize until we're done playing.
|
||||
self.wait_for_state_transition(AsyncProcessorState.DONE)
|
||||
self.set_state(AsyncProcessorState.FINALIZING)
|
||||
self.do_finalization()
|
||||
self.set_state(AsyncProcessorState.FINALIZED)
|
||||
|
||||
def interrupt(self) -> None:
|
||||
# nothing to interrupt if we're already finalizing or finalized, no-op
|
||||
if self.state in [
|
||||
AsyncProcessorState.FINALIZING,
|
||||
AsyncProcessorState.FINALIZED,
|
||||
]:
|
||||
return
|
||||
|
||||
self.set_state(AsyncProcessorState.INTERRUPTING)
|
||||
self.was_interrupted = True
|
||||
self.do_interruption()
|
||||
self.set_state(AsyncProcessorState.INTERRUPTED)
|
||||
self.set_state(AsyncProcessorState.DONE)
|
||||
|
||||
def async_play(self) -> None:
|
||||
self.logger.info(f"Starting to play")
|
||||
if self.maybe_set_state(AsyncProcessorState.PLAYING):
|
||||
self.do_play()
|
||||
self.maybe_set_state(AsyncProcessorState.DONE)
|
||||
|
||||
def async_prepare(self) -> None:
|
||||
self.set_state(AsyncProcessorState.PREPARING)
|
||||
self.start_preparation()
|
||||
self.set_state(AsyncProcessorState.READY)
|
||||
self.continue_preparation()
|
||||
self.logger.info(f"Preparation done for {self.__class__.__name__}")
|
||||
self.preparation_done()
|
||||
|
||||
def start_preparation(self) -> None:
|
||||
pass
|
||||
|
||||
def continue_preparation(self) -> None:
|
||||
pass
|
||||
|
||||
def preparation_done(self):
|
||||
pass
|
||||
|
||||
def get_preparation_iterator(self) -> Iterator:
|
||||
yield None
|
||||
|
||||
def process_chunk(self, chunk) -> None:
|
||||
pass
|
||||
|
||||
def do_interruption(self) -> None:
|
||||
pass
|
||||
|
||||
def do_play(self) -> None:
|
||||
pass
|
||||
|
||||
def do_finalization(self) -> None:
|
||||
pass
|
||||
|
||||
# A common class for responses that use a message queue and
|
||||
# an output queue.
|
||||
|
||||
class OrchestratorResponse(AsyncProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
services,
|
||||
message_handler,
|
||||
output_queue,
|
||||
) -> None:
|
||||
super().__init__(services)
|
||||
|
||||
self.message_handler: MessageHandler = message_handler
|
||||
self.output_queue: Queue = output_queue
|
||||
|
||||
|
||||
class LLMResponse(OrchestratorResponse):
|
||||
def __init__(
|
||||
self,
|
||||
services,
|
||||
message_handler,
|
||||
output_queue,
|
||||
) -> None:
|
||||
super().__init__(services, message_handler, output_queue)
|
||||
|
||||
self.has_sent_first_frame = False
|
||||
|
||||
self.chunks_in_preparation = Queue()
|
||||
|
||||
self.llm_responses: list[str] = []
|
||||
|
||||
def get_preparation_iterator(self) -> Iterator:
|
||||
messages_for_llm = self.message_handler.get_llm_messages()
|
||||
self.logger.debug(f"Messages for llm: {json.dumps(messages_for_llm, indent=2)}")
|
||||
return self.clauses_from_chunks(
|
||||
self.services.llm.run_llm_async(messages_for_llm)
|
||||
)
|
||||
|
||||
def clauses_from_chunks(self, chunks) -> Iterator:
|
||||
out = ""
|
||||
for chunk in chunks:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
out += chunk
|
||||
|
||||
if re.match(r"^.*[.!?]$", out): # it looks like a sentence
|
||||
yield out.strip()
|
||||
out = ""
|
||||
|
||||
if out.strip():
|
||||
yield out.strip()
|
||||
|
||||
def get_frames_from_tts_response(self, audio_frame) -> list[QueueFrame]:
|
||||
return [QueueFrame(FrameType.AUDIO, audio_frame)]
|
||||
|
||||
def get_frames_from_chunk(self, chunk) -> Generator[list[QueueFrame], Any, None]:
|
||||
for audio_frame in self.services.tts.run_tts(chunk):
|
||||
yield self.get_frames_from_tts_response(audio_frame)
|
||||
|
||||
def start_preparation(self) -> None:
|
||||
self.preparation_iterator = self.get_preparation_iterator()
|
||||
|
||||
def continue_preparation(self) -> None:
|
||||
for chunk in self.preparation_iterator:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
self.process_chunk(chunk)
|
||||
|
||||
def process_chunk(self, chunk) -> None:
|
||||
self.chunks_in_preparation.put((chunk, self.get_frames_from_chunk(chunk)))
|
||||
|
||||
def preparation_done(self):
|
||||
self.chunks_in_preparation.put((None, None))
|
||||
|
||||
def do_play(self) -> None:
|
||||
while True:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
prepared_chunk = self.chunks_in_preparation.get()
|
||||
if prepared_chunk[0] == None:
|
||||
return
|
||||
|
||||
self.play_prepared_chunk(prepared_chunk)
|
||||
|
||||
def play_prepared_chunk(self, prepared_chunk) -> None:
|
||||
chunk, tts_generator = prepared_chunk
|
||||
for frames in tts_generator:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
if not self.has_sent_first_frame:
|
||||
self.output_queue.put(QueueFrame(FrameType.START_STREAM, None))
|
||||
self.has_sent_first_frame = True
|
||||
|
||||
for frame in frames:
|
||||
self.output_queue.put(frame)
|
||||
|
||||
self.output_queue.join()
|
||||
self.llm_responses.append(chunk)
|
||||
|
||||
def do_finalization(self) -> None:
|
||||
self.message_handler.add_assistant_messages(self.llm_responses)
|
||||
|
||||
def do_interruption(self) -> None:
|
||||
self.chunks_in_preparation.put((None, None))
|
||||
|
||||
if self.prepare_thread and self.prepare_thread.is_alive():
|
||||
self.prepare_thread.join()
|
||||
|
||||
if self.play_thread and self.play_thread.is_alive():
|
||||
self.play_thread.join()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConversationProcessorCollection:
|
||||
introduction: Optional[Type[OrchestratorResponse]] = None
|
||||
waiting: Optional[Type[OrchestratorResponse]] = None
|
||||
response: Optional[Type[OrchestratorResponse]] = None
|
||||
goodbye: Optional[Type[OrchestratorResponse]] = None
|
||||
76
src/dailyai/conversation_wrappers.py
Normal file
76
src/dailyai/conversation_wrappers.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
from typing import AsyncGenerator, Awaitable, Callable
|
||||
from dailyai.queue_aggregators import LLMContextAggregator
|
||||
from dailyai.queue_frame import EndStreamQueueFrame, QueueFrame, TranscriptionQueueFrame
|
||||
|
||||
class InterruptibleConversationWrapper:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frame_generator: Callable[[], AsyncGenerator[QueueFrame, None]],
|
||||
runner: Callable[
|
||||
[str, LLMContextAggregator, LLMContextAggregator], Awaitable[None]
|
||||
],
|
||||
interrupt: Callable[[], None],
|
||||
my_participant_id: str|None,
|
||||
llm_messages: list[dict[str, str]],
|
||||
llm_context_aggregator_in=LLMContextAggregator,
|
||||
llm_context_aggregator_out=LLMContextAggregator,
|
||||
delay_before_speech_seconds: float = 1.0,
|
||||
):
|
||||
self._frame_generator: Callable[[], AsyncGenerator[QueueFrame, None]] = frame_generator
|
||||
self._runner: Callable[
|
||||
[str, LLMContextAggregator, LLMContextAggregator], Awaitable[None]
|
||||
] = runner
|
||||
self._interrupt: Callable[[], None] = interrupt
|
||||
self._my_participant_id = my_participant_id
|
||||
self._messages: list[dict[str, str]] = llm_messages
|
||||
self._delay_before_speech_seconds = delay_before_speech_seconds
|
||||
self._llm_context_aggregator_in = llm_context_aggregator_in
|
||||
self._llm_context_aggregator_out = llm_context_aggregator_out
|
||||
|
||||
self._current_phrase = ""
|
||||
|
||||
def update_messages(self, new_messages: list[dict[str, str]], task: asyncio.Task | None):
|
||||
if task:
|
||||
if not task.cancelled():
|
||||
self._current_phrase = ""
|
||||
self._messages = new_messages
|
||||
|
||||
async def speak_after_delay(self, user_speech, messages):
|
||||
await asyncio.sleep(self._delay_before_speech_seconds)
|
||||
tma_in = self._llm_context_aggregator_in(
|
||||
messages, "user", self._my_participant_id, False
|
||||
)
|
||||
tma_out = self._llm_context_aggregator_out(
|
||||
messages, "assistant", self._my_participant_id
|
||||
)
|
||||
|
||||
await self._runner(user_speech, tma_in, tma_out)
|
||||
|
||||
async def run_conversation(self):
|
||||
current_response_task = None
|
||||
|
||||
async for frame in self._frame_generator():
|
||||
if isinstance(frame, EndStreamQueueFrame):
|
||||
break
|
||||
elif not isinstance(frame, TranscriptionQueueFrame):
|
||||
continue
|
||||
|
||||
if frame.participantId == self._my_participant_id:
|
||||
continue
|
||||
|
||||
if current_response_task:
|
||||
current_response_task.cancel()
|
||||
self._interrupt()
|
||||
|
||||
self._current_phrase += " " + frame.text
|
||||
current_llm_messages = copy.deepcopy(self._messages)
|
||||
current_response_task = asyncio.create_task(
|
||||
self.speak_after_delay(self._current_phrase, current_llm_messages)
|
||||
)
|
||||
current_response_task.add_done_callback(
|
||||
functools.partial(self.update_messages, current_llm_messages)
|
||||
)
|
||||
@@ -1,127 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread
|
||||
|
||||
from dailyai.storage.search import SearchIndexer
|
||||
from dailyai.services.ai_services import AIServiceConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
type: str
|
||||
timestamp: float
|
||||
message: str
|
||||
|
||||
|
||||
class MessageHandler:
|
||||
def __init__(self, intro):
|
||||
self.messages: list[Message] = [Message("system", time.time(), intro)]
|
||||
self.last_user_message_idx:int | None = None
|
||||
self.finalized_user_message_idx: int | None = None
|
||||
|
||||
def add_user_message(self, message) -> None:
|
||||
if self.last_user_message_idx is not None and self.last_user_message_idx != self.finalized_user_message_idx:
|
||||
previous_message: str = self.messages[self.last_user_message_idx].message
|
||||
self.messages[self.last_user_message_idx] = Message(
|
||||
"user", time.time(), ' '.join([previous_message, message])
|
||||
)
|
||||
self.messages = self.messages[: self.last_user_message_idx + 1]
|
||||
else:
|
||||
self.messages.append(Message("user", time.time(), message))
|
||||
|
||||
self.last_user_message_idx = len(self.messages) - 1
|
||||
|
||||
def add_assistant_message(self, message) -> None:
|
||||
if self.messages[-1].type == "assistant":
|
||||
self.messages[-1].message += " " + message
|
||||
else:
|
||||
self.messages.append(Message("assistant", time.time(), message))
|
||||
|
||||
def add_assistant_messages(self, messages) -> None:
|
||||
self.messages.append(Message("assistant", time.time(), " ".join(messages)))
|
||||
|
||||
def get_llm_messages(self) -> list[dict[str, str]]:
|
||||
return [{"role": m.type, "content": m.message} for m in self.messages]
|
||||
|
||||
def finalize_user_message(self) -> None:
|
||||
self.finalized_user_message_idx = self.last_user_message_idx
|
||||
|
||||
def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
class IndexingMessageHandler(MessageHandler):
|
||||
def __init__(
|
||||
self, intro, services: AIServiceConfig, indexer: SearchIndexer
|
||||
) -> None:
|
||||
super().__init__(intro)
|
||||
self.services = services
|
||||
|
||||
self.search_indexer = indexer
|
||||
|
||||
self.last_written_idx = 0
|
||||
self.storage_message_queue = Queue()
|
||||
|
||||
self.index_writer_thread = Thread(target=self.storage_writer, daemon=True)
|
||||
self.index_writer_thread.start()
|
||||
|
||||
self.logger = logging.getLogger("dailyai")
|
||||
|
||||
def shutdown(self):
|
||||
self.finalize_user_message()
|
||||
self.storage_message_queue.put(None)
|
||||
self.index_writer_thread.join()
|
||||
|
||||
def storage_writer(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
message_idx = self.storage_message_queue.get()
|
||||
self.storage_message_queue.task_done()
|
||||
|
||||
if message_idx is None:
|
||||
return
|
||||
|
||||
if message_idx <= self.last_written_idx:
|
||||
continue
|
||||
|
||||
self.last_written_idx = message_idx
|
||||
|
||||
message = self.messages[message_idx]
|
||||
content = message.message
|
||||
if message.type == "user":
|
||||
content = self.cleanup_user_message(content)
|
||||
|
||||
# sometimes the LLM returns a string wrapped in quotes and sometimes it doesn't.
|
||||
# if it didn't, wrap it in quotes
|
||||
if content[0] != '"':
|
||||
content = '"' + content + '"'
|
||||
|
||||
self.search_indexer.index_text(content)
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
def cleanup_user_message(self, user_message) -> str:
|
||||
return user_message
|
||||
|
||||
def finalize_user_message(self):
|
||||
super().finalize_user_message()
|
||||
self.write_messages_to_storage()
|
||||
|
||||
def write_messages_to_storage(self):
|
||||
if self.finalized_user_message_idx is None:
|
||||
return
|
||||
|
||||
for idx in range(self.last_written_idx, len(self.messages)):
|
||||
self.logger.info(
|
||||
f"Writing to storage: {self.messages[idx].type} {self.messages[idx].message}"
|
||||
)
|
||||
if (
|
||||
self.messages[idx].type == "user"
|
||||
and idx > self.finalized_user_message_idx
|
||||
):
|
||||
break
|
||||
|
||||
if self.messages[idx].type != "system":
|
||||
self.storage_message_queue.put(idx)
|
||||
@@ -1,409 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import wave
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from queue import Queue, Empty
|
||||
from opentelemetry import trace, context
|
||||
|
||||
from dailyai.async_processor.async_processor import (
|
||||
AsyncProcessor,
|
||||
AsyncProcessorState,
|
||||
ConversationProcessorCollection,
|
||||
OrchestratorResponse,
|
||||
LLMResponse,
|
||||
)
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.ai_services import AIServiceConfig
|
||||
from dailyai.message_handler.message_handler import MessageHandler
|
||||
|
||||
from threading import Thread, Semaphore, Event, Timer
|
||||
|
||||
from opentelemetry import context
|
||||
from opentelemetry.context.context import Context
|
||||
|
||||
from daily import (
|
||||
EventHandler,
|
||||
CallClient,
|
||||
Daily,
|
||||
VirtualCameraDevice,
|
||||
VirtualMicrophoneDevice,
|
||||
VirtualSpeakerDevice,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrchestratorConfig:
|
||||
room_url: str
|
||||
token: str
|
||||
bot_name: str
|
||||
expiration: float
|
||||
|
||||
# Note that we use this as a default parameter value in the Orchestrator
|
||||
# constructor. The dataclass is defined with Frozen=True, so this should
|
||||
# be safe.
|
||||
default_conversation_collection = ConversationProcessorCollection(
|
||||
introduction=LLMResponse,
|
||||
waiting=None,
|
||||
response=LLMResponse,
|
||||
goodbye=None,
|
||||
)
|
||||
|
||||
|
||||
class Orchestrator(EventHandler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
daily_config: OrchestratorConfig,
|
||||
ai_service_config: AIServiceConfig,
|
||||
message_handler: MessageHandler,
|
||||
conversation_processors: ConversationProcessorCollection = default_conversation_collection,
|
||||
tracer=None,
|
||||
):
|
||||
self.bot_name: str = daily_config.bot_name
|
||||
self.room_url: str = daily_config.room_url
|
||||
self.token: str = daily_config.token
|
||||
self.expiration: float = daily_config.expiration
|
||||
|
||||
self.logger: logging.Logger = logging.getLogger("dailyai")
|
||||
self.tracer = tracer or trace.get_tracer("orchestrator")
|
||||
|
||||
self.ctx: Context = context.get_current()
|
||||
|
||||
self.transcription = ""
|
||||
self.last_fragment_at = None
|
||||
self.talked_at = None
|
||||
self.paused_at = None
|
||||
|
||||
self.logger.info(f"Creating Response for introductions")
|
||||
self.services: AIServiceConfig = ai_service_config
|
||||
self.output_queue = Queue()
|
||||
self.is_interrupted = Event()
|
||||
self.stop_threads = Event()
|
||||
self.story_started = False
|
||||
|
||||
self.message_handler = message_handler
|
||||
self.conversation_processors: ConversationProcessorCollection = conversation_processors
|
||||
|
||||
if conversation_processors.introduction is not None:
|
||||
intro = conversation_processors.introduction(
|
||||
services=self.services, message_handler=self.message_handler, output_queue=self.output_queue
|
||||
)
|
||||
intro.prepare()
|
||||
intro.set_state_callback(AsyncProcessorState.DONE, self.on_intro_played)
|
||||
intro.set_state_callback(AsyncProcessorState.FINALIZED, self.on_intro_finished)
|
||||
self.logger.info(f"Introduction is preparing")
|
||||
|
||||
self.current_response: AsyncProcessor = intro
|
||||
self.can_interrupt = False
|
||||
# self.response_event.set()
|
||||
self.response_semaphore = Semaphore()
|
||||
|
||||
self.speech_timeout = None
|
||||
self.interrupt_time = None
|
||||
|
||||
self.logger.info("Configuring daily")
|
||||
self.configure_daily()
|
||||
|
||||
def configure_daily(self):
|
||||
Daily.init()
|
||||
self.client = CallClient(event_handler=self)
|
||||
|
||||
self.logger.info(f"Mic sample rate: {self.services.tts.get_mic_sample_rate()}")
|
||||
self.mic: VirtualMicrophoneDevice = Daily.create_microphone_device(
|
||||
"mic", sample_rate=self.services.tts.get_mic_sample_rate(), channels=1
|
||||
)
|
||||
self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
|
||||
"speaker", sample_rate=16000, channels=1
|
||||
)
|
||||
self.camera: VirtualCameraDevice = Daily.create_camera_device(
|
||||
"camera", width=720, height=1280, color_format="RGB"
|
||||
)
|
||||
|
||||
Daily.select_speaker_device("speaker")
|
||||
|
||||
self.client.set_user_name(self.bot_name)
|
||||
self.client.join(self.room_url, self.token, completion=self.call_joined)
|
||||
|
||||
self.client.update_inputs(
|
||||
{
|
||||
"camera": {
|
||||
"isEnabled": True,
|
||||
"settings": {
|
||||
"deviceId": "camera",
|
||||
},
|
||||
},
|
||||
"microphone": {
|
||||
"isEnabled": True,
|
||||
"settings": {
|
||||
"deviceId": "mic",
|
||||
"customConstraints": {
|
||||
"autoGainControl": {"exact": False},
|
||||
"echoCancellation": {"exact": False},
|
||||
"noiseSuppression": {"exact": False},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
self.client.update_publishing(
|
||||
{
|
||||
"camera": {
|
||||
"sendSettings": {
|
||||
"maxQuality": "low",
|
||||
"encodings": {
|
||||
"low": {
|
||||
"maxBitrate": 250000,
|
||||
"scaleResolutionDownBy": 1.333,
|
||||
"maxFramerate": 8,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
self.my_participant_id = self.client.participants()["local"]["id"]
|
||||
|
||||
def start(self) -> None:
|
||||
# TODO: this loop could, I think, be replaced with a timer and an event
|
||||
self.participant_left = False
|
||||
|
||||
try:
|
||||
participant_count: int = len(self.client.participants())
|
||||
self.logger.info(f"{participant_count} participants in room")
|
||||
while time.time() < self.expiration and not self.participant_left:
|
||||
# all handling of incoming transcriptions happens in on_transcription_message
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e}")
|
||||
finally:
|
||||
self.client.leave()
|
||||
|
||||
def stop(self):
|
||||
self.logger.info("Stop current response")
|
||||
if self.current_response:
|
||||
if self.current_response.state < AsyncProcessorState.INTERRUPTED:
|
||||
self.current_response.interrupt()
|
||||
|
||||
self.logger.info("Wait for state transition")
|
||||
self.current_response.wait_for_state_transition(AsyncProcessorState.FINALIZED)
|
||||
|
||||
self.stop_threads.set()
|
||||
self.camera_thread.join()
|
||||
self.logger.info("Camera thread stopped")
|
||||
|
||||
self.logger.info("Put stop in output queue")
|
||||
self.output_queue.put(QueueFrame(FrameType.END_STREAM, None))
|
||||
|
||||
self.frame_consumer_thread.join()
|
||||
self.logger.info("Orchestrator stopped.")
|
||||
|
||||
def on_intro_played(self, intro):
|
||||
self.logger.info(f"Introduction has played")
|
||||
self.can_interrupt = True
|
||||
intro.finalize()
|
||||
|
||||
def on_intro_finished(self, intro):
|
||||
self.logger.info(f"Introduction has finished")
|
||||
waiting = self.conversation_processors.waiting(self.services, self.message_handler, self.output_queue)
|
||||
waiting.prepare()
|
||||
waiting.play()
|
||||
|
||||
def on_response_played(self, response):
|
||||
response.finalize()
|
||||
|
||||
def on_response_finished(self, response):
|
||||
if not response.was_interrupted:
|
||||
self.message_handler.finalize_user_message()
|
||||
|
||||
def call_joined(self, join_data, client_error):
|
||||
self.logger.info(f"Call_joined: {join_data}, {client_error}")
|
||||
self.client.start_transcription(
|
||||
{
|
||||
"language": "en",
|
||||
"tier": "nova",
|
||||
"model": "2-conversationalai",
|
||||
"profanity_filter": True,
|
||||
"redact": False,
|
||||
"extra": {
|
||||
"endpointing": True,
|
||||
"punctuate": False,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
def on_participant_joined(self, participant):
|
||||
with self.tracer.start_as_current_span("on_participant_joined", context=self.ctx):
|
||||
self.logger.info(f"on_participant_joined: {participant}")
|
||||
|
||||
# TODO: figure out the architecture to get the story id to the client
|
||||
# self.client.send_app_message({"event": "story-id", "storyID": self.story_id})
|
||||
time.sleep(2)
|
||||
|
||||
if not self.story_started:
|
||||
self.action()
|
||||
self.story_started = True
|
||||
|
||||
def on_participant_left(self, participant, reason):
|
||||
self.logger.info(f"Participant {participant} left")
|
||||
if len(self.client.participants()) < 2:
|
||||
self.participant_left = True
|
||||
|
||||
def on_app_message(self, message, sender):
|
||||
with self.tracer.start_as_current_span("on_app_message", context=self.ctx):
|
||||
self.logger.info(f"on_app_message {message} from {sender}")
|
||||
if "isSpeaking" in message and message["isSpeaking"] == True:
|
||||
self.handle_user_started_talking()
|
||||
|
||||
if "isSpeaking" in message and message["isSpeaking"] == False:
|
||||
self.handle_user_stopped_talking()
|
||||
|
||||
def on_transcription_message(self, message):
|
||||
with self.tracer.start_as_current_span("on_transcription_message", context=self.ctx):
|
||||
if message["session_id"] != self.my_participant_id:
|
||||
self.handle_transcription_fragment(message['text'])
|
||||
|
||||
def on_transcription_stopped(self, stopped_by, stopped_by_error):
|
||||
self.logger.info(f"Transcription stopped {stopped_by}, {stopped_by_error}")
|
||||
|
||||
def on_transcription_error(self, message):
|
||||
self.logger.error(f"Transcription error {message}")
|
||||
|
||||
def on_transcription_started(self, status):
|
||||
self.logger.info(f"Transcription started {status}")
|
||||
|
||||
def set_image(self, image: bytes):
|
||||
self.image: bytes | None = image
|
||||
|
||||
def run_camera(self):
|
||||
try:
|
||||
while not self.stop_threads.is_set():
|
||||
if self.image:
|
||||
self.camera.write_frame(self.image)
|
||||
|
||||
time.sleep(1.0 / 8.0) # 8 fps
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e} in camera thread.")
|
||||
|
||||
def handle_user_started_talking(self):
|
||||
# TODO: allow configuration of the timer timeout
|
||||
self.logger.error("user started talking")
|
||||
self.speech_timeout = Timer(1.0, self.utterance_interrupt)
|
||||
|
||||
def handle_user_stopped_talking(self):
|
||||
self.logger.error("user stopped talking, canceling utterance interrupt")
|
||||
if self.speech_timeout:
|
||||
self.speech_timeout.cancel()
|
||||
|
||||
def utterance_interrupt(self):
|
||||
self.logger.error("utterance interrupt")
|
||||
self.is_interrupted.set()
|
||||
|
||||
def handle_transcription_fragment(self, fragment):
|
||||
if not self.can_interrupt:
|
||||
return
|
||||
|
||||
# start generating a new response. We'll do the fast parts of the interrupt
|
||||
# now but wait for the state transition after we've kicked off the prepare
|
||||
# on the new response.
|
||||
if (
|
||||
self.current_response
|
||||
and self.current_response.state < AsyncProcessorState.INTERRUPTED
|
||||
):
|
||||
self.interrupt_time = time.perf_counter()
|
||||
self.is_interrupted.set()
|
||||
self.current_response.interrupt()
|
||||
|
||||
self.message_handler.add_user_message(fragment)
|
||||
|
||||
response_type: type[OrchestratorResponse] | type[LLMResponse] = self.conversation_processors.response or LLMResponse
|
||||
new_response: OrchestratorResponse = response_type(
|
||||
self.services, self.message_handler, self.output_queue
|
||||
)
|
||||
new_response.set_state_callback(
|
||||
AsyncProcessorState.DONE, self.on_response_played
|
||||
)
|
||||
new_response.set_state_callback(
|
||||
AsyncProcessorState.FINALIZED, self.on_response_finished
|
||||
)
|
||||
new_response.prepare()
|
||||
|
||||
self.response_semaphore.acquire()
|
||||
if (
|
||||
self.current_response
|
||||
and self.current_response.state < AsyncProcessorState.INTERRUPTED
|
||||
):
|
||||
self.current_response.wait_for_state_transition(
|
||||
AsyncProcessorState.FINALIZED
|
||||
)
|
||||
|
||||
self.current_response = new_response
|
||||
self.current_response.play()
|
||||
|
||||
self.response_semaphore.release()
|
||||
|
||||
def action(self):
|
||||
self.logger.info("Starting camera thread")
|
||||
self.image: bytes | None = None
|
||||
self.camera_thread = Thread(target=self.run_camera, daemon=True)
|
||||
self.camera_thread.start()
|
||||
|
||||
self.logger.info("Starting frame consumer thread")
|
||||
self.frame_consumer_thread = Thread(target=self.frame_consumer, daemon=True)
|
||||
self.frame_consumer_thread.start()
|
||||
|
||||
self.logger.info("Playing introduction")
|
||||
self.can_interrupt = False
|
||||
self.current_response.play()
|
||||
|
||||
def frame_consumer(self):
|
||||
self.logger.info("🎬 Starting frame consumer thread")
|
||||
b = bytearray()
|
||||
smallest_write_size = 3200
|
||||
all_audio_frames = bytearray()
|
||||
while True:
|
||||
try:
|
||||
frame:QueueFrame = self.output_queue.get()
|
||||
if frame.frame_type == FrameType.END_STREAM:
|
||||
self.logger.info("Stopping frame consumer thread")
|
||||
return
|
||||
|
||||
# if interrupted, we just pull frames off the queue and discard them
|
||||
if not self.is_interrupted.is_set():
|
||||
if frame:
|
||||
if frame.frame_type == FrameType.AUDIO:
|
||||
chunk = frame.frame_data
|
||||
|
||||
all_audio_frames.extend(chunk)
|
||||
|
||||
b.extend(chunk)
|
||||
l = len(b) - (len(b) % smallest_write_size)
|
||||
if l:
|
||||
self.mic.write_frames(bytes(b[:l]))
|
||||
b = b[l:]
|
||||
elif frame.frame_type == FrameType.IMAGE:
|
||||
self.set_image(frame.frame_data)
|
||||
elif len(b):
|
||||
self.mic.write_frames(bytes(b))
|
||||
b = bytearray()
|
||||
else:
|
||||
if self.interrupt_time:
|
||||
self.logger.info(f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}")
|
||||
self.interrupt_time = None
|
||||
|
||||
if frame.frame_type == FrameType.START_STREAM:
|
||||
self.is_interrupted.clear()
|
||||
|
||||
self.output_queue.task_done()
|
||||
except Empty:
|
||||
try:
|
||||
if len(b):
|
||||
self.mic.write_frames(bytes(b))
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception in frame_consumer: {e}, {len(b)}")
|
||||
|
||||
b = bytearray()
|
||||
48
src/dailyai/queue_aggregators.py
Normal file
48
src/dailyai/queue_aggregators.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import asyncio
|
||||
|
||||
from dailyai.queue_frame import LLMMessagesQueueFrame, QueueFrame, TextQueueFrame
|
||||
from dailyai.services.ai_services import AIService
|
||||
|
||||
from typing import AsyncGenerator, List
|
||||
|
||||
class QueueTee:
|
||||
async def run_to_queue_and_generate(
|
||||
self,
|
||||
output_queue: asyncio.Queue,
|
||||
generator: AsyncGenerator[QueueFrame, None]
|
||||
) -> AsyncGenerator[QueueFrame, None]:
|
||||
async for frame in generator:
|
||||
await output_queue.put(frame)
|
||||
yield frame
|
||||
|
||||
async def run_to_queues(
|
||||
self,
|
||||
output_queues: List[asyncio.Queue],
|
||||
generator: AsyncGenerator[QueueFrame, None]
|
||||
):
|
||||
async for frame in generator:
|
||||
for queue in output_queues:
|
||||
await queue.put(frame)
|
||||
|
||||
class LLMContextAggregator(AIService):
|
||||
def __init__(self, messages: list[dict], role:str, bot_participant_id=None, complete_sentences=True):
|
||||
self.messages = messages
|
||||
self.bot_participant_id = bot_participant_id
|
||||
self.role = role
|
||||
self.sentence = ""
|
||||
self.complete_sentences = complete_sentences
|
||||
|
||||
async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
# TODO: split up transcription by participant
|
||||
if isinstance(frame, TextQueueFrame):
|
||||
if self.complete_sentences:
|
||||
self.sentence += frame.text
|
||||
if self.sentence.endswith((".", "?", "!")):
|
||||
self.messages.append({"role": self.role, "content": self.sentence})
|
||||
self.sentence = ""
|
||||
yield LLMMessagesQueueFrame(self.messages)
|
||||
else:
|
||||
self.messages.append({"role": self.role, "content": frame.text})
|
||||
yield LLMMessagesQueueFrame(self.messages)
|
||||
|
||||
yield frame
|
||||
@@ -1,19 +1,41 @@
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
class FrameType(Enum):
|
||||
START_STREAM = 0
|
||||
END_STREAM = 1
|
||||
AUDIO = 2
|
||||
IMAGE = 3
|
||||
SENTENCE = 4
|
||||
TEXT_CHUNK = 5
|
||||
LLM_MESSAGE = 6
|
||||
APP_MESSAGE = 7
|
||||
IMAGE_DESCRIPTION = 8
|
||||
TRANSCRIPTION = 9
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QueueFrame:
|
||||
frame_type: FrameType
|
||||
frame_data: str | dict | bytes | list | None
|
||||
pass
|
||||
|
||||
class ControlQueueFrame(QueueFrame):
|
||||
pass
|
||||
|
||||
class StartStreamQueueFrame(ControlQueueFrame):
|
||||
pass
|
||||
|
||||
class EndStreamQueueFrame(ControlQueueFrame):
|
||||
pass
|
||||
|
||||
@dataclass()
|
||||
class AudioQueueFrame(QueueFrame):
|
||||
data: bytes
|
||||
|
||||
@dataclass()
|
||||
class ImageQueueFrame(QueueFrame):
|
||||
url: str | None
|
||||
image: bytes
|
||||
|
||||
@dataclass()
|
||||
class TextQueueFrame(QueueFrame):
|
||||
text: str
|
||||
|
||||
@dataclass()
|
||||
class TranscriptionQueueFrame(TextQueueFrame):
|
||||
participantId: str
|
||||
timestamp: str
|
||||
|
||||
@dataclass()
|
||||
class LLMMessagesQueueFrame(QueueFrame):
|
||||
messages: list[dict[str,str]] # TODO: define this more concretely!
|
||||
|
||||
class AppMessageQueueFrame(QueueFrame):
|
||||
message: Any
|
||||
participantId: str
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
Pillow==10.1.0
|
||||
typing_extensions==4.9.0
|
||||
typing_extensions==4.9.0
|
||||
faster-whisper==0.10.0
|
||||
@@ -1,73 +0,0 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from dailyai.queue_frame import FrameType, QueueFrame
|
||||
from dailyai.services.ai_services import AIService
|
||||
|
||||
class SentenceAggregator(AIService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.current_sentence = ""
|
||||
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.TEXT_CHUNK, FrameType.SENTENCE])
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.SENTENCE])
|
||||
|
||||
async def process_frame(
|
||||
self, requested_frame_types: set[FrameType], frame: QueueFrame
|
||||
) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not FrameType.SENTENCE in requested_frame_types:
|
||||
return
|
||||
|
||||
if frame.frame_type == FrameType.TEXT_CHUNK:
|
||||
if type(frame.frame_data) != str:
|
||||
raise Exception(
|
||||
"Sentence aggregator requires a string for the data field"
|
||||
)
|
||||
|
||||
self.current_sentence += frame.frame_data
|
||||
if self.current_sentence.endswith((".", "?", "!")):
|
||||
sentence = self.current_sentence
|
||||
self.current_sentence = ""
|
||||
yield QueueFrame(FrameType.SENTENCE, sentence)
|
||||
elif frame.frame_type == FrameType.END_STREAM:
|
||||
if self.current_sentence:
|
||||
yield QueueFrame(FrameType.SENTENCE, self.current_sentence)
|
||||
elif frame.frame_type == FrameType.SENTENCE:
|
||||
yield frame
|
||||
|
||||
|
||||
class TranscriptionSentenceAggregator(AIService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.current_sentence = ""
|
||||
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.TEXT_CHUNK, FrameType.SENTENCE])
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.SENTENCE])
|
||||
|
||||
async def process_frame(
|
||||
self, requested_frame_types: set[FrameType], frame: QueueFrame
|
||||
) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not FrameType.SENTENCE in requested_frame_types:
|
||||
return
|
||||
|
||||
if frame.frame_type == FrameType.TEXT_CHUNK:
|
||||
if type(frame.frame_data) != str:
|
||||
raise Exception(
|
||||
"Sentence aggregator requires a string for the data field"
|
||||
)
|
||||
|
||||
self.current_sentence += frame.frame_data
|
||||
if self.current_sentence.endswith((".", "?", "!")):
|
||||
sentence = self.current_sentence
|
||||
self.current_sentence = ""
|
||||
yield QueueFrame(FrameType.SENTENCE, sentence)
|
||||
elif frame.frame_type == FrameType.END_STREAM:
|
||||
if self.current_sentence:
|
||||
yield QueueFrame(FrameType.SENTENCE, self.current_sentence)
|
||||
elif frame.frame_type == FrameType.SENTENCE:
|
||||
yield frame
|
||||
@@ -1,17 +1,22 @@
|
||||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import re
|
||||
import wave
|
||||
|
||||
from httpx import request
|
||||
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.queue_frame import (
|
||||
AudioQueueFrame,
|
||||
ControlQueueFrame,
|
||||
EndStreamQueueFrame,
|
||||
ImageQueueFrame,
|
||||
LLMMessagesQueueFrame,
|
||||
QueueFrame,
|
||||
TextQueueFrame,
|
||||
)
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator, Iterable
|
||||
from typing import AsyncGenerator, AsyncIterable, BinaryIO, Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from collections.abc import Iterable, AsyncIterable
|
||||
|
||||
class AIService:
|
||||
|
||||
@@ -21,95 +26,56 @@ class AIService:
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set()
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set()
|
||||
|
||||
async def run_to_queue(self, queue: asyncio.Queue, frames, add_end_of_stream=False) -> None:
|
||||
async for frame in self.run(frames):
|
||||
await queue.put(frame)
|
||||
|
||||
if add_end_of_stream:
|
||||
await queue.put(QueueFrame(FrameType.END_STREAM, None))
|
||||
await queue.put(EndStreamQueueFrame())
|
||||
|
||||
async def run(
|
||||
self,
|
||||
frames: Iterable[QueueFrame]
|
||||
| AsyncIterable[QueueFrame]
|
||||
| asyncio.Queue[QueueFrame],
|
||||
requested_frame_types: set[FrameType] | None=None,
|
||||
) -> AsyncGenerator[QueueFrame, None]:
|
||||
if requested_frame_types and self.possible_output_frame_types().intersection(requested_frame_types) == set():
|
||||
raise Exception(f"Requested frame types {requested_frame_types} are not supported by this service.")
|
||||
try:
|
||||
if isinstance(frames, AsyncIterable):
|
||||
async for frame in frames:
|
||||
async for output_frame in self.process_frame(frame):
|
||||
yield output_frame
|
||||
elif isinstance(frames, Iterable):
|
||||
for frame in frames:
|
||||
async for output_frame in self.process_frame(frame):
|
||||
yield output_frame
|
||||
elif isinstance(frames, asyncio.Queue):
|
||||
while True:
|
||||
frame = await frames.get()
|
||||
async for output_frame in self.process_frame(frame):
|
||||
yield output_frame
|
||||
if isinstance(frame, EndStreamQueueFrame):
|
||||
break
|
||||
else:
|
||||
raise Exception("Frames must be an iterable or async iterable")
|
||||
|
||||
if not requested_frame_types:
|
||||
requested_frame_types = self.possible_output_frame_types()
|
||||
|
||||
if isinstance(frames, AsyncIterable):
|
||||
async for frame in frames:
|
||||
async for output_frame in self.process_frame(requested_frame_types, frame):
|
||||
yield output_frame
|
||||
elif isinstance(frames, Iterable):
|
||||
for frame in frames:
|
||||
async for output_frame in self.process_frame(requested_frame_types, frame):
|
||||
yield output_frame
|
||||
elif isinstance(frames, asyncio.Queue):
|
||||
while True:
|
||||
frame = await frames.get()
|
||||
async for output_frame in self.process_frame(requested_frame_types, frame):
|
||||
yield output_frame
|
||||
if frame.frame_type == FrameType.END_STREAM:
|
||||
break
|
||||
else:
|
||||
raise Exception("Frames must be an iterable or async iterable")
|
||||
async for output_frame in self.finalize():
|
||||
yield output_frame
|
||||
except Exception as e:
|
||||
self.logger.error("Exception occurred while running AI service", e)
|
||||
raise e
|
||||
|
||||
@abstractmethod
|
||||
async def process_frame(self, requested_frame_types:set[FrameType], frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
# Yield something so the linter can deduce what should happen here.
|
||||
yield QueueFrame(FrameType.END_STREAM, None)
|
||||
|
||||
class SentenceAggregator(AIService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.current_sentence = ""
|
||||
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.TEXT_CHUNK, FrameType.SENTENCE])
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.SENTENCE])
|
||||
|
||||
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not FrameType.SENTENCE in requested_frame_types:
|
||||
return
|
||||
|
||||
if frame.frame_type == FrameType.TEXT_CHUNK:
|
||||
if type(frame.frame_data) != str:
|
||||
raise Exception(
|
||||
"Sentence aggregator requires a string for the data field"
|
||||
)
|
||||
|
||||
self.current_sentence += frame.frame_data
|
||||
if self.current_sentence.endswith((".", "?", "!")):
|
||||
sentence = self.current_sentence
|
||||
self.current_sentence = ""
|
||||
yield QueueFrame(FrameType.SENTENCE, sentence)
|
||||
elif frame.frame_type == FrameType.END_STREAM:
|
||||
if self.current_sentence:
|
||||
yield QueueFrame(FrameType.SENTENCE, self.current_sentence)
|
||||
elif frame.frame_type == FrameType.SENTENCE:
|
||||
async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if isinstance(frame, ControlQueueFrame):
|
||||
yield frame
|
||||
|
||||
@abstractmethod
|
||||
async def finalize(self) -> AsyncGenerator[QueueFrame, None]:
|
||||
# This is a trick for the interpreter (and linter) to know that this is a generator.
|
||||
if False:
|
||||
yield QueueFrame()
|
||||
|
||||
class LLMService(AIService):
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.LLM_MESSAGE, FrameType.SENTENCE, FrameType.TRANSCRIPTION])
|
||||
|
||||
def allowed_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.SENTENCE, FrameType.TEXT_CHUNK])
|
||||
|
||||
@abstractmethod
|
||||
async def run_llm_async(self, messages) -> AsyncGenerator[str, None]:
|
||||
yield ""
|
||||
@@ -118,52 +84,59 @@ class LLMService(AIService):
|
||||
async def run_llm(self, messages) -> str:
|
||||
pass
|
||||
|
||||
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if frame.frame_type == FrameType.LLM_MESSAGE:
|
||||
if type(frame.frame_data) != list:
|
||||
raise Exception("LLM service requires a dict for the data field")
|
||||
|
||||
messages: list[dict[str, str]] = frame.frame_data
|
||||
if FrameType.SENTENCE in requested_frame_types:
|
||||
yield QueueFrame(FrameType.SENTENCE, await self.run_llm(messages))
|
||||
else:
|
||||
async for text_chunk in self.run_llm_async(messages):
|
||||
yield QueueFrame(FrameType.TEXT_CHUNK, text_chunk)
|
||||
|
||||
# TODO: handle other frame types! Need to aggregate into messages
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if isinstance(frame, ControlQueueFrame):
|
||||
yield frame
|
||||
elif isinstance(frame, LLMMessagesQueueFrame):
|
||||
async for text_chunk in self.run_llm_async(frame.messages):
|
||||
yield TextQueueFrame(text_chunk)
|
||||
|
||||
|
||||
class TTSService(AIService):
|
||||
def __init__(self, aggregate_sentences=True):
|
||||
super().__init__()
|
||||
self.aggregate_sentences: bool = aggregate_sentences
|
||||
self.current_sentence: str = ""
|
||||
|
||||
# Some TTS services require a specific sample rate. We default to 16k
|
||||
def get_mic_sample_rate(self):
|
||||
return 16000
|
||||
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.SENTENCE, FrameType.TRANSCRIPTION, FrameType.TEXT_CHUNK])
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.AUDIO])
|
||||
|
||||
# Converts the sentence to audio. Yields a list of audio frames that can
|
||||
# Converts the text to audio. Yields a list of audio frames that can
|
||||
# be sent to the microphone device
|
||||
@abstractmethod
|
||||
async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]:
|
||||
async def run_tts(self, text) -> AsyncGenerator[bytes, None]:
|
||||
# yield empty bytes here, so linting can infer what this method does
|
||||
yield bytes()
|
||||
|
||||
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not FrameType.AUDIO in requested_frame_types:
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if isinstance(frame, ControlQueueFrame):
|
||||
yield frame
|
||||
return
|
||||
elif not isinstance(frame, TextQueueFrame):
|
||||
return
|
||||
|
||||
if type(frame.frame_data) != str:
|
||||
raise Exception("TTS service requires a string for the data field")
|
||||
text: str | None = None
|
||||
if not self.aggregate_sentences:
|
||||
text = frame.text
|
||||
else:
|
||||
self.current_sentence += frame.text
|
||||
if self.current_sentence.endswith((".", "?", "!")):
|
||||
text = self.current_sentence
|
||||
self.current_sentence = ""
|
||||
|
||||
async for audio_chunk in self.run_tts(frame.frame_data):
|
||||
yield QueueFrame(FrameType.AUDIO, audio_chunk)
|
||||
if text:
|
||||
async for audio_chunk in self.run_tts(text):
|
||||
yield AudioQueueFrame(audio_chunk)
|
||||
|
||||
async def finalize(self):
|
||||
if self.current_sentence:
|
||||
async for audio_chunk in self.run_tts(self.current_sentence):
|
||||
yield AudioQueueFrame(audio_chunk)
|
||||
|
||||
# Convenience function to send the audio for a sentence to the given queue
|
||||
async def say(self, sentence, queue: asyncio.Queue):
|
||||
await self.run_to_queue(queue, [QueueFrame(FrameType.SENTENCE, sentence)])
|
||||
await self.run_to_queue(queue, [TextQueueFrame(sentence)])
|
||||
|
||||
|
||||
class ImageGenService(AIService):
|
||||
@@ -171,30 +144,53 @@ class ImageGenService(AIService):
|
||||
super().__init__(**kwargs)
|
||||
self.image_size = image_size
|
||||
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.SENTENCE, FrameType.TRANSCRIPTION, FrameType.TEXT_CHUNK, FrameType.IMAGE_DESCRIPTION])
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.IMAGE])
|
||||
|
||||
# Renders the image. Returns an Image object.
|
||||
@abstractmethod
|
||||
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
|
||||
async def run_image_gen(self, sentence:str) -> tuple[str, bytes]:
|
||||
pass
|
||||
|
||||
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not FrameType.IMAGE in requested_frame_types:
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not isinstance(frame, TextQueueFrame):
|
||||
yield frame
|
||||
return
|
||||
|
||||
if type(frame.frame_data) != str:
|
||||
raise Exception("Image service requires a string for the data field")
|
||||
(url, image_data) = await self.run_image_gen(frame.text)
|
||||
yield ImageQueueFrame(url, image_data)
|
||||
|
||||
(_, image_data) = await self.run_image_gen(frame.frame_data)
|
||||
yield QueueFrame(FrameType.IMAGE, image_data)
|
||||
class STTService(AIService):
|
||||
"""STTService is a base class for speech-to-text services."""
|
||||
|
||||
_frame_rate: int
|
||||
def __init__(self, frame_rate: int = 16000, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._frame_rate = frame_rate
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def run_stt(self, audio: BinaryIO) -> str:
|
||||
"""Returns transcript as a string"""
|
||||
pass
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
"""Processes a frame of audio data, either buffering or transcribing it."""
|
||||
if not isinstance(frame, AudioQueueFrame):
|
||||
return
|
||||
|
||||
data = frame.data
|
||||
content = io.BufferedRandom(io.BytesIO())
|
||||
ww = wave.open(self._content, "wb")
|
||||
ww.setnchannels(1)
|
||||
ww.setsampwidth(2)
|
||||
ww.setframerate(self._frame_rate)
|
||||
ww.writeframesraw(data)
|
||||
ww.close()
|
||||
content.seek(0)
|
||||
text = await self.run_stt(content)
|
||||
yield TextQueueFrame(text)
|
||||
|
||||
@dataclass
|
||||
class AIServiceConfig:
|
||||
tts: TTSService
|
||||
image: ImageGenService
|
||||
llm: LLMService
|
||||
stt: STTService
|
||||
|
||||
@@ -5,7 +5,6 @@ import json
|
||||
from openai import AsyncAzureOpenAI
|
||||
|
||||
import os
|
||||
import requests
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
@@ -16,7 +15,10 @@ from PIL import Image
|
||||
from azure.cognitiveservices.speech import SpeechSynthesizer, SpeechConfig, ResultReason, CancellationReason
|
||||
|
||||
class AzureTTSService(TTSService):
|
||||
def __init__(self, speech_key=None, speech_region=None):
|
||||
|
||||
def __init__(
|
||||
self, speech_key=None, speech_region=None, voice_name="en-US-SaraNeural"
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
speech_key = speech_key or os.getenv("AZURE_SPEECH_SERVICE_KEY")
|
||||
@@ -25,20 +27,19 @@ class AzureTTSService(TTSService):
|
||||
self.speech_config = SpeechConfig(subscription=speech_key, region=speech_region)
|
||||
self.speech_synthesizer = SpeechSynthesizer(speech_config=self.speech_config, audio_config=None)
|
||||
|
||||
self.voice_name = voice_name
|
||||
|
||||
async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]:
|
||||
self.logger.info("Running azure tts")
|
||||
ssml = "<speak version='1.0' xml:lang='en-US' xmlns='http://www.w3.org/2001/10/synthesis' " \
|
||||
ssml = f"<speak version='1.0' xml:lang='en-US' xmlns='http://www.w3.org/2001/10/synthesis' " \
|
||||
"xmlns:mstts='http://www.w3.org/2001/mstts'>" \
|
||||
"<voice name='en-US-SaraNeural'>" \
|
||||
f"<voice name={self.voice_name}>" \
|
||||
"<mstts:silence type='Sentenceboundary' value='20ms' />" \
|
||||
"<mstts:express-as style='lyrical' styledegree='2' role='SeniorFemale'>" \
|
||||
"<prosody rate='1.05'>" \
|
||||
f"{sentence}" \
|
||||
"</prosody></mstts:express-as></voice></speak> "
|
||||
try:
|
||||
result = await asyncio.to_thread(self.speech_synthesizer.speak_ssml, (ssml))
|
||||
except Exception as e:
|
||||
self.logger.error("Error in azure tts", e)
|
||||
result = await asyncio.to_thread(self.speech_synthesizer.speak_ssml, (ssml))
|
||||
self.logger.info("Got azure tts result")
|
||||
if result.reason == ResultReason.SynthesizingAudioCompleted:
|
||||
self.logger.info("Returning result")
|
||||
@@ -95,81 +96,58 @@ class AzureLLMService(LLMService):
|
||||
|
||||
class AzureImageGenServiceREST(ImageGenService):
|
||||
|
||||
def __init__(self, image_size:str, api_key=None, azure_endpoint=None, api_version=None, model=None):
|
||||
def __init__(
|
||||
self,
|
||||
image_size: str,
|
||||
api_key: str | None = None,
|
||||
azure_endpoint: str | None = None,
|
||||
api_version: str | None = None,
|
||||
model: str | None = None,
|
||||
aiohttp_session: aiohttp.ClientSession | None=None,
|
||||
timeout_seconds=120,
|
||||
):
|
||||
super().__init__(image_size=image_size)
|
||||
self.api_key = api_key or os.getenv("AZURE_DALLE_KEY")
|
||||
self.azure_endpoint = azure_endpoint or os.getenv("AZURE_DALLE_ENDPOINT")
|
||||
self.api_version = api_version or "2023-06-01-preview"
|
||||
self.model = model or os.getenv("AZURE_DALLE_DEPLOYMENT_ID")
|
||||
self.aiohttp_session: aiohttp.ClientSession = (
|
||||
aiohttp_session or aiohttp.ClientSession()
|
||||
)
|
||||
self.timeout_seconds = timeout_seconds
|
||||
|
||||
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
|
||||
# TODO hoist the session to app-level
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = f"{self.azure_endpoint}openai/images/generations:submit?api-version={self.api_version}"
|
||||
headers= { "api-key": self.api_key, "Content-Type": "application/json" }
|
||||
body = {
|
||||
# Enter your prompt text here
|
||||
"prompt": sentence,
|
||||
"size": self.image_size,
|
||||
"n": 1,
|
||||
}
|
||||
async with session.post(url, headers=headers, json=body) as submission:
|
||||
operation_location = submission.headers['operation-location']
|
||||
url = f"{self.azure_endpoint}openai/images/generations:submit?api-version={self.api_version}"
|
||||
headers= { "api-key": self.api_key, "Content-Type": "application/json" }
|
||||
body = {
|
||||
"prompt": sentence,
|
||||
"size": self.image_size,
|
||||
"n": 1,
|
||||
}
|
||||
async with self.aiohttp_session.post(
|
||||
url, headers=headers, json=body
|
||||
) as submission:
|
||||
operation_location = submission.headers['operation-location']
|
||||
|
||||
status = ""
|
||||
attempts_left = 120
|
||||
json_response = None
|
||||
while status != "succeeded":
|
||||
attempts_left -= 1
|
||||
if attempts_left == 0:
|
||||
raise Exception("Image generation timed out")
|
||||
status = ""
|
||||
attempts_left = self.timeout_seconds
|
||||
json_response = None
|
||||
while status != "succeeded":
|
||||
attempts_left -= 1
|
||||
if attempts_left == 0:
|
||||
raise Exception("Image generation timed out")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
response = await session.get(operation_location, headers=headers)
|
||||
json_response = await response.json()
|
||||
status = json_response["status"]
|
||||
await asyncio.sleep(1)
|
||||
response = await self.aiohttp_session.get(operation_location, headers=headers)
|
||||
json_response = await response.json()
|
||||
status = json_response["status"]
|
||||
|
||||
image_url = json_response["result"]["data"][0]["url"] if json_response else None
|
||||
if not image_url:
|
||||
raise Exception("Image generation failed")
|
||||
image_url = json_response["result"]["data"][0]["url"] if json_response else None
|
||||
if not image_url:
|
||||
raise Exception("Image generation failed")
|
||||
|
||||
# Load the image from the url
|
||||
async with session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
return (image_url, image.tobytes())
|
||||
|
||||
|
||||
class AzureImageGenService(ImageGenService):
|
||||
|
||||
def __init__(self, api_key=None, azure_endpoint=None, api_version=None, model=None):
|
||||
super().__init__()
|
||||
|
||||
api_key = api_key or os.getenv("AZURE_DALLE_KEY")
|
||||
azure_endpoint = azure_endpoint or os.getenv("AZURE_DALLE_ENDPOINT")
|
||||
api_version = api_version or "2023-06-01-preview"
|
||||
self.model = model or os.getenv("AZURE_DALLE_DEPLOYMENT_ID")
|
||||
|
||||
self.client = AzureOpenAI(
|
||||
api_key=api_key,
|
||||
azure_endpoint=azure_endpoint,
|
||||
api_version=api_version,
|
||||
)
|
||||
|
||||
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
|
||||
self.logger.info("Generating azure image", sentence)
|
||||
|
||||
image = self.client.images.generate(
|
||||
model=self.model,
|
||||
prompt=sentence,
|
||||
n=1,
|
||||
size=self.image_size,
|
||||
)
|
||||
|
||||
url = image["data"][0]["url"]
|
||||
response = requests.get(url)
|
||||
|
||||
dalle_stream = io.BytesIO(response.content)
|
||||
dalle_im = Image.open(dalle_stream.tobytes())
|
||||
|
||||
return (url, dalle_im)
|
||||
# Load the image from the url
|
||||
async with self.aiohttp_session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
return (image_url, image.tobytes())
|
||||
|
||||
@@ -1,15 +1,25 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import types
|
||||
|
||||
from functools import partial
|
||||
from queue import Queue, Empty
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.queue_frame import (
|
||||
AudioQueueFrame,
|
||||
EndStreamQueueFrame,
|
||||
ImageQueueFrame,
|
||||
QueueFrame,
|
||||
StartStreamQueueFrame,
|
||||
TextQueueFrame,
|
||||
TranscriptionQueueFrame,
|
||||
)
|
||||
|
||||
from threading import Thread, Event, Timer
|
||||
from threading import Thread, Event
|
||||
|
||||
from daily import (
|
||||
EventHandler,
|
||||
@@ -21,12 +31,22 @@ from daily import (
|
||||
)
|
||||
|
||||
class DailyTransportService(EventHandler):
|
||||
_daily_initialized = False
|
||||
_lock = threading.Lock()
|
||||
|
||||
speaker_enabled: bool
|
||||
speaker_sample_rate: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
room_url: str,
|
||||
token: str | None,
|
||||
bot_name: str,
|
||||
duration: float = 10,
|
||||
min_others_count: int = 1,
|
||||
start_transcription: bool = True,
|
||||
speaker_enabled: bool = False,
|
||||
speaker_sample_rate: int = 16000,
|
||||
):
|
||||
super().__init__()
|
||||
self.bot_name: str = bot_name
|
||||
@@ -34,6 +54,8 @@ class DailyTransportService(EventHandler):
|
||||
self.token: str | None = token
|
||||
self.duration: float = duration
|
||||
self.expiration = time.time() + duration * 60
|
||||
self.min_others_count = min_others_count
|
||||
self.start_transcription = start_transcription
|
||||
|
||||
# This queue is used to marshal frames from the async send queue to the thread that emits audio & video.
|
||||
# We need this to maintain the asynchronous behavior of asyncio queues -- to give async functions
|
||||
@@ -49,11 +71,14 @@ class DailyTransportService(EventHandler):
|
||||
self.camera_width = 1024
|
||||
self.camera_height = 768
|
||||
self.camera_enabled = False
|
||||
self.speaker_enabled = speaker_enabled
|
||||
self.speaker_sample_rate = speaker_sample_rate
|
||||
|
||||
self.send_queue = asyncio.Queue()
|
||||
self.receive_queue = asyncio.Queue()
|
||||
|
||||
self.other_participant_has_joined = False
|
||||
self.my_participant_id = None
|
||||
|
||||
self.camera_thread = None
|
||||
self.frame_consumer_thread = None
|
||||
@@ -91,6 +116,7 @@ class DailyTransportService(EventHandler):
|
||||
handler(*args, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception in event handler {event_name}: {e}")
|
||||
raise e
|
||||
|
||||
def add_event_handler(self, event_name: str, handler):
|
||||
if not event_name.startswith("on_"):
|
||||
@@ -114,7 +140,11 @@ class DailyTransportService(EventHandler):
|
||||
return decorator
|
||||
|
||||
def configure_daily(self):
|
||||
Daily.init()
|
||||
# Only initialize Daily once
|
||||
if not DailyTransportService._daily_initialized:
|
||||
with DailyTransportService._lock:
|
||||
Daily.init()
|
||||
DailyTransportService._daily_initialized = True
|
||||
self.client = CallClient(event_handler=self)
|
||||
|
||||
if self.mic_enabled:
|
||||
@@ -127,9 +157,11 @@ class DailyTransportService(EventHandler):
|
||||
"camera", width=self.camera_width, height=self.camera_height, color_format="RGB"
|
||||
)
|
||||
|
||||
self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
|
||||
"speaker", sample_rate=16000, channels=1
|
||||
)
|
||||
if self.speaker_enabled:
|
||||
self.speaker: VirtualSpeakerDevice = Daily.create_speaker_device(
|
||||
"speaker", sample_rate=self.speaker_sample_rate, channels=1
|
||||
)
|
||||
Daily.select_speaker_device("speaker")
|
||||
|
||||
self.image: bytes | None = None
|
||||
self.camera_thread = Thread(target=self.run_camera, daemon=True)
|
||||
@@ -139,10 +171,9 @@ class DailyTransportService(EventHandler):
|
||||
self.frame_consumer_thread = Thread(target=self.frame_consumer, daemon=True)
|
||||
self.frame_consumer_thread.start()
|
||||
|
||||
Daily.select_speaker_device("speaker")
|
||||
|
||||
self.client.set_user_name(self.bot_name)
|
||||
self.client.join(self.room_url, self.token, completion=self.call_joined)
|
||||
self.my_participant_id = self.client.participants()["local"]["id"]
|
||||
|
||||
self.client.update_inputs(
|
||||
{
|
||||
@@ -183,16 +214,28 @@ class DailyTransportService(EventHandler):
|
||||
}
|
||||
)
|
||||
|
||||
if self.token:
|
||||
if self.token and self.start_transcription:
|
||||
self.client.start_transcription(self.transcription_settings)
|
||||
|
||||
self.my_participant_id = self.client.participants()["local"]["id"]
|
||||
def _receive_audio(self):
|
||||
"""Receive audio from the Daily call and put it on the receive queue"""
|
||||
seconds = 1
|
||||
desired_frame_count = self.speaker_sample_rate * seconds
|
||||
while True:
|
||||
buffer = self.speaker.read_frames(desired_frame_count)
|
||||
if len(buffer) > 0:
|
||||
frame = AudioQueueFrame(buffer)
|
||||
if self.loop:
|
||||
asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self.loop)
|
||||
|
||||
async def get_receive_frames(self):
|
||||
def interrupt(self):
|
||||
self.is_interrupted.set()
|
||||
|
||||
async def get_receive_frames(self) -> AsyncGenerator[QueueFrame, None]:
|
||||
while True:
|
||||
frame = await self.receive_queue.get()
|
||||
yield frame
|
||||
if frame.frame_type == FrameType.END_STREAM:
|
||||
if isinstance(frame, EndStreamQueueFrame):
|
||||
break
|
||||
|
||||
def get_async_send_queue(self):
|
||||
@@ -203,7 +246,7 @@ class DailyTransportService(EventHandler):
|
||||
frame: QueueFrame | list = await self.send_queue.get()
|
||||
self.threadsafe_send_queue.put(frame)
|
||||
self.send_queue.task_done()
|
||||
if type(frame) == QueueFrame and frame.frame_type == FrameType.END_STREAM:
|
||||
if isinstance(frame, EndStreamQueueFrame):
|
||||
break
|
||||
|
||||
async def wait_for_send_queue_to_empty(self):
|
||||
@@ -217,24 +260,25 @@ class DailyTransportService(EventHandler):
|
||||
async def run(self) -> None:
|
||||
self.configure_daily()
|
||||
|
||||
self.participant_left = False
|
||||
self.do_shutdown = False
|
||||
|
||||
async_output_queue_marshal_task = asyncio.create_task(self.marshal_frames())
|
||||
|
||||
try:
|
||||
participant_count: int = len(self.client.participants())
|
||||
self.logger.info(f"{participant_count} participants in room")
|
||||
while time.time() < self.expiration and not self.participant_left and not self.stop_threads.is_set():
|
||||
while time.time() < self.expiration and not self.do_shutdown and not self.stop_threads.is_set():
|
||||
await asyncio.sleep(1)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e}")
|
||||
raise e
|
||||
finally:
|
||||
self.client.leave()
|
||||
|
||||
self.stop_threads.set()
|
||||
|
||||
await self.receive_queue.put(QueueFrame(FrameType.END_STREAM, None))
|
||||
await self.send_queue.put(QueueFrame(FrameType.END_STREAM, None))
|
||||
await self.receive_queue.put(EndStreamQueueFrame())
|
||||
await self.send_queue.put(EndStreamQueueFrame())
|
||||
await async_output_queue_marshal_task
|
||||
|
||||
if self.camera_thread and self.camera_thread.is_alive():
|
||||
@@ -250,6 +294,9 @@ class DailyTransportService(EventHandler):
|
||||
|
||||
def call_joined(self, join_data, client_error):
|
||||
self.logger.info(f"Call_joined: {join_data}, {client_error}")
|
||||
if self.speaker_enabled:
|
||||
t = Thread(target=self._receive_audio, daemon=True)
|
||||
t.start()
|
||||
|
||||
def on_error(self, error):
|
||||
self.logger.error(f"on_error: {error}")
|
||||
@@ -263,8 +310,8 @@ class DailyTransportService(EventHandler):
|
||||
self.on_first_other_participant_joined()
|
||||
|
||||
def on_participant_left(self, participant, reason):
|
||||
if len(self.client.participants()) < 2:
|
||||
self.participant_left = True
|
||||
if len(self.client.participants()) < self.min_others_count + 1:
|
||||
self.do_shutdown = True
|
||||
pass
|
||||
|
||||
def on_app_message(self, message, sender):
|
||||
@@ -272,7 +319,12 @@ class DailyTransportService(EventHandler):
|
||||
|
||||
def on_transcription_message(self, message:dict):
|
||||
if self.loop:
|
||||
frame = QueueFrame(FrameType.TRANSCRIPTION, message)
|
||||
participantId = ""
|
||||
if "participantId" in message:
|
||||
participantId = message["participantId"]
|
||||
elif "session_id" in message:
|
||||
participantId = message["session_id"]
|
||||
frame = TranscriptionQueueFrame(message["text"], participantId, message["timestamp"])
|
||||
asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self.loop)
|
||||
|
||||
def on_transcription_stopped(self, stopped_by, stopped_by_error):
|
||||
@@ -296,6 +348,7 @@ class DailyTransportService(EventHandler):
|
||||
time.sleep(1.0 / 8) # 8 fps
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception {e} in camera thread.")
|
||||
raise e
|
||||
|
||||
def frame_consumer(self):
|
||||
self.logger.info("🎬 Starting frame consumer thread")
|
||||
@@ -305,15 +358,15 @@ class DailyTransportService(EventHandler):
|
||||
while True:
|
||||
try:
|
||||
frames_or_frame: QueueFrame | list[QueueFrame] = self.threadsafe_send_queue.get()
|
||||
if type(frames_or_frame) == QueueFrame:
|
||||
if isinstance(frames_or_frame, QueueFrame):
|
||||
frames: list[QueueFrame] = [frames_or_frame]
|
||||
elif type(frames_or_frame) == list:
|
||||
elif isinstance(frames_or_frame, list):
|
||||
frames: list[QueueFrame] = frames_or_frame
|
||||
else:
|
||||
raise Exception("Unknown type in output queue")
|
||||
|
||||
for frame in frames:
|
||||
if frame.frame_type == FrameType.END_STREAM:
|
||||
if isinstance(frame, EndStreamQueueFrame):
|
||||
self.logger.info("Stopping frame consumer thread")
|
||||
self.threadsafe_send_queue.task_done()
|
||||
return
|
||||
@@ -321,8 +374,8 @@ class DailyTransportService(EventHandler):
|
||||
# if interrupted, we just pull frames off the queue and discard them
|
||||
if not self.is_interrupted.is_set():
|
||||
if frame:
|
||||
if frame.frame_type == FrameType.AUDIO:
|
||||
chunk = frame.frame_data
|
||||
if isinstance(frame, AudioQueueFrame):
|
||||
chunk = frame.data
|
||||
|
||||
all_audio_frames.extend(chunk)
|
||||
|
||||
@@ -331,19 +384,19 @@ class DailyTransportService(EventHandler):
|
||||
if l:
|
||||
self.mic.write_frames(bytes(b[:l]))
|
||||
b = b[l:]
|
||||
elif frame.frame_type == FrameType.IMAGE:
|
||||
self.set_image(frame.frame_data)
|
||||
elif isinstance(frame, ImageQueueFrame):
|
||||
self.set_image(frame.image)
|
||||
elif len(b):
|
||||
self.mic.write_frames(bytes(b))
|
||||
b = bytearray()
|
||||
else:
|
||||
if self.interrupt_time:
|
||||
self.logger.info(
|
||||
f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}"
|
||||
)
|
||||
self.interrupt_time = None
|
||||
# if there are leftover audio bytes, write them now; failing to do so
|
||||
# can cause static in the audio stream.
|
||||
if len(b):
|
||||
self.mic.write_frames(bytes(b))
|
||||
b = bytearray()
|
||||
|
||||
if frame.frame_type == FrameType.START_STREAM:
|
||||
if isinstance(frame, StartStreamQueueFrame):
|
||||
self.is_interrupted.clear()
|
||||
|
||||
self.threadsafe_send_queue.task_done()
|
||||
@@ -353,5 +406,6 @@ class DailyTransportService(EventHandler):
|
||||
self.mic.write_frames(bytes(b))
|
||||
except Exception as e:
|
||||
self.logger.error(f"Exception in frame_consumer: {e}, {len(b)}")
|
||||
raise e
|
||||
|
||||
b = bytearray()
|
||||
|
||||
@@ -9,28 +9,30 @@ from dailyai.services.ai_services import TTSService
|
||||
|
||||
|
||||
class ElevenLabsTTSService(TTSService):
|
||||
def __init__(self, api_key=None, voice_id=None):
|
||||
def __init__(self, api_key=None, voice_id=None, aiohttp_session:aiohttp.ClientSession=None):
|
||||
super().__init__()
|
||||
|
||||
self.api_key = api_key or os.getenv("ELEVENLABS_API_KEY")
|
||||
self.voice_id = voice_id or os.getenv("ELEVENLABS_VOICE_ID")
|
||||
self.aiohttp_session = aiohttp_session or aiohttp.ClientSession()
|
||||
|
||||
async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.voice_id}/stream"
|
||||
payload = {"text": sentence, "model_id": "eleven_turbo_v2"}
|
||||
querystring = {"output_format": "pcm_16000", "optimize_streaming_latency": 2}
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
async with session.post(url, json=payload, headers=headers, params=querystring) as r:
|
||||
if r.status != 200:
|
||||
self.logger.error(
|
||||
f"audio fetch status code: {r.status}, error: {r.text}"
|
||||
)
|
||||
return
|
||||
url = f"https://api.elevenlabs.io/v1/text-to-speech/{self.voice_id}/stream"
|
||||
payload = {"text": sentence, "model_id": "eleven_turbo_v2"}
|
||||
querystring = {"output_format": "pcm_16000", "optimize_streaming_latency": 2}
|
||||
headers = {
|
||||
"xi-api-key": self.api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
async with self.aiohttp_session.post(
|
||||
url, json=payload, headers=headers, params=querystring
|
||||
) as r:
|
||||
if r.status != 200:
|
||||
self.logger.error(
|
||||
f"audio fetch status code: {r.status}, error: {r.text}"
|
||||
)
|
||||
return
|
||||
|
||||
async for chunk in r.content:
|
||||
if chunk:
|
||||
yield chunk
|
||||
async for chunk in r.content:
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
@@ -34,9 +34,9 @@ class FalImageGenService(ImageGenService):
|
||||
raise Exception("Image generation failed")
|
||||
|
||||
return image_url
|
||||
print(f"fetching image url...")
|
||||
print("fetching image url...")
|
||||
image_url = await asyncio.to_thread(get_image_url, sentence, self.image_size)
|
||||
print(f"got image url, downloading image...")
|
||||
print("got image url, downloading image...")
|
||||
# Load the image from the url
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as response:
|
||||
|
||||
72
src/dailyai/services/local_stt_service.py
Normal file
72
src/dailyai/services/local_stt_service.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import array
|
||||
import io
|
||||
import math
|
||||
from typing import AsyncGenerator
|
||||
import wave
|
||||
from dailyai.queue_frame import AudioQueueFrame, QueueFrame, TextQueueFrame
|
||||
from dailyai.services.ai_services import STTService
|
||||
|
||||
|
||||
class LocalSTTService(STTService):
|
||||
_content: io.BufferedRandom
|
||||
_wave: wave.Wave_write
|
||||
_current_silence_frames: int
|
||||
|
||||
# Configuration
|
||||
_min_rms: int
|
||||
_max_silence_frames: int
|
||||
_frame_rate: int
|
||||
|
||||
def __init__(self,
|
||||
min_rms: int = 400,
|
||||
max_silence_frames: int = 3,
|
||||
frame_rate: int = 16000,
|
||||
**kwargs):
|
||||
super().__init__(frame_rate, **kwargs)
|
||||
self._current_silence_frames = 0
|
||||
self._min_rms = min_rms
|
||||
self._max_silence_frames = max_silence_frames
|
||||
self._frame_rate = frame_rate
|
||||
self._new_wave()
|
||||
|
||||
def _new_wave(self):
|
||||
"""Creates a new wave object and content buffer."""
|
||||
self._content = io.BufferedRandom(io.BytesIO())
|
||||
ww = wave.open(self._content, "wb")
|
||||
ww.setnchannels(1)
|
||||
ww.setsampwidth(2)
|
||||
ww.setframerate(self._frame_rate)
|
||||
self._wave = ww
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
"""Processes a frame of audio data, either buffering or transcribing it."""
|
||||
if not isinstance(frame, AudioQueueFrame):
|
||||
return
|
||||
|
||||
data = frame.data
|
||||
# Try to filter out empty background noise
|
||||
# (Very rudimentary approach, can be improved)
|
||||
rms = self._get_volume(data)
|
||||
if rms >= self._min_rms:
|
||||
# If volume is high enough, write new data to wave file
|
||||
self._wave.writeframesraw(data)
|
||||
|
||||
# If buffer is not empty and we detect a 3-frame pause in speech,
|
||||
# transcribe the audio gathered so far.
|
||||
if self._content.tell() > 0 and self._current_silence_frames > self._max_silence_frames:
|
||||
self._current_silence_frames = 0
|
||||
self._wave.close()
|
||||
self._content.seek(0)
|
||||
text = await self.run_stt(self._content)
|
||||
self._new_wave()
|
||||
yield TextQueueFrame(text)
|
||||
# If we get this far, this is a frame of silence
|
||||
self._current_silence_frames += 1
|
||||
|
||||
def _get_volume(self, audio: bytes) -> float:
|
||||
# https://docs.python.org/3/library/array.html
|
||||
audio_array = array.array('h', audio)
|
||||
squares = [sample**2 for sample in audio_array]
|
||||
mean = sum(squares) / len(audio_array)
|
||||
rms = math.sqrt(mean)
|
||||
return rms
|
||||
@@ -1,6 +1,4 @@
|
||||
import requests
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from PIL import Image
|
||||
import io
|
||||
from openai import AsyncOpenAI
|
||||
@@ -9,7 +7,7 @@ import os
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from dailyai.services.ai_services import AIService, TTSService, LLMService, ImageGenService
|
||||
from dailyai.services.ai_services import LLMService, ImageGenService
|
||||
|
||||
|
||||
class OpenAILLMService(LLMService):
|
||||
@@ -50,11 +48,19 @@ class OpenAILLMService(LLMService):
|
||||
return None
|
||||
|
||||
class OpenAIImageGenService(ImageGenService):
|
||||
def __init__(self, image_size:str, api_key=None, model=None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size: str,
|
||||
api_key=None,
|
||||
model=None,
|
||||
aiohttp_session: aiohttp.ClientSession | None = None,
|
||||
):
|
||||
super().__init__(image_size=image_size)
|
||||
api_key = api_key or os.getenv("OPEN_AI_KEY")
|
||||
self.model = model or os.getenv("OPEN_AI_IMAGE_MODEL") or "dall-e-3"
|
||||
self.client = AsyncOpenAI(api_key=api_key)
|
||||
self.aiohttp_session=aiohttp_session or aiohttp.ClientSession()
|
||||
|
||||
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
|
||||
self.logger.info("Generating OpenAI image", sentence)
|
||||
@@ -70,10 +76,7 @@ class OpenAIImageGenService(ImageGenService):
|
||||
raise Exception("No image provided in response", image)
|
||||
|
||||
# Load the image from the url
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
return (image_url, image.tobytes())
|
||||
|
||||
return (image_url, dalle_im.tobytes())
|
||||
async with self.aiohttp_session.get(image_url) as response:
|
||||
image_stream = io.BytesIO(await response.content.read())
|
||||
image = Image.open(image_stream)
|
||||
return (image_url, image.tobytes())
|
||||
|
||||
55
src/dailyai/services/whisper_ai_services.py
Normal file
55
src/dailyai/services/whisper_ai_services.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""This module implements Whisper transcription with a locally-downloaded model."""
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
import logging
|
||||
from typing import BinaryIO
|
||||
from faster_whisper import WhisperModel
|
||||
from dailyai.services.local_stt_service import LocalSTTService
|
||||
|
||||
|
||||
class Model(Enum):
|
||||
"""Class of basic Whisper model selection options"""
|
||||
TINY = "tiny"
|
||||
BASE = "base"
|
||||
MEDIUM = "medium"
|
||||
LARGE = "large-v3"
|
||||
DISTIL_LARGE_V2 = "Systran/faster-distil-whisper-large-v2"
|
||||
DISTIL_MEDIUM_EN = "Systran/faster-distil-whisper-medium.en"
|
||||
|
||||
|
||||
class WhisperSTTService(LocalSTTService):
|
||||
"""Class to transcribe audio with a locally-downloaded Whisper model"""
|
||||
_model: WhisperModel
|
||||
|
||||
# Model configuration
|
||||
_model_name: Model
|
||||
_device: str
|
||||
_compute_type: str
|
||||
|
||||
def __init__(self, model_name: Model = Model.DISTIL_MEDIUM_EN,
|
||||
device: str = "auto",
|
||||
compute_type: str = "default"):
|
||||
|
||||
super().__init__()
|
||||
self.logger: logging.Logger = logging.getLogger("dailyai")
|
||||
self._model_name = model_name
|
||||
self._device = device
|
||||
self._compute_type = compute_type
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
"""Loads the Whisper model. Note that if this is the first time
|
||||
this model is being run, it will take time to download."""
|
||||
model = WhisperModel(
|
||||
self._model_name.value,
|
||||
device=self._device,
|
||||
compute_type=self._compute_type)
|
||||
self._model = model
|
||||
|
||||
async def run_stt(self, audio: BinaryIO = None) -> str:
|
||||
"""Transcribes given audio using Whisper"""
|
||||
segments, _ = await asyncio.to_thread(self._model.transcribe, audio)
|
||||
res: str = ""
|
||||
for segment in segments:
|
||||
res += f"{segment.text} "
|
||||
return res
|
||||
@@ -3,33 +3,27 @@ import unittest
|
||||
|
||||
from typing import AsyncGenerator, Generator
|
||||
|
||||
from dailyai.services.ai_services import AIService, SentenceAggregator
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.ai_services import AIService
|
||||
from dailyai.queue_frame import EndStreamQueueFrame, QueueFrame, TextQueueFrame
|
||||
|
||||
class SimpleAIService(AIService):
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.TEXT_CHUNK])
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.TEXT_CHUNK])
|
||||
|
||||
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> QueueFrame | None:
|
||||
return frame
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
yield frame
|
||||
|
||||
class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_async_input(self):
|
||||
service = SimpleAIService()
|
||||
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello"),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
TextQueueFrame("hello"),
|
||||
EndStreamQueueFrame()
|
||||
]
|
||||
async def iterate_frames() -> AsyncGenerator[QueueFrame, None]:
|
||||
for frame in input_frames:
|
||||
yield frame
|
||||
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.TEXT_CHUNK]), iterate_frames()):
|
||||
async for frame in service.run(iterate_frames()):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(input_frames, output_frames)
|
||||
@@ -37,93 +31,18 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_nonasync_input(self):
|
||||
service = SimpleAIService()
|
||||
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello"),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
]
|
||||
input_frames = [TextQueueFrame("hello"), EndStreamQueueFrame()]
|
||||
|
||||
def iterate_frames() -> Generator[QueueFrame, None, None]:
|
||||
for frame in input_frames:
|
||||
yield frame
|
||||
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.TEXT_CHUNK]), iterate_frames()):
|
||||
async for frame in service.run(iterate_frames()):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(input_frames, output_frames)
|
||||
|
||||
|
||||
class TestSentenceAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_clause(self) -> None:
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello"),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
]
|
||||
|
||||
service = SentenceAggregator()
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(1, len(output_frames))
|
||||
self.assertEqual(QueueFrame(FrameType.SENTENCE, "hello"), output_frames[0])
|
||||
|
||||
async def test_sentence(self) -> None:
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello, "),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "world."),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
]
|
||||
|
||||
service = SentenceAggregator()
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(1, len(output_frames))
|
||||
self.assertEqual(QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0])
|
||||
|
||||
async def test_sentence_and_clause(self) -> None:
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello, "),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "world."),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, " How are"),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
]
|
||||
|
||||
service = SentenceAggregator()
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(2, len(output_frames))
|
||||
self.assertEqual(
|
||||
QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0]
|
||||
)
|
||||
self.assertEqual(
|
||||
QueueFrame(FrameType.SENTENCE, " How are"), output_frames[1]
|
||||
)
|
||||
|
||||
async def test_two_sentences(self) -> None:
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello, "),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "world."),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, " How are"),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, " you doing?"),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
]
|
||||
|
||||
service = SentenceAggregator()
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(2, len(output_frames))
|
||||
self.assertEqual(
|
||||
QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0]
|
||||
)
|
||||
self.assertEqual(QueueFrame(FrameType.SENTENCE, " How are you doing?"), output_frames[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from queue import Queue, Empty
|
||||
from threading import Thread, Event
|
||||
from typing import Generator
|
||||
|
||||
from dailyai.async_processor.async_processor import (
|
||||
AsyncProcessor,
|
||||
AsyncProcessorState,
|
||||
LLMResponse,
|
||||
)
|
||||
from dailyai.message_handler.message_handler import MessageHandler
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.ai_services import (
|
||||
AIServiceConfig,
|
||||
ImageGenService,
|
||||
LLMService,
|
||||
TTSService,
|
||||
)
|
||||
"""
|
||||
class MockTTSService(TTSService):
|
||||
def run_tts(self, sentence):
|
||||
for word in sentence.split(' '):
|
||||
time.sleep(0.1)
|
||||
yield bytes(word, "utf-8")
|
||||
|
||||
class MockLLMService(LLMService):
|
||||
def run_llm_async(self, messages) -> Generator[str, None, None]:
|
||||
for i in ["Hello ", "there.", "How are ", "you?", "I ", "hope ", "you ", "are ", "well."]:
|
||||
time.sleep(0.1)
|
||||
yield i
|
||||
|
||||
class MockImageService(ImageGenService):
|
||||
def run_image_gen(self, sentence) -> None:
|
||||
return None
|
||||
|
||||
class TestResponse(unittest.TestCase):
|
||||
def test_base_state_transitions(self):
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
processor = AsyncProcessor(AIServiceConfig(tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service))
|
||||
processor.prepare()
|
||||
processor.play()
|
||||
processor.finalize()
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_state_transitions(self):
|
||||
output_queue = Queue()
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
message_handler = MessageHandler("Hello World")
|
||||
processor = LLMResponse(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
),
|
||||
message_handler,
|
||||
output_queue,
|
||||
)
|
||||
processor.prepare()
|
||||
processor.play()
|
||||
|
||||
# Consume the output from the output queue. It's necessary to mark these tasks as done for the
|
||||
# play function to return.
|
||||
expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."]
|
||||
|
||||
# remove the "start_stream" message from the queue
|
||||
output_queue.get()
|
||||
output_queue.task_done()
|
||||
|
||||
while expected_words:
|
||||
actual_word:QueueFrame = output_queue.get()
|
||||
word = expected_words.pop(0)
|
||||
self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME)
|
||||
self.assertEqual(actual_word.frame_data, bytes(word, "utf-8"))
|
||||
output_queue.task_done()
|
||||
|
||||
processor.finalize()
|
||||
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_interrupt_preparation(self):
|
||||
output_queue = Queue()
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
message_handler = MessageHandler("System Message")
|
||||
processor = LLMResponse(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
),
|
||||
message_handler,
|
||||
output_queue,
|
||||
)
|
||||
processor.prepare()
|
||||
interrupt_request_at = time.perf_counter()
|
||||
processor.interrupt()
|
||||
processor.finalize()
|
||||
finalized_at = time.perf_counter()
|
||||
self.assertTrue(0.1 < finalized_at - interrupt_request_at < 0.2)
|
||||
print(f"delta: {interrupt_request_at, finalized_at}")
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_interrupt_play(self):
|
||||
output_queue = Queue()
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
message_handler = MessageHandler("System Message")
|
||||
processor = LLMResponse(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
),
|
||||
message_handler,
|
||||
output_queue,
|
||||
)
|
||||
processor.prepare()
|
||||
processor.play()
|
||||
|
||||
stop_processing_output_queue = Event()
|
||||
def process_output_queue_async():
|
||||
# Consume the output from the output queue. It's necessary to mark these tasks as done for the
|
||||
# play function to return.
|
||||
time.sleep(0.1)
|
||||
expected_words = ["Hello", "there.", "How", "are", "you?", "I", "hope", "you", "are", "well."]
|
||||
while expected_words and not stop_processing_output_queue.is_set():
|
||||
try:
|
||||
actual_word:QueueFrame = output_queue.get_nowait()
|
||||
if actual_word.frame_type == FrameType.AUDIO_FRAME:
|
||||
time.sleep(0.1)
|
||||
word = expected_words.pop(0)
|
||||
self.assertEqual(actual_word.frame_type, FrameType.AUDIO_FRAME)
|
||||
self.assertEqual(actual_word.frame_data, bytes(word, "utf-8"))
|
||||
output_queue.task_done()
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
process_output_queue = Thread(target=process_output_queue_async, daemon=True)
|
||||
process_output_queue.start()
|
||||
|
||||
time.sleep(0.5)
|
||||
processor.interrupt()
|
||||
|
||||
stop_processing_output_queue.set()
|
||||
process_output_queue.join()
|
||||
|
||||
processor.finalize()
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
def test_statechange_callback(self):
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
processor = AsyncProcessor(
|
||||
AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
)
|
||||
)
|
||||
is_finalized = False
|
||||
def set_is_finalized(async_processor:AsyncProcessor):
|
||||
nonlocal is_finalized
|
||||
is_finalized = True
|
||||
|
||||
processor.set_state_callback(
|
||||
AsyncProcessorState.FINALIZED, set_is_finalized
|
||||
)
|
||||
processor.prepare()
|
||||
self.assertFalse(is_finalized)
|
||||
processor.play()
|
||||
self.assertFalse(is_finalized)
|
||||
processor.finalize()
|
||||
self.assertTrue(is_finalized)
|
||||
self.assertEqual(processor.state, AsyncProcessorState.FINALIZED)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
"""
|
||||
@@ -1,147 +0,0 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from unittest.mock import MagicMock, call
|
||||
|
||||
from dailyai.message_handler.message_handler import MessageHandler, IndexingMessageHandler
|
||||
from dailyai.services.ai_services import (
|
||||
AIServiceConfig,
|
||||
TTSService,
|
||||
LLMService,
|
||||
ImageGenService,
|
||||
)
|
||||
from ..storage.search import SearchIndexer
|
||||
|
||||
|
||||
class TestMessageHandler(unittest.TestCase):
|
||||
def test_simple_intro(self):
|
||||
message_handler = MessageHandler("Hello world")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[{"role": "system", "content": "Hello world"}],
|
||||
)
|
||||
|
||||
def test_simple_user_message(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_simple_user_and_assistant_message(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message"},
|
||||
{"role": "assistant", "content": "Assistant message"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_user_message_overwrite(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message")
|
||||
message_handler.add_user_message("plus something else")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message plus something else"},
|
||||
],
|
||||
)
|
||||
|
||||
def test_user_message_after_assistant(self):
|
||||
message_handler = MessageHandler("System prompt")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message")
|
||||
message_handler.finalize_user_message()
|
||||
message_handler.add_user_message("other user message")
|
||||
self.assertEqual(
|
||||
message_handler.get_llm_messages(),
|
||||
[
|
||||
{"role": "system", "content": "System prompt"},
|
||||
{"role": "user", "content": "User message"},
|
||||
{"role": "assistant", "content": "Assistant message"},
|
||||
{"role": "user", "content": "other user message"},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class MockTTSService(TTSService):
|
||||
def run_tts(self, sentence):
|
||||
for word in sentence.split(" "):
|
||||
time.sleep(0.1)
|
||||
yield bytes(word, "utf-8")
|
||||
|
||||
|
||||
class MockLLMService(LLMService):
|
||||
def run_llm(self, messages) -> str:
|
||||
return "Parsed user message."
|
||||
|
||||
class MockImageService(ImageGenService):
|
||||
def run_image_gen(self, sentence) -> None:
|
||||
return None
|
||||
|
||||
|
||||
class TestStorageMessageHandler(unittest.TestCase):
|
||||
def test_user_message_finalized(self):
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
mock_image_service = MockImageService()
|
||||
|
||||
service_config = AIServiceConfig(
|
||||
tts=mock_tts_service, llm=mock_llm_service, image=mock_image_service
|
||||
)
|
||||
|
||||
mock_indexer = MagicMock(spec=SearchIndexer)
|
||||
|
||||
message_handler = IndexingMessageHandler(
|
||||
"Hello world", service_config, mock_indexer
|
||||
)
|
||||
message_handler.cleanup_user_message = MagicMock(return_value="Parsed user message.")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message will be ignored")
|
||||
message_handler.add_user_message("plus something else")
|
||||
message_handler.finalize_user_message()
|
||||
message_handler.add_assistant_message(
|
||||
"New assistant message will not be ignored"
|
||||
)
|
||||
message_handler.add_user_message("User message second time")
|
||||
message_handler.add_assistant_message("Assistant message second time")
|
||||
message_handler.write_messages_to_storage()
|
||||
|
||||
time.sleep(0.5)
|
||||
message_handler.cleanup_user_message.assert_called_with("User message plus something else")
|
||||
self.assertEqual(
|
||||
mock_indexer.mock_calls,
|
||||
[
|
||||
call.index_text('"Parsed user message."'),
|
||||
call.index_text("New assistant message will not be ignored"),
|
||||
],
|
||||
)
|
||||
|
||||
mock_indexer.reset_mock()
|
||||
|
||||
message_handler.finalize_user_message()
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
self.assertEqual(
|
||||
mock_indexer.mock_calls,
|
||||
[
|
||||
call.index_text('"Parsed user message."'),
|
||||
call.index_text("Assistant message second time"),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,10 +1,7 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureTTSService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
|
||||
async def main(room_url):
|
||||
@@ -1,10 +1,8 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.queue_frame import LLMMessagesQueueFrame
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.ai_services import SentenceAggregator
|
||||
from dailyai.services.azure_ai_services import AzureLLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
|
||||
@@ -28,9 +26,7 @@ async def main(room_url):
|
||||
tts_task = asyncio.create_task(
|
||||
tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
SentenceAggregator().run(
|
||||
llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)])
|
||||
)
|
||||
llm.run([LLMMessagesQueueFrame(messages)]),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.queue_frame import TextQueueFrame
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.open_ai_services import OpenAIImageGenService
|
||||
from dailyai.services.azure_ai_services import AzureImageGenServiceREST
|
||||
|
||||
local_joined = False
|
||||
participant_joined = False
|
||||
@@ -21,9 +22,9 @@ async def main(room_url):
|
||||
transport.camera_width = 1024
|
||||
transport.camera_height = 1024
|
||||
|
||||
imagegen = OpenAIImageGenService(image_size="1024x1024")
|
||||
imagegen = AzureImageGenServiceREST(image_size="1024x1024")
|
||||
image_task = asyncio.create_task(
|
||||
imagegen.run_to_queue(transport.send_queue, [QueueFrame(FrameType.IMAGE_DESCRIPTION, "a cat in the style of picasso")])
|
||||
imagegen.run_to_queue(transport.send_queue, [TextQueueFrame("a cat in the style of picasso")])
|
||||
)
|
||||
|
||||
@transport.event_handler("on_participant_joined")
|
||||
@@ -2,10 +2,9 @@ import argparse
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from dailyai.services.ai_services import SentenceAggregator
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.queue_frame import EndStreamQueueFrame, LLMMessagesQueueFrame
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
|
||||
async def main(room_url:str):
|
||||
@@ -36,9 +35,7 @@ async def main(room_url:str):
|
||||
llm_response_task = asyncio.create_task(
|
||||
elevenlabs_tts.run_to_queue(
|
||||
buffer_queue,
|
||||
SentenceAggregator().run(
|
||||
llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)])
|
||||
),
|
||||
llm.run([LLMMessagesQueueFrame(messages)]),
|
||||
True,
|
||||
)
|
||||
)
|
||||
@@ -48,17 +45,14 @@ async def main(room_url:str):
|
||||
if participant["id"] == transport.my_participant_id:
|
||||
return
|
||||
|
||||
await azure_tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
[QueueFrame(FrameType.SENTENCE, "My friend the LLM is now going to tell a joke about llamas.")]
|
||||
)
|
||||
await azure_tts.say("My friend the LLM is now going to tell a joke about llamas.", transport.send_queue)
|
||||
|
||||
async def buffer_to_send_queue():
|
||||
while True:
|
||||
frame = await buffer_queue.get()
|
||||
await transport.send_queue.put(frame)
|
||||
buffer_queue.task_done()
|
||||
if frame.frame_type == FrameType.END_STREAM:
|
||||
if isinstance(frame, EndStreamQueueFrame):
|
||||
break
|
||||
|
||||
await asyncio.gather(llm_response_task, buffer_to_send_queue())
|
||||
@@ -1,14 +1,9 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from asyncio.queues import Queue
|
||||
import re
|
||||
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.ai_services import SentenceAggregator
|
||||
from dailyai.services.azure_ai_services import AzureLLMService
|
||||
from dailyai.queue_frame import AudioQueueFrame, ImageQueueFrame
|
||||
from dailyai.services.azure_ai_services import AzureImageGenServiceREST, AzureLLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
from dailyai.services.open_ai_services import OpenAIImageGenService
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.fal_ai_services import FalImageGenService
|
||||
|
||||
@@ -27,9 +22,9 @@ async def main(room_url):
|
||||
transport.camera_height = 1024
|
||||
|
||||
llm = AzureLLMService()
|
||||
dalle = FalImageGenService(image_size="1024x1024")
|
||||
#dalle = FalImageGenService(image_size="1024x1024")
|
||||
tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV")
|
||||
# dalle = OpenAIImageGenService(image_size="1024x1024")
|
||||
dalle = AzureImageGenServiceREST(image_size="1024x1024")
|
||||
|
||||
# Get a complete audio chunk from the given text. Splitting this into its own
|
||||
# coroutine lets us ensure proper ordering of the audio chunks on the send queue.
|
||||
@@ -49,14 +44,20 @@ async def main(room_url):
|
||||
]
|
||||
|
||||
image_description = await llm.run_llm(messages)
|
||||
if not image_description:
|
||||
return
|
||||
|
||||
to_speak = f"{month}: {image_description}"
|
||||
audio_task = asyncio.create_task(get_all_audio(to_speak))
|
||||
image_task = asyncio.create_task(dalle.run_image_gen(image_description))
|
||||
(audio, image_data) = await asyncio.gather(
|
||||
get_all_audio(to_speak), dalle.run_image_gen(image_description)
|
||||
audio_task, image_task
|
||||
)
|
||||
|
||||
return {
|
||||
"month": month,
|
||||
"text": image_description,
|
||||
"image_url": image_data[0],
|
||||
"image": image_data[1],
|
||||
"audio": audio,
|
||||
}
|
||||
@@ -85,8 +86,8 @@ async def main(room_url):
|
||||
data = await month_data_task
|
||||
await transport.send_queue.put(
|
||||
[
|
||||
QueueFrame(FrameType.IMAGE, data["image"]),
|
||||
QueueFrame(FrameType.AUDIO, data["audio"]),
|
||||
ImageQueueFrame(data["image_url"], data["image"]),
|
||||
AudioQueueFrame(data["audio"]),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -3,11 +3,10 @@ import asyncio
|
||||
import requests
|
||||
import time
|
||||
import urllib.parse
|
||||
from dailyai.services.ai_services import SentenceAggregator
|
||||
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.queue_aggregators import LLMContextAggregator
|
||||
|
||||
async def main(room_url:str, token):
|
||||
global transport
|
||||
@@ -18,7 +17,7 @@ async def main(room_url:str, token):
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
1,
|
||||
5,
|
||||
)
|
||||
transport.mic_enabled = True
|
||||
transport.mic_sample_rate = 16000
|
||||
@@ -27,33 +26,31 @@ async def main(room_url:str, token):
|
||||
llm = AzureLLMService()
|
||||
tts = AzureTTSService()
|
||||
|
||||
@transport.event_handler("on_first_other_participant_joined")
|
||||
async def on_first_other_participant_joined(transport):
|
||||
await tts.say("Hi, I'm listening!", transport.send_queue)
|
||||
|
||||
async def handle_transcriptions():
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio. Respond to what the user said in a creative and helpful way."},
|
||||
]
|
||||
|
||||
sentence = ""
|
||||
async for frame in transport.get_receive_frames():
|
||||
if frame.frame_type != FrameType.TRANSCRIPTION:
|
||||
continue
|
||||
|
||||
message = frame.frame_data
|
||||
if message["session_id"] == transport.my_participant_id:
|
||||
continue
|
||||
|
||||
# todo: we could differentiate between transcriptions from different participants
|
||||
sentence += message["text"]
|
||||
if sentence.endswith((".", "?", "!")):
|
||||
messages.append({"role": "user", "content": sentence})
|
||||
sentence = ''
|
||||
|
||||
full_response = ""
|
||||
async for response in llm.run_llm_async_sentences(messages):
|
||||
full_response += response
|
||||
async for audio in tts.run_tts(response):
|
||||
await transport.send_queue.put(QueueFrame(FrameType.AUDIO, audio))
|
||||
|
||||
messages.append({"role": "assistant", "content": full_response})
|
||||
tma_in = LLMContextAggregator(
|
||||
messages, "user", transport.my_participant_id
|
||||
)
|
||||
tma_out = LLMContextAggregator(
|
||||
messages, "assistant", transport.my_participant_id
|
||||
)
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
tma_out.run(
|
||||
llm.run(
|
||||
tma_in.run(
|
||||
transport.get_receive_frames()
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
transport.transcription_settings["extra"]["punctuate"] = True
|
||||
await asyncio.gather(transport.run(), handle_transcriptions())
|
||||
98
src/samples/foundational/07-interruptible.py
Normal file
98
src/samples/foundational/07-interruptible.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import requests
|
||||
import time
|
||||
import urllib.parse
|
||||
from dailyai.conversation_wrappers import InterruptibleConversationWrapper
|
||||
|
||||
from dailyai.queue_frame import StartStreamQueueFrame, TextQueueFrame
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
|
||||
|
||||
async def main(room_url:str, token):
|
||||
global transport
|
||||
global llm
|
||||
global tts
|
||||
|
||||
transport = DailyTransportService(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
5,
|
||||
)
|
||||
transport.mic_enabled = True
|
||||
transport.mic_sample_rate = 16000
|
||||
transport.camera_enabled = False
|
||||
|
||||
llm = AzureLLMService()
|
||||
tts = ElevenLabsTTSService(voice_id="ErXwobaYiN019PkySvjV")
|
||||
|
||||
async def run_response(user_speech, tma_in, tma_out):
|
||||
await tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
tma_out.run(
|
||||
llm.run(
|
||||
tma_in.run(
|
||||
[StartStreamQueueFrame(), TextQueueFrame(user_speech)]
|
||||
)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_other_participant_joined")
|
||||
async def on_first_other_participant_joined(transport):
|
||||
await tts.say("Hi, I'm listening!", transport.send_queue)
|
||||
|
||||
async def run_conversation():
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio. Respond to what the user said in a creative and helpful way."},
|
||||
]
|
||||
|
||||
conversation_wrapper = InterruptibleConversationWrapper(
|
||||
frame_generator=transport.get_receive_frames,
|
||||
runner=run_response,
|
||||
interrupt=transport.interrupt,
|
||||
my_participant_id=transport.my_participant_id,
|
||||
llm_messages=messages,
|
||||
)
|
||||
await conversation_wrapper.run_conversation()
|
||||
|
||||
transport.transcription_settings["extra"]["punctuate"] = False
|
||||
await asyncio.gather(transport.run(), run_conversation())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple Daily Bot Sample")
|
||||
parser.add_argument(
|
||||
"-u", "--url", type=str, required=True, help="URL of the Daily room to join"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
"--apikey",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Daily API Key (needed to create token)",
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
# Create a meeting token for the given room with an expiration 1 hour in the future.
|
||||
room_name: str = urllib.parse.urlparse(args.url).path[1:]
|
||||
expiration: float = time.time() + 60 * 60
|
||||
|
||||
res: requests.Response = requests.post(
|
||||
f"https://api.daily.co/v1/meeting-tokens",
|
||||
headers={"Authorization": f"Bearer {args.apikey}"},
|
||||
json={
|
||||
"properties": {"room_name": room_name, "is_owner": True, "exp": expiration}
|
||||
},
|
||||
)
|
||||
|
||||
if res.status_code != 200:
|
||||
raise Exception(f"Failed to create meeting token: {res.status_code} {res.text}")
|
||||
|
||||
token: str = res.json()["token"]
|
||||
|
||||
asyncio.run(main(args.url, token))
|
||||
44
src/samples/foundational/07-whisper-transcription.py
Normal file
44
src/samples/foundational/07-whisper-transcription.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.whisper_ai_services import WhisperSTTService
|
||||
|
||||
|
||||
async def main(room_url: str):
|
||||
global transport
|
||||
global stt
|
||||
|
||||
transport = DailyTransportService(
|
||||
room_url,
|
||||
None,
|
||||
"Transcription bot",
|
||||
)
|
||||
transport.mic_enabled = False
|
||||
transport.camera_enabled = False
|
||||
transport.speaker_enabled = True
|
||||
stt = WhisperSTTService()
|
||||
transcription_output_queue = asyncio.Queue()
|
||||
|
||||
async def handle_transcription():
|
||||
print("`````````TRANSCRIPTION`````````")
|
||||
while True:
|
||||
item = await transcription_output_queue.get()
|
||||
print(item.text)
|
||||
|
||||
async def handle_speaker():
|
||||
await stt.run_to_queue(
|
||||
transcription_output_queue,
|
||||
transport.get_receive_frames()
|
||||
)
|
||||
await asyncio.gather(transport.run(), handle_speaker(), handle_transcription())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple Daily Bot Sample")
|
||||
parser.add_argument(
|
||||
"-u", "--url", type=str, required=True, help="URL of the Daily room to join"
|
||||
)
|
||||
|
||||
args, unknown = parser.parse_known_args()
|
||||
asyncio.run(main(args.url))
|
||||
Reference in New Issue
Block a user