Compare commits
46 Commits
hush/smart
...
hush/rtviS
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd6379cb6a | ||
|
|
7618d7511a | ||
|
|
9ca7bad978 | ||
|
|
10289c1f1c | ||
|
|
213d5d6abc | ||
|
|
01f37f769d | ||
|
|
52b393537a | ||
|
|
a6a4d3d71f | ||
|
|
c52de0f5de | ||
|
|
a1e1255f16 | ||
|
|
c4f758725e | ||
|
|
7bc9a78ce6 | ||
|
|
f8be71b32c | ||
|
|
957fa5546d | ||
|
|
039cb8fcae | ||
|
|
8e05f2f1a1 | ||
|
|
8467aa1ed3 | ||
|
|
9c5878af3d | ||
|
|
ef29800fe9 | ||
|
|
7e09933070 | ||
|
|
82a9d7f992 | ||
|
|
facbebb15f | ||
|
|
2ba60fc41f | ||
|
|
685f951ae2 | ||
|
|
27d4c927a8 | ||
|
|
20a59e8c56 | ||
|
|
d9a0a93667 | ||
|
|
154d5d1859 | ||
|
|
a192217256 | ||
|
|
10cdc47e05 | ||
|
|
2b4d41a548 | ||
|
|
962f8062a5 | ||
|
|
d80d385b2f | ||
|
|
b347ca472f | ||
|
|
c3c4952abf | ||
|
|
f369ab4c1a | ||
|
|
62b41c6789 | ||
|
|
2872bc7902 | ||
|
|
9658b75a10 | ||
|
|
63de9039e6 | ||
|
|
9352396d7e | ||
|
|
d1ab1d38b7 | ||
|
|
080f70d91c | ||
|
|
ebed1fc6ea | ||
|
|
6821b1cdab | ||
|
|
144ae9b611 |
34
CHANGELOG.md
34
CHANGELOG.md
@@ -9,6 +9,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `RTVIObserverParams` which allows you to configure what RTVI messages
|
||||
are sent to the clients.
|
||||
|
||||
- Added a `context_window_compression` InputParam to
|
||||
`GeminiMultimodalLiveLLMService` which allows you to enable a sliding context
|
||||
window for the session as well as set the token limit of the sliding window.
|
||||
|
||||
- Updated `SmallWebRTCConnection` to support `ice_servers` with credentials.
|
||||
|
||||
- Added `VADUserStartedSpeakingFrame` and `VADUserStoppedSpeakingFrame`,
|
||||
indicating when the VAD detected the user to start and stop speaking. These
|
||||
events are helpful when using smart turn detection, as the user's stop time
|
||||
@@ -23,10 +32,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added `MCPClient`; a way to connect to MCP servers and use the MCP servers'
|
||||
tools.
|
||||
|
||||
- Added `Mem0 OSS`, along with Mem0 cloud support now the OSS version is also available.
|
||||
- Added `Mem0 OSS`, along with Mem0 cloud support now the OSS version is also
|
||||
available.
|
||||
|
||||
### Changed
|
||||
|
||||
- The `STTMuteFilter` now mutes `InterimTranscriptionFrame` and
|
||||
`TranscriptionFrame` which allows the `STTMuteFilter` to be used in
|
||||
conjunction with transports that generate transcripts, e.g. `DailyTransport`.
|
||||
|
||||
- Function calls now receive a single parameter `FunctionCallParams` instead of
|
||||
`(function_name, tool_call_id, args, llm, context, result_callback)` which is
|
||||
now deprecated.
|
||||
@@ -58,9 +72,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Function calls with parameters `(function_name, tool_call_id, args, llm,
|
||||
context, result_callback)` are deprectated, use a single `FunctionCallParams`
|
||||
parameter instead.
|
||||
- Function calls with parameters
|
||||
`(function_name, tool_call_id, args, llm, context, result_callback)` are
|
||||
deprectated, use a single `FunctionCallParams` parameter instead.
|
||||
|
||||
- `TransportParams.camera_*` parameters are now deprecated, use
|
||||
`TransportParams.video_*` instead.
|
||||
@@ -73,6 +87,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue with `GeminiMultimodalLiveLLMService` where the context
|
||||
contained tokens instead of words.
|
||||
|
||||
- Fixed an issue with HTTP Smart Turn handling, where the service returns a 500
|
||||
error. Previously, this would cause an unhandled exception. Now, a 500 error
|
||||
is treated as an incomplete response.
|
||||
|
||||
- Fixed a TTS services issue that could cause assistant output not to be
|
||||
aggregated to the context when also using `TTSSpeakFrame`s.
|
||||
|
||||
@@ -190,8 +211,9 @@ https://en.wikipedia.org/wiki/Saint_George%27s_Day_in_Catalonia
|
||||
- Fixed an issue in `SmallWebRTCTransport` where an error was thrown if the
|
||||
client did not create a video transceiver.
|
||||
|
||||
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
|
||||
unexpected behavior during inference.
|
||||
- Fixed an issue where LLM input parameters were not working and applied
|
||||
correctly in `GoogleVertexLLMService`, causing unexpected behavior during
|
||||
inference.
|
||||
|
||||
### Other
|
||||
|
||||
|
||||
@@ -50,7 +50,6 @@ autodoc_mock_imports = [
|
||||
"pyht.protos",
|
||||
"pyht.protos.api_pb2",
|
||||
"pipecat_ai_playht", # PlayHT wrapper
|
||||
"vllm",
|
||||
"aiortc",
|
||||
"aiortc.mediastreams",
|
||||
"cv2",
|
||||
@@ -76,7 +75,6 @@ autodoc_mock_imports = [
|
||||
"openpipe",
|
||||
"simli",
|
||||
"soundfile",
|
||||
# Existing mocks
|
||||
"pipecat_ai_krisp",
|
||||
"pyaudio",
|
||||
"_tkinter",
|
||||
@@ -87,6 +85,66 @@ autodoc_mock_imports = [
|
||||
"pydantic.Field",
|
||||
"pydantic._internal._model_construction",
|
||||
"pydantic._internal._fields",
|
||||
# Moondream dependencies
|
||||
"torch",
|
||||
"transformers",
|
||||
"intel_extension_for_pytorch",
|
||||
# Ultravox dependencies
|
||||
"huggingface_hub",
|
||||
"vllm",
|
||||
"vllm.engine.arg_utils",
|
||||
"transformers.AutoTokenizer",
|
||||
# Langchain dependencies
|
||||
"langchain_core",
|
||||
"langchain_core.messages",
|
||||
"langchain_core.runnables",
|
||||
"langchain_core.messages.AIMessageChunk",
|
||||
"langchain_core.runnables.Runnable",
|
||||
# LiveKit dependencies
|
||||
"livekit",
|
||||
"livekit.rtc",
|
||||
"livekit_api",
|
||||
"livekit_protocol",
|
||||
"tenacity",
|
||||
"tenacity.retry",
|
||||
"tenacity.stop_after_attempt",
|
||||
"tenacity.wait_exponential",
|
||||
"rtc",
|
||||
"rtc.Room",
|
||||
"rtc.RoomOptions",
|
||||
"rtc.AudioSource",
|
||||
"rtc.LocalAudioTrack",
|
||||
"rtc.TrackPublishOptions",
|
||||
"rtc.TrackSource",
|
||||
"rtc.AudioStream",
|
||||
"rtc.AudioFrameEvent",
|
||||
"rtc.AudioFrame",
|
||||
"rtc.Track",
|
||||
"rtc.TrackKind",
|
||||
"rtc.RemoteParticipant",
|
||||
"rtc.RemoteTrackPublication",
|
||||
"rtc.DataPacket",
|
||||
# Riva dependencies
|
||||
"riva",
|
||||
"riva.client",
|
||||
"riva.client.Auth",
|
||||
"riva.client.ASRService",
|
||||
"riva.client.StreamingRecognitionConfig",
|
||||
"riva.client.RecognitionConfig",
|
||||
"riva.client.AudioEncoding",
|
||||
"riva.client.proto.riva_tts_pb2",
|
||||
"riva.client.SpeechSynthesisService",
|
||||
# Local CoreML Smart Turn dependencies
|
||||
"coremltools",
|
||||
"coremltools.models",
|
||||
"coremltools.models.MLModel",
|
||||
"torch",
|
||||
"torch.nn",
|
||||
"torch.nn.functional",
|
||||
"transformers",
|
||||
"transformers.AutoFeatureExtractor",
|
||||
# Also add specific classes that are imported
|
||||
"AutoFeatureExtractor",
|
||||
]
|
||||
|
||||
# HTML output settings
|
||||
@@ -118,12 +176,25 @@ def verify_modules():
|
||||
},
|
||||
}
|
||||
|
||||
# Skip importing modules that are in autodoc_mock_imports
|
||||
skipped_modules = set(autodoc_mock_imports)
|
||||
|
||||
missing = []
|
||||
for category, modules in required_modules.items():
|
||||
if isinstance(modules, dict):
|
||||
# Handle nested structure
|
||||
for subcategory, submodules in modules.items():
|
||||
for module in submodules:
|
||||
# Check if module is in autodoc_mock_imports
|
||||
if (
|
||||
f"pipecat.{category}.{subcategory}.{module}" in skipped_modules
|
||||
or module in skipped_modules
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping import of mocked module: pipecat.{category}.{subcategory}.{module}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
__import__(f"pipecat.{category}.{subcategory}.{module}")
|
||||
logger.info(
|
||||
@@ -137,6 +208,11 @@ def verify_modules():
|
||||
else:
|
||||
# Handle flat structure
|
||||
for module in modules:
|
||||
# Check if module is in autodoc_mock_imports
|
||||
if f"pipecat.{category}.{module}" in skipped_modules or module in skipped_modules:
|
||||
logger.info(f"Skipping import of mocked module: pipecat.{category}.{module}")
|
||||
continue
|
||||
|
||||
try:
|
||||
__import__(f"pipecat.{category}.{module}")
|
||||
logger.info(f"Successfully imported pipecat.{category}.{module}")
|
||||
|
||||
@@ -26,20 +26,23 @@ pipecat-ai[grok]
|
||||
pipecat-ai[groq]
|
||||
# pipecat-ai[krisp] # Mocked
|
||||
pipecat-ai[koala]
|
||||
pipecat-ai[langchain]
|
||||
pipecat-ai[livekit]
|
||||
# pipecat-ai[langchain] # Mocked
|
||||
# pipecat-ai[livekit] # Mocked
|
||||
pipecat-ai[lmnt]
|
||||
pipecat-ai[local]
|
||||
# pipecat-ai[local-smart-turn] # Mocked
|
||||
# pipecat-ai[mem0] # Mocked
|
||||
# pipecat-ai[mlx-whisper] # Mocked
|
||||
pipecat-ai[moondream]
|
||||
# pipecat-ai[moondream] # Mocked
|
||||
pipecat-ai[nim]
|
||||
# pipecat-ai[neuphonic] # Mocked
|
||||
pipecat-ai[noisereduce]
|
||||
pipecat-ai[openai]
|
||||
# pipecat-ai[openpipe]
|
||||
# pipecat-ai[playht] # Mocked due to grpcio conflict with riva
|
||||
pipecat-ai[riva]
|
||||
pipecat-ai[qwen]
|
||||
pipecat-ai[remote-smart-turn]
|
||||
# pipecat-ai[riva] # Mocked
|
||||
pipecat-ai[silero]
|
||||
pipecat-ai[simli]
|
||||
pipecat-ai[soundfile]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
python-dotenv==1.0.1
|
||||
modal==0.71.3
|
||||
pipecat-ai[daily,silero,cartesia,openai]==0.0.52
|
||||
pipecat-ai[daily,silero,cartesia,openai]
|
||||
fastapi==0.115.6
|
||||
aiohttp==3.11.11
|
||||
|
||||
@@ -89,6 +89,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system_instruction,
|
||||
transcribe_user_audio=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from fastapi.responses import RedirectResponse
|
||||
from loguru import logger
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
from pipecat.transports.network.webrtc_connection import IceServer, SmallWebRTCConnection
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
@@ -30,7 +30,11 @@ app = FastAPI()
|
||||
# Store connections by pc_id
|
||||
pcs_map: Dict[str, SmallWebRTCConnection] = {}
|
||||
|
||||
ice_servers = ["stun:stun.l.google.com:19302"]
|
||||
ice_servers = [
|
||||
IceServer(
|
||||
urls="stun:stun.l.google.com:19302",
|
||||
)
|
||||
]
|
||||
|
||||
# Mount the frontend at /
|
||||
app.mount("/client", SmallWebRTCPrebuiltUI)
|
||||
|
||||
@@ -18,7 +18,7 @@ from fastapi.responses import RedirectResponse
|
||||
from loguru import logger
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
from pipecat.transports.network.webrtc_connection import IceServer, SmallWebRTCConnection
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
@@ -28,7 +28,11 @@ app = FastAPI()
|
||||
# Store connections by pc_id
|
||||
pcs_map: Dict[str, SmallWebRTCConnection] = {}
|
||||
|
||||
ice_servers = ["stun:stun.l.google.com:19302"]
|
||||
ice_servers = [
|
||||
IceServer(
|
||||
urls="stun:stun.l.google.com:19302",
|
||||
)
|
||||
]
|
||||
|
||||
# Mount the frontend at /
|
||||
app.mount("/prebuilt", SmallWebRTCPrebuiltUI)
|
||||
|
||||
@@ -18,7 +18,7 @@ from fastapi.responses import RedirectResponse
|
||||
from loguru import logger
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
from pipecat.transports.network.webrtc_connection import IceServer, SmallWebRTCConnection
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
@@ -28,7 +28,11 @@ app = FastAPI()
|
||||
# Store connections by pc_id
|
||||
pcs_map: Dict[str, SmallWebRTCConnection] = {}
|
||||
|
||||
ice_servers = ["stun:stun.l.google.com:19302"]
|
||||
ice_servers = [
|
||||
IceServer(
|
||||
urls="stun:stun.l.google.com:19302",
|
||||
)
|
||||
]
|
||||
|
||||
# Mount the frontend at /
|
||||
app.mount("/prebuilt", SmallWebRTCPrebuiltUI)
|
||||
|
||||
@@ -46,6 +46,46 @@ http://localhost:7860
|
||||
|
||||
---
|
||||
|
||||
## WebRTC ICE Servers Configuration
|
||||
|
||||
When implementing WebRTC in your project, **STUN** (Session Traversal Utilities for NAT) and **TURN** (Traversal Using Relays around NAT)
|
||||
servers are usually needed in cases where users are behind routers or firewalls.
|
||||
|
||||
In local networks (e.g., testing within the same home or office network), you usually don’t need to configure STUN or TURN servers.
|
||||
In such cases, WebRTC can often directly establish peer-to-peer connections without needing to traverse NAT or firewalls.
|
||||
|
||||
### What are STUN and TURN Servers?
|
||||
|
||||
- **STUN Server**: Helps clients discover their public IP address and port when they're behind a NAT (Network Address Translation) device (like a router).
|
||||
This allows WebRTC to attempt direct peer-to-peer communication by providing the public-facing IP and port.
|
||||
|
||||
- **TURN Server**: Used as a fallback when direct peer-to-peer communication isn't possible due to strict NATs or firewalls blocking connections.
|
||||
The TURN server relays media traffic between peers.
|
||||
|
||||
### Why are ICE Servers Important?
|
||||
|
||||
**ICE (Interactive Connectivity Establishment)** is a framework used by WebRTC to handle network traversal and NAT issues.
|
||||
The `iceServers` configuration provides a list of **STUN** and **TURN** servers that WebRTC uses to find the best way to connect two peers.
|
||||
|
||||
### Example Configuration for ICE Servers
|
||||
|
||||
Here’s how you can configure a basic `iceServers` object in WebRTC for testing purposes, using Google's public STUN server:
|
||||
|
||||
```javascript
|
||||
const config = {
|
||||
iceServers: [
|
||||
{
|
||||
urls: ["stun:stun.l.google.com:19302"], // Google's public STUN server
|
||||
}
|
||||
],
|
||||
};
|
||||
```
|
||||
|
||||
> For testing purposes, you can either use public **STUN** servers (like Google's) or set up your own **TURN** server.
|
||||
If you're running your own TURN server, make sure to include your server URL, username, and credential in the configuration.
|
||||
|
||||
---
|
||||
|
||||
### 💡 Notes
|
||||
- Ensure all dependencies are installed before running the server.
|
||||
- Check the `.env` file for missing configurations.
|
||||
|
||||
@@ -24,27 +24,47 @@
|
||||
let connected = false
|
||||
let peerConnection = null
|
||||
|
||||
/*const waitForIceGatheringComplete = async (pc) => {
|
||||
const waitForIceGatheringComplete = async (pc, timeoutMs = 2000) => {
|
||||
if (pc.iceGatheringState === 'complete') return;
|
||||
console.log("Waiting for ICE gathering to complete. Current state:", pc.iceGatheringState);
|
||||
return new Promise((resolve) => {
|
||||
let timeoutId;
|
||||
const checkState = () => {
|
||||
console.log("icegatheringstatechange:", pc.iceGatheringState);
|
||||
if (pc.iceGatheringState === 'complete') {
|
||||
pc.removeEventListener('icegatheringstatechange', checkState);
|
||||
cleanup();
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
const onTimeout = () => {
|
||||
console.warn(`ICE gathering timed out after ${timeoutMs} ms.`);
|
||||
cleanup();
|
||||
resolve();
|
||||
};
|
||||
const cleanup = () => {
|
||||
pc.removeEventListener('icegatheringstatechange', checkState);
|
||||
clearTimeout(timeoutId);
|
||||
};
|
||||
pc.addEventListener('icegatheringstatechange', checkState);
|
||||
timeoutId = setTimeout(onTimeout, timeoutMs);
|
||||
// Checking the state again to avoid any eventual race condition
|
||||
checkState();
|
||||
});
|
||||
}*/
|
||||
};
|
||||
|
||||
|
||||
const createSmallWebRTCConnection = async (audioTrack) => {
|
||||
const pc = new RTCPeerConnection()
|
||||
const config = {
|
||||
iceServers: [],
|
||||
};
|
||||
const pc = new RTCPeerConnection(config)
|
||||
addPeerConnectionEventListeners(pc)
|
||||
pc.ontrack = e => audioEl.srcObject = e.streams[0]
|
||||
// SmallWebRTCTransport expects to receive both transceivers
|
||||
pc.addTransceiver(audioTrack, { direction: 'sendrecv' })
|
||||
pc.addTransceiver('video', { direction: 'sendrecv' })
|
||||
await pc.setLocalDescription(await pc.createOffer())
|
||||
//await waitForIceGatheringComplete(pc)
|
||||
await waitForIceGatheringComplete(pc)
|
||||
const offer = pc.localDescription
|
||||
const response = await fetch('/api/offer', {
|
||||
body: JSON.stringify({ sdp: offer.sdp, type: offer.type}),
|
||||
@@ -57,16 +77,37 @@
|
||||
}
|
||||
|
||||
const connect = async () => {
|
||||
_onConnecting()
|
||||
const audioStream = await navigator.mediaDevices.getUserMedia({audio: true})
|
||||
peerConnection= await createSmallWebRTCConnection(audioStream.getAudioTracks()[0])
|
||||
peerConnection.onconnectionstatechange = () => {
|
||||
let connectionState = peerConnection?.connectionState
|
||||
}
|
||||
|
||||
const addPeerConnectionEventListeners = (pc) => {
|
||||
pc.oniceconnectionstatechange = () => {
|
||||
console.log("oniceconnectionstatechange", pc?.iceConnectionState)
|
||||
}
|
||||
pc.onconnectionstatechange = () => {
|
||||
console.log("onconnectionstatechange", pc?.connectionState)
|
||||
let connectionState = pc?.connectionState
|
||||
if (connectionState === 'connected') {
|
||||
_onConnected()
|
||||
} else if (connectionState === 'disconnected') {
|
||||
_onDisconnected()
|
||||
}
|
||||
}
|
||||
pc.onicecandidate = (event) => {
|
||||
if (event.candidate) {
|
||||
console.log("New ICE candidate:", event.candidate);
|
||||
} else {
|
||||
console.log("All ICE candidates have been sent.");
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
const _onConnecting = () => {
|
||||
statusEl.textContent = "Connecting"
|
||||
buttonEl.textContent = "Disconnect"
|
||||
connected = true
|
||||
}
|
||||
|
||||
const _onConnected = () => {
|
||||
|
||||
@@ -17,7 +17,7 @@ from fastapi import BackgroundTasks, FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
from pipecat.transports.network.webrtc_connection import IceServer, SmallWebRTCConnection
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
@@ -28,6 +28,13 @@ app = FastAPI()
|
||||
pcs_map: Dict[str, SmallWebRTCConnection] = {}
|
||||
|
||||
|
||||
ice_servers = [
|
||||
IceServer(
|
||||
urls="stun:stun.l.google.com:19302",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@app.post("/api/offer")
|
||||
async def offer(request: dict, background_tasks: BackgroundTasks):
|
||||
pc_id = request.get("pc_id")
|
||||
@@ -37,7 +44,7 @@ async def offer(request: dict, background_tasks: BackgroundTasks):
|
||||
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
|
||||
await pipecat_connection.renegotiate(sdp=request["sdp"], type=request["type"])
|
||||
else:
|
||||
pipecat_connection = SmallWebRTCConnection()
|
||||
pipecat_connection = SmallWebRTCConnection(ice_servers)
|
||||
await pipecat_connection.initialize(sdp=request["sdp"], type=request["type"])
|
||||
|
||||
@pipecat_connection.event_handler("closed")
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import {
|
||||
RTVIClientAudio,
|
||||
RTVIClientVideo,
|
||||
useRTVIClient,
|
||||
useRTVIClientTransportState,
|
||||
} from '@pipecat-ai/client-react';
|
||||
import { RTVIProvider } from './providers/RTVIProvider';
|
||||
import { ConnectButton } from './components/ConnectButton';
|
||||
import { StatusDisplay } from './components/StatusDisplay';
|
||||
import { DebugDisplay } from './components/DebugDisplay';
|
||||
import './App.css';
|
||||
} from "@pipecat-ai/client-react";
|
||||
import { RTVIProvider } from "./providers/RTVIProvider";
|
||||
import { ConnectButton } from "./components/ConnectButton";
|
||||
import { StatusDisplay } from "./components/StatusDisplay";
|
||||
import { DebugDisplay } from "./components/DebugDisplay";
|
||||
import "./App.css";
|
||||
|
||||
function BotVideo() {
|
||||
const transportState = useRTVIClientTransportState();
|
||||
const isConnected = transportState !== 'disconnected';
|
||||
const isConnected = transportState !== "disconnected";
|
||||
|
||||
return (
|
||||
<div className="bot-container">
|
||||
@@ -23,11 +24,31 @@ function BotVideo() {
|
||||
}
|
||||
|
||||
function AppContent() {
|
||||
const client = useRTVIClient();
|
||||
return (
|
||||
<div className="app">
|
||||
<div className="status-bar">
|
||||
<StatusDisplay />
|
||||
<ConnectButton />
|
||||
<div
|
||||
className="controls"
|
||||
onClick={async () => {
|
||||
if (!client) {
|
||||
console.error("RTVI client is not initialized");
|
||||
return;
|
||||
}
|
||||
client.action({
|
||||
service: "tts",
|
||||
action: "say",
|
||||
arguments: [
|
||||
{ name: "text", value: "Hello, world!" },
|
||||
{ name: "interrupt", value: false },
|
||||
],
|
||||
});
|
||||
}}
|
||||
>
|
||||
<button className="connect-btn">Say something</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="main-content">
|
||||
|
||||
@@ -20,6 +20,7 @@ the conversation flow.
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
@@ -32,15 +33,24 @@ from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
LLMMessagesAppendFrame,
|
||||
OutputImageRawFrame,
|
||||
SpriteFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
|
||||
from pipecat.processors.frameworks.rtvi import (
|
||||
ActionResult,
|
||||
RTVIAction,
|
||||
RTVIActionArgument,
|
||||
RTVIObserver,
|
||||
RTVIProcessor,
|
||||
RTVIService,
|
||||
)
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
@@ -181,8 +191,57 @@ async def main():
|
||||
#
|
||||
# RTVI events for Pipecat client UI
|
||||
#
|
||||
|
||||
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
|
||||
|
||||
rtvi_tts = RTVIService(
|
||||
name="tts",
|
||||
options=[],
|
||||
)
|
||||
|
||||
async def action_tts_say_handler(
|
||||
rtvi: RTVIProcessor, service: str, arguments: Dict[str, Any]
|
||||
) -> ActionResult:
|
||||
if "interrupt" in arguments and arguments["interrupt"]:
|
||||
# interrupting breaks function handling
|
||||
await rtvi.interrupt_bot()
|
||||
if "text" in arguments:
|
||||
save = arguments["save"] if "save" in arguments else False
|
||||
frame = TTSSpeakFrame(text=arguments["text"])
|
||||
await rtvi.push_frame(frame)
|
||||
if save:
|
||||
llm_frame = LLMMessagesAppendFrame(
|
||||
messages=[{"role": "assistant", "content": arguments["text"]}]
|
||||
)
|
||||
await rtvi.push_frame(llm_frame)
|
||||
|
||||
return True
|
||||
|
||||
action_tts_say = RTVIAction(
|
||||
service="tts",
|
||||
action="say",
|
||||
result="bool",
|
||||
arguments=[
|
||||
RTVIActionArgument(name="text", type="string"),
|
||||
RTVIActionArgument(name="save_in_context", type="bool"),
|
||||
],
|
||||
handler=action_tts_say_handler,
|
||||
)
|
||||
|
||||
async def action_tts_interrupt_handler(
|
||||
rtvi: RTVIProcessor, service: str, arguments: Dict[str, Any]
|
||||
) -> ActionResult:
|
||||
await rtvi.interrupt_bot()
|
||||
return True
|
||||
|
||||
action_tts_interrupt = RTVIAction(
|
||||
service="tts", action="interrupt", result="bool", handler=action_tts_interrupt_handler
|
||||
)
|
||||
|
||||
rtvi.register_service(rtvi_tts)
|
||||
rtvi.register_action(action_tts_say)
|
||||
rtvi.register_action(action_tts_interrupt)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
|
||||
@@ -40,7 +40,7 @@ class HttpSmartTurnAnalyzer(BaseSmartTurn):
|
||||
async def _send_raw_request(self, data_bytes: bytes) -> Dict[str, Any]:
|
||||
headers = {"Content-Type": "application/octet-stream"}
|
||||
headers.update(self._headers)
|
||||
logger.trace(f"Sending {len(data_bytes)} bytes as raw body to {self._url}...")
|
||||
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=self._params.stop_secs)
|
||||
|
||||
@@ -50,23 +50,30 @@ class HttpSmartTurnAnalyzer(BaseSmartTurn):
|
||||
logger.trace("\n--- Response ---")
|
||||
logger.trace(f"Status Code: {response.status}")
|
||||
|
||||
if response.status == 200:
|
||||
try:
|
||||
json_data = await response.json()
|
||||
logger.trace("Response JSON:")
|
||||
logger.trace(json_data)
|
||||
return json_data
|
||||
except aiohttp.ContentTypeError:
|
||||
# Non-JSON response
|
||||
text = await response.text()
|
||||
logger.trace("Response Content (non-JSON):")
|
||||
logger.trace(text)
|
||||
raise Exception(f"Non-JSON response: {text}")
|
||||
else:
|
||||
# Check if successful
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.trace("Response Content (Error):")
|
||||
logger.trace(error_text)
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status == 500:
|
||||
logger.warning(f"Smart turn service returned 500 error: {error_text}")
|
||||
raise Exception(f"Server returned HTTP 500: {error_text}")
|
||||
else:
|
||||
response.raise_for_status()
|
||||
|
||||
# Process successful response
|
||||
try:
|
||||
json_data = await response.json()
|
||||
logger.trace("Response JSON:")
|
||||
logger.trace(json_data)
|
||||
return json_data
|
||||
except aiohttp.ContentTypeError:
|
||||
# Non-JSON response
|
||||
text = await response.text()
|
||||
logger.trace("Response Content (non-JSON):")
|
||||
logger.trace(text)
|
||||
raise Exception(f"Non-JSON response: {text}")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Request timed out after {self._params.stop_secs} seconds")
|
||||
@@ -76,5 +83,14 @@ class HttpSmartTurnAnalyzer(BaseSmartTurn):
|
||||
raise Exception("Failed to send raw request to Daily Smart Turn.")
|
||||
|
||||
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
serialized_array = self._serialize_array(audio_array)
|
||||
return await self._send_raw_request(serialized_array)
|
||||
try:
|
||||
serialized_array = self._serialize_array(audio_array)
|
||||
return await self._send_raw_request(serialized_array)
|
||||
except Exception as e:
|
||||
logger.error(f"Smart turn prediction failed: {str(e)}")
|
||||
# Return an incomplete prediction when a failure occurs
|
||||
return {
|
||||
"prediction": 0,
|
||||
"probability": 0.0,
|
||||
"metrics": {"inference_time": 0.0, "total_time": 0.0},
|
||||
}
|
||||
|
||||
@@ -24,10 +24,12 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
STTMuteFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -175,6 +177,8 @@ class STTMuteFilter(FrameProcessor):
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
),
|
||||
):
|
||||
# Only pass VAD-related frames when not muted
|
||||
|
||||
@@ -61,6 +61,9 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.llm_service import (
|
||||
FunctionCallParams, # TODO(aleix): we shouldn't import `services` from `processors`
|
||||
)
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
@@ -392,6 +395,32 @@ class RTVIServerMessageFrame(SystemFrame):
|
||||
return f"{self.name}(data: {self.data})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTVIObserverParams:
|
||||
"""
|
||||
Parameters for configuring RTVI Observer behavior.
|
||||
|
||||
Attributes:
|
||||
bot_llm_enabled (bool): Indicates if the bot's LLM messages should be sent.
|
||||
bot_tts_enabled (bool): Indicates if the bot's TTS messages should be sent.
|
||||
bot_speaking_enabled (bool): Indicates if the bot's started/stopped speaking messages should be sent.
|
||||
user_llm_enabled (bool): Indicates if the user's LLM input messages should be sent.
|
||||
user_speaking_enabled (bool): Indicates if the user's started/stopped speaking messages should be sent.
|
||||
user_transcription_enabled (bool): Indicates if user's transcription messages should be sent.
|
||||
metrics_enabled (bool): Indicates if metrics messages should be sent.
|
||||
errors_enabled (bool): Indicates if errors messages should be sent.
|
||||
"""
|
||||
|
||||
bot_llm_enabled: bool = True
|
||||
bot_tts_enabled: bool = True
|
||||
bot_speaking_enabled: bool = True
|
||||
user_llm_enabled: bool = True
|
||||
user_speaking_enabled: bool = True
|
||||
user_transcription_enabled: bool = True
|
||||
metrics_enabled: bool = True
|
||||
errors_enabled: bool = True
|
||||
|
||||
|
||||
class RTVIObserver(BaseObserver):
|
||||
"""Pipeline frame observer for RTVI server message handling.
|
||||
|
||||
@@ -404,14 +433,17 @@ class RTVIObserver(BaseObserver):
|
||||
are handled by the RTVIProcessor.
|
||||
|
||||
Args:
|
||||
rtvi (FrameProcessor): The RTVI processor to push frames to.
|
||||
rtvi (RTVIProcessor): The RTVI processor to push frames to.
|
||||
params (RTVIObserverParams): Settings to enable/disable specific messages.
|
||||
"""
|
||||
|
||||
def __init__(self, rtvi: FrameProcessor):
|
||||
def __init__(self, rtvi: "RTVIProcessor", *, params: RTVIObserverParams = RTVIObserverParams()):
|
||||
super().__init__()
|
||||
self._rtvi = rtvi
|
||||
self._params = params
|
||||
self._bot_transcription = ""
|
||||
self._frames_seen = set()
|
||||
rtvi.set_errors_enabled(self._params.errors_enabled)
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
@@ -438,35 +470,41 @@ class RTVIObserver(BaseObserver):
|
||||
# again the next time we see the frame.
|
||||
mark_as_seen = True
|
||||
|
||||
if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)):
|
||||
if (
|
||||
isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame))
|
||||
and self._params.user_speaking_enabled
|
||||
):
|
||||
await self._handle_interruptions(frame)
|
||||
elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)) and (
|
||||
direction == FrameDirection.UPSTREAM
|
||||
elif (
|
||||
isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame))
|
||||
and (direction == FrameDirection.UPSTREAM)
|
||||
and self._params.bot_speaking_enabled
|
||||
):
|
||||
await self._handle_bot_speaking(frame)
|
||||
elif isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)):
|
||||
elif (
|
||||
isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame))
|
||||
and self._params.user_transcription_enabled
|
||||
):
|
||||
await self._handle_user_transcriptions(frame)
|
||||
elif isinstance(frame, OpenAILLMContextFrame):
|
||||
elif isinstance(frame, OpenAILLMContextFrame) and self._params.user_llm_enabled:
|
||||
await self._handle_context(frame)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._push_bot_transcription()
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
elif isinstance(frame, LLMFullResponseStartFrame) and self._params.bot_llm_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotLLMStartedMessage())
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) and self._params.bot_llm_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotLLMStoppedMessage())
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
elif isinstance(frame, LLMTextFrame) and self._params.bot_llm_enabled:
|
||||
await self._handle_llm_text_frame(frame)
|
||||
elif isinstance(frame, TTSStartedFrame):
|
||||
elif isinstance(frame, TTSStartedFrame) and self._params.bot_tts_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotTTSStartedMessage())
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotTTSStoppedMessage())
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
|
||||
if isinstance(src, BaseOutputTransport):
|
||||
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self.push_transport_message_urgent(message)
|
||||
else:
|
||||
mark_as_seen = False
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
|
||||
await self._handle_metrics(frame)
|
||||
elif isinstance(frame, RTVIServerMessageFrame):
|
||||
message = RTVIServerMessage(data=frame.data)
|
||||
@@ -609,6 +647,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._bot_ready = False
|
||||
self._client_ready = False
|
||||
self._client_ready_id = ""
|
||||
self._errors_enabled = True
|
||||
|
||||
self._registered_actions: Dict[str, RTVIAction] = {}
|
||||
self._registered_services: Dict[str, RTVIService] = {}
|
||||
@@ -648,26 +687,23 @@ class RTVIProcessor(FrameProcessor):
|
||||
await self._update_config(self._config, False)
|
||||
await self._send_bot_ready()
|
||||
|
||||
def set_errors_enabled(self, enabled: bool):
|
||||
self._errors_enabled = enabled
|
||||
|
||||
async def interrupt_bot(self):
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
async def send_error(self, error: str):
|
||||
message = RTVIError(data=RTVIErrorData(error=error, fatal=False))
|
||||
await self._push_transport_message(message)
|
||||
await self._send_error_frame(ErrorFrame(error=error))
|
||||
|
||||
async def handle_message(self, message: RTVIMessage):
|
||||
await self._message_queue.put(message)
|
||||
|
||||
async def handle_function_call(
|
||||
self,
|
||||
function_name: str,
|
||||
tool_call_id: str,
|
||||
arguments: Mapping[str, Any],
|
||||
):
|
||||
async def handle_function_call(self, params: FunctionCallParams):
|
||||
fn = RTVILLMFunctionCallMessageData(
|
||||
function_name=function_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
arguments=params.arguments,
|
||||
)
|
||||
message = RTVILLMFunctionCallMessage(data=fn)
|
||||
await self._push_transport_message(message, exclude_none=False)
|
||||
@@ -917,12 +953,14 @@ class RTVIProcessor(FrameProcessor):
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _send_error_frame(self, frame: ErrorFrame):
|
||||
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
|
||||
await self._push_transport_message(message)
|
||||
if self._errors_enabled:
|
||||
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _send_error_response(self, id: str, error: str):
|
||||
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
|
||||
await self._push_transport_message(message)
|
||||
if self._errors_enabled:
|
||||
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
def _action_id(self, service: str, action: str) -> str:
|
||||
return f"{service}:{action}"
|
||||
|
||||
@@ -93,49 +93,55 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""Aggregates and emits text fragments as a transcript message.
|
||||
|
||||
This method uses a heuristic to automatically detect whether text fragments
|
||||
use pre-spacing (spaces at the beginning of fragments) or not, and applies
|
||||
the appropriate joining strategy. It handles fragments from different TTS
|
||||
services with different formatting patterns.
|
||||
contain embedded spacing (spaces at the beginning or end of fragments) or not,
|
||||
and applies the appropriate joining strategy. It handles fragments from different
|
||||
TTS services with different formatting patterns.
|
||||
|
||||
Examples:
|
||||
Pre-spaced fragments (concatenated):
|
||||
Fragments with embedded spacing (concatenated):
|
||||
```
|
||||
TTSTextFrame: ["Hello"]
|
||||
TTSTextFrame: [" there"]
|
||||
TTSTextFrame: [" there"] # Leading space
|
||||
TTSTextFrame: ["!"]
|
||||
TTSTextFrame: [" How"]
|
||||
TTSTextFrame: [" How"] # Leading space
|
||||
TTSTextFrame: ["'s"]
|
||||
TTSTextFrame: [" it"]
|
||||
TTSTextFrame: [" going"]
|
||||
TTSTextFrame: ["?"]
|
||||
TTSTextFrame: [" it"] # Leading space
|
||||
```
|
||||
Result: "Hello there! How's it going?"
|
||||
Result: "Hello there! How's it"
|
||||
|
||||
Word-by-word fragments (joined with spaces):
|
||||
Fragments with trailing spaces (concatenated):
|
||||
```
|
||||
TTSTextFrame: ["Hel"]
|
||||
TTSTextFrame: ["lo "] # Trailing space
|
||||
TTSTextFrame: ["to "] # Trailing space
|
||||
TTSTextFrame: ["you"]
|
||||
```
|
||||
Result: "Hello to you"
|
||||
|
||||
Word-by-word fragments without spacing (joined with spaces):
|
||||
```
|
||||
TTSTextFrame: ["Hello"]
|
||||
TTSTextFrame: ["there!"]
|
||||
TTSTextFrame: ["How"]
|
||||
TTSTextFrame: ["is"]
|
||||
TTSTextFrame: ["it"]
|
||||
TTSTextFrame: ["going?"]
|
||||
TTSTextFrame: ["there"]
|
||||
TTSTextFrame: ["how"]
|
||||
TTSTextFrame: ["are"]
|
||||
TTSTextFrame: ["you"]
|
||||
```
|
||||
Result: "Hello there! How is it going?"
|
||||
Result: "Hello there how are you"
|
||||
"""
|
||||
if self._current_text_parts and self._aggregation_start_time:
|
||||
# Heuristic to detect pre-spaced fragments
|
||||
uses_prespacing = False
|
||||
if len(self._current_text_parts) > 1:
|
||||
# Check if any fragment after the first one starts with whitespace
|
||||
has_spaced_parts = any(
|
||||
part and part[0].isspace() for part in self._current_text_parts[1:]
|
||||
)
|
||||
if has_spaced_parts:
|
||||
uses_prespacing = True
|
||||
has_leading_spaces = any(
|
||||
part and part[0].isspace() for part in self._current_text_parts[1:]
|
||||
)
|
||||
has_trailing_spaces = any(
|
||||
part and part[-1].isspace() for part in self._current_text_parts[:-1]
|
||||
)
|
||||
|
||||
# Apply appropriate joining method
|
||||
if uses_prespacing:
|
||||
# Pre-spaced fragments - just concatenate
|
||||
# If there are embedded spaces in the fragments, use direct concatenation
|
||||
contains_spacing_between_fragments = has_leading_spaces or has_trailing_spaces
|
||||
|
||||
# Apply corresponding joining method
|
||||
if contains_spacing_between_fragments:
|
||||
# Fragments already have spacing - just concatenate
|
||||
content = "".join(self._current_text_parts)
|
||||
else:
|
||||
# Word-by-word fragments - join with spaces
|
||||
|
||||
@@ -193,3 +193,10 @@ def parse_server_event(str):
|
||||
except Exception as e:
|
||||
print(f"Error parsing server event: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class ContextWindowCompressionConfig(BaseModel):
|
||||
"""Configuration for context window compression."""
|
||||
|
||||
sliding_window: Optional[bool] = Field(default=True)
|
||||
trigger_tokens: Optional[int] = Field(default=None)
|
||||
|
||||
@@ -223,6 +223,16 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
|
||||
class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
|
||||
# but the GeminiMultimodalLiveAssistantContextAggregator pushes LLMTextFrames and TTSTextFrames. We
|
||||
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
|
||||
# are process. This ensures that the context gets only one set of messages.
|
||||
# GeminiMultimodalLiveLLMService also pushes TranscriptionFrames, so we need to
|
||||
# ignore pushing those as well, as they're also TextFrames.
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if not isinstance(frame, (LLMTextFrame, TranscriptionFrame)):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
# We don't want to store any images in the context. Revisit this later
|
||||
# when the API evolves.
|
||||
@@ -265,6 +275,15 @@ class GeminiVADParams(BaseModel):
|
||||
silence_duration_ms: Optional[int] = Field(default=None)
|
||||
|
||||
|
||||
class ContextWindowCompressionParams(BaseModel):
|
||||
"""Parameters for context window compression."""
|
||||
|
||||
enabled: bool = Field(default=False)
|
||||
trigger_tokens: Optional[int] = Field(
|
||||
default=None
|
||||
) # None = use default (80% of context window)
|
||||
|
||||
|
||||
class InputParams(BaseModel):
|
||||
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
||||
max_tokens: Optional[int] = Field(default=4096, ge=1)
|
||||
@@ -280,6 +299,7 @@ class InputParams(BaseModel):
|
||||
default=GeminiMediaResolution.UNSPECIFIED
|
||||
)
|
||||
vad: Optional[GeminiVADParams] = Field(default=None)
|
||||
context_window_compression: Optional[ContextWindowCompressionParams] = Field(default=None)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -334,7 +354,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
self._bot_is_speaking = False
|
||||
self._user_audio_buffer = bytearray()
|
||||
self._bot_audio_buffer = bytearray()
|
||||
self._bot_text_buffer = ""
|
||||
|
||||
self._sample_rate = 24000
|
||||
|
||||
@@ -355,6 +374,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
"language": self._language_code,
|
||||
"media_resolution": params.media_resolution,
|
||||
"vad": params.vad,
|
||||
"context_window_compression": params.context_window_compression.model_dump()
|
||||
if params.context_window_compression
|
||||
else {},
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
|
||||
@@ -414,7 +436,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
#
|
||||
|
||||
async def _handle_interruption(self):
|
||||
pass
|
||||
self._bot_is_speaking = False
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def _handle_user_started_speaking(self, frame):
|
||||
self._user_is_speaking = True
|
||||
@@ -437,10 +461,12 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
text = await self._transcribe_audio(audio, context)
|
||||
if not text:
|
||||
return
|
||||
logger.debug(f"[Transcription:user] {text}")
|
||||
context.add_message({"role": "user", "content": [{"type": "text", "text": text}]})
|
||||
# Sometimes the transcription contains newlines; we want to remove them.
|
||||
cleaned_text = text.rstrip("\n")
|
||||
logger.debug(f"[Transcription:user] {cleaned_text}")
|
||||
context.add_message({"role": "user", "content": [{"type": "text", "text": cleaned_text}]})
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(text=text, user_id="user", timestamp=time_now_iso8601())
|
||||
TranscriptionFrame(text=cleaned_text, user_id="user", timestamp=time_now_iso8601())
|
||||
)
|
||||
|
||||
async def _transcribe_audio(self, audio, context):
|
||||
@@ -561,6 +587,21 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
}
|
||||
}
|
||||
|
||||
# Add context window compression if enabled
|
||||
if self._settings.get("context_window_compression", {}).get("enabled", False):
|
||||
compression_config = {}
|
||||
# Add sliding window (always true if compression is enabled)
|
||||
compression_config["sliding_window"] = {}
|
||||
|
||||
# Add trigger_tokens if specified
|
||||
trigger_tokens = self._settings.get("context_window_compression", {}).get(
|
||||
"trigger_tokens"
|
||||
)
|
||||
if trigger_tokens is not None:
|
||||
compression_config["trigger_tokens"] = trigger_tokens
|
||||
|
||||
config_data["setup"]["context_window_compression"] = compression_config
|
||||
|
||||
# Add VAD configuration if provided
|
||||
if self._settings.get("vad"):
|
||||
vad_config = {}
|
||||
@@ -811,14 +852,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
if not part:
|
||||
return
|
||||
|
||||
text = part.text
|
||||
if text:
|
||||
if not self._bot_text_buffer:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
self._bot_text_buffer += text
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
|
||||
inline_data = part.inlineData
|
||||
if not inline_data:
|
||||
return
|
||||
@@ -833,6 +866,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
if not self._bot_is_speaking:
|
||||
self._bot_is_speaking = True
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
self._bot_audio_buffer.extend(audio)
|
||||
frame = TTSAudioRawFrame(
|
||||
@@ -858,24 +892,20 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_turn_complete(self, evt):
|
||||
self._bot_is_speaking = False
|
||||
text = self._bot_text_buffer
|
||||
self._bot_text_buffer = ""
|
||||
|
||||
if text:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def _handle_evt_output_transcription(self, evt):
|
||||
if not evt.serverContent.outputTranscription:
|
||||
return
|
||||
|
||||
text = evt.serverContent.outputTranscription.text
|
||||
if text:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
await self.push_frame(TTSTextFrame(text=text))
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
if not text:
|
||||
return
|
||||
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
await self.push_frame(TTSTextFrame(text=text))
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
@@ -906,6 +936,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
GeminiMultimodalLiveContext.upgrade(context)
|
||||
user = GeminiMultimodalLiveUserContextAggregator(context, params=user_params)
|
||||
|
||||
assistant_params.expect_stripped_words = True
|
||||
assistant_params.expect_stripped_words = False
|
||||
assistant = GeminiMultimodalLiveAssistantContextAggregator(context, params=assistant_params)
|
||||
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -14,6 +14,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
@@ -170,6 +171,14 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
|
||||
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
|
||||
# but the OpenAIRealtimeLLMService pushes LLMTextFrames and TTSTextFrames. We
|
||||
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
|
||||
# are process. This ensures that the context gets only one set of messages.
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if not isinstance(frame, LLMTextFrame):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
url: str = "wss://users-ws.rime.ai/ws2",
|
||||
url: str = "wss://users.rime.ai/ws2",
|
||||
model: str = "mistv2",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Literal, Optional, Union
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from av.frame import Frame
|
||||
from loguru import logger
|
||||
@@ -87,13 +87,21 @@ class SmallWebRTCTrack:
|
||||
return getattr(self._track, name)
|
||||
|
||||
|
||||
# Alias so we don't need to expose RTCIceServer
|
||||
IceServer = RTCIceServer
|
||||
|
||||
|
||||
class SmallWebRTCConnection(BaseObject):
|
||||
def __init__(self, ice_servers=None):
|
||||
def __init__(self, ice_servers: Optional[Union[List[str], List[IceServer]]] = None):
|
||||
super().__init__()
|
||||
if ice_servers:
|
||||
self.ice_servers = [RTCIceServer(urls=server) for server in ice_servers]
|
||||
if not ice_servers:
|
||||
self.ice_servers: List[IceServer] = []
|
||||
elif all(isinstance(s, IceServer) for s in ice_servers):
|
||||
self.ice_servers = ice_servers
|
||||
elif all(isinstance(s, str) for s in ice_servers):
|
||||
self.ice_servers = [IceServer(urls=s) for s in ice_servers]
|
||||
else:
|
||||
self.ice_servers = []
|
||||
raise TypeError("ice_servers must be either List[str] or List[RTCIceServer]")
|
||||
self._connect_invoked = False
|
||||
self._track_map = {}
|
||||
self._track_getters = {
|
||||
|
||||
@@ -12,7 +12,9 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
STTMuteFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -100,6 +102,45 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_transcription_frames_with_always_strategy(self):
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
|
||||
|
||||
frames_to_send = [
|
||||
# Bot speaking - should mute
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.1), # Wait for StartedSpeaking to process
|
||||
InterimTranscriptionFrame(
|
||||
user_id="user1", text="This should be suppressed", timestamp="1234567890"
|
||||
),
|
||||
TranscriptionFrame(
|
||||
user_id="user1", text="This should be suppressed", timestamp="1234567890"
|
||||
),
|
||||
SleepFrame(sleep=0.1), # Wait for transcription frames to queue
|
||||
BotStoppedSpeakingFrame(),
|
||||
# Bot not speaking - should pass through
|
||||
InterimTranscriptionFrame(
|
||||
user_id="user1", text="This should pass", timestamp="1234567891"
|
||||
),
|
||||
TranscriptionFrame(
|
||||
user_id="user1", text="This should pass through", timestamp="1234567891"
|
||||
),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
InterimTranscriptionFrame, # Only passes through after bot stops speaking
|
||||
TranscriptionFrame, # Only passes through after bot stops speaking
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
# TODO: Revisit once we figure out how to test SystemFrames and DataFrames
|
||||
# async def test_function_call_strategy(self):
|
||||
# filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL}))
|
||||
|
||||
Reference in New Issue
Block a user