Getting started
This commit is contained in:
373
async_processor/async_processor.py
Normal file
373
async_processor/async_processor.py
Normal file
@@ -0,0 +1,373 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from queue import Queue, PriorityQueue, Empty
|
||||
from threading import Event, Semaphore, Thread
|
||||
from typing import Iterator, Optional, Type, TypedDict
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from services.ai_services import AIServiceConfig
|
||||
from message_handler.message_handler import MessageHandler
|
||||
|
||||
frame_idx = 0
|
||||
|
||||
|
||||
class AsyncProcessorState:
|
||||
# Setting class variables, other synchronous activities
|
||||
INIT = 0
|
||||
|
||||
# Making asynchronous requests to LLM and other services to render response
|
||||
PREPARING = 1
|
||||
|
||||
# Ready to start presenting to user (but may not have all data yet)
|
||||
READY = 2
|
||||
|
||||
# Playing response
|
||||
PLAYING = 3
|
||||
|
||||
# An interrupt has been requested and the response is shutting down in-flight processing
|
||||
INTERRUPTING = 4
|
||||
|
||||
# An interrupt has been requested and the response is finished stopping in-flight processing
|
||||
INTERRUPTED = 5
|
||||
|
||||
# Response has been played or interrupted
|
||||
DONE = 6
|
||||
|
||||
# Response is being finalized (updating records of speech, updating LLM context, etc.)
|
||||
FINALIZING = 7
|
||||
|
||||
# Response is complete. This could mean that everything is updated, or that the response
|
||||
# was interrupted.
|
||||
FINALIZED = 8
|
||||
|
||||
state_transitions = {
|
||||
INIT: [PREPARING, INTERRUPTING],
|
||||
PREPARING: [READY, INTERRUPTING],
|
||||
READY: [PLAYING, INTERRUPTING],
|
||||
PLAYING: [DONE, INTERRUPTING],
|
||||
INTERRUPTING: [INTERRUPTED],
|
||||
INTERRUPTED: [DONE],
|
||||
DONE: [FINALIZING],
|
||||
FINALIZING: [FINALIZED],
|
||||
FINALIZED: [FINALIZED],
|
||||
}
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class StateTransitionItem:
|
||||
state: int
|
||||
evt: Event = field(compare=False)
|
||||
|
||||
class AsyncProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
services: AIServiceConfig
|
||||
) -> None:
|
||||
self.state = AsyncProcessorState.INIT
|
||||
self.prepare_thread = None
|
||||
self.play_thread = None
|
||||
self.finalize_thread = None
|
||||
|
||||
self.services: AIServiceConfig = services
|
||||
|
||||
self.state_transition_semaphore = Semaphore()
|
||||
self.waiting_for_state_changes = PriorityQueue()
|
||||
self.state_queue = Queue()
|
||||
|
||||
self.state_change_callbacks = defaultdict(list)
|
||||
|
||||
self.was_interrupted = False
|
||||
|
||||
self.logger = logging.getLogger("bot-instance")
|
||||
|
||||
def set_state(self, state: int) -> None:
|
||||
if state in AsyncProcessorState.state_transitions[self.state]:
|
||||
self.state_transition_semaphore.acquire()
|
||||
|
||||
self.state = state
|
||||
self.state_transition_semaphore.release()
|
||||
|
||||
# wake up any threads waiting for this state transition
|
||||
try:
|
||||
while True:
|
||||
waiter = self.waiting_for_state_changes.get_nowait()
|
||||
if waiter.state <= state:
|
||||
waiter.evt.set()
|
||||
else:
|
||||
self.waiting_for_state_changes.put(waiter)
|
||||
break
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
# make all the callbacks for this state
|
||||
for callback in self.state_change_callbacks[state]:
|
||||
callback(self)
|
||||
else:
|
||||
self.logger.error(
|
||||
f"Invalid state transition from {self.state} to {state} in {self.__class__.__name__}"
|
||||
)
|
||||
raise Exception(f"Invalid state transition from {self.state} to {state}")
|
||||
|
||||
#
|
||||
# This is used for state transitions that could be blocked by an interruption.
|
||||
# If we are interrupted, we silently fail this call. Use only if you know that
|
||||
# this state transition should fail if the processor has been interrupted.
|
||||
#
|
||||
|
||||
def maybe_set_state(self, state: int) -> bool:
|
||||
if state in AsyncProcessorState.state_transitions[self.state]:
|
||||
self.set_state(state)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def wait_for_state_transition(self, state: int) -> None:
|
||||
if self.state >= state:
|
||||
return
|
||||
|
||||
self.state_transition_semaphore.acquire()
|
||||
|
||||
evt = Event()
|
||||
self.waiting_for_state_changes.put(StateTransitionItem(state, evt))
|
||||
self.state_transition_semaphore.release()
|
||||
result = evt.wait(120.0)
|
||||
if not result:
|
||||
self.logger.error(
|
||||
f"Timed out waiting for state transition to {state} from {self.state}"
|
||||
)
|
||||
|
||||
def set_state_callback(self, state: int, callback: callable) -> None:
|
||||
self.state_change_callbacks[state].append(callback)
|
||||
|
||||
def prepare(self) -> None:
|
||||
self.prepare_thread = Thread(target=self.async_prepare, daemon=True)
|
||||
self.prepare_thread.start()
|
||||
self.wait_for_state_transition(AsyncProcessorState.READY)
|
||||
|
||||
def play(self) -> None:
|
||||
self.wait_for_state_transition(AsyncProcessorState.READY)
|
||||
self.play_thread = Thread(target=self.async_play, daemon=True)
|
||||
self.play_thread.start()
|
||||
self.wait_for_state_transition(AsyncProcessorState.PLAYING)
|
||||
|
||||
def finalize(self) -> None:
|
||||
# don't finalize until we're done playing.
|
||||
self.wait_for_state_transition(AsyncProcessorState.DONE)
|
||||
self.set_state(AsyncProcessorState.FINALIZING)
|
||||
self.do_finalization()
|
||||
self.set_state(AsyncProcessorState.FINALIZED)
|
||||
|
||||
def interrupt(self) -> None:
|
||||
# nothing to interrupt if we're already finalizing or finalized, no-op
|
||||
if self.state in [
|
||||
AsyncProcessorState.FINALIZING,
|
||||
AsyncProcessorState.FINALIZED,
|
||||
]:
|
||||
return
|
||||
|
||||
self.set_state(AsyncProcessorState.INTERRUPTING)
|
||||
self.was_interrupted = True
|
||||
self.do_interruption()
|
||||
self.set_state(AsyncProcessorState.INTERRUPTED)
|
||||
self.set_state(AsyncProcessorState.DONE)
|
||||
|
||||
def async_play(self) -> None:
|
||||
self.logger.info(f"starting to play")
|
||||
if self.maybe_set_state(AsyncProcessorState.PLAYING):
|
||||
self.do_play()
|
||||
self.maybe_set_state(AsyncProcessorState.DONE)
|
||||
|
||||
def async_prepare(self) -> None:
|
||||
self.set_state(AsyncProcessorState.PREPARING)
|
||||
self.preparation_iterator = self.get_preparation_iterator()
|
||||
self.set_state(AsyncProcessorState.READY)
|
||||
for chunk in self.preparation_iterator:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
self.process_chunk(chunk)
|
||||
|
||||
self.logger.info(f"Preparation done for {self.__class__.__name__}")
|
||||
self.preparation_done()
|
||||
|
||||
def preparation_done(self):
|
||||
pass
|
||||
|
||||
def get_preparation_iterator(self) -> Iterator:
|
||||
yield None
|
||||
|
||||
def process_chunk(self, chunk) -> None:
|
||||
pass
|
||||
|
||||
def do_interruption(self) -> None:
|
||||
pass
|
||||
|
||||
def do_play(self) -> None:
|
||||
pass
|
||||
|
||||
def do_finalization(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class ResponseArgs(TypedDict):
|
||||
services: AIServiceConfig
|
||||
message_handler: MessageHandler
|
||||
output_queue: Queue
|
||||
|
||||
class Response(AsyncProcessor):
|
||||
def __init__(
|
||||
self,
|
||||
services,
|
||||
message_handler,
|
||||
output_queue,
|
||||
) -> None:
|
||||
super().__init__(services)
|
||||
|
||||
self.message_handler: MessageHandler = message_handler
|
||||
self.output_queue: Queue = output_queue
|
||||
self.has_sent_first_frame = False
|
||||
|
||||
self.chunks_in_preparation = Queue()
|
||||
|
||||
#self.sprite_loader = sprite_loader.SpriteLoader()
|
||||
|
||||
self.llm_responses: list[str] = []
|
||||
|
||||
def get_preparation_iterator(self) -> Iterator:
|
||||
messages_for_llm = self.message_handler.get_llm_messages()
|
||||
self.logger.error(f"messages for llm: {json.dumps(messages_for_llm, indent=2)}")
|
||||
return self.clauses_from_chunks(
|
||||
self.services.llm.run_llm_async(messages_for_llm)
|
||||
)
|
||||
|
||||
def clauses_from_chunks(self, chunks) -> Iterator:
|
||||
out = ""
|
||||
for chunk in chunks:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
out += chunk
|
||||
|
||||
if re.match(r"^.*[.!?]$", out): # it looks like a sentence
|
||||
yield out.strip()
|
||||
out = ""
|
||||
|
||||
if out.strip():
|
||||
yield out.strip()
|
||||
|
||||
def process_chunk(self, chunk) -> None:
|
||||
# could also put other generators in this tuple
|
||||
self.logger.info(f"putting chunk in preparation queue {chunk}")
|
||||
|
||||
def get_frames_from_chunk(chunk):
|
||||
image_list = [
|
||||
"sc-talk",
|
||||
"sc-default",
|
||||
"sc-default",
|
||||
"sc-default",
|
||||
"sc-talk",
|
||||
"sc-default",
|
||||
"sc-default",
|
||||
"sc-default",
|
||||
"sc-default",
|
||||
"sc-talk",
|
||||
"sc-talk",
|
||||
"sc-default",
|
||||
"sc-default",
|
||||
"sc-talk",
|
||||
"sc-talk",
|
||||
"sc-default",
|
||||
"sc-talk",
|
||||
"sc-default",
|
||||
"sc-default",
|
||||
"sc-default",
|
||||
"sc-talk",
|
||||
"sc-talk",
|
||||
"sc-talk",
|
||||
"sc-talk",
|
||||
"sc-talk",
|
||||
"sc-talk",
|
||||
"sc-default",
|
||||
"sc-default",
|
||||
"sc-talk",
|
||||
"sc-talk",
|
||||
]
|
||||
image_list_idx = 0
|
||||
for frame in self.services.tts.run_tts(chunk):
|
||||
yield (bytearray(frame), None) #self.sprite_loader.get_sprite_bytes(image_list[image_list_idx]))
|
||||
image_list_idx = (image_list_idx + 1) % len(image_list)
|
||||
|
||||
self.chunks_in_preparation.put((chunk, get_frames_from_chunk(chunk)))
|
||||
|
||||
def preparation_done(self):
|
||||
self.chunks_in_preparation.put((None, None))
|
||||
|
||||
def do_play(self) -> None:
|
||||
while True:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
prepared_chunk = self.chunks_in_preparation.get()
|
||||
if prepared_chunk[0] is None:
|
||||
return
|
||||
|
||||
self.play_prepared_chunk(prepared_chunk)
|
||||
|
||||
def play_prepared_chunk(self, prepared_chunk) -> None:
|
||||
chunk, tts_generator = prepared_chunk
|
||||
global frame_idx
|
||||
for tts_chunk in tts_generator:
|
||||
if self.state not in [
|
||||
AsyncProcessorState.READY,
|
||||
AsyncProcessorState.PLAYING,
|
||||
]:
|
||||
break
|
||||
|
||||
if not self.has_sent_first_frame:
|
||||
self.output_queue.put({"type": "start_stream", "idx": frame_idx})
|
||||
frame_idx += 1
|
||||
self.has_sent_first_frame = True
|
||||
|
||||
(audio_frame, video_frame) = tts_chunk
|
||||
self.output_queue.put(
|
||||
{"type": "image_frame", "data": video_frame, "idx": frame_idx}
|
||||
)
|
||||
self.output_queue.put(
|
||||
{"type": "audio_frame", "data": audio_frame, "idx": frame_idx + 1}
|
||||
)
|
||||
frame_idx += 2
|
||||
|
||||
self.output_queue.join()
|
||||
self.llm_responses.append(chunk)
|
||||
|
||||
def do_finalization(self) -> None:
|
||||
self.message_handler.add_assistant_messages(self.llm_responses)
|
||||
|
||||
def do_interruption(self) -> None:
|
||||
self.chunks_in_preparation.put((None, None))
|
||||
|
||||
if self.prepare_thread and self.prepare_thread.is_alive():
|
||||
self.prepare_thread.join()
|
||||
|
||||
if self.play_thread and self.play_thread.is_alive():
|
||||
self.play_thread.join()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConversationProcessorCollection:
|
||||
introduction: Optional[Type[Response]] = None
|
||||
waiting: Optional[Type[Response]] = None
|
||||
response: Optional[Type[Response]] = None
|
||||
goodbye: Optional[Type[Response]] = None
|
||||
Reference in New Issue
Block a user