Add backend api and engine
This commit is contained in:
255
engine/core/conversation.py
Normal file
255
engine/core/conversation.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Conversation management for voice AI.
|
||||
|
||||
Handles conversation context, turn-taking, and message history
|
||||
for multi-turn voice conversations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
|
||||
from services.base import LLMMessage
|
||||
|
||||
|
||||
class ConversationState(Enum):
|
||||
"""State of the conversation."""
|
||||
IDLE = "idle" # Waiting for user input
|
||||
LISTENING = "listening" # User is speaking
|
||||
PROCESSING = "processing" # Processing user input (LLM)
|
||||
SPEAKING = "speaking" # Bot is speaking
|
||||
INTERRUPTED = "interrupted" # Bot was interrupted
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationTurn:
|
||||
"""A single turn in the conversation."""
|
||||
role: str # "user" or "assistant"
|
||||
text: str
|
||||
audio_duration_ms: Optional[int] = None
|
||||
timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time())
|
||||
was_interrupted: bool = False
|
||||
|
||||
|
||||
class ConversationManager:
|
||||
"""
|
||||
Manages conversation state and history.
|
||||
|
||||
Provides:
|
||||
- Message history for LLM context
|
||||
- Turn management
|
||||
- State tracking
|
||||
- Event callbacks for state changes
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
system_prompt: Optional[str] = None,
|
||||
max_history: int = 20,
|
||||
greeting: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize conversation manager.
|
||||
|
||||
Args:
|
||||
system_prompt: System prompt for LLM
|
||||
max_history: Maximum number of turns to keep
|
||||
greeting: Optional greeting message when conversation starts
|
||||
"""
|
||||
self.system_prompt = system_prompt or (
|
||||
"You are a helpful, friendly voice assistant. "
|
||||
"Keep your responses concise and conversational. "
|
||||
"Respond naturally as if having a phone conversation. "
|
||||
"If you don't understand something, ask for clarification."
|
||||
)
|
||||
self.max_history = max_history
|
||||
self.greeting = greeting
|
||||
|
||||
# State
|
||||
self.state = ConversationState.IDLE
|
||||
self.turns: List[ConversationTurn] = []
|
||||
|
||||
# Callbacks
|
||||
self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = []
|
||||
self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = []
|
||||
|
||||
# Current turn tracking
|
||||
self._current_user_text: str = ""
|
||||
self._current_assistant_text: str = ""
|
||||
|
||||
logger.info("ConversationManager initialized")
|
||||
|
||||
def on_state_change(
|
||||
self,
|
||||
callback: Callable[[ConversationState, ConversationState], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for state changes."""
|
||||
self._state_callbacks.append(callback)
|
||||
|
||||
def on_turn_complete(
|
||||
self,
|
||||
callback: Callable[[ConversationTurn], Awaitable[None]]
|
||||
) -> None:
|
||||
"""Register callback for turn completion."""
|
||||
self._turn_callbacks.append(callback)
|
||||
|
||||
async def set_state(self, new_state: ConversationState) -> None:
|
||||
"""Set conversation state and notify listeners."""
|
||||
if new_state != self.state:
|
||||
old_state = self.state
|
||||
self.state = new_state
|
||||
logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}")
|
||||
|
||||
for callback in self._state_callbacks:
|
||||
try:
|
||||
await callback(old_state, new_state)
|
||||
except Exception as e:
|
||||
logger.error(f"State callback error: {e}")
|
||||
|
||||
def get_messages(self) -> List[LLMMessage]:
|
||||
"""
|
||||
Get conversation history as LLM messages.
|
||||
|
||||
Returns:
|
||||
List of LLMMessage objects including system prompt
|
||||
"""
|
||||
messages = [LLMMessage(role="system", content=self.system_prompt)]
|
||||
|
||||
# Add conversation history
|
||||
for turn in self.turns[-self.max_history:]:
|
||||
messages.append(LLMMessage(role=turn.role, content=turn.text))
|
||||
|
||||
# Add current user text if any
|
||||
if self._current_user_text:
|
||||
messages.append(LLMMessage(role="user", content=self._current_user_text))
|
||||
|
||||
return messages
|
||||
|
||||
async def start_user_turn(self) -> None:
|
||||
"""Signal that user has started speaking."""
|
||||
await self.set_state(ConversationState.LISTENING)
|
||||
self._current_user_text = ""
|
||||
|
||||
async def update_user_text(self, text: str, is_final: bool = False) -> None:
|
||||
"""
|
||||
Update current user text (from ASR).
|
||||
|
||||
Args:
|
||||
text: Transcribed text
|
||||
is_final: Whether this is the final transcript
|
||||
"""
|
||||
self._current_user_text = text
|
||||
|
||||
async def end_user_turn(self, text: str) -> None:
|
||||
"""
|
||||
End user turn and add to history.
|
||||
|
||||
Args:
|
||||
text: Final user text
|
||||
"""
|
||||
if text.strip():
|
||||
turn = ConversationTurn(role="user", text=text.strip())
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
logger.info(f"User: {text[:50]}...")
|
||||
|
||||
self._current_user_text = ""
|
||||
await self.set_state(ConversationState.PROCESSING)
|
||||
|
||||
async def start_assistant_turn(self) -> None:
|
||||
"""Signal that assistant has started speaking."""
|
||||
await self.set_state(ConversationState.SPEAKING)
|
||||
self._current_assistant_text = ""
|
||||
|
||||
async def update_assistant_text(self, text: str) -> None:
|
||||
"""
|
||||
Update current assistant text (streaming).
|
||||
|
||||
Args:
|
||||
text: Text chunk from LLM
|
||||
"""
|
||||
self._current_assistant_text += text
|
||||
|
||||
async def end_assistant_turn(self, was_interrupted: bool = False) -> None:
|
||||
"""
|
||||
End assistant turn and add to history.
|
||||
|
||||
Args:
|
||||
was_interrupted: Whether the turn was interrupted by user
|
||||
"""
|
||||
text = self._current_assistant_text.strip()
|
||||
if text:
|
||||
turn = ConversationTurn(
|
||||
role="assistant",
|
||||
text=text,
|
||||
was_interrupted=was_interrupted
|
||||
)
|
||||
self.turns.append(turn)
|
||||
|
||||
for callback in self._turn_callbacks:
|
||||
try:
|
||||
await callback(turn)
|
||||
except Exception as e:
|
||||
logger.error(f"Turn callback error: {e}")
|
||||
|
||||
status = " (interrupted)" if was_interrupted else ""
|
||||
logger.info(f"Assistant{status}: {text[:50]}...")
|
||||
|
||||
self._current_assistant_text = ""
|
||||
|
||||
if was_interrupted:
|
||||
await self.set_state(ConversationState.INTERRUPTED)
|
||||
else:
|
||||
await self.set_state(ConversationState.IDLE)
|
||||
|
||||
async def interrupt(self) -> None:
|
||||
"""Handle interruption (barge-in)."""
|
||||
if self.state == ConversationState.SPEAKING:
|
||||
await self.end_assistant_turn(was_interrupted=True)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset conversation history."""
|
||||
self.turns = []
|
||||
self._current_user_text = ""
|
||||
self._current_assistant_text = ""
|
||||
self.state = ConversationState.IDLE
|
||||
logger.info("Conversation reset")
|
||||
|
||||
@property
|
||||
def turn_count(self) -> int:
|
||||
"""Get number of turns in conversation."""
|
||||
return len(self.turns)
|
||||
|
||||
@property
|
||||
def last_user_text(self) -> Optional[str]:
|
||||
"""Get last user text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "user":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
@property
|
||||
def last_assistant_text(self) -> Optional[str]:
|
||||
"""Get last assistant text."""
|
||||
for turn in reversed(self.turns):
|
||||
if turn.role == "assistant":
|
||||
return turn.text
|
||||
return None
|
||||
|
||||
def get_context_summary(self) -> Dict[str, Any]:
|
||||
"""Get a summary of conversation context."""
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"turn_count": self.turn_count,
|
||||
"last_user": self.last_user_text,
|
||||
"last_assistant": self.last_assistant_text,
|
||||
"current_user": self._current_user_text or None,
|
||||
"current_assistant": self._current_assistant_text or None
|
||||
}
|
||||
Reference in New Issue
Block a user