Compare commits

..

1 Commits

Author SHA1 Message Date
James Hush
6d54ba14e2 fix: add exception handling for FAL smart turn detection 2025-04-29 13:47:23 +08:00
24 changed files with 147 additions and 585 deletions

View File

@@ -9,15 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `RTVIObserverParams` which allows you to configure what RTVI messages
are sent to the clients.
- Added a `context_window_compression` InputParam to
`GeminiMultimodalLiveLLMService` which allows you to enable a sliding context
window for the session as well as set the token limit of the sliding window.
- Updated `SmallWebRTCConnection` to support `ice_servers` with credentials.
- Added `VADUserStartedSpeakingFrame` and `VADUserStoppedSpeakingFrame`,
indicating when the VAD detected the user to start and stop speaking. These
events are helpful when using smart turn detection, as the user's stop time
@@ -32,15 +23,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MCPClient`; a way to connect to MCP servers and use the MCP servers'
tools.
- Added `Mem0 OSS`, along with Mem0 cloud support now the OSS version is also
available.
- Added `Mem0 OSS`, along with Mem0 cloud support now the OSS version is also available.
### Changed
- The `STTMuteFilter` now mutes `InterimTranscriptionFrame` and
`TranscriptionFrame` which allows the `STTMuteFilter` to be used in
conjunction with transports that generate transcripts, e.g. `DailyTransport`.
- Function calls now receive a single parameter `FunctionCallParams` instead of
`(function_name, tool_call_id, args, llm, context, result_callback)` which is
now deprecated.
@@ -72,9 +58,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Deprecated
- Function calls with parameters
`(function_name, tool_call_id, args, llm, context, result_callback)` are
deprectated, use a single `FunctionCallParams` parameter instead.
- Function calls with parameters `(function_name, tool_call_id, args, llm,
context, result_callback)` are deprectated, use a single `FunctionCallParams`
parameter instead.
- `TransportParams.camera_*` parameters are now deprecated, use
`TransportParams.video_*` instead.
@@ -87,13 +73,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed an issue with `GeminiMultimodalLiveLLMService` where the context
contained tokens instead of words.
- Fixed an issue with HTTP Smart Turn handling, where the service returns a 500
error. Previously, this would cause an unhandled exception. Now, a 500 error
is treated as an incomplete response.
- Fixed a TTS services issue that could cause assistant output not to be
aggregated to the context when also using `TTSSpeakFrame`s.
@@ -211,9 +190,8 @@ https://en.wikipedia.org/wiki/Saint_George%27s_Day_in_Catalonia
- Fixed an issue in `SmallWebRTCTransport` where an error was thrown if the
client did not create a video transceiver.
- Fixed an issue where LLM input parameters were not working and applied
correctly in `GoogleVertexLLMService`, causing unexpected behavior during
inference.
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
unexpected behavior during inference.
### Other

View File

@@ -50,6 +50,7 @@ autodoc_mock_imports = [
"pyht.protos",
"pyht.protos.api_pb2",
"pipecat_ai_playht", # PlayHT wrapper
"vllm",
"aiortc",
"aiortc.mediastreams",
"cv2",
@@ -75,6 +76,7 @@ autodoc_mock_imports = [
"openpipe",
"simli",
"soundfile",
# Existing mocks
"pipecat_ai_krisp",
"pyaudio",
"_tkinter",
@@ -85,66 +87,6 @@ autodoc_mock_imports = [
"pydantic.Field",
"pydantic._internal._model_construction",
"pydantic._internal._fields",
# Moondream dependencies
"torch",
"transformers",
"intel_extension_for_pytorch",
# Ultravox dependencies
"huggingface_hub",
"vllm",
"vllm.engine.arg_utils",
"transformers.AutoTokenizer",
# Langchain dependencies
"langchain_core",
"langchain_core.messages",
"langchain_core.runnables",
"langchain_core.messages.AIMessageChunk",
"langchain_core.runnables.Runnable",
# LiveKit dependencies
"livekit",
"livekit.rtc",
"livekit_api",
"livekit_protocol",
"tenacity",
"tenacity.retry",
"tenacity.stop_after_attempt",
"tenacity.wait_exponential",
"rtc",
"rtc.Room",
"rtc.RoomOptions",
"rtc.AudioSource",
"rtc.LocalAudioTrack",
"rtc.TrackPublishOptions",
"rtc.TrackSource",
"rtc.AudioStream",
"rtc.AudioFrameEvent",
"rtc.AudioFrame",
"rtc.Track",
"rtc.TrackKind",
"rtc.RemoteParticipant",
"rtc.RemoteTrackPublication",
"rtc.DataPacket",
# Riva dependencies
"riva",
"riva.client",
"riva.client.Auth",
"riva.client.ASRService",
"riva.client.StreamingRecognitionConfig",
"riva.client.RecognitionConfig",
"riva.client.AudioEncoding",
"riva.client.proto.riva_tts_pb2",
"riva.client.SpeechSynthesisService",
# Local CoreML Smart Turn dependencies
"coremltools",
"coremltools.models",
"coremltools.models.MLModel",
"torch",
"torch.nn",
"torch.nn.functional",
"transformers",
"transformers.AutoFeatureExtractor",
# Also add specific classes that are imported
"AutoFeatureExtractor",
]
# HTML output settings
@@ -176,25 +118,12 @@ def verify_modules():
},
}
# Skip importing modules that are in autodoc_mock_imports
skipped_modules = set(autodoc_mock_imports)
missing = []
for category, modules in required_modules.items():
if isinstance(modules, dict):
# Handle nested structure
for subcategory, submodules in modules.items():
for module in submodules:
# Check if module is in autodoc_mock_imports
if (
f"pipecat.{category}.{subcategory}.{module}" in skipped_modules
or module in skipped_modules
):
logger.info(
f"Skipping import of mocked module: pipecat.{category}.{subcategory}.{module}"
)
continue
try:
__import__(f"pipecat.{category}.{subcategory}.{module}")
logger.info(
@@ -208,11 +137,6 @@ def verify_modules():
else:
# Handle flat structure
for module in modules:
# Check if module is in autodoc_mock_imports
if f"pipecat.{category}.{module}" in skipped_modules or module in skipped_modules:
logger.info(f"Skipping import of mocked module: pipecat.{category}.{module}")
continue
try:
__import__(f"pipecat.{category}.{module}")
logger.info(f"Successfully imported pipecat.{category}.{module}")

View File

@@ -26,23 +26,20 @@ pipecat-ai[grok]
pipecat-ai[groq]
# pipecat-ai[krisp] # Mocked
pipecat-ai[koala]
# pipecat-ai[langchain] # Mocked
# pipecat-ai[livekit] # Mocked
pipecat-ai[langchain]
pipecat-ai[livekit]
pipecat-ai[lmnt]
pipecat-ai[local]
# pipecat-ai[local-smart-turn] # Mocked
# pipecat-ai[mem0] # Mocked
# pipecat-ai[mlx-whisper] # Mocked
# pipecat-ai[moondream] # Mocked
pipecat-ai[moondream]
pipecat-ai[nim]
# pipecat-ai[neuphonic] # Mocked
pipecat-ai[noisereduce]
pipecat-ai[openai]
# pipecat-ai[openpipe]
# pipecat-ai[playht] # Mocked due to grpcio conflict with riva
pipecat-ai[qwen]
pipecat-ai[remote-smart-turn]
# pipecat-ai[riva] # Mocked
pipecat-ai[riva]
pipecat-ai[silero]
pipecat-ai[simli]
pipecat-ai[soundfile]

View File

@@ -1,4 +1,5 @@
python-dotenv==1.0.1
modal==0.71.3
pipecat-ai[daily,silero,cartesia,openai]
pipecat-ai[daily,silero,cartesia,openai]==0.0.52
fastapi==0.115.6
aiohttp==3.11.11

View File

@@ -89,7 +89,6 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
llm = GeminiMultimodalLiveLLMService(
api_key=os.getenv("GOOGLE_API_KEY"),
system_instruction=system_instruction,
transcribe_user_audio=True,
tools=tools,
)

View File

@@ -20,7 +20,7 @@ from fastapi.responses import RedirectResponse
from loguru import logger
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
from pipecat.transports.network.webrtc_connection import IceServer, SmallWebRTCConnection
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
# Load environment variables
load_dotenv(override=True)
@@ -30,11 +30,7 @@ app = FastAPI()
# Store connections by pc_id
pcs_map: Dict[str, SmallWebRTCConnection] = {}
ice_servers = [
IceServer(
urls="stun:stun.l.google.com:19302",
)
]
ice_servers = ["stun:stun.l.google.com:19302"]
# Mount the frontend at /
app.mount("/client", SmallWebRTCPrebuiltUI)

View File

@@ -18,7 +18,7 @@ from fastapi.responses import RedirectResponse
from loguru import logger
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
from pipecat.transports.network.webrtc_connection import IceServer, SmallWebRTCConnection
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
# Load environment variables
load_dotenv(override=True)
@@ -28,11 +28,7 @@ app = FastAPI()
# Store connections by pc_id
pcs_map: Dict[str, SmallWebRTCConnection] = {}
ice_servers = [
IceServer(
urls="stun:stun.l.google.com:19302",
)
]
ice_servers = ["stun:stun.l.google.com:19302"]
# Mount the frontend at /
app.mount("/prebuilt", SmallWebRTCPrebuiltUI)

View File

@@ -18,7 +18,7 @@ from fastapi.responses import RedirectResponse
from loguru import logger
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
from pipecat.transports.network.webrtc_connection import IceServer, SmallWebRTCConnection
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
# Load environment variables
load_dotenv(override=True)
@@ -28,11 +28,7 @@ app = FastAPI()
# Store connections by pc_id
pcs_map: Dict[str, SmallWebRTCConnection] = {}
ice_servers = [
IceServer(
urls="stun:stun.l.google.com:19302",
)
]
ice_servers = ["stun:stun.l.google.com:19302"]
# Mount the frontend at /
app.mount("/prebuilt", SmallWebRTCPrebuiltUI)

View File

@@ -46,46 +46,6 @@ http://localhost:7860
---
## WebRTC ICE Servers Configuration
When implementing WebRTC in your project, **STUN** (Session Traversal Utilities for NAT) and **TURN** (Traversal Using Relays around NAT)
servers are usually needed in cases where users are behind routers or firewalls.
In local networks (e.g., testing within the same home or office network), you usually dont need to configure STUN or TURN servers.
In such cases, WebRTC can often directly establish peer-to-peer connections without needing to traverse NAT or firewalls.
### What are STUN and TURN Servers?
- **STUN Server**: Helps clients discover their public IP address and port when they're behind a NAT (Network Address Translation) device (like a router).
This allows WebRTC to attempt direct peer-to-peer communication by providing the public-facing IP and port.
- **TURN Server**: Used as a fallback when direct peer-to-peer communication isn't possible due to strict NATs or firewalls blocking connections.
The TURN server relays media traffic between peers.
### Why are ICE Servers Important?
**ICE (Interactive Connectivity Establishment)** is a framework used by WebRTC to handle network traversal and NAT issues.
The `iceServers` configuration provides a list of **STUN** and **TURN** servers that WebRTC uses to find the best way to connect two peers.
### Example Configuration for ICE Servers
Heres how you can configure a basic `iceServers` object in WebRTC for testing purposes, using Google's public STUN server:
```javascript
const config = {
iceServers: [
{
urls: ["stun:stun.l.google.com:19302"], // Google's public STUN server
}
],
};
```
> For testing purposes, you can either use public **STUN** servers (like Google's) or set up your own **TURN** server.
If you're running your own TURN server, make sure to include your server URL, username, and credential in the configuration.
---
### 💡 Notes
- Ensure all dependencies are installed before running the server.
- Check the `.env` file for missing configurations.

View File

@@ -24,47 +24,27 @@
let connected = false
let peerConnection = null
const waitForIceGatheringComplete = async (pc, timeoutMs = 2000) => {
/*const waitForIceGatheringComplete = async (pc) => {
if (pc.iceGatheringState === 'complete') return;
console.log("Waiting for ICE gathering to complete. Current state:", pc.iceGatheringState);
return new Promise((resolve) => {
let timeoutId;
const checkState = () => {
console.log("icegatheringstatechange:", pc.iceGatheringState);
if (pc.iceGatheringState === 'complete') {
cleanup();
pc.removeEventListener('icegatheringstatechange', checkState);
resolve();
}
};
const onTimeout = () => {
console.warn(`ICE gathering timed out after ${timeoutMs} ms.`);
cleanup();
resolve();
};
const cleanup = () => {
pc.removeEventListener('icegatheringstatechange', checkState);
clearTimeout(timeoutId);
};
pc.addEventListener('icegatheringstatechange', checkState);
timeoutId = setTimeout(onTimeout, timeoutMs);
// Checking the state again to avoid any eventual race condition
checkState();
});
};
}*/
const createSmallWebRTCConnection = async (audioTrack) => {
const config = {
iceServers: [],
};
const pc = new RTCPeerConnection(config)
addPeerConnectionEventListeners(pc)
const pc = new RTCPeerConnection()
pc.ontrack = e => audioEl.srcObject = e.streams[0]
// SmallWebRTCTransport expects to receive both transceivers
pc.addTransceiver(audioTrack, { direction: 'sendrecv' })
pc.addTransceiver('video', { direction: 'sendrecv' })
await pc.setLocalDescription(await pc.createOffer())
await waitForIceGatheringComplete(pc)
//await waitForIceGatheringComplete(pc)
const offer = pc.localDescription
const response = await fetch('/api/offer', {
body: JSON.stringify({ sdp: offer.sdp, type: offer.type}),
@@ -77,37 +57,16 @@
}
const connect = async () => {
_onConnecting()
const audioStream = await navigator.mediaDevices.getUserMedia({audio: true})
peerConnection= await createSmallWebRTCConnection(audioStream.getAudioTracks()[0])
}
const addPeerConnectionEventListeners = (pc) => {
pc.oniceconnectionstatechange = () => {
console.log("oniceconnectionstatechange", pc?.iceConnectionState)
}
pc.onconnectionstatechange = () => {
console.log("onconnectionstatechange", pc?.connectionState)
let connectionState = pc?.connectionState
peerConnection.onconnectionstatechange = () => {
let connectionState = peerConnection?.connectionState
if (connectionState === 'connected') {
_onConnected()
} else if (connectionState === 'disconnected') {
_onDisconnected()
}
}
pc.onicecandidate = (event) => {
if (event.candidate) {
console.log("New ICE candidate:", event.candidate);
} else {
console.log("All ICE candidates have been sent.");
}
};
}
const _onConnecting = () => {
statusEl.textContent = "Connecting"
buttonEl.textContent = "Disconnect"
connected = true
}
const _onConnected = () => {

View File

@@ -17,7 +17,7 @@ from fastapi import BackgroundTasks, FastAPI
from fastapi.responses import FileResponse
from loguru import logger
from pipecat.transports.network.webrtc_connection import IceServer, SmallWebRTCConnection
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
# Load environment variables
load_dotenv(override=True)
@@ -28,13 +28,6 @@ app = FastAPI()
pcs_map: Dict[str, SmallWebRTCConnection] = {}
ice_servers = [
IceServer(
urls="stun:stun.l.google.com:19302",
)
]
@app.post("/api/offer")
async def offer(request: dict, background_tasks: BackgroundTasks):
pc_id = request.get("pc_id")
@@ -44,7 +37,7 @@ async def offer(request: dict, background_tasks: BackgroundTasks):
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
await pipecat_connection.renegotiate(sdp=request["sdp"], type=request["type"])
else:
pipecat_connection = SmallWebRTCConnection(ice_servers)
pipecat_connection = SmallWebRTCConnection()
await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"])
@pipecat_connection.event_handler("closed")

View File

@@ -1,18 +1,17 @@
import {
RTVIClientAudio,
RTVIClientVideo,
useRTVIClient,
useRTVIClientTransportState,
} from "@pipecat-ai/client-react";
import { RTVIProvider } from "./providers/RTVIProvider";
import { ConnectButton } from "./components/ConnectButton";
import { StatusDisplay } from "./components/StatusDisplay";
import { DebugDisplay } from "./components/DebugDisplay";
import "./App.css";
} from '@pipecat-ai/client-react';
import { RTVIProvider } from './providers/RTVIProvider';
import { ConnectButton } from './components/ConnectButton';
import { StatusDisplay } from './components/StatusDisplay';
import { DebugDisplay } from './components/DebugDisplay';
import './App.css';
function BotVideo() {
const transportState = useRTVIClientTransportState();
const isConnected = transportState !== "disconnected";
const isConnected = transportState !== 'disconnected';
return (
<div className="bot-container">
@@ -24,31 +23,11 @@ function BotVideo() {
}
function AppContent() {
const client = useRTVIClient();
return (
<div className="app">
<div className="status-bar">
<StatusDisplay />
<ConnectButton />
<div
className="controls"
onClick={async () => {
if (!client) {
console.error("RTVI client is not initialized");
return;
}
client.action({
service: "tts",
action: "say",
arguments: [
{ name: "text", value: "Hello, world!" },
{ name: "interrupt", value: false },
],
});
}}
>
<button className="connect-btn">Say something</button>
</div>
</div>
<div className="main-content">

View File

@@ -20,7 +20,6 @@ the conversation flow.
import asyncio
import os
import sys
from typing import Any, Dict
import aiohttp
from dotenv import load_dotenv
@@ -33,24 +32,15 @@ from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
Frame,
LLMMessagesAppendFrame,
OutputImageRawFrame,
SpriteFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frameworks.rtvi import (
ActionResult,
RTVIAction,
RTVIActionArgument,
RTVIObserver,
RTVIProcessor,
RTVIService,
)
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
@@ -191,57 +181,8 @@ async def main():
#
# RTVI events for Pipecat client UI
#
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
rtvi_tts = RTVIService(
name="tts",
options=[],
)
async def action_tts_say_handler(
rtvi: RTVIProcessor, service: str, arguments: Dict[str, Any]
) -> ActionResult:
if "interrupt" in arguments and arguments["interrupt"]:
# interrupting breaks function handling
await rtvi.interrupt_bot()
if "text" in arguments:
save = arguments["save"] if "save" in arguments else False
frame = TTSSpeakFrame(text=arguments["text"])
await rtvi.push_frame(frame)
if save:
llm_frame = LLMMessagesAppendFrame(
messages=[{"role": "assistant", "content": arguments["text"]}]
)
await rtvi.push_frame(llm_frame)
return True
action_tts_say = RTVIAction(
service="tts",
action="say",
result="bool",
arguments=[
RTVIActionArgument(name="text", type="string"),
RTVIActionArgument(name="save_in_context", type="bool"),
],
handler=action_tts_say_handler,
)
async def action_tts_interrupt_handler(
rtvi: RTVIProcessor, service: str, arguments: Dict[str, Any]
) -> ActionResult:
await rtvi.interrupt_bot()
return True
action_tts_interrupt = RTVIAction(
service="tts", action="interrupt", result="bool", handler=action_tts_interrupt_handler
)
rtvi.register_service(rtvi_tts)
rtvi.register_action(action_tts_say)
rtvi.register_action(action_tts_interrupt)
pipeline = Pipeline(
[
transport.input(),

View File

@@ -176,6 +176,8 @@ class BaseSmartTurn(BaseTurnAnalyzer):
f"End of Turn complete due to stop_secs. Silence in ms: {self._silence_ms}"
)
state = EndOfTurnState.COMPLETE
except Exception as e:
logger.error(f"Error during prediction: {e}")
else:
logger.trace(f"params: {self._params}, stop_ms: {self._stop_ms}")

View File

@@ -40,7 +40,7 @@ class HttpSmartTurnAnalyzer(BaseSmartTurn):
async def _send_raw_request(self, data_bytes: bytes) -> Dict[str, Any]:
headers = {"Content-Type": "application/octet-stream"}
headers.update(self._headers)
logger.trace(f"Sending {len(data_bytes)} bytes as raw body to {self._url}...")
try:
timeout = aiohttp.ClientTimeout(total=self._params.stop_secs)
@@ -50,30 +50,23 @@ class HttpSmartTurnAnalyzer(BaseSmartTurn):
logger.trace("\n--- Response ---")
logger.trace(f"Status Code: {response.status}")
# Check if successful
if response.status != 200:
if response.status == 200:
try:
json_data = await response.json()
logger.trace("Response JSON:")
logger.trace(json_data)
return json_data
except aiohttp.ContentTypeError:
# Non-JSON response
text = await response.text()
logger.trace("Response Content (non-JSON):")
logger.trace(text)
raise Exception(f"Non-JSON response: {text}")
else:
error_text = await response.text()
logger.trace("Response Content (Error):")
logger.trace(error_text)
if response.status == 500:
logger.warning(f"Smart turn service returned 500 error: {error_text}")
raise Exception(f"Server returned HTTP 500: {error_text}")
else:
response.raise_for_status()
# Process successful response
try:
json_data = await response.json()
logger.trace("Response JSON:")
logger.trace(json_data)
return json_data
except aiohttp.ContentTypeError:
# Non-JSON response
text = await response.text()
logger.trace("Response Content (non-JSON):")
logger.trace(text)
raise Exception(f"Non-JSON response: {text}")
response.raise_for_status()
except asyncio.TimeoutError:
logger.error(f"Request timed out after {self._params.stop_secs} seconds")
@@ -83,14 +76,5 @@ class HttpSmartTurnAnalyzer(BaseSmartTurn):
raise Exception("Failed to send raw request to Daily Smart Turn.")
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
try:
serialized_array = self._serialize_array(audio_array)
return await self._send_raw_request(serialized_array)
except Exception as e:
logger.error(f"Smart turn prediction failed: {str(e)}")
# Return an incomplete prediction when a failure occurs
return {
"prediction": 0,
"probability": 0.0,
"metrics": {"inference_time": 0.0, "total_time": 0.0},
}
serialized_array = self._serialize_array(audio_array)
return await self._send_raw_request(serialized_array)

View File

@@ -24,12 +24,10 @@ from pipecat.frames.frames import (
FunctionCallInProgressFrame,
FunctionCallResultFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
StartFrame,
StartInterruptionFrame,
StopInterruptionFrame,
STTMuteFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -177,8 +175,6 @@ class STTMuteFilter(FrameProcessor):
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
TranscriptionFrame,
),
):
# Only pass VAD-related frames when not muted

View File

@@ -61,9 +61,6 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.llm_service import (
FunctionCallParams, # TODO(aleix): we shouldn't import `services` from `processors`
)
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport
@@ -395,32 +392,6 @@ class RTVIServerMessageFrame(SystemFrame):
return f"{self.name}(data: {self.data})"
@dataclass
class RTVIObserverParams:
"""
Parameters for configuring RTVI Observer behavior.
Attributes:
bot_llm_enabled (bool): Indicates if the bot's LLM messages should be sent.
bot_tts_enabled (bool): Indicates if the bot's TTS messages should be sent.
bot_speaking_enabled (bool): Indicates if the bot's started/stopped speaking messages should be sent.
user_llm_enabled (bool): Indicates if the user's LLM input messages should be sent.
user_speaking_enabled (bool): Indicates if the user's started/stopped speaking messages should be sent.
user_transcription_enabled (bool): Indicates if user's transcription messages should be sent.
metrics_enabled (bool): Indicates if metrics messages should be sent.
errors_enabled (bool): Indicates if errors messages should be sent.
"""
bot_llm_enabled: bool = True
bot_tts_enabled: bool = True
bot_speaking_enabled: bool = True
user_llm_enabled: bool = True
user_speaking_enabled: bool = True
user_transcription_enabled: bool = True
metrics_enabled: bool = True
errors_enabled: bool = True
class RTVIObserver(BaseObserver):
"""Pipeline frame observer for RTVI server message handling.
@@ -433,17 +404,14 @@ class RTVIObserver(BaseObserver):
are handled by the RTVIProcessor.
Args:
rtvi (RTVIProcessor): The RTVI processor to push frames to.
params (RTVIObserverParams): Settings to enable/disable specific messages.
rtvi (FrameProcessor): The RTVI processor to push frames to.
"""
def __init__(self, rtvi: "RTVIProcessor", *, params: RTVIObserverParams = RTVIObserverParams()):
def __init__(self, rtvi: FrameProcessor):
super().__init__()
self._rtvi = rtvi
self._params = params
self._bot_transcription = ""
self._frames_seen = set()
rtvi.set_errors_enabled(self._params.errors_enabled)
async def on_push_frame(
self,
@@ -470,41 +438,35 @@ class RTVIObserver(BaseObserver):
# again the next time we see the frame.
mark_as_seen = True
if (
isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame))
and self._params.user_speaking_enabled
):
if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)):
await self._handle_interruptions(frame)
elif (
isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame))
and (direction == FrameDirection.UPSTREAM)
and self._params.bot_speaking_enabled
elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)) and (
direction == FrameDirection.UPSTREAM
):
await self._handle_bot_speaking(frame)
elif (
isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame))
and self._params.user_transcription_enabled
):
elif isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)):
await self._handle_user_transcriptions(frame)
elif isinstance(frame, OpenAILLMContextFrame) and self._params.user_llm_enabled:
elif isinstance(frame, OpenAILLMContextFrame):
await self._handle_context(frame)
elif isinstance(frame, LLMFullResponseStartFrame) and self._params.bot_llm_enabled:
elif isinstance(frame, UserStartedSpeakingFrame):
await self._push_bot_transcription()
elif isinstance(frame, LLMFullResponseStartFrame):
await self.push_transport_message_urgent(RTVIBotLLMStartedMessage())
elif isinstance(frame, LLMFullResponseEndFrame) and self._params.bot_llm_enabled:
elif isinstance(frame, LLMFullResponseEndFrame):
await self.push_transport_message_urgent(RTVIBotLLMStoppedMessage())
elif isinstance(frame, LLMTextFrame) and self._params.bot_llm_enabled:
elif isinstance(frame, LLMTextFrame):
await self._handle_llm_text_frame(frame)
elif isinstance(frame, TTSStartedFrame) and self._params.bot_tts_enabled:
elif isinstance(frame, TTSStartedFrame):
await self.push_transport_message_urgent(RTVIBotTTSStartedMessage())
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
elif isinstance(frame, TTSStoppedFrame):
await self.push_transport_message_urgent(RTVIBotTTSStoppedMessage())
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
elif isinstance(frame, TTSTextFrame):
if isinstance(src, BaseOutputTransport):
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
await self.push_transport_message_urgent(message)
else:
mark_as_seen = False
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
elif isinstance(frame, MetricsFrame):
await self._handle_metrics(frame)
elif isinstance(frame, RTVIServerMessageFrame):
message = RTVIServerMessage(data=frame.data)
@@ -647,7 +609,6 @@ class RTVIProcessor(FrameProcessor):
self._bot_ready = False
self._client_ready = False
self._client_ready_id = ""
self._errors_enabled = True
self._registered_actions: Dict[str, RTVIAction] = {}
self._registered_services: Dict[str, RTVIService] = {}
@@ -687,23 +648,26 @@ class RTVIProcessor(FrameProcessor):
await self._update_config(self._config, False)
await self._send_bot_ready()
def set_errors_enabled(self, enabled: bool):
self._errors_enabled = enabled
async def interrupt_bot(self):
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
async def send_error(self, error: str):
await self._send_error_frame(ErrorFrame(error=error))
message = RTVIError(data=RTVIErrorData(error=error, fatal=False))
await self._push_transport_message(message)
async def handle_message(self, message: RTVIMessage):
await self._message_queue.put(message)
async def handle_function_call(self, params: FunctionCallParams):
async def handle_function_call(
self,
function_name: str,
tool_call_id: str,
arguments: Mapping[str, Any],
):
fn = RTVILLMFunctionCallMessageData(
function_name=params.function_name,
tool_call_id=params.tool_call_id,
arguments=params.arguments,
function_name=function_name,
tool_call_id=tool_call_id,
arguments=arguments,
)
message = RTVILLMFunctionCallMessage(data=fn)
await self._push_transport_message(message, exclude_none=False)
@@ -953,14 +917,12 @@ class RTVIProcessor(FrameProcessor):
await self._push_transport_message(message)
async def _send_error_frame(self, frame: ErrorFrame):
if self._errors_enabled:
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
await self._push_transport_message(message)
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
await self._push_transport_message(message)
async def _send_error_response(self, id: str, error: str):
if self._errors_enabled:
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
await self._push_transport_message(message)
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
await self._push_transport_message(message)
def _action_id(self, service: str, action: str) -> str:
return f"{service}:{action}"

View File

@@ -93,55 +93,49 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
"""Aggregates and emits text fragments as a transcript message.
This method uses a heuristic to automatically detect whether text fragments
contain embedded spacing (spaces at the beginning or end of fragments) or not,
and applies the appropriate joining strategy. It handles fragments from different
TTS services with different formatting patterns.
use pre-spacing (spaces at the beginning of fragments) or not, and applies
the appropriate joining strategy. It handles fragments from different TTS
services with different formatting patterns.
Examples:
Fragments with embedded spacing (concatenated):
Pre-spaced fragments (concatenated):
```
TTSTextFrame: ["Hello"]
TTSTextFrame: [" there"] # Leading space
TTSTextFrame: [" there"]
TTSTextFrame: ["!"]
TTSTextFrame: [" How"] # Leading space
TTSTextFrame: [" How"]
TTSTextFrame: ["'s"]
TTSTextFrame: [" it"] # Leading space
TTSTextFrame: [" it"]
TTSTextFrame: [" going"]
TTSTextFrame: ["?"]
```
Result: "Hello there! How's it"
Result: "Hello there! How's it going?"
Fragments with trailing spaces (concatenated):
```
TTSTextFrame: ["Hel"]
TTSTextFrame: ["lo "] # Trailing space
TTSTextFrame: ["to "] # Trailing space
TTSTextFrame: ["you"]
```
Result: "Hello to you"
Word-by-word fragments without spacing (joined with spaces):
Word-by-word fragments (joined with spaces):
```
TTSTextFrame: ["Hello"]
TTSTextFrame: ["there"]
TTSTextFrame: ["how"]
TTSTextFrame: ["are"]
TTSTextFrame: ["you"]
TTSTextFrame: ["there!"]
TTSTextFrame: ["How"]
TTSTextFrame: ["is"]
TTSTextFrame: ["it"]
TTSTextFrame: ["going?"]
```
Result: "Hello there how are you"
Result: "Hello there! How is it going?"
"""
if self._current_text_parts and self._aggregation_start_time:
has_leading_spaces = any(
part and part[0].isspace() for part in self._current_text_parts[1:]
)
has_trailing_spaces = any(
part and part[-1].isspace() for part in self._current_text_parts[:-1]
)
# Heuristic to detect pre-spaced fragments
uses_prespacing = False
if len(self._current_text_parts) > 1:
# Check if any fragment after the first one starts with whitespace
has_spaced_parts = any(
part and part[0].isspace() for part in self._current_text_parts[1:]
)
if has_spaced_parts:
uses_prespacing = True
# If there are embedded spaces in the fragments, use direct concatenation
contains_spacing_between_fragments = has_leading_spaces or has_trailing_spaces
# Apply corresponding joining method
if contains_spacing_between_fragments:
# Fragments already have spacing - just concatenate
# Apply appropriate joining method
if uses_prespacing:
# Pre-spaced fragments - just concatenate
content = "".join(self._current_text_parts)
else:
# Word-by-word fragments - join with spaces

View File

@@ -193,10 +193,3 @@ def parse_server_event(str):
except Exception as e:
print(f"Error parsing server event: {e}")
return None
class ContextWindowCompressionConfig(BaseModel):
"""Configuration for context window compression."""
sliding_window: Optional[bool] = Field(default=True)
trigger_tokens: Optional[int] = Field(default=None)

View File

@@ -223,16 +223,6 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator):
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
# but the GeminiMultimodalLiveAssistantContextAggregator pushes LLMTextFrames and TTSTextFrames. We
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
# are process. This ensures that the context gets only one set of messages.
# GeminiMultimodalLiveLLMService also pushes TranscriptionFrames, so we need to
# ignore pushing those as well, as they're also TextFrames.
async def process_frame(self, frame: Frame, direction: FrameDirection):
if not isinstance(frame, (LLMTextFrame, TranscriptionFrame)):
await super().process_frame(frame, direction)
async def handle_user_image_frame(self, frame: UserImageRawFrame):
# We don't want to store any images in the context. Revisit this later
# when the API evolves.
@@ -275,15 +265,6 @@ class GeminiVADParams(BaseModel):
silence_duration_ms: Optional[int] = Field(default=None)
class ContextWindowCompressionParams(BaseModel):
"""Parameters for context window compression."""
enabled: bool = Field(default=False)
trigger_tokens: Optional[int] = Field(
default=None
) # None = use default (80% of context window)
class InputParams(BaseModel):
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=4096, ge=1)
@@ -299,7 +280,6 @@ class InputParams(BaseModel):
default=GeminiMediaResolution.UNSPECIFIED
)
vad: Optional[GeminiVADParams] = Field(default=None)
context_window_compression: Optional[ContextWindowCompressionParams] = Field(default=None)
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
@@ -354,6 +334,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
self._bot_is_speaking = False
self._user_audio_buffer = bytearray()
self._bot_audio_buffer = bytearray()
self._bot_text_buffer = ""
self._sample_rate = 24000
@@ -374,9 +355,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
"language": self._language_code,
"media_resolution": params.media_resolution,
"vad": params.vad,
"context_window_compression": params.context_window_compression.model_dump()
if params.context_window_compression
else {},
"extra": params.extra if isinstance(params.extra, dict) else {},
}
@@ -436,9 +414,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
#
async def _handle_interruption(self):
self._bot_is_speaking = False
await self.push_frame(TTSStoppedFrame())
await self.push_frame(LLMFullResponseEndFrame())
pass
async def _handle_user_started_speaking(self, frame):
self._user_is_speaking = True
@@ -461,12 +437,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
text = await self._transcribe_audio(audio, context)
if not text:
return
# Sometimes the transcription contains newlines; we want to remove them.
cleaned_text = text.rstrip("\n")
logger.debug(f"[Transcription:user] {cleaned_text}")
context.add_message({"role": "user", "content": [{"type": "text", "text": cleaned_text}]})
logger.debug(f"[Transcription:user] {text}")
context.add_message({"role": "user", "content": [{"type": "text", "text": text}]})
await self.push_frame(
TranscriptionFrame(text=cleaned_text, user_id="user", timestamp=time_now_iso8601())
TranscriptionFrame(text=text, user_id="user", timestamp=time_now_iso8601())
)
async def _transcribe_audio(self, audio, context):
@@ -587,21 +561,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
}
}
# Add context window compression if enabled
if self._settings.get("context_window_compression", {}).get("enabled", False):
compression_config = {}
# Add sliding window (always true if compression is enabled)
compression_config["sliding_window"] = {}
# Add trigger_tokens if specified
trigger_tokens = self._settings.get("context_window_compression", {}).get(
"trigger_tokens"
)
if trigger_tokens is not None:
compression_config["trigger_tokens"] = trigger_tokens
config_data["setup"]["context_window_compression"] = compression_config
# Add VAD configuration if provided
if self._settings.get("vad"):
vad_config = {}
@@ -852,6 +811,14 @@ class GeminiMultimodalLiveLLMService(LLMService):
if not part:
return
text = part.text
if text:
if not self._bot_text_buffer:
await self.push_frame(LLMFullResponseStartFrame())
self._bot_text_buffer += text
await self.push_frame(LLMTextFrame(text=text))
inline_data = part.inlineData
if not inline_data:
return
@@ -866,7 +833,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
if not self._bot_is_speaking:
self._bot_is_speaking = True
await self.push_frame(TTSStartedFrame())
await self.push_frame(LLMFullResponseStartFrame())
self._bot_audio_buffer.extend(audio)
frame = TTSAudioRawFrame(
@@ -892,20 +858,24 @@ class GeminiMultimodalLiveLLMService(LLMService):
async def _handle_evt_turn_complete(self, evt):
self._bot_is_speaking = False
text = self._bot_text_buffer
self._bot_text_buffer = ""
if text:
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(TTSStoppedFrame())
await self.push_frame(LLMFullResponseEndFrame())
async def _handle_evt_output_transcription(self, evt):
if not evt.serverContent.outputTranscription:
return
text = evt.serverContent.outputTranscription.text
if not text:
return
await self.push_frame(LLMTextFrame(text=text))
await self.push_frame(TTSTextFrame(text=text))
if text:
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMTextFrame(text=text))
await self.push_frame(TTSTextFrame(text=text))
await self.push_frame(LLMFullResponseEndFrame())
def create_context_aggregator(
self,
@@ -936,6 +906,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
GeminiMultimodalLiveContext.upgrade(context)
user = GeminiMultimodalLiveUserContextAggregator(context, params=user_params)
assistant_params.expect_stripped_words = False
assistant_params.expect_stripped_words = True
assistant = GeminiMultimodalLiveAssistantContextAggregator(context, params=assistant_params)
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -14,7 +14,6 @@ from pipecat.frames.frames import (
FunctionCallResultFrame,
LLMMessagesUpdateFrame,
LLMSetToolsFrame,
LLMTextFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection
@@ -171,14 +170,6 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
# but the OpenAIRealtimeLLMService pushes LLMTextFrames and TTSTextFrames. We
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
# are process. This ensures that the context gets only one set of messages.
async def process_frame(self, frame: Frame, direction: FrameDirection):
if not isinstance(frame, LLMTextFrame):
await super().process_frame(frame, direction)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
await super().handle_function_call_result(frame)

View File

@@ -74,7 +74,7 @@ class RimeTTSService(AudioContextWordTTSService):
*,
api_key: str,
voice_id: str,
url: str = "wss://users.rime.ai/ws2",
url: str = "wss://users-ws.rime.ai/ws2",
model: str = "mistv2",
sample_rate: Optional[int] = None,
params: InputParams = InputParams(),

View File

@@ -7,7 +7,7 @@
import asyncio
import json
import time
from typing import Any, List, Literal, Optional, Union
from typing import Any, Literal, Optional, Union
from av.frame import Frame
from loguru import logger
@@ -87,21 +87,13 @@ class SmallWebRTCTrack:
return getattr(self._track, name)
# Alias so we don't need to expose RTCIceServer
IceServer = RTCIceServer
class SmallWebRTCConnection(BaseObject):
def __init__(self, ice_servers: Optional[Union[List[str], List[IceServer]]] = None):
def __init__(self, ice_servers=None):
super().__init__()
if not ice_servers:
self.ice_servers: List[IceServer] = []
elif all(isinstance(s, IceServer) for s in ice_servers):
self.ice_servers = ice_servers
elif all(isinstance(s, str) for s in ice_servers):
self.ice_servers = [IceServer(urls=s) for s in ice_servers]
if ice_servers:
self.ice_servers = [RTCIceServer(urls=server) for server in ice_servers]
else:
raise TypeError("ice_servers must be either List[str] or List[RTCIceServer]")
self.ice_servers = []
self._connect_invoked = False
self._track_map = {}
self._track_getters = {

View File

@@ -12,9 +12,7 @@ from pipecat.frames.frames import (
FunctionCallInProgressFrame,
FunctionCallResultFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
STTMuteFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -102,45 +100,6 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
expected_down_frames=expected_returned_frames,
)
async def test_transcription_frames_with_always_strategy(self):
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
frames_to_send = [
# Bot speaking - should mute
BotStartedSpeakingFrame(),
SleepFrame(sleep=0.1), # Wait for StartedSpeaking to process
InterimTranscriptionFrame(
user_id="user1", text="This should be suppressed", timestamp="1234567890"
),
TranscriptionFrame(
user_id="user1", text="This should be suppressed", timestamp="1234567890"
),
SleepFrame(sleep=0.1), # Wait for transcription frames to queue
BotStoppedSpeakingFrame(),
# Bot not speaking - should pass through
InterimTranscriptionFrame(
user_id="user1", text="This should pass", timestamp="1234567891"
),
TranscriptionFrame(
user_id="user1", text="This should pass through", timestamp="1234567891"
),
]
expected_returned_frames = [
BotStartedSpeakingFrame,
STTMuteFrame, # mute=True
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False
InterimTranscriptionFrame, # Only passes through after bot stops speaking
TranscriptionFrame, # Only passes through after bot stops speaking
]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
# TODO: Revisit once we figure out how to test SystemFrames and DataFrames
# async def test_function_call_strategy(self):
# filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL}))