Compare commits

..

1 Commits

Author SHA1 Message Date
James Hush
d77ed9948d Save order 2025-04-30 16:04:38 +08:00
14 changed files with 141 additions and 411 deletions

View File

@@ -9,13 +9,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added `RTVIObserverParams` which allows you to configure what RTVI messages
are sent to the clients.
- Added a `context_window_compression` InputParam to
`GeminiMultimodalLiveLLMService` which allows you to enable a sliding context
window for the session as well as set the token limit of the sliding window.
- Updated `SmallWebRTCConnection` to support `ice_servers` with credentials.
- Added `VADUserStartedSpeakingFrame` and `VADUserStoppedSpeakingFrame`,
@@ -32,15 +25,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `MCPClient`; a way to connect to MCP servers and use the MCP servers'
tools.
- Added `Mem0 OSS`, along with Mem0 cloud support now the OSS version is also
available.
- Added `Mem0 OSS`, along with Mem0 cloud support now the OSS version is also available.
### Changed
- The `STTMuteFilter` now mutes `InterimTranscriptionFrame` and
`TranscriptionFrame` which allows the `STTMuteFilter` to be used in
conjunction with transports that generate transcripts, e.g. `DailyTransport`.
- Function calls now receive a single parameter `FunctionCallParams` instead of
`(function_name, tool_call_id, args, llm, context, result_callback)` which is
now deprecated.
@@ -87,9 +75,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Fixed an issue with `GeminiMultimodalLiveLLMService` where the context
contained tokens instead of words.
- Fixed an issue with HTTP Smart Turn handling, where the service returns a 500
error. Previously, this would cause an unhandled exception. Now, a 500 error
is treated as an incomplete response.
@@ -211,9 +196,8 @@ https://en.wikipedia.org/wiki/Saint_George%27s_Day_in_Catalonia
- Fixed an issue in `SmallWebRTCTransport` where an error was thrown if the
client did not create a video transceiver.
- Fixed an issue where LLM input parameters were not working and applied
correctly in `GoogleVertexLLMService`, causing unexpected behavior during
inference.
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
unexpected behavior during inference.
### Other

View File

