Merge pull request #1876 from m-ods/m-ods/assemblyai-universal-streaming
Update AssemblyAI Streaming STT
This commit is contained in:
@@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
### Added
|
||||
|
||||
- Updated AssemblyAI STT service to support their latest streaming
|
||||
speech-to-text model with improved transcription latency and endpointing.
|
||||
|
||||
- You can now access STT service results through the new
|
||||
`TranscriptionFrame.result` and `InterimTranscriptionFrame.result` field. This
|
||||
is useful in case you use some specific settings for the STT and you want to
|
||||
|
||||
@@ -41,7 +41,7 @@ Website = "https://pipecat.ai"
|
||||
|
||||
[project.optional-dependencies]
|
||||
anthropic = [ "anthropic~=0.49.0" ]
|
||||
assemblyai = [ "assemblyai~=0.37.0" ]
|
||||
assemblyai = [ "websockets~=13.1" ]
|
||||
aws = [ "boto3~=1.37.16", "websockets~=13.1" ]
|
||||
aws-nova-sonic = [ "aws_sdk_bedrock_runtime~=0.0.2" ]
|
||||
azure = [ "azure-cognitiveservices-speech~=1.42.0"]
|
||||
|
||||
61
src/pipecat/services/assemblyai/models.py
Normal file
61
src/pipecat/services/assemblyai/models.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Word(BaseModel):
|
||||
"""Represents a single word in a transcription with timing and confidence."""
|
||||
|
||||
start: int
|
||||
end: int
|
||||
text: str
|
||||
confidence: float
|
||||
word_is_final: bool = Field(..., alias="word_is_final")
|
||||
|
||||
|
||||
class BaseMessage(BaseModel):
|
||||
"""Base class for all AssemblyAI WebSocket messages."""
|
||||
|
||||
type: str
|
||||
|
||||
|
||||
class BeginMessage(BaseMessage):
|
||||
"""Message sent when a new session begins."""
|
||||
|
||||
type: Literal["Begin"] = "Begin"
|
||||
id: str
|
||||
expires_at: int
|
||||
|
||||
|
||||
class TurnMessage(BaseMessage):
|
||||
"""Message containing transcription data for a turn of speech."""
|
||||
|
||||
type: Literal["Turn"] = "Turn"
|
||||
turn_order: int
|
||||
turn_is_formatted: bool
|
||||
end_of_turn: bool
|
||||
transcript: str
|
||||
end_of_turn_confidence: float
|
||||
words: List[Word]
|
||||
|
||||
|
||||
class TerminationMessage(BaseMessage):
|
||||
"""Message sent when the session is terminated."""
|
||||
|
||||
type: Literal["Termination"] = "Termination"
|
||||
audio_duration_seconds: float
|
||||
session_duration_seconds: float
|
||||
|
||||
|
||||
# Union type for all possible message types
|
||||
AnyMessage = BeginMessage | TurnMessage | TerminationMessage
|
||||
|
||||
|
||||
class AssemblyAIConnectionParams(BaseModel):
|
||||
sample_rate: int = 16000
|
||||
encoding: Literal["pcm_s16le", "pcm_mulaw"] = "pcm_s16le"
|
||||
formatted_finals: bool = True
|
||||
word_finalization_max_wait_time: Optional[int] = None
|
||||
end_of_turn_confidence_threshold: Optional[float] = None
|
||||
min_end_of_turn_silence_when_confident: Optional[int] = None
|
||||
max_turn_silence: Optional[int] = None
|
||||
@@ -5,30 +5,42 @@
|
||||
#
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Optional
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, Dict
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from pipecat import __version__ as pipecat_version
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
InterimTranscriptionFrame,
|
||||
StartFrame,
|
||||
TranscriptionFrame,
|
||||
UserStartedSpeakingFrame,
|
||||
UserStoppedSpeakingFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.services.stt_service import STTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
from pipecat.utils.tracing.service_decorators import traced_stt
|
||||
|
||||
from .models import (
|
||||
AssemblyAIConnectionParams,
|
||||
BaseMessage,
|
||||
BeginMessage,
|
||||
TerminationMessage,
|
||||
TurnMessage,
|
||||
)
|
||||
|
||||
try:
|
||||
import assemblyai as aai
|
||||
from assemblyai import AudioEncoding
|
||||
import websockets
|
||||
except ModuleNotFoundError as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
logger.error("In order to use AssemblyAI, you need to `pip install pipecat-ai[assemblyai]`.")
|
||||
logger.error('In order to use AssemblyAI, you need to `pip install "pipecat-ai[assemblyai]"`.')
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
@@ -37,31 +49,37 @@ class AssemblyAISTTService(STTService):
|
||||
self,
|
||||
*,
|
||||
api_key: str,
|
||||
sample_rate: Optional[int] = None,
|
||||
encoding: Optional[AudioEncoding] = None,
|
||||
language=Language.EN, # Only English is supported for Realtime
|
||||
language: Language = Language.EN, # AssemblyAI only supports English
|
||||
api_endpoint_base_url: str = "wss://streaming.assemblyai.com/v3/ws",
|
||||
connection_params: AssemblyAIConnectionParams = AssemblyAIConnectionParams(),
|
||||
vad_force_turn_endpoint: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(sample_rate=sample_rate, **kwargs)
|
||||
self._api_key = api_key
|
||||
self._language = language
|
||||
self._api_endpoint_base_url = api_endpoint_base_url
|
||||
self._connection_params = connection_params
|
||||
self._vad_force_turn_endpoint = vad_force_turn_endpoint
|
||||
|
||||
encoding = encoding or AudioEncoding("pcm_s16le")
|
||||
aai.settings.api_key = api_key
|
||||
self._transcriber: Optional[aai.RealtimeTranscriber] = None
|
||||
super().__init__(sample_rate=self._connection_params.sample_rate, **kwargs)
|
||||
|
||||
self._settings = {
|
||||
"encoding": encoding,
|
||||
"language": language,
|
||||
}
|
||||
self._websocket = None
|
||||
self._termination_event = asyncio.Event()
|
||||
self._received_termination = False
|
||||
self._connected = False
|
||||
|
||||
self._receive_task = None
|
||||
|
||||
self._audio_buffer = bytearray()
|
||||
self._chunk_size_ms = 50
|
||||
self._chunk_size_bytes = 0
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._settings["language"] = language
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
await super().start(frame)
|
||||
self._chunk_size_bytes = int(self._chunk_size_ms * self._sample_rate * 2 / 1000)
|
||||
await self._connect()
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
@@ -73,109 +91,182 @@ class AssemblyAISTTService(STTService):
|
||||
await self._disconnect()
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Process an audio chunk for STT transcription.
|
||||
self._audio_buffer.extend(audio)
|
||||
|
||||
This method streams the audio data to AssemblyAI for real-time transcription.
|
||||
Transcription results are handled asynchronously via callback functions.
|
||||
while len(self._audio_buffer) >= self._chunk_size_bytes:
|
||||
chunk = bytes(self._audio_buffer[: self._chunk_size_bytes])
|
||||
self._audio_buffer = self._audio_buffer[self._chunk_size_bytes :]
|
||||
await self._websocket.send(chunk)
|
||||
|
||||
:param audio: Audio data as bytes
|
||||
:yield: None (transcription frames are pushed via self.push_frame in callbacks)
|
||||
"""
|
||||
if self._transcriber:
|
||||
yield Frame()
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
if isinstance(frame, UserStartedSpeakingFrame):
|
||||
await self.start_ttfb_metrics()
|
||||
elif isinstance(frame, UserStoppedSpeakingFrame):
|
||||
if self._vad_force_turn_endpoint:
|
||||
await self._websocket.send(json.dumps({"type": "ForceEndpoint"}))
|
||||
await self.start_processing_metrics()
|
||||
self._transcriber.stream(audio)
|
||||
yield None
|
||||
|
||||
@traced_stt
|
||||
async def _handle_transcription(
|
||||
self, transcript: str, is_final: bool, language: Optional[Language] = None
|
||||
):
|
||||
"""Handle a transcription result with tracing."""
|
||||
await self.stop_ttfb_metrics()
|
||||
await self.stop_processing_metrics()
|
||||
async def _trace_transcription(self, transcript: str, is_final: bool, language: Language):
|
||||
"""Record transcription event for tracing."""
|
||||
pass
|
||||
|
||||
def _build_ws_url(self) -> str:
|
||||
"""Build WebSocket URL with query parameters using urllib.parse.urlencode."""
|
||||
params = {
|
||||
k: str(v).lower() if isinstance(v, bool) else v
|
||||
for k, v in self._connection_params.model_dump().items()
|
||||
if v is not None
|
||||
}
|
||||
if params:
|
||||
query_string = urlencode(params)
|
||||
return f"{self._api_endpoint_base_url}?{query_string}"
|
||||
return self._api_endpoint_base_url
|
||||
|
||||
async def _connect(self):
|
||||
"""Establish a connection to the AssemblyAI real-time transcription service.
|
||||
|
||||
This method sets up the necessary callback functions and initializes the
|
||||
AssemblyAI transcriber.
|
||||
"""
|
||||
if self._transcriber:
|
||||
return
|
||||
|
||||
def on_open(session_opened: aai.RealtimeSessionOpened):
|
||||
"""Callback for when the connection to AssemblyAI is opened."""
|
||||
logger.info(f"{self}: Connected to AssemblyAI")
|
||||
|
||||
def on_data(transcript: aai.RealtimeTranscript):
|
||||
"""Callback for handling incoming transcription data.
|
||||
|
||||
This function runs in a separate thread from the main asyncio event loop.
|
||||
It creates appropriate transcription frames and schedules them to be
|
||||
pushed to the next stage of the pipeline in the main event loop.
|
||||
"""
|
||||
if not transcript.text:
|
||||
return
|
||||
|
||||
timestamp = time_now_iso8601()
|
||||
is_final = isinstance(transcript, aai.RealtimeFinalTranscript)
|
||||
language = self._settings["language"]
|
||||
|
||||
if is_final:
|
||||
frame = TranscriptionFrame(
|
||||
transcript.text,
|
||||
"",
|
||||
timestamp,
|
||||
language,
|
||||
result=transcript,
|
||||
)
|
||||
else:
|
||||
frame = InterimTranscriptionFrame(
|
||||
transcript.text,
|
||||
"",
|
||||
timestamp,
|
||||
language,
|
||||
result=transcript,
|
||||
)
|
||||
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._handle_transcription(transcript.text, is_final, language),
|
||||
self.get_event_loop(),
|
||||
try:
|
||||
ws_url = self._build_ws_url()
|
||||
headers = {
|
||||
"Authorization": self._api_key,
|
||||
"User-Agent": f"AssemblyAI/1.0 (integration=Pipecat/{pipecat_version})",
|
||||
}
|
||||
self._websocket = await websockets.connect(
|
||||
ws_url,
|
||||
extra_headers=headers,
|
||||
)
|
||||
|
||||
# Schedule the coroutine to run in the main event loop
|
||||
# This is necessary because this callback runs in a different thread
|
||||
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
|
||||
|
||||
def on_error(error: aai.RealtimeError):
|
||||
"""Callback for handling errors from AssemblyAI.
|
||||
|
||||
Like on_data, this runs in a separate thread and schedules error
|
||||
handling in the main event loop.
|
||||
"""
|
||||
logger.error(f"{self}: An error occurred: {error}")
|
||||
# Schedule the coroutine to run in the main event loop
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.push_frame(ErrorFrame(str(error))), self.get_event_loop()
|
||||
)
|
||||
|
||||
def on_close():
|
||||
"""Callback for when the connection to AssemblyAI is closed."""
|
||||
logger.info(f"{self}: Disconnected from AssemblyAI")
|
||||
|
||||
self._transcriber = aai.RealtimeTranscriber(
|
||||
sample_rate=self.sample_rate,
|
||||
encoding=self._settings["encoding"],
|
||||
on_data=on_data,
|
||||
on_error=on_error,
|
||||
on_open=on_open,
|
||||
on_close=on_close,
|
||||
)
|
||||
self._transcriber.connect()
|
||||
self._connected = True
|
||||
self._receive_task = self.create_task(self._receive_task_handler())
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to AssemblyAI: {e}")
|
||||
self._connected = False
|
||||
raise
|
||||
|
||||
async def _disconnect(self):
|
||||
"""Disconnect from the AssemblyAI service and clean up resources."""
|
||||
if self._transcriber:
|
||||
self._transcriber.close()
|
||||
self._transcriber = None
|
||||
"""Disconnect from AssemblyAI WebSocket and wait for termination message."""
|
||||
if not self._connected or not self._websocket:
|
||||
return
|
||||
|
||||
try:
|
||||
self._termination_event.clear()
|
||||
self._received_termination = False
|
||||
|
||||
if len(self._audio_buffer) > 0:
|
||||
await self._websocket.send(bytes(self._audio_buffer))
|
||||
self._audio_buffer.clear()
|
||||
|
||||
try:
|
||||
await self._websocket.send(json.dumps({"type": "Terminate"}))
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._termination_event.wait(),
|
||||
timeout=5.0,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timed out waiting for termination message from server")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during termination handshake: {e}")
|
||||
|
||||
if self._receive_task:
|
||||
await self.cancel_task(self._receive_task)
|
||||
|
||||
await self._websocket.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during disconnect: {e}")
|
||||
|
||||
finally:
|
||||
self._websocket = None
|
||||
self._connected = False
|
||||
self._receive_task = None
|
||||
|
||||
async def _receive_task_handler(self):
|
||||
"""Handle incoming WebSocket messages."""
|
||||
try:
|
||||
while self._connected:
|
||||
try:
|
||||
message = await self._websocket.recv()
|
||||
data = json.loads(message)
|
||||
await self._handle_message(data)
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing WebSocket message: {e}")
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fatal error in receive handler: {e}")
|
||||
|
||||
def _parse_message(self, message: Dict[str, Any]) -> BaseMessage:
|
||||
"""Parse a raw message into the appropriate message type."""
|
||||
msg_type = message.get("type")
|
||||
|
||||
if msg_type == "Begin":
|
||||
return BeginMessage.model_validate(message)
|
||||
elif msg_type == "Turn":
|
||||
return TurnMessage.model_validate(message)
|
||||
elif msg_type == "Termination":
|
||||
return TerminationMessage.model_validate(message)
|
||||
else:
|
||||
raise ValueError(f"Unknown message type: {msg_type}")
|
||||
|
||||
async def _handle_message(self, message: Dict[str, Any]):
|
||||
"""Handle AssemblyAI WebSocket messages."""
|
||||
try:
|
||||
parsed_message = self._parse_message(message)
|
||||
|
||||
if isinstance(parsed_message, BeginMessage):
|
||||
logger.debug(
|
||||
f"Session Begin: {parsed_message.id} (expires at {parsed_message.expires_at})"
|
||||
)
|
||||
elif isinstance(parsed_message, TurnMessage):
|
||||
await self._handle_transcription(parsed_message)
|
||||
elif isinstance(parsed_message, TerminationMessage):
|
||||
await self._handle_termination(parsed_message)
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling message: {e}")
|
||||
|
||||
async def _handle_termination(self, message: TerminationMessage):
|
||||
"""Handle termination message."""
|
||||
self._received_termination = True
|
||||
self._termination_event.set()
|
||||
|
||||
logger.info(
|
||||
f"Session Terminated: Audio Duration={message.audio_duration_seconds}s, "
|
||||
f"Session Duration={message.session_duration_seconds}s"
|
||||
)
|
||||
await self.push_frame(EndFrame())
|
||||
|
||||
async def _handle_transcription(self, message: TurnMessage):
|
||||
"""Handle transcription results."""
|
||||
if not message.transcript:
|
||||
return
|
||||
await self.stop_ttfb_metrics()
|
||||
if message.end_of_turn and (
|
||||
not self._connection_params.formatted_finals or message.turn_is_formatted
|
||||
):
|
||||
await self.push_frame(
|
||||
TranscriptionFrame(
|
||||
message.transcript,
|
||||
"", # participant
|
||||
time_now_iso8601(),
|
||||
self._language,
|
||||
message,
|
||||
)
|
||||
)
|
||||
await self._trace_transcription(message.transcript, True, self._language)
|
||||
await self.stop_processing_metrics()
|
||||
else:
|
||||
await self.push_frame(
|
||||
InterimTranscriptionFrame(
|
||||
message.transcript,
|
||||
"", # participant
|
||||
time_now_iso8601(),
|
||||
self._language,
|
||||
message,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user