141 lines
4.9 KiB
Python
141 lines
4.9 KiB
Python
#
|
|
# Copyright (c) 2025, Daily
|
|
#
|
|
# SPDX-License-Identifier: BSD 2-Clause License
|
|
#
|
|
|
|
import datetime
|
|
import io
|
|
import os
|
|
import sys
|
|
import wave
|
|
|
|
import aiofiles
|
|
from dotenv import load_dotenv
|
|
from fastapi import WebSocket
|
|
from loguru import logger
|
|
|
|
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
|
from pipecat.pipeline.pipeline import Pipeline
|
|
from pipecat.pipeline.runner import PipelineRunner
|
|
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
|
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
|
from pipecat.processors.audio.audio_buffer_processor import AudioBufferProcessor
|
|
from pipecat.serializers.twilio import TwilioFrameSerializer
|
|
from pipecat.services.cartesia import CartesiaTTSService
|
|
from pipecat.services.deepgram import DeepgramSTTService
|
|
from pipecat.services.openai import OpenAILLMService
|
|
from pipecat.transports.network.fastapi_websocket import (
|
|
FastAPIWebsocketParams,
|
|
FastAPIWebsocketTransport,
|
|
)
|
|
|
|
load_dotenv(override=True)
|
|
|
|
logger.remove(0)
|
|
logger.add(sys.stderr, level="DEBUG")
|
|
|
|
|
|
async def save_audio(server_name: str, audio: bytes, sample_rate: int, num_channels: int):
|
|
if len(audio) > 0:
|
|
filename = (
|
|
f"{server_name}_recording_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
|
|
)
|
|
with io.BytesIO() as buffer:
|
|
with wave.open(buffer, "wb") as wf:
|
|
wf.setsampwidth(2)
|
|
wf.setnchannels(num_channels)
|
|
wf.setframerate(sample_rate)
|
|
wf.writeframes(audio)
|
|
async with aiofiles.open(filename, "wb") as file:
|
|
await file.write(buffer.getvalue())
|
|
logger.info(f"Merged audio saved to {filename}")
|
|
else:
|
|
logger.info("No audio data to save")
|
|
|
|
|
|
async def run_bot(websocket_client: WebSocket, stream_sid: str, testing: bool):
|
|
transport = FastAPIWebsocketTransport(
|
|
websocket=websocket_client,
|
|
params=FastAPIWebsocketParams(
|
|
audio_in_enabled=True,
|
|
audio_out_enabled=True,
|
|
add_wav_header=False,
|
|
vad_enabled=True,
|
|
vad_analyzer=SileroVADAnalyzer(),
|
|
vad_audio_passthrough=True,
|
|
serializer=TwilioFrameSerializer(stream_sid),
|
|
),
|
|
)
|
|
|
|
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
|
|
|
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"), audio_passthrough=True)
|
|
|
|
tts = CartesiaTTSService(
|
|
api_key=os.getenv("CARTESIA_API_KEY"),
|
|
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
|
push_silence_after_stop=testing,
|
|
)
|
|
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": "You are an elementary teacher in an audio call. Your output will be converted to audio so don't include special characters in your answers. Respond to what the student said in a short short sentence.",
|
|
},
|
|
]
|
|
|
|
context = OpenAILLMContext(messages)
|
|
context_aggregator = llm.create_context_aggregator(context)
|
|
|
|
# NOTE: Watch out! This will save all the conversation in memory. You can
|
|
# pass `buffer_size` to get periodic callbacks.
|
|
audiobuffer = AudioBufferProcessor(user_continuous_stream=not testing)
|
|
|
|
pipeline = Pipeline(
|
|
[
|
|
transport.input(), # Websocket input from client
|
|
stt, # Speech-To-Text
|
|
context_aggregator.user(),
|
|
llm, # LLM
|
|
tts, # Text-To-Speech
|
|
transport.output(), # Websocket output to client
|
|
audiobuffer, # Used to buffer the audio in the pipeline
|
|
context_aggregator.assistant(),
|
|
]
|
|
)
|
|
|
|
task = PipelineTask(
|
|
pipeline,
|
|
params=PipelineParams(
|
|
audio_in_sample_rate=8000,
|
|
audio_out_sample_rate=8000,
|
|
allow_interruptions=True,
|
|
),
|
|
)
|
|
|
|
@transport.event_handler("on_client_connected")
|
|
async def on_client_connected(transport, client):
|
|
# Start recording.
|
|
await audiobuffer.start_recording()
|
|
# Kick off the conversation.
|
|
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
|
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
|
|
|
@transport.event_handler("on_client_disconnected")
|
|
async def on_client_disconnected(transport, client):
|
|
await task.cancel()
|
|
|
|
@audiobuffer.event_handler("on_audio_data")
|
|
async def on_audio_data(buffer, audio, sample_rate, num_channels):
|
|
server_name = f"server_{websocket_client.client.port}"
|
|
await save_audio(server_name, audio, sample_rate, num_channels)
|
|
|
|
# We use `handle_sigint=False` because `uvicorn` is controlling keyboard
|
|
# interruptions. We use `force_gc=True` to force garbage collection after
|
|
# the runner finishes running a task which could be useful for long running
|
|
# applications with multiple clients connecting.
|
|
runner = PipelineRunner(handle_sigint=False, force_gc=True)
|
|
|
|
await runner.run(task)
|