549 lines
18 KiB
Python
549 lines
18 KiB
Python
"""OpenAI Realtime API Service.
|
|
|
|
Provides true duplex voice conversation using OpenAI's Realtime API,
|
|
similar to active-call's RealtimeProcessor. This bypasses the need for
|
|
separate ASR/LLM/TTS services by handling everything server-side.
|
|
|
|
The Realtime API provides:
|
|
- Server-side VAD with turn detection
|
|
- Streaming speech-to-text
|
|
- Streaming LLM responses
|
|
- Streaming text-to-speech
|
|
- Function calling support
|
|
- Barge-in/interruption handling
|
|
"""
|
|
|
|
import os
|
|
import asyncio
|
|
import json
|
|
import base64
|
|
from typing import Optional, Dict, Any, Callable, Awaitable, List
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from loguru import logger
|
|
|
|
try:
|
|
import websockets
|
|
WEBSOCKETS_AVAILABLE = True
|
|
except ImportError:
|
|
WEBSOCKETS_AVAILABLE = False
|
|
logger.warning("websockets not available - Realtime API will be disabled")
|
|
|
|
|
|
class RealtimeState(Enum):
|
|
"""Realtime API connection state."""
|
|
DISCONNECTED = "disconnected"
|
|
CONNECTING = "connecting"
|
|
CONNECTED = "connected"
|
|
ERROR = "error"
|
|
|
|
|
|
@dataclass
|
|
class RealtimeConfig:
|
|
"""Configuration for OpenAI Realtime API."""
|
|
|
|
# API Configuration
|
|
api_key: Optional[str] = None
|
|
model: str = "gpt-4o-realtime-preview"
|
|
endpoint: Optional[str] = None # For Azure or custom endpoints
|
|
|
|
# Voice Configuration
|
|
voice: str = "alloy" # alloy, echo, shimmer, etc.
|
|
instructions: str = (
|
|
"You are a helpful, friendly voice assistant. "
|
|
"Keep your responses concise and conversational."
|
|
)
|
|
|
|
# Turn Detection (Server-side VAD)
|
|
turn_detection: Optional[Dict[str, Any]] = field(default_factory=lambda: {
|
|
"type": "server_vad",
|
|
"threshold": 0.5,
|
|
"prefix_padding_ms": 300,
|
|
"silence_duration_ms": 500
|
|
})
|
|
|
|
# Audio Configuration
|
|
input_audio_format: str = "pcm16"
|
|
output_audio_format: str = "pcm16"
|
|
|
|
# Tools/Functions
|
|
tools: List[Dict[str, Any]] = field(default_factory=list)
|
|
|
|
|
|
class RealtimeService:
|
|
"""
|
|
OpenAI Realtime API service for true duplex voice conversation.
|
|
|
|
This service handles the entire voice conversation pipeline:
|
|
1. Audio input → Server-side VAD → Speech-to-text
|
|
2. Text → LLM processing → Response generation
|
|
3. Response → Text-to-speech → Audio output
|
|
|
|
Events emitted:
|
|
- on_audio: Audio output from the assistant
|
|
- on_transcript: Text transcript (user or assistant)
|
|
- on_speech_started: User started speaking
|
|
- on_speech_stopped: User stopped speaking
|
|
- on_response_started: Assistant started responding
|
|
- on_response_done: Assistant finished responding
|
|
- on_function_call: Function call requested
|
|
- on_error: Error occurred
|
|
"""
|
|
|
|
def __init__(self, config: Optional[RealtimeConfig] = None):
|
|
"""
|
|
Initialize Realtime API service.
|
|
|
|
Args:
|
|
config: Realtime configuration (uses defaults if not provided)
|
|
"""
|
|
self.config = config or RealtimeConfig()
|
|
self.config.api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
|
|
|
self.state = RealtimeState.DISCONNECTED
|
|
self._ws = None
|
|
self._receive_task: Optional[asyncio.Task] = None
|
|
self._cancel_event = asyncio.Event()
|
|
|
|
# Event callbacks
|
|
self._callbacks: Dict[str, List[Callable]] = {
|
|
"on_audio": [],
|
|
"on_transcript": [],
|
|
"on_speech_started": [],
|
|
"on_speech_stopped": [],
|
|
"on_response_started": [],
|
|
"on_response_done": [],
|
|
"on_function_call": [],
|
|
"on_error": [],
|
|
"on_interrupted": [],
|
|
}
|
|
|
|
logger.debug(f"RealtimeService initialized with model={self.config.model}")
|
|
|
|
def on(self, event: str, callback: Callable[..., Awaitable[None]]) -> None:
|
|
"""
|
|
Register event callback.
|
|
|
|
Args:
|
|
event: Event name
|
|
callback: Async callback function
|
|
"""
|
|
if event in self._callbacks:
|
|
self._callbacks[event].append(callback)
|
|
|
|
async def _emit(self, event: str, *args, **kwargs) -> None:
|
|
"""Emit event to all registered callbacks."""
|
|
for callback in self._callbacks.get(event, []):
|
|
try:
|
|
await callback(*args, **kwargs)
|
|
except Exception as e:
|
|
logger.error(f"Event callback error ({event}): {e}")
|
|
|
|
async def connect(self) -> None:
|
|
"""Connect to OpenAI Realtime API."""
|
|
if not WEBSOCKETS_AVAILABLE:
|
|
raise RuntimeError("websockets package not installed")
|
|
|
|
if not self.config.api_key:
|
|
raise ValueError("OpenAI API key not provided")
|
|
|
|
self.state = RealtimeState.CONNECTING
|
|
|
|
# Build URL
|
|
if self.config.endpoint:
|
|
# Azure or custom endpoint
|
|
url = f"{self.config.endpoint}/openai/realtime?api-version=2024-10-01-preview&deployment={self.config.model}"
|
|
else:
|
|
# OpenAI endpoint
|
|
url = f"wss://api.openai.com/v1/realtime?model={self.config.model}"
|
|
|
|
# Build headers
|
|
headers = {}
|
|
if self.config.endpoint:
|
|
headers["api-key"] = self.config.api_key
|
|
else:
|
|
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
|
headers["OpenAI-Beta"] = "realtime=v1"
|
|
|
|
try:
|
|
logger.info(f"Connecting to Realtime API: {url}")
|
|
self._ws = await websockets.connect(url, extra_headers=headers)
|
|
|
|
# Send session configuration
|
|
await self._configure_session()
|
|
|
|
# Start receive loop
|
|
self._receive_task = asyncio.create_task(self._receive_loop())
|
|
|
|
self.state = RealtimeState.CONNECTED
|
|
logger.info("Realtime API connected successfully")
|
|
|
|
except Exception as e:
|
|
self.state = RealtimeState.ERROR
|
|
logger.error(f"Realtime API connection failed: {e}")
|
|
raise
|
|
|
|
async def _configure_session(self) -> None:
|
|
"""Send session configuration to server."""
|
|
session_config = {
|
|
"type": "session.update",
|
|
"session": {
|
|
"modalities": ["text", "audio"],
|
|
"instructions": self.config.instructions,
|
|
"voice": self.config.voice,
|
|
"input_audio_format": self.config.input_audio_format,
|
|
"output_audio_format": self.config.output_audio_format,
|
|
"turn_detection": self.config.turn_detection,
|
|
}
|
|
}
|
|
|
|
if self.config.tools:
|
|
session_config["session"]["tools"] = self.config.tools
|
|
|
|
await self._send(session_config)
|
|
logger.debug("Session configuration sent")
|
|
|
|
async def _send(self, data: Dict[str, Any]) -> None:
|
|
"""Send JSON data to server."""
|
|
if self._ws:
|
|
await self._ws.send(json.dumps(data))
|
|
|
|
async def send_audio(self, audio_bytes: bytes) -> None:
|
|
"""
|
|
Send audio to the Realtime API.
|
|
|
|
Args:
|
|
audio_bytes: PCM audio data (16-bit, mono, 24kHz by default)
|
|
"""
|
|
if self.state != RealtimeState.CONNECTED:
|
|
return
|
|
|
|
# Encode audio as base64
|
|
audio_b64 = base64.standard_b64encode(audio_bytes).decode()
|
|
|
|
await self._send({
|
|
"type": "input_audio_buffer.append",
|
|
"audio": audio_b64
|
|
})
|
|
|
|
async def send_text(self, text: str) -> None:
|
|
"""
|
|
Send text input (bypassing audio).
|
|
|
|
Args:
|
|
text: User text input
|
|
"""
|
|
if self.state != RealtimeState.CONNECTED:
|
|
return
|
|
|
|
# Create a conversation item with user text
|
|
await self._send({
|
|
"type": "conversation.item.create",
|
|
"item": {
|
|
"type": "message",
|
|
"role": "user",
|
|
"content": [{"type": "input_text", "text": text}]
|
|
}
|
|
})
|
|
|
|
# Trigger response
|
|
await self._send({"type": "response.create"})
|
|
|
|
async def cancel_response(self) -> None:
|
|
"""Cancel the current response (for barge-in)."""
|
|
if self.state != RealtimeState.CONNECTED:
|
|
return
|
|
|
|
await self._send({"type": "response.cancel"})
|
|
logger.debug("Response cancelled")
|
|
|
|
async def commit_audio(self) -> None:
|
|
"""Commit the audio buffer and trigger response."""
|
|
if self.state != RealtimeState.CONNECTED:
|
|
return
|
|
|
|
await self._send({"type": "input_audio_buffer.commit"})
|
|
await self._send({"type": "response.create"})
|
|
|
|
async def clear_audio_buffer(self) -> None:
|
|
"""Clear the input audio buffer."""
|
|
if self.state != RealtimeState.CONNECTED:
|
|
return
|
|
|
|
await self._send({"type": "input_audio_buffer.clear"})
|
|
|
|
async def submit_function_result(self, call_id: str, result: str) -> None:
|
|
"""
|
|
Submit function call result.
|
|
|
|
Args:
|
|
call_id: The function call ID
|
|
result: JSON string result
|
|
"""
|
|
if self.state != RealtimeState.CONNECTED:
|
|
return
|
|
|
|
await self._send({
|
|
"type": "conversation.item.create",
|
|
"item": {
|
|
"type": "function_call_output",
|
|
"call_id": call_id,
|
|
"output": result
|
|
}
|
|
})
|
|
|
|
# Trigger response with the function result
|
|
await self._send({"type": "response.create"})
|
|
|
|
async def _receive_loop(self) -> None:
|
|
"""Receive and process messages from the Realtime API."""
|
|
if not self._ws:
|
|
return
|
|
|
|
try:
|
|
async for message in self._ws:
|
|
try:
|
|
data = json.loads(message)
|
|
await self._handle_event(data)
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Invalid JSON received: {message[:100]}")
|
|
|
|
except asyncio.CancelledError:
|
|
logger.debug("Receive loop cancelled")
|
|
except websockets.ConnectionClosed as e:
|
|
logger.info(f"WebSocket closed: {e}")
|
|
self.state = RealtimeState.DISCONNECTED
|
|
except Exception as e:
|
|
logger.error(f"Receive loop error: {e}")
|
|
self.state = RealtimeState.ERROR
|
|
|
|
async def _handle_event(self, data: Dict[str, Any]) -> None:
|
|
"""Handle incoming event from Realtime API."""
|
|
event_type = data.get("type", "unknown")
|
|
|
|
# Audio delta - streaming audio output
|
|
if event_type == "response.audio.delta":
|
|
if "delta" in data:
|
|
audio_bytes = base64.standard_b64decode(data["delta"])
|
|
await self._emit("on_audio", audio_bytes)
|
|
|
|
# Audio transcript delta - streaming text
|
|
elif event_type == "response.audio_transcript.delta":
|
|
if "delta" in data:
|
|
await self._emit("on_transcript", data["delta"], "assistant", False)
|
|
|
|
# Audio transcript done
|
|
elif event_type == "response.audio_transcript.done":
|
|
if "transcript" in data:
|
|
await self._emit("on_transcript", data["transcript"], "assistant", True)
|
|
|
|
# Input audio transcript (user speech)
|
|
elif event_type == "conversation.item.input_audio_transcription.completed":
|
|
if "transcript" in data:
|
|
await self._emit("on_transcript", data["transcript"], "user", True)
|
|
|
|
# Speech started (server VAD detected speech)
|
|
elif event_type == "input_audio_buffer.speech_started":
|
|
await self._emit("on_speech_started", data.get("audio_start_ms", 0))
|
|
|
|
# Speech stopped
|
|
elif event_type == "input_audio_buffer.speech_stopped":
|
|
await self._emit("on_speech_stopped", data.get("audio_end_ms", 0))
|
|
|
|
# Response started
|
|
elif event_type == "response.created":
|
|
await self._emit("on_response_started", data.get("response", {}))
|
|
|
|
# Response done
|
|
elif event_type == "response.done":
|
|
await self._emit("on_response_done", data.get("response", {}))
|
|
|
|
# Function call
|
|
elif event_type == "response.function_call_arguments.done":
|
|
call_id = data.get("call_id")
|
|
name = data.get("name")
|
|
arguments = data.get("arguments", "{}")
|
|
await self._emit("on_function_call", call_id, name, arguments)
|
|
|
|
# Error
|
|
elif event_type == "error":
|
|
error = data.get("error", {})
|
|
logger.error(f"Realtime API error: {error}")
|
|
await self._emit("on_error", error)
|
|
|
|
# Session events
|
|
elif event_type == "session.created":
|
|
logger.info("Session created")
|
|
elif event_type == "session.updated":
|
|
logger.debug("Session updated")
|
|
|
|
else:
|
|
logger.debug(f"Unhandled event type: {event_type}")
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect from Realtime API."""
|
|
self._cancel_event.set()
|
|
|
|
if self._receive_task:
|
|
self._receive_task.cancel()
|
|
try:
|
|
await self._receive_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
if self._ws:
|
|
await self._ws.close()
|
|
self._ws = None
|
|
|
|
self.state = RealtimeState.DISCONNECTED
|
|
logger.info("Realtime API disconnected")
|
|
|
|
|
|
class RealtimePipeline:
|
|
"""
|
|
Pipeline adapter for RealtimeService.
|
|
|
|
Provides a compatible interface with DuplexPipeline but uses
|
|
OpenAI Realtime API for all processing.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
transport,
|
|
session_id: str,
|
|
config: Optional[RealtimeConfig] = None
|
|
):
|
|
"""
|
|
Initialize Realtime pipeline.
|
|
|
|
Args:
|
|
transport: Transport for sending audio/events
|
|
session_id: Session identifier
|
|
config: Realtime configuration
|
|
"""
|
|
self.transport = transport
|
|
self.session_id = session_id
|
|
|
|
self.service = RealtimeService(config)
|
|
|
|
# Register callbacks
|
|
self.service.on("on_audio", self._on_audio)
|
|
self.service.on("on_transcript", self._on_transcript)
|
|
self.service.on("on_speech_started", self._on_speech_started)
|
|
self.service.on("on_speech_stopped", self._on_speech_stopped)
|
|
self.service.on("on_response_started", self._on_response_started)
|
|
self.service.on("on_response_done", self._on_response_done)
|
|
self.service.on("on_error", self._on_error)
|
|
|
|
self._is_speaking = False
|
|
self._running = True
|
|
|
|
logger.info(f"RealtimePipeline initialized for session {session_id}")
|
|
|
|
async def start(self) -> None:
|
|
"""Start the pipeline."""
|
|
await self.service.connect()
|
|
|
|
async def process_audio(self, pcm_bytes: bytes) -> None:
|
|
"""
|
|
Process incoming audio.
|
|
|
|
Note: Realtime API expects 24kHz audio by default.
|
|
You may need to resample from 16kHz.
|
|
"""
|
|
if not self._running:
|
|
return
|
|
|
|
# TODO: Resample from 16kHz to 24kHz if needed
|
|
await self.service.send_audio(pcm_bytes)
|
|
|
|
async def process_text(self, text: str) -> None:
|
|
"""Process text input."""
|
|
if not self._running:
|
|
return
|
|
|
|
await self.service.send_text(text)
|
|
|
|
async def interrupt(self) -> None:
|
|
"""Interrupt current response."""
|
|
await self.service.cancel_response()
|
|
await self.transport.send_event({
|
|
"event": "interrupt",
|
|
"trackId": self.session_id,
|
|
"timestamp": self._get_timestamp_ms()
|
|
})
|
|
|
|
async def cleanup(self) -> None:
|
|
"""Cleanup resources."""
|
|
self._running = False
|
|
await self.service.disconnect()
|
|
|
|
# Event handlers
|
|
|
|
async def _on_audio(self, audio_bytes: bytes) -> None:
|
|
"""Handle audio output."""
|
|
await self.transport.send_audio(audio_bytes)
|
|
|
|
async def _on_transcript(self, text: str, role: str, is_final: bool) -> None:
|
|
"""Handle transcript."""
|
|
logger.info(f"[{role.upper()}] {text[:50]}..." if len(text) > 50 else f"[{role.upper()}] {text}")
|
|
|
|
async def _on_speech_started(self, start_ms: int) -> None:
|
|
"""Handle user speech start."""
|
|
self._is_speaking = True
|
|
await self.transport.send_event({
|
|
"event": "speaking",
|
|
"trackId": self.session_id,
|
|
"timestamp": self._get_timestamp_ms(),
|
|
"startTime": start_ms
|
|
})
|
|
|
|
# Cancel any ongoing response (barge-in)
|
|
await self.service.cancel_response()
|
|
|
|
async def _on_speech_stopped(self, end_ms: int) -> None:
|
|
"""Handle user speech stop."""
|
|
self._is_speaking = False
|
|
await self.transport.send_event({
|
|
"event": "silence",
|
|
"trackId": self.session_id,
|
|
"timestamp": self._get_timestamp_ms(),
|
|
"duration": end_ms
|
|
})
|
|
|
|
async def _on_response_started(self, response: Dict) -> None:
|
|
"""Handle response start."""
|
|
await self.transport.send_event({
|
|
"event": "trackStart",
|
|
"trackId": self.session_id,
|
|
"timestamp": self._get_timestamp_ms()
|
|
})
|
|
|
|
async def _on_response_done(self, response: Dict) -> None:
|
|
"""Handle response complete."""
|
|
await self.transport.send_event({
|
|
"event": "trackEnd",
|
|
"trackId": self.session_id,
|
|
"timestamp": self._get_timestamp_ms()
|
|
})
|
|
|
|
async def _on_error(self, error: Dict) -> None:
|
|
"""Handle error."""
|
|
await self.transport.send_event({
|
|
"event": "error",
|
|
"trackId": self.session_id,
|
|
"timestamp": self._get_timestamp_ms(),
|
|
"sender": "realtime",
|
|
"error": str(error)
|
|
})
|
|
|
|
def _get_timestamp_ms(self) -> int:
|
|
"""Get current timestamp in milliseconds."""
|
|
import time
|
|
return int(time.time() * 1000)
|
|
|
|
@property
|
|
def is_speaking(self) -> bool:
|
|
"""Check if user is speaking."""
|
|
return self._is_speaking
|