Cleanup: no more sentence aggregator, let the TTS service deal with that; also removed the queue typing stuff from ai_services

This commit is contained in:
Moishe Lettvin
2024-01-19 13:06:15 -05:00
parent 1071dede1a
commit 6ae733ebfe
9 changed files with 82 additions and 207 deletions

View File

@@ -2,16 +2,14 @@ from enum import Enum
from dataclasses import dataclass
class FrameType(Enum):
NOOP = -1
START_STREAM = 0
END_STREAM = 1
AUDIO = 2
IMAGE = 3
SENTENCE = 4
TEXT_CHUNK = 5
LLM_MESSAGE = 6
APP_MESSAGE = 7
IMAGE_DESCRIPTION = 8
TRANSCRIPTION = 9
TEXT = 4
LLM_MESSAGE = 5
APP_MESSAGE = 6
@dataclass(frozen=True)
class QueueFrame:

View File

@@ -21,12 +21,6 @@ class AIService:
def stop(self):
pass
def allowed_input_frame_types(self) -> set[FrameType]:
return set()
def possible_output_frame_types(self) -> set[FrameType]:
return set()
async def run_to_queue(self, queue: asyncio.Queue, frames, add_end_of_stream=False) -> None:
async for frame in self.run(frames):
await queue.put(frame)
@@ -39,76 +33,50 @@ class AIService:
frames: Iterable[QueueFrame]
| AsyncIterable[QueueFrame]
| asyncio.Queue[QueueFrame],
requested_frame_types: set[FrameType] | None=None,
) -> AsyncGenerator[QueueFrame, None]:
if requested_frame_types and self.possible_output_frame_types().intersection(requested_frame_types) == set():
raise Exception(f"Requested frame types {requested_frame_types} are not supported by this service.")
try:
if isinstance(frames, AsyncIterable):
async for frame in frames:
async for output_frame in self.process_frame(frame):
yield output_frame
elif isinstance(frames, Iterable):
for frame in frames:
async for output_frame in self.process_frame(frame):
yield output_frame
elif isinstance(frames, asyncio.Queue):
while True:
frame = await frames.get()
async for output_frame in self.process_frame(frame):
yield output_frame
if frame.frame_type == FrameType.END_STREAM:
break
else:
raise Exception("Frames must be an iterable or async iterable")
if not requested_frame_types:
requested_frame_types = self.possible_output_frame_types()
if isinstance(frames, AsyncIterable):
async for frame in frames:
async for output_frame in self.process_frame(requested_frame_types, frame):
yield output_frame
elif isinstance(frames, Iterable):
for frame in frames:
async for output_frame in self.process_frame(requested_frame_types, frame):
yield output_frame
elif isinstance(frames, asyncio.Queue):
while True:
frame = await frames.get()
async for output_frame in self.process_frame(requested_frame_types, frame):
yield output_frame
if frame.frame_type == FrameType.END_STREAM:
break
else:
raise Exception("Frames must be an iterable or async iterable")
async for output_frame in self.finalize():
yield output_frame
except Exception as e:
self.logger.error("Exception occurred while running AI service", e)
raise e
@abstractmethod
async def process_frame(self, requested_frame_types:set[FrameType], frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
# Yield something so the linter can deduce what should happen here.
yield QueueFrame(FrameType.END_STREAM, None)
class SentenceAggregator(AIService):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.current_sentence = ""
def allowed_input_frame_types(self) -> set[FrameType]:
return set([FrameType.TEXT_CHUNK, FrameType.SENTENCE])
def possible_output_frame_types(self) -> set[FrameType]:
return set([FrameType.SENTENCE])
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
if not FrameType.SENTENCE in requested_frame_types:
return
if frame.frame_type == FrameType.TEXT_CHUNK:
if type(frame.frame_data) != str:
raise Exception(
"Sentence aggregator requires a string for the data field"
)
self.current_sentence += frame.frame_data
if self.current_sentence.endswith((".", "?", "!")):
sentence = self.current_sentence
self.current_sentence = ""
yield QueueFrame(FrameType.SENTENCE, sentence)
elif frame.frame_type == FrameType.END_STREAM:
if self.current_sentence:
yield QueueFrame(FrameType.SENTENCE, self.current_sentence)
elif frame.frame_type == FrameType.SENTENCE:
yield frame
async def process_frame(self, frame:QueueFrame) -> AsyncGenerator[QueueFrame, None]:
# This is a trick for the interpreter (and linter) to know that this is a generator.
if False:
yield QueueFrame(FrameType.NOOP, None)
@abstractmethod
async def finalize(self) -> AsyncGenerator[QueueFrame, None]:
# This is a trick for the interpreter (and linter) to know that this is a generator.
if False:
yield QueueFrame(FrameType.NOOP, None)
class LLMService(AIService):
def allowed_input_frame_types(self) -> set[FrameType]:
return set([FrameType.LLM_MESSAGE, FrameType.SENTENCE, FrameType.TRANSCRIPTION])
return set([FrameType.LLM_MESSAGE])
def allowed_output_frame_types(self) -> set[FrameType]:
return set([FrameType.SENTENCE, FrameType.TEXT_CHUNK])
return set([FrameType.TEXT])
@abstractmethod
async def run_llm_async(self, messages) -> AsyncGenerator[str, None]:
@@ -118,52 +86,58 @@ class LLMService(AIService):
async def run_llm(self, messages) -> str:
pass
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
if frame.frame_type == FrameType.LLM_MESSAGE:
if type(frame.frame_data) != list:
raise Exception("LLM service requires a dict for the data field")
messages: list[dict[str, str]] = frame.frame_data
if FrameType.SENTENCE in requested_frame_types:
yield QueueFrame(FrameType.SENTENCE, await self.run_llm(messages))
else:
async for text_chunk in self.run_llm_async(messages):
yield QueueFrame(FrameType.TEXT_CHUNK, text_chunk)
# TODO: handle other frame types! Need to aggregate into messages
async for text_chunk in self.run_llm_async(messages):
yield QueueFrame(FrameType.TEXT, text_chunk)
class TTSService(AIService):
def __init__(self, aggregate_sentences=True):
super().__init__()
self.aggregate_sentences: bool = aggregate_sentences
self.current_sentence: str = ""
# Some TTS services require a specific sample rate. We default to 16k
def get_mic_sample_rate(self):
return 16000
def allowed_input_frame_types(self) -> set[FrameType]:
return set([FrameType.SENTENCE, FrameType.TRANSCRIPTION, FrameType.TEXT_CHUNK])
def possible_output_frame_types(self) -> set[FrameType]:
return set([FrameType.AUDIO])
# Converts the sentence to audio. Yields a list of audio frames that can
# Converts the text to audio. Yields a list of audio frames that can
# be sent to the microphone device
@abstractmethod
async def run_tts(self, sentence) -> AsyncGenerator[bytes, None]:
async def run_tts(self, text) -> AsyncGenerator[bytes, None]:
# yield empty bytes here, so linting can infer what this method does
yield bytes()
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
if not FrameType.AUDIO in requested_frame_types:
return
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
if frame.frame_type != FrameType.TEXT or type(frame.frame_data) != str:
raise Exception(f"TTS service requires a string for the data field, got {frame.frame_type} and frame_data type {type(frame.frame_data)}")
if type(frame.frame_data) != str:
raise Exception("TTS service requires a string for the data field")
text: str | None = None
if not self.aggregate_sentences:
text = frame.frame_data
else:
self.current_sentence += frame.frame_data
if self.current_sentence.endswith((".", "?", "!")):
text = self.current_sentence
self.current_sentence = ""
async for audio_chunk in self.run_tts(frame.frame_data):
yield QueueFrame(FrameType.AUDIO, audio_chunk)
if text:
async for audio_chunk in self.run_tts(text):
yield QueueFrame(FrameType.AUDIO, audio_chunk)
async def finalize(self):
if self.current_sentence:
async for audio_chunk in self.run_tts(self.current_sentence):
yield QueueFrame(FrameType.AUDIO, audio_chunk)
# Convenience function to send the audio for a sentence to the given queue
async def say(self, sentence, queue: asyncio.Queue):
await self.run_to_queue(queue, [QueueFrame(FrameType.SENTENCE, sentence)])
await self.run_to_queue(queue, [QueueFrame(FrameType.TEXT, sentence)])
class ImageGenService(AIService):
@@ -171,21 +145,12 @@ class ImageGenService(AIService):
super().__init__(**kwargs)
self.image_size = image_size
def allowed_input_frame_types(self) -> set[FrameType]:
return set([FrameType.SENTENCE, FrameType.TRANSCRIPTION, FrameType.TEXT_CHUNK, FrameType.IMAGE_DESCRIPTION])
def possible_output_frame_types(self) -> set[FrameType]:
return set([FrameType.IMAGE])
# Renders the image. Returns an Image object.
@abstractmethod
async def run_image_gen(self, sentence) -> tuple[str, bytes]:
pass
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
if not FrameType.IMAGE in requested_frame_types:
return
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
if type(frame.frame_data) != str:
raise Exception("Image service requires a string for the data field")

View File

@@ -279,7 +279,7 @@ class DailyTransportService(EventHandler):
def on_transcription_message(self, message:dict):
if self.loop:
frame = QueueFrame(FrameType.TRANSCRIPTION, message)
frame = QueueFrame(FrameType.TEXT, message)
asyncio.run_coroutine_threadsafe(self.receive_queue.put(frame), self.loop)
def on_transcription_stopped(self, stopped_by, stopped_by_error):

View File

@@ -3,25 +3,19 @@ import unittest
from typing import AsyncGenerator, Generator
from dailyai.services.ai_services import AIService, SentenceAggregator
from dailyai.services.ai_services import AIService
from dailyai.queue_frame import QueueFrame, FrameType
class SimpleAIService(AIService):
def allowed_input_frame_types(self) -> set[FrameType]:
return set([FrameType.TEXT_CHUNK])
def possible_output_frame_types(self) -> set[FrameType]:
return set([FrameType.TEXT_CHUNK])
async def process_frame(self, requested_frame_types: set[FrameType], frame: QueueFrame) -> QueueFrame | None:
return frame
async def process_frame(self, frame: QueueFrame) -> AsyncGenerator[QueueFrame, None]:
yield frame
class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
async def test_async_input(self):
service = SimpleAIService()
input_frames = [
QueueFrame(FrameType.TEXT_CHUNK, "hello"),
QueueFrame(FrameType.TEXT, "hello"),
QueueFrame(FrameType.END_STREAM, None),
]
async def iterate_frames() -> AsyncGenerator[QueueFrame, None]:
@@ -29,7 +23,7 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
yield frame
output_frames = []
async for frame in service.run(set([FrameType.TEXT_CHUNK]), iterate_frames()):
async for frame in service.run(iterate_frames()):
output_frames.append(frame)
self.assertEqual(input_frames, output_frames)
@@ -38,7 +32,7 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
service = SimpleAIService()
input_frames = [
QueueFrame(FrameType.TEXT_CHUNK, "hello"),
QueueFrame(FrameType.TEXT, "hello"),
QueueFrame(FrameType.END_STREAM, None),
]
@@ -47,83 +41,11 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
yield frame
output_frames = []
async for frame in service.run(set([FrameType.TEXT_CHUNK]), iterate_frames()):
async for frame in service.run(iterate_frames()):
output_frames.append(frame)
self.assertEqual(input_frames, output_frames)
class TestSentenceAggregator(unittest.IsolatedAsyncioTestCase):
async def test_clause(self) -> None:
input_frames = [
QueueFrame(FrameType.TEXT_CHUNK, "hello"),
QueueFrame(FrameType.END_STREAM, None),
]
service = SentenceAggregator()
output_frames = []
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
output_frames.append(frame)
self.assertEqual(1, len(output_frames))
self.assertEqual(QueueFrame(FrameType.SENTENCE, "hello"), output_frames[0])
async def test_sentence(self) -> None:
input_frames = [
QueueFrame(FrameType.TEXT_CHUNK, "hello, "),
QueueFrame(FrameType.TEXT_CHUNK, "world."),
QueueFrame(FrameType.END_STREAM, None),
]
service = SentenceAggregator()
output_frames = []
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
output_frames.append(frame)
self.assertEqual(1, len(output_frames))
self.assertEqual(QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0])
async def test_sentence_and_clause(self) -> None:
input_frames = [
QueueFrame(FrameType.TEXT_CHUNK, "hello, "),
QueueFrame(FrameType.TEXT_CHUNK, "world."),
QueueFrame(FrameType.TEXT_CHUNK, " How are"),
QueueFrame(FrameType.END_STREAM, None),
]
service = SentenceAggregator()
output_frames = []
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
output_frames.append(frame)
self.assertEqual(2, len(output_frames))
self.assertEqual(
QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0]
)
self.assertEqual(
QueueFrame(FrameType.SENTENCE, " How are"), output_frames[1]
)
async def test_two_sentences(self) -> None:
input_frames = [
QueueFrame(FrameType.TEXT_CHUNK, "hello, "),
QueueFrame(FrameType.TEXT_CHUNK, "world."),
QueueFrame(FrameType.TEXT_CHUNK, " How are"),
QueueFrame(FrameType.TEXT_CHUNK, " you doing?"),
QueueFrame(FrameType.END_STREAM, None),
]
service = SentenceAggregator()
output_frames = []
async for frame in service.run(set([FrameType.SENTENCE]), input_frames):
output_frames.append(frame)
self.assertEqual(2, len(output_frames))
self.assertEqual(
QueueFrame(FrameType.SENTENCE, "hello, world."), output_frames[0]
)
self.assertEqual(QueueFrame(FrameType.SENTENCE, " How are you doing?"), output_frames[1])
if __name__ == "__main__":
unittest.main()

