This commit is contained in:
Kwindla Hultman Kramer
2025-03-25 18:47:01 -07:00
committed by Mark Backman
parent 59fdfe697d
commit 060bb4c26b
3 changed files with 179 additions and 2 deletions

View File

@@ -0,0 +1,101 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.groq import GroqLLMService, GroqSTTService, GroqTTSService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
# transcription_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
),
)
stt = GroqSTTService(api_key=os.getenv("GROQ_API_KEY"))
llm = GroqLLMService(api_key=os.getenv("GROQ_API_KEY"), model="llama-3.3-70b-versatile")
tts = GroqTTSService(api_key=os.getenv("GROQ_API_KEY"), voice_id="Atlas-PlayAI")
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt,
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
params=PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([context_aggregator.user().get_context_frame()])
@transport.event_handler("on_participant_left")
async def on_participant_left(transport, participant, reason):
await task.cancel()
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -56,7 +56,7 @@ fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
gladia = [ "websockets~=13.1" ]
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4" ]
grok = []
groq = []
groq = [ "groq~=0.20.0" ]
gstreamer = [ "pygobject~=3.50.0" ]
fireworks = []
krisp = [ "pipecat-ai-krisp~=0.3.0" ]

View File

@@ -5,10 +5,14 @@
#
from typing import Optional
from typing import AsyncGenerator, Optional
from groq import AsyncGroq
from loguru import logger
from pydantic import BaseModel
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
from pipecat.services.ai_services import InterruptibleTTSService
from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
from pipecat.services.openai import OpenAILLMService
from pipecat.transcriptions.language import Language
@@ -98,3 +102,75 @@ class GroqSTTService(BaseWhisperSTTService):
kwargs["temperature"] = self._temperature
return await self._client.audio.transcriptions.create(**kwargs)
class GroqTTSService(InterruptibleTTSService):
class InputParams(BaseModel):
language: Optional[Language] = Language.EN
speed: Optional[float] = 1.0
seed: Optional[int] = None
def __init__(
self,
*,
api_key: str,
output_format: str = "wav",
params: InputParams = InputParams(),
model_name: str = "playai-tts",
voice_id: str = "Atlas-PlayAI",
**kwargs,
):
super().__init__(
pause_frame_processing=True,
**kwargs,
)
self._api_key = api_key
self._model_name = model_name
self._output_format = output_format
self._voice_id = voice_id
self._params = params
self._client = AsyncGroq(api_key=self._api_key)
def can_generate_metrics(self) -> bool:
return True
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"{self}: Generating TTS [{text}]")
measuring_ttfb = True
await self.start_ttfb_metrics()
yield TTSStartedFrame()
response = await self._client.audio.speech.create(
model=self._model_name,
voice=self._voice_id,
response_format=self._output_format,
input=text,
)
async for data in response.iter_bytes(4096):
if measuring_ttfb:
await self.stop_ttfb_metrics()
measuring_ttfb = False
# remove wav header if present
if data.startswith(b"RIFF"):
continue
yield TTSAudioRawFrame(data, 48000, 1)
yield TTSStoppedFrame()
async def _connect(self) -> None:
pass
async def _disconnect(self) -> None:
pass
async def _connect_websocket(self) -> None:
pass
async def _disconnect_websocket(self) -> None:
pass
async def _receive_messages(self) -> None:
pass