Compare commits

...

3 Commits

Author SHA1 Message Date
Mark Backman
5a6cc4d35c Replace assert-based type narrowing with local variables and guards
Use local variable narrowing and if-guards instead of assert statements
for type safety, since asserts are stripped with python -O.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 16:46:45 -05:00
Mark Backman
28be775740 Reduce type: ignore comments by fixing avoidable type mismatches
Replace ~20 type: ignore comments with proper type fixes:
- Widen set_tools() to accept List[dict] | ToolsSchema | NotGiven
- Widen create_task() to accept Coroutine | Awaitable
- Fix _turn_params to use BaseTurnParams instead of SmartTurnParams
- Make _thought_llm Optional[str] with assertion guard
- Add mixer assertion, websocket narrowing, ice_servers cast
- Use dict.get() in protobuf serializer
- Make remote_participants Optional in Daily transport

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 15:30:35 -05:00
Mark Backman
bc730e4069 Enable pyright basic type checking for core framework
Add pyright configuration (basic mode, Python 3.10) to pyproject.toml
and fix all 276 type errors in the core framework (everything except
services/ and adapters/). This establishes a CI-ready type checking
baseline as Pipecat approaches 1.0.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 15:30:35 -05:00
66 changed files with 418 additions and 205 deletions

1
changelog/3678.added.md Normal file
View File

@@ -0,0 +1 @@
- Added pyright basic type checking configuration for the core framework.

View File

