diff --git a/CHANGELOG.md b/CHANGELOG.md index 032fd9efe..ac1bdf586 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `OpenAIRealtimeLLMService` and `AzureRealtimeLLMService` which provide access to OpenAI Realtime. +### Fixed + +- Add additional fixups to Mistral context messages to ensure they meet + Mistral-specific requirements, avoiding Mistral "invalid request" errors. + ### Deprecated - `NoisereduceFilter` is now deprecated and will be removed in a future diff --git a/src/pipecat/services/mistral/llm.py b/src/pipecat/services/mistral/llm.py index ce61a1a11..38ca3f049 100644 --- a/src/pipecat/services/mistral/llm.py +++ b/src/pipecat/services/mistral/llm.py @@ -57,16 +57,18 @@ class MistralLLMService(OpenAILLMService): logger.debug(f"Creating Mistral client with api {base_url}") return super().create_client(api_key, base_url, **kwargs) - def _apply_mistral_assistant_prefix( + def _apply_mistral_fixups( self, messages: List[ChatCompletionMessageParam] ) -> List[ChatCompletionMessageParam]: - """Apply Mistral's assistant message prefix requirement. + """Apply fixups to messages to meet Mistral-specific requirements. - Mistral requires assistant messages to have prefix=True when they - are the final message in a conversation. According to Mistral's API: - - Assistant messages with prefix=True MUST be the last message - - Only add prefix=True to the final assistant message when needed - - This allows assistant messages to be accepted as the last message + 1. A "tool"-role message must be followed by an assistant message. + + 2. "system"-role messages must only appear at the start of a + conversation. + + 3. Assistant messages must have prefix=True when they are the final + message in a conversation (but at no other point). Args: messages: The original list of messages. @@ -80,6 +82,23 @@ class MistralLLMService(OpenAILLMService): # Create a copy to avoid modifying the original fixed_messages = [dict(msg) for msg in messages] + # Ensure all tool responses are followed by an assistant message + assistant_insert_indices = [] + for i, msg in enumerate(fixed_messages[:-1]): + if msg.get("role") == "tool" and not fixed_messages[i + 1].get("role") == "assistant": + assistant_insert_indices.append(i + 1) + for idx in reversed(assistant_insert_indices): + fixed_messages.insert(idx, {"role": "assistant", "content": " "}) + + # Convert any "system" messages that aren't at the start (i.e., after the initial contiguous block) to "user" + first_non_system_idx = next( + (i for i, msg in enumerate(fixed_messages) if msg.get("role") != "system"), + len(fixed_messages), + ) + for i, msg in enumerate(fixed_messages): + if msg.get("role") == "system" and i >= first_non_system_idx: + msg["role"] = "user" + # Get the last message last_message = fixed_messages[-1] @@ -88,6 +107,8 @@ class MistralLLMService(OpenAILLMService): if last_message.get("role") == "assistant" and "prefix" not in last_message: last_message["prefix"] = True + print(f"Fixed messages for Mistral: {fixed_messages}") + return fixed_messages async def run_function_calls(self, function_calls: Sequence[FunctionCallFromLLM]): @@ -158,7 +179,7 @@ class MistralLLMService(OpenAILLMService): - Core completion settings """ # Apply Mistral's assistant prefix requirement for API compatibility - fixed_messages = self._apply_mistral_assistant_prefix(params_from_context["messages"]) + fixed_messages = self._apply_mistral_fixups(params_from_context["messages"]) params = { "model": self.model_name,