diff --git a/examples/transcription/transcription-mistral.py b/examples/transcription/transcription-mistral.py new file mode 100644 index 000000000..b040b457c --- /dev/null +++ b/examples/transcription/transcription-mistral.py @@ -0,0 +1,93 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import os + +from dotenv import load_dotenv +from loguru import logger + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import Frame, TranscriptionFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.audio.vad_processor import VADProcessor +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.runner.types import RunnerArguments +from pipecat.runner.utils import create_transport +from pipecat.services.mistral.stt import MistralSTTService +from pipecat.transports.base_transport import BaseTransport, TransportParams +from pipecat.transports.daily.transport import DailyParams +from pipecat.transports.websocket.fastapi import FastAPIWebsocketParams + +load_dotenv(override=True) + + +class TranscriptionLogger(FrameProcessor): + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TranscriptionFrame): + print(f"Transcription: {frame.text}") + + # Push all frames through + await self.push_frame(frame, direction) + + +transport_params = { + "daily": lambda: DailyParams( + audio_in_enabled=True, + ), + "twilio": lambda: FastAPIWebsocketParams( + audio_in_enabled=True, + ), + "webrtc": lambda: TransportParams( + audio_in_enabled=True, + ), +} + + +async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): + logger.info(f"Starting bot") + + stt = MistralSTTService( + api_key=os.getenv("MISTRAL_API_KEY"), + ) + + tl = TranscriptionLogger() + vad_processor = VADProcessor(vad_analyzer=SileroVADAnalyzer()) + + pipeline = Pipeline([transport.input(), vad_processor, stt, tl]) + + task = PipelineTask( + pipeline, + params=PipelineParams( + enable_metrics=True, + enable_usage_metrics=True, + ), + idle_timeout_secs=runner_args.pipeline_idle_timeout_secs, + ) + + @transport.event_handler("on_client_disconnected") + async def on_client_disconnected(transport, client): + logger.info(f"Client disconnected") + await task.cancel() + + runner = PipelineRunner(handle_sigint=runner_args.handle_sigint) + + await runner.run(task) + + +async def bot(runner_args: RunnerArguments): + """Main bot entry point compatible with Pipecat Cloud.""" + transport = await create_transport(runner_args, transport_params) + await run_bot(transport, runner_args) + + +if __name__ == "__main__": + from pipecat.runner.run import main + + main() diff --git a/examples/voice/voice-mistral.py b/examples/voice/voice-mistral.py index 9c8c8789b..440039b65 100644 --- a/examples/voice/voice-mistral.py +++ b/examples/voice/voice-mistral.py @@ -22,7 +22,7 @@ from pipecat.processors.aggregators.llm_response_universal import ( ) from pipecat.runner.types import RunnerArguments from pipecat.runner.utils import create_transport -from pipecat.services.deepgram.stt import DeepgramSTTService +from pipecat.services.mistral.stt import MistralSTTService from pipecat.services.mistral.tts import MistralTTSService from pipecat.services.openai.llm import OpenAILLMService from pipecat.transports.base_transport import BaseTransport, TransportParams @@ -53,7 +53,7 @@ transport_params = { async def run_bot(transport: BaseTransport, runner_args: RunnerArguments): logger.info(f"Starting bot") - stt = DeepgramSTTService(api_key=os.getenv("DEEPGRAM_API_KEY")) + stt = MistralSTTService(api_key=os.getenv("MISTRAL_API_KEY")) tts = MistralTTSService( api_key=os.getenv("MISTRAL_API_KEY"), diff --git a/src/pipecat/services/mistral/stt.py b/src/pipecat/services/mistral/stt.py new file mode 100644 index 000000000..c41768d15 --- /dev/null +++ b/src/pipecat/services/mistral/stt.py @@ -0,0 +1,315 @@ +# +# Copyright (c) 2024-2026, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +"""Mistral Speech-to-Text service implementation. + +This module provides a real-time STT service that integrates with Mistral's +Voxtral Realtime transcription API using the Mistral SDK's RealtimeConnection. +""" + +from dataclasses import dataclass +from typing import Any, AsyncGenerator, Optional + +from loguru import logger + +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + Frame, + InterimTranscriptionFrame, + StartFrame, + TranscriptionFrame, + VADUserStartedSpeakingFrame, + VADUserStoppedSpeakingFrame, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.settings import STTSettings +from pipecat.services.stt_latency import MISTRAL_TTFS_P99 +from pipecat.services.stt_service import STTService +from pipecat.utils.time import time_now_iso8601 +from pipecat.utils.tracing.service_decorators import traced_stt + +try: + from mistralai.client import Mistral + from mistralai.client.models import ( + AudioFormat, + RealtimeTranscriptionError, + RealtimeTranscriptionSessionCreated, + TranscriptionStreamDone, + TranscriptionStreamLanguage, + TranscriptionStreamTextDelta, + ) + from mistralai.extra.realtime import RealtimeConnection, UnknownRealtimeEvent +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Mistral STT, you need to `pip install pipecat-ai[mistral]`.") + raise Exception(f"Missing module: {e}") + + +@dataclass +class MistralSTTSettings(STTSettings): + """Settings for MistralSTTService. + + Parameters: + model: STT model identifier. + language: Language hint for transcription. + """ + + pass + + +class MistralSTTService(STTService): + """Mistral Speech-to-Text service using the Voxtral Realtime API. + + This service uses the Mistral SDK's RealtimeConnection to stream audio + and receive transcription events over WebSocket. It extends STTService + directly (rather than WebsocketSTTService) because the SDK manages + the WebSocket connection internally. + + Event handlers available: + + - on_connected: Called when a transcription session is created. + - on_disconnected: Called when the connection is closed. + - on_connection_error: Called when a transcription error occurs. + + Example:: + + @stt.event_handler("on_connected") + async def on_connected(stt): + logger.info("Mistral STT connected") + """ + + Settings = MistralSTTSettings + _settings: Settings + + def __init__( + self, + *, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + sample_rate: Optional[int] = None, + target_streaming_delay_ms: Optional[int] = None, + ttfs_p99_latency: Optional[float] = MISTRAL_TTFS_P99, + settings: Optional[Settings] = None, + **kwargs, + ): + """Initialize Mistral STT service. + + Args: + api_key: Mistral API key for authentication. + base_url: Custom API endpoint URL. + sample_rate: Audio sample rate in Hz. If None, uses the pipeline + sample rate. + target_streaming_delay_ms: Streaming delay for accuracy/latency + tradeoff. Higher values may improve accuracy at the cost of + latency. + ttfs_p99_latency: P99 latency from speech end to final transcript + in seconds. Override for your deployment. + settings: Runtime-updatable settings. + **kwargs: Additional keyword arguments passed to STTService. + """ + default_settings = self.Settings( + model="voxtral-mini-transcribe-realtime-2602", + language=None, + ) + + if settings is not None: + default_settings.apply_update(settings) + + super().__init__( + sample_rate=sample_rate, + ttfs_p99_latency=ttfs_p99_latency, + settings=default_settings, + **kwargs, + ) + + self._client = Mistral(api_key=api_key, server_url=base_url) + self._target_streaming_delay_ms = target_streaming_delay_ms + self._connection: Optional[RealtimeConnection] = None + self._receive_task = None + self._accumulated_text = "" + self._detected_language: Optional[str] = None + + def can_generate_metrics(self) -> bool: + """Check if the service can generate processing metrics. + + Returns: + True, indicating metrics are supported. + """ + return True + + async def start(self, frame: StartFrame): + """Start the STT service and establish connection. + + Args: + frame: Frame indicating service should start. + """ + await super().start(frame) + await self._connect() + + async def stop(self, frame: EndFrame): + """Stop the STT service and close connection. + + Args: + frame: Frame indicating service should stop. + """ + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + """Cancel the STT service and close connection. + + Args: + frame: Frame indicating service should be cancelled. + """ + await super().cancel(frame) + await self._disconnect() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + """Process incoming frames and handle speech events. + + Args: + frame: The frame to process. + direction: Direction of frame flow in the pipeline. + """ + await super().process_frame(frame, direction) + + if isinstance(frame, VADUserStartedSpeakingFrame): + self._accumulated_text = "" + await self._start_metrics() + elif isinstance(frame, VADUserStoppedSpeakingFrame): + if self._connection and not self._connection.is_closed: + await self._connection.flush_audio() + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Send audio data to Mistral for transcription. + + Args: + audio: Raw audio bytes to transcribe. + + Yields: + None - transcription results arrive via the receive events task. + """ + if not self._connection or self._connection.is_closed: + await self._connect() + + await self._connection.send_audio(audio) + yield None + + async def _start_metrics(self): + """Start performance metrics collection for transcription processing.""" + await self.start_processing_metrics() + + async def _connect(self): + """Establish a connection to the Mistral Realtime API.""" + try: + logger.debug(f"{self}: Connecting to Mistral STT") + + audio_format = AudioFormat( + encoding="pcm_s16le", + sample_rate=self.sample_rate, + ) + + self._connection = await self._client.audio.realtime.connect( + model=self._settings.model, + audio_format=audio_format, + target_streaming_delay_ms=self._target_streaming_delay_ms, + ) + + self._receive_task = self.create_task( + self._receive_events(), name="mistral_stt_receive" + ) + except Exception as e: + await self.push_error(error_msg=f"Error connecting to Mistral STT: {e}", exception=e) + + async def _disconnect(self): + """Close the connection and cancel the receive task.""" + if self._receive_task: + await self.cancel_task(self._receive_task) + self._receive_task = None + + if self._connection and not self._connection.is_closed: + try: + logger.debug(f"{self}: Disconnecting from Mistral STT") + await self._connection.close() + except Exception as e: + logger.warning(f"{self}: Error closing connection: {e}") + finally: + self._connection = None + await self._call_event_handler("on_disconnected") + + async def _receive_events(self): + """Background task: iterate connection events and handle them.""" + try: + async for event in self._connection.events(): + if isinstance(event, RealtimeTranscriptionSessionCreated): + logger.debug(f"{self}: Session created: {event.session}") + await self._call_event_handler("on_connected") + + elif isinstance(event, TranscriptionStreamTextDelta): + self._accumulated_text += event.text + await self.push_frame( + InterimTranscriptionFrame( + self._accumulated_text, + self._user_id, + time_now_iso8601(), + ) + ) + + elif isinstance(event, TranscriptionStreamDone): + if event.text: + await self.push_frame( + TranscriptionFrame( + event.text, + self._user_id, + time_now_iso8601(), + language=self._detected_language, + ) + ) + await self._handle_transcription(event.text, True, self._detected_language) + await self.stop_processing_metrics() + self._accumulated_text = "" + + elif isinstance(event, TranscriptionStreamLanguage): + self._detected_language = event.audio_language + + elif isinstance(event, RealtimeTranscriptionError): + error_msg = event.error.message if event.error else "Unknown error" + await self.push_error(error_msg=f"Mistral STT error: {error_msg}") + await self._call_event_handler("on_connection_error", error_msg) + + elif isinstance(event, UnknownRealtimeEvent): + logger.warning(f"{self}: Unknown realtime event: {event}") + + except Exception as e: + await self.push_error(error_msg=f"Mistral STT receive error: {e}", exception=e) + await self._call_event_handler("on_connection_error", str(e)) + finally: + self._connection = None + + @traced_stt + async def _handle_transcription( + self, transcript: str, is_final: bool, language: Optional[str] = None + ): + """Handle a transcription result with tracing.""" + pass + + async def _update_settings(self, delta: STTSettings) -> dict[str, Any]: + """Apply a settings delta, reconnecting if model or language changes. + + Args: + delta: An STT settings delta. + + Returns: + Dict mapping changed field names to their previous values. + """ + changed = await super()._update_settings(delta) + + if changed: + await self._disconnect() + await self._connect() + + return changed diff --git a/src/pipecat/services/stt_latency.py b/src/pipecat/services/stt_latency.py index 403902379..5ffd798de 100644 --- a/src/pipecat/services/stt_latency.py +++ b/src/pipecat/services/stt_latency.py @@ -55,4 +55,5 @@ NVIDIA_TTFS_P99: float = DEFAULT_TTFS_P99 WHISPER_TTFS_P99: float = DEFAULT_TTFS_P99 # No benchmark available yet; using conservative default +MISTRAL_TTFS_P99: float = DEFAULT_TTFS_P99 SMALLEST_TTFS_P99: float = DEFAULT_TTFS_P99