Merge pull request #275 from pipecat-ai/aleix/add-missing-keyword-separators
add missing keyword separators
This commit is contained in:
12
CHANGELOG.md
12
CHANGELOG.md
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Added `XTTSService`. This is a local Text-To-Speech service.
|
||||
See https://github.com/coqui-ai/TTS
|
||||
|
||||
- It is now possible to specify a Silero VAD version when using `SileroVADAnalyzer`
|
||||
or `SileroVAD`.
|
||||
|
||||
@@ -25,8 +28,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
processing metrics indicate the time a processor needs to generate all its
|
||||
output. Note that not all processors generate these kind of metrics.
|
||||
|
||||
### Changed
|
||||
|
||||
- `WhisperSTTService` model can now also be a string.
|
||||
|
||||
- Added missing * keyword separators in services.
|
||||
|
||||
### Fixed
|
||||
|
||||
- `WebsocketServerTransport` doesn't try to send frames anymore if serializers
|
||||
returns `None`.
|
||||
|
||||
- Fixed an issue where exceptions that occurred inside frame processors were
|
||||
being swallowed and not displayed.
|
||||
|
||||
|
||||
96
examples/foundational/07i-interruptible-xtts.py
Normal file
96
examples/foundational/07i-interruptible-xtts.py
Normal file
@@ -0,0 +1,96 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator, LLMUserResponseAggregator)
|
||||
from pipecat.services.deepgram import DeepgramSTTService, DeepgramTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.xtts import XTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main(room_url: str, token):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
)
|
||||
)
|
||||
|
||||
tts = XTTSService(
|
||||
aiohttp_session=session,
|
||||
voice_id="Claribel Dervla",
|
||||
language="en",
|
||||
base_url="http://localhost:8000"
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
|
||||
pipeline = Pipeline([
|
||||
transport.input(), # Transport user input
|
||||
tma_in, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out # Assistant spoken responses
|
||||
])
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
messages.append(
|
||||
{"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
(url, token) = configure()
|
||||
asyncio.run(main(url, token))
|
||||
@@ -14,10 +14,11 @@ class AsyncFrameProcessor(FrameProcessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
**kwargs):
|
||||
super().__init__(name, loop, **kwargs)
|
||||
super().__init__(name=name, loop=loop, **kwargs)
|
||||
|
||||
self._create_push_task()
|
||||
|
||||
|
||||
@@ -66,6 +66,7 @@ class FrameProcessor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
name: str | None = None,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
**kwargs):
|
||||
|
||||
@@ -118,7 +118,7 @@ class LLMService(AIService):
|
||||
|
||||
|
||||
class TTSService(AIService):
|
||||
def __init__(self, aggregate_sentences: bool = True, **kwargs):
|
||||
def __init__(self, *, aggregate_sentences: bool = True, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._aggregate_sentences: bool = aggregate_sentences
|
||||
self._current_sentence: str = ""
|
||||
@@ -180,6 +180,7 @@ class STTService(AIService):
|
||||
"""STTService is a base class for speech-to-text services."""
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
min_volume: float = 0.6,
|
||||
max_silence_secs: float = 0.3,
|
||||
max_buffer_secs: float = 1.5,
|
||||
|
||||
@@ -41,6 +41,7 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "claude-3-opus-20240229",
|
||||
max_tokens: int = 1024):
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from typing import AsyncGenerator
|
||||
@@ -18,11 +17,10 @@ from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
SystemFrame,
|
||||
TranscriptionFrame)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AIService, AsyncAIService, TTSService
|
||||
from pipecat.services.ai_services import AsyncAIService, TTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -96,6 +94,7 @@ class DeepgramTTSService(TTSService):
|
||||
|
||||
class DeepgramSTTService(AsyncAIService):
|
||||
def __init__(self,
|
||||
*,
|
||||
api_key: str,
|
||||
url: str = "",
|
||||
live_options: LiveOptions = LiveOptions(
|
||||
|
||||
@@ -19,6 +19,7 @@ except ModuleNotFoundError as e:
|
||||
|
||||
class FireworksLLMService(BaseOpenAILLMService):
|
||||
def __init__(self,
|
||||
*,
|
||||
model: str = "accounts/fireworks/models/firefunction-v1",
|
||||
base_url: str = "https://api.fireworks.ai/inference/v1"):
|
||||
super().__init__(model, base_url)
|
||||
|
||||
@@ -42,7 +42,7 @@ class GoogleLLMService(LLMService):
|
||||
franca for all LLM services, so that it is easy to switch between different LLMs.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, model: str = "gemini-1.5-flash-latest", **kwargs):
|
||||
def __init__(self, *, api_key: str, model: str = "gemini-1.5-flash-latest", **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
gai.configure(api_key=api_key)
|
||||
self._client = gai.GenerativeModel(model)
|
||||
|
||||
@@ -46,6 +46,7 @@ def detect_device():
|
||||
class MoondreamService(VisionService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model="vikhyatk/moondream2",
|
||||
revision="2024-04-02",
|
||||
use_cpu=False
|
||||
|
||||
@@ -9,5 +9,5 @@ from pipecat.services.openai import BaseOpenAILLMService
|
||||
|
||||
class OLLamaLLMService(BaseOpenAILLMService):
|
||||
|
||||
def __init__(self, model: str = "llama2", base_url: str = "http://localhost:11434/v1"):
|
||||
def __init__(self, *, model: str = "llama2", base_url: str = "http://localhost:11434/v1"):
|
||||
super().__init__(model=model, base_url=base_url, api_key="ollama")
|
||||
|
||||
@@ -67,7 +67,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
calls from the LLM.
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, api_key=None, base_url=None, **kwargs):
|
||||
def __init__(self, *, model: str, api_key=None, base_url=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._model: str = model
|
||||
self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs)
|
||||
@@ -236,8 +236,8 @@ class BaseOpenAILLMService(LLMService):
|
||||
|
||||
class OpenAILLMService(BaseOpenAILLMService):
|
||||
|
||||
def __init__(self, model="gpt-4o", **kwargs):
|
||||
super().__init__(model, **kwargs)
|
||||
def __init__(self, *, model: str = "gpt-4o", **kwargs):
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
|
||||
class OpenAIImageGenService(ImageGenService):
|
||||
|
||||
@@ -25,6 +25,7 @@ class OpenPipeLLMService(BaseOpenAILLMService):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
@@ -33,9 +34,9 @@ class OpenPipeLLMService(BaseOpenAILLMService):
|
||||
tags: Dict[str, str] | None = None,
|
||||
**kwargs):
|
||||
super().__init__(
|
||||
model,
|
||||
api_key,
|
||||
base_url,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
openpipe_api_key=openpipe_api_key,
|
||||
openpipe_base_url=openpipe_base_url,
|
||||
**kwargs)
|
||||
|
||||
@@ -42,7 +42,8 @@ class WhisperSTTService(STTService):
|
||||
"""Class to transcribe audio with a locally-downloaded Whisper model"""
|
||||
|
||||
def __init__(self,
|
||||
model: Model = Model.DISTIL_MEDIUM_EN,
|
||||
*,
|
||||
model: str | Model = Model.DISTIL_MEDIUM_EN,
|
||||
device: str = "auto",
|
||||
compute_type: str = "default",
|
||||
no_speech_prob: float = 0.4,
|
||||
@@ -51,7 +52,7 @@ class WhisperSTTService(STTService):
|
||||
super().__init__(**kwargs)
|
||||
self._device: str = device
|
||||
self._compute_type = compute_type
|
||||
self._model_name: Model = model
|
||||
self._model_name: str | Model = model
|
||||
self._no_speech_prob = no_speech_prob
|
||||
self._model: WhisperModel | None = None
|
||||
self._load()
|
||||
@@ -64,7 +65,7 @@ class WhisperSTTService(STTService):
|
||||
this model is being run, it will take time to download."""
|
||||
logger.debug("Loading Whisper model...")
|
||||
self._model = WhisperModel(
|
||||
self._model_name.value,
|
||||
self._model_name.value if isinstance(self._model_name, Enum) else self._model_name,
|
||||
device=self._device,
|
||||
compute_type=self._compute_type)
|
||||
logger.debug("Loaded Whisper model")
|
||||
|
||||
@@ -24,13 +24,14 @@ except ModuleNotFoundError as e:
|
||||
logger.error("In order to use XTTS, you need to `pip install pipecat-ai[xtts]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
#####
|
||||
## The server below can connect to XTTS through a local running docker
|
||||
##
|
||||
## Docker command: $ docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 ghcr.io/coqui-ai/xtts-streaming-server:latest-cuda121
|
||||
##
|
||||
## You can find more information on the official repo: https://github.com/coqui-ai/xtts-streaming-server
|
||||
####
|
||||
|
||||
# The server below can connect to XTTS through a local running docker
|
||||
#
|
||||
# Docker command: $ docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 ghcr.io/coqui-ai/xtts-streaming-server:latest-cuda121
|
||||
#
|
||||
# You can find more information on the official repo:
|
||||
# https://github.com/coqui-ai/xtts-streaming-server
|
||||
|
||||
|
||||
class XTTSService(TTSService):
|
||||
|
||||
@@ -40,7 +41,7 @@ class XTTSService(TTSService):
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
voice_id: str,
|
||||
language: str,
|
||||
base_url:str,
|
||||
base_url: str,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -58,13 +59,13 @@ class XTTSService(TTSService):
|
||||
embeddings = self._studio_speakers[self._voice_id]
|
||||
|
||||
url = self._base_url + "/tts_stream"
|
||||
|
||||
payload={
|
||||
"text": text.replace('.','').replace('*',''),
|
||||
|
||||
payload = {
|
||||
"text": text.replace('.', '').replace('*', ''),
|
||||
"language": self._language,
|
||||
"speaker_embedding": embeddings["speaker_embedding"],
|
||||
"gpt_cond_latent": embeddings["gpt_cond_latent"],
|
||||
"add_wav_header": True,
|
||||
"add_wav_header": False,
|
||||
"stream_chunk_size": 20,
|
||||
}
|
||||
|
||||
@@ -76,7 +77,7 @@ class XTTSService(TTSService):
|
||||
logger.error(f"{self} error getting audio (status: {r.status}, error: {text})")
|
||||
yield ErrorFrame(f"Error getting audio (status: {r.status}, error: {text})")
|
||||
return
|
||||
|
||||
|
||||
buffer = bytearray()
|
||||
|
||||
async for chunk in r.content.iter_chunked(1024):
|
||||
@@ -84,14 +85,14 @@ class XTTSService(TTSService):
|
||||
await self.stop_ttfb_metrics()
|
||||
# Append new chunk to the buffer
|
||||
buffer.extend(chunk)
|
||||
|
||||
|
||||
# Check if buffer has enough data for processing
|
||||
while len(buffer) >= 48000: # Assuming at least 0.5 seconds of audio data at 24000 Hz
|
||||
# Process the buffer up to a safe size for resampling
|
||||
process_data = buffer[:48000]
|
||||
# Remove processed data from buffer
|
||||
buffer = buffer[48000:]
|
||||
|
||||
|
||||
# Convert the byte data to numpy array for resampling
|
||||
audio_np = np.frombuffer(process_data, dtype=np.int16)
|
||||
# Resample the audio from 24000 Hz to 16000 Hz
|
||||
@@ -108,4 +109,4 @@ class XTTSService(TTSService):
|
||||
resampled_audio = resampy.resample(audio_np, 24000, 16000)
|
||||
resampled_audio_bytes = resampled_audio.astype(np.int16).tobytes()
|
||||
frame = AudioRawFrame(resampled_audio_bytes, 16000, 1)
|
||||
yield frame
|
||||
yield frame
|
||||
|
||||
@@ -124,6 +124,9 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
self._websocket = websocket
|
||||
|
||||
async def write_raw_audio_frames(self, frames: bytes):
|
||||
if not self._websocket:
|
||||
return
|
||||
|
||||
self._audio_buffer += frames
|
||||
while len(self._audio_buffer) >= self._params.audio_frame_size:
|
||||
frame = AudioRawFrame(
|
||||
@@ -148,8 +151,8 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
frame = wav_frame
|
||||
|
||||
proto = self._params.serializer.serialize(frame)
|
||||
|
||||
await self._websocket.send(proto)
|
||||
if proto:
|
||||
await self._websocket.send(proto)
|
||||
|
||||
self._audio_buffer = self._audio_buffer[self._params.audio_frame_size:]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user