- Removed the backend client compatibility wrapper and associated methods to streamline backend integration. - Updated session management to utilize control plane gateways and runtime configuration providers. - Adjusted TTS service implementations to remove the EdgeTTS service and simplify service dependencies. - Enhanced documentation to reflect changes in backend integration and service architecture. - Updated configuration files to remove deprecated TTS provider options and clarify available settings.
2765 lines
117 KiB
Python
2765 lines
117 KiB
Python
"""Full duplex audio pipeline for AI voice conversation.
|
||
|
||
This module implements the core duplex pipeline that orchestrates:
|
||
- VAD (Voice Activity Detection)
|
||
- EOU (End of Utterance) Detection
|
||
- ASR (Automatic Speech Recognition) - optional
|
||
- LLM (Language Model)
|
||
- TTS (Text-to-Speech)
|
||
|
||
Inspired by pipecat's frame-based architecture and active-call's
|
||
event-driven design.
|
||
"""
|
||
|
||
import asyncio
|
||
import audioop
|
||
import io
|
||
import json
|
||
import time
|
||
import uuid
|
||
import wave
|
||
from pathlib import Path
|
||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||
|
||
import numpy as np
|
||
import aiohttp
|
||
from loguru import logger
|
||
|
||
from app.config import settings
|
||
from app.service_factory import DefaultRealtimeServiceFactory
|
||
from core.conversation import ConversationManager, ConversationState
|
||
from core.events import get_event_bus
|
||
from core.ports import (
|
||
ASRPort,
|
||
ASRServiceSpec,
|
||
LLMPort,
|
||
LLMServiceSpec,
|
||
RealtimeServiceFactory,
|
||
TTSPort,
|
||
TTSServiceSpec,
|
||
)
|
||
from core.tool_executor import execute_server_tool
|
||
from core.transports import BaseTransport
|
||
from models.ws_v1 import ev
|
||
from processors.eou import EouDetector
|
||
from processors.vad import SileroVAD, VADProcessor
|
||
from services.base import LLMMessage, LLMStreamEvent
|
||
from services.streaming_text import extract_tts_sentence, has_spoken_content
|
||
|
||
|
||
class DuplexPipeline:
|
||
"""
|
||
Full duplex audio pipeline for AI voice conversation.
|
||
|
||
Handles bidirectional audio flow with:
|
||
- User speech detection and transcription
|
||
- AI response generation
|
||
- Text-to-speech synthesis
|
||
- Barge-in (interruption) support
|
||
|
||
Architecture (inspired by pipecat):
|
||
|
||
User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out
|
||
↓
|
||
Barge-in Detection → Interrupt
|
||
"""
|
||
|
||
_SENTENCE_END_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "\n"})
|
||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||
_TOOL_WAIT_TIMEOUT_SECONDS = 60.0
|
||
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
|
||
TRACK_AUDIO_IN = "audio_in"
|
||
TRACK_AUDIO_OUT = "audio_out"
|
||
TRACK_CONTROL = "control"
|
||
_PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
|
||
_ASR_DELTA_THROTTLE_MS = 500
|
||
_LLM_DELTA_THROTTLE_MS = 80
|
||
_ASR_CAPTURE_MAX_MS = 15000
|
||
_OPENER_PRE_ROLL_MS = 180
|
||
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
|
||
"current_time": {
|
||
"name": "current_time",
|
||
"description": "Get current local time",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {},
|
||
"required": [],
|
||
},
|
||
},
|
||
"calculator": {
|
||
"name": "calculator",
|
||
"description": "Execute a math expression",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"expression": {"type": "string", "description": "Math expression, e.g. 2 + 3 * 4"},
|
||
},
|
||
"required": ["expression"],
|
||
},
|
||
},
|
||
"code_interpreter": {
|
||
"name": "code_interpreter",
|
||
"description": "Safely evaluate a Python expression",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"code": {"type": "string", "description": "Python expression to evaluate"},
|
||
},
|
||
"required": ["code"],
|
||
},
|
||
},
|
||
"turn_on_camera": {
|
||
"name": "turn_on_camera",
|
||
"description": "Turn on client camera",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {},
|
||
"required": [],
|
||
},
|
||
},
|
||
"turn_off_camera": {
|
||
"name": "turn_off_camera",
|
||
"description": "Turn off client camera",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {},
|
||
"required": [],
|
||
},
|
||
},
|
||
"increase_volume": {
|
||
"name": "increase_volume",
|
||
"description": "Increase client volume",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"step": {"type": "integer", "description": "Volume increase step, default 1"},
|
||
},
|
||
"required": [],
|
||
},
|
||
},
|
||
"decrease_volume": {
|
||
"name": "decrease_volume",
|
||
"description": "Decrease client volume",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"step": {"type": "integer", "description": "Volume decrease step, default 1"},
|
||
},
|
||
"required": [],
|
||
},
|
||
},
|
||
"voice_msg_prompt": {
|
||
"name": "voice_msg_prompt",
|
||
"description": "Speak a message prompt on client side",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"msg": {"type": "string", "description": "Message text to speak"},
|
||
},
|
||
"required": ["msg"],
|
||
},
|
||
},
|
||
"text_msg_prompt": {
|
||
"name": "text_msg_prompt",
|
||
"description": "Show a text message prompt dialog on client side",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"msg": {"type": "string", "description": "Message text to display"},
|
||
},
|
||
"required": ["msg"],
|
||
},
|
||
},
|
||
"voice_choice_prompt": {
|
||
"name": "voice_choice_prompt",
|
||
"description": "Speak a question and show options on client side, then wait for selection",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"question": {"type": "string", "description": "Question text to show"},
|
||
"options": {
|
||
"type": "array",
|
||
"description": "Selectable options (string or object with id/label/value)",
|
||
"minItems": 2,
|
||
"items": {
|
||
"anyOf": [
|
||
{"type": "string"},
|
||
{
|
||
"type": "object",
|
||
"properties": {
|
||
"id": {"type": "string"},
|
||
"label": {"type": "string"},
|
||
"value": {"type": "string"},
|
||
},
|
||
"required": ["label"],
|
||
},
|
||
]
|
||
},
|
||
},
|
||
"voice_text": {
|
||
"type": "string",
|
||
"description": "Optional voice text. Falls back to question when omitted.",
|
||
},
|
||
},
|
||
"required": ["question", "options"],
|
||
},
|
||
},
|
||
"text_choice_prompt": {
|
||
"name": "text_choice_prompt",
|
||
"description": "Show a text-only choice prompt on client side and wait for selection",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"question": {"type": "string", "description": "Question text to show"},
|
||
"options": {
|
||
"type": "array",
|
||
"description": "Selectable options (string or object with id/label/value)",
|
||
"minItems": 2,
|
||
"items": {
|
||
"anyOf": [
|
||
{"type": "string"},
|
||
{
|
||
"type": "object",
|
||
"properties": {
|
||
"id": {"type": "string"},
|
||
"label": {"type": "string"},
|
||
"value": {"type": "string"},
|
||
},
|
||
"required": ["label"],
|
||
},
|
||
]
|
||
},
|
||
},
|
||
},
|
||
"required": ["question", "options"],
|
||
},
|
||
},
|
||
}
|
||
_DEFAULT_CLIENT_EXECUTORS = frozenset({
|
||
"turn_on_camera",
|
||
"turn_off_camera",
|
||
"increase_volume",
|
||
"decrease_volume",
|
||
"voice_msg_prompt",
|
||
"text_msg_prompt",
|
||
"voice_choice_prompt",
|
||
"text_choice_prompt",
|
||
})
|
||
_TOOL_NAME_ALIASES = {
|
||
"voice_message_prompt": "voice_msg_prompt",
|
||
}
|
||
|
||
@classmethod
|
||
def _normalize_tool_name(cls, raw_name: Any) -> str:
|
||
name = str(raw_name or "").strip()
|
||
if not name:
|
||
return ""
|
||
return cls._TOOL_NAME_ALIASES.get(name, name)
|
||
|
||
def __init__(
|
||
self,
|
||
transport: BaseTransport,
|
||
session_id: str,
|
||
llm_service: Optional[LLMPort] = None,
|
||
tts_service: Optional[TTSPort] = None,
|
||
asr_service: Optional[ASRPort] = None,
|
||
system_prompt: Optional[str] = None,
|
||
greeting: Optional[str] = None,
|
||
knowledge_searcher: Optional[
|
||
Callable[..., Awaitable[List[Dict[str, Any]]]]
|
||
] = None,
|
||
tool_resource_resolver: Optional[
|
||
Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
|
||
] = None,
|
||
server_tool_executor: Optional[
|
||
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||
] = None,
|
||
service_factory: Optional[RealtimeServiceFactory] = None,
|
||
):
|
||
"""
|
||
Initialize duplex pipeline.
|
||
|
||
Args:
|
||
transport: Transport for sending audio/events
|
||
session_id: Session identifier
|
||
llm_service: Optional injected LLM port implementation
|
||
tts_service: Optional injected TTS port implementation
|
||
asr_service: ASR service (optional)
|
||
system_prompt: System prompt for LLM
|
||
greeting: Optional greeting to speak on start
|
||
"""
|
||
self.transport = transport
|
||
self.session_id = session_id
|
||
self.event_bus = get_event_bus()
|
||
self.track_audio_in = self.TRACK_AUDIO_IN
|
||
self.track_audio_out = self.TRACK_AUDIO_OUT
|
||
self.track_control = self.TRACK_CONTROL
|
||
|
||
# Initialize VAD
|
||
self.vad_model = SileroVAD(
|
||
model_path=settings.vad_model_path,
|
||
sample_rate=settings.sample_rate
|
||
)
|
||
self.vad_processor = VADProcessor(
|
||
vad_model=self.vad_model,
|
||
threshold=settings.vad_threshold
|
||
)
|
||
|
||
# Initialize EOU detector
|
||
self.eou_detector = EouDetector(
|
||
silence_threshold_ms=settings.vad_eou_threshold_ms,
|
||
min_speech_duration_ms=settings.vad_min_speech_duration_ms
|
||
)
|
||
|
||
# Initialize services
|
||
self.llm_service = llm_service
|
||
self.tts_service = tts_service
|
||
self.asr_service = asr_service # Will be initialized in start()
|
||
self._service_factory = service_factory or DefaultRealtimeServiceFactory()
|
||
self._knowledge_searcher = knowledge_searcher
|
||
self._tool_resource_resolver = tool_resource_resolver
|
||
self._server_tool_executor = server_tool_executor
|
||
|
||
# Track last sent transcript to avoid duplicates
|
||
self._last_sent_transcript = ""
|
||
self._pending_transcript_delta: str = ""
|
||
self._last_transcript_delta_emit_ms: float = 0.0
|
||
|
||
# Conversation manager
|
||
self.conversation = ConversationManager(
|
||
system_prompt=system_prompt,
|
||
greeting=greeting
|
||
)
|
||
|
||
# State
|
||
self._running = True
|
||
self._is_bot_speaking = False
|
||
self._current_turn_task: Optional[asyncio.Task] = None
|
||
self._audio_buffer: bytes = b""
|
||
max_buffer_seconds = settings.max_audio_buffer_seconds
|
||
self._max_audio_buffer_bytes = int(settings.sample_rate * 2 * max_buffer_seconds)
|
||
self._asr_start_min_speech_ms: int = settings.asr_start_min_speech_ms
|
||
self._asr_capture_active: bool = False
|
||
self._asr_capture_started_ms: float = 0.0
|
||
self._pending_speech_audio: bytes = b""
|
||
# Keep a short rolling pre-speech window so VAD transition latency
|
||
# does not clip the first phoneme/character sent to ASR.
|
||
pre_speech_ms = settings.asr_pre_speech_ms
|
||
self._asr_pre_speech_bytes = int(settings.sample_rate * 2 * (pre_speech_ms / 1000.0))
|
||
self._pre_speech_buffer: bytes = b""
|
||
# Add a tiny trailing silence tail before final ASR to avoid
|
||
# clipping the last phoneme at utterance boundaries.
|
||
asr_final_tail_ms = settings.asr_final_tail_ms
|
||
self._asr_final_tail_bytes = int(settings.sample_rate * 2 * (asr_final_tail_ms / 1000.0))
|
||
self._last_vad_status: str = "Silence"
|
||
self._process_lock = asyncio.Lock()
|
||
# Priority outbound dispatcher (lower value = higher priority).
|
||
self._outbound_q: asyncio.PriorityQueue[Tuple[int, int, str, Any]] = asyncio.PriorityQueue()
|
||
self._outbound_seq = 0
|
||
self._outbound_task: Optional[asyncio.Task] = None
|
||
self._drop_outbound_audio = False
|
||
self._audio_out_frame_buffer: bytes = b""
|
||
|
||
# Interruption handling
|
||
self._interrupt_event = asyncio.Event()
|
||
|
||
# Latency tracking - TTFB (Time to First Byte)
|
||
self._turn_start_time: Optional[float] = None
|
||
self._first_audio_sent: bool = False
|
||
|
||
# Barge-in filtering - require minimum speech duration to interrupt
|
||
self._barge_in_speech_start_time: Optional[float] = None
|
||
self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms
|
||
self._barge_in_silence_tolerance_ms: int = settings.barge_in_silence_tolerance_ms
|
||
self._barge_in_speech_frames: int = 0 # Count speech frames
|
||
self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in
|
||
|
||
# Runtime overrides injected from session.start metadata
|
||
self._runtime_llm: Dict[str, Any] = {}
|
||
self._runtime_asr: Dict[str, Any] = {}
|
||
self._runtime_tts: Dict[str, Any] = {}
|
||
self._runtime_output: Dict[str, Any] = {}
|
||
self._runtime_system_prompt: Optional[str] = None
|
||
self._runtime_first_turn_mode: str = "bot_first"
|
||
self._runtime_greeting: Optional[str] = None
|
||
self._runtime_generated_opener_enabled: Optional[bool] = None
|
||
self._runtime_manual_opener_tool_calls: List[Any] = []
|
||
self._runtime_opener_audio: Dict[str, Any] = {}
|
||
self._runtime_barge_in_enabled: Optional[bool] = None
|
||
self._runtime_barge_in_min_duration_ms: Optional[int] = None
|
||
self._runtime_knowledge: Dict[str, Any] = {}
|
||
self._runtime_knowledge_base_id: Optional[str] = None
|
||
raw_default_tools = settings.tools if isinstance(settings.tools, list) else []
|
||
self._runtime_tools: List[Any] = list(raw_default_tools)
|
||
self._runtime_tool_executor: Dict[str, str] = {}
|
||
self._runtime_tool_default_args: Dict[str, Dict[str, Any]] = {}
|
||
self._runtime_tool_id_map: Dict[str, str] = {}
|
||
self._runtime_tool_display_names: Dict[str, str] = {}
|
||
self._runtime_tool_wait_for_response: Dict[str, bool] = {}
|
||
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
||
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||
self._completed_tool_call_ids: set[str] = set()
|
||
self._pending_client_tool_call_ids: set[str] = set()
|
||
self._pending_client_playback_tts_ids: set[str] = set()
|
||
self._tts_playback_context: Dict[str, Dict[str, Optional[str]]] = {}
|
||
self._last_client_played_tts_id: Optional[str] = None
|
||
self._last_client_played_response_id: Optional[str] = None
|
||
self._last_client_played_turn_id: Optional[str] = None
|
||
self._last_client_played_at_ms: Optional[int] = None
|
||
self._next_seq: Optional[Callable[[], int]] = None
|
||
self._local_seq: int = 0
|
||
|
||
# Cross-service correlation IDs
|
||
self._turn_count: int = 0
|
||
self._response_count: int = 0
|
||
self._tts_count: int = 0
|
||
self._utterance_count: int = 0
|
||
self._current_turn_id: Optional[str] = None
|
||
self._current_utterance_id: Optional[str] = None
|
||
self._current_response_id: Optional[str] = None
|
||
self._current_tts_id: Optional[str] = None
|
||
self._pending_llm_delta: str = ""
|
||
self._last_llm_delta_emit_ms: float = 0.0
|
||
|
||
self._runtime_tool_executor = self._resolved_tool_executor_map()
|
||
self._runtime_tool_default_args = self._resolved_tool_default_args_map()
|
||
self._runtime_tool_id_map = self._resolved_tool_id_map()
|
||
self._runtime_tool_display_names = self._resolved_tool_display_name_map()
|
||
self._runtime_tool_wait_for_response = self._resolved_tool_wait_for_response_map()
|
||
self._initial_greeting_emitted = False
|
||
|
||
if self._server_tool_executor is None:
|
||
if self._tool_resource_resolver:
|
||
async def _executor(call: Dict[str, Any]) -> Dict[str, Any]:
|
||
return await execute_server_tool(
|
||
call,
|
||
tool_resource_fetcher=self._tool_resource_resolver,
|
||
)
|
||
|
||
self._server_tool_executor = _executor
|
||
else:
|
||
self._server_tool_executor = execute_server_tool
|
||
|
||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||
|
||
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
|
||
"""Use session-scoped monotonic sequence provider for envelope events."""
|
||
self._next_seq = provider
|
||
|
||
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
|
||
"""
|
||
Apply runtime overrides from WS session.start metadata.
|
||
|
||
Expected metadata shape:
|
||
{
|
||
"systemPrompt": "...",
|
||
"greeting": "...",
|
||
"services": {
|
||
"llm": {...},
|
||
"asr": {...},
|
||
"tts": {...}
|
||
}
|
||
}
|
||
"""
|
||
if not metadata:
|
||
return
|
||
|
||
if "systemPrompt" in metadata:
|
||
self._runtime_system_prompt = str(metadata.get("systemPrompt") or "")
|
||
if self._runtime_system_prompt:
|
||
self.conversation.system_prompt = self._runtime_system_prompt
|
||
if "firstTurnMode" in metadata:
|
||
raw_mode = str(metadata.get("firstTurnMode") or "").strip().lower()
|
||
self._runtime_first_turn_mode = "user_first" if raw_mode == "user_first" else "bot_first"
|
||
if "greeting" in metadata:
|
||
greeting_payload = metadata.get("greeting")
|
||
if isinstance(greeting_payload, dict):
|
||
self._runtime_greeting = str(greeting_payload.get("text") or "")
|
||
generated_flag = self._coerce_bool(greeting_payload.get("generated"))
|
||
if generated_flag is not None:
|
||
self._runtime_generated_opener_enabled = generated_flag
|
||
else:
|
||
self._runtime_greeting = str(greeting_payload or "")
|
||
self.conversation.greeting = self._runtime_greeting or None
|
||
generated_opener_flag = self._coerce_bool(metadata.get("generatedOpenerEnabled"))
|
||
if generated_opener_flag is not None:
|
||
self._runtime_generated_opener_enabled = generated_opener_flag
|
||
if "manualOpenerToolCalls" in metadata:
|
||
manual_calls = metadata.get("manualOpenerToolCalls")
|
||
self._runtime_manual_opener_tool_calls = manual_calls if isinstance(manual_calls, list) else []
|
||
|
||
services = metadata.get("services") or {}
|
||
if isinstance(services, dict):
|
||
if isinstance(services.get("llm"), dict):
|
||
self._runtime_llm = services["llm"]
|
||
if isinstance(services.get("asr"), dict):
|
||
self._runtime_asr = services["asr"]
|
||
if isinstance(services.get("tts"), dict):
|
||
self._runtime_tts = services["tts"]
|
||
output = metadata.get("output") or {}
|
||
if isinstance(output, dict):
|
||
self._runtime_output = output
|
||
barge_in = metadata.get("bargeIn")
|
||
if isinstance(barge_in, dict):
|
||
barge_in_enabled = self._coerce_bool(barge_in.get("enabled"))
|
||
if barge_in_enabled is not None:
|
||
self._runtime_barge_in_enabled = barge_in_enabled
|
||
min_duration = barge_in.get("minDurationMs")
|
||
if isinstance(min_duration, (int, float, str)):
|
||
try:
|
||
self._runtime_barge_in_min_duration_ms = max(0, int(min_duration))
|
||
except (TypeError, ValueError):
|
||
self._runtime_barge_in_min_duration_ms = None
|
||
|
||
knowledge_base_id = metadata.get("knowledgeBaseId")
|
||
if knowledge_base_id is not None:
|
||
kb_id = str(knowledge_base_id).strip()
|
||
self._runtime_knowledge_base_id = kb_id or None
|
||
|
||
knowledge = metadata.get("knowledge")
|
||
if isinstance(knowledge, dict):
|
||
self._runtime_knowledge = knowledge
|
||
opener_audio = metadata.get("openerAudio")
|
||
if isinstance(opener_audio, dict):
|
||
self._runtime_opener_audio = dict(opener_audio)
|
||
kb_id = str(knowledge.get("kbId") or knowledge.get("knowledgeBaseId") or "").strip()
|
||
if kb_id:
|
||
self._runtime_knowledge_base_id = kb_id
|
||
|
||
tools_payload = metadata.get("tools")
|
||
if isinstance(tools_payload, list):
|
||
self._runtime_tools = tools_payload
|
||
self._runtime_tool_executor = self._resolved_tool_executor_map()
|
||
self._runtime_tool_default_args = self._resolved_tool_default_args_map()
|
||
self._runtime_tool_id_map = self._resolved_tool_id_map()
|
||
self._runtime_tool_display_names = self._resolved_tool_display_name_map()
|
||
self._runtime_tool_wait_for_response = self._resolved_tool_wait_for_response_map()
|
||
elif "tools" in metadata:
|
||
self._runtime_tools = []
|
||
self._runtime_tool_executor = {}
|
||
self._runtime_tool_default_args = {}
|
||
self._runtime_tool_id_map = {}
|
||
self._runtime_tool_display_names = {}
|
||
self._runtime_tool_wait_for_response = {}
|
||
|
||
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
|
||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
|
||
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||
|
||
def resolved_runtime_config(self) -> Dict[str, Any]:
|
||
"""Return current effective runtime configuration without secrets."""
|
||
llm_provider = str(self._runtime_llm.get("provider") or settings.llm_provider).lower()
|
||
llm_base_url = (
|
||
self._runtime_llm.get("baseUrl")
|
||
or settings.llm_api_url
|
||
or self._default_llm_base_url(llm_provider)
|
||
)
|
||
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
|
||
if not output_mode:
|
||
output_mode = "audio" if self._tts_output_enabled() else "text"
|
||
|
||
tts_model = str(
|
||
self._runtime_tts.get("model")
|
||
or settings.tts_model
|
||
or (self._default_dashscope_tts_model() if self._is_dashscope_tts_provider(tts_provider) else "")
|
||
)
|
||
tts_config = {
|
||
"enabled": self._tts_output_enabled(),
|
||
"provider": tts_provider,
|
||
"model": tts_model,
|
||
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
|
||
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
|
||
}
|
||
if self._is_dashscope_tts_provider(tts_provider):
|
||
tts_config["mode"] = self._resolved_dashscope_tts_mode()
|
||
|
||
return {
|
||
"output": {"mode": output_mode},
|
||
"services": {
|
||
"llm": {
|
||
"provider": llm_provider,
|
||
"model": str(self._runtime_llm.get("model") or settings.llm_model),
|
||
"baseUrl": llm_base_url,
|
||
},
|
||
"asr": {
|
||
"provider": asr_provider,
|
||
"model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
|
||
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
|
||
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
|
||
},
|
||
"tts": tts_config,
|
||
},
|
||
"tools": {
|
||
"allowlist": self._resolved_tool_allowlist(),
|
||
},
|
||
"opener": {
|
||
"generated": self._generated_opener_enabled(),
|
||
"manualToolCallCount": len(self._resolved_manual_opener_tool_calls()),
|
||
},
|
||
"tracks": {
|
||
"audio_in": self.track_audio_in,
|
||
"audio_out": self.track_audio_out,
|
||
"control": self.track_control,
|
||
},
|
||
}
|
||
|
||
def _next_event_seq(self) -> int:
|
||
if self._next_seq:
|
||
return self._next_seq()
|
||
self._local_seq += 1
|
||
return self._local_seq
|
||
|
||
def _event_source(self, event_type: str) -> str:
|
||
if event_type.startswith("transcript.") or event_type.startswith("input.speech_"):
|
||
return "asr"
|
||
if event_type.startswith("assistant.response."):
|
||
return "llm"
|
||
if event_type.startswith("assistant.tool_"):
|
||
return "tool"
|
||
if event_type.startswith("output.audio.") or event_type == "metrics.ttfb":
|
||
return "tts"
|
||
return "system"
|
||
|
||
def _new_id(self, prefix: str, counter: int) -> str:
|
||
return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}"
|
||
|
||
def _start_turn(self) -> str:
|
||
self._turn_count += 1
|
||
self._current_turn_id = self._new_id("turn", self._turn_count)
|
||
self._current_utterance_id = None
|
||
self._current_response_id = None
|
||
self._current_tts_id = None
|
||
return self._current_turn_id
|
||
|
||
def _start_response(self) -> str:
|
||
self._response_count += 1
|
||
self._current_response_id = self._new_id("resp", self._response_count)
|
||
self._current_tts_id = None
|
||
return self._current_response_id
|
||
|
||
def _start_tts(self) -> str:
|
||
self._tts_count += 1
|
||
tts_id = self._new_id("tts", self._tts_count)
|
||
self._current_tts_id = tts_id
|
||
self._tts_playback_context[tts_id] = {
|
||
"turn_id": self._current_turn_id,
|
||
"response_id": self._current_response_id,
|
||
}
|
||
return tts_id
|
||
|
||
def _finalize_utterance(self) -> str:
|
||
if self._current_utterance_id:
|
||
return self._current_utterance_id
|
||
self._utterance_count += 1
|
||
self._current_utterance_id = self._new_id("utt", self._utterance_count)
|
||
if not self._current_turn_id:
|
||
self._start_turn()
|
||
return self._current_utterance_id
|
||
|
||
def _mark_client_playback_started(self, tts_id: Optional[str]) -> None:
|
||
normalized_tts_id = str(tts_id or "").strip()
|
||
if not normalized_tts_id:
|
||
return
|
||
self._pending_client_playback_tts_ids.add(normalized_tts_id)
|
||
|
||
def _clear_client_playback_tracking(self) -> None:
|
||
self._pending_client_playback_tts_ids.clear()
|
||
self._tts_playback_context.clear()
|
||
|
||
async def handle_output_audio_played(
|
||
self,
|
||
*,
|
||
tts_id: str,
|
||
response_id: Optional[str] = None,
|
||
turn_id: Optional[str] = None,
|
||
played_at_ms: Optional[int] = None,
|
||
played_ms: Optional[int] = None,
|
||
) -> None:
|
||
"""Record client-side playback completion for a TTS segment."""
|
||
normalized_tts_id = str(tts_id or "").strip()
|
||
if not normalized_tts_id:
|
||
return
|
||
|
||
was_pending = normalized_tts_id in self._pending_client_playback_tts_ids
|
||
self._pending_client_playback_tts_ids.discard(normalized_tts_id)
|
||
|
||
context = self._tts_playback_context.pop(normalized_tts_id, {})
|
||
resolved_response_id = str(response_id or context.get("response_id") or "").strip() or None
|
||
resolved_turn_id = str(turn_id or context.get("turn_id") or "").strip() or None
|
||
|
||
self._last_client_played_tts_id = normalized_tts_id
|
||
self._last_client_played_response_id = resolved_response_id
|
||
self._last_client_played_turn_id = resolved_turn_id
|
||
if isinstance(played_at_ms, int) and played_at_ms >= 0:
|
||
self._last_client_played_at_ms = played_at_ms
|
||
else:
|
||
self._last_client_played_at_ms = self._get_timestamp_ms()
|
||
|
||
duration_ms = played_ms if isinstance(played_ms, int) and played_ms >= 0 else None
|
||
logger.info(
|
||
f"[PlaybackAck] tts_id={normalized_tts_id} response_id={resolved_response_id or '-'} "
|
||
f"turn_id={resolved_turn_id or '-'} pending_before={was_pending} "
|
||
f"pending_now={len(self._pending_client_playback_tts_ids)} "
|
||
f"played_ms={duration_ms if duration_ms is not None else '-'}"
|
||
)
|
||
|
||
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
|
||
event_type = str(event.get("type") or "")
|
||
source = str(event.get("source") or self._event_source(event_type))
|
||
track_id = event.get("trackId")
|
||
if not track_id:
|
||
if source == "asr":
|
||
track_id = self.track_audio_in
|
||
elif source in {"llm", "tts", "tool"}:
|
||
track_id = self.track_audio_out
|
||
else:
|
||
track_id = self.track_control
|
||
|
||
data = event.get("data")
|
||
if not isinstance(data, dict):
|
||
data = {}
|
||
explicit_turn_id = str(event.get("turn_id") or "").strip() or None
|
||
explicit_utterance_id = str(event.get("utterance_id") or "").strip() or None
|
||
explicit_response_id = str(event.get("response_id") or "").strip() or None
|
||
explicit_tts_id = str(event.get("tts_id") or "").strip() or None
|
||
if self._current_turn_id:
|
||
data.setdefault("turn_id", self._current_turn_id)
|
||
if self._current_utterance_id:
|
||
data.setdefault("utterance_id", self._current_utterance_id)
|
||
if self._current_response_id:
|
||
data.setdefault("response_id", self._current_response_id)
|
||
if self._current_tts_id:
|
||
data.setdefault("tts_id", self._current_tts_id)
|
||
if explicit_turn_id:
|
||
data["turn_id"] = explicit_turn_id
|
||
if explicit_utterance_id:
|
||
data["utterance_id"] = explicit_utterance_id
|
||
if explicit_response_id:
|
||
data["response_id"] = explicit_response_id
|
||
if explicit_tts_id:
|
||
data["tts_id"] = explicit_tts_id
|
||
|
||
for k, v in event.items():
|
||
if k in {
|
||
"type",
|
||
"timestamp",
|
||
"sessionId",
|
||
"seq",
|
||
"source",
|
||
"trackId",
|
||
"data",
|
||
"turn_id",
|
||
"utterance_id",
|
||
"response_id",
|
||
"tts_id",
|
||
}:
|
||
continue
|
||
data.setdefault(k, v)
|
||
|
||
event["sessionId"] = self.session_id
|
||
event["seq"] = self._next_event_seq()
|
||
event["source"] = source
|
||
event["trackId"] = track_id
|
||
event["data"] = data
|
||
return event
|
||
|
||
@staticmethod
|
||
def _coerce_bool(value: Any) -> Optional[bool]:
|
||
if isinstance(value, bool):
|
||
return value
|
||
if isinstance(value, (int, float)):
|
||
return bool(value)
|
||
if isinstance(value, str):
|
||
normalized = value.strip().lower()
|
||
if normalized in {"1", "true", "yes", "on", "enabled"}:
|
||
return True
|
||
if normalized in {"0", "false", "no", "off", "disabled"}:
|
||
return False
|
||
return None
|
||
|
||
@staticmethod
|
||
def _is_dashscope_tts_provider(provider: Any) -> bool:
|
||
normalized = str(provider or "").strip().lower()
|
||
return normalized == "dashscope"
|
||
|
||
@staticmethod
|
||
def _default_llm_base_url(provider: Any) -> Optional[str]:
|
||
normalized = str(provider or "").strip().lower()
|
||
if normalized == "siliconflow":
|
||
return "https://api.siliconflow.cn/v1"
|
||
return None
|
||
|
||
@staticmethod
|
||
def _default_dashscope_tts_model() -> str:
|
||
return "qwen3-tts-flash-realtime"
|
||
|
||
def _resolved_dashscope_tts_mode(self) -> str:
|
||
raw_mode = str(self._runtime_tts.get("mode") or settings.tts_mode or "commit").strip().lower()
|
||
if raw_mode in {"commit", "server_commit"}:
|
||
return raw_mode
|
||
return "commit"
|
||
|
||
def _use_engine_sentence_split_for_tts(self) -> bool:
|
||
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).strip().lower()
|
||
if not self._is_dashscope_tts_provider(tts_provider):
|
||
return True
|
||
# DashScope commit mode is client-driven and expects engine-side segmentation.
|
||
# server_commit mode lets DashScope handle segmentation on appended text.
|
||
return self._resolved_dashscope_tts_mode() != "server_commit"
|
||
|
||
def _tts_output_enabled(self) -> bool:
|
||
enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
|
||
if enabled is not None:
|
||
return enabled
|
||
|
||
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
|
||
if output_mode in {"text", "text_only", "text-only"}:
|
||
return False
|
||
|
||
return True
|
||
|
||
def _generated_opener_enabled(self) -> bool:
|
||
return self._runtime_generated_opener_enabled is True
|
||
|
||
def _bot_starts_first(self) -> bool:
|
||
return self._runtime_first_turn_mode != "user_first"
|
||
|
||
def _barge_in_enabled(self) -> bool:
|
||
if self._runtime_barge_in_enabled is not None:
|
||
return self._runtime_barge_in_enabled
|
||
return True
|
||
|
||
def _resolved_barge_in_min_duration_ms(self) -> int:
|
||
if self._runtime_barge_in_min_duration_ms is not None:
|
||
return self._runtime_barge_in_min_duration_ms
|
||
return self._barge_in_min_duration_ms
|
||
|
||
def _barge_in_silence_tolerance_frames(self) -> int:
|
||
"""Convert silence tolerance from ms to frame count using current chunk size."""
|
||
chunk_ms = max(1, settings.chunk_size_ms)
|
||
return max(1, int(np.ceil(self._barge_in_silence_tolerance_ms / chunk_ms)))
|
||
|
||
async def _generate_runtime_greeting(self) -> Optional[str]:
|
||
if not self.llm_service:
|
||
return None
|
||
|
||
system_context = (self.conversation.system_prompt or self._runtime_system_prompt or "").strip()
|
||
# Keep context concise to avoid overloading greeting generation.
|
||
if len(system_context) > 1200:
|
||
system_context = system_context[:1200]
|
||
system_prompt = (
|
||
"你是语音通话助手的开场白生成器。"
|
||
"请只输出一句自然、简洁、友好的中文开场白。"
|
||
"不要使用引号,不要使用 markdown,不要加解释。"
|
||
)
|
||
user_prompt = "请生成一句中文开场白(不超过25个汉字)。"
|
||
if system_context:
|
||
user_prompt += f"\n\n以下是该助手的系统提示词,请据此决定语气、角色和边界:\n{system_context}"
|
||
|
||
try:
|
||
generated = await self.llm_service.generate(
|
||
[
|
||
LLMMessage(role="system", content=system_prompt),
|
||
LLMMessage(role="user", content=user_prompt),
|
||
],
|
||
temperature=0.7,
|
||
max_tokens=64,
|
||
)
|
||
except Exception as exc:
|
||
logger.warning(f"Failed to generate runtime greeting: {exc}")
|
||
return None
|
||
|
||
text = (generated or "").strip()
|
||
if not text:
|
||
return None
|
||
return text.strip().strip('"').strip("'")
|
||
|
||
async def start(self) -> None:
|
||
"""Start the pipeline and connect services."""
|
||
try:
|
||
# Connect LLM service
|
||
if not self.llm_service:
|
||
llm_provider = (self._runtime_llm.get("provider") or settings.llm_provider).lower()
|
||
llm_api_key = self._runtime_llm.get("apiKey")
|
||
llm_base_url = (
|
||
self._runtime_llm.get("baseUrl")
|
||
or settings.llm_api_url
|
||
or self._default_llm_base_url(llm_provider)
|
||
)
|
||
llm_model = self._runtime_llm.get("model") or settings.llm_model
|
||
self.llm_service = self._service_factory.create_llm_service(
|
||
LLMServiceSpec(
|
||
provider=llm_provider,
|
||
model=str(llm_model),
|
||
api_key=str(llm_api_key).strip() if llm_api_key else None,
|
||
base_url=str(llm_base_url).strip() if llm_base_url else None,
|
||
system_prompt=self.conversation.system_prompt,
|
||
temperature=settings.llm_temperature,
|
||
knowledge_config=self._resolved_knowledge_config(),
|
||
knowledge_searcher=self._knowledge_searcher,
|
||
)
|
||
)
|
||
|
||
if hasattr(self.llm_service, "set_knowledge_config"):
|
||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||
if hasattr(self.llm_service, "set_tool_schemas"):
|
||
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||
|
||
await self.llm_service.connect()
|
||
|
||
tts_output_enabled = self._tts_output_enabled()
|
||
|
||
# Connect TTS service only when audio output is enabled.
|
||
if tts_output_enabled:
|
||
if not self.tts_service:
|
||
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||
tts_api_key = self._runtime_tts.get("apiKey")
|
||
tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url
|
||
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||
tts_model = self._runtime_tts.get("model") or settings.tts_model
|
||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||
tts_mode = self._resolved_dashscope_tts_mode()
|
||
runtime_mode = str(self._runtime_tts.get("mode") or "").strip()
|
||
if runtime_mode and not self._is_dashscope_tts_provider(tts_provider):
|
||
logger.warning(
|
||
"services.tts.mode is DashScope-only and will be ignored "
|
||
f"for provider={tts_provider}"
|
||
)
|
||
self.tts_service = self._service_factory.create_tts_service(
|
||
TTSServiceSpec(
|
||
provider=tts_provider,
|
||
api_key=str(tts_api_key).strip() if tts_api_key else None,
|
||
api_url=str(tts_api_url).strip() if tts_api_url else None,
|
||
voice=str(tts_voice),
|
||
model=str(tts_model).strip() if tts_model else None,
|
||
sample_rate=settings.sample_rate,
|
||
speed=tts_speed,
|
||
mode=str(tts_mode),
|
||
)
|
||
)
|
||
|
||
try:
|
||
await self.tts_service.connect()
|
||
except Exception as e:
|
||
logger.warning(f"TTS backend unavailable ({e}); falling back to default TTS adapter")
|
||
self.tts_service = self._service_factory.create_tts_service(
|
||
TTSServiceSpec(
|
||
provider="mock",
|
||
voice="mock",
|
||
sample_rate=settings.sample_rate,
|
||
)
|
||
)
|
||
await self.tts_service.connect()
|
||
else:
|
||
self.tts_service = None
|
||
logger.info("TTS output disabled by runtime metadata")
|
||
|
||
# Connect ASR service
|
||
if not self.asr_service:
|
||
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||
asr_api_key = self._runtime_asr.get("apiKey")
|
||
asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
|
||
asr_model = self._runtime_asr.get("model") or settings.asr_model
|
||
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
|
||
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
|
||
|
||
self.asr_service = self._service_factory.create_asr_service(
|
||
ASRServiceSpec(
|
||
provider=asr_provider,
|
||
sample_rate=settings.sample_rate,
|
||
language="auto",
|
||
api_key=str(asr_api_key).strip() if asr_api_key else None,
|
||
api_url=str(asr_api_url).strip() if asr_api_url else None,
|
||
model=str(asr_model).strip() if asr_model else None,
|
||
interim_interval_ms=asr_interim_interval,
|
||
min_audio_for_interim_ms=asr_min_audio_ms,
|
||
on_transcript=self._on_transcript_callback,
|
||
)
|
||
)
|
||
|
||
await self.asr_service.connect()
|
||
|
||
logger.info("DuplexPipeline services connected")
|
||
if not self._outbound_task or self._outbound_task.done():
|
||
self._outbound_task = asyncio.create_task(self._outbound_loop())
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to start pipeline: {e}")
|
||
raise
|
||
|
||
async def emit_initial_greeting(self) -> None:
|
||
"""
|
||
Emit opener after session activation.
|
||
|
||
Ordering target:
|
||
1) frontend receives `session.started` (shows connected/ready)
|
||
2) frontend receives opener text event
|
||
3) frontend receives opener audio events/chunks
|
||
"""
|
||
if self._initial_greeting_emitted:
|
||
return
|
||
|
||
self._initial_greeting_emitted = True
|
||
if not self._bot_starts_first():
|
||
return
|
||
|
||
if self._generated_opener_enabled() and self._resolved_tool_schemas():
|
||
# Run generated opener as a normal tool-capable assistant turn.
|
||
# Use an empty user input so the opener can be driven by system prompt policy.
|
||
if self._current_turn_task and not self._current_turn_task.done():
|
||
logger.info("Skip initial generated opener: assistant turn already in progress")
|
||
return
|
||
self._current_turn_task = asyncio.create_task(self._handle_turn(""))
|
||
logger.info("Initial generated opener started with tool-calling path")
|
||
return
|
||
|
||
manual_opener_execution: Dict[str, List[Dict[str, Any]]] = {"toolCalls": [], "toolResults": []}
|
||
if not self._generated_opener_enabled() and self._resolved_manual_opener_tool_calls():
|
||
self._start_turn()
|
||
self._start_response()
|
||
manual_opener_execution = await self._execute_manual_opener_tool_calls()
|
||
|
||
greeting_to_speak = self.conversation.greeting
|
||
if self._generated_opener_enabled():
|
||
generated_greeting = await self._generate_runtime_greeting()
|
||
if generated_greeting:
|
||
greeting_to_speak = generated_greeting
|
||
self.conversation.greeting = generated_greeting
|
||
|
||
if not greeting_to_speak:
|
||
if (
|
||
not self._generated_opener_enabled()
|
||
and manual_opener_execution.get("toolCalls")
|
||
and not (self._current_turn_task and not self._current_turn_task.done())
|
||
):
|
||
follow_up_context = self._build_manual_opener_follow_up_context(manual_opener_execution)
|
||
self._current_turn_task = asyncio.create_task(
|
||
self._handle_turn("", system_context=follow_up_context)
|
||
)
|
||
logger.info("Initial manual opener follow-up started")
|
||
return
|
||
|
||
if not self._current_turn_id:
|
||
self._start_turn()
|
||
if not self._current_response_id:
|
||
self._start_response()
|
||
await self._send_event(
|
||
ev(
|
||
"assistant.response.final",
|
||
text=greeting_to_speak,
|
||
trackId=self.track_audio_out,
|
||
),
|
||
priority=20,
|
||
)
|
||
await self.conversation.add_assistant_turn(greeting_to_speak)
|
||
|
||
# Give client mic capture a short head start so opener can be interrupted immediately.
|
||
await asyncio.sleep(self._OPENER_PRE_ROLL_MS / 1000.0)
|
||
|
||
used_preloaded_audio = await self._play_preloaded_opener_audio()
|
||
if self._tts_output_enabled() and not used_preloaded_audio:
|
||
# Keep opener text ahead of opener voice start.
|
||
await self._speak(greeting_to_speak, audio_event_priority=30)
|
||
|
||
async def _play_preloaded_opener_audio(self) -> bool:
|
||
"""
|
||
Play opener audio from runtime metadata cache or YAML-configured local file.
|
||
|
||
Returns True when preloaded audio is played successfully.
|
||
"""
|
||
if not self._tts_output_enabled():
|
||
return False
|
||
|
||
pcm_bytes = await self._load_preloaded_opener_pcm()
|
||
if not pcm_bytes:
|
||
return False
|
||
|
||
try:
|
||
self._drop_outbound_audio = False
|
||
tts_id = self._start_tts()
|
||
self._mark_client_playback_started(tts_id)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"output.audio.start",
|
||
trackId=self.track_audio_out,
|
||
)
|
||
},
|
||
priority=30,
|
||
)
|
||
|
||
self._is_bot_speaking = True
|
||
await self._send_audio(pcm_bytes, priority=50)
|
||
await self._flush_audio_out_frames(priority=50)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"output.audio.end",
|
||
trackId=self.track_audio_out,
|
||
)
|
||
},
|
||
priority=30,
|
||
)
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"Failed to play preloaded opener audio, fallback to TTS: {e}")
|
||
return False
|
||
finally:
|
||
self._is_bot_speaking = False
|
||
|
||
async def _load_preloaded_opener_pcm(self) -> Optional[bytes]:
|
||
# 1) Runtime metadata from backend config
|
||
opener_audio = self._runtime_opener_audio if isinstance(self._runtime_opener_audio, dict) else {}
|
||
if bool(opener_audio.get("enabled")) and bool(opener_audio.get("ready")):
|
||
pcm_url = str(opener_audio.get("pcmUrl") or "").strip()
|
||
if pcm_url:
|
||
resolved_url = pcm_url
|
||
if pcm_url.startswith("/"):
|
||
backend_url = str(settings.backend_url or "").strip().rstrip("/")
|
||
if backend_url:
|
||
resolved_url = f"{backend_url}{pcm_url}"
|
||
try:
|
||
timeout = aiohttp.ClientTimeout(total=10)
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.get(resolved_url) as resp:
|
||
resp.raise_for_status()
|
||
payload = await resp.read()
|
||
if payload:
|
||
return payload
|
||
except Exception as e:
|
||
logger.warning(f"Failed to fetch opener audio from backend ({resolved_url}): {e}")
|
||
|
||
# 2) Standalone fallback via YAML
|
||
opener_audio_file = str(settings.duplex_opener_audio_file or "").strip()
|
||
if not opener_audio_file:
|
||
return None
|
||
path = Path(opener_audio_file)
|
||
if not path.is_absolute():
|
||
path = (Path.cwd() / path).resolve()
|
||
if not path.exists() or not path.is_file():
|
||
logger.warning(f"Configured opener audio file does not exist: {path}")
|
||
return None
|
||
try:
|
||
raw = path.read_bytes()
|
||
suffix = path.suffix.lower()
|
||
if suffix == ".wav":
|
||
pcm, _ = self._wav_to_pcm16_mono_16k(raw)
|
||
return pcm
|
||
# .pcm raw pcm_s16le 16k mono
|
||
return raw
|
||
except Exception as e:
|
||
logger.warning(f"Failed to read opener audio file {path}: {e}")
|
||
return None
|
||
|
||
def _wav_to_pcm16_mono_16k(self, wav_bytes: bytes) -> Tuple[bytes, int]:
|
||
with wave.open(io.BytesIO(wav_bytes), "rb") as wav_file:
|
||
channels = wav_file.getnchannels()
|
||
sample_width = wav_file.getsampwidth()
|
||
sample_rate = wav_file.getframerate()
|
||
nframes = wav_file.getnframes()
|
||
raw = wav_file.readframes(nframes)
|
||
|
||
if sample_width != 2:
|
||
raise ValueError(f"Unsupported WAV sample width: {sample_width * 8}bit")
|
||
if channels > 1:
|
||
raw = audioop.tomono(raw, sample_width, 0.5, 0.5)
|
||
if sample_rate != 16000:
|
||
raw, _ = audioop.ratecv(raw, sample_width, 1, sample_rate, 16000, None)
|
||
duration_ms = int((len(raw) / (16000 * 2)) * 1000)
|
||
return raw, duration_ms
|
||
|
||
async def _enqueue_outbound(self, kind: str, payload: Any, priority: int) -> None:
|
||
"""Queue outbound message with priority ordering."""
|
||
self._outbound_seq += 1
|
||
await self._outbound_q.put((priority, self._outbound_seq, kind, payload))
|
||
|
||
async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None:
|
||
await self._enqueue_outbound("event", self._envelope_event(event), priority)
|
||
|
||
async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None:
|
||
if not pcm_bytes:
|
||
return
|
||
self._audio_out_frame_buffer += pcm_bytes
|
||
while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES:
|
||
frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES]
|
||
self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :]
|
||
await self._enqueue_outbound("audio", frame, priority)
|
||
|
||
async def _flush_audio_out_frames(self, priority: int = 50) -> None:
|
||
"""Flush remaining outbound audio as one padded 20ms PCM frame."""
|
||
if not self._audio_out_frame_buffer:
|
||
return
|
||
tail = self._audio_out_frame_buffer
|
||
self._audio_out_frame_buffer = b""
|
||
if len(tail) < self._PCM_FRAME_BYTES:
|
||
tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail)))
|
||
await self._enqueue_outbound("audio", tail, priority)
|
||
|
||
async def _emit_transcript_delta(self, text: str) -> None:
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"transcript.delta",
|
||
trackId=self.track_audio_in,
|
||
text=text,
|
||
)
|
||
},
|
||
priority=30,
|
||
)
|
||
|
||
async def _emit_llm_delta(
|
||
self,
|
||
text: str,
|
||
*,
|
||
turn_id: Optional[str] = None,
|
||
utterance_id: Optional[str] = None,
|
||
response_id: Optional[str] = None,
|
||
) -> None:
|
||
event = {
|
||
**ev(
|
||
"assistant.response.delta",
|
||
trackId=self.track_audio_out,
|
||
text=text,
|
||
)
|
||
}
|
||
if turn_id:
|
||
event["turn_id"] = turn_id
|
||
if utterance_id:
|
||
event["utterance_id"] = utterance_id
|
||
if response_id:
|
||
event["response_id"] = response_id
|
||
await self._send_event(
|
||
event,
|
||
priority=20,
|
||
)
|
||
|
||
async def _flush_pending_llm_delta(
|
||
self,
|
||
*,
|
||
turn_id: Optional[str] = None,
|
||
utterance_id: Optional[str] = None,
|
||
response_id: Optional[str] = None,
|
||
) -> None:
|
||
if not self._pending_llm_delta:
|
||
return
|
||
chunk = self._pending_llm_delta
|
||
self._pending_llm_delta = ""
|
||
self._last_llm_delta_emit_ms = time.monotonic() * 1000.0
|
||
await self._emit_llm_delta(
|
||
chunk,
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
|
||
async def _outbound_loop(self) -> None:
|
||
"""Single sender loop that enforces priority for interrupt events."""
|
||
while True:
|
||
_priority, _seq, kind, payload = await self._outbound_q.get()
|
||
try:
|
||
if kind == "stop":
|
||
return
|
||
if kind == "audio":
|
||
if self._drop_outbound_audio:
|
||
continue
|
||
await self.transport.send_audio(payload)
|
||
elif kind == "event":
|
||
await self.transport.send_event(payload)
|
||
except Exception as e:
|
||
logger.error(f"Outbound send error ({kind}): {e}")
|
||
finally:
|
||
self._outbound_q.task_done()
|
||
|
||
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||
"""
|
||
Process incoming audio chunk.
|
||
|
||
This is the main entry point for audio from the user.
|
||
|
||
Args:
|
||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||
"""
|
||
if not self._running:
|
||
return
|
||
|
||
try:
|
||
async with self._process_lock:
|
||
if pcm_bytes:
|
||
self._pre_speech_buffer += pcm_bytes
|
||
if len(self._pre_speech_buffer) > self._asr_pre_speech_bytes:
|
||
self._pre_speech_buffer = self._pre_speech_buffer[-self._asr_pre_speech_bytes:]
|
||
|
||
# 1. Process through VAD
|
||
vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||
|
||
vad_status = "Silence"
|
||
if vad_result:
|
||
event_type, probability = vad_result
|
||
vad_status = "Speech" if event_type == "speaking" else "Silence"
|
||
|
||
# Emit VAD event
|
||
await self.event_bus.publish(event_type, {
|
||
"trackId": self.track_audio_in,
|
||
"probability": probability
|
||
})
|
||
await self._send_event(
|
||
ev(
|
||
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
|
||
trackId=self.track_audio_in,
|
||
probability=probability,
|
||
),
|
||
priority=30,
|
||
)
|
||
else:
|
||
# No state change - keep previous status
|
||
vad_status = self._last_vad_status
|
||
|
||
# Update state based on VAD
|
||
if vad_status == "Speech" and self._last_vad_status != "Speech":
|
||
await self._on_speech_start()
|
||
|
||
self._last_vad_status = vad_status
|
||
|
||
# 2. Check for barge-in (user speaking while bot speaking)
|
||
# Filter false interruptions by requiring minimum speech duration
|
||
if self._is_bot_speaking and self._barge_in_enabled():
|
||
if vad_status == "Speech":
|
||
# User is speaking while bot is speaking
|
||
self._barge_in_silence_frames = 0 # Reset silence counter
|
||
|
||
if self._barge_in_speech_start_time is None:
|
||
# Start tracking speech duration
|
||
self._barge_in_speech_start_time = time.time()
|
||
self._barge_in_speech_frames = 1
|
||
logger.debug("Potential barge-in detected, tracking duration...")
|
||
else:
|
||
self._barge_in_speech_frames += 1
|
||
# Check if speech duration exceeds threshold
|
||
speech_duration_ms = (time.time() - self._barge_in_speech_start_time) * 1000
|
||
if speech_duration_ms >= self._resolved_barge_in_min_duration_ms():
|
||
logger.info(f"Barge-in confirmed after {speech_duration_ms:.0f}ms of speech ({self._barge_in_speech_frames} frames)")
|
||
await self._handle_barge_in()
|
||
else:
|
||
# Silence frame during potential barge-in
|
||
if self._barge_in_speech_start_time is not None:
|
||
self._barge_in_silence_frames += 1
|
||
# Allow brief silence gaps (VAD flickering)
|
||
if self._barge_in_silence_frames > self._barge_in_silence_tolerance_frames():
|
||
# Too much silence - reset barge-in tracking
|
||
logger.debug(f"Barge-in cancelled after {self._barge_in_silence_frames} silence frames")
|
||
self._barge_in_speech_start_time = None
|
||
self._barge_in_speech_frames = 0
|
||
self._barge_in_silence_frames = 0
|
||
elif self._is_bot_speaking and not self._barge_in_enabled():
|
||
self._barge_in_speech_start_time = None
|
||
self._barge_in_speech_frames = 0
|
||
self._barge_in_silence_frames = 0
|
||
|
||
# 3. Buffer audio for ASR.
|
||
# Gate ASR startup by a short speech-duration threshold to reduce
|
||
# false positives from micro noises, then always close the turn
|
||
# by EOU once ASR has started.
|
||
just_started_asr = False
|
||
if vad_status == "Speech" and not self._asr_capture_active:
|
||
self._pending_speech_audio += pcm_bytes
|
||
pending_ms = (len(self._pending_speech_audio) / (settings.sample_rate * 2)) * 1000.0
|
||
if pending_ms >= self._asr_start_min_speech_ms:
|
||
await self._start_asr_capture()
|
||
just_started_asr = True
|
||
|
||
if self._asr_capture_active:
|
||
if not just_started_asr:
|
||
self._audio_buffer += pcm_bytes
|
||
if len(self._audio_buffer) > self._max_audio_buffer_bytes:
|
||
# Keep only the most recent audio to cap memory usage
|
||
self._audio_buffer = self._audio_buffer[-self._max_audio_buffer_bytes:]
|
||
await self.asr_service.send_audio(pcm_bytes)
|
||
|
||
# For SiliconFlow ASR, trigger interim transcription periodically
|
||
# The service handles timing internally via start_interim_transcription()
|
||
|
||
# 4. Check for End of Utterance - this triggers LLM response
|
||
if self.eou_detector.process(vad_status, force_eligible=self._asr_capture_active):
|
||
await self._on_end_of_utterance()
|
||
elif (
|
||
self._asr_capture_active
|
||
and self._asr_capture_started_ms > 0.0
|
||
and (time.monotonic() * 1000.0 - self._asr_capture_started_ms) >= self._ASR_CAPTURE_MAX_MS
|
||
):
|
||
logger.warning(
|
||
f"[EOU] Force finalize after ASR capture timeout: {self._ASR_CAPTURE_MAX_MS}ms"
|
||
)
|
||
await self._on_end_of_utterance()
|
||
elif (
|
||
vad_status == "Silence"
|
||
and not self.eou_detector.is_speaking
|
||
and not self._asr_capture_active
|
||
and self.conversation.state == ConversationState.LISTENING
|
||
):
|
||
# Speech was too short to pass ASR gate; reset turn so next
|
||
# utterance can start cleanly.
|
||
self._pending_speech_audio = b""
|
||
self._audio_buffer = b""
|
||
self._last_sent_transcript = ""
|
||
await self.conversation.set_state(ConversationState.IDLE)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Pipeline audio processing error: {e}", exc_info=True)
|
||
|
||
async def process_text(self, text: str) -> None:
|
||
"""
|
||
Process text input (chat command).
|
||
|
||
Allows direct text input to bypass ASR.
|
||
|
||
Args:
|
||
text: User text input
|
||
"""
|
||
if not self._running:
|
||
return
|
||
|
||
logger.info(f"Processing text input: {text[:50]}...")
|
||
|
||
# Cancel any current speaking
|
||
await self._stop_current_speech()
|
||
|
||
self._start_turn()
|
||
self._finalize_utterance()
|
||
|
||
# Start new turn
|
||
await self.conversation.end_user_turn(text)
|
||
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
||
|
||
async def interrupt(self) -> None:
|
||
"""Interrupt current bot speech (manual interrupt command)."""
|
||
await self._handle_barge_in()
|
||
|
||
async def _on_transcript_callback(self, text: str, is_final: bool) -> None:
|
||
"""
|
||
Callback for ASR transcription results.
|
||
|
||
Streams transcription to client for display.
|
||
|
||
Args:
|
||
text: Transcribed text
|
||
is_final: Whether this is the final transcription
|
||
"""
|
||
# Avoid sending duplicate transcripts
|
||
if text == self._last_sent_transcript and not is_final:
|
||
return
|
||
|
||
now_ms = time.monotonic() * 1000.0
|
||
self._last_sent_transcript = text
|
||
|
||
if is_final:
|
||
self._pending_transcript_delta = ""
|
||
self._last_transcript_delta_emit_ms = 0.0
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"transcript.final",
|
||
trackId=self.track_audio_in,
|
||
text=text,
|
||
)
|
||
},
|
||
priority=30,
|
||
)
|
||
logger.debug(f"Sent transcript (final): {text[:50]}...")
|
||
return
|
||
|
||
self._pending_transcript_delta = text
|
||
should_emit = (
|
||
self._last_transcript_delta_emit_ms <= 0.0
|
||
or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS
|
||
)
|
||
if should_emit and self._pending_transcript_delta:
|
||
delta = self._pending_transcript_delta
|
||
self._pending_transcript_delta = ""
|
||
self._last_transcript_delta_emit_ms = now_ms
|
||
await self._emit_transcript_delta(delta)
|
||
|
||
if not is_final:
|
||
logger.info(f"[ASR] ASR interim: {text[:100]}")
|
||
logger.debug(f"Sent transcript (interim): {text[:50]}...")
|
||
|
||
async def _on_speech_start(self) -> None:
|
||
"""Handle user starting to speak."""
|
||
if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED):
|
||
self._start_turn()
|
||
self._finalize_utterance()
|
||
await self.conversation.start_user_turn()
|
||
self._audio_buffer = b""
|
||
self._last_sent_transcript = ""
|
||
self.eou_detector.reset()
|
||
self._asr_capture_active = False
|
||
self._asr_capture_started_ms = 0.0
|
||
self._pending_speech_audio = b""
|
||
|
||
# Clear ASR buffer. Interim starts only after ASR capture is activated.
|
||
if hasattr(self.asr_service, 'clear_buffer'):
|
||
self.asr_service.clear_buffer()
|
||
|
||
logger.debug("User speech started")
|
||
|
||
async def _start_asr_capture(self) -> None:
|
||
"""Start ASR capture for the current turn after min speech gate passes."""
|
||
if self._asr_capture_active:
|
||
return
|
||
|
||
if hasattr(self.asr_service, 'start_interim_transcription'):
|
||
await self.asr_service.start_interim_transcription()
|
||
|
||
# Prime ASR with a short pre-speech context window so the utterance
|
||
# start isn't lost while waiting for VAD to transition to Speech.
|
||
pre_roll = self._pre_speech_buffer
|
||
# _pre_speech_buffer already includes current speech frames; avoid
|
||
# duplicating onset audio when we append pending speech below.
|
||
if self._pending_speech_audio and len(pre_roll) > len(self._pending_speech_audio):
|
||
pre_roll = pre_roll[:-len(self._pending_speech_audio)]
|
||
elif self._pending_speech_audio:
|
||
pre_roll = b""
|
||
capture_audio = pre_roll + self._pending_speech_audio
|
||
if capture_audio:
|
||
await self.asr_service.send_audio(capture_audio)
|
||
self._audio_buffer = capture_audio[-self._max_audio_buffer_bytes:]
|
||
|
||
self._asr_capture_active = True
|
||
self._asr_capture_started_ms = time.monotonic() * 1000.0
|
||
logger.debug(
|
||
f"ASR capture started after speech gate ({self._asr_start_min_speech_ms}ms), "
|
||
f"capture={len(capture_audio)} bytes"
|
||
)
|
||
|
||
async def _on_end_of_utterance(self) -> None:
|
||
"""Handle end of user utterance."""
|
||
if self.conversation.state not in (ConversationState.LISTENING, ConversationState.INTERRUPTED):
|
||
# Prevent a stale ASR capture watchdog from repeatedly forcing EOU
|
||
# once the conversation has already moved past user-listening states.
|
||
self._asr_capture_active = False
|
||
self._asr_capture_started_ms = 0.0
|
||
self._pending_speech_audio = b""
|
||
return
|
||
|
||
# Add a tiny trailing silence tail to stabilize final-token decoding.
|
||
if self._asr_final_tail_bytes > 0:
|
||
final_tail = b"\x00" * self._asr_final_tail_bytes
|
||
await self.asr_service.send_audio(final_tail)
|
||
|
||
# Stop interim transcriptions
|
||
if hasattr(self.asr_service, 'stop_interim_transcription'):
|
||
await self.asr_service.stop_interim_transcription()
|
||
|
||
# Get final transcription from ASR service
|
||
user_text = ""
|
||
|
||
if hasattr(self.asr_service, 'get_final_transcription'):
|
||
# SiliconFlow ASR - get final transcription
|
||
user_text = await self.asr_service.get_final_transcription()
|
||
elif hasattr(self.asr_service, 'get_and_clear_text'):
|
||
# Buffered ASR - get accumulated text
|
||
user_text = self.asr_service.get_and_clear_text()
|
||
|
||
# Skip if no meaningful text
|
||
if not user_text or not user_text.strip():
|
||
logger.debug("[EOU] Detected but no transcription - skipping")
|
||
# Reset for next utterance
|
||
self._audio_buffer = b""
|
||
self._last_sent_transcript = ""
|
||
self._asr_capture_active = False
|
||
self._asr_capture_started_ms = 0.0
|
||
self._pending_speech_audio = b""
|
||
# Return to idle; don't force LISTENING which causes buffering on silence
|
||
await self.conversation.set_state(ConversationState.IDLE)
|
||
return
|
||
|
||
logger.info(f"[EOU] Detected - user said: {user_text[:100]}...")
|
||
self._finalize_utterance()
|
||
|
||
# For ASR backends that already emitted final via callback,
|
||
# avoid duplicating transcript.final on EOU.
|
||
if user_text != self._last_sent_transcript:
|
||
await self._send_event({
|
||
**ev(
|
||
"transcript.final",
|
||
trackId=self.track_audio_in,
|
||
text=user_text,
|
||
)
|
||
}, priority=25)
|
||
|
||
# Clear buffers
|
||
self._audio_buffer = b""
|
||
self._last_sent_transcript = ""
|
||
self._pending_transcript_delta = ""
|
||
self._last_transcript_delta_emit_ms = 0.0
|
||
self._asr_capture_active = False
|
||
self._asr_capture_started_ms = 0.0
|
||
self._pending_speech_audio = b""
|
||
|
||
# Process the turn - trigger LLM response
|
||
# Cancel any existing turn to avoid overlapping assistant responses
|
||
await self._stop_current_speech()
|
||
await self.conversation.end_user_turn(user_text)
|
||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||
|
||
def _resolved_knowledge_config(self) -> Dict[str, Any]:
|
||
cfg: Dict[str, Any] = {}
|
||
if isinstance(self._runtime_knowledge, dict):
|
||
cfg.update(self._runtime_knowledge)
|
||
kb_id = self._runtime_knowledge_base_id or str(
|
||
cfg.get("kbId") or cfg.get("knowledgeBaseId") or ""
|
||
).strip()
|
||
if kb_id:
|
||
cfg["kbId"] = kb_id
|
||
cfg.setdefault("enabled", True)
|
||
return cfg
|
||
|
||
def _resolved_tool_schemas(self) -> List[Dict[str, Any]]:
|
||
schemas: List[Dict[str, Any]] = []
|
||
seen: set[str] = set()
|
||
for item in self._runtime_tools:
|
||
if isinstance(item, str):
|
||
tool_name = self._normalize_tool_name(item)
|
||
if not tool_name or tool_name in seen:
|
||
continue
|
||
seen.add(tool_name)
|
||
base = self._DEFAULT_TOOL_SCHEMAS.get(tool_name)
|
||
if base:
|
||
schemas.append(
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": base["name"],
|
||
"description": base.get("description") or "",
|
||
"parameters": base.get("parameters") or {"type": "object", "properties": {}},
|
||
},
|
||
}
|
||
)
|
||
else:
|
||
schemas.append(
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_name,
|
||
"description": f"Execute tool '{tool_name}'",
|
||
"parameters": {"type": "object", "properties": {}},
|
||
},
|
||
}
|
||
)
|
||
continue
|
||
|
||
if not isinstance(item, dict):
|
||
continue
|
||
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
fn_name = self._normalize_tool_name(fn.get("name"))
|
||
if not fn_name or fn_name in seen:
|
||
continue
|
||
seen.add(fn_name)
|
||
schemas.append(
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": fn_name,
|
||
"description": str(fn.get("description") or item.get("description") or ""),
|
||
"parameters": fn.get("parameters") or {"type": "object", "properties": {}},
|
||
},
|
||
}
|
||
)
|
||
continue
|
||
|
||
if item.get("name"):
|
||
item_name = self._normalize_tool_name(item.get("name"))
|
||
if not item_name or item_name in seen:
|
||
continue
|
||
seen.add(item_name)
|
||
schemas.append(
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": item_name,
|
||
"description": str(item.get("description") or ""),
|
||
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
|
||
},
|
||
}
|
||
)
|
||
return schemas
|
||
|
||
def _resolved_tool_executor_map(self) -> Dict[str, str]:
|
||
result: Dict[str, str] = {}
|
||
for item in self._runtime_tools:
|
||
if isinstance(item, str):
|
||
name = self._normalize_tool_name(item)
|
||
if name in self._DEFAULT_CLIENT_EXECUTORS:
|
||
result[name] = "client"
|
||
continue
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
name = self._normalize_tool_name(fn.get("name"))
|
||
else:
|
||
name = self._normalize_tool_name(item.get("name"))
|
||
if not name:
|
||
continue
|
||
executor = str(item.get("executor") or item.get("run_on") or "").strip().lower()
|
||
if executor in {"client", "server"}:
|
||
result[name] = executor
|
||
return result
|
||
|
||
def _resolved_tool_default_args_map(self) -> Dict[str, Dict[str, Any]]:
|
||
result: Dict[str, Dict[str, Any]] = {}
|
||
for item in self._runtime_tools:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
name = self._normalize_tool_name(fn.get("name"))
|
||
else:
|
||
name = self._normalize_tool_name(item.get("name"))
|
||
if not name:
|
||
continue
|
||
raw_defaults = item.get("defaultArgs")
|
||
if raw_defaults is None:
|
||
raw_defaults = item.get("default_args")
|
||
if isinstance(raw_defaults, dict):
|
||
result[name] = dict(raw_defaults)
|
||
return result
|
||
|
||
def _resolved_tool_wait_for_response_map(self) -> Dict[str, bool]:
|
||
result: Dict[str, bool] = {}
|
||
for item in self._runtime_tools:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
name = self._normalize_tool_name(fn.get("name"))
|
||
else:
|
||
name = self._normalize_tool_name(item.get("name"))
|
||
if not name:
|
||
continue
|
||
raw_wait = item.get("waitForResponse")
|
||
if raw_wait is None:
|
||
raw_wait = item.get("wait_for_response")
|
||
if isinstance(raw_wait, bool):
|
||
result[name] = raw_wait
|
||
return result
|
||
|
||
def _resolved_tool_id_map(self) -> Dict[str, str]:
|
||
result: Dict[str, str] = {}
|
||
for item in self._runtime_tools:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
alias = self._normalize_tool_name(fn.get("name"))
|
||
else:
|
||
alias = self._normalize_tool_name(item.get("name"))
|
||
if not alias:
|
||
continue
|
||
tool_id = self._normalize_tool_name(item.get("toolId") or item.get("tool_id") or alias)
|
||
if tool_id:
|
||
result[alias] = tool_id
|
||
return result
|
||
|
||
def _resolved_tool_display_name_map(self) -> Dict[str, str]:
|
||
result: Dict[str, str] = {}
|
||
for item in self._runtime_tools:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
name = self._normalize_tool_name(fn.get("name"))
|
||
else:
|
||
name = self._normalize_tool_name(item.get("name"))
|
||
if not name:
|
||
continue
|
||
display_name = str(
|
||
item.get("displayName")
|
||
or item.get("display_name")
|
||
or name
|
||
).strip()
|
||
if display_name:
|
||
result[name] = display_name
|
||
tool_id = self._normalize_tool_name(item.get("toolId") or item.get("tool_id") or "")
|
||
if tool_id:
|
||
result[tool_id] = display_name
|
||
return result
|
||
|
||
def _resolved_tool_allowlist(self) -> List[str]:
|
||
names: set[str] = set()
|
||
for item in self._runtime_tools:
|
||
if isinstance(item, str):
|
||
name = self._normalize_tool_name(item)
|
||
if name:
|
||
names.add(name)
|
||
continue
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
names.add(self._normalize_tool_name(fn.get("name")))
|
||
elif item.get("name"):
|
||
names.add(self._normalize_tool_name(item.get("name")))
|
||
return sorted([name for name in names if name])
|
||
|
||
def _resolved_manual_opener_tool_calls(self) -> List[Dict[str, Any]]:
|
||
result: List[Dict[str, Any]] = []
|
||
for item in self._runtime_manual_opener_tool_calls:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
tool_name = self._normalize_tool_name(str(
|
||
item.get("toolName")
|
||
or item.get("tool_name")
|
||
or item.get("name")
|
||
or ""
|
||
).strip())
|
||
if not tool_name:
|
||
continue
|
||
args_raw = item.get("arguments")
|
||
args: Dict[str, Any] = {}
|
||
if isinstance(args_raw, dict):
|
||
args = dict(args_raw)
|
||
elif isinstance(args_raw, str):
|
||
text_value = args_raw.strip()
|
||
if text_value:
|
||
try:
|
||
parsed = json.loads(text_value)
|
||
if isinstance(parsed, dict):
|
||
args = parsed
|
||
except Exception:
|
||
logger.warning(f"[OpenerTool] ignore invalid JSON args for tool={tool_name}")
|
||
result.append({"toolName": tool_name, "arguments": args})
|
||
return result[:8]
|
||
|
||
def _tool_name(self, tool_call: Dict[str, Any]) -> str:
|
||
fn = tool_call.get("function")
|
||
if isinstance(fn, dict):
|
||
return self._normalize_tool_name(fn.get("name"))
|
||
return ""
|
||
|
||
def _tool_id_for_name(self, tool_name: str) -> str:
|
||
normalized = self._normalize_tool_name(tool_name)
|
||
return self._normalize_tool_name(self._runtime_tool_id_map.get(normalized) or normalized)
|
||
|
||
def _tool_display_name(self, tool_name: str) -> str:
|
||
normalized = self._normalize_tool_name(tool_name)
|
||
return str(self._runtime_tool_display_names.get(normalized) or normalized).strip()
|
||
|
||
def _tool_wait_for_response(self, tool_name: str) -> bool:
|
||
normalized = self._normalize_tool_name(tool_name)
|
||
return bool(self._runtime_tool_wait_for_response.get(normalized, False))
|
||
|
||
def _tool_executor(self, tool_call: Dict[str, Any]) -> str:
|
||
name = self._tool_name(tool_call)
|
||
if name and name in self._runtime_tool_executor:
|
||
return self._runtime_tool_executor[name]
|
||
# Default to server execution unless explicitly marked as client.
|
||
return "server"
|
||
|
||
def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||
fn = tool_call.get("function")
|
||
if not isinstance(fn, dict):
|
||
return {}
|
||
raw = fn.get("arguments")
|
||
if isinstance(raw, dict):
|
||
return raw
|
||
if isinstance(raw, str) and raw.strip():
|
||
try:
|
||
parsed = json.loads(raw)
|
||
return parsed if isinstance(parsed, dict) else {"raw": raw}
|
||
except Exception:
|
||
return {"raw": raw}
|
||
return {}
|
||
|
||
def _apply_tool_default_args(self, tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||
normalized_tool_name = self._normalize_tool_name(tool_name)
|
||
defaults = self._runtime_tool_default_args.get(normalized_tool_name)
|
||
if not isinstance(defaults, dict) or not defaults:
|
||
return args
|
||
merged = dict(defaults)
|
||
if isinstance(args, dict):
|
||
merged.update(args)
|
||
return merged
|
||
|
||
def _build_manual_opener_follow_up_context(self, payload: Dict[str, List[Dict[str, Any]]]) -> str:
|
||
tool_calls = payload.get("toolCalls") if isinstance(payload.get("toolCalls"), list) else []
|
||
tool_results = payload.get("toolResults") if isinstance(payload.get("toolResults"), list) else []
|
||
return (
|
||
"Initial opener tool calls already executed. Continue with a natural assistant follow-up. "
|
||
"If tool results include user selections or values, use them in your response. "
|
||
"Never expose internal tool ids or raw payloads.\n"
|
||
f"opener_tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"
|
||
f"opener_tool_results={json.dumps(tool_results, ensure_ascii=False)}"
|
||
)
|
||
|
||
async def _execute_manual_opener_tool_calls(self) -> Dict[str, List[Dict[str, Any]]]:
|
||
calls = self._resolved_manual_opener_tool_calls()
|
||
tool_calls_for_context: List[Dict[str, Any]] = []
|
||
tool_results_for_context: List[Dict[str, Any]] = []
|
||
if not calls:
|
||
return {"toolCalls": tool_calls_for_context, "toolResults": tool_results_for_context}
|
||
|
||
for call in calls:
|
||
tool_name = str(call.get("toolName") or "").strip()
|
||
if not tool_name:
|
||
continue
|
||
tool_id = self._tool_id_for_name(tool_name)
|
||
tool_display_name = self._tool_display_name(tool_name) or tool_name
|
||
tool_arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {}
|
||
merged_tool_arguments = self._apply_tool_default_args(tool_name, tool_arguments)
|
||
call_id = f"call_opener_{uuid.uuid4().hex[:10]}"
|
||
wait_for_response = self._tool_wait_for_response(tool_name)
|
||
tool_call = {
|
||
"id": call_id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_name,
|
||
"arguments": json.dumps(merged_tool_arguments, ensure_ascii=False),
|
||
},
|
||
}
|
||
executor = self._tool_executor(tool_call)
|
||
tool_calls_for_context.append(
|
||
{
|
||
"tool_call_id": call_id,
|
||
"tool_name": tool_name,
|
||
"tool_id": tool_id,
|
||
"tool_display_name": tool_display_name,
|
||
"arguments": merged_tool_arguments,
|
||
"wait_for_response": wait_for_response,
|
||
"executor": executor,
|
||
}
|
||
)
|
||
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"assistant.tool_call",
|
||
trackId=self.track_audio_out,
|
||
tool_call_id=call_id,
|
||
tool_name=tool_name,
|
||
tool_id=tool_id,
|
||
tool_display_name=tool_display_name,
|
||
wait_for_response=wait_for_response,
|
||
arguments=merged_tool_arguments,
|
||
executor=executor,
|
||
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
|
||
tool_call={**tool_call, "executor": executor, "wait_for_response": wait_for_response},
|
||
)
|
||
},
|
||
priority=22,
|
||
)
|
||
logger.info(
|
||
f"[OpenerTool] execute name={tool_name} call_id={call_id} executor={executor} "
|
||
f"wait_for_response={wait_for_response}"
|
||
)
|
||
|
||
if executor == "client":
|
||
self._pending_client_tool_call_ids.add(call_id)
|
||
if wait_for_response:
|
||
result = await self._wait_for_single_tool_result(call_id)
|
||
await self._emit_tool_result(result, source="client")
|
||
tool_results_for_context.append(result if isinstance(result, dict) else {"tool_call_id": call_id})
|
||
continue
|
||
|
||
call_for_executor = dict(tool_call)
|
||
fn_for_executor = (
|
||
dict(call_for_executor.get("function"))
|
||
if isinstance(call_for_executor.get("function"), dict)
|
||
else None
|
||
)
|
||
if isinstance(fn_for_executor, dict):
|
||
fn_for_executor["name"] = tool_id
|
||
call_for_executor["function"] = fn_for_executor
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self._server_tool_executor(call_for_executor),
|
||
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
result = {
|
||
"tool_call_id": call_id,
|
||
"name": tool_name,
|
||
"output": {"message": "server tool timeout"},
|
||
"status": {"code": 504, "message": "server_tool_timeout"},
|
||
}
|
||
await self._emit_tool_result(result, source="server")
|
||
tool_results_for_context.append(result if isinstance(result, dict) else {"tool_call_id": call_id})
|
||
|
||
return {"toolCalls": tool_calls_for_context, "toolResults": tool_results_for_context}
|
||
|
||
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||
status = result.get("status") if isinstance(result.get("status"), dict) else {}
|
||
status_code = int(status.get("code") or 0) if status else 0
|
||
status_message = str(status.get("message") or "") if status else ""
|
||
tool_call_id = str(result.get("tool_call_id") or result.get("id") or "")
|
||
tool_name = str(result.get("name") or "unknown_tool")
|
||
tool_display_name = self._tool_display_name(tool_name) or tool_name
|
||
ok = bool(200 <= status_code < 300)
|
||
retryable = status_code >= 500 or status_code in {429, 408}
|
||
error: Optional[Dict[str, Any]] = None
|
||
if not ok:
|
||
error = {
|
||
"code": status_code or 500,
|
||
"message": status_message or "tool_execution_failed",
|
||
"retryable": retryable,
|
||
}
|
||
return {
|
||
"tool_call_id": tool_call_id,
|
||
"tool_name": tool_name,
|
||
"tool_display_name": tool_display_name,
|
||
"ok": ok,
|
||
"error": error,
|
||
"status": {"code": status_code, "message": status_message},
|
||
}
|
||
|
||
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
|
||
tool_name = str(result.get("name") or "unknown_tool")
|
||
tool_display_name = self._tool_display_name(tool_name) or tool_name
|
||
call_id = str(result.get("tool_call_id") or result.get("id") or "")
|
||
status = result.get("status") if isinstance(result.get("status"), dict) else {}
|
||
status_code = int(status.get("code") or 0) if status else 0
|
||
status_message = str(status.get("message") or "") if status else ""
|
||
logger.info(
|
||
f"[Tool] emit result source={source} name={tool_name} call_id={call_id} "
|
||
f"status={status_code} {status_message}".strip()
|
||
)
|
||
normalized = self._normalize_tool_result(result)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"assistant.tool_result",
|
||
trackId=self.track_audio_out,
|
||
source=source,
|
||
tool_call_id=normalized["tool_call_id"],
|
||
tool_name=normalized["tool_name"],
|
||
tool_display_name=normalized["tool_display_name"],
|
||
ok=normalized["ok"],
|
||
error=normalized["error"],
|
||
result=result,
|
||
)
|
||
},
|
||
priority=22,
|
||
)
|
||
|
||
async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None:
|
||
"""Handle client tool execution results."""
|
||
if not isinstance(results, list):
|
||
return
|
||
|
||
for item in results:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
|
||
if not call_id:
|
||
continue
|
||
if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids:
|
||
logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}")
|
||
continue
|
||
if call_id in self._completed_tool_call_ids:
|
||
logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}")
|
||
continue
|
||
status = item.get("status") if isinstance(item.get("status"), dict) else {}
|
||
status_code = int(status.get("code") or 0) if status else 0
|
||
status_message = str(status.get("message") or "") if status else ""
|
||
tool_name = str(item.get("name") or "unknown_tool")
|
||
logger.info(
|
||
f"[Tool] received client result name={tool_name} call_id={call_id} "
|
||
f"status={status_code} {status_message}".strip()
|
||
)
|
||
|
||
waiter = self._pending_tool_waiters.get(call_id)
|
||
if waiter and not waiter.done():
|
||
waiter.set_result(item)
|
||
self._completed_tool_call_ids.add(call_id)
|
||
continue
|
||
self._early_tool_results[call_id] = item
|
||
self._completed_tool_call_ids.add(call_id)
|
||
|
||
async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]:
|
||
if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results:
|
||
return {
|
||
"tool_call_id": call_id,
|
||
"status": {"code": 208, "message": "tool_call result already handled"},
|
||
"output": "",
|
||
}
|
||
if call_id in self._early_tool_results:
|
||
self._completed_tool_call_ids.add(call_id)
|
||
return self._early_tool_results.pop(call_id)
|
||
|
||
loop = asyncio.get_running_loop()
|
||
future = loop.create_future()
|
||
self._pending_tool_waiters[call_id] = future
|
||
try:
|
||
return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS)
|
||
except asyncio.TimeoutError:
|
||
self._completed_tool_call_ids.add(call_id)
|
||
return {
|
||
"tool_call_id": call_id,
|
||
"status": {"code": 504, "message": "tool_call timeout"},
|
||
"output": "",
|
||
}
|
||
finally:
|
||
self._pending_tool_waiters.pop(call_id, None)
|
||
self._pending_client_tool_call_ids.discard(call_id)
|
||
|
||
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
|
||
if isinstance(item, LLMStreamEvent):
|
||
return item
|
||
if isinstance(item, str):
|
||
return LLMStreamEvent(type="text_delta", text=item)
|
||
if isinstance(item, dict):
|
||
event_type = str(item.get("type") or "")
|
||
if event_type in {"text_delta", "tool_call", "done"}:
|
||
return LLMStreamEvent(
|
||
type=event_type, # type: ignore[arg-type]
|
||
text=item.get("text"),
|
||
tool_call=item.get("tool_call"),
|
||
)
|
||
return LLMStreamEvent(type="done")
|
||
|
||
async def _handle_turn(self, user_text: str, system_context: Optional[str] = None) -> None:
|
||
"""
|
||
Handle a complete conversation turn.
|
||
|
||
Uses sentence-by-sentence streaming TTS for lower latency.
|
||
|
||
Args:
|
||
user_text: User's transcribed text
|
||
"""
|
||
try:
|
||
if not self._current_turn_id:
|
||
self._start_turn()
|
||
if not self._current_utterance_id:
|
||
self._finalize_utterance()
|
||
turn_id = self._current_turn_id
|
||
utterance_id = self._current_utterance_id
|
||
response_id = self._start_response()
|
||
# Start latency tracking
|
||
self._turn_start_time = time.time()
|
||
self._first_audio_sent = False
|
||
|
||
full_response = ""
|
||
messages = self.conversation.get_messages()
|
||
if system_context and system_context.strip():
|
||
messages = [*messages, LLMMessage(role="system", content=system_context.strip())]
|
||
max_rounds = 3
|
||
|
||
await self.conversation.start_assistant_turn()
|
||
self._is_bot_speaking = True
|
||
self._interrupt_event.clear()
|
||
self._drop_outbound_audio = False
|
||
|
||
first_audio_sent = False
|
||
self._pending_llm_delta = ""
|
||
self._last_llm_delta_emit_ms = 0.0
|
||
for _ in range(max_rounds):
|
||
if self._interrupt_event.is_set():
|
||
break
|
||
|
||
sentence_buffer = ""
|
||
pending_punctuation = ""
|
||
round_response = ""
|
||
tool_calls: List[Dict[str, Any]] = []
|
||
allow_text_output = True
|
||
use_engine_sentence_split = self._use_engine_sentence_split_for_tts()
|
||
|
||
async for raw_event in self.llm_service.generate_stream(messages):
|
||
if self._interrupt_event.is_set():
|
||
break
|
||
|
||
event = self._normalize_stream_event(raw_event)
|
||
if event.type == "tool_call":
|
||
await self._flush_pending_llm_delta(
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
|
||
if not tool_call:
|
||
continue
|
||
allow_text_output = False
|
||
executor = self._tool_executor(tool_call)
|
||
enriched_tool_call = dict(tool_call)
|
||
enriched_tool_call["executor"] = executor
|
||
tool_name = self._tool_name(enriched_tool_call) or "unknown_tool"
|
||
tool_id = self._tool_id_for_name(tool_name)
|
||
tool_display_name = self._tool_display_name(tool_name) or tool_name
|
||
wait_for_response = self._tool_wait_for_response(tool_name)
|
||
enriched_tool_call["wait_for_response"] = wait_for_response
|
||
call_id = str(enriched_tool_call.get("id") or "").strip()
|
||
fn_payload = (
|
||
dict(enriched_tool_call.get("function"))
|
||
if isinstance(enriched_tool_call.get("function"), dict)
|
||
else None
|
||
)
|
||
raw_args = str(fn_payload.get("arguments") or "") if isinstance(fn_payload, dict) else ""
|
||
tool_arguments = self._tool_arguments(enriched_tool_call)
|
||
merged_tool_arguments = self._apply_tool_default_args(tool_name, tool_arguments)
|
||
try:
|
||
merged_args_text = json.dumps(merged_tool_arguments, ensure_ascii=False)
|
||
except Exception:
|
||
merged_args_text = raw_args if raw_args else "{}"
|
||
if isinstance(fn_payload, dict):
|
||
fn_payload["arguments"] = merged_args_text
|
||
enriched_tool_call["function"] = fn_payload
|
||
args_preview = raw_args if len(raw_args) <= 160 else f"{raw_args[:160]}..."
|
||
logger.info(
|
||
f"[Tool] call requested name={tool_name} call_id={call_id} "
|
||
f"executor={executor} args={args_preview} merged_args={merged_args_text}"
|
||
)
|
||
tool_calls.append(enriched_tool_call)
|
||
if executor == "client" and call_id:
|
||
self._pending_client_tool_call_ids.add(call_id)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"assistant.tool_call",
|
||
trackId=self.track_audio_out,
|
||
tool_call_id=call_id,
|
||
tool_name=tool_name,
|
||
tool_id=tool_id,
|
||
tool_display_name=tool_display_name,
|
||
wait_for_response=wait_for_response,
|
||
arguments=tool_arguments,
|
||
executor=executor,
|
||
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
|
||
tool_call=enriched_tool_call,
|
||
)
|
||
},
|
||
priority=22,
|
||
)
|
||
continue
|
||
|
||
if event.type != "text_delta":
|
||
continue
|
||
|
||
text_chunk = event.text or ""
|
||
if not text_chunk:
|
||
continue
|
||
|
||
if not allow_text_output:
|
||
continue
|
||
|
||
full_response += text_chunk
|
||
round_response += text_chunk
|
||
sentence_buffer += text_chunk
|
||
await self.conversation.update_assistant_text(text_chunk)
|
||
self._pending_llm_delta += text_chunk
|
||
now_ms = time.monotonic() * 1000.0
|
||
if (
|
||
self._last_llm_delta_emit_ms <= 0.0
|
||
or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS
|
||
):
|
||
await self._flush_pending_llm_delta(
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
|
||
if use_engine_sentence_split:
|
||
while True:
|
||
split_result = extract_tts_sentence(
|
||
sentence_buffer,
|
||
end_chars=self._SENTENCE_END_CHARS,
|
||
trailing_chars=self._SENTENCE_TRAILING_CHARS,
|
||
closers=self._SENTENCE_CLOSERS,
|
||
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
|
||
hold_trailing_at_buffer_end=True,
|
||
force=False,
|
||
)
|
||
if not split_result:
|
||
break
|
||
sentence, sentence_buffer = split_result
|
||
if not sentence:
|
||
continue
|
||
|
||
sentence = f"{pending_punctuation}{sentence}".strip()
|
||
pending_punctuation = ""
|
||
if not sentence:
|
||
continue
|
||
|
||
if not has_spoken_content(sentence):
|
||
pending_punctuation = sentence
|
||
continue
|
||
|
||
if self._tts_output_enabled() and not self._interrupt_event.is_set():
|
||
if not first_audio_sent:
|
||
tts_id = self._start_tts()
|
||
self._mark_client_playback_started(tts_id)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"output.audio.start",
|
||
trackId=self.track_audio_out,
|
||
),
|
||
"turn_id": turn_id,
|
||
"utterance_id": utterance_id,
|
||
"response_id": response_id,
|
||
},
|
||
priority=30,
|
||
)
|
||
first_audio_sent = True
|
||
|
||
await self._speak_sentence(
|
||
sentence,
|
||
fade_in_ms=0,
|
||
fade_out_ms=8,
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
|
||
if use_engine_sentence_split:
|
||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||
else:
|
||
remaining_text = sentence_buffer.strip()
|
||
await self._flush_pending_llm_delta(
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
if (
|
||
self._tts_output_enabled()
|
||
and remaining_text
|
||
and has_spoken_content(remaining_text)
|
||
and not self._interrupt_event.is_set()
|
||
):
|
||
if not first_audio_sent:
|
||
tts_id = self._start_tts()
|
||
self._mark_client_playback_started(tts_id)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"output.audio.start",
|
||
trackId=self.track_audio_out,
|
||
),
|
||
"turn_id": turn_id,
|
||
"utterance_id": utterance_id,
|
||
"response_id": response_id,
|
||
},
|
||
priority=30,
|
||
)
|
||
first_audio_sent = True
|
||
await self._speak_sentence(
|
||
remaining_text,
|
||
fade_in_ms=0,
|
||
fade_out_ms=8,
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
|
||
if not tool_calls:
|
||
break
|
||
|
||
tool_results: List[Dict[str, Any]] = []
|
||
for call in tool_calls:
|
||
call_id = str(call.get("id") or "").strip()
|
||
if not call_id:
|
||
continue
|
||
executor = str(call.get("executor") or "server").strip().lower()
|
||
tool_name = self._tool_name(call) or "unknown_tool"
|
||
tool_id = self._tool_id_for_name(tool_name)
|
||
logger.info(f"[Tool] execute start name={tool_name} call_id={call_id} executor={executor}")
|
||
if executor == "client":
|
||
result = await self._wait_for_single_tool_result(call_id)
|
||
await self._emit_tool_result(result, source="client")
|
||
tool_results.append(result)
|
||
continue
|
||
|
||
call_for_executor = dict(call)
|
||
fn_for_executor = (
|
||
dict(call_for_executor.get("function"))
|
||
if isinstance(call_for_executor.get("function"), dict)
|
||
else None
|
||
)
|
||
if isinstance(fn_for_executor, dict):
|
||
fn_for_executor["name"] = tool_id
|
||
call_for_executor["function"] = fn_for_executor
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self._server_tool_executor(call_for_executor),
|
||
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
result = {
|
||
"tool_call_id": call_id,
|
||
"name": self._tool_name(call) or "unknown_tool",
|
||
"output": {"message": "server tool timeout"},
|
||
"status": {"code": 504, "message": "server_tool_timeout"},
|
||
}
|
||
await self._emit_tool_result(result, source="server")
|
||
tool_results.append(result)
|
||
|
||
messages = [
|
||
*messages,
|
||
LLMMessage(
|
||
role="assistant",
|
||
content=round_response.strip(),
|
||
),
|
||
LLMMessage(
|
||
role="system",
|
||
content=(
|
||
"Tool execution results are available. "
|
||
"Continue answering the user naturally using these results. "
|
||
"Do not request the same tool again in this turn.\n"
|
||
f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"
|
||
f"tool_results={json.dumps(tool_results, ensure_ascii=False)}"
|
||
),
|
||
),
|
||
]
|
||
|
||
if full_response and not self._interrupt_event.is_set():
|
||
await self._flush_pending_llm_delta(
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"assistant.response.final",
|
||
trackId=self.track_audio_out,
|
||
text=full_response,
|
||
),
|
||
"turn_id": turn_id,
|
||
"utterance_id": utterance_id,
|
||
"response_id": response_id,
|
||
},
|
||
priority=20,
|
||
)
|
||
|
||
# Send track end
|
||
if first_audio_sent:
|
||
await self._flush_audio_out_frames(priority=50)
|
||
await self._send_event({
|
||
**ev(
|
||
"output.audio.end",
|
||
trackId=self.track_audio_out,
|
||
),
|
||
"turn_id": turn_id,
|
||
"utterance_id": utterance_id,
|
||
"response_id": response_id,
|
||
}, priority=10)
|
||
|
||
# End assistant turn
|
||
await self.conversation.end_assistant_turn(
|
||
was_interrupted=self._interrupt_event.is_set()
|
||
)
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("Turn handling cancelled")
|
||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||
except Exception as e:
|
||
logger.error(f"Turn handling error: {e}", exc_info=True)
|
||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||
finally:
|
||
self._is_bot_speaking = False
|
||
# Reset barge-in tracking when bot finishes speaking
|
||
self._barge_in_speech_start_time = None
|
||
self._barge_in_speech_frames = 0
|
||
self._barge_in_silence_frames = 0
|
||
self._current_response_id = None
|
||
self._current_tts_id = None
|
||
|
||
async def _speak_sentence(
|
||
self,
|
||
text: str,
|
||
fade_in_ms: int = 0,
|
||
fade_out_ms: int = 8,
|
||
turn_id: Optional[str] = None,
|
||
utterance_id: Optional[str] = None,
|
||
response_id: Optional[str] = None,
|
||
) -> None:
|
||
"""
|
||
Synthesize and send a single sentence.
|
||
|
||
Args:
|
||
text: Sentence to speak
|
||
fade_in_ms: Fade-in duration for sentence start chunks
|
||
fade_out_ms: Fade-out duration for sentence end chunks
|
||
"""
|
||
if not self._tts_output_enabled():
|
||
return
|
||
|
||
if not text.strip() or self._interrupt_event.is_set() or not self.tts_service:
|
||
return
|
||
|
||
logger.info(f"[TTS] split sentence: {text!r}")
|
||
|
||
try:
|
||
is_first_chunk = True
|
||
async for chunk in self.tts_service.synthesize_stream(text):
|
||
# Check interrupt at the start of each iteration
|
||
if self._interrupt_event.is_set():
|
||
logger.debug("TTS sentence interrupted")
|
||
break
|
||
|
||
# Track and log first audio packet latency (TTFB)
|
||
if not self._first_audio_sent and self._turn_start_time:
|
||
ttfb_ms = (time.time() - self._turn_start_time) * 1000
|
||
self._first_audio_sent = True
|
||
logger.info(f"[TTFB] Server first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||
|
||
# Send TTFB event to client
|
||
await self._send_event({
|
||
**ev(
|
||
"metrics.ttfb",
|
||
trackId=self.track_audio_out,
|
||
latencyMs=round(ttfb_ms),
|
||
),
|
||
"turn_id": turn_id,
|
||
"utterance_id": utterance_id,
|
||
"response_id": response_id,
|
||
}, priority=25)
|
||
|
||
# Double-check interrupt right before sending audio
|
||
if self._interrupt_event.is_set():
|
||
break
|
||
|
||
smoothed_audio = self._apply_edge_fade(
|
||
pcm_bytes=chunk.audio,
|
||
sample_rate=chunk.sample_rate,
|
||
fade_in=is_first_chunk,
|
||
fade_out=bool(chunk.is_final),
|
||
fade_in_ms=fade_in_ms,
|
||
fade_out_ms=fade_out_ms,
|
||
)
|
||
is_first_chunk = False
|
||
|
||
await self._send_audio(smoothed_audio, priority=50)
|
||
except asyncio.CancelledError:
|
||
logger.debug("TTS sentence cancelled")
|
||
except Exception as e:
|
||
logger.error(f"TTS sentence error: {e}")
|
||
|
||
def _apply_edge_fade(
|
||
self,
|
||
pcm_bytes: bytes,
|
||
sample_rate: int,
|
||
fade_in: bool = False,
|
||
fade_out: bool = False,
|
||
fade_in_ms: int = 0,
|
||
fade_out_ms: int = 8,
|
||
) -> bytes:
|
||
"""Apply short edge fades to reduce click/pop at sentence boundaries."""
|
||
if not pcm_bytes or (not fade_in and not fade_out):
|
||
return pcm_bytes
|
||
|
||
try:
|
||
samples = np.frombuffer(pcm_bytes, dtype="<i2").astype(np.float32)
|
||
if samples.size == 0:
|
||
return pcm_bytes
|
||
|
||
if fade_in and fade_in_ms > 0:
|
||
fade_in_samples = int(sample_rate * (fade_in_ms / 1000.0))
|
||
fade_in_samples = max(1, min(fade_in_samples, samples.size))
|
||
samples[:fade_in_samples] *= np.linspace(0.0, 1.0, fade_in_samples, endpoint=True)
|
||
if fade_out:
|
||
fade_out_samples = int(sample_rate * (fade_out_ms / 1000.0))
|
||
fade_out_samples = max(1, min(fade_out_samples, samples.size))
|
||
samples[-fade_out_samples:] *= np.linspace(1.0, 0.0, fade_out_samples, endpoint=True)
|
||
|
||
return np.clip(samples, -32768, 32767).astype("<i2").tobytes()
|
||
except Exception:
|
||
# Fallback: never block audio delivery on smoothing failure.
|
||
return pcm_bytes
|
||
|
||
async def _speak(self, text: str, audio_event_priority: int = 10) -> None:
|
||
"""
|
||
Synthesize and send speech.
|
||
|
||
Args:
|
||
text: Text to speak
|
||
audio_event_priority: Priority for output.audio.start/end events
|
||
"""
|
||
if not self._tts_output_enabled():
|
||
return
|
||
|
||
if not text.strip() or not self.tts_service:
|
||
return
|
||
|
||
try:
|
||
self._drop_outbound_audio = False
|
||
# Start latency tracking for greeting
|
||
speak_start_time = time.time()
|
||
first_audio_sent = False
|
||
|
||
# Send track start event
|
||
tts_id = self._start_tts()
|
||
self._mark_client_playback_started(tts_id)
|
||
await self._send_event({
|
||
**ev(
|
||
"output.audio.start",
|
||
trackId=self.track_audio_out,
|
||
)
|
||
}, priority=audio_event_priority)
|
||
|
||
self._is_bot_speaking = True
|
||
|
||
# Stream TTS audio
|
||
async for chunk in self.tts_service.synthesize_stream(text):
|
||
if self._interrupt_event.is_set():
|
||
logger.info("TTS interrupted by barge-in")
|
||
break
|
||
|
||
# Track and log first audio packet latency (TTFB)
|
||
if not first_audio_sent:
|
||
ttfb_ms = (time.time() - speak_start_time) * 1000
|
||
first_audio_sent = True
|
||
logger.info(f"[TTFB] Greeting first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||
|
||
# Send TTFB event to client
|
||
await self._send_event({
|
||
**ev(
|
||
"metrics.ttfb",
|
||
trackId=self.track_audio_out,
|
||
latencyMs=round(ttfb_ms),
|
||
)
|
||
}, priority=25)
|
||
|
||
# Send audio to client
|
||
await self._send_audio(chunk.audio, priority=50)
|
||
|
||
# Small delay to prevent flooding
|
||
await asyncio.sleep(0.01)
|
||
|
||
# Send track end event
|
||
await self._flush_audio_out_frames(priority=50)
|
||
await self._send_event({
|
||
**ev(
|
||
"output.audio.end",
|
||
trackId=self.track_audio_out,
|
||
)
|
||
}, priority=audio_event_priority)
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("TTS cancelled")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"TTS error: {e}")
|
||
finally:
|
||
self._is_bot_speaking = False
|
||
|
||
async def _handle_barge_in(self) -> None:
|
||
"""Handle user barge-in (interruption)."""
|
||
if not self._is_bot_speaking:
|
||
return
|
||
|
||
logger.info("Barge-in detected - interrupting bot speech")
|
||
|
||
# Reset barge-in tracking
|
||
self._barge_in_speech_start_time = None
|
||
self._barge_in_speech_frames = 0
|
||
self._barge_in_silence_frames = 0
|
||
|
||
# IMPORTANT: Signal interruption FIRST to stop audio sending
|
||
self._interrupt_event.set()
|
||
self._is_bot_speaking = False
|
||
self._drop_outbound_audio = True
|
||
self._audio_out_frame_buffer = b""
|
||
self._clear_client_playback_tracking()
|
||
interrupted_turn_id = self._current_turn_id
|
||
interrupted_utterance_id = self._current_utterance_id
|
||
interrupted_response_id = self._current_response_id
|
||
|
||
# Send interrupt event to client IMMEDIATELY
|
||
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
||
await self._send_event({
|
||
**ev(
|
||
"response.interrupted",
|
||
trackId=self.track_audio_out,
|
||
),
|
||
"turn_id": interrupted_turn_id,
|
||
"utterance_id": interrupted_utterance_id,
|
||
"response_id": interrupted_response_id,
|
||
}, priority=0)
|
||
|
||
# Cancel TTS
|
||
if self.tts_service:
|
||
await self.tts_service.cancel()
|
||
|
||
# Cancel LLM
|
||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||
self.llm_service.cancel()
|
||
|
||
# Interrupt conversation only if there is no active turn task.
|
||
# When a turn task exists, it will handle end_assistant_turn() to avoid double callbacks.
|
||
if not (self._current_turn_task and not self._current_turn_task.done()):
|
||
await self.conversation.interrupt()
|
||
|
||
# Reset for new user turn
|
||
await self.conversation.start_user_turn()
|
||
self._audio_buffer = b""
|
||
self.eou_detector.reset()
|
||
self._asr_capture_active = False
|
||
self._asr_capture_started_ms = 0.0
|
||
self._pending_speech_audio = b""
|
||
|
||
async def _stop_current_speech(self) -> None:
|
||
"""Stop any current speech task."""
|
||
self._drop_outbound_audio = True
|
||
self._audio_out_frame_buffer = b""
|
||
self._clear_client_playback_tracking()
|
||
if self._current_turn_task and not self._current_turn_task.done():
|
||
self._interrupt_event.set()
|
||
self._current_turn_task.cancel()
|
||
try:
|
||
await self._current_turn_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
# Ensure underlying services are cancelled to avoid leaking work/audio
|
||
if self.tts_service:
|
||
await self.tts_service.cancel()
|
||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||
self.llm_service.cancel()
|
||
|
||
self._is_bot_speaking = False
|
||
self._interrupt_event.clear()
|
||
|
||
async def cleanup(self) -> None:
|
||
"""Cleanup pipeline resources."""
|
||
logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}")
|
||
|
||
self._running = False
|
||
await self._stop_current_speech()
|
||
if self._outbound_task and not self._outbound_task.done():
|
||
await self._enqueue_outbound("stop", None, priority=-1000)
|
||
await self._outbound_task
|
||
self._outbound_task = None
|
||
|
||
# Disconnect services
|
||
if self.llm_service:
|
||
await self.llm_service.disconnect()
|
||
if self.tts_service:
|
||
await self.tts_service.disconnect()
|
||
if self.asr_service:
|
||
await self.asr_service.disconnect()
|
||
|
||
def _get_timestamp_ms(self) -> int:
|
||
"""Get current timestamp in milliseconds."""
|
||
import time
|
||
return int(time.time() * 1000)
|
||
|
||
@property
|
||
def is_speaking(self) -> bool:
|
||
"""Check if assistant audio is still active (server send or client playback)."""
|
||
return self._is_bot_speaking or self.is_client_playing_audio
|
||
|
||
@property
|
||
def is_client_playing_audio(self) -> bool:
|
||
"""Check if client has unacknowledged assistant audio playback."""
|
||
return bool(self._pending_client_playback_tts_ids)
|
||
|
||
@property
|
||
def state(self) -> ConversationState:
|
||
"""Get current conversation state."""
|
||
return self.conversation.state
|