Merge pull request #4501 from pipecat-ai/aleix/fix-filter-incomplete-tool-calls
Fix filter-incomplete + function-calling deadlock
This commit is contained in:
1
changelog/4501.fixed.md
Normal file
1
changelog/4501.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed bot hangs when `filter_incomplete_user_turns` was enabled and the LLM responded by calling a tool. The user turn never finalized, so the assistant aggregator gated the tool-result context push and the LLM continuation never ran. Tool calls now finalize the turn the moment they start, before the function dispatches.
|
||||
@@ -68,9 +68,9 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
tts = OpenAITTSService(
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
settings=OpenAITTSService.Settings(
|
||||
instructions="Please speak clearly and at a moderate pace.",
|
||||
voice="ballad",
|
||||
),
|
||||
instructions="Please speak clearly and at a moderate pace.",
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
|
||||
@@ -0,0 +1,201 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Example 22: Filter Incomplete Turns
|
||||
|
||||
Demonstrates LLM-based turn completion detection to suppress bot responses when
|
||||
the user was cut off mid-thought. The LLM outputs one of three markers:
|
||||
- ✓ (complete): User finished their thought, respond normally
|
||||
- ○ (incomplete short): User was cut off, wait ~5s then prompt
|
||||
- ◐ (incomplete long): User needs time to think, wait ~10s then prompt
|
||||
|
||||
When incomplete is detected, the bot's response is suppressed. After the timeout
|
||||
expires, the LLM is automatically prompted to re-engage the user.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
AssistantTurnStoppedMessage,
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
UserTurnStoppedMessage,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_turn_strategies import FilterIncompleteUserTurnStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def get_weather(params: FunctionCallParams, location: str):
|
||||
"""Return the current weather for a location.
|
||||
|
||||
A stub that always reports the same conditions — replace with a real
|
||||
weather API in production.
|
||||
|
||||
Args:
|
||||
location (str): The city and state or country, e.g. "Paris, France".
|
||||
"""
|
||||
await params.result_callback(
|
||||
{
|
||||
"location": location,
|
||||
"temperature_celsius": 22,
|
||||
"conditions": "partly cloudy",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.environ["DEEPGRAM_API_KEY"])
|
||||
|
||||
llm = OpenAILLMService(
|
||||
api_key=os.environ["OPENAI_API_KEY"],
|
||||
settings=OpenAILLMService.Settings(
|
||||
system_instruction=(
|
||||
"You are a helpful assistant in a voice conversation. Your "
|
||||
"responses will be spoken aloud, so avoid emojis, bullet "
|
||||
"points, or other formatting that can't be spoken. Respond to "
|
||||
"what the user said in a creative, helpful, and brief way. "
|
||||
"If the user asks about the weather, call the get_weather "
|
||||
"tool and speak the result back naturally."
|
||||
),
|
||||
),
|
||||
)
|
||||
llm.register_direct_function(get_weather)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.environ["CARTESIA_API_KEY"],
|
||||
settings=CartesiaTTSService.Settings(
|
||||
voice="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
),
|
||||
)
|
||||
|
||||
context = LLMContext(tools=ToolsSchema(standard_tools=[get_weather]))
|
||||
# `FilterIncompleteUserTurnStrategies` pairs the default detector
|
||||
# chain with `LLMTurnCompletionUserTurnStopStrategy`: detectors
|
||||
# trigger LLM inference but the public `on_user_turn_stopped` event
|
||||
# fires only when the LLM confirms ✓. The LLM marks each response
|
||||
# with one of:
|
||||
# ✓ = complete (respond normally)
|
||||
# ○ = incomplete short (wait 5s, then prompt)
|
||||
# ◐ = incomplete long (wait 10s, then prompt)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
user_turn_strategies=FilterIncompleteUserTurnStrategies(
|
||||
# Optional: customize turn completion behavior
|
||||
# config=UserTurnCompletionConfig(
|
||||
# incomplete_short_timeout=5.0,
|
||||
# incomplete_long_timeout=10.0,
|
||||
# incomplete_short_prompt="Custom prompt...",
|
||||
# incomplete_long_prompt="Custom prompt...",
|
||||
# instructions="Custom turn completion instructions...",
|
||||
# ),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
context.add_message(
|
||||
{"role": "developer", "content": "Please introduce yourself to the user."}
|
||||
)
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
@user_aggregator.event_handler("on_user_turn_stopped")
|
||||
async def on_user_turn_stopped(aggregator, strategy, message: UserTurnStoppedMessage):
|
||||
timestamp = f"[{message.timestamp}] " if message.timestamp else ""
|
||||
line = f"{timestamp}user: {message.content}"
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
|
||||
async def on_assistant_turn_stopped(aggregator, message: AssistantTurnStoppedMessage):
|
||||
timestamp = f"[{message.timestamp}] " if message.timestamp else ""
|
||||
line = f"{timestamp}assistant: {message.content}"
|
||||
logger.info(f"Transcript: {line}")
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -242,6 +242,7 @@ TESTS_VIDEO_AVATAR = [
|
||||
|
||||
TESTS_TURN_MANAGEMENT = [
|
||||
("turn-management/turn-management-filter-incomplete-turns.py", EVAL_COMPLETE_TURN),
|
||||
("turn-management/turn-management-filter-incomplete-turns-function-calling.py", EVAL_WEATHER),
|
||||
]
|
||||
|
||||
TESTS_THINKING = [
|
||||
|
||||
@@ -20,6 +20,7 @@ from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallsStartedFrame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMMarkerFrame,
|
||||
@@ -222,6 +223,14 @@ class UserTurnCompletionLLMServiceMixin(FrameProcessor):
|
||||
# ensures graceful degradation if the LLM disobeys and outputs additional text.
|
||||
self._turn_suppressed = False
|
||||
self._turn_complete_found = False # True when ✓ (COMPLETE) is detected
|
||||
# Set when the LLM made a tool call during this turn. Informational
|
||||
# only — broadcasting is idempotency-gated by
|
||||
# ``_turn_completion_broadcasted``.
|
||||
self._turn_had_function_call = False
|
||||
# True once ``UserTurnInferenceCompletedFrame`` has been broadcast
|
||||
# for this turn. Prevents double-broadcast when ✓ and a tool call
|
||||
# both occur in the same turn.
|
||||
self._turn_completion_broadcasted = False
|
||||
|
||||
# Timeout handling
|
||||
self._user_turn_completion_config = UserTurnCompletionConfig()
|
||||
@@ -236,6 +245,27 @@ class UserTurnCompletionLLMServiceMixin(FrameProcessor):
|
||||
"""
|
||||
self._user_turn_completion_config = config
|
||||
|
||||
async def _broadcast_turn_completion(self):
|
||||
"""Broadcast ``UserTurnInferenceCompletedFrame`` at most once per turn.
|
||||
|
||||
Called from the two places we know the LLM has committed to a
|
||||
response for the current user turn:
|
||||
|
||||
- the ``✓`` marker is detected in the text stream
|
||||
- a ``FunctionCallsStartedFrame`` is emitted — the LLM committed
|
||||
to a tool call before producing (or instead of) a marker.
|
||||
|
||||
Broadcasting on the tool-call path matters for races: the
|
||||
downstream ``UserStoppedSpeakingFrame`` needs to propagate
|
||||
before the function actually executes and a
|
||||
``FunctionCallResultFrame`` flows back to the assistant
|
||||
aggregator.
|
||||
"""
|
||||
if self._turn_completion_broadcasted:
|
||||
return
|
||||
self._turn_completion_broadcasted = True
|
||||
await self.broadcast_frame(UserTurnInferenceCompletedFrame)
|
||||
|
||||
async def _start_incomplete_timeout(self, incomplete_type: Literal["short", "long"]):
|
||||
"""Start a timeout task for incomplete turn handling.
|
||||
|
||||
@@ -325,6 +355,8 @@ class UserTurnCompletionLLMServiceMixin(FrameProcessor):
|
||||
self._turn_text_buffer = ""
|
||||
self._turn_suppressed = False
|
||||
self._turn_complete_found = False
|
||||
self._turn_had_function_call = False
|
||||
self._turn_completion_broadcasted = False
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames, handling turn completion state resets.
|
||||
@@ -351,7 +383,14 @@ class UserTurnCompletionLLMServiceMixin(FrameProcessor):
|
||||
frame: The frame to push downstream.
|
||||
direction: The direction of frame flow. Defaults to downstream.
|
||||
"""
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if isinstance(frame, FunctionCallsStartedFrame):
|
||||
self._turn_had_function_call = True
|
||||
# Broadcast turn completion now, before the function dispatches
|
||||
# — gives ``UserStoppedSpeakingFrame`` maximum time to propagate
|
||||
# so the assistant aggregator's ``_user_speaking`` is False by
|
||||
# the time a ``FunctionCallResultFrame`` arrives.
|
||||
await self._broadcast_turn_completion()
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._turn_reset()
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
@@ -427,7 +466,9 @@ class UserTurnCompletionLLMServiceMixin(FrameProcessor):
|
||||
# LLMTurnCompletionUserTurnStopStrategy) can fire
|
||||
# `on_user_turn_stopped`. Must fire before the marker so
|
||||
# downstream consumers see the signal before the response.
|
||||
await self.broadcast_frame(UserTurnInferenceCompletedFrame)
|
||||
# Idempotent: a tool call earlier in the turn may have
|
||||
# already broadcast.
|
||||
await self._broadcast_turn_completion()
|
||||
|
||||
# Push the marker as a sideband signal that the assistant
|
||||
# aggregator will prepend to the upcoming aggregated text,
|
||||
|
||||
Reference in New Issue
Block a user