@@ -50,6 +50,7 @@ autodoc_mock_imports = [
"pyht.protos",
"pyht.protos.api_pb2",
"pipecat_ai_playht", # PlayHT wrapper
"vllm",
"aiortc",
"aiortc.mediastreams",
"cv2",
@@ -75,6 +76,7 @@ autodoc_mock_imports = [
"openpipe",
"simli",
"soundfile",
# Existing mocks
"pipecat_ai_krisp",
"pyaudio",
"_tkinter",
@@ -85,66 +87,6 @@ autodoc_mock_imports = [
"pydantic.Field",
"pydantic._internal._model_construction",
"pydantic._internal._fields",
# Moondream dependencies
"torch",
"transformers",
"intel_extension_for_pytorch",
# Ultravox dependencies
"huggingface_hub",
"vllm",
"vllm.engine.arg_utils",
"transformers.AutoTokenizer",
# Langchain dependencies
"langchain_core",
"langchain_core.messages",
"langchain_core.runnables",
"langchain_core.messages.AIMessageChunk",
"langchain_core.runnables.Runnable",
# LiveKit dependencies
"livekit",
"livekit.rtc",
"livekit_api",
"livekit_protocol",
"tenacity",
"tenacity.retry",
"tenacity.stop_after_attempt",
"tenacity.wait_exponential",
"rtc",
"rtc.Room",
"rtc.RoomOptions",
"rtc.AudioSource",
"rtc.LocalAudioTrack",
"rtc.TrackPublishOptions",
"rtc.TrackSource",
"rtc.AudioStream",
"rtc.AudioFrameEvent",
"rtc.AudioFrame",
"rtc.Track",
"rtc.TrackKind",
"rtc.RemoteParticipant",
"rtc.RemoteTrackPublication",
"rtc.DataPacket",
# Riva dependencies
"riva",
"riva.client",
"riva.client.Auth",
"riva.client.ASRService",
"riva.client.StreamingRecognitionConfig",
"riva.client.RecognitionConfig",
"riva.client.AudioEncoding",
"riva.client.proto.riva_tts_pb2",
"riva.client.SpeechSynthesisService",
# Local CoreML Smart Turn dependencies
"coremltools",
"coremltools.models",
"coremltools.models.MLModel",
"torch",
"torch.nn",
"torch.nn.functional",
"transformers",
"transformers.AutoFeatureExtractor",
# Also add specific classes that are imported
"AutoFeatureExtractor",
]
# HTML output settings
@@ -176,25 +118,12 @@ def verify_modules():
},
}
# Skip importing modules that are in autodoc_mock_imports
skipped_modules = set(autodoc_mock_imports)
missing = []
for category, modules in required_modules.items():
if isinstance(modules, dict):
# Handle nested structure
for subcategory, submodules in modules.items():
for module in submodules:
# Check if module is in autodoc_mock_imports
if (
f"pipecat.{category}.{subcategory}.{module}" in skipped_modules
or module in skipped_modules
):
logger.info(
f"Skipping import of mocked module: pipecat.{category}.{subcategory}.{module}"
)
continue
try:
__import__(f"pipecat.{category}.{subcategory}.{module}")
logger.info(
@@ -208,11 +137,6 @@ def verify_modules():
else:
# Handle flat structure
for module in modules:
# Check if module is in autodoc_mock_imports
if f"pipecat.{category}.{module}" in skipped_modules or module in skipped_modules:
logger.info(f"Skipping import of mocked module: pipecat.{category}.{module}")
continue
try:
__import__(f"pipecat.{category}.{module}")
logger.info(f"Successfully imported pipecat.{category}.{module}")

View File

@@ -26,23 +26,20 @@ pipecat-ai[grok]
pipecat-ai[groq]
# pipecat-ai[krisp] # Mocked
pipecat-ai[koala]
# pipecat-ai[langchain] # Mocked
# pipecat-ai[livekit] # Mocked
pipecat-ai[langchain]
pipecat-ai[livekit]
pipecat-ai[lmnt]
pipecat-ai[local]
# pipecat-ai[local-smart-turn] # Mocked
# pipecat-ai[mem0] # Mocked
# pipecat-ai[mlx-whisper] # Mocked
# pipecat-ai[moondream] # Mocked
pipecat-ai[moondream]
pipecat-ai[nim]
# pipecat-ai[neuphonic] # Mocked
pipecat-ai[noisereduce]
pipecat-ai[openai]
# pipecat-ai[openpipe]
# pipecat-ai[playht] # Mocked due to grpcio conflict with riva
pipecat-ai[qwen]
pipecat-ai[remote-smart-turn]
# pipecat-ai[riva] # Mocked
pipecat-ai[riva]
pipecat-ai[silero]
pipecat-ai[simli]
pipecat-ai[soundfile]

View File

@@ -89,7 +89,6 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection, _: argparse.Namespac
llm = GeminiMultimodalLiveLLMService(
api_key=os.getenv("GOOGLE_API_KEY"),
system_instruction=system_instruction,
transcribe_user_audio=True,
tools=tools,
)

View File

@@ -1,18 +1,17 @@
import {
RTVIClientAudio,
RTVIClientVideo,
useRTVIClient,
useRTVIClientTransportState,
} from "@pipecat-ai/client-react";
import { RTVIProvider } from "./providers/RTVIProvider";
import { ConnectButton } from "./components/ConnectButton";
import { StatusDisplay } from "./components/StatusDisplay";
import { DebugDisplay } from "./components/DebugDisplay";
import "./App.css";
} from '@pipecat-ai/client-react';
import { RTVIProvider } from './providers/RTVIProvider';
import { ConnectButton } from './components/ConnectButton';
import { StatusDisplay } from './components/StatusDisplay';
import { DebugDisplay } from './components/DebugDisplay';
import './App.css';
function BotVideo() {
const transportState = useRTVIClientTransportState();
const isConnected = transportState !== "disconnected";
const isConnected = transportState !== 'disconnected';
return (
<div className="bot-container">
@@ -24,31 +23,11 @@ function BotVideo() {
}
function AppContent() {
const client = useRTVIClient();
return (
<div className="app">
<div className="status-bar">
<StatusDisplay />
<ConnectButton />
<div
className="controls"
onClick={async () => {
if (!client) {
console.error("RTVI client is not initialized");
return;
}
client.action({
service: "tts",
action: "say",
arguments: [
{ name: "text", value: "Hello, world!" },
{ name: "interrupt", value: false },
],
});
}}
>
<button className="connect-btn">Say something</button>
</div>
</div>
<div className="main-content">

View File

@@ -20,7 +20,6 @@ the conversation flow.
import asyncio
import os
import sys
from typing import Any, Dict
import aiohttp
from dotenv import load_dotenv
@@ -33,24 +32,19 @@ from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
Frame,
LLMMessagesAppendFrame,
InterimTranscriptionFrame,
OutputImageRawFrame,
SpriteFrame,
TTSSpeakFrame,
STTMuteFrame,
TranscriptionFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frameworks.rtvi import (
ActionResult,
RTVIAction,
RTVIActionArgument,
RTVIObserver,
RTVIProcessor,
RTVIService,
)
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
@@ -113,6 +107,30 @@ class TalkingAnimation(FrameProcessor):
await self.push_frame(frame, direction)
class TranscriptionMuteProcessor(FrameProcessor):
def __init__(self):
super().__init__()
self._is_muted = False
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, STTMuteFrame):
self._is_muted = frame.mute
if isinstance(
frame,
(TranscriptionFrame, InterimTranscriptionFrame),
):
# Only pass VAD-related frames when not muted
if not self._is_muted:
await self.push_frame(frame, direction)
else:
logger.trace(
f"{frame.__class__.__name__} suppressed - Transcription currently muted"
)
async def main():
"""Main bot execution function.
@@ -191,61 +209,25 @@ async def main():
#
# RTVI events for Pipecat client UI
#
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
rtvi_tts = RTVIService(
name="tts",
options=[],
# Configure the mute processor with both strategies
stt_mute_processor = STTMuteFilter(
config=STTMuteConfig(
strategies={
STTMuteStrategy.ALWAYS,
}
),
)
async def action_tts_say_handler(
rtvi: RTVIProcessor, service: str, arguments: Dict[str, Any]
) -> ActionResult:
if "interrupt" in arguments and arguments["interrupt"]:
# interrupting breaks function handling
await rtvi.interrupt_bot()
if "text" in arguments:
save = arguments["save"] if "save" in arguments else False
frame = TTSSpeakFrame(text=arguments["text"])
await rtvi.push_frame(frame)
if save:
llm_frame = LLMMessagesAppendFrame(
messages=[{"role": "assistant", "content": arguments["text"]}]
)
await rtvi.push_frame(llm_frame)
return True
action_tts_say = RTVIAction(
service="tts",
action="say",
result="bool",
arguments=[
RTVIActionArgument(name="text", type="string"),
RTVIActionArgument(name="save_in_context", type="bool"),
],
handler=action_tts_say_handler,
)
async def action_tts_interrupt_handler(
rtvi: RTVIProcessor, service: str, arguments: Dict[str, Any]
) -> ActionResult:
await rtvi.interrupt_bot()
return True
action_tts_interrupt = RTVIAction(
service="tts", action="interrupt", result="bool", handler=action_tts_interrupt_handler
)
rtvi.register_service(rtvi_tts)
rtvi.register_action(action_tts_say)
rtvi.register_action(action_tts_interrupt)
transcription_mute_processor = TranscriptionMuteProcessor()
pipeline = Pipeline(
[
transport.input(),
rtvi,
stt_mute_processor, # Add the mute processor before STT
transcription_mute_processor,
context_aggregator.user(),
llm,
tts,

View File

@@ -24,12 +24,10 @@ from pipecat.frames.frames import (
FunctionCallInProgressFrame,
FunctionCallResultFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
StartFrame,
StartInterruptionFrame,
StopInterruptionFrame,
STTMuteFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -177,8 +175,6 @@ class STTMuteFilter(FrameProcessor):
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
TranscriptionFrame,
),
):
# Only pass VAD-related frames when not muted

View File

@@ -61,9 +61,6 @@ from pipecat.processors.aggregators.openai_llm_context import (
OpenAILLMContextFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.llm_service import (
FunctionCallParams, # TODO(aleix): we shouldn't import `services` from `processors`
)
from pipecat.transports.base_input import BaseInputTransport
from pipecat.transports.base_output import BaseOutputTransport
from pipecat.transports.base_transport import BaseTransport
@@ -395,32 +392,6 @@ class RTVIServerMessageFrame(SystemFrame):
return f"{self.name}(data: {self.data})"
@dataclass
class RTVIObserverParams:
"""
Parameters for configuring RTVI Observer behavior.
Attributes:
bot_llm_enabled (bool): Indicates if the bot's LLM messages should be sent.
bot_tts_enabled (bool): Indicates if the bot's TTS messages should be sent.
bot_speaking_enabled (bool): Indicates if the bot's started/stopped speaking messages should be sent.
user_llm_enabled (bool): Indicates if the user's LLM input messages should be sent.
user_speaking_enabled (bool): Indicates if the user's started/stopped speaking messages should be sent.
user_transcription_enabled (bool): Indicates if user's transcription messages should be sent.
metrics_enabled (bool): Indicates if metrics messages should be sent.
errors_enabled (bool): Indicates if errors messages should be sent.
"""
bot_llm_enabled: bool = True
bot_tts_enabled: bool = True
bot_speaking_enabled: bool = True
user_llm_enabled: bool = True
user_speaking_enabled: bool = True
user_transcription_enabled: bool = True
metrics_enabled: bool = True
errors_enabled: bool = True
class RTVIObserver(BaseObserver):
"""Pipeline frame observer for RTVI server message handling.
@@ -433,17 +404,14 @@ class RTVIObserver(BaseObserver):
are handled by the RTVIProcessor.
Args:
rtvi (RTVIProcessor): The RTVI processor to push frames to.
params (RTVIObserverParams): Settings to enable/disable specific messages.
rtvi (FrameProcessor): The RTVI processor to push frames to.
"""
def __init__(self, rtvi: "RTVIProcessor", *, params: RTVIObserverParams = RTVIObserverParams()):
def __init__(self, rtvi: FrameProcessor):
super().__init__()
self._rtvi = rtvi
self._params = params
self._bot_transcription = ""
self._frames_seen = set()
rtvi.set_errors_enabled(self._params.errors_enabled)
async def on_push_frame(
self,
@@ -470,41 +438,35 @@ class RTVIObserver(BaseObserver):
# again the next time we see the frame.
mark_as_seen = True
if (
isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame))
and self._params.user_speaking_enabled
):
if isinstance(frame, (UserStartedSpeakingFrame, UserStoppedSpeakingFrame)):
await self._handle_interruptions(frame)
elif (
isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame))
and (direction == FrameDirection.UPSTREAM)
and self._params.bot_speaking_enabled
elif isinstance(frame, (BotStartedSpeakingFrame, BotStoppedSpeakingFrame)) and (
direction == FrameDirection.UPSTREAM
):
await self._handle_bot_speaking(frame)
elif (
isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame))
and self._params.user_transcription_enabled
):
elif isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)):
await self._handle_user_transcriptions(frame)
elif isinstance(frame, OpenAILLMContextFrame) and self._params.user_llm_enabled:
elif isinstance(frame, OpenAILLMContextFrame):
await self._handle_context(frame)
elif isinstance(frame, LLMFullResponseStartFrame) and self._params.bot_llm_enabled:
elif isinstance(frame, UserStartedSpeakingFrame):
await self._push_bot_transcription()
elif isinstance(frame, LLMFullResponseStartFrame):
await self.push_transport_message_urgent(RTVIBotLLMStartedMessage())
elif isinstance(frame, LLMFullResponseEndFrame) and self._params.bot_llm_enabled:
elif isinstance(frame, LLMFullResponseEndFrame):
await self.push_transport_message_urgent(RTVIBotLLMStoppedMessage())
elif isinstance(frame, LLMTextFrame) and self._params.bot_llm_enabled:
elif isinstance(frame, LLMTextFrame):
await self._handle_llm_text_frame(frame)
elif isinstance(frame, TTSStartedFrame) and self._params.bot_tts_enabled:
elif isinstance(frame, TTSStartedFrame):
await self.push_transport_message_urgent(RTVIBotTTSStartedMessage())
elif isinstance(frame, TTSStoppedFrame) and self._params.bot_tts_enabled:
elif isinstance(frame, TTSStoppedFrame):
await self.push_transport_message_urgent(RTVIBotTTSStoppedMessage())
elif isinstance(frame, TTSTextFrame) and self._params.bot_tts_enabled:
elif isinstance(frame, TTSTextFrame):
if isinstance(src, BaseOutputTransport):
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
await self.push_transport_message_urgent(message)
else:
mark_as_seen = False
elif isinstance(frame, MetricsFrame) and self._params.metrics_enabled:
elif isinstance(frame, MetricsFrame):
await self._handle_metrics(frame)
elif isinstance(frame, RTVIServerMessageFrame):
message = RTVIServerMessage(data=frame.data)
@@ -647,7 +609,6 @@ class RTVIProcessor(FrameProcessor):
self._bot_ready = False
self._client_ready = False
self._client_ready_id = ""
self._errors_enabled = True
self._registered_actions: Dict[str, RTVIAction] = {}
self._registered_services: Dict[str, RTVIService] = {}
@@ -687,23 +648,26 @@ class RTVIProcessor(FrameProcessor):
await self._update_config(self._config, False)
await self._send_bot_ready()
def set_errors_enabled(self, enabled: bool):
self._errors_enabled = enabled
async def interrupt_bot(self):
await self.push_frame(BotInterruptionFrame(), FrameDirection.UPSTREAM)
async def send_error(self, error: str):
await self._send_error_frame(ErrorFrame(error=error))
message = RTVIError(data=RTVIErrorData(error=error, fatal=False))
await self._push_transport_message(message)
async def handle_message(self, message: RTVIMessage):
await self._message_queue.put(message)
async def handle_function_call(self, params: FunctionCallParams):
async def handle_function_call(
self,
function_name: str,
tool_call_id: str,
arguments: Mapping[str, Any],
):
fn = RTVILLMFunctionCallMessageData(
function_name=params.function_name,
tool_call_id=params.tool_call_id,
arguments=params.arguments,
function_name=function_name,
tool_call_id=tool_call_id,
arguments=arguments,
)
message = RTVILLMFunctionCallMessage(data=fn)
await self._push_transport_message(message, exclude_none=False)
@@ -953,14 +917,12 @@ class RTVIProcessor(FrameProcessor):
await self._push_transport_message(message)
async def _send_error_frame(self, frame: ErrorFrame):
if self._errors_enabled:
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
await self._push_transport_message(message)
message = RTVIError(data=RTVIErrorData(error=frame.error, fatal=frame.fatal))
await self._push_transport_message(message)
async def _send_error_response(self, id: str, error: str):
if self._errors_enabled:
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
await self._push_transport_message(message)
message = RTVIErrorResponse(id=id, data=RTVIErrorResponseData(error=error))
await self._push_transport_message(message)
def _action_id(self, service: str, action: str) -> str:
return f"{service}:{action}"

View File

@@ -93,55 +93,49 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
"""Aggregates and emits text fragments as a transcript message.
This method uses a heuristic to automatically detect whether text fragments
contain embedded spacing (spaces at the beginning or end of fragments) or not,
and applies the appropriate joining strategy. It handles fragments from different
TTS services with different formatting patterns.
use pre-spacing (spaces at the beginning of fragments) or not, and applies
the appropriate joining strategy. It handles fragments from different TTS
services with different formatting patterns.
Examples:
Fragments with embedded spacing (concatenated):
Pre-spaced fragments (concatenated):
```
TTSTextFrame: ["Hello"]
TTSTextFrame: [" there"] # Leading space
TTSTextFrame: [" there"]
TTSTextFrame: ["!"]
TTSTextFrame: [" How"] # Leading space
TTSTextFrame: [" How"]
TTSTextFrame: ["'s"]
TTSTextFrame: [" it"] # Leading space
TTSTextFrame: [" it"]
TTSTextFrame: [" going"]
TTSTextFrame: ["?"]
```
Result: "Hello there! How's it"
Result: "Hello there! How's it going?"
Fragments with trailing spaces (concatenated):
```
TTSTextFrame: ["Hel"]
TTSTextFrame: ["lo "] # Trailing space
TTSTextFrame: ["to "] # Trailing space
TTSTextFrame: ["you"]
```
Result: "Hello to you"
Word-by-word fragments without spacing (joined with spaces):
Word-by-word fragments (joined with spaces):
```
TTSTextFrame: ["Hello"]
TTSTextFrame: ["there"]
TTSTextFrame: ["how"]
TTSTextFrame: ["are"]
TTSTextFrame: ["you"]
TTSTextFrame: ["there!"]
TTSTextFrame: ["How"]
TTSTextFrame: ["is"]
TTSTextFrame: ["it"]
TTSTextFrame: ["going?"]
```
Result: "Hello there how are you"
Result: "Hello there! How is it going?"
"""
if self._current_text_parts and self._aggregation_start_time:
has_leading_spaces = any(
part and part[0].isspace() for part in self._current_text_parts[1:]
)
has_trailing_spaces = any(
part and part[-1].isspace() for part in self._current_text_parts[:-1]
)
# Heuristic to detect pre-spaced fragments
uses_prespacing = False
if len(self._current_text_parts) > 1:
# Check if any fragment after the first one starts with whitespace
has_spaced_parts = any(
part and part[0].isspace() for part in self._current_text_parts[1:]
)
if has_spaced_parts:
uses_prespacing = True
# If there are embedded spaces in the fragments, use direct concatenation
contains_spacing_between_fragments = has_leading_spaces or has_trailing_spaces
# Apply corresponding joining method
if contains_spacing_between_fragments:
# Fragments already have spacing - just concatenate
# Apply appropriate joining method
if uses_prespacing:
# Pre-spaced fragments - just concatenate
content = "".join(self._current_text_parts)
else:
# Word-by-word fragments - join with spaces

View File

@@ -193,10 +193,3 @@ def parse_server_event(str):
except Exception as e:
print(f"Error parsing server event: {e}")
return None
class ContextWindowCompressionConfig(BaseModel):
"""Configuration for context window compression."""
sliding_window: Optional[bool] = Field(default=True)
trigger_tokens: Optional[int] = Field(default=None)

View File

@@ -223,16 +223,6 @@ class GeminiMultimodalLiveUserContextAggregator(OpenAIUserContextAggregator):
class GeminiMultimodalLiveAssistantContextAggregator(OpenAIAssistantContextAggregator):
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
# but the GeminiMultimodalLiveAssistantContextAggregator pushes LLMTextFrames and TTSTextFrames. We
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
# are process. This ensures that the context gets only one set of messages.
# GeminiMultimodalLiveLLMService also pushes TranscriptionFrames, so we need to
# ignore pushing those as well, as they're also TextFrames.
async def process_frame(self, frame: Frame, direction: FrameDirection):
if not isinstance(frame, (LLMTextFrame, TranscriptionFrame)):
await super().process_frame(frame, direction)
async def handle_user_image_frame(self, frame: UserImageRawFrame):
# We don't want to store any images in the context. Revisit this later
# when the API evolves.
@@ -275,15 +265,6 @@ class GeminiVADParams(BaseModel):
silence_duration_ms: Optional[int] = Field(default=None)
class ContextWindowCompressionParams(BaseModel):
"""Parameters for context window compression."""
enabled: bool = Field(default=False)
trigger_tokens: Optional[int] = Field(
default=None
) # None = use default (80% of context window)
class InputParams(BaseModel):
frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0)
max_tokens: Optional[int] = Field(default=4096, ge=1)
@@ -299,7 +280,6 @@ class InputParams(BaseModel):
default=GeminiMediaResolution.UNSPECIFIED
)
vad: Optional[GeminiVADParams] = Field(default=None)
context_window_compression: Optional[ContextWindowCompressionParams] = Field(default=None)
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
@@ -354,6 +334,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
self._bot_is_speaking = False
self._user_audio_buffer = bytearray()
self._bot_audio_buffer = bytearray()
self._bot_text_buffer = ""
self._sample_rate = 24000
@@ -374,9 +355,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
"language": self._language_code,
"media_resolution": params.media_resolution,
"vad": params.vad,
"context_window_compression": params.context_window_compression.model_dump()
if params.context_window_compression
else {},
"extra": params.extra if isinstance(params.extra, dict) else {},
}
@@ -436,9 +414,7 @@ class GeminiMultimodalLiveLLMService(LLMService):
#
async def _handle_interruption(self):
self._bot_is_speaking = False
await self.push_frame(TTSStoppedFrame())
await self.push_frame(LLMFullResponseEndFrame())
pass
async def _handle_user_started_speaking(self, frame):
self._user_is_speaking = True
@@ -461,12 +437,10 @@ class GeminiMultimodalLiveLLMService(LLMService):
text = await self._transcribe_audio(audio, context)
if not text:
return
# Sometimes the transcription contains newlines; we want to remove them.
cleaned_text = text.rstrip("\n")
logger.debug(f"[Transcription:user] {cleaned_text}")
context.add_message({"role": "user", "content": [{"type": "text", "text": cleaned_text}]})
logger.debug(f"[Transcription:user] {text}")
context.add_message({"role": "user", "content": [{"type": "text", "text": text}]})
await self.push_frame(
TranscriptionFrame(text=cleaned_text, user_id="user", timestamp=time_now_iso8601())
TranscriptionFrame(text=text, user_id="user", timestamp=time_now_iso8601())
)
async def _transcribe_audio(self, audio, context):
@@ -587,21 +561,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
}
}
# Add context window compression if enabled
if self._settings.get("context_window_compression", {}).get("enabled", False):
compression_config = {}
# Add sliding window (always true if compression is enabled)
compression_config["sliding_window"] = {}
# Add trigger_tokens if specified
trigger_tokens = self._settings.get("context_window_compression", {}).get(
"trigger_tokens"
)
if trigger_tokens is not None:
compression_config["trigger_tokens"] = trigger_tokens
config_data["setup"]["context_window_compression"] = compression_config
# Add VAD configuration if provided
if self._settings.get("vad"):
vad_config = {}
@@ -852,6 +811,14 @@ class GeminiMultimodalLiveLLMService(LLMService):
if not part:
return
text = part.text
if text:
if not self._bot_text_buffer:
await self.push_frame(LLMFullResponseStartFrame())
self._bot_text_buffer += text
await self.push_frame(LLMTextFrame(text=text))
inline_data = part.inlineData
if not inline_data:
return
@@ -866,7 +833,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
if not self._bot_is_speaking:
self._bot_is_speaking = True
await self.push_frame(TTSStartedFrame())
await self.push_frame(LLMFullResponseStartFrame())
self._bot_audio_buffer.extend(audio)
frame = TTSAudioRawFrame(
@@ -892,20 +858,24 @@ class GeminiMultimodalLiveLLMService(LLMService):
async def _handle_evt_turn_complete(self, evt):
self._bot_is_speaking = False
text = self._bot_text_buffer
self._bot_text_buffer = ""
if text:
await self.push_frame(LLMFullResponseEndFrame())
await self.push_frame(TTSStoppedFrame())
await self.push_frame(LLMFullResponseEndFrame())
async def _handle_evt_output_transcription(self, evt):
if not evt.serverContent.outputTranscription:
return
text = evt.serverContent.outputTranscription.text
if not text:
return
await self.push_frame(LLMTextFrame(text=text))
await self.push_frame(TTSTextFrame(text=text))
if text:
await self.push_frame(LLMFullResponseStartFrame())
await self.push_frame(LLMTextFrame(text=text))
await self.push_frame(TTSTextFrame(text=text))
await self.push_frame(LLMFullResponseEndFrame())
def create_context_aggregator(
self,
@@ -936,6 +906,6 @@ class GeminiMultimodalLiveLLMService(LLMService):
GeminiMultimodalLiveContext.upgrade(context)
user = GeminiMultimodalLiveUserContextAggregator(context, params=user_params)
assistant_params.expect_stripped_words = False
assistant_params.expect_stripped_words = True
assistant = GeminiMultimodalLiveAssistantContextAggregator(context, params=assistant_params)
return GeminiMultimodalLiveContextAggregatorPair(_user=user, _assistant=assistant)

View File

@@ -14,7 +14,6 @@ from pipecat.frames.frames import (
FunctionCallResultFrame,
LLMMessagesUpdateFrame,
LLMSetToolsFrame,
LLMTextFrame,
)
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection
@@ -171,14 +170,6 @@ class OpenAIRealtimeUserContextAggregator(OpenAIUserContextAggregator):
class OpenAIRealtimeAssistantContextAggregator(OpenAIAssistantContextAggregator):
# The LLMAssistantContextAggregator uses TextFrames to aggregate the LLM output,
# but the OpenAIRealtimeLLMService pushes LLMTextFrames and TTSTextFrames. We
# need to override this proces_frame for LLMTextFrame, so that only the TTSTextFrames
# are process. This ensures that the context gets only one set of messages.
async def process_frame(self, frame: Frame, direction: FrameDirection):
if not isinstance(frame, LLMTextFrame):
await super().process_frame(frame, direction)
async def handle_function_call_result(self, frame: FunctionCallResultFrame):
await super().handle_function_call_result(frame)

View File

@@ -74,7 +74,7 @@ class RimeTTSService(AudioContextWordTTSService):
*,
api_key: str,
voice_id: str,
url: str = "wss://users.rime.ai/ws2",
url: str = "wss://users-ws.rime.ai/ws2",
model: str = "mistv2",
sample_rate: Optional[int] = None,
params: InputParams = InputParams(),

View File

@@ -12,9 +12,7 @@ from pipecat.frames.frames import (
FunctionCallInProgressFrame,
FunctionCallResultFrame,
InputAudioRawFrame,
InterimTranscriptionFrame,
STTMuteFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
@@ -102,45 +100,6 @@ class TestSTTMuteFilter(unittest.IsolatedAsyncioTestCase):
expected_down_frames=expected_returned_frames,
)
async def test_transcription_frames_with_always_strategy(self):
filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}))
frames_to_send = [
# Bot speaking - should mute
BotStartedSpeakingFrame(),
SleepFrame(sleep=0.1), # Wait for StartedSpeaking to process
InterimTranscriptionFrame(
user_id="user1", text="This should be suppressed", timestamp="1234567890"
),
TranscriptionFrame(
user_id="user1", text="This should be suppressed", timestamp="1234567890"
),
SleepFrame(sleep=0.1), # Wait for transcription frames to queue
BotStoppedSpeakingFrame(),
# Bot not speaking - should pass through
InterimTranscriptionFrame(
user_id="user1", text="This should pass", timestamp="1234567891"
),
TranscriptionFrame(
user_id="user1", text="This should pass through", timestamp="1234567891"
),
]
expected_returned_frames = [
BotStartedSpeakingFrame,
STTMuteFrame, # mute=True
BotStoppedSpeakingFrame,
STTMuteFrame, # mute=False
InterimTranscriptionFrame, # Only passes through after bot stops speaking
TranscriptionFrame, # Only passes through after bot stops speaking
]
await run_test(
filter,
frames_to_send=frames_to_send,
expected_down_frames=expected_returned_frames,
)
# TODO: Revisit once we figure out how to test SystemFrames and DataFrames
# async def test_function_call_strategy(self):
# filter = STTMuteFilter(config=STTMuteConfig(strategies={STTMuteStrategy.FUNCTION_CALL}))