Apply additional fixups to context messages to meet Mistral-specific requirements
This commit is contained in:
@@ -12,6 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added `OpenAIRealtimeLLMService` and `AzureRealtimeLLMService` which provide
|
||||
access to OpenAI Realtime.
|
||||
|
||||
### Fixed
|
||||
|
||||
- Add additional fixups to Mistral context messages to ensure they meet
|
||||
Mistral-specific requirements, avoiding Mistral "invalid request" errors.
|
||||
|
||||
### Deprecated
|
||||
|
||||
- `NoisereduceFilter` is now deprecated and will be removed in a future
|
||||
|
||||
@@ -57,16 +57,18 @@ class MistralLLMService(OpenAILLMService):
|
||||
logger.debug(f"Creating Mistral client with api {base_url}")
|
||||
return super().create_client(api_key, base_url, **kwargs)
|
||||
|
||||
def _apply_mistral_assistant_prefix(
|
||||
def _apply_mistral_fixups(
|
||||
self, messages: List[ChatCompletionMessageParam]
|
||||
) -> List[ChatCompletionMessageParam]:
|
||||
"""Apply Mistral's assistant message prefix requirement.
|
||||
"""Apply fixups to messages to meet Mistral-specific requirements.
|
||||
|
||||
Mistral requires assistant messages to have prefix=True when they
|
||||
are the final message in a conversation. According to Mistral's API:
|
||||
- Assistant messages with prefix=True MUST be the last message
|
||||
- Only add prefix=True to the final assistant message when needed
|
||||
- This allows assistant messages to be accepted as the last message
|
||||
1. A "tool"-role message must be followed by an assistant message.
|
||||
|
||||
2. "system"-role messages must only appear at the start of a
|
||||
conversation.
|
||||
|
||||
3. Assistant messages must have prefix=True when they are the final
|
||||
message in a conversation (but at no other point).
|
||||
|
||||
Args:
|
||||
messages: The original list of messages.
|
||||
@@ -80,6 +82,23 @@ class MistralLLMService(OpenAILLMService):
|
||||
# Create a copy to avoid modifying the original
|
||||
fixed_messages = [dict(msg) for msg in messages]
|
||||
|
||||
# Ensure all tool responses are followed by an assistant message
|
||||
assistant_insert_indices = []
|
||||
for i, msg in enumerate(fixed_messages[:-1]):
|
||||
if msg.get("role") == "tool" and not fixed_messages[i + 1].get("role") == "assistant":
|
||||
assistant_insert_indices.append(i + 1)
|
||||
for idx in reversed(assistant_insert_indices):
|
||||
fixed_messages.insert(idx, {"role": "assistant", "content": " "})
|
||||
|
||||
# Convert any "system" messages that aren't at the start (i.e., after the initial contiguous block) to "user"
|
||||
first_non_system_idx = next(
|
||||
(i for i, msg in enumerate(fixed_messages) if msg.get("role") != "system"),
|
||||
len(fixed_messages),
|
||||
)
|
||||
for i, msg in enumerate(fixed_messages):
|
||||
if msg.get("role") == "system" and i >= first_non_system_idx:
|
||||
msg["role"] = "user"
|
||||
|
||||
# Get the last message
|
||||
last_message = fixed_messages[-1]
|
||||
|
||||
@@ -88,6 +107,8 @@ class MistralLLMService(OpenAILLMService):
|
||||
if last_message.get("role") == "assistant" and "prefix" not in last_message:
|
||||
last_message["prefix"] = True
|
||||
|
||||
print(f"Fixed messages for Mistral: {fixed_messages}")
|
||||
|
||||
return fixed_messages
|
||||
|
||||
async def run_function_calls(self, function_calls: Sequence[FunctionCallFromLLM]):
|
||||
@@ -158,7 +179,7 @@ class MistralLLMService(OpenAILLMService):
|
||||
- Core completion settings
|
||||
"""
|
||||
# Apply Mistral's assistant prefix requirement for API compatibility
|
||||
fixed_messages = self._apply_mistral_assistant_prefix(params_from_context["messages"])
|
||||
fixed_messages = self._apply_mistral_fixups(params_from_context["messages"])
|
||||
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
|
||||
Reference in New Issue
Block a user