Add backend api and engine
This commit is contained in:
548
engine/services/realtime.py
Normal file
548
engine/services/realtime.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""OpenAI Realtime API Service.
|
||||
|
||||
Provides true duplex voice conversation using OpenAI's Realtime API,
|
||||
similar to active-call's RealtimeProcessor. This bypasses the need for
|
||||
separate ASR/LLM/TTS services by handling everything server-side.
|
||||
|
||||
The Realtime API provides:
|
||||
- Server-side VAD with turn detection
|
||||
- Streaming speech-to-text
|
||||
- Streaming LLM responses
|
||||
- Streaming text-to-speech
|
||||
- Function calling support
|
||||
- Barge-in/interruption handling
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
from typing import Optional, Dict, Any, Callable, Awaitable, List
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
import websockets
|
||||
WEBSOCKETS_AVAILABLE = True
|
||||
except ImportError:
|
||||
WEBSOCKETS_AVAILABLE = False
|
||||
logger.warning("websockets not available - Realtime API will be disabled")
|
||||
|
||||
|
||||
class RealtimeState(Enum):
|
||||
"""Realtime API connection state."""
|
||||
DISCONNECTED = "disconnected"
|
||||
CONNECTING = "connecting"
|
||||
CONNECTED = "connected"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RealtimeConfig:
|
||||
"""Configuration for OpenAI Realtime API."""
|
||||
|
||||
# API Configuration
|
||||
api_key: Optional[str] = None
|
||||
model: str = "gpt-4o-realtime-preview"
|
||||
endpoint: Optional[str] = None # For Azure or custom endpoints
|
||||
|
||||
# Voice Configuration
|
||||
voice: str = "alloy" # alloy, echo, shimmer, etc.
|
||||
instructions: str = (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational."
|
||||
)
|
||||
|
||||
# Turn Detection (Server-side VAD)
|
||||
turn_detection: Optional[Dict[str, Any]] = field(default_factory=lambda: {
|
||||
"type": "server_vad",
|
||||
"threshold": 0.5,
|
||||
"prefix_padding_ms": 300,
|
||||
"silence_duration_ms": 500
|
||||
})
|
||||
|
||||
# Audio Configuration
|
||||
input_audio_format: str = "pcm16"
|
||||
output_audio_format: str = "pcm16"
|
||||
|
||||
# Tools/Functions
|
||||
tools: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
class RealtimeService:
|
||||
"""
|
||||
OpenAI Realtime API service for true duplex voice conversation.
|
||||
|
||||
This service handles the entire voice conversation pipeline:
|
||||
1. Audio input → Server-side VAD → Speech-to-text
|
||||
2. Text → LLM processing → Response generation
|
||||
3. Response → Text-to-speech → Audio output
|
||||
|
||||
Events emitted:
|
||||
- on_audio: Audio output from the assistant
|
||||
- on_transcript: Text transcript (user or assistant)
|
||||
- on_speech_started: User started speaking
|
||||
- on_speech_stopped: User stopped speaking
|
||||
- on_response_started: Assistant started responding
|
||||
- on_response_done: Assistant finished responding
|
||||
- on_function_call: Function call requested
|
||||
- on_error: Error occurred
|
||||
"""
|
||||
|
||||
def __init__(self, config: Optional[RealtimeConfig] = None):
|
||||
"""
|
||||
Initialize Realtime API service.
|
||||
|
||||
Args:
|
||||
config: Realtime configuration (uses defaults if not provided)
|
||||
"""
|
||||
self.config = config or RealtimeConfig()
|
||||
self.config.api_key = self.config.api_key or os.getenv("OPENAI_API_KEY")
|
||||
|
||||
self.state = RealtimeState.DISCONNECTED
|
||||
self._ws = None
|
||||
self._receive_task: Optional[asyncio.Task] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
|
||||
# Event callbacks
|
||||
self._callbacks: Dict[str, List[Callable]] = {
|
||||
"on_audio": [],
|
||||
"on_transcript": [],
|
||||
"on_speech_started": [],
|
||||
"on_speech_stopped": [],
|
||||
"on_response_started": [],
|
||||
"on_response_done": [],
|
||||
"on_function_call": [],
|
||||
"on_error": [],
|
||||
"on_interrupted": [],
|
||||
}
|
||||
|
||||
logger.debug(f"RealtimeService initialized with model={self.config.model}")
|
||||
|
||||
def on(self, event: str, callback: Callable[..., Awaitable[None]]) -> None:
|
||||
"""
|
||||
Register event callback.
|
||||
|
||||
Args:
|
||||
event: Event name
|
||||
callback: Async callback function
|
||||
"""
|
||||
if event in self._callbacks:
|
||||
self._callbacks[event].append(callback)
|
||||
|
||||
async def _emit(self, event: str, *args, **kwargs) -> None:
|
||||
"""Emit event to all registered callbacks."""
|
||||
for callback in self._callbacks.get(event, []):
|
||||
try:
|
||||
await callback(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Event callback error ({event}): {e}")
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to OpenAI Realtime API."""
|
||||
if not WEBSOCKETS_AVAILABLE:
|
||||
raise RuntimeError("websockets package not installed")
|
||||
|
||||
if not self.config.api_key:
|
||||
raise ValueError("OpenAI API key not provided")
|
||||
|
||||
self.state = RealtimeState.CONNECTING
|
||||
|
||||
# Build URL
|
||||
if self.config.endpoint:
|
||||
# Azure or custom endpoint
|
||||
url = f"{self.config.endpoint}/openai/realtime?api-version=2024-10-01-preview&deployment={self.config.model}"
|
||||
else:
|
||||
# OpenAI endpoint
|
||||
url = f"wss://api.openai.com/v1/realtime?model={self.config.model}"
|
||||
|
||||
# Build headers
|
||||
headers = {}
|
||||
if self.config.endpoint:
|
||||
headers["api-key"] = self.config.api_key
|
||||
else:
|
||||
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
||||
headers["OpenAI-Beta"] = "realtime=v1"
|
||||
|
||||
try:
|
||||
logger.info(f"Connecting to Realtime API: {url}")
|
||||
self._ws = await websockets.connect(url, extra_headers=headers)
|
||||
|
||||
# Send session configuration
|
||||
await self._configure_session()
|
||||
|
||||
# Start receive loop
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
self.state = RealtimeState.CONNECTED
|
||||
logger.info("Realtime API connected successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.state = RealtimeState.ERROR
|
||||
logger.error(f"Realtime API connection failed: {e}")
|
||||
raise
|
||||
|
||||
async def _configure_session(self) -> None:
|
||||
"""Send session configuration to server."""
|
||||
session_config = {
|
||||
"type": "session.update",
|
||||
"session": {
|
||||
"modalities": ["text", "audio"],
|
||||
"instructions": self.config.instructions,
|
||||
"voice": self.config.voice,
|
||||
"input_audio_format": self.config.input_audio_format,
|
||||
"output_audio_format": self.config.output_audio_format,
|
||||
"turn_detection": self.config.turn_detection,
|
||||
}
|
||||
}
|
||||
|
||||
if self.config.tools:
|
||||
session_config["session"]["tools"] = self.config.tools
|
||||
|
||||
await self._send(session_config)
|
||||
logger.debug("Session configuration sent")
|
||||
|
||||
async def _send(self, data: Dict[str, Any]) -> None:
|
||||
"""Send JSON data to server."""
|
||||
if self._ws:
|
||||
await self._ws.send(json.dumps(data))
|
||||
|
||||
async def send_audio(self, audio_bytes: bytes) -> None:
|
||||
"""
|
||||
Send audio to the Realtime API.
|
||||
|
||||
Args:
|
||||
audio_bytes: PCM audio data (16-bit, mono, 24kHz by default)
|
||||
"""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
# Encode audio as base64
|
||||
audio_b64 = base64.standard_b64encode(audio_bytes).decode()
|
||||
|
||||
await self._send({
|
||||
"type": "input_audio_buffer.append",
|
||||
"audio": audio_b64
|
||||
})
|
||||
|
||||
async def send_text(self, text: str) -> None:
|
||||
"""
|
||||
Send text input (bypassing audio).
|
||||
|
||||
Args:
|
||||
text: User text input
|
||||
"""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
# Create a conversation item with user text
|
||||
await self._send({
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": text}]
|
||||
}
|
||||
})
|
||||
|
||||
# Trigger response
|
||||
await self._send({"type": "response.create"})
|
||||
|
||||
async def cancel_response(self) -> None:
|
||||
"""Cancel the current response (for barge-in)."""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({"type": "response.cancel"})
|
||||
logger.debug("Response cancelled")
|
||||
|
||||
async def commit_audio(self) -> None:
|
||||
"""Commit the audio buffer and trigger response."""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({"type": "input_audio_buffer.commit"})
|
||||
await self._send({"type": "response.create"})
|
||||
|
||||
async def clear_audio_buffer(self) -> None:
|
||||
"""Clear the input audio buffer."""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({"type": "input_audio_buffer.clear"})
|
||||
|
||||
async def submit_function_result(self, call_id: str, result: str) -> None:
|
||||
"""
|
||||
Submit function call result.
|
||||
|
||||
Args:
|
||||
call_id: The function call ID
|
||||
result: JSON string result
|
||||
"""
|
||||
if self.state != RealtimeState.CONNECTED:
|
||||
return
|
||||
|
||||
await self._send({
|
||||
"type": "conversation.item.create",
|
||||
"item": {
|
||||
"type": "function_call_output",
|
||||
"call_id": call_id,
|
||||
"output": result
|
||||
}
|
||||
})
|
||||
|
||||
# Trigger response with the function result
|
||||
await self._send({"type": "response.create"})
|
||||
|
||||
async def _receive_loop(self) -> None:
|
||||
"""Receive and process messages from the Realtime API."""
|
||||
if not self._ws:
|
||||
return
|
||||
|
||||
try:
|
||||
async for message in self._ws:
|
||||
try:
|
||||
data = json.loads(message)
|
||||
await self._handle_event(data)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Invalid JSON received: {message[:100]}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Receive loop cancelled")
|
||||
except websockets.ConnectionClosed as e:
|
||||
logger.info(f"WebSocket closed: {e}")
|
||||
self.state = RealtimeState.DISCONNECTED
|
||||
except Exception as e:
|
||||
logger.error(f"Receive loop error: {e}")
|
||||
self.state = RealtimeState.ERROR
|
||||
|
||||
async def _handle_event(self, data: Dict[str, Any]) -> None:
|
||||
"""Handle incoming event from Realtime API."""
|
||||
event_type = data.get("type", "unknown")
|
||||
|
||||
# Audio delta - streaming audio output
|
||||
if event_type == "response.audio.delta":
|
||||
if "delta" in data:
|
||||
audio_bytes = base64.standard_b64decode(data["delta"])
|
||||
await self._emit("on_audio", audio_bytes)
|
||||
|
||||
# Audio transcript delta - streaming text
|
||||
elif event_type == "response.audio_transcript.delta":
|
||||
if "delta" in data:
|
||||
await self._emit("on_transcript", data["delta"], "assistant", False)
|
||||
|
||||
# Audio transcript done
|
||||
elif event_type == "response.audio_transcript.done":
|
||||
if "transcript" in data:
|
||||
await self._emit("on_transcript", data["transcript"], "assistant", True)
|
||||
|
||||
# Input audio transcript (user speech)
|
||||
elif event_type == "conversation.item.input_audio_transcription.completed":
|
||||
if "transcript" in data:
|
||||
await self._emit("on_transcript", data["transcript"], "user", True)
|
||||
|
||||
# Speech started (server VAD detected speech)
|
||||
elif event_type == "input_audio_buffer.speech_started":
|
||||
await self._emit("on_speech_started", data.get("audio_start_ms", 0))
|
||||
|
||||
# Speech stopped
|
||||
elif event_type == "input_audio_buffer.speech_stopped":
|
||||
await self._emit("on_speech_stopped", data.get("audio_end_ms", 0))
|
||||
|
||||
# Response started
|
||||
elif event_type == "response.created":
|
||||
await self._emit("on_response_started", data.get("response", {}))
|
||||
|
||||
# Response done
|
||||
elif event_type == "response.done":
|
||||
await self._emit("on_response_done", data.get("response", {}))
|
||||
|
||||
# Function call
|
||||
elif event_type == "response.function_call_arguments.done":
|
||||
call_id = data.get("call_id")
|
||||
name = data.get("name")
|
||||
arguments = data.get("arguments", "{}")
|
||||
await self._emit("on_function_call", call_id, name, arguments)
|
||||
|
||||
# Error
|
||||
elif event_type == "error":
|
||||
error = data.get("error", {})
|
||||
logger.error(f"Realtime API error: {error}")
|
||||
await self._emit("on_error", error)
|
||||
|
||||
# Session events
|
||||
elif event_type == "session.created":
|
||||
logger.info("Session created")
|
||||
elif event_type == "session.updated":
|
||||
logger.debug("Session updated")
|
||||
|
||||
else:
|
||||
logger.debug(f"Unhandled event type: {event_type}")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from Realtime API."""
|
||||
self._cancel_event.set()
|
||||
|
||||
if self._receive_task:
|
||||
self._receive_task.cancel()
|
||||
try:
|
||||
await self._receive_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
|
||||
self.state = RealtimeState.DISCONNECTED
|
||||
logger.info("Realtime API disconnected")
|
||||
|
||||
|
||||
class RealtimePipeline:
|
||||
"""
|
||||
Pipeline adapter for RealtimeService.
|
||||
|
||||
Provides a compatible interface with DuplexPipeline but uses
|
||||
OpenAI Realtime API for all processing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport,
|
||||
session_id: str,
|
||||
config: Optional[RealtimeConfig] = None
|
||||
):
|
||||
"""
|
||||
Initialize Realtime pipeline.
|
||||
|
||||
Args:
|
||||
transport: Transport for sending audio/events
|
||||
session_id: Session identifier
|
||||
config: Realtime configuration
|
||||
"""
|
||||
self.transport = transport
|
||||
self.session_id = session_id
|
||||
|
||||
self.service = RealtimeService(config)
|
||||
|
||||
# Register callbacks
|
||||
self.service.on("on_audio", self._on_audio)
|
||||
self.service.on("on_transcript", self._on_transcript)
|
||||
self.service.on("on_speech_started", self._on_speech_started)
|
||||
self.service.on("on_speech_stopped", self._on_speech_stopped)
|
||||
self.service.on("on_response_started", self._on_response_started)
|
||||
self.service.on("on_response_done", self._on_response_done)
|
||||
self.service.on("on_error", self._on_error)
|
||||
|
||||
self._is_speaking = False
|
||||
self._running = True
|
||||
|
||||
logger.info(f"RealtimePipeline initialized for session {session_id}")
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the pipeline."""
|
||||
await self.service.connect()
|
||||
|
||||
async def process_audio(self, pcm_bytes: bytes) -> None:
|
||||
"""
|
||||
Process incoming audio.
|
||||
|
||||
Note: Realtime API expects 24kHz audio by default.
|
||||
You may need to resample from 16kHz.
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
# TODO: Resample from 16kHz to 24kHz if needed
|
||||
await self.service.send_audio(pcm_bytes)
|
||||
|
||||
async def process_text(self, text: str) -> None:
|
||||
"""Process text input."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
await self.service.send_text(text)
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Interrupt current response."""
|
||||
await self.service.cancel_response()
|
||||
await self.transport.send_event({
|
||||
"event": "interrupt",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""Cleanup resources."""
|
||||
self._running = False
|
||||
await self.service.disconnect()
|
||||
|
||||
# Event handlers
|
||||
|
||||
async def _on_audio(self, audio_bytes: bytes) -> None:
|
||||
"""Handle audio output."""
|
||||
await self.transport.send_audio(audio_bytes)
|
||||
|
||||
async def _on_transcript(self, text: str, role: str, is_final: bool) -> None:
|
||||
"""Handle transcript."""
|
||||
logger.info(f"[{role.upper()}] {text[:50]}..." if len(text) > 50 else f"[{role.upper()}] {text}")
|
||||
|
||||
async def _on_speech_started(self, start_ms: int) -> None:
|
||||
"""Handle user speech start."""
|
||||
self._is_speaking = True
|
||||
await self.transport.send_event({
|
||||
"event": "speaking",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"startTime": start_ms
|
||||
})
|
||||
|
||||
# Cancel any ongoing response (barge-in)
|
||||
await self.service.cancel_response()
|
||||
|
||||
async def _on_speech_stopped(self, end_ms: int) -> None:
|
||||
"""Handle user speech stop."""
|
||||
self._is_speaking = False
|
||||
await self.transport.send_event({
|
||||
"event": "silence",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"duration": end_ms
|
||||
})
|
||||
|
||||
async def _on_response_started(self, response: Dict) -> None:
|
||||
"""Handle response start."""
|
||||
await self.transport.send_event({
|
||||
"event": "trackStart",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
async def _on_response_done(self, response: Dict) -> None:
|
||||
"""Handle response complete."""
|
||||
await self.transport.send_event({
|
||||
"event": "trackEnd",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms()
|
||||
})
|
||||
|
||||
async def _on_error(self, error: Dict) -> None:
|
||||
"""Handle error."""
|
||||
await self.transport.send_event({
|
||||
"event": "error",
|
||||
"trackId": self.session_id,
|
||||
"timestamp": self._get_timestamp_ms(),
|
||||
"sender": "realtime",
|
||||
"error": str(error)
|
||||
})
|
||||
|
||||
def _get_timestamp_ms(self) -> int:
|
||||
"""Get current timestamp in milliseconds."""
|
||||
import time
|
||||
return int(time.time() * 1000)
|
||||
|
||||
@property
|
||||
def is_speaking(self) -> bool:
|
||||
"""Check if user is speaking."""
|
||||
return self._is_speaking
|
||||
Reference in New Issue
Block a user