Compare commits
1 Commits
hush/rtviS
...
hush/muteT
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d77ed9948d |
22
CHANGELOG.md
22
CHANGELOG.md
@@ -9,13 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `RTVIObserverParams` which allows you to configure what RTVI messages
|
||||
are sent to the clients.
|
||||
|
||||
- Added a `context_window_compression` InputParam to
|
||||
`GeminiMultimodalLiveLLMService` which allows you to enable a sliding context
|
||||
window for the session as well as set the token limit of the sliding window.
|
||||
|
||||
- Updated `SmallWebRTCConnection` to support `ice_servers` with credentials.
|
||||
|
||||
- Added `VADUserStartedSpeakingFrame` and `VADUserStoppedSpeakingFrame`,
|
||||
@@ -32,15 +25,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added `MCPClient`; a way to connect to MCP servers and use the MCP servers'
|
||||
tools.
|
||||
|
||||
- Added `Mem0 OSS`, along with Mem0 cloud support now the OSS version is also
|
||||
available.
|
||||
- Added `Mem0 OSS`, along with Mem0 cloud support now the OSS version is also available.
|
||||
|
||||
### Changed
|
||||
|
||||
- The `STTMuteFilter` now mutes `InterimTranscriptionFrame` and
|
||||
`TranscriptionFrame` which allows the `STTMuteFilter` to be used in
|
||||
conjunction with transports that generate transcripts, e.g. `DailyTransport`.
|
||||
|
||||
- Function calls now receive a single parameter `FunctionCallParams` instead of
|
||||
`(function_name, tool_call_id, args, llm, context, result_callback)` which is
|
||||
now deprecated.
|
||||
@@ -87,9 +75,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue with `GeminiMultimodalLiveLLMService` where the context
|
||||
contained tokens instead of words.
|
||||
|
||||
- Fixed an issue with HTTP Smart Turn handling, where the service returns a 500
|
||||
error. Previously, this would cause an unhandled exception. Now, a 500 error
|
||||
is treated as an incomplete response.
|
||||
@@ -211,9 +196,8 @@ https://en.wikipedia.org/wiki/Saint_George%27s_Day_in_Catalonia
|
||||
- Fixed an issue in `SmallWebRTCTransport` where an error was thrown if the
|
||||
client did not create a video transceiver.
|
||||
|
||||
- Fixed an issue where LLM input parameters were not working and applied
|
||||
correctly in `GoogleVertexLLMService`, causing unexpected behavior during
|
||||
inference.
|
||||
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
|
||||
unexpected behavior during inference.
|
||||
|
||||
### Other
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ autodoc_mock_imports = [
|
||||
"pyht.protos",
|
||||
"pyht.protos.api_pb2",
|
||||
"pipecat_ai_playht", # PlayHT wrapper
|
||||
"vllm",
|
||||
"aiortc",
|
||||
"aiortc.mediastreams",
|
||||
"cv2",
|
||||
@@ -75,6 +76,7 @@ autodoc_mock_imports = [
|
||||
"openpipe",
|
||||
"simli",
|
||||
"soundfile",
|
||||
# Existing mocks
|
||||
"pipecat_ai_krisp",
|
||||
"pyaudio",
|
||||
"_tkinter",
|
||||
@@ -85,66 +87,6 @@ autodoc_mock_imports = [
|
||||
"pydantic.Field",
|
||||
"pydantic._internal._model_construction",
|
||||
"pydantic._internal._fields",
|
||||
# Moondream dependencies
|
||||
"torch",
|
||||
"transformers",
|
||||
"intel_extension_for_pytorch",
|
||||
# Ultravox dependencies
|
||||
"huggingface_hub",
|
||||
"vllm",
|
||||
"vllm.engine.arg_utils",
|
||||
"transformers.AutoTokenizer",
|
||||
# Langchain dependencies
|
||||
"langchain_core",
|
||||
"langchain_core.messages",
|
||||
"langchain_core.runnables",
|
||||
"langchain_core.messages.AIMessageChunk",
|
||||
"langchain_core.runnables.Runnable",
|
||||
# LiveKit dependencies
|
||||
"livekit",
|
||||
"livekit.rtc",
|
||||
"livekit_api",
|
||||
"livekit_protocol",
|
||||
"tenacity",
|
||||
"tenacity.retry",
|
||||
"tenacity.stop_after_attempt",
|
||||
"tenacity.wait_exponential",
|
||||
"rtc",
|
||||
"rtc.Room",
|
||||
"rtc.RoomOptions",
|
||||
"rtc.AudioSource",
|
||||
"rtc.LocalAudioTrack",
|
||||
"rtc.TrackPublishOptions",
|
||||
"rtc.TrackSource",
|
||||
"rtc.AudioStream",
|
||||
"rtc.AudioFrameEvent",
|
||||
"rtc.AudioFrame",
|
||||
"rtc.Track",
|
||||
"rtc.TrackKind",
|
||||
"rtc.RemoteParticipant",
|
||||
"rtc.RemoteTrackPublication",
|
||||
"rtc.DataPacket",
|
||||
# Riva dependencies
|
||||
"riva",
|
||||
"riva.client",
|
||||
"riva.client.Auth",
|
||||
"riva.client.ASRService",
|
||||
"riva.client.StreamingRecognitionConfig",
|
||||
"riva.client.RecognitionConfig",
|
||||
"riva.client.AudioEncoding",
|
||||
"riva.client.proto.riva_tts_pb2",
|
||||
"riva.client.SpeechSynthesisService",
|
||||
# Local CoreML Smart Turn dependencies
|
||||
"coremltools",
|
||||
"coremltools.models",
|
||||
"coremltools.models.MLModel",
|
||||
"torch",
|
||||
"torch.nn",
|
||||
"torch.nn.functional",
|
||||
"transformers",
|
||||
"transformers.AutoFeatureExtractor",
|
||||
# Also add specific classes that are imported
|
||||
"AutoFeatureExtractor",
|
||||
]
|
||||
|
||||
# HTML output settings
|
||||
@@ -176,25 +118,12 @@ def verify_modules():
|
||||
},
|
||||
}
|
||||
|
||||
# Skip importing modules that are in autodoc_mock_imports
|
||||
skipped_modules = set(autodoc_mock_imports)
|
||||
|
||||
missing = []
|
||||
for category, modules in required_modules.items():
|
||||
if isinstance(modules, dict):
|
||||
# Handle nested structure
|
||||
for subcategory, submodules in modules.items():
|
||||
for module in submodules:
|
||||
# Check if module is in autodoc_mock_imports
|
||||
if (
|
||||
f"pipecat.{category}.{subcategory}.{module}" in skipped_modules
|
||||
or module in skipped_modules
|
||||
):
|
||||
logger.info(
|
||||
f"Skipping import of mocked module: pipecat.{category}.{subcategory}.{module}"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
__import__(f"pipecat.{category}.{subcategory}.{module}")
|
||||
logger.info(
|
||||
@@ -208,11 +137,6 @@ def verify_modules():
|
||||
else:
|
||||
# Handle flat structure
|
||||
for module in modules:
|
||||
# Check if module is in autodoc_mock_imports
|
||||
if f"pipecat.{category}.{module}" in skipped_modules or module in skipped_modules:
|
||||
logger.info(f"Skipping import of mocked module: pipecat.{category}.{module}")
|
||||
continue
|
||||
|
||||
try:
|
||||
__import__(f"pipecat.{category}.{module}")
|
||||
logger.info(f"Successfully imported pipecat.{category}.{module}")
|
||||
|
||||
@@ -26,23 +26,20 @@ pipecat-ai[grok]
|
||||
pipecat-ai[groq]
|
||||
# pipecat-ai[krisp] # Mocked
|
||||
pipecat-ai[koala]
|
||||
# pipecat-ai[langchain] # Mocked
|
||||
# pipecat-ai[livekit] # Mocked
|
||||
pipecat-ai[langchain]
|
||||
pipecat-ai[livekit]
|
||||
pipecat-ai[lmnt]
|
||||
pipecat-ai[local]
|
||||
# pipecat-ai[local-smart-turn] # Mocked
|
||||
# pipecat-ai[mem0] # Mocked
|
||||
# pipecat-ai[mlx-whisper] # Mocked
|
||||
# pipecat-ai[moondream] # Mocked
|
||||
pipecat-ai[moondream]
|
||||
pipecat-ai[nim]
|
||||
# pipecat-ai[neuphonic] # Mocked
|
||||
pipecat-ai[noisereduce]
|
||||
pipecat-ai[openai]
|
||||
# pipecat-ai[openpipe]
|
||||
# pipecat-ai[playht] # Mocked due to grpcio conflict with riva
|
||||
pipecat-ai[qwen]
|
||||
pipecat-ai[remote-smart-turn]
|
||||
# pipecat-ai[riva] # Mocked
|
||||
pipecat-ai[riva]
|
||||
pipecat-ai[silero]
|
||||
pipecat-ai[simli]
|
||||
pipecat-ai[soundfile]
|
||||
|
||||
@@ -89,7 +89,6 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
|
||||
llm = GeminiMultimodalLiveLLMService(
|
||||
api_key=os.getenv("GOOGLE_API_KEY"),
|
||||
system_instruction=system_instruction,
|
||||
transcribe_user_audio=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,18 +1,17 @@
|
||||
import {
|
||||
RTVIClientAudio,
|
||||
RTVIClientVideo,
|
||||
useRTVIClient,
|
||||
useRTVIClientTransportState,
|
||||
} from "@pipecat-ai/client-react";
|
||||
import { RTVIProvider } from "./providers/RTVIProvider";
|
||||
import { ConnectButton } from "./components/ConnectButton";
|
||||
import { StatusDisplay } from "./components/StatusDisplay";
|
||||
import { DebugDisplay } from "./components/DebugDisplay";
|
||||
import "./App.css";
|
||||
} from '@pipecat-ai/client-react';
|
||||
import { RTVIProvider } from './providers/RTVIProvider';
|
||||
import { ConnectButton } from './components/ConnectButton';
|
||||
import { StatusDisplay } from './components/StatusDisplay';
|
||||
import { DebugDisplay } from './components/DebugDisplay';
|
||||
import './App.css';
|
||||
|
||||
function BotVideo() {
|
||||
const transportState = useRTVIClientTransportState();
|
||||
const isConnected = transportState !== "disconnected";
|
||||
const isConnected = transportState !== 'disconnected';
|
||||
|
||||
return (
|
||||
<div className="bot-container">
|
||||
@@ -24,31 +23,11 @@ function BotVideo() {
|
||||
}
|
||||
|
||||
function AppContent() {
|
||||
const client = useRTVIClient();
|
||||
return (
|
||||
<div className="app">
|
||||
<div className="status-bar">
|
||||
<StatusDisplay />
|
||||
<ConnectButton />
|
||||
<div
|
||||
className="controls"
|
||||
onClick={async () => {
|
||||
if (!client) {
|
||||
console.error("RTVI client is not initialized");
|
||||
return;
|
||||
}
|
||||
client.action({
|
||||
service: "tts",
|
||||
action: "say",
|
||||
arguments: [
|
||||
{ name: "text", value: "Hello, world!" },
|
||||
{ name: "interrupt", value: false },
|
||||
],
|
||||
});
|
||||
}}
|
||||
>
|
||||
<button className="connect-btn">Say something</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="main-content">
|
||||
|
||||
@@ -20,7 +20,6 @@ the conversation flow.
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
@@ -33,24 +32,19 @@ from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
LLMMessagesAppendFrame,
|
||||
InterimTranscriptionFrame,
|
||||
OutputImageRawFrame,
|
||||
SpriteFrame,
|
||||
TTSSpeakFrame,
|
||||
STTMuteFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frameworks.rtvi import (
|
||||
ActionResult,
|
||||
RTVIAction,
|
||||
RTVIActionArgument,
|
||||
RTVIObserver,
|
||||
RTVIProcessor,
|
||||
RTVIService,
|
||||
)
|
||||
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
@@ -113,6 +107,30 @@ class TalkingAnimation(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class TranscriptionMuteProcessor(FrameProcessor):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._is_muted = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, STTMuteFrame):
|
||||
self._is_muted = frame.mute
|
||||
|
||||
if isinstance(
|
||||
frame,
|
||||
(TranscriptionFrame, InterimTranscriptionFrame),
|
||||
):
|
||||
# Only pass VAD-related frames when not muted
|
||||
if not self._is_muted:
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
logger.trace(
|
||||
f"{frame.__class__.__name__} suppressed - Transcription currently muted"
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main bot execution function.
|
||||
|
||||
@@ -191,61 +209,25 @@ async def main():
|
||||
#
|
||||
# RTVI events for Pipecat client UI
|
||||
#
|
||||
|
||||
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
|
||||
|
||||
rtvi_tts = RTVIService(
|
||||
name="tts",
|
||||
options=[],
|
||||
# Configure the mute processor with both strategies
|
||||
stt_mute_processor = STTMuteFilter(
|
||||
config=STTMuteConfig(
|
||||
strategies={
|
||||
STTMuteStrategy.ALWAYS,
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
async def action_tts_say_handler(
|
||||
rtvi: RTVIProcessor, service: str, arguments: Dict[str, Any]
|
||||
) -> ActionResult:
|
||||
if "interrupt" in arguments and arguments["interrupt"]:
|
||||
# interrupting breaks function handling
|
||||
await rtvi.interrupt_bot()
|
||||
if "text" in arguments:
|
||||
save = arguments["save"] if "save" in arguments else False
|
||||
frame = TTSSpeakFrame(text=arguments["text"])
|
||||
await rtvi.push_frame(frame)
|
||||
if save:
|
||||
llm_frame = LLMMessagesAppendFrame(
|
||||
messages=[{"role": "assistant", "content": arguments["text"]}]
|
||||
)
|
||||
await rtvi.push_frame(llm_frame)
|
||||
|
||||
return True
|
||||
|
||||
action_tts_say = RTVIAction(
|
||||
service="tts",
|
||||
action="say",
|
||||
result="bool",
|
||||
arguments=[
|
||||
RTVIActionArgument(name="text", type="string"),
|
||||
RTVIActionArgument(name="save_in_context", type="bool"),
|
||||
],
|
||||
handler=action_tts_say_handler,
|
||||
)
|
||||
|
||||
async def action_tts_interrupt_handler(
|
||||
rtvi: RTVIProcessor, service: str, arguments: Dict[str, Any]
|
||||
) -> ActionResult:
|
||||
await rtvi.interrupt_bot()
|
||||
return True
|
||||
|
||||
action_tts_interrupt = RTVIAction(
|
||||
service="tts", action="interrupt", result="bool", handler=action_tts_interrupt_handler
|
||||
)
|
||||
|
||||
rtvi.register_service(rtvi_tts)
|
||||
rtvi.register_action(action_tts_say)
|
||||
rtvi.register_action(action_tts_interrupt)
|
||||
transcription_mute_processor = TranscriptionMuteProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
rtvi,
|
||||
stt_mute_processor, # Add the mute processor before STT
|
||||
transcription_mute_processor,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
|
||||
@@ -24,12 +24,10 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
STTMuteFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -177,8 +175,6 @@ class STTMuteFilter(FrameProcessor):
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
TranscriptionFrame,
|
||||
),
|
||||
):
|
||||
# Only pass VAD-related frames when not muted
|
||||
|
||||
@@ -61,9 +61,6 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.services.llm_service import (
|
||||
FunctionCallParams, # TODO(aleix): we shouldn't import `services` from `processors`
|
||||
)
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport
|
||||
@@ -395,32 +392,6 @@ class RTVIServerMessageFrame(SystemFrame):
|
||||
return f"{self.name}(data: {self.data})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTVIObserverParams:
|
||||
"""
|
||||
Parameters for configuring RTVI Observer behavior.
|
||||
|
||||
Attributes:
|
||||
bot_llm_enabled (bool): Indicates if the bot's LLM messages should be sent.
|
||||
bot_tts_enabled (bool): Indicates if the bot's TTS messages should be sent.
|
||||
bot_speaking_enabled (bool): Indicates if the bot's started/stopped speaking messages should be sent.
|
||||
user_llm_enabled (bool): Indicates if the user's LLM input messages should be sent.
|
||||
user_speaking_enabled (bool): Indicates if the user's started/stopped speaking messages should be sent.
|
||||
user_transcription_enabled (bool): Indicates if user's transcription messages should be sent.
|
||||
metrics_enabled (bool): Indicates if metrics messages should be sent.
|
||||
errors_enabled (bool): Indicates if errors messages should be sent.
|
||||
"""
|
||||
|
||||
bot_llm_enabled: bool = True
|
||||
bot_tts_enabled: bool = True
|
||||
bot_speaking_enabled: bool = True
|
||||
user_llm_enabled: bool = True
|
||||
user_speaking_enabled: bool = True
|
||||
user_transcription_enabled: bool = True
|
||||
metrics_enabled: bool = True
|
||||
errors_enabled: bool = True
|
||||
|
||||
|
||||
class RTVIObserver(BaseObserver):
|
||||
"""Pipeline frame observer for RTVI server message handling.
|
||||
|
||||
@@ -433,17 +404,14 @@ class RTVIObserver(BaseObserver):
|
||||
are handled by the RTVIProcessor.
|
||||
|
||||
Args:
|
||||
rtvi (RTVIProcessor): The RTVI processor to push frames to.
|
||||
params (RTVIObserverParams): Settings to enable/disable specific messages.
|
||||
rtvi (FrameProcessor): The RTVI processor to push frames to.
|
||||
"""
|
||||
|
||||
def __init__(self, rtvi: "RTVIProcessor", *, params: RTVIObserverParams = RTVIObserverParams()):
|
||||
def __init__(self, rtvi: FrameProcessor):
|
||||
super().__init__()
|
||||
self._rtvi = rtvi
|
||||
self._params = params
|
||||
self._bot_transcription = ""
|
||||
self._frames_seen = set()
|
||||
rtvi.set_errors_enabled(self._params.errors_enabled)
|
||||
|
||||
async def on_push_frame(
|
||||
self,
|
||||
@@ -470,41 +438,35 @@ class RTVIObserver(BaseObserver):
|
||||
# again the next time we see the frame.
|
||||
mark_as_seen = True
|
||||
|
||||
if (
|
||||
isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame))
|
||||
and self._params.user_speaking_enabled
|
||||
):
|
||||
if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)):
|
||||
await self._handle_interruptions(frame)
|
||||
elif (
|
||||
isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame))
|
||||
and (direction == FrameDirection.UPSTREAM)
|
||||
and self._params.bot_speaking_enabled
|
||||
elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)) and (
|
||||
direction == FrameDirection.UPSTREAM
|
||||
):
|
||||
await self._handle_bot_speaking(frame)
|
||||
elif (
|
||||
isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame))
|
||||
and self._params.user_transcription_enabled
|
||||
):
|
||||
elif isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)):
|
||||
await self._handle_user_transcriptions(frame)
|
||||
elif isinstance(frame, OpenAILLMContextFrame) and self._params.user_llm_enabled:
|
||||
elif isinstance(frame, OpenAILLMContextFrame):
|
||||
await self._handle_context(frame)
|
||||
elif isinstance(frame, LLMFullResponseStartFrame) and self._params.bot_llm_enabled:
|
||||
elif isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._push_bot_transcription()
|
||||
elif isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self.push_transport_message_urgent(RTVIBotLLMStartedMessage())
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) and self._params.bot_llm_enabled:
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self.push_transport_message_urgent(RTVIBotLLMStoppedMessage())
|
||||
elif isinstance(frame, LLMTextFrame) and self._params.bot_llm_enabled:
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
await self._handle_llm_text_frame(frame)
|
||||
elif isinstance(frame, TTSStartedFrame) and self._params.bot_tts_enabled:
|
||||
elif isinstance(frame, TTSStartedFrame):
|
||||
await self.push_transport_message_urgent(RTVIBotTTSStartedMessage())
|
||||
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
await self.push_transport_message_urgent(RTVIBotTTSStoppedMessage())
|
||||
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
if isinstance(src, BaseOutputTransport):
|
||||
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self.push_transport_message_urgent(message)
|
||||
else:
|
||||
mark_as_seen = False
|
||||
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
await self._handle_metrics(frame)
|
||||
elif isinstance(frame, RTVIServerMessageFrame):
|
||||
message = RTVIServerMessage(data=frame.data)
|
||||
@@ -647,7 +609,6 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._bot_ready = False
|
||||
self._client_ready = False
|
||||
self._client_ready_id = ""
|
||||
self._errors_enabled = True
|
||||
|
||||
self._registered_actions: Dict[str, RTVIAction] = {}
|
||||
self._registered_services: Dict[str, RTVIService] = {}
|
||||
@@ -687,23 +648,26 @@ class RTVIProcessor(FrameProcessor):
|
||||
await self._update_config(self._config, False)
|
||||
await self._send_bot_ready()
|
||||
|
||||
def set_errors_enabled(self, enabled: bool):
|
||||
self._errors_enabled = enabled
|
||||
|
||||
async def interrupt_bot(self):
|
||||
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
|
||||
|
||||
async def send_error(self, error: str):
|
||||
await self._send_error_frame(ErrorFrame(error=error))
|
||||
message = RTVIError(data=RTVIErrorData(error=error, fatal=False))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def handle_message(self, message: RTVIMessage):
|
||||
await self._message_queue.put(message)
|
||||
|
||||
async def handle_function_call(self, params: FunctionCallParams):
|
||||
async def handle_function_call(
|
||||
self,
|
||||
function_name: str,
|
||||
tool_call_id: str,
|
||||
arguments: Mapping[str, Any],
|
||||
):
|
||||
fn = RTVILLMFunctionCallMessageData(
|
||||
function_name=params.function_name,
|
||||
tool_call_id=params.tool_call_id,
|
||||
arguments=params.arguments,
|
||||
function_name=function_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
)
|
||||
message = RTVILLMFunctionCallMessage(data=fn)
|
||||
await self._push_transport_message(message, exclude_none=False)
|
||||
@@ -953,14 +917,12 @@ class RTVIProcessor(FrameProcessor):
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _send_error_frame(self, frame: ErrorFrame):
|
||||
if self._errors_enabled:
|
||||
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
|
||||
await self._push_transport_message(message)
|
||||
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
async def _send_error_response(self, id: str, error: str):
|
||||
if self._errors_enabled:
|
||||
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
|
||||
await self._push_transport_message(message)
|
||||
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
|
||||
await self._push_transport_message(message)
|
||||
|
||||
def _action_id(self, service: str, action: str) -> str:
|
||||
return f"{service}:{action}"
|
||||
|
||||
@@ -93,55 +93,49 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
|
||||
"""Aggregates and emits text fragments as a transcript message.
|
||||
|
||||
This method uses a heuristic to automatically detect whether text fragments
|
||||
contain embedded spacing (spaces at the beginning or end of fragments) or not,
|
||||
and applies the appropriate joining strategy. It handles fragments from different
|
||||
TTS services with different formatting patterns.
|
||||
use pre-spacing (spaces at the beginning of fragments) or not, and applies
|
||||
the appropriate joining strategy. It handles fragments from different TTS
|
||||
services with different formatting patterns.
|
||||
|
||||
Examples:
|
||||
Fragments with embedded spacing (concatenated):
|
||||
Pre-spaced fragments (concatenated):
|
||||
```
|
||||
TTSTextFrame: ["Hello"]
|
||||
TTSTextFrame: [" there"] # Leading space
|
||||
TTSTextFrame: [" there"]
|
||||
TTSTextFrame: ["!"]
|
||||
TTSTextFrame: [" How"] # Leading space
|
||||
TTSTextFrame: [" How"]
|
||||
TTSTextFrame: ["'s"]
|
||||
TTSTextFrame: [" it"] # Leading space
|
||||
TTSTextFrame: [" it"]
|
||||
TTSTextFrame: [" going"]
|
||||
TTSTextFrame: ["?"]
|
||||
```
|
||||
Result: "Hello there! How's it"
|
||||
Result: "Hello there! How's it going?"
|
||||
|
||||
Fragments with trailing spaces (concatenated):
|
||||
```
|
||||
TTSTextFrame: ["Hel"]
|
||||
TTSTextFrame: ["lo "] # Trailing space
|
||||
TTSTextFrame: ["to "] # Trailing space
|
||||
TTSTextFrame: ["you"]
|
||||
```
|
||||
Result: "Hello to you"
|
||||
|
||||
Word-by-word fragments without spacing (joined with spaces):
|
||||
Word-by-word fragments (joined with spaces):
|
||||
```
|
||||
TTSTextFrame: ["Hello"]
|
||||
TTSTextFrame: ["there"]
|
||||
TTSTextFrame: ["how"]
|
||||
TTSTextFrame: ["are"]
|
||||
TTSTextFrame: ["you"]
|
||||
TTSTextFrame: ["there!"]
|
||||
TTSTextFrame: ["How"]
|
||||
TTSTextFrame: ["is"]
|
||||
TTSTextFrame: ["it"]
|
||||
TTSTextFrame: ["going?"]
|
||||
```
|
||||
Result: "Hello there how are you"
|
||||
Result: "Hello there! How is it going?"
|
||||
"""
|
||||
if self._current_text_parts and self._aggregation_start_time:
|
||||
has_leading_spaces = any(
|
||||
part and part[0].isspace() for part in self._current_text_parts[1:]
|
||||
)
|
||||
has_trailing_spaces = any(
|
||||
part and part[-1].isspace() for part in self._current_text_parts[:-1]
|
||||
)
|
||||
# Heuristic to detect pre-spaced fragments
|
||||
uses_prespacing = False
|
||||
if len(self._current_text_parts) > 1:
|
||||
# Check if any fragment after the first one starts with whitespace
|
||||
has_spaced_parts = any(
|
||||
part and part[0].isspace() for part in self._current_text_parts[1:]
|
||||
)
|
||||
if has_spaced_parts:
|
||||
uses_prespacing = True
|
||||
|
||||
# If there are embedded spaces in the fragments, use direct concatenation
|
||||
contains_spacing_between_fragments = has_leading_spaces or has_trailing_spaces
|
||||
|
||||
# Apply corresponding joining method
|
||||
if contains_spacing_between_fragments:
|
||||
# Fragments already have spacing - just concatenate
|
||||
# Apply appropriate joining method
|
||||
if uses_prespacing:
|
||||
# Pre-spaced fragments - just concatenate
|
||||
content = "".join(self._current_text_parts)
|
||||
else:
|
||||
# Word-by-word fragments - join with spaces
|
||||
|
||||
@@ -193,10 +193,3 @@ def parse_server_event(str):
|
||||
except Exception as e:
|
||||
print(f"Error parsing server event: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class ContextWindowCompressionConfig(BaseModel):
|
||||
"""Configuration for context window compression."""
|
||||
|
||||
sliding_window: Optional[bool] = Field(default=True)
|
||||
trigger_tokens: Optional[int] = Field(default=None)
|
||||
|
||||
@@ -223,16 +223,6 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
|
||||
class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
|
||||
# but the GeminiMultimodalLiveAssistantContextAggregator pushes LLMTextFrames and TTSTextFrames. We
|
||||
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
|
||||
# are process. This ensures that the context gets only one set of messages.
|
||||
# GeminiMultimodalLiveLLMService also pushes TranscriptionFrames, so we need to
|
||||
# ignore pushing those as well, as they're also TextFrames.
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if not isinstance(frame, (LLMTextFrame, TranscriptionFrame)):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
# We don't want to store any images in the context. Revisit this later
|
||||
# when the API evolves.
|
||||
@@ -275,15 +265,6 @@ class GeminiVADParams(BaseModel):
|
||||
silence_duration_ms: Optional[int] = Field(default=None)
|
||||
|
||||
|
||||
class ContextWindowCompressionParams(BaseModel):
|
||||
"""Parameters for context window compression."""
|
||||
|
||||
enabled: bool = Field(default=False)
|
||||
trigger_tokens: Optional[int] = Field(
|
||||
default=None
|
||||
) # None = use default (80% of context window)
|
||||
|
||||
|
||||
class InputParams(BaseModel):
|
||||
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
|
||||
max_tokens: Optional[int] = Field(default=4096, ge=1)
|
||||
@@ -299,7 +280,6 @@ class InputParams(BaseModel):
|
||||
default=GeminiMediaResolution.UNSPECIFIED
|
||||
)
|
||||
vad: Optional[GeminiVADParams] = Field(default=None)
|
||||
context_window_compression: Optional[ContextWindowCompressionParams] = Field(default=None)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@@ -354,6 +334,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
self._bot_is_speaking = False
|
||||
self._user_audio_buffer = bytearray()
|
||||
self._bot_audio_buffer = bytearray()
|
||||
self._bot_text_buffer = ""
|
||||
|
||||
self._sample_rate = 24000
|
||||
|
||||
@@ -374,9 +355,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
"language": self._language_code,
|
||||
"media_resolution": params.media_resolution,
|
||||
"vad": params.vad,
|
||||
"context_window_compression": params.context_window_compression.model_dump()
|
||||
if params.context_window_compression
|
||||
else {},
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
|
||||
@@ -436,9 +414,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
#
|
||||
|
||||
async def _handle_interruption(self):
|
||||
self._bot_is_speaking = False
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
pass
|
||||
|
||||
async def _handle_user_started_speaking(self, frame):
|
||||
self._user_is_speaking = True
|
||||
@@ -461,12 +437,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
text = await self._transcribe_audio(audio, context)
|
||||
if not text:
|
||||
return
|
||||
# Sometimes the transcription contains newlines; we want to remove them.
|
||||
cleaned_text = text.rstrip("\n")
|
||||
logger.debug(f"[Transcription:user] {cleaned_text}")
|
||||
context.add_message({"role": "user", "content": [{"type": "text", "text": cleaned_text}]})
|
||||
logger.debug(f"[Transcription:user] {text}")
|
||||
context.add_message({"role": "user", "content": [{"type": "text", "text": text}]})
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(text=cleaned_text, user_id="user", timestamp=time_now_iso8601())
|
||||
TranscriptionFrame(text=text, user_id="user", timestamp=time_now_iso8601())
|
||||
)
|
||||
|
||||
async def _transcribe_audio(self, audio, context):
|
||||
@@ -587,21 +561,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
}
|
||||
}
|
||||
|
||||
# Add context window compression if enabled
|
||||
if self._settings.get("context_window_compression", {}).get("enabled", False):
|
||||
compression_config = {}
|
||||
# Add sliding window (always true if compression is enabled)
|
||||
compression_config["sliding_window"] = {}
|
||||
|
||||
# Add trigger_tokens if specified
|
||||
trigger_tokens = self._settings.get("context_window_compression", {}).get(
|
||||
"trigger_tokens"
|
||||
)
|
||||
if trigger_tokens is not None:
|
||||
compression_config["trigger_tokens"] = trigger_tokens
|
||||
|
||||
config_data["setup"]["context_window_compression"] = compression_config
|
||||
|
||||
# Add VAD configuration if provided
|
||||
if self._settings.get("vad"):
|
||||
vad_config = {}
|
||||
@@ -852,6 +811,14 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
if not part:
|
||||
return
|
||||
|
||||
text = part.text
|
||||
if text:
|
||||
if not self._bot_text_buffer:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
self._bot_text_buffer += text
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
|
||||
inline_data = part.inlineData
|
||||
if not inline_data:
|
||||
return
|
||||
@@ -866,7 +833,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
if not self._bot_is_speaking:
|
||||
self._bot_is_speaking = True
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
|
||||
self._bot_audio_buffer.extend(audio)
|
||||
frame = TTSAudioRawFrame(
|
||||
@@ -892,20 +858,24 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
|
||||
async def _handle_evt_turn_complete(self, evt):
|
||||
self._bot_is_speaking = False
|
||||
text = self._bot_text_buffer
|
||||
self._bot_text_buffer = ""
|
||||
|
||||
if text:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def _handle_evt_output_transcription(self, evt):
|
||||
if not evt.serverContent.outputTranscription:
|
||||
return
|
||||
|
||||
text = evt.serverContent.outputTranscription.text
|
||||
|
||||
if not text:
|
||||
return
|
||||
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
await self.push_frame(TTSTextFrame(text=text))
|
||||
if text:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.push_frame(LLMTextFrame(text=text))
|
||||
await self.push_frame(TTSTextFrame(text=text))
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
def create_context_aggregator(
|
||||
self,
|
||||
@@ -936,6 +906,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
|
||||
GeminiMultimodalLiveContext.upgrade(context)
|
||||
user = GeminiMultimodalLiveUserContextAggregator(context, params=user_params)
|
||||
|
||||
assistant_params.expect_stripped_words = False
|
||||
assistant_params.expect_stripped_words = True
|
||||
assistant = GeminiMultimodalLiveAssistantContextAggregator(context, params=assistant_params)
|
||||
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
@@ -14,7 +14,6 @@ from pipecat.frames.frames import (
|
||||
FunctionCallResultFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMTextFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
@@ -171,14 +170,6 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
|
||||
|
||||
|
||||
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
|
||||
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
|
||||
# but the OpenAIRealtimeLLMService pushes LLMTextFrames and TTSTextFrames. We
|
||||
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
|
||||
# are process. This ensures that the context gets only one set of messages.
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if not isinstance(frame, LLMTextFrame):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
|
||||
await super().handle_function_call_result(frame)
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ class RimeTTSService(AudioContextWordTTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
url: str = "wss://users.rime.ai/ws2",
|
||||
url: str = "wss://users-ws.rime.ai/ws2",
|
||||
model: str = "mistv2",
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
|
||||
@@ -12,9 +12,7 @@ from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
InputAudioRawFrame,
|
||||
InterimTranscriptionFrame,
|
||||
STTMuteFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
@@ -102,45 +100,6 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
async def test_transcription_frames_with_always_strategy(self):
|
||||
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
|
||||
|
||||
frames_to_send = [
|
||||
# Bot speaking - should mute
|
||||
BotStartedSpeakingFrame(),
|
||||
SleepFrame(sleep=0.1), # Wait for StartedSpeaking to process
|
||||
InterimTranscriptionFrame(
|
||||
user_id="user1", text="This should be suppressed", timestamp="1234567890"
|
||||
),
|
||||
TranscriptionFrame(
|
||||
user_id="user1", text="This should be suppressed", timestamp="1234567890"
|
||||
),
|
||||
SleepFrame(sleep=0.1), # Wait for transcription frames to queue
|
||||
BotStoppedSpeakingFrame(),
|
||||
# Bot not speaking - should pass through
|
||||
InterimTranscriptionFrame(
|
||||
user_id="user1", text="This should pass", timestamp="1234567891"
|
||||
),
|
||||
TranscriptionFrame(
|
||||
user_id="user1", text="This should pass through", timestamp="1234567891"
|
||||
),
|
||||
]
|
||||
|
||||
expected_returned_frames = [
|
||||
BotStartedSpeakingFrame,
|
||||
STTMuteFrame, # mute=True
|
||||
BotStoppedSpeakingFrame,
|
||||
STTMuteFrame, # mute=False
|
||||
InterimTranscriptionFrame, # Only passes through after bot stops speaking
|
||||
TranscriptionFrame, # Only passes through after bot stops speaking
|
||||
]
|
||||
|
||||
await run_test(
|
||||
filter,
|
||||
frames_to_send=frames_to_send,
|
||||
expected_down_frames=expected_returned_frames,
|
||||
)
|
||||
|
||||
# TODO: Revisit once we figure out how to test SystemFrames and DataFrames
|
||||
# async def test_function_call_strategy(self):
|
||||
# filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL}))
|
||||
|
||||
Reference in New Issue
Block a user