Add Mistral Voxtral Realtime STT service

This commit is contained in:
Mark Backman
2026-04-07 15:26:56 -04:00
parent a7bf9f538c
commit 68a3070ad4
4 changed files with 411 additions and 2 deletions

View File

@@ -0,0 +1,93 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import os
from dotenv import load_dotenv
from loguru import logger
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import Frame, TranscriptionFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.audio.vad_processor import VADProcessor
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.mistral.stt import MistralSTTService
from pipecat.transports.base_transport import BaseTransport, TransportParams
from pipecat.transports.daily.transport import DailyParams
from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams
load_dotenv(override=True)
class TranscriptionLogger(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
print(f"Transcription: {frame.text}")
# Push all frames through
await self.push_frame(frame, direction)
transport_params = {
"daily": lambda: DailyParams(
audio_in_enabled=True,
),
"twilio": lambda: FastAPIWebsocketParams(
audio_in_enabled=True,
),
"webrtc": lambda: TransportParams(
audio_in_enabled=True,
),
}
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = MistralSTTService(
api_key=os.getenv("MISTRAL_API_KEY"),
)
tl = TranscriptionLogger()
vad_processor = VADProcessor(vad_analyzer=SileroVADAnalyzer())
pipeline = Pipeline([transport.input(), vad_processor, stt, tl])
task = PipelineTask(
pipeline,
params=PipelineParams(
enable_metrics=True,
enable_usage_metrics=True,
),
idle_timeout_secs=runner_args.pipeline_idle_timeout_secs,
)
@transport.event_handler("on_client_disconnected")
async def on_client_disconnected(transport, client):
logger.info(f"Client disconnected")
await task.cancel()
runner = PipelineRunner(handle_sigint=runner_args.handle_sigint)
await runner.run(task)
async def bot(runner_args: RunnerArguments):
"""Main bot entry point compatible with Pipecat Cloud."""
transport = await create_transport(runner_args, transport_params)
await run_bot(transport, runner_args)
if __name__ == "__main__":
from pipecat.runner.run import main
main()

View File

@@ -22,7 +22,7 @@ from pipecat.processors.aggregators.llm_response_universal import (
)
from pipecat.runner.types import RunnerArguments
from pipecat.runner.utils import create_transport
from pipecat.services.deepgram.stt import DeepgramSTTService
from pipecat.services.mistral.stt import MistralSTTService
from pipecat.services.mistral.tts import MistralTTSService
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.transports.base_transport import BaseTransport, TransportParams
@@ -53,7 +53,7 @@ transport_params = {
async def run_bot(transport: BaseTransport, runner_args: RunnerArguments):
logger.info(f"Starting bot")
stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY"))
stt = MistralSTTService(api_key=os.getenv("MISTRAL_API_KEY"))
tts = MistralTTSService(
api_key=os.getenv("MISTRAL_API_KEY"),

View File

@@ -0,0 +1,315 @@
#
# Copyright (c) 2024-2026, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
"""Mistral Speech-to-Text service implementation.
This module provides a real-time STT service that integrates with Mistral's
Voxtral Realtime transcription API using the Mistral SDK's RealtimeConnection.
"""
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Optional
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
VADUserStartedSpeakingFrame,
VADUserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.settings import STTSettings
from pipecat.services.stt_latency import MISTRAL_TTFS_P99
from pipecat.services.stt_service import STTService
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
try:
from mistralai.client import Mistral
from mistralai.client.models import (
AudioFormat,
RealtimeTranscriptionError,
RealtimeTranscriptionSessionCreated,
TranscriptionStreamDone,
TranscriptionStreamLanguage,
TranscriptionStreamTextDelta,
)
from mistralai.extra.realtime import RealtimeConnection, UnknownRealtimeEvent
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use Mistral STT, you need to `pip install pipecat-ai[mistral]`.")
raise Exception(f"Missing module: {e}")
@dataclass
class MistralSTTSettings(STTSettings):
"""Settings for MistralSTTService.
Parameters:
model: STT model identifier.
language: Language hint for transcription.
"""
pass
class MistralSTTService(STTService):
"""Mistral Speech-to-Text service using the Voxtral Realtime API.
This service uses the Mistral SDK's RealtimeConnection to stream audio
and receive transcription events over WebSocket. It extends STTService
directly (rather than WebsocketSTTService) because the SDK manages
the WebSocket connection internally.
Event handlers available:
- on_connected: Called when a transcription session is created.
- on_disconnected: Called when the connection is closed.
- on_connection_error: Called when a transcription error occurs.
Example::
@stt.event_handler("on_connected")
async def on_connected(stt):
logger.info("Mistral STT connected")
"""
Settings = MistralSTTSettings
_settings: Settings
def __init__(
self,
*,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
sample_rate: Optional[int] = None,
target_streaming_delay_ms: Optional[int] = None,
ttfs_p99_latency: Optional[float] = MISTRAL_TTFS_P99,
settings: Optional[Settings] = None,
**kwargs,
):
"""Initialize Mistral STT service.
Args:
api_key: Mistral API key for authentication.
base_url: Custom API endpoint URL.
sample_rate: Audio sample rate in Hz. If None, uses the pipeline
sample rate.
target_streaming_delay_ms: Streaming delay for accuracy/latency
tradeoff. Higher values may improve accuracy at the cost of
latency.
ttfs_p99_latency: P99 latency from speech end to final transcript
in seconds. Override for your deployment.
settings: Runtime-updatable settings.
**kwargs: Additional keyword arguments passed to STTService.
"""
default_settings = self.Settings(
model="voxtral-mini-transcribe-realtime-2602",
language=None,
)
if settings is not None:
default_settings.apply_update(settings)
super().__init__(
sample_rate=sample_rate,
ttfs_p99_latency=ttfs_p99_latency,
settings=default_settings,
**kwargs,
)
self._client = Mistral(api_key=api_key, server_url=base_url)
self._target_streaming_delay_ms = target_streaming_delay_ms
self._connection: Optional[RealtimeConnection] = None
self._receive_task = None
self._accumulated_text = ""
self._detected_language: Optional[str] = None
def can_generate_metrics(self) -> bool:
"""Check if the service can generate processing metrics.
Returns:
True, indicating metrics are supported.
"""
return True
async def start(self, frame: StartFrame):
"""Start the STT service and establish connection.
Args:
frame: Frame indicating service should start.
"""
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
"""Stop the STT service and close connection.
Args:
frame: Frame indicating service should stop.
"""
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
"""Cancel the STT service and close connection.
Args:
frame: Frame indicating service should be cancelled.
"""
await super().cancel(frame)
await self._disconnect()
async def process_frame(self, frame: Frame, direction: FrameDirection):
"""Process incoming frames and handle speech events.
Args:
frame: The frame to process.
direction: Direction of frame flow in the pipeline.
"""
await super().process_frame(frame, direction)
if isinstance(frame, VADUserStartedSpeakingFrame):
self._accumulated_text = ""
await self._start_metrics()
elif isinstance(frame, VADUserStoppedSpeakingFrame):
if self._connection and not self._connection.is_closed:
await self._connection.flush_audio()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Send audio data to Mistral for transcription.
Args:
audio: Raw audio bytes to transcribe.
Yields:
None - transcription results arrive via the receive events task.
"""
if not self._connection or self._connection.is_closed:
await self._connect()
await self._connection.send_audio(audio)
yield None
async def _start_metrics(self):
"""Start performance metrics collection for transcription processing."""
await self.start_processing_metrics()
async def _connect(self):
"""Establish a connection to the Mistral Realtime API."""
try:
logger.debug(f"{self}: Connecting to Mistral STT")
audio_format = AudioFormat(
encoding="pcm_s16le",
sample_rate=self.sample_rate,
)
self._connection = await self._client.audio.realtime.connect(
model=self._settings.model,
audio_format=audio_format,
target_streaming_delay_ms=self._target_streaming_delay_ms,
)
self._receive_task = self.create_task(
self._receive_events(), name="mistral_stt_receive"
)
except Exception as e:
await self.push_error(error_msg=f"Error connecting to Mistral STT: {e}", exception=e)
async def _disconnect(self):
"""Close the connection and cancel the receive task."""
if self._receive_task:
await self.cancel_task(self._receive_task)
self._receive_task = None
if self._connection and not self._connection.is_closed:
try:
logger.debug(f"{self}: Disconnecting from Mistral STT")
await self._connection.close()
except Exception as e:
logger.warning(f"{self}: Error closing connection: {e}")
finally:
self._connection = None
await self._call_event_handler("on_disconnected")
async def _receive_events(self):
"""Background task: iterate connection events and handle them."""
try:
async for event in self._connection.events():
if isinstance(event, RealtimeTranscriptionSessionCreated):
logger.debug(f"{self}: Session created: {event.session}")
await self._call_event_handler("on_connected")
elif isinstance(event, TranscriptionStreamTextDelta):
self._accumulated_text += event.text
await self.push_frame(
InterimTranscriptionFrame(
self._accumulated_text,
self._user_id,
time_now_iso8601(),
)
)
elif isinstance(event, TranscriptionStreamDone):
if event.text:
await self.push_frame(
TranscriptionFrame(
event.text,
self._user_id,
time_now_iso8601(),
language=self._detected_language,
)
)
await self._handle_transcription(event.text, True, self._detected_language)
await self.stop_processing_metrics()
self._accumulated_text = ""
elif isinstance(event, TranscriptionStreamLanguage):
self._detected_language = event.audio_language
elif isinstance(event, RealtimeTranscriptionError):
error_msg = event.error.message if event.error else "Unknown error"
await self.push_error(error_msg=f"Mistral STT error: {error_msg}")
await self._call_event_handler("on_connection_error", error_msg)
elif isinstance(event, UnknownRealtimeEvent):
logger.warning(f"{self}: Unknown realtime event: {event}")
except Exception as e:
await self.push_error(error_msg=f"Mistral STT receive error: {e}", exception=e)
await self._call_event_handler("on_connection_error", str(e))
finally:
self._connection = None
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[str] = None
):
"""Handle a transcription result with tracing."""
pass
async def _update_settings(self, delta: STTSettings) -> dict[str, Any]:
"""Apply a settings delta, reconnecting if model or language changes.
Args:
delta: An STT settings delta.
Returns:
Dict mapping changed field names to their previous values.
"""
changed = await super()._update_settings(delta)
if changed:
await self._disconnect()
await self._connect()
return changed

View File

@@ -55,4 +55,5 @@ NVIDIA_TTFS_P99: float = DEFAULT_TTFS_P99
WHISPER_TTFS_P99: float = DEFAULT_TTFS_P99
# No benchmark available yet; using conservative default
MISTRAL_TTFS_P99: float = DEFAULT_TTFS_P99
SMALLEST_TTFS_P99: float = DEFAULT_TTFS_P99