Files
pipecat/src/dailyai/async_processor/async_processor.py
Moishe Lettvin 755059c358 a little cleanup
2024-01-16 19:58:11 -05:00

348 lines
11 KiB
Python

import json
import logging
import re
from collections import defaultdict
from dataclasses import dataclass, field
from enum import Enum
from queue import Queue, PriorityQueue, Empty
from threading import Event, Semaphore, Thread
from typing import Any, Generator, Iterator, Optional, Type
from dailyai.queue_frame import QueueFrame, FrameType
from dailyai.message_handler.message_handler import MessageHandler
from dailyai.services.ai_services import AIServiceConfig
class AsyncProcessorState:
# Setting class variables, other synchronous activities
INIT = 0
# Making asynchronous requests to LLM and other services to render response
PREPARING = 1
# Ready to start presenting to user (but may not have all data yet)
READY = 2
# Playing response
PLAYING = 3
# An interrupt has been requested and the response is shutting down in-flight processing
INTERRUPTING = 4
# An interrupt has been requested and the response is finished stopping in-flight processing
INTERRUPTED = 5
# Response has been played or interrupted
DONE = 6
# Response is being finalized (updating records of speech, updating LLM context, etc.)
FINALIZING = 7
# Response is complete. This could mean that everything is updated, or that the response
# was interrupted.
FINALIZED = 8
state_transitions = {
INIT: [PREPARING, INTERRUPTING],
PREPARING: [READY, INTERRUPTING],
READY: [PLAYING, INTERRUPTING],
PLAYING: [DONE, INTERRUPTING],
INTERRUPTING: [INTERRUPTED],
INTERRUPTED: [DONE],
DONE: [FINALIZING],
FINALIZING: [FINALIZED],
FINALIZED: [FINALIZED],
}
@dataclass(order=True)
class StateTransitionItem:
state: int
evt: Event = field(compare=False)
class AsyncProcessor:
def __init__(
self,
services: AIServiceConfig
) -> None:
self.state = AsyncProcessorState.INIT
self.prepare_thread = None
self.play_thread = None
self.finalize_thread = None
self.services: AIServiceConfig = services
self.state_transition_semaphore = Semaphore()
self.waiting_for_state_changes = PriorityQueue()
self.state_queue = Queue()
self.state_change_callbacks = defaultdict(list)
self.was_interrupted = False
self.logger: logging.Logger = logging.getLogger("dailyai")
def set_state(self, state: int) -> None:
if state in AsyncProcessorState.state_transitions[self.state]:
self.state_transition_semaphore.acquire()
self.state: int = state
self.state_transition_semaphore.release()
# wake up any threads waiting for this state transition
try:
while True:
waiter = self.waiting_for_state_changes.get_nowait()
if waiter.state <= state:
waiter.evt.set()
else:
self.waiting_for_state_changes.put(waiter)
break
except Empty:
pass
# make all the callbacks for this state
for callback in self.state_change_callbacks[state]:
callback(self)
else:
self.logger.error(
f"Invalid state transition from {self.state} to {state} in {self.__class__.__name__}"
)
raise Exception(f"Invalid state transition from {self.state} to {state}")
#
# This is used for state transitions that could be blocked by an interruption.
# If we are interrupted, we silently fail this call. Use only if you know that
# this state transition should fail if the processor has been interrupted.
#
def maybe_set_state(self, state: int) -> bool:
if state in AsyncProcessorState.state_transitions[self.state]:
self.set_state(state)
return True
else:
return False
def wait_for_state_transition(self, state: int) -> None:
if self.state >= state:
return
self.state_transition_semaphore.acquire()
evt = Event()
self.waiting_for_state_changes.put(StateTransitionItem(state, evt))
self.state_transition_semaphore.release()
result = evt.wait(120.0)
if not result:
self.logger.error(
f"Timed out waiting for state transition to {state} from {self.state}"
)
def set_state_callback(self, state: int, callback: callable) -> None:
self.state_change_callbacks[state].append(callback)
def prepare(self) -> None:
self.prepare_thread = Thread(target=self.async_prepare, daemon=True)
self.prepare_thread.start()
self.wait_for_state_transition(AsyncProcessorState.READY)
def play(self) -> None:
self.wait_for_state_transition(AsyncProcessorState.READY)
self.play_thread = Thread(target=self.async_play, daemon=True)
self.play_thread.start()
self.wait_for_state_transition(AsyncProcessorState.PLAYING)
def finalize(self) -> None:
# don't finalize until we're done playing.
self.wait_for_state_transition(AsyncProcessorState.DONE)
self.set_state(AsyncProcessorState.FINALIZING)
self.do_finalization()
self.set_state(AsyncProcessorState.FINALIZED)
def interrupt(self) -> None:
# nothing to interrupt if we're already finalizing or finalized, no-op
if self.state in [
AsyncProcessorState.FINALIZING,
AsyncProcessorState.FINALIZED,
]:
return
self.set_state(AsyncProcessorState.INTERRUPTING)
self.was_interrupted = True
self.do_interruption()
self.set_state(AsyncProcessorState.INTERRUPTED)
self.set_state(AsyncProcessorState.DONE)
def async_play(self) -> None:
self.logger.info(f"Starting to play")
if self.maybe_set_state(AsyncProcessorState.PLAYING):
self.do_play()
self.maybe_set_state(AsyncProcessorState.DONE)
def async_prepare(self) -> None:
self.set_state(AsyncProcessorState.PREPARING)
self.start_preparation()
self.set_state(AsyncProcessorState.READY)
self.continue_preparation()
self.logger.info(f"Preparation done for {self.__class__.__name__}")
self.preparation_done()
def start_preparation(self) -> None:
pass
def continue_preparation(self) -> None:
pass
def preparation_done(self):
pass
def get_preparation_iterator(self) -> Iterator:
yield None
def process_chunk(self, chunk) -> None:
pass
def do_interruption(self) -> None:
pass
def do_play(self) -> None:
pass
def do_finalization(self) -> None:
pass
# A common class for responses that use a message queue and
# an output queue.
class OrchestratorResponse(AsyncProcessor):
def __init__(
self,
services,
message_handler,
output_queue,
) -> None:
super().__init__(services)
self.message_handler: MessageHandler = message_handler
self.output_queue: Queue = output_queue
class LLMResponse(OrchestratorResponse):
def __init__(
self,
services,
message_handler,
output_queue,
) -> None:
super().__init__(services, message_handler, output_queue)
self.has_sent_first_frame = False
self.chunks_in_preparation = Queue()
self.llm_responses: list[str] = []
def get_preparation_iterator(self) -> Iterator:
messages_for_llm = self.message_handler.get_llm_messages()
self.logger.debug(f"Messages for llm: {json.dumps(messages_for_llm, indent=2)}")
return self.clauses_from_chunks(
self.services.llm.run_llm_async(messages_for_llm)
)
def clauses_from_chunks(self, chunks) -> Iterator:
out = ""
for chunk in chunks:
if self.state not in [
AsyncProcessorState.READY,
AsyncProcessorState.PLAYING,
]:
break
out += chunk
if re.match(r"^.*[.!?]$", out): # it looks like a sentence
yield out.strip()
out = ""
if out.strip():
yield out.strip()
def get_frames_from_tts_response(self, audio_frame) -> list[QueueFrame]:
return [QueueFrame(FrameType.AUDIO, audio_frame)]
def get_frames_from_chunk(self, chunk) -> Generator[list[QueueFrame], Any, None]:
for audio_frame in self.services.tts.run_tts(chunk):
yield self.get_frames_from_tts_response(audio_frame)
def start_preparation(self) -> None:
self.preparation_iterator = self.get_preparation_iterator()
def continue_preparation(self) -> None:
for chunk in self.preparation_iterator:
if self.state not in [
AsyncProcessorState.READY,
AsyncProcessorState.PLAYING,
]:
break
self.process_chunk(chunk)
def process_chunk(self, chunk) -> None:
self.chunks_in_preparation.put((chunk, self.get_frames_from_chunk(chunk)))
def preparation_done(self):
self.chunks_in_preparation.put((None, None))
def do_play(self) -> None:
while True:
if self.state not in [
AsyncProcessorState.READY,
AsyncProcessorState.PLAYING,
]:
break
prepared_chunk = self.chunks_in_preparation.get()
if prepared_chunk[0] == None:
return
self.play_prepared_chunk(prepared_chunk)
def play_prepared_chunk(self, prepared_chunk) -> None:
chunk, tts_generator = prepared_chunk
for frames in tts_generator:
if self.state not in [
AsyncProcessorState.READY,
AsyncProcessorState.PLAYING,
]:
break
if not self.has_sent_first_frame:
self.output_queue.put(QueueFrame(FrameType.START_STREAM, None))
self.has_sent_first_frame = True
for frame in frames:
self.output_queue.put(frame)
self.output_queue.join()
self.llm_responses.append(chunk)
def do_finalization(self) -> None:
self.message_handler.add_assistant_messages(self.llm_responses)
def do_interruption(self) -> None:
self.chunks_in_preparation.put((None, None))
if self.prepare_thread and self.prepare_thread.is_alive():
self.prepare_thread.join()
if self.play_thread and self.play_thread.is_alive():
self.play_thread.join()
@dataclass(frozen=True)
class ConversationProcessorCollection:
introduction: Optional[Type[OrchestratorResponse]] = None
waiting: Optional[Type[OrchestratorResponse]] = None
response: Optional[Type[OrchestratorResponse]] = None
goodbye: Optional[Type[OrchestratorResponse]] = None