Compare commits
9 Commits
hush/AGENT
...
async-reba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9cd7c82e77 | ||
|
|
43161c816e | ||
|
|
6644c06af1 | ||
|
|
ed47212e07 | ||
|
|
db9cb74364 | ||
|
|
f64902eb25 | ||
|
|
e115a274d6 | ||
|
|
00239c2fd4 | ||
|
|
c0f9ad19fe |
22
CHANGELOG.md
22
CHANGELOG.md
@@ -48,15 +48,10 @@ async def on_connected(processor):
|
||||
frames. To achieve that, each frame processor should only output frames from a
|
||||
single task.
|
||||
|
||||
In this version we introduce synchronous and asynchronous frame
|
||||
processors. The synchronous processors push output frames from the same task
|
||||
that they receive input frames, and therefore only pushing frames from one
|
||||
task. Asynchronous frame processors can have internal tasks to perform things
|
||||
asynchronously (e.g. receiving data from a websocket) but they also have a
|
||||
single task where they push frames from.
|
||||
|
||||
By default, frame processors are synchronous. To change a frame processor to
|
||||
asynchronous you only need to pass `sync=False` to the base class constructor.
|
||||
In this version all the frame processors have their own task to push
|
||||
frames. That is, when `push_frame()` is called the given frame will be put
|
||||
into an internal queue (with the exception of system frames) and a frame
|
||||
processor task will push it out.
|
||||
|
||||
- Added pipeline clocks. A pipeline clock is used by the output transport to
|
||||
know when a frame needs to be presented. For that, all frames now have an
|
||||
@@ -68,9 +63,7 @@ async def on_connected(processor):
|
||||
`SystemClock`). This clock will be passed to each frame processor via the
|
||||
`StartFrame`.
|
||||
|
||||
- Added `CartesiaHttpTTSService`. This is a synchronous frame processor
|
||||
(i.e. given an input text frame it will wait for the whole output before
|
||||
returning).
|
||||
- Added `CartesiaHttpTTSService`.
|
||||
|
||||
- `DailyTransport` now supports setting the audio bitrate to improve audio
|
||||
quality through the `DailyParams.audio_out_bitrate` parameter. The new
|
||||
@@ -110,8 +103,9 @@ async def on_connected(processor):
|
||||
pipelines to be executed concurrently. The difference between a
|
||||
`SyncParallelPipeline` and a `ParallelPipeline` is that, given an input frame,
|
||||
the `SyncParallelPipeline` will wait for all the internal pipelines to
|
||||
complete. This is achieved by ensuring all the processors in each of the
|
||||
internal pipelines are synchronous.
|
||||
complete. This is achieved by making sure the last processor in each of the
|
||||
pipelines is synchronous (e.g. an HTTP-based service that waits for the
|
||||
response).
|
||||
|
||||
- `StartFrame` is back a system frame so we make sure it's processed immediately
|
||||
by all processors. `EndFrame` stays a control frame since it needs to be
|
||||
|
||||
@@ -86,13 +86,13 @@ async def main():
|
||||
),
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
tts = CartesiaHttpTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
imagegen = FalImageGenService(
|
||||
params=FalImageGenService.InputParams(image_size="square_hd"),
|
||||
aiohttp_session=session,
|
||||
@@ -107,8 +107,10 @@ async def main():
|
||||
# that, each pipeline runs concurrently and `SyncParallelPipeline` will
|
||||
# wait for the input frame to be processed.
|
||||
#
|
||||
# Note that `SyncParallelPipeline` requires all processors in it to be
|
||||
# synchronous (which is the default for most processors).
|
||||
# Note that `SyncParallelPipeline` requires the last processor in each
|
||||
# of the pipelines to be synchronous. In this case, we use
|
||||
# `CartesiaHttpTTSService` and `FalImageGenService` which make HTTP
|
||||
# requests and wait for the response.
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm, # LLM
|
||||
|
||||
@@ -121,8 +121,10 @@ async def main():
|
||||
# `SyncParallelPipeline` will wait for the input frame to be
|
||||
# processed.
|
||||
#
|
||||
# Note that `SyncParallelPipeline` requires all processors in it to
|
||||
# be synchronous (which is the default for most processors).
|
||||
# Note that `SyncParallelPipeline` requires the last processor in
|
||||
# each of the pipelines to be synchronous. In this case, we use
|
||||
# `CartesiaHttpTTSService` and `FalImageGenService` which make HTTP
|
||||
# requests and wait for the response.
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm, # LLM
|
||||
|
||||
@@ -5,10 +5,15 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -19,14 +24,6 @@ from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
@@ -34,7 +31,12 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def start_fetch_weather(function_name, llm, context):
|
||||
await llm.push_frame(TextFrame("Let me check on that."))
|
||||
# note: we can't push a frame to the LLM here. the bot
|
||||
# can interrupt itself and/or cause audio overlapping glitches.
|
||||
# possible question for Aleix and Chad about what the right way
|
||||
# to trigger speech is, now, with the new queues/async/sync refactors.
|
||||
await llm.push_frame(TextFrame("Let me check on that. "))
|
||||
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
@@ -106,11 +108,11 @@ async def main():
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
fl_in,
|
||||
# fl_in,
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
fl_out,
|
||||
# fl_out,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
|
||||
@@ -585,6 +585,7 @@ class FunctionCallResultFrame(DataFrame):
|
||||
tool_call_id: str
|
||||
arguments: str
|
||||
result: Any
|
||||
run_llm: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -9,14 +9,20 @@ import asyncio
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import ControlFrame, Frame, SystemFrame
|
||||
from pipecat.pipeline.base_pipeline import BasePipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import Frame
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class SyncFrame(ControlFrame):
|
||||
"""This frame is used to know when the internal pipelines have finished."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Source(FrameProcessor):
|
||||
def __init__(self, upstream_queue: asyncio.Queue):
|
||||
super().__init__()
|
||||
@@ -67,13 +73,16 @@ class SyncParallelPipeline(BasePipeline):
|
||||
raise TypeError(f"SyncParallelPipeline argument {processors} is not a list")
|
||||
|
||||
# We add a source at the beginning of the pipeline and a sink at the end.
|
||||
source = Source(self._up_queue)
|
||||
sink = Sink(self._down_queue)
|
||||
up_queue = asyncio.Queue()
|
||||
down_queue = asyncio.Queue()
|
||||
source = Source(up_queue)
|
||||
sink = Sink(down_queue)
|
||||
processors: List[FrameProcessor] = [source] + processors + [sink]
|
||||
|
||||
# Keep track of sources and sinks.
|
||||
self._sources.append(source)
|
||||
self._sinks.append(sink)
|
||||
# Keep track of sources and sinks. We also keep the output queue of
|
||||
# the source and the sinks so we can use it later.
|
||||
self._sources.append({"processor": source, "queue": down_queue})
|
||||
self._sinks.append({"processor": sink, "queue": up_queue})
|
||||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline(processors)
|
||||
@@ -94,17 +103,46 @@ class SyncParallelPipeline(BasePipeline):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# The last processor of each pipeline needs to be synchronous otherwise
|
||||
# this element won't work. Since, we know it should be synchronous we
|
||||
# push a SyncFrame. Since frames are ordered we know this frame will be
|
||||
# pushed after the synchronous processor has pushed its data allowing us
|
||||
# to synchrnonize all the internal pipelines by waiting for the
|
||||
# SyncFrame in all of them.
|
||||
async def wait_for_sync(
|
||||
obj, main_queue: asyncio.Queue, frame: Frame, direction: FrameDirection
|
||||
):
|
||||
processor = obj["processor"]
|
||||
queue = obj["queue"]
|
||||
await processor.process_frame(frame, direction)
|
||||
|
||||
# If we have a system frame we don't need to synchrnonize anything.
|
||||
if isinstance(frame, SystemFrame):
|
||||
await main_queue.put(frame)
|
||||
else:
|
||||
await processor.process_frame(SyncFrame(), direction)
|
||||
|
||||
frame = await queue.get()
|
||||
while not isinstance(frame, SyncFrame):
|
||||
await main_queue.put(frame)
|
||||
queue.task_done()
|
||||
frame = await queue.get()
|
||||
|
||||
if direction == FrameDirection.UPSTREAM:
|
||||
# If we get an upstream frame we process it in each sink.
|
||||
await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sinks])
|
||||
await asyncio.gather(
|
||||
*[wait_for_sync(s, self._up_queue, frame, direction) for s in self._sinks]
|
||||
)
|
||||
elif direction == FrameDirection.DOWNSTREAM:
|
||||
# If we get a downstream frame we process it in each source.
|
||||
await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sources])
|
||||
await asyncio.gather(
|
||||
*[wait_for_sync(s, self._down_queue, frame, direction) for s in self._sources]
|
||||
)
|
||||
|
||||
seen_ids = set()
|
||||
while not self._up_queue.empty():
|
||||
frame = await self._up_queue.get()
|
||||
if frame and frame.id not in seen_ids:
|
||||
if frame.id not in seen_ids:
|
||||
await self.push_frame(frame, FrameDirection.UPSTREAM)
|
||||
seen_ids.add(frame.id)
|
||||
self._up_queue.task_done()
|
||||
@@ -112,7 +150,7 @@ class SyncParallelPipeline(BasePipeline):
|
||||
seen_ids = set()
|
||||
while not self._down_queue.empty():
|
||||
frame = await self._down_queue.get()
|
||||
if frame and frame.id not in seen_ids:
|
||||
if frame.id not in seen_ids:
|
||||
await self.push_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
seen_ids.add(frame.id)
|
||||
self._down_queue.task_done()
|
||||
|
||||
@@ -69,6 +69,19 @@ class Source(FrameProcessor):
|
||||
await self._up_queue.put(StopTaskFrame())
|
||||
|
||||
|
||||
class Sink(FrameProcessor):
|
||||
def __init__(self, down_queue: asyncio.Queue):
|
||||
super().__init__()
|
||||
self._down_queue = down_queue
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# We really just want to know when the EndFrame reached the sink.
|
||||
if isinstance(frame, EndFrame):
|
||||
await self._down_queue.put(frame)
|
||||
|
||||
|
||||
class PipelineTask:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -84,12 +97,16 @@ class PipelineTask:
|
||||
self._params = params
|
||||
self._finished = False
|
||||
|
||||
self._down_queue = asyncio.Queue()
|
||||
self._up_queue = asyncio.Queue()
|
||||
self._down_queue = asyncio.Queue()
|
||||
self._push_queue = asyncio.Queue()
|
||||
|
||||
self._source = Source(self._up_queue)
|
||||
self._source.link(pipeline)
|
||||
|
||||
self._sink = Sink(self._down_queue)
|
||||
pipeline.link(self._sink)
|
||||
|
||||
def has_finished(self):
|
||||
return self._finished
|
||||
|
||||
@@ -103,19 +120,19 @@ class PipelineTask:
|
||||
# out-of-band from the main streaming task which is what we want since
|
||||
# we want to cancel right away.
|
||||
await self._source.push_frame(CancelFrame())
|
||||
self._process_down_task.cancel()
|
||||
self._process_push_task.cancel()
|
||||
self._process_up_task.cancel()
|
||||
await self._process_down_task
|
||||
await self._process_push_task
|
||||
await self._process_up_task
|
||||
|
||||
async def run(self):
|
||||
self._process_up_task = asyncio.create_task(self._process_up_queue())
|
||||
self._process_down_task = asyncio.create_task(self._process_down_queue())
|
||||
await asyncio.gather(self._process_up_task, self._process_down_task)
|
||||
self._process_push_task = asyncio.create_task(self._process_push_queue())
|
||||
await asyncio.gather(self._process_up_task, self._process_push_task)
|
||||
self._finished = True
|
||||
|
||||
async def queue_frame(self, frame: Frame):
|
||||
await self._down_queue.put(frame)
|
||||
await self._push_queue.put(frame)
|
||||
|
||||
async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]):
|
||||
if isinstance(frames, AsyncIterable):
|
||||
@@ -133,7 +150,7 @@ class PipelineTask:
|
||||
data.append(ProcessingMetricsData(processor=p.name, value=0.0))
|
||||
return MetricsFrame(data=data)
|
||||
|
||||
async def _process_down_queue(self):
|
||||
async def _process_push_queue(self):
|
||||
self._clock.start()
|
||||
|
||||
start_frame = StartFrame(
|
||||
@@ -154,11 +171,13 @@ class PipelineTask:
|
||||
should_cleanup = True
|
||||
while running:
|
||||
try:
|
||||
frame = await self._down_queue.get()
|
||||
frame = await self._push_queue.get()
|
||||
await self._source.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
if isinstance(frame, EndFrame):
|
||||
await self._wait_for_endframe()
|
||||
running = not (isinstance(frame, StopTaskFrame) or isinstance(frame, EndFrame))
|
||||
should_cleanup = not isinstance(frame, StopTaskFrame)
|
||||
self._down_queue.task_done()
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
# Cleanup only if we need to.
|
||||
@@ -169,6 +188,12 @@ class PipelineTask:
|
||||
self._process_up_task.cancel()
|
||||
await self._process_up_task
|
||||
|
||||
async def _wait_for_endframe(self):
|
||||
# NOTE(aleix): the Sink element just pushes EndFrames to the down queue,
|
||||
# so just wait for it. In the future we might do something else here,
|
||||
# but for now this is fine.
|
||||
await self._down_queue.get()
|
||||
|
||||
async def _process_up_queue(self):
|
||||
while True:
|
||||
try:
|
||||
|
||||
@@ -133,6 +133,7 @@ class OpenAILLMContext:
|
||||
tool_call_id: str,
|
||||
arguments: str,
|
||||
llm: FrameProcessor,
|
||||
run_llm: bool = True,
|
||||
) -> None:
|
||||
# Push a SystemFrame downstream. This frame will let our assistant context aggregator
|
||||
# know that we are in the middle of a function call. Some contexts/aggregators may
|
||||
@@ -153,6 +154,7 @@ class OpenAILLMContext:
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
result=result,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ class FrameProcessor:
|
||||
*,
|
||||
name: str | None = None,
|
||||
metrics: FrameProcessorMetrics | None = None,
|
||||
sync: bool = True,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -47,7 +46,6 @@ class FrameProcessor:
|
||||
self._prev: "FrameProcessor" | None = None
|
||||
self._next: "FrameProcessor" | None = None
|
||||
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_running_loop()
|
||||
self._sync = sync
|
||||
|
||||
self._event_handlers: dict = {}
|
||||
|
||||
@@ -66,11 +64,8 @@ class FrameProcessor:
|
||||
|
||||
# Every processor in Pipecat should only output frames from a single
|
||||
# task. This avoid problems like audio overlapping. System frames are
|
||||
# the exception to this rule.
|
||||
#
|
||||
# This create this task.
|
||||
if not self._sync:
|
||||
self.__create_push_task()
|
||||
# the exception to this rule. This create this task.
|
||||
self.__create_push_task()
|
||||
|
||||
@property
|
||||
def interruptions_allowed(self):
|
||||
@@ -167,7 +162,7 @@ class FrameProcessor:
|
||||
await self.push_frame(error, FrameDirection.UPSTREAM)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
if self._sync or isinstance(frame, SystemFrame):
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
else:
|
||||
await self.__push_queue.put((frame, direction))
|
||||
@@ -194,13 +189,12 @@ class FrameProcessor:
|
||||
#
|
||||
|
||||
async def _start_interruption(self):
|
||||
if not self._sync:
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self.__push_frame_task.cancel()
|
||||
await self.__push_frame_task
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self.__push_frame_task.cancel()
|
||||
await self.__push_frame_task
|
||||
|
||||
# Create a new queue and task.
|
||||
self.__create_push_task()
|
||||
# Create a new queue and task.
|
||||
self.__create_push_task()
|
||||
|
||||
async def _stop_interruption(self):
|
||||
# Nothing to do right now.
|
||||
|
||||
@@ -516,7 +516,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
params: RTVIProcessorParams = RTVIProcessorParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
self._config = config
|
||||
self._params = params
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class GStreamerPipelineSource(FrameProcessor):
|
||||
clock_sync: bool = True
|
||||
|
||||
def __init__(self, *, pipeline: str, out_params: OutputParams = OutputParams(), **kwargs):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._out_params = out_params
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class IdleFrameProcessor(FrameProcessor):
|
||||
types: List[type] = [],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._callback = callback
|
||||
self._timeout = timeout
|
||||
|
||||
@@ -31,7 +31,7 @@ class UserIdleProcessor(FrameProcessor):
|
||||
timeout: float,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._callback = callback
|
||||
self._timeout = timeout
|
||||
|
||||
@@ -110,7 +110,13 @@ class LLMService(AIService):
|
||||
return function_name in self._callbacks.keys()
|
||||
|
||||
async def call_function(
|
||||
self, *, context: OpenAILLMContext, tool_call_id: str, function_name: str, arguments: str
|
||||
self,
|
||||
*,
|
||||
context: OpenAILLMContext,
|
||||
tool_call_id: str,
|
||||
function_name: str,
|
||||
arguments: str,
|
||||
run_llm: bool,
|
||||
) -> None:
|
||||
f = None
|
||||
if function_name in self._callbacks.keys():
|
||||
@@ -120,7 +126,12 @@ class LLMService(AIService):
|
||||
else:
|
||||
return None
|
||||
await context.call_function(
|
||||
f, function_name=function_name, tool_call_id=tool_call_id, arguments=arguments, llm=self
|
||||
f,
|
||||
function_name=function_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
llm=self,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
|
||||
# QUESTION FOR CB: maybe this isn't needed anymore?
|
||||
@@ -144,6 +155,10 @@ class TTSService(AIService):
|
||||
# if True, TTSService will push TextFrames and LLMFullResponseEndFrames,
|
||||
# otherwise subclass must do it
|
||||
push_text_frames: bool = True,
|
||||
# if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it
|
||||
push_stop_frames: bool = False,
|
||||
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
|
||||
stop_frame_timeout_s: float = 1.0,
|
||||
# TTS output sample rate
|
||||
sample_rate: int = 16000,
|
||||
**kwargs,
|
||||
@@ -151,9 +166,15 @@ class TTSService(AIService):
|
||||
super().__init__(**kwargs)
|
||||
self._aggregate_sentences: bool = aggregate_sentences
|
||||
self._push_text_frames: bool = push_text_frames
|
||||
self._current_sentence: str = ""
|
||||
self._push_stop_frames: bool = push_stop_frames
|
||||
self._stop_frame_timeout_s: float = stop_frame_timeout_s
|
||||
self._sample_rate: int = sample_rate
|
||||
|
||||
self._stop_frame_task: Optional[asyncio.Task] = None
|
||||
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
self._current_sentence: str = ""
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
return self._sample_rate
|
||||
@@ -210,13 +231,72 @@ class TTSService(AIService):
|
||||
async def set_role(self, role: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def flush_audio(self):
|
||||
pass
|
||||
|
||||
# Converts the text to audio.
|
||||
@abstractmethod
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
pass
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
if self._push_stop_frames:
|
||||
self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
if self._stop_frame_task:
|
||||
self._stop_frame_task.cancel()
|
||||
await self._stop_frame_task
|
||||
self._stop_frame_task = None
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
if self._stop_frame_task:
|
||||
self._stop_frame_task.cancel()
|
||||
await self._stop_frame_task
|
||||
self._stop_frame_task = None
|
||||
|
||||
async def say(self, text: str):
|
||||
await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM)
|
||||
await self.flush_audio()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._process_text_frame(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruption(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame):
|
||||
sentence = self._current_sentence
|
||||
self._current_sentence = ""
|
||||
await self._push_tts_frames(sentence)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
await self._push_tts_frames(frame.text)
|
||||
await self.flush_audio()
|
||||
elif isinstance(frame, TTSUpdateSettingsFrame):
|
||||
await self._update_tts_settings(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
if self._push_stop_frames and (
|
||||
isinstance(frame, StartInterruptionFrame)
|
||||
or isinstance(frame, TTSStartedFrame)
|
||||
or isinstance(frame, TTSAudioRawFrame)
|
||||
or isinstance(frame, TTSStoppedFrame)
|
||||
):
|
||||
await self._stop_frame_queue.put(frame)
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
self._current_sentence = ""
|
||||
@@ -276,88 +356,6 @@ class TTSService(AIService):
|
||||
if frame.role is not None:
|
||||
await self.set_role(frame.role)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._process_text_frame(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruption(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame):
|
||||
sentence = self._current_sentence
|
||||
self._current_sentence = ""
|
||||
await self._push_tts_frames(sentence)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
await self._push_tts_frames(frame.text)
|
||||
elif isinstance(frame, TTSUpdateSettingsFrame):
|
||||
await self._update_tts_settings(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class AsyncTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
# if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it
|
||||
push_stop_frames: bool = False,
|
||||
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
|
||||
stop_frame_timeout_s: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
self._push_stop_frames: bool = push_stop_frames
|
||||
self._stop_frame_timeout_s: float = stop_frame_timeout_s
|
||||
self._stop_frame_task: Optional[asyncio.Task] = None
|
||||
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
@abstractmethod
|
||||
async def flush_audio(self):
|
||||
pass
|
||||
|
||||
async def say(self, text: str):
|
||||
await super().say(text)
|
||||
await self.flush_audio()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
if self._push_stop_frames:
|
||||
self._stop_frame_task = self.get_event_loop().create_task(self._stop_frame_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
if self._stop_frame_task:
|
||||
self._stop_frame_task.cancel()
|
||||
await self._stop_frame_task
|
||||
self._stop_frame_task = None
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
if self._stop_frame_task:
|
||||
self._stop_frame_task.cancel()
|
||||
await self._stop_frame_task
|
||||
self._stop_frame_task = None
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TTSSpeakFrame):
|
||||
await self.flush_audio()
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
if self._push_stop_frames and (
|
||||
isinstance(frame, StartInterruptionFrame)
|
||||
or isinstance(frame, TTSStartedFrame)
|
||||
or isinstance(frame, TTSAudioRawFrame)
|
||||
or isinstance(frame, TTSStoppedFrame)
|
||||
):
|
||||
await self._stop_frame_queue.put(frame)
|
||||
|
||||
async def _stop_frame_handler(self):
|
||||
try:
|
||||
has_started = False
|
||||
@@ -378,7 +376,7 @@ class AsyncTTSService(TTSService):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncWordTTSService(AsyncTTSService):
|
||||
class WordTTSService(TTSService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._initial_word_timestamp = -1
|
||||
|
||||
@@ -26,7 +26,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.services.ai_services import AsyncWordTTSService, TTSService
|
||||
from pipecat.services.ai_services import WordTTSService, TTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -61,7 +61,7 @@ def language_to_cartesia_language(language: Language) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
class CartesiaTTSService(AsyncWordTTSService):
|
||||
class CartesiaTTSService(WordTTSService):
|
||||
class InputParams(BaseModel):
|
||||
encoding: Optional[str] = "pcm_s16le"
|
||||
sample_rate: Optional[int] = 16000
|
||||
|
||||
@@ -23,7 +23,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AsyncWordTTSService
|
||||
from pipecat.services.ai_services import WordTTSService
|
||||
|
||||
# See .env.example for ElevenLabs configuration needed
|
||||
try:
|
||||
@@ -70,7 +70,7 @@ def calculate_word_times(
|
||||
return word_times
|
||||
|
||||
|
||||
class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
class ElevenLabsTTSService(WordTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[str] = None
|
||||
output_format: Literal["pcm_16000", "pcm_22050", "pcm_24000", "pcm_44100"] = "pcm_16000"
|
||||
|
||||
@@ -51,7 +51,7 @@ class GladiaSTTService(STTService):
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
|
||||
@@ -20,7 +20,7 @@ from pipecat.frames.frames import (
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import AsyncTTSService
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -35,7 +35,7 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class LmntTTSService(AsyncTTSService):
|
||||
class LmntTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -47,7 +47,7 @@ class LmntTTSService(AsyncTTSService):
|
||||
):
|
||||
# Let TTSService produce TTSStoppedFrames after a short delay of
|
||||
# no activity.
|
||||
super().__init__(sync=False, push_stop_frames=True, sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(push_stop_frames=True, sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
|
||||
@@ -205,6 +205,10 @@ class BaseOpenAILLMService(LLMService):
|
||||
return chunks
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
functions_list = []
|
||||
arguments_list = []
|
||||
tool_id_list = []
|
||||
func_idx = 0
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
@@ -242,6 +246,14 @@ class BaseOpenAILLMService(LLMService):
|
||||
# yield a frame containing the function name and the arguments.
|
||||
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.index != func_idx:
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
func_idx += 1
|
||||
if tool_call.function and tool_call.function.name:
|
||||
function_name += tool_call.function.name
|
||||
tool_call_id = tool_call.id
|
||||
@@ -257,21 +269,29 @@ class BaseOpenAILLMService(LLMService):
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
# handler, raise an exception.
|
||||
if function_name and arguments:
|
||||
if self.has_function(function_name):
|
||||
await self._handle_function_call(context, tool_call_id, function_name, arguments)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
# added to the list as last function name and arguments not added to the list
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
|
||||
async def _handle_function_call(self, context, tool_call_id, function_name, arguments):
|
||||
arguments = json.loads(arguments)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=tool_call_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
total_items = len(functions_list)
|
||||
for index, (function_name, arguments, tool_id) in enumerate(
|
||||
zip(functions_list, arguments_list, tool_id_list), start=1
|
||||
):
|
||||
if self.has_function(function_name):
|
||||
run_llm = index == total_items
|
||||
arguments = json.loads(arguments)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
tool_call_id=tool_id,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
|
||||
async def _update_settings(self, frame: LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
@@ -461,31 +481,27 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_calls_in_progress = {}
|
||||
self._function_call_result = None
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# See note above about not calling push_frame() here.
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
self._function_call_in_progress = None
|
||||
self._function_calls_in_progress.clear()
|
||||
self._function_call_finished = None
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
self._function_call_in_progress = frame
|
||||
self._function_calls_in_progress[frame.tool_call_id] = frame
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
if (
|
||||
self._function_call_in_progress
|
||||
and self._function_call_in_progress.tool_call_id == frame.tool_call_id
|
||||
):
|
||||
self._function_call_in_progress = None
|
||||
if frame.tool_call_id in self._function_calls_in_progress:
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
self._function_call_result = frame
|
||||
# TODO-CB: Kwin wants us to refactor this out of here but I REFUSE
|
||||
await self._push_aggregation()
|
||||
else:
|
||||
logger.warning(
|
||||
"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
|
||||
"FunctionCallResultFrame tool_call_id does not match any function call in progress"
|
||||
)
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
|
||||
async def _push_aggregation(self):
|
||||
@@ -524,7 +540,7 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
run_llm = True
|
||||
run_llm = frame.run_llm
|
||||
else:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from loguru import logger
|
||||
|
||||
class BaseInputTransport(FrameProcessor):
|
||||
def __init__(self, params: TransportParams, **kwargs):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._params = params
|
||||
|
||||
|
||||
@@ -43,7 +43,7 @@ from pipecat.utils.time import nanoseconds_to_seconds
|
||||
|
||||
class BaseOutputTransport(FrameProcessor):
|
||||
def __init__(self, params: TransportParams, **kwargs):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._params = params
|
||||
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StopTaskFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -32,6 +32,7 @@ from langchain_core.language_models import FakeStreamingListLLM
|
||||
class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
class MockProcessor(FrameProcessor):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.token: list[str] = []
|
||||
# Start collecting tokens when we see the start frame
|
||||
@@ -55,13 +56,13 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.expected_response = "Hello dear human"
|
||||
self.fake_llm = FakeStreamingListLLM(responses=[self.expected_response])
|
||||
self.mock_proc = self.MockProcessor("token_collector")
|
||||
|
||||
async def test_langchain(self):
|
||||
messages = [("system", "Say hello to {name}"), ("human", "{input}")]
|
||||
prompt = ChatPromptTemplate.from_messages(messages).partial(name="Thomas")
|
||||
chain = prompt | self.fake_llm
|
||||
proc = LangchainProcessor(chain=chain)
|
||||
self.mock_proc = self.MockProcessor("token_collector")
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
@@ -81,7 +82,7 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
|
||||
UserStoppedSpeakingFrame(),
|
||||
StopTaskFrame(),
|
||||
EndFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user