Files
engine-v5-pipecat-core/engine/pipeline.py
2026-05-21 13:08:40 +08:00

161 lines
5.6 KiB
Python

from __future__ import annotations
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import (
LLMRunFrame,
OutputTransportMessageUrgentFrame,
TTSSpeakFrame,
)
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import (
AssistantTurnStoppedMessage,
LLMContextAggregatorPair,
LLMUserAggregatorParams,
UserTurnStoppedMessage,
)
from pipecat.serializers.protobuf import ProtobufFrameSerializer
from pipecat.serializers.base_serializer import FrameSerializer
from pipecat.transports.websocket.fastapi import (
FastAPIWebsocketParams,
FastAPIWebsocketTransport,
)
from .config import EngineConfig
from .product_protocol import ProductWebsocketSerializer
from .services import create_llm_service, create_stt_service, create_tts_service
from .text_input import ProductTextInputProcessor
from .text_stream import ProductTextStreamProcessor
async def run_voice_pipeline(websocket, config: EngineConfig) -> None:
await run_pipeline_with_serializer(
websocket,
config,
serializer=ProtobufFrameSerializer(),
client_label="Pipecat protobuf",
)
async def run_product_voice_pipeline(websocket, config: EngineConfig) -> None:
await run_pipeline_with_serializer(
websocket,
config,
serializer=ProductWebsocketSerializer(
sample_rate=config.audio.sample_rate_hz,
channels=config.audio.channels,
),
client_label="Product JSON",
)
async def run_pipeline_with_serializer(
websocket,
config: EngineConfig,
*,
serializer: FrameSerializer,
client_label: str,
) -> None:
transport = FastAPIWebsocketTransport(
websocket=websocket,
params=FastAPIWebsocketParams(
audio_in_enabled=True,
audio_out_enabled=True,
audio_in_sample_rate=config.audio.sample_rate_hz,
audio_out_sample_rate=config.audio.sample_rate_hz,
audio_in_channels=config.audio.channels,
audio_out_channels=config.audio.channels,
serializer=serializer,
session_timeout=None,
),
)
stt = create_stt_service(config.services.stt)
llm = create_llm_service(config.services.llm)
tts = create_tts_service(config.services.tts, config.audio)
messages = [{"role": "developer", "content": config.agent.system_prompt}]
if config.agent.greeting and config.agent.greeting_mode == "generated":
messages.append({"role": "developer", "content": config.agent.greeting})
context = LLMContext(messages)
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
context,
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
)
pipeline = Pipeline(
[
transport.input(),
ProductTextInputProcessor(),
stt,
user_aggregator,
llm,
ProductTextStreamProcessor(),
tts,
transport.output(),
assistant_aggregator,
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
audio_in_sample_rate=config.audio.sample_rate_hz,
audio_out_sample_rate=config.audio.sample_rate_hz,
enable_metrics=True,
enable_usage_metrics=True,
enable_heartbeats=True,
),
idle_timeout_secs=config.session.inactivity_timeout_sec,
)
@transport.event_handler("on_client_connected")
async def on_client_connected(_transport, _client):
logger.info(f"{client_label} websocket client connected")
if config.agent.greeting_mode == "fixed" and config.agent.greeting:
await task.queue_frames([TTSSpeakFrame(config.agent.greeting)])
elif config.agent.greeting_mode == "generated":
await task.queue_frames([LLMRunFrame()])
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(_transport, _client):
logger.info(f"{client_label} websocket client disconnected")
await task.cancel()
@transport.event_handler("on_session_timeout")
async def on_session_timeout(_transport, _client):
logger.info(f"{client_label} websocket session timed out")
await task.cancel()
@user_aggregator.event_handler("on_user_turn_stopped")
async def on_user_turn_stopped(_aggregator, _strategy, message: UserTurnStoppedMessage):
logger.info(f"User: {message.content}")
text = (message.content or "").strip()
if not text:
return
await task.queue_frame(
OutputTransportMessageUrgentFrame(
message={
"type": "input.transcript.final",
"text": text,
"user_id": message.user_id,
"timestamp": message.timestamp,
}
)
)
# NOTE: assistant turn started/final events are emitted by
# ProductTextStreamProcessor, upstream of TTS, so text streams to the
# client ahead of audio. This logger is kept for server-side visibility.
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
async def on_assistant_turn_stopped(_aggregator, message: AssistantTurnStoppedMessage):
logger.info(f"Assistant: {message.content}")
runner = PipelineRunner(handle_sigint=False)
await runner.run(task)