Merge pull request #1876 from m-ods/m-ods/assemblyai-universal-streaming

Update AssemblyAI Streaming STT
This commit is contained in:
Mark Backman
2025-05-30 08:55:43 -04:00
committed by GitHub
4 changed files with 271 additions and 116 deletions

View File

@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Updated AssemblyAI STT service to support their latest streaming
speech-to-text model with improved transcription latency and endpointing.
- You can now access STT service results through the new
`TranscriptionFrame.result` and `InterimTranscriptionFrame.result` field. This
is useful in case you use some specific settings for the STT and you want to

View File

@@ -41,7 +41,7 @@ Website = "https://pipecat.ai"
[project.optional-dependencies]
anthropic = [ "anthropic~=0.49.0" ]
assemblyai = [ "assemblyai~=0.37.0" ]
assemblyai = [ "websockets~=13.1" ]
aws = [ "boto3~=1.37.16", "websockets~=13.1" ]
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2" ]
azure = [ "azure-cognitiveservices-speech~=1.42.0"]

View File

@@ -0,0 +1,61 @@
from typing import List, Literal, Optional
from pydantic import BaseModel, Field
class Word(BaseModel):
"""Represents a single word in a transcription with timing and confidence."""
start: int
end: int
text: str
confidence: float
word_is_final: bool = Field(..., alias="word_is_final")
class BaseMessage(BaseModel):
"""Base class for all AssemblyAI WebSocket messages."""
type: str
class BeginMessage(BaseMessage):
"""Message sent when a new session begins."""
type: Literal["Begin"] = "Begin"
id: str
expires_at: int
class TurnMessage(BaseMessage):
"""Message containing transcription data for a turn of speech."""
type: Literal["Turn"] = "Turn"
turn_order: int
turn_is_formatted: bool
end_of_turn: bool
transcript: str
end_of_turn_confidence: float
words: List[Word]
class TerminationMessage(BaseMessage):
"""Message sent when the session is terminated."""
type: Literal["Termination"] = "Termination"
audio_duration_seconds: float
session_duration_seconds: float
# Union type for all possible message types
AnyMessage = BeginMessage | TurnMessage | TerminationMessage
class AssemblyAIConnectionParams(BaseModel):
sample_rate: int = 16000
encoding: Literal["pcm_s16le", "pcm_mulaw"] = "pcm_s16le"
formatted_finals: bool = True
word_finalization_max_wait_time: Optional[int] = None
end_of_turn_confidence_threshold: Optional[float] = None
min_end_of_turn_silence_when_confident: Optional[int] = None
max_turn_silence: Optional[int] = None

View File