View File

@@ -4,7 +4,6 @@ from typing import AsyncGenerator
from dailyai.queue_frame import QueueFrame, FrameType
from dailyai.services.daily_transport_service import DailyTransportService
from dailyai.services.ai_services import SentenceAggregator
from dailyai.services.azure_ai_services import AzureLLMService
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
@@ -28,9 +27,7 @@ async def main(room_url):
tts_task = asyncio.create_task(
tts.run_to_queue(
transport.send_queue,
SentenceAggregator().run(
llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)])
)
llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)])
)
)

View File

@@ -23,7 +23,7 @@ async def main(room_url):
imagegen = OpenAIImageGenService(image_size="1024x1024")
image_task = asyncio.create_task(
imagegen.run_to_queue(transport.send_queue, [QueueFrame(FrameType.IMAGE_DESCRIPTION, "a cat in the style of picasso")])
imagegen.run_to_queue(transport.send_queue, [QueueFrame(FrameType.TEXT, "a cat in the style of picasso")])
)
@transport.event_handler("on_participant_joined")

View File

@@ -2,7 +2,6 @@ import argparse
import asyncio
import re
from dailyai.services.ai_services import SentenceAggregator
from dailyai.services.daily_transport_service import DailyTransportService
from dailyai.services.azure_ai_services import AzureLLMService, AzureTTSService
from dailyai.queue_frame import QueueFrame, FrameType
@@ -36,9 +35,7 @@ async def main(room_url:str):
llm_response_task = asyncio.create_task(
elevenlabs_tts.run_to_queue(
buffer_queue,
SentenceAggregator().run(
llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)])
),
llm.run([QueueFrame(FrameType.LLM_MESSAGE, messages)]),
True,
)
)
@@ -48,10 +45,7 @@ async def main(room_url:str):
if participant["id"] == transport.my_participant_id:
return
await azure_tts.run_to_queue(
transport.send_queue,
[QueueFrame(FrameType.SENTENCE, "My friend the LLM is now going to tell a joke about llamas.")]
)
await azure_tts.say("My friend the LLM is now going to tell a joke about llamas.", transport.send_queue)
async def buffer_to_send_queue():
while True:

View File

@@ -5,7 +5,6 @@ from asyncio.queues import Queue
import re
from dailyai.queue_frame import QueueFrame, FrameType
from dailyai.services.ai_services import SentenceAggregator
from dailyai.services.azure_ai_services import AzureLLMService
from dailyai.services.elevenlabs_ai_service import ElevenLabsTTSService
from dailyai.services.open_ai_services import OpenAIImageGenService

View File

@@ -33,7 +33,7 @@ async def main(room_url:str, token):
sentence = ""
async for frame in transport.get_receive_frames():
if frame.frame_type != FrameType.TRANSCRIPTION:
if frame.frame_type != FrameType.TEXT:
continue
message = frame.frame_data