It works
This commit is contained in:
261
src/fastgpt_api.py
Normal file
261
src/fastgpt_api.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Dict
|
||||
from logger import log_info, log_debug, log_warning, log_error, log_performance
|
||||
|
||||
class ChatModel:
|
||||
def __init__(self, api_key: str, api_url: str, appId: str, client_id: str = None):
|
||||
self._api_key = api_key
|
||||
self._api_url = api_url
|
||||
self._appId = appId
|
||||
self._client_id = client_id
|
||||
|
||||
log_info(self._client_id, "ChatModel initialized",
|
||||
api_url=self._api_url,
|
||||
app_id=self._appId)
|
||||
|
||||
async def get_welcome_text(self, chatId: str) -> str:
|
||||
"""Get welcome text from FastGPT API."""
|
||||
start_time = time.perf_counter()
|
||||
url = f'{self._api_url}/api/core/chat/init'
|
||||
|
||||
log_debug(self._client_id, "Requesting welcome text",
|
||||
chat_id=chatId,
|
||||
url=url)
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self._api_key}'
|
||||
}
|
||||
params = {
|
||||
'appId': self._appId,
|
||||
'chatId': chatId
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, headers=headers, params=params) as response:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if response.status == 200:
|
||||
response_data = await response.json()
|
||||
welcome_text = response_data['data']['app']['chatConfig']['welcomeText']
|
||||
|
||||
log_performance(self._client_id, "Welcome text request completed",
|
||||
duration=f"{duration:.3f}s",
|
||||
status_code=response.status,
|
||||
response_length=len(welcome_text))
|
||||
|
||||
log_debug(self._client_id, "Welcome text retrieved",
|
||||
chat_id=chatId,
|
||||
welcome_text_length=len(welcome_text))
|
||||
|
||||
return welcome_text
|
||||
else:
|
||||
error_msg = f"Failed to get welcome text. Status code: {response.status}"
|
||||
log_error(self._client_id, error_msg,
|
||||
chat_id=chatId,
|
||||
status_code=response.status,
|
||||
url=url)
|
||||
raise Exception(error_msg)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
error_msg = f"Network error while getting welcome text: {e}"
|
||||
log_error(self._client_id, error_msg,
|
||||
chat_id=chatId,
|
||||
duration=f"{duration:.3f}s",
|
||||
exception_type=type(e).__name__)
|
||||
raise Exception(error_msg)
|
||||
except Exception as e:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
error_msg = f"Unexpected error while getting welcome text: {e}"
|
||||
log_error(self._client_id, error_msg,
|
||||
chat_id=chatId,
|
||||
duration=f"{duration:.3f}s",
|
||||
exception_type=type(e).__name__)
|
||||
raise
|
||||
|
||||
async def generate_ai_response(self, chatId: str, content: str) -> str:
|
||||
"""Generate AI response from FastGPT API."""
|
||||
start_time = time.perf_counter()
|
||||
url = f'{self._api_url}/api/v1/chat/completions'
|
||||
|
||||
log_debug(self._client_id, "Generating AI response",
|
||||
chat_id=chatId,
|
||||
content_length=len(content),
|
||||
url=url)
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self._api_key}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
data = {
|
||||
'chatId': chatId,
|
||||
'messages': [
|
||||
{
|
||||
'content': content,
|
||||
'role': 'user'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if response.status == 200:
|
||||
response_data = await response.json()
|
||||
ai_response = response_data['choices'][0]['message']['content']
|
||||
|
||||
log_performance(self._client_id, "AI response generation completed",
|
||||
duration=f"{duration:.3f}s",
|
||||
status_code=response.status,
|
||||
input_length=len(content),
|
||||
output_length=len(ai_response))
|
||||
|
||||
log_debug(self._client_id, "AI response generated",
|
||||
chat_id=chatId,
|
||||
input_length=len(content),
|
||||
response_length=len(ai_response))
|
||||
|
||||
return ai_response
|
||||
else:
|
||||
error_msg = f"Failed to generate AI response. Status code: {response.status}"
|
||||
log_error(self._client_id, error_msg,
|
||||
chat_id=chatId,
|
||||
status_code=response.status,
|
||||
url=url,
|
||||
input_length=len(content))
|
||||
raise Exception(error_msg)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
error_msg = f"Network error while generating AI response: {e}"
|
||||
log_error(self._client_id, error_msg,
|
||||
chat_id=chatId,
|
||||
duration=f"{duration:.3f}s",
|
||||
exception_type=type(e).__name__,
|
||||
input_length=len(content))
|
||||
raise Exception(error_msg)
|
||||
except Exception as e:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
error_msg = f"Unexpected error while generating AI response: {e}"
|
||||
log_error(self._client_id, error_msg,
|
||||
chat_id=chatId,
|
||||
duration=f"{duration:.3f}s",
|
||||
exception_type=type(e).__name__,
|
||||
input_length=len(content))
|
||||
raise
|
||||
|
||||
async def get_chat_history(self, chatId: str) -> List[Dict[str, str]]:
|
||||
"""Get chat history from FastGPT API."""
|
||||
start_time = time.perf_counter()
|
||||
url = f'{self._api_url}/api/core/chat/getPaginationRecords'
|
||||
|
||||
log_debug(self._client_id, "Fetching chat history",
|
||||
chat_id=chatId,
|
||||
url=url)
|
||||
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self._api_key}',
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
data = {
|
||||
'appId': self._appId,
|
||||
'chatId': chatId,
|
||||
'loadCustomFeedbacks': False
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
if response.status == 200:
|
||||
response_data = await response.json()
|
||||
chat_history = []
|
||||
|
||||
for element in response_data['data']['list']:
|
||||
if element['obj'] == 'Human':
|
||||
chat_history.append({'role': 'user', 'content': element['value'][0]['text']})
|
||||
elif element['obj'] == 'AI':
|
||||
chat_history.append({'role': 'assistant', 'content': element['value'][0]['text']})
|
||||
|
||||
log_performance(self._client_id, "Chat history fetch completed",
|
||||
duration=f"{duration:.3f}s",
|
||||
status_code=response.status,
|
||||
history_count=len(chat_history))
|
||||
|
||||
log_debug(self._client_id, "Chat history retrieved",
|
||||
chat_id=chatId,
|
||||
history_count=len(chat_history))
|
||||
|
||||
return chat_history
|
||||
else:
|
||||
error_msg = f"Failed to fetch chat history. Status code: {response.status}"
|
||||
log_error(self._client_id, error_msg,
|
||||
chat_id=chatId,
|
||||
status_code=response.status,
|
||||
url=url)
|
||||
raise Exception(error_msg)
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
error_msg = f"Network error while fetching chat history: {e}"
|
||||
log_error(self._client_id, error_msg,
|
||||
chat_id=chatId,
|
||||
duration=f"{duration:.3f}s",
|
||||
exception_type=type(e).__name__)
|
||||
raise Exception(error_msg)
|
||||
except Exception as e:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
error_msg = f"Unexpected error while fetching chat history: {e}"
|
||||
log_error(self._client_id, error_msg,
|
||||
chat_id=chatId,
|
||||
duration=f"{duration:.3f}s",
|
||||
exception_type=type(e).__name__)
|
||||
raise
|
||||
|
||||
async def main():
|
||||
"""Example usage of the ChatModel class."""
|
||||
chat_model = ChatModel(
|
||||
api_key="fastgpt-tgpSdDSE51cc6BPdb92ODfsm0apZRXOrc75YeaiZ8HmqlYplZKi5flvJUqjG5b",
|
||||
api_url="http://101.89.151.141:3000/",
|
||||
appId="6846890686197e19f72036f9",
|
||||
client_id="test_client"
|
||||
)
|
||||
|
||||
try:
|
||||
log_info("test_client", "Starting FastGPT API tests")
|
||||
|
||||
# Test welcome text
|
||||
welcome_text = await chat_model.get_welcome_text('welcome')
|
||||
log_info("test_client", "Welcome text test completed", welcome_text_length=len(welcome_text))
|
||||
|
||||
# Test AI response generation
|
||||
response = await chat_model.generate_ai_response('chat0002', '我想问一下怎么用fastgpt')
|
||||
log_info("test_client", "AI response test completed", response_length=len(response))
|
||||
|
||||
# Test chat history
|
||||
history = await chat_model.get_chat_history('chat0002')
|
||||
log_info("test_client", "Chat history test completed", history_count=len(history))
|
||||
|
||||
log_info("test_client", "All FastGPT API tests completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
log_error("test_client", f"Test failed: {e}", exception_type=type(e).__name__)
|
||||
raise
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
98
src/logger.py
Normal file
98
src/logger.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
# ANSI escape codes for colors
|
||||
class LogColors:
|
||||
HEADER = '\033[95m'
|
||||
OKBLUE = '\033[94m'
|
||||
OKCYAN = '\033[96m'
|
||||
OKGREEN = '\033[92m'
|
||||
WARNING = '\033[93m'
|
||||
FAIL = '\033[91m'
|
||||
ENDC = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
UNDERLINE = '\033[4m'
|
||||
|
||||
# Log levels and symbols
|
||||
LOG_LEVELS = {
|
||||
"INFO": ("ℹ️", LogColors.OKGREEN),
|
||||
"DEBUG": ("🐛", LogColors.OKCYAN),
|
||||
"WARNING": ("⚠️", LogColors.WARNING),
|
||||
"ERROR": ("❌", LogColors.FAIL),
|
||||
"TIMEOUT": ("⏱️", LogColors.OKBLUE),
|
||||
"USER_INPUT": ("💬", LogColors.HEADER),
|
||||
"AI_RESPONSE": ("🤖", LogColors.OKBLUE),
|
||||
"SESSION": ("🔗", LogColors.BOLD),
|
||||
"MODEL": ("🧠", LogColors.OKCYAN),
|
||||
"PREDICT": ("🎯", LogColors.HEADER),
|
||||
"PERFORMANCE": ("⚡", LogColors.OKGREEN),
|
||||
"CONNECTION": ("🌐", LogColors.OKBLUE)
|
||||
}
|
||||
|
||||
def app_log(level: str, client_id: Optional[str], message: str, **kwargs):
|
||||
"""
|
||||
Custom logger with timestamp, level, color, and additional context.
|
||||
|
||||
Args:
|
||||
level: Log level (INFO, DEBUG, WARNING, ERROR, etc.)
|
||||
client_id: Client identifier for session tracking
|
||||
message: Main log message
|
||||
**kwargs: Additional key-value pairs to include in the log
|
||||
"""
|
||||
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
|
||||
symbol, color = LOG_LEVELS.get(level.upper(), ("🔹", LogColors.ENDC)) # Default if level not found
|
||||
client_id_str = f" ({client_id})" if client_id else ""
|
||||
|
||||
extra_info = ""
|
||||
if kwargs:
|
||||
extra_info = " | " + " | ".join([f"{k}={v}" for k, v in kwargs.items()])
|
||||
|
||||
print(f"{color}{timestamp} [{level.upper()}] {symbol}{client_id_str}: {message}{extra_info}{LogColors.ENDC}")
|
||||
|
||||
def log_info(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log an info message."""
|
||||
app_log("INFO", client_id, message, **kwargs)
|
||||
|
||||
def log_debug(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log a debug message."""
|
||||
app_log("DEBUG", client_id, message, **kwargs)
|
||||
|
||||
def log_warning(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log a warning message."""
|
||||
app_log("WARNING", client_id, message, **kwargs)
|
||||
|
||||
def log_error(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log an error message."""
|
||||
app_log("ERROR", client_id, message, **kwargs)
|
||||
|
||||
def log_model(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log a model-related message."""
|
||||
app_log("MODEL", client_id, message, **kwargs)
|
||||
|
||||
def log_predict(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log a prediction-related message."""
|
||||
app_log("PREDICT", client_id, message, **kwargs)
|
||||
|
||||
def log_performance(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log a performance-related message."""
|
||||
app_log("PERFORMANCE", client_id, message, **kwargs)
|
||||
|
||||
def log_connection(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log a connection-related message."""
|
||||
app_log("CONNECTION", client_id, message, **kwargs)
|
||||
|
||||
def log_timeout(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log a timeout-related message."""
|
||||
app_log("TIMEOUT", client_id, message, **kwargs)
|
||||
|
||||
def log_user_input(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log a user input message."""
|
||||
app_log("USER_INPUT", client_id, message, **kwargs)
|
||||
|
||||
def log_ai_response(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log an AI response message."""
|
||||
app_log("AI_RESPONSE", client_id, message, **kwargs)
|
||||
|
||||
def log_session(client_id: Optional[str], message: str, **kwargs):
|
||||
"""Log a session-related message."""
|
||||
app_log("SESSION", client_id, message, **kwargs)
|
||||
376
src/main.py
Normal file
376
src/main.py
Normal file
@@ -0,0 +1,376 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import datetime # Added for timestamp
|
||||
import dotenv
|
||||
import urllib.parse # For parsing query parameters
|
||||
import websockets # Make sure it's imported at the top
|
||||
|
||||
from turn_detection import ChatMessage, TurnDetectorFactory, ONNX_AVAILABLE, FASTGPT_AVAILABLE
|
||||
from fastgpt_api import ChatModel
|
||||
from logger import (app_log, log_info, log_debug, log_warning, log_error,
|
||||
log_timeout, log_user_input, log_ai_response, log_session)
|
||||
|
||||
dotenv.load_dotenv()
|
||||
MAX_INCOMPLETE_SENTENCES = int(os.getenv("MAX_INCOMPLETE_SENTENCES", 3))
|
||||
MAX_RESPONSE_TIMEOUT = int(os.getenv("MAX_RESPONSE_TIMEOUT", 5))
|
||||
CHAT_MODEL_API_URL = os.getenv("CHAT_MODEL_API_URL", None)
|
||||
CHAT_MODEL_API_KEY = os.getenv("CHAT_MODEL_API_KEY", None)
|
||||
CHAT_MODEL_APP_ID = os.getenv("CHAT_MODEL_APP_ID", None)
|
||||
|
||||
# Turn Detection Configuration
|
||||
TURN_DETECTION_MODEL = os.getenv("TURN_DETECTION_MODEL", "onnx").lower() # "onnx", "fastgpt", "always_true"
|
||||
ONNX_UNLIKELY_THRESHOLD = float(os.getenv("ONNX_UNLIKELY_THRESHOLD", 0.0009))
|
||||
|
||||
def estimate_tts_playtime(text: str) -> float:
|
||||
chars_per_second = 5.6
|
||||
if not text: return 0.0
|
||||
estimated_time = len(text) / chars_per_second
|
||||
return max(0.5, estimated_time) # Min 0.5s for very short
|
||||
|
||||
def create_turn_detector_with_fallback():
|
||||
"""
|
||||
Create a turn detector with fallback logic if the requested mode is not available.
|
||||
|
||||
Returns:
|
||||
Turn detector instance
|
||||
"""
|
||||
# Check if the requested mode is available
|
||||
available_detectors = TurnDetectorFactory.get_available_detectors()
|
||||
|
||||
if TURN_DETECTION_MODEL not in available_detectors or not available_detectors[TURN_DETECTION_MODEL]:
|
||||
# Requested mode is not available, find a fallback
|
||||
log_warning(None, f"Requested turn detection mode '{TURN_DETECTION_MODEL}' is not available")
|
||||
|
||||
# Log available detectors
|
||||
log_info(None, "Available turn detectors", available_detectors=available_detectors)
|
||||
|
||||
# Log import errors for unavailable detectors
|
||||
import_errors = TurnDetectorFactory.get_import_errors()
|
||||
if import_errors:
|
||||
log_warning(None, "Import errors for unavailable detectors", import_errors=import_errors)
|
||||
|
||||
# Choose fallback based on availability
|
||||
if available_detectors.get("fastgpt", False):
|
||||
fallback_mode = "fastgpt"
|
||||
log_info(None, f"Falling back to FastGPT turn detector")
|
||||
elif available_detectors.get("onnx", False):
|
||||
fallback_mode = "onnx"
|
||||
log_info(None, f"Falling back to ONNX turn detector")
|
||||
else:
|
||||
fallback_mode = "always_true"
|
||||
log_info(None, f"Falling back to AlwaysTrue turn detector (no ML models available)")
|
||||
|
||||
# Create the fallback detector
|
||||
if fallback_mode == "onnx":
|
||||
return TurnDetectorFactory.create_turn_detector(
|
||||
fallback_mode,
|
||||
unlikely_threshold=ONNX_UNLIKELY_THRESHOLD
|
||||
)
|
||||
else:
|
||||
return TurnDetectorFactory.create_turn_detector(fallback_mode)
|
||||
|
||||
# Requested mode is available, create it
|
||||
if TURN_DETECTION_MODEL == "onnx":
|
||||
return TurnDetectorFactory.create_turn_detector(
|
||||
TURN_DETECTION_MODEL,
|
||||
unlikely_threshold=ONNX_UNLIKELY_THRESHOLD
|
||||
)
|
||||
else:
|
||||
return TurnDetectorFactory.create_turn_detector(TURN_DETECTION_MODEL)
|
||||
|
||||
class SessionData:
|
||||
def __init__(self, client_id):
|
||||
self.client_id = client_id
|
||||
self.incomplete_sentences = []
|
||||
self.conversation_history = []
|
||||
self.last_input_time = time.time()
|
||||
self.timeout_task = None
|
||||
self.ai_response_playback_ends_at: float | None = None
|
||||
|
||||
# Global instances
|
||||
turn_detection_model = create_turn_detector_with_fallback()
|
||||
ai_model = chat_model = ChatModel(
|
||||
api_key=CHAT_MODEL_API_KEY,
|
||||
api_url=CHAT_MODEL_API_URL,
|
||||
appId=CHAT_MODEL_APP_ID
|
||||
)
|
||||
sessions = {}
|
||||
|
||||
async def handle_input_timeout(websocket, session: SessionData):
|
||||
client_id = session.client_id
|
||||
try:
|
||||
if session.ai_response_playback_ends_at:
|
||||
current_time = time.time()
|
||||
remaining_ai_playtime = session.ai_response_playback_ends_at - current_time
|
||||
if remaining_ai_playtime > 0:
|
||||
log_timeout(client_id, f"Waiting for AI playback to finish", remaining_playtime=f"{remaining_ai_playtime:.2f}s")
|
||||
await asyncio.sleep(remaining_ai_playtime)
|
||||
|
||||
log_timeout(client_id, f"AI playback done. Starting user inactivity", timeout_seconds=MAX_RESPONSE_TIMEOUT)
|
||||
await asyncio.sleep(MAX_RESPONSE_TIMEOUT)
|
||||
# If we reach here, 5 seconds of user silence have passed *after* AI finished.
|
||||
|
||||
# Process buffered input if any
|
||||
if session.incomplete_sentences:
|
||||
buffered_text = ' '.join(session.incomplete_sentences)
|
||||
log_timeout(client_id, f"Processing buffered input after silence", buffer_content=f"'{buffered_text}'")
|
||||
full_turn_text = " ".join(session.incomplete_sentences)
|
||||
await process_complete_turn(websocket, session, full_turn_text)
|
||||
else:
|
||||
log_timeout(client_id, f"No buffered input after silence")
|
||||
|
||||
session.timeout_task = None # Clear the task reference
|
||||
except asyncio.CancelledError:
|
||||
log_info(client_id, f"Timeout task was cancelled", task_details=str(session.timeout_task))
|
||||
pass # Expected
|
||||
except Exception as e:
|
||||
log_error(client_id, f"Error in timeout handler: {e}", exception_type=type(e).__name__)
|
||||
if session: session.timeout_task = None
|
||||
|
||||
|
||||
async def handle_user_input(websocket, client_id: str, incoming_text: str):
|
||||
incoming_text = incoming_text.strip('。') # chinese period could affect prediction
|
||||
# client_id is now passed directly from chat_handler and is known to exist in sessions
|
||||
session = sessions[client_id]
|
||||
session.last_input_time = time.time() # Update on EVERY user input
|
||||
|
||||
# CRITICAL: Cancel any existing timeout task because new input has arrived.
|
||||
# This handles cancellations during AI playback wait or user silence wait.
|
||||
if session.timeout_task and not session.timeout_task.done():
|
||||
session.timeout_task.cancel()
|
||||
session.timeout_task = None
|
||||
# print(f"Cancelled previous timeout task for {client_id} due to new input.")
|
||||
|
||||
ai_is_speaking_now = False
|
||||
if session.ai_response_playback_ends_at and time.time() < session.ai_response_playback_ends_at:
|
||||
ai_is_speaking_now = True
|
||||
log_user_input(client_id, f"AI speaking. Buffering: '{incoming_text}'", current_buffer_size=len(session.incomplete_sentences))
|
||||
|
||||
if ai_is_speaking_now:
|
||||
session.incomplete_sentences.append(incoming_text)
|
||||
log_user_input(client_id, f"AI speaking. Scheduling new timeout", new_buffer_size=len(session.incomplete_sentences))
|
||||
session.timeout_task = asyncio.create_task(handle_input_timeout(websocket, session))
|
||||
return
|
||||
|
||||
# AI is NOT speaking, proceed with normal turn detection for current + buffered input
|
||||
current_potential_turn_parts = session.incomplete_sentences + [incoming_text]
|
||||
current_potential_turn_text = " ".join(current_potential_turn_parts)
|
||||
context_for_turn_detection = session.conversation_history + [ChatMessage(role='user', content=current_potential_turn_text)]
|
||||
|
||||
# Use the configured turn detector
|
||||
is_complete = await turn_detection_model.predict(
|
||||
context_for_turn_detection,
|
||||
client_id=client_id
|
||||
)
|
||||
log_debug(client_id, "Turn detection result",
|
||||
mode=TURN_DETECTION_MODEL,
|
||||
is_complete=is_complete,
|
||||
text_checked=current_potential_turn_text)
|
||||
|
||||
if is_complete:
|
||||
await process_complete_turn(websocket, session, current_potential_turn_text)
|
||||
else:
|
||||
session.incomplete_sentences.append(incoming_text)
|
||||
if len(session.incomplete_sentences) >= MAX_INCOMPLETE_SENTENCES:
|
||||
log_user_input(client_id, f"Max incomplete sentences limit reached. Processing", limit=MAX_INCOMPLETE_SENTENCES, current_count=len(session.incomplete_sentences))
|
||||
full_turn_text = " ".join(session.incomplete_sentences)
|
||||
await process_complete_turn(websocket, session, full_turn_text)
|
||||
else:
|
||||
log_user_input(client_id, f"Turn incomplete. Scheduling new timeout", current_buffer_size=len(session.incomplete_sentences))
|
||||
session.timeout_task = asyncio.create_task(handle_input_timeout(websocket, session))
|
||||
|
||||
|
||||
async def process_complete_turn(websocket, session: SessionData, full_user_turn_text: str, is_welcome_message_context=False):
|
||||
# For a welcome message, full_user_turn_text might be empty or a system prompt
|
||||
if not is_welcome_message_context: # Only add user message if it's not the initial welcome context
|
||||
session.conversation_history.append(ChatMessage(role="user", content=full_user_turn_text))
|
||||
|
||||
session.incomplete_sentences = []
|
||||
|
||||
try:
|
||||
# Pass current history to AI model. For welcome, it might be empty or have a system seed.
|
||||
if not is_welcome_message_context:
|
||||
ai_response_text = await ai_model.generate_ai_response(session.client_id, full_user_turn_text)
|
||||
else:
|
||||
ai_response_text = await ai_model.get_welcome_text(session.client_id)
|
||||
log_debug(session.client_id, "AI model interaction", is_welcome=is_welcome_message_context, user_turn_length=len(full_user_turn_text) if not is_welcome_message_context else 0)
|
||||
|
||||
except Exception as e:
|
||||
log_error(session.client_id, f"AI response generation failed: {e}", is_welcome=is_welcome_message_context, exception_type=type(e).__name__)
|
||||
# If it's not a welcome message context and AI failed, revert user message
|
||||
if not is_welcome_message_context and session.conversation_history and session.conversation_history[-1].role == "user":
|
||||
session.conversation_history.pop()
|
||||
await websocket.send(json.dumps({
|
||||
"type": "ERROR", "payload": {"message": "AI failed", "client_id": session.client_id}
|
||||
}))
|
||||
return
|
||||
|
||||
session.conversation_history.append(ChatMessage(role="assistant", content=ai_response_text))
|
||||
|
||||
tts_duration = estimate_tts_playtime(ai_response_text)
|
||||
# Set when AI response playback is expected to end. THIS IS THE KEY for the timeout logic.
|
||||
session.ai_response_playback_ends_at = time.time() + tts_duration
|
||||
|
||||
log_ai_response(session.client_id, f"Response sent: '{ai_response_text}'", tts_duration=f"{tts_duration:.2f}s", playback_ends_at=f"{session.ai_response_playback_ends_at:.2f}")
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
"type": "AI_RESPONSE",
|
||||
"payload": {
|
||||
"text": ai_response_text,
|
||||
"client_id": session.client_id,
|
||||
"estimated_tts_duration": tts_duration
|
||||
}
|
||||
}))
|
||||
|
||||
if session.timeout_task and not session.timeout_task.done():
|
||||
session.timeout_task.cancel()
|
||||
session.timeout_task = None
|
||||
|
||||
|
||||
# --- MODIFIED chat_handler ---
|
||||
async def chat_handler(websocket: websockets):
|
||||
"""
|
||||
Handles new WebSocket connections.
|
||||
Extracts client_id from path, manages session creation, and message routing.
|
||||
"""
|
||||
path = websocket.request.path
|
||||
parsed_path = urllib.parse.urlparse(path)
|
||||
query_params = urllib.parse.parse_qs(parsed_path.query)
|
||||
|
||||
raw_client_id_values = query_params.get('clientId') # This will be None or list of strings
|
||||
|
||||
client_id: str | None = None
|
||||
if raw_client_id_values and raw_client_id_values[0].strip():
|
||||
client_id = raw_client_id_values[0].strip()
|
||||
|
||||
if client_id is None:
|
||||
log_warning(None, f"Connection from {websocket.remote_address} missing or empty clientId in path: {path}. Closing.")
|
||||
await websocket.close(code=1008, reason="clientId parameter is required and cannot be empty.")
|
||||
return
|
||||
|
||||
# Now client_id is guaranteed to be a non-empty string here
|
||||
log_info(client_id, f"Connection attempt from {websocket.remote_address}, Path: {path}")
|
||||
|
||||
# --- Session Creation and Welcome Message ---
|
||||
is_new_session = False
|
||||
if client_id not in sessions:
|
||||
log_session(client_id, f"NEW SESSION: Creating session", total_sessions_before=len(sessions))
|
||||
sessions[client_id] = SessionData(client_id)
|
||||
is_new_session = True
|
||||
else:
|
||||
# Client reconnected, or multiple connections with same ID (handle as needed)
|
||||
# For now, we assume one active websocket per client_id for simplicity of timeout tasks etc.
|
||||
# If an old session for this client_id had a lingering timeout task, it should be cancelled
|
||||
# if this new connection effectively replaces the old one.
|
||||
# This part needs care if multiple websockets can truly share one session.
|
||||
# For now, let's ensure any old timeout for this session_id is cleared if a new websocket connects.
|
||||
existing_session = sessions[client_id]
|
||||
if existing_session.timeout_task and not existing_session.timeout_task.done():
|
||||
log_info(client_id, f"RECONNECT: Cancelling old timeout task from previous connection")
|
||||
existing_session.timeout_task.cancel()
|
||||
existing_session.timeout_task = None
|
||||
# Update last_input_time to reflect new activity/connection
|
||||
existing_session.last_input_time = time.time()
|
||||
# Reset playback state as it pertains to the previous connection's AI responses
|
||||
existing_session.ai_response_playback_ends_at = None
|
||||
log_session(client_id, f"EXISTING SESSION: Client reconnected or new connection")
|
||||
|
||||
session = sessions[client_id] # Get the session (new or existing)
|
||||
|
||||
if is_new_session:
|
||||
# Send a welcome message
|
||||
log_session(client_id, f"NEW SESSION: Sending welcome message")
|
||||
# We can add a system prompt to the history before generating welcome message if needed
|
||||
# session.conversation_history.append({"role": "system", "content": "You are a friendly assistant."})
|
||||
await process_complete_turn(websocket, session, "", is_welcome_message_context=True)
|
||||
# The welcome message itself will have TTS, so ai_response_playback_ends_at will be set.
|
||||
|
||||
# --- Message Loop ---
|
||||
try:
|
||||
async for message_str in websocket:
|
||||
try:
|
||||
message_data = json.loads(message_str)
|
||||
msg_type = message_data.get("type")
|
||||
payload = message_data.get("payload")
|
||||
|
||||
if msg_type == "USER_INPUT":
|
||||
# Client no longer needs to send client_id in payload if it's in URL
|
||||
# but if it does, we can validate it matches the URL's client_id
|
||||
payload_client_id = payload.get("client_id")
|
||||
if payload_client_id and payload_client_id != client_id:
|
||||
log_warning(client_id, f"Mismatch! URL clientId='{client_id}', Payload clientId='{payload_client_id}'. Using URL clientId.")
|
||||
# Decide on error strategy or just use URL's client_id
|
||||
|
||||
text_input = payload.get("text")
|
||||
if text_input is None: # Ensure text is present
|
||||
await websocket.send(json.dumps({"type": "ERROR", "payload": {"message": "USER_INPUT missing 'text'", "client_id": client_id}}))
|
||||
continue
|
||||
|
||||
await handle_user_input(websocket, client_id, text_input)
|
||||
else:
|
||||
await websocket.send(json.dumps({"type": "ERROR", "payload": {"message": f"Unknown msg type: {msg_type}", "client_id": client_id}}))
|
||||
|
||||
except json.JSONDecodeError:
|
||||
await websocket.send(json.dumps({"type": "ERROR", "payload": {"message": "Invalid JSON", "client_id": client_id}}))
|
||||
except Exception as e:
|
||||
log_error(client_id, f"Error processing message: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
await websocket.send(json.dumps({"type": "ERROR", "payload": {"message": f"Server error: {str(e)}", "client_id": client_id}}))
|
||||
|
||||
except websockets.exceptions.ConnectionClosedError as e:
|
||||
log_error(client_id, f"Connection closed with error: {e.code} {e.reason}")
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
log_info(client_id, f"Connection closed gracefully")
|
||||
except Exception as e:
|
||||
log_error(client_id, f"Unexpected error in handler: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
log_info(client_id, f"Connection ended. Cleaning up resources.")
|
||||
# The session object itself (sessions[client_id]) remains in memory.
|
||||
# Its timeout_task, if active for THIS websocket connection, should be cancelled.
|
||||
# If another websocket connects with the same client_id, it will reuse the session.
|
||||
# Stale sessions in the `sessions` dict would need a separate cleanup mechanism
|
||||
# if they are not reconnected to (e.g. based on last_input_time).
|
||||
|
||||
# If this websocket was the one associated with the session's current timeout_task, cancel it.
|
||||
# This is tricky because the timeout_task is tied to the session, not the websocket instance directly.
|
||||
# The logic at the start of chat_handler for existing sessions helps here.
|
||||
# If this is the *only* connection for this client_id and it's closing,
|
||||
# then any active timeout_task on its session should ideally be stopped.
|
||||
# However, if client can reconnect, keeping the task might be desired if it's a short disconnect.
|
||||
# For simplicity now, we rely on new connections cancelling old tasks.
|
||||
# A more robust solution might involve tracking active websockets per session.
|
||||
|
||||
# If we want to ensure no timeout task runs for a session if NO websocket is connected for it:
|
||||
# This requires knowing if other websockets are active for this client_id.
|
||||
# For a single-connection-per-client_id model enforced by the client:
|
||||
if client_id in sessions: # Check if session still exists (it should)
|
||||
active_session = sessions[client_id]
|
||||
# Heuristic: If this websocket is closing, and it was the one that last interacted
|
||||
# or if no other known websocket is active for this session, cancel its timeout.
|
||||
# This is complex without explicit websocket tracking per session.
|
||||
# For now, the cancellation at the START of a new connection for an existing session is the primary mechanism.
|
||||
log_info(client_id, f"Client disconnected. Session data remains. Next connection will reuse/manage timeout.")
|
||||
|
||||
|
||||
async def main():
|
||||
log_info(None, f"Chat server starting with turn detection mode: {TURN_DETECTION_MODEL}")
|
||||
|
||||
# Log available detectors
|
||||
available_detectors = TurnDetectorFactory.get_available_detectors()
|
||||
log_info(None, "Available turn detectors", available_detectors=available_detectors)
|
||||
|
||||
if TURN_DETECTION_MODEL == "onnx" and ONNX_AVAILABLE:
|
||||
log_info(None, f"ONNX threshold: {ONNX_UNLIKELY_THRESHOLD}")
|
||||
|
||||
server = await websockets.serve(chat_handler, "0.0.0.0", 9000)
|
||||
log_info(None, "Chat server started (clientId from URL, welcome msg)")
|
||||
# on ws://localhost:8765")
|
||||
await server.wait_closed()
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
166
src/turn_detection/README.md
Normal file
166
src/turn_detection/README.md
Normal file
@@ -0,0 +1,166 @@
|
||||
# Turn Detection Package
|
||||
|
||||
This package provides multiple turn detection implementations for conversational AI systems. Turn detection determines when a user has finished speaking and it's appropriate for the AI to respond.
|
||||
|
||||
## Package Structure
|
||||
|
||||
```
|
||||
turn_detection/
|
||||
├── __init__.py # Package exports and backward compatibility
|
||||
├── base.py # Base classes and common data structures
|
||||
├── factory.py # Factory for creating turn detectors
|
||||
├── onnx_detector.py # ONNX-based turn detector
|
||||
├── fastgpt_detector.py # FastGPT API-based turn detector
|
||||
├── always_true_detector.py # Simple always-true detector for testing
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Available Turn Detectors
|
||||
|
||||
### 1. ONNXTurnDetector
|
||||
- **File**: `onnx_detector.py`
|
||||
- **Description**: Uses a pre-trained ONNX model with Hugging Face tokenizer
|
||||
- **Use Case**: Production-ready, offline turn detection
|
||||
- **Dependencies**: `onnxruntime`, `transformers`, `huggingface_hub`
|
||||
|
||||
### 2. FastGPTTurnDetector
|
||||
- **File**: `fastgpt_detector.py`
|
||||
- **Description**: Uses FastGPT API for turn detection
|
||||
- **Use Case**: Cloud-based turn detection with API access
|
||||
- **Dependencies**: `fastgpt_api`
|
||||
|
||||
### 3. AlwaysTrueTurnDetector
|
||||
- **File**: `always_true_detector.py`
|
||||
- **Description**: Always returns True (considers all turns complete)
|
||||
- **Use Case**: Testing, debugging, or when turn detection is not needed
|
||||
- **Dependencies**: None
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```python
|
||||
from turn_detection import ChatMessage, TurnDetectorFactory
|
||||
|
||||
# Create a turn detector using the factory
|
||||
detector = TurnDetectorFactory.create_turn_detector(
|
||||
mode="onnx", # "onnx", "fastgpt", or "always_true"
|
||||
unlikely_threshold=0.005 # For ONNX detector
|
||||
)
|
||||
|
||||
# Prepare chat context
|
||||
chat_context = [
|
||||
ChatMessage(role='assistant', content='Hello, how can I help you?'),
|
||||
ChatMessage(role='user', content='I need help with my order')
|
||||
]
|
||||
|
||||
# Predict if the turn is complete
|
||||
is_complete = await detector.predict(chat_context, client_id="user123")
|
||||
print(f"Turn complete: {is_complete}")
|
||||
|
||||
# Get probability
|
||||
probability = await detector.predict_probability(chat_context, client_id="user123")
|
||||
print(f"Completion probability: {probability}")
|
||||
```
|
||||
|
||||
### Direct Class Usage
|
||||
|
||||
```python
|
||||
from turn_detection import ONNXTurnDetector, FastGPTTurnDetector, AlwaysTrueTurnDetector
|
||||
|
||||
# ONNX detector
|
||||
onnx_detector = ONNXTurnDetector(unlikely_threshold=0.005)
|
||||
|
||||
# FastGPT detector
|
||||
fastgpt_detector = FastGPTTurnDetector(
|
||||
api_url="http://your-api-url",
|
||||
api_key="your-api-key",
|
||||
appId="your-app-id"
|
||||
)
|
||||
|
||||
# Always true detector
|
||||
always_true_detector = AlwaysTrueTurnDetector()
|
||||
```
|
||||
|
||||
### Factory Configuration
|
||||
|
||||
The factory supports different configuration options for each detector type:
|
||||
|
||||
```python
|
||||
# ONNX detector with custom settings
|
||||
onnx_detector = TurnDetectorFactory.create_turn_detector(
|
||||
mode="onnx",
|
||||
unlikely_threshold=0.001,
|
||||
max_history_tokens=256,
|
||||
max_history_turns=8
|
||||
)
|
||||
|
||||
# FastGPT detector with custom settings
|
||||
fastgpt_detector = TurnDetectorFactory.create_turn_detector(
|
||||
mode="fastgpt",
|
||||
api_url="http://custom-api-url",
|
||||
api_key="custom-api-key",
|
||||
appId="custom-app-id"
|
||||
)
|
||||
|
||||
# Always true detector (no configuration needed)
|
||||
always_true_detector = TurnDetectorFactory.create_turn_detector(mode="always_true")
|
||||
```
|
||||
|
||||
## Data Structures
|
||||
|
||||
### ChatMessage
|
||||
```python
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
role: ChatRole # "system", "user", "assistant", "tool"
|
||||
content: str | list[str] | None = None
|
||||
```
|
||||
|
||||
### ChatRole
|
||||
```python
|
||||
ChatRole = Literal["system", "user", "assistant", "tool"]
|
||||
```
|
||||
|
||||
## Base Class Interface
|
||||
|
||||
All turn detectors implement the `BaseTurnDetector` interface:
|
||||
|
||||
```python
|
||||
class BaseTurnDetector(ABC):
|
||||
@abstractmethod
|
||||
async def predict(self, chat_context: List[ChatMessage], client_id: str = None) -> bool:
|
||||
"""Predicts whether the current utterance is complete."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def predict_probability(self, chat_context: List[ChatMessage], client_id: str = None) -> float:
|
||||
"""Predicts the probability that the current utterance is complete."""
|
||||
pass
|
||||
```
|
||||
|
||||
## Environment Variables
|
||||
|
||||
The following environment variables can be used to configure the detectors:
|
||||
|
||||
- `TURN_DETECTION_MODEL`: Turn detection mode ("onnx", "fastgpt", "always_true")
|
||||
- `ONNX_UNLIKELY_THRESHOLD`: Threshold for ONNX detector (default: 0.005)
|
||||
- `CHAT_MODEL_API_URL`: FastGPT API URL
|
||||
- `CHAT_MODEL_API_KEY`: FastGPT API key
|
||||
- `CHAT_MODEL_APP_ID`: FastGPT app ID
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
For backward compatibility, the original `TurnDetector` name still refers to `ONNXTurnDetector`:
|
||||
|
||||
```python
|
||||
from turn_detection import TurnDetector # Same as ONNXTurnDetector
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
See the individual detector files for complete usage examples:
|
||||
|
||||
- `onnx_detector.py` - ONNX detector example
|
||||
- `fastgpt_detector.py` - FastGPT detector example
|
||||
- `always_true_detector.py` - Always true detector example
|
||||
49
src/turn_detection/__init__.py
Normal file
49
src/turn_detection/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Turn Detection Package
|
||||
|
||||
This package provides multiple turn detection implementations for conversational AI systems.
|
||||
"""
|
||||
|
||||
from .base import ChatMessage, ChatRole, BaseTurnDetector
|
||||
|
||||
# Try to import ONNX detector, but handle import failures gracefully
|
||||
try:
|
||||
from .onnx_detector import TurnDetector as ONNXTurnDetector
|
||||
ONNX_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
ONNX_AVAILABLE = False
|
||||
ONNXTurnDetector = None
|
||||
_onnx_import_error = str(e)
|
||||
|
||||
# Try to import FastGPT detector
|
||||
try:
|
||||
from .fastgpt_detector import TurnDetector as FastGPTTurnDetector
|
||||
FASTGPT_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
FASTGPT_AVAILABLE = False
|
||||
FastGPTTurnDetector = None
|
||||
_fastgpt_import_error = str(e)
|
||||
|
||||
# Always true detector should always be available
|
||||
from .always_true_detector import AlwaysTrueTurnDetector
|
||||
from .factory import TurnDetectorFactory
|
||||
|
||||
# Export the main classes
|
||||
__all__ = [
|
||||
'ChatMessage',
|
||||
'ChatRole',
|
||||
'BaseTurnDetector',
|
||||
'ONNXTurnDetector',
|
||||
'FastGPTTurnDetector',
|
||||
'AlwaysTrueTurnDetector',
|
||||
'TurnDetectorFactory',
|
||||
'ONNX_AVAILABLE',
|
||||
'FASTGPT_AVAILABLE'
|
||||
]
|
||||
|
||||
# For backward compatibility, keep the original names
|
||||
# Only set TurnDetector if ONNX is available
|
||||
if ONNX_AVAILABLE:
|
||||
TurnDetector = ONNXTurnDetector
|
||||
else:
|
||||
TurnDetector = None
|
||||
26
src/turn_detection/always_true_detector.py
Normal file
26
src/turn_detection/always_true_detector.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
AlwaysTrueTurnDetector - A simple turn detector that always returns True.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from .base import BaseTurnDetector, ChatMessage
|
||||
from logger import log_info, log_debug
|
||||
|
||||
class AlwaysTrueTurnDetector(BaseTurnDetector):
|
||||
"""
|
||||
A simple turn detector that always returns True (always considers turns complete).
|
||||
Useful for testing or when turn detection is not needed.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
log_info(None, "AlwaysTrueTurnDetector initialized - all turns will be considered complete")
|
||||
|
||||
async def predict(self, chat_context: List[ChatMessage], client_id: str = None) -> bool:
|
||||
"""Always returns True, indicating the turn is complete."""
|
||||
log_debug(client_id, "AlwaysTrueTurnDetector: Turn considered complete",
|
||||
context_length=len(chat_context))
|
||||
return True
|
||||
|
||||
async def predict_probability(self, chat_context: List[ChatMessage], client_id: str = None) -> float:
|
||||
"""Always returns 1.0 probability."""
|
||||
return 1.0
|
||||
55
src/turn_detection/base.py
Normal file
55
src/turn_detection/base.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Base classes and data structures for turn detection.
|
||||
"""
|
||||
|
||||
from typing import Any, Literal, Union, List
|
||||
from dataclasses import dataclass
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# --- Data Structures ---
|
||||
|
||||
ChatRole = Literal["system", "user", "assistant", "tool"]
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""Represents a single message in a chat conversation."""
|
||||
role: ChatRole
|
||||
content: str | list[str] | None = None
|
||||
|
||||
# --- Abstract Base Class ---
|
||||
|
||||
class BaseTurnDetector(ABC):
|
||||
"""
|
||||
Abstract base class for all turn detectors.
|
||||
|
||||
All turn detectors should inherit from this class and implement
|
||||
the required methods.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def predict(self, chat_context: List[ChatMessage], client_id: str = None) -> bool:
|
||||
"""
|
||||
Predicts whether the current utterance is complete.
|
||||
|
||||
Args:
|
||||
chat_context: A list of ChatMessage objects representing the conversation history.
|
||||
client_id: Client identifier for logging purposes.
|
||||
|
||||
Returns:
|
||||
True if the utterance is complete, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def predict_probability(self, chat_context: List[ChatMessage], client_id: str = None) -> float:
|
||||
"""
|
||||
Predicts the probability that the current utterance is complete.
|
||||
|
||||
Args:
|
||||
chat_context: A list of ChatMessage objects representing the conversation history.
|
||||
client_id: Client identifier for logging purposes.
|
||||
|
||||
Returns:
|
||||
A float representing the probability that the utterance is complete.
|
||||
"""
|
||||
pass
|
||||
102
src/turn_detection/factory.py
Normal file
102
src/turn_detection/factory.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Turn Detector Factory
|
||||
|
||||
Factory class for creating turn detectors based on configuration.
|
||||
"""
|
||||
|
||||
from .base import BaseTurnDetector
|
||||
from .always_true_detector import AlwaysTrueTurnDetector
|
||||
from logger import log_info, log_warning, log_error
|
||||
|
||||
# Try to import ONNX detector
|
||||
try:
|
||||
from .onnx_detector import TurnDetector as ONNXTurnDetector
|
||||
ONNX_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
ONNX_AVAILABLE = False
|
||||
ONNXTurnDetector = None
|
||||
_onnx_import_error = str(e)
|
||||
|
||||
# Try to import FastGPT detector
|
||||
try:
|
||||
from .fastgpt_detector import TurnDetector as FastGPTTurnDetector
|
||||
FASTGPT_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
FASTGPT_AVAILABLE = False
|
||||
FastGPTTurnDetector = None
|
||||
_fastgpt_import_error = str(e)
|
||||
|
||||
class TurnDetectorFactory:
|
||||
"""Factory class to create turn detectors based on configuration."""
|
||||
|
||||
@staticmethod
|
||||
def create_turn_detector(mode: str, **kwargs):
|
||||
"""
|
||||
Create a turn detector based on the specified mode.
|
||||
|
||||
Args:
|
||||
mode: Turn detection mode ("onnx", "fastgpt", "always_true")
|
||||
**kwargs: Additional arguments for the specific turn detector
|
||||
|
||||
Returns:
|
||||
Turn detector instance
|
||||
|
||||
Raises:
|
||||
ImportError: If the requested detector is not available due to missing dependencies
|
||||
"""
|
||||
if mode == "onnx":
|
||||
if not ONNX_AVAILABLE:
|
||||
error_msg = f"ONNX turn detector is not available. Import error: {_onnx_import_error}"
|
||||
log_error(None, error_msg)
|
||||
raise ImportError(error_msg)
|
||||
|
||||
unlikely_threshold = kwargs.get('unlikely_threshold', 0.005)
|
||||
log_info(None, f"Creating ONNX turn detector with threshold {unlikely_threshold}")
|
||||
return ONNXTurnDetector(
|
||||
unlikely_threshold=unlikely_threshold,
|
||||
**{k: v for k, v in kwargs.items() if k != 'unlikely_threshold'}
|
||||
)
|
||||
elif mode == "fastgpt":
|
||||
if not FASTGPT_AVAILABLE:
|
||||
error_msg = f"FastGPT turn detector is not available. Import error: {_fastgpt_import_error}"
|
||||
log_error(None, error_msg)
|
||||
raise ImportError(error_msg)
|
||||
|
||||
log_info(None, "Creating FastGPT turn detector")
|
||||
return FastGPTTurnDetector(**kwargs)
|
||||
elif mode == "always_true":
|
||||
log_info(None, "Creating AlwaysTrue turn detector")
|
||||
return AlwaysTrueTurnDetector()
|
||||
else:
|
||||
log_warning(None, f"Unknown turn detection mode '{mode}', defaulting to AlwaysTrue")
|
||||
log_info(None, "Creating AlwaysTrue turn detector as fallback")
|
||||
return AlwaysTrueTurnDetector()
|
||||
|
||||
@staticmethod
|
||||
def get_available_detectors():
|
||||
"""
|
||||
Get a list of available turn detector modes.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with detector modes as keys and availability as boolean values
|
||||
"""
|
||||
return {
|
||||
"onnx": ONNX_AVAILABLE,
|
||||
"fastgpt": FASTGPT_AVAILABLE,
|
||||
"always_true": True # Always available
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_import_errors():
|
||||
"""
|
||||
Get import error messages for unavailable detectors.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with detector modes as keys and error messages as values
|
||||
"""
|
||||
errors = {}
|
||||
if not ONNX_AVAILABLE:
|
||||
errors["onnx"] = _onnx_import_error
|
||||
if not FASTGPT_AVAILABLE:
|
||||
errors["fastgpt"] = _fastgpt_import_error
|
||||
return errors
|
||||
163
src/turn_detection/fastgpt_detector.py
Normal file
163
src/turn_detection/fastgpt_detector.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
FastGPT-based Turn Detector
|
||||
|
||||
A turn detector implementation using FastGPT API for turn detection.
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from .base import BaseTurnDetector, ChatMessage
|
||||
from fastgpt_api import ChatModel
|
||||
from logger import log_info, log_debug, log_warning, log_performance
|
||||
|
||||
class TurnDetector(BaseTurnDetector):
|
||||
"""
|
||||
A class to detect the end of an utterance (turn) in a conversation
|
||||
using FastGPT API for turn detection.
|
||||
"""
|
||||
|
||||
# --- Class Constants (Default Configuration) ---
|
||||
# These can be overridden during instantiation if needed
|
||||
MAX_HISTORY_TOKENS: int = 128
|
||||
MAX_HISTORY_TURNS: int = 6 # Note: This constant wasn't used in the original logic, keeping for completeness
|
||||
API_URL="http://101.89.151.141:3000/"
|
||||
API_KEY="fastgpt-opfE4uKlw6I1EFIY55iWh1dfVPfaQGH2wXvFaCixaZDaZHU1mA61"
|
||||
APP_ID="6850f14486197e19f721b80d"
|
||||
|
||||
def __init__(self,
|
||||
max_history_tokens: int = None,
|
||||
max_history_turns: int = None,
|
||||
api_url: str = None,
|
||||
api_key: str = None,
|
||||
appId: str = None):
|
||||
"""
|
||||
Initializes the TurnDetector with FastGPT API configuration.
|
||||
|
||||
Args:
|
||||
max_history_tokens: Maximum number of tokens for the input sequence. Defaults to MAX_HISTORY_TOKENS.
|
||||
max_history_turns: Maximum number of turns to consider in history. Defaults to MAX_HISTORY_TURNS.
|
||||
api_url: API URL for the FastGPT model. Defaults to API_URL.
|
||||
api_key: API key for authentication. Defaults to API_KEY.
|
||||
app_id: Application ID for the FastGPT model. Defaults to APP_ID.
|
||||
"""
|
||||
# Store configuration, using provided args or class defaults
|
||||
self._api_url = api_url or self.API_URL
|
||||
self._api_key = api_key or self.API_KEY
|
||||
self._appId = appId or self.APP_ID
|
||||
self._max_history_tokens = max_history_tokens or self.MAX_HISTORY_TOKENS
|
||||
self._max_history_turns = max_history_turns or self.MAX_HISTORY_TURNS
|
||||
|
||||
log_info(None, "FastGPT TurnDetector initialized",
|
||||
api_url=self._api_url,
|
||||
app_id=self._appId)
|
||||
|
||||
self._chat_model = ChatModel(
|
||||
api_url=self._api_url,
|
||||
api_key=self._api_key,
|
||||
appId=self._appId
|
||||
)
|
||||
|
||||
def _format_chat_ctx(self, chat_context: List[ChatMessage]) -> str:
|
||||
"""
|
||||
Formats the chat context into a string for model input.
|
||||
|
||||
Args:
|
||||
chat_context: A list of ChatMessage objects representing the conversation history.
|
||||
|
||||
Returns:
|
||||
A string containing the formatted conversation history.
|
||||
"""
|
||||
lst = []
|
||||
for message in chat_context:
|
||||
if message.role == 'assistant':
|
||||
lst.append(f"客服: {message.content}")
|
||||
elif message.role == 'user':
|
||||
lst.append(f"用户: {message.content}")
|
||||
return "\n".join(lst)
|
||||
|
||||
async def predict(self, chat_context: List[ChatMessage], client_id: str = None) -> bool:
|
||||
"""
|
||||
Predicts whether the current utterance is complete using FastGPT API.
|
||||
|
||||
Args:
|
||||
chat_context: A list of ChatMessage objects representing the conversation history.
|
||||
client_id: Client identifier for logging purposes.
|
||||
|
||||
Returns:
|
||||
True if the utterance is complete, False otherwise.
|
||||
"""
|
||||
if not chat_context:
|
||||
log_warning(client_id, "Empty chat context provided, returning False")
|
||||
return False
|
||||
|
||||
start_time = time.perf_counter()
|
||||
text = self._format_chat_ctx(chat_context[-self._max_history_turns:])
|
||||
|
||||
log_debug(client_id, "FastGPT turn detection processing",
|
||||
context_length=len(chat_context),
|
||||
text_length=len(text))
|
||||
|
||||
# Generate a unique chat ID for this prediction
|
||||
chat_id = f"turn_detection_{int(time.time() * 1000)}"
|
||||
|
||||
try:
|
||||
output = await self._chat_model.generate_ai_response(chat_id, text)
|
||||
result = output == '完整'
|
||||
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
log_performance(client_id, "FastGPT turn detection completed",
|
||||
duration=f"{duration:.3f}s",
|
||||
output=output,
|
||||
result=result)
|
||||
|
||||
log_debug(client_id, "FastGPT turn detection result",
|
||||
output=output,
|
||||
is_complete=result)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
log_warning(client_id, f"FastGPT turn detection failed: {e}",
|
||||
duration=f"{duration:.3f}s",
|
||||
exception_type=type(e).__name__)
|
||||
# Default to True (complete) on error to avoid blocking
|
||||
return True
|
||||
|
||||
async def predict_probability(self, chat_context: List[ChatMessage], client_id: str = None) -> float:
|
||||
"""
|
||||
Predicts the probability that the current utterance is complete.
|
||||
For FastGPT turn detector, this is a simplified implementation.
|
||||
|
||||
Args:
|
||||
chat_context: A list of ChatMessage objects representing the conversation history.
|
||||
client_id: Client identifier for logging purposes.
|
||||
|
||||
Returns:
|
||||
A float representing the probability (1.0 for complete, 0.0 for incomplete).
|
||||
"""
|
||||
is_complete = await self.predict(chat_context, client_id)
|
||||
return 1.0 if is_complete else 0.0
|
||||
|
||||
async def main():
|
||||
"""Example usage of the FastGPT TurnDetector class."""
|
||||
chat_ctx = [
|
||||
ChatMessage(role='assistant', content='目前人工坐席繁忙,我是12345智能客服。请详细说出您要反映的事项,如事件发生的时间、地址、具体的经过以及您期望的解决方案等'),
|
||||
ChatMessage(role='user', content='喂,喂'),
|
||||
ChatMessage(role='assistant', content='您好,请问有什么可以帮到您?'),
|
||||
ChatMessage(role='user', content='嗯,我想问一下,就是我在那个网上买那个迪士尼门票快。过期了,然后找不到。找不到客服退货怎么办'),
|
||||
]
|
||||
|
||||
turn_detection = TurnDetector()
|
||||
result = await turn_detection.predict(chat_ctx, client_id="test_client")
|
||||
log_info("test_client", f"FastGPT turn detection result: {result}")
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
376
src/turn_detection/onnx_detector.py
Normal file
376
src/turn_detection/onnx_detector.py
Normal file
@@ -0,0 +1,376 @@
|
||||
"""
|
||||
ONNX-based Turn Detector
|
||||
|
||||
A turn detector implementation using a pre-trained ONNX model and Hugging Face tokenizer.
|
||||
"""
|
||||
|
||||
import psutil
|
||||
import math
|
||||
import json
|
||||
import time
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from .base import BaseTurnDetector, ChatMessage
|
||||
from logger import log_model, log_predict, log_performance, log_warning
|
||||
|
||||
class TurnDetector(BaseTurnDetector):
|
||||
"""
|
||||
A class to detect the end of an utterance (turn) in a conversation
|
||||
using a pre-trained ONNX model and Hugging Face tokenizer.
|
||||
"""
|
||||
|
||||
# --- Class Constants (Default Configuration) ---
|
||||
# These can be overridden during instantiation if needed
|
||||
HG_MODEL: str = "livekit/turn-detector"
|
||||
ONNX_FILENAME: str = "model_q8.onnx"
|
||||
MODEL_REVISION: str = "v0.2.0-intl"
|
||||
MAX_HISTORY_TOKENS: int = 128
|
||||
MAX_HISTORY_TURNS: int = 6
|
||||
INFERENCE_METHOD: str = "lk_end_of_utterance_multilingual"
|
||||
UNLIKELY_THRESHOLD: float = 0.005
|
||||
|
||||
def __init__(self,
|
||||
max_history_tokens: int = None,
|
||||
max_history_turns: int = None,
|
||||
hg_model: str = None,
|
||||
onnx_filename: str = None,
|
||||
model_revision: str = None,
|
||||
inference_method: str = None,
|
||||
unlikely_threshold: float = None):
|
||||
"""
|
||||
Initializes the TurnDetector by downloading and loading the necessary
|
||||
model files, tokenizer, and configuration.
|
||||
|
||||
Args:
|
||||
max_history_tokens: Maximum number of tokens for the input sequence. Defaults to MAX_HISTORY_TOKENS.
|
||||
max_history_turns: Maximum number of turns to consider in history. Defaults to MAX_HISTORY_TURNS.
|
||||
hg_model: Hugging Face model identifier. Defaults to HG_MODEL.
|
||||
onnx_filename: ONNX model filename. Defaults to ONNX_FILENAME.
|
||||
model_revision: Model revision/tag. Defaults to MODEL_REVISION.
|
||||
inference_method: Inference method name. Defaults to INFERENCE_METHOD.
|
||||
unlikely_threshold: Threshold for determining if utterance is complete. Defaults to UNLIKELY_THRESHOLD.
|
||||
"""
|
||||
# Store configuration, using provided args or class defaults
|
||||
self._max_history_tokens = max_history_tokens or self.MAX_HISTORY_TOKENS
|
||||
self._max_history_turns = max_history_turns or self.MAX_HISTORY_TURNS
|
||||
self._hg_model = hg_model or self.HG_MODEL
|
||||
self._onnx_filename = onnx_filename or self.ONNX_FILENAME
|
||||
self._model_revision = model_revision or self.MODEL_REVISION
|
||||
self._inference_method = inference_method or self.INFERENCE_METHOD
|
||||
|
||||
# Initialize model components
|
||||
self._languages = None
|
||||
self._session = None
|
||||
self._tokenizer = None
|
||||
self._unlikely_threshold = unlikely_threshold or self.UNLIKELY_THRESHOLD
|
||||
|
||||
log_model(None, "Initializing TurnDetector",
|
||||
model=self._hg_model,
|
||||
revision=self._model_revision,
|
||||
threshold=self._unlikely_threshold)
|
||||
|
||||
# Load model components
|
||||
self._load_model_components()
|
||||
|
||||
async def _download_from_hf_hub_async(self, repo_id: str, filename: str, **kwargs) -> str:
|
||||
"""
|
||||
Downloads a file from Hugging Face Hub asynchronously.
|
||||
|
||||
Args:
|
||||
repo_id: Repository ID on Hugging Face Hub.
|
||||
filename: Name of the file to download.
|
||||
**kwargs: Additional arguments for hf_hub_download.
|
||||
|
||||
Returns:
|
||||
Local path to the downloaded file.
|
||||
"""
|
||||
# Run the synchronous download in a thread pool to make it async
|
||||
loop = asyncio.get_event_loop()
|
||||
local_path = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
|
||||
)
|
||||
return local_path
|
||||
|
||||
def _download_from_hf_hub(self, repo_id: str, filename: str, **kwargs) -> str:
|
||||
"""
|
||||
Downloads a file from Hugging Face Hub (synchronous version).
|
||||
|
||||
Args:
|
||||
repo_id: Repository ID on Hugging Face Hub.
|
||||
filename: Name of the file to download.
|
||||
**kwargs: Additional arguments for hf_hub_download.
|
||||
|
||||
Returns:
|
||||
Local path to the downloaded file.
|
||||
"""
|
||||
local_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
|
||||
return local_path
|
||||
|
||||
async def _load_model_components_async(self):
|
||||
"""Loads and initializes the model, tokenizer, and configuration asynchronously."""
|
||||
log_model(None, "Loading model components asynchronously")
|
||||
|
||||
# Load languages configuration
|
||||
config_fname = await self._download_from_hf_hub_async(
|
||||
self._hg_model,
|
||||
"languages.json",
|
||||
revision=self._model_revision,
|
||||
local_files_only=False
|
||||
)
|
||||
|
||||
# Read file asynchronously
|
||||
loop = asyncio.get_event_loop()
|
||||
with open(config_fname) as f:
|
||||
self._languages = json.load(f)
|
||||
log_model(None, "Languages configuration loaded", languages_count=len(self._languages))
|
||||
|
||||
# Load ONNX model
|
||||
local_path_onnx = await self._download_from_hf_hub_async(
|
||||
self._hg_model,
|
||||
self._onnx_filename,
|
||||
subfolder="onnx",
|
||||
revision=self._model_revision,
|
||||
local_files_only=False,
|
||||
)
|
||||
|
||||
# Configure ONNX session
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.intra_op_num_threads = max(
|
||||
1, math.ceil(psutil.cpu_count()) // 2
|
||||
)
|
||||
sess_options.inter_op_num_threads = 1
|
||||
sess_options.add_session_config_entry("session.dynamic_block_base", "4")
|
||||
|
||||
self._session = ort.InferenceSession(
|
||||
local_path_onnx, providers=["CPUExecutionProvider"], sess_options=sess_options
|
||||
)
|
||||
|
||||
# Load tokenizer
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
self._hg_model,
|
||||
revision=self._model_revision,
|
||||
local_files_only=False,
|
||||
truncation_side="left",
|
||||
)
|
||||
|
||||
log_model(None, "Model components loaded successfully",
|
||||
onnx_path=local_path_onnx,
|
||||
intra_threads=sess_options.intra_op_num_threads)
|
||||
|
||||
def _load_model_components(self):
|
||||
"""Loads and initializes the model, tokenizer, and configuration."""
|
||||
log_model(None, "Loading model components")
|
||||
|
||||
# Load languages configuration
|
||||
config_fname = self._download_from_hf_hub(
|
||||
self._hg_model,
|
||||
"languages.json",
|
||||
revision=self._model_revision,
|
||||
local_files_only=False
|
||||
)
|
||||
with open(config_fname) as f:
|
||||
self._languages = json.load(f)
|
||||
log_model(None, "Languages configuration loaded", languages_count=len(self._languages))
|
||||
|
||||
# Load ONNX model
|
||||
local_path_onnx = self._download_from_hf_hub(
|
||||
self._hg_model,
|
||||
self._onnx_filename,
|
||||
subfolder="onnx",
|
||||
revision=self._model_revision,
|
||||
local_files_only=False,
|
||||
)
|
||||
|
||||
# Configure ONNX session
|
||||
sess_options = ort.SessionOptions()
|
||||
sess_options.intra_op_num_threads = max(
|
||||
1, math.ceil(psutil.cpu_count()) // 2
|
||||
)
|
||||
sess_options.inter_op_num_threads = 1
|
||||
sess_options.add_session_config_entry("session.dynamic_block_base", "4")
|
||||
|
||||
self._session = ort.InferenceSession(
|
||||
local_path_onnx, providers=["CPUExecutionProvider"], sess_options=sess_options
|
||||
)
|
||||
|
||||
# Load tokenizer
|
||||
self._tokenizer = AutoTokenizer.from_pretrained(
|
||||
self._hg_model,
|
||||
revision=self._model_revision,
|
||||
local_files_only=False,
|
||||
truncation_side="left",
|
||||
)
|
||||
|
||||
log_model(None, "Model components loaded successfully",
|
||||
onnx_path=local_path_onnx,
|
||||
intra_threads=sess_options.intra_op_num_threads)
|
||||
|
||||
def _format_chat_ctx(self, chat_context: List[ChatMessage]) -> str:
|
||||
"""
|
||||
Formats the chat context into a string for model input.
|
||||
|
||||
Args:
|
||||
chat_context: A list of ChatMessage objects representing the conversation history.
|
||||
|
||||
Returns:
|
||||
A string containing the formatted conversation history.
|
||||
"""
|
||||
new_chat_ctx = []
|
||||
for msg in chat_context:
|
||||
new_chat_ctx.append(msg)
|
||||
|
||||
convo_text = self._tokenizer.apply_chat_template(
|
||||
new_chat_ctx,
|
||||
add_generation_prompt=False,
|
||||
add_special_tokens=False,
|
||||
tokenize=False,
|
||||
)
|
||||
|
||||
# remove the EOU token from current utterance
|
||||
ix = convo_text.rfind("<|im_end|>")
|
||||
text = convo_text[:ix]
|
||||
return text
|
||||
|
||||
async def predict(self, chat_context: List[ChatMessage], client_id: str = None) -> bool:
|
||||
"""
|
||||
Predicts the probability that the current utterance is complete.
|
||||
|
||||
Args:
|
||||
chat_context: A list of ChatMessage objects representing the conversation history.
|
||||
client_id: Client identifier for logging purposes.
|
||||
|
||||
Returns:
|
||||
is_complete: True if the utterance is complete, False otherwise.
|
||||
"""
|
||||
if not chat_context:
|
||||
log_warning(client_id, "Empty chat context provided, returning False")
|
||||
return False
|
||||
|
||||
start_time = time.perf_counter()
|
||||
text = self._format_chat_ctx(chat_context[-self._max_history_turns:])
|
||||
log_predict(client_id, "Processing turn detection",
|
||||
context_length=len(chat_context),
|
||||
text_length=len(text))
|
||||
|
||||
# Run tokenization in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
inputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._tokenizer(
|
||||
text,
|
||||
add_special_tokens=False,
|
||||
return_tensors="np",
|
||||
max_length=self._max_history_tokens,
|
||||
truncation=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Run inference in thread pool
|
||||
outputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._session.run(
|
||||
None, {"input_ids": inputs["input_ids"].astype("int64")}
|
||||
)
|
||||
)
|
||||
eou_probability = outputs[0].flatten()[-1]
|
||||
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
log_predict(client_id, "Turn detection completed",
|
||||
probability=f"{eou_probability:.6f}",
|
||||
threshold=self._unlikely_threshold,
|
||||
is_complete=eou_probability > self._unlikely_threshold)
|
||||
|
||||
log_performance(client_id, "Prediction performance",
|
||||
duration=f"{duration:.3f}s",
|
||||
input_tokens=inputs["input_ids"].shape[1])
|
||||
|
||||
if eou_probability > self._unlikely_threshold:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
async def predict_probability(self, chat_context: List[ChatMessage], client_id: str = None) -> float:
|
||||
"""
|
||||
Predicts the probability that the current utterance is complete.
|
||||
|
||||
Args:
|
||||
chat_context: A list of ChatMessage objects representing the conversation history.
|
||||
client_id: Client identifier for logging purposes.
|
||||
|
||||
Returns:
|
||||
A float representing the probability that the utterance is complete.
|
||||
"""
|
||||
if not chat_context:
|
||||
log_warning(client_id, "Empty chat context provided, returning 0.0 probability")
|
||||
return 0.0
|
||||
|
||||
start_time = time.perf_counter()
|
||||
text = self._format_chat_ctx(chat_context[-self._max_history_turns:])
|
||||
log_predict(client_id, "Processing probability prediction",
|
||||
context_length=len(chat_context),
|
||||
text_length=len(text))
|
||||
|
||||
# Run tokenization in thread pool to avoid blocking
|
||||
loop = asyncio.get_event_loop()
|
||||
inputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._tokenizer(
|
||||
text,
|
||||
add_special_tokens=False,
|
||||
return_tensors="np",
|
||||
max_length=self._max_history_tokens,
|
||||
truncation=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Run inference in thread pool
|
||||
outputs = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._session.run(
|
||||
None, {"input_ids": inputs["input_ids"].astype("int64")}
|
||||
)
|
||||
)
|
||||
eou_probability = outputs[0].flatten()[-1]
|
||||
|
||||
end_time = time.perf_counter()
|
||||
duration = end_time - start_time
|
||||
|
||||
log_predict(client_id, "Probability prediction completed",
|
||||
probability=f"{eou_probability:.6f}")
|
||||
|
||||
log_performance(client_id, "Prediction performance",
|
||||
duration=f"{duration:.3f}s",
|
||||
input_tokens=inputs["input_ids"].shape[1])
|
||||
|
||||
return float(eou_probability)
|
||||
|
||||
async def main():
|
||||
"""Example usage of the TurnDetector class."""
|
||||
chat_ctx = [
|
||||
ChatMessage(role='assistant', content='您好,请问有什么可以帮到您?'),
|
||||
# ChatMessage(role='user', content='我想咨询一下退票的问题。')
|
||||
ChatMessage(role='user', content='我想')
|
||||
]
|
||||
|
||||
turn_detection = TurnDetector()
|
||||
result = await turn_detection.predict(chat_ctx, client_id="test_client")
|
||||
from logger import log_info
|
||||
log_info("test_client", f"Final prediction result: {result}")
|
||||
|
||||
# Also test the probability method
|
||||
probability = await turn_detection.predict_probability(chat_ctx, client_id="test_client")
|
||||
log_info("test_client", f"Probability result: {probability}")
|
||||
|
||||
return result
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
Reference in New Issue
Block a user