Compare commits

...

1 Commits

Author SHA1 Message Date
Chad Bailey
f1b093d9f1 translation use case working 2024-04-02 15:12:47 +00:00
5 changed files with 72 additions and 33 deletions

View File

@@ -52,7 +52,7 @@ class TranslationProcessor(FrameProcessor):
},
{"role": "user", "content": frame.text},
]
yield LLMMessagesFrame(context)
yield LLMMessagesFrame(context, participantId=frame.participantId)
else:
yield frame
@@ -62,11 +62,15 @@ class TranslationSubtitles(FrameProcessor):
self._language = language
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
print(f"!!! got a frame: {frame}")
if isinstance(frame, TextFrame):
print(f"!!! got a textframe: {frame}")
app_message = {
"language": self._language,
"text": frame.text
"text": frame.text,
"speakerId": frame.participantId
}
print(f"!!! App message contents: {app_message}")
yield SendAppMessageFrame(app_message, None)
yield frame
else:

View File

@@ -171,7 +171,8 @@ class LLMAssistantContextAggregator(LLMContextAggregator):
class SentenceAggregator(FrameProcessor):
"""This frame processor aggregates text frames into complete sentences.
"""This frame processor aggregates text frames into complete sentences. It separates
frames by participant if the participant_id field is defined in the TextFrame.
Frame input/output:
TextFrame("Hello,") -> None
@@ -183,26 +184,33 @@ class SentenceAggregator(FrameProcessor):
... print(frame.text)
>>> aggregator = SentenceAggregator()
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.")))
Hello, world.
>>> asyncio.run(print_frames(aggregator, TextFrame("Hello,", participantId="abcd")))
>>> asyncio.run(print_frames(aggregator, TextFrame(" world.", participantId="cdef")))
world.
>>> asyncio.run(print_frames(aggregator, TextFrame(" everyone.", participantId="abcd")))
Hello, everyone.
"""
def __init__(self):
self.aggregation = ""
self.aggregations = {}
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if isinstance(frame, TextFrame):
pax_id = "none"
if frame.participantId:
pax_id = frame.participantId
if pax_id not in self.aggregations:
self.aggregations[pax_id] = ""
m = re.search("(.*[?.!])(.*)", frame.text)
if m:
yield TextFrame(self.aggregation + m.group(1))
self.aggregation = m.group(2)
yield TextFrame(self.aggregations[pax_id] + m.group(1), participantId=pax_id)
self.aggregations[pax_id] = m.group(2)
else:
self.aggregation += frame.text
self.aggregations[pax_id] += frame.text
elif isinstance(frame, EndFrame):
if self.aggregation:
yield TextFrame(self.aggregation)
yield frame
for key in self.aggregations:
if self.aggregations[key]:
yield TextFrame(self.aggregation, participantId=key)
else:
yield frame
@@ -245,15 +253,24 @@ class LLMFullResponseAggregator(FrameProcessor):
"""
def __init__(self):
self.aggregation = ""
self.aggregations = {}
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
if isinstance(frame, TextFrame):
self.aggregation += frame.text
pax_id = "none"
if frame.participantId:
pax_id = frame.participantId
if pax_id not in self.aggregations:
self.aggregations[pax_id] = ""
self.aggregations[pax_id] += frame.text
elif isinstance(frame, LLMResponseEndFrame):
yield TextFrame(self.aggregation)
pax_id = "none"
if frame.participantId:
pax_id = frame.participantId
if self.aggregations[pax_id]:
yield TextFrame(self.aggregations[pax_id], participantId=pax_id)
yield frame
self.aggregation = ""
self.aggregations[pax_id] = ""
else:
yield frame
@@ -390,3 +407,8 @@ class GatedAggregator(FrameProcessor):
self.accumulator = []
else:
self.accumulator.append(frame)
if __name__ == "__main__":
import doctest
doctest.testmod()

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, List
from typing import Any, List, Optional
from dailyai.services.openai_llm_context import OpenAILLMContext
import dailyai.pipeline.protobufs.frames_pb2 as frame_protos
@@ -48,15 +48,17 @@ class PipelineStartedFrame(ControlFrame):
pass
class LLMResponseStartFrame(ControlFrame):
@dataclass
class LLMResponseStartFrame(Frame):
"""Used to indicate the beginning of an LLM response. Following TextFrames
are part of the LLM response until an LLMResponseEndFrame"""
pass
participantId: Optional[str] = None
class LLMResponseEndFrame(ControlFrame):
@dataclass
class LLMResponseEndFrame(Frame):
"""Indicates the end of an LLM response."""
pass
participantId: Optional[str] = None
@dataclass()
@@ -64,6 +66,7 @@ class AudioFrame(Frame):
"""A chunk of audio. Will be played by the transport if the transport's mic
has been enabled."""
data: bytes
participantId: Optional[str] = None
def __str__(self):
return f"{self.__class__.__name__}, size: {len(self.data)} B"
@@ -75,6 +78,7 @@ class ImageFrame(Frame):
enabled."""
url: str | None
image: bytes
participantId: Optional[str] = None
def __str__(self):
return f"{self.__class__.__name__}, url: {self.url}, image size: {len(self.image)} B"
@@ -96,16 +100,16 @@ class TextFrame(Frame):
"""A chunk of text. Emitted by LLM services, consumed by TTS services, can
be used to send text through pipelines."""
text: str
participantId: Optional[str] = None
def __str__(self):
return f'{self.__class__.__name__}: "{self.text}"'
return f'{self.__class__.__name__}: "{self.text}", participantId: {self.participantId}'
@dataclass()
@dataclass(kw_only=True)
class TranscriptionFrame(TextFrame):
"""A text frame with transcription-specific data. Will be placed in the
transport's receive queue when a participant speaks."""
participantId: str
timestamp: str
def __str__(self):
@@ -133,6 +137,10 @@ class LLMMessagesFrame(Frame):
Note that the messages property on this class is mutable, and will be
be updated by various ResponseAggregator frame processors."""
messages: List[dict]
participantId: Optional[str] = None
def __str__(self):
return f"{self.__class__.__name__}, participantId: {self.participantId}"
@dataclass()
@@ -141,6 +149,7 @@ class OpenAILLMContextFrame(Frame):
OpenAI API. The context in this message is also mutable, and will be
changed by the OpenAIContextAggregator frame processor."""
context: OpenAILLMContext
participantId: Optional[str] = None
@dataclass()
@@ -189,6 +198,7 @@ class LLMFunctionStartFrame(Frame):
start preparing to make a function call, if it can do so in the absence of
any arguments."""
function_name: str
participantId: Optional[str] = None
@dataclass()
@@ -196,3 +206,4 @@ class LLMFunctionCallFrame(Frame):
"""Emitted when the LLM has received an entire function call completion."""
function_name: str
arguments: str
participantId: Optional[str] = None

View File

@@ -83,8 +83,8 @@ class BaseOpenAILLMService(LLMService):
function_name = ""
arguments = ""
yield LLMResponseStartFrame()
print(f"%%% I'm yielding a start frame, and in here, frame is {frame}")
yield LLMResponseStartFrame(participantId=frame.participantId)
chunk_stream: AsyncStream[ChatCompletionChunk] = (
await self._stream_chat_completions(context)
)
@@ -107,18 +107,19 @@ class BaseOpenAILLMService(LLMService):
tool_call = chunk.choices[0].delta.tool_calls[0]
if tool_call.function and tool_call.function.name:
function_name += tool_call.function.name
yield LLMFunctionStartFrame(function_name=tool_call.function.name)
yield LLMFunctionStartFrame(function_name=tool_call.function.name, participantId=frame.participantId)
if tool_call.function and tool_call.function.arguments:
# Keep iterating through the response to collect all the argument fragments and
# yield a complete LLMFunctionCallFrame after run_llm_async
# completes
arguments += tool_call.function.arguments
elif chunk.choices[0].delta.content:
yield TextFrame(chunk.choices[0].delta.content)
print(f"%%% yielding text frame for {frame.participantId}")
yield TextFrame(chunk.choices[0].delta.content, participantId=frame.participantId)
# if we got a function name and arguments, yield the frame with all the info so
# frame consumers can take action based on the function call.
if function_name and arguments:
yield LLMFunctionCallFrame(function_name=function_name, arguments=arguments)
yield LLMResponseEndFrame()
yield LLMFunctionCallFrame(function_name=function_name, arguments=arguments, participantId=frame.participantId)
print(f"%%% yielding llm response end frame for {frame.participantId}")
yield LLMResponseEndFrame(participantId=frame.participantId)

View File

@@ -132,6 +132,7 @@ class DailyTransport(ThreadedTransport, EventHandler):
self.mic.write_frames(frame)
def send_app_message(self, message: Any, participantId: str | None):
print(f"### about to try to send {message} to {participantId}")
self.client.send_app_message(message, participantId)
def read_audio_frames(self, desired_frame_count):
@@ -270,7 +271,7 @@ class DailyTransport(ThreadedTransport, EventHandler):
participantId = message["session_id"]
if self._my_participant_id and participantId != self._my_participant_id:
frame = TranscriptionFrame(
message["text"], participantId, message["timestamp"])
message["text"], participantId=participantId, timestamp=message["timestamp"])
asyncio.run_coroutine_threadsafe(
self.receive_queue.put(frame), self._loop)