Compare commits
3 Commits
pk/optiona
...
mb/static-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5a6cc4d35c | ||
|
|
28be775740 | ||
|
|
bc730e4069 |
1
changelog/3678.added.md
Normal file
1
changelog/3678.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added pyright basic type checking configuration for the core framework.
|
||||
@@ -212,6 +212,14 @@ ignore = [
|
||||
[tool.ruff.lint.pydocstyle]
|
||||
convention = "google"
|
||||
|
||||
[tool.pyright]
|
||||
typeCheckingMode = "basic"
|
||||
pythonVersion = "3.10"
|
||||
exclude = [
|
||||
"src/pipecat/services/",
|
||||
"src/pipecat/adapters/",
|
||||
]
|
||||
|
||||
[tool.coverage.run]
|
||||
command_line = "--module pytest"
|
||||
source = ["src"]
|
||||
|
||||
@@ -31,7 +31,7 @@ def version() -> str:
|
||||
import asyncio
|
||||
|
||||
if sys.version_info < (3, 12):
|
||||
import wait_for2
|
||||
import wait_for2 # type: ignore[import-untyped]
|
||||
|
||||
# Replace asyncio.wait_for.
|
||||
asyncio.wait_for = wait_for2.wait_for
|
||||
|
||||
@@ -55,7 +55,7 @@ async def load_dtmf_audio(button: KeypadEntry, *, sample_rate: int = 8000) -> by
|
||||
dtmf_file_name = __DTMF_FILE_NAME.get(button, f"dtmf-{button.value}.wav")
|
||||
dtmf_file_path = files("pipecat.audio.dtmf").joinpath(dtmf_file_name)
|
||||
|
||||
async with aiofiles.open(dtmf_file_path, "rb") as f:
|
||||
async with aiofiles.open(str(dtmf_file_path), "rb") as f:
|
||||
data = await f.read()
|
||||
|
||||
with io.BytesIO(data) as buffer:
|
||||
|
||||
@@ -18,16 +18,22 @@ from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from aic_sdk import (
|
||||
Model,
|
||||
ParameterFixedError,
|
||||
ProcessorAsync,
|
||||
ProcessorConfig,
|
||||
ProcessorParameter,
|
||||
set_sdk_id,
|
||||
)
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from aic_sdk import ( # type: ignore[import-not-found,import-untyped]
|
||||
Model,
|
||||
ParameterFixedError,
|
||||
ProcessorAsync,
|
||||
ProcessorConfig,
|
||||
ProcessorParameter,
|
||||
set_sdk_id, # pyright: ignore[reportAttributeAccessIssue]
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use AICFilter, you need to install aic_sdk.")
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
|
||||
from pipecat.audio.vad.aic_vad import AICVADAnalyzer
|
||||
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame
|
||||
@@ -167,6 +173,7 @@ class AICFilter(BaseAudioFilter):
|
||||
logger.debug(f"Loading AIC model from: {self._model_path}")
|
||||
self._model = Model.from_file(str(self._model_path))
|
||||
else:
|
||||
assert self._model_id is not None
|
||||
logger.debug(f"Downloading AIC model: {self._model_id}")
|
||||
self._model_download_dir.mkdir(parents=True, exist_ok=True)
|
||||
model_path = await Model.download_async(self._model_id, str(self._model_download_dir))
|
||||
@@ -200,6 +207,7 @@ class AICFilter(BaseAudioFilter):
|
||||
return
|
||||
|
||||
# Get contexts for parameter control and VAD
|
||||
assert self._processor is not None
|
||||
self._processor_ctx = self._processor.get_processor_context()
|
||||
self._vad_ctx = self._processor.get_vad_context()
|
||||
|
||||
@@ -287,6 +295,9 @@ class AICFilter(BaseAudioFilter):
|
||||
mv = memoryview(self._audio_buffer)
|
||||
block_size = self._frames_per_block * self._bytes_per_sample
|
||||
|
||||
assert self._in_f32 is not None
|
||||
assert self._out_i16 is not None
|
||||
|
||||
for i in range(num_blocks):
|
||||
start = i * block_size
|
||||
block_i16 = np.frombuffer(mv[start : start + block_size], dtype=self._dtype)
|
||||
|
||||
@@ -11,6 +11,7 @@ reduction technology to suppress background noise in audio streams.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@@ -19,7 +20,9 @@ from pipecat.audio.filters.base_audio_filter import BaseAudioFilter
|
||||
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame
|
||||
|
||||
try:
|
||||
from pipecat_ai_krisp.audio.krisp_processor import KrispAudioProcessor
|
||||
from pipecat_ai_krisp.audio.krisp_processor import ( # type: ignore[import-not-found]
|
||||
KrispAudioProcessor,
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use the Krisp filter, you need to `pip install pipecat-ai[krisp]`.")
|
||||
@@ -68,7 +71,7 @@ class KrispFilter(BaseAudioFilter):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, sample_type: str = "PCM_16", channels: int = 1, model_path: str = None
|
||||
self, sample_type: str = "PCM_16", channels: int = 1, model_path: Optional[str] = None
|
||||
) -> None:
|
||||
"""Initialize the Krisp noise reduction filter.
|
||||
|
||||
@@ -115,6 +118,7 @@ class KrispFilter(BaseAudioFilter):
|
||||
sample_rate: The sample rate of the input transport in Hz.
|
||||
"""
|
||||
self._sample_rate = sample_rate
|
||||
assert self._model_path is not None
|
||||
self._krisp_processor = KrispProcessorManager.get_processor(
|
||||
self._sample_rate, self._sample_type, self._channels, self._model_path
|
||||
)
|
||||
@@ -154,6 +158,7 @@ class KrispFilter(BaseAudioFilter):
|
||||
data = data.astype(np.float32) + epsilon
|
||||
|
||||
# Process the audio chunk to reduce noise
|
||||
assert self._krisp_processor is not None
|
||||
reduced_noise = self._krisp_processor.process(data)
|
||||
|
||||
# Clip and set processed audio back to frame
|
||||
|
||||
@@ -10,6 +10,7 @@ This module provides an audio filter implementation using Krisp VIVA SDK.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
@@ -23,7 +24,7 @@ from pipecat.audio.krisp_instance import (
|
||||
from pipecat.frames.frames import FilterControlFrame, FilterEnableFrame
|
||||
|
||||
try:
|
||||
import krisp_audio
|
||||
import krisp_audio # type: ignore[import-not-found]
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use KrispVivaFilter, you need to install krisp_audio.")
|
||||
@@ -39,7 +40,10 @@ class KrispVivaFilter(BaseAudioFilter):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, model_path: str = None, frame_duration: int = 10, noise_suppression_level: int = 100
|
||||
self,
|
||||
model_path: Optional[str] = None,
|
||||
frame_duration: int = 10,
|
||||
noise_suppression_level: int = 100,
|
||||
) -> None:
|
||||
"""Initialize the Krisp noise reduction filter.
|
||||
|
||||
@@ -171,6 +175,9 @@ class KrispVivaFilter(BaseAudioFilter):
|
||||
return audio
|
||||
|
||||
try:
|
||||
if self._samples_per_frame is None or self._session is None:
|
||||
return audio
|
||||
|
||||
# Add incoming audio to our buffer
|
||||
self._audio_buffer.extend(audio)
|
||||
|
||||
|
||||
@@ -61,6 +61,8 @@ class RNNoiseFilter(BaseAudioFilter):
|
||||
|
||||
try:
|
||||
# RNNoise always requires 48kHz
|
||||
if RNNoise is None:
|
||||
raise RuntimeError("RNNoise module is not available")
|
||||
self._rnnoise = RNNoise(sample_rate=48000)
|
||||
self._rnnoise_ready = True
|
||||
except Exception as e:
|
||||
@@ -126,6 +128,7 @@ class RNNoiseFilter(BaseAudioFilter):
|
||||
# denoise_chunk handles buffering internally and yields (speech_prob, denoised_frame)
|
||||
# denoised_frame is in float32 format normalized to [-1.0, 1.0]
|
||||
filtered_frames = []
|
||||
assert self._rnnoise is not None
|
||||
for speech_prob, denoised_frame in self._rnnoise.denoise_chunk(audio_samples):
|
||||
# Check if output is float (needs scaling) or int16 (ready)
|
||||
if np.issubdtype(denoised_frame.dtype, np.floating):
|
||||
|
||||
@@ -12,7 +12,7 @@ from threading import Lock
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import krisp_audio
|
||||
import krisp_audio # type: ignore[import-not-found]
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use the Krisp instance, you need to install krisp_audio.")
|
||||
|
||||
@@ -96,6 +96,7 @@ class SOXRStreamAudioResampler(BaseAudioResampler):
|
||||
|
||||
self._maybe_initialize_sox_stream(in_rate, out_rate)
|
||||
audio_data = np.frombuffer(audio, dtype=np.int16)
|
||||
assert self._soxr_stream is not None
|
||||
resampled_audio = self._soxr_stream.resample_chunk(audio_data)
|
||||
result = resampled_audio.astype(np.int16).tobytes()
|
||||
return result
|
||||
|
||||
@@ -29,7 +29,7 @@ from pipecat.audio.turn.base_turn_analyzer import BaseTurnAnalyzer, BaseTurnPara
|
||||
from pipecat.metrics.metrics import MetricsData
|
||||
|
||||
try:
|
||||
import krisp_audio
|
||||
import krisp_audio # type: ignore[import-not-found]
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use KrispVivaTurn, you need to install krisp_audio.")
|
||||
|
||||
@@ -74,7 +74,7 @@ class LocalSmartTurnAnalyzer(BaseSmartTurn):
|
||||
# Set device to GPU if available, else CPU
|
||||
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
# Move model to selected device and set it to evaluation mode
|
||||
self._turn_model = self._turn_model.to(self._device)
|
||||
self._turn_model = self._turn_model.to(self._device) # type: ignore[assignment]
|
||||
self._turn_model.eval()
|
||||
logger.debug("Loaded Local Smart Turn")
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
|
||||
elif torch.cuda.is_available():
|
||||
self._device = "cuda"
|
||||
# Move model to selected device and set it to evaluation mode
|
||||
self._turn_model = self._turn_model.to(self._device)
|
||||
self._turn_model = self._turn_model.to(self._device) # type: ignore[assignment]
|
||||
self._turn_model.eval()
|
||||
logger.debug("Loaded Local Smart Turn v2")
|
||||
|
||||
@@ -77,12 +77,12 @@ class LocalSmartTurnAnalyzerV2(BaseSmartTurn):
|
||||
"""Predict end-of-turn using local PyTorch model."""
|
||||
inputs = self._turn_processor(
|
||||
audio_array,
|
||||
sampling_rate=16000,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=16000 * 16, # 16 seconds at 16kHz
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
sampling_rate=16000, # type: ignore[call-arg]
|
||||
padding="max_length", # type: ignore[call-arg]
|
||||
truncation=True, # type: ignore[call-arg]
|
||||
max_length=16000 * 16, # type: ignore[call-arg]
|
||||
return_attention_mask=True, # type: ignore[call-arg]
|
||||
return_tensors="pt", # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
# Move inputs to device
|
||||
|
||||
@@ -57,7 +57,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
package_path = "pipecat.audio.turn.smart_turn.data"
|
||||
|
||||
try:
|
||||
import importlib_resources as impresources
|
||||
import importlib_resources as impresources # type: ignore[import-not-found]
|
||||
|
||||
smart_turn_model_path = str(impresources.files(package_path).joinpath(model_name))
|
||||
except BaseException:
|
||||
@@ -65,7 +65,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
|
||||
try:
|
||||
with impresources.path(package_path, model_name) as f:
|
||||
smart_turn_model_path = f
|
||||
smart_turn_model_path = str(f)
|
||||
except BaseException:
|
||||
smart_turn_model_path = str(
|
||||
impresources.files(package_path).joinpath(model_name)
|
||||
@@ -80,6 +80,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
self._feature_extractor = WhisperFeatureExtractor(chunk_length=8)
|
||||
assert smart_turn_model_path is not None
|
||||
self._session = ort.InferenceSession(smart_turn_model_path, sess_options=so)
|
||||
|
||||
logger.debug("Loaded Local Smart Turn v3.x")
|
||||
@@ -165,7 +166,7 @@ class LocalSmartTurnAnalyzerV3(BaseSmartTurn):
|
||||
outputs = self._session.run(None, {"input_features": input_features})
|
||||
|
||||
# Extract probability (ONNX model returns sigmoid probabilities)
|
||||
probability = outputs[0][0].item()
|
||||
probability = outputs[0][0].item() # type: ignore[index]
|
||||
|
||||
# Make prediction (1 for Complete, 0 for Incomplete)
|
||||
prediction = 1 if probability > 0.5 else 0
|
||||
|
||||
@@ -65,7 +65,7 @@ class SileroOnnxModel:
|
||||
if np.ndim(x) == 1:
|
||||
x = np.expand_dims(x, 0)
|
||||
if np.ndim(x) > 2:
|
||||
raise ValueError(f"Too many dimensions for input audio chunk {x.dim()}")
|
||||
raise ValueError(f"Too many dimensions for input audio chunk {np.ndim(x)}")
|
||||
|
||||
if sr not in self.sample_rates:
|
||||
raise ValueError(
|
||||
@@ -150,7 +150,7 @@ class SileroVADAnalyzer(VADAnalyzer):
|
||||
package_path = "pipecat.audio.vad.data"
|
||||
|
||||
try:
|
||||
import importlib_resources as impresources
|
||||
import importlib_resources as impresources # type: ignore[import-not-found]
|
||||
|
||||
model_file_path = str(impresources.files(package_path).joinpath(model_name))
|
||||
except BaseException:
|
||||
@@ -158,7 +158,7 @@ class SileroVADAnalyzer(VADAnalyzer):
|
||||
|
||||
try:
|
||||
with impresources.path(package_path, model_name) as f:
|
||||
model_file_path = f
|
||||
model_file_path = str(f)
|
||||
except BaseException:
|
||||
model_file_path = str(impresources.files(package_path).joinpath(model_name))
|
||||
|
||||
@@ -209,7 +209,7 @@ class SileroVADAnalyzer(VADAnalyzer):
|
||||
audio_int16 = np.frombuffer(buffer, np.int16)
|
||||
# Divide by 32768 because we have signed 16-bit data.
|
||||
audio_float32 = np.frombuffer(audio_int16, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
new_confidence = self._model(audio_float32, self.sample_rate)[0]
|
||||
new_confidence = self._model(audio_float32, self.sample_rate)[0] # type: ignore[index]
|
||||
|
||||
# We need to reset the model from time to time because it doesn't
|
||||
# really need all the data and memory will keep growing otherwise.
|
||||
|
||||
@@ -99,7 +99,7 @@ class IVRProcessor(FrameProcessor):
|
||||
self._register_event_handler("on_conversation_detected")
|
||||
self._register_event_handler("on_ivr_status_changed")
|
||||
|
||||
def update_saved_messages(self, messages: List[dict]) -> None:
|
||||
def update_saved_messages(self, messages: List) -> None:
|
||||
"""Update the saved context messages.
|
||||
|
||||
Sets the messages that are saved when switching between
|
||||
|
||||
@@ -36,7 +36,7 @@ from pipecat.frames.frames import (
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.pipeline.parallel_pipeline import ParallelPipeline
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext, LLMContextMessage
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
@@ -618,7 +618,7 @@ VOICEMAIL SYSTEM (respond "VOICEMAIL"):
|
||||
self._validate_prompt(custom_system_prompt)
|
||||
|
||||
# Set up the LLM context with the classification prompt
|
||||
self._messages = [
|
||||
self._messages: List[LLMContextMessage] = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": self._prompt,
|
||||
|
||||
@@ -44,7 +44,7 @@ class UserBotLatencyObserver(BaseObserver):
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._user_stopped_time: Optional[float] = None
|
||||
self._processed_frames: Set[str] = set()
|
||||
self._processed_frames: Set[int] = set()
|
||||
|
||||
self._register_event_handler("on_latency_measured")
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
|
||||
"""LLM switcher for switching between different LLMs at runtime, with different switching strategies."""
|
||||
|
||||
from typing import Any, List, Optional, Type
|
||||
from typing import Any, List, Optional, Type, cast
|
||||
|
||||
from pipecat.adapters.schemas.direct_function import DirectFunction
|
||||
from pipecat.pipeline.service_switcher import ServiceSwitcher, StrategyType
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
from pipecat.services.llm_service import LLMService
|
||||
|
||||
|
||||
@@ -32,7 +33,7 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
|
||||
llms: List of LLM services to switch between.
|
||||
strategy_type: The strategy class to use for switching between LLMs.
|
||||
"""
|
||||
super().__init__(llms, strategy_type)
|
||||
super().__init__(cast(List[FrameProcessor], llms), strategy_type)
|
||||
|
||||
@property
|
||||
def llms(self) -> List[LLMService]:
|
||||
@@ -41,7 +42,7 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
|
||||
Returns:
|
||||
List of LLM services managed by this switcher.
|
||||
"""
|
||||
return self.services
|
||||
return cast(List[LLMService], self.services)
|
||||
|
||||
@property
|
||||
def active_llm(self) -> Optional[LLMService]:
|
||||
@@ -50,7 +51,7 @@ class LLMSwitcher(ServiceSwitcher[StrategyType]):
|
||||
Returns:
|
||||
The currently active LLM service, or None if no LLM is active.
|
||||
"""
|
||||
return self.strategy.active_service
|
||||
return cast(Optional[LLMService], self.strategy.active_service)
|
||||
|
||||
async def run_inference(self, context: LLMContext) -> Optional[str]:
|
||||
"""Run a one-shot, out-of-band (i.e. out-of-pipeline) inference with the given LLM context, using the currently active LLM.
|
||||
|
||||
@@ -128,7 +128,7 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
def __init__(
|
||||
self,
|
||||
wrapped_service: FrameProcessor,
|
||||
active_service: FrameProcessor,
|
||||
active_service: Optional[FrameProcessor],
|
||||
direction: FrameDirection,
|
||||
):
|
||||
"""Initialize the service switcher filter with a strategy and direction.
|
||||
@@ -163,7 +163,7 @@ class ServiceSwitcher(ParallelPipeline, Generic[StrategyType]):
|
||||
class ServiceSwitcherFilterFrame(ControlFrame):
|
||||
"""An internal frame used by ServiceSwitcher to filter frames based on active service."""
|
||||
|
||||
active_service: FrameProcessor
|
||||
active_service: Optional[FrameProcessor]
|
||||
|
||||
@staticmethod
|
||||
def _make_pipeline_definitions(
|
||||
|
||||
@@ -383,7 +383,7 @@ class PipelineTask(BasePipelineTask):
|
||||
# allows us to receive and react to downstream frames.
|
||||
source = PipelineSource(self._source_push_frame, name=f"{self}::Source")
|
||||
sink = PipelineSink(self._sink_push_frame, name=f"{self}::Sink")
|
||||
processors = [self._rtvi, pipeline] if self._rtvi else [pipeline]
|
||||
processors: List[FrameProcessor] = [self._rtvi, pipeline] if self._rtvi else [pipeline]
|
||||
self._pipeline = Pipeline(processors, source=source, sink=sink)
|
||||
|
||||
# The task observer acts as a proxy to the provided observers. This way,
|
||||
@@ -786,7 +786,7 @@ class PipelineTask(BasePipelineTask):
|
||||
await self._observer.cleanup()
|
||||
|
||||
# End conversation tracing if it's active - this will also close any active turn span
|
||||
if self._enable_tracing and hasattr(self, "_turn_trace_observer"):
|
||||
if self._enable_tracing and self._turn_trace_observer is not None:
|
||||
self._turn_trace_observer.end_conversation_tracing()
|
||||
|
||||
# Cleanup pipeline processors.
|
||||
@@ -1020,7 +1020,7 @@ class PipelineTask(BasePipelineTask):
|
||||
path = Path(f).resolve()
|
||||
module_name = path.stem
|
||||
spec = importlib.util.spec_from_file_location(module_name, str(path))
|
||||
if spec:
|
||||
if spec and spec.loader:
|
||||
logger.debug(f"{self} loading observers from {path}")
|
||||
|
||||
# Load module.
|
||||
|
||||
@@ -164,6 +164,8 @@ class TaskObserver(BaseObserver):
|
||||
return proxies
|
||||
|
||||
async def _send_to_proxy(self, data: Any):
|
||||
if not self._proxies:
|
||||
return
|
||||
for proxy in self._proxies.values():
|
||||
await proxy.queue.put(data)
|
||||
|
||||
@@ -188,7 +190,9 @@ class TaskObserver(BaseObserver):
|
||||
|
||||
if isinstance(data, FramePushed):
|
||||
if on_push_frame_deprecated:
|
||||
await observer.on_push_frame(
|
||||
# Deprecated signature: on_push_frame(source, destination, frame, direction, timestamp)
|
||||
handler: Any = observer.on_push_frame
|
||||
await handler(
|
||||
data.source, data.destination, data.frame, data.direction, data.timestamp
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -19,7 +19,7 @@ import base64
|
||||
import io
|
||||
import wave
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, TypeAlias, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeAlias, Union, cast
|
||||
|
||||
from loguru import logger
|
||||
from openai._types import NOT_GIVEN as OPEN_AI_NOT_GIVEN
|
||||
@@ -107,10 +107,13 @@ class LLMContext:
|
||||
converted_tools = openai_context.tools
|
||||
if isinstance(converted_tools, list):
|
||||
converted_tools = ToolsSchema(
|
||||
standard_tools=[], custom_tools={AdapterType.SHIM: converted_tools}
|
||||
standard_tools=[],
|
||||
custom_tools=cast(
|
||||
Dict[AdapterType, List[Dict[str, Any]]], {AdapterType.SHIM: converted_tools}
|
||||
),
|
||||
)
|
||||
return LLMContext(
|
||||
messages=openai_context.get_messages(),
|
||||
messages=cast(List[LLMContextMessage], openai_context.get_messages()),
|
||||
tools=converted_tools,
|
||||
tool_choice=openai_context.tool_choice,
|
||||
)
|
||||
@@ -152,7 +155,7 @@ class LLMContext:
|
||||
|
||||
content.append({"type": "image_url", "image_url": {"url": url}})
|
||||
|
||||
return {"role": role, "content": content}
|
||||
return cast(LLMContextMessage, {"role": role, "content": content})
|
||||
|
||||
@staticmethod
|
||||
async def create_image_message(
|
||||
@@ -204,7 +207,7 @@ class LLMContext:
|
||||
audio_frames: List of audio frame objects to include.
|
||||
text: Optional text to include with the audio.
|
||||
"""
|
||||
content = [{"type": "text", "text": text}]
|
||||
content: List[Dict[str, Any]] = [{"type": "text", "text": text}]
|
||||
|
||||
def encode_audio():
|
||||
sample_rate = audio_frames[0].sample_rate
|
||||
@@ -231,7 +234,7 @@ class LLMContext:
|
||||
}
|
||||
)
|
||||
|
||||
return {"role": role, "content": content}
|
||||
return cast(LLMContextMessage, {"role": role, "content": content})
|
||||
|
||||
@property
|
||||
def messages(self) -> List[LLMContextMessage]:
|
||||
|
||||
@@ -15,12 +15,15 @@ import asyncio
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Literal, Optional, Set
|
||||
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Set, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolChoiceOptionParam
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.interruptions.base_interruption_strategy import BaseInterruptionStrategy
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.base_turn_analyzer import BaseTurnParams
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
@@ -336,7 +339,7 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
|
||||
Returns:
|
||||
List of message dictionaries from the context.
|
||||
"""
|
||||
return self._context.get_messages()
|
||||
return cast(List[dict], self._context.get_messages())
|
||||
|
||||
@property
|
||||
def role(self) -> str:
|
||||
@@ -402,7 +405,7 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
|
||||
"""
|
||||
self._context.set_messages(messages)
|
||||
|
||||
def set_tools(self, tools: List):
|
||||
def set_tools(self, tools):
|
||||
"""Set tools in the context.
|
||||
|
||||
Args:
|
||||
@@ -416,7 +419,7 @@ class LLMContextResponseAggregator(BaseLLMResponseAggregator):
|
||||
Args:
|
||||
tool_choice: Tool choice configuration for the context.
|
||||
"""
|
||||
self._context.set_tool_choice(tool_choice)
|
||||
self._context.set_tool_choice(cast("ChatCompletionToolChoiceOptionParam", tool_choice))
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the aggregation state."""
|
||||
@@ -467,7 +470,7 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
super().__init__(context=context, role="user", **kwargs)
|
||||
self._params = params or LLMUserAggregatorParams()
|
||||
self._vad_params: Optional[VADParams] = None
|
||||
self._turn_params: Optional[SmartTurnParams] = None
|
||||
self._turn_params: Optional[BaseTurnParams] = None
|
||||
|
||||
if "aggregation_timeout" in kwargs:
|
||||
with warnings.catch_warnings():
|
||||
@@ -503,7 +506,9 @@ class LLMUserContextAggregator(LLMContextResponseAggregator):
|
||||
Args:
|
||||
aggregation: The aggregated user text to add as a user message.
|
||||
"""
|
||||
self._context.add_message({"role": self.role, "content": aggregation})
|
||||
self._context.add_message(
|
||||
cast("ChatCompletionMessageParam", {"role": self.role, "content": aggregation})
|
||||
)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames for user speech aggregation and context management.
|
||||
@@ -851,7 +856,9 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
Args:
|
||||
aggregation: The aggregated assistant text to add as an assistant message.
|
||||
"""
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
self._context.add_message(
|
||||
cast("ChatCompletionMessageParam", {"role": "assistant", "content": aggregation})
|
||||
)
|
||||
|
||||
async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
|
||||
"""Handle a function call that is in progress.
|
||||
@@ -1030,6 +1037,7 @@ class LLMAssistantContextAggregator(LLMContextResponseAggregator):
|
||||
del self._function_calls_in_progress[frame.tool_call_id]
|
||||
|
||||
async def _handle_user_image_frame(self, frame: UserImageRawFrame):
|
||||
assert frame.request is not None
|
||||
logger.debug(
|
||||
f"{self} UserImageRawFrame: [{frame.request.function_name}:{frame.request.tool_call_id}]"
|
||||
)
|
||||
@@ -1105,14 +1113,18 @@ class LLMUserResponseAggregator(LLMUserContextAggregator):
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
|
||||
super().__init__(
|
||||
context=OpenAILLMContext(cast(List["ChatCompletionMessageParam"], messages)),
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def _process_aggregation(self):
|
||||
"""Process the current aggregation and push it downstream."""
|
||||
aggregation = self._aggregation
|
||||
await self.reset()
|
||||
await self.handle_aggregation(aggregation)
|
||||
frame = LLMMessagesFrame(self._context.messages)
|
||||
frame = LLMMessagesFrame(cast(List[dict], self._context.messages))
|
||||
await self.push_frame(frame)
|
||||
|
||||
|
||||
@@ -1150,7 +1162,11 @@ class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(context=OpenAILLMContext(messages), params=params, **kwargs)
|
||||
super().__init__(
|
||||
context=OpenAILLMContext(cast(List["ChatCompletionMessageParam"], messages)),
|
||||
params=params,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def push_aggregation(self):
|
||||
"""Push the aggregated assistant response as an LLMMessagesFrame."""
|
||||
@@ -1161,5 +1177,5 @@ class LLMAssistantResponseAggregator(LLMAssistantContextAggregator):
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
await self.reset()
|
||||
|
||||
frame = LLMMessagesFrame(self._context.messages)
|
||||
frame = LLMMessagesFrame(cast(List[dict], self._context.messages))
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -16,11 +16,11 @@ import json
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Type
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Type, cast
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.adapters.schemas.tools_schema import AdapterType, ToolsSchema
|
||||
from pipecat.audio.vad.vad_analyzer import VADAnalyzer
|
||||
from pipecat.audio.vad.vad_controller import VADController
|
||||
from pipecat.frames.frames import (
|
||||
@@ -258,12 +258,20 @@ class LLMContextAggregator(FrameProcessor):
|
||||
"""
|
||||
self._context.set_messages(messages)
|
||||
|
||||
def set_tools(self, tools: ToolsSchema | NotGiven):
|
||||
def set_tools(self, tools: ToolsSchema | List | NotGiven):
|
||||
"""Set tools in the context.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions to set in the context.
|
||||
tools: Tool definitions to set in the context.
|
||||
"""
|
||||
if isinstance(tools, list):
|
||||
tools = ToolsSchema(
|
||||
standard_tools=[],
|
||||
custom_tools=cast(
|
||||
Dict[AdapterType, List[Dict[str, Any]]],
|
||||
{AdapterType.SHIM: tools},
|
||||
),
|
||||
)
|
||||
self._context.set_tools(tools)
|
||||
|
||||
def set_tool_choice(self, tool_choice: Literal["none", "auto", "required"] | dict):
|
||||
@@ -272,7 +280,9 @@ class LLMContextAggregator(FrameProcessor):
|
||||
Args:
|
||||
tool_choice: Tool choice configuration for the context.
|
||||
"""
|
||||
self._context.set_tool_choice(tool_choice)
|
||||
from pipecat.processors.aggregators.llm_context import LLMContextToolChoice
|
||||
|
||||
self._context.set_tool_choice(cast(LLMContextToolChoice, tool_choice))
|
||||
|
||||
async def reset(self):
|
||||
"""Reset the aggregation state."""
|
||||
@@ -485,7 +495,9 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
|
||||
aggregation = self.aggregation_string()
|
||||
await self.reset()
|
||||
self._context.add_message({"role": self.role, "content": aggregation})
|
||||
self._context.add_message(
|
||||
cast(LLMContextMessage, {"role": self.role, "content": aggregation})
|
||||
)
|
||||
await self.push_context_frame()
|
||||
|
||||
return aggregation
|
||||
@@ -515,7 +527,11 @@ class LLMUserAggregator(LLMContextAggregator):
|
||||
)
|
||||
|
||||
# Auto-inject turn completion instructions into context
|
||||
self._context.add_message({"role": "system", "content": config.completion_instructions})
|
||||
self._context.add_message(
|
||||
cast(
|
||||
LLMContextMessage, {"role": "system", "content": config.completion_instructions}
|
||||
)
|
||||
)
|
||||
|
||||
async def _stop(self, frame: EndFrame):
|
||||
await self._maybe_emit_user_turn_stopped(on_session_end=True)
|
||||
@@ -811,7 +827,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
self._assistant_turn_start_timestamp = ""
|
||||
|
||||
self._thought_append_to_context = False
|
||||
self._thought_llm: str = ""
|
||||
self._thought_llm: Optional[str] = ""
|
||||
self._thought_aggregation: List[TextPartForConcatenation] = []
|
||||
self._thought_start_time: str = ""
|
||||
|
||||
@@ -899,7 +915,9 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
aggregation = self.aggregation_string()
|
||||
await self.reset()
|
||||
|
||||
self._context.add_message({"role": "assistant", "content": aggregation})
|
||||
self._context.add_message(
|
||||
cast(LLMContextMessage, {"role": "assistant", "content": aggregation})
|
||||
)
|
||||
|
||||
# Push context frame
|
||||
await self.push_context_frame()
|
||||
@@ -945,26 +963,32 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
# Update context with the in-progress function call
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments, ensure_ascii=False),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
}
|
||||
cast(
|
||||
LLMContextMessage,
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": frame.tool_call_id,
|
||||
"function": {
|
||||
"name": frame.function_name,
|
||||
"arguments": json.dumps(frame.arguments, ensure_ascii=False),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
)
|
||||
self._context.add_message(
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "IN_PROGRESS",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
}
|
||||
cast(
|
||||
LLMContextMessage,
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "IN_PROGRESS",
|
||||
"tool_call_id": frame.tool_call_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
self._function_calls_in_progress[frame.tool_call_id] = frame
|
||||
@@ -1060,6 +1084,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
logger.debug(f"{self} Appending AssistantImageRawFrame to LLM context (size: {frame.size})")
|
||||
|
||||
if frame.original_data and frame.original_mime_type:
|
||||
assert frame.original_mime_type is not None
|
||||
await self._context.add_image_frame_message(
|
||||
format=frame.original_mime_type,
|
||||
size=frame.size, # Technically doesn't matter, since already encoded
|
||||
@@ -1068,7 +1093,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
)
|
||||
else:
|
||||
await self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
format=frame.format or "RGB",
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
role="assistant",
|
||||
@@ -1125,11 +1150,10 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
|
||||
thought = concatenate_aggregated_text(self._thought_aggregation)
|
||||
|
||||
if self._thought_append_to_context:
|
||||
llm = self._thought_llm
|
||||
if self._thought_append_to_context and self._thought_llm is not None:
|
||||
self._context.add_message(
|
||||
LLMSpecificMessage(
|
||||
llm=llm,
|
||||
llm=self._thought_llm,
|
||||
message={
|
||||
"type": "thought",
|
||||
"text": thought,
|
||||
@@ -1151,7 +1175,7 @@ class LLMAssistantAggregator(LLMContextAggregator):
|
||||
logger.debug(f"{self} Appending UserImageRawFrame to LLM context (size: {frame.size})")
|
||||
|
||||
await self._context.add_image_frame_message(
|
||||
format=frame.format,
|
||||
format=frame.format or "RGB",
|
||||
size=frame.size,
|
||||
image=frame.image,
|
||||
text=frame.text,
|
||||
|
||||
@@ -21,7 +21,7 @@ import io
|
||||
import json
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from openai._types import NOT_GIVEN, NotGiven
|
||||
from openai.types.chat import (
|
||||
@@ -144,7 +144,7 @@ class OpenAILLMContext:
|
||||
context = OpenAILLMContext()
|
||||
|
||||
for message in messages:
|
||||
context.add_message(message)
|
||||
context.add_message(cast(ChatCompletionMessageParam, message))
|
||||
return context
|
||||
|
||||
@property
|
||||
@@ -157,7 +157,7 @@ class OpenAILLMContext:
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def tools(self) -> List[ChatCompletionToolParam] | NotGiven | List[Any]:
|
||||
def tools(self) -> List[ChatCompletionToolParam] | NotGiven | ToolsSchema | List[Any]:
|
||||
"""Get the tools list, converting through adapter if available.
|
||||
|
||||
Returns:
|
||||
@@ -311,7 +311,7 @@ class OpenAILLMContext:
|
||||
self._tools = tools
|
||||
|
||||
def add_image_frame_message(
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
|
||||
self, *, format: str, size: tuple[int, int], image: bytes, text: Optional[str] = None
|
||||
):
|
||||
"""Add a message containing an image frame.
|
||||
|
||||
@@ -333,7 +333,9 @@ class OpenAILLMContext:
|
||||
)
|
||||
self.add_message({"role": "user", "content": content})
|
||||
|
||||
def add_audio_frames_message(self, *, audio_frames: list[AudioRawFrame], text: str = None):
|
||||
def add_audio_frames_message(
|
||||
self, *, audio_frames: list[AudioRawFrame], text: Optional[str] = None
|
||||
):
|
||||
"""Add a message containing audio frames.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -14,6 +14,7 @@ in conversational pipelines.
|
||||
from pipecat.frames.frames import TextFrame
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import LLMUserAggregator
|
||||
from pipecat.utils.string import concatenate_aggregated_text
|
||||
|
||||
|
||||
class UserResponseAggregator(LLMUserAggregator):
|
||||
@@ -45,20 +46,27 @@ class UserResponseAggregator(LLMUserAggregator):
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
async def push_aggregation(self):
|
||||
async def push_aggregation(self) -> str:
|
||||
"""Push the aggregated user response as a TextFrame.
|
||||
|
||||
Creates a TextFrame from the current aggregation if it contains content,
|
||||
resets the aggregation state, and pushes the frame downstream.
|
||||
|
||||
Returns:
|
||||
The pushed aggregation text, or empty string if nothing to push.
|
||||
"""
|
||||
if len(self._aggregation) > 0:
|
||||
frame = TextFrame(self._aggregation.strip())
|
||||
text = concatenate_aggregated_text(self._aggregation).strip()
|
||||
frame = TextFrame(text)
|
||||
|
||||
# Reset the aggregation. Reset it before pushing it down, otherwise
|
||||
# if the tasks gets cancelled we won't be able to clear things up.
|
||||
self._aggregation = ""
|
||||
self._aggregation = []
|
||||
|
||||
await self.push_frame(frame)
|
||||
|
||||
# Reset our accumulator state.
|
||||
await self.reset()
|
||||
|
||||
return text
|
||||
return ""
|
||||
|
||||
@@ -72,7 +72,7 @@ class VisionImageFrameAggregator(FrameProcessor):
|
||||
text=self._describe_text,
|
||||
image=frame.image,
|
||||
size=frame.size,
|
||||
format=frame.format,
|
||||
format=frame.format or "RGB",
|
||||
)
|
||||
frame = OpenAILLMContextFrame(context)
|
||||
await self.push_frame(frame)
|
||||
|
||||
@@ -103,7 +103,7 @@ class FrameProcessorQueue(asyncio.PriorityQueue):
|
||||
self.__high_counter = 0
|
||||
self.__low_counter = 0
|
||||
|
||||
async def put(self, item: Tuple[Frame, FrameDirection, FrameCallback]):
|
||||
async def put(self, item: Tuple[Frame, FrameDirection, Optional[FrameCallback]]):
|
||||
"""Put an item into the priority queue.
|
||||
|
||||
System frames (`SystemFrame`) have higher priority than any other
|
||||
@@ -470,11 +470,13 @@ class FrameProcessor(BaseObject):
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
def create_task(self, coroutine: Coroutine, name: Optional[str] = None) -> asyncio.Task:
|
||||
def create_task(
|
||||
self, coroutine: Coroutine | Awaitable, name: Optional[str] = None
|
||||
) -> asyncio.Task:
|
||||
"""Create a new task managed by this processor.
|
||||
|
||||
Args:
|
||||
coroutine: The coroutine to run in the task.
|
||||
coroutine: The coroutine or awaitable to run in the task.
|
||||
name: Optional name for the task.
|
||||
|
||||
Returns:
|
||||
@@ -483,7 +485,8 @@ class FrameProcessor(BaseObject):
|
||||
if name:
|
||||
name = f"{self}::{name}"
|
||||
else:
|
||||
name = f"{self}::{coroutine.cr_code.co_name}"
|
||||
cr_code = getattr(coroutine, "cr_code", None)
|
||||
name = f"{self}::{cr_code.co_name if cr_code else 'unknown'}"
|
||||
return self.task_manager.create_task(coroutine, name)
|
||||
|
||||
async def cancel_task(self, task: asyncio.Task, timeout: Optional[float] = 1.0):
|
||||
|
||||
@@ -74,7 +74,13 @@ class LangchainProcessor(FrameProcessor):
|
||||
if isinstance(frame, OpenAILLMContextFrame)
|
||||
else frame.context.get_messages()
|
||||
)
|
||||
text: str = messages[-1]["content"]
|
||||
last_msg = messages[-1]
|
||||
content = (
|
||||
last_msg.get("content", "")
|
||||
if isinstance(last_msg, dict)
|
||||
else getattr(last_msg, "content", "")
|
||||
)
|
||||
text: str = content if isinstance(content, str) else str(content)
|
||||
|
||||
await self._ainvoke(text.strip())
|
||||
else:
|
||||
@@ -94,7 +100,7 @@ class LangchainProcessor(FrameProcessor):
|
||||
case str():
|
||||
return text
|
||||
case AIMessageChunk():
|
||||
return text.content
|
||||
return text.content if isinstance(text.content, str) else str(text.content)
|
||||
case _:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -1253,31 +1253,33 @@ class RTVIObserver(BaseObserver):
|
||||
report_level = self._get_function_call_report_level(function_call.function_name)
|
||||
if report_level == RTVIFunctionCallReportLevel.DISABLED:
|
||||
continue
|
||||
data = RTVILLMFunctionCallStartMessageData()
|
||||
fc_start_data = RTVILLMFunctionCallStartMessageData()
|
||||
if report_level in (
|
||||
RTVIFunctionCallReportLevel.NAME,
|
||||
RTVIFunctionCallReportLevel.FULL,
|
||||
):
|
||||
data.function_name = function_call.function_name
|
||||
message = RTVILLMFunctionCallStartMessage(data=data)
|
||||
fc_start_data.function_name = function_call.function_name
|
||||
message = RTVILLMFunctionCallStartMessage(data=fc_start_data)
|
||||
await self.send_rtvi_message(message)
|
||||
elif isinstance(frame, FunctionCallInProgressFrame):
|
||||
report_level = self._get_function_call_report_level(frame.function_name)
|
||||
if report_level != RTVIFunctionCallReportLevel.DISABLED:
|
||||
data = RTVILLMFunctionCallInProgressMessageData(tool_call_id=frame.tool_call_id)
|
||||
fc_progress_data = RTVILLMFunctionCallInProgressMessageData(
|
||||
tool_call_id=frame.tool_call_id
|
||||
)
|
||||
if report_level in (
|
||||
RTVIFunctionCallReportLevel.NAME,
|
||||
RTVIFunctionCallReportLevel.FULL,
|
||||
):
|
||||
data.function_name = frame.function_name
|
||||
fc_progress_data.function_name = frame.function_name
|
||||
if report_level == RTVIFunctionCallReportLevel.FULL:
|
||||
data.args = frame.arguments
|
||||
message = RTVILLMFunctionCallInProgressMessage(data=data)
|
||||
fc_progress_data.args = frame.arguments
|
||||
message = RTVILLMFunctionCallInProgressMessage(data=fc_progress_data)
|
||||
await self.send_rtvi_message(message)
|
||||
elif isinstance(frame, FunctionCallCancelFrame):
|
||||
report_level = self._get_function_call_report_level(frame.function_name)
|
||||
if report_level != RTVIFunctionCallReportLevel.DISABLED:
|
||||
data = RTVILLMFunctionCallStoppedMessageData(
|
||||
fc_cancel_data = RTVILLMFunctionCallStoppedMessageData(
|
||||
tool_call_id=frame.tool_call_id,
|
||||
cancelled=True,
|
||||
)
|
||||
@@ -1285,13 +1287,13 @@ class RTVIObserver(BaseObserver):
|
||||
RTVIFunctionCallReportLevel.NAME,
|
||||
RTVIFunctionCallReportLevel.FULL,
|
||||
):
|
||||
data.function_name = frame.function_name
|
||||
message = RTVILLMFunctionCallStoppedMessage(data=data)
|
||||
fc_cancel_data.function_name = frame.function_name
|
||||
message = RTVILLMFunctionCallStoppedMessage(data=fc_cancel_data)
|
||||
await self.send_rtvi_message(message)
|
||||
elif isinstance(frame, FunctionCallResultFrame):
|
||||
report_level = self._get_function_call_report_level(frame.function_name)
|
||||
if report_level != RTVIFunctionCallReportLevel.DISABLED:
|
||||
data = RTVILLMFunctionCallStoppedMessageData(
|
||||
fc_result_data = RTVILLMFunctionCallStoppedMessageData(
|
||||
tool_call_id=frame.tool_call_id,
|
||||
cancelled=False,
|
||||
)
|
||||
@@ -1299,10 +1301,10 @@ class RTVIObserver(BaseObserver):
|
||||
RTVIFunctionCallReportLevel.NAME,
|
||||
RTVIFunctionCallReportLevel.FULL,
|
||||
):
|
||||
data.function_name = frame.function_name
|
||||
fc_result_data.function_name = frame.function_name
|
||||
if report_level == RTVIFunctionCallReportLevel.FULL:
|
||||
data.result = frame.result if frame.result else None
|
||||
message = RTVILLMFunctionCallStoppedMessage(data=data)
|
||||
fc_result_data.result = frame.result if frame.result else None
|
||||
message = RTVILLMFunctionCallStoppedMessage(data=fc_result_data)
|
||||
await self.send_rtvi_message(message)
|
||||
elif isinstance(frame, RTVIServerMessageFrame):
|
||||
message = RTVIServerMessage(data=frame.data)
|
||||
@@ -1427,8 +1429,14 @@ class RTVIObserver(BaseObserver):
|
||||
|
||||
# Handle Google LLM format (protobuf objects with attributes)
|
||||
# Note: not possible if frame is a universal LLMContextFrame
|
||||
if hasattr(message, "role") and message.role == "user" and hasattr(message, "parts"):
|
||||
text = "".join(part.text for part in message.parts if hasattr(part, "text"))
|
||||
if (
|
||||
hasattr(message, "role")
|
||||
and getattr(message, "role", None) == "user"
|
||||
and hasattr(message, "parts")
|
||||
):
|
||||
text = "".join(
|
||||
part.text for part in getattr(message, "parts", []) if hasattr(part, "text")
|
||||
)
|
||||
if text:
|
||||
rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text))
|
||||
await self.send_rtvi_message(rtvi_message)
|
||||
@@ -1439,8 +1447,10 @@ class RTVIObserver(BaseObserver):
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
text = " ".join(item["text"] for item in content if "text" in item)
|
||||
else:
|
||||
elif isinstance(content, str):
|
||||
text = content
|
||||
else:
|
||||
text = str(content) if content else ""
|
||||
rtvi_message = RTVIUserLLMTextMessage(data=RTVITextMessageData(text=text))
|
||||
await self.send_rtvi_message(rtvi_message)
|
||||
|
||||
@@ -1482,7 +1492,7 @@ class RTVIObserver(BaseObserver):
|
||||
async def _send_error_response(self, frame: RTVIServerResponseFrame):
|
||||
"""Send a response to the client for a specific request."""
|
||||
message = RTVIErrorResponse(
|
||||
id=str(frame.client_msg.msg_id), data=RTVIErrorResponseData(error=frame.error)
|
||||
id=str(frame.client_msg.msg_id), data=RTVIErrorResponseData(error=frame.error or "")
|
||||
)
|
||||
await self.send_rtvi_message(message)
|
||||
|
||||
@@ -1593,7 +1603,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
self._client_ready = True
|
||||
await self._call_event_handler("on_client_ready")
|
||||
|
||||
async def set_bot_ready(self, about: Mapping[str, Any] = None):
|
||||
async def set_bot_ready(self, about: Optional[Mapping[str, Any]] = None):
|
||||
"""Mark the bot as ready and send the bot-ready message.
|
||||
|
||||
Args:
|
||||
@@ -2026,7 +2036,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
"""Handle an action execution request."""
|
||||
action_id = self._action_id(data.service, data.action)
|
||||
if action_id not in self._registered_actions:
|
||||
await self._send_error_response(request_id, f"Action {action_id} not registered")
|
||||
await self._send_error_response(request_id or "", f"Action {action_id} not registered")
|
||||
return
|
||||
action = self._registered_actions[action_id]
|
||||
arguments = {}
|
||||
@@ -2040,7 +2050,7 @@ class RTVIProcessor(FrameProcessor):
|
||||
message = RTVIActionResponse(id=request_id, data=RTVIActionResponseData(result=result))
|
||||
await self.push_transport_message(message)
|
||||
|
||||
async def _send_bot_ready(self, about: Mapping[str, Any] = None):
|
||||
async def _send_bot_ready(self, about: Optional[Mapping[str, Any]] = None):
|
||||
"""Send the bot-ready message to the client.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -76,7 +76,12 @@ class StrandsAgentsProcessor(FrameProcessor):
|
||||
messages = frame.context.get_messages()
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
await self._ainvoke(str(last_message["content"]).strip())
|
||||
content = (
|
||||
last_message.get("content", "")
|
||||
if isinstance(last_message, dict)
|
||||
else getattr(last_message, "content", "")
|
||||
)
|
||||
await self._ainvoke(str(content).strip())
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -100,6 +105,7 @@ class StrandsAgentsProcessor(FrameProcessor):
|
||||
await self.stop_ttfb_metrics()
|
||||
ttfb_tracking = False
|
||||
try:
|
||||
assert self.graph_exit_node is not None
|
||||
node_result = graph_result.results[self.graph_exit_node]
|
||||
logger.debug(f"Node result: {node_result}")
|
||||
for agent_result in node_result.get_agent_results():
|
||||
@@ -119,6 +125,7 @@ class StrandsAgentsProcessor(FrameProcessor):
|
||||
logger.warning(f"Failed to extract messages from GraphResult: {parse_err}")
|
||||
else:
|
||||
# Agent supports streaming events via async iterator
|
||||
assert self.agent is not None
|
||||
async for event in self.agent.stream_async(text):
|
||||
# Push to TTS service
|
||||
if isinstance(event, dict) and "data" in event:
|
||||
|
||||
@@ -28,7 +28,7 @@ try:
|
||||
|
||||
gi.require_version("Gst", "1.0")
|
||||
gi.require_version("GstApp", "1.0")
|
||||
from gi.repository import Gst, GstApp
|
||||
from gi.repository import Gst, GstApp # type: ignore[import-not-found]
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(
|
||||
|
||||
@@ -65,6 +65,7 @@ class FrameProcessorMetrics(BaseObject):
|
||||
Returns:
|
||||
The task manager instance for async operations.
|
||||
"""
|
||||
assert self._task_manager is not None
|
||||
return self._task_manager
|
||||
|
||||
@property
|
||||
|
||||
@@ -10,7 +10,7 @@ This module provides processors that convert speech and text frames into structu
|
||||
transcript messages with timestamps, enabling conversation history tracking and analysis.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -47,10 +47,14 @@ class BaseTranscriptProcessor(FrameProcessor):
|
||||
**kwargs: Additional arguments passed to parent class.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self._processed_messages: List[TranscriptionMessage] = []
|
||||
self._processed_messages: List[
|
||||
Union[TranscriptionMessage, ThoughtTranscriptionMessage]
|
||||
] = []
|
||||
self._register_event_handler("on_transcript_update")
|
||||
|
||||
async def _emit_update(self, messages: List[TranscriptionMessage]):
|
||||
async def _emit_update(
|
||||
self, messages: List[Union[TranscriptionMessage, ThoughtTranscriptionMessage]]
|
||||
):
|
||||
"""Emit transcript updates for new messages.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -79,11 +79,11 @@ async def configure(
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
room_exp_duration: Optional[float] = 2.0,
|
||||
token_exp_duration: Optional[float] = 2.0,
|
||||
room_exp_duration: float = 2.0,
|
||||
token_exp_duration: float = 2.0,
|
||||
sip_caller_phone: Optional[str] = None,
|
||||
sip_enable_video: Optional[bool] = False,
|
||||
sip_num_endpoints: Optional[int] = 1,
|
||||
sip_enable_video: bool = False,
|
||||
sip_num_endpoints: int = 1,
|
||||
sip_codecs: Optional[Dict[str, List[str]]] = None,
|
||||
room_properties: Optional[DailyRoomProperties] = None,
|
||||
token_properties: Optional["DailyMeetingTokenProperties"] = None,
|
||||
@@ -209,6 +209,7 @@ async def configure(
|
||||
|
||||
# Add SIP configuration if enabled
|
||||
if sip_enabled:
|
||||
assert sip_caller_phone is not None
|
||||
sip_params = DailyRoomSipParams(
|
||||
display_name=sip_caller_phone,
|
||||
video=sip_enable_video,
|
||||
|
||||
@@ -72,10 +72,23 @@ import os
|
||||
import sys
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPMethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from http import HTTPMethod
|
||||
else:
|
||||
# HTTPMethod was added in Python 3.11
|
||||
from enum import Enum
|
||||
|
||||
class HTTPMethod(str, Enum):
|
||||
"""HTTP method enum for Python < 3.11 compatibility."""
|
||||
|
||||
POST = "POST"
|
||||
PATCH = "PATCH"
|
||||
GET = "GET"
|
||||
|
||||
|
||||
import aiohttp
|
||||
from fastapi.responses import FileResponse, Response
|
||||
from loguru import logger
|
||||
@@ -140,6 +153,8 @@ def _get_bot_module():
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
module_name, os.path.join(cwd, filename)
|
||||
)
|
||||
if not spec or not spec.loader:
|
||||
continue
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
@@ -397,6 +412,11 @@ def _setup_whatsapp_routes(app: FastAPI, args: argparse.Namespace):
|
||||
)
|
||||
return
|
||||
|
||||
# Assertions after the all() check above ensures these are non-None
|
||||
assert WHATSAPP_TOKEN is not None
|
||||
assert WHATSAPP_PHONE_NUMBER_ID is not None
|
||||
assert WHATSAPP_WEBHOOK_VERIFICATION_TOKEN is not None
|
||||
|
||||
try:
|
||||
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
||||
from pipecat.transports.whatsapp.api import WhatsAppWebhookRequest
|
||||
|
||||
@@ -416,8 +416,8 @@ def _get_transport_params(transport_key: str, transport_params: Dict[str, Callab
|
||||
async def _create_telephony_transport(
|
||||
websocket: WebSocket,
|
||||
params: Optional[Any] = None,
|
||||
transport_type: str = None,
|
||||
call_data: dict = None,
|
||||
transport_type: Optional[str] = None,
|
||||
call_data: Optional[dict] = None,
|
||||
) -> BaseTransport:
|
||||
"""Create a telephony transport with pre-parsed WebSocket data.
|
||||
|
||||
@@ -443,6 +443,9 @@ async def _create_telephony_transport(
|
||||
|
||||
logger.info(f"Using pre-detected telephony provider: {transport_type}")
|
||||
|
||||
if call_data is None:
|
||||
raise ValueError("call_data must be provided for telephony transports.")
|
||||
|
||||
if transport_type == "twilio":
|
||||
from pipecat.serializers.twilio import TwilioFrameSerializer
|
||||
|
||||
@@ -586,6 +589,8 @@ async def create_transport(
|
||||
|
||||
from pipecat.transports.livekit.transport import LiveKitTransport
|
||||
|
||||
if runner_args.token is None:
|
||||
raise ValueError("LiveKit token is required")
|
||||
return LiveKitTransport(
|
||||
runner_args.url,
|
||||
runner_args.token,
|
||||
|
||||
@@ -62,6 +62,7 @@ class ExotelFrameSerializer(FrameSerializer):
|
||||
params: Configuration parameters.
|
||||
"""
|
||||
super().__init__(params or ExotelFrameSerializer.InputParams())
|
||||
self._params: ExotelFrameSerializer.InputParams
|
||||
|
||||
self._stream_sid = stream_sid
|
||||
self._call_sid = call_sid
|
||||
|
||||
@@ -169,6 +169,7 @@ class GenesysAudioHookSerializer(FrameSerializer):
|
||||
**kwargs: Additional arguments passed to BaseObject (e.g., name).
|
||||
"""
|
||||
super().__init__(params or GenesysAudioHookSerializer.InputParams(), **kwargs)
|
||||
self._params: GenesysAudioHookSerializer.InputParams
|
||||
|
||||
self._genesys_sample_rate = self._params.genesys_sample_rate
|
||||
self._sample_rate = 0 # Pipeline input rate, set in setup()
|
||||
|
||||
@@ -74,6 +74,7 @@ class PlivoFrameSerializer(FrameSerializer):
|
||||
params: Configuration parameters.
|
||||
"""
|
||||
super().__init__(params or PlivoFrameSerializer.InputParams())
|
||||
self._params: PlivoFrameSerializer.InputParams
|
||||
|
||||
self._stream_id = stream_id
|
||||
self._call_id = call_id
|
||||
|
||||
@@ -79,23 +79,23 @@ class ProtobufFrameSerializer(FrameSerializer):
|
||||
Serialized frame as bytes, or None if frame type is not serializable.
|
||||
"""
|
||||
# Wrapping this messages as a JSONFrame to send
|
||||
serializable_frame: Frame | MessageFrame = frame
|
||||
if isinstance(frame, (OutputTransportMessageFrame, OutputTransportMessageUrgentFrame)):
|
||||
if self.should_ignore_frame(frame):
|
||||
return None
|
||||
frame = MessageFrame(
|
||||
serializable_frame = MessageFrame(
|
||||
data=json.dumps(frame.message),
|
||||
)
|
||||
|
||||
proto_frame = frame_protos.Frame()
|
||||
if type(frame) not in self.SERIALIZABLE_TYPES:
|
||||
proto_frame = frame_protos.Frame() # type: ignore[attr-defined]
|
||||
proto_optional_name = self.SERIALIZABLE_TYPES.get(type(serializable_frame))
|
||||
if proto_optional_name is None:
|
||||
logger.warning(f"Frame type {type(frame)} is not serializable")
|
||||
return None
|
||||
|
||||
# ignoring linter errors; we check that type(frame) is in this dict above
|
||||
proto_optional_name = self.SERIALIZABLE_TYPES[type(frame)] # type: ignore
|
||||
proto_attr = getattr(proto_frame, proto_optional_name)
|
||||
for field in dataclasses.fields(frame): # type: ignore
|
||||
value = getattr(frame, field.name)
|
||||
for field in dataclasses.fields(serializable_frame): # type: ignore[arg-type]
|
||||
value = getattr(serializable_frame, field.name)
|
||||
if value and hasattr(proto_attr, field.name):
|
||||
setattr(proto_attr, field.name, value)
|
||||
|
||||
@@ -110,7 +110,7 @@ class ProtobufFrameSerializer(FrameSerializer):
|
||||
Returns:
|
||||
Deserialized frame instance, or None if deserialization fails.
|
||||
"""
|
||||
proto = frame_protos.Frame.FromString(data)
|
||||
proto = frame_protos.Frame.FromString(data) # type: ignore[attr-defined]
|
||||
which = proto.WhichOneof("frame")
|
||||
if which not in self.DESERIALIZABLE_FIELDS:
|
||||
logger.error("Unable to deserialize a valid frame")
|
||||
|
||||
@@ -78,6 +78,7 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
params: Configuration parameters.
|
||||
"""
|
||||
super().__init__(params or TwilioFrameSerializer.InputParams())
|
||||
self._params: TwilioFrameSerializer.InputParams
|
||||
|
||||
# Validate hangup-related parameters if auto_hang_up is enabled
|
||||
if self._params.auto_hang_up:
|
||||
@@ -193,6 +194,8 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
endpoint = f"https://api.{edge_prefix}{region_prefix}twilio.com/2010-04-01/Accounts/{account_sid}/Calls/{call_sid}.json"
|
||||
|
||||
# Create basic auth from account_sid and auth_token
|
||||
assert account_sid is not None
|
||||
assert auth_token is not None
|
||||
auth = aiohttp.BasicAuth(account_sid, auth_token)
|
||||
|
||||
# Parameters to set the call status to "completed" (hang up)
|
||||
|
||||
@@ -57,6 +57,7 @@ class VonageFrameSerializer(FrameSerializer):
|
||||
params: Configuration parameters.
|
||||
"""
|
||||
super().__init__(params or VonageFrameSerializer.InputParams())
|
||||
self._params: VonageFrameSerializer.InputParams
|
||||
|
||||
self._vonage_sample_rate = self._params.vonage_sample_rate
|
||||
self._sample_rate = 0 # Pipeline input rate
|
||||
|
||||
@@ -592,7 +592,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
"""Handle end-of-turn analysis and generate prediction results."""
|
||||
if self._params.turn_analyzer:
|
||||
state, prediction = await self._params.turn_analyzer.analyze_end_of_turn()
|
||||
await self._deprecated_handle_prediction_result(prediction)
|
||||
if prediction is not None:
|
||||
await self._deprecated_handle_prediction_result(prediction)
|
||||
await self._deprecated_handle_end_of_turn_complete(state)
|
||||
|
||||
async def _deprecated_handle_end_of_turn_complete(self, state: EndOfTurnState):
|
||||
@@ -610,6 +611,7 @@ class BaseInputTransport(FrameProcessor):
|
||||
"""Run turn analysis on audio frame and handle results."""
|
||||
is_speech = vad_state == VADState.SPEAKING or vad_state == VADState.STARTING
|
||||
# If silence exceeds threshold, we are going to receive EndOfTurnState.COMPLETE
|
||||
assert self._params.turn_analyzer is not None
|
||||
end_of_turn_state = self._params.turn_analyzer.append_audio(frame.audio, is_speech)
|
||||
if end_of_turn_state == EndOfTurnState.COMPLETE:
|
||||
await self._deprecated_handle_end_of_turn_complete(end_of_turn_state)
|
||||
|
||||
@@ -701,14 +701,16 @@ class BaseOutputTransport(FrameProcessor):
|
||||
# Notify the bot stopped speaking upstream if necessary.
|
||||
await self._bot_stopped_speaking()
|
||||
|
||||
async def with_mixer(vad_stop_secs: float) -> AsyncGenerator[Frame, None]:
|
||||
async def with_mixer(
|
||||
mixer: BaseAudioMixer, vad_stop_secs: float
|
||||
) -> AsyncGenerator[Frame, None]:
|
||||
last_frame_time = 0
|
||||
silence = b"\x00" * self._audio_chunk_size
|
||||
while True:
|
||||
try:
|
||||
frame = self._audio_queue.get_nowait()
|
||||
if isinstance(frame, OutputAudioRawFrame):
|
||||
frame.audio = await self._mixer.mix(frame.audio)
|
||||
frame.audio = await mixer.mix(frame.audio)
|
||||
last_frame_time = time.time()
|
||||
yield frame
|
||||
self._audio_queue.task_done()
|
||||
@@ -719,7 +721,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await self._bot_stopped_speaking()
|
||||
# Generate an audio frame with only the mixer's part.
|
||||
frame = OutputAudioRawFrame(
|
||||
audio=await self._mixer.mix(silence),
|
||||
audio=await mixer.mix(silence),
|
||||
sample_rate=self._sample_rate,
|
||||
num_channels=self._params.audio_out_channels,
|
||||
)
|
||||
@@ -731,7 +733,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
await asyncio.sleep(0)
|
||||
|
||||
if self._mixer:
|
||||
return with_mixer(BOT_VAD_STOP_SECS)
|
||||
return with_mixer(self._mixer, BOT_VAD_STOP_SECS)
|
||||
else:
|
||||
return without_mixer(BOT_VAD_STOP_SECS)
|
||||
|
||||
@@ -864,6 +866,7 @@ class BaseOutputTransport(FrameProcessor):
|
||||
# which is kind of what happens in P2P connections.
|
||||
# We need to add support for that inside the DailyTransport
|
||||
if frame.size != desired_size:
|
||||
assert frame.format is not None
|
||||
image = Image.frombytes(frame.format, frame.size, frame.image)
|
||||
resized_image = image.resize(desired_size)
|
||||
# logger.warning(f"{frame} does not have the expected size {desired_size}, resizing")
|
||||
|
||||
@@ -195,7 +195,7 @@ class DailyUpdateRemoteParticipantsFrame(ControlFrame):
|
||||
remote_participants: See https://reference-python.daily.co/api_reference.html#daily.CallClient.update_remote_participants.
|
||||
"""
|
||||
|
||||
remote_participants: Mapping[str, Any] = None
|
||||
remote_participants: Optional[Mapping[str, Any]] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
@@ -751,7 +751,10 @@ class DailyTransportClient(EventHandler):
|
||||
|
||||
self._client.set_user_name(self._bot_name)
|
||||
|
||||
(data, error) = await self._join()
|
||||
result = await self._join()
|
||||
if result is None:
|
||||
return
|
||||
(data, error) = result
|
||||
|
||||
if not error:
|
||||
self._joined = True
|
||||
@@ -881,7 +884,7 @@ class DailyTransportClient(EventHandler):
|
||||
"""Cleanup the Daily client instance."""
|
||||
if self._client:
|
||||
self._client.release()
|
||||
self._client = None
|
||||
self._client = None # type: ignore[assignment]
|
||||
|
||||
def participants(self) -> Mapping[str, Any]:
|
||||
"""Get current participants in the room.
|
||||
@@ -1956,7 +1959,8 @@ class DailyOutputTransport(BaseOutputTransport):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, DailyUpdateRemoteParticipantsFrame):
|
||||
await self._client.update_remote_participants(frame.remote_participants)
|
||||
if frame.remote_participants is not None:
|
||||
await self._client.update_remote_participants(frame.remote_participants)
|
||||
|
||||
async def send_message(
|
||||
self, frame: OutputTransportMessageFrame | OutputTransportMessageUrgentFrame
|
||||
|
||||
@@ -230,6 +230,7 @@ class HeyGenOutputTransport(BaseOutputTransport):
|
||||
logger.warning("self._event_id is already defined!")
|
||||
self._event_id = str(frame.id)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
assert self._event_id is not None
|
||||
await self._client.agent_speak_end(self._event_id)
|
||||
self._event_id = None
|
||||
await super().push_frame(frame, direction)
|
||||
@@ -252,7 +253,8 @@ class HeyGenOutputTransport(BaseOutputTransport):
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._client.interrupt(self._event_id)
|
||||
if self._event_id is not None:
|
||||
await self._client.interrupt(self._event_id)
|
||||
await self.push_frame(frame, direction)
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self._client.start_agent_listening()
|
||||
@@ -267,6 +269,7 @@ class HeyGenOutputTransport(BaseOutputTransport):
|
||||
Args:
|
||||
frame: The audio frame to write.
|
||||
"""
|
||||
assert self._event_id is not None
|
||||
await self._client.agent_speak(bytes(frame.audio), self._event_id)
|
||||
return True
|
||||
|
||||
|
||||
@@ -412,7 +412,7 @@ class LiveKitTransportClient:
|
||||
"id": participant.sid,
|
||||
"name": participant.name,
|
||||
"metadata": participant.metadata,
|
||||
"is_speaking": participant.is_speaking,
|
||||
"is_speaking": participant.is_speaking, # type: ignore[attr-defined]
|
||||
}
|
||||
return {}
|
||||
|
||||
@@ -432,7 +432,7 @@ class LiveKitTransportClient:
|
||||
"""
|
||||
participant = self.room.remote_participants.get(participant_id)
|
||||
if participant:
|
||||
for track in participant.tracks.values():
|
||||
for track in participant.tracks.values(): # type: ignore[attr-defined]
|
||||
if track.kind == "audio":
|
||||
await track.set_enabled(False)
|
||||
|
||||
@@ -444,13 +444,14 @@ class LiveKitTransportClient:
|
||||
"""
|
||||
participant = self.room.remote_participants.get(participant_id)
|
||||
if participant:
|
||||
for track in participant.tracks.values():
|
||||
for track in participant.tracks.values(): # type: ignore[attr-defined]
|
||||
if track.kind == "audio":
|
||||
await track.set_enabled(True)
|
||||
|
||||
# Wrapper methods for event handlers
|
||||
def _on_participant_connected_wrapper(self, participant: rtc.RemoteParticipant):
|
||||
"""Wrapper for participant connected events."""
|
||||
assert self._task_manager is not None
|
||||
self._task_manager.create_task(
|
||||
self._async_on_participant_connected(participant),
|
||||
f"{self}::_async_on_participant_connected",
|
||||
@@ -458,6 +459,7 @@ class LiveKitTransportClient:
|
||||
|
||||
def _on_participant_disconnected_wrapper(self, participant: rtc.RemoteParticipant):
|
||||
"""Wrapper for participant disconnected events."""
|
||||
assert self._task_manager is not None
|
||||
self._task_manager.create_task(
|
||||
self._async_on_participant_disconnected(participant),
|
||||
f"{self}::_async_on_participant_disconnected",
|
||||
@@ -470,6 +472,7 @@ class LiveKitTransportClient:
|
||||
participant: rtc.RemoteParticipant,
|
||||
):
|
||||
"""Wrapper for track subscribed events."""
|
||||
assert self._task_manager is not None
|
||||
self._task_manager.create_task(
|
||||
self._async_on_track_subscribed(track, publication, participant),
|
||||
f"{self}::_async_on_track_subscribed",
|
||||
@@ -482,6 +485,7 @@ class LiveKitTransportClient:
|
||||
participant: rtc.RemoteParticipant,
|
||||
):
|
||||
"""Wrapper for track unsubscribed events."""
|
||||
assert self._task_manager is not None
|
||||
self._task_manager.create_task(
|
||||
self._async_on_track_unsubscribed(track, publication, participant),
|
||||
f"{self}::_async_on_track_unsubscribed",
|
||||
@@ -489,6 +493,7 @@ class LiveKitTransportClient:
|
||||
|
||||
def _on_data_received_wrapper(self, data: rtc.DataPacket):
|
||||
"""Wrapper for data received events."""
|
||||
assert self._task_manager is not None
|
||||
self._task_manager.create_task(
|
||||
self._async_on_data_received(data),
|
||||
f"{self}::_async_on_data_received",
|
||||
@@ -496,10 +501,12 @@ class LiveKitTransportClient:
|
||||
|
||||
def _on_connected_wrapper(self):
|
||||
"""Wrapper for connected events."""
|
||||
assert self._task_manager is not None
|
||||
self._task_manager.create_task(self._async_on_connected(), f"{self}::_async_on_connected")
|
||||
|
||||
def _on_disconnected_wrapper(self):
|
||||
"""Wrapper for disconnected events."""
|
||||
assert self._task_manager is not None
|
||||
self._task_manager.create_task(
|
||||
self._async_on_disconnected(), f"{self}::_async_on_disconnected"
|
||||
)
|
||||
@@ -531,6 +538,7 @@ class LiveKitTransportClient:
|
||||
logger.info(f"Audio track subscribed: {track.sid} from participant {participant.sid}")
|
||||
self._audio_tracks[participant.sid] = track
|
||||
audio_stream = rtc.AudioStream(track)
|
||||
assert self._task_manager is not None
|
||||
self._task_manager.create_task(
|
||||
self._process_audio_stream(audio_stream, participant.sid),
|
||||
f"{self}::_process_audio_stream",
|
||||
@@ -543,6 +551,7 @@ class LiveKitTransportClient:
|
||||
# unbounded queue growth when there is no consumer for video frames.
|
||||
if self._params.video_in_enabled:
|
||||
video_stream = rtc.VideoStream(track)
|
||||
assert self._task_manager is not None
|
||||
self._task_manager.create_task(
|
||||
self._process_video_stream(video_stream, participant.sid),
|
||||
f"{self}::_process_video_stream",
|
||||
@@ -564,7 +573,9 @@ class LiveKitTransportClient:
|
||||
|
||||
async def _async_on_data_received(self, data: rtc.DataPacket):
|
||||
"""Handle data received events."""
|
||||
await self._callbacks.on_data_received(data.data, data.participant.sid)
|
||||
participant = data.participant
|
||||
assert participant is not None
|
||||
await self._callbacks.on_data_received(data.data, participant.sid)
|
||||
|
||||
async def _async_on_connected(self):
|
||||
"""Handle connected events."""
|
||||
@@ -796,7 +807,7 @@ class LiveKitInputTransport(BaseInputTransport):
|
||||
"""Convert LiveKit video frame to Pipecat video frame."""
|
||||
rgb_frame = video_frame_event.frame.convert(proto_video_frame.VideoBufferType.RGB24)
|
||||
image_frame = ImageRawFrame(
|
||||
image=rgb_frame.data,
|
||||
image=bytes(rgb_frame.data),
|
||||
size=(rgb_frame.width, rgb_frame.height),
|
||||
format="RGB",
|
||||
)
|
||||
@@ -1119,7 +1130,7 @@ class LiveKitTransport(BaseTransport):
|
||||
await self._call_event_handler("on_audio_track_subscribed", participant_id)
|
||||
participant = self._client.room.remote_participants.get(participant_id)
|
||||
if participant:
|
||||
for publication in participant.audio_tracks.values():
|
||||
for publication in participant.audio_tracks.values(): # type: ignore[attr-defined]
|
||||
self._client._on_track_subscribed_wrapper(
|
||||
publication.track, publication, participant
|
||||
)
|
||||
@@ -1133,7 +1144,7 @@ class LiveKitTransport(BaseTransport):
|
||||
await self._call_event_handler("on_video_track_subscribed", participant_id)
|
||||
participant = self._client.room.remote_participants.get(participant_id)
|
||||
if participant:
|
||||
for publication in participant.video_tracks.values():
|
||||
for publication in participant.video_tracks.values(): # type: ignore[attr-defined]
|
||||
self._client._on_track_subscribed_wrapper(
|
||||
publication.track, publication, participant
|
||||
)
|
||||
|
||||
@@ -229,7 +229,7 @@ class TkOutputTransport(BaseOutputTransport):
|
||||
|
||||
# This holds a reference to the photo, preventing it from being garbage
|
||||
# collected.
|
||||
self._image_label.image = photo
|
||||
self._image_label.image = photo # type: ignore[assignment]
|
||||
|
||||
|
||||
class TkLocalTransport(BaseTransport):
|
||||
|
||||
@@ -15,7 +15,7 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
from typing import Any, List, Literal, Optional, Union, cast
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
@@ -224,9 +224,9 @@ class SmallWebRTCConnection(BaseObject):
|
||||
if not ice_servers:
|
||||
self.ice_servers: List[IceServer] = []
|
||||
elif all(isinstance(s, IceServer) for s in ice_servers):
|
||||
self.ice_servers = ice_servers
|
||||
self.ice_servers = cast(List[IceServer], ice_servers)
|
||||
elif all(isinstance(s, str) for s in ice_servers):
|
||||
self.ice_servers = [IceServer(urls=s) for s in ice_servers]
|
||||
self.ice_servers = [IceServer(urls=cast(str, s)) for s in ice_servers]
|
||||
else:
|
||||
raise TypeError("ice_servers must be either List[str] or List[RTCIceServer]")
|
||||
self._connect_invoked = False
|
||||
@@ -384,10 +384,10 @@ class SmallWebRTCConnection(BaseObject):
|
||||
# and aiortc does not handle that pretty well.
|
||||
video_input_track = self.video_input_track()
|
||||
if video_input_track:
|
||||
await self.video_input_track().discard_old_frames()
|
||||
await video_input_track.discard_old_frames()
|
||||
screen_video_input_track = self.screen_video_input_track()
|
||||
if screen_video_input_track:
|
||||
await self.screen_video_input_track().discard_old_frames()
|
||||
await screen_video_input_track.discard_old_frames()
|
||||
if video_input_track or screen_video_input_track:
|
||||
# This prevents an issue where sometimes the WebRTC connection can be established
|
||||
# before the bot is ready to receive video. When that happens, we can lose a couple
|
||||
|
||||
@@ -223,6 +223,7 @@ class SmallWebRTCRequestHandler:
|
||||
)
|
||||
|
||||
answer = pipecat_connection.get_answer()
|
||||
assert answer is not None
|
||||
|
||||
if self._esp32_mode:
|
||||
from pipecat.runner.utils import smallwebrtc_sdp_munging
|
||||
|
||||
@@ -307,7 +307,7 @@ class SmallWebRTCClient:
|
||||
if (
|
||||
self._webrtc_connection.is_connected()
|
||||
and video_track
|
||||
and video_track.is_enabled()
|
||||
and video_track.is_enabled() # type: ignore[attr-defined]
|
||||
):
|
||||
logger.warning("Timeout: No video frame received within the specified time.")
|
||||
# self._webrtc_connection.ask_to_renegotiate()
|
||||
@@ -362,7 +362,7 @@ class SmallWebRTCClient:
|
||||
if (
|
||||
self._webrtc_connection.is_connected()
|
||||
and self._audio_input_track
|
||||
and self._audio_input_track.is_enabled()
|
||||
and self._audio_input_track.is_enabled() # type: ignore[attr-defined]
|
||||
):
|
||||
logger.warning("Timeout: No audio frame received within the specified time.")
|
||||
frame = None
|
||||
@@ -375,7 +375,7 @@ class SmallWebRTCClient:
|
||||
await asyncio.sleep(0.01)
|
||||
continue
|
||||
|
||||
if frame.sample_rate > self._in_sample_rate:
|
||||
if self._in_sample_rate is not None and frame.sample_rate > self._in_sample_rate:
|
||||
resampled_frames = self._pipecat_resampler.resample(frame)
|
||||
for resampled_frame in resampled_frames:
|
||||
# 16-bit PCM bytes
|
||||
@@ -383,6 +383,7 @@ class SmallWebRTCClient:
|
||||
pcm_bytes = pcm_array.tobytes()
|
||||
del pcm_array # free NumPy array immediately
|
||||
|
||||
assert self._audio_in_channels is not None
|
||||
audio_frame = InputAudioRawFrame(
|
||||
audio=pcm_bytes,
|
||||
sample_rate=resampled_frame.sample_rate,
|
||||
@@ -398,6 +399,7 @@ class SmallWebRTCClient:
|
||||
pcm_bytes = pcm_array.tobytes()
|
||||
del pcm_array # free NumPy array immediately
|
||||
|
||||
assert self._audio_in_channels is not None
|
||||
audio_frame = InputAudioRawFrame(
|
||||
audio=pcm_bytes,
|
||||
sample_rate=frame.sample_rate,
|
||||
@@ -489,9 +491,9 @@ class SmallWebRTCClient:
|
||||
if not self._params:
|
||||
return
|
||||
|
||||
self._audio_input_track = self._webrtc_connection.audio_input_track()
|
||||
self._video_input_track = self._webrtc_connection.video_input_track()
|
||||
self._screen_video_track = self._webrtc_connection.screen_video_input_track()
|
||||
self._audio_input_track = self._webrtc_connection.audio_input_track() # type: ignore[assignment]
|
||||
self._video_input_track = self._webrtc_connection.video_input_track() # type: ignore[assignment]
|
||||
self._screen_video_track = self._webrtc_connection.screen_video_input_track() # type: ignore[assignment]
|
||||
if self._params.audio_out_enabled:
|
||||
self._audio_output_track = RawAudioTrack(sample_rate=self._out_sample_rate)
|
||||
self._webrtc_connection.replace_audio_track(self._audio_output_track)
|
||||
|
||||
@@ -257,12 +257,14 @@ class TavusTransportClient:
|
||||
await self._client.setup(setup)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to setup TavusTransportClient: {e}")
|
||||
assert self._conversation_id is not None
|
||||
await self._api.end_conversation(self._conversation_id)
|
||||
self._conversation_id = None
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup client resources."""
|
||||
try:
|
||||
assert self._client is not None
|
||||
await self._client.cleanup()
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during cleanup: {e}")
|
||||
@@ -294,12 +296,15 @@ class TavusTransportClient:
|
||||
frame: The start frame containing initialization parameters.
|
||||
"""
|
||||
logger.debug("TavusTransportClient start invoked!")
|
||||
assert self._client is not None
|
||||
await self._client.start(frame)
|
||||
await self._client.join()
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the client and end the conversation."""
|
||||
assert self._client is not None
|
||||
await self._client.leave()
|
||||
assert self._conversation_id is not None
|
||||
await self._api.end_conversation(self._conversation_id)
|
||||
self._conversation_id = None
|
||||
|
||||
@@ -320,6 +325,7 @@ class TavusTransportClient:
|
||||
video_source: Video source to capture from.
|
||||
color_format: Color format for video frames.
|
||||
"""
|
||||
assert self._client is not None
|
||||
await self._client.capture_participant_video(
|
||||
participant_id, callback, framerate, video_source, color_format
|
||||
)
|
||||
@@ -341,6 +347,7 @@ class TavusTransportClient:
|
||||
sample_rate: Desired sample rate for audio capture.
|
||||
callback_interval_ms: Interval between audio callbacks in milliseconds.
|
||||
"""
|
||||
assert self._client is not None
|
||||
await self._client.capture_participant_audio(
|
||||
participant_id, callback, audio_source, sample_rate, callback_interval_ms
|
||||
)
|
||||
@@ -353,6 +360,7 @@ class TavusTransportClient:
|
||||
Args:
|
||||
frame: The message frame to send.
|
||||
"""
|
||||
assert self._client is not None
|
||||
await self._client.send_message(frame)
|
||||
|
||||
@property
|
||||
@@ -362,6 +370,7 @@ class TavusTransportClient:
|
||||
Returns:
|
||||
The output sample rate in Hz.
|
||||
"""
|
||||
assert self._client is not None
|
||||
return self._client.out_sample_rate
|
||||
|
||||
@property
|
||||
@@ -371,6 +380,7 @@ class TavusTransportClient:
|
||||
Returns:
|
||||
The input sample rate in Hz.
|
||||
"""
|
||||
assert self._client is not None
|
||||
return self._client.in_sample_rate
|
||||
|
||||
async def send_interrupt_message(self) -> None:
|
||||
|
||||
@@ -132,7 +132,7 @@ class WebsocketClientSession:
|
||||
return
|
||||
|
||||
try:
|
||||
self._websocket = await websocket_connect(
|
||||
self._websocket = await websocket_connect( # type: ignore[assignment]
|
||||
uri=self._uri,
|
||||
open_timeout=10,
|
||||
additional_headers=self._params.additional_headers,
|
||||
@@ -141,7 +141,8 @@ class WebsocketClientSession:
|
||||
self._client_task_handler(),
|
||||
f"{self._transport_name}::WebsocketClientSession::_client_task_handler",
|
||||
)
|
||||
await self._callbacks.on_connected(self._websocket)
|
||||
if self._websocket:
|
||||
await self._callbacks.on_connected(self._websocket)
|
||||
except TimeoutError:
|
||||
logger.error(f"Timeout connecting to {self._uri}")
|
||||
|
||||
@@ -179,7 +180,7 @@ class WebsocketClientSession:
|
||||
Returns:
|
||||
True if the WebSocket is in connected state.
|
||||
"""
|
||||
return self._websocket.state == websockets.State.OPEN if self._websocket else False
|
||||
return self._websocket.state == websockets.State.OPEN if self._websocket else False # type: ignore[attr-defined]
|
||||
|
||||
@property
|
||||
def is_closing(self) -> bool:
|
||||
@@ -188,18 +189,22 @@ class WebsocketClientSession:
|
||||
Returns:
|
||||
True if the WebSocket is in the process of closing.
|
||||
"""
|
||||
return self._websocket.state == websockets.State.CLOSING if self._websocket else False
|
||||
return self._websocket.state == websockets.State.CLOSING if self._websocket else False # type: ignore[attr-defined]
|
||||
|
||||
async def _client_task_handler(self):
|
||||
"""Handle incoming messages from the WebSocket connection."""
|
||||
if not self._websocket:
|
||||
return
|
||||
|
||||
websocket = self._websocket
|
||||
try:
|
||||
# Handle incoming messages
|
||||
async for message in self._websocket:
|
||||
await self._callbacks.on_message(self._websocket, message)
|
||||
async for message in websocket:
|
||||
await self._callbacks.on_message(websocket, message)
|
||||
except Exception as e:
|
||||
logger.error(f"{self} exception receiving data: {e.__class__.__name__} ({e})")
|
||||
|
||||
await self._callbacks.on_disconnected(self._websocket)
|
||||
await self._callbacks.on_disconnected(websocket)
|
||||
|
||||
def __str__(self):
|
||||
"""String representation of the WebSocket client session."""
|
||||
|
||||
@@ -326,6 +326,7 @@ class FastAPIWebsocketInputTransport(BaseInputTransport):
|
||||
|
||||
async def _monitor_websocket(self):
|
||||
"""Wait for self._params.session_timeout seconds, if the websocket is still open, trigger timeout event."""
|
||||
assert self._params.session_timeout is not None
|
||||
await asyncio.sleep(self._params.session_timeout)
|
||||
await self._client.trigger_client_timeout()
|
||||
|
||||
|
||||
@@ -181,7 +181,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
async def _server_task_handler(self):
|
||||
"""Handle WebSocket server startup and client connections."""
|
||||
logger.info(f"Starting websocket server on {self._host}:{self._port}")
|
||||
async with websocket_serve(self._client_handler, self._host, self._port) as server:
|
||||
async with websocket_serve(self._client_handler, self._host, self._port) as server: # type: ignore[arg-type]
|
||||
await self._callbacks.on_websocket_ready()
|
||||
await self._stop_server_event.wait()
|
||||
|
||||
|
||||
@@ -157,6 +157,7 @@ class WhatsAppClient:
|
||||
async def _validate_whatsapp_webhook_request(self, raw_body: bytes, sha256_signature: str):
|
||||
"""Common handler for both /start and /connect endpoints."""
|
||||
# Compute HMAC SHA256 using your App Secret
|
||||
assert self._whatsapp_secret is not None
|
||||
expected_signature = hmac.new(
|
||||
key=self._whatsapp_secret.encode("utf-8"),
|
||||
msg=raw_body,
|
||||
@@ -205,6 +206,7 @@ class WhatsAppClient:
|
||||
"""
|
||||
try:
|
||||
if self._whatsapp_secret:
|
||||
assert raw_body is not None and sha256_signature is not None
|
||||
await self._validate_whatsapp_webhook_request(raw_body, sha256_signature)
|
||||
for entry in request.entry:
|
||||
for change in entry.changes:
|
||||
@@ -306,7 +308,10 @@ class WhatsAppClient:
|
||||
# Create and initialize WebRTC connection
|
||||
pipecat_connection = SmallWebRTCConnection(self._ice_servers)
|
||||
await pipecat_connection.initialize(sdp=call.session.sdp, type=call.session.sdp_type)
|
||||
sdp_answer = pipecat_connection.get_answer().get("sdp")
|
||||
answer = pipecat_connection.get_answer()
|
||||
assert answer is not None
|
||||
sdp_answer = answer.get("sdp")
|
||||
assert isinstance(sdp_answer, str)
|
||||
sdp_answer = self._filter_sdp_for_whatsapp(sdp_answer)
|
||||
|
||||
logger.debug(f"SDP answer generated for call {call.id}")
|
||||
|
||||
@@ -14,7 +14,7 @@ were interrupted mid-thought.
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -28,6 +28,13 @@ from pipecat.frames.frames import (
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pipecat.processors.frame_processor import FrameProcessor
|
||||
|
||||
_MixinBase = FrameProcessor
|
||||
else:
|
||||
_MixinBase = object
|
||||
|
||||
# Turn completion markers
|
||||
USER_TURN_COMPLETE_MARKER = "✓"
|
||||
USER_TURN_INCOMPLETE_SHORT_MARKER = "○" # Short wait - user likely continues soon
|
||||
@@ -178,7 +185,7 @@ class UserTurnCompletionConfig:
|
||||
return self.incomplete_long_prompt or DEFAULT_INCOMPLETE_LONG_PROMPT
|
||||
|
||||
|
||||
class UserTurnCompletionLLMServiceMixin:
|
||||
class UserTurnCompletionLLMServiceMixin(_MixinBase):
|
||||
"""Mixin that adds turn completion detection to LLM services.
|
||||
|
||||
This mixin provides methods to push LLM text with turn completion detection.
|
||||
|
||||
@@ -15,7 +15,7 @@ import asyncio
|
||||
import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Coroutine, Dict, Optional, Sequence
|
||||
from typing import Awaitable, Coroutine, Dict, Optional, Sequence
|
||||
|
||||
from loguru import logger
|
||||
|
||||
@@ -56,13 +56,13 @@ class BaseTaskManager(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task:
|
||||
def create_task(self, coroutine: Coroutine | Awaitable, name: str) -> asyncio.Task:
|
||||
"""Creates and schedules a new asyncio Task that runs the given coroutine.
|
||||
|
||||
The task is added to a global set of created tasks.
|
||||
|
||||
Args:
|
||||
coroutine: The coroutine to be executed within the task.
|
||||
coroutine: The coroutine or awaitable to be executed within the task.
|
||||
name: The name to assign to the task for identification.
|
||||
|
||||
Returns:
|
||||
@@ -139,13 +139,13 @@ class TaskManager(BaseTaskManager):
|
||||
raise Exception("TaskManager is not setup: unable to get event loop")
|
||||
return self._params.loop
|
||||
|
||||
def create_task(self, coroutine: Coroutine, name: str) -> asyncio.Task:
|
||||
def create_task(self, coroutine: Coroutine | Awaitable, name: str) -> asyncio.Task:
|
||||
"""Creates and schedules a new asyncio Task that runs the given coroutine.
|
||||
|
||||
The task is added to a global set of created tasks.
|
||||
|
||||
Args:
|
||||
coroutine: The coroutine to be executed within the task.
|
||||
coroutine: The coroutine or awaitable to be executed within the task.
|
||||
name: The name to assign to the task for identification.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -234,7 +234,7 @@ class TextPartForConcatenation:
|
||||
includes_inter_part_spaces: bool
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}(text: [{self.text}], includes_inter_part_spaces: {self.includes_inter_part_spaces})"
|
||||
return f"{type(self).__name__}(text: [{self.text}], includes_inter_part_spaces: {self.includes_inter_part_spaces})"
|
||||
|
||||
|
||||
def concatenate_aggregated_text(text_parts: List[TextPartForConcatenation]) -> str:
|
||||
|
||||
@@ -103,9 +103,9 @@ class BaseTextAggregator(ABC):
|
||||
a string indicating the type of aggregation (e.g., 'sentence', 'word',
|
||||
'token', 'my_custom_aggregation').
|
||||
"""
|
||||
pass
|
||||
# Make this a generator to satisfy type checker
|
||||
yield # pragma: no cover
|
||||
yield Aggregation("", "") # pragma: no cover
|
||||
return # pragma: no cover
|
||||
|
||||
@abstractmethod
|
||||
async def flush(self) -> Optional[Aggregation]:
|
||||
|
||||
@@ -310,7 +310,7 @@ class PatternPairAggregator(SimpleTextAggregator):
|
||||
# Which is why we base the return on the first found.
|
||||
if start_count > end_count:
|
||||
start_index = text.find(start)
|
||||
return [start_index, pattern_info]
|
||||
return (start_index, pattern_info)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -241,4 +241,4 @@ def traceable(cls: C) -> C:
|
||||
else:
|
||||
Traceable.__init__(self, cls.__name__)
|
||||
|
||||
return TracedClass
|
||||
return TracedClass # type: ignore[return-value]
|
||||
|
||||
@@ -439,7 +439,7 @@ def add_openai_realtime_span_attributes(
|
||||
if isinstance(tool, dict) and "name" in tool:
|
||||
tool_names.append(tool["name"])
|
||||
elif hasattr(tool, "name"):
|
||||
tool_names.append(tool.name)
|
||||
tool_names.append(getattr(tool, "name"))
|
||||
elif isinstance(tool, dict) and "function" in tool and "name" in tool["function"]:
|
||||
tool_names.append(tool["function"]["name"])
|
||||
|
||||
@@ -454,7 +454,7 @@ def add_openai_realtime_span_attributes(
|
||||
if function_calls:
|
||||
call = function_calls[0]
|
||||
if hasattr(call, "name"):
|
||||
span.set_attribute("function_calls.first_name", call.name)
|
||||
span.set_attribute("function_calls.first_name", getattr(call, "name"))
|
||||
elif isinstance(call, dict) and "name" in call:
|
||||
span.set_attribute("function_calls.first_name", call["name"])
|
||||
|
||||
|
||||
@@ -467,9 +467,9 @@ def traced_llm(func: Optional[Callable] = None, *, name: Optional[str] = None) -
|
||||
# Handle system message for different services
|
||||
system_message = None
|
||||
if hasattr(context, "system"):
|
||||
system_message = context.system
|
||||
system_message = getattr(context, "system")
|
||||
elif hasattr(context, "system_message"):
|
||||
system_message = context.system_message
|
||||
system_message = getattr(context, "system_message")
|
||||
elif hasattr(self, "_system_instruction"):
|
||||
system_message = self._system_instruction
|
||||
|
||||
|
||||
Reference in New Issue
Block a user