Compare commits

...

4 Commits

Author SHA1 Message Date
James Hush
5e6979cf95 Remove logs 2025-03-20 14:28:53 +08:00
James Hush
b1e9dc5bb4 Remove extra imports 2025-03-20 14:27:58 +08:00
James Hush
2d06cd2109 Send message when user is muted 2025-03-20 14:27:03 +08:00
James Hush
1a237ddae8 This is working but RTVI still gets the transcript
FEAT: Example of muting LLMMessages to LLM.
2025-03-20 13:46:13 +08:00

View File

@@ -29,18 +29,30 @@ from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
BotSpeakingFrame,
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
Frame,
InterimTranscriptionFrame,
LLMMessagesFrame,
OutputImageRawFrame,
SpriteFrame,
STTMuteFrame,
TranscriptionFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
from pipecat.processors.frameworks.rtvi import (
RTVIConfig,
RTVIObserver,
RTVIProcessor,
RTVIServerMessageFrame,
)
from pipecat.services.deepgram import DeepgramSTTService
from pipecat.services.elevenlabs import ElevenLabsTTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
@@ -49,6 +61,51 @@ load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class TranscriptionMuteProcessor(FrameProcessor):
"""Takes in STTMuteFrame and mutes TranscriptionFrame based on its content."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._is_muted = False
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames and mute TranscriptionFrame based on STTMuteFrame content.
Args:
frame: The incoming frame to process
direction: The direction of frame flow in the pipeline
"""
await super().process_frame(frame, direction)
if isinstance(frame, STTMuteFrame):
self._is_muted = frame.mute
frame = RTVIServerMessageFrame(
data={"type": "user-muted-event", "payload": {"is_muted": self._is_muted}}
)
self.push_frame(frame)
if isinstance(
frame,
(
TranscriptionFrame,
InterimTranscriptionFrame,
LLMMessagesFrame,
),
):
# Only pass frames when not muted
if not self._is_muted:
await self.push_frame(frame, direction)
else:
logger.trace(
f"{frame.__class__.__name__} suppressed - Transcription STT currently muted"
)
else:
await self.push_frame(frame, direction)
sprites = []
script_dir = os.path.dirname(__file__)
@@ -128,7 +185,8 @@ async def main():
camera_out_height=576,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
transcription_enabled=True,
vad_audio_passthrough=True,
# transcription_enabled=True,
#
# Spanish
#
@@ -183,9 +241,20 @@ async def main():
#
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
stt_mute_processor = STTMuteFilter(
config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}),
)
transcription_mute_processor = TranscriptionMuteProcessor()
pipeline = Pipeline(
[
transport.input(),
stt,
stt_mute_processor,
transcription_mute_processor,
rtvi,
context_aggregator.user(),
llm,
@@ -213,7 +282,7 @@ async def main():
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# await transport.capture_participant_transcription(participant["id"])
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_participant_left")