Compare commits
35 Commits
aleix/clau
...
hush/conte
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
77c82c64c0 | ||
|
|
b5e79f9dc5 | ||
|
|
613b96819f | ||
|
|
57c24670ea | ||
|
|
d79dd94019 | ||
|
|
fa8e7458e1 | ||
|
|
4d66191963 | ||
|
|
7e9d67002e | ||
|
|
ffbb6e5937 | ||
|
|
535b85cf90 | ||
|
|
8dc9872ed5 | ||
|
|
f37a53cc25 | ||
|
|
9cce28c64c | ||
|
|
3ca94363ec | ||
|
|
050f287ec4 | ||
|
|
e6f5561785 | ||
|
|
2df91f4b37 | ||
|
|
7db49b9067 | ||
|
|
7c497bdc89 | ||
|
|
1aa4247d2b | ||
|
|
acba544e6f | ||
|
|
5d93c64ee5 | ||
|
|
de10bc8803 | ||
|
|
36f5c1722d | ||
|
|
a8280522e5 | ||
|
|
05d65dfdd3 | ||
|
|
a3962e3b47 | ||
|
|
cd231cf829 | ||
|
|
9fafc1692d | ||
|
|
7648d0436c | ||
|
|
bff8747e38 | ||
|
|
d227c0c097 | ||
|
|
9ccde60521 | ||
|
|
b84a40666c | ||
|
|
e72b135a4c |
52
CHANGELOG.md
52
CHANGELOG.md
@@ -7,24 +7,76 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added `wait_for_all` argument to the base `LLMService`. When enabled, this
|
||||
ensures all function calls complete before returning results to the LLM (i.e.,
|
||||
before running a new inference with those results).
|
||||
|
||||
### Changed
|
||||
|
||||
- Improved interruption handling to prevent bots from repeating themselves.
|
||||
LLM services that return multiple sentences in a single response (e.g.,
|
||||
`GoogleLLMService`) are now split into individual sentences before being sent
|
||||
to TTS. This ensures interruptions occur at sentence boundaries, preventing
|
||||
the bot from repeating content after being interrupted during long responses.
|
||||
|
||||
- Text Aggregation Improvements:
|
||||
|
||||
- **Breaking Change**: `BaseTextAggregator.aggregate()` now returns
|
||||
`AsyncIterator[Aggregation]` instead of `Optional[Aggregation]`. This
|
||||
enables the aggregator to return multiple results based on the provided
|
||||
text.
|
||||
- Refactored text aggregators to use inheritance: `SkipTagsAggregator` and
|
||||
`PatternPairAggregator` now inherit from `SimpleTextAggregator`, reusing
|
||||
the base class's sentence detection logic.
|
||||
|
||||
- Updated `AICFilter` to use Quail STT as the default model
|
||||
(`AICModelType.QUAIL_STT`). Quail STT is optimized for human-to-machine
|
||||
interaction (e.g., voice agents, speech-to-text) and operates at a native
|
||||
sample rate of 16 kHz with fixed enhancement parameters.
|
||||
|
||||
- Updated Deepgram logging to include Deepgram request IDs for improved debugging.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- Package `pipecat.sync` is deprecated, use `pipecat.utils.sync` instead.
|
||||
|
||||
- The `noise_gate_enable` parameter in `AICFilter` is deprecated and no longer
|
||||
has any effect. Noise gating is now handled automatically by the AIC VAD
|
||||
system. Use `AICFilter.create_vad_analyzer()` for VAD functionality instead.
|
||||
|
||||
- NVIDIA Services name changes (all functionality is unchanged):
|
||||
|
||||
- `NimLLMService` is now deprecated, use `NvidiaLLMService` instead.
|
||||
- `RivaSTTService` is now deprecated, use `NvidiaSTTService` instead.
|
||||
- `RivaTTSService` is now deprecated, use `NvidiaTTSService` instead.
|
||||
- Use `uv pip install pipecat-ai[nvidia]` instead of
|
||||
`uv pip install pipecat-ai[riva]`
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where `LLMTextFrame.skip_tts` was being overwritten by LLM
|
||||
services.
|
||||
|
||||
- Fixed sentence aggregation to correctly handle ambiguous punctuation in
|
||||
streaming text, such as currency ("$29.95") and abbreviations ("Mr. Smith").
|
||||
|
||||
- Fixed bug in `PatternPairAggregator` where pattern handlers could be called
|
||||
multiple times for `KEEP` or `AGGREGATE` patterns.
|
||||
|
||||
- Fixed an issue in `SarvamTTSService` where the last sentence was not being
|
||||
spoken. Now, audio is flushed when the TTS services receives the
|
||||
`LLMFullResponseEndFrame` or `EndFrame`.
|
||||
|
||||
- Fixed an issue in `AWSTranscribeSTTService` where the `region` arg was
|
||||
always set to `us-east-1` when providing an AWS_REGION env var.
|
||||
|
||||
- Fixed an issue in `DeepgramTTSService` where a `TTSStoppedFrame` was
|
||||
incorrectly pushed after a functional call. This caused an issue with the
|
||||
voice-ui-kit's conversational panel rending of the LLM output after a
|
||||
function call.
|
||||
|
||||
## [0.0.96] - 2025-11-26 🦃 "Happy Thanksgiving!" 🦃
|
||||
|
||||
### Added
|
||||
|
||||
@@ -79,7 +79,7 @@ Once your PR is submitted, post in the `#community-integrations` Discord channel
|
||||
|
||||
**Examples:**
|
||||
|
||||
- [RivaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/riva/stt.py)
|
||||
- [NvidiaSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/nvidia/stt.py)
|
||||
- [FalSTTService](https://github.com/pipecat-ai/pipecat/blob/main/src/pipecat/services/fal/stt.py)
|
||||
|
||||
#### Key requirements:
|
||||
|
||||
@@ -119,7 +119,6 @@ def import_core_modules():
|
||||
"pipecat.observers",
|
||||
"pipecat.runner",
|
||||
"pipecat.serializers",
|
||||
"pipecat.sync",
|
||||
"pipecat.transcriptions",
|
||||
"pipecat.utils",
|
||||
]
|
||||
|
||||
@@ -30,7 +30,6 @@ Quick Links
|
||||
Runner <api/pipecat.runner>
|
||||
Serializers <api/pipecat.serializers>
|
||||
Services <api/pipecat.services>
|
||||
Sync <api/pipecat.sync>
|
||||
Transcriptions <api/pipecat.transcriptions>
|
||||
Transports <api/pipecat.transports>
|
||||
Utils <api/pipecat.utils>
|
||||
Utils <api/pipecat.utils>
|
||||
|
||||
@@ -15,7 +15,7 @@ from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.riva.tts import FastPitchTTSService
|
||||
from pipecat.services.nvidia.tts import NvidiaTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -36,7 +36,7 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
tts = FastPitchTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
task = PipelineTask(
|
||||
Pipeline([tts, transport.output()]),
|
||||
@@ -13,12 +13,13 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame, LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
@@ -30,6 +31,44 @@ from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
FILTERED_WORDS = ["apple", "banana", "car"]
|
||||
|
||||
|
||||
class ContentFilterProcessor(FrameProcessor):
|
||||
"""Processor that filters LLMContextFrames containing specific words.
|
||||
|
||||
If the user's message contains any of the filtered words, the context
|
||||
is replaced with a message indicating the assistant cannot respond.
|
||||
"""
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, LLMContextFrame):
|
||||
# Check the last user message for filtered words
|
||||
messages = frame.context.messages
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
content = last_message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
content_lower = content.lower()
|
||||
if any(word in content_lower for word in FILTERED_WORDS):
|
||||
logger.info(f"Filtered content detected: {content}")
|
||||
# Create a new context with a filtered response instruction
|
||||
filtered_context = LLMContext(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "The user is asking about something you cannot give an answer about. Tell them you don't know how to respond.",
|
||||
}
|
||||
]
|
||||
)
|
||||
await self.push_frame(LLMContextFrame(filtered_context), direction)
|
||||
return
|
||||
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We store functions so objects (e.g. SileroVADAnalyzer) don't get
|
||||
# instantiated. The function will be called when the desired transport gets
|
||||
# selected.
|
||||
@@ -76,12 +115,14 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
|
||||
context = LLMContext(messages)
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
content_filter = ContentFilterProcessor()
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
content_filter, # Content filter
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
|
||||
@@ -22,9 +22,9 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.nim.llm import NimLLMService
|
||||
from pipecat.services.riva.stt import RivaSTTService
|
||||
from pipecat.services.riva.tts import RivaTTSService
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
from pipecat.services.nvidia.stt import NvidiaSTTService
|
||||
from pipecat.services.nvidia.tts import NvidiaTTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -59,11 +59,13 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = RivaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
stt = NvidiaSTTService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
llm = NimLLMService(api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct")
|
||||
llm = NvidiaLLMService(
|
||||
api_key=os.getenv("NVIDIA_API_KEY"), model="meta/llama-3.1-405b-instruct"
|
||||
)
|
||||
|
||||
tts = RivaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
tts = NvidiaTTSService(api_key=os.getenv("NVIDIA_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
@@ -27,7 +27,7 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.nim.llm import NimLLMService
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -75,11 +75,11 @@ async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
# text_filters=[MarkdownTextFilter()],
|
||||
)
|
||||
|
||||
llm = NimLLMService(
|
||||
llm = NvidiaLLMService(
|
||||
api_key=os.getenv("NVIDIA_API_KEY"),
|
||||
model="nvidia/llama-3.3-nemotron-super-49b-v1.5",
|
||||
# Recommended when turning thinking off
|
||||
params=NimLLMService.InputParams(temperature=0.0),
|
||||
params=NvidiaLLMService.InputParams(temperature=0.0),
|
||||
)
|
||||
# You can also register a function_name of None to get all functions
|
||||
# sent to the same callback with an additional function_name parameter.
|
||||
@@ -14,20 +14,13 @@ from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import (
|
||||
LLMRunFrame,
|
||||
LLMSetToolsFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
TranscriptionMessage,
|
||||
)
|
||||
from pipecat.frames.frames import LLMRunFrame, LLMSetToolsFrame, TranscriptionMessage
|
||||
from pipecat.observers.loggers.transcription_log_observer import TranscriptionLogObserver
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.transcript_processor import TranscriptProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
|
||||
@@ -19,7 +19,6 @@ from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response import LLMAssistantAggregatorParams
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
|
||||
@@ -28,10 +28,10 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.services.openai.llm import OpenAIContextAggregatorPair, OpenAILLMService
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
@@ -45,11 +45,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams, LLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -46,11 +46,11 @@ from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.llm_service import FunctionCallParams, LLMService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -47,11 +47,11 @@ from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -83,8 +83,8 @@ mistral = []
|
||||
mlx-whisper = [ "mlx-whisper~=0.4.2" ]
|
||||
moondream = [ "accelerate~=1.10.0", "einops~=0.8.0", "pyvips[binary]~=3.0.0", "timm~=1.0.13", "transformers>=4.48.0" ]
|
||||
neuphonic = [ "pipecat-ai[websockets-base]" ]
|
||||
nim = []
|
||||
noisereduce = [ "noisereduce~=3.0.3" ]
|
||||
nvidia = [ "nvidia-riva-client~=2.21.1" ]
|
||||
openai = [ "pipecat-ai[websockets-base]" ]
|
||||
openpipe = [ "openpipe>=4.50.0,<6" ]
|
||||
openrouter = []
|
||||
@@ -93,7 +93,7 @@ playht = [ "pipecat-ai[websockets-base]" ]
|
||||
qwen = []
|
||||
remote-smart-turn = []
|
||||
rime = [ "pipecat-ai[websockets-base]" ]
|
||||
riva = [ "nvidia-riva-client~=2.21.1" ]
|
||||
riva = [ "pipecat-ai[nvidia]" ]
|
||||
runner = [ "python-dotenv>=1.0.0,<2.0.0", "uvicorn>=0.32.0,<1.0.0", "fastapi>=0.115.6,<0.122.0", "pipecat-ai-small-webrtc-prebuilt>=1.0.0"]
|
||||
sagemaker = ["aws_sdk_sagemaker_runtime_http2; python_version>='3.12'"]
|
||||
sambanova = []
|
||||
|
||||
@@ -103,7 +103,7 @@ TESTS_07 = [
|
||||
("07o-interruptible-assemblyai.py", EVAL_SIMPLE_MATH),
|
||||
("07q-interruptible-rime.py", EVAL_SIMPLE_MATH),
|
||||
("07q-interruptible-rime-http.py", EVAL_SIMPLE_MATH),
|
||||
("07r-interruptible-riva-nim.py", EVAL_SIMPLE_MATH),
|
||||
("07r-interruptible-nvidia.py", EVAL_SIMPLE_MATH),
|
||||
("07s-interruptible-google-audio-in.py", EVAL_SIMPLE_MATH),
|
||||
("07t-interruptible-fish.py", EVAL_SIMPLE_MATH),
|
||||
("07v-interruptible-neuphonic.py", EVAL_SIMPLE_MATH),
|
||||
@@ -136,7 +136,7 @@ TESTS_14 = [
|
||||
("14g-function-calling-grok.py", EVAL_WEATHER),
|
||||
("14h-function-calling-azure.py", EVAL_WEATHER),
|
||||
("14i-function-calling-fireworks.py", EVAL_WEATHER),
|
||||
("14j-function-calling-nim.py", EVAL_WEATHER),
|
||||
("14j-function-calling-nvidia.py", EVAL_WEATHER),
|
||||
("14k-function-calling-cerebras.py", EVAL_WEATHER),
|
||||
("14m-function-calling-openrouter.py", EVAL_WEATHER),
|
||||
("14n-function-calling-perplexity.py", EVAL_WEATHER),
|
||||
|
||||
@@ -18,8 +18,10 @@ from loguru import logger
|
||||
from pipecat.audio.dtmf.types import KeypadEntry
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
EndFrame,
|
||||
Frame,
|
||||
LLMContextFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMTextFrame,
|
||||
OutputDTMFUrgentFrame,
|
||||
@@ -149,11 +151,18 @@ class IVRProcessor(FrameProcessor):
|
||||
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
# Process text through the pattern aggregator
|
||||
result = await self._aggregator.aggregate(frame.text)
|
||||
if result:
|
||||
async for result in self._aggregator.aggregate(frame.text):
|
||||
# Push aggregated text that doesn't contain XML patterns
|
||||
await self.push_frame(LLMTextFrame(result.text), direction)
|
||||
|
||||
elif isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
# Flush any remaining text from the aggregator
|
||||
remaining = await self._aggregator.flush()
|
||||
if remaining:
|
||||
await self.push_frame(LLMTextFrame(remaining.text), direction)
|
||||
# Push the end frame
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
@@ -40,8 +40,8 @@ from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, FrameProcessorSetup
|
||||
from pipecat.services.llm_service import LLMService
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.sync.event_notifier import EventNotifier
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
|
||||
|
||||
class NotifierGate(FrameProcessor):
|
||||
|
||||
@@ -330,7 +330,7 @@ class TextFrame(DataFrame):
|
||||
"""
|
||||
|
||||
text: str
|
||||
skip_tts: bool = field(init=False)
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
# Whether any necessary inter-frame (leading/trailing) spaces are already
|
||||
# included in the text.
|
||||
# NOTE: Ideally this would be available at init time with a default value,
|
||||
@@ -343,7 +343,7 @@ class TextFrame(DataFrame):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
self.skip_tts = None
|
||||
self.includes_inter_frame_spaces = False
|
||||
self.append_to_context = True
|
||||
|
||||
@@ -1632,22 +1632,22 @@ class LLMFullResponseStartFrame(ControlFrame):
|
||||
more TextFrames and a final LLMFullResponseEndFrame.
|
||||
"""
|
||||
|
||||
skip_tts: bool = field(init=False)
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
self.skip_tts = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMFullResponseEndFrame(ControlFrame):
|
||||
"""Frame indicating the end of an LLM response."""
|
||||
|
||||
skip_tts: bool = field(init=False)
|
||||
skip_tts: Optional[bool] = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.skip_tts = False
|
||||
self.skip_tts = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
from pipecat.frames.frames import CancelFrame, EndFrame, Frame, LLMContextFrame, StartFrame
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
|
||||
|
||||
class GatedLLMContextAggregator(FrameProcessor):
|
||||
|
||||
@@ -83,8 +83,7 @@ class LLMTextProcessor(FrameProcessor):
|
||||
await self._text_aggregator.reset()
|
||||
|
||||
async def _handle_llm_text(self, in_frame: LLMTextFrame):
|
||||
aggregation = await self._text_aggregator.aggregate(in_frame.text)
|
||||
if aggregation:
|
||||
async for aggregation in self._text_aggregator.aggregate(in_frame.text):
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=aggregation.text,
|
||||
aggregated_by=aggregation.type,
|
||||
@@ -92,15 +91,13 @@ class LLMTextProcessor(FrameProcessor):
|
||||
out_frame.skip_tts = in_frame.skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
async def _handle_llm_end(self, skip_tts: bool = False):
|
||||
# Flush any remaining aggregated text at the end of the LLM response
|
||||
aggregation = self._text_aggregator.text
|
||||
await self._text_aggregator.reset()
|
||||
text = aggregation.text.strip()
|
||||
if text:
|
||||
async def _handle_llm_end(self, skip_tts: Optional[bool] = None):
|
||||
# Flush any remaining text
|
||||
remaining = await self._text_aggregator.flush()
|
||||
if remaining:
|
||||
out_frame = AggregatedTextFrame(
|
||||
text=text,
|
||||
aggregated_by=aggregation.type,
|
||||
text=remaining.text,
|
||||
aggregated_by=remaining.type,
|
||||
)
|
||||
out_frame.skip_tts = skip_tts
|
||||
await self.push_frame(out_frame)
|
||||
|
||||
@@ -10,7 +10,7 @@ from typing import Awaitable, Callable, Tuple, Type
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
|
||||
|
||||
class WakeNotifierFilter(FrameProcessor):
|
||||
|
||||
@@ -244,6 +244,11 @@ class DeepgramFluxSTTService(WebsocketSTTService):
|
||||
additional_headers={"Authorization": f"Token {self._api_key}"},
|
||||
)
|
||||
|
||||
headers = {
|
||||
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
|
||||
# Creating the receiver task
|
||||
if not self._receive_task:
|
||||
self._receive_task = self.create_task(
|
||||
|
||||
@@ -234,6 +234,13 @@ class DeepgramSTTService(STTService):
|
||||
|
||||
if not await self._connection.start(options=self._settings, addons=self._addons):
|
||||
await self.push_error(error_msg=f"Unable to connect to Deepgram")
|
||||
else:
|
||||
headers = {
|
||||
k: v
|
||||
for k, v in self._connection._socket.response.headers.items()
|
||||
if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
|
||||
async def _disconnect(self):
|
||||
if await self._connection.is_connected():
|
||||
|
||||
@@ -71,7 +71,12 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
encoding: Audio encoding format. Defaults to "linear16".
|
||||
**kwargs: Additional arguments passed to parent InterruptibleTTSService class.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
pause_frame_processing=True,
|
||||
push_stop_frames=True,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
@@ -165,6 +170,11 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
|
||||
self._websocket = await websocket_connect(url, additional_headers=headers)
|
||||
|
||||
headers = {
|
||||
k: v for k, v in self._websocket.response.headers.items() if k.startswith("dg-")
|
||||
}
|
||||
logger.debug(f'{self}: Websocket connection initialized: {{"headers": {headers}}}')
|
||||
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
@@ -231,7 +241,6 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
logger.trace(f"Received Flushed: {msg}")
|
||||
# Flushed indicates the end of audio generation for the current buffer
|
||||
# This happens after flush_audio() is called
|
||||
await self.push_frame(TTSStoppedFrame())
|
||||
elif msg_type == "Cleared":
|
||||
logger.trace(f"Received Cleared: {msg}")
|
||||
# Buffer has been cleared after interruption
|
||||
@@ -286,7 +295,7 @@ class DeepgramTTSService(WebsocketTTSService):
|
||||
speak_msg = {"type": "Speak", "text": text}
|
||||
await self._get_websocket().send(json.dumps(speak_msg))
|
||||
|
||||
# The actual audio frames will be handled in _receive_messages
|
||||
# The audio frames will be handled in _receive_messages
|
||||
yield None
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -166,23 +166,27 @@ class LLMService(AIService):
|
||||
# However, subclasses should override this with a more specific adapter when necessary.
|
||||
adapter_class: Type[BaseLLMAdapter] = OpenAILLMAdapter
|
||||
|
||||
def __init__(self, run_in_parallel: bool = True, **kwargs):
|
||||
def __init__(self, run_in_parallel: bool = True, wait_for_all: bool = False, **kwargs):
|
||||
"""Initialize the LLM service.
|
||||
|
||||
Args:
|
||||
run_in_parallel: Whether to run function calls in parallel or sequentially.
|
||||
Defaults to True.
|
||||
wait_for_all: Whether to wait for all function calls (parallel or
|
||||
sequential) to complete. Defaults to False.
|
||||
**kwargs: Additional arguments passed to the parent AIService.
|
||||
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._run_in_parallel = run_in_parallel
|
||||
self._wait_for_all = wait_for_all
|
||||
self._start_callbacks = {}
|
||||
self._adapter = self.adapter_class()
|
||||
self._functions: Dict[Optional[str], FunctionCallRegistryItem] = {}
|
||||
self._function_call_tasks: Dict[asyncio.Task, FunctionCallRunnerItem] = {}
|
||||
self._function_call_tasks: Dict[Optional[asyncio.Task], FunctionCallRunnerItem] = {}
|
||||
self._sequential_runner_task: Optional[asyncio.Task] = None
|
||||
self._tracing_enabled: bool = False
|
||||
self._skip_tts: bool = False
|
||||
self._skip_tts: Optional[bool] = None
|
||||
|
||||
self._register_event_handler("on_function_calls_started")
|
||||
self._register_event_handler("on_completion_timeout")
|
||||
@@ -293,7 +297,8 @@ class LLMService(AIService):
|
||||
direction: The direction of frame pushing.
|
||||
"""
|
||||
if isinstance(frame, (LLMTextFrame, LLMFullResponseStartFrame, LLMFullResponseEndFrame)):
|
||||
frame.skip_tts = self._skip_tts
|
||||
if self._skip_tts is not None:
|
||||
frame.skip_tts = self._skip_tts
|
||||
|
||||
await super().push_frame(frame, direction)
|
||||
|
||||
@@ -435,6 +440,7 @@ class LLMService(AIService):
|
||||
|
||||
await self.broadcast_frame(FunctionCallsStartedFrame, function_calls=function_calls)
|
||||
|
||||
runner_items = []
|
||||
for function_call in function_calls:
|
||||
if function_call.function_name in self._functions.keys():
|
||||
item = self._functions[function_call.function_name]
|
||||
@@ -446,28 +452,20 @@ class LLMService(AIService):
|
||||
)
|
||||
continue
|
||||
|
||||
runner_item = FunctionCallRunnerItem(
|
||||
registry_item=item,
|
||||
function_name=function_call.function_name,
|
||||
tool_call_id=function_call.tool_call_id,
|
||||
arguments=function_call.arguments,
|
||||
context=function_call.context,
|
||||
runner_items.append(
|
||||
FunctionCallRunnerItem(
|
||||
registry_item=item,
|
||||
function_name=function_call.function_name,
|
||||
tool_call_id=function_call.tool_call_id,
|
||||
arguments=function_call.arguments,
|
||||
context=function_call.context,
|
||||
)
|
||||
)
|
||||
|
||||
if self._run_in_parallel:
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
self._function_call_tasks[task] = runner_item
|
||||
task.add_done_callback(self._function_call_task_finished)
|
||||
else:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _call_start_function(
|
||||
self, context: OpenAILLMContext | LLMContext, function_name: str
|
||||
):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
await self._start_callbacks[function_name](function_name, self, context)
|
||||
elif None in self._start_callbacks.keys():
|
||||
return await self._start_callbacks[None](function_name, self, context)
|
||||
if self._run_in_parallel:
|
||||
await self._run_parallel_function_calls(runner_items)
|
||||
else:
|
||||
await self._run_sequential_function_calls(runner_items)
|
||||
|
||||
async def request_image_frame(
|
||||
self,
|
||||
@@ -540,6 +538,46 @@ class LLMService(AIService):
|
||||
await task
|
||||
del self._function_call_tasks[task]
|
||||
|
||||
async def _run_parallel_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
|
||||
tasks = []
|
||||
for runner_item in runner_items:
|
||||
task = self.create_task(self._run_function_call(runner_item))
|
||||
tasks.append(task)
|
||||
self._function_call_tasks[task] = runner_item
|
||||
task.add_done_callback(self._function_call_task_finished)
|
||||
|
||||
if self._wait_for_all:
|
||||
# Protect gather from being cancelled. This will protect all tasks
|
||||
# form being cancelled. That is fine, because we cancel them
|
||||
# explicitly when handling the interruption (InterruptionFrame). We
|
||||
# need to set `return_exceptions=True` because `asyncio.shield()`
|
||||
# will get cancelled (from FrameProcessor process task), then
|
||||
# `asyncio.gather()` will keep running (because it was protected by
|
||||
# the shield). Then, individiaul function call tasks will be
|
||||
# cancelled by us and we don't need to propagate those
|
||||
# CancelledErrors at that point.
|
||||
await asyncio.shield(asyncio.gather(*tasks, return_exceptions=True))
|
||||
|
||||
async def _run_sequential_function_calls(self, runner_items: Sequence[FunctionCallRunnerItem]):
|
||||
if self._wait_for_all:
|
||||
# Run each function call sequentially, waiting for each to complete.
|
||||
for runner_item in runner_items:
|
||||
self._function_call_tasks[None] = runner_item
|
||||
await self._run_function_call(runner_item)
|
||||
del self._function_call_tasks[None]
|
||||
else:
|
||||
# Enqueue all function calls for background execution.
|
||||
for runner_item in runner_items:
|
||||
await self._sequential_runner_queue.put(runner_item)
|
||||
|
||||
async def _call_start_function(
|
||||
self, context: OpenAILLMContext | LLMContext, function_name: str
|
||||
):
|
||||
if function_name in self._start_callbacks.keys():
|
||||
await self._start_callbacks[function_name](function_name, self, context)
|
||||
elif None in self._start_callbacks.keys():
|
||||
return await self._start_callbacks[None](function_name, self, context)
|
||||
|
||||
async def _run_function_call(self, runner_item: FunctionCallRunnerItem):
|
||||
if runner_item.function_name in self._functions.keys():
|
||||
item = self._functions[runner_item.function_name]
|
||||
@@ -623,20 +661,19 @@ class LLMService(AIService):
|
||||
name = runner_item.function_name
|
||||
tool_call_id = runner_item.tool_call_id
|
||||
|
||||
# We remove the callback because we are going to cancel the task
|
||||
# now, otherwise we will be removing it from the set while we
|
||||
# are iterating.
|
||||
task.remove_done_callback(self._function_call_task_finished)
|
||||
|
||||
logger.debug(f"{self} Cancelling function call [{name}:{tool_call_id}]...")
|
||||
|
||||
await self.cancel_task(task)
|
||||
if task:
|
||||
# We remove the callback because we are going to cancel the
|
||||
# task next, otherwise we will be removing it from the set
|
||||
# while we are iterating.
|
||||
task.remove_done_callback(self._function_call_task_finished)
|
||||
await self.cancel_task(task)
|
||||
cancelled_tasks.add(task)
|
||||
|
||||
frame = FunctionCallCancelFrame(function_name=name, tool_call_id=tool_call_id)
|
||||
await self.push_frame(frame)
|
||||
|
||||
cancelled_tasks.add(task)
|
||||
|
||||
logger.debug(f"{self} Function call [{name}:{tool_call_id}] has been cancelled")
|
||||
|
||||
# Remove all cancelled tasks from our set.
|
||||
|
||||
@@ -8,98 +8,23 @@
|
||||
|
||||
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
|
||||
Microservice) API while maintaining compatibility with the OpenAI-style interface.
|
||||
|
||||
.. deprecated:: 0.0.96
|
||||
This module is deprecated. Please NvidiaLLMService from
|
||||
pipecat.services.nvidia.llm instead.
|
||||
"""
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
import warnings
|
||||
|
||||
from pipecat.services.nvidia.llm import NvidiaLLMService
|
||||
|
||||
class NimLLMService(OpenAILLMService):
|
||||
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"NimLLMService from pipecat.services.nim.llm is deprecated. "
|
||||
"Please use NvidiaLLMService from pipecat.services.nvidia.llm instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
|
||||
compatibility with the OpenAI-style interface. It specifically handles the difference
|
||||
in token usage reporting between NIM (incremental) and OpenAI (final summary).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NimLLMService.
|
||||
|
||||
Args:
|
||||
api_key: The API key for accessing NVIDIA's NIM API.
|
||||
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
|
||||
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService.
|
||||
"""
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
# Counters for accumulating token usage metrics
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = False
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle NVIDIA's
|
||||
incremental token reporting style, accumulating the counts and reporting
|
||||
them once at the end of processing.
|
||||
|
||||
Args:
|
||||
context: The context to process, containing messages and other information
|
||||
needed for the LLM interaction.
|
||||
"""
|
||||
# Reset all counters and flags at the start of processing
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
try:
|
||||
await super()._process_context(context)
|
||||
finally:
|
||||
self._is_processing = False
|
||||
# Report final accumulated token usage at the end of processing
|
||||
if self._prompt_tokens > 0 or self._completion_tokens > 0:
|
||||
self._total_tokens = self._prompt_tokens + self._completion_tokens
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
|
||||
"""Accumulate token usage metrics during processing.
|
||||
|
||||
This method intercepts the incremental token updates from NVIDIA's API
|
||||
and accumulates them instead of passing each update to the metrics system.
|
||||
The final accumulated totals are reported at the end of processing.
|
||||
|
||||
Args:
|
||||
tokens: The token usage metrics for the current chunk of processing,
|
||||
containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
return
|
||||
|
||||
# Record prompt tokens the first time we see them
|
||||
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
|
||||
self._prompt_tokens = tokens.prompt_tokens
|
||||
self._has_reported_prompt_tokens = True
|
||||
|
||||
# Update completion tokens count if it has increased
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
NimLLMService = NvidiaLLMService
|
||||
|
||||
0
src/pipecat/services/nvidia/__init__.py
Normal file
0
src/pipecat/services/nvidia/__init__.py
Normal file
105
src/pipecat/services/nvidia/llm.py
Normal file
105
src/pipecat/services/nvidia/llm.py
Normal file
@@ -0,0 +1,105 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA NIM API service implementation.
|
||||
|
||||
This module provides a service for interacting with NVIDIA's NIM (NVIDIA Inference
|
||||
Microservice) API while maintaining compatibility with the OpenAI-style interface.
|
||||
"""
|
||||
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
class NvidiaLLMService(OpenAILLMService):
|
||||
"""A service for interacting with NVIDIA's NIM (NVIDIA Inference Microservice) API.
|
||||
|
||||
This service extends OpenAILLMService to work with NVIDIA's NIM API while maintaining
|
||||
compatibility with the OpenAI-style interface. It specifically handles the difference
|
||||
in token usage reporting between NIM (incremental) and OpenAI (final summary).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://integrate.api.nvidia.com/v1",
|
||||
model: str = "nvidia/llama-3.1-nemotron-70b-instruct",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NvidiaLLMService.
|
||||
|
||||
Args:
|
||||
api_key: The API key for accessing NVIDIA's NIM API.
|
||||
base_url: The base URL for NIM API. Defaults to "https://integrate.api.nvidia.com/v1".
|
||||
model: The model identifier to use. Defaults to "nvidia/llama-3.1-nemotron-70b-instruct".
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService.
|
||||
"""
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
# Counters for accumulating token usage metrics
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = False
|
||||
|
||||
async def _process_context(self, context: OpenAILLMContext | LLMContext):
|
||||
"""Process a context through the LLM and accumulate token usage metrics.
|
||||
|
||||
This method overrides the parent class implementation to handle NVIDIA's
|
||||
incremental token reporting style, accumulating the counts and reporting
|
||||
them once at the end of processing.
|
||||
|
||||
Args:
|
||||
context: The context to process, containing messages and other information
|
||||
needed for the LLM interaction.
|
||||
"""
|
||||
# Reset all counters and flags at the start of processing
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._total_tokens = 0
|
||||
self._has_reported_prompt_tokens = False
|
||||
self._is_processing = True
|
||||
|
||||
try:
|
||||
await super()._process_context(context)
|
||||
finally:
|
||||
self._is_processing = False
|
||||
# Report final accumulated token usage at the end of processing
|
||||
if self._prompt_tokens > 0 or self._completion_tokens > 0:
|
||||
self._total_tokens = self._prompt_tokens + self._completion_tokens
|
||||
tokens = LLMTokenUsage(
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
)
|
||||
await super().start_llm_usage_metrics(tokens)
|
||||
|
||||
async def start_llm_usage_metrics(self, tokens: LLMTokenUsage):
|
||||
"""Accumulate token usage metrics during processing.
|
||||
|
||||
This method intercepts the incremental token updates from NVIDIA's API
|
||||
and accumulates them instead of passing each update to the metrics system.
|
||||
The final accumulated totals are reported at the end of processing.
|
||||
|
||||
Args:
|
||||
tokens: The token usage metrics for the current chunk of processing,
|
||||
containing prompt_tokens and completion_tokens counts.
|
||||
"""
|
||||
# Only accumulate metrics during active processing
|
||||
if not self._is_processing:
|
||||
return
|
||||
|
||||
# Record prompt tokens the first time we see them
|
||||
if not self._has_reported_prompt_tokens and tokens.prompt_tokens > 0:
|
||||
self._prompt_tokens = tokens.prompt_tokens
|
||||
self._has_reported_prompt_tokens = True
|
||||
|
||||
# Update completion tokens count if it has increased
|
||||
if tokens.completion_tokens > self._completion_tokens:
|
||||
self._completion_tokens = tokens.completion_tokens
|
||||
663
src/pipecat/services/nvidia/stt.py
Normal file
663
src/pipecat/services/nvidia/stt.py
Normal file
@@ -0,0 +1,663 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription."""
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import CancelledError as FuturesCancelledError
|
||||
from typing import AsyncGenerator, List, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
)
|
||||
from pipecat.services.stt_service import SegmentedSTTService, STTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[nvidia]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_nvidia_riva_language(language: Language) -> Optional[str]:
|
||||
"""Maps Language enum to NVIDIA Riva ASR language codes.
|
||||
|
||||
Source:
|
||||
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/asr/asr-riva-build-table.html?highlight=fr%20fr
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
Optional[str]: NVIDIA Riva language code or None if not supported.
|
||||
"""
|
||||
LANGUAGE_MAP = {
|
||||
# Arabic
|
||||
Language.AR: "ar-AR",
|
||||
# English
|
||||
Language.EN: "en-US", # Default to US
|
||||
Language.EN_US: "en-US",
|
||||
Language.EN_GB: "en-GB",
|
||||
# French
|
||||
Language.FR: "fr-FR",
|
||||
Language.FR_FR: "fr-FR",
|
||||
# German
|
||||
Language.DE: "de-DE",
|
||||
Language.DE_DE: "de-DE",
|
||||
# Hindi
|
||||
Language.HI: "hi-IN",
|
||||
Language.HI_IN: "hi-IN",
|
||||
# Italian
|
||||
Language.IT: "it-IT",
|
||||
Language.IT_IT: "it-IT",
|
||||
# Japanese
|
||||
Language.JA: "ja-JP",
|
||||
Language.JA_JP: "ja-JP",
|
||||
# Korean
|
||||
Language.KO: "ko-KR",
|
||||
Language.KO_KR: "ko-KR",
|
||||
# Portuguese
|
||||
Language.PT: "pt-BR", # Default to Brazilian
|
||||
Language.PT_BR: "pt-BR",
|
||||
# Russian
|
||||
Language.RU: "ru-RU",
|
||||
Language.RU_RU: "ru-RU",
|
||||
# Spanish
|
||||
Language.ES: "es-ES", # Default to Spain
|
||||
Language.ES_ES: "es-ES",
|
||||
Language.ES_US: "es-US", # US Spanish
|
||||
}
|
||||
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
|
||||
|
||||
|
||||
class NvidiaSTTService(STTService):
|
||||
"""Real-time speech-to-text service using NVIDIA Riva streaming ASR.
|
||||
|
||||
Provides real-time transcription capabilities using NVIDIA's Riva ASR models
|
||||
through streaming recognition. Supports interim results and continuous audio
|
||||
processing for low-latency applications.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for NVIDIA Riva STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
|
||||
"model_name": "parakeet-ctc-1.1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: NVIDIA Riva server address. Defaults to NVIDIA Cloud Function endpoint.
|
||||
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
|
||||
params: Additional configuration parameters for NVIDIA Riva.
|
||||
**kwargs: Additional arguments passed to STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaSTTService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._profanity_filter = False
|
||||
self._automatic_punctuation = True
|
||||
self._no_verbatim_transcripts = False
|
||||
self._language_code = params.language
|
||||
self._boosted_lm_words = None
|
||||
self._boosted_lm_score = 4.0
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self._settings = {
|
||||
"language": str(params.language),
|
||||
"profanity_filter": self._profanity_filter,
|
||||
"automatic_punctuation": self._automatic_punctuation,
|
||||
"verbatim_transcripts": not self._no_verbatim_transcripts,
|
||||
"boosted_lm_words": self._boosted_lm_words,
|
||||
"boosted_lm_score": self._boosted_lm_score,
|
||||
}
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
self._queue = None
|
||||
self._config = None
|
||||
self._thread_task = None
|
||||
self._response_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
False - this service does not support metrics generation.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the NVIDIA Riva STT service and initialize streaming configuration.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
|
||||
if self._config:
|
||||
return
|
||||
|
||||
config = riva.client.StreamingRecognitionConfig(
|
||||
config=riva.client.RecognitionConfig(
|
||||
encoding=riva.client.AudioEncoding.LINEAR_PCM,
|
||||
language_code=self._language_code,
|
||||
model="",
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=not self._no_verbatim_transcripts,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
audio_channel_count=1,
|
||||
),
|
||||
interim_results=True,
|
||||
)
|
||||
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
self._config = config
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
if not self._thread_task:
|
||||
self._thread_task = self.create_task(self._thread_task_handler())
|
||||
|
||||
if not self._response_task:
|
||||
self._response_queue = asyncio.Queue()
|
||||
self._response_task = self.create_task(self._response_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the NVIDIA Riva STT service and clean up resources.
|
||||
|
||||
Args:
|
||||
frame: EndFrame indicating pipeline stop.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the NVIDIA Riva STT service operation.
|
||||
|
||||
Args:
|
||||
frame: CancelFrame indicating operation cancellation.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def _stop_tasks(self):
|
||||
if self._thread_task:
|
||||
await self.cancel_task(self._thread_task)
|
||||
self._thread_task = None
|
||||
|
||||
if self._response_task:
|
||||
await self.cancel_task(self._response_task)
|
||||
self._response_task = None
|
||||
|
||||
def _response_handler(self):
|
||||
responses = self._asr_service.streaming_response_generator(
|
||||
audio_chunks=self,
|
||||
streaming_config=self._config,
|
||||
)
|
||||
for response in responses:
|
||||
if not response.results:
|
||||
continue
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._response_queue.put(response), self.get_event_loop()
|
||||
)
|
||||
|
||||
async def _thread_task_handler(self):
|
||||
try:
|
||||
self._thread_running = True
|
||||
await asyncio.to_thread(self._response_handler)
|
||||
except asyncio.CancelledError:
|
||||
self._thread_running = False
|
||||
raise
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def _handle_response(self, response):
|
||||
for result in response.results:
|
||||
if result and not result.alternatives:
|
||||
continue
|
||||
|
||||
transcript = result.alternatives[0].transcript
|
||||
if transcript and len(transcript) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
if result.is_final:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=result.is_final,
|
||||
language=self._language_code,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
async def _response_task_handler(self):
|
||||
while True:
|
||||
response = await self._response_queue.get()
|
||||
await self._handle_response(response)
|
||||
self._response_queue.task_done()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
None - transcription results are pushed to the pipeline via frames.
|
||||
"""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._queue.put(audio)
|
||||
yield None
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
"""Get the next audio chunk for NVIDIA Riva processing.
|
||||
|
||||
Returns:
|
||||
Audio bytes from the queue.
|
||||
|
||||
Raises:
|
||||
StopIteration: When the thread is no longer running.
|
||||
"""
|
||||
if not self._thread_running:
|
||||
raise StopIteration
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
|
||||
return future.result()
|
||||
except FuturesCancelledError:
|
||||
raise StopIteration
|
||||
|
||||
def __iter__(self):
|
||||
"""Return iterator for audio chunk processing.
|
||||
|
||||
Returns:
|
||||
Self as iterator.
|
||||
"""
|
||||
return self
|
||||
|
||||
|
||||
class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
"""Speech-to-text service using NVIDIA Riva's offline/batch models.
|
||||
|
||||
By default, his service uses NVIDIA's Riva Canary ASR API to perform speech-to-text
|
||||
transcription on audio segments. It inherits from SegmentedSTTService to handle
|
||||
audio buffering and speech detection.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for NVIDIA Riva segmented STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
profanity_filter: Whether to filter profanity from results.
|
||||
automatic_punctuation: Whether to add automatic punctuation.
|
||||
verbatim_transcripts: Whether to return verbatim transcripts.
|
||||
boosted_lm_words: List of words to boost in language model.
|
||||
boosted_lm_score: Score boost for specified words.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
profanity_filter: bool = False
|
||||
automatic_punctuation: bool = True
|
||||
verbatim_transcripts: bool = False
|
||||
boosted_lm_words: Optional[List[str]] = None
|
||||
boosted_lm_score: float = 4.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "ee8dc628-76de-4acc-8595-1836e7e857bd",
|
||||
"model_name": "canary-1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva segmented STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication
|
||||
server: NVIDIA Riva server address (defaults to NVIDIA Cloud Function endpoint)
|
||||
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
|
||||
params: Additional configuration parameters for NVIDIA Riva
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaSegmentedSTTService.InputParams()
|
||||
|
||||
# Set model name
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
# Initialize NVIDIA Riva settings
|
||||
self._api_key = api_key
|
||||
self._server = server
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
self._model_name = model_function_map.get("model_name")
|
||||
|
||||
# Store the language as a Language enum and as a string
|
||||
self._language_enum = params.language or Language.EN_US
|
||||
self._language = self.language_to_service_language(self._language_enum) or "en-US"
|
||||
|
||||
# Configure transcription parameters
|
||||
self._profanity_filter = params.profanity_filter
|
||||
self._automatic_punctuation = params.automatic_punctuation
|
||||
self._verbatim_transcripts = params.verbatim_transcripts
|
||||
self._boosted_lm_words = params.boosted_lm_words
|
||||
self._boosted_lm_score = params.boosted_lm_score
|
||||
|
||||
# Voice activity detection thresholds (use NVIDIA Riva defaults)
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
|
||||
# Create NVIDIA Riva client
|
||||
self._config = None
|
||||
self._asr_service = None
|
||||
self._settings = {"language": self._language_enum}
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert pipecat Language enum to NVIDIA Riva's language code.
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
NVIDIA Riva language code or None if not supported.
|
||||
"""
|
||||
return language_to_nvidia_riva_language(language)
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize the NVIDIA Riva ASR client with authentication metadata."""
|
||||
if self._asr_service is not None:
|
||||
return
|
||||
|
||||
# Set up authentication metadata for NVIDIA Cloud Functions
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {self._api_key}"],
|
||||
]
|
||||
|
||||
# Create authenticated client
|
||||
auth = riva.client.Auth(None, True, self._server, metadata)
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
logger.info(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
|
||||
|
||||
def _create_recognition_config(self):
|
||||
"""Create the NVIDIA Riva ASR recognition configuration."""
|
||||
# Create base configuration
|
||||
config = riva.client.RecognitionConfig(
|
||||
language_code=self._language, # Now using the string, not a tuple
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=self._verbatim_transcripts,
|
||||
)
|
||||
|
||||
# Add word boosting if specified
|
||||
if self._boosted_lm_words:
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
# Add voice activity detection parameters
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
|
||||
# Add any custom configuration
|
||||
if self._custom_configuration:
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
return config
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True - this service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Initialize the service when the pipeline starts.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._initialize_client()
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the language for the STT service.
|
||||
|
||||
Args:
|
||||
language: Target language for transcription.
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._language_enum = language
|
||||
self._language = self.language_to_service_language(language) or "en-US"
|
||||
self._settings["language"] = language
|
||||
|
||||
# Update configuration with new language
|
||||
if self._config:
|
||||
self._config.language_code = self._language
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Transcribe an audio segment.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes in WAV format (already converted by base class).
|
||||
|
||||
Yields:
|
||||
Frame: TranscriptionFrame containing the transcribed text.
|
||||
"""
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Make sure the client is initialized
|
||||
if self._asr_service is None:
|
||||
self._initialize_client()
|
||||
|
||||
# Make sure the config is created
|
||||
if self._config is None:
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
# Type assertion to satisfy the IDE
|
||||
assert self._asr_service is not None, "ASR service not initialized"
|
||||
assert self._config is not None, "Recognition config not created"
|
||||
|
||||
# Process audio with NVIDIA Riva ASR - explicitly request non-future response
|
||||
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Process the response - handle different possible return types
|
||||
try:
|
||||
# If it's a future-like object, get the result
|
||||
if hasattr(raw_response, "result"):
|
||||
response = raw_response.result()
|
||||
else:
|
||||
response = raw_response
|
||||
|
||||
# Process transcription results
|
||||
transcription_found = False
|
||||
|
||||
# Now we can safely check results
|
||||
# Type hint for the IDE
|
||||
results = getattr(response, "results", [])
|
||||
|
||||
for result in results:
|
||||
alternatives = getattr(result, "alternatives", [])
|
||||
if alternatives:
|
||||
text = alternatives[0].transcript.strip()
|
||||
if text:
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_enum,
|
||||
)
|
||||
transcription_found = True
|
||||
|
||||
await self._handle_transcription(text, True, self._language_enum)
|
||||
|
||||
if not transcription_found:
|
||||
logger.debug("No transcription results found in NVIDIA Riva response")
|
||||
|
||||
except AttributeError as ae:
|
||||
logger.error(f"Unexpected response structure from NVIDIA Riva: {ae}")
|
||||
yield ErrorFrame(f"Unexpected NVIDIA Riva response format: {str(ae)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
187
src/pipecat/services/nvidia/tts.py
Normal file
187
src/pipecat/services/nvidia/tts.py
Normal file
@@ -0,0 +1,187 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA Riva text-to-speech service implementation.
|
||||
|
||||
This module provides integration with NVIDIA Riva's TTS services through
|
||||
gRPC API for high-quality speech synthesis.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Mapping, Optional
|
||||
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
NVIDIA_TTS_TIMEOUT_SECS = 5
|
||||
|
||||
|
||||
class NvidiaTTSService(TTSService):
|
||||
"""NVIDIA Riva text-to-speech service.
|
||||
|
||||
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
|
||||
cloud-based TTS models. Supports multiple voices, languages, and
|
||||
configurable quality settings.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Riva TTS configuration.
|
||||
|
||||
Parameters:
|
||||
language: Language code for synthesis. Defaults to US English.
|
||||
quality: Audio quality setting (0-100). Defaults to 20.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
quality: Optional[int] = 20
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "Magpie-Multilingual.EN-US.Aria",
|
||||
sample_rate: Optional[int] = None,
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "877104f7-e885-42b9-8de8-f6e4c6303969",
|
||||
"model_name": "magpie-tts-multilingual",
|
||||
},
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva TTS service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
|
||||
voice_id: Voice model identifier. Defaults to multilingual Ray voice.
|
||||
sample_rate: Audio sample rate. If None, uses service default.
|
||||
model_function_map: Dictionary containing function_id and model_name for the TTS model.
|
||||
params: Additional configuration parameters for TTS synthesis.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or NvidiaTTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._language_code = params.language
|
||||
self._quality = params.quality
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
self.set_voice(voice_id)
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
# warm up the service
|
||||
config_response = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Attempt to set the TTS model.
|
||||
|
||||
Note: Model cannot be changed after initialization for Riva service.
|
||||
|
||||
Args:
|
||||
model: The model name to set (operation not supported).
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using NVIDIA Riva TTS.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech data.
|
||||
"""
|
||||
|
||||
def read_audio_responses(queue: asyncio.Queue):
|
||||
def add_response(r):
|
||||
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
|
||||
|
||||
try:
|
||||
responses = self._service.synthesize_online(
|
||||
text,
|
||||
self._voice_id,
|
||||
self._language_code,
|
||||
sample_rate_hz=self.sample_rate,
|
||||
zero_shot_audio_prompt_file=None,
|
||||
zero_shot_quality=self._quality,
|
||||
custom_dictionary={},
|
||||
)
|
||||
for r in responses:
|
||||
add_response(r)
|
||||
add_response(None)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
add_response(None)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
queue = asyncio.Queue()
|
||||
await asyncio.to_thread(read_audio_responses, queue)
|
||||
|
||||
# Wait for the thread to start.
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
|
||||
while resp:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=resp.audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
@@ -4,707 +4,32 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription."""
|
||||
"""NVIDIA Riva Speech-to-Text service implementations for real-time and batch transcription.
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import CancelledError as FuturesCancelledError
|
||||
from typing import AsyncGenerator, List, Mapping, Optional
|
||||
.. deprecated:: 0.0.96
|
||||
This module is deprecated. Please NvidiaSTTService from
|
||||
pipecat.services.nvidia.stt instead.
|
||||
"""
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
import warnings
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
from pipecat.services.nvidia.stt import (
|
||||
NvidiaSegmentedSTTService,
|
||||
NvidiaSTTService,
|
||||
language_to_nvidia_riva_language,
|
||||
)
|
||||
from pipecat.services.stt_service import SegmentedSTTService, STTService
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_riva_language(language: Language) -> Optional[str]:
|
||||
"""Maps Language enum to Riva ASR language codes.
|
||||
|
||||
Source:
|
||||
https://docs.nvidia.com/deeplearning/riva/user-guide/docs/asr/asr-riva-build-table.html?highlight=fr%20fr
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
Optional[str]: Riva language code or None if not supported.
|
||||
"""
|
||||
LANGUAGE_MAP = {
|
||||
# Arabic
|
||||
Language.AR: "ar-AR",
|
||||
# English
|
||||
Language.EN: "en-US", # Default to US
|
||||
Language.EN_US: "en-US",
|
||||
Language.EN_GB: "en-GB",
|
||||
# French
|
||||
Language.FR: "fr-FR",
|
||||
Language.FR_FR: "fr-FR",
|
||||
# German
|
||||
Language.DE: "de-DE",
|
||||
Language.DE_DE: "de-DE",
|
||||
# Hindi
|
||||
Language.HI: "hi-IN",
|
||||
Language.HI_IN: "hi-IN",
|
||||
# Italian
|
||||
Language.IT: "it-IT",
|
||||
Language.IT_IT: "it-IT",
|
||||
# Japanese
|
||||
Language.JA: "ja-JP",
|
||||
Language.JA_JP: "ja-JP",
|
||||
# Korean
|
||||
Language.KO: "ko-KR",
|
||||
Language.KO_KR: "ko-KR",
|
||||
# Portuguese
|
||||
Language.PT: "pt-BR", # Default to Brazilian
|
||||
Language.PT_BR: "pt-BR",
|
||||
# Russian
|
||||
Language.RU: "ru-RU",
|
||||
Language.RU_RU: "ru-RU",
|
||||
# Spanish
|
||||
Language.ES: "es-ES", # Default to Spain
|
||||
Language.ES_ES: "es-ES",
|
||||
Language.ES_US: "es-US", # US Spanish
|
||||
}
|
||||
|
||||
return resolve_language(language, LANGUAGE_MAP, use_base_code=False)
|
||||
|
||||
|
||||
class RivaSTTService(STTService):
|
||||
"""Real-time speech-to-text service using NVIDIA Riva streaming ASR.
|
||||
|
||||
Provides real-time transcription capabilities using NVIDIA's Riva ASR models
|
||||
through streaming recognition. Supports interim results and continuous audio
|
||||
processing for low-latency applications.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Riva STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
|
||||
"model_name": "parakeet-ctc-1.1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Riva STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: Riva server address. Defaults to NVIDIA Cloud Function endpoint.
|
||||
model_function_map: Mapping containing 'function_id' and 'model_name' for the ASR model.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
|
||||
params: Additional configuration parameters for Riva.
|
||||
**kwargs: Additional arguments passed to STTService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or RivaSTTService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._profanity_filter = False
|
||||
self._automatic_punctuation = True
|
||||
self._no_verbatim_transcripts = False
|
||||
self._language_code = params.language
|
||||
self._boosted_lm_words = None
|
||||
self._boosted_lm_score = 4.0
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self._settings = {
|
||||
"language": str(params.language),
|
||||
"profanity_filter": self._profanity_filter,
|
||||
"automatic_punctuation": self._automatic_punctuation,
|
||||
"verbatim_transcripts": not self._no_verbatim_transcripts,
|
||||
"boosted_lm_words": self._boosted_lm_words,
|
||||
"boosted_lm_score": self._boosted_lm_score,
|
||||
}
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
self._queue = None
|
||||
self._config = None
|
||||
self._thread_task = None
|
||||
self._response_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
False - this service does not support metrics generation.
|
||||
"""
|
||||
return False
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Riva STT service and initialize streaming configuration.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
|
||||
if self._config:
|
||||
return
|
||||
|
||||
config = riva.client.StreamingRecognitionConfig(
|
||||
config=riva.client.RecognitionConfig(
|
||||
encoding=riva.client.AudioEncoding.LINEAR_PCM,
|
||||
language_code=self._language_code,
|
||||
model="",
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=not self._no_verbatim_transcripts,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
audio_channel_count=1,
|
||||
),
|
||||
interim_results=True,
|
||||
)
|
||||
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
self._config = config
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
if not self._thread_task:
|
||||
self._thread_task = self.create_task(self._thread_task_handler())
|
||||
|
||||
if not self._response_task:
|
||||
self._response_queue = asyncio.Queue()
|
||||
self._response_task = self.create_task(self._response_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Riva STT service and clean up resources.
|
||||
|
||||
Args:
|
||||
frame: EndFrame indicating pipeline stop.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Riva STT service operation.
|
||||
|
||||
Args:
|
||||
frame: CancelFrame indicating operation cancellation.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._stop_tasks()
|
||||
|
||||
async def _stop_tasks(self):
|
||||
if self._thread_task:
|
||||
await self.cancel_task(self._thread_task)
|
||||
self._thread_task = None
|
||||
|
||||
if self._response_task:
|
||||
await self.cancel_task(self._response_task)
|
||||
self._response_task = None
|
||||
|
||||
def _response_handler(self):
|
||||
responses = self._asr_service.streaming_response_generator(
|
||||
audio_chunks=self,
|
||||
streaming_config=self._config,
|
||||
)
|
||||
for response in responses:
|
||||
if not response.results:
|
||||
continue
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._response_queue.put(response), self.get_event_loop()
|
||||
)
|
||||
|
||||
async def _thread_task_handler(self):
|
||||
try:
|
||||
self._thread_running = True
|
||||
await asyncio.to_thread(self._response_handler)
|
||||
except asyncio.CancelledError:
|
||||
self._thread_running = False
|
||||
raise
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def _handle_response(self, response):
|
||||
for result in response.results:
|
||||
if result and not result.alternatives:
|
||||
continue
|
||||
|
||||
transcript = result.alternatives[0].transcript
|
||||
if transcript and len(transcript) > 0:
|
||||
await self.stop_ttfb_metrics()
|
||||
if result.is_final:
|
||||
await self.stop_processing_metrics()
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription(
|
||||
transcript=transcript,
|
||||
is_final=result.is_final,
|
||||
language=self._language_code,
|
||||
)
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_code,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
|
||||
async def _response_task_handler(self):
|
||||
while True:
|
||||
response = await self._response_queue.get()
|
||||
await self._handle_response(response)
|
||||
self._response_queue.task_done()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text transcription.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes to transcribe.
|
||||
|
||||
Yields:
|
||||
None - transcription results are pushed to the pipeline via frames.
|
||||
"""
|
||||
await self.start_ttfb_metrics()
|
||||
await self.start_processing_metrics()
|
||||
await self._queue.put(audio)
|
||||
yield None
|
||||
|
||||
def __next__(self) -> bytes:
|
||||
"""Get the next audio chunk for Riva processing.
|
||||
|
||||
Returns:
|
||||
Audio bytes from the queue.
|
||||
|
||||
Raises:
|
||||
StopIteration: When the thread is no longer running.
|
||||
"""
|
||||
if not self._thread_running:
|
||||
raise StopIteration
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(self._queue.get(), self.get_event_loop())
|
||||
return future.result()
|
||||
except FuturesCancelledError:
|
||||
raise StopIteration
|
||||
|
||||
def __iter__(self):
|
||||
"""Return iterator for audio chunk processing.
|
||||
|
||||
Returns:
|
||||
Self as iterator.
|
||||
"""
|
||||
return self
|
||||
|
||||
|
||||
class RivaSegmentedSTTService(SegmentedSTTService):
|
||||
"""Speech-to-text service using NVIDIA Riva's offline/batch models.
|
||||
|
||||
By default, his service uses NVIDIA's Riva Canary ASR API to perform speech-to-text
|
||||
transcription on audio segments. It inherits from SegmentedSTTService to handle
|
||||
audio buffering and speech detection.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for Riva segmented STT service.
|
||||
|
||||
Parameters:
|
||||
language: Target language for transcription. Defaults to EN_US.
|
||||
profanity_filter: Whether to filter profanity from results.
|
||||
automatic_punctuation: Whether to add automatic punctuation.
|
||||
verbatim_transcripts: Whether to return verbatim transcripts.
|
||||
boosted_lm_words: List of words to boost in language model.
|
||||
boosted_lm_score: Score boost for specified words.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
profanity_filter: bool = False
|
||||
automatic_punctuation: bool = True
|
||||
verbatim_transcripts: bool = False
|
||||
boosted_lm_words: Optional[List[str]] = None
|
||||
boosted_lm_score: float = 4.0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "ee8dc628-76de-4acc-8595-1836e7e857bd",
|
||||
"model_name": "canary-1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Riva segmented STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication
|
||||
server: Riva server address (defaults to NVIDIA Cloud Function endpoint)
|
||||
model_function_map: Mapping of model name and its corresponding NVIDIA Cloud Function ID
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate
|
||||
params: Additional configuration parameters for Riva
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or RivaSegmentedSTTService.InputParams()
|
||||
|
||||
# Set model name
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
# Initialize Riva settings
|
||||
self._api_key = api_key
|
||||
self._server = server
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
self._model_name = model_function_map.get("model_name")
|
||||
|
||||
# Store the language as a Language enum and as a string
|
||||
self._language_enum = params.language or Language.EN_US
|
||||
self._language = self.language_to_service_language(self._language_enum) or "en-US"
|
||||
|
||||
# Configure transcription parameters
|
||||
self._profanity_filter = params.profanity_filter
|
||||
self._automatic_punctuation = params.automatic_punctuation
|
||||
self._verbatim_transcripts = params.verbatim_transcripts
|
||||
self._boosted_lm_words = params.boosted_lm_words
|
||||
self._boosted_lm_score = params.boosted_lm_score
|
||||
|
||||
# Voice activity detection thresholds (use Riva defaults)
|
||||
self._start_history = -1
|
||||
self._start_threshold = -1.0
|
||||
self._stop_history = -1
|
||||
self._stop_threshold = -1.0
|
||||
self._stop_history_eou = -1
|
||||
self._stop_threshold_eou = -1.0
|
||||
self._custom_configuration = ""
|
||||
|
||||
# Create Riva client
|
||||
self._config = None
|
||||
self._asr_service = None
|
||||
self._settings = {"language": self._language_enum}
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert pipecat Language enum to Riva's language code.
|
||||
|
||||
Args:
|
||||
language: Language enum value.
|
||||
|
||||
Returns:
|
||||
Riva language code or None if not supported.
|
||||
"""
|
||||
return language_to_riva_language(language)
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize the Riva ASR client with authentication metadata."""
|
||||
if self._asr_service is not None:
|
||||
return
|
||||
|
||||
# Set up authentication metadata for NVIDIA Cloud Functions
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {self._api_key}"],
|
||||
]
|
||||
|
||||
# Create authenticated client
|
||||
auth = riva.client.Auth(None, True, self._server, metadata)
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
logger.info(f"Initialized RivaSegmentedSTTService with model: {self.model_name}")
|
||||
|
||||
def _create_recognition_config(self):
|
||||
"""Create the Riva ASR recognition configuration."""
|
||||
# Create base configuration
|
||||
config = riva.client.RecognitionConfig(
|
||||
language_code=self._language, # Now using the string, not a tuple
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=self._verbatim_transcripts,
|
||||
)
|
||||
|
||||
# Add word boosting if specified
|
||||
if self._boosted_lm_words:
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
# Add voice activity detection parameters
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
|
||||
# Add any custom configuration
|
||||
if self._custom_configuration:
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
return config
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True - this service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Set the ASR model for transcription.
|
||||
|
||||
Args:
|
||||
model: Model name to set.
|
||||
|
||||
Note:
|
||||
Model cannot be changed after initialization. Use model_function_map
|
||||
parameter in constructor instead.
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Initialize the service when the pipeline starts.
|
||||
|
||||
Args:
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._initialize_client()
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the language for the STT service.
|
||||
|
||||
Args:
|
||||
language: Target language for transcription.
|
||||
"""
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._language_enum = language
|
||||
self._language = self.language_to_service_language(language) or "en-US"
|
||||
self._settings["language"] = language
|
||||
|
||||
# Update configuration with new language
|
||||
if self._config:
|
||||
self._config.language_code = self._language
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
pass
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Transcribe an audio segment.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes in WAV format (already converted by base class).
|
||||
|
||||
Yields:
|
||||
Frame: TranscriptionFrame containing the transcribed text.
|
||||
"""
|
||||
try:
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Make sure the client is initialized
|
||||
if self._asr_service is None:
|
||||
self._initialize_client()
|
||||
|
||||
# Make sure the config is created
|
||||
if self._config is None:
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
# Type assertion to satisfy the IDE
|
||||
assert self._asr_service is not None, "ASR service not initialized"
|
||||
assert self._config is not None, "Recognition config not created"
|
||||
|
||||
# Process audio with Riva ASR - explicitly request non-future response
|
||||
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
|
||||
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Process the response - handle different possible return types
|
||||
try:
|
||||
# If it's a future-like object, get the result
|
||||
if hasattr(raw_response, "result"):
|
||||
response = raw_response.result()
|
||||
else:
|
||||
response = raw_response
|
||||
|
||||
# Process transcription results
|
||||
transcription_found = False
|
||||
|
||||
# Now we can safely check results
|
||||
# Type hint for the IDE
|
||||
results = getattr(response, "results", [])
|
||||
|
||||
for result in results:
|
||||
alternatives = getattr(result, "alternatives", [])
|
||||
if alternatives:
|
||||
text = alternatives[0].transcript.strip()
|
||||
if text:
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_enum,
|
||||
)
|
||||
transcription_found = True
|
||||
|
||||
await self._handle_transcription(text, True, self._language_enum)
|
||||
|
||||
if not transcription_found:
|
||||
logger.debug("No transcription results found in Riva response")
|
||||
|
||||
except AttributeError as ae:
|
||||
yield ErrorFrame(f"Unexpected Riva response format: {str(ae)}")
|
||||
|
||||
except Exception as e:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
|
||||
class ParakeetSTTService(RivaSTTService):
|
||||
"""Deprecated speech-to-text service using NVIDIA Parakeet models.
|
||||
|
||||
.. deprecated:: 0.0.66
|
||||
This class is deprecated. Use `RivaSTTService` instead for equivalent functionality
|
||||
with Parakeet models by specifying the appropriate model_function_map.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "1598d209-5e27-4d3c-8079-4751568b1081",
|
||||
"model_name": "parakeet-ctc-1.1b-asr",
|
||||
},
|
||||
sample_rate: Optional[int] = None,
|
||||
params: Optional[RivaSTTService.InputParams] = None, # Use parent class's type
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Parakeet STT service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: Riva server address. Defaults to NVIDIA Cloud Function endpoint.
|
||||
model_function_map: Mapping containing 'function_id' and 'model_name' for Parakeet model.
|
||||
sample_rate: Audio sample rate in Hz. If None, uses pipeline default.
|
||||
params: Additional configuration parameters for Riva.
|
||||
**kwargs: Additional arguments passed to RivaSTTService.
|
||||
"""
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
server=server,
|
||||
model_function_map=model_function_map,
|
||||
sample_rate=sample_rate,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`ParakeetSTTService` is deprecated, use `RivaSTTService` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"RivaSTTService and ParakeetSTTService "
|
||||
"from pipecat.services.riva.stt is deprecated. "
|
||||
"Please use NvidiaSTTService from pipecat.services.nvidia.stt instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
RivaSTTService = NvidiaSTTService
|
||||
language_to_riva_language = language_to_nvidia_riva_language
|
||||
RivaSegmentedSTTService = NvidiaSegmentedSTTService
|
||||
ParakeetSTTService = NvidiaSTTService
|
||||
|
||||
@@ -8,231 +8,26 @@
|
||||
|
||||
This module provides integration with NVIDIA Riva's TTS services through
|
||||
gRPC API for high-quality speech synthesis.
|
||||
|
||||
.. deprecated:: 0.0.96
|
||||
This module is deprecated. Please NvidiaTTSService from
|
||||
pipecat.services.nvidia.tts instead.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, Mapping, Optional
|
||||
import warnings
|
||||
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
from pipecat.services.nvidia.tts import NVIDIA_TTS_TIMEOUT_SECS, NvidiaTTSService
|
||||
|
||||
# Suppress gRPC fork warnings
|
||||
os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"FastPitchTTSService and RivaTTSService "
|
||||
"from pipecat.services.nim.llm are deprecated. "
|
||||
"Please use NvidiaLLMService from pipecat.services.nvidia.tts instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.services.tts_service import TTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[riva]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
RIVA_TTS_TIMEOUT_SECS = 5
|
||||
|
||||
|
||||
class RivaTTSService(TTSService):
|
||||
"""NVIDIA Riva text-to-speech service.
|
||||
|
||||
Provides high-quality text-to-speech synthesis using NVIDIA Riva's
|
||||
cloud-based TTS models. Supports multiple voices, languages, and
|
||||
configurable quality settings.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Input parameters for Riva TTS configuration.
|
||||
|
||||
Parameters:
|
||||
language: Language code for synthesis. Defaults to US English.
|
||||
quality: Audio quality setting (0-100). Defaults to 20.
|
||||
"""
|
||||
|
||||
language: Optional[Language] = Language.EN_US
|
||||
quality: Optional[int] = 20
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "Magpie-Multilingual.EN-US.Aria",
|
||||
sample_rate: Optional[int] = None,
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "877104f7-e885-42b9-8de8-f6e4c6303969",
|
||||
"model_name": "magpie-tts-multilingual",
|
||||
},
|
||||
params: Optional[InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the NVIDIA Riva TTS service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
|
||||
voice_id: Voice model identifier. Defaults to multilingual Ray voice.
|
||||
sample_rate: Audio sample rate. If None, uses service default.
|
||||
model_function_map: Dictionary containing function_id and model_name for the TTS model.
|
||||
params: Additional configuration parameters for TTS synthesis.
|
||||
**kwargs: Additional arguments passed to parent TTSService.
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
params = params or RivaTTSService.InputParams()
|
||||
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._language_code = params.language
|
||||
self._quality = params.quality
|
||||
self._function_id = model_function_map.get("function_id")
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
self.set_voice(voice_id)
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, True, server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
# warm up the service
|
||||
config_response = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Attempt to set the TTS model.
|
||||
|
||||
Note: Model cannot be changed after initialization for Riva service.
|
||||
|
||||
Args:
|
||||
model: The model name to set (operation not supported).
|
||||
"""
|
||||
logger.warning(f"Cannot set model after initialization. Set model and function id like so:")
|
||||
example = {"function_id": "<UUID>", "model_name": "<model_name>"}
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using NVIDIA Riva TTS.
|
||||
|
||||
Args:
|
||||
text: The text to synthesize into speech.
|
||||
|
||||
Yields:
|
||||
Frame: Audio frames containing the synthesized speech data.
|
||||
"""
|
||||
|
||||
def read_audio_responses(queue: asyncio.Queue):
|
||||
def add_response(r):
|
||||
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
|
||||
|
||||
try:
|
||||
responses = self._service.synthesize_online(
|
||||
text,
|
||||
self._voice_id,
|
||||
self._language_code,
|
||||
sample_rate_hz=self.sample_rate,
|
||||
zero_shot_audio_prompt_file=None,
|
||||
zero_shot_quality=self._quality,
|
||||
custom_dictionary={},
|
||||
)
|
||||
for r in responses:
|
||||
add_response(r)
|
||||
add_response(None)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
add_response(None)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
queue = asyncio.Queue()
|
||||
await asyncio.to_thread(read_audio_responses, queue)
|
||||
|
||||
# Wait for the thread to start.
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS)
|
||||
while resp:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=resp.audio,
|
||||
sample_rate=self.sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=RIVA_TTS_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
yield ErrorFrame(error=f"Unknown error occurred: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
|
||||
class FastPitchTTSService(RivaTTSService):
|
||||
"""Deprecated FastPitch TTS service.
|
||||
|
||||
.. deprecated:: 0.0.66
|
||||
This class is deprecated. Use RivaTTSService instead for new implementations.
|
||||
Provides backward compatibility for existing FastPitch TTS integrations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
server: str = "grpc.nvcf.nvidia.com:443",
|
||||
voice_id: str = "English-US.Female-1",
|
||||
sample_rate: Optional[int] = None,
|
||||
model_function_map: Mapping[str, str] = {
|
||||
"function_id": "0149dedb-2be8-4195-b9a0-e57e0e14f972",
|
||||
"model_name": "fastpitch-hifigan-tts",
|
||||
},
|
||||
params: Optional[RivaTTSService.InputParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the deprecated FastPitch TTS service.
|
||||
|
||||
Args:
|
||||
api_key: NVIDIA API key for authentication.
|
||||
server: gRPC server endpoint. Defaults to NVIDIA's cloud endpoint.
|
||||
voice_id: Voice model identifier. Defaults to Female-1 voice.
|
||||
sample_rate: Audio sample rate. If None, uses service default.
|
||||
model_function_map: Dictionary containing function_id and model_name for FastPitch model.
|
||||
params: Additional configuration parameters for TTS synthesis.
|
||||
**kwargs: Additional arguments passed to parent RivaTTSService.
|
||||
"""
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
server=server,
|
||||
voice_id=voice_id,
|
||||
sample_rate=sample_rate,
|
||||
model_function_map=model_function_map,
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"`FastPitchTTSService` is deprecated, use `RivaTTSService` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
RivaTTSService = NvidiaTTSService
|
||||
FastPitchTTSService = NvidiaTTSService
|
||||
RIVA_TTS_TIMEOUT_SECS = NVIDIA_TTS_TIMEOUT_SECS
|
||||
|
||||
@@ -514,9 +514,11 @@ class SarvamTTSService(InterruptibleTTSService):
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process a frame and flush audio if it's the end of a full response."""
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# When the LLM finishes responding, flush any remaining text in Sarvam's buffer
|
||||
if isinstance(frame, (LLMFullResponseEndFrame, EndFrame)):
|
||||
await self.flush_audio()
|
||||
return await super().process_frame(frame, direction)
|
||||
|
||||
async def _update_settings(self, settings: Mapping[str, Any]):
|
||||
"""Update service settings and reconnect if voice changed."""
|
||||
|
||||
@@ -425,16 +425,13 @@ class TTSService(AIService):
|
||||
# pause to avoid audio overlapping.
|
||||
await self._maybe_pause_frame_processing()
|
||||
|
||||
pending_aggregation = self._text_aggregator.text
|
||||
# Flush any remaining text (including text waiting for lookahead)
|
||||
remaining = await self._text_aggregator.flush()
|
||||
if remaining:
|
||||
await self._push_tts_frames(AggregatedTextFrame(remaining.text, remaining.type))
|
||||
|
||||
# Reset aggregator state
|
||||
await self._text_aggregator.reset()
|
||||
self._processing_text = False
|
||||
|
||||
if pending_aggregation.text:
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(pending_aggregation.text, pending_aggregation.type)
|
||||
)
|
||||
if isinstance(frame, LLMFullResponseEndFrame):
|
||||
if self._push_text_frames:
|
||||
await self.push_frame(frame, direction)
|
||||
@@ -539,17 +536,20 @@ class TTSService(AIService):
|
||||
text = frame.text
|
||||
includes_inter_frame_spaces = frame.includes_inter_frame_spaces
|
||||
aggregated_by = "token"
|
||||
|
||||
if text:
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
|
||||
)
|
||||
else:
|
||||
aggregate = await self._text_aggregator.aggregate(frame.text)
|
||||
if aggregate:
|
||||
async for aggregate in self._text_aggregator.aggregate(frame.text):
|
||||
text = aggregate.text
|
||||
aggregated_by = aggregate.type
|
||||
|
||||
if text:
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
|
||||
)
|
||||
logger.trace(f"Pushing TTS frames for text: {text}, {aggregated_by}")
|
||||
await self._push_tts_frames(
|
||||
AggregatedTextFrame(text, aggregated_by), includes_inter_frame_spaces
|
||||
)
|
||||
|
||||
async def _push_tts_frames(
|
||||
self, src_frame: AggregatedTextFrame, includes_inter_frame_spaces: Optional[bool] = False
|
||||
|
||||
@@ -6,31 +6,14 @@
|
||||
|
||||
"""Base notifier interface for Pipecat."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import warnings
|
||||
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
|
||||
class BaseNotifier(ABC):
|
||||
"""Abstract base class for notification mechanisms.
|
||||
|
||||
Provides a standard interface for implementing notification and waiting
|
||||
patterns used for event coordination and signaling between components
|
||||
in the Pipecat framework.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def notify(self):
|
||||
"""Send a notification signal.
|
||||
|
||||
Implementations should trigger any waiting coroutines or processes
|
||||
that are blocked on this notifier.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def wait(self):
|
||||
"""Wait for a notification signal.
|
||||
|
||||
Implementations should block until a notification is received
|
||||
from the corresponding notify() call.
|
||||
"""
|
||||
pass
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Package pipecat.sync is deprecated, use pipecat.utils.sync instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
@@ -6,40 +6,14 @@
|
||||
|
||||
"""Event-based notifier implementation using asyncio Event primitives."""
|
||||
|
||||
import asyncio
|
||||
import warnings
|
||||
|
||||
from pipecat.sync.base_notifier import BaseNotifier
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
|
||||
|
||||
class EventNotifier(BaseNotifier):
|
||||
"""Event-based notifier using asyncio.Event for task synchronization.
|
||||
|
||||
Provides a simple notification mechanism where one task can signal
|
||||
an event and other tasks can wait for that event to occur. The event
|
||||
is automatically cleared after each wait operation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the event notifier.
|
||||
|
||||
Creates an internal asyncio.Event for managing notifications.
|
||||
"""
|
||||
self._event = asyncio.Event()
|
||||
|
||||
async def notify(self):
|
||||
"""Signal the event to notify waiting tasks.
|
||||
|
||||
Sets the internal event, causing any tasks waiting on this
|
||||
notifier to be awakened.
|
||||
"""
|
||||
self._event.set()
|
||||
|
||||
async def wait(self):
|
||||
"""Wait for the event to be signaled.
|
||||
|
||||
Blocks until another task calls notify(). Automatically clears
|
||||
the event after being awakened so subsequent calls will wait
|
||||
for the next notification.
|
||||
"""
|
||||
await self._event.wait()
|
||||
self._event.clear()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
"Package pipecat.sync is deprecated, use pipecat.utils.sync instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
0
src/pipecat/utils/sync/__init__.py
Normal file
0
src/pipecat/utils/sync/__init__.py
Normal file
36
src/pipecat/utils/sync/base_notifier.py
Normal file
36
src/pipecat/utils/sync/base_notifier.py
Normal file
@@ -0,0 +1,36 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Base notifier interface for Pipecat."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseNotifier(ABC):
|
||||
"""Abstract base class for notification mechanisms.
|
||||
|
||||
Provides a standard interface for implementing notification and waiting
|
||||
patterns used for event coordination and signaling between components
|
||||
in the Pipecat framework.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def notify(self):
|
||||
"""Send a notification signal.
|
||||
|
||||
Implementations should trigger any waiting coroutines or processes
|
||||
that are blocked on this notifier.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def wait(self):
|
||||
"""Wait for a notification signal.
|
||||
|
||||
Implementations should block until a notification is received
|
||||
from the corresponding notify() call.
|
||||
"""
|
||||
pass
|
||||
45
src/pipecat/utils/sync/event_notifier.py
Normal file
45
src/pipecat/utils/sync/event_notifier.py
Normal file
@@ -0,0 +1,45 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Event-based notifier implementation using asyncio Event primitives."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
|
||||
|
||||
class EventNotifier(BaseNotifier):
|
||||
"""Event-based notifier using asyncio.Event for task synchronization.
|
||||
|
||||
Provides a simple notification mechanism where one task can signal
|
||||
an event and other tasks can wait for that event to occur. The event
|
||||
is automatically cleared after each wait operation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the event notifier.
|
||||
|
||||
Creates an internal asyncio.Event for managing notifications.
|
||||
"""
|
||||
self._event = asyncio.Event()
|
||||
|
||||
async def notify(self):
|
||||
"""Signal the event to notify waiting tasks.
|
||||
|
||||
Sets the internal event, causing any tasks waiting on this
|
||||
notifier to be awakened.
|
||||
"""
|
||||
self._event.set()
|
||||
|
||||
async def wait(self):
|
||||
"""Wait for the event to be signaled.
|
||||
|
||||
Blocks until another task calls notify(). Automatically clears
|
||||
the event after being awakened so subsequent calls will wait
|
||||
for the next notification.
|
||||
"""
|
||||
await self._event.wait()
|
||||
self._event.clear()
|
||||
@@ -14,7 +14,7 @@ aggregated text should be sent for speech synthesis.
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
|
||||
class AggregationType(str, Enum):
|
||||
@@ -80,33 +80,43 @@ class BaseTextAggregator(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate the specified text with the currently accumulated text.
|
||||
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
|
||||
"""Aggregate the specified text and yield completed aggregations.
|
||||
|
||||
This method should be implemented to define how the new text contributes
|
||||
to the aggregation process. It returns the aggregated text and a string
|
||||
describing how it was aggregated if it's ready to be processed,
|
||||
or None otherwise.
|
||||
This method processes the input text character-by-character internally
|
||||
and yields Aggregation objects as they complete.
|
||||
|
||||
Subclasses should implement their specific logic for:
|
||||
|
||||
- How to combine new text with existing accumulated text
|
||||
- How to process text character-by-character
|
||||
- When to consider the aggregated text ready for processing
|
||||
- What criteria determine text completion (e.g., sentence boundaries)
|
||||
- When a completion occurs, the method should return an Aggregation object
|
||||
containing the aggregated text and its type. The text should be stripped
|
||||
of leading/trailing whitespace so that consumers can rely on a consistent
|
||||
format.
|
||||
- When a completion occurs, yield an Aggregation object containing the
|
||||
aggregated text (stripped of leading/trailing whitespace) and its type
|
||||
|
||||
Args:
|
||||
text: The text to be aggregated.
|
||||
|
||||
Yields:
|
||||
Aggregation objects as they complete. Each Aggregation consists of
|
||||
the aggregated text (stripped of leading/trailing whitespace) and
|
||||
a string indicating the type of aggregation (e.g., 'sentence', 'word',
|
||||
'token', 'my_custom_aggregation').
|
||||
"""
|
||||
pass
|
||||
# Make this a generator to satisfy type checker
|
||||
yield # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
async def flush(self) -> Optional[Aggregation]:
|
||||
"""Flush any pending aggregation.
|
||||
|
||||
This method is called at the end of a stream (e.g., when receiving
|
||||
LLMFullResponseEndFrame) to return any text that was buffered.
|
||||
|
||||
Returns:
|
||||
An Aggregation object if ready for processing, or None if more
|
||||
text is needed before the aggregated content is ready. If an Aggregation
|
||||
object is returned, it should consist of the updated aggregated text,
|
||||
stripped of leading/trailing whitespace, and a string indicating the
|
||||
type of aggregation (e.g., 'sentence', 'word', 'token', 'my_custom_aggregation').
|
||||
An Aggregation object if there is pending text, or None if there
|
||||
is no pending text.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -13,12 +13,12 @@ support for custom handlers and configurable actions for when a pattern is found
|
||||
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Awaitable, Callable, List, Optional, Tuple
|
||||
from typing import AsyncIterator, Awaitable, Callable, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
class MatchAction(Enum):
|
||||
@@ -72,7 +72,7 @@ class PatternMatch(Aggregation):
|
||||
return f"PatternMatch(type={self.type}, text={self.text}, full_match={self.full_match})"
|
||||
|
||||
|
||||
class PatternPairAggregator(BaseTextAggregator):
|
||||
class PatternPairAggregator(SimpleTextAggregator):
|
||||
"""Aggregator that identifies and processes content between pattern pairs.
|
||||
|
||||
This aggregator buffers text until it can identify complete pattern pairs
|
||||
@@ -97,9 +97,10 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
Creates an empty aggregator with no patterns or handlers registered.
|
||||
Text buffering and pattern detection will begin when text is aggregated.
|
||||
"""
|
||||
self._text = ""
|
||||
super().__init__()
|
||||
self._patterns = {}
|
||||
self._handlers = {}
|
||||
self._last_processed_position = 0 # Track where we last checked for complete patterns
|
||||
|
||||
@property
|
||||
def text(self) -> Aggregation:
|
||||
@@ -218,14 +219,18 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
self._handlers[type] = handler
|
||||
return self
|
||||
|
||||
async def _process_complete_patterns(self, text: str) -> Tuple[List[PatternMatch], str]:
|
||||
"""Process all complete pattern pairs in the text.
|
||||
async def _process_complete_patterns(
|
||||
self, text: str, last_processed_position: int = 0
|
||||
) -> Tuple[List[PatternMatch], str]:
|
||||
"""Process newly complete pattern pairs in the text.
|
||||
|
||||
Searches for all complete pattern pairs in the text, calls the
|
||||
appropriate handlers, and optionally removes the matches.
|
||||
Searches for pattern pairs that have been completed since last_processed_position,
|
||||
calls the appropriate handlers, and optionally removes the matches.
|
||||
|
||||
Args:
|
||||
text: The text to process for pattern matches.
|
||||
last_processed_position: The position in text that was already processed.
|
||||
Only patterns that end at or after this position will be processed.
|
||||
|
||||
Returns:
|
||||
Tuple of (all_matches, processed_text) where:
|
||||
@@ -259,17 +264,23 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
content=content.strip(), type=type, full_match=full_match
|
||||
)
|
||||
|
||||
# Call the appropriate handler if registered
|
||||
if type in self._handlers:
|
||||
# Check if this pattern was already processed
|
||||
already_processed = match.end() <= last_processed_position
|
||||
|
||||
# Only call handler for newly completed patterns
|
||||
if not already_processed and type in self._handlers:
|
||||
try:
|
||||
await self._handlers[type](pattern_match)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pattern handler for {type}: {e}")
|
||||
|
||||
# Remove the pattern from the text if configured
|
||||
# Handle pattern based on action
|
||||
if action == MatchAction.REMOVE:
|
||||
processed_text = processed_text.replace(full_match, "", 1)
|
||||
# Remove patterns are only removed once (when newly completed)
|
||||
if not already_processed:
|
||||
processed_text = processed_text.replace(full_match, "", 1)
|
||||
else:
|
||||
# KEEP/AGGREGATE patterns stay in all_matches
|
||||
all_matches.append(pattern_match)
|
||||
|
||||
return all_matches, processed_text
|
||||
@@ -305,76 +316,84 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
|
||||
return None
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[PatternMatch]:
|
||||
async def aggregate(self, text: str) -> AsyncIterator[PatternMatch]:
|
||||
"""Aggregate text and process pattern pairs.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete pattern
|
||||
pairs, and returns processed text up to sentence boundaries if possible.
|
||||
If there are incomplete patterns (start without matching end), it will
|
||||
continue buffering text.
|
||||
Processes the input text character-by-character, handles pattern pairs,
|
||||
and uses the parent's lookahead logic for sentence detection when no
|
||||
patterns are active.
|
||||
|
||||
Args:
|
||||
text: New text to add to the buffer.
|
||||
text: Text to aggregate.
|
||||
|
||||
Returns:
|
||||
Processed text up to a sentence boundary, or None if more
|
||||
text is needed to form a complete sentence or pattern.
|
||||
Yields:
|
||||
PatternMatch objects as patterns complete or sentences are detected.
|
||||
"""
|
||||
# Add new text to buffer
|
||||
self._text += text
|
||||
# Process text character by character
|
||||
for char in text:
|
||||
self._text += char
|
||||
|
||||
# Process any complete patterns in the buffer
|
||||
patterns, processed_text = await self._process_complete_patterns(self._text)
|
||||
# Process any newly complete patterns in the buffer
|
||||
# Only patterns that complete after _last_processed_position will trigger handlers
|
||||
patterns, processed_text = await self._process_complete_patterns(
|
||||
self._text, self._last_processed_position
|
||||
)
|
||||
|
||||
self._text = processed_text
|
||||
# Update the last processed position to prevent re-processing patterns
|
||||
# This tracks where in the buffer we've already called handlers, so we
|
||||
# only trigger handlers once when a pattern completes
|
||||
self._last_processed_position = len(self._text)
|
||||
|
||||
if len(patterns) > 0:
|
||||
if len(patterns) > 1:
|
||||
logger.warning(
|
||||
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
|
||||
self._text = processed_text
|
||||
|
||||
if len(patterns) > 0:
|
||||
if len(patterns) > 1:
|
||||
logger.warning(
|
||||
f"Multiple patterns matched: {[p.type for p in patterns]}. Only the first pattern will be returned."
|
||||
)
|
||||
# If the pattern found is set to be aggregated, return it
|
||||
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
|
||||
if action == MatchAction.AGGREGATE:
|
||||
self._text = ""
|
||||
yield patterns[0]
|
||||
continue
|
||||
|
||||
# Check if we have incomplete patterns
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start is not None:
|
||||
# If the start pattern is at the beginning or should not be separately aggregated, continue
|
||||
if (
|
||||
pattern_start[0] == 0
|
||||
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
|
||||
):
|
||||
continue
|
||||
# For AGGREGATE patterns: yield any text before the pattern starts
|
||||
# This ensures text doesn't get stuck in the buffer waiting for sentence
|
||||
# boundaries when a pattern begins (e.g., "Here is code <code>..." yields "Here is code")
|
||||
result = self._text[: pattern_start[0]]
|
||||
self._text = self._text[pattern_start[0] :]
|
||||
yield PatternMatch(
|
||||
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
|
||||
)
|
||||
# If the pattern found is set to be aggregated, return it
|
||||
action = self._patterns[patterns[0].type].get("action", MatchAction.REMOVE)
|
||||
if action == MatchAction.AGGREGATE:
|
||||
self._text = ""
|
||||
return patterns[0]
|
||||
continue
|
||||
|
||||
# Check if we have incomplete patterns
|
||||
pattern_start = self._match_start_of_pattern(self._text)
|
||||
if pattern_start is not None:
|
||||
# If the start pattern is at the beginning or should not be separately aggregated, return None
|
||||
if (
|
||||
pattern_start[0] == 0
|
||||
or pattern_start[1].get("action", MatchAction.REMOVE) != MatchAction.AGGREGATE
|
||||
):
|
||||
return None
|
||||
# Otherwise, strip the text up to the start pattern and return it
|
||||
result = self._text[: pattern_start[0]]
|
||||
self._text = self._text[pattern_start[0] :]
|
||||
return PatternMatch(
|
||||
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
|
||||
)
|
||||
|
||||
# Find sentence boundary if no incomplete patterns
|
||||
eos_marker = match_endofsentence(self._text)
|
||||
if eos_marker:
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return PatternMatch(
|
||||
content=result.strip(), type=AggregationType.SENTENCE, full_match=result
|
||||
)
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
# Use parent's lookahead logic for sentence detection
|
||||
aggregation = await super()._check_sentence_with_lookahead(char)
|
||||
if aggregation:
|
||||
# Convert to PatternMatch for consistency with return type
|
||||
yield PatternMatch(
|
||||
content=aggregation.text, type=aggregation.type, full_match=aggregation.text
|
||||
)
|
||||
|
||||
async def handle_interruption(self):
|
||||
"""Handle interruptions by clearing the buffer.
|
||||
"""Handle interruptions by clearing the buffer and pattern state.
|
||||
|
||||
Called when an interruption occurs in the processing pipeline,
|
||||
to reset the state and discard any partially aggregated text.
|
||||
"""
|
||||
self._text = ""
|
||||
await super().handle_interruption()
|
||||
self._last_processed_position = 0
|
||||
# Pattern and handler state persists across interruptions
|
||||
|
||||
async def reset(self):
|
||||
"""Clear the internally aggregated text.
|
||||
@@ -382,4 +401,6 @@ class PatternPairAggregator(BaseTextAggregator):
|
||||
Resets the aggregator to its initial state, discarding any
|
||||
buffered text and clearing pattern tracking state.
|
||||
"""
|
||||
self._text = ""
|
||||
await super().reset()
|
||||
self._last_processed_position = 0
|
||||
# Pattern and handler state persists across resets
|
||||
|
||||
@@ -11,9 +11,9 @@ until it finds an end-of-sentence marker, making it suitable for basic TTS
|
||||
text processing scenarios.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
from pipecat.utils.string import match_endofsentence
|
||||
from pipecat.utils.string import SENTENCE_ENDING_PUNCTUATION, match_endofsentence
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
Creates an empty text buffer ready to begin accumulating text tokens.
|
||||
"""
|
||||
self._text = ""
|
||||
self._needs_lookahead: bool = False
|
||||
|
||||
@property
|
||||
def text(self) -> Aggregation:
|
||||
@@ -41,30 +42,87 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
"""
|
||||
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
"""Aggregate text and return completed sentences.
|
||||
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
|
||||
"""Aggregate text and yield completed sentences.
|
||||
|
||||
Adds the new text to the buffer and checks for end-of-sentence markers.
|
||||
When a sentence boundary is found, returns the completed sentence and
|
||||
removes it from the buffer.
|
||||
Processes the input text character-by-character. When sentence-ending
|
||||
punctuation is detected, it waits for non-whitespace lookahead before
|
||||
calling NLTK. This prevents false positives like "$29." being detected
|
||||
as a sentence when it's actually "$29.95".
|
||||
|
||||
Args:
|
||||
text: New text to add to the aggregation buffer.
|
||||
text: Text to aggregate.
|
||||
|
||||
Yields:
|
||||
Complete sentences as Aggregation objects.
|
||||
"""
|
||||
# Process text character by character
|
||||
for char in text:
|
||||
self._text += char
|
||||
|
||||
# Check for sentence with lookahead
|
||||
result = await self._check_sentence_with_lookahead(char)
|
||||
if result:
|
||||
yield result
|
||||
|
||||
async def _check_sentence_with_lookahead(self, char: str) -> Optional[Aggregation]:
|
||||
"""Check for sentence boundaries using lookahead logic.
|
||||
|
||||
This method implements the core sentence detection logic with lookahead.
|
||||
When sentence-ending punctuation is detected, it waits for the next
|
||||
non-whitespace character before calling NLTK. This disambiguates cases
|
||||
like "$29." (not a sentence) vs "$29. Next" (sentence ends at period).
|
||||
Whitespace alone is not meaningful lookahead since it appears in both
|
||||
cases. Instead, the first non-whitespace character after the punctuation
|
||||
is used to confirm the sentence boundary.
|
||||
|
||||
Subclasses can call this via super() to reuse the lookahead behavior
|
||||
while adding their own logic (e.g., tag handling, pattern matching).
|
||||
|
||||
Args:
|
||||
char: The most recently added character (used for lookahead check).
|
||||
|
||||
Returns:
|
||||
A complete sentence if an end-of-sentence marker is found,
|
||||
or None if more text is needed to complete a sentence.
|
||||
Aggregation if sentence found, None otherwise.
|
||||
"""
|
||||
result: Optional[str] = None
|
||||
# If we need lookahead, check if we now have non-whitespace
|
||||
if self._needs_lookahead:
|
||||
# Check if the new character is non-whitespace
|
||||
if char.strip():
|
||||
# We have meaningful lookahead, call NLTK
|
||||
self._needs_lookahead = False
|
||||
eos_marker = match_endofsentence(self._text)
|
||||
|
||||
self._text += text
|
||||
if eos_marker:
|
||||
# NLTK confirmed a sentence - return it
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return Aggregation(text=result, type=AggregationType.SENTENCE)
|
||||
# No sentence found - keep accumulating
|
||||
return None
|
||||
# Still whitespace, keep waiting
|
||||
return None
|
||||
|
||||
eos_end_marker = match_endofsentence(self._text)
|
||||
if eos_end_marker:
|
||||
result = self._text[:eos_end_marker]
|
||||
self._text = self._text[eos_end_marker:]
|
||||
# Check if we just added sentence-ending punctuation
|
||||
if self._text and self._text[-1] in SENTENCE_ENDING_PUNCTUATION:
|
||||
# Mark that we need lookahead (don't call NLTK yet)
|
||||
self._needs_lookahead = True
|
||||
|
||||
if result:
|
||||
return None
|
||||
|
||||
async def flush(self) -> Optional[Aggregation]:
|
||||
"""Flush any remaining text in the buffer.
|
||||
|
||||
Returns any text remaining in the buffer. This is called at the end
|
||||
of a stream to ensure all text is processed.
|
||||
|
||||
Returns:
|
||||
Any remaining text as a sentence, or None if buffer is empty.
|
||||
"""
|
||||
if self._text:
|
||||
# Return whatever we have in the buffer
|
||||
result = self._text
|
||||
await self.reset()
|
||||
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
|
||||
return None
|
||||
|
||||
@@ -75,6 +133,7 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
discarding any partially accumulated text.
|
||||
"""
|
||||
self._text = ""
|
||||
self._needs_lookahead = False
|
||||
|
||||
async def reset(self):
|
||||
"""Clear the internally aggregated text.
|
||||
@@ -83,3 +142,4 @@ class SimpleTextAggregator(BaseTextAggregator):
|
||||
any accumulated text content.
|
||||
"""
|
||||
self._text = ""
|
||||
self._needs_lookahead = False
|
||||
|
||||
@@ -11,13 +11,14 @@ between specified start/end tag pairs, ensuring that tagged content is processed
|
||||
as a unit regardless of internal punctuation.
|
||||
"""
|
||||
|
||||
from typing import Optional, Sequence
|
||||
from typing import AsyncIterator, Optional, Sequence
|
||||
|
||||
from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType, BaseTextAggregator
|
||||
from pipecat.utils.string import StartEndTags, parse_start_end_tags
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
class SkipTagsAggregator(BaseTextAggregator):
|
||||
class SkipTagsAggregator(SimpleTextAggregator):
|
||||
"""Aggregator that prevents end of sentence matching between start/end tags.
|
||||
|
||||
This aggregator buffers text until it finds an end of sentence or a start
|
||||
@@ -37,67 +38,59 @@ class SkipTagsAggregator(BaseTextAggregator):
|
||||
tags: Sequence of StartEndTags objects defining the tag pairs
|
||||
that should prevent sentence boundary detection.
|
||||
"""
|
||||
self._text = ""
|
||||
super().__init__()
|
||||
self._tags = tags
|
||||
self._current_tag: Optional[StartEndTags] = None
|
||||
self._current_tag_index: int = 0
|
||||
|
||||
@property
|
||||
def text(self) -> Aggregation:
|
||||
"""Get the currently buffered text.
|
||||
|
||||
Returns:
|
||||
The current text buffer content that hasn't been processed yet.
|
||||
"""
|
||||
return Aggregation(text=self._text.strip(), type=AggregationType.SENTENCE)
|
||||
|
||||
async def aggregate(self, text: str) -> Optional[Aggregation]:
|
||||
async def aggregate(self, text: str) -> AsyncIterator[Aggregation]:
|
||||
"""Aggregate text while respecting tag boundaries.
|
||||
|
||||
This method adds the new text to the buffer, processes any complete
|
||||
pattern pairs, and returns processed text up to sentence boundaries if
|
||||
possible. If there are incomplete patterns (start without matching
|
||||
end), it will continue buffering text.
|
||||
Processes the input text character-by-character, updates tag state, and
|
||||
uses the parent's lookahead logic for sentence detection when not
|
||||
inside tags.
|
||||
|
||||
Args:
|
||||
text: New text to add to the buffer.
|
||||
text: Text to aggregate.
|
||||
|
||||
Returns:
|
||||
An Aggregation object containing text up to a sentence boundary and
|
||||
marked as SENTENCE type or None if more text is needed to complete a
|
||||
sentence or close tags.
|
||||
Yields:
|
||||
Aggregation objects containing text up to a sentence boundary,
|
||||
marked as SENTENCE type.
|
||||
"""
|
||||
# Add new text to buffer
|
||||
self._text += text
|
||||
# Process text character by character
|
||||
for char in text:
|
||||
self._text += char
|
||||
|
||||
(self._current_tag, self._current_tag_index) = parse_start_end_tags(
|
||||
self._text, self._tags, self._current_tag, self._current_tag_index
|
||||
)
|
||||
# Update tag state
|
||||
(self._current_tag, self._current_tag_index) = parse_start_end_tags(
|
||||
self._text, self._tags, self._current_tag, self._current_tag_index
|
||||
)
|
||||
|
||||
# Find sentence boundary if no incomplete patterns
|
||||
if not self._current_tag:
|
||||
eos_marker = match_endofsentence(self._text)
|
||||
if eos_marker:
|
||||
# Extract text up to the sentence boundary
|
||||
result = self._text[:eos_marker]
|
||||
self._text = self._text[eos_marker:]
|
||||
return Aggregation(text=result.strip(), type=AggregationType.SENTENCE)
|
||||
# If inside tags, don't check for sentences
|
||||
if self._current_tag:
|
||||
continue
|
||||
|
||||
# No complete sentence found yet
|
||||
return None
|
||||
# Otherwise, use parent's lookahead logic for sentence detection
|
||||
result = await super()._check_sentence_with_lookahead(char)
|
||||
if result:
|
||||
yield result
|
||||
|
||||
async def handle_interruption(self):
|
||||
"""Handle interruptions by clearing the buffer.
|
||||
"""Handle interruptions by clearing the buffer and tag state.
|
||||
|
||||
Called when an interruption occurs in the processing pipeline,
|
||||
to reset the state and discard any partially aggregated text.
|
||||
"""
|
||||
self._text = ""
|
||||
await super().handle_interruption()
|
||||
self._current_tag = None
|
||||
self._current_tag_index = 0
|
||||
|
||||
async def reset(self):
|
||||
"""Clear the internally aggregated text.
|
||||
"""Clear the internally aggregated text and tag state.
|
||||
|
||||
Resets the aggregator to its initial state, discarding any
|
||||
buffered text.
|
||||
"""
|
||||
self._text = ""
|
||||
await super().reset()
|
||||
self._current_tag = None
|
||||
self._current_tag_index = 0
|
||||
|
||||
@@ -10,6 +10,7 @@ from unittest.mock import AsyncMock
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.extensions.ivr.ivr_navigator import IVRProcessor
|
||||
from pipecat.frames.frames import (
|
||||
LLMFullResponseEndFrame,
|
||||
LLMMessagesUpdateFrame,
|
||||
LLMTextFrame,
|
||||
OutputDTMFUrgentFrame,
|
||||
@@ -334,10 +335,12 @@ class TestIVRNavigation(unittest.IsolatedAsyncioTestCase):
|
||||
|
||||
frames_to_send = [
|
||||
LLMTextFrame(text="Hello, I'm trying to reach billing."),
|
||||
LLMFullResponseEndFrame(),
|
||||
]
|
||||
|
||||
expected_down_frames = [
|
||||
LLMTextFrame, # Should pass through unchanged
|
||||
LLMFullResponseEndFrame,
|
||||
]
|
||||
|
||||
expected_up_frames = [
|
||||
|
||||
@@ -38,14 +38,8 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.aggregator.on_pattern_match("code_pattern", self.code_handler)
|
||||
|
||||
async def test_pattern_match_and_removal(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "Hello <test>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "test_pattern")
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</test>!")
|
||||
text = "Hello <test>pattern content</test>!"
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
|
||||
# Verify the handler was called with correct PatternMatch object
|
||||
self.test_handler.assert_called_once()
|
||||
@@ -55,28 +49,37 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(call_args.full_match, "<test>pattern content</test>")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
|
||||
# The exclamation point should be treated as a sentence boundary,
|
||||
# so the result should include just text up to and including "!"
|
||||
self.assertEqual(result.text, "Hello !")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
# No results yet (waiting for lookahead after "!")
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Next sentence should be processed separately. Spaces around the sentence
|
||||
# should be stripped in the returned Aggregation.
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
# Next sentence should provide the lookahead and trigger the previous sentence
|
||||
async for result in self.aggregator.aggregate(" This is another sentence."):
|
||||
results.append(result)
|
||||
|
||||
# First result should be "Hello !" triggered by the space lookahead
|
||||
self.assertEqual(len(results), 1)
|
||||
self.assertEqual(results[0].text, "Hello !")
|
||||
self.assertEqual(results[0].type, "sentence")
|
||||
|
||||
# Now flush to get the remaining sentence
|
||||
result = await self.aggregator.flush()
|
||||
self.assertEqual(result.text, "This is another sentence.")
|
||||
|
||||
# Buffer should be empty after returning a complete sentence
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_pattern_match_and_aggregate(self):
|
||||
# First part doesn't complete the pattern
|
||||
result = await self.aggregator.aggregate("Here is code <code>pattern")
|
||||
self.assertEqual(result.text, "Here is code")
|
||||
self.assertEqual(self.aggregator.text.text, "<code>pattern")
|
||||
self.assertEqual(self.aggregator.text.type, "code_pattern")
|
||||
text = "Here is code <code>pattern content</code> This is another sentence."
|
||||
|
||||
# Second part completes the pattern and includes an exclamation point
|
||||
result = await self.aggregator.aggregate(" content</code>")
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
|
||||
# First result should be "Here is code" when pattern starts
|
||||
self.assertEqual(results[0].text, "Here is code")
|
||||
self.assertEqual(results[0].type, "sentence")
|
||||
|
||||
# Second result should be the code pattern content
|
||||
self.assertEqual(results[1].text, "pattern content")
|
||||
self.assertEqual(results[1].type, "code_pattern")
|
||||
|
||||
# Verify the handler was called with correct PatternMatch object
|
||||
self.code_handler.assert_called_once()
|
||||
@@ -85,11 +88,9 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(call_args.type, "code_pattern")
|
||||
self.assertEqual(call_args.full_match, "<code>pattern content</code>")
|
||||
self.assertEqual(call_args.text, "pattern content")
|
||||
self.assertEqual(result.text, "pattern content")
|
||||
self.assertEqual(result.type, "code_pattern")
|
||||
|
||||
# Next sentence should be processed separately
|
||||
result = await self.aggregator.aggregate(" This is another sentence.")
|
||||
# Last sentence needs flush (waiting for lookahead after ".")
|
||||
result = await self.aggregator.flush()
|
||||
self.assertEqual(result.text, "This is another sentence.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
|
||||
@@ -97,11 +98,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_incomplete_pattern(self):
|
||||
# Add text with incomplete pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern content")
|
||||
|
||||
text = "Hello <test>pattern content"
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
# No complete pattern yet, so nothing should be returned
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# The handler should not be called yet
|
||||
self.test_handler.assert_not_called()
|
||||
@@ -136,9 +136,8 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.aggregator.on_pattern_match("voice", voice_handler)
|
||||
self.aggregator.on_pattern_match("emphasis", emphasis_handler)
|
||||
|
||||
# Test with multiple patterns in one text block
|
||||
text = "Hello <voice>female</voice> I am <em>very</em> excited to meet you!"
|
||||
result = await self.aggregator.aggregate(text)
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
|
||||
# Both handlers should be called with correct data
|
||||
voice_handler.assert_called_once()
|
||||
@@ -151,6 +150,10 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(emphasis_match.type, "emphasis")
|
||||
self.assertEqual(emphasis_match.text, "very")
|
||||
|
||||
# With lookahead, we need to flush to get the final sentence
|
||||
self.assertEqual(len(results), 0) # Waiting for lookahead after "!"
|
||||
|
||||
result = await self.aggregator.flush()
|
||||
# Voice pattern should be removed, emphasis pattern should remain
|
||||
self.assertEqual(result.text, "Hello I am <em>very</em> excited to meet you!")
|
||||
|
||||
@@ -158,9 +161,9 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
|
||||
async def test_handle_interruption(self):
|
||||
# Start with incomplete pattern
|
||||
result = await self.aggregator.aggregate("Hello <test>pattern")
|
||||
self.assertIsNone(result)
|
||||
text = "Hello <test>pattern"
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Simulate interruption
|
||||
await self.aggregator.handle_interruption()
|
||||
@@ -172,20 +175,18 @@ class TestPatternPairAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.test_handler.assert_not_called()
|
||||
|
||||
async def test_pattern_across_sentences(self):
|
||||
# Test pattern that spans multiple sentences
|
||||
result = await self.aggregator.aggregate("Hello <test>This is sentence one.")
|
||||
|
||||
# First sentence contains start of pattern but no end, so no complete pattern yet
|
||||
self.assertIsNone(result)
|
||||
|
||||
# Add second part with pattern end
|
||||
result = await self.aggregator.aggregate(" This is sentence two.</test> Final sentence.")
|
||||
text = "Hello <test>This is sentence one. This is sentence two.</test> Final sentence."
|
||||
results = [result async for result in self.aggregator.aggregate(text)]
|
||||
|
||||
# Handler should be called with entire content
|
||||
self.test_handler.assert_called_once()
|
||||
call_args = self.test_handler.call_args[0][0]
|
||||
self.assertEqual(call_args.text, "This is sentence one. This is sentence two.")
|
||||
|
||||
# With lookahead, we need to flush to get the final sentence
|
||||
self.assertEqual(len(results), 0) # Waiting for lookahead after "."
|
||||
|
||||
result = await self.aggregator.flush()
|
||||
# Pattern should be removed, resulting in text with sentences merged
|
||||
self.assertEqual(result.text, "Hello Final sentence.")
|
||||
|
||||
|
||||
@@ -14,22 +14,112 @@ class TestSimpleTextAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
self.aggregator = SimpleTextAggregator()
|
||||
|
||||
async def test_reset_aggregations(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
text = "Hello "
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences yet
|
||||
assert len(results) == 0
|
||||
assert self.aggregator.text.text == "Hello"
|
||||
await self.aggregator.reset()
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_simple_sentence(self):
|
||||
assert await self.aggregator.aggregate("Hello ") == None
|
||||
aggregate = await self.aggregator.aggregate("Pipecat!")
|
||||
text = "Hello Pipecat!"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences yet (waiting for lookahead after "!")
|
||||
assert len(results) == 0
|
||||
|
||||
# Flush to get the pending sentence
|
||||
aggregate = await self.aggregator.flush()
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
assert aggregate.type == "sentence"
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_multiple_sentences(self):
|
||||
aggregate = await self.aggregator.aggregate("Hello Pipecat! How are ")
|
||||
assert aggregate.text == "Hello Pipecat!"
|
||||
# Aggregators should strip leading/trailing spaces when returning text
|
||||
assert self.aggregator.text.text == "How are"
|
||||
aggregate = await self.aggregator.aggregate("you?")
|
||||
assert aggregate.text == "How are you?"
|
||||
text = "Hello Pipecat! How are you?"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# First sentence should be complete (lookahead from "H" confirmed it)
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Hello Pipecat!"
|
||||
|
||||
# Flush to get the pending sentence
|
||||
result = await self.aggregator.flush()
|
||||
assert result.text == "How are you?"
|
||||
|
||||
async def test_lookahead_decimal_number(self):
|
||||
"""Test that $29.95 is not split at $29."""
|
||||
text = "Ask me for only $29.95/month."
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences yet (waiting for lookahead after final ".")
|
||||
assert len(results) == 0
|
||||
|
||||
# Can use flush() to get the pending sentence at end of stream
|
||||
result = await self.aggregator.flush()
|
||||
assert result.text == "Ask me for only $29.95/month."
|
||||
|
||||
async def test_lookahead_abbreviation(self):
|
||||
"""Test that Mr. Smith is not split at Mr."""
|
||||
text = "Hello Mr. Smith."
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences yet (waiting for lookahead after final ".")
|
||||
assert len(results) == 0
|
||||
|
||||
# Can use flush() to get the pending sentence at end of stream
|
||||
result = await self.aggregator.flush()
|
||||
assert result.text == "Hello Mr. Smith."
|
||||
|
||||
async def test_lookahead_actual_sentence_end(self):
|
||||
"""Test that a real sentence end is detected after lookahead."""
|
||||
text = "Hello world. Next sentence"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# First sentence should be complete (lookahead from "N" confirmed it)
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Hello world."
|
||||
|
||||
async def test_flush_pending_sentence(self):
|
||||
"""Test that flush() returns pending sentence waiting for lookahead."""
|
||||
text = "Hello world."
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences yet (waiting for lookahead)
|
||||
assert len(results) == 0
|
||||
|
||||
# Call flush to get it
|
||||
result = await self.aggregator.flush()
|
||||
assert result is not None
|
||||
assert result.text == "Hello world."
|
||||
# Flush again should return None
|
||||
assert await self.aggregator.flush() == None
|
||||
|
||||
async def test_flush_with_no_pending(self):
|
||||
"""Test that flush() returns any remaining text in buffer."""
|
||||
text = "Hello"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# No complete sentences
|
||||
assert len(results) == 0
|
||||
|
||||
result = await self.aggregator.flush()
|
||||
# flush() now returns any remaining text, not just pending lookahead
|
||||
assert result is not None
|
||||
assert result.text == "Hello"
|
||||
# Buffer should be empty after flush
|
||||
assert self.aggregator.text.text == ""
|
||||
|
||||
async def test_flush_after_lookahead_confirmed(self):
|
||||
"""Test flush after lookahead has already confirmed sentence."""
|
||||
text = "Hello. W"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# First sentence should be complete (lookahead from "W" confirmed it)
|
||||
assert len(results) == 1
|
||||
assert results[0].text == "Hello."
|
||||
|
||||
# flush() returns any remaining text (the "W" in this case)
|
||||
result = await self.aggregator.flush()
|
||||
assert result.text == "W"
|
||||
|
||||
@@ -17,7 +17,14 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
await self.aggregator.reset()
|
||||
|
||||
# No tags involved, aggregate at end of sentence.
|
||||
result = await self.aggregator.aggregate("Hello Pipecat!")
|
||||
text = "Hello Pipecat!"
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# Should still be waiting for lookahead after "!"
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Flush to get the pending sentence
|
||||
result = await self.aggregator.flush()
|
||||
self.assertEqual(result.text, "Hello Pipecat!")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
@@ -26,7 +33,14 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
await self.aggregator.reset()
|
||||
|
||||
# Tags involved, avoid aggregation during tags.
|
||||
result = await self.aggregator.aggregate("My email is <spell>foo@pipecat.ai</spell>.")
|
||||
text = "My email is <spell>foo@pipecat.ai</spell>."
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
# Should still be waiting for lookahead after "."
|
||||
self.assertEqual(len(results), 0)
|
||||
|
||||
# Flush to get the pending sentence
|
||||
result = await self.aggregator.flush()
|
||||
self.assertEqual(result.text, "My email is <spell>foo@pipecat.ai</spell>.")
|
||||
self.assertEqual(result.type, "sentence")
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
@@ -34,25 +48,17 @@ class TestSkipTagsAggregator(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_streaming_tags(self):
|
||||
await self.aggregator.reset()
|
||||
|
||||
# Tags involved, stream small chunk of texts.
|
||||
result = await self.aggregator.aggregate("My email is <sp")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <sp")
|
||||
# Tags involved
|
||||
text = "My email is <spell>foo.bar@pipecat.ai</spell>."
|
||||
results = [agg async for agg in self.aggregator.aggregate(text)]
|
||||
|
||||
result = await self.aggregator.aggregate("ell>foo.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.")
|
||||
|
||||
result = await self.aggregator.aggregate("bar@pipecat.")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.")
|
||||
|
||||
result = await self.aggregator.aggregate("ai</spe")
|
||||
self.assertIsNone(result)
|
||||
self.assertEqual(self.aggregator.text.text, "My email is <spell>foo.bar@pipecat.ai</spe")
|
||||
# Should still be waiting for lookahead after "."
|
||||
self.assertEqual(len(results), 0)
|
||||
self.assertEqual(self.aggregator.text.text, text)
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
result = await self.aggregator.aggregate("ll>.")
|
||||
self.assertEqual(result.text, "My email is <spell>foo.bar@pipecat.ai</spell>.")
|
||||
# Flush to get the pending sentence
|
||||
result = await self.aggregator.flush()
|
||||
self.assertEqual(result.text, text)
|
||||
self.assertEqual(self.aggregator.text.text, "")
|
||||
self.assertEqual(self.aggregator.text.type, "sentence")
|
||||
|
||||
14
uv.lock
generated
14
uv.lock
generated
@@ -36,12 +36,12 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "aic-sdk"
|
||||
version = "1.1.0"
|
||||
version = "1.2.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "numpy" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/99/83/bf38b95d98c67b8ebc574fb4a4f23c07a3740b51992d7522976173d30b98/aic_sdk-1.1.0.tar.gz", hash = "sha256:04e08df695581c8cb4db8acca20e73815e9f449e7bd08e0162fd55518c727963", size = 34954, upload-time = "2025-11-11T20:45:24.25Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f9/ba/3ebe31b91e03d42437ec864e9d2af3a52b7ccc73a1a0c1026275956270b0/aic_sdk-1.2.0.tar.gz", hash = "sha256:eeda9a181c679f175dbe6f0efc0c67ec98ff3d84cfe01541fef7fa12ecd505ca", size = 35606, upload-time = "2025-11-20T14:42:14.333Z" }
|
||||
|
||||
[[package]]
|
||||
name = "aioboto3"
|
||||
@@ -4564,6 +4564,9 @@ neuphonic = [
|
||||
noisereduce = [
|
||||
{ name = "noisereduce" },
|
||||
]
|
||||
nvidia = [
|
||||
{ name = "nvidia-riva-client" },
|
||||
]
|
||||
openai = [
|
||||
{ name = "websockets" },
|
||||
]
|
||||
@@ -4666,7 +4669,7 @@ docs = [
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "accelerate", marker = "extra == 'moondream'", specifier = "~=1.10.0" },
|
||||
{ name = "aic-sdk", marker = "extra == 'aic'", specifier = "~=1.1.0" },
|
||||
{ name = "aic-sdk", marker = "extra == 'aic'", specifier = "~=1.2.0" },
|
||||
{ name = "aioboto3", marker = "extra == 'aws'", specifier = "~=15.5.0" },
|
||||
{ name = "aiofiles", specifier = ">=24.1.0,<25" },
|
||||
{ name = "aiohttp", specifier = ">=3.11.12,<4" },
|
||||
@@ -4706,7 +4709,7 @@ requires-dist = [
|
||||
{ name = "noisereduce", marker = "extra == 'noisereduce'", specifier = "~=3.0.3" },
|
||||
{ name = "numba", specifier = "==0.61.2" },
|
||||
{ name = "numpy", specifier = ">=1.26.4,<3" },
|
||||
{ name = "nvidia-riva-client", marker = "extra == 'riva'", specifier = "~=2.21.1" },
|
||||
{ name = "nvidia-riva-client", marker = "extra == 'nvidia'", specifier = "~=2.21.1" },
|
||||
{ name = "onnxruntime", marker = "extra == 'local-smart-turn-v3'", specifier = ">=1.20.1,<2" },
|
||||
{ name = "onnxruntime", marker = "extra == 'silero'", specifier = ">=1.20.1,<2" },
|
||||
{ name = "openai", specifier = ">=1.74.0,<3" },
|
||||
@@ -4717,6 +4720,7 @@ requires-dist = [
|
||||
{ name = "opentelemetry-sdk", marker = "extra == 'tracing'", specifier = ">=1.33.0" },
|
||||
{ name = "ormsgpack", marker = "extra == 'fish'", specifier = "~=1.7.0" },
|
||||
{ name = "pillow", specifier = ">=11.1.0,<12" },
|
||||
{ name = "pipecat-ai", extras = ["nvidia"], marker = "extra == 'riva'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'assemblyai'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'asyncai'" },
|
||||
{ name = "pipecat-ai", extras = ["websockets-base"], marker = "extra == 'aws'" },
|
||||
@@ -4767,7 +4771,7 @@ requires-dist = [
|
||||
{ name = "wait-for2", marker = "python_full_version < '3.12'", specifier = ">=0.4.1" },
|
||||
{ name = "websockets", marker = "extra == 'websockets-base'", specifier = ">=13.1,<16.0" },
|
||||
]
|
||||
provides-extras = ["aic", "anthropic", "assemblyai", "asyncai", "aws", "aws-nova-sonic", "azure", "cartesia", "cerebras", "daily", "deepgram", "deepseek", "elevenlabs", "fal", "fireworks", "fish", "gladia", "google", "grok", "groq", "gstreamer", "heygen", "hume", "inworld", "koala", "krisp", "langchain", "livekit", "lmnt", "local", "local-smart-turn", "local-smart-turn-v3", "mcp", "mem0", "mistral", "mlx-whisper", "moondream", "neuphonic", "nim", "noisereduce", "openai", "openpipe", "openrouter", "perplexity", "playht", "qwen", "remote-smart-turn", "rime", "riva", "runner", "sagemaker", "sambanova", "sarvam", "sentry", "silero", "simli", "soniox", "soundfile", "speechmatics", "strands", "tavus", "together", "tracing", "ultravox", "webrtc", "websocket", "websockets-base", "whisper"]
|
||||
provides-extras = ["aic", "anthropic", "assemblyai", "asyncai", "aws", "aws-nova-sonic", "azure", "cartesia", "cerebras", "daily", "deepgram", "deepseek", "elevenlabs", "fal", "fireworks", "fish", "gladia", "google", "grok", "groq", "gstreamer", "heygen", "hume", "inworld", "koala", "krisp", "langchain", "livekit", "lmnt", "local", "local-smart-turn", "local-smart-turn-v3", "mcp", "mem0", "mistral", "mlx-whisper", "moondream", "neuphonic", "noisereduce", "nvidia", "openai", "openpipe", "openrouter", "perplexity", "playht", "qwen", "remote-smart-turn", "rime", "riva", "runner", "sagemaker", "sambanova", "sarvam", "sentry", "silero", "simli", "soniox", "soundfile", "speechmatics", "strands", "tavus", "together", "tracing", "ultravox", "webrtc", "websocket", "websockets-base", "whisper"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
|
||||
Reference in New Issue
Block a user