- Introduced Volcengine as a new provider for both TTS and ASR services. - Updated configuration files to include Volcengine-specific parameters such as app_id, resource_id, and uid. - Enhanced the ASR service to support streaming mode with Volcengine's API. - Modified existing tests to validate the integration of Volcengine services. - Updated documentation to reflect the addition of Volcengine as a supported provider for TTS and ASR. - Refactored service factory to accommodate Volcengine alongside existing providers.
2863 lines
122 KiB
Python
2863 lines
122 KiB
Python
"""Full duplex audio pipeline for AI voice conversation.
|
||
|
||
This module implements the core duplex pipeline that orchestrates:
|
||
- VAD (Voice Activity Detection)
|
||
- EOU (End of Utterance) Detection
|
||
- ASR (Automatic Speech Recognition) - optional
|
||
- LLM (Language Model)
|
||
- TTS (Text-to-Speech)
|
||
|
||
Inspired by pipecat's frame-based architecture and active-call's
|
||
event-driven design.
|
||
"""
|
||
|
||
import asyncio
|
||
import audioop
|
||
import io
|
||
import json
|
||
import time
|
||
import uuid
|
||
import wave
|
||
from pathlib import Path
|
||
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
|
||
|
||
import numpy as np
|
||
import aiohttp
|
||
from loguru import logger
|
||
|
||
from app.config import settings
|
||
from providers.factory.default import DefaultRealtimeServiceFactory
|
||
from runtime.conversation import ConversationManager, ConversationState
|
||
from runtime.events import get_event_bus
|
||
from runtime.ports import (
|
||
ASRMode,
|
||
ASRPort,
|
||
ASRServiceSpec,
|
||
LLMPort,
|
||
LLMServiceSpec,
|
||
OfflineASRPort,
|
||
RealtimeServiceFactory,
|
||
StreamingASRPort,
|
||
TTSPort,
|
||
TTSServiceSpec,
|
||
)
|
||
from tools.executor import execute_server_tool
|
||
from runtime.transports import BaseTransport
|
||
from protocol.ws_v1.schema import ev
|
||
from processors.eou import EouDetector
|
||
from processors.vad import SileroVAD, VADProcessor
|
||
from providers.common.base import LLMMessage, LLMStreamEvent
|
||
from providers.common.streaming_text import extract_tts_sentence, has_spoken_content
|
||
|
||
|
||
class DuplexPipeline:
|
||
"""
|
||
Full duplex audio pipeline for AI voice conversation.
|
||
|
||
Handles bidirectional audio flow with:
|
||
- User speech detection and transcription
|
||
- AI response generation
|
||
- Text-to-speech synthesis
|
||
- Barge-in (interruption) support
|
||
|
||
Architecture (inspired by pipecat):
|
||
|
||
User Audio → VAD → EOU → [ASR] → LLM → TTS → Audio Out
|
||
↓
|
||
Barge-in Detection → Interrupt
|
||
"""
|
||
|
||
_SENTENCE_END_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "\n"})
|
||
_SENTENCE_TRAILING_CHARS = frozenset({"。", "!", "?", ".", "!", "?", "…", "~", "~", "\n"})
|
||
_SENTENCE_CLOSERS = frozenset({'"', "'", "”", "’", ")", "]", "}", ")", "】", "」", "』", "》"})
|
||
_MIN_SPLIT_SPOKEN_CHARS = 6
|
||
_TOOL_WAIT_TIMEOUT_SECONDS = 60.0
|
||
_SERVER_TOOL_TIMEOUT_SECONDS = 15.0
|
||
TRACK_AUDIO_IN = "audio_in"
|
||
TRACK_AUDIO_OUT = "audio_out"
|
||
TRACK_CONTROL = "control"
|
||
_PCM_FRAME_BYTES = 640 # 16k mono pcm_s16le, 20ms
|
||
_ASR_DELTA_THROTTLE_MS = 500
|
||
_LLM_DELTA_THROTTLE_MS = 80
|
||
_ASR_CAPTURE_MAX_MS = 15000
|
||
_ASR_STREAM_FINAL_TIMEOUT_MS = 800
|
||
_OPENER_PRE_ROLL_MS = 180
|
||
_DEFAULT_TOOL_SCHEMAS: Dict[str, Dict[str, Any]] = {
|
||
"current_time": {
|
||
"name": "current_time",
|
||
"description": "Get current local time",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {},
|
||
"required": [],
|
||
},
|
||
},
|
||
"calculator": {
|
||
"name": "calculator",
|
||
"description": "Execute a math expression",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"expression": {"type": "string", "description": "Math expression, e.g. 2 + 3 * 4"},
|
||
},
|
||
"required": ["expression"],
|
||
},
|
||
},
|
||
"code_interpreter": {
|
||
"name": "code_interpreter",
|
||
"description": "Safely evaluate a Python expression",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"code": {"type": "string", "description": "Python expression to evaluate"},
|
||
},
|
||
"required": ["code"],
|
||
},
|
||
},
|
||
"turn_on_camera": {
|
||
"name": "turn_on_camera",
|
||
"description": "Turn on client camera",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {},
|
||
"required": [],
|
||
},
|
||
},
|
||
"turn_off_camera": {
|
||
"name": "turn_off_camera",
|
||
"description": "Turn off client camera",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {},
|
||
"required": [],
|
||
},
|
||
},
|
||
"increase_volume": {
|
||
"name": "increase_volume",
|
||
"description": "Increase client volume",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"step": {"type": "integer", "description": "Volume increase step, default 1"},
|
||
},
|
||
"required": [],
|
||
},
|
||
},
|
||
"decrease_volume": {
|
||
"name": "decrease_volume",
|
||
"description": "Decrease client volume",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"step": {"type": "integer", "description": "Volume decrease step, default 1"},
|
||
},
|
||
"required": [],
|
||
},
|
||
},
|
||
"voice_msg_prompt": {
|
||
"name": "voice_msg_prompt",
|
||
"description": "Speak a message prompt on client side",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"msg": {"type": "string", "description": "Message text to speak"},
|
||
},
|
||
"required": ["msg"],
|
||
},
|
||
},
|
||
"text_msg_prompt": {
|
||
"name": "text_msg_prompt",
|
||
"description": "Show a text message prompt dialog on client side",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"msg": {"type": "string", "description": "Message text to display"},
|
||
},
|
||
"required": ["msg"],
|
||
},
|
||
},
|
||
"voice_choice_prompt": {
|
||
"name": "voice_choice_prompt",
|
||
"description": "Speak a question and show options on client side, then wait for selection",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"question": {"type": "string", "description": "Question text to show"},
|
||
"options": {
|
||
"type": "array",
|
||
"description": "Selectable options (string or object with id/label/value)",
|
||
"minItems": 2,
|
||
"items": {
|
||
"anyOf": [
|
||
{"type": "string"},
|
||
{
|
||
"type": "object",
|
||
"properties": {
|
||
"id": {"type": "string"},
|
||
"label": {"type": "string"},
|
||
"value": {"type": "string"},
|
||
},
|
||
"required": ["label"],
|
||
},
|
||
]
|
||
},
|
||
},
|
||
"voice_text": {
|
||
"type": "string",
|
||
"description": "Optional voice text. Falls back to question when omitted.",
|
||
},
|
||
},
|
||
"required": ["question", "options"],
|
||
},
|
||
},
|
||
"text_choice_prompt": {
|
||
"name": "text_choice_prompt",
|
||
"description": "Show a text-only choice prompt on client side and wait for selection",
|
||
"parameters": {
|
||
"type": "object",
|
||
"properties": {
|
||
"question": {"type": "string", "description": "Question text to show"},
|
||
"options": {
|
||
"type": "array",
|
||
"description": "Selectable options (string or object with id/label/value)",
|
||
"minItems": 2,
|
||
"items": {
|
||
"anyOf": [
|
||
{"type": "string"},
|
||
{
|
||
"type": "object",
|
||
"properties": {
|
||
"id": {"type": "string"},
|
||
"label": {"type": "string"},
|
||
"value": {"type": "string"},
|
||
},
|
||
"required": ["label"],
|
||
},
|
||
]
|
||
},
|
||
},
|
||
},
|
||
"required": ["question", "options"],
|
||
},
|
||
},
|
||
}
|
||
_DEFAULT_CLIENT_EXECUTORS = frozenset({
|
||
"turn_on_camera",
|
||
"turn_off_camera",
|
||
"increase_volume",
|
||
"decrease_volume",
|
||
"voice_msg_prompt",
|
||
"text_msg_prompt",
|
||
"voice_choice_prompt",
|
||
"text_choice_prompt",
|
||
})
|
||
_TOOL_NAME_ALIASES = {
|
||
"voice_message_prompt": "voice_msg_prompt",
|
||
}
|
||
|
||
@classmethod
|
||
def _normalize_tool_name(cls, raw_name: Any) -> str:
|
||
name = str(raw_name or "").strip()
|
||
if not name:
|
||
return ""
|
||
return cls._TOOL_NAME_ALIASES.get(name, name)
|
||
|
||
def __init__(
|
||
self,
|
||
transport: BaseTransport,
|
||
session_id: str,
|
||
llm_service: Optional[LLMPort] = None,
|
||
tts_service: Optional[TTSPort] = None,
|
||
asr_service: Optional[ASRPort] = None,
|
||
system_prompt: Optional[str] = None,
|
||
greeting: Optional[str] = None,
|
||
knowledge_searcher: Optional[
|
||
Callable[..., Awaitable[List[Dict[str, Any]]]]
|
||
] = None,
|
||
tool_resource_resolver: Optional[
|
||
Callable[[str], Awaitable[Optional[Dict[str, Any]]]]
|
||
] = None,
|
||
server_tool_executor: Optional[
|
||
Callable[[Dict[str, Any]], Awaitable[Dict[str, Any]]]
|
||
] = None,
|
||
service_factory: Optional[RealtimeServiceFactory] = None,
|
||
):
|
||
"""
|
||
Initialize duplex pipeline.
|
||
|
||
Args:
|
||
transport: Transport for sending audio/events
|
||
session_id: Session identifier
|
||
llm_service: Optional injected LLM port implementation
|
||
tts_service: Optional injected TTS port implementation
|
||
asr_service: ASR service (optional)
|
||
system_prompt: System prompt for LLM
|
||
greeting: Optional greeting to speak on start
|
||
"""
|
||
self.transport = transport
|
||
self.session_id = session_id
|
||
self.event_bus = get_event_bus()
|
||
self.track_audio_in = self.TRACK_AUDIO_IN
|
||
self.track_audio_out = self.TRACK_AUDIO_OUT
|
||
self.track_control = self.TRACK_CONTROL
|
||
|
||
# Initialize VAD
|
||
self.vad_model = SileroVAD(
|
||
model_path=settings.vad_model_path,
|
||
sample_rate=settings.sample_rate
|
||
)
|
||
self.vad_processor = VADProcessor(
|
||
vad_model=self.vad_model,
|
||
threshold=settings.vad_threshold
|
||
)
|
||
|
||
# Initialize EOU detector
|
||
self.eou_detector = EouDetector(
|
||
silence_threshold_ms=settings.vad_eou_threshold_ms,
|
||
min_speech_duration_ms=settings.vad_min_speech_duration_ms
|
||
)
|
||
|
||
# Initialize services
|
||
self.llm_service = llm_service
|
||
self.tts_service = tts_service
|
||
self.asr_service = asr_service # Will be initialized in start()
|
||
self._asr_mode: ASRMode = self._resolve_asr_mode(
|
||
settings.asr_provider,
|
||
getattr(asr_service, "mode", None),
|
||
)
|
||
self._service_factory = service_factory or DefaultRealtimeServiceFactory()
|
||
self._knowledge_searcher = knowledge_searcher
|
||
self._tool_resource_resolver = tool_resource_resolver
|
||
self._server_tool_executor = server_tool_executor
|
||
|
||
# Track last sent transcript to avoid duplicates
|
||
self._last_sent_transcript = ""
|
||
self._latest_asr_interim_text = ""
|
||
self._pending_transcript_delta: str = ""
|
||
self._last_transcript_delta_emit_ms: float = 0.0
|
||
|
||
# Conversation manager
|
||
self.conversation = ConversationManager(
|
||
system_prompt=system_prompt,
|
||
greeting=greeting
|
||
)
|
||
|
||
# State
|
||
self._running = True
|
||
self._is_bot_speaking = False
|
||
self._current_turn_task: Optional[asyncio.Task] = None
|
||
self._audio_buffer: bytes = b""
|
||
max_buffer_seconds = settings.max_audio_buffer_seconds
|
||
self._max_audio_buffer_bytes = int(settings.sample_rate * 2 * max_buffer_seconds)
|
||
self._asr_start_min_speech_ms: int = settings.asr_start_min_speech_ms
|
||
self._asr_capture_active: bool = False
|
||
self._asr_capture_started_ms: float = 0.0
|
||
self._pending_speech_audio: bytes = b""
|
||
# Keep a short rolling pre-speech window so VAD transition latency
|
||
# does not clip the first phoneme/character sent to ASR.
|
||
pre_speech_ms = settings.asr_pre_speech_ms
|
||
self._asr_pre_speech_bytes = int(settings.sample_rate * 2 * (pre_speech_ms / 1000.0))
|
||
self._pre_speech_buffer: bytes = b""
|
||
# Add a tiny trailing silence tail before final ASR to avoid
|
||
# clipping the last phoneme at utterance boundaries.
|
||
asr_final_tail_ms = settings.asr_final_tail_ms
|
||
self._asr_final_tail_bytes = int(settings.sample_rate * 2 * (asr_final_tail_ms / 1000.0))
|
||
self._last_vad_status: str = "Silence"
|
||
self._process_lock = asyncio.Lock()
|
||
# Priority outbound dispatcher (lower value = higher priority).
|
||
self._outbound_q: asyncio.PriorityQueue[Tuple[int, int, str, Any]] = asyncio.PriorityQueue()
|
||
self._outbound_seq = 0
|
||
self._outbound_task: Optional[asyncio.Task] = None
|
||
self._drop_outbound_audio = False
|
||
self._audio_out_frame_buffer: bytes = b""
|
||
|
||
# Interruption handling
|
||
self._interrupt_event = asyncio.Event()
|
||
|
||
# Latency tracking - TTFB (Time to First Byte)
|
||
self._turn_start_time: Optional[float] = None
|
||
self._first_audio_sent: bool = False
|
||
|
||
# Barge-in filtering - require minimum speech duration to interrupt
|
||
self._barge_in_speech_start_time: Optional[float] = None
|
||
self._barge_in_min_duration_ms: int = settings.barge_in_min_duration_ms
|
||
self._barge_in_silence_tolerance_ms: int = settings.barge_in_silence_tolerance_ms
|
||
self._barge_in_speech_frames: int = 0 # Count speech frames
|
||
self._barge_in_silence_frames: int = 0 # Count silence frames during potential barge-in
|
||
|
||
# Runtime overrides injected from session.start metadata
|
||
self._runtime_llm: Dict[str, Any] = {}
|
||
self._runtime_asr: Dict[str, Any] = {}
|
||
self._runtime_tts: Dict[str, Any] = {}
|
||
self._runtime_output: Dict[str, Any] = {}
|
||
self._runtime_system_prompt: Optional[str] = None
|
||
self._runtime_first_turn_mode: str = "bot_first"
|
||
self._runtime_greeting: Optional[str] = None
|
||
self._runtime_generated_opener_enabled: Optional[bool] = None
|
||
self._runtime_manual_opener_tool_calls: List[Any] = []
|
||
self._runtime_opener_audio: Dict[str, Any] = {}
|
||
self._runtime_barge_in_enabled: Optional[bool] = None
|
||
self._runtime_barge_in_min_duration_ms: Optional[int] = None
|
||
self._runtime_knowledge: Dict[str, Any] = {}
|
||
self._runtime_knowledge_base_id: Optional[str] = None
|
||
raw_default_tools = settings.tools if isinstance(settings.tools, list) else []
|
||
self._runtime_tools: List[Any] = list(raw_default_tools)
|
||
self._runtime_tool_executor: Dict[str, str] = {}
|
||
self._runtime_tool_default_args: Dict[str, Dict[str, Any]] = {}
|
||
self._runtime_tool_id_map: Dict[str, str] = {}
|
||
self._runtime_tool_display_names: Dict[str, str] = {}
|
||
self._runtime_tool_wait_for_response: Dict[str, bool] = {}
|
||
self._pending_tool_waiters: Dict[str, asyncio.Future] = {}
|
||
self._early_tool_results: Dict[str, Dict[str, Any]] = {}
|
||
self._completed_tool_call_ids: set[str] = set()
|
||
self._pending_client_tool_call_ids: set[str] = set()
|
||
self._pending_client_playback_tts_ids: set[str] = set()
|
||
self._tts_playback_context: Dict[str, Dict[str, Optional[str]]] = {}
|
||
self._last_client_played_tts_id: Optional[str] = None
|
||
self._last_client_played_response_id: Optional[str] = None
|
||
self._last_client_played_turn_id: Optional[str] = None
|
||
self._last_client_played_at_ms: Optional[int] = None
|
||
self._next_seq: Optional[Callable[[], int]] = None
|
||
self._local_seq: int = 0
|
||
|
||
# Cross-service correlation IDs
|
||
self._turn_count: int = 0
|
||
self._response_count: int = 0
|
||
self._tts_count: int = 0
|
||
self._utterance_count: int = 0
|
||
self._current_turn_id: Optional[str] = None
|
||
self._current_utterance_id: Optional[str] = None
|
||
self._current_response_id: Optional[str] = None
|
||
self._current_tts_id: Optional[str] = None
|
||
self._pending_llm_delta: str = ""
|
||
self._last_llm_delta_emit_ms: float = 0.0
|
||
|
||
self._runtime_tool_executor = self._resolved_tool_executor_map()
|
||
self._runtime_tool_default_args = self._resolved_tool_default_args_map()
|
||
self._runtime_tool_id_map = self._resolved_tool_id_map()
|
||
self._runtime_tool_display_names = self._resolved_tool_display_name_map()
|
||
self._runtime_tool_wait_for_response = self._resolved_tool_wait_for_response_map()
|
||
self._initial_greeting_emitted = False
|
||
|
||
if self._server_tool_executor is None:
|
||
if self._tool_resource_resolver:
|
||
async def _executor(call: Dict[str, Any]) -> Dict[str, Any]:
|
||
return await execute_server_tool(
|
||
call,
|
||
tool_resource_fetcher=self._tool_resource_resolver,
|
||
)
|
||
|
||
self._server_tool_executor = _executor
|
||
else:
|
||
self._server_tool_executor = execute_server_tool
|
||
|
||
logger.info(f"DuplexPipeline initialized for session {session_id}")
|
||
|
||
def set_event_sequence_provider(self, provider: Callable[[], int]) -> None:
|
||
"""Use session-scoped monotonic sequence provider for envelope events."""
|
||
self._next_seq = provider
|
||
|
||
def apply_runtime_overrides(self, metadata: Optional[Dict[str, Any]]) -> None:
|
||
"""
|
||
Apply runtime overrides from WS session.start metadata.
|
||
|
||
Expected metadata shape:
|
||
{
|
||
"systemPrompt": "...",
|
||
"greeting": "...",
|
||
"services": {
|
||
"llm": {...},
|
||
"asr": {...},
|
||
"tts": {...}
|
||
}
|
||
}
|
||
"""
|
||
if not metadata:
|
||
return
|
||
|
||
if "systemPrompt" in metadata:
|
||
self._runtime_system_prompt = str(metadata.get("systemPrompt") or "")
|
||
if self._runtime_system_prompt:
|
||
self.conversation.system_prompt = self._runtime_system_prompt
|
||
if "firstTurnMode" in metadata:
|
||
raw_mode = str(metadata.get("firstTurnMode") or "").strip().lower()
|
||
self._runtime_first_turn_mode = "user_first" if raw_mode == "user_first" else "bot_first"
|
||
if "greeting" in metadata:
|
||
greeting_payload = metadata.get("greeting")
|
||
if isinstance(greeting_payload, dict):
|
||
self._runtime_greeting = str(greeting_payload.get("text") or "")
|
||
generated_flag = self._coerce_bool(greeting_payload.get("generated"))
|
||
if generated_flag is not None:
|
||
self._runtime_generated_opener_enabled = generated_flag
|
||
else:
|
||
self._runtime_greeting = str(greeting_payload or "")
|
||
self.conversation.greeting = self._runtime_greeting or None
|
||
generated_opener_flag = self._coerce_bool(metadata.get("generatedOpenerEnabled"))
|
||
if generated_opener_flag is not None:
|
||
self._runtime_generated_opener_enabled = generated_opener_flag
|
||
if "manualOpenerToolCalls" in metadata:
|
||
manual_calls = metadata.get("manualOpenerToolCalls")
|
||
self._runtime_manual_opener_tool_calls = manual_calls if isinstance(manual_calls, list) else []
|
||
|
||
services = metadata.get("services") or {}
|
||
if isinstance(services, dict):
|
||
if isinstance(services.get("llm"), dict):
|
||
self._runtime_llm = services["llm"]
|
||
if isinstance(services.get("asr"), dict):
|
||
self._runtime_asr = services["asr"]
|
||
if isinstance(services.get("tts"), dict):
|
||
self._runtime_tts = services["tts"]
|
||
output = metadata.get("output") or {}
|
||
if isinstance(output, dict):
|
||
self._runtime_output = output
|
||
barge_in = metadata.get("bargeIn")
|
||
if isinstance(barge_in, dict):
|
||
barge_in_enabled = self._coerce_bool(barge_in.get("enabled"))
|
||
if barge_in_enabled is not None:
|
||
self._runtime_barge_in_enabled = barge_in_enabled
|
||
min_duration = barge_in.get("minDurationMs")
|
||
if isinstance(min_duration, (int, float, str)):
|
||
try:
|
||
self._runtime_barge_in_min_duration_ms = max(0, int(min_duration))
|
||
except (TypeError, ValueError):
|
||
self._runtime_barge_in_min_duration_ms = None
|
||
|
||
knowledge_base_id = metadata.get("knowledgeBaseId")
|
||
if knowledge_base_id is not None:
|
||
kb_id = str(knowledge_base_id).strip()
|
||
self._runtime_knowledge_base_id = kb_id or None
|
||
|
||
knowledge = metadata.get("knowledge")
|
||
if isinstance(knowledge, dict):
|
||
self._runtime_knowledge = knowledge
|
||
opener_audio = metadata.get("openerAudio")
|
||
if isinstance(opener_audio, dict):
|
||
self._runtime_opener_audio = dict(opener_audio)
|
||
kb_id = str(knowledge.get("kbId") or knowledge.get("knowledgeBaseId") or "").strip()
|
||
if kb_id:
|
||
self._runtime_knowledge_base_id = kb_id
|
||
|
||
tools_payload = metadata.get("tools")
|
||
if isinstance(tools_payload, list):
|
||
self._runtime_tools = tools_payload
|
||
self._runtime_tool_executor = self._resolved_tool_executor_map()
|
||
self._runtime_tool_default_args = self._resolved_tool_default_args_map()
|
||
self._runtime_tool_id_map = self._resolved_tool_id_map()
|
||
self._runtime_tool_display_names = self._resolved_tool_display_name_map()
|
||
self._runtime_tool_wait_for_response = self._resolved_tool_wait_for_response_map()
|
||
elif "tools" in metadata:
|
||
self._runtime_tools = []
|
||
self._runtime_tool_executor = {}
|
||
self._runtime_tool_default_args = {}
|
||
self._runtime_tool_id_map = {}
|
||
self._runtime_tool_display_names = {}
|
||
self._runtime_tool_wait_for_response = {}
|
||
|
||
if self.llm_service and hasattr(self.llm_service, "set_knowledge_config"):
|
||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||
if self.llm_service and hasattr(self.llm_service, "set_tool_schemas"):
|
||
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||
|
||
def resolved_runtime_config(self) -> Dict[str, Any]:
|
||
"""Return current effective runtime configuration without secrets."""
|
||
llm_provider = str(self._runtime_llm.get("provider") or settings.llm_provider).lower()
|
||
llm_base_url = (
|
||
self._runtime_llm.get("baseUrl")
|
||
or settings.llm_api_url
|
||
or self._default_llm_base_url(llm_provider)
|
||
)
|
||
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||
asr_provider = str(self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
|
||
if not output_mode:
|
||
output_mode = "audio" if self._tts_output_enabled() else "text"
|
||
|
||
tts_model = str(
|
||
self._runtime_tts.get("model")
|
||
or settings.tts_model
|
||
or (self._default_dashscope_tts_model() if self._is_dashscope_tts_provider(tts_provider) else "")
|
||
)
|
||
tts_config = {
|
||
"enabled": self._tts_output_enabled(),
|
||
"provider": tts_provider,
|
||
"model": tts_model,
|
||
"voice": str(self._runtime_tts.get("voice") or settings.tts_voice),
|
||
"speed": float(self._runtime_tts.get("speed") or settings.tts_speed),
|
||
}
|
||
if self._is_dashscope_tts_provider(tts_provider):
|
||
tts_config["mode"] = self._resolved_dashscope_tts_mode()
|
||
|
||
return {
|
||
"output": {"mode": output_mode},
|
||
"services": {
|
||
"llm": {
|
||
"provider": llm_provider,
|
||
"model": str(self._runtime_llm.get("model") or settings.llm_model),
|
||
"baseUrl": llm_base_url,
|
||
},
|
||
"asr": {
|
||
"provider": asr_provider,
|
||
"mode": self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode")),
|
||
"model": str(self._runtime_asr.get("model") or settings.asr_model or ""),
|
||
"enableInterim": self._asr_interim_enabled(),
|
||
"interimIntervalMs": int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms),
|
||
"minAudioMs": int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms),
|
||
},
|
||
"tts": tts_config,
|
||
},
|
||
"tools": {
|
||
"allowlist": self._resolved_tool_allowlist(),
|
||
},
|
||
"opener": {
|
||
"generated": self._generated_opener_enabled(),
|
||
"manualToolCallCount": len(self._resolved_manual_opener_tool_calls()),
|
||
},
|
||
"tracks": {
|
||
"audio_in": self.track_audio_in,
|
||
"audio_out": self.track_audio_out,
|
||
"control": self.track_control,
|
||
},
|
||
}
|
||
|
||
def _next_event_seq(self) -> int:
|
||
if self._next_seq:
|
||
return self._next_seq()
|
||
self._local_seq += 1
|
||
return self._local_seq
|
||
|
||
def _event_source(self, event_type: str) -> str:
|
||
if event_type.startswith("transcript.") or event_type.startswith("input.speech_"):
|
||
return "asr"
|
||
if event_type.startswith("assistant.response."):
|
||
return "llm"
|
||
if event_type.startswith("assistant.tool_"):
|
||
return "tool"
|
||
if event_type.startswith("output.audio.") or event_type == "metrics.ttfb":
|
||
return "tts"
|
||
return "system"
|
||
|
||
def _new_id(self, prefix: str, counter: int) -> str:
|
||
return f"{prefix}_{counter}_{uuid.uuid4().hex[:8]}"
|
||
|
||
def _start_turn(self) -> str:
|
||
self._turn_count += 1
|
||
self._current_turn_id = self._new_id("turn", self._turn_count)
|
||
self._current_utterance_id = None
|
||
self._current_response_id = None
|
||
self._current_tts_id = None
|
||
return self._current_turn_id
|
||
|
||
def _start_response(self) -> str:
|
||
self._response_count += 1
|
||
self._current_response_id = self._new_id("resp", self._response_count)
|
||
self._current_tts_id = None
|
||
return self._current_response_id
|
||
|
||
def _start_tts(self) -> str:
|
||
self._tts_count += 1
|
||
tts_id = self._new_id("tts", self._tts_count)
|
||
self._current_tts_id = tts_id
|
||
self._tts_playback_context[tts_id] = {
|
||
"turn_id": self._current_turn_id,
|
||
"response_id": self._current_response_id,
|
||
}
|
||
return tts_id
|
||
|
||
def _finalize_utterance(self) -> str:
|
||
if self._current_utterance_id:
|
||
return self._current_utterance_id
|
||
self._utterance_count += 1
|
||
self._current_utterance_id = self._new_id("utt", self._utterance_count)
|
||
if not self._current_turn_id:
|
||
self._start_turn()
|
||
return self._current_utterance_id
|
||
|
||
def _mark_client_playback_started(self, tts_id: Optional[str]) -> None:
|
||
normalized_tts_id = str(tts_id or "").strip()
|
||
if not normalized_tts_id:
|
||
return
|
||
self._pending_client_playback_tts_ids.add(normalized_tts_id)
|
||
|
||
def _clear_client_playback_tracking(self) -> None:
|
||
self._pending_client_playback_tts_ids.clear()
|
||
self._tts_playback_context.clear()
|
||
|
||
async def handle_output_audio_played(
|
||
self,
|
||
*,
|
||
tts_id: str,
|
||
response_id: Optional[str] = None,
|
||
turn_id: Optional[str] = None,
|
||
played_at_ms: Optional[int] = None,
|
||
played_ms: Optional[int] = None,
|
||
) -> None:
|
||
"""Record client-side playback completion for a TTS segment."""
|
||
normalized_tts_id = str(tts_id or "").strip()
|
||
if not normalized_tts_id:
|
||
return
|
||
|
||
was_pending = normalized_tts_id in self._pending_client_playback_tts_ids
|
||
self._pending_client_playback_tts_ids.discard(normalized_tts_id)
|
||
|
||
context = self._tts_playback_context.pop(normalized_tts_id, {})
|
||
resolved_response_id = str(response_id or context.get("response_id") or "").strip() or None
|
||
resolved_turn_id = str(turn_id or context.get("turn_id") or "").strip() or None
|
||
|
||
self._last_client_played_tts_id = normalized_tts_id
|
||
self._last_client_played_response_id = resolved_response_id
|
||
self._last_client_played_turn_id = resolved_turn_id
|
||
if isinstance(played_at_ms, int) and played_at_ms >= 0:
|
||
self._last_client_played_at_ms = played_at_ms
|
||
else:
|
||
self._last_client_played_at_ms = self._get_timestamp_ms()
|
||
|
||
duration_ms = played_ms if isinstance(played_ms, int) and played_ms >= 0 else None
|
||
logger.info(
|
||
f"[PlaybackAck] tts_id={normalized_tts_id} response_id={resolved_response_id or '-'} "
|
||
f"turn_id={resolved_turn_id or '-'} pending_before={was_pending} "
|
||
f"pending_now={len(self._pending_client_playback_tts_ids)} "
|
||
f"played_ms={duration_ms if duration_ms is not None else '-'}"
|
||
)
|
||
|
||
def _envelope_event(self, event: Dict[str, Any]) -> Dict[str, Any]:
|
||
event_type = str(event.get("type") or "")
|
||
source = str(event.get("source") or self._event_source(event_type))
|
||
track_id = event.get("trackId")
|
||
if not track_id:
|
||
if source == "asr":
|
||
track_id = self.track_audio_in
|
||
elif source in {"llm", "tts", "tool"}:
|
||
track_id = self.track_audio_out
|
||
else:
|
||
track_id = self.track_control
|
||
|
||
data = event.get("data")
|
||
if not isinstance(data, dict):
|
||
data = {}
|
||
explicit_turn_id = str(event.get("turn_id") or "").strip() or None
|
||
explicit_utterance_id = str(event.get("utterance_id") or "").strip() or None
|
||
explicit_response_id = str(event.get("response_id") or "").strip() or None
|
||
explicit_tts_id = str(event.get("tts_id") or "").strip() or None
|
||
if self._current_turn_id:
|
||
data.setdefault("turn_id", self._current_turn_id)
|
||
if self._current_utterance_id:
|
||
data.setdefault("utterance_id", self._current_utterance_id)
|
||
if self._current_response_id:
|
||
data.setdefault("response_id", self._current_response_id)
|
||
if self._current_tts_id:
|
||
data.setdefault("tts_id", self._current_tts_id)
|
||
if explicit_turn_id:
|
||
data["turn_id"] = explicit_turn_id
|
||
if explicit_utterance_id:
|
||
data["utterance_id"] = explicit_utterance_id
|
||
if explicit_response_id:
|
||
data["response_id"] = explicit_response_id
|
||
if explicit_tts_id:
|
||
data["tts_id"] = explicit_tts_id
|
||
|
||
for k, v in event.items():
|
||
if k in {
|
||
"type",
|
||
"timestamp",
|
||
"sessionId",
|
||
"seq",
|
||
"source",
|
||
"trackId",
|
||
"data",
|
||
"turn_id",
|
||
"utterance_id",
|
||
"response_id",
|
||
"tts_id",
|
||
}:
|
||
continue
|
||
data.setdefault(k, v)
|
||
|
||
event["sessionId"] = self.session_id
|
||
event["seq"] = self._next_event_seq()
|
||
event["source"] = source
|
||
event["trackId"] = track_id
|
||
event["data"] = data
|
||
return event
|
||
|
||
@staticmethod
|
||
def _coerce_bool(value: Any) -> Optional[bool]:
|
||
if isinstance(value, bool):
|
||
return value
|
||
if isinstance(value, (int, float)):
|
||
return bool(value)
|
||
if isinstance(value, str):
|
||
normalized = value.strip().lower()
|
||
if normalized in {"1", "true", "yes", "on", "enabled"}:
|
||
return True
|
||
if normalized in {"0", "false", "no", "off", "disabled"}:
|
||
return False
|
||
return None
|
||
|
||
@staticmethod
|
||
def _coerce_json_object(value: Any) -> Optional[Dict[str, Any]]:
|
||
if isinstance(value, dict):
|
||
return dict(value)
|
||
if isinstance(value, str):
|
||
raw = value.strip()
|
||
if not raw:
|
||
return None
|
||
try:
|
||
parsed = json.loads(raw)
|
||
except json.JSONDecodeError:
|
||
logger.warning("Ignoring invalid JSON object config: {}", raw[:120])
|
||
return None
|
||
if isinstance(parsed, dict):
|
||
return parsed
|
||
return None
|
||
|
||
@staticmethod
|
||
def _is_dashscope_tts_provider(provider: Any) -> bool:
|
||
normalized = str(provider or "").strip().lower()
|
||
return normalized == "dashscope"
|
||
|
||
@staticmethod
|
||
def _resolve_asr_mode(provider: Any, raw_mode: Any = None) -> ASRMode:
|
||
normalized_mode = str(raw_mode or "").strip().lower()
|
||
if normalized_mode in {"offline", "streaming"}:
|
||
return normalized_mode # type: ignore[return-value]
|
||
normalized_provider = str(provider or "").strip().lower()
|
||
if normalized_provider in {"dashscope", "volcengine"}:
|
||
return "streaming"
|
||
return "offline"
|
||
|
||
def _offline_asr(self) -> OfflineASRPort:
|
||
return self.asr_service # type: ignore[return-value]
|
||
|
||
def _streaming_asr(self) -> StreamingASRPort:
|
||
return self.asr_service # type: ignore[return-value]
|
||
|
||
@staticmethod
|
||
def _default_llm_base_url(provider: Any) -> Optional[str]:
|
||
normalized = str(provider or "").strip().lower()
|
||
if normalized == "siliconflow":
|
||
return "https://api.siliconflow.cn/v1"
|
||
return None
|
||
|
||
@staticmethod
|
||
def _default_dashscope_tts_model() -> str:
|
||
return "qwen3-tts-flash-realtime"
|
||
|
||
def _resolved_dashscope_tts_mode(self) -> str:
|
||
raw_mode = str(self._runtime_tts.get("mode") or settings.tts_mode or "commit").strip().lower()
|
||
if raw_mode in {"commit", "server_commit"}:
|
||
return raw_mode
|
||
return "commit"
|
||
|
||
def _use_engine_sentence_split_for_tts(self) -> bool:
|
||
tts_provider = str(self._runtime_tts.get("provider") or settings.tts_provider).strip().lower()
|
||
if not self._is_dashscope_tts_provider(tts_provider):
|
||
return True
|
||
# DashScope commit mode is client-driven and expects engine-side segmentation.
|
||
# server_commit mode lets DashScope handle segmentation on appended text.
|
||
return self._resolved_dashscope_tts_mode() != "server_commit"
|
||
|
||
def _tts_output_enabled(self) -> bool:
|
||
enabled = self._coerce_bool(self._runtime_tts.get("enabled"))
|
||
if enabled is not None:
|
||
return enabled
|
||
|
||
output_mode = str(self._runtime_output.get("mode") or "").strip().lower()
|
||
if output_mode in {"text", "text_only", "text-only"}:
|
||
return False
|
||
|
||
return True
|
||
|
||
def _generated_opener_enabled(self) -> bool:
|
||
return self._runtime_generated_opener_enabled is True
|
||
|
||
def _bot_starts_first(self) -> bool:
|
||
return self._runtime_first_turn_mode != "user_first"
|
||
|
||
def _barge_in_enabled(self) -> bool:
|
||
if self._runtime_barge_in_enabled is not None:
|
||
return self._runtime_barge_in_enabled
|
||
return True
|
||
|
||
def _resolved_barge_in_min_duration_ms(self) -> int:
|
||
if self._runtime_barge_in_min_duration_ms is not None:
|
||
return self._runtime_barge_in_min_duration_ms
|
||
return self._barge_in_min_duration_ms
|
||
|
||
def _asr_interim_enabled(self) -> bool:
|
||
current_mode = self._asr_mode
|
||
if not self.asr_service:
|
||
current_mode = self._resolve_asr_mode(
|
||
self._runtime_asr.get("provider") or settings.asr_provider,
|
||
self._runtime_asr.get("mode"),
|
||
)
|
||
if current_mode != "offline":
|
||
return True
|
||
enabled = self._coerce_bool(self._runtime_asr.get("enableInterim"))
|
||
if enabled is not None:
|
||
return enabled
|
||
return bool(settings.asr_enable_interim)
|
||
|
||
def _barge_in_silence_tolerance_frames(self) -> int:
|
||
"""Convert silence tolerance from ms to frame count using current chunk size."""
|
||
chunk_ms = max(1, settings.chunk_size_ms)
|
||
return max(1, int(np.ceil(self._barge_in_silence_tolerance_ms / chunk_ms)))
|
||
|
||
async def _generate_runtime_greeting(self) -> Optional[str]:
|
||
if not self.llm_service:
|
||
return None
|
||
|
||
system_context = (self.conversation.system_prompt or self._runtime_system_prompt or "").strip()
|
||
# Keep context concise to avoid overloading greeting generation.
|
||
if len(system_context) > 1200:
|
||
system_context = system_context[:1200]
|
||
system_prompt = (
|
||
"你是语音通话助手的开场白生成器。"
|
||
"请只输出一句自然、简洁、友好的中文开场白。"
|
||
"不要使用引号,不要使用 markdown,不要加解释。"
|
||
)
|
||
user_prompt = "请生成一句中文开场白(不超过25个汉字)。"
|
||
if system_context:
|
||
user_prompt += f"\n\n以下是该助手的系统提示词,请据此决定语气、角色和边界:\n{system_context}"
|
||
|
||
try:
|
||
generated = await self.llm_service.generate(
|
||
[
|
||
LLMMessage(role="system", content=system_prompt),
|
||
LLMMessage(role="user", content=user_prompt),
|
||
],
|
||
temperature=0.7,
|
||
max_tokens=64,
|
||
)
|
||
except Exception as exc:
|
||
logger.warning(f"Failed to generate runtime greeting: {exc}")
|
||
return None
|
||
|
||
text = (generated or "").strip()
|
||
if not text:
|
||
return None
|
||
return text.strip().strip('"').strip("'")
|
||
|
||
async def start(self) -> None:
|
||
"""Start the pipeline and connect services."""
|
||
try:
|
||
# Connect LLM service
|
||
if not self.llm_service:
|
||
llm_provider = (self._runtime_llm.get("provider") or settings.llm_provider).lower()
|
||
llm_api_key = self._runtime_llm.get("apiKey")
|
||
llm_base_url = (
|
||
self._runtime_llm.get("baseUrl")
|
||
or settings.llm_api_url
|
||
or self._default_llm_base_url(llm_provider)
|
||
)
|
||
llm_model = self._runtime_llm.get("model") or settings.llm_model
|
||
self.llm_service = self._service_factory.create_llm_service(
|
||
LLMServiceSpec(
|
||
provider=llm_provider,
|
||
model=str(llm_model),
|
||
api_key=str(llm_api_key).strip() if llm_api_key else None,
|
||
base_url=str(llm_base_url).strip() if llm_base_url else None,
|
||
system_prompt=self.conversation.system_prompt,
|
||
temperature=settings.llm_temperature,
|
||
knowledge_config=self._resolved_knowledge_config(),
|
||
knowledge_searcher=self._knowledge_searcher,
|
||
)
|
||
)
|
||
|
||
if hasattr(self.llm_service, "set_knowledge_config"):
|
||
self.llm_service.set_knowledge_config(self._resolved_knowledge_config())
|
||
if hasattr(self.llm_service, "set_tool_schemas"):
|
||
self.llm_service.set_tool_schemas(self._resolved_tool_schemas())
|
||
|
||
await self.llm_service.connect()
|
||
|
||
tts_output_enabled = self._tts_output_enabled()
|
||
|
||
# Connect TTS service only when audio output is enabled.
|
||
if tts_output_enabled:
|
||
if not self.tts_service:
|
||
tts_provider = (self._runtime_tts.get("provider") or settings.tts_provider).lower()
|
||
tts_api_key = self._runtime_tts.get("apiKey")
|
||
tts_api_url = self._runtime_tts.get("baseUrl") or settings.tts_api_url
|
||
tts_voice = self._runtime_tts.get("voice") or settings.tts_voice
|
||
tts_model = self._runtime_tts.get("model") or settings.tts_model
|
||
tts_app_id = self._runtime_tts.get("appId") or settings.tts_app_id
|
||
tts_resource_id = self._runtime_tts.get("resourceId") or settings.tts_resource_id
|
||
tts_cluster = self._runtime_tts.get("cluster") or settings.tts_cluster
|
||
tts_uid = self._runtime_tts.get("uid") or settings.tts_uid
|
||
tts_speed = float(self._runtime_tts.get("speed") or settings.tts_speed)
|
||
tts_mode = self._resolved_dashscope_tts_mode()
|
||
runtime_mode = str(self._runtime_tts.get("mode") or "").strip()
|
||
if runtime_mode and not self._is_dashscope_tts_provider(tts_provider):
|
||
logger.warning(
|
||
"services.tts.mode is DashScope-only and will be ignored "
|
||
f"for provider={tts_provider}"
|
||
)
|
||
self.tts_service = self._service_factory.create_tts_service(
|
||
TTSServiceSpec(
|
||
provider=tts_provider,
|
||
api_key=str(tts_api_key).strip() if tts_api_key else None,
|
||
api_url=str(tts_api_url).strip() if tts_api_url else None,
|
||
voice=str(tts_voice),
|
||
model=str(tts_model).strip() if tts_model else None,
|
||
app_id=str(tts_app_id).strip() if tts_app_id else None,
|
||
resource_id=str(tts_resource_id).strip() if tts_resource_id else None,
|
||
cluster=str(tts_cluster).strip() if tts_cluster else None,
|
||
uid=str(tts_uid).strip() if tts_uid else None,
|
||
sample_rate=settings.sample_rate,
|
||
speed=tts_speed,
|
||
mode=str(tts_mode),
|
||
)
|
||
)
|
||
|
||
try:
|
||
await self.tts_service.connect()
|
||
except Exception as e:
|
||
logger.warning(f"TTS backend unavailable ({e}); falling back to default TTS adapter")
|
||
self.tts_service = self._service_factory.create_tts_service(
|
||
TTSServiceSpec(
|
||
provider="mock",
|
||
voice="mock",
|
||
sample_rate=settings.sample_rate,
|
||
)
|
||
)
|
||
await self.tts_service.connect()
|
||
else:
|
||
self.tts_service = None
|
||
logger.info("TTS output disabled by runtime metadata")
|
||
|
||
# Connect ASR service
|
||
if not self.asr_service:
|
||
asr_provider = (self._runtime_asr.get("provider") or settings.asr_provider).lower()
|
||
asr_api_key = self._runtime_asr.get("apiKey")
|
||
asr_api_url = self._runtime_asr.get("baseUrl") or settings.asr_api_url
|
||
asr_model = self._runtime_asr.get("model") or settings.asr_model
|
||
asr_app_id = self._runtime_asr.get("appId") or settings.asr_app_id
|
||
asr_resource_id = self._runtime_asr.get("resourceId") or settings.asr_resource_id
|
||
asr_cluster = self._runtime_asr.get("cluster") or settings.asr_cluster
|
||
asr_uid = self._runtime_asr.get("uid") or settings.asr_uid
|
||
asr_request_params = self._coerce_json_object(self._runtime_asr.get("requestParams"))
|
||
if asr_request_params is None:
|
||
asr_request_params = self._coerce_json_object(settings.asr_request_params_json)
|
||
asr_enable_interim = self._coerce_bool(self._runtime_asr.get("enableInterim"))
|
||
if asr_enable_interim is None:
|
||
asr_enable_interim = bool(settings.asr_enable_interim)
|
||
asr_interim_interval = int(self._runtime_asr.get("interimIntervalMs") or settings.asr_interim_interval_ms)
|
||
asr_min_audio_ms = int(self._runtime_asr.get("minAudioMs") or settings.asr_min_audio_ms)
|
||
asr_mode = self._resolve_asr_mode(asr_provider, self._runtime_asr.get("mode"))
|
||
|
||
self.asr_service = self._service_factory.create_asr_service(
|
||
ASRServiceSpec(
|
||
provider=asr_provider,
|
||
sample_rate=settings.sample_rate,
|
||
mode=asr_mode,
|
||
language="auto",
|
||
api_key=str(asr_api_key).strip() if asr_api_key else None,
|
||
api_url=str(asr_api_url).strip() if asr_api_url else None,
|
||
model=str(asr_model).strip() if asr_model else None,
|
||
app_id=str(asr_app_id).strip() if asr_app_id else None,
|
||
resource_id=str(asr_resource_id).strip() if asr_resource_id else None,
|
||
cluster=str(asr_cluster).strip() if asr_cluster else None,
|
||
uid=str(asr_uid).strip() if asr_uid else None,
|
||
request_params=asr_request_params,
|
||
enable_interim=asr_enable_interim,
|
||
interim_interval_ms=asr_interim_interval,
|
||
min_audio_for_interim_ms=asr_min_audio_ms,
|
||
on_transcript=self._on_transcript_callback,
|
||
)
|
||
)
|
||
self._asr_mode = self._resolve_asr_mode(
|
||
self._runtime_asr.get("provider") or settings.asr_provider,
|
||
getattr(self.asr_service, "mode", self._runtime_asr.get("mode")),
|
||
)
|
||
|
||
await self.asr_service.connect()
|
||
|
||
logger.info("DuplexPipeline services connected (asr_mode={})", self._asr_mode)
|
||
if not self._outbound_task or self._outbound_task.done():
|
||
self._outbound_task = asyncio.create_task(self._outbound_loop())
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to start pipeline: {e}")
|
||
raise
|
||
|
||
async def emit_initial_greeting(self) -> None:
|
||
"""
|
||
Emit opener after session activation.
|
||
|
||
Ordering target:
|
||
1) frontend receives `session.started` (shows connected/ready)
|
||
2) frontend receives opener text event
|
||
3) frontend receives opener audio events/chunks
|
||
"""
|
||
if self._initial_greeting_emitted:
|
||
return
|
||
|
||
self._initial_greeting_emitted = True
|
||
if not self._bot_starts_first():
|
||
return
|
||
|
||
if self._generated_opener_enabled() and self._resolved_tool_schemas():
|
||
# Run generated opener as a normal tool-capable assistant turn.
|
||
# Use an empty user input so the opener can be driven by system prompt policy.
|
||
if self._current_turn_task and not self._current_turn_task.done():
|
||
logger.info("Skip initial generated opener: assistant turn already in progress")
|
||
return
|
||
self._current_turn_task = asyncio.create_task(self._handle_turn(""))
|
||
logger.info("Initial generated opener started with tool-calling path")
|
||
return
|
||
|
||
manual_opener_execution: Dict[str, List[Dict[str, Any]]] = {"toolCalls": [], "toolResults": []}
|
||
if not self._generated_opener_enabled() and self._resolved_manual_opener_tool_calls():
|
||
self._start_turn()
|
||
self._start_response()
|
||
manual_opener_execution = await self._execute_manual_opener_tool_calls()
|
||
|
||
greeting_to_speak = self.conversation.greeting
|
||
if self._generated_opener_enabled():
|
||
generated_greeting = await self._generate_runtime_greeting()
|
||
if generated_greeting:
|
||
greeting_to_speak = generated_greeting
|
||
self.conversation.greeting = generated_greeting
|
||
|
||
if not greeting_to_speak:
|
||
if (
|
||
not self._generated_opener_enabled()
|
||
and manual_opener_execution.get("toolCalls")
|
||
and not (self._current_turn_task and not self._current_turn_task.done())
|
||
):
|
||
follow_up_context = self._build_manual_opener_follow_up_context(manual_opener_execution)
|
||
self._current_turn_task = asyncio.create_task(
|
||
self._handle_turn("", system_context=follow_up_context)
|
||
)
|
||
logger.info("Initial manual opener follow-up started")
|
||
return
|
||
|
||
if not self._current_turn_id:
|
||
self._start_turn()
|
||
if not self._current_response_id:
|
||
self._start_response()
|
||
await self._send_event(
|
||
ev(
|
||
"assistant.response.final",
|
||
text=greeting_to_speak,
|
||
trackId=self.track_audio_out,
|
||
),
|
||
priority=20,
|
||
)
|
||
await self.conversation.add_assistant_turn(greeting_to_speak)
|
||
|
||
# Give client mic capture a short head start so opener can be interrupted immediately.
|
||
await asyncio.sleep(self._OPENER_PRE_ROLL_MS / 1000.0)
|
||
|
||
used_preloaded_audio = await self._play_preloaded_opener_audio()
|
||
if self._tts_output_enabled() and not used_preloaded_audio:
|
||
# Keep opener text ahead of opener voice start.
|
||
await self._speak(greeting_to_speak, audio_event_priority=30)
|
||
|
||
async def _play_preloaded_opener_audio(self) -> bool:
|
||
"""
|
||
Play opener audio from runtime metadata cache or YAML-configured local file.
|
||
|
||
Returns True when preloaded audio is played successfully.
|
||
"""
|
||
if not self._tts_output_enabled():
|
||
return False
|
||
|
||
pcm_bytes = await self._load_preloaded_opener_pcm()
|
||
if not pcm_bytes:
|
||
return False
|
||
|
||
try:
|
||
self._drop_outbound_audio = False
|
||
tts_id = self._start_tts()
|
||
self._mark_client_playback_started(tts_id)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"output.audio.start",
|
||
trackId=self.track_audio_out,
|
||
)
|
||
},
|
||
priority=30,
|
||
)
|
||
|
||
self._is_bot_speaking = True
|
||
await self._send_audio(pcm_bytes, priority=50)
|
||
await self._flush_audio_out_frames(priority=50)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"output.audio.end",
|
||
trackId=self.track_audio_out,
|
||
)
|
||
},
|
||
priority=30,
|
||
)
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"Failed to play preloaded opener audio, fallback to TTS: {e}")
|
||
return False
|
||
finally:
|
||
self._is_bot_speaking = False
|
||
|
||
async def _load_preloaded_opener_pcm(self) -> Optional[bytes]:
|
||
# 1) Runtime metadata from backend config
|
||
opener_audio = self._runtime_opener_audio if isinstance(self._runtime_opener_audio, dict) else {}
|
||
if bool(opener_audio.get("enabled")) and bool(opener_audio.get("ready")):
|
||
pcm_url = str(opener_audio.get("pcmUrl") or "").strip()
|
||
if pcm_url:
|
||
resolved_url = pcm_url
|
||
if pcm_url.startswith("/"):
|
||
backend_url = str(settings.backend_url or "").strip().rstrip("/")
|
||
if backend_url:
|
||
resolved_url = f"{backend_url}{pcm_url}"
|
||
try:
|
||
timeout = aiohttp.ClientTimeout(total=10)
|
||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||
async with session.get(resolved_url) as resp:
|
||
resp.raise_for_status()
|
||
payload = await resp.read()
|
||
if payload:
|
||
return payload
|
||
except Exception as e:
|
||
logger.warning(f"Failed to fetch opener audio from backend ({resolved_url}): {e}")
|
||
|
||
# 2) Standalone fallback via YAML
|
||
opener_audio_file = str(settings.duplex_opener_audio_file or "").strip()
|
||
if not opener_audio_file:
|
||
return None
|
||
path = Path(opener_audio_file)
|
||
if not path.is_absolute():
|
||
path = (Path.cwd() / path).resolve()
|
||
if not path.exists() or not path.is_file():
|
||
logger.warning(f"Configured opener audio file does not exist: {path}")
|
||
return None
|
||
try:
|
||
raw = path.read_bytes()
|
||
suffix = path.suffix.lower()
|
||
if suffix == ".wav":
|
||
pcm, _ = self._wav_to_pcm16_mono_16k(raw)
|
||
return pcm
|
||
# .pcm raw pcm_s16le 16k mono
|
||
return raw
|
||
except Exception as e:
|
||
logger.warning(f"Failed to read opener audio file {path}: {e}")
|
||
return None
|
||
|
||
def _wav_to_pcm16_mono_16k(self, wav_bytes: bytes) -> Tuple[bytes, int]:
|
||
with wave.open(io.BytesIO(wav_bytes), "rb") as wav_file:
|
||
channels = wav_file.getnchannels()
|
||
sample_width = wav_file.getsampwidth()
|
||
sample_rate = wav_file.getframerate()
|
||
nframes = wav_file.getnframes()
|
||
raw = wav_file.readframes(nframes)
|
||
|
||
if sample_width != 2:
|
||
raise ValueError(f"Unsupported WAV sample width: {sample_width * 8}bit")
|
||
if channels > 1:
|
||
raw = audioop.tomono(raw, sample_width, 0.5, 0.5)
|
||
if sample_rate != 16000:
|
||
raw, _ = audioop.ratecv(raw, sample_width, 1, sample_rate, 16000, None)
|
||
duration_ms = int((len(raw) / (16000 * 2)) * 1000)
|
||
return raw, duration_ms
|
||
|
||
async def _enqueue_outbound(self, kind: str, payload: Any, priority: int) -> None:
|
||
"""Queue outbound message with priority ordering."""
|
||
self._outbound_seq += 1
|
||
await self._outbound_q.put((priority, self._outbound_seq, kind, payload))
|
||
|
||
async def _send_event(self, event: Dict[str, Any], priority: int = 20) -> None:
|
||
await self._enqueue_outbound("event", self._envelope_event(event), priority)
|
||
|
||
async def _send_audio(self, pcm_bytes: bytes, priority: int = 50) -> None:
|
||
if not pcm_bytes:
|
||
return
|
||
self._audio_out_frame_buffer += pcm_bytes
|
||
while len(self._audio_out_frame_buffer) >= self._PCM_FRAME_BYTES:
|
||
frame = self._audio_out_frame_buffer[: self._PCM_FRAME_BYTES]
|
||
self._audio_out_frame_buffer = self._audio_out_frame_buffer[self._PCM_FRAME_BYTES :]
|
||
await self._enqueue_outbound("audio", frame, priority)
|
||
|
||
async def _flush_audio_out_frames(self, priority: int = 50) -> None:
|
||
"""Flush remaining outbound audio as one padded 20ms PCM frame."""
|
||
if not self._audio_out_frame_buffer:
|
||
return
|
||
tail = self._audio_out_frame_buffer
|
||
self._audio_out_frame_buffer = b""
|
||
if len(tail) < self._PCM_FRAME_BYTES:
|
||
tail = tail + (b"\x00" * (self._PCM_FRAME_BYTES - len(tail)))
|
||
await self._enqueue_outbound("audio", tail, priority)
|
||
|
||
async def _emit_transcript_delta(self, text: str) -> None:
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"transcript.delta",
|
||
trackId=self.track_audio_in,
|
||
text=text,
|
||
)
|
||
},
|
||
priority=30,
|
||
)
|
||
|
||
async def _emit_llm_delta(
|
||
self,
|
||
text: str,
|
||
*,
|
||
turn_id: Optional[str] = None,
|
||
utterance_id: Optional[str] = None,
|
||
response_id: Optional[str] = None,
|
||
) -> None:
|
||
event = {
|
||
**ev(
|
||
"assistant.response.delta",
|
||
trackId=self.track_audio_out,
|
||
text=text,
|
||
)
|
||
}
|
||
if turn_id:
|
||
event["turn_id"] = turn_id
|
||
if utterance_id:
|
||
event["utterance_id"] = utterance_id
|
||
if response_id:
|
||
event["response_id"] = response_id
|
||
await self._send_event(
|
||
event,
|
||
priority=20,
|
||
)
|
||
|
||
async def _flush_pending_llm_delta(
|
||
self,
|
||
*,
|
||
turn_id: Optional[str] = None,
|
||
utterance_id: Optional[str] = None,
|
||
response_id: Optional[str] = None,
|
||
) -> None:
|
||
if not self._pending_llm_delta:
|
||
return
|
||
chunk = self._pending_llm_delta
|
||
self._pending_llm_delta = ""
|
||
self._last_llm_delta_emit_ms = time.monotonic() * 1000.0
|
||
await self._emit_llm_delta(
|
||
chunk,
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
|
||
async def _outbound_loop(self) -> None:
|
||
"""Single sender loop that enforces priority for interrupt events."""
|
||
while True:
|
||
_priority, _seq, kind, payload = await self._outbound_q.get()
|
||
try:
|
||
if kind == "stop":
|
||
return
|
||
if kind == "audio":
|
||
if self._drop_outbound_audio:
|
||
continue
|
||
await self.transport.send_audio(payload)
|
||
elif kind == "event":
|
||
await self.transport.send_event(payload)
|
||
except Exception as e:
|
||
logger.error(f"Outbound send error ({kind}): {e}")
|
||
finally:
|
||
self._outbound_q.task_done()
|
||
|
||
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||
"""
|
||
Process incoming audio chunk.
|
||
|
||
This is the main entry point for audio from the user.
|
||
|
||
Args:
|
||
pcm_bytes: PCM audio data (16-bit, mono, 16kHz)
|
||
"""
|
||
if not self._running:
|
||
return
|
||
|
||
try:
|
||
async with self._process_lock:
|
||
if pcm_bytes:
|
||
self._pre_speech_buffer += pcm_bytes
|
||
if len(self._pre_speech_buffer) > self._asr_pre_speech_bytes:
|
||
self._pre_speech_buffer = self._pre_speech_buffer[-self._asr_pre_speech_bytes:]
|
||
|
||
# 1. Process through VAD
|
||
vad_result = self.vad_processor.process(pcm_bytes, settings.chunk_size_ms)
|
||
|
||
vad_status = "Silence"
|
||
if vad_result:
|
||
event_type, probability = vad_result
|
||
vad_status = "Speech" if event_type == "speaking" else "Silence"
|
||
|
||
# Emit VAD event
|
||
await self.event_bus.publish(event_type, {
|
||
"trackId": self.track_audio_in,
|
||
"probability": probability
|
||
})
|
||
await self._send_event(
|
||
ev(
|
||
"input.speech_started" if event_type == "speaking" else "input.speech_stopped",
|
||
trackId=self.track_audio_in,
|
||
probability=probability,
|
||
),
|
||
priority=30,
|
||
)
|
||
else:
|
||
# No state change - keep previous status
|
||
vad_status = self._last_vad_status
|
||
|
||
# Update state based on VAD
|
||
if vad_status == "Speech" and self._last_vad_status != "Speech":
|
||
await self._on_speech_start()
|
||
|
||
self._last_vad_status = vad_status
|
||
|
||
# 2. Check for barge-in (user speaking while bot speaking)
|
||
# Filter false interruptions by requiring minimum speech duration
|
||
if self._is_bot_speaking and self._barge_in_enabled():
|
||
if vad_status == "Speech":
|
||
# User is speaking while bot is speaking
|
||
self._barge_in_silence_frames = 0 # Reset silence counter
|
||
|
||
if self._barge_in_speech_start_time is None:
|
||
# Start tracking speech duration
|
||
self._barge_in_speech_start_time = time.time()
|
||
self._barge_in_speech_frames = 1
|
||
logger.debug("Potential barge-in detected, tracking duration...")
|
||
else:
|
||
self._barge_in_speech_frames += 1
|
||
# Check if speech duration exceeds threshold
|
||
speech_duration_ms = (time.time() - self._barge_in_speech_start_time) * 1000
|
||
if speech_duration_ms >= self._resolved_barge_in_min_duration_ms():
|
||
logger.info(f"Barge-in confirmed after {speech_duration_ms:.0f}ms of speech ({self._barge_in_speech_frames} frames)")
|
||
await self._handle_barge_in()
|
||
else:
|
||
# Silence frame during potential barge-in
|
||
if self._barge_in_speech_start_time is not None:
|
||
self._barge_in_silence_frames += 1
|
||
# Allow brief silence gaps (VAD flickering)
|
||
if self._barge_in_silence_frames > self._barge_in_silence_tolerance_frames():
|
||
# Too much silence - reset barge-in tracking
|
||
logger.debug(f"Barge-in cancelled after {self._barge_in_silence_frames} silence frames")
|
||
self._barge_in_speech_start_time = None
|
||
self._barge_in_speech_frames = 0
|
||
self._barge_in_silence_frames = 0
|
||
elif self._is_bot_speaking and not self._barge_in_enabled():
|
||
self._barge_in_speech_start_time = None
|
||
self._barge_in_speech_frames = 0
|
||
self._barge_in_silence_frames = 0
|
||
|
||
# 3. Buffer audio for ASR.
|
||
# Gate ASR startup by a short speech-duration threshold to reduce
|
||
# false positives from micro noises, then always close the turn
|
||
# by EOU once ASR has started.
|
||
just_started_asr = False
|
||
if vad_status == "Speech" and not self._asr_capture_active:
|
||
self._pending_speech_audio += pcm_bytes
|
||
pending_ms = (len(self._pending_speech_audio) / (settings.sample_rate * 2)) * 1000.0
|
||
if pending_ms >= self._asr_start_min_speech_ms:
|
||
await self._start_asr_capture()
|
||
just_started_asr = True
|
||
|
||
if self._asr_capture_active:
|
||
if not just_started_asr:
|
||
self._audio_buffer += pcm_bytes
|
||
if len(self._audio_buffer) > self._max_audio_buffer_bytes:
|
||
# Keep only the most recent audio to cap memory usage
|
||
self._audio_buffer = self._audio_buffer[-self._max_audio_buffer_bytes:]
|
||
await self.asr_service.send_audio(pcm_bytes)
|
||
|
||
# For SiliconFlow ASR, trigger interim transcription periodically
|
||
# The service handles timing internally via start_interim_transcription()
|
||
|
||
# 4. Check for End of Utterance - this triggers LLM response
|
||
if self.eou_detector.process(vad_status, force_eligible=self._asr_capture_active):
|
||
await self._on_end_of_utterance()
|
||
elif (
|
||
self._asr_capture_active
|
||
and self._asr_capture_started_ms > 0.0
|
||
and (time.monotonic() * 1000.0 - self._asr_capture_started_ms) >= self._ASR_CAPTURE_MAX_MS
|
||
):
|
||
logger.warning(
|
||
f"[EOU] Force finalize after ASR capture timeout: {self._ASR_CAPTURE_MAX_MS}ms"
|
||
)
|
||
await self._on_end_of_utterance()
|
||
elif (
|
||
vad_status == "Silence"
|
||
and not self.eou_detector.is_speaking
|
||
and not self._asr_capture_active
|
||
and self.conversation.state == ConversationState.LISTENING
|
||
):
|
||
# Speech was too short to pass ASR gate; reset turn so next
|
||
# utterance can start cleanly.
|
||
self._pending_speech_audio = b""
|
||
self._audio_buffer = b""
|
||
self._last_sent_transcript = ""
|
||
await self.conversation.set_state(ConversationState.IDLE)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Pipeline audio processing error: {e}", exc_info=True)
|
||
|
||
async def process_text(self, text: str) -> None:
|
||
"""
|
||
Process text input (chat command).
|
||
|
||
Allows direct text input to bypass ASR.
|
||
|
||
Args:
|
||
text: User text input
|
||
"""
|
||
if not self._running:
|
||
return
|
||
|
||
logger.info(f"Processing text input: {text[:50]}...")
|
||
|
||
# Cancel any current speaking
|
||
await self._stop_current_speech()
|
||
|
||
self._start_turn()
|
||
self._finalize_utterance()
|
||
|
||
# Start new turn
|
||
await self.conversation.end_user_turn(text)
|
||
self._current_turn_task = asyncio.create_task(self._handle_turn(text))
|
||
|
||
async def interrupt(self) -> None:
|
||
"""Interrupt current bot speech (manual interrupt command)."""
|
||
await self._handle_barge_in()
|
||
|
||
async def _on_transcript_callback(self, text: str, is_final: bool) -> None:
|
||
"""
|
||
Callback for ASR transcription results.
|
||
|
||
Streams transcription to client for display.
|
||
|
||
Args:
|
||
text: Transcribed text
|
||
is_final: Whether this is the final transcription
|
||
"""
|
||
if not is_final and not self._asr_interim_enabled():
|
||
return
|
||
|
||
# Avoid sending duplicate transcripts
|
||
if text == self._last_sent_transcript and not is_final:
|
||
return
|
||
|
||
now_ms = time.monotonic() * 1000.0
|
||
self._last_sent_transcript = text
|
||
|
||
if is_final:
|
||
self._latest_asr_interim_text = ""
|
||
self._pending_transcript_delta = ""
|
||
self._last_transcript_delta_emit_ms = 0.0
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"transcript.final",
|
||
trackId=self.track_audio_in,
|
||
text=text,
|
||
)
|
||
},
|
||
priority=30,
|
||
)
|
||
logger.debug(f"Sent transcript (final): {text[:50]}...")
|
||
return
|
||
|
||
self._latest_asr_interim_text = text
|
||
self._pending_transcript_delta = text
|
||
should_emit = (
|
||
self._last_transcript_delta_emit_ms <= 0.0
|
||
or now_ms - self._last_transcript_delta_emit_ms >= self._ASR_DELTA_THROTTLE_MS
|
||
)
|
||
if should_emit and self._pending_transcript_delta:
|
||
delta = self._pending_transcript_delta
|
||
self._pending_transcript_delta = ""
|
||
self._last_transcript_delta_emit_ms = now_ms
|
||
await self._emit_transcript_delta(delta)
|
||
|
||
if not is_final:
|
||
logger.info(f"[ASR] ASR interim: {text[:100]}")
|
||
logger.debug(f"Sent transcript (interim): {text[:50]}...")
|
||
|
||
async def _on_speech_start(self) -> None:
|
||
"""Handle user starting to speak."""
|
||
if self.conversation.state in (ConversationState.IDLE, ConversationState.INTERRUPTED):
|
||
self._start_turn()
|
||
self._finalize_utterance()
|
||
await self.conversation.start_user_turn()
|
||
self._audio_buffer = b""
|
||
self._last_sent_transcript = ""
|
||
self._latest_asr_interim_text = ""
|
||
self.eou_detector.reset()
|
||
self._asr_capture_active = False
|
||
self._asr_capture_started_ms = 0.0
|
||
self._pending_speech_audio = b""
|
||
|
||
if self._asr_mode == "streaming":
|
||
self._streaming_asr().clear_utterance()
|
||
else:
|
||
self._offline_asr().clear_buffer()
|
||
|
||
logger.debug("User speech started")
|
||
|
||
async def _start_asr_capture(self) -> None:
|
||
"""Start ASR capture for the current turn after min speech gate passes."""
|
||
if self._asr_capture_active:
|
||
return
|
||
|
||
if self._asr_mode == "streaming":
|
||
await self._streaming_asr().begin_utterance()
|
||
else:
|
||
if self._asr_interim_enabled():
|
||
await self._offline_asr().start_interim_transcription()
|
||
|
||
# Prime ASR with a short pre-speech context window so the utterance
|
||
# start isn't lost while waiting for VAD to transition to Speech.
|
||
pre_roll = self._pre_speech_buffer
|
||
# _pre_speech_buffer already includes current speech frames; avoid
|
||
# duplicating onset audio when we append pending speech below.
|
||
if self._pending_speech_audio and len(pre_roll) > len(self._pending_speech_audio):
|
||
pre_roll = pre_roll[:-len(self._pending_speech_audio)]
|
||
elif self._pending_speech_audio:
|
||
pre_roll = b""
|
||
capture_audio = pre_roll + self._pending_speech_audio
|
||
if capture_audio:
|
||
await self.asr_service.send_audio(capture_audio)
|
||
self._audio_buffer = capture_audio[-self._max_audio_buffer_bytes:]
|
||
|
||
self._asr_capture_active = True
|
||
self._asr_capture_started_ms = time.monotonic() * 1000.0
|
||
logger.debug(
|
||
f"ASR capture started after speech gate ({self._asr_start_min_speech_ms}ms), "
|
||
f"capture={len(capture_audio)} bytes"
|
||
)
|
||
|
||
async def _on_end_of_utterance(self) -> None:
|
||
"""Handle end of user utterance."""
|
||
if self.conversation.state not in (ConversationState.LISTENING, ConversationState.INTERRUPTED):
|
||
# Prevent a stale ASR capture watchdog from repeatedly forcing EOU
|
||
# once the conversation has already moved past user-listening states.
|
||
self._asr_capture_active = False
|
||
self._asr_capture_started_ms = 0.0
|
||
self._pending_speech_audio = b""
|
||
return
|
||
|
||
user_text = ""
|
||
if self._asr_mode == "streaming":
|
||
streaming_asr = self._streaming_asr()
|
||
await streaming_asr.end_utterance()
|
||
user_text = await streaming_asr.wait_for_final_transcription(
|
||
timeout_ms=self._ASR_STREAM_FINAL_TIMEOUT_MS
|
||
)
|
||
if not user_text.strip():
|
||
user_text = self._latest_asr_interim_text
|
||
else:
|
||
# Add a tiny trailing silence tail to stabilize final-token decoding.
|
||
if self._asr_final_tail_bytes > 0:
|
||
final_tail = b"\x00" * self._asr_final_tail_bytes
|
||
await self.asr_service.send_audio(final_tail)
|
||
await self._offline_asr().stop_interim_transcription()
|
||
user_text = await self._offline_asr().get_final_transcription()
|
||
|
||
# Skip if no meaningful text
|
||
if not user_text or not user_text.strip():
|
||
logger.debug("[EOU] Detected but no transcription - skipping")
|
||
# Reset for next utterance
|
||
self._audio_buffer = b""
|
||
self._last_sent_transcript = ""
|
||
self._latest_asr_interim_text = ""
|
||
self._asr_capture_active = False
|
||
self._asr_capture_started_ms = 0.0
|
||
self._pending_speech_audio = b""
|
||
# Return to idle; don't force LISTENING which causes buffering on silence
|
||
await self.conversation.set_state(ConversationState.IDLE)
|
||
return
|
||
|
||
logger.info(f"[EOU] Detected - user said: {user_text[:100]}...")
|
||
self._finalize_utterance()
|
||
|
||
# For ASR backends that already emitted final via callback,
|
||
# avoid duplicating transcript.final on EOU.
|
||
if user_text != self._last_sent_transcript:
|
||
await self._send_event({
|
||
**ev(
|
||
"transcript.final",
|
||
trackId=self.track_audio_in,
|
||
text=user_text,
|
||
)
|
||
}, priority=25)
|
||
|
||
# Clear buffers
|
||
self._audio_buffer = b""
|
||
self._last_sent_transcript = ""
|
||
self._latest_asr_interim_text = ""
|
||
self._pending_transcript_delta = ""
|
||
self._last_transcript_delta_emit_ms = 0.0
|
||
self._asr_capture_active = False
|
||
self._asr_capture_started_ms = 0.0
|
||
self._pending_speech_audio = b""
|
||
|
||
# Process the turn - trigger LLM response
|
||
# Cancel any existing turn to avoid overlapping assistant responses
|
||
await self._stop_current_speech()
|
||
await self.conversation.end_user_turn(user_text)
|
||
self._current_turn_task = asyncio.create_task(self._handle_turn(user_text))
|
||
|
||
def _resolved_knowledge_config(self) -> Dict[str, Any]:
|
||
cfg: Dict[str, Any] = {}
|
||
if isinstance(self._runtime_knowledge, dict):
|
||
cfg.update(self._runtime_knowledge)
|
||
kb_id = self._runtime_knowledge_base_id or str(
|
||
cfg.get("kbId") or cfg.get("knowledgeBaseId") or ""
|
||
).strip()
|
||
if kb_id:
|
||
cfg["kbId"] = kb_id
|
||
cfg.setdefault("enabled", True)
|
||
return cfg
|
||
|
||
def _resolved_tool_schemas(self) -> List[Dict[str, Any]]:
|
||
schemas: List[Dict[str, Any]] = []
|
||
seen: set[str] = set()
|
||
for item in self._runtime_tools:
|
||
if isinstance(item, str):
|
||
tool_name = self._normalize_tool_name(item)
|
||
if not tool_name or tool_name in seen:
|
||
continue
|
||
seen.add(tool_name)
|
||
base = self._DEFAULT_TOOL_SCHEMAS.get(tool_name)
|
||
if base:
|
||
schemas.append(
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": base["name"],
|
||
"description": base.get("description") or "",
|
||
"parameters": base.get("parameters") or {"type": "object", "properties": {}},
|
||
},
|
||
}
|
||
)
|
||
else:
|
||
schemas.append(
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_name,
|
||
"description": f"Execute tool '{tool_name}'",
|
||
"parameters": {"type": "object", "properties": {}},
|
||
},
|
||
}
|
||
)
|
||
continue
|
||
|
||
if not isinstance(item, dict):
|
||
continue
|
||
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
fn_name = self._normalize_tool_name(fn.get("name"))
|
||
if not fn_name or fn_name in seen:
|
||
continue
|
||
seen.add(fn_name)
|
||
schemas.append(
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": fn_name,
|
||
"description": str(fn.get("description") or item.get("description") or ""),
|
||
"parameters": fn.get("parameters") or {"type": "object", "properties": {}},
|
||
},
|
||
}
|
||
)
|
||
continue
|
||
|
||
if item.get("name"):
|
||
item_name = self._normalize_tool_name(item.get("name"))
|
||
if not item_name or item_name in seen:
|
||
continue
|
||
seen.add(item_name)
|
||
schemas.append(
|
||
{
|
||
"type": "function",
|
||
"function": {
|
||
"name": item_name,
|
||
"description": str(item.get("description") or ""),
|
||
"parameters": item.get("parameters") or {"type": "object", "properties": {}},
|
||
},
|
||
}
|
||
)
|
||
return schemas
|
||
|
||
def _resolved_tool_executor_map(self) -> Dict[str, str]:
|
||
result: Dict[str, str] = {}
|
||
for item in self._runtime_tools:
|
||
if isinstance(item, str):
|
||
name = self._normalize_tool_name(item)
|
||
if name in self._DEFAULT_CLIENT_EXECUTORS:
|
||
result[name] = "client"
|
||
continue
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
name = self._normalize_tool_name(fn.get("name"))
|
||
else:
|
||
name = self._normalize_tool_name(item.get("name"))
|
||
if not name:
|
||
continue
|
||
executor = str(item.get("executor") or item.get("run_on") or "").strip().lower()
|
||
if executor in {"client", "server"}:
|
||
result[name] = executor
|
||
return result
|
||
|
||
def _resolved_tool_default_args_map(self) -> Dict[str, Dict[str, Any]]:
|
||
result: Dict[str, Dict[str, Any]] = {}
|
||
for item in self._runtime_tools:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
name = self._normalize_tool_name(fn.get("name"))
|
||
else:
|
||
name = self._normalize_tool_name(item.get("name"))
|
||
if not name:
|
||
continue
|
||
raw_defaults = item.get("defaultArgs")
|
||
if raw_defaults is None:
|
||
raw_defaults = item.get("default_args")
|
||
if isinstance(raw_defaults, dict):
|
||
result[name] = dict(raw_defaults)
|
||
return result
|
||
|
||
def _resolved_tool_wait_for_response_map(self) -> Dict[str, bool]:
|
||
result: Dict[str, bool] = {}
|
||
for item in self._runtime_tools:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
name = self._normalize_tool_name(fn.get("name"))
|
||
else:
|
||
name = self._normalize_tool_name(item.get("name"))
|
||
if not name:
|
||
continue
|
||
raw_wait = item.get("waitForResponse")
|
||
if raw_wait is None:
|
||
raw_wait = item.get("wait_for_response")
|
||
if isinstance(raw_wait, bool):
|
||
result[name] = raw_wait
|
||
return result
|
||
|
||
def _resolved_tool_id_map(self) -> Dict[str, str]:
|
||
result: Dict[str, str] = {}
|
||
for item in self._runtime_tools:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
alias = self._normalize_tool_name(fn.get("name"))
|
||
else:
|
||
alias = self._normalize_tool_name(item.get("name"))
|
||
if not alias:
|
||
continue
|
||
tool_id = self._normalize_tool_name(item.get("toolId") or item.get("tool_id") or alias)
|
||
if tool_id:
|
||
result[alias] = tool_id
|
||
return result
|
||
|
||
def _resolved_tool_display_name_map(self) -> Dict[str, str]:
|
||
result: Dict[str, str] = {}
|
||
for item in self._runtime_tools:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
name = self._normalize_tool_name(fn.get("name"))
|
||
else:
|
||
name = self._normalize_tool_name(item.get("name"))
|
||
if not name:
|
||
continue
|
||
display_name = str(
|
||
item.get("displayName")
|
||
or item.get("display_name")
|
||
or name
|
||
).strip()
|
||
if display_name:
|
||
result[name] = display_name
|
||
tool_id = self._normalize_tool_name(item.get("toolId") or item.get("tool_id") or "")
|
||
if tool_id:
|
||
result[tool_id] = display_name
|
||
return result
|
||
|
||
def _resolved_tool_allowlist(self) -> List[str]:
|
||
names: set[str] = set()
|
||
for item in self._runtime_tools:
|
||
if isinstance(item, str):
|
||
name = self._normalize_tool_name(item)
|
||
if name:
|
||
names.add(name)
|
||
continue
|
||
if not isinstance(item, dict):
|
||
continue
|
||
fn = item.get("function")
|
||
if isinstance(fn, dict) and fn.get("name"):
|
||
names.add(self._normalize_tool_name(fn.get("name")))
|
||
elif item.get("name"):
|
||
names.add(self._normalize_tool_name(item.get("name")))
|
||
return sorted([name for name in names if name])
|
||
|
||
def _resolved_manual_opener_tool_calls(self) -> List[Dict[str, Any]]:
|
||
result: List[Dict[str, Any]] = []
|
||
for item in self._runtime_manual_opener_tool_calls:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
tool_name = self._normalize_tool_name(str(
|
||
item.get("toolName")
|
||
or item.get("tool_name")
|
||
or item.get("name")
|
||
or ""
|
||
).strip())
|
||
if not tool_name:
|
||
continue
|
||
args_raw = item.get("arguments")
|
||
args: Dict[str, Any] = {}
|
||
if isinstance(args_raw, dict):
|
||
args = dict(args_raw)
|
||
elif isinstance(args_raw, str):
|
||
text_value = args_raw.strip()
|
||
if text_value:
|
||
try:
|
||
parsed = json.loads(text_value)
|
||
if isinstance(parsed, dict):
|
||
args = parsed
|
||
except Exception:
|
||
logger.warning(f"[OpenerTool] ignore invalid JSON args for tool={tool_name}")
|
||
result.append({"toolName": tool_name, "arguments": args})
|
||
return result[:8]
|
||
|
||
def _tool_name(self, tool_call: Dict[str, Any]) -> str:
|
||
fn = tool_call.get("function")
|
||
if isinstance(fn, dict):
|
||
return self._normalize_tool_name(fn.get("name"))
|
||
return ""
|
||
|
||
def _tool_id_for_name(self, tool_name: str) -> str:
|
||
normalized = self._normalize_tool_name(tool_name)
|
||
return self._normalize_tool_name(self._runtime_tool_id_map.get(normalized) or normalized)
|
||
|
||
def _tool_display_name(self, tool_name: str) -> str:
|
||
normalized = self._normalize_tool_name(tool_name)
|
||
return str(self._runtime_tool_display_names.get(normalized) or normalized).strip()
|
||
|
||
def _tool_wait_for_response(self, tool_name: str) -> bool:
|
||
normalized = self._normalize_tool_name(tool_name)
|
||
return bool(self._runtime_tool_wait_for_response.get(normalized, False))
|
||
|
||
def _tool_executor(self, tool_call: Dict[str, Any]) -> str:
|
||
name = self._tool_name(tool_call)
|
||
if name and name in self._runtime_tool_executor:
|
||
return self._runtime_tool_executor[name]
|
||
# Default to server execution unless explicitly marked as client.
|
||
return "server"
|
||
|
||
def _tool_arguments(self, tool_call: Dict[str, Any]) -> Dict[str, Any]:
|
||
fn = tool_call.get("function")
|
||
if not isinstance(fn, dict):
|
||
return {}
|
||
raw = fn.get("arguments")
|
||
if isinstance(raw, dict):
|
||
return raw
|
||
if isinstance(raw, str) and raw.strip():
|
||
try:
|
||
parsed = json.loads(raw)
|
||
return parsed if isinstance(parsed, dict) else {"raw": raw}
|
||
except Exception:
|
||
return {"raw": raw}
|
||
return {}
|
||
|
||
def _apply_tool_default_args(self, tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
|
||
normalized_tool_name = self._normalize_tool_name(tool_name)
|
||
defaults = self._runtime_tool_default_args.get(normalized_tool_name)
|
||
if not isinstance(defaults, dict) or not defaults:
|
||
return args
|
||
merged = dict(defaults)
|
||
if isinstance(args, dict):
|
||
merged.update(args)
|
||
return merged
|
||
|
||
def _build_manual_opener_follow_up_context(self, payload: Dict[str, List[Dict[str, Any]]]) -> str:
|
||
tool_calls = payload.get("toolCalls") if isinstance(payload.get("toolCalls"), list) else []
|
||
tool_results = payload.get("toolResults") if isinstance(payload.get("toolResults"), list) else []
|
||
return (
|
||
"Initial opener tool calls already executed. Continue with a natural assistant follow-up. "
|
||
"If tool results include user selections or values, use them in your response. "
|
||
"Never expose internal tool ids or raw payloads.\n"
|
||
f"opener_tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"
|
||
f"opener_tool_results={json.dumps(tool_results, ensure_ascii=False)}"
|
||
)
|
||
|
||
async def _execute_manual_opener_tool_calls(self) -> Dict[str, List[Dict[str, Any]]]:
|
||
calls = self._resolved_manual_opener_tool_calls()
|
||
tool_calls_for_context: List[Dict[str, Any]] = []
|
||
tool_results_for_context: List[Dict[str, Any]] = []
|
||
if not calls:
|
||
return {"toolCalls": tool_calls_for_context, "toolResults": tool_results_for_context}
|
||
|
||
for call in calls:
|
||
tool_name = str(call.get("toolName") or "").strip()
|
||
if not tool_name:
|
||
continue
|
||
tool_id = self._tool_id_for_name(tool_name)
|
||
tool_display_name = self._tool_display_name(tool_name) or tool_name
|
||
tool_arguments = call.get("arguments") if isinstance(call.get("arguments"), dict) else {}
|
||
merged_tool_arguments = self._apply_tool_default_args(tool_name, tool_arguments)
|
||
call_id = f"call_opener_{uuid.uuid4().hex[:10]}"
|
||
wait_for_response = self._tool_wait_for_response(tool_name)
|
||
tool_call = {
|
||
"id": call_id,
|
||
"type": "function",
|
||
"function": {
|
||
"name": tool_name,
|
||
"arguments": json.dumps(merged_tool_arguments, ensure_ascii=False),
|
||
},
|
||
}
|
||
executor = self._tool_executor(tool_call)
|
||
tool_calls_for_context.append(
|
||
{
|
||
"tool_call_id": call_id,
|
||
"tool_name": tool_name,
|
||
"tool_id": tool_id,
|
||
"tool_display_name": tool_display_name,
|
||
"arguments": merged_tool_arguments,
|
||
"wait_for_response": wait_for_response,
|
||
"executor": executor,
|
||
}
|
||
)
|
||
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"assistant.tool_call",
|
||
trackId=self.track_audio_out,
|
||
tool_call_id=call_id,
|
||
tool_name=tool_name,
|
||
tool_id=tool_id,
|
||
tool_display_name=tool_display_name,
|
||
wait_for_response=wait_for_response,
|
||
arguments=merged_tool_arguments,
|
||
executor=executor,
|
||
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
|
||
tool_call={**tool_call, "executor": executor, "wait_for_response": wait_for_response},
|
||
)
|
||
},
|
||
priority=22,
|
||
)
|
||
logger.info(
|
||
f"[OpenerTool] execute name={tool_name} call_id={call_id} executor={executor} "
|
||
f"wait_for_response={wait_for_response}"
|
||
)
|
||
|
||
if executor == "client":
|
||
self._pending_client_tool_call_ids.add(call_id)
|
||
if wait_for_response:
|
||
result = await self._wait_for_single_tool_result(call_id)
|
||
await self._emit_tool_result(result, source="client")
|
||
tool_results_for_context.append(result if isinstance(result, dict) else {"tool_call_id": call_id})
|
||
continue
|
||
|
||
call_for_executor = dict(tool_call)
|
||
fn_for_executor = (
|
||
dict(call_for_executor.get("function"))
|
||
if isinstance(call_for_executor.get("function"), dict)
|
||
else None
|
||
)
|
||
if isinstance(fn_for_executor, dict):
|
||
fn_for_executor["name"] = tool_id
|
||
call_for_executor["function"] = fn_for_executor
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self._server_tool_executor(call_for_executor),
|
||
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
result = {
|
||
"tool_call_id": call_id,
|
||
"name": tool_name,
|
||
"output": {"message": "server tool timeout"},
|
||
"status": {"code": 504, "message": "server_tool_timeout"},
|
||
}
|
||
await self._emit_tool_result(result, source="server")
|
||
tool_results_for_context.append(result if isinstance(result, dict) else {"tool_call_id": call_id})
|
||
|
||
return {"toolCalls": tool_calls_for_context, "toolResults": tool_results_for_context}
|
||
|
||
def _normalize_tool_result(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||
status = result.get("status") if isinstance(result.get("status"), dict) else {}
|
||
status_code = int(status.get("code") or 0) if status else 0
|
||
status_message = str(status.get("message") or "") if status else ""
|
||
tool_call_id = str(result.get("tool_call_id") or result.get("id") or "")
|
||
tool_name = str(result.get("name") or "unknown_tool")
|
||
tool_display_name = self._tool_display_name(tool_name) or tool_name
|
||
ok = bool(200 <= status_code < 300)
|
||
retryable = status_code >= 500 or status_code in {429, 408}
|
||
error: Optional[Dict[str, Any]] = None
|
||
if not ok:
|
||
error = {
|
||
"code": status_code or 500,
|
||
"message": status_message or "tool_execution_failed",
|
||
"retryable": retryable,
|
||
}
|
||
return {
|
||
"tool_call_id": tool_call_id,
|
||
"tool_name": tool_name,
|
||
"tool_display_name": tool_display_name,
|
||
"ok": ok,
|
||
"error": error,
|
||
"status": {"code": status_code, "message": status_message},
|
||
}
|
||
|
||
async def _emit_tool_result(self, result: Dict[str, Any], source: str) -> None:
|
||
tool_name = str(result.get("name") or "unknown_tool")
|
||
tool_display_name = self._tool_display_name(tool_name) or tool_name
|
||
call_id = str(result.get("tool_call_id") or result.get("id") or "")
|
||
status = result.get("status") if isinstance(result.get("status"), dict) else {}
|
||
status_code = int(status.get("code") or 0) if status else 0
|
||
status_message = str(status.get("message") or "") if status else ""
|
||
logger.info(
|
||
f"[Tool] emit result source={source} name={tool_name} call_id={call_id} "
|
||
f"status={status_code} {status_message}".strip()
|
||
)
|
||
normalized = self._normalize_tool_result(result)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"assistant.tool_result",
|
||
trackId=self.track_audio_out,
|
||
source=source,
|
||
tool_call_id=normalized["tool_call_id"],
|
||
tool_name=normalized["tool_name"],
|
||
tool_display_name=normalized["tool_display_name"],
|
||
ok=normalized["ok"],
|
||
error=normalized["error"],
|
||
result=result,
|
||
)
|
||
},
|
||
priority=22,
|
||
)
|
||
|
||
async def handle_tool_call_results(self, results: List[Dict[str, Any]]) -> None:
|
||
"""Handle client tool execution results."""
|
||
if not isinstance(results, list):
|
||
return
|
||
|
||
for item in results:
|
||
if not isinstance(item, dict):
|
||
continue
|
||
call_id = str(item.get("tool_call_id") or item.get("id") or "").strip()
|
||
if not call_id:
|
||
continue
|
||
if self._pending_client_tool_call_ids and call_id not in self._pending_client_tool_call_ids:
|
||
logger.warning(f"[Tool] ignore unsolicited client result call_id={call_id}")
|
||
continue
|
||
if call_id in self._completed_tool_call_ids:
|
||
logger.debug(f"[Tool] ignore duplicate client result call_id={call_id}")
|
||
continue
|
||
status = item.get("status") if isinstance(item.get("status"), dict) else {}
|
||
status_code = int(status.get("code") or 0) if status else 0
|
||
status_message = str(status.get("message") or "") if status else ""
|
||
tool_name = str(item.get("name") or "unknown_tool")
|
||
logger.info(
|
||
f"[Tool] received client result name={tool_name} call_id={call_id} "
|
||
f"status={status_code} {status_message}".strip()
|
||
)
|
||
|
||
waiter = self._pending_tool_waiters.get(call_id)
|
||
if waiter and not waiter.done():
|
||
waiter.set_result(item)
|
||
self._completed_tool_call_ids.add(call_id)
|
||
continue
|
||
self._early_tool_results[call_id] = item
|
||
self._completed_tool_call_ids.add(call_id)
|
||
|
||
async def _wait_for_single_tool_result(self, call_id: str) -> Dict[str, Any]:
|
||
if call_id in self._completed_tool_call_ids and call_id not in self._early_tool_results:
|
||
return {
|
||
"tool_call_id": call_id,
|
||
"status": {"code": 208, "message": "tool_call result already handled"},
|
||
"output": "",
|
||
}
|
||
if call_id in self._early_tool_results:
|
||
self._completed_tool_call_ids.add(call_id)
|
||
return self._early_tool_results.pop(call_id)
|
||
|
||
loop = asyncio.get_running_loop()
|
||
future = loop.create_future()
|
||
self._pending_tool_waiters[call_id] = future
|
||
try:
|
||
return await asyncio.wait_for(future, timeout=self._TOOL_WAIT_TIMEOUT_SECONDS)
|
||
except asyncio.TimeoutError:
|
||
self._completed_tool_call_ids.add(call_id)
|
||
return {
|
||
"tool_call_id": call_id,
|
||
"status": {"code": 504, "message": "tool_call timeout"},
|
||
"output": "",
|
||
}
|
||
finally:
|
||
self._pending_tool_waiters.pop(call_id, None)
|
||
self._pending_client_tool_call_ids.discard(call_id)
|
||
|
||
def _normalize_stream_event(self, item: Any) -> LLMStreamEvent:
|
||
if isinstance(item, LLMStreamEvent):
|
||
return item
|
||
if isinstance(item, str):
|
||
return LLMStreamEvent(type="text_delta", text=item)
|
||
if isinstance(item, dict):
|
||
event_type = str(item.get("type") or "")
|
||
if event_type in {"text_delta", "tool_call", "done"}:
|
||
return LLMStreamEvent(
|
||
type=event_type, # type: ignore[arg-type]
|
||
text=item.get("text"),
|
||
tool_call=item.get("tool_call"),
|
||
)
|
||
return LLMStreamEvent(type="done")
|
||
|
||
async def _handle_turn(self, user_text: str, system_context: Optional[str] = None) -> None:
|
||
"""
|
||
Handle a complete conversation turn.
|
||
|
||
Uses sentence-by-sentence streaming TTS for lower latency.
|
||
|
||
Args:
|
||
user_text: User's transcribed text
|
||
"""
|
||
try:
|
||
if not self._current_turn_id:
|
||
self._start_turn()
|
||
if not self._current_utterance_id:
|
||
self._finalize_utterance()
|
||
turn_id = self._current_turn_id
|
||
utterance_id = self._current_utterance_id
|
||
response_id = self._start_response()
|
||
# Start latency tracking
|
||
self._turn_start_time = time.time()
|
||
self._first_audio_sent = False
|
||
|
||
full_response = ""
|
||
messages = self.conversation.get_messages()
|
||
if system_context and system_context.strip():
|
||
messages = [*messages, LLMMessage(role="system", content=system_context.strip())]
|
||
max_rounds = 3
|
||
|
||
await self.conversation.start_assistant_turn()
|
||
self._is_bot_speaking = True
|
||
self._interrupt_event.clear()
|
||
self._drop_outbound_audio = False
|
||
|
||
first_audio_sent = False
|
||
self._pending_llm_delta = ""
|
||
self._last_llm_delta_emit_ms = 0.0
|
||
for _ in range(max_rounds):
|
||
if self._interrupt_event.is_set():
|
||
break
|
||
|
||
sentence_buffer = ""
|
||
pending_punctuation = ""
|
||
round_response = ""
|
||
tool_calls: List[Dict[str, Any]] = []
|
||
allow_text_output = True
|
||
use_engine_sentence_split = self._use_engine_sentence_split_for_tts()
|
||
|
||
async for raw_event in self.llm_service.generate_stream(messages):
|
||
if self._interrupt_event.is_set():
|
||
break
|
||
|
||
event = self._normalize_stream_event(raw_event)
|
||
if event.type == "tool_call":
|
||
await self._flush_pending_llm_delta(
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
tool_call = event.tool_call if isinstance(event.tool_call, dict) else None
|
||
if not tool_call:
|
||
continue
|
||
allow_text_output = False
|
||
executor = self._tool_executor(tool_call)
|
||
enriched_tool_call = dict(tool_call)
|
||
enriched_tool_call["executor"] = executor
|
||
tool_name = self._tool_name(enriched_tool_call) or "unknown_tool"
|
||
tool_id = self._tool_id_for_name(tool_name)
|
||
tool_display_name = self._tool_display_name(tool_name) or tool_name
|
||
wait_for_response = self._tool_wait_for_response(tool_name)
|
||
enriched_tool_call["wait_for_response"] = wait_for_response
|
||
call_id = str(enriched_tool_call.get("id") or "").strip()
|
||
fn_payload = (
|
||
dict(enriched_tool_call.get("function"))
|
||
if isinstance(enriched_tool_call.get("function"), dict)
|
||
else None
|
||
)
|
||
raw_args = str(fn_payload.get("arguments") or "") if isinstance(fn_payload, dict) else ""
|
||
tool_arguments = self._tool_arguments(enriched_tool_call)
|
||
merged_tool_arguments = self._apply_tool_default_args(tool_name, tool_arguments)
|
||
try:
|
||
merged_args_text = json.dumps(merged_tool_arguments, ensure_ascii=False)
|
||
except Exception:
|
||
merged_args_text = raw_args if raw_args else "{}"
|
||
if isinstance(fn_payload, dict):
|
||
fn_payload["arguments"] = merged_args_text
|
||
enriched_tool_call["function"] = fn_payload
|
||
args_preview = raw_args if len(raw_args) <= 160 else f"{raw_args[:160]}..."
|
||
logger.info(
|
||
f"[Tool] call requested name={tool_name} call_id={call_id} "
|
||
f"executor={executor} args={args_preview} merged_args={merged_args_text}"
|
||
)
|
||
tool_calls.append(enriched_tool_call)
|
||
if executor == "client" and call_id:
|
||
self._pending_client_tool_call_ids.add(call_id)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"assistant.tool_call",
|
||
trackId=self.track_audio_out,
|
||
tool_call_id=call_id,
|
||
tool_name=tool_name,
|
||
tool_id=tool_id,
|
||
tool_display_name=tool_display_name,
|
||
wait_for_response=wait_for_response,
|
||
arguments=tool_arguments,
|
||
executor=executor,
|
||
timeout_ms=int(self._TOOL_WAIT_TIMEOUT_SECONDS * 1000),
|
||
tool_call=enriched_tool_call,
|
||
)
|
||
},
|
||
priority=22,
|
||
)
|
||
continue
|
||
|
||
if event.type != "text_delta":
|
||
continue
|
||
|
||
text_chunk = event.text or ""
|
||
if not text_chunk:
|
||
continue
|
||
|
||
if not allow_text_output:
|
||
continue
|
||
|
||
full_response += text_chunk
|
||
round_response += text_chunk
|
||
sentence_buffer += text_chunk
|
||
await self.conversation.update_assistant_text(text_chunk)
|
||
self._pending_llm_delta += text_chunk
|
||
now_ms = time.monotonic() * 1000.0
|
||
if (
|
||
self._last_llm_delta_emit_ms <= 0.0
|
||
or now_ms - self._last_llm_delta_emit_ms >= self._LLM_DELTA_THROTTLE_MS
|
||
):
|
||
await self._flush_pending_llm_delta(
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
|
||
if use_engine_sentence_split:
|
||
while True:
|
||
split_result = extract_tts_sentence(
|
||
sentence_buffer,
|
||
end_chars=self._SENTENCE_END_CHARS,
|
||
trailing_chars=self._SENTENCE_TRAILING_CHARS,
|
||
closers=self._SENTENCE_CLOSERS,
|
||
min_split_spoken_chars=self._MIN_SPLIT_SPOKEN_CHARS,
|
||
hold_trailing_at_buffer_end=True,
|
||
force=False,
|
||
)
|
||
if not split_result:
|
||
break
|
||
sentence, sentence_buffer = split_result
|
||
if not sentence:
|
||
continue
|
||
|
||
sentence = f"{pending_punctuation}{sentence}".strip()
|
||
pending_punctuation = ""
|
||
if not sentence:
|
||
continue
|
||
|
||
if not has_spoken_content(sentence):
|
||
pending_punctuation = sentence
|
||
continue
|
||
|
||
if self._tts_output_enabled() and not self._interrupt_event.is_set():
|
||
if not first_audio_sent:
|
||
tts_id = self._start_tts()
|
||
self._mark_client_playback_started(tts_id)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"output.audio.start",
|
||
trackId=self.track_audio_out,
|
||
),
|
||
"turn_id": turn_id,
|
||
"utterance_id": utterance_id,
|
||
"response_id": response_id,
|
||
},
|
||
priority=30,
|
||
)
|
||
first_audio_sent = True
|
||
|
||
await self._speak_sentence(
|
||
sentence,
|
||
fade_in_ms=0,
|
||
fade_out_ms=8,
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
|
||
if use_engine_sentence_split:
|
||
remaining_text = f"{pending_punctuation}{sentence_buffer}".strip()
|
||
else:
|
||
remaining_text = sentence_buffer.strip()
|
||
await self._flush_pending_llm_delta(
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
if (
|
||
self._tts_output_enabled()
|
||
and remaining_text
|
||
and has_spoken_content(remaining_text)
|
||
and not self._interrupt_event.is_set()
|
||
):
|
||
if not first_audio_sent:
|
||
tts_id = self._start_tts()
|
||
self._mark_client_playback_started(tts_id)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"output.audio.start",
|
||
trackId=self.track_audio_out,
|
||
),
|
||
"turn_id": turn_id,
|
||
"utterance_id": utterance_id,
|
||
"response_id": response_id,
|
||
},
|
||
priority=30,
|
||
)
|
||
first_audio_sent = True
|
||
await self._speak_sentence(
|
||
remaining_text,
|
||
fade_in_ms=0,
|
||
fade_out_ms=8,
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
|
||
if not tool_calls:
|
||
break
|
||
|
||
tool_results: List[Dict[str, Any]] = []
|
||
for call in tool_calls:
|
||
call_id = str(call.get("id") or "").strip()
|
||
if not call_id:
|
||
continue
|
||
executor = str(call.get("executor") or "server").strip().lower()
|
||
tool_name = self._tool_name(call) or "unknown_tool"
|
||
tool_id = self._tool_id_for_name(tool_name)
|
||
logger.info(f"[Tool] execute start name={tool_name} call_id={call_id} executor={executor}")
|
||
if executor == "client":
|
||
result = await self._wait_for_single_tool_result(call_id)
|
||
await self._emit_tool_result(result, source="client")
|
||
tool_results.append(result)
|
||
continue
|
||
|
||
call_for_executor = dict(call)
|
||
fn_for_executor = (
|
||
dict(call_for_executor.get("function"))
|
||
if isinstance(call_for_executor.get("function"), dict)
|
||
else None
|
||
)
|
||
if isinstance(fn_for_executor, dict):
|
||
fn_for_executor["name"] = tool_id
|
||
call_for_executor["function"] = fn_for_executor
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self._server_tool_executor(call_for_executor),
|
||
timeout=self._SERVER_TOOL_TIMEOUT_SECONDS,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
result = {
|
||
"tool_call_id": call_id,
|
||
"name": self._tool_name(call) or "unknown_tool",
|
||
"output": {"message": "server tool timeout"},
|
||
"status": {"code": 504, "message": "server_tool_timeout"},
|
||
}
|
||
await self._emit_tool_result(result, source="server")
|
||
tool_results.append(result)
|
||
|
||
messages = [
|
||
*messages,
|
||
LLMMessage(
|
||
role="assistant",
|
||
content=round_response.strip(),
|
||
),
|
||
LLMMessage(
|
||
role="system",
|
||
content=(
|
||
"Tool execution results are available. "
|
||
"Continue answering the user naturally using these results. "
|
||
"Do not request the same tool again in this turn.\n"
|
||
f"tool_calls={json.dumps(tool_calls, ensure_ascii=False)}\n"
|
||
f"tool_results={json.dumps(tool_results, ensure_ascii=False)}"
|
||
),
|
||
),
|
||
]
|
||
|
||
if full_response and not self._interrupt_event.is_set():
|
||
await self._flush_pending_llm_delta(
|
||
turn_id=turn_id,
|
||
utterance_id=utterance_id,
|
||
response_id=response_id,
|
||
)
|
||
await self._send_event(
|
||
{
|
||
**ev(
|
||
"assistant.response.final",
|
||
trackId=self.track_audio_out,
|
||
text=full_response,
|
||
),
|
||
"turn_id": turn_id,
|
||
"utterance_id": utterance_id,
|
||
"response_id": response_id,
|
||
},
|
||
priority=20,
|
||
)
|
||
|
||
# Send track end
|
||
if first_audio_sent:
|
||
await self._flush_audio_out_frames(priority=50)
|
||
await self._send_event({
|
||
**ev(
|
||
"output.audio.end",
|
||
trackId=self.track_audio_out,
|
||
),
|
||
"turn_id": turn_id,
|
||
"utterance_id": utterance_id,
|
||
"response_id": response_id,
|
||
}, priority=10)
|
||
|
||
# End assistant turn
|
||
await self.conversation.end_assistant_turn(
|
||
was_interrupted=self._interrupt_event.is_set()
|
||
)
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("Turn handling cancelled")
|
||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||
except Exception as e:
|
||
logger.error(f"Turn handling error: {e}", exc_info=True)
|
||
await self.conversation.end_assistant_turn(was_interrupted=True)
|
||
finally:
|
||
self._is_bot_speaking = False
|
||
# Reset barge-in tracking when bot finishes speaking
|
||
self._barge_in_speech_start_time = None
|
||
self._barge_in_speech_frames = 0
|
||
self._barge_in_silence_frames = 0
|
||
self._current_response_id = None
|
||
self._current_tts_id = None
|
||
|
||
async def _speak_sentence(
|
||
self,
|
||
text: str,
|
||
fade_in_ms: int = 0,
|
||
fade_out_ms: int = 8,
|
||
turn_id: Optional[str] = None,
|
||
utterance_id: Optional[str] = None,
|
||
response_id: Optional[str] = None,
|
||
) -> None:
|
||
"""
|
||
Synthesize and send a single sentence.
|
||
|
||
Args:
|
||
text: Sentence to speak
|
||
fade_in_ms: Fade-in duration for sentence start chunks
|
||
fade_out_ms: Fade-out duration for sentence end chunks
|
||
"""
|
||
if not self._tts_output_enabled():
|
||
return
|
||
|
||
if not text.strip() or self._interrupt_event.is_set() or not self.tts_service:
|
||
return
|
||
|
||
logger.info(f"[TTS] split sentence: {text!r}")
|
||
|
||
try:
|
||
is_first_chunk = True
|
||
async for chunk in self.tts_service.synthesize_stream(text):
|
||
# Check interrupt at the start of each iteration
|
||
if self._interrupt_event.is_set():
|
||
logger.debug("TTS sentence interrupted")
|
||
break
|
||
|
||
# Track and log first audio packet latency (TTFB)
|
||
if not self._first_audio_sent and self._turn_start_time:
|
||
ttfb_ms = (time.time() - self._turn_start_time) * 1000
|
||
self._first_audio_sent = True
|
||
logger.info(f"[TTFB] Server first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||
|
||
# Send TTFB event to client
|
||
await self._send_event({
|
||
**ev(
|
||
"metrics.ttfb",
|
||
trackId=self.track_audio_out,
|
||
latencyMs=round(ttfb_ms),
|
||
),
|
||
"turn_id": turn_id,
|
||
"utterance_id": utterance_id,
|
||
"response_id": response_id,
|
||
}, priority=25)
|
||
|
||
# Double-check interrupt right before sending audio
|
||
if self._interrupt_event.is_set():
|
||
break
|
||
|
||
smoothed_audio = self._apply_edge_fade(
|
||
pcm_bytes=chunk.audio,
|
||
sample_rate=chunk.sample_rate,
|
||
fade_in=is_first_chunk,
|
||
fade_out=bool(chunk.is_final),
|
||
fade_in_ms=fade_in_ms,
|
||
fade_out_ms=fade_out_ms,
|
||
)
|
||
is_first_chunk = False
|
||
|
||
await self._send_audio(smoothed_audio, priority=50)
|
||
except asyncio.CancelledError:
|
||
logger.debug("TTS sentence cancelled")
|
||
except Exception as e:
|
||
logger.error(f"TTS sentence error: {e}")
|
||
|
||
def _apply_edge_fade(
|
||
self,
|
||
pcm_bytes: bytes,
|
||
sample_rate: int,
|
||
fade_in: bool = False,
|
||
fade_out: bool = False,
|
||
fade_in_ms: int = 0,
|
||
fade_out_ms: int = 8,
|
||
) -> bytes:
|
||
"""Apply short edge fades to reduce click/pop at sentence boundaries."""
|
||
if not pcm_bytes or (not fade_in and not fade_out):
|
||
return pcm_bytes
|
||
|
||
try:
|
||
samples = np.frombuffer(pcm_bytes, dtype="<i2").astype(np.float32)
|
||
if samples.size == 0:
|
||
return pcm_bytes
|
||
|
||
if fade_in and fade_in_ms > 0:
|
||
fade_in_samples = int(sample_rate * (fade_in_ms / 1000.0))
|
||
fade_in_samples = max(1, min(fade_in_samples, samples.size))
|
||
samples[:fade_in_samples] *= np.linspace(0.0, 1.0, fade_in_samples, endpoint=True)
|
||
if fade_out:
|
||
fade_out_samples = int(sample_rate * (fade_out_ms / 1000.0))
|
||
fade_out_samples = max(1, min(fade_out_samples, samples.size))
|
||
samples[-fade_out_samples:] *= np.linspace(1.0, 0.0, fade_out_samples, endpoint=True)
|
||
|
||
return np.clip(samples, -32768, 32767).astype("<i2").tobytes()
|
||
except Exception:
|
||
# Fallback: never block audio delivery on smoothing failure.
|
||
return pcm_bytes
|
||
|
||
async def _speak(self, text: str, audio_event_priority: int = 10) -> None:
|
||
"""
|
||
Synthesize and send speech.
|
||
|
||
Args:
|
||
text: Text to speak
|
||
audio_event_priority: Priority for output.audio.start/end events
|
||
"""
|
||
if not self._tts_output_enabled():
|
||
return
|
||
|
||
if not text.strip() or not self.tts_service:
|
||
return
|
||
|
||
try:
|
||
self._drop_outbound_audio = False
|
||
# Start latency tracking for greeting
|
||
speak_start_time = time.time()
|
||
first_audio_sent = False
|
||
|
||
# Send track start event
|
||
tts_id = self._start_tts()
|
||
self._mark_client_playback_started(tts_id)
|
||
await self._send_event({
|
||
**ev(
|
||
"output.audio.start",
|
||
trackId=self.track_audio_out,
|
||
)
|
||
}, priority=audio_event_priority)
|
||
|
||
self._is_bot_speaking = True
|
||
|
||
# Stream TTS audio
|
||
async for chunk in self.tts_service.synthesize_stream(text):
|
||
if self._interrupt_event.is_set():
|
||
logger.info("TTS interrupted by barge-in")
|
||
break
|
||
|
||
# Track and log first audio packet latency (TTFB)
|
||
if not first_audio_sent:
|
||
ttfb_ms = (time.time() - speak_start_time) * 1000
|
||
first_audio_sent = True
|
||
logger.info(f"[TTFB] Greeting first audio packet latency: {ttfb_ms:.0f}ms (session {self.session_id})")
|
||
|
||
# Send TTFB event to client
|
||
await self._send_event({
|
||
**ev(
|
||
"metrics.ttfb",
|
||
trackId=self.track_audio_out,
|
||
latencyMs=round(ttfb_ms),
|
||
)
|
||
}, priority=25)
|
||
|
||
# Send audio to client
|
||
await self._send_audio(chunk.audio, priority=50)
|
||
|
||
# Small delay to prevent flooding
|
||
await asyncio.sleep(0.01)
|
||
|
||
# Send track end event
|
||
await self._flush_audio_out_frames(priority=50)
|
||
await self._send_event({
|
||
**ev(
|
||
"output.audio.end",
|
||
trackId=self.track_audio_out,
|
||
)
|
||
}, priority=audio_event_priority)
|
||
|
||
except asyncio.CancelledError:
|
||
logger.info("TTS cancelled")
|
||
raise
|
||
except Exception as e:
|
||
logger.error(f"TTS error: {e}")
|
||
finally:
|
||
self._is_bot_speaking = False
|
||
|
||
async def _handle_barge_in(self) -> None:
|
||
"""Handle user barge-in (interruption)."""
|
||
if not self._is_bot_speaking:
|
||
return
|
||
|
||
logger.info("Barge-in detected - interrupting bot speech")
|
||
|
||
# Reset barge-in tracking
|
||
self._barge_in_speech_start_time = None
|
||
self._barge_in_speech_frames = 0
|
||
self._barge_in_silence_frames = 0
|
||
|
||
# IMPORTANT: Signal interruption FIRST to stop audio sending
|
||
self._interrupt_event.set()
|
||
self._is_bot_speaking = False
|
||
self._drop_outbound_audio = True
|
||
self._audio_out_frame_buffer = b""
|
||
self._clear_client_playback_tracking()
|
||
interrupted_turn_id = self._current_turn_id
|
||
interrupted_utterance_id = self._current_utterance_id
|
||
interrupted_response_id = self._current_response_id
|
||
|
||
# Send interrupt event to client IMMEDIATELY
|
||
# This must happen BEFORE canceling services, so client knows to discard in-flight audio
|
||
await self._send_event({
|
||
**ev(
|
||
"response.interrupted",
|
||
trackId=self.track_audio_out,
|
||
),
|
||
"turn_id": interrupted_turn_id,
|
||
"utterance_id": interrupted_utterance_id,
|
||
"response_id": interrupted_response_id,
|
||
}, priority=0)
|
||
|
||
# Cancel TTS
|
||
if self.tts_service:
|
||
await self.tts_service.cancel()
|
||
|
||
# Cancel LLM
|
||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||
self.llm_service.cancel()
|
||
|
||
# Interrupt conversation only if there is no active turn task.
|
||
# When a turn task exists, it will handle end_assistant_turn() to avoid double callbacks.
|
||
if not (self._current_turn_task and not self._current_turn_task.done()):
|
||
await self.conversation.interrupt()
|
||
|
||
# Reset for new user turn
|
||
await self.conversation.start_user_turn()
|
||
self._audio_buffer = b""
|
||
self.eou_detector.reset()
|
||
self._asr_capture_active = False
|
||
self._asr_capture_started_ms = 0.0
|
||
self._pending_speech_audio = b""
|
||
|
||
async def _stop_current_speech(self) -> None:
|
||
"""Stop any current speech task."""
|
||
self._drop_outbound_audio = True
|
||
self._audio_out_frame_buffer = b""
|
||
self._clear_client_playback_tracking()
|
||
if self._current_turn_task and not self._current_turn_task.done():
|
||
self._interrupt_event.set()
|
||
self._current_turn_task.cancel()
|
||
try:
|
||
await self._current_turn_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
# Ensure underlying services are cancelled to avoid leaking work/audio
|
||
if self.tts_service:
|
||
await self.tts_service.cancel()
|
||
if self.llm_service and hasattr(self.llm_service, 'cancel'):
|
||
self.llm_service.cancel()
|
||
|
||
self._is_bot_speaking = False
|
||
self._interrupt_event.clear()
|
||
|
||
async def cleanup(self) -> None:
|
||
"""Cleanup pipeline resources."""
|
||
logger.info(f"Cleaning up DuplexPipeline for session {self.session_id}")
|
||
|
||
self._running = False
|
||
await self._stop_current_speech()
|
||
if self._outbound_task and not self._outbound_task.done():
|
||
await self._enqueue_outbound("stop", None, priority=-1000)
|
||
await self._outbound_task
|
||
self._outbound_task = None
|
||
|
||
# Disconnect services
|
||
if self.llm_service:
|
||
await self.llm_service.disconnect()
|
||
if self.tts_service:
|
||
await self.tts_service.disconnect()
|
||
if self.asr_service:
|
||
await self.asr_service.disconnect()
|
||
|
||
def _get_timestamp_ms(self) -> int:
|
||
"""Get current timestamp in milliseconds."""
|
||
import time
|
||
return int(time.time() * 1000)
|
||
|
||
@property
|
||
def is_speaking(self) -> bool:
|
||
"""Check if assistant audio is still active (server send or client playback)."""
|
||
return self._is_bot_speaking or self.is_client_playing_audio
|
||
|
||
@property
|
||
def is_client_playing_audio(self) -> bool:
|
||
"""Check if client has unacknowledged assistant audio playback."""
|
||
return bool(self._pending_client_playback_tts_ids)
|
||
|
||
@property
|
||
def state(self) -> ConversationState:
|
||
"""Get current conversation state."""
|
||
return self.conversation.state
|