diff --git a/CHANGELOG.md b/CHANGELOG.md
index cbde030da..695586744 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -47,6 +47,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- It is now possible to tell whether `UserStartedSpeakingFrame` or
`UserStoppedSpeakingFrame` have been generated because of emulation frames.
+### Changed
+
+- Pipecat services have been reorganized into packages. Each package can have
+ one or more of the following modules (in the future new module names might be
+ needed) depending on the services implemented:
+ - image: for image generation services
+ - llm: for LLM services
+ - memory: for memory services
+ - stt: for Speech-To-Text services
+ - tts: for Text-To-Speech services
+ - video: for video generation services
+ - vision: for video recognition services
+
+### Deprecated
+
+- All Pipecat services imports have been deprecated and a warning will be shown
+ when using the old import. The new import should be
+ `pipecat.services.[service].[image,llm,memory,stt,tts,video,vision]`. For
+ example, `from pipecat.services.openai.llm import OpenAILLMService`.
+
### Fixed
- Fixed an issue that would cause `SegmentedSTTService` based services
diff --git a/src/pipecat/services/__init__.py b/src/pipecat/services/__init__.py
index e69de29bb..d79c1793d 100644
--- a/src/pipecat/services/__init__.py
+++ b/src/pipecat/services/__init__.py
@@ -0,0 +1,32 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+from typing import Any, Dict
+
+
+def _warn_deprecated_access(globals: Dict[str, Any], attr, old: str, new: str):
+ import warnings
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("always")
+ warnings.warn(
+ f"Module `pipecat.services.{old}` is deprecated, use `pipecat.services.{new}` instead",
+ DeprecationWarning,
+ )
+
+ return globals[attr]
+
+
+class DeprecatedModuleProxy:
+ def __init__(self, globals: Dict[str, Any], old: str, new: str):
+ self._globals = globals
+ self._old = old
+ self._new = new
+
+ def __getattr__(self, attr):
+ if attr in self._globals:
+ return _warn_deprecated_access(self._globals, attr, self._old, self._new)
+ raise AttributeError(f"module 'pipecat.{self._old}' has no attribute '{attr}'")
diff --git a/src/pipecat/services/anthropic/__init__.py b/src/pipecat/services/anthropic/__init__.py
new file mode 100644
index 000000000..37ffe8d99
--- /dev/null
+++ b/src/pipecat/services/anthropic/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "anthropic", "anthropic.llm")
diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic/llm.py
similarity index 100%
rename from src/pipecat/services/anthropic.py
rename to src/pipecat/services/anthropic/llm.py
diff --git a/src/pipecat/services/assemblyai/__init__.py b/src/pipecat/services/assemblyai/__init__.py
new file mode 100644
index 000000000..e31a2f393
--- /dev/null
+++ b/src/pipecat/services/assemblyai/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .stt import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "assemblyai", "assemblyai.stt")
diff --git a/src/pipecat/services/assemblyai.py b/src/pipecat/services/assemblyai/stt.py
similarity index 96%
rename from src/pipecat/services/assemblyai.py
rename to src/pipecat/services/assemblyai/stt.py
index 94f081c5b..87dd18bf4 100644
--- a/src/pipecat/services/assemblyai.py
+++ b/src/pipecat/services/assemblyai/stt.py
@@ -27,9 +27,7 @@ try:
from assemblyai import AudioEncoding
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
- logger.error(
- "In order to use AssemblyAI, you need to `pip install pipecat-ai[assemblyai]`. Also, set `ASSEMBLYAI_API_KEY` environment variable."
- )
+ logger.error("In order to use AssemblyAI, you need to `pip install pipecat-ai[assemblyai]`.")
raise Exception(f"Missing module: {e}")
diff --git a/src/pipecat/services/aws/__init__.py b/src/pipecat/services/aws/__init__.py
new file mode 100644
index 000000000..b36c88499
--- /dev/null
+++ b/src/pipecat/services/aws/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.tts")
diff --git a/src/pipecat/services/aws.py b/src/pipecat/services/aws/tts.py
similarity index 100%
rename from src/pipecat/services/aws.py
rename to src/pipecat/services/aws/tts.py
diff --git a/src/pipecat/services/azure.py b/src/pipecat/services/azure.py
deleted file mode 100644
index 9df1a8ef1..000000000
--- a/src/pipecat/services/azure.py
+++ /dev/null
@@ -1,813 +0,0 @@
-#
-# Copyright (c) 2024–2025, Daily
-#
-# SPDX-License-Identifier: BSD 2-Clause License
-#
-
-import asyncio
-import io
-from typing import AsyncGenerator, Optional
-
-import aiohttp
-from loguru import logger
-from openai import AsyncAzureOpenAI
-from PIL import Image
-from pydantic import BaseModel
-
-from pipecat.frames.frames import (
- CancelFrame,
- EndFrame,
- ErrorFrame,
- Frame,
- StartFrame,
- TranscriptionFrame,
- TTSAudioRawFrame,
- TTSStartedFrame,
- TTSStoppedFrame,
- URLImageRawFrame,
-)
-from pipecat.services.ai_services import ImageGenService, STTService, TTSService
-from pipecat.services.openai import (
- OpenAILLMService,
-)
-from pipecat.transcriptions.language import Language
-from pipecat.utils.time import time_now_iso8601
-
-# See .env.example for Azure configuration needed
-try:
- from azure.cognitiveservices.speech import (
- CancellationReason,
- ResultReason,
- ServicePropertyChannel,
- SpeechConfig,
- SpeechRecognizer,
- SpeechSynthesisOutputFormat,
- SpeechSynthesizer,
- )
- from azure.cognitiveservices.speech.audio import (
- AudioStreamFormat,
- PushAudioInputStream,
- )
- from azure.cognitiveservices.speech.dialog import AudioConfig
-except ModuleNotFoundError as e:
- logger.error(f"Exception: {e}")
- logger.error(
- "In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables."
- )
- raise Exception(f"Missing module: {e}")
-
-
-def language_to_azure_language(language: Language) -> Optional[str]:
- language_map = {
- # Afrikaans
- Language.AF: "af-ZA",
- Language.AF_ZA: "af-ZA",
- # Amharic
- Language.AM: "am-ET",
- Language.AM_ET: "am-ET",
- # Arabic
- Language.AR: "ar-AE", # Default to UAE Arabic
- Language.AR_AE: "ar-AE",
- Language.AR_BH: "ar-BH",
- Language.AR_DZ: "ar-DZ",
- Language.AR_EG: "ar-EG",
- Language.AR_IQ: "ar-IQ",
- Language.AR_JO: "ar-JO",
- Language.AR_KW: "ar-KW",
- Language.AR_LB: "ar-LB",
- Language.AR_LY: "ar-LY",
- Language.AR_MA: "ar-MA",
- Language.AR_OM: "ar-OM",
- Language.AR_QA: "ar-QA",
- Language.AR_SA: "ar-SA",
- Language.AR_SY: "ar-SY",
- Language.AR_TN: "ar-TN",
- Language.AR_YE: "ar-YE",
- # Assamese
- Language.AS: "as-IN",
- Language.AS_IN: "as-IN",
- # Azerbaijani
- Language.AZ: "az-AZ",
- Language.AZ_AZ: "az-AZ",
- # Bulgarian
- Language.BG: "bg-BG",
- Language.BG_BG: "bg-BG",
- # Bengali
- Language.BN: "bn-IN", # Default to Indian Bengali
- Language.BN_BD: "bn-BD",
- Language.BN_IN: "bn-IN",
- # Bosnian
- Language.BS: "bs-BA",
- Language.BS_BA: "bs-BA",
- # Catalan
- Language.CA: "ca-ES",
- Language.CA_ES: "ca-ES",
- # Czech
- Language.CS: "cs-CZ",
- Language.CS_CZ: "cs-CZ",
- # Welsh
- Language.CY: "cy-GB",
- Language.CY_GB: "cy-GB",
- # Danish
- Language.DA: "da-DK",
- Language.DA_DK: "da-DK",
- # German
- Language.DE: "de-DE",
- Language.DE_AT: "de-AT",
- Language.DE_CH: "de-CH",
- Language.DE_DE: "de-DE",
- # Greek
- Language.EL: "el-GR",
- Language.EL_GR: "el-GR",
- # English
- Language.EN: "en-US", # Default to US English
- Language.EN_AU: "en-AU",
- Language.EN_CA: "en-CA",
- Language.EN_GB: "en-GB",
- Language.EN_HK: "en-HK",
- Language.EN_IE: "en-IE",
- Language.EN_IN: "en-IN",
- Language.EN_KE: "en-KE",
- Language.EN_NG: "en-NG",
- Language.EN_NZ: "en-NZ",
- Language.EN_PH: "en-PH",
- Language.EN_SG: "en-SG",
- Language.EN_TZ: "en-TZ",
- Language.EN_US: "en-US",
- Language.EN_ZA: "en-ZA",
- # Spanish
- Language.ES: "es-ES", # Default to Spain Spanish
- Language.ES_AR: "es-AR",
- Language.ES_BO: "es-BO",
- Language.ES_CL: "es-CL",
- Language.ES_CO: "es-CO",
- Language.ES_CR: "es-CR",
- Language.ES_CU: "es-CU",
- Language.ES_DO: "es-DO",
- Language.ES_EC: "es-EC",
- Language.ES_ES: "es-ES",
- Language.ES_GQ: "es-GQ",
- Language.ES_GT: "es-GT",
- Language.ES_HN: "es-HN",
- Language.ES_MX: "es-MX",
- Language.ES_NI: "es-NI",
- Language.ES_PA: "es-PA",
- Language.ES_PE: "es-PE",
- Language.ES_PR: "es-PR",
- Language.ES_PY: "es-PY",
- Language.ES_SV: "es-SV",
- Language.ES_US: "es-US",
- Language.ES_UY: "es-UY",
- Language.ES_VE: "es-VE",
- # Estonian
- Language.ET: "et-EE",
- Language.ET_EE: "et-EE",
- # Basque
- Language.EU: "eu-ES",
- Language.EU_ES: "eu-ES",
- # Persian
- Language.FA: "fa-IR",
- Language.FA_IR: "fa-IR",
- # Finnish
- Language.FI: "fi-FI",
- Language.FI_FI: "fi-FI",
- # Filipino
- Language.FIL: "fil-PH",
- Language.FIL_PH: "fil-PH",
- # French
- Language.FR: "fr-FR",
- Language.FR_BE: "fr-BE",
- Language.FR_CA: "fr-CA",
- Language.FR_CH: "fr-CH",
- Language.FR_FR: "fr-FR",
- # Irish
- Language.GA: "ga-IE",
- Language.GA_IE: "ga-IE",
- # Galician
- Language.GL: "gl-ES",
- Language.GL_ES: "gl-ES",
- # Gujarati
- Language.GU: "gu-IN",
- Language.GU_IN: "gu-IN",
- # Hebrew
- Language.HE: "he-IL",
- Language.HE_IL: "he-IL",
- # Hindi
- Language.HI: "hi-IN",
- Language.HI_IN: "hi-IN",
- # Croatian
- Language.HR: "hr-HR",
- Language.HR_HR: "hr-HR",
- # Hungarian
- Language.HU: "hu-HU",
- Language.HU_HU: "hu-HU",
- # Armenian
- Language.HY: "hy-AM",
- Language.HY_AM: "hy-AM",
- # Indonesian
- Language.ID: "id-ID",
- Language.ID_ID: "id-ID",
- # Icelandic
- Language.IS: "is-IS",
- Language.IS_IS: "is-IS",
- # Italian
- Language.IT: "it-IT",
- Language.IT_IT: "it-IT",
- # Inuktitut
- Language.IU_CANS_CA: "iu-Cans-CA",
- Language.IU_LATN_CA: "iu-Latn-CA",
- # Japanese
- Language.JA: "ja-JP",
- Language.JA_JP: "ja-JP",
- # Javanese
- Language.JV: "jv-ID",
- Language.JV_ID: "jv-ID",
- # Georgian
- Language.KA: "ka-GE",
- Language.KA_GE: "ka-GE",
- # Kazakh
- Language.KK: "kk-KZ",
- Language.KK_KZ: "kk-KZ",
- # Khmer
- Language.KM: "km-KH",
- Language.KM_KH: "km-KH",
- # Kannada
- Language.KN: "kn-IN",
- Language.KN_IN: "kn-IN",
- # Korean
- Language.KO: "ko-KR",
- Language.KO_KR: "ko-KR",
- # Lao
- Language.LO: "lo-LA",
- Language.LO_LA: "lo-LA",
- # Lithuanian
- Language.LT: "lt-LT",
- Language.LT_LT: "lt-LT",
- # Latvian
- Language.LV: "lv-LV",
- Language.LV_LV: "lv-LV",
- # Macedonian
- Language.MK: "mk-MK",
- Language.MK_MK: "mk-MK",
- # Malayalam
- Language.ML: "ml-IN",
- Language.ML_IN: "ml-IN",
- # Mongolian
- Language.MN: "mn-MN",
- Language.MN_MN: "mn-MN",
- # Marathi
- Language.MR: "mr-IN",
- Language.MR_IN: "mr-IN",
- # Malay
- Language.MS: "ms-MY",
- Language.MS_MY: "ms-MY",
- # Maltese
- Language.MT: "mt-MT",
- Language.MT_MT: "mt-MT",
- # Burmese
- Language.MY: "my-MM",
- Language.MY_MM: "my-MM",
- # Norwegian
- Language.NB: "nb-NO",
- Language.NB_NO: "nb-NO",
- Language.NO: "nb-NO",
- # Nepali
- Language.NE: "ne-NP",
- Language.NE_NP: "ne-NP",
- # Dutch
- Language.NL: "nl-NL",
- Language.NL_BE: "nl-BE",
- Language.NL_NL: "nl-NL",
- # Odia
- Language.OR: "or-IN",
- Language.OR_IN: "or-IN",
- # Punjabi
- Language.PA: "pa-IN",
- Language.PA_IN: "pa-IN",
- # Polish
- Language.PL: "pl-PL",
- Language.PL_PL: "pl-PL",
- # Pashto
- Language.PS: "ps-AF",
- Language.PS_AF: "ps-AF",
- # Portuguese
- Language.PT: "pt-PT",
- Language.PT_BR: "pt-BR",
- Language.PT_PT: "pt-PT",
- # Romanian
- Language.RO: "ro-RO",
- Language.RO_RO: "ro-RO",
- # Russian
- Language.RU: "ru-RU",
- Language.RU_RU: "ru-RU",
- # Sinhala
- Language.SI: "si-LK",
- Language.SI_LK: "si-LK",
- # Slovak
- Language.SK: "sk-SK",
- Language.SK_SK: "sk-SK",
- # Slovenian
- Language.SL: "sl-SI",
- Language.SL_SI: "sl-SI",
- # Somali
- Language.SO: "so-SO",
- Language.SO_SO: "so-SO",
- # Albanian
- Language.SQ: "sq-AL",
- Language.SQ_AL: "sq-AL",
- # Serbian
- Language.SR: "sr-RS",
- Language.SR_RS: "sr-RS",
- Language.SR_LATN: "sr-Latn-RS",
- Language.SR_LATN_RS: "sr-Latn-RS",
- # Sundanese
- Language.SU: "su-ID",
- Language.SU_ID: "su-ID",
- # Swedish
- Language.SV: "sv-SE",
- Language.SV_SE: "sv-SE",
- # Swahili
- Language.SW: "sw-KE",
- Language.SW_KE: "sw-KE",
- Language.SW_TZ: "sw-TZ",
- # Tamil
- Language.TA: "ta-IN",
- Language.TA_IN: "ta-IN",
- Language.TA_LK: "ta-LK",
- Language.TA_MY: "ta-MY",
- Language.TA_SG: "ta-SG",
- # Telugu
- Language.TE: "te-IN",
- Language.TE_IN: "te-IN",
- # Thai
- Language.TH: "th-TH",
- Language.TH_TH: "th-TH",
- # Turkish
- Language.TR: "tr-TR",
- Language.TR_TR: "tr-TR",
- # Ukrainian
- Language.UK: "uk-UA",
- Language.UK_UA: "uk-UA",
- # Urdu
- Language.UR: "ur-IN",
- Language.UR_IN: "ur-IN",
- Language.UR_PK: "ur-PK",
- # Uzbek
- Language.UZ: "uz-UZ",
- Language.UZ_UZ: "uz-UZ",
- # Vietnamese
- Language.VI: "vi-VN",
- Language.VI_VN: "vi-VN",
- # Wu Chinese
- Language.WUU: "wuu-CN",
- Language.WUU_CN: "wuu-CN",
- # Yue Chinese
- Language.YUE: "yue-CN",
- Language.YUE_CN: "yue-CN",
- # Chinese
- Language.ZH: "zh-CN",
- Language.ZH_CN: "zh-CN",
- Language.ZH_CN_GUANGXI: "zh-CN-guangxi",
- Language.ZH_CN_HENAN: "zh-CN-henan",
- Language.ZH_CN_LIAONING: "zh-CN-liaoning",
- Language.ZH_CN_SHAANXI: "zh-CN-shaanxi",
- Language.ZH_CN_SHANDONG: "zh-CN-shandong",
- Language.ZH_CN_SICHUAN: "zh-CN-sichuan",
- Language.ZH_HK: "zh-HK",
- Language.ZH_TW: "zh-TW",
- # Zulu
- Language.ZU: "zu-ZA",
- Language.ZU_ZA: "zu-ZA",
- }
- return language_map.get(language)
-
-
-def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputFormat:
- sample_rate_map = {
- 8000: SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm,
- 16000: SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm,
- 22050: SpeechSynthesisOutputFormat.Raw22050Hz16BitMonoPcm,
- 24000: SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm,
- 44100: SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm,
- 48000: SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm,
- }
- return sample_rate_map.get(sample_rate, SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm)
-
-
-class AzureLLMService(OpenAILLMService):
- """A service for interacting with Azure OpenAI using the OpenAI-compatible interface.
-
- This service extends OpenAILLMService to connect to Azure's OpenAI endpoint while
- maintaining full compatibility with OpenAI's interface and functionality.
-
- Args:
- api_key (str): The API key for accessing Azure OpenAI
- endpoint (str): The Azure endpoint URL
- model (str): The model identifier to use
- api_version (str, optional): Azure API version. Defaults to "2024-09-01-preview"
- **kwargs: Additional keyword arguments passed to OpenAILLMService
- """
-
- def __init__(
- self,
- *,
- api_key: str,
- endpoint: str,
- model: str,
- api_version: str = "2024-09-01-preview",
- **kwargs,
- ):
- # Initialize variables before calling parent __init__() because that
- # will call create_client() and we need those values there.
- self._endpoint = endpoint
- self._api_version = api_version
- super().__init__(api_key=api_key, model=model, **kwargs)
-
- def create_client(self, api_key=None, base_url=None, **kwargs):
- """Create OpenAI-compatible client for Azure OpenAI endpoint."""
- logger.debug(f"Creating Azure OpenAI client with endpoint {self._endpoint}")
- return AsyncAzureOpenAI(
- api_key=api_key,
- azure_endpoint=self._endpoint,
- api_version=self._api_version,
- )
-
-
-class AzureBaseTTSService(TTSService):
- class InputParams(BaseModel):
- emphasis: Optional[str] = None
- language: Optional[Language] = Language.EN_US
- pitch: Optional[str] = None
- rate: Optional[str] = "1.05"
- role: Optional[str] = None
- style: Optional[str] = None
- style_degree: Optional[str] = None
- volume: Optional[str] = None
-
- def __init__(
- self,
- *,
- api_key: str,
- region: str,
- voice="en-US-SaraNeural",
- sample_rate: Optional[int] = None,
- params: InputParams = InputParams(),
- **kwargs,
- ):
- super().__init__(sample_rate=sample_rate, **kwargs)
-
- self._settings = {
- "emphasis": params.emphasis,
- "language": self.language_to_service_language(params.language)
- if params.language
- else "en-US",
- "pitch": params.pitch,
- "rate": params.rate,
- "role": params.role,
- "style": params.style,
- "style_degree": params.style_degree,
- "volume": params.volume,
- }
-
- self._api_key = api_key
- self._region = region
- self._voice_id = voice
- self._speech_synthesizer = None
-
- def can_generate_metrics(self) -> bool:
- return True
-
- def language_to_service_language(self, language: Language) -> Optional[str]:
- return language_to_azure_language(language)
-
- def _construct_ssml(self, text: str) -> str:
- language = self._settings["language"]
- ssml = (
- f""
- f""
- ""
- )
-
- if self._settings["style"]:
- ssml += f""
-
- prosody_attrs = []
- if self._settings["rate"]:
- prosody_attrs.append(f"rate='{self._settings['rate']}'")
- if self._settings["pitch"]:
- prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
- if self._settings["volume"]:
- prosody_attrs.append(f"volume='{self._settings['volume']}'")
-
- ssml += f""
-
- if self._settings["emphasis"]:
- ssml += f""
-
- ssml += text
-
- if self._settings["emphasis"]:
- ssml += ""
-
- ssml += ""
-
- if self._settings["style"]:
- ssml += ""
-
- ssml += ""
-
- return ssml
-
-
-class AzureTTSService(AzureBaseTTSService):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self._speech_config = None
- self._speech_synthesizer = None
- self._audio_queue = asyncio.Queue()
-
- async def start(self, frame: StartFrame):
- await super().start(frame)
-
- if self._speech_config:
- return
-
- # Now self.sample_rate is properly initialized
- self._speech_config = SpeechConfig(
- subscription=self._api_key,
- region=self._region,
- speech_recognition_language=self._settings["language"],
- )
- self._speech_config.set_speech_synthesis_output_format(
- sample_rate_to_output_format(self.sample_rate)
- )
- self._speech_config.set_service_property(
- "synthesizer.synthesis.connection.synthesisConnectionImpl",
- "websocket",
- ServicePropertyChannel.UriQueryParameter,
- )
-
- self._speech_synthesizer = SpeechSynthesizer(
- speech_config=self._speech_config, audio_config=None
- )
-
- # Set up event handlers
- self._speech_synthesizer.synthesizing.connect(self._handle_synthesizing)
- self._speech_synthesizer.synthesis_completed.connect(self._handle_completed)
- self._speech_synthesizer.synthesis_canceled.connect(self._handle_canceled)
-
- def _handle_synthesizing(self, evt):
- """Handle audio chunks as they arrive"""
- if evt.result and evt.result.audio_data:
- self._audio_queue.put_nowait(evt.result.audio_data)
-
- def _handle_completed(self, evt):
- """Handle synthesis completion"""
- self._audio_queue.put_nowait(None) # Signal completion
-
- def _handle_canceled(self, evt):
- """Handle synthesis cancellation"""
- logger.error(f"Speech synthesis canceled: {evt.result.cancellation_details.reason}")
- self._audio_queue.put_nowait(None)
-
- async def flush_audio(self):
- logger.trace(f"{self}: flushing audio")
-
- async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
- logger.debug(f"{self}: Generating TTS [{text}]")
-
- try:
- if self._speech_synthesizer is None:
- error_msg = "Speech synthesizer not initialized."
- logger.error(error_msg)
- yield ErrorFrame(error_msg)
- return
-
- try:
- await self.start_ttfb_metrics()
- yield TTSStartedFrame()
-
- ssml = self._construct_ssml(text)
- self._speech_synthesizer.speak_ssml_async(ssml)
- await self.start_tts_usage_metrics(text)
-
- # Stream audio chunks as they arrive
- while True:
- chunk = await self._audio_queue.get()
- if chunk is None: # End of stream
- break
-
- await self.stop_ttfb_metrics()
- yield TTSAudioRawFrame(
- audio=chunk,
- sample_rate=self.sample_rate,
- num_channels=1,
- )
-
- yield TTSStoppedFrame()
-
- except Exception as e:
- logger.error(f"{self} error during synthesis: {e}")
- yield TTSStoppedFrame()
- # Could add reconnection logic here if needed
- return
-
- except Exception as e:
- logger.error(f"{self} exception: {e}")
-
-
-class AzureHttpTTSService(AzureBaseTTSService):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self._speech_config = None
- self._speech_synthesizer = None
-
- async def start(self, frame: StartFrame):
- await super().start(frame)
-
- if self._speech_config:
- return
-
- self._speech_config = SpeechConfig(
- subscription=self._api_key,
- region=self._region,
- speech_recognition_language=self._settings["language"],
- )
- self._speech_config.set_speech_synthesis_output_format(
- sample_rate_to_output_format(self.sample_rate)
- )
- self._speech_synthesizer = SpeechSynthesizer(
- speech_config=self._speech_config, audio_config=None
- )
-
- async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
- logger.debug(f"{self}: Generating TTS [{text}]")
-
- await self.start_ttfb_metrics()
-
- ssml = self._construct_ssml(text)
-
- result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, ssml)
-
- if result.reason == ResultReason.SynthesizingAudioCompleted:
- await self.start_tts_usage_metrics(text)
- await self.stop_ttfb_metrics()
- yield TTSStartedFrame()
- # Azure always sends a 44-byte header. Strip it off.
- yield TTSAudioRawFrame(
- audio=result.audio_data[44:],
- sample_rate=self.sample_rate,
- num_channels=1,
- )
- yield TTSStoppedFrame()
- elif result.reason == ResultReason.Canceled:
- cancellation_details = result.cancellation_details
- logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
- if cancellation_details.reason == CancellationReason.Error:
- logger.error(f"{self} error: {cancellation_details.error_details}")
-
-
-class AzureSTTService(STTService):
- def __init__(
- self,
- *,
- api_key: str,
- region: str,
- language: Language = Language.EN_US,
- sample_rate: Optional[int] = None,
- **kwargs,
- ):
- super().__init__(sample_rate=sample_rate, **kwargs)
-
- self._speech_config = SpeechConfig(
- subscription=api_key,
- region=region,
- speech_recognition_language=language_to_azure_language(language),
- )
-
- self._audio_stream = None
- self._speech_recognizer = None
-
- async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
- await self.start_processing_metrics()
- if self._audio_stream:
- self._audio_stream.write(audio)
- await self.stop_processing_metrics()
- yield None
-
- async def start(self, frame: StartFrame):
- await super().start(frame)
-
- if self._audio_stream:
- return
-
- stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
- self._audio_stream = PushAudioInputStream(stream_format)
-
- audio_config = AudioConfig(stream=self._audio_stream)
-
- self._speech_recognizer = SpeechRecognizer(
- speech_config=self._speech_config, audio_config=audio_config
- )
- self._speech_recognizer.recognized.connect(self._on_handle_recognized)
- self._speech_recognizer.start_continuous_recognition_async()
-
- async def stop(self, frame: EndFrame):
- await super().stop(frame)
-
- if self._speech_recognizer:
- self._speech_recognizer.stop_continuous_recognition_async()
-
- if self._audio_stream:
- self._audio_stream.close()
-
- async def cancel(self, frame: CancelFrame):
- await super().cancel(frame)
-
- if self._speech_recognizer:
- self._speech_recognizer.stop_continuous_recognition_async()
-
- if self._audio_stream:
- self._audio_stream.close()
-
- def _on_handle_recognized(self, event):
- if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
- frame = TranscriptionFrame(event.result.text, "", time_now_iso8601())
- asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
-
-
-class AzureImageGenServiceREST(ImageGenService):
- def __init__(
- self,
- *,
- image_size: str,
- api_key: str,
- endpoint: str,
- model: str,
- aiohttp_session: aiohttp.ClientSession,
- api_version="2023-06-01-preview",
- ):
- super().__init__()
-
- self._api_key = api_key
- self._azure_endpoint = endpoint
- self._api_version = api_version
- self.set_model_name(model)
- self._image_size = image_size
- self._aiohttp_session = aiohttp_session
-
- async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
- url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}"
-
- headers = {"api-key": self._api_key, "Content-Type": "application/json"}
-
- body = {
- # Enter your prompt text here
- "prompt": prompt,
- "size": self._image_size,
- "n": 1,
- }
-
- async with self._aiohttp_session.post(url, headers=headers, json=body) as submission:
- # We never get past this line, because this header isn't
- # defined on a 429 response, but something is eating our
- # exceptions!
- operation_location = submission.headers["operation-location"]
- status = ""
- attempts_left = 120
- json_response = None
- while status != "succeeded":
- attempts_left -= 1
- if attempts_left == 0:
- logger.error(f"{self} error: image generation timed out")
- yield ErrorFrame("Image generation timed out")
- return
-
- await asyncio.sleep(1)
-
- response = await self._aiohttp_session.get(operation_location, headers=headers)
-
- json_response = await response.json()
- status = json_response["status"]
-
- image_url = json_response["result"]["data"][0]["url"] if json_response else None
- if not image_url:
- logger.error(f"{self} error: image generation failed")
- yield ErrorFrame("Image generation failed")
- return
-
- # Load the image from the url
- async with self._aiohttp_session.get(image_url) as response:
- image_stream = io.BytesIO(await response.content.read())
- image = Image.open(image_stream)
- frame = URLImageRawFrame(
- url=image_url, image=image.tobytes(), size=image.size, format=image.format
- )
- yield frame
diff --git a/src/pipecat/services/azure/__init__.py b/src/pipecat/services/azure/__init__.py
new file mode 100644
index 000000000..26e96993d
--- /dev/null
+++ b/src/pipecat/services/azure/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+from .stt import *
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "azure", "azure.[llm,stt,tts]")
diff --git a/src/pipecat/services/azure/common.py b/src/pipecat/services/azure/common.py
new file mode 100644
index 000000000..054463257
--- /dev/null
+++ b/src/pipecat/services/azure/common.py
@@ -0,0 +1,336 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+from typing import Optional
+
+from loguru import logger
+
+from pipecat.transcriptions.language import Language
+
+
+def language_to_azure_language(language: Language) -> Optional[str]:
+ language_map = {
+ # Afrikaans
+ Language.AF: "af-ZA",
+ Language.AF_ZA: "af-ZA",
+ # Amharic
+ Language.AM: "am-ET",
+ Language.AM_ET: "am-ET",
+ # Arabic
+ Language.AR: "ar-AE", # Default to UAE Arabic
+ Language.AR_AE: "ar-AE",
+ Language.AR_BH: "ar-BH",
+ Language.AR_DZ: "ar-DZ",
+ Language.AR_EG: "ar-EG",
+ Language.AR_IQ: "ar-IQ",
+ Language.AR_JO: "ar-JO",
+ Language.AR_KW: "ar-KW",
+ Language.AR_LB: "ar-LB",
+ Language.AR_LY: "ar-LY",
+ Language.AR_MA: "ar-MA",
+ Language.AR_OM: "ar-OM",
+ Language.AR_QA: "ar-QA",
+ Language.AR_SA: "ar-SA",
+ Language.AR_SY: "ar-SY",
+ Language.AR_TN: "ar-TN",
+ Language.AR_YE: "ar-YE",
+ # Assamese
+ Language.AS: "as-IN",
+ Language.AS_IN: "as-IN",
+ # Azerbaijani
+ Language.AZ: "az-AZ",
+ Language.AZ_AZ: "az-AZ",
+ # Bulgarian
+ Language.BG: "bg-BG",
+ Language.BG_BG: "bg-BG",
+ # Bengali
+ Language.BN: "bn-IN", # Default to Indian Bengali
+ Language.BN_BD: "bn-BD",
+ Language.BN_IN: "bn-IN",
+ # Bosnian
+ Language.BS: "bs-BA",
+ Language.BS_BA: "bs-BA",
+ # Catalan
+ Language.CA: "ca-ES",
+ Language.CA_ES: "ca-ES",
+ # Czech
+ Language.CS: "cs-CZ",
+ Language.CS_CZ: "cs-CZ",
+ # Welsh
+ Language.CY: "cy-GB",
+ Language.CY_GB: "cy-GB",
+ # Danish
+ Language.DA: "da-DK",
+ Language.DA_DK: "da-DK",
+ # German
+ Language.DE: "de-DE",
+ Language.DE_AT: "de-AT",
+ Language.DE_CH: "de-CH",
+ Language.DE_DE: "de-DE",
+ # Greek
+ Language.EL: "el-GR",
+ Language.EL_GR: "el-GR",
+ # English
+ Language.EN: "en-US", # Default to US English
+ Language.EN_AU: "en-AU",
+ Language.EN_CA: "en-CA",
+ Language.EN_GB: "en-GB",
+ Language.EN_HK: "en-HK",
+ Language.EN_IE: "en-IE",
+ Language.EN_IN: "en-IN",
+ Language.EN_KE: "en-KE",
+ Language.EN_NG: "en-NG",
+ Language.EN_NZ: "en-NZ",
+ Language.EN_PH: "en-PH",
+ Language.EN_SG: "en-SG",
+ Language.EN_TZ: "en-TZ",
+ Language.EN_US: "en-US",
+ Language.EN_ZA: "en-ZA",
+ # Spanish
+ Language.ES: "es-ES", # Default to Spain Spanish
+ Language.ES_AR: "es-AR",
+ Language.ES_BO: "es-BO",
+ Language.ES_CL: "es-CL",
+ Language.ES_CO: "es-CO",
+ Language.ES_CR: "es-CR",
+ Language.ES_CU: "es-CU",
+ Language.ES_DO: "es-DO",
+ Language.ES_EC: "es-EC",
+ Language.ES_ES: "es-ES",
+ Language.ES_GQ: "es-GQ",
+ Language.ES_GT: "es-GT",
+ Language.ES_HN: "es-HN",
+ Language.ES_MX: "es-MX",
+ Language.ES_NI: "es-NI",
+ Language.ES_PA: "es-PA",
+ Language.ES_PE: "es-PE",
+ Language.ES_PR: "es-PR",
+ Language.ES_PY: "es-PY",
+ Language.ES_SV: "es-SV",
+ Language.ES_US: "es-US",
+ Language.ES_UY: "es-UY",
+ Language.ES_VE: "es-VE",
+ # Estonian
+ Language.ET: "et-EE",
+ Language.ET_EE: "et-EE",
+ # Basque
+ Language.EU: "eu-ES",
+ Language.EU_ES: "eu-ES",
+ # Persian
+ Language.FA: "fa-IR",
+ Language.FA_IR: "fa-IR",
+ # Finnish
+ Language.FI: "fi-FI",
+ Language.FI_FI: "fi-FI",
+ # Filipino
+ Language.FIL: "fil-PH",
+ Language.FIL_PH: "fil-PH",
+ # French
+ Language.FR: "fr-FR",
+ Language.FR_BE: "fr-BE",
+ Language.FR_CA: "fr-CA",
+ Language.FR_CH: "fr-CH",
+ Language.FR_FR: "fr-FR",
+ # Irish
+ Language.GA: "ga-IE",
+ Language.GA_IE: "ga-IE",
+ # Galician
+ Language.GL: "gl-ES",
+ Language.GL_ES: "gl-ES",
+ # Gujarati
+ Language.GU: "gu-IN",
+ Language.GU_IN: "gu-IN",
+ # Hebrew
+ Language.HE: "he-IL",
+ Language.HE_IL: "he-IL",
+ # Hindi
+ Language.HI: "hi-IN",
+ Language.HI_IN: "hi-IN",
+ # Croatian
+ Language.HR: "hr-HR",
+ Language.HR_HR: "hr-HR",
+ # Hungarian
+ Language.HU: "hu-HU",
+ Language.HU_HU: "hu-HU",
+ # Armenian
+ Language.HY: "hy-AM",
+ Language.HY_AM: "hy-AM",
+ # Indonesian
+ Language.ID: "id-ID",
+ Language.ID_ID: "id-ID",
+ # Icelandic
+ Language.IS: "is-IS",
+ Language.IS_IS: "is-IS",
+ # Italian
+ Language.IT: "it-IT",
+ Language.IT_IT: "it-IT",
+ # Inuktitut
+ Language.IU_CANS_CA: "iu-Cans-CA",
+ Language.IU_LATN_CA: "iu-Latn-CA",
+ # Japanese
+ Language.JA: "ja-JP",
+ Language.JA_JP: "ja-JP",
+ # Javanese
+ Language.JV: "jv-ID",
+ Language.JV_ID: "jv-ID",
+ # Georgian
+ Language.KA: "ka-GE",
+ Language.KA_GE: "ka-GE",
+ # Kazakh
+ Language.KK: "kk-KZ",
+ Language.KK_KZ: "kk-KZ",
+ # Khmer
+ Language.KM: "km-KH",
+ Language.KM_KH: "km-KH",
+ # Kannada
+ Language.KN: "kn-IN",
+ Language.KN_IN: "kn-IN",
+ # Korean
+ Language.KO: "ko-KR",
+ Language.KO_KR: "ko-KR",
+ # Lao
+ Language.LO: "lo-LA",
+ Language.LO_LA: "lo-LA",
+ # Lithuanian
+ Language.LT: "lt-LT",
+ Language.LT_LT: "lt-LT",
+ # Latvian
+ Language.LV: "lv-LV",
+ Language.LV_LV: "lv-LV",
+ # Macedonian
+ Language.MK: "mk-MK",
+ Language.MK_MK: "mk-MK",
+ # Malayalam
+ Language.ML: "ml-IN",
+ Language.ML_IN: "ml-IN",
+ # Mongolian
+ Language.MN: "mn-MN",
+ Language.MN_MN: "mn-MN",
+ # Marathi
+ Language.MR: "mr-IN",
+ Language.MR_IN: "mr-IN",
+ # Malay
+ Language.MS: "ms-MY",
+ Language.MS_MY: "ms-MY",
+ # Maltese
+ Language.MT: "mt-MT",
+ Language.MT_MT: "mt-MT",
+ # Burmese
+ Language.MY: "my-MM",
+ Language.MY_MM: "my-MM",
+ # Norwegian
+ Language.NB: "nb-NO",
+ Language.NB_NO: "nb-NO",
+ Language.NO: "nb-NO",
+ # Nepali
+ Language.NE: "ne-NP",
+ Language.NE_NP: "ne-NP",
+ # Dutch
+ Language.NL: "nl-NL",
+ Language.NL_BE: "nl-BE",
+ Language.NL_NL: "nl-NL",
+ # Odia
+ Language.OR: "or-IN",
+ Language.OR_IN: "or-IN",
+ # Punjabi
+ Language.PA: "pa-IN",
+ Language.PA_IN: "pa-IN",
+ # Polish
+ Language.PL: "pl-PL",
+ Language.PL_PL: "pl-PL",
+ # Pashto
+ Language.PS: "ps-AF",
+ Language.PS_AF: "ps-AF",
+ # Portuguese
+ Language.PT: "pt-PT",
+ Language.PT_BR: "pt-BR",
+ Language.PT_PT: "pt-PT",
+ # Romanian
+ Language.RO: "ro-RO",
+ Language.RO_RO: "ro-RO",
+ # Russian
+ Language.RU: "ru-RU",
+ Language.RU_RU: "ru-RU",
+ # Sinhala
+ Language.SI: "si-LK",
+ Language.SI_LK: "si-LK",
+ # Slovak
+ Language.SK: "sk-SK",
+ Language.SK_SK: "sk-SK",
+ # Slovenian
+ Language.SL: "sl-SI",
+ Language.SL_SI: "sl-SI",
+ # Somali
+ Language.SO: "so-SO",
+ Language.SO_SO: "so-SO",
+ # Albanian
+ Language.SQ: "sq-AL",
+ Language.SQ_AL: "sq-AL",
+ # Serbian
+ Language.SR: "sr-RS",
+ Language.SR_RS: "sr-RS",
+ Language.SR_LATN: "sr-Latn-RS",
+ Language.SR_LATN_RS: "sr-Latn-RS",
+ # Sundanese
+ Language.SU: "su-ID",
+ Language.SU_ID: "su-ID",
+ # Swedish
+ Language.SV: "sv-SE",
+ Language.SV_SE: "sv-SE",
+ # Swahili
+ Language.SW: "sw-KE",
+ Language.SW_KE: "sw-KE",
+ Language.SW_TZ: "sw-TZ",
+ # Tamil
+ Language.TA: "ta-IN",
+ Language.TA_IN: "ta-IN",
+ Language.TA_LK: "ta-LK",
+ Language.TA_MY: "ta-MY",
+ Language.TA_SG: "ta-SG",
+ # Telugu
+ Language.TE: "te-IN",
+ Language.TE_IN: "te-IN",
+ # Thai
+ Language.TH: "th-TH",
+ Language.TH_TH: "th-TH",
+ # Turkish
+ Language.TR: "tr-TR",
+ Language.TR_TR: "tr-TR",
+ # Ukrainian
+ Language.UK: "uk-UA",
+ Language.UK_UA: "uk-UA",
+ # Urdu
+ Language.UR: "ur-IN",
+ Language.UR_IN: "ur-IN",
+ Language.UR_PK: "ur-PK",
+ # Uzbek
+ Language.UZ: "uz-UZ",
+ Language.UZ_UZ: "uz-UZ",
+ # Vietnamese
+ Language.VI: "vi-VN",
+ Language.VI_VN: "vi-VN",
+ # Wu Chinese
+ Language.WUU: "wuu-CN",
+ Language.WUU_CN: "wuu-CN",
+ # Yue Chinese
+ Language.YUE: "yue-CN",
+ Language.YUE_CN: "yue-CN",
+ # Chinese
+ Language.ZH: "zh-CN",
+ Language.ZH_CN: "zh-CN",
+ Language.ZH_CN_GUANGXI: "zh-CN-guangxi",
+ Language.ZH_CN_HENAN: "zh-CN-henan",
+ Language.ZH_CN_LIAONING: "zh-CN-liaoning",
+ Language.ZH_CN_SHAANXI: "zh-CN-shaanxi",
+ Language.ZH_CN_SHANDONG: "zh-CN-shandong",
+ Language.ZH_CN_SICHUAN: "zh-CN-sichuan",
+ Language.ZH_HK: "zh-HK",
+ Language.ZH_TW: "zh-TW",
+ # Zulu
+ Language.ZU: "zu-ZA",
+ Language.ZU_ZA: "zu-ZA",
+ }
+ return language_map.get(language)
diff --git a/src/pipecat/services/azure/image.py b/src/pipecat/services/azure/image.py
new file mode 100644
index 000000000..d86c6075b
--- /dev/null
+++ b/src/pipecat/services/azure/image.py
@@ -0,0 +1,86 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import asyncio
+import io
+from typing import AsyncGenerator
+
+import aiohttp
+from loguru import logger
+from PIL import Image
+
+from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
+from pipecat.services.ai_services import ImageGenService
+
+
+class AzureImageGenServiceREST(ImageGenService):
+ def __init__(
+ self,
+ *,
+ image_size: str,
+ api_key: str,
+ endpoint: str,
+ model: str,
+ aiohttp_session: aiohttp.ClientSession,
+ api_version="2023-06-01-preview",
+ ):
+ super().__init__()
+
+ self._api_key = api_key
+ self._azure_endpoint = endpoint
+ self._api_version = api_version
+ self.set_model_name(model)
+ self._image_size = image_size
+ self._aiohttp_session = aiohttp_session
+
+ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
+ url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}"
+
+ headers = {"api-key": self._api_key, "Content-Type": "application/json"}
+
+ body = {
+ # Enter your prompt text here
+ "prompt": prompt,
+ "size": self._image_size,
+ "n": 1,
+ }
+
+ async with self._aiohttp_session.post(url, headers=headers, json=body) as submission:
+ # We never get past this line, because this header isn't
+ # defined on a 429 response, but something is eating our
+ # exceptions!
+ operation_location = submission.headers["operation-location"]
+ status = ""
+ attempts_left = 120
+ json_response = None
+ while status != "succeeded":
+ attempts_left -= 1
+ if attempts_left == 0:
+ logger.error(f"{self} error: image generation timed out")
+ yield ErrorFrame("Image generation timed out")
+ return
+
+ await asyncio.sleep(1)
+
+ response = await self._aiohttp_session.get(operation_location, headers=headers)
+
+ json_response = await response.json()
+ status = json_response["status"]
+
+ image_url = json_response["result"]["data"][0]["url"] if json_response else None
+ if not image_url:
+ logger.error(f"{self} error: image generation failed")
+ yield ErrorFrame("Image generation failed")
+ return
+
+ # Load the image from the url
+ async with self._aiohttp_session.get(image_url) as response:
+ image_stream = io.BytesIO(await response.content.read())
+ image = Image.open(image_stream)
+ frame = URLImageRawFrame(
+ url=image_url, image=image.tobytes(), size=image.size, format=image.format
+ )
+ yield frame
diff --git a/src/pipecat/services/azure/llm.py b/src/pipecat/services/azure/llm.py
new file mode 100644
index 000000000..295a1f1c1
--- /dev/null
+++ b/src/pipecat/services/azure/llm.py
@@ -0,0 +1,49 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+from loguru import logger
+from openai import AsyncAzureOpenAI
+
+from pipecat.services.openai.llm import OpenAILLMService
+
+
+class AzureLLMService(OpenAILLMService):
+ """A service for interacting with Azure OpenAI using the OpenAI-compatible interface.
+
+ This service extends OpenAILLMService to connect to Azure's OpenAI endpoint while
+ maintaining full compatibility with OpenAI's interface and functionality.
+
+ Args:
+ api_key (str): The API key for accessing Azure OpenAI
+ endpoint (str): The Azure endpoint URL
+ model (str): The model identifier to use
+ api_version (str, optional): Azure API version. Defaults to "2024-09-01-preview"
+ **kwargs: Additional keyword arguments passed to OpenAILLMService
+ """
+
+ def __init__(
+ self,
+ *,
+ api_key: str,
+ endpoint: str,
+ model: str,
+ api_version: str = "2024-09-01-preview",
+ **kwargs,
+ ):
+ # Initialize variables before calling parent __init__() because that
+ # will call create_client() and we need those values there.
+ self._endpoint = endpoint
+ self._api_version = api_version
+ super().__init__(api_key=api_key, model=model, **kwargs)
+
+ def create_client(self, api_key=None, base_url=None, **kwargs):
+ """Create OpenAI-compatible client for Azure OpenAI endpoint."""
+ logger.debug(f"Creating Azure OpenAI client with endpoint {self._endpoint}")
+ return AsyncAzureOpenAI(
+ api_key=api_key,
+ azure_endpoint=self._endpoint,
+ api_version=self._api_version,
+ )
diff --git a/src/pipecat/services/azure/stt.py b/src/pipecat/services/azure/stt.py
new file mode 100644
index 000000000..3e1302029
--- /dev/null
+++ b/src/pipecat/services/azure/stt.py
@@ -0,0 +1,107 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import asyncio
+from typing import AsyncGenerator, Optional
+
+from loguru import logger
+
+from pipecat.frames.frames import (
+ CancelFrame,
+ EndFrame,
+ Frame,
+ StartFrame,
+ TranscriptionFrame,
+)
+from pipecat.services.ai_services import STTService
+from pipecat.services.azure.common import language_to_azure_language
+from pipecat.transcriptions.language import Language
+from pipecat.utils.time import time_now_iso8601
+
+try:
+ from azure.cognitiveservices.speech import (
+ ResultReason,
+ SpeechConfig,
+ SpeechRecognizer,
+ )
+ from azure.cognitiveservices.speech.audio import (
+ AudioStreamFormat,
+ PushAudioInputStream,
+ )
+ from azure.cognitiveservices.speech.dialog import AudioConfig
+except ModuleNotFoundError as e:
+ logger.error(f"Exception: {e}")
+ logger.error("In order to use Azure, you need to `pip install pipecat-ai[azure]`.")
+ raise Exception(f"Missing module: {e}")
+
+
+class AzureSTTService(STTService):
+ def __init__(
+ self,
+ *,
+ api_key: str,
+ region: str,
+ language: Language = Language.EN_US,
+ sample_rate: Optional[int] = None,
+ **kwargs,
+ ):
+ super().__init__(sample_rate=sample_rate, **kwargs)
+
+ self._speech_config = SpeechConfig(
+ subscription=api_key,
+ region=region,
+ speech_recognition_language=language_to_azure_language(language),
+ )
+
+ self._audio_stream = None
+ self._speech_recognizer = None
+
+ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
+ await self.start_processing_metrics()
+ if self._audio_stream:
+ self._audio_stream.write(audio)
+ await self.stop_processing_metrics()
+ yield None
+
+ async def start(self, frame: StartFrame):
+ await super().start(frame)
+
+ if self._audio_stream:
+ return
+
+ stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1)
+ self._audio_stream = PushAudioInputStream(stream_format)
+
+ audio_config = AudioConfig(stream=self._audio_stream)
+
+ self._speech_recognizer = SpeechRecognizer(
+ speech_config=self._speech_config, audio_config=audio_config
+ )
+ self._speech_recognizer.recognized.connect(self._on_handle_recognized)
+ self._speech_recognizer.start_continuous_recognition_async()
+
+ async def stop(self, frame: EndFrame):
+ await super().stop(frame)
+
+ if self._speech_recognizer:
+ self._speech_recognizer.stop_continuous_recognition_async()
+
+ if self._audio_stream:
+ self._audio_stream.close()
+
+ async def cancel(self, frame: CancelFrame):
+ await super().cancel(frame)
+
+ if self._speech_recognizer:
+ self._speech_recognizer.stop_continuous_recognition_async()
+
+ if self._audio_stream:
+ self._audio_stream.close()
+
+ def _on_handle_recognized(self, event):
+ if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
+ frame = TranscriptionFrame(event.result.text, "", time_now_iso8601())
+ asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop())
diff --git a/src/pipecat/services/azure/tts.py b/src/pipecat/services/azure/tts.py
new file mode 100644
index 000000000..1227a4e96
--- /dev/null
+++ b/src/pipecat/services/azure/tts.py
@@ -0,0 +1,290 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import asyncio
+from typing import AsyncGenerator, Optional
+
+from loguru import logger
+from pydantic import BaseModel
+
+from pipecat.frames.frames import (
+ ErrorFrame,
+ Frame,
+ StartFrame,
+ TTSAudioRawFrame,
+ TTSStartedFrame,
+ TTSStoppedFrame,
+)
+from pipecat.services.ai_services import TTSService
+from pipecat.services.azure.common import language_to_azure_language
+from pipecat.transcriptions.language import Language
+
+try:
+ from azure.cognitiveservices.speech import (
+ CancellationReason,
+ ResultReason,
+ ServicePropertyChannel,
+ SpeechConfig,
+ SpeechSynthesisOutputFormat,
+ SpeechSynthesizer,
+ )
+except ModuleNotFoundError as e:
+ logger.error(f"Exception: {e}")
+ logger.error("In order to use Azure, you need to `pip install pipecat-ai[azure]`.")
+ raise Exception(f"Missing module: {e}")
+
+
+def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputFormat:
+ sample_rate_map = {
+ 8000: SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm,
+ 16000: SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm,
+ 22050: SpeechSynthesisOutputFormat.Raw22050Hz16BitMonoPcm,
+ 24000: SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm,
+ 44100: SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm,
+ 48000: SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm,
+ }
+ return sample_rate_map.get(sample_rate, SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm)
+
+
+class AzureBaseTTSService(TTSService):
+ class InputParams(BaseModel):
+ emphasis: Optional[str] = None
+ language: Optional[Language] = Language.EN_US
+ pitch: Optional[str] = None
+ rate: Optional[str] = "1.05"
+ role: Optional[str] = None
+ style: Optional[str] = None
+ style_degree: Optional[str] = None
+ volume: Optional[str] = None
+
+ def __init__(
+ self,
+ *,
+ api_key: str,
+ region: str,
+ voice="en-US-SaraNeural",
+ sample_rate: Optional[int] = None,
+ params: InputParams = InputParams(),
+ **kwargs,
+ ):
+ super().__init__(sample_rate=sample_rate, **kwargs)
+
+ self._settings = {
+ "emphasis": params.emphasis,
+ "language": self.language_to_service_language(params.language)
+ if params.language
+ else "en-US",
+ "pitch": params.pitch,
+ "rate": params.rate,
+ "role": params.role,
+ "style": params.style,
+ "style_degree": params.style_degree,
+ "volume": params.volume,
+ }
+
+ self._api_key = api_key
+ self._region = region
+ self._voice_id = voice
+ self._speech_synthesizer = None
+
+ def can_generate_metrics(self) -> bool:
+ return True
+
+ def language_to_service_language(self, language: Language) -> Optional[str]:
+ return language_to_azure_language(language)
+
+ def _construct_ssml(self, text: str) -> str:
+ language = self._settings["language"]
+ ssml = (
+ f""
+ f""
+ ""
+ )
+
+ if self._settings["style"]:
+ ssml += f""
+
+ prosody_attrs = []
+ if self._settings["rate"]:
+ prosody_attrs.append(f"rate='{self._settings['rate']}'")
+ if self._settings["pitch"]:
+ prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
+ if self._settings["volume"]:
+ prosody_attrs.append(f"volume='{self._settings['volume']}'")
+
+ ssml += f""
+
+ if self._settings["emphasis"]:
+ ssml += f""
+
+ ssml += text
+
+ if self._settings["emphasis"]:
+ ssml += ""
+
+ ssml += ""
+
+ if self._settings["style"]:
+ ssml += ""
+
+ ssml += ""
+
+ return ssml
+
+
+class AzureTTSService(AzureBaseTTSService):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self._speech_config = None
+ self._speech_synthesizer = None
+ self._audio_queue = asyncio.Queue()
+
+ async def start(self, frame: StartFrame):
+ await super().start(frame)
+
+ if self._speech_config:
+ return
+
+ # Now self.sample_rate is properly initialized
+ self._speech_config = SpeechConfig(
+ subscription=self._api_key,
+ region=self._region,
+ speech_recognition_language=self._settings["language"],
+ )
+ self._speech_config.set_speech_synthesis_output_format(
+ sample_rate_to_output_format(self.sample_rate)
+ )
+ self._speech_config.set_service_property(
+ "synthesizer.synthesis.connection.synthesisConnectionImpl",
+ "websocket",
+ ServicePropertyChannel.UriQueryParameter,
+ )
+
+ self._speech_synthesizer = SpeechSynthesizer(
+ speech_config=self._speech_config, audio_config=None
+ )
+
+ # Set up event handlers
+ self._speech_synthesizer.synthesizing.connect(self._handle_synthesizing)
+ self._speech_synthesizer.synthesis_completed.connect(self._handle_completed)
+ self._speech_synthesizer.synthesis_canceled.connect(self._handle_canceled)
+
+ def _handle_synthesizing(self, evt):
+ """Handle audio chunks as they arrive"""
+ if evt.result and evt.result.audio_data:
+ self._audio_queue.put_nowait(evt.result.audio_data)
+
+ def _handle_completed(self, evt):
+ """Handle synthesis completion"""
+ self._audio_queue.put_nowait(None) # Signal completion
+
+ def _handle_canceled(self, evt):
+ """Handle synthesis cancellation"""
+ logger.error(f"Speech synthesis canceled: {evt.result.cancellation_details.reason}")
+ self._audio_queue.put_nowait(None)
+
+ async def flush_audio(self):
+ logger.trace(f"{self}: flushing audio")
+
+ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
+ logger.debug(f"{self}: Generating TTS [{text}]")
+
+ try:
+ if self._speech_synthesizer is None:
+ error_msg = "Speech synthesizer not initialized."
+ logger.error(error_msg)
+ yield ErrorFrame(error_msg)
+ return
+
+ try:
+ await self.start_ttfb_metrics()
+ yield TTSStartedFrame()
+
+ ssml = self._construct_ssml(text)
+ self._speech_synthesizer.speak_ssml_async(ssml)
+ await self.start_tts_usage_metrics(text)
+
+ # Stream audio chunks as they arrive
+ while True:
+ chunk = await self._audio_queue.get()
+ if chunk is None: # End of stream
+ break
+
+ await self.stop_ttfb_metrics()
+ yield TTSAudioRawFrame(
+ audio=chunk,
+ sample_rate=self.sample_rate,
+ num_channels=1,
+ )
+
+ yield TTSStoppedFrame()
+
+ except Exception as e:
+ logger.error(f"{self} error during synthesis: {e}")
+ yield TTSStoppedFrame()
+ # Could add reconnection logic here if needed
+ return
+
+ except Exception as e:
+ logger.error(f"{self} exception: {e}")
+
+
+class AzureHttpTTSService(AzureBaseTTSService):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self._speech_config = None
+ self._speech_synthesizer = None
+
+ async def start(self, frame: StartFrame):
+ await super().start(frame)
+
+ if self._speech_config:
+ return
+
+ self._speech_config = SpeechConfig(
+ subscription=self._api_key,
+ region=self._region,
+ speech_recognition_language=self._settings["language"],
+ )
+ self._speech_config.set_speech_synthesis_output_format(
+ sample_rate_to_output_format(self.sample_rate)
+ )
+ self._speech_synthesizer = SpeechSynthesizer(
+ speech_config=self._speech_config, audio_config=None
+ )
+
+ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
+ logger.debug(f"{self}: Generating TTS [{text}]")
+
+ await self.start_ttfb_metrics()
+
+ ssml = self._construct_ssml(text)
+
+ result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, ssml)
+
+ if result.reason == ResultReason.SynthesizingAudioCompleted:
+ await self.start_tts_usage_metrics(text)
+ await self.stop_ttfb_metrics()
+ yield TTSStartedFrame()
+ # Azure always sends a 44-byte header. Strip it off.
+ yield TTSAudioRawFrame(
+ audio=result.audio_data[44:],
+ sample_rate=self.sample_rate,
+ num_channels=1,
+ )
+ yield TTSStoppedFrame()
+ elif result.reason == ResultReason.Canceled:
+ cancellation_details = result.cancellation_details
+ logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}")
+ if cancellation_details.reason == CancellationReason.Error:
+ logger.error(f"{self} error: {cancellation_details.error_details}")
diff --git a/src/pipecat/services/canonical/__init__.py b/src/pipecat/services/canonical/__init__.py
new file mode 100644
index 000000000..f47b99c4e
--- /dev/null
+++ b/src/pipecat/services/canonical/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .metrics import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "canonical", "canonical.metrics")
diff --git a/src/pipecat/services/canonical.py b/src/pipecat/services/canonical/metrics.py
similarity index 100%
rename from src/pipecat/services/canonical.py
rename to src/pipecat/services/canonical/metrics.py
diff --git a/src/pipecat/services/cartesia/__init__.py b/src/pipecat/services/cartesia/__init__.py
new file mode 100644
index 000000000..56c789743
--- /dev/null
+++ b/src/pipecat/services/cartesia/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "cartesia", "cartesia.tts")
diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia/tts.py
similarity index 98%
rename from src/pipecat/services/cartesia.py
rename to src/pipecat/services/cartesia/tts.py
index 11796464d..45c2fcaf8 100644
--- a/src/pipecat/services/cartesia.py
+++ b/src/pipecat/services/cartesia/tts.py
@@ -35,9 +35,7 @@ try:
from cartesia import AsyncCartesia
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
- logger.error(
- "In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`. Also, set `CARTESIA_API_KEY` environment variable."
- )
+ logger.error("In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`.")
raise Exception(f"Missing module: {e}")
diff --git a/src/pipecat/services/cerebras/__init__.py b/src/pipecat/services/cerebras/__init__.py
new file mode 100644
index 000000000..726a227de
--- /dev/null
+++ b/src/pipecat/services/cerebras/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "cerebras", "cerebras.llm")
diff --git a/src/pipecat/services/cerebras.py b/src/pipecat/services/cerebras/llm.py
similarity index 98%
rename from src/pipecat/services/cerebras.py
rename to src/pipecat/services/cerebras/llm.py
index b5e34afe5..2217cc2f8 100644
--- a/src/pipecat/services/cerebras.py
+++ b/src/pipecat/services/cerebras/llm.py
@@ -11,7 +11,7 @@ from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
-from pipecat.services.openai import OpenAILLMService
+from pipecat.services.openai.llm import OpenAILLMService
class CerebrasLLMService(OpenAILLMService):
diff --git a/src/pipecat/services/deepgram/__init__.py b/src/pipecat/services/deepgram/__init__.py
new file mode 100644
index 000000000..bb7a3c1f8
--- /dev/null
+++ b/src/pipecat/services/deepgram/__init__.py
@@ -0,0 +1,14 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .stt import *
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "deepgram", "deepgram.[stt,tts]")
diff --git a/src/pipecat/services/deepgram.py b/src/pipecat/services/deepgram/stt.py
similarity index 74%
rename from src/pipecat/services/deepgram.py
rename to src/pipecat/services/deepgram/stt.py
index 689d55e9f..ae8b2318a 100644
--- a/src/pipecat/services/deepgram.py
+++ b/src/pipecat/services/deepgram/stt.py
@@ -4,7 +4,6 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
-import asyncio
from typing import AsyncGenerator, Dict, Optional
from loguru import logger
@@ -12,23 +11,18 @@ from loguru import logger
from pipecat.frames.frames import (
CancelFrame,
EndFrame,
- ErrorFrame,
Frame,
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
- TTSAudioRawFrame,
- TTSStartedFrame,
- TTSStoppedFrame,
UserStartedSpeakingFrame,
UserStoppedSpeakingFrame,
)
from pipecat.processors.frame_processor import FrameDirection
-from pipecat.services.ai_services import STTService, TTSService
+from pipecat.services.ai_services import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
-# See .env.example for Deepgram configuration needed
try:
from deepgram import (
AsyncListenWebSocketClient,
@@ -38,80 +32,13 @@ try:
LiveOptions,
LiveResultResponse,
LiveTranscriptionEvents,
- SpeakOptions,
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
- logger.error(
- "In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`. Also, set `DEEPGRAM_API_KEY` environment variable."
- )
+ logger.error("In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`.")
raise Exception(f"Missing module: {e}")
-class DeepgramTTSService(TTSService):
- def __init__(
- self,
- *,
- api_key: str,
- voice: str = "aura-helios-en",
- sample_rate: Optional[int] = None,
- encoding: str = "linear16",
- **kwargs,
- ):
- super().__init__(sample_rate=sample_rate, **kwargs)
-
- self._settings = {
- "encoding": encoding,
- }
- self.set_voice(voice)
- self._deepgram_client = DeepgramClient(api_key=api_key)
-
- def can_generate_metrics(self) -> bool:
- return True
-
- async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
- logger.debug(f"{self}: Generating TTS [{text}]")
-
- options = SpeakOptions(
- model=self._voice_id,
- encoding=self._settings["encoding"],
- sample_rate=self.sample_rate,
- container="none",
- )
-
- try:
- await self.start_ttfb_metrics()
-
- response = await asyncio.to_thread(
- self._deepgram_client.speak.v("1").stream, {"text": text}, options
- )
-
- await self.start_tts_usage_metrics(text)
- yield TTSStartedFrame()
-
- # The response.stream_memory is already a BytesIO object
- audio_buffer = response.stream_memory
-
- if audio_buffer is None:
- raise ValueError("No audio data received from Deepgram")
-
- # Read and yield the audio data in chunks
- audio_buffer.seek(0) # Ensure we're at the start of the buffer
- chunk_size = 1024 # Use a fixed buffer size
- while True:
- await self.stop_ttfb_metrics()
- chunk = audio_buffer.read(chunk_size)
- if not chunk:
- break
- frame = TTSAudioRawFrame(audio=chunk, sample_rate=self.sample_rate, num_channels=1)
- yield frame
- yield TTSStoppedFrame()
-
- except Exception as e:
- logger.exception(f"{self} exception: {e}")
- yield ErrorFrame(f"Error getting audio: {str(e)}")
-
-
class DeepgramSTTService(STTService):
def __init__(
self,
diff --git a/src/pipecat/services/deepgram/tts.py b/src/pipecat/services/deepgram/tts.py
new file mode 100644
index 000000000..5e05292d9
--- /dev/null
+++ b/src/pipecat/services/deepgram/tts.py
@@ -0,0 +1,90 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import asyncio
+from typing import AsyncGenerator, Optional
+
+from loguru import logger
+
+from pipecat.frames.frames import (
+ ErrorFrame,
+ Frame,
+ TTSAudioRawFrame,
+ TTSStartedFrame,
+ TTSStoppedFrame,
+)
+from pipecat.services.ai_services import TTSService
+
+try:
+ from deepgram import DeepgramClient, SpeakOptions
+except ModuleNotFoundError as e:
+ logger.error(f"Exception: {e}")
+ logger.error("In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`.")
+ raise Exception(f"Missing module: {e}")
+
+
+class DeepgramTTSService(TTSService):
+ def __init__(
+ self,
+ *,
+ api_key: str,
+ voice: str = "aura-helios-en",
+ sample_rate: Optional[int] = None,
+ encoding: str = "linear16",
+ **kwargs,
+ ):
+ super().__init__(sample_rate=sample_rate, **kwargs)
+
+ self._settings = {
+ "encoding": encoding,
+ }
+ self.set_voice(voice)
+ self._deepgram_client = DeepgramClient(api_key=api_key)
+
+ def can_generate_metrics(self) -> bool:
+ return True
+
+ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
+ logger.debug(f"{self}: Generating TTS [{text}]")
+
+ options = SpeakOptions(
+ model=self._voice_id,
+ encoding=self._settings["encoding"],
+ sample_rate=self.sample_rate,
+ container="none",
+ )
+
+ try:
+ await self.start_ttfb_metrics()
+
+ response = await asyncio.to_thread(
+ self._deepgram_client.speak.v("1").stream, {"text": text}, options
+ )
+
+ await self.start_tts_usage_metrics(text)
+ yield TTSStartedFrame()
+
+ # The response.stream_memory is already a BytesIO object
+ audio_buffer = response.stream_memory
+
+ if audio_buffer is None:
+ raise ValueError("No audio data received from Deepgram")
+
+ # Read and yield the audio data in chunks
+ audio_buffer.seek(0) # Ensure we're at the start of the buffer
+ chunk_size = 1024 # Use a fixed buffer size
+ while True:
+ await self.stop_ttfb_metrics()
+ chunk = audio_buffer.read(chunk_size)
+ if not chunk:
+ break
+ frame = TTSAudioRawFrame(audio=chunk, sample_rate=self.sample_rate, num_channels=1)
+ yield frame
+ yield TTSStoppedFrame()
+
+ except Exception as e:
+ logger.exception(f"{self} exception: {e}")
+ yield ErrorFrame(f"Error getting audio: {str(e)}")
diff --git a/src/pipecat/services/deepseek/__init__.py b/src/pipecat/services/deepseek/__init__.py
new file mode 100644
index 000000000..1e483d95f
--- /dev/null
+++ b/src/pipecat/services/deepseek/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "deepseek", "deepseek.llm")
diff --git a/src/pipecat/services/deepseek.py b/src/pipecat/services/deepseek/llm.py
similarity index 98%
rename from src/pipecat/services/deepseek.py
rename to src/pipecat/services/deepseek/llm.py
index 2537f2f7c..7bed5d33b 100644
--- a/src/pipecat/services/deepseek.py
+++ b/src/pipecat/services/deepseek/llm.py
@@ -12,7 +12,7 @@ from openai import AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
-from pipecat.services.openai import OpenAILLMService
+from pipecat.services.openai.llm import OpenAILLMService
class DeepSeekLLMService(OpenAILLMService):
diff --git a/src/pipecat/services/elevenlabs/__init__.py b/src/pipecat/services/elevenlabs/__init__.py
new file mode 100644
index 000000000..e5a76e71a
--- /dev/null
+++ b/src/pipecat/services/elevenlabs/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "elevenlabs", "elevenlabs.tts")
diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs/tts.py
similarity index 99%
rename from src/pipecat/services/elevenlabs.py
rename to src/pipecat/services/elevenlabs/tts.py
index 68c71a144..7b4e4f0dc 100644
--- a/src/pipecat/services/elevenlabs.py
+++ b/src/pipecat/services/elevenlabs/tts.py
@@ -33,9 +33,7 @@ try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
- logger.error(
- "In order to use ElevenLabs, you need to `pip install pipecat-ai[elevenlabs]`. Also, set `ELEVENLABS_API_KEY` environment variable."
- )
+ logger.error("In order to use ElevenLabs, you need to `pip install pipecat-ai[elevenlabs]`.")
raise Exception(f"Missing module: {e}")
ElevenLabsOutputFormat = Literal["pcm_16000", "pcm_22050", "pcm_24000", "pcm_44100"]
diff --git a/src/pipecat/services/fal/__init__.py b/src/pipecat/services/fal/__init__.py
new file mode 100644
index 000000000..6235b4812
--- /dev/null
+++ b/src/pipecat/services/fal/__init__.py
@@ -0,0 +1,14 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .image import *
+from .stt import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "fal", "fal.[image,stt]")
diff --git a/src/pipecat/services/fal/image.py b/src/pipecat/services/fal/image.py
new file mode 100644
index 000000000..6ba14caf9
--- /dev/null
+++ b/src/pipecat/services/fal/image.py
@@ -0,0 +1,84 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import asyncio
+import io
+import os
+from typing import AsyncGenerator, Dict, Optional, Union
+
+import aiohttp
+from loguru import logger
+from PIL import Image
+from pydantic import BaseModel
+
+from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
+from pipecat.services.ai_services import ImageGenService
+
+try:
+ import fal_client
+except ModuleNotFoundError as e:
+ logger.error(f"Exception: {e}")
+ logger.error("In order to use Fal, you need to `pip install pipecat-ai[fal]`.")
+ raise Exception(f"Missing module: {e}")
+
+
+class FalImageGenService(ImageGenService):
+ class InputParams(BaseModel):
+ seed: Optional[int] = None
+ num_inference_steps: int = 8
+ num_images: int = 1
+ image_size: Union[str, Dict[str, int]] = "square_hd"
+ expand_prompt: bool = False
+ enable_safety_checker: bool = True
+ format: str = "png"
+
+ def __init__(
+ self,
+ *,
+ params: InputParams,
+ aiohttp_session: aiohttp.ClientSession,
+ model: str = "fal-ai/fast-sdxl",
+ key: Optional[str] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.set_model_name(model)
+ self._params = params
+ self._aiohttp_session = aiohttp_session
+ if key:
+ os.environ["FAL_KEY"] = key
+
+ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
+ def load_image_bytes(encoded_image: bytes):
+ buffer = io.BytesIO(encoded_image)
+ image = Image.open(buffer)
+ return (image.tobytes(), image.size, image.format)
+
+ logger.debug(f"Generating image from prompt: {prompt}")
+
+ response = await fal_client.run_async(
+ self.model_name,
+ arguments={"prompt": prompt, **self._params.model_dump(exclude_none=True)},
+ )
+
+ image_url = response["images"][0]["url"] if response else None
+
+ if not image_url:
+ logger.error(f"{self} error: image generation failed")
+ yield ErrorFrame("Image generation failed")
+ return
+
+ logger.debug(f"Image generated at: {image_url}")
+
+ # Load the image from the url
+ logger.debug(f"Downloading image {image_url} ...")
+ async with self._aiohttp_session.get(image_url) as response:
+ logger.debug(f"Downloaded image {image_url}")
+ encoded_image = await response.content.read()
+ (image_bytes, size, format) = await asyncio.to_thread(load_image_bytes, encoded_image)
+
+ frame = URLImageRawFrame(url=image_url, image=image_bytes, size=size, format=format)
+ yield frame
diff --git a/src/pipecat/services/fal.py b/src/pipecat/services/fal/stt.py
similarity index 76%
rename from src/pipecat/services/fal.py
rename to src/pipecat/services/fal/stt.py
index cb39da75f..4926e4718 100644
--- a/src/pipecat/services/fal.py
+++ b/src/pipecat/services/fal/stt.py
@@ -4,19 +4,14 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
-import asyncio
-import io
import os
-import wave
-from typing import AsyncGenerator, Dict, Optional, Union
+from typing import AsyncGenerator, Optional
-import aiohttp
from loguru import logger
-from PIL import Image
from pydantic import BaseModel
-from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame, URLImageRawFrame
-from pipecat.services.ai_services import ImageGenService, SegmentedSTTService
+from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame
+from pipecat.services.ai_services import SegmentedSTTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
@@ -144,65 +139,6 @@ def language_to_fal_language(language: Language) -> Optional[str]:
return result
-class FalImageGenService(ImageGenService):
- class InputParams(BaseModel):
- seed: Optional[int] = None
- num_inference_steps: int = 8
- num_images: int = 1
- image_size: Union[str, Dict[str, int]] = "square_hd"
- expand_prompt: bool = False
- enable_safety_checker: bool = True
- format: str = "png"
-
- def __init__(
- self,
- *,
- params: InputParams,
- aiohttp_session: aiohttp.ClientSession,
- model: str = "fal-ai/fast-sdxl",
- key: Optional[str] = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.set_model_name(model)
- self._params = params
- self._aiohttp_session = aiohttp_session
- if key:
- os.environ["FAL_KEY"] = key
-
- async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
- def load_image_bytes(encoded_image: bytes):
- buffer = io.BytesIO(encoded_image)
- image = Image.open(buffer)
- return (image.tobytes(), image.size, image.format)
-
- logger.debug(f"Generating image from prompt: {prompt}")
-
- response = await fal_client.run_async(
- self.model_name,
- arguments={"prompt": prompt, **self._params.model_dump(exclude_none=True)},
- )
-
- image_url = response["images"][0]["url"] if response else None
-
- if not image_url:
- logger.error(f"{self} error: image generation failed")
- yield ErrorFrame("Image generation failed")
- return
-
- logger.debug(f"Image generated at: {image_url}")
-
- # Load the image from the url
- logger.debug(f"Downloading image {image_url} ...")
- async with self._aiohttp_session.get(image_url) as response:
- logger.debug(f"Downloaded image {image_url}")
- encoded_image = await response.content.read()
- (image_bytes, size, format) = await asyncio.to_thread(load_image_bytes, encoded_image)
-
- frame = URLImageRawFrame(url=image_url, image=image_bytes, size=size, format=format)
- yield frame
-
-
class FalSTTService(SegmentedSTTService):
"""Speech-to-text service using Fal's Wizper API.
diff --git a/src/pipecat/services/fireworks/__init__.py b/src/pipecat/services/fireworks/__init__.py
new file mode 100644
index 000000000..c308ad967
--- /dev/null
+++ b/src/pipecat/services/fireworks/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "fireworks", "fireworks.llm")
diff --git a/src/pipecat/services/fireworks.py b/src/pipecat/services/fireworks/llm.py
similarity index 97%
rename from src/pipecat/services/fireworks.py
rename to src/pipecat/services/fireworks/llm.py
index 8e40e1a7b..d4003f86f 100644
--- a/src/pipecat/services/fireworks.py
+++ b/src/pipecat/services/fireworks/llm.py
@@ -11,7 +11,7 @@ from loguru import logger
from openai.types.chat import ChatCompletionMessageParam
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
-from pipecat.services.openai import OpenAILLMService
+from pipecat.services.openai.llm import OpenAILLMService
class FireworksLLMService(OpenAILLMService):
diff --git a/src/pipecat/services/fish/__init__.py b/src/pipecat/services/fish/__init__.py
new file mode 100644
index 000000000..a783d6224
--- /dev/null
+++ b/src/pipecat/services/fish/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "fish", "fish.tts")
diff --git a/src/pipecat/services/fish.py b/src/pipecat/services/fish/tts.py
similarity index 98%
rename from src/pipecat/services/fish.py
rename to src/pipecat/services/fish/tts.py
index 9e6a8b91e..d4fe59635 100644
--- a/src/pipecat/services/fish.py
+++ b/src/pipecat/services/fish/tts.py
@@ -30,9 +30,7 @@ try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
- logger.error(
- "In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`. Also, set `FISH_API_KEY` environment variable."
- )
+ logger.error("In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`.")
raise Exception(f"Missing module: {e}")
# FishAudio supports various output formats
diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py
index 965648ada..3ecab0186 100644
--- a/src/pipecat/services/gemini_multimodal_live/gemini.py
+++ b/src/pipecat/services/gemini_multimodal_live/gemini.py
@@ -51,7 +51,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
-from pipecat.services.openai import (
+from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAIUserContextAggregator,
)
diff --git a/src/pipecat/services/google/__init__.py b/src/pipecat/services/google/__init__.py
index 3e63a3ba9..ec187000f 100644
--- a/src/pipecat/services/google/__init__.py
+++ b/src/pipecat/services/google/__init__.py
@@ -1,3 +1,22 @@
-from .frames import LLMSearchResponseFrame
-from .google import *
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .frames import *
+from .image import *
+from .llm import *
+from .llm_openai import *
+from .llm_vertex import *
from .rtvi import *
+from .stt import *
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(
+ globals(), "google", "google.[frames,image,llm,llm_openai,llm_vertex,rtvi,stt,tts]"
+)
diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py
index 95c5a1edb..ec187000f 100644
--- a/src/pipecat/services/google/google.py
+++ b/src/pipecat/services/google/google.py
@@ -4,2089 +4,19 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
-import asyncio
-import base64
-import io
-import json
-import os
-import time
-import uuid
+import sys
-from google.api_core.exceptions import DeadlineExceeded
-from openai import AsyncStream
-from openai.types.chat import ChatCompletionChunk
+from pipecat.services import DeprecatedModuleProxy
-from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
+from .frames import *
+from .image import *
+from .llm import *
+from .llm_openai import *
+from .llm_vertex import *
+from .rtvi import *
+from .stt import *
+from .tts import *
-# Suppress gRPC fork warnings
-os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
-
-from dataclasses import dataclass
-from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Union
-
-from loguru import logger
-from PIL import Image
-from pydantic import BaseModel, Field, field_validator
-
-from pipecat.frames.frames import (
- AudioRawFrame,
- CancelFrame,
- EndFrame,
- ErrorFrame,
- Frame,
- FunctionCallCancelFrame,
- FunctionCallInProgressFrame,
- FunctionCallResultFrame,
- InterimTranscriptionFrame,
- LLMFullResponseEndFrame,
- LLMFullResponseStartFrame,
- LLMMessagesFrame,
- LLMTextFrame,
- LLMUpdateSettingsFrame,
- StartFrame,
- TranscriptionFrame,
- TTSAudioRawFrame,
- TTSStartedFrame,
- TTSStoppedFrame,
- URLImageRawFrame,
- UserImageRawFrame,
- VisionImageRawFrame,
+sys.modules[__name__] = DeprecatedModuleProxy(
+ globals(), "google", "google.[frames,image,llm,llm_openai,llm_vertex,rtvi,stt,tts]"
)
-from pipecat.metrics.metrics import LLMTokenUsage
-from pipecat.processors.aggregators.openai_llm_context import (
- OpenAILLMContext,
- OpenAILLMContextFrame,
-)
-from pipecat.processors.frame_processor import FrameDirection
-from pipecat.services.ai_services import ImageGenService, LLMService, STTService, TTSService
-from pipecat.services.google.frames import LLMSearchResponseFrame
-from pipecat.services.openai import (
- OpenAIAssistantContextAggregator,
- OpenAILLMService,
- OpenAIUnhandledFunctionException,
- OpenAIUserContextAggregator,
-)
-from pipecat.transcriptions.language import Language
-from pipecat.utils.time import time_now_iso8601
-
-try:
- import google.ai.generativelanguage as glm
- import google.generativeai as gai
- from google import genai
- from google.api_core.client_options import ClientOptions
- from google.auth.transport.requests import Request
- from google.cloud import speech_v2, texttospeech_v1
- from google.cloud.speech_v2.types import cloud_speech
- from google.genai import types
- from google.generativeai.types import GenerationConfig
- from google.oauth2 import service_account
-
-except ModuleNotFoundError as e:
- logger.error(f"Exception: {e}")
- logger.error(
- "In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set the environment variable GOOGLE_API_KEY for the GoogleLLMService and GOOGLE_APPLICATION_CREDENTIALS for the GoogleTTSService`."
- )
- raise Exception(f"Missing module: {e}")
-
-
-def language_to_google_tts_language(language: Language) -> Optional[str]:
- language_map = {
- # Afrikaans
- Language.AF: "af-ZA",
- Language.AF_ZA: "af-ZA",
- # Arabic
- Language.AR: "ar-XA",
- # Bengali
- Language.BN: "bn-IN",
- Language.BN_IN: "bn-IN",
- # Bulgarian
- Language.BG: "bg-BG",
- Language.BG_BG: "bg-BG",
- # Catalan
- Language.CA: "ca-ES",
- Language.CA_ES: "ca-ES",
- # Chinese (Mandarin and Cantonese)
- Language.ZH: "cmn-CN",
- Language.ZH_CN: "cmn-CN",
- Language.ZH_TW: "cmn-TW",
- Language.ZH_HK: "yue-HK",
- # Czech
- Language.CS: "cs-CZ",
- Language.CS_CZ: "cs-CZ",
- # Danish
- Language.DA: "da-DK",
- Language.DA_DK: "da-DK",
- # Dutch
- Language.NL: "nl-NL",
- Language.NL_BE: "nl-BE",
- Language.NL_NL: "nl-NL",
- # English
- Language.EN: "en-US",
- Language.EN_US: "en-US",
- Language.EN_AU: "en-AU",
- Language.EN_GB: "en-GB",
- Language.EN_IN: "en-IN",
- # Estonian
- Language.ET: "et-EE",
- Language.ET_EE: "et-EE",
- # Filipino
- Language.FIL: "fil-PH",
- Language.FIL_PH: "fil-PH",
- # Finnish
- Language.FI: "fi-FI",
- Language.FI_FI: "fi-FI",
- # French
- Language.FR: "fr-FR",
- Language.FR_CA: "fr-CA",
- Language.FR_FR: "fr-FR",
- # Galician
- Language.GL: "gl-ES",
- Language.GL_ES: "gl-ES",
- # German
- Language.DE: "de-DE",
- Language.DE_DE: "de-DE",
- # Greek
- Language.EL: "el-GR",
- Language.EL_GR: "el-GR",
- # Gujarati
- Language.GU: "gu-IN",
- Language.GU_IN: "gu-IN",
- # Hebrew
- Language.HE: "he-IL",
- Language.HE_IL: "he-IL",
- # Hindi
- Language.HI: "hi-IN",
- Language.HI_IN: "hi-IN",
- # Hungarian
- Language.HU: "hu-HU",
- Language.HU_HU: "hu-HU",
- # Icelandic
- Language.IS: "is-IS",
- Language.IS_IS: "is-IS",
- # Indonesian
- Language.ID: "id-ID",
- Language.ID_ID: "id-ID",
- # Italian
- Language.IT: "it-IT",
- Language.IT_IT: "it-IT",
- # Japanese
- Language.JA: "ja-JP",
- Language.JA_JP: "ja-JP",
- # Kannada
- Language.KN: "kn-IN",
- Language.KN_IN: "kn-IN",
- # Korean
- Language.KO: "ko-KR",
- Language.KO_KR: "ko-KR",
- # Latvian
- Language.LV: "lv-LV",
- Language.LV_LV: "lv-LV",
- # Lithuanian
- Language.LT: "lt-LT",
- Language.LT_LT: "lt-LT",
- # Malay
- Language.MS: "ms-MY",
- Language.MS_MY: "ms-MY",
- # Malayalam
- Language.ML: "ml-IN",
- Language.ML_IN: "ml-IN",
- # Marathi
- Language.MR: "mr-IN",
- Language.MR_IN: "mr-IN",
- # Norwegian
- Language.NO: "nb-NO",
- Language.NB: "nb-NO",
- Language.NB_NO: "nb-NO",
- # Polish
- Language.PL: "pl-PL",
- Language.PL_PL: "pl-PL",
- # Portuguese
- Language.PT: "pt-PT",
- Language.PT_BR: "pt-BR",
- Language.PT_PT: "pt-PT",
- # Punjabi
- Language.PA: "pa-IN",
- Language.PA_IN: "pa-IN",
- # Romanian
- Language.RO: "ro-RO",
- Language.RO_RO: "ro-RO",
- # Russian
- Language.RU: "ru-RU",
- Language.RU_RU: "ru-RU",
- # Serbian
- Language.SR: "sr-RS",
- Language.SR_RS: "sr-RS",
- # Slovak
- Language.SK: "sk-SK",
- Language.SK_SK: "sk-SK",
- # Spanish
- Language.ES: "es-ES",
- Language.ES_ES: "es-ES",
- Language.ES_US: "es-US",
- # Swedish
- Language.SV: "sv-SE",
- Language.SV_SE: "sv-SE",
- # Tamil
- Language.TA: "ta-IN",
- Language.TA_IN: "ta-IN",
- # Telugu
- Language.TE: "te-IN",
- Language.TE_IN: "te-IN",
- # Thai
- Language.TH: "th-TH",
- Language.TH_TH: "th-TH",
- # Turkish
- Language.TR: "tr-TR",
- Language.TR_TR: "tr-TR",
- # Ukrainian
- Language.UK: "uk-UA",
- Language.UK_UA: "uk-UA",
- # Vietnamese
- Language.VI: "vi-VN",
- Language.VI_VN: "vi-VN",
- }
-
- return language_map.get(language)
-
-
-def language_to_google_stt_language(language: Language) -> Optional[str]:
- """Maps Language enum to Google Speech-to-Text V2 language codes.
-
- Args:
- language: Language enum value.
-
- Returns:
- Optional[str]: Google STT language code or None if not supported.
- """
- language_map = {
- # Afrikaans
- Language.AF: "af-ZA",
- Language.AF_ZA: "af-ZA",
- # Albanian
- Language.SQ: "sq-AL",
- Language.SQ_AL: "sq-AL",
- # Amharic
- Language.AM: "am-ET",
- Language.AM_ET: "am-ET",
- # Arabic
- Language.AR: "ar-EG", # Default to Egypt
- Language.AR_AE: "ar-AE",
- Language.AR_BH: "ar-BH",
- Language.AR_DZ: "ar-DZ",
- Language.AR_EG: "ar-EG",
- Language.AR_IQ: "ar-IQ",
- Language.AR_JO: "ar-JO",
- Language.AR_KW: "ar-KW",
- Language.AR_LB: "ar-LB",
- Language.AR_MA: "ar-MA",
- Language.AR_OM: "ar-OM",
- Language.AR_QA: "ar-QA",
- Language.AR_SA: "ar-SA",
- Language.AR_SY: "ar-SY",
- Language.AR_TN: "ar-TN",
- Language.AR_YE: "ar-YE",
- # Armenian
- Language.HY: "hy-AM",
- Language.HY_AM: "hy-AM",
- # Azerbaijani
- Language.AZ: "az-AZ",
- Language.AZ_AZ: "az-AZ",
- # Basque
- Language.EU: "eu-ES",
- Language.EU_ES: "eu-ES",
- # Bengali
- Language.BN: "bn-IN", # Default to India
- Language.BN_BD: "bn-BD",
- Language.BN_IN: "bn-IN",
- # Bosnian
- Language.BS: "bs-BA",
- Language.BS_BA: "bs-BA",
- # Bulgarian
- Language.BG: "bg-BG",
- Language.BG_BG: "bg-BG",
- # Burmese
- Language.MY: "my-MM",
- Language.MY_MM: "my-MM",
- # Catalan
- Language.CA: "ca-ES",
- Language.CA_ES: "ca-ES",
- # Chinese
- Language.ZH: "cmn-Hans-CN", # Default to Simplified Chinese
- Language.ZH_CN: "cmn-Hans-CN",
- Language.ZH_HK: "cmn-Hans-HK",
- Language.ZH_TW: "cmn-Hant-TW",
- Language.YUE: "yue-Hant-HK", # Cantonese
- Language.YUE_CN: "yue-Hant-HK",
- # Croatian
- Language.HR: "hr-HR",
- Language.HR_HR: "hr-HR",
- # Czech
- Language.CS: "cs-CZ",
- Language.CS_CZ: "cs-CZ",
- # Danish
- Language.DA: "da-DK",
- Language.DA_DK: "da-DK",
- # Dutch
- Language.NL: "nl-NL", # Default to Netherlands
- Language.NL_BE: "nl-BE",
- Language.NL_NL: "nl-NL",
- # English
- Language.EN: "en-US", # Default to US
- Language.EN_AU: "en-AU",
- Language.EN_CA: "en-CA",
- Language.EN_GB: "en-GB",
- Language.EN_GH: "en-GH",
- Language.EN_HK: "en-HK",
- Language.EN_IN: "en-IN",
- Language.EN_IE: "en-IE",
- Language.EN_KE: "en-KE",
- Language.EN_NG: "en-NG",
- Language.EN_NZ: "en-NZ",
- Language.EN_PH: "en-PH",
- Language.EN_SG: "en-SG",
- Language.EN_TZ: "en-TZ",
- Language.EN_US: "en-US",
- Language.EN_ZA: "en-ZA",
- # Estonian
- Language.ET: "et-EE",
- Language.ET_EE: "et-EE",
- # Filipino
- Language.FIL: "fil-PH",
- Language.FIL_PH: "fil-PH",
- # Finnish
- Language.FI: "fi-FI",
- Language.FI_FI: "fi-FI",
- # French
- Language.FR: "fr-FR", # Default to France
- Language.FR_BE: "fr-BE",
- Language.FR_CA: "fr-CA",
- Language.FR_CH: "fr-CH",
- Language.FR_FR: "fr-FR",
- # Galician
- Language.GL: "gl-ES",
- Language.GL_ES: "gl-ES",
- # Georgian
- Language.KA: "ka-GE",
- Language.KA_GE: "ka-GE",
- # German
- Language.DE: "de-DE", # Default to Germany
- Language.DE_AT: "de-AT",
- Language.DE_CH: "de-CH",
- Language.DE_DE: "de-DE",
- # Greek
- Language.EL: "el-GR",
- Language.EL_GR: "el-GR",
- # Gujarati
- Language.GU: "gu-IN",
- Language.GU_IN: "gu-IN",
- # Hebrew
- Language.HE: "iw-IL",
- Language.HE_IL: "iw-IL",
- # Hindi
- Language.HI: "hi-IN",
- Language.HI_IN: "hi-IN",
- # Hungarian
- Language.HU: "hu-HU",
- Language.HU_HU: "hu-HU",
- # Icelandic
- Language.IS: "is-IS",
- Language.IS_IS: "is-IS",
- # Indonesian
- Language.ID: "id-ID",
- Language.ID_ID: "id-ID",
- # Italian
- Language.IT: "it-IT",
- Language.IT_IT: "it-IT",
- Language.IT_CH: "it-CH",
- # Japanese
- Language.JA: "ja-JP",
- Language.JA_JP: "ja-JP",
- # Javanese
- Language.JV: "jv-ID",
- Language.JV_ID: "jv-ID",
- # Kannada
- Language.KN: "kn-IN",
- Language.KN_IN: "kn-IN",
- # Kazakh
- Language.KK: "kk-KZ",
- Language.KK_KZ: "kk-KZ",
- # Khmer
- Language.KM: "km-KH",
- Language.KM_KH: "km-KH",
- # Korean
- Language.KO: "ko-KR",
- Language.KO_KR: "ko-KR",
- # Lao
- Language.LO: "lo-LA",
- Language.LO_LA: "lo-LA",
- # Latvian
- Language.LV: "lv-LV",
- Language.LV_LV: "lv-LV",
- # Lithuanian
- Language.LT: "lt-LT",
- Language.LT_LT: "lt-LT",
- # Macedonian
- Language.MK: "mk-MK",
- Language.MK_MK: "mk-MK",
- # Malay
- Language.MS: "ms-MY",
- Language.MS_MY: "ms-MY",
- # Malayalam
- Language.ML: "ml-IN",
- Language.ML_IN: "ml-IN",
- # Marathi
- Language.MR: "mr-IN",
- Language.MR_IN: "mr-IN",
- # Mongolian
- Language.MN: "mn-MN",
- Language.MN_MN: "mn-MN",
- # Nepali
- Language.NE: "ne-NP",
- Language.NE_NP: "ne-NP",
- # Norwegian
- Language.NO: "no-NO",
- Language.NB: "no-NO",
- Language.NB_NO: "no-NO",
- # Persian
- Language.FA: "fa-IR",
- Language.FA_IR: "fa-IR",
- # Polish
- Language.PL: "pl-PL",
- Language.PL_PL: "pl-PL",
- # Portuguese
- Language.PT: "pt-PT", # Default to Portugal
- Language.PT_BR: "pt-BR",
- Language.PT_PT: "pt-PT",
- # Punjabi
- Language.PA: "pa-Guru-IN",
- Language.PA_IN: "pa-Guru-IN",
- # Romanian
- Language.RO: "ro-RO",
- Language.RO_RO: "ro-RO",
- # Russian
- Language.RU: "ru-RU",
- Language.RU_RU: "ru-RU",
- # Serbian
- Language.SR: "sr-RS",
- Language.SR_RS: "sr-RS",
- # Sinhala
- Language.SI: "si-LK",
- Language.SI_LK: "si-LK",
- # Slovak
- Language.SK: "sk-SK",
- Language.SK_SK: "sk-SK",
- # Slovenian
- Language.SL: "sl-SI",
- Language.SL_SI: "sl-SI",
- # Spanish
- Language.ES: "es-ES", # Default to Spain
- Language.ES_AR: "es-AR",
- Language.ES_BO: "es-BO",
- Language.ES_CL: "es-CL",
- Language.ES_CO: "es-CO",
- Language.ES_CR: "es-CR",
- Language.ES_DO: "es-DO",
- Language.ES_EC: "es-EC",
- Language.ES_ES: "es-ES",
- Language.ES_GT: "es-GT",
- Language.ES_HN: "es-HN",
- Language.ES_MX: "es-MX",
- Language.ES_NI: "es-NI",
- Language.ES_PA: "es-PA",
- Language.ES_PE: "es-PE",
- Language.ES_PR: "es-PR",
- Language.ES_PY: "es-PY",
- Language.ES_SV: "es-SV",
- Language.ES_US: "es-US",
- Language.ES_UY: "es-UY",
- Language.ES_VE: "es-VE",
- # Sundanese
- Language.SU: "su-ID",
- Language.SU_ID: "su-ID",
- # Swahili
- Language.SW: "sw-TZ", # Default to Tanzania
- Language.SW_KE: "sw-KE",
- Language.SW_TZ: "sw-TZ",
- # Swedish
- Language.SV: "sv-SE",
- Language.SV_SE: "sv-SE",
- # Tamil
- Language.TA: "ta-IN", # Default to India
- Language.TA_IN: "ta-IN",
- Language.TA_MY: "ta-MY",
- Language.TA_SG: "ta-SG",
- Language.TA_LK: "ta-LK",
- # Telugu
- Language.TE: "te-IN",
- Language.TE_IN: "te-IN",
- # Thai
- Language.TH: "th-TH",
- Language.TH_TH: "th-TH",
- # Turkish
- Language.TR: "tr-TR",
- Language.TR_TR: "tr-TR",
- # Ukrainian
- Language.UK: "uk-UA",
- Language.UK_UA: "uk-UA",
- # Urdu
- Language.UR: "ur-IN", # Default to India
- Language.UR_IN: "ur-IN",
- Language.UR_PK: "ur-PK",
- # Uzbek
- Language.UZ: "uz-UZ",
- Language.UZ_UZ: "uz-UZ",
- # Vietnamese
- Language.VI: "vi-VN",
- Language.VI_VN: "vi-VN",
- # Xhosa
- Language.XH: "xh-ZA",
- # Zulu
- Language.ZU: "zu-ZA",
- Language.ZU_ZA: "zu-ZA",
- }
-
- return language_map.get(language)
-
-
-class GoogleUserContextAggregator(OpenAIUserContextAggregator):
- async def push_aggregation(self):
- if len(self._aggregation) > 0:
- self._context.add_message(
- glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
- )
-
- # Reset the aggregation. Reset it before pushing it down, otherwise
- # if the tasks gets cancelled we won't be able to clear things up.
- self._aggregation = ""
-
- # Push context frame
- frame = OpenAILLMContextFrame(self._context)
- await self.push_frame(frame)
-
- # Reset our accumulator state.
- self.reset()
-
-
-class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
- async def handle_aggregation(self, aggregation: str):
- self._context.add_message(glm.Content(role="model", parts=[glm.Part(text=aggregation)]))
-
- async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
- self._context.add_message(
- glm.Content(
- role="model",
- parts=[
- glm.Part(
- function_call=glm.FunctionCall(
- id=frame.tool_call_id, name=frame.function_name, args=frame.arguments
- )
- )
- ],
- )
- )
- self._context.add_message(
- glm.Content(
- role="user",
- parts=[
- glm.Part(
- function_response=glm.FunctionResponse(
- id=frame.tool_call_id,
- name=frame.function_name,
- response={"response": "IN_PROGRESS"},
- )
- )
- ],
- )
- )
-
- async def handle_function_call_result(self, frame: FunctionCallResultFrame):
- if frame.result:
- await self._update_function_call_result(
- frame.function_name, frame.tool_call_id, frame.result
- )
- else:
- await self._update_function_call_result(
- frame.function_name, frame.tool_call_id, "COMPLETED"
- )
-
- async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
- await self._update_function_call_result(
- frame.function_name, frame.tool_call_id, "CANCELLED"
- )
-
- async def _update_function_call_result(
- self, function_name: str, tool_call_id: str, result: Any
- ):
- for message in self._context.messages:
- if message.role == "user":
- for part in message.parts:
- if part.function_response and part.function_response.id == tool_call_id:
- part.function_response.response = {"value": json.dumps(result)}
-
- async def handle_user_image_frame(self, frame: UserImageRawFrame):
- await self._update_function_call_result(
- frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
- )
- self._context.add_image_frame_message(
- format=frame.format,
- size=frame.size,
- image=frame.image,
- text=frame.request.context,
- )
-
-
-@dataclass
-class GoogleContextAggregatorPair:
- _user: "GoogleUserContextAggregator"
- _assistant: "GoogleAssistantContextAggregator"
-
- def user(self) -> "GoogleUserContextAggregator":
- return self._user
-
- def assistant(self) -> "GoogleAssistantContextAggregator":
- return self._assistant
-
-
-class GoogleLLMContext(OpenAILLMContext):
- def __init__(
- self,
- messages: Optional[List[dict]] = None,
- tools: Optional[List[dict]] = None,
- tool_choice: Optional[dict] = None,
- ):
- super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
- self.system_message = None
-
- @staticmethod
- def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext":
- if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext):
- logger.debug(f"Upgrading to Google: {obj}")
- obj.__class__ = GoogleLLMContext
- obj._restructure_from_openai_messages()
- return obj
-
- def set_messages(self, messages: List):
- self._messages[:] = messages
- self._restructure_from_openai_messages()
-
- def add_messages(self, messages: List):
- # Convert each message individually
- converted_messages = []
- for msg in messages:
- if isinstance(msg, glm.Content):
- # Already in Gemini format
- converted_messages.append(msg)
- else:
- # Convert from standard format to Gemini format
- converted = self.from_standard_message(msg)
- if converted is not None:
- converted_messages.append(converted)
-
- # Add the converted messages to our existing messages
- self._messages.extend(converted_messages)
-
- def get_messages_for_logging(self):
- msgs = []
- for message in self.messages:
- obj = glm.Content.to_dict(message)
- try:
- if "parts" in obj:
- for part in obj["parts"]:
- if "inline_data" in part:
- part["inline_data"]["data"] = "..."
- except Exception as e:
- logger.debug(f"Error: {e}")
- msgs.append(obj)
- return msgs
-
- def add_image_frame_message(
- self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
- ):
- buffer = io.BytesIO()
- Image.frombytes(format, size, image).save(buffer, format="JPEG")
-
- parts = []
- if text:
- parts.append(glm.Part(text=text))
- parts.append(glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())))
-
- self.add_message(glm.Content(role="user", parts=parts))
-
- def add_audio_frames_message(
- self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
- ):
- if not audio_frames:
- return
-
- sample_rate = audio_frames[0].sample_rate
- num_channels = audio_frames[0].num_channels
-
- parts = []
- data = b"".join(frame.audio for frame in audio_frames)
- # NOTE(aleix): According to the docs only text or inline_data should be needed.
- # (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
- parts.append(glm.Part(text=text))
- parts.append(
- glm.Part(
- inline_data=glm.Blob(
- mime_type="audio/wav",
- data=(
- bytes(
- self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data
- )
- ),
- )
- ),
- )
- self.add_message(glm.Content(role="user", parts=parts))
- # message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
- # self.add_message(message)
-
- def from_standard_message(self, message):
- """Convert standard format message to Google Content object.
-
- Handles conversion of text, images, and function calls to Google's format.
- System messages are stored separately and return None.
-
- Args:
- message: Message in standard format:
- {
- "role": "user/assistant/system/tool",
- "content": str | [{"type": "text/image_url", ...}] | None,
- "tool_calls": [{"function": {"name": str, "arguments": str}}]
- }
-
- Returns:
- glm.Content object with:
- - role: "user" or "model" (converted from "assistant")
- - parts: List[Part] containing text, inline_data, or function calls
- Returns None for system messages.
- """
- role = message["role"]
- content = message.get("content", [])
- if role == "system":
- self.system_message = content
- return None
- elif role == "assistant":
- role = "model"
-
- parts = []
- if message.get("tool_calls"):
- for tc in message["tool_calls"]:
- parts.append(
- glm.Part(
- function_call=glm.FunctionCall(
- name=tc["function"]["name"],
- args=json.loads(tc["function"]["arguments"]),
- )
- )
- )
- elif role == "tool":
- role = "model"
- parts.append(
- glm.Part(
- function_response=glm.FunctionResponse(
- name="tool_call_result", # seems to work to hard-code the same name every time
- response=json.loads(message["content"]),
- )
- )
- )
- elif isinstance(content, str):
- parts.append(glm.Part(text=content))
- elif isinstance(content, list):
- for c in content:
- if c["type"] == "text":
- parts.append(glm.Part(text=c["text"]))
- elif c["type"] == "image_url":
- parts.append(
- glm.Part(
- inline_data=glm.Blob(
- mime_type="image/jpeg",
- data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
- )
- )
- )
-
- message = glm.Content(role=role, parts=parts)
- return message
-
- def to_standard_messages(self, obj) -> list:
- """Convert Google Content object to standard structured format.
-
- Handles text, images, and function calls from Google's Content/Part objects.
-
- Args:
- obj: Google Content object with:
- - role: "model" (converted to "assistant") or "user"
- - parts: List[Part] containing text, inline_data, or function calls
-
- Returns:
- List of messages in standard format:
- [
- {
- "role": "user/assistant/tool",
- "content": [
- {"type": "text", "text": str} |
- {"type": "image_url", "image_url": {"url": str}}
- ]
- }
- ]
- """
- msg = {"role": obj.role, "content": []}
- if msg["role"] == "model":
- msg["role"] = "assistant"
-
- for part in obj.parts:
- if part.text:
- msg["content"].append({"type": "text", "text": part.text})
- elif part.inline_data:
- encoded = base64.b64encode(part.inline_data.data).decode("utf-8")
- msg["content"].append(
- {
- "type": "image_url",
- "image_url": {"url": f"data:{part.inline_data.mime_type};base64,{encoded}"},
- }
- )
- elif part.function_call:
- args = type(part.function_call).to_dict(part.function_call).get("args", {})
- msg["tool_calls"] = [
- {
- "id": part.function_call.name,
- "type": "function",
- "function": {
- "name": part.function_call.name,
- "arguments": json.dumps(args),
- },
- }
- ]
-
- elif part.function_response:
- msg["role"] = "tool"
- resp = (
- type(part.function_response).to_dict(part.function_response).get("response", {})
- )
- msg["tool_call_id"] = part.function_response.name
- msg["content"] = json.dumps(resp)
-
- # there might be no content parts for tool_calls messages
- if not msg["content"]:
- del msg["content"]
- return [msg]
-
- def _restructure_from_openai_messages(self):
- """Restructures messages to ensure proper Google format and message ordering.
-
- This method handles conversion of OpenAI-formatted messages to Google format,
- with special handling for function calls, function responses, and system messages.
- System messages are added back to the context as user messages when needed.
-
- The final message order is preserved as:
- 1. Function calls (from model)
- 2. Function responses (from user)
- 3. Text messages (converted from system messages)
-
- Note:
- System messages are only added back when there are no regular text
- messages in the context, ensuring proper conversation continuity
- after function calls.
- """
- self.system_message = None
- converted_messages = []
-
- # Process each message, preserving Google-formatted messages and converting others
- for message in self._messages:
- if isinstance(message, glm.Content):
- # Keep existing Google-formatted messages (e.g., function calls/responses)
- converted_messages.append(message)
- continue
-
- # Convert OpenAI format to Google format, system messages return None
- converted = self.from_standard_message(message)
- if converted is not None:
- converted_messages.append(converted)
-
- # Update message list
- self._messages[:] = converted_messages
-
- # Check if we only have function-related messages (no regular text)
- has_regular_messages = any(
- len(msg.parts) == 1
- and not getattr(msg.parts[0], "text", None)
- and getattr(msg.parts[0], "function_call", None)
- and getattr(msg.parts[0], "function_response", None)
- for msg in self._messages
- )
-
- # Add system message back as a user message if we only have function messages
- if self.system_message and not has_regular_messages:
- self._messages.append(
- glm.Content(role="user", parts=[glm.Part(text=self.system_message)])
- )
-
- # Remove any empty messages
- self._messages = [m for m in self._messages if m.parts]
-
-
-class GoogleLLMService(LLMService):
- """This class implements inference with Google's AI models.
-
- This service translates internally from OpenAILLMContext to the messages format
- expected by the Google AI model. We are using the OpenAILLMContext as a lingua
- franca for all LLM services, so that it is easy to switch between different LLMs.
- """
-
- # Overriding the default adapter to use the Gemini one.
- adapter_class = GeminiLLMAdapter
-
- class InputParams(BaseModel):
- max_tokens: Optional[int] = Field(default=4096, ge=1)
- temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
- top_k: Optional[int] = Field(default=None, ge=0)
- top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
- extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
-
- def __init__(
- self,
- *,
- api_key: str,
- model: str = "gemini-2.0-flash-001",
- params: InputParams = InputParams(),
- system_instruction: Optional[str] = None,
- tools: Optional[List[Dict[str, Any]]] = None,
- tool_config: Optional[Dict[str, Any]] = None,
- **kwargs,
- ):
- super().__init__(**kwargs)
- gai.configure(api_key=api_key)
- self.set_model_name(model)
- self._system_instruction = system_instruction
- self._create_client()
- self._settings = {
- "max_tokens": params.max_tokens,
- "temperature": params.temperature,
- "top_k": params.top_k,
- "top_p": params.top_p,
- "extra": params.extra if isinstance(params.extra, dict) else {},
- }
- self._tools = tools
- self._tool_config = tool_config
-
- def can_generate_metrics(self) -> bool:
- return True
-
- def _create_client(self):
- self._client = gai.GenerativeModel(
- self._model_name, system_instruction=self._system_instruction
- )
-
- async def _process_context(self, context: OpenAILLMContext):
- await self.push_frame(LLMFullResponseStartFrame())
-
- prompt_tokens = 0
- completion_tokens = 0
- total_tokens = 0
-
- grounding_metadata = None
- search_result = ""
-
- try:
- logger.debug(
- # f"{self}: Generating chat [{self._system_instruction}] | [{context.get_messages_for_logging()}]"
- f"{self}: Generating chat [{context.get_messages_for_logging()}]"
- )
-
- messages = context.messages
- if context.system_message and self._system_instruction != context.system_message:
- logger.debug(f"System instruction changed: {context.system_message}")
- self._system_instruction = context.system_message
- self._create_client()
-
- # Filter out None values and create GenerationConfig
- generation_params = {
- k: v
- for k, v in {
- "temperature": self._settings["temperature"],
- "top_p": self._settings["top_p"],
- "top_k": self._settings["top_k"],
- "max_output_tokens": self._settings["max_tokens"],
- }.items()
- if v is not None
- }
-
- generation_config = GenerationConfig(**generation_params) if generation_params else None
-
- await self.start_ttfb_metrics()
- tools = []
- if context.tools:
- tools = context.tools
- elif self._tools:
- tools = self._tools
- tool_config = None
- if self._tool_config:
- tool_config = self._tool_config
- response = await self._client.generate_content_async(
- contents=messages,
- tools=tools,
- stream=True,
- generation_config=generation_config,
- tool_config=tool_config,
- )
- await self.stop_ttfb_metrics()
-
- if response.usage_metadata:
- # Use only the prompt token count from the response object
- prompt_tokens = response.usage_metadata.prompt_token_count
- total_tokens = prompt_tokens
-
- async for chunk in response:
- if chunk.usage_metadata:
- # Use only the completion_tokens from the chunks. Prompt tokens are already counted and
- # are repeated here.
- completion_tokens += chunk.usage_metadata.candidates_token_count
- total_tokens += chunk.usage_metadata.candidates_token_count
- try:
- for c in chunk.parts:
- if c.text:
- search_result += c.text
- await self.push_frame(LLMTextFrame(c.text))
- elif c.function_call:
- logger.debug(f"Function call: {c.function_call}")
- args = type(c.function_call).to_dict(c.function_call).get("args", {})
- await self.call_function(
- context=context,
- tool_call_id=str(uuid.uuid4()),
- function_name=c.function_call.name,
- arguments=args,
- )
- # Handle grounding metadata
- # It seems only the last chunk that we receive may contain this information
- # If the response doesn't include groundingMetadata, this means the response wasn't grounded.
- if chunk.candidates:
- for candidate in chunk.candidates:
- # logger.debug(f"candidate received: {candidate}")
- # Extract grounding metadata
- grounding_metadata = (
- {
- "rendered_content": getattr(
- getattr(candidate, "grounding_metadata", None),
- "search_entry_point",
- None,
- ).rendered_content
- if hasattr(
- getattr(candidate, "grounding_metadata", None),
- "search_entry_point",
- )
- else None,
- "origins": [
- {
- "site_uri": getattr(grounding_chunk.web, "uri", None),
- "site_title": getattr(
- grounding_chunk.web, "title", None
- ),
- "results": [
- {
- "text": getattr(
- grounding_support.segment, "text", ""
- ),
- "confidence": getattr(
- grounding_support, "confidence_scores", None
- ),
- }
- for grounding_support in getattr(
- getattr(candidate, "grounding_metadata", None),
- "grounding_supports",
- [],
- )
- if index
- in getattr(
- grounding_support, "grounding_chunk_indices", []
- )
- ],
- }
- for index, grounding_chunk in enumerate(
- getattr(
- getattr(candidate, "grounding_metadata", None),
- "grounding_chunks",
- [],
- )
- )
- ],
- }
- if getattr(candidate, "grounding_metadata", None)
- else None
- )
- except Exception as e:
- # Google LLMs seem to flag safety issues a lot!
- if chunk.candidates[0].finish_reason == 3:
- logger.debug(
- f"LLM refused to generate content for safety reasons - {messages}."
- )
- else:
- logger.exception(f"{self} error: {e}")
-
- except DeadlineExceeded:
- await self._call_event_handler("on_completion_timeout")
- except Exception as e:
- logger.exception(f"{self} exception: {e}")
- finally:
- if grounding_metadata is not None and isinstance(grounding_metadata, dict):
- llm_search_frame = LLMSearchResponseFrame(
- search_result=search_result,
- origins=grounding_metadata["origins"],
- rendered_content=grounding_metadata["rendered_content"],
- )
- await self.push_frame(llm_search_frame)
-
- await self.start_llm_usage_metrics(
- LLMTokenUsage(
- prompt_tokens=prompt_tokens,
- completion_tokens=completion_tokens,
- total_tokens=total_tokens,
- )
- )
- await self.push_frame(LLMFullResponseEndFrame())
-
- async def process_frame(self, frame: Frame, direction: FrameDirection):
- await super().process_frame(frame, direction)
-
- context = None
-
- if isinstance(frame, OpenAILLMContextFrame):
- context = GoogleLLMContext.upgrade_to_google(frame.context)
- elif isinstance(frame, LLMMessagesFrame):
- context = GoogleLLMContext(frame.messages)
- elif isinstance(frame, VisionImageRawFrame):
- context = GoogleLLMContext()
- context.add_image_frame_message(
- format=frame.format, size=frame.size, image=frame.image, text=frame.text
- )
- elif isinstance(frame, LLMUpdateSettingsFrame):
- await self._update_settings(frame.settings)
- else:
- await self.push_frame(frame, direction)
-
- if context:
- await self._process_context(context)
-
- def create_context_aggregator(
- self,
- context: OpenAILLMContext,
- *,
- user_kwargs: Mapping[str, Any] = {},
- assistant_kwargs: Mapping[str, Any] = {},
- ) -> GoogleContextAggregatorPair:
- """Create an instance of GoogleContextAggregatorPair from an
- OpenAILLMContext. Constructor keyword arguments for both the user and
- assistant aggregators can be provided.
-
- Args:
- context (OpenAILLMContext): The LLM context.
- user_kwargs (Mapping[str, Any], optional): Additional keyword
- arguments for the user context aggregator constructor. Defaults
- to an empty mapping.
- assistant_kwargs (Mapping[str, Any], optional): Additional keyword
- arguments for the assistant context aggregator
- constructor. Defaults to an empty mapping.
-
- Returns:
- GoogleContextAggregatorPair: A pair of context aggregators, one for
- the user and one for the assistant, encapsulated in an
- GoogleContextAggregatorPair.
-
- """
- context.set_llm_adapter(self.get_llm_adapter())
-
- if isinstance(context, OpenAILLMContext):
- context = GoogleLLMContext.upgrade_to_google(context)
- user = GoogleUserContextAggregator(context, **user_kwargs)
- assistant = GoogleAssistantContextAggregator(context, **assistant_kwargs)
- return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
-
-
-class GoogleLLMOpenAIBetaService(OpenAILLMService):
- """This class implements inference with Google's AI LLM models using the OpenAI format.
- Ref - https://ai.google.dev/gemini-api/docs/openai
- """
-
- def __init__(
- self,
- *,
- api_key: str,
- base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/",
- model: str = "gemini-2.0-flash",
- **kwargs,
- ):
- super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
-
- async def _process_context(self, context: OpenAILLMContext):
- functions_list = []
- arguments_list = []
- tool_id_list = []
- func_idx = 0
- function_name = ""
- arguments = ""
- tool_call_id = ""
-
- await self.start_ttfb_metrics()
-
- chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(
- context
- )
-
- async for chunk in chunk_stream:
- if chunk.usage:
- tokens = LLMTokenUsage(
- prompt_tokens=chunk.usage.prompt_tokens,
- completion_tokens=chunk.usage.completion_tokens,
- total_tokens=chunk.usage.total_tokens,
- )
- await self.start_llm_usage_metrics(tokens)
-
- if chunk.choices is None or len(chunk.choices) == 0:
- continue
-
- await self.stop_ttfb_metrics()
-
- if not chunk.choices[0].delta:
- continue
-
- if chunk.choices[0].delta.tool_calls:
- # We're streaming the LLM response to enable the fastest response times.
- # For text, we just yield each chunk as we receive it and count on consumers
- # to do whatever coalescing they need (eg. to pass full sentences to TTS)
- #
- # If the LLM is a function call, we'll do some coalescing here.
- # If the response contains a function name, we'll yield a frame to tell consumers
- # that they can start preparing to call the function with that name.
- # We accumulate all the arguments for the rest of the streamed response, then when
- # the response is done, we package up all the arguments and the function name and
- # yield a frame containing the function name and the arguments.
- logger.debug(f"Tool call: {chunk.choices[0].delta.tool_calls}")
- tool_call = chunk.choices[0].delta.tool_calls[0]
- if tool_call.index != func_idx:
- functions_list.append(function_name)
- arguments_list.append(arguments)
- tool_id_list.append(tool_call_id)
- function_name = ""
- arguments = ""
- tool_call_id = ""
- func_idx += 1
- if tool_call.function and tool_call.function.name:
- function_name += tool_call.function.name
- tool_call_id = tool_call.id
- if tool_call.function and tool_call.function.arguments:
- # Keep iterating through the response to collect all the argument fragments
- arguments += tool_call.function.arguments
- elif chunk.choices[0].delta.content:
- await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
-
- # if we got a function name and arguments, check to see if it's a function with
- # a registered handler. If so, run the registered callback, save the result to
- # the context, and re-prompt to get a chat answer. If we don't have a registered
- # handler, raise an exception.
- if function_name and arguments:
- # added to the list as last function name and arguments not added to the list
- functions_list.append(function_name)
- arguments_list.append(arguments)
- tool_id_list.append(tool_call_id)
-
- logger.debug(
- f"Function list: {functions_list}, Arguments list: {arguments_list}, Tool ID list: {tool_id_list}"
- )
- for index, (function_name, arguments, tool_id) in enumerate(
- zip(functions_list, arguments_list, tool_id_list), start=1
- ):
- if function_name == "":
- # TODO: Remove the _process_context method once Google resolves the bug
- # where the index is incorrectly set to None instead of returning the actual index,
- # which currently results in an empty function name('').
- continue
- if self.has_function(function_name):
- run_llm = False
- arguments = json.loads(arguments)
- await self.call_function(
- context=context,
- function_name=function_name,
- arguments=arguments,
- tool_call_id=tool_id,
- run_llm=run_llm,
- )
- else:
- raise OpenAIUnhandledFunctionException(
- f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
- )
-
-
-class GoogleVertexLLMService(OpenAILLMService):
- """Implements inference with Google's AI models via Vertex AI while
- maintaining OpenAI API compatibility.
-
- Reference:
- https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-vertex-using-openai-library
-
- """
-
- class InputParams(OpenAILLMService.InputParams):
- """Input parameters specific to Vertex AI."""
-
- # https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations
- location: str = "us-east4"
- project_id: str
-
- def __init__(
- self,
- *,
- credentials: Optional[str] = None,
- credentials_path: Optional[str] = None,
- model: str = "google/gemini-2.0-flash-001",
- params: InputParams = OpenAILLMService.InputParams(),
- **kwargs,
- ):
- """Initializes the VertexLLMService.
-
- Args:
- credentials (Optional[str]): JSON string of service account credentials.
- credentials_path (Optional[str]): Path to the service account JSON file.
- model (str): Model identifier. Defaults to "google/gemini-2.0-flash-001".
- params (InputParams): Vertex AI input parameters.
- **kwargs: Additional arguments for OpenAILLMService.
- """
- base_url = self._get_base_url(params)
- self._api_key = self._get_api_token(credentials, credentials_path)
-
- super().__init__(api_key=self._api_key, base_url=base_url, model=model, **kwargs)
-
- @staticmethod
- def _get_base_url(params: InputParams) -> str:
- """Constructs the base URL for Vertex AI API."""
- return (
- f"https://{params.location}-aiplatform.googleapis.com/v1/"
- f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
- )
-
- @staticmethod
- def _get_api_token(credentials: Optional[str], credentials_path: Optional[str]) -> str:
- """Retrieves an authentication token using Google service account credentials.
-
- Args:
- credentials (Optional[str]): JSON string of service account credentials.
- credentials_path (Optional[str]): Path to the service account JSON file.
-
- Returns:
- str: OAuth token for API authentication.
- """
- creds: Optional[service_account.Credentials] = None
-
- if credentials:
- # Parse and load credentials from JSON string
- creds = service_account.Credentials.from_service_account_info(
- json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"]
- )
- elif credentials_path:
- # Load credentials from JSON file
- creds = service_account.Credentials.from_service_account_file(
- credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"]
- )
-
- if not creds:
- raise ValueError("No valid credentials provided.")
-
- creds.refresh(Request()) # Ensure token is up-to-date, lifetime is 1 hour.
-
- return creds.token
-
-
-class GoogleTTSService(TTSService):
- class InputParams(BaseModel):
- pitch: Optional[str] = None
- rate: Optional[str] = None
- volume: Optional[str] = None
- emphasis: Optional[Literal["strong", "moderate", "reduced", "none"]] = None
- language: Optional[Language] = Language.EN
- gender: Optional[Literal["male", "female", "neutral"]] = None
- google_style: Optional[Literal["apologetic", "calm", "empathetic", "firm", "lively"]] = None
-
- def __init__(
- self,
- *,
- credentials: Optional[str] = None,
- credentials_path: Optional[str] = None,
- voice_id: str = "en-US-Neural2-A",
- sample_rate: Optional[int] = None,
- params: InputParams = InputParams(),
- **kwargs,
- ):
- super().__init__(sample_rate=sample_rate, **kwargs)
-
- self._settings = {
- "pitch": params.pitch,
- "rate": params.rate,
- "volume": params.volume,
- "emphasis": params.emphasis,
- "language": self.language_to_service_language(params.language)
- if params.language
- else "en-US",
- "gender": params.gender,
- "google_style": params.google_style,
- }
- self.set_voice(voice_id)
- self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client(
- credentials, credentials_path
- )
-
- def _create_client(
- self, credentials: Optional[str], credentials_path: Optional[str]
- ) -> texttospeech_v1.TextToSpeechAsyncClient:
- creds: Optional[service_account.Credentials] = None
-
- # Create a Google Cloud service account for the Cloud Text-to-Speech API
- # Using either the provided credentials JSON string or the path to a service account JSON
- # file, create a Google Cloud service account and use it to authenticate with the API.
- if credentials:
- # Use provided credentials JSON string
- json_account_info = json.loads(credentials)
- creds = service_account.Credentials.from_service_account_info(json_account_info)
- elif credentials_path:
- # Use service account JSON file if provided
- creds = service_account.Credentials.from_service_account_file(credentials_path)
-
- return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
-
- def can_generate_metrics(self) -> bool:
- return True
-
- def language_to_service_language(self, language: Language) -> Optional[str]:
- return language_to_google_tts_language(language)
-
- def _construct_ssml(self, text: str) -> str:
- ssml = ""
-
- # Voice tag
- voice_attrs = [f"name='{self._voice_id}'"]
-
- language = self._settings["language"]
- voice_attrs.append(f"language='{language}'")
-
- if self._settings["gender"]:
- voice_attrs.append(f"gender='{self._settings['gender']}'")
- ssml += f""
-
- # Prosody tag
- prosody_attrs = []
- if self._settings["pitch"]:
- prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
- if self._settings["rate"]:
- prosody_attrs.append(f"rate='{self._settings['rate']}'")
- if self._settings["volume"]:
- prosody_attrs.append(f"volume='{self._settings['volume']}'")
-
- if prosody_attrs:
- ssml += f""
-
- # Emphasis tag
- if self._settings["emphasis"]:
- ssml += f""
-
- # Google style tag
- if self._settings["google_style"]:
- ssml += f""
-
- ssml += text
-
- # Close tags
- if self._settings["google_style"]:
- ssml += ""
- if self._settings["emphasis"]:
- ssml += ""
- if prosody_attrs:
- ssml += ""
- ssml += ""
-
- return ssml
-
- async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
- logger.debug(f"{self}: Generating TTS [{text}]")
-
- try:
- await self.start_ttfb_metrics()
-
- # Check if the voice is a Chirp voice (including Chirp 3) or Journey voice
- is_chirp_voice = "chirp" in self._voice_id.lower()
- is_journey_voice = "journey" in self._voice_id.lower()
-
- # Create synthesis input based on voice_id
- if is_chirp_voice or is_journey_voice:
- # Chirp and Journey voices don't support SSML, use plain text
- synthesis_input = texttospeech_v1.SynthesisInput(text=text)
- else:
- ssml = self._construct_ssml(text)
- synthesis_input = texttospeech_v1.SynthesisInput(ssml=ssml)
-
- voice = texttospeech_v1.VoiceSelectionParams(
- language_code=self._settings["language"], name=self._voice_id
- )
- audio_config = texttospeech_v1.AudioConfig(
- audio_encoding=texttospeech_v1.AudioEncoding.LINEAR16,
- sample_rate_hertz=self.sample_rate,
- )
-
- request = texttospeech_v1.SynthesizeSpeechRequest(
- input=synthesis_input, voice=voice, audio_config=audio_config
- )
-
- response = await self._client.synthesize_speech(request=request)
-
- await self.start_tts_usage_metrics(text)
-
- yield TTSStartedFrame()
-
- # Skip the first 44 bytes to remove the WAV header
- audio_content = response.audio_content[44:]
-
- # Read and yield audio data in chunks
- chunk_size = 8192
- for i in range(0, len(audio_content), chunk_size):
- chunk = audio_content[i : i + chunk_size]
- if not chunk:
- break
- await self.stop_ttfb_metrics()
- frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
- yield frame
- await asyncio.sleep(0) # Allow other tasks to run
-
- yield TTSStoppedFrame()
-
- except Exception as e:
- logger.exception(f"{self} error generating TTS: {e}")
- error_message = f"TTS generation error: {str(e)}"
- yield ErrorFrame(error=error_message)
-
-
-class GoogleImageGenService(ImageGenService):
- class InputParams(BaseModel):
- number_of_images: int = Field(default=1, ge=1, le=8)
- model: str = Field(default="imagen-3.0-generate-002")
- negative_prompt: str = Field(default=None)
-
- def __init__(
- self,
- *,
- params: InputParams = InputParams(),
- api_key: str,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.set_model_name(params.model)
- self._params = params
- self._client = genai.Client(api_key=api_key)
-
- def can_generate_metrics(self) -> bool:
- return True
-
- async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
- """Generate images from a text prompt using Google's Imagen model.
-
- Args:
- prompt (str): The text description to generate images from.
-
- Yields:
- Frame: Generated image frames or error frames.
- """
- logger.debug(f"Generating image from prompt: {prompt}")
- await self.start_ttfb_metrics()
-
- try:
- response = await self._client.aio.models.generate_images(
- model=self._params.model,
- prompt=prompt,
- config=types.GenerateImagesConfig(
- number_of_images=self._params.number_of_images,
- negative_prompt=self._params.negative_prompt,
- ),
- )
- await self.stop_ttfb_metrics()
-
- if not response or not response.generated_images:
- logger.error(f"{self} error: image generation failed")
- yield ErrorFrame("Image generation failed")
- return
-
- for img_response in response.generated_images:
- # Google returns the image data directly
- image_bytes = img_response.image.image_bytes
- image = Image.open(io.BytesIO(image_bytes))
-
- frame = URLImageRawFrame(
- url=None, # Google doesn't provide URLs, only image data
- image=image.tobytes(),
- size=image.size,
- format=image.format,
- )
- yield frame
-
- except Exception as e:
- logger.error(f"{self} error generating image: {e}")
- yield ErrorFrame(f"Image generation error: {str(e)}")
-
-
-class GoogleSTTService(STTService):
- """Google Cloud Speech-to-Text V2 service implementation.
-
- Provides real-time speech recognition using Google Cloud's Speech-to-Text V2 API
- with streaming support. Handles audio transcription and optional voice activity detection.
-
- Attributes:
- InputParams: Configuration parameters for the STT service.
- """
-
- # Google Cloud's STT service has a connection time limit of 5 minutes per stream.
- # They've shared an "endless streaming" example that guided this implementation:
- # https://cloud.google.com/speech-to-text/docs/transcribe-streaming-audio#endless-streaming
-
- STREAMING_LIMIT = 240000 # 4 minutes in milliseconds
-
- class InputParams(BaseModel):
- """Configuration parameters for Google Speech-to-Text.
-
- Attributes:
- languages: Single language or list of recognition languages. First language is primary.
- model: Speech recognition model to use.
- use_separate_recognition_per_channel: Process each audio channel separately.
- enable_automatic_punctuation: Add punctuation to transcripts.
- enable_spoken_punctuation: Include spoken punctuation in transcript.
- enable_spoken_emojis: Include spoken emojis in transcript.
- profanity_filter: Filter profanity from transcript.
- enable_word_time_offsets: Include timing information for each word.
- enable_word_confidence: Include confidence scores for each word.
- enable_interim_results: Stream partial recognition results.
- enable_voice_activity_events: Detect voice activity in audio.
- """
-
- languages: Union[Language, List[Language]] = Field(default_factory=lambda: [Language.EN_US])
- model: Optional[str] = "latest_long"
- use_separate_recognition_per_channel: Optional[bool] = False
- enable_automatic_punctuation: Optional[bool] = True
- enable_spoken_punctuation: Optional[bool] = False
- enable_spoken_emojis: Optional[bool] = False
- profanity_filter: Optional[bool] = False
- enable_word_time_offsets: Optional[bool] = False
- enable_word_confidence: Optional[bool] = False
- enable_interim_results: Optional[bool] = True
- enable_voice_activity_events: Optional[bool] = False
-
- @field_validator("languages", mode="before")
- @classmethod
- def validate_languages(cls, v) -> List[Language]:
- if isinstance(v, Language):
- return [v]
- return v
-
- @property
- def language_list(self) -> List[Language]:
- """Get languages as a guaranteed list."""
- assert isinstance(self.languages, list)
- return self.languages
-
- def __init__(
- self,
- *,
- credentials: Optional[str] = None,
- credentials_path: Optional[str] = None,
- location: str = "global",
- sample_rate: Optional[int] = None,
- params: InputParams = InputParams(),
- **kwargs,
- ):
- """Initialize the Google STT service.
-
- Args:
- credentials: JSON string containing Google Cloud service account credentials.
- credentials_path: Path to service account credentials JSON file.
- location: Google Cloud location (e.g., "global", "us-central1").
- sample_rate: Audio sample rate in Hertz.
- params: Configuration parameters for the service.
- **kwargs: Additional arguments passed to STTService.
-
- Raises:
- ValueError: If neither credentials nor credentials_path is provided.
- ValueError: If project ID is not found in credentials.
- """
- super().__init__(sample_rate=sample_rate, **kwargs)
-
- self._location = location
- self._stream = None
- self._config = None
- self._request_queue = asyncio.Queue()
- self._streaming_task = None
-
- # Used for keep-alive logic
- self._stream_start_time = 0
- self._last_audio_input = []
- self._audio_input = []
- self._result_end_time = 0
- self._is_final_end_time = 0
- self._final_request_end_time = 0
- self._bridging_offset = 0
- self._last_transcript_was_final = False
- self._new_stream = True
- self._restart_counter = 0
-
- # Configure client options based on location
- client_options = None
- if self._location != "global":
- client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com")
-
- # Extract project ID and create client
- if credentials:
- json_account_info = json.loads(credentials)
- self._project_id = json_account_info.get("project_id")
- creds = service_account.Credentials.from_service_account_info(json_account_info)
- elif credentials_path:
- with open(credentials_path) as f:
- json_account_info = json.load(f)
- self._project_id = json_account_info.get("project_id")
- creds = service_account.Credentials.from_service_account_file(credentials_path)
- else:
- raise ValueError("Either credentials or credentials_path must be provided")
-
- if not self._project_id:
- raise ValueError("Project ID not found in credentials")
-
- self._client = speech_v2.SpeechAsyncClient(credentials=creds, client_options=client_options)
-
- self._settings = {
- "language_codes": [
- self.language_to_service_language(lang) for lang in params.language_list
- ],
- "model": params.model,
- "use_separate_recognition_per_channel": params.use_separate_recognition_per_channel,
- "enable_automatic_punctuation": params.enable_automatic_punctuation,
- "enable_spoken_punctuation": params.enable_spoken_punctuation,
- "enable_spoken_emojis": params.enable_spoken_emojis,
- "profanity_filter": params.profanity_filter,
- "enable_word_time_offsets": params.enable_word_time_offsets,
- "enable_word_confidence": params.enable_word_confidence,
- "enable_interim_results": params.enable_interim_results,
- "enable_voice_activity_events": params.enable_voice_activity_events,
- }
-
- def language_to_service_language(self, language: Language | List[Language]) -> str | List[str]:
- """Convert Language enum(s) to Google STT language code(s).
-
- Args:
- language: Single Language enum or list of Language enums.
-
- Returns:
- str | List[str]: Google STT language code(s).
- """
- if isinstance(language, list):
- return [language_to_google_stt_language(lang) or "en-US" for lang in language]
- return language_to_google_stt_language(language) or "en-US"
-
- async def _reconnect_if_needed(self):
- """Reconnect the stream if it's currently active."""
- if self._streaming_task:
- logger.debug("Reconnecting stream due to configuration changes")
- await self._disconnect()
- await self._connect()
-
- async def set_language(self, language: Language):
- """Update the service's recognition language.
-
- A convenience method for setting a single language.
-
- Args:
- language: New language for recognition.
- """
- logger.debug(f"Switching STT language to: {language}")
- await self.set_languages([language])
-
- async def set_languages(self, languages: List[Language]):
- """Update the service's recognition languages.
-
- Args:
- languages: List of languages for recognition. First language is primary.
- """
- logger.debug(f"Switching STT languages to: {languages}")
- self._settings["language_codes"] = [
- self.language_to_service_language(lang) for lang in languages
- ]
- # Recreate stream with new languages
- await self._reconnect_if_needed()
-
- async def set_model(self, model: str):
- """Update the service's recognition model."""
- logger.debug(f"Switching STT model to: {model}")
- await super().set_model(model)
- self._settings["model"] = model
- # Recreate stream with new model
- await self._reconnect_if_needed()
-
- async def start(self, frame: StartFrame):
- await super().start(frame)
- await self._connect()
-
- async def stop(self, frame: EndFrame):
- await super().stop(frame)
- await self._disconnect()
-
- async def cancel(self, frame: CancelFrame):
- await super().cancel(frame)
- await self._disconnect()
-
- async def update_options(
- self,
- *,
- languages: Optional[List[Language]] = None,
- model: Optional[str] = None,
- enable_automatic_punctuation: Optional[bool] = None,
- enable_spoken_punctuation: Optional[bool] = None,
- enable_spoken_emojis: Optional[bool] = None,
- profanity_filter: Optional[bool] = None,
- enable_word_time_offsets: Optional[bool] = None,
- enable_word_confidence: Optional[bool] = None,
- enable_interim_results: Optional[bool] = None,
- enable_voice_activity_events: Optional[bool] = None,
- location: Optional[str] = None,
- ) -> None:
- """Update service options dynamically.
-
- Args:
- languages: New list of recongition languages.
- model: New recognition model.
- enable_automatic_punctuation: Enable/disable automatic punctuation.
- enable_spoken_punctuation: Enable/disable spoken punctuation.
- enable_spoken_emojis: Enable/disable spoken emojis.
- profanity_filter: Enable/disable profanity filter.
- enable_word_time_offsets: Enable/disable word timing info.
- enable_word_confidence: Enable/disable word confidence scores.
- enable_interim_results: Enable/disable interim results.
- enable_voice_activity_events: Enable/disable voice activity detection.
- location: New Google Cloud location.
-
- Note:
- Changes that affect the streaming configuration will cause
- the stream to be reconnected.
- """
- # Update settings with new values
- if languages is not None:
- logger.debug(f"Updating language to: {languages}")
- self._settings["language_codes"] = [
- self.language_to_service_language(lang) for lang in languages
- ]
-
- if model is not None:
- logger.debug(f"Updating model to: {model}")
- self._settings["model"] = model
-
- if enable_automatic_punctuation is not None:
- logger.debug(f"Updating automatic punctuation to: {enable_automatic_punctuation}")
- self._settings["enable_automatic_punctuation"] = enable_automatic_punctuation
-
- if enable_spoken_punctuation is not None:
- logger.debug(f"Updating spoken punctuation to: {enable_spoken_punctuation}")
- self._settings["enable_spoken_punctuation"] = enable_spoken_punctuation
-
- if enable_spoken_emojis is not None:
- logger.debug(f"Updating spoken emojis to: {enable_spoken_emojis}")
- self._settings["enable_spoken_emojis"] = enable_spoken_emojis
-
- if profanity_filter is not None:
- logger.debug(f"Updating profanity filter to: {profanity_filter}")
- self._settings["profanity_filter"] = profanity_filter
-
- if enable_word_time_offsets is not None:
- logger.debug(f"Updating word time offsets to: {enable_word_time_offsets}")
- self._settings["enable_word_time_offsets"] = enable_word_time_offsets
-
- if enable_word_confidence is not None:
- logger.debug(f"Updating word confidence to: {enable_word_confidence}")
- self._settings["enable_word_confidence"] = enable_word_confidence
-
- if enable_interim_results is not None:
- logger.debug(f"Updating interim results to: {enable_interim_results}")
- self._settings["enable_interim_results"] = enable_interim_results
-
- if enable_voice_activity_events is not None:
- logger.debug(f"Updating voice activity events to: {enable_voice_activity_events}")
- self._settings["enable_voice_activity_events"] = enable_voice_activity_events
-
- if location is not None:
- logger.debug(f"Updating location to: {location}")
- self._location = location
-
- # Reconnect the stream for updates
- await self._reconnect_if_needed()
-
- async def _connect(self):
- """Initialize streaming recognition config and stream."""
- logger.debug("Connecting to Google Speech-to-Text")
-
- # Set stream start time
- self._stream_start_time = int(time.time() * 1000)
- self._new_stream = True
-
- self._config = cloud_speech.StreamingRecognitionConfig(
- config=cloud_speech.RecognitionConfig(
- explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
- encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
- sample_rate_hertz=self.sample_rate,
- audio_channel_count=1,
- ),
- language_codes=self._settings["language_codes"],
- model=self._settings["model"],
- features=cloud_speech.RecognitionFeatures(
- enable_automatic_punctuation=self._settings["enable_automatic_punctuation"],
- enable_spoken_punctuation=self._settings["enable_spoken_punctuation"],
- enable_spoken_emojis=self._settings["enable_spoken_emojis"],
- profanity_filter=self._settings["profanity_filter"],
- enable_word_time_offsets=self._settings["enable_word_time_offsets"],
- enable_word_confidence=self._settings["enable_word_confidence"],
- ),
- ),
- streaming_features=cloud_speech.StreamingRecognitionFeatures(
- enable_voice_activity_events=self._settings["enable_voice_activity_events"],
- interim_results=self._settings["enable_interim_results"],
- ),
- )
-
- self._streaming_task = self.create_task(self._stream_audio())
-
- async def _disconnect(self):
- """Clean up streaming recognition resources."""
- if self._streaming_task:
- logger.debug("Disconnecting from Google Speech-to-Text")
- # Send sentinel value to stop request generator
- await self._request_queue.put(None)
- await self.cancel_task(self._streaming_task)
- self._streaming_task = None
- # Clear any remaining items in the queue
- while not self._request_queue.empty():
- try:
- self._request_queue.get_nowait()
- self._request_queue.task_done()
- except asyncio.QueueEmpty:
- break
-
- async def _request_generator(self):
- """Generates requests for the streaming recognize method."""
- recognizer_path = f"projects/{self._project_id}/locations/{self._location}/recognizers/_"
- logger.trace(f"Using recognizer path: {recognizer_path}")
-
- try:
- # Send initial config
- yield cloud_speech.StreamingRecognizeRequest(
- recognizer=recognizer_path,
- streaming_config=self._config,
- )
-
- while True:
- try:
- audio_data = await self._request_queue.get()
- if audio_data is None: # Sentinel value to stop
- break
-
- # Check streaming limit
- if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
- logger.debug("Streaming limit reached, initiating graceful reconnection")
- # Instead of immediate reconnection, we'll break and let the stream close naturally
- self._last_audio_input = self._audio_input
- self._audio_input = []
- self._restart_counter += 1
- # Put the current audio chunk back in the queue
- await self._request_queue.put(audio_data)
- break
-
- self._audio_input.append(audio_data)
- yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
-
- except asyncio.CancelledError:
- break
- finally:
- self._request_queue.task_done()
-
- except Exception as e:
- logger.error(f"Error in request generator: {e}")
- raise
-
- async def _stream_audio(self):
- """Handle bi-directional streaming with Google STT."""
- try:
- while True:
- try:
- # Start bi-directional streaming
- streaming_recognize = await self._client.streaming_recognize(
- requests=self._request_generator()
- )
-
- # Process responses
- await self._process_responses(streaming_recognize)
-
- # If we're here, check if we need to reconnect
- if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
- logger.debug("Reconnecting stream after timeout")
- # Reset stream start time
- self._stream_start_time = int(time.time() * 1000)
- continue
- else:
- # Normal stream end
- break
-
- except Exception as e:
- logger.warning(f"{self} Reconnecting: {e}")
-
- await asyncio.sleep(1) # Brief delay before reconnecting
- self._stream_start_time = int(time.time() * 1000)
- continue
-
- except Exception as e:
- logger.error(f"Error in streaming task: {e}")
- await self.push_frame(ErrorFrame(str(e)))
-
- async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
- """Process an audio chunk for STT transcription."""
- if self._streaming_task:
- # Queue the audio data
- await self._request_queue.put(audio)
- yield None
-
- async def _process_responses(self, streaming_recognize):
- """Process streaming recognition responses."""
- try:
- async for response in streaming_recognize:
- # Check streaming limit
- if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
- logger.debug("Stream timeout reached in response processing")
- break
-
- if not response.results:
- continue
-
- for result in response.results:
- if not result.alternatives:
- continue
-
- transcript = result.alternatives[0].transcript
- if not transcript:
- continue
-
- primary_language = self._settings["language_codes"][0]
-
- if result.is_final:
- self._last_transcript_was_final = True
- await self.push_frame(
- TranscriptionFrame(transcript, "", time_now_iso8601(), primary_language)
- )
- else:
- self._last_transcript_was_final = False
- await self.push_frame(
- InterimTranscriptionFrame(
- transcript, "", time_now_iso8601(), primary_language
- )
- )
-
- except Exception as e:
- logger.error(f"Error processing Google STT responses: {e}")
-
- # Re-raise the exception to let it propagate (e.g. in the case of a timeout, propagate to _stream_audio to reconnect)
- raise
diff --git a/src/pipecat/services/google/image.py b/src/pipecat/services/google/image.py
new file mode 100644
index 000000000..f7a7764f2
--- /dev/null
+++ b/src/pipecat/services/google/image.py
@@ -0,0 +1,95 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import io
+import os
+
+# Suppress gRPC fork warnings
+os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
+
+from typing import AsyncGenerator
+
+from loguru import logger
+from PIL import Image
+from pydantic import BaseModel, Field
+
+from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame
+from pipecat.services.ai_services import ImageGenService
+
+try:
+ from google import genai
+ from google.genai import types
+except ModuleNotFoundError as e:
+ logger.error(f"Exception: {e}")
+ logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
+ raise Exception(f"Missing module: {e}")
+
+
+class GoogleImageGenService(ImageGenService):
+ class InputParams(BaseModel):
+ number_of_images: int = Field(default=1, ge=1, le=8)
+ model: str = Field(default="imagen-3.0-generate-002")
+ negative_prompt: str = Field(default=None)
+
+ def __init__(
+ self,
+ *,
+ params: InputParams = InputParams(),
+ api_key: str,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.set_model_name(params.model)
+ self._params = params
+ self._client = genai.Client(api_key=api_key)
+
+ def can_generate_metrics(self) -> bool:
+ return True
+
+ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
+ """Generate images from a text prompt using Google's Imagen model.
+
+ Args:
+ prompt (str): The text description to generate images from.
+
+ Yields:
+ Frame: Generated image frames or error frames.
+ """
+ logger.debug(f"Generating image from prompt: {prompt}")
+ await self.start_ttfb_metrics()
+
+ try:
+ response = await self._client.aio.models.generate_images(
+ model=self._params.model,
+ prompt=prompt,
+ config=types.GenerateImagesConfig(
+ number_of_images=self._params.number_of_images,
+ negative_prompt=self._params.negative_prompt,
+ ),
+ )
+ await self.stop_ttfb_metrics()
+
+ if not response or not response.generated_images:
+ logger.error(f"{self} error: image generation failed")
+ yield ErrorFrame("Image generation failed")
+ return
+
+ for img_response in response.generated_images:
+ # Google returns the image data directly
+ image_bytes = img_response.image.image_bytes
+ image = Image.open(io.BytesIO(image_bytes))
+
+ frame = URLImageRawFrame(
+ url=None, # Google doesn't provide URLs, only image data
+ image=image.tobytes(),
+ size=image.size,
+ format=image.format,
+ )
+ yield frame
+
+ except Exception as e:
+ logger.error(f"{self} error generating image: {e}")
+ yield ErrorFrame(f"Image generation error: {str(e)}")
diff --git a/src/pipecat/services/google/llm.py b/src/pipecat/services/google/llm.py
new file mode 100644
index 000000000..e78e3949b
--- /dev/null
+++ b/src/pipecat/services/google/llm.py
@@ -0,0 +1,717 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import base64
+import io
+import json
+import os
+import uuid
+
+from google.api_core.exceptions import DeadlineExceeded
+
+from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter
+
+# Suppress gRPC fork warnings
+os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Mapping, Optional, Union
+
+from loguru import logger
+from PIL import Image
+from pydantic import BaseModel, Field
+
+from pipecat.frames.frames import (
+ AudioRawFrame,
+ Frame,
+ FunctionCallCancelFrame,
+ FunctionCallInProgressFrame,
+ FunctionCallResultFrame,
+ LLMFullResponseEndFrame,
+ LLMFullResponseStartFrame,
+ LLMMessagesFrame,
+ LLMTextFrame,
+ LLMUpdateSettingsFrame,
+ UserImageRawFrame,
+ VisionImageRawFrame,
+)
+from pipecat.metrics.metrics import LLMTokenUsage
+from pipecat.processors.aggregators.openai_llm_context import (
+ OpenAILLMContext,
+ OpenAILLMContextFrame,
+)
+from pipecat.processors.frame_processor import FrameDirection
+from pipecat.services.ai_services import LLMService
+from pipecat.services.google.frames import LLMSearchResponseFrame
+from pipecat.services.openai.llm import (
+ OpenAIAssistantContextAggregator,
+ OpenAIUserContextAggregator,
+)
+
+try:
+ import google.ai.generativelanguage as glm
+ import google.generativeai as gai
+ from google.generativeai.types import GenerationConfig
+
+except ModuleNotFoundError as e:
+ logger.error(f"Exception: {e}")
+ logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.")
+ raise Exception(f"Missing module: {e}")
+
+
+class GoogleUserContextAggregator(OpenAIUserContextAggregator):
+ async def push_aggregation(self):
+ if len(self._aggregation) > 0:
+ self._context.add_message(
+ glm.Content(role="user", parts=[glm.Part(text=self._aggregation)])
+ )
+
+ # Reset the aggregation. Reset it before pushing it down, otherwise
+ # if the tasks gets cancelled we won't be able to clear things up.
+ self._aggregation = ""
+
+ # Push context frame
+ frame = OpenAILLMContextFrame(self._context)
+ await self.push_frame(frame)
+
+ # Reset our accumulator state.
+ self.reset()
+
+
+class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator):
+ async def handle_aggregation(self, aggregation: str):
+ self._context.add_message(glm.Content(role="model", parts=[glm.Part(text=aggregation)]))
+
+ async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
+ self._context.add_message(
+ glm.Content(
+ role="model",
+ parts=[
+ glm.Part(
+ function_call=glm.FunctionCall(
+ id=frame.tool_call_id, name=frame.function_name, args=frame.arguments
+ )
+ )
+ ],
+ )
+ )
+ self._context.add_message(
+ glm.Content(
+ role="user",
+ parts=[
+ glm.Part(
+ function_response=glm.FunctionResponse(
+ id=frame.tool_call_id,
+ name=frame.function_name,
+ response={"response": "IN_PROGRESS"},
+ )
+ )
+ ],
+ )
+ )
+
+ async def handle_function_call_result(self, frame: FunctionCallResultFrame):
+ if frame.result:
+ await self._update_function_call_result(
+ frame.function_name, frame.tool_call_id, frame.result
+ )
+ else:
+ await self._update_function_call_result(
+ frame.function_name, frame.tool_call_id, "COMPLETED"
+ )
+
+ async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
+ await self._update_function_call_result(
+ frame.function_name, frame.tool_call_id, "CANCELLED"
+ )
+
+ async def _update_function_call_result(
+ self, function_name: str, tool_call_id: str, result: Any
+ ):
+ for message in self._context.messages:
+ if message.role == "user":
+ for part in message.parts:
+ if part.function_response and part.function_response.id == tool_call_id:
+ part.function_response.response = {"value": json.dumps(result)}
+
+ async def handle_user_image_frame(self, frame: UserImageRawFrame):
+ await self._update_function_call_result(
+ frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
+ )
+ self._context.add_image_frame_message(
+ format=frame.format,
+ size=frame.size,
+ image=frame.image,
+ text=frame.request.context,
+ )
+
+
+@dataclass
+class GoogleContextAggregatorPair:
+ _user: GoogleUserContextAggregator
+ _assistant: GoogleAssistantContextAggregator
+
+ def user(self) -> GoogleUserContextAggregator:
+ return self._user
+
+ def assistant(self) -> GoogleAssistantContextAggregator:
+ return self._assistant
+
+
+class GoogleLLMContext(OpenAILLMContext):
+ def __init__(
+ self,
+ messages: Optional[List[dict]] = None,
+ tools: Optional[List[dict]] = None,
+ tool_choice: Optional[dict] = None,
+ ):
+ super().__init__(messages=messages, tools=tools, tool_choice=tool_choice)
+ self.system_message = None
+
+ @staticmethod
+ def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext":
+ if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext):
+ logger.debug(f"Upgrading to Google: {obj}")
+ obj.__class__ = GoogleLLMContext
+ obj._restructure_from_openai_messages()
+ return obj
+
+ def set_messages(self, messages: List):
+ self._messages[:] = messages
+ self._restructure_from_openai_messages()
+
+ def add_messages(self, messages: List):
+ # Convert each message individually
+ converted_messages = []
+ for msg in messages:
+ if isinstance(msg, glm.Content):
+ # Already in Gemini format
+ converted_messages.append(msg)
+ else:
+ # Convert from standard format to Gemini format
+ converted = self.from_standard_message(msg)
+ if converted is not None:
+ converted_messages.append(converted)
+
+ # Add the converted messages to our existing messages
+ self._messages.extend(converted_messages)
+
+ def get_messages_for_logging(self):
+ msgs = []
+ for message in self.messages:
+ obj = glm.Content.to_dict(message)
+ try:
+ if "parts" in obj:
+ for part in obj["parts"]:
+ if "inline_data" in part:
+ part["inline_data"]["data"] = "..."
+ except Exception as e:
+ logger.debug(f"Error: {e}")
+ msgs.append(obj)
+ return msgs
+
+ def add_image_frame_message(
+ self, *, format: str, size: tuple[int, int], image: bytes, text: str = None
+ ):
+ buffer = io.BytesIO()
+ Image.frombytes(format, size, image).save(buffer, format="JPEG")
+
+ parts = []
+ if text:
+ parts.append(glm.Part(text=text))
+ parts.append(glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue())))
+
+ self.add_message(glm.Content(role="user", parts=parts))
+
+ def add_audio_frames_message(
+ self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows"
+ ):
+ if not audio_frames:
+ return
+
+ sample_rate = audio_frames[0].sample_rate
+ num_channels = audio_frames[0].num_channels
+
+ parts = []
+ data = b"".join(frame.audio for frame in audio_frames)
+ # NOTE(aleix): According to the docs only text or inline_data should be needed.
+ # (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference)
+ parts.append(glm.Part(text=text))
+ parts.append(
+ glm.Part(
+ inline_data=glm.Blob(
+ mime_type="audio/wav",
+ data=(
+ bytes(
+ self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data
+ )
+ ),
+ )
+ ),
+ )
+ self.add_message(glm.Content(role="user", parts=parts))
+ # message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))}
+ # self.add_message(message)
+
+ def from_standard_message(self, message):
+ """Convert standard format message to Google Content object.
+
+ Handles conversion of text, images, and function calls to Google's format.
+ System messages are stored separately and return None.
+
+ Args:
+ message: Message in standard format:
+ {
+ "role": "user/assistant/system/tool",
+ "content": str | [{"type": "text/image_url", ...}] | None,
+ "tool_calls": [{"function": {"name": str, "arguments": str}}]
+ }
+
+ Returns:
+ glm.Content object with:
+ - role: "user" or "model" (converted from "assistant")
+ - parts: List[Part] containing text, inline_data, or function calls
+ Returns None for system messages.
+ """
+ role = message["role"]
+ content = message.get("content", [])
+ if role == "system":
+ self.system_message = content
+ return None
+ elif role == "assistant":
+ role = "model"
+
+ parts = []
+ if message.get("tool_calls"):
+ for tc in message["tool_calls"]:
+ parts.append(
+ glm.Part(
+ function_call=glm.FunctionCall(
+ name=tc["function"]["name"],
+ args=json.loads(tc["function"]["arguments"]),
+ )
+ )
+ )
+ elif role == "tool":
+ role = "model"
+ parts.append(
+ glm.Part(
+ function_response=glm.FunctionResponse(
+ name="tool_call_result", # seems to work to hard-code the same name every time
+ response=json.loads(message["content"]),
+ )
+ )
+ )
+ elif isinstance(content, str):
+ parts.append(glm.Part(text=content))
+ elif isinstance(content, list):
+ for c in content:
+ if c["type"] == "text":
+ parts.append(glm.Part(text=c["text"]))
+ elif c["type"] == "image_url":
+ parts.append(
+ glm.Part(
+ inline_data=glm.Blob(
+ mime_type="image/jpeg",
+ data=base64.b64decode(c["image_url"]["url"].split(",")[1]),
+ )
+ )
+ )
+
+ message = glm.Content(role=role, parts=parts)
+ return message
+
+ def to_standard_messages(self, obj) -> list:
+ """Convert Google Content object to standard structured format.
+
+ Handles text, images, and function calls from Google's Content/Part objects.
+
+ Args:
+ obj: Google Content object with:
+ - role: "model" (converted to "assistant") or "user"
+ - parts: List[Part] containing text, inline_data, or function calls
+
+ Returns:
+ List of messages in standard format:
+ [
+ {
+ "role": "user/assistant/tool",
+ "content": [
+ {"type": "text", "text": str} |
+ {"type": "image_url", "image_url": {"url": str}}
+ ]
+ }
+ ]
+ """
+ msg = {"role": obj.role, "content": []}
+ if msg["role"] == "model":
+ msg["role"] = "assistant"
+
+ for part in obj.parts:
+ if part.text:
+ msg["content"].append({"type": "text", "text": part.text})
+ elif part.inline_data:
+ encoded = base64.b64encode(part.inline_data.data).decode("utf-8")
+ msg["content"].append(
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:{part.inline_data.mime_type};base64,{encoded}"},
+ }
+ )
+ elif part.function_call:
+ args = type(part.function_call).to_dict(part.function_call).get("args", {})
+ msg["tool_calls"] = [
+ {
+ "id": part.function_call.name,
+ "type": "function",
+ "function": {
+ "name": part.function_call.name,
+ "arguments": json.dumps(args),
+ },
+ }
+ ]
+
+ elif part.function_response:
+ msg["role"] = "tool"
+ resp = (
+ type(part.function_response).to_dict(part.function_response).get("response", {})
+ )
+ msg["tool_call_id"] = part.function_response.name
+ msg["content"] = json.dumps(resp)
+
+ # there might be no content parts for tool_calls messages
+ if not msg["content"]:
+ del msg["content"]
+ return [msg]
+
+ def _restructure_from_openai_messages(self):
+ """Restructures messages to ensure proper Google format and message ordering.
+
+ This method handles conversion of OpenAI-formatted messages to Google format,
+ with special handling for function calls, function responses, and system messages.
+ System messages are added back to the context as user messages when needed.
+
+ The final message order is preserved as:
+ 1. Function calls (from model)
+ 2. Function responses (from user)
+ 3. Text messages (converted from system messages)
+
+ Note:
+ System messages are only added back when there are no regular text
+ messages in the context, ensuring proper conversation continuity
+ after function calls.
+ """
+ self.system_message = None
+ converted_messages = []
+
+ # Process each message, preserving Google-formatted messages and converting others
+ for message in self._messages:
+ if isinstance(message, glm.Content):
+ # Keep existing Google-formatted messages (e.g., function calls/responses)
+ converted_messages.append(message)
+ continue
+
+ # Convert OpenAI format to Google format, system messages return None
+ converted = self.from_standard_message(message)
+ if converted is not None:
+ converted_messages.append(converted)
+
+ # Update message list
+ self._messages[:] = converted_messages
+
+ # Check if we only have function-related messages (no regular text)
+ has_regular_messages = any(
+ len(msg.parts) == 1
+ and not getattr(msg.parts[0], "text", None)
+ and getattr(msg.parts[0], "function_call", None)
+ and getattr(msg.parts[0], "function_response", None)
+ for msg in self._messages
+ )
+
+ # Add system message back as a user message if we only have function messages
+ if self.system_message and not has_regular_messages:
+ self._messages.append(
+ glm.Content(role="user", parts=[glm.Part(text=self.system_message)])
+ )
+
+ # Remove any empty messages
+ self._messages = [m for m in self._messages if m.parts]
+
+
+class GoogleLLMService(LLMService):
+ """This class implements inference with Google's AI models.
+
+ This service translates internally from OpenAILLMContext to the messages format
+ expected by the Google AI model. We are using the OpenAILLMContext as a lingua
+ franca for all LLM services, so that it is easy to switch between different LLMs.
+ """
+
+ # Overriding the default adapter to use the Gemini one.
+ adapter_class = GeminiLLMAdapter
+
+ class InputParams(BaseModel):
+ max_tokens: Optional[int] = Field(default=4096, ge=1)
+ temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0)
+ top_k: Optional[int] = Field(default=None, ge=0)
+ top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0)
+ extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
+
+ def __init__(
+ self,
+ *,
+ api_key: str,
+ model: str = "gemini-2.0-flash-001",
+ params: InputParams = InputParams(),
+ system_instruction: Optional[str] = None,
+ tools: Optional[List[Dict[str, Any]]] = None,
+ tool_config: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ gai.configure(api_key=api_key)
+ self.set_model_name(model)
+ self._system_instruction = system_instruction
+ self._create_client()
+ self._settings = {
+ "max_tokens": params.max_tokens,
+ "temperature": params.temperature,
+ "top_k": params.top_k,
+ "top_p": params.top_p,
+ "extra": params.extra if isinstance(params.extra, dict) else {},
+ }
+ self._tools = tools
+ self._tool_config = tool_config
+
+ def can_generate_metrics(self) -> bool:
+ return True
+
+ def _create_client(self):
+ self._client = gai.GenerativeModel(
+ self._model_name, system_instruction=self._system_instruction
+ )
+
+ async def _process_context(self, context: OpenAILLMContext):
+ await self.push_frame(LLMFullResponseStartFrame())
+
+ prompt_tokens = 0
+ completion_tokens = 0
+ total_tokens = 0
+
+ grounding_metadata = None
+ search_result = ""
+
+ try:
+ logger.debug(
+ # f"{self}: Generating chat [{self._system_instruction}] | [{context.get_messages_for_logging()}]"
+ f"{self}: Generating chat [{context.get_messages_for_logging()}]"
+ )
+
+ messages = context.messages
+ if context.system_message and self._system_instruction != context.system_message:
+ logger.debug(f"System instruction changed: {context.system_message}")
+ self._system_instruction = context.system_message
+ self._create_client()
+
+ # Filter out None values and create GenerationConfig
+ generation_params = {
+ k: v
+ for k, v in {
+ "temperature": self._settings["temperature"],
+ "top_p": self._settings["top_p"],
+ "top_k": self._settings["top_k"],
+ "max_output_tokens": self._settings["max_tokens"],
+ }.items()
+ if v is not None
+ }
+
+ generation_config = GenerationConfig(**generation_params) if generation_params else None
+
+ await self.start_ttfb_metrics()
+ tools = []
+ if context.tools:
+ tools = context.tools
+ elif self._tools:
+ tools = self._tools
+ tool_config = None
+ if self._tool_config:
+ tool_config = self._tool_config
+ response = await self._client.generate_content_async(
+ contents=messages,
+ tools=tools,
+ stream=True,
+ generation_config=generation_config,
+ tool_config=tool_config,
+ )
+ await self.stop_ttfb_metrics()
+
+ if response.usage_metadata:
+ # Use only the prompt token count from the response object
+ prompt_tokens = response.usage_metadata.prompt_token_count
+ total_tokens = prompt_tokens
+
+ async for chunk in response:
+ if chunk.usage_metadata:
+ # Use only the completion_tokens from the chunks. Prompt tokens are already counted and
+ # are repeated here.
+ completion_tokens += chunk.usage_metadata.candidates_token_count
+ total_tokens += chunk.usage_metadata.candidates_token_count
+ try:
+ for c in chunk.parts:
+ if c.text:
+ search_result += c.text
+ await self.push_frame(LLMTextFrame(c.text))
+ elif c.function_call:
+ logger.debug(f"Function call: {c.function_call}")
+ args = type(c.function_call).to_dict(c.function_call).get("args", {})
+ await self.call_function(
+ context=context,
+ tool_call_id=str(uuid.uuid4()),
+ function_name=c.function_call.name,
+ arguments=args,
+ )
+ # Handle grounding metadata
+ # It seems only the last chunk that we receive may contain this information
+ # If the response doesn't include groundingMetadata, this means the response wasn't grounded.
+ if chunk.candidates:
+ for candidate in chunk.candidates:
+ # logger.debug(f"candidate received: {candidate}")
+ # Extract grounding metadata
+ grounding_metadata = (
+ {
+ "rendered_content": getattr(
+ getattr(candidate, "grounding_metadata", None),
+ "search_entry_point",
+ None,
+ ).rendered_content
+ if hasattr(
+ getattr(candidate, "grounding_metadata", None),
+ "search_entry_point",
+ )
+ else None,
+ "origins": [
+ {
+ "site_uri": getattr(grounding_chunk.web, "uri", None),
+ "site_title": getattr(
+ grounding_chunk.web, "title", None
+ ),
+ "results": [
+ {
+ "text": getattr(
+ grounding_support.segment, "text", ""
+ ),
+ "confidence": getattr(
+ grounding_support, "confidence_scores", None
+ ),
+ }
+ for grounding_support in getattr(
+ getattr(candidate, "grounding_metadata", None),
+ "grounding_supports",
+ [],
+ )
+ if index
+ in getattr(
+ grounding_support, "grounding_chunk_indices", []
+ )
+ ],
+ }
+ for index, grounding_chunk in enumerate(
+ getattr(
+ getattr(candidate, "grounding_metadata", None),
+ "grounding_chunks",
+ [],
+ )
+ )
+ ],
+ }
+ if getattr(candidate, "grounding_metadata", None)
+ else None
+ )
+ except Exception as e:
+ # Google LLMs seem to flag safety issues a lot!
+ if chunk.candidates[0].finish_reason == 3:
+ logger.debug(
+ f"LLM refused to generate content for safety reasons - {messages}."
+ )
+ else:
+ logger.exception(f"{self} error: {e}")
+
+ except DeadlineExceeded:
+ await self._call_event_handler("on_completion_timeout")
+ except Exception as e:
+ logger.exception(f"{self} exception: {e}")
+ finally:
+ if grounding_metadata is not None and isinstance(grounding_metadata, dict):
+ llm_search_frame = LLMSearchResponseFrame(
+ search_result=search_result,
+ origins=grounding_metadata["origins"],
+ rendered_content=grounding_metadata["rendered_content"],
+ )
+ await self.push_frame(llm_search_frame)
+
+ await self.start_llm_usage_metrics(
+ LLMTokenUsage(
+ prompt_tokens=prompt_tokens,
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ )
+ )
+ await self.push_frame(LLMFullResponseEndFrame())
+
+ async def process_frame(self, frame: Frame, direction: FrameDirection):
+ await super().process_frame(frame, direction)
+
+ context = None
+
+ if isinstance(frame, OpenAILLMContextFrame):
+ context = GoogleLLMContext.upgrade_to_google(frame.context)
+ elif isinstance(frame, LLMMessagesFrame):
+ context = GoogleLLMContext(frame.messages)
+ elif isinstance(frame, VisionImageRawFrame):
+ context = GoogleLLMContext()
+ context.add_image_frame_message(
+ format=frame.format, size=frame.size, image=frame.image, text=frame.text
+ )
+ elif isinstance(frame, LLMUpdateSettingsFrame):
+ await self._update_settings(frame.settings)
+ else:
+ await self.push_frame(frame, direction)
+
+ if context:
+ await self._process_context(context)
+
+ def create_context_aggregator(
+ self,
+ context: OpenAILLMContext,
+ *,
+ user_kwargs: Mapping[str, Any] = {},
+ assistant_kwargs: Mapping[str, Any] = {},
+ ) -> GoogleContextAggregatorPair:
+ """Create an instance of GoogleContextAggregatorPair from an
+ OpenAILLMContext. Constructor keyword arguments for both the user and
+ assistant aggregators can be provided.
+
+ Args:
+ context (OpenAILLMContext): The LLM context.
+ user_kwargs (Mapping[str, Any], optional): Additional keyword
+ arguments for the user context aggregator constructor. Defaults
+ to an empty mapping.
+ assistant_kwargs (Mapping[str, Any], optional): Additional keyword
+ arguments for the assistant context aggregator
+ constructor. Defaults to an empty mapping.
+
+ Returns:
+ GoogleContextAggregatorPair: A pair of context aggregators, one for
+ the user and one for the assistant, encapsulated in an
+ GoogleContextAggregatorPair.
+
+ """
+ context.set_llm_adapter(self.get_llm_adapter())
+
+ if isinstance(context, OpenAILLMContext):
+ context = GoogleLLMContext.upgrade_to_google(context)
+ user = GoogleUserContextAggregator(context, **user_kwargs)
+ assistant = GoogleAssistantContextAggregator(context, **assistant_kwargs)
+ return GoogleContextAggregatorPair(_user=user, _assistant=assistant)
diff --git a/src/pipecat/services/google/llm_openai.py b/src/pipecat/services/google/llm_openai.py
new file mode 100644
index 000000000..94b99072a
--- /dev/null
+++ b/src/pipecat/services/google/llm_openai.py
@@ -0,0 +1,136 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import json
+import os
+
+from openai import AsyncStream
+from openai.types.chat import ChatCompletionChunk
+
+# Suppress gRPC fork warnings
+os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
+
+from loguru import logger
+
+from pipecat.frames.frames import LLMTextFrame
+from pipecat.metrics.metrics import LLMTokenUsage
+from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
+from pipecat.services.openai.base_llm import OpenAIUnhandledFunctionException
+from pipecat.services.openai.llm import OpenAILLMService
+
+
+class GoogleLLMOpenAIBetaService(OpenAILLMService):
+ """This class implements inference with Google's AI LLM models using the OpenAI format.
+ Ref - https://ai.google.dev/gemini-api/docs/openai
+ """
+
+ def __init__(
+ self,
+ *,
+ api_key: str,
+ base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/",
+ model: str = "gemini-2.0-flash",
+ **kwargs,
+ ):
+ super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
+
+ async def _process_context(self, context: OpenAILLMContext):
+ functions_list = []
+ arguments_list = []
+ tool_id_list = []
+ func_idx = 0
+ function_name = ""
+ arguments = ""
+ tool_call_id = ""
+
+ await self.start_ttfb_metrics()
+
+ chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(
+ context
+ )
+
+ async for chunk in chunk_stream:
+ if chunk.usage:
+ tokens = LLMTokenUsage(
+ prompt_tokens=chunk.usage.prompt_tokens,
+ completion_tokens=chunk.usage.completion_tokens,
+ total_tokens=chunk.usage.total_tokens,
+ )
+ await self.start_llm_usage_metrics(tokens)
+
+ if chunk.choices is None or len(chunk.choices) == 0:
+ continue
+
+ await self.stop_ttfb_metrics()
+
+ if not chunk.choices[0].delta:
+ continue
+
+ if chunk.choices[0].delta.tool_calls:
+ # We're streaming the LLM response to enable the fastest response times.
+ # For text, we just yield each chunk as we receive it and count on consumers
+ # to do whatever coalescing they need (eg. to pass full sentences to TTS)
+ #
+ # If the LLM is a function call, we'll do some coalescing here.
+ # If the response contains a function name, we'll yield a frame to tell consumers
+ # that they can start preparing to call the function with that name.
+ # We accumulate all the arguments for the rest of the streamed response, then when
+ # the response is done, we package up all the arguments and the function name and
+ # yield a frame containing the function name and the arguments.
+ logger.debug(f"Tool call: {chunk.choices[0].delta.tool_calls}")
+ tool_call = chunk.choices[0].delta.tool_calls[0]
+ if tool_call.index != func_idx:
+ functions_list.append(function_name)
+ arguments_list.append(arguments)
+ tool_id_list.append(tool_call_id)
+ function_name = ""
+ arguments = ""
+ tool_call_id = ""
+ func_idx += 1
+ if tool_call.function and tool_call.function.name:
+ function_name += tool_call.function.name
+ tool_call_id = tool_call.id
+ if tool_call.function and tool_call.function.arguments:
+ # Keep iterating through the response to collect all the argument fragments
+ arguments += tool_call.function.arguments
+ elif chunk.choices[0].delta.content:
+ await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
+
+ # if we got a function name and arguments, check to see if it's a function with
+ # a registered handler. If so, run the registered callback, save the result to
+ # the context, and re-prompt to get a chat answer. If we don't have a registered
+ # handler, raise an exception.
+ if function_name and arguments:
+ # added to the list as last function name and arguments not added to the list
+ functions_list.append(function_name)
+ arguments_list.append(arguments)
+ tool_id_list.append(tool_call_id)
+
+ logger.debug(
+ f"Function list: {functions_list}, Arguments list: {arguments_list}, Tool ID list: {tool_id_list}"
+ )
+ for index, (function_name, arguments, tool_id) in enumerate(
+ zip(functions_list, arguments_list, tool_id_list), start=1
+ ):
+ if function_name == "":
+ # TODO: Remove the _process_context method once Google resolves the bug
+ # where the index is incorrectly set to None instead of returning the actual index,
+ # which currently results in an empty function name('').
+ continue
+ if self.has_function(function_name):
+ run_llm = False
+ arguments = json.loads(arguments)
+ await self.call_function(
+ context=context,
+ function_name=function_name,
+ arguments=arguments,
+ tool_call_id=tool_id,
+ run_llm=run_llm,
+ )
+ else:
+ raise OpenAIUnhandledFunctionException(
+ f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
+ )
diff --git a/src/pipecat/services/google/llm_vertex.py b/src/pipecat/services/google/llm_vertex.py
new file mode 100644
index 000000000..1a23fe7d5
--- /dev/null
+++ b/src/pipecat/services/google/llm_vertex.py
@@ -0,0 +1,107 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import json
+import os
+
+# Suppress gRPC fork warnings
+os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
+
+from typing import Optional
+
+from loguru import logger
+
+from pipecat.services.openai.llm import OpenAILLMService
+
+try:
+ from google.auth.transport.requests import Request
+ from google.oauth2 import service_account
+
+except ModuleNotFoundError as e:
+ logger.error(f"Exception: {e}")
+ logger.error(
+ "In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_APPLICATION_CREDENTIALS` environment variable."
+ )
+ raise Exception(f"Missing module: {e}")
+
+
+class GoogleVertexLLMService(OpenAILLMService):
+ """Implements inference with Google's AI models via Vertex AI while
+ maintaining OpenAI API compatibility.
+
+ Reference:
+ https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-vertex-using-openai-library
+
+ """
+
+ class InputParams(OpenAILLMService.InputParams):
+ """Input parameters specific to Vertex AI."""
+
+ # https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations
+ location: str = "us-east4"
+ project_id: str
+
+ def __init__(
+ self,
+ *,
+ credentials: Optional[str] = None,
+ credentials_path: Optional[str] = None,
+ model: str = "google/gemini-2.0-flash-001",
+ params: InputParams = OpenAILLMService.InputParams(),
+ **kwargs,
+ ):
+ """Initializes the VertexLLMService.
+
+ Args:
+ credentials (Optional[str]): JSON string of service account credentials.
+ credentials_path (Optional[str]): Path to the service account JSON file.
+ model (str): Model identifier. Defaults to "google/gemini-2.0-flash-001".
+ params (InputParams): Vertex AI input parameters.
+ **kwargs: Additional arguments for OpenAILLMService.
+ """
+ base_url = self._get_base_url(params)
+ self._api_key = self._get_api_token(credentials, credentials_path)
+
+ super().__init__(api_key=self._api_key, base_url=base_url, model=model, **kwargs)
+
+ @staticmethod
+ def _get_base_url(params: InputParams) -> str:
+ """Constructs the base URL for Vertex AI API."""
+ return (
+ f"https://{params.location}-aiplatform.googleapis.com/v1/"
+ f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi"
+ )
+
+ @staticmethod
+ def _get_api_token(credentials: Optional[str], credentials_path: Optional[str]) -> str:
+ """Retrieves an authentication token using Google service account credentials.
+
+ Args:
+ credentials (Optional[str]): JSON string of service account credentials.
+ credentials_path (Optional[str]): Path to the service account JSON file.
+
+ Returns:
+ str: OAuth token for API authentication.
+ """
+ creds: Optional[service_account.Credentials] = None
+
+ if credentials:
+ # Parse and load credentials from JSON string
+ creds = service_account.Credentials.from_service_account_info(
+ json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"]
+ )
+ elif credentials_path:
+ # Load credentials from JSON file
+ creds = service_account.Credentials.from_service_account_file(
+ credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"]
+ )
+
+ if not creds:
+ raise ValueError("No valid credentials provided.")
+
+ creds.refresh(Request()) # Ensure token is up-to-date, lifetime is 1 hour.
+
+ return creds.token
diff --git a/src/pipecat/services/google/stt.py b/src/pipecat/services/google/stt.py
new file mode 100644
index 000000000..7af994bb2
--- /dev/null
+++ b/src/pipecat/services/google/stt.py
@@ -0,0 +1,806 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import asyncio
+import json
+import os
+import time
+
+# Suppress gRPC fork warnings
+os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
+
+from typing import AsyncGenerator, List, Optional, Union
+
+from loguru import logger
+from pydantic import BaseModel, Field, field_validator
+
+from pipecat.frames.frames import (
+ CancelFrame,
+ EndFrame,
+ ErrorFrame,
+ Frame,
+ InterimTranscriptionFrame,
+ StartFrame,
+ TranscriptionFrame,
+)
+from pipecat.services.ai_services import STTService
+from pipecat.transcriptions.language import Language
+from pipecat.utils.time import time_now_iso8601
+
+try:
+ from google.api_core.client_options import ClientOptions
+ from google.cloud import speech_v2
+ from google.cloud.speech_v2.types import cloud_speech
+ from google.oauth2 import service_account
+
+except ModuleNotFoundError as e:
+ logger.error(f"Exception: {e}")
+ logger.error(
+ "In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_APPLICATION_CREDENTIALS` environment variable."
+ )
+ raise Exception(f"Missing module: {e}")
+
+
+def language_to_google_stt_language(language: Language) -> Optional[str]:
+ """Maps Language enum to Google Speech-to-Text V2 language codes.
+
+ Args:
+ language: Language enum value.
+
+ Returns:
+ Optional[str]: Google STT language code or None if not supported.
+ """
+ language_map = {
+ # Afrikaans
+ Language.AF: "af-ZA",
+ Language.AF_ZA: "af-ZA",
+ # Albanian
+ Language.SQ: "sq-AL",
+ Language.SQ_AL: "sq-AL",
+ # Amharic
+ Language.AM: "am-ET",
+ Language.AM_ET: "am-ET",
+ # Arabic
+ Language.AR: "ar-EG", # Default to Egypt
+ Language.AR_AE: "ar-AE",
+ Language.AR_BH: "ar-BH",
+ Language.AR_DZ: "ar-DZ",
+ Language.AR_EG: "ar-EG",
+ Language.AR_IQ: "ar-IQ",
+ Language.AR_JO: "ar-JO",
+ Language.AR_KW: "ar-KW",
+ Language.AR_LB: "ar-LB",
+ Language.AR_MA: "ar-MA",
+ Language.AR_OM: "ar-OM",
+ Language.AR_QA: "ar-QA",
+ Language.AR_SA: "ar-SA",
+ Language.AR_SY: "ar-SY",
+ Language.AR_TN: "ar-TN",
+ Language.AR_YE: "ar-YE",
+ # Armenian
+ Language.HY: "hy-AM",
+ Language.HY_AM: "hy-AM",
+ # Azerbaijani
+ Language.AZ: "az-AZ",
+ Language.AZ_AZ: "az-AZ",
+ # Basque
+ Language.EU: "eu-ES",
+ Language.EU_ES: "eu-ES",
+ # Bengali
+ Language.BN: "bn-IN", # Default to India
+ Language.BN_BD: "bn-BD",
+ Language.BN_IN: "bn-IN",
+ # Bosnian
+ Language.BS: "bs-BA",
+ Language.BS_BA: "bs-BA",
+ # Bulgarian
+ Language.BG: "bg-BG",
+ Language.BG_BG: "bg-BG",
+ # Burmese
+ Language.MY: "my-MM",
+ Language.MY_MM: "my-MM",
+ # Catalan
+ Language.CA: "ca-ES",
+ Language.CA_ES: "ca-ES",
+ # Chinese
+ Language.ZH: "cmn-Hans-CN", # Default to Simplified Chinese
+ Language.ZH_CN: "cmn-Hans-CN",
+ Language.ZH_HK: "cmn-Hans-HK",
+ Language.ZH_TW: "cmn-Hant-TW",
+ Language.YUE: "yue-Hant-HK", # Cantonese
+ Language.YUE_CN: "yue-Hant-HK",
+ # Croatian
+ Language.HR: "hr-HR",
+ Language.HR_HR: "hr-HR",
+ # Czech
+ Language.CS: "cs-CZ",
+ Language.CS_CZ: "cs-CZ",
+ # Danish
+ Language.DA: "da-DK",
+ Language.DA_DK: "da-DK",
+ # Dutch
+ Language.NL: "nl-NL", # Default to Netherlands
+ Language.NL_BE: "nl-BE",
+ Language.NL_NL: "nl-NL",
+ # English
+ Language.EN: "en-US", # Default to US
+ Language.EN_AU: "en-AU",
+ Language.EN_CA: "en-CA",
+ Language.EN_GB: "en-GB",
+ Language.EN_GH: "en-GH",
+ Language.EN_HK: "en-HK",
+ Language.EN_IN: "en-IN",
+ Language.EN_IE: "en-IE",
+ Language.EN_KE: "en-KE",
+ Language.EN_NG: "en-NG",
+ Language.EN_NZ: "en-NZ",
+ Language.EN_PH: "en-PH",
+ Language.EN_SG: "en-SG",
+ Language.EN_TZ: "en-TZ",
+ Language.EN_US: "en-US",
+ Language.EN_ZA: "en-ZA",
+ # Estonian
+ Language.ET: "et-EE",
+ Language.ET_EE: "et-EE",
+ # Filipino
+ Language.FIL: "fil-PH",
+ Language.FIL_PH: "fil-PH",
+ # Finnish
+ Language.FI: "fi-FI",
+ Language.FI_FI: "fi-FI",
+ # French
+ Language.FR: "fr-FR", # Default to France
+ Language.FR_BE: "fr-BE",
+ Language.FR_CA: "fr-CA",
+ Language.FR_CH: "fr-CH",
+ Language.FR_FR: "fr-FR",
+ # Galician
+ Language.GL: "gl-ES",
+ Language.GL_ES: "gl-ES",
+ # Georgian
+ Language.KA: "ka-GE",
+ Language.KA_GE: "ka-GE",
+ # German
+ Language.DE: "de-DE", # Default to Germany
+ Language.DE_AT: "de-AT",
+ Language.DE_CH: "de-CH",
+ Language.DE_DE: "de-DE",
+ # Greek
+ Language.EL: "el-GR",
+ Language.EL_GR: "el-GR",
+ # Gujarati
+ Language.GU: "gu-IN",
+ Language.GU_IN: "gu-IN",
+ # Hebrew
+ Language.HE: "iw-IL",
+ Language.HE_IL: "iw-IL",
+ # Hindi
+ Language.HI: "hi-IN",
+ Language.HI_IN: "hi-IN",
+ # Hungarian
+ Language.HU: "hu-HU",
+ Language.HU_HU: "hu-HU",
+ # Icelandic
+ Language.IS: "is-IS",
+ Language.IS_IS: "is-IS",
+ # Indonesian
+ Language.ID: "id-ID",
+ Language.ID_ID: "id-ID",
+ # Italian
+ Language.IT: "it-IT",
+ Language.IT_IT: "it-IT",
+ Language.IT_CH: "it-CH",
+ # Japanese
+ Language.JA: "ja-JP",
+ Language.JA_JP: "ja-JP",
+ # Javanese
+ Language.JV: "jv-ID",
+ Language.JV_ID: "jv-ID",
+ # Kannada
+ Language.KN: "kn-IN",
+ Language.KN_IN: "kn-IN",
+ # Kazakh
+ Language.KK: "kk-KZ",
+ Language.KK_KZ: "kk-KZ",
+ # Khmer
+ Language.KM: "km-KH",
+ Language.KM_KH: "km-KH",
+ # Korean
+ Language.KO: "ko-KR",
+ Language.KO_KR: "ko-KR",
+ # Lao
+ Language.LO: "lo-LA",
+ Language.LO_LA: "lo-LA",
+ # Latvian
+ Language.LV: "lv-LV",
+ Language.LV_LV: "lv-LV",
+ # Lithuanian
+ Language.LT: "lt-LT",
+ Language.LT_LT: "lt-LT",
+ # Macedonian
+ Language.MK: "mk-MK",
+ Language.MK_MK: "mk-MK",
+ # Malay
+ Language.MS: "ms-MY",
+ Language.MS_MY: "ms-MY",
+ # Malayalam
+ Language.ML: "ml-IN",
+ Language.ML_IN: "ml-IN",
+ # Marathi
+ Language.MR: "mr-IN",
+ Language.MR_IN: "mr-IN",
+ # Mongolian
+ Language.MN: "mn-MN",
+ Language.MN_MN: "mn-MN",
+ # Nepali
+ Language.NE: "ne-NP",
+ Language.NE_NP: "ne-NP",
+ # Norwegian
+ Language.NO: "no-NO",
+ Language.NB: "no-NO",
+ Language.NB_NO: "no-NO",
+ # Persian
+ Language.FA: "fa-IR",
+ Language.FA_IR: "fa-IR",
+ # Polish
+ Language.PL: "pl-PL",
+ Language.PL_PL: "pl-PL",
+ # Portuguese
+ Language.PT: "pt-PT", # Default to Portugal
+ Language.PT_BR: "pt-BR",
+ Language.PT_PT: "pt-PT",
+ # Punjabi
+ Language.PA: "pa-Guru-IN",
+ Language.PA_IN: "pa-Guru-IN",
+ # Romanian
+ Language.RO: "ro-RO",
+ Language.RO_RO: "ro-RO",
+ # Russian
+ Language.RU: "ru-RU",
+ Language.RU_RU: "ru-RU",
+ # Serbian
+ Language.SR: "sr-RS",
+ Language.SR_RS: "sr-RS",
+ # Sinhala
+ Language.SI: "si-LK",
+ Language.SI_LK: "si-LK",
+ # Slovak
+ Language.SK: "sk-SK",
+ Language.SK_SK: "sk-SK",
+ # Slovenian
+ Language.SL: "sl-SI",
+ Language.SL_SI: "sl-SI",
+ # Spanish
+ Language.ES: "es-ES", # Default to Spain
+ Language.ES_AR: "es-AR",
+ Language.ES_BO: "es-BO",
+ Language.ES_CL: "es-CL",
+ Language.ES_CO: "es-CO",
+ Language.ES_CR: "es-CR",
+ Language.ES_DO: "es-DO",
+ Language.ES_EC: "es-EC",
+ Language.ES_ES: "es-ES",
+ Language.ES_GT: "es-GT",
+ Language.ES_HN: "es-HN",
+ Language.ES_MX: "es-MX",
+ Language.ES_NI: "es-NI",
+ Language.ES_PA: "es-PA",
+ Language.ES_PE: "es-PE",
+ Language.ES_PR: "es-PR",
+ Language.ES_PY: "es-PY",
+ Language.ES_SV: "es-SV",
+ Language.ES_US: "es-US",
+ Language.ES_UY: "es-UY",
+ Language.ES_VE: "es-VE",
+ # Sundanese
+ Language.SU: "su-ID",
+ Language.SU_ID: "su-ID",
+ # Swahili
+ Language.SW: "sw-TZ", # Default to Tanzania
+ Language.SW_KE: "sw-KE",
+ Language.SW_TZ: "sw-TZ",
+ # Swedish
+ Language.SV: "sv-SE",
+ Language.SV_SE: "sv-SE",
+ # Tamil
+ Language.TA: "ta-IN", # Default to India
+ Language.TA_IN: "ta-IN",
+ Language.TA_MY: "ta-MY",
+ Language.TA_SG: "ta-SG",
+ Language.TA_LK: "ta-LK",
+ # Telugu
+ Language.TE: "te-IN",
+ Language.TE_IN: "te-IN",
+ # Thai
+ Language.TH: "th-TH",
+ Language.TH_TH: "th-TH",
+ # Turkish
+ Language.TR: "tr-TR",
+ Language.TR_TR: "tr-TR",
+ # Ukrainian
+ Language.UK: "uk-UA",
+ Language.UK_UA: "uk-UA",
+ # Urdu
+ Language.UR: "ur-IN", # Default to India
+ Language.UR_IN: "ur-IN",
+ Language.UR_PK: "ur-PK",
+ # Uzbek
+ Language.UZ: "uz-UZ",
+ Language.UZ_UZ: "uz-UZ",
+ # Vietnamese
+ Language.VI: "vi-VN",
+ Language.VI_VN: "vi-VN",
+ # Xhosa
+ Language.XH: "xh-ZA",
+ # Zulu
+ Language.ZU: "zu-ZA",
+ Language.ZU_ZA: "zu-ZA",
+ }
+
+ return language_map.get(language)
+
+
+class GoogleSTTService(STTService):
+ """Google Cloud Speech-to-Text V2 service implementation.
+
+ Provides real-time speech recognition using Google Cloud's Speech-to-Text V2 API
+ with streaming support. Handles audio transcription and optional voice activity detection.
+
+ Attributes:
+ InputParams: Configuration parameters for the STT service.
+ """
+
+ # Google Cloud's STT service has a connection time limit of 5 minutes per stream.
+ # They've shared an "endless streaming" example that guided this implementation:
+ # https://cloud.google.com/speech-to-text/docs/transcribe-streaming-audio#endless-streaming
+
+ STREAMING_LIMIT = 240000 # 4 minutes in milliseconds
+
+ class InputParams(BaseModel):
+ """Configuration parameters for Google Speech-to-Text.
+
+ Attributes:
+ languages: Single language or list of recognition languages. First language is primary.
+ model: Speech recognition model to use.
+ use_separate_recognition_per_channel: Process each audio channel separately.
+ enable_automatic_punctuation: Add punctuation to transcripts.
+ enable_spoken_punctuation: Include spoken punctuation in transcript.
+ enable_spoken_emojis: Include spoken emojis in transcript.
+ profanity_filter: Filter profanity from transcript.
+ enable_word_time_offsets: Include timing information for each word.
+ enable_word_confidence: Include confidence scores for each word.
+ enable_interim_results: Stream partial recognition results.
+ enable_voice_activity_events: Detect voice activity in audio.
+ """
+
+ languages: Union[Language, List[Language]] = Field(default_factory=lambda: [Language.EN_US])
+ model: Optional[str] = "latest_long"
+ use_separate_recognition_per_channel: Optional[bool] = False
+ enable_automatic_punctuation: Optional[bool] = True
+ enable_spoken_punctuation: Optional[bool] = False
+ enable_spoken_emojis: Optional[bool] = False
+ profanity_filter: Optional[bool] = False
+ enable_word_time_offsets: Optional[bool] = False
+ enable_word_confidence: Optional[bool] = False
+ enable_interim_results: Optional[bool] = True
+ enable_voice_activity_events: Optional[bool] = False
+
+ @field_validator("languages", mode="before")
+ @classmethod
+ def validate_languages(cls, v) -> List[Language]:
+ if isinstance(v, Language):
+ return [v]
+ return v
+
+ @property
+ def language_list(self) -> List[Language]:
+ """Get languages as a guaranteed list."""
+ assert isinstance(self.languages, list)
+ return self.languages
+
+ def __init__(
+ self,
+ *,
+ credentials: Optional[str] = None,
+ credentials_path: Optional[str] = None,
+ location: str = "global",
+ sample_rate: Optional[int] = None,
+ params: InputParams = InputParams(),
+ **kwargs,
+ ):
+ """Initialize the Google STT service.
+
+ Args:
+ credentials: JSON string containing Google Cloud service account credentials.
+ credentials_path: Path to service account credentials JSON file.
+ location: Google Cloud location (e.g., "global", "us-central1").
+ sample_rate: Audio sample rate in Hertz.
+ params: Configuration parameters for the service.
+ **kwargs: Additional arguments passed to STTService.
+
+ Raises:
+ ValueError: If neither credentials nor credentials_path is provided.
+ ValueError: If project ID is not found in credentials.
+ """
+ super().__init__(sample_rate=sample_rate, **kwargs)
+
+ self._location = location
+ self._stream = None
+ self._config = None
+ self._request_queue = asyncio.Queue()
+ self._streaming_task = None
+
+ # Used for keep-alive logic
+ self._stream_start_time = 0
+ self._last_audio_input = []
+ self._audio_input = []
+ self._result_end_time = 0
+ self._is_final_end_time = 0
+ self._final_request_end_time = 0
+ self._bridging_offset = 0
+ self._last_transcript_was_final = False
+ self._new_stream = True
+ self._restart_counter = 0
+
+ # Configure client options based on location
+ client_options = None
+ if self._location != "global":
+ client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com")
+
+ # Extract project ID and create client
+ if credentials:
+ json_account_info = json.loads(credentials)
+ self._project_id = json_account_info.get("project_id")
+ creds = service_account.Credentials.from_service_account_info(json_account_info)
+ elif credentials_path:
+ with open(credentials_path) as f:
+ json_account_info = json.load(f)
+ self._project_id = json_account_info.get("project_id")
+ creds = service_account.Credentials.from_service_account_file(credentials_path)
+ else:
+ raise ValueError("Either credentials or credentials_path must be provided")
+
+ if not self._project_id:
+ raise ValueError("Project ID not found in credentials")
+
+ self._client = speech_v2.SpeechAsyncClient(credentials=creds, client_options=client_options)
+
+ self._settings = {
+ "language_codes": [
+ self.language_to_service_language(lang) for lang in params.language_list
+ ],
+ "model": params.model,
+ "use_separate_recognition_per_channel": params.use_separate_recognition_per_channel,
+ "enable_automatic_punctuation": params.enable_automatic_punctuation,
+ "enable_spoken_punctuation": params.enable_spoken_punctuation,
+ "enable_spoken_emojis": params.enable_spoken_emojis,
+ "profanity_filter": params.profanity_filter,
+ "enable_word_time_offsets": params.enable_word_time_offsets,
+ "enable_word_confidence": params.enable_word_confidence,
+ "enable_interim_results": params.enable_interim_results,
+ "enable_voice_activity_events": params.enable_voice_activity_events,
+ }
+
+ def language_to_service_language(self, language: Language | List[Language]) -> str | List[str]:
+ """Convert Language enum(s) to Google STT language code(s).
+
+ Args:
+ language: Single Language enum or list of Language enums.
+
+ Returns:
+ str | List[str]: Google STT language code(s).
+ """
+ if isinstance(language, list):
+ return [language_to_google_stt_language(lang) or "en-US" for lang in language]
+ return language_to_google_stt_language(language) or "en-US"
+
+ async def _reconnect_if_needed(self):
+ """Reconnect the stream if it's currently active."""
+ if self._streaming_task:
+ logger.debug("Reconnecting stream due to configuration changes")
+ await self._disconnect()
+ await self._connect()
+
+ async def set_language(self, language: Language):
+ """Update the service's recognition language.
+
+ A convenience method for setting a single language.
+
+ Args:
+ language: New language for recognition.
+ """
+ logger.debug(f"Switching STT language to: {language}")
+ await self.set_languages([language])
+
+ async def set_languages(self, languages: List[Language]):
+ """Update the service's recognition languages.
+
+ Args:
+ languages: List of languages for recognition. First language is primary.
+ """
+ logger.debug(f"Switching STT languages to: {languages}")
+ self._settings["language_codes"] = [
+ self.language_to_service_language(lang) for lang in languages
+ ]
+ # Recreate stream with new languages
+ await self._reconnect_if_needed()
+
+ async def set_model(self, model: str):
+ """Update the service's recognition model."""
+ logger.debug(f"Switching STT model to: {model}")
+ await super().set_model(model)
+ self._settings["model"] = model
+ # Recreate stream with new model
+ await self._reconnect_if_needed()
+
+ async def start(self, frame: StartFrame):
+ await super().start(frame)
+ await self._connect()
+
+ async def stop(self, frame: EndFrame):
+ await super().stop(frame)
+ await self._disconnect()
+
+ async def cancel(self, frame: CancelFrame):
+ await super().cancel(frame)
+ await self._disconnect()
+
+ async def update_options(
+ self,
+ *,
+ languages: Optional[List[Language]] = None,
+ model: Optional[str] = None,
+ enable_automatic_punctuation: Optional[bool] = None,
+ enable_spoken_punctuation: Optional[bool] = None,
+ enable_spoken_emojis: Optional[bool] = None,
+ profanity_filter: Optional[bool] = None,
+ enable_word_time_offsets: Optional[bool] = None,
+ enable_word_confidence: Optional[bool] = None,
+ enable_interim_results: Optional[bool] = None,
+ enable_voice_activity_events: Optional[bool] = None,
+ location: Optional[str] = None,
+ ) -> None:
+ """Update service options dynamically.
+
+ Args:
+ languages: New list of recongition languages.
+ model: New recognition model.
+ enable_automatic_punctuation: Enable/disable automatic punctuation.
+ enable_spoken_punctuation: Enable/disable spoken punctuation.
+ enable_spoken_emojis: Enable/disable spoken emojis.
+ profanity_filter: Enable/disable profanity filter.
+ enable_word_time_offsets: Enable/disable word timing info.
+ enable_word_confidence: Enable/disable word confidence scores.
+ enable_interim_results: Enable/disable interim results.
+ enable_voice_activity_events: Enable/disable voice activity detection.
+ location: New Google Cloud location.
+
+ Note:
+ Changes that affect the streaming configuration will cause
+ the stream to be reconnected.
+ """
+ # Update settings with new values
+ if languages is not None:
+ logger.debug(f"Updating language to: {languages}")
+ self._settings["language_codes"] = [
+ self.language_to_service_language(lang) for lang in languages
+ ]
+
+ if model is not None:
+ logger.debug(f"Updating model to: {model}")
+ self._settings["model"] = model
+
+ if enable_automatic_punctuation is not None:
+ logger.debug(f"Updating automatic punctuation to: {enable_automatic_punctuation}")
+ self._settings["enable_automatic_punctuation"] = enable_automatic_punctuation
+
+ if enable_spoken_punctuation is not None:
+ logger.debug(f"Updating spoken punctuation to: {enable_spoken_punctuation}")
+ self._settings["enable_spoken_punctuation"] = enable_spoken_punctuation
+
+ if enable_spoken_emojis is not None:
+ logger.debug(f"Updating spoken emojis to: {enable_spoken_emojis}")
+ self._settings["enable_spoken_emojis"] = enable_spoken_emojis
+
+ if profanity_filter is not None:
+ logger.debug(f"Updating profanity filter to: {profanity_filter}")
+ self._settings["profanity_filter"] = profanity_filter
+
+ if enable_word_time_offsets is not None:
+ logger.debug(f"Updating word time offsets to: {enable_word_time_offsets}")
+ self._settings["enable_word_time_offsets"] = enable_word_time_offsets
+
+ if enable_word_confidence is not None:
+ logger.debug(f"Updating word confidence to: {enable_word_confidence}")
+ self._settings["enable_word_confidence"] = enable_word_confidence
+
+ if enable_interim_results is not None:
+ logger.debug(f"Updating interim results to: {enable_interim_results}")
+ self._settings["enable_interim_results"] = enable_interim_results
+
+ if enable_voice_activity_events is not None:
+ logger.debug(f"Updating voice activity events to: {enable_voice_activity_events}")
+ self._settings["enable_voice_activity_events"] = enable_voice_activity_events
+
+ if location is not None:
+ logger.debug(f"Updating location to: {location}")
+ self._location = location
+
+ # Reconnect the stream for updates
+ await self._reconnect_if_needed()
+
+ async def _connect(self):
+ """Initialize streaming recognition config and stream."""
+ logger.debug("Connecting to Google Speech-to-Text")
+
+ # Set stream start time
+ self._stream_start_time = int(time.time() * 1000)
+ self._new_stream = True
+
+ self._config = cloud_speech.StreamingRecognitionConfig(
+ config=cloud_speech.RecognitionConfig(
+ explicit_decoding_config=cloud_speech.ExplicitDecodingConfig(
+ encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16,
+ sample_rate_hertz=self.sample_rate,
+ audio_channel_count=1,
+ ),
+ language_codes=self._settings["language_codes"],
+ model=self._settings["model"],
+ features=cloud_speech.RecognitionFeatures(
+ enable_automatic_punctuation=self._settings["enable_automatic_punctuation"],
+ enable_spoken_punctuation=self._settings["enable_spoken_punctuation"],
+ enable_spoken_emojis=self._settings["enable_spoken_emojis"],
+ profanity_filter=self._settings["profanity_filter"],
+ enable_word_time_offsets=self._settings["enable_word_time_offsets"],
+ enable_word_confidence=self._settings["enable_word_confidence"],
+ ),
+ ),
+ streaming_features=cloud_speech.StreamingRecognitionFeatures(
+ enable_voice_activity_events=self._settings["enable_voice_activity_events"],
+ interim_results=self._settings["enable_interim_results"],
+ ),
+ )
+
+ self._streaming_task = self.create_task(self._stream_audio())
+
+ async def _disconnect(self):
+ """Clean up streaming recognition resources."""
+ if self._streaming_task:
+ logger.debug("Disconnecting from Google Speech-to-Text")
+ # Send sentinel value to stop request generator
+ await self._request_queue.put(None)
+ await self.cancel_task(self._streaming_task)
+ self._streaming_task = None
+ # Clear any remaining items in the queue
+ while not self._request_queue.empty():
+ try:
+ self._request_queue.get_nowait()
+ self._request_queue.task_done()
+ except asyncio.QueueEmpty:
+ break
+
+ async def _request_generator(self):
+ """Generates requests for the streaming recognize method."""
+ recognizer_path = f"projects/{self._project_id}/locations/{self._location}/recognizers/_"
+ logger.trace(f"Using recognizer path: {recognizer_path}")
+
+ try:
+ # Send initial config
+ yield cloud_speech.StreamingRecognizeRequest(
+ recognizer=recognizer_path,
+ streaming_config=self._config,
+ )
+
+ while True:
+ try:
+ audio_data = await self._request_queue.get()
+ if audio_data is None: # Sentinel value to stop
+ break
+
+ # Check streaming limit
+ if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
+ logger.debug("Streaming limit reached, initiating graceful reconnection")
+ # Instead of immediate reconnection, we'll break and let the stream close naturally
+ self._last_audio_input = self._audio_input
+ self._audio_input = []
+ self._restart_counter += 1
+ # Put the current audio chunk back in the queue
+ await self._request_queue.put(audio_data)
+ break
+
+ self._audio_input.append(audio_data)
+ yield cloud_speech.StreamingRecognizeRequest(audio=audio_data)
+
+ except asyncio.CancelledError:
+ break
+ finally:
+ self._request_queue.task_done()
+
+ except Exception as e:
+ logger.error(f"Error in request generator: {e}")
+ raise
+
+ async def _stream_audio(self):
+ """Handle bi-directional streaming with Google STT."""
+ try:
+ while True:
+ try:
+ # Start bi-directional streaming
+ streaming_recognize = await self._client.streaming_recognize(
+ requests=self._request_generator()
+ )
+
+ # Process responses
+ await self._process_responses(streaming_recognize)
+
+ # If we're here, check if we need to reconnect
+ if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
+ logger.debug("Reconnecting stream after timeout")
+ # Reset stream start time
+ self._stream_start_time = int(time.time() * 1000)
+ continue
+ else:
+ # Normal stream end
+ break
+
+ except Exception as e:
+ logger.warning(f"{self} Reconnecting: {e}")
+
+ await asyncio.sleep(1) # Brief delay before reconnecting
+ self._stream_start_time = int(time.time() * 1000)
+ continue
+
+ except Exception as e:
+ logger.error(f"Error in streaming task: {e}")
+ await self.push_frame(ErrorFrame(str(e)))
+
+ async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]:
+ """Process an audio chunk for STT transcription."""
+ if self._streaming_task:
+ # Queue the audio data
+ await self._request_queue.put(audio)
+ yield None
+
+ async def _process_responses(self, streaming_recognize):
+ """Process streaming recognition responses."""
+ try:
+ async for response in streaming_recognize:
+ # Check streaming limit
+ if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT:
+ logger.debug("Stream timeout reached in response processing")
+ break
+
+ if not response.results:
+ continue
+
+ for result in response.results:
+ if not result.alternatives:
+ continue
+
+ transcript = result.alternatives[0].transcript
+ if not transcript:
+ continue
+
+ primary_language = self._settings["language_codes"][0]
+
+ if result.is_final:
+ self._last_transcript_was_final = True
+ await self.push_frame(
+ TranscriptionFrame(transcript, "", time_now_iso8601(), primary_language)
+ )
+ else:
+ self._last_transcript_was_final = False
+ await self.push_frame(
+ InterimTranscriptionFrame(
+ transcript, "", time_now_iso8601(), primary_language
+ )
+ )
+
+ except Exception as e:
+ logger.error(f"Error processing Google STT responses: {e}")
+
+ # Re-raise the exception to let it propagate (e.g. in the case of a timeout, propagate to _stream_audio to reconnect)
+ raise
diff --git a/src/pipecat/services/google/tts.py b/src/pipecat/services/google/tts.py
new file mode 100644
index 000000000..36bd27a51
--- /dev/null
+++ b/src/pipecat/services/google/tts.py
@@ -0,0 +1,364 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import asyncio
+import json
+import os
+
+# Suppress gRPC fork warnings
+os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false"
+
+from typing import AsyncGenerator, Literal, Optional
+
+from loguru import logger
+from pydantic import BaseModel
+
+from pipecat.frames.frames import (
+ ErrorFrame,
+ Frame,
+ TTSAudioRawFrame,
+ TTSStartedFrame,
+ TTSStoppedFrame,
+)
+from pipecat.services.ai_services import TTSService
+from pipecat.transcriptions.language import Language
+
+try:
+ from google.cloud import texttospeech_v1
+ from google.oauth2 import service_account
+
+except ModuleNotFoundError as e:
+ logger.error(f"Exception: {e}")
+ logger.error(
+ "In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_APPLICATION_CREDENTIALS` environment variable."
+ )
+ raise Exception(f"Missing module: {e}")
+
+
+def language_to_google_tts_language(language: Language) -> Optional[str]:
+ language_map = {
+ # Afrikaans
+ Language.AF: "af-ZA",
+ Language.AF_ZA: "af-ZA",
+ # Arabic
+ Language.AR: "ar-XA",
+ # Bengali
+ Language.BN: "bn-IN",
+ Language.BN_IN: "bn-IN",
+ # Bulgarian
+ Language.BG: "bg-BG",
+ Language.BG_BG: "bg-BG",
+ # Catalan
+ Language.CA: "ca-ES",
+ Language.CA_ES: "ca-ES",
+ # Chinese (Mandarin and Cantonese)
+ Language.ZH: "cmn-CN",
+ Language.ZH_CN: "cmn-CN",
+ Language.ZH_TW: "cmn-TW",
+ Language.ZH_HK: "yue-HK",
+ # Czech
+ Language.CS: "cs-CZ",
+ Language.CS_CZ: "cs-CZ",
+ # Danish
+ Language.DA: "da-DK",
+ Language.DA_DK: "da-DK",
+ # Dutch
+ Language.NL: "nl-NL",
+ Language.NL_BE: "nl-BE",
+ Language.NL_NL: "nl-NL",
+ # English
+ Language.EN: "en-US",
+ Language.EN_US: "en-US",
+ Language.EN_AU: "en-AU",
+ Language.EN_GB: "en-GB",
+ Language.EN_IN: "en-IN",
+ # Estonian
+ Language.ET: "et-EE",
+ Language.ET_EE: "et-EE",
+ # Filipino
+ Language.FIL: "fil-PH",
+ Language.FIL_PH: "fil-PH",
+ # Finnish
+ Language.FI: "fi-FI",
+ Language.FI_FI: "fi-FI",
+ # French
+ Language.FR: "fr-FR",
+ Language.FR_CA: "fr-CA",
+ Language.FR_FR: "fr-FR",
+ # Galician
+ Language.GL: "gl-ES",
+ Language.GL_ES: "gl-ES",
+ # German
+ Language.DE: "de-DE",
+ Language.DE_DE: "de-DE",
+ # Greek
+ Language.EL: "el-GR",
+ Language.EL_GR: "el-GR",
+ # Gujarati
+ Language.GU: "gu-IN",
+ Language.GU_IN: "gu-IN",
+ # Hebrew
+ Language.HE: "he-IL",
+ Language.HE_IL: "he-IL",
+ # Hindi
+ Language.HI: "hi-IN",
+ Language.HI_IN: "hi-IN",
+ # Hungarian
+ Language.HU: "hu-HU",
+ Language.HU_HU: "hu-HU",
+ # Icelandic
+ Language.IS: "is-IS",
+ Language.IS_IS: "is-IS",
+ # Indonesian
+ Language.ID: "id-ID",
+ Language.ID_ID: "id-ID",
+ # Italian
+ Language.IT: "it-IT",
+ Language.IT_IT: "it-IT",
+ # Japanese
+ Language.JA: "ja-JP",
+ Language.JA_JP: "ja-JP",
+ # Kannada
+ Language.KN: "kn-IN",
+ Language.KN_IN: "kn-IN",
+ # Korean
+ Language.KO: "ko-KR",
+ Language.KO_KR: "ko-KR",
+ # Latvian
+ Language.LV: "lv-LV",
+ Language.LV_LV: "lv-LV",
+ # Lithuanian
+ Language.LT: "lt-LT",
+ Language.LT_LT: "lt-LT",
+ # Malay
+ Language.MS: "ms-MY",
+ Language.MS_MY: "ms-MY",
+ # Malayalam
+ Language.ML: "ml-IN",
+ Language.ML_IN: "ml-IN",
+ # Marathi
+ Language.MR: "mr-IN",
+ Language.MR_IN: "mr-IN",
+ # Norwegian
+ Language.NO: "nb-NO",
+ Language.NB: "nb-NO",
+ Language.NB_NO: "nb-NO",
+ # Polish
+ Language.PL: "pl-PL",
+ Language.PL_PL: "pl-PL",
+ # Portuguese
+ Language.PT: "pt-PT",
+ Language.PT_BR: "pt-BR",
+ Language.PT_PT: "pt-PT",
+ # Punjabi
+ Language.PA: "pa-IN",
+ Language.PA_IN: "pa-IN",
+ # Romanian
+ Language.RO: "ro-RO",
+ Language.RO_RO: "ro-RO",
+ # Russian
+ Language.RU: "ru-RU",
+ Language.RU_RU: "ru-RU",
+ # Serbian
+ Language.SR: "sr-RS",
+ Language.SR_RS: "sr-RS",
+ # Slovak
+ Language.SK: "sk-SK",
+ Language.SK_SK: "sk-SK",
+ # Spanish
+ Language.ES: "es-ES",
+ Language.ES_ES: "es-ES",
+ Language.ES_US: "es-US",
+ # Swedish
+ Language.SV: "sv-SE",
+ Language.SV_SE: "sv-SE",
+ # Tamil
+ Language.TA: "ta-IN",
+ Language.TA_IN: "ta-IN",
+ # Telugu
+ Language.TE: "te-IN",
+ Language.TE_IN: "te-IN",
+ # Thai
+ Language.TH: "th-TH",
+ Language.TH_TH: "th-TH",
+ # Turkish
+ Language.TR: "tr-TR",
+ Language.TR_TR: "tr-TR",
+ # Ukrainian
+ Language.UK: "uk-UA",
+ Language.UK_UA: "uk-UA",
+ # Vietnamese
+ Language.VI: "vi-VN",
+ Language.VI_VN: "vi-VN",
+ }
+
+ return language_map.get(language)
+
+
+class GoogleTTSService(TTSService):
+ class InputParams(BaseModel):
+ pitch: Optional[str] = None
+ rate: Optional[str] = None
+ volume: Optional[str] = None
+ emphasis: Optional[Literal["strong", "moderate", "reduced", "none"]] = None
+ language: Optional[Language] = Language.EN
+ gender: Optional[Literal["male", "female", "neutral"]] = None
+ google_style: Optional[Literal["apologetic", "calm", "empathetic", "firm", "lively"]] = None
+
+ def __init__(
+ self,
+ *,
+ credentials: Optional[str] = None,
+ credentials_path: Optional[str] = None,
+ voice_id: str = "en-US-Neural2-A",
+ sample_rate: Optional[int] = None,
+ params: InputParams = InputParams(),
+ **kwargs,
+ ):
+ super().__init__(sample_rate=sample_rate, **kwargs)
+
+ self._settings = {
+ "pitch": params.pitch,
+ "rate": params.rate,
+ "volume": params.volume,
+ "emphasis": params.emphasis,
+ "language": self.language_to_service_language(params.language)
+ if params.language
+ else "en-US",
+ "gender": params.gender,
+ "google_style": params.google_style,
+ }
+ self.set_voice(voice_id)
+ self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client(
+ credentials, credentials_path
+ )
+
+ def _create_client(
+ self, credentials: Optional[str], credentials_path: Optional[str]
+ ) -> texttospeech_v1.TextToSpeechAsyncClient:
+ creds: Optional[service_account.Credentials] = None
+
+ # Create a Google Cloud service account for the Cloud Text-to-Speech API
+ # Using either the provided credentials JSON string or the path to a service account JSON
+ # file, create a Google Cloud service account and use it to authenticate with the API.
+ if credentials:
+ # Use provided credentials JSON string
+ json_account_info = json.loads(credentials)
+ creds = service_account.Credentials.from_service_account_info(json_account_info)
+ elif credentials_path:
+ # Use service account JSON file if provided
+ creds = service_account.Credentials.from_service_account_file(credentials_path)
+
+ return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds)
+
+ def can_generate_metrics(self) -> bool:
+ return True
+
+ def language_to_service_language(self, language: Language) -> Optional[str]:
+ return language_to_google_tts_language(language)
+
+ def _construct_ssml(self, text: str) -> str:
+ ssml = ""
+
+ # Voice tag
+ voice_attrs = [f"name='{self._voice_id}'"]
+
+ language = self._settings["language"]
+ voice_attrs.append(f"language='{language}'")
+
+ if self._settings["gender"]:
+ voice_attrs.append(f"gender='{self._settings['gender']}'")
+ ssml += f""
+
+ # Prosody tag
+ prosody_attrs = []
+ if self._settings["pitch"]:
+ prosody_attrs.append(f"pitch='{self._settings['pitch']}'")
+ if self._settings["rate"]:
+ prosody_attrs.append(f"rate='{self._settings['rate']}'")
+ if self._settings["volume"]:
+ prosody_attrs.append(f"volume='{self._settings['volume']}'")
+
+ if prosody_attrs:
+ ssml += f""
+
+ # Emphasis tag
+ if self._settings["emphasis"]:
+ ssml += f""
+
+ # Google style tag
+ if self._settings["google_style"]:
+ ssml += f""
+
+ ssml += text
+
+ # Close tags
+ if self._settings["google_style"]:
+ ssml += ""
+ if self._settings["emphasis"]:
+ ssml += ""
+ if prosody_attrs:
+ ssml += ""
+ ssml += ""
+
+ return ssml
+
+ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
+ logger.debug(f"{self}: Generating TTS [{text}]")
+
+ try:
+ await self.start_ttfb_metrics()
+
+ # Check if the voice is a Chirp voice (including Chirp 3) or Journey voice
+ is_chirp_voice = "chirp" in self._voice_id.lower()
+ is_journey_voice = "journey" in self._voice_id.lower()
+
+ # Create synthesis input based on voice_id
+ if is_chirp_voice or is_journey_voice:
+ # Chirp and Journey voices don't support SSML, use plain text
+ synthesis_input = texttospeech_v1.SynthesisInput(text=text)
+ else:
+ ssml = self._construct_ssml(text)
+ synthesis_input = texttospeech_v1.SynthesisInput(ssml=ssml)
+
+ voice = texttospeech_v1.VoiceSelectionParams(
+ language_code=self._settings["language"], name=self._voice_id
+ )
+ audio_config = texttospeech_v1.AudioConfig(
+ audio_encoding=texttospeech_v1.AudioEncoding.LINEAR16,
+ sample_rate_hertz=self.sample_rate,
+ )
+
+ request = texttospeech_v1.SynthesizeSpeechRequest(
+ input=synthesis_input, voice=voice, audio_config=audio_config
+ )
+
+ response = await self._client.synthesize_speech(request=request)
+
+ await self.start_tts_usage_metrics(text)
+
+ yield TTSStartedFrame()
+
+ # Skip the first 44 bytes to remove the WAV header
+ audio_content = response.audio_content[44:]
+
+ # Read and yield audio data in chunks
+ chunk_size = 8192
+ for i in range(0, len(audio_content), chunk_size):
+ chunk = audio_content[i : i + chunk_size]
+ if not chunk:
+ break
+ await self.stop_ttfb_metrics()
+ frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
+ yield frame
+ await asyncio.sleep(0) # Allow other tasks to run
+
+ yield TTSStoppedFrame()
+
+ except Exception as e:
+ logger.exception(f"{self} error generating TTS: {e}")
+ error_message = f"TTS generation error: {str(e)}"
+ yield ErrorFrame(error=error_message)
diff --git a/src/pipecat/services/grok/__init__.py b/src/pipecat/services/grok/__init__.py
new file mode 100644
index 000000000..9eebcfccf
--- /dev/null
+++ b/src/pipecat/services/grok/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "grok", "grok.llm")
diff --git a/src/pipecat/services/grok.py b/src/pipecat/services/grok/llm.py
similarity index 91%
rename from src/pipecat/services/grok.py
rename to src/pipecat/services/grok/llm.py
index faed13050..57517eb3e 100644
--- a/src/pipecat/services/grok.py
+++ b/src/pipecat/services/grok/llm.py
@@ -4,21 +4,14 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
-
-import json
from dataclasses import dataclass
-from typing import Any, Mapping, Optional
+from typing import Any, Mapping
from loguru import logger
-from pipecat.frames.frames import FunctionCallResultProperties
from pipecat.metrics.metrics import LLMTokenUsage
-from pipecat.processors.aggregators.openai_llm_context import (
- OpenAILLMContext,
- OpenAILLMContextFrame,
-)
-from pipecat.processors.frame_processor import FrameDirection
-from pipecat.services.openai import (
+from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
+from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAILLMService,
OpenAIUserContextAggregator,
@@ -27,13 +20,13 @@ from pipecat.services.openai import (
@dataclass
class GrokContextAggregatorPair:
- _user: "OpenAIUserContextAggregator"
- _assistant: "OpenAIAssistantContextAggregator"
+ _user: OpenAIUserContextAggregator
+ _assistant: OpenAIAssistantContextAggregator
- def user(self) -> "OpenAIUserContextAggregator":
+ def user(self) -> OpenAIUserContextAggregator:
return self._user
- def assistant(self) -> "OpenAIAssistantContextAggregator":
+ def assistant(self) -> OpenAIAssistantContextAggregator:
return self._assistant
diff --git a/src/pipecat/services/groq.py b/src/pipecat/services/groq.py
deleted file mode 100644
index bf0304df2..000000000
--- a/src/pipecat/services/groq.py
+++ /dev/null
@@ -1,177 +0,0 @@
-#
-# Copyright (c) 2024–2025, Daily
-#
-# SPDX-License-Identifier: BSD 2-Clause License
-#
-
-
-from typing import AsyncGenerator, Optional
-
-from loguru import logger
-from pydantic import BaseModel
-
-from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
-from pipecat.services.ai_services import TTSService
-from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
-from pipecat.services.openai import OpenAILLMService
-from pipecat.transcriptions.language import Language
-
-try:
- from groq import AsyncGroq
-except ModuleNotFoundError as e:
- logger.error(f"Exception: {e}")
- logger.error(
- "In order to use Groq, you need to `pip install pipecat-ai[groq]`. Also, set a `GROQ_API_KEY` environment variable."
- )
- raise Exception(f"Missing module: {e}")
-
-
-class GroqLLMService(OpenAILLMService):
- """A service for interacting with Groq's API using the OpenAI-compatible interface.
-
- This service extends OpenAILLMService to connect to Groq's API endpoint while
- maintaining full compatibility with OpenAI's interface and functionality.
-
- Args:
- api_key (str): The API key for accessing Groq's API
- base_url (str, optional): The base URL for Groq API. Defaults to "https://api.groq.com/openai/v1"
- model (str, optional): The model identifier to use. Defaults to "llama-3.3-70b-versatile"
- **kwargs: Additional keyword arguments passed to OpenAILLMService
- """
-
- def __init__(
- self,
- *,
- api_key: str,
- base_url: str = "https://api.groq.com/openai/v1",
- model: str = "llama-3.3-70b-versatile",
- **kwargs,
- ):
- super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
-
- def create_client(self, api_key=None, base_url=None, **kwargs):
- """Create OpenAI-compatible client for Groq API endpoint."""
- logger.debug(f"Creating Groq client with api {base_url}")
- return super().create_client(api_key, base_url, **kwargs)
-
-
-class GroqSTTService(BaseWhisperSTTService):
- """Groq Whisper speech-to-text service.
-
- Uses Groq's Whisper API to convert audio to text. Requires a Groq API key
- set via the api_key parameter or GROQ_API_KEY environment variable.
-
- Args:
- model: Whisper model to use. Defaults to "whisper-large-v3-turbo".
- api_key: Groq API key. Defaults to None.
- base_url: API base URL. Defaults to "https://api.groq.com/openai/v1".
- language: Language of the audio input. Defaults to English.
- prompt: Optional text to guide the model's style or continue a previous segment.
- temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
- **kwargs: Additional arguments passed to BaseWhisperSTTService.
- """
-
- def __init__(
- self,
- *,
- model: str = "whisper-large-v3-turbo",
- api_key: Optional[str] = None,
- base_url: str = "https://api.groq.com/openai/v1",
- language: Optional[Language] = Language.EN,
- prompt: Optional[str] = None,
- temperature: Optional[float] = None,
- **kwargs,
- ):
- super().__init__(
- model=model,
- api_key=api_key,
- base_url=base_url,
- language=language,
- prompt=prompt,
- temperature=temperature,
- **kwargs,
- )
-
- async def _transcribe(self, audio: bytes) -> Transcription:
- assert self._language is not None # Assigned in the BaseWhisperSTTService class
-
- # Build kwargs dict with only set parameters
- kwargs = {
- "file": ("audio.wav", audio, "audio/wav"),
- "model": self.model_name,
- "response_format": "json",
- "language": self._language,
- }
-
- if self._prompt is not None:
- kwargs["prompt"] = self._prompt
-
- if self._temperature is not None:
- kwargs["temperature"] = self._temperature
-
- return await self._client.audio.transcriptions.create(**kwargs)
-
-
-class GroqTTSService(TTSService):
- class InputParams(BaseModel):
- language: Optional[Language] = Language.EN
- speed: Optional[float] = 1.0
- seed: Optional[int] = None
-
- GROQ_SAMPLE_RATE = 48000 # Groq TTS only supports 48kHz sample rate
-
- def __init__(
- self,
- *,
- api_key: str,
- output_format: str = "wav",
- params: InputParams = InputParams(),
- model_name: str = "playai-tts",
- voice_id: str = "Celeste-PlayAI",
- sample_rate: Optional[int] = GROQ_SAMPLE_RATE,
- **kwargs,
- ):
- if sample_rate != self.GROQ_SAMPLE_RATE:
- logger.warning(f"Groq TTS only supports {self.GROQ_SAMPLE_RATE}Hz sample rate. ")
- super().__init__(
- pause_frame_processing=True,
- sample_rate=sample_rate,
- **kwargs,
- )
-
- self._api_key = api_key
- self._model_name = model_name
- self._output_format = output_format
- self._voice_id = voice_id
- self._params = params
-
- self._client = AsyncGroq(api_key=self._api_key)
-
- def can_generate_metrics(self) -> bool:
- return True
-
- async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
- logger.debug(f"{self}: Generating TTS [{text}]")
- measuring_ttfb = True
- await self.start_ttfb_metrics()
- yield TTSStartedFrame()
-
- response = await self._client.audio.speech.create(
- model=self._model_name,
- voice=self._voice_id,
- response_format=self._output_format,
- input=text,
- )
-
- async for data in response.iter_bytes():
- if measuring_ttfb:
- await self.stop_ttfb_metrics()
- measuring_ttfb = False
- # remove wav header if present
- if data.startswith(b"RIFF"):
- data = data[44:]
- if len(data) == 0:
- continue
- yield TTSAudioRawFrame(data, self.sample_rate, 1)
-
- yield TTSStoppedFrame()
diff --git a/src/pipecat/services/groq/__init__.py b/src/pipecat/services/groq/__init__.py
new file mode 100644
index 000000000..216853830
--- /dev/null
+++ b/src/pipecat/services/groq/__init__.py
@@ -0,0 +1,15 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+from .stt import *
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "groq", "groq.[llm,stt,tts]")
diff --git a/src/pipecat/services/groq/llm.py b/src/pipecat/services/groq/llm.py
new file mode 100644
index 000000000..be2ed5e72
--- /dev/null
+++ b/src/pipecat/services/groq/llm.py
@@ -0,0 +1,38 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+from loguru import logger
+
+from pipecat.services.openai.llm import OpenAILLMService
+
+
+class GroqLLMService(OpenAILLMService):
+ """A service for interacting with Groq's API using the OpenAI-compatible interface.
+
+ This service extends OpenAILLMService to connect to Groq's API endpoint while
+ maintaining full compatibility with OpenAI's interface and functionality.
+
+ Args:
+ api_key (str): The API key for accessing Groq's API
+ base_url (str, optional): The base URL for Groq API. Defaults to "https://api.groq.com/openai/v1"
+ model (str, optional): The model identifier to use. Defaults to "llama-3.3-70b-versatile"
+ **kwargs: Additional keyword arguments passed to OpenAILLMService
+ """
+
+ def __init__(
+ self,
+ *,
+ api_key: str,
+ base_url: str = "https://api.groq.com/openai/v1",
+ model: str = "llama-3.3-70b-versatile",
+ **kwargs,
+ ):
+ super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs)
+
+ def create_client(self, api_key=None, base_url=None, **kwargs):
+ """Create OpenAI-compatible client for Groq API endpoint."""
+ logger.debug(f"Creating Groq client with api {base_url}")
+ return super().create_client(api_key, base_url, **kwargs)
diff --git a/src/pipecat/services/groq/stt.py b/src/pipecat/services/groq/stt.py
new file mode 100644
index 000000000..5852bedfd
--- /dev/null
+++ b/src/pipecat/services/groq/stt.py
@@ -0,0 +1,67 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+from typing import Optional
+
+from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription
+from pipecat.transcriptions.language import Language
+
+
+class GroqSTTService(BaseWhisperSTTService):
+ """Groq Whisper speech-to-text service.
+
+ Uses Groq's Whisper API to convert audio to text. Requires a Groq API key
+ set via the api_key parameter or GROQ_API_KEY environment variable.
+
+ Args:
+ model: Whisper model to use. Defaults to "whisper-large-v3-turbo".
+ api_key: Groq API key. Defaults to None.
+ base_url: API base URL. Defaults to "https://api.groq.com/openai/v1".
+ language: Language of the audio input. Defaults to English.
+ prompt: Optional text to guide the model's style or continue a previous segment.
+ temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
+ **kwargs: Additional arguments passed to BaseWhisperSTTService.
+ """
+
+ def __init__(
+ self,
+ *,
+ model: str = "whisper-large-v3-turbo",
+ api_key: Optional[str] = None,
+ base_url: str = "https://api.groq.com/openai/v1",
+ language: Optional[Language] = Language.EN,
+ prompt: Optional[str] = None,
+ temperature: Optional[float] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ model=model,
+ api_key=api_key,
+ base_url=base_url,
+ language=language,
+ prompt=prompt,
+ temperature=temperature,
+ **kwargs,
+ )
+
+ async def _transcribe(self, audio: bytes) -> Transcription:
+ assert self._language is not None # Assigned in the BaseWhisperSTTService class
+
+ # Build kwargs dict with only set parameters
+ kwargs = {
+ "file": ("audio.wav", audio, "audio/wav"),
+ "model": self.model_name,
+ "response_format": "json",
+ "language": self._language,
+ }
+
+ if self._prompt is not None:
+ kwargs["prompt"] = self._prompt
+
+ if self._temperature is not None:
+ kwargs["temperature"] = self._temperature
+
+ return await self._client.audio.transcriptions.create(**kwargs)
diff --git a/src/pipecat/services/groq/tts.py b/src/pipecat/services/groq/tts.py
new file mode 100644
index 000000000..69429424a
--- /dev/null
+++ b/src/pipecat/services/groq/tts.py
@@ -0,0 +1,86 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+from typing import AsyncGenerator, Optional
+
+from loguru import logger
+from pydantic import BaseModel
+
+from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame
+from pipecat.services.ai_services import TTSService
+from pipecat.transcriptions.language import Language
+
+try:
+ from groq import AsyncGroq
+except ModuleNotFoundError as e:
+ logger.error(f"Exception: {e}")
+ logger.error("In order to use Groq, you need to `pip install pipecat-ai[groq]`.")
+ raise Exception(f"Missing module: {e}")
+
+
+class GroqTTSService(TTSService):
+ class InputParams(BaseModel):
+ language: Optional[Language] = Language.EN
+ speed: Optional[float] = 1.0
+ seed: Optional[int] = None
+
+ GROQ_SAMPLE_RATE = 48000 # Groq TTS only supports 48kHz sample rate
+
+ def __init__(
+ self,
+ *,
+ api_key: str,
+ output_format: str = "wav",
+ params: InputParams = InputParams(),
+ model_name: str = "playai-tts",
+ voice_id: str = "Celeste-PlayAI",
+ sample_rate: Optional[int] = GROQ_SAMPLE_RATE,
+ **kwargs,
+ ):
+ if sample_rate != self.GROQ_SAMPLE_RATE:
+ logger.warning(f"Groq TTS only supports {self.GROQ_SAMPLE_RATE}Hz sample rate. ")
+ super().__init__(
+ pause_frame_processing=True,
+ sample_rate=sample_rate,
+ **kwargs,
+ )
+
+ self._api_key = api_key
+ self._model_name = model_name
+ self._output_format = output_format
+ self._voice_id = voice_id
+ self._params = params
+
+ self._client = AsyncGroq(api_key=self._api_key)
+
+ def can_generate_metrics(self) -> bool:
+ return True
+
+ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
+ logger.debug(f"{self}: Generating TTS [{text}]")
+ measuring_ttfb = True
+ await self.start_ttfb_metrics()
+ yield TTSStartedFrame()
+
+ response = await self._client.audio.speech.create(
+ model=self._model_name,
+ voice=self._voice_id,
+ response_format=self._output_format,
+ input=text,
+ )
+
+ async for data in response.iter_bytes():
+ if measuring_ttfb:
+ await self.stop_ttfb_metrics()
+ measuring_ttfb = False
+ # remove wav header if present
+ if data.startswith(b"RIFF"):
+ data = data[44:]
+ if len(data) == 0:
+ continue
+ yield TTSAudioRawFrame(data, self.sample_rate, 1)
+
+ yield TTSStoppedFrame()
diff --git a/src/pipecat/services/lmnt/__init__.py b/src/pipecat/services/lmnt/__init__.py
new file mode 100644
index 000000000..0f55aa55c
--- /dev/null
+++ b/src/pipecat/services/lmnt/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "lmnt", "lmnt.tts")
diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt/tts.py
similarity index 98%
rename from src/pipecat/services/lmnt.py
rename to src/pipecat/services/lmnt/tts.py
index d3cc92603..040d526f9 100644
--- a/src/pipecat/services/lmnt.py
+++ b/src/pipecat/services/lmnt/tts.py
@@ -29,9 +29,7 @@ try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
- logger.error(
- "In order to use LMNT, you need to `pip install pipecat-ai[lmnt]`. Also, set `LMNT_API_KEY` environment variable."
- )
+ logger.error("In order to use LMNT, you need to `pip install pipecat-ai[lmnt]`.")
raise Exception(f"Missing module: {e}")
diff --git a/src/pipecat/services/mem0/__init__.py b/src/pipecat/services/mem0/__init__.py
new file mode 100644
index 000000000..8459054d2
--- /dev/null
+++ b/src/pipecat/services/mem0/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .memory import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "mem0", "mem0.memory")
diff --git a/src/pipecat/services/mem0.py b/src/pipecat/services/mem0/memory.py
similarity index 100%
rename from src/pipecat/services/mem0.py
rename to src/pipecat/services/mem0/memory.py
diff --git a/src/pipecat/services/moondream/__init__.py b/src/pipecat/services/moondream/__init__.py
new file mode 100644
index 000000000..2b9994f08
--- /dev/null
+++ b/src/pipecat/services/moondream/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .vision import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "moondream", "moondream.vision")
diff --git a/src/pipecat/services/moondream.py b/src/pipecat/services/moondream/vision.py
similarity index 100%
rename from src/pipecat/services/moondream.py
rename to src/pipecat/services/moondream/vision.py
diff --git a/src/pipecat/services/neuphonic/__init__.py b/src/pipecat/services/neuphonic/__init__.py
new file mode 100644
index 000000000..5b2d9c34a
--- /dev/null
+++ b/src/pipecat/services/neuphonic/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "neuphonic", "neuphonic.tts")
diff --git a/src/pipecat/services/neuphonic.py b/src/pipecat/services/neuphonic/tts.py
similarity index 98%
rename from src/pipecat/services/neuphonic.py
rename to src/pipecat/services/neuphonic/tts.py
index 407e54a83..7ac005afc 100644
--- a/src/pipecat/services/neuphonic.py
+++ b/src/pipecat/services/neuphonic/tts.py
@@ -30,15 +30,12 @@ from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import InterruptibleTTSService, TTSService
from pipecat.transcriptions.language import Language
-# See .env.example for Neuphonic configuration needed
try:
import websockets
from pyneuphonic import Neuphonic, TTSConfig
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
- logger.error(
- "In order to use Neuphonic, you need to `pip install pipecat-ai[neuphonic]`. Also, set `NEUPHONIC_API_KEY` environment variable."
- )
+ logger.error("In order to use Neuphonic, you need to `pip install pipecat-ai[neuphonic]`.")
raise Exception(f"Missing module: {e}")
diff --git a/src/pipecat/services/nim/__init__.py b/src/pipecat/services/nim/__init__.py
new file mode 100644
index 000000000..05788d5ce
--- /dev/null
+++ b/src/pipecat/services/nim/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "nim", "nim.llm")
diff --git a/src/pipecat/services/nim.py b/src/pipecat/services/nim/llm.py
similarity index 98%
rename from src/pipecat/services/nim.py
rename to src/pipecat/services/nim/llm.py
index 7146e01b3..d57fa8d4c 100644
--- a/src/pipecat/services/nim.py
+++ b/src/pipecat/services/nim/llm.py
@@ -4,10 +4,9 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
-
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
-from pipecat.services.openai import OpenAILLMService
+from pipecat.services.openai.llm import OpenAILLMService
class NimLLMService(OpenAILLMService):
diff --git a/src/pipecat/services/ollama/__init__.py b/src/pipecat/services/ollama/__init__.py
new file mode 100644
index 000000000..9103ee851
--- /dev/null
+++ b/src/pipecat/services/ollama/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "ollama", "ollama.llm")
diff --git a/src/pipecat/services/ollama.py b/src/pipecat/services/ollama/llm.py
similarity index 71%
rename from src/pipecat/services/ollama.py
rename to src/pipecat/services/ollama/llm.py
index 7d23fe128..bd1ac0d0d 100644
--- a/src/pipecat/services/ollama.py
+++ b/src/pipecat/services/ollama/llm.py
@@ -4,9 +4,9 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
-from pipecat.services.openai import BaseOpenAILLMService
+from pipecat.services.openai.llm import OpenAILLMService
-class OLLamaLLMService(BaseOpenAILLMService):
+class OLLamaLLMService(OpenAILLMService):
def __init__(self, *, model: str = "llama2", base_url: str = "http://localhost:11434/v1"):
super().__init__(model=model, base_url=base_url, api_key="ollama")
diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py
deleted file mode 100644
index 3f85d917c..000000000
--- a/src/pipecat/services/openai.py
+++ /dev/null
@@ -1,644 +0,0 @@
-#
-# Copyright (c) 2024–2025, Daily
-#
-# SPDX-License-Identifier: BSD 2-Clause License
-#
-
-import base64
-import io
-import json
-from dataclasses import dataclass
-from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional
-
-import aiohttp
-import httpx
-from loguru import logger
-from openai import (
- NOT_GIVEN,
- AsyncOpenAI,
- AsyncStream,
- BadRequestError,
- DefaultAsyncHttpxClient,
-)
-from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
-from PIL import Image
-from pydantic import BaseModel, Field
-
-from pipecat.frames.frames import (
- ErrorFrame,
- Frame,
- FunctionCallCancelFrame,
- FunctionCallInProgressFrame,
- FunctionCallResultFrame,
- LLMFullResponseEndFrame,
- LLMFullResponseStartFrame,
- LLMMessagesFrame,
- LLMTextFrame,
- LLMUpdateSettingsFrame,
- StartFrame,
- TTSAudioRawFrame,
- TTSStartedFrame,
- TTSStoppedFrame,
- URLImageRawFrame,
- UserImageRawFrame,
- VisionImageRawFrame,
-)
-from pipecat.metrics.metrics import LLMTokenUsage
-from pipecat.processors.aggregators.llm_response import (
- LLMAssistantContextAggregator,
- LLMUserContextAggregator,
-)
-from pipecat.processors.aggregators.openai_llm_context import (
- OpenAILLMContext,
- OpenAILLMContextFrame,
-)
-from pipecat.processors.frame_processor import FrameDirection
-from pipecat.services.ai_services import (
- ImageGenService,
- LLMService,
- TTSService,
-)
-from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription
-from pipecat.transcriptions.language import Language
-
-ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
-
-VALID_VOICES: Dict[str, ValidVoice] = {
- "alloy": "alloy",
- "echo": "echo",
- "fable": "fable",
- "onyx": "onyx",
- "nova": "nova",
- "shimmer": "shimmer",
-}
-
-
-class OpenAIUnhandledFunctionException(Exception):
- pass
-
-
-class BaseOpenAILLMService(LLMService):
- """This is the base for all services that use the AsyncOpenAI client.
-
- This service consumes OpenAILLMContextFrame frames, which contain a reference
- to an OpenAILLMContext frame. The OpenAILLMContext object defines the context
- sent to the LLM for a completion. This includes user, assistant and system messages
- as well as tool choices and the tool, which is used if requesting function
- calls from the LLM.
- """
-
- class InputParams(BaseModel):
- frequency_penalty: Optional[float] = Field(
- default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0
- )
- presence_penalty: Optional[float] = Field(
- default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0
- )
- seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
- temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0)
- # Note: top_k is currently not supported by the OpenAI client library,
- # so top_k is ignored right now.
- top_k: Optional[int] = Field(default=None, ge=0)
- top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
- max_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
- max_completion_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
- extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
-
- def __init__(
- self,
- *,
- model: str,
- api_key=None,
- base_url=None,
- organization=None,
- project=None,
- default_headers: Mapping[str, str] | None = None,
- params: InputParams = InputParams(),
- **kwargs,
- ):
- super().__init__(**kwargs)
- self._settings = {
- "frequency_penalty": params.frequency_penalty,
- "presence_penalty": params.presence_penalty,
- "seed": params.seed,
- "temperature": params.temperature,
- "top_p": params.top_p,
- "max_tokens": params.max_tokens,
- "max_completion_tokens": params.max_completion_tokens,
- "extra": params.extra if isinstance(params.extra, dict) else {},
- }
- self.set_model_name(model)
- self._client = self.create_client(
- api_key=api_key,
- base_url=base_url,
- organization=organization,
- project=project,
- default_headers=default_headers,
- **kwargs,
- )
-
- def create_client(
- self,
- api_key=None,
- base_url=None,
- organization=None,
- project=None,
- default_headers=None,
- **kwargs,
- ):
- return AsyncOpenAI(
- api_key=api_key,
- base_url=base_url,
- organization=organization,
- project=project,
- http_client=DefaultAsyncHttpxClient(
- limits=httpx.Limits(
- max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None
- )
- ),
- default_headers=default_headers,
- )
-
- def can_generate_metrics(self) -> bool:
- return True
-
- async def get_chat_completions(
- self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
- ) -> AsyncStream[ChatCompletionChunk]:
- params = {
- "model": self.model_name,
- "stream": True,
- "messages": messages,
- "tools": context.tools,
- "tool_choice": context.tool_choice,
- "stream_options": {"include_usage": True},
- "frequency_penalty": self._settings["frequency_penalty"],
- "presence_penalty": self._settings["presence_penalty"],
- "seed": self._settings["seed"],
- "temperature": self._settings["temperature"],
- "top_p": self._settings["top_p"],
- "max_tokens": self._settings["max_tokens"],
- "max_completion_tokens": self._settings["max_completion_tokens"],
- }
-
- params.update(self._settings["extra"])
-
- chunks = await self._client.chat.completions.create(**params)
- return chunks
-
- async def _stream_chat_completions(
- self, context: OpenAILLMContext
- ) -> AsyncStream[ChatCompletionChunk]:
- logger.debug(f"{self}: Generating chat [{context.get_messages_for_logging()}]")
-
- messages: List[ChatCompletionMessageParam] = context.get_messages()
-
- # base64 encode any images
- for message in messages:
- if message.get("mime_type") == "image/jpeg":
- encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8")
- text = message["content"]
- message["content"] = [
- {"type": "text", "text": text},
- {
- "type": "image_url",
- "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
- },
- ]
- del message["data"]
- del message["mime_type"]
-
- chunks = await self.get_chat_completions(context, messages)
-
- return chunks
-
- async def _process_context(self, context: OpenAILLMContext):
- functions_list = []
- arguments_list = []
- tool_id_list = []
- func_idx = 0
- function_name = ""
- arguments = ""
- tool_call_id = ""
-
- await self.start_ttfb_metrics()
-
- chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(
- context
- )
-
- async for chunk in chunk_stream:
- if chunk.usage:
- tokens = LLMTokenUsage(
- prompt_tokens=chunk.usage.prompt_tokens,
- completion_tokens=chunk.usage.completion_tokens,
- total_tokens=chunk.usage.total_tokens,
- )
- await self.start_llm_usage_metrics(tokens)
-
- if chunk.choices is None or len(chunk.choices) == 0:
- continue
-
- await self.stop_ttfb_metrics()
-
- if not chunk.choices[0].delta:
- continue
-
- if chunk.choices[0].delta.tool_calls:
- # We're streaming the LLM response to enable the fastest response times.
- # For text, we just yield each chunk as we receive it and count on consumers
- # to do whatever coalescing they need (eg. to pass full sentences to TTS)
- #
- # If the LLM is a function call, we'll do some coalescing here.
- # If the response contains a function name, we'll yield a frame to tell consumers
- # that they can start preparing to call the function with that name.
- # We accumulate all the arguments for the rest of the streamed response, then when
- # the response is done, we package up all the arguments and the function name and
- # yield a frame containing the function name and the arguments.
-
- tool_call = chunk.choices[0].delta.tool_calls[0]
- if tool_call.index != func_idx:
- functions_list.append(function_name)
- arguments_list.append(arguments)
- tool_id_list.append(tool_call_id)
- function_name = ""
- arguments = ""
- tool_call_id = ""
- func_idx += 1
- if tool_call.function and tool_call.function.name:
- function_name += tool_call.function.name
- tool_call_id = tool_call.id
- if tool_call.function and tool_call.function.arguments:
- # Keep iterating through the response to collect all the argument fragments
- arguments += tool_call.function.arguments
- elif chunk.choices[0].delta.content:
- await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
-
- # if we got a function name and arguments, check to see if it's a function with
- # a registered handler. If so, run the registered callback, save the result to
- # the context, and re-prompt to get a chat answer. If we don't have a registered
- # handler, raise an exception.
- if function_name and arguments:
- # added to the list as last function name and arguments not added to the list
- functions_list.append(function_name)
- arguments_list.append(arguments)
- tool_id_list.append(tool_call_id)
-
- for index, (function_name, arguments, tool_id) in enumerate(
- zip(functions_list, arguments_list, tool_id_list), start=1
- ):
- if self.has_function(function_name):
- run_llm = False
- arguments = json.loads(arguments)
- await self.call_function(
- context=context,
- function_name=function_name,
- arguments=arguments,
- tool_call_id=tool_id,
- run_llm=run_llm,
- )
- else:
- raise OpenAIUnhandledFunctionException(
- f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
- )
-
- async def process_frame(self, frame: Frame, direction: FrameDirection):
- await super().process_frame(frame, direction)
-
- context = None
- if isinstance(frame, OpenAILLMContextFrame):
- context: OpenAILLMContext = frame.context
- elif isinstance(frame, LLMMessagesFrame):
- context = OpenAILLMContext.from_messages(frame.messages)
- elif isinstance(frame, VisionImageRawFrame):
- context = OpenAILLMContext()
- context.add_image_frame_message(
- format=frame.format, size=frame.size, image=frame.image, text=frame.text
- )
- elif isinstance(frame, LLMUpdateSettingsFrame):
- await self._update_settings(frame.settings)
- else:
- await self.push_frame(frame, direction)
-
- if context:
- try:
- await self.push_frame(LLMFullResponseStartFrame())
- await self.start_processing_metrics()
- await self._process_context(context)
- except httpx.TimeoutException:
- await self._call_event_handler("on_completion_timeout")
- finally:
- await self.stop_processing_metrics()
- await self.push_frame(LLMFullResponseEndFrame())
-
-
-@dataclass
-class OpenAIContextAggregatorPair:
- _user: "OpenAIUserContextAggregator"
- _assistant: "OpenAIAssistantContextAggregator"
-
- def user(self) -> "OpenAIUserContextAggregator":
- return self._user
-
- def assistant(self) -> "OpenAIAssistantContextAggregator":
- return self._assistant
-
-
-class OpenAILLMService(BaseOpenAILLMService):
- def __init__(
- self,
- *,
- model: str = "gpt-4o",
- params: BaseOpenAILLMService.InputParams = BaseOpenAILLMService.InputParams(),
- **kwargs,
- ):
- super().__init__(model=model, params=params, **kwargs)
-
- def create_context_aggregator(
- self,
- context: OpenAILLMContext,
- *,
- user_kwargs: Mapping[str, Any] = {},
- assistant_kwargs: Mapping[str, Any] = {},
- ) -> OpenAIContextAggregatorPair:
- """Create an instance of OpenAIContextAggregatorPair from an
- OpenAILLMContext. Constructor keyword arguments for both the user and
- assistant aggregators can be provided.
-
- Args:
- context (OpenAILLMContext): The LLM context.
- user_kwargs (Mapping[str, Any], optional): Additional keyword
- arguments for the user context aggregator constructor. Defaults
- to an empty mapping.
- assistant_kwargs (Mapping[str, Any], optional): Additional keyword
- arguments for the assistant context aggregator
- constructor. Defaults to an empty mapping.
-
- Returns:
- OpenAIContextAggregatorPair: A pair of context aggregators, one for
- the user and one for the assistant, encapsulated in an
- OpenAIContextAggregatorPair.
-
- """
- context.set_llm_adapter(self.get_llm_adapter())
- user = OpenAIUserContextAggregator(context, **user_kwargs)
- assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs)
- return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
-
-
-class OpenAIImageGenService(ImageGenService):
- def __init__(
- self,
- *,
- api_key: str,
- base_url: Optional[str] = None,
- aiohttp_session: aiohttp.ClientSession,
- image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
- model: str = "dall-e-3",
- ):
- super().__init__()
- self.set_model_name(model)
- self._image_size = image_size
- self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
- self._aiohttp_session = aiohttp_session
-
- async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
- logger.debug(f"Generating image from prompt: {prompt}")
-
- image = await self._client.images.generate(
- prompt=prompt, model=self.model_name, n=1, size=self._image_size
- )
-
- image_url = image.data[0].url
-
- if not image_url:
- logger.error(f"{self} No image provided in response: {image}")
- yield ErrorFrame("Image generation failed")
- return
-
- # Load the image from the url
- async with self._aiohttp_session.get(image_url) as response:
- image_stream = io.BytesIO(await response.content.read())
- image = Image.open(image_stream)
- frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format)
- yield frame
-
-
-class OpenAISTTService(BaseWhisperSTTService):
- """OpenAI Speech-to-Text service that generates text from audio.
-
- Uses OpenAI's transcription API to convert audio to text. Requires an OpenAI API key
- set via the api_key parameter or OPENAI_API_KEY environment variable.
-
- Args:
- model: Model to use — either gpt-4o or Whisper. Defaults to "gpt-4o-transcribe".
- api_key: OpenAI API key. Defaults to None.
- base_url: API base URL. Defaults to None.
- language: Language of the audio input. Defaults to English.
- prompt: Optional text to guide the model's style or continue a previous segment.
- temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
- **kwargs: Additional arguments passed to BaseWhisperSTTService.
- """
-
- def __init__(
- self,
- *,
- model: str = "gpt-4o-transcribe",
- api_key: Optional[str] = None,
- base_url: Optional[str] = None,
- language: Optional[Language] = Language.EN,
- prompt: Optional[str] = None,
- temperature: Optional[float] = None,
- **kwargs,
- ):
- super().__init__(
- model=model,
- api_key=api_key,
- base_url=base_url,
- language=language,
- prompt=prompt,
- temperature=temperature,
- **kwargs,
- )
-
- async def _transcribe(self, audio: bytes) -> Transcription:
- assert self._language is not None # Assigned in the BaseWhisperSTTService class
-
- # Build kwargs dict with only set parameters
- kwargs = {
- "file": ("audio.wav", audio, "audio/wav"),
- "model": self.model_name,
- "language": self._language,
- }
-
- if self._prompt is not None:
- kwargs["prompt"] = self._prompt
-
- if self._temperature is not None:
- kwargs["temperature"] = self._temperature
-
- return await self._client.audio.transcriptions.create(**kwargs)
-
-
-class OpenAITTSService(TTSService):
- """OpenAI Text-to-Speech service that generates audio from text.
-
- This service uses the OpenAI TTS API to generate PCM-encoded audio at 24kHz.
-
- Args:
- api_key: OpenAI API key. Defaults to None.
- voice: Voice ID to use. Defaults to "alloy".
- model: TTS model to use. Defaults to "gpt-4o-mini-tts".
- sample_rate: Output audio sample rate in Hz. Defaults to None.
- **kwargs: Additional keyword arguments passed to TTSService.
-
- The service returns PCM-encoded audio at the specified sample rate.
-
- """
-
- OPENAI_SAMPLE_RATE = 24000 # OpenAI TTS always outputs at 24kHz
-
- def __init__(
- self,
- *,
- api_key: Optional[str] = None,
- base_url: Optional[str] = None,
- voice: str = "alloy",
- model: str = "gpt-4o-mini-tts",
- sample_rate: Optional[int] = None,
- instructions: Optional[str] = None,
- **kwargs,
- ):
- if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE:
- logger.warning(
- f"OpenAI TTS only supports {self.OPENAI_SAMPLE_RATE}Hz sample rate. "
- f"Current rate of {self.sample_rate}Hz may cause issues."
- )
- super().__init__(sample_rate=sample_rate, **kwargs)
-
- self.set_model_name(model)
- self.set_voice(voice)
- self._instructions = instructions
- self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
-
- def can_generate_metrics(self) -> bool:
- return True
-
- async def set_model(self, model: str):
- logger.info(f"Switching TTS model to: [{model}]")
- self.set_model_name(model)
-
- async def start(self, frame: StartFrame):
- await super().start(frame)
- if self.sample_rate != self.OPENAI_SAMPLE_RATE:
- logger.warning(
- f"OpenAI TTS requires {self.OPENAI_SAMPLE_RATE}Hz sample rate. "
- f"Current rate of {self.sample_rate}Hz may cause issues."
- )
-
- async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
- logger.debug(f"{self}: Generating TTS [{text}]")
- try:
- await self.start_ttfb_metrics()
-
- # Setup extra body parameters
- extra_body = {}
- if self._instructions:
- extra_body["instructions"] = self._instructions
-
- async with self._client.audio.speech.with_streaming_response.create(
- input=text or " ", # Text must contain at least one character
- model=self.model_name,
- voice=VALID_VOICES[self._voice_id],
- response_format="pcm",
- extra_body=extra_body,
- ) as r:
- if r.status_code != 200:
- error = await r.text()
- logger.error(
- f"{self} error getting audio (status: {r.status_code}, error: {error})"
- )
- yield ErrorFrame(
- f"Error getting audio (status: {r.status_code}, error: {error})"
- )
- return
-
- await self.start_tts_usage_metrics(text)
-
- CHUNK_SIZE = 1024
-
- yield TTSStartedFrame()
- async for chunk in r.iter_bytes(CHUNK_SIZE):
- if len(chunk) > 0:
- await self.stop_ttfb_metrics()
- frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
- yield frame
- yield TTSStoppedFrame()
- except BadRequestError as e:
- logger.exception(f"{self} error generating TTS: {e}")
-
-
-class OpenAIUserContextAggregator(LLMUserContextAggregator):
- pass
-
-
-class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
- async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
- self._context.add_message(
- {
- "role": "assistant",
- "tool_calls": [
- {
- "id": frame.tool_call_id,
- "function": {
- "name": frame.function_name,
- "arguments": json.dumps(frame.arguments),
- },
- "type": "function",
- }
- ],
- }
- )
- self._context.add_message(
- {
- "role": "tool",
- "content": "IN_PROGRESS",
- "tool_call_id": frame.tool_call_id,
- }
- )
-
- async def handle_function_call_result(self, frame: FunctionCallResultFrame):
- if frame.result:
- result = json.dumps(frame.result)
- await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
- else:
- await self._update_function_call_result(
- frame.function_name, frame.tool_call_id, "COMPLETED"
- )
-
- async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
- await self._update_function_call_result(
- frame.function_name, frame.tool_call_id, "CANCELLED"
- )
-
- async def _update_function_call_result(
- self, function_name: str, tool_call_id: str, result: Any
- ):
- for message in self._context.messages:
- if (
- message["role"] == "tool"
- and message["tool_call_id"]
- and message["tool_call_id"] == tool_call_id
- ):
- message["content"] = result
-
- async def handle_user_image_frame(self, frame: UserImageRawFrame):
- await self._update_function_call_result(
- frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
- )
- self._context.add_image_frame_message(
- format=frame.format,
- size=frame.size,
- image=frame.image,
- text=frame.request.context,
- )
diff --git a/src/pipecat/services/openai/__init__.py b/src/pipecat/services/openai/__init__.py
new file mode 100644
index 000000000..4decac126
--- /dev/null
+++ b/src/pipecat/services/openai/__init__.py
@@ -0,0 +1,16 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .image import *
+from .llm import *
+from .stt import *
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "openai", "openai.[image,llm,stt,tts]")
diff --git a/src/pipecat/services/openai/base_llm.py b/src/pipecat/services/openai/base_llm.py
new file mode 100644
index 000000000..5343c1eb7
--- /dev/null
+++ b/src/pipecat/services/openai/base_llm.py
@@ -0,0 +1,296 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import base64
+import json
+from typing import Any, Dict, List, Mapping, Optional
+
+import httpx
+from loguru import logger
+from openai import (
+ NOT_GIVEN,
+ AsyncOpenAI,
+ AsyncStream,
+ DefaultAsyncHttpxClient,
+)
+from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
+from pydantic import BaseModel, Field
+
+from pipecat.frames.frames import (
+ Frame,
+ LLMFullResponseEndFrame,
+ LLMFullResponseStartFrame,
+ LLMMessagesFrame,
+ LLMTextFrame,
+ LLMUpdateSettingsFrame,
+ VisionImageRawFrame,
+)
+from pipecat.metrics.metrics import LLMTokenUsage
+from pipecat.processors.aggregators.openai_llm_context import (
+ OpenAILLMContext,
+ OpenAILLMContextFrame,
+)
+from pipecat.processors.frame_processor import FrameDirection
+from pipecat.services.ai_services import LLMService
+
+
+class OpenAIUnhandledFunctionException(Exception):
+ pass
+
+
+class BaseOpenAILLMService(LLMService):
+ """This is the base for all services that use the AsyncOpenAI client.
+
+ This service consumes OpenAILLMContextFrame frames, which contain a reference
+ to an OpenAILLMContext frame. The OpenAILLMContext object defines the context
+ sent to the LLM for a completion. This includes user, assistant and system messages
+ as well as tool choices and the tool, which is used if requesting function
+ calls from the LLM.
+ """
+
+ class InputParams(BaseModel):
+ frequency_penalty: Optional[float] = Field(
+ default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0
+ )
+ presence_penalty: Optional[float] = Field(
+ default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0
+ )
+ seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0)
+ temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0)
+ # Note: top_k is currently not supported by the OpenAI client library,
+ # so top_k is ignored right now.
+ top_k: Optional[int] = Field(default=None, ge=0)
+ top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0)
+ max_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
+ max_completion_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1)
+ extra: Optional[Dict[str, Any]] = Field(default_factory=dict)
+
+ def __init__(
+ self,
+ *,
+ model: str,
+ api_key=None,
+ base_url=None,
+ organization=None,
+ project=None,
+ default_headers: Mapping[str, str] | None = None,
+ params: InputParams = InputParams(),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self._settings = {
+ "frequency_penalty": params.frequency_penalty,
+ "presence_penalty": params.presence_penalty,
+ "seed": params.seed,
+ "temperature": params.temperature,
+ "top_p": params.top_p,
+ "max_tokens": params.max_tokens,
+ "max_completion_tokens": params.max_completion_tokens,
+ "extra": params.extra if isinstance(params.extra, dict) else {},
+ }
+ self.set_model_name(model)
+ self._client = self.create_client(
+ api_key=api_key,
+ base_url=base_url,
+ organization=organization,
+ project=project,
+ default_headers=default_headers,
+ **kwargs,
+ )
+
+ def create_client(
+ self,
+ api_key=None,
+ base_url=None,
+ organization=None,
+ project=None,
+ default_headers=None,
+ **kwargs,
+ ):
+ return AsyncOpenAI(
+ api_key=api_key,
+ base_url=base_url,
+ organization=organization,
+ project=project,
+ http_client=DefaultAsyncHttpxClient(
+ limits=httpx.Limits(
+ max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None
+ )
+ ),
+ default_headers=default_headers,
+ )
+
+ def can_generate_metrics(self) -> bool:
+ return True
+
+ async def get_chat_completions(
+ self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam]
+ ) -> AsyncStream[ChatCompletionChunk]:
+ params = {
+ "model": self.model_name,
+ "stream": True,
+ "messages": messages,
+ "tools": context.tools,
+ "tool_choice": context.tool_choice,
+ "stream_options": {"include_usage": True},
+ "frequency_penalty": self._settings["frequency_penalty"],
+ "presence_penalty": self._settings["presence_penalty"],
+ "seed": self._settings["seed"],
+ "temperature": self._settings["temperature"],
+ "top_p": self._settings["top_p"],
+ "max_tokens": self._settings["max_tokens"],
+ "max_completion_tokens": self._settings["max_completion_tokens"],
+ }
+
+ params.update(self._settings["extra"])
+
+ chunks = await self._client.chat.completions.create(**params)
+ return chunks
+
+ async def _stream_chat_completions(
+ self, context: OpenAILLMContext
+ ) -> AsyncStream[ChatCompletionChunk]:
+ logger.debug(f"{self}: Generating chat [{context.get_messages_for_logging()}]")
+
+ messages: List[ChatCompletionMessageParam] = context.get_messages()
+
+ # base64 encode any images
+ for message in messages:
+ if message.get("mime_type") == "image/jpeg":
+ encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8")
+ text = message["content"]
+ message["content"] = [
+ {"type": "text", "text": text},
+ {
+ "type": "image_url",
+ "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"},
+ },
+ ]
+ del message["data"]
+ del message["mime_type"]
+
+ chunks = await self.get_chat_completions(context, messages)
+
+ return chunks
+
+ async def _process_context(self, context: OpenAILLMContext):
+ functions_list = []
+ arguments_list = []
+ tool_id_list = []
+ func_idx = 0
+ function_name = ""
+ arguments = ""
+ tool_call_id = ""
+
+ await self.start_ttfb_metrics()
+
+ chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions(
+ context
+ )
+
+ async for chunk in chunk_stream:
+ if chunk.usage:
+ tokens = LLMTokenUsage(
+ prompt_tokens=chunk.usage.prompt_tokens,
+ completion_tokens=chunk.usage.completion_tokens,
+ total_tokens=chunk.usage.total_tokens,
+ )
+ await self.start_llm_usage_metrics(tokens)
+
+ if chunk.choices is None or len(chunk.choices) == 0:
+ continue
+
+ await self.stop_ttfb_metrics()
+
+ if not chunk.choices[0].delta:
+ continue
+
+ if chunk.choices[0].delta.tool_calls:
+ # We're streaming the LLM response to enable the fastest response times.
+ # For text, we just yield each chunk as we receive it and count on consumers
+ # to do whatever coalescing they need (eg. to pass full sentences to TTS)
+ #
+ # If the LLM is a function call, we'll do some coalescing here.
+ # If the response contains a function name, we'll yield a frame to tell consumers
+ # that they can start preparing to call the function with that name.
+ # We accumulate all the arguments for the rest of the streamed response, then when
+ # the response is done, we package up all the arguments and the function name and
+ # yield a frame containing the function name and the arguments.
+
+ tool_call = chunk.choices[0].delta.tool_calls[0]
+ if tool_call.index != func_idx:
+ functions_list.append(function_name)
+ arguments_list.append(arguments)
+ tool_id_list.append(tool_call_id)
+ function_name = ""
+ arguments = ""
+ tool_call_id = ""
+ func_idx += 1
+ if tool_call.function and tool_call.function.name:
+ function_name += tool_call.function.name
+ tool_call_id = tool_call.id
+ if tool_call.function and tool_call.function.arguments:
+ # Keep iterating through the response to collect all the argument fragments
+ arguments += tool_call.function.arguments
+ elif chunk.choices[0].delta.content:
+ await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content))
+
+ # if we got a function name and arguments, check to see if it's a function with
+ # a registered handler. If so, run the registered callback, save the result to
+ # the context, and re-prompt to get a chat answer. If we don't have a registered
+ # handler, raise an exception.
+ if function_name and arguments:
+ # added to the list as last function name and arguments not added to the list
+ functions_list.append(function_name)
+ arguments_list.append(arguments)
+ tool_id_list.append(tool_call_id)
+
+ for index, (function_name, arguments, tool_id) in enumerate(
+ zip(functions_list, arguments_list, tool_id_list), start=1
+ ):
+ if self.has_function(function_name):
+ run_llm = False
+ arguments = json.loads(arguments)
+ await self.call_function(
+ context=context,
+ function_name=function_name,
+ arguments=arguments,
+ tool_call_id=tool_id,
+ run_llm=run_llm,
+ )
+ else:
+ raise OpenAIUnhandledFunctionException(
+ f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function."
+ )
+
+ async def process_frame(self, frame: Frame, direction: FrameDirection):
+ await super().process_frame(frame, direction)
+
+ context = None
+ if isinstance(frame, OpenAILLMContextFrame):
+ context: OpenAILLMContext = frame.context
+ elif isinstance(frame, LLMMessagesFrame):
+ context = OpenAILLMContext.from_messages(frame.messages)
+ elif isinstance(frame, VisionImageRawFrame):
+ context = OpenAILLMContext()
+ context.add_image_frame_message(
+ format=frame.format, size=frame.size, image=frame.image, text=frame.text
+ )
+ elif isinstance(frame, LLMUpdateSettingsFrame):
+ await self._update_settings(frame.settings)
+ else:
+ await self.push_frame(frame, direction)
+
+ if context:
+ try:
+ await self.push_frame(LLMFullResponseStartFrame())
+ await self.start_processing_metrics()
+ await self._process_context(context)
+ except httpx.TimeoutException:
+ await self._call_event_handler("on_completion_timeout")
+ finally:
+ await self.stop_processing_metrics()
+ await self.push_frame(LLMFullResponseEndFrame())
diff --git a/src/pipecat/services/openai/image.py b/src/pipecat/services/openai/image.py
new file mode 100644
index 000000000..fc0c475f9
--- /dev/null
+++ b/src/pipecat/services/openai/image.py
@@ -0,0 +1,58 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import io
+from typing import AsyncGenerator, Literal, Optional
+
+import aiohttp
+from loguru import logger
+from openai import AsyncOpenAI
+from PIL import Image
+
+from pipecat.frames.frames import (
+ ErrorFrame,
+ Frame,
+ URLImageRawFrame,
+)
+from pipecat.services.ai_services import ImageGenService
+
+
+class OpenAIImageGenService(ImageGenService):
+ def __init__(
+ self,
+ *,
+ api_key: str,
+ base_url: Optional[str] = None,
+ aiohttp_session: aiohttp.ClientSession,
+ image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"],
+ model: str = "dall-e-3",
+ ):
+ super().__init__()
+ self.set_model_name(model)
+ self._image_size = image_size
+ self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
+ self._aiohttp_session = aiohttp_session
+
+ async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]:
+ logger.debug(f"Generating image from prompt: {prompt}")
+
+ image = await self._client.images.generate(
+ prompt=prompt, model=self.model_name, n=1, size=self._image_size
+ )
+
+ image_url = image.data[0].url
+
+ if not image_url:
+ logger.error(f"{self} No image provided in response: {image}")
+ yield ErrorFrame("Image generation failed")
+ return
+
+ # Load the image from the url
+ async with self._aiohttp_session.get(image_url) as response:
+ image_stream = io.BytesIO(await response.content.read())
+ image = Image.open(image_stream)
+ frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format)
+ yield frame
diff --git a/src/pipecat/services/openai/llm.py b/src/pipecat/services/openai/llm.py
new file mode 100644
index 000000000..5ca51a479
--- /dev/null
+++ b/src/pipecat/services/openai/llm.py
@@ -0,0 +1,142 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import json
+from dataclasses import dataclass
+from typing import Any, Mapping
+
+from pipecat.frames.frames import (
+ FunctionCallCancelFrame,
+ FunctionCallInProgressFrame,
+ FunctionCallResultFrame,
+ UserImageRawFrame,
+)
+from pipecat.processors.aggregators.llm_response import (
+ LLMAssistantContextAggregator,
+ LLMUserContextAggregator,
+)
+from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
+from pipecat.services.openai.base_llm import BaseOpenAILLMService
+
+
+@dataclass
+class OpenAIContextAggregatorPair:
+ _user: "OpenAIUserContextAggregator"
+ _assistant: "OpenAIAssistantContextAggregator"
+
+ def user(self) -> "OpenAIUserContextAggregator":
+ return self._user
+
+ def assistant(self) -> "OpenAIAssistantContextAggregator":
+ return self._assistant
+
+
+class OpenAILLMService(BaseOpenAILLMService):
+ def __init__(
+ self,
+ *,
+ model: str = "gpt-4o",
+ params: BaseOpenAILLMService.InputParams = BaseOpenAILLMService.InputParams(),
+ **kwargs,
+ ):
+ super().__init__(model=model, params=params, **kwargs)
+
+ def create_context_aggregator(
+ self,
+ context: OpenAILLMContext,
+ *,
+ user_kwargs: Mapping[str, Any] = {},
+ assistant_kwargs: Mapping[str, Any] = {},
+ ) -> OpenAIContextAggregatorPair:
+ """Create an instance of OpenAIContextAggregatorPair from an
+ OpenAILLMContext. Constructor keyword arguments for both the user and
+ assistant aggregators can be provided.
+
+ Args:
+ context (OpenAILLMContext): The LLM context.
+ user_kwargs (Mapping[str, Any], optional): Additional keyword
+ arguments for the user context aggregator constructor. Defaults
+ to an empty mapping.
+ assistant_kwargs (Mapping[str, Any], optional): Additional keyword
+ arguments for the assistant context aggregator
+ constructor. Defaults to an empty mapping.
+
+ Returns:
+ OpenAIContextAggregatorPair: A pair of context aggregators, one for
+ the user and one for the assistant, encapsulated in an
+ OpenAIContextAggregatorPair.
+
+ """
+ context.set_llm_adapter(self.get_llm_adapter())
+ user = OpenAIUserContextAggregator(context, **user_kwargs)
+ assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs)
+ return OpenAIContextAggregatorPair(_user=user, _assistant=assistant)
+
+
+class OpenAIUserContextAggregator(LLMUserContextAggregator):
+ pass
+
+
+class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator):
+ async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame):
+ self._context.add_message(
+ {
+ "role": "assistant",
+ "tool_calls": [
+ {
+ "id": frame.tool_call_id,
+ "function": {
+ "name": frame.function_name,
+ "arguments": json.dumps(frame.arguments),
+ },
+ "type": "function",
+ }
+ ],
+ }
+ )
+ self._context.add_message(
+ {
+ "role": "tool",
+ "content": "IN_PROGRESS",
+ "tool_call_id": frame.tool_call_id,
+ }
+ )
+
+ async def handle_function_call_result(self, frame: FunctionCallResultFrame):
+ if frame.result:
+ result = json.dumps(frame.result)
+ await self._update_function_call_result(frame.function_name, frame.tool_call_id, result)
+ else:
+ await self._update_function_call_result(
+ frame.function_name, frame.tool_call_id, "COMPLETED"
+ )
+
+ async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame):
+ await self._update_function_call_result(
+ frame.function_name, frame.tool_call_id, "CANCELLED"
+ )
+
+ async def _update_function_call_result(
+ self, function_name: str, tool_call_id: str, result: Any
+ ):
+ for message in self._context.messages:
+ if (
+ message["role"] == "tool"
+ and message["tool_call_id"]
+ and message["tool_call_id"] == tool_call_id
+ ):
+ message["content"] = result
+
+ async def handle_user_image_frame(self, frame: UserImageRawFrame):
+ await self._update_function_call_result(
+ frame.request.function_name, frame.request.tool_call_id, "COMPLETED"
+ )
+ self._context.add_image_frame_message(
+ format=frame.format,
+ size=frame.size,
+ image=frame.image,
+ text=frame.request.context,
+ )
diff --git a/src/pipecat/services/openai/stt.py b/src/pipecat/services/openai/stt.py
new file mode 100644
index 000000000..173205aa0
--- /dev/null
+++ b/src/pipecat/services/openai/stt.py
@@ -0,0 +1,66 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+from typing import Optional
+
+from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription
+from pipecat.transcriptions.language import Language
+
+
+class OpenAISTTService(BaseWhisperSTTService):
+ """OpenAI Speech-to-Text service that generates text from audio.
+
+ Uses OpenAI's transcription API to convert audio to text. Requires an OpenAI API key
+ set via the api_key parameter or OPENAI_API_KEY environment variable.
+
+ Args:
+ model: Model to use — either gpt-4o or Whisper. Defaults to "gpt-4o-transcribe".
+ api_key: OpenAI API key. Defaults to None.
+ base_url: API base URL. Defaults to None.
+ language: Language of the audio input. Defaults to English.
+ prompt: Optional text to guide the model's style or continue a previous segment.
+ temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0.
+ **kwargs: Additional arguments passed to BaseWhisperSTTService.
+ """
+
+ def __init__(
+ self,
+ *,
+ model: str = "gpt-4o-transcribe",
+ api_key: Optional[str] = None,
+ base_url: Optional[str] = None,
+ language: Optional[Language] = Language.EN,
+ prompt: Optional[str] = None,
+ temperature: Optional[float] = None,
+ **kwargs,
+ ):
+ super().__init__(
+ model=model,
+ api_key=api_key,
+ base_url=base_url,
+ language=language,
+ prompt=prompt,
+ temperature=temperature,
+ **kwargs,
+ )
+
+ async def _transcribe(self, audio: bytes) -> Transcription:
+ assert self._language is not None # Assigned in the BaseWhisperSTTService class
+
+ # Build kwargs dict with only set parameters
+ kwargs = {
+ "file": ("audio.wav", audio, "audio/wav"),
+ "model": self.model_name,
+ "language": self._language,
+ }
+
+ if self._prompt is not None:
+ kwargs["prompt"] = self._prompt
+
+ if self._temperature is not None:
+ kwargs["temperature"] = self._temperature
+
+ return await self._client.audio.transcriptions.create(**kwargs)
diff --git a/src/pipecat/services/openai/tts.py b/src/pipecat/services/openai/tts.py
new file mode 100644
index 000000000..68644f147
--- /dev/null
+++ b/src/pipecat/services/openai/tts.py
@@ -0,0 +1,129 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+from typing import AsyncGenerator, Dict, Literal, Optional
+
+from loguru import logger
+from openai import AsyncOpenAI, BadRequestError
+
+from pipecat.frames.frames import (
+ ErrorFrame,
+ Frame,
+ StartFrame,
+ TTSAudioRawFrame,
+ TTSStartedFrame,
+ TTSStoppedFrame,
+)
+from pipecat.services.ai_services import TTSService
+
+ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"]
+
+VALID_VOICES: Dict[str, ValidVoice] = {
+ "alloy": "alloy",
+ "echo": "echo",
+ "fable": "fable",
+ "onyx": "onyx",
+ "nova": "nova",
+ "shimmer": "shimmer",
+}
+
+
+class OpenAITTSService(TTSService):
+ """OpenAI Text-to-Speech service that generates audio from text.
+
+ This service uses the OpenAI TTS API to generate PCM-encoded audio at 24kHz.
+
+ Args:
+ api_key: OpenAI API key. Defaults to None.
+ voice: Voice ID to use. Defaults to "alloy".
+ model: TTS model to use. Defaults to "gpt-4o-mini-tts".
+ sample_rate: Output audio sample rate in Hz. Defaults to None.
+ **kwargs: Additional keyword arguments passed to TTSService.
+
+ The service returns PCM-encoded audio at the specified sample rate.
+
+ """
+
+ OPENAI_SAMPLE_RATE = 24000 # OpenAI TTS always outputs at 24kHz
+
+ def __init__(
+ self,
+ *,
+ api_key: Optional[str] = None,
+ base_url: Optional[str] = None,
+ voice: str = "alloy",
+ model: str = "gpt-4o-mini-tts",
+ sample_rate: Optional[int] = None,
+ instructions: Optional[str] = None,
+ **kwargs,
+ ):
+ if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE:
+ logger.warning(
+ f"OpenAI TTS only supports {self.OPENAI_SAMPLE_RATE}Hz sample rate. "
+ f"Current rate of {self.sample_rate}Hz may cause issues."
+ )
+ super().__init__(sample_rate=sample_rate, **kwargs)
+
+ self.set_model_name(model)
+ self.set_voice(voice)
+ self._instructions = instructions
+ self._client = AsyncOpenAI(api_key=api_key, base_url=base_url)
+
+ def can_generate_metrics(self) -> bool:
+ return True
+
+ async def set_model(self, model: str):
+ logger.info(f"Switching TTS model to: [{model}]")
+ self.set_model_name(model)
+
+ async def start(self, frame: StartFrame):
+ await super().start(frame)
+ if self.sample_rate != self.OPENAI_SAMPLE_RATE:
+ logger.warning(
+ f"OpenAI TTS requires {self.OPENAI_SAMPLE_RATE}Hz sample rate. "
+ f"Current rate of {self.sample_rate}Hz may cause issues."
+ )
+
+ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
+ logger.debug(f"{self}: Generating TTS [{text}]")
+ try:
+ await self.start_ttfb_metrics()
+
+ # Setup extra body parameters
+ extra_body = {}
+ if self._instructions:
+ extra_body["instructions"] = self._instructions
+
+ async with self._client.audio.speech.with_streaming_response.create(
+ input=text or " ", # Text must contain at least one character
+ model=self.model_name,
+ voice=VALID_VOICES[self._voice_id],
+ response_format="pcm",
+ extra_body=extra_body,
+ ) as r:
+ if r.status_code != 200:
+ error = await r.text()
+ logger.error(
+ f"{self} error getting audio (status: {r.status_code}, error: {error})"
+ )
+ yield ErrorFrame(
+ f"Error getting audio (status: {r.status_code}, error: {error})"
+ )
+ return
+
+ await self.start_tts_usage_metrics(text)
+
+ CHUNK_SIZE = 1024
+
+ yield TTSStartedFrame()
+ async for chunk in r.iter_bytes(CHUNK_SIZE):
+ if len(chunk) > 0:
+ await self.stop_ttfb_metrics()
+ frame = TTSAudioRawFrame(chunk, self.sample_rate, 1)
+ yield frame
+ yield TTSStoppedFrame()
+ except BadRequestError as e:
+ logger.exception(f"{self} error generating TTS: {e}")
diff --git a/src/pipecat/services/openai_realtime_beta/azure.py b/src/pipecat/services/openai_realtime_beta/azure.py
index 5f046b8b0..799c5e686 100644
--- a/src/pipecat/services/openai_realtime_beta/azure.py
+++ b/src/pipecat/services/openai_realtime_beta/azure.py
@@ -4,8 +4,6 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
-import os
-
from loguru import logger
from .openai import OpenAIRealtimeBetaLLMService
diff --git a/src/pipecat/services/openai_realtime_beta/context.py b/src/pipecat/services/openai_realtime_beta/context.py
index c8381976f..80cffbef2 100644
--- a/src/pipecat/services/openai_realtime_beta/context.py
+++ b/src/pipecat/services/openai_realtime_beta/context.py
@@ -6,23 +6,18 @@
import copy
import json
-from typing import Optional
from loguru import logger
from pipecat.frames.frames import (
Frame,
FunctionCallResultFrame,
- FunctionCallResultProperties,
LLMMessagesUpdateFrame,
LLMSetToolsFrame,
)
-from pipecat.processors.aggregators.openai_llm_context import (
- OpenAILLMContext,
- OpenAILLMContextFrame,
-)
+from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
from pipecat.processors.frame_processor import FrameDirection
-from pipecat.services.openai import (
+from pipecat.services.openai.llm import (
OpenAIAssistantContextAggregator,
OpenAIUserContextAggregator,
)
diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py
index a1f4bc731..7b10b83fc 100644
--- a/src/pipecat/services/openai_realtime_beta/openai.py
+++ b/src/pipecat/services/openai_realtime_beta/openai.py
@@ -12,8 +12,6 @@ from typing import Any, Mapping
from loguru import logger
-from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
-
try:
import websockets
except ModuleNotFoundError as e:
@@ -23,6 +21,7 @@ except ModuleNotFoundError as e:
)
raise Exception(f"Missing module: {e}")
+from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter
from pipecat.frames.frames import (
BotStoppedSpeakingFrame,
CancelFrame,
@@ -55,7 +54,7 @@ from pipecat.processors.aggregators.openai_llm_context import (
)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import LLMService
-from pipecat.services.openai import OpenAIContextAggregatorPair
+from pipecat.services.openai.llm import OpenAIContextAggregatorPair
from pipecat.utils.time import time_now_iso8601
from . import events
diff --git a/src/pipecat/services/openpipe/__init__.py b/src/pipecat/services/openpipe/__init__.py
new file mode 100644
index 000000000..c6e439ba9
--- /dev/null
+++ b/src/pipecat/services/openpipe/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "openpipe", "openpipe.llm")
diff --git a/src/pipecat/services/openpipe.py b/src/pipecat/services/openpipe/llm.py
similarity index 89%
rename from src/pipecat/services/openpipe.py
rename to src/pipecat/services/openpipe/llm.py
index c89c9d6a3..63e499714 100644
--- a/src/pipecat/services/openpipe.py
+++ b/src/pipecat/services/openpipe/llm.py
@@ -10,16 +10,14 @@ from loguru import logger
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
-from pipecat.services.openai import OpenAILLMService
+from pipecat.services.openai.llm import OpenAILLMService
try:
from openpipe import AsyncOpenAI as OpenPipeAI
from openpipe import AsyncStream
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
- logger.error(
- "In order to use OpenPipe, you need to `pip install pipecat-ai[openpipe]`. Also, set `OPENPIPE_API_KEY` and `OPENAI_API_KEY` environment variables."
- )
+ logger.error("In order to use OpenPipe, you need to `pip install pipecat-ai[openpipe]`.")
raise Exception(f"Missing module: {e}")
diff --git a/src/pipecat/services/openrouter/__init__.py b/src/pipecat/services/openrouter/__init__.py
new file mode 100644
index 000000000..12c4f2ea3
--- /dev/null
+++ b/src/pipecat/services/openrouter/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "openrouter", "openrouter.llm")
diff --git a/src/pipecat/services/openrouter.py b/src/pipecat/services/openrouter/llm.py
similarity index 96%
rename from src/pipecat/services/openrouter.py
rename to src/pipecat/services/openrouter/llm.py
index d25990d2e..431724f94 100644
--- a/src/pipecat/services/openrouter.py
+++ b/src/pipecat/services/openrouter/llm.py
@@ -8,7 +8,7 @@ from typing import Optional
from loguru import logger
-from pipecat.services.openai import OpenAILLMService
+from pipecat.services.openai.llm import OpenAILLMService
class OpenRouterLLMService(OpenAILLMService):
diff --git a/src/pipecat/services/perplexity/__init__.py b/src/pipecat/services/perplexity/__init__.py
new file mode 100644
index 000000000..b1cf42e12
--- /dev/null
+++ b/src/pipecat/services/perplexity/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "perplexity", "perplexity.llm")
diff --git a/src/pipecat/services/perplexity.py b/src/pipecat/services/perplexity/llm.py
similarity index 98%
rename from src/pipecat/services/perplexity.py
rename to src/pipecat/services/perplexity/llm.py
index b0c560ce4..ff9f82bdb 100644
--- a/src/pipecat/services/perplexity.py
+++ b/src/pipecat/services/perplexity/llm.py
@@ -6,13 +6,12 @@
from typing import List
-from loguru import logger
from openai import NOT_GIVEN, AsyncStream
from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam
from pipecat.metrics.metrics import LLMTokenUsage
from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext
-from pipecat.services.openai import OpenAILLMService
+from pipecat.services.openai.llm import OpenAILLMService
class PerplexityLLMService(OpenAILLMService):
diff --git a/src/pipecat/services/piper/__init__.py b/src/pipecat/services/piper/__init__.py
new file mode 100644
index 000000000..2cb7e790b
--- /dev/null
+++ b/src/pipecat/services/piper/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "piper", "piper.tts")
diff --git a/src/pipecat/services/piper.py b/src/pipecat/services/piper/tts.py
similarity index 100%
rename from src/pipecat/services/piper.py
rename to src/pipecat/services/piper/tts.py
diff --git a/src/pipecat/services/playht/__init__.py b/src/pipecat/services/playht/__init__.py
new file mode 100644
index 000000000..3d2247f5e
--- /dev/null
+++ b/src/pipecat/services/playht/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "playht", "playht.tts")
diff --git a/src/pipecat/services/playht.py b/src/pipecat/services/playht/tts.py
similarity index 98%
rename from src/pipecat/services/playht.py
rename to src/pipecat/services/playht/tts.py
index 75677876f..f7157a950 100644
--- a/src/pipecat/services/playht.py
+++ b/src/pipecat/services/playht/tts.py
@@ -36,9 +36,7 @@ try:
from pyht.client import Language as PlayHTLanguage
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
- logger.error(
- "In order to use PlayHT, you need to `pip install pipecat-ai[playht]`. Also, set `PLAY_HT_USER_ID` and `PLAY_HT_API_KEY` environment variables."
- )
+ logger.error("In order to use PlayHT, you need to `pip install pipecat-ai[playht]`.")
raise Exception(f"Missing module: {e}")
diff --git a/src/pipecat/services/qwen/__init__.py b/src/pipecat/services/qwen/__init__.py
new file mode 100644
index 000000000..0b398c1eb
--- /dev/null
+++ b/src/pipecat/services/qwen/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "qwen", "qwen.llm")
diff --git a/src/pipecat/services/qwen.py b/src/pipecat/services/qwen/llm.py
similarity index 93%
rename from src/pipecat/services/qwen.py
rename to src/pipecat/services/qwen/llm.py
index 0ed5ca593..de910a741 100644
--- a/src/pipecat/services/qwen.py
+++ b/src/pipecat/services/qwen/llm.py
@@ -4,11 +4,9 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
-from typing import Optional
-
from loguru import logger
-from pipecat.services.openai import OpenAILLMService, OpenAISTTService
+from pipecat.services.openai.llm import OpenAILLMService
class QwenLLMService(OpenAILLMService):
diff --git a/src/pipecat/services/rime/__init__.py b/src/pipecat/services/rime/__init__.py
new file mode 100644
index 000000000..842fbd971
--- /dev/null
+++ b/src/pipecat/services/rime/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "rime", "rime.tts")
diff --git a/src/pipecat/services/rime.py b/src/pipecat/services/rime/tts.py
similarity index 99%
rename from src/pipecat/services/rime.py
rename to src/pipecat/services/rime/tts.py
index a1a455eb5..f5b5da2a6 100644
--- a/src/pipecat/services/rime.py
+++ b/src/pipecat/services/rime/tts.py
@@ -34,9 +34,7 @@ try:
import websockets
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
- logger.error(
- "In order to use Rime, you need to `pip install pipecat-ai[rime]`. Also, set `RIME_API_KEY` environment variable."
- )
+ logger.error("In order to use Rime, you need to `pip install pipecat-ai[rime]`.")
raise Exception(f"Missing module: {e}")
diff --git a/src/pipecat/services/riva/__init__.py b/src/pipecat/services/riva/__init__.py
new file mode 100644
index 000000000..29fe0c003
--- /dev/null
+++ b/src/pipecat/services/riva/__init__.py
@@ -0,0 +1,14 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .stt import *
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "riva", "riva.[stt,tts]")
diff --git a/src/pipecat/services/riva.py b/src/pipecat/services/riva/stt.py
similarity index 64%
rename from src/pipecat/services/riva.py
rename to src/pipecat/services/riva/stt.py
index de065aef4..63eea8230 100644
--- a/src/pipecat/services/riva.py
+++ b/src/pipecat/services/riva/stt.py
@@ -17,11 +17,8 @@ from pipecat.frames.frames import (
InterimTranscriptionFrame,
StartFrame,
TranscriptionFrame,
- TTSAudioRawFrame,
- TTSStartedFrame,
- TTSStoppedFrame,
)
-from pipecat.services.ai_services import STTService, TTSService
+from pipecat.services.ai_services import STTService
from pipecat.transcriptions.language import Language
from pipecat.utils.time import time_now_iso8601
@@ -30,100 +27,9 @@ try:
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
- logger.error(
- "In order to use nvidia riva TTS or STT, you need to `pip install pipecat-ai[riva]`. Also, set `NVIDIA_API_KEY` environment variable."
- )
+ logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`.")
raise Exception(f"Missing module: {e}")
-FASTPITCH_TIMEOUT_SECS = 5
-
-
-class FastPitchTTSService(TTSService):
- class InputParams(BaseModel):
- language: Optional[Language] = Language.EN_US
- quality: Optional[int] = 20
-
- def __init__(
- self,
- *,
- api_key: str,
- server: str = "grpc.nvcf.nvidia.com:443",
- voice_id: str = "English-US.Female-1",
- sample_rate: Optional[int] = None,
- function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972",
- params: InputParams = InputParams(),
- **kwargs,
- ):
- super().__init__(sample_rate=sample_rate, **kwargs)
- self._api_key = api_key
- self._voice_id = voice_id
- self._language_code = params.language
- self._quality = params.quality
-
- self.set_model_name("fastpitch-hifigan-tts")
- self.set_voice(voice_id)
-
- metadata = [
- ["function-id", function_id],
- ["authorization", f"Bearer {api_key}"],
- ]
- auth = riva.client.Auth(None, True, server, metadata)
-
- self._service = riva.client.SpeechSynthesisService(auth)
-
- # warm up the service
- config_response = self._service.stub.GetRivaSynthesisConfig(
- riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
- )
-
- async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
- def read_audio_responses(queue: asyncio.Queue):
- def add_response(r):
- asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
-
- try:
- responses = self._service.synthesize_online(
- text,
- self._voice_id,
- self._language_code,
- sample_rate_hz=self.sample_rate,
- audio_prompt_file=None,
- quality=self._quality,
- custom_dictionary={},
- )
- for r in responses:
- add_response(r)
- add_response(None)
- except Exception as e:
- logger.error(f"{self} exception: {e}")
- add_response(None)
-
- await self.start_ttfb_metrics()
- yield TTSStartedFrame()
-
- logger.debug(f"{self}: Generating TTS [{text}]")
-
- try:
- queue = asyncio.Queue()
- await asyncio.to_thread(read_audio_responses, queue)
-
- # Wait for the thread to start.
- resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
- while resp:
- await self.stop_ttfb_metrics()
- frame = TTSAudioRawFrame(
- audio=resp.audio,
- sample_rate=self.sample_rate,
- num_channels=1,
- )
- yield frame
- resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
- except asyncio.TimeoutError:
- logger.error(f"{self} timeout waiting for audio response")
-
- await self.start_tts_usage_metrics(text)
- yield TTSStoppedFrame()
-
class ParakeetSTTService(STTService):
class InputParams(BaseModel):
diff --git a/src/pipecat/services/riva/tts.py b/src/pipecat/services/riva/tts.py
new file mode 100644
index 000000000..e0da3ab98
--- /dev/null
+++ b/src/pipecat/services/riva/tts.py
@@ -0,0 +1,117 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import asyncio
+from typing import AsyncGenerator, Optional
+
+from loguru import logger
+from pydantic import BaseModel
+
+from pipecat.frames.frames import (
+ Frame,
+ TTSAudioRawFrame,
+ TTSStartedFrame,
+ TTSStoppedFrame,
+)
+from pipecat.services.ai_services import TTSService
+from pipecat.transcriptions.language import Language
+
+try:
+ import riva.client
+
+except ModuleNotFoundError as e:
+ logger.error(f"Exception: {e}")
+ logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[riva]`.")
+ raise Exception(f"Missing module: {e}")
+
+FASTPITCH_TIMEOUT_SECS = 5
+
+
+class FastPitchTTSService(TTSService):
+ class InputParams(BaseModel):
+ language: Optional[Language] = Language.EN_US
+ quality: Optional[int] = 20
+
+ def __init__(
+ self,
+ *,
+ api_key: str,
+ server: str = "grpc.nvcf.nvidia.com:443",
+ voice_id: str = "English-US.Female-1",
+ sample_rate: Optional[int] = None,
+ function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972",
+ params: InputParams = InputParams(),
+ **kwargs,
+ ):
+ super().__init__(sample_rate=sample_rate, **kwargs)
+ self._api_key = api_key
+ self._voice_id = voice_id
+ self._language_code = params.language
+ self._quality = params.quality
+
+ self.set_model_name("fastpitch-hifigan-tts")
+ self.set_voice(voice_id)
+
+ metadata = [
+ ["function-id", function_id],
+ ["authorization", f"Bearer {api_key}"],
+ ]
+ auth = riva.client.Auth(None, True, server, metadata)
+
+ self._service = riva.client.SpeechSynthesisService(auth)
+
+ # warm up the service
+ config_response = self._service.stub.GetRivaSynthesisConfig(
+ riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest()
+ )
+
+ async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
+ def read_audio_responses(queue: asyncio.Queue):
+ def add_response(r):
+ asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop())
+
+ try:
+ responses = self._service.synthesize_online(
+ text,
+ self._voice_id,
+ self._language_code,
+ sample_rate_hz=self.sample_rate,
+ audio_prompt_file=None,
+ quality=self._quality,
+ custom_dictionary={},
+ )
+ for r in responses:
+ add_response(r)
+ add_response(None)
+ except Exception as e:
+ logger.error(f"{self} exception: {e}")
+ add_response(None)
+
+ await self.start_ttfb_metrics()
+ yield TTSStartedFrame()
+
+ logger.debug(f"{self}: Generating TTS [{text}]")
+
+ try:
+ queue = asyncio.Queue()
+ await asyncio.to_thread(read_audio_responses, queue)
+
+ # Wait for the thread to start.
+ resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
+ while resp:
+ await self.stop_ttfb_metrics()
+ frame = TTSAudioRawFrame(
+ audio=resp.audio,
+ sample_rate=self.sample_rate,
+ num_channels=1,
+ )
+ yield frame
+ resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS)
+ except asyncio.TimeoutError:
+ logger.error(f"{self} timeout waiting for audio response")
+
+ await self.start_tts_usage_metrics(text)
+ yield TTSStoppedFrame()
diff --git a/src/pipecat/services/simli/__init__.py b/src/pipecat/services/simli/__init__.py
new file mode 100644
index 000000000..7e7cd7ed7
--- /dev/null
+++ b/src/pipecat/services/simli/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .video import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "simli", "simli.video")
diff --git a/src/pipecat/services/simli.py b/src/pipecat/services/simli/video.py
similarity index 100%
rename from src/pipecat/services/simli.py
rename to src/pipecat/services/simli/video.py
diff --git a/src/pipecat/services/tavus/__init__.py b/src/pipecat/services/tavus/__init__.py
new file mode 100644
index 000000000..97811e467
--- /dev/null
+++ b/src/pipecat/services/tavus/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .video import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "tavus", "tavus.video")
diff --git a/src/pipecat/services/tavus.py b/src/pipecat/services/tavus/video.py
similarity index 99%
rename from src/pipecat/services/tavus.py
rename to src/pipecat/services/tavus/video.py
index 9083dd53d..27b1d7155 100644
--- a/src/pipecat/services/tavus.py
+++ b/src/pipecat/services/tavus/video.py
@@ -4,7 +4,6 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
-
"""This module implements Tavus as a sink transport layer"""
import base64
diff --git a/src/pipecat/services/to_be_updated/__init__.py b/src/pipecat/services/to_be_updated/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/src/pipecat/services/to_be_updated/cloudflare_ai_service.py b/src/pipecat/services/to_be_updated/cloudflare_ai_service.py
deleted file mode 100644
index ff637ff1a..000000000
--- a/src/pipecat/services/to_be_updated/cloudflare_ai_service.py
+++ /dev/null
@@ -1,69 +0,0 @@
-import os
-
-import requests
-from services.ai_service import AIService
-
-# Note that Cloudflare's AI workers are still in beta.
-# https://developers.cloudflare.com/workers-ai/
-
-
-class CloudflareAIService(AIService):
- def __init__(self):
- super().__init__()
- self.cloudflare_account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID")
- self.cloudflare_api_token = os.getenv("CLOUDFLARE_API_TOKEN")
-
- self.api_base_url = (
- f"https://api.cloudflare.com/client/v4/accounts/{self.cloudflare_account_id}/ai/run/"
- )
- self.headers = {"Authorization": f"Bearer {self.cloudflare_api_token}"}
-
- # base endpoint, used by the others
- def run(self, model, input):
- response = requests.post(f"{self.api_base_url}{model}", headers=self.headers, json=input)
- return response.json()
-
- # https://developers.cloudflare.com/workers-ai/models/llm/
- def run_llm(self, messages, latest_user_message=None, stream=True):
- input = {
- "messages": [
- {"role": "system", "content": "You are a friendly assistant"},
- {"role": "user", "content": sentence},
- ]
- }
-
- return self.run("@cf/meta/llama-2-7b-chat-int8", input)
-
- # https://developers.cloudflare.com/workers-ai/models/translation/
- def run_text_translation(self, sentence, source_language, target_language):
- return self.run(
- "@cf/meta/m2m100-1.2b",
- {"text": sentence, "source_lang": source_language, "target_lang": target_language},
- )
-
- # https://developers.cloudflare.com/workers-ai/models/sentiment-analysis/
- def run_text_sentiment(self, sentence):
- return self.run("@cf/huggingface/distilbert-sst-2-int8", {"text": sentence})
-
- # https://developers.cloudflare.com/workers-ai/models/image-classification/
- def run_image_classification(self, image_url):
- response = requests.get(image_url)
-
- if response.status_code != 200:
- return {"error": "There was a problem downloading the image."}
-
- if response.status_code == 200:
- data = response.content
- inputs = {"image": list(data)}
-
- return self.run("@cf/microsoft/resnet-50", inputs)
-
- # https://developers.cloudflare.com/workers-ai/models/embedding/
- def run_embeddings(self, texts, size="medium"):
- models = {
- "small": "@cf/baai/bge-small-en-v1.5", # 384 output dimensions
- "medium": "@cf/baai/bge-base-en-v1.5", # 768 output dimensions
- "large": "@cf/baai/bge-large-en-v1.5", # 1024 output dimensions
- }
-
- return self.run(models[size], {"text": texts})
diff --git a/src/pipecat/services/to_be_updated/google_ai_service.py b/src/pipecat/services/to_be_updated/google_ai_service.py
deleted file mode 100644
index 3ca688750..000000000
--- a/src/pipecat/services/to_be_updated/google_ai_service.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import os
-
-import openai
-
-# To use Google Cloud's AI products, you'll need to install Google Cloud
-# CLI and enable the TTS and in your project:
-# https://cloud.google.com/sdk/docs/install
-from google.cloud import texttospeech
-from services.ai_service import AIService
-
-
-class GoogleAIService(AIService):
- def __init__(self):
- super().__init__()
-
- self.client = texttospeech.TextToSpeechClient()
- self.voice = texttospeech.VoiceSelectionParams(
- language_code="en-GB", name="en-GB-Neural2-F"
- )
-
- self.audio_config = texttospeech.AudioConfig(
- audio_encoding=texttospeech.AudioEncoding.LINEAR16, sample_rate_hertz=16000
- )
-
- def run_tts(self, sentence):
- synthesis_input = texttospeech.SynthesisInput(text=sentence.strip())
- result = self.client.synthesize_speech(
- input=synthesis_input, voice=self.voice, audio_config=self.audio_config
- )
- return result
diff --git a/src/pipecat/services/to_be_updated/huggingface_ai_service.py b/src/pipecat/services/to_be_updated/huggingface_ai_service.py
deleted file mode 100644
index 09f0b8248..000000000
--- a/src/pipecat/services/to_be_updated/huggingface_ai_service.py
+++ /dev/null
@@ -1,33 +0,0 @@
-from services.ai_service import AIService
-from transformers import pipeline
-
-# These functions are just intended for testing, not production use. If
-# you'd like to use HuggingFace, you should use your own models, or do
-# some research into the specific models that will work best for your use
-# case.
-
-
-class HuggingFaceAIService(AIService):
- def __init__(self):
- super().__init__()
-
- def run_text_sentiment(self, sentence):
- classifier = pipeline("sentiment-analysis")
- return classifier(sentence)
-
- # available models at https://huggingface.co/Helsinki-NLP (**not all
- # models use 2-character language codes**)
- def run_text_translation(self, sentence, source_language, target_language):
- translator = pipeline(
- f"translation", model=f"Helsinki-NLP/opus-mt-{source_language}-{target_language}"
- )
-
- return translator(sentence)[0]["translation_text"]
-
- def run_text_summarization(self, sentence):
- summarizer = pipeline("summarization")
- return summarizer(sentence)
-
- def run_image_classification(self, image_path):
- classifier = pipeline("image-classification")
- return classifier(image_path)
diff --git a/src/pipecat/services/to_be_updated/mock_ai_service.py b/src/pipecat/services/to_be_updated/mock_ai_service.py
deleted file mode 100644
index 0825cde33..000000000
--- a/src/pipecat/services/to_be_updated/mock_ai_service.py
+++ /dev/null
@@ -1,28 +0,0 @@
-import io
-import time
-
-import requests
-from PIL import Image
-from services.ai_service import AIService
-
-
-class MockAIService(AIService):
- def __init__(self):
- super().__init__()
-
- def run_tts(self, sentence):
- print("running tts", sentence)
- time.sleep(2)
-
- def run_image_gen(self, sentence):
- image_url = "https://d3d00swyhr67nd.cloudfront.net/w800h800/collection/ASH/ASHM/ASH_ASHM_WA1940_2_22-001.jpg"
- response = requests.get(image_url)
- image_stream = io.BytesIO(response.content)
- image = Image.open(image_stream)
- time.sleep(1)
- return (image_url, image.tobytes(), image.size)
-
- def run_llm(self, messages, latest_user_message=None, stream=True):
- for i in range(5):
- time.sleep(1)
- yield ({"choices": [{"delta": {"content": f"hello {i}!"}}]})
diff --git a/src/pipecat/services/together/__init__.py b/src/pipecat/services/together/__init__.py
new file mode 100644
index 000000000..b7e3a779b
--- /dev/null
+++ b/src/pipecat/services/together/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .llm import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "together", "together.llm")
diff --git a/src/pipecat/services/together.py b/src/pipecat/services/together/llm.py
similarity index 96%
rename from src/pipecat/services/together.py
rename to src/pipecat/services/together/llm.py
index 3b83cad83..31b15ae73 100644
--- a/src/pipecat/services/together.py
+++ b/src/pipecat/services/together/llm.py
@@ -4,10 +4,9 @@
# SPDX-License-Identifier: BSD 2-Clause License
#
-
from loguru import logger
-from pipecat.services.openai import OpenAILLMService
+from pipecat.services.openai.llm import OpenAILLMService
class TogetherLLMService(OpenAILLMService):
diff --git a/src/pipecat/services/ultravox/__init__.py b/src/pipecat/services/ultravox/__init__.py
new file mode 100644
index 000000000..b05e930eb
--- /dev/null
+++ b/src/pipecat/services/ultravox/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .stt import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "ultravox", "ultravox.stt")
diff --git a/src/pipecat/services/ultravox.py b/src/pipecat/services/ultravox/stt.py
similarity index 100%
rename from src/pipecat/services/ultravox.py
rename to src/pipecat/services/ultravox/stt.py
diff --git a/src/pipecat/services/whisper/__init__.py b/src/pipecat/services/whisper/__init__.py
new file mode 100644
index 000000000..3c82f3bfb
--- /dev/null
+++ b/src/pipecat/services/whisper/__init__.py
@@ -0,0 +1,14 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .base_stt import *
+from .stt import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "whisper", "whisper.stt")
diff --git a/src/pipecat/services/base_whisper.py b/src/pipecat/services/whisper/base_stt.py
similarity index 100%
rename from src/pipecat/services/base_whisper.py
rename to src/pipecat/services/whisper/base_stt.py
diff --git a/src/pipecat/services/whisper.py b/src/pipecat/services/whisper/stt.py
similarity index 100%
rename from src/pipecat/services/whisper.py
rename to src/pipecat/services/whisper/stt.py
diff --git a/src/pipecat/services/xtts/__init__.py b/src/pipecat/services/xtts/__init__.py
new file mode 100644
index 000000000..8f5a87892
--- /dev/null
+++ b/src/pipecat/services/xtts/__init__.py
@@ -0,0 +1,13 @@
+#
+# Copyright (c) 2024–2025, Daily
+#
+# SPDX-License-Identifier: BSD 2-Clause License
+#
+
+import sys
+
+from pipecat.services import DeprecatedModuleProxy
+
+from .tts import *
+
+sys.modules[__name__] = DeprecatedModuleProxy(globals(), "xtts", "xtts.tts")
diff --git a/src/pipecat/services/xtts.py b/src/pipecat/services/xtts/tts.py
similarity index 100%
rename from src/pipecat/services/xtts.py
rename to src/pipecat/services/xtts/tts.py
diff --git a/tests/integration/test_integration_unified_function_calling.py b/tests/integration/test_integration_unified_function_calling.py
index 0696c531a..f4652905d 100644
--- a/tests/integration/test_integration_unified_function_calling.py
+++ b/tests/integration/test_integration_unified_function_calling.py
@@ -13,10 +13,11 @@ from dotenv import load_dotenv
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.pipeline.pipeline import Pipeline
+from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame
from pipecat.services.ai_services import LLMService
from pipecat.services.anthropic import AnthropicLLMService
from pipecat.services.google import GoogleLLMService
-from pipecat.services.openai import OpenAILLMContext, OpenAILLMContextFrame, OpenAILLMService
+from pipecat.services.openai import OpenAILLMContext, OpenAILLMService
from pipecat.tests.utils import run_test
load_dotenv(override=True)