prefer Optional over to "| None"

This commit is contained in:
Aleix Conchillo Flaqué
2025-02-06 11:11:37 -08:00
parent 684764fece
commit c4dbe92b30
43 changed files with 139 additions and 136 deletions

View File

@@ -6,6 +6,7 @@
import argparse
import os
from typing import Optional
import aiohttp
@@ -18,7 +19,7 @@ async def configure(aiohttp_session: aiohttp.ClientSession):
async def configure_with_args(
aiohttp_session: aiohttp.ClientSession, parser: argparse.ArgumentParser | None = None
aiohttp_session: aiohttp.ClientSession, parser: Optional[argparse.ArgumentParser] = None
):
if not parser:
parser = argparse.ArgumentParser(description="Daily AI SDK Bot Sample")

View File

@@ -7,6 +7,7 @@
import asyncio
import os
import sys
from typing import Optional
import aiohttp
from dotenv import load_dotenv
@@ -32,7 +33,7 @@ logger.add(sys.stderr, level="DEBUG")
class UserImageRequester(FrameProcessor):
def __init__(self, participant_id: str | None = None):
def __init__(self, participant_id: Optional[str] = None):
super().__init__()
self._participant_id = participant_id

View File

@@ -7,6 +7,7 @@
import asyncio
import os
import sys
from typing import Optional
import aiohttp
from dotenv import load_dotenv
@@ -32,7 +33,7 @@ logger.add(sys.stderr, level="DEBUG")
class UserImageRequester(FrameProcessor):
def __init__(self, participant_id: str | None = None):
def __init__(self, participant_id: Optional[str] = None):
super().__init__()
self._participant_id = participant_id

View File

@@ -7,6 +7,7 @@
import asyncio
import os
import sys
from typing import Optional
import aiohttp
from dotenv import load_dotenv
@@ -32,7 +33,7 @@ logger.add(sys.stderr, level="DEBUG")
class UserImageRequester(FrameProcessor):
def __init__(self, participant_id: str | None = None):
def __init__(self, participant_id: Optional[str] = None):
super().__init__()
self._participant_id = participant_id

View File

@@ -7,6 +7,7 @@
import asyncio
import os
import sys
from typing import Optional
import aiohttp
from dotenv import load_dotenv
@@ -32,7 +33,7 @@ logger.add(sys.stderr, level="DEBUG")
class UserImageRequester(FrameProcessor):
def __init__(self, participant_id: str | None = None):
def __init__(self, participant_id: Optional[str] = None):
super().__init__()
self._participant_id = participant_id

View File

@@ -6,6 +6,7 @@
import argparse
import os
from typing import Optional
import aiohttp
@@ -18,7 +19,7 @@ async def configure(aiohttp_session: aiohttp.ClientSession):
async def configure_with_args(
aiohttp_session: aiohttp.ClientSession, parser: argparse.ArgumentParser | None = None
aiohttp_session: aiohttp.ClientSession, parser: Optional[argparse.ArgumentParser] = None
):
if not parser:
parser = argparse.ArgumentParser(description="Daily AI SDK Bot Sample")

View File

@@ -6,6 +6,7 @@
import argparse
import os
from typing import Optional
import aiohttp
@@ -18,7 +19,7 @@ async def configure(aiohttp_session: aiohttp.ClientSession):
async def configure_with_args(
aiohttp_session: aiohttp.ClientSession, parser: argparse.ArgumentParser | None = None
aiohttp_session: aiohttp.ClientSession, parser: Optional[argparse.ArgumentParser] = None
):
if not parser:
parser = argparse.ArgumentParser(description="Daily AI SDK Bot Sample")

View File

@@ -2,6 +2,7 @@ import argparse
import asyncio
import os
import sys
from typing import Optional
from dotenv import load_dotenv
from loguru import logger
@@ -42,7 +43,7 @@ async def main(
callId: str,
callDomain: str,
detect_voicemail: bool,
dialout_number: str | None,
dialout_number: Optional[str],
):
# dialin_settings are only needed if Daily's SIP URI is used
# If you are handling this via Twilio, Telnyx, set this to None

