Apply additional fixups to context messages to meet Mistral-specific requirements

This commit is contained in:
Paul Kompfner
2025-09-08 11:23:25 -04:00
parent 1cccb97ccf
commit daee1ddf3b
2 changed files with 34 additions and 8 deletions

View File

@@ -12,6 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `OpenAIRealtimeLLMService` and `AzureRealtimeLLMService` which provide
access to OpenAI Realtime.
### Fixed
- Add additional fixups to Mistral context messages to ensure they meet
Mistral-specific requirements, avoiding Mistral "invalid request" errors.
### Deprecated
- `NoisereduceFilter` is now deprecated and will be removed in a future

View File

@@ -57,16 +57,18 @@ class MistralLLMService(OpenAILLMService):
logger.debug(f"Creating Mistral client with api {base_url}")
return super().create_client(api_key, base_url, **kwargs)
def _apply_mistral_assistant_prefix(
def _apply_mistral_fixups(
self, messages: List[ChatCompletionMessageParam]
) -> List[ChatCompletionMessageParam]:
"""Apply Mistral's assistant message prefix requirement.
"""Apply fixups to messages to meet Mistral-specific requirements.
Mistral requires assistant messages to have prefix=True when they
are the final message in a conversation. According to Mistral's API:
- Assistant messages with prefix=True MUST be the last message
- Only add prefix=True to the final assistant message when needed
- This allows assistant messages to be accepted as the last message
1. A "tool"-role message must be followed by an assistant message.
2. "system"-role messages must only appear at the start of a
conversation.
3. Assistant messages must have prefix=True when they are the final
message in a conversation (but at no other point).
Args:
messages: The original list of messages.
@@ -80,6 +82,23 @@ class MistralLLMService(OpenAILLMService):
# Create a copy to avoid modifying the original
fixed_messages = [dict(msg) for msg in messages]
# Ensure all tool responses are followed by an assistant message
assistant_insert_indices = []
for i, msg in enumerate(fixed_messages[:-1]):
if msg.get("role") == "tool" and not fixed_messages[i + 1].get("role") == "assistant":
assistant_insert_indices.append(i + 1)
for idx in reversed(assistant_insert_indices):
fixed_messages.insert(idx, {"role": "assistant", "content": " "})
# Convert any "system" messages that aren't at the start (i.e., after the initial contiguous block) to "user"
first_non_system_idx = next(
(i for i, msg in enumerate(fixed_messages) if msg.get("role") != "system"),
len(fixed_messages),
)
for i, msg in enumerate(fixed_messages):
if msg.get("role") == "system" and i >= first_non_system_idx:
msg["role"] = "user"
# Get the last message
last_message = fixed_messages[-1]
@@ -88,6 +107,8 @@ class MistralLLMService(OpenAILLMService):
if last_message.get("role") == "assistant" and "prefix" not in last_message:
last_message["prefix"] = True
print(f"Fixed messages for Mistral: {fixed_messages}")
return fixed_messages
async def run_function_calls(self, function_calls: Sequence[FunctionCallFromLLM]):
@@ -158,7 +179,7 @@ class MistralLLMService(OpenAILLMService):
- Core completion settings
"""
# Apply Mistral's assistant prefix requirement for API compatibility
fixed_messages = self._apply_mistral_assistant_prefix(params_from_context["messages"])
fixed_messages = self._apply_mistral_fixups(params_from_context["messages"])
params = {
"model": self.model_name,