Compare commits
4 Commits
aleix/queu
...
hush/openA
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e09028aca2 | ||
|
|
b17165b7ea | ||
|
|
19a4b97504 | ||
|
|
fda762d8e8 |
274
examples/foundational/99-open-ai-agent.py
Normal file
274
examples/foundational/99-open-ai-agent.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2024–2025, Daily
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: BSD 2-Clause License
|
||||||
|
#
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import field
|
||||||
|
from typing import List, Literal, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from agents import Agent, Runner
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from loguru import logger
|
||||||
|
from openai import AsyncStream, BaseModel
|
||||||
|
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||||
|
|
||||||
|
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||||
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||||
|
from pipecat.frames.frames import (
|
||||||
|
Frame,
|
||||||
|
LLMFullResponseEndFrame,
|
||||||
|
LLMFullResponseStartFrame,
|
||||||
|
LLMMessagesFrame,
|
||||||
|
LLMTextFrame,
|
||||||
|
LLMUpdateSettingsFrame,
|
||||||
|
VisionImageRawFrame,
|
||||||
|
)
|
||||||
|
from pipecat.pipeline.pipeline import Pipeline
|
||||||
|
from pipecat.pipeline.runner import PipelineRunner
|
||||||
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||||
|
from pipecat.processors.aggregators.llm_response import (
|
||||||
|
LLMAssistantAggregatorParams,
|
||||||
|
LLMUserAggregatorParams,
|
||||||
|
)
|
||||||
|
from pipecat.processors.aggregators.openai_llm_context import (
|
||||||
|
OpenAILLMContext,
|
||||||
|
OpenAILLMContextFrame,
|
||||||
|
)
|
||||||
|
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||||
|
from pipecat.services.ai_service import AIService
|
||||||
|
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||||
|
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||||
|
from pipecat.services.openai.base_llm import BaseOpenAILLMService
|
||||||
|
from pipecat.services.openai.llm import (
|
||||||
|
OpenAIAssistantContextAggregator,
|
||||||
|
OpenAIContextAggregatorPair,
|
||||||
|
OpenAILLMService,
|
||||||
|
OpenAIUserContextAggregator,
|
||||||
|
)
|
||||||
|
from pipecat.transports.base_transport import TransportParams
|
||||||
|
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||||
|
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||||
|
|
||||||
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
|
class LlmMessage(BaseModel):
|
||||||
|
# ...
|
||||||
|
role: Literal["system", "user", "assistant", "tool"]
|
||||||
|
content: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class AgentResponse(BaseModel):
|
||||||
|
content: str
|
||||||
|
msgs: list[LlmMessage] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class BackendBase(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
async def get_resp(self, messages: list[LlmMessage], extra_params) -> AgentResponse:
|
||||||
|
raise NotImplementedError("The method get_resp is not implemented.")
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceDelta(BaseModel):
|
||||||
|
content: Optional[str] = None
|
||||||
|
"""The contents of the chunk message."""
|
||||||
|
|
||||||
|
|
||||||
|
class Choice(BaseModel):
|
||||||
|
delta: ChoiceDelta
|
||||||
|
"""The contents of the chunk message."""
|
||||||
|
|
||||||
|
index: int
|
||||||
|
"""The index of the choice in the list of choices."""
|
||||||
|
|
||||||
|
|
||||||
|
class CustomLLMService(BaseOpenAILLMService):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._client = Agent(
|
||||||
|
name="Assistant agent",
|
||||||
|
instructions="Respond with haikus.",
|
||||||
|
# tools=[get_weather],
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_client(
|
||||||
|
self,
|
||||||
|
api_key=None,
|
||||||
|
base_url=None,
|
||||||
|
organization=None,
|
||||||
|
project=None,
|
||||||
|
default_headers=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
return Agent(
|
||||||
|
name="Assistant agent",
|
||||||
|
instructions="Respond with haikus.",
|
||||||
|
# tools=[get_weather],
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_context_aggregator(
|
||||||
|
self,
|
||||||
|
context: OpenAILLMContext,
|
||||||
|
*,
|
||||||
|
user_params: LLMUserAggregatorParams = LLMUserAggregatorParams(),
|
||||||
|
assistant_params: LLMAssistantAggregatorParams = LLMAssistantAggregatorParams(),
|
||||||
|
) -> OpenAIContextAggregatorPair:
|
||||||
|
"""Create an instance of OpenAIContextAggregatorPair.
|
||||||
|
|
||||||
|
from an
|
||||||
|
OpenAILLMContext. Constructor keyword arguments for both the user and
|
||||||
|
assistant aggregators can be provided.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context (OpenAILLMContext): The LLM context.
|
||||||
|
user_params (LLMUserAggregatorParams, optional): User aggregator parameters.
|
||||||
|
assistant_params (LLMAssistantAggregatorParams, optional): User aggregator parameters.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
OpenAIContextAggregatorPair: A pair of context aggregators, one for
|
||||||
|
the user and one for the assistant, encapsulated in an
|
||||||
|
OpenAIContextAggregatorPair.
|
||||||
|
|
||||||
|
"""
|
||||||
|
context.set_llm_adapter(self.get_llm_adapter())
|
||||||
|
user = OpenAIUserContextAggregator(context, params=user_params)
|
||||||
|
assistant = OpenAIAssistantContextAggregator(context, params=assistant_params)
|
||||||
|
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||||
|
|
||||||
|
async def _process_context(self, context: OpenAILLMContext):
|
||||||
|
functions_list = []
|
||||||
|
arguments_list = []
|
||||||
|
tool_id_list = []
|
||||||
|
func_idx = 0
|
||||||
|
function_name = ""
|
||||||
|
arguments = ""
|
||||||
|
tool_call_id = ""
|
||||||
|
|
||||||
|
await self.start_ttfb_metrics()
|
||||||
|
|
||||||
|
result = Runner.run_streamed(
|
||||||
|
# context=context,
|
||||||
|
starting_agent=self._client,
|
||||||
|
input=context.messages, # messages
|
||||||
|
# ---
|
||||||
|
# no func tool
|
||||||
|
# input="give me a 2 sentences about life",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"get_chat_completions: {result}")
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
logger.error("Runner.run_streamed returned None")
|
||||||
|
return
|
||||||
|
|
||||||
|
async for event in result.stream_events():
|
||||||
|
if event.type == "raw_response_event":
|
||||||
|
if event.data.type == "response.output_text.delta":
|
||||||
|
await self.push_frame(LLMTextFrame(event.data.delta))
|
||||||
|
|
||||||
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||||
|
await super().process_frame(frame, direction)
|
||||||
|
|
||||||
|
context = None
|
||||||
|
if isinstance(frame, OpenAILLMContextFrame):
|
||||||
|
context: OpenAILLMContext = frame.context
|
||||||
|
elif isinstance(frame, LLMMessagesFrame):
|
||||||
|
context = OpenAILLMContext.from_messages(frame.messages)
|
||||||
|
else:
|
||||||
|
await self.push_frame(frame, direction)
|
||||||
|
|
||||||
|
if context:
|
||||||
|
try:
|
||||||
|
await self.push_frame(LLMFullResponseStartFrame())
|
||||||
|
await self.start_processing_metrics()
|
||||||
|
await self._process_context(context)
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
await self._call_event_handler("on_completion_timeout")
|
||||||
|
finally:
|
||||||
|
await self.stop_processing_metrics()
|
||||||
|
await self.push_frame(LLMFullResponseEndFrame())
|
||||||
|
|
||||||
|
|
||||||
|
async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespace):
|
||||||
|
logger.info(f"Starting bot")
|
||||||
|
|
||||||
|
transport = SmallWebRTCTransport(
|
||||||
|
webrtc_connection=webrtc_connection,
|
||||||
|
params=TransportParams(
|
||||||
|
audio_in_enabled=True,
|
||||||
|
audio_out_enabled=True,
|
||||||
|
vad_analyzer=SileroVADAnalyzer(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||||
|
|
||||||
|
tts = CartesiaTTSService(
|
||||||
|
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||||
|
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = CustomLLMService(model="gpt-4.1", api_key=os.getenv("OPENAI_API_KEY"))
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
context = OpenAILLMContext(messages=messages)
|
||||||
|
context_aggregator = llm.create_context_aggregator(context)
|
||||||
|
|
||||||
|
pipeline = Pipeline(
|
||||||
|
[
|
||||||
|
transport.input(), # Transport user input
|
||||||
|
stt,
|
||||||
|
context_aggregator.user(), # User responses
|
||||||
|
llm, # LLM
|
||||||
|
tts, # TTS
|
||||||
|
transport.output(), # Transport bot output
|
||||||
|
context_aggregator.assistant(), # Assistant spoken responses
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
task = PipelineTask(
|
||||||
|
pipeline,
|
||||||
|
params=PipelineParams(
|
||||||
|
allow_interruptions=True,
|
||||||
|
enable_metrics=True,
|
||||||
|
enable_usage_metrics=True,
|
||||||
|
report_only_initial_ttfb=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@transport.event_handler("on_client_connected")
|
||||||
|
async def on_client_connected(transport, client):
|
||||||
|
logger.info(f"Client connected")
|
||||||
|
# Kick off the conversation.
|
||||||
|
# messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||||
|
# await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||||
|
|
||||||
|
@transport.event_handler("on_client_disconnected")
|
||||||
|
async def on_client_disconnected(transport, client):
|
||||||
|
logger.info(f"Client disconnected")
|
||||||
|
|
||||||
|
@transport.event_handler("on_client_closed")
|
||||||
|
async def on_client_closed(transport, client):
|
||||||
|
logger.info(f"Client closed connection")
|
||||||
|
await task.cancel()
|
||||||
|
|
||||||
|
runner = PipelineRunner(handle_sigint=False)
|
||||||
|
|
||||||
|
await runner.run(task)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from run import main
|
||||||
|
|
||||||
|
main()
|
||||||
122
examples/foundational/agent.py
Normal file
122
examples/foundational/agent.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from agents import (
|
||||||
|
Agent,
|
||||||
|
FunctionTool,
|
||||||
|
HandoffOutputItem,
|
||||||
|
ItemHelpers,
|
||||||
|
MessageOutputItem,
|
||||||
|
RunContextWrapper,
|
||||||
|
Runner,
|
||||||
|
ToolCallItem,
|
||||||
|
ToolCallOutputItem,
|
||||||
|
function_tool,
|
||||||
|
set_default_openai_api,
|
||||||
|
set_default_openai_client,
|
||||||
|
set_tracing_disabled,
|
||||||
|
trace,
|
||||||
|
)
|
||||||
|
from httpx import get
|
||||||
|
|
||||||
|
|
||||||
|
@function_tool
|
||||||
|
async def get_weather(location: str) -> str:
|
||||||
|
"""Fetch the weather for today.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
location: The location to fetch the weather for.
|
||||||
|
"""
|
||||||
|
return f"{location} is sunny"
|
||||||
|
|
||||||
|
|
||||||
|
system_prompt = """
|
||||||
|
you are a helpful assistant for a real estate brokerage AI assistant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
bot = Agent(
|
||||||
|
name="Assistant agent",
|
||||||
|
instructions=system_prompt,
|
||||||
|
# tools=[get_weather],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
# res = await Runner.run(
|
||||||
|
# starting_agent=bot,
|
||||||
|
# input="What is the weather today?",
|
||||||
|
# )
|
||||||
|
# print(res)
|
||||||
|
|
||||||
|
result = Runner.run_streamed(
|
||||||
|
starting_agent=bot,
|
||||||
|
# ---
|
||||||
|
# with func tool
|
||||||
|
input="Tell a joke about pirates.",
|
||||||
|
# ---
|
||||||
|
# no func tool
|
||||||
|
# input="give me a 2 sentences about life",
|
||||||
|
)
|
||||||
|
|
||||||
|
final = []
|
||||||
|
async for event in result.stream_events():
|
||||||
|
# We'll ignore the raw responses event deltas
|
||||||
|
name = getattr(event, "name", None)
|
||||||
|
# print(f"Event: {event.type} - name {name}")
|
||||||
|
# print(event)
|
||||||
|
# continue
|
||||||
|
if event.type == "raw_response_event":
|
||||||
|
if event.data.type == "response.output_text.delta":
|
||||||
|
final += event.data.delta
|
||||||
|
|
||||||
|
print(f"raw resp: {event}")
|
||||||
|
# When the agent updates, print that
|
||||||
|
elif event.type == "agent_updated_stream_event":
|
||||||
|
print(f"Agent updated: {event.new_agent.name}")
|
||||||
|
continue
|
||||||
|
# When items are generated, print them
|
||||||
|
elif event.type == "run_item_stream_event":
|
||||||
|
if event.item.type == "tool_call_item":
|
||||||
|
print("-- Tool was called")
|
||||||
|
elif event.item.type == "tool_call_output_item":
|
||||||
|
print(f"-- Tool output: {event.item.output}")
|
||||||
|
elif event.item.type == "message_output_item":
|
||||||
|
print(f"-- Message output:\n {ItemHelpers.text_message_output(event.item)}")
|
||||||
|
else:
|
||||||
|
print(f"-- Unknown item type: {event.item.type}")
|
||||||
|
pass # Ignore other event types
|
||||||
|
else:
|
||||||
|
print(f"-- Unknown out item type: {event.item.type}")
|
||||||
|
|
||||||
|
print(f"----------------------")
|
||||||
|
|
||||||
|
print(f"FinalFinalFinal: {''.join(final)}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
||||||
|
|
||||||
|
# no func tool:
|
||||||
|
#
|
||||||
|
# Event: agent_updated_stream_event - name None
|
||||||
|
# Event: raw_response_event - name None
|
||||||
|
# ...
|
||||||
|
# Event: raw_response_event - name None
|
||||||
|
# Event: run_item_stream_event - name message_output_created
|
||||||
|
|
||||||
|
# with func tool:
|
||||||
|
#
|
||||||
|
# Event: agent_updated_stream_event - name None
|
||||||
|
# Event: raw_response_event - name None
|
||||||
|
# ...
|
||||||
|
# Event: raw_response_event - name None
|
||||||
|
# Event: run_item_stream_event - name tool_called
|
||||||
|
# Event: run_item_stream_event - name tool_output
|
||||||
|
# Event: raw_response_event - name None
|
||||||
|
# ...
|
||||||
|
# Event: raw_response_event - name None
|
||||||
|
# Event: run_item_stream_event - name message_output_created
|
||||||
@@ -77,8 +77,8 @@ class Frame:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SystemFrame(Frame):
|
class SystemFrame(Frame):
|
||||||
"""A frame that takes higher priority than other frames. System frames are
|
"""System frames are frames that are not internally queued by any of the
|
||||||
handled in order and are not affected by user interruptions.
|
frame processors and should be processed immediately.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -87,9 +87,8 @@ class SystemFrame(Frame):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataFrame(Frame):
|
class DataFrame(Frame):
|
||||||
"""A frame that is processed in order and usually contains data such as LLM
|
"""Data frames are frames that will be processed in order and usually
|
||||||
context, text, audio or images. Data frames are cancelled by user
|
contain data such as LLM context, text, audio or images.
|
||||||
interruptions.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -98,9 +97,9 @@ class DataFrame(Frame):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ControlFrame(Frame):
|
class ControlFrame(Frame):
|
||||||
"""A frame that, as data frames, is processed in order and usually contains
|
"""Control frames are frames that, similar to data frames, will be processed
|
||||||
control information such as update settings or to end the pipeline after
|
in order and usually contain control information such as frames to update
|
||||||
everything is flushed. Control frames are cancelled by user interruptions.
|
settings or to end the pipeline.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -691,7 +690,7 @@ class FunctionCallResultFrame(SystemFrame):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class STTMuteFrame(SystemFrame):
|
class STTMuteFrame(SystemFrame):
|
||||||
"""A frame to mute/unmute the STT service."""
|
"""System frame to mute/unmute the STT service."""
|
||||||
|
|
||||||
mute: bool
|
mute: bool
|
||||||
|
|
||||||
@@ -797,7 +796,7 @@ class EndFrame(ControlFrame):
|
|||||||
should be shut down. If the transport receives this frame, it will stop
|
should be shut down. If the transport receives this frame, it will stop
|
||||||
sending frames to its output channel(s) and close all its threads. Note,
|
sending frames to its output channel(s) and close all its threads. Note,
|
||||||
that this is a control frame, which means it will received in the order it
|
that this is a control frame, which means it will received in the order it
|
||||||
was sent.
|
was sent (unline system frames).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
#
|
#
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Awaitable, Callable, Coroutine, Optional
|
from typing import Awaitable, Callable, Coroutine, Optional
|
||||||
|
|
||||||
@@ -33,51 +32,6 @@ class FrameDirection(Enum):
|
|||||||
UPSTREAM = 2
|
UPSTREAM = 2
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FrameProcessorQueueItem:
|
|
||||||
frame: Frame
|
|
||||||
direction: FrameDirection
|
|
||||||
callback: Optional[Callable[["FrameProcessor", Frame, FrameDirection], Awaitable[None]]]
|
|
||||||
|
|
||||||
|
|
||||||
class FrameProcessorQueue:
|
|
||||||
def __init__(self):
|
|
||||||
self._queue = asyncio.Queue()
|
|
||||||
self._urgent_queue = asyncio.Queue()
|
|
||||||
self._event = asyncio.Event()
|
|
||||||
|
|
||||||
async def put(self, item: FrameProcessorQueueItem):
|
|
||||||
if isinstance(item.frame, SystemFrame):
|
|
||||||
await self._urgent_queue.put(item)
|
|
||||||
else:
|
|
||||||
await self._queue.put(item)
|
|
||||||
self._event.set()
|
|
||||||
|
|
||||||
async def get(self) -> FrameProcessorQueueItem:
|
|
||||||
# Wait for an item in any of the queues.
|
|
||||||
await self._event.wait()
|
|
||||||
|
|
||||||
if self._urgent_queue.empty():
|
|
||||||
item = await self._queue.get()
|
|
||||||
self._queue.task_done()
|
|
||||||
else:
|
|
||||||
item = await self._urgent_queue.get()
|
|
||||||
self._urgent_queue.task_done()
|
|
||||||
|
|
||||||
# Clear the event only if all queues are empty.
|
|
||||||
if self._queue.empty() and self._urgent_queue.empty():
|
|
||||||
self._event.clear()
|
|
||||||
|
|
||||||
return item
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
self._queue = asyncio.Queue()
|
|
||||||
|
|
||||||
# Clear the event only if all queues are empty.
|
|
||||||
if self._queue.empty() and self._urgent_queue.empty():
|
|
||||||
self._event.clear()
|
|
||||||
|
|
||||||
|
|
||||||
class FrameProcessor(BaseObject):
|
class FrameProcessor(BaseObject):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -115,21 +69,18 @@ class FrameProcessor(BaseObject):
|
|||||||
self._metrics = metrics or FrameProcessorMetrics()
|
self._metrics = metrics or FrameProcessorMetrics()
|
||||||
self._metrics.set_processor_name(self.name)
|
self._metrics.set_processor_name(self.name)
|
||||||
|
|
||||||
# Processors receive frames on a streaming queue which are then
|
# Processors have an input queue. The input queue will be processed
|
||||||
# processed by a streaming task. This guarantees that all frames are
|
# immediately (default) or it will block if `pause_processing_frames()`
|
||||||
# processed in the same task. By default, the streaming queue is
|
|
||||||
# processed immediately but it may block if `pause_processing_frames()`
|
|
||||||
# is called. To resume processing frames we need to call
|
# is called. To resume processing frames we need to call
|
||||||
# `resume_processing_frames()` which will wake up the event.
|
# `resume_processing_frames()` which will wake up the event.
|
||||||
self.__should_block_frames = False
|
self.__should_block_frames = False
|
||||||
self.__streaming_event = asyncio.Event()
|
self.__input_event = asyncio.Event()
|
||||||
self.__streaming_queue = FrameProcessorQueue()
|
self.__input_frame_task: Optional[asyncio.Task] = None
|
||||||
self.__streaming_frame_task: Optional[asyncio.Task] = None
|
|
||||||
|
|
||||||
self.__process_queue = asyncio.Queue()
|
# Every processor in Pipecat should only output frames from a single
|
||||||
self.__process_task: Optional[asyncio.Task] = None
|
# task. This avoid problems like audio overlapping. System frames are the
|
||||||
self.__process_urgent_queue = asyncio.Queue()
|
# exception to this rule. This create this task.
|
||||||
self.__process_urgent_task: Optional[asyncio.Task] = None
|
self.__push_frame_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def id(self) -> int:
|
def id(self) -> int:
|
||||||
@@ -219,8 +170,7 @@ class FrameProcessor(BaseObject):
|
|||||||
async def cleanup(self):
|
async def cleanup(self):
|
||||||
await super().cleanup()
|
await super().cleanup()
|
||||||
await self.__cancel_input_task()
|
await self.__cancel_input_task()
|
||||||
await self.__cancel_process_task()
|
await self.__cancel_push_task()
|
||||||
await self.__cancel_process_urgent_task()
|
|
||||||
|
|
||||||
def link(self, processor: "FrameProcessor"):
|
def link(self, processor: "FrameProcessor"):
|
||||||
self._next = processor
|
self._next = processor
|
||||||
@@ -265,7 +215,7 @@ class FrameProcessor(BaseObject):
|
|||||||
await self.process_frame(frame, direction)
|
await self.process_frame(frame, direction)
|
||||||
else:
|
else:
|
||||||
# We queue everything else.
|
# We queue everything else.
|
||||||
await self.__streaming_queue.put(FrameProcessorQueueItem(frame, direction, callback))
|
await self.__input_queue.put((frame, direction, callback))
|
||||||
|
|
||||||
async def pause_processing_frames(self):
|
async def pause_processing_frames(self):
|
||||||
logger.trace(f"{self}: pausing frame processing")
|
logger.trace(f"{self}: pausing frame processing")
|
||||||
@@ -273,7 +223,7 @@ class FrameProcessor(BaseObject):
|
|||||||
|
|
||||||
async def resume_processing_frames(self):
|
async def resume_processing_frames(self):
|
||||||
logger.trace(f"{self}: resuming frame processing")
|
logger.trace(f"{self}: resuming frame processing")
|
||||||
self.__streaming_event.set()
|
self.__input_event.set()
|
||||||
|
|
||||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||||
if isinstance(frame, StartFrame):
|
if isinstance(frame, StartFrame):
|
||||||
@@ -300,6 +250,47 @@ class FrameProcessor(BaseObject):
|
|||||||
if not self._check_ready(frame):
|
if not self._check_ready(frame):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if isinstance(frame, SystemFrame):
|
||||||
|
await self.__internal_push_frame(frame, direction)
|
||||||
|
else:
|
||||||
|
await self.__push_queue.put((frame, direction))
|
||||||
|
|
||||||
|
async def __start(self, frame: StartFrame):
|
||||||
|
self.__create_input_task()
|
||||||
|
self.__create_push_task()
|
||||||
|
|
||||||
|
async def __cancel(self, frame: CancelFrame):
|
||||||
|
self._cancelling = True
|
||||||
|
await self.__cancel_input_task()
|
||||||
|
await self.__cancel_push_task()
|
||||||
|
|
||||||
|
#
|
||||||
|
# Handle interruptions
|
||||||
|
#
|
||||||
|
|
||||||
|
async def _start_interruption(self):
|
||||||
|
try:
|
||||||
|
# Cancel the push frame task. This will stop pushing frames downstream.
|
||||||
|
await self.__cancel_push_task()
|
||||||
|
|
||||||
|
# Cancel the input task. This will stop processing queued frames.
|
||||||
|
await self.__cancel_input_task()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Uncaught exception in {self}: {e}")
|
||||||
|
await self.push_error(ErrorFrame(str(e)))
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create a new input queue and task.
|
||||||
|
self.__create_input_task()
|
||||||
|
|
||||||
|
# Create a new output queue and task.
|
||||||
|
self.__create_push_task()
|
||||||
|
|
||||||
|
async def _stop_interruption(self):
|
||||||
|
# Nothing to do right now.
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def __internal_push_frame(self, frame: Frame, direction: FrameDirection):
|
||||||
try:
|
try:
|
||||||
timestamp = self._clock.get_time() if self._clock else 0
|
timestamp = self._clock.get_time() if self._clock else 0
|
||||||
if direction == FrameDirection.DOWNSTREAM and self._next:
|
if direction == FrameDirection.DOWNSTREAM and self._next:
|
||||||
@@ -332,49 +323,6 @@ class FrameProcessor(BaseObject):
|
|||||||
await self.push_error(ErrorFrame(str(e)))
|
await self.push_error(ErrorFrame(str(e)))
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def __start(self, frame: StartFrame):
|
|
||||||
self.__create_process_task()
|
|
||||||
self.__create_process_urgent_task()
|
|
||||||
self.__create_input_task()
|
|
||||||
|
|
||||||
async def __cancel(self, frame: CancelFrame):
|
|
||||||
self._cancelling = True
|
|
||||||
await self.__cancel_input_task()
|
|
||||||
await self.__cancel_process_task()
|
|
||||||
await self.__cancel_process_urgent_task()
|
|
||||||
|
|
||||||
#
|
|
||||||
# Handle interruptions
|
|
||||||
#
|
|
||||||
|
|
||||||
async def _start_interruption(self):
|
|
||||||
try:
|
|
||||||
# Cancel the streaming task.
|
|
||||||
await self.__cancel_input_task()
|
|
||||||
|
|
||||||
# Cancel the task processing frames. We do not cancel the task that
|
|
||||||
# is processing urgent frames.
|
|
||||||
await self.__cancel_process_task()
|
|
||||||
|
|
||||||
# If there's an interruption we should not block frames anymore.
|
|
||||||
self.__should_block_frames = False
|
|
||||||
|
|
||||||
# Clear the streaming queue, since we don't want to process its
|
|
||||||
# frame anymore (except system and urgent frames).
|
|
||||||
self.__streaming_queue.clear()
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception(f"Uncaught exception in {self}: {e}")
|
|
||||||
await self.push_error(ErrorFrame(str(e)))
|
|
||||||
raise
|
|
||||||
|
|
||||||
# Create a new tasks.
|
|
||||||
self.__create_process_task()
|
|
||||||
self.__create_input_task()
|
|
||||||
|
|
||||||
async def _stop_interruption(self):
|
|
||||||
# Nothing to do right now.
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _check_ready(self, frame: Frame):
|
def _check_ready(self, frame: Frame):
|
||||||
# If we are trying to push a frame but we still have no clock, it means
|
# If we are trying to push a frame but we still have no clock, it means
|
||||||
# we didn't process a StartFrame.
|
# we didn't process a StartFrame.
|
||||||
@@ -386,60 +334,49 @@ class FrameProcessor(BaseObject):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def __create_input_task(self):
|
def __create_input_task(self):
|
||||||
if not self.__streaming_frame_task:
|
if not self.__input_frame_task:
|
||||||
self.__streaming_frame_task = self.create_task(self.__streaming_frame_task_handler())
|
self.__should_block_frames = False
|
||||||
|
self.__input_event.clear()
|
||||||
|
self.__input_queue = asyncio.Queue()
|
||||||
|
self.__input_frame_task = self.create_task(self.__input_frame_task_handler())
|
||||||
|
|
||||||
async def __cancel_input_task(self):
|
async def __cancel_input_task(self):
|
||||||
if self.__streaming_frame_task:
|
if self.__input_frame_task:
|
||||||
await self.cancel_task(self.__streaming_frame_task)
|
await self.cancel_task(self.__input_frame_task)
|
||||||
self.__streaming_frame_task = None
|
self.__input_frame_task = None
|
||||||
|
|
||||||
def __create_process_task(self):
|
async def __input_frame_task_handler(self):
|
||||||
if not self.__process_task:
|
|
||||||
self.__process_queue = asyncio.Queue()
|
|
||||||
self.__process_task = self.create_task(
|
|
||||||
self.__process_task_handler(self.__process_queue)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def __cancel_process_task(self):
|
|
||||||
if self.__process_task:
|
|
||||||
await self.cancel_task(self.__process_task)
|
|
||||||
self.__process_task = None
|
|
||||||
|
|
||||||
def __create_process_urgent_task(self):
|
|
||||||
if not self.__process_urgent_task:
|
|
||||||
self.__process_urgent_task = self.create_task(
|
|
||||||
self.__process_task_handler(self.__process_urgent_queue)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def __cancel_process_urgent_task(self):
|
|
||||||
if self.__process_urgent_task:
|
|
||||||
await self.cancel_task(self.__process_urgent_task)
|
|
||||||
self.__process_urgent_task = None
|
|
||||||
|
|
||||||
async def __streaming_frame_task_handler(self):
|
|
||||||
while True:
|
while True:
|
||||||
if self.__should_block_frames:
|
if self.__should_block_frames:
|
||||||
logger.trace(f"{self}: frame processing paused")
|
logger.trace(f"{self}: frame processing paused")
|
||||||
await self.__streaming_event.wait()
|
await self.__input_event.wait()
|
||||||
self.__streaming_event.clear()
|
self.__input_event.clear()
|
||||||
self.__should_block_frames = False
|
self.__should_block_frames = False
|
||||||
logger.trace(f"{self}: frame processing resumed")
|
logger.trace(f"{self}: frame processing resumed")
|
||||||
|
|
||||||
item = await self.__streaming_queue.get()
|
(frame, direction, callback) = await self.__input_queue.get()
|
||||||
|
|
||||||
if isinstance(item.frame, SystemFrame):
|
|
||||||
await self.__process_urgent_queue.put(item)
|
|
||||||
else:
|
|
||||||
await self.__process_queue.put(item)
|
|
||||||
|
|
||||||
async def __process_task_handler(self, queue: asyncio.Queue):
|
|
||||||
while True:
|
|
||||||
item = await queue.get()
|
|
||||||
|
|
||||||
# Process the frame.
|
# Process the frame.
|
||||||
await self.process_frame(item.frame, item.direction)
|
await self.process_frame(frame, direction)
|
||||||
|
|
||||||
# If this frame has an associated callback, call it now.
|
# If this frame has an associated callback, call it now.
|
||||||
if item.callback:
|
if callback:
|
||||||
await item.callback(self, item.frame, item.direction)
|
await callback(self, frame, direction)
|
||||||
|
|
||||||
|
self.__input_queue.task_done()
|
||||||
|
|
||||||
|
def __create_push_task(self):
|
||||||
|
if not self.__push_frame_task:
|
||||||
|
self.__push_queue = asyncio.Queue()
|
||||||
|
self.__push_frame_task = self.create_task(self.__push_frame_task_handler())
|
||||||
|
|
||||||
|
async def __cancel_push_task(self):
|
||||||
|
if self.__push_frame_task:
|
||||||
|
await self.cancel_task(self.__push_frame_task)
|
||||||
|
self.__push_frame_task = None
|
||||||
|
|
||||||
|
async def __push_frame_task_handler(self):
|
||||||
|
while True:
|
||||||
|
(frame, direction) = await self.__push_queue.get()
|
||||||
|
await self.__internal_push_frame(frame, direction)
|
||||||
|
self.__push_queue.task_done()
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import unittest
|
|||||||
from pipecat.frames.frames import (
|
from pipecat.frames.frames import (
|
||||||
EndFrame,
|
EndFrame,
|
||||||
Frame,
|
Frame,
|
||||||
StartInterruptionFrame,
|
|
||||||
TextFrame,
|
TextFrame,
|
||||||
TranscriptionFrame,
|
TranscriptionFrame,
|
||||||
UserStartedSpeakingFrame,
|
UserStartedSpeakingFrame,
|
||||||
@@ -58,8 +57,8 @@ class TestFrameFilter(unittest.IsolatedAsyncioTestCase):
|
|||||||
|
|
||||||
async def test_system_frame(self):
|
async def test_system_frame(self):
|
||||||
filter = FrameFilter(types=())
|
filter = FrameFilter(types=())
|
||||||
frames_to_send = [StartInterruptionFrame()]
|
frames_to_send = [UserStartedSpeakingFrame()]
|
||||||
expected_down_frames = [StartInterruptionFrame]
|
expected_down_frames = [UserStartedSpeakingFrame]
|
||||||
await run_test(
|
await run_test(
|
||||||
filter,
|
filter,
|
||||||
frames_to_send=frames_to_send,
|
frames_to_send=frames_to_send,
|
||||||
|
|||||||
Reference in New Issue
Block a user