161 lines
5.6 KiB
Python
161 lines
5.6 KiB
Python
from __future__ import annotations
|
|
|
|
from loguru import logger
|
|
|
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
|
from pipecat.frames.frames import (
|
|
LLMRunFrame,
|
|
OutputTransportMessageUrgentFrame,
|
|
TTSSpeakFrame,
|
|
)
|
|
from pipecat.pipeline.pipeline import Pipeline
|
|
from pipecat.pipeline.runner import PipelineRunner
|
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
|
from pipecat.processors.aggregators.llm_context import LLMContext
|
|
from pipecat.processors.aggregators.llm_response_universal import (
|
|
AssistantTurnStoppedMessage,
|
|
LLMContextAggregatorPair,
|
|
LLMUserAggregatorParams,
|
|
UserTurnStoppedMessage,
|
|
)
|
|
from pipecat.serializers.protobuf import ProtobufFrameSerializer
|
|
from pipecat.serializers.base_serializer import FrameSerializer
|
|
from pipecat.transports.websocket.fastapi import (
|
|
FastAPIWebsocketParams,
|
|
FastAPIWebsocketTransport,
|
|
)
|
|
|
|
from .config import EngineConfig
|
|
from .product_protocol import ProductWebsocketSerializer
|
|
from .services import create_llm_service, create_stt_service, create_tts_service
|
|
from .text_input import ProductTextInputProcessor
|
|
from .text_stream import ProductTextStreamProcessor
|
|
|
|
|
|
async def run_voice_pipeline(websocket, config: EngineConfig) -> None:
|
|
await run_pipeline_with_serializer(
|
|
websocket,
|
|
config,
|
|
serializer=ProtobufFrameSerializer(),
|
|
client_label="Pipecat protobuf",
|
|
)
|
|
|
|
|
|
async def run_product_voice_pipeline(websocket, config: EngineConfig) -> None:
|
|
await run_pipeline_with_serializer(
|
|
websocket,
|
|
config,
|
|
serializer=ProductWebsocketSerializer(
|
|
sample_rate=config.audio.sample_rate_hz,
|
|
channels=config.audio.channels,
|
|
),
|
|
client_label="Product JSON",
|
|
)
|
|
|
|
|
|
async def run_pipeline_with_serializer(
|
|
websocket,
|
|
config: EngineConfig,
|
|
*,
|
|
serializer: FrameSerializer,
|
|
client_label: str,
|
|
) -> None:
|
|
transport = FastAPIWebsocketTransport(
|
|
websocket=websocket,
|
|
params=FastAPIWebsocketParams(
|
|
audio_in_enabled=True,
|
|
audio_out_enabled=True,
|
|
audio_in_sample_rate=config.audio.sample_rate_hz,
|
|
audio_out_sample_rate=config.audio.sample_rate_hz,
|
|
audio_in_channels=config.audio.channels,
|
|
audio_out_channels=config.audio.channels,
|
|
serializer=serializer,
|
|
session_timeout=None,
|
|
),
|
|
)
|
|
|
|
stt = create_stt_service(config.services.stt)
|
|
llm = create_llm_service(config.services.llm)
|
|
tts = create_tts_service(config.services.tts, config.audio)
|
|
|
|
messages = [{"role": "developer", "content": config.agent.system_prompt}]
|
|
if config.agent.greeting and config.agent.greeting_mode == "generated":
|
|
messages.append({"role": "developer", "content": config.agent.greeting})
|
|
|
|
context = LLMContext(messages)
|
|
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
|
context,
|
|
user_params=LLMUserAggregatorParams(vad_analyzer=SileroVADAnalyzer()),
|
|
)
|
|
|
|
pipeline = Pipeline(
|
|
[
|
|
transport.input(),
|
|
ProductTextInputProcessor(),
|
|
stt,
|
|
user_aggregator,
|
|
llm,
|
|
ProductTextStreamProcessor(),
|
|
tts,
|
|
transport.output(),
|
|
assistant_aggregator,
|
|
]
|
|
)
|
|
|
|
task = PipelineTask(
|
|
pipeline,
|
|
params=PipelineParams(
|
|
audio_in_sample_rate=config.audio.sample_rate_hz,
|
|
audio_out_sample_rate=config.audio.sample_rate_hz,
|
|
enable_metrics=True,
|
|
enable_usage_metrics=True,
|
|
enable_heartbeats=True,
|
|
),
|
|
idle_timeout_secs=config.session.inactivity_timeout_sec,
|
|
)
|
|
|
|
@transport.event_handler("on_client_connected")
|
|
async def on_client_connected(_transport, _client):
|
|
logger.info(f"{client_label} websocket client connected")
|
|
if config.agent.greeting_mode == "fixed" and config.agent.greeting:
|
|
await task.queue_frames([TTSSpeakFrame(config.agent.greeting)])
|
|
elif config.agent.greeting_mode == "generated":
|
|
await task.queue_frames([LLMRunFrame()])
|
|
|
|
@transport.event_handler("on_client_disconnected")
|
|
async def on_client_disconnected(_transport, _client):
|
|
logger.info(f"{client_label} websocket client disconnected")
|
|
await task.cancel()
|
|
|
|
@transport.event_handler("on_session_timeout")
|
|
async def on_session_timeout(_transport, _client):
|
|
logger.info(f"{client_label} websocket session timed out")
|
|
await task.cancel()
|
|
|
|
@user_aggregator.event_handler("on_user_turn_stopped")
|
|
async def on_user_turn_stopped(_aggregator, _strategy, message: UserTurnStoppedMessage):
|
|
logger.info(f"User: {message.content}")
|
|
text = (message.content or "").strip()
|
|
if not text:
|
|
return
|
|
await task.queue_frame(
|
|
OutputTransportMessageUrgentFrame(
|
|
message={
|
|
"type": "input.transcript.final",
|
|
"text": text,
|
|
"user_id": message.user_id,
|
|
"timestamp": message.timestamp,
|
|
}
|
|
)
|
|
)
|
|
|
|
# NOTE: assistant turn started/final events are emitted by
|
|
# ProductTextStreamProcessor, upstream of TTS, so text streams to the
|
|
# client ahead of audio. This logger is kept for server-side visibility.
|
|
@assistant_aggregator.event_handler("on_assistant_turn_stopped")
|
|
async def on_assistant_turn_stopped(_aggregator, message: AssistantTurnStoppedMessage):
|
|
logger.info(f"Assistant: {message.content}")
|
|
|
|
runner = PipelineRunner(handle_sigint=False)
|
|
await runner.run(task)
|