Compare commits

...

1 Commits

Author SHA1 Message Date
Mark Backman
6f84a65a1a Add FalSTTService 2025-02-11 13:00:25 -05:00
3 changed files with 419 additions and 3 deletions

View File

@@ -0,0 +1,110 @@
#
# Copyright (c) 20242025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#
import asyncio
import os
import sys
import aiohttp
from dotenv import load_dotenv
from loguru import logger
from runner import configure
from pipecat.audio.vad.silero import SileroVADAnalyzer
from pipecat.pipeline.pipeline import Pipeline
from pipecat.pipeline.runner import PipelineRunner
from pipecat.pipeline.task import PipelineParams, PipelineTask
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.services.cartesia import CartesiaTTSService
from pipecat.services.fal import FalSTTService
from pipecat.services.gladia import GladiaSTTService
from pipecat.services.openai import OpenAILLMService
from pipecat.transports.services.daily import DailyParams, DailyTransport
load_dotenv(override=True)
logger.remove(0)
logger.add(sys.stderr, level="DEBUG")
async def main():
async with aiohttp.ClientSession() as session:
(room_url, token) = await configure(session)
transport = DailyTransport(
room_url,
token,
"Respond bot",
DailyParams(
audio_out_enabled=True,
vad_enabled=True,
vad_analyzer=SileroVADAnalyzer(),
vad_audio_passthrough=True,
),
)
stt = FalSTTService(
api_key=os.getenv("FAL_KEY"),
)
tts = CartesiaTTSService(
api_key=os.getenv("CARTESIA_API_KEY"),
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
)
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
messages = [
{
"role": "system",
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
},
]
context = OpenAILLMContext(messages)
context_aggregator = llm.create_context_aggregator(context)
pipeline = Pipeline(
[
transport.input(), # Transport user input
stt, # STT
context_aggregator.user(), # User responses
llm, # LLM
tts, # TTS
transport.output(), # Transport bot output
context_aggregator.assistant(), # Assistant spoken responses
]
)
task = PipelineTask(
pipeline,
PipelineParams(
allow_interruptions=True,
enable_metrics=True,
enable_usage_metrics=True,
report_only_initial_ttfb=True,
),
)
@transport.event_handler("on_first_participant_joined")
async def on_first_participant_joined(transport, participant):
await transport.capture_participant_transcription(participant["id"])
# Kick off the conversation.
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
await task.queue_frames([context_aggregator.user().get_context_frame()])
# Register an event handler to exit the application when the user leaves.
@transport.event_handler("on_participant_left")
async def on_participant_left(transport, participant, reason):
await task.cancel()
runner = PipelineRunner()
await runner.run(task)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -6,6 +6,7 @@
import io
import os
import wave
from typing import AsyncGenerator, Dict, Optional, Union
import aiohttp
@@ -13,8 +14,15 @@ from loguru import logger
from PIL import Image
from pydantic import BaseModel
from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
from pipecat.services.ai_services import ImageGenService
from pipecat.frames.frames import (
ErrorFrame,
Frame,
TranscriptionFrame,
URLImageRawFrame,
)
from pipecat.services.ai_services import ImageGenService, SegmentedSTTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
try:
import fal_client
@@ -26,6 +34,120 @@ except ModuleNotFoundError as e:
raise Exception(f"Missing module: {e}")
def language_to_fal_language(language: Language) -> Optional[str]:
"""Language support for Fal's Whisper API."""
BASE_LANGUAGES = {
Language.AF: "af",
Language.AM: "am",
Language.AR: "ar",
Language.AS: "as",
Language.AZ: "az",
Language.BA: "ba",
Language.BE: "be",
Language.BG: "bg",
Language.BN: "bn",
Language.BO: "bo",
Language.BR: "br",
Language.BS: "bs",
Language.CA: "ca",
Language.CS: "cs",
Language.CY: "cy",
Language.DA: "da",
Language.DE: "de",
Language.EL: "el",
Language.EN: "en",
Language.ES: "es",
Language.ET: "et",
Language.EU: "eu",
Language.FA: "fa",
Language.FI: "fi",
Language.FO: "fo",
Language.FR: "fr",
Language.GL: "gl",
Language.GU: "gu",
Language.HA: "ha",
Language.HE: "he",
Language.HI: "hi",
Language.HR: "hr",
Language.HT: "ht",
Language.HU: "hu",
Language.HY: "hy",
Language.ID: "id",
Language.IS: "is",
Language.IT: "it",
Language.JA: "ja",
Language.JW: "jw",
Language.KA: "ka",
Language.KK: "kk",
Language.KM: "km",
Language.KN: "kn",
Language.KO: "ko",
Language.LA: "la",
Language.LB: "lb",
Language.LN: "ln",
Language.LO: "lo",
Language.LT: "lt",
Language.LV: "lv",
Language.MG: "mg",
Language.MI: "mi",
Language.MK: "mk",
Language.ML: "ml",
Language.MN: "mn",
Language.MR: "mr",
Language.MS: "ms",
Language.MT: "mt",
Language.MY: "my",
Language.NE: "ne",
Language.NL: "nl",
Language.NN: "nn",
Language.NO: "no",
Language.OC: "oc",
Language.PA: "pa",
Language.PL: "pl",
Language.PS: "ps",
Language.PT: "pt",
Language.RO: "ro",
Language.RU: "ru",
Language.SA: "sa",
Language.SD: "sd",
Language.SI: "si",
Language.SK: "sk",
Language.SL: "sl",
Language.SN: "sn",
Language.SO: "so",
Language.SQ: "sq",
Language.SR: "sr",
Language.SU: "su",
Language.SV: "sv",
Language.SW: "sw",
Language.TA: "ta",
Language.TE: "te",
Language.TG: "tg",
Language.TH: "th",
Language.TK: "tk",
Language.TL: "tl",
Language.TR: "tr",
Language.TT: "tt",
Language.UK: "uk",
Language.UR: "ur",
Language.UZ: "uz",
Language.VI: "vi",
Language.YI: "yi",
Language.YO: "yo",
Language.ZH: "zh",
}
result = BASE_LANGUAGES.get(language)
# If not found in base languages, try to find the base language from a variant
if not result:
lang_str = str(language.value)
base_code = lang_str.split("-")[0].lower()
result = base_code if base_code in BASE_LANGUAGES.values() else None
return result
class FalImageGenService(ImageGenService):
class InputParams(BaseModel):
seed: Optional[int] = None
@@ -49,6 +171,7 @@ class FalImageGenService(ImageGenService):
self.set_model_name(model)
self._params = params
self._aiohttp_session = aiohttp_session
self._fal_client = fal_client.AsyncClient()
if key:
os.environ["FAL_KEY"] = key
@@ -80,3 +203,127 @@ class FalImageGenService(ImageGenService):
url=image_url, image=image.tobytes(), size=image.size, format=image.format
)
yield frame
class FalSTTService(SegmentedSTTService):
"""Speech-to-text service using Fal's Whisper API.
This service uses Fal's Whisper API to perform speech-to-text transcription on audio
segments. It inherits from SegmentedSTTService to handle audio buffering and speech detection.
Args:
api_key: Fal API key. If not provided, will check FAL_KEY environment variable.
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate.
params: Configuration parameters for the Whisper API.
**kwargs: Additional arguments passed to SegmentedSTTService.
"""
class InputParams(BaseModel):
language: Optional[Language] = Language.EN
task: str = "transcribe"
chunk_level: str = "segment"
version: str = "3"
"""Configuration parameters for Fal's Whisper API.
Attributes:
language: Language of the audio input. Defaults to English.
task: Task to perform ('transcribe' or 'translate'). Defaults to 'transcribe'.
chunk_level: Level of chunking ('segment'). Defaults to 'segment'.
version: Version of Whisper model to use. Defaults to '3'.
"""
def __init__(
self,
*,
api_key: Optional[str] = None,
sample_rate: Optional[int] = None,
params: InputParams = InputParams(),
# min_volume: float = 0.6,
# max_silence_secs: float = 0.3,
# max_buffer_secs: float = 1.5,
**kwargs,
):
super().__init__(
sample_rate=sample_rate,
# min_volume=min_volume,
# max_silence_secs=max_silence_secs,
# max_buffer_secs=max_buffer_secs,
**kwargs,
)
if api_key:
os.environ["FAL_KEY"] = api_key
elif "FAL_KEY" not in os.environ:
raise ValueError(
"FAL_KEY must be provided either through api_key parameter or environment variable"
)
self._fal_client = fal_client.AsyncClient(key=api_key or os.getenv("FAL_KEY"))
self._settings = {
"task": params.task,
"language": self.language_to_service_language(params.language)
if params.language
else "en",
"chunk_level": params.chunk_level,
"version": params.version,
}
def can_generate_metrics(self) -> bool:
return True
def language_to_service_language(self, language: Language) -> Optional[str]:
return language_to_fal_language(language)
async def set_language(self, language: Language):
logger.info(f"Switching STT language to: [{language}]")
self._settings["language"] = self.language_to_service_language(language)
async def set_model(self, model: str):
await super().set_model(model)
logger.info(f"Switching STT model to: [{model}]")
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
"""Transcribes an audio segment using Fal's Whisper API.
Args:
audio: Raw audio bytes in 16-bit PCM format.
Yields:
Frame: TranscriptionFrame containing the transcribed text.
Note:
The audio is converted to WAV format before being sent to the API.
Only non-empty transcriptions are yielded.
"""
try:
# Convert PCM to WAV
with io.BytesIO() as wav_buffer:
with wave.open(wav_buffer, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(self.sample_rate)
wav_file.writeframes(audio)
wav_buffer.seek(0)
wav_data = wav_buffer.read()
# Send to Fal
data_uri = fal_client.encode(wav_data, "audio/x-wav")
response = await self._fal_client.run(
"fal-ai/wizper",
arguments={"audio_url": data_uri, **self._settings},
)
# Log full response to understand what data we get
logger.debug(f"Full Fal response: {response}")
if response and "text" in response:
text = response["text"].strip()
if text: # Only yield non-empty text
logger.debug(f"Transcription: [{text}]")
yield TranscriptionFrame(
text, "", time_now_iso8601(), Language(self._settings["language"])
)
except Exception as e:
logger.error(f"Fal Wizper error: {e}")

View File

@@ -54,6 +54,9 @@ class Language(StrEnum):
AZ = "az"
AZ_AZ = "az-AZ"
# Bashkir
BA = "ba"
# Belarusian
BE = "be"
@@ -66,6 +69,12 @@ class Language(StrEnum):
BN_BD = "bn-BD"
BN_IN = "bn-IN"
# Tibetan
BO = "bo"
# Breton
BR = "br"
# Bosnian
BS = "bs"
BS_BA = "bs-BA"
@@ -158,6 +167,9 @@ class Language(StrEnum):
FIL = "fil"
FIL_PH = "fil-PH"
# Faroese
FO = "fo"
# French
FR = "fr"
FR_BE = "fr-BE"
@@ -177,6 +189,9 @@ class Language(StrEnum):
GU = "gu"
GU_IN = "gu-IN"
# Hausa
HA = "ha"
# Hebrew
HE = "he"
HE_IL = "he-IL"
@@ -189,6 +204,9 @@ class Language(StrEnum):
HR = "hr"
HR_HR = "hr-HR"
# Haitian Creole
HT = "ht"
# Hungarian
HU = "hu"
HU_HU = "hu-HU"
@@ -222,6 +240,7 @@ class Language(StrEnum):
# Javanese
JV = "jv"
JV_ID = "jv-ID"
JW = "jw" # Fal requires for Javanese
# Georgian
KA = "ka"
@@ -243,6 +262,15 @@ class Language(StrEnum):
KO = "ko"
KO_KR = "ko-KR"
# Latin
LA = "la"
# Luxembourgish
LB = "lb"
# Lingala
LN = "ln"
# Lao
LO = "lo"
LO_LA = "lo-LA"
@@ -255,6 +283,9 @@ class Language(StrEnum):
LV = "lv"
LV_LV = "lv-LV"
# Malagasy
MG = "mg"
# Macedonian
MK = "mk"
MK_MK = "mk-MK"
@@ -287,9 +318,10 @@ class Language(StrEnum):
MY_MM = "my-MM"
# Norwegian
NB = "nb"
NB = "nb" # Norwegian Bokmål
NB_NO = "nb-NO"
NO = "no"
NN = "nn" # Norwegian Nynorsk
# Nepali
NE = "ne"
@@ -300,6 +332,9 @@ class Language(StrEnum):
NL_BE = "nl-BE"
NL_NL = "nl-NL"
# Occitan
OC = "oc"
# Odia
OR = "or"
OR_IN = "or-IN"
@@ -329,6 +364,12 @@ class Language(StrEnum):
RU = "ru"
RU_RU = "ru-RU"
# Sanskrit
SA = "sa"
# Sindhi
SD = "sd"
# Sinhala
SI = "si"
SI_LK = "si-LK"
@@ -341,6 +382,9 @@ class Language(StrEnum):
SL = "sl"
SL_SI = "sl-SI"
# Shona
SN = "sn"
# Somali
SO = "so"
SO_SO = "so-SO"
@@ -382,14 +426,23 @@ class Language(StrEnum):
TE = "te"
TE_IN = "te-IN"
# Tajik
TG = "tg"
# Thai
TH = "th"
TH_TH = "th-TH"
# Turkmen
TK = "tk"
# Turkish
TR = "tr"
TR_TR = "tr-TR"
# Tatar
TT = "tt"
# Ukrainian
UK = "uk"
UK_UA = "uk-UA"
@@ -411,6 +464,12 @@ class Language(StrEnum):
WUU = "wuu"
WUU_CN = "wuu-CN"
# Yiddish
YI = "yi"
# Yoruba
YO = "yo"
# Yue Chinese
YUE = "yue"
YUE_CN = "yue-CN"