245 lines
6.2 KiB
Python
245 lines
6.2 KiB
Python
"""Base classes for AI services.
|
|
|
|
Defines abstract interfaces for ASR, LLM, and TTS services,
|
|
inspired by pipecat's service architecture and active-call's
|
|
StreamEngine pattern.
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from typing import AsyncIterator, Optional, List, Dict, Any
|
|
from enum import Enum
|
|
|
|
|
|
class ServiceState(Enum):
|
|
"""Service connection state."""
|
|
DISCONNECTED = "disconnected"
|
|
CONNECTING = "connecting"
|
|
CONNECTED = "connected"
|
|
ERROR = "error"
|
|
|
|
|
|
@dataclass
|
|
class ASRResult:
|
|
"""ASR transcription result."""
|
|
text: str
|
|
is_final: bool = False
|
|
confidence: float = 1.0
|
|
language: Optional[str] = None
|
|
start_time: Optional[float] = None
|
|
end_time: Optional[float] = None
|
|
|
|
def __str__(self) -> str:
|
|
status = "FINAL" if self.is_final else "PARTIAL"
|
|
return f"[{status}] {self.text}"
|
|
|
|
|
|
@dataclass
|
|
class LLMMessage:
|
|
"""LLM conversation message."""
|
|
role: str # "system", "user", "assistant", "function"
|
|
content: str
|
|
name: Optional[str] = None # For function calls
|
|
function_call: Optional[Dict[str, Any]] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert to API-compatible dict."""
|
|
d = {"role": self.role, "content": self.content}
|
|
if self.name:
|
|
d["name"] = self.name
|
|
if self.function_call:
|
|
d["function_call"] = self.function_call
|
|
return d
|
|
|
|
|
|
@dataclass
|
|
class TTSChunk:
|
|
"""TTS audio chunk."""
|
|
audio: bytes # PCM audio data
|
|
sample_rate: int = 16000
|
|
channels: int = 1
|
|
bits_per_sample: int = 16
|
|
is_final: bool = False
|
|
text_offset: Optional[int] = None # Character offset in original text
|
|
|
|
|
|
class BaseASRService(ABC):
|
|
"""
|
|
Abstract base class for ASR (Speech-to-Text) services.
|
|
|
|
Supports both streaming and non-streaming transcription.
|
|
"""
|
|
|
|
def __init__(self, sample_rate: int = 16000, language: str = "en"):
|
|
self.sample_rate = sample_rate
|
|
self.language = language
|
|
self.state = ServiceState.DISCONNECTED
|
|
|
|
@abstractmethod
|
|
async def connect(self) -> None:
|
|
"""Establish connection to ASR service."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def disconnect(self) -> None:
|
|
"""Close connection to ASR service."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def send_audio(self, audio: bytes) -> None:
|
|
"""
|
|
Send audio chunk for transcription.
|
|
|
|
Args:
|
|
audio: PCM audio data (16-bit, mono)
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def receive_transcripts(self) -> AsyncIterator[ASRResult]:
|
|
"""
|
|
Receive transcription results.
|
|
|
|
Yields:
|
|
ASRResult objects as they become available
|
|
"""
|
|
pass
|
|
|
|
async def transcribe(self, audio: bytes) -> ASRResult:
|
|
"""
|
|
Transcribe a complete audio buffer (non-streaming).
|
|
|
|
Args:
|
|
audio: Complete PCM audio data
|
|
|
|
Returns:
|
|
Final ASRResult
|
|
"""
|
|
# Default implementation using streaming
|
|
await self.send_audio(audio)
|
|
async for result in self.receive_transcripts():
|
|
if result.is_final:
|
|
return result
|
|
return ASRResult(text="", is_final=True)
|
|
|
|
|
|
class BaseLLMService(ABC):
|
|
"""
|
|
Abstract base class for LLM (Language Model) services.
|
|
|
|
Supports streaming responses for real-time conversation.
|
|
"""
|
|
|
|
def __init__(self, model: str = "gpt-4"):
|
|
self.model = model
|
|
self.state = ServiceState.DISCONNECTED
|
|
|
|
@abstractmethod
|
|
async def connect(self) -> None:
|
|
"""Initialize LLM service connection."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def disconnect(self) -> None:
|
|
"""Close LLM service connection."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def generate(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
) -> str:
|
|
"""
|
|
Generate a complete response.
|
|
|
|
Args:
|
|
messages: Conversation history
|
|
temperature: Sampling temperature
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
Returns:
|
|
Complete response text
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def generate_stream(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
) -> AsyncIterator[str]:
|
|
"""
|
|
Generate response in streaming mode.
|
|
|
|
Args:
|
|
messages: Conversation history
|
|
temperature: Sampling temperature
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
Yields:
|
|
Text chunks as they are generated
|
|
"""
|
|
pass
|
|
|
|
|
|
class BaseTTSService(ABC):
|
|
"""
|
|
Abstract base class for TTS (Text-to-Speech) services.
|
|
|
|
Supports streaming audio synthesis for low-latency playback.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
voice: str = "default",
|
|
sample_rate: int = 16000,
|
|
speed: float = 1.0
|
|
):
|
|
self.voice = voice
|
|
self.sample_rate = sample_rate
|
|
self.speed = speed
|
|
self.state = ServiceState.DISCONNECTED
|
|
|
|
@abstractmethod
|
|
async def connect(self) -> None:
|
|
"""Initialize TTS service connection."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def disconnect(self) -> None:
|
|
"""Close TTS service connection."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def synthesize(self, text: str) -> bytes:
|
|
"""
|
|
Synthesize complete audio for text (non-streaming).
|
|
|
|
Args:
|
|
text: Text to synthesize
|
|
|
|
Returns:
|
|
Complete PCM audio data
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def synthesize_stream(self, text: str) -> AsyncIterator[TTSChunk]:
|
|
"""
|
|
Synthesize audio in streaming mode.
|
|
|
|
Args:
|
|
text: Text to synthesize
|
|
|
|
Yields:
|
|
TTSChunk objects as audio is generated
|
|
"""
|
|
pass
|
|
|
|
async def cancel(self) -> None:
|
|
"""Cancel ongoing synthesis (for barge-in support)."""
|
|
pass
|