wip
This commit is contained in:
committed by
Mark Backman
parent
59fdfe697d
commit
060bb4c26b
101
examples/foundational/07y-interruptible-groq.py
Normal file
101
examples/foundational/07y-interruptible-groq.py
Normal file
@@ -0,0 +1,101 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.groq import GroqLLMService, GroqSTTService, GroqTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
# transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_audio_passthrough=True,
|
||||
),
|
||||
)
|
||||
|
||||
stt = GroqSTTService(api_key=os.getenv("GROQ_API_KEY"))
|
||||
|
||||
llm = GroqLLMService(api_key=os.getenv("GROQ_API_KEY"), model="llama-3.3-70b-versatile")
|
||||
|
||||
tts = GroqTTSService(api_key=os.getenv("GROQ_API_KEY"), voice_id="Atlas-PlayAI")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -56,7 +56,7 @@ fish = [ "ormsgpack~=1.7.0", "websockets~=13.1" ]
|
||||
gladia = [ "websockets~=13.1" ]
|
||||
google = [ "google-cloud-speech~=2.31.1", "google-cloud-texttospeech~=2.25.1", "google-genai~=1.7.0", "google-generativeai~=0.8.4" ]
|
||||
grok = []
|
||||
groq = []
|
||||
groq = [ "groq~=0.20.0" ]
|
||||
gstreamer = [ "pygobject~=3.50.0" ]
|
||||
fireworks = []
|
||||
krisp = [ "pipecat-ai-krisp~=0.3.0" ]
|
||||
|
||||
@@ -5,10 +5,14 @@
|
||||
#
|
||||
|
||||
|
||||
from typing import Optional
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from groq import AsyncGroq
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
|
||||
from pipecat.services.ai_services import InterruptibleTTSService
|
||||
from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
@@ -98,3 +102,75 @@ class GroqSTTService(BaseWhisperSTTService):
|
||||
kwargs["temperature"] = self._temperature
|
||||
|
||||
return await self._client.audio.transcriptions.create(**kwargs)
|
||||
|
||||
|
||||
class GroqTTSService(InterruptibleTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[float] = 1.0
|
||||
seed: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
output_format: str = "wav",
|
||||
params: InputParams = InputParams(),
|
||||
model_name: str = "playai-tts",
|
||||
voice_id: str = "Atlas-PlayAI",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
pause_frame_processing=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._model_name = model_name
|
||||
self._output_format = output_format
|
||||
self._voice_id = voice_id
|
||||
self._params = params
|
||||
|
||||
self._client = AsyncGroq(api_key=self._api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
measuring_ttfb = True
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
response = await self._client.audio.speech.create(
|
||||
model=self._model_name,
|
||||
voice=self._voice_id,
|
||||
response_format=self._output_format,
|
||||
input=text,
|
||||
)
|
||||
|
||||
async for data in response.iter_bytes(4096):
|
||||
if measuring_ttfb:
|
||||
await self.stop_ttfb_metrics()
|
||||
measuring_ttfb = False
|
||||
# remove wav header if present
|
||||
if data.startswith(b"RIFF"):
|
||||
continue
|
||||
yield TTSAudioRawFrame(data, 48000, 1)
|
||||
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
async def _connect(self) -> None:
|
||||
pass
|
||||
|
||||
async def _disconnect(self) -> None:
|
||||
pass
|
||||
|
||||
async def _connect_websocket(self) -> None:
|
||||
pass
|
||||
|
||||
async def _disconnect_websocket(self) -> None:
|
||||
pass
|
||||
|
||||
async def _receive_messages(self) -> None:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user