Compare commits

...

2 Commits

Author SHA1 Message Date
James Hush
dddfd791e1 Replace hello with banana 2025-04-18 14:18:32 +08:00
James Hush
e721c2086c Add banana processor 2025-04-18 14:14:22 +08:00
2 changed files with 29 additions and 2 deletions

View File

@@ -9,10 +9,11 @@ import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.frames.frames import EndFrame, TTSSpeakFrame
from pipecat.frames.frames import EndFrame, TranscriptionFrame, TTSSpeakFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.transports.base_transport import TransportParams
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport

View File

@@ -10,7 +10,7 @@ from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import Frame, MetricsFrame
from pipecat.frames.frames import Frame, MetricsFrame, TranscriptionFrame, TTSSpeakFrame
from pipecat.metrics.metrics import (
LLMUsageMetricsData,
ProcessingMetricsData,
@@ -32,7 +32,30 @@ from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
load_dotenv(override=True)
# Custom processor that prints a message if it receives a TranscriptionFrame that says "banana"
class BananaProcessor(FrameProcessor):
"""A custom processor that listens for transcription frames containing the word 'banana'."""
def __init__(self):
super().__init__()
async def process_frame(self, frame: Frame, direction: FrameDirection):
# Ensure the super method is called first
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
logger.debug(f"Received transcription frame: {frame.text}")
if "banana" in frame.text.lower():
logger.info("---- Received 'banana' in transcription frame")
# Push the frame after processing
await self.push_frame(frame)
class MetricsLogger(FrameProcessor):
def __init__(self):
super().__init__()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
@@ -87,10 +110,13 @@ async def run_bot(webrtc_connection: SmallWebRTCConnection):
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
banana = BananaProcessor()
pipeline = Pipeline(
[
transport.input(),
stt,
banana,
context_aggregator.user(),
llm,
tts,