Compare commits
33 Commits
update-mod
...
v0.0.65
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b23ca5a4a8 | ||
|
|
63a6697a90 | ||
|
|
f1e45d0f02 | ||
|
|
4ad227ca2d | ||
|
|
66cc18194b | ||
|
|
7d65132c93 | ||
|
|
db7d7a4204 | ||
|
|
7bbac11084 | ||
|
|
76c8322b57 | ||
|
|
7b1cd3523d | ||
|
|
6bd821ac9a | ||
|
|
b91780ced2 | ||
|
|
8ded666958 | ||
|
|
2490c804a5 | ||
|
|
dd8856a673 | ||
|
|
e7da08dab1 | ||
|
|
ae60d42016 | ||
|
|
50e8d82ece | ||
|
|
cc9901a82f | ||
|
|
1fd43e8a3f | ||
|
|
fdc508a1a5 | ||
|
|
37269db247 | ||
|
|
51269aabbd | ||
|
|
74ecc19e09 | ||
|
|
c6d48c16df | ||
|
|
873d84aa09 | ||
|
|
7360866c97 | ||
|
|
81f4768661 | ||
|
|
972d65f61b | ||
|
|
1da9d398e3 | ||
|
|
7358bc6428 | ||
|
|
a6af499f84 | ||
|
|
a9b551d73e |
45
CHANGELOG.md
45
CHANGELOG.md
@@ -5,10 +5,41 @@ All notable changes to **Pipecat** will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
## [0.0.65] - 2025-04-23 "Sant Jordi's release"
|
||||
|
||||
https://en.wikipedia.org/wiki/Saint_George%27s_Day_in_Catalonia
|
||||
|
||||
### Added
|
||||
|
||||
- Added automatic hangup logic to the Telnyx serializer. This feature hangs up
|
||||
the Telnyx call when an `EndFrame` or `CancelFrame` is received. It is
|
||||
enabled by default and is configurable via the `auto_hang_up` `InputParam`.
|
||||
|
||||
- Added a keepalive task to `GladiaSTTService` to prevent the websocket from
|
||||
disconnecting after 30 seconds of no audio input.
|
||||
|
||||
### Changed
|
||||
|
||||
- The `InputParams` for `ElevenLabsTTSService` and `ElevenLabsHttpTTSService`
|
||||
no longer require that `stability` and `similarity_boost` be set. You can
|
||||
individually set each param.
|
||||
|
||||
- In `TwilioFrameSerializer`, `call_sid` is Optional so as to avoid a breaking
|
||||
changed. `call_sid` is required to automatically hang up.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed an issue where `TwilioFrameSerializer` would send two hang up commands:
|
||||
one for the `EndFrame` and one for the `CancelFrame`.
|
||||
|
||||
## [0.0.64] - 2025-04-22
|
||||
|
||||
### Added
|
||||
|
||||
- Added automatic hangup logic to the Twilio serializer. This feature hangs up
|
||||
the Twilio call when an `EndFrame` or `CancelFrame` is received. It is
|
||||
enabled by default and is configurable via the `auto_hang_up` `InputParam`.
|
||||
|
||||
- Added `SmartTurnMetricsData`, which contains end-of-turn prediction metrics,
|
||||
to the `MetricsFrame`. Using `MetricsFrame`, you can now retrieve prediction
|
||||
confidence scores and processing time metrics from the smart turn analyzers.
|
||||
@@ -17,9 +48,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
`GoogleSTTService`, `GoogleTTSService`, and `GoogleVertexLLMService`.
|
||||
|
||||
- Added support for Smart Turn Detection via the `turn_analyzer` transport
|
||||
parameter. You can now choose between `SmartTurnAnalyzer()` for remote
|
||||
inference or `LocalCoreMLSmartTurnAnalyzer()` for on-device inference using
|
||||
Core ML.
|
||||
parameter. You can now choose between `HttpSmartTurnAnalyzer()` or
|
||||
`FalSmartTurnAnalyzer()` for remote inference or
|
||||
`LocalCoreMLSmartTurnAnalyzer()` for on-device inference using Core ML.
|
||||
|
||||
- `DeepgramTTSService` accepts `base_url` argument again, allowing you to
|
||||
connect to an on-prem service.
|
||||
@@ -44,6 +75,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Changed
|
||||
|
||||
- `GrokLLMService` now uses `grok-3-beta` as its default model.
|
||||
|
||||
- Daily's REST helpers now include an `eject_at_token_exp` param, which ejects
|
||||
the user when their token expires. This new parameter defaults to False.
|
||||
Also, the default value for `enable_prejoin_ui` changed to False and
|
||||
@@ -78,6 +111,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Fixed an issue where LLM input parameters were not working and applied correctly in `GoogleVertexLLMService`, causing
|
||||
unexpected behavior during inference.
|
||||
|
||||
### Other
|
||||
|
||||
- Updated the `twilio-chatbot` example to use the auto-hangup feature.
|
||||
|
||||
## [0.0.63] - 2025-04-11
|
||||
|
||||
### Added
|
||||
|
||||
@@ -96,4 +96,8 @@ PIPER_BASE_URL=...
|
||||
|
||||
# Smart turn
|
||||
LOCAL_SMART_TURN_MODEL_PATH=
|
||||
REMOTE_SMART_TURN_URL=
|
||||
FAL_SMART_TURN_API_KEY=...
|
||||
|
||||
# Twilio
|
||||
TWILIO_ACCOUNT_SID=
|
||||
TWILIO_AUTH_TOKEN=
|
||||
113
examples/foundational/38-smart-turn-fal.py
Normal file
113
examples/foundational/38-smart-turn-fal.py
Normal file
@@ -0,0 +1,113 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.fal_smart_turn import FalSmartTurnAnalyzer
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def run_bot(webrtc_connection: SmallWebRTCConnection):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
transport = SmallWebRTCTransport(
|
||||
webrtc_connection=webrtc_connection,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
vad_audio_passthrough=True,
|
||||
turn_analyzer=FalSmartTurnAnalyzer(
|
||||
api_key=os.getenv("FAL_SMART_TURN_API_KEY"), aiohttp_session=session
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
|
||||
@transport.event_handler("on_client_closed")
|
||||
async def on_client_closed(transport, client):
|
||||
logger.info(f"Client closed connection")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from run import main
|
||||
|
||||
main()
|
||||
@@ -1,111 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn import SmartTurnAnalyzer
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.transports.base_transport import TransportParams
|
||||
from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
async def run_bot(webrtc_connection: SmallWebRTCConnection):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
remote_smart_turn_url = os.getenv("REMOTE_SMART_TURN_URL")
|
||||
|
||||
transport = SmallWebRTCTransport(
|
||||
webrtc_connection=webrtc_connection,
|
||||
params=TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
vad_audio_passthrough=True,
|
||||
turn_analyzer=SmartTurnAnalyzer(url=remote_smart_turn_url),
|
||||
),
|
||||
)
|
||||
|
||||
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="71a7ad14-091c-4e8e-a314-022ece01c121", # British Reading Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt,
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
|
||||
@transport.event_handler("on_client_closed")
|
||||
async def on_client_closed(transport, client):
|
||||
logger.info(f"Client closed connection")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=False)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from run import main
|
||||
|
||||
main()
|
||||
@@ -9,8 +9,8 @@ import os
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.local_smart_turn import LocalCoreMLSmartTurnAnalyzer
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_coreml_smart_turn import LocalCoreMLSmartTurnAnalyzer
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
@@ -7,7 +7,6 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -18,6 +17,7 @@ import uvicorn
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import BackgroundTasks, FastAPI
|
||||
from fastapi.responses import RedirectResponse
|
||||
from loguru import logger
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
@@ -25,14 +25,6 @@ from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Configure logger
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
logger = logging.getLogger("pipecat-server")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Store connections by pc_id
|
||||
@@ -162,10 +154,11 @@ def main():
|
||||
parser.add_argument("--verbose", "-v", action="count", default=0)
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.remove(0)
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger.add(sys.stderr, level="TRACE")
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
# Infer the bot file from the caller if not provided explicitly
|
||||
bot_file = args.bot_file
|
||||
|
||||
@@ -26,9 +26,6 @@ from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class MirrorProcessor(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict
|
||||
|
||||
@@ -9,6 +15,7 @@ from bot import run_bot
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import BackgroundTasks, FastAPI
|
||||
from fastapi.responses import RedirectResponse
|
||||
from loguru import logger
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
@@ -16,8 +23,6 @@ from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger = logging.getLogger("pc")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Store connections by pc_id
|
||||
@@ -81,9 +86,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--verbose", "-v", action="count")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.remove(0)
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger.add(sys.stderr, level="TRACE")
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
@@ -25,9 +25,6 @@ from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
class EdgeDetectionProcessor(FrameProcessor):
|
||||
def __init__(self, camera_out_width, camera_out_height: int):
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict
|
||||
|
||||
@@ -9,6 +15,7 @@ from bot import run_bot
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import BackgroundTasks, FastAPI
|
||||
from fastapi.responses import RedirectResponse
|
||||
from loguru import logger
|
||||
from pipecat_ai_small_webrtc_prebuilt.frontend import SmallWebRTCPrebuiltUI
|
||||
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
@@ -16,8 +23,6 @@ from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger = logging.getLogger("pc")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Store connections by pc_id
|
||||
@@ -81,9 +86,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--verbose", "-v", action="count")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.remove(0)
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger.add(sys.stderr, level="TRACE")
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
@@ -20,10 +20,6 @@ from pipecat.transports.network.small_webrtc import SmallWebRTCTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
SYSTEM_INSTRUCTION = f"""
|
||||
"You are Gemini Chatbot, a friendly, helpful robot.
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Dict
|
||||
|
||||
@@ -9,14 +15,13 @@ from bot import run_bot
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import BackgroundTasks, FastAPI
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.transports.network.webrtc_connection import SmallWebRTCConnection
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger = logging.getLogger("pc")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Store connections by pc_id
|
||||
@@ -73,9 +78,10 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--verbose", "-v", action="count")
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.remove(0)
|
||||
if args.verbose:
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger.add(sys.stderr, level="TRACE")
|
||||
else:
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
DAILY_SAMPLE_ROOM_URL=https://yourdomain.daily.co/yourroom # (optional: for joining the bot to the same room repeatedly for local dev)
|
||||
DAILY_API_KEY=
|
||||
DAILY_API_URL=api.daily.co/v1
|
||||
DAILY_API_URL=https://api.daily.co/v1
|
||||
DEEPGRAM_API_KEY=
|
||||
OPENAI_API_KEY=
|
||||
GOOGLE_API_KEY
|
||||
CARTESIA_API_KEY=
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
pipecat-ai[daily,cartesia,openai,google,silero]
|
||||
fastapi==3.11.12
|
||||
pipecat-ai[daily,cartesia,deepgram,openai,google,silero]
|
||||
fastapi==0.115.6
|
||||
uvicorn
|
||||
python-dotenv
|
||||
twilio
|
||||
|
||||
@@ -63,20 +63,35 @@ This project is a FastAPI-based chatbot that integrates with Telnyx to handle We
|
||||
ngrok http 8765
|
||||
```
|
||||
|
||||
2. **Update the Telnyx TeXML applications Webhook**:
|
||||
2. **Purchase a number**
|
||||
|
||||
- Go to your TeXML configuration page
|
||||
- Provide the ngrok URL to the Webhook URL field and ensure the POST method is selected
|
||||
- Click Save at the bottom of the page
|
||||
If you haven't already, purchase a number from Telnyx.
|
||||
|
||||
3. **Configure streams.xml**:
|
||||
- Log in to the Telnyx developer portal: https://portal.telnyx.com/
|
||||
- Buy a number: https://portal.telnyx.com/#/numbers/buy-numbers
|
||||
|
||||
3. **Update the Telnyx TeXML applications Webhook**:
|
||||
|
||||
- Go to your TeXML configuration page: https://portal.telnyx.com/#/call-control/texml
|
||||
- Create a new TeXML app, if one doesn't exist already:
|
||||
- Add an application name
|
||||
- Under Webhooks, select POST as the "Voice Method"
|
||||
- Select "Custom URL" under Webhook URL Method
|
||||
- Enter your ngrok URL in the "Webhook URL" field (e.g. https://your-name.ngrok.io)
|
||||
- Click "Create" to save
|
||||
Note: You'll see subsequent pages to set up SIP and Outbound, both are not required, so just skip.
|
||||
- Navigate to "Manage Numbers" (https://portal.telnyx.com/#/numbers/my-numbers) and under SIP connection, select the pencil icon to edit and select the TeXML application that you just created.
|
||||
|
||||
Now your number is ready to call.
|
||||
|
||||
4. **Configure streams.xml**:
|
||||
- Copy the template file to create your local version:
|
||||
```sh
|
||||
cp templates/streams.xml.template templates/streams.xml
|
||||
```
|
||||
- In `templates/streams.xml`, replace `<your server url>` with your ngrok URL (without `https://`)
|
||||
- The final URL should look like: `wss://abc123.ngrok.io/ws`
|
||||
- The encoding (`bidirectionalCodec`) should be `PCMU` or `PCMA` depending on your needs. Based on selected encoding, set the outbound_encoding in `server.py` when the bot is initialized.
|
||||
- The final URL should look like: `wss://abc123.ngrok.io/ws`. This needs to be the same URL that you added to your TeXML app above.
|
||||
- The encoding (`bidirectionalCodec`) should be `PCMU` or `PCMA` depending on your needs. Based on selected encoding, set the outbound_encoding in `server.py` when the bot is initialized. (No changes are required by default.)
|
||||
- The inbound encoding can be controlled from the application configuration for inbound calls and dial/transfer commands for outbound calls.
|
||||
|
||||
## Running the Application
|
||||
|
||||
@@ -33,9 +33,18 @@ logger.add(sys.stderr, level="DEBUG")
|
||||
async def run_bot(
|
||||
websocket_client,
|
||||
stream_id: str,
|
||||
call_control_id: str,
|
||||
outbound_encoding: str,
|
||||
inbound_encoding: str,
|
||||
):
|
||||
serializer = TelnyxFrameSerializer(
|
||||
stream_id=stream_id,
|
||||
outbound_encoding=outbound_encoding,
|
||||
inbound_encoding=inbound_encoding,
|
||||
call_control_id=call_control_id,
|
||||
api_key=os.getenv("TELNYX_API_KEY"),
|
||||
)
|
||||
|
||||
transport = FastAPIWebsocketTransport(
|
||||
websocket=websocket_client,
|
||||
params=FastAPIWebsocketParams(
|
||||
@@ -44,7 +53,7 @@ async def run_bot(
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_audio_passthrough=True,
|
||||
serializer=TelnyxFrameSerializer(stream_id, outbound_encoding, inbound_encoding),
|
||||
serializer=serializer,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -37,9 +37,10 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
call_data = json.loads(await start_data.__anext__())
|
||||
print(call_data, flush=True)
|
||||
stream_id = call_data["stream_id"]
|
||||
call_control_id = call_data["start"]["call_control_id"]
|
||||
outbound_encoding = call_data["start"]["media_format"]["encoding"]
|
||||
print("WebSocket connection accepted")
|
||||
await run_bot(websocket, stream_id, outbound_encoding, "PCMU")
|
||||
await run_bot(websocket, stream_id, call_control_id, outbound_encoding, "PCMU")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -54,7 +54,14 @@ async def save_audio(server_name: str, audio: bytes, sample_rate: int, num_chann
|
||||
logger.info("No audio data to save")
|
||||
|
||||
|
||||
async def run_bot(websocket_client: WebSocket, stream_sid: str, testing: bool):
|
||||
async def run_bot(websocket_client: WebSocket, stream_sid: str, call_sid: str, testing: bool):
|
||||
serializer = TwilioFrameSerializer(
|
||||
stream_sid=stream_sid,
|
||||
call_sid=call_sid,
|
||||
account_sid=os.getenv("TWILIO_ACCOUNT_SID", ""),
|
||||
auth_token=os.getenv("TWILIO_AUTH_TOKEN", ""),
|
||||
)
|
||||
|
||||
transport = FastAPIWebsocketTransport(
|
||||
websocket=websocket_client,
|
||||
params=FastAPIWebsocketParams(
|
||||
@@ -64,7 +71,7 @@ async def run_bot(websocket_client: WebSocket, stream_sid: str, testing: bool):
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_audio_passthrough=True,
|
||||
serializer=TwilioFrameSerializer(stream_sid),
|
||||
serializer=serializer,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -38,8 +38,9 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
call_data = json.loads(await start_data.__anext__())
|
||||
print(call_data, flush=True)
|
||||
stream_sid = call_data["start"]["streamSid"]
|
||||
call_sid = call_data["start"]["callSid"]
|
||||
print("WebSocket connection accepted")
|
||||
await run_bot(websocket, stream_sid, app.state.testing)
|
||||
await run_bot(websocket, stream_sid, call_sid, app.state.testing)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -71,7 +71,7 @@ class BaseTurnAnalyzer(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def analyze_end_of_turn(self) -> Tuple[EndOfTurnState, Optional[MetricsData]]:
|
||||
async def analyze_end_of_turn(self) -> Tuple[EndOfTurnState, Optional[MetricsData]]:
|
||||
"""Analyzes if an end of turn has occurred based on the audio input.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.base_smart_turn import BaseSmartTurn
|
||||
|
||||
|
||||
class SmartTurnAnalyzer(BaseSmartTurn):
|
||||
def __init__(self, url: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.remote_smart_turn_url = url
|
||||
|
||||
if not self.remote_smart_turn_url:
|
||||
logger.error("remote_smart_turn_url is not set.")
|
||||
raise Exception("remote_smart_turn_url must be provided.")
|
||||
|
||||
# Use a session to reuse connections (keep-alive)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({"Connection": "keep-alive"})
|
||||
|
||||
def _serialize_array(self, audio_array: np.ndarray) -> bytes:
|
||||
logger.trace("Serializing NumPy array to bytes...")
|
||||
buffer = io.BytesIO()
|
||||
np.save(buffer, audio_array)
|
||||
serialized_bytes = buffer.getvalue()
|
||||
logger.trace(f"Serialized size: {len(serialized_bytes)} bytes")
|
||||
return serialized_bytes
|
||||
|
||||
def _send_raw_request(self, data_bytes: bytes):
|
||||
headers = {"Content-Type": "application/octet-stream"}
|
||||
logger.trace(
|
||||
f"Sending {len(data_bytes)} bytes as raw body to {self.remote_smart_turn_url}..."
|
||||
)
|
||||
try:
|
||||
response = self.session.post(
|
||||
self.remote_smart_turn_url,
|
||||
data=data_bytes,
|
||||
headers=headers,
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
logger.trace("\n--- Response ---")
|
||||
logger.trace(f"Status Code: {response.status_code}")
|
||||
|
||||
if response.ok:
|
||||
try:
|
||||
logger.trace("Response JSON:")
|
||||
logger.trace(response.json())
|
||||
return response.json()
|
||||
except requests.exceptions.JSONDecodeError:
|
||||
logger.trace("Response Content (non-JSON):")
|
||||
logger.trace(response.text)
|
||||
else:
|
||||
logger.trace("Response Content (Error):")
|
||||
logger.trace(response.text)
|
||||
response.raise_for_status()
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to send raw request to Daily Smart Turn: {e}")
|
||||
raise Exception("Failed to send raw request to Daily Smart Turn.")
|
||||
|
||||
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, any]:
|
||||
serialized_array = self._serialize_array(audio_array)
|
||||
return self._send_raw_request(serialized_array)
|
||||
0
src/pipecat/audio/turn/smart_turn/__init__.py
Normal file
0
src/pipecat/audio/turn/smart_turn/__init__.py
Normal file
@@ -30,6 +30,10 @@ class SmartTurnParams(BaseModel):
|
||||
# use_only_last_vad_segment: bool = USE_ONLY_LAST_VAD_SEGMENT
|
||||
|
||||
|
||||
class SmartTurnTimeoutException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseSmartTurn(BaseTurnAnalyzer):
|
||||
def __init__(
|
||||
self, *, sample_rate: Optional[int] = None, params: SmartTurnParams = SmartTurnParams()
|
||||
@@ -42,7 +46,7 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
||||
self._audio_buffer = []
|
||||
self._speech_triggered = False
|
||||
self._silence_ms = 0
|
||||
self._speech_start_time = None
|
||||
self._speech_start_time = 0
|
||||
|
||||
@property
|
||||
def speech_triggered(self) -> bool:
|
||||
@@ -60,7 +64,7 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
||||
# Reset silence tracking on speech
|
||||
self._silence_ms = 0
|
||||
self._speech_triggered = True
|
||||
if self._speech_start_time is None:
|
||||
if self._speech_start_time == 0:
|
||||
self._speech_start_time = time.time()
|
||||
else:
|
||||
if self._speech_triggered:
|
||||
@@ -87,8 +91,8 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
||||
|
||||
return state
|
||||
|
||||
def analyze_end_of_turn(self) -> Tuple[EndOfTurnState, Optional[MetricsData]]:
|
||||
state, result = self._process_speech_segment(self._audio_buffer)
|
||||
async def analyze_end_of_turn(self) -> Tuple[EndOfTurnState, Optional[MetricsData]]:
|
||||
state, result = await self._process_speech_segment(self._audio_buffer)
|
||||
if state == EndOfTurnState.COMPLETE or USE_ONLY_LAST_VAD_SEGMENT:
|
||||
self._clear(state)
|
||||
logger.debug(f"End of Turn result: {state}")
|
||||
@@ -98,10 +102,12 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
||||
# If the state is still incomplete, keep the _speech_triggered as True
|
||||
self._speech_triggered = turn_state == EndOfTurnState.INCOMPLETE
|
||||
self._audio_buffer = []
|
||||
self._speech_start_time = None
|
||||
self._speech_start_time = 0
|
||||
self._silence_ms = 0
|
||||
|
||||
def _process_speech_segment(self, audio_buffer) -> Tuple[EndOfTurnState, Optional[MetricsData]]:
|
||||
async def _process_speech_segment(
|
||||
self, audio_buffer
|
||||
) -> Tuple[EndOfTurnState, Optional[MetricsData]]:
|
||||
state = EndOfTurnState.INCOMPLETE
|
||||
|
||||
if not audio_buffer:
|
||||
@@ -131,30 +137,41 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
||||
|
||||
if len(segment_audio) > 0:
|
||||
start_time = time.perf_counter()
|
||||
result = self._predict_endpoint(segment_audio)
|
||||
state = (
|
||||
EndOfTurnState.COMPLETE if result["prediction"] == 1 else EndOfTurnState.INCOMPLETE
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
try:
|
||||
result = await self._predict_endpoint(segment_audio)
|
||||
state = (
|
||||
EndOfTurnState.COMPLETE
|
||||
if result["prediction"] == 1
|
||||
else EndOfTurnState.INCOMPLETE
|
||||
)
|
||||
end_time = time.perf_counter()
|
||||
|
||||
# Calculate processing time
|
||||
e2e_processing_time_ms = (end_time - start_time) * 1000
|
||||
# Calculate processing time
|
||||
e2e_processing_time_ms = (end_time - start_time) * 1000
|
||||
|
||||
# Prepare the result data
|
||||
result_data = SmartTurnMetricsData(
|
||||
processor="BaseSmartTurn",
|
||||
is_complete=result["prediction"] == 1,
|
||||
probability=result["probability"],
|
||||
inference_time_ms=result.get("inference_time", 0) * 1000,
|
||||
server_total_time_ms=result.get("total_time", 0) * 1000,
|
||||
e2e_processing_time_ms=e2e_processing_time_ms,
|
||||
)
|
||||
# Prepare the result data
|
||||
result_data = SmartTurnMetricsData(
|
||||
processor="BaseSmartTurn",
|
||||
is_complete=result["prediction"] == 1,
|
||||
probability=result["probability"],
|
||||
inference_time_ms=result.get("inference_time", 0) * 1000,
|
||||
server_total_time_ms=result.get("total_time", 0) * 1000,
|
||||
e2e_processing_time_ms=e2e_processing_time_ms,
|
||||
)
|
||||
|
||||
logger.trace(
|
||||
f"Prediction: {'Complete' if result_data.is_complete else 'Incomplete'}"
|
||||
)
|
||||
logger.trace(f"Probability of complete: {result_data.probability:.4f}")
|
||||
logger.trace(f"Inference time: {result_data.inference_time_ms:.2f}ms")
|
||||
logger.trace(f"Server total time: {result_data.server_total_time_ms:.2f}ms")
|
||||
logger.trace(f"E2E processing time: {result_data.e2e_processing_time_ms:.2f}ms")
|
||||
except SmartTurnTimeoutException:
|
||||
logger.debug(
|
||||
f"End of Turn complete due to stop_secs. Silence in ms: {self._silence_ms}"
|
||||
)
|
||||
state = EndOfTurnState.COMPLETE
|
||||
|
||||
logger.trace(f"Prediction: {'Complete' if result_data.is_complete else 'Incomplete'}")
|
||||
logger.trace(f"Probability of complete: {result_data.probability:.4f}")
|
||||
logger.trace(f"Inference time: {result_data.inference_time_ms:.2f}ms")
|
||||
logger.trace(f"Server total time: {result_data.server_total_time_ms:.2f}ms")
|
||||
logger.trace(f"E2E processing time: {result_data.e2e_processing_time_ms:.2f}ms")
|
||||
else:
|
||||
logger.trace(f"params: {self._params}, stop_ms: {self._stop_ms}")
|
||||
logger.trace("Captured empty audio segment, skipping prediction.")
|
||||
@@ -162,11 +179,11 @@ class BaseSmartTurn(BaseTurnAnalyzer):
|
||||
return state, result_data
|
||||
|
||||
@abstractmethod
|
||||
def _predict_endpoint(self, buffer: np.ndarray) -> Dict[str, Any]:
|
||||
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
"""Abstract method to predict if a turn has ended based on audio.
|
||||
|
||||
Args:
|
||||
buffer: Float32 numpy array of audio samples at 16kHz.
|
||||
audio_array: Float32 numpy array of audio samples at 16kHz.
|
||||
|
||||
Returns:
|
||||
Dictionary with:
|
||||
26
src/pipecat/audio/turn/smart_turn/fal_smart_turn.py
Normal file
26
src/pipecat/audio/turn/smart_turn/fal_smart_turn.py
Normal file
@@ -0,0 +1,26 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
|
||||
from pipecat.audio.turn.smart_turn.http_smart_turn import HttpSmartTurnAnalyzer
|
||||
|
||||
|
||||
class FalSmartTurnAnalyzer(HttpSmartTurnAnalyzer):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
url: str = "https://fal.run/fal-ai/smart-turn/raw",
|
||||
api_key: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
headers = {}
|
||||
if api_key:
|
||||
headers = {"Authorization": f"Key {api_key}"}
|
||||
super().__init__(url=url, aiohttp_session=aiohttp_session, headers=headers, **kwargs)
|
||||
80
src/pipecat/audio/turn/smart_turn/http_smart_turn.py
Normal file
80
src/pipecat/audio/turn/smart_turn/http_smart_turn.py
Normal file
@@ -0,0 +1,80 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn, SmartTurnTimeoutException
|
||||
|
||||
|
||||
class HttpSmartTurnAnalyzer(BaseSmartTurn):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
aiohttp_session: aiohttp.ClientSession,
|
||||
headers: Dict[str, str] = {},
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._url = url
|
||||
self._headers = headers
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
def _serialize_array(self, audio_array: np.ndarray) -> bytes:
|
||||
logger.trace("Serializing NumPy array to bytes...")
|
||||
buffer = io.BytesIO()
|
||||
np.save(buffer, audio_array)
|
||||
serialized_bytes = buffer.getvalue()
|
||||
logger.trace(f"Serialized size: {len(serialized_bytes)} bytes")
|
||||
return serialized_bytes
|
||||
|
||||
async def _send_raw_request(self, data_bytes: bytes) -> Dict[str, Any]:
|
||||
headers = {"Content-Type": "application/octet-stream"}
|
||||
headers.update(self._headers)
|
||||
logger.trace(f"Sending {len(data_bytes)} bytes as raw body to {self._url}...")
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=self._params.stop_secs)
|
||||
|
||||
async with self._aiohttp_session.post(
|
||||
self._url, data=data_bytes, headers=headers, timeout=timeout
|
||||
) as response:
|
||||
logger.trace("\n--- Response ---")
|
||||
logger.trace(f"Status Code: {response.status}")
|
||||
|
||||
if response.status == 200:
|
||||
try:
|
||||
json_data = await response.json()
|
||||
logger.trace("Response JSON:")
|
||||
logger.trace(json_data)
|
||||
return json_data
|
||||
except aiohttp.ContentTypeError:
|
||||
# Non-JSON response
|
||||
text = await response.text()
|
||||
logger.trace("Response Content (non-JSON):")
|
||||
logger.trace(text)
|
||||
raise Exception(f"Non-JSON response: {text}")
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.trace("Response Content (Error):")
|
||||
logger.trace(error_text)
|
||||
response.raise_for_status()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Request timed out after {self._params.stop_secs} seconds")
|
||||
raise SmartTurnTimeoutException(f"Request exceeded {self._params.stop_secs} seconds.")
|
||||
except aiohttp.ClientError as e:
|
||||
logger.error(f"Failed to send raw request to Daily Smart Turn: {e}")
|
||||
raise Exception("Failed to send raw request to Daily Smart Turn.")
|
||||
|
||||
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
serialized_array = self._serialize_array(audio_array)
|
||||
return await self._send_raw_request(serialized_array)
|
||||
@@ -5,17 +5,16 @@
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.base_smart_turn import BaseSmartTurn
|
||||
from pipecat.audio.turn.smart_turn.base_smart_turn import BaseSmartTurn
|
||||
|
||||
try:
|
||||
import coremltools as ct
|
||||
import torch
|
||||
from transformers import AutoFeatureExtractor
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
@@ -26,7 +25,7 @@ except ModuleNotFoundError as e:
|
||||
|
||||
|
||||
class LocalCoreMLSmartTurnAnalyzer(BaseSmartTurn):
|
||||
def __init__(self, smart_turn_model_path: str, **kwargs):
|
||||
def __init__(self, *, smart_turn_model_path: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if not smart_turn_model_path:
|
||||
@@ -41,7 +40,7 @@ class LocalCoreMLSmartTurnAnalyzer(BaseSmartTurn):
|
||||
self._turn_model = ct.models.MLModel(core_ml_model_path)
|
||||
logger.debug("Loaded Local Smart Turn")
|
||||
|
||||
def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, any]:
|
||||
async def _predict_endpoint(self, audio_array: np.ndarray) -> Dict[str, Any]:
|
||||
inputs = self._turn_processor(
|
||||
audio_array,
|
||||
sampling_rate=16000,
|
||||
@@ -8,6 +8,8 @@ import base64
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.audio.utils import (
|
||||
@@ -19,6 +21,8 @@ from pipecat.audio.utils import (
|
||||
)
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputDTMFFrame,
|
||||
@@ -30,38 +34,120 @@ from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializer
|
||||
|
||||
|
||||
class TelnyxFrameSerializer(FrameSerializer):
|
||||
"""Serializer for Telnyx WebSocket protocol.
|
||||
|
||||
This serializer handles converting between Pipecat frames and Telnyx's WebSocket
|
||||
media streams protocol. It supports audio conversion, DTMF events, and automatic
|
||||
call termination.
|
||||
|
||||
When auto_hang_up is enabled (default), the serializer will automatically terminate
|
||||
the Telnyx call when an EndFrame or CancelFrame is processed, but requires Telnyx
|
||||
credentials to be provided.
|
||||
|
||||
Attributes:
|
||||
_stream_id: The Telnyx Stream ID.
|
||||
_call_control_id: The associated Telnyx Call Control ID.
|
||||
_api_key: Telnyx API key for API access.
|
||||
_params: Configuration parameters.
|
||||
_telnyx_sample_rate: Sample rate used by Telnyx (typically 8kHz).
|
||||
_sample_rate: Input sample rate for the pipeline.
|
||||
_resampler: Audio resampler for format conversion.
|
||||
_hangup_attempted: Flag to track if hang-up has been attempted.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
telnyx_sample_rate: int = 8000 # Default Telnyx rate (8kHz)
|
||||
sample_rate: Optional[int] = None # Pipeline input rate
|
||||
"""Configuration parameters for TelnyxFrameSerializer.
|
||||
|
||||
Attributes:
|
||||
telnyx_sample_rate: Sample rate used by Telnyx, defaults to 8000 Hz.
|
||||
sample_rate: Optional override for pipeline input sample rate.
|
||||
inbound_encoding: Audio encoding for data sent to Telnyx (e.g., "PCMU").
|
||||
outbound_encoding: Audio encoding for data received from Telnyx (e.g., "PCMU").
|
||||
auto_hang_up: Whether to automatically terminate call on EndFrame.
|
||||
"""
|
||||
|
||||
telnyx_sample_rate: int = 8000
|
||||
sample_rate: Optional[int] = None
|
||||
inbound_encoding: str = "PCMU"
|
||||
outbound_encoding: str = "PCMU"
|
||||
auto_hang_up: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_id: str,
|
||||
outbound_encoding: str,
|
||||
inbound_encoding: str,
|
||||
call_control_id: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
params: InputParams = InputParams(),
|
||||
):
|
||||
"""Initialize the TelnyxFrameSerializer.
|
||||
|
||||
Args:
|
||||
stream_id: The Stream ID for Telnyx.
|
||||
outbound_encoding: The encoding type for outbound audio (e.g., "PCMU").
|
||||
inbound_encoding: The encoding type for inbound audio (e.g., "PCMU").
|
||||
call_control_id: The Call Control ID for the Telnyx call (optional, but required for auto hang-up).
|
||||
api_key: Your Telnyx API key (required for auto hang-up).
|
||||
params: Configuration parameters.
|
||||
"""
|
||||
self._stream_id = stream_id
|
||||
params.outbound_encoding = outbound_encoding
|
||||
params.inbound_encoding = inbound_encoding
|
||||
self._call_control_id = call_control_id
|
||||
self._api_key = api_key
|
||||
self._params = params
|
||||
|
||||
self._telnyx_sample_rate = self._params.telnyx_sample_rate
|
||||
self._sample_rate = 0 # Pipeline input rate
|
||||
|
||||
self._resampler = create_default_resampler()
|
||||
self._hangup_attempted = False
|
||||
|
||||
@property
|
||||
def type(self) -> FrameSerializerType:
|
||||
"""Gets the serializer type.
|
||||
|
||||
Returns:
|
||||
The serializer type, either TEXT or BINARY.
|
||||
"""
|
||||
return FrameSerializerType.TEXT
|
||||
|
||||
async def setup(self, frame: StartFrame):
|
||||
"""Sets up the serializer with pipeline configuration.
|
||||
|
||||
Args:
|
||||
frame: The StartFrame containing pipeline configuration.
|
||||
"""
|
||||
self._sample_rate = self._params.sample_rate or frame.audio_in_sample_rate
|
||||
|
||||
async def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
if isinstance(frame, AudioRawFrame):
|
||||
"""Serializes a Pipecat frame to Telnyx WebSocket format.
|
||||
|
||||
Handles conversion of various frame types to Telnyx WebSocket messages.
|
||||
For EndFrames and CancelFrames, initiates call termination if auto_hang_up is enabled.
|
||||
|
||||
Args:
|
||||
frame: The Pipecat frame to serialize.
|
||||
|
||||
Returns:
|
||||
Serialized data as string or bytes, or None if the frame isn't handled.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported encoding is specified.
|
||||
"""
|
||||
if (
|
||||
self._params.auto_hang_up
|
||||
and not self._hangup_attempted
|
||||
and isinstance(frame, (EndFrame, CancelFrame))
|
||||
):
|
||||
self._hangup_attempted = True
|
||||
await self._hang_up_call()
|
||||
return None
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
answer = {"event": "clear"}
|
||||
return json.dumps(answer)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
data = frame.audio
|
||||
|
||||
# Output: Convert PCM at frame's rate to 8kHz encoded for Telnyx
|
||||
@@ -84,11 +170,58 @@ class TelnyxFrameSerializer(FrameSerializer):
|
||||
|
||||
return json.dumps(answer)
|
||||
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
answer = {"event": "clear"}
|
||||
return json.dumps(answer)
|
||||
# Return None for unhandled frames
|
||||
return None
|
||||
|
||||
async def _hang_up_call(self):
|
||||
"""Hang up the Telnyx call using Telnyx's REST API."""
|
||||
try:
|
||||
call_control_id = self._call_control_id
|
||||
api_key = self._api_key
|
||||
|
||||
if not call_control_id or not api_key:
|
||||
logger.warning(
|
||||
"Cannot hang up Telnyx call: call_control_id and api_key must be provided"
|
||||
)
|
||||
return
|
||||
|
||||
# Telnyx API endpoint for hanging up a call
|
||||
endpoint = f"https://api.telnyx.com/v2/calls/{call_control_id}/actions/hangup"
|
||||
|
||||
# Set headers with API key
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
|
||||
# Make the POST request to hang up the call
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(endpoint, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
logger.info(f"Successfully terminated Telnyx call {call_control_id}")
|
||||
else:
|
||||
# Get the error details for better debugging
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
f"Failed to terminate Telnyx call {call_control_id}: "
|
||||
f"Status {response.status}, Response: {error_text}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to hang up Telnyx call: {e}")
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Telnyx WebSocket data to Pipecat frames.
|
||||
|
||||
Handles conversion of Telnyx media events to appropriate Pipecat frames,
|
||||
including audio data and DTMF keypresses.
|
||||
|
||||
Args:
|
||||
data: The raw WebSocket data from Telnyx.
|
||||
|
||||
Returns:
|
||||
A Pipecat frame corresponding to the Telnyx event, or None if unhandled.
|
||||
|
||||
Raises:
|
||||
ValueError: If an unsupported encoding is specified.
|
||||
"""
|
||||
message = json.loads(data)
|
||||
|
||||
if message["event"] == "media":
|
||||
|
||||
@@ -8,11 +8,14 @@ import base64
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.audio.utils import create_default_resampler, pcm_to_ulaw, ulaw_to_pcm
|
||||
from pipecat.frames.frames import (
|
||||
AudioRawFrame,
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
InputDTMFFrame,
|
||||
@@ -26,28 +29,107 @@ from pipecat.serializers.base_serializer import FrameSerializer, FrameSerializer
|
||||
|
||||
|
||||
class TwilioFrameSerializer(FrameSerializer):
|
||||
class InputParams(BaseModel):
|
||||
twilio_sample_rate: int = 8000 # Default Twilio rate (8kHz)
|
||||
sample_rate: Optional[int] = None # Pipeline input rate
|
||||
"""Serializer for Twilio Media Streams WebSocket protocol.
|
||||
|
||||
def __init__(self, stream_sid: str, params: InputParams = InputParams()):
|
||||
This serializer handles converting between Pipecat frames and Twilio's WebSocket
|
||||
media streams protocol. It supports audio conversion, DTMF events, and automatic
|
||||
call termination.
|
||||
|
||||
When auto_hang_up is enabled (default), the serializer will automatically terminate
|
||||
the Twilio call when an EndFrame or CancelFrame is processed, but requires Twilio
|
||||
credentials to be provided.
|
||||
|
||||
Attributes:
|
||||
_stream_sid: The Twilio Media Stream SID.
|
||||
_call_sid: The associated Twilio Call SID.
|
||||
_account_sid: Twilio account SID for API access.
|
||||
_auth_token: Twilio authentication token for API access.
|
||||
_params: Configuration parameters.
|
||||
_twilio_sample_rate: Sample rate used by Twilio (typically 8kHz).
|
||||
_sample_rate: Input sample rate for the pipeline.
|
||||
_resampler: Audio resampler for format conversion.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
"""Configuration parameters for TwilioFrameSerializer.
|
||||
|
||||
Attributes:
|
||||
twilio_sample_rate: Sample rate used by Twilio, defaults to 8000 Hz.
|
||||
sample_rate: Optional override for pipeline input sample rate.
|
||||
auto_hang_up: Whether to automatically terminate call on EndFrame.
|
||||
"""
|
||||
|
||||
twilio_sample_rate: int = 8000
|
||||
sample_rate: Optional[int] = None
|
||||
auto_hang_up: bool = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_sid: str,
|
||||
call_sid: Optional[str] = None,
|
||||
account_sid: Optional[str] = None,
|
||||
auth_token: Optional[str] = None,
|
||||
params: InputParams = InputParams(),
|
||||
):
|
||||
"""Initialize the TwilioFrameSerializer.
|
||||
|
||||
Args:
|
||||
stream_sid: The Twilio Media Stream SID.
|
||||
call_sid: The associated Twilio Call SID (optional, but required for auto hang-up).
|
||||
account_sid: Twilio account SID (required for auto hang-up).
|
||||
auth_token: Twilio auth token (required for auto hang-up).
|
||||
params: Configuration parameters.
|
||||
"""
|
||||
self._stream_sid = stream_sid
|
||||
self._call_sid = call_sid
|
||||
self._account_sid = account_sid
|
||||
self._auth_token = auth_token
|
||||
self._params = params
|
||||
|
||||
self._twilio_sample_rate = self._params.twilio_sample_rate
|
||||
self._sample_rate = 0 # Pipeline input rate
|
||||
|
||||
self._resampler = create_default_resampler()
|
||||
self._hangup_attempted = False
|
||||
|
||||
@property
|
||||
def type(self) -> FrameSerializerType:
|
||||
"""Gets the serializer type.
|
||||
|
||||
Returns:
|
||||
The serializer type, either TEXT or BINARY.
|
||||
"""
|
||||
return FrameSerializerType.TEXT
|
||||
|
||||
async def setup(self, frame: StartFrame):
|
||||
"""Sets up the serializer with pipeline configuration.
|
||||
|
||||
Args:
|
||||
frame: The StartFrame containing pipeline configuration.
|
||||
"""
|
||||
self._sample_rate = self._params.sample_rate or frame.audio_in_sample_rate
|
||||
|
||||
async def serialize(self, frame: Frame) -> str | bytes | None:
|
||||
if isinstance(frame, StartInterruptionFrame):
|
||||
"""Serializes a Pipecat frame to Twilio WebSocket format.
|
||||
|
||||
Handles conversion of various frame types to Twilio WebSocket messages.
|
||||
For EndFrames, initiates call termination if auto_hang_up is enabled.
|
||||
|
||||
Args:
|
||||
frame: The Pipecat frame to serialize.
|
||||
|
||||
Returns:
|
||||
Serialized data as string or bytes, or None if the frame isn't handled.
|
||||
"""
|
||||
if (
|
||||
self._params.auto_hang_up
|
||||
and not self._hangup_attempted
|
||||
and isinstance(frame, (EndFrame, CancelFrame))
|
||||
):
|
||||
self._hangup_attempted = True
|
||||
await self._hang_up_call()
|
||||
return None
|
||||
elif isinstance(frame, StartInterruptionFrame):
|
||||
answer = {"event": "clear", "streamSid": self._stream_sid}
|
||||
return json.dumps(answer)
|
||||
elif isinstance(frame, AudioRawFrame):
|
||||
@@ -68,7 +150,70 @@ class TwilioFrameSerializer(FrameSerializer):
|
||||
elif isinstance(frame, (TransportMessageFrame, TransportMessageUrgentFrame)):
|
||||
return json.dumps(frame.message)
|
||||
|
||||
# Return None for unhandled frames
|
||||
return None
|
||||
|
||||
async def _hang_up_call(self):
|
||||
"""Hang up the Twilio call using Twilio's REST API."""
|
||||
try:
|
||||
import aiohttp
|
||||
|
||||
account_sid = self._account_sid
|
||||
auth_token = self._auth_token
|
||||
call_sid = self._call_sid
|
||||
|
||||
if not call_sid or not account_sid or not auth_token:
|
||||
missing = []
|
||||
if not call_sid:
|
||||
missing.append("call_sid")
|
||||
if not account_sid:
|
||||
missing.append("account_sid")
|
||||
if not auth_token:
|
||||
missing.append("auth_token")
|
||||
|
||||
logger.warning(
|
||||
f"Cannot hang up Twilio call: missing required parameters: {', '.join(missing)}"
|
||||
)
|
||||
return
|
||||
|
||||
# Twilio API endpoint for updating calls
|
||||
endpoint = (
|
||||
f"https://api.twilio.com/2010-04-01/Accounts/{account_sid}/Calls/{call_sid}.json"
|
||||
)
|
||||
|
||||
# Create basic auth from account_sid and auth_token
|
||||
auth = aiohttp.BasicAuth(account_sid, auth_token)
|
||||
|
||||
# Parameters to set the call status to "completed" (hang up)
|
||||
params = {"Status": "completed"}
|
||||
|
||||
# Make the POST request to update the call
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(endpoint, auth=auth, data=params) as response:
|
||||
if response.status == 200:
|
||||
logger.info(f"Successfully terminated Twilio call {call_sid}")
|
||||
else:
|
||||
# Get the error details for better debugging
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
f"Failed to terminate Twilio call {call_sid}: "
|
||||
f"Status {response.status}, Response: {error_text}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to hang up Twilio call: {e}")
|
||||
|
||||
async def deserialize(self, data: str | bytes) -> Frame | None:
|
||||
"""Deserializes Twilio WebSocket data to Pipecat frames.
|
||||
|
||||
Handles conversion of Twilio media events to appropriate Pipecat frames.
|
||||
|
||||
Args:
|
||||
data: The raw WebSocket data from Twilio.
|
||||
|
||||
Returns:
|
||||
A Pipecat frame corresponding to the Twilio event, or None if unhandled.
|
||||
"""
|
||||
message = json.loads(data)
|
||||
|
||||
if message["event"] == "media":
|
||||
|
||||
@@ -126,31 +126,14 @@ def build_elevenlabs_voice_settings(
|
||||
settings: Dictionary containing voice settings parameters
|
||||
|
||||
Returns:
|
||||
Dictionary of voice settings or None if required parameters are missing
|
||||
Dictionary of voice settings or None if no valid settings are provided
|
||||
"""
|
||||
voice_setting_keys = ["stability", "similarity_boost", "style", "use_speaker_boost", "speed"]
|
||||
|
||||
voice_settings = {}
|
||||
if settings["stability"] is not None and settings["similarity_boost"] is not None:
|
||||
voice_settings["stability"] = settings["stability"]
|
||||
voice_settings["similarity_boost"] = settings["similarity_boost"]
|
||||
if settings["style"] is not None:
|
||||
voice_settings["style"] = settings["style"]
|
||||
if settings["use_speaker_boost"] is not None:
|
||||
voice_settings["use_speaker_boost"] = settings["use_speaker_boost"]
|
||||
if settings["speed"] is not None:
|
||||
voice_settings["speed"] = settings["speed"]
|
||||
else:
|
||||
if settings["style"] is not None:
|
||||
logger.warning(
|
||||
"'style' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||
)
|
||||
if settings["use_speaker_boost"] is not None:
|
||||
logger.warning(
|
||||
"'use_speaker_boost' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||
)
|
||||
if settings["speed"] is not None:
|
||||
logger.warning(
|
||||
"'speed' is set but will not be applied because 'stability' and 'similarity_boost' are not both set."
|
||||
)
|
||||
for key in voice_setting_keys:
|
||||
if key in settings and settings[key] is not None:
|
||||
voice_settings[key] = settings[key]
|
||||
|
||||
return voice_settings or None
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import warnings
|
||||
@@ -224,6 +225,7 @@ class GladiaSTTService(STTService):
|
||||
self._params = params
|
||||
self._websocket = None
|
||||
self._receive_task = None
|
||||
self._keepalive_task = None
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
"""Convert pipecat Language enum to Gladia's language code."""
|
||||
@@ -287,14 +289,22 @@ class GladiaSTTService(STTService):
|
||||
self._websocket = await websockets.connect(response["url"])
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
if self._websocket and not self._keepalive_task:
|
||||
self._keepalive_task = self.create_task(self._keepalive_task_handler())
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the Gladia STT websocket connection."""
|
||||
await super().stop(frame)
|
||||
await self._send_stop_recording()
|
||||
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
self._keepalive_task = None
|
||||
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._receive_task:
|
||||
await self.wait_for_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
@@ -302,7 +312,15 @@ class GladiaSTTService(STTService):
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the Gladia STT websocket connection."""
|
||||
await super().cancel(frame)
|
||||
await self._websocket.close()
|
||||
|
||||
if self._keepalive_task:
|
||||
await self.cancel_task(self._keepalive_task)
|
||||
self._keepalive_task = None
|
||||
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
self._websocket = None
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
self._receive_task = None
|
||||
@@ -341,6 +359,24 @@ class GladiaSTTService(STTService):
|
||||
if self._websocket and not self._websocket.closed:
|
||||
await self._websocket.send(json.dumps({"type": "stop_recording"}))
|
||||
|
||||
async def _keepalive_task_handler(self):
|
||||
"""Send periodic empty audio chunks to keep the connection alive."""
|
||||
try:
|
||||
while True:
|
||||
# Send keepalive every 20 seconds (Gladia times out after 30 seconds)
|
||||
await asyncio.sleep(20)
|
||||
if self._websocket and not self._websocket.closed:
|
||||
# Send an empty audio chunk as keepalive
|
||||
empty_audio = b""
|
||||
await self._send_audio(empty_audio)
|
||||
else:
|
||||
logger.debug("Websocket closed, stopping keepalive")
|
||||
break
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
logger.debug("Connection closed during keepalive")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Gladia keepalive task: {e}")
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
try:
|
||||
async for message in self._websocket:
|
||||
|
||||
@@ -42,7 +42,7 @@ class GrokLLMService(OpenAILLMService):
|
||||
Args:
|
||||
api_key (str): The API key for accessing Grok's API
|
||||
base_url (str, optional): The base URL for Grok API. Defaults to "https://api.x.ai/v1"
|
||||
model (str, optional): The model identifier to use. Defaults to "grok-2"
|
||||
model (str, optional): The model identifier to use. Defaults to "grok-3-beta"
|
||||
**kwargs: Additional keyword arguments passed to OpenAILLMService
|
||||
"""
|
||||
|
||||
@@ -51,7 +51,7 @@ class GrokLLMService(OpenAILLMService):
|
||||
*,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.x.ai/v1",
|
||||
model: str = "grok-2",
|
||||
model: str = "grok-3-beta",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||||
|
||||
@@ -222,12 +222,8 @@ class BaseInputTransport(FrameProcessor):
|
||||
|
||||
async def _handle_end_of_turn(self):
|
||||
if self.turn_analyzer:
|
||||
state, prediction = await self.get_event_loop().run_in_executor(
|
||||
self._executor, self.turn_analyzer.analyze_end_of_turn
|
||||
)
|
||||
|
||||
state, prediction = await self.turn_analyzer.analyze_end_of_turn()
|
||||
await self._handle_prediction_result(prediction)
|
||||
|
||||
await self._handle_end_of_turn_complete(state)
|
||||
|
||||
async def _handle_end_of_turn_complete(self, state: EndOfTurnState):
|
||||
|
||||
@@ -207,10 +207,12 @@ class FastAPIWebsocketOutputTransport(BaseOutputTransport):
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._write_frame(frame)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._write_frame(frame)
|
||||
await self._client.disconnect()
|
||||
|
||||
async def cleanup(self):
|
||||
|
||||
@@ -157,7 +157,8 @@ class WebsocketServerInputTransport(BaseInputTransport):
|
||||
self, websocket: websockets.WebSocketServerProtocol, session_timeout: int
|
||||
):
|
||||
"""Wait for session_timeout seconds, if the websocket is still open,
|
||||
trigger timeout event."""
|
||||
trigger timeout event.
|
||||
"""
|
||||
try:
|
||||
await asyncio.sleep(session_timeout)
|
||||
if not websocket.closed:
|
||||
@@ -195,6 +196,14 @@ class WebsocketServerOutputTransport(BaseOutputTransport):
|
||||
await self._params.serializer.setup(frame)
|
||||
self._send_interval = (self._audio_chunk_size / self.sample_rate) / 2
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
await super().stop(frame)
|
||||
await self._write_frame(frame)
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
await super().cancel(frame)
|
||||
await self._write_frame(frame)
|
||||
|
||||
async def cleanup(self):
|
||||
await super().cleanup()
|
||||
await self._transport.cleanup()
|
||||
|
||||
Reference in New Issue
Block a user