a little cleanup, moving services that need to be updated to their own directory
This commit is contained in:
@@ -174,7 +174,7 @@ class AsyncProcessor:
|
||||
self.set_state(AsyncProcessorState.DONE)
|
||||
|
||||
def async_play(self) -> None:
|
||||
self.logger.info(f"starting to play")
|
||||
self.logger.info(f"Starting to play")
|
||||
if self.maybe_set_state(AsyncProcessorState.PLAYING):
|
||||
self.do_play()
|
||||
self.maybe_set_state(AsyncProcessorState.DONE)
|
||||
@@ -233,7 +233,7 @@ class Response(AsyncProcessor):
|
||||
|
||||
def get_preparation_iterator(self) -> Iterator:
|
||||
messages_for_llm = self.message_handler.get_llm_messages()
|
||||
self.logger.error(f"messages for llm: {json.dumps(messages_for_llm, indent=2)}")
|
||||
self.logger.debug(f"Messages for llm: {json.dumps(messages_for_llm, indent=2)}")
|
||||
return self.clauses_from_chunks(
|
||||
self.services.llm.run_llm_async(messages_for_llm)
|
||||
)
|
||||
|
||||
@@ -62,23 +62,23 @@ class IndexingMessageHandler(MessageHandler):
|
||||
self.search_indexer = indexer
|
||||
|
||||
self.last_written_idx = 0
|
||||
self.index_message_queue = Queue()
|
||||
self.storage_message_queue = Queue()
|
||||
|
||||
self.index_writer_thread = Thread(target=self.indexer_writer, daemon=True)
|
||||
self.index_writer_thread = Thread(target=self.storage_writer, daemon=True)
|
||||
self.index_writer_thread.start()
|
||||
|
||||
self.logger = logging.getLogger("bot-instance")
|
||||
|
||||
def shutdown(self):
|
||||
self.finalize_user_message()
|
||||
self.index_message_queue.put(None)
|
||||
self.storage_message_queue.put(None)
|
||||
self.index_writer_thread.join()
|
||||
|
||||
def indexer_writer(self) -> None:
|
||||
def storage_writer(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
message_idx = self.index_message_queue.get()
|
||||
self.index_message_queue.task_done()
|
||||
message_idx = self.storage_message_queue.get()
|
||||
self.storage_message_queue.task_done()
|
||||
|
||||
if message_idx is None:
|
||||
return
|
||||
@@ -103,35 +103,19 @@ class IndexingMessageHandler(MessageHandler):
|
||||
pass
|
||||
|
||||
def cleanup_user_message(self, user_message) -> str:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": """
|
||||
You are an assistant who is very good at making transcriptions
|
||||
of human speech into well-capitalized and punctuated text, without
|
||||
changing any words or the order of the words. Please change this
|
||||
transcription to something suitable for the printed page.
|
||||
""",
|
||||
},
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
result = self.services.llm.run_llm(messages)
|
||||
if result:
|
||||
user_message = result
|
||||
|
||||
return user_message
|
||||
|
||||
def finalize_user_message(self):
|
||||
super().finalize_user_message()
|
||||
self.write_messages_to_index()
|
||||
self.write_messages_to_storage()
|
||||
|
||||
def write_messages_to_index(self):
|
||||
def write_messages_to_storage(self):
|
||||
if self.finalized_user_message_idx is None:
|
||||
return
|
||||
|
||||
for idx in range(self.last_written_idx, len(self.messages)):
|
||||
self.logger.info(
|
||||
f"writing to index: {self.messages[idx].type} {self.messages[idx].message}"
|
||||
f"Writing to storage: {self.messages[idx].type} {self.messages[idx].message}"
|
||||
)
|
||||
if (
|
||||
self.messages[idx].type == "user"
|
||||
@@ -140,4 +124,4 @@ class IndexingMessageHandler(MessageHandler):
|
||||
break
|
||||
|
||||
if self.messages[idx].type != "system":
|
||||
self.index_message_queue.put(idx)
|
||||
self.storage_message_queue.put(idx)
|
||||
|
||||
@@ -78,7 +78,7 @@ class Orchestrator(EventHandler):
|
||||
intro.prepare()
|
||||
intro.set_state_callback(AsyncProcessorState.DONE, self.on_intro_played)
|
||||
intro.set_state_callback(AsyncProcessorState.FINALIZED, self.on_intro_finished)
|
||||
self.logger.info(f"Response is preparing")
|
||||
self.logger.info(f"Introduction is preparing")
|
||||
|
||||
self.current_response: AsyncProcessor = intro
|
||||
self.can_interrupt = False
|
||||
@@ -88,14 +88,14 @@ class Orchestrator(EventHandler):
|
||||
self.speech_timeout = None
|
||||
self.interrupt_time = None
|
||||
|
||||
self.logger.info("configuring daily")
|
||||
self.logger.info("Configuring daily")
|
||||
self.configure_daily()
|
||||
|
||||
def configure_daily(self):
|
||||
Daily.init()
|
||||
self.client = CallClient(event_handler=self)
|
||||
|
||||
self.logger.info(f"mic sample rate: {self.services.tts.get_mic_sample_rate()}")
|
||||
self.logger.info(f"Mic sample rate: {self.services.tts.get_mic_sample_rate()}")
|
||||
self.mic: VirtualMicrophoneDevice = Daily.create_microphone_device(
|
||||
"mic", sample_rate=self.services.tts.get_mic_sample_rate(), channels=1
|
||||
)
|
||||
@@ -168,23 +168,23 @@ class Orchestrator(EventHandler):
|
||||
self.client.leave()
|
||||
|
||||
def stop(self):
|
||||
self.logger.info("stop current response")
|
||||
self.logger.info("Stop current response")
|
||||
if self.current_response:
|
||||
if self.current_response.state < AsyncProcessorState.INTERRUPTED:
|
||||
self.current_response.interrupt()
|
||||
|
||||
self.logger.info("wait for state transition")
|
||||
self.logger.info("Wait for state transition")
|
||||
self.current_response.wait_for_state_transition(AsyncProcessorState.FINALIZED)
|
||||
|
||||
self.stop_threads.set()
|
||||
self.camera_thread.join()
|
||||
self.logger.info("camera thread stopped")
|
||||
self.logger.info("Camera thread stopped")
|
||||
|
||||
self.logger.info("put stop in output queue")
|
||||
self.logger.info("Put stop in output queue")
|
||||
self.output_queue.put({"type": "stop"})
|
||||
|
||||
self.frame_consumer_thread.join()
|
||||
self.logger.info("orchestrator stopped.")
|
||||
self.logger.info("Orchestrator stopped.")
|
||||
|
||||
def on_intro_played(self, intro):
|
||||
self.can_interrupt = True
|
||||
@@ -202,7 +202,7 @@ class Orchestrator(EventHandler):
|
||||
self.message_handler.finalize_user_message()
|
||||
|
||||
def call_joined(self, join_data, client_error):
|
||||
self.logger.info(f"call_joined: {join_data}, {client_error}")
|
||||
self.logger.info(f"Call_joined: {join_data}, {client_error}")
|
||||
self.client.start_transcription(
|
||||
{
|
||||
"language": "en",
|
||||
@@ -231,7 +231,7 @@ class Orchestrator(EventHandler):
|
||||
|
||||
def on_participant_left(self, participant, reason):
|
||||
if len(self.client.participants()) < 2:
|
||||
self.logger.info("participant left")
|
||||
self.logger.info(f"Participant {participant} left")
|
||||
self.participant_left = True
|
||||
|
||||
def on_app_message(self, message, sender):
|
||||
@@ -249,13 +249,13 @@ class Orchestrator(EventHandler):
|
||||
self.handle_transcription_fragment(message['text'])
|
||||
|
||||
def on_transcription_stopped(self, stopped_by, stopped_by_error):
|
||||
self.logger.info(f"transcription stopped {stopped_by}, {stopped_by_error}")
|
||||
self.logger.info(f"Transcription stopped {stopped_by}, {stopped_by_error}")
|
||||
|
||||
def on_transcription_error(self, message):
|
||||
self.logger.error(f"transcription error {message}")
|
||||
self.logger.error(f"Transcription error {message}")
|
||||
|
||||
def on_transcription_started(self, status):
|
||||
self.logger.info(f"transcription started {status}")
|
||||
self.logger.info(f"Transcription started {status}")
|
||||
|
||||
def set_image(self, image: bytes):
|
||||
self.image: bytes | None = image
|
||||
@@ -380,14 +380,16 @@ class Orchestrator(EventHandler):
|
||||
# self.display_images(thinking_images)
|
||||
|
||||
def action(self):
|
||||
self.logger.info("starting camera thread")
|
||||
self.logger.info("Starting camera thread")
|
||||
self.image: bytes | None = None
|
||||
self.camera_thread = Thread(target=self.run_camera, daemon=True)
|
||||
self.camera_thread.start()
|
||||
|
||||
self.logger.info("Starting frame consumer thread")
|
||||
self.frame_consumer_thread = Thread(target=self.frame_consumer, daemon=True)
|
||||
self.frame_consumer_thread.start()
|
||||
|
||||
self.logger.info("Playing introduction")
|
||||
self.can_interrupt = False
|
||||
self.current_response.play()
|
||||
|
||||
@@ -401,7 +403,7 @@ class Orchestrator(EventHandler):
|
||||
try:
|
||||
frame = self.output_queue.get()
|
||||
if frame["type"] == "stop":
|
||||
self.logger.info("🎬 Stopping frame consumer thread")
|
||||
self.logger.info("Stopping frame consumer thread")
|
||||
|
||||
if os.getenv("WRITE_BOT_AUDIO", False):
|
||||
filename = f"conversation-{len(all_audio_frames)}.wav"
|
||||
@@ -440,7 +442,7 @@ class Orchestrator(EventHandler):
|
||||
b = bytearray()
|
||||
else:
|
||||
if self.interrupt_time:
|
||||
self.logger.info(f"====== lag to stop stream ====== {time.perf_counter() - self.interrupt_time}")
|
||||
self.logger.info(f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}")
|
||||
self.interrupt_time = None
|
||||
|
||||
if frame["type"] == "start_stream":
|
||||
|
||||
@@ -45,7 +45,7 @@ class TTSService(AIService):
|
||||
class ImageGenService(AIService):
|
||||
# Renders the image. Returns an Image object.
|
||||
@abstractmethod
|
||||
def run_image_gen(self, sentence) -> Image.Image:
|
||||
def run_image_gen(self, sentence) -> tuple[str, Image.Image]:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class AzureTTSService(TTSService):
|
||||
self.speech_synthesizer = SpeechSynthesizer(speech_config=self.speech_config, audio_config=None)
|
||||
|
||||
def run_tts(self, sentence) -> Generator[bytes, None, None]:
|
||||
self.logger.info("⌨️ running azure tts async")
|
||||
self.logger.info("Running azure tts")
|
||||
ssml = "<speak version='1.0' xml:lang='en-US' xmlns='http://www.w3.org/2001/10/synthesis' " \
|
||||
"xmlns:mstts='http://www.w3.org/2001/mstts'>" \
|
||||
"<voice name='en-US-SaraNeural'>" \
|
||||
@@ -33,9 +33,9 @@ class AzureTTSService(TTSService):
|
||||
f"{sentence}" \
|
||||
"</prosody></mstts:express-as></voice></speak> "
|
||||
result = self.speech_synthesizer.speak_ssml(ssml)
|
||||
self.logger.info("⌨️ got azure tts result")
|
||||
self.logger.info("Got azure tts result")
|
||||
if result.reason == ResultReason.SynthesizingAudioCompleted:
|
||||
self.logger.info("⌨️ returning result")
|
||||
self.logger.info("Returning result")
|
||||
# azure always sends a 44-byte header. Strip it off.
|
||||
yield result.audio_data[44:]
|
||||
elif result.reason == ResultReason.Canceled:
|
||||
@@ -60,7 +60,7 @@ class AzureLLMService(LLMService):
|
||||
def run_llm_async(self, messages) -> Generator[str, None, None]:
|
||||
local_messages = messages.copy()
|
||||
messages_for_log = json.dumps(local_messages)
|
||||
self.logger.info(f"==== generating chat via azure: {messages_for_log}")
|
||||
self.logger.debug(f"Generating chat via azure: {messages_for_log}")
|
||||
|
||||
response = self.get_response(local_messages, stream=True)
|
||||
|
||||
@@ -75,10 +75,10 @@ class AzureLLMService(LLMService):
|
||||
yield chunk["choices"][0]["delta"]["content"]
|
||||
|
||||
|
||||
def run_llm(self, messages) -> str or None:
|
||||
def run_llm(self, messages) -> str | None:
|
||||
local_messages = messages.copy()
|
||||
messages_for_log = json.dumps(local_messages)
|
||||
self.logger.info(f"==== generating chat via azure: {messages_for_log}")
|
||||
self.logger.debug(f"Generating chat via azure: {messages_for_log}")
|
||||
|
||||
response = self.get_response(local_messages, stream=False)
|
||||
if (
|
||||
@@ -93,8 +93,9 @@ class AzureLLMService(LLMService):
|
||||
|
||||
|
||||
class AzureImageGenService(ImageGenService):
|
||||
def run_image_gen(self, sentence) -> Image.Image:
|
||||
self.logger.info("generating azure image", sentence)
|
||||
|
||||
def run_image_gen(self, sentence) -> tuple[str, Image.Image]:
|
||||
self.logger.info("Generating azure image", sentence)
|
||||
|
||||
image = openai.Image.create(
|
||||
api_type = 'azure',
|
||||
|
||||
@@ -15,7 +15,7 @@ class DeepgramAIService(AIService):
|
||||
return 24000
|
||||
|
||||
def run_tts(self, sentence):
|
||||
self.logger.info(f"running deepgram tts for {sentence}")
|
||||
self.logger.info(f"Running deepgram tts for {sentence}")
|
||||
base_url = "https://api.beta.deepgram.com/v1/speak"
|
||||
voice = os.getenv("DEEPGRAM_VOICE") or "alpha-apollo-en-v1" # move this to an environment variable
|
||||
request_url = f"{base_url}?model={voice}&encoding=linear16&container=none"
|
||||
@@ -91,7 +91,7 @@ class MockImageService(ImageGenService):
|
||||
return None
|
||||
|
||||
|
||||
class TestIndexingMessageHandler(unittest.TestCase):
|
||||
class TestStorageMessageHandler(unittest.TestCase):
|
||||
def test_user_message_finalized(self):
|
||||
mock_tts_service = MockTTSService()
|
||||
mock_llm_service = MockLLMService()
|
||||
@@ -106,18 +106,20 @@ class TestIndexingMessageHandler(unittest.TestCase):
|
||||
message_handler = IndexingMessageHandler(
|
||||
"Hello world", service_config, mock_indexer
|
||||
)
|
||||
message_handler.cleanup_user_message = MagicMock(return_value="Parsed user message.")
|
||||
message_handler.add_user_message("User message")
|
||||
message_handler.add_assistant_message("Assistant message will be ignored")
|
||||
message_handler.add_user_message("User message plus something else")
|
||||
message_handler.add_user_message("plus something else")
|
||||
message_handler.finalize_user_message()
|
||||
message_handler.add_assistant_message(
|
||||
"New assistant message will not be ignored"
|
||||
)
|
||||
message_handler.add_user_message("User message second time")
|
||||
message_handler.add_assistant_message("Assistant message second time")
|
||||
message_handler.write_messages_to_index()
|
||||
message_handler.write_messages_to_storage()
|
||||
|
||||
time.sleep(0.5)
|
||||
message_handler.cleanup_user_message.assert_called_with("User message plus something else")
|
||||
self.assertEqual(
|
||||
mock_indexer.mock_calls,
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user