Compare commits
1 Commits
mb/remove-
...
hush/conte
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77c82c64c0 |
@@ -13,12 +13,13 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame, LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
@@ -30,6 +31,44 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
FILTERED_WORDS = ["apple", "banana", "car"]
|
||||
|
||||
|
||||
class ContentFilterProcessor(FrameProcessor):
|
||||
"""Processor that filters LLMContextFrames containing specific words.
|
||||
|
||||
If the user's message contains any of the filtered words, the context
|
||||
is replaced with a message indicating the assistant cannot respond.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
# Check the last user message for filtered words
|
||||
messages = frame.context.messages
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
content = last_message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
content_lower = content.lower()
|
||||
if any(word in content_lower for word in FILTERED_WORDS):
|
||||
logger.info(f"Filtered content detected: {content}")
|
||||
# Create a new context with a filtered response instruction
|
||||
filtered_context = LLMContext(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "The user is asking about something you cannot give an answer about. Tell them you don't know how to respond.",
|
||||
}
|
||||
]
|
||||
)
|
||||
await self.push_frame(LLMContextFrame(filtered_context), direction)
|
||||
return
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -76,12 +115,14 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
content_filter = ContentFilterProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
content_filter, # Content filter
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
|
||||
Reference in New Issue
Block a user