@@ -212,6 +212,14 @@ ignore = [
[tool.ruff.lint.pydocstyle]
convention = "google"
[tool.pyright]
typeCheckingMode = "basic"
pythonVersion = "3.10"
exclude = [
"src/pipecat/services/",
"src/pipecat/adapters/",
]
[tool.coverage.run]
command_line = "--module pytest"
source = ["src"]

View File

@@ -31,7 +31,7 @@ def version() -> str:
import asyncio
if sys.version_info < (3, 12):
import wait_for2
import wait_for2 # type: ignore[import-untyped]
# Replace asyncio.wait_for.
asyncio.wait_for = wait_for2.wait_for

View File

@@ -55,7 +55,7 @@ async def load_dtmf_audio(button: KeypadEntry, *, sample_rate: int = 8000) -> by
dtmf_file_name = __DTMF_FILE_NAME.get(button, f"dtmf-{button.value}.wav")
dtmf_file_path = files("pipecat.audio.dtmf").joinpath(dtmf_file_name)
async with aiofiles.open(dtmf_file_path, "rb") as f:
async with aiofiles.open(str(dtmf_file_path), "rb") as f:
data = await f.read()
with io.BytesIO(data) as buffer:

View File

@@ -18,16 +18,22 @@ from pathlib import Path
from typing import List, Optional
import numpy as np
from aic_sdk import (
Model,
ParameterFixedError,
ProcessorAsync,
ProcessorConfig,
ProcessorParameter,
set_sdk_id,
)
from loguru import logger
try:
from aic_sdk import ( # type: ignore[import-not-found,import-untyped]
Model,
ParameterFixedError,
ProcessorAsync,
ProcessorConfig,
ProcessorParameter,
set_sdk_id, # pyright: ignore[reportAttributeAccessIssue]
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use AICFilter, you need to install aic_sdk.")
raise Exception(f"Missing module: {e}")
from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
from pipecat.audio.vad.aic_vad import AICVADAnalyzer
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame
@@ -167,6 +173,7 @@ class AICFilter(BaseAudioFilter):
logger.debug(f"Loading AIC model from: {self._model_path}")
self._model = Model.from_file(str(self._model_path))
else:
assert self._model_id is not None
logger.debug(f"Downloading AIC model: {self._model_id}")
self._model_download_dir.mkdir(parents=True, exist_ok=True)
model_path = await Model.download_async(self._model_id, str(self._model_download_dir))
@@ -200,6 +207,7 @@ class AICFilter(BaseAudioFilter):
return
# Get contexts for parameter control and VAD
assert self._processor is not None
self._processor_ctx = self._processor.get_processor_context()
self._vad_ctx = self._processor.get_vad_context()
@@ -287,6 +295,9 @@ class AICFilter(BaseAudioFilter):
mv = memoryview(self._audio_buffer)
block_size = self._frames_per_block * self._bytes_per_sample
assert self._in_f32 is not None
assert self._out_i16 is not None
for i in range(num_blocks):
start = i * block_size
block_i16 = np.frombuffer(mv[start : start + block_size], dtype=self._dtype)

View File

@@ -11,6 +11,7 @@ reduction technology to suppress background noise in audio streams.
"""
import os
from typing import Optional
import numpy as np
from loguru import logger
@@ -19,7 +20,9 @@ from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame
try:
from pipecat_ai_krisp.audio.krisp_processor import KrispAudioProcessor
from pipecat_ai_krisp.audio.krisp_processor import ( # type: ignore[import-not-found]
KrispAudioProcessor,
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use the Krisp filter, you need to `pip install pipecat-ai[krisp]`.")
@@ -68,7 +71,7 @@ class KrispFilter(BaseAudioFilter):
"""
def __init__(
self, sample_type: str = "PCM_16", channels: int = 1, model_path: str = None
self, sample_type: str = "PCM_16", channels: int = 1, model_path: Optional[str] = None
) -> None:
"""Initialize the Krisp noise reduction filter.
@@ -115,6 +118,7 @@ class KrispFilter(BaseAudioFilter):
sample_rate: The sample rate of the input transport in Hz.
"""
self._sample_rate = sample_rate
assert self._model_path is not None
self._krisp_processor = KrispProcessorManager.get_processor(
self._sample_rate, self._sample_type, self._channels, self._model_path
)
@@ -154,6 +158,7 @@ class KrispFilter(BaseAudioFilter):
data = data.astype(np.float32) + epsilon
# Process the audio chunk to reduce noise
assert self._krisp_processor is not None
reduced_noise = self._krisp_processor.process(data)
# Clip and set processed audio back to frame

View File

@@ -10,6 +10,7 @@ This module provides an audio filter implementation using Krisp VIVA SDK.
"""
import os
from typing import Optional
import numpy as np
from loguru import logger
@@ -23,7 +24,7 @@ from pipecat.audio.krisp_instance import (
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame
try:
import krisp_audio
import krisp_audio # type: ignore[import-not-found]
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use KrispVivaFilter, you need to install krisp_audio.")
@@ -39,7 +40,10 @@ class KrispVivaFilter(BaseAudioFilter):
"""
def __init__(
self, model_path: str = None, frame_duration: int = 10, noise_suppression_level: int = 100
self,
model_path: Optional[str] = None,
frame_duration: int = 10,
noise_suppression_level: int = 100,
) -> None:
"""Initialize the Krisp noise reduction filter.
@@ -171,6 +175,9 @@ class KrispVivaFilter(BaseAudioFilter):
return audio
try:
if self._samples_per_frame is None or self._session is None:
return audio
# Add incoming audio to our buffer
self._audio_buffer.extend(audio)

View File

@@ -61,6 +61,8 @@ class RNNoiseFilter(BaseAudioFilter):
try:
# RNNoise always requires 48kHz
if RNNoise is None:
raise RuntimeError("RNNoise module is not available")
self._rnnoise = RNNoise(sample_rate=48000)
self._rnnoise_ready = True
except Exception as e:
@@ -126,6 +128,7 @@ class RNNoiseFilter(BaseAudioFilter):
# denoise_chunk handles buffering internally and yields (speech_prob, denoised_frame)
# denoised_frame is in float32 format normalized to [-1.0, 1.0]
filtered_frames = []
assert self._rnnoise is not None
for speech_prob, denoised_frame in self._rnnoise.denoise_chunk(audio_samples):
# Check if output is float (needs scaling) or int16 (ready)
if np.issubdtype(denoised_frame.dtype, np.floating):

View File

@@ -12,7 +12,7 @@ from threading import Lock
from loguru import logger
try:
import krisp_audio
import krisp_audio # type: ignore[import-not-found]
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use the Krisp instance, you need to install krisp_audio.")

View File

@@ -96,6 +96,7 @@ class SOXRStreamAudioResampler(BaseAudioResampler):
self._maybe_initialize_sox_stream(in_rate, out_rate)
audio_data = np.frombuffer(audio, dtype=np.int16)
assert self._soxr_stream is not None
resampled_audio = self._soxr_stream.resample_chunk(audio_data)
result = resampled_audio.astype(np.int16).tobytes()
return result

View File

@@ -29,7 +29,7 @@ from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, BaseTurnPara
from pipecat.metrics.metrics import MetricsData
try:
import krisp_audio
import krisp_audio # type: ignore[import-not-found]
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use KrispVivaTurn, you need to install krisp_audio.")

View File

@@ -74,7 +74,7 @@ class LocalSmartTurnAnalyzer(BaseSmartTurn):
# Set device to GPU if available, else CPU
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Move model to selected device and set it to evaluation mode
self._turn_model = self._turn_model.to(self._device)
self._turn_model = self._turn_model.to(self._device) # type: ignore[assignment]
self._turn_model.eval()
logger.debug("Loaded Local Smart Turn")

View File

@@ -69,7 +69,7 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
elif torch.cuda.is_available():
self._device = "cuda"
# Move model to selected device and set it to evaluation mode
self._turn_model = self._turn_model.to(self._device)
self._turn_model = self._turn_model.to(self._device) # type: ignore[assignment]
self._turn_model.eval()
logger.debug("Loaded Local Smart Turn v2")
@@ -77,12 +77,12 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
"""Predict end-of-turn using local PyTorch model."""
inputs = self._turn_processor(
audio_array,
sampling_rate=16000,
padding="max_length",
truncation=True,
max_length=16000 * 16, # 16 seconds at 16kHz
return_attention_mask=True,
return_tensors="pt",
sampling_rate=16000, # type: ignore[call-arg]
padding="max_length", # type: ignore[call-arg]
truncation=True, # type: ignore[call-arg]
max_length=16000 * 16, # type: ignore[call-arg]
return_attention_mask=True, # type: ignore[call-arg]
return_tensors="pt", # type: ignore[call-arg]
)
# Move inputs to device

View File

@@ -57,7 +57,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
package_path = "pipecat.audio.turn.smart_turn.data"
try:
import importlib_resources as impresources
import importlib_resources as impresources # type: ignore[import-not-found]
smart_turn_model_path = str(impresources.files(package_path).joinpath(model_name))
except BaseException:
@@ -65,7 +65,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
try:
with impresources.path(package_path, model_name) as f:
smart_turn_model_path = f
smart_turn_model_path = str(f)
except BaseException:
smart_turn_model_path = str(
impresources.files(package_path).joinpath(model_name)
@@ -80,6 +80,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
assert smart_turn_model_path is not None
self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so)
logger.debug("Loaded Local Smart Turn v3.x")
@@ -165,7 +166,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
outputs = self._session.run(None, {"input_features": input_features})
# Extract probability (ONNX model returns sigmoid probabilities)
probability = outputs[0][0].item()
probability = outputs[0][0].item() # type: ignore[index]
# Make prediction (1 for Complete, 0 for Incomplete)
prediction = 1 if probability > 0.5 else 0

View File

@@ -65,7 +65,7 @@ class SileroOnnxModel:
if np.ndim(x) == 1:
x = np.expand_dims(x, 0)
if np.ndim(x) > 2:
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
raise ValueError(f"Too many dimensions for input audio chunk {np.ndim(x)}")
if sr not in self.sample_rates:
raise ValueError(
@@ -150,7 +150,7 @@ class SileroVADAnalyzer(VADAnalyzer):
package_path = "pipecat.audio.vad.data"
try:
import importlib_resources as impresources
import importlib_resources as impresources # type: ignore[import-not-found]
model_file_path = str(impresources.files(package_path).joinpath(model_name))
except BaseException:
@@ -158,7 +158,7 @@ class SileroVADAnalyzer(VADAnalyzer):
try:
with impresources.path(package_path, model_name) as f:
model_file_path = f
model_file_path = str(f)
except BaseException:
model_file_path = str(impresources.files(package_path).joinpath(model_name))
@@ -209,7 +209,7 @@ class SileroVADAnalyzer(VADAnalyzer):
audio_int16 = np.frombuffer(buffer, np.int16)
# Divide by 32768 because we have signed 16-bit data.
audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0
new_confidence = self._model(audio_float32, self.sample_rate)[0]
new_confidence = self._model(audio_float32, self.sample_rate)[0] # type: ignore[index]
# We need to reset the model from time to time because it doesn't
# really need all the data and memory will keep growing otherwise.

View File

@@ -99,7 +99,7 @@ class IVRProcessor(FrameProcessor):
self._register_event_handler("on_conversation_detected")
self._register_event_handler("on_ivr_status_changed")
def update_saved_messages(self, messages: List[dict]) -> None:
def update_saved_messages(self, messages: List) -> None:
"""Update the saved context messages.
Sets the messages that are saved when switching between

View File

@@ -36,7 +36,7 @@ from pipecat.frames.frames import (
UserStoppedSpeakingFrame,
)
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage
from pipecat.processors.aggregators.llm_response_universal import (
LLMContextAggregatorPair,
LLMUserAggregatorParams,
@@ -618,7 +618,7 @@ VOICEMAIL SYSTEM (respond "VOICEMAIL"):
self._validate_prompt(custom_system_prompt)
# Set up the LLM context with the classification prompt
self._messages = [
self._messages: List[LLMContextMessage] = [
{
"role": "system",
"content": self._prompt,

View File

@@ -44,7 +44,7 @@ class UserBotLatencyObserver(BaseObserver):
"""
super().__init__(**kwargs)
self._user_stopped_time: Optional[float] = None
self._processed_frames: Set[str] = set()
self._processed_frames: Set[int] = set()
self._register_event_handler("on_latency_measured")

View File

@@ -6,11 +6,12 @@
"""LLM switcher for switching between different LLMs at runtime, with different switching strategies."""
from typing import Any, List, Optional, Type
from typing import Any, List, Optional, Type, cast
from pipecat.adapters.schemas.direct_function import DirectFunction
from pipecat.pipeline.service_switcher import ServiceSwitcher, StrategyType
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.frame_processor import FrameProcessor
from pipecat.services.llm_service import LLMService
@@ -32,7 +33,7 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
llms: List of LLM services to switch between.
strategy_type: The strategy class to use for switching between LLMs.
"""
super().__init__(llms, strategy_type)
super().__init__(cast(List[FrameProcessor], llms), strategy_type)
@property
def llms(self) -> List[LLMService]:
@@ -41,7 +42,7 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
Returns:
List of LLM services managed by this switcher.
"""
return self.services
return cast(List[LLMService], self.services)
@property
def active_llm(self) -> Optional[LLMService]:
@@ -50,7 +51,7 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
Returns:
The currently active LLM service, or None if no LLM is active.
"""
return self.strategy.active_service
return cast(Optional[LLMService], self.strategy.active_service)
async def run_inference(self, context: LLMContext) -> Optional[str]:
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context, using the currently active LLM.

View File

@@ -128,7 +128,7 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
def __init__(
self,
wrapped_service: FrameProcessor,
active_service: FrameProcessor,
active_service: Optional[FrameProcessor],
direction: FrameDirection,
):
"""Initialize the service switcher filter with a strategy and direction.
@@ -163,7 +163,7 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
class ServiceSwitcherFilterFrame(ControlFrame):
"""An internal frame used by ServiceSwitcher to filter frames based on active service."""
active_service: FrameProcessor
active_service: Optional[FrameProcessor]
@staticmethod
def _make_pipeline_definitions(

View File

@@ -383,7 +383,7 @@ class PipelineTask(BasePipelineTask):
# allows us to receive and react to downstream frames.
source = PipelineSource(self._source_push_frame, name=f"{self}::Source")
sink = PipelineSink(self._sink_push_frame, name=f"{self}::Sink")
processors = [self._rtvi, pipeline] if self._rtvi else [pipeline]
processors: List[FrameProcessor] = [self._rtvi, pipeline] if self._rtvi else [pipeline]
self._pipeline = Pipeline(processors, source=source, sink=sink)
# The task observer acts as a proxy to the provided observers. This way,
@@ -786,7 +786,7 @@ class PipelineTask(BasePipelineTask):
await self._observer.cleanup()
# End conversation tracing if it's active - this will also close any active turn span
if self._enable_tracing and hasattr(self, "_turn_trace_observer"):
if self._enable_tracing and self._turn_trace_observer is not None:
self._turn_trace_observer.end_conversation_tracing()
# Cleanup pipeline processors.
@@ -1020,7 +1020,7 @@ class PipelineTask(BasePipelineTask):
path = Path(f).resolve()
module_name = path.stem
spec = importlib.util.spec_from_file_location(module_name, str(path))
if spec:
if spec and spec.loader:
logger.debug(f"{self} loading observers from {path}")
# Load module.

View File

@@ -164,6 +164,8 @@ class TaskObserver(BaseObserver):
return proxies
async def _send_to_proxy(self, data: Any):
if not self._proxies:
return
for proxy in self._proxies.values():
await proxy.queue.put(data)
@@ -188,7 +190,9 @@ class TaskObserver(BaseObserver):
if isinstance(data, FramePushed):
if on_push_frame_deprecated:
await observer.on_push_frame(
# Deprecated signature: on_push_frame(source, destination, frame, direction, timestamp)
handler: Any = observer.on_push_frame
await handler(
data.source, data.destination, data.frame, data.direction, data.timestamp
)
else:

View File

@@ -19,7 +19,7 @@ import base64
import io
import wave
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Optional, TypeAlias, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeAlias, Union, cast
from loguru import logger
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
@@ -107,10 +107,13 @@ class LLMContext:
converted_tools = openai_context.tools
if isinstance(converted_tools, list):
converted_tools = ToolsSchema(
standard_tools=[], custom_tools={AdapterType.SHIM: converted_tools}
standard_tools=[],
custom_tools=cast(
Dict[AdapterType, List[Dict[str, Any]]], {AdapterType.SHIM: converted_tools}
),
)
return LLMContext(
messages=openai_context.get_messages(),
messages=cast(List[LLMContextMessage], openai_context.get_messages()),
tools=converted_tools,
tool_choice=openai_context.tool_choice,
)
@@ -152,7 +155,7 @@ class LLMContext:
content.append({"type": "image_url", "image_url": {"url": url}})
return {"role": role, "content": content}
return cast(LLMContextMessage, {"role": role, "content": content})
@staticmethod
async def create_image_message(
@@ -204,7 +207,7 @@ class LLMContext:
audio_frames: List of audio frame objects to include.
text: Optional text to include with the audio.
"""
content = [{"type": "text", "text": text}]
content: List[Dict[str, Any]] = [{"type": "text", "text": text}]
def encode_audio():
sample_rate = audio_frames[0].sample_rate
@@ -231,7 +234,7 @@ class LLMContext:
}
)
return {"role": role, "content": content}
return cast(LLMContextMessage, {"role": role, "content": content})
@property
def messages(self) -> List[LLMContextMessage]:

View File

@@ -15,12 +15,15 @@ import asyncio
import warnings
from abc import abstractmethod
from dataclasses import dataclass
from typing import Dict, List, Literal, Optional, Set
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Set, cast
if TYPE_CHECKING:
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolChoiceOptionParam
from loguru import logger
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
from pipecat.audio.turn.base_turn_analyzer import BaseTurnParams
from pipecat.audio.vad.vad_analyzer import VADParams
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
@@ -336,7 +339,7 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
Returns:
List of message dictionaries from the context.
"""
return self._context.get_messages()
return cast(List[dict], self._context.get_messages())
@property
def role(self) -> str:
@@ -402,7 +405,7 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
"""
self._context.set_messages(messages)
def set_tools(self, tools: List):
def set_tools(self, tools):
"""Set tools in the context.
Args:
@@ -416,7 +419,7 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
Args:
tool_choice: Tool choice configuration for the context.
"""
self._context.set_tool_choice(tool_choice)
self._context.set_tool_choice(cast("ChatCompletionToolChoiceOptionParam", tool_choice))
async def reset(self):
"""Reset the aggregation state."""
@@ -467,7 +470,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
super().__init__(context=context, role="user", **kwargs)
self._params = params or LLMUserAggregatorParams()
self._vad_params: Optional[VADParams] = None
self._turn_params: Optional[SmartTurnParams] = None
self._turn_params: Optional[BaseTurnParams] = None
if "aggregation_timeout" in kwargs:
with warnings.catch_warnings():
@@ -503,7 +506,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
Args:
aggregation: The aggregated user text to add as a user message.
"""
self._context.add_message({"role": self.role, "content": aggregation})
self._context.add_message(
cast("ChatCompletionMessageParam", {"role": self.role, "content": aggregation})
)
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process frames for user speech aggregation and context management.
@@ -851,7 +856,9 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
Args:
aggregation: The aggregated assistant text to add as an assistant message.
"""
self._context.add_message({"role": "assistant", "content": aggregation})
self._context.add_message(
cast("ChatCompletionMessageParam", {"role": "assistant", "content": aggregation})
)
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
"""Handle a function call that is in progress.
@@ -1030,6 +1037,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
del self._function_calls_in_progress[frame.tool_call_id]
async def _handle_user_image_frame(self, frame: UserImageRawFrame):
assert frame.request is not None
logger.debug(
f"{self} UserImageRawFrame: [{frame.request.function_name}:{frame.request.tool_call_id}]"
)
@@ -1105,14 +1113,18 @@ class LLMUserResponseAggregator(LLMUserContextAggregator):
DeprecationWarning,
stacklevel=2,
)
super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
super().__init__(
context=OpenAILLMContext(cast(List["ChatCompletionMessageParam"], messages)),
params=params,
**kwargs,
)
async def _process_aggregation(self):
"""Process the current aggregation and push it downstream."""
aggregation = self._aggregation
await self.reset()
await self.handle_aggregation(aggregation)
frame = LLMMessagesFrame(self._context.messages)
frame = LLMMessagesFrame(cast(List[dict], self._context.messages))
await self.push_frame(frame)
@@ -1150,7 +1162,11 @@ class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
DeprecationWarning,
stacklevel=2,
)
super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
super().__init__(
context=OpenAILLMContext(cast(List["ChatCompletionMessageParam"], messages)),
params=params,
**kwargs,
)
async def push_aggregation(self):
"""Push the aggregated assistant response as an LLMMessagesFrame."""
@@ -1161,5 +1177,5 @@ class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
# if the tasks gets cancelled we won't be able to clear things up.
await self.reset()
frame = LLMMessagesFrame(self._context.messages)
frame = LLMMessagesFrame(cast(List[dict], self._context.messages))
await self.push_frame(frame)

View File

@@ -16,11 +16,11 @@ import json
import warnings
from abc import abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Literal, Optional, Set, Type
from typing import Any, Dict, List, Literal, Optional, Set, Type, cast
from loguru import logger
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
from pipecat.audio.vad.vad_analyzer import VADAnalyzer
from pipecat.audio.vad.vad_controller import VADController
from pipecat.frames.frames import (
@@ -258,12 +258,20 @@ class LLMContextAggregator(FrameProcessor):
"""
self._context.set_messages(messages)
def set_tools(self, tools: ToolsSchema | NotGiven):
def set_tools(self, tools: ToolsSchema | List | NotGiven):
"""Set tools in the context.
Args:
tools: List of tool definitions to set in the context.
tools: Tool definitions to set in the context.
"""
if isinstance(tools, list):
tools = ToolsSchema(
standard_tools=[],
custom_tools=cast(
Dict[AdapterType, List[Dict[str, Any]]],
{AdapterType.SHIM: tools},
),
)
self._context.set_tools(tools)
def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict):
@@ -272,7 +280,9 @@ class LLMContextAggregator(FrameProcessor):
Args:
tool_choice: Tool choice configuration for the context.
"""
self._context.set_tool_choice(tool_choice)
from pipecat.processors.aggregators.llm_context import LLMContextToolChoice
self._context.set_tool_choice(cast(LLMContextToolChoice, tool_choice))
async def reset(self):
"""Reset the aggregation state."""
@@ -485,7 +495,9 @@ class LLMUserAggregator(LLMContextAggregator):
aggregation = self.aggregation_string()
await self.reset()
self._context.add_message({"role": self.role, "content": aggregation})
self._context.add_message(
cast(LLMContextMessage, {"role": self.role, "content": aggregation})
)
await self.push_context_frame()
return aggregation
@@ -515,7 +527,11 @@ class LLMUserAggregator(LLMContextAggregator):
)
# Auto-inject turn completion instructions into context
self._context.add_message({"role": "system", "content": config.completion_instructions})
self._context.add_message(
cast(
LLMContextMessage, {"role": "system", "content": config.completion_instructions}
)
)
async def _stop(self, frame: EndFrame):
await self._maybe_emit_user_turn_stopped(on_session_end=True)
@@ -811,7 +827,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
self._assistant_turn_start_timestamp = ""
self._thought_append_to_context = False
self._thought_llm: str = ""
self._thought_llm: Optional[str] = ""
self._thought_aggregation: List[TextPartForConcatenation] = []
self._thought_start_time: str = ""
@@ -899,7 +915,9 @@ class LLMAssistantAggregator(LLMContextAggregator):
aggregation = self.aggregation_string()
await self.reset()
self._context.add_message({"role": "assistant", "content": aggregation})
self._context.add_message(
cast(LLMContextMessage, {"role": "assistant", "content": aggregation})
)
# Push context frame
await self.push_context_frame()
@@ -945,26 +963,32 @@ class LLMAssistantAggregator(LLMContextAggregator):
# Update context with the in-progress function call
self._context.add_message(
{
"role": "assistant",
"tool_calls": [
{
"id": frame.tool_call_id,
"function": {
"name": frame.function_name,
"arguments": json.dumps(frame.arguments, ensure_ascii=False),
},
"type": "function",
}
],
}
cast(
LLMContextMessage,
{
"role": "assistant",
"tool_calls": [
{
"id": frame.tool_call_id,
"function": {
"name": frame.function_name,
"arguments": json.dumps(frame.arguments, ensure_ascii=False),
},
"type": "function",
}
],
},
)
)
self._context.add_message(
{
"role": "tool",
"content": "IN_PROGRESS",
"tool_call_id": frame.tool_call_id,
}
cast(
LLMContextMessage,
{
"role": "tool",
"content": "IN_PROGRESS",
"tool_call_id": frame.tool_call_id,
},
)
)
self._function_calls_in_progress[frame.tool_call_id] = frame
@@ -1060,6 +1084,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
logger.debug(f"{self} Appending AssistantImageRawFrame to LLM context (size: {frame.size})")
if frame.original_data and frame.original_mime_type:
assert frame.original_mime_type is not None
await self._context.add_image_frame_message(
format=frame.original_mime_type,
size=frame.size, # Technically doesn't matter, since already encoded
@@ -1068,7 +1093,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
)
else:
await self._context.add_image_frame_message(
format=frame.format,
format=frame.format or "RGB",
size=frame.size,
image=frame.image,
role="assistant",
@@ -1125,11 +1150,10 @@ class LLMAssistantAggregator(LLMContextAggregator):
thought = concatenate_aggregated_text(self._thought_aggregation)
if self._thought_append_to_context:
llm = self._thought_llm
if self._thought_append_to_context and self._thought_llm is not None:
self._context.add_message(
LLMSpecificMessage(
llm=llm,
llm=self._thought_llm,
message={
"type": "thought",
"text": thought,
@@ -1151,7 +1175,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
logger.debug(f"{self} Appending UserImageRawFrame to LLM context (size: {frame.size})")
await self._context.add_image_frame_message(
format=frame.format,
format=frame.format or "RGB",
size=frame.size,
image=frame.image,
text=frame.text,

View File

@@ -21,7 +21,7 @@ import io
import json
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
from openai._types import NOT_GIVEN, NotGiven
from openai.types.chat import (
@@ -144,7 +144,7 @@ class OpenAILLMContext:
context = OpenAILLMContext()
for message in messages:
context.add_message(message)
context.add_message(cast(ChatCompletionMessageParam, message))
return context
@property
@@ -157,7 +157,7 @@ class OpenAILLMContext:
return self._messages
@property
def tools(self) -> List[ChatCompletionToolParam] | NotGiven | List[Any]:
def tools(self) -> List[ChatCompletionToolParam] | NotGiven | ToolsSchema | List[Any]:
"""Get the tools list, converting through adapter if available.
Returns:
@@ -311,7 +311,7 @@ class OpenAILLMContext:
self._tools = tools
def add_image_frame_message(
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
self, *, format: str, size: tuple[int, int], image: bytes, text: Optional[str] = None
):
"""Add a message containing an image frame.
@@ -333,7 +333,9 @@ class OpenAILLMContext:
)
self.add_message({"role": "user", "content": content})
def add_audio_frames_message(self, *, audio_frames: list[AudioRawFrame], text: str = None):
def add_audio_frames_message(
self, *, audio_frames: list[AudioRawFrame], text: Optional[str] = None
):
"""Add a message containing audio frames.
Args:

View File

@@ -14,6 +14,7 @@ in conversational pipelines.
from pipecat.frames.frames import TextFrame
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.processors.aggregators.llm_response_universal import LLMUserAggregator
from pipecat.utils.string import concatenate_aggregated_text
class UserResponseAggregator(LLMUserAggregator):
@@ -45,20 +46,27 @@ class UserResponseAggregator(LLMUserAggregator):
stacklevel=2,
)
async def push_aggregation(self):
async def push_aggregation(self) -> str:
"""Push the aggregated user response as a TextFrame.
Creates a TextFrame from the current aggregation if it contains content,
resets the aggregation state, and pushes the frame downstream.
Returns:
The pushed aggregation text, or empty string if nothing to push.
"""
if len(self._aggregation) > 0:
frame = TextFrame(self._aggregation.strip())
text = concatenate_aggregated_text(self._aggregation).strip()
frame = TextFrame(text)
# Reset the aggregation. Reset it before pushing it down, otherwise
# if the tasks gets cancelled we won't be able to clear things up.
self._aggregation = ""
self._aggregation = []
await self.push_frame(frame)
# Reset our accumulator state.
await self.reset()
return text
return ""

View File

@@ -72,7 +72,7 @@ class VisionImageFrameAggregator(FrameProcessor):
text=self._describe_text,
image=frame.image,
size=frame.size,
format=frame.format,
format=frame.format or "RGB",
)
frame = OpenAILLMContextFrame(context)
await self.push_frame(frame)

View File

@@ -103,7 +103,7 @@ class FrameProcessorQueue(asyncio.PriorityQueue):
self.__high_counter = 0
self.__low_counter = 0
async def put(self, item: Tuple[Frame, FrameDirection, FrameCallback]):
async def put(self, item: Tuple[Frame, FrameDirection, Optional[FrameCallback]]):
"""Put an item into the priority queue.
System frames (`SystemFrame`) have higher priority than any other
@@ -470,11 +470,13 @@ class FrameProcessor(BaseObject):
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
def create_task(self, coroutine: Coroutine, name: Optional[str] = None) -> asyncio.Task:
def create_task(
self, coroutine: Coroutine | Awaitable, name: Optional[str] = None
) -> asyncio.Task:
"""Create a new task managed by this processor.
Args:
coroutine: The coroutine to run in the task.
coroutine: The coroutine or awaitable to run in the task.
name: Optional name for the task.
Returns:
@@ -483,7 +485,8 @@ class FrameProcessor(BaseObject):
if name:
name = f"{self}::{name}"
else:
name = f"{self}::{coroutine.cr_code.co_name}"
cr_code = getattr(coroutine, "cr_code", None)
name = f"{self}::{cr_code.co_name if cr_code else 'unknown'}"
return self.task_manager.create_task(coroutine, name)
async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = 1.0):