@@ -5,30 +5,42 @@
#
import asyncio
from typing import AsyncGenerator, Optional
import json
from typing import Any, AsyncGenerator, Dict
from urllib.parse import urlencode
from loguru import logger
from pipecat import __version__ as pipecat_version
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.stt_service import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
from pipecat.utils.tracing.service_decorators import traced_stt
from .models import (
AssemblyAIConnectionParams,
BaseMessage,
BeginMessage,
TerminationMessage,
TurnMessage,
)
try:
import assemblyai as aai
from assemblyai import AudioEncoding
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error("In order to use AssemblyAI, you need to `pip install pipecat-ai[assemblyai]`.")
logger.error('In order to use AssemblyAI, you need to `pip install "pipecat-ai[assemblyai]"`.')
raise Exception(f"Missing module: {e}")
@@ -37,31 +49,37 @@ class AssemblyAISTTService(STTService):
self,
*,
api_key: str,
sample_rate: Optional[int] = None,
encoding: Optional[AudioEncoding] = None,
language=Language.EN, # Only English is supported for Realtime
language: Language = Language.EN, # AssemblyAI only supports English
api_endpoint_base_url: str = "wss://streaming.assemblyai.com/v3/ws",
connection_params: AssemblyAIConnectionParams = AssemblyAIConnectionParams(),
vad_force_turn_endpoint: bool = True,
**kwargs,
):
super().__init__(sample_rate=sample_rate, **kwargs)
self._api_key = api_key
self._language = language
self._api_endpoint_base_url = api_endpoint_base_url
self._connection_params = connection_params
self._vad_force_turn_endpoint = vad_force_turn_endpoint
encoding = encoding or AudioEncoding("pcm_s16le")
aai.settings.api_key = api_key
self._transcriber: Optional[aai.RealtimeTranscriber] = None
super().__init__(sample_rate=self._connection_params.sample_rate, **kwargs)
self._settings = {
"encoding": encoding,
"language": language,
}
self._websocket = None
self._termination_event = asyncio.Event()
self._received_termination = False
self._connected = False
self._receive_task = None
self._audio_buffer = bytearray()
self._chunk_size_ms = 50
self._chunk_size_bytes = 0
def can_generate_metrics(self) -> bool:
return True
async def set_language(self, language: Language):
logger.info(f"Switching STT language to: [{language}]")
self._settings["language"] = language
async def start(self, frame: StartFrame):
await super().start(frame)
self._chunk_size_bytes = int(self._chunk_size_ms * self._sample_rate * 2 / 1000)
await self._connect()
async def stop(self, frame: EndFrame):
@@ -73,109 +91,182 @@ class AssemblyAISTTService(STTService):
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Process an audio chunk for STT transcription.
self._audio_buffer.extend(audio)
This method streams the audio data to AssemblyAI for real-time transcription.
Transcription results are handled asynchronously via callback functions.
while len(self._audio_buffer) >= self._chunk_size_bytes:
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :]
await self._websocket.send(chunk)
:param audio: Audio data as bytes
:yield: None (transcription frames are pushed via self.push_frame in callbacks)
"""
if self._transcriber:
yield Frame()
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, UserStartedSpeakingFrame):
await self.start_ttfb_metrics()
elif isinstance(frame, UserStoppedSpeakingFrame):
if self._vad_force_turn_endpoint:
await self._websocket.send(json.dumps({"type": "ForceEndpoint"}))
await self.start_processing_metrics()
self._transcriber.stream(audio)
yield None
@traced_stt
async def _handle_transcription(
self, transcript: str, is_final: bool, language: Optional[Language] = None
):
"""Handle a transcription result with tracing."""
await self.stop_ttfb_metrics()
await self.stop_processing_metrics()
async def _trace_transcription(self, transcript: str, is_final: bool, language: Language):
"""Record transcription event for tracing."""
pass
def _build_ws_url(self) -> str:
"""Build WebSocket URL with query parameters using urllib.parse.urlencode."""
params = {
k: str(v).lower() if isinstance(v, bool) else v
for k, v in self._connection_params.model_dump().items()
if v is not None
}
if params:
query_string = urlencode(params)
return f"{self._api_endpoint_base_url}?{query_string}"
return self._api_endpoint_base_url
async def _connect(self):
"""Establish a connection to the AssemblyAI real-time transcription service.
This method sets up the necessary callback functions and initializes the
AssemblyAI transcriber.
"""
if self._transcriber:
return
def on_open(session_opened: aai.RealtimeSessionOpened):
"""Callback for when the connection to AssemblyAI is opened."""
logger.info(f"{self}: Connected to AssemblyAI")
def on_data(transcript: aai.RealtimeTranscript):
"""Callback for handling incoming transcription data.
This function runs in a separate thread from the main asyncio event loop.
It creates appropriate transcription frames and schedules them to be
pushed to the next stage of the pipeline in the main event loop.
"""
if not transcript.text:
return
timestamp = time_now_iso8601()
is_final = isinstance(transcript, aai.RealtimeFinalTranscript)
language = self._settings["language"]
if is_final:
frame = TranscriptionFrame(
transcript.text,
"",
timestamp,
language,
result=transcript,
)
else:
frame = InterimTranscriptionFrame(
transcript.text,
"",
timestamp,
language,
result=transcript,
)
asyncio.run_coroutine_threadsafe(
self._handle_transcription(transcript.text, is_final, language),
self.get_event_loop(),
try:
ws_url = self._build_ws_url()
headers = {
"Authorization": self._api_key,
"User-Agent": f"AssemblyAI/1.0 (integration=Pipecat/{pipecat_version})",
}
self._websocket = await websockets.connect(
ws_url,
extra_headers=headers,
)
# Schedule the coroutine to run in the main event loop
# This is necessary because this callback runs in a different thread
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
def on_error(error: aai.RealtimeError):
"""Callback for handling errors from AssemblyAI.
Like on_data, this runs in a separate thread and schedules error
handling in the main event loop.
"""
logger.error(f"{self}: An error occurred: {error}")
# Schedule the coroutine to run in the main event loop
asyncio.run_coroutine_threadsafe(
self.push_frame(ErrorFrame(str(error))), self.get_event_loop()
)
def on_close():
"""Callback for when the connection to AssemblyAI is closed."""
logger.info(f"{self}: Disconnected from AssemblyAI")
self._transcriber = aai.RealtimeTranscriber(
sample_rate=self.sample_rate,
encoding=self._settings["encoding"],
on_data=on_data,
on_error=on_error,
on_open=on_open,
on_close=on_close,
)
self._transcriber.connect()
self._connected = True
self._receive_task = self.create_task(self._receive_task_handler())
except Exception as e:
logger.error(f"Failed to connect to AssemblyAI: {e}")
self._connected = False
raise
async def _disconnect(self):
"""Disconnect from the AssemblyAI service and clean up resources."""
if self._transcriber:
self._transcriber.close()
self._transcriber = None
"""Disconnect from AssemblyAI WebSocket and wait for termination message."""
if not self._connected or not self._websocket:
return
try:
self._termination_event.clear()
self._received_termination = False
if len(self._audio_buffer) > 0:
await self._websocket.send(bytes(self._audio_buffer))
self._audio_buffer.clear()
try:
await self._websocket.send(json.dumps({"type": "Terminate"}))
try:
await asyncio.wait_for(
self._termination_event.wait(),
timeout=5.0,
)
except asyncio.TimeoutError:
logger.warning("Timed out waiting for termination message from server")
except Exception as e:
logger.warning(f"Error during termination handshake: {e}")
if self._receive_task:
await self.cancel_task(self._receive_task)
await self._websocket.close()
except Exception as e:
logger.error(f"Error during disconnect: {e}")
finally:
self._websocket = None
self._connected = False
self._receive_task = None
async def _receive_task_handler(self):
"""Handle incoming WebSocket messages."""
try:
while self._connected:
try:
message = await self._websocket.recv()
data = json.loads(message)
await self._handle_message(data)
except websockets.exceptions.ConnectionClosedOK:
break
except Exception as e:
logger.error(f"Error processing WebSocket message: {e}")
break
except Exception as e:
logger.error(f"Fatal error in receive handler: {e}")
def _parse_message(self, message: Dict[str, Any]) -> BaseMessage:
"""Parse a raw message into the appropriate message type."""
msg_type = message.get("type")
if msg_type == "Begin":
return BeginMessage.model_validate(message)
elif msg_type == "Turn":
return TurnMessage.model_validate(message)
elif msg_type == "Termination":
return TerminationMessage.model_validate(message)
else:
raise ValueError(f"Unknown message type: {msg_type}")
async def _handle_message(self, message: Dict[str, Any]):
"""Handle AssemblyAI WebSocket messages."""
try:
parsed_message = self._parse_message(message)
if isinstance(parsed_message, BeginMessage):
logger.debug(
f"Session Begin: {parsed_message.id} (expires at {parsed_message.expires_at})"
)
elif isinstance(parsed_message, TurnMessage):
await self._handle_transcription(parsed_message)
elif isinstance(parsed_message, TerminationMessage):
await self._handle_termination(parsed_message)
except Exception as e:
logger.error(f"Error handling message: {e}")
async def _handle_termination(self, message: TerminationMessage):
"""Handle termination message."""
self._received_termination = True
self._termination_event.set()
logger.info(
f"Session Terminated: Audio Duration={message.audio_duration_seconds}s, "
f"Session Duration={message.session_duration_seconds}s"
)
await self.push_frame(EndFrame())
async def _handle_transcription(self, message: TurnMessage):
"""Handle transcription results."""
if not message.transcript:
return
await self.stop_ttfb_metrics()
if message.end_of_turn and (
not self._connection_params.formatted_finals or message.turn_is_formatted
):
await self.push_frame(
TranscriptionFrame(
message.transcript,
"", # participant
time_now_iso8601(),
self._language,
message,
)
)
await self._trace_transcription(message.transcript, True, self._language)
await self.stop_processing_metrics()
else:
await self.push_frame(
InterimTranscriptionFrame(
message.transcript,
"", # participant
time_now_iso8601(),
self._language,
message,
)
)