diff --git a/.github/workflows/format.yaml b/.github/workflows/format.yaml index 444e24338..7d1219411 100644 --- a/.github/workflows/format.yaml +++ b/.github/workflows/format.yaml @@ -17,7 +17,7 @@ concurrency: jobs: ruff-format: - name: "Formatting checker" + name: "Code quality checks" runs-on: ubuntu-latest steps: - name: Checkout repo @@ -39,8 +39,8 @@ jobs: run: | source .venv/bin/activate ruff format --diff - - name: Ruff import linter + - name: Ruff linter (all rules) id: ruff-check run: | source .venv/bin/activate - ruff check --select I + ruff check diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1e0594d64..bee07b8e0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -71,6 +71,21 @@ We follow Google-style docstrings with these specific conventions: - Use `Parameters:` section to document each enum value and its meaning - No `__init__` docstring (Enums don't have custom constructors) +**Code Examples in Docstrings:** + +- Use `Examples:` as a section header for multiple examples +- Use descriptive text followed by double colons (`::`) for each example +- **Always include a blank line after the `::"`** +- Indent all code consistently within each block +- Separate multiple examples with blank lines for readability + +**Lists and Bullets in Docstrings:** + +- Use dashes (`-`) for bullet points, not asterisks (`*`) +- **Add a blank line before bullet lists** when they follow a colon +- Use section headers like "Supported features:" or "Behavior:" before lists +- For complex nested information, consider using paragraph format instead + #### Examples: ```python @@ -80,6 +95,12 @@ class MyService(BaseService): Provides detailed explanation of the service's functionality, key features, and usage patterns. + + Supported features: + + - Feature one with detailed explanation + - Feature two with additional context + - Feature three for advanced use cases """ def __init__(self, param1: str, param2: bool = True, **kwargs): @@ -127,6 +148,34 @@ class ConfigParams: port: int = 8080 timeout: float = 30.0 +# Dataclass with code examples +@dataclass +class MessageFrame: + """Frame containing messages in OpenAI format. + + Supports both simple and content list message formats. + + Examples: + Simple format:: + + [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"} + ] + + Content list format:: + + [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]} + ] + + Parameters: + messages: List of messages in OpenAI format. + """ + + messages: List[dict] + # Enum class class Status(Enum): """Status codes for processing operations. diff --git a/docs/api/conf.py b/docs/api/conf.py index b69c62bbb..31c9fac25 100644 --- a/docs/api/conf.py +++ b/docs/api/conf.py @@ -26,6 +26,10 @@ extensions = [ "sphinx.ext.intersphinx", ] +suppress_warnings = [ + "autodoc.mocked_object", +] + # Napoleon settings napoleon_google_docstring = True napoleon_include_init_with_doc = True @@ -71,7 +75,6 @@ autodoc_mock_imports = [ "langchain", "lmnt", "noisereduce", - "openai", "openpipe", "simli", "soundfile", @@ -81,10 +84,6 @@ autodoc_mock_imports = [ "tkinter", "daily", "daily_python", - "pydantic.BaseModel", - "pydantic.Field", - "pydantic._internal._model_construction", - "pydantic._internal._fields", # Moondream dependencies "torch", "transformers", @@ -167,6 +166,19 @@ autodoc_mock_imports = [ "mcp.client.stdio", "mcp.ClientSession", "mcp.StdioServerParameters", + # gstreamer + "gi", + "gi.require_version", + "gi.repository", + # Protobuf mocks + "pipecat.frames.protobufs.frames_pb2", + "pipecat.serializers.protobuf", + "google.protobuf", + "google.protobuf.descriptor", + "google.protobuf.descriptor_pool", + "google.protobuf.runtime_version", + "google.protobuf.symbol_database", + "google.protobuf.internal.builder", ] # HTML output settings @@ -176,76 +188,32 @@ autodoc_typehints = "signature" # Show type hints in the signature only, not in html_show_sphinx = False -def verify_modules(): - """Verify that required modules are available.""" - required_modules = { - "services": [ - "assemblyai", - "aws", - "cartesia", - "deepgram", - "google", - "lmnt", - "riva", - "simli", - ], - "serializers": ["livekit"], - "vad": ["silero", "vad_analyzer"], - "transports": { - "services": ["daily", "livekit"], - "local": ["audio", "tk"], - "network": ["fastapi_websocket", "websocket_server"], - }, - } +def import_core_modules(): + """Import core pipecat modules for autodoc to discover.""" + core_modules = [ + "pipecat", + "pipecat.frames", + "pipecat.pipeline", + "pipecat.processors", + "pipecat.services", + "pipecat.transports", + "pipecat.audio", + "pipecat.adapters", + "pipecat.clocks", + "pipecat.metrics", + "pipecat.observers", + "pipecat.serializers", + "pipecat.sync", + "pipecat.transcriptions", + "pipecat.utils", + ] - # Skip importing modules that are in autodoc_mock_imports - skipped_modules = set(autodoc_mock_imports) - - missing = [] - for category, modules in required_modules.items(): - if isinstance(modules, dict): - # Handle nested structure - for subcategory, submodules in modules.items(): - for module in submodules: - # Check if module is in autodoc_mock_imports - if ( - f"pipecat.{category}.{subcategory}.{module}" in skipped_modules - or module in skipped_modules - ): - logger.info( - f"Skipping import of mocked module: pipecat.{category}.{subcategory}.{module}" - ) - continue - - try: - __import__(f"pipecat.{category}.{subcategory}.{module}") - logger.info( - f"Successfully imported pipecat.{category}.{subcategory}.{module}" - ) - except (ImportError, TypeError, NameError) as e: - missing.append(f"pipecat.{category}.{subcategory}.{module}") - logger.warning( - f"Optional module not available: pipecat.{category}.{subcategory}.{module} - {str(e)}" - ) - else: - # Handle flat structure - for module in modules: - # Check if module is in autodoc_mock_imports - if f"pipecat.{category}.{module}" in skipped_modules or module in skipped_modules: - logger.info(f"Skipping import of mocked module: pipecat.{category}.{module}") - continue - - try: - __import__(f"pipecat.{category}.{module}") - logger.info(f"Successfully imported pipecat.{category}.{module}") - except (ImportError, TypeError, NameError) as e: - missing.append(f"pipecat.{category}.{module}") - logger.warning( - f"Optional module not available: pipecat.{category}.{module} - {str(e)}" - ) - - if missing: - logger.warning(f"Some optional modules are not available: {missing}") + for module_name in core_modules: + try: + __import__(module_name) + logger.info(f"Successfully imported {module_name}") + except ImportError as e: + logger.warning(f"Failed to import {module_name}: {e}") def clean_title(title: str) -> str: @@ -257,40 +225,7 @@ def clean_title(title: str) -> str: parts = title.split(".") title = parts[-1] - # Special cases for service names and common acronyms - special_cases = { - "ai": "AI", - "aws": "AWS", - "api": "API", - "vad": "VAD", - "assemblyai": "AssemblyAI", - "deepgram": "Deepgram", - "elevenlabs": "ElevenLabs", - "openai": "OpenAI", - "openpipe": "OpenPipe", - "playht": "PlayHT", - "xtts": "XTTS", - "lmnt": "LMNT", - "stt": "STT", - "tts": "TTS", - "llm": "LLM", - "rtvi": "RTVI", - } - - # Check if the entire title is a special case - if title.lower() in special_cases: - return special_cases[title.lower()] - - # Otherwise, capitalize each word - words = title.split("_") - cleaned_words = [] - for word in words: - if word.lower() in special_cases: - cleaned_words.append(special_cases[word.lower()]) - else: - cleaned_words.append(word.capitalize()) - - return " ".join(cleaned_words) + return title def setup(app): @@ -315,9 +250,8 @@ def setup(app): excludes = [ str(project_root / "src/pipecat/pipeline/to_be_updated"), - str(project_root / "src/pipecat/processors/gstreamer"), - str(project_root / "src/pipecat/services/to_be_updated"), - str(project_root / "src/pipecat/vad"), # deprecated + str(project_root / "src/pipecat/examples"), + str(project_root / "src/pipecat/tests"), "**/test_*.py", "**/tests/*.py", ] @@ -358,5 +292,4 @@ def setup(app): logger.error(f"Error generating API documentation: {e}", exc_info=True) -# Run module verification -verify_modules() +import_core_modules() diff --git a/docs/api/index.rst b/docs/api/index.rst index 199aed1dd..344cf3ec6 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -1,57 +1,17 @@ -Pipecat API Reference Docs -========================== +Pipecat API Reference +===================== -Welcome to Pipecat's API reference documentation! +Welcome to the Pipecat API reference. -Pipecat is an open source framework for building voice and multimodal assistants. -It provides a flexible pipeline architecture for connecting various AI services, -audio processing, and transport layers. +Use the navigation on the left to browse modules, or search using the search box. + +**New to Pipecat?** Check out the `main documentation `_ for tutorials, guides, and client SDK information. Quick Links ----------- * `GitHub Repository `_ -* `Website `_ - -API Reference -------------- - -Core Components -~~~~~~~~~~~~~~~ - -* :mod:`Frames ` -* :mod:`Processors ` -* :mod:`Pipeline ` - -Audio Processing -~~~~~~~~~~~~~~~~ - -* :mod:`Audio ` - -Services -~~~~~~~~ - -* :mod:`Services ` - -Transport & Serialization -~~~~~~~~~~~~~~~~~~~~~~~~~ - -* :mod:`Transports ` - * :mod:`Local ` - * :mod:`Network ` - * :mod:`Services ` -* :mod:`Serializers ` - -Utilities -~~~~~~~~~ - -* :mod:`Adapters ` -* :mod:`Clocks ` -* :mod:`Metrics ` -* :mod:`Observers ` -* :mod:`Sync ` -* :mod:`Transcriptions ` -* :mod:`Utils ` +* `Join our Community `_ .. toctree:: :maxdepth: 3 @@ -71,11 +31,4 @@ Utilities Sync Transcriptions Transports - Utils - -Indices and tables -================== - -* :ref:`genindex` -* :ref:`modindex` -* :ref:`search` \ No newline at end of file + Utils \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f402ddfd1..9dfd84ec6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,9 +123,21 @@ select = [ "D", # Docstring rules "I", # Import rules ] +ignore = [ + "D105", # Missing docstring in magic methods (__str__, __repr__, etc.) +] [tool.ruff.lint.per-file-ignores] +# Skip docstring checks for non-source code +"examples/**/*.py" = ["D"] +"tests/**/*.py" = ["D"] +"scripts/**/*.py" = ["D"] +"docs/**/*.py" = ["D"] +# Skip D104 (missing docstring in public package) for __init__.py files "**/__init__.py" = ["D104"] +# Skip specific rules for generated protobuf files +"**/*_pb2.py" = ["D"] +"src/pipecat/services/__init__.py" = ["D"] [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/scripts/pre-commit.sh b/scripts/pre-commit.sh index 22f00313d..1341a8575 100755 --- a/scripts/pre-commit.sh +++ b/scripts/pre-commit.sh @@ -1,3 +1,27 @@ -#!/bin/sh +#!/bin/bash -NO_COLOR=1 ruff format --diff +# Color codes for output +RED='\033[0;31m' +GREEN='\033[0;32m' +NC='\033[0m' # No Color + +echo "🔍 Running pre-commit checks..." + +# Change to project root (one level up from scripts/) +cd "$(dirname "$0")/.." + +# Format check +echo "📝 Checking code formatting..." +if ! NO_COLOR=1 ruff format --diff --check; then + echo -e "${RED}❌ Code formatting issues found. Run 'ruff format' to fix.${NC}" + exit 1 +fi + +# Lint check +echo "🔍 Running linter..." +if ! ruff check; then + echo -e "${RED}❌ Linting issues found.${NC}" + exit 1 +fi + +echo -e "${GREEN}✅ All pre-commit checks passed!${NC}" \ No newline at end of file diff --git a/src/pipecat/adapters/base_llm_adapter.py b/src/pipecat/adapters/base_llm_adapter.py index c26722604..6a957c267 100644 --- a/src/pipecat/adapters/base_llm_adapter.py +++ b/src/pipecat/adapters/base_llm_adapter.py @@ -1,3 +1,15 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Base adapter for LLM provider integration. + +This module provides the abstract base class for implementing LLM provider-specific +adapters that handle tool format conversion and standardization. +""" + from abc import ABC, abstractmethod from typing import Any, List, Union, cast @@ -7,12 +19,35 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema class BaseLLMAdapter(ABC): + """Abstract base class for LLM provider adapters. + + Provides a standard interface for converting between Pipecat's standardized + tool schemas and provider-specific tool formats. Subclasses must implement + provider-specific conversion logic. + """ + @abstractmethod def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Any]: - """Converts tools to the provider's format.""" + """Convert tools schema to the provider's specific format. + + Args: + tools_schema: The standardized tools schema to convert. + + Returns: + List of tools in the provider's expected format. + """ pass def from_standard_tools(self, tools: Any) -> List[Any]: + """Convert tools from standard format to provider format. + + Args: + tools: Tools in standard format or provider-specific format. + + Returns: + List of tools converted to provider format, or original tools + if not in standard format. + """ if isinstance(tools, ToolsSchema): logger.debug(f"Retrieving the tools using the adapter: {type(self)}") return self.to_provider_tools_format(tools) diff --git a/src/pipecat/adapters/schemas/function_schema.py b/src/pipecat/adapters/schemas/function_schema.py index 55a070cf9..2b8753e58 100644 --- a/src/pipecat/adapters/schemas/function_schema.py +++ b/src/pipecat/adapters/schemas/function_schema.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Function schema utilities for AI tool definitions. + +This module provides standardized function schema representation for defining +tools and functions used with AI models, ensuring consistent formatting +across different AI service providers. +""" + from typing import Any, Dict, List @@ -13,17 +20,19 @@ class FunctionSchema: Provides a structured way to define function tools used with AI models like OpenAI. This schema defines the function's name, description, parameter properties, and required parameters, following specifications required by AI service providers. - - Args: - name: Name of the function to be called. - description: Description of what the function does. - properties: Dictionary defining parameter types, descriptions, and constraints. - required: List of property names that are required parameters. """ def __init__( self, name: str, description: str, properties: Dict[str, Any], required: List[str] ) -> None: + """Initialize the function schema. + + Args: + name: Name of the function to be called. + description: Description of what the function does. + properties: Dictionary defining parameter types, descriptions, and constraints. + required: List of property names that are required parameters. + """ self._name = name self._description = description self._properties = properties diff --git a/src/pipecat/adapters/schemas/tools_schema.py b/src/pipecat/adapters/schemas/tools_schema.py index 2489e0b4d..7ef582f7e 100644 --- a/src/pipecat/adapters/schemas/tools_schema.py +++ b/src/pipecat/adapters/schemas/tools_schema.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Tools schema definitions for function calling adapters. + +This module provides schemas for managing both standardized function tools +and custom adapter-specific tools in the Pipecat framework. +""" + from enum import Enum from typing import Any, Dict, List, Optional @@ -11,33 +17,61 @@ from pipecat.adapters.schemas.function_schema import FunctionSchema class AdapterType(Enum): + """Supported adapter types for custom tools. + + Parameters: + GEMINI: Google Gemini adapter - currently the only service supporting custom tools. + """ + GEMINI = "gemini" # that is the only service where we are able to add custom tools for now class ToolsSchema: + """Schema for managing both standard and custom function calling tools. + + This class provides a unified interface for handling standardized function + schemas alongside custom tools that may not follow the standard format, + such as adapter-specific search tools. + """ + def __init__( self, standard_tools: List[FunctionSchema], custom_tools: Optional[Dict[AdapterType, List[Dict[str, Any]]]] = None, ) -> None: - """ - A schema for tools that includes both standardized function schemas - and custom tools that do not follow the FunctionSchema format. + """Initialize the tools schema. - :param standard_tools: List of tools following FunctionSchema. - :param custom_tools: List of tools in a custom format (e.g., search_tool). + Args: + standard_tools: List of tools following the standardized FunctionSchema format. + custom_tools: Dictionary mapping adapter types to their custom tool definitions. + These tools may not follow the FunctionSchema format (e.g., search_tool). """ self._standard_tools = standard_tools self._custom_tools = custom_tools @property def standard_tools(self) -> List[FunctionSchema]: + """Get the list of standard function schema tools. + + Returns: + List of tools following the FunctionSchema format. + """ return self._standard_tools @property def custom_tools(self) -> Dict[AdapterType, List[Dict[str, Any]]]: + """Get the custom tools dictionary. + + Returns: + Dictionary mapping adapter types to their custom tool definitions. + """ return self._custom_tools @custom_tools.setter def custom_tools(self, value: Dict[AdapterType, List[Dict[str, Any]]]) -> None: + """Set the custom tools dictionary. + + Args: + value: Dictionary mapping adapter types to their custom tool definitions. + """ self._custom_tools = value diff --git a/src/pipecat/adapters/services/anthropic_adapter.py b/src/pipecat/adapters/services/anthropic_adapter.py index 23197d3a8..fb5abe108 100644 --- a/src/pipecat/adapters/services/anthropic_adapter.py +++ b/src/pipecat/adapters/services/anthropic_adapter.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Anthropic LLM adapter for Pipecat.""" + from typing import Any, Dict, List from pipecat.adapters.base_llm_adapter import BaseLLMAdapter @@ -12,8 +14,22 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema class AnthropicLLMAdapter(BaseLLMAdapter): + """Adapter for converting tool schemas to Anthropic's function-calling format. + + This adapter handles the conversion of Pipecat's standard function schemas + to the specific format required by Anthropic's Claude models for function calling. + """ + @staticmethod def _to_anthropic_function_format(function: FunctionSchema) -> Dict[str, Any]: + """Convert a single function schema to Anthropic's format. + + Args: + function: The function schema to convert. + + Returns: + Dictionary containing the function definition in Anthropic's format. + """ return { "name": function.name, "description": function.description, @@ -25,10 +41,13 @@ class AnthropicLLMAdapter(BaseLLMAdapter): } def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]: - """Converts function schemas to Anthropic's function-calling format. + """Convert function schemas to Anthropic's function-calling format. - :return: Anthropic formatted function call definition. + Args: + tools_schema: The tools schema containing functions to convert. + + Returns: + List of function definitions formatted for Anthropic's API. """ - functions_schema = tools_schema.standard_tools return [self._to_anthropic_function_format(func) for func in functions_schema] diff --git a/src/pipecat/adapters/services/aws_nova_sonic_adapter.py b/src/pipecat/adapters/services/aws_nova_sonic_adapter.py index dc7eef92d..2875f8272 100644 --- a/src/pipecat/adapters/services/aws_nova_sonic_adapter.py +++ b/src/pipecat/adapters/services/aws_nova_sonic_adapter.py @@ -3,6 +3,9 @@ # # SPDX-License-Identifier: BSD 2-Clause License # + +"""AWS Nova Sonic LLM adapter for Pipecat.""" + import json from typing import Any, Dict, List @@ -12,8 +15,22 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema class AWSNovaSonicLLMAdapter(BaseLLMAdapter): + """Adapter for AWS Nova Sonic language models. + + Converts Pipecat's standard function schemas into AWS Nova Sonic's + specific function-calling format, enabling tool use with Nova Sonic models. + """ + @staticmethod def _to_aws_nova_sonic_function_format(function: FunctionSchema) -> Dict[str, Any]: + """Convert a function schema to AWS Nova Sonic format. + + Args: + function: The function schema to convert. + + Returns: + Dictionary in AWS Nova Sonic function format with toolSpec structure. + """ return { "toolSpec": { "name": function.name, @@ -31,10 +48,13 @@ class AWSNovaSonicLLMAdapter(BaseLLMAdapter): } def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]: - """Converts function schemas to AWS Nova Sonic function-calling format. + """Convert tools schema to AWS Nova Sonic function-calling format. - :return: AWS Nova Sonic formatted function call definition. + Args: + tools_schema: The tools schema containing function definitions to convert. + + Returns: + List of dictionaries in AWS Nova Sonic function format. """ - functions_schema = tools_schema.standard_tools return [self._to_aws_nova_sonic_function_format(func) for func in functions_schema] diff --git a/src/pipecat/adapters/services/bedrock_adapter.py b/src/pipecat/adapters/services/bedrock_adapter.py index 113a6938d..364ad87d2 100644 --- a/src/pipecat/adapters/services/bedrock_adapter.py +++ b/src/pipecat/adapters/services/bedrock_adapter.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""AWS Bedrock LLM adapter for Pipecat.""" + from typing import Any, Dict, List from pipecat.adapters.base_llm_adapter import BaseLLMAdapter @@ -12,8 +14,22 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema class AWSBedrockLLMAdapter(BaseLLMAdapter): + """Adapter for AWS Bedrock LLM integration with Pipecat. + + Provides conversion utilities for transforming Pipecat function schemas + into AWS Bedrock's expected tool format for function calling capabilities. + """ + @staticmethod def _to_bedrock_function_format(function: FunctionSchema) -> Dict[str, Any]: + """Convert a function schema to Bedrock's tool format. + + Args: + function: The function schema to convert. + + Returns: + Dictionary formatted for Bedrock's tool specification. + """ return { "toolSpec": { "name": function.name, @@ -29,10 +45,13 @@ class AWSBedrockLLMAdapter(BaseLLMAdapter): } def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]: - """Converts function schemas to Bedrock's function-calling format. + """Convert function schemas to Bedrock's function-calling format. - :return: Bedrock formatted function call definition. + Args: + tools_schema: The tools schema containing functions to convert. + + Returns: + List of Bedrock formatted function call definitions. """ - functions_schema = tools_schema.standard_tools return [self._to_bedrock_function_format(func) for func in functions_schema] diff --git a/src/pipecat/adapters/services/gemini_adapter.py b/src/pipecat/adapters/services/gemini_adapter.py index 8efca5189..2139e0057 100644 --- a/src/pipecat/adapters/services/gemini_adapter.py +++ b/src/pipecat/adapters/services/gemini_adapter.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Gemini LLM adapter for Pipecat.""" + from typing import Any, Dict, List, Union from pipecat.adapters.base_llm_adapter import BaseLLMAdapter @@ -11,12 +13,23 @@ from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema class GeminiLLMAdapter(BaseLLMAdapter): + """LLM adapter for Google's Gemini service. + + Provides tool schema conversion functionality to transform standard tool + definitions into Gemini's specific function-calling format for use with + Gemini LLM models. + """ + def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]: - """Converts function schemas to Gemini's function-calling format. + """Convert tool schemas to Gemini's function-calling format. - :return: Gemini formatted function call definition. + Args: + tools_schema: The tools schema containing standard and custom tool definitions. + + Returns: + List of tool definitions formatted for Gemini's function-calling API. + Includes both converted standard tools and any custom Gemini-specific tools. """ - functions_schema = tools_schema.standard_tools formatted_standard_tools = [ {"function_declarations": [func.to_default_dict() for func in functions_schema]} diff --git a/src/pipecat/adapters/services/open_ai_adapter.py b/src/pipecat/adapters/services/open_ai_adapter.py index 909e5103a..59d70aa1e 100644 --- a/src/pipecat/adapters/services/open_ai_adapter.py +++ b/src/pipecat/adapters/services/open_ai_adapter.py @@ -3,6 +3,9 @@ # # SPDX-License-Identifier: BSD 2-Clause License # + +"""OpenAI LLM adapter for Pipecat.""" + from typing import List from openai.types.chat import ChatCompletionToolParam @@ -12,10 +15,22 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema class OpenAILLMAdapter(BaseLLMAdapter): - def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[ChatCompletionToolParam]: - """Converts function schemas to OpenAI's function-calling format. + """Adapter for converting tool schemas to OpenAI's format. - :return: OpenAI formatted function call definition. + Provides conversion utilities for transforming Pipecat's standard tool + schemas into the format expected by OpenAI's ChatCompletion API for + function calling capabilities. + """ + + def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[ChatCompletionToolParam]: + """Convert function schemas to OpenAI's function-calling format. + + Args: + tools_schema: The Pipecat tools schema to convert. + + Returns: + List of OpenAI formatted function call definitions ready for use + with ChatCompletion API. """ functions_schema = tools_schema.standard_tools return [ diff --git a/src/pipecat/adapters/services/open_ai_realtime_adapter.py b/src/pipecat/adapters/services/open_ai_realtime_adapter.py index b7eafaa81..58aea5a9a 100644 --- a/src/pipecat/adapters/services/open_ai_realtime_adapter.py +++ b/src/pipecat/adapters/services/open_ai_realtime_adapter.py @@ -3,6 +3,9 @@ # # SPDX-License-Identifier: BSD 2-Clause License # + +"""OpenAI Realtime LLM adapter for Pipecat.""" + from typing import Any, Dict, List, Union from pipecat.adapters.base_llm_adapter import BaseLLMAdapter @@ -11,8 +14,22 @@ from pipecat.adapters.schemas.tools_schema import ToolsSchema class OpenAIRealtimeLLMAdapter(BaseLLMAdapter): + """LLM adapter for OpenAI Realtime API function calling. + + Converts Pipecat's tool schemas into the specific format required by + OpenAI's Realtime API for function calling capabilities. + """ + @staticmethod def _to_openai_realtime_function_format(function: FunctionSchema) -> Dict[str, Any]: + """Convert a function schema to OpenAI Realtime format. + + Args: + function: The function schema to convert. + + Returns: + Dictionary in OpenAI Realtime function format. + """ return { "type": "function", "name": function.name, @@ -25,10 +42,13 @@ class OpenAIRealtimeLLMAdapter(BaseLLMAdapter): } def to_provider_tools_format(self, tools_schema: ToolsSchema) -> List[Dict[str, Any]]: - """Converts function schemas to Openai Realtime function-calling format. + """Convert tool schemas to OpenAI Realtime function-calling format. - :return: Openai Realtime formatted function call definition. + Args: + tools_schema: The tools schema containing functions to convert. + + Returns: + List of function definitions in OpenAI Realtime format. """ - functions_schema = tools_schema.standard_tools return [self._to_openai_realtime_function_format(func) for func in functions_schema] diff --git a/src/pipecat/audio/filters/base_audio_filter.py b/src/pipecat/audio/filters/base_audio_filter.py index c956ffe16..1724fd859 100644 --- a/src/pipecat/audio/filters/base_audio_filter.py +++ b/src/pipecat/audio/filters/base_audio_filter.py @@ -4,44 +4,68 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base audio filter interface for input transport audio processing. + +This module provides the abstract base class for implementing audio filters +that process audio data before VAD and downstream processing in input transports. +""" + from abc import ABC, abstractmethod from pipecat.frames.frames import FilterControlFrame class BaseAudioFilter(ABC): - """This is a base class for input transport audio filters. If an audio + """Base class for input transport audio filters. + + This is a base class for input transport audio filters. If an audio filter is provided to the input transport it will be used to process audio before VAD and before pushing it downstream. There are control frames to update filter settings or to enable or disable the filter at runtime. - """ @abstractmethod async def start(self, sample_rate: int): - """This will be called from the input transport when the transport is + """Initialize the filter when the input transport starts. + + This will be called from the input transport when the transport is started. It can be used to initialize the filter. The input transport sample rate is provided so the filter can adjust to that sample rate. + Args: + sample_rate: The sample rate of the input transport in Hz. """ pass @abstractmethod async def stop(self): - """This will be called from the input transport when the transport is - stopping. + """Clean up the filter when the input transport stops. + This will be called from the input transport when the transport is + stopping. """ pass @abstractmethod async def process_frame(self, frame: FilterControlFrame): - """This will be called when the input transport receives a + """Process control frames for runtime filter configuration. + + This will be called when the input transport receives a FilterControlFrame. + Args: + frame: The control frame containing filter commands or settings. """ pass @abstractmethod async def filter(self, audio: bytes) -> bytes: + """Apply the audio filter to the provided audio data. + + Args: + audio: Raw audio data as bytes to be filtered. + + Returns: + Filtered audio data as bytes. + """ pass diff --git a/src/pipecat/audio/filters/koala_filter.py b/src/pipecat/audio/filters/koala_filter.py index c1e539e37..64314428a 100644 --- a/src/pipecat/audio/filters/koala_filter.py +++ b/src/pipecat/audio/filters/koala_filter.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Koala noise suppression audio filter for Pipecat. + +This module provides an audio filter implementation using PicoVoice's Koala +Noise Suppression engine to reduce background noise in audio streams. +""" + from typing import Sequence import numpy as np @@ -21,12 +27,19 @@ except ModuleNotFoundError as e: class KoalaFilter(BaseAudioFilter): - """This is an audio filter that uses Koala Noise Suppression (from - PicoVoice). + """Audio filter using Koala Noise Suppression from PicoVoice. + Provides real-time noise suppression for audio streams using PicoVoice's + Koala engine. The filter buffers audio data to match Koala's required + frame length and processes it in chunks. """ def __init__(self, *, access_key: str) -> None: + """Initialize the Koala noise suppression filter. + + Args: + access_key: PicoVoice access key for Koala engine authentication. + """ self._access_key = access_key self._filtering = True @@ -36,6 +49,11 @@ class KoalaFilter(BaseAudioFilter): self._audio_buffer = bytearray() async def start(self, sample_rate: int): + """Initialize the filter with the transport's sample rate. + + Args: + sample_rate: The sample rate of the input transport in Hz. + """ self._sample_rate = sample_rate if self._sample_rate != self._koala.sample_rate: logger.warning( @@ -44,13 +62,30 @@ class KoalaFilter(BaseAudioFilter): self._koala_ready = False async def stop(self): + """Clean up the Koala engine when stopping.""" self._koala.reset() async def process_frame(self, frame: FilterControlFrame): + """Process control frames to enable/disable filtering. + + Args: + frame: The control frame containing filter commands. + """ if isinstance(frame, FilterEnableFrame): self._filtering = frame.enable async def filter(self, audio: bytes) -> bytes: + """Apply Koala noise suppression to audio data. + + Buffers incoming audio and processes it in chunks that match Koala's + required frame length. Returns filtered audio data. + + Args: + audio: Raw audio data as bytes to be filtered. + + Returns: + Noise-suppressed audio data as bytes. + """ if not self._koala_ready or not self._filtering: return audio diff --git a/src/pipecat/audio/filters/krisp_filter.py b/src/pipecat/audio/filters/krisp_filter.py index b23a46b65..267d2f2ea 100644 --- a/src/pipecat/audio/filters/krisp_filter.py +++ b/src/pipecat/audio/filters/krisp_filter.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Krisp noise reduction audio filter for Pipecat. + +This module provides an audio filter implementation using Krisp's noise +reduction technology to suppress background noise in audio streams. +""" + import os import numpy as np @@ -21,14 +27,27 @@ except ModuleNotFoundError as e: class KrispProcessorManager: - """ - Ensures that only one KrispAudioProcessor instance exists for the entire program. + """Singleton manager for KrispAudioProcessor instances. + + Ensures that only one KrispAudioProcessor instance exists for the entire + program. """ _krisp_instance = None @classmethod def get_processor(cls, sample_rate: int, sample_type: str, channels: int, model_path: str): + """Get or create a KrispAudioProcessor instance. + + Args: + sample_rate: Audio sample rate in Hz. + sample_type: Audio sample type (e.g., "PCM_16"). + channels: Number of audio channels. + model_path: Path to the Krisp model file. + + Returns: + Shared KrispAudioProcessor instance. + """ if cls._krisp_instance is None: cls._krisp_instance = KrispAudioProcessor( sample_rate, sample_type, channels, model_path @@ -37,14 +56,26 @@ class KrispProcessorManager: class KrispFilter(BaseAudioFilter): + """Audio filter using Krisp noise reduction technology. + + Provides real-time noise reduction for audio streams using Krisp's + proprietary noise suppression algorithms. Requires a Krisp model file + for operation. + """ + def __init__( self, sample_type: str = "PCM_16", channels: int = 1, model_path: str = None ) -> None: - """Initializes the KrispAudioProcessor with customizable audio processing settings. + """Initialize the Krisp noise reduction filter. - :param sample_type: The type of audio sample, default is 'PCM_16'. - :param channels: Number of audio channels, default is 1. - :param model_path: Path to the Krisp model; defaults to environment variable KRISP_MODEL_PATH if not provided. + Args: + sample_type: The audio sample format. Defaults to "PCM_16". + channels: Number of audio channels. Defaults to 1. + model_path: Path to the Krisp model file. If None, uses KRISP_MODEL_PATH + environment variable. + + Raises: + ValueError: If model_path is not provided and KRISP_MODEL_PATH is not set. """ super().__init__() @@ -63,19 +94,41 @@ class KrispFilter(BaseAudioFilter): self._krisp_processor = None async def start(self, sample_rate: int): + """Initialize the Krisp processor with the transport's sample rate. + + Args: + sample_rate: The sample rate of the input transport in Hz. + """ self._sample_rate = sample_rate self._krisp_processor = KrispProcessorManager.get_processor( self._sample_rate, self._sample_type, self._channels, self._model_path ) async def stop(self): + """Clean up the Krisp processor when stopping.""" self._krisp_processor = None async def process_frame(self, frame: FilterControlFrame): + """Process control frames to enable/disable filtering. + + Args: + frame: The control frame containing filter commands. + """ if isinstance(frame, FilterEnableFrame): self._filtering = frame.enable async def filter(self, audio: bytes) -> bytes: + """Apply Krisp noise reduction to audio data. + + Converts audio to float32, applies Krisp noise reduction processing, + and returns the filtered audio clipped to int16 range. + + Args: + audio: Raw audio data as bytes to be filtered. + + Returns: + Noise-reduced audio data as bytes. + """ if not self._filtering: return audio diff --git a/src/pipecat/audio/filters/noisereduce_filter.py b/src/pipecat/audio/filters/noisereduce_filter.py index 7a4b18395..550153b56 100644 --- a/src/pipecat/audio/filters/noisereduce_filter.py +++ b/src/pipecat/audio/filters/noisereduce_filter.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Noisereduce audio filter for Pipecat. + +This module provides an audio filter implementation using the noisereduce +library to reduce background noise in audio streams through spectral +gating algorithms. +""" + import numpy as np from loguru import logger @@ -21,21 +28,51 @@ except ModuleNotFoundError as e: class NoisereduceFilter(BaseAudioFilter): + """Audio filter using the noisereduce library for noise suppression. + + Applies spectral gating noise reduction algorithms to suppress background + noise in audio streams. Uses the noisereduce library's default noise + reduction parameters. + """ + def __init__(self) -> None: + """Initialize the noisereduce filter.""" self._filtering = True self._sample_rate = 0 async def start(self, sample_rate: int): + """Initialize the filter with the transport's sample rate. + + Args: + sample_rate: The sample rate of the input transport in Hz. + """ self._sample_rate = sample_rate async def stop(self): + """Clean up the filter when stopping.""" pass async def process_frame(self, frame: FilterControlFrame): + """Process control frames to enable/disable filtering. + + Args: + frame: The control frame containing filter commands. + """ if isinstance(frame, FilterEnableFrame): self._filtering = frame.enable async def filter(self, audio: bytes) -> bytes: + """Apply noise reduction to audio data using spectral gating. + + Converts audio to float32, applies noisereduce processing, and returns + the filtered audio clipped to int16 range. + + Args: + audio: Raw audio data as bytes to be filtered. + + Returns: + Noise-reduced audio data as bytes. + """ if not self._filtering: return audio diff --git a/src/pipecat/audio/interruptions/base_interruption_strategy.py b/src/pipecat/audio/interruptions/base_interruption_strategy.py index 7811e8418..83b2ff280 100644 --- a/src/pipecat/audio/interruptions/base_interruption_strategy.py +++ b/src/pipecat/audio/interruptions/base_interruption_strategy.py @@ -4,31 +4,51 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base interruption strategy for determining when users can interrupt bot speech.""" + from abc import ABC, abstractmethod class BaseInterruptionStrategy(ABC): - """This is a base class for interruption strategies. Interruption strategies + """Base class for interruption strategies. + + This is a base class for interruption strategies. Interruption strategies decide when the user can interrupt the bot while the bot is speaking. For example, there could be strategies based on audio volume or strategies based on the number of words the user spoke. - """ async def append_audio(self, audio: bytes, sample_rate: int): - """Appends audio to the strategy. Not all strategies handle audio.""" + """Append audio data to the strategy for analysis. + + Not all strategies handle audio. Default implementation does nothing. + + Args: + audio: Raw audio bytes to append. + sample_rate: Sample rate of the audio data in Hz. + """ pass async def append_text(self, text: str): - """Appends text to the strategy. Not all strategies handle text.""" + """Append text data to the strategy for analysis. + + Not all strategies handle text. Default implementation does nothing. + + Args: + text: Text string to append for analysis. + """ pass @abstractmethod async def should_interrupt(self) -> bool: - """This is called when the user stops speaking and it's time to decide + """Determine if the user should interrupt the bot. + + This is called when the user stops speaking and it's time to decide whether the user should interrupt the bot. The decision will be based on the aggregated audio and/or text. + Returns: + True if the user should interrupt the bot, False otherwise. """ pass diff --git a/src/pipecat/audio/interruptions/min_words_interruption_strategy.py b/src/pipecat/audio/interruptions/min_words_interruption_strategy.py index f9f7595ab..3f2dd5825 100644 --- a/src/pipecat/audio/interruptions/min_words_interruption_strategy.py +++ b/src/pipecat/audio/interruptions/min_words_interruption_strategy.py @@ -4,31 +4,47 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Minimum words interruption strategy for word count-based interruptions.""" + from loguru import logger from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy class MinWordsInterruptionStrategy(BaseInterruptionStrategy): - """This is an interruption strategy based on a minimum number of words said + """Interruption strategy based on minimum number of words spoken. + + This is an interruption strategy based on a minimum number of words said by the user. That is, the strategy will be true if the user has said at least that amount of words. - """ def __init__(self, *, min_words: int): + """Initialize the minimum words interruption strategy. + + Args: + min_words: Minimum number of words required to trigger an interruption. + """ super().__init__() self._min_words = min_words self._text = "" async def append_text(self, text: str): - """Appends text for later analysis. Not all strategies need to handle - text. + """Append text for word count analysis. + Args: + text: Text string to append to the accumulated text. + + Note: Not all strategies need to handle text. """ self._text += text async def should_interrupt(self) -> bool: + """Check if the minimum word count has been reached. + + Returns: + True if the user has spoken at least the minimum number of words. + """ word_count = len(self._text.split()) interrupt = word_count >= self._min_words logger.debug( @@ -37,4 +53,5 @@ class MinWordsInterruptionStrategy(BaseInterruptionStrategy): return interrupt async def reset(self): + """Reset the accumulated text for the next analysis cycle.""" self._text = "" diff --git a/src/pipecat/audio/mixers/base_audio_mixer.py b/src/pipecat/audio/mixers/base_audio_mixer.py index 9b7b12163..4ba5938d2 100644 --- a/src/pipecat/audio/mixers/base_audio_mixer.py +++ b/src/pipecat/audio/mixers/base_audio_mixer.py @@ -4,50 +4,73 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base audio mixer for output transport integration. + +Provides the abstract base class for audio mixers that can be integrated with +output transports to mix incoming audio with generated audio from the mixer. +""" + from abc import ABC, abstractmethod from pipecat.frames.frames import MixerControlFrame class BaseAudioMixer(ABC): - """This is a base class for output transport audio mixers. If an audio mixer + """Base class for output transport audio mixers. + + This is a base class for output transport audio mixers. If an audio mixer is provided to the output transport it will be used to mix the audio frames coming into to the transport with the audio generated from the mixer. There are control frames to update mixer settings or to enable or disable the mixer at runtime. - """ @abstractmethod async def start(self, sample_rate: int): - """This will be called from the output transport when the transport is + """Initialize the mixer when the output transport starts. + + This will be called from the output transport when the transport is started. It can be used to initialize the mixer. The output transport sample rate is provided so the mixer can adjust to that sample rate. + Args: + sample_rate: The sample rate of the output transport in Hz. """ pass @abstractmethod async def stop(self): - """This will be called from the output transport when the transport is - stopping. + """Clean up the mixer when the output transport stops. + This will be called from the output transport when the transport is + stopping. """ pass @abstractmethod async def process_frame(self, frame: MixerControlFrame): - """This will be called when the output transport receives a + """Process mixer control frames from the transport. + + This will be called when the output transport receives a MixerControlFrame. + Args: + frame: The mixer control frame to process. """ pass @abstractmethod async def mix(self, audio: bytes) -> bytes: - """This is called with the audio that is about to be sent from the + """Mix transport audio with mixer-generated audio. + + This is called with the audio that is about to be sent from the output transport and that should be mixed with the mixer audio if the mixer is enabled. + Args: + audio: Raw audio bytes from the transport to mix. + + Returns: + Mixed audio bytes combining transport and mixer audio. """ pass diff --git a/src/pipecat/audio/mixers/soundfile_mixer.py b/src/pipecat/audio/mixers/soundfile_mixer.py index 1628c4d8a..c3664012c 100644 --- a/src/pipecat/audio/mixers/soundfile_mixer.py +++ b/src/pipecat/audio/mixers/soundfile_mixer.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Soundfile-based audio mixer for file playback integration. + +Provides an audio mixer that combines incoming audio with audio loaded from +files using the soundfile library. Supports multiple audio formats and +runtime configuration changes. +""" + import asyncio from typing import Any, Dict, Mapping @@ -24,7 +31,9 @@ except ModuleNotFoundError as e: class SoundfileMixer(BaseAudioMixer): - """This is an audio mixer that mixes incoming audio with audio from a + """Audio mixer that combines incoming audio with file-based audio. + + This is an audio mixer that mixes incoming audio with audio from a file. It uses the soundfile library to load files so it supports multiple formats. The audio files need to only have one channel (mono) and it needs to match the sample rate of the output transport. @@ -33,7 +42,6 @@ class SoundfileMixer(BaseAudioMixer): `MixerUpdateSettingsFrame` has the following settings available: `sound` (str) and `volume` (float) to be able to update to a different sound file or to change the volume at runtime. - """ def __init__( @@ -46,6 +54,16 @@ class SoundfileMixer(BaseAudioMixer): loop: bool = True, **kwargs, ): + """Initialize the soundfile mixer. + + Args: + sound_files: Mapping of sound names to file paths for loading. + default_sound: Name of the default sound to play initially. + volume: Mixing volume level (0.0 to 1.0). Defaults to 0.4. + mixing: Whether mixing is initially enabled. Defaults to True. + loop: Whether to loop audio files when they end. Defaults to True. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(**kwargs) self._sound_files = sound_files self._volume = volume @@ -58,14 +76,28 @@ class SoundfileMixer(BaseAudioMixer): self._loop = loop async def start(self, sample_rate: int): + """Initialize the mixer and load all sound files. + + Args: + sample_rate: The sample rate of the output transport in Hz. + """ self._sample_rate = sample_rate for sound_name, file_name in self._sound_files.items(): await asyncio.to_thread(self._load_sound_file, sound_name, file_name) async def stop(self): + """Clean up mixer resources. + + Currently performs no cleanup as sound data is managed by garbage collection. + """ pass async def process_frame(self, frame: MixerControlFrame): + """Process mixer control frames to update settings or enable/disable mixing. + + Args: + frame: The mixer control frame to process. + """ if isinstance(frame, MixerUpdateSettingsFrame): await self._update_settings(frame) elif isinstance(frame, MixerEnableFrame): @@ -73,12 +105,22 @@ class SoundfileMixer(BaseAudioMixer): pass async def mix(self, audio: bytes) -> bytes: + """Mix transport audio with the current sound file. + + Args: + audio: Raw audio bytes from the transport to mix. + + Returns: + Mixed audio bytes combining transport and file audio. + """ return self._mix_with_sound(audio) async def _enable_mixing(self, enable: bool): + """Enable or disable audio mixing.""" self._mixing = enable async def _update_settings(self, frame: MixerUpdateSettingsFrame): + """Update mixer settings from a control frame.""" for setting, value in frame.settings.items(): match setting: case "sound": @@ -89,6 +131,11 @@ class SoundfileMixer(BaseAudioMixer): await self._update_loop(value) async def _change_sound(self, sound: str): + """Change the currently playing sound file. + + Args: + sound: Name of the sound file to switch to. + """ if sound in self._sound_files: self._current_sound = sound self._sound_pos = 0 @@ -96,12 +143,15 @@ class SoundfileMixer(BaseAudioMixer): logger.error(f"Sound {sound} is not available") async def _update_volume(self, volume: float): + """Update the mixing volume level.""" self._volume = volume async def _update_loop(self, loop: bool): + """Update the looping behavior.""" self._loop = loop def _load_sound_file(self, sound_name: str, file_name: str): + """Load an audio file into memory for mixing.""" try: logger.debug(f"Loading mixer sound from {file_name}") sound, sample_rate = sf.read(file_name, dtype="int16") @@ -118,10 +168,7 @@ class SoundfileMixer(BaseAudioMixer): logger.error(f"Unable to open file {file_name}: {e}") def _mix_with_sound(self, audio: bytes): - """Mixes raw audio frames with chunks of the same length from the sound - file. - - """ + """Mix raw audio frames with chunks of the same length from the sound file.""" if not self._mixing or not self._current_sound in self._sounds: return audio diff --git a/src/pipecat/audio/resamplers/base_audio_resampler.py b/src/pipecat/audio/resamplers/base_audio_resampler.py index 6afbbcbfe..42854cd2c 100644 --- a/src/pipecat/audio/resamplers/base_audio_resampler.py +++ b/src/pipecat/audio/resamplers/base_audio_resampler.py @@ -4,27 +4,35 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base audio resampler interface for Pipecat. + +This module defines the abstract base class for audio resampling implementations, +providing a common interface for converting audio between different sample rates. +""" + from abc import ABC, abstractmethod class BaseAudioResampler(ABC): - """Abstract base class for audio resampling. This class defines an - interface for audio resampling implementations. + """Abstract base class for audio resampling implementations. + + This class defines the interface that all audio resampling implementations + must follow, providing a standardized way to convert audio data between + different sample rates. """ @abstractmethod async def resample(self, audio: bytes, in_rate: int, out_rate: int) -> bytes: - """ - Resamples the given audio data to a different sample rate. + """Resamples the given audio data to a different sample rate. This is an abstract method that must be implemented in subclasses. - Parameters: - audio (bytes): The audio data to be resampled, represented as a byte string. - in_rate (int): The original sample rate of the audio data (in Hz). - out_rate (int): The desired sample rate for the resampled audio data (in Hz). + Args: + audio: The audio data to be resampled, as raw bytes. + in_rate: The original sample rate of the audio data in Hz. + out_rate: The desired sample rate for the output audio in Hz. Returns: - bytes: The resampled audio data as a byte string. + The resampled audio data as raw bytes. """ pass diff --git a/src/pipecat/audio/resamplers/resampy_resampler.py b/src/pipecat/audio/resamplers/resampy_resampler.py index 8c053fc3b..b427bd3e8 100644 --- a/src/pipecat/audio/resamplers/resampy_resampler.py +++ b/src/pipecat/audio/resamplers/resampy_resampler.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Resampy-based audio resampler implementation. + +This module provides an audio resampler that uses the resampy library +for high-quality audio sample rate conversion. +""" + import numpy as np import resampy @@ -11,12 +17,31 @@ from pipecat.audio.resamplers.base_audio_resampler import BaseAudioResampler class ResampyResampler(BaseAudioResampler): - """Audio resampler implementation using the resampy library.""" + """Audio resampler implementation using the resampy library. + + This resampler uses the resampy library's Kaiser windowing filter + for high-quality audio resampling with good performance characteristics. + """ def __init__(self, **kwargs): + """Initialize the resampy resampler. + + Args: + **kwargs: Additional keyword arguments (currently unused). + """ pass async def resample(self, audio: bytes, in_rate: int, out_rate: int) -> bytes: + """Resample audio data using resampy library. + + Args: + audio: Input audio data as raw bytes (16-bit signed integers). + in_rate: Original sample rate in Hz. + out_rate: Target sample rate in Hz. + + Returns: + Resampled audio data as raw bytes (16-bit signed integers). + """ if in_rate == out_rate: return audio audio_data = np.frombuffer(audio, dtype=np.int16) diff --git a/src/pipecat/audio/resamplers/soxr_resampler.py b/src/pipecat/audio/resamplers/soxr_resampler.py index 88edb84eb..9f285069f 100644 --- a/src/pipecat/audio/resamplers/soxr_resampler.py +++ b/src/pipecat/audio/resamplers/soxr_resampler.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""SoX-based audio resampler implementation. + +This module provides an audio resampler that uses the SoX resampler library +for very high quality audio sample rate conversion. +""" + import numpy as np import soxr @@ -11,12 +17,32 @@ from pipecat.audio.resamplers.base_audio_resampler import BaseAudioResampler class SOXRAudioResampler(BaseAudioResampler): - """Audio resampler implementation using the SoX resampler library.""" + """Audio resampler implementation using the SoX resampler library. + + This resampler uses the SoX resampler library configured for very high + quality (VHQ) resampling, providing excellent audio quality at the cost + of additional computational overhead. + """ def __init__(self, **kwargs): + """Initialize the SoX audio resampler. + + Args: + **kwargs: Additional keyword arguments (currently unused). + """ pass async def resample(self, audio: bytes, in_rate: int, out_rate: int) -> bytes: + """Resample audio data using SoX resampler library. + + Args: + audio: Input audio data as raw bytes (16-bit signed integers). + in_rate: Original sample rate in Hz. + out_rate: Target sample rate in Hz. + + Returns: + Resampled audio data as raw bytes (16-bit signed integers). + """ if in_rate == out_rate: return audio audio_data = np.frombuffer(audio, dtype=np.int16) diff --git a/src/pipecat/audio/turn/base_turn_analyzer.py b/src/pipecat/audio/turn/base_turn_analyzer.py index 301dbf7bf..642173852 100644 --- a/src/pipecat/audio/turn/base_turn_analyzer.py +++ b/src/pipecat/audio/turn/base_turn_analyzer.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base turn analyzer for determining end-of-turn in audio conversations. + +This module provides the abstract base class and enumeration for analyzing +when a user has finished speaking in a conversation. +""" + from abc import ABC, abstractmethod from enum import Enum from typing import Optional, Tuple @@ -12,6 +18,13 @@ from pipecat.metrics.metrics import MetricsData class EndOfTurnState(Enum): + """State enumeration for end-of-turn analysis results. + + Parameters: + COMPLETE: The user has finished their turn and stopped speaking. + INCOMPLETE: The user is still speaking or may continue speaking. + """ + COMPLETE = 1 INCOMPLETE = 2 @@ -24,6 +37,12 @@ class BaseTurnAnalyzer(ABC): """ def __init__(self, *, sample_rate: Optional[int] = None): + """Initialize the turn analyzer. + + Args: + sample_rate: Optional initial sample rate for audio processing. + If provided, this will be used as the fixed sample rate. + """ self._init_sample_rate = sample_rate self._sample_rate = 0 diff --git a/src/pipecat/audio/turn/smart_turn/base_smart_turn.py b/src/pipecat/audio/turn/smart_turn/base_smart_turn.py index 0b577028b..38b5410ec 100644 --- a/src/pipecat/audio/turn/smart_turn/base_smart_turn.py +++ b/src/pipecat/audio/turn/smart_turn/base_smart_turn.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Smart turn analyzer base class using ML models for end-of-turn detection. + +This module provides the base implementation for smart turn analyzers that use +machine learning models to determine when a user has finished speaking, going +beyond simple silence-based detection. +""" + import time from abc import abstractmethod from typing import Any, Dict, Optional, Tuple @@ -23,6 +30,14 @@ USE_ONLY_LAST_VAD_SEGMENT = True class SmartTurnParams(BaseModel): + """Configuration parameters for smart turn analysis. + + Parameters: + stop_secs: Maximum silence duration in seconds before ending turn. + pre_speech_ms: Milliseconds of audio to include before speech starts. + max_duration_secs: Maximum duration in seconds for audio segments. + """ + stop_secs: float = STOP_SECS pre_speech_ms: float = PRE_SPEECH_MS max_duration_secs: float = MAX_DURATION_SECONDS @@ -31,13 +46,28 @@ class SmartTurnParams(BaseModel): class SmartTurnTimeoutException(Exception): + """Exception raised when smart turn analysis times out.""" + pass class BaseSmartTurn(BaseTurnAnalyzer): + """Base class for smart turn analyzers using ML models. + + Provides common functionality for smart turn detection including audio + buffering, speech tracking, and ML model integration. Subclasses must + implement the specific model prediction logic. + """ + def __init__( self, *, sample_rate: Optional[int] = None, params: Optional[SmartTurnParams] = None ): + """Initialize the smart turn analyzer. + + Args: + sample_rate: Optional sample rate for audio processing. + params: Configuration parameters for turn analysis behavior. + """ super().__init__(sample_rate=sample_rate) self._params = params or SmartTurnParams() # Configuration @@ -50,9 +80,23 @@ class BaseSmartTurn(BaseTurnAnalyzer): @property def speech_triggered(self) -> bool: + """Check if speech has been detected and triggered analysis. + + Returns: + True if speech has been detected and turn analysis is active. + """ return self._speech_triggered def append_audio(self, buffer: bytes, is_speech: bool) -> EndOfTurnState: + """Append audio data for turn analysis. + + Args: + buffer: Raw audio data bytes to append for analysis. + is_speech: Whether the audio buffer contains detected speech. + + Returns: + Current end-of-turn state after processing the audio. + """ # Convert raw audio to float32 format and append to the buffer audio_int16 = np.frombuffer(buffer, dtype=np.int16) audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0 @@ -92,6 +136,12 @@ class BaseSmartTurn(BaseTurnAnalyzer): return state async def analyze_end_of_turn(self) -> Tuple[EndOfTurnState, Optional[MetricsData]]: + """Analyze the current audio state to determine if turn has ended. + + Returns: + Tuple containing the end-of-turn state and optional metrics data + from the ML model analysis. + """ state, result = await self._process_speech_segment(self._audio_buffer) if state == EndOfTurnState.COMPLETE or USE_ONLY_LAST_VAD_SEGMENT: self._clear(state) @@ -99,9 +149,11 @@ class BaseSmartTurn(BaseTurnAnalyzer): return state, result def clear(self): + """Reset the turn analyzer to its initial state.""" self._clear(EndOfTurnState.COMPLETE) def _clear(self, turn_state: EndOfTurnState): + """Clear internal state based on turn completion status.""" # If the state is still incomplete, keep the _speech_triggered as True self._speech_triggered = turn_state == EndOfTurnState.INCOMPLETE self._audio_buffer = [] @@ -111,6 +163,7 @@ class BaseSmartTurn(BaseTurnAnalyzer): async def _process_speech_segment( self, audio_buffer ) -> Tuple[EndOfTurnState, Optional[MetricsData]]: + """Process accumulated audio segment using ML model.""" state = EndOfTurnState.INCOMPLETE if not audio_buffer: @@ -188,14 +241,5 @@ class BaseSmartTurn(BaseTurnAnalyzer): @abstractmethod async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]: - """Abstract method to predict if a turn has ended based on audio. - - Args: - audio_array: Float32 numpy array of audio samples at 16kHz. - - Returns: - Dictionary with: - - prediction: 1 if turn is complete, else 0 - - probability: Confidence of the prediction - """ + """Predict end-of-turn using ML model from audio data.""" pass diff --git a/src/pipecat/audio/turn/smart_turn/fal_smart_turn.py b/src/pipecat/audio/turn/smart_turn/fal_smart_turn.py index 9e3a85b56..d627eca72 100644 --- a/src/pipecat/audio/turn/smart_turn/fal_smart_turn.py +++ b/src/pipecat/audio/turn/smart_turn/fal_smart_turn.py @@ -4,6 +4,16 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Fal.ai smart turn analyzer implementation. + +This module provides a smart turn analyzer that uses Fal.ai's hosted smart-turn model +for end-of-turn detection in conversations. + +Note: To learn more about the smart-turn model, visit: + - https://fal.ai/models/fal-ai/smart-turn/playground + - https://github.com/pipecat-ai/smart-turn +""" + from typing import Optional import aiohttp @@ -12,6 +22,12 @@ from pipecat.audio.turn.smart_turn.http_smart_turn import HttpSmartTurnAnalyzer class FalSmartTurnAnalyzer(HttpSmartTurnAnalyzer): + """Smart turn analyzer using Fal.ai's hosted smart-turn model. + + Extends HttpSmartTurnAnalyzer to provide integration with Fal.ai's + smart turn detection API endpoint with proper authentication. + """ + def __init__( self, *, @@ -20,6 +36,14 @@ class FalSmartTurnAnalyzer(HttpSmartTurnAnalyzer): api_key: Optional[str] = None, **kwargs, ): + """Initialize the Fal.ai smart turn analyzer. + + Args: + aiohttp_session: HTTP client session for making API requests. + url: Fal.ai API endpoint URL for smart turn detection. + api_key: API key for authenticating with Fal.ai service. + **kwargs: Additional arguments passed to parent HttpSmartTurnAnalyzer. + """ headers = {} if api_key: headers = {"Authorization": f"Key {api_key}"} diff --git a/src/pipecat/audio/turn/smart_turn/http_smart_turn.py b/src/pipecat/audio/turn/smart_turn/http_smart_turn.py index bf9f086a3..c28727f78 100644 --- a/src/pipecat/audio/turn/smart_turn/http_smart_turn.py +++ b/src/pipecat/audio/turn/smart_turn/http_smart_turn.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""HTTP-based smart turn analyzer for remote ML inference. + +This module provides a smart turn analyzer that sends audio data to remote +HTTP endpoints for ML-based end-of-turn detection. +""" + import asyncio import io from typing import Any, Dict, Optional @@ -16,6 +22,12 @@ from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn, SmartTu class HttpSmartTurnAnalyzer(BaseSmartTurn): + """Smart turn analyzer using HTTP-based ML inference. + + Sends audio data to remote HTTP endpoints for ML-based end-of-turn + prediction. Handles serialization, HTTP communication, and error recovery. + """ + def __init__( self, *, @@ -24,12 +36,21 @@ class HttpSmartTurnAnalyzer(BaseSmartTurn): headers: Optional[Dict[str, str]] = None, **kwargs, ): + """Initialize the HTTP smart turn analyzer. + + Args: + url: HTTP endpoint URL for the smart turn ML service. + aiohttp_session: HTTP client session for making requests. + headers: Optional HTTP headers to include in requests. + **kwargs: Additional arguments passed to BaseSmartTurn. + """ super().__init__(**kwargs) self._url = url self._headers = headers or {} self._aiohttp_session = aiohttp_session def _serialize_array(self, audio_array: np.ndarray) -> bytes: + """Serialize NumPy audio array to bytes for HTTP transmission.""" logger.trace("Serializing NumPy array to bytes...") buffer = io.BytesIO() np.save(buffer, audio_array) @@ -38,6 +59,7 @@ class HttpSmartTurnAnalyzer(BaseSmartTurn): return serialized_bytes async def _send_raw_request(self, data_bytes: bytes) -> Dict[str, Any]: + """Send raw audio data to the HTTP endpoint for prediction.""" headers = {"Content-Type": "application/octet-stream"} headers.update(self._headers) @@ -83,6 +105,7 @@ class HttpSmartTurnAnalyzer(BaseSmartTurn): raise Exception("Failed to send raw request to Daily Smart Turn.") async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]: + """Predict end-of-turn using remote HTTP ML service.""" try: serialized_array = self._serialize_array(audio_array) return await self._send_raw_request(serialized_array) diff --git a/src/pipecat/audio/turn/smart_turn/local_coreml_smart_turn.py b/src/pipecat/audio/turn/smart_turn/local_coreml_smart_turn.py index 88d6530bd..dd35449de 100644 --- a/src/pipecat/audio/turn/smart_turn/local_coreml_smart_turn.py +++ b/src/pipecat/audio/turn/smart_turn/local_coreml_smart_turn.py @@ -4,6 +4,11 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Local CoreML smart turn analyzer for on-device ML inference. + +This module provides a smart turn analyzer that uses CoreML models for +local end-of-turn detection without requiring network connectivity. +""" from typing import Any, Dict @@ -25,7 +30,24 @@ except ModuleNotFoundError as e: class LocalCoreMLSmartTurnAnalyzer(BaseSmartTurn): + """Local smart turn analyzer using CoreML models. + + Provides end-of-turn detection using locally-stored CoreML models, + enabling offline operation without network dependencies. Optimized + for Apple Silicon and other CoreML-compatible hardware. + """ + def __init__(self, *, smart_turn_model_path: str, **kwargs): + """Initialize the local CoreML smart turn analyzer. + + Args: + smart_turn_model_path: Path to directory containing the CoreML model + and feature extractor files. + **kwargs: Additional arguments passed to BaseSmartTurn. + + Raises: + Exception: If smart_turn_model_path is not provided or model loading fails. + """ super().__init__(**kwargs) if not smart_turn_model_path: @@ -41,6 +63,7 @@ class LocalCoreMLSmartTurnAnalyzer(BaseSmartTurn): logger.debug("Loaded Local Smart Turn") async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]: + """Predict end-of-turn using local CoreML model.""" inputs = self._turn_processor( audio_array, sampling_rate=16000, diff --git a/src/pipecat/audio/turn/smart_turn/local_smart_turn.py b/src/pipecat/audio/turn/smart_turn/local_smart_turn.py index a3ee7ebf9..ed67dad12 100644 --- a/src/pipecat/audio/turn/smart_turn/local_smart_turn.py +++ b/src/pipecat/audio/turn/smart_turn/local_smart_turn.py @@ -4,6 +4,11 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Local PyTorch smart turn analyzer for on-device ML inference. + +This module provides a smart turn analyzer that uses PyTorch models for +local end-of-turn detection without requiring network connectivity. +""" from typing import Any, Dict @@ -24,7 +29,21 @@ except ModuleNotFoundError as e: class LocalSmartTurnAnalyzer(BaseSmartTurn): + """Local smart turn analyzer using PyTorch models. + + Provides end-of-turn detection using locally-stored PyTorch models, + enabling offline operation without network dependencies. Uses + Wav2Vec2-BERT architecture for audio sequence classification. + """ + def __init__(self, *, smart_turn_model_path: str, **kwargs): + """Initialize the local PyTorch smart turn analyzer. + + Args: + smart_turn_model_path: Path to directory containing the PyTorch model + and feature extractor files. If empty, uses default HuggingFace model. + **kwargs: Additional arguments passed to BaseSmartTurn. + """ super().__init__(**kwargs) if not smart_turn_model_path: @@ -46,6 +65,7 @@ class LocalSmartTurnAnalyzer(BaseSmartTurn): logger.debug("Loaded Local Smart Turn") async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]: + """Predict end-of-turn using local PyTorch model.""" inputs = self._turn_processor( audio_array, sampling_rate=16000, diff --git a/src/pipecat/audio/utils.py b/src/pipecat/audio/utils.py index 1f7db648f..7b72be1a9 100644 --- a/src/pipecat/audio/utils.py +++ b/src/pipecat/audio/utils.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Audio utility functions for Pipecat. + +This module provides common audio processing utilities including mixing, +format conversion, volume calculation, and codec transformations for +various audio formats used in Pipecat pipelines. +""" + import audioop import numpy as np @@ -15,10 +22,31 @@ from pipecat.audio.resamplers.soxr_resampler import SOXRAudioResampler def create_default_resampler(**kwargs) -> BaseAudioResampler: + """Create a default audio resampler instance. + + Args: + **kwargs: Additional keyword arguments passed to the resampler constructor. + + Returns: + A configured SOXRAudioResampler instance. + """ return SOXRAudioResampler(**kwargs) def mix_audio(audio1: bytes, audio2: bytes) -> bytes: + """Mix two audio streams together by adding their samples. + + Both audio streams are assumed to be 16-bit signed integer PCM data. + If the streams have different lengths, the shorter one is zero-padded + to match the longer stream. + + Args: + audio1: First audio stream as raw bytes (16-bit signed integers). + audio2: Second audio stream as raw bytes (16-bit signed integers). + + Returns: + Mixed audio data as raw bytes with samples clipped to 16-bit range. + """ data1 = np.frombuffer(audio1, dtype=np.int16) data2 = np.frombuffer(audio2, dtype=np.int16) @@ -37,6 +65,19 @@ def mix_audio(audio1: bytes, audio2: bytes) -> bytes: def interleave_stereo_audio(left_audio: bytes, right_audio: bytes) -> bytes: + """Interleave left and right mono audio channels into stereo audio. + + Takes two mono audio streams and combines them into a single stereo + stream by interleaving the samples (L, R, L, R, ...). If the channels + have different lengths, both are truncated to the shorter length. + + Args: + left_audio: Left channel audio as raw bytes (16-bit signed integers). + right_audio: Right channel audio as raw bytes (16-bit signed integers). + + Returns: + Interleaved stereo audio data as raw bytes. + """ left = np.frombuffer(left_audio, dtype=np.int16) right = np.frombuffer(right_audio, dtype=np.int16) @@ -50,12 +91,34 @@ def interleave_stereo_audio(left_audio: bytes, right_audio: bytes) -> bytes: def normalize_value(value, min_value, max_value): + """Normalize a value to the range [0, 1] and clamp it to bounds. + + Args: + value: The value to normalize. + min_value: The minimum value of the input range. + max_value: The maximum value of the input range. + + Returns: + Normalized value clamped to the range [0, 1]. + """ normalized = (value - min_value) / (max_value - min_value) normalized_clamped = max(0, min(1, normalized)) return normalized_clamped def calculate_audio_volume(audio: bytes, sample_rate: int) -> float: + """Calculate the loudness level of audio data using EBU R128 standard. + + Uses the pyloudnorm library to calculate integrated loudness according + to the EBU R128 recommendation, then normalizes the result to [0, 1]. + + Args: + audio: Audio data as raw bytes (16-bit signed integers). + sample_rate: Sample rate of the audio in Hz. + + Returns: + Normalized loudness value between 0 (quiet) and 1 (loud). + """ audio_np = np.frombuffer(audio, dtype=np.int16) audio_float = audio_np.astype(np.float64) @@ -71,12 +134,37 @@ def calculate_audio_volume(audio: bytes, sample_rate: int) -> float: def exp_smoothing(value: float, prev_value: float, factor: float) -> float: + """Apply exponential smoothing to a value. + + Exponential smoothing is used to reduce noise in time-series data by + giving more weight to recent values while still considering historical data. + + Args: + value: The new value to incorporate. + prev_value: The previous smoothed value. + factor: Smoothing factor between 0 and 1. Higher values give more + weight to the new value. + + Returns: + The exponentially smoothed value. + """ return prev_value + factor * (value - prev_value) async def ulaw_to_pcm( ulaw_bytes: bytes, in_rate: int, out_rate: int, resampler: BaseAudioResampler ): + """Convert μ-law encoded audio to PCM and optionally resample. + + Args: + ulaw_bytes: μ-law encoded audio data as raw bytes. + in_rate: Original sample rate of the μ-law audio in Hz. + out_rate: Desired output sample rate in Hz. + resampler: Audio resampler instance for rate conversion. + + Returns: + PCM audio data as raw bytes at the specified output rate. + """ # Convert μ-law to PCM in_pcm_bytes = audioop.ulaw2lin(ulaw_bytes, 2) @@ -87,6 +175,17 @@ async def ulaw_to_pcm( async def pcm_to_ulaw(pcm_bytes: bytes, in_rate: int, out_rate: int, resampler: BaseAudioResampler): + """Convert PCM audio to μ-law encoding and optionally resample. + + Args: + pcm_bytes: PCM audio data as raw bytes (16-bit signed integers). + in_rate: Original sample rate of the PCM audio in Hz. + out_rate: Desired output sample rate in Hz. + resampler: Audio resampler instance for rate conversion. + + Returns: + μ-law encoded audio data as raw bytes at the specified output rate. + """ # Resample in_pcm_bytes = await resampler.resample(pcm_bytes, in_rate, out_rate) @@ -99,6 +198,17 @@ async def pcm_to_ulaw(pcm_bytes: bytes, in_rate: int, out_rate: int, resampler: async def alaw_to_pcm( alaw_bytes: bytes, in_rate: int, out_rate: int, resampler: BaseAudioResampler ) -> bytes: + """Convert A-law encoded audio to PCM and optionally resample. + + Args: + alaw_bytes: A-law encoded audio data as raw bytes. + in_rate: Original sample rate of the A-law audio in Hz. + out_rate: Desired output sample rate in Hz. + resampler: Audio resampler instance for rate conversion. + + Returns: + PCM audio data as raw bytes at the specified output rate. + """ # Convert a-law to PCM in_pcm_bytes = audioop.alaw2lin(alaw_bytes, 2) @@ -109,6 +219,17 @@ async def alaw_to_pcm( async def pcm_to_alaw(pcm_bytes: bytes, in_rate: int, out_rate: int, resampler: BaseAudioResampler): + """Convert PCM audio to A-law encoding and optionally resample. + + Args: + pcm_bytes: PCM audio data as raw bytes (16-bit signed integers). + in_rate: Original sample rate of the PCM audio in Hz. + out_rate: Desired output sample rate in Hz. + resampler: Audio resampler instance for rate conversion. + + Returns: + A-law encoded audio data as raw bytes at the specified output rate. + """ # Resample in_pcm_bytes = await resampler.resample(pcm_bytes, in_rate, out_rate) diff --git a/src/pipecat/audio/vad/silero.py b/src/pipecat/audio/vad/silero.py index cb1dc7631..d97e9b1df 100644 --- a/src/pipecat/audio/vad/silero.py +++ b/src/pipecat/audio/vad/silero.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Silero Voice Activity Detection (VAD) implementation for Pipecat. + +This module provides a VAD analyzer based on the Silero VAD ONNX model, +which can detect voice activity in audio streams with high accuracy. +Supports 8kHz and 16kHz sample rates. +""" + import time from typing import Optional @@ -25,7 +32,20 @@ except ModuleNotFoundError as e: class SileroOnnxModel: + """ONNX runtime wrapper for the Silero VAD model. + + Provides voice activity detection using the pre-trained Silero VAD model + with ONNX runtime for efficient inference. Handles model state management + and input validation for audio processing. + """ + def __init__(self, path, force_onnx_cpu=True): + """Initialize the Silero ONNX model. + + Args: + path: Path to the ONNX model file. + force_onnx_cpu: Whether to force CPU execution provider. + """ import numpy as np global np @@ -45,6 +65,7 @@ class SileroOnnxModel: self.sample_rates = [8000, 16000] def _validate_input(self, x, sr: int): + """Validate and preprocess input audio data.""" if np.ndim(x) == 1: x = np.expand_dims(x, 0) if np.ndim(x) > 2: @@ -60,12 +81,18 @@ class SileroOnnxModel: return x, sr def reset_states(self, batch_size=1): + """Reset the internal model states. + + Args: + batch_size: Batch size for state initialization. Defaults to 1. + """ self._state = np.zeros((2, batch_size, 128), dtype="float32") self._context = np.zeros((batch_size, 0), dtype="float32") self._last_sr = 0 self._last_batch_size = 0 def __call__(self, x, sr: int): + """Process audio input through the VAD model.""" x, sr = self._validate_input(x, sr) num_samples = 512 if sr == 16000 else 256 @@ -105,7 +132,20 @@ class SileroOnnxModel: class SileroVADAnalyzer(VADAnalyzer): + """Voice Activity Detection analyzer using the Silero VAD model. + + Implements VAD analysis using the pre-trained Silero ONNX model for + accurate voice activity detection. Supports 8kHz and 16kHz sample rates + with automatic model state management and periodic resets. + """ + def __init__(self, *, sample_rate: Optional[int] = None, params: Optional[VADParams] = None): + """Initialize the Silero VAD analyzer. + + Args: + sample_rate: Audio sample rate (8000 or 16000 Hz). If None, will be set later. + params: VAD parameters for detection thresholds and timing. + """ super().__init__(sample_rate=sample_rate, params=params) logger.debug("Loading Silero VAD model...") @@ -137,6 +177,14 @@ class SileroVADAnalyzer(VADAnalyzer): # def set_sample_rate(self, sample_rate: int): + """Set the sample rate for audio processing. + + Args: + sample_rate: Audio sample rate (must be 8000 or 16000 Hz). + + Raises: + ValueError: If sample rate is not 8000 or 16000 Hz. + """ if sample_rate != 16000 and sample_rate != 8000: raise ValueError( f"Silero VAD sample rate needs to be 16000 or 8000 (sample rate: {sample_rate})" @@ -145,9 +193,22 @@ class SileroVADAnalyzer(VADAnalyzer): super().set_sample_rate(sample_rate) def num_frames_required(self) -> int: + """Get the number of audio frames required for VAD analysis. + + Returns: + Number of frames required (512 for 16kHz, 256 for 8kHz). + """ return 512 if self.sample_rate == 16000 else 256 def voice_confidence(self, buffer) -> float: + """Calculate voice activity confidence for the given audio buffer. + + Args: + buffer: Audio buffer to analyze. + + Returns: + Voice confidence score between 0.0 and 1.0. + """ try: audio_int16 = np.frombuffer(buffer, np.int16) # Divide by 32768 because we have signed 16-bit data. diff --git a/src/pipecat/audio/vad/vad_analyzer.py b/src/pipecat/audio/vad/vad_analyzer.py index 073e61712..5c92d390c 100644 --- a/src/pipecat/audio/vad/vad_analyzer.py +++ b/src/pipecat/audio/vad/vad_analyzer.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Voice Activity Detection (VAD) analyzer base classes and utilities. + +This module provides the abstract base class for VAD analyzers and associated +data structures for voice activity detection in audio streams. Includes state +management, parameter configuration, and audio analysis framework. +""" + from abc import ABC, abstractmethod from enum import Enum from typing import Optional @@ -20,6 +27,15 @@ VAD_MIN_VOLUME = 0.6 class VADState(Enum): + """Voice Activity Detection states. + + Parameters: + QUIET: No voice activity detected. + STARTING: Voice activity beginning, transitioning from quiet. + SPEAKING: Active voice detected and confirmed. + STOPPING: Voice activity ending, transitioning to quiet. + """ + QUIET = 1 STARTING = 2 SPEAKING = 3 @@ -27,6 +43,15 @@ class VADState(Enum): class VADParams(BaseModel): + """Configuration parameters for Voice Activity Detection. + + Parameters: + confidence: Minimum confidence threshold for voice detection. + start_secs: Duration to wait before confirming voice start. + stop_secs: Duration to wait before confirming voice stop. + min_volume: Minimum audio volume threshold for voice detection. + """ + confidence: float = VAD_CONFIDENCE start_secs: float = VAD_START_SECS stop_secs: float = VAD_STOP_SECS @@ -34,7 +59,20 @@ class VADParams(BaseModel): class VADAnalyzer(ABC): + """Abstract base class for Voice Activity Detection analyzers. + + Provides the framework for implementing VAD analysis with configurable + parameters, state management, and audio processing capabilities. + Subclasses must implement the core voice confidence calculation. + """ + def __init__(self, *, sample_rate: Optional[int] = None, params: Optional[VADParams] = None): + """Initialize the VAD analyzer. + + Args: + sample_rate: Audio sample rate in Hz. If None, will be set later. + params: VAD parameters for detection configuration. + """ self._init_sample_rate = sample_rate self._sample_rate = 0 self._params = params or VADParams() @@ -48,29 +86,67 @@ class VADAnalyzer(ABC): @property def sample_rate(self) -> int: + """Get the current sample rate. + + Returns: + Current audio sample rate in Hz. + """ return self._sample_rate @property def num_channels(self) -> int: + """Get the number of audio channels. + + Returns: + Number of audio channels (always 1 for mono). + """ return self._num_channels @property def params(self) -> VADParams: + """Get the current VAD parameters. + + Returns: + Current VAD configuration parameters. + """ return self._params @abstractmethod def num_frames_required(self) -> int: + """Get the number of audio frames required for analysis. + + Returns: + Number of frames needed for VAD processing. + """ pass @abstractmethod def voice_confidence(self, buffer) -> float: + """Calculate voice activity confidence for the given audio buffer. + + Args: + buffer: Audio buffer to analyze. + + Returns: + Voice confidence score between 0.0 and 1.0. + """ pass def set_sample_rate(self, sample_rate: int): + """Set the sample rate for audio processing. + + Args: + sample_rate: Audio sample rate in Hz. + """ self._sample_rate = self._init_sample_rate or sample_rate self.set_params(self._params) def set_params(self, params: VADParams): + """Set VAD parameters and recalculate internal values. + + Args: + params: VAD parameters for detection configuration. + """ logger.debug(f"Setting VAD params to: {params}") self._params = params self._vad_frames = self.num_frames_required() @@ -85,10 +161,22 @@ class VADAnalyzer(ABC): self._vad_state: VADState = VADState.QUIET def _get_smoothed_volume(self, audio: bytes) -> float: + """Calculate smoothed audio volume using exponential smoothing.""" volume = calculate_audio_volume(audio, self.sample_rate) return exp_smoothing(volume, self._prev_volume, self._smoothing_factor) def analyze_audio(self, buffer) -> VADState: + """Analyze audio buffer and return current VAD state. + + Processes incoming audio data, maintains internal state, and determines + voice activity status based on confidence and volume thresholds. + + Args: + buffer: Audio buffer to analyze. + + Returns: + Current VAD state after processing the buffer. + """ self._vad_buffer += buffer num_required_bytes = self._vad_frames_num_bytes diff --git a/src/pipecat/clocks/base_clock.py b/src/pipecat/clocks/base_clock.py index 184c82f00..7efe3c457 100644 --- a/src/pipecat/clocks/base_clock.py +++ b/src/pipecat/clocks/base_clock.py @@ -4,14 +4,33 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base clock interface for Pipecat timing operations.""" + from abc import ABC, abstractmethod class BaseClock(ABC): + """Abstract base class for clock implementations. + + Provides a common interface for timing operations used in Pipecat + for synchronization, scheduling, and time-based processing. + """ + @abstractmethod def get_time(self) -> int: + """Get the current time value. + + Returns: + The current time as an integer value. The specific unit and + reference point depend on the concrete implementation. + """ pass @abstractmethod def start(self): + """Start or initialize the clock. + + Performs any necessary initialization or starts the timing mechanism. + This method should be called before using get_time(). + """ pass diff --git a/src/pipecat/clocks/system_clock.py b/src/pipecat/clocks/system_clock.py index ed6e81ad7..87f2a722b 100644 --- a/src/pipecat/clocks/system_clock.py +++ b/src/pipecat/clocks/system_clock.py @@ -4,17 +4,42 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""System clock implementation for Pipecat.""" + import time from pipecat.clocks.base_clock import BaseClock class SystemClock(BaseClock): + """A monotonic clock implementation using system time. + + Provides high-precision timing using the system's monotonic clock, + which is not affected by system clock adjustments and is suitable + for measuring elapsed time in real-time applications. + """ + def __init__(self): + """Initialize the system clock. + + The clock starts in an uninitialized state and must be started + explicitly using the start() method before time measurement begins. + """ self._time = 0 def get_time(self) -> int: + """Get the elapsed time since the clock was started. + + Returns: + The elapsed time in nanoseconds since start() was called. + Returns 0 if the clock has not been started yet. + """ return time.monotonic_ns() - self._time if self._time > 0 else 0 def start(self): + """Start the clock and begin time measurement. + + Records the current monotonic time as the reference point + for all subsequent get_time() calls. + """ self._time = time.monotonic_ns() diff --git a/src/pipecat/examples/daily_runner.py b/src/pipecat/examples/daily_runner.py index 04157d549..d8d6acd81 100644 --- a/src/pipecat/examples/daily_runner.py +++ b/src/pipecat/examples/daily_runner.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Daily.co room configuration utilities for Pipecat examples.""" + import argparse import os from typing import Optional @@ -14,6 +16,17 @@ from pipecat.transports.services.helpers.daily_rest import DailyRESTHelper async def configure(aiohttp_session: aiohttp.ClientSession): + """Configure Daily.co room URL and token from arguments or environment. + + Args: + aiohttp_session: HTTP session for making API requests. + + Returns: + Tuple containing the room URL and authentication token. + + Raises: + Exception: If room URL or API key are not provided. + """ (url, token, _) = await configure_with_args(aiohttp_session) return (url, token) @@ -21,6 +34,18 @@ async def configure(aiohttp_session: aiohttp.ClientSession): async def configure_with_args( aiohttp_session: aiohttp.ClientSession, parser: Optional[argparse.ArgumentParser] = None ): + """Configure Daily.co room with command-line argument parsing. + + Args: + aiohttp_session: HTTP session for making API requests. + parser: Optional argument parser. If None, creates a default one. + + Returns: + Tuple containing room URL, authentication token, and parsed arguments. + + Raises: + Exception: If room URL or API key are not provided via arguments or environment. + """ if not parser: parser = argparse.ArgumentParser(description="Daily AI SDK Bot Sample") parser.add_argument( diff --git a/src/pipecat/examples/run.py b/src/pipecat/examples/run.py index 24f673e06..edc8ba058 100644 --- a/src/pipecat/examples/run.py +++ b/src/pipecat/examples/run.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Pipecat example runner with support for multiple transport types. + +This module provides a unified interface for running Pipecat examples across +different transport types including Daily.co, WebRTC, and Twilio. It handles +setup, configuration, and lifecycle management for each transport type. +""" + import argparse import asyncio import json @@ -35,6 +42,15 @@ load_dotenv(override=True) def get_transport_client_id(transport: BaseTransport, client: Any) -> str: + """Get client identifier from transport-specific client object. + + Args: + transport: The transport instance. + client: Transport-specific client object. + + Returns: + Client identifier string, empty if transport not supported. + """ if isinstance(transport, SmallWebRTCTransport): return client.pc_id elif isinstance(transport, DailyTransport): @@ -46,6 +62,13 @@ def get_transport_client_id(transport: BaseTransport, client: Any) -> str: async def maybe_capture_participant_camera( transport: BaseTransport, client: Any, framerate: int = 0 ): + """Capture participant camera video if transport supports it. + + Args: + transport: The transport instance. + client: Transport-specific client object. + framerate: Video capture framerate. Defaults to 0 (auto). + """ if isinstance(transport, DailyTransport): await transport.capture_participant_video( client["id"], framerate=framerate, video_source="camera" @@ -55,6 +78,13 @@ async def maybe_capture_participant_camera( async def maybe_capture_participant_screen( transport: BaseTransport, client: Any, framerate: int = 0 ): + """Capture participant screen video if transport supports it. + + Args: + transport: The transport instance. + client: Transport-specific client object. + framerate: Video capture framerate. Defaults to 0 (auto). + """ if isinstance(transport, DailyTransport): await transport.capture_participant_video( client["id"], framerate=framerate, video_source="screenVideo" @@ -66,6 +96,13 @@ def run_example_daily( args: argparse.Namespace, transport_params: Mapping[str, Callable] = {}, ): + """Run example using Daily.co transport. + + Args: + run_example: The example function to run. + args: Parsed command-line arguments. + transport_params: Mapping of transport names to parameter factory functions. + """ logger.info("Running example with DailyTransport...") from pipecat.examples.daily_runner import configure @@ -87,6 +124,13 @@ def run_example_webrtc( args: argparse.Namespace, transport_params: Mapping[str, Callable] = {}, ): + """Run example using WebRTC transport with FastAPI server. + + Args: + run_example: The example function to run. + args: Parsed command-line arguments. + transport_params: Mapping of transport names to parameter factory functions. + """ logger.info("Running example with SmallWebRTCTransport...") from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI @@ -107,10 +151,20 @@ def run_example_webrtc( @app.get("/", include_in_schema=False) async def root_redirect(): + """Redirect root requests to client interface.""" return RedirectResponse(url="/client/") @app.post("/api/offer") async def offer(request: dict, background_tasks: BackgroundTasks): + """Handle WebRTC offer requests and manage peer connections. + + Args: + request: WebRTC offer request containing SDP and connection details. + background_tasks: FastAPI background tasks for running examples. + + Returns: + WebRTC answer with connection details. + """ pc_id = request.get("pc_id") if pc_id and pc_id in pcs_map: @@ -127,6 +181,11 @@ def run_example_webrtc( @pipecat_connection.event_handler("closed") async def handle_disconnected(webrtc_connection: SmallWebRTCConnection): + """Handle WebRTC connection closure and cleanup. + + Args: + webrtc_connection: The closed WebRTC connection. + """ logger.info(f"Discarding peer connection for pc_id: {webrtc_connection.pc_id}") pcs_map.pop(webrtc_connection.pc_id, None) @@ -143,6 +202,11 @@ def run_example_webrtc( @asynccontextmanager async def lifespan(app: FastAPI): + """Manage FastAPI application lifecycle and cleanup connections. + + Args: + app: The FastAPI application instance. + """ yield # Run app coros = [pc.disconnect() for pc in pcs_map.values()] await asyncio.gather(*coros) @@ -156,6 +220,13 @@ def run_example_twilio( args: argparse.Namespace, transport_params: Mapping[str, Callable] = {}, ): + """Run example using Twilio transport with FastAPI WebSocket server. + + Args: + run_example: The example function to run. + args: Parsed command-line arguments. + transport_params: Mapping of transport names to parameter factory functions. + """ logger.info("Running example with FastAPIWebsocketTransport (Twilio)...") app = FastAPI() @@ -170,6 +241,11 @@ def run_example_twilio( @app.post("/") async def start_call(): + """Handle Twilio webhook and return TwiML response. + + Returns: + TwiML XML response directing call to WebSocket stream. + """ logger.debug("POST TwiML") xml_content = f""" @@ -184,6 +260,11 @@ def run_example_twilio( @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): + """Handle Twilio WebSocket connections for voice streaming. + + Args: + websocket: The WebSocket connection from Twilio. + """ await websocket.accept() logger.debug("WebSocket connection accepted") @@ -216,6 +297,13 @@ def run_main( args: argparse.Namespace, transport_params: Mapping[str, Callable] = {}, ): + """Run the example with the specified transport type. + + Args: + run_example: The example function to run. + args: Parsed command-line arguments. + transport_params: Mapping of transport names to parameter factory functions. + """ if args.transport not in transport_params: logger.error(f"Transport '{args.transport}' not supported by this example") return @@ -235,6 +323,13 @@ def main( parser: Optional[argparse.ArgumentParser] = None, transport_params: Mapping[str, Callable] = {}, ): + """Main entry point for running Pipecat examples with transport selection. + + Args: + run_example: The example function to run. + parser: Optional argument parser. If None, creates a default one. + transport_params: Mapping of transport names to parameter factory functions. + """ if not parser: parser = argparse.ArgumentParser(description="Pipecat Bot Runner") parser.add_argument( diff --git a/src/pipecat/frames/frames.py b/src/pipecat/frames/frames.py index 3a602a3e2..59e862818 100644 --- a/src/pipecat/frames/frames.py +++ b/src/pipecat/frames/frames.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Core frame definitions for the Pipecat AI framework. + +This module contains all frame types used throughout the Pipecat pipeline system, +including data frames, system frames, and control frames for audio, video, text, +and LLM processing. +""" + from dataclasses import dataclass, field from enum import Enum from typing import ( @@ -32,7 +39,22 @@ if TYPE_CHECKING: class KeypadEntry(str, Enum): - """DTMF entries.""" + """DTMF keypad entries for phone system integration. + + Parameters: + ONE: Number key 1. + TWO: Number key 2. + THREE: Number key 3. + FOUR: Number key 4. + FIVE: Number key 5. + SIX: Number key 6. + SEVEN: Number key 7. + EIGHT: Number key 8. + NINE: Number key 9. + ZERO: Number key 0. + POUND: Pound/hash key (#). + STAR: Star/asterisk key (*). + """ ONE = "1" TWO = "2" @@ -49,12 +71,31 @@ class KeypadEntry(str, Enum): def format_pts(pts: Optional[int]): + """Format presentation timestamp (PTS) in nanoseconds to a human-readable string. + + Converts a PTS value in nanoseconds to a string representation. + + Args: + pts: Presentation timestamp in nanoseconds, or None if not set. + """ return nanoseconds_to_str(pts) if pts else None @dataclass class Frame: - """Base frame class.""" + """Base frame class for all frames in the Pipecat pipeline. + + All frames inherit from this base class and automatically receive + unique identifiers, names, and metadata support. + + Parameters: + id: Unique identifier for the frame instance. + name: Human-readable name combining class name and instance count. + pts: Presentation timestamp in nanoseconds. + metadata: Dictionary for arbitrary frame metadata. + transport_source: Name of the transport source that created this frame. + transport_destination: Name of the transport destination for this frame. + """ id: int = field(init=False) name: str = field(init=False) @@ -77,9 +118,10 @@ class Frame: @dataclass class SystemFrame(Frame): - """System frames are frames that are not internally queued by any of the - frame processors and should be processed immediately. + """System frame class for immediate processing. + System frames are frames that are not internally queued by any of the + frame processors and should be processed immediately. """ pass @@ -87,9 +129,10 @@ class SystemFrame(Frame): @dataclass class DataFrame(Frame): - """Data frames are frames that will be processed in order and usually - contain data such as LLM context, text, audio or images. + """Data frame class for processing data in order. + Data frames are frames that will be processed in order and usually + contain data such as LLM context, text, audio or images. """ pass @@ -97,10 +140,11 @@ class DataFrame(Frame): @dataclass class ControlFrame(Frame): - """Control frames are frames that, similar to data frames, will be processed + """Control frame class for processing control information in order. + + Control frames are frames that, similar to data frames, will be processed in order and usually contain control information such as frames to update settings or to end the pipeline. - """ pass @@ -113,7 +157,14 @@ class ControlFrame(Frame): @dataclass class AudioRawFrame: - """A chunk of audio.""" + """A frame containing a chunk of raw audio. + + Parameters: + audio: Raw audio bytes in PCM format. + sample_rate: Audio sample rate in Hz. + num_channels: Number of audio channels. + num_frames: Number of audio frames (calculated automatically). + """ audio: bytes sample_rate: int @@ -126,7 +177,13 @@ class AudioRawFrame: @dataclass class ImageRawFrame: - """A raw image.""" + """A frame containing a raw image. + + Parameters: + image: Raw image bytes. + size: Image dimensions as (width, height) tuple. + format: Image format (e.g., 'JPEG', 'PNG'). + """ image: bytes size: Tuple[int, int] @@ -140,10 +197,11 @@ class ImageRawFrame: @dataclass class OutputAudioRawFrame(DataFrame, AudioRawFrame): - """A chunk of audio. Will be played by the output transport. If the - transport supports multiple audio destinations (e.g. multiple audio tracks) the - destination name can be specified. + """Audio data frame for output to transport. + A chunk of raw audio that will be played by the output transport. If the + transport supports multiple audio destinations (e.g. multiple audio tracks) + the destination name can be specified in transport_destination. """ def __post_init__(self): @@ -157,10 +215,11 @@ class OutputAudioRawFrame(DataFrame, AudioRawFrame): @dataclass class OutputImageRawFrame(DataFrame, ImageRawFrame): - """An image that will be shown by the transport. If the transport supports - multiple video destinations (e.g. multiple video tracks) the destination - name can be specified. + """Image data frame for output to transport. + An image that will be shown by the transport. If the transport supports + multiple video destinations (e.g. multiple video tracks) the destination + name can be specified in transport_destination. """ def __str__(self): @@ -170,16 +229,23 @@ class OutputImageRawFrame(DataFrame, ImageRawFrame): @dataclass class TTSAudioRawFrame(OutputAudioRawFrame): - """A chunk of output audio generated by a TTS service.""" + """Audio data frame generated by Text-to-Speech services. + + A chunk of output audio generated by a TTS service, ready for playback. + """ pass @dataclass class URLImageRawFrame(OutputImageRawFrame): - """An output image with an associated URL. These images are usually + """Image frame with an associated URL. + + An output image with an associated URL. These images are usually generated by third-party services that provide a URL to download the image. + Parameters: + url: URL where the image can be downloaded from. """ url: Optional[str] = None @@ -191,10 +257,14 @@ class URLImageRawFrame(OutputImageRawFrame): @dataclass class SpriteFrame(DataFrame): - """An animated sprite. Will be shown by the transport if the transport's + """Animated sprite frame containing multiple images. + + An animated sprite that will be shown by the transport if the transport's camera is enabled. Will play at the framerate specified in the transport's `camera_out_framerate` constructor parameter. + Parameters: + images: List of image frames that make up the sprite animation. """ images: List[OutputImageRawFrame] @@ -206,9 +276,14 @@ class SpriteFrame(DataFrame): @dataclass class TextFrame(DataFrame): - """A chunk of text. Emitted by LLM services, consumed by TTS services, can - be used to send text through processors. + """Text data frame for passing text through the pipeline. + A chunk of text. Emitted by LLM services, consumed by context + aggregators, TTS services and more. Can be used to send text + through processors. + + Parameters: + text: The text content. """ text: str @@ -220,23 +295,30 @@ class TextFrame(DataFrame): @dataclass class LLMTextFrame(TextFrame): - """A text frame generated by LLM services.""" + """Text frame generated by LLM services.""" pass @dataclass class TTSTextFrame(TextFrame): - """A text frame generated by TTS services.""" + """Text frame generated by Text-to-Speech services.""" pass @dataclass class TranscriptionFrame(TextFrame): - """A text frame with transcription-specific data. The `result` field + """Text frame containing speech transcription data. + + A text frame with transcription-specific data. The `result` field contains the result from the STT service if available. + Parameters: + user_id: Identifier for the user who spoke. + timestamp: When the transcription occurred. + language: Detected or specified language of the speech. + result: Raw result from the STT service. """ user_id: str @@ -250,9 +332,17 @@ class TranscriptionFrame(TextFrame): @dataclass class InterimTranscriptionFrame(TextFrame): - """A text frame with interim transcription-specific data. The `result` field + """Text frame containing partial/interim transcription data. + + A text frame with interim transcription-specific data that represents + partial results before final transcription. The `result` field contains the result from the STT service if available. + Parameters: + user_id: Identifier for the user who spoke. + timestamp: When the interim transcription occurred. + language: Detected or specified language of the speech. + result: Raw result from the STT service. """ text: str @@ -267,10 +357,15 @@ class InterimTranscriptionFrame(TextFrame): @dataclass class TranslationFrame(TextFrame): - """A text frame with translated transcription data. + """Text frame containing translated transcription data. - Will be placed in the transport's receive queue when a participant speaks. + A text frame with translated transcription data that will be placed + in the transport's receive queue when a participant speaks. + Parameters: + user_id: Identifier for the user who spoke. + timestamp: When the translation occurred. + language: Target language of the translation. """ user_id: str @@ -283,16 +378,27 @@ class TranslationFrame(TextFrame): @dataclass class OpenAILLMContextAssistantTimestampFrame(DataFrame): - """Timestamp information for assistant message in LLM context.""" + """Timestamp information for assistant messages in LLM context. + + Parameters: + timestamp: Timestamp when the assistant message was created. + """ timestamp: str @dataclass class TranscriptionMessage: - """A message in a conversation transcript containing the role and content. + """A message in a conversation transcript. + A message in a conversation transcript containing the role and content. Messages are in standard format with roles normalized to user/assistant. + + Parameters: + role: The role of the message sender (user or assistant). + content: The message content/text. + user_id: Optional identifier for the user. + timestamp: Optional timestamp when the message was created. """ role: Literal["user", "assistant"] @@ -303,39 +409,46 @@ class TranscriptionMessage: @dataclass class TranscriptionUpdateFrame(DataFrame): - """A frame containing new messages added to the conversation transcript. + """Frame containing new messages added to conversation transcript. + A frame containing new messages added to the conversation transcript. This frame is emitted when new messages are added to the conversation history, containing only the newly added messages rather than the full transcript. Messages have normalized roles (user/assistant) regardless of the LLM service used. Messages are always in the OpenAI standard message format, which supports both: - Simple format: - [ - { - "role": "user", - "content": "Hi, how are you?" - }, - { - "role": "assistant", - "content": "Great! And you?" - } - ] + Examples: + Simple format:: - Content list format: - [ - { - "role": "user", - "content": [{"type": "text", "text": "Hi, how are you?"}] - }, - { - "role": "assistant", - "content": [{"type": "text", "text": "Great! And you?"}] - } - ] + [ + { + "role": "user", + "content": "Hi, how are you?" + }, + { + "role": "assistant", + "content": "Great! And you?" + } + ] + + Content list format:: + + [ + { + "role": "user", + "content": [{"type": "text", "text": "Hi, how are you?"}] + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "Great! And you?"}] + } + ] OpenAI supports both formats. Anthropic and Google messages are converted to the content list format. + + Parameters: + messages: List of new transcript messages that were added. """ messages: List[TranscriptionMessage] @@ -347,12 +460,16 @@ class TranscriptionUpdateFrame(DataFrame): @dataclass class LLMMessagesFrame(DataFrame): - """A frame containing a list of LLM messages. Used to signal that an LLM + """Frame containing LLM messages for chat completion. + + A frame containing a list of LLM messages. Used to signal that an LLM service should run a chat completion and emit an LLMFullResponseStartFrame, TextFrames and an LLMFullResponseEndFrame. Note that the `messages` - property in this class is mutable, and will be be updated by various + property in this class is mutable, and will be updated by various aggregators. + Parameters: + messages: List of message dictionaries in LLM format. """ messages: List[dict] @@ -360,9 +477,13 @@ class LLMMessagesFrame(DataFrame): @dataclass class LLMMessagesAppendFrame(DataFrame): - """A frame containing a list of LLM messages that need to be added to the + """Frame containing LLM messages to append to current context. + + A frame containing a list of LLM messages that need to be added to the current context. + Parameters: + messages: List of message dictionaries to append. """ messages: List[dict] @@ -370,10 +491,14 @@ class LLMMessagesAppendFrame(DataFrame): @dataclass class LLMMessagesUpdateFrame(DataFrame): - """A frame containing a list of new LLM messages. These messages will + """Frame containing LLM messages to replace current context. + + A frame containing a list of new LLM messages. These messages will replace the current context LLM messages and should generate a new LLMMessagesFrame. + Parameters: + messages: List of message dictionaries to replace current context. """ messages: List[dict] @@ -381,9 +506,14 @@ class LLMMessagesUpdateFrame(DataFrame): @dataclass class LLMSetToolsFrame(DataFrame): - """A frame containing a list of tools for an LLM to use for function calling. + """Frame containing tools for LLM function calling. + + A frame containing a list of tools for an LLM to use for function calling. The specific format depends on the LLM being used, but it should typically contain JSON Schema objects. + + Parameters: + tools: List of tool/function definitions for the LLM. """ tools: List[dict] @@ -391,23 +521,35 @@ class LLMSetToolsFrame(DataFrame): @dataclass class LLMSetToolChoiceFrame(DataFrame): - """A frame containing a tool choice for an LLM to use for function calling.""" + """Frame containing tool choice configuration for LLM function calling. + + Parameters: + tool_choice: Tool choice setting - 'none', 'auto', 'required', or specific tool dict. + """ tool_choice: Literal["none", "auto", "required"] | dict @dataclass class LLMEnablePromptCachingFrame(DataFrame): - """A frame to enable/disable prompt caching in certain LLMs.""" + """Frame to enable/disable prompt caching in LLMs. + + Parameters: + enable: Whether to enable prompt caching. + """ enable: bool @dataclass class TTSSpeakFrame(DataFrame): - """A frame that contains a text that should be spoken by the TTS in the - pipeline (if any). + """Frame containing text that should be spoken by TTS. + A frame that contains text that should be spoken by the TTS service + in the pipeline (if any). + + Parameters: + text: The text to be spoken. """ text: str @@ -415,6 +557,12 @@ class TTSSpeakFrame(DataFrame): @dataclass class TransportMessageFrame(DataFrame): + """Frame containing transport-specific message data. + + Parameters: + message: The transport message payload. + """ + message: Any def __str__(self): @@ -423,17 +571,22 @@ class TransportMessageFrame(DataFrame): @dataclass class DTMFFrame: - """A DTMF button frame""" + """Base class for DTMF (Dual-Tone Multi-Frequency) keypad frames. + + Parameters: + button: The DTMF keypad entry that was pressed. + """ button: KeypadEntry @dataclass class OutputDTMFFrame(DTMFFrame, DataFrame): - """A DTMF keypress output that will be queued. If your transport supports + """DTMF keypress output frame for transport queuing. + + A DTMF keypress output that will be queued. If your transport supports multiple dial-out destinations, use the `transport_destination` field to specify where the DTMF keypress should be sent. - """ pass @@ -446,7 +599,20 @@ class OutputDTMFFrame(DTMFFrame, DataFrame): @dataclass class StartFrame(SystemFrame): - """This is the first frame that should be pushed down a pipeline.""" + """Initial frame to start pipeline processing. + + This is the first frame that should be pushed down a pipeline to + initialize all processors with their configuration parameters. + + Parameters: + audio_in_sample_rate: Input audio sample rate in Hz. + audio_out_sample_rate: Output audio sample rate in Hz. + allow_interruptions: Whether to allow user interruptions. + enable_metrics: Whether to enable performance metrics collection. + enable_usage_metrics: Whether to enable usage metrics collection. + interruption_strategies: List of interruption handling strategies. + report_only_initial_ttfb: Whether to report only initial time-to-first-byte. + """ audio_in_sample_rate: int = 16000 audio_out_sample_rate: int = 24000 @@ -459,17 +625,26 @@ class StartFrame(SystemFrame): @dataclass class CancelFrame(SystemFrame): - """Indicates that a pipeline needs to stop right away.""" + """Frame indicating pipeline should stop immediately. + + Indicates that a pipeline needs to stop right away without + processing remaining queued frames. + """ pass @dataclass class ErrorFrame(SystemFrame): - """This is used notify upstream that an error has occurred downstream the - pipeline. A fatal error indicates the error is unrecoverable and that the + """Frame notifying of errors in the pipeline. + + This is used to notify upstream that an error has occurred downstream in + the pipeline. A fatal error indicates the error is unrecoverable and that the bot should exit. + Parameters: + error: Description of the error that occurred. + fatal: Whether the error is fatal and requires bot shutdown. """ error: str @@ -481,9 +656,13 @@ class ErrorFrame(SystemFrame): @dataclass class FatalErrorFrame(ErrorFrame): - """This is used notify upstream that an unrecoverable error has occurred and - that the bot should exit. + """Frame notifying of unrecoverable errors requiring bot shutdown. + This is used to notify upstream that an unrecoverable error has occurred and + that the bot should exit immediately. + + Parameters: + fatal: Always True for fatal errors. """ fatal: bool = field(default=True, init=False) @@ -491,10 +670,11 @@ class FatalErrorFrame(ErrorFrame): @dataclass class EndTaskFrame(SystemFrame): - """This is used to notify the pipeline task that the pipeline should be - closed nicely (flushing all the queued frames) by pushing an EndFrame - downstream. + """Frame to request graceful pipeline task closure. + This is used to notify the pipeline task that the pipeline should be + closed nicely (flushing all the queued frames) by pushing an EndFrame + downstream. This frame should be pushed upstream. """ pass @@ -502,9 +682,11 @@ class EndTaskFrame(SystemFrame): @dataclass class CancelTaskFrame(SystemFrame): - """This is used to notify the pipeline task that the pipeline should be - stopped immediately by pushing a CancelFrame downstream. + """Frame to request immediate pipeline task cancellation. + This is used to notify the pipeline task that the pipeline should be + stopped immediately by pushing a CancelFrame downstream. This frame + should be pushed upstream. """ pass @@ -512,10 +694,12 @@ class CancelTaskFrame(SystemFrame): @dataclass class StopTaskFrame(SystemFrame): - """This is used to notify the pipeline task that it should be stopped as - soon as possible (flushing all the queued frames) but that the pipeline - processors should be kept in a running state. + """Frame to request pipeline task stop while keeping processors running. + This is used to notify the pipeline task that it should be stopped as + soon as possible (flushing all the queued frames) but that the pipeline + processors should be kept in a running state. This frame should be pushed + upstream. """ pass @@ -523,11 +707,15 @@ class StopTaskFrame(SystemFrame): @dataclass class FrameProcessorPauseUrgentFrame(SystemFrame): - """This frame is used to pause frame processing for the given processor as + """Frame to pause frame processing immediately. + + This frame is used to pause frame processing for the given processor as fast as possible. Pausing frame processing will keep frames in the internal queue which will then be processed when frame processing is resumed with `FrameProcessorResumeFrame`. + Parameters: + processor: The frame processor to pause. """ processor: "FrameProcessor" @@ -535,10 +723,14 @@ class FrameProcessorPauseUrgentFrame(SystemFrame): @dataclass class FrameProcessorResumeUrgentFrame(SystemFrame): - """This frame is used to resume frame processing for the given processor + """Frame to resume frame processing immediately. + + This frame is used to resume frame processing for the given processor if it was previously paused as fast as possible. After resuming frame processing all queued frames will be processed in the order received. + Parameters: + processor: The frame processor to resume. """ processor: "FrameProcessor" @@ -546,11 +738,12 @@ class FrameProcessorResumeUrgentFrame(SystemFrame): @dataclass class StartInterruptionFrame(SystemFrame): - """Emitted by VAD to indicate that a user has started speaking (i.e. is - interruption). This is similar to UserStartedSpeakingFrame except that it - should be pushed concurrently with other frames (so the order is not - guaranteed). + """Frame indicating user started speaking (interruption detected). + Emitted by the BaseInputTransport to indicate that a user has started + speaking (i.e. is interrupting). This is similar to + UserStartedSpeakingFrame except that it should be pushed concurrently + with other frames (so the order is not guaranteed). """ pass @@ -558,11 +751,12 @@ class StartInterruptionFrame(SystemFrame): @dataclass class StopInterruptionFrame(SystemFrame): - """Emitted by VAD to indicate that a user has stopped speaking (i.e. no more - interruptions). This is similar to UserStoppedSpeakingFrame except that it - should be pushed concurrently with other frames (so the order is not - guaranteed). + """Frame indicating user stopped speaking (interruption ended). + Emitted by the BaseInputTransport to indicate that a user has stopped + speaking (i.e. no more interruptions). This is similar to + UserStoppedSpeakingFrame except that it should be pushed concurrently + with other frames (so the order is not guaranteed). """ pass @@ -570,11 +764,15 @@ class StopInterruptionFrame(SystemFrame): @dataclass class UserStartedSpeakingFrame(SystemFrame): - """Emitted by VAD to indicate that a user has started speaking. This can be + """Frame indicating user has started speaking. + + Emitted by VAD to indicate that a user has started speaking. This can be used for interruptions or other times when detecting that someone is speaking is more important than knowing what they're saying (as you will - with a TranscriptionFrame) + get with a TranscriptionFrame). + Parameters: + emulated: Whether this event was emulated rather than detected by VAD. """ emulated: bool = False @@ -582,14 +780,22 @@ class UserStartedSpeakingFrame(SystemFrame): @dataclass class UserStoppedSpeakingFrame(SystemFrame): - """Emitted by the VAD to indicate that a user stopped speaking.""" + """Frame indicating user has stopped speaking. + + Emitted by the VAD to indicate that a user stopped speaking. + + Parameters: + emulated: Whether this event was emulated rather than detected by VAD. + """ emulated: bool = False @dataclass class EmulateUserStartedSpeakingFrame(SystemFrame): - """Emitted by internal processors upstream to emulate VAD behavior when a + """Frame to emulate user started speaking behavior. + + Emitted by internal processors upstream to emulate VAD behavior when a user starts speaking. """ @@ -598,7 +804,9 @@ class EmulateUserStartedSpeakingFrame(SystemFrame): @dataclass class EmulateUserStoppedSpeakingFrame(SystemFrame): - """Emitted by internal processors upstream to emulate VAD behavior when a + """Frame to emulate user stopped speaking behavior. + + Emitted by internal processors upstream to emulate VAD behavior when a user stops speaking. """ @@ -607,24 +815,27 @@ class EmulateUserStoppedSpeakingFrame(SystemFrame): @dataclass class VADUserStartedSpeakingFrame(SystemFrame): - """Frame emitted when VAD detects the user has definitively started speaking.""" + """Frame emitted when VAD definitively detects user started speaking.""" pass @dataclass class VADUserStoppedSpeakingFrame(SystemFrame): - """Frame emitted when VAD detects the user has definitively stopped speaking.""" + """Frame emitted when VAD definitively detects user stopped speaking.""" pass @dataclass class BotInterruptionFrame(SystemFrame): - """Emitted by when the bot should be interrupted. This will mainly cause the + """Frame indicating the bot should be interrupted. + + Emitted when the bot should be interrupted. This will mainly cause the same actions as if the user interrupted except that the UserStartedSpeakingFrame and UserStoppedSpeakingFrame won't be generated. - + This frame should be pushed upstreams. It results in the BaseInputTransport + starting an interruption by pushing a StartInterruptionFrame downstream. """ pass @@ -632,25 +843,34 @@ class BotInterruptionFrame(SystemFrame): @dataclass class BotStartedSpeakingFrame(SystemFrame): - """Emitted upstream by transport outputs to indicate the bot started speaking.""" + """Frame indicating the bot started speaking. + + Emitted upstream and downstream by the BaseTransportOutput to indicate the + bot started speaking. + """ pass @dataclass class BotStoppedSpeakingFrame(SystemFrame): - """Emitted upstream by transport outputs to indicate the bot stopped speaking.""" + """Frame indicating the bot stopped speaking. + + Emitted upstream and downstream by the BaseTransportOutput to indicate the + bot stopped speaking. + """ pass @dataclass class BotSpeakingFrame(SystemFrame): - """Emitted upstream by transport outputs while the bot is still - speaking. This can be used, for example, to detect when a user is idle. That - is, while the bot is speaking we don't want to trigger any user idle timeout - since the user might be listening. + """Frame indicating the bot is currently speaking. + Emitted upstream and downstream by the BaseOutputTransport while the bot is + still speaking. This can be used, for example, to detect when a user is + idle. That is, while the bot is speaking we don't want to trigger any user + idle timeout since the user might be listening. """ pass @@ -658,21 +878,28 @@ class BotSpeakingFrame(SystemFrame): @dataclass class MetricsFrame(SystemFrame): - """Emitted by processor that can compute metrics like latencies.""" + """Frame containing performance metrics data. + + Emitted by processors that can compute metrics like latencies. + + Parameters: + data: List of metrics data collected by the processor. + """ data: List[MetricsData] @dataclass class FunctionCallFromLLM: - """Represents a function call returned by the LLM to be registered for execution. + """Represents a function call returned by the LLM. - Attributes: - function_name (str): The name of the function. - tool_call_id (str): A unique identifier for the function call. - arguments (Mapping[str, Any]): The arguments for the function. - context (OpenAILLMContext): The LLM context. + Represents a function call returned by the LLM to be registered for execution. + Parameters: + function_name: The name of the function to call. + tool_call_id: A unique identifier for the function call. + arguments: The arguments to pass to the function. + context: The LLM context when the function call was made. """ function_name: str @@ -683,15 +910,28 @@ class FunctionCallFromLLM: @dataclass class FunctionCallsStartedFrame(SystemFrame): - """A frame signaling that one or more function call execution is going to - start.""" + """Frame signaling that function call execution is starting. + + A frame signaling that one or more function call execution is going to + start. + + Parameters: + function_calls: Sequence of function calls that will be executed. + """ function_calls: Sequence[FunctionCallFromLLM] @dataclass class FunctionCallInProgressFrame(SystemFrame): - """A frame signaling that a function call is in progress.""" + """Frame signaling that a function call is currently executing. + + Parameters: + function_name: Name of the function being executed. + tool_call_id: Unique identifier for this function call. + arguments: Arguments passed to the function. + cancel_on_interruption: Whether to cancel this call if interrupted. + """ function_name: str tool_call_id: str @@ -701,7 +941,12 @@ class FunctionCallInProgressFrame(SystemFrame): @dataclass class FunctionCallCancelFrame(SystemFrame): - """A frame to signal a function call has been cancelled.""" + """Frame signaling that a function call has been cancelled. + + Parameters: + function_name: Name of the function that was cancelled. + tool_call_id: Unique identifier for the cancelled function call. + """ function_name: str tool_call_id: str @@ -709,7 +954,12 @@ class FunctionCallCancelFrame(SystemFrame): @dataclass class FunctionCallResultProperties: - """Properties for a function call result frame.""" + """Properties for configuring function call result behavior. + + Parameters: + run_llm: Whether to run the LLM after receiving this result. + on_context_updated: Callback to execute when context is updated. + """ run_llm: Optional[bool] = None on_context_updated: Optional[Callable[[], Awaitable[None]]] = None @@ -717,7 +967,16 @@ class FunctionCallResultProperties: @dataclass class FunctionCallResultFrame(SystemFrame): - """A frame containing the result of an LLM function (tool) call.""" + """Frame containing the result of an LLM function call. + + Parameters: + function_name: Name of the function that was executed. + tool_call_id: Unique identifier for the function call. + arguments: Arguments that were passed to the function. + result: The result returned by the function. + run_llm: Whether to run the LLM after this result. + properties: Additional properties for result handling. + """ function_name: str tool_call_id: str @@ -729,13 +988,23 @@ class FunctionCallResultFrame(SystemFrame): @dataclass class STTMuteFrame(SystemFrame): - """System frame to mute/unmute the STT service.""" + """Frame to mute/unmute the Speech-to-Text service. + + Parameters: + mute: Whether to mute (True) or unmute (False) the STT service. + """ mute: bool @dataclass class TransportMessageUrgentFrame(SystemFrame): + """Frame for urgent transport messages that need immediate processing. + + Parameters: + message: The urgent transport message payload. + """ + message: Any def __str__(self): @@ -744,10 +1013,18 @@ class TransportMessageUrgentFrame(SystemFrame): @dataclass class UserImageRequestFrame(SystemFrame): - """A frame to request an image from the given user. The frame might be + """Frame requesting an image from a specific user. + + A frame to request an image from the given user. The frame might be generated by a function call in which case the corresponding fields will be properly set. + Parameters: + user_id: Identifier of the user to request image from. + context: Optional context for the image request. + function_name: Name of function that generated this request (if any). + tool_call_id: Tool call ID if generated by function call. + video_source: Specific video source to capture from. """ user_id: str @@ -762,10 +1039,11 @@ class UserImageRequestFrame(SystemFrame): @dataclass class InputAudioRawFrame(SystemFrame, AudioRawFrame): - """A chunk of audio usually coming from an input transport. If the transport - supports multiple audio sources (e.g. multiple audio tracks) the source name - will be specified. + """Raw audio input frame from transport. + A chunk of audio usually coming from an input transport. If the transport + supports multiple audio sources (e.g. multiple audio tracks) the source name + will be specified in transport_source. """ def __post_init__(self): @@ -779,10 +1057,11 @@ class InputAudioRawFrame(SystemFrame, AudioRawFrame): @dataclass class InputImageRawFrame(SystemFrame, ImageRawFrame): - """An image usually coming from an input transport. If the transport - supports multiple video sources (e.g. multiple video tracks) the source name - will be specified. + """Raw image input frame from transport. + An image usually coming from an input transport. If the transport + supports multiple video sources (e.g. multiple video tracks) the source name + will be specified in transport_source. """ def __str__(self): @@ -792,7 +1071,13 @@ class InputImageRawFrame(SystemFrame, ImageRawFrame): @dataclass class UserAudioRawFrame(InputAudioRawFrame): - """A chunk of audio, usually coming from an input transport, associated to a user.""" + """Raw audio input frame associated with a specific user. + + A chunk of audio, usually coming from an input transport, associated to a user. + + Parameters: + user_id: Identifier of the user who provided this audio. + """ user_id: str = "" @@ -803,7 +1088,14 @@ class UserAudioRawFrame(InputAudioRawFrame): @dataclass class UserImageRawFrame(InputImageRawFrame): - """An image associated to a user.""" + """Raw image input frame associated with a specific user. + + An image associated to a user, potentially in response to an image request. + + Parameters: + user_id: Identifier of the user who provided this image. + request: The original image request frame if this is a response. + """ user_id: str = "" request: Optional[UserImageRequestFrame] = None @@ -815,7 +1107,13 @@ class UserImageRawFrame(InputImageRawFrame): @dataclass class VisionImageRawFrame(InputImageRawFrame): - """An image with an associated text to ask for a description of it.""" + """Image frame for vision/image analysis with associated text prompt. + + An image with an associated text to ask for a description of it. + + Parameters: + text: Optional text prompt describing what to analyze in the image. + """ text: Optional[str] = None @@ -826,17 +1124,18 @@ class VisionImageRawFrame(InputImageRawFrame): @dataclass class InputDTMFFrame(DTMFFrame, SystemFrame): - """A DTMF keypress input.""" + """DTMF keypress input frame from transport.""" pass @dataclass class OutputDTMFUrgentFrame(DTMFFrame, SystemFrame): - """A DTMF keypress output that will be sent right away. If your transport + """DTMF keypress output frame for immediate sending. + + A DTMF keypress output that will be sent right away. If your transport supports multiple dial-out destinations, use the `transport_destination` field to specify where the DTMF keypress should be sent. - """ pass @@ -849,12 +1148,13 @@ class OutputDTMFUrgentFrame(DTMFFrame, SystemFrame): @dataclass class EndFrame(ControlFrame): - """Indicates that a pipeline has ended and frame processors and pipelines + """Frame indicating pipeline has ended and should shut down. + + Indicates that a pipeline has ended and frame processors and pipelines should be shut down. If the transport receives this frame, it will stop sending frames to its output channel(s) and close all its threads. Note, - that this is a control frame, which means it will received in the order it - was sent (unline system frames). - + that this is a control frame, which means it will be received in the order it + was sent (unlike system frames). """ pass @@ -862,10 +1162,11 @@ class EndFrame(ControlFrame): @dataclass class StopFrame(ControlFrame): - """Indicates that a pipeline should be stopped but that the pipeline + """Frame indicating pipeline should stop but keep processors running. + + Indicates that a pipeline should be stopped but that the pipeline processors should be kept in a running state. This is normally queued from the pipeline task. - """ pass @@ -873,9 +1174,13 @@ class StopFrame(ControlFrame): @dataclass class HeartbeatFrame(ControlFrame): - """This frame is used by the pipeline task as a mechanism to know if the + """Frame used by pipeline task to monitor pipeline health. + + This frame is used by the pipeline task as a mechanism to know if the pipeline is running properly. + Parameters: + timestamp: Timestamp when the heartbeat was generated. """ timestamp: int @@ -883,11 +1188,15 @@ class HeartbeatFrame(ControlFrame): @dataclass class FrameProcessorPauseFrame(ControlFrame): - """This frame is used to pause frame processing for the given + """Frame to pause frame processing for a specific processor. + + This frame is used to pause frame processing for the given processor. Pausing frame processing will keep frames in the internal queue which will then be processed when frame processing is resumed with `FrameProcessorResumeFrame`. + Parameters: + processor: The frame processor to pause. """ processor: "FrameProcessor" @@ -895,10 +1204,14 @@ class FrameProcessorPauseFrame(ControlFrame): @dataclass class FrameProcessorResumeFrame(ControlFrame): - """This frame is used to resume frame processing for the given processor if + """Frame to resume frame processing for a specific processor. + + This frame is used to resume frame processing for the given processor if it was previously paused. After resuming frame processing all queued frames will be processed in the order received. + Parameters: + processor: The frame processor to resume. """ processor: "FrameProcessor" @@ -906,8 +1219,10 @@ class FrameProcessorResumeFrame(ControlFrame): @dataclass class LLMFullResponseStartFrame(ControlFrame): - """Used to indicate the beginning of an LLM response. Following by one or - more TextFrame and a final LLMFullResponseEndFrame. + """Frame indicating the beginning of an LLM response. + + Used to indicate the beginning of an LLM response. Followed by one or + more TextFrames and a final LLMFullResponseEndFrame. """ pass @@ -915,19 +1230,20 @@ class LLMFullResponseStartFrame(ControlFrame): @dataclass class LLMFullResponseEndFrame(ControlFrame): - """Indicates the end of an LLM response.""" + """Frame indicating the end of an LLM response.""" pass @dataclass class TTSStartedFrame(ControlFrame): - """Used to indicate the beginning of a TTS response. Following - TTSAudioRawFrames are part of the TTS response until an + """Frame indicating the beginning of a TTS response. + + Used to indicate the beginning of a TTS response. Following + TTSAudioRawFrames are part of the TTS response until a TTSStoppedFrame. These frames can be used for aggregating audio frames in a transport to optimize the size of frames sent to the session, without needing to control this in the TTS service. - """ pass @@ -935,37 +1251,54 @@ class TTSStartedFrame(ControlFrame): @dataclass class TTSStoppedFrame(ControlFrame): - """Indicates the end of a TTS response.""" + """Frame indicating the end of a TTS response.""" pass @dataclass class ServiceUpdateSettingsFrame(ControlFrame): - """A control frame containing a request to update service settings.""" + """Base frame for updating service settings. + + A control frame containing a request to update service settings. + + Parameters: + settings: Dictionary of setting name to value mappings. + """ settings: Mapping[str, Any] @dataclass class LLMUpdateSettingsFrame(ServiceUpdateSettingsFrame): + """Frame for updating LLM service settings.""" + pass @dataclass class TTSUpdateSettingsFrame(ServiceUpdateSettingsFrame): + """Frame for updating TTS service settings.""" + pass @dataclass class STTUpdateSettingsFrame(ServiceUpdateSettingsFrame): + """Frame for updating STT service settings.""" + pass @dataclass class VADParamsUpdateFrame(ControlFrame): - """A control frame containing a request to update VAD params. Intended + """Frame for updating VAD parameters. + + A control frame containing a request to update VAD params. Intended to be pushed upstream from RTVI processor. + + Parameters: + params: New VAD parameters to apply. """ params: VADParams @@ -973,41 +1306,57 @@ class VADParamsUpdateFrame(ControlFrame): @dataclass class FilterControlFrame(ControlFrame): - """Base control frame for other audio filter frames.""" + """Base control frame for audio filter operations.""" pass @dataclass class FilterUpdateSettingsFrame(FilterControlFrame): - """Control frame to update filter settings.""" + """Frame for updating audio filter settings. + + Parameters: + settings: Dictionary of filter setting name to value mappings. + """ settings: Mapping[str, Any] @dataclass class FilterEnableFrame(FilterControlFrame): - """Control frame to enable or disable the filter at runtime.""" + """Frame for enabling/disabling audio filters at runtime. + + Parameters: + enable: Whether to enable (True) or disable (False) the filter. + """ enable: bool @dataclass class MixerControlFrame(ControlFrame): - """Base control frame for other audio mixer frames.""" + """Base control frame for audio mixer operations.""" pass @dataclass class MixerUpdateSettingsFrame(MixerControlFrame): - """Control frame to update mixer settings.""" + """Frame for updating audio mixer settings. + + Parameters: + settings: Dictionary of mixer setting name to value mappings. + """ settings: Mapping[str, Any] @dataclass class MixerEnableFrame(MixerControlFrame): - """Control frame to enable or disable the mixer at runtime.""" + """Frame for enabling/disabling audio mixer at runtime. + + Parameters: + enable: Whether to enable (True) or disable (False) the mixer. + """ enable: bool diff --git a/src/pipecat/metrics/metrics.py b/src/pipecat/metrics/metrics.py index fbd5f9c8c..4884d14f6 100644 --- a/src/pipecat/metrics/metrics.py +++ b/src/pipecat/metrics/metrics.py @@ -1,22 +1,64 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Metrics data models for Pipecat framework. + +This module defines Pydantic models for various types of metrics data +collected throughout the pipeline, including timing, token usage, and +processing statistics. +""" + from typing import Optional from pydantic import BaseModel class MetricsData(BaseModel): + """Base class for all metrics data. + + Parameters: + processor: Name of the processor generating the metrics. + model: Optional model name associated with the metrics. + """ + processor: str model: Optional[str] = None class TTFBMetricsData(MetricsData): + """Time To First Byte (TTFB) metrics data. + + Parameters: + value: TTFB measurement in seconds. + """ + value: float class ProcessingMetricsData(MetricsData): + """General processing time metrics data. + + Parameters: + value: Processing time measurement in seconds. + """ + value: float class LLMTokenUsage(BaseModel): + """Token usage statistics for LLM operations. + + Parameters: + prompt_tokens: Number of tokens in the input prompt. + completion_tokens: Number of tokens in the generated completion. + total_tokens: Total number of tokens used (prompt + completion). + cache_read_input_tokens: Number of tokens read from cache, if applicable. + cache_creation_input_tokens: Number of tokens used to create cache entries, if applicable. + """ + prompt_tokens: int completion_tokens: int total_tokens: int @@ -26,15 +68,35 @@ class LLMTokenUsage(BaseModel): class LLMUsageMetricsData(MetricsData): + """LLM token usage metrics data. + + Parameters: + value: Token usage statistics for the LLM operation. + """ + value: LLMTokenUsage class TTSUsageMetricsData(MetricsData): + """Text-to-Speech usage metrics data. + + Parameters: + value: Number of characters processed by TTS. + """ + value: int class SmartTurnMetricsData(MetricsData): - """Metrics data for smart turn predictions.""" + """Metrics data for smart turn predictions. + + Parameters: + is_complete: Whether the turn is predicted to be complete. + probability: Confidence probability of the turn completion prediction. + inference_time_ms: Time taken for inference in milliseconds. + server_total_time_ms: Total server processing time in milliseconds. + e2e_processing_time_ms: End-to-end processing time in milliseconds. + """ is_complete: bool probability: float diff --git a/src/pipecat/observers/base_observer.py b/src/pipecat/observers/base_observer.py index 077f6986b..d8a67a373 100644 --- a/src/pipecat/observers/base_observer.py +++ b/src/pipecat/observers/base_observer.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base observer classes for monitoring frame flow in the Pipecat pipeline. + +This module provides the foundation for observing frame transfers between +processors without modifying the pipeline structure. Observers can be used +for logging, debugging, analytics, and monitoring pipeline behavior. +""" + from abc import abstractmethod from dataclasses import dataclass @@ -18,19 +25,19 @@ if TYPE_CHECKING: @dataclass class FramePushed: - """Represents an event where a frame is pushed from one processor to another - within the pipeline. + """Event data for frame transfers between processors in the pipeline. - This data structure is typically used by observers to track the flow of - frames through the pipeline for logging, debugging, or analytics purposes. - - Attributes: - source (FrameProcessor): The processor sending the frame. - destination (FrameProcessor): The processor receiving the frame. - frame (Frame): The frame being transferred. - direction (FrameDirection): The direction of the transfer (e.g., downstream or upstream). - timestamp (int): The time when the frame was pushed, based on the pipeline clock. + Represents an event where a frame is pushed from one processor to another + within the pipeline. This data structure is typically used by observers + to track the flow of frames through the pipeline for logging, debugging, + or analytics purposes. + Parameters: + source: The processor sending the frame. + destination: The processor receiving the frame. + frame: The frame being transferred. + direction: The direction of the transfer (e.g., downstream or upstream). + timestamp: The time when the frame was pushed, based on the pipeline clock. """ source: "FrameProcessor" @@ -41,11 +48,12 @@ class FramePushed: class BaseObserver(BaseObject): - """This is the base class for pipeline frame observers. Observers can view - all the frames that go through the pipeline without the need to inject - processors in the pipeline. This can be useful, for example, to implement - frame loggers or debuggers among other things. + """Base class for pipeline frame observers. + Observers can view all frames that flow through the pipeline without + needing to inject processors into the pipeline structure. This enables + non-intrusive monitoring capabilities such as frame logging, debugging, + performance analysis, and analytics collection. """ @abstractmethod @@ -57,7 +65,6 @@ class BaseObserver(BaseObject): transferred through the pipeline. Args: - data (FramePushed): The event data containing details about the frame transfer. - + data: The event data containing details about the frame transfer. """ pass diff --git a/src/pipecat/observers/loggers/debug_log_observer.py b/src/pipecat/observers/loggers/debug_log_observer.py index 1b75a3f7a..a8cf7d2d3 100644 --- a/src/pipecat/observers/loggers/debug_log_observer.py +++ b/src/pipecat/observers/loggers/debug_log_observer.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Debug logging observer for frame activity monitoring. + +This module provides a debug observer that logs detailed frame activity +to the console, making it useful for debugging pipeline behavior and +understanding frame flow between processors. +""" + from dataclasses import fields, is_dataclass from enum import Enum, auto from typing import Dict, Optional, Set, Tuple, Type, Union @@ -16,7 +23,12 @@ from pipecat.processors.frame_processor import FrameDirection class FrameEndpoint(Enum): - """Specifies which endpoint (source or destination) to filter on.""" + """Specifies which endpoint (source or destination) to filter on. + + Parameters: + SOURCE: Filter on the source component that is pushing the frame. + DESTINATION: Filter on the destination component receiving the frame. + """ SOURCE = auto() DESTINATION = auto() @@ -28,44 +40,36 @@ class DebugLogObserver(BaseObserver): Automatically extracts and formats data from any frame type, making it useful for debugging pipeline behavior without needing frame-specific observers. - Args: - frame_types: Optional tuple of frame types to log, or a dict with frame type - filters. If None, logs all frame types. - exclude_fields: Optional set of field names to exclude from logging. - Examples: - Log all frames from all services: - ```python - observers = DebugLogObserver() - ``` + Log all frames from all services:: - Log specific frame types from any source/destination: - ```python - from pipecat.frames.frames import TranscriptionFrame, InterimTranscriptionFrame - observers=[ - DebugLogObserver(frame_types=(LLMTextFrame,TranscriptionFrame,)), - ], - ``` + observers = DebugLogObserver() - Log frames with specific source/destination filters: - ```python - from pipecat.frames.frames import StartInterruptionFrame, UserStartedSpeakingFrame, LLMTextFrame - from pipecat.transports.base_output_transport import BaseOutputTransport - from pipecat.services.stt_service import STTService + Log specific frame types from any source/destination:: - observers=[ - DebugLogObserver( - frame_types={ - # Only log StartInterruptionFrame when source is BaseOutputTransport - StartInterruptionFrame: (BaseOutputTransport, FrameEndpoint.SOURCE), - # Only log UserStartedSpeakingFrame when destination is STTService - UserStartedSpeakingFrame: (STTService, FrameEndpoint.DESTINATION), - # Log LLMTextFrame regardless of source or destination type - LLMTextFrame: None, - } - ), - ], - ``` + from pipecat.frames.frames import TranscriptionFrame, InterimTranscriptionFrame + observers=[ + DebugLogObserver(frame_types=(LLMTextFrame,TranscriptionFrame,)), + ] + + Log frames with specific source/destination filters:: + + from pipecat.frames.frames import StartInterruptionFrame, UserStartedSpeakingFrame, LLMTextFrame + from pipecat.transports.base_output_transport import BaseOutputTransport + from pipecat.services.stt_service import STTService + + observers=[ + DebugLogObserver( + frame_types={ + # Only log StartInterruptionFrame when source is BaseOutputTransport + StartInterruptionFrame: (BaseOutputTransport, FrameEndpoint.SOURCE), + # Only log UserStartedSpeakingFrame when destination is STTService + UserStartedSpeakingFrame: (STTService, FrameEndpoint.DESTINATION), + # Log LLMTextFrame regardless of source or destination type + LLMTextFrame: None, + } + ), + ] """ def __init__( @@ -79,14 +83,17 @@ class DebugLogObserver(BaseObserver): """Initialize the debug log observer. Args: - frame_types: Tuple of frame types to log, or a dict mapping frame types to - filter configurations. Filter configs can be: - - None to log all instances of the frame type - - A tuple of (service_type, endpoint) to filter on a specific service - and endpoint (SOURCE or DESTINATION) - If None is provided instead of a tuple/dict, log all frames. - exclude_fields: Set of field names to exclude from logging. If None, only binary - data fields are excluded. + frame_types: Frame types to log. Can be: + + - Tuple of frame types to log all instances + - Dict mapping frame types to filter configurations + - None to log all frames + + Filter configurations can be None (log all instances) or a tuple + of (service_type, endpoint) to filter on specific services. + exclude_fields: Field names to exclude from logging. Defaults to + excluding binary data fields like 'audio', 'image', 'images'. + **kwargs: Additional arguments passed to parent class. """ super().__init__(**kwargs) @@ -113,14 +120,7 @@ class DebugLogObserver(BaseObserver): ) def _format_value(self, value): - """Format a value for logging. - - Args: - value: The value to format. - - Returns: - str: A string representation of the value suitable for logging. - """ + """Format a value for logging.""" if value is None: return "None" elif isinstance(value, str): @@ -143,16 +143,7 @@ class DebugLogObserver(BaseObserver): return str(value) def _should_log_frame(self, frame, src, dst): - """Determine if a frame should be logged based on filters. - - Args: - frame: The frame being processed - src: The source component - dst: The destination component - - Returns: - bool: True if the frame should be logged, False otherwise - """ + """Determine if a frame should be logged based on filters.""" # If no filters, log all frames if not self.frame_filters: return True diff --git a/src/pipecat/observers/loggers/llm_log_observer.py b/src/pipecat/observers/loggers/llm_log_observer.py index a6675b5c0..53a0ac484 100644 --- a/src/pipecat/observers/loggers/llm_log_observer.py +++ b/src/pipecat/observers/loggers/llm_log_observer.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""LLM logging observer for Pipecat.""" + from loguru import logger from pipecat.frames.frames import ( @@ -34,10 +36,15 @@ class LLMLogObserver(BaseObserver): This allows you to track when the LLM starts responding, what it generates, and when it finishes. - """ async def on_push_frame(self, data: FramePushed): + """Handle frame push events and log LLM-related activities. + + Args: + data: The frame push event data containing source, destination, + frame, direction, and timestamp information. + """ src = data.source dst = data.destination frame = data.frame diff --git a/src/pipecat/observers/loggers/transcription_log_observer.py b/src/pipecat/observers/loggers/transcription_log_observer.py index 8ca1d9c9b..1f1a74388 100644 --- a/src/pipecat/observers/loggers/transcription_log_observer.py +++ b/src/pipecat/observers/loggers/transcription_log_observer.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Transcription logging observer for Pipecat. + +This module provides an observer that logs transcription frames to the console, +allowing developers to monitor speech-to-text activity in real-time. +""" + from loguru import logger from pipecat.frames.frames import ( @@ -17,17 +23,23 @@ from pipecat.services.stt_service import STTService class TranscriptionLogObserver(BaseObserver): """Observer to log transcription activity to the console. - Logs all frame instances (only from STT service) of: - - - TranscriptionFrame - - InterimTranscriptionFrame - - This allows you to track when the LLM starts responding, what it generates, - and when it finishes. + Monitors and logs all transcription frames from STT services, including + both final transcriptions and interim results. This allows developers + to track speech recognition activity and debug transcription issues. + Only processes frames from STTService instances to avoid logging + unrelated transcription frames from other sources. """ async def on_push_frame(self, data: FramePushed): + """Handle frame push events and log transcription frames. + + Logs TranscriptionFrame and InterimTranscriptionFrame instances + with timestamps and user information for debugging purposes. + + Args: + data: Frame push event data containing source, frame, and timestamp. + """ src = data.source frame = data.frame timestamp = data.timestamp diff --git a/src/pipecat/observers/loggers/user_bot_latency_log_observer.py b/src/pipecat/observers/loggers/user_bot_latency_log_observer.py index f601a9e9e..eb699f2bb 100644 --- a/src/pipecat/observers/loggers/user_bot_latency_log_observer.py +++ b/src/pipecat/observers/loggers/user_bot_latency_log_observer.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Observer for measuring user-to-bot response latency.""" + import time from loguru import logger @@ -18,19 +20,28 @@ from pipecat.processors.frame_processor import FrameDirection class UserBotLatencyLogObserver(BaseObserver): - """Observer that logs the latency between when the user stops speaking and - when the bot starts speaking. - - This helps measure how quickly the AI services respond. + """Observer that measures time between user stopping speech and bot starting speech. + This helps measure how quickly the AI services respond by tracking + conversation turn timing and logging latency metrics. """ def __init__(self): + """Initialize the latency observer. + + Sets up tracking for processed frames and user speech timing + to calculate response latencies. + """ super().__init__() self._processed_frames = set() self._user_stopped_time = 0 async def on_push_frame(self, data: FramePushed): + """Process frames to track speech timing and calculate latency. + + Args: + data: Frame push event containing the frame and direction information. + """ # Only process downstream frames if data.direction != FrameDirection.DOWNSTREAM: return diff --git a/src/pipecat/observers/turn_tracking_observer.py b/src/pipecat/observers/turn_tracking_observer.py index 04b5ad92b..525101f57 100644 --- a/src/pipecat/observers/turn_tracking_observer.py +++ b/src/pipecat/observers/turn_tracking_observer.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Turn tracking observer for conversation flow monitoring. + +This module provides an observer that monitors conversation turns in a pipeline, +tracking when turns start and end based on user and bot speech patterns. +""" + import asyncio from collections import deque @@ -23,15 +29,30 @@ from pipecat.observers.base_observer import BaseObserver, FramePushed class TurnTrackingObserver(BaseObserver): """Observer that tracks conversation turns in a pipeline. + This observer monitors the flow of conversation by tracking when turns + start and end based on user and bot speaking patterns. It handles + interruptions, timeouts, and maintains turn state throughout the pipeline. + Turn tracking logic: + - The first turn starts immediately when the pipeline starts (StartFrame) - Subsequent turns start when the user starts speaking - A turn ends when the bot stops speaking and either: + - The user starts speaking again - A timeout period elapses with no more bot speech """ def __init__(self, max_frames=100, turn_end_timeout_secs=2.5, **kwargs): + """Initialize the turn tracking observer. + + Args: + max_frames: Maximum number of frame IDs to keep in history for + duplicate detection. Defaults to 100. + turn_end_timeout_secs: Timeout in seconds after bot stops speaking + before automatically ending the turn. Defaults to 2.5. + **kwargs: Additional arguments passed to the parent observer. + """ super().__init__(**kwargs) self._turn_count = 0 self._is_turn_active = False @@ -49,7 +70,11 @@ class TurnTrackingObserver(BaseObserver): self._register_event_handler("on_turn_ended") async def on_push_frame(self, data: FramePushed): - """Process frame events for turn tracking.""" + """Process frame events for turn tracking. + + Args: + data: Frame push event data containing the frame and metadata. + """ # Skip already processed frames if data.frame.id in self._processed_frames: return diff --git a/src/pipecat/pipeline/base_pipeline.py b/src/pipecat/pipeline/base_pipeline.py index b0a08da80..a3b1e0b3e 100644 --- a/src/pipecat/pipeline/base_pipeline.py +++ b/src/pipecat/pipeline/base_pipeline.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base pipeline implementation for frame processing.""" + from abc import abstractmethod from typing import List @@ -11,9 +13,24 @@ from pipecat.processors.frame_processor import FrameProcessor class BasePipeline(FrameProcessor): + """Base class for all pipeline implementations. + + Provides the foundation for pipeline processors that need to support + metrics collection from their contained processors. + """ + def __init__(self): + """Initialize the base pipeline.""" super().__init__() @abstractmethod def processors_with_metrics(self) -> List[FrameProcessor]: + """Return processors that can generate metrics. + + Implementing classes should collect and return all processors within + their pipeline that support metrics generation. + + Returns: + List of frame processors that support metrics collection. + """ pass diff --git a/src/pipecat/pipeline/base_task.py b/src/pipecat/pipeline/base_task.py index 6cdd88c9b..2751f2b7b 100644 --- a/src/pipecat/pipeline/base_task.py +++ b/src/pipecat/pipeline/base_task.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base pipeline task implementation for managing pipeline execution. + +This module provides the abstract base class and configuration for pipeline +tasks that manage the lifecycle and execution of frame processing pipelines. +""" + import asyncio from abc import abstractmethod from dataclasses import dataclass @@ -15,44 +21,81 @@ from pipecat.utils.base_object import BaseObject @dataclass class PipelineTaskParams: - """Specific configuration for the pipeline task.""" + """Configuration parameters for pipeline task execution. + + Parameters: + loop: The asyncio event loop to use for task execution. + """ loop: asyncio.AbstractEventLoop class BasePipelineTask(BaseObject): + """Abstract base class for pipeline task implementations. + + Defines the interface for managing pipeline execution lifecycle, + including starting, stopping, and frame queuing operations. + """ + @abstractmethod def has_finished(self) -> bool: - """Indicates whether the tasks has finished. That is, all processors - have stopped. + """Check if the pipeline task has finished execution. + Returns: + True if all processors have stopped and the task is complete. """ pass @abstractmethod async def stop_when_done(self): - """This is a helper function that sends an EndFrame to the pipeline in - order to stop the task after everything in it has been processed. + """Schedule the pipeline to stop after processing all queued frames. + Implementing classes should send an EndFrame or equivalent signal to + gracefully terminate the pipeline once all current processing is complete. """ pass @abstractmethod async def cancel(self): - """Stops the running pipeline immediately.""" + """Immediately stop the running pipeline. + + Implementing classes should cancel all running tasks and stop frame + processing without waiting for completion. + """ pass @abstractmethod async def run(self, params: PipelineTaskParams): - """Starts running the given pipeline.""" + """Start and run the pipeline with the given parameters. + + Implementing classes should initialize and execute the pipeline using + the provided configuration parameters. + + Args: + params: Configuration parameters for pipeline execution. + """ pass @abstractmethod async def queue_frame(self, frame: Frame): - """Queue a frame to be pushed down the pipeline.""" + """Queue a single frame for processing by the pipeline. + + Implementing classes should add the frame to their processing queue + for downstream handling. + + Args: + frame: The frame to be processed. + """ pass @abstractmethod async def queue_frames(self, frames: Iterable[Frame] | AsyncIterable[Frame]): - """Queues multiple frames to be pushed down the pipeline.""" + """Queue multiple frames for processing by the pipeline. + + Implementing classes should process the iterable/async iterable and + add all frames to their processing queue. + + Args: + frames: An iterable or async iterable of frames to be processed. + """ pass diff --git a/src/pipecat/pipeline/parallel_pipeline.py b/src/pipecat/pipeline/parallel_pipeline.py index 300492cc5..57b3a82c3 100644 --- a/src/pipecat/pipeline/parallel_pipeline.py +++ b/src/pipecat/pipeline/parallel_pipeline.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Parallel pipeline implementation for concurrent frame processing. + +This module provides a parallel pipeline that processes frames through multiple +sub-pipelines concurrently, with coordination for system frames and proper +handling of pipeline lifecycle events. +""" + import asyncio from itertools import chain from typing import Awaitable, Callable, Dict, List @@ -25,16 +32,34 @@ from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue class ParallelPipelineSource(FrameProcessor): + """Source processor for parallel pipeline branches. + + Handles frame routing for parallel pipeline inputs, directing system frames + to the parent push function and other upstream frames to a queue for processing. + """ + def __init__( self, upstream_queue: asyncio.Queue, push_frame_func: Callable[[Frame, FrameDirection], Awaitable[None]], ): + """Initialize the parallel pipeline source. + + Args: + upstream_queue: Queue for collecting upstream frames from this branch. + push_frame_func: Function to push frames to the parent parallel pipeline. + """ super().__init__() self._up_queue = upstream_queue self._push_frame_func = push_frame_func async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames with special handling for system frames. + + Args: + frame: The frame to process. + direction: The direction of frame flow. + """ await super().process_frame(frame, direction) match direction: @@ -48,16 +73,34 @@ class ParallelPipelineSource(FrameProcessor): class ParallelPipelineSink(FrameProcessor): + """Sink processor for parallel pipeline branches. + + Handles frame routing for parallel pipeline outputs, directing system frames + to the parent push function and other downstream frames to a queue for coordination. + """ + def __init__( self, downstream_queue: asyncio.Queue, push_frame_func: Callable[[Frame, FrameDirection], Awaitable[None]], ): + """Initialize the parallel pipeline sink. + + Args: + downstream_queue: Queue for collecting downstream frames from this branch. + push_frame_func: Function to push frames to the parent parallel pipeline. + """ super().__init__() self._down_queue = downstream_queue self._push_frame_func = push_frame_func async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames with special handling for system frames. + + Args: + frame: The frame to process. + direction: The direction of frame flow. + """ await super().process_frame(frame, direction) match direction: @@ -71,7 +114,24 @@ class ParallelPipelineSink(FrameProcessor): class ParallelPipeline(BasePipeline): + """Pipeline that processes frames through multiple sub-pipelines concurrently. + + Creates multiple parallel processing branches from the provided processor lists, + coordinating frame flow and ensuring proper synchronization of lifecycle events + like EndFrames. Each branch runs independently while system frames are handled + specially to maintain pipeline coordination. + """ + def __init__(self, *args): + """Initialize the parallel pipeline with processor lists. + + Args: + *args: Variable number of processor lists, each becoming a parallel branch. + + Raises: + Exception: If no processor lists are provided. + TypeError: If any argument is not a list of processors. + """ super().__init__() if len(args) == 0: @@ -93,6 +153,11 @@ class ParallelPipeline(BasePipeline): # def processors_with_metrics(self) -> List[FrameProcessor]: + """Collect processors that can generate metrics from all parallel branches. + + Returns: + List of frame processors that support metrics collection from all branches. + """ return list(chain.from_iterable(p.processors_with_metrics() for p in self._pipelines)) # @@ -100,6 +165,14 @@ class ParallelPipeline(BasePipeline): # async def setup(self, setup: FrameProcessorSetup): + """Set up the parallel pipeline and all its branches. + + Args: + setup: Configuration for frame processor setup. + + Raises: + TypeError: If any processor list argument is not actually a list. + """ await super().setup(setup) self._up_queue = WatchdogQueue(setup.task_manager) @@ -129,12 +202,19 @@ class ParallelPipeline(BasePipeline): await asyncio.gather(*[s.setup(setup) for s in self._sinks]) async def cleanup(self): + """Clean up the parallel pipeline and all its branches.""" await super().cleanup() await asyncio.gather(*[s.cleanup() for s in self._sources]) await asyncio.gather(*[p.cleanup() for p in self._pipelines]) await asyncio.gather(*[s.cleanup() for s in self._sinks]) async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames through all parallel branches with lifecycle coordination. + + Args: + frame: The frame to process. + direction: The direction of frame flow. + """ await super().process_frame(frame, direction) if isinstance(frame, StartFrame): @@ -159,9 +239,11 @@ class ParallelPipeline(BasePipeline): await self._stop() async def _start(self, frame: StartFrame): + """Start the parallel pipeline processing tasks.""" await self._create_tasks() async def _stop(self): + """Stop all parallel pipeline processing tasks.""" if self._up_task: # The up task doesn't receive an EndFrame, so we just cancel it. await self.cancel_task(self._up_task) @@ -174,6 +256,7 @@ class ParallelPipeline(BasePipeline): self._down_task = None async def _cancel(self): + """Cancel all parallel pipeline processing tasks.""" if self._up_task: await self.cancel_task(self._up_task) self._up_task = None @@ -182,34 +265,44 @@ class ParallelPipeline(BasePipeline): self._down_task = None async def _create_tasks(self): + """Create upstream and downstream processing tasks if not already running.""" if not self._up_task: self._up_task = self.create_task(self._process_up_queue()) if not self._down_task: self._down_task = self.create_task(self._process_down_queue()) async def _drain_queues(self): + """Drain all frames from upstream and downstream queues.""" while not self._up_queue.empty: await self._up_queue.get() while not self._down_queue.empty: await self._down_queue.get() async def _handle_interruption(self): + """Handle interruption by cancelling tasks, draining queues, and restarting.""" await self._cancel() await self._drain_queues() await self._create_tasks() async def _parallel_push_frame(self, frame: Frame, direction: FrameDirection): + """Push frames while avoiding duplicates using frame ID tracking.""" if frame.id not in self._seen_ids: self._seen_ids.add(frame.id) await self.push_frame(frame, direction) async def _process_up_queue(self): + """Process upstream frames from all parallel branches.""" while True: frame = await self._up_queue.get() await self._parallel_push_frame(frame, FrameDirection.UPSTREAM) self._up_queue.task_done() async def _process_down_queue(self): + """Process downstream frames with EndFrame coordination. + + Coordinates EndFrames to ensure they are only pushed upstream once + all parallel branches have completed processing them. + """ running = True while running: frame = await self._down_queue.get() diff --git a/src/pipecat/pipeline/pipeline.py b/src/pipecat/pipeline/pipeline.py index c10a32e0a..ee96dd7fc 100644 --- a/src/pipecat/pipeline/pipeline.py +++ b/src/pipecat/pipeline/pipeline.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Pipeline implementation for connecting and managing frame processors. + +This module provides the main Pipeline class that connects frame processors +in sequence and manages frame flow between them, along with helper classes +for pipeline source and sink operations. +""" + from typing import Callable, Coroutine, List from pipecat.frames.frames import Frame @@ -12,11 +19,29 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor, F class PipelineSource(FrameProcessor): + """Source processor that forwards frames to an upstream handler. + + This processor acts as the entry point for a pipeline, forwarding + downstream frames to the next processor and upstream frames to a + provided upstream handler function. + """ + def __init__(self, upstream_push_frame: Callable[[Frame, FrameDirection], Coroutine]): + """Initialize the pipeline source. + + Args: + upstream_push_frame: Coroutine function to handle upstream frames. + """ super().__init__() self._upstream_push_frame = upstream_push_frame async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames and route them based on direction. + + Args: + frame: The frame to process. + direction: The direction of frame flow. + """ await super().process_frame(frame, direction) match direction: @@ -27,11 +52,29 @@ class PipelineSource(FrameProcessor): class PipelineSink(FrameProcessor): + """Sink processor that forwards frames to a downstream handler. + + This processor acts as the exit point for a pipeline, forwarding + upstream frames to the previous processor and downstream frames to a + provided downstream handler function. + """ + def __init__(self, downstream_push_frame: Callable[[Frame, FrameDirection], Coroutine]): + """Initialize the pipeline sink. + + Args: + downstream_push_frame: Coroutine function to handle downstream frames. + """ super().__init__() self._downstream_push_frame = downstream_push_frame async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames and route them based on direction. + + Args: + frame: The frame to process. + direction: The direction of frame flow. + """ await super().process_frame(frame, direction) match direction: @@ -42,7 +85,19 @@ class PipelineSink(FrameProcessor): class Pipeline(BasePipeline): + """Main pipeline implementation that connects frame processors in sequence. + + Creates a linear chain of frame processors with automatic source and sink + processors for external frame handling. Manages processor lifecycle and + provides metrics collection from contained processors. + """ + def __init__(self, processors: List[FrameProcessor]): + """Initialize the pipeline with a list of processors. + + Args: + processors: List of frame processors to connect in sequence. + """ super().__init__() # Add a source and a sink queue so we can forward frames upstream and @@ -58,6 +113,14 @@ class Pipeline(BasePipeline): # def processors_with_metrics(self): + """Return processors that can generate metrics. + + Recursively collects all processors that support metrics generation, + including those from nested pipelines. + + Returns: + List of frame processors that can generate metrics. + """ services = [] for p in self._processors: if isinstance(p, BasePipeline): @@ -71,14 +134,26 @@ class Pipeline(BasePipeline): # async def setup(self, setup: FrameProcessorSetup): + """Set up the pipeline and all contained processors. + + Args: + setup: Configuration for frame processor setup. + """ await super().setup(setup) await self._setup_processors(setup) async def cleanup(self): + """Clean up the pipeline and all contained processors.""" await super().cleanup() await self._cleanup_processors() async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames by routing them through the pipeline. + + Args: + frame: The frame to process. + direction: The direction of frame flow. + """ await super().process_frame(frame, direction) if direction == FrameDirection.DOWNSTREAM: @@ -87,14 +162,17 @@ class Pipeline(BasePipeline): await self._sink.queue_frame(frame, FrameDirection.UPSTREAM) async def _setup_processors(self, setup: FrameProcessorSetup): + """Set up all processors in the pipeline.""" for p in self._processors: await p.setup(setup) async def _cleanup_processors(self): + """Clean up all processors in the pipeline.""" for p in self._processors: await p.cleanup() def _link_processors(self): + """Link all processors in sequence and set their parent.""" prev = self._processors[0] for curr in self._processors[1:]: prev.set_parent(self) diff --git a/src/pipecat/pipeline/runner.py b/src/pipecat/pipeline/runner.py index b789fc7ba..d64bf9495 100644 --- a/src/pipecat/pipeline/runner.py +++ b/src/pipecat/pipeline/runner.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Pipeline runner for managing pipeline task execution. + +This module provides the PipelineRunner class that handles the execution +of pipeline tasks with signal handling, garbage collection, and lifecycle +management. +""" + import asyncio import gc import signal @@ -17,6 +24,13 @@ from pipecat.utils.base_object import BaseObject class PipelineRunner(BaseObject): + """Manages the execution of pipeline tasks with lifecycle and signal handling. + + Provides a high-level interface for running pipeline tasks with automatic + signal handling (SIGINT/SIGTERM), optional garbage collection, and proper + cleanup of resources. + """ + def __init__( self, *, @@ -25,6 +39,14 @@ class PipelineRunner(BaseObject): force_gc: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None, ): + """Initialize the pipeline runner. + + Args: + name: Optional name for the runner instance. + handle_sigint: Whether to automatically handle SIGINT/SIGTERM signals. + force_gc: Whether to force garbage collection after task completion. + loop: Event loop to use. If None, uses the current running loop. + """ super().__init__(name=name) self._tasks = {} @@ -36,6 +58,11 @@ class PipelineRunner(BaseObject): self._setup_sigint() async def run(self, task: PipelineTask): + """Run a pipeline task to completion. + + Args: + task: The pipeline task to execute. + """ logger.debug(f"Runner {self} started running {task}") self._tasks[task.name] = task params = PipelineTaskParams(loop=self._loop) @@ -56,27 +83,33 @@ class PipelineRunner(BaseObject): logger.debug(f"Runner {self} finished running {task}") async def stop_when_done(self): + """Schedule all running tasks to stop when their current processing is complete.""" logger.debug(f"Runner {self} scheduled to stop when all tasks are done") await asyncio.gather(*[t.stop_when_done() for t in self._tasks.values()]) async def cancel(self): + """Cancel all running tasks immediately.""" logger.debug(f"Cancelling runner {self}") await asyncio.gather(*[t.cancel() for t in self._tasks.values()]) def _setup_sigint(self): + """Set up signal handlers for graceful shutdown.""" loop = asyncio.get_running_loop() loop.add_signal_handler(signal.SIGINT, lambda *args: self._sig_handler()) loop.add_signal_handler(signal.SIGTERM, lambda *args: self._sig_handler()) def _sig_handler(self): + """Handle interrupt signals by cancelling all tasks.""" if not self._sig_task: self._sig_task = asyncio.create_task(self._sig_cancel()) async def _sig_cancel(self): + """Cancel all running tasks due to signal interruption.""" logger.warning(f"Interruption detected. Cancelling runner {self}") await self.cancel() def _gc_collect(self): + """Force garbage collection and log results.""" collected = gc.collect() logger.debug(f"Garbage collector: collected {collected} objects.") logger.debug(f"Garbage collector: uncollectable objects {gc.garbage}") diff --git a/src/pipecat/pipeline/sync_parallel_pipeline.py b/src/pipecat/pipeline/sync_parallel_pipeline.py index 006290710..6454c3929 100644 --- a/src/pipecat/pipeline/sync_parallel_pipeline.py +++ b/src/pipecat/pipeline/sync_parallel_pipeline.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Synchronous parallel pipeline implementation for concurrent frame processing. + +This module provides a pipeline that processes frames through multiple parallel +pipelines simultaneously, synchronizing their output to maintain frame ordering +and prevent duplicate processing. +""" + import asyncio from dataclasses import dataclass from itertools import chain @@ -20,17 +27,38 @@ from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue @dataclass class SyncFrame(ControlFrame): - """This frame is used to know when the internal pipelines have finished.""" + """Control frame used to synchronize parallel pipeline processing. + + This frame is sent through parallel pipelines to determine when the + internal pipelines have finished processing a batch of frames. + """ pass class SyncParallelPipelineSource(FrameProcessor): + """Source processor for synchronous parallel pipeline processing. + + Routes frames to parallel pipelines and collects upstream responses + for synchronization purposes. + """ + def __init__(self, upstream_queue: asyncio.Queue): + """Initialize the sync parallel pipeline source. + + Args: + upstream_queue: Queue for collecting upstream frames from the pipeline. + """ super().__init__() self._up_queue = upstream_queue async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames and route them based on direction. + + Args: + frame: The frame to process. + direction: The direction of frame flow. + """ await super().process_frame(frame, direction) match direction: @@ -41,11 +69,28 @@ class SyncParallelPipelineSource(FrameProcessor): class SyncParallelPipelineSink(FrameProcessor): + """Sink processor for synchronous parallel pipeline processing. + + Collects downstream frames from parallel pipelines and routes + upstream frames back through the pipeline. + """ + def __init__(self, downstream_queue: asyncio.Queue): + """Initialize the sync parallel pipeline sink. + + Args: + downstream_queue: Queue for collecting downstream frames from the pipeline. + """ super().__init__() self._down_queue = downstream_queue async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames and route them based on direction. + + Args: + frame: The frame to process. + direction: The direction of frame flow. + """ await super().process_frame(frame, direction) match direction: @@ -56,7 +101,28 @@ class SyncParallelPipelineSink(FrameProcessor): class SyncParallelPipeline(BasePipeline): + """Pipeline that processes frames through multiple parallel pipelines synchronously. + + Creates multiple parallel processing paths that all receive the same input frames + and produces synchronized output. Each parallel path is a separate pipeline that + processes frames independently, with synchronization points to ensure consistent + ordering and prevent duplicate frame processing. + + The pipeline uses SyncFrame control frames to coordinate between parallel paths + and ensure all paths have completed processing before moving to the next frame. + """ + def __init__(self, *args): + """Initialize the synchronous parallel pipeline. + + Args: + *args: Variable number of processor lists, each representing a parallel pipeline path. + Each argument should be a list of FrameProcessor instances. + + Raises: + Exception: If no arguments are provided. + TypeError: If any argument is not a list of processors. + """ super().__init__() if len(args) == 0: @@ -72,6 +138,11 @@ class SyncParallelPipeline(BasePipeline): # def processors_with_metrics(self) -> List[FrameProcessor]: + """Collect processors that can generate metrics from all parallel pipelines. + + Returns: + List of frame processors that support metrics collection from all parallel paths. + """ return list(chain.from_iterable(p.processors_with_metrics() for p in self._pipelines)) # @@ -79,6 +150,11 @@ class SyncParallelPipeline(BasePipeline): # async def setup(self, setup: FrameProcessorSetup): + """Set up the parallel pipeline and all contained processors. + + Args: + setup: Configuration for frame processor setup. + """ await super().setup(setup) self._up_queue = WatchdogQueue(setup.task_manager) @@ -113,12 +189,23 @@ class SyncParallelPipeline(BasePipeline): await asyncio.gather(*[s["processor"].setup(setup) for s in self._sinks]) async def cleanup(self): + """Clean up the parallel pipeline and all contained processors.""" await super().cleanup() await asyncio.gather(*[s["processor"].cleanup() for s in self._sources]) await asyncio.gather(*[p.cleanup() for p in self._pipelines]) await asyncio.gather(*[s["processor"].cleanup() for s in self._sinks]) async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames through all parallel pipelines with synchronization. + + Distributes frames to all parallel pipelines and synchronizes their output + to maintain proper ordering and prevent duplicate processing. Uses SyncFrame + control frames to coordinate between parallel paths. + + Args: + frame: The frame to process. + direction: The direction of frame flow. + """ await super().process_frame(frame, direction) # The last processor of each pipeline needs to be synchronous otherwise diff --git a/src/pipecat/pipeline/task.py b/src/pipecat/pipeline/task.py index 0cf294b05..a8b55e77c 100644 --- a/src/pipecat/pipeline/task.py +++ b/src/pipecat/pipeline/task.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Pipeline task implementation for managing frame processing pipelines. + +This module provides the main PipelineTask class that orchestrates pipeline +execution, frame routing, lifecycle management, and monitoring capabilities +including heartbeats, idle detection, and observer integration. +""" + import asyncio import time from collections import deque @@ -53,12 +60,13 @@ HEARTBEAT_MONITOR_SECONDS = HEARTBEAT_SECONDS * 10 class PipelineParams(BaseModel): - """Configuration parameters for pipeline execution. These parameters are - usually passed to all frame processors using through `StartFrame`. For other - generic pipeline task parameters use `PipelineTask` constructor arguments - instead. + """Configuration parameters for pipeline execution. - Attributes: + These parameters are usually passed to all frame processors through + StartFrame. For other generic pipeline task parameters use PipelineTask + constructor arguments instead. + + Parameters: allow_interruptions: Whether to allow pipeline interruptions. audio_in_sample_rate: Input audio sample rate in Hz. audio_out_sample_rate: Output audio sample rate in Hz. @@ -66,12 +74,11 @@ class PipelineParams(BaseModel): enable_metrics: Whether to enable metrics collection. enable_usage_metrics: Whether to enable usage metrics. heartbeats_period_secs: Period between heartbeats in seconds. + interruption_strategies: Strategies for bot interruption behavior. observers: [deprecated] Use `observers` arg in `PipelineTask` class. report_only_initial_ttfb: Whether to report only initial time to first byte. send_initial_empty_metrics: Whether to send initial empty metrics. start_metadata: Additional metadata for pipeline start. - interruption_strategies: Strategies for bot interruption behavior. - """ model_config = ConfigDict(arbitrary_types_allowed=True) @@ -97,17 +104,25 @@ class PipelineTaskSource(FrameProcessor): pipeline given to the pipeline task. It allows us to easily push frames downstream to the pipeline and also receive upstream frames coming from the pipeline. - - Args: - up_queue: Queue for upstream frame processing. - """ def __init__(self, up_queue: asyncio.Queue, **kwargs): + """Initialize the pipeline task source. + + Args: + up_queue: Queue for upstream frame processing. + **kwargs: Additional arguments passed to the parent class. + """ super().__init__(**kwargs) self._up_queue = up_queue async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames and route them based on direction. + + Args: + frame: The frame to process. + direction: The direction of frame flow. + """ await super().process_frame(frame, direction) match direction: @@ -123,16 +138,25 @@ class PipelineTaskSink(FrameProcessor): This is the sink processor that is linked at the end of the pipeline given to the pipeline task. It allows us to receive downstream frames and act on them, for example, waiting to receive an EndFrame. - - Args: - down_queue: Queue for downstream frame processing. """ def __init__(self, down_queue: asyncio.Queue, **kwargs): + """Initialize the pipeline task sink. + + Args: + down_queue: Queue for downstream frame processing. + **kwargs: Additional arguments passed to the parent class. + """ super().__init__(**kwargs) self._down_queue = down_queue async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames and route them to the downstream queue. + + Args: + frame: The frame to process. + direction: The direction of frame flow. + """ await super().process_frame(frame, direction) await self._down_queue.put(frame) @@ -140,69 +164,30 @@ class PipelineTaskSink(FrameProcessor): class PipelineTask(BasePipelineTask): """Manages the execution of a pipeline, handling frame processing and task lifecycle. - It has a couple of event handlers `on_frame_reached_upstream` and - `on_frame_reached_downstream` that are called when upstream frames or - downstream frames reach both ends of pipeline. By default, the events - handlers will not be called unless some filters are set using - `set_reached_upstream_filter` and `set_reached_downstream_filter`. + This class orchestrates pipeline execution with comprehensive monitoring, + event handling, and lifecycle management. It provides event handlers for + various pipeline states and frame types, idle detection, heartbeat monitoring, + and observer integration. - @task.event_handler("on_frame_reached_upstream") - async def on_frame_reached_upstream(task, frame): - ... + Event handlers available: - @task.event_handler("on_frame_reached_downstream") - async def on_frame_reached_downstream(task, frame): - ... + - on_frame_reached_upstream: Called when upstream frames reach the source + - on_frame_reached_downstream: Called when downstream frames reach the sink + - on_idle_timeout: Called when pipeline is idle beyond timeout threshold + - on_pipeline_started: Called when pipeline starts with StartFrame + - on_pipeline_stopped: Called when pipeline stops with StopFrame + - on_pipeline_ended: Called when pipeline ends with EndFrame + - on_pipeline_cancelled: Called when pipeline is cancelled - It also has an event handler that detects when the pipeline is idle. By - default, a pipeline is idle if no `BotSpeakingFrame` or - `LLMFullResponseEndFrame` are received within `idle_timeout_secs`. + Example:: - @task.event_handler("on_idle_timeout") - async def on_pipeline_idle_timeout(task): - ... + @task.event_handler("on_frame_reached_upstream") + async def on_frame_reached_upstream(task, frame): + ... - There are also events to know if a pipeline has been started, stopped, ended - or cancelled. - - @task.event_handler("on_pipeline_started") - async def on_pipeline_started(task, frame: StartFrame): - ... - - @task.event_handler("on_pipeline_stopped") - async def on_pipeline_stopped(task, frame: StopFrame): - ... - - @task.event_handler("on_pipeline_ended") - async def on_pipeline_ended(task, frame: EndFrame): - ... - - @task.event_handler("on_pipeline_cancelled") - async def on_pipeline_cancelled(task, frame: CancelFrame): - ... - - Args: - pipeline: The pipeline to execute. - params: Configuration parameters for the pipeline. - additional_span_attributes: Optional dictionary of attributes to propagate as - OpenTelemetry conversation span attributes. - cancel_on_idle_timeout: Whether the pipeline task should be cancelled if - the idle timeout is reached. - check_dangling_tasks: Whether to check for processors' tasks finishing properly. - clock: Clock implementation for timing operations. - conversation_id: Optional custom ID for the conversation. - enable_tracing: Whether to enable tracing. - enable_turn_tracking: Whether to enable turn tracking. - enable_watchdog_logging: Whether to print task processing times. - enable_watchdog_timers: Whether to enable task watchdog timers. - idle_timeout_frames: A tuple with the frames that should trigger an idle - timeout if not received withing `idle_timeout_seconds`. - idle_timeout_secs: Timeout (in seconds) to consider pipeline idle or - None. If a pipeline is idle the pipeline task will be cancelled - automatically. - observers: List of observers for monitoring pipeline execution. - watchdog_timeout_secs: Watchdog timer timeout (in seconds). A warning - will be logged if the watchdog timer is not reset before this timeout. + @task.event_handler("on_idle_timeout") + async def on_pipeline_idle_timeout(task): + ... """ def __init__( @@ -228,6 +213,32 @@ class PipelineTask(BasePipelineTask): task_manager: Optional[BaseTaskManager] = None, watchdog_timeout_secs: float = WATCHDOG_TIMEOUT, ): + """Initialize the PipelineTask. + + Args: + pipeline: The pipeline to execute. + params: Configuration parameters for the pipeline. + additional_span_attributes: Optional dictionary of attributes to propagate as + OpenTelemetry conversation span attributes. + cancel_on_idle_timeout: Whether the pipeline task should be cancelled if + the idle timeout is reached. + check_dangling_tasks: Whether to check for processors' tasks finishing properly. + clock: Clock implementation for timing operations. + conversation_id: Optional custom ID for the conversation. + enable_tracing: Whether to enable tracing. + enable_turn_tracking: Whether to enable turn tracking. + enable_watchdog_logging: Whether to print task processing times. + enable_watchdog_timers: Whether to enable task watchdog timers. + idle_timeout_frames: A tuple with the frames that should trigger an idle + timeout if not received within `idle_timeout_seconds`. + idle_timeout_secs: Timeout (in seconds) to consider pipeline idle or + None. If a pipeline is idle the pipeline task will be cancelled + automatically. + observers: List of observers for monitoring pipeline execution. + task_manager: Optional task manager for handling asyncio tasks. + watchdog_timeout_secs: Watchdog timer timeout (in seconds). A warning + will be logged if the watchdog timer is not reset before this timeout. + """ super().__init__() self._pipeline = pipeline self._params = params or PipelineParams() @@ -331,60 +342,97 @@ class PipelineTask(BasePipelineTask): @property def params(self) -> PipelineParams: - """Returns the pipeline parameters of this task.""" + """Get the pipeline parameters for this task. + + Returns: + The pipeline parameters configuration. + """ return self._params @property def turn_tracking_observer(self) -> Optional[TurnTrackingObserver]: - """Return the turn tracking observer if enabled.""" + """Get the turn tracking observer if enabled. + + Returns: + The turn tracking observer instance or None if not enabled. + """ return self._turn_tracking_observer @property def turn_trace_observer(self) -> Optional[TurnTraceObserver]: - """Return the turn trace observer if enabled.""" + """Get the turn trace observer if enabled. + + Returns: + The turn trace observer instance or None if not enabled. + """ return self._turn_trace_observer def add_observer(self, observer: BaseObserver): + """Add an observer to monitor pipeline execution. + + Args: + observer: The observer to add to the pipeline monitoring. + """ self._observer.add_observer(observer) async def remove_observer(self, observer: BaseObserver): + """Remove an observer from pipeline monitoring. + + Args: + observer: The observer to remove from pipeline monitoring. + """ await self._observer.remove_observer(observer) def set_reached_upstream_filter(self, types: Tuple[Type[Frame], ...]): - """Sets which frames will be checked before calling the - on_frame_reached_upstream event handler. + """Set which frame types trigger the on_frame_reached_upstream event. + Args: + types: Tuple of frame types to monitor for upstream events. """ self._reached_upstream_types = types def set_reached_downstream_filter(self, types: Tuple[Type[Frame], ...]): - """Sets which frames will be checked before calling the - on_frame_reached_downstream event handler. + """Set which frame types trigger the on_frame_reached_downstream event. + Args: + types: Tuple of frame types to monitor for downstream events. """ self._reached_downstream_types = types def has_finished(self) -> bool: - """Indicates whether the tasks has finished. That is, all processors + """Check if the pipeline task has finished execution. + + This indicates whether the tasks has finished, meaninig all processors have stopped. + Returns: + True if all processors have stopped and the task is complete. """ return self._finished async def stop_when_done(self): - """This is a helper function that sends an EndFrame to the pipeline in - order to stop the task after everything in it has been processed. + """Schedule the pipeline to stop after processing all queued frames. + Sends an EndFrame to gracefully terminate the pipeline once all + current processing is complete. """ logger.debug(f"Task {self} scheduled to stop when done") await self.queue_frame(EndFrame()) async def cancel(self): - """Stops the running pipeline immediately.""" + """Immediately stop the running pipeline. + + Cancels all running tasks and stops frame processing without + waiting for completion. + """ await self._cancel() async def run(self, params: PipelineTaskParams): - """Starts and manages the pipeline execution until completion or cancellation.""" + """Start and manage the pipeline execution until completion or cancellation. + + Args: + params: Configuration parameters for pipeline execution. + """ if self.has_finished(): return cleanup_pipeline = True @@ -440,6 +488,7 @@ class PipelineTask(BasePipelineTask): await self.queue_frame(frame) async def _cancel(self): + """Internal cancellation logic for the pipeline task.""" if not self._cancelled: logger.debug(f"Canceling pipeline task {self}") self._cancelled = True @@ -453,6 +502,7 @@ class PipelineTask(BasePipelineTask): self._process_push_task = None async def _create_tasks(self): + """Create and start all pipeline processing tasks.""" self._process_up_task = self._task_manager.create_task( self._process_up_queue(), f"{self}::_process_up_queue" ) @@ -468,6 +518,7 @@ class PipelineTask(BasePipelineTask): return self._process_push_task def _maybe_start_heartbeat_tasks(self): + """Start heartbeat tasks if heartbeats are enabled and not already running.""" if self._params.enable_heartbeats and self._heartbeat_push_task is None: self._heartbeat_push_task = self._task_manager.create_task( self._heartbeat_push_handler(), f"{self}::_heartbeat_push_handler" @@ -477,12 +528,14 @@ class PipelineTask(BasePipelineTask): ) def _maybe_start_idle_task(self): + """Start idle monitoring task if idle timeout is configured.""" if self._idle_timeout_secs: self._idle_monitor_task = self._task_manager.create_task( self._idle_monitor_handler(), f"{self}::_idle_monitor_handler" ) async def _cancel_tasks(self): + """Cancel all running pipeline tasks.""" await self._observer.stop() if self._process_up_task: @@ -497,6 +550,7 @@ class PipelineTask(BasePipelineTask): await self._maybe_cancel_idle_task() async def _maybe_cancel_heartbeat_tasks(self): + """Cancel heartbeat tasks if they are running.""" if not self._params.enable_heartbeats: return @@ -509,11 +563,13 @@ class PipelineTask(BasePipelineTask): self._heartbeat_monitor_task = None async def _maybe_cancel_idle_task(self): + """Cancel idle monitoring task if it is running.""" if self._idle_timeout_secs and self._idle_monitor_task: await self._task_manager.cancel_task(self._idle_monitor_task) self._idle_monitor_task = None def _initial_metrics_frame(self) -> MetricsFrame: + """Create an initial metrics frame with zero values for all processors.""" processors = self._pipeline.processors_with_metrics() data = [] for p in processors: @@ -522,10 +578,12 @@ class PipelineTask(BasePipelineTask): return MetricsFrame(data=data) async def _wait_for_pipeline_end(self): + """Wait for the pipeline to signal completion.""" await self._pipeline_end_event.wait() self._pipeline_end_event.clear() async def _setup(self, params: PipelineTaskParams): + """Set up the pipeline task and all processors.""" mgr_params = TaskManagerParams( loop=params.loop, enable_watchdog_logging=self._enable_watchdog_logging, @@ -545,6 +603,7 @@ class PipelineTask(BasePipelineTask): await self._sink.setup(setup) async def _cleanup(self, cleanup_pipeline: bool): + """Clean up the pipeline task and processors.""" # Cleanup base object. await self.cleanup() @@ -559,10 +618,11 @@ class PipelineTask(BasePipelineTask): await self._sink.cleanup() async def _process_push_queue(self): - """This is the task that runs the pipeline for the first time by sending + """Process frames from the push queue and send them through the pipeline. + + This is the task that runs the pipeline for the first time by sending a StartFrame and by pushing any other frames queued by the user. It runs until the tasks is cancelled or stopped (e.g. with an EndFrame). - """ self._clock.start() @@ -596,11 +656,12 @@ class PipelineTask(BasePipelineTask): await self._cleanup(cleanup_pipeline) async def _process_up_queue(self): - """This is the task that processes frames coming upstream from the + """Process frames coming upstream from the pipeline. + + This is the task that processes frames coming upstream from the pipeline. These frames might indicate, for example, that we want the pipeline to be stopped (e.g. EndTaskFrame) in which case we would send an EndFrame down the pipeline. - """ while True: frame = await self._up_queue.get() @@ -629,11 +690,12 @@ class PipelineTask(BasePipelineTask): self._up_queue.task_done() async def _process_down_queue(self): - """This tasks process frames coming downstream from the pipeline. For + """Process frames coming downstream from the pipeline. + + This tasks process frames coming downstream from the pipeline. For example, heartbeat frames or an EndFrame which would indicate all processors have handled the EndFrame and therefore we can exit the task cleanly. - """ while True: frame = await self._down_queue.get() @@ -664,7 +726,7 @@ class PipelineTask(BasePipelineTask): self._down_queue.task_done() async def _heartbeat_push_handler(self): - """This tasks pushes a heartbeat frame every heartbeat period.""" + """Push heartbeat frames at regular intervals.""" while True: # Don't use `queue_frame()` because if an EndFrame is queued the # task will just stop waiting for the pipeline to finish not @@ -673,11 +735,12 @@ class PipelineTask(BasePipelineTask): await asyncio.sleep(self._params.heartbeats_period_secs) async def _heartbeat_monitor_handler(self): - """This tasks monitors heartbeat frames. If a heartbeat frame has not + """Monitor heartbeat frames for processing time and timeout detection. + + This task monitors heartbeat frames. If a heartbeat frame has not been received for a long period a warning will be logged. It also logs the time that a heartbeat frame takes to processes, that is how long it takes for the heartbeat frame to traverse all the pipeline. - """ wait_time = HEARTBEAT_MONITOR_SECONDS while True: @@ -692,9 +755,12 @@ class PipelineTask(BasePipelineTask): ) async def _idle_monitor_handler(self): - """This tasks monitors activity in the pipeline. If no frames are - received (heartbeats don't count) the pipeline is considered idle. + """Monitor pipeline activity and detect idle conditions. + Tracks frame activity and triggers idle timeout events when the + pipeline hasn't received relevant frames within the timeout period. + + Note: Heartbeats are excluded from idle detection. """ running = True last_frame_time = 0 @@ -732,10 +798,13 @@ class PipelineTask(BasePipelineTask): running = await self._idle_timeout_detected(frame_buffer) async def _idle_timeout_detected(self, last_frames: Deque[Frame]) -> bool: - """Logic for when the pipeline is idle. + """Handle idle timeout detection and optional cancellation. + + Args: + last_frames: Recent frames received before timeout for debugging. Returns: - bool: Whther the pipeline task is being cancelled or not. + Whether the pipeline task should continue running. """ logger.warning("Idle timeout detected. Last 10 frames received:") for i, frame in enumerate(last_frames, 1): @@ -749,6 +818,7 @@ class PipelineTask(BasePipelineTask): return True def _print_dangling_tasks(self): + """Log any dangling tasks that haven't been properly cleaned up.""" tasks = [t.get_name() for t in self._task_manager.current_tasks()] if tasks: logger.warning(f"Dangling tasks detected: {tasks}") diff --git a/src/pipecat/pipeline/task_observer.py b/src/pipecat/pipeline/task_observer.py index 40f4c953c..cd46f85ef 100644 --- a/src/pipecat/pipeline/task_observer.py +++ b/src/pipecat/pipeline/task_observer.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Task observer for managing pipeline frame observers. + +This module provides a proxy observer system that manages multiple observers +for pipeline frame events, ensuring that observer processing doesn't block +the main pipeline execution. +""" + import asyncio import inspect from typing import Dict, List, Optional @@ -17,9 +24,15 @@ from pipecat.utils.asyncio.watchdog_queue import WatchdogQueue @dataclass class Proxy: - """This is the data we receive from the main observer and that we put into - a queue for later processing. + """Proxy data for managing observer tasks and queues. + This represents is the data received from the main observer that + is queued for later processing. + + Parameters: + queue: Queue for frame data awaiting observer processing. + task: Asyncio task running the observer's frame processing loop. + observer: The actual observer instance being proxied. """ queue: asyncio.Queue @@ -28,7 +41,9 @@ class Proxy: class TaskObserver(BaseObserver): - """This is a pipeline frame observer that is meant to be used as a proxy to + """Proxy observer that manages multiple observers without blocking the pipeline. + + This is a pipeline frame observer that is meant to be used as a proxy to the user provided observers. That is, this is the observer that should be passed to the frame processors. Then, every time a frame is pushed this observer will call all the observers registered to the pipeline task. @@ -37,7 +52,6 @@ class TaskObserver(BaseObserver): pipeline by creating a queue and a task for each user observer. When a frame is received, it will be put in a queue for efficiency and later processed by each task. - """ def __init__( @@ -47,6 +61,13 @@ class TaskObserver(BaseObserver): task_manager: BaseTaskManager, **kwargs, ): + """Initialize the TaskObserver. + + Args: + observers: List of observers to manage. Defaults to empty list. + task_manager: Task manager for creating and managing observer tasks. + **kwargs: Additional arguments passed to the base observer. + """ super().__init__(**kwargs) self._observers = observers or [] self._task_manager = task_manager @@ -55,6 +76,11 @@ class TaskObserver(BaseObserver): ) def add_observer(self, observer: BaseObserver): + """Add a new observer to the managed list. + + Args: + observer: The observer to add. + """ # Add the observer to the list. self._observers.append(observer) @@ -65,6 +91,11 @@ class TaskObserver(BaseObserver): self._proxies[observer] = proxy async def remove_observer(self, observer: BaseObserver): + """Remove an observer and clean up its resources. + + Args: + observer: The observer to remove. + """ # If the observer has a proxy, remove it. if observer in self._proxies: proxy = self._proxies[observer] @@ -78,11 +109,11 @@ class TaskObserver(BaseObserver): self._observers.remove(observer) async def start(self): - """Starts all proxy observer tasks.""" + """Start all proxy observer tasks.""" self._proxies = self._create_proxies(self._observers) async def stop(self): - """Stops all proxy observer tasks.""" + """Stop all proxy observer tasks.""" if not self._proxies: return @@ -90,13 +121,20 @@ class TaskObserver(BaseObserver): await self._task_manager.cancel_task(proxy.task) async def on_push_frame(self, data: FramePushed): + """Queue frame data for all managed observers. + + Args: + data: The frame push event data to distribute to observers. + """ for proxy in self._proxies.values(): await proxy.queue.put(data) def _started(self) -> bool: + """Check if the task observer has been started.""" return self._proxies is not None def _create_proxy(self, observer: BaseObserver) -> Proxy: + """Create a proxy for a single observer.""" queue = WatchdogQueue(self._task_manager) task = self._task_manager.create_task( self._proxy_task_handler(queue, observer), @@ -106,6 +144,7 @@ class TaskObserver(BaseObserver): return proxy def _create_proxies(self, observers: List[BaseObserver]) -> Dict[BaseObserver, Proxy]: + """Create proxies for all observers.""" proxies = {} for observer in observers: proxy = self._create_proxy(observer) @@ -113,6 +152,7 @@ class TaskObserver(BaseObserver): return proxies async def _proxy_task_handler(self, queue: asyncio.Queue, observer: BaseObserver): + """Handle frame processing for a single observer.""" warning_reported = False while True: data = await queue.get() diff --git a/src/pipecat/pipeline/to_be_updated/merge_pipeline.py b/src/pipecat/pipeline/to_be_updated/merge_pipeline.py index 27a52894b..4cd825a97 100644 --- a/src/pipecat/pipeline/to_be_updated/merge_pipeline.py +++ b/src/pipecat/pipeline/to_be_updated/merge_pipeline.py @@ -1,3 +1,16 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Sequential pipeline merging for Pipecat. + +This module provides a pipeline implementation that sequentially merges +the output from multiple pipelines, processing them one after another +in a specified order. +""" + from typing import List from pipecat.frames.frames import EndFrame, EndPipeFrame @@ -5,14 +18,31 @@ from pipecat.pipeline.pipeline import Pipeline class SequentialMergePipeline(Pipeline): - """This class merges the sink queues from a list of pipelines. Frames from - each pipeline's sink are merged in the order of pipelines in the list.""" + """Pipeline that sequentially merges output from multiple pipelines. + + This pipeline merges the sink queues from a list of pipelines by processing + frames from each pipeline's sink sequentially in the order specified. Each + pipeline runs to completion before the next one begins processing. + """ def __init__(self, pipelines: List[Pipeline]): + """Initialize the sequential merge pipeline. + + Args: + pipelines: List of pipelines to merge sequentially. Pipelines will + be processed in the order they appear in this list. + """ super().__init__([]) self.pipelines = pipelines async def run_pipeline(self): + """Run all pipelines sequentially and merge their output. + + Processes each pipeline in order, consuming all frames from each + pipeline's sink until an EndFrame or EndPipeFrame is encountered, + then moves to the next pipeline. After all pipelines complete, + sends a final EndFrame to signal completion. + """ for idx, pipeline in enumerate(self.pipelines): while True: frame = await pipeline.sink.get() diff --git a/src/pipecat/processors/aggregators/dtmf_aggregator.py b/src/pipecat/processors/aggregators/dtmf_aggregator.py index 006a7c181..f3485245c 100644 --- a/src/pipecat/processors/aggregators/dtmf_aggregator.py +++ b/src/pipecat/processors/aggregators/dtmf_aggregator.py @@ -33,6 +33,7 @@ class DTMFAggregator(FrameProcessor): The aggregator accumulates digits from InputDTMFFrame instances and flushes when: + - Timeout occurs (configurable idle period) - Termination digit is received (default: '#') - EndFrame or CancelFrame is received diff --git a/src/pipecat/processors/aggregators/llm_response.py b/src/pipecat/processors/aggregators/llm_response.py index 6d27d1ddf..85c4a3b06 100644 --- a/src/pipecat/processors/aggregators/llm_response.py +++ b/src/pipecat/processors/aggregators/llm_response.py @@ -92,7 +92,7 @@ class LLMFullResponseAggregator(FrameProcessor): the complete response via an event handler. The aggregator provides an "on_completion" event that fires when a full - completion is available: + completion is available:: @aggregator.event_handler("on_completion") async def on_completion( @@ -363,6 +363,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator): This aggregator handles the complex logic of aggregating user speech transcriptions from STT services. It manages multiple scenarios including: + - Transcriptions received between VAD events - Transcriptions received outside VAD events - Interim vs final transcriptions @@ -654,6 +655,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator): """Assistant LLM aggregator that processes bot responses and function calls. This aggregator handles the complex logic of processing assistant responses including: + - Text frame aggregation between response start/end markers - Function call lifecycle management - Context updates with timestamps diff --git a/src/pipecat/processors/aggregators/openai_llm_context.py b/src/pipecat/processors/aggregators/openai_llm_context.py index a8520546a..82edcadb0 100644 --- a/src/pipecat/processors/aggregators/openai_llm_context.py +++ b/src/pipecat/processors/aggregators/openai_llm_context.py @@ -210,9 +210,10 @@ class OpenAILLMContext: def from_standard_message(self, message): """Convert from OpenAI message format to OpenAI message format (passthrough). - OpenAI's format allows both simple string content and structured content: - - Simple: {"role": "user", "content": "Hello"} - - Structured: {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + OpenAI's format allows both simple string content and structured content:: + + Simple: {"role": "user", "content": "Hello"} + Structured: {"role": "user", "content": [{"type": "text", "text": "Hello"}]} Since OpenAI is our standard format, this is a passthrough function. diff --git a/src/pipecat/processors/audio/audio_buffer_processor.py b/src/pipecat/processors/audio/audio_buffer_processor.py index f48d2d6a8..78561b2f9 100644 --- a/src/pipecat/processors/audio/audio_buffer_processor.py +++ b/src/pipecat/processors/audio/audio_buffer_processor.py @@ -39,17 +39,19 @@ class AudioBufferProcessor(FrameProcessor): including sample rate conversion and mono/stereo output. Events: - on_audio_data: Triggered when buffer_size is reached, providing merged audio - on_track_audio_data: Triggered when buffer_size is reached, providing separate tracks - on_user_turn_audio_data: Triggered when user turn has ended, providing that user turn's audio - on_bot_turn_audio_data: Triggered when bot turn has ended, providing that bot turn's audio + + - on_audio_data: Triggered when buffer_size is reached, providing merged audio + - on_track_audio_data: Triggered when buffer_size is reached, providing separate tracks + - on_user_turn_audio_data: Triggered when user turn has ended, providing that user turn's audio + - on_bot_turn_audio_data: Triggered when bot turn has ended, providing that bot turn's audio Audio handling: - - Mono output (num_channels=1): User and bot audio are mixed - - Stereo output (num_channels=2): User audio on left, bot audio on right - - Automatic resampling of incoming audio to match desired sample_rate - - Silence insertion for non-continuous audio streams - - Buffer synchronization between user and bot audio + + - Mono output (num_channels=1): User and bot audio are mixed + - Stereo output (num_channels=2): User audio on left, bot audio on right + - Automatic resampling of incoming audio to match desired sample_rate + - Silence insertion for non-continuous audio streams + - Buffer synchronization between user and bot audio """ def __init__( diff --git a/src/pipecat/processors/transcript_processor.py b/src/pipecat/processors/transcript_processor.py index 856311392..df6dd469b 100644 --- a/src/pipecat/processors/transcript_processor.py +++ b/src/pipecat/processors/transcript_processor.py @@ -84,6 +84,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor): This processor aggregates TTS text frames into complete utterances and emits them as transcript messages. Utterances are completed when: + - The bot stops speaking (BotStoppedSpeakingFrame) - The bot is interrupted (StartInterruptionFrame) - The pipeline ends (EndFrame) @@ -108,34 +109,34 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor): TTS services with different formatting patterns. Examples: - Fragments with embedded spacing (concatenated): - ``` + Fragments with embedded spacing (concatenated):: + TTSTextFrame: ["Hello"] TTSTextFrame: [" there"] # Leading space TTSTextFrame: ["!"] TTSTextFrame: [" How"] # Leading space TTSTextFrame: ["'s"] TTSTextFrame: [" it"] # Leading space - ``` + Result: "Hello there! How's it" - Fragments with trailing spaces (concatenated): - ``` + Fragments with trailing spaces (concatenated):: + TTSTextFrame: ["Hel"] TTSTextFrame: ["lo "] # Trailing space TTSTextFrame: ["to "] # Trailing space TTSTextFrame: ["you"] - ``` + Result: "Hello to you" - Word-by-word fragments without spacing (joined with spaces): - ``` + Word-by-word fragments without spacing (joined with spaces):: + TTSTextFrame: ["Hello"] TTSTextFrame: ["there"] TTSTextFrame: ["how"] TTSTextFrame: ["are"] TTSTextFrame: ["you"] - ``` + Result: "Hello there how are you" """ if self._current_text_parts and self._aggregation_start_time: @@ -179,6 +180,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor): """Process frames into assistant conversation messages. Handles different frame types: + - TTSTextFrame: Aggregates text for current utterance - BotStoppedSpeakingFrame: Completes current utterance - StartInterruptionFrame: Completes current utterance due to interruption @@ -221,8 +223,8 @@ class TranscriptProcessor: Provides unified access to user and assistant transcript processors with shared event handling. - Example: - ```python + Example:: + transcript = TranscriptProcessor() pipeline = Pipeline( @@ -242,7 +244,6 @@ class TranscriptProcessor: @transcript.event_handler("on_transcript_update") async def handle_update(processor, frame): print(f"New messages: {frame.messages}") - ``` """ def __init__(self): diff --git a/src/pipecat/processors/user_idle_processor.py b/src/pipecat/processors/user_idle_processor.py index c692642cc..1a442fd8a 100644 --- a/src/pipecat/processors/user_idle_processor.py +++ b/src/pipecat/processors/user_idle_processor.py @@ -28,8 +28,8 @@ class UserIdleProcessor(FrameProcessor): users become idle. It starts monitoring only after the first conversation activity and supports both basic and retry-based callback patterns. - Example: - ``` + Example:: + # Retry callback: async def handle_idle(processor: "UserIdleProcessor", retry_count: int) -> bool: if retry_count < 3: @@ -45,7 +45,6 @@ class UserIdleProcessor(FrameProcessor): callback=handle_idle, timeout=5.0 ) - ``` """ def __init__( @@ -61,11 +60,10 @@ class UserIdleProcessor(FrameProcessor): """Initialize the user idle processor. Args: - callback: Function to call when user is idle. Can be either: - - Basic callback(processor) -> None - - Retry callback(processor, retry_count) -> bool - Return True to continue monitoring for idle events, - Return False to stop the idle monitoring task + callback: Function to call when user is idle. Can be either a basic + callback taking only the processor, or a retry callback taking + the processor and retry count. Retry callbacks should return + True to continue monitoring or False to stop. timeout: Seconds to wait before considering user idle. **kwargs: Additional arguments passed to FrameProcessor. """ diff --git a/src/pipecat/serializers/base_serializer.py b/src/pipecat/serializers/base_serializer.py index c70b1408e..f57165034 100644 --- a/src/pipecat/serializers/base_serializer.py +++ b/src/pipecat/serializers/base_serializer.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Frame serialization interfaces for Pipecat.""" + from abc import ABC, abstractmethod from enum import Enum @@ -11,23 +13,63 @@ from pipecat.frames.frames import Frame, StartFrame class FrameSerializerType(Enum): + """Enumeration of supported frame serialization formats. + + Parameters: + BINARY: Binary serialization format for compact representation. + TEXT: Text-based serialization format for human-readable output. + """ + BINARY = "binary" TEXT = "text" class FrameSerializer(ABC): + """Abstract base class for frame serialization implementations. + + Defines the interface for converting frames to/from serialized formats + for transmission or storage. Subclasses must implement serialization + type detection and the core serialize/deserialize methods. + """ + @property @abstractmethod def type(self) -> FrameSerializerType: + """Get the serialization type supported by this serializer. + + Returns: + The FrameSerializerType indicating binary or text format. + """ pass async def setup(self, frame: StartFrame): + """Initialize the serializer with startup configuration. + + Args: + frame: StartFrame containing initialization parameters. + """ pass @abstractmethod async def serialize(self, frame: Frame) -> str | bytes | None: + """Convert a frame to its serialized representation. + + Args: + frame: The frame to serialize. + + Returns: + Serialized frame data as string, bytes, or None if serialization fails. + """ pass @abstractmethod async def deserialize(self, data: str | bytes) -> Frame | None: + """Convert serialized data back to a frame object. + + Args: + data: Serialized frame data as string or bytes. + + Returns: + Reconstructed Frame object, or None if deserialization fails. + """ pass diff --git a/src/pipecat/serializers/exotel.py b/src/pipecat/serializers/exotel.py index 4e546e442..960cbe74a 100644 --- a/src/pipecat/serializers/exotel.py +++ b/src/pipecat/serializers/exotel.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Exotel Media Streams serializer for Pipecat.""" + import base64 import json from typing import Optional @@ -33,13 +35,14 @@ class ExotelFrameSerializer(FrameSerializer): media streams protocol. It supports audio conversion, DTMF events, and automatic call termination. - Ref Doc for events - https://support.exotel.com/support/solutions/articles/3000108630-working-with-the-stream-and-voicebot-applet + Note: Ref docs for events: + https://support.exotel.com/support/solutions/articles/3000108630-working-with-the-stream-and-voicebot-applet """ class InputParams(BaseModel): """Configuration parameters for ExotelFrameSerializer. - Attributes: + Parameters: exotel_sample_rate: Sample rate used by Exotel, defaults to 8000 Hz. sample_rate: Optional override for pipeline input sample rate. """ diff --git a/src/pipecat/serializers/livekit.py b/src/pipecat/serializers/livekit.py index d856a7a56..3d4188960 100644 --- a/src/pipecat/serializers/livekit.py +++ b/src/pipecat/serializers/livekit.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""LiveKit frame serializer for Pipecat.""" + import ctypes import pickle @@ -21,11 +23,33 @@ except ModuleNotFoundError as e: class LivekitFrameSerializer(FrameSerializer): + """Serializer for converting between Pipecat frames and LiveKit audio frames. + + This serializer handles the conversion of Pipecat's OutputAudioRawFrame objects + to LiveKit AudioFrame objects for transmission, and the reverse conversion + for received audio data. + """ + @property def type(self) -> FrameSerializerType: + """Get the serializer type. + + Returns: + The serializer type indicating binary serialization. + """ return FrameSerializerType.BINARY async def serialize(self, frame: Frame) -> str | bytes | None: + """Serialize a Pipecat frame to LiveKit AudioFrame format. + + Args: + frame: The Pipecat frame to serialize. Only OutputAudioRawFrame + instances are supported. + + Returns: + Pickled LiveKit AudioFrame bytes if frame is OutputAudioRawFrame, + None otherwise. + """ if not isinstance(frame, OutputAudioRawFrame): return None audio_frame = AudioFrame( @@ -37,6 +61,15 @@ class LivekitFrameSerializer(FrameSerializer): return pickle.dumps(audio_frame) async def deserialize(self, data: str | bytes) -> Frame | None: + """Deserialize LiveKit AudioFrame data to a Pipecat frame. + + Args: + data: Pickled data containing a LiveKit AudioFrame. + + Returns: + InputAudioRawFrame containing the deserialized audio data, + or None if deserialization fails. + """ audio_frame: AudioFrame = pickle.loads(data)["frame"] return InputAudioRawFrame( audio=bytes(audio_frame.data), diff --git a/src/pipecat/serializers/plivo.py b/src/pipecat/serializers/plivo.py index 7fcd951d4..d0c19be0a 100644 --- a/src/pipecat/serializers/plivo.py +++ b/src/pipecat/serializers/plivo.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Plivo WebSocket frame serializer for audio streaming.""" + import base64 import json from typing import Optional @@ -38,22 +40,12 @@ class PlivoFrameSerializer(FrameSerializer): When auto_hang_up is enabled (default), the serializer will automatically terminate the Plivo call when an EndFrame or CancelFrame is processed, but requires Plivo credentials to be provided. - - Attributes: - _stream_id: The Plivo Stream ID. - _call_id: The associated Plivo Call ID. - _auth_id: Plivo auth ID for API access. - _auth_token: Plivo authentication token for API access. - _params: Configuration parameters. - _plivo_sample_rate: Sample rate used by Plivo (typically 8kHz). - _sample_rate: Input sample rate for the pipeline. - _resampler: Audio resampler for format conversion. """ class InputParams(BaseModel): """Configuration parameters for PlivoFrameSerializer. - Attributes: + Parameters: plivo_sample_rate: Sample rate used by Plivo, defaults to 8000 Hz. sample_rate: Optional override for pipeline input sample rate. auto_hang_up: Whether to automatically terminate call on EndFrame. diff --git a/src/pipecat/serializers/protobuf.py b/src/pipecat/serializers/protobuf.py index c91c6661e..867fa0674 100644 --- a/src/pipecat/serializers/protobuf.py +++ b/src/pipecat/serializers/protobuf.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Protobuf frame serialization for Pipecat.""" + import dataclasses import json @@ -22,13 +24,25 @@ from pipecat.frames.frames import ( from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializerType -# Data class for converting transport messages into Protobuf format. @dataclasses.dataclass class MessageFrame: + """Data class for converting transport messages into Protobuf format. + + Parameters: + data: JSON-encoded message data for transport. + """ + data: str class ProtobufFrameSerializer(FrameSerializer): + """Serializer for converting Pipecat frames to/from Protocol Buffer format. + + Provides efficient binary serialization for frame transport over network + connections. Supports text, audio, transcription, and message frames with + automatic conversion between transport message types. + """ + SERIALIZABLE_TYPES = { TextFrame: "text", OutputAudioRawFrame: "audio", @@ -46,13 +60,27 @@ class ProtobufFrameSerializer(FrameSerializer): DESERIALIZABLE_FIELDS = {v: k for k, v in DESERIALIZABLE_TYPES.items()} def __init__(self): + """Initialize the Protobuf frame serializer.""" pass @property def type(self) -> FrameSerializerType: + """Get the serializer type. + + Returns: + FrameSerializerType.BINARY indicating binary serialization format. + """ return FrameSerializerType.BINARY async def serialize(self, frame: Frame) -> str | bytes | None: + """Serialize a frame to Protocol Buffer binary format. + + Args: + frame: The frame to serialize. + + Returns: + Serialized frame as bytes, or None if frame type is not serializable. + """ # Wrapping this messages as a JSONFrame to send if isinstance(frame, (TransportMessageFrame, TransportMessageUrgentFrame)): frame = MessageFrame( @@ -75,6 +103,14 @@ class ProtobufFrameSerializer(FrameSerializer): return proto_frame.SerializeToString() async def deserialize(self, data: str | bytes) -> Frame | None: + """Deserialize Protocol Buffer binary data to a frame. + + Args: + data: Binary protobuf data to deserialize. + + Returns: + Deserialized frame instance, or None if deserialization fails. + """ proto = frame_protos.Frame.FromString(data) which = proto.WhichOneof("frame") if which not in self.DESERIALIZABLE_FIELDS: diff --git a/src/pipecat/serializers/telnyx.py b/src/pipecat/serializers/telnyx.py index 5f8a09252..adc235555 100644 --- a/src/pipecat/serializers/telnyx.py +++ b/src/pipecat/serializers/telnyx.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Telnyx WebSocket frame serializer for Pipecat.""" + import base64 import json from typing import Optional @@ -43,22 +45,12 @@ class TelnyxFrameSerializer(FrameSerializer): When auto_hang_up is enabled (default), the serializer will automatically terminate the Telnyx call when an EndFrame or CancelFrame is processed, but requires Telnyx credentials to be provided. - - Attributes: - _stream_id: The Telnyx Stream ID. - _call_control_id: The associated Telnyx Call Control ID. - _api_key: Telnyx API key for API access. - _params: Configuration parameters. - _telnyx_sample_rate: Sample rate used by Telnyx (typically 8kHz). - _sample_rate: Input sample rate for the pipeline. - _resampler: Audio resampler for format conversion. - _hangup_attempted: Flag to track if hang-up has been attempted. """ class InputParams(BaseModel): """Configuration parameters for TelnyxFrameSerializer. - Attributes: + Parameters: telnyx_sample_rate: Sample rate used by Telnyx, defaults to 8000 Hz. sample_rate: Optional override for pipeline input sample rate. inbound_encoding: Audio encoding for data sent to Telnyx (e.g., "PCMU"). diff --git a/src/pipecat/serializers/twilio.py b/src/pipecat/serializers/twilio.py index 26a3fea5e..ae4d54e4d 100644 --- a/src/pipecat/serializers/twilio.py +++ b/src/pipecat/serializers/twilio.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Twilio Media Streams WebSocket protocol serializer for Pipecat.""" + import base64 import json from typing import Optional @@ -38,22 +40,12 @@ class TwilioFrameSerializer(FrameSerializer): When auto_hang_up is enabled (default), the serializer will automatically terminate the Twilio call when an EndFrame or CancelFrame is processed, but requires Twilio credentials to be provided. - - Attributes: - _stream_sid: The Twilio Media Stream SID. - _call_sid: The associated Twilio Call SID. - _account_sid: Twilio account SID for API access. - _auth_token: Twilio authentication token for API access. - _params: Configuration parameters. - _twilio_sample_rate: Sample rate used by Twilio (typically 8kHz). - _sample_rate: Input sample rate for the pipeline. - _resampler: Audio resampler for format conversion. """ class InputParams(BaseModel): """Configuration parameters for TwilioFrameSerializer. - Attributes: + Parameters: twilio_sample_rate: Sample rate used by Twilio, defaults to 8000 Hz. sample_rate: Optional override for pipeline input sample rate. auto_hang_up: Whether to automatically terminate call on EndFrame. diff --git a/src/pipecat/services/anthropic/llm.py b/src/pipecat/services/anthropic/llm.py index 33b9c9e30..d7c703ebb 100644 --- a/src/pipecat/services/anthropic/llm.py +++ b/src/pipecat/services/anthropic/llm.py @@ -538,20 +538,37 @@ class AnthropicLLMContext(OpenAILLMContext): Handles text content and function calls for both user and assistant messages. Args: - obj: Message in Anthropic format: - { - "role": "user/assistant", - "content": str | [{"type": "text/tool_use/tool_result", ...}] - } + obj: Message in Anthropic format. Returns: - List of messages in standard format: - [ + List of messages in standard format. + + Examples: + Input Anthropic format:: + { - "role": "user/assistant/tool", - "content": [{"type": "text", "text": str}] + "role": "assistant", + "content": [ + {"type": "text", "text": "Hello"}, + {"type": "tool_use", "id": "123", "name": "search", "input": {"q": "test"}} + ] } - ] + + Output standard format:: + + [ + {"role": "assistant", "content": [{"type": "text", "text": "Hello"}]}, + { + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "id": "123", + "function": {"name": "search", "arguments": '{"q": "test"}'} + } + ] + } + ] """ # todo: image format (?) # tool_use @@ -613,23 +630,37 @@ class AnthropicLLMContext(OpenAILLMContext): Empty text content is converted to "(empty)". Args: - message: Message in standard format: - { - "role": "user/assistant/tool", - "content": str | [{"type": "text", ...}], - "tool_calls": [{"id": str, "function": {"name": str, "arguments": str}}] - } + message: Message in standard format. Returns: - Message in Anthropic format: - { - "role": "user/assistant", - "content": str | [ - {"type": "text", "text": str} | - {"type": "tool_use", "id": str, "name": str, "input": dict} | - {"type": "tool_result", "tool_use_id": str, "content": str} - ] - } + Message in Anthropic format. + + Examples: + Input standard format:: + + { + "role": "assistant", + "tool_calls": [ + { + "id": "123", + "function": {"name": "search", "arguments": '{"q": "test"}'} + } + ] + } + + Output Anthropic format:: + + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "123", + "name": "search", + "input": {"q": "test"} + } + ] + } """ # todo: image messages (?) if message["role"] == "tool": diff --git a/src/pipecat/services/aws/llm.py b/src/pipecat/services/aws/llm.py index 283cea601..43015b2bd 100644 --- a/src/pipecat/services/aws/llm.py +++ b/src/pipecat/services/aws/llm.py @@ -207,20 +207,37 @@ class AWSBedrockLLMContext(OpenAILLMContext): Handles text content and function calls for both user and assistant messages. Args: - obj: Message in AWS Bedrock format: - { - "role": "user/assistant", - "content": [{"text": str} | {"toolUse": {...}} | {"toolResult": {...}}] - } + obj: Message in AWS Bedrock format. Returns: - List of messages in standard format: - [ + List of messages in standard format. + + Examples: + AWS Bedrock format input:: + { - "role": "user/assistant/tool", - "content": [{"type": "text", "text": str}] + "role": "assistant", + "content": [ + {"text": "Hello"}, + {"toolUse": {"toolUseId": "123", "name": "search", "input": {"q": "test"}}} + ] } - ] + + Standard format output:: + + [ + {"role": "assistant", "content": [{"type": "text", "text": "Hello"}]}, + { + "role": "assistant", + "tool_calls": [ + { + "type": "function", + "id": "123", + "function": {"name": "search", "arguments": '{"q": "test"}'} + } + ] + } + ] """ role = obj.get("role") content = obj.get("content") @@ -294,23 +311,38 @@ class AWSBedrockLLMContext(OpenAILLMContext): Empty text content is converted to "(empty)". Args: - message: Message in standard format: - { - "role": "user/assistant/tool", - "content": str | [{"type": "text", ...}], - "tool_calls": [{"id": str, "function": {"name": str, "arguments": str}}] - } + message: Message in standard format. Returns: - Message in AWS Bedrock format: - { - "role": "user/assistant", - "content": [ - {"text": str} | - {"toolUse": {"toolUseId": str, "name": str, "input": dict}} | - {"toolResult": {"toolUseId": str, "content": [...], "status": str}} - ] - } + Message in AWS Bedrock format. + + Examples: + Standard format input:: + + { + "role": "assistant", + "tool_calls": [ + { + "id": "123", + "function": {"name": "search", "arguments": '{"q": "test"}'} + } + ] + } + + AWS Bedrock format output:: + + { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "123", + "name": "search", + "input": {"q": "test"} + } + } + ] + } """ if message["role"] == "tool": # Try to parse the content as JSON if it looks like JSON diff --git a/src/pipecat/services/aws/utils.py b/src/pipecat/services/aws/utils.py index cfd36b417..a2559d8bc 100644 --- a/src/pipecat/services/aws/utils.py +++ b/src/pipecat/services/aws/utils.py @@ -39,10 +39,8 @@ def get_presigned_url( Args: region: AWS region for the service. - credentials: Dictionary containing AWS credentials with keys: - - access_key: AWS access key ID - - secret_key: AWS secret access key - - session_token: AWS session token (optional) + credentials: Dictionary containing AWS credentials. Must include + 'access_key' and 'secret_key', with optional 'session_token'. language_code: Language code for transcription (e.g., "en-US"). media_encoding: Audio encoding format. Defaults to "pcm". sample_rate: Audio sample rate in Hz. Defaults to 16000. @@ -325,9 +323,10 @@ def decode_event(message): message: Raw event stream message bytes received from AWS. Returns: - Tuple containing: - - Dictionary of parsed headers - - Dictionary of parsed JSON payload + A tuple of (headers, payload) where: + + - headers: Dictionary of parsed headers + - payload: Dictionary of parsed JSON payload Raises: AssertionError: If CRC checksum verification fails. diff --git a/src/pipecat/services/elevenlabs/tts.py b/src/pipecat/services/elevenlabs/tts.py index 7153c39c5..13cb08255 100644 --- a/src/pipecat/services/elevenlabs/tts.py +++ b/src/pipecat/services/elevenlabs/tts.py @@ -764,21 +764,23 @@ class ElevenLabsHttpTTSService(WordTTSService): def calculate_word_times(self, alignment_info: Mapping[str, Any]) -> List[Tuple[str, float]]: """Calculate word timing from character alignment data. - Example input data: - { - "characters": [" ", "H", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"], - "character_start_times_seconds": [0.0, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], - "character_end_times_seconds": [0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] - } - - Would produce word times (with cumulative_time=0): - [("Hello", 0.1), ("world", 0.5)] - Args: alignment_info: Character timing data from ElevenLabs. Returns: List of (word, timestamp) pairs. + + Example input data:: + + { + "characters": [" ", "H", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"], + "character_start_times_seconds": [0.0, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + "character_end_times_seconds": [0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0] + } + + Would produce word times (with cumulative_time=0):: + + [("Hello", 0.1), ("world", 0.5)] """ chars = alignment_info.get("characters", []) char_start_times = alignment_info.get("character_start_times_seconds", []) diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index ec187000f..3b4f03a86 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Google services module for Pipecat.""" + import sys from pipecat.services import DeprecatedModuleProxy diff --git a/src/pipecat/services/google/llm.py b/src/pipecat/services/google/llm.py index 86ed4dd88..b4ddb3c00 100644 --- a/src/pipecat/services/google/llm.py +++ b/src/pipecat/services/google/llm.py @@ -380,18 +380,48 @@ class GoogleLLMContext(OpenAILLMContext): System messages are stored separately and return None. Args: - message: Message in standard format: - { - "role": "user/assistant/system/tool", - "content": str | [{"type": "text/image_url", ...}] | None, - "tool_calls": [{"function": {"name": str, "arguments": str}}] - } + message: Message in standard format. Returns: - Content object with: - - role: "user" or "model" (converted from "assistant") - - parts: List[Part] containing text, inline_data, or function calls - Returns None for system messages. + Content object with role and parts, or None for system messages. + + Examples: + Standard text message:: + + { + "role": "user", + "content": "Hello there" + } + + Converts to Google Content with:: + + Content( + role="user", + parts=[Part(text="Hello there")] + ) + + Standard function call message:: + + { + "role": "assistant", + "tool_calls": [ + { + "function": { + "name": "search", + "arguments": '{"query": "test"}' + } + } + ] + } + + Converts to Google Content with:: + + Content( + role="model", + parts=[Part(function_call=FunctionCall(name="search", args={"query": "test"}))] + ) + + System message returns None and stores content in self.system_message. """ role = message["role"] content = message.get("content", []) @@ -447,21 +477,73 @@ class GoogleLLMContext(OpenAILLMContext): Handles text, images, and function calls from Google's Content/Part objects. Args: - obj: Google Content object with: - - role: "model" (converted to "assistant") or "user" - - parts: List[Part] containing text, inline_data, or function calls + obj: Google Content object with role and parts. Returns: - List of messages in standard format: - [ - { - "role": "user/assistant/tool", - "content": [ - {"type": "text", "text": str} | - {"type": "image_url", "image_url": {"url": str}} - ] - } - ] + List containing a single message in standard format. + + Examples: + Google Content with text:: + + Content( + role="user", + parts=[Part(text="Hello")] + ) + + Converts to:: + + [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}] + } + ] + + Google Content with function call:: + + Content( + role="model", + parts=[Part(function_call=FunctionCall(name="search", args={"q": "test"}))] + ) + + Converts to:: + + [ + { + "role": "assistant", + "tool_calls": [ + { + "id": "search", + "type": "function", + "function": { + "name": "search", + "arguments": '{"q": "test"}' + } + } + ] + } + ] + + Google Content with image:: + + Content( + role="user", + parts=[Part(inline_data=Blob(mime_type="image/jpeg", data=bytes_data))] + ) + + Converts to:: + + [ + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": "data:image/jpeg;base64,"} + } + ] + } + ] """ msg = {"role": obj.role, "content": []} if msg["role"] == "model": diff --git a/src/pipecat/services/google/tts.py b/src/pipecat/services/google/tts.py index 40de8afff..0bd201c4b 100644 --- a/src/pipecat/services/google/tts.py +++ b/src/pipecat/services/google/tts.py @@ -471,8 +471,8 @@ class GoogleTTSService(TTSService): default application credentials (GOOGLE_APPLICATION_CREDENTIALS env var). Only Chirp 3 HD and Journey voices are supported. Use GoogleHttpTTSService for other voices. - Example: - ```python + Example:: + tts = GoogleTTSService( credentials_path="/path/to/service-account.json", voice_id="en-US-Chirp3-HD-Charon", @@ -480,7 +480,6 @@ class GoogleTTSService(TTSService): language=Language.EN_US, ) ) - ``` """ class InputParams(BaseModel): diff --git a/src/pipecat/services/llm_service.py b/src/pipecat/services/llm_service.py index 6ef85951a..68a74335b 100644 --- a/src/pipecat/services/llm_service.py +++ b/src/pipecat/services/llm_service.py @@ -128,13 +128,14 @@ class LLMService(AIService): parallel and sequential execution modes. Provides event handlers for completion timeouts and function call lifecycle events. - Event handlers: - on_completion_timeout: Called when an LLM completion timeout occurs. - on_function_calls_started: Called when function calls are received and - execution is about to start. + The service supports the following event handlers: + + - on_completion_timeout: Called when an LLM completion timeout occurs + - on_function_calls_started: Called when function calls are received and + execution is about to start + + Example:: - Example: - ```python @task.event_handler("on_completion_timeout") async def on_completion_timeout(service): logger.warning("LLM completion timed out") @@ -142,7 +143,6 @@ class LLMService(AIService): @task.event_handler("on_function_calls_started") async def on_function_calls_started(service, function_calls): logger.info(f"Starting {len(function_calls)} function calls") - ``` """ # OpenAILLMAdapter is used as the default adapter since it aligns with most LLM implementations. diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index ce21e33ad..eb8925760 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -639,6 +639,7 @@ class OpenAIRealtimeBetaLLMService(LLMService): """Maybe handle an error event related to retrieving a conversation item. If the given error event is an error retrieving a conversation item: + - set an exception on the future that retrieve_conversation_item() is waiting on - return true Otherwise: diff --git a/src/pipecat/services/sarvam/tts.py b/src/pipecat/services/sarvam/tts.py index eee4048cb..51fd4f07b 100644 --- a/src/pipecat/services/sarvam/tts.py +++ b/src/pipecat/services/sarvam/tts.py @@ -59,8 +59,8 @@ class SarvamTTSService(TTSService): Indian languages. Provides control over voice characteristics like pitch, pace, and loudness. - Example: - ```python + Example:: + tts = SarvamTTSService( api_key="your-api-key", voice_id="anushka", @@ -72,7 +72,6 @@ class SarvamTTSService(TTSService): pace=1.2 ) ) - ``` """ class InputParams(BaseModel): diff --git a/src/pipecat/services/tavus/video.py b/src/pipecat/services/tavus/video.py index 999d712d0..9f40b709d 100644 --- a/src/pipecat/services/tavus/video.py +++ b/src/pipecat/services/tavus/video.py @@ -42,6 +42,7 @@ class TavusVideoService(AIService): are routed through Pipecat's media pipeline. In use cases with DailyTransport, this creates two distinct virtual rooms: + - Tavus room: Contains the Tavus Avatar and the Pipecat Bot - User room: Contains the Pipecat Bot and the user """ diff --git a/src/pipecat/services/tts_service.py b/src/pipecat/services/tts_service.py index 183ef1f19..1d97045fe 100644 --- a/src/pipecat/services/tts_service.py +++ b/src/pipecat/services/tts_service.py @@ -549,12 +549,11 @@ class WebsocketTTSService(TTSService, WebsocketService): Event handlers: on_connection_error: Called when a websocket connection error occurs. - Example: - ```python + Example:: + @tts.event_handler("on_connection_error") async def on_connection_error(tts: TTSService, error: str): logger.error(f"TTS connection error: {error}") - ``` """ def __init__(self, *, reconnect_on_error: bool = True, **kwargs): @@ -622,12 +621,11 @@ class WebsocketWordTTSService(WordTTSService, WebsocketService): Event handlers: on_connection_error: Called when a websocket connection error occurs. - Example: - ```python + Example:: + @tts.event_handler("on_connection_error") async def on_connection_error(tts: TTSService, error: str): logger.error(f"TTS connection error: {error}") - ``` """ def __init__(self, *, reconnect_on_error: bool = True, **kwargs): diff --git a/src/pipecat/sync/base_notifier.py b/src/pipecat/sync/base_notifier.py index 69a27b445..60321e282 100644 --- a/src/pipecat/sync/base_notifier.py +++ b/src/pipecat/sync/base_notifier.py @@ -4,14 +4,33 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base notifier interface for Pipecat.""" + from abc import ABC, abstractmethod class BaseNotifier(ABC): + """Abstract base class for notification mechanisms. + + Provides a standard interface for implementing notification and waiting + patterns used for event coordination and signaling between components + in the Pipecat framework. + """ + @abstractmethod async def notify(self): + """Send a notification signal. + + Implementations should trigger any waiting coroutines or processes + that are blocked on this notifier. + """ pass @abstractmethod async def wait(self): + """Wait for a notification signal. + + Implementations should block until a notification is received + from the corresponding notify() call. + """ pass diff --git a/src/pipecat/sync/event_notifier.py b/src/pipecat/sync/event_notifier.py index f62ba4f6f..6708c2404 100644 --- a/src/pipecat/sync/event_notifier.py +++ b/src/pipecat/sync/event_notifier.py @@ -4,18 +4,42 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Event-based notifier implementation using asyncio Event primitives.""" + import asyncio from pipecat.sync.base_notifier import BaseNotifier class EventNotifier(BaseNotifier): + """Event-based notifier using asyncio.Event for task synchronization. + + Provides a simple notification mechanism where one task can signal + an event and other tasks can wait for that event to occur. The event + is automatically cleared after each wait operation. + """ + def __init__(self): + """Initialize the event notifier. + + Creates an internal asyncio.Event for managing notifications. + """ self._event = asyncio.Event() async def notify(self): + """Signal the event to notify waiting tasks. + + Sets the internal event, causing any tasks waiting on this + notifier to be awakened. + """ self._event.set() async def wait(self): + """Wait for the event to be signaled. + + Blocks until another task calls notify(). Automatically clears + the event after being awakened so subsequent calls will wait + for the next notification. + """ await self._event.wait() self._event.clear() diff --git a/src/pipecat/tests/utils.py b/src/pipecat/tests/utils.py index 3ea52bf26..2c0f65b2d 100644 --- a/src/pipecat/tests/utils.py +++ b/src/pipecat/tests/utils.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Testing utilities for Pipecat pipeline components.""" + import asyncio from dataclasses import dataclass from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence, Tuple @@ -24,15 +26,27 @@ from pipecat.processors.frame_processor import FrameDirection, FrameProcessor @dataclass class SleepFrame(SystemFrame): - """This frame is used by test framework to introduce some sleep time before - the next frame is pushed. This is useful to control system frames vs data or - control frames. + """A system frame that introduces a sleep delay in the test pipeline. + + This frame is used by the test framework to control timing between + frame processing, allowing tests to separate system frames from + data or control frames. + + Parameters: + sleep: Duration to sleep in seconds before processing the next frame. """ sleep: float = 0.1 class HeartbeatsObserver(BaseObserver): + """Observer that monitors heartbeat frames from a specific processor. + + This observer watches for HeartbeatFrames from a target processor and + invokes a callback when they are detected, useful for testing timing + and lifecycle events. + """ + def __init__( self, *, @@ -40,11 +54,23 @@ class HeartbeatsObserver(BaseObserver): heartbeat_callback: Callable[[FrameProcessor, HeartbeatFrame], Awaitable[None]], **kwargs, ): + """Initialize the heartbeats observer. + + Args: + target: The frame processor to monitor for heartbeat frames. + heartbeat_callback: Async callback function to invoke when heartbeats are detected. + **kwargs: Additional arguments passed to the parent observer. + """ super().__init__(**kwargs) self._target = target self._callback = heartbeat_callback async def on_push_frame(self, data: FramePushed): + """Handle frame push events and detect heartbeats from target processor. + + Args: + data: The frame push event data containing source and frame information. + """ src = data.source frame = data.frame @@ -53,6 +79,13 @@ class HeartbeatsObserver(BaseObserver): class QueuedFrameProcessor(FrameProcessor): + """A processor that captures frames in a queue for testing purposes. + + This processor intercepts frames flowing in a specific direction and + stores them in a queue for later inspection during testing, while + still allowing the frames to continue through the pipeline. + """ + def __init__( self, *, @@ -60,12 +93,25 @@ class QueuedFrameProcessor(FrameProcessor): queue_direction: FrameDirection, ignore_start: bool = True, ): + """Initialize the queued frame processor. + + Args: + queue: The asyncio queue to store captured frames. + queue_direction: The direction of frames to capture (UPSTREAM or DOWNSTREAM). + ignore_start: Whether to ignore StartFrames when capturing. + """ super().__init__() self._queue = queue self._queue_direction = queue_direction self._ignore_start = ignore_start async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames and capture them in the queue if they match the direction. + + Args: + frame: The frame to process. + direction: The direction the frame is flowing. + """ await super().process_frame(frame, direction) if direction == self._queue_direction: @@ -85,6 +131,28 @@ async def run_test( start_metadata: Optional[Dict[str, Any]] = None, send_end_frame: bool = True, ) -> Tuple[Sequence[Frame], Sequence[Frame]]: + """Run a test pipeline with the specified processor and validate frame flow. + + This function creates a test pipeline with the given processor, sends the + specified frames through it, and validates that the expected frames are + received in both upstream and downstream directions. + + Args: + processor: The frame processor to test. + frames_to_send: Sequence of frames to send through the processor. + expected_down_frames: Expected frame types flowing downstream (optional). + expected_up_frames: Expected frame types flowing upstream (optional). + ignore_start: Whether to ignore StartFrames in frame validation. + observers: Optional list of observers to attach to the pipeline. + start_metadata: Optional metadata to include with the StartFrame. + send_end_frame: Whether to send an EndFrame at the end of the test. + + Returns: + Tuple containing (downstream_frames, upstream_frames) that were received. + + Raises: + AssertionError: If the received frames don't match the expected frame types. + """ observers = observers or [] start_metadata = start_metadata or {} diff --git a/src/pipecat/transports/base_input.py b/src/pipecat/transports/base_input.py index 93cd90825..5b71ba69c 100644 --- a/src/pipecat/transports/base_input.py +++ b/src/pipecat/transports/base_input.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base input transport implementation for Pipecat. + +This module provides the BaseInputTransport class which handles audio and video +input processing, including VAD, turn analysis, and interruption management. +""" + import asyncio from concurrent.futures import ThreadPoolExecutor from typing import Optional @@ -47,7 +53,20 @@ AUDIO_INPUT_TIMEOUT_SECS = 0.5 class BaseInputTransport(FrameProcessor): + """Base class for input transport implementations. + + Handles audio and video input processing including Voice Activity Detection, + turn analysis, audio filtering, and user interaction management. Supports + interruption handling and provides hooks for transport-specific implementations. + """ + def __init__(self, params: TransportParams, **kwargs): + """Initialize the base input transport. + + Args: + params: Transport configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(**kwargs) self._params = params @@ -115,25 +134,54 @@ class BaseInputTransport(FrameProcessor): self._params.video_out_color_format = self._params.camera_out_color_format def enable_audio_in_stream_on_start(self, enabled: bool) -> None: + """Enable or disable audio streaming on transport start. + + Args: + enabled: Whether to start audio streaming immediately on transport start. + """ logger.debug(f"Enabling audio on start. {enabled}") self._params.audio_in_stream_on_start = enabled async def start_audio_in_streaming(self): + """Start audio input streaming. + + Override in subclasses to implement transport-specific audio streaming. + """ pass @property def sample_rate(self) -> int: + """Get the current audio sample rate. + + Returns: + The sample rate in Hz. + """ return self._sample_rate @property def vad_analyzer(self) -> Optional[VADAnalyzer]: + """Get the Voice Activity Detection analyzer. + + Returns: + The VAD analyzer instance if configured, None otherwise. + """ return self._params.vad_analyzer @property def turn_analyzer(self) -> Optional[BaseTurnAnalyzer]: + """Get the turn-taking analyzer. + + Returns: + The turn analyzer instance if configured, None otherwise. + """ return self._params.turn_analyzer async def start(self, frame: StartFrame): + """Start the input transport and initialize components. + + Args: + frame: The start frame containing initialization parameters. + """ self._paused = False self._user_speaking = False @@ -152,6 +200,11 @@ class BaseInputTransport(FrameProcessor): await self._params.audio_in_filter.start(self._sample_rate) async def stop(self, frame: EndFrame): + """Stop the input transport and cleanup resources. + + Args: + frame: The end frame signaling transport shutdown. + """ # Cancel and wait for the audio input task to finish. await self._cancel_audio_task() # Stop audio filter. @@ -159,6 +212,11 @@ class BaseInputTransport(FrameProcessor): await self._params.audio_in_filter.stop() async def pause(self, frame: StopFrame): + """Pause the input transport temporarily. + + Args: + frame: The stop frame signaling transport pause. + """ self._paused = True # Cancel task so we clear the queue await self._cancel_audio_task() @@ -166,19 +224,38 @@ class BaseInputTransport(FrameProcessor): self._create_audio_task() async def cancel(self, frame: CancelFrame): + """Cancel the input transport and stop all processing. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ # Cancel and wait for the audio input task to finish. await self._cancel_audio_task() async def set_transport_ready(self, frame: StartFrame): - """To be called when the transport is ready to stream.""" + """Called when the transport is ready to stream. + + Args: + frame: The start frame containing initialization parameters. + """ # Create audio input queue and task if needed. self._create_audio_task() async def push_video_frame(self, frame: InputImageRawFrame): + """Push a video frame downstream if video input is enabled. + + Args: + frame: The input video frame to process. + """ if self._params.video_in_enabled and not self._paused: await self.push_frame(frame) async def push_audio_frame(self, frame: InputAudioRawFrame): + """Push an audio frame to the processing queue if audio input is enabled. + + Args: + frame: The input audio frame to process. + """ if self._params.audio_in_enabled and not self._paused: await self._audio_in_queue.put(frame) @@ -187,6 +264,12 @@ class BaseInputTransport(FrameProcessor): # async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process incoming frames and handle transport-specific logic. + + Args: + frame: The frame to process. + direction: The direction of frame flow in the pipeline. + """ await super().process_frame(frame, direction) # Specific system frames @@ -238,12 +321,14 @@ class BaseInputTransport(FrameProcessor): # async def _handle_bot_interruption(self, frame: BotInterruptionFrame): + """Handle bot interruption frames.""" logger.debug("Bot interruption") if self.interruptions_allowed: await self._start_interruption() await self.push_frame(StartInterruptionFrame()) async def _handle_user_interruption(self, frame: Frame): + """Handle user interruption events based on speaking state.""" if isinstance(frame, UserStartedSpeakingFrame): logger.debug("User started speaking") self._user_speaking = True @@ -281,9 +366,11 @@ class BaseInputTransport(FrameProcessor): # async def _handle_bot_started_speaking(self, frame: BotStartedSpeakingFrame): + """Update bot speaking state when bot starts speaking.""" self._bot_speaking = True async def _handle_bot_stopped_speaking(self, frame: BotStoppedSpeakingFrame): + """Update bot speaking state when bot stops speaking.""" self._bot_speaking = False # @@ -291,16 +378,19 @@ class BaseInputTransport(FrameProcessor): # def _create_audio_task(self): + """Create the audio processing task if audio input is enabled.""" if not self._audio_task and self._params.audio_in_enabled: self._audio_in_queue = asyncio.Queue() self._audio_task = self.create_task(self._audio_task_handler()) async def _cancel_audio_task(self): + """Cancel and cleanup the audio processing task.""" if self._audio_task: await self.cancel_task(self._audio_task) self._audio_task = None async def _vad_analyze(self, audio_frame: InputAudioRawFrame) -> VADState: + """Analyze audio frame for voice activity.""" state = VADState.QUIET if self.vad_analyzer: state = await self.get_event_loop().run_in_executor( @@ -309,6 +399,7 @@ class BaseInputTransport(FrameProcessor): return state async def _handle_vad(self, audio_frame: InputAudioRawFrame, vad_state: VADState): + """Handle Voice Activity Detection results and generate appropriate frames.""" new_vad_state = await self._vad_analyze(audio_frame) if ( new_vad_state != vad_state @@ -339,18 +430,21 @@ class BaseInputTransport(FrameProcessor): return vad_state async def _handle_end_of_turn(self): + """Handle end-of-turn analysis and generate prediction results.""" if self.turn_analyzer: state, prediction = await self.turn_analyzer.analyze_end_of_turn() await self._handle_prediction_result(prediction) await self._handle_end_of_turn_complete(state) async def _handle_end_of_turn_complete(self, state: EndOfTurnState): + """Handle completion of end-of-turn analysis.""" if state == EndOfTurnState.COMPLETE: await self._handle_user_interruption(UserStoppedSpeakingFrame()) async def _run_turn_analyzer( self, frame: InputAudioRawFrame, vad_state: VADState, previous_vad_state: VADState ): + """Run turn analysis on audio frame and handle results.""" is_speech = vad_state == VADState.SPEAKING or vad_state == VADState.STARTING # If silence exceeds threshold, we are going to receive EndOfTurnState.COMPLETE end_of_turn_state = self._params.turn_analyzer.append_audio(frame.audio, is_speech) @@ -361,6 +455,7 @@ class BaseInputTransport(FrameProcessor): await self._handle_end_of_turn() async def _audio_task_handler(self): + """Main audio processing task handler for VAD and turn analysis.""" vad_state: VADState = VADState.QUIET while True: try: @@ -399,9 +494,5 @@ class BaseInputTransport(FrameProcessor): self.reset_watchdog() async def _handle_prediction_result(self, result: MetricsData): - """Handle a prediction result event from the turn analyzer. - - Args: - result: The prediction result MetricsData. - """ + """Handle a prediction result event from the turn analyzer.""" await self.push_frame(MetricsFrame(data=[result])) diff --git a/src/pipecat/transports/base_output.py b/src/pipecat/transports/base_output.py index 36d0536d7..f90b4c553 100644 --- a/src/pipecat/transports/base_output.py +++ b/src/pipecat/transports/base_output.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base output transport implementation for Pipecat. + +This module provides the BaseOutputTransport class which handles audio and video +output processing, including frame buffering, mixing, timing, and media streaming. +""" + import asyncio import itertools import sys @@ -46,7 +52,20 @@ BOT_VAD_STOP_SECS = 0.35 class BaseOutputTransport(FrameProcessor): + """Base class for output transport implementations. + + Handles audio and video output processing including frame buffering, audio mixing, + timing coordination, and media streaming. Supports multiple output destinations + and provides interruption handling for real-time communication. + """ + def __init__(self, params: TransportParams, **kwargs): + """Initialize the base output transport. + + Args: + params: Transport configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(**kwargs) self._params = params @@ -67,13 +86,28 @@ class BaseOutputTransport(FrameProcessor): @property def sample_rate(self) -> int: + """Get the current audio sample rate. + + Returns: + The sample rate in Hz. + """ return self._sample_rate @property def audio_chunk_size(self) -> int: + """Get the audio chunk size for output processing. + + Returns: + The size of audio chunks in bytes. + """ return self._audio_chunk_size async def start(self, frame: StartFrame): + """Start the output transport and initialize components. + + Args: + frame: The start frame containing initialization parameters. + """ self._sample_rate = self._params.audio_out_sample_rate or frame.audio_out_sample_rate # We will write 10ms*CHUNKS of audio at a time (where CHUNKS is the @@ -83,15 +117,29 @@ class BaseOutputTransport(FrameProcessor): self._audio_chunk_size = audio_bytes_10ms * self._params.audio_out_10ms_chunks async def stop(self, frame: EndFrame): + """Stop the output transport and cleanup resources. + + Args: + frame: The end frame signaling transport shutdown. + """ for _, sender in self._media_senders.items(): await sender.stop(frame) async def cancel(self, frame: CancelFrame): + """Cancel the output transport and stop all processing. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ for _, sender in self._media_senders.items(): await sender.cancel(frame) async def set_transport_ready(self, frame: StartFrame): - """To be called when the transport is ready to stream.""" + """Called when the transport is ready to stream. + + Args: + frame: The start frame containing initialization parameters. + """ # Register destinations. for destination in self._params.audio_out_destinations: await self.register_audio_destination(destination) @@ -127,27 +175,67 @@ class BaseOutputTransport(FrameProcessor): await self._media_senders[destination].start(frame) async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): + """Send a transport message. + + Args: + frame: The transport message frame to send. + """ pass async def register_video_destination(self, destination: str): + """Register a video output destination. + + Args: + destination: The destination identifier to register. + """ pass async def register_audio_destination(self, destination: str): + """Register an audio output destination. + + Args: + destination: The destination identifier to register. + """ pass async def write_video_frame(self, frame: OutputImageRawFrame): + """Write a video frame to the transport. + + Args: + frame: The output video frame to write. + """ pass async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the transport. + + Args: + frame: The output audio frame to write. + """ pass async def write_dtmf(self, frame: OutputDTMFFrame | OutputDTMFUrgentFrame): + """Write a DTMF tone to the transport. + + Args: + frame: The DTMF frame to write. + """ pass async def send_audio(self, frame: OutputAudioRawFrame): + """Send an audio frame downstream. + + Args: + frame: The audio frame to send. + """ await self.queue_frame(frame, FrameDirection.DOWNSTREAM) async def send_image(self, frame: OutputImageRawFrame | SpriteFrame): + """Send an image frame downstream. + + Args: + frame: The image frame to send. + """ await self.queue_frame(frame, FrameDirection.DOWNSTREAM) # @@ -155,6 +243,12 @@ class BaseOutputTransport(FrameProcessor): # async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process incoming frames and handle transport-specific logic. + + Args: + frame: The frame to process. + direction: The direction of frame flow in the pipeline. + """ await super().process_frame(frame, direction) # @@ -200,6 +294,7 @@ class BaseOutputTransport(FrameProcessor): await self._handle_frame(frame) async def _handle_frame(self, frame: Frame): + """Handle frames by routing them to appropriate media senders.""" if frame.transport_destination not in self._media_senders: logger.warning( f"{self} destination [{frame.transport_destination}] not registered for frame {frame}" @@ -226,6 +321,12 @@ class BaseOutputTransport(FrameProcessor): # class MediaSender: + """Handles media streaming for a specific destination. + + Manages audio and video output processing including buffering, timing, + mixing, and frame delivery for a single output destination. + """ + def __init__( self, transport: "BaseOutputTransport", @@ -235,6 +336,15 @@ class BaseOutputTransport(FrameProcessor): audio_chunk_size: int, params: TransportParams, ): + """Initialize the media sender. + + Args: + transport: The parent transport instance. + destination: The destination identifier for this sender. + sample_rate: The audio sample rate in Hz. + audio_chunk_size: The size of audio chunks in bytes. + params: Transport configuration parameters. + """ self._transport = transport self._destination = destination self._sample_rate = sample_rate @@ -266,13 +376,28 @@ class BaseOutputTransport(FrameProcessor): @property def sample_rate(self) -> int: + """Get the audio sample rate. + + Returns: + The sample rate in Hz. + """ return self._sample_rate @property def audio_chunk_size(self) -> int: + """Get the audio chunk size. + + Returns: + The size of audio chunks in bytes. + """ return self._audio_chunk_size async def start(self, frame: StartFrame): + """Start the media sender and initialize components. + + Args: + frame: The start frame containing initialization parameters. + """ self._audio_buffer = bytearray() # Create all tasks. @@ -293,6 +418,11 @@ class BaseOutputTransport(FrameProcessor): await self._mixer.start(self._sample_rate) async def stop(self, frame: EndFrame): + """Stop the media sender and cleanup resources. + + Args: + frame: The end frame signaling sender shutdown. + """ # Let the sink tasks process the queue until they reach this EndFrame. await self._clock_queue.put((sys.maxsize, frame.id, frame)) await self._audio_queue.put(frame) @@ -314,12 +444,22 @@ class BaseOutputTransport(FrameProcessor): await self._cancel_video_task() async def cancel(self, frame: CancelFrame): + """Cancel the media sender and stop all processing. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ # Since we are cancelling everything it doesn't matter what task we cancel first. await self._cancel_audio_task() await self._cancel_clock_task() await self._cancel_video_task() async def handle_interruptions(self, _: StartInterruptionFrame): + """Handle interruption events by restarting tasks and clearing buffers. + + Args: + _: The start interruption frame (unused). + """ if not self._transport.interruptions_allowed: return @@ -335,6 +475,11 @@ class BaseOutputTransport(FrameProcessor): await self._bot_stopped_speaking() async def handle_audio_frame(self, frame: OutputAudioRawFrame): + """Handle incoming audio frames by buffering and chunking. + + Args: + frame: The output audio frame to handle. + """ if not self._params.audio_out_enabled: return @@ -357,6 +502,11 @@ class BaseOutputTransport(FrameProcessor): self._audio_buffer = self._audio_buffer[self._audio_chunk_size :] async def handle_image_frame(self, frame: OutputImageRawFrame | SpriteFrame): + """Handle incoming image frames for video output. + + Args: + frame: The output image or sprite frame to handle. + """ if not self._params.video_out_enabled: return @@ -368,12 +518,27 @@ class BaseOutputTransport(FrameProcessor): await self._set_video_images(frame.images) async def handle_timed_frame(self, frame: Frame): + """Handle frames with presentation timestamps. + + Args: + frame: The frame with timing information to handle. + """ await self._clock_queue.put((frame.pts, frame.id, frame)) async def handle_sync_frame(self, frame: Frame): + """Handle frames that need synchronized processing. + + Args: + frame: The frame to handle synchronously. + """ await self._audio_queue.put(frame) async def handle_mixer_control_frame(self, frame: MixerControlFrame): + """Handle audio mixer control frames. + + Args: + frame: The mixer control frame to handle. + """ if self._mixer: await self._mixer.process_frame(frame) @@ -382,16 +547,19 @@ class BaseOutputTransport(FrameProcessor): # def _create_audio_task(self): + """Create the audio processing task.""" if not self._audio_task: self._audio_queue = asyncio.Queue() self._audio_task = self._transport.create_task(self._audio_task_handler()) async def _cancel_audio_task(self): + """Cancel and cleanup the audio processing task.""" if self._audio_task: await self._transport.cancel_task(self._audio_task) self._audio_task = None async def _bot_started_speaking(self): + """Handle bot started speaking event.""" if not self._bot_speaking: logger.debug( f"Bot{f' [{self._destination}]' if self._destination else ''} started speaking" @@ -407,6 +575,7 @@ class BaseOutputTransport(FrameProcessor): self._bot_speaking = True async def _bot_stopped_speaking(self): + """Handle bot stopped speaking event.""" if self._bot_speaking: logger.debug( f"Bot{f' [{self._destination}]' if self._destination else ''} stopped speaking" @@ -426,6 +595,11 @@ class BaseOutputTransport(FrameProcessor): self._audio_buffer = bytearray() async def _handle_frame(self, frame: Frame): + """Handle various frame types with appropriate processing. + + Args: + frame: The frame to handle. + """ if isinstance(frame, OutputImageRawFrame): await self._set_video_image(frame) elif isinstance(frame, SpriteFrame): @@ -436,6 +610,12 @@ class BaseOutputTransport(FrameProcessor): await self._transport.write_dtmf(frame) def _next_frame(self) -> AsyncGenerator[Frame, None]: + """Generate the next frame for audio processing. + + Returns: + An async generator yielding frames for processing. + """ + async def without_mixer(vad_stop_secs: float) -> AsyncGenerator[Frame, None]: while True: try: @@ -480,6 +660,7 @@ class BaseOutputTransport(FrameProcessor): return without_mixer(BOT_VAD_STOP_SECS) async def _audio_task_handler(self): + """Main audio processing task handler.""" # Push a BotSpeakingFrame every 200ms, we don't really need to push it # at every audio chunk. If the audio chunk is bigger than 200ms, push at # every audio chunk. @@ -518,23 +699,36 @@ class BaseOutputTransport(FrameProcessor): # def _create_video_task(self): + """Create the video processing task if video output is enabled.""" if not self._video_task and self._params.video_out_enabled: self._video_queue = asyncio.Queue() self._video_task = self._transport.create_task(self._video_task_handler()) async def _cancel_video_task(self): + """Cancel and cleanup the video processing task.""" # Stop video output task. if self._video_task: await self._transport.cancel_task(self._video_task) self._video_task = None async def _set_video_image(self, image: OutputImageRawFrame): + """Set a single video image for cycling output. + + Args: + image: The image frame to cycle for video output. + """ self._video_images = itertools.cycle([image]) async def _set_video_images(self, images: List[OutputImageRawFrame]): + """Set multiple video images for cycling output. + + Args: + images: The list of image frames to cycle for video output. + """ self._video_images = itertools.cycle(images) async def _video_task_handler(self): + """Main video processing task handler.""" self._video_start_time = None self._video_frame_index = 0 self._video_frame_duration = 1 / self._params.video_out_framerate @@ -550,6 +744,7 @@ class BaseOutputTransport(FrameProcessor): await asyncio.sleep(self._video_frame_duration) async def _video_is_live_handler(self): + """Handle live video streaming with frame timing.""" image = await self._video_queue.get() # We get the start time as soon as we get the first image. @@ -575,6 +770,12 @@ class BaseOutputTransport(FrameProcessor): self._video_queue.task_done() async def _draw_image(self, frame: OutputImageRawFrame): + """Draw/render an image frame with resizing if needed. + + Args: + frame: The image frame to draw. + """ + def resize_frame(frame: OutputImageRawFrame) -> OutputImageRawFrame: desired_size = (self._params.video_out_width, self._params.video_out_height) @@ -601,16 +802,19 @@ class BaseOutputTransport(FrameProcessor): # def _create_clock_task(self): + """Create the clock/timing processing task.""" if not self._clock_task: self._clock_queue = WatchdogPriorityQueue(self._transport.task_manager) self._clock_task = self._transport.create_task(self._clock_task_handler()) async def _cancel_clock_task(self): + """Cancel and cleanup the clock processing task.""" if self._clock_task: await self._transport.cancel_task(self._clock_task) self._clock_task = None async def _clock_task_handler(self): + """Main clock/timing task handler for timed frame delivery.""" running = True while running: timestamp, _, frame = await self._clock_queue.get() diff --git a/src/pipecat/transports/base_transport.py b/src/pipecat/transports/base_transport.py index c634babb8..8e722127f 100644 --- a/src/pipecat/transports/base_transport.py +++ b/src/pipecat/transports/base_transport.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base transport classes for Pipecat. + +This module provides the foundation for transport implementations including +parameter configuration and abstract base classes for input/output transport +functionality. +""" + from abc import abstractmethod from typing import List, Mapping, Optional @@ -18,6 +25,45 @@ from pipecat.utils.base_object import BaseObject class TransportParams(BaseModel): + """Configuration parameters for transport implementations. + + Parameters: + camera_in_enabled: Enable camera input (deprecated, use video_in_enabled). + camera_out_enabled: Enable camera output (deprecated, use video_out_enabled). + camera_out_is_live: Enable real-time camera output (deprecated). + camera_out_width: Camera output width in pixels (deprecated). + camera_out_height: Camera output height in pixels (deprecated). + camera_out_bitrate: Camera output bitrate in bits per second (deprecated). + camera_out_framerate: Camera output frame rate in FPS (deprecated). + camera_out_color_format: Camera output color format string (deprecated). + audio_out_enabled: Enable audio output streaming. + audio_out_sample_rate: Output audio sample rate in Hz. + audio_out_channels: Number of output audio channels. + audio_out_bitrate: Output audio bitrate in bits per second. + audio_out_10ms_chunks: Number of 10ms chunks to buffer for output. + audio_out_mixer: Audio mixer instance or destination mapping. + audio_out_destinations: List of audio output destination identifiers. + audio_in_enabled: Enable audio input streaming. + audio_in_sample_rate: Input audio sample rate in Hz. + audio_in_channels: Number of input audio channels. + audio_in_filter: Audio filter to apply to input audio. + audio_in_stream_on_start: Start audio streaming immediately on transport start. + audio_in_passthrough: Pass through input audio frames downstream. + video_in_enabled: Enable video input streaming. + video_out_enabled: Enable video output streaming. + video_out_is_live: Enable real-time video output streaming. + video_out_width: Video output width in pixels. + video_out_height: Video output height in pixels. + video_out_bitrate: Video output bitrate in bits per second. + video_out_framerate: Video output frame rate in FPS. + video_out_color_format: Video output color format string. + video_out_destinations: List of video output destination identifiers. + vad_enabled: Enable Voice Activity Detection (deprecated). + vad_audio_passthrough: Enable VAD audio passthrough (deprecated). + vad_analyzer: Voice Activity Detection analyzer instance. + turn_analyzer: Turn-taking analyzer instance for conversation management. + """ + model_config = ConfigDict(arbitrary_types_allowed=True) camera_in_enabled: bool = False @@ -57,6 +103,12 @@ class TransportParams(BaseModel): class BaseTransport(BaseObject): + """Base class for transport implementations. + + Provides the foundation for transport classes that handle media streaming, + including input and output frame processors for audio and video data. + """ + def __init__( self, *, @@ -64,14 +116,31 @@ class BaseTransport(BaseObject): input_name: Optional[str] = None, output_name: Optional[str] = None, ): + """Initialize the base transport. + + Args: + name: Optional name for the transport instance. + input_name: Optional name for the input processor. + output_name: Optional name for the output processor. + """ super().__init__(name=name) self._input_name = input_name self._output_name = output_name @abstractmethod def input(self) -> FrameProcessor: + """Get the input frame processor for this transport. + + Returns: + The frame processor that handles incoming frames. + """ pass @abstractmethod def output(self) -> FrameProcessor: + """Get the output frame processor for this transport. + + Returns: + The frame processor that handles outgoing frames. + """ pass diff --git a/src/pipecat/transports/local/audio.py b/src/pipecat/transports/local/audio.py index 1d5e80158..b38d1dd8e 100644 --- a/src/pipecat/transports/local/audio.py +++ b/src/pipecat/transports/local/audio.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Local audio transport implementation for Pipecat. + +This module provides a local audio transport that uses PyAudio for real-time +audio input and output through the system's default audio devices. +""" + import asyncio from concurrent.futures import ThreadPoolExecutor from typing import Optional @@ -27,14 +33,33 @@ except ModuleNotFoundError as e: class LocalAudioTransportParams(TransportParams): + """Configuration parameters for local audio transport. + + Parameters: + input_device_index: PyAudio device index for audio input. If None, uses default. + output_device_index: PyAudio device index for audio output. If None, uses default. + """ + input_device_index: Optional[int] = None output_device_index: Optional[int] = None class LocalAudioInputTransport(BaseInputTransport): + """Local audio input transport using PyAudio. + + Captures audio from the system's audio input device and converts it to + InputAudioRawFrame objects for processing in the pipeline. + """ + _params: LocalAudioTransportParams def __init__(self, py_audio: pyaudio.PyAudio, params: LocalAudioTransportParams): + """Initialize the local audio input transport. + + Args: + py_audio: PyAudio instance for audio device management. + params: Transport configuration parameters. + """ super().__init__(params) self._py_audio = py_audio @@ -42,6 +67,11 @@ class LocalAudioInputTransport(BaseInputTransport): self._sample_rate = 0 async def start(self, frame: StartFrame): + """Start the audio input stream. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._in_stream: @@ -64,6 +94,7 @@ class LocalAudioInputTransport(BaseInputTransport): await self.set_transport_ready(frame) async def cleanup(self): + """Stop and cleanup the audio input stream.""" await super().cleanup() if self._in_stream: self._in_stream.stop_stream() @@ -71,6 +102,7 @@ class LocalAudioInputTransport(BaseInputTransport): self._in_stream = None def _audio_in_callback(self, in_data, frame_count, time_info, status): + """Callback function for PyAudio input stream.""" frame = InputAudioRawFrame( audio=in_data, sample_rate=self._sample_rate, @@ -83,9 +115,21 @@ class LocalAudioInputTransport(BaseInputTransport): class LocalAudioOutputTransport(BaseOutputTransport): + """Local audio output transport using PyAudio. + + Plays audio frames through the system's audio output device by converting + OutputAudioRawFrame objects to playable audio data. + """ + _params: LocalAudioTransportParams def __init__(self, py_audio: pyaudio.PyAudio, params: LocalAudioTransportParams): + """Initialize the local audio output transport. + + Args: + py_audio: PyAudio instance for audio device management. + params: Transport configuration parameters. + """ super().__init__(params) self._py_audio = py_audio @@ -97,6 +141,11 @@ class LocalAudioOutputTransport(BaseOutputTransport): self._executor = ThreadPoolExecutor(max_workers=1) async def start(self, frame: StartFrame): + """Start the audio output stream. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._out_stream: @@ -116,6 +165,7 @@ class LocalAudioOutputTransport(BaseOutputTransport): await self.set_transport_ready(frame) async def cleanup(self): + """Stop and cleanup the audio output stream.""" await super().cleanup() if self._out_stream: self._out_stream.stop_stream() @@ -123,6 +173,11 @@ class LocalAudioOutputTransport(BaseOutputTransport): self._out_stream = None async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the output stream. + + Args: + frame: The audio frame to write to the output device. + """ if self._out_stream: await self.get_event_loop().run_in_executor( self._executor, self._out_stream.write, frame.audio @@ -130,7 +185,18 @@ class LocalAudioOutputTransport(BaseOutputTransport): class LocalAudioTransport(BaseTransport): + """Complete local audio transport with input and output capabilities. + + Provides a unified interface for local audio I/O using PyAudio, supporting + both audio capture and playback through the system's audio devices. + """ + def __init__(self, params: LocalAudioTransportParams): + """Initialize the local audio transport. + + Args: + params: Transport configuration parameters. + """ super().__init__() self._params = params self._pyaudio = pyaudio.PyAudio() @@ -143,11 +209,21 @@ class LocalAudioTransport(BaseTransport): # def input(self) -> FrameProcessor: + """Get the input frame processor for this transport. + + Returns: + The audio input transport processor. + """ if not self._input: self._input = LocalAudioInputTransport(self._pyaudio, self._params) return self._input def output(self) -> FrameProcessor: + """Get the output frame processor for this transport. + + Returns: + The audio output transport processor. + """ if not self._output: self._output = LocalAudioOutputTransport(self._pyaudio, self._params) return self._output diff --git a/src/pipecat/transports/local/tk.py b/src/pipecat/transports/local/tk.py index 73c17853a..c687370ce 100644 --- a/src/pipecat/transports/local/tk.py +++ b/src/pipecat/transports/local/tk.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Tkinter-based local transport implementation for Pipecat. + +This module provides a local transport using Tkinter for video display and +PyAudio for audio I/O, suitable for desktop applications and testing. +""" + import asyncio import tkinter as tk from concurrent.futures import ThreadPoolExecutor @@ -40,20 +46,44 @@ except ModuleNotFoundError as e: class TkTransportParams(TransportParams): + """Configuration parameters for Tkinter transport. + + Parameters: + audio_input_device_index: PyAudio device index for audio input. If None, uses default. + audio_output_device_index: PyAudio device index for audio output. If None, uses default. + """ + audio_input_device_index: Optional[int] = None audio_output_device_index: Optional[int] = None class TkInputTransport(BaseInputTransport): + """Tkinter-based audio input transport. + + Captures audio from the system's audio input device using PyAudio and + converts it to InputAudioRawFrame objects for pipeline processing. + """ + _params: TkTransportParams def __init__(self, py_audio: pyaudio.PyAudio, params: TkTransportParams): + """Initialize the Tkinter input transport. + + Args: + py_audio: PyAudio instance for audio device management. + params: Transport configuration parameters. + """ super().__init__(params) self._py_audio = py_audio self._in_stream = None self._sample_rate = 0 async def start(self, frame: StartFrame): + """Start the audio input stream. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._in_stream: @@ -76,6 +106,7 @@ class TkInputTransport(BaseInputTransport): await self.set_transport_ready(frame) async def cleanup(self): + """Stop and cleanup the audio input stream.""" await super().cleanup() if self._in_stream: self._in_stream.stop_stream() @@ -83,6 +114,7 @@ class TkInputTransport(BaseInputTransport): self._in_stream = None def _audio_in_callback(self, in_data, frame_count, time_info, status): + """Callback function for PyAudio input stream.""" frame = InputAudioRawFrame( audio=in_data, sample_rate=self._sample_rate, @@ -95,9 +127,22 @@ class TkInputTransport(BaseInputTransport): class TkOutputTransport(BaseOutputTransport): + """Tkinter-based audio and video output transport. + + Plays audio through PyAudio and displays video frames in a Tkinter window, + providing a complete multimedia output solution for desktop applications. + """ + _params: TkTransportParams def __init__(self, tk_root: tk.Tk, py_audio: pyaudio.PyAudio, params: TkTransportParams): + """Initialize the Tkinter output transport. + + Args: + tk_root: The root Tkinter window for video display. + py_audio: PyAudio instance for audio device management. + params: Transport configuration parameters. + """ super().__init__(params) self._py_audio = py_audio self._out_stream = None @@ -115,6 +160,11 @@ class TkOutputTransport(BaseOutputTransport): self._image_label.pack() async def start(self, frame: StartFrame): + """Start the audio output stream. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._out_stream: @@ -134,6 +184,7 @@ class TkOutputTransport(BaseOutputTransport): await self.set_transport_ready(frame) async def cleanup(self): + """Stop and cleanup the audio output stream.""" await super().cleanup() if self._out_stream: self._out_stream.stop_stream() @@ -141,15 +192,26 @@ class TkOutputTransport(BaseOutputTransport): self._out_stream = None async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the output stream. + + Args: + frame: The audio frame to write to the output device. + """ if self._out_stream: await self.get_event_loop().run_in_executor( self._executor, self._out_stream.write, frame.audio ) async def write_video_frame(self, frame: OutputImageRawFrame): + """Write a video frame to the Tkinter display. + + Args: + frame: The video frame to display in the Tkinter window. + """ self.get_event_loop().call_soon(self._write_frame_to_tk, frame) def _write_frame_to_tk(self, frame: OutputImageRawFrame): + """Write frame data to the Tkinter image label.""" width = frame.size[0] height = frame.size[1] data = f"P6 {width} {height} 255 ".encode() + frame.image @@ -162,7 +224,19 @@ class TkOutputTransport(BaseOutputTransport): class TkLocalTransport(BaseTransport): + """Complete Tkinter-based local transport with audio and video capabilities. + + Provides a unified interface for local multimedia I/O using Tkinter for video + display and PyAudio for audio, suitable for desktop applications and testing. + """ + def __init__(self, tk_root: tk.Tk, params: TkTransportParams): + """Initialize the Tkinter local transport. + + Args: + tk_root: The root Tkinter window for video display. + params: Transport configuration parameters. + """ super().__init__() self._tk_root = tk_root self._params = params @@ -176,11 +250,21 @@ class TkLocalTransport(BaseTransport): # def input(self) -> TkInputTransport: + """Get the input frame processor for this transport. + + Returns: + The Tkinter input transport processor. + """ if not self._input: self._input = TkInputTransport(self._pyaudio, self._params) return self._input def output(self) -> TkOutputTransport: + """Get the output frame processor for this transport. + + Returns: + The Tkinter output transport processor. + """ if not self._output: self._output = TkOutputTransport(self._tk_root, self._pyaudio, self._params) return self._output diff --git a/src/pipecat/transports/network/fastapi_websocket.py b/src/pipecat/transports/network/fastapi_websocket.py index 5ddaacff7..c3f3a933b 100644 --- a/src/pipecat/transports/network/fastapi_websocket.py +++ b/src/pipecat/transports/network/fastapi_websocket.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""FastAPI WebSocket transport implementation for Pipecat. + +This module provides WebSocket-based transport for real-time audio/video streaming +using FastAPI and WebSocket connections. Supports binary and text serialization +with configurable session timeouts and WAV header generation. +""" import asyncio import io @@ -45,19 +51,48 @@ except ModuleNotFoundError as e: class FastAPIWebsocketParams(TransportParams): + """Configuration parameters for FastAPI WebSocket transport. + + Parameters: + add_wav_header: Whether to add WAV headers to audio frames. + serializer: Frame serializer for encoding/decoding messages. + session_timeout: Session timeout in seconds, None for no timeout. + """ + add_wav_header: bool = False serializer: Optional[FrameSerializer] = None session_timeout: Optional[int] = None class FastAPIWebsocketCallbacks(BaseModel): + """Callback functions for WebSocket events. + + Parameters: + on_client_connected: Called when a client connects to the WebSocket. + on_client_disconnected: Called when a client disconnects from the WebSocket. + on_session_timeout: Called when a session timeout occurs. + """ + on_client_connected: Callable[[WebSocket], Awaitable[None]] on_client_disconnected: Callable[[WebSocket], Awaitable[None]] on_session_timeout: Callable[[WebSocket], Awaitable[None]] class FastAPIWebsocketClient: + """WebSocket client wrapper for handling connections and message passing. + + Manages WebSocket state, message sending/receiving, and connection lifecycle + with support for both binary and text message types. + """ + def __init__(self, websocket: WebSocket, is_binary: bool, callbacks: FastAPIWebsocketCallbacks): + """Initialize the WebSocket client. + + Args: + websocket: The FastAPI WebSocket connection. + is_binary: Whether to use binary message format. + callbacks: Event callback functions. + """ self._websocket = websocket self._closing = False self._is_binary = is_binary @@ -65,12 +100,27 @@ class FastAPIWebsocketClient: self._leave_counter = 0 async def setup(self, _: StartFrame): + """Set up the WebSocket client. + + Args: + _: The start frame (unused). + """ self._leave_counter += 1 def receive(self) -> typing.AsyncIterator[bytes | str]: + """Get an async iterator for receiving WebSocket messages. + + Returns: + An async iterator yielding bytes or strings based on message type. + """ return self._websocket.iter_bytes() if self._is_binary else self._websocket.iter_text() async def send(self, data: str | bytes): + """Send data through the WebSocket connection. + + Args: + data: The data to send (string or bytes). + """ try: if self._can_send(): if self._is_binary: @@ -89,6 +139,7 @@ class FastAPIWebsocketClient: await self.trigger_client_disconnected() async def disconnect(self): + """Disconnect the WebSocket client.""" self._leave_counter -= 1 if self._leave_counter > 0: return @@ -99,27 +150,47 @@ class FastAPIWebsocketClient: await self.trigger_client_disconnected() async def trigger_client_disconnected(self): + """Trigger the client disconnected callback.""" await self._callbacks.on_client_disconnected(self._websocket) async def trigger_client_connected(self): + """Trigger the client connected callback.""" await self._callbacks.on_client_connected(self._websocket) async def trigger_client_timeout(self): + """Trigger the client timeout callback.""" await self._callbacks.on_session_timeout(self._websocket) def _can_send(self): + """Check if data can be sent through the WebSocket.""" return self.is_connected and not self.is_closing @property def is_connected(self) -> bool: + """Check if the WebSocket is currently connected. + + Returns: + True if the WebSocket is in connected state. + """ return self._websocket.client_state == WebSocketState.CONNECTED @property def is_closing(self) -> bool: + """Check if the WebSocket is currently closing. + + Returns: + True if the WebSocket is in the process of closing. + """ return self._closing class FastAPIWebsocketInputTransport(BaseInputTransport): + """Input transport for FastAPI WebSocket connections. + + Handles incoming WebSocket messages, deserializes frames, and manages + connection monitoring with optional session timeouts. + """ + def __init__( self, transport: BaseTransport, @@ -127,6 +198,14 @@ class FastAPIWebsocketInputTransport(BaseInputTransport): params: FastAPIWebsocketParams, **kwargs, ): + """Initialize the WebSocket input transport. + + Args: + transport: The parent transport instance. + client: The WebSocket client wrapper. + params: Transport configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._transport = transport self._client = client @@ -138,6 +217,11 @@ class FastAPIWebsocketInputTransport(BaseInputTransport): self._initialized = False async def start(self, frame: StartFrame): + """Start the input transport and begin message processing. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -156,6 +240,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport): await self.set_transport_ready(frame) async def _stop_tasks(self): + """Stop all running tasks.""" if self._monitor_websocket_task: await self.cancel_task(self._monitor_websocket_task) self._monitor_websocket_task = None @@ -164,20 +249,32 @@ class FastAPIWebsocketInputTransport(BaseInputTransport): self._receive_task = None async def stop(self, frame: EndFrame): + """Stop the input transport and cleanup resources. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) await self._stop_tasks() await self._client.disconnect() async def cancel(self, frame: CancelFrame): + """Cancel the input transport and stop all processing. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) await self._stop_tasks() await self._client.disconnect() async def cleanup(self): + """Clean up transport resources.""" await super().cleanup() await self._transport.cleanup() async def _receive_messages(self): + """Main message receiving loop for WebSocket messages.""" try: async for message in WatchdogAsyncIterator( self._client.receive(), manager=self.task_manager @@ -206,6 +303,12 @@ class FastAPIWebsocketInputTransport(BaseInputTransport): class FastAPIWebsocketOutputTransport(BaseOutputTransport): + """Output transport for FastAPI WebSocket connections. + + Handles outgoing frame serialization, audio streaming with timing simulation, + and WebSocket message transmission with optional WAV header generation. + """ + def __init__( self, transport: BaseTransport, @@ -213,6 +316,14 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): params: FastAPIWebsocketParams, **kwargs, ): + """Initialize the WebSocket output transport. + + Args: + transport: The parent transport instance. + client: The WebSocket client wrapper. + params: Transport configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._transport = transport @@ -231,6 +342,11 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): self._initialized = False async def start(self, frame: StartFrame): + """Start the output transport and initialize timing. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -245,20 +361,37 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): + """Stop the output transport and cleanup resources. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) await self._write_frame(frame) await self._client.disconnect() async def cancel(self, frame: CancelFrame): + """Cancel the output transport and stop all processing. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) await self._write_frame(frame) await self._client.disconnect() async def cleanup(self): + """Clean up transport resources.""" await super().cleanup() await self._transport.cleanup() async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process outgoing frames with special handling for interruptions. + + Args: + frame: The frame to process. + direction: The direction of frame flow in the pipeline. + """ await super().process_frame(frame, direction) if isinstance(frame, StartInterruptionFrame): @@ -266,9 +399,19 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): self._next_send_time = 0 async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): + """Send a transport message frame. + + Args: + frame: The transport message frame to send. + """ await self._write_frame(frame) async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the WebSocket with timing simulation. + + Args: + frame: The output audio frame to write. + """ if self._client.is_closing: return @@ -303,6 +446,7 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): await self._write_audio_sleep() async def _write_frame(self, frame: Frame): + """Serialize and send a frame through the WebSocket.""" if not self._params.serializer: return @@ -314,6 +458,7 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): logger.error(f"{self} exception sending data: {e.__class__.__name__} ({e})") async def _write_audio_sleep(self): + """Simulate audio playback timing with appropriate delays.""" # Simulate a clock. current_time = time.monotonic() sleep_duration = max(0, self._next_send_time - current_time) @@ -325,6 +470,12 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport): class FastAPIWebsocketTransport(BaseTransport): + """FastAPI WebSocket transport for real-time audio/video streaming. + + Provides bidirectional WebSocket communication with frame serialization, + session management, and event handling for client connections and timeouts. + """ + def __init__( self, websocket: WebSocket, @@ -332,6 +483,14 @@ class FastAPIWebsocketTransport(BaseTransport): input_name: Optional[str] = None, output_name: Optional[str] = None, ): + """Initialize the FastAPI WebSocket transport. + + Args: + websocket: The FastAPI WebSocket connection. + params: Transport configuration parameters. + input_name: Optional name for the input processor. + output_name: Optional name for the output processor. + """ super().__init__(input_name=input_name, output_name=output_name) self._params = params @@ -361,16 +520,29 @@ class FastAPIWebsocketTransport(BaseTransport): self._register_event_handler("on_session_timeout") def input(self) -> FastAPIWebsocketInputTransport: + """Get the input transport processor. + + Returns: + The WebSocket input transport instance. + """ return self._input def output(self) -> FastAPIWebsocketOutputTransport: + """Get the output transport processor. + + Returns: + The WebSocket output transport instance. + """ return self._output async def _on_client_connected(self, websocket): + """Handle client connected event.""" await self._call_event_handler("on_client_connected", websocket) async def _on_client_disconnected(self, websocket): + """Handle client disconnected event.""" await self._call_event_handler("on_client_disconnected", websocket) async def _on_session_timeout(self, websocket): + """Handle session timeout event.""" await self._call_event_handler("on_session_timeout", websocket) diff --git a/src/pipecat/transports/network/small_webrtc.py b/src/pipecat/transports/network/small_webrtc.py index 9eedd7d95..140a7b537 100644 --- a/src/pipecat/transports/network/small_webrtc.py +++ b/src/pipecat/transports/network/small_webrtc.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Small WebRTC transport implementation for Pipecat. + +This module provides a WebRTC transport implementation using aiortc for +real-time audio and video communication. It supports bidirectional media +streaming, application messaging, and client connection management. +""" + import asyncio import fractions import time @@ -47,13 +54,32 @@ except ModuleNotFoundError as e: class SmallWebRTCCallbacks(BaseModel): + """Callback handlers for SmallWebRTC events. + + Parameters: + on_app_message: Called when an application message is received. + on_client_connected: Called when a client establishes connection. + on_client_disconnected: Called when a client disconnects. + """ + on_app_message: Callable[[Any], Awaitable[None]] on_client_connected: Callable[[SmallWebRTCConnection], Awaitable[None]] on_client_disconnected: Callable[[SmallWebRTCConnection], Awaitable[None]] class RawAudioTrack(AudioStreamTrack): + """Custom audio stream track for WebRTC output. + + Handles audio frame generation and timing for WebRTC transmission, + supporting queued audio data with proper synchronization. + """ + def __init__(self, sample_rate): + """Initialize the raw audio track. + + Args: + sample_rate: The audio sample rate in Hz. + """ super().__init__() self._sample_rate = sample_rate self._samples_per_10ms = sample_rate * 10 // 1000 @@ -64,7 +90,17 @@ class RawAudioTrack(AudioStreamTrack): self._chunk_queue = deque() def add_audio_bytes(self, audio_bytes: bytes): - """Adds bytes to the audio buffer and returns a Future that completes when the data is processed.""" + """Add audio bytes to the buffer for transmission. + + Args: + audio_bytes: Raw audio data to queue for transmission. + + Returns: + A Future that completes when the data is processed. + + Raises: + ValueError: If audio bytes are not a multiple of 10ms size. + """ if len(audio_bytes) % self._bytes_per_10ms != 0: raise ValueError("Audio bytes must be a multiple of 10ms size.") future = asyncio.get_running_loop().create_future() @@ -79,7 +115,11 @@ class RawAudioTrack(AudioStreamTrack): return future async def recv(self): - """Returns the next audio frame, generating silence if needed.""" + """Return the next audio frame for WebRTC transmission. + + Returns: + An AudioFrame containing the next audio data or silence. + """ # Compute required wait time for synchronization if self._timestamp > 0: wait = self._start + (self._timestamp / self._sample_rate) - time.time() @@ -106,18 +146,37 @@ class RawAudioTrack(AudioStreamTrack): class RawVideoTrack(VideoStreamTrack): + """Custom video stream track for WebRTC output. + + Handles video frame queuing and conversion for WebRTC transmission. + """ + def __init__(self, width, height): + """Initialize the raw video track. + + Args: + width: Video frame width in pixels. + height: Video frame height in pixels. + """ super().__init__() self._width = width self._height = height self._video_buffer = asyncio.Queue() def add_video_frame(self, frame): - """Adds a raw video frame to the buffer.""" + """Add a video frame to the transmission buffer. + + Args: + frame: The video frame to queue for transmission. + """ self._video_buffer.put_nowait(frame) async def recv(self): - """Returns the next video frame, waiting if the buffer is empty.""" + """Return the next video frame for WebRTC transmission. + + Returns: + A VideoFrame ready for WebRTC transmission. + """ raw_frame = await self._video_buffer.get() # Convert bytes to NumPy array @@ -134,6 +193,12 @@ class RawVideoTrack(VideoStreamTrack): class SmallWebRTCClient: + """WebRTC client implementation for handling connections and media streams. + + Manages WebRTC peer connections, audio/video streaming, and application + messaging through the SmallWebRTCConnection interface. + """ + FORMAT_CONVERSIONS = { "yuv420p": cv2.COLOR_YUV2RGB_I420, "yuvj420p": cv2.COLOR_YUV2RGB_I420, # OpenCV treats both the same @@ -142,6 +207,12 @@ class SmallWebRTCClient: } def __init__(self, webrtc_connection: SmallWebRTCConnection, callbacks: SmallWebRTCCallbacks): + """Initialize the WebRTC client. + + Args: + webrtc_connection: The underlying WebRTC connection handler. + callbacks: Event callbacks for connection and message handling. + """ self._webrtc_connection = webrtc_connection self._closing = False self._callbacks = callbacks @@ -180,14 +251,14 @@ class SmallWebRTCClient: await self._handle_app_message(message) def _convert_frame(self, frame_array: np.ndarray, format_name: str) -> np.ndarray: - """Convert a given frame to RGB format based on the input format. + """Convert a video frame to RGB format based on the input format. Args: - frame_array (np.ndarray): The input frame. - format_name (str): The format of the input frame. + frame_array: The input frame as a NumPy array. + format_name: The format of the input frame. Returns: - np.ndarray: The converted RGB frame. + The converted RGB frame as a NumPy array. Raises: ValueError: If the format is unsupported. @@ -203,8 +274,13 @@ class SmallWebRTCClient: return cv2.cvtColor(frame_array, conversion_code) async def read_video_frame(self): - """Reads a video frame from the given MediaStreamTrack, converts it to RGB, + """Read video frames from the WebRTC connection. + + Reads a video frame from the given MediaStreamTrack, converts it to RGB, and creates an InputImageRawFrame. + + Yields: + UserImageRawFrame objects containing video data from the peer. """ while True: if self._video_input_track is None: @@ -242,7 +318,13 @@ class SmallWebRTCClient: yield image_frame async def read_audio_frame(self): - """Reads 20ms of audio from the given MediaStreamTrack and creates an InputAudioRawFrame.""" + """Read audio frames from the WebRTC connection. + + Reads 20ms of audio from the given MediaStreamTrack and creates an InputAudioRawFrame. + + Yields: + InputAudioRawFrame objects containing audio data from the peer. + """ while True: if self._audio_input_track is None: await asyncio.sleep(0.01) @@ -285,20 +367,37 @@ class SmallWebRTCClient: yield audio_frame async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the WebRTC connection. + + Args: + frame: The audio frame to transmit. + """ if self._can_send() and self._audio_output_track: await self._audio_output_track.add_audio_bytes(frame.audio) async def write_video_frame(self, frame: OutputImageRawFrame): + """Write a video frame to the WebRTC connection. + + Args: + frame: The video frame to transmit. + """ if self._can_send() and self._video_output_track: self._video_output_track.add_video_frame(frame) async def setup(self, _params: TransportParams, frame): + """Set up the client with transport parameters. + + Args: + _params: Transport configuration parameters. + frame: The initialization frame containing setup data. + """ self._audio_in_channels = _params.audio_in_channels self._in_sample_rate = _params.audio_in_sample_rate or frame.audio_in_sample_rate self._out_sample_rate = _params.audio_out_sample_rate or frame.audio_out_sample_rate self._params = _params async def connect(self): + """Establish the WebRTC connection.""" if self._webrtc_connection.is_connected(): # already initialized return @@ -307,6 +406,7 @@ class SmallWebRTCClient: await self._webrtc_connection.connect() async def disconnect(self): + """Disconnect from the WebRTC peer.""" if self.is_connected and not self.is_closing: logger.info(f"Disconnecting to Small WebRTC") self._closing = True @@ -314,10 +414,16 @@ class SmallWebRTCClient: await self._handle_peer_disconnected() async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): + """Send an application message through the WebRTC connection. + + Args: + frame: The message frame to send. + """ if self._can_send(): self._webrtc_connection.send_app_message(frame.message) async def _handle_client_connected(self): + """Handle client connection establishment.""" # There is nothing to do here yet, the pipeline is still not ready if not self._params: return @@ -337,12 +443,14 @@ class SmallWebRTCClient: await self._callbacks.on_client_connected(self._webrtc_connection) async def _handle_peer_disconnected(self): + """Handle peer disconnection cleanup.""" self._audio_input_track = None self._video_input_track = None self._audio_output_track = None self._video_output_track = None async def _handle_client_closed(self): + """Handle client connection closure.""" self._audio_input_track = None self._video_input_track = None self._audio_output_track = None @@ -350,27 +458,52 @@ class SmallWebRTCClient: await self._callbacks.on_client_disconnected(self._webrtc_connection) async def _handle_app_message(self, message: Any): + """Handle incoming application messages.""" await self._callbacks.on_app_message(message) def _can_send(self): + """Check if the connection is ready for sending data.""" return self.is_connected and not self.is_closing @property def is_connected(self) -> bool: + """Check if the WebRTC connection is established. + + Returns: + True if connected to the peer. + """ return self._webrtc_connection.is_connected() @property def is_closing(self) -> bool: + """Check if the connection is in the process of closing. + + Returns: + True if the connection is closing. + """ return self._closing class SmallWebRTCInputTransport(BaseInputTransport): + """Input transport implementation for SmallWebRTC. + + Handles incoming audio and video streams from WebRTC peers, + including user image requests and application message handling. + """ + def __init__( self, client: SmallWebRTCClient, params: TransportParams, **kwargs, ): + """Initialize the WebRTC input transport. + + Args: + client: The WebRTC client instance. + params: Transport configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._client = client self._params = params @@ -382,12 +515,23 @@ class SmallWebRTCInputTransport(BaseInputTransport): self._initialized = False async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process incoming frames including user image requests. + + Args: + frame: The frame to process. + direction: The direction of frame flow in the pipeline. + """ await super().process_frame(frame, direction) if isinstance(frame, UserImageRequestFrame): await self.request_participant_image(frame) async def start(self, frame: StartFrame): + """Start the input transport and establish WebRTC connection. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -404,6 +548,7 @@ class SmallWebRTCInputTransport(BaseInputTransport): await self.set_transport_ready(frame) async def _stop_tasks(self): + """Stop all background tasks.""" if self._receive_audio_task: await self.cancel_task(self._receive_audio_task) self._receive_audio_task = None @@ -412,16 +557,27 @@ class SmallWebRTCInputTransport(BaseInputTransport): self._receive_video_task = None async def stop(self, frame: EndFrame): + """Stop the input transport and disconnect from WebRTC. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) await self._stop_tasks() await self._client.disconnect() async def cancel(self, frame: CancelFrame): + """Cancel the input transport and disconnect immediately. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) await self._stop_tasks() await self._client.disconnect() async def _receive_audio(self): + """Background task for receiving audio frames from WebRTC.""" try: audio_iterator = self._client.read_audio_frame() async for audio_frame in WatchdogAsyncIterator( @@ -434,6 +590,7 @@ class SmallWebRTCInputTransport(BaseInputTransport): logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})") async def _receive_video(self): + """Background task for receiving video frames from WebRTC.""" try: video_iterator = self._client.read_video_frame() async for video_frame in WatchdogAsyncIterator( @@ -462,16 +619,24 @@ class SmallWebRTCInputTransport(BaseInputTransport): logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})") async def push_app_message(self, message: Any): + """Push an application message into the pipeline. + + Args: + message: The application message to process. + """ logger.debug(f"Received app message inside SmallWebRTCInputTransport {message}") frame = TransportMessageUrgentFrame(message=message) await self.push_frame(frame) # Add this method similar to DailyInputTransport.request_participant_image async def request_participant_image(self, frame: UserImageRequestFrame): - """Requests an image frame from the participant's video stream. + """Request an image frame from the participant's video stream. When a UserImageRequestFrame is received, this method will store the request and the next video frame received will be converted to a UserImageRawFrame. + + Args: + frame: The user image request frame. """ logger.debug(f"Requesting image from participant: {frame.user_id}") @@ -486,12 +651,25 @@ class SmallWebRTCInputTransport(BaseInputTransport): class SmallWebRTCOutputTransport(BaseOutputTransport): + """Output transport implementation for SmallWebRTC. + + Handles outgoing audio and video streams to WebRTC peers, + including transport message sending. + """ + def __init__( self, client: SmallWebRTCClient, params: TransportParams, **kwargs, ): + """Initialize the WebRTC output transport. + + Args: + client: The WebRTC client instance. + params: Transport configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._client = client self._params = params @@ -500,6 +678,11 @@ class SmallWebRTCOutputTransport(BaseOutputTransport): self._initialized = False async def start(self, frame: StartFrame): + """Start the output transport and establish WebRTC connection. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -512,24 +695,55 @@ class SmallWebRTCOutputTransport(BaseOutputTransport): await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): + """Stop the output transport and disconnect from WebRTC. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) await self._client.disconnect() async def cancel(self, frame: CancelFrame): + """Cancel the output transport and disconnect immediately. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) await self._client.disconnect() async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): + """Send a transport message through the WebRTC connection. + + Args: + frame: The transport message frame to send. + """ await self._client.send_message(frame) async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the WebRTC connection. + + Args: + frame: The output audio frame to transmit. + """ await self._client.write_audio_frame(frame) async def write_video_frame(self, frame: OutputImageRawFrame): + """Write a video frame to the WebRTC connection. + + Args: + frame: The output video frame to transmit. + """ await self._client.write_video_frame(frame) class SmallWebRTCTransport(BaseTransport): + """WebRTC transport implementation for real-time communication. + + Provides bidirectional audio and video streaming over WebRTC connections + with support for application messaging and connection event handling. + """ + def __init__( self, webrtc_connection: SmallWebRTCConnection, @@ -537,6 +751,14 @@ class SmallWebRTCTransport(BaseTransport): input_name: Optional[str] = None, output_name: Optional[str] = None, ): + """Initialize the WebRTC transport. + + Args: + webrtc_connection: The underlying WebRTC connection handler. + params: Transport configuration parameters. + input_name: Optional name for the input processor. + output_name: Optional name for the output processor. + """ super().__init__(input_name=input_name, output_name=output_name) self._params = params @@ -558,6 +780,11 @@ class SmallWebRTCTransport(BaseTransport): self._register_event_handler("on_client_disconnected") def input(self) -> SmallWebRTCInputTransport: + """Get the input transport processor. + + Returns: + The input transport for handling incoming media streams. + """ if not self._input: self._input = SmallWebRTCInputTransport( self._client, self._params, name=self._input_name @@ -565,6 +792,11 @@ class SmallWebRTCTransport(BaseTransport): return self._input def output(self) -> SmallWebRTCOutputTransport: + """Get the output transport processor. + + Returns: + The output transport for handling outgoing media streams. + """ if not self._output: self._output = SmallWebRTCOutputTransport( self._client, self._params, name=self._input_name @@ -572,20 +804,33 @@ class SmallWebRTCTransport(BaseTransport): return self._output async def send_image(self, frame: OutputImageRawFrame | SpriteFrame): + """Send an image frame through the transport. + + Args: + frame: The image frame to send. + """ if self._output: await self._output.queue_frame(frame, FrameDirection.DOWNSTREAM) async def send_audio(self, frame: OutputAudioRawFrame): + """Send an audio frame through the transport. + + Args: + frame: The audio frame to send. + """ if self._output: await self._output.queue_frame(frame, FrameDirection.DOWNSTREAM) async def _on_app_message(self, message: Any): + """Handle incoming application messages.""" if self._input: await self._input.push_app_message(message) await self._call_event_handler("on_app_message", message) async def _on_client_connected(self, webrtc_connection): + """Handle client connection events.""" await self._call_event_handler("on_client_connected", webrtc_connection) async def _on_client_disconnected(self, webrtc_connection): + """Handle client disconnection events.""" await self._call_event_handler("on_client_disconnected", webrtc_connection) diff --git a/src/pipecat/transports/network/webrtc_connection.py b/src/pipecat/transports/network/webrtc_connection.py index 49aa2b1da..e08d678e5 100644 --- a/src/pipecat/transports/network/webrtc_connection.py +++ b/src/pipecat/transports/network/webrtc_connection.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Small WebRTC connection implementation for Pipecat. + +This module provides a WebRTC connection implementation using aiortc, +with support for audio/video tracks, data channels, and signaling +for real-time communication applications. +""" + import asyncio import json import time @@ -35,36 +42,85 @@ VIDEO_TRANSCEIVER_INDEX = 1 class TrackStatusMessage(BaseModel): + """Message for updating track enabled/disabled status. + + Parameters: + type: Message type identifier. + receiver_index: Index of the track receiver to update. + enabled: Whether the track should be enabled or disabled. + """ + type: Literal["trackStatus"] receiver_index: int enabled: bool class RenegotiateMessage(BaseModel): + """Message requesting WebRTC renegotiation. + + Parameters: + type: Message type identifier for renegotiation requests. + """ + type: Literal["renegotiate"] = "renegotiate" class PeerLeftMessage(BaseModel): + """Message indicating a peer has left the connection. + + Parameters: + type: Message type identifier for peer departure. + """ + type: Literal["peerLeft"] = "peerLeft" class SignallingMessage: + """Union types for signaling message handling. + + Parameters: + Inbound: Types of messages that can be received from peers. + outbound: Types of messages that can be sent to peers. + """ + Inbound = Union[TrackStatusMessage] # in case we need to add new messages in the future outbound = Union[RenegotiateMessage] class SmallWebRTCTrack: + """Wrapper for WebRTC media tracks with enabled/disabled state management. + + Provides additional functionality on top of aiortc MediaStreamTrack including + enable/disable control and frame discarding for audio and video streams. + """ + def __init__(self, track: MediaStreamTrack): + """Initialize the WebRTC track wrapper. + + Args: + track: The underlying MediaStreamTrack to wrap. + """ self._track = track self._enabled = True def set_enabled(self, enabled: bool) -> None: + """Enable or disable the track. + + Args: + enabled: Whether the track should be enabled for receiving frames. + """ self._enabled = enabled def is_enabled(self) -> bool: + """Check if the track is currently enabled. + + Returns: + True if the track is enabled for receiving frames. + """ return self._enabled async def discard_old_frames(self): + """Discard old frames from the track queue to reduce latency.""" remote_track = self._track if isinstance(remote_track, RemoteStreamTrack): if not hasattr(remote_track, "_queue") or not isinstance( @@ -78,11 +134,24 @@ class SmallWebRTCTrack: remote_track._queue.task_done() async def recv(self) -> Optional[Frame]: + """Receive the next frame from the track. + + Returns: + The next frame if the track is enabled, None otherwise. + """ if not self._enabled: return None return await self._track.recv() def __getattr__(self, name): + """Forward attribute access to the underlying track. + + Args: + name: The attribute name to access. + + Returns: + The attribute value from the underlying track. + """ # Forward other attribute/method calls to the underlying track return getattr(self._track, name) @@ -92,7 +161,22 @@ IceServer = RTCIceServer class SmallWebRTCConnection(BaseObject): + """WebRTC connection implementation using aiortc. + + Provides WebRTC peer connection functionality including ICE server configuration, + track management, data channel communication, and connection state handling + for real-time audio/video communication. + """ + def __init__(self, ice_servers: Optional[Union[List[str], List[IceServer]]] = None): + """Initialize the WebRTC connection. + + Args: + ice_servers: List of ICE servers as URLs or IceServer objects. + + Raises: + TypeError: If ice_servers contains mixed types or unsupported types. + """ super().__init__() if not ice_servers: self.ice_servers: List[IceServer] = [] @@ -126,13 +210,24 @@ class SmallWebRTCConnection(BaseObject): @property def pc(self) -> RTCPeerConnection: + """Get the underlying RTCPeerConnection. + + Returns: + The aiortc RTCPeerConnection instance. + """ return self._pc @property def pc_id(self) -> str: + """Get the peer connection identifier. + + Returns: + The unique identifier for this peer connection. + """ return self._pc_id def _initialize(self): + """Initialize the peer connection and associated components.""" logger.debug("Initializing new peer connection") rtc_config = RTCConfiguration(iceServers=self.ice_servers) @@ -147,6 +242,8 @@ class SmallWebRTCConnection(BaseObject): self._pending_app_messages = [] def _setup_listeners(self): + """Set up event listeners for the peer connection.""" + @self._pc.on("datachannel") def on_datachannel(channel): self._data_channel = channel @@ -208,6 +305,7 @@ class SmallWebRTCConnection(BaseObject): await self._call_event_handler("track-ended", track) async def _create_answer(self, sdp: str, type: str): + """Create an SDP answer for the given offer.""" offer = RTCSessionDescription(sdp=sdp, type=type) await self._pc.setRemoteDescription(offer) @@ -223,9 +321,16 @@ class SmallWebRTCConnection(BaseObject): self._answer = self._pc.localDescription async def initialize(self, sdp: str, type: str): + """Initialize the connection with an SDP offer. + + Args: + sdp: The SDP offer string. + type: The SDP type (usually "offer"). + """ await self._create_answer(sdp, type) async def connect(self): + """Connect the WebRTC peer connection and handle initial setup.""" self._connect_invoked = True # If we already connected, trigger again the connected event if self.is_connected(): @@ -241,6 +346,13 @@ class SmallWebRTCConnection(BaseObject): self.ask_to_renegotiate() async def renegotiate(self, sdp: str, type: str, restart_pc: bool = False): + """Renegotiate the WebRTC connection with new parameters. + + Args: + sdp: The new SDP offer string. + type: The SDP type (usually "offer"). + restart_pc: Whether to restart the peer connection entirely. + """ logger.debug(f"Renegotiating {self._pc_id}") if restart_pc: @@ -264,6 +376,7 @@ class SmallWebRTCConnection(BaseObject): asyncio.create_task(delayed_task()) def force_transceivers_to_send_recv(self): + """Force all transceivers to bidirectional send/receive mode.""" for transceiver in self._pc.getTransceivers(): transceiver.direction = "sendrecv" # logger.debug( @@ -272,6 +385,11 @@ class SmallWebRTCConnection(BaseObject): # logger.debug(f"Sender track: {transceiver.sender.track}") def replace_audio_track(self, track): + """Replace the audio track in the first transceiver. + + Args: + track: The new audio track to use for sending. + """ logger.debug(f"Replacing audio track {track.kind}") # Transceivers always appear in creation-order for both peers # For now we are only considering that we are going to have 02 transceivers, @@ -283,6 +401,11 @@ class SmallWebRTCConnection(BaseObject): logger.warning("Audio transceiver not found. Cannot replace audio track.") def replace_video_track(self, track): + """Replace the video track in the second transceiver. + + Args: + track: The new video track to use for sending. + """ logger.debug(f"Replacing video track {track.kind}") # Transceivers always appear in creation-order for both peers # For now we are only considering that we are going to have 02 transceivers, @@ -294,10 +417,12 @@ class SmallWebRTCConnection(BaseObject): logger.warning("Video transceiver not found. Cannot replace video track.") async def disconnect(self): + """Disconnect from the WebRTC peer connection.""" self.send_app_message({"type": SIGNALLING_TYPE, "message": PeerLeftMessage().model_dump()}) await self._close() async def _close(self): + """Close the peer connection and cleanup resources.""" if self._pc: await self._pc.close() self._message_queue.clear() @@ -305,6 +430,12 @@ class SmallWebRTCConnection(BaseObject): self._track_map = {} def get_answer(self): + """Get the SDP answer for the current connection. + + Returns: + Dictionary containing SDP answer, type, and peer connection ID, + or None if no answer is available. + """ if not self._answer: return None @@ -315,6 +446,7 @@ class SmallWebRTCConnection(BaseObject): } async def _handle_new_connection_state(self): + """Handle changes in the peer connection state.""" state = self._pc.connectionState if state == "connected" and not self._connect_invoked: # We are going to wait until the pipeline is ready before triggering the event @@ -328,7 +460,12 @@ class SmallWebRTCConnection(BaseObject): # Despite the fact that aiortc provides this listener, they don't have a status for "disconnected" # So, there is no advantage in looking at self._pc.connectionState # That is why we are trying to keep our own state - def is_connected(self): + def is_connected(self) -> bool: + """Check if the WebRTC connection is currently active. + + Returns: + True if the connection is active and receiving data. + """ # If the small webrtc transport has never invoked to connect # we are acting like if we are not connected if not self._connect_invoked: @@ -342,6 +479,11 @@ class SmallWebRTCConnection(BaseObject): return (time.time() - self._last_received_time) < 3 def audio_input_track(self): + """Get the audio input track wrapper. + + Returns: + SmallWebRTCTrack wrapper for the audio track, or None if unavailable. + """ if self._track_map.get(AUDIO_TRANSCEIVER_INDEX): return self._track_map[AUDIO_TRANSCEIVER_INDEX] @@ -359,6 +501,11 @@ class SmallWebRTCConnection(BaseObject): return audio_track def video_input_track(self): + """Get the video input track wrapper. + + Returns: + SmallWebRTCTrack wrapper for the video track, or None if unavailable. + """ if self._track_map.get(VIDEO_TRANSCEIVER_INDEX): return self._track_map[VIDEO_TRANSCEIVER_INDEX] @@ -376,6 +523,11 @@ class SmallWebRTCConnection(BaseObject): return video_track def send_app_message(self, message: Any): + """Send an application message through the data channel. + + Args: + message: The message to send (will be JSON serialized). + """ json_message = json.dumps(message) if self._data_channel and self._data_channel.readyState == "open": self._data_channel.send(json_message) @@ -384,6 +536,7 @@ class SmallWebRTCConnection(BaseObject): self._message_queue.append(json_message) def ask_to_renegotiate(self): + """Request renegotiation of the WebRTC connection.""" if self._renegotiation_in_progress: return @@ -393,6 +546,7 @@ class SmallWebRTCConnection(BaseObject): ) def _handle_signalling_message(self, message): + """Handle incoming signaling messages.""" logger.debug(f"Signalling message received: {message}") inbound_adapter = TypeAdapter(SignallingMessage.Inbound) signalling_message = inbound_adapter.validate_python(message) diff --git a/src/pipecat/transports/network/websocket_client.py b/src/pipecat/transports/network/websocket_client.py index 98c7f9e2d..f0746a589 100644 --- a/src/pipecat/transports/network/websocket_client.py +++ b/src/pipecat/transports/network/websocket_client.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""WebSocket client transport implementation for Pipecat. + +This module provides a WebSocket client transport that enables bidirectional +communication over WebSocket connections, with support for audio streaming, +frame serialization, and connection management. +""" + import asyncio import io import time @@ -34,17 +41,38 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager class WebsocketClientParams(TransportParams): + """Configuration parameters for WebSocket client transport. + + Parameters: + add_wav_header: Whether to add WAV headers to audio frames. + serializer: Frame serializer for encoding/decoding messages. + """ + add_wav_header: bool = True serializer: Optional[FrameSerializer] = None class WebsocketClientCallbacks(BaseModel): + """Callback functions for WebSocket client events. + + Parameters: + on_connected: Called when WebSocket connection is established. + on_disconnected: Called when WebSocket connection is closed. + on_message: Called when a message is received from the WebSocket. + """ + on_connected: Callable[[websockets.WebSocketClientProtocol], Awaitable[None]] on_disconnected: Callable[[websockets.WebSocketClientProtocol], Awaitable[None]] on_message: Callable[[websockets.WebSocketClientProtocol, websockets.Data], Awaitable[None]] class WebsocketClientSession: + """Manages a WebSocket client connection session. + + Handles connection lifecycle, message sending/receiving, and provides + callback mechanisms for connection events. + """ + def __init__( self, uri: str, @@ -52,6 +80,14 @@ class WebsocketClientSession: callbacks: WebsocketClientCallbacks, transport_name: str, ): + """Initialize the WebSocket client session. + + Args: + uri: The WebSocket URI to connect to. + params: Configuration parameters for the session. + callbacks: Callback functions for session events. + transport_name: Name of the parent transport for logging. + """ self._uri = uri self._params = params self._callbacks = callbacks @@ -63,6 +99,14 @@ class WebsocketClientSession: @property def task_manager(self) -> BaseTaskManager: + """Get the task manager for this session. + + Returns: + The task manager instance. + + Raises: + Exception: If task manager is not initialized. + """ if not self._task_manager: raise Exception( f"{self._transport_name}::WebsocketClientSession: TaskManager not initialized (pipeline not started?)" @@ -70,11 +114,17 @@ class WebsocketClientSession: return self._task_manager async def setup(self, task_manager: BaseTaskManager): + """Set up the session with a task manager. + + Args: + task_manager: The task manager to use for session tasks. + """ self._leave_counter += 1 if not self._task_manager: self._task_manager = task_manager async def connect(self): + """Connect to the WebSocket server.""" if self._websocket: return @@ -89,6 +139,7 @@ class WebsocketClientSession: logger.error(f"Timeout connecting to {self._uri}") async def disconnect(self): + """Disconnect from the WebSocket server.""" self._leave_counter -= 1 if not self._websocket or self._leave_counter > 0: return @@ -99,6 +150,11 @@ class WebsocketClientSession: self._websocket = None async def send(self, message: websockets.Data): + """Send a message through the WebSocket connection. + + Args: + message: The message data to send. + """ try: if self._websocket: await self._websocket.send(message) @@ -106,6 +162,7 @@ class WebsocketClientSession: logger.error(f"{self} exception sending data: {e.__class__.__name__} ({e})") async def _client_task_handler(self): + """Handle incoming messages from the WebSocket connection.""" try: # Handle incoming messages async for message in self._websocket: @@ -116,16 +173,30 @@ class WebsocketClientSession: await self._callbacks.on_disconnected(self._websocket) def __str__(self): + """String representation of the WebSocket client session.""" return f"{self._transport_name}::WebsocketClientSession" class WebsocketClientInputTransport(BaseInputTransport): + """WebSocket client input transport for receiving frames. + + Handles incoming WebSocket messages, deserializes them to frames, + and pushes them downstream in the processing pipeline. + """ + def __init__( self, transport: BaseTransport, session: WebsocketClientSession, params: WebsocketClientParams, ): + """Initialize the WebSocket client input transport. + + Args: + transport: The parent transport instance. + session: The WebSocket session to use for communication. + params: Configuration parameters for the transport. + """ super().__init__(params) self._transport = transport @@ -136,10 +207,20 @@ class WebsocketClientInputTransport(BaseInputTransport): self._initialized = False async def setup(self, setup: FrameProcessorSetup): + """Set up the input transport with the frame processor setup. + + Args: + setup: The frame processor setup configuration. + """ await super().setup(setup) await self._session.setup(setup.task_manager) async def start(self, frame: StartFrame): + """Start the input transport and initialize the WebSocket connection. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -153,18 +234,35 @@ class WebsocketClientInputTransport(BaseInputTransport): await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): + """Stop the input transport and disconnect from WebSocket. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) await self._session.disconnect() async def cancel(self, frame: CancelFrame): + """Cancel the input transport and disconnect from WebSocket. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) await self._session.disconnect() async def cleanup(self): + """Clean up the input transport resources.""" await super().cleanup() await self._transport.cleanup() async def on_message(self, websocket, message): + """Handle incoming WebSocket messages. + + Args: + websocket: The WebSocket connection that received the message. + message: The received message data. + """ if not self._params.serializer: return frame = await self._params.serializer.deserialize(message) @@ -177,12 +275,25 @@ class WebsocketClientInputTransport(BaseInputTransport): class WebsocketClientOutputTransport(BaseOutputTransport): + """WebSocket client output transport for sending frames. + + Handles outgoing frames, serializes them for WebSocket transmission, + and manages audio streaming with proper timing simulation. + """ + def __init__( self, transport: BaseTransport, session: WebsocketClientSession, params: WebsocketClientParams, ): + """Initialize the WebSocket client output transport. + + Args: + transport: The parent transport instance. + session: The WebSocket session to use for communication. + params: Configuration parameters for the transport. + """ super().__init__(params) self._transport = transport @@ -201,10 +312,20 @@ class WebsocketClientOutputTransport(BaseOutputTransport): self._initialized = False async def setup(self, setup: FrameProcessorSetup): + """Set up the output transport with the frame processor setup. + + Args: + setup: The frame processor setup configuration. + """ await super().setup(setup) await self._session.setup(setup.task_manager) async def start(self, frame: StartFrame): + """Start the output transport and initialize the WebSocket connection. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -219,21 +340,42 @@ class WebsocketClientOutputTransport(BaseOutputTransport): await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): + """Stop the output transport and disconnect from WebSocket. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) await self._session.disconnect() async def cancel(self, frame: CancelFrame): + """Cancel the output transport and disconnect from WebSocket. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) await self._session.disconnect() async def cleanup(self): + """Clean up the output transport resources.""" await super().cleanup() await self._transport.cleanup() async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): + """Send a transport message through the WebSocket. + + Args: + frame: The transport message frame to send. + """ await self._write_frame(frame) async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the WebSocket with optional WAV header. + + Args: + frame: The output audio frame to write. + """ frame = OutputAudioRawFrame( audio=frame.audio, sample_rate=self.sample_rate, @@ -260,6 +402,7 @@ class WebsocketClientOutputTransport(BaseOutputTransport): await self._write_audio_sleep() async def _write_frame(self, frame: Frame): + """Write a frame to the WebSocket after serialization.""" if not self._params.serializer: return payload = await self._params.serializer.serialize(frame) @@ -267,6 +410,7 @@ class WebsocketClientOutputTransport(BaseOutputTransport): await self._session.send(payload) async def _write_audio_sleep(self): + """Simulate audio playback timing with sleep delays.""" # Simulate a clock. current_time = time.monotonic() sleep_duration = max(0, self._next_send_time - current_time) @@ -278,11 +422,23 @@ class WebsocketClientOutputTransport(BaseOutputTransport): class WebsocketClientTransport(BaseTransport): + """WebSocket client transport for bidirectional communication. + + Provides a complete WebSocket client transport implementation with + input and output capabilities, connection management, and event handling. + """ + def __init__( self, uri: str, params: Optional[WebsocketClientParams] = None, ): + """Initialize the WebSocket client transport. + + Args: + uri: The WebSocket URI to connect to. + params: Optional configuration parameters for the transport. + """ super().__init__() self._params = params or WebsocketClientParams() @@ -304,21 +460,34 @@ class WebsocketClientTransport(BaseTransport): self._register_event_handler("on_disconnected") def input(self) -> WebsocketClientInputTransport: + """Get the input transport for receiving frames. + + Returns: + The WebSocket client input transport instance. + """ if not self._input: self._input = WebsocketClientInputTransport(self, self._session, self._params) return self._input def output(self) -> WebsocketClientOutputTransport: + """Get the output transport for sending frames. + + Returns: + The WebSocket client output transport instance. + """ if not self._output: self._output = WebsocketClientOutputTransport(self, self._session, self._params) return self._output async def _on_connected(self, websocket): + """Handle WebSocket connection established event.""" await self._call_event_handler("on_connected", websocket) async def _on_disconnected(self, websocket): + """Handle WebSocket connection closed event.""" await self._call_event_handler("on_disconnected", websocket) async def _on_message(self, websocket, message): + """Handle incoming WebSocket message.""" if self._input: await self._input.on_message(websocket, message) diff --git a/src/pipecat/transports/network/websocket_server.py b/src/pipecat/transports/network/websocket_server.py index 1fe33f4a6..dbe418d3b 100644 --- a/src/pipecat/transports/network/websocket_server.py +++ b/src/pipecat/transports/network/websocket_server.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""WebSocket server transport implementation for Pipecat. + +This module provides WebSocket server transport functionality for real-time +audio and data streaming, including client connection management, session +handling, and frame serialization. +""" + import asyncio import io import time @@ -39,12 +46,29 @@ except ModuleNotFoundError as e: class WebsocketServerParams(TransportParams): + """Configuration parameters for WebSocket server transport. + + Parameters: + add_wav_header: Whether to add WAV headers to audio frames. + serializer: Frame serializer for message encoding/decoding. + session_timeout: Timeout in seconds for client sessions. + """ + add_wav_header: bool = False serializer: Optional[FrameSerializer] = None session_timeout: Optional[int] = None class WebsocketServerCallbacks(BaseModel): + """Callback functions for WebSocket server events. + + Parameters: + on_client_connected: Called when a client connects to the server. + on_client_disconnected: Called when a client disconnects from the server. + on_session_timeout: Called when a client session times out. + on_websocket_ready: Called when the WebSocket server is ready to accept connections. + """ + on_client_connected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]] on_client_disconnected: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]] on_session_timeout: Callable[[websockets.WebSocketServerProtocol], Awaitable[None]] @@ -52,6 +76,12 @@ class WebsocketServerCallbacks(BaseModel): class WebsocketServerInputTransport(BaseInputTransport): + """WebSocket server input transport for receiving client data. + + Handles incoming WebSocket connections, message processing, and client + session management including timeout monitoring and connection lifecycle. + """ + def __init__( self, transport: BaseTransport, @@ -61,6 +91,16 @@ class WebsocketServerInputTransport(BaseInputTransport): callbacks: WebsocketServerCallbacks, **kwargs, ): + """Initialize the WebSocket server input transport. + + Args: + transport: The parent transport instance. + host: Host address to bind the WebSocket server to. + port: Port number to bind the WebSocket server to. + params: WebSocket server configuration parameters. + callbacks: Callback functions for WebSocket events. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._transport = transport @@ -82,6 +122,11 @@ class WebsocketServerInputTransport(BaseInputTransport): self._initialized = False async def start(self, frame: StartFrame): + """Start the WebSocket server and initialize components. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -96,6 +141,11 @@ class WebsocketServerInputTransport(BaseInputTransport): await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): + """Stop the WebSocket server and cleanup resources. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) self._stop_server_event.set() if self._monitor_task: @@ -106,6 +156,11 @@ class WebsocketServerInputTransport(BaseInputTransport): self._server_task = None async def cancel(self, frame: CancelFrame): + """Cancel the WebSocket server and stop all processing. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) if self._monitor_task: await self.cancel_task(self._monitor_task) @@ -115,16 +170,19 @@ class WebsocketServerInputTransport(BaseInputTransport): self._server_task = None async def cleanup(self): + """Cleanup resources and parent transport.""" await super().cleanup() await self._transport.cleanup() async def _server_task_handler(self): + """Handle WebSocket server startup and client connections.""" logger.info(f"Starting websocket server on {self._host}:{self._port}") async with websockets.serve(self._client_handler, self._host, self._port) as server: await self._callbacks.on_websocket_ready() await self._stop_server_event.wait() async def _client_handler(self, websocket: websockets.WebSocketServerProtocol, path): + """Handle individual client connections and message processing.""" logger.info(f"New client connection from {websocket.remote_address}") if self._websocket: await self._websocket.close() @@ -170,9 +228,7 @@ class WebsocketServerInputTransport(BaseInputTransport): async def _monitor_websocket( self, websocket: websockets.WebSocketServerProtocol, session_timeout: int ): - """Wait for session_timeout seconds, if the websocket is still open, - trigger timeout event. - """ + """Monitor WebSocket connection for session timeout.""" try: await asyncio.sleep(session_timeout) if not websocket.closed: @@ -183,7 +239,20 @@ class WebsocketServerInputTransport(BaseInputTransport): class WebsocketServerOutputTransport(BaseOutputTransport): + """WebSocket server output transport for sending data to clients. + + Handles outgoing frame serialization, audio streaming with timing control, + and client connection management for WebSocket communication. + """ + def __init__(self, transport: BaseTransport, params: WebsocketServerParams, **kwargs): + """Initialize the WebSocket server output transport. + + Args: + transport: The parent transport instance. + params: WebSocket server configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._transport = transport @@ -203,12 +272,22 @@ class WebsocketServerOutputTransport(BaseOutputTransport): self._initialized = False async def set_client_connection(self, websocket: Optional[websockets.WebSocketServerProtocol]): + """Set the active client WebSocket connection. + + Args: + websocket: The WebSocket connection to set as active, or None to clear. + """ if self._websocket: await self._websocket.close() logger.warning("Only one client allowed, using new connection") self._websocket = websocket async def start(self, frame: StartFrame): + """Start the output transport and initialize components. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -222,18 +301,35 @@ class WebsocketServerOutputTransport(BaseOutputTransport): await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): + """Stop the output transport and send final frame. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) await self._write_frame(frame) async def cancel(self, frame: CancelFrame): + """Cancel the output transport and send cancellation frame. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) await self._write_frame(frame) async def cleanup(self): + """Cleanup resources and parent transport.""" await super().cleanup() await self._transport.cleanup() async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames and handle interruption timing. + + Args: + frame: The frame to process. + direction: The direction of frame flow in the pipeline. + """ await super().process_frame(frame, direction) if isinstance(frame, StartInterruptionFrame): @@ -241,9 +337,19 @@ class WebsocketServerOutputTransport(BaseOutputTransport): self._next_send_time = 0 async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): + """Send a transport message frame to the client. + + Args: + frame: The transport message frame to send. + """ await self._write_frame(frame) async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the WebSocket client with timing control. + + Args: + frame: The output audio frame to write. + """ if not self._websocket: # Simulate audio playback with a sleep. await self._write_audio_sleep() @@ -275,6 +381,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport): await self._write_audio_sleep() async def _write_frame(self, frame: Frame): + """Serialize and send a frame to the WebSocket client.""" if not self._params.serializer: return @@ -286,6 +393,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport): logger.error(f"{self} exception sending data: {e.__class__.__name__} ({e})") async def _write_audio_sleep(self): + """Simulate audio device timing by sleeping between audio chunks.""" # Simulate a clock. current_time = time.monotonic() sleep_duration = max(0, self._next_send_time - current_time) @@ -297,6 +405,13 @@ class WebsocketServerOutputTransport(BaseOutputTransport): class WebsocketServerTransport(BaseTransport): + """WebSocket server transport for bidirectional real-time communication. + + Provides a complete WebSocket server implementation with separate input and + output transports, client connection management, and event handling for + real-time audio and data streaming applications. + """ + def __init__( self, params: WebsocketServerParams, @@ -305,6 +420,15 @@ class WebsocketServerTransport(BaseTransport): input_name: Optional[str] = None, output_name: Optional[str] = None, ): + """Initialize the WebSocket server transport. + + Args: + params: WebSocket server configuration parameters. + host: Host address to bind the server to. Defaults to "localhost". + port: Port number to bind the server to. Defaults to 8765. + input_name: Optional name for the input processor. + output_name: Optional name for the output processor. + """ super().__init__(input_name=input_name, output_name=output_name) self._host = host self._port = port @@ -328,6 +452,11 @@ class WebsocketServerTransport(BaseTransport): self._register_event_handler("on_websocket_ready") def input(self) -> WebsocketServerInputTransport: + """Get the input transport for receiving client data. + + Returns: + The WebSocket server input transport instance. + """ if not self._input: self._input = WebsocketServerInputTransport( self, self._host, self._port, self._params, self._callbacks, name=self._input_name @@ -335,6 +464,11 @@ class WebsocketServerTransport(BaseTransport): return self._input def output(self) -> WebsocketServerOutputTransport: + """Get the output transport for sending data to clients. + + Returns: + The WebSocket server output transport instance. + """ if not self._output: self._output = WebsocketServerOutputTransport( self, self._params, name=self._output_name @@ -342,6 +476,7 @@ class WebsocketServerTransport(BaseTransport): return self._output async def _on_client_connected(self, websocket): + """Handle client connection events.""" if self._output: await self._output.set_client_connection(websocket) await self._call_event_handler("on_client_connected", websocket) @@ -349,6 +484,7 @@ class WebsocketServerTransport(BaseTransport): logger.error("A WebsocketServerTransport output is missing in the pipeline") async def _on_client_disconnected(self, websocket): + """Handle client disconnection events.""" if self._output: await self._output.set_client_connection(None) await self._call_event_handler("on_client_disconnected", websocket) @@ -356,7 +492,9 @@ class WebsocketServerTransport(BaseTransport): logger.error("A WebsocketServerTransport output is missing in the pipeline") async def _on_session_timeout(self, websocket): + """Handle client session timeout events.""" await self._call_event_handler("on_session_timeout", websocket) async def _on_websocket_ready(self): + """Handle WebSocket server ready events.""" await self._call_event_handler("on_websocket_ready") diff --git a/src/pipecat/transports/services/daily.py b/src/pipecat/transports/services/daily.py index 4c00fa44c..b62aa4b6e 100644 --- a/src/pipecat/transports/services/daily.py +++ b/src/pipecat/transports/services/daily.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Daily transport implementation for Pipecat. + +This module provides comprehensive Daily video conferencing integration including +audio/video streaming, transcription, recording, dial-in/out functionality, and +real-time communication features. +""" + import asyncio import time from concurrent.futures import ThreadPoolExecutor @@ -67,7 +74,7 @@ VAD_RESET_PERIOD_MS = 2000 class DailyTransportMessageFrame(TransportMessageFrame): """Frame for transport messages in Daily calls. - Attributes: + Parameters: participant_id: Optional ID of the participant this message is for/from. """ @@ -78,7 +85,7 @@ class DailyTransportMessageFrame(TransportMessageFrame): class DailyTransportMessageUrgentFrame(TransportMessageUrgentFrame): """Frame for urgent transport messages in Daily calls. - Attributes: + Parameters: participant_id: Optional ID of the participant this message is for/from. """ @@ -89,13 +96,15 @@ class WebRTCVADAnalyzer(VADAnalyzer): """Voice Activity Detection analyzer using WebRTC. Implements voice activity detection using Daily's native WebRTC VAD. - - Args: - sample_rate: Audio sample rate in Hz. - params: VAD configuration parameters (VADParams). """ def __init__(self, *, sample_rate: Optional[int] = None, params: Optional[VADParams] = None): + """Initialize the WebRTC VAD analyzer. + + Args: + sample_rate: Audio sample rate in Hz. + params: VAD configuration parameters. + """ super().__init__(sample_rate=sample_rate, params=params) self._webrtc_vad = Daily.create_native_vad( @@ -104,9 +113,22 @@ class WebRTCVADAnalyzer(VADAnalyzer): logger.debug("Loaded native WebRTC VAD") def num_frames_required(self) -> int: + """Get the number of audio frames required for VAD analysis. + + Returns: + The number of frames needed (equivalent to 10ms of audio). + """ return int(self.sample_rate / 100.0) def voice_confidence(self, buffer) -> float: + """Analyze audio buffer and return voice confidence score. + + Args: + buffer: Audio buffer to analyze. + + Returns: + Voice confidence score between 0.0 and 1.0. + """ confidence = 0 if len(buffer) > 0: confidence = self._webrtc_vad.analyze_frames(buffer) @@ -116,7 +138,7 @@ class WebRTCVADAnalyzer(VADAnalyzer): class DailyDialinSettings(BaseModel): """Settings for Daily's dial-in functionality. - Attributes: + Parameters: call_id: CallId is represented by UUID and represents the sessionId in the SIP Network. call_domain: Call Domain is represented by UUID and represents your Daily Domain on the SIP Network. """ @@ -128,7 +150,7 @@ class DailyDialinSettings(BaseModel): class DailyTranscriptionSettings(BaseModel): """Configuration settings for Daily's transcription service. - Attributes: + Parameters: language: ISO language code for transcription (e.g. "en"). model: Transcription model to use (e.g. "nova-2-general"). profanity_filter: Whether to filter profanity from transcripts. @@ -152,14 +174,14 @@ class DailyTranscriptionSettings(BaseModel): class DailyParams(TransportParams): """Configuration parameters for Daily transport. - Args: - api_url: Daily API base URL - api_key: Daily API authentication key - dialin_settings: Optional settings for dial-in functionality - camera_out_enabled: Whether to enable the main camera output track. If enabled, it still needs `video_out_enabled=True` - microphone_out_enabled: Whether to enable the main microphone track. If enabled, it still needs `audio_out_enabled=True` - transcription_enabled: Whether to enable speech transcription - transcription_settings: Configuration for transcription service + Parameters: + api_url: Daily API base URL. + api_key: Daily API authentication key. + dialin_settings: Optional settings for dial-in functionality. + camera_out_enabled: Whether to enable the main camera output track. + microphone_out_enabled: Whether to enable the main microphone track. + transcription_enabled: Whether to enable speech transcription. + transcription_settings: Configuration for transcription service. """ api_url: str = "https://api.daily.co/v1" @@ -174,7 +196,7 @@ class DailyParams(TransportParams): class DailyCallbacks(BaseModel): """Callback handlers for Daily events. - Attributes: + Parameters: on_active_speaker_changed: Called when the active speaker of the call has changed. on_joined: Called when bot successfully joined a room. on_left: Called when bot left a room. @@ -230,6 +252,15 @@ class DailyCallbacks(BaseModel): def completion_callback(future): + """Create a completion callback for Daily API calls. + + Args: + future: The asyncio Future to set the result on. + + Returns: + A callback function that sets the future result. + """ + def _callback(*args): def set_result(future, *args): try: @@ -247,6 +278,13 @@ def completion_callback(future): @dataclass class DailyAudioTrack: + """Container for Daily audio track components. + + Parameters: + source: The custom audio source for the track. + track: The custom audio track instance. + """ + source: CustomAudioSource track: CustomAudioTrack @@ -254,21 +292,14 @@ class DailyAudioTrack: class DailyTransportClient(EventHandler): """Core client for interacting with Daily's API. - Manages the connection to Daily rooms and handles all low-level API interactions. - - Args: - room_url: URL of the Daily room to connect to. - token: Optional authentication token for the room. - bot_name: Display name for the bot in the call. - params: Configuration parameters (DailyParams). - callbacks: Event callback handlers (DailyCallbacks). - transport_name: Name identifier for the transport. + Manages the connection to Daily rooms and handles all low-level API interactions + including room management, media streaming, transcription, and event handling. """ _daily_initialized: bool = False - # This is necessary to override EventHandler's __new__ method. def __new__(cls, *args, **kwargs): + """Override EventHandler's __new__ method to ensure Daily is initialized only once.""" return super().__new__(cls) def __init__( @@ -280,6 +311,16 @@ class DailyTransportClient(EventHandler): callbacks: DailyCallbacks, transport_name: str, ): + """Initialize the Daily transport client. + + Args: + room_url: URL of the Daily room to connect to. + token: Optional authentication token for the room. + bot_name: Display name for the bot in the call. + params: Configuration parameters for the transport. + callbacks: Event callback handlers. + transport_name: Name identifier for the transport. + """ super().__init__() if not DailyTransportClient._daily_initialized: @@ -335,25 +376,51 @@ class DailyTransportClient(EventHandler): self._custom_audio_tracks: Dict[str, DailyAudioTrack] = {} def _camera_name(self): + """Generate a unique camera name for this client instance.""" return f"camera-{self}" @property def room_url(self) -> str: + """Get the Daily room URL. + + Returns: + The room URL this client is connected to. + """ return self._room_url @property def participant_id(self) -> str: + """Get the participant ID for this client. + + Returns: + The participant ID assigned by Daily. + """ return self._participant_id @property def in_sample_rate(self) -> int: + """Get the input audio sample rate. + + Returns: + The input sample rate in Hz. + """ return self._in_sample_rate @property def out_sample_rate(self) -> int: + """Get the output audio sample rate. + + Returns: + The output sample rate in Hz. + """ return self._out_sample_rate async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): + """Send an application message to participants. + + Args: + frame: The message frame to send. + """ if not self._joined: return @@ -368,10 +435,20 @@ class DailyTransportClient(EventHandler): await future async def register_audio_destination(self, destination: str): + """Register a custom audio destination for multi-track output. + + Args: + destination: The destination identifier to register. + """ self._custom_audio_tracks[destination] = await self.add_custom_audio_track(destination) self._client.update_publishing({"customAudio": {destination: True}}) async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the appropriate audio track. + + Args: + frame: The audio frame to write. + """ future = self._get_event_loop().create_future() destination = frame.transport_destination @@ -391,10 +468,20 @@ class DailyTransportClient(EventHandler): await future async def write_video_frame(self, frame: OutputImageRawFrame): + """Write a video frame to the camera device. + + Args: + frame: The image frame to write. + """ if not frame.transport_destination and self._camera: self._camera.write_frame(frame.image) async def setup(self, setup: FrameProcessorSetup): + """Setup the client with task manager and event queues. + + Args: + setup: The frame processor setup configuration. + """ if self._task_manager: return @@ -408,6 +495,7 @@ class DailyTransportClient(EventHandler): ) async def cleanup(self): + """Cleanup client resources and cancel tasks.""" if self._event_task and self._task_manager: await self._task_manager.cancel_task(self._event_task) self._event_task = None @@ -422,6 +510,11 @@ class DailyTransportClient(EventHandler): await self._get_event_loop().run_in_executor(self._executor, self._cleanup) async def start(self, frame: StartFrame): + """Start the client and initialize audio/video components. + + Args: + frame: The start frame containing initialization parameters. + """ self._in_sample_rate = self._params.audio_in_sample_rate or frame.audio_in_sample_rate self._out_sample_rate = self._params.audio_out_sample_rate or frame.audio_out_sample_rate @@ -452,6 +545,7 @@ class DailyTransportClient(EventHandler): self._microphone_track = DailyAudioTrack(source=audio_source, track=audio_track) async def join(self): + """Join the Daily room with configured settings.""" # Transport already joined or joining, ignore. if self._joined or self._joining: # Increment leave counter if we already joined. @@ -497,6 +591,7 @@ class DailyTransportClient(EventHandler): await self._callbacks.on_error(error_msg) async def _join(self): + """Execute the actual room join operation.""" future = self._get_event_loop().create_future() camera_enabled = self._params.video_out_enabled and self._params.camera_out_enabled @@ -552,6 +647,7 @@ class DailyTransportClient(EventHandler): return await asyncio.wait_for(future, timeout=10) async def leave(self): + """Leave the Daily room and cleanup resources.""" # Decrement leave counter when leaving. self._leave_counter -= 1 @@ -586,22 +682,39 @@ class DailyTransportClient(EventHandler): await self._callbacks.on_error(error_msg) async def _leave(self): + """Execute the actual room leave operation.""" future = self._get_event_loop().create_future() self._client.leave(completion=completion_callback(future)) return await asyncio.wait_for(future, timeout=10) def _cleanup(self): + """Cleanup the Daily client instance.""" if self._client: self._client.release() self._client = None def participants(self): + """Get current participants in the room. + + Returns: + Dictionary of participants keyed by participant ID. + """ return self._client.participants() def participant_counts(self): + """Get participant count information. + + Returns: + Dictionary with participant count details. + """ return self._client.participant_counts() async def start_dialout(self, settings): + """Start a dial-out call to a phone number. + + Args: + settings: Dial-out configuration settings. + """ logger.debug(f"Starting dialout: settings={settings}") future = self._get_event_loop().create_future() @@ -611,6 +724,11 @@ class DailyTransportClient(EventHandler): logger.error(f"Unable to start dialout: {error}") async def stop_dialout(self, participant_id): + """Stop a dial-out call for a specific participant. + + Args: + participant_id: ID of the participant to stop dial-out for. + """ logger.debug(f"Stopping dialout: participant_id={participant_id}") future = self._get_event_loop().create_future() @@ -620,21 +738,43 @@ class DailyTransportClient(EventHandler): logger.error(f"Unable to stop dialout: {error}") async def send_dtmf(self, settings): + """Send DTMF tones during a call. + + Args: + settings: DTMF settings including tones and target session. + """ future = self._get_event_loop().create_future() self._client.send_dtmf(settings, completion=completion_callback(future)) await future async def sip_call_transfer(self, settings): + """Transfer a SIP call to another destination. + + Args: + settings: SIP call transfer settings. + """ future = self._get_event_loop().create_future() self._client.sip_call_transfer(settings, completion=completion_callback(future)) await future async def sip_refer(self, settings): + """Send a SIP REFER request. + + Args: + settings: SIP REFER settings. + """ future = self._get_event_loop().create_future() self._client.sip_refer(settings, completion=completion_callback(future)) await future async def start_recording(self, streaming_settings, stream_id, force_new): + """Start recording the call. + + Args: + streaming_settings: Recording configuration settings. + stream_id: Unique identifier for the recording stream. + force_new: Whether to force a new recording session. + """ logger.debug( f"Starting recording: stream_id={stream_id} force_new={force_new} settings={streaming_settings}" ) @@ -648,6 +788,11 @@ class DailyTransportClient(EventHandler): logger.error(f"Unable to start recording: {error}") async def stop_recording(self, stream_id): + """Stop recording the call. + + Args: + stream_id: Unique identifier for the recording stream to stop. + """ logger.debug(f"Stopping recording: stream_id={stream_id}") future = self._get_event_loop().create_future() @@ -657,6 +802,11 @@ class DailyTransportClient(EventHandler): logger.error(f"Unable to stop recording: {error}") async def start_transcription(self, settings): + """Start transcription for the call. + + Args: + settings: Transcription configuration settings. + """ if not self._token: logger.warning("Transcription can't be started without a room token") return @@ -673,6 +823,7 @@ class DailyTransportClient(EventHandler): logger.error(f"Unable to start transcription: {error}") async def stop_transcription(self): + """Stop transcription for the call.""" if not self._token: return @@ -685,6 +836,12 @@ class DailyTransportClient(EventHandler): logger.error(f"Unable to stop transcription: {error}") async def send_prebuilt_chat_message(self, message: str, user_name: Optional[str] = None): + """Send a chat message to Daily's Prebuilt main room. + + Args: + message: The chat message to send. + user_name: Optional user name that will appear as sender of the message. + """ if not self._joined: return @@ -695,6 +852,11 @@ class DailyTransportClient(EventHandler): await future async def capture_participant_transcription(self, participant_id: str): + """Enable transcription capture for a specific participant. + + Args: + participant_id: ID of the participant to capture transcription for. + """ if not self._params.transcription_enabled: return @@ -710,6 +872,15 @@ class DailyTransportClient(EventHandler): sample_rate: int = 16000, callback_interval_ms: int = 20, ): + """Capture audio from a specific participant. + + Args: + participant_id: ID of the participant to capture audio from. + callback: Callback function to handle audio data. + audio_source: Audio source to capture (microphone, screenAudio, or custom). + sample_rate: Desired sample rate for audio capture. + callback_interval_ms: Interval between audio callbacks in milliseconds. + """ # Only enable the desired audio source subscription on this participant. if audio_source in ("microphone", "screenAudio"): media = {"media": {audio_source: "subscribed"}} @@ -740,6 +911,15 @@ class DailyTransportClient(EventHandler): video_source: str = "camera", color_format: str = "RGB", ): + """Capture video from a specific participant. + + Args: + participant_id: ID of the participant to capture video from. + callback: Callback function to handle video frames. + framerate: Desired framerate for video capture. + video_source: Video source to capture (camera, screenVideo, or custom). + color_format: Color format for video frames. + """ # Only enable the desired audio source subscription on this participant. if video_source in ("camera", "screenVideo"): media = {"media": {video_source: "subscribed"}} @@ -762,6 +942,14 @@ class DailyTransportClient(EventHandler): ) async def add_custom_audio_track(self, track_name: str) -> DailyAudioTrack: + """Add a custom audio track for multi-stream output. + + Args: + track_name: Name for the custom audio track. + + Returns: + The created DailyAudioTrack instance. + """ future = self._get_event_loop().create_future() audio_source = CustomAudioSource(self._out_sample_rate, 1) @@ -782,6 +970,11 @@ class DailyTransportClient(EventHandler): return track async def remove_custom_audio_track(self, track_name: str): + """Remove a custom audio track. + + Args: + track_name: Name of the custom audio track to remove. + """ future = self._get_event_loop().create_future() self._client.remove_custom_audio_track( track_name=track_name, @@ -790,6 +983,12 @@ class DailyTransportClient(EventHandler): await future async def update_transcription(self, participants=None, instance_id=None): + """Update transcription settings for specific participants. + + Args: + participants: List of participant IDs to enable transcription for. + instance_id: Optional transcription instance ID. + """ future = self._get_event_loop().create_future() self._client.update_transcription( participants, instance_id, completion=completion_callback(future) @@ -797,6 +996,12 @@ class DailyTransportClient(EventHandler): await future async def update_subscriptions(self, participant_settings=None, profile_settings=None): + """Update media subscription settings. + + Args: + participant_settings: Per-participant subscription settings. + profile_settings: Global subscription profile settings. + """ future = self._get_event_loop().create_future() self._client.update_subscriptions( participant_settings=participant_settings, @@ -806,6 +1011,11 @@ class DailyTransportClient(EventHandler): await future async def update_publishing(self, publishing_settings: Mapping[str, Any]): + """Update media publishing settings. + + Args: + publishing_settings: Publishing configuration settings. + """ future = self._get_event_loop().create_future() self._client.update_publishing( publishing_settings=publishing_settings, @@ -814,6 +1024,11 @@ class DailyTransportClient(EventHandler): await future async def update_remote_participants(self, remote_participants: Mapping[str, Any]): + """Update settings for remote participants. + + Args: + remote_participants: Remote participant configuration settings. + """ future = self._get_event_loop().create_future() self._client.update_remote_participants( remote_participants=remote_participants, completion=completion_callback(future) @@ -826,76 +1041,195 @@ class DailyTransportClient(EventHandler): # def on_active_speaker_changed(self, participant): + """Handle active speaker change events. + + Args: + participant: The new active speaker participant info. + """ self._call_event_callback(self._callbacks.on_active_speaker_changed, participant) def on_app_message(self, message: Any, sender: str): + """Handle application message events. + + Args: + message: The received message data. + sender: ID of the message sender. + """ self._call_event_callback(self._callbacks.on_app_message, message, sender) def on_call_state_updated(self, state: str): + """Handle call state update events. + + Args: + state: The new call state. + """ self._call_event_callback(self._callbacks.on_call_state_updated, state) def on_dialin_connected(self, data: Any): + """Handle dial-in connected events. + + Args: + data: Dial-in connection data. + """ self._call_event_callback(self._callbacks.on_dialin_connected, data) def on_dialin_ready(self, sip_endpoint: str): + """Handle dial-in ready events. + + Args: + sip_endpoint: The SIP endpoint for dial-in. + """ self._call_event_callback(self._callbacks.on_dialin_ready, sip_endpoint) def on_dialin_stopped(self, data: Any): + """Handle dial-in stopped events. + + Args: + data: Dial-in stop data. + """ self._call_event_callback(self._callbacks.on_dialin_stopped, data) def on_dialin_error(self, data: Any): + """Handle dial-in error events. + + Args: + data: Dial-in error data. + """ self._call_event_callback(self._callbacks.on_dialin_error, data) def on_dialin_warning(self, data: Any): + """Handle dial-in warning events. + + Args: + data: Dial-in warning data. + """ self._call_event_callback(self._callbacks.on_dialin_warning, data) def on_dialout_answered(self, data: Any): + """Handle dial-out answered events. + + Args: + data: Dial-out answered data. + """ self._call_event_callback(self._callbacks.on_dialout_answered, data) def on_dialout_connected(self, data: Any): + """Handle dial-out connected events. + + Args: + data: Dial-out connection data. + """ self._call_event_callback(self._callbacks.on_dialout_connected, data) def on_dialout_stopped(self, data: Any): + """Handle dial-out stopped events. + + Args: + data: Dial-out stop data. + """ self._call_event_callback(self._callbacks.on_dialout_stopped, data) def on_dialout_error(self, data: Any): + """Handle dial-out error events. + + Args: + data: Dial-out error data. + """ self._call_event_callback(self._callbacks.on_dialout_error, data) def on_dialout_warning(self, data: Any): + """Handle dial-out warning events. + + Args: + data: Dial-out warning data. + """ self._call_event_callback(self._callbacks.on_dialout_warning, data) def on_participant_joined(self, participant): + """Handle participant joined events. + + Args: + participant: The participant that joined. + """ self._call_event_callback(self._callbacks.on_participant_joined, participant) def on_participant_left(self, participant, reason): + """Handle participant left events. + + Args: + participant: The participant that left. + reason: Reason for leaving. + """ self._call_event_callback(self._callbacks.on_participant_left, participant, reason) def on_participant_updated(self, participant): + """Handle participant updated events. + + Args: + participant: The updated participant info. + """ self._call_event_callback(self._callbacks.on_participant_updated, participant) def on_transcription_started(self, status): + """Handle transcription started events. + + Args: + status: Transcription start status. + """ logger.debug(f"Transcription started: {status}") self._transcription_status = status self._call_event_callback(self.update_transcription, self._transcription_ids) def on_transcription_stopped(self, stopped_by, stopped_by_error): + """Handle transcription stopped events. + + Args: + stopped_by: Who stopped the transcription. + stopped_by_error: Whether stopped due to error. + """ logger.debug("Transcription stopped") def on_transcription_error(self, message): + """Handle transcription error events. + + Args: + message: Error message. + """ logger.error(f"Transcription error: {message}") def on_transcription_message(self, message): + """Handle transcription message events. + + Args: + message: The transcription message data. + """ self._call_event_callback(self._callbacks.on_transcription_message, message) def on_recording_started(self, status): + """Handle recording started events. + + Args: + status: Recording start status. + """ logger.debug(f"Recording started: {status}") self._call_event_callback(self._callbacks.on_recording_started, status) def on_recording_stopped(self, stream_id): + """Handle recording stopped events. + + Args: + stream_id: ID of the stopped recording stream. + """ logger.debug(f"Recording stopped: {stream_id}") self._call_event_callback(self._callbacks.on_recording_stopped, stream_id) def on_recording_error(self, stream_id, message): + """Handle recording error events. + + Args: + stream_id: ID of the recording stream with error. + message: Error message. + """ logger.error(f"Recording error for {stream_id}: {message}") self._call_event_callback(self._callbacks.on_recording_error, stream_id, message) @@ -904,12 +1238,14 @@ class DailyTransportClient(EventHandler): # def _audio_data_received(self, participant_id: str, audio_data: AudioData, audio_source: str): + """Handle received audio data from participants.""" callback = self._audio_renderers[participant_id][audio_source] self._call_audio_callback(callback, participant_id, audio_data, audio_source) def _video_frame_received( self, participant_id: str, video_frame: VideoFrame, video_source: str ): + """Handle received video frames from participants.""" callback = self._video_renderers[participant_id][video_source] self._call_video_callback(callback, participant_id, video_frame, video_source) @@ -918,21 +1254,26 @@ class DailyTransportClient(EventHandler): # def _call_audio_callback(self, callback, *args): + """Queue an audio callback for async execution.""" self._call_async_callback(self._audio_queue, callback, *args) def _call_video_callback(self, callback, *args): + """Queue a video callback for async execution.""" self._call_async_callback(self._video_queue, callback, *args) def _call_event_callback(self, callback, *args): + """Queue an event callback for async execution.""" self._call_async_callback(self._event_queue, callback, *args) def _call_async_callback(self, queue: asyncio.Queue, callback, *args): + """Queue a callback for async execution on the event loop.""" future = asyncio.run_coroutine_threadsafe( queue.put((callback, *args)), self._get_event_loop() ) future.result() async def _callback_task_handler(self, queue: asyncio.Queue): + """Handle queued callbacks from the specified queue.""" while True: # Wait to process any callback until we are joined. await self._joined_event.wait() @@ -941,22 +1282,21 @@ class DailyTransportClient(EventHandler): queue.task_done() def _get_event_loop(self) -> asyncio.AbstractEventLoop: + """Get the event loop from the task manager.""" if not self._task_manager: raise Exception(f"{self}: missing task manager (pipeline not started?)") return self._task_manager.get_event_loop() def __str__(self): + """String representation of the DailyTransportClient.""" return f"{self._transport_name}::DailyTransportClient" class DailyInputTransport(BaseInputTransport): """Handles incoming media streams and events from Daily calls. - Processes incoming audio, video, transcriptions and other events from Daily. - - Args: - client: DailyTransportClient instance. - params: Configuration parameters. + Processes incoming audio, video, transcriptions and other events from Daily + room participants, including participant media capture and event forwarding. """ def __init__( @@ -966,6 +1306,14 @@ class DailyInputTransport(BaseInputTransport): params: DailyParams, **kwargs, ): + """Initialize the Daily input transport. + + Args: + transport: The parent transport instance. + client: DailyTransportClient instance. + params: Configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._transport = transport @@ -988,9 +1336,15 @@ class DailyInputTransport(BaseInputTransport): @property def vad_analyzer(self) -> Optional[VADAnalyzer]: + """Get the Voice Activity Detection analyzer. + + Returns: + The VAD analyzer instance if configured. + """ return self._vad_analyzer async def start_audio_in_streaming(self): + """Start receiving audio from participants.""" if not self._params.audio_in_enabled: return @@ -1003,15 +1357,26 @@ class DailyInputTransport(BaseInputTransport): self._streaming_started = True async def setup(self, setup: FrameProcessorSetup): + """Setup the input transport with shared client setup. + + Args: + setup: The frame processor setup configuration. + """ await super().setup(setup) await self._client.setup(setup) async def cleanup(self): + """Cleanup input transport and shared resources.""" await super().cleanup() await self._client.cleanup() await self._transport.cleanup() async def start(self, frame: StartFrame): + """Start the input transport and join the Daily room. + + Args: + frame: The start frame containing initialization parameters. + """ # Parent start. await super().start(frame) @@ -1033,12 +1398,22 @@ class DailyInputTransport(BaseInputTransport): await self.start_audio_in_streaming() async def stop(self, frame: EndFrame): + """Stop the input transport and leave the Daily room. + + Args: + frame: The end frame signaling transport shutdown. + """ # Parent stop. await super().stop(frame) # Leave the room. await self._client.leave() async def cancel(self, frame: CancelFrame): + """Cancel the input transport and leave the Daily room. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ # Parent stop. await super().cancel(frame) # Leave the room. @@ -1049,6 +1424,12 @@ class DailyInputTransport(BaseInputTransport): # async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process incoming frames, including user image requests. + + Args: + frame: The frame to process. + direction: The direction of frame flow in the pipeline. + """ await super().process_frame(frame, direction) if isinstance(frame, UserImageRequestFrame): @@ -1059,9 +1440,20 @@ class DailyInputTransport(BaseInputTransport): # async def push_transcription_frame(self, frame: TranscriptionFrame | InterimTranscriptionFrame): + """Push a transcription frame downstream. + + Args: + frame: The transcription frame to push. + """ await self.push_frame(frame) async def push_app_message(self, message: Any, sender: str): + """Push an application message as an urgent transport frame. + + Args: + message: The message data to send. + sender: ID of the message sender. + """ frame = DailyTransportMessageUrgentFrame(message=message, participant_id=sender) await self.push_frame(frame) @@ -1075,6 +1467,13 @@ class DailyInputTransport(BaseInputTransport): audio_source: str = "microphone", sample_rate: int = 16000, ): + """Capture audio from a specific participant. + + Args: + participant_id: ID of the participant to capture audio from. + audio_source: Audio source to capture from. + sample_rate: Desired sample rate for audio capture. + """ if self._streaming_started: await self._client.capture_participant_audio( participant_id, self._on_participant_audio_data, audio_source, sample_rate @@ -1085,6 +1484,7 @@ class DailyInputTransport(BaseInputTransport): async def _on_participant_audio_data( self, participant_id: str, audio: AudioData, audio_source: str ): + """Handle received participant audio data.""" frame = UserAudioRawFrame( user_id=participant_id, audio=audio.audio_frames, @@ -1105,6 +1505,14 @@ class DailyInputTransport(BaseInputTransport): video_source: str = "camera", color_format: str = "RGB", ): + """Capture video from a specific participant. + + Args: + participant_id: ID of the participant to capture video from. + framerate: Desired framerate for video capture. + video_source: Video source to capture from. + color_format: Color format for video frames. + """ if participant_id not in self._video_renderers: self._video_renderers[participant_id] = {} @@ -1119,6 +1527,11 @@ class DailyInputTransport(BaseInputTransport): ) async def request_participant_image(self, frame: UserImageRequestFrame): + """Request a video frame from a specific participant. + + Args: + frame: The user image request frame. + """ if frame.user_id in self._video_renderers: video_source = frame.video_source if frame.video_source else "camera" self._video_renderers[frame.user_id][video_source]["render_next_frame"].append(frame) @@ -1126,6 +1539,7 @@ class DailyInputTransport(BaseInputTransport): async def _on_participant_video_frame( self, participant_id: str, video_frame: VideoFrame, video_source: str ): + """Handle received participant video frames.""" render_frame = False curr_time = time.time() @@ -1161,16 +1575,21 @@ class DailyInputTransport(BaseInputTransport): class DailyOutputTransport(BaseOutputTransport): """Handles outgoing media streams and events to Daily calls. - Manages sending audio, video and other data to Daily calls. - - Args: - client: DailyTransportClient instance. - params: Configuration parameters. + Manages sending audio, video, DTMF tones, and other data to Daily calls, + including audio destination registration and message transmission. """ def __init__( self, transport: BaseTransport, client: DailyTransportClient, params: DailyParams, **kwargs ): + """Initialize the Daily output transport. + + Args: + transport: The parent transport instance. + client: DailyTransportClient instance. + params: Configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._transport = transport @@ -1180,15 +1599,26 @@ class DailyOutputTransport(BaseOutputTransport): self._initialized = False async def setup(self, setup: FrameProcessorSetup): + """Setup the output transport with shared client setup. + + Args: + setup: The frame processor setup configuration. + """ await super().setup(setup) await self._client.setup(setup) async def cleanup(self): + """Cleanup output transport and shared resources.""" await super().cleanup() await self._client.cleanup() await self._transport.cleanup() async def start(self, frame: StartFrame): + """Start the output transport and join the Daily room. + + Args: + frame: The start frame containing initialization parameters. + """ # Parent start. await super().start(frame) @@ -1207,27 +1637,57 @@ class DailyOutputTransport(BaseOutputTransport): await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): + """Stop the output transport and leave the Daily room. + + Args: + frame: The end frame signaling transport shutdown. + """ # Parent stop. await super().stop(frame) # Leave the room. await self._client.leave() async def cancel(self, frame: CancelFrame): + """Cancel the output transport and leave the Daily room. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ # Parent stop. await super().cancel(frame) # Leave the room. await self._client.leave() async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): + """Send a transport message to participants. + + Args: + frame: The transport message frame to send. + """ await self._client.send_message(frame) async def register_video_destination(self, destination: str): + """Register a video output destination. + + Args: + destination: The destination identifier to register. + """ logger.warning(f"{self} registering video destinations is not supported yet") async def register_audio_destination(self, destination: str): + """Register an audio output destination. + + Args: + destination: The destination identifier to register. + """ await self._client.register_audio_destination(destination) async def write_dtmf(self, frame: OutputDTMFFrame | OutputDTMFUrgentFrame): + """Write DTMF tones to the call. + + Args: + frame: The DTMF frame containing tone information. + """ await self._client.send_dtmf( { "sessionId": frame.transport_destination, @@ -1236,25 +1696,28 @@ class DailyOutputTransport(BaseOutputTransport): ) async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the Daily call. + + Args: + frame: The audio frame to write. + """ await self._client.write_audio_frame(frame) async def write_video_frame(self, frame: OutputImageRawFrame): + """Write a video frame to the Daily call. + + Args: + frame: The video frame to write. + """ await self._client.write_video_frame(frame) class DailyTransport(BaseTransport): """Transport implementation for Daily audio and video calls. - Handles audio/video streaming, transcription, recordings, dial-in, - dial-out, and call management through Daily's API. - - Args: - room_url: URL of the Daily room to connect to. - token: Optional authentication token for the room. - bot_name: Display name for the bot in the call. - params: Configuration parameters (DailyParams) for the transport. - input_name: Optional name for the input transport. - output_name: Optional name for the output transport. + Provides comprehensive Daily integration including audio/video streaming, + transcription, recording, dial-in/out functionality, and real-time communication + features for conversational AI applications. """ def __init__( @@ -1266,6 +1729,16 @@ class DailyTransport(BaseTransport): input_name: Optional[str] = None, output_name: Optional[str] = None, ): + """Initialize the Daily transport. + + Args: + room_url: URL of the Daily room to connect to. + token: Optional authentication token for the room. + bot_name: Display name for the bot in the call. + params: Configuration parameters for the transport. + input_name: Optional name for the input transport. + output_name: Optional name for the output transport. + """ super().__init__(input_name=input_name, output_name=output_name) callbacks = DailyCallbacks( @@ -1339,6 +1812,11 @@ class DailyTransport(BaseTransport): # def input(self) -> DailyInputTransport: + """Get the input transport for receiving media and events. + + Returns: + The Daily input transport instance. + """ if not self._input: self._input = DailyInputTransport( self, self._client, self._params, name=self._input_name @@ -1346,6 +1824,11 @@ class DailyTransport(BaseTransport): return self._input def output(self) -> DailyOutputTransport: + """Get the output transport for sending media and events. + + Returns: + The Daily output transport instance. + """ if not self._output: self._output = DailyOutputTransport( self, self._client, self._params, name=self._output_name @@ -1358,33 +1841,78 @@ class DailyTransport(BaseTransport): @property def room_url(self) -> str: + """Get the Daily room URL. + + Returns: + The room URL this transport is connected to. + """ return self._client.room_url @property def participant_id(self) -> str: + """Get the participant ID for this transport. + + Returns: + The participant ID assigned by Daily. + """ return self._client.participant_id async def send_image(self, frame: OutputImageRawFrame | SpriteFrame): + """Send an image frame to the Daily call. + + Args: + frame: The image frame to send. + """ if self._output: await self._output.queue_frame(frame, FrameDirection.DOWNSTREAM) async def send_audio(self, frame: OutputAudioRawFrame): + """Send an audio frame to the Daily call. + + Args: + frame: The audio frame to send. + """ if self._output: await self._output.queue_frame(frame, FrameDirection.DOWNSTREAM) def participants(self): + """Get current participants in the room. + + Returns: + Dictionary of participants keyed by participant ID. + """ return self._client.participants() def participant_counts(self): + """Get participant count information. + + Returns: + Dictionary with participant count details. + """ return self._client.participant_counts() async def start_dialout(self, settings=None): + """Start a dial-out call to a phone number. + + Args: + settings: Dial-out configuration settings. + """ await self._client.start_dialout(settings) async def stop_dialout(self, participant_id): + """Stop a dial-out call for a specific participant. + + Args: + participant_id: ID of the participant to stop dial-out for. + """ await self._client.stop_dialout(participant_id) async def send_dtmf(self, settings): + """Send DTMF tones during a call (deprecated). + + Args: + settings: DTMF settings including tones and target session. + """ import warnings with warnings.catch_warnings(): @@ -1396,33 +1924,66 @@ class DailyTransport(BaseTransport): await self._client.send_dtmf(settings) async def sip_call_transfer(self, settings): + """Transfer a SIP call to another destination. + + Args: + settings: SIP call transfer settings. + """ await self._client.sip_call_transfer(settings) async def sip_refer(self, settings): + """Send a SIP REFER request. + + Args: + settings: SIP REFER settings. + """ await self._client.sip_refer(settings) async def start_recording(self, streaming_settings=None, stream_id=None, force_new=None): + """Start recording the call. + + Args: + streaming_settings: Recording configuration settings. + stream_id: Unique identifier for the recording stream. + force_new: Whether to force a new recording session. + """ await self._client.start_recording(streaming_settings, stream_id, force_new) async def stop_recording(self, stream_id=None): + """Stop recording the call. + + Args: + stream_id: Unique identifier for the recording stream to stop. + """ await self._client.stop_recording(stream_id) async def start_transcription(self, settings=None): + """Start transcription for the call. + + Args: + settings: Transcription configuration settings. + """ await self._client.start_transcription(settings) async def stop_transcription(self): + """Stop transcription for the call.""" await self._client.stop_transcription() async def send_prebuilt_chat_message(self, message: str, user_name: Optional[str] = None): - """Sends a chat message to Daily's Prebuilt main room. + """Send a chat message to Daily's Prebuilt main room. Args: - message: The chat message to send - user_name: Optional user name that will appear as sender of the message + message: The chat message to send. + user_name: Optional user name that will appear as sender of the message. """ await self._client.send_prebuilt_chat_message(message, user_name) async def capture_participant_transcription(self, participant_id: str): + """Enable transcription capture for a specific participant. + + Args: + participant_id: ID of the participant to capture transcription for. + """ await self._client.capture_participant_transcription(participant_id) async def capture_participant_audio( @@ -1431,6 +1992,13 @@ class DailyTransport(BaseTransport): audio_source: str = "microphone", sample_rate: int = 16000, ): + """Capture audio from a specific participant. + + Args: + participant_id: ID of the participant to capture audio from. + audio_source: Audio source to capture from. + sample_rate: Desired sample rate for audio capture. + """ if self._input: await self._input.capture_participant_audio(participant_id, audio_source, sample_rate) @@ -1441,32 +2009,60 @@ class DailyTransport(BaseTransport): video_source: str = "camera", color_format: str = "RGB", ): + """Capture video from a specific participant. + + Args: + participant_id: ID of the participant to capture video from. + framerate: Desired framerate for video capture. + video_source: Video source to capture from. + color_format: Color format for video frames. + """ if self._input: await self._input.capture_participant_video( participant_id, framerate, video_source, color_format ) async def update_publishing(self, publishing_settings: Mapping[str, Any]): + """Update media publishing settings. + + Args: + publishing_settings: Publishing configuration settings. + """ await self._client.update_publishing(publishing_settings=publishing_settings) async def update_subscriptions(self, participant_settings=None, profile_settings=None): + """Update media subscription settings. + + Args: + participant_settings: Per-participant subscription settings. + profile_settings: Global subscription profile settings. + """ await self._client.update_subscriptions( participant_settings=participant_settings, profile_settings=profile_settings ) async def update_remote_participants(self, remote_participants: Mapping[str, Any]): + """Update settings for remote participants. + + Args: + remote_participants: Remote participant configuration settings. + """ await self._client.update_remote_participants(remote_participants=remote_participants) async def _on_active_speaker_changed(self, participant: Any): + """Handle active speaker change events.""" await self._call_event_handler("on_active_speaker_changed", participant) async def _on_joined(self, data): + """Handle room joined events.""" await self._call_event_handler("on_joined", data) async def _on_left(self): + """Handle room left events.""" await self._call_event_handler("on_left") async def _on_error(self, error): + """Handle error events and push error frames.""" await self._call_event_handler("on_error", error) # Push error frame to notify the pipeline error_frame = ErrorFrame(error) @@ -1480,20 +2076,25 @@ class DailyTransport(BaseTransport): raise Exception("No valid input or output channel to push error") async def _on_app_message(self, message: Any, sender: str): + """Handle application message events.""" if self._input: await self._input.push_app_message(message, sender) await self._call_event_handler("on_app_message", message, sender) async def _on_call_state_updated(self, state: str): + """Handle call state update events.""" await self._call_event_handler("on_call_state_updated", state) async def _on_client_connected(self, participant: Any): + """Handle client connected events.""" await self._call_event_handler("on_client_connected", participant) async def _on_client_disconnected(self, participant: Any): + """Handle client disconnected events.""" await self._call_event_handler("on_client_disconnected", participant) async def _handle_dialin_ready(self, sip_endpoint: str): + """Handle dial-in ready events by updating SIP configuration.""" if not self._params.dialin_settings: return @@ -1528,38 +2129,49 @@ class DailyTransport(BaseTransport): logger.exception(f"Error handling dialin-ready event ({url}): {e}") async def _on_dialin_connected(self, data): + """Handle dial-in connected events.""" await self._call_event_handler("on_dialin_connected", data) async def _on_dialin_ready(self, sip_endpoint): + """Handle dial-in ready events.""" if self._params.dialin_settings: await self._handle_dialin_ready(sip_endpoint) await self._call_event_handler("on_dialin_ready", sip_endpoint) async def _on_dialin_stopped(self, data): + """Handle dial-in stopped events.""" await self._call_event_handler("on_dialin_stopped", data) async def _on_dialin_error(self, data): + """Handle dial-in error events.""" await self._call_event_handler("on_dialin_error", data) async def _on_dialin_warning(self, data): + """Handle dial-in warning events.""" await self._call_event_handler("on_dialin_warning", data) async def _on_dialout_answered(self, data): + """Handle dial-out answered events.""" await self._call_event_handler("on_dialout_answered", data) async def _on_dialout_connected(self, data): + """Handle dial-out connected events.""" await self._call_event_handler("on_dialout_connected", data) async def _on_dialout_stopped(self, data): + """Handle dial-out stopped events.""" await self._call_event_handler("on_dialout_stopped", data) async def _on_dialout_error(self, data): + """Handle dial-out error events.""" await self._call_event_handler("on_dialout_error", data) async def _on_dialout_warning(self, data): + """Handle dial-out warning events.""" await self._call_event_handler("on_dialout_warning", data) async def _on_participant_joined(self, participant): + """Handle participant joined events.""" id = participant["id"] logger.info(f"Participant joined {id}") @@ -1577,6 +2189,7 @@ class DailyTransport(BaseTransport): await self._call_event_handler("on_client_connected", participant) async def _on_participant_left(self, participant, reason): + """Handle participant left events.""" id = participant["id"] logger.info(f"Participant left {id}") await self._call_event_handler("on_participant_left", participant, reason) @@ -1584,9 +2197,11 @@ class DailyTransport(BaseTransport): await self._call_event_handler("on_client_disconnected", participant) async def _on_participant_updated(self, participant): + """Handle participant updated events.""" await self._call_event_handler("on_participant_updated", participant) async def _on_transcription_message(self, message): + """Handle transcription message events.""" await self._call_event_handler("on_transcription_message", message) participant_id = "" @@ -1619,10 +2234,13 @@ class DailyTransport(BaseTransport): await self._input.push_transcription_frame(frame) async def _on_recording_started(self, status): + """Handle recording started events.""" await self._call_event_handler("on_recording_started", status) async def _on_recording_stopped(self, stream_id): + """Handle recording stopped events.""" await self._call_event_handler("on_recording_stopped", stream_id) async def _on_recording_error(self, stream_id, message): + """Handle recording error events.""" await self._call_event_handler("on_recording_error", stream_id, message) diff --git a/src/pipecat/transports/services/helpers/daily_rest.py b/src/pipecat/transports/services/helpers/daily_rest.py index 796022920..a283b4dfc 100644 --- a/src/pipecat/transports/services/helpers/daily_rest.py +++ b/src/pipecat/transports/services/helpers/daily_rest.py @@ -20,11 +20,11 @@ from pydantic import BaseModel, Field, ValidationError class DailyRoomSipParams(BaseModel): """SIP configuration parameters for Daily rooms. - Attributes: - display_name: Name shown for the SIP endpoint - video: Whether video is enabled for SIP - sip_mode: SIP connection mode, typically 'dial-in' - num_endpoints: Number of allowed SIP endpoints + Parameters: + display_name: Name shown for the SIP endpoint. + video: Whether video is enabled for SIP. + sip_mode: SIP connection mode, typically 'dial-in'. + num_endpoints: Number of allowed SIP endpoints. """ display_name: str = "sw-sip-dialin" @@ -38,6 +38,12 @@ class RecordingsBucketConfig(BaseModel): Refer to the Daily API documentation for more information: https://docs.daily.co/guides/products/live-streaming-recording/storing-recordings-in-a-custom-s3-bucket + + Parameters: + bucket_name: Name of the S3 bucket for storing recordings. + bucket_region: AWS region where the S3 bucket is located. + assume_role_arn: ARN of the IAM role to assume for S3 access. + allow_api_access: Whether to allow API access to the recordings. """ bucket_name: str @@ -49,21 +55,22 @@ class RecordingsBucketConfig(BaseModel): class DailyRoomProperties(BaseModel, extra="allow"): """Properties for configuring a Daily room. - Attributes: - exp: Optional Unix epoch timestamp for room expiration (e.g., time.time() + 300 for 5 minutes) - enable_chat: Whether chat is enabled in the room - enable_prejoin_ui: Whether the pre-join UI is enabled - enable_emoji_reactions: Whether emoji reactions are enabled - eject_at_room_exp: Whether to remove participants when room expires - enable_dialout: Whether SIP dial-out is enabled - enable_recording: Recording settings ('cloud', 'local', 'raw-tracks') - geo: Geographic region for room - max_participants: Maximum number of participants allowed in the room - sip: SIP configuration parameters - sip_uri: SIP URI information returned by Daily - start_video_off: Whether video is off by default - Reference: https://docs.daily.co/reference/rest-api/rooms/create-room#properties + + Parameters: + exp: Optional Unix epoch timestamp for room expiration (e.g., time.time() + 300 for 5 minutes). + enable_chat: Whether chat is enabled in the room. + enable_prejoin_ui: Whether the pre-join UI is enabled. + enable_emoji_reactions: Whether emoji reactions are enabled. + eject_at_room_exp: Whether to remove participants when room expires. + enable_dialout: Whether SIP dial-out is enabled. + enable_recording: Recording settings ('cloud', 'local', 'raw-tracks'). + geo: Geographic region for room. + max_participants: Maximum number of participants allowed in the room. + recordings_bucket: Configuration for custom S3 bucket recordings. + sip: SIP configuration parameters. + sip_uri: SIP URI information returned by Daily. + start_video_off: Whether video is off by default. """ exp: Optional[float] = None @@ -85,7 +92,7 @@ class DailyRoomProperties(BaseModel, extra="allow"): """Get the SIP endpoint URI if available. Returns: - str: SIP endpoint URI or empty string if not available + SIP endpoint URI or empty string if not available. """ if not self.sip_uri: return "" @@ -96,10 +103,10 @@ class DailyRoomProperties(BaseModel, extra="allow"): class DailyRoomParams(BaseModel): """Parameters for creating a Daily room. - Attributes: - name: Optional custom name for the room - privacy: Room privacy setting ('private' or 'public') - properties: Room configuration properties + Parameters: + name: Optional custom name for the room. + privacy: Room privacy setting ('private' or 'public'). + properties: Room configuration properties. """ name: Optional[str] = None @@ -110,14 +117,14 @@ class DailyRoomParams(BaseModel): class DailyRoomObject(BaseModel): """Represents a Daily room returned by the API. - Attributes: - id: Unique room identifier - name: Room name - api_created: Whether room was created via API - privacy: Room privacy setting ('private' or 'public') - url: Full URL for joining the room + Parameters: + id: Unique room identifier. + name: Room name. + api_created: Whether room was created via API. + privacy: Room privacy setting ('private' or 'public'). + url: Full URL for joining the room. created_at: Timestamp of room creation in ISO 8601 format (e.g., "2019-01-26T09:01:22.000Z"). - config: Room configuration properties + config: Room configuration properties. """ id: str @@ -134,71 +141,40 @@ class DailyMeetingTokenProperties(BaseModel): Refer to the Daily API documentation for more information: https://docs.daily.co/reference/rest-api/meeting-tokens/create-meeting-token#properties + + Parameters: + room_name: The room for which this token is valid. If not set, the token is valid for all rooms in your domain. + eject_at_token_exp: If True, the user will be ejected from the room when the token expires. + eject_after_elapsed: The number of seconds after which the user will be ejected from the room. + nbf: Not before timestamp - users cannot join with this token before this time. + exp: Expiration time (unix timestamp in seconds). Strongly recommended for security. + is_owner: If True, the token will grant owner privileges in the room. + user_name: The name of the user. This will be added to the token payload. + user_id: A unique identifier for the user. This will be added to the token payload. + enable_screenshare: If True, the user will be able to share their screen. + start_video_off: If True, the user's video will be turned off when they join the room. + start_audio_off: If True, the user's audio will be turned off when they join the room. + enable_recording: Recording settings for the token. Must be one of 'cloud', 'local' or 'raw-tracks'. + enable_prejoin_ui: If True, the user will see the prejoin UI before joining the room. + start_cloud_recording: Start cloud recording when the user joins the room. + permissions: Specifies the initial default permissions for a non-meeting-owner participant. """ - room_name: Optional[str] = Field( - default=None, - description="The room for which this token is valid. If not set, the token is valid for all rooms in your domain. You should always set room_name if using this token to control meeting access.", - ) - - eject_at_token_exp: Optional[bool] = Field( - default=None, - description="If `true`, the user will be ejected from the room when the token expires. Defaults to `false`.", - ) - eject_after_elapsed: Optional[int] = Field( - default=None, - description="The number of seconds after which the user will be ejected from the room. If not provided, the user will not be ejected based on elapsed time.", - ) - - nbf: Optional[int] = Field( - default=None, - description="Not before. This is a unix timestamp (seconds since the epoch.) Users cannot join a meeting in with this token before this time.", - ) - - exp: Optional[int] = Field( - default=None, - description="Expiration time (unix timestamp in seconds). We strongly recommend setting this value for security. If not set, the token will not expire. Refer docs for more info.", - ) - is_owner: Optional[bool] = Field( - default=None, - description="If `true`, the token will grant owner privileges in the room. Defaults to `false`.", - ) - user_name: Optional[str] = Field( - default=None, - description="The name of the user. This will be added to the token payload.", - ) - user_id: Optional[str] = Field( - default=None, - description="A unique identifier for the user. This will be added to the token payload.", - ) - enable_screenshare: Optional[bool] = Field( - default=None, - description="If `true`, the user will be able to share their screen. Defaults to `true`.", - ) - start_video_off: Optional[bool] = Field( - default=None, - description="If `true`, the user's video will be turned off when they join the room. Defaults to `false`.", - ) - start_audio_off: Optional[bool] = Field( - default=None, - description="If `true`, the user's audio will be turned off when they join the room. Defaults to `false`.", - ) - enable_recording: Optional[Literal["cloud", "local", "raw-tracks"]] = Field( - default=None, - description="Recording settings for the token. Must be one of `cloud`, `local` or `raw-tracks`.", - ) - enable_prejoin_ui: Optional[bool] = Field( - default=None, - description="If `true`, the user will see the prejoin UI before joining the room.", - ) - start_cloud_recording: Optional[bool] = Field( - default=None, - description="Start cloud recording when the user joins the room. This can be used to always record and archive meetings, for example in a customer support context.", - ) - permissions: Optional[dict] = Field( - default=None, - description="Specifies the initial default permissions for a non-meeting-owner participant joining a call.", - ) + room_name: Optional[str] = None + eject_at_token_exp: Optional[bool] = None + eject_after_elapsed: Optional[int] = None + nbf: Optional[int] = None + exp: Optional[int] = None + is_owner: Optional[bool] = None + user_name: Optional[str] = None + user_id: Optional[str] = None + enable_screenshare: Optional[bool] = None + start_video_off: Optional[bool] = None + start_audio_off: Optional[bool] = None + enable_recording: Optional[Literal["cloud", "local", "raw-tracks"]] = None + enable_prejoin_ui: Optional[bool] = None + start_cloud_recording: Optional[bool] = None + permissions: Optional[dict] = None class DailyMeetingTokenParams(BaseModel): @@ -206,6 +182,9 @@ class DailyMeetingTokenParams(BaseModel): Refer to the Daily API documentation for more information: https://docs.daily.co/reference/rest-api/meeting-tokens/create-meeting-token#body-params + + Parameters: + properties: Meeting token configuration properties. """ properties: DailyMeetingTokenProperties = Field(default_factory=DailyMeetingTokenProperties) @@ -215,11 +194,6 @@ class DailyRESTHelper: """Helper class for interacting with Daily's REST API. Provides methods for creating, managing, and accessing Daily rooms. - - Args: - daily_api_key: Your Daily API key - daily_api_url: Daily API base URL (e.g. "https://api.daily.co/v1") - aiohttp_session: Async HTTP session for making requests """ def __init__( @@ -229,7 +203,13 @@ class DailyRESTHelper: daily_api_url: str = "https://api.daily.co/v1", aiohttp_session: aiohttp.ClientSession, ): - """Initialize the Daily REST helper.""" + """Initialize the Daily REST helper. + + Args: + daily_api_key: Your Daily API key. + daily_api_url: Daily API base URL (e.g. "https://api.daily.co/v1"). + aiohttp_session: Async HTTP session for making requests. + """ self.daily_api_key = daily_api_key self.daily_api_url = daily_api_url self.aiohttp_session = aiohttp_session @@ -238,10 +218,10 @@ class DailyRESTHelper: """Extract room name from a Daily room URL. Args: - room_url: Full Daily room URL + room_url: Full Daily room URL. Returns: - str: Room name portion of the URL + Room name portion of the URL. """ return urlparse(room_url).path[1:] @@ -249,10 +229,10 @@ class DailyRESTHelper: """Get room details from a Daily room URL. Args: - room_url: Full Daily room URL + room_url: Full Daily room URL. Returns: - DailyRoomObject: DailyRoomObject instance for the room + DailyRoomObject instance for the room. """ room_name = self.get_name_from_url(room_url) return await self._get_room_from_name(room_name) @@ -261,13 +241,13 @@ class DailyRESTHelper: """Create a new Daily room. Args: - params: Room configuration parameters + params: Room configuration parameters. Returns: - DailyRoomObject: DailyRoomObject instance for the created room + DailyRoomObject instance for the created room. Raises: - Exception: If room creation fails or response is invalid + Exception: If room creation fails or response is invalid. """ headers = {"Authorization": f"Bearer {self.daily_api_key}"} json = params.model_dump(exclude_none=True) @@ -298,19 +278,19 @@ class DailyRESTHelper: """Generate a meeting token for user to join a Daily room. Args: - room_url: Daily room URL - expiry_time: Token validity duration in seconds (default: 1 hour) - eject_at_token_exp: Whether to eject user when token expires - owner: Whether token has owner privileges + room_url: Daily room URL. + expiry_time: Token validity duration in seconds (default: 1 hour). + eject_at_token_exp: Whether to eject user when token expires. + owner: Whether token has owner privileges. params: Optional additional token properties. Note that room_name, exp, and is_owner will be set based on the other function parameters regardless of values in params. Returns: - str: Meeting token + Meeting token. Raises: - Exception: If token generation fails or room URL is missing + Exception: If token generation fails or room URL is missing. """ if not room_url: raise Exception( @@ -355,10 +335,10 @@ class DailyRESTHelper: """Delete a room using its URL. Args: - room_url: Daily room URL + room_url: Daily room URL. Returns: - bool: True if deletion was successful + True if deletion was successful. """ room_name = self.get_name_from_url(room_url) return await self.delete_room_by_name(room_name) @@ -367,13 +347,13 @@ class DailyRESTHelper: """Delete a room using its name. Args: - room_name: Name of the room to delete + room_name: Name of the room to delete. Returns: - bool: True if deletion was successful + True if deletion was successful. Raises: - Exception: If deletion fails (excluding 404 Not Found) + Exception: If deletion fails (excluding 404 Not Found). """ headers = {"Authorization": f"Bearer {self.daily_api_key}"} async with self.aiohttp_session.delete( @@ -386,17 +366,7 @@ class DailyRESTHelper: return True async def _get_room_from_name(self, room_name: str) -> DailyRoomObject: - """Internal method to get room details by name. - - Args: - room_name: Name of the room - - Returns: - DailyRoomObject: DailyRoomObject instance for the room - - Raises: - Exception: If room is not found or response is invalid - """ + """Internal method to get room details by name.""" headers = {"Authorization": f"Bearer {self.daily_api_key}"} async with self.aiohttp_session.get( f"{self.daily_api_url}/rooms/{room_name}", headers=headers diff --git a/src/pipecat/transports/services/livekit.py b/src/pipecat/transports/services/livekit.py index 53dd091ef..d7363b9e7 100644 --- a/src/pipecat/transports/services/livekit.py +++ b/src/pipecat/transports/services/livekit.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""LiveKit transport implementation for Pipecat. + +This module provides comprehensive LiveKit real-time communication integration +including audio streaming, data messaging, participant management, and room +event handling for conversational AI applications. +""" + import asyncio from dataclasses import dataclass from typing import Any, Awaitable, Callable, List, Optional @@ -41,19 +48,49 @@ except ModuleNotFoundError as e: @dataclass class LiveKitTransportMessageFrame(TransportMessageFrame): + """Frame for transport messages in LiveKit rooms. + + Parameters: + participant_id: Optional ID of the participant this message is for/from. + """ + participant_id: Optional[str] = None @dataclass class LiveKitTransportMessageUrgentFrame(TransportMessageUrgentFrame): + """Frame for urgent transport messages in LiveKit rooms. + + Parameters: + participant_id: Optional ID of the participant this message is for/from. + """ + participant_id: Optional[str] = None class LiveKitParams(TransportParams): + """Configuration parameters for LiveKit transport. + + Inherits all parameters from TransportParams without additional configuration. + """ + pass class LiveKitCallbacks(BaseModel): + """Callback handlers for LiveKit events. + + Parameters: + on_connected: Called when connected to the LiveKit room. + on_disconnected: Called when disconnected from the LiveKit room. + on_participant_connected: Called when a participant joins the room. + on_participant_disconnected: Called when a participant leaves the room. + on_audio_track_subscribed: Called when an audio track is subscribed. + on_audio_track_unsubscribed: Called when an audio track is unsubscribed. + on_data_received: Called when data is received from a participant. + on_first_participant_joined: Called when the first participant joins. + """ + on_connected: Callable[[], Awaitable[None]] on_disconnected: Callable[[], Awaitable[None]] on_participant_connected: Callable[[str], Awaitable[None]] @@ -65,6 +102,12 @@ class LiveKitCallbacks(BaseModel): class LiveKitTransportClient: + """Core client for interacting with LiveKit rooms. + + Manages the connection to LiveKit rooms and handles all low-level API interactions + including room management, audio streaming, data messaging, and event handling. + """ + def __init__( self, url: str, @@ -74,6 +117,16 @@ class LiveKitTransportClient: callbacks: LiveKitCallbacks, transport_name: str, ): + """Initialize the LiveKit transport client. + + Args: + url: LiveKit server URL to connect to. + token: Authentication token for the room. + room_name: Name of the LiveKit room to join. + params: Configuration parameters for the transport. + callbacks: Event callback handlers. + transport_name: Name identifier for the transport. + """ self._url = url self._token = token self._room_name = room_name @@ -93,15 +146,33 @@ class LiveKitTransportClient: @property def participant_id(self) -> str: + """Get the participant ID for this client. + + Returns: + The participant ID assigned by LiveKit. + """ return self._participant_id @property def room(self) -> rtc.Room: + """Get the LiveKit room instance. + + Returns: + The LiveKit room object. + + Raises: + Exception: If room object is not available. + """ if not self._room: raise Exception(f"{self}: missing room object (pipeline not started?)") return self._room async def setup(self, setup: FrameProcessorSetup): + """Setup the client with task manager and room initialization. + + Args: + setup: The frame processor setup configuration. + """ if self._task_manager: return @@ -118,13 +189,20 @@ class LiveKitTransportClient: self.room.on("disconnected")(self._on_disconnected_wrapper) async def cleanup(self): + """Cleanup client resources.""" await self.disconnect() async def start(self, frame: StartFrame): + """Start the client and initialize audio components. + + Args: + frame: The start frame containing initialization parameters. + """ self._out_sample_rate = self._params.audio_out_sample_rate or frame.audio_out_sample_rate @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10)) async def connect(self): + """Connect to the LiveKit room with retry logic.""" if self._connected: # Increment disconnect counter if already connected. self._disconnect_counter += 1 @@ -168,6 +246,7 @@ class LiveKitTransportClient: raise async def disconnect(self): + """Disconnect from the LiveKit room.""" # Decrement leave counter when leaving. self._disconnect_counter -= 1 @@ -181,6 +260,12 @@ class LiveKitTransportClient: await self._callbacks.on_disconnected() async def send_data(self, data: bytes, participant_id: Optional[str] = None): + """Send data to participants in the room. + + Args: + data: The data bytes to send. + participant_id: Optional specific participant to send to. + """ if not self._connected: return @@ -195,6 +280,11 @@ class LiveKitTransportClient: logger.error(f"Error sending data: {e}") async def publish_audio(self, audio_frame: rtc.AudioFrame): + """Publish an audio frame to the room. + + Args: + audio_frame: The LiveKit audio frame to publish. + """ if not self._connected or not self._audio_source: return @@ -204,9 +294,22 @@ class LiveKitTransportClient: logger.error(f"Error publishing audio: {e}") def get_participants(self) -> List[str]: + """Get list of participant IDs in the room. + + Returns: + List of participant IDs. + """ return [p.sid for p in self.room.remote_participants.values()] async def get_participant_metadata(self, participant_id: str) -> dict: + """Get metadata for a specific participant. + + Args: + participant_id: ID of the participant to get metadata for. + + Returns: + Dictionary containing participant metadata. + """ participant = self.room.remote_participants.get(participant_id) if participant: return { @@ -218,9 +321,19 @@ class LiveKitTransportClient: return {} async def set_participant_metadata(self, metadata: str): + """Set metadata for the local participant. + + Args: + metadata: Metadata string to set. + """ await self.room.local_participant.set_metadata(metadata) async def mute_participant(self, participant_id: str): + """Mute a specific participant's audio tracks. + + Args: + participant_id: ID of the participant to mute. + """ participant = self.room.remote_participants.get(participant_id) if participant: for track in participant.tracks.values(): @@ -228,6 +341,11 @@ class LiveKitTransportClient: await track.set_enabled(False) async def unmute_participant(self, participant_id: str): + """Unmute a specific participant's audio tracks. + + Args: + participant_id: ID of the participant to unmute. + """ participant = self.room.remote_participants.get(participant_id) if participant: for track in participant.tracks.values(): @@ -236,12 +354,14 @@ class LiveKitTransportClient: # Wrapper methods for event handlers def _on_participant_connected_wrapper(self, participant: rtc.RemoteParticipant): + """Wrapper for participant connected events.""" self._task_manager.create_task( self._async_on_participant_connected(participant), f"{self}::_async_on_participant_connected", ) def _on_participant_disconnected_wrapper(self, participant: rtc.RemoteParticipant): + """Wrapper for participant disconnected events.""" self._task_manager.create_task( self._async_on_participant_disconnected(participant), f"{self}::_async_on_participant_disconnected", @@ -253,6 +373,7 @@ class LiveKitTransportClient: publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant, ): + """Wrapper for track subscribed events.""" self._task_manager.create_task( self._async_on_track_subscribed(track, publication, participant), f"{self}::_async_on_track_subscribed", @@ -264,27 +385,32 @@ class LiveKitTransportClient: publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant, ): + """Wrapper for track unsubscribed events.""" self._task_manager.create_task( self._async_on_track_unsubscribed(track, publication, participant), f"{self}::_async_on_track_unsubscribed", ) def _on_data_received_wrapper(self, data: rtc.DataPacket): + """Wrapper for data received events.""" self._task_manager.create_task( self._async_on_data_received(data), f"{self}::_async_on_data_received", ) def _on_connected_wrapper(self): + """Wrapper for connected events.""" self._task_manager.create_task(self._async_on_connected(), f"{self}::_async_on_connected") def _on_disconnected_wrapper(self): + """Wrapper for disconnected events.""" self._task_manager.create_task( self._async_on_disconnected(), f"{self}::_async_on_disconnected" ) # Async methods for event handling async def _async_on_participant_connected(self, participant: rtc.RemoteParticipant): + """Handle participant connected events.""" logger.info(f"Participant connected: {participant.identity}") await self._callbacks.on_participant_connected(participant.sid) if not self._other_participant_has_joined: @@ -292,6 +418,7 @@ class LiveKitTransportClient: await self._callbacks.on_first_participant_joined(participant.sid) async def _async_on_participant_disconnected(self, participant: rtc.RemoteParticipant): + """Handle participant disconnected events.""" logger.info(f"Participant disconnected: {participant.identity}") await self._callbacks.on_participant_disconnected(participant.sid) if len(self.get_participants()) == 0: @@ -303,6 +430,7 @@ class LiveKitTransportClient: publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant, ): + """Handle track subscribed events.""" if track.kind == rtc.TrackKind.KIND_AUDIO: logger.info(f"Audio track subscribed: {track.sid} from participant {participant.sid}") self._audio_tracks[participant.sid] = track @@ -318,22 +446,27 @@ class LiveKitTransportClient: publication: rtc.RemoteTrackPublication, participant: rtc.RemoteParticipant, ): + """Handle track unsubscribed events.""" logger.info(f"Track unsubscribed: {publication.sid} from {participant.identity}") if track.kind == rtc.TrackKind.KIND_AUDIO: await self._callbacks.on_audio_track_unsubscribed(participant.sid) async def _async_on_data_received(self, data: rtc.DataPacket): + """Handle data received events.""" await self._callbacks.on_data_received(data.data, data.participant.sid) async def _async_on_connected(self): + """Handle connected events.""" await self._callbacks.on_connected() async def _async_on_disconnected(self, reason=None): + """Handle disconnected events.""" self._connected = False logger.info(f"Disconnected from {self._room_name}. Reason: {reason}") await self._callbacks.on_disconnected() async def _process_audio_stream(self, audio_stream: rtc.AudioStream, participant_id: str): + """Process incoming audio stream from a participant.""" logger.info(f"Started processing audio stream for participant {participant_id}") async for event in audio_stream: if isinstance(event, rtc.AudioFrameEvent): @@ -342,15 +475,23 @@ class LiveKitTransportClient: logger.warning(f"Received unexpected event type: {type(event)}") async def get_next_audio_frame(self): + """Get the next audio frame from the queue.""" while True: frame, participant_id = await self._audio_queue.get() yield frame, participant_id def __str__(self): + """String representation of the LiveKit transport client.""" return f"{self._transport_name}::LiveKitTransportClient" class LiveKitInputTransport(BaseInputTransport): + """Handles incoming media streams and events from LiveKit rooms. + + Processes incoming audio streams from room participants and forwards them + as Pipecat frames, including audio resampling and VAD integration. + """ + def __init__( self, transport: BaseTransport, @@ -358,6 +499,14 @@ class LiveKitInputTransport(BaseInputTransport): params: LiveKitParams, **kwargs, ): + """Initialize the LiveKit input transport. + + Args: + transport: The parent transport instance. + client: LiveKitTransportClient instance. + params: Configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._transport = transport self._client = client @@ -371,9 +520,19 @@ class LiveKitInputTransport(BaseInputTransport): @property def vad_analyzer(self) -> Optional[VADAnalyzer]: + """Get the Voice Activity Detection analyzer. + + Returns: + The VAD analyzer instance if configured. + """ return self._vad_analyzer async def start(self, frame: StartFrame): + """Start the input transport and connect to LiveKit room. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -389,6 +548,11 @@ class LiveKitInputTransport(BaseInputTransport): logger.info("LiveKitInputTransport started") async def stop(self, frame: EndFrame): + """Stop the input transport and disconnect from LiveKit room. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) await self._client.disconnect() if self._audio_in_task: @@ -396,24 +560,42 @@ class LiveKitInputTransport(BaseInputTransport): logger.info("LiveKitInputTransport stopped") async def cancel(self, frame: CancelFrame): + """Cancel the input transport and disconnect from LiveKit room. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) await self._client.disconnect() if self._audio_in_task and self._params.audio_in_enabled: await self.cancel_task(self._audio_in_task) async def setup(self, setup: FrameProcessorSetup): + """Setup the input transport with shared client setup. + + Args: + setup: The frame processor setup configuration. + """ await super().setup(setup) await self._client.setup(setup) async def cleanup(self): + """Cleanup input transport and shared resources.""" await super().cleanup() await self._transport.cleanup() async def push_app_message(self, message: Any, sender: str): + """Push an application message as an urgent transport frame. + + Args: + message: The message data to send. + sender: ID of the message sender. + """ frame = LiveKitTransportMessageUrgentFrame(message=message, participant_id=sender) await self.push_frame(frame) async def _audio_in_task_handler(self): + """Handle incoming audio frames from participants.""" logger.info("Audio input task started") audio_iterator = self._client.get_next_audio_frame() async for audio_data in WatchdogAsyncIterator(audio_iterator, manager=self.task_manager): @@ -433,6 +615,7 @@ class LiveKitInputTransport(BaseInputTransport): async def _convert_livekit_audio_to_pipecat( self, audio_frame_event: rtc.AudioFrameEvent ) -> AudioRawFrame: + """Convert LiveKit audio frame to Pipecat audio frame.""" audio_frame = audio_frame_event.frame audio_data = await self._resampler.resample( @@ -447,6 +630,12 @@ class LiveKitInputTransport(BaseInputTransport): class LiveKitOutputTransport(BaseOutputTransport): + """Handles outgoing media streams and events to LiveKit rooms. + + Manages sending audio frames and data messages to LiveKit room participants, + including audio format conversion for LiveKit compatibility. + """ + def __init__( self, transport: BaseTransport, @@ -454,6 +643,14 @@ class LiveKitOutputTransport(BaseOutputTransport): params: LiveKitParams, **kwargs, ): + """Initialize the LiveKit output transport. + + Args: + transport: The parent transport instance. + client: LiveKitTransportClient instance. + params: Configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._transport = transport self._client = client @@ -462,6 +659,11 @@ class LiveKitOutputTransport(BaseOutputTransport): self._initialized = False async def start(self, frame: StartFrame): + """Start the output transport and connect to LiveKit room. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -475,33 +677,60 @@ class LiveKitOutputTransport(BaseOutputTransport): logger.info("LiveKitOutputTransport started") async def stop(self, frame: EndFrame): + """Stop the output transport and disconnect from LiveKit room. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) await self._client.disconnect() logger.info("LiveKitOutputTransport stopped") async def cancel(self, frame: CancelFrame): + """Cancel the output transport and disconnect from LiveKit room. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) await self._client.disconnect() async def setup(self, setup: FrameProcessorSetup): + """Setup the output transport with shared client setup. + + Args: + setup: The frame processor setup configuration. + """ await super().setup(setup) await self._client.setup(setup) async def cleanup(self): + """Cleanup output transport and shared resources.""" await super().cleanup() await self._transport.cleanup() async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): + """Send a transport message to participants. + + Args: + frame: The transport message frame to send. + """ if isinstance(frame, (LiveKitTransportMessageFrame, LiveKitTransportMessageUrgentFrame)): await self._client.send_data(frame.message.encode(), frame.participant_id) else: await self._client.send_data(frame.message.encode()) async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the LiveKit room. + + Args: + frame: The audio frame to write. + """ livekit_audio = self._convert_pipecat_audio_to_livekit(frame.audio) await self._client.publish_audio(livekit_audio) def _convert_pipecat_audio_to_livekit(self, pipecat_audio: bytes) -> rtc.AudioFrame: + """Convert Pipecat audio data to LiveKit audio frame.""" bytes_per_sample = 2 # Assuming 16-bit audio total_samples = len(pipecat_audio) // bytes_per_sample samples_per_channel = total_samples // self._params.audio_out_channels @@ -515,6 +744,13 @@ class LiveKitOutputTransport(BaseOutputTransport): class LiveKitTransport(BaseTransport): + """Transport implementation for LiveKit real-time communication. + + Provides comprehensive LiveKit integration including audio streaming, data + messaging, participant management, and room event handling for conversational + AI applications. + """ + def __init__( self, url: str, @@ -524,6 +760,16 @@ class LiveKitTransport(BaseTransport): input_name: Optional[str] = None, output_name: Optional[str] = None, ): + """Initialize the LiveKit transport. + + Args: + url: LiveKit server URL to connect to. + token: Authentication token for the room. + room_name: Name of the LiveKit room to join. + params: Configuration parameters for the transport. + input_name: Optional name for the input transport. + output_name: Optional name for the output transport. + """ super().__init__(input_name=input_name, output_name=output_name) callbacks = LiveKitCallbacks( @@ -556,6 +802,11 @@ class LiveKitTransport(BaseTransport): self._register_event_handler("on_call_state_updated") def input(self) -> LiveKitInputTransport: + """Get the input transport for receiving media and events. + + Returns: + The LiveKit input transport instance. + """ if not self._input: self._input = LiveKitInputTransport( self, self._client, self._params, name=self._input_name @@ -563,6 +814,11 @@ class LiveKitTransport(BaseTransport): return self._input def output(self) -> LiveKitOutputTransport: + """Get the output transport for sending media and events. + + Returns: + The LiveKit output transport instance. + """ if not self._output: self._output = LiveKitOutputTransport( self, self._client, self._params, name=self._output_name @@ -571,41 +827,84 @@ class LiveKitTransport(BaseTransport): @property def participant_id(self) -> str: + """Get the participant ID for this transport. + + Returns: + The participant ID assigned by LiveKit. + """ return self._client.participant_id async def send_audio(self, frame: OutputAudioRawFrame): + """Send an audio frame to the LiveKit room. + + Args: + frame: The audio frame to send. + """ if self._output: await self._output.queue_frame(frame, FrameDirection.DOWNSTREAM) def get_participants(self) -> List[str]: + """Get list of participant IDs in the room. + + Returns: + List of participant IDs. + """ return self._client.get_participants() async def get_participant_metadata(self, participant_id: str) -> dict: + """Get metadata for a specific participant. + + Args: + participant_id: ID of the participant to get metadata for. + + Returns: + Dictionary containing participant metadata. + """ return await self._client.get_participant_metadata(participant_id) async def set_metadata(self, metadata: str): + """Set metadata for the local participant. + + Args: + metadata: Metadata string to set. + """ await self._client.set_participant_metadata(metadata) async def mute_participant(self, participant_id: str): + """Mute a specific participant's audio tracks. + + Args: + participant_id: ID of the participant to mute. + """ await self._client.mute_participant(participant_id) async def unmute_participant(self, participant_id: str): + """Unmute a specific participant's audio tracks. + + Args: + participant_id: ID of the participant to unmute. + """ await self._client.unmute_participant(participant_id) async def _on_connected(self): + """Handle room connected events.""" await self._call_event_handler("on_connected") async def _on_disconnected(self): + """Handle room disconnected events.""" await self._call_event_handler("on_disconnected") async def _on_participant_connected(self, participant_id: str): + """Handle participant connected events.""" await self._call_event_handler("on_participant_connected", participant_id) async def _on_participant_disconnected(self, participant_id: str): + """Handle participant disconnected events.""" await self._call_event_handler("on_participant_disconnected", participant_id) await self._call_event_handler("on_participant_left", participant_id, "disconnected") async def _on_audio_track_subscribed(self, participant_id: str): + """Handle audio track subscribed events.""" await self._call_event_handler("on_audio_track_subscribed", participant_id) participant = self._client.room.remote_participants.get(participant_id) if participant: @@ -615,19 +914,33 @@ class LiveKitTransport(BaseTransport): ) async def _on_audio_track_unsubscribed(self, participant_id: str): + """Handle audio track unsubscribed events.""" await self._call_event_handler("on_audio_track_unsubscribed", participant_id) async def _on_data_received(self, data: bytes, participant_id: str): + """Handle data received events.""" if self._input: await self._input.push_app_message(data.decode(), participant_id) await self._call_event_handler("on_data_received", data, participant_id) async def send_message(self, message: str, participant_id: Optional[str] = None): + """Send a message to participants in the room. + + Args: + message: The message string to send. + participant_id: Optional specific participant to send to. + """ if self._output: frame = LiveKitTransportMessageFrame(message=message, participant_id=participant_id) await self._output.send_message(frame) async def send_message_urgent(self, message: str, participant_id: Optional[str] = None): + """Send an urgent message to participants in the room. + + Args: + message: The urgent message string to send. + participant_id: Optional specific participant to send to. + """ if self._output: frame = LiveKitTransportMessageUrgentFrame( message=message, participant_id=participant_id @@ -635,19 +948,36 @@ class LiveKitTransport(BaseTransport): await self._output.send_message(frame) async def on_room_event(self, event): + """Handle room events. + + Args: + event: The room event to handle. + """ # Handle room events pass async def on_participant_event(self, event): + """Handle participant events. + + Args: + event: The participant event to handle. + """ # Handle participant events pass async def on_track_event(self, event): + """Handle track events. + + Args: + event: The track event to handle. + """ # Handle track events pass async def _on_call_state_updated(self, state: str): + """Handle call state update events.""" await self._call_event_handler("on_call_state_updated", self, state) async def _on_first_participant_joined(self, participant_id: str): + """Handle first participant joined events.""" await self._call_event_handler("on_first_participant_joined", participant_id) diff --git a/src/pipecat/transports/services/tavus.py b/src/pipecat/transports/services/tavus.py index ff70416d2..e83dc6a20 100644 --- a/src/pipecat/transports/services/tavus.py +++ b/src/pipecat/transports/services/tavus.py @@ -1,3 +1,16 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Tavus transport implementation for Pipecat. + +This module provides integration with the Tavus platform for creating conversational +AI applications with avatars. It manages conversation sessions and provides real-time +audio/video streaming capabilities through the Tavus API. +""" + import os from functools import partial from typing import Any, Awaitable, Callable, Mapping, Optional @@ -31,8 +44,10 @@ from pipecat.transports.services.daily import ( class TavusApi: - """ - A helper class for interacting with the Tavus API (v2). + """Helper class for interacting with the Tavus API (v2). + + Provides methods for creating and managing conversations with Tavus avatars, + including conversation lifecycle management and persona information retrieval. """ BASE_URL = "https://tavusapi.com/v2" @@ -40,12 +55,11 @@ class TavusApi: MOCK_PERSONA_NAME = "TestTavusTransport" def __init__(self, api_key: str, session: aiohttp.ClientSession): - """ - Initialize the TavusApi client. + """Initialize the TavusApi client. Args: - api_key (str): Tavus API key. - session (aiohttp.ClientSession): An aiohttp session for making HTTP requests. + api_key: Tavus API key for authentication. + session: An aiohttp session for making HTTP requests. """ self._api_key = api_key self._session = session @@ -54,6 +68,15 @@ class TavusApi: self._dev_room_url = os.getenv("TAVUS_SAMPLE_ROOM_URL") async def create_conversation(self, replica_id: str, persona_id: str) -> dict: + """Create a new conversation with the specified replica and persona. + + Args: + replica_id: ID of the replica to use in the conversation. + persona_id: ID of the persona to use in the conversation. + + Returns: + Dictionary containing conversation_id and conversation_url. + """ if self._dev_room_url: return { "conversation_id": self.MOCK_CONVERSATION_ID, @@ -73,6 +96,11 @@ class TavusApi: return response async def end_conversation(self, conversation_id: str): + """End an existing conversation. + + Args: + conversation_id: ID of the conversation to end. + """ if conversation_id is None or conversation_id == self.MOCK_CONVERSATION_ID: return @@ -82,6 +110,14 @@ class TavusApi: logger.debug(f"Ended Tavus conversation {conversation_id}") async def get_persona_name(self, persona_id: str) -> str: + """Get the name of a persona by ID. + + Args: + persona_id: ID of the persona to retrieve. + + Returns: + The name of the persona. + """ if self._dev_room_url is not None: return self.MOCK_PERSONA_NAME @@ -94,11 +130,11 @@ class TavusApi: class TavusCallbacks(BaseModel): - """Callback handlers for the Tavus events. + """Callback handlers for Tavus events. - Attributes: - on_participant_joined: Called when a participant joins. - on_participant_left: Called when a participant leaves. + Parameters: + on_participant_joined: Called when a participant joins the conversation. + on_participant_left: Called when a participant leaves the conversation. """ on_participant_joined: Callable[[Mapping[str, Any]], Awaitable[None]] @@ -106,7 +142,13 @@ class TavusCallbacks(BaseModel): class TavusParams(DailyParams): - """Configuration parameters for the Tavus transport.""" + """Configuration parameters for the Tavus transport. + + Parameters: + audio_in_enabled: Whether to enable audio input from participants. + audio_out_enabled: Whether to enable audio output to participants. + microphone_out_enabled: Whether to enable microphone output track. + """ audio_in_enabled: bool = True audio_out_enabled: bool = True @@ -114,24 +156,14 @@ class TavusParams(DailyParams): class TavusTransportClient: - """ + """Transport client that integrates Pipecat with the Tavus platform. + A transport client that integrates a Pipecat Bot with the Tavus platform by managing conversation sessions using the Tavus API. This client uses `TavusApi` to interact with the Tavus backend services. When a conversation is started via `TavusApi`, Tavus provides a `roomURL` that can be used to connect the Pipecat Bot into the same virtual room where the TavusBot is operating. - - Args: - bot_name (str): The name of the Pipecat bot instance. - params (TavusParams): Optional parameters for Tavus operation. Defaults to `TavusParams()`. - callbacks (TavusCallbacks): Callback handlers for Tavus-related events. - api_key (str): API key for authenticating with Tavus API. - replica_id (str): ID of the replica to use in the Tavus conversation. - persona_id (str): ID of the Tavus persona. Defaults to "pipecat-stream", which signals Tavus to use - the TTS voice of the Pipecat bot instead of a Tavus persona voice. - session (aiohttp.ClientSession): The aiohttp session for making async HTTP requests. - sample_rate: Audio sample rate to be used by the client. """ def __init__( @@ -145,6 +177,19 @@ class TavusTransportClient: persona_id: str = "pipecat-stream", session: aiohttp.ClientSession, ) -> None: + """Initialize the Tavus transport client. + + Args: + bot_name: The name of the Pipecat bot instance. + params: Optional parameters for Tavus operation. + callbacks: Callback handlers for Tavus-related events. + api_key: API key for authenticating with Tavus API. + replica_id: ID of the replica to use in the Tavus conversation. + persona_id: ID of the Tavus persona. Defaults to "pipecat-stream", + which signals Tavus to use the TTS voice of the Pipecat bot + instead of a Tavus persona voice. + session: The aiohttp session for making async HTTP requests. + """ self._bot_name = bot_name self._api = TavusApi(api_key, session) self._replica_id = replica_id @@ -155,11 +200,17 @@ class TavusTransportClient: self._params = params async def _initialize(self) -> str: + """Initialize the conversation and return the room URL.""" response = await self._api.create_conversation(self._replica_id, self._persona_id) self._conversation_id = response["conversation_id"] return response["conversation_url"] async def setup(self, setup: FrameProcessorSetup): + """Setup the client and initialize the conversation. + + Args: + setup: The frame processor setup configuration. + """ if self._conversation_id is not None: logger.debug(f"Conversation ID already defined: {self._conversation_id}") return @@ -206,29 +257,44 @@ class TavusTransportClient: self._conversation_id = None async def cleanup(self): + """Cleanup client resources.""" try: await self._client.cleanup() except Exception as e: logger.exception(f"Exception during cleanup: {e}") async def _on_joined(self, data): + """Handle joined event.""" logger.debug("TavusTransportClient joined!") async def _on_left(self): + """Handle left event.""" logger.debug("TavusTransportClient left!") async def _on_handle_callback(self, event_name, *args, **kwargs): + """Handle generic callback events.""" logger.trace(f"[Callback] {event_name} called with args={args}, kwargs={kwargs}") async def get_persona_name(self) -> str: + """Get the persona name from the API. + + Returns: + The name of the current persona. + """ return await self._api.get_persona_name(self._persona_id) async def start(self, frame: StartFrame): + """Start the client and join the room. + + Args: + frame: The start frame containing initialization parameters. + """ logger.debug("TavusTransportClient start invoked!") await self._client.start(frame) await self._client.join() async def stop(self): + """Stop the client and end the conversation.""" await self._client.leave() await self._api.end_conversation(self._conversation_id) self._conversation_id = None @@ -241,6 +307,15 @@ class TavusTransportClient: video_source: str = "camera", color_format: str = "RGB", ): + """Capture video from a participant. + + Args: + participant_id: ID of the participant to capture video from. + callback: Callback function to handle video frames. + framerate: Desired framerate for video capture. + video_source: Video source to capture from. + color_format: Color format for video frames. + """ await self._client.capture_participant_video( participant_id, callback, framerate, video_source, color_format ) @@ -253,22 +328,47 @@ class TavusTransportClient: sample_rate: int = 16000, callback_interval_ms: int = 20, ): + """Capture audio from a participant. + + Args: + participant_id: ID of the participant to capture audio from. + callback: Callback function to handle audio data. + audio_source: Audio source to capture from. + sample_rate: Desired sample rate for audio capture. + callback_interval_ms: Interval between audio callbacks in milliseconds. + """ await self._client.capture_participant_audio( participant_id, callback, audio_source, sample_rate, callback_interval_ms ) async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): + """Send a message to participants. + + Args: + frame: The message frame to send. + """ await self._client.send_message(frame) @property def out_sample_rate(self) -> int: + """Get the output sample rate. + + Returns: + The output sample rate in Hz. + """ return self._client.out_sample_rate @property def in_sample_rate(self) -> int: + """Get the input sample rate. + + Returns: + The input sample rate in Hz. + """ return self._client.in_sample_rate async def send_interrupt_message(self) -> None: + """Send an interrupt message to the conversation.""" transport_frame = TransportMessageUrgentFrame( message={ "message_type": "conversation", @@ -279,6 +379,12 @@ class TavusTransportClient: await self.send_message(transport_frame) async def update_subscriptions(self, participant_settings=None, profile_settings=None): + """Update subscription settings for participants. + + Args: + participant_settings: Per-participant subscription settings. + profile_settings: Global subscription profile settings. + """ if not self._client: return @@ -287,11 +393,21 @@ class TavusTransportClient: ) async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the transport. + + Args: + frame: The audio frame to write. + """ if not self._client: return await self._client.write_audio_frame(frame) async def register_audio_destination(self, destination: str): + """Register an audio destination for output. + + Args: + destination: The destination identifier to register. + """ if not self._client: return @@ -299,12 +415,25 @@ class TavusTransportClient: class TavusInputTransport(BaseInputTransport): + """Input transport for receiving audio and events from Tavus conversations. + + Handles incoming audio streams from participants and manages audio capture + from the Daily room connected to the Tavus conversation. + """ + def __init__( self, client: TavusTransportClient, params: TransportParams, **kwargs, ): + """Initialize the Tavus input transport. + + Args: + client: The Tavus transport client instance. + params: Transport configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._client = client self._params = params @@ -314,14 +443,25 @@ class TavusInputTransport(BaseInputTransport): self._initialized = False async def setup(self, setup: FrameProcessorSetup): + """Setup the input transport. + + Args: + setup: The frame processor setup configuration. + """ await super().setup(setup) await self._client.setup(setup) async def cleanup(self): + """Cleanup input transport resources.""" await super().cleanup() await self._client.cleanup() async def start(self, frame: StartFrame): + """Start the input transport. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -333,14 +473,29 @@ class TavusInputTransport(BaseInputTransport): await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): + """Stop the input transport. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) await self._client.stop() async def cancel(self, frame: CancelFrame): + """Cancel the input transport. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) await self._client.stop() async def start_capturing_audio(self, participant): + """Start capturing audio from a participant. + + Args: + participant: The participant to capture audio from. + """ if self._params.audio_in_enabled: logger.info( f"TavusTransportClient start capturing audio for participant {participant['id']}" @@ -354,6 +509,7 @@ class TavusInputTransport(BaseInputTransport): async def _on_participant_audio_data( self, participant_id: str, audio: AudioData, audio_source: str ): + """Handle received participant audio data.""" frame = InputAudioRawFrame( audio=audio.audio_frames, sample_rate=audio.audio_frames, @@ -364,12 +520,25 @@ class TavusInputTransport(BaseInputTransport): class TavusOutputTransport(BaseOutputTransport): + """Output transport for sending audio and events to Tavus conversations. + + Handles outgoing audio streams to participants and manages the custom + audio track expected by the Tavus platform. + """ + def __init__( self, client: TavusTransportClient, params: TransportParams, **kwargs, ): + """Initialize the Tavus output transport. + + Args: + client: The Tavus transport client instance. + params: Transport configuration parameters. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(params, **kwargs) self._client = client self._params = params @@ -380,14 +549,25 @@ class TavusOutputTransport(BaseOutputTransport): self._transport_destination: Optional[str] = "stream" async def setup(self, setup: FrameProcessorSetup): + """Setup the output transport. + + Args: + setup: The frame processor setup configuration. + """ await super().setup(setup) await self._client.setup(setup) async def cleanup(self): + """Cleanup output transport resources.""" await super().cleanup() await self._client.cleanup() async def start(self, frame: StartFrame): + """Start the output transport. + + Args: + frame: The start frame containing initialization parameters. + """ await super().start(frame) if self._initialized: @@ -403,51 +583,72 @@ class TavusOutputTransport(BaseOutputTransport): await self.set_transport_ready(frame) async def stop(self, frame: EndFrame): + """Stop the output transport. + + Args: + frame: The end frame signaling transport shutdown. + """ await super().stop(frame) await self._client.stop() async def cancel(self, frame: CancelFrame): + """Cancel the output transport. + + Args: + frame: The cancel frame signaling immediate cancellation. + """ await super().cancel(frame) await self._client.stop() async def send_message(self, frame: TransportMessageFrame | TransportMessageUrgentFrame): + """Send a message to participants. + + Args: + frame: The message frame to send. + """ logger.info(f"TavusOutputTransport sending message {frame}") await self._client.send_message(frame) async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process frames and handle interruptions. + + Args: + frame: The frame to process. + direction: The direction of frame flow in the pipeline. + """ await super().process_frame(frame, direction) if isinstance(frame, StartInterruptionFrame): await self._handle_interruptions() async def _handle_interruptions(self): + """Handle interruption events by sending interrupt message.""" await self._client.send_interrupt_message() async def write_audio_frame(self, frame: OutputAudioRawFrame): + """Write an audio frame to the Tavus transport. + + Args: + frame: The audio frame to write. + """ # This is the custom track destination expected by Tavus frame.transport_destination = self._transport_destination await self._client.write_audio_frame(frame) async def register_audio_destination(self, destination: str): + """Register an audio destination. + + Args: + destination: The destination identifier to register. + """ await self._client.register_audio_destination(destination) class TavusTransport(BaseTransport): - """ - Transport implementation for Tavus video calls. + """Transport implementation for Tavus video calls. When used, the Pipecat bot joins the same virtual room as the Tavus Avatar and the user. This is achieved by using `TavusTransportClient`, which initiates the conversation via `TavusApi` and obtains a room URL that all participants connect to. - - Args: - bot_name (str): The name of the Pipecat bot. - session (aiohttp.ClientSession): aiohttp session used for async HTTP requests. - api_key (str): Tavus API key for authentication. - replica_id (str): ID of the replica model used for voice generation. - persona_id (str): ID of the Tavus persona. Defaults to "pipecat-stream" to use the Pipecat TTS voice. - params (TavusParams): Optional Tavus-specific configuration parameters. - input_name (Optional[str]): Optional name for the input transport. - output_name (Optional[str]): Optional name for the output transport. """ def __init__( @@ -461,6 +662,19 @@ class TavusTransport(BaseTransport): input_name: Optional[str] = None, output_name: Optional[str] = None, ): + """Initialize the Tavus transport. + + Args: + bot_name: The name of the Pipecat bot. + session: aiohttp session used for async HTTP requests. + api_key: Tavus API key for authentication. + replica_id: ID of the replica model used for voice generation. + persona_id: ID of the Tavus persona. Defaults to "pipecat-stream" + to use the Pipecat TTS voice. + params: Optional Tavus-specific configuration parameters. + input_name: Optional name for the input transport. + output_name: Optional name for the output transport. + """ super().__init__(input_name=input_name, output_name=output_name) self._params = params @@ -487,11 +701,13 @@ class TavusTransport(BaseTransport): self._register_event_handler("on_client_disconnected") async def _on_participant_left(self, participant, reason): + """Handle participant left events.""" persona_name = await self._client.get_persona_name() if participant.get("info", {}).get("userName", "") != persona_name: await self._on_client_disconnected(participant) async def _on_participant_joined(self, participant): + """Handle participant joined events.""" # get persona, look up persona_name, set this as the bot name to ignore persona_name = await self._client.get_persona_name() @@ -513,23 +729,41 @@ class TavusTransport(BaseTransport): await self._input.start_capturing_audio(participant) async def update_subscriptions(self, participant_settings=None, profile_settings=None): + """Update subscription settings for participants. + + Args: + participant_settings: Per-participant subscription settings. + profile_settings: Global subscription profile settings. + """ await self._client.update_subscriptions( participant_settings=participant_settings, profile_settings=profile_settings, ) def input(self) -> FrameProcessor: + """Get the input transport for receiving media and events. + + Returns: + The Tavus input transport instance. + """ if not self._input: self._input = TavusInputTransport(client=self._client, params=self._params) return self._input def output(self) -> FrameProcessor: + """Get the output transport for sending media and events. + + Returns: + The Tavus output transport instance. + """ if not self._output: self._output = TavusOutputTransport(client=self._client, params=self._params) return self._output async def _on_client_connected(self, participant: Any): + """Handle client connected events.""" await self._call_event_handler("on_client_connected", participant) async def _on_client_disconnected(self, participant: Any): + """Handle client disconnected events.""" await self._call_event_handler("on_client_disconnected", participant) diff --git a/src/pipecat/utils/asyncio/task_manager.py b/src/pipecat/utils/asyncio/task_manager.py index 844536186..aaa340399 100644 --- a/src/pipecat/utils/asyncio/task_manager.py +++ b/src/pipecat/utils/asyncio/task_manager.py @@ -4,6 +4,14 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Asyncio task management with watchdog monitoring capabilities. + +This module provides task management functionality with optional watchdog timers +to monitor task execution and prevent hanging operations. Includes both abstract +base classes and concrete implementations for managing asyncio tasks with +comprehensive monitoring and cleanup capabilities. +""" + import asyncio import time from abc import ABC, abstractmethod @@ -17,6 +25,15 @@ WATCHDOG_TIMEOUT = 5.0 @dataclass class TaskManagerParams: + """Configuration parameters for task manager initialization. + + Parameters: + loop: The asyncio event loop to use for task management. + enable_watchdog_timers: Whether to enable watchdog timers for tasks. + enable_watchdog_logging: Whether to log watchdog timing information. + watchdog_timeout: Default timeout in seconds for watchdog timers. + """ + loop: asyncio.AbstractEventLoop enable_watchdog_timers: bool = False enable_watchdog_logging: bool = False @@ -24,12 +41,28 @@ class TaskManagerParams: class BaseTaskManager(ABC): + """Abstract base class for asyncio task management with watchdog support. + + Provides the interface for creating, monitoring, and managing asyncio tasks + with optional watchdog timer functionality to detect stalled operations. + """ + @abstractmethod def setup(self, params: TaskManagerParams): + """Initialize the task manager with configuration parameters. + + Args: + params: Configuration parameters for task management. + """ pass @abstractmethod def get_event_loop(self) -> asyncio.AbstractEventLoop: + """Get the event loop used by this task manager. + + Returns: + The asyncio event loop instance. + """ pass @abstractmethod @@ -42,21 +75,19 @@ class BaseTaskManager(ABC): enable_watchdog_timers: Optional[bool] = None, watchdog_timeout: Optional[float] = None, ) -> asyncio.Task: - """ - Creates and schedules a new asyncio Task that runs the given coroutine. + """Creates and schedules a new asyncio Task that runs the given coroutine. The task is added to a global set of created tasks. Args: - loop (asyncio.AbstractEventLoop): The event loop to use for creating the task. - coroutine (Coroutine): The coroutine to be executed within the task. - name (str): The name to assign to the task for identification. - enable_watchdog_logging(bool): whether this task should log watchdog processing times. - enable_watchdog_timers(bool): whether this task should have a watchdog timer. - watchdog_timeout(float): watchdog timer timeout for this task. + coroutine: The coroutine to be executed within the task. + name: The name to assign to the task for identification. + enable_watchdog_logging: Whether this task should log watchdog processing times. + enable_watchdog_timers: Whether this task should have a watchdog timer. + watchdog_timeout: Watchdog timer timeout for this task. Returns: - asyncio.Task: The created task object. + The created task object. """ pass @@ -69,50 +100,67 @@ class BaseTaskManager(ABC): is removed from the set of registered tasks upon completion or failure. Args: - task (asyncio.Task): The asyncio Task to wait for. - timeout (Optional[float], optional): The maximum number of seconds - to wait for the task to complete. If None, waits indefinitely. - Defaults to None. + task: The asyncio Task to wait for. + timeout: The maximum number of seconds to wait for the task to complete. + If None, waits indefinitely. """ pass @abstractmethod async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None): - """Cancels the given asyncio Task and awaits its completion with an - optional timeout. + """Cancels the given asyncio Task and awaits its completion with an optional timeout. This function removes the task from the set of registered tasks upon completion or failure. Args: - task (asyncio.Task): The task to be cancelled. - timeout (Optional[float]): The optional timeout in seconds to wait for the task to cancel. - + task: The task to be cancelled. + timeout: The optional timeout in seconds to wait for the task to cancel. """ pass @abstractmethod def current_tasks(self) -> Sequence[asyncio.Task]: - """Returns the list of currently created/registered tasks.""" + """Returns the list of currently created/registered tasks. + + Returns: + Sequence of currently managed asyncio tasks. + """ pass @abstractmethod def task_reset_watchdog(self): - """Resets the running task watchdog timer. If not reset, a warning will - be logged indicating the task is stalling. + """Task reset watchdog timer. + Resets the running task watchdog timer. If not reset, a warning will + be logged indicating the task is stalling. """ pass @property @abstractmethod def task_watchdog_enabled(self) -> bool: - """Whether the current running task has a watchdog timer enabled.""" + """Whether the current running task has a watchdog timer enabled. + + Returns: + True if the current task has watchdog monitoring active. + """ pass @dataclass class TaskData: + """Internal data structure for tracking task metadata and watchdog state. + + Parameters: + task: The asyncio Task being managed. + watchdog_timer: Event used to reset the watchdog timer. + enable_watchdog_logging: Whether to log watchdog timing information. + enable_watchdog_timers: Whether watchdog timers are enabled for this task. + watchdog_timeout: Timeout in seconds for watchdog warnings. + watchdog_task: Optional background task monitoring the watchdog timer. + """ + task: asyncio.Task watchdog_timer: asyncio.Event enable_watchdog_logging: bool @@ -122,15 +170,36 @@ class TaskData: class TaskManager(BaseTaskManager): + """Concrete implementation of BaseTaskManager with full watchdog support. + + Manages asyncio tasks with optional watchdog monitoring to detect stalled + operations. Provides comprehensive task lifecycle management including + creation, monitoring, cancellation, and cleanup. + """ + def __init__(self) -> None: + """Initialize the task manager with empty task registry.""" self._tasks: Dict[str, TaskData] = {} self._params: Optional[TaskManagerParams] = None def setup(self, params: TaskManagerParams): + """Initialize the task manager with configuration parameters. + + Args: + params: Configuration parameters for task management. + """ if not self._params: self._params = params def get_event_loop(self) -> asyncio.AbstractEventLoop: + """Get the event loop used by this task manager. + + Returns: + The asyncio event loop instance. + + Raises: + Exception: If the task manager is not properly set up. + """ if not self._params: raise Exception("TaskManager is not setup: unable to get event loop") return self._params.loop @@ -144,21 +213,22 @@ class TaskManager(BaseTaskManager): enable_watchdog_timers: Optional[bool] = None, watchdog_timeout: Optional[float] = None, ) -> asyncio.Task: - """ - Creates and schedules a new asyncio Task that runs the given coroutine. + """Creates and schedules a new asyncio Task that runs the given coroutine. The task is added to a global set of created tasks. Args: - loop (asyncio.AbstractEventLoop): The event loop to use for creating the task. - coroutine (Coroutine): The coroutine to be executed within the task. - name (str): The name to assign to the task for identification. - enable_watchdog_logging(bool): whether this task should log watchdog processing time. - enable_watchdog_timers(bool): whether this task should have a watchdog timer. - watchdog_timeout(float): watchdog timer timeout for this task. + coroutine: The coroutine to be executed within the task. + name: The name to assign to the task for identification. + enable_watchdog_logging: Whether this task should log watchdog processing time. + enable_watchdog_timers: Whether this task should have a watchdog timer. + watchdog_timeout: Watchdog timer timeout for this task. Returns: - asyncio.Task: The created task object. + The created task object. + + Raises: + Exception: If the task manager is not properly set up. """ async def run_coroutine(): @@ -208,10 +278,9 @@ class TaskManager(BaseTaskManager): is removed from the set of registered tasks upon completion or failure. Args: - task (asyncio.Task): The asyncio Task to wait for. - timeout (Optional[float], optional): The maximum number of seconds - to wait for the task to complete. If None, waits indefinitely. - Defaults to None. + task: The asyncio Task to wait for. + timeout: The maximum number of seconds to wait for the task to complete. + If None, waits indefinitely. """ name = task.get_name() try: @@ -228,16 +297,14 @@ class TaskManager(BaseTaskManager): logger.exception(f"{name}: unexpected exception while stopping task: {e}") async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = None): - """Cancels the given asyncio Task and awaits its completion with an - optional timeout. + """Cancels the given asyncio Task and awaits its completion with an optional timeout. This function removes the task from the set of registered tasks upon completion or failure. Args: - task (asyncio.Task): The task to be cancelled. - timeout (Optional[float]): The optional timeout in seconds to wait for the task to cancel. - + task: The task to be cancelled. + timeout: The optional timeout in seconds to wait for the task to cancel. """ name = task.get_name() task.cancel() @@ -260,18 +327,28 @@ class TaskManager(BaseTaskManager): raise def reset_watchdog(self, task: asyncio.Task): + """Reset the watchdog timer for a specific task. + + Args: + task: The task whose watchdog timer should be reset. + """ name = task.get_name() if name in self._tasks and self._tasks[name].enable_watchdog_timers: self._tasks[name].watchdog_timer.set() def current_tasks(self) -> Sequence[asyncio.Task]: - """Returns the list of currently created/registered tasks.""" + """Returns the list of currently created/registered tasks. + + Returns: + Sequence of currently managed asyncio tasks. + """ return [data.task for data in self._tasks.values()] def task_reset_watchdog(self): - """Resets the running task watchdog timer. If not reset on time, a warning - will be logged indicating the task is stalling. + """Task reset watchdog timer. + Resets the running task watchdog timer. If not reset on time, a warning + will be logged indicating the task is stalling. """ task = asyncio.current_task() if task: @@ -279,6 +356,11 @@ class TaskManager(BaseTaskManager): @property def task_watchdog_enabled(self) -> bool: + """Whether the current running task has a watchdog timer enabled. + + Returns: + True if the current task has watchdog monitoring active. + """ task = asyncio.current_task() if not task: return False @@ -286,6 +368,11 @@ class TaskManager(BaseTaskManager): return name in self._tasks and self._tasks[name].enable_watchdog_timers def _add_task(self, task_data: TaskData): + """Add a task to the internal registry and start watchdog if enabled. + + Args: + task_data: The task data containing task and watchdog configuration. + """ name = task_data.task.get_name() self._tasks[name] = task_data if self._params and task_data.enable_watchdog_timers: @@ -295,6 +382,11 @@ class TaskManager(BaseTaskManager): task_data.watchdog_task = watchdog_task async def _watchdog_task_handler(self, task_data: TaskData): + """Background task that monitors watchdog timer for a specific task. + + Args: + task_data: The task data containing watchdog configuration. + """ name = task_data.task.get_name() timer = task_data.watchdog_timer enable_watchdog_logging = task_data.enable_watchdog_logging @@ -315,6 +407,11 @@ class TaskManager(BaseTaskManager): timer.clear() def _task_done_handler(self, task: asyncio.Task): + """Handle task completion by cleaning up watchdog and removing from registry. + + Args: + task: The completed asyncio task. + """ name = task.get_name() try: task_data = self._tasks[name] diff --git a/src/pipecat/utils/asyncio/watchdog_async_iterator.py b/src/pipecat/utils/asyncio/watchdog_async_iterator.py index d9d3e2f79..c9db0ba7e 100644 --- a/src/pipecat/utils/asyncio/watchdog_async_iterator.py +++ b/src/pipecat/utils/asyncio/watchdog_async_iterator.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Watchdog-enabled async iterator wrapper for task monitoring. + +This module provides an async iterator wrapper that automatically resets +watchdog timers while waiting for iterator items, preventing false positive +watchdog timeouts during legitimate waiting periods. +""" + import asyncio from typing import AsyncIterator, Optional @@ -11,10 +18,11 @@ from pipecat.utils.asyncio.task_manager import BaseTaskManager class WatchdogAsyncIterator: - """An asynchronous iterator that monitors activity and resets the current + """Watchdog async iterator wrapper. + + An asynchronous iterator that monitors activity and resets the current task watchdog timer. This is necessary to avoid task watchdog timers to expire while we are waiting to get an item from the iterator. - """ def __init__( @@ -24,6 +32,13 @@ class WatchdogAsyncIterator: manager: BaseTaskManager, timeout: float = 2.0, ): + """Initialize the watchdog async iterator. + + Args: + async_iterable: The async iterable to wrap with watchdog monitoring. + manager: The task manager for watchdog timer control. + timeout: Timeout in seconds between watchdog resets while waiting. + """ self._async_iterable = async_iterable self._manager = manager self._timeout = timeout @@ -31,9 +46,22 @@ class WatchdogAsyncIterator: self._current_anext_task: Optional[asyncio.Task] = None def __aiter__(self): + """Return self as the async iterator. + + Returns: + This iterator instance. + """ return self async def __anext__(self): + """Get the next item from the iterator with watchdog monitoring. + + Returns: + The next item from the wrapped async iterator. + + Raises: + StopAsyncIteration: When the iterator is exhausted. + """ if not self._iter: self._iter = await self._ensure_async_iterator(self._async_iterable) @@ -43,6 +71,7 @@ class WatchdogAsyncIterator: return await self._iter.__anext__() async def _watchdog_anext(self): + """Get next item while periodically resetting watchdog timer.""" while True: try: if not self._current_anext_task: @@ -67,6 +96,7 @@ class WatchdogAsyncIterator: raise async def _ensure_async_iterator(self, obj) -> AsyncIterator: + """Ensure the object is an async iterator, awaiting if necessary.""" aiter = obj.__aiter__() if asyncio.iscoroutine(aiter): aiter = await aiter diff --git a/src/pipecat/utils/asyncio/watchdog_coroutine.py b/src/pipecat/utils/asyncio/watchdog_coroutine.py index 84855c3e6..234776548 100644 --- a/src/pipecat/utils/asyncio/watchdog_coroutine.py +++ b/src/pipecat/utils/asyncio/watchdog_coroutine.py @@ -4,17 +4,26 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Watchdog-enabled coroutine wrapper for task monitoring. + +This module provides a coroutine wrapper that automatically resets watchdog +timers while waiting for coroutine completion, preventing false positive +watchdog timeouts during legitimate operations. +""" + import asyncio from typing import Optional +from pipecat.pipeline import task from pipecat.utils.asyncio.task_manager import BaseTaskManager class WatchdogCoroutine: - """An asynchronous iterator that monitors activity and resets the current + """Watchdog-enabled coroutine wrapper. + + An asynchronous iterator that monitors activity and resets the current task watchdog timer. This is necessary to avoid task watchdog timers to expire while we are waiting to get an item from the iterator. - """ def __init__( @@ -24,18 +33,27 @@ class WatchdogCoroutine: manager: BaseTaskManager, timeout: float = 2.0, ): + """Initialize the watchdog coroutine wrapper. + + Args: + coroutine: The coroutine to wrap with watchdog monitoring. + manager: The task manager for watchdog timer control. + timeout: Timeout in seconds between watchdog resets while waiting. + """ self._coroutine = coroutine self._manager = manager self._timeout = timeout self._current_coro_task: Optional[asyncio.Task] = None async def __call__(self): + """Execute the wrapped coroutine with watchdog monitoring.""" if self._manager.task_watchdog_enabled: return await self._watchdog_call() else: return await self._coroutine async def _watchdog_call(self): + """Execute coroutine while periodically resetting watchdog timer.""" while True: try: if not self._current_coro_task: @@ -57,5 +75,15 @@ class WatchdogCoroutine: async def watchdog_coroutine(coroutine, *, manager: BaseTaskManager, timeout: float = 2.0): + """Execute a coroutine with watchdog monitoring support. + + Args: + coroutine: The coroutine to execute with watchdog monitoring. + manager: The task manager for watchdog timer control. + timeout: Timeout in seconds between watchdog resets while waiting. + + Returns: + The result of the coroutine execution. + """ watchdog_coro = WatchdogCoroutine(coroutine, manager=manager, timeout=timeout) return await watchdog_coro() diff --git a/src/pipecat/utils/asyncio/watchdog_event.py b/src/pipecat/utils/asyncio/watchdog_event.py index 65453f6ec..b2b306618 100644 --- a/src/pipecat/utils/asyncio/watchdog_event.py +++ b/src/pipecat/utils/asyncio/watchdog_event.py @@ -4,16 +4,24 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Watchdog-enabled asyncio Event for task monitoring. + +This module provides an asyncio Event subclass that automatically resets +watchdog timers while waiting for the event, preventing false positive +watchdog timeouts during legitimate waiting periods. +""" + import asyncio from pipecat.utils.asyncio.task_manager import BaseTaskManager class WatchdogEvent(asyncio.Event): - """An asynchronous event that resets the current task watchdog timer. This + """Watchdog-enabled asyncio Event. + + An asynchronous event that resets the current task watchdog timer. This is necessary to avoid task watchdog timers to expire while we are waiting on the event. - """ def __init__( @@ -22,17 +30,29 @@ class WatchdogEvent(asyncio.Event): *, timeout: float = 2.0, ) -> None: + """Initialize the watchdog event. + + Args: + manager: The task manager for watchdog timer control. + timeout: Timeout in seconds between watchdog resets while waiting. + """ super().__init__() self._manager = manager self._timeout = timeout async def wait(self): + """Wait for the event to be set with watchdog monitoring. + + Returns: + True when the event is set. + """ if self._manager.task_watchdog_enabled: return await self._watchdog_wait() else: return await super().wait() async def _watchdog_wait(self): + """Wait for event while periodically resetting watchdog timer.""" while True: try: await asyncio.wait_for(super().wait(), timeout=self._timeout) diff --git a/src/pipecat/utils/asyncio/watchdog_priority_queue.py b/src/pipecat/utils/asyncio/watchdog_priority_queue.py index 31d358fc7..46c6adf3d 100644 --- a/src/pipecat/utils/asyncio/watchdog_priority_queue.py +++ b/src/pipecat/utils/asyncio/watchdog_priority_queue.py @@ -4,16 +4,24 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Watchdog-enabled asyncio PriorityQueue for task monitoring. + +This module provides an asyncio PriorityQueue subclass that automatically resets +watchdog timers while waiting for items, preventing false positive watchdog +timeouts during legitimate queue operations. +""" + import asyncio from pipecat.utils.asyncio.task_manager import BaseTaskManager class WatchdogPriorityQueue(asyncio.PriorityQueue): - """An asynchronous priority queue that resets the current task watchdog + """Watchdog-enabled asyncio PriorityQueue. + + An asynchronous priority queue that resets the current task watchdog timer. This is necessary to avoid task watchdog timers to expire while we are waiting to get an item from the queue. - """ def __init__( @@ -23,22 +31,39 @@ class WatchdogPriorityQueue(asyncio.PriorityQueue): maxsize: int = 0, timeout: float = 2.0, ) -> None: + """Initialize the watchdog priority queue. + + Args: + manager: The task manager for watchdog timer control. + maxsize: Maximum queue size. 0 means unlimited. + timeout: Timeout in seconds between watchdog resets while waiting. + """ super().__init__(maxsize) self._manager = manager self._timeout = timeout async def get(self): + """Get an item from the queue with watchdog monitoring. + + Returns: + The next item from the priority queue. + """ if self._manager.task_watchdog_enabled: return await self._watchdog_get() else: return await super().get() def task_done(self): + """Mark a task as done and reset watchdog if enabled. + + Should be called after processing each item retrieved from the queue. + """ if self._manager.task_watchdog_enabled: self._manager.task_reset_watchdog() super().task_done() async def _watchdog_get(self): + """Get item from queue while periodically resetting watchdog timer.""" while True: try: item = await asyncio.wait_for(super().get(), timeout=self._timeout) diff --git a/src/pipecat/utils/asyncio/watchdog_queue.py b/src/pipecat/utils/asyncio/watchdog_queue.py index 961324b7b..4a92497f4 100644 --- a/src/pipecat/utils/asyncio/watchdog_queue.py +++ b/src/pipecat/utils/asyncio/watchdog_queue.py @@ -4,16 +4,24 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Watchdog-enabled asyncio Queue for task monitoring. + +This module provides an asyncio Queue subclass that automatically resets +watchdog timers while waiting for items, preventing false positive watchdog +timeouts during legitimate queue operations. +""" + import asyncio from pipecat.utils.asyncio.task_manager import BaseTaskManager class WatchdogQueue(asyncio.Queue): - """An asynchronous queue that resets the current task watchdog timer. This + """Watchdog-enabled asyncio Queue. + + An asynchronous queue that resets the current task watchdog timer. This is necessary to avoid task watchdog timers to expire while we are waiting to get an item from the queue. - """ def __init__( @@ -23,22 +31,39 @@ class WatchdogQueue(asyncio.Queue): maxsize: int = 0, timeout: float = 2.0, ) -> None: + """Initialize the watchdog queue. + + Args: + manager: The task manager for watchdog timer control. + maxsize: Maximum queue size. 0 means unlimited. + timeout: Timeout in seconds between watchdog resets while waiting. + """ super().__init__(maxsize) self._manager = manager self._timeout = timeout async def get(self): + """Get an item from the queue with watchdog monitoring. + + Returns: + The next item from the queue. + """ if self._manager.task_watchdog_enabled: return await self._watchdog_get() else: return await super().get() def task_done(self): + """Mark a task as done and reset watchdog if enabled. + + Should be called after processing each item retrieved from the queue. + """ if self._manager.task_watchdog_enabled: self._manager.task_reset_watchdog() super().task_done() async def _watchdog_get(self): + """Get item from queue while periodically resetting watchdog timer.""" while True: try: item = await asyncio.wait_for(super().get(), timeout=self._timeout) diff --git a/src/pipecat/utils/base_object.py b/src/pipecat/utils/base_object.py index 03b42ade0..51eb1195b 100644 --- a/src/pipecat/utils/base_object.py +++ b/src/pipecat/utils/base_object.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base object class providing event handling and lifecycle management. + +This module provides the foundational BaseObject class that offers common +functionality including unique identification, naming, event handling, +and async cleanup for all Pipecat components. +""" + import asyncio import inspect from abc import ABC @@ -15,7 +22,20 @@ from pipecat.utils.utils import obj_count, obj_id class BaseObject(ABC): + """Abstract base class providing common functionality for Pipecat objects. + + Provides unique identification, naming, event handling capabilities, + and async lifecycle management for all Pipecat components. All major + classes in the framework should inherit from this base class. + """ + def __init__(self, *, name: Optional[str] = None): + """Initialize the base object. + + Args: + name: Optional custom name for the object. If not provided, + generates a name using the class name and instance count. + """ self._id: int = obj_id() self._name = name or f"{self.__class__.__name__}#{obj_count(self)}" @@ -29,19 +49,44 @@ class BaseObject(ABC): @property def id(self) -> int: + """Get the unique identifier for this object. + + Returns: + The unique integer ID assigned to this object instance. + """ return self._id @property def name(self) -> str: + """Get the name of this object. + + Returns: + The object's name, either custom-provided or auto-generated. + """ return self._name async def cleanup(self): + """Clean up resources and wait for running event handlers to complete. + + This method should be called when the object is no longer needed. + It waits for all currently executing event handler tasks to finish + before returning. + """ if self._event_tasks: event_names, tasks = zip(*self._event_tasks) logger.debug(f"{self} waiting on event handlers to finish {list(event_names)}...") await asyncio.wait(tasks) def event_handler(self, event_name: str): + """Decorator for registering event handlers. + + Args: + event_name: The name of the event to handle. + + Returns: + The decorator function that registers the handler. + """ + def decorator(handler): self.add_event_handler(event_name, handler) return handler @@ -49,18 +94,37 @@ class BaseObject(ABC): return decorator def add_event_handler(self, event_name: str, handler): + """Add an event handler for the specified event. + + Args: + event_name: The name of the event to handle. + handler: The function to call when the event occurs. + Can be sync or async. + """ if event_name in self._event_handlers: self._event_handlers[event_name].append(handler) else: logger.warning(f"Event handler {event_name} not registered") def _register_event_handler(self, event_name: str): + """Register an event handler type. + + Args: + event_name: The name of the event type to register. + """ if event_name not in self._event_handlers: self._event_handlers[event_name] = [] else: logger.warning(f"Event handler {event_name} not registered") async def _call_event_handler(self, event_name: str, *args, **kwargs): + """Call all registered handlers for the specified event. + + Args: + event_name: The name of the event to trigger. + *args: Positional arguments to pass to event handlers. + **kwargs: Keyword arguments to pass to event handlers. + """ # If we haven't registered an event handler, we don't need to do # anything. if not self._event_handlers.get(event_name): @@ -76,6 +140,13 @@ class BaseObject(ABC): task.add_done_callback(self._event_task_finished) async def _run_task(self, event_name: str, *args, **kwargs): + """Execute all handlers for an event. + + Args: + event_name: The name of the event being handled. + *args: Positional arguments to pass to handlers. + **kwargs: Keyword arguments to pass to handlers. + """ try: for handler in self._event_handlers[event_name]: if inspect.iscoroutinefunction(handler): @@ -86,9 +157,19 @@ class BaseObject(ABC): logger.exception(f"Exception in event handler {event_name}: {e}") def _event_task_finished(self, task: asyncio.Task): + """Clean up completed event handler tasks. + + Args: + task: The completed asyncio Task to remove from tracking. + """ tuple_to_remove = next((t for t in self._event_tasks if t[1] == task), None) if tuple_to_remove: self._event_tasks.discard(tuple_to_remove) def __str__(self): + """Return the string representation of this object. + + Returns: + The object's name as its string representation. + """ return self.name diff --git a/src/pipecat/utils/network.py b/src/pipecat/utils/network.py index 27bec990c..3aef43988 100644 --- a/src/pipecat/utils/network.py +++ b/src/pipecat/utils/network.py @@ -4,6 +4,8 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base class for network utilities, providing exponential backoff time calculation.""" + def exponential_backoff_time( attempt: int, min_wait: float = 4, max_wait: float = 10, multiplier: float = 1 diff --git a/src/pipecat/utils/string.py b/src/pipecat/utils/string.py index 69036a665..21449a3ab 100644 --- a/src/pipecat/utils/string.py +++ b/src/pipecat/utils/string.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Text processing utilities for sentence boundary detection and tag parsing. + +This module provides utilities for natural language text processing including +sentence boundary detection, email and number pattern handling, and XML-style +tag parsing for structured text content. +""" + import re from typing import Optional, Sequence, Tuple @@ -30,18 +37,16 @@ StartEndTags = Tuple[str, str] def replace_match(text: str, match: re.Match, old: str, new: str) -> str: - """Replace occurrences of a substring within a matched section of a given - text. + """Replace occurrences of a substring within a matched section of text. Args: - text (str): The input text in which replacements will be made. - match (re.Match): A regex match object representing the section of text to modify. - old (str): The substring to be replaced. - new (str): The substring to replace `old` with. + text: The input text in which replacements will be made. + match: A regex match object representing the section of text to modify. + old: The substring to be replaced. + new: The substring to replace `old` with. Returns: - str: The modified text with the specified replacements made within the matched section. - + The modified text with the specified replacements made within the matched section. """ start = match.start() end = match.end() @@ -51,7 +56,7 @@ def replace_match(text: str, match: re.Match, old: str, new: str) -> str: def match_endofsentence(text: str) -> int: - """Finds the position of the end of a sentence in the provided text string. + """Find the position of the end of a sentence in the provided text. This function processes the input text by replacing periods in email addresses and numbers with ampersands to prevent them from being @@ -59,11 +64,10 @@ def match_endofsentence(text: str) -> int: sentence using a specified regex pattern. Args: - text (str): The input text in which to find the end of the sentence. + text: The input text in which to find the end of the sentence. Returns: - int: The position of the end of the sentence if found, otherwise 0. - + The position of the end of the sentence if found, otherwise 0. """ text = text.rstrip() @@ -90,24 +94,22 @@ def parse_start_end_tags( current_tag: Optional[StartEndTags], current_tag_index: int, ) -> Tuple[Optional[StartEndTags], int]: - """Parses the given text to identify a pair of start/end tags. + """Parse text to identify start and end tag pairs. - If a start tag was previously found (i.e. current_tags is valid), wait for - the corresponding end tag. Otherwise, wait for a start tag. + If a start tag was previously found (i.e., current_tag is valid), wait for + the corresponding end tag. Otherwise, wait for a start tag. - This function will return the index in the text that we should start parsing + This function returns the index in the text where parsing should continue in the next call and the current or new tags. - Parameters: - - text (str): The text to be parsed. - - tags (Sequence[StartEndTags]): List of tuples containing start and end tags. - - current_tags (Optional[StartEndTags]): The currently active tags, if any. - - current_tags_index (int): The current index in the text. + Args: + text: The text to be parsed. + tags: List of tuples containing start and end tags. + current_tag: The currently active tags, if any. + current_tag_index: The current index in the text. Returns: - Tuple[Optional[StartEndTags], int]: A tuple containing None or the current - tag and the index of the text. - + A tuple containing None or the current tag and the index of the text. """ # If we are already inside a tag, check if the end tag is in the text. if current_tag: diff --git a/src/pipecat/utils/text/base_text_aggregator.py b/src/pipecat/utils/text/base_text_aggregator.py index 01fd2ba9e..27e50fff5 100644 --- a/src/pipecat/utils/text/base_text_aggregator.py +++ b/src/pipecat/utils/text/base_text_aggregator.py @@ -4,54 +4,85 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base text aggregator interface for Pipecat text processing. + +This module defines the abstract base class for text aggregators that accumulate +and process text tokens, typically used by TTS services to determine when +aggregated text should be sent for speech synthesis. +""" + from abc import ABC, abstractmethod from typing import Optional class BaseTextAggregator(ABC): - """This is the base class for text aggregators. Text aggregators are usually - used by the TTS service to aggregate LLM tokens and decide when the - aggregated text should be pushed to the TTS service. + """Base class for text aggregators in the Pipecat framework. + + Text aggregators are usually used by the TTS service to aggregate LLM tokens + and decide when the aggregated text should be pushed to the TTS service. Text aggregators can also be used to manipulate text while it's being aggregated (e.g. reasoning blocks can be removed). + Subclasses must implement all abstract methods to define specific aggregation + logic, text manipulation behavior, and state management for interruptions. """ @property @abstractmethod def text(self) -> str: - """Returns the currently aggregated text.""" + """Get the currently aggregated text. + + Subclasses must implement this property to return the text that has + been accumulated so far in their internal buffer or storage. + + Returns: + The text that has been accumulated so far. + """ pass @abstractmethod async def aggregate(self, text: str) -> Optional[str]: - """Aggregates the specified text with the currently accumulated text. + """Aggregate the specified text with the currently accumulated text. This method should be implemented to define how the new text contributes to the aggregation process. It returns the updated aggregated text if it's ready to be processed, or None otherwise. + Subclasses should implement their specific logic for: + + - How to combine new text with existing accumulated text + - When to consider the aggregated text ready for processing + - What criteria determine text completion (e.g., sentence boundaries) + Args: - text (str): The text to be aggregated. + text: The text to be aggregated. Returns: - Optional[str]: The updated aggregated text or None if aggregated - text is not ready. - + The updated aggregated text if ready for processing, or None if more + text is needed before the aggregated content is ready. """ pass @abstractmethod async def handle_interruption(self): - """Handles interruptions. When an interruption occurs it is possible - that we might want to discard the aggregated text or do some internal - modifications to the aggregated text. + """Handle interruptions in the text aggregation process. + When an interruption occurs it is possible that we might want to discard + the aggregated text or do some internal modifications to the aggregated text. + + Subclasses should implement this method to define how they respond to + interruptions, such as clearing buffers, resetting state, or preserving + partial content. """ pass @abstractmethod async def reset(self): - """Clears the internally aggregated text.""" + """Clear the internally aggregated text and reset to initial state. + + Subclasses should implement this method to return the aggregator to its + initial state, discarding any previously accumulated text content and + resetting any internal tracking variables. + """ pass diff --git a/src/pipecat/utils/text/base_text_filter.py b/src/pipecat/utils/text/base_text_filter.py index 787a1a9da..1a18a38a6 100644 --- a/src/pipecat/utils/text/base_text_filter.py +++ b/src/pipecat/utils/text/base_text_filter.py @@ -4,23 +4,69 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Base text filter interface for Pipecat text processing. + +This module defines the abstract base class for text filters that can modify +text content in the processing pipeline, including support for settings updates +and interruption handling. +""" + from abc import ABC, abstractmethod from typing import Any, Mapping class BaseTextFilter(ABC): + """Abstract base class for text filters in the Pipecat framework. + + Text filters are responsible for modifying text content as it flows through + the processing pipeline. They support dynamic settings updates and can handle + interruptions to reset their internal state. + + Subclasses must implement all abstract methods to define specific filtering + behavior, settings management, and interruption handling logic. + """ + @abstractmethod async def update_settings(self, settings: Mapping[str, Any]): + """Update the filter's configuration settings. + + Subclasses should implement this method to handle dynamic configuration + updates during runtime, updating internal state as needed. + + Args: + settings: Dictionary of setting names to values for configuration. + """ pass @abstractmethod async def filter(self, text: str) -> str: + """Apply filtering transformations to the input text. + + Subclasses must implement this method to define the specific text + transformations that should be applied to the input. + + Args: + text: The input text to be filtered. + + Returns: + The filtered text after applying transformations. + """ pass @abstractmethod async def handle_interruption(self): + """Handle interruption events in the processing pipeline. + + Subclasses should implement this method to reset internal state, + clear buffers, or perform other cleanup when an interruption occurs. + """ pass @abstractmethod async def reset_interruption(self): + """Reset the filter state after an interruption has been handled. + + Subclasses should implement this method to restore the filter to normal + operation after an interruption has been processed and resolved. + """ pass diff --git a/src/pipecat/utils/text/markdown_text_filter.py b/src/pipecat/utils/text/markdown_text_filter.py index 5ec960ad2..49b56b47a 100644 --- a/src/pipecat/utils/text/markdown_text_filter.py +++ b/src/pipecat/utils/text/markdown_text_filter.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Markdown text filter for removing Markdown formatting from text. + +This module provides a text filter that converts Markdown content to plain text +while preserving structure and handling special cases like code blocks and tables. +""" + import re from typing import Any, Mapping, Optional @@ -14,19 +20,34 @@ from pipecat.utils.text.base_text_filter import BaseTextFilter class MarkdownTextFilter(BaseTextFilter): - """Removes Markdown formatting from text in TextFrames. + """Text filter that removes Markdown formatting from text content. Converts Markdown to plain text while preserving the overall structure, including leading and trailing spaces. Handles special cases like - asterisks and table formatting. + asterisks and table formatting. Supports selective filtering of code + blocks and tables based on configuration. """ class InputParams(BaseModel): + """Configuration parameters for Markdown text filtering. + + Parameters: + enable_text_filter: Whether to apply Markdown filtering. Defaults to True. + filter_code: Whether to remove code blocks from the text. Defaults to False. + filter_tables: Whether to remove table content from the text. Defaults to False. + """ + enable_text_filter: Optional[bool] = True filter_code: Optional[bool] = False filter_tables: Optional[bool] = False def __init__(self, params: Optional[InputParams] = None, **kwargs): + """Initialize the Markdown text filter. + + Args: + params: Configuration parameters for filtering behavior. + **kwargs: Additional keyword arguments passed to parent class. + """ super().__init__(**kwargs) self._settings = params or MarkdownTextFilter.InputParams() self._in_code_block = False @@ -34,11 +55,24 @@ class MarkdownTextFilter(BaseTextFilter): self._interrupted = False async def update_settings(self, settings: Mapping[str, Any]): + """Update the filter's configuration settings. + + Args: + settings: Dictionary of setting names to values for configuration. + """ for key, value in settings.items(): if hasattr(self._settings, key): setattr(self._settings, key, value) async def filter(self, text: str) -> str: + """Apply Markdown filtering transformations to the input text. + + Args: + text: The input text containing Markdown formatting to be filtered. + + Returns: + The filtered text with Markdown formatting removed or converted. + """ if self._settings.enable_text_filter: # Remove newlines and replace with a space only when there's no text before or after filtered_text = re.sub(r"^\s*\n", " ", text, flags=re.MULTILINE) @@ -108,11 +142,20 @@ class MarkdownTextFilter(BaseTextFilter): return text async def handle_interruption(self): + """Handle interruption events in the processing pipeline. + + Resets the filter state and clears any tracking variables for + code blocks and tables. + """ self._interrupted = True self._in_code_block = False self._in_table = False async def reset_interruption(self): + """Reset the filter state after an interruption has been handled. + + Clears the interrupted flag to restore normal operation. + """ self._interrupted = False # @@ -120,8 +163,10 @@ class MarkdownTextFilter(BaseTextFilter): # def _remove_code_blocks(self, text: str) -> str: - """Main method to remove code blocks from the input text. - Handles interruptions and delegates to specific methods based on the current state. + """Remove code blocks from the input text. + + Handles interruptions and delegates to specific methods based on the + current state. """ if self._interrupted: self._in_code_block = False @@ -137,8 +182,10 @@ class MarkdownTextFilter(BaseTextFilter): return self._handle_not_in_code_block(match, text, code_block_pattern) def _handle_in_code_block(self, match, text): - """Handle text when we're currently inside a code block. - If we find the end of the block, return text after it. Otherwise, skip the content. + """Handle text when not currently inside a code block. + + If we find the end of the block, return text after it. Otherwise, skip + the content. """ if match: self._in_code_block = False @@ -147,9 +194,7 @@ class MarkdownTextFilter(BaseTextFilter): return "" # Skip content inside code block def _handle_not_in_code_block(self, match, text, code_block_pattern): - """Handle text when we're not currently inside a code block. - Delegate to specific methods based on whether we find a code block delimiter. - """ + """Handle text when not currently inside a code block.""" if not match: return text # No code block found, return original text @@ -159,14 +204,17 @@ class MarkdownTextFilter(BaseTextFilter): return self._handle_code_block_within_text(text, code_block_pattern) def _handle_start_of_code_block(self, text, start_index): - """Handle the case where we find the start of a code block. - Return any text before the code block and set the state to inside a code block. + """Handle the case where a code block starts. + + Return any text before the code block and set the state to inside a + code block. """ self._in_code_block = True return text[:start_index].strip() def _handle_code_block_within_text(self, text, code_block_pattern): - """Handle the case where we find a code block within the text. + """Handle code blocks found within text content. + If it's a complete code block, remove it and return surrounding text. If it's the start of a code block, return text before it and set state. """ @@ -180,8 +228,16 @@ class MarkdownTextFilter(BaseTextFilter): # Filter tables # def remove_tables(self, text: str) -> str: - """Remove tables from the input text, handling cases where - both start and end tags are in the same input. + """Remove HTML tables from the input text. + + Handles cases where both start and end tags are in the same input, + as well as tables that span multiple text chunks. + + Args: + text: The text containing HTML tables to remove. + + Returns: + The text with tables removed. """ if self._interrupted: self._in_table = False diff --git a/src/pipecat/utils/text/pattern_pair_aggregator.py b/src/pipecat/utils/text/pattern_pair_aggregator.py index dbc985774..ac074f2de 100644 --- a/src/pipecat/utils/text/pattern_pair_aggregator.py +++ b/src/pipecat/utils/text/pattern_pair_aggregator.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Pattern pair aggregator for processing structured content in streaming text. + +This module provides an aggregator that identifies and processes content between +pattern pairs (like XML tags or custom delimiters) in streaming text, with +support for custom handlers and configurable pattern removal. +""" + import re from typing import Awaitable, Callable, Optional, Tuple @@ -20,20 +27,15 @@ class PatternMatch: in the text. It contains information about which pattern was matched, the full matched text (including start and end patterns), and the content between the patterns. - - Attributes: - pattern_id: The identifier of the matched pattern pair. - full_match: The complete text including start and end patterns. - content: The text content between the start and end patterns. """ def __init__(self, pattern_id: str, full_match: str, content: str): """Initialize a pattern match. Args: - pattern_id: ID of the pattern pair. - full_match: Complete matched text including start and end patterns. - content: Content between the start and end patterns. + pattern_id: The identifier of the matched pattern pair. + full_match: The complete text including start and end patterns. + content: The text content between the start and end patterns. """ self.pattern_id = pattern_id self.full_match = full_match @@ -43,7 +45,7 @@ class PatternMatch: """Return a string representation of the pattern match. Returns: - A string describing the pattern match. + A descriptive string showing the pattern ID and content. """ return f"PatternMatch(id={self.pattern_id}, content={self.content})" @@ -66,6 +68,7 @@ class PatternPairAggregator(BaseTextAggregator): """Initialize the pattern pair aggregator. Creates an empty aggregator with no patterns or handlers registered. + Text buffering and pattern detection will begin when text is aggregated. """ self._text = "" self._patterns = {} @@ -76,7 +79,7 @@ class PatternPairAggregator(BaseTextAggregator): """Get the currently buffered text. Returns: - The current text buffer content. + The current text buffer content that hasn't been processed yet. """ return self._text @@ -115,7 +118,7 @@ class PatternPairAggregator(BaseTextAggregator): Args: pattern_id: ID of the pattern pair to match. - handler: Function to call when pattern is matched. + handler: Async function to call when pattern is matched. The function should accept a PatternMatch object. Returns: @@ -131,10 +134,11 @@ class PatternPairAggregator(BaseTextAggregator): appropriate handlers, and optionally removes the matches. Args: - text: The text to process. + text: The text to process for pattern matches. Returns: Tuple of (processed_text, was_modified) where: + - processed_text is the text after processing patterns - was_modified indicates whether any changes were made """ @@ -185,7 +189,7 @@ class PatternPairAggregator(BaseTextAggregator): matching end patterns, which would indicate incomplete content. Args: - text: The text to check. + text: The text to check for incomplete patterns. Returns: True if there are incomplete patterns, False otherwise. @@ -257,6 +261,6 @@ class PatternPairAggregator(BaseTextAggregator): """Clear the internally aggregated text. Resets the aggregator to its initial state, discarding any - buffered text. + buffered text and clearing pattern tracking state. """ self._text = "" diff --git a/src/pipecat/utils/text/simple_text_aggregator.py b/src/pipecat/utils/text/simple_text_aggregator.py index 791844f73..f9eb7d83a 100644 --- a/src/pipecat/utils/text/simple_text_aggregator.py +++ b/src/pipecat/utils/text/simple_text_aggregator.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Simple text aggregator for basic sentence-boundary text processing. + +This module provides a straightforward text aggregator that accumulates text +until it finds an end-of-sentence marker, making it suitable for basic TTS +text processing scenarios. +""" + from typing import Optional from pipecat.utils.string import match_endofsentence @@ -11,19 +18,43 @@ from pipecat.utils.text.base_text_aggregator import BaseTextAggregator class SimpleTextAggregator(BaseTextAggregator): - """This is a simple text aggregator. It aggregates text until an end of - sentence is found. + """Simple text aggregator that accumulates text until sentence boundaries. + This aggregator provides basic functionality for accumulating text tokens + and releasing them when an end-of-sentence marker is detected. It's the + most straightforward implementation of text aggregation for TTS processing. """ def __init__(self): + """Initialize the simple text aggregator. + + Creates an empty text buffer ready to begin accumulating text tokens. + """ self._text = "" @property def text(self) -> str: + """Get the currently aggregated text. + + Returns: + The text that has been accumulated in the buffer. + """ return self._text async def aggregate(self, text: str) -> Optional[str]: + """Aggregate text and return completed sentences. + + Adds the new text to the buffer and checks for end-of-sentence markers. + When a sentence boundary is found, returns the completed sentence and + removes it from the buffer. + + Args: + text: New text to add to the aggregation buffer. + + Returns: + A complete sentence if an end-of-sentence marker is found, + or None if more text is needed to complete a sentence. + """ result: Optional[str] = None self._text += text @@ -36,7 +67,17 @@ class SimpleTextAggregator(BaseTextAggregator): return result async def handle_interruption(self): + """Handle interruptions by clearing the text buffer. + + Called when an interruption occurs in the processing pipeline, + discarding any partially accumulated text. + """ self._text = "" async def reset(self): + """Clear the internally aggregated text. + + Resets the aggregator to its initial empty state, discarding + any accumulated text content. + """ self._text = "" diff --git a/src/pipecat/utils/text/skip_tags_aggregator.py b/src/pipecat/utils/text/skip_tags_aggregator.py index 81bbb9a96..6f6f8455c 100644 --- a/src/pipecat/utils/text/skip_tags_aggregator.py +++ b/src/pipecat/utils/text/skip_tags_aggregator.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Skip tags aggregator for preventing sentence boundaries within tagged content. + +This module provides a text aggregator that prevents end-of-sentence matching +between specified start/end tag pairs, ensuring that tagged content is processed +as a unit regardless of internal punctuation. +""" + from typing import Optional, Sequence from pipecat.utils.string import StartEndTags, match_endofsentence, parse_start_end_tags @@ -17,17 +24,18 @@ class SkipTagsAggregator(BaseTextAggregator): tag. If a start tag is found the aggregator will keep aggregating text unconditionally until the corresponding end tag is found. It's particularly useful for processing content with custom delimiters that should prevent - text from being considered for end of sentence matching.. + text from being considered for end of sentence matching. The aggregator ensures that tags spanning multiple text chunks are correctly - identified. - + identified and that content within tags is never split at sentence boundaries. """ def __init__(self, tags: Sequence[StartEndTags]): - """Initialize the pattern pair aggregator. + """Initialize the skip tags aggregator. - Creates an empty aggregator with no patterns or handlers registered. + Args: + tags: Sequence of StartEndTags objects defining the tag pairs + that should prevent sentence boundary detection. """ self._text = "" self._tags = tags @@ -39,24 +47,24 @@ class SkipTagsAggregator(BaseTextAggregator): """Get the currently buffered text. Returns: - The current text buffer content. + The current text buffer content that hasn't been processed yet. """ return self._text async def aggregate(self, text: str) -> Optional[str]: - """Aggregate text and process pattern pairs. + """Aggregate text while respecting tag boundaries. - This method adds the new text to the buffer, processes any complete pattern - pairs, and returns processed text up to sentence boundaries if possible. - If there are incomplete patterns (start without matching end), it will - continue buffering text. + This method adds the new text to the buffer, processes any complete + pattern pairs, and returns processed text up to sentence boundaries if + possible. If there are incomplete patterns (start without matching + end), it will continue buffering text. Args: text: New text to add to the buffer. Returns: - Processed text up to a sentence boundary, or None if more - text is needed to form a complete sentence or pattern. + Processed text up to a sentence boundary (when not within tags), + or None if more text is needed to complete a sentence or close tags. """ # Add new text to buffer self._text += text diff --git a/src/pipecat/utils/time.py b/src/pipecat/utils/time.py index b1e95f895..36156b1ca 100644 --- a/src/pipecat/utils/time.py +++ b/src/pipecat/utils/time.py @@ -4,22 +4,58 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Time utilities for the Pipecat framework. + +This module provides utility functions for time handling including +ISO8601 formatting, nanosecond conversions, and human-readable +time string formatting. +""" + import datetime def time_now_iso8601() -> str: + """Get the current UTC time as an ISO8601 formatted string. + + Returns: + The current UTC time in ISO8601 format with millisecond precision. + """ return datetime.datetime.now(datetime.timezone.utc).isoformat(timespec="milliseconds") def seconds_to_nanoseconds(seconds: float) -> int: + """Convert seconds to nanoseconds. + + Args: + seconds: The number of seconds to convert. + + Returns: + The equivalent number of nanoseconds as an integer. + """ return int(seconds * 1_000_000_000) def nanoseconds_to_seconds(nanoseconds: int) -> float: + """Convert nanoseconds to seconds. + + Args: + nanoseconds: The number of nanoseconds to convert. + + Returns: + The equivalent number of seconds as a float. + """ return nanoseconds / 1_000_000_000 def nanoseconds_to_str(nanoseconds: int) -> str: + """Convert nanoseconds to a human-readable time string. + + Args: + nanoseconds: The number of nanoseconds to convert. + + Returns: + A formatted time string in "H:MM:SS.microseconds" format. + """ total_seconds = nanoseconds_to_seconds(nanoseconds) hours = int(total_seconds // 3600) minutes = int((total_seconds % 3600) // 60) diff --git a/src/pipecat/utils/tracing/class_decorators.py b/src/pipecat/utils/tracing/class_decorators.py index 98da7211e..ba997b275 100644 --- a/src/pipecat/utils/tracing/class_decorators.py +++ b/src/pipecat/utils/tracing/class_decorators.py @@ -33,7 +33,7 @@ C = TypeVar("C", bound=type) class AttachmentStrategy(enum.Enum): """Controls how spans are attached to the trace hierarchy. - Attributes: + Parameters: CHILD: Attached to class span if no parent, otherwise to parent. LINK: Attached to class span with link to parent. NONE: Always attached to class span regardless of context. @@ -71,10 +71,10 @@ class Traceable: @property def meter(self): - """Returns the OpenTelemetry meter instance. + """Get the OpenTelemetry meter instance. Returns: - Meter: The OpenTelemetry meter instance for this object. + The OpenTelemetry meter instance for this object. """ return self._meter @@ -83,7 +83,17 @@ class Traceable: def __traced_context_manager( self: Traceable, func: Callable, name: str | None, attachment_strategy: AttachmentStrategy ): - """Internal context manager for the traced decorator.""" + """Internal context manager for the traced decorator. + + Args: + self: The Traceable instance. + func: The function being traced. + name: Custom span name or None to use function name. + attachment_strategy: How to attach this span to the trace hierarchy. + + Raises: + RuntimeError: If used in a class not inheriting from Traceable. + """ if not isinstance(self, Traceable): raise RuntimeError( "@traced annotation can only be used in classes inheriting from Traceable" @@ -124,7 +134,16 @@ def __traced_context_manager( def __traced_decorator(func, name, attachment_strategy: AttachmentStrategy): - """Implementation of the traced decorator.""" + """Implementation of the traced decorator. + + Args: + func: The function to trace. + name: Custom span name. + attachment_strategy: How to attach this span. + + Returns: + The wrapped function with tracing capabilities. + """ @functools.wraps(func) async def coroutine_wrapper(self: Traceable, *args, **kwargs): @@ -163,7 +182,7 @@ def traced( name: Optional[str] = None, attachment_strategy: AttachmentStrategy = AttachmentStrategy.CHILD, ) -> Callable: - """Adds tracing to an async function in a Traceable class. + """Add tracing to an async function in a Traceable class. Args: func: The async function to trace. @@ -193,7 +212,7 @@ def traced( def traceable(cls: C) -> C: - """Makes a class traceable for OpenTelemetry. + """Make a class traceable for OpenTelemetry. Creates a new class that inherits from both the original class and Traceable, enabling tracing for class methods. @@ -210,6 +229,12 @@ def traceable(cls: C) -> C: @functools.wraps(cls, updated=()) class TracedClass(cls, Traceable): def __init__(self, *args, **kwargs): + """Initialize the traced class instance. + + Args: + *args: Positional arguments passed to parent classes. + **kwargs: Keyword arguments passed to parent classes. + """ cls.__init__(self, *args, **kwargs) if hasattr(self, "name"): Traceable.__init__(self, self.name) diff --git a/src/pipecat/utils/tracing/conversation_context_provider.py b/src/pipecat/utils/tracing/conversation_context_provider.py index 611e59650..995776ff5 100644 --- a/src/pipecat/utils/tracing/conversation_context_provider.py +++ b/src/pipecat/utils/tracing/conversation_context_provider.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Conversation context provider for OpenTelemetry tracing in Pipecat. + +This module provides a singleton context provider that manages the current +conversation's tracing context, allowing services to create child spans +that are properly associated with the conversation. +""" + import uuid from typing import TYPE_CHECKING, Optional @@ -32,7 +39,11 @@ class ConversationContextProvider: @classmethod def get_instance(cls): - """Get the singleton instance.""" + """Get the singleton instance. + + Returns: + The singleton ConversationContextProvider instance. + """ if cls._instance is None: cls._instance = ConversationContextProvider() return cls._instance @@ -83,7 +94,6 @@ class ConversationContextProvider: return str(uuid.uuid4()) -# Create a simple helper function to get the current conversation context def get_current_conversation_context() -> Optional["Context"]: """Get the OpenTelemetry context for the current conversation. diff --git a/src/pipecat/utils/tracing/service_attributes.py b/src/pipecat/utils/tracing/service_attributes.py index 8f90ad3be..3896bd028 100644 --- a/src/pipecat/utils/tracing/service_attributes.py +++ b/src/pipecat/utils/tracing/service_attributes.py @@ -4,7 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # -"""Functions for adding attributes to OpenTelemetry spans.""" +"""Functions for adding attributes to OpenTelemetry spans. + +This module provides specialized functions for adding service-specific +attributes to OpenTelemetry spans, following standard semantic conventions +where applicable and Pipecat-specific conventions for additional context. +""" from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -26,6 +31,12 @@ def _get_gen_ai_system_from_service_name(service_name: str) -> str: Uses standard OTel names where possible, with special case mappings for service names that don't follow the pattern. + + Args: + service_name: The service class name to extract system name from. + + Returns: + The standardized gen_ai.system value. """ SPECIAL_CASE_MAPPINGS = { # AWS @@ -66,16 +77,16 @@ def add_tts_span_attributes( """Add TTS-specific attributes to a span. Args: - span: The span to add attributes to - service_name: Name of the TTS service (e.g., "cartesia") - model: Model name/identifier - voice_id: Voice identifier - text: The text being synthesized - settings: Service configuration settings - character_count: Number of characters in the text - operation_name: Name of the operation (default: "tts") - ttfb: Time to first byte in seconds - **kwargs: Additional attributes to add + span: The span to add attributes to. + service_name: Name of the TTS service (e.g., "cartesia"). + model: Model name/identifier. + voice_id: Voice identifier. + text: The text being synthesized. + settings: Service configuration settings. + character_count: Number of characters in the text. + operation_name: Name of the operation (default: "tts"). + ttfb: Time to first byte in seconds. + **kwargs: Additional attributes to add. """ # Add standard attributes span.set_attribute("gen_ai.system", service_name.replace("TTSService", "").lower()) @@ -122,17 +133,17 @@ def add_stt_span_attributes( """Add STT-specific attributes to a span. Args: - span: The span to add attributes to - service_name: Name of the STT service (e.g., "deepgram") - model: Model name/identifier - operation_name: Name of the operation (default: "stt") - transcript: The transcribed text - is_final: Whether this is a final transcript - language: Detected or configured language - settings: Service configuration settings - vad_enabled: Whether voice activity detection is enabled - ttfb: Time to first byte in seconds - **kwargs: Additional attributes to add + span: The span to add attributes to. + service_name: Name of the STT service (e.g., "deepgram"). + model: Model name/identifier. + operation_name: Name of the operation (default: "stt"). + transcript: The transcribed text. + is_final: Whether this is a final transcript. + language: Detected or configured language. + settings: Service configuration settings. + vad_enabled: Whether voice activity detection is enabled. + ttfb: Time to first byte in seconds. + **kwargs: Additional attributes to add. """ # Add standard attributes span.set_attribute("gen_ai.system", service_name.replace("STTService", "").lower()) @@ -184,20 +195,20 @@ def add_llm_span_attributes( """Add LLM-specific attributes to a span. Args: - span: The span to add attributes to - service_name: Name of the LLM service (e.g., "openai") - model: Model name/identifier - stream: Whether streaming is enabled - messages: JSON-serialized messages - output: Aggregated output text from the LLM - tools: JSON-serialized tools configuration - tool_count: Number of tools available - tool_choice: Tool selection configuration - system: System message - parameters: Service parameters - extra_parameters: Additional parameters - ttfb: Time to first byte in seconds - **kwargs: Additional attributes to add + span: The span to add attributes to. + service_name: Name of the LLM service (e.g., "openai"). + model: Model name/identifier. + stream: Whether streaming is enabled. + messages: JSON-serialized messages. + output: Aggregated output text from the LLM. + tools: JSON-serialized tools configuration. + tool_count: Number of tools available. + tool_choice: Tool selection configuration. + system: System message. + parameters: Service parameters. + extra_parameters: Additional parameters. + ttfb: Time to first byte in seconds. + **kwargs: Additional attributes to add. """ # Add standard attributes span.set_attribute("gen_ai.system", _get_gen_ai_system_from_service_name(service_name)) @@ -278,21 +289,21 @@ def add_gemini_live_span_attributes( """Add Gemini Live specific attributes to a span. Args: - span: The span to add attributes to - service_name: Name of the service - model: Model name/identifier - operation_name: Name of the operation (setup, model_turn, tool_call, etc.) - voice_id: Voice identifier used for output - language: Language code for the session - modalities: Supported modalities (e.g., "AUDIO", "TEXT") - settings: Service configuration settings - tools: Available tools/functions list - tools_serialized: JSON-serialized tools for detailed inspection - transcript: Transcription text - is_input: Whether transcript is input (True) or output (False) - text_output: Text output from model - audio_data_size: Size of audio data in bytes - **kwargs: Additional attributes to add + span: The span to add attributes to. + service_name: Name of the service. + model: Model name/identifier. + operation_name: Name of the operation (setup, model_turn, tool_call, etc.). + voice_id: Voice identifier used for output. + language: Language code for the session. + modalities: Supported modalities (e.g., "AUDIO", "TEXT"). + settings: Service configuration settings. + tools: Available tools/functions list. + tools_serialized: JSON-serialized tools for detailed inspection. + transcript: Transcription text. + is_input: Whether transcript is input (True) or output (False). + text_output: Text output from model. + audio_data_size: Size of audio data in bytes. + **kwargs: Additional attributes to add. """ # Add standard attributes span.set_attribute("gen_ai.system", "gcp.gemini") @@ -381,19 +392,19 @@ def add_openai_realtime_span_attributes( """Add OpenAI Realtime specific attributes to a span. Args: - span: The span to add attributes to - service_name: Name of the service - model: Model name/identifier - operation_name: Name of the operation (setup, transcription, response, etc.) - session_properties: Session configuration properties - transcript: Transcription text - is_input: Whether transcript is input (True) or output (False) - context_messages: JSON-serialized context messages - function_calls: Function calls being made - tools: Available tools/functions list - tools_serialized: JSON-serialized tools for detailed inspection - audio_data_size: Size of audio data in bytes - **kwargs: Additional attributes to add + span: The span to add attributes to. + service_name: Name of the service. + model: Model name/identifier. + operation_name: Name of the operation (setup, transcription, response, etc.). + session_properties: Session configuration properties. + transcript: Transcription text. + is_input: Whether transcript is input (True) or output (False). + context_messages: JSON-serialized context messages. + function_calls: Function calls being made. + tools: Available tools/functions list. + tools_serialized: JSON-serialized tools for detailed inspection. + audio_data_size: Size of audio data in bytes. + **kwargs: Additional attributes to add. """ # Add standard attributes span.set_attribute("gen_ai.system", "openai") diff --git a/src/pipecat/utils/tracing/service_decorators.py b/src/pipecat/utils/tracing/service_decorators.py index c016827d6..9078bba15 100644 --- a/src/pipecat/utils/tracing/service_decorators.py +++ b/src/pipecat/utils/tracing/service_decorators.py @@ -41,9 +41,15 @@ T = TypeVar("T") R = TypeVar("R") -# Internal helper functions def _noop_decorator(func): - """No-op fallback decorator when tracing is unavailable.""" + """No-op fallback decorator when tracing is unavailable. + + Args: + func: The function to pass through unchanged. + + Returns: + The original function unchanged. + """ return func @@ -53,10 +59,10 @@ def _get_parent_service_context(self): This looks for the service span that was created when the service was initialized. Args: - self: The service instance + self: The service instance. Returns: - Context or None: The parent service context, or None if unavailable + The parent service context, or None if unavailable. """ if not is_tracing_available(): return None @@ -73,8 +79,8 @@ def _add_token_usage_to_span(span, token_usage): """Add token usage metrics to a span (internal use only). Args: - span: The span to add token metrics to - token_usage: Dictionary or object containing token usage information + span: The span to add token metrics to. + token_usage: Dictionary or object containing token usage information. """ if not is_tracing_available() or not token_usage: return @@ -93,9 +99,10 @@ def _add_token_usage_to_span(span, token_usage): def traced_tts(func: Optional[Callable] = None, *, name: Optional[str] = None) -> Callable: - """Traces TTS service methods with TTS-specific attributes. + """Trace TTS service methods with TTS-specific attributes. Automatically captures and records: + - Service name and model information - Voice ID and settings - Character count and text content @@ -118,7 +125,15 @@ def traced_tts(func: Optional[Callable] = None, *, name: Optional[str] = None) - @contextlib.asynccontextmanager async def tracing_context(self, text): - """Async context manager for TTS tracing.""" + """Async context manager for TTS tracing. + + Args: + self: The TTS service instance. + text: The text being synthesized. + + Yields: + The active span for the TTS operation. + """ if not is_tracing_available(): yield None return @@ -201,9 +216,10 @@ def traced_tts(func: Optional[Callable] = None, *, name: Optional[str] = None) - def traced_stt(func: Optional[Callable] = None, *, name: Optional[str] = None) -> Callable: - """Traces STT service methods with transcription attributes. + """Trace STT service methods with transcription attributes. Automatically captures and records: + - Service name and model information - Transcription text and final status - Language information @@ -278,9 +294,10 @@ def traced_stt(func: Optional[Callable] = None, *, name: Optional[str] = None) - def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -> Callable: - """Traces LLM service methods with LLM-specific attributes. + """Trace LLM service methods with LLM-specific attributes. Automatically captures and records: + - Service name and model information - Context content and messages - Tool configurations @@ -482,16 +499,17 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) - def traced_gemini_live(operation: str) -> Callable: - """Traces Gemini Live service methods with operation-specific attributes. + """Trace Gemini Live service methods with operation-specific attributes. This decorator automatically captures relevant information based on the operation type: + - llm_setup: Configuration, tools definitions, and system instructions - llm_tool_call: Function call information - llm_tool_result: Function execution results - llm_response: Complete LLM response with usage and output Args: - operation: The operation name (matches the event type being handled) + operation: The operation name (matches the event type being handled). Returns: Wrapped method with Gemini Live specific tracing. @@ -786,15 +804,16 @@ def traced_gemini_live(operation: str) -> Callable: def traced_openai_realtime(operation: str) -> Callable: - """Traces OpenAI Realtime service methods with operation-specific attributes. + """Trace OpenAI Realtime service methods with operation-specific attributes. This decorator automatically captures relevant information based on the operation type: + - llm_setup: Session configuration and tools - llm_request: Context and input messages - llm_response: Usage metadata, output, and function calls Args: - operation: The operation name (matches the event type being handled) + operation: The operation name (matches the event type being handled). Returns: Wrapped method with OpenAI Realtime specific tracing. diff --git a/src/pipecat/utils/tracing/setup.py b/src/pipecat/utils/tracing/setup.py index ab74530cb..f2dfa0c88 100644 --- a/src/pipecat/utils/tracing/setup.py +++ b/src/pipecat/utils/tracing/setup.py @@ -4,7 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # -"""Core OpenTelemetry tracing utilities and setup for Pipecat.""" +"""Core OpenTelemetry tracing utilities and setup for Pipecat. + +This module provides functions to check availability and configure OpenTelemetry +tracing for Pipecat applications. It handles the optional nature of OpenTelemetry +dependencies and provides a safe setup process. +""" import os @@ -21,10 +26,10 @@ except ImportError: def is_tracing_available() -> bool: - """Returns True if OpenTelemetry tracing is available and configured. + """Check if OpenTelemetry tracing is available and configured. Returns: - bool: True if tracing is available, False otherwise. + True if tracing is available, False otherwise. """ return OPENTELEMETRY_AVAILABLE @@ -37,15 +42,16 @@ def setup_tracing( """Set up OpenTelemetry tracing with a user-provided exporter. Args: - service_name: The name of the service for traces + service_name: The name of the service for traces. exporter: A pre-configured OpenTelemetry span exporter instance. If None, only console export will be available if enabled. - console_export: Whether to also export traces to console (useful for debugging) + console_export: Whether to also export traces to console (useful for debugging). Returns: - bool: True if setup was successful, False otherwise + True if setup was successful, False otherwise. + + Example:: - Example: # With OTLP exporter from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter diff --git a/src/pipecat/utils/tracing/turn_context_provider.py b/src/pipecat/utils/tracing/turn_context_provider.py index 7af3d694a..f02d92d45 100644 --- a/src/pipecat/utils/tracing/turn_context_provider.py +++ b/src/pipecat/utils/tracing/turn_context_provider.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Turn context provider for OpenTelemetry tracing in Pipecat. + +This module provides a singleton context provider that manages the current +turn's tracing context, allowing services to create child spans that are +properly associated with the conversation turn. +""" + from typing import TYPE_CHECKING, Optional # Import types for type checking only @@ -30,7 +37,11 @@ class TurnContextProvider: @classmethod def get_instance(cls): - """Get the singleton instance.""" + """Get the singleton instance. + + Returns: + The singleton TurnContextProvider instance. + """ if cls._instance is None: cls._instance = TurnContextProvider() return cls._instance @@ -60,7 +71,6 @@ class TurnContextProvider: return self._current_turn_context -# Create a simple helper function to get the current turn context def get_current_turn_context() -> Optional["Context"]: """Get the OpenTelemetry context for the current turn. diff --git a/src/pipecat/utils/tracing/turn_trace_observer.py b/src/pipecat/utils/tracing/turn_trace_observer.py index f67ca3b28..31b4d0130 100644 --- a/src/pipecat/utils/tracing/turn_trace_observer.py +++ b/src/pipecat/utils/tracing/turn_trace_observer.py @@ -4,6 +4,13 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Turn trace observer for OpenTelemetry tracing in Pipecat. + +This module provides an observer that creates trace spans for each conversation +turn, integrating with the turn tracking system to provide hierarchical tracing +of conversation flows. +""" + from typing import TYPE_CHECKING, Dict, Optional from loguru import logger @@ -41,6 +48,14 @@ class TurnTraceObserver(BaseObserver): additional_span_attributes: Optional[dict] = None, **kwargs, ): + """Initialize the turn trace observer. + + Args: + turn_tracker: The turn tracking observer to monitor. + conversation_id: Optional conversation ID for grouping turns. + additional_span_attributes: Additional attributes to add to spans. + **kwargs: Additional arguments passed to parent class. + """ super().__init__(**kwargs) self._turn_tracker = turn_tracker self._current_span: Optional["Span"] = None @@ -68,6 +83,9 @@ class TurnTraceObserver(BaseObserver): This observer doesn't need to process individual frames as it relies on turn start/end events from the turn tracker. + + Args: + data: The frame push event data. """ pass @@ -198,6 +216,9 @@ class TurnTraceObserver(BaseObserver): """Get the span context for the current turn. This can be used by services to create child spans. + + Returns: + The current turn's span context or None if not available. """ if not is_tracing_available() or not self._current_span: return None @@ -208,6 +229,12 @@ class TurnTraceObserver(BaseObserver): """Get the span context for a specific turn. This can be used by services to create child spans. + + Args: + turn_number: The turn number to get context for. + + Returns: + The specified turn's span context or None if not available. """ if not is_tracing_available(): return None diff --git a/src/pipecat/utils/utils.py b/src/pipecat/utils/utils.py index 0f4801f35..f741bc05b 100644 --- a/src/pipecat/utils/utils.py +++ b/src/pipecat/utils/utils.py @@ -4,6 +4,12 @@ # SPDX-License-Identifier: BSD 2-Clause License # +"""Utility functions for object identification and counting. + +This module provides thread-safe utilities for generating unique identifiers +and maintaining per-class instance counts across the Pipecat framework. +""" + import collections import itertools import threading @@ -17,27 +23,40 @@ _ID_LOCK = threading.Lock() def obj_id() -> int: """Generate a unique id for an object. - >>> obj_id() - 0 - >>> obj_id() - 1 - >>> obj_id() - 2 + Returns: + A unique integer identifier that increments globally across all objects. + + Examples:: + + >>> obj_id() + 0 + >>> obj_id() + 1 + >>> obj_id() + 2 """ with _ID_LOCK: return next(_ID) def obj_count(obj) -> int: - """Generate a unique id for an object. + """Generate a unique count for an object based on its class. - >>> obj_count(object()) - 0 - >>> obj_count(object()) - 1 - >>> new_type = type('NewType', (object,), {}) - >>> obj_count(new_type()) - 0 + Args: + obj: The object instance to count. + + Returns: + A unique integer count that increments per class type. + + Examples:: + + >>> obj_count(object()) + 0 + >>> obj_count(object()) + 1 + >>> new_type = type('NewType', (object,), {}) + >>> obj_count(new_type()) + 0 """ with _COUNTS_LOCK: return next(_COUNTS[obj.__class__.__name__])