This commit is contained in:
Xin Wang
2025-06-19 17:39:45 +08:00
commit e46f30c742
17 changed files with 3174 additions and 0 deletions

261
src/fastgpt_api.py Normal file
View File

@@ -0,0 +1,261 @@
import aiohttp
import asyncio
import time
from typing import List, Dict
from logger import log_info, log_debug, log_warning, log_error, log_performance
class ChatModel:
def __init__(self, api_key: str, api_url: str, appId: str, client_id: str = None):
self._api_key = api_key
self._api_url = api_url
self._appId = appId
self._client_id = client_id
log_info(self._client_id, "ChatModel initialized",
api_url=self._api_url,
app_id=self._appId)
async def get_welcome_text(self, chatId: str) -> str:
"""Get welcome text from FastGPT API."""
start_time = time.perf_counter()
url = f'{self._api_url}/api/core/chat/init'
log_debug(self._client_id, "Requesting welcome text",
chat_id=chatId,
url=url)
headers = {
'Authorization': f'Bearer {self._api_key}'
}
params = {
'appId': self._appId,
'chatId': chatId
}
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers, params=params) as response:
end_time = time.perf_counter()
duration = end_time - start_time
if response.status == 200:
response_data = await response.json()
welcome_text = response_data['data']['app']['chatConfig']['welcomeText']
log_performance(self._client_id, "Welcome text request completed",
duration=f"{duration:.3f}s",
status_code=response.status,
response_length=len(welcome_text))
log_debug(self._client_id, "Welcome text retrieved",
chat_id=chatId,
welcome_text_length=len(welcome_text))
return welcome_text
else:
error_msg = f"Failed to get welcome text. Status code: {response.status}"
log_error(self._client_id, error_msg,
chat_id=chatId,
status_code=response.status,
url=url)
raise Exception(error_msg)
except aiohttp.ClientError as e:
end_time = time.perf_counter()
duration = end_time - start_time
error_msg = f"Network error while getting welcome text: {e}"
log_error(self._client_id, error_msg,
chat_id=chatId,
duration=f"{duration:.3f}s",
exception_type=type(e).__name__)
raise Exception(error_msg)
except Exception as e:
end_time = time.perf_counter()
duration = end_time - start_time
error_msg = f"Unexpected error while getting welcome text: {e}"
log_error(self._client_id, error_msg,
chat_id=chatId,
duration=f"{duration:.3f}s",
exception_type=type(e).__name__)
raise
async def generate_ai_response(self, chatId: str, content: str) -> str:
"""Generate AI response from FastGPT API."""
start_time = time.perf_counter()
url = f'{self._api_url}/api/v1/chat/completions'
log_debug(self._client_id, "Generating AI response",
chat_id=chatId,
content_length=len(content),
url=url)
headers = {
'Authorization': f'Bearer {self._api_key}',
'Content-Type': 'application/json'
}
data = {
'chatId': chatId,
'messages': [
{
'content': content,
'role': 'user'
}
]
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
end_time = time.perf_counter()
duration = end_time - start_time
if response.status == 200:
response_data = await response.json()
ai_response = response_data['choices'][0]['message']['content']
log_performance(self._client_id, "AI response generation completed",
duration=f"{duration:.3f}s",
status_code=response.status,
input_length=len(content),
output_length=len(ai_response))
log_debug(self._client_id, "AI response generated",
chat_id=chatId,
input_length=len(content),
response_length=len(ai_response))
return ai_response
else:
error_msg = f"Failed to generate AI response. Status code: {response.status}"
log_error(self._client_id, error_msg,
chat_id=chatId,
status_code=response.status,
url=url,
input_length=len(content))
raise Exception(error_msg)
except aiohttp.ClientError as e:
end_time = time.perf_counter()
duration = end_time - start_time
error_msg = f"Network error while generating AI response: {e}"
log_error(self._client_id, error_msg,
chat_id=chatId,
duration=f"{duration:.3f}s",
exception_type=type(e).__name__,
input_length=len(content))
raise Exception(error_msg)
except Exception as e:
end_time = time.perf_counter()
duration = end_time - start_time
error_msg = f"Unexpected error while generating AI response: {e}"
log_error(self._client_id, error_msg,
chat_id=chatId,
duration=f"{duration:.3f}s",
exception_type=type(e).__name__,
input_length=len(content))
raise
async def get_chat_history(self, chatId: str) -> List[Dict[str, str]]:
"""Get chat history from FastGPT API."""
start_time = time.perf_counter()
url = f'{self._api_url}/api/core/chat/getPaginationRecords'
log_debug(self._client_id, "Fetching chat history",
chat_id=chatId,
url=url)
headers = {
'Authorization': f'Bearer {self._api_key}',
'Content-Type': 'application/json'
}
data = {
'appId': self._appId,
'chatId': chatId,
'loadCustomFeedbacks': False
}
try:
async with aiohttp.ClientSession() as session:
async with session.post(url, headers=headers, json=data) as response:
end_time = time.perf_counter()
duration = end_time - start_time
if response.status == 200:
response_data = await response.json()
chat_history = []
for element in response_data['data']['list']:
if element['obj'] == 'Human':
chat_history.append({'role': 'user', 'content': element['value'][0]['text']})
elif element['obj'] == 'AI':
chat_history.append({'role': 'assistant', 'content': element['value'][0]['text']})
log_performance(self._client_id, "Chat history fetch completed",
duration=f"{duration:.3f}s",
status_code=response.status,
history_count=len(chat_history))
log_debug(self._client_id, "Chat history retrieved",
chat_id=chatId,
history_count=len(chat_history))
return chat_history
else:
error_msg = f"Failed to fetch chat history. Status code: {response.status}"
log_error(self._client_id, error_msg,
chat_id=chatId,
status_code=response.status,
url=url)
raise Exception(error_msg)
except aiohttp.ClientError as e:
end_time = time.perf_counter()
duration = end_time - start_time
error_msg = f"Network error while fetching chat history: {e}"
log_error(self._client_id, error_msg,
chat_id=chatId,
duration=f"{duration:.3f}s",
exception_type=type(e).__name__)
raise Exception(error_msg)
except Exception as e:
end_time = time.perf_counter()
duration = end_time - start_time
error_msg = f"Unexpected error while fetching chat history: {e}"
log_error(self._client_id, error_msg,
chat_id=chatId,
duration=f"{duration:.3f}s",
exception_type=type(e).__name__)
raise
async def main():
"""Example usage of the ChatModel class."""
chat_model = ChatModel(
api_key="fastgpt-tgpSdDSE51cc6BPdb92ODfsm0apZRXOrc75YeaiZ8HmqlYplZKi5flvJUqjG5b",
api_url="http://101.89.151.141:3000/",
appId="6846890686197e19f72036f9",
client_id="test_client"
)
try:
log_info("test_client", "Starting FastGPT API tests")
# Test welcome text
welcome_text = await chat_model.get_welcome_text('welcome')
log_info("test_client", "Welcome text test completed", welcome_text_length=len(welcome_text))
# Test AI response generation
response = await chat_model.generate_ai_response('chat0002', '我想问一下怎么用fastgpt')
log_info("test_client", "AI response test completed", response_length=len(response))
# Test chat history
history = await chat_model.get_chat_history('chat0002')
log_info("test_client", "Chat history test completed", history_count=len(history))
log_info("test_client", "All FastGPT API tests completed successfully")
except Exception as e:
log_error("test_client", f"Test failed: {e}", exception_type=type(e).__name__)
raise
if __name__ == "__main__":
asyncio.run(main())

98
src/logger.py Normal file
View File

@@ -0,0 +1,98 @@
import datetime
from typing import Optional
# ANSI escape codes for colors
class LogColors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'
# Log levels and symbols
LOG_LEVELS = {
"INFO": ("", LogColors.OKGREEN),
"DEBUG": ("🐛", LogColors.OKCYAN),
"WARNING": ("⚠️", LogColors.WARNING),
"ERROR": ("", LogColors.FAIL),
"TIMEOUT": ("⏱️", LogColors.OKBLUE),
"USER_INPUT": ("💬", LogColors.HEADER),
"AI_RESPONSE": ("🤖", LogColors.OKBLUE),
"SESSION": ("🔗", LogColors.BOLD),
"MODEL": ("🧠", LogColors.OKCYAN),
"PREDICT": ("🎯", LogColors.HEADER),
"PERFORMANCE": ("", LogColors.OKGREEN),
"CONNECTION": ("🌐", LogColors.OKBLUE)
}
def app_log(level: str, client_id: Optional[str], message: str, **kwargs):
"""
Custom logger with timestamp, level, color, and additional context.
Args:
level: Log level (INFO, DEBUG, WARNING, ERROR, etc.)
client_id: Client identifier for session tracking
message: Main log message
**kwargs: Additional key-value pairs to include in the log
"""
timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
symbol, color = LOG_LEVELS.get(level.upper(), ("🔹", LogColors.ENDC)) # Default if level not found
client_id_str = f" ({client_id})" if client_id else ""
extra_info = ""
if kwargs:
extra_info = " | " + " | ".join([f"{k}={v}" for k, v in kwargs.items()])
print(f"{color}{timestamp} [{level.upper()}] {symbol}{client_id_str}: {message}{extra_info}{LogColors.ENDC}")
def log_info(client_id: Optional[str], message: str, **kwargs):
"""Log an info message."""
app_log("INFO", client_id, message, **kwargs)
def log_debug(client_id: Optional[str], message: str, **kwargs):
"""Log a debug message."""
app_log("DEBUG", client_id, message, **kwargs)
def log_warning(client_id: Optional[str], message: str, **kwargs):
"""Log a warning message."""
app_log("WARNING", client_id, message, **kwargs)
def log_error(client_id: Optional[str], message: str, **kwargs):
"""Log an error message."""
app_log("ERROR", client_id, message, **kwargs)
def log_model(client_id: Optional[str], message: str, **kwargs):
"""Log a model-related message."""
app_log("MODEL", client_id, message, **kwargs)
def log_predict(client_id: Optional[str], message: str, **kwargs):
"""Log a prediction-related message."""
app_log("PREDICT", client_id, message, **kwargs)
def log_performance(client_id: Optional[str], message: str, **kwargs):
"""Log a performance-related message."""
app_log("PERFORMANCE", client_id, message, **kwargs)
def log_connection(client_id: Optional[str], message: str, **kwargs):
"""Log a connection-related message."""
app_log("CONNECTION", client_id, message, **kwargs)
def log_timeout(client_id: Optional[str], message: str, **kwargs):
"""Log a timeout-related message."""
app_log("TIMEOUT", client_id, message, **kwargs)
def log_user_input(client_id: Optional[str], message: str, **kwargs):
"""Log a user input message."""
app_log("USER_INPUT", client_id, message, **kwargs)
def log_ai_response(client_id: Optional[str], message: str, **kwargs):
"""Log an AI response message."""
app_log("AI_RESPONSE", client_id, message, **kwargs)
def log_session(client_id: Optional[str], message: str, **kwargs):
"""Log a session-related message."""
app_log("SESSION", client_id, message, **kwargs)

376
src/main.py Normal file
View File

@@ -0,0 +1,376 @@
import os
import asyncio
import json
import time
import datetime # Added for timestamp
import dotenv
import urllib.parse # For parsing query parameters
import websockets # Make sure it's imported at the top
from turn_detection import ChatMessage, TurnDetectorFactory, ONNX_AVAILABLE, FASTGPT_AVAILABLE
from fastgpt_api import ChatModel
from logger import (app_log, log_info, log_debug, log_warning, log_error,
log_timeout, log_user_input, log_ai_response, log_session)
dotenv.load_dotenv()
MAX_INCOMPLETE_SENTENCES = int(os.getenv("MAX_INCOMPLETE_SENTENCES", 3))
MAX_RESPONSE_TIMEOUT = int(os.getenv("MAX_RESPONSE_TIMEOUT", 5))
CHAT_MODEL_API_URL = os.getenv("CHAT_MODEL_API_URL", None)
CHAT_MODEL_API_KEY = os.getenv("CHAT_MODEL_API_KEY", None)
CHAT_MODEL_APP_ID = os.getenv("CHAT_MODEL_APP_ID", None)
# Turn Detection Configuration
TURN_DETECTION_MODEL = os.getenv("TURN_DETECTION_MODEL", "onnx").lower() # "onnx", "fastgpt", "always_true"
ONNX_UNLIKELY_THRESHOLD = float(os.getenv("ONNX_UNLIKELY_THRESHOLD", 0.0009))
def estimate_tts_playtime(text: str) -> float:
chars_per_second = 5.6
if not text: return 0.0
estimated_time = len(text) / chars_per_second
return max(0.5, estimated_time) # Min 0.5s for very short
def create_turn_detector_with_fallback():
"""
Create a turn detector with fallback logic if the requested mode is not available.
Returns:
Turn detector instance
"""
# Check if the requested mode is available
available_detectors = TurnDetectorFactory.get_available_detectors()
if TURN_DETECTION_MODEL not in available_detectors or not available_detectors[TURN_DETECTION_MODEL]:
# Requested mode is not available, find a fallback
log_warning(None, f"Requested turn detection mode '{TURN_DETECTION_MODEL}' is not available")
# Log available detectors
log_info(None, "Available turn detectors", available_detectors=available_detectors)
# Log import errors for unavailable detectors
import_errors = TurnDetectorFactory.get_import_errors()
if import_errors:
log_warning(None, "Import errors for unavailable detectors", import_errors=import_errors)
# Choose fallback based on availability
if available_detectors.get("fastgpt", False):
fallback_mode = "fastgpt"
log_info(None, f"Falling back to FastGPT turn detector")
elif available_detectors.get("onnx", False):
fallback_mode = "onnx"
log_info(None, f"Falling back to ONNX turn detector")
else:
fallback_mode = "always_true"
log_info(None, f"Falling back to AlwaysTrue turn detector (no ML models available)")
# Create the fallback detector
if fallback_mode == "onnx":
return TurnDetectorFactory.create_turn_detector(
fallback_mode,
unlikely_threshold=ONNX_UNLIKELY_THRESHOLD
)
else:
return TurnDetectorFactory.create_turn_detector(fallback_mode)
# Requested mode is available, create it
if TURN_DETECTION_MODEL == "onnx":
return TurnDetectorFactory.create_turn_detector(
TURN_DETECTION_MODEL,
unlikely_threshold=ONNX_UNLIKELY_THRESHOLD
)
else:
return TurnDetectorFactory.create_turn_detector(TURN_DETECTION_MODEL)
class SessionData:
def __init__(self, client_id):
self.client_id = client_id
self.incomplete_sentences = []
self.conversation_history = []
self.last_input_time = time.time()
self.timeout_task = None
self.ai_response_playback_ends_at: float | None = None
# Global instances
turn_detection_model = create_turn_detector_with_fallback()
ai_model = chat_model = ChatModel(
api_key=CHAT_MODEL_API_KEY,
api_url=CHAT_MODEL_API_URL,
appId=CHAT_MODEL_APP_ID
)
sessions = {}
async def handle_input_timeout(websocket, session: SessionData):
client_id = session.client_id
try:
if session.ai_response_playback_ends_at:
current_time = time.time()
remaining_ai_playtime = session.ai_response_playback_ends_at - current_time
if remaining_ai_playtime > 0:
log_timeout(client_id, f"Waiting for AI playback to finish", remaining_playtime=f"{remaining_ai_playtime:.2f}s")
await asyncio.sleep(remaining_ai_playtime)
log_timeout(client_id, f"AI playback done. Starting user inactivity", timeout_seconds=MAX_RESPONSE_TIMEOUT)
await asyncio.sleep(MAX_RESPONSE_TIMEOUT)
# If we reach here, 5 seconds of user silence have passed *after* AI finished.
# Process buffered input if any
if session.incomplete_sentences:
buffered_text = ' '.join(session.incomplete_sentences)
log_timeout(client_id, f"Processing buffered input after silence", buffer_content=f"'{buffered_text}'")
full_turn_text = " ".join(session.incomplete_sentences)
await process_complete_turn(websocket, session, full_turn_text)
else:
log_timeout(client_id, f"No buffered input after silence")
session.timeout_task = None # Clear the task reference
except asyncio.CancelledError:
log_info(client_id, f"Timeout task was cancelled", task_details=str(session.timeout_task))
pass # Expected
except Exception as e:
log_error(client_id, f"Error in timeout handler: {e}", exception_type=type(e).__name__)
if session: session.timeout_task = None
async def handle_user_input(websocket, client_id: str, incoming_text: str):
incoming_text = incoming_text.strip('') # chinese period could affect prediction
# client_id is now passed directly from chat_handler and is known to exist in sessions
session = sessions[client_id]
session.last_input_time = time.time() # Update on EVERY user input
# CRITICAL: Cancel any existing timeout task because new input has arrived.
# This handles cancellations during AI playback wait or user silence wait.
if session.timeout_task and not session.timeout_task.done():
session.timeout_task.cancel()
session.timeout_task = None
# print(f"Cancelled previous timeout task for {client_id} due to new input.")
ai_is_speaking_now = False
if session.ai_response_playback_ends_at and time.time() < session.ai_response_playback_ends_at:
ai_is_speaking_now = True
log_user_input(client_id, f"AI speaking. Buffering: '{incoming_text}'", current_buffer_size=len(session.incomplete_sentences))
if ai_is_speaking_now:
session.incomplete_sentences.append(incoming_text)
log_user_input(client_id, f"AI speaking. Scheduling new timeout", new_buffer_size=len(session.incomplete_sentences))
session.timeout_task = asyncio.create_task(handle_input_timeout(websocket, session))
return
# AI is NOT speaking, proceed with normal turn detection for current + buffered input
current_potential_turn_parts = session.incomplete_sentences + [incoming_text]
current_potential_turn_text = " ".join(current_potential_turn_parts)
context_for_turn_detection = session.conversation_history + [ChatMessage(role='user', content=current_potential_turn_text)]
# Use the configured turn detector
is_complete = await turn_detection_model.predict(
context_for_turn_detection,
client_id=client_id
)
log_debug(client_id, "Turn detection result",
mode=TURN_DETECTION_MODEL,
is_complete=is_complete,
text_checked=current_potential_turn_text)
if is_complete:
await process_complete_turn(websocket, session, current_potential_turn_text)
else:
session.incomplete_sentences.append(incoming_text)
if len(session.incomplete_sentences) >= MAX_INCOMPLETE_SENTENCES:
log_user_input(client_id, f"Max incomplete sentences limit reached. Processing", limit=MAX_INCOMPLETE_SENTENCES, current_count=len(session.incomplete_sentences))
full_turn_text = " ".join(session.incomplete_sentences)
await process_complete_turn(websocket, session, full_turn_text)
else:
log_user_input(client_id, f"Turn incomplete. Scheduling new timeout", current_buffer_size=len(session.incomplete_sentences))
session.timeout_task = asyncio.create_task(handle_input_timeout(websocket, session))
async def process_complete_turn(websocket, session: SessionData, full_user_turn_text: str, is_welcome_message_context=False):
# For a welcome message, full_user_turn_text might be empty or a system prompt
if not is_welcome_message_context: # Only add user message if it's not the initial welcome context
session.conversation_history.append(ChatMessage(role="user", content=full_user_turn_text))
session.incomplete_sentences = []
try:
# Pass current history to AI model. For welcome, it might be empty or have a system seed.
if not is_welcome_message_context:
ai_response_text = await ai_model.generate_ai_response(session.client_id, full_user_turn_text)
else:
ai_response_text = await ai_model.get_welcome_text(session.client_id)
log_debug(session.client_id, "AI model interaction", is_welcome=is_welcome_message_context, user_turn_length=len(full_user_turn_text) if not is_welcome_message_context else 0)
except Exception as e:
log_error(session.client_id, f"AI response generation failed: {e}", is_welcome=is_welcome_message_context, exception_type=type(e).__name__)
# If it's not a welcome message context and AI failed, revert user message
if not is_welcome_message_context and session.conversation_history and session.conversation_history[-1].role == "user":
session.conversation_history.pop()
await websocket.send(json.dumps({
"type": "ERROR", "payload": {"message": "AI failed", "client_id": session.client_id}
}))
return
session.conversation_history.append(ChatMessage(role="assistant", content=ai_response_text))
tts_duration = estimate_tts_playtime(ai_response_text)
# Set when AI response playback is expected to end. THIS IS THE KEY for the timeout logic.
session.ai_response_playback_ends_at = time.time() + tts_duration
log_ai_response(session.client_id, f"Response sent: '{ai_response_text}'", tts_duration=f"{tts_duration:.2f}s", playback_ends_at=f"{session.ai_response_playback_ends_at:.2f}")
await websocket.send(json.dumps({
"type": "AI_RESPONSE",
"payload": {
"text": ai_response_text,
"client_id": session.client_id,
"estimated_tts_duration": tts_duration
}
}))
if session.timeout_task and not session.timeout_task.done():
session.timeout_task.cancel()
session.timeout_task = None
# --- MODIFIED chat_handler ---
async def chat_handler(websocket: websockets):
"""
Handles new WebSocket connections.
Extracts client_id from path, manages session creation, and message routing.
"""
path = websocket.request.path
parsed_path = urllib.parse.urlparse(path)
query_params = urllib.parse.parse_qs(parsed_path.query)
raw_client_id_values = query_params.get('clientId') # This will be None or list of strings
client_id: str | None = None
if raw_client_id_values and raw_client_id_values[0].strip():
client_id = raw_client_id_values[0].strip()
if client_id is None:
log_warning(None, f"Connection from {websocket.remote_address} missing or empty clientId in path: {path}. Closing.")
await websocket.close(code=1008, reason="clientId parameter is required and cannot be empty.")
return
# Now client_id is guaranteed to be a non-empty string here
log_info(client_id, f"Connection attempt from {websocket.remote_address}, Path: {path}")
# --- Session Creation and Welcome Message ---
is_new_session = False
if client_id not in sessions:
log_session(client_id, f"NEW SESSION: Creating session", total_sessions_before=len(sessions))
sessions[client_id] = SessionData(client_id)
is_new_session = True
else:
# Client reconnected, or multiple connections with same ID (handle as needed)
# For now, we assume one active websocket per client_id for simplicity of timeout tasks etc.
# If an old session for this client_id had a lingering timeout task, it should be cancelled
# if this new connection effectively replaces the old one.
# This part needs care if multiple websockets can truly share one session.
# For now, let's ensure any old timeout for this session_id is cleared if a new websocket connects.
existing_session = sessions[client_id]
if existing_session.timeout_task and not existing_session.timeout_task.done():
log_info(client_id, f"RECONNECT: Cancelling old timeout task from previous connection")
existing_session.timeout_task.cancel()
existing_session.timeout_task = None
# Update last_input_time to reflect new activity/connection
existing_session.last_input_time = time.time()
# Reset playback state as it pertains to the previous connection's AI responses
existing_session.ai_response_playback_ends_at = None
log_session(client_id, f"EXISTING SESSION: Client reconnected or new connection")
session = sessions[client_id] # Get the session (new or existing)
if is_new_session:
# Send a welcome message
log_session(client_id, f"NEW SESSION: Sending welcome message")
# We can add a system prompt to the history before generating welcome message if needed
# session.conversation_history.append({"role": "system", "content": "You are a friendly assistant."})
await process_complete_turn(websocket, session, "", is_welcome_message_context=True)
# The welcome message itself will have TTS, so ai_response_playback_ends_at will be set.
# --- Message Loop ---
try:
async for message_str in websocket:
try:
message_data = json.loads(message_str)
msg_type = message_data.get("type")
payload = message_data.get("payload")
if msg_type == "USER_INPUT":
# Client no longer needs to send client_id in payload if it's in URL
# but if it does, we can validate it matches the URL's client_id
payload_client_id = payload.get("client_id")
if payload_client_id and payload_client_id != client_id:
log_warning(client_id, f"Mismatch! URL clientId='{client_id}', Payload clientId='{payload_client_id}'. Using URL clientId.")
# Decide on error strategy or just use URL's client_id
text_input = payload.get("text")
if text_input is None: # Ensure text is present
await websocket.send(json.dumps({"type": "ERROR", "payload": {"message": "USER_INPUT missing 'text'", "client_id": client_id}}))
continue
await handle_user_input(websocket, client_id, text_input)
else:
await websocket.send(json.dumps({"type": "ERROR", "payload": {"message": f"Unknown msg type: {msg_type}", "client_id": client_id}}))
except json.JSONDecodeError:
await websocket.send(json.dumps({"type": "ERROR", "payload": {"message": "Invalid JSON", "client_id": client_id}}))
except Exception as e:
log_error(client_id, f"Error processing message: {e}")
import traceback
traceback.print_exc()
await websocket.send(json.dumps({"type": "ERROR", "payload": {"message": f"Server error: {str(e)}", "client_id": client_id}}))
except websockets.exceptions.ConnectionClosedError as e:
log_error(client_id, f"Connection closed with error: {e.code} {e.reason}")
except websockets.exceptions.ConnectionClosedOK:
log_info(client_id, f"Connection closed gracefully")
except Exception as e:
log_error(client_id, f"Unexpected error in handler: {e}")
import traceback
traceback.print_exc()
finally:
log_info(client_id, f"Connection ended. Cleaning up resources.")
# The session object itself (sessions[client_id]) remains in memory.
# Its timeout_task, if active for THIS websocket connection, should be cancelled.
# If another websocket connects with the same client_id, it will reuse the session.
# Stale sessions in the `sessions` dict would need a separate cleanup mechanism
# if they are not reconnected to (e.g. based on last_input_time).
# If this websocket was the one associated with the session's current timeout_task, cancel it.
# This is tricky because the timeout_task is tied to the session, not the websocket instance directly.
# The logic at the start of chat_handler for existing sessions helps here.
# If this is the *only* connection for this client_id and it's closing,
# then any active timeout_task on its session should ideally be stopped.
# However, if client can reconnect, keeping the task might be desired if it's a short disconnect.
# For simplicity now, we rely on new connections cancelling old tasks.
# A more robust solution might involve tracking active websockets per session.
# If we want to ensure no timeout task runs for a session if NO websocket is connected for it:
# This requires knowing if other websockets are active for this client_id.
# For a single-connection-per-client_id model enforced by the client:
if client_id in sessions: # Check if session still exists (it should)
active_session = sessions[client_id]
# Heuristic: If this websocket is closing, and it was the one that last interacted
# or if no other known websocket is active for this session, cancel its timeout.
# This is complex without explicit websocket tracking per session.
# For now, the cancellation at the START of a new connection for an existing session is the primary mechanism.
log_info(client_id, f"Client disconnected. Session data remains. Next connection will reuse/manage timeout.")
async def main():
log_info(None, f"Chat server starting with turn detection mode: {TURN_DETECTION_MODEL}")
# Log available detectors
available_detectors = TurnDetectorFactory.get_available_detectors()
log_info(None, "Available turn detectors", available_detectors=available_detectors)
if TURN_DETECTION_MODEL == "onnx" and ONNX_AVAILABLE:
log_info(None, f"ONNX threshold: {ONNX_UNLIKELY_THRESHOLD}")
server = await websockets.serve(chat_handler, "0.0.0.0", 9000)
log_info(None, "Chat server started (clientId from URL, welcome msg)")
# on ws://localhost:8765")
await server.wait_closed()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,166 @@
# Turn Detection Package
This package provides multiple turn detection implementations for conversational AI systems. Turn detection determines when a user has finished speaking and it's appropriate for the AI to respond.
## Package Structure
```
turn_detection/
├── __init__.py # Package exports and backward compatibility
├── base.py # Base classes and common data structures
├── factory.py # Factory for creating turn detectors
├── onnx_detector.py # ONNX-based turn detector
├── fastgpt_detector.py # FastGPT API-based turn detector
├── always_true_detector.py # Simple always-true detector for testing
└── README.md # This file
```
## Available Turn Detectors
### 1. ONNXTurnDetector
- **File**: `onnx_detector.py`
- **Description**: Uses a pre-trained ONNX model with Hugging Face tokenizer
- **Use Case**: Production-ready, offline turn detection
- **Dependencies**: `onnxruntime`, `transformers`, `huggingface_hub`
### 2. FastGPTTurnDetector
- **File**: `fastgpt_detector.py`
- **Description**: Uses FastGPT API for turn detection
- **Use Case**: Cloud-based turn detection with API access
- **Dependencies**: `fastgpt_api`
### 3. AlwaysTrueTurnDetector
- **File**: `always_true_detector.py`
- **Description**: Always returns True (considers all turns complete)
- **Use Case**: Testing, debugging, or when turn detection is not needed
- **Dependencies**: None
## Usage
### Basic Usage
```python
from turn_detection import ChatMessage, TurnDetectorFactory
# Create a turn detector using the factory
detector = TurnDetectorFactory.create_turn_detector(
mode="onnx", # "onnx", "fastgpt", or "always_true"
unlikely_threshold=0.005 # For ONNX detector
)
# Prepare chat context
chat_context = [
ChatMessage(role='assistant', content='Hello, how can I help you?'),
ChatMessage(role='user', content='I need help with my order')
]
# Predict if the turn is complete
is_complete = await detector.predict(chat_context, client_id="user123")
print(f"Turn complete: {is_complete}")
# Get probability
probability = await detector.predict_probability(chat_context, client_id="user123")
print(f"Completion probability: {probability}")
```
### Direct Class Usage
```python
from turn_detection import ONNXTurnDetector, FastGPTTurnDetector, AlwaysTrueTurnDetector
# ONNX detector
onnx_detector = ONNXTurnDetector(unlikely_threshold=0.005)
# FastGPT detector
fastgpt_detector = FastGPTTurnDetector(
api_url="http://your-api-url",
api_key="your-api-key",
appId="your-app-id"
)
# Always true detector
always_true_detector = AlwaysTrueTurnDetector()
```
### Factory Configuration
The factory supports different configuration options for each detector type:
```python
# ONNX detector with custom settings
onnx_detector = TurnDetectorFactory.create_turn_detector(
mode="onnx",
unlikely_threshold=0.001,
max_history_tokens=256,
max_history_turns=8
)
# FastGPT detector with custom settings
fastgpt_detector = TurnDetectorFactory.create_turn_detector(
mode="fastgpt",
api_url="http://custom-api-url",
api_key="custom-api-key",
appId="custom-app-id"
)
# Always true detector (no configuration needed)
always_true_detector = TurnDetectorFactory.create_turn_detector(mode="always_true")
```
## Data Structures
### ChatMessage
```python
@dataclass
class ChatMessage:
role: ChatRole # "system", "user", "assistant", "tool"
content: str | list[str] | None = None
```
### ChatRole
```python
ChatRole = Literal["system", "user", "assistant", "tool"]
```
## Base Class Interface
All turn detectors implement the `BaseTurnDetector` interface:
```python
class BaseTurnDetector(ABC):
@abstractmethod
async def predict(self, chat_context: List[ChatMessage], client_id: str = None) -> bool:
"""Predicts whether the current utterance is complete."""
pass
@abstractmethod
async def predict_probability(self, chat_context: List[ChatMessage], client_id: str = None) -> float:
"""Predicts the probability that the current utterance is complete."""
pass
```
## Environment Variables
The following environment variables can be used to configure the detectors:
- `TURN_DETECTION_MODEL`: Turn detection mode ("onnx", "fastgpt", "always_true")
- `ONNX_UNLIKELY_THRESHOLD`: Threshold for ONNX detector (default: 0.005)
- `CHAT_MODEL_API_URL`: FastGPT API URL
- `CHAT_MODEL_API_KEY`: FastGPT API key
- `CHAT_MODEL_APP_ID`: FastGPT app ID
## Backward Compatibility
For backward compatibility, the original `TurnDetector` name still refers to `ONNXTurnDetector`:
```python
from turn_detection import TurnDetector # Same as ONNXTurnDetector
```
## Examples
See the individual detector files for complete usage examples:
- `onnx_detector.py` - ONNX detector example
- `fastgpt_detector.py` - FastGPT detector example
- `always_true_detector.py` - Always true detector example

View File

@@ -0,0 +1,49 @@
"""
Turn Detection Package
This package provides multiple turn detection implementations for conversational AI systems.
"""
from .base import ChatMessage, ChatRole, BaseTurnDetector
# Try to import ONNX detector, but handle import failures gracefully
try:
from .onnx_detector import TurnDetector as ONNXTurnDetector
ONNX_AVAILABLE = True
except ImportError as e:
ONNX_AVAILABLE = False
ONNXTurnDetector = None
_onnx_import_error = str(e)
# Try to import FastGPT detector
try:
from .fastgpt_detector import TurnDetector as FastGPTTurnDetector
FASTGPT_AVAILABLE = True
except ImportError as e:
FASTGPT_AVAILABLE = False
FastGPTTurnDetector = None
_fastgpt_import_error = str(e)
# Always true detector should always be available
from .always_true_detector import AlwaysTrueTurnDetector
from .factory import TurnDetectorFactory
# Export the main classes
__all__ = [
'ChatMessage',
'ChatRole',
'BaseTurnDetector',
'ONNXTurnDetector',
'FastGPTTurnDetector',
'AlwaysTrueTurnDetector',
'TurnDetectorFactory',
'ONNX_AVAILABLE',
'FASTGPT_AVAILABLE'
]
# For backward compatibility, keep the original names
# Only set TurnDetector if ONNX is available
if ONNX_AVAILABLE:
TurnDetector = ONNXTurnDetector
else:
TurnDetector = None

View File

@@ -0,0 +1,26 @@
"""
AlwaysTrueTurnDetector - A simple turn detector that always returns True.
"""
from typing import List
from .base import BaseTurnDetector, ChatMessage
from logger import log_info, log_debug
class AlwaysTrueTurnDetector(BaseTurnDetector):
"""
A simple turn detector that always returns True (always considers turns complete).
Useful for testing or when turn detection is not needed.
"""
def __init__(self):
log_info(None, "AlwaysTrueTurnDetector initialized - all turns will be considered complete")
async def predict(self, chat_context: List[ChatMessage], client_id: str = None) -> bool:
"""Always returns True, indicating the turn is complete."""
log_debug(client_id, "AlwaysTrueTurnDetector: Turn considered complete",
context_length=len(chat_context))
return True
async def predict_probability(self, chat_context: List[ChatMessage], client_id: str = None) -> float:
"""Always returns 1.0 probability."""
return 1.0

View File

@@ -0,0 +1,55 @@
"""
Base classes and data structures for turn detection.
"""
from typing import Any, Literal, Union, List
from dataclasses import dataclass
from abc import ABC, abstractmethod
# --- Data Structures ---
ChatRole = Literal["system", "user", "assistant", "tool"]
@dataclass
class ChatMessage:
"""Represents a single message in a chat conversation."""
role: ChatRole
content: str | list[str] | None = None
# --- Abstract Base Class ---
class BaseTurnDetector(ABC):
"""
Abstract base class for all turn detectors.
All turn detectors should inherit from this class and implement
the required methods.
"""
@abstractmethod
async def predict(self, chat_context: List[ChatMessage], client_id: str = None) -> bool:
"""
Predicts whether the current utterance is complete.
Args:
chat_context: A list of ChatMessage objects representing the conversation history.
client_id: Client identifier for logging purposes.
Returns:
True if the utterance is complete, False otherwise.
"""
pass
@abstractmethod
async def predict_probability(self, chat_context: List[ChatMessage], client_id: str = None) -> float:
"""
Predicts the probability that the current utterance is complete.
Args:
chat_context: A list of ChatMessage objects representing the conversation history.
client_id: Client identifier for logging purposes.
Returns:
A float representing the probability that the utterance is complete.
"""
pass

View File

@@ -0,0 +1,102 @@
"""
Turn Detector Factory
Factory class for creating turn detectors based on configuration.
"""
from .base import BaseTurnDetector
from .always_true_detector import AlwaysTrueTurnDetector
from logger import log_info, log_warning, log_error
# Try to import ONNX detector
try:
from .onnx_detector import TurnDetector as ONNXTurnDetector
ONNX_AVAILABLE = True
except ImportError as e:
ONNX_AVAILABLE = False
ONNXTurnDetector = None
_onnx_import_error = str(e)
# Try to import FastGPT detector
try:
from .fastgpt_detector import TurnDetector as FastGPTTurnDetector
FASTGPT_AVAILABLE = True
except ImportError as e:
FASTGPT_AVAILABLE = False
FastGPTTurnDetector = None
_fastgpt_import_error = str(e)
class TurnDetectorFactory:
"""Factory class to create turn detectors based on configuration."""
@staticmethod
def create_turn_detector(mode: str, **kwargs):
"""
Create a turn detector based on the specified mode.
Args:
mode: Turn detection mode ("onnx", "fastgpt", "always_true")
**kwargs: Additional arguments for the specific turn detector
Returns:
Turn detector instance
Raises:
ImportError: If the requested detector is not available due to missing dependencies
"""
if mode == "onnx":
if not ONNX_AVAILABLE:
error_msg = f"ONNX turn detector is not available. Import error: {_onnx_import_error}"
log_error(None, error_msg)
raise ImportError(error_msg)
unlikely_threshold = kwargs.get('unlikely_threshold', 0.005)
log_info(None, f"Creating ONNX turn detector with threshold {unlikely_threshold}")
return ONNXTurnDetector(
unlikely_threshold=unlikely_threshold,
**{k: v for k, v in kwargs.items() if k != 'unlikely_threshold'}
)
elif mode == "fastgpt":
if not FASTGPT_AVAILABLE:
error_msg = f"FastGPT turn detector is not available. Import error: {_fastgpt_import_error}"
log_error(None, error_msg)
raise ImportError(error_msg)
log_info(None, "Creating FastGPT turn detector")
return FastGPTTurnDetector(**kwargs)
elif mode == "always_true":
log_info(None, "Creating AlwaysTrue turn detector")
return AlwaysTrueTurnDetector()
else:
log_warning(None, f"Unknown turn detection mode '{mode}', defaulting to AlwaysTrue")
log_info(None, "Creating AlwaysTrue turn detector as fallback")
return AlwaysTrueTurnDetector()
@staticmethod
def get_available_detectors():
"""
Get a list of available turn detector modes.
Returns:
dict: Dictionary with detector modes as keys and availability as boolean values
"""
return {
"onnx": ONNX_AVAILABLE,
"fastgpt": FASTGPT_AVAILABLE,
"always_true": True # Always available
}
@staticmethod
def get_import_errors():
"""
Get import error messages for unavailable detectors.
Returns:
dict: Dictionary with detector modes as keys and error messages as values
"""
errors = {}
if not ONNX_AVAILABLE:
errors["onnx"] = _onnx_import_error
if not FASTGPT_AVAILABLE:
errors["fastgpt"] = _fastgpt_import_error
return errors

View File

@@ -0,0 +1,163 @@
"""
FastGPT-based Turn Detector
A turn detector implementation using FastGPT API for turn detection.
"""
import time
import asyncio
from typing import List
from .base import BaseTurnDetector, ChatMessage
from fastgpt_api import ChatModel
from logger import log_info, log_debug, log_warning, log_performance
class TurnDetector(BaseTurnDetector):
"""
A class to detect the end of an utterance (turn) in a conversation
using FastGPT API for turn detection.
"""
# --- Class Constants (Default Configuration) ---
# These can be overridden during instantiation if needed
MAX_HISTORY_TOKENS: int = 128
MAX_HISTORY_TURNS: int = 6 # Note: This constant wasn't used in the original logic, keeping for completeness
API_URL="http://101.89.151.141:3000/"
API_KEY="fastgpt-opfE4uKlw6I1EFIY55iWh1dfVPfaQGH2wXvFaCixaZDaZHU1mA61"
APP_ID="6850f14486197e19f721b80d"
def __init__(self,
max_history_tokens: int = None,
max_history_turns: int = None,
api_url: str = None,
api_key: str = None,
appId: str = None):
"""
Initializes the TurnDetector with FastGPT API configuration.
Args:
max_history_tokens: Maximum number of tokens for the input sequence. Defaults to MAX_HISTORY_TOKENS.
max_history_turns: Maximum number of turns to consider in history. Defaults to MAX_HISTORY_TURNS.
api_url: API URL for the FastGPT model. Defaults to API_URL.
api_key: API key for authentication. Defaults to API_KEY.
app_id: Application ID for the FastGPT model. Defaults to APP_ID.
"""
# Store configuration, using provided args or class defaults
self._api_url = api_url or self.API_URL
self._api_key = api_key or self.API_KEY
self._appId = appId or self.APP_ID
self._max_history_tokens = max_history_tokens or self.MAX_HISTORY_TOKENS
self._max_history_turns = max_history_turns or self.MAX_HISTORY_TURNS
log_info(None, "FastGPT TurnDetector initialized",
api_url=self._api_url,
app_id=self._appId)
self._chat_model = ChatModel(
api_url=self._api_url,
api_key=self._api_key,
appId=self._appId
)
def _format_chat_ctx(self, chat_context: List[ChatMessage]) -> str:
"""
Formats the chat context into a string for model input.
Args:
chat_context: A list of ChatMessage objects representing the conversation history.
Returns:
A string containing the formatted conversation history.
"""
lst = []
for message in chat_context:
if message.role == 'assistant':
lst.append(f"客服: {message.content}")
elif message.role == 'user':
lst.append(f"用户: {message.content}")
return "\n".join(lst)
async def predict(self, chat_context: List[ChatMessage], client_id: str = None) -> bool:
"""
Predicts whether the current utterance is complete using FastGPT API.
Args:
chat_context: A list of ChatMessage objects representing the conversation history.
client_id: Client identifier for logging purposes.
Returns:
True if the utterance is complete, False otherwise.
"""
if not chat_context:
log_warning(client_id, "Empty chat context provided, returning False")
return False
start_time = time.perf_counter()
text = self._format_chat_ctx(chat_context[-self._max_history_turns:])
log_debug(client_id, "FastGPT turn detection processing",
context_length=len(chat_context),
text_length=len(text))
# Generate a unique chat ID for this prediction
chat_id = f"turn_detection_{int(time.time() * 1000)}"
try:
output = await self._chat_model.generate_ai_response(chat_id, text)
result = output == '完整'
end_time = time.perf_counter()
duration = end_time - start_time
log_performance(client_id, "FastGPT turn detection completed",
duration=f"{duration:.3f}s",
output=output,
result=result)
log_debug(client_id, "FastGPT turn detection result",
output=output,
is_complete=result)
return result
except Exception as e:
end_time = time.perf_counter()
duration = end_time - start_time
log_warning(client_id, f"FastGPT turn detection failed: {e}",
duration=f"{duration:.3f}s",
exception_type=type(e).__name__)
# Default to True (complete) on error to avoid blocking
return True
async def predict_probability(self, chat_context: List[ChatMessage], client_id: str = None) -> float:
"""
Predicts the probability that the current utterance is complete.
For FastGPT turn detector, this is a simplified implementation.
Args:
chat_context: A list of ChatMessage objects representing the conversation history.
client_id: Client identifier for logging purposes.
Returns:
A float representing the probability (1.0 for complete, 0.0 for incomplete).
"""
is_complete = await self.predict(chat_context, client_id)
return 1.0 if is_complete else 0.0
async def main():
"""Example usage of the FastGPT TurnDetector class."""
chat_ctx = [
ChatMessage(role='assistant', content='目前人工坐席繁忙我是12345智能客服。请详细说出您要反映的事项如事件发生的时间、地址、具体的经过以及您期望的解决方案等'),
ChatMessage(role='user', content='喂,喂'),
ChatMessage(role='assistant', content='您好,请问有什么可以帮到您?'),
ChatMessage(role='user', content='嗯,我想问一下,就是我在那个网上买那个迪士尼门票快。过期了,然后找不到。找不到客服退货怎么办'),
]
turn_detection = TurnDetector()
result = await turn_detection.predict(chat_ctx, client_id="test_client")
log_info("test_client", f"FastGPT turn detection result: {result}")
return result
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,376 @@
"""
ONNX-based Turn Detector
A turn detector implementation using a pre-trained ONNX model and Hugging Face tokenizer.
"""
import psutil
import math
import json
import time
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer
import asyncio
from typing import List
from .base import BaseTurnDetector, ChatMessage
from logger import log_model, log_predict, log_performance, log_warning
class TurnDetector(BaseTurnDetector):
"""
A class to detect the end of an utterance (turn) in a conversation
using a pre-trained ONNX model and Hugging Face tokenizer.
"""
# --- Class Constants (Default Configuration) ---
# These can be overridden during instantiation if needed
HG_MODEL: str = "livekit/turn-detector"
ONNX_FILENAME: str = "model_q8.onnx"
MODEL_REVISION: str = "v0.2.0-intl"
MAX_HISTORY_TOKENS: int = 128
MAX_HISTORY_TURNS: int = 6
INFERENCE_METHOD: str = "lk_end_of_utterance_multilingual"
UNLIKELY_THRESHOLD: float = 0.005
def __init__(self,
max_history_tokens: int = None,
max_history_turns: int = None,
hg_model: str = None,
onnx_filename: str = None,
model_revision: str = None,
inference_method: str = None,
unlikely_threshold: float = None):
"""
Initializes the TurnDetector by downloading and loading the necessary
model files, tokenizer, and configuration.
Args:
max_history_tokens: Maximum number of tokens for the input sequence. Defaults to MAX_HISTORY_TOKENS.
max_history_turns: Maximum number of turns to consider in history. Defaults to MAX_HISTORY_TURNS.
hg_model: Hugging Face model identifier. Defaults to HG_MODEL.
onnx_filename: ONNX model filename. Defaults to ONNX_FILENAME.
model_revision: Model revision/tag. Defaults to MODEL_REVISION.
inference_method: Inference method name. Defaults to INFERENCE_METHOD.
unlikely_threshold: Threshold for determining if utterance is complete. Defaults to UNLIKELY_THRESHOLD.
"""
# Store configuration, using provided args or class defaults
self._max_history_tokens = max_history_tokens or self.MAX_HISTORY_TOKENS
self._max_history_turns = max_history_turns or self.MAX_HISTORY_TURNS
self._hg_model = hg_model or self.HG_MODEL
self._onnx_filename = onnx_filename or self.ONNX_FILENAME
self._model_revision = model_revision or self.MODEL_REVISION
self._inference_method = inference_method or self.INFERENCE_METHOD
# Initialize model components
self._languages = None
self._session = None
self._tokenizer = None
self._unlikely_threshold = unlikely_threshold or self.UNLIKELY_THRESHOLD
log_model(None, "Initializing TurnDetector",
model=self._hg_model,
revision=self._model_revision,
threshold=self._unlikely_threshold)
# Load model components
self._load_model_components()
async def _download_from_hf_hub_async(self, repo_id: str, filename: str, **kwargs) -> str:
"""
Downloads a file from Hugging Face Hub asynchronously.
Args:
repo_id: Repository ID on Hugging Face Hub.
filename: Name of the file to download.
**kwargs: Additional arguments for hf_hub_download.
Returns:
Local path to the downloaded file.
"""
# Run the synchronous download in a thread pool to make it async
loop = asyncio.get_event_loop()
local_path = await loop.run_in_executor(
None,
lambda: hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
)
return local_path
def _download_from_hf_hub(self, repo_id: str, filename: str, **kwargs) -> str:
"""
Downloads a file from Hugging Face Hub (synchronous version).
Args:
repo_id: Repository ID on Hugging Face Hub.
filename: Name of the file to download.
**kwargs: Additional arguments for hf_hub_download.
Returns:
Local path to the downloaded file.
"""
local_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
return local_path
async def _load_model_components_async(self):
"""Loads and initializes the model, tokenizer, and configuration asynchronously."""
log_model(None, "Loading model components asynchronously")
# Load languages configuration
config_fname = await self._download_from_hf_hub_async(
self._hg_model,
"languages.json",
revision=self._model_revision,
local_files_only=False
)
# Read file asynchronously
loop = asyncio.get_event_loop()
with open(config_fname) as f:
self._languages = json.load(f)
log_model(None, "Languages configuration loaded", languages_count=len(self._languages))
# Load ONNX model
local_path_onnx = await self._download_from_hf_hub_async(
self._hg_model,
self._onnx_filename,
subfolder="onnx",
revision=self._model_revision,
local_files_only=False,
)
# Configure ONNX session
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = max(
1, math.ceil(psutil.cpu_count()) // 2
)
sess_options.inter_op_num_threads = 1
sess_options.add_session_config_entry("session.dynamic_block_base", "4")
self._session = ort.InferenceSession(
local_path_onnx, providers=["CPUExecutionProvider"], sess_options=sess_options
)
# Load tokenizer
self._tokenizer = AutoTokenizer.from_pretrained(
self._hg_model,
revision=self._model_revision,
local_files_only=False,
truncation_side="left",
)
log_model(None, "Model components loaded successfully",
onnx_path=local_path_onnx,
intra_threads=sess_options.intra_op_num_threads)
def _load_model_components(self):
"""Loads and initializes the model, tokenizer, and configuration."""
log_model(None, "Loading model components")
# Load languages configuration
config_fname = self._download_from_hf_hub(
self._hg_model,
"languages.json",
revision=self._model_revision,
local_files_only=False
)
with open(config_fname) as f:
self._languages = json.load(f)
log_model(None, "Languages configuration loaded", languages_count=len(self._languages))
# Load ONNX model
local_path_onnx = self._download_from_hf_hub(
self._hg_model,
self._onnx_filename,
subfolder="onnx",
revision=self._model_revision,
local_files_only=False,
)
# Configure ONNX session
sess_options = ort.SessionOptions()
sess_options.intra_op_num_threads = max(
1, math.ceil(psutil.cpu_count()) // 2
)
sess_options.inter_op_num_threads = 1
sess_options.add_session_config_entry("session.dynamic_block_base", "4")
self._session = ort.InferenceSession(
local_path_onnx, providers=["CPUExecutionProvider"], sess_options=sess_options
)
# Load tokenizer
self._tokenizer = AutoTokenizer.from_pretrained(
self._hg_model,
revision=self._model_revision,
local_files_only=False,
truncation_side="left",
)
log_model(None, "Model components loaded successfully",
onnx_path=local_path_onnx,
intra_threads=sess_options.intra_op_num_threads)
def _format_chat_ctx(self, chat_context: List[ChatMessage]) -> str:
"""
Formats the chat context into a string for model input.
Args:
chat_context: A list of ChatMessage objects representing the conversation history.
Returns:
A string containing the formatted conversation history.
"""
new_chat_ctx = []
for msg in chat_context:
new_chat_ctx.append(msg)
convo_text = self._tokenizer.apply_chat_template(
new_chat_ctx,
add_generation_prompt=False,
add_special_tokens=False,
tokenize=False,
)
# remove the EOU token from current utterance
ix = convo_text.rfind("<|im_end|>")
text = convo_text[:ix]
return text
async def predict(self, chat_context: List[ChatMessage], client_id: str = None) -> bool:
"""
Predicts the probability that the current utterance is complete.
Args:
chat_context: A list of ChatMessage objects representing the conversation history.
client_id: Client identifier for logging purposes.
Returns:
is_complete: True if the utterance is complete, False otherwise.
"""
if not chat_context:
log_warning(client_id, "Empty chat context provided, returning False")
return False
start_time = time.perf_counter()
text = self._format_chat_ctx(chat_context[-self._max_history_turns:])
log_predict(client_id, "Processing turn detection",
context_length=len(chat_context),
text_length=len(text))
# Run tokenization in thread pool to avoid blocking
loop = asyncio.get_event_loop()
inputs = await loop.run_in_executor(
None,
lambda: self._tokenizer(
text,
add_special_tokens=False,
return_tensors="np",
max_length=self._max_history_tokens,
truncation=True,
)
)
# Run inference in thread pool
outputs = await loop.run_in_executor(
None,
lambda: self._session.run(
None, {"input_ids": inputs["input_ids"].astype("int64")}
)
)
eou_probability = outputs[0].flatten()[-1]
end_time = time.perf_counter()
duration = end_time - start_time
log_predict(client_id, "Turn detection completed",
probability=f"{eou_probability:.6f}",
threshold=self._unlikely_threshold,
is_complete=eou_probability > self._unlikely_threshold)
log_performance(client_id, "Prediction performance",
duration=f"{duration:.3f}s",
input_tokens=inputs["input_ids"].shape[1])
if eou_probability > self._unlikely_threshold:
return True
else:
return False
async def predict_probability(self, chat_context: List[ChatMessage], client_id: str = None) -> float:
"""
Predicts the probability that the current utterance is complete.
Args:
chat_context: A list of ChatMessage objects representing the conversation history.
client_id: Client identifier for logging purposes.
Returns:
A float representing the probability that the utterance is complete.
"""
if not chat_context:
log_warning(client_id, "Empty chat context provided, returning 0.0 probability")
return 0.0
start_time = time.perf_counter()
text = self._format_chat_ctx(chat_context[-self._max_history_turns:])
log_predict(client_id, "Processing probability prediction",
context_length=len(chat_context),
text_length=len(text))
# Run tokenization in thread pool to avoid blocking
loop = asyncio.get_event_loop()
inputs = await loop.run_in_executor(
None,
lambda: self._tokenizer(
text,
add_special_tokens=False,
return_tensors="np",
max_length=self._max_history_tokens,
truncation=True,
)
)
# Run inference in thread pool
outputs = await loop.run_in_executor(
None,
lambda: self._session.run(
None, {"input_ids": inputs["input_ids"].astype("int64")}
)
)
eou_probability = outputs[0].flatten()[-1]
end_time = time.perf_counter()
duration = end_time - start_time
log_predict(client_id, "Probability prediction completed",
probability=f"{eou_probability:.6f}")
log_performance(client_id, "Prediction performance",
duration=f"{duration:.3f}s",
input_tokens=inputs["input_ids"].shape[1])
return float(eou_probability)
async def main():
"""Example usage of the TurnDetector class."""
chat_ctx = [
ChatMessage(role='assistant', content='您好,请问有什么可以帮到您?'),
# ChatMessage(role='user', content='我想咨询一下退票的问题。')
ChatMessage(role='user', content='我想')
]
turn_detection = TurnDetector()
result = await turn_detection.predict(chat_ctx, client_id="test_client")
from logger import log_info
log_info("test_client", f"Final prediction result: {result}")
# Also test the probability method
probability = await turn_detection.predict_probability(chat_ctx, client_id="test_client")
log_info("test_client", f"Probability result: {probability}")
return result
if __name__ == "__main__":
asyncio.run(main())