diff --git a/CHANGELOG.md b/CHANGELOG.md index cbde030da..695586744 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - It is now possible to tell whether `UserStartedSpeakingFrame` or `UserStoppedSpeakingFrame` have been generated because of emulation frames. +### Changed + +- Pipecat services have been reorganized into packages. Each package can have + one or more of the following modules (in the future new module names might be + needed) depending on the services implemented: + - image: for image generation services + - llm: for LLM services + - memory: for memory services + - stt: for Speech-To-Text services + - tts: for Text-To-Speech services + - video: for video generation services + - vision: for video recognition services + +### Deprecated + +- All Pipecat services imports have been deprecated and a warning will be shown + when using the old import. The new import should be + `pipecat.services.[service].[image,llm,memory,stt,tts,video,vision]`. For + example, `from pipecat.services.openai.llm import OpenAILLMService`. + ### Fixed - Fixed an issue that would cause `SegmentedSTTService` based services diff --git a/src/pipecat/services/__init__.py b/src/pipecat/services/__init__.py index e69de29bb..d79c1793d 100644 --- a/src/pipecat/services/__init__.py +++ b/src/pipecat/services/__init__.py @@ -0,0 +1,32 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Any, Dict + + +def _warn_deprecated_access(globals: Dict[str, Any], attr, old: str, new: str): + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + f"Module `pipecat.services.{old}` is deprecated, use `pipecat.services.{new}` instead", + DeprecationWarning, + ) + + return globals[attr] + + +class DeprecatedModuleProxy: + def __init__(self, globals: Dict[str, Any], old: str, new: str): + self._globals = globals + self._old = old + self._new = new + + def __getattr__(self, attr): + if attr in self._globals: + return _warn_deprecated_access(self._globals, attr, self._old, self._new) + raise AttributeError(f"module 'pipecat.{self._old}' has no attribute '{attr}'") diff --git a/src/pipecat/services/anthropic/__init__.py b/src/pipecat/services/anthropic/__init__.py new file mode 100644 index 000000000..37ffe8d99 --- /dev/null +++ b/src/pipecat/services/anthropic/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "anthropic", "anthropic.llm") diff --git a/src/pipecat/services/anthropic.py b/src/pipecat/services/anthropic/llm.py similarity index 100% rename from src/pipecat/services/anthropic.py rename to src/pipecat/services/anthropic/llm.py diff --git a/src/pipecat/services/assemblyai/__init__.py b/src/pipecat/services/assemblyai/__init__.py new file mode 100644 index 000000000..e31a2f393 --- /dev/null +++ b/src/pipecat/services/assemblyai/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .stt import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "assemblyai", "assemblyai.stt") diff --git a/src/pipecat/services/assemblyai.py b/src/pipecat/services/assemblyai/stt.py similarity index 96% rename from src/pipecat/services/assemblyai.py rename to src/pipecat/services/assemblyai/stt.py index 94f081c5b..87dd18bf4 100644 --- a/src/pipecat/services/assemblyai.py +++ b/src/pipecat/services/assemblyai/stt.py @@ -27,9 +27,7 @@ try: from assemblyai import AudioEncoding except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use AssemblyAI, you need to `pip install pipecat-ai[assemblyai]`. Also, set `ASSEMBLYAI_API_KEY` environment variable." - ) + logger.error("In order to use AssemblyAI, you need to `pip install pipecat-ai[assemblyai]`.") raise Exception(f"Missing module: {e}") diff --git a/src/pipecat/services/aws/__init__.py b/src/pipecat/services/aws/__init__.py new file mode 100644 index 000000000..b36c88499 --- /dev/null +++ b/src/pipecat/services/aws/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "aws", "aws.tts") diff --git a/src/pipecat/services/aws.py b/src/pipecat/services/aws/tts.py similarity index 100% rename from src/pipecat/services/aws.py rename to src/pipecat/services/aws/tts.py diff --git a/src/pipecat/services/azure.py b/src/pipecat/services/azure.py deleted file mode 100644 index 9df1a8ef1..000000000 --- a/src/pipecat/services/azure.py +++ /dev/null @@ -1,813 +0,0 @@ -# -# Copyright (c) 2024–2025, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -import asyncio -import io -from typing import AsyncGenerator, Optional - -import aiohttp -from loguru import logger -from openai import AsyncAzureOpenAI -from PIL import Image -from pydantic import BaseModel - -from pipecat.frames.frames import ( - CancelFrame, - EndFrame, - ErrorFrame, - Frame, - StartFrame, - TranscriptionFrame, - TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, - URLImageRawFrame, -) -from pipecat.services.ai_services import ImageGenService, STTService, TTSService -from pipecat.services.openai import ( - OpenAILLMService, -) -from pipecat.transcriptions.language import Language -from pipecat.utils.time import time_now_iso8601 - -# See .env.example for Azure configuration needed -try: - from azure.cognitiveservices.speech import ( - CancellationReason, - ResultReason, - ServicePropertyChannel, - SpeechConfig, - SpeechRecognizer, - SpeechSynthesisOutputFormat, - SpeechSynthesizer, - ) - from azure.cognitiveservices.speech.audio import ( - AudioStreamFormat, - PushAudioInputStream, - ) - from azure.cognitiveservices.speech.dialog import AudioConfig -except ModuleNotFoundError as e: - logger.error(f"Exception: {e}") - logger.error( - "In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables." - ) - raise Exception(f"Missing module: {e}") - - -def language_to_azure_language(language: Language) -> Optional[str]: - language_map = { - # Afrikaans - Language.AF: "af-ZA", - Language.AF_ZA: "af-ZA", - # Amharic - Language.AM: "am-ET", - Language.AM_ET: "am-ET", - # Arabic - Language.AR: "ar-AE", # Default to UAE Arabic - Language.AR_AE: "ar-AE", - Language.AR_BH: "ar-BH", - Language.AR_DZ: "ar-DZ", - Language.AR_EG: "ar-EG", - Language.AR_IQ: "ar-IQ", - Language.AR_JO: "ar-JO", - Language.AR_KW: "ar-KW", - Language.AR_LB: "ar-LB", - Language.AR_LY: "ar-LY", - Language.AR_MA: "ar-MA", - Language.AR_OM: "ar-OM", - Language.AR_QA: "ar-QA", - Language.AR_SA: "ar-SA", - Language.AR_SY: "ar-SY", - Language.AR_TN: "ar-TN", - Language.AR_YE: "ar-YE", - # Assamese - Language.AS: "as-IN", - Language.AS_IN: "as-IN", - # Azerbaijani - Language.AZ: "az-AZ", - Language.AZ_AZ: "az-AZ", - # Bulgarian - Language.BG: "bg-BG", - Language.BG_BG: "bg-BG", - # Bengali - Language.BN: "bn-IN", # Default to Indian Bengali - Language.BN_BD: "bn-BD", - Language.BN_IN: "bn-IN", - # Bosnian - Language.BS: "bs-BA", - Language.BS_BA: "bs-BA", - # Catalan - Language.CA: "ca-ES", - Language.CA_ES: "ca-ES", - # Czech - Language.CS: "cs-CZ", - Language.CS_CZ: "cs-CZ", - # Welsh - Language.CY: "cy-GB", - Language.CY_GB: "cy-GB", - # Danish - Language.DA: "da-DK", - Language.DA_DK: "da-DK", - # German - Language.DE: "de-DE", - Language.DE_AT: "de-AT", - Language.DE_CH: "de-CH", - Language.DE_DE: "de-DE", - # Greek - Language.EL: "el-GR", - Language.EL_GR: "el-GR", - # English - Language.EN: "en-US", # Default to US English - Language.EN_AU: "en-AU", - Language.EN_CA: "en-CA", - Language.EN_GB: "en-GB", - Language.EN_HK: "en-HK", - Language.EN_IE: "en-IE", - Language.EN_IN: "en-IN", - Language.EN_KE: "en-KE", - Language.EN_NG: "en-NG", - Language.EN_NZ: "en-NZ", - Language.EN_PH: "en-PH", - Language.EN_SG: "en-SG", - Language.EN_TZ: "en-TZ", - Language.EN_US: "en-US", - Language.EN_ZA: "en-ZA", - # Spanish - Language.ES: "es-ES", # Default to Spain Spanish - Language.ES_AR: "es-AR", - Language.ES_BO: "es-BO", - Language.ES_CL: "es-CL", - Language.ES_CO: "es-CO", - Language.ES_CR: "es-CR", - Language.ES_CU: "es-CU", - Language.ES_DO: "es-DO", - Language.ES_EC: "es-EC", - Language.ES_ES: "es-ES", - Language.ES_GQ: "es-GQ", - Language.ES_GT: "es-GT", - Language.ES_HN: "es-HN", - Language.ES_MX: "es-MX", - Language.ES_NI: "es-NI", - Language.ES_PA: "es-PA", - Language.ES_PE: "es-PE", - Language.ES_PR: "es-PR", - Language.ES_PY: "es-PY", - Language.ES_SV: "es-SV", - Language.ES_US: "es-US", - Language.ES_UY: "es-UY", - Language.ES_VE: "es-VE", - # Estonian - Language.ET: "et-EE", - Language.ET_EE: "et-EE", - # Basque - Language.EU: "eu-ES", - Language.EU_ES: "eu-ES", - # Persian - Language.FA: "fa-IR", - Language.FA_IR: "fa-IR", - # Finnish - Language.FI: "fi-FI", - Language.FI_FI: "fi-FI", - # Filipino - Language.FIL: "fil-PH", - Language.FIL_PH: "fil-PH", - # French - Language.FR: "fr-FR", - Language.FR_BE: "fr-BE", - Language.FR_CA: "fr-CA", - Language.FR_CH: "fr-CH", - Language.FR_FR: "fr-FR", - # Irish - Language.GA: "ga-IE", - Language.GA_IE: "ga-IE", - # Galician - Language.GL: "gl-ES", - Language.GL_ES: "gl-ES", - # Gujarati - Language.GU: "gu-IN", - Language.GU_IN: "gu-IN", - # Hebrew - Language.HE: "he-IL", - Language.HE_IL: "he-IL", - # Hindi - Language.HI: "hi-IN", - Language.HI_IN: "hi-IN", - # Croatian - Language.HR: "hr-HR", - Language.HR_HR: "hr-HR", - # Hungarian - Language.HU: "hu-HU", - Language.HU_HU: "hu-HU", - # Armenian - Language.HY: "hy-AM", - Language.HY_AM: "hy-AM", - # Indonesian - Language.ID: "id-ID", - Language.ID_ID: "id-ID", - # Icelandic - Language.IS: "is-IS", - Language.IS_IS: "is-IS", - # Italian - Language.IT: "it-IT", - Language.IT_IT: "it-IT", - # Inuktitut - Language.IU_CANS_CA: "iu-Cans-CA", - Language.IU_LATN_CA: "iu-Latn-CA", - # Japanese - Language.JA: "ja-JP", - Language.JA_JP: "ja-JP", - # Javanese - Language.JV: "jv-ID", - Language.JV_ID: "jv-ID", - # Georgian - Language.KA: "ka-GE", - Language.KA_GE: "ka-GE", - # Kazakh - Language.KK: "kk-KZ", - Language.KK_KZ: "kk-KZ", - # Khmer - Language.KM: "km-KH", - Language.KM_KH: "km-KH", - # Kannada - Language.KN: "kn-IN", - Language.KN_IN: "kn-IN", - # Korean - Language.KO: "ko-KR", - Language.KO_KR: "ko-KR", - # Lao - Language.LO: "lo-LA", - Language.LO_LA: "lo-LA", - # Lithuanian - Language.LT: "lt-LT", - Language.LT_LT: "lt-LT", - # Latvian - Language.LV: "lv-LV", - Language.LV_LV: "lv-LV", - # Macedonian - Language.MK: "mk-MK", - Language.MK_MK: "mk-MK", - # Malayalam - Language.ML: "ml-IN", - Language.ML_IN: "ml-IN", - # Mongolian - Language.MN: "mn-MN", - Language.MN_MN: "mn-MN", - # Marathi - Language.MR: "mr-IN", - Language.MR_IN: "mr-IN", - # Malay - Language.MS: "ms-MY", - Language.MS_MY: "ms-MY", - # Maltese - Language.MT: "mt-MT", - Language.MT_MT: "mt-MT", - # Burmese - Language.MY: "my-MM", - Language.MY_MM: "my-MM", - # Norwegian - Language.NB: "nb-NO", - Language.NB_NO: "nb-NO", - Language.NO: "nb-NO", - # Nepali - Language.NE: "ne-NP", - Language.NE_NP: "ne-NP", - # Dutch - Language.NL: "nl-NL", - Language.NL_BE: "nl-BE", - Language.NL_NL: "nl-NL", - # Odia - Language.OR: "or-IN", - Language.OR_IN: "or-IN", - # Punjabi - Language.PA: "pa-IN", - Language.PA_IN: "pa-IN", - # Polish - Language.PL: "pl-PL", - Language.PL_PL: "pl-PL", - # Pashto - Language.PS: "ps-AF", - Language.PS_AF: "ps-AF", - # Portuguese - Language.PT: "pt-PT", - Language.PT_BR: "pt-BR", - Language.PT_PT: "pt-PT", - # Romanian - Language.RO: "ro-RO", - Language.RO_RO: "ro-RO", - # Russian - Language.RU: "ru-RU", - Language.RU_RU: "ru-RU", - # Sinhala - Language.SI: "si-LK", - Language.SI_LK: "si-LK", - # Slovak - Language.SK: "sk-SK", - Language.SK_SK: "sk-SK", - # Slovenian - Language.SL: "sl-SI", - Language.SL_SI: "sl-SI", - # Somali - Language.SO: "so-SO", - Language.SO_SO: "so-SO", - # Albanian - Language.SQ: "sq-AL", - Language.SQ_AL: "sq-AL", - # Serbian - Language.SR: "sr-RS", - Language.SR_RS: "sr-RS", - Language.SR_LATN: "sr-Latn-RS", - Language.SR_LATN_RS: "sr-Latn-RS", - # Sundanese - Language.SU: "su-ID", - Language.SU_ID: "su-ID", - # Swedish - Language.SV: "sv-SE", - Language.SV_SE: "sv-SE", - # Swahili - Language.SW: "sw-KE", - Language.SW_KE: "sw-KE", - Language.SW_TZ: "sw-TZ", - # Tamil - Language.TA: "ta-IN", - Language.TA_IN: "ta-IN", - Language.TA_LK: "ta-LK", - Language.TA_MY: "ta-MY", - Language.TA_SG: "ta-SG", - # Telugu - Language.TE: "te-IN", - Language.TE_IN: "te-IN", - # Thai - Language.TH: "th-TH", - Language.TH_TH: "th-TH", - # Turkish - Language.TR: "tr-TR", - Language.TR_TR: "tr-TR", - # Ukrainian - Language.UK: "uk-UA", - Language.UK_UA: "uk-UA", - # Urdu - Language.UR: "ur-IN", - Language.UR_IN: "ur-IN", - Language.UR_PK: "ur-PK", - # Uzbek - Language.UZ: "uz-UZ", - Language.UZ_UZ: "uz-UZ", - # Vietnamese - Language.VI: "vi-VN", - Language.VI_VN: "vi-VN", - # Wu Chinese - Language.WUU: "wuu-CN", - Language.WUU_CN: "wuu-CN", - # Yue Chinese - Language.YUE: "yue-CN", - Language.YUE_CN: "yue-CN", - # Chinese - Language.ZH: "zh-CN", - Language.ZH_CN: "zh-CN", - Language.ZH_CN_GUANGXI: "zh-CN-guangxi", - Language.ZH_CN_HENAN: "zh-CN-henan", - Language.ZH_CN_LIAONING: "zh-CN-liaoning", - Language.ZH_CN_SHAANXI: "zh-CN-shaanxi", - Language.ZH_CN_SHANDONG: "zh-CN-shandong", - Language.ZH_CN_SICHUAN: "zh-CN-sichuan", - Language.ZH_HK: "zh-HK", - Language.ZH_TW: "zh-TW", - # Zulu - Language.ZU: "zu-ZA", - Language.ZU_ZA: "zu-ZA", - } - return language_map.get(language) - - -def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputFormat: - sample_rate_map = { - 8000: SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm, - 16000: SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm, - 22050: SpeechSynthesisOutputFormat.Raw22050Hz16BitMonoPcm, - 24000: SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm, - 44100: SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm, - 48000: SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm, - } - return sample_rate_map.get(sample_rate, SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm) - - -class AzureLLMService(OpenAILLMService): - """A service for interacting with Azure OpenAI using the OpenAI-compatible interface. - - This service extends OpenAILLMService to connect to Azure's OpenAI endpoint while - maintaining full compatibility with OpenAI's interface and functionality. - - Args: - api_key (str): The API key for accessing Azure OpenAI - endpoint (str): The Azure endpoint URL - model (str): The model identifier to use - api_version (str, optional): Azure API version. Defaults to "2024-09-01-preview" - **kwargs: Additional keyword arguments passed to OpenAILLMService - """ - - def __init__( - self, - *, - api_key: str, - endpoint: str, - model: str, - api_version: str = "2024-09-01-preview", - **kwargs, - ): - # Initialize variables before calling parent __init__() because that - # will call create_client() and we need those values there. - self._endpoint = endpoint - self._api_version = api_version - super().__init__(api_key=api_key, model=model, **kwargs) - - def create_client(self, api_key=None, base_url=None, **kwargs): - """Create OpenAI-compatible client for Azure OpenAI endpoint.""" - logger.debug(f"Creating Azure OpenAI client with endpoint {self._endpoint}") - return AsyncAzureOpenAI( - api_key=api_key, - azure_endpoint=self._endpoint, - api_version=self._api_version, - ) - - -class AzureBaseTTSService(TTSService): - class InputParams(BaseModel): - emphasis: Optional[str] = None - language: Optional[Language] = Language.EN_US - pitch: Optional[str] = None - rate: Optional[str] = "1.05" - role: Optional[str] = None - style: Optional[str] = None - style_degree: Optional[str] = None - volume: Optional[str] = None - - def __init__( - self, - *, - api_key: str, - region: str, - voice="en-US-SaraNeural", - sample_rate: Optional[int] = None, - params: InputParams = InputParams(), - **kwargs, - ): - super().__init__(sample_rate=sample_rate, **kwargs) - - self._settings = { - "emphasis": params.emphasis, - "language": self.language_to_service_language(params.language) - if params.language - else "en-US", - "pitch": params.pitch, - "rate": params.rate, - "role": params.role, - "style": params.style, - "style_degree": params.style_degree, - "volume": params.volume, - } - - self._api_key = api_key - self._region = region - self._voice_id = voice - self._speech_synthesizer = None - - def can_generate_metrics(self) -> bool: - return True - - def language_to_service_language(self, language: Language) -> Optional[str]: - return language_to_azure_language(language) - - def _construct_ssml(self, text: str) -> str: - language = self._settings["language"] - ssml = ( - f"" - f"" - "" - ) - - if self._settings["style"]: - ssml += f"" - - if self._settings["emphasis"]: - ssml += f"" - - ssml += text - - if self._settings["emphasis"]: - ssml += "" - - ssml += "" - - if self._settings["style"]: - ssml += "" - - ssml += "" - - return ssml - - -class AzureTTSService(AzureBaseTTSService): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._speech_config = None - self._speech_synthesizer = None - self._audio_queue = asyncio.Queue() - - async def start(self, frame: StartFrame): - await super().start(frame) - - if self._speech_config: - return - - # Now self.sample_rate is properly initialized - self._speech_config = SpeechConfig( - subscription=self._api_key, - region=self._region, - speech_recognition_language=self._settings["language"], - ) - self._speech_config.set_speech_synthesis_output_format( - sample_rate_to_output_format(self.sample_rate) - ) - self._speech_config.set_service_property( - "synthesizer.synthesis.connection.synthesisConnectionImpl", - "websocket", - ServicePropertyChannel.UriQueryParameter, - ) - - self._speech_synthesizer = SpeechSynthesizer( - speech_config=self._speech_config, audio_config=None - ) - - # Set up event handlers - self._speech_synthesizer.synthesizing.connect(self._handle_synthesizing) - self._speech_synthesizer.synthesis_completed.connect(self._handle_completed) - self._speech_synthesizer.synthesis_canceled.connect(self._handle_canceled) - - def _handle_synthesizing(self, evt): - """Handle audio chunks as they arrive""" - if evt.result and evt.result.audio_data: - self._audio_queue.put_nowait(evt.result.audio_data) - - def _handle_completed(self, evt): - """Handle synthesis completion""" - self._audio_queue.put_nowait(None) # Signal completion - - def _handle_canceled(self, evt): - """Handle synthesis cancellation""" - logger.error(f"Speech synthesis canceled: {evt.result.cancellation_details.reason}") - self._audio_queue.put_nowait(None) - - async def flush_audio(self): - logger.trace(f"{self}: flushing audio") - - async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: - logger.debug(f"{self}: Generating TTS [{text}]") - - try: - if self._speech_synthesizer is None: - error_msg = "Speech synthesizer not initialized." - logger.error(error_msg) - yield ErrorFrame(error_msg) - return - - try: - await self.start_ttfb_metrics() - yield TTSStartedFrame() - - ssml = self._construct_ssml(text) - self._speech_synthesizer.speak_ssml_async(ssml) - await self.start_tts_usage_metrics(text) - - # Stream audio chunks as they arrive - while True: - chunk = await self._audio_queue.get() - if chunk is None: # End of stream - break - - await self.stop_ttfb_metrics() - yield TTSAudioRawFrame( - audio=chunk, - sample_rate=self.sample_rate, - num_channels=1, - ) - - yield TTSStoppedFrame() - - except Exception as e: - logger.error(f"{self} error during synthesis: {e}") - yield TTSStoppedFrame() - # Could add reconnection logic here if needed - return - - except Exception as e: - logger.error(f"{self} exception: {e}") - - -class AzureHttpTTSService(AzureBaseTTSService): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._speech_config = None - self._speech_synthesizer = None - - async def start(self, frame: StartFrame): - await super().start(frame) - - if self._speech_config: - return - - self._speech_config = SpeechConfig( - subscription=self._api_key, - region=self._region, - speech_recognition_language=self._settings["language"], - ) - self._speech_config.set_speech_synthesis_output_format( - sample_rate_to_output_format(self.sample_rate) - ) - self._speech_synthesizer = SpeechSynthesizer( - speech_config=self._speech_config, audio_config=None - ) - - async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: - logger.debug(f"{self}: Generating TTS [{text}]") - - await self.start_ttfb_metrics() - - ssml = self._construct_ssml(text) - - result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, ssml) - - if result.reason == ResultReason.SynthesizingAudioCompleted: - await self.start_tts_usage_metrics(text) - await self.stop_ttfb_metrics() - yield TTSStartedFrame() - # Azure always sends a 44-byte header. Strip it off. - yield TTSAudioRawFrame( - audio=result.audio_data[44:], - sample_rate=self.sample_rate, - num_channels=1, - ) - yield TTSStoppedFrame() - elif result.reason == ResultReason.Canceled: - cancellation_details = result.cancellation_details - logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}") - if cancellation_details.reason == CancellationReason.Error: - logger.error(f"{self} error: {cancellation_details.error_details}") - - -class AzureSTTService(STTService): - def __init__( - self, - *, - api_key: str, - region: str, - language: Language = Language.EN_US, - sample_rate: Optional[int] = None, - **kwargs, - ): - super().__init__(sample_rate=sample_rate, **kwargs) - - self._speech_config = SpeechConfig( - subscription=api_key, - region=region, - speech_recognition_language=language_to_azure_language(language), - ) - - self._audio_stream = None - self._speech_recognizer = None - - async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: - await self.start_processing_metrics() - if self._audio_stream: - self._audio_stream.write(audio) - await self.stop_processing_metrics() - yield None - - async def start(self, frame: StartFrame): - await super().start(frame) - - if self._audio_stream: - return - - stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1) - self._audio_stream = PushAudioInputStream(stream_format) - - audio_config = AudioConfig(stream=self._audio_stream) - - self._speech_recognizer = SpeechRecognizer( - speech_config=self._speech_config, audio_config=audio_config - ) - self._speech_recognizer.recognized.connect(self._on_handle_recognized) - self._speech_recognizer.start_continuous_recognition_async() - - async def stop(self, frame: EndFrame): - await super().stop(frame) - - if self._speech_recognizer: - self._speech_recognizer.stop_continuous_recognition_async() - - if self._audio_stream: - self._audio_stream.close() - - async def cancel(self, frame: CancelFrame): - await super().cancel(frame) - - if self._speech_recognizer: - self._speech_recognizer.stop_continuous_recognition_async() - - if self._audio_stream: - self._audio_stream.close() - - def _on_handle_recognized(self, event): - if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0: - frame = TranscriptionFrame(event.result.text, "", time_now_iso8601()) - asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop()) - - -class AzureImageGenServiceREST(ImageGenService): - def __init__( - self, - *, - image_size: str, - api_key: str, - endpoint: str, - model: str, - aiohttp_session: aiohttp.ClientSession, - api_version="2023-06-01-preview", - ): - super().__init__() - - self._api_key = api_key - self._azure_endpoint = endpoint - self._api_version = api_version - self.set_model_name(model) - self._image_size = image_size - self._aiohttp_session = aiohttp_session - - async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: - url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}" - - headers = {"api-key": self._api_key, "Content-Type": "application/json"} - - body = { - # Enter your prompt text here - "prompt": prompt, - "size": self._image_size, - "n": 1, - } - - async with self._aiohttp_session.post(url, headers=headers, json=body) as submission: - # We never get past this line, because this header isn't - # defined on a 429 response, but something is eating our - # exceptions! - operation_location = submission.headers["operation-location"] - status = "" - attempts_left = 120 - json_response = None - while status != "succeeded": - attempts_left -= 1 - if attempts_left == 0: - logger.error(f"{self} error: image generation timed out") - yield ErrorFrame("Image generation timed out") - return - - await asyncio.sleep(1) - - response = await self._aiohttp_session.get(operation_location, headers=headers) - - json_response = await response.json() - status = json_response["status"] - - image_url = json_response["result"]["data"][0]["url"] if json_response else None - if not image_url: - logger.error(f"{self} error: image generation failed") - yield ErrorFrame("Image generation failed") - return - - # Load the image from the url - async with self._aiohttp_session.get(image_url) as response: - image_stream = io.BytesIO(await response.content.read()) - image = Image.open(image_stream) - frame = URLImageRawFrame( - url=image_url, image=image.tobytes(), size=image.size, format=image.format - ) - yield frame diff --git a/src/pipecat/services/azure/__init__.py b/src/pipecat/services/azure/__init__.py new file mode 100644 index 000000000..26e96993d --- /dev/null +++ b/src/pipecat/services/azure/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * +from .stt import * +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "azure", "azure.[llm,stt,tts]") diff --git a/src/pipecat/services/azure/common.py b/src/pipecat/services/azure/common.py new file mode 100644 index 000000000..054463257 --- /dev/null +++ b/src/pipecat/services/azure/common.py @@ -0,0 +1,336 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Optional + +from loguru import logger + +from pipecat.transcriptions.language import Language + + +def language_to_azure_language(language: Language) -> Optional[str]: + language_map = { + # Afrikaans + Language.AF: "af-ZA", + Language.AF_ZA: "af-ZA", + # Amharic + Language.AM: "am-ET", + Language.AM_ET: "am-ET", + # Arabic + Language.AR: "ar-AE", # Default to UAE Arabic + Language.AR_AE: "ar-AE", + Language.AR_BH: "ar-BH", + Language.AR_DZ: "ar-DZ", + Language.AR_EG: "ar-EG", + Language.AR_IQ: "ar-IQ", + Language.AR_JO: "ar-JO", + Language.AR_KW: "ar-KW", + Language.AR_LB: "ar-LB", + Language.AR_LY: "ar-LY", + Language.AR_MA: "ar-MA", + Language.AR_OM: "ar-OM", + Language.AR_QA: "ar-QA", + Language.AR_SA: "ar-SA", + Language.AR_SY: "ar-SY", + Language.AR_TN: "ar-TN", + Language.AR_YE: "ar-YE", + # Assamese + Language.AS: "as-IN", + Language.AS_IN: "as-IN", + # Azerbaijani + Language.AZ: "az-AZ", + Language.AZ_AZ: "az-AZ", + # Bulgarian + Language.BG: "bg-BG", + Language.BG_BG: "bg-BG", + # Bengali + Language.BN: "bn-IN", # Default to Indian Bengali + Language.BN_BD: "bn-BD", + Language.BN_IN: "bn-IN", + # Bosnian + Language.BS: "bs-BA", + Language.BS_BA: "bs-BA", + # Catalan + Language.CA: "ca-ES", + Language.CA_ES: "ca-ES", + # Czech + Language.CS: "cs-CZ", + Language.CS_CZ: "cs-CZ", + # Welsh + Language.CY: "cy-GB", + Language.CY_GB: "cy-GB", + # Danish + Language.DA: "da-DK", + Language.DA_DK: "da-DK", + # German + Language.DE: "de-DE", + Language.DE_AT: "de-AT", + Language.DE_CH: "de-CH", + Language.DE_DE: "de-DE", + # Greek + Language.EL: "el-GR", + Language.EL_GR: "el-GR", + # English + Language.EN: "en-US", # Default to US English + Language.EN_AU: "en-AU", + Language.EN_CA: "en-CA", + Language.EN_GB: "en-GB", + Language.EN_HK: "en-HK", + Language.EN_IE: "en-IE", + Language.EN_IN: "en-IN", + Language.EN_KE: "en-KE", + Language.EN_NG: "en-NG", + Language.EN_NZ: "en-NZ", + Language.EN_PH: "en-PH", + Language.EN_SG: "en-SG", + Language.EN_TZ: "en-TZ", + Language.EN_US: "en-US", + Language.EN_ZA: "en-ZA", + # Spanish + Language.ES: "es-ES", # Default to Spain Spanish + Language.ES_AR: "es-AR", + Language.ES_BO: "es-BO", + Language.ES_CL: "es-CL", + Language.ES_CO: "es-CO", + Language.ES_CR: "es-CR", + Language.ES_CU: "es-CU", + Language.ES_DO: "es-DO", + Language.ES_EC: "es-EC", + Language.ES_ES: "es-ES", + Language.ES_GQ: "es-GQ", + Language.ES_GT: "es-GT", + Language.ES_HN: "es-HN", + Language.ES_MX: "es-MX", + Language.ES_NI: "es-NI", + Language.ES_PA: "es-PA", + Language.ES_PE: "es-PE", + Language.ES_PR: "es-PR", + Language.ES_PY: "es-PY", + Language.ES_SV: "es-SV", + Language.ES_US: "es-US", + Language.ES_UY: "es-UY", + Language.ES_VE: "es-VE", + # Estonian + Language.ET: "et-EE", + Language.ET_EE: "et-EE", + # Basque + Language.EU: "eu-ES", + Language.EU_ES: "eu-ES", + # Persian + Language.FA: "fa-IR", + Language.FA_IR: "fa-IR", + # Finnish + Language.FI: "fi-FI", + Language.FI_FI: "fi-FI", + # Filipino + Language.FIL: "fil-PH", + Language.FIL_PH: "fil-PH", + # French + Language.FR: "fr-FR", + Language.FR_BE: "fr-BE", + Language.FR_CA: "fr-CA", + Language.FR_CH: "fr-CH", + Language.FR_FR: "fr-FR", + # Irish + Language.GA: "ga-IE", + Language.GA_IE: "ga-IE", + # Galician + Language.GL: "gl-ES", + Language.GL_ES: "gl-ES", + # Gujarati + Language.GU: "gu-IN", + Language.GU_IN: "gu-IN", + # Hebrew + Language.HE: "he-IL", + Language.HE_IL: "he-IL", + # Hindi + Language.HI: "hi-IN", + Language.HI_IN: "hi-IN", + # Croatian + Language.HR: "hr-HR", + Language.HR_HR: "hr-HR", + # Hungarian + Language.HU: "hu-HU", + Language.HU_HU: "hu-HU", + # Armenian + Language.HY: "hy-AM", + Language.HY_AM: "hy-AM", + # Indonesian + Language.ID: "id-ID", + Language.ID_ID: "id-ID", + # Icelandic + Language.IS: "is-IS", + Language.IS_IS: "is-IS", + # Italian + Language.IT: "it-IT", + Language.IT_IT: "it-IT", + # Inuktitut + Language.IU_CANS_CA: "iu-Cans-CA", + Language.IU_LATN_CA: "iu-Latn-CA", + # Japanese + Language.JA: "ja-JP", + Language.JA_JP: "ja-JP", + # Javanese + Language.JV: "jv-ID", + Language.JV_ID: "jv-ID", + # Georgian + Language.KA: "ka-GE", + Language.KA_GE: "ka-GE", + # Kazakh + Language.KK: "kk-KZ", + Language.KK_KZ: "kk-KZ", + # Khmer + Language.KM: "km-KH", + Language.KM_KH: "km-KH", + # Kannada + Language.KN: "kn-IN", + Language.KN_IN: "kn-IN", + # Korean + Language.KO: "ko-KR", + Language.KO_KR: "ko-KR", + # Lao + Language.LO: "lo-LA", + Language.LO_LA: "lo-LA", + # Lithuanian + Language.LT: "lt-LT", + Language.LT_LT: "lt-LT", + # Latvian + Language.LV: "lv-LV", + Language.LV_LV: "lv-LV", + # Macedonian + Language.MK: "mk-MK", + Language.MK_MK: "mk-MK", + # Malayalam + Language.ML: "ml-IN", + Language.ML_IN: "ml-IN", + # Mongolian + Language.MN: "mn-MN", + Language.MN_MN: "mn-MN", + # Marathi + Language.MR: "mr-IN", + Language.MR_IN: "mr-IN", + # Malay + Language.MS: "ms-MY", + Language.MS_MY: "ms-MY", + # Maltese + Language.MT: "mt-MT", + Language.MT_MT: "mt-MT", + # Burmese + Language.MY: "my-MM", + Language.MY_MM: "my-MM", + # Norwegian + Language.NB: "nb-NO", + Language.NB_NO: "nb-NO", + Language.NO: "nb-NO", + # Nepali + Language.NE: "ne-NP", + Language.NE_NP: "ne-NP", + # Dutch + Language.NL: "nl-NL", + Language.NL_BE: "nl-BE", + Language.NL_NL: "nl-NL", + # Odia + Language.OR: "or-IN", + Language.OR_IN: "or-IN", + # Punjabi + Language.PA: "pa-IN", + Language.PA_IN: "pa-IN", + # Polish + Language.PL: "pl-PL", + Language.PL_PL: "pl-PL", + # Pashto + Language.PS: "ps-AF", + Language.PS_AF: "ps-AF", + # Portuguese + Language.PT: "pt-PT", + Language.PT_BR: "pt-BR", + Language.PT_PT: "pt-PT", + # Romanian + Language.RO: "ro-RO", + Language.RO_RO: "ro-RO", + # Russian + Language.RU: "ru-RU", + Language.RU_RU: "ru-RU", + # Sinhala + Language.SI: "si-LK", + Language.SI_LK: "si-LK", + # Slovak + Language.SK: "sk-SK", + Language.SK_SK: "sk-SK", + # Slovenian + Language.SL: "sl-SI", + Language.SL_SI: "sl-SI", + # Somali + Language.SO: "so-SO", + Language.SO_SO: "so-SO", + # Albanian + Language.SQ: "sq-AL", + Language.SQ_AL: "sq-AL", + # Serbian + Language.SR: "sr-RS", + Language.SR_RS: "sr-RS", + Language.SR_LATN: "sr-Latn-RS", + Language.SR_LATN_RS: "sr-Latn-RS", + # Sundanese + Language.SU: "su-ID", + Language.SU_ID: "su-ID", + # Swedish + Language.SV: "sv-SE", + Language.SV_SE: "sv-SE", + # Swahili + Language.SW: "sw-KE", + Language.SW_KE: "sw-KE", + Language.SW_TZ: "sw-TZ", + # Tamil + Language.TA: "ta-IN", + Language.TA_IN: "ta-IN", + Language.TA_LK: "ta-LK", + Language.TA_MY: "ta-MY", + Language.TA_SG: "ta-SG", + # Telugu + Language.TE: "te-IN", + Language.TE_IN: "te-IN", + # Thai + Language.TH: "th-TH", + Language.TH_TH: "th-TH", + # Turkish + Language.TR: "tr-TR", + Language.TR_TR: "tr-TR", + # Ukrainian + Language.UK: "uk-UA", + Language.UK_UA: "uk-UA", + # Urdu + Language.UR: "ur-IN", + Language.UR_IN: "ur-IN", + Language.UR_PK: "ur-PK", + # Uzbek + Language.UZ: "uz-UZ", + Language.UZ_UZ: "uz-UZ", + # Vietnamese + Language.VI: "vi-VN", + Language.VI_VN: "vi-VN", + # Wu Chinese + Language.WUU: "wuu-CN", + Language.WUU_CN: "wuu-CN", + # Yue Chinese + Language.YUE: "yue-CN", + Language.YUE_CN: "yue-CN", + # Chinese + Language.ZH: "zh-CN", + Language.ZH_CN: "zh-CN", + Language.ZH_CN_GUANGXI: "zh-CN-guangxi", + Language.ZH_CN_HENAN: "zh-CN-henan", + Language.ZH_CN_LIAONING: "zh-CN-liaoning", + Language.ZH_CN_SHAANXI: "zh-CN-shaanxi", + Language.ZH_CN_SHANDONG: "zh-CN-shandong", + Language.ZH_CN_SICHUAN: "zh-CN-sichuan", + Language.ZH_HK: "zh-HK", + Language.ZH_TW: "zh-TW", + # Zulu + Language.ZU: "zu-ZA", + Language.ZU_ZA: "zu-ZA", + } + return language_map.get(language) diff --git a/src/pipecat/services/azure/image.py b/src/pipecat/services/azure/image.py new file mode 100644 index 000000000..d86c6075b --- /dev/null +++ b/src/pipecat/services/azure/image.py @@ -0,0 +1,86 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import io +from typing import AsyncGenerator + +import aiohttp +from loguru import logger +from PIL import Image + +from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame +from pipecat.services.ai_services import ImageGenService + + +class AzureImageGenServiceREST(ImageGenService): + def __init__( + self, + *, + image_size: str, + api_key: str, + endpoint: str, + model: str, + aiohttp_session: aiohttp.ClientSession, + api_version="2023-06-01-preview", + ): + super().__init__() + + self._api_key = api_key + self._azure_endpoint = endpoint + self._api_version = api_version + self.set_model_name(model) + self._image_size = image_size + self._aiohttp_session = aiohttp_session + + async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: + url = f"{self._azure_endpoint}openai/images/generations:submit?api-version={self._api_version}" + + headers = {"api-key": self._api_key, "Content-Type": "application/json"} + + body = { + # Enter your prompt text here + "prompt": prompt, + "size": self._image_size, + "n": 1, + } + + async with self._aiohttp_session.post(url, headers=headers, json=body) as submission: + # We never get past this line, because this header isn't + # defined on a 429 response, but something is eating our + # exceptions! + operation_location = submission.headers["operation-location"] + status = "" + attempts_left = 120 + json_response = None + while status != "succeeded": + attempts_left -= 1 + if attempts_left == 0: + logger.error(f"{self} error: image generation timed out") + yield ErrorFrame("Image generation timed out") + return + + await asyncio.sleep(1) + + response = await self._aiohttp_session.get(operation_location, headers=headers) + + json_response = await response.json() + status = json_response["status"] + + image_url = json_response["result"]["data"][0]["url"] if json_response else None + if not image_url: + logger.error(f"{self} error: image generation failed") + yield ErrorFrame("Image generation failed") + return + + # Load the image from the url + async with self._aiohttp_session.get(image_url) as response: + image_stream = io.BytesIO(await response.content.read()) + image = Image.open(image_stream) + frame = URLImageRawFrame( + url=image_url, image=image.tobytes(), size=image.size, format=image.format + ) + yield frame diff --git a/src/pipecat/services/azure/llm.py b/src/pipecat/services/azure/llm.py new file mode 100644 index 000000000..295a1f1c1 --- /dev/null +++ b/src/pipecat/services/azure/llm.py @@ -0,0 +1,49 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from loguru import logger +from openai import AsyncAzureOpenAI + +from pipecat.services.openai.llm import OpenAILLMService + + +class AzureLLMService(OpenAILLMService): + """A service for interacting with Azure OpenAI using the OpenAI-compatible interface. + + This service extends OpenAILLMService to connect to Azure's OpenAI endpoint while + maintaining full compatibility with OpenAI's interface and functionality. + + Args: + api_key (str): The API key for accessing Azure OpenAI + endpoint (str): The Azure endpoint URL + model (str): The model identifier to use + api_version (str, optional): Azure API version. Defaults to "2024-09-01-preview" + **kwargs: Additional keyword arguments passed to OpenAILLMService + """ + + def __init__( + self, + *, + api_key: str, + endpoint: str, + model: str, + api_version: str = "2024-09-01-preview", + **kwargs, + ): + # Initialize variables before calling parent __init__() because that + # will call create_client() and we need those values there. + self._endpoint = endpoint + self._api_version = api_version + super().__init__(api_key=api_key, model=model, **kwargs) + + def create_client(self, api_key=None, base_url=None, **kwargs): + """Create OpenAI-compatible client for Azure OpenAI endpoint.""" + logger.debug(f"Creating Azure OpenAI client with endpoint {self._endpoint}") + return AsyncAzureOpenAI( + api_key=api_key, + azure_endpoint=self._endpoint, + api_version=self._api_version, + ) diff --git a/src/pipecat/services/azure/stt.py b/src/pipecat/services/azure/stt.py new file mode 100644 index 000000000..3e1302029 --- /dev/null +++ b/src/pipecat/services/azure/stt.py @@ -0,0 +1,107 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +from typing import AsyncGenerator, Optional + +from loguru import logger + +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + Frame, + StartFrame, + TranscriptionFrame, +) +from pipecat.services.ai_services import STTService +from pipecat.services.azure.common import language_to_azure_language +from pipecat.transcriptions.language import Language +from pipecat.utils.time import time_now_iso8601 + +try: + from azure.cognitiveservices.speech import ( + ResultReason, + SpeechConfig, + SpeechRecognizer, + ) + from azure.cognitiveservices.speech.audio import ( + AudioStreamFormat, + PushAudioInputStream, + ) + from azure.cognitiveservices.speech.dialog import AudioConfig +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Azure, you need to `pip install pipecat-ai[azure]`.") + raise Exception(f"Missing module: {e}") + + +class AzureSTTService(STTService): + def __init__( + self, + *, + api_key: str, + region: str, + language: Language = Language.EN_US, + sample_rate: Optional[int] = None, + **kwargs, + ): + super().__init__(sample_rate=sample_rate, **kwargs) + + self._speech_config = SpeechConfig( + subscription=api_key, + region=region, + speech_recognition_language=language_to_azure_language(language), + ) + + self._audio_stream = None + self._speech_recognizer = None + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + await self.start_processing_metrics() + if self._audio_stream: + self._audio_stream.write(audio) + await self.stop_processing_metrics() + yield None + + async def start(self, frame: StartFrame): + await super().start(frame) + + if self._audio_stream: + return + + stream_format = AudioStreamFormat(samples_per_second=self.sample_rate, channels=1) + self._audio_stream = PushAudioInputStream(stream_format) + + audio_config = AudioConfig(stream=self._audio_stream) + + self._speech_recognizer = SpeechRecognizer( + speech_config=self._speech_config, audio_config=audio_config + ) + self._speech_recognizer.recognized.connect(self._on_handle_recognized) + self._speech_recognizer.start_continuous_recognition_async() + + async def stop(self, frame: EndFrame): + await super().stop(frame) + + if self._speech_recognizer: + self._speech_recognizer.stop_continuous_recognition_async() + + if self._audio_stream: + self._audio_stream.close() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + + if self._speech_recognizer: + self._speech_recognizer.stop_continuous_recognition_async() + + if self._audio_stream: + self._audio_stream.close() + + def _on_handle_recognized(self, event): + if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0: + frame = TranscriptionFrame(event.result.text, "", time_now_iso8601()) + asyncio.run_coroutine_threadsafe(self.push_frame(frame), self.get_event_loop()) diff --git a/src/pipecat/services/azure/tts.py b/src/pipecat/services/azure/tts.py new file mode 100644 index 000000000..1227a4e96 --- /dev/null +++ b/src/pipecat/services/azure/tts.py @@ -0,0 +1,290 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +from typing import AsyncGenerator, Optional + +from loguru import logger +from pydantic import BaseModel + +from pipecat.frames.frames import ( + ErrorFrame, + Frame, + StartFrame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.ai_services import TTSService +from pipecat.services.azure.common import language_to_azure_language +from pipecat.transcriptions.language import Language + +try: + from azure.cognitiveservices.speech import ( + CancellationReason, + ResultReason, + ServicePropertyChannel, + SpeechConfig, + SpeechSynthesisOutputFormat, + SpeechSynthesizer, + ) +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Azure, you need to `pip install pipecat-ai[azure]`.") + raise Exception(f"Missing module: {e}") + + +def sample_rate_to_output_format(sample_rate: int) -> SpeechSynthesisOutputFormat: + sample_rate_map = { + 8000: SpeechSynthesisOutputFormat.Raw8Khz16BitMonoPcm, + 16000: SpeechSynthesisOutputFormat.Raw16Khz16BitMonoPcm, + 22050: SpeechSynthesisOutputFormat.Raw22050Hz16BitMonoPcm, + 24000: SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm, + 44100: SpeechSynthesisOutputFormat.Raw44100Hz16BitMonoPcm, + 48000: SpeechSynthesisOutputFormat.Raw48Khz16BitMonoPcm, + } + return sample_rate_map.get(sample_rate, SpeechSynthesisOutputFormat.Raw24Khz16BitMonoPcm) + + +class AzureBaseTTSService(TTSService): + class InputParams(BaseModel): + emphasis: Optional[str] = None + language: Optional[Language] = Language.EN_US + pitch: Optional[str] = None + rate: Optional[str] = "1.05" + role: Optional[str] = None + style: Optional[str] = None + style_degree: Optional[str] = None + volume: Optional[str] = None + + def __init__( + self, + *, + api_key: str, + region: str, + voice="en-US-SaraNeural", + sample_rate: Optional[int] = None, + params: InputParams = InputParams(), + **kwargs, + ): + super().__init__(sample_rate=sample_rate, **kwargs) + + self._settings = { + "emphasis": params.emphasis, + "language": self.language_to_service_language(params.language) + if params.language + else "en-US", + "pitch": params.pitch, + "rate": params.rate, + "role": params.role, + "style": params.style, + "style_degree": params.style_degree, + "volume": params.volume, + } + + self._api_key = api_key + self._region = region + self._voice_id = voice + self._speech_synthesizer = None + + def can_generate_metrics(self) -> bool: + return True + + def language_to_service_language(self, language: Language) -> Optional[str]: + return language_to_azure_language(language) + + def _construct_ssml(self, text: str) -> str: + language = self._settings["language"] + ssml = ( + f"" + f"" + "" + ) + + if self._settings["style"]: + ssml += f"" + + if self._settings["emphasis"]: + ssml += f"" + + ssml += text + + if self._settings["emphasis"]: + ssml += "" + + ssml += "" + + if self._settings["style"]: + ssml += "" + + ssml += "" + + return ssml + + +class AzureTTSService(AzureBaseTTSService): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._speech_config = None + self._speech_synthesizer = None + self._audio_queue = asyncio.Queue() + + async def start(self, frame: StartFrame): + await super().start(frame) + + if self._speech_config: + return + + # Now self.sample_rate is properly initialized + self._speech_config = SpeechConfig( + subscription=self._api_key, + region=self._region, + speech_recognition_language=self._settings["language"], + ) + self._speech_config.set_speech_synthesis_output_format( + sample_rate_to_output_format(self.sample_rate) + ) + self._speech_config.set_service_property( + "synthesizer.synthesis.connection.synthesisConnectionImpl", + "websocket", + ServicePropertyChannel.UriQueryParameter, + ) + + self._speech_synthesizer = SpeechSynthesizer( + speech_config=self._speech_config, audio_config=None + ) + + # Set up event handlers + self._speech_synthesizer.synthesizing.connect(self._handle_synthesizing) + self._speech_synthesizer.synthesis_completed.connect(self._handle_completed) + self._speech_synthesizer.synthesis_canceled.connect(self._handle_canceled) + + def _handle_synthesizing(self, evt): + """Handle audio chunks as they arrive""" + if evt.result and evt.result.audio_data: + self._audio_queue.put_nowait(evt.result.audio_data) + + def _handle_completed(self, evt): + """Handle synthesis completion""" + self._audio_queue.put_nowait(None) # Signal completion + + def _handle_canceled(self, evt): + """Handle synthesis cancellation""" + logger.error(f"Speech synthesis canceled: {evt.result.cancellation_details.reason}") + self._audio_queue.put_nowait(None) + + async def flush_audio(self): + logger.trace(f"{self}: flushing audio") + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"{self}: Generating TTS [{text}]") + + try: + if self._speech_synthesizer is None: + error_msg = "Speech synthesizer not initialized." + logger.error(error_msg) + yield ErrorFrame(error_msg) + return + + try: + await self.start_ttfb_metrics() + yield TTSStartedFrame() + + ssml = self._construct_ssml(text) + self._speech_synthesizer.speak_ssml_async(ssml) + await self.start_tts_usage_metrics(text) + + # Stream audio chunks as they arrive + while True: + chunk = await self._audio_queue.get() + if chunk is None: # End of stream + break + + await self.stop_ttfb_metrics() + yield TTSAudioRawFrame( + audio=chunk, + sample_rate=self.sample_rate, + num_channels=1, + ) + + yield TTSStoppedFrame() + + except Exception as e: + logger.error(f"{self} error during synthesis: {e}") + yield TTSStoppedFrame() + # Could add reconnection logic here if needed + return + + except Exception as e: + logger.error(f"{self} exception: {e}") + + +class AzureHttpTTSService(AzureBaseTTSService): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._speech_config = None + self._speech_synthesizer = None + + async def start(self, frame: StartFrame): + await super().start(frame) + + if self._speech_config: + return + + self._speech_config = SpeechConfig( + subscription=self._api_key, + region=self._region, + speech_recognition_language=self._settings["language"], + ) + self._speech_config.set_speech_synthesis_output_format( + sample_rate_to_output_format(self.sample_rate) + ) + self._speech_synthesizer = SpeechSynthesizer( + speech_config=self._speech_config, audio_config=None + ) + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"{self}: Generating TTS [{text}]") + + await self.start_ttfb_metrics() + + ssml = self._construct_ssml(text) + + result = await asyncio.to_thread(self._speech_synthesizer.speak_ssml, ssml) + + if result.reason == ResultReason.SynthesizingAudioCompleted: + await self.start_tts_usage_metrics(text) + await self.stop_ttfb_metrics() + yield TTSStartedFrame() + # Azure always sends a 44-byte header. Strip it off. + yield TTSAudioRawFrame( + audio=result.audio_data[44:], + sample_rate=self.sample_rate, + num_channels=1, + ) + yield TTSStoppedFrame() + elif result.reason == ResultReason.Canceled: + cancellation_details = result.cancellation_details + logger.warning(f"Speech synthesis canceled: {cancellation_details.reason}") + if cancellation_details.reason == CancellationReason.Error: + logger.error(f"{self} error: {cancellation_details.error_details}") diff --git a/src/pipecat/services/canonical/__init__.py b/src/pipecat/services/canonical/__init__.py new file mode 100644 index 000000000..f47b99c4e --- /dev/null +++ b/src/pipecat/services/canonical/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .metrics import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "canonical", "canonical.metrics") diff --git a/src/pipecat/services/canonical.py b/src/pipecat/services/canonical/metrics.py similarity index 100% rename from src/pipecat/services/canonical.py rename to src/pipecat/services/canonical/metrics.py diff --git a/src/pipecat/services/cartesia/__init__.py b/src/pipecat/services/cartesia/__init__.py new file mode 100644 index 000000000..56c789743 --- /dev/null +++ b/src/pipecat/services/cartesia/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "cartesia", "cartesia.tts") diff --git a/src/pipecat/services/cartesia.py b/src/pipecat/services/cartesia/tts.py similarity index 98% rename from src/pipecat/services/cartesia.py rename to src/pipecat/services/cartesia/tts.py index 11796464d..45c2fcaf8 100644 --- a/src/pipecat/services/cartesia.py +++ b/src/pipecat/services/cartesia/tts.py @@ -35,9 +35,7 @@ try: from cartesia import AsyncCartesia except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`. Also, set `CARTESIA_API_KEY` environment variable." - ) + logger.error("In order to use Cartesia, you need to `pip install pipecat-ai[cartesia]`.") raise Exception(f"Missing module: {e}") diff --git a/src/pipecat/services/cerebras/__init__.py b/src/pipecat/services/cerebras/__init__.py new file mode 100644 index 000000000..726a227de --- /dev/null +++ b/src/pipecat/services/cerebras/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "cerebras", "cerebras.llm") diff --git a/src/pipecat/services/cerebras.py b/src/pipecat/services/cerebras/llm.py similarity index 98% rename from src/pipecat/services/cerebras.py rename to src/pipecat/services/cerebras/llm.py index b5e34afe5..2217cc2f8 100644 --- a/src/pipecat/services/cerebras.py +++ b/src/pipecat/services/cerebras/llm.py @@ -11,7 +11,7 @@ from openai import AsyncStream from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.openai import OpenAILLMService +from pipecat.services.openai.llm import OpenAILLMService class CerebrasLLMService(OpenAILLMService): diff --git a/src/pipecat/services/deepgram/__init__.py b/src/pipecat/services/deepgram/__init__.py new file mode 100644 index 000000000..bb7a3c1f8 --- /dev/null +++ b/src/pipecat/services/deepgram/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .stt import * +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "deepgram", "deepgram.[stt,tts]") diff --git a/src/pipecat/services/deepgram.py b/src/pipecat/services/deepgram/stt.py similarity index 74% rename from src/pipecat/services/deepgram.py rename to src/pipecat/services/deepgram/stt.py index 689d55e9f..ae8b2318a 100644 --- a/src/pipecat/services/deepgram.py +++ b/src/pipecat/services/deepgram/stt.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio from typing import AsyncGenerator, Dict, Optional from loguru import logger @@ -12,23 +11,18 @@ from loguru import logger from pipecat.frames.frames import ( CancelFrame, EndFrame, - ErrorFrame, Frame, InterimTranscriptionFrame, StartFrame, TranscriptionFrame, - TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, UserStartedSpeakingFrame, UserStoppedSpeakingFrame, ) from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import STTService, TTSService +from pipecat.services.ai_services import STTService from pipecat.transcriptions.language import Language from pipecat.utils.time import time_now_iso8601 -# See .env.example for Deepgram configuration needed try: from deepgram import ( AsyncListenWebSocketClient, @@ -38,80 +32,13 @@ try: LiveOptions, LiveResultResponse, LiveTranscriptionEvents, - SpeakOptions, ) except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`. Also, set `DEEPGRAM_API_KEY` environment variable." - ) + logger.error("In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`.") raise Exception(f"Missing module: {e}") -class DeepgramTTSService(TTSService): - def __init__( - self, - *, - api_key: str, - voice: str = "aura-helios-en", - sample_rate: Optional[int] = None, - encoding: str = "linear16", - **kwargs, - ): - super().__init__(sample_rate=sample_rate, **kwargs) - - self._settings = { - "encoding": encoding, - } - self.set_voice(voice) - self._deepgram_client = DeepgramClient(api_key=api_key) - - def can_generate_metrics(self) -> bool: - return True - - async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: - logger.debug(f"{self}: Generating TTS [{text}]") - - options = SpeakOptions( - model=self._voice_id, - encoding=self._settings["encoding"], - sample_rate=self.sample_rate, - container="none", - ) - - try: - await self.start_ttfb_metrics() - - response = await asyncio.to_thread( - self._deepgram_client.speak.v("1").stream, {"text": text}, options - ) - - await self.start_tts_usage_metrics(text) - yield TTSStartedFrame() - - # The response.stream_memory is already a BytesIO object - audio_buffer = response.stream_memory - - if audio_buffer is None: - raise ValueError("No audio data received from Deepgram") - - # Read and yield the audio data in chunks - audio_buffer.seek(0) # Ensure we're at the start of the buffer - chunk_size = 1024 # Use a fixed buffer size - while True: - await self.stop_ttfb_metrics() - chunk = audio_buffer.read(chunk_size) - if not chunk: - break - frame = TTSAudioRawFrame(audio=chunk, sample_rate=self.sample_rate, num_channels=1) - yield frame - yield TTSStoppedFrame() - - except Exception as e: - logger.exception(f"{self} exception: {e}") - yield ErrorFrame(f"Error getting audio: {str(e)}") - - class DeepgramSTTService(STTService): def __init__( self, diff --git a/src/pipecat/services/deepgram/tts.py b/src/pipecat/services/deepgram/tts.py new file mode 100644 index 000000000..5e05292d9 --- /dev/null +++ b/src/pipecat/services/deepgram/tts.py @@ -0,0 +1,90 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +from typing import AsyncGenerator, Optional + +from loguru import logger + +from pipecat.frames.frames import ( + ErrorFrame, + Frame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.ai_services import TTSService + +try: + from deepgram import DeepgramClient, SpeakOptions +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`.") + raise Exception(f"Missing module: {e}") + + +class DeepgramTTSService(TTSService): + def __init__( + self, + *, + api_key: str, + voice: str = "aura-helios-en", + sample_rate: Optional[int] = None, + encoding: str = "linear16", + **kwargs, + ): + super().__init__(sample_rate=sample_rate, **kwargs) + + self._settings = { + "encoding": encoding, + } + self.set_voice(voice) + self._deepgram_client = DeepgramClient(api_key=api_key) + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"{self}: Generating TTS [{text}]") + + options = SpeakOptions( + model=self._voice_id, + encoding=self._settings["encoding"], + sample_rate=self.sample_rate, + container="none", + ) + + try: + await self.start_ttfb_metrics() + + response = await asyncio.to_thread( + self._deepgram_client.speak.v("1").stream, {"text": text}, options + ) + + await self.start_tts_usage_metrics(text) + yield TTSStartedFrame() + + # The response.stream_memory is already a BytesIO object + audio_buffer = response.stream_memory + + if audio_buffer is None: + raise ValueError("No audio data received from Deepgram") + + # Read and yield the audio data in chunks + audio_buffer.seek(0) # Ensure we're at the start of the buffer + chunk_size = 1024 # Use a fixed buffer size + while True: + await self.stop_ttfb_metrics() + chunk = audio_buffer.read(chunk_size) + if not chunk: + break + frame = TTSAudioRawFrame(audio=chunk, sample_rate=self.sample_rate, num_channels=1) + yield frame + yield TTSStoppedFrame() + + except Exception as e: + logger.exception(f"{self} exception: {e}") + yield ErrorFrame(f"Error getting audio: {str(e)}") diff --git a/src/pipecat/services/deepseek/__init__.py b/src/pipecat/services/deepseek/__init__.py new file mode 100644 index 000000000..1e483d95f --- /dev/null +++ b/src/pipecat/services/deepseek/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "deepseek", "deepseek.llm") diff --git a/src/pipecat/services/deepseek.py b/src/pipecat/services/deepseek/llm.py similarity index 98% rename from src/pipecat/services/deepseek.py rename to src/pipecat/services/deepseek/llm.py index 2537f2f7c..7bed5d33b 100644 --- a/src/pipecat/services/deepseek.py +++ b/src/pipecat/services/deepseek/llm.py @@ -12,7 +12,7 @@ from openai import AsyncStream from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.openai import OpenAILLMService +from pipecat.services.openai.llm import OpenAILLMService class DeepSeekLLMService(OpenAILLMService): diff --git a/src/pipecat/services/elevenlabs/__init__.py b/src/pipecat/services/elevenlabs/__init__.py new file mode 100644 index 000000000..e5a76e71a --- /dev/null +++ b/src/pipecat/services/elevenlabs/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "elevenlabs", "elevenlabs.tts") diff --git a/src/pipecat/services/elevenlabs.py b/src/pipecat/services/elevenlabs/tts.py similarity index 99% rename from src/pipecat/services/elevenlabs.py rename to src/pipecat/services/elevenlabs/tts.py index 68c71a144..7b4e4f0dc 100644 --- a/src/pipecat/services/elevenlabs.py +++ b/src/pipecat/services/elevenlabs/tts.py @@ -33,9 +33,7 @@ try: import websockets except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use ElevenLabs, you need to `pip install pipecat-ai[elevenlabs]`. Also, set `ELEVENLABS_API_KEY` environment variable." - ) + logger.error("In order to use ElevenLabs, you need to `pip install pipecat-ai[elevenlabs]`.") raise Exception(f"Missing module: {e}") ElevenLabsOutputFormat = Literal["pcm_16000", "pcm_22050", "pcm_24000", "pcm_44100"] diff --git a/src/pipecat/services/fal/__init__.py b/src/pipecat/services/fal/__init__.py new file mode 100644 index 000000000..6235b4812 --- /dev/null +++ b/src/pipecat/services/fal/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .image import * +from .stt import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "fal", "fal.[image,stt]") diff --git a/src/pipecat/services/fal/image.py b/src/pipecat/services/fal/image.py new file mode 100644 index 000000000..6ba14caf9 --- /dev/null +++ b/src/pipecat/services/fal/image.py @@ -0,0 +1,84 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import io +import os +from typing import AsyncGenerator, Dict, Optional, Union + +import aiohttp +from loguru import logger +from PIL import Image +from pydantic import BaseModel + +from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame +from pipecat.services.ai_services import ImageGenService + +try: + import fal_client +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Fal, you need to `pip install pipecat-ai[fal]`.") + raise Exception(f"Missing module: {e}") + + +class FalImageGenService(ImageGenService): + class InputParams(BaseModel): + seed: Optional[int] = None + num_inference_steps: int = 8 + num_images: int = 1 + image_size: Union[str, Dict[str, int]] = "square_hd" + expand_prompt: bool = False + enable_safety_checker: bool = True + format: str = "png" + + def __init__( + self, + *, + params: InputParams, + aiohttp_session: aiohttp.ClientSession, + model: str = "fal-ai/fast-sdxl", + key: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.set_model_name(model) + self._params = params + self._aiohttp_session = aiohttp_session + if key: + os.environ["FAL_KEY"] = key + + async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: + def load_image_bytes(encoded_image: bytes): + buffer = io.BytesIO(encoded_image) + image = Image.open(buffer) + return (image.tobytes(), image.size, image.format) + + logger.debug(f"Generating image from prompt: {prompt}") + + response = await fal_client.run_async( + self.model_name, + arguments={"prompt": prompt, **self._params.model_dump(exclude_none=True)}, + ) + + image_url = response["images"][0]["url"] if response else None + + if not image_url: + logger.error(f"{self} error: image generation failed") + yield ErrorFrame("Image generation failed") + return + + logger.debug(f"Image generated at: {image_url}") + + # Load the image from the url + logger.debug(f"Downloading image {image_url} ...") + async with self._aiohttp_session.get(image_url) as response: + logger.debug(f"Downloaded image {image_url}") + encoded_image = await response.content.read() + (image_bytes, size, format) = await asyncio.to_thread(load_image_bytes, encoded_image) + + frame = URLImageRawFrame(url=image_url, image=image_bytes, size=size, format=format) + yield frame diff --git a/src/pipecat/services/fal.py b/src/pipecat/services/fal/stt.py similarity index 76% rename from src/pipecat/services/fal.py rename to src/pipecat/services/fal/stt.py index cb39da75f..4926e4718 100644 --- a/src/pipecat/services/fal.py +++ b/src/pipecat/services/fal/stt.py @@ -4,19 +4,14 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio -import io import os -import wave -from typing import AsyncGenerator, Dict, Optional, Union +from typing import AsyncGenerator, Optional -import aiohttp from loguru import logger -from PIL import Image from pydantic import BaseModel -from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame, URLImageRawFrame -from pipecat.services.ai_services import ImageGenService, SegmentedSTTService +from pipecat.frames.frames import ErrorFrame, Frame, TranscriptionFrame +from pipecat.services.ai_services import SegmentedSTTService from pipecat.transcriptions.language import Language from pipecat.utils.time import time_now_iso8601 @@ -144,65 +139,6 @@ def language_to_fal_language(language: Language) -> Optional[str]: return result -class FalImageGenService(ImageGenService): - class InputParams(BaseModel): - seed: Optional[int] = None - num_inference_steps: int = 8 - num_images: int = 1 - image_size: Union[str, Dict[str, int]] = "square_hd" - expand_prompt: bool = False - enable_safety_checker: bool = True - format: str = "png" - - def __init__( - self, - *, - params: InputParams, - aiohttp_session: aiohttp.ClientSession, - model: str = "fal-ai/fast-sdxl", - key: Optional[str] = None, - **kwargs, - ): - super().__init__(**kwargs) - self.set_model_name(model) - self._params = params - self._aiohttp_session = aiohttp_session - if key: - os.environ["FAL_KEY"] = key - - async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: - def load_image_bytes(encoded_image: bytes): - buffer = io.BytesIO(encoded_image) - image = Image.open(buffer) - return (image.tobytes(), image.size, image.format) - - logger.debug(f"Generating image from prompt: {prompt}") - - response = await fal_client.run_async( - self.model_name, - arguments={"prompt": prompt, **self._params.model_dump(exclude_none=True)}, - ) - - image_url = response["images"][0]["url"] if response else None - - if not image_url: - logger.error(f"{self} error: image generation failed") - yield ErrorFrame("Image generation failed") - return - - logger.debug(f"Image generated at: {image_url}") - - # Load the image from the url - logger.debug(f"Downloading image {image_url} ...") - async with self._aiohttp_session.get(image_url) as response: - logger.debug(f"Downloaded image {image_url}") - encoded_image = await response.content.read() - (image_bytes, size, format) = await asyncio.to_thread(load_image_bytes, encoded_image) - - frame = URLImageRawFrame(url=image_url, image=image_bytes, size=size, format=format) - yield frame - - class FalSTTService(SegmentedSTTService): """Speech-to-text service using Fal's Wizper API. diff --git a/src/pipecat/services/fireworks/__init__.py b/src/pipecat/services/fireworks/__init__.py new file mode 100644 index 000000000..c308ad967 --- /dev/null +++ b/src/pipecat/services/fireworks/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "fireworks", "fireworks.llm") diff --git a/src/pipecat/services/fireworks.py b/src/pipecat/services/fireworks/llm.py similarity index 97% rename from src/pipecat/services/fireworks.py rename to src/pipecat/services/fireworks/llm.py index 8e40e1a7b..d4003f86f 100644 --- a/src/pipecat/services/fireworks.py +++ b/src/pipecat/services/fireworks/llm.py @@ -11,7 +11,7 @@ from loguru import logger from openai.types.chat import ChatCompletionMessageParam from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.openai import OpenAILLMService +from pipecat.services.openai.llm import OpenAILLMService class FireworksLLMService(OpenAILLMService): diff --git a/src/pipecat/services/fish/__init__.py b/src/pipecat/services/fish/__init__.py new file mode 100644 index 000000000..a783d6224 --- /dev/null +++ b/src/pipecat/services/fish/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "fish", "fish.tts") diff --git a/src/pipecat/services/fish.py b/src/pipecat/services/fish/tts.py similarity index 98% rename from src/pipecat/services/fish.py rename to src/pipecat/services/fish/tts.py index 9e6a8b91e..d4fe59635 100644 --- a/src/pipecat/services/fish.py +++ b/src/pipecat/services/fish/tts.py @@ -30,9 +30,7 @@ try: import websockets except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`. Also, set `FISH_API_KEY` environment variable." - ) + logger.error("In order to use Fish Audio, you need to `pip install pipecat-ai[fish]`.") raise Exception(f"Missing module: {e}") # FishAudio supports various output formats diff --git a/src/pipecat/services/gemini_multimodal_live/gemini.py b/src/pipecat/services/gemini_multimodal_live/gemini.py index 965648ada..3ecab0186 100644 --- a/src/pipecat/services/gemini_multimodal_live/gemini.py +++ b/src/pipecat/services/gemini_multimodal_live/gemini.py @@ -51,7 +51,7 @@ from pipecat.processors.aggregators.openai_llm_context import ( ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import LLMService -from pipecat.services.openai import ( +from pipecat.services.openai.llm import ( OpenAIAssistantContextAggregator, OpenAIUserContextAggregator, ) diff --git a/src/pipecat/services/google/__init__.py b/src/pipecat/services/google/__init__.py index 3e63a3ba9..ec187000f 100644 --- a/src/pipecat/services/google/__init__.py +++ b/src/pipecat/services/google/__init__.py @@ -1,3 +1,22 @@ -from .frames import LLMSearchResponseFrame -from .google import * +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .frames import * +from .image import * +from .llm import * +from .llm_openai import * +from .llm_vertex import * from .rtvi import * +from .stt import * +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy( + globals(), "google", "google.[frames,image,llm,llm_openai,llm_vertex,rtvi,stt,tts]" +) diff --git a/src/pipecat/services/google/google.py b/src/pipecat/services/google/google.py index 95c5a1edb..ec187000f 100644 --- a/src/pipecat/services/google/google.py +++ b/src/pipecat/services/google/google.py @@ -4,2089 +4,19 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import asyncio -import base64 -import io -import json -import os -import time -import uuid +import sys -from google.api_core.exceptions import DeadlineExceeded -from openai import AsyncStream -from openai.types.chat import ChatCompletionChunk +from pipecat.services import DeprecatedModuleProxy -from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter +from .frames import * +from .image import * +from .llm import * +from .llm_openai import * +from .llm_vertex import * +from .rtvi import * +from .stt import * +from .tts import * -# Suppress gRPC fork warnings -os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" - -from dataclasses import dataclass -from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional, Union - -from loguru import logger -from PIL import Image -from pydantic import BaseModel, Field, field_validator - -from pipecat.frames.frames import ( - AudioRawFrame, - CancelFrame, - EndFrame, - ErrorFrame, - Frame, - FunctionCallCancelFrame, - FunctionCallInProgressFrame, - FunctionCallResultFrame, - InterimTranscriptionFrame, - LLMFullResponseEndFrame, - LLMFullResponseStartFrame, - LLMMessagesFrame, - LLMTextFrame, - LLMUpdateSettingsFrame, - StartFrame, - TranscriptionFrame, - TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, - URLImageRawFrame, - UserImageRawFrame, - VisionImageRawFrame, +sys.modules[__name__] = DeprecatedModuleProxy( + globals(), "google", "google.[frames,image,llm,llm_openai,llm_vertex,rtvi,stt,tts]" ) -from pipecat.metrics.metrics import LLMTokenUsage -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import ImageGenService, LLMService, STTService, TTSService -from pipecat.services.google.frames import LLMSearchResponseFrame -from pipecat.services.openai import ( - OpenAIAssistantContextAggregator, - OpenAILLMService, - OpenAIUnhandledFunctionException, - OpenAIUserContextAggregator, -) -from pipecat.transcriptions.language import Language -from pipecat.utils.time import time_now_iso8601 - -try: - import google.ai.generativelanguage as glm - import google.generativeai as gai - from google import genai - from google.api_core.client_options import ClientOptions - from google.auth.transport.requests import Request - from google.cloud import speech_v2, texttospeech_v1 - from google.cloud.speech_v2.types import cloud_speech - from google.genai import types - from google.generativeai.types import GenerationConfig - from google.oauth2 import service_account - -except ModuleNotFoundError as e: - logger.error(f"Exception: {e}") - logger.error( - "In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set the environment variable GOOGLE_API_KEY for the GoogleLLMService and GOOGLE_APPLICATION_CREDENTIALS for the GoogleTTSService`." - ) - raise Exception(f"Missing module: {e}") - - -def language_to_google_tts_language(language: Language) -> Optional[str]: - language_map = { - # Afrikaans - Language.AF: "af-ZA", - Language.AF_ZA: "af-ZA", - # Arabic - Language.AR: "ar-XA", - # Bengali - Language.BN: "bn-IN", - Language.BN_IN: "bn-IN", - # Bulgarian - Language.BG: "bg-BG", - Language.BG_BG: "bg-BG", - # Catalan - Language.CA: "ca-ES", - Language.CA_ES: "ca-ES", - # Chinese (Mandarin and Cantonese) - Language.ZH: "cmn-CN", - Language.ZH_CN: "cmn-CN", - Language.ZH_TW: "cmn-TW", - Language.ZH_HK: "yue-HK", - # Czech - Language.CS: "cs-CZ", - Language.CS_CZ: "cs-CZ", - # Danish - Language.DA: "da-DK", - Language.DA_DK: "da-DK", - # Dutch - Language.NL: "nl-NL", - Language.NL_BE: "nl-BE", - Language.NL_NL: "nl-NL", - # English - Language.EN: "en-US", - Language.EN_US: "en-US", - Language.EN_AU: "en-AU", - Language.EN_GB: "en-GB", - Language.EN_IN: "en-IN", - # Estonian - Language.ET: "et-EE", - Language.ET_EE: "et-EE", - # Filipino - Language.FIL: "fil-PH", - Language.FIL_PH: "fil-PH", - # Finnish - Language.FI: "fi-FI", - Language.FI_FI: "fi-FI", - # French - Language.FR: "fr-FR", - Language.FR_CA: "fr-CA", - Language.FR_FR: "fr-FR", - # Galician - Language.GL: "gl-ES", - Language.GL_ES: "gl-ES", - # German - Language.DE: "de-DE", - Language.DE_DE: "de-DE", - # Greek - Language.EL: "el-GR", - Language.EL_GR: "el-GR", - # Gujarati - Language.GU: "gu-IN", - Language.GU_IN: "gu-IN", - # Hebrew - Language.HE: "he-IL", - Language.HE_IL: "he-IL", - # Hindi - Language.HI: "hi-IN", - Language.HI_IN: "hi-IN", - # Hungarian - Language.HU: "hu-HU", - Language.HU_HU: "hu-HU", - # Icelandic - Language.IS: "is-IS", - Language.IS_IS: "is-IS", - # Indonesian - Language.ID: "id-ID", - Language.ID_ID: "id-ID", - # Italian - Language.IT: "it-IT", - Language.IT_IT: "it-IT", - # Japanese - Language.JA: "ja-JP", - Language.JA_JP: "ja-JP", - # Kannada - Language.KN: "kn-IN", - Language.KN_IN: "kn-IN", - # Korean - Language.KO: "ko-KR", - Language.KO_KR: "ko-KR", - # Latvian - Language.LV: "lv-LV", - Language.LV_LV: "lv-LV", - # Lithuanian - Language.LT: "lt-LT", - Language.LT_LT: "lt-LT", - # Malay - Language.MS: "ms-MY", - Language.MS_MY: "ms-MY", - # Malayalam - Language.ML: "ml-IN", - Language.ML_IN: "ml-IN", - # Marathi - Language.MR: "mr-IN", - Language.MR_IN: "mr-IN", - # Norwegian - Language.NO: "nb-NO", - Language.NB: "nb-NO", - Language.NB_NO: "nb-NO", - # Polish - Language.PL: "pl-PL", - Language.PL_PL: "pl-PL", - # Portuguese - Language.PT: "pt-PT", - Language.PT_BR: "pt-BR", - Language.PT_PT: "pt-PT", - # Punjabi - Language.PA: "pa-IN", - Language.PA_IN: "pa-IN", - # Romanian - Language.RO: "ro-RO", - Language.RO_RO: "ro-RO", - # Russian - Language.RU: "ru-RU", - Language.RU_RU: "ru-RU", - # Serbian - Language.SR: "sr-RS", - Language.SR_RS: "sr-RS", - # Slovak - Language.SK: "sk-SK", - Language.SK_SK: "sk-SK", - # Spanish - Language.ES: "es-ES", - Language.ES_ES: "es-ES", - Language.ES_US: "es-US", - # Swedish - Language.SV: "sv-SE", - Language.SV_SE: "sv-SE", - # Tamil - Language.TA: "ta-IN", - Language.TA_IN: "ta-IN", - # Telugu - Language.TE: "te-IN", - Language.TE_IN: "te-IN", - # Thai - Language.TH: "th-TH", - Language.TH_TH: "th-TH", - # Turkish - Language.TR: "tr-TR", - Language.TR_TR: "tr-TR", - # Ukrainian - Language.UK: "uk-UA", - Language.UK_UA: "uk-UA", - # Vietnamese - Language.VI: "vi-VN", - Language.VI_VN: "vi-VN", - } - - return language_map.get(language) - - -def language_to_google_stt_language(language: Language) -> Optional[str]: - """Maps Language enum to Google Speech-to-Text V2 language codes. - - Args: - language: Language enum value. - - Returns: - Optional[str]: Google STT language code or None if not supported. - """ - language_map = { - # Afrikaans - Language.AF: "af-ZA", - Language.AF_ZA: "af-ZA", - # Albanian - Language.SQ: "sq-AL", - Language.SQ_AL: "sq-AL", - # Amharic - Language.AM: "am-ET", - Language.AM_ET: "am-ET", - # Arabic - Language.AR: "ar-EG", # Default to Egypt - Language.AR_AE: "ar-AE", - Language.AR_BH: "ar-BH", - Language.AR_DZ: "ar-DZ", - Language.AR_EG: "ar-EG", - Language.AR_IQ: "ar-IQ", - Language.AR_JO: "ar-JO", - Language.AR_KW: "ar-KW", - Language.AR_LB: "ar-LB", - Language.AR_MA: "ar-MA", - Language.AR_OM: "ar-OM", - Language.AR_QA: "ar-QA", - Language.AR_SA: "ar-SA", - Language.AR_SY: "ar-SY", - Language.AR_TN: "ar-TN", - Language.AR_YE: "ar-YE", - # Armenian - Language.HY: "hy-AM", - Language.HY_AM: "hy-AM", - # Azerbaijani - Language.AZ: "az-AZ", - Language.AZ_AZ: "az-AZ", - # Basque - Language.EU: "eu-ES", - Language.EU_ES: "eu-ES", - # Bengali - Language.BN: "bn-IN", # Default to India - Language.BN_BD: "bn-BD", - Language.BN_IN: "bn-IN", - # Bosnian - Language.BS: "bs-BA", - Language.BS_BA: "bs-BA", - # Bulgarian - Language.BG: "bg-BG", - Language.BG_BG: "bg-BG", - # Burmese - Language.MY: "my-MM", - Language.MY_MM: "my-MM", - # Catalan - Language.CA: "ca-ES", - Language.CA_ES: "ca-ES", - # Chinese - Language.ZH: "cmn-Hans-CN", # Default to Simplified Chinese - Language.ZH_CN: "cmn-Hans-CN", - Language.ZH_HK: "cmn-Hans-HK", - Language.ZH_TW: "cmn-Hant-TW", - Language.YUE: "yue-Hant-HK", # Cantonese - Language.YUE_CN: "yue-Hant-HK", - # Croatian - Language.HR: "hr-HR", - Language.HR_HR: "hr-HR", - # Czech - Language.CS: "cs-CZ", - Language.CS_CZ: "cs-CZ", - # Danish - Language.DA: "da-DK", - Language.DA_DK: "da-DK", - # Dutch - Language.NL: "nl-NL", # Default to Netherlands - Language.NL_BE: "nl-BE", - Language.NL_NL: "nl-NL", - # English - Language.EN: "en-US", # Default to US - Language.EN_AU: "en-AU", - Language.EN_CA: "en-CA", - Language.EN_GB: "en-GB", - Language.EN_GH: "en-GH", - Language.EN_HK: "en-HK", - Language.EN_IN: "en-IN", - Language.EN_IE: "en-IE", - Language.EN_KE: "en-KE", - Language.EN_NG: "en-NG", - Language.EN_NZ: "en-NZ", - Language.EN_PH: "en-PH", - Language.EN_SG: "en-SG", - Language.EN_TZ: "en-TZ", - Language.EN_US: "en-US", - Language.EN_ZA: "en-ZA", - # Estonian - Language.ET: "et-EE", - Language.ET_EE: "et-EE", - # Filipino - Language.FIL: "fil-PH", - Language.FIL_PH: "fil-PH", - # Finnish - Language.FI: "fi-FI", - Language.FI_FI: "fi-FI", - # French - Language.FR: "fr-FR", # Default to France - Language.FR_BE: "fr-BE", - Language.FR_CA: "fr-CA", - Language.FR_CH: "fr-CH", - Language.FR_FR: "fr-FR", - # Galician - Language.GL: "gl-ES", - Language.GL_ES: "gl-ES", - # Georgian - Language.KA: "ka-GE", - Language.KA_GE: "ka-GE", - # German - Language.DE: "de-DE", # Default to Germany - Language.DE_AT: "de-AT", - Language.DE_CH: "de-CH", - Language.DE_DE: "de-DE", - # Greek - Language.EL: "el-GR", - Language.EL_GR: "el-GR", - # Gujarati - Language.GU: "gu-IN", - Language.GU_IN: "gu-IN", - # Hebrew - Language.HE: "iw-IL", - Language.HE_IL: "iw-IL", - # Hindi - Language.HI: "hi-IN", - Language.HI_IN: "hi-IN", - # Hungarian - Language.HU: "hu-HU", - Language.HU_HU: "hu-HU", - # Icelandic - Language.IS: "is-IS", - Language.IS_IS: "is-IS", - # Indonesian - Language.ID: "id-ID", - Language.ID_ID: "id-ID", - # Italian - Language.IT: "it-IT", - Language.IT_IT: "it-IT", - Language.IT_CH: "it-CH", - # Japanese - Language.JA: "ja-JP", - Language.JA_JP: "ja-JP", - # Javanese - Language.JV: "jv-ID", - Language.JV_ID: "jv-ID", - # Kannada - Language.KN: "kn-IN", - Language.KN_IN: "kn-IN", - # Kazakh - Language.KK: "kk-KZ", - Language.KK_KZ: "kk-KZ", - # Khmer - Language.KM: "km-KH", - Language.KM_KH: "km-KH", - # Korean - Language.KO: "ko-KR", - Language.KO_KR: "ko-KR", - # Lao - Language.LO: "lo-LA", - Language.LO_LA: "lo-LA", - # Latvian - Language.LV: "lv-LV", - Language.LV_LV: "lv-LV", - # Lithuanian - Language.LT: "lt-LT", - Language.LT_LT: "lt-LT", - # Macedonian - Language.MK: "mk-MK", - Language.MK_MK: "mk-MK", - # Malay - Language.MS: "ms-MY", - Language.MS_MY: "ms-MY", - # Malayalam - Language.ML: "ml-IN", - Language.ML_IN: "ml-IN", - # Marathi - Language.MR: "mr-IN", - Language.MR_IN: "mr-IN", - # Mongolian - Language.MN: "mn-MN", - Language.MN_MN: "mn-MN", - # Nepali - Language.NE: "ne-NP", - Language.NE_NP: "ne-NP", - # Norwegian - Language.NO: "no-NO", - Language.NB: "no-NO", - Language.NB_NO: "no-NO", - # Persian - Language.FA: "fa-IR", - Language.FA_IR: "fa-IR", - # Polish - Language.PL: "pl-PL", - Language.PL_PL: "pl-PL", - # Portuguese - Language.PT: "pt-PT", # Default to Portugal - Language.PT_BR: "pt-BR", - Language.PT_PT: "pt-PT", - # Punjabi - Language.PA: "pa-Guru-IN", - Language.PA_IN: "pa-Guru-IN", - # Romanian - Language.RO: "ro-RO", - Language.RO_RO: "ro-RO", - # Russian - Language.RU: "ru-RU", - Language.RU_RU: "ru-RU", - # Serbian - Language.SR: "sr-RS", - Language.SR_RS: "sr-RS", - # Sinhala - Language.SI: "si-LK", - Language.SI_LK: "si-LK", - # Slovak - Language.SK: "sk-SK", - Language.SK_SK: "sk-SK", - # Slovenian - Language.SL: "sl-SI", - Language.SL_SI: "sl-SI", - # Spanish - Language.ES: "es-ES", # Default to Spain - Language.ES_AR: "es-AR", - Language.ES_BO: "es-BO", - Language.ES_CL: "es-CL", - Language.ES_CO: "es-CO", - Language.ES_CR: "es-CR", - Language.ES_DO: "es-DO", - Language.ES_EC: "es-EC", - Language.ES_ES: "es-ES", - Language.ES_GT: "es-GT", - Language.ES_HN: "es-HN", - Language.ES_MX: "es-MX", - Language.ES_NI: "es-NI", - Language.ES_PA: "es-PA", - Language.ES_PE: "es-PE", - Language.ES_PR: "es-PR", - Language.ES_PY: "es-PY", - Language.ES_SV: "es-SV", - Language.ES_US: "es-US", - Language.ES_UY: "es-UY", - Language.ES_VE: "es-VE", - # Sundanese - Language.SU: "su-ID", - Language.SU_ID: "su-ID", - # Swahili - Language.SW: "sw-TZ", # Default to Tanzania - Language.SW_KE: "sw-KE", - Language.SW_TZ: "sw-TZ", - # Swedish - Language.SV: "sv-SE", - Language.SV_SE: "sv-SE", - # Tamil - Language.TA: "ta-IN", # Default to India - Language.TA_IN: "ta-IN", - Language.TA_MY: "ta-MY", - Language.TA_SG: "ta-SG", - Language.TA_LK: "ta-LK", - # Telugu - Language.TE: "te-IN", - Language.TE_IN: "te-IN", - # Thai - Language.TH: "th-TH", - Language.TH_TH: "th-TH", - # Turkish - Language.TR: "tr-TR", - Language.TR_TR: "tr-TR", - # Ukrainian - Language.UK: "uk-UA", - Language.UK_UA: "uk-UA", - # Urdu - Language.UR: "ur-IN", # Default to India - Language.UR_IN: "ur-IN", - Language.UR_PK: "ur-PK", - # Uzbek - Language.UZ: "uz-UZ", - Language.UZ_UZ: "uz-UZ", - # Vietnamese - Language.VI: "vi-VN", - Language.VI_VN: "vi-VN", - # Xhosa - Language.XH: "xh-ZA", - # Zulu - Language.ZU: "zu-ZA", - Language.ZU_ZA: "zu-ZA", - } - - return language_map.get(language) - - -class GoogleUserContextAggregator(OpenAIUserContextAggregator): - async def push_aggregation(self): - if len(self._aggregation) > 0: - self._context.add_message( - glm.Content(role="user", parts=[glm.Part(text=self._aggregation)]) - ) - - # Reset the aggregation. Reset it before pushing it down, otherwise - # if the tasks gets cancelled we won't be able to clear things up. - self._aggregation = "" - - # Push context frame - frame = OpenAILLMContextFrame(self._context) - await self.push_frame(frame) - - # Reset our accumulator state. - self.reset() - - -class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): - async def handle_aggregation(self, aggregation: str): - self._context.add_message(glm.Content(role="model", parts=[glm.Part(text=aggregation)])) - - async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): - self._context.add_message( - glm.Content( - role="model", - parts=[ - glm.Part( - function_call=glm.FunctionCall( - id=frame.tool_call_id, name=frame.function_name, args=frame.arguments - ) - ) - ], - ) - ) - self._context.add_message( - glm.Content( - role="user", - parts=[ - glm.Part( - function_response=glm.FunctionResponse( - id=frame.tool_call_id, - name=frame.function_name, - response={"response": "IN_PROGRESS"}, - ) - ) - ], - ) - ) - - async def handle_function_call_result(self, frame: FunctionCallResultFrame): - if frame.result: - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, frame.result - ) - else: - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "COMPLETED" - ) - - async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "CANCELLED" - ) - - async def _update_function_call_result( - self, function_name: str, tool_call_id: str, result: Any - ): - for message in self._context.messages: - if message.role == "user": - for part in message.parts: - if part.function_response and part.function_response.id == tool_call_id: - part.function_response.response = {"value": json.dumps(result)} - - async def handle_user_image_frame(self, frame: UserImageRawFrame): - await self._update_function_call_result( - frame.request.function_name, frame.request.tool_call_id, "COMPLETED" - ) - self._context.add_image_frame_message( - format=frame.format, - size=frame.size, - image=frame.image, - text=frame.request.context, - ) - - -@dataclass -class GoogleContextAggregatorPair: - _user: "GoogleUserContextAggregator" - _assistant: "GoogleAssistantContextAggregator" - - def user(self) -> "GoogleUserContextAggregator": - return self._user - - def assistant(self) -> "GoogleAssistantContextAggregator": - return self._assistant - - -class GoogleLLMContext(OpenAILLMContext): - def __init__( - self, - messages: Optional[List[dict]] = None, - tools: Optional[List[dict]] = None, - tool_choice: Optional[dict] = None, - ): - super().__init__(messages=messages, tools=tools, tool_choice=tool_choice) - self.system_message = None - - @staticmethod - def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext": - if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext): - logger.debug(f"Upgrading to Google: {obj}") - obj.__class__ = GoogleLLMContext - obj._restructure_from_openai_messages() - return obj - - def set_messages(self, messages: List): - self._messages[:] = messages - self._restructure_from_openai_messages() - - def add_messages(self, messages: List): - # Convert each message individually - converted_messages = [] - for msg in messages: - if isinstance(msg, glm.Content): - # Already in Gemini format - converted_messages.append(msg) - else: - # Convert from standard format to Gemini format - converted = self.from_standard_message(msg) - if converted is not None: - converted_messages.append(converted) - - # Add the converted messages to our existing messages - self._messages.extend(converted_messages) - - def get_messages_for_logging(self): - msgs = [] - for message in self.messages: - obj = glm.Content.to_dict(message) - try: - if "parts" in obj: - for part in obj["parts"]: - if "inline_data" in part: - part["inline_data"]["data"] = "..." - except Exception as e: - logger.debug(f"Error: {e}") - msgs.append(obj) - return msgs - - def add_image_frame_message( - self, *, format: str, size: tuple[int, int], image: bytes, text: str = None - ): - buffer = io.BytesIO() - Image.frombytes(format, size, image).save(buffer, format="JPEG") - - parts = [] - if text: - parts.append(glm.Part(text=text)) - parts.append(glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue()))) - - self.add_message(glm.Content(role="user", parts=parts)) - - def add_audio_frames_message( - self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows" - ): - if not audio_frames: - return - - sample_rate = audio_frames[0].sample_rate - num_channels = audio_frames[0].num_channels - - parts = [] - data = b"".join(frame.audio for frame in audio_frames) - # NOTE(aleix): According to the docs only text or inline_data should be needed. - # (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference) - parts.append(glm.Part(text=text)) - parts.append( - glm.Part( - inline_data=glm.Blob( - mime_type="audio/wav", - data=( - bytes( - self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data - ) - ), - ) - ), - ) - self.add_message(glm.Content(role="user", parts=parts)) - # message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))} - # self.add_message(message) - - def from_standard_message(self, message): - """Convert standard format message to Google Content object. - - Handles conversion of text, images, and function calls to Google's format. - System messages are stored separately and return None. - - Args: - message: Message in standard format: - { - "role": "user/assistant/system/tool", - "content": str | [{"type": "text/image_url", ...}] | None, - "tool_calls": [{"function": {"name": str, "arguments": str}}] - } - - Returns: - glm.Content object with: - - role: "user" or "model" (converted from "assistant") - - parts: List[Part] containing text, inline_data, or function calls - Returns None for system messages. - """ - role = message["role"] - content = message.get("content", []) - if role == "system": - self.system_message = content - return None - elif role == "assistant": - role = "model" - - parts = [] - if message.get("tool_calls"): - for tc in message["tool_calls"]: - parts.append( - glm.Part( - function_call=glm.FunctionCall( - name=tc["function"]["name"], - args=json.loads(tc["function"]["arguments"]), - ) - ) - ) - elif role == "tool": - role = "model" - parts.append( - glm.Part( - function_response=glm.FunctionResponse( - name="tool_call_result", # seems to work to hard-code the same name every time - response=json.loads(message["content"]), - ) - ) - ) - elif isinstance(content, str): - parts.append(glm.Part(text=content)) - elif isinstance(content, list): - for c in content: - if c["type"] == "text": - parts.append(glm.Part(text=c["text"])) - elif c["type"] == "image_url": - parts.append( - glm.Part( - inline_data=glm.Blob( - mime_type="image/jpeg", - data=base64.b64decode(c["image_url"]["url"].split(",")[1]), - ) - ) - ) - - message = glm.Content(role=role, parts=parts) - return message - - def to_standard_messages(self, obj) -> list: - """Convert Google Content object to standard structured format. - - Handles text, images, and function calls from Google's Content/Part objects. - - Args: - obj: Google Content object with: - - role: "model" (converted to "assistant") or "user" - - parts: List[Part] containing text, inline_data, or function calls - - Returns: - List of messages in standard format: - [ - { - "role": "user/assistant/tool", - "content": [ - {"type": "text", "text": str} | - {"type": "image_url", "image_url": {"url": str}} - ] - } - ] - """ - msg = {"role": obj.role, "content": []} - if msg["role"] == "model": - msg["role"] = "assistant" - - for part in obj.parts: - if part.text: - msg["content"].append({"type": "text", "text": part.text}) - elif part.inline_data: - encoded = base64.b64encode(part.inline_data.data).decode("utf-8") - msg["content"].append( - { - "type": "image_url", - "image_url": {"url": f"data:{part.inline_data.mime_type};base64,{encoded}"}, - } - ) - elif part.function_call: - args = type(part.function_call).to_dict(part.function_call).get("args", {}) - msg["tool_calls"] = [ - { - "id": part.function_call.name, - "type": "function", - "function": { - "name": part.function_call.name, - "arguments": json.dumps(args), - }, - } - ] - - elif part.function_response: - msg["role"] = "tool" - resp = ( - type(part.function_response).to_dict(part.function_response).get("response", {}) - ) - msg["tool_call_id"] = part.function_response.name - msg["content"] = json.dumps(resp) - - # there might be no content parts for tool_calls messages - if not msg["content"]: - del msg["content"] - return [msg] - - def _restructure_from_openai_messages(self): - """Restructures messages to ensure proper Google format and message ordering. - - This method handles conversion of OpenAI-formatted messages to Google format, - with special handling for function calls, function responses, and system messages. - System messages are added back to the context as user messages when needed. - - The final message order is preserved as: - 1. Function calls (from model) - 2. Function responses (from user) - 3. Text messages (converted from system messages) - - Note: - System messages are only added back when there are no regular text - messages in the context, ensuring proper conversation continuity - after function calls. - """ - self.system_message = None - converted_messages = [] - - # Process each message, preserving Google-formatted messages and converting others - for message in self._messages: - if isinstance(message, glm.Content): - # Keep existing Google-formatted messages (e.g., function calls/responses) - converted_messages.append(message) - continue - - # Convert OpenAI format to Google format, system messages return None - converted = self.from_standard_message(message) - if converted is not None: - converted_messages.append(converted) - - # Update message list - self._messages[:] = converted_messages - - # Check if we only have function-related messages (no regular text) - has_regular_messages = any( - len(msg.parts) == 1 - and not getattr(msg.parts[0], "text", None) - and getattr(msg.parts[0], "function_call", None) - and getattr(msg.parts[0], "function_response", None) - for msg in self._messages - ) - - # Add system message back as a user message if we only have function messages - if self.system_message and not has_regular_messages: - self._messages.append( - glm.Content(role="user", parts=[glm.Part(text=self.system_message)]) - ) - - # Remove any empty messages - self._messages = [m for m in self._messages if m.parts] - - -class GoogleLLMService(LLMService): - """This class implements inference with Google's AI models. - - This service translates internally from OpenAILLMContext to the messages format - expected by the Google AI model. We are using the OpenAILLMContext as a lingua - franca for all LLM services, so that it is easy to switch between different LLMs. - """ - - # Overriding the default adapter to use the Gemini one. - adapter_class = GeminiLLMAdapter - - class InputParams(BaseModel): - max_tokens: Optional[int] = Field(default=4096, ge=1) - temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) - top_k: Optional[int] = Field(default=None, ge=0) - top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) - extra: Optional[Dict[str, Any]] = Field(default_factory=dict) - - def __init__( - self, - *, - api_key: str, - model: str = "gemini-2.0-flash-001", - params: InputParams = InputParams(), - system_instruction: Optional[str] = None, - tools: Optional[List[Dict[str, Any]]] = None, - tool_config: Optional[Dict[str, Any]] = None, - **kwargs, - ): - super().__init__(**kwargs) - gai.configure(api_key=api_key) - self.set_model_name(model) - self._system_instruction = system_instruction - self._create_client() - self._settings = { - "max_tokens": params.max_tokens, - "temperature": params.temperature, - "top_k": params.top_k, - "top_p": params.top_p, - "extra": params.extra if isinstance(params.extra, dict) else {}, - } - self._tools = tools - self._tool_config = tool_config - - def can_generate_metrics(self) -> bool: - return True - - def _create_client(self): - self._client = gai.GenerativeModel( - self._model_name, system_instruction=self._system_instruction - ) - - async def _process_context(self, context: OpenAILLMContext): - await self.push_frame(LLMFullResponseStartFrame()) - - prompt_tokens = 0 - completion_tokens = 0 - total_tokens = 0 - - grounding_metadata = None - search_result = "" - - try: - logger.debug( - # f"{self}: Generating chat [{self._system_instruction}] | [{context.get_messages_for_logging()}]" - f"{self}: Generating chat [{context.get_messages_for_logging()}]" - ) - - messages = context.messages - if context.system_message and self._system_instruction != context.system_message: - logger.debug(f"System instruction changed: {context.system_message}") - self._system_instruction = context.system_message - self._create_client() - - # Filter out None values and create GenerationConfig - generation_params = { - k: v - for k, v in { - "temperature": self._settings["temperature"], - "top_p": self._settings["top_p"], - "top_k": self._settings["top_k"], - "max_output_tokens": self._settings["max_tokens"], - }.items() - if v is not None - } - - generation_config = GenerationConfig(**generation_params) if generation_params else None - - await self.start_ttfb_metrics() - tools = [] - if context.tools: - tools = context.tools - elif self._tools: - tools = self._tools - tool_config = None - if self._tool_config: - tool_config = self._tool_config - response = await self._client.generate_content_async( - contents=messages, - tools=tools, - stream=True, - generation_config=generation_config, - tool_config=tool_config, - ) - await self.stop_ttfb_metrics() - - if response.usage_metadata: - # Use only the prompt token count from the response object - prompt_tokens = response.usage_metadata.prompt_token_count - total_tokens = prompt_tokens - - async for chunk in response: - if chunk.usage_metadata: - # Use only the completion_tokens from the chunks. Prompt tokens are already counted and - # are repeated here. - completion_tokens += chunk.usage_metadata.candidates_token_count - total_tokens += chunk.usage_metadata.candidates_token_count - try: - for c in chunk.parts: - if c.text: - search_result += c.text - await self.push_frame(LLMTextFrame(c.text)) - elif c.function_call: - logger.debug(f"Function call: {c.function_call}") - args = type(c.function_call).to_dict(c.function_call).get("args", {}) - await self.call_function( - context=context, - tool_call_id=str(uuid.uuid4()), - function_name=c.function_call.name, - arguments=args, - ) - # Handle grounding metadata - # It seems only the last chunk that we receive may contain this information - # If the response doesn't include groundingMetadata, this means the response wasn't grounded. - if chunk.candidates: - for candidate in chunk.candidates: - # logger.debug(f"candidate received: {candidate}") - # Extract grounding metadata - grounding_metadata = ( - { - "rendered_content": getattr( - getattr(candidate, "grounding_metadata", None), - "search_entry_point", - None, - ).rendered_content - if hasattr( - getattr(candidate, "grounding_metadata", None), - "search_entry_point", - ) - else None, - "origins": [ - { - "site_uri": getattr(grounding_chunk.web, "uri", None), - "site_title": getattr( - grounding_chunk.web, "title", None - ), - "results": [ - { - "text": getattr( - grounding_support.segment, "text", "" - ), - "confidence": getattr( - grounding_support, "confidence_scores", None - ), - } - for grounding_support in getattr( - getattr(candidate, "grounding_metadata", None), - "grounding_supports", - [], - ) - if index - in getattr( - grounding_support, "grounding_chunk_indices", [] - ) - ], - } - for index, grounding_chunk in enumerate( - getattr( - getattr(candidate, "grounding_metadata", None), - "grounding_chunks", - [], - ) - ) - ], - } - if getattr(candidate, "grounding_metadata", None) - else None - ) - except Exception as e: - # Google LLMs seem to flag safety issues a lot! - if chunk.candidates[0].finish_reason == 3: - logger.debug( - f"LLM refused to generate content for safety reasons - {messages}." - ) - else: - logger.exception(f"{self} error: {e}") - - except DeadlineExceeded: - await self._call_event_handler("on_completion_timeout") - except Exception as e: - logger.exception(f"{self} exception: {e}") - finally: - if grounding_metadata is not None and isinstance(grounding_metadata, dict): - llm_search_frame = LLMSearchResponseFrame( - search_result=search_result, - origins=grounding_metadata["origins"], - rendered_content=grounding_metadata["rendered_content"], - ) - await self.push_frame(llm_search_frame) - - await self.start_llm_usage_metrics( - LLMTokenUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - ) - ) - await self.push_frame(LLMFullResponseEndFrame()) - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - context = None - - if isinstance(frame, OpenAILLMContextFrame): - context = GoogleLLMContext.upgrade_to_google(frame.context) - elif isinstance(frame, LLMMessagesFrame): - context = GoogleLLMContext(frame.messages) - elif isinstance(frame, VisionImageRawFrame): - context = GoogleLLMContext() - context.add_image_frame_message( - format=frame.format, size=frame.size, image=frame.image, text=frame.text - ) - elif isinstance(frame, LLMUpdateSettingsFrame): - await self._update_settings(frame.settings) - else: - await self.push_frame(frame, direction) - - if context: - await self._process_context(context) - - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_kwargs: Mapping[str, Any] = {}, - assistant_kwargs: Mapping[str, Any] = {}, - ) -> GoogleContextAggregatorPair: - """Create an instance of GoogleContextAggregatorPair from an - OpenAILLMContext. Constructor keyword arguments for both the user and - assistant aggregators can be provided. - - Args: - context (OpenAILLMContext): The LLM context. - user_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the user context aggregator constructor. Defaults - to an empty mapping. - assistant_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the assistant context aggregator - constructor. Defaults to an empty mapping. - - Returns: - GoogleContextAggregatorPair: A pair of context aggregators, one for - the user and one for the assistant, encapsulated in an - GoogleContextAggregatorPair. - - """ - context.set_llm_adapter(self.get_llm_adapter()) - - if isinstance(context, OpenAILLMContext): - context = GoogleLLMContext.upgrade_to_google(context) - user = GoogleUserContextAggregator(context, **user_kwargs) - assistant = GoogleAssistantContextAggregator(context, **assistant_kwargs) - return GoogleContextAggregatorPair(_user=user, _assistant=assistant) - - -class GoogleLLMOpenAIBetaService(OpenAILLMService): - """This class implements inference with Google's AI LLM models using the OpenAI format. - Ref - https://ai.google.dev/gemini-api/docs/openai - """ - - def __init__( - self, - *, - api_key: str, - base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/", - model: str = "gemini-2.0-flash", - **kwargs, - ): - super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs) - - async def _process_context(self, context: OpenAILLMContext): - functions_list = [] - arguments_list = [] - tool_id_list = [] - func_idx = 0 - function_name = "" - arguments = "" - tool_call_id = "" - - await self.start_ttfb_metrics() - - chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions( - context - ) - - async for chunk in chunk_stream: - if chunk.usage: - tokens = LLMTokenUsage( - prompt_tokens=chunk.usage.prompt_tokens, - completion_tokens=chunk.usage.completion_tokens, - total_tokens=chunk.usage.total_tokens, - ) - await self.start_llm_usage_metrics(tokens) - - if chunk.choices is None or len(chunk.choices) == 0: - continue - - await self.stop_ttfb_metrics() - - if not chunk.choices[0].delta: - continue - - if chunk.choices[0].delta.tool_calls: - # We're streaming the LLM response to enable the fastest response times. - # For text, we just yield each chunk as we receive it and count on consumers - # to do whatever coalescing they need (eg. to pass full sentences to TTS) - # - # If the LLM is a function call, we'll do some coalescing here. - # If the response contains a function name, we'll yield a frame to tell consumers - # that they can start preparing to call the function with that name. - # We accumulate all the arguments for the rest of the streamed response, then when - # the response is done, we package up all the arguments and the function name and - # yield a frame containing the function name and the arguments. - logger.debug(f"Tool call: {chunk.choices[0].delta.tool_calls}") - tool_call = chunk.choices[0].delta.tool_calls[0] - if tool_call.index != func_idx: - functions_list.append(function_name) - arguments_list.append(arguments) - tool_id_list.append(tool_call_id) - function_name = "" - arguments = "" - tool_call_id = "" - func_idx += 1 - if tool_call.function and tool_call.function.name: - function_name += tool_call.function.name - tool_call_id = tool_call.id - if tool_call.function and tool_call.function.arguments: - # Keep iterating through the response to collect all the argument fragments - arguments += tool_call.function.arguments - elif chunk.choices[0].delta.content: - await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content)) - - # if we got a function name and arguments, check to see if it's a function with - # a registered handler. If so, run the registered callback, save the result to - # the context, and re-prompt to get a chat answer. If we don't have a registered - # handler, raise an exception. - if function_name and arguments: - # added to the list as last function name and arguments not added to the list - functions_list.append(function_name) - arguments_list.append(arguments) - tool_id_list.append(tool_call_id) - - logger.debug( - f"Function list: {functions_list}, Arguments list: {arguments_list}, Tool ID list: {tool_id_list}" - ) - for index, (function_name, arguments, tool_id) in enumerate( - zip(functions_list, arguments_list, tool_id_list), start=1 - ): - if function_name == "": - # TODO: Remove the _process_context method once Google resolves the bug - # where the index is incorrectly set to None instead of returning the actual index, - # which currently results in an empty function name(''). - continue - if self.has_function(function_name): - run_llm = False - arguments = json.loads(arguments) - await self.call_function( - context=context, - function_name=function_name, - arguments=arguments, - tool_call_id=tool_id, - run_llm=run_llm, - ) - else: - raise OpenAIUnhandledFunctionException( - f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." - ) - - -class GoogleVertexLLMService(OpenAILLMService): - """Implements inference with Google's AI models via Vertex AI while - maintaining OpenAI API compatibility. - - Reference: - https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-vertex-using-openai-library - - """ - - class InputParams(OpenAILLMService.InputParams): - """Input parameters specific to Vertex AI.""" - - # https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations - location: str = "us-east4" - project_id: str - - def __init__( - self, - *, - credentials: Optional[str] = None, - credentials_path: Optional[str] = None, - model: str = "google/gemini-2.0-flash-001", - params: InputParams = OpenAILLMService.InputParams(), - **kwargs, - ): - """Initializes the VertexLLMService. - - Args: - credentials (Optional[str]): JSON string of service account credentials. - credentials_path (Optional[str]): Path to the service account JSON file. - model (str): Model identifier. Defaults to "google/gemini-2.0-flash-001". - params (InputParams): Vertex AI input parameters. - **kwargs: Additional arguments for OpenAILLMService. - """ - base_url = self._get_base_url(params) - self._api_key = self._get_api_token(credentials, credentials_path) - - super().__init__(api_key=self._api_key, base_url=base_url, model=model, **kwargs) - - @staticmethod - def _get_base_url(params: InputParams) -> str: - """Constructs the base URL for Vertex AI API.""" - return ( - f"https://{params.location}-aiplatform.googleapis.com/v1/" - f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi" - ) - - @staticmethod - def _get_api_token(credentials: Optional[str], credentials_path: Optional[str]) -> str: - """Retrieves an authentication token using Google service account credentials. - - Args: - credentials (Optional[str]): JSON string of service account credentials. - credentials_path (Optional[str]): Path to the service account JSON file. - - Returns: - str: OAuth token for API authentication. - """ - creds: Optional[service_account.Credentials] = None - - if credentials: - # Parse and load credentials from JSON string - creds = service_account.Credentials.from_service_account_info( - json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"] - ) - elif credentials_path: - # Load credentials from JSON file - creds = service_account.Credentials.from_service_account_file( - credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"] - ) - - if not creds: - raise ValueError("No valid credentials provided.") - - creds.refresh(Request()) # Ensure token is up-to-date, lifetime is 1 hour. - - return creds.token - - -class GoogleTTSService(TTSService): - class InputParams(BaseModel): - pitch: Optional[str] = None - rate: Optional[str] = None - volume: Optional[str] = None - emphasis: Optional[Literal["strong", "moderate", "reduced", "none"]] = None - language: Optional[Language] = Language.EN - gender: Optional[Literal["male", "female", "neutral"]] = None - google_style: Optional[Literal["apologetic", "calm", "empathetic", "firm", "lively"]] = None - - def __init__( - self, - *, - credentials: Optional[str] = None, - credentials_path: Optional[str] = None, - voice_id: str = "en-US-Neural2-A", - sample_rate: Optional[int] = None, - params: InputParams = InputParams(), - **kwargs, - ): - super().__init__(sample_rate=sample_rate, **kwargs) - - self._settings = { - "pitch": params.pitch, - "rate": params.rate, - "volume": params.volume, - "emphasis": params.emphasis, - "language": self.language_to_service_language(params.language) - if params.language - else "en-US", - "gender": params.gender, - "google_style": params.google_style, - } - self.set_voice(voice_id) - self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client( - credentials, credentials_path - ) - - def _create_client( - self, credentials: Optional[str], credentials_path: Optional[str] - ) -> texttospeech_v1.TextToSpeechAsyncClient: - creds: Optional[service_account.Credentials] = None - - # Create a Google Cloud service account for the Cloud Text-to-Speech API - # Using either the provided credentials JSON string or the path to a service account JSON - # file, create a Google Cloud service account and use it to authenticate with the API. - if credentials: - # Use provided credentials JSON string - json_account_info = json.loads(credentials) - creds = service_account.Credentials.from_service_account_info(json_account_info) - elif credentials_path: - # Use service account JSON file if provided - creds = service_account.Credentials.from_service_account_file(credentials_path) - - return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds) - - def can_generate_metrics(self) -> bool: - return True - - def language_to_service_language(self, language: Language) -> Optional[str]: - return language_to_google_tts_language(language) - - def _construct_ssml(self, text: str) -> str: - ssml = "" - - # Voice tag - voice_attrs = [f"name='{self._voice_id}'"] - - language = self._settings["language"] - voice_attrs.append(f"language='{language}'") - - if self._settings["gender"]: - voice_attrs.append(f"gender='{self._settings['gender']}'") - ssml += f"" - - # Prosody tag - prosody_attrs = [] - if self._settings["pitch"]: - prosody_attrs.append(f"pitch='{self._settings['pitch']}'") - if self._settings["rate"]: - prosody_attrs.append(f"rate='{self._settings['rate']}'") - if self._settings["volume"]: - prosody_attrs.append(f"volume='{self._settings['volume']}'") - - if prosody_attrs: - ssml += f"" - - # Emphasis tag - if self._settings["emphasis"]: - ssml += f"" - - # Google style tag - if self._settings["google_style"]: - ssml += f"" - - ssml += text - - # Close tags - if self._settings["google_style"]: - ssml += "" - if self._settings["emphasis"]: - ssml += "" - if prosody_attrs: - ssml += "" - ssml += "" - - return ssml - - async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: - logger.debug(f"{self}: Generating TTS [{text}]") - - try: - await self.start_ttfb_metrics() - - # Check if the voice is a Chirp voice (including Chirp 3) or Journey voice - is_chirp_voice = "chirp" in self._voice_id.lower() - is_journey_voice = "journey" in self._voice_id.lower() - - # Create synthesis input based on voice_id - if is_chirp_voice or is_journey_voice: - # Chirp and Journey voices don't support SSML, use plain text - synthesis_input = texttospeech_v1.SynthesisInput(text=text) - else: - ssml = self._construct_ssml(text) - synthesis_input = texttospeech_v1.SynthesisInput(ssml=ssml) - - voice = texttospeech_v1.VoiceSelectionParams( - language_code=self._settings["language"], name=self._voice_id - ) - audio_config = texttospeech_v1.AudioConfig( - audio_encoding=texttospeech_v1.AudioEncoding.LINEAR16, - sample_rate_hertz=self.sample_rate, - ) - - request = texttospeech_v1.SynthesizeSpeechRequest( - input=synthesis_input, voice=voice, audio_config=audio_config - ) - - response = await self._client.synthesize_speech(request=request) - - await self.start_tts_usage_metrics(text) - - yield TTSStartedFrame() - - # Skip the first 44 bytes to remove the WAV header - audio_content = response.audio_content[44:] - - # Read and yield audio data in chunks - chunk_size = 8192 - for i in range(0, len(audio_content), chunk_size): - chunk = audio_content[i : i + chunk_size] - if not chunk: - break - await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame(chunk, self.sample_rate, 1) - yield frame - await asyncio.sleep(0) # Allow other tasks to run - - yield TTSStoppedFrame() - - except Exception as e: - logger.exception(f"{self} error generating TTS: {e}") - error_message = f"TTS generation error: {str(e)}" - yield ErrorFrame(error=error_message) - - -class GoogleImageGenService(ImageGenService): - class InputParams(BaseModel): - number_of_images: int = Field(default=1, ge=1, le=8) - model: str = Field(default="imagen-3.0-generate-002") - negative_prompt: str = Field(default=None) - - def __init__( - self, - *, - params: InputParams = InputParams(), - api_key: str, - **kwargs, - ): - super().__init__(**kwargs) - self.set_model_name(params.model) - self._params = params - self._client = genai.Client(api_key=api_key) - - def can_generate_metrics(self) -> bool: - return True - - async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: - """Generate images from a text prompt using Google's Imagen model. - - Args: - prompt (str): The text description to generate images from. - - Yields: - Frame: Generated image frames or error frames. - """ - logger.debug(f"Generating image from prompt: {prompt}") - await self.start_ttfb_metrics() - - try: - response = await self._client.aio.models.generate_images( - model=self._params.model, - prompt=prompt, - config=types.GenerateImagesConfig( - number_of_images=self._params.number_of_images, - negative_prompt=self._params.negative_prompt, - ), - ) - await self.stop_ttfb_metrics() - - if not response or not response.generated_images: - logger.error(f"{self} error: image generation failed") - yield ErrorFrame("Image generation failed") - return - - for img_response in response.generated_images: - # Google returns the image data directly - image_bytes = img_response.image.image_bytes - image = Image.open(io.BytesIO(image_bytes)) - - frame = URLImageRawFrame( - url=None, # Google doesn't provide URLs, only image data - image=image.tobytes(), - size=image.size, - format=image.format, - ) - yield frame - - except Exception as e: - logger.error(f"{self} error generating image: {e}") - yield ErrorFrame(f"Image generation error: {str(e)}") - - -class GoogleSTTService(STTService): - """Google Cloud Speech-to-Text V2 service implementation. - - Provides real-time speech recognition using Google Cloud's Speech-to-Text V2 API - with streaming support. Handles audio transcription and optional voice activity detection. - - Attributes: - InputParams: Configuration parameters for the STT service. - """ - - # Google Cloud's STT service has a connection time limit of 5 minutes per stream. - # They've shared an "endless streaming" example that guided this implementation: - # https://cloud.google.com/speech-to-text/docs/transcribe-streaming-audio#endless-streaming - - STREAMING_LIMIT = 240000 # 4 minutes in milliseconds - - class InputParams(BaseModel): - """Configuration parameters for Google Speech-to-Text. - - Attributes: - languages: Single language or list of recognition languages. First language is primary. - model: Speech recognition model to use. - use_separate_recognition_per_channel: Process each audio channel separately. - enable_automatic_punctuation: Add punctuation to transcripts. - enable_spoken_punctuation: Include spoken punctuation in transcript. - enable_spoken_emojis: Include spoken emojis in transcript. - profanity_filter: Filter profanity from transcript. - enable_word_time_offsets: Include timing information for each word. - enable_word_confidence: Include confidence scores for each word. - enable_interim_results: Stream partial recognition results. - enable_voice_activity_events: Detect voice activity in audio. - """ - - languages: Union[Language, List[Language]] = Field(default_factory=lambda: [Language.EN_US]) - model: Optional[str] = "latest_long" - use_separate_recognition_per_channel: Optional[bool] = False - enable_automatic_punctuation: Optional[bool] = True - enable_spoken_punctuation: Optional[bool] = False - enable_spoken_emojis: Optional[bool] = False - profanity_filter: Optional[bool] = False - enable_word_time_offsets: Optional[bool] = False - enable_word_confidence: Optional[bool] = False - enable_interim_results: Optional[bool] = True - enable_voice_activity_events: Optional[bool] = False - - @field_validator("languages", mode="before") - @classmethod - def validate_languages(cls, v) -> List[Language]: - if isinstance(v, Language): - return [v] - return v - - @property - def language_list(self) -> List[Language]: - """Get languages as a guaranteed list.""" - assert isinstance(self.languages, list) - return self.languages - - def __init__( - self, - *, - credentials: Optional[str] = None, - credentials_path: Optional[str] = None, - location: str = "global", - sample_rate: Optional[int] = None, - params: InputParams = InputParams(), - **kwargs, - ): - """Initialize the Google STT service. - - Args: - credentials: JSON string containing Google Cloud service account credentials. - credentials_path: Path to service account credentials JSON file. - location: Google Cloud location (e.g., "global", "us-central1"). - sample_rate: Audio sample rate in Hertz. - params: Configuration parameters for the service. - **kwargs: Additional arguments passed to STTService. - - Raises: - ValueError: If neither credentials nor credentials_path is provided. - ValueError: If project ID is not found in credentials. - """ - super().__init__(sample_rate=sample_rate, **kwargs) - - self._location = location - self._stream = None - self._config = None - self._request_queue = asyncio.Queue() - self._streaming_task = None - - # Used for keep-alive logic - self._stream_start_time = 0 - self._last_audio_input = [] - self._audio_input = [] - self._result_end_time = 0 - self._is_final_end_time = 0 - self._final_request_end_time = 0 - self._bridging_offset = 0 - self._last_transcript_was_final = False - self._new_stream = True - self._restart_counter = 0 - - # Configure client options based on location - client_options = None - if self._location != "global": - client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com") - - # Extract project ID and create client - if credentials: - json_account_info = json.loads(credentials) - self._project_id = json_account_info.get("project_id") - creds = service_account.Credentials.from_service_account_info(json_account_info) - elif credentials_path: - with open(credentials_path) as f: - json_account_info = json.load(f) - self._project_id = json_account_info.get("project_id") - creds = service_account.Credentials.from_service_account_file(credentials_path) - else: - raise ValueError("Either credentials or credentials_path must be provided") - - if not self._project_id: - raise ValueError("Project ID not found in credentials") - - self._client = speech_v2.SpeechAsyncClient(credentials=creds, client_options=client_options) - - self._settings = { - "language_codes": [ - self.language_to_service_language(lang) for lang in params.language_list - ], - "model": params.model, - "use_separate_recognition_per_channel": params.use_separate_recognition_per_channel, - "enable_automatic_punctuation": params.enable_automatic_punctuation, - "enable_spoken_punctuation": params.enable_spoken_punctuation, - "enable_spoken_emojis": params.enable_spoken_emojis, - "profanity_filter": params.profanity_filter, - "enable_word_time_offsets": params.enable_word_time_offsets, - "enable_word_confidence": params.enable_word_confidence, - "enable_interim_results": params.enable_interim_results, - "enable_voice_activity_events": params.enable_voice_activity_events, - } - - def language_to_service_language(self, language: Language | List[Language]) -> str | List[str]: - """Convert Language enum(s) to Google STT language code(s). - - Args: - language: Single Language enum or list of Language enums. - - Returns: - str | List[str]: Google STT language code(s). - """ - if isinstance(language, list): - return [language_to_google_stt_language(lang) or "en-US" for lang in language] - return language_to_google_stt_language(language) or "en-US" - - async def _reconnect_if_needed(self): - """Reconnect the stream if it's currently active.""" - if self._streaming_task: - logger.debug("Reconnecting stream due to configuration changes") - await self._disconnect() - await self._connect() - - async def set_language(self, language: Language): - """Update the service's recognition language. - - A convenience method for setting a single language. - - Args: - language: New language for recognition. - """ - logger.debug(f"Switching STT language to: {language}") - await self.set_languages([language]) - - async def set_languages(self, languages: List[Language]): - """Update the service's recognition languages. - - Args: - languages: List of languages for recognition. First language is primary. - """ - logger.debug(f"Switching STT languages to: {languages}") - self._settings["language_codes"] = [ - self.language_to_service_language(lang) for lang in languages - ] - # Recreate stream with new languages - await self._reconnect_if_needed() - - async def set_model(self, model: str): - """Update the service's recognition model.""" - logger.debug(f"Switching STT model to: {model}") - await super().set_model(model) - self._settings["model"] = model - # Recreate stream with new model - await self._reconnect_if_needed() - - async def start(self, frame: StartFrame): - await super().start(frame) - await self._connect() - - async def stop(self, frame: EndFrame): - await super().stop(frame) - await self._disconnect() - - async def cancel(self, frame: CancelFrame): - await super().cancel(frame) - await self._disconnect() - - async def update_options( - self, - *, - languages: Optional[List[Language]] = None, - model: Optional[str] = None, - enable_automatic_punctuation: Optional[bool] = None, - enable_spoken_punctuation: Optional[bool] = None, - enable_spoken_emojis: Optional[bool] = None, - profanity_filter: Optional[bool] = None, - enable_word_time_offsets: Optional[bool] = None, - enable_word_confidence: Optional[bool] = None, - enable_interim_results: Optional[bool] = None, - enable_voice_activity_events: Optional[bool] = None, - location: Optional[str] = None, - ) -> None: - """Update service options dynamically. - - Args: - languages: New list of recongition languages. - model: New recognition model. - enable_automatic_punctuation: Enable/disable automatic punctuation. - enable_spoken_punctuation: Enable/disable spoken punctuation. - enable_spoken_emojis: Enable/disable spoken emojis. - profanity_filter: Enable/disable profanity filter. - enable_word_time_offsets: Enable/disable word timing info. - enable_word_confidence: Enable/disable word confidence scores. - enable_interim_results: Enable/disable interim results. - enable_voice_activity_events: Enable/disable voice activity detection. - location: New Google Cloud location. - - Note: - Changes that affect the streaming configuration will cause - the stream to be reconnected. - """ - # Update settings with new values - if languages is not None: - logger.debug(f"Updating language to: {languages}") - self._settings["language_codes"] = [ - self.language_to_service_language(lang) for lang in languages - ] - - if model is not None: - logger.debug(f"Updating model to: {model}") - self._settings["model"] = model - - if enable_automatic_punctuation is not None: - logger.debug(f"Updating automatic punctuation to: {enable_automatic_punctuation}") - self._settings["enable_automatic_punctuation"] = enable_automatic_punctuation - - if enable_spoken_punctuation is not None: - logger.debug(f"Updating spoken punctuation to: {enable_spoken_punctuation}") - self._settings["enable_spoken_punctuation"] = enable_spoken_punctuation - - if enable_spoken_emojis is not None: - logger.debug(f"Updating spoken emojis to: {enable_spoken_emojis}") - self._settings["enable_spoken_emojis"] = enable_spoken_emojis - - if profanity_filter is not None: - logger.debug(f"Updating profanity filter to: {profanity_filter}") - self._settings["profanity_filter"] = profanity_filter - - if enable_word_time_offsets is not None: - logger.debug(f"Updating word time offsets to: {enable_word_time_offsets}") - self._settings["enable_word_time_offsets"] = enable_word_time_offsets - - if enable_word_confidence is not None: - logger.debug(f"Updating word confidence to: {enable_word_confidence}") - self._settings["enable_word_confidence"] = enable_word_confidence - - if enable_interim_results is not None: - logger.debug(f"Updating interim results to: {enable_interim_results}") - self._settings["enable_interim_results"] = enable_interim_results - - if enable_voice_activity_events is not None: - logger.debug(f"Updating voice activity events to: {enable_voice_activity_events}") - self._settings["enable_voice_activity_events"] = enable_voice_activity_events - - if location is not None: - logger.debug(f"Updating location to: {location}") - self._location = location - - # Reconnect the stream for updates - await self._reconnect_if_needed() - - async def _connect(self): - """Initialize streaming recognition config and stream.""" - logger.debug("Connecting to Google Speech-to-Text") - - # Set stream start time - self._stream_start_time = int(time.time() * 1000) - self._new_stream = True - - self._config = cloud_speech.StreamingRecognitionConfig( - config=cloud_speech.RecognitionConfig( - explicit_decoding_config=cloud_speech.ExplicitDecodingConfig( - encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16, - sample_rate_hertz=self.sample_rate, - audio_channel_count=1, - ), - language_codes=self._settings["language_codes"], - model=self._settings["model"], - features=cloud_speech.RecognitionFeatures( - enable_automatic_punctuation=self._settings["enable_automatic_punctuation"], - enable_spoken_punctuation=self._settings["enable_spoken_punctuation"], - enable_spoken_emojis=self._settings["enable_spoken_emojis"], - profanity_filter=self._settings["profanity_filter"], - enable_word_time_offsets=self._settings["enable_word_time_offsets"], - enable_word_confidence=self._settings["enable_word_confidence"], - ), - ), - streaming_features=cloud_speech.StreamingRecognitionFeatures( - enable_voice_activity_events=self._settings["enable_voice_activity_events"], - interim_results=self._settings["enable_interim_results"], - ), - ) - - self._streaming_task = self.create_task(self._stream_audio()) - - async def _disconnect(self): - """Clean up streaming recognition resources.""" - if self._streaming_task: - logger.debug("Disconnecting from Google Speech-to-Text") - # Send sentinel value to stop request generator - await self._request_queue.put(None) - await self.cancel_task(self._streaming_task) - self._streaming_task = None - # Clear any remaining items in the queue - while not self._request_queue.empty(): - try: - self._request_queue.get_nowait() - self._request_queue.task_done() - except asyncio.QueueEmpty: - break - - async def _request_generator(self): - """Generates requests for the streaming recognize method.""" - recognizer_path = f"projects/{self._project_id}/locations/{self._location}/recognizers/_" - logger.trace(f"Using recognizer path: {recognizer_path}") - - try: - # Send initial config - yield cloud_speech.StreamingRecognizeRequest( - recognizer=recognizer_path, - streaming_config=self._config, - ) - - while True: - try: - audio_data = await self._request_queue.get() - if audio_data is None: # Sentinel value to stop - break - - # Check streaming limit - if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT: - logger.debug("Streaming limit reached, initiating graceful reconnection") - # Instead of immediate reconnection, we'll break and let the stream close naturally - self._last_audio_input = self._audio_input - self._audio_input = [] - self._restart_counter += 1 - # Put the current audio chunk back in the queue - await self._request_queue.put(audio_data) - break - - self._audio_input.append(audio_data) - yield cloud_speech.StreamingRecognizeRequest(audio=audio_data) - - except asyncio.CancelledError: - break - finally: - self._request_queue.task_done() - - except Exception as e: - logger.error(f"Error in request generator: {e}") - raise - - async def _stream_audio(self): - """Handle bi-directional streaming with Google STT.""" - try: - while True: - try: - # Start bi-directional streaming - streaming_recognize = await self._client.streaming_recognize( - requests=self._request_generator() - ) - - # Process responses - await self._process_responses(streaming_recognize) - - # If we're here, check if we need to reconnect - if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT: - logger.debug("Reconnecting stream after timeout") - # Reset stream start time - self._stream_start_time = int(time.time() * 1000) - continue - else: - # Normal stream end - break - - except Exception as e: - logger.warning(f"{self} Reconnecting: {e}") - - await asyncio.sleep(1) # Brief delay before reconnecting - self._stream_start_time = int(time.time() * 1000) - continue - - except Exception as e: - logger.error(f"Error in streaming task: {e}") - await self.push_frame(ErrorFrame(str(e))) - - async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: - """Process an audio chunk for STT transcription.""" - if self._streaming_task: - # Queue the audio data - await self._request_queue.put(audio) - yield None - - async def _process_responses(self, streaming_recognize): - """Process streaming recognition responses.""" - try: - async for response in streaming_recognize: - # Check streaming limit - if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT: - logger.debug("Stream timeout reached in response processing") - break - - if not response.results: - continue - - for result in response.results: - if not result.alternatives: - continue - - transcript = result.alternatives[0].transcript - if not transcript: - continue - - primary_language = self._settings["language_codes"][0] - - if result.is_final: - self._last_transcript_was_final = True - await self.push_frame( - TranscriptionFrame(transcript, "", time_now_iso8601(), primary_language) - ) - else: - self._last_transcript_was_final = False - await self.push_frame( - InterimTranscriptionFrame( - transcript, "", time_now_iso8601(), primary_language - ) - ) - - except Exception as e: - logger.error(f"Error processing Google STT responses: {e}") - - # Re-raise the exception to let it propagate (e.g. in the case of a timeout, propagate to _stream_audio to reconnect) - raise diff --git a/src/pipecat/services/google/image.py b/src/pipecat/services/google/image.py new file mode 100644 index 000000000..f7a7764f2 --- /dev/null +++ b/src/pipecat/services/google/image.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import io +import os + +# Suppress gRPC fork warnings +os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" + +from typing import AsyncGenerator + +from loguru import logger +from PIL import Image +from pydantic import BaseModel, Field + +from pipecat.frames.frames import ErrorFrame, Frame, URLImageRawFrame +from pipecat.services.ai_services import ImageGenService + +try: + from google import genai + from google.genai import types +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.") + raise Exception(f"Missing module: {e}") + + +class GoogleImageGenService(ImageGenService): + class InputParams(BaseModel): + number_of_images: int = Field(default=1, ge=1, le=8) + model: str = Field(default="imagen-3.0-generate-002") + negative_prompt: str = Field(default=None) + + def __init__( + self, + *, + params: InputParams = InputParams(), + api_key: str, + **kwargs, + ): + super().__init__(**kwargs) + self.set_model_name(params.model) + self._params = params + self._client = genai.Client(api_key=api_key) + + def can_generate_metrics(self) -> bool: + return True + + async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: + """Generate images from a text prompt using Google's Imagen model. + + Args: + prompt (str): The text description to generate images from. + + Yields: + Frame: Generated image frames or error frames. + """ + logger.debug(f"Generating image from prompt: {prompt}") + await self.start_ttfb_metrics() + + try: + response = await self._client.aio.models.generate_images( + model=self._params.model, + prompt=prompt, + config=types.GenerateImagesConfig( + number_of_images=self._params.number_of_images, + negative_prompt=self._params.negative_prompt, + ), + ) + await self.stop_ttfb_metrics() + + if not response or not response.generated_images: + logger.error(f"{self} error: image generation failed") + yield ErrorFrame("Image generation failed") + return + + for img_response in response.generated_images: + # Google returns the image data directly + image_bytes = img_response.image.image_bytes + image = Image.open(io.BytesIO(image_bytes)) + + frame = URLImageRawFrame( + url=None, # Google doesn't provide URLs, only image data + image=image.tobytes(), + size=image.size, + format=image.format, + ) + yield frame + + except Exception as e: + logger.error(f"{self} error generating image: {e}") + yield ErrorFrame(f"Image generation error: {str(e)}") diff --git a/src/pipecat/services/google/llm.py b/src/pipecat/services/google/llm.py new file mode 100644 index 000000000..e78e3949b --- /dev/null +++ b/src/pipecat/services/google/llm.py @@ -0,0 +1,717 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import base64 +import io +import json +import os +import uuid + +from google.api_core.exceptions import DeadlineExceeded + +from pipecat.adapters.services.gemini_adapter import GeminiLLMAdapter + +# Suppress gRPC fork warnings +os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" + +from dataclasses import dataclass +from typing import Any, Dict, List, Mapping, Optional, Union + +from loguru import logger +from PIL import Image +from pydantic import BaseModel, Field + +from pipecat.frames.frames import ( + AudioRawFrame, + Frame, + FunctionCallCancelFrame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMMessagesFrame, + LLMTextFrame, + LLMUpdateSettingsFrame, + UserImageRawFrame, + VisionImageRawFrame, +) +from pipecat.metrics.metrics import LLMTokenUsage +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import LLMService +from pipecat.services.google.frames import LLMSearchResponseFrame +from pipecat.services.openai.llm import ( + OpenAIAssistantContextAggregator, + OpenAIUserContextAggregator, +) + +try: + import google.ai.generativelanguage as glm + import google.generativeai as gai + from google.generativeai.types import GenerationConfig + +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Google AI, you need to `pip install pipecat-ai[google]`.") + raise Exception(f"Missing module: {e}") + + +class GoogleUserContextAggregator(OpenAIUserContextAggregator): + async def push_aggregation(self): + if len(self._aggregation) > 0: + self._context.add_message( + glm.Content(role="user", parts=[glm.Part(text=self._aggregation)]) + ) + + # Reset the aggregation. Reset it before pushing it down, otherwise + # if the tasks gets cancelled we won't be able to clear things up. + self._aggregation = "" + + # Push context frame + frame = OpenAILLMContextFrame(self._context) + await self.push_frame(frame) + + # Reset our accumulator state. + self.reset() + + +class GoogleAssistantContextAggregator(OpenAIAssistantContextAggregator): + async def handle_aggregation(self, aggregation: str): + self._context.add_message(glm.Content(role="model", parts=[glm.Part(text=aggregation)])) + + async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): + self._context.add_message( + glm.Content( + role="model", + parts=[ + glm.Part( + function_call=glm.FunctionCall( + id=frame.tool_call_id, name=frame.function_name, args=frame.arguments + ) + ) + ], + ) + ) + self._context.add_message( + glm.Content( + role="user", + parts=[ + glm.Part( + function_response=glm.FunctionResponse( + id=frame.tool_call_id, + name=frame.function_name, + response={"response": "IN_PROGRESS"}, + ) + ) + ], + ) + ) + + async def handle_function_call_result(self, frame: FunctionCallResultFrame): + if frame.result: + await self._update_function_call_result( + frame.function_name, frame.tool_call_id, frame.result + ) + else: + await self._update_function_call_result( + frame.function_name, frame.tool_call_id, "COMPLETED" + ) + + async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): + await self._update_function_call_result( + frame.function_name, frame.tool_call_id, "CANCELLED" + ) + + async def _update_function_call_result( + self, function_name: str, tool_call_id: str, result: Any + ): + for message in self._context.messages: + if message.role == "user": + for part in message.parts: + if part.function_response and part.function_response.id == tool_call_id: + part.function_response.response = {"value": json.dumps(result)} + + async def handle_user_image_frame(self, frame: UserImageRawFrame): + await self._update_function_call_result( + frame.request.function_name, frame.request.tool_call_id, "COMPLETED" + ) + self._context.add_image_frame_message( + format=frame.format, + size=frame.size, + image=frame.image, + text=frame.request.context, + ) + + +@dataclass +class GoogleContextAggregatorPair: + _user: GoogleUserContextAggregator + _assistant: GoogleAssistantContextAggregator + + def user(self) -> GoogleUserContextAggregator: + return self._user + + def assistant(self) -> GoogleAssistantContextAggregator: + return self._assistant + + +class GoogleLLMContext(OpenAILLMContext): + def __init__( + self, + messages: Optional[List[dict]] = None, + tools: Optional[List[dict]] = None, + tool_choice: Optional[dict] = None, + ): + super().__init__(messages=messages, tools=tools, tool_choice=tool_choice) + self.system_message = None + + @staticmethod + def upgrade_to_google(obj: OpenAILLMContext) -> "GoogleLLMContext": + if isinstance(obj, OpenAILLMContext) and not isinstance(obj, GoogleLLMContext): + logger.debug(f"Upgrading to Google: {obj}") + obj.__class__ = GoogleLLMContext + obj._restructure_from_openai_messages() + return obj + + def set_messages(self, messages: List): + self._messages[:] = messages + self._restructure_from_openai_messages() + + def add_messages(self, messages: List): + # Convert each message individually + converted_messages = [] + for msg in messages: + if isinstance(msg, glm.Content): + # Already in Gemini format + converted_messages.append(msg) + else: + # Convert from standard format to Gemini format + converted = self.from_standard_message(msg) + if converted is not None: + converted_messages.append(converted) + + # Add the converted messages to our existing messages + self._messages.extend(converted_messages) + + def get_messages_for_logging(self): + msgs = [] + for message in self.messages: + obj = glm.Content.to_dict(message) + try: + if "parts" in obj: + for part in obj["parts"]: + if "inline_data" in part: + part["inline_data"]["data"] = "..." + except Exception as e: + logger.debug(f"Error: {e}") + msgs.append(obj) + return msgs + + def add_image_frame_message( + self, *, format: str, size: tuple[int, int], image: bytes, text: str = None + ): + buffer = io.BytesIO() + Image.frombytes(format, size, image).save(buffer, format="JPEG") + + parts = [] + if text: + parts.append(glm.Part(text=text)) + parts.append(glm.Part(inline_data=glm.Blob(mime_type="image/jpeg", data=buffer.getvalue()))) + + self.add_message(glm.Content(role="user", parts=parts)) + + def add_audio_frames_message( + self, *, audio_frames: list[AudioRawFrame], text: str = "Audio follows" + ): + if not audio_frames: + return + + sample_rate = audio_frames[0].sample_rate + num_channels = audio_frames[0].num_channels + + parts = [] + data = b"".join(frame.audio for frame in audio_frames) + # NOTE(aleix): According to the docs only text or inline_data should be needed. + # (see https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference) + parts.append(glm.Part(text=text)) + parts.append( + glm.Part( + inline_data=glm.Blob( + mime_type="audio/wav", + data=( + bytes( + self.create_wav_header(sample_rate, num_channels, 16, len(data)) + data + ) + ), + ) + ), + ) + self.add_message(glm.Content(role="user", parts=parts)) + # message = {"mime_type": "audio/mp3", "data": bytes(data + create_wav_header(sample_rate, num_channels, 16, len(data)))} + # self.add_message(message) + + def from_standard_message(self, message): + """Convert standard format message to Google Content object. + + Handles conversion of text, images, and function calls to Google's format. + System messages are stored separately and return None. + + Args: + message: Message in standard format: + { + "role": "user/assistant/system/tool", + "content": str | [{"type": "text/image_url", ...}] | None, + "tool_calls": [{"function": {"name": str, "arguments": str}}] + } + + Returns: + glm.Content object with: + - role: "user" or "model" (converted from "assistant") + - parts: List[Part] containing text, inline_data, or function calls + Returns None for system messages. + """ + role = message["role"] + content = message.get("content", []) + if role == "system": + self.system_message = content + return None + elif role == "assistant": + role = "model" + + parts = [] + if message.get("tool_calls"): + for tc in message["tool_calls"]: + parts.append( + glm.Part( + function_call=glm.FunctionCall( + name=tc["function"]["name"], + args=json.loads(tc["function"]["arguments"]), + ) + ) + ) + elif role == "tool": + role = "model" + parts.append( + glm.Part( + function_response=glm.FunctionResponse( + name="tool_call_result", # seems to work to hard-code the same name every time + response=json.loads(message["content"]), + ) + ) + ) + elif isinstance(content, str): + parts.append(glm.Part(text=content)) + elif isinstance(content, list): + for c in content: + if c["type"] == "text": + parts.append(glm.Part(text=c["text"])) + elif c["type"] == "image_url": + parts.append( + glm.Part( + inline_data=glm.Blob( + mime_type="image/jpeg", + data=base64.b64decode(c["image_url"]["url"].split(",")[1]), + ) + ) + ) + + message = glm.Content(role=role, parts=parts) + return message + + def to_standard_messages(self, obj) -> list: + """Convert Google Content object to standard structured format. + + Handles text, images, and function calls from Google's Content/Part objects. + + Args: + obj: Google Content object with: + - role: "model" (converted to "assistant") or "user" + - parts: List[Part] containing text, inline_data, or function calls + + Returns: + List of messages in standard format: + [ + { + "role": "user/assistant/tool", + "content": [ + {"type": "text", "text": str} | + {"type": "image_url", "image_url": {"url": str}} + ] + } + ] + """ + msg = {"role": obj.role, "content": []} + if msg["role"] == "model": + msg["role"] = "assistant" + + for part in obj.parts: + if part.text: + msg["content"].append({"type": "text", "text": part.text}) + elif part.inline_data: + encoded = base64.b64encode(part.inline_data.data).decode("utf-8") + msg["content"].append( + { + "type": "image_url", + "image_url": {"url": f"data:{part.inline_data.mime_type};base64,{encoded}"}, + } + ) + elif part.function_call: + args = type(part.function_call).to_dict(part.function_call).get("args", {}) + msg["tool_calls"] = [ + { + "id": part.function_call.name, + "type": "function", + "function": { + "name": part.function_call.name, + "arguments": json.dumps(args), + }, + } + ] + + elif part.function_response: + msg["role"] = "tool" + resp = ( + type(part.function_response).to_dict(part.function_response).get("response", {}) + ) + msg["tool_call_id"] = part.function_response.name + msg["content"] = json.dumps(resp) + + # there might be no content parts for tool_calls messages + if not msg["content"]: + del msg["content"] + return [msg] + + def _restructure_from_openai_messages(self): + """Restructures messages to ensure proper Google format and message ordering. + + This method handles conversion of OpenAI-formatted messages to Google format, + with special handling for function calls, function responses, and system messages. + System messages are added back to the context as user messages when needed. + + The final message order is preserved as: + 1. Function calls (from model) + 2. Function responses (from user) + 3. Text messages (converted from system messages) + + Note: + System messages are only added back when there are no regular text + messages in the context, ensuring proper conversation continuity + after function calls. + """ + self.system_message = None + converted_messages = [] + + # Process each message, preserving Google-formatted messages and converting others + for message in self._messages: + if isinstance(message, glm.Content): + # Keep existing Google-formatted messages (e.g., function calls/responses) + converted_messages.append(message) + continue + + # Convert OpenAI format to Google format, system messages return None + converted = self.from_standard_message(message) + if converted is not None: + converted_messages.append(converted) + + # Update message list + self._messages[:] = converted_messages + + # Check if we only have function-related messages (no regular text) + has_regular_messages = any( + len(msg.parts) == 1 + and not getattr(msg.parts[0], "text", None) + and getattr(msg.parts[0], "function_call", None) + and getattr(msg.parts[0], "function_response", None) + for msg in self._messages + ) + + # Add system message back as a user message if we only have function messages + if self.system_message and not has_regular_messages: + self._messages.append( + glm.Content(role="user", parts=[glm.Part(text=self.system_message)]) + ) + + # Remove any empty messages + self._messages = [m for m in self._messages if m.parts] + + +class GoogleLLMService(LLMService): + """This class implements inference with Google's AI models. + + This service translates internally from OpenAILLMContext to the messages format + expected by the Google AI model. We are using the OpenAILLMContext as a lingua + franca for all LLM services, so that it is easy to switch between different LLMs. + """ + + # Overriding the default adapter to use the Gemini one. + adapter_class = GeminiLLMAdapter + + class InputParams(BaseModel): + max_tokens: Optional[int] = Field(default=4096, ge=1) + temperature: Optional[float] = Field(default=None, ge=0.0, le=2.0) + top_k: Optional[int] = Field(default=None, ge=0) + top_p: Optional[float] = Field(default=None, ge=0.0, le=1.0) + extra: Optional[Dict[str, Any]] = Field(default_factory=dict) + + def __init__( + self, + *, + api_key: str, + model: str = "gemini-2.0-flash-001", + params: InputParams = InputParams(), + system_instruction: Optional[str] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_config: Optional[Dict[str, Any]] = None, + **kwargs, + ): + super().__init__(**kwargs) + gai.configure(api_key=api_key) + self.set_model_name(model) + self._system_instruction = system_instruction + self._create_client() + self._settings = { + "max_tokens": params.max_tokens, + "temperature": params.temperature, + "top_k": params.top_k, + "top_p": params.top_p, + "extra": params.extra if isinstance(params.extra, dict) else {}, + } + self._tools = tools + self._tool_config = tool_config + + def can_generate_metrics(self) -> bool: + return True + + def _create_client(self): + self._client = gai.GenerativeModel( + self._model_name, system_instruction=self._system_instruction + ) + + async def _process_context(self, context: OpenAILLMContext): + await self.push_frame(LLMFullResponseStartFrame()) + + prompt_tokens = 0 + completion_tokens = 0 + total_tokens = 0 + + grounding_metadata = None + search_result = "" + + try: + logger.debug( + # f"{self}: Generating chat [{self._system_instruction}] | [{context.get_messages_for_logging()}]" + f"{self}: Generating chat [{context.get_messages_for_logging()}]" + ) + + messages = context.messages + if context.system_message and self._system_instruction != context.system_message: + logger.debug(f"System instruction changed: {context.system_message}") + self._system_instruction = context.system_message + self._create_client() + + # Filter out None values and create GenerationConfig + generation_params = { + k: v + for k, v in { + "temperature": self._settings["temperature"], + "top_p": self._settings["top_p"], + "top_k": self._settings["top_k"], + "max_output_tokens": self._settings["max_tokens"], + }.items() + if v is not None + } + + generation_config = GenerationConfig(**generation_params) if generation_params else None + + await self.start_ttfb_metrics() + tools = [] + if context.tools: + tools = context.tools + elif self._tools: + tools = self._tools + tool_config = None + if self._tool_config: + tool_config = self._tool_config + response = await self._client.generate_content_async( + contents=messages, + tools=tools, + stream=True, + generation_config=generation_config, + tool_config=tool_config, + ) + await self.stop_ttfb_metrics() + + if response.usage_metadata: + # Use only the prompt token count from the response object + prompt_tokens = response.usage_metadata.prompt_token_count + total_tokens = prompt_tokens + + async for chunk in response: + if chunk.usage_metadata: + # Use only the completion_tokens from the chunks. Prompt tokens are already counted and + # are repeated here. + completion_tokens += chunk.usage_metadata.candidates_token_count + total_tokens += chunk.usage_metadata.candidates_token_count + try: + for c in chunk.parts: + if c.text: + search_result += c.text + await self.push_frame(LLMTextFrame(c.text)) + elif c.function_call: + logger.debug(f"Function call: {c.function_call}") + args = type(c.function_call).to_dict(c.function_call).get("args", {}) + await self.call_function( + context=context, + tool_call_id=str(uuid.uuid4()), + function_name=c.function_call.name, + arguments=args, + ) + # Handle grounding metadata + # It seems only the last chunk that we receive may contain this information + # If the response doesn't include groundingMetadata, this means the response wasn't grounded. + if chunk.candidates: + for candidate in chunk.candidates: + # logger.debug(f"candidate received: {candidate}") + # Extract grounding metadata + grounding_metadata = ( + { + "rendered_content": getattr( + getattr(candidate, "grounding_metadata", None), + "search_entry_point", + None, + ).rendered_content + if hasattr( + getattr(candidate, "grounding_metadata", None), + "search_entry_point", + ) + else None, + "origins": [ + { + "site_uri": getattr(grounding_chunk.web, "uri", None), + "site_title": getattr( + grounding_chunk.web, "title", None + ), + "results": [ + { + "text": getattr( + grounding_support.segment, "text", "" + ), + "confidence": getattr( + grounding_support, "confidence_scores", None + ), + } + for grounding_support in getattr( + getattr(candidate, "grounding_metadata", None), + "grounding_supports", + [], + ) + if index + in getattr( + grounding_support, "grounding_chunk_indices", [] + ) + ], + } + for index, grounding_chunk in enumerate( + getattr( + getattr(candidate, "grounding_metadata", None), + "grounding_chunks", + [], + ) + ) + ], + } + if getattr(candidate, "grounding_metadata", None) + else None + ) + except Exception as e: + # Google LLMs seem to flag safety issues a lot! + if chunk.candidates[0].finish_reason == 3: + logger.debug( + f"LLM refused to generate content for safety reasons - {messages}." + ) + else: + logger.exception(f"{self} error: {e}") + + except DeadlineExceeded: + await self._call_event_handler("on_completion_timeout") + except Exception as e: + logger.exception(f"{self} exception: {e}") + finally: + if grounding_metadata is not None and isinstance(grounding_metadata, dict): + llm_search_frame = LLMSearchResponseFrame( + search_result=search_result, + origins=grounding_metadata["origins"], + rendered_content=grounding_metadata["rendered_content"], + ) + await self.push_frame(llm_search_frame) + + await self.start_llm_usage_metrics( + LLMTokenUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + ) + ) + await self.push_frame(LLMFullResponseEndFrame()) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + context = None + + if isinstance(frame, OpenAILLMContextFrame): + context = GoogleLLMContext.upgrade_to_google(frame.context) + elif isinstance(frame, LLMMessagesFrame): + context = GoogleLLMContext(frame.messages) + elif isinstance(frame, VisionImageRawFrame): + context = GoogleLLMContext() + context.add_image_frame_message( + format=frame.format, size=frame.size, image=frame.image, text=frame.text + ) + elif isinstance(frame, LLMUpdateSettingsFrame): + await self._update_settings(frame.settings) + else: + await self.push_frame(frame, direction) + + if context: + await self._process_context(context) + + def create_context_aggregator( + self, + context: OpenAILLMContext, + *, + user_kwargs: Mapping[str, Any] = {}, + assistant_kwargs: Mapping[str, Any] = {}, + ) -> GoogleContextAggregatorPair: + """Create an instance of GoogleContextAggregatorPair from an + OpenAILLMContext. Constructor keyword arguments for both the user and + assistant aggregators can be provided. + + Args: + context (OpenAILLMContext): The LLM context. + user_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the user context aggregator constructor. Defaults + to an empty mapping. + assistant_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the assistant context aggregator + constructor. Defaults to an empty mapping. + + Returns: + GoogleContextAggregatorPair: A pair of context aggregators, one for + the user and one for the assistant, encapsulated in an + GoogleContextAggregatorPair. + + """ + context.set_llm_adapter(self.get_llm_adapter()) + + if isinstance(context, OpenAILLMContext): + context = GoogleLLMContext.upgrade_to_google(context) + user = GoogleUserContextAggregator(context, **user_kwargs) + assistant = GoogleAssistantContextAggregator(context, **assistant_kwargs) + return GoogleContextAggregatorPair(_user=user, _assistant=assistant) diff --git a/src/pipecat/services/google/llm_openai.py b/src/pipecat/services/google/llm_openai.py new file mode 100644 index 000000000..94b99072a --- /dev/null +++ b/src/pipecat/services/google/llm_openai.py @@ -0,0 +1,136 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import json +import os + +from openai import AsyncStream +from openai.types.chat import ChatCompletionChunk + +# Suppress gRPC fork warnings +os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" + +from loguru import logger + +from pipecat.frames.frames import LLMTextFrame +from pipecat.metrics.metrics import LLMTokenUsage +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.openai.base_llm import OpenAIUnhandledFunctionException +from pipecat.services.openai.llm import OpenAILLMService + + +class GoogleLLMOpenAIBetaService(OpenAILLMService): + """This class implements inference with Google's AI LLM models using the OpenAI format. + Ref - https://ai.google.dev/gemini-api/docs/openai + """ + + def __init__( + self, + *, + api_key: str, + base_url: str = "https://generativelanguage.googleapis.com/v1beta/openai/", + model: str = "gemini-2.0-flash", + **kwargs, + ): + super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs) + + async def _process_context(self, context: OpenAILLMContext): + functions_list = [] + arguments_list = [] + tool_id_list = [] + func_idx = 0 + function_name = "" + arguments = "" + tool_call_id = "" + + await self.start_ttfb_metrics() + + chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions( + context + ) + + async for chunk in chunk_stream: + if chunk.usage: + tokens = LLMTokenUsage( + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + total_tokens=chunk.usage.total_tokens, + ) + await self.start_llm_usage_metrics(tokens) + + if chunk.choices is None or len(chunk.choices) == 0: + continue + + await self.stop_ttfb_metrics() + + if not chunk.choices[0].delta: + continue + + if chunk.choices[0].delta.tool_calls: + # We're streaming the LLM response to enable the fastest response times. + # For text, we just yield each chunk as we receive it and count on consumers + # to do whatever coalescing they need (eg. to pass full sentences to TTS) + # + # If the LLM is a function call, we'll do some coalescing here. + # If the response contains a function name, we'll yield a frame to tell consumers + # that they can start preparing to call the function with that name. + # We accumulate all the arguments for the rest of the streamed response, then when + # the response is done, we package up all the arguments and the function name and + # yield a frame containing the function name and the arguments. + logger.debug(f"Tool call: {chunk.choices[0].delta.tool_calls}") + tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.index != func_idx: + functions_list.append(function_name) + arguments_list.append(arguments) + tool_id_list.append(tool_call_id) + function_name = "" + arguments = "" + tool_call_id = "" + func_idx += 1 + if tool_call.function and tool_call.function.name: + function_name += tool_call.function.name + tool_call_id = tool_call.id + if tool_call.function and tool_call.function.arguments: + # Keep iterating through the response to collect all the argument fragments + arguments += tool_call.function.arguments + elif chunk.choices[0].delta.content: + await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content)) + + # if we got a function name and arguments, check to see if it's a function with + # a registered handler. If so, run the registered callback, save the result to + # the context, and re-prompt to get a chat answer. If we don't have a registered + # handler, raise an exception. + if function_name and arguments: + # added to the list as last function name and arguments not added to the list + functions_list.append(function_name) + arguments_list.append(arguments) + tool_id_list.append(tool_call_id) + + logger.debug( + f"Function list: {functions_list}, Arguments list: {arguments_list}, Tool ID list: {tool_id_list}" + ) + for index, (function_name, arguments, tool_id) in enumerate( + zip(functions_list, arguments_list, tool_id_list), start=1 + ): + if function_name == "": + # TODO: Remove the _process_context method once Google resolves the bug + # where the index is incorrectly set to None instead of returning the actual index, + # which currently results in an empty function name(''). + continue + if self.has_function(function_name): + run_llm = False + arguments = json.loads(arguments) + await self.call_function( + context=context, + function_name=function_name, + arguments=arguments, + tool_call_id=tool_id, + run_llm=run_llm, + ) + else: + raise OpenAIUnhandledFunctionException( + f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." + ) diff --git a/src/pipecat/services/google/llm_vertex.py b/src/pipecat/services/google/llm_vertex.py new file mode 100644 index 000000000..1a23fe7d5 --- /dev/null +++ b/src/pipecat/services/google/llm_vertex.py @@ -0,0 +1,107 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import json +import os + +# Suppress gRPC fork warnings +os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" + +from typing import Optional + +from loguru import logger + +from pipecat.services.openai.llm import OpenAILLMService + +try: + from google.auth.transport.requests import Request + from google.oauth2 import service_account + +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_APPLICATION_CREDENTIALS` environment variable." + ) + raise Exception(f"Missing module: {e}") + + +class GoogleVertexLLMService(OpenAILLMService): + """Implements inference with Google's AI models via Vertex AI while + maintaining OpenAI API compatibility. + + Reference: + https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/call-vertex-using-openai-library + + """ + + class InputParams(OpenAILLMService.InputParams): + """Input parameters specific to Vertex AI.""" + + # https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations + location: str = "us-east4" + project_id: str + + def __init__( + self, + *, + credentials: Optional[str] = None, + credentials_path: Optional[str] = None, + model: str = "google/gemini-2.0-flash-001", + params: InputParams = OpenAILLMService.InputParams(), + **kwargs, + ): + """Initializes the VertexLLMService. + + Args: + credentials (Optional[str]): JSON string of service account credentials. + credentials_path (Optional[str]): Path to the service account JSON file. + model (str): Model identifier. Defaults to "google/gemini-2.0-flash-001". + params (InputParams): Vertex AI input parameters. + **kwargs: Additional arguments for OpenAILLMService. + """ + base_url = self._get_base_url(params) + self._api_key = self._get_api_token(credentials, credentials_path) + + super().__init__(api_key=self._api_key, base_url=base_url, model=model, **kwargs) + + @staticmethod + def _get_base_url(params: InputParams) -> str: + """Constructs the base URL for Vertex AI API.""" + return ( + f"https://{params.location}-aiplatform.googleapis.com/v1/" + f"projects/{params.project_id}/locations/{params.location}/endpoints/openapi" + ) + + @staticmethod + def _get_api_token(credentials: Optional[str], credentials_path: Optional[str]) -> str: + """Retrieves an authentication token using Google service account credentials. + + Args: + credentials (Optional[str]): JSON string of service account credentials. + credentials_path (Optional[str]): Path to the service account JSON file. + + Returns: + str: OAuth token for API authentication. + """ + creds: Optional[service_account.Credentials] = None + + if credentials: + # Parse and load credentials from JSON string + creds = service_account.Credentials.from_service_account_info( + json.loads(credentials), scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + elif credentials_path: + # Load credentials from JSON file + creds = service_account.Credentials.from_service_account_file( + credentials_path, scopes=["https://www.googleapis.com/auth/cloud-platform"] + ) + + if not creds: + raise ValueError("No valid credentials provided.") + + creds.refresh(Request()) # Ensure token is up-to-date, lifetime is 1 hour. + + return creds.token diff --git a/src/pipecat/services/google/stt.py b/src/pipecat/services/google/stt.py new file mode 100644 index 000000000..7af994bb2 --- /dev/null +++ b/src/pipecat/services/google/stt.py @@ -0,0 +1,806 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import json +import os +import time + +# Suppress gRPC fork warnings +os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" + +from typing import AsyncGenerator, List, Optional, Union + +from loguru import logger +from pydantic import BaseModel, Field, field_validator + +from pipecat.frames.frames import ( + CancelFrame, + EndFrame, + ErrorFrame, + Frame, + InterimTranscriptionFrame, + StartFrame, + TranscriptionFrame, +) +from pipecat.services.ai_services import STTService +from pipecat.transcriptions.language import Language +from pipecat.utils.time import time_now_iso8601 + +try: + from google.api_core.client_options import ClientOptions + from google.cloud import speech_v2 + from google.cloud.speech_v2.types import cloud_speech + from google.oauth2 import service_account + +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_APPLICATION_CREDENTIALS` environment variable." + ) + raise Exception(f"Missing module: {e}") + + +def language_to_google_stt_language(language: Language) -> Optional[str]: + """Maps Language enum to Google Speech-to-Text V2 language codes. + + Args: + language: Language enum value. + + Returns: + Optional[str]: Google STT language code or None if not supported. + """ + language_map = { + # Afrikaans + Language.AF: "af-ZA", + Language.AF_ZA: "af-ZA", + # Albanian + Language.SQ: "sq-AL", + Language.SQ_AL: "sq-AL", + # Amharic + Language.AM: "am-ET", + Language.AM_ET: "am-ET", + # Arabic + Language.AR: "ar-EG", # Default to Egypt + Language.AR_AE: "ar-AE", + Language.AR_BH: "ar-BH", + Language.AR_DZ: "ar-DZ", + Language.AR_EG: "ar-EG", + Language.AR_IQ: "ar-IQ", + Language.AR_JO: "ar-JO", + Language.AR_KW: "ar-KW", + Language.AR_LB: "ar-LB", + Language.AR_MA: "ar-MA", + Language.AR_OM: "ar-OM", + Language.AR_QA: "ar-QA", + Language.AR_SA: "ar-SA", + Language.AR_SY: "ar-SY", + Language.AR_TN: "ar-TN", + Language.AR_YE: "ar-YE", + # Armenian + Language.HY: "hy-AM", + Language.HY_AM: "hy-AM", + # Azerbaijani + Language.AZ: "az-AZ", + Language.AZ_AZ: "az-AZ", + # Basque + Language.EU: "eu-ES", + Language.EU_ES: "eu-ES", + # Bengali + Language.BN: "bn-IN", # Default to India + Language.BN_BD: "bn-BD", + Language.BN_IN: "bn-IN", + # Bosnian + Language.BS: "bs-BA", + Language.BS_BA: "bs-BA", + # Bulgarian + Language.BG: "bg-BG", + Language.BG_BG: "bg-BG", + # Burmese + Language.MY: "my-MM", + Language.MY_MM: "my-MM", + # Catalan + Language.CA: "ca-ES", + Language.CA_ES: "ca-ES", + # Chinese + Language.ZH: "cmn-Hans-CN", # Default to Simplified Chinese + Language.ZH_CN: "cmn-Hans-CN", + Language.ZH_HK: "cmn-Hans-HK", + Language.ZH_TW: "cmn-Hant-TW", + Language.YUE: "yue-Hant-HK", # Cantonese + Language.YUE_CN: "yue-Hant-HK", + # Croatian + Language.HR: "hr-HR", + Language.HR_HR: "hr-HR", + # Czech + Language.CS: "cs-CZ", + Language.CS_CZ: "cs-CZ", + # Danish + Language.DA: "da-DK", + Language.DA_DK: "da-DK", + # Dutch + Language.NL: "nl-NL", # Default to Netherlands + Language.NL_BE: "nl-BE", + Language.NL_NL: "nl-NL", + # English + Language.EN: "en-US", # Default to US + Language.EN_AU: "en-AU", + Language.EN_CA: "en-CA", + Language.EN_GB: "en-GB", + Language.EN_GH: "en-GH", + Language.EN_HK: "en-HK", + Language.EN_IN: "en-IN", + Language.EN_IE: "en-IE", + Language.EN_KE: "en-KE", + Language.EN_NG: "en-NG", + Language.EN_NZ: "en-NZ", + Language.EN_PH: "en-PH", + Language.EN_SG: "en-SG", + Language.EN_TZ: "en-TZ", + Language.EN_US: "en-US", + Language.EN_ZA: "en-ZA", + # Estonian + Language.ET: "et-EE", + Language.ET_EE: "et-EE", + # Filipino + Language.FIL: "fil-PH", + Language.FIL_PH: "fil-PH", + # Finnish + Language.FI: "fi-FI", + Language.FI_FI: "fi-FI", + # French + Language.FR: "fr-FR", # Default to France + Language.FR_BE: "fr-BE", + Language.FR_CA: "fr-CA", + Language.FR_CH: "fr-CH", + Language.FR_FR: "fr-FR", + # Galician + Language.GL: "gl-ES", + Language.GL_ES: "gl-ES", + # Georgian + Language.KA: "ka-GE", + Language.KA_GE: "ka-GE", + # German + Language.DE: "de-DE", # Default to Germany + Language.DE_AT: "de-AT", + Language.DE_CH: "de-CH", + Language.DE_DE: "de-DE", + # Greek + Language.EL: "el-GR", + Language.EL_GR: "el-GR", + # Gujarati + Language.GU: "gu-IN", + Language.GU_IN: "gu-IN", + # Hebrew + Language.HE: "iw-IL", + Language.HE_IL: "iw-IL", + # Hindi + Language.HI: "hi-IN", + Language.HI_IN: "hi-IN", + # Hungarian + Language.HU: "hu-HU", + Language.HU_HU: "hu-HU", + # Icelandic + Language.IS: "is-IS", + Language.IS_IS: "is-IS", + # Indonesian + Language.ID: "id-ID", + Language.ID_ID: "id-ID", + # Italian + Language.IT: "it-IT", + Language.IT_IT: "it-IT", + Language.IT_CH: "it-CH", + # Japanese + Language.JA: "ja-JP", + Language.JA_JP: "ja-JP", + # Javanese + Language.JV: "jv-ID", + Language.JV_ID: "jv-ID", + # Kannada + Language.KN: "kn-IN", + Language.KN_IN: "kn-IN", + # Kazakh + Language.KK: "kk-KZ", + Language.KK_KZ: "kk-KZ", + # Khmer + Language.KM: "km-KH", + Language.KM_KH: "km-KH", + # Korean + Language.KO: "ko-KR", + Language.KO_KR: "ko-KR", + # Lao + Language.LO: "lo-LA", + Language.LO_LA: "lo-LA", + # Latvian + Language.LV: "lv-LV", + Language.LV_LV: "lv-LV", + # Lithuanian + Language.LT: "lt-LT", + Language.LT_LT: "lt-LT", + # Macedonian + Language.MK: "mk-MK", + Language.MK_MK: "mk-MK", + # Malay + Language.MS: "ms-MY", + Language.MS_MY: "ms-MY", + # Malayalam + Language.ML: "ml-IN", + Language.ML_IN: "ml-IN", + # Marathi + Language.MR: "mr-IN", + Language.MR_IN: "mr-IN", + # Mongolian + Language.MN: "mn-MN", + Language.MN_MN: "mn-MN", + # Nepali + Language.NE: "ne-NP", + Language.NE_NP: "ne-NP", + # Norwegian + Language.NO: "no-NO", + Language.NB: "no-NO", + Language.NB_NO: "no-NO", + # Persian + Language.FA: "fa-IR", + Language.FA_IR: "fa-IR", + # Polish + Language.PL: "pl-PL", + Language.PL_PL: "pl-PL", + # Portuguese + Language.PT: "pt-PT", # Default to Portugal + Language.PT_BR: "pt-BR", + Language.PT_PT: "pt-PT", + # Punjabi + Language.PA: "pa-Guru-IN", + Language.PA_IN: "pa-Guru-IN", + # Romanian + Language.RO: "ro-RO", + Language.RO_RO: "ro-RO", + # Russian + Language.RU: "ru-RU", + Language.RU_RU: "ru-RU", + # Serbian + Language.SR: "sr-RS", + Language.SR_RS: "sr-RS", + # Sinhala + Language.SI: "si-LK", + Language.SI_LK: "si-LK", + # Slovak + Language.SK: "sk-SK", + Language.SK_SK: "sk-SK", + # Slovenian + Language.SL: "sl-SI", + Language.SL_SI: "sl-SI", + # Spanish + Language.ES: "es-ES", # Default to Spain + Language.ES_AR: "es-AR", + Language.ES_BO: "es-BO", + Language.ES_CL: "es-CL", + Language.ES_CO: "es-CO", + Language.ES_CR: "es-CR", + Language.ES_DO: "es-DO", + Language.ES_EC: "es-EC", + Language.ES_ES: "es-ES", + Language.ES_GT: "es-GT", + Language.ES_HN: "es-HN", + Language.ES_MX: "es-MX", + Language.ES_NI: "es-NI", + Language.ES_PA: "es-PA", + Language.ES_PE: "es-PE", + Language.ES_PR: "es-PR", + Language.ES_PY: "es-PY", + Language.ES_SV: "es-SV", + Language.ES_US: "es-US", + Language.ES_UY: "es-UY", + Language.ES_VE: "es-VE", + # Sundanese + Language.SU: "su-ID", + Language.SU_ID: "su-ID", + # Swahili + Language.SW: "sw-TZ", # Default to Tanzania + Language.SW_KE: "sw-KE", + Language.SW_TZ: "sw-TZ", + # Swedish + Language.SV: "sv-SE", + Language.SV_SE: "sv-SE", + # Tamil + Language.TA: "ta-IN", # Default to India + Language.TA_IN: "ta-IN", + Language.TA_MY: "ta-MY", + Language.TA_SG: "ta-SG", + Language.TA_LK: "ta-LK", + # Telugu + Language.TE: "te-IN", + Language.TE_IN: "te-IN", + # Thai + Language.TH: "th-TH", + Language.TH_TH: "th-TH", + # Turkish + Language.TR: "tr-TR", + Language.TR_TR: "tr-TR", + # Ukrainian + Language.UK: "uk-UA", + Language.UK_UA: "uk-UA", + # Urdu + Language.UR: "ur-IN", # Default to India + Language.UR_IN: "ur-IN", + Language.UR_PK: "ur-PK", + # Uzbek + Language.UZ: "uz-UZ", + Language.UZ_UZ: "uz-UZ", + # Vietnamese + Language.VI: "vi-VN", + Language.VI_VN: "vi-VN", + # Xhosa + Language.XH: "xh-ZA", + # Zulu + Language.ZU: "zu-ZA", + Language.ZU_ZA: "zu-ZA", + } + + return language_map.get(language) + + +class GoogleSTTService(STTService): + """Google Cloud Speech-to-Text V2 service implementation. + + Provides real-time speech recognition using Google Cloud's Speech-to-Text V2 API + with streaming support. Handles audio transcription and optional voice activity detection. + + Attributes: + InputParams: Configuration parameters for the STT service. + """ + + # Google Cloud's STT service has a connection time limit of 5 minutes per stream. + # They've shared an "endless streaming" example that guided this implementation: + # https://cloud.google.com/speech-to-text/docs/transcribe-streaming-audio#endless-streaming + + STREAMING_LIMIT = 240000 # 4 minutes in milliseconds + + class InputParams(BaseModel): + """Configuration parameters for Google Speech-to-Text. + + Attributes: + languages: Single language or list of recognition languages. First language is primary. + model: Speech recognition model to use. + use_separate_recognition_per_channel: Process each audio channel separately. + enable_automatic_punctuation: Add punctuation to transcripts. + enable_spoken_punctuation: Include spoken punctuation in transcript. + enable_spoken_emojis: Include spoken emojis in transcript. + profanity_filter: Filter profanity from transcript. + enable_word_time_offsets: Include timing information for each word. + enable_word_confidence: Include confidence scores for each word. + enable_interim_results: Stream partial recognition results. + enable_voice_activity_events: Detect voice activity in audio. + """ + + languages: Union[Language, List[Language]] = Field(default_factory=lambda: [Language.EN_US]) + model: Optional[str] = "latest_long" + use_separate_recognition_per_channel: Optional[bool] = False + enable_automatic_punctuation: Optional[bool] = True + enable_spoken_punctuation: Optional[bool] = False + enable_spoken_emojis: Optional[bool] = False + profanity_filter: Optional[bool] = False + enable_word_time_offsets: Optional[bool] = False + enable_word_confidence: Optional[bool] = False + enable_interim_results: Optional[bool] = True + enable_voice_activity_events: Optional[bool] = False + + @field_validator("languages", mode="before") + @classmethod + def validate_languages(cls, v) -> List[Language]: + if isinstance(v, Language): + return [v] + return v + + @property + def language_list(self) -> List[Language]: + """Get languages as a guaranteed list.""" + assert isinstance(self.languages, list) + return self.languages + + def __init__( + self, + *, + credentials: Optional[str] = None, + credentials_path: Optional[str] = None, + location: str = "global", + sample_rate: Optional[int] = None, + params: InputParams = InputParams(), + **kwargs, + ): + """Initialize the Google STT service. + + Args: + credentials: JSON string containing Google Cloud service account credentials. + credentials_path: Path to service account credentials JSON file. + location: Google Cloud location (e.g., "global", "us-central1"). + sample_rate: Audio sample rate in Hertz. + params: Configuration parameters for the service. + **kwargs: Additional arguments passed to STTService. + + Raises: + ValueError: If neither credentials nor credentials_path is provided. + ValueError: If project ID is not found in credentials. + """ + super().__init__(sample_rate=sample_rate, **kwargs) + + self._location = location + self._stream = None + self._config = None + self._request_queue = asyncio.Queue() + self._streaming_task = None + + # Used for keep-alive logic + self._stream_start_time = 0 + self._last_audio_input = [] + self._audio_input = [] + self._result_end_time = 0 + self._is_final_end_time = 0 + self._final_request_end_time = 0 + self._bridging_offset = 0 + self._last_transcript_was_final = False + self._new_stream = True + self._restart_counter = 0 + + # Configure client options based on location + client_options = None + if self._location != "global": + client_options = ClientOptions(api_endpoint=f"{self._location}-speech.googleapis.com") + + # Extract project ID and create client + if credentials: + json_account_info = json.loads(credentials) + self._project_id = json_account_info.get("project_id") + creds = service_account.Credentials.from_service_account_info(json_account_info) + elif credentials_path: + with open(credentials_path) as f: + json_account_info = json.load(f) + self._project_id = json_account_info.get("project_id") + creds = service_account.Credentials.from_service_account_file(credentials_path) + else: + raise ValueError("Either credentials or credentials_path must be provided") + + if not self._project_id: + raise ValueError("Project ID not found in credentials") + + self._client = speech_v2.SpeechAsyncClient(credentials=creds, client_options=client_options) + + self._settings = { + "language_codes": [ + self.language_to_service_language(lang) for lang in params.language_list + ], + "model": params.model, + "use_separate_recognition_per_channel": params.use_separate_recognition_per_channel, + "enable_automatic_punctuation": params.enable_automatic_punctuation, + "enable_spoken_punctuation": params.enable_spoken_punctuation, + "enable_spoken_emojis": params.enable_spoken_emojis, + "profanity_filter": params.profanity_filter, + "enable_word_time_offsets": params.enable_word_time_offsets, + "enable_word_confidence": params.enable_word_confidence, + "enable_interim_results": params.enable_interim_results, + "enable_voice_activity_events": params.enable_voice_activity_events, + } + + def language_to_service_language(self, language: Language | List[Language]) -> str | List[str]: + """Convert Language enum(s) to Google STT language code(s). + + Args: + language: Single Language enum or list of Language enums. + + Returns: + str | List[str]: Google STT language code(s). + """ + if isinstance(language, list): + return [language_to_google_stt_language(lang) or "en-US" for lang in language] + return language_to_google_stt_language(language) or "en-US" + + async def _reconnect_if_needed(self): + """Reconnect the stream if it's currently active.""" + if self._streaming_task: + logger.debug("Reconnecting stream due to configuration changes") + await self._disconnect() + await self._connect() + + async def set_language(self, language: Language): + """Update the service's recognition language. + + A convenience method for setting a single language. + + Args: + language: New language for recognition. + """ + logger.debug(f"Switching STT language to: {language}") + await self.set_languages([language]) + + async def set_languages(self, languages: List[Language]): + """Update the service's recognition languages. + + Args: + languages: List of languages for recognition. First language is primary. + """ + logger.debug(f"Switching STT languages to: {languages}") + self._settings["language_codes"] = [ + self.language_to_service_language(lang) for lang in languages + ] + # Recreate stream with new languages + await self._reconnect_if_needed() + + async def set_model(self, model: str): + """Update the service's recognition model.""" + logger.debug(f"Switching STT model to: {model}") + await super().set_model(model) + self._settings["model"] = model + # Recreate stream with new model + await self._reconnect_if_needed() + + async def start(self, frame: StartFrame): + await super().start(frame) + await self._connect() + + async def stop(self, frame: EndFrame): + await super().stop(frame) + await self._disconnect() + + async def cancel(self, frame: CancelFrame): + await super().cancel(frame) + await self._disconnect() + + async def update_options( + self, + *, + languages: Optional[List[Language]] = None, + model: Optional[str] = None, + enable_automatic_punctuation: Optional[bool] = None, + enable_spoken_punctuation: Optional[bool] = None, + enable_spoken_emojis: Optional[bool] = None, + profanity_filter: Optional[bool] = None, + enable_word_time_offsets: Optional[bool] = None, + enable_word_confidence: Optional[bool] = None, + enable_interim_results: Optional[bool] = None, + enable_voice_activity_events: Optional[bool] = None, + location: Optional[str] = None, + ) -> None: + """Update service options dynamically. + + Args: + languages: New list of recongition languages. + model: New recognition model. + enable_automatic_punctuation: Enable/disable automatic punctuation. + enable_spoken_punctuation: Enable/disable spoken punctuation. + enable_spoken_emojis: Enable/disable spoken emojis. + profanity_filter: Enable/disable profanity filter. + enable_word_time_offsets: Enable/disable word timing info. + enable_word_confidence: Enable/disable word confidence scores. + enable_interim_results: Enable/disable interim results. + enable_voice_activity_events: Enable/disable voice activity detection. + location: New Google Cloud location. + + Note: + Changes that affect the streaming configuration will cause + the stream to be reconnected. + """ + # Update settings with new values + if languages is not None: + logger.debug(f"Updating language to: {languages}") + self._settings["language_codes"] = [ + self.language_to_service_language(lang) for lang in languages + ] + + if model is not None: + logger.debug(f"Updating model to: {model}") + self._settings["model"] = model + + if enable_automatic_punctuation is not None: + logger.debug(f"Updating automatic punctuation to: {enable_automatic_punctuation}") + self._settings["enable_automatic_punctuation"] = enable_automatic_punctuation + + if enable_spoken_punctuation is not None: + logger.debug(f"Updating spoken punctuation to: {enable_spoken_punctuation}") + self._settings["enable_spoken_punctuation"] = enable_spoken_punctuation + + if enable_spoken_emojis is not None: + logger.debug(f"Updating spoken emojis to: {enable_spoken_emojis}") + self._settings["enable_spoken_emojis"] = enable_spoken_emojis + + if profanity_filter is not None: + logger.debug(f"Updating profanity filter to: {profanity_filter}") + self._settings["profanity_filter"] = profanity_filter + + if enable_word_time_offsets is not None: + logger.debug(f"Updating word time offsets to: {enable_word_time_offsets}") + self._settings["enable_word_time_offsets"] = enable_word_time_offsets + + if enable_word_confidence is not None: + logger.debug(f"Updating word confidence to: {enable_word_confidence}") + self._settings["enable_word_confidence"] = enable_word_confidence + + if enable_interim_results is not None: + logger.debug(f"Updating interim results to: {enable_interim_results}") + self._settings["enable_interim_results"] = enable_interim_results + + if enable_voice_activity_events is not None: + logger.debug(f"Updating voice activity events to: {enable_voice_activity_events}") + self._settings["enable_voice_activity_events"] = enable_voice_activity_events + + if location is not None: + logger.debug(f"Updating location to: {location}") + self._location = location + + # Reconnect the stream for updates + await self._reconnect_if_needed() + + async def _connect(self): + """Initialize streaming recognition config and stream.""" + logger.debug("Connecting to Google Speech-to-Text") + + # Set stream start time + self._stream_start_time = int(time.time() * 1000) + self._new_stream = True + + self._config = cloud_speech.StreamingRecognitionConfig( + config=cloud_speech.RecognitionConfig( + explicit_decoding_config=cloud_speech.ExplicitDecodingConfig( + encoding=cloud_speech.ExplicitDecodingConfig.AudioEncoding.LINEAR16, + sample_rate_hertz=self.sample_rate, + audio_channel_count=1, + ), + language_codes=self._settings["language_codes"], + model=self._settings["model"], + features=cloud_speech.RecognitionFeatures( + enable_automatic_punctuation=self._settings["enable_automatic_punctuation"], + enable_spoken_punctuation=self._settings["enable_spoken_punctuation"], + enable_spoken_emojis=self._settings["enable_spoken_emojis"], + profanity_filter=self._settings["profanity_filter"], + enable_word_time_offsets=self._settings["enable_word_time_offsets"], + enable_word_confidence=self._settings["enable_word_confidence"], + ), + ), + streaming_features=cloud_speech.StreamingRecognitionFeatures( + enable_voice_activity_events=self._settings["enable_voice_activity_events"], + interim_results=self._settings["enable_interim_results"], + ), + ) + + self._streaming_task = self.create_task(self._stream_audio()) + + async def _disconnect(self): + """Clean up streaming recognition resources.""" + if self._streaming_task: + logger.debug("Disconnecting from Google Speech-to-Text") + # Send sentinel value to stop request generator + await self._request_queue.put(None) + await self.cancel_task(self._streaming_task) + self._streaming_task = None + # Clear any remaining items in the queue + while not self._request_queue.empty(): + try: + self._request_queue.get_nowait() + self._request_queue.task_done() + except asyncio.QueueEmpty: + break + + async def _request_generator(self): + """Generates requests for the streaming recognize method.""" + recognizer_path = f"projects/{self._project_id}/locations/{self._location}/recognizers/_" + logger.trace(f"Using recognizer path: {recognizer_path}") + + try: + # Send initial config + yield cloud_speech.StreamingRecognizeRequest( + recognizer=recognizer_path, + streaming_config=self._config, + ) + + while True: + try: + audio_data = await self._request_queue.get() + if audio_data is None: # Sentinel value to stop + break + + # Check streaming limit + if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT: + logger.debug("Streaming limit reached, initiating graceful reconnection") + # Instead of immediate reconnection, we'll break and let the stream close naturally + self._last_audio_input = self._audio_input + self._audio_input = [] + self._restart_counter += 1 + # Put the current audio chunk back in the queue + await self._request_queue.put(audio_data) + break + + self._audio_input.append(audio_data) + yield cloud_speech.StreamingRecognizeRequest(audio=audio_data) + + except asyncio.CancelledError: + break + finally: + self._request_queue.task_done() + + except Exception as e: + logger.error(f"Error in request generator: {e}") + raise + + async def _stream_audio(self): + """Handle bi-directional streaming with Google STT.""" + try: + while True: + try: + # Start bi-directional streaming + streaming_recognize = await self._client.streaming_recognize( + requests=self._request_generator() + ) + + # Process responses + await self._process_responses(streaming_recognize) + + # If we're here, check if we need to reconnect + if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT: + logger.debug("Reconnecting stream after timeout") + # Reset stream start time + self._stream_start_time = int(time.time() * 1000) + continue + else: + # Normal stream end + break + + except Exception as e: + logger.warning(f"{self} Reconnecting: {e}") + + await asyncio.sleep(1) # Brief delay before reconnecting + self._stream_start_time = int(time.time() * 1000) + continue + + except Exception as e: + logger.error(f"Error in streaming task: {e}") + await self.push_frame(ErrorFrame(str(e))) + + async def run_stt(self, audio: bytes) -> AsyncGenerator[Frame, None]: + """Process an audio chunk for STT transcription.""" + if self._streaming_task: + # Queue the audio data + await self._request_queue.put(audio) + yield None + + async def _process_responses(self, streaming_recognize): + """Process streaming recognition responses.""" + try: + async for response in streaming_recognize: + # Check streaming limit + if (int(time.time() * 1000) - self._stream_start_time) > self.STREAMING_LIMIT: + logger.debug("Stream timeout reached in response processing") + break + + if not response.results: + continue + + for result in response.results: + if not result.alternatives: + continue + + transcript = result.alternatives[0].transcript + if not transcript: + continue + + primary_language = self._settings["language_codes"][0] + + if result.is_final: + self._last_transcript_was_final = True + await self.push_frame( + TranscriptionFrame(transcript, "", time_now_iso8601(), primary_language) + ) + else: + self._last_transcript_was_final = False + await self.push_frame( + InterimTranscriptionFrame( + transcript, "", time_now_iso8601(), primary_language + ) + ) + + except Exception as e: + logger.error(f"Error processing Google STT responses: {e}") + + # Re-raise the exception to let it propagate (e.g. in the case of a timeout, propagate to _stream_audio to reconnect) + raise diff --git a/src/pipecat/services/google/tts.py b/src/pipecat/services/google/tts.py new file mode 100644 index 000000000..36bd27a51 --- /dev/null +++ b/src/pipecat/services/google/tts.py @@ -0,0 +1,364 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +import json +import os + +# Suppress gRPC fork warnings +os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" + +from typing import AsyncGenerator, Literal, Optional + +from loguru import logger +from pydantic import BaseModel + +from pipecat.frames.frames import ( + ErrorFrame, + Frame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.ai_services import TTSService +from pipecat.transcriptions.language import Language + +try: + from google.cloud import texttospeech_v1 + from google.oauth2 import service_account + +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error( + "In order to use Google AI, you need to `pip install pipecat-ai[google]`. Also, set `GOOGLE_APPLICATION_CREDENTIALS` environment variable." + ) + raise Exception(f"Missing module: {e}") + + +def language_to_google_tts_language(language: Language) -> Optional[str]: + language_map = { + # Afrikaans + Language.AF: "af-ZA", + Language.AF_ZA: "af-ZA", + # Arabic + Language.AR: "ar-XA", + # Bengali + Language.BN: "bn-IN", + Language.BN_IN: "bn-IN", + # Bulgarian + Language.BG: "bg-BG", + Language.BG_BG: "bg-BG", + # Catalan + Language.CA: "ca-ES", + Language.CA_ES: "ca-ES", + # Chinese (Mandarin and Cantonese) + Language.ZH: "cmn-CN", + Language.ZH_CN: "cmn-CN", + Language.ZH_TW: "cmn-TW", + Language.ZH_HK: "yue-HK", + # Czech + Language.CS: "cs-CZ", + Language.CS_CZ: "cs-CZ", + # Danish + Language.DA: "da-DK", + Language.DA_DK: "da-DK", + # Dutch + Language.NL: "nl-NL", + Language.NL_BE: "nl-BE", + Language.NL_NL: "nl-NL", + # English + Language.EN: "en-US", + Language.EN_US: "en-US", + Language.EN_AU: "en-AU", + Language.EN_GB: "en-GB", + Language.EN_IN: "en-IN", + # Estonian + Language.ET: "et-EE", + Language.ET_EE: "et-EE", + # Filipino + Language.FIL: "fil-PH", + Language.FIL_PH: "fil-PH", + # Finnish + Language.FI: "fi-FI", + Language.FI_FI: "fi-FI", + # French + Language.FR: "fr-FR", + Language.FR_CA: "fr-CA", + Language.FR_FR: "fr-FR", + # Galician + Language.GL: "gl-ES", + Language.GL_ES: "gl-ES", + # German + Language.DE: "de-DE", + Language.DE_DE: "de-DE", + # Greek + Language.EL: "el-GR", + Language.EL_GR: "el-GR", + # Gujarati + Language.GU: "gu-IN", + Language.GU_IN: "gu-IN", + # Hebrew + Language.HE: "he-IL", + Language.HE_IL: "he-IL", + # Hindi + Language.HI: "hi-IN", + Language.HI_IN: "hi-IN", + # Hungarian + Language.HU: "hu-HU", + Language.HU_HU: "hu-HU", + # Icelandic + Language.IS: "is-IS", + Language.IS_IS: "is-IS", + # Indonesian + Language.ID: "id-ID", + Language.ID_ID: "id-ID", + # Italian + Language.IT: "it-IT", + Language.IT_IT: "it-IT", + # Japanese + Language.JA: "ja-JP", + Language.JA_JP: "ja-JP", + # Kannada + Language.KN: "kn-IN", + Language.KN_IN: "kn-IN", + # Korean + Language.KO: "ko-KR", + Language.KO_KR: "ko-KR", + # Latvian + Language.LV: "lv-LV", + Language.LV_LV: "lv-LV", + # Lithuanian + Language.LT: "lt-LT", + Language.LT_LT: "lt-LT", + # Malay + Language.MS: "ms-MY", + Language.MS_MY: "ms-MY", + # Malayalam + Language.ML: "ml-IN", + Language.ML_IN: "ml-IN", + # Marathi + Language.MR: "mr-IN", + Language.MR_IN: "mr-IN", + # Norwegian + Language.NO: "nb-NO", + Language.NB: "nb-NO", + Language.NB_NO: "nb-NO", + # Polish + Language.PL: "pl-PL", + Language.PL_PL: "pl-PL", + # Portuguese + Language.PT: "pt-PT", + Language.PT_BR: "pt-BR", + Language.PT_PT: "pt-PT", + # Punjabi + Language.PA: "pa-IN", + Language.PA_IN: "pa-IN", + # Romanian + Language.RO: "ro-RO", + Language.RO_RO: "ro-RO", + # Russian + Language.RU: "ru-RU", + Language.RU_RU: "ru-RU", + # Serbian + Language.SR: "sr-RS", + Language.SR_RS: "sr-RS", + # Slovak + Language.SK: "sk-SK", + Language.SK_SK: "sk-SK", + # Spanish + Language.ES: "es-ES", + Language.ES_ES: "es-ES", + Language.ES_US: "es-US", + # Swedish + Language.SV: "sv-SE", + Language.SV_SE: "sv-SE", + # Tamil + Language.TA: "ta-IN", + Language.TA_IN: "ta-IN", + # Telugu + Language.TE: "te-IN", + Language.TE_IN: "te-IN", + # Thai + Language.TH: "th-TH", + Language.TH_TH: "th-TH", + # Turkish + Language.TR: "tr-TR", + Language.TR_TR: "tr-TR", + # Ukrainian + Language.UK: "uk-UA", + Language.UK_UA: "uk-UA", + # Vietnamese + Language.VI: "vi-VN", + Language.VI_VN: "vi-VN", + } + + return language_map.get(language) + + +class GoogleTTSService(TTSService): + class InputParams(BaseModel): + pitch: Optional[str] = None + rate: Optional[str] = None + volume: Optional[str] = None + emphasis: Optional[Literal["strong", "moderate", "reduced", "none"]] = None + language: Optional[Language] = Language.EN + gender: Optional[Literal["male", "female", "neutral"]] = None + google_style: Optional[Literal["apologetic", "calm", "empathetic", "firm", "lively"]] = None + + def __init__( + self, + *, + credentials: Optional[str] = None, + credentials_path: Optional[str] = None, + voice_id: str = "en-US-Neural2-A", + sample_rate: Optional[int] = None, + params: InputParams = InputParams(), + **kwargs, + ): + super().__init__(sample_rate=sample_rate, **kwargs) + + self._settings = { + "pitch": params.pitch, + "rate": params.rate, + "volume": params.volume, + "emphasis": params.emphasis, + "language": self.language_to_service_language(params.language) + if params.language + else "en-US", + "gender": params.gender, + "google_style": params.google_style, + } + self.set_voice(voice_id) + self._client: texttospeech_v1.TextToSpeechAsyncClient = self._create_client( + credentials, credentials_path + ) + + def _create_client( + self, credentials: Optional[str], credentials_path: Optional[str] + ) -> texttospeech_v1.TextToSpeechAsyncClient: + creds: Optional[service_account.Credentials] = None + + # Create a Google Cloud service account for the Cloud Text-to-Speech API + # Using either the provided credentials JSON string or the path to a service account JSON + # file, create a Google Cloud service account and use it to authenticate with the API. + if credentials: + # Use provided credentials JSON string + json_account_info = json.loads(credentials) + creds = service_account.Credentials.from_service_account_info(json_account_info) + elif credentials_path: + # Use service account JSON file if provided + creds = service_account.Credentials.from_service_account_file(credentials_path) + + return texttospeech_v1.TextToSpeechAsyncClient(credentials=creds) + + def can_generate_metrics(self) -> bool: + return True + + def language_to_service_language(self, language: Language) -> Optional[str]: + return language_to_google_tts_language(language) + + def _construct_ssml(self, text: str) -> str: + ssml = "" + + # Voice tag + voice_attrs = [f"name='{self._voice_id}'"] + + language = self._settings["language"] + voice_attrs.append(f"language='{language}'") + + if self._settings["gender"]: + voice_attrs.append(f"gender='{self._settings['gender']}'") + ssml += f"" + + # Prosody tag + prosody_attrs = [] + if self._settings["pitch"]: + prosody_attrs.append(f"pitch='{self._settings['pitch']}'") + if self._settings["rate"]: + prosody_attrs.append(f"rate='{self._settings['rate']}'") + if self._settings["volume"]: + prosody_attrs.append(f"volume='{self._settings['volume']}'") + + if prosody_attrs: + ssml += f"" + + # Emphasis tag + if self._settings["emphasis"]: + ssml += f"" + + # Google style tag + if self._settings["google_style"]: + ssml += f"" + + ssml += text + + # Close tags + if self._settings["google_style"]: + ssml += "" + if self._settings["emphasis"]: + ssml += "" + if prosody_attrs: + ssml += "" + ssml += "" + + return ssml + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"{self}: Generating TTS [{text}]") + + try: + await self.start_ttfb_metrics() + + # Check if the voice is a Chirp voice (including Chirp 3) or Journey voice + is_chirp_voice = "chirp" in self._voice_id.lower() + is_journey_voice = "journey" in self._voice_id.lower() + + # Create synthesis input based on voice_id + if is_chirp_voice or is_journey_voice: + # Chirp and Journey voices don't support SSML, use plain text + synthesis_input = texttospeech_v1.SynthesisInput(text=text) + else: + ssml = self._construct_ssml(text) + synthesis_input = texttospeech_v1.SynthesisInput(ssml=ssml) + + voice = texttospeech_v1.VoiceSelectionParams( + language_code=self._settings["language"], name=self._voice_id + ) + audio_config = texttospeech_v1.AudioConfig( + audio_encoding=texttospeech_v1.AudioEncoding.LINEAR16, + sample_rate_hertz=self.sample_rate, + ) + + request = texttospeech_v1.SynthesizeSpeechRequest( + input=synthesis_input, voice=voice, audio_config=audio_config + ) + + response = await self._client.synthesize_speech(request=request) + + await self.start_tts_usage_metrics(text) + + yield TTSStartedFrame() + + # Skip the first 44 bytes to remove the WAV header + audio_content = response.audio_content[44:] + + # Read and yield audio data in chunks + chunk_size = 8192 + for i in range(0, len(audio_content), chunk_size): + chunk = audio_content[i : i + chunk_size] + if not chunk: + break + await self.stop_ttfb_metrics() + frame = TTSAudioRawFrame(chunk, self.sample_rate, 1) + yield frame + await asyncio.sleep(0) # Allow other tasks to run + + yield TTSStoppedFrame() + + except Exception as e: + logger.exception(f"{self} error generating TTS: {e}") + error_message = f"TTS generation error: {str(e)}" + yield ErrorFrame(error=error_message) diff --git a/src/pipecat/services/grok/__init__.py b/src/pipecat/services/grok/__init__.py new file mode 100644 index 000000000..9eebcfccf --- /dev/null +++ b/src/pipecat/services/grok/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "grok", "grok.llm") diff --git a/src/pipecat/services/grok.py b/src/pipecat/services/grok/llm.py similarity index 91% rename from src/pipecat/services/grok.py rename to src/pipecat/services/grok/llm.py index faed13050..57517eb3e 100644 --- a/src/pipecat/services/grok.py +++ b/src/pipecat/services/grok/llm.py @@ -4,21 +4,14 @@ # SPDX-License-Identifier: BSD 2-Clause License # - -import json from dataclasses import dataclass -from typing import Any, Mapping, Optional +from typing import Any, Mapping from loguru import logger -from pipecat.frames.frames import FunctionCallResultProperties from pipecat.metrics.metrics import LLMTokenUsage -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.openai import ( +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.openai.llm import ( OpenAIAssistantContextAggregator, OpenAILLMService, OpenAIUserContextAggregator, @@ -27,13 +20,13 @@ from pipecat.services.openai import ( @dataclass class GrokContextAggregatorPair: - _user: "OpenAIUserContextAggregator" - _assistant: "OpenAIAssistantContextAggregator" + _user: OpenAIUserContextAggregator + _assistant: OpenAIAssistantContextAggregator - def user(self) -> "OpenAIUserContextAggregator": + def user(self) -> OpenAIUserContextAggregator: return self._user - def assistant(self) -> "OpenAIAssistantContextAggregator": + def assistant(self) -> OpenAIAssistantContextAggregator: return self._assistant diff --git a/src/pipecat/services/groq.py b/src/pipecat/services/groq.py deleted file mode 100644 index bf0304df2..000000000 --- a/src/pipecat/services/groq.py +++ /dev/null @@ -1,177 +0,0 @@ -# -# Copyright (c) 2024–2025, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - - -from typing import AsyncGenerator, Optional - -from loguru import logger -from pydantic import BaseModel - -from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame -from pipecat.services.ai_services import TTSService -from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription -from pipecat.services.openai import OpenAILLMService -from pipecat.transcriptions.language import Language - -try: - from groq import AsyncGroq -except ModuleNotFoundError as e: - logger.error(f"Exception: {e}") - logger.error( - "In order to use Groq, you need to `pip install pipecat-ai[groq]`. Also, set a `GROQ_API_KEY` environment variable." - ) - raise Exception(f"Missing module: {e}") - - -class GroqLLMService(OpenAILLMService): - """A service for interacting with Groq's API using the OpenAI-compatible interface. - - This service extends OpenAILLMService to connect to Groq's API endpoint while - maintaining full compatibility with OpenAI's interface and functionality. - - Args: - api_key (str): The API key for accessing Groq's API - base_url (str, optional): The base URL for Groq API. Defaults to "https://api.groq.com/openai/v1" - model (str, optional): The model identifier to use. Defaults to "llama-3.3-70b-versatile" - **kwargs: Additional keyword arguments passed to OpenAILLMService - """ - - def __init__( - self, - *, - api_key: str, - base_url: str = "https://api.groq.com/openai/v1", - model: str = "llama-3.3-70b-versatile", - **kwargs, - ): - super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs) - - def create_client(self, api_key=None, base_url=None, **kwargs): - """Create OpenAI-compatible client for Groq API endpoint.""" - logger.debug(f"Creating Groq client with api {base_url}") - return super().create_client(api_key, base_url, **kwargs) - - -class GroqSTTService(BaseWhisperSTTService): - """Groq Whisper speech-to-text service. - - Uses Groq's Whisper API to convert audio to text. Requires a Groq API key - set via the api_key parameter or GROQ_API_KEY environment variable. - - Args: - model: Whisper model to use. Defaults to "whisper-large-v3-turbo". - api_key: Groq API key. Defaults to None. - base_url: API base URL. Defaults to "https://api.groq.com/openai/v1". - language: Language of the audio input. Defaults to English. - prompt: Optional text to guide the model's style or continue a previous segment. - temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0. - **kwargs: Additional arguments passed to BaseWhisperSTTService. - """ - - def __init__( - self, - *, - model: str = "whisper-large-v3-turbo", - api_key: Optional[str] = None, - base_url: str = "https://api.groq.com/openai/v1", - language: Optional[Language] = Language.EN, - prompt: Optional[str] = None, - temperature: Optional[float] = None, - **kwargs, - ): - super().__init__( - model=model, - api_key=api_key, - base_url=base_url, - language=language, - prompt=prompt, - temperature=temperature, - **kwargs, - ) - - async def _transcribe(self, audio: bytes) -> Transcription: - assert self._language is not None # Assigned in the BaseWhisperSTTService class - - # Build kwargs dict with only set parameters - kwargs = { - "file": ("audio.wav", audio, "audio/wav"), - "model": self.model_name, - "response_format": "json", - "language": self._language, - } - - if self._prompt is not None: - kwargs["prompt"] = self._prompt - - if self._temperature is not None: - kwargs["temperature"] = self._temperature - - return await self._client.audio.transcriptions.create(**kwargs) - - -class GroqTTSService(TTSService): - class InputParams(BaseModel): - language: Optional[Language] = Language.EN - speed: Optional[float] = 1.0 - seed: Optional[int] = None - - GROQ_SAMPLE_RATE = 48000 # Groq TTS only supports 48kHz sample rate - - def __init__( - self, - *, - api_key: str, - output_format: str = "wav", - params: InputParams = InputParams(), - model_name: str = "playai-tts", - voice_id: str = "Celeste-PlayAI", - sample_rate: Optional[int] = GROQ_SAMPLE_RATE, - **kwargs, - ): - if sample_rate != self.GROQ_SAMPLE_RATE: - logger.warning(f"Groq TTS only supports {self.GROQ_SAMPLE_RATE}Hz sample rate. ") - super().__init__( - pause_frame_processing=True, - sample_rate=sample_rate, - **kwargs, - ) - - self._api_key = api_key - self._model_name = model_name - self._output_format = output_format - self._voice_id = voice_id - self._params = params - - self._client = AsyncGroq(api_key=self._api_key) - - def can_generate_metrics(self) -> bool: - return True - - async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: - logger.debug(f"{self}: Generating TTS [{text}]") - measuring_ttfb = True - await self.start_ttfb_metrics() - yield TTSStartedFrame() - - response = await self._client.audio.speech.create( - model=self._model_name, - voice=self._voice_id, - response_format=self._output_format, - input=text, - ) - - async for data in response.iter_bytes(): - if measuring_ttfb: - await self.stop_ttfb_metrics() - measuring_ttfb = False - # remove wav header if present - if data.startswith(b"RIFF"): - data = data[44:] - if len(data) == 0: - continue - yield TTSAudioRawFrame(data, self.sample_rate, 1) - - yield TTSStoppedFrame() diff --git a/src/pipecat/services/groq/__init__.py b/src/pipecat/services/groq/__init__.py new file mode 100644 index 000000000..216853830 --- /dev/null +++ b/src/pipecat/services/groq/__init__.py @@ -0,0 +1,15 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * +from .stt import * +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "groq", "groq.[llm,stt,tts]") diff --git a/src/pipecat/services/groq/llm.py b/src/pipecat/services/groq/llm.py new file mode 100644 index 000000000..be2ed5e72 --- /dev/null +++ b/src/pipecat/services/groq/llm.py @@ -0,0 +1,38 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from loguru import logger + +from pipecat.services.openai.llm import OpenAILLMService + + +class GroqLLMService(OpenAILLMService): + """A service for interacting with Groq's API using the OpenAI-compatible interface. + + This service extends OpenAILLMService to connect to Groq's API endpoint while + maintaining full compatibility with OpenAI's interface and functionality. + + Args: + api_key (str): The API key for accessing Groq's API + base_url (str, optional): The base URL for Groq API. Defaults to "https://api.groq.com/openai/v1" + model (str, optional): The model identifier to use. Defaults to "llama-3.3-70b-versatile" + **kwargs: Additional keyword arguments passed to OpenAILLMService + """ + + def __init__( + self, + *, + api_key: str, + base_url: str = "https://api.groq.com/openai/v1", + model: str = "llama-3.3-70b-versatile", + **kwargs, + ): + super().__init__(api_key=api_key, base_url=base_url, model=model, **kwargs) + + def create_client(self, api_key=None, base_url=None, **kwargs): + """Create OpenAI-compatible client for Groq API endpoint.""" + logger.debug(f"Creating Groq client with api {base_url}") + return super().create_client(api_key, base_url, **kwargs) diff --git a/src/pipecat/services/groq/stt.py b/src/pipecat/services/groq/stt.py new file mode 100644 index 000000000..5852bedfd --- /dev/null +++ b/src/pipecat/services/groq/stt.py @@ -0,0 +1,67 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Optional + +from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription +from pipecat.transcriptions.language import Language + + +class GroqSTTService(BaseWhisperSTTService): + """Groq Whisper speech-to-text service. + + Uses Groq's Whisper API to convert audio to text. Requires a Groq API key + set via the api_key parameter or GROQ_API_KEY environment variable. + + Args: + model: Whisper model to use. Defaults to "whisper-large-v3-turbo". + api_key: Groq API key. Defaults to None. + base_url: API base URL. Defaults to "https://api.groq.com/openai/v1". + language: Language of the audio input. Defaults to English. + prompt: Optional text to guide the model's style or continue a previous segment. + temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0. + **kwargs: Additional arguments passed to BaseWhisperSTTService. + """ + + def __init__( + self, + *, + model: str = "whisper-large-v3-turbo", + api_key: Optional[str] = None, + base_url: str = "https://api.groq.com/openai/v1", + language: Optional[Language] = Language.EN, + prompt: Optional[str] = None, + temperature: Optional[float] = None, + **kwargs, + ): + super().__init__( + model=model, + api_key=api_key, + base_url=base_url, + language=language, + prompt=prompt, + temperature=temperature, + **kwargs, + ) + + async def _transcribe(self, audio: bytes) -> Transcription: + assert self._language is not None # Assigned in the BaseWhisperSTTService class + + # Build kwargs dict with only set parameters + kwargs = { + "file": ("audio.wav", audio, "audio/wav"), + "model": self.model_name, + "response_format": "json", + "language": self._language, + } + + if self._prompt is not None: + kwargs["prompt"] = self._prompt + + if self._temperature is not None: + kwargs["temperature"] = self._temperature + + return await self._client.audio.transcriptions.create(**kwargs) diff --git a/src/pipecat/services/groq/tts.py b/src/pipecat/services/groq/tts.py new file mode 100644 index 000000000..69429424a --- /dev/null +++ b/src/pipecat/services/groq/tts.py @@ -0,0 +1,86 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import AsyncGenerator, Optional + +from loguru import logger +from pydantic import BaseModel + +from pipecat.frames.frames import Frame, TTSAudioRawFrame, TTSStartedFrame, TTSStoppedFrame +from pipecat.services.ai_services import TTSService +from pipecat.transcriptions.language import Language + +try: + from groq import AsyncGroq +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use Groq, you need to `pip install pipecat-ai[groq]`.") + raise Exception(f"Missing module: {e}") + + +class GroqTTSService(TTSService): + class InputParams(BaseModel): + language: Optional[Language] = Language.EN + speed: Optional[float] = 1.0 + seed: Optional[int] = None + + GROQ_SAMPLE_RATE = 48000 # Groq TTS only supports 48kHz sample rate + + def __init__( + self, + *, + api_key: str, + output_format: str = "wav", + params: InputParams = InputParams(), + model_name: str = "playai-tts", + voice_id: str = "Celeste-PlayAI", + sample_rate: Optional[int] = GROQ_SAMPLE_RATE, + **kwargs, + ): + if sample_rate != self.GROQ_SAMPLE_RATE: + logger.warning(f"Groq TTS only supports {self.GROQ_SAMPLE_RATE}Hz sample rate. ") + super().__init__( + pause_frame_processing=True, + sample_rate=sample_rate, + **kwargs, + ) + + self._api_key = api_key + self._model_name = model_name + self._output_format = output_format + self._voice_id = voice_id + self._params = params + + self._client = AsyncGroq(api_key=self._api_key) + + def can_generate_metrics(self) -> bool: + return True + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"{self}: Generating TTS [{text}]") + measuring_ttfb = True + await self.start_ttfb_metrics() + yield TTSStartedFrame() + + response = await self._client.audio.speech.create( + model=self._model_name, + voice=self._voice_id, + response_format=self._output_format, + input=text, + ) + + async for data in response.iter_bytes(): + if measuring_ttfb: + await self.stop_ttfb_metrics() + measuring_ttfb = False + # remove wav header if present + if data.startswith(b"RIFF"): + data = data[44:] + if len(data) == 0: + continue + yield TTSAudioRawFrame(data, self.sample_rate, 1) + + yield TTSStoppedFrame() diff --git a/src/pipecat/services/lmnt/__init__.py b/src/pipecat/services/lmnt/__init__.py new file mode 100644 index 000000000..0f55aa55c --- /dev/null +++ b/src/pipecat/services/lmnt/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "lmnt", "lmnt.tts") diff --git a/src/pipecat/services/lmnt.py b/src/pipecat/services/lmnt/tts.py similarity index 98% rename from src/pipecat/services/lmnt.py rename to src/pipecat/services/lmnt/tts.py index d3cc92603..040d526f9 100644 --- a/src/pipecat/services/lmnt.py +++ b/src/pipecat/services/lmnt/tts.py @@ -29,9 +29,7 @@ try: import websockets except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use LMNT, you need to `pip install pipecat-ai[lmnt]`. Also, set `LMNT_API_KEY` environment variable." - ) + logger.error("In order to use LMNT, you need to `pip install pipecat-ai[lmnt]`.") raise Exception(f"Missing module: {e}") diff --git a/src/pipecat/services/mem0/__init__.py b/src/pipecat/services/mem0/__init__.py new file mode 100644 index 000000000..8459054d2 --- /dev/null +++ b/src/pipecat/services/mem0/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .memory import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "mem0", "mem0.memory") diff --git a/src/pipecat/services/mem0.py b/src/pipecat/services/mem0/memory.py similarity index 100% rename from src/pipecat/services/mem0.py rename to src/pipecat/services/mem0/memory.py diff --git a/src/pipecat/services/moondream/__init__.py b/src/pipecat/services/moondream/__init__.py new file mode 100644 index 000000000..2b9994f08 --- /dev/null +++ b/src/pipecat/services/moondream/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .vision import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "moondream", "moondream.vision") diff --git a/src/pipecat/services/moondream.py b/src/pipecat/services/moondream/vision.py similarity index 100% rename from src/pipecat/services/moondream.py rename to src/pipecat/services/moondream/vision.py diff --git a/src/pipecat/services/neuphonic/__init__.py b/src/pipecat/services/neuphonic/__init__.py new file mode 100644 index 000000000..5b2d9c34a --- /dev/null +++ b/src/pipecat/services/neuphonic/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "neuphonic", "neuphonic.tts") diff --git a/src/pipecat/services/neuphonic.py b/src/pipecat/services/neuphonic/tts.py similarity index 98% rename from src/pipecat/services/neuphonic.py rename to src/pipecat/services/neuphonic/tts.py index 407e54a83..7ac005afc 100644 --- a/src/pipecat/services/neuphonic.py +++ b/src/pipecat/services/neuphonic/tts.py @@ -30,15 +30,12 @@ from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import InterruptibleTTSService, TTSService from pipecat.transcriptions.language import Language -# See .env.example for Neuphonic configuration needed try: import websockets from pyneuphonic import Neuphonic, TTSConfig except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use Neuphonic, you need to `pip install pipecat-ai[neuphonic]`. Also, set `NEUPHONIC_API_KEY` environment variable." - ) + logger.error("In order to use Neuphonic, you need to `pip install pipecat-ai[neuphonic]`.") raise Exception(f"Missing module: {e}") diff --git a/src/pipecat/services/nim/__init__.py b/src/pipecat/services/nim/__init__.py new file mode 100644 index 000000000..05788d5ce --- /dev/null +++ b/src/pipecat/services/nim/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "nim", "nim.llm") diff --git a/src/pipecat/services/nim.py b/src/pipecat/services/nim/llm.py similarity index 98% rename from src/pipecat/services/nim.py rename to src/pipecat/services/nim/llm.py index 7146e01b3..d57fa8d4c 100644 --- a/src/pipecat/services/nim.py +++ b/src/pipecat/services/nim/llm.py @@ -4,10 +4,9 @@ # SPDX-License-Identifier: BSD 2-Clause License # - from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.openai import OpenAILLMService +from pipecat.services.openai.llm import OpenAILLMService class NimLLMService(OpenAILLMService): diff --git a/src/pipecat/services/ollama/__init__.py b/src/pipecat/services/ollama/__init__.py new file mode 100644 index 000000000..9103ee851 --- /dev/null +++ b/src/pipecat/services/ollama/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "ollama", "ollama.llm") diff --git a/src/pipecat/services/ollama.py b/src/pipecat/services/ollama/llm.py similarity index 71% rename from src/pipecat/services/ollama.py rename to src/pipecat/services/ollama/llm.py index 7d23fe128..bd1ac0d0d 100644 --- a/src/pipecat/services/ollama.py +++ b/src/pipecat/services/ollama/llm.py @@ -4,9 +4,9 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from pipecat.services.openai import BaseOpenAILLMService +from pipecat.services.openai.llm import OpenAILLMService -class OLLamaLLMService(BaseOpenAILLMService): +class OLLamaLLMService(OpenAILLMService): def __init__(self, *, model: str = "llama2", base_url: str = "http://localhost:11434/v1"): super().__init__(model=model, base_url=base_url, api_key="ollama") diff --git a/src/pipecat/services/openai.py b/src/pipecat/services/openai.py deleted file mode 100644 index 3f85d917c..000000000 --- a/src/pipecat/services/openai.py +++ /dev/null @@ -1,644 +0,0 @@ -# -# Copyright (c) 2024–2025, Daily -# -# SPDX-License-Identifier: BSD 2-Clause License -# - -import base64 -import io -import json -from dataclasses import dataclass -from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional - -import aiohttp -import httpx -from loguru import logger -from openai import ( - NOT_GIVEN, - AsyncOpenAI, - AsyncStream, - BadRequestError, - DefaultAsyncHttpxClient, -) -from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam -from PIL import Image -from pydantic import BaseModel, Field - -from pipecat.frames.frames import ( - ErrorFrame, - Frame, - FunctionCallCancelFrame, - FunctionCallInProgressFrame, - FunctionCallResultFrame, - LLMFullResponseEndFrame, - LLMFullResponseStartFrame, - LLMMessagesFrame, - LLMTextFrame, - LLMUpdateSettingsFrame, - StartFrame, - TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, - URLImageRawFrame, - UserImageRawFrame, - VisionImageRawFrame, -) -from pipecat.metrics.metrics import LLMTokenUsage -from pipecat.processors.aggregators.llm_response import ( - LLMAssistantContextAggregator, - LLMUserContextAggregator, -) -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) -from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.ai_services import ( - ImageGenService, - LLMService, - TTSService, -) -from pipecat.services.base_whisper import BaseWhisperSTTService, Transcription -from pipecat.transcriptions.language import Language - -ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] - -VALID_VOICES: Dict[str, ValidVoice] = { - "alloy": "alloy", - "echo": "echo", - "fable": "fable", - "onyx": "onyx", - "nova": "nova", - "shimmer": "shimmer", -} - - -class OpenAIUnhandledFunctionException(Exception): - pass - - -class BaseOpenAILLMService(LLMService): - """This is the base for all services that use the AsyncOpenAI client. - - This service consumes OpenAILLMContextFrame frames, which contain a reference - to an OpenAILLMContext frame. The OpenAILLMContext object defines the context - sent to the LLM for a completion. This includes user, assistant and system messages - as well as tool choices and the tool, which is used if requesting function - calls from the LLM. - """ - - class InputParams(BaseModel): - frequency_penalty: Optional[float] = Field( - default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0 - ) - presence_penalty: Optional[float] = Field( - default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0 - ) - seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0) - temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0) - # Note: top_k is currently not supported by the OpenAI client library, - # so top_k is ignored right now. - top_k: Optional[int] = Field(default=None, ge=0) - top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0) - max_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1) - max_completion_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1) - extra: Optional[Dict[str, Any]] = Field(default_factory=dict) - - def __init__( - self, - *, - model: str, - api_key=None, - base_url=None, - organization=None, - project=None, - default_headers: Mapping[str, str] | None = None, - params: InputParams = InputParams(), - **kwargs, - ): - super().__init__(**kwargs) - self._settings = { - "frequency_penalty": params.frequency_penalty, - "presence_penalty": params.presence_penalty, - "seed": params.seed, - "temperature": params.temperature, - "top_p": params.top_p, - "max_tokens": params.max_tokens, - "max_completion_tokens": params.max_completion_tokens, - "extra": params.extra if isinstance(params.extra, dict) else {}, - } - self.set_model_name(model) - self._client = self.create_client( - api_key=api_key, - base_url=base_url, - organization=organization, - project=project, - default_headers=default_headers, - **kwargs, - ) - - def create_client( - self, - api_key=None, - base_url=None, - organization=None, - project=None, - default_headers=None, - **kwargs, - ): - return AsyncOpenAI( - api_key=api_key, - base_url=base_url, - organization=organization, - project=project, - http_client=DefaultAsyncHttpxClient( - limits=httpx.Limits( - max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None - ) - ), - default_headers=default_headers, - ) - - def can_generate_metrics(self) -> bool: - return True - - async def get_chat_completions( - self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam] - ) -> AsyncStream[ChatCompletionChunk]: - params = { - "model": self.model_name, - "stream": True, - "messages": messages, - "tools": context.tools, - "tool_choice": context.tool_choice, - "stream_options": {"include_usage": True}, - "frequency_penalty": self._settings["frequency_penalty"], - "presence_penalty": self._settings["presence_penalty"], - "seed": self._settings["seed"], - "temperature": self._settings["temperature"], - "top_p": self._settings["top_p"], - "max_tokens": self._settings["max_tokens"], - "max_completion_tokens": self._settings["max_completion_tokens"], - } - - params.update(self._settings["extra"]) - - chunks = await self._client.chat.completions.create(**params) - return chunks - - async def _stream_chat_completions( - self, context: OpenAILLMContext - ) -> AsyncStream[ChatCompletionChunk]: - logger.debug(f"{self}: Generating chat [{context.get_messages_for_logging()}]") - - messages: List[ChatCompletionMessageParam] = context.get_messages() - - # base64 encode any images - for message in messages: - if message.get("mime_type") == "image/jpeg": - encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8") - text = message["content"] - message["content"] = [ - {"type": "text", "text": text}, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}, - }, - ] - del message["data"] - del message["mime_type"] - - chunks = await self.get_chat_completions(context, messages) - - return chunks - - async def _process_context(self, context: OpenAILLMContext): - functions_list = [] - arguments_list = [] - tool_id_list = [] - func_idx = 0 - function_name = "" - arguments = "" - tool_call_id = "" - - await self.start_ttfb_metrics() - - chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions( - context - ) - - async for chunk in chunk_stream: - if chunk.usage: - tokens = LLMTokenUsage( - prompt_tokens=chunk.usage.prompt_tokens, - completion_tokens=chunk.usage.completion_tokens, - total_tokens=chunk.usage.total_tokens, - ) - await self.start_llm_usage_metrics(tokens) - - if chunk.choices is None or len(chunk.choices) == 0: - continue - - await self.stop_ttfb_metrics() - - if not chunk.choices[0].delta: - continue - - if chunk.choices[0].delta.tool_calls: - # We're streaming the LLM response to enable the fastest response times. - # For text, we just yield each chunk as we receive it and count on consumers - # to do whatever coalescing they need (eg. to pass full sentences to TTS) - # - # If the LLM is a function call, we'll do some coalescing here. - # If the response contains a function name, we'll yield a frame to tell consumers - # that they can start preparing to call the function with that name. - # We accumulate all the arguments for the rest of the streamed response, then when - # the response is done, we package up all the arguments and the function name and - # yield a frame containing the function name and the arguments. - - tool_call = chunk.choices[0].delta.tool_calls[0] - if tool_call.index != func_idx: - functions_list.append(function_name) - arguments_list.append(arguments) - tool_id_list.append(tool_call_id) - function_name = "" - arguments = "" - tool_call_id = "" - func_idx += 1 - if tool_call.function and tool_call.function.name: - function_name += tool_call.function.name - tool_call_id = tool_call.id - if tool_call.function and tool_call.function.arguments: - # Keep iterating through the response to collect all the argument fragments - arguments += tool_call.function.arguments - elif chunk.choices[0].delta.content: - await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content)) - - # if we got a function name and arguments, check to see if it's a function with - # a registered handler. If so, run the registered callback, save the result to - # the context, and re-prompt to get a chat answer. If we don't have a registered - # handler, raise an exception. - if function_name and arguments: - # added to the list as last function name and arguments not added to the list - functions_list.append(function_name) - arguments_list.append(arguments) - tool_id_list.append(tool_call_id) - - for index, (function_name, arguments, tool_id) in enumerate( - zip(functions_list, arguments_list, tool_id_list), start=1 - ): - if self.has_function(function_name): - run_llm = False - arguments = json.loads(arguments) - await self.call_function( - context=context, - function_name=function_name, - arguments=arguments, - tool_call_id=tool_id, - run_llm=run_llm, - ) - else: - raise OpenAIUnhandledFunctionException( - f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." - ) - - async def process_frame(self, frame: Frame, direction: FrameDirection): - await super().process_frame(frame, direction) - - context = None - if isinstance(frame, OpenAILLMContextFrame): - context: OpenAILLMContext = frame.context - elif isinstance(frame, LLMMessagesFrame): - context = OpenAILLMContext.from_messages(frame.messages) - elif isinstance(frame, VisionImageRawFrame): - context = OpenAILLMContext() - context.add_image_frame_message( - format=frame.format, size=frame.size, image=frame.image, text=frame.text - ) - elif isinstance(frame, LLMUpdateSettingsFrame): - await self._update_settings(frame.settings) - else: - await self.push_frame(frame, direction) - - if context: - try: - await self.push_frame(LLMFullResponseStartFrame()) - await self.start_processing_metrics() - await self._process_context(context) - except httpx.TimeoutException: - await self._call_event_handler("on_completion_timeout") - finally: - await self.stop_processing_metrics() - await self.push_frame(LLMFullResponseEndFrame()) - - -@dataclass -class OpenAIContextAggregatorPair: - _user: "OpenAIUserContextAggregator" - _assistant: "OpenAIAssistantContextAggregator" - - def user(self) -> "OpenAIUserContextAggregator": - return self._user - - def assistant(self) -> "OpenAIAssistantContextAggregator": - return self._assistant - - -class OpenAILLMService(BaseOpenAILLMService): - def __init__( - self, - *, - model: str = "gpt-4o", - params: BaseOpenAILLMService.InputParams = BaseOpenAILLMService.InputParams(), - **kwargs, - ): - super().__init__(model=model, params=params, **kwargs) - - def create_context_aggregator( - self, - context: OpenAILLMContext, - *, - user_kwargs: Mapping[str, Any] = {}, - assistant_kwargs: Mapping[str, Any] = {}, - ) -> OpenAIContextAggregatorPair: - """Create an instance of OpenAIContextAggregatorPair from an - OpenAILLMContext. Constructor keyword arguments for both the user and - assistant aggregators can be provided. - - Args: - context (OpenAILLMContext): The LLM context. - user_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the user context aggregator constructor. Defaults - to an empty mapping. - assistant_kwargs (Mapping[str, Any], optional): Additional keyword - arguments for the assistant context aggregator - constructor. Defaults to an empty mapping. - - Returns: - OpenAIContextAggregatorPair: A pair of context aggregators, one for - the user and one for the assistant, encapsulated in an - OpenAIContextAggregatorPair. - - """ - context.set_llm_adapter(self.get_llm_adapter()) - user = OpenAIUserContextAggregator(context, **user_kwargs) - assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs) - return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) - - -class OpenAIImageGenService(ImageGenService): - def __init__( - self, - *, - api_key: str, - base_url: Optional[str] = None, - aiohttp_session: aiohttp.ClientSession, - image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], - model: str = "dall-e-3", - ): - super().__init__() - self.set_model_name(model) - self._image_size = image_size - self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) - self._aiohttp_session = aiohttp_session - - async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: - logger.debug(f"Generating image from prompt: {prompt}") - - image = await self._client.images.generate( - prompt=prompt, model=self.model_name, n=1, size=self._image_size - ) - - image_url = image.data[0].url - - if not image_url: - logger.error(f"{self} No image provided in response: {image}") - yield ErrorFrame("Image generation failed") - return - - # Load the image from the url - async with self._aiohttp_session.get(image_url) as response: - image_stream = io.BytesIO(await response.content.read()) - image = Image.open(image_stream) - frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format) - yield frame - - -class OpenAISTTService(BaseWhisperSTTService): - """OpenAI Speech-to-Text service that generates text from audio. - - Uses OpenAI's transcription API to convert audio to text. Requires an OpenAI API key - set via the api_key parameter or OPENAI_API_KEY environment variable. - - Args: - model: Model to use — either gpt-4o or Whisper. Defaults to "gpt-4o-transcribe". - api_key: OpenAI API key. Defaults to None. - base_url: API base URL. Defaults to None. - language: Language of the audio input. Defaults to English. - prompt: Optional text to guide the model's style or continue a previous segment. - temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0. - **kwargs: Additional arguments passed to BaseWhisperSTTService. - """ - - def __init__( - self, - *, - model: str = "gpt-4o-transcribe", - api_key: Optional[str] = None, - base_url: Optional[str] = None, - language: Optional[Language] = Language.EN, - prompt: Optional[str] = None, - temperature: Optional[float] = None, - **kwargs, - ): - super().__init__( - model=model, - api_key=api_key, - base_url=base_url, - language=language, - prompt=prompt, - temperature=temperature, - **kwargs, - ) - - async def _transcribe(self, audio: bytes) -> Transcription: - assert self._language is not None # Assigned in the BaseWhisperSTTService class - - # Build kwargs dict with only set parameters - kwargs = { - "file": ("audio.wav", audio, "audio/wav"), - "model": self.model_name, - "language": self._language, - } - - if self._prompt is not None: - kwargs["prompt"] = self._prompt - - if self._temperature is not None: - kwargs["temperature"] = self._temperature - - return await self._client.audio.transcriptions.create(**kwargs) - - -class OpenAITTSService(TTSService): - """OpenAI Text-to-Speech service that generates audio from text. - - This service uses the OpenAI TTS API to generate PCM-encoded audio at 24kHz. - - Args: - api_key: OpenAI API key. Defaults to None. - voice: Voice ID to use. Defaults to "alloy". - model: TTS model to use. Defaults to "gpt-4o-mini-tts". - sample_rate: Output audio sample rate in Hz. Defaults to None. - **kwargs: Additional keyword arguments passed to TTSService. - - The service returns PCM-encoded audio at the specified sample rate. - - """ - - OPENAI_SAMPLE_RATE = 24000 # OpenAI TTS always outputs at 24kHz - - def __init__( - self, - *, - api_key: Optional[str] = None, - base_url: Optional[str] = None, - voice: str = "alloy", - model: str = "gpt-4o-mini-tts", - sample_rate: Optional[int] = None, - instructions: Optional[str] = None, - **kwargs, - ): - if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE: - logger.warning( - f"OpenAI TTS only supports {self.OPENAI_SAMPLE_RATE}Hz sample rate. " - f"Current rate of {self.sample_rate}Hz may cause issues." - ) - super().__init__(sample_rate=sample_rate, **kwargs) - - self.set_model_name(model) - self.set_voice(voice) - self._instructions = instructions - self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) - - def can_generate_metrics(self) -> bool: - return True - - async def set_model(self, model: str): - logger.info(f"Switching TTS model to: [{model}]") - self.set_model_name(model) - - async def start(self, frame: StartFrame): - await super().start(frame) - if self.sample_rate != self.OPENAI_SAMPLE_RATE: - logger.warning( - f"OpenAI TTS requires {self.OPENAI_SAMPLE_RATE}Hz sample rate. " - f"Current rate of {self.sample_rate}Hz may cause issues." - ) - - async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: - logger.debug(f"{self}: Generating TTS [{text}]") - try: - await self.start_ttfb_metrics() - - # Setup extra body parameters - extra_body = {} - if self._instructions: - extra_body["instructions"] = self._instructions - - async with self._client.audio.speech.with_streaming_response.create( - input=text or " ", # Text must contain at least one character - model=self.model_name, - voice=VALID_VOICES[self._voice_id], - response_format="pcm", - extra_body=extra_body, - ) as r: - if r.status_code != 200: - error = await r.text() - logger.error( - f"{self} error getting audio (status: {r.status_code}, error: {error})" - ) - yield ErrorFrame( - f"Error getting audio (status: {r.status_code}, error: {error})" - ) - return - - await self.start_tts_usage_metrics(text) - - CHUNK_SIZE = 1024 - - yield TTSStartedFrame() - async for chunk in r.iter_bytes(CHUNK_SIZE): - if len(chunk) > 0: - await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame(chunk, self.sample_rate, 1) - yield frame - yield TTSStoppedFrame() - except BadRequestError as e: - logger.exception(f"{self} error generating TTS: {e}") - - -class OpenAIUserContextAggregator(LLMUserContextAggregator): - pass - - -class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): - async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): - self._context.add_message( - { - "role": "assistant", - "tool_calls": [ - { - "id": frame.tool_call_id, - "function": { - "name": frame.function_name, - "arguments": json.dumps(frame.arguments), - }, - "type": "function", - } - ], - } - ) - self._context.add_message( - { - "role": "tool", - "content": "IN_PROGRESS", - "tool_call_id": frame.tool_call_id, - } - ) - - async def handle_function_call_result(self, frame: FunctionCallResultFrame): - if frame.result: - result = json.dumps(frame.result) - await self._update_function_call_result(frame.function_name, frame.tool_call_id, result) - else: - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "COMPLETED" - ) - - async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): - await self._update_function_call_result( - frame.function_name, frame.tool_call_id, "CANCELLED" - ) - - async def _update_function_call_result( - self, function_name: str, tool_call_id: str, result: Any - ): - for message in self._context.messages: - if ( - message["role"] == "tool" - and message["tool_call_id"] - and message["tool_call_id"] == tool_call_id - ): - message["content"] = result - - async def handle_user_image_frame(self, frame: UserImageRawFrame): - await self._update_function_call_result( - frame.request.function_name, frame.request.tool_call_id, "COMPLETED" - ) - self._context.add_image_frame_message( - format=frame.format, - size=frame.size, - image=frame.image, - text=frame.request.context, - ) diff --git a/src/pipecat/services/openai/__init__.py b/src/pipecat/services/openai/__init__.py new file mode 100644 index 000000000..4decac126 --- /dev/null +++ b/src/pipecat/services/openai/__init__.py @@ -0,0 +1,16 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .image import * +from .llm import * +from .stt import * +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "openai", "openai.[image,llm,stt,tts]") diff --git a/src/pipecat/services/openai/base_llm.py b/src/pipecat/services/openai/base_llm.py new file mode 100644 index 000000000..5343c1eb7 --- /dev/null +++ b/src/pipecat/services/openai/base_llm.py @@ -0,0 +1,296 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import base64 +import json +from typing import Any, Dict, List, Mapping, Optional + +import httpx +from loguru import logger +from openai import ( + NOT_GIVEN, + AsyncOpenAI, + AsyncStream, + DefaultAsyncHttpxClient, +) +from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam +from pydantic import BaseModel, Field + +from pipecat.frames.frames import ( + Frame, + LLMFullResponseEndFrame, + LLMFullResponseStartFrame, + LLMMessagesFrame, + LLMTextFrame, + LLMUpdateSettingsFrame, + VisionImageRawFrame, +) +from pipecat.metrics.metrics import LLMTokenUsage +from pipecat.processors.aggregators.openai_llm_context import ( + OpenAILLMContext, + OpenAILLMContextFrame, +) +from pipecat.processors.frame_processor import FrameDirection +from pipecat.services.ai_services import LLMService + + +class OpenAIUnhandledFunctionException(Exception): + pass + + +class BaseOpenAILLMService(LLMService): + """This is the base for all services that use the AsyncOpenAI client. + + This service consumes OpenAILLMContextFrame frames, which contain a reference + to an OpenAILLMContext frame. The OpenAILLMContext object defines the context + sent to the LLM for a completion. This includes user, assistant and system messages + as well as tool choices and the tool, which is used if requesting function + calls from the LLM. + """ + + class InputParams(BaseModel): + frequency_penalty: Optional[float] = Field( + default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0 + ) + presence_penalty: Optional[float] = Field( + default_factory=lambda: NOT_GIVEN, ge=-2.0, le=2.0 + ) + seed: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=0) + temperature: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=2.0) + # Note: top_k is currently not supported by the OpenAI client library, + # so top_k is ignored right now. + top_k: Optional[int] = Field(default=None, ge=0) + top_p: Optional[float] = Field(default_factory=lambda: NOT_GIVEN, ge=0.0, le=1.0) + max_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1) + max_completion_tokens: Optional[int] = Field(default_factory=lambda: NOT_GIVEN, ge=1) + extra: Optional[Dict[str, Any]] = Field(default_factory=dict) + + def __init__( + self, + *, + model: str, + api_key=None, + base_url=None, + organization=None, + project=None, + default_headers: Mapping[str, str] | None = None, + params: InputParams = InputParams(), + **kwargs, + ): + super().__init__(**kwargs) + self._settings = { + "frequency_penalty": params.frequency_penalty, + "presence_penalty": params.presence_penalty, + "seed": params.seed, + "temperature": params.temperature, + "top_p": params.top_p, + "max_tokens": params.max_tokens, + "max_completion_tokens": params.max_completion_tokens, + "extra": params.extra if isinstance(params.extra, dict) else {}, + } + self.set_model_name(model) + self._client = self.create_client( + api_key=api_key, + base_url=base_url, + organization=organization, + project=project, + default_headers=default_headers, + **kwargs, + ) + + def create_client( + self, + api_key=None, + base_url=None, + organization=None, + project=None, + default_headers=None, + **kwargs, + ): + return AsyncOpenAI( + api_key=api_key, + base_url=base_url, + organization=organization, + project=project, + http_client=DefaultAsyncHttpxClient( + limits=httpx.Limits( + max_keepalive_connections=100, max_connections=1000, keepalive_expiry=None + ) + ), + default_headers=default_headers, + ) + + def can_generate_metrics(self) -> bool: + return True + + async def get_chat_completions( + self, context: OpenAILLMContext, messages: List[ChatCompletionMessageParam] + ) -> AsyncStream[ChatCompletionChunk]: + params = { + "model": self.model_name, + "stream": True, + "messages": messages, + "tools": context.tools, + "tool_choice": context.tool_choice, + "stream_options": {"include_usage": True}, + "frequency_penalty": self._settings["frequency_penalty"], + "presence_penalty": self._settings["presence_penalty"], + "seed": self._settings["seed"], + "temperature": self._settings["temperature"], + "top_p": self._settings["top_p"], + "max_tokens": self._settings["max_tokens"], + "max_completion_tokens": self._settings["max_completion_tokens"], + } + + params.update(self._settings["extra"]) + + chunks = await self._client.chat.completions.create(**params) + return chunks + + async def _stream_chat_completions( + self, context: OpenAILLMContext + ) -> AsyncStream[ChatCompletionChunk]: + logger.debug(f"{self}: Generating chat [{context.get_messages_for_logging()}]") + + messages: List[ChatCompletionMessageParam] = context.get_messages() + + # base64 encode any images + for message in messages: + if message.get("mime_type") == "image/jpeg": + encoded_image = base64.b64encode(message["data"].getvalue()).decode("utf-8") + text = message["content"] + message["content"] = [ + {"type": "text", "text": text}, + { + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{encoded_image}"}, + }, + ] + del message["data"] + del message["mime_type"] + + chunks = await self.get_chat_completions(context, messages) + + return chunks + + async def _process_context(self, context: OpenAILLMContext): + functions_list = [] + arguments_list = [] + tool_id_list = [] + func_idx = 0 + function_name = "" + arguments = "" + tool_call_id = "" + + await self.start_ttfb_metrics() + + chunk_stream: AsyncStream[ChatCompletionChunk] = await self._stream_chat_completions( + context + ) + + async for chunk in chunk_stream: + if chunk.usage: + tokens = LLMTokenUsage( + prompt_tokens=chunk.usage.prompt_tokens, + completion_tokens=chunk.usage.completion_tokens, + total_tokens=chunk.usage.total_tokens, + ) + await self.start_llm_usage_metrics(tokens) + + if chunk.choices is None or len(chunk.choices) == 0: + continue + + await self.stop_ttfb_metrics() + + if not chunk.choices[0].delta: + continue + + if chunk.choices[0].delta.tool_calls: + # We're streaming the LLM response to enable the fastest response times. + # For text, we just yield each chunk as we receive it and count on consumers + # to do whatever coalescing they need (eg. to pass full sentences to TTS) + # + # If the LLM is a function call, we'll do some coalescing here. + # If the response contains a function name, we'll yield a frame to tell consumers + # that they can start preparing to call the function with that name. + # We accumulate all the arguments for the rest of the streamed response, then when + # the response is done, we package up all the arguments and the function name and + # yield a frame containing the function name and the arguments. + + tool_call = chunk.choices[0].delta.tool_calls[0] + if tool_call.index != func_idx: + functions_list.append(function_name) + arguments_list.append(arguments) + tool_id_list.append(tool_call_id) + function_name = "" + arguments = "" + tool_call_id = "" + func_idx += 1 + if tool_call.function and tool_call.function.name: + function_name += tool_call.function.name + tool_call_id = tool_call.id + if tool_call.function and tool_call.function.arguments: + # Keep iterating through the response to collect all the argument fragments + arguments += tool_call.function.arguments + elif chunk.choices[0].delta.content: + await self.push_frame(LLMTextFrame(chunk.choices[0].delta.content)) + + # if we got a function name and arguments, check to see if it's a function with + # a registered handler. If so, run the registered callback, save the result to + # the context, and re-prompt to get a chat answer. If we don't have a registered + # handler, raise an exception. + if function_name and arguments: + # added to the list as last function name and arguments not added to the list + functions_list.append(function_name) + arguments_list.append(arguments) + tool_id_list.append(tool_call_id) + + for index, (function_name, arguments, tool_id) in enumerate( + zip(functions_list, arguments_list, tool_id_list), start=1 + ): + if self.has_function(function_name): + run_llm = False + arguments = json.loads(arguments) + await self.call_function( + context=context, + function_name=function_name, + arguments=arguments, + tool_call_id=tool_id, + run_llm=run_llm, + ) + else: + raise OpenAIUnhandledFunctionException( + f"The LLM tried to call a function named '{function_name}', but there isn't a callback registered for that function." + ) + + async def process_frame(self, frame: Frame, direction: FrameDirection): + await super().process_frame(frame, direction) + + context = None + if isinstance(frame, OpenAILLMContextFrame): + context: OpenAILLMContext = frame.context + elif isinstance(frame, LLMMessagesFrame): + context = OpenAILLMContext.from_messages(frame.messages) + elif isinstance(frame, VisionImageRawFrame): + context = OpenAILLMContext() + context.add_image_frame_message( + format=frame.format, size=frame.size, image=frame.image, text=frame.text + ) + elif isinstance(frame, LLMUpdateSettingsFrame): + await self._update_settings(frame.settings) + else: + await self.push_frame(frame, direction) + + if context: + try: + await self.push_frame(LLMFullResponseStartFrame()) + await self.start_processing_metrics() + await self._process_context(context) + except httpx.TimeoutException: + await self._call_event_handler("on_completion_timeout") + finally: + await self.stop_processing_metrics() + await self.push_frame(LLMFullResponseEndFrame()) diff --git a/src/pipecat/services/openai/image.py b/src/pipecat/services/openai/image.py new file mode 100644 index 000000000..fc0c475f9 --- /dev/null +++ b/src/pipecat/services/openai/image.py @@ -0,0 +1,58 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import io +from typing import AsyncGenerator, Literal, Optional + +import aiohttp +from loguru import logger +from openai import AsyncOpenAI +from PIL import Image + +from pipecat.frames.frames import ( + ErrorFrame, + Frame, + URLImageRawFrame, +) +from pipecat.services.ai_services import ImageGenService + + +class OpenAIImageGenService(ImageGenService): + def __init__( + self, + *, + api_key: str, + base_url: Optional[str] = None, + aiohttp_session: aiohttp.ClientSession, + image_size: Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"], + model: str = "dall-e-3", + ): + super().__init__() + self.set_model_name(model) + self._image_size = image_size + self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) + self._aiohttp_session = aiohttp_session + + async def run_image_gen(self, prompt: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"Generating image from prompt: {prompt}") + + image = await self._client.images.generate( + prompt=prompt, model=self.model_name, n=1, size=self._image_size + ) + + image_url = image.data[0].url + + if not image_url: + logger.error(f"{self} No image provided in response: {image}") + yield ErrorFrame("Image generation failed") + return + + # Load the image from the url + async with self._aiohttp_session.get(image_url) as response: + image_stream = io.BytesIO(await response.content.read()) + image = Image.open(image_stream) + frame = URLImageRawFrame(image_url, image.tobytes(), image.size, image.format) + yield frame diff --git a/src/pipecat/services/openai/llm.py b/src/pipecat/services/openai/llm.py new file mode 100644 index 000000000..5ca51a479 --- /dev/null +++ b/src/pipecat/services/openai/llm.py @@ -0,0 +1,142 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import json +from dataclasses import dataclass +from typing import Any, Mapping + +from pipecat.frames.frames import ( + FunctionCallCancelFrame, + FunctionCallInProgressFrame, + FunctionCallResultFrame, + UserImageRawFrame, +) +from pipecat.processors.aggregators.llm_response import ( + LLMAssistantContextAggregator, + LLMUserContextAggregator, +) +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext +from pipecat.services.openai.base_llm import BaseOpenAILLMService + + +@dataclass +class OpenAIContextAggregatorPair: + _user: "OpenAIUserContextAggregator" + _assistant: "OpenAIAssistantContextAggregator" + + def user(self) -> "OpenAIUserContextAggregator": + return self._user + + def assistant(self) -> "OpenAIAssistantContextAggregator": + return self._assistant + + +class OpenAILLMService(BaseOpenAILLMService): + def __init__( + self, + *, + model: str = "gpt-4o", + params: BaseOpenAILLMService.InputParams = BaseOpenAILLMService.InputParams(), + **kwargs, + ): + super().__init__(model=model, params=params, **kwargs) + + def create_context_aggregator( + self, + context: OpenAILLMContext, + *, + user_kwargs: Mapping[str, Any] = {}, + assistant_kwargs: Mapping[str, Any] = {}, + ) -> OpenAIContextAggregatorPair: + """Create an instance of OpenAIContextAggregatorPair from an + OpenAILLMContext. Constructor keyword arguments for both the user and + assistant aggregators can be provided. + + Args: + context (OpenAILLMContext): The LLM context. + user_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the user context aggregator constructor. Defaults + to an empty mapping. + assistant_kwargs (Mapping[str, Any], optional): Additional keyword + arguments for the assistant context aggregator + constructor. Defaults to an empty mapping. + + Returns: + OpenAIContextAggregatorPair: A pair of context aggregators, one for + the user and one for the assistant, encapsulated in an + OpenAIContextAggregatorPair. + + """ + context.set_llm_adapter(self.get_llm_adapter()) + user = OpenAIUserContextAggregator(context, **user_kwargs) + assistant = OpenAIAssistantContextAggregator(context, **assistant_kwargs) + return OpenAIContextAggregatorPair(_user=user, _assistant=assistant) + + +class OpenAIUserContextAggregator(LLMUserContextAggregator): + pass + + +class OpenAIAssistantContextAggregator(LLMAssistantContextAggregator): + async def handle_function_call_in_progress(self, frame: FunctionCallInProgressFrame): + self._context.add_message( + { + "role": "assistant", + "tool_calls": [ + { + "id": frame.tool_call_id, + "function": { + "name": frame.function_name, + "arguments": json.dumps(frame.arguments), + }, + "type": "function", + } + ], + } + ) + self._context.add_message( + { + "role": "tool", + "content": "IN_PROGRESS", + "tool_call_id": frame.tool_call_id, + } + ) + + async def handle_function_call_result(self, frame: FunctionCallResultFrame): + if frame.result: + result = json.dumps(frame.result) + await self._update_function_call_result(frame.function_name, frame.tool_call_id, result) + else: + await self._update_function_call_result( + frame.function_name, frame.tool_call_id, "COMPLETED" + ) + + async def handle_function_call_cancel(self, frame: FunctionCallCancelFrame): + await self._update_function_call_result( + frame.function_name, frame.tool_call_id, "CANCELLED" + ) + + async def _update_function_call_result( + self, function_name: str, tool_call_id: str, result: Any + ): + for message in self._context.messages: + if ( + message["role"] == "tool" + and message["tool_call_id"] + and message["tool_call_id"] == tool_call_id + ): + message["content"] = result + + async def handle_user_image_frame(self, frame: UserImageRawFrame): + await self._update_function_call_result( + frame.request.function_name, frame.request.tool_call_id, "COMPLETED" + ) + self._context.add_image_frame_message( + format=frame.format, + size=frame.size, + image=frame.image, + text=frame.request.context, + ) diff --git a/src/pipecat/services/openai/stt.py b/src/pipecat/services/openai/stt.py new file mode 100644 index 000000000..173205aa0 --- /dev/null +++ b/src/pipecat/services/openai/stt.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import Optional + +from pipecat.services.whisper.base_stt import BaseWhisperSTTService, Transcription +from pipecat.transcriptions.language import Language + + +class OpenAISTTService(BaseWhisperSTTService): + """OpenAI Speech-to-Text service that generates text from audio. + + Uses OpenAI's transcription API to convert audio to text. Requires an OpenAI API key + set via the api_key parameter or OPENAI_API_KEY environment variable. + + Args: + model: Model to use — either gpt-4o or Whisper. Defaults to "gpt-4o-transcribe". + api_key: OpenAI API key. Defaults to None. + base_url: API base URL. Defaults to None. + language: Language of the audio input. Defaults to English. + prompt: Optional text to guide the model's style or continue a previous segment. + temperature: Optional sampling temperature between 0 and 1. Defaults to 0.0. + **kwargs: Additional arguments passed to BaseWhisperSTTService. + """ + + def __init__( + self, + *, + model: str = "gpt-4o-transcribe", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + language: Optional[Language] = Language.EN, + prompt: Optional[str] = None, + temperature: Optional[float] = None, + **kwargs, + ): + super().__init__( + model=model, + api_key=api_key, + base_url=base_url, + language=language, + prompt=prompt, + temperature=temperature, + **kwargs, + ) + + async def _transcribe(self, audio: bytes) -> Transcription: + assert self._language is not None # Assigned in the BaseWhisperSTTService class + + # Build kwargs dict with only set parameters + kwargs = { + "file": ("audio.wav", audio, "audio/wav"), + "model": self.model_name, + "language": self._language, + } + + if self._prompt is not None: + kwargs["prompt"] = self._prompt + + if self._temperature is not None: + kwargs["temperature"] = self._temperature + + return await self._client.audio.transcriptions.create(**kwargs) diff --git a/src/pipecat/services/openai/tts.py b/src/pipecat/services/openai/tts.py new file mode 100644 index 000000000..68644f147 --- /dev/null +++ b/src/pipecat/services/openai/tts.py @@ -0,0 +1,129 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +from typing import AsyncGenerator, Dict, Literal, Optional + +from loguru import logger +from openai import AsyncOpenAI, BadRequestError + +from pipecat.frames.frames import ( + ErrorFrame, + Frame, + StartFrame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.ai_services import TTSService + +ValidVoice = Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] + +VALID_VOICES: Dict[str, ValidVoice] = { + "alloy": "alloy", + "echo": "echo", + "fable": "fable", + "onyx": "onyx", + "nova": "nova", + "shimmer": "shimmer", +} + + +class OpenAITTSService(TTSService): + """OpenAI Text-to-Speech service that generates audio from text. + + This service uses the OpenAI TTS API to generate PCM-encoded audio at 24kHz. + + Args: + api_key: OpenAI API key. Defaults to None. + voice: Voice ID to use. Defaults to "alloy". + model: TTS model to use. Defaults to "gpt-4o-mini-tts". + sample_rate: Output audio sample rate in Hz. Defaults to None. + **kwargs: Additional keyword arguments passed to TTSService. + + The service returns PCM-encoded audio at the specified sample rate. + + """ + + OPENAI_SAMPLE_RATE = 24000 # OpenAI TTS always outputs at 24kHz + + def __init__( + self, + *, + api_key: Optional[str] = None, + base_url: Optional[str] = None, + voice: str = "alloy", + model: str = "gpt-4o-mini-tts", + sample_rate: Optional[int] = None, + instructions: Optional[str] = None, + **kwargs, + ): + if sample_rate and sample_rate != self.OPENAI_SAMPLE_RATE: + logger.warning( + f"OpenAI TTS only supports {self.OPENAI_SAMPLE_RATE}Hz sample rate. " + f"Current rate of {self.sample_rate}Hz may cause issues." + ) + super().__init__(sample_rate=sample_rate, **kwargs) + + self.set_model_name(model) + self.set_voice(voice) + self._instructions = instructions + self._client = AsyncOpenAI(api_key=api_key, base_url=base_url) + + def can_generate_metrics(self) -> bool: + return True + + async def set_model(self, model: str): + logger.info(f"Switching TTS model to: [{model}]") + self.set_model_name(model) + + async def start(self, frame: StartFrame): + await super().start(frame) + if self.sample_rate != self.OPENAI_SAMPLE_RATE: + logger.warning( + f"OpenAI TTS requires {self.OPENAI_SAMPLE_RATE}Hz sample rate. " + f"Current rate of {self.sample_rate}Hz may cause issues." + ) + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + logger.debug(f"{self}: Generating TTS [{text}]") + try: + await self.start_ttfb_metrics() + + # Setup extra body parameters + extra_body = {} + if self._instructions: + extra_body["instructions"] = self._instructions + + async with self._client.audio.speech.with_streaming_response.create( + input=text or " ", # Text must contain at least one character + model=self.model_name, + voice=VALID_VOICES[self._voice_id], + response_format="pcm", + extra_body=extra_body, + ) as r: + if r.status_code != 200: + error = await r.text() + logger.error( + f"{self} error getting audio (status: {r.status_code}, error: {error})" + ) + yield ErrorFrame( + f"Error getting audio (status: {r.status_code}, error: {error})" + ) + return + + await self.start_tts_usage_metrics(text) + + CHUNK_SIZE = 1024 + + yield TTSStartedFrame() + async for chunk in r.iter_bytes(CHUNK_SIZE): + if len(chunk) > 0: + await self.stop_ttfb_metrics() + frame = TTSAudioRawFrame(chunk, self.sample_rate, 1) + yield frame + yield TTSStoppedFrame() + except BadRequestError as e: + logger.exception(f"{self} error generating TTS: {e}") diff --git a/src/pipecat/services/openai_realtime_beta/azure.py b/src/pipecat/services/openai_realtime_beta/azure.py index 5f046b8b0..799c5e686 100644 --- a/src/pipecat/services/openai_realtime_beta/azure.py +++ b/src/pipecat/services/openai_realtime_beta/azure.py @@ -4,8 +4,6 @@ # SPDX-License-Identifier: BSD 2-Clause License # -import os - from loguru import logger from .openai import OpenAIRealtimeBetaLLMService diff --git a/src/pipecat/services/openai_realtime_beta/context.py b/src/pipecat/services/openai_realtime_beta/context.py index c8381976f..80cffbef2 100644 --- a/src/pipecat/services/openai_realtime_beta/context.py +++ b/src/pipecat/services/openai_realtime_beta/context.py @@ -6,23 +6,18 @@ import copy import json -from typing import Optional from loguru import logger from pipecat.frames.frames import ( Frame, FunctionCallResultFrame, - FunctionCallResultProperties, LLMMessagesUpdateFrame, LLMSetToolsFrame, ) -from pipecat.processors.aggregators.openai_llm_context import ( - OpenAILLMContext, - OpenAILLMContextFrame, -) +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext from pipecat.processors.frame_processor import FrameDirection -from pipecat.services.openai import ( +from pipecat.services.openai.llm import ( OpenAIAssistantContextAggregator, OpenAIUserContextAggregator, ) diff --git a/src/pipecat/services/openai_realtime_beta/openai.py b/src/pipecat/services/openai_realtime_beta/openai.py index a1f4bc731..7b10b83fc 100644 --- a/src/pipecat/services/openai_realtime_beta/openai.py +++ b/src/pipecat/services/openai_realtime_beta/openai.py @@ -12,8 +12,6 @@ from typing import Any, Mapping from loguru import logger -from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter - try: import websockets except ModuleNotFoundError as e: @@ -23,6 +21,7 @@ except ModuleNotFoundError as e: ) raise Exception(f"Missing module: {e}") +from pipecat.adapters.services.open_ai_realtime_adapter import OpenAIRealtimeLLMAdapter from pipecat.frames.frames import ( BotStoppedSpeakingFrame, CancelFrame, @@ -55,7 +54,7 @@ from pipecat.processors.aggregators.openai_llm_context import ( ) from pipecat.processors.frame_processor import FrameDirection from pipecat.services.ai_services import LLMService -from pipecat.services.openai import OpenAIContextAggregatorPair +from pipecat.services.openai.llm import OpenAIContextAggregatorPair from pipecat.utils.time import time_now_iso8601 from . import events diff --git a/src/pipecat/services/openpipe/__init__.py b/src/pipecat/services/openpipe/__init__.py new file mode 100644 index 000000000..c6e439ba9 --- /dev/null +++ b/src/pipecat/services/openpipe/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "openpipe", "openpipe.llm") diff --git a/src/pipecat/services/openpipe.py b/src/pipecat/services/openpipe/llm.py similarity index 89% rename from src/pipecat/services/openpipe.py rename to src/pipecat/services/openpipe/llm.py index c89c9d6a3..63e499714 100644 --- a/src/pipecat/services/openpipe.py +++ b/src/pipecat/services/openpipe/llm.py @@ -10,16 +10,14 @@ from loguru import logger from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.openai import OpenAILLMService +from pipecat.services.openai.llm import OpenAILLMService try: from openpipe import AsyncOpenAI as OpenPipeAI from openpipe import AsyncStream except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use OpenPipe, you need to `pip install pipecat-ai[openpipe]`. Also, set `OPENPIPE_API_KEY` and `OPENAI_API_KEY` environment variables." - ) + logger.error("In order to use OpenPipe, you need to `pip install pipecat-ai[openpipe]`.") raise Exception(f"Missing module: {e}") diff --git a/src/pipecat/services/openrouter/__init__.py b/src/pipecat/services/openrouter/__init__.py new file mode 100644 index 000000000..12c4f2ea3 --- /dev/null +++ b/src/pipecat/services/openrouter/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "openrouter", "openrouter.llm") diff --git a/src/pipecat/services/openrouter.py b/src/pipecat/services/openrouter/llm.py similarity index 96% rename from src/pipecat/services/openrouter.py rename to src/pipecat/services/openrouter/llm.py index d25990d2e..431724f94 100644 --- a/src/pipecat/services/openrouter.py +++ b/src/pipecat/services/openrouter/llm.py @@ -8,7 +8,7 @@ from typing import Optional from loguru import logger -from pipecat.services.openai import OpenAILLMService +from pipecat.services.openai.llm import OpenAILLMService class OpenRouterLLMService(OpenAILLMService): diff --git a/src/pipecat/services/perplexity/__init__.py b/src/pipecat/services/perplexity/__init__.py new file mode 100644 index 000000000..b1cf42e12 --- /dev/null +++ b/src/pipecat/services/perplexity/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "perplexity", "perplexity.llm") diff --git a/src/pipecat/services/perplexity.py b/src/pipecat/services/perplexity/llm.py similarity index 98% rename from src/pipecat/services/perplexity.py rename to src/pipecat/services/perplexity/llm.py index b0c560ce4..ff9f82bdb 100644 --- a/src/pipecat/services/perplexity.py +++ b/src/pipecat/services/perplexity/llm.py @@ -6,13 +6,12 @@ from typing import List -from loguru import logger from openai import NOT_GIVEN, AsyncStream from openai.types.chat import ChatCompletionChunk, ChatCompletionMessageParam from pipecat.metrics.metrics import LLMTokenUsage from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContext -from pipecat.services.openai import OpenAILLMService +from pipecat.services.openai.llm import OpenAILLMService class PerplexityLLMService(OpenAILLMService): diff --git a/src/pipecat/services/piper/__init__.py b/src/pipecat/services/piper/__init__.py new file mode 100644 index 000000000..2cb7e790b --- /dev/null +++ b/src/pipecat/services/piper/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "piper", "piper.tts") diff --git a/src/pipecat/services/piper.py b/src/pipecat/services/piper/tts.py similarity index 100% rename from src/pipecat/services/piper.py rename to src/pipecat/services/piper/tts.py diff --git a/src/pipecat/services/playht/__init__.py b/src/pipecat/services/playht/__init__.py new file mode 100644 index 000000000..3d2247f5e --- /dev/null +++ b/src/pipecat/services/playht/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "playht", "playht.tts") diff --git a/src/pipecat/services/playht.py b/src/pipecat/services/playht/tts.py similarity index 98% rename from src/pipecat/services/playht.py rename to src/pipecat/services/playht/tts.py index 75677876f..f7157a950 100644 --- a/src/pipecat/services/playht.py +++ b/src/pipecat/services/playht/tts.py @@ -36,9 +36,7 @@ try: from pyht.client import Language as PlayHTLanguage except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use PlayHT, you need to `pip install pipecat-ai[playht]`. Also, set `PLAY_HT_USER_ID` and `PLAY_HT_API_KEY` environment variables." - ) + logger.error("In order to use PlayHT, you need to `pip install pipecat-ai[playht]`.") raise Exception(f"Missing module: {e}") diff --git a/src/pipecat/services/qwen/__init__.py b/src/pipecat/services/qwen/__init__.py new file mode 100644 index 000000000..0b398c1eb --- /dev/null +++ b/src/pipecat/services/qwen/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "qwen", "qwen.llm") diff --git a/src/pipecat/services/qwen.py b/src/pipecat/services/qwen/llm.py similarity index 93% rename from src/pipecat/services/qwen.py rename to src/pipecat/services/qwen/llm.py index 0ed5ca593..de910a741 100644 --- a/src/pipecat/services/qwen.py +++ b/src/pipecat/services/qwen/llm.py @@ -4,11 +4,9 @@ # SPDX-License-Identifier: BSD 2-Clause License # -from typing import Optional - from loguru import logger -from pipecat.services.openai import OpenAILLMService, OpenAISTTService +from pipecat.services.openai.llm import OpenAILLMService class QwenLLMService(OpenAILLMService): diff --git a/src/pipecat/services/rime/__init__.py b/src/pipecat/services/rime/__init__.py new file mode 100644 index 000000000..842fbd971 --- /dev/null +++ b/src/pipecat/services/rime/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "rime", "rime.tts") diff --git a/src/pipecat/services/rime.py b/src/pipecat/services/rime/tts.py similarity index 99% rename from src/pipecat/services/rime.py rename to src/pipecat/services/rime/tts.py index a1a455eb5..f5b5da2a6 100644 --- a/src/pipecat/services/rime.py +++ b/src/pipecat/services/rime/tts.py @@ -34,9 +34,7 @@ try: import websockets except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use Rime, you need to `pip install pipecat-ai[rime]`. Also, set `RIME_API_KEY` environment variable." - ) + logger.error("In order to use Rime, you need to `pip install pipecat-ai[rime]`.") raise Exception(f"Missing module: {e}") diff --git a/src/pipecat/services/riva/__init__.py b/src/pipecat/services/riva/__init__.py new file mode 100644 index 000000000..29fe0c003 --- /dev/null +++ b/src/pipecat/services/riva/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .stt import * +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "riva", "riva.[stt,tts]") diff --git a/src/pipecat/services/riva.py b/src/pipecat/services/riva/stt.py similarity index 64% rename from src/pipecat/services/riva.py rename to src/pipecat/services/riva/stt.py index de065aef4..63eea8230 100644 --- a/src/pipecat/services/riva.py +++ b/src/pipecat/services/riva/stt.py @@ -17,11 +17,8 @@ from pipecat.frames.frames import ( InterimTranscriptionFrame, StartFrame, TranscriptionFrame, - TTSAudioRawFrame, - TTSStartedFrame, - TTSStoppedFrame, ) -from pipecat.services.ai_services import STTService, TTSService +from pipecat.services.ai_services import STTService from pipecat.transcriptions.language import Language from pipecat.utils.time import time_now_iso8601 @@ -30,100 +27,9 @@ try: except ModuleNotFoundError as e: logger.error(f"Exception: {e}") - logger.error( - "In order to use nvidia riva TTS or STT, you need to `pip install pipecat-ai[riva]`. Also, set `NVIDIA_API_KEY` environment variable." - ) + logger.error("In order to use NVIDIA Riva STT, you need to `pip install pipecat-ai[riva]`.") raise Exception(f"Missing module: {e}") -FASTPITCH_TIMEOUT_SECS = 5 - - -class FastPitchTTSService(TTSService): - class InputParams(BaseModel): - language: Optional[Language] = Language.EN_US - quality: Optional[int] = 20 - - def __init__( - self, - *, - api_key: str, - server: str = "grpc.nvcf.nvidia.com:443", - voice_id: str = "English-US.Female-1", - sample_rate: Optional[int] = None, - function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972", - params: InputParams = InputParams(), - **kwargs, - ): - super().__init__(sample_rate=sample_rate, **kwargs) - self._api_key = api_key - self._voice_id = voice_id - self._language_code = params.language - self._quality = params.quality - - self.set_model_name("fastpitch-hifigan-tts") - self.set_voice(voice_id) - - metadata = [ - ["function-id", function_id], - ["authorization", f"Bearer {api_key}"], - ] - auth = riva.client.Auth(None, True, server, metadata) - - self._service = riva.client.SpeechSynthesisService(auth) - - # warm up the service - config_response = self._service.stub.GetRivaSynthesisConfig( - riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest() - ) - - async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: - def read_audio_responses(queue: asyncio.Queue): - def add_response(r): - asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop()) - - try: - responses = self._service.synthesize_online( - text, - self._voice_id, - self._language_code, - sample_rate_hz=self.sample_rate, - audio_prompt_file=None, - quality=self._quality, - custom_dictionary={}, - ) - for r in responses: - add_response(r) - add_response(None) - except Exception as e: - logger.error(f"{self} exception: {e}") - add_response(None) - - await self.start_ttfb_metrics() - yield TTSStartedFrame() - - logger.debug(f"{self}: Generating TTS [{text}]") - - try: - queue = asyncio.Queue() - await asyncio.to_thread(read_audio_responses, queue) - - # Wait for the thread to start. - resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS) - while resp: - await self.stop_ttfb_metrics() - frame = TTSAudioRawFrame( - audio=resp.audio, - sample_rate=self.sample_rate, - num_channels=1, - ) - yield frame - resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS) - except asyncio.TimeoutError: - logger.error(f"{self} timeout waiting for audio response") - - await self.start_tts_usage_metrics(text) - yield TTSStoppedFrame() - class ParakeetSTTService(STTService): class InputParams(BaseModel): diff --git a/src/pipecat/services/riva/tts.py b/src/pipecat/services/riva/tts.py new file mode 100644 index 000000000..e0da3ab98 --- /dev/null +++ b/src/pipecat/services/riva/tts.py @@ -0,0 +1,117 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import asyncio +from typing import AsyncGenerator, Optional + +from loguru import logger +from pydantic import BaseModel + +from pipecat.frames.frames import ( + Frame, + TTSAudioRawFrame, + TTSStartedFrame, + TTSStoppedFrame, +) +from pipecat.services.ai_services import TTSService +from pipecat.transcriptions.language import Language + +try: + import riva.client + +except ModuleNotFoundError as e: + logger.error(f"Exception: {e}") + logger.error("In order to use NVIDIA Riva TTS, you need to `pip install pipecat-ai[riva]`.") + raise Exception(f"Missing module: {e}") + +FASTPITCH_TIMEOUT_SECS = 5 + + +class FastPitchTTSService(TTSService): + class InputParams(BaseModel): + language: Optional[Language] = Language.EN_US + quality: Optional[int] = 20 + + def __init__( + self, + *, + api_key: str, + server: str = "grpc.nvcf.nvidia.com:443", + voice_id: str = "English-US.Female-1", + sample_rate: Optional[int] = None, + function_id: str = "0149dedb-2be8-4195-b9a0-e57e0e14f972", + params: InputParams = InputParams(), + **kwargs, + ): + super().__init__(sample_rate=sample_rate, **kwargs) + self._api_key = api_key + self._voice_id = voice_id + self._language_code = params.language + self._quality = params.quality + + self.set_model_name("fastpitch-hifigan-tts") + self.set_voice(voice_id) + + metadata = [ + ["function-id", function_id], + ["authorization", f"Bearer {api_key}"], + ] + auth = riva.client.Auth(None, True, server, metadata) + + self._service = riva.client.SpeechSynthesisService(auth) + + # warm up the service + config_response = self._service.stub.GetRivaSynthesisConfig( + riva.client.proto.riva_tts_pb2.RivaSynthesisConfigRequest() + ) + + async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]: + def read_audio_responses(queue: asyncio.Queue): + def add_response(r): + asyncio.run_coroutine_threadsafe(queue.put(r), self.get_event_loop()) + + try: + responses = self._service.synthesize_online( + text, + self._voice_id, + self._language_code, + sample_rate_hz=self.sample_rate, + audio_prompt_file=None, + quality=self._quality, + custom_dictionary={}, + ) + for r in responses: + add_response(r) + add_response(None) + except Exception as e: + logger.error(f"{self} exception: {e}") + add_response(None) + + await self.start_ttfb_metrics() + yield TTSStartedFrame() + + logger.debug(f"{self}: Generating TTS [{text}]") + + try: + queue = asyncio.Queue() + await asyncio.to_thread(read_audio_responses, queue) + + # Wait for the thread to start. + resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS) + while resp: + await self.stop_ttfb_metrics() + frame = TTSAudioRawFrame( + audio=resp.audio, + sample_rate=self.sample_rate, + num_channels=1, + ) + yield frame + resp = await asyncio.wait_for(queue.get(), FASTPITCH_TIMEOUT_SECS) + except asyncio.TimeoutError: + logger.error(f"{self} timeout waiting for audio response") + + await self.start_tts_usage_metrics(text) + yield TTSStoppedFrame() diff --git a/src/pipecat/services/simli/__init__.py b/src/pipecat/services/simli/__init__.py new file mode 100644 index 000000000..7e7cd7ed7 --- /dev/null +++ b/src/pipecat/services/simli/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .video import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "simli", "simli.video") diff --git a/src/pipecat/services/simli.py b/src/pipecat/services/simli/video.py similarity index 100% rename from src/pipecat/services/simli.py rename to src/pipecat/services/simli/video.py diff --git a/src/pipecat/services/tavus/__init__.py b/src/pipecat/services/tavus/__init__.py new file mode 100644 index 000000000..97811e467 --- /dev/null +++ b/src/pipecat/services/tavus/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .video import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "tavus", "tavus.video") diff --git a/src/pipecat/services/tavus.py b/src/pipecat/services/tavus/video.py similarity index 99% rename from src/pipecat/services/tavus.py rename to src/pipecat/services/tavus/video.py index 9083dd53d..27b1d7155 100644 --- a/src/pipecat/services/tavus.py +++ b/src/pipecat/services/tavus/video.py @@ -4,7 +4,6 @@ # SPDX-License-Identifier: BSD 2-Clause License # - """This module implements Tavus as a sink transport layer""" import base64 diff --git a/src/pipecat/services/to_be_updated/__init__.py b/src/pipecat/services/to_be_updated/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/pipecat/services/to_be_updated/cloudflare_ai_service.py b/src/pipecat/services/to_be_updated/cloudflare_ai_service.py deleted file mode 100644 index ff637ff1a..000000000 --- a/src/pipecat/services/to_be_updated/cloudflare_ai_service.py +++ /dev/null @@ -1,69 +0,0 @@ -import os - -import requests -from services.ai_service import AIService - -# Note that Cloudflare's AI workers are still in beta. -# https://developers.cloudflare.com/workers-ai/ - - -class CloudflareAIService(AIService): - def __init__(self): - super().__init__() - self.cloudflare_account_id = os.getenv("CLOUDFLARE_ACCOUNT_ID") - self.cloudflare_api_token = os.getenv("CLOUDFLARE_API_TOKEN") - - self.api_base_url = ( - f"https://api.cloudflare.com/client/v4/accounts/{self.cloudflare_account_id}/ai/run/" - ) - self.headers = {"Authorization": f"Bearer {self.cloudflare_api_token}"} - - # base endpoint, used by the others - def run(self, model, input): - response = requests.post(f"{self.api_base_url}{model}", headers=self.headers, json=input) - return response.json() - - # https://developers.cloudflare.com/workers-ai/models/llm/ - def run_llm(self, messages, latest_user_message=None, stream=True): - input = { - "messages": [ - {"role": "system", "content": "You are a friendly assistant"}, - {"role": "user", "content": sentence}, - ] - } - - return self.run("@cf/meta/llama-2-7b-chat-int8", input) - - # https://developers.cloudflare.com/workers-ai/models/translation/ - def run_text_translation(self, sentence, source_language, target_language): - return self.run( - "@cf/meta/m2m100-1.2b", - {"text": sentence, "source_lang": source_language, "target_lang": target_language}, - ) - - # https://developers.cloudflare.com/workers-ai/models/sentiment-analysis/ - def run_text_sentiment(self, sentence): - return self.run("@cf/huggingface/distilbert-sst-2-int8", {"text": sentence}) - - # https://developers.cloudflare.com/workers-ai/models/image-classification/ - def run_image_classification(self, image_url): - response = requests.get(image_url) - - if response.status_code != 200: - return {"error": "There was a problem downloading the image."} - - if response.status_code == 200: - data = response.content - inputs = {"image": list(data)} - - return self.run("@cf/microsoft/resnet-50", inputs) - - # https://developers.cloudflare.com/workers-ai/models/embedding/ - def run_embeddings(self, texts, size="medium"): - models = { - "small": "@cf/baai/bge-small-en-v1.5", # 384 output dimensions - "medium": "@cf/baai/bge-base-en-v1.5", # 768 output dimensions - "large": "@cf/baai/bge-large-en-v1.5", # 1024 output dimensions - } - - return self.run(models[size], {"text": texts}) diff --git a/src/pipecat/services/to_be_updated/google_ai_service.py b/src/pipecat/services/to_be_updated/google_ai_service.py deleted file mode 100644 index 3ca688750..000000000 --- a/src/pipecat/services/to_be_updated/google_ai_service.py +++ /dev/null @@ -1,30 +0,0 @@ -import os - -import openai - -# To use Google Cloud's AI products, you'll need to install Google Cloud -# CLI and enable the TTS and in your project: -# https://cloud.google.com/sdk/docs/install -from google.cloud import texttospeech -from services.ai_service import AIService - - -class GoogleAIService(AIService): - def __init__(self): - super().__init__() - - self.client = texttospeech.TextToSpeechClient() - self.voice = texttospeech.VoiceSelectionParams( - language_code="en-GB", name="en-GB-Neural2-F" - ) - - self.audio_config = texttospeech.AudioConfig( - audio_encoding=texttospeech.AudioEncoding.LINEAR16, sample_rate_hertz=16000 - ) - - def run_tts(self, sentence): - synthesis_input = texttospeech.SynthesisInput(text=sentence.strip()) - result = self.client.synthesize_speech( - input=synthesis_input, voice=self.voice, audio_config=self.audio_config - ) - return result diff --git a/src/pipecat/services/to_be_updated/huggingface_ai_service.py b/src/pipecat/services/to_be_updated/huggingface_ai_service.py deleted file mode 100644 index 09f0b8248..000000000 --- a/src/pipecat/services/to_be_updated/huggingface_ai_service.py +++ /dev/null @@ -1,33 +0,0 @@ -from services.ai_service import AIService -from transformers import pipeline - -# These functions are just intended for testing, not production use. If -# you'd like to use HuggingFace, you should use your own models, or do -# some research into the specific models that will work best for your use -# case. - - -class HuggingFaceAIService(AIService): - def __init__(self): - super().__init__() - - def run_text_sentiment(self, sentence): - classifier = pipeline("sentiment-analysis") - return classifier(sentence) - - # available models at https://huggingface.co/Helsinki-NLP (**not all - # models use 2-character language codes**) - def run_text_translation(self, sentence, source_language, target_language): - translator = pipeline( - f"translation", model=f"Helsinki-NLP/opus-mt-{source_language}-{target_language}" - ) - - return translator(sentence)[0]["translation_text"] - - def run_text_summarization(self, sentence): - summarizer = pipeline("summarization") - return summarizer(sentence) - - def run_image_classification(self, image_path): - classifier = pipeline("image-classification") - return classifier(image_path) diff --git a/src/pipecat/services/to_be_updated/mock_ai_service.py b/src/pipecat/services/to_be_updated/mock_ai_service.py deleted file mode 100644 index 0825cde33..000000000 --- a/src/pipecat/services/to_be_updated/mock_ai_service.py +++ /dev/null @@ -1,28 +0,0 @@ -import io -import time - -import requests -from PIL import Image -from services.ai_service import AIService - - -class MockAIService(AIService): - def __init__(self): - super().__init__() - - def run_tts(self, sentence): - print("running tts", sentence) - time.sleep(2) - - def run_image_gen(self, sentence): - image_url = "https://d3d00swyhr67nd.cloudfront.net/w800h800/collection/ASH/ASHM/ASH_ASHM_WA1940_2_22-001.jpg" - response = requests.get(image_url) - image_stream = io.BytesIO(response.content) - image = Image.open(image_stream) - time.sleep(1) - return (image_url, image.tobytes(), image.size) - - def run_llm(self, messages, latest_user_message=None, stream=True): - for i in range(5): - time.sleep(1) - yield ({"choices": [{"delta": {"content": f"hello {i}!"}}]}) diff --git a/src/pipecat/services/together/__init__.py b/src/pipecat/services/together/__init__.py new file mode 100644 index 000000000..b7e3a779b --- /dev/null +++ b/src/pipecat/services/together/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .llm import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "together", "together.llm") diff --git a/src/pipecat/services/together.py b/src/pipecat/services/together/llm.py similarity index 96% rename from src/pipecat/services/together.py rename to src/pipecat/services/together/llm.py index 3b83cad83..31b15ae73 100644 --- a/src/pipecat/services/together.py +++ b/src/pipecat/services/together/llm.py @@ -4,10 +4,9 @@ # SPDX-License-Identifier: BSD 2-Clause License # - from loguru import logger -from pipecat.services.openai import OpenAILLMService +from pipecat.services.openai.llm import OpenAILLMService class TogetherLLMService(OpenAILLMService): diff --git a/src/pipecat/services/ultravox/__init__.py b/src/pipecat/services/ultravox/__init__.py new file mode 100644 index 000000000..b05e930eb --- /dev/null +++ b/src/pipecat/services/ultravox/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .stt import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "ultravox", "ultravox.stt") diff --git a/src/pipecat/services/ultravox.py b/src/pipecat/services/ultravox/stt.py similarity index 100% rename from src/pipecat/services/ultravox.py rename to src/pipecat/services/ultravox/stt.py diff --git a/src/pipecat/services/whisper/__init__.py b/src/pipecat/services/whisper/__init__.py new file mode 100644 index 000000000..3c82f3bfb --- /dev/null +++ b/src/pipecat/services/whisper/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .base_stt import * +from .stt import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "whisper", "whisper.stt") diff --git a/src/pipecat/services/base_whisper.py b/src/pipecat/services/whisper/base_stt.py similarity index 100% rename from src/pipecat/services/base_whisper.py rename to src/pipecat/services/whisper/base_stt.py diff --git a/src/pipecat/services/whisper.py b/src/pipecat/services/whisper/stt.py similarity index 100% rename from src/pipecat/services/whisper.py rename to src/pipecat/services/whisper/stt.py diff --git a/src/pipecat/services/xtts/__init__.py b/src/pipecat/services/xtts/__init__.py new file mode 100644 index 000000000..8f5a87892 --- /dev/null +++ b/src/pipecat/services/xtts/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2024–2025, Daily +# +# SPDX-License-Identifier: BSD 2-Clause License +# + +import sys + +from pipecat.services import DeprecatedModuleProxy + +from .tts import * + +sys.modules[__name__] = DeprecatedModuleProxy(globals(), "xtts", "xtts.tts") diff --git a/src/pipecat/services/xtts.py b/src/pipecat/services/xtts/tts.py similarity index 100% rename from src/pipecat/services/xtts.py rename to src/pipecat/services/xtts/tts.py diff --git a/tests/integration/test_integration_unified_function_calling.py b/tests/integration/test_integration_unified_function_calling.py index 0696c531a..f4652905d 100644 --- a/tests/integration/test_integration_unified_function_calling.py +++ b/tests/integration/test_integration_unified_function_calling.py @@ -13,10 +13,11 @@ from dotenv import load_dotenv from pipecat.adapters.schemas.function_schema import FunctionSchema from pipecat.adapters.schemas.tools_schema import ToolsSchema from pipecat.pipeline.pipeline import Pipeline +from pipecat.processors.aggregators.openai_llm_context import OpenAILLMContextFrame from pipecat.services.ai_services import LLMService from pipecat.services.anthropic import AnthropicLLMService from pipecat.services.google import GoogleLLMService -from pipecat.services.openai import OpenAILLMContext, OpenAILLMContextFrame, OpenAILLMService +from pipecat.services.openai import OpenAILLMContext, OpenAILLMService from pipecat.tests.utils import run_test load_dotenv(override=True)