Compare commits
1 Commits
v0.0.100
...
fix/event-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f2ffa8fed |
2
.github/workflows/coverage.yaml
vendored
2
.github/workflows/coverage.yaml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra livekit --extra websocket
|
||||
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra websocket
|
||||
|
||||
- name: Run tests with coverage
|
||||
run: |
|
||||
|
||||
2
.github/workflows/tests.yaml
vendored
2
.github/workflows/tests.yaml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra livekit --extra websocket
|
||||
uv sync --group dev --extra anthropic --extra aws --extra google --extra langchain --extra websocket
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
|
||||
123
CHANGELOG.md
123
CHANGELOG.md
@@ -7,129 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
<!-- towncrier release notes start -->
|
||||
|
||||
## [0.0.100] - 2026-01-20
|
||||
|
||||
### Added
|
||||
|
||||
- Added Hathora service to support Hathora-hosted TTS and STT models (only
|
||||
non-streaming)
|
||||
(PR [#3169](https://github.com/pipecat-ai/pipecat/pull/3169))
|
||||
|
||||
- Added `CambTTSService`, using Camb.ai's TTS integration with MARS models
|
||||
(mars-flash, mars-pro, mars-instruct) for high-quality text-to-speech
|
||||
synthesis.
|
||||
(PR [#3349](https://github.com/pipecat-ai/pipecat/pull/3349))
|
||||
|
||||
- Added the `additional_headers` param to `WebsocketClientParams`, allowing
|
||||
`WebsocketClientTransport` to send custom headers on connect, for cases such
|
||||
as authentication.
|
||||
(PR [#3461](https://github.com/pipecat-ai/pipecat/pull/3461))
|
||||
|
||||
- Added `UserIdleController` for detecting user idle state, integrated into
|
||||
`LLMUserAggregator` and `UserTurnProcessor` via optional `user_idle_timeout`
|
||||
parameter. Emits `on_user_turn_idle` event for application-level handling.
|
||||
Deprecated `UserIdleProcessor` in favor of the new compositional approach.
|
||||
(PR [#3482](https://github.com/pipecat-ai/pipecat/pull/3482))
|
||||
|
||||
- Added `on_user_mute_started` and `on_user_mute_stopped` event handlers to
|
||||
`LLMUserAggregator` for tracking user mute state changes.
|
||||
(PR [#3490](https://github.com/pipecat-ai/pipecat/pull/3490))
|
||||
|
||||
### Changed
|
||||
|
||||
- Enhanced interruption handling in `AsyncAITTSService` by supporting
|
||||
multi-context WebSocket sessions for more robust context management.
|
||||
(PR [#3287](https://github.com/pipecat-ai/pipecat/pull/3287))
|
||||
|
||||
- Throttle `UserSpeakingFrame` to broadcast at most every 200ms instead of on
|
||||
every audio chunk, reducing frame processing overhead during user speech.
|
||||
(PR [#3483](https://github.com/pipecat-ai/pipecat/pull/3483))
|
||||
|
||||
### Deprecated
|
||||
|
||||
- For consistency with other package names, we just deprecated
|
||||
`pipecat.turns.mute` (introduced in Pipecat 0.0.99) in favor of
|
||||
`pipecat.turns.user_mute`.
|
||||
(PR [#3479](https://github.com/pipecat-ai/pipecat/pull/3479))
|
||||
|
||||
### Fixed
|
||||
|
||||
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.
|
||||
(PR [#3287](https://github.com/pipecat-ai/pipecat/pull/3287))
|
||||
|
||||
- Fixed an issue where the "bot-llm-text" RTVI event would not fire for
|
||||
realtime (speech-to-speech) services:
|
||||
|
||||
- `AWSNovaSonicLLMService`
|
||||
- `GeminiLiveLLMService`
|
||||
- `OpenAIRealtimeLLMService`
|
||||
- `GrokRealtimeLLMService`
|
||||
|
||||
The issue was that these services weren't pushing `LLMTextFrame`s. Now
|
||||
they do.
|
||||
(PR [#3446](https://github.com/pipecat-ai/pipecat/pull/3446))
|
||||
|
||||
- Fixed an issue where `on_user_turn_stop_timeout` could fire while a user is
|
||||
talking when using `ExternalUserTurnStrategies`.
|
||||
(PR [#3454](https://github.com/pipecat-ai/pipecat/pull/3454))
|
||||
|
||||
- Fixed an issue where user turn start strategies were not being reset after a
|
||||
user turn started, causing incorrect strategy behavior.
|
||||
(PR [#3455](https://github.com/pipecat-ai/pipecat/pull/3455))
|
||||
|
||||
- Fixed `MinWordsUserTurnStartStrategy` to not aggregate transcriptions,
|
||||
preventing incorrect turn starts when words are spoken with pauses between
|
||||
them.
|
||||
(PR [#3462](https://github.com/pipecat-ai/pipecat/pull/3462))
|
||||
|
||||
- Fixed an issue where Grok Realtime would error out when running with
|
||||
SmallWebRTC transport.
|
||||
(PR [#3480](https://github.com/pipecat-ai/pipecat/pull/3480))
|
||||
|
||||
- Fixed a `Mem0MemoryService` issue where passing `async_mode: true` was
|
||||
causing an error. See
|
||||
https://docs.mem0.ai/platform/features/async-mode-default-change.
|
||||
(PR [#3484](https://github.com/pipecat-ai/pipecat/pull/3484))
|
||||
|
||||
- Fixed `AWSNovaSonicLLMService.reset_conversation()`, which would previously
|
||||
error out. Now it successfully reconnects and "rehydrates" from the context
|
||||
object.
|
||||
(PR [#3486](https://github.com/pipecat-ai/pipecat/pull/3486))
|
||||
|
||||
- Fixed `AzureTTSService` transcript formatting issues:
|
||||
- Punctuation now appears without extra spaces (e.g., "Hello!" instead of
|
||||
"Hello !")
|
||||
- CJK languages (Chinese, Japanese, Korean) no longer have unwanted spaces
|
||||
between characters
|
||||
(PR [#3489](https://github.com/pipecat-ai/pipecat/pull/3489))
|
||||
|
||||
- Fixed an issue where `UninterruptibleFrame` frames would not be preserved in
|
||||
some cases.
|
||||
(PR [#3494](https://github.com/pipecat-ai/pipecat/pull/3494))
|
||||
|
||||
- Fixed memory leak in `LiveKitTransport` when `video_in_enabled` is `False`.
|
||||
(PR [#3499](https://github.com/pipecat-ai/pipecat/pull/3499))
|
||||
|
||||
- Fixed an issue in `AIService` where unhandled exceptions in `start()`,
|
||||
`stop()`, or `cancel()` implementations would prevent `process_frame()` to
|
||||
continue and therefore `StartFrame`, `EndFrame`, or `CancelFrame` from being
|
||||
pushed downstream, causing the pipeline to not start or stop properly.
|
||||
(PR [#3503](https://github.com/pipecat-ai/pipecat/pull/3503))
|
||||
|
||||
- Moved `NVIDIATTSService` and `NVIDIASTTService` client initialization from
|
||||
constructor to `start()` for better error handling.
|
||||
(PR [#3504](https://github.com/pipecat-ai/pipecat/pull/3504))
|
||||
|
||||
- Optimized `NVIDIATTSService` to process incoming audio frames immediately.
|
||||
(PR [#3509](https://github.com/pipecat-ai/pipecat/pull/3509))
|
||||
|
||||
- Optimized `NVIDIASTTService` by removing unnecessary queue and task.
|
||||
(PR [#3509](https://github.com/pipecat-ai/pipecat/pull/3509))
|
||||
|
||||
- Fixed a `CambTTSService` issue where client was being initialized in the
|
||||
constructor which wouldn't allow for proper Pipeline error handling.
|
||||
(PR [#3511](https://github.com/pipecat-ai/pipecat/pull/3511))
|
||||
|
||||
## [0.0.99] - 2026-01-13
|
||||
|
||||
### Added
|
||||
|
||||
1
changelog/3169.added.md
Normal file
1
changelog/3169.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added Hathora service to support Hathora-hosted TTS and STT models (only non-streaming)
|
||||
1
changelog/3287.changed.md
Normal file
1
changelog/3287.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Enhanced interruption handling in `AsyncAITTSService` by supporting multi-context WebSocket sessions for more robust context management.
|
||||
1
changelog/3287.fixed.md
Normal file
1
changelog/3287.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Corrected TTFB metric calculation in `AsyncAIHttpTTSService`.
|
||||
1
changelog/3349.added.md
Normal file
1
changelog/3349.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `CambTTSService`, using Camb.ai's TTS integration with MARS models (mars-flash, mars-pro, mars-instruct) for high-quality text-to-speech synthesis.
|
||||
8
changelog/3446.fixed.md
Normal file
8
changelog/3446.fixed.md
Normal file
@@ -0,0 +1,8 @@
|
||||
- Fixed an issue where the "bot-llm-text" RTVI event would not fire for realtime (speech-to-speech) services:
|
||||
|
||||
- `AWSNovaSonicLLMService`
|
||||
- `GeminiLiveLLMService`
|
||||
- `OpenAIRealtimeLLMService`
|
||||
- `GrokRealtimeLLMService`
|
||||
|
||||
The issue was that these services weren't pushing `LLMTextFrame`s. Now they do.
|
||||
1
changelog/3454.fixed.md
Normal file
1
changelog/3454.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed an issue where `on_user_turn_stop_timeout` could fire while a user is talking when using `ExternalUserTurnStrategies`.
|
||||
1
changelog/3455.fixed.md
Normal file
1
changelog/3455.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed an issue where user turn start strategies were not being reset after a user turn started, causing incorrect strategy behavior.
|
||||
1
changelog/3461.added.md
Normal file
1
changelog/3461.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added the `additional_headers` param to `WebsocketClientParams`, allowing `WebsocketClientTransport` to send custom headers on connect, for cases such as authentication.
|
||||
1
changelog/3462.fixed.md
Normal file
1
changelog/3462.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed `MinWordsUserTurnStartStrategy` to not aggregate transcriptions, preventing incorrect turn starts when words are spoken with pauses between them.
|
||||
1
changelog/3479.deprecated.md
Normal file
1
changelog/3479.deprecated.md
Normal file
@@ -0,0 +1 @@
|
||||
- For consistency with other package names, we just deprecated `pipecat.turns.mute` (introduced in Pipecat 0.0.99) in favor of `pipecat.turns.user_mute`.
|
||||
1
changelog/3480.fixed.md
Normal file
1
changelog/3480.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed an issue where Grok Realtime would error out when running with SmallWebRTC transport.
|
||||
1
changelog/3482.added.md
Normal file
1
changelog/3482.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `UserIdleController` for detecting user idle state, integrated into `LLMUserAggregator` and `UserTurnProcessor` via optional `user_idle_timeout` parameter. Emits `on_user_turn_idle` event for application-level handling. Deprecated `UserIdleProcessor` in favor of the new compositional approach.
|
||||
1
changelog/3483.changed.md
Normal file
1
changelog/3483.changed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Throttle `UserSpeakingFrame` to broadcast at most every 200ms instead of on every audio chunk, reducing frame processing overhead during user speech.
|
||||
1
changelog/3484.fixed.md
Normal file
1
changelog/3484.fixed.md
Normal file
@@ -0,0 +1 @@
|
||||
- Fixed a `Mem0MemoryService` issue where passing `async_mode: true` was causing an error. See https://docs.mem0.ai/platform/features/async-mode-default-change.
|
||||
3
changelog/3489.fixed.md
Normal file
3
changelog/3489.fixed.md
Normal file
@@ -0,0 +1,3 @@
|
||||
- Fixed `AzureTTSService` transcript formatting issues:
|
||||
- Punctuation now appears without extra spaces (e.g., "Hello!" instead of "Hello !")
|
||||
- CJK languages (Chinese, Japanese, Korean) no longer have unwanted spaces between characters
|
||||
1
changelog/3490.added.md
Normal file
1
changelog/3490.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `on_user_mute_started` and `on_user_mute_stopped` event handlers to `LLMUserAggregator` for tracking user mute state changes.
|
||||
@@ -10,6 +10,7 @@ import os
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
|
||||
@@ -45,6 +45,7 @@ from pipecat.services.google.tts import GoogleTTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.filters.krisp_viva_filter import KrispVivaFilter
|
||||
from pipecat.audio.turn.krisp_viva_turn import KrispVivaTurn
|
||||
from pipecat.audio.turn.krisp_viva_turn import KrispTurnParams, KrispVivaTurn
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
|
||||
@@ -22,7 +22,7 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
|
||||
@@ -17,7 +17,7 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame, UserImageRequestFrame
|
||||
from pipecat.frames.frames import LLMRunFrame, TTSSpeakFrame, UserImageRequestFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
|
||||
@@ -22,6 +22,7 @@ from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMContextAggregatorPair
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.aws.nova_sonic.llm import AWSNovaSonicLLMService
|
||||
@@ -113,14 +114,6 @@ async def load_conversation(params: FunctionCallParams):
|
||||
# "content": f"{AWSNovaSonicLLMService.AWAIT_TRIGGER_ASSISTANT_RESPONSE_INSTRUCTION}",
|
||||
# }
|
||||
# )
|
||||
# If the last message isn't from the user, add a message asking for a recap
|
||||
if messages and messages[-1].get("role") != "user":
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Can you catch me up on what we were talking about?",
|
||||
}
|
||||
)
|
||||
params.context.set_messages(messages)
|
||||
await params.llm.reset_conversation()
|
||||
# await params.llm.trigger_assistant_response()
|
||||
|
||||
@@ -9,6 +9,7 @@ import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -13,7 +13,7 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.frames.frames import LLMRunFrame, ThoughtTranscriptionMessage, TranscriptionMessage
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
|
||||
@@ -53,6 +53,8 @@ from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.grok.realtime.events import (
|
||||
SessionProperties,
|
||||
WebSearchTool,
|
||||
XSearchTool,
|
||||
)
|
||||
from pipecat.services.grok.realtime.llm import GrokRealtimeLLMService
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
agent_name = "quickstart-test"
|
||||
image = "markatdaily/quickstart-test:latest"
|
||||
secret_set = "quickstart-test-secrets"
|
||||
agent_name = "quickstart"
|
||||
image = "your_username/quickstart:0.1"
|
||||
secret_set = "quickstart-secrets"
|
||||
agent_profile = "agent-1x"
|
||||
|
||||
# RECOMMENDED: Set an image pull secret:
|
||||
# https://docs.pipecat.ai/deployment/pipecat-cloud/fundamentals/secrets#image-pull-secrets
|
||||
image_credentials = "dockerhub-access"
|
||||
# image_credentials = "your_image_pull_secret"
|
||||
|
||||
[scaling]
|
||||
min_agents = 1
|
||||
|
||||
@@ -293,13 +293,12 @@ async def run_eval_pipeline(
|
||||
"You should only call the eval function if:\n"
|
||||
"- The user explicitly attempts to answer the question, AND\n"
|
||||
f"- Their answer can be cleanly evaluated using: {eval_config.eval}\n"
|
||||
"Ignore greetings, comments, non-answers, or requests for clarification.\n"
|
||||
"Numerical word answers are allowed (e.g., 'five' is the same as '5').\n"
|
||||
"Ignore greetings, comments, non-answers, or requests for clarification."
|
||||
)
|
||||
if eval_config.eval_speaks_first:
|
||||
system_prompt = f"You are an evaluation agent, be extremly brief. You will start the conversation by saying: '{example_prompt}'. {common_system_prompt}"
|
||||
system_prompt = f"You are an evaluation agent, be extremly brief. Numerical word answers are allowed. You will start the conversation by saying: '{example_prompt}'. {common_system_prompt}"
|
||||
else:
|
||||
system_prompt = f"You are an evaluation agent, be extremly brief. First, ask one question: {example_prompt}. {common_system_prompt}"
|
||||
system_prompt = f"You are an evaluation agent, be extremly brief. Numerical word answers are allowed. First, ask one question: {example_prompt}. {common_system_prompt}"
|
||||
|
||||
messages = [
|
||||
{
|
||||
|
||||
@@ -137,7 +137,6 @@ TESTS_07 = [
|
||||
# ("07zd-interruptible-aicoustics.py", EVAL_SIMPLE_MATH),
|
||||
("07ze-interruptible-hume.py", EVAL_SIMPLE_MATH),
|
||||
("07zf-interruptible-gradium.py", EVAL_SIMPLE_MATH),
|
||||
("07zg-interruptible-camb.py", EVAL_SIMPLE_MATH),
|
||||
("07zh-interruptible-hathora.py", EVAL_SIMPLE_MATH),
|
||||
# Needs a local XTTS docker instance running.
|
||||
# ("07i-interruptible-xtts.py", EVAL_SIMPLE_MATH),
|
||||
|
||||
@@ -22,7 +22,7 @@ from pathlib import Path
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import soundfile as sf # noqa: F401
|
||||
import soundfile as sf
|
||||
from audio_file_utils import calculate_audio_stats, read_audio_file, write_audio_file
|
||||
except ImportError as e:
|
||||
print(f"Error: Missing required dependencies: {e}")
|
||||
|
||||
@@ -23,7 +23,7 @@ from pathlib import Path
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
import soundfile as sf # noqa: F401
|
||||
import soundfile as sf
|
||||
from audio_file_utils import read_audio_file
|
||||
except ImportError as e:
|
||||
print(f"Error: Missing required dependencies: {e}")
|
||||
|
||||
@@ -10,7 +10,7 @@ import base64
|
||||
import copy
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
from typing import Any, Dict, List, Literal, Optional, TypedDict
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, TypedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypedDict
|
||||
|
||||
from loguru import logger
|
||||
from openai import NotGiven
|
||||
|
||||
@@ -7,8 +7,10 @@
|
||||
"""OpenAI LLM adapter for Pipecat."""
|
||||
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, List, TypedDict
|
||||
|
||||
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
|
||||
from openai._types import NotGiven as OpenAINotGiven
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
This module provides an audio filter implementation using Krisp VIVA SDK.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -34,6 +34,7 @@ from PIL import Image
|
||||
from pipecat.adapters.base_llm_adapter import BaseLLMAdapter
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import AudioRawFrame, Frame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
# JSON custom encoder to handle bytes arrays so that we can log contexts
|
||||
# with images to the console.
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.frames.frames import Frame, TranscriptionFrame
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
|
||||
@@ -950,8 +950,7 @@ class FrameProcessor(BaseObject):
|
||||
# Process current queue and keep UninterruptibleFrame frames.
|
||||
while not self.__process_queue.empty():
|
||||
item = self.__process_queue.get_nowait()
|
||||
frame = item[0]
|
||||
if isinstance(frame, UninterruptibleFrame):
|
||||
if isinstance(item, UninterruptibleFrame):
|
||||
new_queue.put_nowait(item)
|
||||
self.__process_queue.task_done()
|
||||
|
||||
|
||||
@@ -263,7 +263,7 @@ def _setup_webrtc_routes(
|
||||
"""Handle WebRTC offer requests via SmallWebRTCRequestHandler."""
|
||||
|
||||
# Prepare runner arguments with the callback to run your bot
|
||||
async def webrtc_connection_callback(connection: SmallWebRTCConnection):
|
||||
async def webrtc_connection_callback(connection):
|
||||
bot_module = _get_bot_module()
|
||||
|
||||
runner_args = SmallWebRTCRunnerArguments(
|
||||
@@ -406,7 +406,13 @@ def _setup_whatsapp_routes(app: FastAPI):
|
||||
return
|
||||
|
||||
try:
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
||||
from pipecat.transports.smallwebrtc.request_handler import (
|
||||
SmallWebRTCRequest,
|
||||
SmallWebRTCRequestHandler,
|
||||
)
|
||||
from pipecat.transports.whatsapp.api import WhatsAppWebhookRequest
|
||||
from pipecat.transports.whatsapp.client import WhatsAppClient
|
||||
except ImportError as e:
|
||||
|
||||
@@ -148,11 +148,11 @@ class AIService(FrameProcessor):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, StartFrame):
|
||||
await self._start(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self._stop(frame)
|
||||
await self.start(frame)
|
||||
elif isinstance(frame, CancelFrame):
|
||||
await self._cancel(frame)
|
||||
await self.cancel(frame)
|
||||
elif isinstance(frame, EndFrame):
|
||||
await self.stop(frame)
|
||||
|
||||
async def process_generator(self, generator: AsyncGenerator[Frame | None, None]):
|
||||
"""Process frames from an async generator.
|
||||
@@ -169,21 +169,3 @@ class AIService(FrameProcessor):
|
||||
await self.push_error_frame(f)
|
||||
else:
|
||||
await self.push_frame(f)
|
||||
|
||||
async def _start(self, frame: StartFrame):
|
||||
try:
|
||||
await self.start(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"{self}: exception processing {frame}: {e}")
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
try:
|
||||
await self.stop(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"{self}: exception processing {frame}: {e}")
|
||||
|
||||
async def _cancel(self, frame: CancelFrame):
|
||||
try:
|
||||
await self.cancel(frame)
|
||||
except Exception as e:
|
||||
logger.error(f"{self}: exception processing {frame}: {e}")
|
||||
|
||||
@@ -296,7 +296,6 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._user_text_buffer = ""
|
||||
self._assistant_text_buffer = ""
|
||||
self._completed_tool_calls = set()
|
||||
self._audio_input_started = False
|
||||
|
||||
file_path = files("pipecat.services.aws.nova_sonic").joinpath("ready.wav")
|
||||
with wave.open(file_path.open("rb"), "rb") as wav_file:
|
||||
@@ -533,30 +532,14 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
if system_instruction:
|
||||
await self._send_text_event(text=system_instruction, role=Role.SYSTEM)
|
||||
|
||||
# Send conversation history (except for the last message if it's from the
|
||||
# user, which we'll send as interactive after starting audio input)
|
||||
messages = llm_connection_params["messages"]
|
||||
last_user_message = None
|
||||
for i, message in enumerate(messages):
|
||||
# Send conversation history
|
||||
for message in llm_connection_params["messages"]:
|
||||
# logger.debug(f"Seeding conversation history with message: {message}")
|
||||
is_last_message = i == len(messages) - 1
|
||||
if is_last_message and message.role == Role.USER:
|
||||
# Save for sending after audio input starts
|
||||
last_user_message = message
|
||||
else:
|
||||
await self._send_text_event(text=message.text, role=message.role)
|
||||
await self._send_text_event(text=message.text, role=message.role)
|
||||
|
||||
# Start audio input
|
||||
await self._send_audio_input_start_event()
|
||||
|
||||
# Now send the last user message as interactive to trigger bot response
|
||||
if last_user_message:
|
||||
# logger.debug(
|
||||
# f"Sending last user message as interactive to trigger bot response: {last_user_message}")
|
||||
await self._send_text_event(
|
||||
text=last_user_message.text, role=last_user_message.role, interactive=True
|
||||
)
|
||||
|
||||
# Start receiving events
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
|
||||
@@ -619,7 +602,6 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
self._user_text_buffer = ""
|
||||
self._assistant_text_buffer = ""
|
||||
self._completed_tool_calls = set()
|
||||
self._audio_input_started = False
|
||||
|
||||
logger.info("Finished disconnecting")
|
||||
except Exception as e:
|
||||
@@ -745,18 +727,8 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
}}
|
||||
'''
|
||||
await self._send_client_event(audio_content_start)
|
||||
self._audio_input_started = True
|
||||
|
||||
async def _send_text_event(self, text: str, role: Role, interactive: bool = False):
|
||||
"""Send a text event to the LLM.
|
||||
|
||||
Args:
|
||||
text: The text content to send.
|
||||
role: The role associated with the text (e.g., USER, ASSISTANT, SYSTEM).
|
||||
interactive: Whether the content is interactive. Defaults to False.
|
||||
False: conversation history or system instruction, sent prior to interactive audio
|
||||
True: text input sent during (or at the start of) interactive audio
|
||||
"""
|
||||
async def _send_text_event(self, text: str, role: Role):
|
||||
if not self._stream or not self._prompt_name or not text:
|
||||
return
|
||||
|
||||
@@ -769,7 +741,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
"promptName": "{self._prompt_name}",
|
||||
"contentName": "{content_name}",
|
||||
"type": "TEXT",
|
||||
"interactive": {json.dumps(interactive)},
|
||||
"interactive": true,
|
||||
"role": "{role.value}",
|
||||
"textInputConfiguration": {{
|
||||
"mediaType": "text/plain"
|
||||
@@ -807,7 +779,7 @@ class AWSNovaSonicLLMService(LLMService):
|
||||
await self._send_client_event(text_content_end)
|
||||
|
||||
async def _send_user_audio_event(self, audio: bytes):
|
||||
if not self._stream or not self._audio_input_started:
|
||||
if not self._stream:
|
||||
return
|
||||
|
||||
blob = base64.b64encode(audio)
|
||||
|
||||
@@ -10,6 +10,7 @@ This module provides a WebSocket-based connection to AWS Transcribe for real-tim
|
||||
speech-to-text transcription with support for multiple languages and audio formats.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
|
||||
@@ -10,6 +10,7 @@ This module provides integration with Amazon Polly for text-to-speech synthesis,
|
||||
supporting multiple languages, voices, and SSML features.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
|
||||
@@ -17,8 +17,3 @@ with warnings.catch_warnings():
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AWSNovaSonicLLMService",
|
||||
"Params",
|
||||
]
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.transcriptions.language import Language, resolve_language
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import io
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
|
||||
|
||||
@@ -199,11 +199,10 @@ class CambTTSService(TTSService):
|
||||
"""
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._timeout = timeout
|
||||
|
||||
params = params or CambTTSService.InputParams()
|
||||
|
||||
self._client = AsyncCambAI(api_key=api_key, timeout=timeout)
|
||||
|
||||
# Warn if sample rate doesn't match model's supported rate
|
||||
if sample_rate and sample_rate != MODEL_SAMPLE_RATES.get(model):
|
||||
logger.warning(
|
||||
@@ -223,8 +222,6 @@ class CambTTSService(TTSService):
|
||||
self.set_voice(str(voice_id))
|
||||
self._voice_id = voice_id
|
||||
|
||||
self._client = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
|
||||
@@ -252,8 +249,6 @@ class CambTTSService(TTSService):
|
||||
"""
|
||||
await super().start(frame)
|
||||
|
||||
self._client = AsyncCambAI(api_key=self._api_key, timeout=self._timeout)
|
||||
|
||||
# Use model-specific sample rate if not explicitly specified
|
||||
if not self._init_sample_rate:
|
||||
self._sample_rate = MODEL_SAMPLE_RATES.get(self.model_name, 22050)
|
||||
@@ -294,8 +289,6 @@ class CambTTSService(TTSService):
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStartedFrame()
|
||||
|
||||
assert self._client is not None, "Camb.ai TTS service not initialized"
|
||||
|
||||
# Buffer for aligning chunks to 2-byte boundaries (16-bit PCM)
|
||||
audio_buffer = b""
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
"""Cerebras LLM service implementation using OpenAI-compatible interface."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
|
||||
|
||||
@@ -27,6 +27,7 @@ from pipecat.frames.frames import (
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
"""DeepSeek LLM service implementation using OpenAI-compatible interface."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
"""Fireworks AI service implementation using OpenAI-compatible interface."""
|
||||
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
|
||||
|
||||
@@ -1,7 +1,2 @@
|
||||
from .file_api import GeminiFileAPI
|
||||
from .gemini import GeminiMultimodalLiveLLMService
|
||||
|
||||
__all__ = [
|
||||
"GeminiFileAPI",
|
||||
"GeminiMultimodalLiveLLMService",
|
||||
]
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
from .file_api import GeminiFileAPI
|
||||
from .llm import GeminiLiveLLMService
|
||||
from .llm_vertex import GeminiLiveVertexLLMService
|
||||
|
||||
__all__ = [
|
||||
"GeminiFileAPI",
|
||||
"GeminiLiveLLMService",
|
||||
"GeminiLiveVertexLLMService",
|
||||
]
|
||||
|
||||
@@ -40,6 +40,7 @@ from pipecat.frames.frames import (
|
||||
LLMThoughtStartFrame,
|
||||
LLMThoughtTextFrame,
|
||||
LLMUpdateSettingsFrame,
|
||||
OutputImageRawFrame,
|
||||
UserImageRawFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import LLMTokenUsage
|
||||
|
||||
@@ -15,7 +15,9 @@ from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import Frame
|
||||
from pipecat.observers.base_observer import FramePushed
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.processors.frameworks.rtvi import RTVIObserver, RTVIProcessor
|
||||
from pipecat.services.google.frames import LLMSearchOrigin, LLMSearchResponseFrame
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Mapping, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -16,6 +16,7 @@ from pipecat import version as pipecat_version
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
StartFrame,
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import Any, Dict, List, Optional
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.frames.frames import Frame, LLMContextFrame, LLMMessagesFrame
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, LLMContextFrame, LLMMessagesFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.openai_llm_context import (
|
||||
OpenAILLMContext,
|
||||
|
||||
@@ -9,10 +9,12 @@
|
||||
from typing import List, Sequence
|
||||
|
||||
from loguru import logger
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
|
||||
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMInvocationParams
|
||||
from pipecat.frames.frames import FunctionCallFromLLM
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ def detect_device():
|
||||
and dtype is the recommended torch data type for that device.
|
||||
"""
|
||||
try:
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
import intel_extension_for_pytorch
|
||||
|
||||
if torch.xpu.is_available():
|
||||
return torch.device("xpu"), torch.float32
|
||||
|
||||
@@ -134,7 +134,6 @@ class NvidiaSTTService(STTService):
|
||||
|
||||
params = params or NvidiaSTTService.InputParams()
|
||||
|
||||
self._server = server
|
||||
self._api_key = api_key
|
||||
self._use_ssl = use_ssl
|
||||
self._profanity_filter = False
|
||||
@@ -163,53 +162,18 @@ class NvidiaSTTService(STTService):
|
||||
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
|
||||
self._asr_service = None
|
||||
self._queue = None
|
||||
self._config = None
|
||||
self._thread_task = None
|
||||
|
||||
def _initialize_client(self):
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {self._api_key}"],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
|
||||
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
|
||||
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
def _create_recognition_config(self):
|
||||
"""Create the NVIDIA Riva ASR recognition configuration."""
|
||||
config = riva.client.StreamingRecognitionConfig(
|
||||
config=riva.client.RecognitionConfig(
|
||||
encoding=riva.client.AudioEncoding.LINEAR_PCM,
|
||||
language_code=self._language_code,
|
||||
model="",
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=not self._no_verbatim_transcripts,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
audio_channel_count=1,
|
||||
),
|
||||
interim_results=True,
|
||||
)
|
||||
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
return config
|
||||
self._queue = None
|
||||
self._config = None
|
||||
self._thread_task = None
|
||||
self._response_task = None
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if this service can generate processing metrics.
|
||||
@@ -242,15 +206,49 @@ class NvidiaSTTService(STTService):
|
||||
frame: StartFrame indicating pipeline start.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._initialize_client()
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
if self._config:
|
||||
return
|
||||
|
||||
config = riva.client.StreamingRecognitionConfig(
|
||||
config=riva.client.RecognitionConfig(
|
||||
encoding=riva.client.AudioEncoding.LINEAR_PCM,
|
||||
language_code=self._language_code,
|
||||
model="",
|
||||
max_alternatives=1,
|
||||
profanity_filter=self._profanity_filter,
|
||||
enable_automatic_punctuation=self._automatic_punctuation,
|
||||
verbatim_transcripts=not self._no_verbatim_transcripts,
|
||||
sample_rate_hertz=self.sample_rate,
|
||||
audio_channel_count=1,
|
||||
),
|
||||
interim_results=True,
|
||||
)
|
||||
|
||||
riva.client.add_word_boosting_to_config(
|
||||
config, self._boosted_lm_words, self._boosted_lm_score
|
||||
)
|
||||
|
||||
riva.client.add_endpoint_parameters_to_config(
|
||||
config,
|
||||
self._start_history,
|
||||
self._start_threshold,
|
||||
self._stop_history,
|
||||
self._stop_history_eou,
|
||||
self._stop_threshold,
|
||||
self._stop_threshold_eou,
|
||||
)
|
||||
riva.client.add_custom_configuration_to_config(config, self._custom_configuration)
|
||||
|
||||
self._config = config
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
if not self._thread_task:
|
||||
self._thread_task = self.create_task(self._thread_task_handler())
|
||||
|
||||
logger.debug(f"Initialized NvidiaSTTService with model: {self.model_name}")
|
||||
if not self._response_task:
|
||||
self._response_queue = asyncio.Queue()
|
||||
self._response_task = self.create_task(self._response_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the NVIDIA Riva STT service and clean up resources.
|
||||
@@ -275,6 +273,10 @@ class NvidiaSTTService(STTService):
|
||||
await self.cancel_task(self._thread_task)
|
||||
self._thread_task = None
|
||||
|
||||
if self._response_task:
|
||||
await self.cancel_task(self._response_task)
|
||||
self._response_task = None
|
||||
|
||||
def _response_handler(self):
|
||||
responses = self._asr_service.streaming_response_generator(
|
||||
audio_chunks=self,
|
||||
@@ -283,7 +285,9 @@ class NvidiaSTTService(STTService):
|
||||
for response in responses:
|
||||
if not response.results:
|
||||
continue
|
||||
asyncio.run_coroutine_threadsafe(self._handle_response(response), self.get_event_loop())
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._response_queue.put(response), self.get_event_loop()
|
||||
)
|
||||
|
||||
async def _thread_task_handler(self):
|
||||
try:
|
||||
@@ -335,6 +339,12 @@ class NvidiaSTTService(STTService):
|
||||
)
|
||||
)
|
||||
|
||||
async def _response_task_handler(self):
|
||||
while True:
|
||||
response = await self._response_queue.get()
|
||||
await self._handle_response(response)
|
||||
self._response_queue.task_done()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process audio data for speech-to-text transcription.
|
||||
|
||||
@@ -493,6 +503,8 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
|
||||
self._asr_service = riva.client.ASRService(auth)
|
||||
|
||||
logger.info(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
|
||||
|
||||
def _create_recognition_config(self):
|
||||
"""Create the NVIDIA Riva ASR recognition configuration."""
|
||||
# Create base configuration
|
||||
@@ -560,7 +572,6 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
await super().start(frame)
|
||||
self._initialize_client()
|
||||
self._config = self._create_recognition_config()
|
||||
logger.debug(f"Initialized NvidiaSegmentedSTTService with model: {self.model_name}")
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the language for the STT service.
|
||||
@@ -594,12 +605,21 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
Frame: TranscriptionFrame containing the transcribed text.
|
||||
"""
|
||||
try:
|
||||
assert self._asr_service is not None, "ASR service not initialized"
|
||||
assert self._config is not None, "Recognition config not created"
|
||||
|
||||
await self.start_processing_metrics()
|
||||
await self.start_ttfb_metrics()
|
||||
|
||||
# Make sure the client is initialized
|
||||
if self._asr_service is None:
|
||||
self._initialize_client()
|
||||
|
||||
# Make sure the config is created
|
||||
if self._config is None:
|
||||
self._config = self._create_recognition_config()
|
||||
|
||||
# Type assertion to satisfy the IDE
|
||||
assert self._asr_service is not None, "ASR service not initialized"
|
||||
assert self._config is not None, "Recognition config not created"
|
||||
|
||||
# Process audio with NVIDIA Riva ASR - explicitly request non-future response
|
||||
raw_response = self._asr_service.offline_recognize(audio, self._config, future=False)
|
||||
|
||||
@@ -607,40 +627,43 @@ class NvidiaSegmentedSTTService(SegmentedSTTService):
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
# Process the response - handle different possible return types
|
||||
# If it's a future-like object, get the result
|
||||
if hasattr(raw_response, "result"):
|
||||
response = raw_response.result()
|
||||
else:
|
||||
response = raw_response
|
||||
try:
|
||||
# If it's a future-like object, get the result
|
||||
if hasattr(raw_response, "result"):
|
||||
response = raw_response.result()
|
||||
else:
|
||||
response = raw_response
|
||||
|
||||
# Process transcription results
|
||||
transcription_found = False
|
||||
# Process transcription results
|
||||
transcription_found = False
|
||||
|
||||
# Now we can safely check results
|
||||
# Type hint for the IDE
|
||||
results = getattr(response, "results", [])
|
||||
# Now we can safely check results
|
||||
# Type hint for the IDE
|
||||
results = getattr(response, "results", [])
|
||||
|
||||
for result in results:
|
||||
alternatives = getattr(result, "alternatives", [])
|
||||
if alternatives:
|
||||
text = alternatives[0].transcript.strip()
|
||||
if text:
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_enum,
|
||||
)
|
||||
transcription_found = True
|
||||
for result in results:
|
||||
alternatives = getattr(result, "alternatives", [])
|
||||
if alternatives:
|
||||
text = alternatives[0].transcript.strip()
|
||||
if text:
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(
|
||||
text,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
self._language_enum,
|
||||
)
|
||||
transcription_found = True
|
||||
|
||||
await self._handle_transcription(text, True, self._language_enum)
|
||||
await self._handle_transcription(text, True, self._language_enum)
|
||||
|
||||
if not transcription_found:
|
||||
logger.debug("No transcription results found in NVIDIA Riva response")
|
||||
|
||||
except AttributeError as ae:
|
||||
logger.error(f"Unexpected response structure from NVIDIA Riva: {ae}")
|
||||
yield ErrorFrame(f"Unexpected NVIDIA Riva response format: {str(ae)}")
|
||||
|
||||
if not transcription_found:
|
||||
logger.debug(f"{self}: No transcription results found in NVIDIA Riva response")
|
||||
except AttributeError as ae:
|
||||
logger.error(f"{self}: Unexpected response structure from NVIDIA Riva: {ae}")
|
||||
yield ErrorFrame(f"{self}: Unexpected NVIDIA Riva response format: {str(ae)}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
@@ -12,7 +12,7 @@ gRPC API for high-quality speech synthesis.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import AsyncGenerator, AsyncIterable, Generator, Mapping, Optional
|
||||
from typing import AsyncGenerator, Mapping, Optional
|
||||
|
||||
from pipecat.utils.tracing.service_decorators import traced_tts
|
||||
|
||||
@@ -25,7 +25,6 @@ from pydantic import BaseModel
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
StartFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
@@ -35,12 +34,14 @@ from pipecat.transcriptions.language import Language
|
||||
|
||||
try:
|
||||
import riva.client
|
||||
import riva.client.proto.riva_tts_pb2 as rtts
|
||||
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[nvidia]`.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
NVIDIA_TTS_TIMEOUT_SECS = 5
|
||||
|
||||
|
||||
class NvidiaTTSService(TTSService):
|
||||
"""NVIDIA Riva text-to-speech service.
|
||||
@@ -92,7 +93,6 @@ class NvidiaTTSService(TTSService):
|
||||
|
||||
params = params or NvidiaTTSService.InputParams()
|
||||
|
||||
self._server = server
|
||||
self._api_key = api_key
|
||||
self._voice_id = voice_id
|
||||
self._language_code = params.language
|
||||
@@ -102,8 +102,18 @@ class NvidiaTTSService(TTSService):
|
||||
self.set_model_name(model_function_map.get("model_name"))
|
||||
self.set_voice(voice_id)
|
||||
|
||||
self._service = None
|
||||
self._config = None
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, self._use_ssl, server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
# warm up the service
|
||||
config_response = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
"""Attempt to set the TTS model.
|
||||
@@ -119,39 +129,6 @@ class NvidiaTTSService(TTSService):
|
||||
f"{self.__class__.__name__}(api_key=<api_key>, model_function_map={example})"
|
||||
)
|
||||
|
||||
def _initialize_client(self):
|
||||
if self._service is not None:
|
||||
return
|
||||
|
||||
metadata = [
|
||||
["function-id", self._function_id],
|
||||
["authorization", f"Bearer {self._api_key}"],
|
||||
]
|
||||
auth = riva.client.Auth(None, self._use_ssl, self._server, metadata)
|
||||
|
||||
self._service = riva.client.SpeechSynthesisService(auth)
|
||||
|
||||
def _create_synthesis_config(self):
|
||||
if not self._service:
|
||||
return
|
||||
|
||||
# warm up the service
|
||||
config = self._service.stub.GetRivaSynthesisConfig(
|
||||
riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
|
||||
)
|
||||
return config
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the Cartesia TTS service.
|
||||
|
||||
Args:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
await super().start(frame)
|
||||
self._initialize_client()
|
||||
self._config = self._create_synthesis_config()
|
||||
logger.debug(f"Initialized NvidiaTTSService with model: {self.model_name}")
|
||||
|
||||
@traced_tts
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
"""Generate speech from text using NVIDIA Riva TTS.
|
||||
@@ -163,43 +140,39 @@ class NvidiaTTSService(TTSService):
|
||||
Frame: Audio frames containing the synthesized speech data.
|
||||
"""
|
||||
|
||||
def read_audio_responses() -> Generator[rtts.SynthesizeSpeechResponse, None, None]:
|
||||
responses = self._service.synthesize_online(
|
||||
text,
|
||||
self._voice_id,
|
||||
self._language_code,
|
||||
sample_rate_hz=self.sample_rate,
|
||||
zero_shot_audio_prompt_file=None,
|
||||
zero_shot_quality=self._quality,
|
||||
custom_dictionary={},
|
||||
)
|
||||
return responses
|
||||
def read_audio_responses(queue: asyncio.Queue):
|
||||
def add_response(r):
|
||||
asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
|
||||
|
||||
def async_next(it):
|
||||
try:
|
||||
return next(it)
|
||||
except StopIteration:
|
||||
return None
|
||||
responses = self._service.synthesize_online(
|
||||
text,
|
||||
self._voice_id,
|
||||
self._language_code,
|
||||
sample_rate_hz=self.sample_rate,
|
||||
zero_shot_audio_prompt_file=None,
|
||||
zero_shot_quality=self._quality,
|
||||
custom_dictionary={},
|
||||
)
|
||||
for r in responses:
|
||||
add_response(r)
|
||||
add_response(None)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
add_response(None)
|
||||
|
||||
async def async_iterator(iterator) -> AsyncIterable[rtts.SynthesizeSpeechResponse]:
|
||||
while True:
|
||||
item = await asyncio.to_thread(async_next, iterator)
|
||||
if item is None:
|
||||
return
|
||||
yield item
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
try:
|
||||
assert self._service is not None, "TTS service not initialized"
|
||||
assert self._config is not None, "Synthesis configuration not created"
|
||||
queue = asyncio.Queue()
|
||||
await asyncio.to_thread(read_audio_responses, queue)
|
||||
|
||||
await self.start_ttfb_metrics()
|
||||
yield TTSStartedFrame()
|
||||
|
||||
logger.debug(f"{self}: Generating TTS [{text}]")
|
||||
|
||||
responses = await asyncio.to_thread(read_audio_responses)
|
||||
|
||||
async for resp in async_iterator(responses):
|
||||
# Wait for the thread to start.
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
|
||||
while resp:
|
||||
await self.stop_ttfb_metrics()
|
||||
frame = TTSAudioRawFrame(
|
||||
audio=resp.audio,
|
||||
@@ -207,12 +180,10 @@ class NvidiaTTSService(TTSService):
|
||||
num_channels=1,
|
||||
)
|
||||
yield frame
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
resp = await asyncio.wait_for(queue.get(), timeout=NVIDIA_TTS_TIMEOUT_SECS)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"{self} timeout waiting for audio response")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception: {e}")
|
||||
yield ErrorFrame(error=f"{self} error: {e}")
|
||||
|
||||
await self.start_tts_usage_metrics(text)
|
||||
yield TTSStoppedFrame()
|
||||
|
||||
@@ -25,13 +25,3 @@ with warnings.catch_warnings():
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureRealtimeLLMService",
|
||||
"InputAudioNoiseReduction",
|
||||
"InputAudioTranscription",
|
||||
"SemanticTurnDetection",
|
||||
"SessionProperties",
|
||||
"TurnDetection",
|
||||
"OpenAIRealtimeLLMService",
|
||||
]
|
||||
|
||||
@@ -7,13 +7,3 @@ from .events import (
|
||||
TurnDetection,
|
||||
)
|
||||
from .openai import OpenAIRealtimeBetaLLMService
|
||||
|
||||
__all__ = [
|
||||
"AzureRealtimeBetaLLMService",
|
||||
"InputAudioNoiseReduction",
|
||||
"InputAudioTranscription",
|
||||
"SemanticTurnDetection",
|
||||
"SessionProperties",
|
||||
"TurnDetection",
|
||||
"OpenAIRealtimeBetaLLMService",
|
||||
]
|
||||
|
||||
@@ -10,7 +10,7 @@ This module provides an OpenPipe-specific implementation of the OpenAI LLM servi
|
||||
enabling integration with OpenPipe's fine-tuning and monitoring capabilities.
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from openai.types import chat as openai_chat_types
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
|
||||
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
try:
|
||||
import mlx_whisper # noqa: F401
|
||||
import mlx_whisper
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use Whisper, you need to `pip install pipecat-ai[mlx-whisper]`.")
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.utils.sync.base_notifier import BaseNotifier
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
|
||||
@@ -8,6 +8,8 @@
|
||||
|
||||
import warnings
|
||||
|
||||
from pipecat.utils.sync.event_notifier import EventNotifier
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("always")
|
||||
warnings.warn(
|
||||
|
||||
@@ -759,11 +759,7 @@ class DailyTransportClient(EventHandler):
|
||||
# Increment leave counter if we successfully joined.
|
||||
self._leave_counter += 1
|
||||
|
||||
participant_id = data.get("participants", {}).get("local", {}).get("id")
|
||||
meeting_id = data.get("meetingSession", {}).get("id")
|
||||
logger.info(
|
||||
f"Joined {self._room_url}. Participant ID: {participant_id}, Meeting ID: {meeting_id}"
|
||||
)
|
||||
logger.info(f"Joined {self._room_url}")
|
||||
|
||||
await self._callbacks.on_joined(data)
|
||||
|
||||
|
||||
@@ -539,14 +539,11 @@ class LiveKitTransportClient:
|
||||
elif track.kind == rtc.TrackKind.KIND_VIDEO:
|
||||
logger.info(f"Video track subscribed: {track.sid} from participant {participant.sid}")
|
||||
self._video_tracks[participant.sid] = track
|
||||
# Only process video stream if video input is enabled to prevent
|
||||
# unbounded queue growth when there is no consumer for video frames.
|
||||
if self._params.video_in_enabled:
|
||||
video_stream = rtc.VideoStream(track)
|
||||
self._task_manager.create_task(
|
||||
self._process_video_stream(video_stream, participant.sid),
|
||||
f"{self}::_process_video_stream",
|
||||
)
|
||||
video_stream = rtc.VideoStream(track)
|
||||
self._task_manager.create_task(
|
||||
self._process_video_stream(video_stream, participant.sid),
|
||||
f"{self}::_process_video_stream",
|
||||
)
|
||||
await self._callbacks.on_video_track_subscribed(participant.sid)
|
||||
|
||||
async def _async_on_track_unsubscribed(
|
||||
|
||||
@@ -14,7 +14,6 @@ for real-time communication applications.
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
@@ -24,6 +23,7 @@ from pipecat.utils.base_object import BaseObject
|
||||
|
||||
try:
|
||||
from aiortc import (
|
||||
MediaStreamTrack,
|
||||
RTCConfiguration,
|
||||
RTCIceServer,
|
||||
RTCPeerConnection,
|
||||
@@ -278,7 +278,7 @@ class SmallWebRTCConnection(BaseObject):
|
||||
|
||||
self._answer: Optional[RTCSessionDescription] = None
|
||||
self._pc = RTCPeerConnection(rtc_config)
|
||||
self._pc_id = f"{self.name}-{uuid.uuid4().hex}"
|
||||
self._pc_id = self.name
|
||||
self._setup_listeners()
|
||||
self._data_channel = None
|
||||
self._renegotiation_in_progress = False
|
||||
|
||||
@@ -22,11 +22,3 @@ with warnings.catch_warnings():
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AlwaysUserMuteStrategy",
|
||||
"BaseUserMuteStrategy",
|
||||
"FirstSpeechUserMuteStrategy",
|
||||
"FunctionCallUserMuteStrategy",
|
||||
"MuteUntilFirstBotCompleteUserMuteStrategy",
|
||||
]
|
||||
|
||||
@@ -10,6 +10,7 @@ import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, EndOfTurnState
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
|
||||
@@ -16,12 +16,15 @@ import inspect
|
||||
import traceback
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.utils.utils import obj_count, obj_id
|
||||
|
||||
# TypeVar for preserving function signatures in decorators
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
@dataclass
|
||||
class EventHandler:
|
||||
@@ -99,7 +102,7 @@ class BaseObject(ABC):
|
||||
logger.debug(f"{self}: waiting on event handlers to finish {list(event_names)}...")
|
||||
await asyncio.wait(tasks)
|
||||
|
||||
def event_handler(self, event_name: str):
|
||||
def event_handler(self, event_name: str) -> Callable[[F], F]:
|
||||
"""Decorator for registering event handlers.
|
||||
|
||||
Args:
|
||||
@@ -109,7 +112,7 @@ class BaseObject(ABC):
|
||||
The decorator function that registers the handler.
|
||||
"""
|
||||
|
||||
def decorator(handler):
|
||||
def decorator(handler: F) -> F:
|
||||
self.add_event_handler(event_name, handler)
|
||||
return handler
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ as a unit regardless of internal punctuation.
|
||||
from typing import AsyncIterator, Optional, Sequence
|
||||
|
||||
from pipecat.utils.string import StartEndTags, parse_start_end_tags
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation
|
||||
from pipecat.utils.text.base_text_aggregator import Aggregation, AggregationType
|
||||
from pipecat.utils.text.simple_text_aggregator import SimpleTextAggregator
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -43,6 +43,7 @@ For AWS Bedrock adapter:
|
||||
import unittest
|
||||
|
||||
from google.genai.types import Content, Part
|
||||
from openai.types.chat import ChatCompletionMessage
|
||||
|
||||
from pipecat.adapters.services.anthropic_adapter import AnthropicLLMAdapter
|
||||
from pipecat.adapters.services.bedrock_adapter import AWSBedrockLLMAdapter
|
||||
@@ -50,6 +51,7 @@ from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
|
||||
from pipecat.adapters.services.open_ai_adapter import OpenAILLMAdapter
|
||||
from pipecat.processors.aggregators.llm_context import (
|
||||
LLMContext,
|
||||
LLMSpecificMessage,
|
||||
LLMStandardMessage,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Tests for LiveKit transport video stream handling.
|
||||
|
||||
Regression tests for issue #3116: Memory leak when video_in_enabled=False
|
||||
but video tracks are subscribed. The fix ensures video stream processing
|
||||
only starts when there is a consumer for the frames.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
try:
|
||||
from livekit import rtc
|
||||
|
||||
from pipecat.transports.livekit.transport import (
|
||||
LiveKitCallbacks,
|
||||
LiveKitParams,
|
||||
LiveKitTransportClient,
|
||||
)
|
||||
|
||||
LIVEKIT_AVAILABLE = True
|
||||
except ImportError:
|
||||
LIVEKIT_AVAILABLE = False
|
||||
|
||||
|
||||
@unittest.skipUnless(LIVEKIT_AVAILABLE, "livekit package not installed")
|
||||
class TestLiveKitVideoStreamMemoryLeak(unittest.IsolatedAsyncioTestCase):
|
||||
"""Regression tests for video queue memory leak (#3116).
|
||||
|
||||
The bug: When video_in_enabled=False, subscribing to a video track would
|
||||
start a producer that fills _video_queue, but no consumer would drain it,
|
||||
causing unbounded memory growth (~3GB/min).
|
||||
|
||||
The fix: Only start video stream processing when video_in_enabled=True.
|
||||
"""
|
||||
|
||||
def _create_client(self, video_in_enabled: bool) -> LiveKitTransportClient:
|
||||
"""Create a client with the specified video input setting."""
|
||||
params = LiveKitParams(video_in_enabled=video_in_enabled)
|
||||
callbacks = LiveKitCallbacks(
|
||||
on_connected=AsyncMock(),
|
||||
on_disconnected=AsyncMock(),
|
||||
on_before_disconnect=AsyncMock(),
|
||||
on_participant_connected=AsyncMock(),
|
||||
on_participant_disconnected=AsyncMock(),
|
||||
on_audio_track_subscribed=AsyncMock(),
|
||||
on_audio_track_unsubscribed=AsyncMock(),
|
||||
on_video_track_subscribed=AsyncMock(),
|
||||
on_video_track_unsubscribed=AsyncMock(),
|
||||
on_data_received=AsyncMock(),
|
||||
on_first_participant_joined=AsyncMock(),
|
||||
)
|
||||
client = LiveKitTransportClient(
|
||||
url="wss://test.livekit.cloud",
|
||||
token="test-token",
|
||||
room_name="test-room",
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
transport_name="test-transport",
|
||||
)
|
||||
client._task_manager = MagicMock()
|
||||
return client
|
||||
|
||||
def _create_mock_video_track(self):
|
||||
"""Create a mock video track subscription event."""
|
||||
track = MagicMock()
|
||||
track.kind = rtc.TrackKind.KIND_VIDEO
|
||||
track.sid = "video-track-123"
|
||||
publication = MagicMock()
|
||||
participant = MagicMock()
|
||||
participant.sid = "participant-456"
|
||||
return track, publication, participant
|
||||
|
||||
async def test_disabled_video_input_does_not_start_queue_producer(self):
|
||||
"""When video input is disabled, no producer should fill the queue.
|
||||
|
||||
This prevents the memory leak where frames accumulate with no consumer.
|
||||
"""
|
||||
client = self._create_client(video_in_enabled=False)
|
||||
track, publication, participant = self._create_mock_video_track()
|
||||
|
||||
await client._async_on_track_subscribed(track, publication, participant)
|
||||
|
||||
# Verify no video processing task was started
|
||||
task_names = [call[0][1] for call in client._task_manager.create_task.call_args_list]
|
||||
video_tasks = [name for name in task_names if "video" in name.lower()]
|
||||
self.assertEqual(video_tasks, [], "No video processing task should be started")
|
||||
|
||||
# Queue should remain empty
|
||||
self.assertEqual(client._video_queue.qsize(), 0)
|
||||
|
||||
# Track metadata should still be recorded
|
||||
self.assertIn(participant.sid, client._video_tracks)
|
||||
|
||||
# Callback should still fire for user code
|
||||
client._callbacks.on_video_track_subscribed.assert_called_once()
|
||||
|
||||
async def test_enabled_video_input_starts_queue_producer(self):
|
||||
"""When video input is enabled, the producer should start."""
|
||||
client = self._create_client(video_in_enabled=True)
|
||||
track, publication, participant = self._create_mock_video_track()
|
||||
|
||||
with patch.object(rtc, "VideoStream"):
|
||||
await client._async_on_track_subscribed(track, publication, participant)
|
||||
|
||||
# Verify video processing task was started
|
||||
task_names = [call[0][1] for call in client._task_manager.create_task.call_args_list]
|
||||
video_tasks = [name for name in task_names if "video" in name.lower()]
|
||||
self.assertEqual(len(video_tasks), 1, "Video processing task should be started")
|
||||
|
||||
# Track metadata should be recorded
|
||||
self.assertIn(participant.sid, client._video_tracks)
|
||||
|
||||
# Callback should fire
|
||||
client._callbacks.on_video_track_subscribed.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -5,7 +5,7 @@
|
||||
#
|
||||
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@@ -10,6 +10,7 @@ from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
FunctionCallFromLLM,
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
FunctionCallsStartedFrame,
|
||||
InputAudioRawFrame,
|
||||
|
||||
Reference in New Issue
Block a user