View File

@@ -6,6 +6,7 @@
import argparse
import os
from typing import Optional
import aiohttp
@@ -18,7 +19,7 @@ async def configure(aiohttp_session: aiohttp.ClientSession):
async def configure_with_args(
aiohttp_session: aiohttp.ClientSession, parser: argparse.ArgumentParser | None = None
aiohttp_session: aiohttp.ClientSession, parser: Optional[argparse.ArgumentParser] = None
):
if not parser:
parser = argparse.ArgumentParser(description="Daily AI SDK Bot Sample")

View File

@@ -48,7 +48,7 @@ class KeypadEntry(str, Enum):
STAR = "*"
def format_pts(pts: int | None):
def format_pts(pts: Optional[int]):
return nanoseconds_to_str(pts) if pts else None
@@ -126,7 +126,7 @@ class ImageRawFrame:
image: bytes
size: Tuple[int, int]
format: str | None
format: Optional[str]
#
@@ -176,7 +176,7 @@ class URLImageRawFrame(OutputImageRawFrame):
"""
url: str | None
url: Optional[str]
def __str__(self):
pts = format_pts(self.pts)
@@ -235,7 +235,7 @@ class TranscriptionFrame(TextFrame):
user_id: str
timestamp: str
language: Language | None = None
language: Optional[Language] = None
def __str__(self):
return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})"
@@ -250,7 +250,7 @@ class InterimTranscriptionFrame(TextFrame):
text: str
user_id: str
timestamp: str
language: Language | None = None
language: Optional[Language] = None
def __str__(self):
return f"{self.name}(user: {self.user_id}, text: [{self.text}], language: {self.language}, timestamp: {self.timestamp})"
@@ -272,7 +272,7 @@ class TranscriptionMessage:
role: Literal["user", "assistant"]
content: str
timestamp: str | None = None
timestamp: Optional[str] = None
@dataclass
@@ -674,7 +674,7 @@ class UserImageRawFrame(InputImageRawFrame):
class VisionImageRawFrame(InputImageRawFrame):
"""An image with an associated text to ask for a description of it."""
text: str | None
text: Optional[str]
def __str__(self):
pts = format_pts(self.pts)

View File

@@ -19,7 +19,7 @@ class PipelineRunner:
def __init__(
self,
*,
name: str | None = None,
name: Optional[str] = None,
handle_sigint: bool = True,
force_gc: bool = False,
loop: Optional[asyncio.AbstractEventLoop] = None,

View File

@@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import List, Type
from typing import List, Optional, Type
from pipecat.frames.frames import (
Frame,
@@ -37,7 +37,7 @@ class LLMResponseAggregator(FrameProcessor):
start_frame,
end_frame,
accumulator_frame: Type[TextFrame],
interim_accumulator_frame: Type[TextFrame] | None = None,
interim_accumulator_frame: Optional[Type[TextFrame]] = None,
handle_interruptions: bool = False,
expect_stripped_words: bool = True, # if True, need to add spaces between words
):

View File

@@ -51,7 +51,7 @@ class CustomEncoder(json.JSONEncoder):
class OpenAILLMContext:
def __init__(
self,
messages: List[ChatCompletionMessageParam] | None = None,
messages: Optional[List[ChatCompletionMessageParam]] = None,
tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
):

View File

@@ -4,6 +4,8 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import Optional
from pipecat.frames.frames import (
Frame,
InterimTranscriptionFrame,
@@ -50,7 +52,7 @@ class ResponseAggregator(FrameProcessor):
start_frame,
end_frame,
accumulator_frame: TextFrame,
interim_accumulator_frame: TextFrame | None = None,
interim_accumulator_frame: Optional[TextFrame] = None,
):
super().__init__()

View File

@@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import Union
from typing import Optional, Union
from loguru import logger
@@ -30,7 +30,7 @@ class LangchainProcessor(FrameProcessor):
super().__init__()
self._chain = chain
self._transcript_key = transcript_key
self._participant_id: str | None = None
self._participant_id: Optional[str] = None
def set_participant_id(self, participant_id: str):
self._participant_id = participant_id

View File

@@ -753,7 +753,7 @@ class RTVIProcessor(FrameProcessor):
super().__init__(**kwargs)
self._config = config
self._pipeline: FrameProcessor | None = None
self._pipeline: Optional[FrameProcessor] = None
self._bot_ready = False
self._client_ready = False
@@ -999,7 +999,7 @@ class RTVIProcessor(FrameProcessor):
)
await self.push_frame(frame)
async def _handle_action(self, request_id: str | None, data: RTVIActionRun):
async def _handle_action(self, request_id: Optional[str], data: RTVIActionRun):
action_id = self._action_id(data.service, data.action)
if action_id not in self._registered_actions:
await self._send_error_response(request_id, f"Action {action_id} not registered")

View File

@@ -87,7 +87,7 @@ class AssistantTranscriptProcessor(BaseTranscriptProcessor):
"""Initialize processor with aggregation state."""
super().__init__(**kwargs)
self._current_text_parts: List[str] = []
self._aggregation_start_time: Optional[str] | None = None
self._aggregation_start_time: Optional[str] = None
async def _emit_aggregated_text(self):
"""Emit aggregated text as a transcript message."""

View File

@@ -140,7 +140,7 @@ class LLMService(AIService):
self._start_callbacks = {}
# TODO-CB: callback function type
def register_function(self, function_name: str | None, callback, start_callback=None):
def register_function(self, function_name: Optional[str], callback, start_callback=None):
# Registering a function with the function_name set to None will run that callback
# for all functions
self._callbacks[function_name] = callback
@@ -148,7 +148,7 @@ class LLMService(AIService):
if start_callback:
self._start_callbacks[function_name] = start_callback
def unregister_function(self, function_name: str | None):
def unregister_function(self, function_name: Optional[str]):
del self._callbacks[function_name]
if self._start_callbacks[function_name]:
del self._start_callbacks[function_name]
@@ -190,7 +190,7 @@ class LLMService(AIService):
elif None in self._start_callbacks.keys():
return await self._start_callbacks[None](function_name, self, context)
async def request_image_frame(self, user_id: str, *, text_content: str | None = None):
async def request_image_frame(self, user_id: str, *, text_content: Optional[str] = None):
await self.push_frame(
UserImageRequestFrame(user_id=user_id, context=text_content), FrameDirection.UPSTREAM
)
@@ -254,7 +254,7 @@ class TTSService(AIService):
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
pass
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return Language(language)
async def update_setting(self, key: str, value: Any):
@@ -352,7 +352,7 @@ class TTSService(AIService):
await self.push_frame(frame, direction)
async def _process_text_frame(self, frame: TextFrame):
text: str | None = None
text: Optional[str] = None
if not self._aggregate_sentences:
text = frame.text
else:

View File

@@ -326,9 +326,9 @@ class AnthropicLLMService(LLMService):
class AnthropicLLMContext(OpenAILLMContext):
def __init__(
self,
messages: list[dict] | None = None,
tools: list[dict] | None = None,
tool_choice: dict | None = None,
messages: Optional[List[dict]] = None,
tools: Optional[List[dict]] = None,
tool_choice: Optional[dict] = None,
*,
system: Union[str, NotGiven] = NOT_GIVEN,
):

View File

@@ -46,7 +46,7 @@ class AssemblyAISTTService(STTService):
super().__init__(sample_rate=sample_rate, **kwargs)
aai.settings.api_key = api_key
self._transcriber: aai.RealtimeTranscriber | None = None
self._transcriber: Optional[aai.RealtimeTranscriber] = None
self._settings = {
"encoding": encoding,

View File

@@ -32,7 +32,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
def language_to_aws_language(language: Language) -> str | None:
def language_to_aws_language(language: Language) -> Optional[str]:
language_map = {
# Arabic
Language.AR: "arb",
@@ -154,7 +154,7 @@ class PollyTTSService(TTSService):
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_aws_language(language)
def _construct_ssml(self, text: str) -> str:

View File

@@ -57,7 +57,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
def language_to_azure_language(language: Language) -> str | None:
def language_to_azure_language(language: Language) -> Optional[str]:
language_map = {
# Afrikaans
Language.AF: "af-ZA",
@@ -477,7 +477,7 @@ class AzureBaseTTSService(TTSService):
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_azure_language(language)
def _construct_ssml(self, text: str) -> str:

View File

@@ -9,7 +9,7 @@ import os
import uuid
import wave
from datetime import datetime
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import aiohttp
from loguru import logger
@@ -69,7 +69,7 @@ class CanonicalMetricsService(AIService):
api_url: str = "https://voiceapp.canonical.chat/api/v1",
assistant_speaks_first: bool = True,
output_dir: str = "recordings",
context: OpenAILLMContext | None = None,
context: Optional[OpenAILLMContext] = None,
**kwargs,
):
super().__init__(**kwargs)

View File

@@ -43,7 +43,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
def language_to_cartesia_language(language: Language) -> str | None:
def language_to_cartesia_language(language: Language) -> Optional[str]:
BASE_LANGUAGES = {
Language.DE: "de",
Language.EN: "en",
@@ -143,7 +143,7 @@ class CartesiaTTSService(WordTTSService, WebsocketService):
await super().set_model(model)
logger.info(f"Switching TTS model to: [{model}]")
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_cartesia_language(language)
def _build_msg(
@@ -358,7 +358,7 @@ class CartesiaHttpTTSService(TTSService):
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_cartesia_language(language)
async def start(self, frame: StartFrame):

View File

@@ -55,7 +55,7 @@ ELEVENLABS_MULTILINGUAL_MODELS = {
}
def language_to_elevenlabs_language(language: Language) -> str | None:
def language_to_elevenlabs_language(language: Language) -> Optional[str]:
BASE_LANGUAGES = {
Language.AR: "ar",
Language.BG: "bg",
@@ -223,7 +223,7 @@ class ElevenLabsTTSService(WordTTSService, WebsocketService):
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_elevenlabs_language(language)
def _set_voice_settings(self):

View File

@@ -42,7 +42,7 @@ class FalImageGenService(ImageGenService):
params: InputParams,
aiohttp_session: aiohttp.ClientSession,
model: str = "fal-ai/fast-sdxl",
key: str | None = None,
key: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)

View File

@@ -34,7 +34,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
def language_to_gladia_language(language: Language) -> str | None:
def language_to_gladia_language(language: Language) -> Optional[str]:
BASE_LANGUAGES = {
Language.AF: "af",
Language.AM: "am",
@@ -173,7 +173,7 @@ class GladiaSTTService(STTService):
}
self._confidence = confidence
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_gladia_language(language)
async def start(self, frame: StartFrame):

View File

@@ -63,7 +63,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
def language_to_google_language(language: Language) -> str | None:
def language_to_google_language(language: Language) -> Optional[str]:
language_map = {
# Afrikaans
Language.AF: "af-ZA",
@@ -346,9 +346,9 @@ class GoogleContextAggregatorPair:
class GoogleLLMContext(OpenAILLMContext):
def __init__(
self,
messages: list[dict] | None = None,
tools: list[dict] | None = None,
tool_choice: dict | None = None,
messages: Optional[List[dict]] = None,
tools: Optional[List[dict]] = None,
tool_choice: Optional[dict] = None,
):
super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
self.system_message = None
@@ -926,7 +926,7 @@ class GoogleTTSService(TTSService):
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_google_language(language)
def _construct_ssml(self, text: str) -> str:

View File

@@ -36,7 +36,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
def language_to_lmnt_language(language: Language) -> str | None:
def language_to_lmnt_language(language: Language) -> Optional[str]:
BASE_LANGUAGES = {
Language.DE: "de",
Language.EN: "en",
@@ -89,7 +89,7 @@ class LmntTTSService(TTSService, WebsocketService):
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_lmnt_language(language)
async def start(self, frame: StartFrame):

View File

@@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import Dict, List
from typing import Dict, List, Optional
from loguru import logger
@@ -28,11 +28,11 @@ class OpenPipeLLMService(OpenAILLMService):
self,
*,
model: str = "gpt-4o",
api_key: str | None = None,
base_url: str | None = None,
openpipe_api_key: str | None = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
openpipe_api_key: Optional[str] = None,
openpipe_base_url: str = "https://app.openpipe.ai/api/v1",
tags: Dict[str, str] | None = None,
tags: Optional[Dict[str, str]] = None,
**kwargs,
):
super().__init__(

View File

@@ -4,23 +4,12 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
from typing import Dict, List
from typing import Optional
from loguru import logger
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.openai import OpenAILLMService
try:
from openai import AsyncStream, OpenAI
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use OpenRouter, you need to `pip install pipecat-ai[openrouter]`. Also, set `OPENROUTER_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")
class OpenRouterLLMService(OpenAILLMService):
"""A service for interacting with OpenRouter's API using the OpenAI-compatible interface.
@@ -38,7 +27,7 @@ class OpenRouterLLMService(OpenAILLMService):
def __init__(
self,
*,
api_key: str | None = None,
api_key: Optional[str] = None,
model: str = "openai/gpt-4o-2024-11-20",
base_url: str = "https://openrouter.ai/api/v1",
**kwargs,

View File

@@ -46,7 +46,7 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
def language_to_playht_language(language: Language) -> str | None:
def language_to_playht_language(language: Language) -> Optional[str]:
BASE_LANGUAGES = {
Language.AF: "afrikans",
Language.AM: "amharic",
@@ -146,7 +146,7 @@ class PlayHTTTSService(TTSService, WebsocketService):
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_playht_language(language)
async def start(self, frame: StartFrame):
@@ -389,7 +389,7 @@ class PlayHTHttpTTSService(TTSService):
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_playht_language(language)
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:

View File

@@ -8,7 +8,7 @@
import asyncio
from enum import Enum
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional
import numpy as np
from loguru import logger
@@ -53,7 +53,7 @@ class WhisperSTTService(SegmentedSTTService):
self._compute_type = compute_type
self.set_model_name(model if isinstance(model, str) else model.value)
self._no_speech_prob = no_speech_prob
self._model: WhisperModel | None = None
self._model: Optional[WhisperModel] = None
self._load()
def can_generate_metrics(self) -> bool:

View File

@@ -29,7 +29,7 @@ from pipecat.transcriptions.language import Language
# https://github.com/coqui-ai/xtts-streaming-server
def language_to_xtts_language(language: Language) -> str | None:
def language_to_xtts_language(language: Language) -> Optional[str]:
BASE_LANGUAGES = {
Language.CS: "cs",
Language.DE: "de",
@@ -86,7 +86,7 @@ class XTTSService(TTSService):
"base_url": base_url,
}
self.set_voice(voice_id)
self._studio_speakers: Dict[str, Any] | None = None
self._studio_speakers: Optional[Dict[str, Any]] = None
self._aiohttp_session = aiohttp_session
self._resampler = create_default_resampler()
@@ -94,7 +94,7 @@ class XTTSService(TTSService):
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> str | None:
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_xtts_language(language)
async def start(self, frame: StartFrame):

View File

@@ -6,6 +6,7 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from loguru import logger
@@ -51,7 +52,7 @@ class BaseInputTransport(FrameProcessor):
return self._sample_rate
@property
def vad_analyzer(self) -> VADAnalyzer | None:
def vad_analyzer(self) -> Optional[VADAnalyzer]:
return self._params.vad_analyzer
async def start(self, frame: StartFrame):

View File

@@ -41,7 +41,7 @@ class TransportParams(BaseModel):
audio_in_filter: Optional[BaseAudioFilter] = None
vad_enabled: bool = False
vad_audio_passthrough: bool = False
vad_analyzer: VADAnalyzer | None = None
vad_analyzer: Optional[VADAnalyzer] = None
class BaseTransport(ABC):

View File

@@ -6,6 +6,7 @@
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from loguru import logger
@@ -116,8 +117,8 @@ class LocalAudioTransport(BaseTransport):
self._params = params
self._pyaudio = pyaudio.PyAudio()
self._input: LocalAudioInputTransport | None = None
self._output: LocalAudioOutputTransport | None = None
self._input: Optional[LocalAudioInputTransport] = None
self._output: Optional[LocalAudioOutputTransport] = None
#
# BaseTransport

View File

@@ -7,6 +7,7 @@
import asyncio
import tkinter as tk
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
import numpy as np
from loguru import logger
@@ -145,8 +146,8 @@ class TkLocalTransport(BaseTransport):
self._params = params
self._pyaudio = pyaudio.PyAudio()
self._input: TkInputTransport | None = None
self._output: TkOutputTransport | None = None
self._input: Optional[TkInputTransport] = None
self._output: Optional[TkOutputTransport] = None
#
# BaseTransport

View File

@@ -10,7 +10,7 @@ import io
import time
import typing
import wave
from typing import Awaitable, Callable
from typing import Awaitable, Callable, Optional
from loguru import logger
from pydantic import BaseModel
@@ -44,7 +44,7 @@ except ModuleNotFoundError as e:
class FastAPIWebsocketParams(TransportParams):
add_wav_header: bool = False
serializer: FrameSerializer
session_timeout: int | None = None
session_timeout: Optional[int] = None
class FastAPIWebsocketCallbacks(BaseModel):
@@ -202,8 +202,8 @@ class FastAPIWebsocketTransport(BaseTransport):
self,
websocket: WebSocket,
params: FastAPIWebsocketParams,
input_name: str | None = None,
output_name: str | None = None,
input_name: Optional[str] = None,
output_name: Optional[str] = None,
):
super().__init__(input_name=input_name, output_name=output_name)
self._params = params

View File

@@ -59,7 +59,7 @@ class WebsocketClientSession:
self._task_manager: Optional[TaskManager] = None
self._websocket: websockets.WebSocketClientProtocol | None = None
self._websocket: Optional[websockets.WebSocketClientProtocol] = None
@property
def task_manager(self) -> TaskManager:
@@ -240,8 +240,8 @@ class WebsocketClientTransport(BaseTransport):
)
self._session = WebsocketClientSession(uri, params, callbacks, self.name)
self._input: WebsocketClientInputTransport | None = None
self._output: WebsocketClientOutputTransport | None = None
self._input: Optional[WebsocketClientInputTransport] = None
self._output: Optional[WebsocketClientOutputTransport] = None
# Register supported handlers. The user will only be able to register
# these handlers.

