Compare commits
49 Commits
async-reba
...
v0.0.42
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
65eeb0f1f6 | ||
|
|
1d7d0bb1ea | ||
|
|
598936bc53 | ||
|
|
b1bf6f7733 | ||
|
|
75d27aeb9f | ||
|
|
0a37caf4b4 | ||
|
|
6db65f4335 | ||
|
|
3648874301 | ||
|
|
8bcb5d7fd2 | ||
|
|
8c01a900cd | ||
|
|
d378e699d2 | ||
|
|
c25c375c41 | ||
|
|
70c3ff31fd | ||
|
|
cd2e29f285 | ||
|
|
6d4d7d763d | ||
|
|
6c1851eef8 | ||
|
|
096a15eef6 | ||
|
|
3d642df2b0 | ||
|
|
d75a02dc51 | ||
|
|
28643b453d | ||
|
|
88cca7bf68 | ||
|
|
a397b859fe | ||
|
|
8aae4e9856 | ||
|
|
92d8b37229 | ||
|
|
0801fc578b | ||
|
|
0d5cb84531 | ||
|
|
47b943a117 | ||
|
|
128355add5 | ||
|
|
0499fe41e4 | ||
|
|
6ad3437fd2 | ||
|
|
a5c73ec829 | ||
|
|
def04ac0ce | ||
|
|
5d63615b1b | ||
|
|
90ee284fe0 | ||
|
|
539e0b66fb | ||
|
|
fef393dcac | ||
|
|
ed607d5c4b | ||
|
|
37da7e44cd | ||
|
|
69c7edd60c | ||
|
|
392f210371 | ||
|
|
9a63df1ea1 | ||
|
|
f8a75cede9 | ||
|
|
4d1e370e02 | ||
|
|
d080a31a5c | ||
|
|
a90ebdfe7c | ||
|
|
c8995b82e5 | ||
|
|
6b7f924af6 | ||
|
|
51580e5349 | ||
|
|
ed49cebf2c |
38
CHANGELOG.md
38
CHANGELOG.md
@@ -1,20 +1,29 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to **pipecat** will be documented in this file.
|
||||
All notable changes to **Pipecat** will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
## [0.0.42] - 2024-10-02
|
||||
|
||||
### Added
|
||||
|
||||
- Added Google TTS service and corresponding foundational example `07n-interruptible-google.py`
|
||||
- `SentryMetrics` has been added to report frame processor metrics to
|
||||
Sentry. This is now possible because `FrameProcessorMetrics` can now be passed
|
||||
to `FrameProcessor`.
|
||||
|
||||
- Added Google TTS service and corresponding foundational example
|
||||
`07n-interruptible-google.py`
|
||||
|
||||
- Added AWS Polly TTS support and `07m-interruptible-aws.py` as an example.
|
||||
|
||||
- Added InputParams to Azure TTS service.
|
||||
|
||||
- Added `LivekitTransport` (audio-only for now).
|
||||
|
||||
- RTVI 0.2.0 is now supported.
|
||||
|
||||
- All `FrameProcessors` can now register event handlers.
|
||||
|
||||
```
|
||||
@@ -86,8 +95,12 @@ async def on_connected(processor):
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated individual update settings frame classes into a single UpdateSettingsFrame
|
||||
class for STT, LLM, and TTS.
|
||||
- Context frames are now pushed downstream from assistant context aggregators.
|
||||
|
||||
- Removed Silero VAD torch dependency.
|
||||
|
||||
- Updated individual update settings frame classes into a single
|
||||
`ServiceUpdateSettingsFrame` class.
|
||||
|
||||
- We now distinguish between input and output audio and image frames. We
|
||||
introduce `InputAudioRawFrame`, `OutputAudioRawFrame`, `InputImageRawFrame`
|
||||
@@ -107,9 +120,9 @@ async def on_connected(processor):
|
||||
pipelines is synchronous (e.g. an HTTP-based service that waits for the
|
||||
response).
|
||||
|
||||
- `StartFrame` is back a system frame so we make sure it's processed immediately
|
||||
by all processors. `EndFrame` stays a control frame since it needs to be
|
||||
ordered allowing the frames in the pipeline to be processed.
|
||||
- `StartFrame` is back a system frame to make sure it's processed immediately by
|
||||
all processors. `EndFrame` stays a control frame since it needs to be ordered
|
||||
allowing the frames in the pipeline to be processed.
|
||||
|
||||
- Updated `MoondreamService` revision to `2024-08-26`.
|
||||
|
||||
@@ -133,6 +146,11 @@ async def on_connected(processor):
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI multiple function calls.
|
||||
|
||||
- Fixed a Cartesia TTS issue that would cause audio to be truncated in some
|
||||
cases.
|
||||
|
||||
- Fixed a `BaseOutputTransport` issue that would stop audio and video rendering
|
||||
tasks (after receiving and `EndFrame`) before the internal queue was emptied,
|
||||
causing the pipeline to finish prematurely.
|
||||
@@ -146,6 +164,10 @@ async def on_connected(processor):
|
||||
- `obj_id()` and `obj_count()` now use `itertools.count` avoiding the need of
|
||||
`threading.Lock`.
|
||||
|
||||
### Other
|
||||
|
||||
- Pipecat now uses Ruff as its formatter (https://github.com/astral-sh/ruff).
|
||||
|
||||
## [0.0.41] - 2024-08-22
|
||||
|
||||
### Added
|
||||
|
||||
@@ -82,6 +82,7 @@ async def main():
|
||||
self.frame = OutputAudioRawFrame(
|
||||
bytes(self.audio), frame.sample_rate, frame.num_channels
|
||||
)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
class ImageGrabber(FrameProcessor):
|
||||
def __init__(self):
|
||||
@@ -93,6 +94,7 @@ async def main():
|
||||
|
||||
if isinstance(frame, URLImageRawFrame):
|
||||
self.frame = frame
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
|
||||
@@ -5,29 +5,24 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.anthropic import AnthropicLLMService
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
@@ -69,17 +64,17 @@ async def main():
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
tma_in, # User responses
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out, # Assistant spoken responses
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -4,11 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -17,17 +21,11 @@ from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.playht import PlayHTTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.playht import PlayHTTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
|
||||
@@ -4,11 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -17,17 +21,10 @@ from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.openai import OpenAITTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.openai import OpenAILLMService, OpenAITTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
|
||||
@@ -5,29 +5,24 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.ai_services import OpenAILLMContext
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.together import TogetherLLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
@@ -72,25 +67,32 @@ async def main():
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond in plain language. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
user_aggregator = context_aggregator.user()
|
||||
assistant_aggregator = context_aggregator.assistant()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
tma_in, # User responses
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out, # Assistant spoken responses
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True, enable_metrics=True, enable_usage_metrics=True
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
|
||||
@@ -53,7 +53,6 @@ async def main():
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = GoogleTTSService(
|
||||
credentials=os.getenv("GOOGLE_CREDENTIALS"),
|
||||
voice_id="en-US-Neural2-J",
|
||||
params=GoogleTTSService.InputParams(language="en-US", rate="1.05"),
|
||||
)
|
||||
|
||||
@@ -5,25 +5,26 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.logger import FrameLogger
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
@@ -35,7 +36,7 @@ async def start_fetch_weather(function_name, llm, context):
|
||||
# can interrupt itself and/or cause audio overlapping glitches.
|
||||
# possible question for Aleix and Chad about what the right way
|
||||
# to trigger speech is, now, with the new queues/async/sync refactors.
|
||||
await llm.push_frame(TextFrame("Let me check on that. "))
|
||||
# await llm.push_frame(TextFrame("Let me check on that."))
|
||||
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
|
||||
|
||||
|
||||
@@ -69,9 +70,6 @@ async def main():
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
|
||||
|
||||
fl_in = FrameLogger("Inner")
|
||||
fl_out = FrameLogger("Outer")
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
@@ -108,11 +106,9 @@ async def main():
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
# fl_in,
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
# fl_out,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
|
||||
136
examples/foundational/14c-function-calling-together.py
Normal file
136
examples/foundational/14c-function-calling-together.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMContext
|
||||
from pipecat.services.together import TogetherLLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def start_fetch_weather(function_name, llm, context):
|
||||
# note: we can't push a frame to the LLM here. the bot
|
||||
# can interrupt itself and/or cause audio overlapping glitches.
|
||||
# possible question for Aleix and Chad about what the right way
|
||||
# to trigger speech is, now, with the new queues/async/sync refactors.
|
||||
# await llm.push_frame(TextFrame("Let me check on that."))
|
||||
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
await result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = TogetherLLMService(
|
||||
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
)
|
||||
# Register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
# await tts.say("Hi! Ask me about the weather in San Francisco.")
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
167
examples/foundational/14d-function-calling-video.py
Normal file
167
examples/foundational/14d-function-calling-video.py
Normal file
@@ -0,0 +1,167 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
video_participant_id = None
|
||||
|
||||
|
||||
async def get_weather(function_name, tool_call_id, arguments, llm, context, result_callback):
|
||||
location = arguments["location"]
|
||||
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||
|
||||
|
||||
async def get_image(function_name, tool_call_id, arguments, llm, context, result_callback):
|
||||
logger.debug(f"!!! IN get_image {video_participant_id}, {arguments}")
|
||||
question = arguments["question"]
|
||||
await llm.request_image_frame(user_id=video_participant_id, text_content=question)
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
llm.register_function("get_weather", get_weather)
|
||||
llm.register_function("get_image", get_image)
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
),
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_image",
|
||||
"description": "Get an image from the video stream.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to ask the AI to generate an image of",
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
system_prompt = """\
|
||||
You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
|
||||
|
||||
Your response will be turned into speech so use only simple words and punctuation.
|
||||
|
||||
You have access to two tools: get_weather and get_image.
|
||||
|
||||
You can respond to questions about the weather using the get_weather tool.
|
||||
|
||||
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
|
||||
indicate you should use the get_image tool are:
|
||||
- What do you see?
|
||||
- What's in the video?
|
||||
- Can you describe the video?
|
||||
- Tell me about what you see.
|
||||
- Tell me something interesting about what you see.
|
||||
- What's happening in the video?
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
global video_participant_id
|
||||
video_participant_id = participant["id"]
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
transport.capture_participant_video(video_participant_id, framerate=0)
|
||||
# Kick off the conversation.
|
||||
await tts.say("Hi! Ask me about the weather in San Francisco.")
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -5,10 +5,14 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -26,12 +30,6 @@ from pipecat.transports.services.daily import (
|
||||
)
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.together import TogetherLLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def get_current_weather(
|
||||
function_name, tool_call_id, arguments, llm, context, result_callback
|
||||
):
|
||||
logger.debug("IN get_current_weather")
|
||||
location = arguments["location"]
|
||||
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = TogetherLLMService(
|
||||
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||
model=os.getenv("TOGETHER_MODEL"),
|
||||
)
|
||||
llm.register_function("get_current_weather", get_current_weather)
|
||||
|
||||
weatherTool = {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
|
||||
system_prompt = f"""\
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function '{weatherTool["name"]}' to '{weatherTool["description"]}':
|
||||
{json.dumps(weatherTool)}
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>
|
||||
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": "Wait for the user to say something."},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(), # User speech to text
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True, enable_metrics=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,4 +1,4 @@
|
||||
DAILY_SAMPLE_ROOM_URL=https://yourdomain.daily.co/yourroom # (for joining the bot to the same room repeatedly for local dev)
|
||||
DAILY_API_KEY=7df...
|
||||
OPENAI_API_KEY=sk-PL...
|
||||
ELEVENLABS_API_KEY=aeb...
|
||||
CARTESIA_API_KEY=your_cartesia_api_key_here
|
||||
|
||||
2497
examples/storytelling-chatbot/frontend/package-lock.json
generated
2497
examples/storytelling-chatbot/frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -11,28 +11,28 @@
|
||||
"dependencies": {
|
||||
"@daily-co/daily-js": "^0.62.0",
|
||||
"@daily-co/daily-react": "^0.18.0",
|
||||
"@radix-ui/react-select": "^2.0.0",
|
||||
"@radix-ui/react-select": "^2.1.2",
|
||||
"@radix-ui/react-slot": "^1.0.2",
|
||||
"@tabler/icons-react": "^3.1.0",
|
||||
"@tabler/icons-react": "^3.19.0",
|
||||
"class-variance-authority": "^0.7.0",
|
||||
"clsx": "^2.1.0",
|
||||
"framer-motion": "^11.0.27",
|
||||
"next": "14.1.4",
|
||||
"react": "^18",
|
||||
"react-dom": "^18",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^11.9.0",
|
||||
"next": "^14.2.14",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"recoil": "^0.7.7",
|
||||
"tailwind-merge": "^2.2.2",
|
||||
"tailwind-merge": "^2.5.2",
|
||||
"tailwindcss-animate": "^1.0.7"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^18",
|
||||
"@types/react-dom": "^18",
|
||||
"autoprefixer": "^10.0.1",
|
||||
"eslint": "^8",
|
||||
"@types/node": "^20.16.10",
|
||||
"@types/react": "^18.3.11",
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"eslint": "^8.57.1",
|
||||
"eslint-config-next": "14.1.4",
|
||||
"postcss": "^8",
|
||||
"tailwindcss": "^3.4.3",
|
||||
"typescript": "^5"
|
||||
"postcss": "^8.4.47",
|
||||
"tailwindcss": "^3.4.13",
|
||||
"typescript": "^5.6.2"
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -143,7 +143,7 @@ async def main(room_url, token=None):
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
intro_task.queue_frame(EndFrame())
|
||||
await intro_task.queue_frame(EndFrame())
|
||||
await main_task.queue_frame(EndFrame())
|
||||
|
||||
@transport.event_handler("on_call_state_updated")
|
||||
|
||||
@@ -38,7 +38,7 @@ anthropic = [ "anthropic~=0.34.0" ]
|
||||
aws = [ "boto3~=1.35.27" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.40.0" ]
|
||||
cartesia = [ "cartesia~=1.0.13", "websockets~=12.0" ]
|
||||
daily = [ "daily-python~=0.10.1" ]
|
||||
daily = [ "daily-python~=0.11.0" ]
|
||||
deepgram = [ "deepgram-sdk~=3.5.0" ]
|
||||
elevenlabs = [ "websockets~=12.0" ]
|
||||
examples = [ "python-dotenv~=1.0.1", "flask~=3.0.3", "flask_cors~=4.0.1" ]
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pipecat.clocks.base_clock import BaseClock
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
@@ -527,45 +527,25 @@ class UserImageRequestFrame(ControlFrame):
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update LLM settings."""
|
||||
class ServiceUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update service settings."""
|
||||
|
||||
model: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
extra: dict = field(default_factory=dict)
|
||||
settings: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update TTS settings."""
|
||||
|
||||
model: Optional[str] = None
|
||||
voice: Optional[str] = None
|
||||
language: Optional[Language] = None
|
||||
speed: Optional[Union[str, float]] = None
|
||||
emotion: Optional[List[str]] = None
|
||||
engine: Optional[str] = None
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
emphasis: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
style_degree: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
class LLMUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update STT settings."""
|
||||
class TTSUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
model: Optional[str] = None
|
||||
language: Optional[Language] = None
|
||||
|
||||
@dataclass
|
||||
class STTUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -6,10 +6,11 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import ControlFrame, Frame, SystemFrame
|
||||
from pipecat.frames.frames import ControlFrame, EndFrame, Frame, SystemFrame
|
||||
from pipecat.pipeline.base_pipeline import BasePipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
@@ -17,6 +18,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncFrame(ControlFrame):
|
||||
"""This frame is used to know when the internal pipelines have finished."""
|
||||
|
||||
@@ -114,19 +116,25 @@ class SyncParallelPipeline(BasePipeline):
|
||||
):
|
||||
processor = obj["processor"]
|
||||
queue = obj["queue"]
|
||||
|
||||
await processor.process_frame(frame, direction)
|
||||
|
||||
# If we have a system frame we don't need to synchrnonize anything.
|
||||
if isinstance(frame, SystemFrame):
|
||||
await main_queue.put(frame)
|
||||
if isinstance(frame, (SystemFrame, EndFrame)):
|
||||
new_frame = await queue.get()
|
||||
if isinstance(new_frame, (SystemFrame, EndFrame)):
|
||||
await main_queue.put(new_frame)
|
||||
else:
|
||||
while not isinstance(new_frame, (SystemFrame, EndFrame)):
|
||||
await main_queue.put(new_frame)
|
||||
queue.task_done()
|
||||
new_frame = await queue.get()
|
||||
else:
|
||||
await processor.process_frame(SyncFrame(), direction)
|
||||
|
||||
frame = await queue.get()
|
||||
while not isinstance(frame, SyncFrame):
|
||||
await main_queue.put(frame)
|
||||
new_frame = await queue.get()
|
||||
while not isinstance(new_frame, SyncFrame):
|
||||
await main_queue.put(new_frame)
|
||||
queue.task_done()
|
||||
frame = await queue.get()
|
||||
new_frame = await queue.get()
|
||||
|
||||
if direction == FrameDirection.UPSTREAM:
|
||||
# If we get an upstream frame we process it in each sink.
|
||||
|
||||
@@ -6,12 +6,6 @@
|
||||
|
||||
from typing import List, Type
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContext,
|
||||
)
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
@@ -22,11 +16,16 @@ from pipecat.frames.frames import (
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class LLMResponseAggregator(FrameProcessor):
|
||||
@@ -40,6 +39,7 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
accumulator_frame: Type[TextFrame],
|
||||
interim_accumulator_frame: Type[TextFrame] | None = None,
|
||||
handle_interruptions: bool = False,
|
||||
expect_stripped_words: bool = True, # if True, need to add spaces between words
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -50,6 +50,7 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
self._accumulator_frame = accumulator_frame
|
||||
self._interim_accumulator_frame = interim_accumulator_frame
|
||||
self._handle_interruptions = handle_interruptions
|
||||
self._expect_stripped_words = expect_stripped_words
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
@@ -111,7 +112,10 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._accumulator_frame):
|
||||
if self._aggregating:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
if self._expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
# We have recevied a complete sentence, so if we have seen the
|
||||
# end frame and we were still aggregating, it means we should
|
||||
# send the aggregation.
|
||||
@@ -290,7 +294,7 @@ class LLMContextAggregator(LLMResponseAggregator):
|
||||
|
||||
|
||||
class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
|
||||
super().__init__(
|
||||
messages=[],
|
||||
context=context,
|
||||
@@ -299,6 +303,7 @@ class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
expect_stripped_words=expect_stripped_words,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
|
||||
@@ -60,6 +62,7 @@ class OpenAILLMContext:
|
||||
self._messages: List[ChatCompletionMessageParam] = messages if messages else []
|
||||
self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
|
||||
self._tools: List[ChatCompletionToolParam] | NotGiven = tools
|
||||
self._user_image_request_context = {}
|
||||
|
||||
@staticmethod
|
||||
def from_messages(messages: List[dict]) -> "OpenAILLMContext":
|
||||
@@ -114,6 +117,21 @@ class OpenAILLMContext:
|
||||
def get_messages_json(self) -> str:
|
||||
return json.dumps(self._messages, cls=CustomEncoder)
|
||||
|
||||
def get_messages_for_logging(self) -> str:
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item["type"] == "image_url":
|
||||
if item["image_url"]["url"].startswith("data:image/"):
|
||||
item["image_url"]["url"] = "data:image/..."
|
||||
if "mime_type" in msg and msg["mime_type"].startswith("image/"):
|
||||
msg["data"] = "..."
|
||||
msgs.append(msg)
|
||||
return json.dumps(msgs)
|
||||
|
||||
def set_tool_choice(self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven):
|
||||
self._tool_choice = tool_choice
|
||||
|
||||
@@ -122,6 +140,21 @@ class OpenAILLMContext:
|
||||
tools = NOT_GIVEN
|
||||
self._tools = tools
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
content = [
|
||||
{"type": "text", "text": text},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
|
||||
]
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
async def call_function(
|
||||
self,
|
||||
f: Callable[
|
||||
|
||||
@@ -8,7 +8,7 @@ import asyncio
|
||||
import io
|
||||
import wave
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -45,6 +45,7 @@ class AIService(FrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._model_name: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
@@ -63,6 +64,16 @@ class AIService(FrameProcessor):
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
pass
|
||||
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
logger.debug(f"Updating setting {key} to: [{value}] for {self.name}")
|
||||
self._settings[key] = value
|
||||
elif key == "model":
|
||||
self.set_model_name(value)
|
||||
else:
|
||||
logger.warning(f"Unknown setting for {self.name} service: {key}")
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -116,7 +127,7 @@ class LLMService(AIService):
|
||||
tool_call_id: str,
|
||||
function_name: str,
|
||||
arguments: str,
|
||||
run_llm: bool,
|
||||
run_llm: bool = True,
|
||||
) -> None:
|
||||
f = None
|
||||
if function_name in self._callbacks.keys():
|
||||
@@ -169,6 +180,8 @@ class TTSService(AIService):
|
||||
self._push_stop_frames: bool = push_stop_frames
|
||||
self._stop_frame_timeout_s: float = stop_frame_timeout_s
|
||||
self._sample_rate: int = sample_rate
|
||||
self._voice_id: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
|
||||
self._stop_frame_task: Optional[asyncio.Task] = None
|
||||
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
|
||||
@@ -184,52 +197,8 @@ class TTSService(AIService):
|
||||
self.set_model_name(model)
|
||||
|
||||
@abstractmethod
|
||||
async def set_voice(self, voice: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_language(self, language: Language):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_speed(self, speed: Union[str, float]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_emotion(self, emotion: List[str]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_engine(self, engine: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_pitch(self, pitch: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_rate(self, rate: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_volume(self, volume: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_emphasis(self, emphasis: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_style(self, style: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_style_degree(self, style_degree: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_role(self, role: str):
|
||||
pass
|
||||
def set_voice(self, voice: str):
|
||||
self._voice_id = voice
|
||||
|
||||
@abstractmethod
|
||||
async def flush_audio(self):
|
||||
@@ -259,8 +228,25 @@ class TTSService(AIService):
|
||||
await self._stop_frame_task
|
||||
self._stop_frame_task = None
|
||||
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
logger.debug(f"Updating TTS setting {key} to: [{value}]")
|
||||
self._settings[key] = value
|
||||
if key == "language":
|
||||
self._settings[key] = Language(value)
|
||||
elif key == "model":
|
||||
self.set_model_name(value)
|
||||
elif key == "voice":
|
||||
self.set_voice(value)
|
||||
else:
|
||||
logger.warning(f"Unknown setting for TTS service: {key}")
|
||||
|
||||
async def say(self, text: str):
|
||||
aggregate_sentences = self._aggregate_sentences
|
||||
self._aggregate_sentences = False
|
||||
await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM)
|
||||
self._aggregate_sentences = aggregate_sentences
|
||||
await self.flush_audio()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
@@ -283,7 +269,7 @@ class TTSService(AIService):
|
||||
await self._push_tts_frames(frame.text)
|
||||
await self.flush_audio()
|
||||
elif isinstance(frame, TTSUpdateSettingsFrame):
|
||||
await self._update_tts_settings(frame)
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -308,16 +294,18 @@ class TTSService(AIService):
|
||||
text = frame.text
|
||||
else:
|
||||
self._current_sentence += frame.text
|
||||
if match_endofsentence(self._current_sentence):
|
||||
text = self._current_sentence
|
||||
self._current_sentence = ""
|
||||
eos_end_marker = match_endofsentence(self._current_sentence)
|
||||
if eos_end_marker:
|
||||
text = self._current_sentence[:eos_end_marker]
|
||||
self._current_sentence = self._current_sentence[eos_end_marker:]
|
||||
|
||||
if text:
|
||||
await self._push_tts_frames(text)
|
||||
|
||||
async def _push_tts_frames(self, text: str):
|
||||
text = text.strip()
|
||||
if not text:
|
||||
# Don't send only whitespace. This causes problems for some TTS models. But also don't
|
||||
# strip all whitespace, as whitespace can influence prosody.
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
await self.start_processing_metrics()
|
||||
@@ -328,34 +316,6 @@ class TTSService(AIService):
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
await self.push_frame(TextFrame(text))
|
||||
|
||||
async def _update_tts_settings(self, frame: TTSUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
await self.set_model(frame.model)
|
||||
if frame.voice is not None:
|
||||
await self.set_voice(frame.voice)
|
||||
if frame.language is not None:
|
||||
await self.set_language(frame.language)
|
||||
if frame.speed is not None:
|
||||
await self.set_speed(frame.speed)
|
||||
if frame.emotion is not None:
|
||||
await self.set_emotion(frame.emotion)
|
||||
if frame.engine is not None:
|
||||
await self.set_engine(frame.engine)
|
||||
if frame.pitch is not None:
|
||||
await self.set_pitch(frame.pitch)
|
||||
if frame.rate is not None:
|
||||
await self.set_rate(frame.rate)
|
||||
if frame.volume is not None:
|
||||
await self.set_volume(frame.volume)
|
||||
if frame.emphasis is not None:
|
||||
await self.set_emphasis(frame.emphasis)
|
||||
if frame.style is not None:
|
||||
await self.set_style(frame.style)
|
||||
if frame.style_degree is not None:
|
||||
await self.set_style_degree(frame.style_degree)
|
||||
if frame.role is not None:
|
||||
await self.set_role(frame.role)
|
||||
|
||||
async def _stop_frame_handler(self):
|
||||
try:
|
||||
has_started = False
|
||||
@@ -441,25 +401,29 @@ class STTService(AIService):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._settings: Dict[str, Any] = {}
|
||||
|
||||
@abstractmethod
|
||||
async def set_model(self, model: str):
|
||||
self.set_model_name(model)
|
||||
|
||||
@abstractmethod
|
||||
async def set_language(self, language: Language):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Returns transcript as a string"""
|
||||
pass
|
||||
|
||||
async def _update_stt_settings(self, frame: STTUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
await self.set_model(frame.model)
|
||||
if frame.language is not None:
|
||||
await self.set_language(frame.language)
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
logger.debug(f"Updating STT settings: {self._settings}")
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
logger.debug(f"Updating STT setting {key} to: [{value}]")
|
||||
self._settings[key] = value
|
||||
if key == "language":
|
||||
self._settings[key] = Language(value)
|
||||
elif key == "model":
|
||||
self.set_model_name(value)
|
||||
else:
|
||||
logger.warning(f"Unknown setting for STT service: {key}")
|
||||
|
||||
async def process_audio_frame(self, frame: AudioRawFrame):
|
||||
await self.process_generator(self.run_stt(frame.audio))
|
||||
@@ -473,7 +437,7 @@ class STTService(AIService):
|
||||
# push a TextFrame. We don't really want to push audio frames down.
|
||||
await self.process_audio_frame(frame)
|
||||
elif isinstance(frame, STTUpdateSettingsFrame):
|
||||
await self._update_stt_settings(frame)
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
# internal use only -- todo: refactor
|
||||
@dataclass
|
||||
class AnthropicImageMessageFrame(Frame):
|
||||
user_image_raw_frame: UserImageRawFrame
|
||||
@@ -95,12 +96,14 @@ class AnthropicLLMService(LLMService):
|
||||
super().__init__(**kwargs)
|
||||
self._client = AsyncAnthropic(api_key=api_key)
|
||||
self.set_model_name(model)
|
||||
self._max_tokens = params.max_tokens
|
||||
self._enable_prompt_caching_beta: bool = params.enable_prompt_caching_beta or False
|
||||
self._temperature = params.temperature
|
||||
self._top_k = params.top_k
|
||||
self._top_p = params.top_p
|
||||
self._extra = params.extra if isinstance(params.extra, dict) else {}
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"enable_prompt_caching_beta": params.enable_prompt_caching_beta or False,
|
||||
"temperature": params.temperature,
|
||||
"top_k": params.top_k,
|
||||
"top_p": params.top_p,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
@@ -110,35 +113,15 @@ class AnthropicLLMService(LLMService):
|
||||
return self._enable_prompt_caching_beta
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair:
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> AnthropicContextAggregatorPair:
|
||||
user = AnthropicUserContextAggregator(context)
|
||||
assistant = AnthropicAssistantContextAggregator(user)
|
||||
assistant = AnthropicAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool):
|
||||
logger.debug(f"Switching LLM enable_prompt_caching_beta to: [{enable_prompt_caching_beta}]")
|
||||
self._enable_prompt_caching_beta = enable_prompt_caching_beta
|
||||
|
||||
async def set_max_tokens(self, max_tokens: int):
|
||||
logger.debug(f"Switching LLM max_tokens to: [{max_tokens}]")
|
||||
self._max_tokens = max_tokens
|
||||
|
||||
async def set_temperature(self, temperature: float):
|
||||
logger.debug(f"Switching LLM temperature to: [{temperature}]")
|
||||
self._temperature = temperature
|
||||
|
||||
async def set_top_k(self, top_k: float):
|
||||
logger.debug(f"Switching LLM top_k to: [{top_k}]")
|
||||
self._top_k = top_k
|
||||
|
||||
async def set_top_p(self, top_p: float):
|
||||
logger.debug(f"Switching LLM top_p to: [{top_p}]")
|
||||
self._top_p = top_p
|
||||
|
||||
async def set_extra(self, extra: Dict[str, Any]):
|
||||
logger.debug(f"Switching LLM extra to: [{extra}]")
|
||||
self._extra = extra
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
|
||||
# completion_tokens. We also estimate the completion tokens from output text
|
||||
@@ -160,11 +143,11 @@ class AnthropicLLMService(LLMService):
|
||||
)
|
||||
|
||||
messages = context.messages
|
||||
if self._enable_prompt_caching_beta:
|
||||
if self._settings["enable_prompt_caching_beta"]:
|
||||
messages = context.get_messages_with_cache_control_markers()
|
||||
|
||||
api_call = self._client.messages.create
|
||||
if self._enable_prompt_caching_beta:
|
||||
if self._settings["enable_prompt_caching_beta"]:
|
||||
api_call = self._client.beta.prompt_caching.messages.create
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
@@ -174,14 +157,14 @@ class AnthropicLLMService(LLMService):
|
||||
"system": context.system,
|
||||
"messages": messages,
|
||||
"model": self.model_name,
|
||||
"max_tokens": self._max_tokens,
|
||||
"max_tokens": self._settings["max_tokens"],
|
||||
"stream": True,
|
||||
"temperature": self._temperature,
|
||||
"top_k": self._top_k,
|
||||
"top_p": self._top_p,
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_k": self._settings["top_k"],
|
||||
"top_p": self._settings["top_p"],
|
||||
}
|
||||
|
||||
params.update(self._extra)
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
response = await api_call(**params)
|
||||
|
||||
@@ -279,21 +262,6 @@ class AnthropicLLMService(LLMService):
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
)
|
||||
|
||||
async def _update_settings(self, frame: LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
if frame.max_tokens is not None:
|
||||
await self.set_max_tokens(frame.max_tokens)
|
||||
if frame.temperature is not None:
|
||||
await self.set_temperature(frame.temperature)
|
||||
if frame.top_k is not None:
|
||||
await self.set_top_k(frame.top_k)
|
||||
if frame.top_p is not None:
|
||||
await self.set_top_p(frame.top_p)
|
||||
if frame.extra:
|
||||
await self.set_extra(frame.extra)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -309,10 +277,10 @@ class AnthropicLLMService(LLMService):
|
||||
# to the context.
|
||||
context = AnthropicLLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame)
|
||||
await self._update_settings(frame.settings)
|
||||
elif isinstance(frame, LLMEnablePromptCachingFrame):
|
||||
logger.debug(f"Setting enable prompt caching to: [{frame.enable}]")
|
||||
self._enable_prompt_caching_beta = frame.enable
|
||||
self._settings["enable_prompt_caching_beta"] = frame.enable
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -355,7 +323,6 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
system: str | NotGiven = NOT_GIVEN,
|
||||
):
|
||||
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
self._user_image_request_context = {}
|
||||
|
||||
# For beta prompt caching. This is a counter that tracks the number of turns
|
||||
# we've seen above the cache threshold. We reset this when we reset the
|
||||
@@ -541,8 +508,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
|
||||
|
||||
|
||||
class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
@@ -579,7 +546,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
run_llm = False
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._aggregation = ""
|
||||
self._reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
@@ -630,5 +597,8 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -16,8 +17,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import boto3
|
||||
@@ -30,10 +30,71 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_aws_language(language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.CA:
|
||||
return "ca-ES"
|
||||
case Language.ZH:
|
||||
return "cmn-CN"
|
||||
case Language.DA:
|
||||
return "da-DK"
|
||||
case Language.NL:
|
||||
return "nl-NL"
|
||||
case Language.NL_BE:
|
||||
return "nl-BE"
|
||||
case Language.EN:
|
||||
return "en-US"
|
||||
case Language.EN_US:
|
||||
return "en-US"
|
||||
case Language.EN_AU:
|
||||
return "en-AU"
|
||||
case Language.EN_GB:
|
||||
return "en-GB"
|
||||
case Language.EN_NZ:
|
||||
return "en-NZ"
|
||||
case Language.EN_IN:
|
||||
return "en-IN"
|
||||
case Language.FI:
|
||||
return "fi-FI"
|
||||
case Language.FR:
|
||||
return "fr-FR"
|
||||
case Language.FR_CA:
|
||||
return "fr-CA"
|
||||
case Language.DE:
|
||||
return "de-DE"
|
||||
case Language.HI:
|
||||
return "hi-IN"
|
||||
case Language.IT:
|
||||
return "it-IT"
|
||||
case Language.JA:
|
||||
return "ja-JP"
|
||||
case Language.KO:
|
||||
return "ko-KR"
|
||||
case Language.NO:
|
||||
return "nb-NO"
|
||||
case Language.PL:
|
||||
return "pl-PL"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro-RO"
|
||||
case Language.RU:
|
||||
return "ru-RU"
|
||||
case Language.ES:
|
||||
return "es-ES"
|
||||
case Language.SV:
|
||||
return "sv-SE"
|
||||
case Language.TR:
|
||||
return "tr-TR"
|
||||
return None
|
||||
|
||||
|
||||
class AWSTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
engine: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
@@ -57,9 +118,16 @@ class AWSTTSService(TTSService):
|
||||
aws_secret_access_key=api_key,
|
||||
region_name=region,
|
||||
)
|
||||
self._voice_id = voice_id
|
||||
self._sample_rate = sample_rate
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"engine": params.engine,
|
||||
"language": params.language if params.language else Language.EN,
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"volume": params.volume,
|
||||
}
|
||||
|
||||
self.set_voice(voice_id)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
@@ -67,18 +135,18 @@ class AWSTTSService(TTSService):
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
ssml = "<speak>"
|
||||
|
||||
if self._params.language:
|
||||
ssml += f"<lang xml:lang='{self._params.language}'>"
|
||||
language = language_to_aws_language(self._settings["language"])
|
||||
ssml += f"<lang xml:lang='{language}'>"
|
||||
|
||||
prosody_attrs = []
|
||||
# Prosody tags are only supported for standard and neural engines
|
||||
if self._params.engine != "generative":
|
||||
if self._params.rate:
|
||||
prosody_attrs.append(f"rate='{self._params.rate}'")
|
||||
if self._params.pitch:
|
||||
prosody_attrs.append(f"pitch='{self._params.pitch}'")
|
||||
if self._params.volume:
|
||||
prosody_attrs.append(f"volume='{self._params.volume}'")
|
||||
if self._settings["engine"] != "generative":
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
@@ -90,41 +158,12 @@ class AWSTTSService(TTSService):
|
||||
if prosody_attrs:
|
||||
ssml += "</prosody>"
|
||||
|
||||
if self._params.language:
|
||||
ssml += "</lang>"
|
||||
ssml += "</lang>"
|
||||
|
||||
ssml += "</speak>"
|
||||
|
||||
return ssml
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_engine(self, engine: str):
|
||||
logger.debug(f"Switching TTS engine to: [{engine}]")
|
||||
self._params.engine = engine
|
||||
|
||||
async def set_language(self, language: str):
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._params.language = language
|
||||
|
||||
async def set_pitch(self, pitch: str):
|
||||
logger.debug(f"Switching TTS pitch to: [{pitch}]")
|
||||
self._params.pitch = pitch
|
||||
|
||||
async def set_rate(self, rate: str):
|
||||
logger.debug(f"Switching TTS rate to: [{rate}]")
|
||||
self._params.rate = rate
|
||||
|
||||
async def set_volume(self, volume: str):
|
||||
logger.debug(f"Switching TTS volume to: [{volume}]")
|
||||
self._params.volume = volume
|
||||
|
||||
async def set_params(self, params: InputParams):
|
||||
logger.debug(f"Switching TTS params to: [{params}]")
|
||||
self._params = params
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -139,8 +178,8 @@ class AWSTTSService(TTSService):
|
||||
"TextType": "ssml",
|
||||
"OutputFormat": "pcm",
|
||||
"VoiceId": self._voice_id,
|
||||
"Engine": self._params.engine,
|
||||
"SampleRate": str(self._sample_rate),
|
||||
"Engine": self._settings["engine"],
|
||||
"SampleRate": str(self._settings["sample_rate"]),
|
||||
}
|
||||
|
||||
# Filter out None values
|
||||
@@ -150,7 +189,7 @@ class AWSTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
if "AudioStream" in response:
|
||||
with response["AudioStream"] as stream:
|
||||
@@ -160,10 +199,10 @@ class AWSTTSService(TTSService):
|
||||
chunk = audio_data[i : i + chunk_size]
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self._sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except (BotoCoreError, ClientError) as error:
|
||||
logger.exception(f"{self} error generating TTS: {error}")
|
||||
@@ -171,4 +210,4 @@ class AWSTTSService(TTSService):
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
finally:
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -4,12 +4,13 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import io
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -26,12 +27,10 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.services.ai_services import ImageGenService, STTService, TTSService
|
||||
from pipecat.services.openai import BaseOpenAILLMService
|
||||
from pipecat.transcriptions import language
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Azure configuration needed
|
||||
try:
|
||||
from azure.cognitiveservices.speech import (
|
||||
@@ -73,10 +72,101 @@ class AzureLLMService(BaseOpenAILLMService):
|
||||
)
|
||||
|
||||
|
||||
def language_to_azure_language(language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bg-BG"
|
||||
case Language.CA:
|
||||
return "ca-ES"
|
||||
case Language.ZH:
|
||||
return "zh-CN"
|
||||
case Language.ZH_TW:
|
||||
return "zh-TW"
|
||||
case Language.CS:
|
||||
return "cs-CZ"
|
||||
case Language.DA:
|
||||
return "da-DK"
|
||||
case Language.NL:
|
||||
return "nl-NL"
|
||||
case Language.EN:
|
||||
return "en-US"
|
||||
case Language.EN_US:
|
||||
return "en-US"
|
||||
case Language.EN_AU:
|
||||
return "en-AU"
|
||||
case Language.EN_GB:
|
||||
return "en-GB"
|
||||
case Language.EN_NZ:
|
||||
return "en-NZ"
|
||||
case Language.EN_IN:
|
||||
return "en-IN"
|
||||
case Language.ET:
|
||||
return "et-EE"
|
||||
case Language.FI:
|
||||
return "fi-FI"
|
||||
case Language.NL_BE:
|
||||
return "nl-BE"
|
||||
case Language.FR:
|
||||
return "fr-FR"
|
||||
case Language.FR_CA:
|
||||
return "fr-CA"
|
||||
case Language.DE:
|
||||
return "de-DE"
|
||||
case Language.DE_CH:
|
||||
return "de-CH"
|
||||
case Language.EL:
|
||||
return "el-GR"
|
||||
case Language.HI:
|
||||
return "hi-IN"
|
||||
case Language.HU:
|
||||
return "hu-HU"
|
||||
case Language.ID:
|
||||
return "id-ID"
|
||||
case Language.IT:
|
||||
return "it-IT"
|
||||
case Language.JA:
|
||||
return "ja-JP"
|
||||
case Language.KO:
|
||||
return "ko-KR"
|
||||
case Language.LV:
|
||||
return "lv-LV"
|
||||
case Language.LT:
|
||||
return "lt-LT"
|
||||
case Language.MS:
|
||||
return "ms-MY"
|
||||
case Language.NO:
|
||||
return "nb-NO"
|
||||
case Language.PL:
|
||||
return "pl-PL"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro-RO"
|
||||
case Language.RU:
|
||||
return "ru-RU"
|
||||
case Language.SK:
|
||||
return "sk-SK"
|
||||
case Language.ES:
|
||||
return "es-ES"
|
||||
case Language.SV:
|
||||
return "sv-SE"
|
||||
case Language.TH:
|
||||
return "th-TH"
|
||||
case Language.TR:
|
||||
return "tr-TR"
|
||||
case Language.UK:
|
||||
return "uk-UA"
|
||||
case Language.VI:
|
||||
return "vi-VN"
|
||||
return None
|
||||
|
||||
|
||||
class AzureTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
emphasis: Optional[str] = None
|
||||
language: Optional[str] = "en-US"
|
||||
language: Optional[Language] = Language.EN
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = "1.05"
|
||||
role: Optional[str] = None
|
||||
@@ -99,114 +189,68 @@ class AzureTTSService(TTSService):
|
||||
speech_config = SpeechConfig(subscription=api_key, region=region)
|
||||
self._speech_synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None)
|
||||
|
||||
self._voice = voice
|
||||
self._sample_rate = sample_rate
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"emphasis": params.emphasis,
|
||||
"language": params.language if params.language else Language.EN,
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"role": params.role,
|
||||
"style": params.style,
|
||||
"style_degree": params.style_degree,
|
||||
"volume": params.volume,
|
||||
}
|
||||
|
||||
self.set_voice(voice)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
language = language_to_azure_language(self._settings["language"])
|
||||
ssml = (
|
||||
f"<speak version='1.0' xml:lang='{self._params.language}' "
|
||||
f"<speak version='1.0' xml:lang='{language}' "
|
||||
"xmlns='http://www.w3.org/2001/10/synthesis' "
|
||||
"xmlns:mstts='http://www.w3.org/2001/mstts'>"
|
||||
f"<voice name='{self._voice}'>"
|
||||
f"<voice name='{self._voice_id}'>"
|
||||
"<mstts:silence type='Sentenceboundary' value='20ms' />"
|
||||
)
|
||||
|
||||
if self._params.style:
|
||||
ssml += f"<mstts:express-as style='{self._params.style}'"
|
||||
if self._params.style_degree:
|
||||
ssml += f" styledegree='{self._params.style_degree}'"
|
||||
if self._params.role:
|
||||
ssml += f" role='{self._params.role}'"
|
||||
if self._settings["style"]:
|
||||
ssml += f"<mstts:express-as style='{self._settings['style']}'"
|
||||
if self._settings["style_degree"]:
|
||||
ssml += f" styledegree='{self._settings['style_degree']}'"
|
||||
if self._settings["role"]:
|
||||
ssml += f" role='{self._settings['role']}'"
|
||||
ssml += ">"
|
||||
|
||||
prosody_attrs = []
|
||||
if self._params.rate:
|
||||
prosody_attrs.append(f"rate='{self._params.rate}'")
|
||||
if self._params.pitch:
|
||||
prosody_attrs.append(f"pitch='{self._params.pitch}'")
|
||||
if self._params.volume:
|
||||
prosody_attrs.append(f"volume='{self._params.volume}'")
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
if self._params.emphasis:
|
||||
ssml += f"<emphasis level='{self._params.emphasis}'>"
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
|
||||
ssml += text
|
||||
|
||||
if self._params.emphasis:
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
|
||||
ssml += "</prosody>"
|
||||
|
||||
if self._params.style:
|
||||
if self._settings["style"]:
|
||||
ssml += "</mstts:express-as>"
|
||||
|
||||
ssml += "</voice></speak>"
|
||||
|
||||
return ssml
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = voice
|
||||
|
||||
async def set_emphasis(self, emphasis: str):
|
||||
logger.debug(f"Setting TTS emphasis to: [{emphasis}]")
|
||||
self._params.emphasis = emphasis
|
||||
|
||||
async def set_language(self, language: str):
|
||||
logger.debug(f"Setting TTS language code to: [{language}]")
|
||||
self._params.language = language
|
||||
|
||||
async def set_pitch(self, pitch: str):
|
||||
logger.debug(f"Setting TTS pitch to: [{pitch}]")
|
||||
self._params.pitch = pitch
|
||||
|
||||
async def set_rate(self, rate: str):
|
||||
logger.debug(f"Setting TTS rate to: [{rate}]")
|
||||
self._params.rate = rate
|
||||
|
||||
async def set_role(self, role: str):
|
||||
logger.debug(f"Setting TTS role to: [{role}]")
|
||||
self._params.role = role
|
||||
|
||||
async def set_style(self, style: str):
|
||||
logger.debug(f"Setting TTS style to: [{style}]")
|
||||
self._params.style = style
|
||||
|
||||
async def set_style_degree(self, style_degree: str):
|
||||
logger.debug(f"Setting TTS style degree to: [{style_degree}]")
|
||||
self._params.style_degree = style_degree
|
||||
|
||||
async def set_volume(self, volume: str):
|
||||
logger.debug(f"Setting TTS volume to: [{volume}]")
|
||||
self._params.volume = volume
|
||||
|
||||
async def set_params(self, **kwargs):
|
||||
valid_params = {
|
||||
"voice": self.set_voice,
|
||||
"emphasis": self.set_emphasis,
|
||||
"language_code": self.set_language,
|
||||
"pitch": self.set_pitch,
|
||||
"rate": self.set_rate,
|
||||
"role": self.set_role,
|
||||
"style": self.set_style,
|
||||
"style_degree": self.set_style_degree,
|
||||
"volume": self.set_volume,
|
||||
}
|
||||
|
||||
for param, value in kwargs.items():
|
||||
if param in valid_params:
|
||||
await valid_params[param](value)
|
||||
else:
|
||||
logger.warning(f"Ignoring unknown parameter: {param}")
|
||||
|
||||
logger.debug(f"Updated TTS parameters: {', '.join(kwargs.keys())}")
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -219,12 +263,14 @@ class AzureTTSService(TTSService):
|
||||
if result.reason == ResultReason.SynthesizingAudioCompleted:
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
# Azure always sends a 44-byte header. Strip it off.
|
||||
yield TTSAudioRawFrame(
|
||||
audio=result.audio_data[44:], sample_rate=self._sample_rate, num_channels=1
|
||||
audio=result.audio_data[44:],
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
elif result.reason == ResultReason.Canceled:
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
|
||||
@@ -4,36 +4,35 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from typing import AsyncGenerator, Optional, Union, List
|
||||
from loguru import logger
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartInterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
EndFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService, WordTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.services.ai_services import WordTTSService, TTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Cartesia configuration needed
|
||||
try:
|
||||
from cartesia import AsyncCartesia
|
||||
import websockets
|
||||
from cartesia import AsyncCartesia
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
@@ -66,7 +65,7 @@ class CartesiaTTSService(WordTTSService):
|
||||
encoding: Optional[str] = "pcm_s16le"
|
||||
sample_rate: Optional[int] = 16000
|
||||
container: Optional[str] = "raw"
|
||||
language: Optional[str] = "en"
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[Union[str, float]] = ""
|
||||
emotion: Optional[List[str]] = []
|
||||
|
||||
@@ -77,7 +76,7 @@ class CartesiaTTSService(WordTTSService):
|
||||
voice_id: str,
|
||||
cartesia_version: str = "2024-06-10",
|
||||
url: str = "wss://api.cartesia.ai/tts/websocket",
|
||||
model_id: str = "sonic-english",
|
||||
model: str = "sonic-english",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
@@ -101,17 +100,18 @@ class CartesiaTTSService(WordTTSService):
|
||||
self._api_key = api_key
|
||||
self._cartesia_version = cartesia_version
|
||||
self._url = url
|
||||
self._voice_id = voice_id
|
||||
self._model_id = model_id
|
||||
self.set_model_name(model_id)
|
||||
self._output_format = {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
self._settings = {
|
||||
"output_format": {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
},
|
||||
"language": params.language if params.language else Language.EN,
|
||||
"speed": params.speed,
|
||||
"emotion": params.emotion,
|
||||
}
|
||||
self._language = params.language
|
||||
self._speed = params.speed
|
||||
self._emotion = params.emotion
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
|
||||
self._websocket = None
|
||||
self._context_id = None
|
||||
@@ -125,42 +125,28 @@ class CartesiaTTSService(WordTTSService):
|
||||
await super().set_model(model)
|
||||
logger.debug(f"Switching TTS model to: [{model}]")
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_speed(self, speed: str):
|
||||
logger.debug(f"Switching TTS speed to: [{speed}]")
|
||||
self._speed = speed
|
||||
|
||||
async def set_emotion(self, emotion: list[str]):
|
||||
logger.debug(f"Switching TTS emotion to: [{emotion}]")
|
||||
self._emotion = emotion
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._language = language_to_cartesia_language(language)
|
||||
|
||||
def _build_msg(
|
||||
self, text: str = "", continue_transcript: bool = True, add_timestamps: bool = True
|
||||
):
|
||||
voice_config = {"mode": "id", "id": self._voice_id}
|
||||
voice_config = {}
|
||||
voice_config["mode"] = "id"
|
||||
voice_config["id"] = self._voice_id
|
||||
|
||||
if self._speed or self._emotion:
|
||||
if self._settings["speed"] or self._settings["emotion"]:
|
||||
voice_config["__experimental_controls"] = {}
|
||||
if self._speed:
|
||||
voice_config["__experimental_controls"]["speed"] = self._speed
|
||||
if self._emotion:
|
||||
voice_config["__experimental_controls"]["emotion"] = self._emotion
|
||||
if self._settings["speed"]:
|
||||
voice_config["__experimental_controls"]["speed"] = self._settings["speed"]
|
||||
if self._settings["emotion"]:
|
||||
voice_config["__experimental_controls"]["emotion"] = self._settings["emotion"]
|
||||
|
||||
msg = {
|
||||
"transcript": text,
|
||||
"continue": continue_transcript,
|
||||
"context_id": self._context_id,
|
||||
"model_id": self._model_name,
|
||||
"model_id": self.model_name,
|
||||
"voice": voice_config,
|
||||
"output_format": self._output_format,
|
||||
"language": self._language,
|
||||
"output_format": self._settings["output_format"],
|
||||
"language": language_to_cartesia_language(self._settings["language"]),
|
||||
"add_timestamps": add_timestamps,
|
||||
}
|
||||
return json.dumps(msg)
|
||||
@@ -245,7 +231,7 @@ class CartesiaTTSService(WordTTSService):
|
||||
self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]),
|
||||
sample_rate=self._output_format["sample_rate"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
@@ -269,8 +255,8 @@ class CartesiaTTSService(WordTTSService):
|
||||
await self._connect()
|
||||
|
||||
if not self._context_id:
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._context_id = str(uuid.uuid4())
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
@@ -280,7 +266,7 @@ class CartesiaTTSService(WordTTSService):
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
@@ -294,7 +280,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
encoding: Optional[str] = "pcm_s16le"
|
||||
sample_rate: Optional[int] = 16000
|
||||
container: Optional[str] = "raw"
|
||||
language: Optional[str] = "en"
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[Union[str, float]] = ""
|
||||
emotion: Optional[List[str]] = []
|
||||
|
||||
@@ -303,7 +289,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model_id: str = "sonic-english",
|
||||
model: str = "sonic-english",
|
||||
base_url: str = "https://api.cartesia.ai",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
@@ -311,44 +297,24 @@ class CartesiaHttpTTSService(TTSService):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._model_id = model_id
|
||||
self.set_model_name(model_id)
|
||||
self._output_format = {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
self._settings = {
|
||||
"output_format": {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
},
|
||||
"language": params.language if params.language else Language.EN,
|
||||
"speed": params.speed,
|
||||
"emotion": params.emotion,
|
||||
}
|
||||
self._language = params.language
|
||||
self._speed = params.speed
|
||||
self._emotion = params.emotion
|
||||
self.set_voice(voice_id)
|
||||
self.set_model_name(model)
|
||||
|
||||
self._client = AsyncCartesia(api_key=api_key, base_url=base_url)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
logger.debug(f"Switching TTS model to: [{model}]")
|
||||
self._model_id = model
|
||||
await super().set_model(model)
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_speed(self, speed: str):
|
||||
logger.debug(f"Switching TTS speed to: [{speed}]")
|
||||
self._speed = speed
|
||||
|
||||
async def set_emotion(self, emotion: list[str]):
|
||||
logger.debug(f"Switching TTS emotion to: [{emotion}]")
|
||||
self._emotion = emotion
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._language = language_to_cartesia_language(language)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._client.close()
|
||||
@@ -360,24 +326,24 @@ class CartesiaHttpTTSService(TTSService):
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
try:
|
||||
voice_controls = None
|
||||
if self._speed or self._emotion:
|
||||
if self._settings["speed"] or self._settings["emotion"]:
|
||||
voice_controls = {}
|
||||
if self._speed:
|
||||
voice_controls["speed"] = self._speed
|
||||
if self._emotion:
|
||||
voice_controls["emotion"] = self._emotion
|
||||
if self._settings["speed"]:
|
||||
voice_controls["speed"] = self._settings["speed"]
|
||||
if self._settings["emotion"]:
|
||||
voice_controls["emotion"] = self._settings["emotion"]
|
||||
|
||||
output = await self._client.tts.sse(
|
||||
model_id=self._model_id,
|
||||
model_id=self._model_name,
|
||||
transcript=text,
|
||||
voice_id=self._voice_id,
|
||||
output_format=self._output_format,
|
||||
language=self._language,
|
||||
output_format=self._settings["output_format"],
|
||||
language=language_to_cartesia_language(self._settings["language"]),
|
||||
stream=False,
|
||||
_experimental_voice_controls=voice_controls,
|
||||
)
|
||||
@@ -386,7 +352,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=output["audio"],
|
||||
sample_rate=self._output_format["sample_rate"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
@@ -394,4 +360,4 @@ class CartesiaHttpTTSService(TTSService):
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -5,9 +5,10 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -24,8 +25,6 @@ from pipecat.services.ai_services import STTService, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Deepgram configuration needed
|
||||
try:
|
||||
from deepgram import (
|
||||
@@ -57,25 +56,23 @@ class DeepgramTTSService(TTSService):
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice = voice
|
||||
self._sample_rate = sample_rate
|
||||
self._encoding = encoding
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"encoding": encoding,
|
||||
}
|
||||
self.set_voice(voice)
|
||||
self._deepgram_client = DeepgramClient(api_key=api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
options = SpeakOptions(
|
||||
model=self._voice,
|
||||
encoding=self._encoding,
|
||||
sample_rate=self._sample_rate,
|
||||
model=self._voice_id,
|
||||
encoding=self._settings["encoding"],
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
container="none",
|
||||
)
|
||||
|
||||
@@ -87,7 +84,7 @@ class DeepgramTTSService(TTSService):
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# The response.stream_memory is already a BytesIO object
|
||||
audio_buffer = response.stream_memory
|
||||
@@ -103,10 +100,12 @@ class DeepgramTTSService(TTSService):
|
||||
chunk = audio_buffer.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
frame = TTSAudioRawFrame(audio=chunk, sample_rate=self._sample_rate, num_channels=1)
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=chunk, sample_rate=self._settings["sample_rate"], num_channels=1
|
||||
)
|
||||
yield frame
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
@@ -121,7 +120,7 @@ class DeepgramSTTService(STTService):
|
||||
url: str = "",
|
||||
live_options: LiveOptions = LiveOptions(
|
||||
encoding="linear16",
|
||||
language="en-US",
|
||||
language=Language.EN,
|
||||
model="nova-2-conversationalai",
|
||||
sample_rate=16000,
|
||||
channels=1,
|
||||
@@ -135,7 +134,7 @@ class DeepgramSTTService(STTService):
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._live_options = live_options
|
||||
self._settings = vars(live_options)
|
||||
|
||||
self._client = DeepgramClient(
|
||||
api_key, config=DeepgramClientOptions(url=url, options={"keepalive": "true"})
|
||||
@@ -147,7 +146,7 @@ class DeepgramSTTService(STTService):
|
||||
|
||||
@property
|
||||
def vad_enabled(self):
|
||||
return self._live_options.vad_events
|
||||
return self._settings["vad_events"]
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return self.vad_enabled
|
||||
@@ -155,13 +154,7 @@ class DeepgramSTTService(STTService):
|
||||
async def set_model(self, model: str):
|
||||
await super().set_model(model)
|
||||
logger.debug(f"Switching STT model to: [{model}]")
|
||||
self._live_options.model = model
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
logger.debug(f"Switching STT language to: [{language}]")
|
||||
self._live_options.language = language
|
||||
self._settings["model"] = model
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
@@ -182,7 +175,7 @@ class DeepgramSTTService(STTService):
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
if await self._connection.start(self._live_options):
|
||||
if await self._connection.start(self._settings):
|
||||
logger.debug(f"{self}: Connected to Deepgram")
|
||||
else:
|
||||
logger.error(f"{self}: Unable to connect to Deepgram")
|
||||
|
||||
@@ -24,6 +24,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import WordTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for ElevenLabs configuration needed
|
||||
try:
|
||||
@@ -49,6 +50,76 @@ def sample_rate_from_output_format(output_format: str) -> int:
|
||||
return 16000
|
||||
|
||||
|
||||
def language_to_elevenlabs_language(language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bg"
|
||||
case Language.ZH:
|
||||
return "zh"
|
||||
case Language.CS:
|
||||
return "cs"
|
||||
case Language.DA:
|
||||
return "da"
|
||||
case Language.NL:
|
||||
return "nl"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "en"
|
||||
case Language.FI:
|
||||
return "fi"
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "fr"
|
||||
case Language.DE | Language.DE_CH:
|
||||
return "de"
|
||||
case Language.EL:
|
||||
return "el"
|
||||
case Language.HI:
|
||||
return "hi"
|
||||
case Language.HU:
|
||||
return "hu"
|
||||
case Language.ID:
|
||||
return "id"
|
||||
case Language.IT:
|
||||
return "it"
|
||||
case Language.JA:
|
||||
return "ja"
|
||||
case Language.KO:
|
||||
return "ko"
|
||||
case Language.MS:
|
||||
return "ms"
|
||||
case Language.NO:
|
||||
return "no"
|
||||
case Language.PL:
|
||||
return "pl"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro"
|
||||
case Language.RU:
|
||||
return "ru"
|
||||
case Language.SK:
|
||||
return "sk"
|
||||
case Language.ES:
|
||||
return "es"
|
||||
case Language.SV:
|
||||
return "sv"
|
||||
case Language.TR:
|
||||
return "tr"
|
||||
case Language.UK:
|
||||
return "uk"
|
||||
case Language.VI:
|
||||
return "vi"
|
||||
return None
|
||||
|
||||
|
||||
def calculate_word_times(
|
||||
alignment_info: Mapping[str, Any], cumulative_time: float
|
||||
) -> List[Tuple[str, float]]:
|
||||
@@ -72,7 +143,7 @@ def calculate_word_times(
|
||||
|
||||
class ElevenLabsTTSService(WordTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
output_format: Literal["pcm_16000", "pcm_22050", "pcm_24000", "pcm_44100"] = "pcm_16000"
|
||||
optimize_streaming_latency: Optional[str] = None
|
||||
stability: Optional[float] = None
|
||||
@@ -124,10 +195,19 @@ class ElevenLabsTTSService(WordTTSService):
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self.set_model_name(model)
|
||||
self._url = url
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate_from_output_format(params.output_format),
|
||||
"language": params.language if params.language else Language.EN,
|
||||
"output_format": params.output_format,
|
||||
"optimize_streaming_latency": params.optimize_streaming_latency,
|
||||
"stability": params.stability,
|
||||
"similarity_boost": params.similarity_boost,
|
||||
"style": params.style,
|
||||
"use_speaker_boost": params.use_speaker_boost,
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
self._voice_settings = self._set_voice_settings()
|
||||
|
||||
# Websocket connection to ElevenLabs.
|
||||
@@ -142,19 +222,22 @@ class ElevenLabsTTSService(WordTTSService):
|
||||
|
||||
def _set_voice_settings(self):
|
||||
voice_settings = {}
|
||||
if self._params.stability is not None and self._params.similarity_boost is not None:
|
||||
voice_settings["stability"] = self._params.stability
|
||||
voice_settings["similarity_boost"] = self._params.similarity_boost
|
||||
if self._params.style is not None:
|
||||
voice_settings["style"] = self._params.style
|
||||
if self._params.use_speaker_boost is not None:
|
||||
voice_settings["use_speaker_boost"] = self._params.use_speaker_boost
|
||||
if (
|
||||
self._settings["stability"] is not None
|
||||
and self._settings["similarity_boost"] is not None
|
||||
):
|
||||
voice_settings["stability"] = self._settings["stability"]
|
||||
voice_settings["similarity_boost"] = self._settings["similarity_boost"]
|
||||
if self._settings["style"] is not None:
|
||||
voice_settings["style"] = self._settings["style"]
|
||||
if self._settings["use_speaker_boost"] is not None:
|
||||
voice_settings["use_speaker_boost"] = self._settings["use_speaker_boost"]
|
||||
else:
|
||||
if self._params.style is not None:
|
||||
if self._settings["style"] is not None:
|
||||
logger.warning(
|
||||
"'style' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||
)
|
||||
if self._params.use_speaker_boost is not None:
|
||||
if self._settings["use_speaker_boost"] is not None:
|
||||
logger.warning(
|
||||
"'use_speaker_boost' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||
)
|
||||
@@ -167,33 +250,13 @@ class ElevenLabsTTSService(WordTTSService):
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_voice_settings(
|
||||
self,
|
||||
stability: Optional[float] = None,
|
||||
similarity_boost: Optional[float] = None,
|
||||
style: Optional[float] = None,
|
||||
use_speaker_boost: Optional[bool] = None,
|
||||
):
|
||||
self._params.stability = stability if stability is not None else self._params.stability
|
||||
self._params.similarity_boost = (
|
||||
similarity_boost if similarity_boost is not None else self._params.similarity_boost
|
||||
)
|
||||
self._params.style = style if style is not None else self._params.style
|
||||
self._params.use_speaker_boost = (
|
||||
use_speaker_boost if use_speaker_boost is not None else self._params.use_speaker_boost
|
||||
)
|
||||
|
||||
self._set_voice_settings()
|
||||
|
||||
if self._websocket:
|
||||
msg = {"voice_settings": self._voice_settings}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
prev_voice = self._voice_id
|
||||
await super()._update_settings(settings)
|
||||
if not prev_voice == self._voice_id:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
logger.debug(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
@@ -223,20 +286,20 @@ class ElevenLabsTTSService(WordTTSService):
|
||||
try:
|
||||
voice_id = self._voice_id
|
||||
model = self.model_name
|
||||
output_format = self._params.output_format
|
||||
output_format = self._settings["output_format"]
|
||||
url = f"{self._url}/v1/text-to-speech/{voice_id}/stream-input?model_id={model}&output_format={output_format}"
|
||||
|
||||
if self._params.optimize_streaming_latency:
|
||||
url += f"&optimize_streaming_latency={self._params.optimize_streaming_latency}"
|
||||
if self._settings["optimize_streaming_latency"]:
|
||||
url += f"&optimize_streaming_latency={self._settings['optimize_streaming_latency']}"
|
||||
|
||||
# language can only be used with the 'eleven_turbo_v2_5' model
|
||||
if self._params.language:
|
||||
if model == "eleven_turbo_v2_5":
|
||||
url += f"&language_code={self._params.language}"
|
||||
else:
|
||||
logger.debug(
|
||||
f"Language code [{self._params.language}] not applied. Language codes can only be used with the 'eleven_turbo_v2_5' model."
|
||||
)
|
||||
# Language can only be used with the 'eleven_turbo_v2_5' model
|
||||
language = language_to_elevenlabs_language(self._settings["language"])
|
||||
if model == "eleven_turbo_v2_5":
|
||||
url += f"&language_code={language}"
|
||||
else:
|
||||
logger.debug(
|
||||
f"Language code [{language}] not applied. Language codes can only be used with the 'eleven_turbo_v2_5' model."
|
||||
)
|
||||
|
||||
self._websocket = await websockets.connect(url)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
@@ -286,7 +349,7 @@ class ElevenLabsTTSService(WordTTSService):
|
||||
self.start_word_timestamps()
|
||||
|
||||
audio = base64.b64decode(msg["audio"])
|
||||
frame = TTSAudioRawFrame(audio, self.sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
|
||||
await self.push_frame(frame)
|
||||
|
||||
if msg.get("alignment"):
|
||||
@@ -322,8 +385,8 @@ class ElevenLabsTTSService(WordTTSService):
|
||||
|
||||
try:
|
||||
if not self._started:
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
self._cumulative_time = 0
|
||||
|
||||
@@ -331,7 +394,7 @@ class ElevenLabsTTSService(WordTTSService):
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
|
||||
@@ -6,8 +6,9 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -19,10 +20,9 @@ from pipecat.frames.frames import (
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Gladia configuration needed
|
||||
try:
|
||||
import websockets
|
||||
@@ -34,10 +34,88 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_gladia_language(language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bulgarian"
|
||||
case Language.CA:
|
||||
return "catalan"
|
||||
case Language.ZH:
|
||||
return "chinese"
|
||||
case Language.CS:
|
||||
return "czech"
|
||||
case Language.DA:
|
||||
return "danish"
|
||||
case Language.NL:
|
||||
return "dutch"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "english"
|
||||
case Language.ET:
|
||||
return "estonian"
|
||||
case Language.FI:
|
||||
return "finnish"
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "french"
|
||||
case Language.DE | Language.DE_CH:
|
||||
return "german"
|
||||
case Language.EL:
|
||||
return "greek"
|
||||
case Language.HI:
|
||||
return "hindi"
|
||||
case Language.HU:
|
||||
return "hungarian"
|
||||
case Language.ID:
|
||||
return "indonesian"
|
||||
case Language.IT:
|
||||
return "italian"
|
||||
case Language.JA:
|
||||
return "japanese"
|
||||
case Language.KO:
|
||||
return "korean"
|
||||
case Language.LV:
|
||||
return "latvian"
|
||||
case Language.LT:
|
||||
return "lithuanian"
|
||||
case Language.MS:
|
||||
return "malay"
|
||||
case Language.NO:
|
||||
return "norwegian"
|
||||
case Language.PL:
|
||||
return "polish"
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "portuguese"
|
||||
case Language.RO:
|
||||
return "romanian"
|
||||
case Language.RU:
|
||||
return "russian"
|
||||
case Language.SK:
|
||||
return "slovak"
|
||||
case Language.ES:
|
||||
return "spanish"
|
||||
case Language.SV:
|
||||
return "slovenian"
|
||||
case Language.TH:
|
||||
return "thai"
|
||||
case Language.TR:
|
||||
return "turkish"
|
||||
case Language.UK:
|
||||
return "ukrainian"
|
||||
case Language.VI:
|
||||
return "vietnamese"
|
||||
return None
|
||||
|
||||
|
||||
class GladiaSTTService(STTService):
|
||||
class InputParams(BaseModel):
|
||||
sample_rate: Optional[int] = 16000
|
||||
language: Optional[str] = "english"
|
||||
language: Optional[Language] = Language.EN
|
||||
transcription_hint: Optional[str] = None
|
||||
endpointing: Optional[int] = 200
|
||||
prosody: Optional[bool] = None
|
||||
@@ -55,7 +133,13 @@ class GladiaSTTService(STTService):
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": params.sample_rate,
|
||||
"language": params.language if params.language else Language.EN,
|
||||
"transcription_hint": params.transcription_hint,
|
||||
"endpointing": params.endpointing,
|
||||
"prosody": params.prosody,
|
||||
}
|
||||
self._confidence = confidence
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
@@ -84,7 +168,11 @@ class GladiaSTTService(STTService):
|
||||
"encoding": "WAV/PCM",
|
||||
"model_type": "fast",
|
||||
"language_behaviour": "manual",
|
||||
**self._params.model_dump(exclude_none=True),
|
||||
"sample_rate": self._settings["sample_rate"],
|
||||
"language": language_to_gladia_language(self._settings["language"]),
|
||||
"transcription_hint": self._settings["transcription_hint"],
|
||||
"endpointing": self._settings["endpointing"],
|
||||
"prosody": self._settings["prosody"],
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(configuration))
|
||||
|
||||
@@ -30,6 +30,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import google.ai.generativelanguage as glm
|
||||
@@ -39,7 +40,7 @@ try:
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_API_KEY` environment variable."
|
||||
"In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set the environment variable GOOGLE_API_KEY for the GoogleLLMService and GOOGLE_APPLICATION_CREDENTIALS for the GoogleTTSService`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
@@ -137,9 +138,7 @@ class GoogleLLMService(LLMService):
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -147,13 +146,100 @@ class GoogleLLMService(LLMService):
|
||||
await self._process_context(context)
|
||||
|
||||
|
||||
def language_to_google_language(language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bg-BG"
|
||||
case Language.CA:
|
||||
return "ca-ES"
|
||||
case Language.ZH:
|
||||
return "cmn-CN"
|
||||
case Language.ZH_TW:
|
||||
return "cmn-TW"
|
||||
case Language.CS:
|
||||
return "cs-CZ"
|
||||
case Language.DA:
|
||||
return "da-DK"
|
||||
case Language.NL:
|
||||
return "nl-NL"
|
||||
case Language.EN:
|
||||
return "en-US"
|
||||
case Language.EN_US:
|
||||
return "en-US"
|
||||
case Language.EN_AU:
|
||||
return "en-AU"
|
||||
case Language.EN_GB:
|
||||
return "en-GB"
|
||||
case Language.EN_IN:
|
||||
return "en-IN"
|
||||
case Language.ET:
|
||||
return "et-EE"
|
||||
case Language.FI:
|
||||
return "fi-FI"
|
||||
case Language.NL_BE:
|
||||
return "nl-BE"
|
||||
case Language.FR:
|
||||
return "fr-FR"
|
||||
case Language.FR_CA:
|
||||
return "fr-CA"
|
||||
case Language.DE:
|
||||
return "de-DE"
|
||||
case Language.EL:
|
||||
return "el-GR"
|
||||
case Language.HI:
|
||||
return "hi-IN"
|
||||
case Language.HU:
|
||||
return "hu-HU"
|
||||
case Language.ID:
|
||||
return "id-ID"
|
||||
case Language.IT:
|
||||
return "it-IT"
|
||||
case Language.JA:
|
||||
return "ja-JP"
|
||||
case Language.KO:
|
||||
return "ko-KR"
|
||||
case Language.LV:
|
||||
return "lv-LV"
|
||||
case Language.LT:
|
||||
return "lt-LT"
|
||||
case Language.MS:
|
||||
return "ms-MY"
|
||||
case Language.NO:
|
||||
return "nb-NO"
|
||||
case Language.PL:
|
||||
return "pl-PL"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro-RO"
|
||||
case Language.RU:
|
||||
return "ru-RU"
|
||||
case Language.SK:
|
||||
return "sk-SK"
|
||||
case Language.ES:
|
||||
return "es-ES"
|
||||
case Language.SV:
|
||||
return "sv-SE"
|
||||
case Language.TH:
|
||||
return "th-TH"
|
||||
case Language.TR:
|
||||
return "tr-TR"
|
||||
case Language.UK:
|
||||
return "uk-UA"
|
||||
case Language.VI:
|
||||
return "vi-VN"
|
||||
return None
|
||||
|
||||
|
||||
class GoogleTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
emphasis: Optional[Literal["strong", "moderate", "reduced", "none"]] = None
|
||||
language: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
gender: Optional[Literal["male", "female", "neutral"]] = None
|
||||
google_style: Optional[Literal["apologetic", "calm", "empathetic", "firm", "lively"]] = None
|
||||
|
||||
@@ -169,8 +255,17 @@ class GoogleTTSService(TTSService):
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._voice_id: str = voice_id
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"volume": params.volume,
|
||||
"emphasis": params.emphasis,
|
||||
"language": params.language if params.language else Language.EN,
|
||||
"gender": params.gender,
|
||||
"google_style": params.google_style,
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client(
|
||||
credentials, credentials_path
|
||||
)
|
||||
@@ -190,8 +285,6 @@ class GoogleTTSService(TTSService):
|
||||
elif credentials_path:
|
||||
# Use service account JSON file if provided
|
||||
creds = service_account.Credentials.from_service_account_file(credentials_path)
|
||||
else:
|
||||
raise ValueError("Either 'credentials' or 'credentials_path' must be provided.")
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
|
||||
|
||||
@@ -203,38 +296,40 @@ class GoogleTTSService(TTSService):
|
||||
|
||||
# Voice tag
|
||||
voice_attrs = [f"name='{self._voice_id}'"]
|
||||
if self._params.language:
|
||||
voice_attrs.append(f"language='{self._params.language}'")
|
||||
if self._params.gender:
|
||||
voice_attrs.append(f"gender='{self._params.gender}'")
|
||||
|
||||
language = language_to_google_language(self._settings["language"])
|
||||
voice_attrs.append(f"language='{language}'")
|
||||
|
||||
if self._settings["gender"]:
|
||||
voice_attrs.append(f"gender='{self._settings['gender']}'")
|
||||
ssml += f"<voice {' '.join(voice_attrs)}>"
|
||||
|
||||
# Prosody tag
|
||||
prosody_attrs = []
|
||||
if self._params.pitch:
|
||||
prosody_attrs.append(f"pitch='{self._params.pitch}'")
|
||||
if self._params.rate:
|
||||
prosody_attrs.append(f"rate='{self._params.rate}'")
|
||||
if self._params.volume:
|
||||
prosody_attrs.append(f"volume='{self._params.volume}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
# Emphasis tag
|
||||
if self._params.emphasis:
|
||||
ssml += f"<emphasis level='{self._params.emphasis}'>"
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
|
||||
# Google style tag
|
||||
if self._params.google_style:
|
||||
ssml += f"<google:style name='{self._params.google_style}'>"
|
||||
if self._settings["google_style"]:
|
||||
ssml += f"<google:style name='{self._settings['google_style']}'>"
|
||||
|
||||
ssml += text
|
||||
|
||||
# Close tags
|
||||
if self._params.google_style:
|
||||
if self._settings["google_style"]:
|
||||
ssml += "</google:style>"
|
||||
if self._params.emphasis:
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
if prosody_attrs:
|
||||
ssml += "</prosody>"
|
||||
@@ -242,46 +337,6 @@ class GoogleTTSService(TTSService):
|
||||
|
||||
return ssml
|
||||
|
||||
async def set_voice(self, voice: str) -> None:
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_language(self, language: str) -> None:
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._params.language = language
|
||||
|
||||
async def set_pitch(self, pitch: str) -> None:
|
||||
logger.debug(f"Switching TTS pitch to: [{pitch}]")
|
||||
self._params.pitch = pitch
|
||||
|
||||
async def set_rate(self, rate: str) -> None:
|
||||
logger.debug(f"Switching TTS rate to: [{rate}]")
|
||||
self._params.rate = rate
|
||||
|
||||
async def set_volume(self, volume: str) -> None:
|
||||
logger.debug(f"Switching TTS volume to: [{volume}]")
|
||||
self._params.volume = volume
|
||||
|
||||
async def set_emphasis(
|
||||
self, emphasis: Literal["strong", "moderate", "reduced", "none"]
|
||||
) -> None:
|
||||
logger.debug(f"Switching TTS emphasis to: [{emphasis}]")
|
||||
self._params.emphasis = emphasis
|
||||
|
||||
async def set_gender(self, gender: Literal["male", "female", "neutral"]) -> None:
|
||||
logger.debug(f"Switch TTS gender to [{gender}]")
|
||||
self._params.gender = gender
|
||||
|
||||
async def google_style(
|
||||
self, google_style: Literal["apologetic", "calm", "empathetic", "firm", "lively"]
|
||||
) -> None:
|
||||
logger.debug(f"Switching TTS google style to: [{google_style}]")
|
||||
self._params.google_style = google_style
|
||||
|
||||
async def set_params(self, params: InputParams) -> None:
|
||||
logger.debug(f"Switching TTS params to: [{params}]")
|
||||
self._params = params
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -291,11 +346,11 @@ class GoogleTTSService(TTSService):
|
||||
ssml = self._construct_ssml(text)
|
||||
synthesis_input = texttospeech_v1.SynthesisInput(ssml=ssml)
|
||||
voice = texttospeech_v1.VoiceSelectionParams(
|
||||
language_code=self._params.language, name=self._voice_id
|
||||
language_code=self._settings["language"], name=self._voice_id
|
||||
)
|
||||
audio_config = texttospeech_v1.AudioConfig(
|
||||
audio_encoding=texttospeech_v1.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
sample_rate_hertz=self._settings["sample_rate"],
|
||||
)
|
||||
|
||||
request = texttospeech_v1.SynthesizeSpeechRequest(
|
||||
@@ -306,7 +361,7 @@ class GoogleTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# Skip the first 44 bytes to remove the WAV header
|
||||
audio_content = response.audio_content[44:]
|
||||
@@ -318,15 +373,15 @@ class GoogleTTSService(TTSService):
|
||||
if not chunk:
|
||||
break
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
await asyncio.sleep(0) # Allow other tasks to run
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
finally:
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -20,9 +20,9 @@ from pipecat.frames.frames import (
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for LMNT configuration needed
|
||||
try:
|
||||
@@ -35,6 +35,32 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_lmnt_language(language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.DE:
|
||||
return "de"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "en"
|
||||
case Language.ES:
|
||||
return "es"
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "fr"
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "pt"
|
||||
case Language.ZH | Language.ZH_TW:
|
||||
return "zh"
|
||||
case Language.KO:
|
||||
return "ko"
|
||||
return None
|
||||
|
||||
|
||||
class LmntTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -42,7 +68,7 @@ class LmntTTSService(TTSService):
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
sample_rate: int = 24000,
|
||||
language: str = "en",
|
||||
language: Language = Language.EN,
|
||||
**kwargs,
|
||||
):
|
||||
# Let TTSService produce TTSStoppedFrames after a short delay of
|
||||
@@ -50,13 +76,16 @@ class LmntTTSService(TTSService):
|
||||
super().__init__(push_stop_frames=True, sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._output_format = {
|
||||
"container": "raw",
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate": sample_rate,
|
||||
self._settings = {
|
||||
"output_format": {
|
||||
"container": "raw",
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate": sample_rate,
|
||||
},
|
||||
"language": language,
|
||||
}
|
||||
self._language = language
|
||||
|
||||
self.set_voice(voice_id)
|
||||
|
||||
self._speech = None
|
||||
self._connection = None
|
||||
@@ -68,10 +97,6 @@ class LmntTTSService(TTSService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
@@ -93,7 +118,9 @@ class LmntTTSService(TTSService):
|
||||
try:
|
||||
self._speech = Speech()
|
||||
self._connection = await self._speech.synthesize_streaming(
|
||||
self._voice_id, format="raw", sample_rate=self._output_format["sample_rate"]
|
||||
self._voice_id,
|
||||
format="raw",
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
@@ -130,7 +157,7 @@ class LmntTTSService(TTSService):
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=msg["audio"],
|
||||
sample_rate=self._output_format["sample_rate"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
@@ -149,8 +176,8 @@ class LmntTTSService(TTSService):
|
||||
await self._connect()
|
||||
|
||||
if not self._started:
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
|
||||
try:
|
||||
@@ -159,7 +186,7 @@ class LmntTTSService(TTSService):
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
|
||||
@@ -31,6 +31,8 @@ from pipecat.frames.frames import (
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
URLImageRawFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
@@ -109,14 +111,16 @@ class BaseOpenAILLMService(LLMService):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._settings = {
|
||||
"frequency_penalty": params.frequency_penalty,
|
||||
"presence_penalty": params.presence_penalty,
|
||||
"seed": params.seed,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs)
|
||||
self._frequency_penalty = params.frequency_penalty
|
||||
self._presence_penalty = params.presence_penalty
|
||||
self._seed = params.seed
|
||||
self._temperature = params.temperature
|
||||
self._top_p = params.top_p
|
||||
self._extra = params.extra if isinstance(params.extra, dict) else {}
|
||||
|
||||
def create_client(self, api_key=None, base_url=None, **kwargs):
|
||||
return AsyncOpenAI(
|
||||
@@ -132,30 +136,6 @@ class BaseOpenAILLMService(LLMService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_frequency_penalty(self, frequency_penalty: float):
|
||||
logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]")
|
||||
self._frequency_penalty = frequency_penalty
|
||||
|
||||
async def set_presence_penalty(self, presence_penalty: float):
|
||||
logger.debug(f"Switching LLM presence_penalty to: [{presence_penalty}]")
|
||||
self._presence_penalty = presence_penalty
|
||||
|
||||
async def set_seed(self, seed: int):
|
||||
logger.debug(f"Switching LLM seed to: [{seed}]")
|
||||
self._seed = seed
|
||||
|
||||
async def set_temperature(self, temperature: float):
|
||||
logger.debug(f"Switching LLM temperature to: [{temperature}]")
|
||||
self._temperature = temperature
|
||||
|
||||
async def set_top_p(self, top_p: float):
|
||||
logger.debug(f"Switching LLM top_p to: [{top_p}]")
|
||||
self._top_p = top_p
|
||||
|
||||
async def set_extra(self, extra: Dict[str, Any]):
|
||||
logger.debug(f"Switching LLM extra to: [{extra}]")
|
||||
self._extra = extra
|
||||
|
||||
async def get_chat_completions(
|
||||
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
@@ -166,14 +146,14 @@ class BaseOpenAILLMService(LLMService):
|
||||
"tools": context.tools,
|
||||
"tool_choice": context.tool_choice,
|
||||
"stream_options": {"include_usage": True},
|
||||
"frequency_penalty": self._frequency_penalty,
|
||||
"presence_penalty": self._presence_penalty,
|
||||
"seed": self._seed,
|
||||
"temperature": self._temperature,
|
||||
"top_p": self._top_p,
|
||||
"frequency_penalty": self._settings["frequency_penalty"],
|
||||
"presence_penalty": self._settings["presence_penalty"],
|
||||
"seed": self._settings["seed"],
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_p": self._settings["top_p"],
|
||||
}
|
||||
|
||||
params.update(self._extra)
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
chunks = await self._client.chat.completions.create(**params)
|
||||
return chunks
|
||||
@@ -181,7 +161,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
async def _stream_chat_completions(
|
||||
self, context: OpenAILLMContext
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
logger.debug(f"Generating chat: {context.get_messages_json()}")
|
||||
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
|
||||
|
||||
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
||||
|
||||
@@ -293,23 +273,6 @@ class BaseOpenAILLMService(LLMService):
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
|
||||
async def _update_settings(self, frame: LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
if frame.frequency_penalty is not None:
|
||||
await self.set_frequency_penalty(frame.frequency_penalty)
|
||||
if frame.presence_penalty is not None:
|
||||
await self.set_presence_penalty(frame.presence_penalty)
|
||||
if frame.seed is not None:
|
||||
await self.set_seed(frame.seed)
|
||||
if frame.temperature is not None:
|
||||
await self.set_temperature(frame.temperature)
|
||||
if frame.top_p is not None:
|
||||
await self.set_top_p(frame.top_p)
|
||||
if frame.extra:
|
||||
await self.set_extra(frame.extra)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -321,7 +284,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame)
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -356,9 +319,13 @@ class OpenAILLMService(BaseOpenAILLMService):
|
||||
super().__init__(model=model, params=params, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> OpenAIContextAggregatorPair:
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
user = OpenAIUserContextAggregator(context)
|
||||
assistant = OpenAIAssistantContextAggregator(user)
|
||||
assistant = OpenAIAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
@@ -421,22 +388,20 @@ class OpenAITTSService(TTSService):
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._voice: ValidVoice = VALID_VOICES.get(voice, "alloy")
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self._sample_rate = sample_rate
|
||||
self.set_voice(voice)
|
||||
|
||||
self._client = AsyncOpenAI(api_key=api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = VALID_VOICES.get(voice, self._voice)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
logger.debug(f"Switching TTS model to: [{model}]")
|
||||
self._model = model
|
||||
self.set_model_name(model)
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
@@ -446,7 +411,7 @@ class OpenAITTSService(TTSService):
|
||||
async with self._client.audio.speech.with_streaming_response.create(
|
||||
input=text,
|
||||
model=self.model_name,
|
||||
voice=self._voice,
|
||||
voice=VALID_VOICES[self._voice_id],
|
||||
response_format="pcm",
|
||||
) as r:
|
||||
if r.status_code != 200:
|
||||
@@ -461,28 +426,68 @@ class OpenAITTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
async for chunk in r.iter_bytes(8192):
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
except BadRequestError as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
|
||||
|
||||
# internal use only -- todo: refactor
|
||||
@dataclass
|
||||
class OpenAIImageMessageFrame(Frame):
|
||||
user_image_raw_frame: UserImageRawFrame
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
super().__init__(context=context)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# Our parent method has already called push_frame(). So we can't interrupt the
|
||||
# flow here and we don't need to call push_frame() ourselves.
|
||||
try:
|
||||
if isinstance(frame, UserImageRequestFrame):
|
||||
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
|
||||
# that frame so we can use it when we assemble the image message in the assistant
|
||||
# context aggregator.
|
||||
if frame.context:
|
||||
if isinstance(frame.context, str):
|
||||
self._context._user_image_request_context[frame.user_id] = frame.context
|
||||
else:
|
||||
logger.error(
|
||||
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
|
||||
)
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
else:
|
||||
if frame.user_id in self._context._user_image_request_context:
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
elif isinstance(frame, UserImageRawFrame):
|
||||
# Push a new AnthropicImageMessageFrame with the text context we cached
|
||||
# downstream to be handled by our assistant context aggregator. This is
|
||||
# necessary so that we add the message to the context in the right order.
|
||||
text = self._context._user_image_request_context.get(frame.user_id) or ""
|
||||
if text:
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
frame = OpenAIImageMessageFrame(user_image_raw_frame=frame, text=text)
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
|
||||
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_calls_in_progress = {}
|
||||
self._function_call_result = None
|
||||
self._pending_image_frame_message = None
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -503,15 +508,20 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
"FunctionCallResultFrame tool_call_id does not match any function call in progress"
|
||||
)
|
||||
self._function_call_result = None
|
||||
elif isinstance(frame, OpenAIImageMessageFrame):
|
||||
self._pending_image_frame_message = frame
|
||||
await self._push_aggregation()
|
||||
|
||||
async def _push_aggregation(self):
|
||||
if not (self._aggregation or self._function_call_result):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
return
|
||||
|
||||
run_llm = False
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._aggregation = ""
|
||||
self._reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
@@ -544,8 +554,22 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
else:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if self._pending_image_frame_message:
|
||||
frame = self._pending_image_frame_message
|
||||
self._pending_image_frame_message = None
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.user_image_raw_frame.format,
|
||||
size=frame.user_image_raw_frame.size,
|
||||
image=frame.user_image_raw_frame.image,
|
||||
text=frame.text,
|
||||
)
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
@@ -6,17 +6,21 @@
|
||||
|
||||
import io
|
||||
import struct
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
try:
|
||||
from pyht.client import TTSOptions
|
||||
from pyht.async_client import AsyncClient
|
||||
from pyht.client import TTSOptions
|
||||
from pyht.protos.api_pb2 import Format
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -39,17 +43,23 @@ class PlayHTTTSService(TTSService):
|
||||
user_id=self._user_id,
|
||||
api_key=self._speech_key,
|
||||
)
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"quality": "higher",
|
||||
"format": Format.FORMAT_WAV,
|
||||
"voice_engine": "PlayHT2.0-turbo",
|
||||
}
|
||||
self.set_voice(voice_url)
|
||||
self._options = TTSOptions(
|
||||
voice=voice_url, sample_rate=sample_rate, quality="higher", format=Format.FORMAT_WAV
|
||||
voice=self._voice_id,
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
quality=self._settings["quality"],
|
||||
format=self._settings["format"],
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._options.voice = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -60,12 +70,12 @@ class PlayHTTTSService(TTSService):
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
playht_gen = self._client.tts(
|
||||
text, voice_engine="PlayHT2.0-turbo", options=self._options
|
||||
text, voice_engine=self._settings["voice_engine"], options=self._options
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
async for chunk in playht_gen:
|
||||
# skip the RIFF header.
|
||||
if in_header:
|
||||
@@ -83,8 +93,8 @@ class PlayHTTTSService(TTSService):
|
||||
else:
|
||||
if len(chunk):
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, 16000, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
|
||||
@@ -4,42 +4,18 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from asyncio import CancelledError
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
|
||||
try:
|
||||
from together import AsyncTogether
|
||||
# Together.ai is recommending OpenAI-compatible function calling, so we've switched over
|
||||
# to using the OpenAI client library here rather than the Together Python client library.
|
||||
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
@@ -48,19 +24,7 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TogetherContextAggregatorPair:
|
||||
_user: "TogetherUserContextAggregator"
|
||||
_assistant: "TogetherAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "TogetherUserContextAggregator":
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "TogetherAssistantContextAggregator":
|
||||
return self._assistant
|
||||
|
||||
|
||||
class TogetherLLMService(LLMService):
|
||||
class TogetherLLMService(OpenAILLMService):
|
||||
"""This class implements inference with Together's Llama 3.1 models"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
@@ -68,327 +32,45 @@ class TogetherLLMService(LLMService):
|
||||
max_tokens: Optional[int] = Field(default=4096, ge=1)
|
||||
presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
|
||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
# Note: top_k is currently not supported by the OpenAI client library,
|
||||
# so top_k is ignore right now.
|
||||
top_k: Optional[int] = Field(default=None, ge=0)
|
||||
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
seed: Optional[int] = Field(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.together.xyz/v1",
|
||||
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._client = AsyncTogether(api_key=api_key)
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, params=params, **kwargs)
|
||||
self.set_model_name(model)
|
||||
self._max_tokens = params.max_tokens
|
||||
self._frequency_penalty = params.frequency_penalty
|
||||
self._presence_penalty = params.presence_penalty
|
||||
self._temperature = params.temperature
|
||||
self._top_k = params.top_k
|
||||
self._top_p = params.top_p
|
||||
self._extra = params.extra if isinstance(params.extra, dict) else {}
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"frequency_penalty": params.frequency_penalty,
|
||||
"presence_penalty": params.presence_penalty,
|
||||
"seed": params.seed,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> TogetherContextAggregatorPair:
|
||||
user = TogetherUserContextAggregator(context)
|
||||
assistant = TogetherAssistantContextAggregator(user)
|
||||
return TogetherContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def set_frequency_penalty(self, frequency_penalty: float):
|
||||
logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]")
|
||||
self._frequency_penalty = frequency_penalty
|
||||
|
||||
async def set_max_tokens(self, max_tokens: int):
|
||||
logger.debug(f"Switching LLM max_tokens to: [{max_tokens}]")
|
||||
self._max_tokens = max_tokens
|
||||
|
||||
async def set_presence_penalty(self, presence_penalty: float):
|
||||
logger.debug(f"Switching LLM presence_penalty to: [{presence_penalty}]")
|
||||
self._presence_penalty = presence_penalty
|
||||
|
||||
async def set_temperature(self, temperature: float):
|
||||
logger.debug(f"Switching LLM temperature to: [{temperature}]")
|
||||
self._temperature = temperature
|
||||
|
||||
async def set_top_k(self, top_k: float):
|
||||
logger.debug(f"Switching LLM top_k to: [{top_k}]")
|
||||
self._top_k = top_k
|
||||
|
||||
async def set_top_p(self, top_p: float):
|
||||
logger.debug(f"Switching LLM top_p to: [{top_p}]")
|
||||
self._top_p = top_p
|
||||
|
||||
async def set_extra(self, extra: Dict[str, Any]):
|
||||
logger.debug(f"Switching LLM extra to: [{extra}]")
|
||||
self._extra = extra
|
||||
|
||||
async def _update_settings(self, frame: LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
if frame.frequency_penalty is not None:
|
||||
await self.set_frequency_penalty(frame.frequency_penalty)
|
||||
if frame.max_tokens is not None:
|
||||
await self.set_max_tokens(frame.max_tokens)
|
||||
if frame.presence_penalty is not None:
|
||||
await self.set_presence_penalty(frame.presence_penalty)
|
||||
if frame.temperature is not None:
|
||||
await self.set_temperature(frame.temperature)
|
||||
if frame.top_k is not None:
|
||||
await self.set_top_k(frame.top_k)
|
||||
if frame.top_p is not None:
|
||||
await self.set_top_p(frame.top_p)
|
||||
if frame.extra:
|
||||
await self.set_extra(frame.extra)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
|
||||
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
params = {
|
||||
"messages": context.messages,
|
||||
"model": self.model_name,
|
||||
"max_tokens": self._max_tokens,
|
||||
"stream": True,
|
||||
"frequency_penalty": self._frequency_penalty,
|
||||
"presence_penalty": self._presence_penalty,
|
||||
"temperature": self._temperature,
|
||||
"top_k": self._top_k,
|
||||
"top_p": self._top_p,
|
||||
}
|
||||
|
||||
params.update(self._extra)
|
||||
|
||||
stream = await self._client.chat.completions.create(**params)
|
||||
|
||||
# Function calling
|
||||
got_first_chunk = False
|
||||
accumulating_function_call = False
|
||||
function_call_accumulator = ""
|
||||
|
||||
async for chunk in stream:
|
||||
# logger.debug(f"Together LLM event: {chunk}")
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
if not got_first_chunk:
|
||||
await self.stop_ttfb_metrics()
|
||||
if chunk.choices[0].delta.content:
|
||||
got_first_chunk = True
|
||||
if chunk.choices[0].delta.content[0] == "<":
|
||||
accumulating_function_call = True
|
||||
|
||||
if chunk.choices[0].delta.content:
|
||||
if accumulating_function_call:
|
||||
function_call_accumulator += chunk.choices[0].delta.content
|
||||
else:
|
||||
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
if chunk.choices[0].finish_reason == "eos" and accumulating_function_call:
|
||||
await self._extract_function_call(context, function_call_accumulator)
|
||||
|
||||
except CancelledError:
|
||||
# todo: implement token counting estimates for use when the user interrupts a long generation
|
||||
# we do this in the anthropic.py service
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = TogetherLLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if context:
|
||||
await self._process_context(context)
|
||||
|
||||
async def _extract_function_call(self, context, function_call_accumulator):
|
||||
context.add_message({"role": "assistant", "content": function_call_accumulator})
|
||||
|
||||
function_regex = r"<function=(\w+)>(.*?)</function>"
|
||||
match = re.search(function_regex, function_call_accumulator)
|
||||
if match:
|
||||
function_name, args_string = match.groups()
|
||||
try:
|
||||
arguments = json.loads(args_string)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
def create_client(self, api_key=None, base_url=None, **kwargs):
|
||||
logger.debug(f"Creating Together.ai client with api {base_url}")
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
http_client=DefaultAsyncHttpxClient(
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None
|
||||
)
|
||||
return
|
||||
except json.JSONDecodeError as error:
|
||||
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
|
||||
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
|
||||
# Should we do anything more than log a warning?
|
||||
logger.debug(f"Error parsing function arguments: {error}")
|
||||
|
||||
|
||||
class TogetherLLMContext(OpenAILLMContext):
|
||||
def __init__(
|
||||
self,
|
||||
messages: list[dict] | None = None,
|
||||
):
|
||||
super().__init__(messages=messages)
|
||||
|
||||
@classmethod
|
||||
def from_openai_context(cls, openai_context: OpenAILLMContext):
|
||||
self = cls(
|
||||
messages=openai_context.messages,
|
||||
),
|
||||
)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: List[dict]) -> "TogetherLLMContext":
|
||||
return cls(messages=messages)
|
||||
|
||||
def add_message(self, message):
|
||||
try:
|
||||
self.messages.append(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message: {e}")
|
||||
|
||||
def get_messages_for_logging(self) -> str:
|
||||
return json.dumps(self.messages)
|
||||
|
||||
|
||||
class TogetherUserContextAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext | TogetherLLMContext):
|
||||
super().__init__(context=context)
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
self._context = TogetherLLMContext.from_openai_context(context)
|
||||
|
||||
async def push_messages_frame(self):
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# Our parent method has already called push_frame(). So we can't interrupt the
|
||||
# flow here and we don't need to call push_frame() ourselves. Possibly something
|
||||
# to talk through (tagging @aleix). At some point we might need to refactor these
|
||||
# context aggregators.
|
||||
try:
|
||||
if isinstance(frame, UserImageRequestFrame):
|
||||
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
|
||||
# that frame so we can use it when we assemble the image message in the assistant
|
||||
# context aggregator.
|
||||
if frame.context:
|
||||
if isinstance(frame.context, str):
|
||||
self._context._user_image_request_context[frame.user_id] = frame.context
|
||||
else:
|
||||
logger.error(
|
||||
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
|
||||
)
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
else:
|
||||
if frame.user_id in self._context._user_image_request_context:
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
|
||||
#
|
||||
# Claude returns a text content block along with a tool use content block. This works quite nicely
|
||||
# with streaming. We get the text first, so we can start streaming it right away. Then we get the
|
||||
# tool_use block. While the text is streaming to TTS and the transport, we can run the tool call.
|
||||
#
|
||||
# But Claude is verbose. It would be nice to come up with prompt language that suppresses Claude's
|
||||
# chattiness about it's tool thinking.
|
||||
#
|
||||
|
||||
|
||||
class TogetherAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: TogetherUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# See note above about not calling push_frame() here.
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_finished = None
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
self._function_call_in_progress = frame
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
if (
|
||||
self._function_call_in_progress
|
||||
and self._function_call_in_progress.tool_call_id == frame.tool_call_id
|
||||
):
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = frame
|
||||
await self._push_aggregation()
|
||||
else:
|
||||
logger.warning(
|
||||
"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
|
||||
)
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
|
||||
def add_message(self, message):
|
||||
self._user_context_aggregator.add_message(message)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
if not (self._aggregation or self._function_call_result):
|
||||
return
|
||||
|
||||
run_llm = False
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._aggregation = ""
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
frame = self._function_call_result
|
||||
self._function_call_result = None
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
# Together expects the content here to be a string, so stringify it
|
||||
"content": str(frame.result),
|
||||
}
|
||||
)
|
||||
run_llm = True
|
||||
else:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_messages_frame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
@@ -4,10 +4,12 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
@@ -17,10 +19,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
import numpy as np
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import resampy
|
||||
@@ -38,21 +37,67 @@ except ModuleNotFoundError as e:
|
||||
# https://github.com/coqui-ai/xtts-streaming-server
|
||||
|
||||
|
||||
def language_to_xtts_language(language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.CS:
|
||||
return "cs"
|
||||
case Language.DE:
|
||||
return "de"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "en"
|
||||
case Language.ES:
|
||||
return "es"
|
||||
case Language.FR:
|
||||
return "fr"
|
||||
case Language.HI:
|
||||
return "hi"
|
||||
case Language.HU:
|
||||
return "hu"
|
||||
case Language.IT:
|
||||
return "it"
|
||||
case Language.JA:
|
||||
return "ja"
|
||||
case Language.KO:
|
||||
return "ko"
|
||||
case Language.NL:
|
||||
return "nl"
|
||||
case Language.PL:
|
||||
return "pl"
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "pt"
|
||||
case Language.RU:
|
||||
return "ru"
|
||||
case Language.TR:
|
||||
return "tr"
|
||||
case Language.ZH:
|
||||
return "zh-cn"
|
||||
return None
|
||||
|
||||
|
||||
class XTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
voice_id: str,
|
||||
language: str,
|
||||
language: Language,
|
||||
base_url: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice_id = voice_id
|
||||
self._language = language
|
||||
self._base_url = base_url
|
||||
self._settings = {
|
||||
"language": language,
|
||||
"base_url": base_url,
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self._studio_speakers: Dict[str, Any] | None = None
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
@@ -61,7 +106,7 @@ class XTTSService(TTSService):
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
async with self._aiohttp_session.get(self._base_url + "/studio_speakers") as r:
|
||||
async with self._aiohttp_session.get(self._settings["base_url"] + "/studio_speakers") as r:
|
||||
if r.status != 200:
|
||||
text = await r.text()
|
||||
logger.error(
|
||||
@@ -75,10 +120,6 @@ class XTTSService(TTSService):
|
||||
return
|
||||
self._studio_speakers = await r.json()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -88,11 +129,13 @@ class XTTSService(TTSService):
|
||||
|
||||
embeddings = self._studio_speakers[self._voice_id]
|
||||
|
||||
url = self._base_url + "/tts_stream"
|
||||
url = self._settings["base_url"] + "/tts_stream"
|
||||
|
||||
language = language_to_xtts_language(self._settings["language"])
|
||||
|
||||
payload = {
|
||||
"text": text.replace(".", "").replace("*", ""),
|
||||
"language": self._language,
|
||||
"language": language,
|
||||
"speaker_embedding": embeddings["speaker_embedding"],
|
||||
"gpt_cond_latent": embeddings["gpt_cond_latent"],
|
||||
"add_wav_header": False,
|
||||
@@ -110,7 +153,7 @@ class XTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
buffer = bytearray()
|
||||
async for chunk in r.content.iter_chunked(1024):
|
||||
@@ -146,4 +189,4 @@ class XTTSService(TTSService):
|
||||
frame = TTSAudioRawFrame(resampled_audio_bytes, 16000, 1)
|
||||
yield frame
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -5,17 +5,17 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
InputAudioRawFrame,
|
||||
StartFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
SystemFrame,
|
||||
@@ -23,11 +23,10 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
VADParamsUpdateFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.vad.vad_analyzer import VADAnalyzer, VADState
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class BaseInputTransport(FrameProcessor):
|
||||
def __init__(self, params: TransportParams, **kwargs):
|
||||
@@ -87,6 +86,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
logger.debug("Bot interruption")
|
||||
await self._start_interruption()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
import re
|
||||
|
||||
|
||||
ENDOFSENTENCE_PATTERN_STR = r"""
|
||||
(?<![A-Z]) # Negative lookbehind: not preceded by an uppercase letter (e.g., "U.S.A.")
|
||||
(?<!\d) # Negative lookbehind: not preceded by a digit (e.g., "1. Let's start")
|
||||
@@ -21,5 +20,6 @@ ENDOFSENTENCE_PATTERN_STR = r"""
|
||||
ENDOFSENTENCE_PATTERN = re.compile(ENDOFSENTENCE_PATTERN_STR, re.VERBOSE)
|
||||
|
||||
|
||||
def match_endofsentence(text: str) -> bool:
|
||||
return ENDOFSENTENCE_PATTERN.search(text.rstrip()) is not None
|
||||
def match_endofsentence(text: str) -> int:
|
||||
match = ENDOFSENTENCE_PATTERN.search(text.rstrip())
|
||||
return match.end() if match else 0
|
||||
|
||||
@@ -2,7 +2,7 @@ aiohttp~=3.10.3
|
||||
anthropic~=0.30.0
|
||||
azure-cognitiveservices-speech~=1.40.0
|
||||
boto3~=1.35.27
|
||||
daily-python~=0.10.1
|
||||
daily-python~=0.11.0
|
||||
deepgram-sdk~=3.5.0
|
||||
fal-client~=0.4.1
|
||||
fastapi~=0.112.1
|
||||
|
||||
Reference in New Issue
Block a user