diff --git a/CHANGELOG.md b/CHANGELOG.md index de1a9c1f4..c971a1c23 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Updated AssemblyAI STT service to support their latest streaming + speech-to-text model with improved transcription latency and endpointing. + - You can now access STT service results through the new `TranscriptionFrame.result` and `InterimTranscriptionFrame.result` field. This is useful in case you use some specific settings for the STT and you want to diff --git a/pyproject.toml b/pyproject.toml index a0638a5b0..8f9ff9f9d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ Website = "https://pipecat.ai" [project.optional-dependencies] anthropic = [ "anthropic~=0.49.0" ] -assemblyai = [ "assemblyai~=0.37.0" ] +assemblyai = [ "websockets~=13.1" ] aws = [ "boto3~=1.37.16", "websockets~=13.1" ] aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2" ] azure = [ "azure-cognitiveservices-speech~=1.42.0"] diff --git a/src/pipecat/services/assemblyai/models.py b/src/pipecat/services/assemblyai/models.py new file mode 100644 index 000000000..58b69fdf5 --- /dev/null +++ b/src/pipecat/services/assemblyai/models.py @@ -0,0 +1,61 @@ +from typing import List, Literal, Optional + +from pydantic import BaseModel, Field + + +class Word(BaseModel): + """Represents a single word in a transcription with timing and confidence.""" + + start: int + end: int + text: str + confidence: float + word_is_final: bool = Field(..., alias="word_is_final") + + +class BaseMessage(BaseModel): + """Base class for all AssemblyAI WebSocket messages.""" + + type: str + + +class BeginMessage(BaseMessage): + """Message sent when a new session begins.""" + + type: Literal["Begin"] = "Begin" + id: str + expires_at: int + + +class TurnMessage(BaseMessage): + """Message containing transcription data for a turn of speech.""" + + type: Literal["Turn"] = "Turn" + turn_order: int + turn_is_formatted: bool + end_of_turn: bool + transcript: str + end_of_turn_confidence: float + words: List[Word] + + +class TerminationMessage(BaseMessage): + """Message sent when the session is terminated.""" + + type: Literal["Termination"] = "Termination" + audio_duration_seconds: float + session_duration_seconds: float + + +# Union type for all possible message types +AnyMessage = BeginMessage | TurnMessage | TerminationMessage + + +class AssemblyAIConnectionParams(BaseModel): + sample_rate: int = 16000 + encoding: Literal["pcm_s16le", "pcm_mulaw"] = "pcm_s16le" + formatted_finals: bool = True + word_finalization_max_wait_time: Optional[int] = None + end_of_turn_confidence_threshold: Optional[float] = None + min_end_of_turn_silence_when_confident: Optional[int] = None + max_turn_silence: Optional[int] = None diff --git a/src/pipecat/services/assemblyai/stt.py b/src/pipecat/services/assemblyai/stt.py index 50e16756e..c7e1d9e48 100644 --- a/src/pipecat/services/assemblyai/stt.py +++ b/src/pipecat/services/assemblyai/stt.py @@ -5,30 +5,42 @@ # import asyncio -from typing import AsyncGenerator, Optional +import json +from typing import Any, AsyncGenerator, Dict +from urllib.parse import urlencode from loguru import logger +from pipecat import __version__ as pipecat_version from pipecat.frames.frames import ( CancelFrame, EndFrame, - ErrorFrame, Frame, InterimTranscriptionFrame, StartFrame, TranscriptionFrame, + UserStartedSpeakingFrame, + UserStoppedSpeakingFrame, ) +from pipecat.processors.frame_processor import FrameDirection from pipecat.services.stt_service import STTService from pipecat.transcriptions.language import Language from pipecat.utils.time import time_now_iso8601 from pipecat.utils.tracing.service_decorators import traced_stt +from .models import ( + AssemblyAIConnectionParams, + BaseMessage, + BeginMessage, + TerminationMessage, + TurnMessage, +) + try: - import assemblyai as aai - from assemblyai import AudioEncoding + import websockets except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error("In order to use AssemblyAI, you need to `pip install pipecat-ai[assemblyai]`.") + logger.error('In order to use AssemblyAI, you need to `pip install "pipecat-ai[assemblyai]"`.') raise Exception(f"Missing module: {e}") @@ -37,31 +49,37 @@ class AssemblyAISTTService(STTService): self, *, api_key: str, - sample_rate: Optional[int] = None, - encoding: Optional[AudioEncoding] = None, - language=Language.EN, # Only English is supported for Realtime + language: Language = Language.EN, # AssemblyAI only supports English + api_endpoint_base_url: str = "wss://streaming.assemblyai.com/v3/ws", + connection_params: AssemblyAIConnectionParams = AssemblyAIConnectionParams(), + vad_force_turn_endpoint: bool = True, **kwargs, ): - super().__init__(sample_rate=sample_rate, **kwargs) + self._api_key = api_key + self._language = language + self._api_endpoint_base_url = api_endpoint_base_url + self._connection_params = connection_params + self._vad_force_turn_endpoint = vad_force_turn_endpoint - encoding = encoding or AudioEncoding("pcm_s16le") - aai.settings.api_key = api_key - self._transcriber: Optional[aai.RealtimeTranscriber] = None + super().__init__(sample_rate=self._connection_params.sample_rate, **kwargs) - self._settings = { - "encoding": encoding, - "language": language, - } + self._websocket = None + self._termination_event = asyncio.Event() + self._received_termination = False + self._connected = False + + self._receive_task = None + + self._audio_buffer = bytearray() + self._chunk_size_ms = 50 + self._chunk_size_bytes = 0 def can_generate_metrics(self) -> bool: return True - async def set_language(self, language: Language): - logger.info(f"Switching STT language to: [{language}]") - self._settings["language"] = language - async def start(self, frame: StartFrame): await super().start(frame) + self._chunk_size_bytes = int(self._chunk_size_ms * self._sample_rate * 2 / 1000) await self._connect() async def stop(self, frame: EndFrame): @@ -73,109 +91,182 @@ class AssemblyAISTTService(STTService): await self._disconnect() async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: - """Process an audio chunk for STT transcription. + self._audio_buffer.extend(audio) - This method streams the audio data to AssemblyAI for real-time transcription. - Transcription results are handled asynchronously via callback functions. + while len(self._audio_buffer) >= self._chunk_size_bytes: + chunk = bytes(self._audio_buffer[: self._chunk_size_bytes]) + self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :] + await self._websocket.send(chunk) - :param audio: Audio data as bytes - :yield: None (transcription frames are pushed via self.push_frame in callbacks) - """ - if self._transcriber: + yield Frame() + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + if isinstance(frame, UserStartedSpeakingFrame): await self.start_ttfb_metrics() + elif isinstance(frame, UserStoppedSpeakingFrame): + if self._vad_force_turn_endpoint: + await self._websocket.send(json.dumps({"type": "ForceEndpoint"})) await self.start_processing_metrics() - self._transcriber.stream(audio) - yield None @traced_stt - async def _handle_transcription( - self, transcript: str, is_final: bool, language: Optional[Language] = None - ): - """Handle a transcription result with tracing.""" - await self.stop_ttfb_metrics() - await self.stop_processing_metrics() + async def _trace_transcription(self, transcript: str, is_final: bool, language: Language): + """Record transcription event for tracing.""" + pass + + def _build_ws_url(self) -> str: + """Build WebSocket URL with query parameters using urllib.parse.urlencode.""" + params = { + k: str(v).lower() if isinstance(v, bool) else v + for k, v in self._connection_params.model_dump().items() + if v is not None + } + if params: + query_string = urlencode(params) + return f"{self._api_endpoint_base_url}?{query_string}" + return self._api_endpoint_base_url async def _connect(self): - """Establish a connection to the AssemblyAI real-time transcription service. - - This method sets up the necessary callback functions and initializes the - AssemblyAI transcriber. - """ - if self._transcriber: - return - - def on_open(session_opened: aai.RealtimeSessionOpened): - """Callback for when the connection to AssemblyAI is opened.""" - logger.info(f"{self}: Connected to AssemblyAI") - - def on_data(transcript: aai.RealtimeTranscript): - """Callback for handling incoming transcription data. - - This function runs in a separate thread from the main asyncio event loop. - It creates appropriate transcription frames and schedules them to be - pushed to the next stage of the pipeline in the main event loop. - """ - if not transcript.text: - return - - timestamp = time_now_iso8601() - is_final = isinstance(transcript, aai.RealtimeFinalTranscript) - language = self._settings["language"] - - if is_final: - frame = TranscriptionFrame( - transcript.text, - "", - timestamp, - language, - result=transcript, - ) - else: - frame = InterimTranscriptionFrame( - transcript.text, - "", - timestamp, - language, - result=transcript, - ) - - asyncio.run_coroutine_threadsafe( - self._handle_transcription(transcript.text, is_final, language), - self.get_event_loop(), + try: + ws_url = self._build_ws_url() + headers = { + "Authorization": self._api_key, + "User-Agent": f"AssemblyAI/1.0 (integration=Pipecat/{pipecat_version})", + } + self._websocket = await websockets.connect( + ws_url, + extra_headers=headers, ) - - # Schedule the coroutine to run in the main event loop - # This is necessary because this callback runs in a different thread - asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop()) - - def on_error(error: aai.RealtimeError): - """Callback for handling errors from AssemblyAI. - - Like on_data, this runs in a separate thread and schedules error - handling in the main event loop. - """ - logger.error(f"{self}: An error occurred: {error}") - # Schedule the coroutine to run in the main event loop - asyncio.run_coroutine_threadsafe( - self.push_frame(ErrorFrame(str(error))), self.get_event_loop() - ) - - def on_close(): - """Callback for when the connection to AssemblyAI is closed.""" - logger.info(f"{self}: Disconnected from AssemblyAI") - - self._transcriber = aai.RealtimeTranscriber( - sample_rate=self.sample_rate, - encoding=self._settings["encoding"], - on_data=on_data, - on_error=on_error, - on_open=on_open, - on_close=on_close, - ) - self._transcriber.connect() + self._connected = True + self._receive_task = self.create_task(self._receive_task_handler()) + except Exception as e: + logger.error(f"Failed to connect to AssemblyAI: {e}") + self._connected = False + raise async def _disconnect(self): - """Disconnect from the AssemblyAI service and clean up resources.""" - if self._transcriber: - self._transcriber.close() - self._transcriber = None + """Disconnect from AssemblyAI WebSocket and wait for termination message.""" + if not self._connected or not self._websocket: + return + + try: + self._termination_event.clear() + self._received_termination = False + + if len(self._audio_buffer) > 0: + await self._websocket.send(bytes(self._audio_buffer)) + self._audio_buffer.clear() + + try: + await self._websocket.send(json.dumps({"type": "Terminate"})) + + try: + await asyncio.wait_for( + self._termination_event.wait(), + timeout=5.0, + ) + except asyncio.TimeoutError: + logger.warning("Timed out waiting for termination message from server") + + except Exception as e: + logger.warning(f"Error during termination handshake: {e}") + + if self._receive_task: + await self.cancel_task(self._receive_task) + + await self._websocket.close() + + except Exception as e: + logger.error(f"Error during disconnect: {e}") + + finally: + self._websocket = None + self._connected = False + self._receive_task = None + + async def _receive_task_handler(self): + """Handle incoming WebSocket messages.""" + try: + while self._connected: + try: + message = await self._websocket.recv() + data = json.loads(message) + await self._handle_message(data) + except websockets.exceptions.ConnectionClosedOK: + break + except Exception as e: + logger.error(f"Error processing WebSocket message: {e}") + break + + except Exception as e: + logger.error(f"Fatal error in receive handler: {e}") + + def _parse_message(self, message: Dict[str, Any]) -> BaseMessage: + """Parse a raw message into the appropriate message type.""" + msg_type = message.get("type") + + if msg_type == "Begin": + return BeginMessage.model_validate(message) + elif msg_type == "Turn": + return TurnMessage.model_validate(message) + elif msg_type == "Termination": + return TerminationMessage.model_validate(message) + else: + raise ValueError(f"Unknown message type: {msg_type}") + + async def _handle_message(self, message: Dict[str, Any]): + """Handle AssemblyAI WebSocket messages.""" + try: + parsed_message = self._parse_message(message) + + if isinstance(parsed_message, BeginMessage): + logger.debug( + f"Session Begin: {parsed_message.id} (expires at {parsed_message.expires_at})" + ) + elif isinstance(parsed_message, TurnMessage): + await self._handle_transcription(parsed_message) + elif isinstance(parsed_message, TerminationMessage): + await self._handle_termination(parsed_message) + except Exception as e: + logger.error(f"Error handling message: {e}") + + async def _handle_termination(self, message: TerminationMessage): + """Handle termination message.""" + self._received_termination = True + self._termination_event.set() + + logger.info( + f"Session Terminated: Audio Duration={message.audio_duration_seconds}s, " + f"Session Duration={message.session_duration_seconds}s" + ) + await self.push_frame(EndFrame()) + + async def _handle_transcription(self, message: TurnMessage): + """Handle transcription results.""" + if not message.transcript: + return + await self.stop_ttfb_metrics() + if message.end_of_turn and ( + not self._connection_params.formatted_finals or message.turn_is_formatted + ): + await self.push_frame( + TranscriptionFrame( + message.transcript, + "", # participant + time_now_iso8601(), + self._language, + message, + ) + ) + await self._trace_transcription(message.transcript, True, self._language) + await self.stop_processing_metrics() + else: + await self.push_frame( + InterimTranscriptionFrame( + message.transcript, + "", # participant + time_now_iso8601(), + self._language, + message, + ) + )