Unify db api
This commit is contained in:
@@ -7,10 +7,10 @@ for real-time voice conversation.
|
||||
import os
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any
|
||||
from typing import AsyncIterator, Optional, List, Dict, Any, Callable, Awaitable
|
||||
from loguru import logger
|
||||
|
||||
from app.backend_client import search_knowledge_context
|
||||
from app.backend_adapters import build_backend_adapter_from_settings
|
||||
from services.base import BaseLLMService, LLMMessage, LLMStreamEvent, ServiceState
|
||||
|
||||
# Try to import openai
|
||||
@@ -37,20 +37,21 @@ class OpenAILLMService(BaseLLMService):
|
||||
base_url: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
knowledge_config: Optional[Dict[str, Any]] = None,
|
||||
knowledge_searcher: Optional[Callable[..., Awaitable[List[Dict[str, Any]]]]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize OpenAI LLM service.
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "gpt-4o-mini", "gpt-4o")
|
||||
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
|
||||
api_key: Provider API key (defaults to LLM_API_KEY/OPENAI_API_KEY env vars)
|
||||
base_url: Custom API base URL (for Azure or compatible APIs)
|
||||
system_prompt: Default system prompt for conversations
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = base_url or os.getenv("OPENAI_API_URL")
|
||||
self.api_key = api_key or os.getenv("LLM_API_KEY") or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = base_url or os.getenv("LLM_API_URL") or os.getenv("OPENAI_API_URL")
|
||||
self.system_prompt = system_prompt or (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational. "
|
||||
@@ -60,6 +61,11 @@ class OpenAILLMService(BaseLLMService):
|
||||
self.client: Optional[AsyncOpenAI] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
self._knowledge_config: Dict[str, Any] = knowledge_config or {}
|
||||
if knowledge_searcher is None:
|
||||
adapter = build_backend_adapter_from_settings()
|
||||
self._knowledge_searcher = adapter.search_knowledge_context
|
||||
else:
|
||||
self._knowledge_searcher = knowledge_searcher
|
||||
self._tool_schemas: List[Dict[str, Any]] = []
|
||||
|
||||
_RAG_DEFAULT_RESULTS = 5
|
||||
@@ -224,7 +230,7 @@ class OpenAILLMService(BaseLLMService):
|
||||
n_results = self._coerce_int(cfg.get("nResults"), self._RAG_DEFAULT_RESULTS)
|
||||
n_results = max(1, min(n_results, self._RAG_MAX_RESULTS))
|
||||
|
||||
results = await search_knowledge_context(
|
||||
results = await self._knowledge_searcher(
|
||||
kb_id=kb_id,
|
||||
query=latest_user,
|
||||
n_results=n_results,
|
||||
|
||||
@@ -6,6 +6,7 @@ API: https://docs.siliconflow.cn/cn/api-reference/audio/create-audio-transcripti
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
import os
|
||||
import wave
|
||||
from typing import AsyncIterator, Optional, Callable, Awaitable
|
||||
from loguru import logger
|
||||
@@ -46,7 +47,8 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
model: str = "FunAudioLLM/SenseVoiceSmall",
|
||||
sample_rate: int = 16000,
|
||||
language: str = "auto",
|
||||
@@ -59,6 +61,7 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
|
||||
Args:
|
||||
api_key: Provider API key
|
||||
api_url: Provider API URL (defaults to SiliconFlow endpoint)
|
||||
model: ASR model name or alias
|
||||
sample_rate: Audio sample rate (16000 recommended)
|
||||
language: Language code (auto for automatic detection)
|
||||
@@ -71,7 +74,8 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError("aiohttp is required for OpenAICompatibleASRService")
|
||||
|
||||
self.api_key = api_key
|
||||
self.api_key = api_key or os.getenv("ASR_API_KEY") or os.getenv("SILICONFLOW_API_KEY")
|
||||
self.api_url = api_url or os.getenv("ASR_API_URL") or self.API_URL
|
||||
self.model = self.MODELS.get(model.lower(), model)
|
||||
self.interim_interval_ms = interim_interval_ms
|
||||
self.min_audio_for_interim_ms = min_audio_for_interim_ms
|
||||
@@ -96,6 +100,8 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Connect to the service."""
|
||||
if not self.api_key:
|
||||
raise ValueError("ASR API key not provided. Configure agent.asr.api_key in YAML.")
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
@@ -180,7 +186,7 @@ class OpenAICompatibleASRService(BaseASRService):
|
||||
)
|
||||
form_data.add_field('model', self.model)
|
||||
|
||||
async with self._session.post(self.API_URL, data=form_data) as response:
|
||||
async with self._session.post(self.api_url, data=form_data) as response:
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
text = result.get("text", "").strip()
|
||||
|
||||
@@ -38,6 +38,7 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
api_url: Optional[str] = None,
|
||||
voice: str = "anna",
|
||||
model: str = "FunAudioLLM/CosyVoice2-0.5B",
|
||||
sample_rate: int = 16000,
|
||||
@@ -47,7 +48,8 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
||||
Initialize OpenAI-compatible TTS service.
|
||||
|
||||
Args:
|
||||
api_key: Provider API key (defaults to SILICONFLOW_API_KEY env var)
|
||||
api_key: Provider API key (defaults to TTS_API_KEY/SILICONFLOW_API_KEY env vars)
|
||||
api_url: Provider API URL (defaults to SiliconFlow endpoint)
|
||||
voice: Voice name (alex, anna, bella, benjamin, charles, claire, david, diana)
|
||||
model: Model name
|
||||
sample_rate: Output sample rate (8000, 16000, 24000, 32000, 44100)
|
||||
@@ -70,9 +72,9 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
||||
|
||||
super().__init__(voice=full_voice, sample_rate=sample_rate, speed=speed)
|
||||
|
||||
self.api_key = api_key or os.getenv("SILICONFLOW_API_KEY")
|
||||
self.api_key = api_key or os.getenv("TTS_API_KEY") or os.getenv("SILICONFLOW_API_KEY")
|
||||
self.model = model
|
||||
self.api_url = "https://api.siliconflow.cn/v1/audio/speech"
|
||||
self.api_url = api_url or os.getenv("TTS_API_URL") or "https://api.siliconflow.cn/v1/audio/speech"
|
||||
|
||||
self._session: Optional[aiohttp.ClientSession] = None
|
||||
self._cancel_event = asyncio.Event()
|
||||
@@ -80,7 +82,7 @@ class OpenAICompatibleTTSService(BaseTTSService):
|
||||
async def connect(self) -> None:
|
||||
"""Initialize HTTP session."""
|
||||
if not self.api_key:
|
||||
raise ValueError("SiliconFlow API key not provided. Set SILICONFLOW_API_KEY env var.")
|
||||
raise ValueError("TTS API key not provided. Configure agent.tts.api_key in YAML.")
|
||||
|
||||
self._session = aiohttp.ClientSession(
|
||||
headers={
|
||||
|
||||
Reference in New Issue
Block a user