View File

@@ -8,7 +8,7 @@ import asyncio
import io
import time
import wave
from typing import Awaitable, Callable
from typing import Awaitable, Callable, Optional
from loguru import logger
from pydantic import BaseModel
@@ -39,7 +39,7 @@ except ModuleNotFoundError as e:
class WebsocketServerParams(TransportParams):
add_wav_header: bool = False
serializer: FrameSerializer
session_timeout: int | None = None
session_timeout: Optional[int] = None
class WebsocketServerCallbacks(BaseModel):
@@ -64,7 +64,7 @@ class WebsocketServerInputTransport(BaseInputTransport):
self._params = params
self._callbacks = callbacks
self._websocket: websockets.WebSocketServerProtocol | None = None
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
self._server_task = None
@@ -158,7 +158,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
self._params = params
self._websocket: websockets.WebSocketServerProtocol | None = None
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
# write_raw_audio_frames() is called quickly, as soon as we get audio
# (e.g. from the TTS), and since this is just a network connection we
@@ -168,7 +168,7 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
self._send_interval = 0
self._next_send_time = 0
async def set_client_connection(self, websocket: websockets.WebSocketServerProtocol | None):
async def set_client_connection(self, websocket: Optional[websockets.WebSocketServerProtocol]):
if self._websocket:
await self._websocket.close()
logger.warning("Only one client allowed, using new connection")
@@ -242,8 +242,8 @@ class WebsocketServerTransport(BaseTransport):
params: WebsocketServerParams,
host: str = "localhost",
port: int = 8765,
input_name: str | None = None,
output_name: str | None = None,
input_name: Optional[str] = None,
output_name: Optional[str] = None,
):
super().__init__(input_name=input_name, output_name=output_name)
self._host = host
@@ -255,9 +255,9 @@ class WebsocketServerTransport(BaseTransport):
on_client_disconnected=self._on_client_disconnected,
on_session_timeout=self._on_session_timeout,
)
self._input: WebsocketServerInputTransport | None = None
self._output: WebsocketServerOutputTransport | None = None
self._websocket: websockets.WebSocketServerProtocol | None = None
self._input: Optional[WebsocketServerInputTransport] = None
self._output: Optional[WebsocketServerOutputTransport] = None
self._websocket: Optional[websockets.WebSocketServerProtocol] = None
# Register supported handlers. The user will only be able to register
# these handlers.

