Compare commits
53 Commits
v0.0.53
...
hush/callT
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a751130a76 | ||
|
|
b29ac3c7a8 | ||
|
|
5222488fb5 | ||
|
|
c2fef9584b | ||
|
|
fc6aa6eae8 | ||
|
|
ddd5bf70ab | ||
|
|
aa59744444 | ||
|
|
067ddfe505 | ||
|
|
a64df978e7 | ||
|
|
7167719761 | ||
|
|
e1430be9f9 | ||
|
|
c2fe8e7fdb | ||
|
|
31c77d8e35 | ||
|
|
2a60d54830 | ||
|
|
b3c99887dc | ||
|
|
38ad75cc17 | ||
|
|
2debac314c | ||
|
|
e0c9a1a1a2 | ||
|
|
4cdcca588e | ||
|
|
a90e81e2eb | ||
|
|
0ba60c9e28 | ||
|
|
5ca5fbd825 | ||
|
|
2b52e2c109 | ||
|
|
7e8fc2e7e2 | ||
|
|
0d79a9eaa6 | ||
|
|
f89b9ec23f | ||
|
|
20d5824e56 | ||
|
|
f23baa78d8 | ||
|
|
cacd6ba3fa | ||
|
|
f87ecd3a51 | ||
|
|
b96a922aa8 | ||
|
|
401d3ff267 | ||
|
|
ab4221a4db | ||
|
|
bd6f82cf94 | ||
|
|
dd21b424d6 | ||
|
|
76884877dd | ||
|
|
0d6c680133 | ||
|
|
a27fe4bde2 | ||
|
|
177cb2ca8b | ||
|
|
3c970a3cee | ||
|
|
af02f8f1cd | ||
|
|
2e0fb198bf | ||
|
|
4f758c5a3b | ||
|
|
3e0836b340 | ||
|
|
2f23693bf3 | ||
|
|
b7dd9748cf | ||
|
|
d4d9c3b7ae | ||
|
|
090bc81ec5 | ||
|
|
e3d53d3d9a | ||
|
|
262d3a19c9 | ||
|
|
491feb691c | ||
|
|
e4f83b237e | ||
|
|
14e5419913 |
4
.github/workflows/tests.yaml
vendored
4
.github/workflows/tests.yaml
vendored
@@ -1,4 +1,4 @@
|
||||
name: test
|
||||
name: tests
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
@@ -49,4 +49,4 @@ jobs:
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
pytest --ignore-glob="*to_be_updated*" --ignore-glob=*pipeline_source* src tests
|
||||
pytest
|
||||
|
||||
40
CHANGELOG.md
40
CHANGELOG.md
@@ -5,12 +5,48 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- It is now possible to specify the period of the `PipelineTask` heartbeat
|
||||
frames with `heartbeats_period_secs`.
|
||||
|
||||
### Changed
|
||||
|
||||
- Modified `TranscriptProcessor` to use TTS text frames for more accurate assistant
|
||||
transcripts. Assistant messages are now aggregated based on bot speaking boundaries
|
||||
rather than LLM context, providing better handling of interruptions and partial
|
||||
utterances.
|
||||
|
||||
- Updated foundational examples `28a-transcription-processor-openai.py`,
|
||||
`28b-transcript-processor-anthropic.py`, and
|
||||
`28c-transcription-processor-gemini.py` to use the updated
|
||||
`TranscriptProcessor`.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed a type error when using `voice_settings` in `ElevenLabsHttpTTSService`.
|
||||
|
||||
- Fixed an issue where `OpenAIRealtimeBetaLLMService` function calling resulted
|
||||
in an error.
|
||||
|
||||
### Performance
|
||||
|
||||
- Replaced audio resampling library `resampy` with `soxr`. Resampling a 2:21s
|
||||
audio file from 24KHz to 16KHz took 1.41s with `resampy` and 0.031s with
|
||||
`soxr` with similar audio quality.
|
||||
|
||||
### Other
|
||||
|
||||
- Added initial unit test infrastructure.
|
||||
|
||||
## [0.0.53] - 2025-01-18
|
||||
|
||||
### Added
|
||||
|
||||
- Added `ElevenLabsHttpTTSService` and the
|
||||
`07d-interruptible-elevenlabs-http.py` foundational example.
|
||||
- Added `ElevenLabsHttpTTSService` which uses EleveLabs' HTTP API instead of the
|
||||
websocket one.
|
||||
|
||||
- Introduced pipeline frame observers. Observers can view all the frames that go
|
||||
through the pipeline without the need to inject processors in the
|
||||
|
||||
10
README.md
10
README.md
@@ -2,7 +2,7 @@
|
||||
<img alt="pipecat" width="300px" height="auto" src="https://raw.githubusercontent.com/pipecat-ai/pipecat/main/pipecat.png">
|
||||
</div></h1>
|
||||
|
||||
[](https://pypi.org/project/pipecat-ai) [](https://docs.pipecat.ai) [](https://discord.gg/pipecat) <a href="https://app.commanddash.io/agent/github_pipecat-ai_pipecat"><img src="https://img.shields.io/badge/AI-Code%20Agent-EB9FDA"></a>
|
||||
[](https://pypi.org/project/pipecat-ai)  [](https://docs.pipecat.ai) [](https://discord.gg/pipecat) <a href="https://app.commanddash.io/agent/github_pipecat-ai_pipecat"><img src="https://img.shields.io/badge/AI-Code%20Agent-EB9FDA"></a>
|
||||
|
||||
Pipecat is an open source Python framework for building voice and multimodal conversational agents. It handles the complex orchestration of AI services, network transport, audio processing, and multimodal interactions, letting you focus on creating engaging experiences.
|
||||
|
||||
@@ -53,12 +53,6 @@ To keep things lightweight, only the core framework is included by default. If y
|
||||
pip install "pipecat-ai[option,...]"
|
||||
```
|
||||
|
||||
Or you can install all of them with:
|
||||
|
||||
```shell
|
||||
pip install "pipecat-ai[all]"
|
||||
```
|
||||
|
||||
Available options include:
|
||||
|
||||
| Category | Services | Install Command Example |
|
||||
@@ -195,7 +189,7 @@ pip install "path_to_this_repo[option,...]"
|
||||
From the root directory, run:
|
||||
|
||||
```shell
|
||||
pytest --doctest-modules --ignore-glob="*to_be_updated*" --ignore-glob=*pipeline_source* src tests
|
||||
pytest
|
||||
```
|
||||
|
||||
## Setting up your editor
|
||||
|
||||
@@ -4,6 +4,7 @@ pip-tools~=7.4.1
|
||||
pre-commit~=4.0.1
|
||||
pyright~=1.1.392
|
||||
pytest~=8.3.4
|
||||
pytest-asyncio~=0.25.2
|
||||
ruff~=0.9.1
|
||||
setuptools~=75.8.0
|
||||
setuptools_scm~=8.1.0
|
||||
|
||||
@@ -42,7 +42,7 @@ Next, follow the steps in the README for each demo.
|
||||
| [Dialin Chatbot](dialin-chatbot) | A chatbot that connects to an incoming phone call from Daily or Twilio. | Deepgram, ElevenLabs, OpenAI, Daily, Twilio |
|
||||
| [Twilio Chatbot](twilio-chatbot) | A chatbot that connects to an incoming phone call from Twilio. | Deepgram, ElevenLabs, OpenAI, Daily, Twilio |
|
||||
| [studypal](studypal) | A chatbot to have a conversation about any article on the web | |
|
||||
| [WebSocket Chatbot Server](websocket-server) | A real-time websocket server that handles audio streaming and bot interactions with speech-to-text and text-to-speech capabilities | `python-websockets`, `openai`, `deepgram`, `silero-tts`, `numpy` |
|
||||
| [WebSocket Chatbot Server](websocket-server) | A real-time websocket server that handles audio streaming and bot interactions with speech-to-text and text-to-speech capabilities. | Cartesia, Deepgram, OpenAI, Websockets |
|
||||
|
||||
> [!IMPORTANT]
|
||||
> These example projects use Daily as a WebRTC transport and can be joined using their hosted Prebuilt UI.
|
||||
|
||||
@@ -37,7 +37,16 @@ Run `bot_runner.py` to handle incoming HTTP requests:
|
||||
|
||||
Then target the following URL:
|
||||
|
||||
`POST /daily_start_bot`
|
||||
```bash
|
||||
curl -X POST 'http://localhost:7860/daily_start_bot' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"callId": "callId-from-call",
|
||||
"callDomain": "callDomain-from-call"
|
||||
}'
|
||||
```
|
||||
|
||||
Use [this guide](https://docs.pipecat.ai/guides/telephony/daily-webrtc) to connect a phone number purchased from Daily to the bot.
|
||||
|
||||
For more configuration options, please consult Daily's API documentation.
|
||||
|
||||
@@ -82,4 +91,4 @@ If you're using Twilio as a number vendor:
|
||||
|
||||
## Need to do something more advanced?
|
||||
|
||||
This demo covers the basics of bot telephony. If you want to know more about working with PSTN / SIP, please ping us on [Discord](https://discord.gg/pipecat).
|
||||
This demo covers the basics of bot telephony. If you want to know more about working with PSTN / SIP, please ping us on [Discord](https://discord.gg/pipecat).
|
||||
|
||||
@@ -1,3 +1,8 @@
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
@@ -5,13 +10,16 @@ import sys
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import EndFrame
|
||||
from pipecat.frames.frames import EndFrame, TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.ai_services import LLMService
|
||||
from pipecat.services.deepgram import DeepgramSTTService
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyDialinSettings, DailyParams, DailyTransport
|
||||
@@ -55,16 +63,62 @@ async def main(room_url: str, token: str, callId: str, callDomain: str):
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
content = f"""
|
||||
You are a delivery service customer support specialist supporting customers with their orders.
|
||||
Begin with: "Hello, this is Hailey from customer support. What can I help you with today?"
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are Chatbot, a friendly, helpful robot. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way, but keep your responses brief. Start by saying 'Oh, hello! Who dares dial me at this hour?!'.",
|
||||
"content": content,
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "transfer_call",
|
||||
"description": "Transfer the call to a person. This function is used to connect the call to a real person. Examples of real people are: managers, supervisors, or other customer support specialists. Any person is okay as long as they are not a bot.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"call_id": {
|
||||
"type": "string",
|
||||
"description": "This is always {callId}.",
|
||||
},
|
||||
"summary": {
|
||||
"type": "string",
|
||||
"description": """
|
||||
Provide a concise summary in 3-5 sentences. Highlight any important details or unusual aspects of the conversation.
|
||||
""",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
async def default_transfer_call(
|
||||
function_name, tool_call_id, args, llm: LLMService, context, result_callback
|
||||
):
|
||||
logger.debug(f"default_transfer_call: {function_name} {tool_call_id} {args}")
|
||||
await result_callback(
|
||||
{
|
||||
"transfer_call": False,
|
||||
"reason": "To transfer call calls, please dial in to the room using a phone or a SIP client.",
|
||||
}
|
||||
)
|
||||
|
||||
llm.register_function(
|
||||
function_name="transfer_call",
|
||||
callback=default_transfer_call,
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
@@ -87,6 +141,44 @@ async def main(room_url: str, token: str, callId: str, callDomain: str):
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.queue_frame(EndFrame())
|
||||
|
||||
@transport.event_handler("on_dialin_ready")
|
||||
async def on_dialin_ready(_, sip_endpoint):
|
||||
logger.info(f"on_dialin_ready: {sip_endpoint}")
|
||||
|
||||
@transport.event_handler("on_dialin_connected")
|
||||
async def on_dialin_connected(transport, event):
|
||||
logger.info(f"on_dialin_connected: {event}")
|
||||
sip_session_id = event["sessionId"]
|
||||
|
||||
async def transfer_call(
|
||||
function_name, tool_call_id, args, llm: LLMService, context, result_callback
|
||||
):
|
||||
logger.debug(f"transfer_call: {function_name} {tool_call_id} {args}")
|
||||
|
||||
# sip_url = "sip:your_user_name@sip.linphone.org"
|
||||
|
||||
sip_url = (
|
||||
f"sip:your_username@dailyco.sip.twilio.com?x-daily_id={room_url.split('/')[-1]}"
|
||||
)
|
||||
|
||||
try:
|
||||
await transport.sip_refer(
|
||||
settings={
|
||||
"sessionId": sip_session_id,
|
||||
"toEndPoint": sip_url,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during SIP refer: {e}")
|
||||
await result_callback({"transfer_call": False})
|
||||
|
||||
await result_callback({"transfer_call": True})
|
||||
|
||||
llm.register_function(
|
||||
function_name="transfer_call",
|
||||
callback=transfer_call,
|
||||
)
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,7 +15,11 @@ from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import EndFrame, TranscriptionMessage, TranscriptionUpdateFrame
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -33,13 +37,49 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptHandler:
|
||||
"""Simple handler to demonstrate transcript processing.
|
||||
"""Handles real-time transcript processing and output.
|
||||
|
||||
Maintains a list of conversation messages and logs them with timestamps.
|
||||
Maintains a list of conversation messages and outputs them either to a log
|
||||
or to a file as they are received. Each message includes its timestamp and role.
|
||||
|
||||
Attributes:
|
||||
messages: List of all processed transcript messages
|
||||
output_file: Optional path to file where transcript is saved. If None, outputs to log only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, output_file: Optional[str] = None):
|
||||
"""Initialize handler with optional file output.
|
||||
|
||||
Args:
|
||||
output_file: Path to output file. If None, outputs to log only.
|
||||
"""
|
||||
self.messages: List[TranscriptionMessage] = []
|
||||
self.output_file: Optional[str] = output_file
|
||||
logger.debug(
|
||||
f"TranscriptHandler initialized {'with output_file=' + output_file if output_file else 'with log output only'}"
|
||||
)
|
||||
|
||||
async def save_message(self, message: TranscriptionMessage):
|
||||
"""Save a single transcript message.
|
||||
|
||||
Outputs the message to the log and optionally to a file.
|
||||
|
||||
Args:
|
||||
message: The message to save
|
||||
"""
|
||||
timestamp = f"[{message.timestamp}] " if message.timestamp else ""
|
||||
line = f"{timestamp}{message.role}: {message.content}"
|
||||
|
||||
# Always log the message
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
# Optionally write to file
|
||||
if self.output_file:
|
||||
try:
|
||||
with open(self.output_file, "a", encoding="utf-8") as f:
|
||||
f.write(line + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving transcript message to file: {e}")
|
||||
|
||||
async def on_transcript_update(
|
||||
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
|
||||
@@ -50,13 +90,11 @@ class TranscriptHandler:
|
||||
processor: The TranscriptProcessor that emitted the update
|
||||
frame: TranscriptionUpdateFrame containing new messages
|
||||
"""
|
||||
self.messages.extend(frame.messages)
|
||||
logger.debug(f"Received transcript update with {len(frame.messages)} new messages")
|
||||
|
||||
# Log the new messages
|
||||
logger.info("New transcript messages:")
|
||||
for msg in frame.messages:
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
logger.info(f"{timestamp}{msg.role}: {msg.content}")
|
||||
self.messages.append(msg)
|
||||
await self.save_message(msg)
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -99,7 +137,8 @@ async def main():
|
||||
|
||||
# Create transcript processor and handler
|
||||
transcript = TranscriptProcessor()
|
||||
transcript_handler = TranscriptHandler()
|
||||
transcript_handler = TranscriptHandler() # Output to log only
|
||||
# transcript_handler = TranscriptHandler(output_file="transcript.txt") # Output to file and log
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
@@ -110,8 +149,8 @@ async def main():
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
transcript.assistant(), # Assistant transcripts
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
@@ -130,7 +169,8 @@ async def main():
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.queue_frame(EndFrame())
|
||||
# Stop the pipeline immediately when the participant leaves
|
||||
await task.queue_frame(CancelFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,7 +15,11 @@ from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import EndFrame, TranscriptionMessage, TranscriptionUpdateFrame
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -33,13 +37,49 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptHandler:
|
||||
"""Simple handler to demonstrate transcript processing.
|
||||
"""Handles real-time transcript processing and output.
|
||||
|
||||
Maintains a list of conversation messages and logs them with timestamps.
|
||||
Maintains a list of conversation messages and outputs them either to a log
|
||||
or to a file as they are received. Each message includes its timestamp and role.
|
||||
|
||||
Attributes:
|
||||
messages: List of all processed transcript messages
|
||||
output_file: Optional path to file where transcript is saved. If None, outputs to log only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, output_file: Optional[str] = None):
|
||||
"""Initialize handler with optional file output.
|
||||
|
||||
Args:
|
||||
output_file: Path to output file. If None, outputs to log only.
|
||||
"""
|
||||
self.messages: List[TranscriptionMessage] = []
|
||||
self.output_file: Optional[str] = output_file
|
||||
logger.debug(
|
||||
f"TranscriptHandler initialized {'with output_file=' + output_file if output_file else 'with log output only'}"
|
||||
)
|
||||
|
||||
async def save_message(self, message: TranscriptionMessage):
|
||||
"""Save a single transcript message.
|
||||
|
||||
Outputs the message to the log and optionally to a file.
|
||||
|
||||
Args:
|
||||
message: The message to save
|
||||
"""
|
||||
timestamp = f"[{message.timestamp}] " if message.timestamp else ""
|
||||
line = f"{timestamp}{message.role}: {message.content}"
|
||||
|
||||
# Always log the message
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
# Optionally write to file
|
||||
if self.output_file:
|
||||
try:
|
||||
with open(self.output_file, "a", encoding="utf-8") as f:
|
||||
f.write(line + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving transcript message to file: {e}")
|
||||
|
||||
async def on_transcript_update(
|
||||
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
|
||||
@@ -50,13 +90,11 @@ class TranscriptHandler:
|
||||
processor: The TranscriptProcessor that emitted the update
|
||||
frame: TranscriptionUpdateFrame containing new messages
|
||||
"""
|
||||
self.messages.extend(frame.messages)
|
||||
logger.debug(f"Received transcript update with {len(frame.messages)} new messages")
|
||||
|
||||
# Log the new messages
|
||||
logger.info("New transcript messages:")
|
||||
for msg in frame.messages:
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
logger.info(f"{timestamp}{msg.role}: {msg.content}")
|
||||
self.messages.append(msg)
|
||||
await self.save_message(msg)
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -99,7 +137,8 @@ async def main():
|
||||
|
||||
# Create transcript processor and handler
|
||||
transcript = TranscriptProcessor()
|
||||
transcript_handler = TranscriptHandler()
|
||||
transcript_handler = TranscriptHandler() # Output to log only
|
||||
# transcript_handler = TranscriptHandler(output_file="transcript.txt") # Output to file and log
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
@@ -110,8 +149,8 @@ async def main():
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
transcript.assistant(), # Assistant transcripts
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
@@ -130,7 +169,8 @@ async def main():
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.queue_frame(EndFrame())
|
||||
# Stop the pipeline immediately when the participant leaves
|
||||
await task.queue_frame(CancelFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
@@ -15,7 +15,11 @@ from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import EndFrame, TranscriptionMessage, TranscriptionUpdateFrame
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
@@ -34,13 +38,49 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptHandler:
|
||||
"""Simple handler to demonstrate transcript processing.
|
||||
"""Handles real-time transcript processing and output.
|
||||
|
||||
Maintains a list of conversation messages and logs them with timestamps.
|
||||
Maintains a list of conversation messages and outputs them either to a log
|
||||
or to a file as they are received. Each message includes its timestamp and role.
|
||||
|
||||
Attributes:
|
||||
messages: List of all processed transcript messages
|
||||
output_file: Optional path to file where transcript is saved. If None, outputs to log only.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, output_file: Optional[str] = None):
|
||||
"""Initialize handler with optional file output.
|
||||
|
||||
Args:
|
||||
output_file: Path to output file. If None, outputs to log only.
|
||||
"""
|
||||
self.messages: List[TranscriptionMessage] = []
|
||||
self.output_file: Optional[str] = output_file
|
||||
logger.debug(
|
||||
f"TranscriptHandler initialized {'with output_file=' + output_file if output_file else 'with log output only'}"
|
||||
)
|
||||
|
||||
async def save_message(self, message: TranscriptionMessage):
|
||||
"""Save a single transcript message.
|
||||
|
||||
Outputs the message to the log and optionally to a file.
|
||||
|
||||
Args:
|
||||
message: The message to save
|
||||
"""
|
||||
timestamp = f"[{message.timestamp}] " if message.timestamp else ""
|
||||
line = f"{timestamp}{message.role}: {message.content}"
|
||||
|
||||
# Always log the message
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
# Optionally write to file
|
||||
if self.output_file:
|
||||
try:
|
||||
with open(self.output_file, "a", encoding="utf-8") as f:
|
||||
f.write(line + "\n")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving transcript message to file: {e}")
|
||||
|
||||
async def on_transcript_update(
|
||||
self, processor: TranscriptProcessor, frame: TranscriptionUpdateFrame
|
||||
@@ -51,13 +91,11 @@ class TranscriptHandler:
|
||||
processor: The TranscriptProcessor that emitted the update
|
||||
frame: TranscriptionUpdateFrame containing new messages
|
||||
"""
|
||||
self.messages.extend(frame.messages)
|
||||
logger.debug(f"Received transcript update with {len(frame.messages)} new messages")
|
||||
|
||||
# Log the new messages
|
||||
logger.info("New transcript messages:")
|
||||
for msg in frame.messages:
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
logger.info(f"{timestamp}{msg.role}: {msg.content}")
|
||||
self.messages.append(msg)
|
||||
await self.save_message(msg)
|
||||
|
||||
|
||||
async def main():
|
||||
@@ -102,7 +140,8 @@ async def main():
|
||||
|
||||
# Create transcript processor and handler
|
||||
transcript = TranscriptProcessor()
|
||||
transcript_handler = TranscriptHandler()
|
||||
transcript_handler = TranscriptHandler() # Output to log only
|
||||
# transcript_handler = TranscriptHandler(output_file="transcript.txt") # Output to file and log
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
@@ -113,8 +152,8 @@ async def main():
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
transcript.assistant(), # Assistant transcripts
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
@@ -140,7 +179,8 @@ async def main():
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.queue_frame(EndFrame())
|
||||
# Stop the pipeline immediately when the participant leaves
|
||||
await task.queue_frame(CancelFrame())
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
|
||||
156
examples/foundational/32-double-room.py
Normal file
156
examples/foundational/32-double-room.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import BotSpeakingFrame, EndFrame, Frame, TextFrame, TTSSpeakFrame
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.filters.function_filter import FunctionFilter
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.transports.services.daily import DailyOutputTransport, DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class DebugObserver(BaseObserver):
|
||||
"""Observer to log interruptions and bot speaking events to the console.
|
||||
|
||||
Logs all frame instances of:
|
||||
- StartInterruptionFrame
|
||||
- BotStartedSpeakingFrame
|
||||
- BotStoppedSpeakingFrame
|
||||
|
||||
This allows you to see the frame flow from processor to processor through the pipeline for these frames.
|
||||
Log format: [EVENT TYPE]: [source processor] → [destination processor] at [timestamp]s
|
||||
"""
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
arrow = "→" if direction == FrameDirection.DOWNSTREAM else "←"
|
||||
# Convert timestamp to seconds for readability
|
||||
time_sec = timestamp / 1_000_000_000
|
||||
|
||||
if isinstance(frame, BotSpeakingFrame):
|
||||
return
|
||||
|
||||
if isinstance(dst, DailyOutputTransport):
|
||||
logger.debug(f"{frame} {src} {arrow} {dst} at {time_sec:.2f}s")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, _) = await configure(session)
|
||||
|
||||
transport1 = DailyTransport(
|
||||
"https://hush.daily.co/sip",
|
||||
None,
|
||||
"Don't Do Anything",
|
||||
DailyParams(audio_out_enabled=True),
|
||||
)
|
||||
|
||||
transport2 = DailyTransport(
|
||||
"https://hush.daily.co/demo",
|
||||
None,
|
||||
"Summarize Call",
|
||||
DailyParams(audio_out_enabled=True),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
async def true_filter(frame) -> bool:
|
||||
return True
|
||||
|
||||
async def false_filter(frame) -> bool:
|
||||
return False
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport1.input(),
|
||||
transport2.input(),
|
||||
ParallelPipeline(
|
||||
[transport1.output()],
|
||||
[tts, transport2.output()],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
observers=[DebugObserver()],
|
||||
),
|
||||
)
|
||||
|
||||
# Register an event handler so we can play the audio when the
|
||||
# participant joins.
|
||||
@transport1.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
participant_name = participant.get("info", {}).get("userName", "")
|
||||
logger.info(f"-- {participant_name} joined transport1")
|
||||
|
||||
def get_call_summary():
|
||||
"""In a real app this would be a call to a database or API."""
|
||||
# Randomly choose between two options
|
||||
message = random.choice(
|
||||
[
|
||||
"Alice needs help finding her customer record.",
|
||||
"Bob is calling about his lost password.",
|
||||
]
|
||||
)
|
||||
|
||||
return message
|
||||
|
||||
@transport2.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
participant_name = participant.get("info", {}).get("userName", "")
|
||||
logger.info(f"-- {participant_name} joined transport2")
|
||||
call_summary = get_call_summary()
|
||||
await task.queue_frames(
|
||||
[
|
||||
TTSSpeakFrame(
|
||||
f"Hi {participant_name}! Here's the summary of the call: {call_summary}"
|
||||
),
|
||||
EndFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -67,8 +67,8 @@ talking_frame = SpriteFrame(images=sprites)
|
||||
|
||||
|
||||
class TalkingAnimation(FrameProcessor):
|
||||
"""This class starts a talking animation when it receives an first AudioFrame,
|
||||
and then returns to a "quiet" sprite when it sees a TTSStoppedFrame.
|
||||
"""This class starts a talking animation when it receives an first BotStartedSpeakingFrame,
|
||||
and then returns to a "quiet" sprite when it sees a BotStoppedSpeakingFrame.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
@@ -66,7 +66,7 @@ The build UI files can be found in `frontend/out`
|
||||
|
||||
Start the API / bot manager:
|
||||
|
||||
`python src/bot_runner.py`
|
||||
`python src/bot_runner.py --host localhost`
|
||||
|
||||
If you'd like to run a custom domain or port:
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ ELEVENLABS_API_KEY=
|
||||
ELEVENLABS_VOICE_ID=
|
||||
FAL_KEY=
|
||||
OPENAI_API_KEY=
|
||||
GOOGLE_API_KEY=
|
||||
|
||||
ENV= # dev | production
|
||||
RUN_AS_VM= # Set this if you want to run bots on process (not launch a new VM)
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useState } from "react";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import {
|
||||
useDaily,
|
||||
useParticipantIds,
|
||||
@@ -33,7 +33,9 @@ const Story: React.FC<StoryProps> = ({ handleLeave }) => {
|
||||
setTimeout(() => daily.setLocalAudio(true), 500);
|
||||
setStoryState("user");
|
||||
} else {
|
||||
daily.setLocalAudio(false);
|
||||
// Uncomment the next line to mute the mic while the
|
||||
// assistant it talking. Leave it commented to allow for interruptions
|
||||
// daily.setLocalAudio(false);
|
||||
setStoryState("assistant");
|
||||
}
|
||||
},
|
||||
@@ -58,7 +60,7 @@ const Story: React.FC<StoryProps> = ({ handleLeave }) => {
|
||||
{participantIds.length >= 1 ? (
|
||||
<VideoTile
|
||||
sessionId={participantIds[0]}
|
||||
inactive={storyState === "user"}
|
||||
inactive={false}
|
||||
/>
|
||||
) : (
|
||||
<span className="p-3 rounded-full bg-gray-900/60 animate-pulse">
|
||||
@@ -71,7 +73,7 @@ const Story: React.FC<StoryProps> = ({ handleLeave }) => {
|
||||
)}
|
||||
<DailyAudio />
|
||||
</div>
|
||||
<UserInputIndicator active={storyState === "user"} />
|
||||
<UserInputIndicator active={true} />
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -43,25 +43,8 @@
|
||||
transition: opacity 0.5s ease;
|
||||
}
|
||||
|
||||
|
||||
@keyframes pulse {
|
||||
0% {
|
||||
outline-width: 6px;
|
||||
@apply outline-teal-500/10;
|
||||
}
|
||||
50% {
|
||||
outline-width: 24px;
|
||||
@apply outline-teal-500/50;
|
||||
}
|
||||
100% {
|
||||
outline-width: 6px;
|
||||
@apply outline-teal-500/10;
|
||||
}
|
||||
}
|
||||
|
||||
.micIconActive{
|
||||
@apply bg-teal-950 border-teal-500 outline-teal-500/20;
|
||||
animation: pulse 2s infinite ease-in-out;
|
||||
}
|
||||
|
||||
.micIconActive svg{
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import React, { useState, useEffect, useRef } from "react";
|
||||
|
||||
import { useAppMessage } from "@daily-co/daily-react";
|
||||
import { DailyEventObjectAppMessage } from "@daily-co/daily-js";
|
||||
@@ -13,12 +13,31 @@ interface Props {
|
||||
|
||||
export default function UserInputIndicator({ active }: Props) {
|
||||
const [transcription, setTranscription] = useState<string[]>([]);
|
||||
const timeoutRef = useRef<NodeJS.Timeout>();
|
||||
|
||||
const resetTimeout = () => {
|
||||
if (timeoutRef.current) {
|
||||
clearTimeout(timeoutRef.current);
|
||||
}
|
||||
timeoutRef.current = setTimeout(() => {
|
||||
setTranscription([]);
|
||||
}, 5000);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (timeoutRef.current) {
|
||||
clearTimeout(timeoutRef.current);
|
||||
}
|
||||
};
|
||||
}, []);
|
||||
|
||||
useAppMessage({
|
||||
onAppMessage: (e: DailyEventObjectAppMessage<any>) => {
|
||||
if (e.fromId && e.fromId === "transcription") {
|
||||
if (e.data.user_id === "" && e.data.is_final) {
|
||||
setTranscription((t) => [...t, ...e.data.text.split(" ")]);
|
||||
resetTimeout();
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -2,4 +2,4 @@ async_timeout
|
||||
fastapi
|
||||
uvicorn
|
||||
python-dotenv
|
||||
pipecat-ai[daily,elevenlabs,openai,fal]
|
||||
pipecat-ai[daily,openai,fal,google,cartesia]
|
||||
|
||||
@@ -13,16 +13,23 @@ import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from processors import StoryImageProcessor, StoryProcessor
|
||||
from prompts import CUE_USER_TURN, LLM_BASE_PROMPT, LLM_INTRO_PROMPT
|
||||
from prompts import CUE_USER_TURN, LLM_BASE_PROMPT
|
||||
from utils.helpers import load_images, load_sounds
|
||||
|
||||
from pipecat.frames.frames import EndFrame, LLMMessagesFrame, StopTaskFrame
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import EndFrame, StopTaskFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.logger import FrameLogger
|
||||
from pipecat.services.cartesia import CartesiaHttpTTSService, CartesiaTTSService
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.services.fal import FalImageGenService
|
||||
from pipecat.services.google import GoogleLLMService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import (
|
||||
DailyParams,
|
||||
@@ -53,6 +60,7 @@ async def main(room_url, token=None):
|
||||
camera_out_width=768,
|
||||
camera_out_height=768,
|
||||
transcription_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_enabled=True,
|
||||
),
|
||||
)
|
||||
@@ -61,11 +69,10 @@ async def main(room_url, token=None):
|
||||
|
||||
# -------------- Services --------------- #
|
||||
|
||||
llm_service = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
llm_service = GoogleLLMService(api_key=os.getenv("GOOGLE_API_KEY"))
|
||||
|
||||
tts_service = ElevenLabsTTSService(
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"),
|
||||
voice_id=os.getenv("ELEVENLABS_VOICE_ID"),
|
||||
api_key=os.getenv("ELEVENLABS_API_KEY"), voice_id=os.getenv("ELEVENLABS_VOICE_ID")
|
||||
)
|
||||
|
||||
fal_service_params = FalImageGenService.InputParams(
|
||||
@@ -74,7 +81,7 @@ async def main(room_url, token=None):
|
||||
|
||||
fal_service = FalImageGenService(
|
||||
aiohttp_session=session,
|
||||
model="fal-ai/fast-lightning-sdxl",
|
||||
model="fal-ai/stable-diffusion-v35-medium",
|
||||
params=fal_service_params,
|
||||
key=os.getenv("FAL_KEY"),
|
||||
)
|
||||
@@ -97,35 +104,8 @@ async def main(room_url, token=None):
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
# The intro pipeline is used to start
|
||||
# the story (as per LLM_INTRO_PROMPT)
|
||||
intro_pipeline = Pipeline([llm_service, tts_service, transport.output()])
|
||||
|
||||
intro_task = PipelineTask(intro_pipeline)
|
||||
|
||||
logger.debug("Waiting for participant...")
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
logger.debug("Participant joined, storytime commence!")
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
await intro_task.queue_frames(
|
||||
[
|
||||
images["book1"],
|
||||
LLMMessagesFrame([LLM_INTRO_PROMPT]),
|
||||
DailyTransportMessageFrame(CUE_USER_TURN),
|
||||
sounds["listening"],
|
||||
images["book2"],
|
||||
StopTaskFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
# We run the intro pipeline. This will start the transport. The intro
|
||||
# task will exit after StopTaskFrame is processed.
|
||||
await runner.run(intro_task)
|
||||
|
||||
# The main story pipeline is used to continue the story based on user
|
||||
# input.
|
||||
main_pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
@@ -139,11 +119,32 @@ async def main(room_url, token=None):
|
||||
]
|
||||
)
|
||||
|
||||
main_task = PipelineTask(main_pipeline)
|
||||
main_task = PipelineTask(
|
||||
main_pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
logger.debug("Participant joined, storytime commence!")
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
await main_task.queue_frames(
|
||||
[
|
||||
images["book1"],
|
||||
context_aggregator.user().get_context_frame(),
|
||||
DailyTransportMessageFrame(CUE_USER_TURN),
|
||||
# sounds["listening"],
|
||||
images["book2"],
|
||||
]
|
||||
)
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await intro_task.queue_frame(EndFrame())
|
||||
await main_task.queue_frame(EndFrame())
|
||||
|
||||
@transport.event_handler("on_call_state_updated")
|
||||
|
||||
@@ -114,7 +114,7 @@ async def start_bot(request: Request) -> JSONResponse:
|
||||
else:
|
||||
try:
|
||||
subprocess.Popen(
|
||||
[f"python3 -m bot -u {room.url} -t {token}"],
|
||||
[f"python -m bot -u {room.url} -t {token}"],
|
||||
shell=True,
|
||||
bufsize=1,
|
||||
cwd=os.path.dirname(os.path.abspath(__file__)),
|
||||
@@ -175,7 +175,7 @@ async def virtualize_bot(room_url: str, token: str):
|
||||
image = data[0]["config"]["image"]
|
||||
|
||||
# Machine configuration
|
||||
cmd = f"python3 src/bot.py -u {room_url} -t {token}"
|
||||
cmd = f"python src/bot.py -u {room_url} -t {token}"
|
||||
cmd = cmd.split()
|
||||
worker_props = {
|
||||
"config": {
|
||||
@@ -215,7 +215,7 @@ async def virtualize_bot(room_url: str, token: str):
|
||||
if __name__ == "__main__":
|
||||
# Check environment variables
|
||||
required_env_vars = [
|
||||
"OPENAI_API_KEY",
|
||||
"GOOGLE_API_KEY",
|
||||
"DAILY_API_KEY",
|
||||
"FAL_KEY",
|
||||
"ELEVENLABS_VOICE_ID",
|
||||
|
||||
@@ -37,8 +37,7 @@ class StoryPromptFrame(TextFrame):
|
||||
|
||||
|
||||
class StoryImageProcessor(FrameProcessor):
|
||||
"""
|
||||
Processor for image prompt frames that will be sent to the FAL service.
|
||||
"""Processor for image prompt frames that will be sent to the FAL service.
|
||||
|
||||
This processor is responsible for consuming frames of type `StoryImageFrame`.
|
||||
It processes them by passing it to the FAL service.
|
||||
@@ -68,8 +67,7 @@ class StoryImageProcessor(FrameProcessor):
|
||||
|
||||
|
||||
class StoryProcessor(FrameProcessor):
|
||||
"""
|
||||
Primary frame processor. It takes the frames generated by the LLM
|
||||
"""Primary frame processor. It takes the frames generated by the LLM
|
||||
and processes them into image prompts and story pages (sentences).
|
||||
For a clearer picture of how this works, reference prompts.py
|
||||
|
||||
@@ -97,44 +95,10 @@ class StoryProcessor(FrameProcessor):
|
||||
await self.push_frame(sounds["talking"])
|
||||
|
||||
elif isinstance(frame, TextFrame):
|
||||
# We want to look for sentence breaks in the text
|
||||
# but since TextFrames are streamed from the LLM
|
||||
# we need to keep a buffer of the text we've seen so far
|
||||
# Add new text to the buffer
|
||||
self._text += frame.text
|
||||
|
||||
# IMAGE PROMPT
|
||||
# Looking for: < [image prompt] > in the LLM response
|
||||
# We prompted our LLM to add an image prompt in the response
|
||||
# so we use regex matching to find it and yield a StoryImageFrame
|
||||
if re.search(r"<.*?>", self._text):
|
||||
if not re.search(r"<.*?>.*?>", self._text):
|
||||
# Pass any frames until we have a closing bracket
|
||||
# otherwise the image prompt will be passed to TTS
|
||||
pass
|
||||
# Extract the image prompt from the text using regex
|
||||
image_prompt = re.search(r"<(.*?)>", self._text).group(1)
|
||||
# Remove the image prompt from the text
|
||||
self._text = re.sub(r"<.*?>", "", self._text, count=1)
|
||||
# Process the image prompt frame
|
||||
await self.push_frame(StoryImageFrame(image_prompt))
|
||||
|
||||
# STORY PAGE
|
||||
# Looking for: [break] in the LLM response
|
||||
# We prompted our LLM to add a [break] after each sentence
|
||||
# so we use regex matching to find it in the LLM response
|
||||
if re.search(r".*\[[bB]reak\].*", self._text):
|
||||
# Remove the [break] token from the text
|
||||
# so it isn't spoken out loud by the TTS
|
||||
self._text = re.sub(r"\[[bB]reak\]", "", self._text, flags=re.IGNORECASE)
|
||||
self._text = self._text.replace("\n", " ")
|
||||
if len(self._text) > 2:
|
||||
# Append the sentence to the story
|
||||
self._story.append(self._text)
|
||||
await self.push_frame(StoryPageFrame(self._text))
|
||||
# Assert that it's the LLMs turn, until we're finished
|
||||
await self.push_frame(DailyTransportMessageFrame(CUE_ASSISTANT_TURN))
|
||||
# Clear the buffer
|
||||
self._text = ""
|
||||
# Process any complete patterns in the order they appear
|
||||
await self.process_text_content()
|
||||
|
||||
# End of a full LLM response
|
||||
# Driven by the prompt, the LLM should have asked the user for input
|
||||
@@ -150,3 +114,38 @@ class StoryProcessor(FrameProcessor):
|
||||
# Anything that is not a TextFrame pass through
|
||||
else:
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def process_text_content(self):
|
||||
"""Process text content in order of appearance, handling both image prompts and story breaks."""
|
||||
while True:
|
||||
# Find the first occurrence of each pattern
|
||||
image_match = re.search(r"<(.*?)>", self._text)
|
||||
break_match = re.search(r"\[[bB]reak\]", self._text)
|
||||
|
||||
# If neither pattern is found, we're done processing
|
||||
if not image_match and not break_match:
|
||||
break
|
||||
|
||||
# Find which pattern comes first in the text
|
||||
image_pos = image_match.start() if image_match else float("inf")
|
||||
break_pos = break_match.start() if break_match else float("inf")
|
||||
|
||||
if image_pos < break_pos:
|
||||
# Process image prompt first
|
||||
image_prompt = image_match.group(1)
|
||||
# Remove the image prompt from the text
|
||||
self._text = self._text[: image_match.start()] + self._text[image_match.end() :]
|
||||
await self.push_frame(StoryImageFrame(image_prompt))
|
||||
else:
|
||||
# Process story break first
|
||||
parts = re.split(r"\[[bB]reak\]", self._text, flags=re.IGNORECASE, maxsplit=1)
|
||||
before_break = parts[0].replace("\n", " ").strip()
|
||||
|
||||
if len(before_break) > 2:
|
||||
self._story.append(before_break)
|
||||
await self.push_frame(StoryPageFrame(before_break))
|
||||
# await self.push_frame(sounds["ding"])
|
||||
await self.push_frame(DailyTransportMessageFrame(CUE_ASSISTANT_TURN))
|
||||
|
||||
# Keep the remainder (if any) in the buffer
|
||||
self._text = parts[1].strip() if len(parts) > 1 else ""
|
||||
|
||||
@@ -1,31 +1,34 @@
|
||||
LLM_INTRO_PROMPT = {
|
||||
"role": "system",
|
||||
"content": "You are a creative storyteller who loves to tell whimsical, fantastical stories. \
|
||||
Your goal is to craft an engaging and fun story. \
|
||||
Start by asking the user what kind of story they'd like to hear. Don't provide any examples. \
|
||||
Keep your response to only a few sentences.",
|
||||
}
|
||||
|
||||
|
||||
LLM_BASE_PROMPT = {
|
||||
"role": "system",
|
||||
"content": "You are a creative storyteller who loves tell whimsical, fantastical stories. \
|
||||
Your goal is to craft an engaging and fun story. \
|
||||
Keep all responses short and no more than a few sentences. Include [break] after each sentence of the story. \
|
||||
\
|
||||
Start each sentence with an image prompt, wrapped in triangle braces, that I can use to generate an illustration representing the upcoming scene. \
|
||||
Image prompts should always be wrapped in triangle braces, like this: <image prompt goes here>. \
|
||||
You should provide as much descriptive detail in your image prompt as you can to help recreate the current scene depicted by the sentence. \
|
||||
For any recurring characters, you should provide a description of them in the image prompt each time, for example: <a brown fluffy dog ...>. \
|
||||
Please do not include any character names in the image prompts, just their descriptions. \
|
||||
Image prompts should focus on key visual attributes of all characters each time, for example <a brown fluffy dog and the tiny red cat ...>. \
|
||||
Please use the following structure for your image prompts: characters, setting, action, and mood. \
|
||||
Image prompts should be less than 150-200 characters and start in lowercase. \
|
||||
\
|
||||
Responses should use the format: <...> story sentence [break] <...> story sentence [break] ... \
|
||||
After each response, ask me how I'd like the story to continue and wait for my input. \
|
||||
Please ensure your responses are less than 3-4 sentences long. \
|
||||
Please refrain from using any explicit language or content. Do not tell scary stories.",
|
||||
"content": """You are a creative storyteller who loves tell whimsical, fantastical stories.
|
||||
Your goal is to craft an engaging and fun story.
|
||||
Keep all responses short and no more than a few sentences.
|
||||
Start by asking the user what kind of story they'd like to hear. Don't provide any examples.
|
||||
After they've answered the question, start telling the story. Include [break] after each sentence of the story.
|
||||
|
||||
Start each sentence with an image prompt, wrapped in triangle braces, that I can use to generate an illustration representing the upcoming scene.
|
||||
Image prompts should always be wrapped in triangle braces, like this: <image prompt goes here>.
|
||||
You should provide as much descriptive detail in your image prompt as you can to help recreate the current scene depicted by the sentence.
|
||||
For any recurring characters, you should provide a description of them in the image prompt each time, for example: <a brown fluffy dog ...>.
|
||||
Please do not include any character names in the image prompts, just their descriptions.
|
||||
Image prompts should focus on key visual attributes of all characters each time, for example <a brown fluffy dog and the tiny red cat ...>.
|
||||
Please use the following structure for your image prompts: characters, setting, action, and mood.
|
||||
Image prompts should be less than 150-200 characters and start in lowercase.
|
||||
|
||||
STORY SENTENCE OUTPUT FORMAT:
|
||||
<image description 1>
|
||||
story sentence 1 [break]
|
||||
<image description 2>
|
||||
story sentence 2 [break]
|
||||
<image description 3>
|
||||
story sentence 3 [break]
|
||||
How would you like the story to continue?
|
||||
END OF EXAMPLE OUTPUT
|
||||
|
||||
Generate three story sentences, then ask what should happen next and wait for my input. You can propose an idea for how the story should proceed, but make sure to tell me I can suggest whatever I want. \
|
||||
Please ensure your responses are less than 5 sentences long. \
|
||||
Please refrain from using any explicit language or content. Do not tell scary stories.
|
||||
Once you've started telling the story, EVERY RESPONSE should follow the story sentence output format. It is VERY IMPORTANT that you continue to include <image descriptions> and [break] between story sentences. DO NOT RESPOND without image descriptions and break tags.""",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ dependencies = [
|
||||
"protobuf~=5.29.3",
|
||||
"pydantic~=2.10.5",
|
||||
"pyloudnorm~=0.1.1",
|
||||
"resampy~=0.4.3"
|
||||
"soxr~=0.5.0"
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
@@ -85,7 +85,13 @@ openrouter = [ "openai~=1.59.6" ]
|
||||
where = ["src"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
addopts = "--verbose"
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["src"]
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
filterwarnings = [
|
||||
"ignore:'audioop' is deprecated:DeprecationWarning",
|
||||
]
|
||||
|
||||
[tool.setuptools_scm]
|
||||
local_scheme = "no-local-version"
|
||||
|
||||
@@ -8,14 +8,14 @@ import audioop
|
||||
|
||||
import numpy as np
|
||||
import pyloudnorm as pyln
|
||||
import resampy
|
||||
import soxr
|
||||
|
||||
|
||||
def resample_audio(audio: bytes, original_rate: int, target_rate: int) -> bytes:
|
||||
if original_rate == target_rate:
|
||||
return audio
|
||||
audio_data = np.frombuffer(audio, dtype=np.int16)
|
||||
resampled_audio = resampy.resample(audio_data, original_rate, target_rate)
|
||||
resampled_audio = soxr.resample(audio_data, original_rate, target_rate)
|
||||
return resampled_audio.astype(np.int16).tobytes()
|
||||
|
||||
|
||||
|
||||
56
src/pipecat/pipeline/base_task.py
Normal file
56
src/pipecat/pipeline/base_task.py
Normal file
@@ -0,0 +1,56 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterable, Iterable
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
|
||||
|
||||
class BaseTask(ABC):
|
||||
@abstractmethod
|
||||
def has_finished(self) -> bool:
|
||||
"""Indicates whether the tasks has finished. That is, all processors
|
||||
have stopped.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def stop_when_done(self):
|
||||
"""This is a helper function that sends an EndFrame to the pipeline in
|
||||
order to stop the task after everything in it has been processed.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def cancel(self):
|
||||
"""
|
||||
Stops the running pipeline immediately.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run(self):
|
||||
"""
|
||||
Starts running the given pipeline.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def queue_frame(self, frame: Frame):
|
||||
"""
|
||||
Queue a frame to be pushed down the pipeline.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]):
|
||||
"""
|
||||
Queues multiple frames to be pushed down the pipeline.
|
||||
"""
|
||||
pass
|
||||
@@ -27,6 +27,7 @@ from pipecat.frames.frames import (
|
||||
from pipecat.metrics.metrics import ProcessingMetricsData, TTFBMetricsData
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.pipeline.base_pipeline import BasePipeline
|
||||
from pipecat.pipeline.base_task import BaseTask
|
||||
from pipecat.pipeline.task_observer import TaskObserver
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
@@ -45,6 +46,7 @@ class PipelineParams(BaseModel):
|
||||
send_initial_empty_metrics: bool = True
|
||||
report_only_initial_ttfb: bool = False
|
||||
observers: List[BaseObserver] = []
|
||||
heartbeats_period_secs: float = HEARTBEAT_SECONDS
|
||||
|
||||
|
||||
class Source(FrameProcessor):
|
||||
@@ -85,7 +87,7 @@ class Sink(FrameProcessor):
|
||||
await self._down_queue.put(frame)
|
||||
|
||||
|
||||
class PipelineTask:
|
||||
class PipelineTask(BaseTask):
|
||||
def __init__(
|
||||
self,
|
||||
pipeline: BasePipeline,
|
||||
@@ -121,7 +123,7 @@ class PipelineTask:
|
||||
|
||||
self._observer = TaskObserver(params.observers)
|
||||
|
||||
def has_finished(self):
|
||||
def has_finished(self) -> bool:
|
||||
"""Indicates whether the tasks has finished. That is, all processors
|
||||
have stopped.
|
||||
|
||||
@@ -315,7 +317,7 @@ class PipelineTask:
|
||||
|
||||
async def _heartbeat_push_handler(self):
|
||||
"""
|
||||
This tasks pushes a heartbeat frame every HEARTBEAT_SECONDS.
|
||||
This tasks pushes a heartbeat frame every heartbeat period.
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
@@ -323,7 +325,7 @@ class PipelineTask:
|
||||
# task will just stop waiting for the pipeline to finish not
|
||||
# allowing more frames to be pushed.
|
||||
await self._source.queue_frame(HeartbeatFrame(timestamp=self._clock.get_time()))
|
||||
await asyncio.sleep(HEARTBEAT_SECONDS)
|
||||
await asyncio.sleep(self._params.heartbeats_period_secs)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class FrameFilter(FrameProcessor):
|
||||
def __init__(self, types: Tuple[Type[Frame]]):
|
||||
def __init__(self, types: Tuple[Type[Frame], ...]):
|
||||
super().__init__()
|
||||
self._types = types
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ class FrameProcessor:
|
||||
self.__should_block_frames = True
|
||||
|
||||
async def resume_processing_frames(self):
|
||||
logger.trace("f{self}: resuming frame processing")
|
||||
logger.trace(f"{self}: resuming frame processing")
|
||||
self.__input_event.set()
|
||||
self.__should_block_frames = False
|
||||
|
||||
@@ -293,8 +293,7 @@ class FrameProcessor:
|
||||
await self.__input_frame_task
|
||||
|
||||
async def __input_frame_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
while True:
|
||||
try:
|
||||
if self.__should_block_frames:
|
||||
logger.trace(f"{self}: frame processing paused")
|
||||
@@ -311,8 +310,6 @@ class FrameProcessor:
|
||||
if callback:
|
||||
await callback(self, frame, direction)
|
||||
|
||||
running = not isinstance(frame, EndFrame)
|
||||
|
||||
self.__input_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
logger.trace(f"{self}: cancelled input task")
|
||||
@@ -330,12 +327,10 @@ class FrameProcessor:
|
||||
await self.__push_frame_task
|
||||
|
||||
async def __push_frame_task_handler(self):
|
||||
running = True
|
||||
while running:
|
||||
while True:
|
||||
try:
|
||||
(frame, direction) = await self.__push_queue.get()
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
running = not isinstance(frame, EndFrame)
|
||||
self.__push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
logger.trace(f"{self}: cancelled push task")
|
||||
|
||||
@@ -62,6 +62,9 @@ from pipecat.utils.string import match_endofsentence
|
||||
|
||||
RTVI_PROTOCOL_VERSION = "0.3.0"
|
||||
|
||||
RTVI_MESSAGE_LABEL = "rtvi-ai"
|
||||
RTVIMessageLiteral = Literal["rtvi-ai"]
|
||||
|
||||
ActionResult = Union[bool, int, float, str, list, dict]
|
||||
|
||||
|
||||
@@ -154,7 +157,7 @@ class RTVIActionFrame(DataFrame):
|
||||
|
||||
|
||||
class RTVIMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: str
|
||||
id: str
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
@@ -170,7 +173,7 @@ class RTVIErrorResponseData(BaseModel):
|
||||
|
||||
|
||||
class RTVIErrorResponse(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["error-response"] = "error-response"
|
||||
id: str
|
||||
data: RTVIErrorResponseData
|
||||
@@ -182,7 +185,7 @@ class RTVIErrorData(BaseModel):
|
||||
|
||||
|
||||
class RTVIError(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["error"] = "error"
|
||||
data: RTVIErrorData
|
||||
|
||||
@@ -192,7 +195,7 @@ class RTVIDescribeConfigData(BaseModel):
|
||||
|
||||
|
||||
class RTVIDescribeConfig(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["config-available"] = "config-available"
|
||||
id: str
|
||||
data: RTVIDescribeConfigData
|
||||
@@ -203,14 +206,14 @@ class RTVIDescribeActionsData(BaseModel):
|
||||
|
||||
|
||||
class RTVIDescribeActions(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["actions-available"] = "actions-available"
|
||||
id: str
|
||||
data: RTVIDescribeActionsData
|
||||
|
||||
|
||||
class RTVIConfigResponse(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["config"] = "config"
|
||||
id: str
|
||||
data: RTVIConfig
|
||||
@@ -221,7 +224,7 @@ class RTVIActionResponseData(BaseModel):
|
||||
|
||||
|
||||
class RTVIActionResponse(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["action-response"] = "action-response"
|
||||
id: str
|
||||
data: RTVIActionResponseData
|
||||
@@ -233,7 +236,7 @@ class RTVIBotReadyData(BaseModel):
|
||||
|
||||
|
||||
class RTVIBotReady(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-ready"] = "bot-ready"
|
||||
id: str
|
||||
data: RTVIBotReadyData
|
||||
@@ -246,7 +249,7 @@ class RTVILLMFunctionCallMessageData(BaseModel):
|
||||
|
||||
|
||||
class RTVILLMFunctionCallMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["llm-function-call"] = "llm-function-call"
|
||||
data: RTVILLMFunctionCallMessageData
|
||||
|
||||
@@ -256,7 +259,7 @@ class RTVILLMFunctionCallStartMessageData(BaseModel):
|
||||
|
||||
|
||||
class RTVILLMFunctionCallStartMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["llm-function-call-start"] = "llm-function-call-start"
|
||||
data: RTVILLMFunctionCallStartMessageData
|
||||
|
||||
@@ -269,22 +272,22 @@ class RTVILLMFunctionCallResultData(BaseModel):
|
||||
|
||||
|
||||
class RTVIBotLLMStartedMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-llm-started"] = "bot-llm-started"
|
||||
|
||||
|
||||
class RTVIBotLLMStoppedMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-llm-stopped"] = "bot-llm-stopped"
|
||||
|
||||
|
||||
class RTVIBotTTSStartedMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-tts-started"] = "bot-tts-started"
|
||||
|
||||
|
||||
class RTVIBotTTSStoppedMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-tts-stopped"] = "bot-tts-stopped"
|
||||
|
||||
|
||||
@@ -293,19 +296,19 @@ class RTVITextMessageData(BaseModel):
|
||||
|
||||
|
||||
class RTVIBotTranscriptionMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-transcription"] = "bot-transcription"
|
||||
data: RTVITextMessageData
|
||||
|
||||
|
||||
class RTVIBotLLMTextMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-llm-text"] = "bot-llm-text"
|
||||
data: RTVITextMessageData
|
||||
|
||||
|
||||
class RTVIBotTTSTextMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-tts-text"] = "bot-tts-text"
|
||||
data: RTVITextMessageData
|
||||
|
||||
@@ -317,7 +320,7 @@ class RTVIAudioMessageData(BaseModel):
|
||||
|
||||
|
||||
class RTVIBotTTSAudioMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-tts-audio"] = "bot-tts-audio"
|
||||
data: RTVIAudioMessageData
|
||||
|
||||
@@ -330,39 +333,39 @@ class RTVIUserTranscriptionMessageData(BaseModel):
|
||||
|
||||
|
||||
class RTVIUserTranscriptionMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["user-transcription"] = "user-transcription"
|
||||
data: RTVIUserTranscriptionMessageData
|
||||
|
||||
|
||||
class RTVIUserLLMTextMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["user-llm-text"] = "user-llm-text"
|
||||
data: RTVITextMessageData
|
||||
|
||||
|
||||
class RTVIUserStartedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["user-started-speaking"] = "user-started-speaking"
|
||||
|
||||
|
||||
class RTVIUserStoppedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["user-stopped-speaking"] = "user-stopped-speaking"
|
||||
|
||||
|
||||
class RTVIBotStartedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-started-speaking"] = "bot-started-speaking"
|
||||
|
||||
|
||||
class RTVIBotStoppedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["bot-stopped-speaking"] = "bot-stopped-speaking"
|
||||
|
||||
|
||||
class RTVIMetricsMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
label: RTVIMessageLiteral = RTVI_MESSAGE_LABEL
|
||||
type: Literal["metrics"] = "metrics"
|
||||
data: Mapping[str, Any]
|
||||
|
||||
@@ -875,7 +878,11 @@ class RTVIProcessor(FrameProcessor):
|
||||
|
||||
async def _handle_transport_message(self, frame: TransportMessageUrgentFrame):
|
||||
try:
|
||||
message = RTVIMessage.model_validate(frame.message)
|
||||
transport_message = frame.message
|
||||
if transport_message.get("label") != RTVI_MESSAGE_LABEL:
|
||||
logger.warning(f"Ignoring not RTVI message: {transport_message}")
|
||||
return
|
||||
message = RTVIMessage.model_validate(transport_message)
|
||||
await self._message_queue.put(message)
|
||||
except ValidationError as e:
|
||||
await self.send_error(f"Invalid RTVI transport message: {e}")
|
||||
|
||||
@@ -4,17 +4,23 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotStoppedSpeakingFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
OpenAILLMContextAssistantTimestampFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TranscriptionMessage,
|
||||
TranscriptionUpdateFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
|
||||
class BaseTranscriptProcessor(FrameProcessor):
|
||||
@@ -64,89 +70,74 @@ class UserTranscriptProcessor(BaseTranscriptProcessor):
|
||||
|
||||
|
||||
class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""Processes assistant LLM context frames into timestamped conversation messages."""
|
||||
"""Processes assistant TTS text frames into timestamped conversation messages.
|
||||
|
||||
This processor aggregates TTS text frames into complete utterances and emits them as
|
||||
transcript messages. Utterances are completed when:
|
||||
- The bot stops speaking (BotStoppedSpeakingFrame)
|
||||
- The bot is interrupted (StartInterruptionFrame)
|
||||
- The pipeline ends (EndFrame)
|
||||
|
||||
Attributes:
|
||||
_current_text_parts: List of text fragments being aggregated for current utterance
|
||||
_aggregation_start_time: Timestamp when the current utterance began
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Initialize processor with empty message stores."""
|
||||
"""Initialize processor with aggregation state."""
|
||||
super().__init__(**kwargs)
|
||||
self._pending_assistant_messages: List[TranscriptionMessage] = []
|
||||
self._current_text_parts: List[str] = []
|
||||
self._aggregation_start_time: Optional[str] | None = None
|
||||
|
||||
def _extract_messages(self, messages: List[dict]) -> List[TranscriptionMessage]:
|
||||
"""Extract assistant messages from the OpenAI standard message format.
|
||||
async def _emit_aggregated_text(self):
|
||||
"""Emit aggregated text as a transcript message."""
|
||||
if self._current_text_parts and self._aggregation_start_time:
|
||||
content = " ".join(self._current_text_parts).strip()
|
||||
if content:
|
||||
logger.debug(f"Emitting aggregated assistant message: {content}")
|
||||
message = TranscriptionMessage(
|
||||
role="assistant",
|
||||
content=content,
|
||||
timestamp=self._aggregation_start_time,
|
||||
)
|
||||
await self._emit_update([message])
|
||||
else:
|
||||
logger.debug("No content to emit after stripping whitespace")
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format, which can be either:
|
||||
- Simple format: {"role": "user", "content": "Hello"}
|
||||
- Content list: {"role": "user", "content": [{"type": "text", "text": "Hello"}]}
|
||||
|
||||
Returns:
|
||||
List[TranscriptionMessage]: Normalized conversation messages
|
||||
"""
|
||||
result = []
|
||||
for msg in messages:
|
||||
if msg["role"] != "assistant":
|
||||
continue
|
||||
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str):
|
||||
if content:
|
||||
result.append(TranscriptionMessage(role="assistant", content=content))
|
||||
elif isinstance(content, list):
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
text_parts.append(part["text"])
|
||||
|
||||
if text_parts:
|
||||
result.append(
|
||||
TranscriptionMessage(role="assistant", content=" ".join(text_parts))
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _find_new_messages(self, current: List[TranscriptionMessage]) -> List[TranscriptionMessage]:
|
||||
"""Find unprocessed messages from current list.
|
||||
|
||||
Args:
|
||||
current: List of current messages
|
||||
|
||||
Returns:
|
||||
List[TranscriptionMessage]: New messages not yet processed
|
||||
"""
|
||||
if not self._processed_messages:
|
||||
return current
|
||||
|
||||
processed_len = len(self._processed_messages)
|
||||
if len(current) <= processed_len:
|
||||
return []
|
||||
|
||||
return current[processed_len:]
|
||||
# Reset aggregation state
|
||||
self._current_text_parts = []
|
||||
self._aggregation_start_time = None
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames into assistant conversation messages.
|
||||
|
||||
Handles different frame types:
|
||||
- TTSTextFrame: Aggregates text for current utterance
|
||||
- BotStoppedSpeakingFrame: Completes current utterance
|
||||
- StartInterruptionFrame: Completes current utterance due to interruption
|
||||
- EndFrame: Completes current utterance at pipeline end
|
||||
- CancelFrame: Completes current utterance due to cancellation
|
||||
|
||||
Args:
|
||||
frame: Input frame to process
|
||||
direction: Frame processing direction
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
standard_messages = []
|
||||
for msg in frame.context.messages:
|
||||
converted = frame.context.to_standard_messages(msg)
|
||||
standard_messages.extend(converted)
|
||||
if isinstance(frame, TTSTextFrame):
|
||||
# Start timestamp on first text part
|
||||
if not self._aggregation_start_time:
|
||||
self._aggregation_start_time = time_now_iso8601()
|
||||
|
||||
current_messages = self._extract_messages(standard_messages)
|
||||
new_messages = self._find_new_messages(current_messages)
|
||||
self._pending_assistant_messages.extend(new_messages)
|
||||
self._current_text_parts.append(frame.text)
|
||||
|
||||
elif isinstance(frame, OpenAILLMContextAssistantTimestampFrame):
|
||||
if self._pending_assistant_messages:
|
||||
for msg in self._pending_assistant_messages:
|
||||
msg.timestamp = frame.timestamp
|
||||
await self._emit_update(self._pending_assistant_messages)
|
||||
self._pending_assistant_messages = []
|
||||
elif isinstance(frame, (BotStoppedSpeakingFrame, StartInterruptionFrame, CancelFrame)):
|
||||
# Emit accumulated text when bot finishes speaking or is interrupted
|
||||
await self._emit_aggregated_text()
|
||||
|
||||
elif isinstance(frame, EndFrame):
|
||||
# Emit any remaining text when pipeline ends
|
||||
await self._emit_aggregated_text()
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -170,8 +161,8 @@ class TranscriptProcessor:
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
transcript.assistant_tts(), # Assistant transcripts
|
||||
context_aggregator.assistant(),
|
||||
transcript.assistant(), # Assistant transcripts
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -93,11 +93,11 @@ class ProtobufFrameSerializer(FrameSerializer):
|
||||
id = getattr(args, "id", None)
|
||||
name = getattr(args, "name", None)
|
||||
pts = getattr(args, "pts", None)
|
||||
if not id and "id" in args_dict:
|
||||
if "id" in args_dict:
|
||||
del args_dict["id"]
|
||||
if not name and "name" in args_dict:
|
||||
if "name" in args_dict:
|
||||
del args_dict["name"]
|
||||
if not pts and "pts" in args_dict:
|
||||
if "pts" in args_dict:
|
||||
del args_dict["pts"]
|
||||
|
||||
# Create the instance
|
||||
@@ -105,10 +105,10 @@ class ProtobufFrameSerializer(FrameSerializer):
|
||||
|
||||
# Set special fields
|
||||
if id:
|
||||
setattr(instance, "id", getattr(args, "id", None))
|
||||
setattr(instance, "id", id)
|
||||
if name:
|
||||
setattr(instance, "name", getattr(args, "name", None))
|
||||
setattr(instance, "name", name)
|
||||
if pts:
|
||||
setattr(instance, "pts", getattr(args, "pts", None))
|
||||
setattr(instance, "pts", pts)
|
||||
|
||||
return instance
|
||||
|
||||
@@ -18,6 +18,7 @@ from pipecat.frames.frames import CancelFrame, EndFrame, Frame
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AIService
|
||||
|
||||
try:
|
||||
import aiofiles
|
||||
|
||||
@@ -524,13 +524,13 @@ class ElevenLabsHttpTTSService(TTSService):
|
||||
|
||||
url = f"{self._base_url}/v1/text-to-speech/{self._voice_id}/stream"
|
||||
|
||||
payload = {
|
||||
payload: Dict[str, Union[str, Dict[str, Union[float, bool]]]] = {
|
||||
"text": text,
|
||||
"model_id": self._model_name,
|
||||
}
|
||||
|
||||
if self._voice_settings:
|
||||
payload["voice_settings"] = json.dumps(self._voice_settings)
|
||||
payload["voice_settings"] = self._voice_settings
|
||||
|
||||
if self._settings["language"]:
|
||||
payload["language_code"] = self._settings["language"]
|
||||
|
||||
@@ -288,6 +288,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
)
|
||||
|
||||
async def _handle_transcribe_model_audio(self, audio, context):
|
||||
# Early return if modalities are not set to audio.
|
||||
if self._settings["modalities"] != GeminiMultimodalModalities.AUDIO:
|
||||
return
|
||||
|
||||
text = await self._transcribe_audio(audio, context)
|
||||
logger.debug(f"[Transcription:model] {text}")
|
||||
# We add user messages directly to the context. We don't do that for assistant messages,
|
||||
|
||||
@@ -221,7 +221,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if len(chunk.choices) == 0:
|
||||
if chunk.choices is None or len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
|
||||
@@ -6,10 +6,16 @@
|
||||
|
||||
import copy
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import Frame, LLMMessagesUpdateFrame, LLMSetToolsFrame
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallResultProperties,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
@@ -174,10 +180,13 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
|
||||
if not self._function_call_result:
|
||||
return
|
||||
|
||||
properties: Optional[FunctionCallResultProperties] = None
|
||||
|
||||
self._reset()
|
||||
try:
|
||||
run_llm = True
|
||||
frame = self._function_call_result
|
||||
properties = frame.properties
|
||||
self._function_call_result = None
|
||||
if frame.result:
|
||||
# The "tool_call" message from the LLM that triggered the function call
|
||||
@@ -211,11 +220,20 @@ class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator)
|
||||
await self._user_context_aggregator.push_frame(
|
||||
RealtimeFunctionCallResultFrame(result_frame=frame)
|
||||
)
|
||||
run_llm = frame.run_llm
|
||||
if properties and properties.run_llm is not None:
|
||||
# If the tool call result has a run_llm property, use it
|
||||
run_llm = properties.run_llm
|
||||
else:
|
||||
# Default behavior is to run the LLM if there are no function calls in progress
|
||||
run_llm = not bool(self._function_calls_in_progress)
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
# Emit the on_context_updated callback once the function call result is added to the context
|
||||
if properties and properties.on_context_updated is not None:
|
||||
await properties.on_context_updated()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
|
||||
@@ -22,8 +22,8 @@ pydantic~=2.8.2
|
||||
pyloudnorm~=0.1.1
|
||||
pyht~=0.1.4
|
||||
python-dotenv~=1.0.1
|
||||
resampy~=0.4.3
|
||||
silero-vad~=5.1
|
||||
soxr~=0.5.0
|
||||
together~=1.2.7
|
||||
transformers~=4.44.0
|
||||
websockets~=13.1
|
||||
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
@@ -1,122 +1,70 @@
|
||||
import asyncio
|
||||
import doctest
|
||||
import functools
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
ImageRawFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
OutputAudioRawFrame,
|
||||
OutputImageRawFrame,
|
||||
TextFrame,
|
||||
)
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.gated import GatedAggregator
|
||||
from pipecat.processors.aggregators.sentence import SentenceAggregator
|
||||
from pipecat.processors.text_transformer import StatelessTextTransformer
|
||||
from tests.utils import run_test
|
||||
|
||||
|
||||
class TestDailyFrameAggregators(unittest.IsolatedAsyncioTestCase):
|
||||
@unittest.skip("FIXME: This test is failing")
|
||||
class TestSentenceAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_sentence_aggregator(self):
|
||||
sentence = "Hello, world. How are you? I am fine"
|
||||
expected_sentences = ["Hello, world.", " How are you?", " I am fine "]
|
||||
aggregator = SentenceAggregator()
|
||||
|
||||
sentence = "Hello, world. How are you? I am fine!"
|
||||
|
||||
frames_to_send = []
|
||||
for word in sentence.split(" "):
|
||||
async for sentence in aggregator.process_frame(TextFrame(word + " ")):
|
||||
self.assertIsInstance(sentence, TextFrame)
|
||||
if isinstance(sentence, TextFrame):
|
||||
self.assertEqual(sentence.text, expected_sentences.pop(0))
|
||||
frames_to_send.append(TextFrame(text=word + " "))
|
||||
|
||||
async for sentence in aggregator.process_frame(EndFrame()):
|
||||
if len(expected_sentences):
|
||||
self.assertIsInstance(sentence, TextFrame)
|
||||
if isinstance(sentence, TextFrame):
|
||||
self.assertEqual(sentence.text, expected_sentences.pop(0))
|
||||
else:
|
||||
self.assertIsInstance(sentence, EndFrame)
|
||||
expected_returned_frames = [TextFrame, TextFrame, TextFrame]
|
||||
|
||||
self.assertEqual(expected_sentences, [])
|
||||
(received_down, _) = await run_test(aggregator, frames_to_send, expected_returned_frames)
|
||||
assert received_down[-3].text == "Hello, world. "
|
||||
assert received_down[-2].text == "How are you? "
|
||||
assert received_down[-1].text == "I am fine! "
|
||||
|
||||
@unittest.skip("FIXME: This test is failing")
|
||||
async def test_gated_accumulator(self):
|
||||
|
||||
class TestGatedAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_gated_aggregator(self):
|
||||
gated_aggregator = GatedAggregator(
|
||||
gate_open_fn=lambda frame: isinstance(frame, ImageRawFrame),
|
||||
gate_close_fn=lambda frame: isinstance(frame, LLMFullResponseStartFrame),
|
||||
start_open=False,
|
||||
)
|
||||
|
||||
frames = [
|
||||
frames_to_send = [
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame("Hello, "),
|
||||
TextFrame("world."),
|
||||
AudioRawFrame(b"hello"),
|
||||
ImageRawFrame(b"image", (0, 0)),
|
||||
AudioRawFrame(b"world"),
|
||||
OutputAudioRawFrame(audio=b"hello", sample_rate=16000, num_channels=1),
|
||||
OutputImageRawFrame(image=b"image", size=(0, 0), format="RGB"),
|
||||
OutputAudioRawFrame(audio=b"world", sample_rate=16000, num_channels=1),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
|
||||
expected_output_frames = [
|
||||
ImageRawFrame(b"image", (0, 0)),
|
||||
LLMFullResponseStartFrame(),
|
||||
TextFrame("Hello, "),
|
||||
TextFrame("world."),
|
||||
AudioRawFrame(b"hello"),
|
||||
AudioRawFrame(b"world"),
|
||||
LLMFullResponseEndFrame(),
|
||||
expected_returned_frames = [
|
||||
OutputImageRawFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
TextFrame,
|
||||
TextFrame,
|
||||
OutputAudioRawFrame,
|
||||
OutputAudioRawFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
]
|
||||
for frame in frames:
|
||||
async for out_frame in gated_aggregator.process_frame(frame):
|
||||
self.assertEqual(out_frame, expected_output_frames.pop(0))
|
||||
self.assertEqual(expected_output_frames, [])
|
||||
|
||||
@unittest.skip("FIXME: This test is failing")
|
||||
async def test_parallel_pipeline(self):
|
||||
async def slow_add(sleep_time: float, name: str, x: str):
|
||||
await asyncio.sleep(sleep_time)
|
||||
return ":".join([x, name])
|
||||
|
||||
pipe1_annotation = StatelessTextTransformer(functools.partial(slow_add, 0.1, "pipe1"))
|
||||
pipe2_annotation = StatelessTextTransformer(functools.partial(slow_add, 0.2, "pipe2"))
|
||||
sentence_aggregator = SentenceAggregator()
|
||||
add_dots = StatelessTextTransformer(lambda x: x + ".")
|
||||
|
||||
source = asyncio.Queue()
|
||||
sink = asyncio.Queue()
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
ParallelPipeline([[pipe1_annotation], [sentence_aggregator, pipe2_annotation]]),
|
||||
add_dots,
|
||||
],
|
||||
source,
|
||||
sink,
|
||||
(received_down, _) = await run_test(
|
||||
gated_aggregator, frames_to_send, expected_returned_frames
|
||||
)
|
||||
|
||||
frames = [TextFrame("Hello, "), TextFrame("world."), EndFrame()]
|
||||
|
||||
expected_output_frames: list[Frame] = [
|
||||
TextFrame(text="Hello, :pipe1."),
|
||||
TextFrame(text="world.:pipe1."),
|
||||
TextFrame(text="Hello, world.:pipe2."),
|
||||
EndFrame(),
|
||||
]
|
||||
|
||||
for frame in frames:
|
||||
await source.put(frame)
|
||||
|
||||
await pipeline.run_pipeline()
|
||||
|
||||
while not sink.empty():
|
||||
frame = await sink.get()
|
||||
self.assertEqual(frame, expected_output_frames.pop(0))
|
||||
|
||||
|
||||
def load_tests(loader, tests, ignore):
|
||||
"""Run doctests on the aggregators module."""
|
||||
from pipecat.processors import aggregators
|
||||
|
||||
tests.addTests(doctest.DocTestSuite(aggregators))
|
||||
return tests
|
||||
|
||||
@@ -1,3 +1,9 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
|
||||
|
||||
94
tests/test_filters.py
Normal file
94
tests/test_filters.py
Normal file
@@ -0,0 +1,94 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.filters.frame_filter import FrameFilter
|
||||
from pipecat.processors.filters.function_filter import FunctionFilter
|
||||
from pipecat.processors.filters.identity_filter import IdentityFilter
|
||||
from pipecat.processors.filters.wake_check_filter import WakeCheckFilter
|
||||
from tests.utils import EndTestFrame, run_test
|
||||
|
||||
|
||||
class TestIdentifyFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_identity(self):
|
||||
filter = IdentityFilter()
|
||||
frames_to_send = [UserStartedSpeakingFrame(), UserStoppedSpeakingFrame()]
|
||||
expected_returned_frames = [UserStartedSpeakingFrame, UserStoppedSpeakingFrame]
|
||||
await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
|
||||
|
||||
class TestFrameFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_text_frame(self):
|
||||
filter = FrameFilter(types=(TextFrame, EndTestFrame))
|
||||
frames_to_send = [TextFrame(text="Hello Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
|
||||
async def test_end_frame(self):
|
||||
filter = FrameFilter(types=(EndFrame, EndTestFrame))
|
||||
frames_to_send = [EndFrame()]
|
||||
expected_returned_frames = [EndFrame]
|
||||
await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
|
||||
async def test_system_frame(self):
|
||||
filter = FrameFilter(types=(EndTestFrame,))
|
||||
frames_to_send = [UserStartedSpeakingFrame()]
|
||||
expected_returned_frames = [UserStartedSpeakingFrame]
|
||||
await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
|
||||
|
||||
class TestFunctionFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_passthrough(self):
|
||||
async def passthrough(frame: Frame):
|
||||
return True
|
||||
|
||||
filter = FunctionFilter(filter=passthrough)
|
||||
frames_to_send = [TextFrame(text="Hello Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
|
||||
async def test_no_passthrough(self):
|
||||
async def no_passthrough(frame: Frame):
|
||||
return False
|
||||
|
||||
filter = FunctionFilter(filter=no_passthrough)
|
||||
frames_to_send = [TextFrame(text="Hello Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
run_test(filter, frames_to_send, expected_returned_frames), timeout=0.5
|
||||
)
|
||||
assert False
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
|
||||
class TestWakeCheckFilter(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_no_wake_word(self):
|
||||
filter = WakeCheckFilter(wake_phrases=["Hey, Pipecat"])
|
||||
frames_to_send = [TranscriptionFrame(user_id="test", text="Phrase 1", timestamp="")]
|
||||
expected_returned_frames = []
|
||||
await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
|
||||
async def test_wake_word(self):
|
||||
filter = WakeCheckFilter(wake_phrases=["Hey, Pipecat"])
|
||||
frames_to_send = [
|
||||
TranscriptionFrame(user_id="test", text="Hey, Pipecat", timestamp=""),
|
||||
TranscriptionFrame(user_id="test", text="Phrase 1", timestamp=""),
|
||||
]
|
||||
expected_returned_frames = [TranscriptionFrame, TranscriptionFrame]
|
||||
(received_down, _) = await run_test(filter, frames_to_send, expected_returned_frames)
|
||||
assert received_down[-1].text == "Phrase 1"
|
||||
@@ -93,7 +93,3 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
# This next one would fail with:
|
||||
# AssertionError: ' H e l l o d e a r h u m a n' != 'Hello dear human'
|
||||
# self.assertEqual(tma_out.messages[-1]["content"], self.expected_response)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
|
||||
import openai
|
||||
import pyaudio
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, ErrorFrame
|
||||
from pipecat.services.openai import OpenAITTSService
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class TestWhisperOpenAIService(unittest.IsolatedAsyncioTestCase):
|
||||
@unittest.skip("FIXME: This test is failing")
|
||||
async def test_whisper_tts(self):
|
||||
pa = pyaudio.PyAudio()
|
||||
stream = pa.open(format=pyaudio.paInt16, channels=1, rate=24_000, output=True)
|
||||
|
||||
tts = OpenAITTSService(voice="nova")
|
||||
|
||||
async for frame in tts.run_tts("Hello, there. Nice to meet you, seems to work well"):
|
||||
self.assertIsInstance(frame, AudioRawFrame)
|
||||
stream.write(frame.audio)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
stream.stop_stream()
|
||||
pa.terminate()
|
||||
|
||||
tts = OpenAITTSService(voice="invalid_voice")
|
||||
with self.assertRaises(openai.BadRequestError):
|
||||
async for frame in tts.run_tts("wont work"):
|
||||
self.assertIsInstance(frame, ErrorFrame)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,116 +1,92 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import Mock
|
||||
|
||||
from pipecat.frames.frames import EndFrame, TextFrame
|
||||
from pipecat.frames.frames import EndFrame, HeartbeatFrame, TextFrame
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.sentence import SentenceAggregator
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.filters.identity_filter import IdentityFilter
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.processors.text_transformer import StatelessTextTransformer
|
||||
from tests.utils import HeartbeatsObserver, run_test
|
||||
|
||||
|
||||
class TestDailyPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
@unittest.skip("FIXME: This test is failing")
|
||||
async def test_pipeline_simple(self):
|
||||
aggregator = SentenceAggregator()
|
||||
class TestPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_pipeline_single(self):
|
||||
pipeline = Pipeline([IdentityFilter()])
|
||||
|
||||
outgoing_queue = asyncio.Queue()
|
||||
incoming_queue = asyncio.Queue()
|
||||
pipeline = Pipeline([aggregator], incoming_queue, outgoing_queue)
|
||||
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
await run_test(pipeline, frames_to_send, expected_returned_frames)
|
||||
|
||||
await incoming_queue.put(TextFrame("Hello, "))
|
||||
await incoming_queue.put(TextFrame("world."))
|
||||
await incoming_queue.put(EndFrame())
|
||||
async def test_pipeline_multiple(self):
|
||||
identity1 = IdentityFilter()
|
||||
identity2 = IdentityFilter()
|
||||
identity3 = IdentityFilter()
|
||||
|
||||
await pipeline.run_pipeline()
|
||||
pipeline = Pipeline([identity1, identity2, identity3])
|
||||
|
||||
self.assertEqual(await outgoing_queue.get(), TextFrame("Hello, world."))
|
||||
self.assertIsInstance(await outgoing_queue.get(), EndFrame)
|
||||
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
await run_test(pipeline, frames_to_send, expected_returned_frames)
|
||||
|
||||
@unittest.skip("FIXME: This test is failing")
|
||||
async def test_pipeline_multiple_stages(self):
|
||||
sentence_aggregator = SentenceAggregator()
|
||||
to_upper = StatelessTextTransformer(lambda x: x.upper())
|
||||
add_space = StatelessTextTransformer(lambda x: x + " ")
|
||||
|
||||
outgoing_queue = asyncio.Queue()
|
||||
incoming_queue = asyncio.Queue()
|
||||
pipeline = Pipeline(
|
||||
[add_space, sentence_aggregator, to_upper], incoming_queue, outgoing_queue
|
||||
class TestParallelPipeline(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_parallel_single(self):
|
||||
pipeline = ParallelPipeline([IdentityFilter()])
|
||||
|
||||
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
await run_test(pipeline, frames_to_send, expected_returned_frames)
|
||||
|
||||
async def test_parallel_multiple(self):
|
||||
"""Should only passthrough one instance of TextFrame."""
|
||||
pipeline = ParallelPipeline([IdentityFilter()], [IdentityFilter()])
|
||||
|
||||
frames_to_send = [TextFrame(text="Hello from Pipecat!")]
|
||||
expected_returned_frames = [TextFrame]
|
||||
await run_test(pipeline, frames_to_send, expected_returned_frames)
|
||||
|
||||
|
||||
class TestPipelineTask(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_task_single(self):
|
||||
pipeline = Pipeline([IdentityFilter()])
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
await task.queue_frame(TextFrame(text="Hello!"))
|
||||
await task.queue_frames([TextFrame(text="Bye!"), EndFrame()])
|
||||
await task.run()
|
||||
assert task.has_finished()
|
||||
|
||||
async def test_task_heartbeats(self):
|
||||
heartbeats_counter = 0
|
||||
|
||||
async def heartbeat_received(processor: FrameProcessor, heartbeat: HeartbeatFrame):
|
||||
nonlocal heartbeats_counter
|
||||
heartbeats_counter += 1
|
||||
|
||||
identity = IdentityFilter()
|
||||
pipeline = Pipeline([identity])
|
||||
heartbeats_observer = HeartbeatsObserver(
|
||||
target=identity, heartbeat_callback=heartbeat_received
|
||||
)
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_heartbeats=True, heartbeats_period_secs=0.2, observers=[heartbeats_observer]
|
||||
),
|
||||
)
|
||||
|
||||
sentence = "Hello, world. It's me, a pipeline."
|
||||
for c in sentence:
|
||||
await incoming_queue.put(TextFrame(c))
|
||||
await incoming_queue.put(EndFrame())
|
||||
expected_heartbeats = 1.0 / 0.2
|
||||
|
||||
await pipeline.run_pipeline()
|
||||
|
||||
self.assertEqual(await outgoing_queue.get(), TextFrame("H E L L O , W O R L D ."))
|
||||
self.assertEqual(
|
||||
await outgoing_queue.get(),
|
||||
TextFrame(" I T ' S M E , A P I P E L I N E ."),
|
||||
)
|
||||
# leftover little bit because of the spacing
|
||||
self.assertEqual(
|
||||
await outgoing_queue.get(),
|
||||
TextFrame(" "),
|
||||
)
|
||||
self.assertIsInstance(await outgoing_queue.get(), EndFrame)
|
||||
|
||||
|
||||
class TestLogFrame(unittest.TestCase):
|
||||
class MockProcessor(FrameProcessor):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
def setUp(self):
|
||||
self.processor1 = self.MockProcessor("processor1")
|
||||
self.processor2 = self.MockProcessor("processor2")
|
||||
self.pipeline = Pipeline(processors=[self.processor1, self.processor2])
|
||||
self.pipeline._name = "MyClass"
|
||||
self.pipeline._logger = Mock()
|
||||
|
||||
@unittest.skip("FIXME: This test is failing")
|
||||
def test_log_frame_from_source(self):
|
||||
frame = Mock(__class__=Mock(__name__="MyFrame"))
|
||||
self.pipeline._log_frame(frame, depth=1)
|
||||
self.pipeline._logger.debug.assert_called_once_with(
|
||||
"MyClass source -> MyFrame -> processor1"
|
||||
)
|
||||
|
||||
@unittest.skip("FIXME: This test is failing")
|
||||
def test_log_frame_to_sink(self):
|
||||
frame = Mock(__class__=Mock(__name__="MyFrame"))
|
||||
self.pipeline._log_frame(frame, depth=3)
|
||||
self.pipeline._logger.debug.assert_called_once_with(
|
||||
"MyClass processor2 -> MyFrame -> sink"
|
||||
)
|
||||
|
||||
@unittest.skip("FIXME: This test is failing")
|
||||
def test_log_frame_repeated_log(self):
|
||||
frame = Mock(__class__=Mock(__name__="MyFrame"))
|
||||
self.pipeline._log_frame(frame, depth=2)
|
||||
self.pipeline._logger.debug.assert_called_once_with(
|
||||
"MyClass processor1 -> MyFrame -> processor2"
|
||||
)
|
||||
self.pipeline._log_frame(frame, depth=2)
|
||||
self.pipeline._logger.debug.assert_called_with("MyClass ... repeated")
|
||||
|
||||
@unittest.skip("FIXME: This test is failing")
|
||||
def test_log_frame_reset_repeated_log(self):
|
||||
frame1 = Mock(__class__=Mock(__name__="MyFrame1"))
|
||||
frame2 = Mock(__class__=Mock(__name__="MyFrame2"))
|
||||
self.pipeline._log_frame(frame1, depth=2)
|
||||
self.pipeline._logger.debug.assert_called_once_with(
|
||||
"MyClass processor1 -> MyFrame1 -> processor2"
|
||||
)
|
||||
self.pipeline._log_frame(frame1, depth=2)
|
||||
self.pipeline._logger.debug.assert_called_with("MyClass ... repeated")
|
||||
self.pipeline._log_frame(frame2, depth=2)
|
||||
self.pipeline._logger.debug.assert_called_with(
|
||||
"MyClass processor1 -> MyFrame2 -> processor2"
|
||||
)
|
||||
await task.queue_frame(TextFrame(text="Hello!"))
|
||||
try:
|
||||
await asyncio.wait_for(task.run(), timeout=1.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
assert heartbeats_counter == expected_heartbeats
|
||||
|
||||
@@ -1,6 +1,16 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import AudioRawFrame, TextFrame, TranscriptionFrame
|
||||
from pipecat.frames.frames import (
|
||||
OutputAudioRawFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.serializers.protobuf import ProtobufFrameSerializer
|
||||
|
||||
|
||||
@@ -8,22 +18,19 @@ class TestProtobufFrameSerializer(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.serializer = ProtobufFrameSerializer()
|
||||
|
||||
@unittest.skip("FIXME: This test is failing")
|
||||
async def test_roundtrip(self):
|
||||
text_frame = TextFrame(text="hello world")
|
||||
frame = self.serializer.deserialize(self.serializer.serialize(text_frame))
|
||||
self.assertEqual(frame, TextFrame(text="hello world"))
|
||||
self.assertEqual(text_frame, frame)
|
||||
|
||||
transcription_frame = TranscriptionFrame(
|
||||
text="Hello there!", participantId="123", timestamp="2021-01-01"
|
||||
text="Hello there!", user_id="123", timestamp="2021-01-01"
|
||||
)
|
||||
frame = self.serializer.deserialize(self.serializer.serialize(transcription_frame))
|
||||
self.assertEqual(frame, transcription_frame)
|
||||
|
||||
audio_frame = AudioRawFrame(data=b"1234567890")
|
||||
audio_frame = OutputAudioRawFrame(audio=b"1234567890", sample_rate=16000, num_channels=1)
|
||||
frame = self.serializer.deserialize(self.serializer.serialize(audio_frame))
|
||||
self.assertEqual(frame, audio_frame)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
self.assertEqual(frame.audio, audio_frame.audio)
|
||||
self.assertEqual(frame.sample_rate, audio_frame.sample_rate)
|
||||
self.assertEqual(frame.num_channels, audio_frame.num_channels)
|
||||
|
||||
@@ -1,28 +1,15 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import unittest
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.frames.frames import EndFrame, Frame, TextFrame
|
||||
from pipecat.services.ai_services import AIService, match_endofsentence
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
|
||||
|
||||
class SimpleAIService(AIService):
|
||||
async def process_frame(self, frame: Frame) -> AsyncGenerator[Frame, None]:
|
||||
yield frame
|
||||
|
||||
|
||||
class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_simple_processing(self):
|
||||
service = SimpleAIService()
|
||||
|
||||
input_frames = [TextFrame("hello"), EndFrame()]
|
||||
|
||||
output_frames = []
|
||||
for input_frame in input_frames:
|
||||
async for output_frame in service.process_frame(input_frame):
|
||||
output_frames.append(output_frame)
|
||||
|
||||
self.assertEqual(input_frames, output_frames)
|
||||
|
||||
class TestUtilsString(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_endofsentence(self):
|
||||
assert match_endofsentence("This is a sentence.")
|
||||
assert match_endofsentence("This is a sentence! ")
|
||||
@@ -51,7 +38,3 @@ class TestBaseAIService(unittest.IsolatedAsyncioTestCase):
|
||||
for i in chinese_sentences:
|
||||
assert match_endofsentence(i)
|
||||
assert not match_endofsentence("你好,")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,3 +1,9 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
# import asyncio
|
||||
# import unittest
|
||||
# from unittest.mock import AsyncMock, patch, Mock
|
||||
|
||||
120
tests/utils.py
Normal file
120
tests/utils.py
Normal file
@@ -0,0 +1,120 @@
|
||||
#
|
||||
# Copyright (c) 2024-2025 Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Sequence, Tuple
|
||||
|
||||
from pipecat.clocks.system_clock import SystemClock
|
||||
from pipecat.frames.frames import (
|
||||
ControlFrame,
|
||||
Frame,
|
||||
HeartbeatFrame,
|
||||
StartFrame,
|
||||
)
|
||||
from pipecat.observers.base_observer import BaseObserver
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
@dataclass
|
||||
class EndTestFrame(ControlFrame):
|
||||
pass
|
||||
|
||||
|
||||
class HeartbeatsObserver(BaseObserver):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
target: FrameProcessor,
|
||||
heartbeat_callback: Callable[[FrameProcessor, HeartbeatFrame], Awaitable[None]],
|
||||
):
|
||||
self._target = target
|
||||
self._callback = heartbeat_callback
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
src: FrameProcessor,
|
||||
dst: FrameProcessor,
|
||||
frame: Frame,
|
||||
direction: FrameDirection,
|
||||
timestamp: int,
|
||||
):
|
||||
if src == self._target and isinstance(frame, HeartbeatFrame):
|
||||
await self._callback(self._target, frame)
|
||||
|
||||
|
||||
class QueuedFrameProcessor(FrameProcessor):
|
||||
def __init__(self, queue: asyncio.Queue, ignore_start: bool = True):
|
||||
super().__init__()
|
||||
self._queue = queue
|
||||
self._ignore_start = ignore_start
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if self._ignore_start and isinstance(frame, StartFrame):
|
||||
return
|
||||
await self._queue.put(frame)
|
||||
|
||||
|
||||
async def run_test(
|
||||
processor: FrameProcessor,
|
||||
frames_to_send: Sequence[Frame],
|
||||
expected_down_frames: Sequence[type],
|
||||
expected_up_frames: Sequence[type] = [],
|
||||
) -> Tuple[Sequence[Frame], Sequence[Frame]]:
|
||||
received_up = asyncio.Queue()
|
||||
received_down = asyncio.Queue()
|
||||
up_processor = QueuedFrameProcessor(received_up)
|
||||
down_processor = QueuedFrameProcessor(received_down)
|
||||
|
||||
up_processor.link(processor)
|
||||
processor.link(down_processor)
|
||||
|
||||
await processor.queue_frame(StartFrame(clock=SystemClock()))
|
||||
|
||||
for frame in frames_to_send:
|
||||
await processor.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
await processor.queue_frame(EndTestFrame())
|
||||
await processor.queue_frame(EndTestFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
#
|
||||
# Down frames
|
||||
#
|
||||
received_down_frames: Sequence[Frame] = []
|
||||
running = True
|
||||
while running:
|
||||
frame = await received_down.get()
|
||||
running = not isinstance(frame, EndTestFrame)
|
||||
if running:
|
||||
received_down_frames.append(frame)
|
||||
|
||||
print("received DOWN frames =", received_down_frames)
|
||||
|
||||
assert len(received_down_frames) == len(expected_down_frames)
|
||||
|
||||
for real, expected in zip(received_down_frames, expected_down_frames):
|
||||
assert isinstance(real, expected)
|
||||
|
||||
#
|
||||
# Up frames
|
||||
#
|
||||
received_up_frames: Sequence[Frame] = []
|
||||
running = True
|
||||
while running:
|
||||
frame = await received_up.get()
|
||||
running = not isinstance(frame, EndTestFrame)
|
||||
if running:
|
||||
received_up_frames.append(frame)
|
||||
|
||||
print("received UP frames =", received_up_frames)
|
||||
|
||||
assert len(received_up_frames) == len(expected_up_frames)
|
||||
|
||||
for real, expected in zip(received_up_frames, expected_up_frames):
|
||||
assert isinstance(real, expected)
|
||||
|
||||
return (received_down_frames, received_up_frames)
|
||||
Reference in New Issue
Block a user