Compare commits

...

1 Commits

Author SHA1 Message Date
Chad Bailey
f1b093d9f1 translation use case working 2024-04-02 15:12:47 +00:00
5 changed files with 72 additions and 33 deletions

View File

@@ -52,7 +52,7 @@ class TranslationProcessor(FrameProcessor):
}, },
{"role": "user", "content": frame.text}, {"role": "user", "content": frame.text},
] ]
yield LLMMessagesFrame(context) yield LLMMessagesFrame(context, participantId=frame.participantId)
else: else:
yield frame yield frame
@@ -62,11 +62,15 @@ class TranslationSubtitles(FrameProcessor):
self._language = language self._language = language
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
print(f"!!! got a frame: {frame}")
if isinstance(frame, TextFrame): if isinstance(frame, TextFrame):
print(f"!!! got a textframe: {frame}")
app_message = { app_message = {
"language": self._language, "language": self._language,
"text": frame.text "text": frame.text,
"speakerId": frame.participantId
} }
print(f"!!! App message contents: {app_message}")
yield SendAppMessageFrame(app_message, None) yield SendAppMessageFrame(app_message, None)
yield frame yield frame
else: else:

View File

@@ -171,7 +171,8 @@ class LLMAssistantContextAggregator(LLMContextAggregator):
class SentenceAggregator(FrameProcessor): class SentenceAggregator(FrameProcessor):
"""This frame processor aggregates text frames into complete sentences. """This frame processor aggregates text frames into complete sentences. It separates
frames by participant if the participant_id field is defined in the TextFrame.
Frame input/output: Frame input/output:
TextFrame("Hello,") -> None TextFrame("Hello,") -> None
@@ -183,26 +184,33 @@ class SentenceAggregator(FrameProcessor):
... print(frame.text) ... print(frame.text)
>>> aggregator = SentenceAggregator() >>> aggregator = SentenceAggregator()
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,"))) >>> asyncio.run(print_frames(aggregator, TextFrame("Hello,", participantId="abcd")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" world."))) >>> asyncio.run(print_frames(aggregator, TextFrame(" world.", participantId="cdef")))
Hello, world. world.
>>> asyncio.run(print_frames(aggregator, TextFrame(" everyone.", participantId="abcd")))
Hello, everyone.
""" """
def __init__(self): def __init__(self):
self.aggregation = "" self.aggregations = {}
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if isinstance(frame, TextFrame): if isinstance(frame, TextFrame):
pax_id = "none"
if frame.participantId:
pax_id = frame.participantId
if pax_id not in self.aggregations:
self.aggregations[pax_id] = ""
m = re.search("(.*[?.!])(.*)", frame.text) m = re.search("(.*[?.!])(.*)", frame.text)
if m: if m:
yield TextFrame(self.aggregation + m.group(1)) yield TextFrame(self.aggregations[pax_id] + m.group(1), participantId=pax_id)
self.aggregation = m.group(2) self.aggregations[pax_id] = m.group(2)
else: else:
self.aggregation += frame.text self.aggregations[pax_id] += frame.text
elif isinstance(frame, EndFrame): elif isinstance(frame, EndFrame):
if self.aggregation: for key in self.aggregations:
yield TextFrame(self.aggregation) if self.aggregations[key]:
yield frame yield TextFrame(self.aggregation, participantId=key)
else: else:
yield frame yield frame
@@ -245,15 +253,24 @@ class LLMFullResponseAggregator(FrameProcessor):
""" """
def __init__(self): def __init__(self):
self.aggregation = "" self.aggregations = {}
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]: async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if isinstance(frame, TextFrame): if isinstance(frame, TextFrame):
self.aggregation += frame.text pax_id = "none"
if frame.participantId:
pax_id = frame.participantId
if pax_id not in self.aggregations:
self.aggregations[pax_id] = ""
self.aggregations[pax_id] += frame.text
elif isinstance(frame, LLMResponseEndFrame): elif isinstance(frame, LLMResponseEndFrame):
yield TextFrame(self.aggregation) pax_id = "none"
if frame.participantId:
pax_id = frame.participantId
if self.aggregations[pax_id]:
yield TextFrame(self.aggregations[pax_id], participantId=pax_id)
yield frame yield frame
self.aggregation = "" self.aggregations[pax_id] = ""
else: else:
yield frame yield frame
@@ -390,3 +407,8 @@ class GatedAggregator(FrameProcessor):
self.accumulator = [] self.accumulator = []
else: else:
self.accumulator.append(frame) self.accumulator.append(frame)
if __name__ == "__main__":
import doctest
doctest.testmod()

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List from typing import Any, List, Optional
from dailyai.services.openai_llm_context import OpenAILLMContext from dailyai.services.openai_llm_context import OpenAILLMContext
import dailyai.pipeline.protobufs.frames_pb2 as frame_protos import dailyai.pipeline.protobufs.frames_pb2 as frame_protos
@@ -48,15 +48,17 @@ class PipelineStartedFrame(ControlFrame):
pass pass
class LLMResponseStartFrame(ControlFrame): @dataclass
class LLMResponseStartFrame(Frame):
"""Used to indicate the beginning of an LLM response. Following TextFrames """Used to indicate the beginning of an LLM response. Following TextFrames
are part of the LLM response until an LLMResponseEndFrame""" are part of the LLM response until an LLMResponseEndFrame"""
pass participantId: Optional[str] = None
class LLMResponseEndFrame(ControlFrame): @dataclass
class LLMResponseEndFrame(Frame):
"""Indicates the end of an LLM response.""" """Indicates the end of an LLM response."""
pass participantId: Optional[str] = None
@dataclass() @dataclass()
@@ -64,6 +66,7 @@ class AudioFrame(Frame):
"""A chunk of audio. Will be played by the transport if the transport's mic """A chunk of audio. Will be played by the transport if the transport's mic
has been enabled.""" has been enabled."""
data: bytes data: bytes
participantId: Optional[str] = None
def __str__(self): def __str__(self):
return f"{self.__class__.__name__}, size: {len(self.data)} B" return f"{self.__class__.__name__}, size: {len(self.data)} B"
@@ -75,6 +78,7 @@ class ImageFrame(Frame):
enabled.""" enabled."""
url: str | None url: str | None
image: bytes image: bytes
participantId: Optional[str] = None
def __str__(self): def __str__(self):
return f"{self.__class__.__name__}, url: {self.url}, image size: {len(self.image)} B" return f"{self.__class__.__name__}, url: {self.url}, image size: {len(self.image)} B"
@@ -96,16 +100,16 @@ class TextFrame(Frame):
"""A chunk of text. Emitted by LLM services, consumed by TTS services, can """A chunk of text. Emitted by LLM services, consumed by TTS services, can
be used to send text through pipelines.""" be used to send text through pipelines."""
text: str text: str
participantId: Optional[str] = None
def __str__(self): def __str__(self):
return f'{self.__class__.__name__}: "{self.text}"' return f'{self.__class__.__name__}: "{self.text}", participantId: {self.participantId}'
@dataclass() @dataclass(kw_only=True)
class TranscriptionFrame(TextFrame): class TranscriptionFrame(TextFrame):
"""A text frame with transcription-specific data. Will be placed in the """A text frame with transcription-specific data. Will be placed in the
transport's receive queue when a participant speaks.""" transport's receive queue when a participant speaks."""
participantId: str
timestamp: str timestamp: str
def __str__(self): def __str__(self):
@@ -133,6 +137,10 @@ class LLMMessagesFrame(Frame):
Note that the messages property on this class is mutable, and will be Note that the messages property on this class is mutable, and will be
be updated by various ResponseAggregator frame processors.""" be updated by various ResponseAggregator frame processors."""
messages: List[dict] messages: List[dict]
participantId: Optional[str] = None
def __str__(self):
return f"{self.__class__.__name__}, participantId: {self.participantId}"
@dataclass() @dataclass()
@@ -141,6 +149,7 @@ class OpenAILLMContextFrame(Frame):
OpenAI API. The context in this message is also mutable, and will be OpenAI API. The context in this message is also mutable, and will be
changed by the OpenAIContextAggregator frame processor.""" changed by the OpenAIContextAggregator frame processor."""
context: OpenAILLMContext context: OpenAILLMContext
participantId: Optional[str] = None
@dataclass() @dataclass()
@@ -189,6 +198,7 @@ class LLMFunctionStartFrame(Frame):
start preparing to make a function call, if it can do so in the absence of start preparing to make a function call, if it can do so in the absence of
any arguments.""" any arguments."""
function_name: str function_name: str
participantId: Optional[str] = None
@dataclass() @dataclass()
@@ -196,3 +206,4 @@ class LLMFunctionCallFrame(Frame):
"""Emitted when the LLM has received an entire function call completion.""" """Emitted when the LLM has received an entire function call completion."""
function_name: str function_name: str
arguments: str arguments: str
participantId: Optional[str] = None

View File

@@ -83,8 +83,8 @@ class BaseOpenAILLMService(LLMService):
function_name = "" function_name = ""
arguments = "" arguments = ""
print(f"%%% I'm yielding a start frame, and in here, frame is {frame}")
yield LLMResponseStartFrame() yield LLMResponseStartFrame(participantId=frame.participantId)
chunk_stream: AsyncStream[ChatCompletionChunk] = ( chunk_stream: AsyncStream[ChatCompletionChunk] = (
await self._stream_chat_completions(context) await self._stream_chat_completions(context)
) )
@@ -107,18 +107,19 @@ class BaseOpenAILLMService(LLMService):
tool_call = chunk.choices[0].delta.tool_calls[0] tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.function and tool_call.function.name: if tool_call.function and tool_call.function.name:
function_name += tool_call.function.name function_name += tool_call.function.name
yield LLMFunctionStartFrame(function_name=tool_call.function.name) yield LLMFunctionStartFrame(function_name=tool_call.function.name, participantId=frame.participantId)
if tool_call.function and tool_call.function.arguments: if tool_call.function and tool_call.function.arguments:
# Keep iterating through the response to collect all the argument fragments and # Keep iterating through the response to collect all the argument fragments and
# yield a complete LLMFunctionCallFrame after run_llm_async # yield a complete LLMFunctionCallFrame after run_llm_async
# completes # completes
arguments += tool_call.function.arguments arguments += tool_call.function.arguments
elif chunk.choices[0].delta.content: elif chunk.choices[0].delta.content:
yield TextFrame(chunk.choices[0].delta.content) print(f"%%% yielding text frame for {frame.participantId}")
yield TextFrame(chunk.choices[0].delta.content, participantId=frame.participantId)
# if we got a function name and arguments, yield the frame with all the info so # if we got a function name and arguments, yield the frame with all the info so
# frame consumers can take action based on the function call. # frame consumers can take action based on the function call.
if function_name and arguments: if function_name and arguments:
yield LLMFunctionCallFrame(function_name=function_name, arguments=arguments) yield LLMFunctionCallFrame(function_name=function_name, arguments=arguments, participantId=frame.participantId)
print(f"%%% yielding llm response end frame for {frame.participantId}")
yield LLMResponseEndFrame() yield LLMResponseEndFrame(participantId=frame.participantId)

View File

@@ -132,6 +132,7 @@ class DailyTransport(ThreadedTransport, EventHandler):
self.mic.write_frames(frame) self.mic.write_frames(frame)
def send_app_message(self, message: Any, participantId: str | None): def send_app_message(self, message: Any, participantId: str | None):
print(f"### about to try to send {message} to {participantId}")
self.client.send_app_message(message, participantId) self.client.send_app_message(message, participantId)
def read_audio_frames(self, desired_frame_count): def read_audio_frames(self, desired_frame_count):
@@ -270,7 +271,7 @@ class DailyTransport(ThreadedTransport, EventHandler):
participantId = message["session_id"] participantId = message["session_id"]
if self._my_participant_id and participantId != self._my_participant_id: if self._my_participant_id and participantId != self._my_participant_id:
frame = TranscriptionFrame( frame = TranscriptionFrame(
message["text"], participantId, message["timestamp"]) message["text"], participantId=participantId, timestamp=message["timestamp"])
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
self.receive_queue.put(frame), self._loop) self.receive_queue.put(frame), self._loop)