Compare commits
4 Commits
aleix/intr
...
hush/trans
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e6979cf95 | ||
|
|
b1e9dc5bb4 | ||
|
|
2d06cd2109 | ||
|
|
1a237ddae8 |
@@ -29,18 +29,30 @@ from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
BotSpeakingFrame,
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMMessagesFrame,
|
||||
OutputImageRawFrame,
|
||||
SpriteFrame,
|
||||
STTMuteFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.filters.stt_mute_filter import STTMuteConfig, STTMuteFilter, STTMuteStrategy
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frameworks.rtvi import RTVIConfig, RTVIObserver, RTVIProcessor
|
||||
from pipecat.processors.frameworks.rtvi import (
|
||||
RTVIConfig,
|
||||
RTVIObserver,
|
||||
RTVIProcessor,
|
||||
RTVIServerMessageFrame,
|
||||
)
|
||||
from pipecat.services.deepgram import DeepgramSTTService
|
||||
from pipecat.services.elevenlabs import ElevenLabsTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
@@ -49,6 +61,51 @@ load_dotenv(override=True)
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class TranscriptionMuteProcessor(FrameProcessor):
|
||||
"""Takes in STTMuteFrame and mutes TranscriptionFrame based on its content."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._is_muted = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process incoming frames and mute TranscriptionFrame based on STTMuteFrame content.
|
||||
|
||||
Args:
|
||||
frame: The incoming frame to process
|
||||
direction: The direction of frame flow in the pipeline
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, STTMuteFrame):
|
||||
self._is_muted = frame.mute
|
||||
|
||||
frame = RTVIServerMessageFrame(
|
||||
data={"type": "user-muted-event", "payload": {"is_muted": self._is_muted}}
|
||||
)
|
||||
|
||||
self.push_frame(frame)
|
||||
|
||||
if isinstance(
|
||||
frame,
|
||||
(
|
||||
TranscriptionFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMMessagesFrame,
|
||||
),
|
||||
):
|
||||
# Only pass frames when not muted
|
||||
if not self._is_muted:
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
logger.trace(
|
||||
f"{frame.__class__.__name__} suppressed - Transcription STT currently muted"
|
||||
)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
sprites = []
|
||||
script_dir = os.path.dirname(__file__)
|
||||
|
||||
@@ -128,7 +185,8 @@ async def main():
|
||||
camera_out_height=576,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
transcription_enabled=True,
|
||||
vad_audio_passthrough=True,
|
||||
# transcription_enabled=True,
|
||||
#
|
||||
# Spanish
|
||||
#
|
||||
@@ -183,9 +241,20 @@ async def main():
|
||||
#
|
||||
rtvi = RTVIProcessor(config=RTVIConfig(config=[]))
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
stt_mute_processor = STTMuteFilter(
|
||||
config=STTMuteConfig(strategies={STTMuteStrategy.ALWAYS}),
|
||||
)
|
||||
|
||||
transcription_mute_processor = TranscriptionMuteProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
stt,
|
||||
stt_mute_processor,
|
||||
transcription_mute_processor,
|
||||
rtvi,
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
@@ -213,7 +282,7 @@ async def main():
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# await transport.capture_participant_transcription(participant["id"])
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
|
||||
Reference in New Issue
Block a user