View File

@@ -74,7 +74,13 @@ class LangchainProcessor(FrameProcessor):
if isinstance(frame, OpenAILLMContextFrame)
else frame.context.get_messages()
)
text: str = messages[-1]["content"]
last_msg = messages[-1]
content = (
last_msg.get("content", "")
if isinstance(last_msg, dict)
else getattr(last_msg, "content", "")
)
text: str = content if isinstance(content, str) else str(content)
await self._ainvoke(text.strip())
else:
@@ -94,7 +100,7 @@ class LangchainProcessor(FrameProcessor):
case str():
return text
case AIMessageChunk():
return text.content
return text.content if isinstance(text.content, str) else str(text.content)
case _:
return ""

View File

@@ -1253,31 +1253,33 @@ class RTVIObserver(BaseObserver):
report_level = self._get_function_call_report_level(function_call.function_name)
if report_level == RTVIFunctionCallReportLevel.DISABLED:
continue
data = RTVILLMFunctionCallStartMessageData()
fc_start_data = RTVILLMFunctionCallStartMessageData()
if report_level in (
RTVIFunctionCallReportLevel.NAME,
RTVIFunctionCallReportLevel.FULL,
):
data.function_name = function_call.function_name
message = RTVILLMFunctionCallStartMessage(data=data)
fc_start_data.function_name = function_call.function_name
message = RTVILLMFunctionCallStartMessage(data=fc_start_data)
await self.send_rtvi_message(message)
elif isinstance(frame, FunctionCallInProgressFrame):
report_level = self._get_function_call_report_level(frame.function_name)
if report_level != RTVIFunctionCallReportLevel.DISABLED:
data = RTVILLMFunctionCallInProgressMessageData(tool_call_id=frame.tool_call_id)
fc_progress_data = RTVILLMFunctionCallInProgressMessageData(
tool_call_id=frame.tool_call_id
)
if report_level in (
RTVIFunctionCallReportLevel.NAME,
RTVIFunctionCallReportLevel.FULL,
):
data.function_name = frame.function_name
fc_progress_data.function_name = frame.function_name
if report_level == RTVIFunctionCallReportLevel.FULL:
data.args = frame.arguments
message = RTVILLMFunctionCallInProgressMessage(data=data)
fc_progress_data.args = frame.arguments
message = RTVILLMFunctionCallInProgressMessage(data=fc_progress_data)
await self.send_rtvi_message(message)
elif isinstance(frame, FunctionCallCancelFrame):
report_level = self._get_function_call_report_level(frame.function_name)
if report_level != RTVIFunctionCallReportLevel.DISABLED:
data = RTVILLMFunctionCallStoppedMessageData(
fc_cancel_data = RTVILLMFunctionCallStoppedMessageData(
tool_call_id=frame.tool_call_id,
cancelled=True,
)
@@ -1285,13 +1287,13 @@ class RTVIObserver(BaseObserver):
RTVIFunctionCallReportLevel.NAME,
RTVIFunctionCallReportLevel.FULL,
):
data.function_name = frame.function_name
message = RTVILLMFunctionCallStoppedMessage(data=data)
fc_cancel_data.function_name = frame.function_name
message = RTVILLMFunctionCallStoppedMessage(data=fc_cancel_data)
await self.send_rtvi_message(message)
elif isinstance(frame, FunctionCallResultFrame):
report_level = self._get_function_call_report_level(frame.function_name)
if report_level != RTVIFunctionCallReportLevel.DISABLED:
data = RTVILLMFunctionCallStoppedMessageData(
fc_result_data = RTVILLMFunctionCallStoppedMessageData(
tool_call_id=frame.tool_call_id,
cancelled=False,
)
@@ -1299,10 +1301,10 @@ class RTVIObserver(BaseObserver):
RTVIFunctionCallReportLevel.NAME,
RTVIFunctionCallReportLevel.FULL,
):
data.function_name = frame.function_name
fc_result_data.function_name = frame.function_name
if report_level == RTVIFunctionCallReportLevel.FULL:
data.result = frame.result if frame.result else None
message = RTVILLMFunctionCallStoppedMessage(data=data)
fc_result_data.result = frame.result if frame.result else None
message = RTVILLMFunctionCallStoppedMessage(data=fc_result_data)
await self.send_rtvi_message(message)
elif isinstance(frame, RTVIServerMessageFrame):
message = RTVIServerMessage(data=frame.data)
@@ -1427,8 +1429,14 @@ class RTVIObserver(BaseObserver):
# Handle Google LLM format (protobuf objects with attributes)
# Note: not possible if frame is a universal LLMContextFrame
if hasattr(message, "role") and message.role == "user" and hasattr(message, "parts"):
text = "".join(part.text for part in message.parts if hasattr(part, "text"))
if (
hasattr(message, "role")
and getattr(message, "role", None) == "user"
and hasattr(message, "parts")
):
text = "".join(
part.text for part in getattr(message, "parts", []) if hasattr(part, "text")
)
if text:
rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text))
await self.send_rtvi_message(rtvi_message)
@@ -1439,8 +1447,10 @@ class RTVIObserver(BaseObserver):
content = message["content"]
if isinstance(content, list):
text = " ".join(item["text"] for item in content if "text" in item)
else:
elif isinstance(content, str):
text = content
else:
text = str(content) if content else ""
rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text))
await self.send_rtvi_message(rtvi_message)
@@ -1482,7 +1492,7 @@ class RTVIObserver(BaseObserver):
async def _send_error_response(self, frame: RTVIServerResponseFrame):
"""Send a response to the client for a specific request."""
message = RTVIErrorResponse(
id=str(frame.client_msg.msg_id), data=RTVIErrorResponseData(error=frame.error)
id=str(frame.client_msg.msg_id), data=RTVIErrorResponseData(error=frame.error or "")
)
await self.send_rtvi_message(message)
@@ -1593,7 +1603,7 @@ class RTVIProcessor(FrameProcessor):
self._client_ready = True
await self._call_event_handler("on_client_ready")
async def set_bot_ready(self, about: Mapping[str, Any] = None):
async def set_bot_ready(self, about: Optional[Mapping[str, Any]] = None):
"""Mark the bot as ready and send the bot-ready message.
Args:
@@ -2026,7 +2036,7 @@ class RTVIProcessor(FrameProcessor):
"""Handle an action execution request."""
action_id = self._action_id(data.service, data.action)
if action_id not in self._registered_actions:
await self._send_error_response(request_id, f"Action {action_id} not registered")
await self._send_error_response(request_id or "", f"Action {action_id} not registered")
return
action = self._registered_actions[action_id]
arguments = {}
@@ -2040,7 +2050,7 @@ class RTVIProcessor(FrameProcessor):
message = RTVIActionResponse(id=request_id, data=RTVIActionResponseData(result=result))
await self.push_transport_message(message)
async def _send_bot_ready(self, about: Mapping[str, Any] = None):
async def _send_bot_ready(self, about: Optional[Mapping[str, Any]] = None):
"""Send the bot-ready message to the client.
Args:

View File

@@ -76,7 +76,12 @@ class StrandsAgentsProcessor(FrameProcessor):
messages = frame.context.get_messages()
if messages:
last_message = messages[-1]
await self._ainvoke(str(last_message["content"]).strip())
content = (
last_message.get("content", "")
if isinstance(last_message, dict)
else getattr(last_message, "content", "")
)
await self._ainvoke(str(content).strip())
else:
await self.push_frame(frame, direction)
@@ -100,6 +105,7 @@ class StrandsAgentsProcessor(FrameProcessor):
await self.stop_ttfb_metrics()
ttfb_tracking = False
try:
assert self.graph_exit_node is not None
node_result = graph_result.results[self.graph_exit_node]
logger.debug(f"Node result: {node_result}")
for agent_result in node_result.get_agent_results():
@@ -119,6 +125,7 @@ class StrandsAgentsProcessor(FrameProcessor):
logger.warning(f"Failed to extract messages from GraphResult: {parse_err}")
else:
# Agent supports streaming events via async iterator
assert self.agent is not None
async for event in self.agent.stream_async(text):
# Push to TTS service
if isinstance(event, dict) and "data" in event:

View File

@@ -28,7 +28,7 @@ try:
gi.require_version("Gst", "1.0")
gi.require_version("GstApp", "1.0")
from gi.repository import Gst, GstApp
from gi.repository import Gst, GstApp # type: ignore[import-not-found]
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(

View File

@@ -65,6 +65,7 @@ class FrameProcessorMetrics(BaseObject):
Returns:
The task manager instance for async operations.
"""
assert self._task_manager is not None
return self._task_manager
@property

View File

@@ -10,7 +10,7 @@ This module provides processors that convert speech and text frames into structu
transcript messages with timestamps, enabling conversation history tracking and analysis.
"""
from typing import List, Optional
from typing import List, Optional, Union
from loguru import logger
@@ -47,10 +47,14 @@ class BaseTranscriptProcessor(FrameProcessor):
**kwargs: Additional arguments passed to parent class.
"""
super().__init__(**kwargs)
self._processed_messages: List[TranscriptionMessage] = []
self._processed_messages: List[
Union[TranscriptionMessage, ThoughtTranscriptionMessage]
] = []
self._register_event_handler("on_transcript_update")
async def _emit_update(self, messages: List[TranscriptionMessage]):
async def _emit_update(
self, messages: List[Union[TranscriptionMessage, ThoughtTranscriptionMessage]]
):
"""Emit transcript updates for new messages.
Args:

View File

@@ -79,11 +79,11 @@ async def configure(
aiohttp_session: aiohttp.ClientSession,
*,
api_key: Optional[str] = None,
room_exp_duration: Optional[float] = 2.0,
token_exp_duration: Optional[float] = 2.0,
room_exp_duration: float = 2.0,
token_exp_duration: float = 2.0,
sip_caller_phone: Optional[str] = None,
sip_enable_video: Optional[bool] = False,
sip_num_endpoints: Optional[int] = 1,
sip_enable_video: bool = False,
sip_num_endpoints: int = 1,
sip_codecs: Optional[Dict[str, List[str]]] = None,
room_properties: Optional[DailyRoomProperties] = None,
token_properties: Optional["DailyMeetingTokenProperties"] = None,
@@ -209,6 +209,7 @@ async def configure(
# Add SIP configuration if enabled
if sip_enabled:
assert sip_caller_phone is not None
sip_params = DailyRoomSipParams(
display_name=sip_caller_phone,
video=sip_enable_video,

View File

@@ -72,10 +72,23 @@ import os
import sys
import uuid
from contextlib import asynccontextmanager
from http import HTTPMethod
from pathlib import Path
from typing import Any, Dict, List, Optional, TypedDict, Union
if sys.version_info >= (3, 11):
from http import HTTPMethod
else:
# HTTPMethod was added in Python 3.11
from enum import Enum
class HTTPMethod(str, Enum):
"""HTTP method enum for Python < 3.11 compatibility."""
POST = "POST"
PATCH = "PATCH"
GET = "GET"
import aiohttp
from fastapi.responses import FileResponse, Response
from loguru import logger
@@ -140,6 +153,8 @@ def _get_bot_module():
spec = importlib.util.spec_from_file_location(
module_name, os.path.join(cwd, filename)
)
if not spec or not spec.loader:
continue
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
@@ -397,6 +412,11 @@ def _setup_whatsapp_routes(app: FastAPI, args: argparse.Namespace):
)
return
# Assertions after the all() check above ensures these are non-None
assert WHATSAPP_TOKEN is not None
assert WHATSAPP_PHONE_NUMBER_ID is not None
assert WHATSAPP_WEBHOOK_VERIFICATION_TOKEN is not None
try:
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
from pipecat.transports.whatsapp.api import WhatsAppWebhookRequest

View File

@@ -416,8 +416,8 @@ def _get_transport_params(transport_key: str, transport_params: Dict[str, Callab
async def _create_telephony_transport(
websocket: WebSocket,
params: Optional[Any] = None,
transport_type: str = None,
call_data: dict = None,
transport_type: Optional[str] = None,
call_data: Optional[dict] = None,
) -> BaseTransport:
"""Create a telephony transport with pre-parsed WebSocket data.
@@ -443,6 +443,9 @@ async def _create_telephony_transport(
logger.info(f"Using pre-detected telephony provider: {transport_type}")
if call_data is None:
raise ValueError("call_data must be provided for telephony transports.")
if transport_type == "twilio":
from pipecat.serializers.twilio import TwilioFrameSerializer
@@ -586,6 +589,8 @@ async def create_transport(
from pipecat.transports.livekit.transport import LiveKitTransport
if runner_args.token is None:
raise ValueError("LiveKit token is required")
return LiveKitTransport(
runner_args.url,
runner_args.token,

View File

@@ -62,6 +62,7 @@ class ExotelFrameSerializer(FrameSerializer):
params: Configuration parameters.
"""
super().__init__(params or ExotelFrameSerializer.InputParams())
self._params: ExotelFrameSerializer.InputParams
self._stream_sid = stream_sid
self._call_sid = call_sid

View File

@@ -169,6 +169,7 @@ class GenesysAudioHookSerializer(FrameSerializer):
**kwargs: Additional arguments passed to BaseObject (e.g., name).
"""
super().__init__(params or GenesysAudioHookSerializer.InputParams(), **kwargs)
self._params: GenesysAudioHookSerializer.InputParams
self._genesys_sample_rate = self._params.genesys_sample_rate
self._sample_rate = 0 # Pipeline input rate, set in setup()

View File

@@ -74,6 +74,7 @@ class PlivoFrameSerializer(FrameSerializer):
params: Configuration parameters.
"""
super().__init__(params or PlivoFrameSerializer.InputParams())
self._params: PlivoFrameSerializer.InputParams
self._stream_id = stream_id
self._call_id = call_id

View File

@@ -79,23 +79,23 @@ class ProtobufFrameSerializer(FrameSerializer):
Serialized frame as bytes, or None if frame type is not serializable.
"""
# Wrapping this messages as a JSONFrame to send
serializable_frame: Frame | MessageFrame = frame
if isinstance(frame, (OutputTransportMessageFrame, OutputTransportMessageUrgentFrame)):
if self.should_ignore_frame(frame):
return None
frame = MessageFrame(
serializable_frame = MessageFrame(
data=json.dumps(frame.message),
)
proto_frame = frame_protos.Frame()
if type(frame) not in self.SERIALIZABLE_TYPES:
proto_frame = frame_protos.Frame() # type: ignore[attr-defined]
proto_optional_name = self.SERIALIZABLE_TYPES.get(type(serializable_frame))
if proto_optional_name is None:
logger.warning(f"Frame type {type(frame)} is not serializable")
return None
# ignoring linter errors; we check that type(frame) is in this dict above
proto_optional_name = self.SERIALIZABLE_TYPES[type(frame)] # type: ignore
proto_attr = getattr(proto_frame, proto_optional_name)
for field in dataclasses.fields(frame): # type: ignore
value = getattr(frame, field.name)
for field in dataclasses.fields(serializable_frame): # type: ignore[arg-type]
value = getattr(serializable_frame, field.name)
if value and hasattr(proto_attr, field.name):
setattr(proto_attr, field.name, value)
@@ -110,7 +110,7 @@ class ProtobufFrameSerializer(FrameSerializer):
Returns:
Deserialized frame instance, or None if deserialization fails.
"""
proto = frame_protos.Frame.FromString(data)
proto = frame_protos.Frame.FromString(data) # type: ignore[attr-defined]
which = proto.WhichOneof("frame")
if which not in self.DESERIALIZABLE_FIELDS:
logger.error("Unable to deserialize a valid frame")

View File

@@ -78,6 +78,7 @@ class TwilioFrameSerializer(FrameSerializer):
params: Configuration parameters.
"""
super().__init__(params or TwilioFrameSerializer.InputParams())
self._params: TwilioFrameSerializer.InputParams
# Validate hangup-related parameters if auto_hang_up is enabled
if self._params.auto_hang_up:
@@ -193,6 +194,8 @@ class TwilioFrameSerializer(FrameSerializer):
endpoint = f"https://api.{edge_prefix}{region_prefix}twilio.com/2010-04-01/Accounts/{account_sid}/Calls/{call_sid}.json"
# Create basic auth from account_sid and auth_token
assert account_sid is not None
assert auth_token is not None
auth = aiohttp.BasicAuth(account_sid, auth_token)
# Parameters to set the call status to "completed" (hang up)

View File

@@ -57,6 +57,7 @@ class VonageFrameSerializer(FrameSerializer):
params: Configuration parameters.
"""
super().__init__(params or VonageFrameSerializer.InputParams())
self._params: VonageFrameSerializer.InputParams
self._vonage_sample_rate = self._params.vonage_sample_rate
self._sample_rate = 0 # Pipeline input rate

View File

@@ -592,7 +592,8 @@ class BaseInputTransport(FrameProcessor):
"""Handle end-of-turn analysis and generate prediction results."""
if self._params.turn_analyzer:
state, prediction = await self._params.turn_analyzer.analyze_end_of_turn()
await self._deprecated_handle_prediction_result(prediction)
if prediction is not None:
await self._deprecated_handle_prediction_result(prediction)
await self._deprecated_handle_end_of_turn_complete(state)
async def _deprecated_handle_end_of_turn_complete(self, state: EndOfTurnState):
@@ -610,6 +611,7 @@ class BaseInputTransport(FrameProcessor):
"""Run turn analysis on audio frame and handle results."""
is_speech = vad_state == VADState.SPEAKING or vad_state == VADState.STARTING
# If silence exceeds threshold, we are going to receive EndOfTurnState.COMPLETE
assert self._params.turn_analyzer is not None
end_of_turn_state = self._params.turn_analyzer.append_audio(frame.audio, is_speech)
if end_of_turn_state == EndOfTurnState.COMPLETE:
await self._deprecated_handle_end_of_turn_complete(end_of_turn_state)

View File

@@ -701,14 +701,16 @@ class BaseOutputTransport(FrameProcessor):
# Notify the bot stopped speaking upstream if necessary.
await self._bot_stopped_speaking()
async def with_mixer(vad_stop_secs: float) -> AsyncGenerator[Frame, None]:
async def with_mixer(
mixer: BaseAudioMixer, vad_stop_secs: float
) -> AsyncGenerator[Frame, None]:
last_frame_time = 0
silence = b"\x00" * self._audio_chunk_size
while True:
try:
frame = self._audio_queue.get_nowait()
if isinstance(frame, OutputAudioRawFrame):
frame.audio = await self._mixer.mix(frame.audio)
frame.audio = await mixer.mix(frame.audio)
last_frame_time = time.time()
yield frame
self._audio_queue.task_done()
@@ -719,7 +721,7 @@ class BaseOutputTransport(FrameProcessor):
await self._bot_stopped_speaking()
# Generate an audio frame with only the mixer's part.
frame = OutputAudioRawFrame(
audio=await self._mixer.mix(silence),
audio=await mixer.mix(silence),
sample_rate=self._sample_rate,
num_channels=self._params.audio_out_channels,
)
@@ -731,7 +733,7 @@ class BaseOutputTransport(FrameProcessor):
await asyncio.sleep(0)
if self._mixer:
return with_mixer(BOT_VAD_STOP_SECS)
return with_mixer(self._mixer, BOT_VAD_STOP_SECS)
else:
return without_mixer(BOT_VAD_STOP_SECS)
@@ -864,6 +866,7 @@ class BaseOutputTransport(FrameProcessor):
# which is kind of what happens in P2P connections.
# We need to add support for that inside the DailyTransport
if frame.size != desired_size:
assert frame.format is not None
image = Image.frombytes(frame.format, frame.size, frame.image)
resized_image = image.resize(desired_size)
# logger.warning(f"{frame} does not have the expected size {desired_size}, resizing")

View File

@@ -195,7 +195,7 @@ class DailyUpdateRemoteParticipantsFrame(ControlFrame):
remote_participants: See https://reference-python.daily.co/api_reference.html#daily.CallClient.update_remote_participants.
"""
remote_participants: Mapping[str, Any] = None
remote_participants: Optional[Mapping[str, Any]] = None
def __post_init__(self):
super().__post_init__()
@@ -751,7 +751,10 @@ class DailyTransportClient(EventHandler):
self._client.set_user_name(self._bot_name)
(data, error) = await self._join()
result = await self._join()
if result is None:
return
(data, error) = result
if not error:
self._joined = True
@@ -881,7 +884,7 @@ class DailyTransportClient(EventHandler):
"""Cleanup the Daily client instance."""
if self._client:
self._client.release()
self._client = None
self._client = None # type: ignore[assignment]
def participants(self) -> Mapping[str, Any]:
"""Get current participants in the room.
@@ -1956,7 +1959,8 @@ class DailyOutputTransport(BaseOutputTransport):
await super().process_frame(frame, direction)
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
await self._client.update_remote_participants(frame.remote_participants)
if frame.remote_participants is not None:
await self._client.update_remote_participants(frame.remote_participants)
async def send_message(
self, frame: OutputTransportMessageFrame | OutputTransportMessageUrgentFrame

View File

@@ -230,6 +230,7 @@ class HeyGenOutputTransport(BaseOutputTransport):
logger.warning("self._event_id is already defined!")
self._event_id = str(frame.id)
elif isinstance(frame, BotStoppedSpeakingFrame):
assert self._event_id is not None
await self._client.agent_speak_end(self._event_id)
self._event_id = None
await super().push_frame(frame, direction)
@@ -252,7 +253,8 @@ class HeyGenOutputTransport(BaseOutputTransport):
"""
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
await self._client.interrupt(self._event_id)
if self._event_id is not None:
await self._client.interrupt(self._event_id)
await self.push_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
await self._client.start_agent_listening()
@@ -267,6 +269,7 @@ class HeyGenOutputTransport(BaseOutputTransport):
Args:
frame: The audio frame to write.
"""
assert self._event_id is not None
await self._client.agent_speak(bytes(frame.audio), self._event_id)
return True

View File

@@ -412,7 +412,7 @@ class LiveKitTransportClient:
"id": participant.sid,
"name": participant.name,
"metadata": participant.metadata,
"is_speaking": participant.is_speaking,
"is_speaking": participant.is_speaking, # type: ignore[attr-defined]
}
return {}
@@ -432,7 +432,7 @@ class LiveKitTransportClient:
"""
participant = self.room.remote_participants.get(participant_id)
if participant:
for track in participant.tracks.values():
for track in participant.tracks.values(): # type: ignore[attr-defined]
if track.kind == "audio":
await track.set_enabled(False)
@@ -444,13 +444,14 @@ class LiveKitTransportClient:
"""
participant = self.room.remote_participants.get(participant_id)
if participant:
for track in participant.tracks.values():
for track in participant.tracks.values(): # type: ignore[attr-defined]
if track.kind == "audio":
await track.set_enabled(True)
# Wrapper methods for event handlers
def _on_participant_connected_wrapper(self, participant: rtc.RemoteParticipant):
"""Wrapper for participant connected events."""
assert self._task_manager is not None
self._task_manager.create_task(
self._async_on_participant_connected(participant),
f"{self}::_async_on_participant_connected",
@@ -458,6 +459,7 @@ class LiveKitTransportClient:
def _on_participant_disconnected_wrapper(self, participant: rtc.RemoteParticipant):
"""Wrapper for participant disconnected events."""
assert self._task_manager is not None
self._task_manager.create_task(
self._async_on_participant_disconnected(participant),
f"{self}::_async_on_participant_disconnected",
@@ -470,6 +472,7 @@ class LiveKitTransportClient:
participant: rtc.RemoteParticipant,
):
"""Wrapper for track subscribed events."""
assert self._task_manager is not None
self._task_manager.create_task(
self._async_on_track_subscribed(track, publication, participant),
f"{self}::_async_on_track_subscribed",
@@ -482,6 +485,7 @@ class LiveKitTransportClient:
participant: rtc.RemoteParticipant,
):
"""Wrapper for track unsubscribed events."""
assert self._task_manager is not None
self._task_manager.create_task(
self._async_on_track_unsubscribed(track, publication, participant),
f"{self}::_async_on_track_unsubscribed",
@@ -489,6 +493,7 @@ class LiveKitTransportClient:
def _on_data_received_wrapper(self, data: rtc.DataPacket):
"""Wrapper for data received events."""
assert self._task_manager is not None
self._task_manager.create_task(
self._async_on_data_received(data),
f"{self}::_async_on_data_received",
@@ -496,10 +501,12 @@ class LiveKitTransportClient:
def _on_connected_wrapper(self):
"""Wrapper for connected events."""
assert self._task_manager is not None
self._task_manager.create_task(self._async_on_connected(), f"{self}::_async_on_connected")
def _on_disconnected_wrapper(self):
"""Wrapper for disconnected events."""
assert self._task_manager is not None
self._task_manager.create_task(
self._async_on_disconnected(), f"{self}::_async_on_disconnected"
)
@@ -531,6 +538,7 @@ class LiveKitTransportClient:
logger.info(f"Audio track subscribed: {track.sid} from participant {participant.sid}")
self._audio_tracks[participant.sid] = track
audio_stream = rtc.AudioStream(track)
assert self._task_manager is not None
self._task_manager.create_task(
self._process_audio_stream(audio_stream, participant.sid),
f"{self}::_process_audio_stream",
@@ -543,6 +551,7 @@ class LiveKitTransportClient:
# unbounded queue growth when there is no consumer for video frames.
if self._params.video_in_enabled:
video_stream = rtc.VideoStream(track)
assert self._task_manager is not None
self._task_manager.create_task(
self._process_video_stream(video_stream, participant.sid),
f"{self}::_process_video_stream",
@@ -564,7 +573,9 @@ class LiveKitTransportClient:
async def _async_on_data_received(self, data: rtc.DataPacket):
"""Handle data received events."""
await self._callbacks.on_data_received(data.data, data.participant.sid)
participant = data.participant
assert participant is not None
await self._callbacks.on_data_received(data.data, participant.sid)
async def _async_on_connected(self):
"""Handle connected events."""
@@ -796,7 +807,7 @@ class LiveKitInputTransport(BaseInputTransport):
"""Convert LiveKit video frame to Pipecat video frame."""
rgb_frame = video_frame_event.frame.convert(proto_video_frame.VideoBufferType.RGB24)
image_frame = ImageRawFrame(
image=rgb_frame.data,
image=bytes(rgb_frame.data),
size=(rgb_frame.width, rgb_frame.height),
format="RGB",
)
@@ -1119,7 +1130,7 @@ class LiveKitTransport(BaseTransport):
await self._call_event_handler("on_audio_track_subscribed", participant_id)
participant = self._client.room.remote_participants.get(participant_id)
if participant:
for publication in participant.audio_tracks.values():
for publication in participant.audio_tracks.values(): # type: ignore[attr-defined]
self._client._on_track_subscribed_wrapper(
publication.track, publication, participant
)
@@ -1133,7 +1144,7 @@ class LiveKitTransport(BaseTransport):
await self._call_event_handler("on_video_track_subscribed", participant_id)
participant = self._client.room.remote_participants.get(participant_id)
if participant:
for publication in participant.video_tracks.values():
for publication in participant.video_tracks.values(): # type: ignore[attr-defined]
self._client._on_track_subscribed_wrapper(
publication.track, publication, participant
)

View File

@@ -229,7 +229,7 @@ class TkOutputTransport(BaseOutputTransport):
# This holds a reference to the photo, preventing it from being garbage
# collected.
self._image_label.image = photo
self._image_label.image = photo # type: ignore[assignment]
class TkLocalTransport(BaseTransport):

View File

@@ -15,7 +15,7 @@ import asyncio
import json
import time
import uuid
from typing import Any, List, Literal, Optional, Union
from typing import Any, List, Literal, Optional, Union, cast
from loguru import logger
from pydantic import BaseModel, TypeAdapter
@@ -224,9 +224,9 @@ class SmallWebRTCConnection(BaseObject):
if not ice_servers:
self.ice_servers: List[IceServer] = []
elif all(isinstance(s, IceServer) for s in ice_servers):
self.ice_servers = ice_servers
self.ice_servers = cast(List[IceServer], ice_servers)
elif all(isinstance(s, str) for s in ice_servers):
self.ice_servers = [IceServer(urls=s) for s in ice_servers]
self.ice_servers = [IceServer(urls=cast(str, s)) for s in ice_servers]
else:
raise TypeError("ice_servers must be either List[str] or List[RTCIceServer]")
self._connect_invoked = False
@@ -384,10 +384,10 @@ class SmallWebRTCConnection(BaseObject):
# and aiortc does not handle that pretty well.
video_input_track = self.video_input_track()
if video_input_track:
await self.video_input_track().discard_old_frames()
await video_input_track.discard_old_frames()
screen_video_input_track = self.screen_video_input_track()
if screen_video_input_track:
await self.screen_video_input_track().discard_old_frames()
await screen_video_input_track.discard_old_frames()
if video_input_track or screen_video_input_track:
# This prevents an issue where sometimes the WebRTC connection can be established
# before the bot is ready to receive video. When that happens, we can lose a couple

View File

@@ -223,6 +223,7 @@ class SmallWebRTCRequestHandler:
)
answer = pipecat_connection.get_answer()
assert answer is not None
if self._esp32_mode:
from pipecat.runner.utils import smallwebrtc_sdp_munging

View File

@@ -307,7 +307,7 @@ class SmallWebRTCClient:
if (
self._webrtc_connection.is_connected()
and video_track
and video_track.is_enabled()
and video_track.is_enabled() # type: ignore[attr-defined]
):
logger.warning("Timeout: No video frame received within the specified time.")
# self._webrtc_connection.ask_to_renegotiate()
@@ -362,7 +362,7 @@ class SmallWebRTCClient:
if (
self._webrtc_connection.is_connected()
and self._audio_input_track
and self._audio_input_track.is_enabled()
and self._audio_input_track.is_enabled() # type: ignore[attr-defined]
):
logger.warning("Timeout: No audio frame received within the specified time.")
frame = None
@@ -375,7 +375,7 @@ class SmallWebRTCClient:
await asyncio.sleep(0.01)
continue
if frame.sample_rate > self._in_sample_rate:
if self._in_sample_rate is not None and frame.sample_rate > self._in_sample_rate:
resampled_frames = self._pipecat_resampler.resample(frame)
for resampled_frame in resampled_frames:
# 16-bit PCM bytes
@@ -383,6 +383,7 @@ class SmallWebRTCClient:
pcm_bytes = pcm_array.tobytes()
del pcm_array # free NumPy array immediately
assert self._audio_in_channels is not None
audio_frame = InputAudioRawFrame(
audio=pcm_bytes,
sample_rate=resampled_frame.sample_rate,
@@ -398,6 +399,7 @@ class SmallWebRTCClient:
pcm_bytes = pcm_array.tobytes()
del pcm_array # free NumPy array immediately
assert self._audio_in_channels is not None
audio_frame = InputAudioRawFrame(
audio=pcm_bytes,
sample_rate=frame.sample_rate,
@@ -489,9 +491,9 @@ class SmallWebRTCClient:
if not self._params:
return
self._audio_input_track = self._webrtc_connection.audio_input_track()
self._video_input_track = self._webrtc_connection.video_input_track()
self._screen_video_track = self._webrtc_connection.screen_video_input_track()
self._audio_input_track = self._webrtc_connection.audio_input_track() # type: ignore[assignment]
self._video_input_track = self._webrtc_connection.video_input_track() # type: ignore[assignment]
self._screen_video_track = self._webrtc_connection.screen_video_input_track() # type: ignore[assignment]
if self._params.audio_out_enabled:
self._audio_output_track = RawAudioTrack(sample_rate=self._out_sample_rate)
self._webrtc_connection.replace_audio_track(self._audio_output_track)

View File

@@ -257,12 +257,14 @@ class TavusTransportClient:
await self._client.setup(setup)
except Exception as e:
logger.error(f"Failed to setup TavusTransportClient: {e}")
assert self._conversation_id is not None
await self._api.end_conversation(self._conversation_id)
self._conversation_id = None
async def cleanup(self):
"""Cleanup client resources."""
try:
assert self._client is not None
await self._client.cleanup()
except Exception as e:
logger.error(f"Exception during cleanup: {e}")
@@ -294,12 +296,15 @@ class TavusTransportClient:
frame: The start frame containing initialization parameters.
"""
logger.debug("TavusTransportClient start invoked!")
assert self._client is not None
await self._client.start(frame)
await self._client.join()
async def stop(self):
"""Stop the client and end the conversation."""
assert self._client is not None
await self._client.leave()
assert self._conversation_id is not None
await self._api.end_conversation(self._conversation_id)
self._conversation_id = None
@@ -320,6 +325,7 @@ class TavusTransportClient:
video_source: Video source to capture from.
color_format: Color format for video frames.
"""
assert self._client is not None
await self._client.capture_participant_video(
participant_id, callback, framerate, video_source, color_format
)
@@ -341,6 +347,7 @@ class TavusTransportClient:
sample_rate: Desired sample rate for audio capture.
callback_interval_ms: Interval between audio callbacks in milliseconds.
"""
assert self._client is not None
await self._client.capture_participant_audio(
participant_id, callback, audio_source, sample_rate, callback_interval_ms
)
@@ -353,6 +360,7 @@ class TavusTransportClient:
Args:
frame: The message frame to send.
"""
assert self._client is not None
await self._client.send_message(frame)
@property
@@ -362,6 +370,7 @@ class TavusTransportClient:
Returns:
The output sample rate in Hz.
"""
assert self._client is not None
return self._client.out_sample_rate
@property
@@ -371,6 +380,7 @@ class TavusTransportClient:
Returns:
The input sample rate in Hz.
"""
assert self._client is not None
return self._client.in_sample_rate
async def send_interrupt_message(self) -> None:

View File

@@ -132,7 +132,7 @@ class WebsocketClientSession:
return
try:
self._websocket = await websocket_connect(
self._websocket = await websocket_connect( # type: ignore[assignment]
uri=self._uri,
open_timeout=10,
additional_headers=self._params.additional_headers,
@@ -141,7 +141,8 @@ class WebsocketClientSession:
self._client_task_handler(),
f"{self._transport_name}::WebsocketClientSession::_client_task_handler",
)
await self._callbacks.on_connected(self._websocket)
if self._websocket:
await self._callbacks.on_connected(self._websocket)
except TimeoutError:
logger.error(f"Timeout connecting to {self._uri}")
@@ -179,7 +180,7 @@ class WebsocketClientSession:
Returns:
True if the WebSocket is in connected state.
"""
return self._websocket.state == websockets.State.OPEN if self._websocket else False
return self._websocket.state == websockets.State.OPEN if self._websocket else False # type: ignore[attr-defined]
@property
def is_closing(self) -> bool:
@@ -188,18 +189,22 @@ class WebsocketClientSession:
Returns:
True if the WebSocket is in the process of closing.
"""
return self._websocket.state == websockets.State.CLOSING if self._websocket else False
return self._websocket.state == websockets.State.CLOSING if self._websocket else False # type: ignore[attr-defined]
async def _client_task_handler(self):
"""Handle incoming messages from the WebSocket connection."""
if not self._websocket:
return
websocket = self._websocket
try:
# Handle incoming messages
async for message in self._websocket:
await self._callbacks.on_message(self._websocket, message)
async for message in websocket:
await self._callbacks.on_message(websocket, message)
except Exception as e:
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
await self._callbacks.on_disconnected(self._websocket)
await self._callbacks.on_disconnected(websocket)
def __str__(self):
"""String representation of the WebSocket client session."""

View File

@@ -326,6 +326,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
async def _monitor_websocket(self):
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
assert self._params.session_timeout is not None
await asyncio.sleep(self._params.session_timeout)
await self._client.trigger_client_timeout()

View File

@@ -181,7 +181,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
async def _server_task_handler(self):
"""Handle WebSocket server startup and client connections."""
logger.info(f"Starting websocket server on {self._host}:{self._port}")
async with websocket_serve(self._client_handler, self._host, self._port) as server:
async with websocket_serve(self._client_handler, self._host, self._port) as server: # type: ignore[arg-type]
await self._callbacks.on_websocket_ready()
await self._stop_server_event.wait()

View File

@@ -157,6 +157,7 @@ class WhatsAppClient:
async def _validate_whatsapp_webhook_request(self, raw_body: bytes, sha256_signature: str):
"""Common handler for both /start and /connect endpoints."""
# Compute HMAC SHA256 using your App Secret
assert self._whatsapp_secret is not None
expected_signature = hmac.new(
key=self._whatsapp_secret.encode("utf-8"),
msg=raw_body,
@@ -205,6 +206,7 @@ class WhatsAppClient:
"""
try:
if self._whatsapp_secret:
assert raw_body is not None and sha256_signature is not None
await self._validate_whatsapp_webhook_request(raw_body, sha256_signature)
for entry in request.entry:
for change in entry.changes:
@@ -306,7 +308,10 @@ class WhatsAppClient:
# Create and initialize WebRTC connection
pipecat_connection = SmallWebRTCConnection(self._ice_servers)
await pipecat_connection.initialize(sdp=call.session.sdp, type=call.session.sdp_type)
sdp_answer = pipecat_connection.get_answer().get("sdp")
answer = pipecat_connection.get_answer()
assert answer is not None
sdp_answer = answer.get("sdp")
assert isinstance(sdp_answer, str)
sdp_answer = self._filter_sdp_for_whatsapp(sdp_answer)
logger.debug(f"SDP answer generated for call {call.id}")

View File

@@ -14,7 +14,7 @@ were interrupted mid-thought.
import asyncio
from dataclasses import dataclass
from typing import Literal, Optional
from typing import TYPE_CHECKING, Literal, Optional
from loguru import logger
@@ -28,6 +28,13 @@ from pipecat.frames.frames import (
)
from pipecat.processors.frame_processor import FrameDirection
if TYPE_CHECKING:
from pipecat.processors.frame_processor import FrameProcessor
_MixinBase = FrameProcessor
else:
_MixinBase = object
# Turn completion markers
USER_TURN_COMPLETE_MARKER = ""
USER_TURN_INCOMPLETE_SHORT_MARKER = "" # Short wait - user likely continues soon
@@ -178,7 +185,7 @@ class UserTurnCompletionConfig:
return self.incomplete_long_prompt or DEFAULT_INCOMPLETE_LONG_PROMPT
class UserTurnCompletionLLMServiceMixin:
class UserTurnCompletionLLMServiceMixin(_MixinBase):
"""Mixin that adds turn completion detection to LLM services.
This mixin provides methods to push LLM text with turn completion detection.

View File

@@ -15,7 +15,7 @@ import asyncio
import traceback
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Coroutine, Dict, Optional, Sequence
from typing import Awaitable, Coroutine, Dict, Optional, Sequence
from loguru import logger
@@ -56,13 +56,13 @@ class BaseTaskManager(ABC):
pass
@abstractmethod
def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task:
def create_task(self, coroutine: Coroutine | Awaitable, name: str) -> asyncio.Task:
"""Creates and schedules a new asyncio Task that runs the given coroutine.
The task is added to a global set of created tasks.
Args:
coroutine: The coroutine to be executed within the task.
coroutine: The coroutine or awaitable to be executed within the task.
name: The name to assign to the task for identification.
Returns:
@@ -139,13 +139,13 @@ class TaskManager(BaseTaskManager):
raise Exception("TaskManager is not setup: unable to get event loop")
return self._params.loop
def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task:
def create_task(self, coroutine: Coroutine | Awaitable, name: str) -> asyncio.Task:
"""Creates and schedules a new asyncio Task that runs the given coroutine.
The task is added to a global set of created tasks.
Args:
coroutine: The coroutine to be executed within the task.
coroutine: The coroutine or awaitable to be executed within the task.
name: The name to assign to the task for identification.
Returns:

View File

@@ -234,7 +234,7 @@ class TextPartForConcatenation:
includes_inter_part_spaces: bool
def __str__(self):
return f"{self.name}(text: [{self.text}], includes_inter_part_spaces: {self.includes_inter_part_spaces})"
return f"{type(self).__name__}(text: [{self.text}], includes_inter_part_spaces: {self.includes_inter_part_spaces})"
def concatenate_aggregated_text(text_parts: List[TextPartForConcatenation]) -> str:

View File

@@ -103,9 +103,9 @@ class BaseTextAggregator(ABC):
a string indicating the type of aggregation (e.g., 'sentence', 'word',
'token', 'my_custom_aggregation').
"""
pass
# Make this a generator to satisfy type checker
yield # pragma: no cover
yield Aggregation("", "") # pragma: no cover
return # pragma: no cover
@abstractmethod
async def flush(self) -> Optional[Aggregation]:

View File

@@ -310,7 +310,7 @@ class PatternPairAggregator(SimpleTextAggregator):
# Which is why we base the return on the first found.
if start_count > end_count:
start_index = text.find(start)
return [start_index, pattern_info]
return (start_index, pattern_info)
return None

View File

@@ -241,4 +241,4 @@ def traceable(cls: C) -> C:
else:
Traceable.__init__(self, cls.__name__)
return TracedClass
return TracedClass # type: ignore[return-value]

View File

@@ -439,7 +439,7 @@ def add_openai_realtime_span_attributes(
if isinstance(tool, dict) and "name" in tool:
tool_names.append(tool["name"])
elif hasattr(tool, "name"):
tool_names.append(tool.name)
tool_names.append(getattr(tool, "name"))
elif isinstance(tool, dict) and "function" in tool and "name" in tool["function"]:
tool_names.append(tool["function"]["name"])
@@ -454,7 +454,7 @@ def add_openai_realtime_span_attributes(
if function_calls:
call = function_calls[0]
if hasattr(call, "name"):
span.set_attribute("function_calls.first_name", call.name)
span.set_attribute("function_calls.first_name", getattr(call, "name"))
elif isinstance(call, dict) and "name" in call:
span.set_attribute("function_calls.first_name", call["name"])

View File

@@ -467,9 +467,9 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
# Handle system message for different services
system_message = None
if hasattr(context, "system"):
system_message = context.system
system_message = getattr(context, "system")
elif hasattr(context, "system_message"):
system_message = context.system_message
system_message = getattr(context, "system_message")
elif hasattr(self, "_system_instruction"):
system_message = self._system_instruction