Cleanup: no more sentence aggregator, let the TTS service deal with that; also removed the queue typing stuff from ai_services
This commit is contained in:
@@ -2,16 +2,14 @@ from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
|
||||
class FrameType(Enum):
|
||||
NOOP = -1
|
||||
START_STREAM = 0
|
||||
END_STREAM = 1
|
||||
AUDIO = 2
|
||||
IMAGE = 3
|
||||
SENTENCE = 4
|
||||
TEXT_CHUNK = 5
|
||||
LLM_MESSAGE = 6
|
||||
APP_MESSAGE = 7
|
||||
IMAGE_DESCRIPTION = 8
|
||||
TRANSCRIPTION = 9
|
||||
TEXT = 4
|
||||
LLM_MESSAGE = 5
|
||||
APP_MESSAGE = 6
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class QueueFrame:
|
||||
|
||||
@@ -21,12 +21,6 @@ class AIService:
|
||||
def stop(self):
|
||||
pass
|
||||
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set()
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set()
|
||||
|
||||
async def run_to_queue(self, queue: asyncio.Queue, frames, add_end_of_stream=False) -> None:
|
||||
async for frame in self.run(frames):
|
||||
await queue.put(frame)
|
||||
@@ -39,76 +33,50 @@ class AIService:
|
||||
frames: Iterable[QueueFrame]
|
||||
| AsyncIterable[QueueFrame]
|
||||
| asyncio.Queue[QueueFrame],
|
||||
requested_frame_types: set[FrameType] | None=None,
|
||||
) -> AsyncGenerator[QueueFrame, None]:
|
||||
if requested_frame_types and self.possible_output_frame_types().intersection(requested_frame_types) == set():
|
||||
raise Exception(f"Requested frame types {requested_frame_types} are not supported by this service.")
|
||||
try:
|
||||
if isinstance(frames, AsyncIterable):
|
||||
async for frame in frames:
|
||||
async for output_frame in self.process_frame(frame):
|
||||
yield output_frame
|
||||
elif isinstance(frames, Iterable):
|
||||
for frame in frames:
|
||||
async for output_frame in self.process_frame(frame):
|
||||
yield output_frame
|
||||
elif isinstance(frames, asyncio.Queue):
|
||||
while True:
|
||||
frame = await frames.get()
|
||||
async for output_frame in self.process_frame(frame):
|
||||
yield output_frame
|
||||
if frame.frame_type == FrameType.END_STREAM:
|
||||
break
|
||||
else:
|
||||
raise Exception("Frames must be an iterable or async iterable")
|
||||
|
||||
if not requested_frame_types:
|
||||
requested_frame_types = self.possible_output_frame_types()
|
||||
|
||||
if isinstance(frames, AsyncIterable):
|
||||
async for frame in frames:
|
||||
async for output_frame in self.process_frame(requested_frame_types, frame):
|
||||
yield output_frame
|
||||
elif isinstance(frames, Iterable):
|
||||
for frame in frames:
|
||||
async for output_frame in self.process_frame(requested_frame_types, frame):
|
||||
yield output_frame
|
||||
elif isinstance(frames, asyncio.Queue):
|
||||
while True:
|
||||
frame = await frames.get()
|
||||
async for output_frame in self.process_frame(requested_frame_types, frame):
|
||||
yield output_frame
|
||||
if frame.frame_type == FrameType.END_STREAM:
|
||||
break
|
||||
else:
|
||||
raise Exception("Frames must be an iterable or async iterable")
|
||||
async for output_frame in self.finalize():
|
||||
yield output_frame
|
||||
except Exception as e:
|
||||
self.logger.error("Exception occurred while running AI service", e)
|
||||
raise e
|
||||
|
||||
@abstractmethod
|
||||
async def process_frame(self, requested_frame_types:set[FrameType], frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
# Yield something so the linter can deduce what should happen here.
|
||||
yield QueueFrame(FrameType.END_STREAM, None)
|
||||
|
||||
class SentenceAggregator(AIService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.current_sentence = ""
|
||||
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.TEXT_CHUNK, FrameType.SENTENCE])
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.SENTENCE])
|
||||
|
||||
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not FrameType.SENTENCE in requested_frame_types:
|
||||
return
|
||||
|
||||
if frame.frame_type == FrameType.TEXT_CHUNK:
|
||||
if type(frame.frame_data) != str:
|
||||
raise Exception(
|
||||
"Sentence aggregator requires a string for the data field"
|
||||
)
|
||||
|
||||
self.current_sentence += frame.frame_data
|
||||
if self.current_sentence.endswith((".", "?", "!")):
|
||||
sentence = self.current_sentence
|
||||
self.current_sentence = ""
|
||||
yield QueueFrame(FrameType.SENTENCE, sentence)
|
||||
elif frame.frame_type == FrameType.END_STREAM:
|
||||
if self.current_sentence:
|
||||
yield QueueFrame(FrameType.SENTENCE, self.current_sentence)
|
||||
elif frame.frame_type == FrameType.SENTENCE:
|
||||
yield frame
|
||||
async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
# This is a trick for the interpreter (and linter) to know that this is a generator.
|
||||
if False:
|
||||
yield QueueFrame(FrameType.NOOP, None)
|
||||
|
||||
@abstractmethod
|
||||
async def finalize(self) -> AsyncGenerator[QueueFrame, None]:
|
||||
# This is a trick for the interpreter (and linter) to know that this is a generator.
|
||||
if False:
|
||||
yield QueueFrame(FrameType.NOOP, None)
|
||||
|
||||
class LLMService(AIService):
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.LLM_MESSAGE, FrameType.SENTENCE, FrameType.TRANSCRIPTION])
|
||||
return set([FrameType.LLM_MESSAGE])
|
||||
|
||||
def allowed_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.SENTENCE, FrameType.TEXT_CHUNK])
|
||||
return set([FrameType.TEXT])
|
||||
|
||||
@abstractmethod
|
||||
async def run_llm_async(self, messages) -> AsyncGenerator[str, None]:
|
||||
@@ -118,52 +86,58 @@ class LLMService(AIService):
|
||||
async def run_llm(self, messages) -> str:
|
||||
pass
|
||||
|
||||
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if frame.frame_type == FrameType.LLM_MESSAGE:
|
||||
if type(frame.frame_data) != list:
|
||||
raise Exception("LLM service requires a dict for the data field")
|
||||
|
||||
messages: list[dict[str, str]] = frame.frame_data
|
||||
if FrameType.SENTENCE in requested_frame_types:
|
||||
yield QueueFrame(FrameType.SENTENCE, await self.run_llm(messages))
|
||||
else:
|
||||
async for text_chunk in self.run_llm_async(messages):
|
||||
yield QueueFrame(FrameType.TEXT_CHUNK, text_chunk)
|
||||
|
||||
# TODO: handle other frame types! Need to aggregate into messages
|
||||
async for text_chunk in self.run_llm_async(messages):
|
||||
yield QueueFrame(FrameType.TEXT, text_chunk)
|
||||
|
||||
|
||||
class TTSService(AIService):
|
||||
def __init__(self, aggregate_sentences=True):
|
||||
super().__init__()
|
||||
self.aggregate_sentences: bool = aggregate_sentences
|
||||
self.current_sentence: str = ""
|
||||
|
||||
# Some TTS services require a specific sample rate. We default to 16k
|
||||
def get_mic_sample_rate(self):
|
||||
return 16000
|
||||
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.SENTENCE, FrameType.TRANSCRIPTION, FrameType.TEXT_CHUNK])
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.AUDIO])
|
||||
|
||||
# Converts the sentence to audio. Yields a list of audio frames that can
|
||||
# Converts the text to audio. Yields a list of audio frames that can
|
||||
# be sent to the microphone device
|
||||
@abstractmethod
|
||||
async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]:
|
||||
async def run_tts(self, text) -> AsyncGenerator[bytes, None]:
|
||||
# yield empty bytes here, so linting can infer what this method does
|
||||
yield bytes()
|
||||
|
||||
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not FrameType.AUDIO in requested_frame_types:
|
||||
return
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if frame.frame_type != FrameType.TEXT or type(frame.frame_data) != str:
|
||||
raise Exception(f"TTS service requires a string for the data field, got {frame.frame_type} and frame_data type {type(frame.frame_data)}")
|
||||
|
||||
if type(frame.frame_data) != str:
|
||||
raise Exception("TTS service requires a string for the data field")
|
||||
text: str | None = None
|
||||
if not self.aggregate_sentences:
|
||||
text = frame.frame_data
|
||||
else:
|
||||
self.current_sentence += frame.frame_data
|
||||
if self.current_sentence.endswith((".", "?", "!")):
|
||||
text = self.current_sentence
|
||||
self.current_sentence = ""
|
||||
|
||||
async for audio_chunk in self.run_tts(frame.frame_data):
|
||||
yield QueueFrame(FrameType.AUDIO, audio_chunk)
|
||||
if text:
|
||||
async for audio_chunk in self.run_tts(text):
|
||||
yield QueueFrame(FrameType.AUDIO, audio_chunk)
|
||||
|
||||
async def finalize(self):
|
||||
if self.current_sentence:
|
||||
async for audio_chunk in self.run_tts(self.current_sentence):
|
||||
yield QueueFrame(FrameType.AUDIO, audio_chunk)
|
||||
|
||||
# Convenience function to send the audio for a sentence to the given queue
|
||||
async def say(self, sentence, queue: asyncio.Queue):
|
||||
await self.run_to_queue(queue, [QueueFrame(FrameType.SENTENCE, sentence)])
|
||||
await self.run_to_queue(queue, [QueueFrame(FrameType.TEXT, sentence)])
|
||||
|
||||
|
||||
class ImageGenService(AIService):
|
||||
@@ -171,21 +145,12 @@ class ImageGenService(AIService):
|
||||
super().__init__(**kwargs)
|
||||
self.image_size = image_size
|
||||
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.SENTENCE, FrameType.TRANSCRIPTION, FrameType.TEXT_CHUNK, FrameType.IMAGE_DESCRIPTION])
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.IMAGE])
|
||||
|
||||
# Renders the image. Returns an Image object.
|
||||
@abstractmethod
|
||||
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
|
||||
pass
|
||||
|
||||
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if not FrameType.IMAGE in requested_frame_types:
|
||||
return
|
||||
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
if type(frame.frame_data) != str:
|
||||
raise Exception("Image service requires a string for the data field")
|
||||
|
||||
|
||||
@@ -279,7 +279,7 @@ class DailyTransportService(EventHandler):
|
||||
|
||||
def on_transcription_message(self, message:dict):
|
||||
if self.loop:
|
||||
frame = QueueFrame(FrameType.TRANSCRIPTION, message)
|
||||
frame = QueueFrame(FrameType.TEXT, message)
|
||||
asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self.loop)
|
||||
|
||||
def on_transcription_stopped(self, stopped_by, stopped_by_error):
|
||||
|
||||
@@ -3,25 +3,19 @@ import unittest
|
||||
|
||||
from typing import AsyncGenerator, Generator
|
||||
|
||||
from dailyai.services.ai_services import AIService, SentenceAggregator
|
||||
from dailyai.services.ai_services import AIService
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
|
||||
class SimpleAIService(AIService):
|
||||
def allowed_input_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.TEXT_CHUNK])
|
||||
|
||||
def possible_output_frame_types(self) -> set[FrameType]:
|
||||
return set([FrameType.TEXT_CHUNK])
|
||||
|
||||
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> QueueFrame | None:
|
||||
return frame
|
||||
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
|
||||
yield frame
|
||||
|
||||
class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_async_input(self):
|
||||
service = SimpleAIService()
|
||||
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello"),
|
||||
QueueFrame(FrameType.TEXT, "hello"),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
]
|
||||
async def iterate_frames() -> AsyncGenerator[QueueFrame, None]:
|
||||
@@ -29,7 +23,7 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
|
||||
yield frame
|
||||
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.TEXT_CHUNK]), iterate_frames()):
|
||||
async for frame in service.run(iterate_frames()):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(input_frames, output_frames)
|
||||
@@ -38,7 +32,7 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
|
||||
service = SimpleAIService()
|
||||
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello"),
|
||||
QueueFrame(FrameType.TEXT, "hello"),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
]
|
||||
|
||||
@@ -47,83 +41,11 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
|
||||
yield frame
|
||||
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.TEXT_CHUNK]), iterate_frames()):
|
||||
async for frame in service.run(iterate_frames()):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(input_frames, output_frames)
|
||||
|
||||
|
||||
class TestSentenceAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_clause(self) -> None:
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello"),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
]
|
||||
|
||||
service = SentenceAggregator()
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(1, len(output_frames))
|
||||
self.assertEqual(QueueFrame(FrameType.SENTENCE, "hello"), output_frames[0])
|
||||
|
||||
async def test_sentence(self) -> None:
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello, "),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "world."),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
]
|
||||
|
||||
service = SentenceAggregator()
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(1, len(output_frames))
|
||||
self.assertEqual(QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0])
|
||||
|
||||
async def test_sentence_and_clause(self) -> None:
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello, "),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "world."),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, " How are"),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
]
|
||||
|
||||
service = SentenceAggregator()
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(2, len(output_frames))
|
||||
self.assertEqual(
|
||||
QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0]
|
||||
)
|
||||
self.assertEqual(
|
||||
QueueFrame(FrameType.SENTENCE, " How are"), output_frames[1]
|
||||
)
|
||||
|
||||
async def test_two_sentences(self) -> None:
|
||||
input_frames = [
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "hello, "),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, "world."),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, " How are"),
|
||||
QueueFrame(FrameType.TEXT_CHUNK, " you doing?"),
|
||||
QueueFrame(FrameType.END_STREAM, None),
|
||||
]
|
||||
|
||||
service = SentenceAggregator()
|
||||
output_frames = []
|
||||
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
|
||||
output_frames.append(frame)
|
||||
|
||||
self.assertEqual(2, len(output_frames))
|
||||
self.assertEqual(
|
||||
QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0]
|
||||
)
|
||||
self.assertEqual(QueueFrame(FrameType.SENTENCE, " How are you doing?"), output_frames[1])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import AsyncGenerator
|
||||
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.ai_services import SentenceAggregator
|
||||
from dailyai.services.azure_ai_services import AzureLLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
|
||||
@@ -28,9 +27,7 @@ async def main(room_url):
|
||||
tts_task = asyncio.create_task(
|
||||
tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
SentenceAggregator().run(
|
||||
llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)])
|
||||
)
|
||||
llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)])
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ async def main(room_url):
|
||||
|
||||
imagegen = OpenAIImageGenService(image_size="1024x1024")
|
||||
image_task = asyncio.create_task(
|
||||
imagegen.run_to_queue(transport.send_queue, [QueueFrame(FrameType.IMAGE_DESCRIPTION, "a cat in the style of picasso")])
|
||||
imagegen.run_to_queue(transport.send_queue, [QueueFrame(FrameType.TEXT, "a cat in the style of picasso")])
|
||||
)
|
||||
|
||||
@transport.event_handler("on_participant_joined")
|
||||
|
||||
@@ -2,7 +2,6 @@ import argparse
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from dailyai.services.ai_services import SentenceAggregator
|
||||
from dailyai.services.daily_transport_service import DailyTransportService
|
||||
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
@@ -36,9 +35,7 @@ async def main(room_url:str):
|
||||
llm_response_task = asyncio.create_task(
|
||||
elevenlabs_tts.run_to_queue(
|
||||
buffer_queue,
|
||||
SentenceAggregator().run(
|
||||
llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)])
|
||||
),
|
||||
llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)]),
|
||||
True,
|
||||
)
|
||||
)
|
||||
@@ -48,10 +45,7 @@ async def main(room_url:str):
|
||||
if participant["id"] == transport.my_participant_id:
|
||||
return
|
||||
|
||||
await azure_tts.run_to_queue(
|
||||
transport.send_queue,
|
||||
[QueueFrame(FrameType.SENTENCE, "My friend the LLM is now going to tell a joke about llamas.")]
|
||||
)
|
||||
await azure_tts.say("My friend the LLM is now going to tell a joke about llamas.", transport.send_queue)
|
||||
|
||||
async def buffer_to_send_queue():
|
||||
while True:
|
||||
|
||||
@@ -5,7 +5,6 @@ from asyncio.queues import Queue
|
||||
import re
|
||||
|
||||
from dailyai.queue_frame import QueueFrame, FrameType
|
||||
from dailyai.services.ai_services import SentenceAggregator
|
||||
from dailyai.services.azure_ai_services import AzureLLMService
|
||||
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
|
||||
from dailyai.services.open_ai_services import OpenAIImageGenService
|
||||
|
||||
@@ -33,7 +33,7 @@ async def main(room_url:str, token):
|
||||
|
||||
sentence = ""
|
||||
async for frame in transport.get_receive_frames():
|
||||
if frame.frame_type != FrameType.TRANSCRIPTION:
|
||||
if frame.frame_type != FrameType.TEXT:
|
||||
continue
|
||||
|
||||
message = frame.frame_data
|
||||
|
||||
Reference in New Issue
Block a user