Compare commits
77 Commits
async-reba
...
v0.0.43
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66a76af341 | ||
|
|
d402d91c2f | ||
|
|
b05130a089 | ||
|
|
b3cc0779f0 | ||
|
|
cbecae40a9 | ||
|
|
5b8753c8b6 | ||
|
|
3c5f9457f1 | ||
|
|
e32e56d0bc | ||
|
|
788aec665b | ||
|
|
3cada03a92 | ||
|
|
e21fb520f9 | ||
|
|
3403197a90 | ||
|
|
8cdb9ab1ad | ||
|
|
5dbf26d283 | ||
|
|
8001bab9b0 | ||
|
|
12d0686adc | ||
|
|
a28a5e954a | ||
|
|
bb966a89d2 | ||
|
|
4a74eb3321 | ||
|
|
1f54ee6991 | ||
|
|
ea2a05a04b | ||
|
|
5692ca586c | ||
|
|
a11ad81f02 | ||
|
|
c49b31e6ad | ||
|
|
7796a272ce | ||
|
|
27dcf83f37 | ||
|
|
72db83528d | ||
|
|
45c7d36b2e | ||
|
|
65eeb0f1f6 | ||
|
|
1d7d0bb1ea | ||
|
|
598936bc53 | ||
|
|
b1bf6f7733 | ||
|
|
75d27aeb9f | ||
|
|
0a37caf4b4 | ||
|
|
6db65f4335 | ||
|
|
3648874301 | ||
|
|
8bcb5d7fd2 | ||
|
|
8c01a900cd | ||
|
|
d378e699d2 | ||
|
|
c25c375c41 | ||
|
|
70c3ff31fd | ||
|
|
cd2e29f285 | ||
|
|
6d4d7d763d | ||
|
|
6c1851eef8 | ||
|
|
096a15eef6 | ||
|
|
3d642df2b0 | ||
|
|
d75a02dc51 | ||
|
|
28643b453d | ||
|
|
88cca7bf68 | ||
|
|
a397b859fe | ||
|
|
8aae4e9856 | ||
|
|
92d8b37229 | ||
|
|
0801fc578b | ||
|
|
0d5cb84531 | ||
|
|
47b943a117 | ||
|
|
128355add5 | ||
|
|
0499fe41e4 | ||
|
|
6ad3437fd2 | ||
|
|
a5c73ec829 | ||
|
|
def04ac0ce | ||
|
|
5d63615b1b | ||
|
|
90ee284fe0 | ||
|
|
539e0b66fb | ||
|
|
fef393dcac | ||
|
|
ed607d5c4b | ||
|
|
37da7e44cd | ||
|
|
69c7edd60c | ||
|
|
392f210371 | ||
|
|
9a63df1ea1 | ||
|
|
f8a75cede9 | ||
|
|
4d1e370e02 | ||
|
|
d080a31a5c | ||
|
|
a90ebdfe7c | ||
|
|
c8995b82e5 | ||
|
|
6b7f924af6 | ||
|
|
51580e5349 | ||
|
|
ed49cebf2c |
86
CHANGELOG.md
86
CHANGELOG.md
@@ -1,20 +1,55 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to **pipecat** will be documented in this file.
|
||||
All notable changes to **Pipecat** will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
## [0.0.43] - 2024-10-10
|
||||
|
||||
### Added
|
||||
|
||||
- Added Google TTS service and corresponding foundational example `07n-interruptible-google.py`
|
||||
- Added a new util called `MarkdownTextFilter` which is a subclass of a new
|
||||
base class called `BaseTextFilter`. This is a configurable utility which
|
||||
is intended to filter text received by TTS services.
|
||||
|
||||
- Added new `RTVIUserLLMTextProcessor`. This processor will send an RTVI
|
||||
`user-llm-text` message with the user content's that was sent to the LLM.
|
||||
|
||||
### Changed
|
||||
|
||||
- `TransportMessageFrame` doesn't have an `urgent` field anymore, instead
|
||||
there's now a `TransportMessageUrgentFrame` which is a `SystemFrame` and
|
||||
therefore skip all internal queuing.
|
||||
|
||||
- For TTS services, convert inputted languages to match each service's language
|
||||
format
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where changing a language with the Deepgram STT service
|
||||
wouldn't apply the change. This was fixed by disconnecting and reconnecting
|
||||
when the language changes.
|
||||
|
||||
## [0.0.42] - 2024-10-02
|
||||
|
||||
### Added
|
||||
|
||||
- `SentryMetrics` has been added to report frame processor metrics to
|
||||
Sentry. This is now possible because `FrameProcessorMetrics` can now be passed
|
||||
to `FrameProcessor`.
|
||||
|
||||
- Added Google TTS service and corresponding foundational example
|
||||
`07n-interruptible-google.py`
|
||||
|
||||
- Added AWS Polly TTS support and `07m-interruptible-aws.py` as an example.
|
||||
|
||||
- Added InputParams to Azure TTS service.
|
||||
|
||||
- Added `LivekitTransport` (audio-only for now).
|
||||
|
||||
- RTVI 0.2.0 is now supported.
|
||||
|
||||
- All `FrameProcessors` can now register event handlers.
|
||||
|
||||
```
|
||||
@@ -48,15 +83,10 @@ async def on_connected(processor):
|
||||
frames. To achieve that, each frame processor should only output frames from a
|
||||
single task.
|
||||
|
||||
In this version we introduce synchronous and asynchronous frame
|
||||
processors. The synchronous processors push output frames from the same task
|
||||
that they receive input frames, and therefore only pushing frames from one
|
||||
task. Asynchronous frame processors can have internal tasks to perform things
|
||||
asynchronously (e.g. receiving data from a websocket) but they also have a
|
||||
single task where they push frames from.
|
||||
|
||||
By default, frame processors are synchronous. To change a frame processor to
|
||||
asynchronous you only need to pass `sync=False` to the base class constructor.
|
||||
In this version all the frame processors have their own task to push
|
||||
frames. That is, when `push_frame()` is called the given frame will be put
|
||||
into an internal queue (with the exception of system frames) and a frame
|
||||
processor task will push it out.
|
||||
|
||||
- Added pipeline clocks. A pipeline clock is used by the output transport to
|
||||
know when a frame needs to be presented. For that, all frames now have an
|
||||
@@ -68,9 +98,7 @@ async def on_connected(processor):
|
||||
`SystemClock`). This clock will be passed to each frame processor via the
|
||||
`StartFrame`.
|
||||
|
||||
- Added `CartesiaHttpTTSService`. This is a synchronous frame processor
|
||||
(i.e. given an input text frame it will wait for the whole output before
|
||||
returning).
|
||||
- Added `CartesiaHttpTTSService`.
|
||||
|
||||
- `DailyTransport` now supports setting the audio bitrate to improve audio
|
||||
quality through the `DailyParams.audio_out_bitrate` parameter. The new
|
||||
@@ -93,8 +121,12 @@ async def on_connected(processor):
|
||||
|
||||
### Changed
|
||||
|
||||
- Updated individual update settings frame classes into a single UpdateSettingsFrame
|
||||
class for STT, LLM, and TTS.
|
||||
- Context frames are now pushed downstream from assistant context aggregators.
|
||||
|
||||
- Removed Silero VAD torch dependency.
|
||||
|
||||
- Updated individual update settings frame classes into a single
|
||||
`ServiceUpdateSettingsFrame` class.
|
||||
|
||||
- We now distinguish between input and output audio and image frames. We
|
||||
introduce `InputAudioRawFrame`, `OutputAudioRawFrame`, `InputImageRawFrame`
|
||||
@@ -110,12 +142,13 @@ async def on_connected(processor):
|
||||
pipelines to be executed concurrently. The difference between a
|
||||
`SyncParallelPipeline` and a `ParallelPipeline` is that, given an input frame,
|
||||
the `SyncParallelPipeline` will wait for all the internal pipelines to
|
||||
complete. This is achieved by ensuring all the processors in each of the
|
||||
internal pipelines are synchronous.
|
||||
complete. This is achieved by making sure the last processor in each of the
|
||||
pipelines is synchronous (e.g. an HTTP-based service that waits for the
|
||||
response).
|
||||
|
||||
- `StartFrame` is back a system frame so we make sure it's processed immediately
|
||||
by all processors. `EndFrame` stays a control frame since it needs to be
|
||||
ordered allowing the frames in the pipeline to be processed.
|
||||
- `StartFrame` is back a system frame to make sure it's processed immediately by
|
||||
all processors. `EndFrame` stays a control frame since it needs to be ordered
|
||||
allowing the frames in the pipeline to be processed.
|
||||
|
||||
- Updated `MoondreamService` revision to `2024-08-26`.
|
||||
|
||||
@@ -139,6 +172,11 @@ async def on_connected(processor):
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed OpenAI multiple function calls.
|
||||
|
||||
- Fixed a Cartesia TTS issue that would cause audio to be truncated in some
|
||||
cases.
|
||||
|
||||
- Fixed a `BaseOutputTransport` issue that would stop audio and video rendering
|
||||
tasks (after receiving and `EndFrame`) before the internal queue was emptied,
|
||||
causing the pipeline to finish prematurely.
|
||||
@@ -152,6 +190,10 @@ async def on_connected(processor):
|
||||
- `obj_id()` and `obj_count()` now use `itertools.count` avoiding the need of
|
||||
`threading.Lock`.
|
||||
|
||||
### Other
|
||||
|
||||
- Pipecat now uses Ruff as its formatter (https://github.com/astral-sh/ruff).
|
||||
|
||||
## [0.0.41] - 2024-08-22
|
||||
|
||||
### Added
|
||||
|
||||
@@ -86,13 +86,13 @@ async def main():
|
||||
),
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
tts = CartesiaHttpTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
imagegen = FalImageGenService(
|
||||
params=FalImageGenService.InputParams(image_size="square_hd"),
|
||||
aiohttp_session=session,
|
||||
@@ -107,8 +107,10 @@ async def main():
|
||||
# that, each pipeline runs concurrently and `SyncParallelPipeline` will
|
||||
# wait for the input frame to be processed.
|
||||
#
|
||||
# Note that `SyncParallelPipeline` requires all processors in it to be
|
||||
# synchronous (which is the default for most processors).
|
||||
# Note that `SyncParallelPipeline` requires the last processor in each
|
||||
# of the pipelines to be synchronous. In this case, we use
|
||||
# `CartesiaHttpTTSService` and `FalImageGenService` which make HTTP
|
||||
# requests and wait for the response.
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm, # LLM
|
||||
|
||||
@@ -82,6 +82,7 @@ async def main():
|
||||
self.frame = OutputAudioRawFrame(
|
||||
bytes(self.audio), frame.sample_rate, frame.num_channels
|
||||
)
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
class ImageGrabber(FrameProcessor):
|
||||
def __init__(self):
|
||||
@@ -93,6 +94,7 @@ async def main():
|
||||
|
||||
if isinstance(frame, URLImageRawFrame):
|
||||
self.frame = frame
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
@@ -121,8 +123,10 @@ async def main():
|
||||
# `SyncParallelPipeline` will wait for the input frame to be
|
||||
# processed.
|
||||
#
|
||||
# Note that `SyncParallelPipeline` requires all processors in it to
|
||||
# be synchronous (which is the default for most processors).
|
||||
# Note that `SyncParallelPipeline` requires the last processor in
|
||||
# each of the pipelines to be synchronous. In this case, we use
|
||||
# `CartesiaHttpTTSService` and `FalImageGenService` which make HTTP
|
||||
# requests and wait for the response.
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
llm, # LLM
|
||||
|
||||
@@ -5,29 +5,24 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.anthropic import AnthropicLLMService
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
@@ -69,17 +64,17 @@ async def main():
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
tma_in, # User responses
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out, # Assistant spoken responses
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -4,11 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -17,17 +21,11 @@ from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.playht import PlayHTTTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.playht import PlayHTTTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
|
||||
@@ -4,11 +4,15 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -17,17 +21,10 @@ from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.openai import OpenAITTSService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.services.openai import OpenAILLMService, OpenAITTSService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
|
||||
@@ -5,29 +5,24 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantResponseAggregator,
|
||||
LLMUserResponseAggregator,
|
||||
)
|
||||
from pipecat.services.ai_services import OpenAILLMContext
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.together import TogetherLLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
@@ -72,25 +67,32 @@ async def main():
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond in plain language. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
user_aggregator = context_aggregator.user()
|
||||
assistant_aggregator = context_aggregator.assistant()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
tma_in, # User responses
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
tma_out, # Assistant spoken responses
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True, enable_metrics=True, enable_usage_metrics=True
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
|
||||
@@ -53,7 +53,6 @@ async def main():
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = GoogleTTSService(
|
||||
credentials=os.getenv("GOOGLE_CREDENTIALS"),
|
||||
voice_id="en-US-Neural2-J",
|
||||
params=GoogleTTSService.InputParams(language="en-US", rate="1.05"),
|
||||
)
|
||||
|
||||
@@ -9,11 +9,9 @@ import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.logger import FrameLogger
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
@@ -34,7 +32,12 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def start_fetch_weather(function_name, llm, context):
|
||||
await llm.push_frame(TextFrame("Let me check on that."))
|
||||
# note: we can't push a frame to the LLM here. the bot
|
||||
# can interrupt itself and/or cause audio overlapping glitches.
|
||||
# possible question for Aleix and Chad about what the right way
|
||||
# to trigger speech is, now, with the new queues/async/sync refactors.
|
||||
# await llm.push_frame(TextFrame("Let me check on that."))
|
||||
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
@@ -67,9 +70,6 @@ async def main():
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
|
||||
|
||||
fl_in = FrameLogger("Inner")
|
||||
fl_out = FrameLogger("Outer")
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
@@ -106,11 +106,9 @@ async def main():
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
fl_in,
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
fl_out,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
|
||||
136
examples/foundational/14c-function-calling-together.py
Normal file
136
examples/foundational/14c-function-calling-together.py
Normal file
@@ -0,0 +1,136 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMContext
|
||||
from pipecat.services.together import TogetherLLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def start_fetch_weather(function_name, llm, context):
|
||||
# note: we can't push a frame to the LLM here. the bot
|
||||
# can interrupt itself and/or cause audio overlapping glitches.
|
||||
# possible question for Aleix and Chad about what the right way
|
||||
# to trigger speech is, now, with the new queues/async/sync refactors.
|
||||
# await llm.push_frame(TextFrame("Let me check on that."))
|
||||
logger.debug(f"Starting fetch_weather_from_api with function_name: {function_name}")
|
||||
|
||||
|
||||
async def fetch_weather_from_api(function_name, tool_call_id, args, llm, context, result_callback):
|
||||
await result_callback({"conditions": "nice", "temperature": "75"})
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = TogetherLLMService(
|
||||
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||
model="meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
)
|
||||
# Register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
llm.register_function(None, fetch_weather_from_api, start_callback=start_fetch_weather)
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
# await tts.say("Hi! Ask me about the weather in San Francisco.")
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
167
examples/foundational/14d-function-calling-video.py
Normal file
167
examples/foundational/14d-function-calling-video.py
Normal file
@@ -0,0 +1,167 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
video_participant_id = None
|
||||
|
||||
|
||||
async def get_weather(function_name, tool_call_id, arguments, llm, context, result_callback):
|
||||
location = arguments["location"]
|
||||
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||
|
||||
|
||||
async def get_image(function_name, tool_call_id, arguments, llm, context, result_callback):
|
||||
logger.debug(f"!!! IN get_image {video_participant_id}, {arguments}")
|
||||
question = arguments["question"]
|
||||
await llm.request_image_frame(user_id=video_participant_id, text_content=question)
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
llm.register_function("get_weather", get_weather)
|
||||
llm.register_function("get_image", get_image)
|
||||
|
||||
tools = [
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"],
|
||||
"description": "The temperature unit to use. Infer this from the users location.",
|
||||
},
|
||||
},
|
||||
"required": ["location", "format"],
|
||||
},
|
||||
},
|
||||
),
|
||||
ChatCompletionToolParam(
|
||||
type="function",
|
||||
function={
|
||||
"name": "get_image",
|
||||
"description": "Get an image from the video stream.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"question": {
|
||||
"type": "string",
|
||||
"description": "The question to ask the AI to generate an image of",
|
||||
},
|
||||
},
|
||||
"required": ["question"],
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
system_prompt = """\
|
||||
You are a helpful assistant who converses with a user and answers questions. Respond concisely to general questions.
|
||||
|
||||
Your response will be turned into speech so use only simple words and punctuation.
|
||||
|
||||
You have access to two tools: get_weather and get_image.
|
||||
|
||||
You can respond to questions about the weather using the get_weather tool.
|
||||
|
||||
You can answer questions about the user's video stream using the get_image tool. Some examples of phrases that \
|
||||
indicate you should use the get_image tool are:
|
||||
- What do you see?
|
||||
- What's in the video?
|
||||
- Can you describe the video?
|
||||
- Tell me about what you see.
|
||||
- Tell me something interesting about what you see.
|
||||
- What's happening in the video?
|
||||
"""
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages, tools)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
context_aggregator.user(),
|
||||
llm,
|
||||
tts,
|
||||
transport.output(),
|
||||
context_aggregator.assistant(),
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
global video_participant_id
|
||||
video_participant_id = participant["id"]
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
transport.capture_participant_video(video_participant_id, framerate=0)
|
||||
# Kick off the conversation.
|
||||
await tts.say("Hi! Ask me about the weather in San Francisco.")
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -5,10 +5,14 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
@@ -26,12 +30,6 @@ from pipecat.transports.services.daily import (
|
||||
)
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
from pipecat.frames.frames import LLMMessagesFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.together import TogetherLLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
from pipecat.vad.silero import SileroVADAnalyzer
|
||||
|
||||
from runner import configure
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def get_current_weather(
|
||||
function_name, tool_call_id, arguments, llm, context, result_callback
|
||||
):
|
||||
logger.debug("IN get_current_weather")
|
||||
location = arguments["location"]
|
||||
await result_callback(f"The weather in {location} is currently 72 degrees and sunny.")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
transcription_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = TogetherLLMService(
|
||||
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||
model=os.getenv("TOGETHER_MODEL"),
|
||||
)
|
||||
llm.register_function("get_current_weather", get_current_weather)
|
||||
|
||||
weatherTool = {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
|
||||
system_prompt = f"""\
|
||||
You have access to the following functions:
|
||||
|
||||
Use the function '{weatherTool["name"]}' to '{weatherTool["description"]}':
|
||||
{json.dumps(weatherTool)}
|
||||
|
||||
If you choose to call a function ONLY reply in the following format with no prefix or suffix:
|
||||
|
||||
<function=example_function_name>{{\"example_name\": \"example_value\"}}</function>
|
||||
|
||||
Reminder:
|
||||
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
||||
- Required parameters MUST be specified
|
||||
- Only call one function at a time
|
||||
- Put the entire function call reply on one line
|
||||
- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls
|
||||
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": "Wait for the user to say something."},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
context_aggregator.user(), # User speech to text
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses and tool context
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True, enable_metrics=True))
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
await task.queue_frames([LLMMessagesFrame(messages)])
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,4 +1,4 @@
|
||||
DAILY_SAMPLE_ROOM_URL=https://yourdomain.daily.co/yourroom # (for joining the bot to the same room repeatedly for local dev)
|
||||
DAILY_API_KEY=7df...
|
||||
OPENAI_API_KEY=sk-PL...
|
||||
ELEVENLABS_API_KEY=aeb...
|
||||
CARTESIA_API_KEY=your_cartesia_api_key_here
|
||||
|
||||
2497
examples/storytelling-chatbot/frontend/package-lock.json
generated
2497
examples/storytelling-chatbot/frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -11,28 +11,28 @@
|
||||
"dependencies": {
|
||||
"@daily-co/daily-js": "^0.62.0",
|
||||
"@daily-co/daily-react": "^0.18.0",
|
||||
"@radix-ui/react-select": "^2.0.0",
|
||||
"@radix-ui/react-select": "^2.1.2",
|
||||
"@radix-ui/react-slot": "^1.0.2",
|
||||
"@tabler/icons-react": "^3.1.0",
|
||||
"@tabler/icons-react": "^3.19.0",
|
||||
"class-variance-authority": "^0.7.0",
|
||||
"clsx": "^2.1.0",
|
||||
"framer-motion": "^11.0.27",
|
||||
"next": "14.1.4",
|
||||
"react": "^18",
|
||||
"react-dom": "^18",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^11.9.0",
|
||||
"next": "^14.2.14",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"recoil": "^0.7.7",
|
||||
"tailwind-merge": "^2.2.2",
|
||||
"tailwind-merge": "^2.5.2",
|
||||
"tailwindcss-animate": "^1.0.7"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^20",
|
||||
"@types/react": "^18",
|
||||
"@types/react-dom": "^18",
|
||||
"autoprefixer": "^10.0.1",
|
||||
"eslint": "^8",
|
||||
"@types/node": "^20.16.10",
|
||||
"@types/react": "^18.3.11",
|
||||
"@types/react-dom": "^18.3.0",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"eslint": "^8.57.1",
|
||||
"eslint-config-next": "14.1.4",
|
||||
"postcss": "^8",
|
||||
"tailwindcss": "^3.4.3",
|
||||
"typescript": "^5"
|
||||
"postcss": "^8.4.47",
|
||||
"tailwindcss": "^3.4.13",
|
||||
"typescript": "^5.6.2"
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -143,7 +143,7 @@ async def main(room_url, token=None):
|
||||
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
intro_task.queue_frame(EndFrame())
|
||||
await intro_task.queue_frame(EndFrame())
|
||||
await main_task.queue_frame(EndFrame())
|
||||
|
||||
@transport.event_handler("on_call_state_updated")
|
||||
|
||||
@@ -21,6 +21,7 @@ classifiers = [
|
||||
]
|
||||
dependencies = [
|
||||
"aiohttp~=3.10.3",
|
||||
"Markdown~=3.7",
|
||||
"numpy~=1.26.4",
|
||||
"loguru~=0.7.2",
|
||||
"Pillow~=10.4.0",
|
||||
@@ -38,8 +39,8 @@ anthropic = [ "anthropic~=0.34.0" ]
|
||||
aws = [ "boto3~=1.35.27" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.40.0" ]
|
||||
cartesia = [ "cartesia~=1.0.13", "websockets~=12.0" ]
|
||||
daily = [ "daily-python~=0.10.1" ]
|
||||
deepgram = [ "deepgram-sdk~=3.5.0" ]
|
||||
daily = [ "daily-python~=0.11.0" ]
|
||||
deepgram = [ "deepgram-sdk~=3.7.3" ]
|
||||
elevenlabs = [ "websockets~=12.0" ]
|
||||
examples = [ "python-dotenv~=1.0.1", "flask~=3.0.3", "flask_cors~=4.0.1" ]
|
||||
fal = [ "fal-client~=0.4.1" ]
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
#
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from pipecat.clocks.base_clock import BaseClock
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
@@ -269,7 +269,6 @@ class TTSSpeakFrame(DataFrame):
|
||||
@dataclass
|
||||
class TransportMessageFrame(DataFrame):
|
||||
message: Any
|
||||
urgent: bool = False
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(message: {self.message})"
|
||||
@@ -405,6 +404,14 @@ class BotInterruptionFrame(SystemFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransportMessageUrgentFrame(SystemFrame):
|
||||
message: Any
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(message: {self.message})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricsFrame(SystemFrame):
|
||||
"""Emitted by processor that can compute metrics like latencies."""
|
||||
@@ -527,45 +534,25 @@ class UserImageRequestFrame(ControlFrame):
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update LLM settings."""
|
||||
class ServiceUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update service settings."""
|
||||
|
||||
model: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
extra: dict = field(default_factory=dict)
|
||||
settings: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TTSUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update TTS settings."""
|
||||
|
||||
model: Optional[str] = None
|
||||
voice: Optional[str] = None
|
||||
language: Optional[Language] = None
|
||||
speed: Optional[Union[str, float]] = None
|
||||
emotion: Optional[List[str]] = None
|
||||
engine: Optional[str] = None
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
emphasis: Optional[str] = None
|
||||
style: Optional[str] = None
|
||||
style_degree: Optional[str] = None
|
||||
role: Optional[str] = None
|
||||
class LLMUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class STTUpdateSettingsFrame(ControlFrame):
|
||||
"""A control frame containing a request to update STT settings."""
|
||||
class TTSUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
model: Optional[str] = None
|
||||
language: Optional[Language] = None
|
||||
|
||||
@dataclass
|
||||
class STTUpdateSettingsFrame(ServiceUpdateSettingsFrame):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -585,6 +572,7 @@ class FunctionCallResultFrame(DataFrame):
|
||||
tool_call_id: str
|
||||
arguments: str
|
||||
result: Any
|
||||
run_llm: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -120,7 +120,7 @@ class ParallelPipeline(BasePipeline):
|
||||
|
||||
# If we get an EndFrame we stop our queue processing tasks and wait on
|
||||
# all the pipelines to finish.
|
||||
if isinstance(frame, CancelFrame) or isinstance(frame, EndFrame):
|
||||
if isinstance(frame, (CancelFrame, EndFrame)):
|
||||
# Use None to indicate when queues should be done processing.
|
||||
await self._up_queue.put(None)
|
||||
await self._down_queue.put(None)
|
||||
|
||||
@@ -6,17 +6,25 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
from pipecat.frames.frames import ControlFrame, EndFrame, Frame, SystemFrame
|
||||
from pipecat.pipeline.base_pipeline import BasePipeline
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import Frame
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class SyncFrame(ControlFrame):
|
||||
"""This frame is used to know when the internal pipelines have finished."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Source(FrameProcessor):
|
||||
def __init__(self, upstream_queue: asyncio.Queue):
|
||||
super().__init__()
|
||||
@@ -67,13 +75,16 @@ class SyncParallelPipeline(BasePipeline):
|
||||
raise TypeError(f"SyncParallelPipeline argument {processors} is not a list")
|
||||
|
||||
# We add a source at the beginning of the pipeline and a sink at the end.
|
||||
source = Source(self._up_queue)
|
||||
sink = Sink(self._down_queue)
|
||||
up_queue = asyncio.Queue()
|
||||
down_queue = asyncio.Queue()
|
||||
source = Source(up_queue)
|
||||
sink = Sink(down_queue)
|
||||
processors: List[FrameProcessor] = [source] + processors + [sink]
|
||||
|
||||
# Keep track of sources and sinks.
|
||||
self._sources.append(source)
|
||||
self._sinks.append(sink)
|
||||
# Keep track of sources and sinks. We also keep the output queue of
|
||||
# the source and the sinks so we can use it later.
|
||||
self._sources.append({"processor": source, "queue": down_queue})
|
||||
self._sinks.append({"processor": sink, "queue": up_queue})
|
||||
|
||||
# Create pipeline
|
||||
pipeline = Pipeline(processors)
|
||||
@@ -94,17 +105,52 @@ class SyncParallelPipeline(BasePipeline):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# The last processor of each pipeline needs to be synchronous otherwise
|
||||
# this element won't work. Since, we know it should be synchronous we
|
||||
# push a SyncFrame. Since frames are ordered we know this frame will be
|
||||
# pushed after the synchronous processor has pushed its data allowing us
|
||||
# to synchrnonize all the internal pipelines by waiting for the
|
||||
# SyncFrame in all of them.
|
||||
async def wait_for_sync(
|
||||
obj, main_queue: asyncio.Queue, frame: Frame, direction: FrameDirection
|
||||
):
|
||||
processor = obj["processor"]
|
||||
queue = obj["queue"]
|
||||
|
||||
await processor.process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, (SystemFrame, EndFrame)):
|
||||
new_frame = await queue.get()
|
||||
if isinstance(new_frame, (SystemFrame, EndFrame)):
|
||||
await main_queue.put(new_frame)
|
||||
else:
|
||||
while not isinstance(new_frame, (SystemFrame, EndFrame)):
|
||||
await main_queue.put(new_frame)
|
||||
queue.task_done()
|
||||
new_frame = await queue.get()
|
||||
else:
|
||||
await processor.process_frame(SyncFrame(), direction)
|
||||
new_frame = await queue.get()
|
||||
while not isinstance(new_frame, SyncFrame):
|
||||
await main_queue.put(new_frame)
|
||||
queue.task_done()
|
||||
new_frame = await queue.get()
|
||||
|
||||
if direction == FrameDirection.UPSTREAM:
|
||||
# If we get an upstream frame we process it in each sink.
|
||||
await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sinks])
|
||||
await asyncio.gather(
|
||||
*[wait_for_sync(s, self._up_queue, frame, direction) for s in self._sinks]
|
||||
)
|
||||
elif direction == FrameDirection.DOWNSTREAM:
|
||||
# If we get a downstream frame we process it in each source.
|
||||
await asyncio.gather(*[s.process_frame(frame, direction) for s in self._sources])
|
||||
await asyncio.gather(
|
||||
*[wait_for_sync(s, self._down_queue, frame, direction) for s in self._sources]
|
||||
)
|
||||
|
||||
seen_ids = set()
|
||||
while not self._up_queue.empty():
|
||||
frame = await self._up_queue.get()
|
||||
if frame and frame.id not in seen_ids:
|
||||
if frame.id not in seen_ids:
|
||||
await self.push_frame(frame, FrameDirection.UPSTREAM)
|
||||
seen_ids.add(frame.id)
|
||||
self._up_queue.task_done()
|
||||
@@ -112,7 +158,7 @@ class SyncParallelPipeline(BasePipeline):
|
||||
seen_ids = set()
|
||||
while not self._down_queue.empty():
|
||||
frame = await self._down_queue.get()
|
||||
if frame and frame.id not in seen_ids:
|
||||
if frame.id not in seen_ids:
|
||||
await self.push_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
seen_ids.add(frame.id)
|
||||
self._down_queue.task_done()
|
||||
|
||||
@@ -69,6 +69,19 @@ class Source(FrameProcessor):
|
||||
await self._up_queue.put(StopTaskFrame())
|
||||
|
||||
|
||||
class Sink(FrameProcessor):
|
||||
def __init__(self, down_queue: asyncio.Queue):
|
||||
super().__init__()
|
||||
self._down_queue = down_queue
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# We really just want to know when the EndFrame reached the sink.
|
||||
if isinstance(frame, EndFrame):
|
||||
await self._down_queue.put(frame)
|
||||
|
||||
|
||||
class PipelineTask:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -84,12 +97,16 @@ class PipelineTask:
|
||||
self._params = params
|
||||
self._finished = False
|
||||
|
||||
self._down_queue = asyncio.Queue()
|
||||
self._up_queue = asyncio.Queue()
|
||||
self._down_queue = asyncio.Queue()
|
||||
self._push_queue = asyncio.Queue()
|
||||
|
||||
self._source = Source(self._up_queue)
|
||||
self._source.link(pipeline)
|
||||
|
||||
self._sink = Sink(self._down_queue)
|
||||
pipeline.link(self._sink)
|
||||
|
||||
def has_finished(self):
|
||||
return self._finished
|
||||
|
||||
@@ -103,19 +120,19 @@ class PipelineTask:
|
||||
# out-of-band from the main streaming task which is what we want since
|
||||
# we want to cancel right away.
|
||||
await self._source.push_frame(CancelFrame())
|
||||
self._process_down_task.cancel()
|
||||
self._process_push_task.cancel()
|
||||
self._process_up_task.cancel()
|
||||
await self._process_down_task
|
||||
await self._process_push_task
|
||||
await self._process_up_task
|
||||
|
||||
async def run(self):
|
||||
self._process_up_task = asyncio.create_task(self._process_up_queue())
|
||||
self._process_down_task = asyncio.create_task(self._process_down_queue())
|
||||
await asyncio.gather(self._process_up_task, self._process_down_task)
|
||||
self._process_push_task = asyncio.create_task(self._process_push_queue())
|
||||
await asyncio.gather(self._process_up_task, self._process_push_task)
|
||||
self._finished = True
|
||||
|
||||
async def queue_frame(self, frame: Frame):
|
||||
await self._down_queue.put(frame)
|
||||
await self._push_queue.put(frame)
|
||||
|
||||
async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]):
|
||||
if isinstance(frames, AsyncIterable):
|
||||
@@ -133,7 +150,7 @@ class PipelineTask:
|
||||
data.append(ProcessingMetricsData(processor=p.name, value=0.0))
|
||||
return MetricsFrame(data=data)
|
||||
|
||||
async def _process_down_queue(self):
|
||||
async def _process_push_queue(self):
|
||||
self._clock.start()
|
||||
|
||||
start_frame = StartFrame(
|
||||
@@ -154,11 +171,13 @@ class PipelineTask:
|
||||
should_cleanup = True
|
||||
while running:
|
||||
try:
|
||||
frame = await self._down_queue.get()
|
||||
frame = await self._push_queue.get()
|
||||
await self._source.process_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
running = not (isinstance(frame, StopTaskFrame) or isinstance(frame, EndFrame))
|
||||
if isinstance(frame, EndFrame):
|
||||
await self._wait_for_endframe()
|
||||
running = not isinstance(frame, (StopTaskFrame, EndFrame))
|
||||
should_cleanup = not isinstance(frame, StopTaskFrame)
|
||||
self._down_queue.task_done()
|
||||
self._push_queue.task_done()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
# Cleanup only if we need to.
|
||||
@@ -169,6 +188,12 @@ class PipelineTask:
|
||||
self._process_up_task.cancel()
|
||||
await self._process_up_task
|
||||
|
||||
async def _wait_for_endframe(self):
|
||||
# NOTE(aleix): the Sink element just pushes EndFrames to the down queue,
|
||||
# so just wait for it. In the future we might do something else here,
|
||||
# but for now this is fine.
|
||||
await self._down_queue.get()
|
||||
|
||||
async def _process_up_queue(self):
|
||||
while True:
|
||||
try:
|
||||
|
||||
@@ -6,12 +6,6 @@
|
||||
|
||||
from typing import List, Type
|
||||
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContextFrame,
|
||||
OpenAILLMContext,
|
||||
)
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
@@ -22,11 +16,16 @@ from pipecat.frames.frames import (
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMSetToolsFrame,
|
||||
StartInterruptionFrame,
|
||||
TranscriptionFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class LLMResponseAggregator(FrameProcessor):
|
||||
@@ -40,6 +39,7 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
accumulator_frame: Type[TextFrame],
|
||||
interim_accumulator_frame: Type[TextFrame] | None = None,
|
||||
handle_interruptions: bool = False,
|
||||
expect_stripped_words: bool = True, # if True, need to add spaces between words
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -50,6 +50,7 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
self._accumulator_frame = accumulator_frame
|
||||
self._interim_accumulator_frame = interim_accumulator_frame
|
||||
self._handle_interruptions = handle_interruptions
|
||||
self._expect_stripped_words = expect_stripped_words
|
||||
|
||||
# Reset our accumulator state.
|
||||
self._reset()
|
||||
@@ -111,7 +112,10 @@ class LLMResponseAggregator(FrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, self._accumulator_frame):
|
||||
if self._aggregating:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
if self._expect_stripped_words:
|
||||
self._aggregation += f" {frame.text}" if self._aggregation else frame.text
|
||||
else:
|
||||
self._aggregation += frame.text
|
||||
# We have recevied a complete sentence, so if we have seen the
|
||||
# end frame and we were still aggregating, it means we should
|
||||
# send the aggregation.
|
||||
@@ -290,7 +294,7 @@ class LLMContextAggregator(LLMResponseAggregator):
|
||||
|
||||
|
||||
class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
def __init__(self, context: OpenAILLMContext, *, expect_stripped_words: bool = True):
|
||||
super().__init__(
|
||||
messages=[],
|
||||
context=context,
|
||||
@@ -299,6 +303,7 @@ class LLMAssistantContextAggregator(LLMContextAggregator):
|
||||
end_frame=LLMFullResponseEndFrame,
|
||||
accumulator_frame=TextFrame,
|
||||
handle_interruptions=True,
|
||||
expect_stripped_words=expect_stripped_words,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import base64
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
|
||||
@@ -60,6 +62,7 @@ class OpenAILLMContext:
|
||||
self._messages: List[ChatCompletionMessageParam] = messages if messages else []
|
||||
self._tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = tool_choice
|
||||
self._tools: List[ChatCompletionToolParam] | NotGiven = tools
|
||||
self._user_image_request_context = {}
|
||||
|
||||
@staticmethod
|
||||
def from_messages(messages: List[dict]) -> "OpenAILLMContext":
|
||||
@@ -114,6 +117,21 @@ class OpenAILLMContext:
|
||||
def get_messages_json(self) -> str:
|
||||
return json.dumps(self._messages, cls=CustomEncoder)
|
||||
|
||||
def get_messages_for_logging(self) -> str:
|
||||
msgs = []
|
||||
for message in self.messages:
|
||||
msg = copy.deepcopy(message)
|
||||
if "content" in msg:
|
||||
if isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item["type"] == "image_url":
|
||||
if item["image_url"]["url"].startswith("data:image/"):
|
||||
item["image_url"]["url"] = "data:image/..."
|
||||
if "mime_type" in msg and msg["mime_type"].startswith("image/"):
|
||||
msg["data"] = "..."
|
||||
msgs.append(msg)
|
||||
return json.dumps(msgs)
|
||||
|
||||
def set_tool_choice(self, tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven):
|
||||
self._tool_choice = tool_choice
|
||||
|
||||
@@ -122,6 +140,21 @@ class OpenAILLMContext:
|
||||
tools = NOT_GIVEN
|
||||
self._tools = tools
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
):
|
||||
buffer = io.BytesIO()
|
||||
Image.frombytes(format, size, image).save(buffer, format="JPEG")
|
||||
encoded_image = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
content = [
|
||||
{"type": "text", "text": text},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}},
|
||||
]
|
||||
if text:
|
||||
content.append({"type": "text", "text": text})
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
async def call_function(
|
||||
self,
|
||||
f: Callable[
|
||||
@@ -133,6 +166,7 @@ class OpenAILLMContext:
|
||||
tool_call_id: str,
|
||||
arguments: str,
|
||||
llm: FrameProcessor,
|
||||
run_llm: bool = True,
|
||||
) -> None:
|
||||
# Push a SystemFrame downstream. This frame will let our assistant context aggregator
|
||||
# know that we are in the middle of a function call. Some contexts/aggregators may
|
||||
@@ -153,6 +187,7 @@ class OpenAILLMContext:
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
result=result,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -37,7 +37,6 @@ class FrameProcessor:
|
||||
*,
|
||||
name: str | None = None,
|
||||
metrics: FrameProcessorMetrics | None = None,
|
||||
sync: bool = True,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
@@ -47,7 +46,6 @@ class FrameProcessor:
|
||||
self._prev: "FrameProcessor" | None = None
|
||||
self._next: "FrameProcessor" | None = None
|
||||
self._loop: asyncio.AbstractEventLoop = loop or asyncio.get_running_loop()
|
||||
self._sync = sync
|
||||
|
||||
self._event_handlers: dict = {}
|
||||
|
||||
@@ -66,11 +64,8 @@ class FrameProcessor:
|
||||
|
||||
# Every processor in Pipecat should only output frames from a single
|
||||
# task. This avoid problems like audio overlapping. System frames are
|
||||
# the exception to this rule.
|
||||
#
|
||||
# This create this task.
|
||||
if not self._sync:
|
||||
self.__create_push_task()
|
||||
# the exception to this rule. This create this task.
|
||||
self.__create_push_task()
|
||||
|
||||
@property
|
||||
def interruptions_allowed(self):
|
||||
@@ -167,7 +162,7 @@ class FrameProcessor:
|
||||
await self.push_frame(error, FrameDirection.UPSTREAM)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
if self._sync or isinstance(frame, SystemFrame):
|
||||
if isinstance(frame, SystemFrame):
|
||||
await self.__internal_push_frame(frame, direction)
|
||||
else:
|
||||
await self.__push_queue.put((frame, direction))
|
||||
@@ -194,13 +189,12 @@ class FrameProcessor:
|
||||
#
|
||||
|
||||
async def _start_interruption(self):
|
||||
if not self._sync:
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self.__push_frame_task.cancel()
|
||||
await self.__push_frame_task
|
||||
# Cancel the task. This will stop pushing frames downstream.
|
||||
self.__push_frame_task.cancel()
|
||||
await self.__push_frame_task
|
||||
|
||||
# Create a new queue and task.
|
||||
self.__create_push_task()
|
||||
# Create a new queue and task.
|
||||
self.__create_push_task()
|
||||
|
||||
async def _stop_interruption(self):
|
||||
# Nothing to do right now.
|
||||
|
||||
@@ -6,10 +6,11 @@
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, Field, PrivateAttr, ValidationError
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, PrivateAttr, ValidationError
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
@@ -20,27 +21,28 @@ from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
FunctionCallResultFrame,
|
||||
InterimTranscriptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
OutputAudioRawFrame,
|
||||
StartFrame,
|
||||
SystemFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
FunctionCallResultFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
RTVI_PROTOCOL_VERSION = "0.2"
|
||||
|
||||
ActionResult = Union[bool, int, float, str, list, dict]
|
||||
@@ -291,22 +293,12 @@ class RTVIAudioMessageData(BaseModel):
|
||||
num_channels: int
|
||||
|
||||
|
||||
class RTVIBotAudioMessage(BaseModel):
|
||||
class RTVIBotTTSAudioMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-audio"] = "bot-audio"
|
||||
type: Literal["bot-tts-audio"] = "bot-tts-audio"
|
||||
data: RTVIAudioMessageData
|
||||
|
||||
|
||||
class RTVIBotTranscriptionMessageData(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class RTVIBotTranscriptionMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["bot-transcription"] = "bot-transcription"
|
||||
data: RTVIBotTranscriptionMessageData
|
||||
|
||||
|
||||
class RTVIUserTranscriptionMessageData(BaseModel):
|
||||
text: str
|
||||
user_id: str
|
||||
@@ -320,6 +312,12 @@ class RTVIUserTranscriptionMessage(BaseModel):
|
||||
data: RTVIUserTranscriptionMessageData
|
||||
|
||||
|
||||
class RTVIUserLLMTextMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-llm-text"] = "user-llm-text"
|
||||
data: RTVITextMessageData
|
||||
|
||||
|
||||
class RTVIUserStartedSpeakingMessage(BaseModel):
|
||||
label: Literal["rtvi-ai"] = "rtvi-ai"
|
||||
type: Literal["user-started-speaking"] = "user-started-speaking"
|
||||
@@ -350,9 +348,11 @@ class RTVIFrameProcessor(FrameProcessor):
|
||||
self._direction = direction
|
||||
|
||||
async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True):
|
||||
frame = TransportMessageFrame(
|
||||
message=model.model_dump(exclude_none=exclude_none), urgent=True
|
||||
)
|
||||
frame = TransportMessageFrame(message=model.model_dump(exclude_none=exclude_none))
|
||||
await self.push_frame(frame, self._direction)
|
||||
|
||||
async def _push_transport_message_urgent(self, model: BaseModel, exclude_none: bool = True):
|
||||
frame = TransportMessageUrgentFrame(message=model.model_dump(exclude_none=exclude_none))
|
||||
await self.push_frame(frame, self._direction)
|
||||
|
||||
|
||||
@@ -378,7 +378,7 @@ class RTVISpeakingProcessor(RTVIFrameProcessor):
|
||||
message = RTVIUserStoppedSpeakingMessage()
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
await self._push_transport_message_urgent(message)
|
||||
|
||||
async def _handle_bot_speaking(self, frame: Frame):
|
||||
message = None
|
||||
@@ -388,7 +388,7 @@ class RTVISpeakingProcessor(RTVIFrameProcessor):
|
||||
message = RTVIBotStoppedSpeakingMessage()
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
await self._push_transport_message_urgent(message)
|
||||
|
||||
|
||||
class RTVIUserTranscriptionProcessor(RTVIFrameProcessor):
|
||||
@@ -419,7 +419,36 @@ class RTVIUserTranscriptionProcessor(RTVIFrameProcessor):
|
||||
)
|
||||
|
||||
if message:
|
||||
await self._push_transport_message(message)
|
||||
await self._push_transport_message_urgent(message)
|
||||
|
||||
|
||||
class RTVIUserLLMTextProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
await self._handle_context(frame)
|
||||
|
||||
async def _handle_context(self, frame: OpenAILLMContextFrame):
|
||||
messages = frame.context.messages
|
||||
if len(messages) > 0:
|
||||
message = messages[-1]
|
||||
if message["role"] == "user":
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
print("LIST")
|
||||
text = " ".join(item["text"] for item in content if "text" in item)
|
||||
else:
|
||||
print("STRING")
|
||||
text = content
|
||||
|
||||
rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text))
|
||||
await self._push_transport_message_urgent(rtvi_message)
|
||||
|
||||
|
||||
class RTVIBotLLMProcessor(RTVIFrameProcessor):
|
||||
@@ -432,9 +461,9 @@ class RTVIBotLLMProcessor(RTVIFrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMFullResponseStartFrame):
|
||||
await self._push_transport_message(RTVIBotLLMStartedMessage())
|
||||
await self._push_transport_message_urgent(RTVIBotLLMStartedMessage())
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._push_transport_message(RTVIBotLLMStoppedMessage())
|
||||
await self._push_transport_message_urgent(RTVIBotLLMStoppedMessage())
|
||||
|
||||
|
||||
class RTVIBotTTSProcessor(RTVIFrameProcessor):
|
||||
@@ -447,9 +476,9 @@ class RTVIBotTTSProcessor(RTVIFrameProcessor):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TTSStartedFrame):
|
||||
await self._push_transport_message(RTVIBotTTSStartedMessage())
|
||||
await self._push_transport_message_urgent(RTVIBotTTSStartedMessage())
|
||||
elif isinstance(frame, TTSStoppedFrame):
|
||||
await self._push_transport_message(RTVIBotTTSStoppedMessage())
|
||||
await self._push_transport_message_urgent(RTVIBotTTSStoppedMessage())
|
||||
|
||||
|
||||
class RTVIBotLLMTextProcessor(RTVIFrameProcessor):
|
||||
@@ -466,7 +495,7 @@ class RTVIBotLLMTextProcessor(RTVIFrameProcessor):
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
message = RTVIBotLLMTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self._push_transport_message(message)
|
||||
await self._push_transport_message_urgent(message)
|
||||
|
||||
|
||||
class RTVIBotTTSTextProcessor(RTVIFrameProcessor):
|
||||
@@ -483,10 +512,10 @@ class RTVIBotTTSTextProcessor(RTVIFrameProcessor):
|
||||
|
||||
async def _handle_text(self, frame: TextFrame):
|
||||
message = RTVIBotTTSTextMessage(data=RTVITextMessageData(text=frame.text))
|
||||
await self._push_transport_message(message)
|
||||
await self._push_transport_message_urgent(message)
|
||||
|
||||
|
||||
class RTVIBotAudioProcessor(RTVIFrameProcessor):
|
||||
class RTVIBotTTSAudioProcessor(RTVIFrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -500,7 +529,7 @@ class RTVIBotAudioProcessor(RTVIFrameProcessor):
|
||||
|
||||
async def _handle_audio(self, frame: OutputAudioRawFrame):
|
||||
encoded = base64.b64encode(frame.audio).decode("utf-8")
|
||||
message = RTVIBotAudioMessage(
|
||||
message = RTVIBotTTSAudioMessage(
|
||||
data=RTVIAudioMessageData(
|
||||
audio=encoded, sample_rate=frame.sample_rate, num_channels=frame.num_channels
|
||||
)
|
||||
@@ -516,7 +545,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
params: RTVIProcessorParams = RTVIProcessorParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
self._config = config
|
||||
self._params = params
|
||||
|
||||
@@ -647,9 +676,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._message_task = None
|
||||
|
||||
async def _push_transport_message(self, model: BaseModel, exclude_none: bool = True):
|
||||
frame = TransportMessageFrame(
|
||||
message=model.model_dump(exclude_none=exclude_none), urgent=True
|
||||
)
|
||||
frame = TransportMessageUrgentFrame(message=model.model_dump(exclude_none=exclude_none))
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def _action_task_handler(self):
|
||||
|
||||
@@ -44,7 +44,7 @@ class GStreamerPipelineSource(FrameProcessor):
|
||||
clock_sync: bool = True
|
||||
|
||||
def __init__(self, *, pipeline: str, out_params: OutputParams = OutputParams(), **kwargs):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._out_params = out_params
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ class IdleFrameProcessor(FrameProcessor):
|
||||
types: List[type] = [],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._callback = callback
|
||||
self._timeout = timeout
|
||||
|
||||
@@ -31,7 +31,7 @@ class UserIdleProcessor(FrameProcessor):
|
||||
timeout: float,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._callback = callback
|
||||
self._timeout = timeout
|
||||
|
||||
@@ -8,7 +8,7 @@ import asyncio
|
||||
import io
|
||||
import wave
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator, List, Optional, Tuple, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -37,6 +37,7 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.audio import calculate_audio_volume
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_filter import BaseTextFilter
|
||||
from pipecat.utils.time import seconds_to_nanoseconds
|
||||
from pipecat.utils.utils import exp_smoothing
|
||||
|
||||
@@ -45,6 +46,7 @@ class AIService(FrameProcessor):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._model_name: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
@@ -63,6 +65,16 @@ class AIService(FrameProcessor):
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
pass
|
||||
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
logger.debug(f"Updating setting {key} to: [{value}] for {self.name}")
|
||||
self._settings[key] = value
|
||||
elif key == "model":
|
||||
self.set_model_name(value)
|
||||
else:
|
||||
logger.warning(f"Unknown setting for {self.name} service: {key}")
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -110,7 +122,13 @@ class LLMService(AIService):
|
||||
return function_name in self._callbacks.keys()
|
||||
|
||||
async def call_function(
|
||||
self, *, context: OpenAILLMContext, tool_call_id: str, function_name: str, arguments: str
|
||||
self,
|
||||
*,
|
||||
context: OpenAILLMContext,
|
||||
tool_call_id: str,
|
||||
function_name: str,
|
||||
arguments: str,
|
||||
run_llm: bool = True,
|
||||
) -> None:
|
||||
f = None
|
||||
if function_name in self._callbacks.keys():
|
||||
@@ -120,7 +138,12 @@ class LLMService(AIService):
|
||||
else:
|
||||
return None
|
||||
await context.call_function(
|
||||
f, function_name=function_name, tool_call_id=tool_call_id, arguments=arguments, llm=self
|
||||
f,
|
||||
function_name=function_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=arguments,
|
||||
llm=self,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
|
||||
# QUESTION FOR CB: maybe this isn't needed anymore?
|
||||
@@ -144,15 +167,29 @@ class TTSService(AIService):
|
||||
# if True, TTSService will push TextFrames and LLMFullResponseEndFrames,
|
||||
# otherwise subclass must do it
|
||||
push_text_frames: bool = True,
|
||||
# if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it
|
||||
push_stop_frames: bool = False,
|
||||
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
|
||||
stop_frame_timeout_s: float = 1.0,
|
||||
# TTS output sample rate
|
||||
sample_rate: int = 16000,
|
||||
text_filter: Optional[BaseTextFilter] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._aggregate_sentences: bool = aggregate_sentences
|
||||
self._push_text_frames: bool = push_text_frames
|
||||
self._current_sentence: str = ""
|
||||
self._push_stop_frames: bool = push_stop_frames
|
||||
self._stop_frame_timeout_s: float = stop_frame_timeout_s
|
||||
self._sample_rate: int = sample_rate
|
||||
self._voice_id: str = ""
|
||||
self._settings: Dict[str, Any] = {}
|
||||
self._text_filter: Optional[BaseTextFilter] = text_filter
|
||||
|
||||
self._stop_frame_task: Optional[asyncio.Task] = None
|
||||
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
|
||||
|
||||
self._current_sentence: str = ""
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
@@ -163,165 +200,20 @@ class TTSService(AIService):
|
||||
self.set_model_name(model)
|
||||
|
||||
@abstractmethod
|
||||
async def set_voice(self, voice: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_language(self, language: Language):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_speed(self, speed: Union[str, float]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_emotion(self, emotion: List[str]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_engine(self, engine: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_pitch(self, pitch: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_rate(self, rate: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_volume(self, volume: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_emphasis(self, emphasis: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_style(self, style: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_style_degree(self, style_degree: str):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_role(self, role: str):
|
||||
pass
|
||||
|
||||
# Converts the text to audio.
|
||||
@abstractmethod
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
pass
|
||||
|
||||
async def say(self, text: str):
|
||||
await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM)
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
self._current_sentence = ""
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_text_frame(self, frame: TextFrame):
|
||||
text: str | None = None
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
else:
|
||||
self._current_sentence += frame.text
|
||||
if match_endofsentence(self._current_sentence):
|
||||
text = self._current_sentence
|
||||
self._current_sentence = ""
|
||||
|
||||
if text:
|
||||
await self._push_tts_frames(text)
|
||||
|
||||
async def _push_tts_frames(self, text: str):
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
await self.start_processing_metrics()
|
||||
await self.process_generator(self.run_tts(text))
|
||||
await self.stop_processing_metrics()
|
||||
if self._push_text_frames:
|
||||
# We send the original text after the audio. This way, if we are
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
await self.push_frame(TextFrame(text))
|
||||
|
||||
async def _update_tts_settings(self, frame: TTSUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
await self.set_model(frame.model)
|
||||
if frame.voice is not None:
|
||||
await self.set_voice(frame.voice)
|
||||
if frame.language is not None:
|
||||
await self.set_language(frame.language)
|
||||
if frame.speed is not None:
|
||||
await self.set_speed(frame.speed)
|
||||
if frame.emotion is not None:
|
||||
await self.set_emotion(frame.emotion)
|
||||
if frame.engine is not None:
|
||||
await self.set_engine(frame.engine)
|
||||
if frame.pitch is not None:
|
||||
await self.set_pitch(frame.pitch)
|
||||
if frame.rate is not None:
|
||||
await self.set_rate(frame.rate)
|
||||
if frame.volume is not None:
|
||||
await self.set_volume(frame.volume)
|
||||
if frame.emphasis is not None:
|
||||
await self.set_emphasis(frame.emphasis)
|
||||
if frame.style is not None:
|
||||
await self.set_style(frame.style)
|
||||
if frame.style_degree is not None:
|
||||
await self.set_style_degree(frame.style_degree)
|
||||
if frame.role is not None:
|
||||
await self.set_role(frame.role)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._process_text_frame(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruption(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame):
|
||||
sentence = self._current_sentence
|
||||
self._current_sentence = ""
|
||||
await self._push_tts_frames(sentence)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
await self._push_tts_frames(frame.text)
|
||||
elif isinstance(frame, TTSUpdateSettingsFrame):
|
||||
await self._update_tts_settings(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class AsyncTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
# if True, TTSService will push TTSStoppedFrames, otherwise subclass must do it
|
||||
push_stop_frames: bool = False,
|
||||
# if push_stop_frames is True, wait for this idle period before pushing TTSStoppedFrame
|
||||
stop_frame_timeout_s: float = 1.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
self._push_stop_frames: bool = push_stop_frames
|
||||
self._stop_frame_timeout_s: float = stop_frame_timeout_s
|
||||
self._stop_frame_task: Optional[asyncio.Task] = None
|
||||
self._stop_frame_queue: asyncio.Queue = asyncio.Queue()
|
||||
def set_voice(self, voice: str):
|
||||
self._voice_id = voice
|
||||
|
||||
@abstractmethod
|
||||
async def flush_audio(self):
|
||||
pass
|
||||
|
||||
async def say(self, text: str):
|
||||
await super().say(text)
|
||||
await self.flush_audio()
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
return Language(language)
|
||||
|
||||
# Converts the text to audio.
|
||||
@abstractmethod
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
pass
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
@@ -342,10 +234,52 @@ class AsyncTTSService(TTSService):
|
||||
await self._stop_frame_task
|
||||
self._stop_frame_task = None
|
||||
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
logger.debug(f"Updating TTS setting {key} to: [{value}]")
|
||||
self._settings[key] = value
|
||||
if key == "language":
|
||||
self._settings[key] = self.language_to_service_language(value)
|
||||
elif key == "model":
|
||||
self.set_model_name(value)
|
||||
elif key == "voice":
|
||||
self.set_voice(value)
|
||||
elif key == "text_filter" and self._text_filter:
|
||||
self._text_filter.update_settings(value)
|
||||
else:
|
||||
logger.warning(f"Unknown setting for TTS service: {key}")
|
||||
|
||||
async def say(self, text: str):
|
||||
aggregate_sentences = self._aggregate_sentences
|
||||
self._aggregate_sentences = False
|
||||
await self.process_frame(TextFrame(text=text), FrameDirection.DOWNSTREAM)
|
||||
self._aggregate_sentences = aggregate_sentences
|
||||
await self.flush_audio()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, TTSSpeakFrame):
|
||||
|
||||
if isinstance(frame, TextFrame):
|
||||
await self._process_text_frame(frame)
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
await self._handle_interruption(frame, direction)
|
||||
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
sentence = self._current_sentence
|
||||
self._current_sentence = ""
|
||||
await self._push_tts_frames(sentence)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, TTSSpeakFrame):
|
||||
await self._push_tts_frames(frame.text)
|
||||
await self.flush_audio()
|
||||
elif isinstance(frame, TTSUpdateSettingsFrame):
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def push_frame(self, frame: Frame, direction: FrameDirection = FrameDirection.DOWNSTREAM):
|
||||
await super().push_frame(frame, direction)
|
||||
@@ -358,6 +292,40 @@ class AsyncTTSService(TTSService):
|
||||
):
|
||||
await self._stop_frame_queue.put(frame)
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
self._current_sentence = ""
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
async def _process_text_frame(self, frame: TextFrame):
|
||||
text: str | None = None
|
||||
if not self._aggregate_sentences:
|
||||
text = frame.text
|
||||
else:
|
||||
self._current_sentence += frame.text
|
||||
eos_end_marker = match_endofsentence(self._current_sentence)
|
||||
if eos_end_marker:
|
||||
text = self._current_sentence[:eos_end_marker]
|
||||
self._current_sentence = self._current_sentence[eos_end_marker:]
|
||||
|
||||
if text:
|
||||
await self._push_tts_frames(text)
|
||||
|
||||
async def _push_tts_frames(self, text: str):
|
||||
# Don't send only whitespace. This causes problems for some TTS models. But also don't
|
||||
# strip all whitespace, as whitespace can influence prosody.
|
||||
if not text.strip():
|
||||
return
|
||||
|
||||
await self.start_processing_metrics()
|
||||
if self._text_filter:
|
||||
text = self._text_filter.filter(text)
|
||||
await self.process_generator(self.run_tts(text))
|
||||
await self.stop_processing_metrics()
|
||||
if self._push_text_frames:
|
||||
# We send the original text after the audio. This way, if we are
|
||||
# interrupted, the text is not added to the assistant context.
|
||||
await self.push_frame(TextFrame(text))
|
||||
|
||||
async def _stop_frame_handler(self):
|
||||
try:
|
||||
has_started = False
|
||||
@@ -378,7 +346,7 @@ class AsyncTTSService(TTSService):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncWordTTSService(AsyncTTSService):
|
||||
class WordTTSService(TTSService):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._initial_word_timestamp = -1
|
||||
@@ -408,7 +376,7 @@ class AsyncWordTTSService(AsyncTTSService):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMFullResponseEndFrame) or isinstance(frame, EndFrame):
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
|
||||
async def _handle_interruption(self, frame: StartInterruptionFrame, direction: FrameDirection):
|
||||
@@ -443,6 +411,7 @@ class STTService(AIService):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._settings: Dict[str, Any] = {}
|
||||
|
||||
@abstractmethod
|
||||
async def set_model(self, model: str):
|
||||
@@ -457,11 +426,18 @@ class STTService(AIService):
|
||||
"""Returns transcript as a string"""
|
||||
pass
|
||||
|
||||
async def _update_stt_settings(self, frame: STTUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
await self.set_model(frame.model)
|
||||
if frame.language is not None:
|
||||
await self.set_language(frame.language)
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
logger.debug(f"Updating STT settings: {self._settings}")
|
||||
for key, value in settings.items():
|
||||
if key in self._settings:
|
||||
logger.debug(f"Updating STT setting {key} to: [{value}]")
|
||||
self._settings[key] = value
|
||||
if key == "language":
|
||||
await self.set_language(value)
|
||||
elif key == "model":
|
||||
self.set_model_name(value)
|
||||
else:
|
||||
logger.warning(f"Unknown setting for STT service: {key}")
|
||||
|
||||
async def process_audio_frame(self, frame: AudioRawFrame):
|
||||
await self.process_generator(self.run_stt(frame.audio))
|
||||
@@ -475,7 +451,7 @@ class STTService(AIService):
|
||||
# push a TextFrame. We don't really want to push audio frames down.
|
||||
await self.process_audio_frame(frame)
|
||||
elif isinstance(frame, STTUpdateSettingsFrame):
|
||||
await self._update_stt_settings(frame)
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -55,6 +55,7 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
# internal use only -- todo: refactor
|
||||
@dataclass
|
||||
class AnthropicImageMessageFrame(Frame):
|
||||
user_image_raw_frame: UserImageRawFrame
|
||||
@@ -95,12 +96,14 @@ class AnthropicLLMService(LLMService):
|
||||
super().__init__(**kwargs)
|
||||
self._client = AsyncAnthropic(api_key=api_key)
|
||||
self.set_model_name(model)
|
||||
self._max_tokens = params.max_tokens
|
||||
self._enable_prompt_caching_beta: bool = params.enable_prompt_caching_beta or False
|
||||
self._temperature = params.temperature
|
||||
self._top_k = params.top_k
|
||||
self._top_p = params.top_p
|
||||
self._extra = params.extra if isinstance(params.extra, dict) else {}
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"enable_prompt_caching_beta": params.enable_prompt_caching_beta or False,
|
||||
"temperature": params.temperature,
|
||||
"top_k": params.top_k,
|
||||
"top_p": params.top_p,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
@@ -110,35 +113,15 @@ class AnthropicLLMService(LLMService):
|
||||
return self._enable_prompt_caching_beta
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> AnthropicContextAggregatorPair:
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> AnthropicContextAggregatorPair:
|
||||
user = AnthropicUserContextAggregator(context)
|
||||
assistant = AnthropicAssistantContextAggregator(user)
|
||||
assistant = AnthropicAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return AnthropicContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def set_enable_prompt_caching_beta(self, enable_prompt_caching_beta: bool):
|
||||
logger.debug(f"Switching LLM enable_prompt_caching_beta to: [{enable_prompt_caching_beta}]")
|
||||
self._enable_prompt_caching_beta = enable_prompt_caching_beta
|
||||
|
||||
async def set_max_tokens(self, max_tokens: int):
|
||||
logger.debug(f"Switching LLM max_tokens to: [{max_tokens}]")
|
||||
self._max_tokens = max_tokens
|
||||
|
||||
async def set_temperature(self, temperature: float):
|
||||
logger.debug(f"Switching LLM temperature to: [{temperature}]")
|
||||
self._temperature = temperature
|
||||
|
||||
async def set_top_k(self, top_k: float):
|
||||
logger.debug(f"Switching LLM top_k to: [{top_k}]")
|
||||
self._top_k = top_k
|
||||
|
||||
async def set_top_p(self, top_p: float):
|
||||
logger.debug(f"Switching LLM top_p to: [{top_p}]")
|
||||
self._top_p = top_p
|
||||
|
||||
async def set_extra(self, extra: Dict[str, Any]):
|
||||
logger.debug(f"Switching LLM extra to: [{extra}]")
|
||||
self._extra = extra
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
# Usage tracking. We track the usage reported by Anthropic in prompt_tokens and
|
||||
# completion_tokens. We also estimate the completion tokens from output text
|
||||
@@ -160,11 +143,11 @@ class AnthropicLLMService(LLMService):
|
||||
)
|
||||
|
||||
messages = context.messages
|
||||
if self._enable_prompt_caching_beta:
|
||||
if self._settings["enable_prompt_caching_beta"]:
|
||||
messages = context.get_messages_with_cache_control_markers()
|
||||
|
||||
api_call = self._client.messages.create
|
||||
if self._enable_prompt_caching_beta:
|
||||
if self._settings["enable_prompt_caching_beta"]:
|
||||
api_call = self._client.beta.prompt_caching.messages.create
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
@@ -174,14 +157,14 @@ class AnthropicLLMService(LLMService):
|
||||
"system": context.system,
|
||||
"messages": messages,
|
||||
"model": self.model_name,
|
||||
"max_tokens": self._max_tokens,
|
||||
"max_tokens": self._settings["max_tokens"],
|
||||
"stream": True,
|
||||
"temperature": self._temperature,
|
||||
"top_k": self._top_k,
|
||||
"top_p": self._top_p,
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_k": self._settings["top_k"],
|
||||
"top_p": self._settings["top_p"],
|
||||
}
|
||||
|
||||
params.update(self._extra)
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
response = await api_call(**params)
|
||||
|
||||
@@ -279,21 +262,6 @@ class AnthropicLLMService(LLMService):
|
||||
cache_read_input_tokens=cache_read_input_tokens,
|
||||
)
|
||||
|
||||
async def _update_settings(self, frame: LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
if frame.max_tokens is not None:
|
||||
await self.set_max_tokens(frame.max_tokens)
|
||||
if frame.temperature is not None:
|
||||
await self.set_temperature(frame.temperature)
|
||||
if frame.top_k is not None:
|
||||
await self.set_top_k(frame.top_k)
|
||||
if frame.top_p is not None:
|
||||
await self.set_top_p(frame.top_p)
|
||||
if frame.extra:
|
||||
await self.set_extra(frame.extra)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
@@ -309,10 +277,10 @@ class AnthropicLLMService(LLMService):
|
||||
# to the context.
|
||||
context = AnthropicLLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame)
|
||||
await self._update_settings(frame.settings)
|
||||
elif isinstance(frame, LLMEnablePromptCachingFrame):
|
||||
logger.debug(f"Setting enable prompt caching to: [{frame.enable}]")
|
||||
self._enable_prompt_caching_beta = frame.enable
|
||||
self._settings["enable_prompt_caching_beta"] = frame.enable
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -355,7 +323,6 @@ class AnthropicLLMContext(OpenAILLMContext):
|
||||
system: str | NotGiven = NOT_GIVEN,
|
||||
):
|
||||
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
|
||||
self._user_image_request_context = {}
|
||||
|
||||
# For beta prompt caching. This is a counter that tracks the number of turns
|
||||
# we've seen above the cache threshold. We reset this when we reset the
|
||||
@@ -541,8 +508,8 @@ class AnthropicUserContextAggregator(LLMUserContextAggregator):
|
||||
|
||||
|
||||
class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
def __init__(self, user_context_aggregator: AnthropicUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
@@ -579,7 +546,7 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
run_llm = False
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._aggregation = ""
|
||||
self._reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
@@ -630,5 +597,8 @@ class AnthropicAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -16,8 +17,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import boto3
|
||||
@@ -33,7 +33,7 @@ except ModuleNotFoundError as e:
|
||||
class AWSTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
engine: Optional[str] = None
|
||||
language: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
@@ -57,28 +57,95 @@ class AWSTTSService(TTSService):
|
||||
aws_secret_access_key=api_key,
|
||||
region_name=region,
|
||||
)
|
||||
self._voice_id = voice_id
|
||||
self._sample_rate = sample_rate
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"engine": params.engine,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"volume": params.volume,
|
||||
}
|
||||
|
||||
self.set_voice(voice_id)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.CA:
|
||||
return "ca-ES"
|
||||
case Language.ZH:
|
||||
return "cmn-CN"
|
||||
case Language.DA:
|
||||
return "da-DK"
|
||||
case Language.NL:
|
||||
return "nl-NL"
|
||||
case Language.NL_BE:
|
||||
return "nl-BE"
|
||||
case Language.EN | Language.EN_US:
|
||||
return "en-US"
|
||||
case Language.EN_AU:
|
||||
return "en-AU"
|
||||
case Language.EN_GB:
|
||||
return "en-GB"
|
||||
case Language.EN_NZ:
|
||||
return "en-NZ"
|
||||
case Language.EN_IN:
|
||||
return "en-IN"
|
||||
case Language.FI:
|
||||
return "fi-FI"
|
||||
case Language.FR:
|
||||
return "fr-FR"
|
||||
case Language.FR_CA:
|
||||
return "fr-CA"
|
||||
case Language.DE:
|
||||
return "de-DE"
|
||||
case Language.HI:
|
||||
return "hi-IN"
|
||||
case Language.IT:
|
||||
return "it-IT"
|
||||
case Language.JA:
|
||||
return "ja-JP"
|
||||
case Language.KO:
|
||||
return "ko-KR"
|
||||
case Language.NO:
|
||||
return "nb-NO"
|
||||
case Language.PL:
|
||||
return "pl-PL"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro-RO"
|
||||
case Language.RU:
|
||||
return "ru-RU"
|
||||
case Language.ES:
|
||||
return "es-ES"
|
||||
case Language.SV:
|
||||
return "sv-SE"
|
||||
case Language.TR:
|
||||
return "tr-TR"
|
||||
return None
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
ssml = "<speak>"
|
||||
|
||||
if self._params.language:
|
||||
ssml += f"<lang xml:lang='{self._params.language}'>"
|
||||
language = self._settings["language"]
|
||||
ssml += f"<lang xml:lang='{language}'>"
|
||||
|
||||
prosody_attrs = []
|
||||
# Prosody tags are only supported for standard and neural engines
|
||||
if self._params.engine != "generative":
|
||||
if self._params.rate:
|
||||
prosody_attrs.append(f"rate='{self._params.rate}'")
|
||||
if self._params.pitch:
|
||||
prosody_attrs.append(f"pitch='{self._params.pitch}'")
|
||||
if self._params.volume:
|
||||
prosody_attrs.append(f"volume='{self._params.volume}'")
|
||||
if self._settings["engine"] != "generative":
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
@@ -90,41 +157,12 @@ class AWSTTSService(TTSService):
|
||||
if prosody_attrs:
|
||||
ssml += "</prosody>"
|
||||
|
||||
if self._params.language:
|
||||
ssml += "</lang>"
|
||||
ssml += "</lang>"
|
||||
|
||||
ssml += "</speak>"
|
||||
|
||||
return ssml
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_engine(self, engine: str):
|
||||
logger.debug(f"Switching TTS engine to: [{engine}]")
|
||||
self._params.engine = engine
|
||||
|
||||
async def set_language(self, language: str):
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._params.language = language
|
||||
|
||||
async def set_pitch(self, pitch: str):
|
||||
logger.debug(f"Switching TTS pitch to: [{pitch}]")
|
||||
self._params.pitch = pitch
|
||||
|
||||
async def set_rate(self, rate: str):
|
||||
logger.debug(f"Switching TTS rate to: [{rate}]")
|
||||
self._params.rate = rate
|
||||
|
||||
async def set_volume(self, volume: str):
|
||||
logger.debug(f"Switching TTS volume to: [{volume}]")
|
||||
self._params.volume = volume
|
||||
|
||||
async def set_params(self, params: InputParams):
|
||||
logger.debug(f"Switching TTS params to: [{params}]")
|
||||
self._params = params
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -139,8 +177,8 @@ class AWSTTSService(TTSService):
|
||||
"TextType": "ssml",
|
||||
"OutputFormat": "pcm",
|
||||
"VoiceId": self._voice_id,
|
||||
"Engine": self._params.engine,
|
||||
"SampleRate": str(self._sample_rate),
|
||||
"Engine": self._settings["engine"],
|
||||
"SampleRate": str(self._settings["sample_rate"]),
|
||||
}
|
||||
|
||||
# Filter out None values
|
||||
@@ -150,7 +188,7 @@ class AWSTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
if "AudioStream" in response:
|
||||
with response["AudioStream"] as stream:
|
||||
@@ -160,10 +198,10 @@ class AWSTTSService(TTSService):
|
||||
chunk = audio_data[i : i + chunk_size]
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self._sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except (BotoCoreError, ClientError) as error:
|
||||
logger.exception(f"{self} error generating TTS: {error}")
|
||||
@@ -171,4 +209,4 @@ class AWSTTSService(TTSService):
|
||||
yield ErrorFrame(error=error_message)
|
||||
|
||||
finally:
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -4,12 +4,13 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import io
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -26,12 +27,9 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.services.ai_services import ImageGenService, STTService, TTSService
|
||||
from pipecat.services.openai import BaseOpenAILLMService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Azure configuration needed
|
||||
try:
|
||||
from azure.cognitiveservices.speech import (
|
||||
@@ -76,7 +74,7 @@ class AzureLLMService(BaseOpenAILLMService):
|
||||
class AzureTTSService(TTSService):
|
||||
class InputParams(BaseModel):
|
||||
emphasis: Optional[str] = None
|
||||
language: Optional[str] = "en-US"
|
||||
language: Optional[Language] = Language.EN_US
|
||||
pitch: Optional[str] = None
|
||||
rate: Optional[str] = "1.05"
|
||||
role: Optional[str] = None
|
||||
@@ -99,114 +97,158 @@ class AzureTTSService(TTSService):
|
||||
speech_config = SpeechConfig(subscription=api_key, region=region)
|
||||
self._speech_synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=None)
|
||||
|
||||
self._voice = voice
|
||||
self._sample_rate = sample_rate
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"emphasis": params.emphasis,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN_US,
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"role": params.role,
|
||||
"style": params.style,
|
||||
"style_degree": params.style_degree,
|
||||
"volume": params.volume,
|
||||
}
|
||||
|
||||
self.set_voice(voice)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bg-BG"
|
||||
case Language.CA:
|
||||
return "ca-ES"
|
||||
case Language.ZH:
|
||||
return "zh-CN"
|
||||
case Language.ZH_TW:
|
||||
return "zh-TW"
|
||||
case Language.CS:
|
||||
return "cs-CZ"
|
||||
case Language.DA:
|
||||
return "da-DK"
|
||||
case Language.NL:
|
||||
return "nl-NL"
|
||||
case Language.EN | Language.EN_US:
|
||||
return "en-US"
|
||||
case Language.EN_AU:
|
||||
return "en-AU"
|
||||
case Language.EN_GB:
|
||||
return "en-GB"
|
||||
case Language.EN_NZ:
|
||||
return "en-NZ"
|
||||
case Language.EN_IN:
|
||||
return "en-IN"
|
||||
case Language.ET:
|
||||
return "et-EE"
|
||||
case Language.FI:
|
||||
return "fi-FI"
|
||||
case Language.NL_BE:
|
||||
return "nl-BE"
|
||||
case Language.FR:
|
||||
return "fr-FR"
|
||||
case Language.FR_CA:
|
||||
return "fr-CA"
|
||||
case Language.DE:
|
||||
return "de-DE"
|
||||
case Language.DE_CH:
|
||||
return "de-CH"
|
||||
case Language.EL:
|
||||
return "el-GR"
|
||||
case Language.HI:
|
||||
return "hi-IN"
|
||||
case Language.HU:
|
||||
return "hu-HU"
|
||||
case Language.ID:
|
||||
return "id-ID"
|
||||
case Language.IT:
|
||||
return "it-IT"
|
||||
case Language.JA:
|
||||
return "ja-JP"
|
||||
case Language.KO:
|
||||
return "ko-KR"
|
||||
case Language.LV:
|
||||
return "lv-LV"
|
||||
case Language.LT:
|
||||
return "lt-LT"
|
||||
case Language.MS:
|
||||
return "ms-MY"
|
||||
case Language.NO:
|
||||
return "nb-NO"
|
||||
case Language.PL:
|
||||
return "pl-PL"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro-RO"
|
||||
case Language.RU:
|
||||
return "ru-RU"
|
||||
case Language.SK:
|
||||
return "sk-SK"
|
||||
case Language.ES:
|
||||
return "es-ES"
|
||||
case Language.SV:
|
||||
return "sv-SE"
|
||||
case Language.TH:
|
||||
return "th-TH"
|
||||
case Language.TR:
|
||||
return "tr-TR"
|
||||
case Language.UK:
|
||||
return "uk-UA"
|
||||
case Language.VI:
|
||||
return "vi-VN"
|
||||
return None
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
language = self._settings["language"]
|
||||
ssml = (
|
||||
f"<speak version='1.0' xml:lang='{self._params.language}' "
|
||||
f"<speak version='1.0' xml:lang='{language}' "
|
||||
"xmlns='http://www.w3.org/2001/10/synthesis' "
|
||||
"xmlns:mstts='http://www.w3.org/2001/mstts'>"
|
||||
f"<voice name='{self._voice}'>"
|
||||
f"<voice name='{self._voice_id}'>"
|
||||
"<mstts:silence type='Sentenceboundary' value='20ms' />"
|
||||
)
|
||||
|
||||
if self._params.style:
|
||||
ssml += f"<mstts:express-as style='{self._params.style}'"
|
||||
if self._params.style_degree:
|
||||
ssml += f" styledegree='{self._params.style_degree}'"
|
||||
if self._params.role:
|
||||
ssml += f" role='{self._params.role}'"
|
||||
if self._settings["style"]:
|
||||
ssml += f"<mstts:express-as style='{self._settings['style']}'"
|
||||
if self._settings["style_degree"]:
|
||||
ssml += f" styledegree='{self._settings['style_degree']}'"
|
||||
if self._settings["role"]:
|
||||
ssml += f" role='{self._settings['role']}'"
|
||||
ssml += ">"
|
||||
|
||||
prosody_attrs = []
|
||||
if self._params.rate:
|
||||
prosody_attrs.append(f"rate='{self._params.rate}'")
|
||||
if self._params.pitch:
|
||||
prosody_attrs.append(f"pitch='{self._params.pitch}'")
|
||||
if self._params.volume:
|
||||
prosody_attrs.append(f"volume='{self._params.volume}'")
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
if self._params.emphasis:
|
||||
ssml += f"<emphasis level='{self._params.emphasis}'>"
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
|
||||
ssml += text
|
||||
|
||||
if self._params.emphasis:
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
|
||||
ssml += "</prosody>"
|
||||
|
||||
if self._params.style:
|
||||
if self._settings["style"]:
|
||||
ssml += "</mstts:express-as>"
|
||||
|
||||
ssml += "</voice></speak>"
|
||||
|
||||
return ssml
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = voice
|
||||
|
||||
async def set_emphasis(self, emphasis: str):
|
||||
logger.debug(f"Setting TTS emphasis to: [{emphasis}]")
|
||||
self._params.emphasis = emphasis
|
||||
|
||||
async def set_language(self, language: str):
|
||||
logger.debug(f"Setting TTS language code to: [{language}]")
|
||||
self._params.language = language
|
||||
|
||||
async def set_pitch(self, pitch: str):
|
||||
logger.debug(f"Setting TTS pitch to: [{pitch}]")
|
||||
self._params.pitch = pitch
|
||||
|
||||
async def set_rate(self, rate: str):
|
||||
logger.debug(f"Setting TTS rate to: [{rate}]")
|
||||
self._params.rate = rate
|
||||
|
||||
async def set_role(self, role: str):
|
||||
logger.debug(f"Setting TTS role to: [{role}]")
|
||||
self._params.role = role
|
||||
|
||||
async def set_style(self, style: str):
|
||||
logger.debug(f"Setting TTS style to: [{style}]")
|
||||
self._params.style = style
|
||||
|
||||
async def set_style_degree(self, style_degree: str):
|
||||
logger.debug(f"Setting TTS style degree to: [{style_degree}]")
|
||||
self._params.style_degree = style_degree
|
||||
|
||||
async def set_volume(self, volume: str):
|
||||
logger.debug(f"Setting TTS volume to: [{volume}]")
|
||||
self._params.volume = volume
|
||||
|
||||
async def set_params(self, **kwargs):
|
||||
valid_params = {
|
||||
"voice": self.set_voice,
|
||||
"emphasis": self.set_emphasis,
|
||||
"language_code": self.set_language,
|
||||
"pitch": self.set_pitch,
|
||||
"rate": self.set_rate,
|
||||
"role": self.set_role,
|
||||
"style": self.set_style,
|
||||
"style_degree": self.set_style_degree,
|
||||
"volume": self.set_volume,
|
||||
}
|
||||
|
||||
for param, value in kwargs.items():
|
||||
if param in valid_params:
|
||||
await valid_params[param](value)
|
||||
else:
|
||||
logger.warning(f"Ignoring unknown parameter: {param}")
|
||||
|
||||
logger.debug(f"Updated TTS parameters: {', '.join(kwargs.keys())}")
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -219,12 +261,14 @@ class AzureTTSService(TTSService):
|
||||
if result.reason == ResultReason.SynthesizingAudioCompleted:
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
# Azure always sends a 44-byte header. Strip it off.
|
||||
yield TTSAudioRawFrame(
|
||||
audio=result.audio_data[44:], sample_rate=self._sample_rate, num_channels=1
|
||||
audio=result.audio_data[44:],
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
elif result.reason == ResultReason.Canceled:
|
||||
cancellation_details = result.cancellation_details
|
||||
logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
|
||||
@@ -238,7 +282,7 @@ class AzureSTTService(STTService):
|
||||
*,
|
||||
api_key: str,
|
||||
region: str,
|
||||
language="en-US",
|
||||
language=Language.EN_US,
|
||||
sample_rate=16000,
|
||||
channels=1,
|
||||
**kwargs,
|
||||
|
||||
@@ -4,36 +4,35 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from typing import AsyncGenerator, Optional, Union, List
|
||||
from loguru import logger
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartInterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
StartFrame,
|
||||
EndFrame,
|
||||
StartInterruptionFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService, WordTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.services.ai_services import AsyncWordTTSService, TTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Cartesia configuration needed
|
||||
try:
|
||||
from cartesia import AsyncCartesia
|
||||
import websockets
|
||||
from cartesia import AsyncCartesia
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
@@ -46,27 +45,34 @@ def language_to_cartesia_language(language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.DE:
|
||||
return "de"
|
||||
case Language.EN:
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_GB
|
||||
| Language.EN_AU
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "en"
|
||||
case Language.ES:
|
||||
return "es"
|
||||
case Language.FR:
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "fr"
|
||||
case Language.JA:
|
||||
return "ja"
|
||||
case Language.PT:
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "pt"
|
||||
case Language.ZH:
|
||||
case Language.ZH | Language.ZH_TW:
|
||||
return "zh"
|
||||
return None
|
||||
|
||||
|
||||
class CartesiaTTSService(AsyncWordTTSService):
|
||||
class CartesiaTTSService(WordTTSService):
|
||||
class InputParams(BaseModel):
|
||||
encoding: Optional[str] = "pcm_s16le"
|
||||
sample_rate: Optional[int] = 16000
|
||||
container: Optional[str] = "raw"
|
||||
language: Optional[str] = "en"
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[Union[str, float]] = ""
|
||||
emotion: Optional[List[str]] = []
|
||||
|
||||
@@ -77,7 +83,7 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
voice_id: str,
|
||||
cartesia_version: str = "2024-06-10",
|
||||
url: str = "wss://api.cartesia.ai/tts/websocket",
|
||||
model_id: str = "sonic-english",
|
||||
model: str = "sonic-english",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
@@ -101,17 +107,20 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
self._api_key = api_key
|
||||
self._cartesia_version = cartesia_version
|
||||
self._url = url
|
||||
self._voice_id = voice_id
|
||||
self._model_id = model_id
|
||||
self.set_model_name(model_id)
|
||||
self._output_format = {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
self._settings = {
|
||||
"output_format": {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
},
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"speed": params.speed,
|
||||
"emotion": params.emotion,
|
||||
}
|
||||
self._language = params.language
|
||||
self._speed = params.speed
|
||||
self._emotion = params.emotion
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
|
||||
self._websocket = None
|
||||
self._context_id = None
|
||||
@@ -125,42 +134,31 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
await super().set_model(model)
|
||||
logger.debug(f"Switching TTS model to: [{model}]")
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_speed(self, speed: str):
|
||||
logger.debug(f"Switching TTS speed to: [{speed}]")
|
||||
self._speed = speed
|
||||
|
||||
async def set_emotion(self, emotion: list[str]):
|
||||
logger.debug(f"Switching TTS emotion to: [{emotion}]")
|
||||
self._emotion = emotion
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._language = language_to_cartesia_language(language)
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
return language_to_cartesia_language(language)
|
||||
|
||||
def _build_msg(
|
||||
self, text: str = "", continue_transcript: bool = True, add_timestamps: bool = True
|
||||
):
|
||||
voice_config = {"mode": "id", "id": self._voice_id}
|
||||
voice_config = {}
|
||||
voice_config["mode"] = "id"
|
||||
voice_config["id"] = self._voice_id
|
||||
|
||||
if self._speed or self._emotion:
|
||||
if self._settings["speed"] or self._settings["emotion"]:
|
||||
voice_config["__experimental_controls"] = {}
|
||||
if self._speed:
|
||||
voice_config["__experimental_controls"]["speed"] = self._speed
|
||||
if self._emotion:
|
||||
voice_config["__experimental_controls"]["emotion"] = self._emotion
|
||||
if self._settings["speed"]:
|
||||
voice_config["__experimental_controls"]["speed"] = self._settings["speed"]
|
||||
if self._settings["emotion"]:
|
||||
voice_config["__experimental_controls"]["emotion"] = self._settings["emotion"]
|
||||
|
||||
msg = {
|
||||
"transcript": text,
|
||||
"continue": continue_transcript,
|
||||
"context_id": self._context_id,
|
||||
"model_id": self._model_name,
|
||||
"model_id": self.model_name,
|
||||
"voice": voice_config,
|
||||
"output_format": self._output_format,
|
||||
"language": self._language,
|
||||
"output_format": self._settings["output_format"],
|
||||
"language": self._settings["language"],
|
||||
"add_timestamps": add_timestamps,
|
||||
}
|
||||
return json.dumps(msg)
|
||||
@@ -245,7 +243,7 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
self.start_word_timestamps()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=base64.b64decode(msg["data"]),
|
||||
sample_rate=self._output_format["sample_rate"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
@@ -269,8 +267,8 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
await self._connect()
|
||||
|
||||
if not self._context_id:
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._context_id = str(uuid.uuid4())
|
||||
|
||||
msg = self._build_msg(text=text)
|
||||
@@ -280,7 +278,7 @@ class CartesiaTTSService(AsyncWordTTSService):
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
@@ -294,7 +292,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
encoding: Optional[str] = "pcm_s16le"
|
||||
sample_rate: Optional[int] = 16000
|
||||
container: Optional[str] = "raw"
|
||||
language: Optional[str] = "en"
|
||||
language: Optional[Language] = Language.EN
|
||||
speed: Optional[Union[str, float]] = ""
|
||||
emotion: Optional[List[str]] = []
|
||||
|
||||
@@ -303,7 +301,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
model_id: str = "sonic-english",
|
||||
model: str = "sonic-english",
|
||||
base_url: str = "https://api.cartesia.ai",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
@@ -311,43 +309,28 @@ class CartesiaHttpTTSService(TTSService):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._model_id = model_id
|
||||
self.set_model_name(model_id)
|
||||
self._output_format = {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
self._settings = {
|
||||
"output_format": {
|
||||
"container": params.container,
|
||||
"encoding": params.encoding,
|
||||
"sample_rate": params.sample_rate,
|
||||
},
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"speed": params.speed,
|
||||
"emotion": params.emotion,
|
||||
}
|
||||
self._language = params.language
|
||||
self._speed = params.speed
|
||||
self._emotion = params.emotion
|
||||
self.set_voice(voice_id)
|
||||
self.set_model_name(model)
|
||||
|
||||
self._client = AsyncCartesia(api_key=api_key, base_url=base_url)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
logger.debug(f"Switching TTS model to: [{model}]")
|
||||
self._model_id = model
|
||||
await super().set_model(model)
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_speed(self, speed: str):
|
||||
logger.debug(f"Switching TTS speed to: [{speed}]")
|
||||
self._speed = speed
|
||||
|
||||
async def set_emotion(self, emotion: list[str]):
|
||||
logger.debug(f"Switching TTS emotion to: [{emotion}]")
|
||||
self._emotion = emotion
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._language = language_to_cartesia_language(language)
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
return language_to_cartesia_language(language)
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
@@ -360,24 +343,24 @@ class CartesiaHttpTTSService(TTSService):
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
try:
|
||||
voice_controls = None
|
||||
if self._speed or self._emotion:
|
||||
if self._settings["speed"] or self._settings["emotion"]:
|
||||
voice_controls = {}
|
||||
if self._speed:
|
||||
voice_controls["speed"] = self._speed
|
||||
if self._emotion:
|
||||
voice_controls["emotion"] = self._emotion
|
||||
if self._settings["speed"]:
|
||||
voice_controls["speed"] = self._settings["speed"]
|
||||
if self._settings["emotion"]:
|
||||
voice_controls["emotion"] = self._settings["emotion"]
|
||||
|
||||
output = await self._client.tts.sse(
|
||||
model_id=self._model_id,
|
||||
model_id=self._model_name,
|
||||
transcript=text,
|
||||
voice_id=self._voice_id,
|
||||
output_format=self._output_format,
|
||||
language=self._language,
|
||||
output_format=self._settings["output_format"],
|
||||
language=self._settings["language"],
|
||||
stream=False,
|
||||
_experimental_voice_controls=voice_controls,
|
||||
)
|
||||
@@ -386,7 +369,7 @@ class CartesiaHttpTTSService(TTSService):
|
||||
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=output["audio"],
|
||||
sample_rate=self._output_format["sample_rate"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
@@ -394,4 +377,4 @@ class CartesiaHttpTTSService(TTSService):
|
||||
logger.error(f"{self} exception: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -5,9 +5,10 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -24,8 +25,6 @@ from pipecat.services.ai_services import STTService, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Deepgram configuration needed
|
||||
try:
|
||||
from deepgram import (
|
||||
@@ -57,25 +56,23 @@ class DeepgramTTSService(TTSService):
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice = voice
|
||||
self._sample_rate = sample_rate
|
||||
self._encoding = encoding
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"encoding": encoding,
|
||||
}
|
||||
self.set_voice(voice)
|
||||
self._deepgram_client = DeepgramClient(api_key=api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
options = SpeakOptions(
|
||||
model=self._voice,
|
||||
encoding=self._encoding,
|
||||
sample_rate=self._sample_rate,
|
||||
model=self._voice_id,
|
||||
encoding=self._settings["encoding"],
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
container="none",
|
||||
)
|
||||
|
||||
@@ -87,7 +84,7 @@ class DeepgramTTSService(TTSService):
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# The response.stream_memory is already a BytesIO object
|
||||
audio_buffer = response.stream_memory
|
||||
@@ -103,10 +100,12 @@ class DeepgramTTSService(TTSService):
|
||||
chunk = audio_buffer.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
frame = TTSAudioRawFrame(audio=chunk, sample_rate=self._sample_rate, num_channels=1)
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=chunk, sample_rate=self._settings["sample_rate"], num_channels=1
|
||||
)
|
||||
yield frame
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
@@ -121,7 +120,7 @@ class DeepgramSTTService(STTService):
|
||||
url: str = "",
|
||||
live_options: LiveOptions = LiveOptions(
|
||||
encoding="linear16",
|
||||
language="en-US",
|
||||
language=Language.EN,
|
||||
model="nova-2-conversationalai",
|
||||
sample_rate=16000,
|
||||
channels=1,
|
||||
@@ -135,7 +134,7 @@ class DeepgramSTTService(STTService):
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._live_options = live_options
|
||||
self._settings = vars(live_options)
|
||||
|
||||
self._client = DeepgramClient(
|
||||
api_key, config=DeepgramClientOptions(url=url, options={"keepalive": "true"})
|
||||
@@ -147,7 +146,7 @@ class DeepgramSTTService(STTService):
|
||||
|
||||
@property
|
||||
def vad_enabled(self):
|
||||
return self._live_options.vad_events
|
||||
return self._settings["vad_events"]
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return self.vad_enabled
|
||||
@@ -155,13 +154,13 @@ class DeepgramSTTService(STTService):
|
||||
async def set_model(self, model: str):
|
||||
await super().set_model(model)
|
||||
logger.debug(f"Switching STT model to: [{model}]")
|
||||
self._live_options.model = model
|
||||
self._settings["model"] = model
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
logger.debug(f"Switching STT language to: [{language}]")
|
||||
self._live_options.language = language
|
||||
self._settings["language"] = language
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
@@ -182,7 +181,7 @@ class DeepgramSTTService(STTService):
|
||||
yield None
|
||||
|
||||
async def _connect(self):
|
||||
if await self._connection.start(self._live_options):
|
||||
if await self._connection.start(self._settings):
|
||||
logger.debug(f"{self}: Connected to Deepgram")
|
||||
else:
|
||||
logger.error(f"{self}: Unable to connect to Deepgram")
|
||||
|
||||
@@ -23,7 +23,8 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import AsyncWordTTSService
|
||||
from pipecat.services.ai_services import WordTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for ElevenLabs configuration needed
|
||||
try:
|
||||
@@ -70,9 +71,9 @@ def calculate_word_times(
|
||||
return word_times
|
||||
|
||||
|
||||
class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
class ElevenLabsTTSService(WordTTSService):
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
output_format: Literal["pcm_16000", "pcm_22050", "pcm_24000", "pcm_44100"] = "pcm_16000"
|
||||
optimize_streaming_latency: Optional[str] = None
|
||||
stability: Optional[float] = None
|
||||
@@ -124,10 +125,21 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self.set_model_name(model)
|
||||
self._url = url
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate_from_output_format(params.output_format),
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"output_format": params.output_format,
|
||||
"optimize_streaming_latency": params.optimize_streaming_latency,
|
||||
"stability": params.stability,
|
||||
"similarity_boost": params.similarity_boost,
|
||||
"style": params.style,
|
||||
"use_speaker_boost": params.use_speaker_boost,
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self.set_voice(voice_id)
|
||||
self._voice_settings = self._set_voice_settings()
|
||||
|
||||
# Websocket connection to ElevenLabs.
|
||||
@@ -140,21 +152,93 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bg"
|
||||
case Language.ZH:
|
||||
return "zh"
|
||||
case Language.CS:
|
||||
return "cs"
|
||||
case Language.DA:
|
||||
return "da"
|
||||
case Language.NL:
|
||||
return "nl"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "en"
|
||||
case Language.FI:
|
||||
return "fi"
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "fr"
|
||||
case Language.DE | Language.DE_CH:
|
||||
return "de"
|
||||
case Language.EL:
|
||||
return "el"
|
||||
case Language.HI:
|
||||
return "hi"
|
||||
case Language.HU:
|
||||
return "hu"
|
||||
case Language.ID:
|
||||
return "id"
|
||||
case Language.IT:
|
||||
return "it"
|
||||
case Language.JA:
|
||||
return "ja"
|
||||
case Language.KO:
|
||||
return "ko"
|
||||
case Language.MS:
|
||||
return "ms"
|
||||
case Language.NO:
|
||||
return "no"
|
||||
case Language.PL:
|
||||
return "pl"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro"
|
||||
case Language.RU:
|
||||
return "ru"
|
||||
case Language.SK:
|
||||
return "sk"
|
||||
case Language.ES:
|
||||
return "es"
|
||||
case Language.SV:
|
||||
return "sv"
|
||||
case Language.TR:
|
||||
return "tr"
|
||||
case Language.UK:
|
||||
return "uk"
|
||||
case Language.VI:
|
||||
return "vi"
|
||||
return None
|
||||
|
||||
def _set_voice_settings(self):
|
||||
voice_settings = {}
|
||||
if self._params.stability is not None and self._params.similarity_boost is not None:
|
||||
voice_settings["stability"] = self._params.stability
|
||||
voice_settings["similarity_boost"] = self._params.similarity_boost
|
||||
if self._params.style is not None:
|
||||
voice_settings["style"] = self._params.style
|
||||
if self._params.use_speaker_boost is not None:
|
||||
voice_settings["use_speaker_boost"] = self._params.use_speaker_boost
|
||||
if (
|
||||
self._settings["stability"] is not None
|
||||
and self._settings["similarity_boost"] is not None
|
||||
):
|
||||
voice_settings["stability"] = self._settings["stability"]
|
||||
voice_settings["similarity_boost"] = self._settings["similarity_boost"]
|
||||
if self._settings["style"] is not None:
|
||||
voice_settings["style"] = self._settings["style"]
|
||||
if self._settings["use_speaker_boost"] is not None:
|
||||
voice_settings["use_speaker_boost"] = self._settings["use_speaker_boost"]
|
||||
else:
|
||||
if self._params.style is not None:
|
||||
if self._settings["style"] is not None:
|
||||
logger.warning(
|
||||
"'style' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||
)
|
||||
if self._params.use_speaker_boost is not None:
|
||||
if self._settings["use_speaker_boost"] is not None:
|
||||
logger.warning(
|
||||
"'use_speaker_boost' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||
)
|
||||
@@ -167,33 +251,13 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
|
||||
async def set_voice_settings(
|
||||
self,
|
||||
stability: Optional[float] = None,
|
||||
similarity_boost: Optional[float] = None,
|
||||
style: Optional[float] = None,
|
||||
use_speaker_boost: Optional[bool] = None,
|
||||
):
|
||||
self._params.stability = stability if stability is not None else self._params.stability
|
||||
self._params.similarity_boost = (
|
||||
similarity_boost if similarity_boost is not None else self._params.similarity_boost
|
||||
)
|
||||
self._params.style = style if style is not None else self._params.style
|
||||
self._params.use_speaker_boost = (
|
||||
use_speaker_boost if use_speaker_boost is not None else self._params.use_speaker_boost
|
||||
)
|
||||
|
||||
self._set_voice_settings()
|
||||
|
||||
if self._websocket:
|
||||
msg = {"voice_settings": self._voice_settings}
|
||||
await self._websocket.send(json.dumps(msg))
|
||||
async def _update_settings(self, settings: Dict[str, Any]):
|
||||
prev_voice = self._voice_id
|
||||
await super()._update_settings(settings)
|
||||
if not prev_voice == self._voice_id:
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
logger.debug(f"Switching TTS voice to: [{self._voice_id}]")
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
@@ -223,20 +287,20 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
try:
|
||||
voice_id = self._voice_id
|
||||
model = self.model_name
|
||||
output_format = self._params.output_format
|
||||
output_format = self._settings["output_format"]
|
||||
url = f"{self._url}/v1/text-to-speech/{voice_id}/stream-input?model_id={model}&output_format={output_format}"
|
||||
|
||||
if self._params.optimize_streaming_latency:
|
||||
url += f"&optimize_streaming_latency={self._params.optimize_streaming_latency}"
|
||||
if self._settings["optimize_streaming_latency"]:
|
||||
url += f"&optimize_streaming_latency={self._settings['optimize_streaming_latency']}"
|
||||
|
||||
# language can only be used with the 'eleven_turbo_v2_5' model
|
||||
if self._params.language:
|
||||
if model == "eleven_turbo_v2_5":
|
||||
url += f"&language_code={self._params.language}"
|
||||
else:
|
||||
logger.debug(
|
||||
f"Language code [{self._params.language}] not applied. Language codes can only be used with the 'eleven_turbo_v2_5' model."
|
||||
)
|
||||
# Language can only be used with the 'eleven_turbo_v2_5' model
|
||||
language = self._settings["language"]
|
||||
if model == "eleven_turbo_v2_5":
|
||||
url += f"&language_code={language}"
|
||||
else:
|
||||
logger.debug(
|
||||
f"Language code [{language}] not applied. Language codes can only be used with the 'eleven_turbo_v2_5' model."
|
||||
)
|
||||
|
||||
self._websocket = await websockets.connect(url)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
@@ -286,7 +350,7 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
self.start_word_timestamps()
|
||||
|
||||
audio = base64.b64decode(msg["audio"])
|
||||
frame = TTSAudioRawFrame(audio, self.sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(audio, self._settings["sample_rate"], 1)
|
||||
await self.push_frame(frame)
|
||||
|
||||
if msg.get("alignment"):
|
||||
@@ -322,8 +386,8 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
|
||||
try:
|
||||
if not self._started:
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
self._cumulative_time = 0
|
||||
|
||||
@@ -331,7 +395,7 @@ class ElevenLabsTTSService(AsyncWordTTSService):
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
|
||||
@@ -6,8 +6,9 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
@@ -19,10 +20,9 @@ from pipecat.frames.frames import (
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
from loguru import logger
|
||||
|
||||
# See .env.example for Gladia configuration needed
|
||||
try:
|
||||
import websockets
|
||||
@@ -37,7 +37,7 @@ except ModuleNotFoundError as e:
|
||||
class GladiaSTTService(STTService):
|
||||
class InputParams(BaseModel):
|
||||
sample_rate: Optional[int] = 16000
|
||||
language: Optional[str] = "english"
|
||||
language: Optional[Language] = Language.EN
|
||||
transcription_hint: Optional[str] = None
|
||||
endpointing: Optional[int] = 200
|
||||
prosody: Optional[bool] = None
|
||||
@@ -51,13 +51,98 @@ class GladiaSTTService(STTService):
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._url = url
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": params.sample_rate,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"transcription_hint": params.transcription_hint,
|
||||
"endpointing": params.endpointing,
|
||||
"prosody": params.prosody,
|
||||
}
|
||||
self._confidence = confidence
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bulgarian"
|
||||
case Language.CA:
|
||||
return "catalan"
|
||||
case Language.ZH:
|
||||
return "chinese"
|
||||
case Language.CS:
|
||||
return "czech"
|
||||
case Language.DA:
|
||||
return "danish"
|
||||
case Language.NL:
|
||||
return "dutch"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "english"
|
||||
case Language.ET:
|
||||
return "estonian"
|
||||
case Language.FI:
|
||||
return "finnish"
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "french"
|
||||
case Language.DE | Language.DE_CH:
|
||||
return "german"
|
||||
case Language.EL:
|
||||
return "greek"
|
||||
case Language.HI:
|
||||
return "hindi"
|
||||
case Language.HU:
|
||||
return "hungarian"
|
||||
case Language.ID:
|
||||
return "indonesian"
|
||||
case Language.IT:
|
||||
return "italian"
|
||||
case Language.JA:
|
||||
return "japanese"
|
||||
case Language.KO:
|
||||
return "korean"
|
||||
case Language.LV:
|
||||
return "latvian"
|
||||
case Language.LT:
|
||||
return "lithuanian"
|
||||
case Language.MS:
|
||||
return "malay"
|
||||
case Language.NO:
|
||||
return "norwegian"
|
||||
case Language.PL:
|
||||
return "polish"
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "portuguese"
|
||||
case Language.RO:
|
||||
return "romanian"
|
||||
case Language.RU:
|
||||
return "russian"
|
||||
case Language.SK:
|
||||
return "slovak"
|
||||
case Language.ES:
|
||||
return "spanish"
|
||||
case Language.SV:
|
||||
return "slovenian"
|
||||
case Language.TH:
|
||||
return "thai"
|
||||
case Language.TR:
|
||||
return "turkish"
|
||||
case Language.UK:
|
||||
return "ukrainian"
|
||||
case Language.VI:
|
||||
return "vietnamese"
|
||||
return None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._websocket = await websockets.connect(self._url)
|
||||
@@ -84,7 +169,11 @@ class GladiaSTTService(STTService):
|
||||
"encoding": "WAV/PCM",
|
||||
"model_type": "fast",
|
||||
"language_behaviour": "manual",
|
||||
**self._params.model_dump(exclude_none=True),
|
||||
"sample_rate": self._settings["sample_rate"],
|
||||
"language": self._settings["language"],
|
||||
"transcription_hint": self._settings["transcription_hint"],
|
||||
"endpointing": self._settings["endpointing"],
|
||||
"prosody": self._settings["prosody"],
|
||||
}
|
||||
|
||||
await self._websocket.send(json.dumps(configuration))
|
||||
|
||||
@@ -30,6 +30,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService, TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import google.ai.generativelanguage as glm
|
||||
@@ -39,7 +40,7 @@ try:
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
"In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_API_KEY` environment variable."
|
||||
"In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set the environment variable GOOGLE_API_KEY for the GoogleLLMService and GOOGLE_APPLICATION_CREDENTIALS for the GoogleTTSService`."
|
||||
)
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
@@ -137,9 +138,7 @@ class GoogleLLMService(LLMService):
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -153,7 +152,7 @@ class GoogleTTSService(TTSService):
|
||||
rate: Optional[str] = None
|
||||
volume: Optional[str] = None
|
||||
emphasis: Optional[Literal["strong", "moderate", "reduced", "none"]] = None
|
||||
language: Optional[str] = None
|
||||
language: Optional[Language] = Language.EN
|
||||
gender: Optional[Literal["male", "female", "neutral"]] = None
|
||||
google_style: Optional[Literal["apologetic", "calm", "empathetic", "firm", "lively"]] = None
|
||||
|
||||
@@ -169,8 +168,19 @@ class GoogleTTSService(TTSService):
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._voice_id: str = voice_id
|
||||
self._params = params
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"pitch": params.pitch,
|
||||
"rate": params.rate,
|
||||
"volume": params.volume,
|
||||
"emphasis": params.emphasis,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else Language.EN,
|
||||
"gender": params.gender,
|
||||
"google_style": params.google_style,
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client(
|
||||
credentials, credentials_path
|
||||
)
|
||||
@@ -190,51 +200,135 @@ class GoogleTTSService(TTSService):
|
||||
elif credentials_path:
|
||||
# Use service account JSON file if provided
|
||||
creds = service_account.Credentials.from_service_account_file(credentials_path)
|
||||
else:
|
||||
raise ValueError("Either 'credentials' or 'credentials_path' must be provided.")
|
||||
|
||||
return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.BG:
|
||||
return "bg-BG"
|
||||
case Language.CA:
|
||||
return "ca-ES"
|
||||
case Language.ZH:
|
||||
return "cmn-CN"
|
||||
case Language.ZH_TW:
|
||||
return "cmn-TW"
|
||||
case Language.CS:
|
||||
return "cs-CZ"
|
||||
case Language.DA:
|
||||
return "da-DK"
|
||||
case Language.NL:
|
||||
return "nl-NL"
|
||||
case Language.EN | Language.EN_US:
|
||||
return "en-US"
|
||||
case Language.EN_AU:
|
||||
return "en-AU"
|
||||
case Language.EN_GB:
|
||||
return "en-GB"
|
||||
case Language.EN_IN:
|
||||
return "en-IN"
|
||||
case Language.ET:
|
||||
return "et-EE"
|
||||
case Language.FI:
|
||||
return "fi-FI"
|
||||
case Language.NL_BE:
|
||||
return "nl-BE"
|
||||
case Language.FR:
|
||||
return "fr-FR"
|
||||
case Language.FR_CA:
|
||||
return "fr-CA"
|
||||
case Language.DE:
|
||||
return "de-DE"
|
||||
case Language.EL:
|
||||
return "el-GR"
|
||||
case Language.HI:
|
||||
return "hi-IN"
|
||||
case Language.HU:
|
||||
return "hu-HU"
|
||||
case Language.ID:
|
||||
return "id-ID"
|
||||
case Language.IT:
|
||||
return "it-IT"
|
||||
case Language.JA:
|
||||
return "ja-JP"
|
||||
case Language.KO:
|
||||
return "ko-KR"
|
||||
case Language.LV:
|
||||
return "lv-LV"
|
||||
case Language.LT:
|
||||
return "lt-LT"
|
||||
case Language.MS:
|
||||
return "ms-MY"
|
||||
case Language.NO:
|
||||
return "nb-NO"
|
||||
case Language.PL:
|
||||
return "pl-PL"
|
||||
case Language.PT:
|
||||
return "pt-PT"
|
||||
case Language.PT_BR:
|
||||
return "pt-BR"
|
||||
case Language.RO:
|
||||
return "ro-RO"
|
||||
case Language.RU:
|
||||
return "ru-RU"
|
||||
case Language.SK:
|
||||
return "sk-SK"
|
||||
case Language.ES:
|
||||
return "es-ES"
|
||||
case Language.SV:
|
||||
return "sv-SE"
|
||||
case Language.TH:
|
||||
return "th-TH"
|
||||
case Language.TR:
|
||||
return "tr-TR"
|
||||
case Language.UK:
|
||||
return "uk-UA"
|
||||
case Language.VI:
|
||||
return "vi-VN"
|
||||
return None
|
||||
|
||||
def _construct_ssml(self, text: str) -> str:
|
||||
ssml = "<speak>"
|
||||
|
||||
# Voice tag
|
||||
voice_attrs = [f"name='{self._voice_id}'"]
|
||||
if self._params.language:
|
||||
voice_attrs.append(f"language='{self._params.language}'")
|
||||
if self._params.gender:
|
||||
voice_attrs.append(f"gender='{self._params.gender}'")
|
||||
|
||||
language = self._settings["language"]
|
||||
voice_attrs.append(f"language='{language}'")
|
||||
|
||||
if self._settings["gender"]:
|
||||
voice_attrs.append(f"gender='{self._settings['gender']}'")
|
||||
ssml += f"<voice {' '.join(voice_attrs)}>"
|
||||
|
||||
# Prosody tag
|
||||
prosody_attrs = []
|
||||
if self._params.pitch:
|
||||
prosody_attrs.append(f"pitch='{self._params.pitch}'")
|
||||
if self._params.rate:
|
||||
prosody_attrs.append(f"rate='{self._params.rate}'")
|
||||
if self._params.volume:
|
||||
prosody_attrs.append(f"volume='{self._params.volume}'")
|
||||
if self._settings["pitch"]:
|
||||
prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
|
||||
if self._settings["rate"]:
|
||||
prosody_attrs.append(f"rate='{self._settings['rate']}'")
|
||||
if self._settings["volume"]:
|
||||
prosody_attrs.append(f"volume='{self._settings['volume']}'")
|
||||
|
||||
if prosody_attrs:
|
||||
ssml += f"<prosody {' '.join(prosody_attrs)}>"
|
||||
|
||||
# Emphasis tag
|
||||
if self._params.emphasis:
|
||||
ssml += f"<emphasis level='{self._params.emphasis}'>"
|
||||
if self._settings["emphasis"]:
|
||||
ssml += f"<emphasis level='{self._settings['emphasis']}'>"
|
||||
|
||||
# Google style tag
|
||||
if self._params.google_style:
|
||||
ssml += f"<google:style name='{self._params.google_style}'>"
|
||||
if self._settings["google_style"]:
|
||||
ssml += f"<google:style name='{self._settings['google_style']}'>"
|
||||
|
||||
ssml += text
|
||||
|
||||
# Close tags
|
||||
if self._params.google_style:
|
||||
if self._settings["google_style"]:
|
||||
ssml += "</google:style>"
|
||||
if self._params.emphasis:
|
||||
if self._settings["emphasis"]:
|
||||
ssml += "</emphasis>"
|
||||
if prosody_attrs:
|
||||
ssml += "</prosody>"
|
||||
@@ -242,46 +336,6 @@ class GoogleTTSService(TTSService):
|
||||
|
||||
return ssml
|
||||
|
||||
async def set_voice(self, voice: str) -> None:
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def set_language(self, language: str) -> None:
|
||||
logger.debug(f"Switching TTS language to: [{language}]")
|
||||
self._params.language = language
|
||||
|
||||
async def set_pitch(self, pitch: str) -> None:
|
||||
logger.debug(f"Switching TTS pitch to: [{pitch}]")
|
||||
self._params.pitch = pitch
|
||||
|
||||
async def set_rate(self, rate: str) -> None:
|
||||
logger.debug(f"Switching TTS rate to: [{rate}]")
|
||||
self._params.rate = rate
|
||||
|
||||
async def set_volume(self, volume: str) -> None:
|
||||
logger.debug(f"Switching TTS volume to: [{volume}]")
|
||||
self._params.volume = volume
|
||||
|
||||
async def set_emphasis(
|
||||
self, emphasis: Literal["strong", "moderate", "reduced", "none"]
|
||||
) -> None:
|
||||
logger.debug(f"Switching TTS emphasis to: [{emphasis}]")
|
||||
self._params.emphasis = emphasis
|
||||
|
||||
async def set_gender(self, gender: Literal["male", "female", "neutral"]) -> None:
|
||||
logger.debug(f"Switch TTS gender to [{gender}]")
|
||||
self._params.gender = gender
|
||||
|
||||
async def google_style(
|
||||
self, google_style: Literal["apologetic", "calm", "empathetic", "firm", "lively"]
|
||||
) -> None:
|
||||
logger.debug(f"Switching TTS google style to: [{google_style}]")
|
||||
self._params.google_style = google_style
|
||||
|
||||
async def set_params(self, params: InputParams) -> None:
|
||||
logger.debug(f"Switching TTS params to: [{params}]")
|
||||
self._params = params
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -291,11 +345,11 @@ class GoogleTTSService(TTSService):
|
||||
ssml = self._construct_ssml(text)
|
||||
synthesis_input = texttospeech_v1.SynthesisInput(ssml=ssml)
|
||||
voice = texttospeech_v1.VoiceSelectionParams(
|
||||
language_code=self._params.language, name=self._voice_id
|
||||
language_code=self._settings["language"], name=self._voice_id
|
||||
)
|
||||
audio_config = texttospeech_v1.AudioConfig(
|
||||
audio_encoding=texttospeech_v1.AudioEncoding.LINEAR16,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
sample_rate_hertz=self._settings["sample_rate"],
|
||||
)
|
||||
|
||||
request = texttospeech_v1.SynthesizeSpeechRequest(
|
||||
@@ -306,7 +360,7 @@ class GoogleTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
# Skip the first 44 bytes to remove the WAV header
|
||||
audio_content = response.audio_content[44:]
|
||||
@@ -318,15 +372,15 @@ class GoogleTTSService(TTSService):
|
||||
if not chunk:
|
||||
break
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
await asyncio.sleep(0) # Allow other tasks to run
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
error_message = f"TTS generation error: {str(e)}"
|
||||
yield ErrorFrame(error=error_message)
|
||||
finally:
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -5,10 +5,10 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
@@ -20,9 +20,9 @@ from pipecat.frames.frames import (
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import AsyncTTSService
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
# See .env.example for LMNT configuration needed
|
||||
try:
|
||||
@@ -35,28 +35,31 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
class LmntTTSService(AsyncTTSService):
|
||||
class LmntTTSService(TTSService):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
voice_id: str,
|
||||
sample_rate: int = 24000,
|
||||
language: str = "en",
|
||||
language: Language = Language.EN,
|
||||
**kwargs,
|
||||
):
|
||||
# Let TTSService produce TTSStoppedFrames after a short delay of
|
||||
# no activity.
|
||||
super().__init__(sync=False, push_stop_frames=True, sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(push_stop_frames=True, sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._output_format = {
|
||||
"container": "raw",
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate": sample_rate,
|
||||
self._settings = {
|
||||
"output_format": {
|
||||
"container": "raw",
|
||||
"encoding": "pcm_s16le",
|
||||
"sample_rate": sample_rate,
|
||||
},
|
||||
"language": self.language_to_service_language(language),
|
||||
}
|
||||
self._language = language
|
||||
|
||||
self.set_voice(voice_id)
|
||||
|
||||
self._speech = None
|
||||
self._connection = None
|
||||
@@ -68,9 +71,30 @@ class LmntTTSService(AsyncTTSService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.DE:
|
||||
return "de"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "en"
|
||||
case Language.ES:
|
||||
return "es"
|
||||
case Language.FR | Language.FR_CA:
|
||||
return "fr"
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "pt"
|
||||
case Language.ZH | Language.ZH_TW:
|
||||
return "zh"
|
||||
case Language.KO:
|
||||
return "ko"
|
||||
return None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
@@ -93,7 +117,10 @@ class LmntTTSService(AsyncTTSService):
|
||||
try:
|
||||
self._speech = Speech()
|
||||
self._connection = await self._speech.synthesize_streaming(
|
||||
self._voice_id, format="raw", sample_rate=self._output_format["sample_rate"]
|
||||
self._voice_id,
|
||||
format="raw",
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
language=self._settings["language"],
|
||||
)
|
||||
self._receive_task = self.get_event_loop().create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
@@ -130,7 +157,7 @@ class LmntTTSService(AsyncTTSService):
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=msg["audio"],
|
||||
sample_rate=self._output_format["sample_rate"],
|
||||
sample_rate=self._settings["output_format"]["sample_rate"],
|
||||
num_channels=1,
|
||||
)
|
||||
await self.push_frame(frame)
|
||||
@@ -149,8 +176,8 @@ class LmntTTSService(AsyncTTSService):
|
||||
await self._connect()
|
||||
|
||||
if not self._started:
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
self._started = True
|
||||
|
||||
try:
|
||||
@@ -159,7 +186,7 @@ class LmntTTSService(AsyncTTSService):
|
||||
await self.start_tts_usage_metrics(text)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} error sending message: {e}")
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
await self._disconnect()
|
||||
await self._connect()
|
||||
return
|
||||
|
||||
@@ -31,6 +31,8 @@ from pipecat.frames.frames import (
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
URLImageRawFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
VisionImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
@@ -109,14 +111,16 @@ class BaseOpenAILLMService(LLMService):
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._settings = {
|
||||
"frequency_penalty": params.frequency_penalty,
|
||||
"presence_penalty": params.presence_penalty,
|
||||
"seed": params.seed,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self._client = self.create_client(api_key=api_key, base_url=base_url, **kwargs)
|
||||
self._frequency_penalty = params.frequency_penalty
|
||||
self._presence_penalty = params.presence_penalty
|
||||
self._seed = params.seed
|
||||
self._temperature = params.temperature
|
||||
self._top_p = params.top_p
|
||||
self._extra = params.extra if isinstance(params.extra, dict) else {}
|
||||
|
||||
def create_client(self, api_key=None, base_url=None, **kwargs):
|
||||
return AsyncOpenAI(
|
||||
@@ -132,30 +136,6 @@ class BaseOpenAILLMService(LLMService):
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_frequency_penalty(self, frequency_penalty: float):
|
||||
logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]")
|
||||
self._frequency_penalty = frequency_penalty
|
||||
|
||||
async def set_presence_penalty(self, presence_penalty: float):
|
||||
logger.debug(f"Switching LLM presence_penalty to: [{presence_penalty}]")
|
||||
self._presence_penalty = presence_penalty
|
||||
|
||||
async def set_seed(self, seed: int):
|
||||
logger.debug(f"Switching LLM seed to: [{seed}]")
|
||||
self._seed = seed
|
||||
|
||||
async def set_temperature(self, temperature: float):
|
||||
logger.debug(f"Switching LLM temperature to: [{temperature}]")
|
||||
self._temperature = temperature
|
||||
|
||||
async def set_top_p(self, top_p: float):
|
||||
logger.debug(f"Switching LLM top_p to: [{top_p}]")
|
||||
self._top_p = top_p
|
||||
|
||||
async def set_extra(self, extra: Dict[str, Any]):
|
||||
logger.debug(f"Switching LLM extra to: [{extra}]")
|
||||
self._extra = extra
|
||||
|
||||
async def get_chat_completions(
|
||||
self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
@@ -166,14 +146,14 @@ class BaseOpenAILLMService(LLMService):
|
||||
"tools": context.tools,
|
||||
"tool_choice": context.tool_choice,
|
||||
"stream_options": {"include_usage": True},
|
||||
"frequency_penalty": self._frequency_penalty,
|
||||
"presence_penalty": self._presence_penalty,
|
||||
"seed": self._seed,
|
||||
"temperature": self._temperature,
|
||||
"top_p": self._top_p,
|
||||
"frequency_penalty": self._settings["frequency_penalty"],
|
||||
"presence_penalty": self._settings["presence_penalty"],
|
||||
"seed": self._settings["seed"],
|
||||
"temperature": self._settings["temperature"],
|
||||
"top_p": self._settings["top_p"],
|
||||
}
|
||||
|
||||
params.update(self._extra)
|
||||
params.update(self._settings["extra"])
|
||||
|
||||
chunks = await self._client.chat.completions.create(**params)
|
||||
return chunks
|
||||
@@ -181,7 +161,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
async def _stream_chat_completions(
|
||||
self, context: OpenAILLMContext
|
||||
) -> AsyncStream[ChatCompletionChunk]:
|
||||
logger.debug(f"Generating chat: {context.get_messages_json()}")
|
||||
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
|
||||
|
||||
messages: List[ChatCompletionMessageParam] = context.get_messages()
|
||||
|
||||
@@ -205,6 +185,10 @@ class BaseOpenAILLMService(LLMService):
|
||||
return chunks
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
functions_list = []
|
||||
arguments_list = []
|
||||
tool_id_list = []
|
||||
func_idx = 0
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
@@ -242,6 +226,14 @@ class BaseOpenAILLMService(LLMService):
|
||||
# yield a frame containing the function name and the arguments.
|
||||
|
||||
tool_call = chunk.choices[0].delta.tool_calls[0]
|
||||
if tool_call.index != func_idx:
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
function_name = ""
|
||||
arguments = ""
|
||||
tool_call_id = ""
|
||||
func_idx += 1
|
||||
if tool_call.function and tool_call.function.name:
|
||||
function_name += tool_call.function.name
|
||||
tool_call_id = tool_call.id
|
||||
@@ -257,38 +249,29 @@ class BaseOpenAILLMService(LLMService):
|
||||
# the context, and re-prompt to get a chat answer. If we don't have a registered
|
||||
# handler, raise an exception.
|
||||
if function_name and arguments:
|
||||
if self.has_function(function_name):
|
||||
await self._handle_function_call(context, tool_call_id, function_name, arguments)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
# added to the list as last function name and arguments not added to the list
|
||||
functions_list.append(function_name)
|
||||
arguments_list.append(arguments)
|
||||
tool_id_list.append(tool_call_id)
|
||||
|
||||
async def _handle_function_call(self, context, tool_call_id, function_name, arguments):
|
||||
arguments = json.loads(arguments)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=tool_call_id,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
)
|
||||
|
||||
async def _update_settings(self, frame: LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
if frame.frequency_penalty is not None:
|
||||
await self.set_frequency_penalty(frame.frequency_penalty)
|
||||
if frame.presence_penalty is not None:
|
||||
await self.set_presence_penalty(frame.presence_penalty)
|
||||
if frame.seed is not None:
|
||||
await self.set_seed(frame.seed)
|
||||
if frame.temperature is not None:
|
||||
await self.set_temperature(frame.temperature)
|
||||
if frame.top_p is not None:
|
||||
await self.set_top_p(frame.top_p)
|
||||
if frame.extra:
|
||||
await self.set_extra(frame.extra)
|
||||
total_items = len(functions_list)
|
||||
for index, (function_name, arguments, tool_id) in enumerate(
|
||||
zip(functions_list, arguments_list, tool_id_list), start=1
|
||||
):
|
||||
if self.has_function(function_name):
|
||||
run_llm = index == total_items
|
||||
arguments = json.loads(arguments)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
tool_call_id=tool_id,
|
||||
run_llm=run_llm,
|
||||
)
|
||||
else:
|
||||
raise OpenAIUnhandledFunctionException(
|
||||
f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
|
||||
)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
@@ -301,7 +284,7 @@ class BaseOpenAILLMService(LLMService):
|
||||
elif isinstance(frame, VisionImageRawFrame):
|
||||
context = OpenAILLMContext.from_image_frame(frame)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame)
|
||||
await self._update_settings(frame.settings)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -336,9 +319,13 @@ class OpenAILLMService(BaseOpenAILLMService):
|
||||
super().__init__(model=model, params=params, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> OpenAIContextAggregatorPair:
|
||||
def create_context_aggregator(
|
||||
context: OpenAILLMContext, *, assistant_expect_stripped_words: bool = True
|
||||
) -> OpenAIContextAggregatorPair:
|
||||
user = OpenAIUserContextAggregator(context)
|
||||
assistant = OpenAIAssistantContextAggregator(user)
|
||||
assistant = OpenAIAssistantContextAggregator(
|
||||
user, expect_stripped_words=assistant_expect_stripped_words
|
||||
)
|
||||
return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
|
||||
@@ -401,22 +388,20 @@ class OpenAITTSService(TTSService):
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._voice: ValidVoice = VALID_VOICES.get(voice, "alloy")
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
}
|
||||
self.set_model_name(model)
|
||||
self._sample_rate = sample_rate
|
||||
self.set_voice(voice)
|
||||
|
||||
self._client = AsyncOpenAI(api_key=api_key)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice = VALID_VOICES.get(voice, self._voice)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
logger.debug(f"Switching TTS model to: [{model}]")
|
||||
self._model = model
|
||||
self.set_model_name(model)
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
@@ -426,7 +411,7 @@ class OpenAITTSService(TTSService):
|
||||
async with self._client.audio.speech.with_streaming_response.create(
|
||||
input=text,
|
||||
model=self.model_name,
|
||||
voice=self._voice,
|
||||
voice=VALID_VOICES[self._voice_id],
|
||||
response_format="pcm",
|
||||
) as r:
|
||||
if r.status_code != 200:
|
||||
@@ -441,61 +426,102 @@ class OpenAITTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
async for chunk in r.iter_bytes(8192):
|
||||
if len(chunk) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
except BadRequestError as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
|
||||
|
||||
# internal use only -- todo: refactor
|
||||
@dataclass
|
||||
class OpenAIImageMessageFrame(Frame):
|
||||
user_image_raw_frame: UserImageRawFrame
|
||||
text: Optional[str] = None
|
||||
|
||||
|
||||
class OpenAIUserContextAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext):
|
||||
super().__init__(context=context)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# Our parent method has already called push_frame(). So we can't interrupt the
|
||||
# flow here and we don't need to call push_frame() ourselves.
|
||||
try:
|
||||
if isinstance(frame, UserImageRequestFrame):
|
||||
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
|
||||
# that frame so we can use it when we assemble the image message in the assistant
|
||||
# context aggregator.
|
||||
if frame.context:
|
||||
if isinstance(frame.context, str):
|
||||
self._context._user_image_request_context[frame.user_id] = frame.context
|
||||
else:
|
||||
logger.error(
|
||||
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
|
||||
)
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
else:
|
||||
if frame.user_id in self._context._user_image_request_context:
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
elif isinstance(frame, UserImageRawFrame):
|
||||
# Push a new AnthropicImageMessageFrame with the text context we cached
|
||||
# downstream to be handled by our assistant context aggregator. This is
|
||||
# necessary so that we add the message to the context in the right order.
|
||||
text = self._context._user_image_request_context.get(frame.user_id) or ""
|
||||
if text:
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
frame = OpenAIImageMessageFrame(user_image_raw_frame=frame, text=text)
|
||||
await self.push_frame(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
|
||||
class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
def __init__(self, user_context_aggregator: OpenAIUserContextAggregator, **kwargs):
|
||||
super().__init__(context=user_context_aggregator._context, **kwargs)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_calls_in_progress = {}
|
||||
self._function_call_result = None
|
||||
self._pending_image_frame_message = None
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# See note above about not calling push_frame() here.
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
self._function_call_in_progress = None
|
||||
self._function_calls_in_progress.clear()
|
||||
self._function_call_finished = None
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
self._function_call_in_progress = frame
|
||||
self._function_calls_in_progress[frame.tool_call_id] = frame
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
if (
|
||||
self._function_call_in_progress
|
||||
and self._function_call_in_progress.tool_call_id == frame.tool_call_id
|
||||
):
|
||||
self._function_call_in_progress = None
|
||||
if frame.tool_call_id in self._function_calls_in_progress:
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
self._function_call_result = frame
|
||||
# TODO-CB: Kwin wants us to refactor this out of here but I REFUSE
|
||||
await self._push_aggregation()
|
||||
else:
|
||||
logger.warning(
|
||||
"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
|
||||
"FunctionCallResultFrame tool_call_id does not match any function call in progress"
|
||||
)
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
elif isinstance(frame, OpenAIImageMessageFrame):
|
||||
self._pending_image_frame_message = frame
|
||||
await self._push_aggregation()
|
||||
|
||||
async def _push_aggregation(self):
|
||||
if not (self._aggregation or self._function_call_result):
|
||||
if not (
|
||||
self._aggregation or self._function_call_result or self._pending_image_frame_message
|
||||
):
|
||||
return
|
||||
|
||||
run_llm = False
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._aggregation = ""
|
||||
self._reset()
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
@@ -524,12 +550,26 @@ class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
)
|
||||
run_llm = True
|
||||
run_llm = frame.run_llm
|
||||
else:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if self._pending_image_frame_message:
|
||||
frame = self._pending_image_frame_message
|
||||
self._pending_image_frame_message = None
|
||||
self._context.add_image_frame_message(
|
||||
format=frame.user_image_raw_frame.format,
|
||||
size=frame.user_image_raw_frame.size,
|
||||
image=frame.user_image_raw_frame.image,
|
||||
text=frame.text,
|
||||
)
|
||||
run_llm = True
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_context_frame()
|
||||
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
@@ -6,17 +6,21 @@
|
||||
|
||||
import io
|
||||
import struct
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
try:
|
||||
from pyht.client import TTSOptions
|
||||
from pyht.async_client import AsyncClient
|
||||
from pyht.client import TTSOptions
|
||||
from pyht.protos.api_pb2 import Format
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -39,17 +43,23 @@ class PlayHTTTSService(TTSService):
|
||||
user_id=self._user_id,
|
||||
api_key=self._speech_key,
|
||||
)
|
||||
self._settings = {
|
||||
"sample_rate": sample_rate,
|
||||
"quality": "higher",
|
||||
"format": Format.FORMAT_WAV,
|
||||
"voice_engine": "PlayHT2.0-turbo",
|
||||
}
|
||||
self.set_voice(voice_url)
|
||||
self._options = TTSOptions(
|
||||
voice=voice_url, sample_rate=sample_rate, quality="higher", format=Format.FORMAT_WAV
|
||||
voice=self._voice_id,
|
||||
sample_rate=self._settings["sample_rate"],
|
||||
quality=self._settings["quality"],
|
||||
format=self._settings["format"],
|
||||
)
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._options.voice = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -60,12 +70,12 @@ class PlayHTTTSService(TTSService):
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
playht_gen = self._client.tts(
|
||||
text, voice_engine="PlayHT2.0-turbo", options=self._options
|
||||
text, voice_engine=self._settings["voice_engine"], options=self._options
|
||||
)
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
async for chunk in playht_gen:
|
||||
# skip the RIFF header.
|
||||
if in_header:
|
||||
@@ -83,8 +93,8 @@ class PlayHTTTSService(TTSService):
|
||||
else:
|
||||
if len(chunk):
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(chunk, 16000, 1)
|
||||
frame = TTSAudioRawFrame(chunk, self._settings["sample_rate"], 1)
|
||||
yield frame
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} error generating TTS: {e}")
|
||||
|
||||
@@ -4,42 +4,18 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from asyncio import CancelledError
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
LLMMessagesFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
StartInterruptionFrame,
|
||||
TextFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_response import (
|
||||
LLMAssistantContextAggregator,
|
||||
LLMUserContextAggregator,
|
||||
)
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
OpenAILLMContextFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.ai_services import LLMService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
|
||||
try:
|
||||
from together import AsyncTogether
|
||||
# Together.ai is recommending OpenAI-compatible function calling, so we've switched over
|
||||
# to using the OpenAI client library here rather than the Together Python client library.
|
||||
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
@@ -48,19 +24,7 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TogetherContextAggregatorPair:
|
||||
_user: "TogetherUserContextAggregator"
|
||||
_assistant: "TogetherAssistantContextAggregator"
|
||||
|
||||
def user(self) -> "TogetherUserContextAggregator":
|
||||
return self._user
|
||||
|
||||
def assistant(self) -> "TogetherAssistantContextAggregator":
|
||||
return self._assistant
|
||||
|
||||
|
||||
class TogetherLLMService(LLMService):
|
||||
class TogetherLLMService(OpenAILLMService):
|
||||
"""This class implements inference with Together's Llama 3.1 models"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
@@ -68,327 +32,45 @@ class TogetherLLMService(LLMService):
|
||||
max_tokens: Optional[int] = Field(default=4096, ge=1)
|
||||
presence_penalty: Optional[float] = Field(default=None, ge=-2.0, le=2.0)
|
||||
temperature: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
# Note: top_k is currently not supported by the OpenAI client library,
|
||||
# so top_k is ignore right now.
|
||||
top_k: Optional[int] = Field(default=None, ge=0)
|
||||
top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
|
||||
extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
seed: Optional[int] = Field(default=None)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.together.xyz/v1",
|
||||
model: str = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
params: InputParams = InputParams(),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._client = AsyncTogether(api_key=api_key)
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, params=params, **kwargs)
|
||||
self.set_model_name(model)
|
||||
self._max_tokens = params.max_tokens
|
||||
self._frequency_penalty = params.frequency_penalty
|
||||
self._presence_penalty = params.presence_penalty
|
||||
self._temperature = params.temperature
|
||||
self._top_k = params.top_k
|
||||
self._top_p = params.top_p
|
||||
self._extra = params.extra if isinstance(params.extra, dict) else {}
|
||||
self._settings = {
|
||||
"max_tokens": params.max_tokens,
|
||||
"frequency_penalty": params.frequency_penalty,
|
||||
"presence_penalty": params.presence_penalty,
|
||||
"seed": params.seed,
|
||||
"temperature": params.temperature,
|
||||
"top_p": params.top_p,
|
||||
"extra": params.extra if isinstance(params.extra, dict) else {},
|
||||
}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def create_context_aggregator(context: OpenAILLMContext) -> TogetherContextAggregatorPair:
|
||||
user = TogetherUserContextAggregator(context)
|
||||
assistant = TogetherAssistantContextAggregator(user)
|
||||
return TogetherContextAggregatorPair(_user=user, _assistant=assistant)
|
||||
|
||||
async def set_frequency_penalty(self, frequency_penalty: float):
|
||||
logger.debug(f"Switching LLM frequency_penalty to: [{frequency_penalty}]")
|
||||
self._frequency_penalty = frequency_penalty
|
||||
|
||||
async def set_max_tokens(self, max_tokens: int):
|
||||
logger.debug(f"Switching LLM max_tokens to: [{max_tokens}]")
|
||||
self._max_tokens = max_tokens
|
||||
|
||||
async def set_presence_penalty(self, presence_penalty: float):
|
||||
logger.debug(f"Switching LLM presence_penalty to: [{presence_penalty}]")
|
||||
self._presence_penalty = presence_penalty
|
||||
|
||||
async def set_temperature(self, temperature: float):
|
||||
logger.debug(f"Switching LLM temperature to: [{temperature}]")
|
||||
self._temperature = temperature
|
||||
|
||||
async def set_top_k(self, top_k: float):
|
||||
logger.debug(f"Switching LLM top_k to: [{top_k}]")
|
||||
self._top_k = top_k
|
||||
|
||||
async def set_top_p(self, top_p: float):
|
||||
logger.debug(f"Switching LLM top_p to: [{top_p}]")
|
||||
self._top_p = top_p
|
||||
|
||||
async def set_extra(self, extra: Dict[str, Any]):
|
||||
logger.debug(f"Switching LLM extra to: [{extra}]")
|
||||
self._extra = extra
|
||||
|
||||
async def _update_settings(self, frame: LLMUpdateSettingsFrame):
|
||||
if frame.model is not None:
|
||||
logger.debug(f"Switching LLM model to: [{frame.model}]")
|
||||
self.set_model_name(frame.model)
|
||||
if frame.frequency_penalty is not None:
|
||||
await self.set_frequency_penalty(frame.frequency_penalty)
|
||||
if frame.max_tokens is not None:
|
||||
await self.set_max_tokens(frame.max_tokens)
|
||||
if frame.presence_penalty is not None:
|
||||
await self.set_presence_penalty(frame.presence_penalty)
|
||||
if frame.temperature is not None:
|
||||
await self.set_temperature(frame.temperature)
|
||||
if frame.top_k is not None:
|
||||
await self.set_top_k(frame.top_k)
|
||||
if frame.top_p is not None:
|
||||
await self.set_top_p(frame.top_p)
|
||||
if frame.extra:
|
||||
await self.set_extra(frame.extra)
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext):
|
||||
try:
|
||||
await self.push_frame(LLMFullResponseStartFrame())
|
||||
await self.start_processing_metrics()
|
||||
|
||||
logger.debug(f"Generating chat: {context.get_messages_for_logging()}")
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
params = {
|
||||
"messages": context.messages,
|
||||
"model": self.model_name,
|
||||
"max_tokens": self._max_tokens,
|
||||
"stream": True,
|
||||
"frequency_penalty": self._frequency_penalty,
|
||||
"presence_penalty": self._presence_penalty,
|
||||
"temperature": self._temperature,
|
||||
"top_k": self._top_k,
|
||||
"top_p": self._top_p,
|
||||
}
|
||||
|
||||
params.update(self._extra)
|
||||
|
||||
stream = await self._client.chat.completions.create(**params)
|
||||
|
||||
# Function calling
|
||||
got_first_chunk = False
|
||||
accumulating_function_call = False
|
||||
function_call_accumulator = ""
|
||||
|
||||
async for chunk in stream:
|
||||
# logger.debug(f"Together LLM event: {chunk}")
|
||||
if chunk.usage:
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=chunk.usage.prompt_tokens,
|
||||
completion_tokens=chunk.usage.completion_tokens,
|
||||
total_tokens=chunk.usage.total_tokens,
|
||||
)
|
||||
await self.start_llm_usage_metrics(tokens)
|
||||
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
|
||||
if not got_first_chunk:
|
||||
await self.stop_ttfb_metrics()
|
||||
if chunk.choices[0].delta.content:
|
||||
got_first_chunk = True
|
||||
if chunk.choices[0].delta.content[0] == "<":
|
||||
accumulating_function_call = True
|
||||
|
||||
if chunk.choices[0].delta.content:
|
||||
if accumulating_function_call:
|
||||
function_call_accumulator += chunk.choices[0].delta.content
|
||||
else:
|
||||
await self.push_frame(TextFrame(chunk.choices[0].delta.content))
|
||||
|
||||
if chunk.choices[0].finish_reason == "eos" and accumulating_function_call:
|
||||
await self._extract_function_call(context, function_call_accumulator)
|
||||
|
||||
except CancelledError:
|
||||
# todo: implement token counting estimates for use when the user interrupts a long generation
|
||||
# we do this in the anthropic.py service
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception(f"{self} exception: {e}")
|
||||
finally:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
context = None
|
||||
if isinstance(frame, OpenAILLMContextFrame):
|
||||
context = frame.context
|
||||
elif isinstance(frame, LLMMessagesFrame):
|
||||
context = TogetherLLMContext.from_messages(frame.messages)
|
||||
elif isinstance(frame, LLMUpdateSettingsFrame):
|
||||
await self._update_settings(frame)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
if context:
|
||||
await self._process_context(context)
|
||||
|
||||
async def _extract_function_call(self, context, function_call_accumulator):
|
||||
context.add_message({"role": "assistant", "content": function_call_accumulator})
|
||||
|
||||
function_regex = r"<function=(\w+)>(.*?)</function>"
|
||||
match = re.search(function_regex, function_call_accumulator)
|
||||
if match:
|
||||
function_name, args_string = match.groups()
|
||||
try:
|
||||
arguments = json.loads(args_string)
|
||||
await self.call_function(
|
||||
context=context,
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
function_name=function_name,
|
||||
arguments=arguments,
|
||||
def create_client(self, api_key=None, base_url=None, **kwargs):
|
||||
logger.debug(f"Creating Together.ai client with api {base_url}")
|
||||
return AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
http_client=DefaultAsyncHttpxClient(
|
||||
limits=httpx.Limits(
|
||||
max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None
|
||||
)
|
||||
return
|
||||
except json.JSONDecodeError as error:
|
||||
# We get here if the LLM returns a function call with invalid JSON arguments. This could happen
|
||||
# because of LLM non-determinism, or maybe more often because of user error in the prompt.
|
||||
# Should we do anything more than log a warning?
|
||||
logger.debug(f"Error parsing function arguments: {error}")
|
||||
|
||||
|
||||
class TogetherLLMContext(OpenAILLMContext):
|
||||
def __init__(
|
||||
self,
|
||||
messages: list[dict] | None = None,
|
||||
):
|
||||
super().__init__(messages=messages)
|
||||
|
||||
@classmethod
|
||||
def from_openai_context(cls, openai_context: OpenAILLMContext):
|
||||
self = cls(
|
||||
messages=openai_context.messages,
|
||||
),
|
||||
)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: List[dict]) -> "TogetherLLMContext":
|
||||
return cls(messages=messages)
|
||||
|
||||
def add_message(self, message):
|
||||
try:
|
||||
self.messages.append(message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message: {e}")
|
||||
|
||||
def get_messages_for_logging(self) -> str:
|
||||
return json.dumps(self.messages)
|
||||
|
||||
|
||||
class TogetherUserContextAggregator(LLMUserContextAggregator):
|
||||
def __init__(self, context: OpenAILLMContext | TogetherLLMContext):
|
||||
super().__init__(context=context)
|
||||
|
||||
if isinstance(context, OpenAILLMContext):
|
||||
self._context = TogetherLLMContext.from_openai_context(context)
|
||||
|
||||
async def push_messages_frame(self):
|
||||
frame = OpenAILLMContextFrame(self._context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# Our parent method has already called push_frame(). So we can't interrupt the
|
||||
# flow here and we don't need to call push_frame() ourselves. Possibly something
|
||||
# to talk through (tagging @aleix). At some point we might need to refactor these
|
||||
# context aggregators.
|
||||
try:
|
||||
if isinstance(frame, UserImageRequestFrame):
|
||||
# The LLM sends a UserImageRequestFrame upstream. Cache any context provided with
|
||||
# that frame so we can use it when we assemble the image message in the assistant
|
||||
# context aggregator.
|
||||
if frame.context:
|
||||
if isinstance(frame.context, str):
|
||||
self._context._user_image_request_context[frame.user_id] = frame.context
|
||||
else:
|
||||
logger.error(
|
||||
f"Unexpected UserImageRequestFrame context type: {type(frame.context)}"
|
||||
)
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
else:
|
||||
if frame.user_id in self._context._user_image_request_context:
|
||||
del self._context._user_image_request_context[frame.user_id]
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
|
||||
#
|
||||
# Claude returns a text content block along with a tool use content block. This works quite nicely
|
||||
# with streaming. We get the text first, so we can start streaming it right away. Then we get the
|
||||
# tool_use block. While the text is streaming to TTS and the transport, we can run the tool call.
|
||||
#
|
||||
# But Claude is verbose. It would be nice to come up with prompt language that suppresses Claude's
|
||||
# chattiness about it's tool thinking.
|
||||
#
|
||||
|
||||
|
||||
class TogetherAssistantContextAggregator(LLMAssistantContextAggregator):
|
||||
def __init__(self, user_context_aggregator: TogetherUserContextAggregator):
|
||||
super().__init__(context=user_context_aggregator._context)
|
||||
self._user_context_aggregator = user_context_aggregator
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
|
||||
async def process_frame(self, frame, direction):
|
||||
await super().process_frame(frame, direction)
|
||||
# See note above about not calling push_frame() here.
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_finished = None
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
self._function_call_in_progress = frame
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
if (
|
||||
self._function_call_in_progress
|
||||
and self._function_call_in_progress.tool_call_id == frame.tool_call_id
|
||||
):
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = frame
|
||||
await self._push_aggregation()
|
||||
else:
|
||||
logger.warning(
|
||||
"FunctionCallResultFrame tool_call_id does not match FunctionCallInProgressFrame tool_call_id"
|
||||
)
|
||||
self._function_call_in_progress = None
|
||||
self._function_call_result = None
|
||||
|
||||
def add_message(self, message):
|
||||
self._user_context_aggregator.add_message(message)
|
||||
|
||||
async def _push_aggregation(self):
|
||||
if not (self._aggregation or self._function_call_result):
|
||||
return
|
||||
|
||||
run_llm = False
|
||||
|
||||
aggregation = self._aggregation
|
||||
self._aggregation = ""
|
||||
|
||||
try:
|
||||
if self._function_call_result:
|
||||
frame = self._function_call_result
|
||||
self._function_call_result = None
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
# Together expects the content here to be a string, so stringify it
|
||||
"content": str(frame.result),
|
||||
}
|
||||
)
|
||||
run_llm = True
|
||||
else:
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
|
||||
if run_llm:
|
||||
await self._user_context_aggregator.push_messages_frame()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing frame: {e}")
|
||||
|
||||
@@ -4,10 +4,12 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import aiohttp
|
||||
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
@@ -17,10 +19,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import TTSService
|
||||
|
||||
import numpy as np
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import resampy
|
||||
@@ -43,25 +42,70 @@ class XTTSService(TTSService):
|
||||
self,
|
||||
*,
|
||||
voice_id: str,
|
||||
language: str,
|
||||
language: Language,
|
||||
base_url: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._voice_id = voice_id
|
||||
self._language = language
|
||||
self._base_url = base_url
|
||||
self._settings = {
|
||||
"language": self.language_to_service_language(language),
|
||||
"base_url": base_url,
|
||||
}
|
||||
self.set_voice(voice_id)
|
||||
self._studio_speakers: Dict[str, Any] | None = None
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> str | None:
|
||||
match language:
|
||||
case Language.CS:
|
||||
return "cs"
|
||||
case Language.DE:
|
||||
return "de"
|
||||
case (
|
||||
Language.EN
|
||||
| Language.EN_US
|
||||
| Language.EN_AU
|
||||
| Language.EN_GB
|
||||
| Language.EN_NZ
|
||||
| Language.EN_IN
|
||||
):
|
||||
return "en"
|
||||
case Language.ES:
|
||||
return "es"
|
||||
case Language.FR:
|
||||
return "fr"
|
||||
case Language.HI:
|
||||
return "hi"
|
||||
case Language.HU:
|
||||
return "hu"
|
||||
case Language.IT:
|
||||
return "it"
|
||||
case Language.JA:
|
||||
return "ja"
|
||||
case Language.KO:
|
||||
return "ko"
|
||||
case Language.NL:
|
||||
return "nl"
|
||||
case Language.PL:
|
||||
return "pl"
|
||||
case Language.PT | Language.PT_BR:
|
||||
return "pt"
|
||||
case Language.RU:
|
||||
return "ru"
|
||||
case Language.TR:
|
||||
return "tr"
|
||||
case Language.ZH:
|
||||
return "zh-cn"
|
||||
return None
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
async with self._aiohttp_session.get(self._base_url + "/studio_speakers") as r:
|
||||
async with self._aiohttp_session.get(self._settings["base_url"] + "/studio_speakers") as r:
|
||||
if r.status != 200:
|
||||
text = await r.text()
|
||||
logger.error(
|
||||
@@ -75,10 +119,6 @@ class XTTSService(TTSService):
|
||||
return
|
||||
self._studio_speakers = await r.json()
|
||||
|
||||
async def set_voice(self, voice: str):
|
||||
logger.debug(f"Switching TTS voice to: [{voice}]")
|
||||
self._voice_id = voice
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
@@ -88,11 +128,11 @@ class XTTSService(TTSService):
|
||||
|
||||
embeddings = self._studio_speakers[self._voice_id]
|
||||
|
||||
url = self._base_url + "/tts_stream"
|
||||
url = self._settings["base_url"] + "/tts_stream"
|
||||
|
||||
payload = {
|
||||
"text": text.replace(".", "").replace("*", ""),
|
||||
"language": self._language,
|
||||
"language": self._settings["language"],
|
||||
"speaker_embedding": embeddings["speaker_embedding"],
|
||||
"gpt_cond_latent": embeddings["gpt_cond_latent"],
|
||||
"add_wav_header": False,
|
||||
@@ -110,7 +150,7 @@ class XTTSService(TTSService):
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
|
||||
await self.push_frame(TTSStartedFrame())
|
||||
yield TTSStartedFrame()
|
||||
|
||||
buffer = bytearray()
|
||||
async for chunk in r.content.iter_chunked(1024):
|
||||
@@ -146,4 +186,4 @@ class XTTSService(TTSService):
|
||||
frame = TTSAudioRawFrame(resampled_audio_bytes, 16000, 1)
|
||||
yield frame
|
||||
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -5,17 +5,17 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
BotInterruptionFrame,
|
||||
CancelFrame,
|
||||
InputAudioRawFrame,
|
||||
StartFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
StartFrame,
|
||||
StartInterruptionFrame,
|
||||
StopInterruptionFrame,
|
||||
SystemFrame,
|
||||
@@ -23,15 +23,14 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
VADParamsUpdateFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.vad.vad_analyzer import VADAnalyzer, VADState
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class BaseInputTransport(FrameProcessor):
|
||||
def __init__(self, params: TransportParams, **kwargs):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._params = params
|
||||
|
||||
@@ -87,6 +86,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
elif isinstance(frame, BotInterruptionFrame):
|
||||
logger.debug("Bot interruption")
|
||||
await self._start_interruption()
|
||||
await self.push_frame(StartInterruptionFrame())
|
||||
# All other system frames
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -33,6 +33,7 @@ from pipecat.frames.frames import (
|
||||
TTSStoppedFrame,
|
||||
TextFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
|
||||
@@ -43,7 +44,7 @@ from pipecat.utils.time import nanoseconds_to_seconds
|
||||
|
||||
class BaseOutputTransport(FrameProcessor):
|
||||
def __init__(self, params: TransportParams, **kwargs):
|
||||
super().__init__(sync=False, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._params = params
|
||||
|
||||
@@ -148,7 +149,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._audio_out_task
|
||||
self._audio_out_task = None
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
pass
|
||||
|
||||
async def send_metrics(self, frame: MetricsFrame):
|
||||
@@ -180,12 +181,14 @@ class BaseOutputTransport(FrameProcessor):
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self.cancel(frame)
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame):
|
||||
elif isinstance(frame, (StartInterruptionFrame, StopInterruptionFrame)):
|
||||
await self.push_frame(frame, direction)
|
||||
await self._handle_interruptions(frame)
|
||||
elif isinstance(frame, MetricsFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
await self.send_metrics(frame)
|
||||
elif isinstance(frame, TransportMessageUrgentFrame):
|
||||
await self.send_message(frame)
|
||||
elif isinstance(frame, SystemFrame):
|
||||
await self.push_frame(frame, direction)
|
||||
# Control frames.
|
||||
@@ -196,10 +199,8 @@ class BaseOutputTransport(FrameProcessor):
|
||||
# Other frames.
|
||||
elif isinstance(frame, OutputAudioRawFrame):
|
||||
await self._handle_audio(frame)
|
||||
elif isinstance(frame, OutputImageRawFrame) or isinstance(frame, SpriteFrame):
|
||||
elif isinstance(frame, (OutputImageRawFrame, SpriteFrame)):
|
||||
await self._handle_image(frame)
|
||||
elif isinstance(frame, TransportMessageFrame) and frame.urgent:
|
||||
await self.send_message(frame)
|
||||
# TODO(aleix): Images and audio should support presentation timestamps.
|
||||
elif frame.pts:
|
||||
await self._sink_clock_queue.put((frame.pts, frame.id, frame))
|
||||
|
||||
@@ -35,6 +35,7 @@ from pipecat.frames.frames import (
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
UserImageRawFrame,
|
||||
UserImageRequestFrame,
|
||||
)
|
||||
@@ -70,6 +71,11 @@ class DailyTransportMessageFrame(TransportMessageFrame):
|
||||
participant_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DailyTransportMessageUrgentFrame(TransportMessageUrgentFrame):
|
||||
participant_id: str | None = None
|
||||
|
||||
|
||||
class WebRTCVADAnalyzer(VADAnalyzer):
|
||||
def __init__(self, *, sample_rate=16000, num_channels=1, params: VADParams = VADParams()):
|
||||
super().__init__(sample_rate=sample_rate, num_channels=num_channels, params=params)
|
||||
@@ -234,12 +240,12 @@ class DailyTransportClient(EventHandler):
|
||||
def set_callbacks(self, callbacks: DailyCallbacks):
|
||||
self._callbacks = callbacks
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
if not self._client:
|
||||
return
|
||||
|
||||
participant_id = None
|
||||
if isinstance(frame, DailyTransportMessageFrame):
|
||||
if isinstance(frame, (DailyTransportMessageFrame, DailyTransportMessageUrgentFrame)):
|
||||
participant_id = frame.participant_id
|
||||
|
||||
future = self._loop.create_future()
|
||||
@@ -736,7 +742,7 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
await super().cleanup()
|
||||
await self._client.cleanup()
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
await self._client.send_message(frame)
|
||||
|
||||
async def send_metrics(self, frame: MetricsFrame):
|
||||
|
||||
@@ -22,6 +22,7 @@ from pipecat.frames.frames import (
|
||||
MetricsFrame,
|
||||
StartFrame,
|
||||
TransportMessageFrame,
|
||||
TransportMessageUrgentFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import (
|
||||
LLMUsageMetricsData,
|
||||
@@ -51,6 +52,11 @@ class LiveKitTransportMessageFrame(TransportMessageFrame):
|
||||
participant_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LiveKitTransportMessageUrgentFrame(TransportMessageUrgentFrame):
|
||||
participant_id: str | None = None
|
||||
|
||||
|
||||
class LiveKitParams(TransportParams):
|
||||
audio_out_sample_rate: int = 48000
|
||||
audio_out_channels: int = 1
|
||||
@@ -420,8 +426,8 @@ class LiveKitOutputTransport(BaseOutputTransport):
|
||||
await super().cancel(frame)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def send_message(self, frame: TransportMessageFrame):
|
||||
if isinstance(frame, LiveKitTransportMessageFrame):
|
||||
async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame):
|
||||
if isinstance(frame, (LiveKitTransportMessageFrame, LiveKitTransportMessageUrgentFrame)):
|
||||
await self._client.send_data(frame.message.encode(), frame.participant_id)
|
||||
else:
|
||||
await self._client.send_data(frame.message.encode())
|
||||
@@ -596,6 +602,13 @@ class LiveKitTransport(BaseTransport):
|
||||
frame = LiveKitTransportMessageFrame(message=message, participant_id=participant_id)
|
||||
await self._output.send_message(frame)
|
||||
|
||||
async def send_message_urgent(self, message: str, participant_id: str | None = None):
|
||||
if self._output:
|
||||
frame = LiveKitTransportMessageUrgentFrame(
|
||||
message=message, participant_id=participant_id
|
||||
)
|
||||
await self._output.send_message(frame)
|
||||
|
||||
async def cleanup(self):
|
||||
if self._input:
|
||||
await self._input.cleanup()
|
||||
|
||||
@@ -6,7 +6,6 @@
|
||||
|
||||
import re
|
||||
|
||||
|
||||
ENDOFSENTENCE_PATTERN_STR = r"""
|
||||
(?<![A-Z]) # Negative lookbehind: not preceded by an uppercase letter (e.g., "U.S.A.")
|
||||
(?<!\d) # Negative lookbehind: not preceded by a digit (e.g., "1. Let's start")
|
||||
@@ -21,5 +20,6 @@ ENDOFSENTENCE_PATTERN_STR = r"""
|
||||
ENDOFSENTENCE_PATTERN = re.compile(ENDOFSENTENCE_PATTERN_STR, re.VERBOSE)
|
||||
|
||||
|
||||
def match_endofsentence(text: str) -> bool:
|
||||
return ENDOFSENTENCE_PATTERN.search(text.rstrip()) is not None
|
||||
def match_endofsentence(text: str) -> int:
|
||||
match = ENDOFSENTENCE_PATTERN.search(text.rstrip())
|
||||
return match.end() if match else 0
|
||||
|
||||
0
src/pipecat/utils/text/__init__.py
Normal file
0
src/pipecat/utils/text/__init__.py
Normal file
18
src/pipecat/utils/text/base_text_filter.py
Normal file
18
src/pipecat/utils/text/base_text_filter.py
Normal file
@@ -0,0 +1,18 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Mapping
|
||||
|
||||
|
||||
class BaseTextFilter(ABC):
|
||||
@abstractmethod
|
||||
def update_settings(self, settings: Mapping[str, Any]):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def filter(self, text: str) -> str:
|
||||
pass
|
||||
84
src/pipecat/utils/text/markdown_text_filter.py
Normal file
84
src/pipecat/utils/text/markdown_text_filter.py
Normal file
@@ -0,0 +1,84 @@
|
||||
#
|
||||
# Copyright (c) 2024, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import re
|
||||
from typing import Any, Mapping
|
||||
|
||||
from markdown import Markdown
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.utils.text.base_text_filter import BaseTextFilter
|
||||
|
||||
|
||||
class MarkdownTextFilter(BaseTextFilter):
|
||||
"""Removes Markdown formatting from text in TextFrames.
|
||||
|
||||
Converts Markdown to plain text while preserving the overall structure,
|
||||
including leading and trailing spaces. Handles special cases like
|
||||
asterisks and table formatting.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
enable_text_filter: bool = True
|
||||
|
||||
def __init__(self, params: InputParams = InputParams(), **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._settings = params
|
||||
|
||||
def update_settings(self, settings: Mapping[str, Any]):
|
||||
for key, value in settings.items():
|
||||
if hasattr(self._settings, key):
|
||||
setattr(self._settings, key, value)
|
||||
|
||||
def filter(self, text: str) -> str:
|
||||
if self._settings.enable_text_filter:
|
||||
# Replace newlines with spaces only when there's no text before or after
|
||||
text = re.sub(r"^\s*\n", " ", text, flags=re.MULTILINE)
|
||||
|
||||
# Remove repeated sequences of 5 or more characters
|
||||
text = re.sub(r"(\S)(\1{4,})", "", text)
|
||||
|
||||
# Preserve numbered list items with a unique marker, §NUM§
|
||||
text = re.sub(r"^(\d+\.)\s", r"§NUM§\1 ", text)
|
||||
|
||||
# Preserve leading/trailing spaces with a unique marker, §
|
||||
# Critical for word-by-word streaming in bot-tts-text
|
||||
preserved_markdown = re.sub(
|
||||
r"^( +)|\s+$", lambda m: "§" * len(m.group(0)), text, flags=re.MULTILINE
|
||||
)
|
||||
|
||||
# Convert markdown to HTML
|
||||
md = Markdown()
|
||||
html = md.convert(preserved_markdown)
|
||||
|
||||
# Remove HTML tags
|
||||
filtered_text = re.sub("<[^<]+?>", "", html)
|
||||
|
||||
# Replace HTML entities
|
||||
filtered_text = filtered_text.replace(" ", " ")
|
||||
filtered_text = filtered_text.replace("<", "<")
|
||||
filtered_text = filtered_text.replace(">", ">")
|
||||
filtered_text = filtered_text.replace("&", "&")
|
||||
|
||||
# Remove double asterisks (consecutive without any exceptions)
|
||||
filtered_text = re.sub(r"\*\*", "", filtered_text)
|
||||
|
||||
# Remove single asterisks at the start or end of words
|
||||
filtered_text = re.sub(r"(^|\s)\*|\*($|\s)", r"\1\2", filtered_text)
|
||||
|
||||
# Remove Markdown table formatting
|
||||
filtered_text = re.sub(r"\|", "", filtered_text)
|
||||
filtered_text = re.sub(r"^\s*[-:]+\s*$", "", filtered_text, flags=re.MULTILINE)
|
||||
|
||||
# Restore numbered list items
|
||||
filtered_text = filtered_text.replace("§NUM§", "")
|
||||
|
||||
# Restore leading and trailing spaces
|
||||
filtered_text = re.sub("§", " ", filtered_text)
|
||||
|
||||
return filtered_text
|
||||
else:
|
||||
return text
|
||||
@@ -2,10 +2,10 @@ aiohttp~=3.10.3
|
||||
anthropic~=0.30.0
|
||||
azure-cognitiveservices-speech~=1.40.0
|
||||
boto3~=1.35.27
|
||||
daily-python~=0.10.1
|
||||
daily-python~=0.11.0
|
||||
deepgram-sdk~=3.5.0
|
||||
fal-client~=0.4.1
|
||||
fastapi~=0.112.1
|
||||
fastapi~=0.115.0
|
||||
faster-whisper~=1.0.3
|
||||
google-cloud-texttospeech~=2.17.2
|
||||
google-generativeai~=0.7.2
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
import unittest
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMFullResponseStartFrame,
|
||||
StopTaskFrame,
|
||||
TextFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
@@ -32,6 +32,7 @@ from langchain_core.language_models import FakeStreamingListLLM
|
||||
class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
class MockProcessor(FrameProcessor):
|
||||
def __init__(self, name):
|
||||
super().__init__()
|
||||
self.name = name
|
||||
self.token: list[str] = []
|
||||
# Start collecting tokens when we see the start frame
|
||||
@@ -55,13 +56,13 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
def setUp(self):
|
||||
self.expected_response = "Hello dear human"
|
||||
self.fake_llm = FakeStreamingListLLM(responses=[self.expected_response])
|
||||
self.mock_proc = self.MockProcessor("token_collector")
|
||||
|
||||
async def test_langchain(self):
|
||||
messages = [("system", "Say hello to {name}"), ("human", "{input}")]
|
||||
prompt = ChatPromptTemplate.from_messages(messages).partial(name="Thomas")
|
||||
chain = prompt | self.fake_llm
|
||||
proc = LangchainProcessor(chain=chain)
|
||||
self.mock_proc = self.MockProcessor("token_collector")
|
||||
|
||||
tma_in = LLMUserResponseAggregator(messages)
|
||||
tma_out = LLMAssistantResponseAggregator(messages)
|
||||
@@ -81,7 +82,7 @@ class TestLangchain(unittest.IsolatedAsyncioTestCase):
|
||||
UserStartedSpeakingFrame(),
|
||||
TranscriptionFrame(text="Hi World", user_id="user", timestamp="now"),
|
||||
UserStoppedSpeakingFrame(),
|
||||
StopTaskFrame(),
|
||||
EndFrame(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user