Add AssemblyAI STT service

This commit is contained in:
Mark Backman
2024-10-19 22:44:49 -04:00
parent 4f1b2dce9b
commit 1dbf4ff27d
5 changed files with 377 additions and 0 deletions

View File

@@ -0,0 +1,97 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.frames.frames import LLMMessagesFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.assemblyai import AssemblyAISTTService
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.openai import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
),
)
stt = AssemblyAISTTService(
api_key=os.getenv("ASSEMBLYAI_API_KEY"),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt, # STT
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(pipeline, PipelineParams(allow_interruptions=True))
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([LLMMessagesFrame(messages)])
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,63 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.frames.frames import Frame, TranscriptionFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.gladia import GladiaSTTService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class TranscriptionLogger(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
print(f"Transcription: {frame.text}")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, _) = await configure(session)
transport = DailyTransport(
room_url, None, "Transcription bot", DailyParams(audio_in_enabled=True)
)
stt = GladiaSTTService(
api_key=os.getenv("GLADIA_API_KEY"),
# live_options=LiveOptions(language=Language.FR),
)
tl = TranscriptionLogger()
pipeline = Pipeline([transport.input(), stt, tl])
task = PipelineTask(pipeline)
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,62 @@
#
# Copyright (c) 2024, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.frames.frames import Frame, TranscriptionFrame
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
from pipecat.services.assemblyai import AssemblyAISTTService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
class TranscriptionLogger(FrameProcessor):
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, TranscriptionFrame):
print(f"Transcription: {frame.text}")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, _) = await configure(session)
transport = DailyTransport(
room_url, None, "Transcription bot", DailyParams(audio_in_enabled=True)
)
stt = AssemblyAISTTService(
api_key=os.getenv("ASSEMBLYAI_API_KEY"),
)
tl = TranscriptionLogger()
pipeline = Pipeline([transport.input(), stt, tl])
task = PipelineTask(pipeline)
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -37,6 +37,7 @@ Website = "https://pipecat.ai"
[project.optional-dependencies]
anthropic = [ "anthropic~=0.34.0" ]
assemblyai = [ "assemblyai~=0.34.0" ]
aws = [ "boto3~=1.35.27" ]
azure = [ "azure-cognitiveservices-speech~=1.40.0" ]
canonical = [ "aiofiles~=24.1.0" ]

View File

@@ -0,0 +1,154 @@
import asyncio
from typing import AsyncGenerator
from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
)
from pipecat.services.ai_services import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
try:
import assemblyai as aai
from assemblyai import AudioEncoding
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use AssemblyAI, you need to `pip install pipecat-ai[assemblyai]`. Also, set `ASSEMBLYAI_API_KEY` environment variable."
)
raise Exception(f"Missing module: {e}")
class AssemblyAISTTService(STTService):
def __init__(
self,
*,
api_key: str,
sample_rate: int = 16000,
encoding: AudioEncoding = AudioEncoding("pcm_s16le"),
language=Language.EN, # Only English is supported for Realtime
**kwargs,
):
super().__init__(**kwargs)
aai.settings.api_key = api_key
self._transcriber: aai.RealtimeTranscriber | None = None
# Store reference to the main event loop for use in callback functions
self._loop = asyncio.get_event_loop()
self._settings = {
"sample_rate": sample_rate,
"encoding": encoding,
"language": language,
}
async def set_language(self, language: Language):
logger.info(f"Switching STT language to: [{language}]")
self._settings["language"] = language
await self._disconnect()
await self._connect()
async def start(self, frame: StartFrame):
await super().start(frame)
await self._connect()
async def stop(self, frame: EndFrame):
await super().stop(frame)
await self._disconnect()
async def cancel(self, frame: CancelFrame):
await super().cancel(frame)
await self._disconnect()
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""
Process an audio chunk for STT transcription.
This method streams the audio data to AssemblyAI for real-time transcription.
Transcription results are handled asynchronously via callback functions.
:param audio: Audio data as bytes
:yield: None (transcription frames are pushed via self.push_frame in callbacks)
"""
if self._transcriber:
await self.start_processing_metrics()
self._transcriber.stream(audio)
await self.stop_processing_metrics()
yield None
async def _connect(self):
"""
Establish a connection to the AssemblyAI real-time transcription service.
This method sets up the necessary callback functions and initializes the
AssemblyAI transcriber.
"""
def on_open(session_opened: aai.RealtimeSessionOpened):
"""Callback for when the connection to AssemblyAI is opened."""
logger.info(f"{self}: Connected to AssemblyAI")
def on_data(transcript: aai.RealtimeTranscript):
"""
Callback for handling incoming transcription data.
This function runs in a separate thread from the main asyncio event loop.
It creates appropriate transcription frames and schedules them to be
pushed to the next stage of the pipeline in the main event loop.
"""
if not transcript.text:
return
timestamp = time_now_iso8601()
if isinstance(transcript, aai.RealtimeFinalTranscript):
frame = TranscriptionFrame(
transcript.text, "", timestamp, self._settings["language"]
)
else:
frame = InterimTranscriptionFrame(
transcript.text, "", timestamp, self._settings["language"]
)
# Schedule the coroutine to run in the main event loop
# This is necessary because this callback runs in a different thread
asyncio.run_coroutine_threadsafe(self.push_frame(frame), self._loop)
def on_error(error: aai.RealtimeError):
"""
Callback for handling errors from AssemblyAI.
Like on_data, this runs in a separate thread and schedules error
handling in the main event loop.
"""
logger.error(f"{self}: An error occurred: {error}")
# Schedule the coroutine to run in the main event loop
asyncio.run_coroutine_threadsafe(self.push_frame(ErrorFrame(str(error))), self._loop)
def on_close():
"""Callback for when the connection to AssemblyAI is closed."""
logger.info(f"{self}: Disconnected from AssemblyAI")
self._transcriber = aai.RealtimeTranscriber(
sample_rate=self._settings["sample_rate"],
encoding=self._settings["encoding"],
on_data=on_data,
on_error=on_error,
on_open=on_open,
on_close=on_close,
)
self._transcriber.connect()
async def _disconnect(self):
"""Disconnect from the AssemblyAI service and clean up resources."""
if self._transcriber:
self._transcriber.close()
self._transcriber = None