256 lines
8.4 KiB
Python
256 lines
8.4 KiB
Python
"""Conversation management for voice AI.
|
|
|
|
Handles conversation context, turn-taking, and message history
|
|
for multi-turn voice conversations.
|
|
"""
|
|
|
|
import asyncio
|
|
from typing import List, Optional, Dict, Any, Callable, Awaitable
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from loguru import logger
|
|
|
|
from services.base import LLMMessage
|
|
|
|
|
|
class ConversationState(Enum):
|
|
"""State of the conversation."""
|
|
IDLE = "idle" # Waiting for user input
|
|
LISTENING = "listening" # User is speaking
|
|
PROCESSING = "processing" # Processing user input (LLM)
|
|
SPEAKING = "speaking" # Bot is speaking
|
|
INTERRUPTED = "interrupted" # Bot was interrupted
|
|
|
|
|
|
@dataclass
|
|
class ConversationTurn:
|
|
"""A single turn in the conversation."""
|
|
role: str # "user" or "assistant"
|
|
text: str
|
|
audio_duration_ms: Optional[int] = None
|
|
timestamp: float = field(default_factory=lambda: asyncio.get_event_loop().time())
|
|
was_interrupted: bool = False
|
|
|
|
|
|
class ConversationManager:
|
|
"""
|
|
Manages conversation state and history.
|
|
|
|
Provides:
|
|
- Message history for LLM context
|
|
- Turn management
|
|
- State tracking
|
|
- Event callbacks for state changes
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
system_prompt: Optional[str] = None,
|
|
max_history: int = 20,
|
|
greeting: Optional[str] = None
|
|
):
|
|
"""
|
|
Initialize conversation manager.
|
|
|
|
Args:
|
|
system_prompt: System prompt for LLM
|
|
max_history: Maximum number of turns to keep
|
|
greeting: Optional greeting message when conversation starts
|
|
"""
|
|
self.system_prompt = system_prompt or (
|
|
"You are a helpful, friendly voice assistant. "
|
|
"Keep your responses concise and conversational. "
|
|
"Respond naturally as if having a phone conversation. "
|
|
"If you don't understand something, ask for clarification."
|
|
)
|
|
self.max_history = max_history
|
|
self.greeting = greeting
|
|
|
|
# State
|
|
self.state = ConversationState.IDLE
|
|
self.turns: List[ConversationTurn] = []
|
|
|
|
# Callbacks
|
|
self._state_callbacks: List[Callable[[ConversationState, ConversationState], Awaitable[None]]] = []
|
|
self._turn_callbacks: List[Callable[[ConversationTurn], Awaitable[None]]] = []
|
|
|
|
# Current turn tracking
|
|
self._current_user_text: str = ""
|
|
self._current_assistant_text: str = ""
|
|
|
|
logger.info("ConversationManager initialized")
|
|
|
|
def on_state_change(
|
|
self,
|
|
callback: Callable[[ConversationState, ConversationState], Awaitable[None]]
|
|
) -> None:
|
|
"""Register callback for state changes."""
|
|
self._state_callbacks.append(callback)
|
|
|
|
def on_turn_complete(
|
|
self,
|
|
callback: Callable[[ConversationTurn], Awaitable[None]]
|
|
) -> None:
|
|
"""Register callback for turn completion."""
|
|
self._turn_callbacks.append(callback)
|
|
|
|
async def set_state(self, new_state: ConversationState) -> None:
|
|
"""Set conversation state and notify listeners."""
|
|
if new_state != self.state:
|
|
old_state = self.state
|
|
self.state = new_state
|
|
logger.debug(f"Conversation state: {old_state.value} -> {new_state.value}")
|
|
|
|
for callback in self._state_callbacks:
|
|
try:
|
|
await callback(old_state, new_state)
|
|
except Exception as e:
|
|
logger.error(f"State callback error: {e}")
|
|
|
|
def get_messages(self) -> List[LLMMessage]:
|
|
"""
|
|
Get conversation history as LLM messages.
|
|
|
|
Returns:
|
|
List of LLMMessage objects including system prompt
|
|
"""
|
|
messages = [LLMMessage(role="system", content=self.system_prompt)]
|
|
|
|
# Add conversation history
|
|
for turn in self.turns[-self.max_history:]:
|
|
messages.append(LLMMessage(role=turn.role, content=turn.text))
|
|
|
|
# Add current user text if any
|
|
if self._current_user_text:
|
|
messages.append(LLMMessage(role="user", content=self._current_user_text))
|
|
|
|
return messages
|
|
|
|
async def start_user_turn(self) -> None:
|
|
"""Signal that user has started speaking."""
|
|
await self.set_state(ConversationState.LISTENING)
|
|
self._current_user_text = ""
|
|
|
|
async def update_user_text(self, text: str, is_final: bool = False) -> None:
|
|
"""
|
|
Update current user text (from ASR).
|
|
|
|
Args:
|
|
text: Transcribed text
|
|
is_final: Whether this is the final transcript
|
|
"""
|
|
self._current_user_text = text
|
|
|
|
async def end_user_turn(self, text: str) -> None:
|
|
"""
|
|
End user turn and add to history.
|
|
|
|
Args:
|
|
text: Final user text
|
|
"""
|
|
if text.strip():
|
|
turn = ConversationTurn(role="user", text=text.strip())
|
|
self.turns.append(turn)
|
|
|
|
for callback in self._turn_callbacks:
|
|
try:
|
|
await callback(turn)
|
|
except Exception as e:
|
|
logger.error(f"Turn callback error: {e}")
|
|
|
|
logger.info(f"User: {text[:50]}...")
|
|
|
|
self._current_user_text = ""
|
|
await self.set_state(ConversationState.PROCESSING)
|
|
|
|
async def start_assistant_turn(self) -> None:
|
|
"""Signal that assistant has started speaking."""
|
|
await self.set_state(ConversationState.SPEAKING)
|
|
self._current_assistant_text = ""
|
|
|
|
async def update_assistant_text(self, text: str) -> None:
|
|
"""
|
|
Update current assistant text (streaming).
|
|
|
|
Args:
|
|
text: Text chunk from LLM
|
|
"""
|
|
self._current_assistant_text += text
|
|
|
|
async def end_assistant_turn(self, was_interrupted: bool = False) -> None:
|
|
"""
|
|
End assistant turn and add to history.
|
|
|
|
Args:
|
|
was_interrupted: Whether the turn was interrupted by user
|
|
"""
|
|
text = self._current_assistant_text.strip()
|
|
if text:
|
|
turn = ConversationTurn(
|
|
role="assistant",
|
|
text=text,
|
|
was_interrupted=was_interrupted
|
|
)
|
|
self.turns.append(turn)
|
|
|
|
for callback in self._turn_callbacks:
|
|
try:
|
|
await callback(turn)
|
|
except Exception as e:
|
|
logger.error(f"Turn callback error: {e}")
|
|
|
|
status = " (interrupted)" if was_interrupted else ""
|
|
logger.info(f"Assistant{status}: {text[:50]}...")
|
|
|
|
self._current_assistant_text = ""
|
|
|
|
if was_interrupted:
|
|
await self.set_state(ConversationState.INTERRUPTED)
|
|
else:
|
|
await self.set_state(ConversationState.IDLE)
|
|
|
|
async def interrupt(self) -> None:
|
|
"""Handle interruption (barge-in)."""
|
|
if self.state == ConversationState.SPEAKING:
|
|
await self.end_assistant_turn(was_interrupted=True)
|
|
|
|
def reset(self) -> None:
|
|
"""Reset conversation history."""
|
|
self.turns = []
|
|
self._current_user_text = ""
|
|
self._current_assistant_text = ""
|
|
self.state = ConversationState.IDLE
|
|
logger.info("Conversation reset")
|
|
|
|
@property
|
|
def turn_count(self) -> int:
|
|
"""Get number of turns in conversation."""
|
|
return len(self.turns)
|
|
|
|
@property
|
|
def last_user_text(self) -> Optional[str]:
|
|
"""Get last user text."""
|
|
for turn in reversed(self.turns):
|
|
if turn.role == "user":
|
|
return turn.text
|
|
return None
|
|
|
|
@property
|
|
def last_assistant_text(self) -> Optional[str]:
|
|
"""Get last assistant text."""
|
|
for turn in reversed(self.turns):
|
|
if turn.role == "assistant":
|
|
return turn.text
|
|
return None
|
|
|
|
def get_context_summary(self) -> Dict[str, Any]:
|
|
"""Get a summary of conversation context."""
|
|
return {
|
|
"state": self.state.value,
|
|
"turn_count": self.turn_count,
|
|
"last_user": self.last_user_text,
|
|
"last_assistant": self.last_assistant_text,
|
|
"current_user": self._current_user_text or None,
|
|
"current_assistant": self._current_assistant_text or None
|
|
}
|