From 1dbf4ff27dfd5cca1de4e967ebd62fb38270d247 Mon Sep 17 00:00:00 2001 From: Mark Backman Date: Sat, 19 Oct 2024 22:44:49 -0400 Subject: [PATCH] Add AssemblyAI STT service --- .../07o-interruptible-assemblyai.py | 97 +++++++++++ .../foundational/13c-gladia-transcription.py | 63 +++++++ .../13d-assemblyai-transcription.py | 62 +++++++ pyproject.toml | 1 + src/pipecat/services/assemblyai.py | 154 ++++++++++++++++++ 5 files changed, 377 insertions(+) create mode 100644 examples/foundational/07o-interruptible-assemblyai.py create mode 100644 examples/foundational/13c-gladia-transcription.py create mode 100644 examples/foundational/13d-assemblyai-transcription.py create mode 100644 src/pipecat/services/assemblyai.py diff --git a/examples/foundational/07o-interruptible-assemblyai.py b/examples/foundational/07o-interruptible-assemblyai.py new file mode 100644 index 000000000..76ade04c2 --- /dev/null +++ b/examples/foundational/07o-interruptible-assemblyai.py @@ -0,0 +1,97 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.audio.vad.silero import SileroVADAnalyzer +from pipecat.frames.frames import LLMMessagesFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineParams, PipelineTask +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.assemblyai import AssemblyAISTTService +from pipecat.services.cartesia import CartesiaTTSService +from pipecat.services.openai import OpenAILLMService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, token) = await configure(session) + + transport = DailyTransport( + room_url, + token, + "Respond bot", + DailyParams( + audio_out_enabled=True, + vad_enabled=True, + vad_analyzer=SileroVADAnalyzer(), + vad_audio_passthrough=True, + ), + ) + + stt = AssemblyAISTTService( + api_key=os.getenv("ASSEMBLYAI_API_KEY"), + ) + + tts = CartesiaTTSService( + api_key=os.getenv("CARTESIA_API_KEY"), + voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady + ) + + llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o") + + messages = [ + { + "role": "system", + "content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.", + }, + ] + + context = OpenAILLMContext(messages) + context_aggregator = llm.create_context_aggregator(context) + + pipeline = Pipeline( + [ + transport.input(), # Transport user input + stt, # STT + context_aggregator.user(), # User responses + llm, # LLM + tts, # TTS + transport.output(), # Transport bot output + context_aggregator.assistant(), # Assistant spoken responses + ] + ) + + task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True)) + + @transport.event_handler("on_first_participant_joined") + async def on_first_participant_joined(transport, participant): + transport.capture_participant_transcription(participant["id"]) + # Kick off the conversation. + messages.append({"role": "system", "content": "Please introduce yourself to the user."}) + await task.queue_frames([LLMMessagesFrame(messages)]) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/foundational/13c-gladia-transcription.py b/examples/foundational/13c-gladia-transcription.py new file mode 100644 index 000000000..acc21b6c2 --- /dev/null +++ b/examples/foundational/13c-gladia-transcription.py @@ -0,0 +1,63 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.frames.frames import Frame, TranscriptionFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineTask +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.services.gladia import GladiaSTTService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +class TranscriptionLogger(FrameProcessor): + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TranscriptionFrame): + print(f"Transcription: {frame.text}") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, _) = await configure(session) + + transport = DailyTransport( + room_url, None, "Transcription bot", DailyParams(audio_in_enabled=True) + ) + + stt = GladiaSTTService( + api_key=os.getenv("GLADIA_API_KEY"), + # live_options=LiveOptions(language=Language.FR), + ) + + tl = TranscriptionLogger() + + pipeline = Pipeline([transport.input(), stt, tl]) + + task = PipelineTask(pipeline) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/foundational/13d-assemblyai-transcription.py b/examples/foundational/13d-assemblyai-transcription.py new file mode 100644 index 000000000..d10a80274 --- /dev/null +++ b/examples/foundational/13d-assemblyai-transcription.py @@ -0,0 +1,62 @@ +# +# Copyright (c) 2024, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import os +import sys + +import aiohttp +from dotenv import load_dotenv +from loguru import logger +from runner import configure + +from pipecat.frames.frames import Frame, TranscriptionFrame +from pipecat.pipeline.pipeline import Pipeline +from pipecat.pipeline.runner import PipelineRunner +from pipecat.pipeline.task import PipelineTask +from pipecat.processors.frame_processor import FrameDirection, FrameProcessor +from pipecat.services.assemblyai import AssemblyAISTTService +from pipecat.transports.services.daily import DailyParams, DailyTransport + +load_dotenv(override=True) + +logger.remove(0) +logger.add(sys.stderr, level="DEBUG") + + +class TranscriptionLogger(FrameProcessor): + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + if isinstance(frame, TranscriptionFrame): + print(f"Transcription: {frame.text}") + + +async def main(): + async with aiohttp.ClientSession() as session: + (room_url, _) = await configure(session) + + transport = DailyTransport( + room_url, None, "Transcription bot", DailyParams(audio_in_enabled=True) + ) + + stt = AssemblyAISTTService( + api_key=os.getenv("ASSEMBLYAI_API_KEY"), + ) + + tl = TranscriptionLogger() + + pipeline = Pipeline([transport.input(), stt, tl]) + + task = PipelineTask(pipeline) + + runner = PipelineRunner() + + await runner.run(task) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 7625ce68a..d4ee2edd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,7 @@ Website = "https://pipecat.ai" [project.optional-dependencies] anthropic = [ "anthropic~=0.34.0" ] +assemblyai = [ "assemblyai~=0.34.0" ] aws = [ "boto3~=1.35.27" ] azure = [ "azure-cognitiveservices-speech~=1.40.0" ] canonical = [ "aiofiles~=24.1.0" ] diff --git a/src/pipecat/services/assemblyai.py b/src/pipecat/services/assemblyai.py new file mode 100644 index 000000000..9fae27069 --- /dev/null +++ b/src/pipecat/services/assemblyai.py @@ -0,0 +1,154 @@ +import asyncio +from typing import AsyncGenerator + +from loguru import logger + +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + InterimTranscriptionFrame, + StartFrame, + TranscriptionFrame, +) +from pipecat.services.ai_services import STTService +from pipecat.transcriptions.language import Language +from pipecat.utils.time import time_now_iso8601 + +try: + import assemblyai as aai + from assemblyai import AudioEncoding +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use AssemblyAI, you need to `pip install pipecat-ai[assemblyai]`. Also, set `ASSEMBLYAI_API_KEY` environment variable." + ) + raise Exception(f"Missing module: {e}") + + +class AssemblyAISTTService(STTService): + def __init__( + self, + *, + api_key: str, + sample_rate: int = 16000, + encoding: AudioEncoding = AudioEncoding("pcm_s16le"), + language=Language.EN, # Only English is supported for Realtime + **kwargs, + ): + super().__init__(**kwargs) + + aai.settings.api_key = api_key + self._transcriber: aai.RealtimeTranscriber | None = None + # Store reference to the main event loop for use in callback functions + self._loop = asyncio.get_event_loop() + + self._settings = { + "sample_rate": sample_rate, + "encoding": encoding, + "language": language, + } + + async def set_language(self, language: Language): + logger.info(f"Switching STT language to: [{language}]") + self._settings["language"] = language + await self._disconnect() + await self._connect() + + async def start(self, frame: StartFrame): + await super().start(frame) + await self._connect() + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._disconnect() + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """ + Process an audio chunk for STT transcription. + + This method streams the audio data to AssemblyAI for real-time transcription. + Transcription results are handled asynchronously via callback functions. + + :param audio: Audio data as bytes + :yield: None (transcription frames are pushed via self.push_frame in callbacks) + """ + if self._transcriber: + await self.start_processing_metrics() + self._transcriber.stream(audio) + await self.stop_processing_metrics() + yield None + + async def _connect(self): + """ + Establish a connection to the AssemblyAI real-time transcription service. + + This method sets up the necessary callback functions and initializes the + AssemblyAI transcriber. + """ + + def on_open(session_opened: aai.RealtimeSessionOpened): + """Callback for when the connection to AssemblyAI is opened.""" + logger.info(f"{self}: Connected to AssemblyAI") + + def on_data(transcript: aai.RealtimeTranscript): + """ + Callback for handling incoming transcription data. + + This function runs in a separate thread from the main asyncio event loop. + It creates appropriate transcription frames and schedules them to be + pushed to the next stage of the pipeline in the main event loop. + """ + if not transcript.text: + return + + timestamp = time_now_iso8601() + + if isinstance(transcript, aai.RealtimeFinalTranscript): + frame = TranscriptionFrame( + transcript.text, "", timestamp, self._settings["language"] + ) + else: + frame = InterimTranscriptionFrame( + transcript.text, "", timestamp, self._settings["language"] + ) + + # Schedule the coroutine to run in the main event loop + # This is necessary because this callback runs in a different thread + asyncio.run_coroutine_threadsafe(self.push_frame(frame), self._loop) + + def on_error(error: aai.RealtimeError): + """ + Callback for handling errors from AssemblyAI. + + Like on_data, this runs in a separate thread and schedules error + handling in the main event loop. + """ + logger.error(f"{self}: An error occurred: {error}") + # Schedule the coroutine to run in the main event loop + asyncio.run_coroutine_threadsafe(self.push_frame(ErrorFrame(str(error))), self._loop) + + def on_close(): + """Callback for when the connection to AssemblyAI is closed.""" + logger.info(f"{self}: Disconnected from AssemblyAI") + + self._transcriber = aai.RealtimeTranscriber( + sample_rate=self._settings["sample_rate"], + encoding=self._settings["encoding"], + on_data=on_data, + on_error=on_error, + on_open=on_open, + on_close=on_close, + ) + self._transcriber.connect() + + async def _disconnect(self): + """Disconnect from the AssemblyAI service and clean up resources.""" + if self._transcriber: + self._transcriber.close() + self._transcriber = None