diff --git a/examples/foundational/07y-interruptible-groq.py b/examples/foundational/07y-interruptible-groq.py new file mode 100644 index 000000000..48d0eb700 --- /dev/null +++ b/examples/foundational/07y-interruptible-groq.py @@ -0,0 +1,101 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.groq import GroqLLMService, GroqSTTService, GroqTTSService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + # transcription_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, + ), + ) + + stt = GroqSTTService(api_key=os.getenv("GROQ_API_KEY")) + + llm = GroqLLMService(api_key=os.getenv("GROQ_API_KEY"), model="llama-3.3-70b-versatile") + + tts = GroqTTSService(api_key=os.getenv("GROQ_API_KEY"), voice_id="Atlas-PlayAI") + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + ] + ) + + task = PipelineTask( + pipeline, + params=PipelineParams( + allow_interruptions=True, + enable_metrics=True, + enable_usage_metrics=True, + ), + ) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + await transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([context_aggregator.user().get_context_frame()]) + + @transport.event_handler("on_participant_left") + async def on_participant_left(transport, participant, reason): + await task.cancel() + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index d2fa46dcf..da9c18dd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ] gladia = [ "websockets~=13.1" ] google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4" ] grok = [] -groq = [] +groq = [ "groq~=0.20.0" ] gstreamer = [ "pygobject~=3.50.0" ] fireworks = [] krisp = [ "pipecat-ai-krisp~=0.3.0" ] diff --git a/src/pipecat/services/groq.py b/src/pipecat/services/groq.py index 66cc9357f..1b3570abb 100644 --- a/src/pipecat/services/groq.py +++ b/src/pipecat/services/groq.py @@ -5,10 +5,14 @@ # -from typing import Optional +from typing import AsyncGenerator, Optional +from groq import AsyncGroq from loguru import logger +from pydantic import BaseModel +from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame +from pipecat.services.ai_services import InterruptibleTTSService from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription from pipecat.services.openai import OpenAILLMService from pipecat.transcriptions.language import Language @@ -98,3 +102,75 @@ class GroqSTTService(BaseWhisperSTTService): kwargs["temperature"] = self._temperature return await self._client.audio.transcriptions.create(**kwargs) + + +class GroqTTSService(InterruptibleTTSService): + class InputParams(BaseModel): + language: Optional[Language] = Language.EN + speed: Optional[float] = 1.0 + seed: Optional[int] = None + + def __init__( + self, + *, + api_key: str, + output_format: str = "wav", + params: InputParams = InputParams(), + model_name: str = "playai-tts", + voice_id: str = "Atlas-PlayAI", + **kwargs, + ): + super().__init__( + pause_frame_processing=True, + **kwargs, + ) + + self._api_key = api_key + self._model_name = model_name + self._output_format = output_format + self._voice_id = voice_id + self._params = params + + self._client = AsyncGroq(api_key=self._api_key) + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"{self}: Generating TTS [{text}]") + measuring_ttfb = True + await self.start_ttfb_metrics() + yield TTSStartedFrame() + + response = await self._client.audio.speech.create( + model=self._model_name, + voice=self._voice_id, + response_format=self._output_format, + input=text, + ) + + async for data in response.iter_bytes(4096): + if measuring_ttfb: + await self.stop_ttfb_metrics() + measuring_ttfb = False + # remove wav header if present + if data.startswith(b"RIFF"): + continue + yield TTSAudioRawFrame(data, 48000, 1) + + yield TTSStoppedFrame() + + async def _connect(self) -> None: + pass + + async def _disconnect(self) -> None: + pass + + async def _connect_websocket(self) -> None: + pass + + async def _disconnect_websocket(self) -> None: + pass + + async def _receive_messages(self) -> None: + pass