247 lines
9.0 KiB
Python
247 lines
9.0 KiB
Python
from __future__ import annotations
|
|
|
|
import uuid
|
|
|
|
from loguru import logger
|
|
|
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
|
from pipecat.audio.vad.vad_analyzer import VADParams
|
|
from pipecat.frames.frames import (
|
|
LLMRunFrame,
|
|
OutputTransportMessageUrgentFrame,
|
|
TTSSpeakFrame,
|
|
)
|
|
from pipecat.pipeline.pipeline import Pipeline
|
|
from pipecat.pipeline.runner import PipelineRunner
|
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
|
from pipecat.processors.aggregators.llm_context import LLMContext
|
|
from pipecat.processors.aggregators.llm_response_universal import (
|
|
AssistantTurnStoppedMessage,
|
|
LLMContextAggregatorPair,
|
|
LLMUserAggregatorParams,
|
|
UserTurnStoppedMessage,
|
|
)
|
|
from pipecat.serializers.protobuf import ProtobufFrameSerializer
|
|
from pipecat.serializers.base_serializer import FrameSerializer
|
|
from pipecat.transports.websocket.fastapi import (
|
|
FastAPIWebsocketParams,
|
|
FastAPIWebsocketTransport,
|
|
)
|
|
from pipecat.turns.user_stop.speech_timeout_user_turn_stop_strategy import (
|
|
SpeechTimeoutUserTurnStopStrategy,
|
|
)
|
|
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
|
|
|
from .audio_filters import create_audio_input_filter
|
|
from .config import EngineConfig
|
|
from .context_sync import AssistantContextSyncProcessor
|
|
from .fastgpt_llm import FastGPTLLMService
|
|
from .product_protocol import ProductWebsocketSerializer
|
|
from .response_state import StateTagResponseProcessor
|
|
from .services import create_llm_service, create_stt_service, create_tts_service
|
|
from .text_input import ProductTextInputProcessor
|
|
from .text_stream import ProductTextStreamProcessor, maybe_sync_assistant_context
|
|
from .transcript_stream import ProductTranscriptStreamProcessor
|
|
from .turn_start import InterruptionGateUserTurnStartStrategy
|
|
|
|
|
|
async def run_voice_pipeline(websocket, config: EngineConfig) -> None:
|
|
await run_pipeline_with_serializer(
|
|
websocket,
|
|
config,
|
|
serializer=ProtobufFrameSerializer(),
|
|
client_label="Pipecat protobuf",
|
|
)
|
|
|
|
|
|
async def run_product_voice_pipeline(websocket, config: EngineConfig) -> None:
|
|
await run_pipeline_with_serializer(
|
|
websocket,
|
|
config,
|
|
serializer=ProductWebsocketSerializer(
|
|
sample_rate=config.audio.sample_rate_hz,
|
|
channels=config.audio.channels,
|
|
),
|
|
client_label="Product JSON",
|
|
)
|
|
|
|
|
|
async def run_pipeline_with_serializer(
|
|
websocket,
|
|
config: EngineConfig,
|
|
*,
|
|
serializer: FrameSerializer,
|
|
client_label: str,
|
|
) -> None:
|
|
transport = FastAPIWebsocketTransport(
|
|
websocket=websocket,
|
|
params=FastAPIWebsocketParams(
|
|
audio_in_enabled=True,
|
|
audio_out_enabled=True,
|
|
audio_in_sample_rate=config.audio.sample_rate_hz,
|
|
audio_out_sample_rate=config.audio.sample_rate_hz,
|
|
audio_in_channels=config.audio.channels,
|
|
audio_out_channels=config.audio.channels,
|
|
audio_in_filter=create_audio_input_filter(config.audio_filter, config.audio),
|
|
serializer=serializer,
|
|
session_timeout=None,
|
|
),
|
|
)
|
|
|
|
stt = create_stt_service(config.services.stt, config.audio)
|
|
|
|
llm_config = config.services.llm
|
|
chat_id = llm_config.chat_id or f"voice_{uuid.uuid4().hex[:16]}"
|
|
llm = create_llm_service(
|
|
llm_config,
|
|
chat_id=chat_id,
|
|
session_variables={"session_id": chat_id, "channel": "voice"},
|
|
greeting_prompt=config.agent.greeting,
|
|
)
|
|
if llm_config.is_fastgpt:
|
|
logger.info(f"LLM backend=fastgpt chatId={chat_id} appId={llm_config.app_id or '-'}")
|
|
else:
|
|
logger.info(f"LLM backend=openai model={llm_config.model}")
|
|
|
|
tts = create_tts_service(config.services.tts, config.audio)
|
|
|
|
messages: list[dict[str, str]] = []
|
|
if llm_config.uses_local_context_history:
|
|
messages = [{"role": "system", "content": config.agent.system_prompt}]
|
|
if config.agent.greeting and config.agent.greeting_mode == "generated":
|
|
messages.append({"role": "system", "content": config.agent.greeting})
|
|
|
|
context = LLMContext(messages)
|
|
|
|
vad_params = VADParams(
|
|
confidence=config.turn.vad.confidence,
|
|
start_secs=config.turn.vad.start_secs,
|
|
stop_secs=config.turn.vad.stop_secs,
|
|
min_volume=config.turn.vad.min_volume,
|
|
)
|
|
# Replace pipecat's default stop strategy (Smart Turn v3) with a simple
|
|
# silence-timeout strategy. Smart Turn v3 was finalizing every short
|
|
# Chinese phrase as a complete turn, which caused one logical utterance
|
|
# to become several LLM calls and several user bubbles in the UI. The
|
|
# timeout strategy waits for `user_speech_timeout_sec` of silence
|
|
# (re-armed every time the user resumes speaking) before declaring the
|
|
# turn finished — which is what we actually want for streaming ASRs.
|
|
user_turn_strategies = UserTurnStrategies(
|
|
start=[
|
|
InterruptionGateUserTurnStartStrategy(
|
|
min_chars_when_bot_speaking=config.turn.interruption_min_chars,
|
|
allowed_short_replies=config.turn.interruption_short_replies,
|
|
use_interim=config.turn.interruption_use_interim,
|
|
),
|
|
],
|
|
stop=[
|
|
SpeechTimeoutUserTurnStopStrategy(
|
|
user_speech_timeout=config.turn.user_speech_timeout_sec,
|
|
),
|
|
],
|
|
)
|
|
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
|
context,
|
|
user_params=LLMUserAggregatorParams(
|
|
vad_analyzer=SileroVADAnalyzer(params=vad_params),
|
|
user_turn_strategies=user_turn_strategies,
|
|
),
|
|
)
|
|
|
|
text_stream = ProductTextStreamProcessor()
|
|
context_sync = AssistantContextSyncProcessor(
|
|
text_stream=text_stream,
|
|
assistant_aggregator=assistant_aggregator,
|
|
)
|
|
|
|
processors = [
|
|
transport.input(),
|
|
ProductTextInputProcessor(),
|
|
stt,
|
|
ProductTranscriptStreamProcessor(),
|
|
context_sync,
|
|
user_aggregator,
|
|
llm,
|
|
]
|
|
if config.agent.response_state.enabled:
|
|
processors.append(StateTagResponseProcessor(config.agent.response_state))
|
|
processors.extend(
|
|
[
|
|
text_stream,
|
|
tts,
|
|
transport.output(),
|
|
assistant_aggregator,
|
|
]
|
|
)
|
|
pipeline = Pipeline(processors)
|
|
|
|
task = PipelineTask(
|
|
pipeline,
|
|
params=PipelineParams(
|
|
audio_in_sample_rate=config.audio.sample_rate_hz,
|
|
audio_out_sample_rate=config.audio.sample_rate_hz,
|
|
enable_metrics=True,
|
|
enable_usage_metrics=True,
|
|
enable_heartbeats=True,
|
|
),
|
|
idle_timeout_secs=config.session.inactivity_timeout_sec,
|
|
)
|
|
|
|
@transport.event_handler("on_client_connected")
|
|
async def on_client_connected(_transport, _client):
|
|
logger.info(f"{client_label} websocket client connected")
|
|
if config.agent.greeting_mode == "fixed" and config.agent.greeting:
|
|
await task.queue_frames([TTSSpeakFrame(config.agent.greeting)])
|
|
elif config.agent.greeting_mode == "generated":
|
|
if isinstance(llm, FastGPTLLMService):
|
|
welcome = await llm.fetch_welcome_text()
|
|
if welcome:
|
|
await task.queue_frames([TTSSpeakFrame(welcome)])
|
|
else:
|
|
await task.queue_frames([LLMRunFrame()])
|
|
else:
|
|
await task.queue_frames([LLMRunFrame()])
|
|
|
|
@transport.event_handler("on_client_disconnected")
|
|
async def on_client_disconnected(_transport, _client):
|
|
logger.info(f"{client_label} websocket client disconnected")
|
|
await task.cancel()
|
|
|
|
@transport.event_handler("on_session_timeout")
|
|
async def on_session_timeout(_transport, _client):
|
|
logger.info(f"{client_label} websocket session timed out")
|
|
await task.cancel()
|
|
|
|
@user_aggregator.event_handler("on_user_turn_stopped")
|
|
async def on_user_turn_stopped(_aggregator, _strategy, message: UserTurnStoppedMessage):
|
|
logger.info(f"User: {message.content}")
|
|
text = (message.content or "").strip()
|
|
if not text:
|
|
return
|
|
await task.queue_frame(
|
|
OutputTransportMessageUrgentFrame(
|
|
message={
|
|
"type": "input.transcript.final",
|
|
"text": text,
|
|
"user_id": message.user_id,
|
|
"timestamp": message.timestamp,
|
|
}
|
|
)
|
|
)
|
|
|
|
# NOTE: assistant turn started/final events are emitted by
|
|
# ProductTextStreamProcessor, upstream of TTS, so text streams to the
|
|
# client ahead of audio. This logger is kept for server-side visibility.
|
|
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
|
|
async def on_assistant_turn_stopped(_aggregator, message: AssistantTurnStoppedMessage):
|
|
logger.info(f"Assistant: {message.content}")
|
|
maybe_sync_assistant_context(
|
|
_aggregator,
|
|
text_stream,
|
|
committed_text=message.content or "",
|
|
)
|
|
text_stream.take_interrupted_stream_text()
|
|
|
|
runner = PipelineRunner(handle_sigint=False)
|
|
await runner.run(task)
|