Compare commits
30 Commits
hush/muteT
...
hush/rtviS
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd6379cb6a | ||
|
|
7618d7511a | ||
|
|
9ca7bad978 | ||
|
|
10289c1f1c | ||
|
|
213d5d6abc | ||
|
|
01f37f769d | ||
|
|
52b393537a | ||
|
|
a6a4d3d71f | ||
|
|
c52de0f5de | ||
|
|
a1e1255f16 | ||
|
|
c4f758725e | ||
|
|
7bc9a78ce6 | ||
|
|
f8be71b32c | ||
|
|
957fa5546d | ||
|
|
039cb8fcae | ||
|
|
8e05f2f1a1 | ||
|
|
8467aa1ed3 | ||
|
|
9c5878af3d | ||
|
|
ef29800fe9 | ||
|
|
7e09933070 | ||
|
|
82a9d7f992 | ||
|
|
facbebb15f | ||
|
|
2ba60fc41f | ||
|
|
685f951ae2 | ||
|
|
27d4c927a8 | ||
|
|
20a59e8c56 | ||
|
|
d9a0a93667 | ||
|
|
154d5d1859 | ||
|
|
a192217256 | ||
|
|
6821b1cdab |
22
CHANGELOG.md
22
CHANGELOG.md
@@ -9,6 +9,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `RTVIObserverParams` which allows you to configure what RTVI messages
|
||||
are sent to the clients.
|
||||
|
||||
- Added a `context_window_compression` InputParam to
|
||||
`GeminiMultimodalLiveLLMService` which allows you to enable a sliding context
|
||||
window for the session as well as set the token limit of the sliding window.
|
||||
|
||||
- Updated `SmallWebRTCConnection` to support `ice_servers` with credentials.
|
||||
|
||||
- Added `VADUserStartedSpeakingFrame` and `VADUserStoppedSpeakingFrame`,
|
||||
@@ -25,10 +32,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added `MCPClient`; a way to connect to MCP servers and use the MCP servers'
|
||||
tools.
|
||||
|
||||
- Added `Mem0 OSS`, along with Mem0 cloud support now the OSS version is also available.
|
||||
- Added `Mem0 OSS`, along with Mem0 cloud support now the OSS version is also
|
||||
available.
|
||||
|
||||
### Changed
|
||||
|
||||
- The `STTMuteFilter` now mutes `InterimTranscriptionFrame` and
|
||||
`TranscriptionFrame` which allows the `STTMuteFilter` to be used in
|
||||
conjunction with transports that generate transcripts, e.g. `DailyTransport`.
|
||||
|
||||
- Function calls now receive a single parameter `FunctionCallParams` instead of
|
||||
`(function_name, tool_call_id, args, llm, context, result_callback)` which is
|
||||
now deprecated.
|
||||
@@ -75,6 +87,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue with `GeminiMultimodalLiveLLMService` where the context
|
||||
contained tokens instead of words.
|
||||
|
||||
- Fixed an issue with HTTP Smart Turn handling, where the service returns a 500
|
||||
error. Previously, this would cause an unhandled exception. Now, a 500 error
|
||||
is treated as an incomplete response.
|
||||
@@ -196,8 +211,9 @@ https://en.wikipedia.org/wiki/Saint_George%27s_Day_in_Catalonia
|
||||
- Fixed an issue in `SmallWebRTCTransport` where an error was thrown if the
|
||||
client did not create a video transceiver.
|
||||
|
||||
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
|
||||
unexpected behavior during inference.
|
||||
- Fixed an issue where LLM input parameters were not working and applied
|
||||
correctly in `GoogleVertexLLMService`, causing unexpected behavior during
|
||||
inference.
|
||||
|
||||
### Other
|
||||
|
||||
|
||||
@@ -50,7 +50,6 @@ autodoc_mock_imports = [
|
||||
"pyht.protos",
|
||||
"pyht.protos.api_pb2",
|
||||
"pipecat_ai_playht", # PlayHT wrapper
|
||||
"vllm",
|
||||
"aiortc",
|
||||
"aiortc.mediastreams",
|
||||
"cv2",
|
||||
@@ -76,7 +75,6 @@ autodoc_mock_imports = [
|
||||
"openpipe",
|
||||
"simli",
|
||||
"soundfile",
|
||||
# Existing mocks
|
||||
"pipecat_ai_krisp",
|
||||
"pyaudio",
|
||||
"_tkinter",
|
||||
@@ -87,6 +85,66 @@ autodoc_mock_imports = [
|
||||
"pydantic.Field",
|
||||
"pydantic._internal._model_construction",
|
||||
"pydantic._internal._fields",
|
||||
# Moondream dependencies
|
||||
"torch",
|
||||
"transformers",
|
||||
"intel_extension_for_pytorch",
|
||||
# Ultravox dependencies
|
||||
"huggingface_hub",
|
||||
"vllm",
|
||||
"vllm.engine.arg_utils",
|
||||
"transformers.AutoTokenizer",
|
||||
# Langchain dependencies
|
||||
"langchain_core",
|
||||
"langchain_core.messages",
|
||||
"langchain_core.runnables",
|
||||
"langchain_core.messages.AIMessageChunk",
|
||||
"langchain_core.runnables.Runnable",
|
||||
# LiveKit dependencies
|
||||
"livekit",
|
||||
"livekit.rtc",
|
||||
"livekit_api",
|
||||
"livekit_protocol",
|
||||
"tenacity",
|
||||
"tenacity.retry",
|
||||
"tenacity.stop_after_attempt",
|
||||
"tenacity.wait_exponential",
|
||||
"rtc",
|
||||
"rtc.Room",
|
||||
"rtc.RoomOptions",
|
||||
"rtc.AudioSource",
|
||||
"rtc.LocalAudioTrack",
|
||||
"rtc.TrackPublishOptions",
|
||||
"rtc.TrackSource",
|
||||
"rtc.AudioStream",
|
||||
"rtc.AudioFrameEvent",
|
||||
"rtc.AudioFrame",
|
||||
"rtc.Track",
|
||||
"rtc.TrackKind",
|
||||
"rtc.RemoteParticipant",
|
||||
"rtc.RemoteTrackPublication",
|
||||
"rtc.DataPacket",
|
||||
# Riva dependencies
|
||||
"riva",
|
||||
"riva.client",
|
||||
"riva.client.Auth",
|
||||
"riva.client.ASRService",
|
||||
"riva.client.StreamingRecognitionConfig",
|
||||
"riva.client.RecognitionConfig",
|
||||
"riva.client.AudioEncoding",
|
||||
"riva.client.proto.riva_tts_pb2",
|
||||
"riva.client.SpeechSynthesisService",
|
||||
# Local CoreML Smart Turn dependencies
|
||||
"coremltools",
|
||||
"coremltools.models",
|
||||
"coremltools.models.MLModel",
|
||||
"torch",
|
||||
"torch.nn",
|
||||
"torch.nn.functional",
|
||||
"transformers",
|
||||
"transformers.AutoFeatureExtractor",
|
||||
# Also add specific classes that are imported
|
||||
"AutoFeatureExtractor",
|
||||
]
|
||||
|
||||
# HTML output settings
|
||||
@@ -118,12 +176,25 @@ def verify_modules():
|
||||
},
|
||||
}
|
||||
|
||||
# Skip importing modules that are in autodoc_mock_imports
|
||||
skipped_modules = set(autodoc_mock_imports)
|
||||
|
||||
missing = []
|
||||
for category, modules in required_modules.items():
|
||||
if isinstance(modules, dict):
|
||||
# Handle nested structure
|
||||
for subcategory, submodules in modules.items():
|
||||
for module in submodules:
|
||||
# Check if module is in autodoc_mock_imports
|
||||
if (
|
||||
f"pipecat.{category}.{subcategory}.{module}" in skipped_modules
|
||||
or module in skipped_modules
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping import of mocked module: pipecat.{category}.{subcategory}.{module}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
__import__(f"pipecat.{category}.{subcategory}.{module}")
|
||||
logger.info(
|
||||
@@ -137,6 +208,11 @@ def verify_modules():
|
||||
else:
|
||||
# Handle flat structure
|
||||
for module in modules:
|
||||
# Check if module is in autodoc_mock_imports
|
||||
if f"pipecat.{category}.{module}" in skipped_modules or module in skipped_modules:
|
||||
logger.info(f"Skipping import of mocked module: pipecat.{category}.{module}")
|
||||
continue
|
||||
|
||||
try:
|
||||
__import__(f"pipecat.{category}.{module}")
|
||||
logger.info(f"Successfully imported pipecat.{category}.{module}")
|
||||
|
||||
@@ -26,20 +26,23 @@ pipecat-ai[grok]
|
||||
pipecat-ai[groq]
|
||||
# pipecat-ai[krisp] # Mocked
|
||||
pipecat-ai[koala]
|
||||
pipecat-ai[langchain]
|
||||
pipecat-ai[livekit]
|
||||
# pipecat-ai[langchain] # Mocked
|
||||
# pipecat-ai[livekit] # Mocked
|
||||
pipecat-ai[lmnt]
|
||||
pipecat-ai[local]
|
||||
# pipecat-ai[local-smart-turn] # Mocked
|
||||
# pipecat-ai[mem0] # Mocked
|
||||
# pipecat-ai[mlx-whisper] # Mocked
|
||||
pipecat-ai[moondream]
|
||||
# pipecat-ai[moondream] # Mocked
|
||||
pipecat-ai[nim]
|
||||
# pipecat-ai[neuphonic] # Mocked
|
||||
pipecat-ai[noisereduce]
|
||||
pipecat-ai[openai]
|
||||
# pipecat-ai[openpipe]
|
||||
# pipecat-ai[playht] # Mocked due to grpcio conflict with riva
|
||||
pipecat-ai[riva]
|
||||
pipecat-ai[qwen]
|
||||
pipecat-ai[remote-smart-turn]
|
||||
# pipecat-ai[riva] # Mocked
|
||||
pipecat-ai[silero]
|
||||
pipecat-ai[simli]
|
||||
pipecat-ai[soundfile]
|
||||
|
||||
@@ -89,6 +89,7 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system_instruction,
|
||||
transcribe_user_audio=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import {
|
||||
RTVIClientAudio,
|
||||
RTVIClientVideo,
|
||||
useRTVIClient,
|
||||
useRTVIClientTransportState,
|
||||
} from '@pipecat-ai/client-react';
|
||||
import { RTVIProvider } from './providers/RTVIProvider';
|
||||
import { ConnectButton } from './components/ConnectButton';
|
||||
import { StatusDisplay } from './components/StatusDisplay';
|
||||
import { DebugDisplay } from './components/DebugDisplay';
|
||||
import './App.css';
|
||||
} from "@pipecat-ai/client-react";
|
||||
import { RTVIProvider } from "./providers/RTVIProvider";
|
||||
import { ConnectButton } from "./components/ConnectButton";
|
||||
import { StatusDisplay } from "./components/StatusDisplay";
|
||||
import { DebugDisplay } from "./components/DebugDisplay";
|
||||
import "./App.css";
|
||||
|
||||
function BotVideo() {
|
||||
const transportState = useRTVIClientTransportState();
|
||||
const isConnected = transportState !== 'disconnected';
|
||||
const isConnected = transportState !== "disconnected";
|
||||
|
||||
return (
|
||||
<div className="bot-container">
|
||||
@@ -23,11 +24,31 @@ function BotVideo() {
|
||||
}
|
||||
|
||||
function AppContent() {
|
||||
const client = useRTVIClient();
|
||||
return (
|
||||
<div className="app">
|
||||
<div className="status-bar">
|
||||
<StatusDisplay />
|
||||
<ConnectButton />
|
||||
<div
|
||||
className="controls"
|
||||
onClick={async () => {
|
||||
if (!client) {
|
||||
console.error("RTVI client is not initialized");
|
||||
return;
|
||||
}
|
||||
client.action({
|
||||
service: "tts",
|
||||
action: "say",
|
||||
arguments: [
|
||||
{ name: "text", value: "Hello, world!" },
|
||||
{ name: "interrupt", value: false },
|
||||
],
|
||||
});
|
||||
}}
|
||||
>
|
||||
<button className="connect-btn">Say something</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="main-content">
|
||||
|
||||
@@ -20,6 +20,7 @@ the conversation flow.
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
@@ -32,15 +33,24 @@ from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
LLMMessagesAppendFrame,
|
||||
OutputImageRawFrame,
|
||||
SpriteFrame,
|
||||
TTSSpeakFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
|
||||
from pipecat.processors.frameworks.rtvi import (
|
||||
ActionResult,
|
||||
RTVIAction,
|
||||
RTVIActionArgument,
|
||||
RTVIObserver,
|
||||
RTVIProcessor,
|
||||
RTVIService,
|
||||
)
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
@@ -181,8 +191,57 @@ async def main():
|
||||
#
|
||||
# RTVI events for Pipecat client UI
|
||||
#
|
||||
|
||||
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
|
||||
|
||||
rtvi_tts = RTVIService(
|
||||
name="tts",
|
||||
options=[],
|
||||
)
|
||||
|
||||
async def action_tts_say_handler(
|
||||
rtvi: RTVIProcessor, service: str, arguments: Dict[str, Any]
|
||||
) -> ActionResult:
|
||||
if "interrupt" in arguments and arguments["interrupt"]:
|
||||
# interrupting breaks function handling
|
||||
await rtvi.interrupt_bot()
|
||||
if "text" in arguments:
|
||||
save = arguments["save"] if "save" in arguments else False
|
||||
frame = TTSSpeakFrame(text=arguments["text"])
|
||||
await rtvi.push_frame(frame)
|
||||
if save:
|
||||
llm_frame = LLMMessagesAppendFrame(
|
||||
messages=[{"role": "assistant", "content": arguments["text"]}]
|
||||
)
|
||||
await rtvi.push_frame(llm_frame)
|
||||
|
||||
return True
|
||||
|
||||
action_tts_say = RTVIAction(
|
||||
service="tts",
|
||||
action="say",
|
||||
result="bool",
|
||||
arguments=[
|
||||
RTVIActionArgument(name="text", type="string"),
|
||||
RTVIActionArgument(name="save_in_context", type="bool"),
|
||||
],
|
||||
handler=action_tts_say_handler,
|
||||
)
|
||||
|
||||
async def action_tts_interrupt_handler(
|
||||
rtvi: RTVIProcessor, service: str, arguments: Dict[str, Any]
|
||||
) -> ActionResult:
|
||||
await rtvi.interrupt_bot()
|
||||
return True
|
||||
|
||||
action_tts_interrupt = RTVIAction(
|
||||
service="tts", action="interrupt", result="bool", handler=action_tts_interrupt_handler
|
||||
)
|
||||
|
||||
rtvi.register_service(rtvi_tts)
|
||||
rtvi.register_action(action_tts_say)
|
||||
rtvi.register_action(action_tts_interrupt)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
|
||||
@@ -24,10 +24,12 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
STTMuteFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -175,6 +177,8 @@ class STTMuteFilter(FrameProcessor):
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
),
|
||||
):
|
||||
# Only pass VAD-related frames when not muted
|
||||
|
||||
@@ -61,6 +61,9 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.llm_service import (
|
||||
FunctionCallParams, # TODO(aleix): we shouldn't import `services` from `processors`
|
||||
)
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
@@ -392,6 +395,32 @@ class RTVIServerMessageFrame(SystemFrame):
|
||||
return f"{self.name}(data: {self.data})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTVIObserverParams:
|
||||
"""
|
||||
Parameters for configuring RTVI Observer behavior.
|
||||
|
||||
Attributes:
|
||||
bot_llm_enabled (bool): Indicates if the bot's LLM messages should be sent.
|
||||
bot_tts_enabled (bool): Indicates if the bot's TTS messages should be sent.
|
||||
bot_speaking_enabled (bool): Indicates if the bot's started/stopped speaking messages should be sent.
|
||||
user_llm_enabled (bool): Indicates if the user's LLM input messages should be sent.
|
||||
user_speaking_enabled (bool): Indicates if the user's started/stopped speaking messages should be sent.
|
||||
user_transcription_enabled (bool): Indicates if user's transcription messages should be sent.
|
||||
metrics_enabled (bool): Indicates if metrics messages should be sent.
|
||||
errors_enabled (bool): Indicates if errors messages should be sent.
|
||||
"""
|
||||
|
||||
bot_llm_enabled: bool = True
|
||||
bot_tts_enabled: bool = True
|
||||
bot_speaking_enabled: bool = True
|
||||
user_llm_enabled: bool = True
|
||||
user_speaking_enabled: bool = True
|
||||
user_transcription_enabled: bool = True
|
||||
metrics_enabled: bool = True
|
||||
errors_enabled: bool = True
|
||||
|
||||
|
||||
class RTVIObserver(BaseObserver):
|
||||
"""Pipeline frame observer for RTVI server message handling.
|
||||
|
||||
@@ -404,14 +433,17 @@ class RTVIObserver(BaseObserver):
|
||||
are handled by the RTVIProcessor.
|
||||
|
||||
Args:
|
||||
rtvi (FrameProcessor): The RTVI processor to push frames to.
|
||||
rtvi (RTVIProcessor): The RTVI processor to push frames to.
|
||||
params (RTVIObserverParams): Settings to enable/disable specific messages.
|
||||
"""
|
||||
|
||||
def __init__(self, rtvi: FrameProcessor):
|
||||
def __init__(self, rtvi: "RTVIProcessor", *, params: RTVIObserverParams = RTVIObserverParams()):
|
||||
super().__init__()
|
||||
self._rtvi = rtvi
|
||||
self._params = params
|
||||
self._bot_transcription = ""
|
||||
self._frames_seen = set()
|
||||
rtvi.set_errors_enabled(self._params.errors_enabled)
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
@@ -438,35 +470,41 @@ class RTVIObserver(BaseObserver):
|
||||
# again the next time we see the frame.
|
||||
mark_as_seen = True
|
||||
|
||||
if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)):
|
||||
if (
|
||||
isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame))
|
||||
and self._params.user_speaking_enabled
|
||||
):
|
||||
await self._handle_interruptions(frame)
|
||||
elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)) and (
|
||||
direction == FrameDirection.UPSTREAM
|
||||
elif (
|
||||
isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame))
|
||||
and (direction == FrameDirection.UPSTREAM)
|
||||
and self._params.bot_speaking_enabled
|
||||
):
|
||||
await self._handle_bot_speaking(frame)
|
||||
elif isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)):
|
||||
elif (
|
||||
isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame))
|
||||
and self._params.user_transcription_enabled
|
||||
):
|
||||
await self._handle_user_transcriptions(frame)
|
||||
elif isinstance(frame, OpenAILLMContextFrame):
|
||||
elif isinstance(frame, OpenAILLMContextFrame) and self._params.user_llm_enabled:
|
||||
await self._handle_context(frame)
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._push_bot_transcription()
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
elif isinstance(frame, LLMFullResponseStartFrame) and self._params.bot_llm_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotLLMStartedMessage())
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) and self._params.bot_llm_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotLLMStoppedMessage())
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
elif isinstance(frame, LLMTextFrame) and self._params.bot_llm_enabled:
|
||||
await self._handle_llm_text_frame(frame)
|
||||
elif isinstance(frame, TTSStartedFrame):
|
||||
elif isinstance(frame, TTSStartedFrame) and self._params.bot_tts_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotTTSStartedMessage())
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
|
||||
await self.push_transport_message_urgent(RTVIBotTTSStoppedMessage())
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
|
||||
if isinstance(src, BaseOutputTransport):
|
||||
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self.push_transport_message_urgent(message)
|
||||
else:
|
||||
mark_as_seen = False
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
|
||||
await self._handle_metrics(frame)
|
||||
elif isinstance(frame, RTVIServerMessageFrame):
|
||||
message = RTVIServerMessage(data=frame.data)
|
||||
@@ -609,6 +647,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._bot_ready = False
|
||||
self._client_ready = False
|
||||
self._client_ready_id = ""
|
||||
self._errors_enabled = True
|
||||
|
||||
self._registered_actions: Dict[str, RTVIAction] = {}
|
||||
self._registered_services: Dict[str, RTVIService] = {}
|
||||
@@ -648,26 +687,23 @@ class RTVIProcessor(FrameProcessor):
|
||||
await self._update_config(self._config, False)
|
||||
await self._send_bot_ready()
|
||||
|
||||
def set_errors_enabled(self, enabled: bool):
|
||||
self._errors_enabled = enabled
|
||||
|
||||
async def interrupt_bot(self):
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
async def send_error(self, error: str):
|
||||
message = RTVIError(data=RTVIErrorData(error=error, fatal=False))
|
||||
await self._push_transport_message(message)
|
||||
await self._send_error_frame(ErrorFrame(error=error))
|
||||
|
||||
async def handle_message(self, message: RTVIMessage):
|
||||
await self._message_queue.put(message)
|
||||
|
||||
async def handle_function_call(
|
||||
self,
|
||||
function_name: str,
|
||||
tool_call_id: str,
|
||||
arguments: Mapping[str, Any],
|
||||
):
|
||||
async def handle_function_call(self, params: FunctionCallParams):
|
||||
fn = RTVILLMFunctionCallMessageData(
|
||||
function_name=function_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
arguments=params.arguments,
|
||||
)
|
||||
message = RTVILLMFunctionCallMessage(data=fn)
|
||||
await self._push_transport_message(message, exclude_none=False)
|
||||
@@ -917,12 +953,14 @@ class RTVIProcessor(FrameProcessor):
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _send_error_frame(self, frame: ErrorFrame):
|
||||
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
|
||||
await self._push_transport_message(message)
|
||||
if self._errors_enabled:
|
||||
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _send_error_response(self, id: str, error: str):
|
||||
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
|
||||
await self._push_transport_message(message)
|
||||
if self._errors_enabled:
|
||||
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
def _action_id(self, service: str, action: str) -> str:
|
||||
return f"{service}:{action}"
|
||||
|
||||
@@ -93,49 +93,55 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""Aggregates and emits text fragments as a transcript message.
|
||||
|
||||
This method uses a heuristic to automatically detect whether text fragments
|
||||
use pre-spacing (spaces at the beginning of fragments) or not, and applies
|
||||
the appropriate joining strategy. It handles fragments from different TTS
|
||||
services with different formatting patterns.
|
||||
contain embedded spacing (spaces at the beginning or end of fragments) or not,
|
||||
and applies the appropriate joining strategy. It handles fragments from different
|
||||
TTS services with different formatting patterns.
|
||||
|
||||
Examples:
|
||||
Pre-spaced fragments (concatenated):
|
||||
Fragments with embedded spacing (concatenated):
|
||||
```
|
||||
TTSTextFrame: ["Hello"]
|
||||
TTSTextFrame: [" there"]
|
||||
TTSTextFrame: [" there"] # Leading space
|
||||
TTSTextFrame: ["!"]
|
||||
TTSTextFrame: [" How"]
|
||||
TTSTextFrame: [" How"] # Leading space
|
||||
TTSTextFrame: ["'s"]
|
||||
TTSTextFrame: [" it"]
|
||||
TTSTextFrame: [" going"]
|
||||
TTSTextFrame: ["?"]
|
||||
TTSTextFrame: [" it"] # Leading space
|
||||
```
|
||||
Result: "Hello there! How's it going?"
|
||||
Result: "Hello there! How's it"
|
||||
|
||||
Word-by-word fragments (joined with spaces):
|
||||
Fragments with trailing spaces (concatenated):
|
||||
```
|
||||
TTSTextFrame: ["Hel"]
|
||||
TTSTextFrame: ["lo "] # Trailing space
|
||||
TTSTextFrame: ["to "] # Trailing space
|
||||
TTSTextFrame: ["you"]
|
||||
```
|
||||
Result: "Hello to you"
|
||||
|
||||
Word-by-word fragments without spacing (joined with spaces):
|
||||
```
|
||||
TTSTextFrame: ["Hello"]
|
||||
TTSTextFrame: ["there!"]
|
||||
TTSTextFrame: ["How"]
|
||||
TTSTextFrame: ["is"]
|
||||
TTSTextFrame: ["it"]
|
||||
TTSTextFrame: ["going?"]
|
||||
TTSTextFrame: ["there"]
|
||||
TTSTextFrame: ["how"]
|
||||
TTSTextFrame: ["are"]
|
||||
TTSTextFrame: ["you"]
|
||||
```
|
||||
Result: "Hello there! How is it going?"
|
||||
Result: "Hello there how are you"
|
||||
"""
|
||||
if self._current_text_parts and self._aggregation_start_time:
|
||||
# Heuristic to detect pre-spaced fragments
|
||||
uses_prespacing = False
|
||||
if len(self._current_text_parts) > 1:
|
||||
# Check if any fragment after the first one starts with whitespace
|
||||
has_spaced_parts = any(
|
||||
part and part[0].isspace() for part in self._current_text_parts[1:]
|
||||
)
|
||||
if has_spaced_parts:
|
||||
uses_prespacing = True
|
||||
has_leading_spaces = any(
|
||||
part and part[0].isspace() for part in self._current_text_parts[1:]
|
||||
)
|
||||
has_trailing_spaces = any(
|
||||
part and part[-1].isspace() for part in self._current_text_parts[:-1]
|
||||
)
|
||||
|
||||
# Apply appropriate joining method
|
||||
if uses_prespacing:
|
||||
# Pre-spaced fragments - just concatenate
|
||||
# If there are embedded spaces in the fragments, use direct concatenation
|
||||
contains_spacing_between_fragments = has_leading_spaces or has_trailing_spaces
|
||||
|
||||
# Apply corresponding joining method
|
||||
if contains_spacing_between_fragments:
|
||||
# Fragments already have spacing - just concatenate
|
||||
content = "".join(self._current_text_parts)
|
||||
else:
|
||||
# Word-by-word fragments - join with spaces
|
||||
|
||||
@@ -193,3 +193,10 @@ def parse_server_event(str):
|
||||
except Exception as e:
|
||||
print(f"Error parsing server event: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class ContextWindowCompressionConfig(BaseModel):
|
||||
"""Configuration for context window compression."""
|
||||
|
||||
sliding_window: Optional[bool] = Field(default=True)
|
||||
trigger_tokens: Optional[int] = Field(default=None)
|
||||
|
||||
@@ -223,6 +223,16 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
|
||||
class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
|
||||
# but the GeminiMultimodalLiveAssistantContextAggregator pushes LLMTextFrames and TTSTextFrames. We
|
||||
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
|
||||
# are process. This ensures that the context gets only one set of messages.
|
||||
# GeminiMultimodalLiveLLMService also pushes TranscriptionFrames, so we need to
|
||||
# ignore pushing those as well, as they're also TextFrames.
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if not isinstance(frame, (LLMTextFrame, TranscriptionFrame)):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
# We don't want to store any images in the context. Revisit this later
|
||||
# when the API evolves.
|
||||
@@ -265,6 +275,15 @@ class GeminiVADParams(BaseModel):
|
||||
silence_duration_ms: Optional[int] = Field(default=None)
|
||||
|
||||
|
||||
class ContextWindowCompressionParams(BaseModel):
|
||||
"""Parameters for context window compression."""
|
||||
|
||||
enabled: bool = Field(default=False)
|
||||
trigger_tokens: Optional[int] = Field(
|
||||
default=None
|
||||
) # None = use default (80% of context window)
|
||||
|
||||
|
||||
class InputParams(BaseModel):
|
||||
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
||||
max_tokens: Optional[int] = Field(default=4096, ge=1)
|
||||
@@ -280,6 +299,7 @@ class InputParams(BaseModel):
|
||||
default=GeminiMediaResolution.UNSPECIFIED
|
||||
)
|
||||
vad: Optional[GeminiVADParams] = Field(default=None)
|
||||
context_window_compression: Optional[ContextWindowCompressionParams] = Field(default=None)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -334,7 +354,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
self._bot_is_speaking = False
|
||||
self._user_audio_buffer = bytearray()
|
||||
self._bot_audio_buffer = bytearray()
|
||||
self._bot_text_buffer = ""
|
||||
|
||||
self._sample_rate = 24000
|
||||
|
||||
@@ -355,6 +374,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
"language": self._language_code,
|
||||
"media_resolution": params.media_resolution,
|
||||
"vad": params.vad,
|
||||
"context_window_compression": params.context_window_compression.model_dump()
|
||||
if params.context_window_compression
|
||||
else {},
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
|
||||
@@ -414,7 +436,9 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
#
|
||||
|
||||
async def _handle_interruption(self):
|
||||
pass
|
||||
self._bot_is_speaking = False
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def _handle_user_started_speaking(self, frame):
|
||||
self._user_is_speaking = True
|
||||
@@ -437,10 +461,12 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
text = await self._transcribe_audio(audio, context)
|
||||
if not text:
|
||||
return
|
||||
logger.debug(f"[Transcription:user] {text}")
|
||||
context.add_message({"role": "user", "content": [{"type": "text", "text": text}]})
|
||||
# Sometimes the transcription contains newlines; we want to remove them.
|
||||
cleaned_text = text.rstrip("\n")
|
||||
logger.debug(f"[Transcription:user] {cleaned_text}")
|
||||
context.add_message({"role": "user", "content": [{"type": "text", "text": cleaned_text}]})
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(text=text, user_id="user", timestamp=time_now_iso8601())
|
||||
TranscriptionFrame(text=cleaned_text, user_id="user", timestamp=time_now_iso8601())
|
||||
)
|
||||
|
||||
async def _transcribe_audio(self, audio, context):
|
||||
@@ -561,6 +587,21 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
}
|
||||
}
|
||||
|
||||
# Add context window compression if enabled
|
||||
if self._settings.get("context_window_compression", {}).get("enabled", False):
|
||||
compression_config = {}
|
||||
# Add sliding window (always true if compression is enabled)
|
||||
compression_config["sliding_window"] = {}
|
||||
|
||||
# Add trigger_tokens if specified
|
||||
trigger_tokens = self._settings.get("context_window_compression", {}).get(
|
||||
"trigger_tokens"
|
||||
)
|
||||
if trigger_tokens is not None:
|
||||
compression_config["trigger_tokens"] = trigger_tokens
|
||||
|
||||
config_data["setup"]["context_window_compression"] = compression_config
|
||||
|
||||
# Add VAD configuration if provided
|
||||
if self._settings.get("vad"):
|
||||
vad_config = {}
|
||||
@@ -811,14 +852,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
if not part:
|
||||
return
|
||||
|
||||
text = part.text
|
||||
if text:
|
||||
if not self._bot_text_buffer:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
self._bot_text_buffer += text
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
|
||||
inline_data = part.inlineData
|
||||
if not inline_data:
|
||||
return
|
||||
@@ -833,6 +866,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
if not self._bot_is_speaking:
|
||||
self._bot_is_speaking = True
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
self._bot_audio_buffer.extend(audio)
|
||||
frame = TTSAudioRawFrame(
|
||||
@@ -858,24 +892,20 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_turn_complete(self, evt):
|
||||
self._bot_is_speaking = False
|
||||
text = self._bot_text_buffer
|
||||
self._bot_text_buffer = ""
|
||||
|
||||
if text:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def _handle_evt_output_transcription(self, evt):
|
||||
if not evt.serverContent.outputTranscription:
|
||||
return
|
||||
|
||||
text = evt.serverContent.outputTranscription.text
|
||||
if text:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
await self.push_frame(TTSTextFrame(text=text))
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
if not text:
|
||||
return
|
||||
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
await self.push_frame(TTSTextFrame(text=text))
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
@@ -906,6 +936,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
GeminiMultimodalLiveContext.upgrade(context)
|
||||
user = GeminiMultimodalLiveUserContextAggregator(context, params=user_params)
|
||||
|
||||
assistant_params.expect_stripped_words = True
|
||||
assistant_params.expect_stripped_words = False
|
||||
assistant = GeminiMultimodalLiveAssistantContextAggregator(context, params=assistant_params)
|
||||
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -14,6 +14,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
@@ -170,6 +171,14 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
|
||||
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
|
||||
# but the OpenAIRealtimeLLMService pushes LLMTextFrames and TTSTextFrames. We
|
||||
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
|
||||
# are process. This ensures that the context gets only one set of messages.
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if not isinstance(frame, LLMTextFrame):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
url: str = "wss://users-ws.rime.ai/ws2",
|
||||
url: str = "wss://users.rime.ai/ws2",
|
||||
model: str = "mistv2",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
|
||||
@@ -12,7 +12,9 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
STTMuteFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -100,6 +102,45 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_transcription_frames_with_always_strategy(self):
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
|
||||
|
||||
frames_to_send = [
|
||||
# Bot speaking - should mute
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.1), # Wait for StartedSpeaking to process
|
||||
InterimTranscriptionFrame(
|
||||
user_id="user1", text="This should be suppressed", timestamp="1234567890"
|
||||
),
|
||||
TranscriptionFrame(
|
||||
user_id="user1", text="This should be suppressed", timestamp="1234567890"
|
||||
),
|
||||
SleepFrame(sleep=0.1), # Wait for transcription frames to queue
|
||||
BotStoppedSpeakingFrame(),
|
||||
# Bot not speaking - should pass through
|
||||
InterimTranscriptionFrame(
|
||||
user_id="user1", text="This should pass", timestamp="1234567891"
|
||||
),
|
||||
TranscriptionFrame(
|
||||
user_id="user1", text="This should pass through", timestamp="1234567891"
|
||||
),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
InterimTranscriptionFrame, # Only passes through after bot stops speaking
|
||||
TranscriptionFrame, # Only passes through after bot stops speaking
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
# TODO: Revisit once we figure out how to test SystemFrames and DataFrames
|
||||
# async def test_function_call_strategy(self):
|
||||
# filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL}))
|
||||
|
||||
Reference in New Issue
Block a user