a little cleanup, moving services that need to be updated to their own directory

This commit is contained in:
Moishe Lettvin
2023-12-28 11:33:37 -05:00
parent 3f1bb7cac1
commit 512cc71073
14 changed files with 46 additions and 57 deletions

View File

@@ -174,7 +174,7 @@ class AsyncProcessor:
self.set_state(AsyncProcessorState.DONE)
def async_play(self) -> None:
self.logger.info(f"starting to play")
self.logger.info(f"Starting to play")
if self.maybe_set_state(AsyncProcessorState.PLAYING):
self.do_play()
self.maybe_set_state(AsyncProcessorState.DONE)
@@ -233,7 +233,7 @@ class Response(AsyncProcessor):
def get_preparation_iterator(self) -> Iterator:
messages_for_llm = self.message_handler.get_llm_messages()
self.logger.error(f"messages for llm: {json.dumps(messages_for_llm, indent=2)}")
self.logger.debug(f"Messages for llm: {json.dumps(messages_for_llm, indent=2)}")
return self.clauses_from_chunks(
self.services.llm.run_llm_async(messages_for_llm)
)

View File

@@ -62,23 +62,23 @@ class IndexingMessageHandler(MessageHandler):
self.search_indexer = indexer
self.last_written_idx = 0
self.index_message_queue = Queue()
self.storage_message_queue = Queue()
self.index_writer_thread = Thread(target=self.indexer_writer, daemon=True)
self.index_writer_thread = Thread(target=self.storage_writer, daemon=True)
self.index_writer_thread.start()
self.logger = logging.getLogger("bot-instance")
def shutdown(self):
self.finalize_user_message()
self.index_message_queue.put(None)
self.storage_message_queue.put(None)
self.index_writer_thread.join()
def indexer_writer(self) -> None:
def storage_writer(self) -> None:
while True:
try:
message_idx = self.index_message_queue.get()
self.index_message_queue.task_done()
message_idx = self.storage_message_queue.get()
self.storage_message_queue.task_done()
if message_idx is None:
return
@@ -103,35 +103,19 @@ class IndexingMessageHandler(MessageHandler):
pass
def cleanup_user_message(self, user_message) -> str:
messages = [
{
"role": "system",
"content": """
You are an assistant who is very good at making transcriptions
of human speech into well-capitalized and punctuated text, without
changing any words or the order of the words. Please change this
transcription to something suitable for the printed page.
""",
},
{"role": "user", "content": user_message},
]
result = self.services.llm.run_llm(messages)
if result:
user_message = result
return user_message
def finalize_user_message(self):
super().finalize_user_message()
self.write_messages_to_index()
self.write_messages_to_storage()
def write_messages_to_index(self):
def write_messages_to_storage(self):
if self.finalized_user_message_idx is None:
return
for idx in range(self.last_written_idx, len(self.messages)):
self.logger.info(
f"writing to index: {self.messages[idx].type} {self.messages[idx].message}"
f"Writing to storage: {self.messages[idx].type} {self.messages[idx].message}"
)
if (
self.messages[idx].type == "user"
@@ -140,4 +124,4 @@ class IndexingMessageHandler(MessageHandler):
break
if self.messages[idx].type != "system":
self.index_message_queue.put(idx)
self.storage_message_queue.put(idx)

View File

@@ -78,7 +78,7 @@ class Orchestrator(EventHandler):
intro.prepare()
intro.set_state_callback(AsyncProcessorState.DONE, self.on_intro_played)
intro.set_state_callback(AsyncProcessorState.FINALIZED, self.on_intro_finished)
self.logger.info(f"Response is preparing")
self.logger.info(f"Introduction is preparing")
self.current_response: AsyncProcessor = intro
self.can_interrupt = False
@@ -88,14 +88,14 @@ class Orchestrator(EventHandler):
self.speech_timeout = None
self.interrupt_time = None
self.logger.info("configuring daily")
self.logger.info("Configuring daily")
self.configure_daily()
def configure_daily(self):
Daily.init()
self.client = CallClient(event_handler=self)
self.logger.info(f"mic sample rate: {self.services.tts.get_mic_sample_rate()}")
self.logger.info(f"Mic sample rate: {self.services.tts.get_mic_sample_rate()}")
self.mic: VirtualMicrophoneDevice = Daily.create_microphone_device(
"mic", sample_rate=self.services.tts.get_mic_sample_rate(), channels=1
)
@@ -168,23 +168,23 @@ class Orchestrator(EventHandler):
self.client.leave()
def stop(self):
self.logger.info("stop current response")
self.logger.info("Stop current response")
if self.current_response:
if self.current_response.state < AsyncProcessorState.INTERRUPTED:
self.current_response.interrupt()
self.logger.info("wait for state transition")
self.logger.info("Wait for state transition")
self.current_response.wait_for_state_transition(AsyncProcessorState.FINALIZED)
self.stop_threads.set()
self.camera_thread.join()
self.logger.info("camera thread stopped")
self.logger.info("Camera thread stopped")
self.logger.info("put stop in output queue")
self.logger.info("Put stop in output queue")
self.output_queue.put({"type": "stop"})
self.frame_consumer_thread.join()
self.logger.info("orchestrator stopped.")
self.logger.info("Orchestrator stopped.")
def on_intro_played(self, intro):
self.can_interrupt = True
@@ -202,7 +202,7 @@ class Orchestrator(EventHandler):
self.message_handler.finalize_user_message()
def call_joined(self, join_data, client_error):
self.logger.info(f"call_joined: {join_data}, {client_error}")
self.logger.info(f"Call_joined: {join_data}, {client_error}")
self.client.start_transcription(
{
"language": "en",
@@ -231,7 +231,7 @@ class Orchestrator(EventHandler):
def on_participant_left(self, participant, reason):
if len(self.client.participants()) < 2:
self.logger.info("participant left")
self.logger.info(f"Participant {participant} left")
self.participant_left = True
def on_app_message(self, message, sender):
@@ -249,13 +249,13 @@ class Orchestrator(EventHandler):
self.handle_transcription_fragment(message['text'])
def on_transcription_stopped(self, stopped_by, stopped_by_error):
self.logger.info(f"transcription stopped {stopped_by}, {stopped_by_error}")
self.logger.info(f"Transcription stopped {stopped_by}, {stopped_by_error}")
def on_transcription_error(self, message):
self.logger.error(f"transcription error {message}")
self.logger.error(f"Transcription error {message}")
def on_transcription_started(self, status):
self.logger.info(f"transcription started {status}")
self.logger.info(f"Transcription started {status}")
def set_image(self, image: bytes):
self.image: bytes | None = image
@@ -380,14 +380,16 @@ class Orchestrator(EventHandler):
# self.display_images(thinking_images)
def action(self):
self.logger.info("starting camera thread")
self.logger.info("Starting camera thread")
self.image: bytes | None = None
self.camera_thread = Thread(target=self.run_camera, daemon=True)
self.camera_thread.start()
self.logger.info("Starting frame consumer thread")
self.frame_consumer_thread = Thread(target=self.frame_consumer, daemon=True)
self.frame_consumer_thread.start()
self.logger.info("Playing introduction")
self.can_interrupt = False
self.current_response.play()
@@ -401,7 +403,7 @@ class Orchestrator(EventHandler):
try:
frame = self.output_queue.get()
if frame["type"] == "stop":
self.logger.info("🎬 Stopping frame consumer thread")
self.logger.info("Stopping frame consumer thread")
if os.getenv("WRITE_BOT_AUDIO", False):
filename = f"conversation-{len(all_audio_frames)}.wav"
@@ -440,7 +442,7 @@ class Orchestrator(EventHandler):
b = bytearray()
else:
if self.interrupt_time:
self.logger.info(f"====== lag to stop stream ====== {time.perf_counter() - self.interrupt_time}")
self.logger.info(f"Lag to stop stream after interruption {time.perf_counter() - self.interrupt_time}")
self.interrupt_time = None
if frame["type"] == "start_stream":

View File

@@ -45,7 +45,7 @@ class TTSService(AIService):
class ImageGenService(AIService):
# Renders the image. Returns an Image object.
@abstractmethod
def run_image_gen(self, sentence) -> Image.Image:
def run_image_gen(self, sentence) -> tuple[str, Image.Image]:
pass

View File

@@ -23,7 +23,7 @@ class AzureTTSService(TTSService):
self.speech_synthesizer = SpeechSynthesizer(speech_config=self.speech_config, audio_config=None)
def run_tts(self, sentence) -> Generator[bytes, None, None]:
self.logger.info("⌨️ running azure tts async")
self.logger.info("Running azure tts")
ssml = "<speak version='1.0' xml:lang='en-US' xmlns='http://www.w3.org/2001/10/synthesis' " \
"xmlns:mstts='http://www.w3.org/2001/mstts'>" \
"<voice name='en-US-SaraNeural'>" \
@@ -33,9 +33,9 @@ class AzureTTSService(TTSService):
f"{sentence}" \
"</prosody></mstts:express-as></voice></speak> "
result = self.speech_synthesizer.speak_ssml(ssml)
self.logger.info("⌨️ got azure tts result")
self.logger.info("Got azure tts result")
if result.reason == ResultReason.SynthesizingAudioCompleted:
self.logger.info("⌨️ returning result")
self.logger.info("Returning result")
# azure always sends a 44-byte header. Strip it off.
yield result.audio_data[44:]
elif result.reason == ResultReason.Canceled:
@@ -60,7 +60,7 @@ class AzureLLMService(LLMService):
def run_llm_async(self, messages) -> Generator[str, None, None]:
local_messages = messages.copy()
messages_for_log = json.dumps(local_messages)
self.logger.info(f"==== generating chat via azure: {messages_for_log}")
self.logger.debug(f"Generating chat via azure: {messages_for_log}")
response = self.get_response(local_messages, stream=True)
@@ -75,10 +75,10 @@ class AzureLLMService(LLMService):
yield chunk["choices"][0]["delta"]["content"]
def run_llm(self, messages) -> str or None:
def run_llm(self, messages) -> str | None:
local_messages = messages.copy()
messages_for_log = json.dumps(local_messages)
self.logger.info(f"==== generating chat via azure: {messages_for_log}")
self.logger.debug(f"Generating chat via azure: {messages_for_log}")
response = self.get_response(local_messages, stream=False)
if (
@@ -93,8 +93,9 @@ class AzureLLMService(LLMService):
class AzureImageGenService(ImageGenService):
def run_image_gen(self, sentence) -> Image.Image:
self.logger.info("generating azure image", sentence)
def run_image_gen(self, sentence) -> tuple[str, Image.Image]:
self.logger.info("Generating azure image", sentence)
image = openai.Image.create(
api_type = 'azure',

View File

@@ -15,7 +15,7 @@ class DeepgramAIService(AIService):
return 24000
def run_tts(self, sentence):
self.logger.info(f"running deepgram tts for {sentence}")
self.logger.info(f"Running deepgram tts for {sentence}")
base_url = "https://api.beta.deepgram.com/v1/speak"
voice = os.getenv("DEEPGRAM_VOICE") or "alpha-apollo-en-v1" # move this to an environment variable
request_url = f"{base_url}?model={voice}&encoding=linear16&container=none"

View File

@@ -91,7 +91,7 @@ class MockImageService(ImageGenService):
return None
class TestIndexingMessageHandler(unittest.TestCase):
class TestStorageMessageHandler(unittest.TestCase):
def test_user_message_finalized(self):
mock_tts_service = MockTTSService()
mock_llm_service = MockLLMService()
@@ -106,18 +106,20 @@ class TestIndexingMessageHandler(unittest.TestCase):
message_handler = IndexingMessageHandler(
"Hello world", service_config, mock_indexer
)
message_handler.cleanup_user_message = MagicMock(return_value="Parsed user message.")
message_handler.add_user_message("User message")
message_handler.add_assistant_message("Assistant message will be ignored")
message_handler.add_user_message("User message plus something else")
message_handler.add_user_message("plus something else")
message_handler.finalize_user_message()
message_handler.add_assistant_message(
"New assistant message will not be ignored"
)
message_handler.add_user_message("User message second time")
message_handler.add_assistant_message("Assistant message second time")
message_handler.write_messages_to_index()
message_handler.write_messages_to_storage()
time.sleep(0.5)
message_handler.cleanup_user_message.assert_called_with("User message plus something else")
self.assertEqual(
mock_indexer.mock_calls,
[