Add OpenAIRealtimeSTTService
This commit is contained in:
1
changelog/3656.added.md
Normal file
1
changelog/3656.added.md
Normal file
@@ -0,0 +1 @@
|
||||
- Added `OpenAIRealtimeSTTService` for real-time streaming speech-to-text using OpenAI's Realtime API WebSocket transcription sessions. Supports local VAD and server-side VAD modes, noise reduction, and automatic reconnection.
|
||||
135
examples/foundational/07g-interruptible-openai-http.py
Normal file
135
examples/foundational/07g-interruptible-openai-http.py
Normal file
@@ -0,0 +1,135 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
|
||||
import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.audio.vad.vad_analyzer import VADParams
|
||||
from pipecat.frames.frames import LLMRunFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
LLMUserAggregatorParams,
|
||||
)
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import OpenAISTTService
|
||||
from pipecat.services.openai.tts import OpenAITTSService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
from pipecat.turns.user_stop import TurnAnalyzerUserTurnStopStrategy
|
||||
from pipecat.turns.user_turn_strategies import UserTurnStrategies
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
"webrtc": lambda: TransportParams(
|
||||
audio_in_enabled=True,
|
||||
audio_out_enabled=True,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = OpenAISTTService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o-transcribe",
|
||||
prompt="Expect words related to dogs, such as breed names.",
|
||||
)
|
||||
|
||||
tts = OpenAITTSService(api_key=os.getenv("OPENAI_API_KEY"), voice="ballad")
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are very knowledgable about dogs. Your output will be spoken aloud, so avoid special characters that can't easily be spoken, such as emojis or bullet points. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = LLMContext(messages)
|
||||
user_aggregator, assistant_aggregator = LLMContextAggregatorPair(
|
||||
context,
|
||||
user_params=LLMUserAggregatorParams(
|
||||
user_turn_strategies=UserTurnStrategies(
|
||||
stop=[TurnAnalyzerUserTurnStopStrategy(turn_analyzer=LocalSmartTurnAnalyzerV3())]
|
||||
),
|
||||
vad_analyzer=SileroVADAnalyzer(params=VADParams(stop_secs=0.2)),
|
||||
),
|
||||
)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
user_aggregator, # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
assistant_aggregator, # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
params=PipelineParams(
|
||||
audio_out_sample_rate=24000,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
),
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, client):
|
||||
logger.info(f"Client connected")
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([LLMRunFrame()])
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -25,8 +25,9 @@ from pipecat.processors.aggregators.llm_response_universal import (
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import OpenAISTTService
|
||||
from pipecat.services.openai.stt import OpenAIRealtimeSTTService
|
||||
from pipecat.services.openai.tts import OpenAITTSService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
@@ -56,10 +57,15 @@ transport_params = {
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = OpenAISTTService(
|
||||
stt = OpenAIRealtimeSTTService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o-transcribe",
|
||||
prompt="Expect words related to dogs, such as breed names.",
|
||||
language=Language.EN,
|
||||
# Uses local VAD by default.
|
||||
# To enable server-side VAD, set turn_detection=None or
|
||||
# a dict with server_vad settings.
|
||||
# turn_detection={"type": "server_vad", "threshold": 0.5},
|
||||
)
|
||||
|
||||
tts = OpenAITTSService(api_key=os.getenv("OPENAI_API_KEY"), voice="ballad")
|
||||
|
||||
88
examples/foundational/13m-openai-transcription.py
Normal file
88
examples/foundational/13m-openai-transcription.py
Normal file
@@ -0,0 +1,88 @@
|
||||
#
|
||||
# Copyright (c) 2024-2026, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.frames.frames import Frame, TranscriptionFrame
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
from pipecat.runner.types import RunnerArguments
|
||||
from pipecat.runner.utils import create_transport
|
||||
from pipecat.services.openai.stt import OpenAIRealtimeSTTService
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
from pipecat.transports.daily.transport import DailyParams
|
||||
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
class TranscriptionLogger(FrameProcessor):
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, TranscriptionFrame):
|
||||
print(f"Transcription: {frame.text}")
|
||||
|
||||
# Push all frames through
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
# We use lambdas to defer transport parameter creation until the transport
|
||||
# type is selected at runtime.
|
||||
transport_params = {
|
||||
"daily": lambda: DailyParams(audio_in_enabled=True, vad_analyzer=SileroVADAnalyzer()),
|
||||
"twilio": lambda: FastAPIWebsocketParams(
|
||||
audio_in_enabled=True, vad_analyzer=SileroVADAnalyzer()
|
||||
),
|
||||
"webrtc": lambda: TransportParams(audio_in_enabled=True, vad_analyzer=SileroVADAnalyzer()),
|
||||
}
|
||||
|
||||
|
||||
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
|
||||
logger.info(f"Starting bot")
|
||||
|
||||
stt = OpenAIRealtimeSTTService(
|
||||
api_key=os.getenv("OPENAI_API_KEY"),
|
||||
model="gpt-4o-transcribe",
|
||||
prompt="Expect words related to dogs, such as breed names.",
|
||||
)
|
||||
|
||||
tl = TranscriptionLogger()
|
||||
|
||||
pipeline = Pipeline([transport.input(), stt, tl])
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
|
||||
)
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, client):
|
||||
logger.info(f"Client disconnected")
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
async def bot(runner_args: RunnerArguments):
|
||||
"""Main bot entry point compatible with Pipecat Cloud."""
|
||||
transport = await create_transport(runner_args, transport_params)
|
||||
await run_bot(transport, runner_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from pipecat.runner.run import main
|
||||
|
||||
main()
|
||||
@@ -111,6 +111,7 @@ TESTS_07 = [
|
||||
("07f-interruptible-azure.py", EVAL_SIMPLE_MATH),
|
||||
("07f-interruptible-azure-http.py", EVAL_SIMPLE_MATH),
|
||||
("07g-interruptible-openai.py", EVAL_SIMPLE_MATH),
|
||||
("07g-interruptible-openai-http.py", EVAL_SIMPLE_MATH),
|
||||
("07h-interruptible-openpipe.py", EVAL_SIMPLE_MATH),
|
||||
("07j-interruptible-gladia.py", EVAL_SIMPLE_MATH),
|
||||
("07j-interruptible-gladia-vad.py", EVAL_SIMPLE_MATH),
|
||||
|
||||
@@ -4,12 +4,48 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""OpenAI Speech-to-Text service implementation using OpenAI's transcription API."""
|
||||
"""OpenAI Speech-to-Text service implementations.
|
||||
|
||||
from typing import Optional
|
||||
Provides two STT services:
|
||||
|
||||
- ``OpenAISTTService``: REST-based transcription using the Audio API
|
||||
(Whisper / GPT-4o).
|
||||
- ``OpenAIRealtimeSTTService``: WebSocket-based streaming transcription
|
||||
using the Realtime API in transcription-only mode.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import AsyncGenerator, Literal, Optional, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat.audio.utils import create_stream_resampler
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
VADUserStartedSpeakingFrame,
|
||||
VADUserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_service import WebsocketSTTService
|
||||
from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
try:
|
||||
from websockets.asyncio.client import connect as websocket_connect
|
||||
from websockets.protocol import State
|
||||
except ModuleNotFoundError:
|
||||
websocket_connect = None
|
||||
State = None
|
||||
|
||||
|
||||
class OpenAISTTService(BaseWhisperSTTService):
|
||||
@@ -77,3 +113,528 @@ class OpenAISTTService(BaseWhisperSTTService):
|
||||
kwargs["temperature"] = self._temperature
|
||||
|
||||
return await self._client.audio.transcriptions.create(**kwargs)
|
||||
|
||||
|
||||
_OPENAI_SAMPLE_RATE = 24000
|
||||
|
||||
|
||||
class OpenAIRealtimeSTTService(WebsocketSTTService):
|
||||
"""OpenAI Realtime Speech-to-Text service using WebSocket transcription sessions.
|
||||
|
||||
Uses OpenAI's Realtime API in transcription-only mode for real-time streaming
|
||||
speech recognition with optional server-side VAD and noise reduction. The model
|
||||
does not generate conversational responses — only transcription output.
|
||||
|
||||
This service supports two VAD modes:
|
||||
|
||||
**Local VAD** (default): Disable server-side VAD and use
|
||||
a local VAD processor in the pipeline instead. When a
|
||||
``VADUserStoppedSpeakingFrame`` is received, the service commits the
|
||||
audio buffer so that the server begins transcription for the completed
|
||||
speech segment.
|
||||
|
||||
**Server-side VAD** (``turn_detection=None``): The OpenAI server performs voice-activity
|
||||
detection. The service broadcasts ``UserStartedSpeakingFrame`` and
|
||||
``UserStoppedSpeakingFrame`` when the server detects speech boundaries.
|
||||
Do **not** use a separate VAD processor in the pipeline in this mode.
|
||||
|
||||
Audio is sent as 24 kHz 16-bit mono PCM as required by the OpenAI Realtime
|
||||
API. If the pipeline runs at a different sample rate (e.g. 16 kHz for Silero
|
||||
VAD compatibility), audio is automatically upsampled before sending.
|
||||
|
||||
Example::
|
||||
|
||||
stt = OpenAIRealtimeSTTService(
|
||||
api_key="sk-...",
|
||||
model="gpt-4o-transcribe",
|
||||
noise_reduction="near_field",
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
model: str = "gpt-4o-transcribe",
|
||||
base_url: str = "wss://api.openai.com/v1/realtime",
|
||||
language: Optional[Language] = Language.EN,
|
||||
prompt: Optional[str] = None,
|
||||
turn_detection: Optional[Union[dict, Literal[False]]] = False,
|
||||
noise_reduction: Optional[Literal["near_field", "far_field"]] = None,
|
||||
should_interrupt: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the OpenAI Realtime STT service.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key for authentication.
|
||||
model: Transcription model. Supported values are
|
||||
``"gpt-4o-transcribe"`` and ``"gpt-4o-mini-transcribe"``.
|
||||
Defaults to ``"gpt-4o-transcribe"``.
|
||||
base_url: WebSocket base URL for the Realtime API.
|
||||
Defaults to ``"wss://api.openai.com/v1/realtime"``.
|
||||
language: Language of the audio input. Defaults to English.
|
||||
prompt: Optional prompt text to guide transcription style
|
||||
or provide keyword hints.
|
||||
turn_detection: Server-side VAD configuration. Defaults to
|
||||
``False`` (disabled), which relies on a local VAD
|
||||
processor in the pipeline. Pass ``None`` to use server
|
||||
defaults (``server_vad``), or a dict with custom
|
||||
settings (e.g. ``{"type": "server_vad", "threshold": 0.5}``).
|
||||
noise_reduction: Noise reduction mode. ``"near_field"`` for
|
||||
close microphones, ``"far_field"`` for distant
|
||||
microphones, or ``None`` to disable.
|
||||
should_interrupt: Whether to interrupt bot output when
|
||||
speech is detected by server-side VAD. Only applies when
|
||||
turn detection is enabled. Defaults to True.
|
||||
**kwargs: Additional arguments passed to parent
|
||||
WebsocketSTTService.
|
||||
"""
|
||||
if websocket_connect is None:
|
||||
raise ImportError(
|
||||
"websockets is required for OpenAIRealtimeSTTService. "
|
||||
"Install it with: pip install pipecat-ai[openai]"
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self.set_model_name(model)
|
||||
|
||||
self._language_code = self._language_to_code(language) if language else None
|
||||
self._prompt = prompt
|
||||
self._turn_detection = turn_detection
|
||||
self._noise_reduction = noise_reduction
|
||||
self._should_interrupt = should_interrupt
|
||||
|
||||
self._receive_task = None
|
||||
self._session_ready = False
|
||||
self._resampler = create_stream_resampler()
|
||||
|
||||
# Server-side VAD is disabled by default (turn_detection=False).
|
||||
# Set to None or a dict to enable server-side VAD.
|
||||
self._server_vad_enabled = turn_detection is not False
|
||||
|
||||
@staticmethod
|
||||
def _language_to_code(language: Language) -> str:
|
||||
"""Convert a Language enum value to an ISO-639-1 code.
|
||||
|
||||
Args:
|
||||
language: The Language enum value.
|
||||
|
||||
Returns:
|
||||
Two-letter ISO-639-1 language code.
|
||||
"""
|
||||
# Language.value is e.g. "en", "en-US", "fr", "zh".
|
||||
return language.value.split("-")[0].lower()
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
"""Check if the service can generate processing metrics.
|
||||
|
||||
Returns:
|
||||
True, as this service supports metrics generation.
|
||||
"""
|
||||
return True
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
"""Set the language for speech recognition.
|
||||
|
||||
If the session is already active, sends an updated configuration
|
||||
to the server.
|
||||
|
||||
Args:
|
||||
language: The language to use for speech recognition.
|
||||
"""
|
||||
self._language_code = self._language_to_code(language)
|
||||
if self._session_ready:
|
||||
await self._send_session_update()
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the service and establish WebSocket connection.
|
||||
|
||||
Args:
|
||||
frame: The start frame triggering service initialization.
|
||||
"""
|
||||
await super().start(frame)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the service and close WebSocket connection.
|
||||
|
||||
Args:
|
||||
frame: The end frame triggering service shutdown.
|
||||
"""
|
||||
await super().stop(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the service and close WebSocket connection.
|
||||
|
||||
Args:
|
||||
frame: The cancel frame triggering service cancellation.
|
||||
"""
|
||||
await super().cancel(frame)
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Send audio data to the transcription session.
|
||||
|
||||
Audio is streamed over the WebSocket. Transcription results arrive
|
||||
asynchronously via the receive task and are pushed as
|
||||
``InterimTranscriptionFrame`` or ``TranscriptionFrame``.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes (16-bit mono PCM at the pipeline
|
||||
sample rate). Automatically resampled to 24 kHz.
|
||||
|
||||
Yields:
|
||||
None — results are delivered via the WebSocket receive task.
|
||||
"""
|
||||
await self._send_audio(audio)
|
||||
yield None
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process frames from the pipeline.
|
||||
|
||||
Extends the base STT service to handle local VAD events when
|
||||
server-side VAD is disabled. On ``VADUserStoppedSpeakingFrame``,
|
||||
commits the audio buffer so the server begins transcription for
|
||||
the completed speech segment.
|
||||
|
||||
Args:
|
||||
frame: The frame to process.
|
||||
direction: The direction of frame flow in the pipeline.
|
||||
"""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Handle local VAD events when server-side VAD is disabled.
|
||||
if not self._server_vad_enabled:
|
||||
if isinstance(frame, VADUserStartedSpeakingFrame):
|
||||
await self.start_processing_metrics()
|
||||
elif isinstance(frame, VADUserStoppedSpeakingFrame):
|
||||
await self._commit_audio_buffer()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# WebSocket connection management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _connect(self):
|
||||
"""Connect to the transcription endpoint and start receiving."""
|
||||
await super()._connect()
|
||||
await self._connect_websocket()
|
||||
if self._websocket and not self._receive_task:
|
||||
self._receive_task = self.create_task(self._receive_task_handler(self._report_error))
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect and clean up background tasks."""
|
||||
await super()._disconnect()
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task, timeout=1.0)
|
||||
self._receive_task = None
|
||||
await self._disconnect_websocket()
|
||||
|
||||
async def _connect_websocket(self):
|
||||
"""Establish the WebSocket connection to the transcription endpoint."""
|
||||
try:
|
||||
if self._websocket and self._websocket.state is State.OPEN:
|
||||
return
|
||||
|
||||
self._session_ready = False
|
||||
url = f"{self._base_url}?intent=transcription"
|
||||
self._websocket = await websocket_connect(
|
||||
uri=url,
|
||||
additional_headers={
|
||||
"Authorization": f"Bearer {self._api_key}",
|
||||
},
|
||||
)
|
||||
await self._call_event_handler("on_connected")
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Error connecting to OpenAI Realtime STT: {e}",
|
||||
exception=e,
|
||||
)
|
||||
self._websocket = None
|
||||
|
||||
async def _disconnect_websocket(self):
|
||||
"""Close the WebSocket connection."""
|
||||
try:
|
||||
self._session_ready = False
|
||||
if self._websocket:
|
||||
await self._websocket.close()
|
||||
except Exception as e:
|
||||
await self.push_error(
|
||||
error_msg=f"Error disconnecting: {e}",
|
||||
exception=e,
|
||||
)
|
||||
finally:
|
||||
self._websocket = None
|
||||
await self._call_event_handler("on_disconnected")
|
||||
|
||||
async def _ws_send(self, message: dict):
|
||||
"""Send a JSON message over the WebSocket.
|
||||
|
||||
Args:
|
||||
message: The message dict to serialize and send.
|
||||
"""
|
||||
try:
|
||||
if not self._disconnecting and self._websocket:
|
||||
await self._websocket.send(json.dumps(message))
|
||||
except Exception as e:
|
||||
if self._disconnecting or not self._websocket:
|
||||
return
|
||||
await self.push_error(
|
||||
error_msg=f"Error sending message: {e}",
|
||||
exception=e,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Client events
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _send_session_update(self):
|
||||
"""Send ``session.update`` to configure the transcription session."""
|
||||
transcription: dict = {"model": self.model_name}
|
||||
|
||||
if self._language_code:
|
||||
transcription["language"] = self._language_code
|
||||
|
||||
if self._prompt:
|
||||
transcription["prompt"] = self._prompt
|
||||
|
||||
input_audio: dict = {
|
||||
"format": {
|
||||
"type": "audio/pcm",
|
||||
"rate": _OPENAI_SAMPLE_RATE,
|
||||
},
|
||||
"transcription": transcription,
|
||||
}
|
||||
|
||||
# Turn detection
|
||||
if self._turn_detection is False:
|
||||
input_audio["turn_detection"] = None
|
||||
elif self._turn_detection is not None:
|
||||
input_audio["turn_detection"] = self._turn_detection
|
||||
|
||||
# Noise reduction
|
||||
if self._noise_reduction:
|
||||
input_audio["noise_reduction"] = {
|
||||
"type": self._noise_reduction,
|
||||
}
|
||||
|
||||
await self._ws_send(
|
||||
{
|
||||
"type": "session.update",
|
||||
"session": {
|
||||
"type": "transcription",
|
||||
"audio": {
|
||||
"input": input_audio,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
async def _send_audio(self, audio: bytes):
|
||||
"""Send audio data via ``input_audio_buffer.append``.
|
||||
|
||||
Resamples from the pipeline sample rate to 24 kHz if needed.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes at the pipeline sample rate.
|
||||
"""
|
||||
audio = await self._resampler.resample(audio, self.sample_rate, _OPENAI_SAMPLE_RATE)
|
||||
if not audio:
|
||||
return
|
||||
payload = base64.b64encode(audio).decode("utf-8")
|
||||
await self._ws_send(
|
||||
{
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": payload,
|
||||
}
|
||||
)
|
||||
|
||||
async def _commit_audio_buffer(self):
|
||||
"""Commit the current audio buffer for transcription."""
|
||||
await self._ws_send({"type": "input_audio_buffer.commit"})
|
||||
|
||||
async def _clear_audio_buffer(self):
|
||||
"""Clear the current audio buffer."""
|
||||
await self._ws_send({"type": "input_audio_buffer.clear"})
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Server event handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _receive_messages(self):
|
||||
"""Receive and dispatch server events from the transcription session.
|
||||
|
||||
Called by ``WebsocketService._receive_task_handler`` which wraps
|
||||
this method with automatic reconnection on connection errors.
|
||||
"""
|
||||
async for message in self._websocket:
|
||||
try:
|
||||
evt = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Failed to parse WebSocket message")
|
||||
continue
|
||||
|
||||
evt_type = evt.get("type", "")
|
||||
|
||||
if evt_type == "session.created":
|
||||
await self._handle_session_created(evt)
|
||||
elif evt_type == "session.updated":
|
||||
await self._handle_session_updated(evt)
|
||||
elif evt_type == "conversation.item.input_audio_transcription.delta":
|
||||
await self._handle_transcription_delta(evt)
|
||||
elif evt_type == "conversation.item.input_audio_transcription.completed":
|
||||
await self._handle_transcription_completed(evt)
|
||||
elif evt_type == "conversation.item.input_audio_transcription.failed":
|
||||
await self._handle_transcription_failed(evt)
|
||||
elif evt_type == "input_audio_buffer.speech_started":
|
||||
await self._handle_speech_started(evt)
|
||||
elif evt_type == "input_audio_buffer.speech_stopped":
|
||||
await self._handle_speech_stopped(evt)
|
||||
elif evt_type == "input_audio_buffer.committed":
|
||||
logger.trace(f"Audio buffer committed: item_id={evt.get('item_id', '')}")
|
||||
elif evt_type == "error":
|
||||
await self._handle_error(evt)
|
||||
else:
|
||||
logger.trace(f"Unhandled event: {evt_type}")
|
||||
|
||||
async def _handle_session_created(self, evt: dict):
|
||||
"""Handle ``session.created``.
|
||||
|
||||
Sent immediately after connecting. We respond by configuring the
|
||||
session with our desired settings.
|
||||
|
||||
Args:
|
||||
evt: The session created event from the server.
|
||||
"""
|
||||
logger.debug("Transcription session created, sending configuration")
|
||||
await self._send_session_update()
|
||||
|
||||
async def _handle_session_updated(self, evt: dict):
|
||||
"""Handle ``session.updated``.
|
||||
|
||||
The session is now fully configured and ready to transcribe.
|
||||
|
||||
Args:
|
||||
evt: The session updated event from the server.
|
||||
"""
|
||||
logger.debug("Transcription session configured and ready")
|
||||
self._session_ready = True
|
||||
|
||||
async def _handle_transcription_delta(self, evt: dict):
|
||||
"""Handle incremental transcription text.
|
||||
|
||||
For ``gpt-4o-transcribe`` and ``gpt-4o-mini-transcribe``, deltas
|
||||
contain streaming partial text. For ``whisper-1``, each delta
|
||||
contains the full turn transcript.
|
||||
|
||||
Args:
|
||||
evt: The delta event from the server.
|
||||
"""
|
||||
delta = evt.get("delta", "")
|
||||
if delta:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
delta,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
result=evt,
|
||||
)
|
||||
)
|
||||
|
||||
async def _handle_transcription_completed(self, evt: dict):
|
||||
"""Handle a completed transcription for a speech segment.
|
||||
|
||||
Pushes a ``TranscriptionFrame`` and records the result for
|
||||
tracing.
|
||||
|
||||
Args:
|
||||
evt: The completed event containing the full transcript.
|
||||
"""
|
||||
transcript = evt.get("transcript", "")
|
||||
if transcript:
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
transcript,
|
||||
self._user_id,
|
||||
time_now_iso8601(),
|
||||
result=evt,
|
||||
)
|
||||
)
|
||||
await self._handle_transcription_trace(transcript, True)
|
||||
await self.stop_processing_metrics()
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription_trace(
|
||||
self,
|
||||
transcript: str,
|
||||
is_final: bool,
|
||||
language: Optional[Language] = None,
|
||||
):
|
||||
"""Record transcription result for tracing.
|
||||
|
||||
Args:
|
||||
transcript: The transcribed text.
|
||||
is_final: Whether this is a final transcription result.
|
||||
language: Optional language of the transcription.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _handle_speech_started(self, evt: dict):
|
||||
"""Handle server-side VAD speech start.
|
||||
|
||||
Broadcasts ``UserStartedSpeakingFrame`` and optionally triggers
|
||||
interruption of current bot output.
|
||||
|
||||
Args:
|
||||
evt: The ``input_audio_buffer.speech_started`` event.
|
||||
"""
|
||||
logger.debug("Server VAD: speech started")
|
||||
await self.broadcast_frame(UserStartedSpeakingFrame)
|
||||
if self._should_interrupt:
|
||||
await self.push_interruption_task_frame_and_wait()
|
||||
await self.start_processing_metrics()
|
||||
|
||||
async def _handle_speech_stopped(self, evt: dict):
|
||||
"""Handle server-side VAD speech stop.
|
||||
|
||||
Broadcasts ``UserStoppedSpeakingFrame``. The audio buffer is
|
||||
automatically committed by the server when VAD is enabled.
|
||||
|
||||
Args:
|
||||
evt: The ``input_audio_buffer.speech_stopped`` event.
|
||||
"""
|
||||
logger.debug("Server VAD: speech stopped")
|
||||
await self.broadcast_frame(UserStoppedSpeakingFrame)
|
||||
|
||||
async def _handle_transcription_failed(self, evt: dict):
|
||||
"""Handle a transcription failure for a speech segment.
|
||||
|
||||
Logs the error but does not treat it as fatal — the session
|
||||
remains active for subsequent turns.
|
||||
|
||||
Args:
|
||||
evt: The failed event containing error details.
|
||||
"""
|
||||
error = evt.get("error", {})
|
||||
error_msg = error.get("message", "Transcription failed")
|
||||
await self.push_error(error_msg=f"OpenAI Realtime STT error: {error_msg}")
|
||||
|
||||
async def _handle_error(self, evt: dict):
|
||||
"""Handle a fatal error from the transcription session.
|
||||
|
||||
Raises an exception so that ``WebsocketService`` can decide
|
||||
whether to attempt reconnection.
|
||||
|
||||
Args:
|
||||
evt: The error event.
|
||||
"""
|
||||
error = evt.get("error", {})
|
||||
error_msg = error.get("message", "Unknown error")
|
||||
error_code = error.get("code", "")
|
||||
msg = f"OpenAI Realtime STT error [{error_code}]: {error_msg}"
|
||||
await self.push_error(error_msg=msg)
|
||||
raise Exception(msg)
|
||||
|
||||
Reference in New Issue
Block a user