240 lines
7.6 KiB
Python
240 lines
7.6 KiB
Python
"""LLM (Large Language Model) Service implementations.
|
|
|
|
Provides OpenAI-compatible LLM integration with streaming support
|
|
for real-time voice conversation.
|
|
"""
|
|
|
|
import os
|
|
import asyncio
|
|
from typing import AsyncIterator, Optional, List, Dict, Any
|
|
from loguru import logger
|
|
|
|
from services.base import BaseLLMService, LLMMessage, ServiceState
|
|
|
|
# Try to import openai
|
|
try:
|
|
from openai import AsyncOpenAI
|
|
OPENAI_AVAILABLE = True
|
|
except ImportError:
|
|
OPENAI_AVAILABLE = False
|
|
logger.warning("openai package not available - LLM service will be disabled")
|
|
|
|
|
|
class OpenAILLMService(BaseLLMService):
|
|
"""
|
|
OpenAI-compatible LLM service.
|
|
|
|
Supports streaming responses for low-latency voice conversation.
|
|
Works with OpenAI API, Azure OpenAI, and compatible APIs.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
model: str = "gpt-4o-mini",
|
|
api_key: Optional[str] = None,
|
|
base_url: Optional[str] = None,
|
|
system_prompt: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize OpenAI LLM service.
|
|
|
|
Args:
|
|
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
|
|
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
|
|
base_url: Custom API base URL (for Azure or compatible APIs)
|
|
system_prompt: Default system prompt for conversations
|
|
"""
|
|
super().__init__(model=model)
|
|
|
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
|
self.base_url = base_url or os.getenv("OPENAI_API_URL")
|
|
self.system_prompt = system_prompt or (
|
|
"You are a helpful, friendly voice assistant. "
|
|
"Keep your responses concise and conversational. "
|
|
"Respond naturally as if having a phone conversation."
|
|
)
|
|
|
|
self.client: Optional[AsyncOpenAI] = None
|
|
self._cancel_event = asyncio.Event()
|
|
|
|
async def connect(self) -> None:
|
|
"""Initialize OpenAI client."""
|
|
if not OPENAI_AVAILABLE:
|
|
raise RuntimeError("openai package not installed")
|
|
|
|
if not self.api_key:
|
|
raise ValueError("OpenAI API key not provided")
|
|
|
|
self.client = AsyncOpenAI(
|
|
api_key=self.api_key,
|
|
base_url=self.base_url
|
|
)
|
|
self.state = ServiceState.CONNECTED
|
|
logger.info(f"OpenAI LLM service connected: model={self.model}")
|
|
|
|
async def disconnect(self) -> None:
|
|
"""Close OpenAI client."""
|
|
if self.client:
|
|
await self.client.close()
|
|
self.client = None
|
|
self.state = ServiceState.DISCONNECTED
|
|
logger.info("OpenAI LLM service disconnected")
|
|
|
|
def _prepare_messages(self, messages: List[LLMMessage]) -> List[Dict[str, Any]]:
|
|
"""Prepare messages list with system prompt."""
|
|
result = []
|
|
|
|
# Add system prompt if not already present
|
|
has_system = any(m.role == "system" for m in messages)
|
|
if not has_system and self.system_prompt:
|
|
result.append({"role": "system", "content": self.system_prompt})
|
|
|
|
# Add all messages
|
|
for msg in messages:
|
|
result.append(msg.to_dict())
|
|
|
|
return result
|
|
|
|
async def generate(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
) -> str:
|
|
"""
|
|
Generate a complete response.
|
|
|
|
Args:
|
|
messages: Conversation history
|
|
temperature: Sampling temperature
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
Returns:
|
|
Complete response text
|
|
"""
|
|
if not self.client:
|
|
raise RuntimeError("LLM service not connected")
|
|
|
|
prepared = self._prepare_messages(messages)
|
|
|
|
try:
|
|
response = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=prepared,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens
|
|
)
|
|
|
|
content = response.choices[0].message.content or ""
|
|
logger.debug(f"LLM response: {content[:100]}...")
|
|
return content
|
|
|
|
except Exception as e:
|
|
logger.error(f"LLM generation error: {e}")
|
|
raise
|
|
|
|
async def generate_stream(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
) -> AsyncIterator[str]:
|
|
"""
|
|
Generate response in streaming mode.
|
|
|
|
Args:
|
|
messages: Conversation history
|
|
temperature: Sampling temperature
|
|
max_tokens: Maximum tokens to generate
|
|
|
|
Yields:
|
|
Text chunks as they are generated
|
|
"""
|
|
if not self.client:
|
|
raise RuntimeError("LLM service not connected")
|
|
|
|
prepared = self._prepare_messages(messages)
|
|
self._cancel_event.clear()
|
|
|
|
try:
|
|
stream = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=prepared,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
stream=True
|
|
)
|
|
|
|
async for chunk in stream:
|
|
# Check for cancellation
|
|
if self._cancel_event.is_set():
|
|
logger.info("LLM stream cancelled")
|
|
break
|
|
|
|
if chunk.choices and chunk.choices[0].delta.content:
|
|
content = chunk.choices[0].delta.content
|
|
yield content
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("LLM stream cancelled via asyncio")
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"LLM streaming error: {e}")
|
|
raise
|
|
|
|
def cancel(self) -> None:
|
|
"""Cancel ongoing generation."""
|
|
self._cancel_event.set()
|
|
|
|
|
|
class MockLLMService(BaseLLMService):
|
|
"""
|
|
Mock LLM service for testing without API calls.
|
|
"""
|
|
|
|
def __init__(self, response_delay: float = 0.5):
|
|
super().__init__(model="mock")
|
|
self.response_delay = response_delay
|
|
self.responses = [
|
|
"Hello! How can I help you today?",
|
|
"That's an interesting question. Let me think about it.",
|
|
"I understand. Is there anything else you'd like to know?",
|
|
"Great! I'm here if you need anything else.",
|
|
]
|
|
self._response_index = 0
|
|
|
|
async def connect(self) -> None:
|
|
self.state = ServiceState.CONNECTED
|
|
logger.info("Mock LLM service connected")
|
|
|
|
async def disconnect(self) -> None:
|
|
self.state = ServiceState.DISCONNECTED
|
|
logger.info("Mock LLM service disconnected")
|
|
|
|
async def generate(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
) -> str:
|
|
await asyncio.sleep(self.response_delay)
|
|
response = self.responses[self._response_index % len(self.responses)]
|
|
self._response_index += 1
|
|
return response
|
|
|
|
async def generate_stream(
|
|
self,
|
|
messages: List[LLMMessage],
|
|
temperature: float = 0.7,
|
|
max_tokens: Optional[int] = None
|
|
) -> AsyncIterator[str]:
|
|
response = await self.generate(messages, temperature, max_tokens)
|
|
|
|
# Stream word by word
|
|
words = response.split()
|
|
for i, word in enumerate(words):
|
|
if i > 0:
|
|
yield " "
|
|
yield word
|
|
await asyncio.sleep(0.05) # Simulate streaming delay
|