services(anthropic): fix interruptions with anthropic
This commit is contained in:
@@ -4,8 +4,6 @@
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import time
|
||||
import base64
|
||||
|
||||
@@ -80,8 +78,20 @@ class AnthropicLLMService(LLMService):
|
||||
}]
|
||||
})
|
||||
else:
|
||||
# text frame
|
||||
anthropic_messages.append({"role": role, "content": content})
|
||||
# Text frame. Anthropic needs the roles to alternate. This will
|
||||
# cause an issue with interruptions. So, if we detect we are the
|
||||
# ones asking again it probably means we were interrupted.
|
||||
if role == "user" and len(anthropic_messages) > 1:
|
||||
last_message = anthropic_messages[-1]
|
||||
if last_message["role"] == "user":
|
||||
anthropic_messages = anthropic_messages[:-1]
|
||||
content = last_message["content"]
|
||||
anthropic_messages.append(
|
||||
{"role": "user", "content": f"Sorry, I just asked you about [{content}] but now I would like to know [{text}]."})
|
||||
else:
|
||||
anthropic_messages.append({"role": role, "content": text})
|
||||
else:
|
||||
anthropic_messages.append({"role": role, "content": text})
|
||||
|
||||
return anthropic_messages
|
||||
|
||||
@@ -107,7 +117,7 @@ class AnthropicLLMService(LLMService):
|
||||
await self.push_frame(LLMResponseEndFrame())
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error(f"Anthrophic exception: {e}")
|
||||
finally:
|
||||
await self.push_frame(LLMFullResponseEndFrame())
|
||||
|
||||
@@ -125,22 +135,3 @@ class AnthropicLLMService(LLMService):
|
||||
|
||||
if context:
|
||||
await self._process_context(context)
|
||||
|
||||
async def x_process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
if isinstance(frame, LLMMessagesFrame):
|
||||
stream = await self.client.messages.create(
|
||||
max_tokens=self.max_tokens,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, Claude",
|
||||
}
|
||||
],
|
||||
model=self.model,
|
||||
stream=True,
|
||||
)
|
||||
async for event in stream:
|
||||
if event.type == "content_block_delta":
|
||||
await self.push_frame(TextFrame(event.delta.text))
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
@@ -36,7 +36,7 @@ class CartesiaTTSService(TTSService):
|
||||
logger.error(f"Cartesia initialization error: {e}")
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.debug(f"Transcribing text: [{text}]")
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
try:
|
||||
chunk_generator = await self._client.generate(
|
||||
@@ -50,4 +50,4 @@ class CartesiaTTSService(TTSService):
|
||||
async for chunk in chunk_generator:
|
||||
yield AudioRawFrame(chunk['audio'], 16000, 1)
|
||||
except Exception as e:
|
||||
logger.error(f"Cartesia error: {e}")
|
||||
logger.error(f"Cartesia exception: {e}")
|
||||
|
||||
@@ -30,7 +30,8 @@ class DeepgramTTSService(TTSService):
|
||||
self._aiohttp_session = aiohttp_session
|
||||
|
||||
async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
|
||||
logger.info(f"Running Deepgram TTS for {text}")
|
||||
logger.debug(f"Generating TTS: [{text}]")
|
||||
|
||||
base_url = "https://api.deepgram.com/v1/speak"
|
||||
request_url = f"{base_url}?model={self._voice}&encoding=linear16&container=none&sample_rate=16000"
|
||||
headers = {"authorization": f"token {self._api_key}"}
|
||||
@@ -48,4 +49,4 @@ class DeepgramTTSService(TTSService):
|
||||
frame = AudioRawFrame(audio=data, sample_rate=16000, num_channels=1)
|
||||
yield frame
|
||||
except Exception as e:
|
||||
logger.error(f"Exception {e}")
|
||||
logger.error(f"Deepgram exception: {e}")
|
||||
|
||||
Reference in New Issue
Block a user