View File

@@ -62,12 +62,12 @@ VAD_RESET_PERIOD_MS = 2000
@dataclass
class DailyTransportMessageFrame(TransportMessageFrame):
participant_id: str | None = None
participant_id: Optional[str] = None
@dataclass
class DailyTransportMessageUrgentFrame(TransportMessageUrgentFrame):
participant_id: str | None = None
participant_id: Optional[str] = None
class WebRTCVADAnalyzer(VADAnalyzer):
@@ -175,7 +175,7 @@ class DailyTransportClient(EventHandler):
def __init__(
self,
room_url: str,
token: str | None,
token: Optional[str],
bot_name: str,
params: DailyParams,
callbacks: DailyCallbacks,
@@ -188,7 +188,7 @@ class DailyTransportClient(EventHandler):
Daily.init()
self._room_url: str = room_url
self._token: str | None = token
self._token: Optional[str] = token
self._bot_name: str = bot_name
self._params: DailyParams = params
self._callbacks = callbacks
@@ -226,9 +226,9 @@ class DailyTransportClient(EventHandler):
self._in_sample_rate = 0
self._out_sample_rate = 0
self._camera: VirtualCameraDevice | None = None
self._mic: VirtualMicrophoneDevice | None = None
self._speaker: VirtualSpeakerDevice | None = None
self._camera: Optional[VirtualCameraDevice] = None
self._mic: Optional[VirtualMicrophoneDevice] = None
self._speaker: Optional[VirtualSpeakerDevice] = None
def _camera_name(self):
return f"camera-{self}"
@@ -257,7 +257,7 @@ class DailyTransportClient(EventHandler):
)
await future
async def read_next_audio_frame(self) -> InputAudioRawFrame | None:
async def read_next_audio_frame(self) -> Optional[InputAudioRawFrame]:
if not self._speaker:
return None
@@ -542,7 +542,7 @@ class DailyTransportClient(EventHandler):
self._client.stop_recording(stream_id, completion=completion_callback(future))
await future
async def send_prebuilt_chat_message(self, message: str, user_name: str | None = None):
async def send_prebuilt_chat_message(self, message: str, user_name: Optional[str] = None):
if not self._joined:
return
@@ -723,10 +723,10 @@ class DailyInputTransport(BaseInputTransport):
# internally to be processed.
self._audio_in_task = None
self._vad_analyzer: VADAnalyzer | None = params.vad_analyzer
self._vad_analyzer: Optional[VADAnalyzer] = params.vad_analyzer
@property
def vad_analyzer(self) -> VADAnalyzer | None:
def vad_analyzer(self) -> Optional[VADAnalyzer]:
return self._vad_analyzer
async def start(self, frame: StartFrame):
@@ -891,11 +891,11 @@ class DailyTransport(BaseTransport):
def __init__(
self,
room_url: str,
token: str | None,
token: Optional[str],
bot_name: str,
params: DailyParams = DailyParams(),
input_name: str | None = None,
output_name: str | None = None,
input_name: Optional[str] = None,
output_name: Optional[str] = None,
):
super().__init__(input_name=input_name, output_name=output_name)
@@ -926,8 +926,8 @@ class DailyTransport(BaseTransport):
self._params = params
self._client = DailyTransportClient(room_url, token, bot_name, params, callbacks, self.name)
self._input: DailyInputTransport | None = None
self._output: DailyOutputTransport | None = None
self._input: Optional[DailyInputTransport] = None
self._output: Optional[DailyOutputTransport] = None
self._other_participant_has_joined = False
@@ -1014,7 +1014,7 @@ class DailyTransport(BaseTransport):
async def stop_recording(self, stream_id=None):
await self._client.stop_recording(stream_id)
async def send_prebuilt_chat_message(self, message: str, user_name: str | None = None):
async def send_prebuilt_chat_message(self, message: str, user_name: Optional[str] = None):
"""Sends a chat message to Daily's Prebuilt main room.
Args:

View File

@@ -40,12 +40,12 @@ except ModuleNotFoundError as e:
@dataclass
class LiveKitTransportMessageFrame(TransportMessageFrame):
participant_id: str | None = None
participant_id: Optional[str] = None
@dataclass
class LiveKitTransportMessageUrgentFrame(TransportMessageUrgentFrame):
participant_id: str | None = None
participant_id: Optional[str] = None
class LiveKitParams(TransportParams):
@@ -79,12 +79,12 @@ class LiveKitTransportClient:
self._params = params
self._callbacks = callbacks
self._transport_name = transport_name
self._room: rtc.Room | None = None
self._room: Optional[rtc.Room] = None
self._participant_id: str = ""
self._connected = False
self._disconnect_counter = 0
self._audio_source: rtc.AudioSource | None = None
self._audio_track: rtc.LocalAudioTrack | None = None
self._audio_source: Optional[rtc.AudioSource] = None
self._audio_track: Optional[rtc.LocalAudioTrack] = None
self._audio_tracks = {}
self._audio_queue = asyncio.Queue()
self._other_participant_has_joined = False
@@ -172,7 +172,7 @@ class LiveKitTransportClient:
logger.info(f"Disconnected from {self._room_name}")
await self._callbacks.on_disconnected()
async def send_data(self, data: bytes, participant_id: str | None = None):
async def send_data(self, data: bytes, participant_id: Optional[str] = None):
if not self._connected:
return
@@ -349,11 +349,11 @@ class LiveKitInputTransport(BaseInputTransport):
super().__init__(params, **kwargs)
self._client = client
self._audio_in_task = None
self._vad_analyzer: VADAnalyzer | None = params.vad_analyzer
self._vad_analyzer: Optional[VADAnalyzer] = params.vad_analyzer
self._resampler = create_default_resampler()
@property
def vad_analyzer(self) -> VADAnalyzer | None:
def vad_analyzer(self) -> Optional[VADAnalyzer]:
return self._vad_analyzer
async def start(self, frame: StartFrame):
@@ -463,8 +463,8 @@ class LiveKitTransport(BaseTransport):
token: str,
room_name: str,
params: LiveKitParams = LiveKitParams(),
input_name: str | None = None,
output_name: str | None = None,
input_name: Optional[str] = None,
output_name: Optional[str] = None,
):
super().__init__(input_name=input_name, output_name=output_name)
@@ -483,8 +483,8 @@ class LiveKitTransport(BaseTransport):
self._client = LiveKitTransportClient(
url, token, room_name, self._params, callbacks, self.name
)
self._input: LiveKitInputTransport | None = None
self._output: LiveKitOutputTransport | None = None
self._input: Optional[LiveKitInputTransport] = None
self._output: Optional[LiveKitOutputTransport] = None
self._register_event_handler("on_connected")
self._register_event_handler("on_disconnected")
@@ -562,12 +562,12 @@ class LiveKitTransport(BaseTransport):
await self._input.push_app_message(data.decode(), participant_id)
await self._call_event_handler("on_data_received", data, participant_id)
async def send_message(self, message: str, participant_id: str | None = None):
async def send_message(self, message: str, participant_id: Optional[str] = None):
if self._output:
frame = LiveKitTransportMessageFrame(message=message, participant_id=participant_id)
await self._output.send_message(frame)
async def send_message_urgent(self, message: str, participant_id: str | None = None):
async def send_message_urgent(self, message: str, participant_id: Optional[str] = None):
if self._output:
frame = LiveKitTransportMessageUrgentFrame(
message=message, participant_id=participant_id