services(anthropic): fix interruptions with anthropic

This commit is contained in:
Aleix Conchillo Flaqué
2024-06-04 12:13:29 -07:00
parent af202d4fe5
commit 571e10f83e
3 changed files with 20 additions and 28 deletions

View File

@@ -4,8 +4,6 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
import asyncio
import time
import base64
@@ -80,8 +78,20 @@ class AnthropicLLMService(LLMService):
}]
})
else:
# text frame
anthropic_messages.append({"role": role, "content": content})
# Text frame. Anthropic needs the roles to alternate. This will
# cause an issue with interruptions. So, if we detect we are the
# ones asking again it probably means we were interrupted.
if role == "user" and len(anthropic_messages) > 1:
last_message = anthropic_messages[-1]
if last_message["role"] == "user":
anthropic_messages = anthropic_messages[:-1]
content = last_message["content"]
anthropic_messages.append(
{"role": "user", "content": f"Sorry, I just asked you about [{content}] but now I would like to know [{text}]."})
else:
anthropic_messages.append({"role": role, "content": text})
else:
anthropic_messages.append({"role": role, "content": text})
return anthropic_messages
@@ -107,7 +117,7 @@ class AnthropicLLMService(LLMService):
await self.push_frame(LLMResponseEndFrame())
except Exception as e:
logger.error(f"Exception: {e}")
logger.error(f"Anthrophic exception: {e}")
finally:
await self.push_frame(LLMFullResponseEndFrame())
@@ -125,22 +135,3 @@ class AnthropicLLMService(LLMService):
if context:
await self._process_context(context)
async def x_process_frame(self, frame: Frame, direction: FrameDirection):
if isinstance(frame, LLMMessagesFrame):
stream = await self.client.messages.create(
max_tokens=self.max_tokens,
messages=[
{
"role": "user",
"content": "Hello, Claude",
}
],
model=self.model,
stream=True,
)
async for event in stream:
if event.type == "content_block_delta":
await self.push_frame(TextFrame(event.delta.text))
else:
await self.push_frame(frame, direction)

View File

@@ -36,7 +36,7 @@ class CartesiaTTSService(TTSService):
logger.error(f"Cartesia initialization error: {e}")
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Transcribing text: [{text}]")
logger.debug(f"Generating TTS: [{text}]")
try:
chunk_generator = await self._client.generate(
@@ -50,4 +50,4 @@ class CartesiaTTSService(TTSService):
async for chunk in chunk_generator:
yield AudioRawFrame(chunk['audio'], 16000, 1)
except Exception as e:
logger.error(f"Cartesia error: {e}")
logger.error(f"Cartesia exception: {e}")

View File

@@ -30,7 +30,8 @@ class DeepgramTTSService(TTSService):
self._aiohttp_session = aiohttp_session
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.info(f"Running Deepgram TTS for {text}")
logger.debug(f"Generating TTS: [{text}]")
base_url = "https://api.deepgram.com/v1/speak"
request_url = f"{base_url}?model={self._voice}&encoding=linear16&container=none&sample_rate=16000"
headers = {"authorization": f"token {self._api_key}"}
@@ -48,4 +49,4 @@ class DeepgramTTSService(TTSService):
frame = AudioRawFrame(audio=data, sample_rate=16000, num_channels=1)
yield frame
except Exception as e:
logger.error(f"Exception {e}")
logger.error(f"Deepgram exception: {e}")