Compare commits
1 Commits
hush/reset
...
mb/fal-whi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f84a65a1a |
110
examples/foundational/07u-interruptible-fal.py
Normal file
110
examples/foundational/07u-interruptible-fal.py
Normal file
@@ -0,0 +1,110 @@
|
||||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import aiohttp
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
from runner import configure
|
||||
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.pipeline.runner import PipelineRunner
|
||||
from pipecat.pipeline.task import PipelineParams, PipelineTask
|
||||
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
|
||||
from pipecat.services.cartesia import CartesiaTTSService
|
||||
from pipecat.services.fal import FalSTTService
|
||||
from pipecat.services.gladia import GladiaSTTService
|
||||
from pipecat.services.openai import OpenAILLMService
|
||||
from pipecat.transports.services.daily import DailyParams, DailyTransport
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
logger.remove(0)
|
||||
logger.add(sys.stderr, level="DEBUG")
|
||||
|
||||
|
||||
async def main():
|
||||
async with aiohttp.ClientSession() as session:
|
||||
(room_url, token) = await configure(session)
|
||||
|
||||
transport = DailyTransport(
|
||||
room_url,
|
||||
token,
|
||||
"Respond bot",
|
||||
DailyParams(
|
||||
audio_out_enabled=True,
|
||||
vad_enabled=True,
|
||||
vad_analyzer=SileroVADAnalyzer(),
|
||||
vad_audio_passthrough=True,
|
||||
),
|
||||
)
|
||||
|
||||
stt = FalSTTService(
|
||||
api_key=os.getenv("FAL_KEY"),
|
||||
)
|
||||
|
||||
tts = CartesiaTTSService(
|
||||
api_key=os.getenv("CARTESIA_API_KEY"),
|
||||
voice_id="79a125e8-cd45-4c13-8a67-188112f4dd22", # British Lady
|
||||
)
|
||||
|
||||
llm = OpenAILLMService(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o")
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful LLM in a WebRTC call. Your goal is to demonstrate your capabilities in a succinct way. Your output will be converted to audio so don't include special characters in your answers. Respond to what the user said in a creative and helpful way.",
|
||||
},
|
||||
]
|
||||
|
||||
context = OpenAILLMContext(messages)
|
||||
context_aggregator = llm.create_context_aggregator(context)
|
||||
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(), # Transport user input
|
||||
stt, # STT
|
||||
context_aggregator.user(), # User responses
|
||||
llm, # LLM
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
context_aggregator.assistant(), # Assistant spoken responses
|
||||
]
|
||||
)
|
||||
|
||||
task = PipelineTask(
|
||||
pipeline,
|
||||
PipelineParams(
|
||||
allow_interruptions=True,
|
||||
enable_metrics=True,
|
||||
enable_usage_metrics=True,
|
||||
report_only_initial_ttfb=True,
|
||||
),
|
||||
)
|
||||
|
||||
@transport.event_handler("on_first_participant_joined")
|
||||
async def on_first_participant_joined(transport, participant):
|
||||
await transport.capture_participant_transcription(participant["id"])
|
||||
# Kick off the conversation.
|
||||
messages.append({"role": "system", "content": "Please introduce yourself to the user."})
|
||||
await task.queue_frames([context_aggregator.user().get_context_frame()])
|
||||
|
||||
# Register an event handler to exit the application when the user leaves.
|
||||
@transport.event_handler("on_participant_left")
|
||||
async def on_participant_left(transport, participant, reason):
|
||||
await task.cancel()
|
||||
|
||||
runner = PipelineRunner()
|
||||
|
||||
await runner.run(task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
import io
|
||||
import os
|
||||
import wave
|
||||
from typing import AsyncGenerator, Dict, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
@@ -13,8 +14,15 @@ from loguru import logger
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
|
||||
from pipecat.services.ai_services import ImageGenService
|
||||
from pipecat.frames.frames import (
|
||||
ErrorFrame,
|
||||
Frame,
|
||||
TranscriptionFrame,
|
||||
URLImageRawFrame,
|
||||
)
|
||||
from pipecat.services.ai_services import ImageGenService, SegmentedSTTService
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.time import time_now_iso8601
|
||||
|
||||
try:
|
||||
import fal_client
|
||||
@@ -26,6 +34,120 @@ except ModuleNotFoundError as e:
|
||||
raise Exception(f"Missing module: {e}")
|
||||
|
||||
|
||||
def language_to_fal_language(language: Language) -> Optional[str]:
|
||||
"""Language support for Fal's Whisper API."""
|
||||
BASE_LANGUAGES = {
|
||||
Language.AF: "af",
|
||||
Language.AM: "am",
|
||||
Language.AR: "ar",
|
||||
Language.AS: "as",
|
||||
Language.AZ: "az",
|
||||
Language.BA: "ba",
|
||||
Language.BE: "be",
|
||||
Language.BG: "bg",
|
||||
Language.BN: "bn",
|
||||
Language.BO: "bo",
|
||||
Language.BR: "br",
|
||||
Language.BS: "bs",
|
||||
Language.CA: "ca",
|
||||
Language.CS: "cs",
|
||||
Language.CY: "cy",
|
||||
Language.DA: "da",
|
||||
Language.DE: "de",
|
||||
Language.EL: "el",
|
||||
Language.EN: "en",
|
||||
Language.ES: "es",
|
||||
Language.ET: "et",
|
||||
Language.EU: "eu",
|
||||
Language.FA: "fa",
|
||||
Language.FI: "fi",
|
||||
Language.FO: "fo",
|
||||
Language.FR: "fr",
|
||||
Language.GL: "gl",
|
||||
Language.GU: "gu",
|
||||
Language.HA: "ha",
|
||||
Language.HE: "he",
|
||||
Language.HI: "hi",
|
||||
Language.HR: "hr",
|
||||
Language.HT: "ht",
|
||||
Language.HU: "hu",
|
||||
Language.HY: "hy",
|
||||
Language.ID: "id",
|
||||
Language.IS: "is",
|
||||
Language.IT: "it",
|
||||
Language.JA: "ja",
|
||||
Language.JW: "jw",
|
||||
Language.KA: "ka",
|
||||
Language.KK: "kk",
|
||||
Language.KM: "km",
|
||||
Language.KN: "kn",
|
||||
Language.KO: "ko",
|
||||
Language.LA: "la",
|
||||
Language.LB: "lb",
|
||||
Language.LN: "ln",
|
||||
Language.LO: "lo",
|
||||
Language.LT: "lt",
|
||||
Language.LV: "lv",
|
||||
Language.MG: "mg",
|
||||
Language.MI: "mi",
|
||||
Language.MK: "mk",
|
||||
Language.ML: "ml",
|
||||
Language.MN: "mn",
|
||||
Language.MR: "mr",
|
||||
Language.MS: "ms",
|
||||
Language.MT: "mt",
|
||||
Language.MY: "my",
|
||||
Language.NE: "ne",
|
||||
Language.NL: "nl",
|
||||
Language.NN: "nn",
|
||||
Language.NO: "no",
|
||||
Language.OC: "oc",
|
||||
Language.PA: "pa",
|
||||
Language.PL: "pl",
|
||||
Language.PS: "ps",
|
||||
Language.PT: "pt",
|
||||
Language.RO: "ro",
|
||||
Language.RU: "ru",
|
||||
Language.SA: "sa",
|
||||
Language.SD: "sd",
|
||||
Language.SI: "si",
|
||||
Language.SK: "sk",
|
||||
Language.SL: "sl",
|
||||
Language.SN: "sn",
|
||||
Language.SO: "so",
|
||||
Language.SQ: "sq",
|
||||
Language.SR: "sr",
|
||||
Language.SU: "su",
|
||||
Language.SV: "sv",
|
||||
Language.SW: "sw",
|
||||
Language.TA: "ta",
|
||||
Language.TE: "te",
|
||||
Language.TG: "tg",
|
||||
Language.TH: "th",
|
||||
Language.TK: "tk",
|
||||
Language.TL: "tl",
|
||||
Language.TR: "tr",
|
||||
Language.TT: "tt",
|
||||
Language.UK: "uk",
|
||||
Language.UR: "ur",
|
||||
Language.UZ: "uz",
|
||||
Language.VI: "vi",
|
||||
Language.YI: "yi",
|
||||
Language.YO: "yo",
|
||||
Language.ZH: "zh",
|
||||
}
|
||||
|
||||
result = BASE_LANGUAGES.get(language)
|
||||
|
||||
# If not found in base languages, try to find the base language from a variant
|
||||
if not result:
|
||||
lang_str = str(language.value)
|
||||
base_code = lang_str.split("-")[0].lower()
|
||||
result = base_code if base_code in BASE_LANGUAGES.values() else None
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class FalImageGenService(ImageGenService):
|
||||
class InputParams(BaseModel):
|
||||
seed: Optional[int] = None
|
||||
@@ -49,6 +171,7 @@ class FalImageGenService(ImageGenService):
|
||||
self.set_model_name(model)
|
||||
self._params = params
|
||||
self._aiohttp_session = aiohttp_session
|
||||
self._fal_client = fal_client.AsyncClient()
|
||||
if key:
|
||||
os.environ["FAL_KEY"] = key
|
||||
|
||||
@@ -80,3 +203,127 @@ class FalImageGenService(ImageGenService):
|
||||
url=image_url, image=image.tobytes(), size=image.size, format=image.format
|
||||
)
|
||||
yield frame
|
||||
|
||||
|
||||
class FalSTTService(SegmentedSTTService):
|
||||
"""Speech-to-text service using Fal's Whisper API.
|
||||
|
||||
This service uses Fal's Whisper API to perform speech-to-text transcription on audio
|
||||
segments. It inherits from SegmentedSTTService to handle audio buffering and speech detection.
|
||||
|
||||
Args:
|
||||
api_key: Fal API key. If not provided, will check FAL_KEY environment variable.
|
||||
sample_rate: Audio sample rate in Hz. If not provided, uses the pipeline's rate.
|
||||
params: Configuration parameters for the Whisper API.
|
||||
**kwargs: Additional arguments passed to SegmentedSTTService.
|
||||
"""
|
||||
|
||||
class InputParams(BaseModel):
|
||||
language: Optional[Language] = Language.EN
|
||||
task: str = "transcribe"
|
||||
chunk_level: str = "segment"
|
||||
version: str = "3"
|
||||
"""Configuration parameters for Fal's Whisper API.
|
||||
|
||||
Attributes:
|
||||
language: Language of the audio input. Defaults to English.
|
||||
task: Task to perform ('transcribe' or 'translate'). Defaults to 'transcribe'.
|
||||
chunk_level: Level of chunking ('segment'). Defaults to 'segment'.
|
||||
version: Version of Whisper model to use. Defaults to '3'.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
api_key: Optional[str] = None,
|
||||
sample_rate: Optional[int] = None,
|
||||
params: InputParams = InputParams(),
|
||||
# min_volume: float = 0.6,
|
||||
# max_silence_secs: float = 0.3,
|
||||
# max_buffer_secs: float = 1.5,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
# min_volume=min_volume,
|
||||
# max_silence_secs=max_silence_secs,
|
||||
# max_buffer_secs=max_buffer_secs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if api_key:
|
||||
os.environ["FAL_KEY"] = api_key
|
||||
elif "FAL_KEY" not in os.environ:
|
||||
raise ValueError(
|
||||
"FAL_KEY must be provided either through api_key parameter or environment variable"
|
||||
)
|
||||
|
||||
self._fal_client = fal_client.AsyncClient(key=api_key or os.getenv("FAL_KEY"))
|
||||
self._settings = {
|
||||
"task": params.task,
|
||||
"language": self.language_to_service_language(params.language)
|
||||
if params.language
|
||||
else "en",
|
||||
"chunk_level": params.chunk_level,
|
||||
"version": params.version,
|
||||
}
|
||||
|
||||
def can_generate_metrics(self) -> bool:
|
||||
return True
|
||||
|
||||
def language_to_service_language(self, language: Language) -> Optional[str]:
|
||||
return language_to_fal_language(language)
|
||||
|
||||
async def set_language(self, language: Language):
|
||||
logger.info(f"Switching STT language to: [{language}]")
|
||||
self._settings["language"] = self.language_to_service_language(language)
|
||||
|
||||
async def set_model(self, model: str):
|
||||
await super().set_model(model)
|
||||
logger.info(f"Switching STT model to: [{model}]")
|
||||
|
||||
async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
|
||||
"""Transcribes an audio segment using Fal's Whisper API.
|
||||
|
||||
Args:
|
||||
audio: Raw audio bytes in 16-bit PCM format.
|
||||
|
||||
Yields:
|
||||
Frame: TranscriptionFrame containing the transcribed text.
|
||||
|
||||
Note:
|
||||
The audio is converted to WAV format before being sent to the API.
|
||||
Only non-empty transcriptions are yielded.
|
||||
"""
|
||||
try:
|
||||
# Convert PCM to WAV
|
||||
with io.BytesIO() as wav_buffer:
|
||||
with wave.open(wav_buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(self.sample_rate)
|
||||
wav_file.writeframes(audio)
|
||||
|
||||
wav_buffer.seek(0)
|
||||
wav_data = wav_buffer.read()
|
||||
|
||||
# Send to Fal
|
||||
data_uri = fal_client.encode(wav_data, "audio/x-wav")
|
||||
response = await self._fal_client.run(
|
||||
"fal-ai/wizper",
|
||||
arguments={"audio_url": data_uri, **self._settings},
|
||||
)
|
||||
|
||||
# Log full response to understand what data we get
|
||||
logger.debug(f"Full Fal response: {response}")
|
||||
|
||||
if response and "text" in response:
|
||||
text = response["text"].strip()
|
||||
if text: # Only yield non-empty text
|
||||
logger.debug(f"Transcription: [{text}]")
|
||||
yield TranscriptionFrame(
|
||||
text, "", time_now_iso8601(), Language(self._settings["language"])
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Fal Wizper error: {e}")
|
||||
|
||||
@@ -54,6 +54,9 @@ class Language(StrEnum):
|
||||
AZ = "az"
|
||||
AZ_AZ = "az-AZ"
|
||||
|
||||
# Bashkir
|
||||
BA = "ba"
|
||||
|
||||
# Belarusian
|
||||
BE = "be"
|
||||
|
||||
@@ -66,6 +69,12 @@ class Language(StrEnum):
|
||||
BN_BD = "bn-BD"
|
||||
BN_IN = "bn-IN"
|
||||
|
||||
# Tibetan
|
||||
BO = "bo"
|
||||
|
||||
# Breton
|
||||
BR = "br"
|
||||
|
||||
# Bosnian
|
||||
BS = "bs"
|
||||
BS_BA = "bs-BA"
|
||||
@@ -158,6 +167,9 @@ class Language(StrEnum):
|
||||
FIL = "fil"
|
||||
FIL_PH = "fil-PH"
|
||||
|
||||
# Faroese
|
||||
FO = "fo"
|
||||
|
||||
# French
|
||||
FR = "fr"
|
||||
FR_BE = "fr-BE"
|
||||
@@ -177,6 +189,9 @@ class Language(StrEnum):
|
||||
GU = "gu"
|
||||
GU_IN = "gu-IN"
|
||||
|
||||
# Hausa
|
||||
HA = "ha"
|
||||
|
||||
# Hebrew
|
||||
HE = "he"
|
||||
HE_IL = "he-IL"
|
||||
@@ -189,6 +204,9 @@ class Language(StrEnum):
|
||||
HR = "hr"
|
||||
HR_HR = "hr-HR"
|
||||
|
||||
# Haitian Creole
|
||||
HT = "ht"
|
||||
|
||||
# Hungarian
|
||||
HU = "hu"
|
||||
HU_HU = "hu-HU"
|
||||
@@ -222,6 +240,7 @@ class Language(StrEnum):
|
||||
# Javanese
|
||||
JV = "jv"
|
||||
JV_ID = "jv-ID"
|
||||
JW = "jw" # Fal requires for Javanese
|
||||
|
||||
# Georgian
|
||||
KA = "ka"
|
||||
@@ -243,6 +262,15 @@ class Language(StrEnum):
|
||||
KO = "ko"
|
||||
KO_KR = "ko-KR"
|
||||
|
||||
# Latin
|
||||
LA = "la"
|
||||
|
||||
# Luxembourgish
|
||||
LB = "lb"
|
||||
|
||||
# Lingala
|
||||
LN = "ln"
|
||||
|
||||
# Lao
|
||||
LO = "lo"
|
||||
LO_LA = "lo-LA"
|
||||
@@ -255,6 +283,9 @@ class Language(StrEnum):
|
||||
LV = "lv"
|
||||
LV_LV = "lv-LV"
|
||||
|
||||
# Malagasy
|
||||
MG = "mg"
|
||||
|
||||
# Macedonian
|
||||
MK = "mk"
|
||||
MK_MK = "mk-MK"
|
||||
@@ -287,9 +318,10 @@ class Language(StrEnum):
|
||||
MY_MM = "my-MM"
|
||||
|
||||
# Norwegian
|
||||
NB = "nb"
|
||||
NB = "nb" # Norwegian Bokmål
|
||||
NB_NO = "nb-NO"
|
||||
NO = "no"
|
||||
NN = "nn" # Norwegian Nynorsk
|
||||
|
||||
# Nepali
|
||||
NE = "ne"
|
||||
@@ -300,6 +332,9 @@ class Language(StrEnum):
|
||||
NL_BE = "nl-BE"
|
||||
NL_NL = "nl-NL"
|
||||
|
||||
# Occitan
|
||||
OC = "oc"
|
||||
|
||||
# Odia
|
||||
OR = "or"
|
||||
OR_IN = "or-IN"
|
||||
@@ -329,6 +364,12 @@ class Language(StrEnum):
|
||||
RU = "ru"
|
||||
RU_RU = "ru-RU"
|
||||
|
||||
# Sanskrit
|
||||
SA = "sa"
|
||||
|
||||
# Sindhi
|
||||
SD = "sd"
|
||||
|
||||
# Sinhala
|
||||
SI = "si"
|
||||
SI_LK = "si-LK"
|
||||
@@ -341,6 +382,9 @@ class Language(StrEnum):
|
||||
SL = "sl"
|
||||
SL_SI = "sl-SI"
|
||||
|
||||
# Shona
|
||||
SN = "sn"
|
||||
|
||||
# Somali
|
||||
SO = "so"
|
||||
SO_SO = "so-SO"
|
||||
@@ -382,14 +426,23 @@ class Language(StrEnum):
|
||||
TE = "te"
|
||||
TE_IN = "te-IN"
|
||||
|
||||
# Tajik
|
||||
TG = "tg"
|
||||
|
||||
# Thai
|
||||
TH = "th"
|
||||
TH_TH = "th-TH"
|
||||
|
||||
# Turkmen
|
||||
TK = "tk"
|
||||
|
||||
# Turkish
|
||||
TR = "tr"
|
||||
TR_TR = "tr-TR"
|
||||
|
||||
# Tatar
|
||||
TT = "tt"
|
||||
|
||||
# Ukrainian
|
||||
UK = "uk"
|
||||
UK_UA = "uk-UA"
|
||||
@@ -411,6 +464,12 @@ class Language(StrEnum):
|
||||
WUU = "wuu"
|
||||
WUU_CN = "wuu-CN"
|
||||
|
||||
# Yiddish
|
||||
YI = "yi"
|
||||
|
||||
# Yoruba
|
||||
YO = "yo"
|
||||
|
||||
# Yue Chinese
|
||||
YUE = "yue"
|
||||
YUE_CN = "yue-CN"
|
||||
|
||||
Reference in New Issue
Block a user